├── lib ├── __init__.py ├── config │ ├── __init__.py │ └── config.py ├── interactive │ ├── __init__.py │ ├── fonts │ │ ├── ZillaSlab-Regular.ttf │ │ └── Caskaydia Cove Nerd Font Complete.ttf │ └── render_options.py ├── networks │ ├── __init__.py │ ├── make_network.py │ └── enerf │ │ ├── feature_net.py │ │ ├── cost_reg_net.py │ │ ├── res_unet.py │ │ ├── cost_reg_net_.py │ │ ├── nerf.py │ │ ├── nerf_.py │ │ ├── network.py │ │ └── network_human.py ├── datasets │ ├── __init__.py │ ├── collate_batch.py │ ├── enerf_utils.py │ ├── make_dataset.py │ ├── nerf │ │ └── enerf.py │ └── samplers.py ├── evaluators │ ├── __init__.py │ ├── make_evaluator.py │ ├── nerf.py │ ├── enerf_composite.py │ └── enerf_human.py ├── train │ ├── trainers │ │ ├── __init__.py │ │ ├── make_trainer.py │ │ └── trainer.py │ ├── __init__.py │ ├── optimizer.py │ ├── scheduler.py │ ├── losses │ │ ├── nerf.py │ │ ├── vgg_perceptual_loss.py │ │ └── enerf.py │ └── recorder.py ├── visualizers │ ├── __init__.py │ ├── make_visualizer.py │ ├── enerf_interactive.py │ ├── enerf.py │ └── nerf.py └── utils │ ├── data_config.py │ ├── mesh_utils.py │ ├── colmap │ ├── test_read_write_dense.py │ ├── clang_format_code.py │ ├── test_read_write_fused_vis.py │ ├── merge_ply_files.py │ ├── export_inlier_pairs.py │ ├── export_inlier_matches.py │ ├── build_windows_app.py │ ├── test_read_write_model.py │ ├── bundler_to_ply.py │ ├── nvm_to_ply.py │ ├── read_write_fused_vis.py │ ├── read_write_dense.py │ └── crawl_camera_specs.py │ ├── vis_utils.py │ ├── optimizer │ └── lr_scheduler.py │ ├── enerf │ └── val_data_utils.py │ ├── img_utils.py │ ├── base_utils.py │ └── rend_utils.py ├── configs ├── default.yaml ├── enerf │ ├── dtu │ │ └── scan114.yaml │ ├── enerf_outdoor │ │ ├── actor1_path.yaml │ │ └── actor1.yaml │ ├── nerf │ │ └── lego.yaml │ ├── nerf_eval.yaml │ ├── llff │ │ ├── flower.yaml │ │ └── fortress.yaml │ ├── llff_eval.yaml │ ├── zjumocap_eval.yaml │ ├── zjumocap │ │ └── zjumocap_train.yaml │ ├── interactive │ │ └── zjumocap.yaml │ ├── dtu_pretrain_nocascade.yaml │ └── dtu_pretrain.yaml └── nerf │ ├── nerf_pl.yaml │ ├── colmapvideo.yaml │ ├── nerf_ngp.yaml │ ├── nerf.yaml │ └── nerf_313.yaml ├── data ├── .gitignore └── mvsnerf │ ├── pairs.th │ ├── dtu_val_all.txt │ └── dtu_train_all.txt ├── assets ├── enerf_outdoor.jpg ├── enerf_outdoor_calib.png └── ENeRF-Outdoor_Agreement.pdf ├── .gitignore ├── imgui.ini ├── requirements.txt ├── docs └── enerf_outdoor.md ├── LICENSE ├── run.py └── train_net.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import cfg, args 2 | -------------------------------------------------------------------------------- /lib/interactive/__init__.py: -------------------------------------------------------------------------------- 1 | from .render_options import opt 2 | -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_network import make_network 2 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataset import make_data_loader 2 | -------------------------------------------------------------------------------- /lib/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_evaluator import make_evaluator 2 | -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_trainer import make_trainer 2 | -------------------------------------------------------------------------------- /lib/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_visualizer import make_visualizer 2 | -------------------------------------------------------------------------------- /data/mvsnerf/pairs.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/ENeRF/HEAD/data/mvsnerf/pairs.th -------------------------------------------------------------------------------- /assets/enerf_outdoor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/ENeRF/HEAD/assets/enerf_outdoor.jpg -------------------------------------------------------------------------------- /assets/enerf_outdoor_calib.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/ENeRF/HEAD/assets/enerf_outdoor_calib.png -------------------------------------------------------------------------------- /assets/ENeRF-Outdoor_Agreement.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/ENeRF/HEAD/assets/ENeRF-Outdoor_Agreement.pdf -------------------------------------------------------------------------------- /lib/interactive/fonts/ZillaSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/ENeRF/HEAD/lib/interactive/fonts/ZillaSlab-Regular.ttf -------------------------------------------------------------------------------- /lib/interactive/fonts/Caskaydia Cove Nerd Font Complete.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/ENeRF/HEAD/lib/interactive/fonts/Caskaydia Cove Nerd Font Complete.ttf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .ipynb_checkpoints/ 4 | *.py[cod] 5 | *.so 6 | *.orig 7 | *.o 8 | *.json 9 | *.pth 10 | *.npy 11 | *.ipynb 12 | *.png 13 | -------------------------------------------------------------------------------- /imgui.ini: -------------------------------------------------------------------------------- 1 | [Window][Debug##Default] 2 | Pos=60,60 3 | Size=400,400 4 | Collapsed=0 5 | 6 | [Window][Render Backend: cuda:0] 7 | Pos=537,55 8 | Size=310,365 9 | Collapsed=0 10 | 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | opencv-python 3 | imgaug 4 | plyfile 5 | tqdm 6 | kornia 7 | ipdb 8 | lpips 9 | tensorboardX 10 | glfw 11 | pyglm 12 | pyopengl 13 | imgui 14 | termcolor 15 | trimesh 16 | -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainers import make_trainer 2 | from .optimizer import make_optimizer 3 | from .scheduler import make_lr_scheduler, set_lr_scheduler 4 | from .recorder import make_recorder 5 | 6 | -------------------------------------------------------------------------------- /lib/utils/data_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | mean_rgb = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3).astype(np.float32) 3 | std_rgb = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3).astype(np.float32) 4 | -------------------------------------------------------------------------------- /data/mvsnerf/dtu_val_all.txt: -------------------------------------------------------------------------------- 1 | scan1 2 | scan8 3 | scan21 4 | scan30 5 | scan31 6 | scan34 7 | scan38 8 | scan40 9 | scan41 10 | scan45 11 | scan55 12 | scan63 13 | scan82 14 | scan103 15 | scan110 16 | scan114 17 | -------------------------------------------------------------------------------- /lib/networks/make_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imp 3 | 4 | 5 | def make_network(cfg): 6 | module = cfg.network_module 7 | path = cfg.network_path 8 | network = imp.load_source(module, path).Network() 9 | return network 10 | -------------------------------------------------------------------------------- /lib/visualizers/make_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imp 3 | 4 | 5 | def make_visualizer(cfg): 6 | module = cfg.visualizer_module 7 | path = cfg.visualizer_path 8 | visualizer = imp.load_source(module, path).Visualizer() 9 | return visualizer 10 | -------------------------------------------------------------------------------- /configs/enerf/dtu/scan114.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/enerf/dtu_pretrain.yaml 2 | exp_name: dtu_ft_scan114 3 | enerf: 4 | test_input_views: 4 5 | train_dataset: 6 | scene: scan114 7 | test_dataset: 8 | scene: scan114 9 | train: 10 | epoch: 150 # pretrained epoch + 11 11 | save_ep: 1 12 | eval_ep: 1 13 | -------------------------------------------------------------------------------- /lib/evaluators/make_evaluator.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | 4 | def _evaluator_factory(cfg): 5 | module = cfg.evaluator_module 6 | path = cfg.evaluator_path 7 | evaluator = imp.load_source(module, path).Evaluator() 8 | return evaluator 9 | 10 | 11 | def make_evaluator(cfg): 12 | if cfg.skip_eval: 13 | return None 14 | else: 15 | return _evaluator_factory(cfg) 16 | -------------------------------------------------------------------------------- /configs/enerf/enerf_outdoor/actor1_path.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: 'configs/enerf/enerf_outdoor/actor1.yaml' 2 | test_dataset_module: lib.datasets.enerf_outdoor.enerf_path 3 | visualizer_module: lib.visualizers.enerf 4 | enerf: 5 | test_input_views: 3 6 | cas_config: 7 | render_if: [False, True] 8 | test_dataset: 9 | frames: [0, 1000, 1] 10 | num_circle_view: 100 11 | input_views: [0, -1, 1] 12 | -------------------------------------------------------------------------------- /configs/nerf/nerf_pl.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/nerf/nerf.yaml 2 | task: nerf 3 | exp_name: 'nerf_pl' 4 | 5 | train_dataset_module: lib.datasets.nerf.synthetic.Dataset 6 | test_dataset_module: lib.datasets.nerf.synthetic 7 | network_module: lib.networks.nerf.network.Network 8 | loss_module: lib.train.losses.nerf.NetworkWrapper 9 | evaluator_module: lib.evaluators.nerf 10 | visualizer_module: lib.visualizers.nerf 11 | 12 | -------------------------------------------------------------------------------- /lib/datasets/collate_batch.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | import torch 3 | import numpy as np 4 | from lib.config import cfg 5 | 6 | _collators = {} 7 | 8 | def make_collator(cfg, is_train): 9 | collator = cfg.train.collator if is_train else cfg.test.collator 10 | if collator in _collators: 11 | return _collators[collator] 12 | else: 13 | return default_collate 14 | -------------------------------------------------------------------------------- /configs/enerf/nerf/lego.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/enerf/nerf_eval.yaml 2 | exp_name: nerf_ft_lego 3 | enerf: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: lego 11 | test_dataset: 12 | scene: lego 13 | train: 14 | epoch: 147 # pretrained epoch + 8 15 | save_ep: 1 16 | eval_ep: 1 17 | -------------------------------------------------------------------------------- /lib/train/trainers/make_trainer.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | import imp 3 | 4 | 5 | def _wrapper_factory(cfg, network, train_loader=None): 6 | module = cfg.loss_module 7 | path = cfg.loss_path 8 | network_wrapper = imp.load_source(module, path).NetworkWrapper(network, train_loader) 9 | return network_wrapper 10 | 11 | 12 | def make_trainer(cfg, network, train_loader=None): 13 | network = _wrapper_factory(cfg, network, train_loader) 14 | return Trainer(network) 15 | -------------------------------------------------------------------------------- /configs/enerf/nerf_eval.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/enerf/dtu_pretrain.yaml 2 | 3 | train_dataset_module: lib.datasets.nerf.enerf 4 | test_dataset_module: lib.datasets.nerf.enerf 5 | 6 | enerf: 7 | cas_config: 8 | render_if: [False, True] 9 | 10 | train_dataset: 11 | data_root: 'nerf_synthetic' # 12 | split: 'train' 13 | batch_size: 1 14 | input_ratio: 1. 15 | 16 | test_dataset: 17 | data_root: 'nerf_synthetic' # 18 | split: 'test' 19 | batch_size: 1 20 | input_ratio: 1. 21 | -------------------------------------------------------------------------------- /configs/enerf/llff/flower.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/enerf/llff_eval.yaml 2 | exp_name: llff_ft_flower 3 | 4 | enerf: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | 11 | train_dataset: 12 | scene: flower 13 | test_dataset: 14 | scene: flower 15 | train: 16 | epoch: 147 # pretrained epoch + 8 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/enerf/llff/fortress.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/enerf/llff_eval.yaml 2 | exp_name: llff_ft_fortress 3 | 4 | enerf: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | 11 | train_dataset: 12 | scene: fortress 13 | test_dataset: 14 | scene: fortress 15 | train: 16 | epoch: 147 # pretrained epoch + 8 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/nerf/colmapvideo.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/nerf/nerf.yaml 2 | scene: 'IMG_0622' 3 | 4 | train_dataset_module: lib.datasets.nerf.colmapvideo 5 | test_dataset_module: lib.datasets.nerf.colmapvideo 6 | 7 | task_arg: 8 | cascade_samples: [64, 64] 9 | 10 | network: 11 | nerf: 12 | W: 256 13 | D: 8 14 | V_D: 1 15 | 16 | train_dataset: 17 | data_root: 'data/iphonevideo' 18 | split: 'train' 19 | input_ratio: 0.5 20 | cams: [0, -1, 2] 21 | 22 | test_dataset: 23 | data_root: 'data/iphonevideo' 24 | split: 'test' 25 | input_ratio: 0.25 26 | cams: [1, -1, 4] 27 | -------------------------------------------------------------------------------- /configs/enerf/llff_eval.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/enerf/dtu_pretrain.yaml 2 | 3 | train_dataset_module: lib.datasets.llff.enerf 4 | test_dataset_module: lib.datasets.llff.enerf 5 | 6 | enerf: 7 | eval_center: True 8 | cas_config: 9 | render_if: [False, True] 10 | volume_planes: [32, 8] 11 | 12 | train_dataset: 13 | data_root: 'nerf_llff_data' # 14 | split: 'train' 15 | input_h_w: [640, 960] 16 | batch_size: 1 17 | input_ratio: 1. 18 | 19 | test_dataset: 20 | data_root: 'nerf_llff_data' # 21 | split: 'test' 22 | batch_size: 1 23 | input_h_w: [640, 960] 24 | input_ratio: 1. 25 | -------------------------------------------------------------------------------- /lib/train/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.utils.optimizer.radam import RAdam 3 | 4 | 5 | _optimizer_factory = { 6 | 'adam': torch.optim.Adam, 7 | 'radam': RAdam, 8 | 'sgd': torch.optim.SGD 9 | } 10 | 11 | 12 | def make_optimizer(cfg, net): 13 | params = [] 14 | lr = cfg.train.lr 15 | weight_decay = cfg.train.weight_decay 16 | eps = cfg.train.eps 17 | 18 | for key, value in net.named_parameters(): 19 | if not value.requires_grad: 20 | continue 21 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay, "eps": eps}] 22 | 23 | if 'adam' in cfg.train.optim: 24 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay, eps=eps) 25 | else: 26 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9) 27 | 28 | return optimizer 29 | -------------------------------------------------------------------------------- /docs/enerf_outdoor.md: -------------------------------------------------------------------------------- 1 | # ENeRF-Outdoor 2 | ## Introduction 3 | **ENeRF-Outdoor** is a dynamic dataset of multi-purpose outdoor scenes, collected by 18 synchronized cameras. Each sequence generally has about 1000 frames and has complex motions. 4 | ![](../assets/enerf_outdoor.jpg) 5 | 6 | 7 | Here is a visualization generated using colmap during calibration, which can be used to view the distribution of cameras. The front row of cameras in the picture is scanned with a mobile phone to provide more matches during calibration. 8 | ![](../assets/enerf_outdoor_calib.png) 9 | 10 | 11 | ## Download 12 | 13 | If someone wants to download the **ENeRF-Outdoor** dataset, please fill in the [agreement](https://github.com/zju3dv/ENeRF/blob/master/assets/ENeRF-Outdoor_Agreement.pdf), and email me (haotongl@zju.edu.cn) and cc Xiaowei Zhou (xwzhou@zju.edu.cn) and Sida Peng (pengsida@zju.edu.cn) to request the download link. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////// 2 | // Copyright 2022-2023 the 3D Vision Group at the State Key Lab of CAD&CG, 3 | // Zhejiang University. All Rights Reserved. 4 | // 5 | // For more information see 6 | // If you use this code, please cite the corresponding publications as 7 | // listed on the above website. 8 | // 9 | // Permission to use, copy, modify and distribute this software and its 10 | // documentation for educational, research and non-profit purposes only. 11 | // Any modification based on this work must be open source and prohibited 12 | // for commercial use. 13 | // You must retain, in the source form of any derivative works that you 14 | // distribute, all copyright, patent, trademark, and attribution notices 15 | // from the source form of this work. 16 | // 17 | // 18 | //////////////////////////////////////////////////////////////////////////// 19 | -------------------------------------------------------------------------------- /data/mvsnerf/dtu_train_all.txt: -------------------------------------------------------------------------------- 1 | scan3 2 | scan4 3 | scan5 4 | scan6 5 | scan9 6 | scan10 7 | scan11 8 | scan12 9 | scan13 10 | scan14 11 | scan15 12 | scan16 13 | scan17 14 | scan18 15 | scan19 16 | scan20 17 | scan22 18 | scan23 19 | scan24 20 | scan28 21 | scan32 22 | scan33 23 | scan35 24 | scan36 25 | scan37 26 | scan42 27 | scan43 28 | scan44 29 | scan46 30 | scan47 31 | scan48 32 | scan49 33 | scan50 34 | scan52 35 | scan53 36 | scan59 37 | scan60 38 | scan61 39 | scan62 40 | scan64 41 | scan65 42 | scan66 43 | scan67 44 | scan68 45 | scan69 46 | scan70 47 | scan71 48 | scan72 49 | scan74 50 | scan75 51 | scan76 52 | scan77 53 | scan84 54 | scan85 55 | scan86 56 | scan87 57 | scan88 58 | scan89 59 | scan90 60 | scan91 61 | scan92 62 | scan93 63 | scan94 64 | scan95 65 | scan96 66 | scan97 67 | scan98 68 | scan99 69 | scan100 70 | scan101 71 | scan102 72 | scan104 73 | scan105 74 | scan106 75 | scan107 76 | scan108 77 | scan109 78 | scan118 79 | scan119 80 | scan120 81 | scan121 82 | scan122 83 | scan123 84 | scan124 85 | scan125 86 | scan126 87 | scan127 88 | scan128 -------------------------------------------------------------------------------- /lib/train/scheduler.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from lib.utils.optimizer.lr_scheduler import WarmupMultiStepLR, MultiStepLR, ExponentialLR 3 | 4 | 5 | def make_lr_scheduler(cfg, optimizer): 6 | cfg_scheduler = cfg.train.scheduler 7 | if cfg_scheduler.type == 'multi_step': 8 | scheduler = MultiStepLR(optimizer, 9 | milestones=cfg_scheduler.milestones, 10 | gamma=cfg_scheduler.gamma) 11 | elif cfg_scheduler.type == 'exponential': 12 | scheduler = ExponentialLR(optimizer, 13 | decay_epochs=cfg_scheduler.decay_epochs, 14 | gamma=cfg_scheduler.gamma) 15 | return scheduler 16 | 17 | 18 | def set_lr_scheduler(cfg, scheduler): 19 | cfg_scheduler = cfg.train.scheduler 20 | if cfg_scheduler.type == 'multi_step': 21 | scheduler.milestones = Counter(cfg_scheduler.milestones) 22 | elif cfg_scheduler.type == 'exponential': 23 | scheduler.decay_epochs = cfg_scheduler.decay_epochs 24 | scheduler.gamma = cfg_scheduler.gamma 25 | -------------------------------------------------------------------------------- /lib/interactive/render_options.py: -------------------------------------------------------------------------------- 1 | # this is the render option class, just a dot dict 2 | # it should control all modifiable render options through the imgui options 3 | 4 | from lib.utils.base_utils import DotDict 5 | from lib.config import cfg 6 | 7 | 8 | opt = DotDict() 9 | 10 | # ----------------------------------------------------------------------------- 11 | # * Interactive Rendering Related 12 | # ----------------------------------------------------------------------------- 13 | opt.fps_cnter_int = 1 # update fps per 0.5 seconds 14 | opt.render_level = 1 # indexing rendering scale 15 | opt.type = 0 # indexing rendering scale 16 | opt.type_mapping = ['pred', 'depth', 'seg', 'bbox'] 17 | 18 | if cfg.test_dataset.scene == 'taekwondo' or cfg.test_dataset.scene == 'walking': 19 | opt.window_hw = [320, 640] 20 | elif 'cook' in cfg.test_dataset.scene or 'flame' in cfg.test_dataset.scene or 'coffee' in cfg.test_dataset.scene: 21 | opt.window_hw = [448, 640] 22 | else: 23 | opt.window_hw = [512, 512] 24 | # opt.window_hw = 512, 512 25 | opt.font_filepath = 'lib/interactive/fonts/Caskaydia Cove Nerd Font Complete.ttf' 26 | # opt.lock_fxfy = True 27 | opt.autoplay = True 28 | 29 | opt.smoothing_term = 0.1 30 | -------------------------------------------------------------------------------- /configs/enerf/zjumocap_eval.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: 'configs/enerf/dtu_pretrain.yaml' 2 | exp_name: dtu_pretrain 3 | 4 | 5 | train_dataset_module: lib.datasets.zjumocap.enerf 6 | test_dataset_module: lib.datasets.zjumocap.enerf 7 | network_module: lib.networks.enerf.network_human 8 | evaluator_module: lib.evaluators.enerf_human 9 | 10 | enerf: 11 | sample_on_mask: True 12 | train_input_views: [2, 3] 13 | train_input_views_prob: [0.9, 0.1] 14 | test_input_views: 2 15 | cas_config: 16 | train_img: [False, False] 17 | patch_size: [-1, 64] # 18 | num_rays: [4096, 16384] # 19 | num_patchs: [0, 4] # 20 | volume_planes: [32, 8] 21 | render_if: [False, True] 22 | 23 | train_dataset: 24 | data_root: 'zju_mocap' 25 | scene: 'CoreView_313' 26 | split: train 27 | frames: [0, 600, 1] 28 | input_views: [0, -1, 2] 29 | render_views: [0, -1, 2] 30 | input_ratio: 0.5 31 | 32 | test_dataset: 33 | data_root: 'zju_mocap' 34 | scene: 'CoreView_313' 35 | split: test 36 | frames: [0, 600, 100] 37 | input_views: [0, -1, 2] 38 | render_views: [1, -1, 2] 39 | input_ratio: 0.5 40 | 41 | train: 42 | batch_size: 1 43 | lr: 5e-4 44 | epoch: 100 45 | scheduler: 46 | type: 'exponential' 47 | gamma: 0.5 48 | decay_epochs: 10 49 | 50 | eval_ep: 1 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /configs/nerf/nerf_ngp.yaml: -------------------------------------------------------------------------------- 1 | task: nerf 2 | gpus: [0] 3 | exp_name: 'nerf' 4 | scene: 'lego' 5 | 6 | train_dataset_module: lib.datasets.nerf.synthetic 7 | test_dataset_module: lib.datasets.nerf.synthetic 8 | network_module: lib.networks.nerf.network 9 | loss_module: lib.train.losses.nerf 10 | evaluator_module: lib.evaluators.nerf 11 | visualizer_module: lib.visualizers.nerf 12 | 13 | task_arg: 14 | N_rays: 1024 15 | chunk_size: 4096 16 | white_bkgd: True 17 | cascade_samples: [64, 128] 18 | 19 | network: 20 | nerf: 21 | W: 64 22 | D: 1 23 | V_D: 2 24 | xyz_encoder: 25 | type: 'cuda_hashgrid' 26 | input_dim: 3 27 | dir_encoder: 28 | type: 'frequency' 29 | input_dim: 3 30 | freq: 2 31 | 32 | train_dataset: 33 | data_root: 'data/nerf_synthetic' 34 | split: 'train' 35 | input_ratio: 1. 36 | cams: [0, -1, 1] 37 | 38 | test_dataset: 39 | data_root: 'data/nerf_synthetic' 40 | split: 'test' 41 | input_ratio: 1. 42 | cams: [0, -1, 100] 43 | 44 | train: 45 | batch_size: 1 46 | lr: 5e-4 47 | weight_decay: 0. 48 | epoch: 400 49 | scheduler: 50 | type: 'exponential' 51 | gamma: 0.1 52 | decay_epochs: 1000 53 | num_workers: 4 54 | 55 | test: 56 | batch_size: 1 57 | 58 | ep_iter: 500 59 | save_ep: 20 60 | eval_ep: 20 # 10000 iterations 61 | save_latest_ep: 5 # 2500 iterations 62 | log_interval: 10 63 | -------------------------------------------------------------------------------- /configs/nerf/nerf.yaml: -------------------------------------------------------------------------------- 1 | task: nerf 2 | gpus: [0] 3 | exp_name: 'nerf' 4 | scene: 'lego' 5 | 6 | train_dataset_module: lib.datasets.nerf.synthetic 7 | test_dataset_module: lib.datasets.nerf.synthetic 8 | network_module: lib.networks.nerf.network 9 | loss_module: lib.train.losses.nerf 10 | evaluator_module: lib.evaluators.nerf 11 | visualizer_module: lib.visualizers.nerf 12 | 13 | task_arg: 14 | N_rays: 1024 15 | chunk_size: 4096 16 | white_bkgd: True 17 | cascade_samples: [64, 128] 18 | 19 | network: 20 | nerf: 21 | W: 256 22 | D: 8 23 | V_D: 1 24 | xyz_encoder: 25 | type: 'frequency' 26 | input_dim: 3 27 | freq: 10 28 | dir_encoder: 29 | type: 'frequency' 30 | input_dim: 3 31 | freq: 4 32 | 33 | train_dataset: 34 | data_root: 'data/nerf_synthetic' 35 | split: 'train' 36 | input_ratio: 1. 37 | cams: [0, -1, 1] 38 | 39 | test_dataset: 40 | data_root: 'data/nerf_synthetic' 41 | split: 'test' 42 | input_ratio: 1. 43 | cams: [0, -1, 100] 44 | 45 | train: 46 | batch_size: 1 47 | lr: 5e-4 48 | weight_decay: 0. 49 | epoch: 400 50 | scheduler: 51 | type: 'exponential' 52 | gamma: 0.1 53 | decay_epochs: 1000 54 | num_workers: 4 55 | 56 | test: 57 | batch_size: 1 58 | 59 | ep_iter: 500 60 | save_ep: 20 61 | eval_ep: 20 # 10000 iterations 62 | save_latest_ep: 5 # 2500 iterations 63 | log_interval: 10 64 | -------------------------------------------------------------------------------- /lib/train/losses/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.utils import net_utils 4 | from lib.config import cfg 5 | 6 | class NetworkWrapper(nn.Module): 7 | def __init__(self, net): 8 | super(NetworkWrapper, self).__init__() 9 | self.net = net 10 | self.color_crit = nn.MSELoss(reduction='mean') 11 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 12 | 13 | def forward(self, batch): 14 | output = self.net(batch) 15 | 16 | scalar_stats = {} 17 | loss = 0 18 | color_loss = self.color_crit(output['rgb_0'], batch['rgb']) 19 | scalar_stats.update({'color_mse_0': color_loss}) 20 | loss += color_loss 21 | 22 | psnr = -10. * torch.log(color_loss.detach()) / \ 23 | torch.log(torch.Tensor([10.]).to(color_loss.device)) 24 | scalar_stats.update({'psnr_0': psnr}) 25 | 26 | if len(cfg.task_arg.cascade_samples) > 1: 27 | color_loss = self.color_crit(output['rgb_1'], batch['rgb']) 28 | scalar_stats.update({'color_mse_1': color_loss}) 29 | loss += color_loss 30 | 31 | psnr = -10. * torch.log(color_loss.detach()) / \ 32 | torch.log(torch.Tensor([10.]).to(color_loss.device)) 33 | scalar_stats.update({'psnr_1': psnr}) 34 | 35 | scalar_stats.update({'loss': loss}) 36 | image_stats = {} 37 | 38 | return output, loss, scalar_stats, image_stats 39 | -------------------------------------------------------------------------------- /configs/nerf/nerf_313.yaml: -------------------------------------------------------------------------------- 1 | task: nerf 2 | gpus: [0] 3 | exp_name: 'nerf_313' 4 | 5 | N_rays: 1024 6 | chunk_size: 4096 7 | cascade_samples: [64] 8 | white_bkgd: False 9 | cache_items: 100000 10 | 11 | network: 12 | nerf: 13 | W: 256 14 | D: 8 15 | V_D: 1 16 | xyz_encoder: 17 | type: 'frequency' 18 | input_dim: 3 19 | freq: 10 20 | dir_encoder: 21 | type: 'frequency' 22 | input_dim: 3 23 | freq: 4 24 | 25 | train_dataset_module: lib.datasets.light_stage 26 | test_dataset_module: lib.datasets.light_stage 27 | network_module: lib.networks.nerf.network 28 | loss_module: lib.train.losses.nerf 29 | evaluator_module: lib.evaluators.nerf 30 | visualizer_module: lib.visualizers.nerf 31 | 32 | train_dataset: 33 | data_root: 'data/light_stage/CoreView_313' 34 | split: 'train' 35 | frames: [0, 1, 1] # start:end:skip 36 | cameras: [0, -1, 1] 37 | input_ratio: 0.5 38 | 39 | test_dataset: 40 | data_root: 'data/light_stage/CoreView_313' 41 | split: 'test' 42 | frames: [0, 1, 1] 43 | cameras: [0, -1, 1] 44 | input_ratio: 0.5 45 | 46 | train: 47 | batch_size: 1 48 | lr: 5e-4 49 | weight_decay: 0. 50 | epoch: 400 51 | scheduler: 52 | type: 'exponential' 53 | gamma: 0.1 54 | decay_epochs: 1000 55 | num_workers: 2 56 | 57 | test: 58 | batch_size: 1 59 | 60 | ep_iter: 500 61 | save_ep: 20 62 | eval_ep: 5 63 | save_latest_ep: 5 64 | log_interval: 10 65 | -------------------------------------------------------------------------------- /configs/enerf/zjumocap/zjumocap_train.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: 'configs/enerf/dtu_pretrain.yaml' 2 | exp_name: zjumocap 3 | pretrain: dtu_pretrain 4 | 5 | train_dataset_module: lib.datasets.zjumocap.enerf 6 | test_dataset_module: lib.datasets.zjumocap.enerf 7 | network_module: lib.networks.enerf.network_human 8 | evaluator_module: lib.evaluators.enerf_human 9 | 10 | enerf: 11 | sample_on_mask: True 12 | train_input_views: [2, 3] 13 | train_input_views_prob: [0.9, 0.1] 14 | test_input_views: 2 15 | cas_config: 16 | train_img: [False, False] 17 | patch_size: [-1, 64] # 18 | num_rays: [4096, 16384] # 19 | num_patchs: [0, 4] # 20 | volume_planes: [32, 8] 21 | render_if: [True, True] 22 | 23 | train_dataset: 24 | data_root: 'zju_mocap' 25 | scene: 'CoreView_313' 26 | split: train 27 | frames: [0, 600, 1] 28 | input_views: [0, -1, 2] 29 | render_views: [0, -1, 2] 30 | input_ratio: 0.5 31 | 32 | test_dataset: 33 | data_root: 'zju_mocap' 34 | scene: 'CoreView_313' 35 | split: test 36 | frames: [0, 600, 100] 37 | input_views: [0, -1, 2] 38 | render_views: [1, -1, 2] 39 | input_ratio: 0.5 40 | 41 | train: 42 | batch_size: 1 43 | lr: 5e-4 44 | epoch: 100 45 | scheduler: 46 | type: 'exponential' 47 | gamma: 0.5 48 | decay_epochs: 10 49 | sampler_meta: 50 | input_views_num: [2, 3] 51 | input_views_prob: [0.9, 0.1] 52 | 53 | eval_ep: 1 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/enerf/enerf_outdoor/actor1.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: 'configs/enerf/dtu_pretrain.yaml' 2 | exp_name: 'actor1' 3 | 4 | # module 5 | train_dataset_module: lib.datasets.enerf_outdoor.enerf 6 | test_dataset_module: lib.datasets.enerf_outdoor.enerf 7 | network_module: lib.networks.enerf.network_composite 8 | loss_module: lib.train.losses.enerf 9 | evaluator_module: lib.evaluators.enerf_composite 10 | visualizer_module: lib.visualizers.enerf 11 | 12 | 13 | num_fg_layers: 1 14 | # task config 15 | enerf: 16 | train_input_views: [2, 3, 4] 17 | train_input_views_prob: [0.2, 0.6, 0.2] 18 | test_input_views: 3 19 | viewdir_agg: False 20 | cas_config: 21 | volume_planes: [32, 8] 22 | num_samples: [2, 1] # 23 | 24 | train_dataset: 25 | data_root: 'enerf_outdoor' # 26 | frames: [0, 1000, 1] 27 | input_ratio: 0.75 28 | input_h_w: [768, 1024] 29 | input_views: [0, -1, 1] 30 | render_views: [0, -1, 1] 31 | split: 'train' 32 | scene: 'actor1' 33 | 34 | test_dataset: 35 | data_root: 'enerf_outdoor' # 36 | frames: [0, 1000, 400] 37 | input_ratio: 0.75 38 | input_h_w: [768, 1024] 39 | input_views: [1, -1, 1] 40 | render_views: [0, 1, 1] 41 | split: 'test' 42 | scene: 'actor1' 43 | 44 | train: 45 | lr: 5e-4 46 | epoch: 50 47 | scheduler: 48 | type: 'exponential' 49 | gamma: 0.1 50 | decay_epochs: 50 51 | sampler_meta: 52 | input_views_num: [2, 3, 4] 53 | input_views_prob: [0.2, 0.6, 0.2] 54 | -------------------------------------------------------------------------------- /configs/enerf/interactive/zjumocap.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: 'configs/enerf/dtu_pretrain.yaml' 2 | exp_name: dtu_pretrain 3 | 4 | 5 | train_dataset_module: lib.datasets.zjumocap.enerf_interactive 6 | test_dataset_module: lib.datasets.zjumocap.enerf_interactive 7 | network_module: lib.networks.enerf.network_human 8 | evaluator_module: lib.evaluators.enerf_human 9 | visualizer_module: lib.visualizers.enerf_interactive 10 | 11 | enerf: 12 | sample_on_mask: True 13 | train_input_views: [2, 3] 14 | train_input_views_prob: [0.9, 0.1] 15 | test_input_views: 2 16 | cas_config: 17 | train_img: [False, False] 18 | patch_size: [-1, 64] # 19 | num_rays: [4096, 16384] # 20 | num_patchs: [0, 4] # 21 | volume_planes: [32, 8] 22 | render_if: [False, True] 23 | 24 | train_dataset: 25 | data_root: 'zju_mocap' 26 | scene: 'CoreView_313' 27 | split: train 28 | frames: [0, 600, 1] 29 | input_views: [0, -1, 2] 30 | render_views: [0, -1, 2] 31 | input_ratio: 0.5 32 | 33 | test_dataset: 34 | data_root: 'zju_mocap' 35 | scene: 'CoreView_313' 36 | split: test 37 | frames: [0, 100, 10] # Render frames 38 | input_views: [0, -1, 1] # use 21 views 39 | render_views: [1, -1, 2] # not important 40 | input_ratio: 0.5 41 | 42 | train: 43 | batch_size: 1 44 | lr: 5e-4 45 | epoch: 100 46 | scheduler: 47 | type: 'exponential' 48 | gamma: 0.5 49 | decay_epochs: 10 50 | 51 | eval_ep: 1 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /lib/utils/mesh_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import trimesh 5 | from skimage import measure 6 | 7 | def extract_mesh(queryfn, level, bbox, output_path='test.ply', N=256): 8 | bbox = np.array(bbox).reshape((2, 3)) 9 | 10 | voxel_grid_origin = np.mean(bbox, axis=0) 11 | volume_size = bbox[1] - bbox[0] 12 | s = volume_size[0] 13 | 14 | overall_index = np.arange(0, N ** 3, 1).astype(np.int) 15 | xyz = np.zeros([N ** 3, 3]) 16 | 17 | # transform first 3 columns 18 | # to be the x, y, z index 19 | xyz[:, 2] = overall_index % N 20 | xyz[:, 1] = (overall_index / N) % N 21 | xyz[:, 0] = ((overall_index / N) / N) % N 22 | 23 | # transform first 3 columns 24 | # to be the x, y, z coordinate 25 | xyz[:, 0] = (xyz[:, 0] * (s/(N-1))) + bbox[0][0] 26 | xyz[:, 1] = (xyz[:, 1] * (s/(N-1))) + bbox[0][1] 27 | xyz[:, 2] = (xyz[:, 2] * (s/(N-1))) + bbox[0][2] 28 | 29 | xyz = torch.from_numpy(xyz).float() 30 | 31 | batch_size = 8192 32 | density = [] 33 | for i in tqdm(range(N ** 3 // batch_size)): 34 | start = i * batch_size 35 | end = (i + 1) * batch_size 36 | density.append(queryfn(xyz[start: end].cuda())[..., 0].detach().cpu()) 37 | 38 | density = torch.cat(density, dim=-1) 39 | density = density.view(N, N, N) 40 | vertices, faces, normals, _ = measure.marching_cubes_lewiner(density.numpy(), level=level, spacing=[float(v) / N for v in volume_size]) 41 | vertices += voxel_grid_origin 42 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces) 43 | mesh.export(output_path) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lib/visualizers/enerf_interactive.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from lib.utils import data_utils 3 | from lib.utils import img_utils 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from lib.config import cfg 7 | import cv2 8 | import torch 9 | import imageio 10 | from skimage.metrics import peak_signal_noise_ratio as psnr 11 | 12 | class Visualizer: 13 | def __init__(self,): 14 | self.HW = [512, 512] 15 | 16 | def visualize(self, output, batch): 17 | B = 1 18 | for b in range(B): 19 | # H, W = batch['meta']['h'][b].item(), batch['meta']['w'][b].item() 20 | # H, W = self.HW 21 | _, _, _, H, W = batch['src_inps'].shape 22 | i = cfg.enerf.cas_config.num - 1 23 | # gt_img = batch[f'rgb_{i}'][b].reshape(H, W, 3).detach().cpu().numpy() 24 | pred_img = output[f'rgb_level{i}'][b].reshape(H, W, 3) 25 | # imageio.imwrite('test.png', pred_img.detach().cpu().numpy()) 26 | ret = {'pred': pred_img} 27 | if 'vis_ret' in output and False: 28 | seg = (output['vis_ret']['layer_0_weight'] > output['vis_ret']['layer_3_weight'])[b].float()[..., None].repeat(1, 1, 3) 29 | depth = output['vis_ret']['depth'][b][..., None].repeat(1, 1, 3) 30 | bbox = batch['masks'][b, 0].float()[..., None].repeat(1, 1, 3) 31 | ret.update({'seg': seg, 'depth': depth, 'bbox': bbox}) 32 | # src_inps = (batch['src_inps'][b] * 0.5 + 0.5).detach().cpu() 33 | # idx = 0 34 | # for src_inp in src_inps: 35 | # cv2.imshow(f'src_{idx}', ((src_inp.permute(1, 2, 0).numpy()[..., [2,1,0]])*255).astype(np.uint8)) 36 | # cv2.waitKey(1) 37 | # idx += 1 38 | # ret.update({'bbox': batch['masks'][0][0][..., None].repeat(1, 1, 3)}) 39 | 40 | return ret 41 | # psnr_item = psnr(gt_img, pred_img, data_range=1.) 42 | 43 | # print(psnr_item) 44 | # plt.imshow(np.concatenate([gt_img, pred_img], axis=1)) 45 | # plt.show() 46 | 47 | -------------------------------------------------------------------------------- /lib/visualizers/enerf.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from lib.utils import data_utils 3 | from lib.utils import img_utils 4 | from lib.config import cfg 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | import imageio 9 | import os 10 | 11 | class Visualizer: 12 | def __init__(self,): 13 | self.write_video = cfg.write_video 14 | self.imgs = [] 15 | self.depths = [] 16 | self.imgs_coarse = [] 17 | os.system('mkdir -p {}'.format(cfg.result_dir)) 18 | os.system('mkdir -p {}'.format(cfg.result_dir + '/imgs')) 19 | 20 | def visualize(self, output, batch): 21 | B, S, _, H, W = batch['src_inps'].shape 22 | i = cfg.enerf.cas_config.num - 1 23 | render_scale = cfg.enerf.cas_config.render_scale[i] 24 | h, w = int(H*render_scale), int(W*render_scale) 25 | assert(B == 1) 26 | pred_rgb = output[f'rgb_level{i}'].reshape(h, w, 3).detach().cpu().numpy() 27 | depth = output[f'depth_level{i}'].reshape(h, w).detach().cpu().numpy() 28 | crop_h, crop_w = int(h * 0.1), int(w * 0.1) 29 | pred_rgb = pred_rgb[crop_h:, crop_w:-crop_w] 30 | depth = depth[crop_h:, crop_w:-crop_w] 31 | self.imgs.append(pred_rgb) 32 | self.depths.append(depth) 33 | if cfg.save_result: 34 | frame_id = batch['meta']['frame_id'][0].item() 35 | imageio.imwrite(os.path.join(cfg.result_dir, 'imgs/{:06d}_rgb.jpg'.format(frame_id)), pred_rgb) 36 | imageio.imwrite(os.path.join(cfg.result_dir, 'imgs/{:06d}_dpt.jpg'.format(frame_id)), ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)) 37 | 38 | def summarize(self): 39 | imageio.mimwrite(os.path.join(cfg.result_dir, 'color.mp4'), self.imgs, fps=cfg.fps) 40 | d_min, d_max = np.array(self.depths).min(), np.array(self.depths).max() 41 | self.depths = [ (dpt - d_min)/(d_max-d_min) for dpt in self.depths ] 42 | imageio.mimwrite(os.path.join(cfg.result_dir, 'depth.mp4'), self.depths, fps=cfg.fps) 43 | print('Save visualization results into {}'.format(cfg.result_dir)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lib/evaluators/nerf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.config import cfg 3 | import os 4 | import imageio 5 | from lib.utils import img_utils 6 | from skimage.metrics import structural_similarity as ssim 7 | from skimage.metrics import peak_signal_noise_ratio as psnr 8 | import torch.nn.functional as F 9 | import torch 10 | import lpips 11 | import imageio 12 | from lib.utils import img_utils 13 | import cv2 14 | 15 | 16 | class Evaluator: 17 | 18 | def __init__(self,): 19 | self.psnrs = [] 20 | self.psnrs_0 = [] 21 | os.system('mkdir -p ' + cfg.result_dir) 22 | 23 | def evaluate(self, output, batch): 24 | B, N_rays = batch['rays'].shape[:2] 25 | for b in range(B): 26 | gt_rgb = batch['rgb'][b].reshape(-1, 3).detach().cpu().numpy() 27 | pred_rgb = output['rgb_1'][b].detach().cpu().numpy() 28 | self.psnrs.append(psnr(pred_rgb, gt_rgb, data_range=1.)) 29 | pred_rgb = output['rgb_0'][b].detach().cpu().numpy() 30 | psnr_item = psnr(gt_rgb, pred_rgb, data_range=1.) 31 | self.psnrs_0.append(psnr_item) 32 | if cfg.save_result: 33 | h, w = batch['meta']['h'][b].item(), batch['meta']['w'][b].item() 34 | gt_rgb = batch['rgb'][b].reshape(h, w, 3).detach().cpu().numpy() 35 | pred_rgb_coarse = output['rgb_0'][b].reshape(h, w, 3).detach().cpu().numpy() 36 | pred_rgb_fine = output['rgb_1'][b].reshape(h, w, 3).detach().cpu().numpy() 37 | save_path = os.path.join(cfg.result_dir, 'view{:06d}'.format(batch['meta']['idx'][b].item())) 38 | save_path = save_path + '_{}.jpg' 39 | imageio.imwrite(save_path.format('gt'), gt_rgb) 40 | imageio.imwrite(save_path.format('coarse'), pred_rgb_coarse) 41 | imageio.imwrite(save_path.format('fine'), pred_rgb_fine) 42 | 43 | def summarize(self): 44 | ret = {} 45 | ret.update({'psnr': np.mean(self.psnrs)}) 46 | if len(self.psnrs_0) != 0: 47 | ret.update({'psnr_0': np.mean(self.psnrs_0)}) 48 | self.psnrs_0 = [] 49 | print(ret) 50 | self.psnrs = [] 51 | return ret 52 | -------------------------------------------------------------------------------- /lib/train/losses/vgg_perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | class VGGPerceptualLoss(torch.nn.Module): 5 | def __init__(self, resize=False): 6 | super(VGGPerceptualLoss, self).__init__() 7 | blocks = [] 8 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 9 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 10 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 11 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 12 | for bl in blocks: 13 | for p in bl.parameters(): 14 | p.requires_grad = False 15 | self.blocks = torch.nn.ModuleList(blocks) 16 | self.transform = torch.nn.functional.interpolate 17 | self.resize = resize 18 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 19 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 20 | 21 | def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): 22 | if input.shape[1] != 3: 23 | input = input.repeat(1, 3, 1, 1) 24 | target = target.repeat(1, 3, 1, 1) 25 | input = (input-self.mean) / self.std 26 | target = (target-self.mean) / self.std 27 | if self.resize: 28 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 29 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 30 | loss = 0.0 31 | x = input 32 | y = target 33 | for i, block in enumerate(self.blocks): 34 | x = block(x) 35 | y = block(y) 36 | if i in feature_layers: 37 | loss += torch.nn.functional.l1_loss(x, y) 38 | if i in style_layers: 39 | act_x = x.reshape(x.shape[0], x.shape[1], -1) 40 | act_y = y.reshape(y.shape[0], y.shape[1], -1) 41 | gram_x = act_x @ act_x.permute(0, 2, 1) 42 | gram_y = act_y @ act_y.permute(0, 2, 1) 43 | loss += torch.nn.functional.l1_loss(gram_x, gram_y) 44 | return loss 45 | -------------------------------------------------------------------------------- /lib/networks/enerf/feature_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import * 3 | 4 | class FeatureNet(nn.Module): 5 | def __init__(self, norm_act=nn.BatchNorm2d): 6 | super(FeatureNet, self).__init__() 7 | self.conv0 = nn.Sequential( 8 | ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), 9 | ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act)) 10 | self.conv1 = nn.Sequential( 11 | ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), 12 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act)) 13 | self.conv2 = nn.Sequential( 14 | ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), 15 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act)) 16 | 17 | self.toplayer = nn.Conv2d(32, 32, 1) 18 | self.lat1 = nn.Conv2d(16, 32, 1) 19 | self.lat0 = nn.Conv2d(8, 32, 1) 20 | 21 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) 22 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) 23 | 24 | def _upsample_add(self, x, y): 25 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y 26 | 27 | def forward(self, x): 28 | conv0 = self.conv0(x) 29 | conv1 = self.conv1(conv0) 30 | conv2 = self.conv2(conv1) 31 | feat2 = self.toplayer(conv2) 32 | feat1 = self._upsample_add(feat2, self.lat1(conv1)) 33 | feat0 = self._upsample_add(feat1, self.lat0(conv0)) 34 | feat1 = self.smooth1(feat1) 35 | feat0 = self.smooth0(feat0) 36 | return feat2, feat1, feat0 37 | 38 | class CNNRender(nn.Module): 39 | def __init__(self, norm_act=nn.BatchNorm2d): 40 | super(CNNRender, self).__init__() 41 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act) 42 | self.conv1 = ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act) 43 | self.conv2 = nn.Conv2d(8, 16, 1) 44 | self.conv3 = nn.Conv2d(16, 3, 1) 45 | 46 | def _upsample_add(self, x, y): 47 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y 48 | 49 | def forward(self, x): 50 | conv0 = self.conv0(x) 51 | conv1 = self.conv1(conv0) 52 | conv2 = self._upsample_add(conv1, self.conv2(conv0)) 53 | conv3 = self.conv3(conv2) 54 | return torch.clamp(conv3+x, 0., 1.) 55 | 56 | -------------------------------------------------------------------------------- /lib/visualizers/nerf.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from lib.utils import data_utils 3 | from lib.utils import img_utils 4 | from lib.config import cfg 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | import imageio 9 | import os 10 | 11 | class Visualizer: 12 | def __init__(self,): 13 | self.write_video = cfg.write_video 14 | self.imgs = [] 15 | self.depths = [] 16 | self.imgs_coarse = [] 17 | os.system('mkdir -p {}'.format(cfg.result_dir)) 18 | os.system('mkdir -p {}'.format(cfg.result_dir + '/imgs')) 19 | 20 | def visualize(self, output, batch): 21 | B, N_rays = batch['rays'].shape[:2] 22 | for b in range(B): 23 | h, w = batch['meta']['h'][b].item(), batch['meta']['w'][b].item() 24 | if 'fine_s_rgb_map_1' in output and cfg.render_static: 25 | img = output['fine_s_rgb_map_1'][b].reshape(h, w, 3).detach().cpu().numpy() 26 | depth = output['fine_s_depth_map_1'][b].reshape(h, w).detach().cpu().numpy() 27 | else: 28 | depth = output['depth_1'][b].reshape(h, w).detach().cpu().numpy() 29 | img = output['rgb_1'][b].reshape(h, w, 3).detach().cpu().numpy() 30 | img_coarse = output['rgb_0'][b].reshape(h, w, 3).detach().cpu().numpy() 31 | self.imgs_coarse.append(img_coarse) 32 | idx = batch['meta']['seq_id'][b].item() 33 | imageio.imwrite(os.path.join(cfg.result_dir, 'imgs/{:06d}_rgb.png'.format(idx)), img) 34 | imageio.imwrite(os.path.join(cfg.result_dir, 'imgs/{:06d}_rgb_coarse.png'.format(idx)), img_coarse) 35 | imageio.imwrite(os.path.join(cfg.result_dir, 'imgs/{:06d}_dpt.png'.format(idx)), ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)) 36 | self.imgs.append(img) 37 | self.depths.append(depth) 38 | 39 | def summarize(self): 40 | imageio.mimwrite(os.path.join(cfg.result_dir, 'color_coarse.mp4'), self.imgs_coarse, fps=cfg.fps) 41 | imageio.mimwrite(os.path.join(cfg.result_dir, 'color.mp4'), self.imgs, fps=cfg.fps) 42 | d_min, d_max = np.array(self.depths).min(), np.array(self.depths).max() 43 | self.depths = [ (dpt - d_min)/(d_max-d_min) for dpt in self.depths ] 44 | imageio.mimwrite(os.path.join(cfg.result_dir, 'depth.mp4'), self.depths, fps=cfg.fps) 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /configs/enerf/dtu_pretrain_nocascade.yaml: -------------------------------------------------------------------------------- 1 | task: enerf 2 | gpus: [0, 1, 2, 3] 3 | exp_name: 'dtu_pretrain_nocascade' 4 | 5 | # module 6 | train_dataset_module: lib.datasets.dtu.enerf 7 | test_dataset_module: lib.datasets.dtu.enerf 8 | network_module: lib.networks.enerf.network 9 | loss_module: lib.train.losses.enerf 10 | evaluator_module: lib.evaluators.enerf 11 | visualizer_module: lib.visualizers.enerf 12 | 13 | save_result: False 14 | eval_lpips: True 15 | 16 | # task config 17 | enerf: 18 | train_input_views: [2, 3, 4] 19 | train_input_views_prob: [0.1, 0.8, 0.1] 20 | test_input_views: 3 21 | viewdir_agg: True 22 | chunk_size: 1000000 23 | white_bkgd: False 24 | eval_depth: False 25 | eval_center: False # only for llff evaluation (same as MVSNeRF: https://github.com/apchenstu/mvsnerf/blob/1fdf6487389d0872dade614b3cea61f7b099406e/renderer.ipynb) 26 | sample_on_mask: False # only for ZJU-MoCap/DynamicCap 27 | cas_config: 28 | num: 1 29 | depth_inv: [True] 30 | volume_scale: [0.25] 31 | volume_planes: [48] 32 | im_feat_scale: [0.25] 33 | im_ibr_scale: [1.] 34 | render_scale: [1.0] 35 | render_im_feat_level: [2] 36 | nerf_model_feat_ch: [8] 37 | render_if: [True] 38 | num_samples: [2] # 39 | num_rays: [32768] # 40 | num_patchs: [0] # 41 | train_img: [True] 42 | patch_size: [-1] # 43 | loss_weight: [1.] 44 | 45 | train_dataset: 46 | data_root: 'dtu' # 47 | ann_file: 'data/mvsnerf/dtu_train_all.txt' 48 | split: 'train' 49 | batch_size: 2 50 | input_ratio: 1. 51 | 52 | test_dataset: 53 | data_root: 'dtu' # 54 | ann_file: 'data/mvsnerf/dtu_val_all.txt' 55 | split: 'test' 56 | batch_size: 1 57 | input_ratio: 1. 58 | 59 | train: 60 | batch_size: 1 61 | lr: 5e-4 62 | weight_decay: 0. 63 | epoch: 300 64 | scheduler: 65 | type: 'exponential' 66 | gamma: 0.5 67 | decay_epochs: 50 68 | batch_sampler: 'enerf' 69 | collator: 'enerf' 70 | sampler_meta: 71 | input_views_num: [2, 3, 4] 72 | input_views_prob: [0.1, 0.8, 0.1] 73 | num_workers: 4 74 | 75 | test: 76 | batch_size: 1 77 | collator: 'enerf' 78 | batch_sampler: 'enerf' 79 | sampler_meta: 80 | input_views_num: [3] 81 | input_views_prob: [1.] 82 | 83 | ep_iter: 1000 84 | save_ep: 5 85 | eval_ep: 5 86 | save_latest_ep: 1 87 | log_interval: 1 88 | -------------------------------------------------------------------------------- /configs/enerf/dtu_pretrain.yaml: -------------------------------------------------------------------------------- 1 | task: enerf 2 | gpus: [0, 1, 2, 3] 3 | exp_name: 'dtu_pretrain' 4 | 5 | # module 6 | train_dataset_module: lib.datasets.dtu.enerf 7 | test_dataset_module: lib.datasets.dtu.enerf 8 | network_module: lib.networks.enerf.network 9 | loss_module: lib.train.losses.enerf 10 | evaluator_module: lib.evaluators.enerf 11 | visualizer_module: lib.visualizers.enerf 12 | 13 | save_result: False 14 | eval_lpips: True 15 | 16 | # task config 17 | enerf: 18 | train_input_views: [2, 3, 4] 19 | train_input_views_prob: [0.1, 0.8, 0.1] 20 | test_input_views: 3 21 | viewdir_agg: True 22 | chunk_size: 1000000 23 | white_bkgd: False 24 | eval_depth: False 25 | eval_center: False # only for llff evaluation (same as MVSNeRF: https://github.com/apchenstu/mvsnerf/blob/1fdf6487389d0872dade614b3cea61f7b099406e/renderer.ipynb) 26 | sample_on_mask: False # only for ZJU-MoCap/DynamicCap 27 | cas_config: 28 | num: 2 29 | depth_inv: [True, False] 30 | volume_scale: [0.125, 0.5] 31 | volume_planes: [64, 8] 32 | im_feat_scale: [0.25, 0.5] 33 | im_ibr_scale: [0.25, 1.] 34 | render_scale: [0.25, 1.0] 35 | render_im_feat_level: [0, 2] 36 | nerf_model_feat_ch: [32, 8] 37 | render_if: [True, True] 38 | num_samples: [8, 2] # 39 | num_rays: [4096, 32768] # 40 | num_patchs: [0, 0] # 41 | train_img: [True, True] 42 | patch_size: [-1, -1] # 43 | loss_weight: [0.1, 1.] 44 | 45 | train_dataset: 46 | data_root: 'dtu' # 47 | ann_file: 'data/mvsnerf/dtu_train_all.txt' 48 | split: 'train' 49 | batch_size: 2 50 | input_ratio: 1. 51 | 52 | test_dataset: 53 | data_root: 'dtu' # 54 | ann_file: 'data/mvsnerf/dtu_val_all.txt' 55 | split: 'test' 56 | batch_size: 1 57 | input_ratio: 1. 58 | 59 | train: 60 | batch_size: 1 61 | lr: 5e-4 62 | weight_decay: 0. 63 | epoch: 300 64 | scheduler: 65 | type: 'exponential' 66 | gamma: 0.5 67 | decay_epochs: 50 68 | batch_sampler: 'enerf' 69 | collator: 'enerf' 70 | sampler_meta: 71 | input_views_num: [2, 3, 4] 72 | input_views_prob: [0.1, 0.8, 0.1] 73 | num_workers: 4 74 | 75 | test: 76 | batch_size: 1 77 | collator: 'enerf' 78 | batch_sampler: 'enerf' 79 | sampler_meta: 80 | input_views_num: [3] 81 | input_views_prob: [1.] 82 | 83 | ep_iter: 1000 84 | save_ep: 5 85 | eval_ep: 5 86 | save_latest_ep: 1 87 | log_interval: 1 88 | -------------------------------------------------------------------------------- /lib/utils/colmap/test_read_write_dense.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import numpy as np 33 | from read_write_dense import read_array, write_array 34 | 35 | 36 | def main(): 37 | import sys 38 | if len(sys.argv) != 3: 39 | print("Usage: python test_read_write_dense.py " 40 | "path/to/dense/input.bin path/to/dense/output.bin") 41 | return 42 | 43 | print("Checking consistency of reading and writing dense arrays " 44 | + "(depth maps / normal maps) ...") 45 | 46 | path_to_dense_input = sys.argv[1] 47 | path_to_dense_output = sys.argv[2] 48 | 49 | dense_input = read_array(path_to_dense_input) 50 | print("Input shape: " + str(dense_input.shape)) 51 | 52 | write_array(dense_input, path_to_dense_output) 53 | dense_output = read_array(path_to_dense_output) 54 | 55 | np.testing.assert_array_equal(dense_input, dense_output) 56 | 57 | print("... dense arrays are equal.") 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /lib/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def get_bound_corners(bounds): 5 | min_x, min_y, min_z = bounds[0] 6 | max_x, max_y, max_z = bounds[1] 7 | corners_3d = np.array([ 8 | [min_x, min_y, min_z], 9 | [min_x, min_y, max_z], 10 | [min_x, max_y, min_z], 11 | [min_x, max_y, max_z], 12 | [max_x, min_y, min_z], 13 | [max_x, min_y, max_z], 14 | [max_x, max_y, min_z], 15 | [max_x, max_y, max_z], 16 | ]) 17 | return corners_3d 18 | 19 | def project(xyz, K, RT): 20 | """ 21 | xyz: [N, 3] 22 | K: [3, 3] 23 | RT: [3, 4] 24 | """ 25 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 26 | 27 | xyz = np.dot(xyz, K.T) 28 | xy = xyz[:, :2] / xyz[:, 2:] 29 | return xy 30 | 31 | row_col_ = { 32 | 2: (2, 1), 33 | 7: (2, 4), 34 | 8: (2, 4), 35 | 9: (3, 3), 36 | 26: (4, 7) 37 | } 38 | 39 | row_col_square = { 40 | 2: (2, 1), 41 | 7: (3, 3), 42 | 8: (3, 3), 43 | 9: (3, 3), 44 | 26: (5, 5) 45 | } 46 | 47 | def get_row_col(l, square): 48 | if square and l in row_col_square.keys(): 49 | return row_col_square[l] 50 | if l in row_col_.keys(): 51 | return row_col_[l] 52 | else: 53 | from math import sqrt 54 | row = int(sqrt(l) + 0.5) 55 | col = int(l/ row + 0.5) 56 | if row*col col: 59 | row, col = col, row 60 | return row, col 61 | 62 | def merge(images, row=-1, col=-1, resize=False, ret_range=False, square=False, **kwargs): 63 | if row == -1 and col == -1: 64 | row, col = get_row_col(len(images), square) 65 | height = images[0].shape[0] 66 | width = images[0].shape[1] 67 | # special case 68 | if height > width: 69 | if len(images) == 3: 70 | row, col = 1, 3 71 | if len(images[0].shape) > 2: 72 | ret_img = np.zeros((height * row, width * col, images[0].shape[2]), dtype=np.uint8) + 255 73 | else: 74 | ret_img = np.zeros((height * row, width * col), dtype=np.uint8) + 255 75 | ranges = [] 76 | for i in range(row): 77 | for j in range(col): 78 | if i*col + j >= len(images): 79 | break 80 | img = images[i * col + j] 81 | # resize the image size 82 | img = cv2.resize(img, (width, height)) 83 | ret_img[height * i: height * (i+1), width * j: width * (j+1)] = img 84 | ranges.append((width*j, height*i, width*(j+1), height*(i+1))) 85 | if resize: 86 | min_height = 1000 87 | if ret_img.shape[0] > min_height: 88 | scale = min_height/ret_img.shape[0] 89 | ret_img = cv2.resize(ret_img, None, fx=scale, fy=scale) 90 | if ret_range: 91 | return ret_img, ranges 92 | return ret_img 93 | -------------------------------------------------------------------------------- /lib/utils/optimizer/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from collections import Counter 3 | 4 | import torch 5 | 6 | 7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | warmup_factor=1.0 / 3, 14 | warmup_iters=5, 15 | warmup_method="linear", 16 | last_epoch=-1, 17 | ): 18 | if not list(milestones) == sorted(milestones): 19 | raise ValueError( 20 | "Milestones should be a list of" " increasing integers. Got {}", 21 | milestones, 22 | ) 23 | 24 | if warmup_method not in ("constant", "linear"): 25 | raise ValueError( 26 | "Only 'constant' or 'linear' warmup_method accepted" 27 | "got {}".format(warmup_method) 28 | ) 29 | self.milestones = milestones 30 | self.gamma = gamma 31 | self.warmup_factor = warmup_factor 32 | self.warmup_iters = warmup_iters 33 | self.warmup_method = warmup_method 34 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | warmup_factor = 1 38 | if self.last_epoch < self.warmup_iters: 39 | if self.warmup_method == "constant": 40 | warmup_factor = self.warmup_factor 41 | elif self.warmup_method == "linear": 42 | alpha = float(self.last_epoch) / self.warmup_iters 43 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 44 | return [ 45 | base_lr 46 | * warmup_factor 47 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | 52 | class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): 53 | 54 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 55 | self.milestones = Counter(milestones) 56 | self.gamma = gamma 57 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 58 | 59 | def get_lr(self): 60 | if self.last_epoch not in self.milestones: 61 | return [group['lr'] for group in self.optimizer.param_groups] 62 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] 63 | for group in self.optimizer.param_groups] 64 | 65 | 66 | class ExponentialLR(torch.optim.lr_scheduler._LRScheduler): 67 | 68 | def __init__(self, optimizer, decay_epochs, gamma=0.1, last_epoch=-1): 69 | self.decay_epochs = decay_epochs 70 | self.gamma = gamma 71 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 72 | 73 | def get_lr(self): 74 | return [base_lr * self.gamma ** (self.last_epoch / self.decay_epochs) 75 | for base_lr in self.base_lrs] 76 | -------------------------------------------------------------------------------- /lib/utils/enerf/val_data_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import random 5 | from lib.networks.enerf.utils import unpreprocess 6 | 7 | def validate(batch): 8 | B = len(batch['tar_img']) 9 | num_points = 10 10 | batch_src_inps = unpreprocess(batch['src_inp']).cpu().numpy() 11 | for b in range(B): 12 | rgb = batch['tar_img'][b].cpu().numpy() # rgb 13 | gray = cv2.cvtColor((rgb*255.).astype(np.uint8), cv2.COLOR_RGB2GRAY) 14 | sift = cv2.xfeatures2d.SIFT_create() 15 | keypoints = sift.detectAndCompute(gray, None)[0] 16 | points = [keypoint.pt for keypoint in keypoints] 17 | points = np.stack(random.sample(points, num_points)) 18 | near, far = batch['near_far'][b][0].item(), batch['near_far'][b][1].item() 19 | points_near = np.concatenate((points, near * np.ones_like(points[:, :1])), axis=-1) 20 | points_far = np.concatenate((points, far * np.ones_like(points[:, :1])), axis=-1) 21 | 22 | src_inps = batch_src_inps[b] 23 | S = len(src_inps) 24 | 25 | ax = plt.subplot(1, 1+S, 1) 26 | ax.axis('off') 27 | ax.set_title('1') 28 | plt.imshow(rgb) 29 | for i in range(len(points)): 30 | plt.plot(points[i, 0][None], points[i, 1][None], '.') 31 | 32 | for s in range(S): 33 | points_near_s = transform(points_near.copy(), batch, b, s) 34 | points_far_s = transform(points_far.copy(), batch, b, s) 35 | lines = [] 36 | for point_near, point_far in zip(points_near_s, points_far_s): 37 | lines.append(np.concatenate((point_near[:2][None], point_far[:2][None]))) 38 | 39 | ax = plt.subplot(1, 1+S, s+2) 40 | ax.axis('off') 41 | ax.set_title('{}_{}'.format(1+S, s+1)) 42 | src_inp = batch_src_inps[b][s].transpose(1, 2, 0) 43 | plt.imshow(src_inp) 44 | for i in range(len(lines)): 45 | plt.plot(lines[i][:, 0], lines[i][:, 1]) 46 | 47 | plt.subplots_adjust(left=0.,bottom=0.,top=1.,right=1.,hspace=0.,wspace=0.) 48 | plt.show() 49 | 50 | def transform(points, batch, b, s): 51 | tar_ext = batch['tar_ext'][b].cpu().numpy() 52 | tar_ixt = batch['tar_ixt'][b].cpu().numpy() 53 | src_ext = batch['src_ext'][b, s].cpu().numpy() 54 | src_ixt = batch['src_ixt'][b, s].cpu().numpy() 55 | c2w = np.linalg.inv(tar_ext) 56 | points[..., :2] = points[..., :2] * points[..., 2:] 57 | points = points @ np.linalg.inv(tar_ixt).transpose(-1, -2) 58 | points = np.concatenate((points, np.ones_like(points[..., :1])), axis=-1) 59 | points = points @ np.linalg.inv(tar_ext).transpose(-1, -2) 60 | points = points @ src_ext.transpose(-1, -2) 61 | points = points[..., :3] @ src_ixt.transpose(-1, -2) 62 | points[..., :2] = points[..., :2] / points[..., 2:] 63 | return points 64 | -------------------------------------------------------------------------------- /lib/utils/colmap/clang_format_code.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import os 33 | import string 34 | import argparse 35 | import subprocess 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--path", required=True) 41 | parser.add_argument("--exts", default=".h,.cc") 42 | parser.add_argument("--style", default="File") 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | 50 | exts = map(string.lower, args.exts.split(",")) 51 | 52 | for root, subdirs, files in os.walk(args.path): 53 | for f in files: 54 | name, ext = os.path.splitext(f) 55 | if ext.lower() in exts: 56 | file_path = os.path.join(root, f) 57 | proc = subprocess.Popen(["clang-format", "--style", 58 | args.style, file_path], 59 | stdout=subprocess.PIPE) 60 | 61 | text = "".join(proc.stdout) 62 | 63 | with open(file_path, "w") as fd: 64 | fd.write(text) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /lib/utils/colmap/test_read_write_fused_vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import filecmp 33 | from read_write_fused_vis import read_fused, write_fused 34 | 35 | 36 | def main(): 37 | import sys 38 | if len(sys.argv) != 5: 39 | print("Usage: python test_read_write_fused_vis.py " 40 | "path/to/input_fused.ply path/to/input_fused.ply.vis " 41 | "path/to/output_fused.ply path/to/output_fused.ply.vis") 42 | return 43 | 44 | print("Checking consistency of reading and writing fused.ply and fused.ply.vis files ...") 45 | 46 | path_to_fused_ply_input = sys.argv[1] 47 | path_to_fused_ply_vis_input = sys.argv[2] 48 | path_to_fused_ply_output = sys.argv[3] 49 | path_to_fused_ply_vis_output = sys.argv[4] 50 | 51 | mesh_points = read_fused(path_to_fused_ply_input, path_to_fused_ply_vis_input) 52 | write_fused(mesh_points, path_to_fused_ply_output, path_to_fused_ply_vis_output) 53 | 54 | assert filecmp.cmp(path_to_fused_ply_input, path_to_fused_ply_output) 55 | assert filecmp.cmp(path_to_fused_ply_vis_input, path_to_fused_ply_vis_output) 56 | 57 | print("... Results are equal.") 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /lib/utils/colmap/merge_ply_files.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script merges multiple homogeneous PLY files into a single PLY file. 33 | 34 | import os 35 | import glob 36 | import argparse 37 | import numpy as np 38 | import plyfile 39 | 40 | 41 | def parse_args(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--folder_path", required=True) 44 | parser.add_argument("--merged_path", required=True) 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def main(): 50 | args = parse_args() 51 | 52 | files = [] 53 | for file_name in os.listdir(args.folder_path): 54 | if len(file_name) < 4 or file_name[-4:].lower() != ".ply": 55 | continue 56 | 57 | print("Reading file", file_name) 58 | file = plyfile.PlyData.read(os.path.join(args.folder_path, file_name)) 59 | for element in file.elements: 60 | files.append(element.data) 61 | 62 | print("Merging files") 63 | merged_file = np.concatenate(files, -1) 64 | merged_el = plyfile.PlyElement.describe(merged_file, 'vertex') 65 | 66 | print("Writing merged file") 67 | plyfile.PlyData([merged_el]).write(args.merged_path) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /lib/train/losses/enerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.utils import net_utils 4 | from lib.config import cfg 5 | from lib.train.losses.vgg_perceptual_loss import VGGPerceptualLoss 6 | 7 | class NetworkWrapper(nn.Module): 8 | def __init__(self, net, train_loader): 9 | super(NetworkWrapper, self).__init__() 10 | self.device = torch.device('cuda:{}'.format(cfg.local_rank)) 11 | self.net = net 12 | self.color_crit = nn.MSELoss(reduction='mean') 13 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 14 | self.perceptual_loss = VGGPerceptualLoss().to(self.device) 15 | 16 | def forward(self, batch): 17 | output = self.net(batch) 18 | 19 | scalar_stats = {} 20 | loss = 0 21 | for i in range(cfg.enerf.cas_config.num): 22 | color_loss = self.color_crit(batch[f'rgb_{i}'], output[f'rgb_level{i}']) 23 | scalar_stats.update({f'color_mse_{i}': color_loss}) 24 | loss += cfg.enerf.cas_config.loss_weight[i] * color_loss 25 | 26 | psnr = -10. * torch.log(color_loss) / torch.log(torch.Tensor([10.]).to(color_loss.device)) 27 | scalar_stats.update({f'psnr_{i}': psnr}) 28 | 29 | num_patchs = cfg.enerf.cas_config.num_patchs[i] 30 | if cfg.enerf.cas_config.train_img[i]: 31 | render_scale = cfg.enerf.cas_config.render_scale[i] 32 | B, S, C, H, W = batch['src_inps'].shape 33 | H, W = int(H * render_scale), int(W * render_scale) 34 | inp = output[f'rgb_level{i}'].reshape(B, H, W, 3).permute(0, 3, 1, 2) 35 | tar = batch[f'rgb_{i}'].reshape(B, H, W, 3).permute(0, 3, 1, 2) 36 | perceptual_loss = self.perceptual_loss(inp, tar) 37 | loss += 0.01 * perceptual_loss * cfg.enerf.cas_config.loss_weight[i] 38 | scalar_stats.update({f'perceptual_loss_{i}': perceptual_loss.detach()}) 39 | elif num_patchs > 0: 40 | patch_size = cfg.enerf.cas_config.patch_size[i] 41 | num_rays = cfg.enerf.cas_config.num_rays[i] 42 | patch_rays = int(patch_size ** 2) 43 | inp = torch.empty((0, 3, patch_size, patch_size)).to(self.device) 44 | tar = torch.empty((0, 3, patch_size, patch_size)).to(self.device) 45 | for j in range(num_patchs): 46 | inp = torch.cat([inp, output[f'rgb_level{i}'][:, num_rays+j*patch_rays:num_rays+(j+1)*patch_rays, :].reshape(-1, patch_size, patch_size, 3).permute(0, 3, 1, 2)]) 47 | tar = torch.cat([tar, batch[f'rgb_{i}'][:, num_rays+j*patch_rays:num_rays+(j+1)*patch_rays, :].reshape(-1, patch_size, patch_size, 3).permute(0, 3, 1, 2)]) 48 | perceptual_loss = self.perceptual_loss(inp, tar) 49 | 50 | loss += 0.01 * perceptual_loss * cfg.enerf.cas_config.loss_weight[i] 51 | scalar_stats.update({f'perceptual_loss_{i}': perceptual_loss.detach()}) 52 | 53 | scalar_stats.update({'loss': loss}) 54 | image_stats = {} 55 | 56 | return output, loss, scalar_stats, image_stats 57 | 58 | -------------------------------------------------------------------------------- /lib/networks/enerf/cost_reg_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import * 3 | 4 | class CostRegNet(nn.Module): 5 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d): 6 | super(CostRegNet, self).__init__() 7 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 8 | 9 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 10 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 11 | 12 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 13 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 14 | 15 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act) 16 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) 17 | 18 | self.conv7 = nn.Sequential( 19 | nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, 20 | stride=2, bias=False), 21 | norm_act(32)) 22 | 23 | self.conv9 = nn.Sequential( 24 | nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, 25 | stride=2, bias=False), 26 | norm_act(16)) 27 | 28 | self.conv11 = nn.Sequential( 29 | nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1, 30 | stride=2, bias=False), 31 | norm_act(8)) 32 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) 33 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) 34 | 35 | def forward(self, x): 36 | conv0 = self.conv0(x) 37 | conv2 = self.conv2(self.conv1(conv0)) 38 | conv4 = self.conv4(self.conv3(conv2)) 39 | x = self.conv6(self.conv5(conv4)) 40 | x = conv4 + self.conv7(x) 41 | del conv4 42 | x = conv2 + self.conv9(x) 43 | del conv2 44 | x = conv0 + self.conv11(x) 45 | del conv0 46 | feat = self.feat_conv(x) 47 | depth = self.depth_conv(x) 48 | return feat, depth.squeeze(1) 49 | 50 | 51 | class MinCostRegNet(nn.Module): 52 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d): 53 | super(MinCostRegNet, self).__init__() 54 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 55 | 56 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 57 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 58 | 59 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 60 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 61 | 62 | self.conv9 = nn.Sequential( 63 | nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, 64 | stride=2, bias=False), 65 | norm_act(16)) 66 | 67 | self.conv11 = nn.Sequential( 68 | nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1, 69 | stride=2, bias=False), 70 | norm_act(8)) 71 | 72 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) 73 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) 74 | 75 | def forward(self, x): 76 | conv0 = self.conv0(x) 77 | conv2 = self.conv2(self.conv1(conv0)) 78 | conv4 = self.conv4(self.conv3(conv2)) 79 | x = conv4 80 | x = conv2 + self.conv9(x) 81 | del conv2 82 | x = conv0 + self.conv11(x) 83 | del conv0 84 | feat = self.feat_conv(x) 85 | depth = self.depth_conv(x) 86 | return feat, depth.squeeze(1) 87 | -------------------------------------------------------------------------------- /lib/utils/colmap/export_inlier_pairs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script exports inlier image pairs from a COLMAP database to a text file. 33 | 34 | import sqlite3 35 | import argparse 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--database_path", required=True) 41 | parser.add_argument("--match_list_path", required=True) 42 | parser.add_argument("--min_num_matches", type=int, default=15) 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def pair_id_to_image_ids(pair_id): 48 | image_id2 = pair_id % 2147483647 49 | image_id1 = (pair_id - image_id2) / 2147483647 50 | return image_id1, image_id2 51 | 52 | 53 | def main(): 54 | args = parse_args() 55 | 56 | connection = sqlite3.connect(args.database_path) 57 | cursor = connection.cursor() 58 | 59 | # Get a mapping between image ids and image names 60 | image_id_to_name = dict() 61 | cursor.execute('SELECT image_id, name FROM images;') 62 | for row in cursor: 63 | image_id = row[0] 64 | name = row[1] 65 | image_id_to_name[image_id] = name 66 | 67 | # Iterate over entries in the two_view_geometries table 68 | output = open(args.match_list_path, 'w') 69 | cursor.execute('SELECT pair_id, rows FROM two_view_geometries;') 70 | for row in cursor: 71 | pair_id = row[0] 72 | rows = row[1] 73 | 74 | if rows < args.min_num_matches: 75 | continue 76 | 77 | image_id1, image_id2 = pair_id_to_image_ids(pair_id) 78 | image_name1 = image_id_to_name[image_id1] 79 | image_name2 = image_id_to_name[image_id2] 80 | 81 | output.write("%s %s\n" % (image_name1, image_name2)) 82 | 83 | output.close() 84 | cursor.close() 85 | connection.close() 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /lib/networks/enerf/res_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ResidualConv(nn.Module): 5 | def __init__(self, input_dim, output_dim, stride, padding): 6 | super(ResidualConv, self).__init__() 7 | 8 | self.conv_block = nn.Sequential( 9 | nn.BatchNorm2d(input_dim), 10 | nn.ReLU(), 11 | nn.Conv2d( 12 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 13 | ), 14 | nn.BatchNorm2d(output_dim), 15 | nn.ReLU(), 16 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 17 | ) 18 | self.conv_skip = nn.Sequential( 19 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 20 | nn.BatchNorm2d(output_dim), 21 | ) 22 | 23 | def forward(self, x): 24 | 25 | return self.conv_block(x) + self.conv_skip(x) 26 | 27 | class Upsample(nn.Module): 28 | def __init__(self, input_dim, output_dim, kernel, stride): 29 | super(Upsample, self).__init__() 30 | 31 | self.upsample = nn.ConvTranspose2d( 32 | input_dim, output_dim, kernel_size=kernel, stride=stride 33 | ) 34 | 35 | def forward(self, x): 36 | return self.upsample(x) 37 | 38 | 39 | 40 | class ResUnet(nn.Module): 41 | def __init__(self, channel=3, filters=[16, 32, 64, 128]): 42 | super(ResUnet, self).__init__() 43 | 44 | self.input_layer = nn.Sequential( 45 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), 46 | nn.BatchNorm2d(filters[0]), 47 | nn.ReLU(), 48 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 49 | ) 50 | self.input_skip = nn.Sequential( 51 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) 52 | ) 53 | 54 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1) 55 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1) 56 | 57 | self.bridge = ResidualConv(filters[2], filters[3], 2, 1) 58 | 59 | self.upsample_1 = Upsample(filters[3], filters[3], 2, 2) 60 | # self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1) 61 | 62 | # self.upsample_2 = Upsample(filters[2], filters[2], 2, 2) 63 | # self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1) 64 | 65 | # self.upsample_3 = Upsample(filters[1], filters[1], 2, 2) 66 | # self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1) 67 | 68 | self.output_layer = nn.Sequential( 69 | nn.Conv2d(filters[2]+filters[3], 32, 1, 1) 70 | ) 71 | 72 | def forward(self, x): 73 | # Encode 74 | B, S, C, H, W = x.shape 75 | x = x.view(B*S, C, H, W) 76 | x1 = self.input_layer(x) + self.input_skip(x) 77 | x2 = self.residual_conv_1(x1) 78 | x3 = self.residual_conv_2(x2) 79 | # Bridge 80 | x4 = self.bridge(x3) 81 | # Decode 82 | x4 = self.upsample_1(x4) 83 | x5 = torch.cat([x4, x3], dim=1) 84 | 85 | # x6 = self.up_residual_conv1(x5) 86 | 87 | # x6 = self.upsample_2(x6) 88 | # x7 = torch.cat([x6, x2], dim=1) 89 | 90 | # x8 = self.up_residual_conv2(x7) 91 | 92 | # x8 = self.upsample_3(x8) 93 | # x9 = torch.cat([x8, x1], dim=1) 94 | 95 | # x10 = self.up_residual_conv3(x9) 96 | 97 | output = self.output_layer(x5) 98 | output = output.view(B, S, 32, H//4, W//4) 99 | return output 100 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | import numpy as np 3 | import os 4 | 5 | def run_dataset(): 6 | from lib.datasets import make_data_loader 7 | import tqdm 8 | 9 | cfg.train.num_workers = 0 10 | data_loader = make_data_loader(cfg, is_train=False) 11 | for batch in tqdm.tqdm(data_loader): 12 | pass 13 | 14 | def run_network(): 15 | from lib.networks import make_network 16 | from lib.datasets import make_data_loader 17 | from lib.utils.net_utils import load_network 18 | from lib.utils.data_utils import to_cuda 19 | import tqdm 20 | import torch 21 | import time 22 | 23 | network = make_network(cfg).cuda() 24 | load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch) 25 | network.eval() 26 | 27 | data_loader = make_data_loader(cfg, is_train=False) 28 | total_time = 0 29 | for batch in tqdm.tqdm(data_loader): 30 | batch = to_cuda(batch) 31 | with torch.no_grad(): 32 | torch.cuda.synchronize() 33 | start = time.time() 34 | network(batch) 35 | torch.cuda.synchronize() 36 | total_time += time.time() - start 37 | print(total_time / len(data_loader)) 38 | 39 | def run_evaluate(): 40 | from lib.datasets import make_data_loader 41 | from lib.evaluators import make_evaluator 42 | import tqdm 43 | import torch 44 | from lib.networks import make_network 45 | from lib.utils import net_utils 46 | import time 47 | 48 | network = make_network(cfg).cuda() 49 | net_utils.load_network(network, 50 | cfg.trained_model_dir, 51 | resume=cfg.resume, 52 | epoch=cfg.test.epoch) 53 | network.eval() 54 | 55 | data_loader = make_data_loader(cfg, is_train=False) 56 | evaluator = make_evaluator(cfg) 57 | net_time = [] 58 | for batch in tqdm.tqdm(data_loader): 59 | for k in batch: 60 | if k != 'meta': 61 | batch[k] = batch[k].cuda() 62 | with torch.no_grad(): 63 | torch.cuda.synchronize() 64 | start_time = time.time() 65 | output = network(batch) 66 | torch.cuda.synchronize() 67 | end_time = time.time() 68 | net_time.append(end_time - start_time) 69 | evaluator.evaluate(output, batch) 70 | evaluator.summarize() 71 | if len(net_time) > 1: 72 | # print('net_time: ', np.mean(net_time[1:])) 73 | print('FPS: ', 1./np.mean(net_time[1:])) 74 | else: 75 | # print('net_time: ', np.mean(net_time)) 76 | print('FPS: ', 1./np.mean(net_time)) 77 | 78 | 79 | def run_visualize(): 80 | from lib.networks import make_network 81 | from lib.datasets import make_data_loader 82 | from lib.utils.net_utils import load_network 83 | from lib.utils import net_utils 84 | import tqdm 85 | import torch 86 | from lib.visualizers import make_visualizer 87 | from lib.utils.data_utils import to_cuda 88 | 89 | network = make_network(cfg).cuda() 90 | load_network(network, 91 | cfg.trained_model_dir, 92 | resume=cfg.resume, 93 | epoch=cfg.test.epoch) 94 | network.eval() 95 | 96 | data_loader = make_data_loader(cfg, is_train=False) 97 | visualizer = make_visualizer(cfg) 98 | for batch in tqdm.tqdm(data_loader): 99 | batch = to_cuda(batch) 100 | with torch.no_grad(): 101 | output = network(batch) 102 | visualizer.visualize(output, batch) 103 | visualizer.summarize() 104 | 105 | if __name__ == '__main__': 106 | globals()['run_' + args.type]() 107 | -------------------------------------------------------------------------------- /lib/utils/colmap/export_inlier_matches.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script exports inlier matches from a COLMAP database to a text file. 33 | 34 | import os 35 | import argparse 36 | import sqlite3 37 | import numpy as np 38 | 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--database_path", required=True) 43 | parser.add_argument("--output_path", required=True) 44 | parser.add_argument("--min_num_matches", type=int, default=15) 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def pair_id_to_image_ids(pair_id): 50 | image_id2 = pair_id % 2147483647 51 | image_id1 = (pair_id - image_id2) / 2147483647 52 | return image_id1, image_id2 53 | 54 | 55 | def main(): 56 | args = parse_args() 57 | 58 | connection = sqlite3.connect(args.database_path) 59 | cursor = connection.cursor() 60 | 61 | images = {} 62 | cursor.execute("SELECT image_id, camera_id, name FROM images;") 63 | for row in cursor: 64 | image_id = row[0] 65 | image_name = row[2] 66 | images[image_id] = image_name 67 | 68 | with open(os.path.join(args.output_path), "w") as fid: 69 | cursor.execute( 70 | "SELECT pair_id, data FROM two_view_geometries WHERE rows>=?;", 71 | (args.min_num_matches,)) 72 | for row in cursor: 73 | pair_id = row[0] 74 | inlier_matches = np.fromstring(row[1], 75 | dtype=np.uint32).reshape(-1, 2) 76 | image_id1, image_id2 = pair_id_to_image_ids(pair_id) 77 | image_name1 = images[image_id1] 78 | image_name2 = images[image_id2] 79 | fid.write("%s %s %d\n" % (image_name1, image_name2, 80 | inlier_matches.shape[0])) 81 | for i in range(inlier_matches.shape[0]): 82 | fid.write("%d %d\n" % tuple(inlier_matches[i])) 83 | 84 | cursor.close() 85 | connection.close() 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /lib/networks/enerf/cost_reg_net_.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import * 3 | 4 | class CostRegNet(nn.Module): 5 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d): 6 | super(CostRegNet, self).__init__() 7 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 8 | 9 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 10 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 11 | 12 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 13 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 14 | 15 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act) 16 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) 17 | 18 | self.conv7 = nn.Sequential( 19 | nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, 20 | stride=2, bias=False), 21 | norm_act(32)) 22 | 23 | self.conv9 = nn.Sequential( 24 | nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, 25 | stride=2, bias=False), 26 | norm_act(16)) 27 | 28 | self.conv11 = nn.Sequential( 29 | nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1, 30 | stride=2, bias=False), 31 | norm_act(8)) 32 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) 33 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) 34 | 35 | def forward(self, x): 36 | conv0 = self.conv0(x) 37 | conv2 = self.conv2(self.conv1(conv0)) 38 | conv4 = self.conv4(self.conv3(conv2)) 39 | x = self.conv6(self.conv5(conv4)) 40 | x = conv4 + self.conv7(x) 41 | del conv4 42 | x = conv2 + self.conv9(x) 43 | del conv2 44 | x = conv0 + self.conv11(x) 45 | del conv0 46 | feat = self.feat_conv(x) 47 | depth = self.depth_conv(x) 48 | return feat, depth.squeeze(1) 49 | 50 | 51 | class MinCostRegNet(nn.Module): 52 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d): 53 | super(MinCostRegNet, self).__init__() 54 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 55 | 56 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 57 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 58 | 59 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 60 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 61 | 62 | self.conv9 = ConvBnReLU3D(32, 16, norm_act=norm_act) 63 | self.conv11 = ConvBnReLU3D(16, 8, norm_act=norm_act) 64 | 65 | # self.conv9 = nn.Sequential( 66 | # nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, 67 | # stride=2, bias=False), 68 | # norm_act(16)) 69 | 70 | # self.conv11 = nn.Sequential( 71 | # nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1, 72 | # stride=2, bias=False), 73 | # norm_act(8)) 74 | 75 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) 76 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) 77 | 78 | def forward(self, x): 79 | conv0 = self.conv0(x) 80 | conv2 = self.conv2(self.conv1(conv0)) 81 | conv4 = self.conv4(self.conv3(conv2)) 82 | x = conv4 83 | x = conv2 + self.conv9(F.interpolate(x, scale_factor=2., align_corners=True, mode='trilinear')) 84 | del conv2 85 | x = conv0 + self.conv11(F.interpolate(x, scale_factor=2., align_corners=True, mode='trilinear')) 86 | del conv0 87 | feat = self.feat_conv(x) 88 | depth = self.depth_conv(x) 89 | return feat, depth.squeeze(1) 90 | -------------------------------------------------------------------------------- /lib/datasets/enerf_utils.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg 2 | import cv2 3 | import numpy as np 4 | 5 | def sample_patch(num_patch, patch_size, H, W, msk_sample): 6 | half_patch_size = patch_size // 2 7 | if msk_sample.sum() > 0: 8 | num_fg_patch = num_patch 9 | non_zero = msk_sample.nonzero() 10 | permutation = np.random.permutation(msk_sample.sum())[:num_fg_patch].astype(np.int32) 11 | X_, Y_ = non_zero[1][permutation], non_zero[0][permutation] 12 | X_ = np.clip(X_, half_patch_size, W-half_patch_size) 13 | Y_ = np.clip(Y_, half_patch_size, H-half_patch_size) 14 | else: 15 | num_fg_patch = 0 16 | num_patch = num_patch - num_fg_patch 17 | X = np.random.randint(low=half_patch_size, high=W-half_patch_size, size=num_patch) 18 | Y = np.random.randint(low=half_patch_size, high=H-half_patch_size, size=num_patch) 19 | if num_fg_patch > 0: 20 | X = np.concatenate([X, X_]).astype(np.int32) 21 | Y = np.concatenate([Y, Y_]).astype(np.int32) 22 | grid = np.meshgrid(np.arange(patch_size)-half_patch_size, np.arange(patch_size)-half_patch_size) 23 | return np.concatenate([grid[0].reshape(-1) + x for x in X]), np.concatenate([grid[1].reshape(-1) + y for y in Y]) 24 | 25 | def build_rays(tar_img, tar_ext, tar_ixt, tar_msk, level, split): 26 | scale = cfg.enerf.cas_config.render_scale[level] 27 | if scale != 1.: 28 | tar_img = cv2.resize(tar_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) 29 | tar_msk = cv2.resize(tar_msk, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) 30 | tar_ixt = tar_ixt.copy() 31 | tar_ixt[:2] *= scale 32 | H, W = tar_img.shape[:2] 33 | c2w = np.linalg.inv(tar_ext) 34 | if split == 'train' and not cfg.enerf.cas_config.train_img[level]: 35 | if cfg.enerf.sample_on_mask: # 313 36 | msk_sample = tar_msk 37 | num_fg_rays = int(min(cfg.enerf.cas_config.num_rays[level]*0.75, tar_msk.sum()*0.95)) 38 | non_zero = msk_sample.nonzero() 39 | permutation = np.random.permutation(tar_msk.sum())[:num_fg_rays].astype(np.int32) 40 | X_, Y_ = non_zero[1][permutation], non_zero[0][permutation] 41 | else: 42 | num_fg_rays = 0 43 | msk_sample = np.zeros_like(tar_msk) 44 | num_rays = cfg.enerf.cas_config.num_rays[level] - num_fg_rays 45 | X = np.random.randint(low=0, high=W, size=num_rays) 46 | Y = np.random.randint(low=0, high=H, size=num_rays) 47 | if num_fg_rays > 0: 48 | X = np.concatenate([X, X_]).astype(np.int32) 49 | Y = np.concatenate([Y, Y_]).astype(np.int32) 50 | if cfg.enerf.cas_config.num_patchs[level] > 0: 51 | X_, Y_ = sample_patch(cfg.enerf.cas_config.num_patchs[level], cfg.enerf.cas_config.patch_size[level], H, W, msk_sample) 52 | X = np.concatenate([X, X_]).astype(np.int32) 53 | Y = np.concatenate([Y, Y_]).astype(np.int32) 54 | num_rays = len(X) 55 | rays_o = c2w[:3, 3][None].repeat(num_rays, 0) 56 | XYZ = np.concatenate((X[:, None], Y[:, None], np.ones_like(X[:, None])), axis=-1) 57 | XYZ = XYZ @ (np.linalg.inv(tar_ixt).T @ c2w[:3, :3].T) 58 | rays = np.concatenate((rays_o, XYZ, X[..., None], Y[..., None]), axis=-1) 59 | rgb = tar_img[Y, X] 60 | msk = tar_msk[Y, X] 61 | else: 62 | rays_o = c2w[:3, 3][None, None] 63 | X, Y = np.meshgrid(np.arange(W), np.arange(H)) 64 | XYZ = np.concatenate((X[:, :, None], Y[:, :, None], np.ones_like(X[:, :, None])), axis=-1) 65 | XYZ = XYZ @ (np.linalg.inv(tar_ixt).T @ c2w[:3, :3].T) 66 | rays_o = rays_o.repeat(H, axis=0) 67 | rays_o = rays_o.repeat(W, axis=1) 68 | rays = np.concatenate((rays_o, XYZ, X[..., None], Y[..., None]), axis=-1) 69 | rgb = tar_img 70 | msk = tar_msk 71 | return rays.astype(np.float32).reshape(-1, 8), rgb.reshape(-1, 3), msk.reshape(-1) 72 | 73 | 74 | -------------------------------------------------------------------------------- /lib/datasets/make_dataset.py: -------------------------------------------------------------------------------- 1 | from . import samplers 2 | import torch 3 | import torch.utils.data 4 | import imp 5 | import os 6 | from .collate_batch import make_collator 7 | import numpy as np 8 | import time 9 | from lib.config.config import cfg 10 | from torch.utils.data import DataLoader, ConcatDataset 11 | import cv2 12 | cv2.setNumThreads(1) 13 | 14 | 15 | # torch.multiprocessing.set_sharing_strategy('file_system') 16 | 17 | def _dataset_factory(is_train, is_val): 18 | if is_val: 19 | module = cfg.val_dataset_module 20 | path = cfg.val_dataset_path 21 | elif is_train: 22 | module = cfg.train_dataset_module 23 | path = cfg.train_dataset_path 24 | else: 25 | module = cfg.test_dataset_module 26 | path = cfg.test_dataset_path 27 | dataset = imp.load_source(module, path).Dataset 28 | return dataset 29 | 30 | 31 | def make_dataset(cfg, is_train=True): 32 | if is_train: 33 | args = cfg.train_dataset 34 | module = cfg.train_dataset_module 35 | path = cfg.train_dataset_path 36 | else: 37 | args = cfg.test_dataset 38 | module = cfg.test_dataset_module 39 | path = cfg.test_dataset_path 40 | dataset = imp.load_source(module, path).Dataset 41 | dataset = dataset(**args) 42 | return dataset 43 | 44 | 45 | def make_data_sampler(dataset, shuffle, is_distributed): 46 | if is_distributed: 47 | return samplers.DistributedSampler(dataset, shuffle=shuffle) 48 | if shuffle: 49 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 50 | else: 51 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 52 | return sampler 53 | 54 | 55 | def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter, 56 | is_train): 57 | if is_train: 58 | batch_sampler = cfg.train.batch_sampler 59 | sampler_meta = cfg.train.sampler_meta 60 | else: 61 | batch_sampler = cfg.test.batch_sampler 62 | sampler_meta = cfg.test.sampler_meta 63 | if batch_sampler == 'default': 64 | batch_sampler = torch.utils.data.sampler.BatchSampler( 65 | sampler, batch_size, drop_last) 66 | elif batch_sampler == 'image_size': 67 | batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size, 68 | drop_last, sampler_meta) 69 | elif batch_sampler == 'enerf': 70 | batch_sampler = samplers.EnerfBatchSampler(sampler, batch_size, drop_last, sampler_meta) 71 | if max_iter != -1: 72 | batch_sampler = samplers.IterationBasedBatchSampler( 73 | batch_sampler, max_iter) 74 | return batch_sampler 75 | 76 | 77 | def worker_init_fn(worker_id): 78 | np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16)))) 79 | 80 | 81 | def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1): 82 | if is_train: 83 | batch_size = cfg.train.batch_size 84 | # shuffle = True 85 | shuffle = cfg.train.shuffle 86 | drop_last = False 87 | else: 88 | batch_size = cfg.test.batch_size 89 | shuffle = True if is_distributed else False 90 | drop_last = False 91 | 92 | dataset = make_dataset(cfg, is_train) 93 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 94 | batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, 95 | drop_last, max_iter, is_train) 96 | num_workers = cfg.train.num_workers 97 | collator = make_collator(cfg, is_train) 98 | data_loader = DataLoader(dataset, 99 | batch_sampler=batch_sampler, 100 | num_workers=num_workers, 101 | collate_fn=collator, 102 | worker_init_fn=worker_init_fn) 103 | 104 | return data_loader 105 | -------------------------------------------------------------------------------- /lib/utils/colmap/build_windows_app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import os 33 | import glob 34 | import shutil 35 | import argparse 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--install_path", required=True, 41 | help="The installation prefix, e.g., build/__install__") 42 | parser.add_argument("--app_path", required=True, 43 | help="The application path, e.g., " 44 | "build/COLMAP-dev-windows") 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def mkdir_if_not_exists(path): 50 | assert os.path.exists(os.path.dirname(os.path.abspath(path))) 51 | if not os.path.exists(path): 52 | os.makedirs(path) 53 | 54 | 55 | def main(): 56 | args = parse_args() 57 | 58 | mkdir_if_not_exists(args.app_path) 59 | mkdir_if_not_exists(os.path.join(args.app_path, "bin")) 60 | mkdir_if_not_exists(os.path.join(args.app_path, "lib")) 61 | mkdir_if_not_exists(os.path.join(args.app_path, "lib/platforms")) 62 | 63 | # Copy batch scripts to app directory. 64 | shutil.copyfile( 65 | os.path.join(args.install_path, "COLMAP.bat"), 66 | os.path.join(args.app_path, "COLMAP.bat")) 67 | shutil.copyfile( 68 | os.path.join(args.install_path, "RUN_TESTS.bat"), 69 | os.path.join(args.app_path, "RUN_TESTS.bat")) 70 | 71 | # Copy executables to app directory. 72 | exe_files = glob.glob(os.path.join(args.install_path, "bin/*.exe")) 73 | for exe_file in exe_files: 74 | shutil.copyfile(exe_file, os.path.join(args.app_path, "bin", 75 | os.path.basename(exe_file))) 76 | 77 | # Copy shared libraries to app directory. 78 | dll_files = glob.glob(os.path.join(args.install_path, "lib/*.dll")) 79 | for dll_file in dll_files: 80 | shutil.copyfile(dll_file, os.path.join(args.app_path, "lib", 81 | os.path.basename(dll_file))) 82 | shutil.copyfile( 83 | os.path.join(args.install_path, "lib/platforms/qwindows.dll"), 84 | os.path.join(args.app_path, "lib/platforms/qwindows.dll")) 85 | 86 | # Create zip archive for deployment. 87 | shutil.make_archive(args.app_path, "zip", root_dir=args.app_path) 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | from lib.networks import make_network 3 | from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler 4 | from lib.datasets import make_data_loader 5 | from lib.utils.net_utils import load_model, save_model, load_network, save_trained_config, load_pretrain 6 | from lib.evaluators import make_evaluator 7 | import torch.multiprocessing 8 | import torch 9 | import torch.distributed as dist 10 | import os 11 | # torch.autograd.set_detect_anomaly(True) 12 | 13 | if cfg.fix_random: 14 | torch.manual_seed(0) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | 19 | def train(cfg, network): 20 | train_loader = make_data_loader(cfg, 21 | is_train=True, 22 | is_distributed=cfg.distributed, 23 | max_iter=cfg.ep_iter) 24 | if cfg.skip_eval: 25 | val_loader = None 26 | else: 27 | val_loader = make_data_loader(cfg, is_train=False) 28 | trainer = make_trainer(cfg, network, train_loader) 29 | optimizer = make_optimizer(cfg, network) 30 | scheduler = make_lr_scheduler(cfg, optimizer) 31 | recorder = make_recorder(cfg) 32 | evaluator = make_evaluator(cfg) 33 | 34 | begin_epoch = load_model(network, 35 | optimizer, 36 | scheduler, 37 | recorder, 38 | cfg.trained_model_dir, 39 | resume=cfg.resume) 40 | if begin_epoch == 0 and cfg.pretrain != '': 41 | load_pretrain(network, cfg.pretrain) 42 | 43 | set_lr_scheduler(cfg, scheduler) 44 | for epoch in range(begin_epoch, cfg.train.epoch): 45 | recorder.epoch = epoch 46 | if cfg.distributed: 47 | train_loader.batch_sampler.sampler.set_epoch(epoch) 48 | 49 | train_loader.dataset.epoch = epoch 50 | 51 | trainer.train(epoch, train_loader, optimizer, recorder) 52 | scheduler.step() 53 | 54 | if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0: 55 | save_model(network, optimizer, scheduler, recorder, 56 | cfg.trained_model_dir, epoch) 57 | 58 | if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0: 59 | save_model(network, 60 | optimizer, 61 | scheduler, 62 | recorder, 63 | cfg.trained_model_dir, 64 | epoch, 65 | last=True) 66 | 67 | if not cfg.skip_eval and (epoch + 1) % cfg.eval_ep == 0 and cfg.local_rank == 0: 68 | trainer.val(epoch, val_loader, evaluator, recorder) 69 | 70 | return network 71 | 72 | 73 | def test(cfg, network): 74 | trainer = make_trainer(cfg, network) 75 | val_loader = make_data_loader(cfg, is_train=False) 76 | evaluator = make_evaluator(cfg) 77 | epoch = load_network(network, 78 | cfg.trained_model_dir, 79 | resume=cfg.resume, 80 | epoch=cfg.test.epoch) 81 | trainer.val(epoch, val_loader, evaluator) 82 | 83 | def synchronize(): 84 | """ 85 | Helper function to synchronize (barrier) among all processes when 86 | using distributed training 87 | """ 88 | if not dist.is_available(): 89 | return 90 | if not dist.is_initialized(): 91 | return 92 | world_size = dist.get_world_size() 93 | if world_size == 1: 94 | return 95 | dist.barrier() 96 | 97 | def main(): 98 | if cfg.distributed: 99 | cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count() 100 | torch.cuda.set_device(cfg.local_rank) 101 | torch.distributed.init_process_group(backend="nccl", 102 | init_method="env://") 103 | synchronize() 104 | 105 | network = make_network(cfg) 106 | if args.test: 107 | test(cfg, network) 108 | else: 109 | train(cfg, network) 110 | if cfg.local_rank == 0: 111 | print('Success!') 112 | print('='*80) 113 | os.system('kill -9 {}'.format(os.getpid())) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /lib/train/recorder.py: -------------------------------------------------------------------------------- 1 | from collections import deque, defaultdict 2 | import torch 3 | from tensorboardX import SummaryWriter 4 | import os 5 | from lib.config.config import cfg 6 | 7 | from termcolor import colored 8 | 9 | 10 | class SmoothedValue(object): 11 | """Track a series of values and provide access to smoothed values over a 12 | window or the global series average. 13 | """ 14 | 15 | def __init__(self, window_size=20): 16 | self.deque = deque(maxlen=window_size) 17 | self.total = 0.0 18 | self.count = 0 19 | 20 | def update(self, value): 21 | self.deque.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | def process_volsdf(image_stats): 41 | for k, v in image_stats.items(): 42 | image_stats[k] = torch.clamp(v[0].permute(2, 0, 1), min=0., max=1.) 43 | return image_stats 44 | 45 | process_neus = process_volsdf 46 | 47 | class Recorder(object): 48 | def __init__(self, cfg): 49 | if cfg.local_rank > 0: 50 | return 51 | 52 | log_dir = cfg.record_dir 53 | if not cfg.resume: 54 | print(colored('remove contents of directory %s' % log_dir, 'red')) 55 | os.system('rm -r %s/*' % log_dir) 56 | self.writer = SummaryWriter(log_dir=log_dir) 57 | 58 | # scalars 59 | self.epoch = 0 60 | self.step = 0 61 | self.loss_stats = defaultdict(SmoothedValue) 62 | self.batch_time = SmoothedValue() 63 | self.data_time = SmoothedValue() 64 | 65 | # images 66 | self.image_stats = defaultdict(object) 67 | if 'process_' + cfg.task in globals(): 68 | self.processor = globals()['process_' + cfg.task] 69 | else: 70 | self.processor = None 71 | 72 | def update_loss_stats(self, loss_dict): 73 | if cfg.local_rank > 0: 74 | return 75 | for k, v in loss_dict.items(): 76 | self.loss_stats[k].update(v.detach().cpu()) 77 | 78 | def update_image_stats(self, image_stats): 79 | if cfg.local_rank > 0: 80 | return 81 | if self.processor is None: 82 | return 83 | image_stats = self.processor(image_stats) 84 | for k, v in image_stats.items(): 85 | self.image_stats[k] = v.detach().cpu() 86 | 87 | def record(self, prefix, step=-1, loss_stats=None, image_stats=None): 88 | if cfg.local_rank > 0: 89 | return 90 | 91 | pattern = prefix + '/{}' 92 | step = step if step >= 0 else self.step 93 | loss_stats = loss_stats if loss_stats else self.loss_stats 94 | 95 | for k, v in loss_stats.items(): 96 | if isinstance(v, SmoothedValue): 97 | self.writer.add_scalar(pattern.format(k), v.median, step) 98 | else: 99 | self.writer.add_scalar(pattern.format(k), v, step) 100 | 101 | if self.processor is None: 102 | return 103 | image_stats = self.processor(image_stats) if image_stats else self.image_stats 104 | for k, v in image_stats.items(): 105 | self.writer.add_image(pattern.format(k), v, step) 106 | 107 | def state_dict(self): 108 | if cfg.local_rank > 0: 109 | return 110 | scalar_dict = {} 111 | scalar_dict['step'] = self.step 112 | return scalar_dict 113 | 114 | def load_state_dict(self, scalar_dict): 115 | if cfg.local_rank > 0: 116 | return 117 | self.step = scalar_dict['step'] 118 | 119 | def __str__(self): 120 | if cfg.local_rank > 0: 121 | return 122 | loss_state = [] 123 | for k, v in self.loss_stats.items(): 124 | loss_state.append('{}: {:.4f}'.format(k, v.avg)) 125 | loss_state = ' '.join(loss_state) 126 | 127 | recording_state = ' '.join(['epoch: {}', 'step: {}', '{}', 'data: {:.4f}', 'batch: {:.4f}']) 128 | return recording_state.format(self.epoch, self.step, loss_state, self.data_time.avg, self.batch_time.avg) 129 | 130 | 131 | def make_recorder(cfg): 132 | return Recorder(cfg) 133 | -------------------------------------------------------------------------------- /lib/evaluators/enerf_composite.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.config import cfg 3 | import os 4 | import imageio 5 | from lib.utils import img_utils 6 | from skimage.metrics import structural_similarity as ssim 7 | from skimage.metrics import peak_signal_noise_ratio as psnr 8 | import torch.nn.functional as F 9 | import torch 10 | import lpips 11 | import imageio 12 | from lib.utils import img_utils 13 | import cv2 14 | 15 | 16 | class Evaluator: 17 | 18 | def __init__(self,): 19 | self.psnrs = [] 20 | self.ssims = [] 21 | self.lpips = [] 22 | self.scene_psnrs = {} 23 | self.scene_ssims = {} 24 | self.scene_lpips = {} 25 | self.loss_fn_vgg = lpips.LPIPS(net='vgg') 26 | self.loss_fn_vgg.cuda() 27 | os.system('mkdir -p ' + cfg.result_dir) 28 | 29 | def evaluate(self, output, batch): 30 | B, S, _, H, W = batch['src_inps'].shape 31 | for i in range(cfg.enerf.cas_config.num): 32 | if not cfg.enerf.cas_config.render_if[i]: 33 | continue 34 | render_scale = cfg.enerf.cas_config.render_scale[i] 35 | h, w = int(H*render_scale), int(W*render_scale) 36 | pred_rgb = output[f'rgb_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 37 | gt_rgb = batch[f'rgb_{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 38 | pred_depth = output[f'depth_level{i}'].reshape(B, h, w).detach().cpu().numpy()[..., None].repeat(3, -1) 39 | pred_depth -= pred_depth.min() 40 | pred_depth /= pred_depth.max() 41 | 42 | for b in range(B): 43 | if not batch['meta']['scene'][b]+f'_level{i}' in self.scene_psnrs: 44 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'] = [] 45 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'] = [] 46 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'] = [] 47 | # self.scene_lpips[batch['meta']['scene'][b]] = [] 48 | if cfg.save_result and i == 1: 49 | img = img_utils.horizon_concate(gt_rgb[b], pred_rgb[b]) 50 | img = img_utils.horizon_concate(img, pred_depth[b]) 51 | img_path = os.path.join(cfg.result_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][b], batch['meta']['tar_view'][b].item(), batch['meta']['frame_id'][b].item())) 52 | imageio.imwrite(img_path, (img*255.).astype(np.uint8)) 53 | 54 | psnr_item = psnr(gt_rgb[b], pred_rgb[b], data_range=1.) 55 | if i == cfg.enerf.cas_config.num-1: 56 | self.psnrs.append(psnr_item) 57 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'].append(psnr_item) 58 | ssim_item = ssim(gt_rgb[b], pred_rgb[b], multichannel=True) 59 | if i == cfg.enerf.cas_config.num-1: 60 | self.ssims.append(ssim_item) 61 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'].append(ssim_item) 62 | if cfg.eval_lpips: 63 | gt, pred = torch.Tensor(gt_rgb[b])[None].permute(0, 3, 1, 2), torch.Tensor(pred_rgb[b])[None].permute(0, 3, 1, 2) 64 | gt, pred = (gt-0.5)*2., (pred-0.5)*2. 65 | lpips_item = self.loss_fn_vgg(gt.cuda(), pred.cuda()).item() 66 | if i == cfg.enerf.cas_config.num-1: 67 | self.lpips.append(lpips_item) 68 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'].append(lpips_item) 69 | 70 | def summarize(self): 71 | ret = {} 72 | ret.update({'psnr': np.mean(self.psnrs)}) 73 | ret.update({'ssim': np.mean(self.ssims)}) 74 | if cfg.eval_lpips: 75 | ret.update({'lpips': np.mean(self.lpips)}) 76 | print('='*30) 77 | for scene in self.scene_psnrs: 78 | if cfg.eval_lpips: 79 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} lpips:{:.3f}'.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]), np.mean(self.scene_lpips[scene]))) 80 | else: 81 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} '.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]))) 82 | print('='*30) 83 | print(ret) 84 | self.psnrs = [] 85 | self.ssims = [] 86 | self.lpips = [] 87 | self.scene_psnrs = {} 88 | self.scene_ssims = {} 89 | self.scene_lpips = {} 90 | if cfg.save_result: 91 | print('Save visualization results to: {}'.format(cfg.result_dir)) 92 | return ret 93 | -------------------------------------------------------------------------------- /lib/evaluators/enerf_human.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.config import cfg 3 | import os 4 | import imageio 5 | from lib.utils import img_utils 6 | from skimage.metrics import structural_similarity as ssim 7 | from skimage.metrics import peak_signal_noise_ratio as psnr 8 | import torch.nn.functional as F 9 | import torch 10 | import lpips 11 | import imageio 12 | from lib.utils import img_utils 13 | import cv2 14 | 15 | 16 | class Evaluator: 17 | 18 | def __init__(self,): 19 | self.psnrs = [] 20 | self.ssims = [] 21 | self.lpips = [] 22 | self.scene_psnrs = {} 23 | self.scene_ssims = {} 24 | self.scene_lpips = {} 25 | self.loss_fn_vgg = lpips.LPIPS(net='vgg') 26 | self.loss_fn_vgg.cuda() 27 | os.system('mkdir -p ' + cfg.result_dir) 28 | 29 | def evaluate(self, output, batch): 30 | B, S, _, H, W = batch['src_inps'].shape 31 | for i in range(cfg.enerf.cas_config.num): 32 | if not cfg.enerf.cas_config.render_if[i]: 33 | continue 34 | render_scale = cfg.enerf.cas_config.render_scale[i] 35 | h, w = int(H*render_scale), int(W*render_scale) 36 | pred_rgb = output[f'rgb_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 37 | gt_rgb = batch[f'rgb_{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 38 | 39 | if i == cfg.enerf.cas_config.num-1: 40 | masks = batch['mask_at_box'].detach().cpu().numpy() 41 | else: 42 | masks = np.ones_like(pred_rgb[..., 0]) 43 | 44 | for b in range(B): 45 | if not batch['meta']['scene'][b]+f'_level{i}' in self.scene_psnrs: 46 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'] = [] 47 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'] = [] 48 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'] = [] 49 | if cfg.save_result: 50 | img = img_utils.horizon_concate(gt_rgb[b], pred_rgb[b]) 51 | img_path = os.path.join(cfg.result_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][b], batch['meta']['tar_view'][b].item(), batch['meta']['frame_id'][b].item())) 52 | imageio.imwrite(img_path, (img*255.).astype(np.uint8)) 53 | 54 | mask = masks[b] == 1 55 | gt_rgb[b][mask==False] = 0. 56 | pred_rgb[b][mask==False] = 0. 57 | 58 | psnr_item = psnr(gt_rgb[b][mask], pred_rgb[b][mask], data_range=1.) 59 | if i == cfg.enerf.cas_config.num-1: 60 | self.psnrs.append(psnr_item) 61 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'].append(psnr_item) 62 | 63 | 64 | x, y, w, h = cv2.boundingRect(mask.astype(np.uint8)) 65 | 66 | ssim_item = ssim(gt_rgb[b][y:y+h, x:x+w], pred_rgb[b][y:y+h, x:x+w], multichannel=True) 67 | if i == cfg.enerf.cas_config.num-1: 68 | self.ssims.append(ssim_item) 69 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'].append(ssim_item) 70 | 71 | if cfg.eval_lpips: 72 | gt, pred = torch.Tensor(gt_rgb[b][y:y+h, x:x+w])[None].permute(0, 3, 1, 2), torch.Tensor(pred_rgb[b][y:y+h, x:x+w])[None].permute(0, 3, 1, 2) 73 | gt, pred = (gt-0.5)*2., (pred-0.5)*2. 74 | lpips_item = self.loss_fn_vgg(gt.cuda(), pred.cuda()).item() 75 | if i == cfg.enerf.cas_config.num-1: 76 | self.lpips.append(lpips_item) 77 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'].append(lpips_item) 78 | 79 | def summarize(self): 80 | ret = {} 81 | ret.update({'psnr': np.mean(self.psnrs)}) 82 | ret.update({'ssim': np.mean(self.ssims)}) 83 | if cfg.eval_lpips: 84 | ret.update({'lpips': np.mean(self.lpips)}) 85 | 86 | 87 | print('='*30) 88 | for scene in self.scene_psnrs: 89 | # print(scene.ljust(8), 'psnr: {:.2f} ssim: {:.2f} lpips: {:.3f}'.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]), np.mean(self.scene_lpips[scene]))) 90 | if cfg.eval_lpips: 91 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} lpips:{:.3f}'.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]), np.mean(self.scene_lpips[scene]))) 92 | else: 93 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} '.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]))) 94 | print(ret) 95 | print('='*30) 96 | self.psnrs = [] 97 | self.ssims = [] 98 | self.lpips = [] 99 | self.scene_psnrs = {} 100 | self.scene_ssims = {} 101 | self.scene_lpips = {} 102 | if cfg.save_result: 103 | print('Save visualization results to: {}'.format(cfg.result_dir)) 104 | return ret 105 | -------------------------------------------------------------------------------- /lib/utils/colmap/test_read_write_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import numpy as np 33 | from read_write_model import read_model, write_model 34 | from tempfile import mkdtemp 35 | 36 | 37 | def compare_cameras(cameras1, cameras2): 38 | assert len(cameras1) == len(cameras2) 39 | for camera_id1 in cameras1: 40 | camera1 = cameras1[camera_id1] 41 | camera2 = cameras2[camera_id1] 42 | assert camera1.id == camera2.id 43 | assert camera1.width == camera2.width 44 | assert camera1.height == camera2.height 45 | assert np.allclose(camera1.params, camera2.params) 46 | 47 | 48 | def compare_images(images1, images2): 49 | assert len(images1) == len(images2) 50 | for image_id1 in images1: 51 | image1 = images1[image_id1] 52 | image2 = images2[image_id1] 53 | assert image1.id == image2.id 54 | assert np.allclose(image1.qvec, image2.qvec) 55 | assert np.allclose(image1.tvec, image2.tvec) 56 | assert image1.camera_id == image2.camera_id 57 | assert image1.name == image2.name 58 | assert np.allclose(image1.xys, image2.xys) 59 | assert np.array_equal(image1.point3D_ids, image2.point3D_ids) 60 | 61 | 62 | def compare_points(points3D1, points3D2): 63 | for point3D_id1 in points3D1: 64 | point3D1 = points3D1[point3D_id1] 65 | point3D2 = points3D2[point3D_id1] 66 | assert point3D1.id == point3D2.id 67 | assert np.allclose(point3D1.xyz, point3D2.xyz) 68 | assert np.array_equal(point3D1.rgb, point3D2.rgb) 69 | assert np.allclose(point3D1.error, point3D2.error) 70 | assert np.array_equal(point3D1.image_ids, point3D2.image_ids) 71 | assert np.array_equal(point3D1.point2D_idxs, point3D2.point2D_idxs) 72 | 73 | 74 | def main(): 75 | import sys 76 | if len(sys.argv) != 3: 77 | print("Usage: python read_model.py " 78 | "path/to/model/folder/txt path/to/model/folder/bin") 79 | return 80 | 81 | print("Comparing text and binary models ...") 82 | 83 | path_to_model_txt_folder = sys.argv[1] 84 | path_to_model_bin_folder = sys.argv[2] 85 | cameras_txt, images_txt, points3D_txt = \ 86 | read_model(path_to_model_txt_folder, ext=".txt") 87 | cameras_bin, images_bin, points3D_bin = \ 88 | read_model(path_to_model_bin_folder, ext=".bin") 89 | compare_cameras(cameras_txt, cameras_bin) 90 | compare_images(images_txt, images_bin) 91 | compare_points(points3D_txt, points3D_bin) 92 | 93 | print("... text and binary models are equal.") 94 | print("Saving text model and reloading it ...") 95 | 96 | tmpdir = mkdtemp() 97 | write_model(cameras_bin, images_bin, points3D_bin, tmpdir, ext='.txt') 98 | cameras_txt, images_txt, points3D_txt = \ 99 | read_model(tmpdir, ext=".txt") 100 | compare_cameras(cameras_txt, cameras_bin) 101 | compare_images(images_txt, images_bin) 102 | compare_points(points3D_txt, points3D_bin) 103 | 104 | print("... saved text and loaded models are equal.") 105 | print("Saving binary model and reloading it ...") 106 | 107 | write_model(cameras_bin, images_bin, points3D_bin, tmpdir, ext='.bin') 108 | cameras_bin, images_bin, points3D_bin = \ 109 | read_model(tmpdir, ext=".bin") 110 | compare_cameras(cameras_txt, cameras_bin) 111 | compare_images(images_txt, images_bin) 112 | compare_points(points3D_txt, points3D_bin) 113 | 114 | print("... saved binary and loaded models are equal.") 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /lib/utils/colmap/bundler_to_ply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script converts a Bundler reconstruction file to a PLY point cloud. 33 | 34 | import argparse 35 | import numpy as np 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--bundler_path", required=True) 41 | parser.add_argument("--ply_path", required=True) 42 | parser.add_argument("--normalize", type=bool, default=True) 43 | parser.add_argument("--normalize_p0", type=float, default=0.2) 44 | parser.add_argument("--normalize_p1", type=float, default=0.8) 45 | parser.add_argument("--min_track_length", type=int, default=3) 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | def main(): 51 | args = parse_args() 52 | 53 | with open(args.bundler_path, "r") as fid: 54 | line = fid.readline() 55 | line = fid.readline() 56 | num_images, num_points = map(int, line.split()) 57 | 58 | for i in range(5 * num_images): 59 | fid.readline() 60 | 61 | xyz = np.zeros((num_points, 3), dtype=np.float64) 62 | rgb = np.zeros((num_points, 3), dtype=np.uint16) 63 | track_lengths = np.zeros((num_points,), dtype=np.uint32) 64 | 65 | for i in range(num_points): 66 | if i % 1000 == 0: 67 | print("Reading point", i, "/", num_points) 68 | xyz[i] = map(float, fid.readline().split()) 69 | rgb[i] = map(int, fid.readline().split()) 70 | track_lengths[i] = int(fid.readline().split()[0]) 71 | 72 | mask = track_lengths >= args.min_track_length 73 | xyz = xyz[mask] 74 | rgb = rgb[mask] 75 | 76 | if args.normalize: 77 | sorted_x = np.sort(xyz[:, 0]) 78 | sorted_y = np.sort(xyz[:, 1]) 79 | sorted_z = np.sort(xyz[:, 2]) 80 | 81 | num_coords = sorted_x.size 82 | min_coord = int(args.normalize_p0 * num_coords) 83 | max_coord = int(args.normalize_p1 * num_coords) 84 | mean_coords = xyz.mean(0) 85 | 86 | bbox_min = np.array([sorted_x[min_coord], sorted_y[min_coord], 87 | sorted_z[min_coord]]) 88 | bbox_max = np.array([sorted_x[max_coord], sorted_y[max_coord], 89 | sorted_z[max_coord]]) 90 | 91 | extent = np.linalg.norm(bbox_max - bbox_min) 92 | scale = 10.0 / extent 93 | 94 | xyz -= mean_coords 95 | xyz *= scale 96 | 97 | xyz[:, 2] *= -1 98 | 99 | with open(args.ply_path, "w") as fid: 100 | fid.write("ply\n") 101 | fid.write("format ascii 1.0\n") 102 | fid.write("element vertex %d\n" % xyz.shape[0]) 103 | fid.write("property float x\n") 104 | fid.write("property float y\n") 105 | fid.write("property float z\n") 106 | fid.write("property float nx\n") 107 | fid.write("property float ny\n") 108 | fid.write("property float nz\n") 109 | fid.write("property uchar diffuse_red\n") 110 | fid.write("property uchar diffuse_green\n") 111 | fid.write("property uchar diffuse_blue\n") 112 | fid.write("end_header\n") 113 | for i in range(xyz.shape[0]): 114 | if i % 1000 == 0: 115 | print("Writing point", i, "/", xyz.shape[0]) 116 | fid.write("%f %f %f 0 0 0 %d %d %d\n" % (xyz[i, 0], xyz[i, 1], 117 | xyz[i, 2], rgb[i, 0], 118 | rgb[i, 1], rgb[i, 2])) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /lib/utils/colmap/nvm_to_ply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script converts a VisualSfM reconstruction file to a PLY point cloud. 33 | 34 | import os 35 | import argparse 36 | import numpy as np 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--nvm_path", required=True) 42 | parser.add_argument("--ply_path", required=True) 43 | parser.add_argument("--normalize", type=bool, default=True) 44 | parser.add_argument("--normalize_p0", type=float, default=0.2) 45 | parser.add_argument("--normalize_p1", type=float, default=0.8) 46 | parser.add_argument("--min_track_length", type=int, default=3) 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | def main(): 52 | args = parse_args() 53 | 54 | with open(args.nvm_path, "r") as fid: 55 | line = fid.readline() 56 | line = fid.readline() 57 | num_images = int(fid.readline()) 58 | 59 | for i in range(num_images + 1): 60 | fid.readline() 61 | 62 | num_points = int(fid.readline()) 63 | 64 | xyz = np.zeros((num_points, 3), dtype=np.float64) 65 | rgb = np.zeros((num_points, 3), dtype=np.uint16) 66 | track_lengths = np.zeros((num_points,), dtype=np.uint32) 67 | 68 | for i in range(num_points): 69 | if i % 1000 == 0: 70 | print("Reading point", i, "/", num_points) 71 | elems = fid.readline().split() 72 | xyz[i] = map(float, elems[0:3]) 73 | rgb[i] = map(int, elems[3:6]) 74 | track_lengths[i] = int(elems[6]) 75 | 76 | mask = track_lengths >= args.min_track_length 77 | xyz = xyz[mask] 78 | rgb = rgb[mask] 79 | 80 | if args.normalize: 81 | sorted_x = np.sort(xyz[:, 0]) 82 | sorted_y = np.sort(xyz[:, 1]) 83 | sorted_z = np.sort(xyz[:, 2]) 84 | 85 | num_coords = sorted_x.size 86 | min_coord = int(args.normalize_p0 * num_coords) 87 | max_coord = int(args.normalize_p1 * num_coords) 88 | mean_coords = xyz.mean(0) 89 | 90 | bbox_min = np.array([sorted_x[min_coord], sorted_y[min_coord], 91 | sorted_z[min_coord]]) 92 | bbox_max = np.array([sorted_x[max_coord], sorted_y[max_coord], 93 | sorted_z[max_coord]]) 94 | 95 | extent = np.linalg.norm(bbox_max - bbox_min) 96 | scale = 10.0 / extent 97 | 98 | xyz -= mean_coords 99 | xyz *= scale 100 | 101 | with open(args.ply_path, "w") as fid: 102 | fid.write("ply\n") 103 | fid.write("format ascii 1.0\n") 104 | fid.write("element vertex %d\n" % xyz.shape[0]) 105 | fid.write("property float x\n") 106 | fid.write("property float y\n") 107 | fid.write("property float z\n") 108 | fid.write("property float nx\n") 109 | fid.write("property float ny\n") 110 | fid.write("property float nz\n") 111 | fid.write("property uchar diffuse_red\n") 112 | fid.write("property uchar diffuse_green\n") 113 | fid.write("property uchar diffuse_blue\n") 114 | fid.write("end_header\n") 115 | for i in range(xyz.shape[0]): 116 | if i % 1000 == 0: 117 | print("Writing point", i, "/", xyz.shape[0]) 118 | fid.write("%f %f %f 0 0 0 %d %d %d\n" % (xyz[i, 0], xyz[i, 1], 119 | xyz[i, 2], rgb[i, 0], 120 | rgb[i, 1], rgb[i, 2])) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /lib/train/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import torch 4 | import tqdm 5 | from torch.nn import DataParallel 6 | from torch.nn.parallel import DistributedDataParallel 7 | from lib.config import cfg 8 | from lib.utils.data_utils import to_cuda 9 | 10 | 11 | class Trainer(object): 12 | def __init__(self, network): 13 | device = torch.device('cuda:{}'.format(cfg.local_rank)) 14 | network = network.to(device) 15 | if cfg.distributed: 16 | network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(network) 17 | network = DistributedDataParallel( 18 | network, 19 | device_ids=[cfg.local_rank], 20 | output_device=cfg.local_rank, 21 | find_unused_parameters=True 22 | ) 23 | self.network = network 24 | self.local_rank = cfg.local_rank 25 | self.device = device 26 | self.global_step = 0 27 | 28 | def reduce_loss_stats(self, loss_stats): 29 | reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()} 30 | return reduced_losses 31 | 32 | def to_cuda(self, batch): 33 | for k in batch: 34 | if isinstance(batch[k], tuple) or isinstance(batch[k], list): 35 | #batch[k] = [b.cuda() for b in batch[k]] 36 | batch[k] = [b.to(self.device) for b in batch[k]] 37 | elif isinstance(batch[k], dict): 38 | batch[k] = {key: self.to_cuda(batch[k][key]) for key in batch[k]} 39 | else: 40 | # batch[k] = batch[k].cuda() 41 | batch[k] = batch[k].to(self.device) 42 | return batch 43 | 44 | def train(self, epoch, data_loader, optimizer, recorder): 45 | max_iter = len(data_loader) 46 | self.network.train() 47 | end = time.time() 48 | if self.global_step == 0: 49 | self.global_step = cfg.ep_iter * epoch 50 | for iteration, batch in enumerate(data_loader): 51 | data_time = time.time() - end 52 | iteration = iteration + 1 53 | 54 | batch = to_cuda(batch, self.device) 55 | batch['step'] = 0 56 | output, loss, loss_stats, image_stats = self.network(batch) 57 | 58 | # training stage: loss; optimizer; scheduler 59 | loss = loss.mean() 60 | optimizer.zero_grad() 61 | loss.backward() 62 | torch.nn.utils.clip_grad_value_(self.network.parameters(), 40) 63 | optimizer.step() 64 | 65 | if cfg.local_rank > 0: 66 | continue 67 | 68 | # data recording stage: loss_stats, time, image_stats 69 | recorder.step += 1 70 | 71 | loss_stats = self.reduce_loss_stats(loss_stats) 72 | recorder.update_loss_stats(loss_stats) 73 | 74 | batch_time = time.time() - end 75 | end = time.time() 76 | recorder.batch_time.update(batch_time) 77 | recorder.data_time.update(data_time) 78 | 79 | self.global_step += 1 80 | if iteration % cfg.log_interval == 0 or iteration == (max_iter - 1): 81 | # print training state 82 | eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration) 83 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 84 | lr = optimizer.param_groups[0]['lr'] 85 | memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 86 | 87 | training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}']) 88 | training_state = training_state.format(eta_string, str(recorder), lr, memory) 89 | print(training_state) 90 | 91 | # record loss_stats and image_dict 92 | recorder.update_image_stats(image_stats) 93 | recorder.record('train') 94 | 95 | def val(self, epoch, data_loader, evaluator=None, recorder=None): 96 | self.network.eval() 97 | torch.cuda.empty_cache() 98 | val_loss_stats = {} 99 | image_stats = {} 100 | data_size = len(data_loader) 101 | for batch in tqdm.tqdm(data_loader): 102 | batch = to_cuda(batch, self.device) 103 | batch['step'] = recorder.step 104 | with torch.no_grad(): 105 | output, loss, loss_stats, _ = self.network(batch) 106 | if evaluator is not None: 107 | image_stats_ = evaluator.evaluate(output, batch) 108 | if image_stats_ is not None: 109 | image_stats.update(image_stats_) 110 | 111 | loss_stats = self.reduce_loss_stats(loss_stats) 112 | for k, v in loss_stats.items(): 113 | val_loss_stats.setdefault(k, 0) 114 | val_loss_stats[k] += v 115 | 116 | loss_state = [] 117 | for k in val_loss_stats.keys(): 118 | val_loss_stats[k] /= data_size 119 | loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k])) 120 | print(loss_state) 121 | 122 | if evaluator is not None: 123 | result = evaluator.summarize() 124 | val_loss_stats.update(result) 125 | 126 | if recorder: 127 | recorder.record('val', epoch, val_loss_stats, image_stats) 128 | -------------------------------------------------------------------------------- /lib/utils/img_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from matplotlib import cm 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | import numpy as np 6 | import cv2 7 | from lib.utils import data_config 8 | 9 | 10 | def unnormalize_img(img, mean, std): 11 | """ 12 | img: [3, h, w] 13 | """ 14 | img = img.detach().cpu().clone() 15 | # img = img / 255. 16 | img *= torch.tensor(std).view(3, 1, 1) 17 | img += torch.tensor(mean).view(3, 1, 1) 18 | min_v = torch.min(img) 19 | img = (img - min_v) / (torch.max(img) - min_v) 20 | return img 21 | 22 | 23 | def bgr_to_rgb(img): 24 | return img[:, :, [2, 1, 0]] 25 | 26 | 27 | def horizon_concate(inp0, inp1): 28 | h0, w0 = inp0.shape[:2] 29 | h1, w1 = inp1.shape[:2] 30 | if inp0.ndim == 3: 31 | inp = np.zeros((max(h0, h1), w0 + w1, 3), dtype=inp0.dtype) 32 | inp[:h0, :w0, :] = inp0 33 | inp[:h1, w0:(w0 + w1), :] = inp1 34 | else: 35 | inp = np.zeros((max(h0, h1), w0 + w1), dtype=inp0.dtype) 36 | inp[:h0, :w0] = inp0 37 | inp[:h1, w0:(w0 + w1)] = inp1 38 | return inp 39 | 40 | 41 | def vertical_concate(inp0, inp1): 42 | h0, w0 = inp0.shape[:2] 43 | h1, w1 = inp1.shape[:2] 44 | if inp0.ndim == 3: 45 | inp = np.zeros((h0 + h1, max(w0, w1), 3), dtype=inp0.dtype) 46 | inp[:h0, :w0, :] = inp0 47 | inp[h0:(h0 + h1), :w1, :] = inp1 48 | else: 49 | inp = np.zeros((h0 + h1, max(w0, w1)), dtype=inp0.dtype) 50 | inp[:h0, :w0] = inp0 51 | inp[h0:(h0 + h1), :w1] = inp1 52 | return inp 53 | 54 | 55 | def transparent_cmap(cmap): 56 | """Copy colormap and set alpha values""" 57 | mycmap = cmap 58 | mycmap._init() 59 | mycmap._lut[:,-1] = 0.3 60 | return mycmap 61 | 62 | cmap = transparent_cmap(plt.get_cmap('jet')) 63 | 64 | 65 | def set_grid(ax, h, w, interval=8): 66 | ax.set_xticks(np.arange(0, w, interval)) 67 | ax.set_yticks(np.arange(0, h, interval)) 68 | ax.grid() 69 | ax.set_yticklabels([]) 70 | ax.set_xticklabels([]) 71 | 72 | 73 | color_list = np.array( 74 | [ 75 | 0.000, 0.447, 0.741, 76 | 0.850, 0.325, 0.098, 77 | 0.929, 0.694, 0.125, 78 | 0.494, 0.184, 0.556, 79 | 0.466, 0.674, 0.188, 80 | 0.301, 0.745, 0.933, 81 | 0.635, 0.078, 0.184, 82 | 0.300, 0.300, 0.300, 83 | 0.600, 0.600, 0.600, 84 | 1.000, 0.000, 0.000, 85 | 1.000, 0.500, 0.000, 86 | 0.749, 0.749, 0.000, 87 | 0.000, 1.000, 0.000, 88 | 0.000, 0.000, 1.000, 89 | 0.667, 0.000, 1.000, 90 | 0.333, 0.333, 0.000, 91 | 0.333, 0.667, 0.000, 92 | 0.333, 1.000, 0.000, 93 | 0.667, 0.333, 0.000, 94 | 0.667, 0.667, 0.000, 95 | 0.667, 1.000, 0.000, 96 | 1.000, 0.333, 0.000, 97 | 1.000, 0.667, 0.000, 98 | 1.000, 1.000, 0.000, 99 | 0.000, 0.333, 0.500, 100 | 0.000, 0.667, 0.500, 101 | 0.000, 1.000, 0.500, 102 | 0.333, 0.000, 0.500, 103 | 0.333, 0.333, 0.500, 104 | 0.333, 0.667, 0.500, 105 | 0.333, 1.000, 0.500, 106 | 0.667, 0.000, 0.500, 107 | 0.667, 0.333, 0.500, 108 | 0.667, 0.667, 0.500, 109 | 0.667, 1.000, 0.500, 110 | 1.000, 0.000, 0.500, 111 | 1.000, 0.333, 0.500, 112 | 1.000, 0.667, 0.500, 113 | 1.000, 1.000, 0.500, 114 | 0.000, 0.333, 1.000, 115 | 0.000, 0.667, 1.000, 116 | 0.000, 1.000, 1.000, 117 | 0.333, 0.000, 1.000, 118 | 0.333, 0.333, 1.000, 119 | 0.333, 0.667, 1.000, 120 | 0.333, 1.000, 1.000, 121 | 0.667, 0.000, 1.000, 122 | 0.667, 0.333, 1.000, 123 | 0.667, 0.667, 1.000, 124 | 0.667, 1.000, 1.000, 125 | 1.000, 0.000, 1.000, 126 | 1.000, 0.333, 1.000, 127 | 1.000, 0.667, 1.000, 128 | 0.167, 0.000, 0.000, 129 | 0.333, 0.000, 0.000, 130 | 0.500, 0.000, 0.000, 131 | 0.667, 0.000, 0.000, 132 | 0.833, 0.000, 0.000, 133 | 1.000, 0.000, 0.000, 134 | 0.000, 0.167, 0.000, 135 | 0.000, 0.333, 0.000, 136 | 0.000, 0.500, 0.000, 137 | 0.000, 0.667, 0.000, 138 | 0.000, 0.833, 0.000, 139 | 0.000, 1.000, 0.000, 140 | 0.000, 0.000, 0.167, 141 | 0.000, 0.000, 0.333, 142 | 0.000, 0.000, 0.500, 143 | 0.000, 0.000, 0.667, 144 | 0.000, 0.000, 0.833, 145 | 0.000, 0.000, 1.000, 146 | 0.000, 0.000, 0.000, 147 | 0.143, 0.143, 0.143, 148 | 0.286, 0.286, 0.286, 149 | 0.429, 0.429, 0.429, 150 | 0.571, 0.571, 0.571, 151 | 0.714, 0.714, 0.714, 152 | 0.857, 0.857, 0.857, 153 | 1.000, 1.000, 1.000, 154 | 0.50, 0.5, 0 155 | ] 156 | ).astype(np.float32) 157 | colors = color_list.reshape((-1, 3)) * 255 158 | colors = np.array(colors, dtype=np.uint8).reshape(len(colors), 1, 1, 3) 159 | 160 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 161 | """ 162 | depth: (H, W) 163 | """ 164 | x = np.nan_to_num(depth) # change nan to 0 165 | if minmax is None: 166 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 167 | ma = np.max(x) 168 | else: 169 | mi,ma = minmax 170 | 171 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 172 | x = (255*x).astype(np.uint8) 173 | x_ = cv2.applyColorMap(x, cmap) 174 | return x_, [mi,ma] 175 | -------------------------------------------------------------------------------- /lib/utils/colmap/read_write_fused_vis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # 12 | # * Redistributions in binary form must reproduce the above copyright 13 | # notice, this list of conditions and the following disclaimer in the 14 | # documentation and/or other materials provided with the distribution. 15 | # 16 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 17 | # its contributors may be used to endorse or promote products derived 18 | # from this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 24 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 25 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 26 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 27 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 28 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 29 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 30 | # POSSIBILITY OF SUCH DAMAGE. 31 | # 32 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 33 | 34 | import os 35 | import collections 36 | import numpy as np 37 | import pandas as pd 38 | from pyntcloud import PyntCloud 39 | 40 | from read_write_model import read_next_bytes, write_next_bytes 41 | 42 | 43 | MeshPoint = collections.namedtuple( 44 | "MeshingPoint", ["position", "color", "normal", "num_visible_images", "visible_image_idxs"]) 45 | 46 | 47 | def read_fused(path_to_fused_ply, path_to_fused_ply_vis): 48 | """ 49 | see: src/mvs/meshing.cc 50 | void ReadDenseReconstruction(const std::string& path 51 | """ 52 | assert os.path.isfile(path_to_fused_ply) 53 | assert os.path.isfile(path_to_fused_ply_vis) 54 | 55 | point_cloud = PyntCloud.from_file(path_to_fused_ply) 56 | xyz_arr = point_cloud.points.loc[:, ["x", "y", "z"]].to_numpy() 57 | normal_arr = point_cloud.points.loc[:, ["nx", "ny", "nz"]].to_numpy() 58 | color_arr = point_cloud.points.loc[:, ["red", "green", "blue"]].to_numpy() 59 | 60 | with open(path_to_fused_ply_vis, "rb") as fid: 61 | num_points = read_next_bytes(fid, 8, "Q")[0] 62 | mesh_points = [0] * num_points 63 | for i in range(num_points): 64 | num_visible_images = read_next_bytes(fid, 4, "I")[0] 65 | visible_image_idxs = read_next_bytes( 66 | fid, num_bytes=4*num_visible_images, 67 | format_char_sequence="I"*num_visible_images) 68 | visible_image_idxs = np.array(tuple(map(int, visible_image_idxs))) 69 | mesh_point = MeshPoint( 70 | position=xyz_arr[i], 71 | color=color_arr[i], 72 | normal=normal_arr[i], 73 | num_visible_images=num_visible_images, 74 | visible_image_idxs=visible_image_idxs) 75 | mesh_points[i] = mesh_point 76 | return mesh_points 77 | 78 | 79 | def write_fused_ply(mesh_points, path_to_fused_ply): 80 | columns = ["x", "y", "z", "nx", "ny", "nz", "red", "green", "blue"] 81 | points_data_frame = pd.DataFrame( 82 | np.zeros((len(mesh_points), len(columns))), 83 | columns=columns) 84 | 85 | positions = np.asarray([point.position for point in mesh_points]) 86 | normals = np.asarray([point.normal for point in mesh_points]) 87 | colors = np.asarray([point.color for point in mesh_points]) 88 | 89 | points_data_frame.loc[:, ["x", "y", "z"]] = positions 90 | points_data_frame.loc[:, ["nx", "ny", "nz"]] = normals 91 | points_data_frame.loc[:, ["red", "green", "blue"]] = colors 92 | 93 | points_data_frame = points_data_frame.astype({ 94 | "x": positions.dtype, "y": positions.dtype, "z": positions.dtype, 95 | "red": colors.dtype, "green": colors.dtype, "blue": colors.dtype, 96 | "nx": normals.dtype, "ny": normals.dtype, "nz": normals.dtype}) 97 | 98 | point_cloud = PyntCloud(points_data_frame) 99 | point_cloud.to_file(path_to_fused_ply) 100 | 101 | 102 | def write_fused_ply_vis(mesh_points, path_to_fused_ply_vis): 103 | """ 104 | see: src/mvs/fusion.cc 105 | void WritePointsVisibility(const std::string& path, const std::vector>& points_visibility) 106 | """ 107 | with open(path_to_fused_ply_vis, "wb") as fid: 108 | write_next_bytes(fid, len(mesh_points), "Q") 109 | for point in mesh_points: 110 | write_next_bytes(fid, point.num_visible_images, "I") 111 | format_char_sequence = "I"*point.num_visible_images 112 | write_next_bytes(fid, [*point.visible_image_idxs], format_char_sequence) 113 | 114 | 115 | def write_fused(points, path_to_fused_ply, path_to_fused_ply_vis): 116 | write_fused_ply(points, path_to_fused_ply) 117 | write_fused_ply_vis(points, path_to_fused_ply_vis) 118 | -------------------------------------------------------------------------------- /lib/utils/base_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | import cv2 5 | import time 6 | from termcolor import colored 7 | import importlib 8 | import torch.distributed as dist 9 | import math 10 | 11 | class perf_timer: 12 | def __init__(self, msg="Elapsed time: {}s", logf=lambda x: print(colored(x, 'yellow')), sync_cuda=True, use_ms=False, disabled=False): 13 | self.logf = logf 14 | self.msg = msg 15 | self.sync_cuda = sync_cuda 16 | self.use_ms = use_ms 17 | self.disabled = disabled 18 | 19 | self.loggedtime = None 20 | 21 | def __enter__(self,): 22 | if self.sync_cuda: 23 | torch.cuda.synchronize() 24 | self.start = time.perf_counter() 25 | 26 | def __exit__(self, exc_type, exc_value, traceback): 27 | if self.sync_cuda: 28 | torch.cuda.synchronize() 29 | self.logtime(self.msg) 30 | 31 | def logtime(self, msg=None, logf=None): 32 | if self.disabled: 33 | return 34 | # SAME CLASS, DIFFERENT FUNCTIONALITY, is this good? 35 | # call the logger for timing code sections 36 | if self.sync_cuda: 37 | torch.cuda.synchronize() 38 | 39 | # always remember current time 40 | prev = self.loggedtime 41 | self.loggedtime = time.perf_counter() 42 | 43 | # print it if we've remembered previous time 44 | if prev is not None and msg: 45 | logf = logf or self.logf 46 | diff = self.loggedtime-prev 47 | diff *= 1000 if self.use_ms else 1 48 | logf(msg.format(diff)) 49 | 50 | return self.loggedtime 51 | 52 | def read_pickle(pkl_path): 53 | with open(pkl_path, 'rb') as f: 54 | return pickle.load(f) 55 | 56 | 57 | def save_pickle(data, pkl_path): 58 | os.system('mkdir -p {}'.format(os.path.dirname(pkl_path))) 59 | with open(pkl_path, 'wb') as f: 60 | pickle.dump(data, f) 61 | 62 | 63 | def project(xyz, K, RT): 64 | """ 65 | xyz: [N, 3] 66 | K: [3, 3] 67 | RT: [3, 4] 68 | """ 69 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 70 | xyz = np.dot(xyz, K.T) 71 | xy = xyz[:, :2] / xyz[:, 2:] 72 | return xy 73 | 74 | def get_bbox_2d(bbox, K, RT): 75 | pts = np.array([[bbox[0, 0], bbox[0, 1], bbox[0, 2]], 76 | [bbox[0, 0], bbox[0, 1], bbox[1, 2]], 77 | [bbox[0, 0], bbox[1, 1], bbox[0, 2]], 78 | [bbox[0, 0], bbox[1, 1], bbox[1, 2]], 79 | [bbox[1, 0], bbox[0, 1], bbox[0, 2]], 80 | [bbox[1, 0], bbox[0, 1], bbox[1, 2]], 81 | [bbox[1, 0], bbox[1, 1], bbox[0, 2]], 82 | [bbox[1, 0], bbox[1, 1], bbox[1, 2]], 83 | ]) 84 | pts_2d = project(pts, K, RT) 85 | return [pts_2d[:, 0].min(), pts_2d[:, 1].min(), pts_2d[:, 0].max(), pts_2d[:, 1].max()] 86 | 87 | 88 | def get_bound_corners(bounds): 89 | min_x, min_y, min_z = bounds[0] 90 | max_x, max_y, max_z = bounds[1] 91 | corners_3d = np.array([ 92 | [min_x, min_y, min_z], 93 | [min_x, min_y, max_z], 94 | [min_x, max_y, min_z], 95 | [min_x, max_y, max_z], 96 | [max_x, min_y, min_z], 97 | [max_x, min_y, max_z], 98 | [max_x, max_y, min_z], 99 | [max_x, max_y, max_z], 100 | ]) 101 | return corners_3d 102 | 103 | def get_bound_2d_mask(bounds, K, pose, H, W): 104 | corners_3d = get_bound_corners(bounds) 105 | corners_2d = project(corners_3d, K, pose) 106 | corners_2d = np.round(corners_2d).astype(int) 107 | mask = np.zeros((H, W), dtype=np.uint8) 108 | cv2.fillPoly(mask, [corners_2d[[0, 1, 3, 2, 0]]], 1) 109 | cv2.fillPoly(mask, [corners_2d[[4, 5, 7, 6, 5]]], 1) 110 | cv2.fillPoly(mask, [corners_2d[[0, 1, 5, 4, 0]]], 1) 111 | cv2.fillPoly(mask, [corners_2d[[2, 3, 7, 6, 2]]], 1) 112 | cv2.fillPoly(mask, [corners_2d[[0, 2, 6, 4, 0]]], 1) 113 | cv2.fillPoly(mask, [corners_2d[[1, 3, 7, 5, 1]]], 1) 114 | return mask 115 | 116 | def load_object(module_name, module_args, **extra_args): 117 | module_path = '.'.join(module_name.split('.')[:-1]) 118 | module = importlib.import_module(module_path) 119 | name = module_name.split('.')[-1] 120 | obj = getattr(module, name)(**extra_args, **module_args) 121 | return obj 122 | 123 | 124 | 125 | def get_indices(length): 126 | num_replicas = dist.get_world_size() 127 | rank = dist.get_rank() 128 | num_samples = int(math.ceil(length * 1.0 / num_replicas)) 129 | total_size = num_samples * num_replicas 130 | indices = np.arange(length).tolist() 131 | indices += indices[: (total_size - len(indices))] 132 | offset = num_samples * rank 133 | indices = indices[offset:offset+num_samples] 134 | return indices 135 | 136 | 137 | class DotDict(dict): 138 | """ 139 | a dictionary that supports dot notation 140 | as well as dictionary access notation 141 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 142 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 143 | get attributes: d.val2 or d['val2'] 144 | """ 145 | __getattr__ = dict.__getitem__ 146 | __setattr__ = dict.__setitem__ 147 | __delattr__ = dict.__delitem__ 148 | 149 | def __init__(self, dct=None): 150 | if dct is not None: 151 | for key, value in dct.items(): 152 | if hasattr(value, 'keys'): 153 | value = DotDict(value) 154 | # self[key] = value 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /lib/networks/enerf/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.config import cfg 5 | 6 | class NeRF(nn.Module): 7 | def __init__(self, hid_n=64, feat_ch=16+3): 8 | """ 9 | """ 10 | super(NeRF, self).__init__() 11 | self.hid_n = hid_n 12 | self.agg = Agg(feat_ch) 13 | self.lr0 = nn.Sequential(nn.Linear(8+16, hid_n), 14 | nn.ReLU()) 15 | self.lrs = nn.ModuleList([ 16 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0) 17 | ]) 18 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) 19 | self.color = nn.Sequential( 20 | nn.Linear(64+24+feat_ch+4, hid_n), 21 | nn.ReLU(), 22 | nn.Linear(hid_n, 1), 23 | nn.ReLU()) 24 | self.lr0.apply(weights_init) 25 | self.lrs.apply(weights_init) 26 | self.sigma.apply(weights_init) 27 | self.color.apply(weights_init) 28 | 29 | def forward(self, vox_feat, img_feat_rgb_dir): 30 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1] 31 | img_feat = self.agg(img_feat_rgb_dir) 32 | S = img_feat_rgb_dir.shape[2] 33 | vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1) 34 | x = self.lr0(vox_img_feat) 35 | for i in range(len(self.lrs)): 36 | x = self.lrs[i](x) 37 | sigma = self.sigma(x) 38 | x = torch.cat((x, vox_img_feat), dim=-1) 39 | x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1) 40 | x = torch.cat((x, img_feat_rgb_dir), dim=-1) 41 | color_weight = F.softmax(self.color(x), dim=-2) 42 | color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2) 43 | return torch.cat([color, sigma], dim=-1) 44 | 45 | class Agg(nn.Module): 46 | def __init__(self, feat_ch): 47 | """ 48 | """ 49 | super(Agg, self).__init__() 50 | self.feat_ch = feat_ch 51 | if cfg.enerf.viewdir_agg: 52 | self.view_fc = nn.Sequential( 53 | nn.Linear(4, feat_ch), 54 | nn.ReLU(), 55 | ) 56 | self.view_fc.apply(weights_init) 57 | self.global_fc = nn.Sequential( 58 | nn.Linear(feat_ch*3, 32), 59 | nn.ReLU(), 60 | ) 61 | 62 | self.agg_w_fc = nn.Sequential( 63 | nn.Linear(32, 1), 64 | nn.ReLU(), 65 | ) 66 | self.fc = nn.Sequential( 67 | nn.Linear(32, 16), 68 | nn.ReLU(), 69 | ) 70 | self.global_fc.apply(weights_init) 71 | self.agg_w_fc.apply(weights_init) 72 | self.fc.apply(weights_init) 73 | 74 | def forward(self, img_feat_rgb_dir): 75 | B, S = len(img_feat_rgb_dir), img_feat_rgb_dir.shape[-2] 76 | if cfg.enerf.viewdir_agg: 77 | view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) 78 | img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat 79 | else: 80 | img_feat_rgb = img_feat_rgb_dir[..., :-4] 81 | 82 | var_feat = torch.var(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1) 83 | avg_feat = torch.mean(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1) 84 | 85 | feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) 86 | global_feat = self.global_fc(feat) 87 | agg_w = F.softmax(self.agg_w_fc(global_feat), dim=-2) 88 | im_feat = (global_feat * agg_w).sum(dim=-2) 89 | return self.fc(im_feat) 90 | 91 | class MVSNeRF(nn.Module): 92 | def __init__(self, hid_n=64, feat_ch=16+3): 93 | """ 94 | """ 95 | super(MVSNeRF, self).__init__() 96 | self.hid_n = hid_n 97 | self.lr0 = nn.Sequential(nn.Linear(8+feat_ch*3, hid_n), 98 | nn.ReLU()) 99 | self.lrs = nn.ModuleList([ 100 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0) 101 | ]) 102 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) 103 | self.color = nn.Sequential( 104 | nn.Linear(hid_n, hid_n), 105 | nn.ReLU(), 106 | nn.Linear(hid_n, 3)) 107 | self.lr0.apply(weights_init) 108 | self.lrs.apply(weights_init) 109 | self.sigma.apply(weights_init) 110 | self.color.apply(weights_init) 111 | 112 | def forward(self, vox_feat, img_feat_rgb_dir): 113 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1] 114 | # img_feat = self.agg(img_feat_rgb_dir) 115 | img_feat = torch.cat([img_feat_rgb_dir[..., i, :-4] for i in range(N_views)] , dim=-1) 116 | S = img_feat_rgb_dir.shape[2] 117 | vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1) 118 | x = self.lr0(vox_img_feat) 119 | for i in range(len(self.lrs)): 120 | x = self.lrs[i](x) 121 | sigma = self.sigma(x) 122 | # x = torch.cat((x, vox_img_feat), dim=-1) 123 | # x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1) 124 | # x = torch.cat((x, img_feat_rgb_dir), dim=-1) 125 | color = torch.sigmoid(self.color(x)) 126 | return torch.cat([color, sigma], dim=-1) 127 | 128 | 129 | 130 | def weights_init(m): 131 | if isinstance(m, nn.Linear): 132 | nn.init.kaiming_normal_(m.weight.data) 133 | if m.bias is not None: 134 | nn.init.zeros_(m.bias.data) 135 | 136 | -------------------------------------------------------------------------------- /lib/networks/enerf/nerf_.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.config import cfg 5 | 6 | class NeRF(nn.Module): 7 | def __init__(self, hid_n=64, feat_ch=16+3): 8 | """ 9 | """ 10 | super(NeRF, self).__init__() 11 | self.hid_n = hid_n 12 | self.agg = Agg(feat_ch) 13 | self.lr0 = nn.Sequential(nn.Linear(16, hid_n), 14 | nn.ReLU()) 15 | self.lrs = nn.ModuleList([ 16 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0) 17 | ]) 18 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) 19 | self.color = nn.Sequential( 20 | nn.Linear(64+16+feat_ch+4, hid_n), 21 | nn.ReLU(), 22 | nn.Linear(hid_n, 1), 23 | nn.ReLU()) 24 | self.lr0.apply(weights_init) 25 | self.lrs.apply(weights_init) 26 | self.sigma.apply(weights_init) 27 | self.color.apply(weights_init) 28 | 29 | def forward(self, vox_feat, img_feat_rgb_dir): 30 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1] 31 | img_feat = self.agg(img_feat_rgb_dir) 32 | S = img_feat_rgb_dir.shape[2] 33 | # vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1) 34 | vox_img_feat = img_feat 35 | x = self.lr0(vox_img_feat) 36 | for i in range(len(self.lrs)): 37 | x = self.lrs[i](x) 38 | sigma = self.sigma(x) 39 | x = torch.cat((x, vox_img_feat), dim=-1) 40 | x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1) 41 | x = torch.cat((x, img_feat_rgb_dir), dim=-1) 42 | color_weight = F.softmax(self.color(x), dim=-2) 43 | color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2) 44 | return torch.cat([color, sigma], dim=-1) 45 | 46 | class Agg(nn.Module): 47 | def __init__(self, feat_ch): 48 | """ 49 | """ 50 | super(Agg, self).__init__() 51 | self.feat_ch = feat_ch 52 | if cfg.enerf.viewdir_agg: 53 | self.view_fc = nn.Sequential( 54 | nn.Linear(4, feat_ch), 55 | nn.ReLU(), 56 | ) 57 | self.view_fc.apply(weights_init) 58 | self.global_fc = nn.Sequential( 59 | nn.Linear(feat_ch*3, 32), 60 | nn.ReLU(), 61 | ) 62 | 63 | self.agg_w_fc = nn.Sequential( 64 | nn.Linear(32, 1), 65 | nn.ReLU(), 66 | ) 67 | self.fc = nn.Sequential( 68 | nn.Linear(32, 16), 69 | nn.ReLU(), 70 | ) 71 | self.global_fc.apply(weights_init) 72 | self.agg_w_fc.apply(weights_init) 73 | self.fc.apply(weights_init) 74 | 75 | def forward(self, img_feat_rgb_dir): 76 | B, S = len(img_feat_rgb_dir), img_feat_rgb_dir.shape[-2] 77 | if cfg.enerf.viewdir_agg: 78 | view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) 79 | img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat 80 | else: 81 | img_feat_rgb = img_feat_rgb_dir[..., :-4] 82 | 83 | var_feat = torch.var(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1) 84 | avg_feat = torch.mean(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1) 85 | 86 | feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) 87 | global_feat = self.global_fc(feat) 88 | agg_w = F.softmax(self.agg_w_fc(global_feat), dim=-2) 89 | im_feat = (global_feat * agg_w).sum(dim=-2) 90 | return self.fc(im_feat) 91 | 92 | class MVSNeRF(nn.Module): 93 | def __init__(self, hid_n=64, feat_ch=16+3): 94 | """ 95 | """ 96 | super(MVSNeRF, self).__init__() 97 | self.hid_n = hid_n 98 | self.lr0 = nn.Sequential(nn.Linear(8+feat_ch*3, hid_n), 99 | nn.ReLU()) 100 | self.lrs = nn.ModuleList([ 101 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0) 102 | ]) 103 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) 104 | self.color = nn.Sequential( 105 | nn.Linear(hid_n, hid_n), 106 | nn.ReLU(), 107 | nn.Linear(hid_n, 3)) 108 | self.lr0.apply(weights_init) 109 | self.lrs.apply(weights_init) 110 | self.sigma.apply(weights_init) 111 | self.color.apply(weights_init) 112 | 113 | def forward(self, vox_feat, img_feat_rgb_dir): 114 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1] 115 | # img_feat = self.agg(img_feat_rgb_dir) 116 | img_feat = torch.cat([img_feat_rgb_dir[..., i, :-4] for i in range(N_views)] , dim=-1) 117 | S = img_feat_rgb_dir.shape[2] 118 | vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1) 119 | x = self.lr0(vox_img_feat) 120 | for i in range(len(self.lrs)): 121 | x = self.lrs[i](x) 122 | sigma = self.sigma(x) 123 | # x = torch.cat((x, vox_img_feat), dim=-1) 124 | # x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1) 125 | # x = torch.cat((x, img_feat_rgb_dir), dim=-1) 126 | color = torch.sigmoid(self.color(x)) 127 | return torch.cat([color, sigma], dim=-1) 128 | 129 | 130 | 131 | def weights_init(m): 132 | if isinstance(m, nn.Linear): 133 | nn.init.kaiming_normal_(m.weight.data) 134 | if m.bias is not None: 135 | nn.init.zeros_(m.bias.data) 136 | 137 | -------------------------------------------------------------------------------- /lib/networks/enerf/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from .feature_net import FeatureNet 6 | from .cost_reg_net import CostRegNet, MinCostRegNet 7 | from . import utils 8 | from lib.config import cfg 9 | from .nerf import NeRF 10 | 11 | class Network(nn.Module): 12 | def __init__(self,): 13 | super(Network, self).__init__() 14 | self.feature_net = FeatureNet() 15 | for i in range(cfg.enerf.cas_config.num): 16 | if i == 0: 17 | cost_reg_l = MinCostRegNet(int(32 * (2**(-i)))) 18 | else: 19 | cost_reg_l = CostRegNet(int(32 * (2**(-i)))) 20 | setattr(self, f'cost_reg_{i}', cost_reg_l) 21 | nerf_l = NeRF(feat_ch=cfg.enerf.cas_config.nerf_model_feat_ch[i]+3) 22 | setattr(self, f'nerf_{i}', nerf_l) 23 | 24 | def render_rays(self, rays, **kwargs): 25 | level, batch, im_feat, feat_volume, nerf_model = kwargs['level'], kwargs['batch'], kwargs['im_feat'], kwargs['feature_volume'], kwargs['nerf_model'] 26 | world_xyz, uvd, z_vals = utils.sample_along_depth(rays, N_samples=cfg.enerf.cas_config.num_samples[level], level=level) 27 | B, N_rays, N_samples = world_xyz.shape[:3] 28 | rgbs = utils.unpreprocess(batch['src_inps'], render_scale=cfg.enerf.cas_config.render_scale[level]) 29 | up_feat_scale = cfg.enerf.cas_config.render_scale[level] / cfg.enerf.cas_config.im_ibr_scale[level] 30 | if up_feat_scale != 1.: 31 | B, S, C, H, W = im_feat.shape 32 | im_feat = F.interpolate(im_feat.reshape(B*S, C, H, W), None, scale_factor=up_feat_scale, align_corners=True, mode='bilinear').view(B, S, C, int(H*up_feat_scale), int(W*up_feat_scale)) 33 | 34 | img_feat_rgb = torch.cat((im_feat, rgbs), dim=2) 35 | H_O, W_O = kwargs['batch']['src_inps'].shape[-2:] 36 | B, H, W = len(uvd), int(H_O * cfg.enerf.cas_config.render_scale[level]), int(W_O * cfg.enerf.cas_config.render_scale[level]) 37 | uvd[..., 0], uvd[..., 1] = (uvd[..., 0]) / (W-1), (uvd[..., 1]) / (H-1) 38 | vox_feat = utils.get_vox_feat(uvd.reshape(B, -1, 3), feat_volume) 39 | img_feat_rgb_dir = utils.get_img_feat(world_xyz, img_feat_rgb, batch, self.training, level) # B * N * S * (8+3+4) 40 | net_output = nerf_model(vox_feat, img_feat_rgb_dir) 41 | net_output = net_output.reshape(B, -1, N_samples, net_output.shape[-1]) 42 | outputs = utils.raw2outputs(net_output, z_vals, cfg.enerf.white_bkgd) 43 | return outputs 44 | 45 | def batchify_rays(self, rays, **kwargs): 46 | all_ret = {} 47 | chunk = cfg.enerf.chunk_size 48 | for i in range(0, rays.shape[1], chunk): 49 | ret = self.render_rays(rays[:, i:i + chunk], **kwargs) 50 | for k in ret: 51 | if k not in all_ret: 52 | all_ret[k] = [] 53 | all_ret[k].append(ret[k]) 54 | all_ret = {k: torch.cat(all_ret[k], dim=1) for k in all_ret} 55 | return all_ret 56 | 57 | 58 | def forward_feat(self, x): 59 | B, S, C, H, W = x.shape 60 | x = x.view(B*S, C, H, W) 61 | feat2, feat1, feat0 = self.feature_net(x) 62 | feats = { 63 | 'level_2': feat0.reshape((B, S, feat0.shape[1], H, W)), 64 | 'level_1': feat1.reshape((B, S, feat1.shape[1], H//2, W//2)), 65 | 'level_0': feat2.reshape((B, S, feat2.shape[1], H//4, W//4)), 66 | } 67 | return feats 68 | 69 | def forward_render(self, ret, batch): 70 | B, _, _, H, W = batch['src_inps'].shape 71 | rgb = ret['rgb'].reshape(B, H, W, 3).permute(0, 3, 1, 2) 72 | rgb = self.cnn_renderer(rgb) 73 | ret['rgb'] = rgb.permute(0, 2, 3, 1).reshape(B, H*W, 3) 74 | 75 | 76 | def forward(self, batch): 77 | feats = self.forward_feat(batch['src_inps']) 78 | ret = {} 79 | depth, std, near_far = None, None, None 80 | for i in range(cfg.enerf.cas_config.num): 81 | feature_volume, depth_values, near_far = utils.build_feature_volume( 82 | feats[f'level_{i}'], 83 | batch, 84 | D=cfg.enerf.cas_config.volume_planes[i], 85 | depth=depth, 86 | std=std, 87 | near_far=near_far, 88 | level=i) 89 | feature_volume, depth_prob = getattr(self, f'cost_reg_{i}')(feature_volume) 90 | depth, std = utils.depth_regression(depth_prob, depth_values, i, batch) 91 | if not cfg.enerf.cas_config.render_if[i]: 92 | continue 93 | rays = utils.build_rays(depth, std, batch, self.training, near_far, i) 94 | # UV(2) + ray_o (3) + ray_d (3) + ray_near_far (2) + volume_near_far (2) 95 | im_feat_level = cfg.enerf.cas_config.render_im_feat_level[i] 96 | ret_i = self.batchify_rays( 97 | rays=rays, 98 | feature_volume=feature_volume, 99 | batch=batch, 100 | im_feat=feats[f'level_{im_feat_level}'], 101 | nerf_model=getattr(self, f'nerf_{i}'), 102 | level=i) 103 | # if i == 1: 104 | # self.forward_render(ret_i, batch) 105 | if cfg.enerf.cas_config.depth_inv[i]: 106 | ret_i.update({'depth_mvs': 1./depth}) 107 | else: 108 | ret_i.update({'depth_mvs': depth}) 109 | ret_i.update({'std': std}) 110 | if ret_i['rgb'].isnan().any(): 111 | __import__('ipdb').set_trace() 112 | ret.update({key+f'_level{i}': ret_i[key] for key in ret_i}) 113 | return ret 114 | -------------------------------------------------------------------------------- /lib/utils/colmap/read_write_dense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # 12 | # * Redistributions in binary form must reproduce the above copyright 13 | # notice, this list of conditions and the following disclaimer in the 14 | # documentation and/or other materials provided with the distribution. 15 | # 16 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 17 | # its contributors may be used to endorse or promote products derived 18 | # from this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 24 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 25 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 26 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 27 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 28 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 29 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 30 | # POSSIBILITY OF SUCH DAMAGE. 31 | # 32 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 33 | 34 | import argparse 35 | import numpy as np 36 | import os 37 | import struct 38 | 39 | 40 | def read_array(path): 41 | with open(path, "rb") as fid: 42 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 43 | usecols=(0, 1, 2), dtype=int) 44 | fid.seek(0) 45 | num_delimiter = 0 46 | byte = fid.read(1) 47 | while True: 48 | if byte == b"&": 49 | num_delimiter += 1 50 | if num_delimiter >= 3: 51 | break 52 | byte = fid.read(1) 53 | array = np.fromfile(fid, np.float32) 54 | array = array.reshape((width, height, channels), order="F") 55 | return np.transpose(array, (1, 0, 2)).squeeze() 56 | 57 | 58 | def write_array(array, path): 59 | """ 60 | see: src/mvs/mat.h 61 | void Mat::Write(const std::string& path) 62 | """ 63 | assert array.dtype == np.float32 64 | if len(array.shape) == 2: 65 | height, width = array.shape 66 | channels = 1 67 | elif len(array.shape) == 3: 68 | height, width, channels = array.shape 69 | else: 70 | assert False 71 | 72 | with open(path, "w") as fid: 73 | fid.write(str(width) + "&" + str(height) + "&" + str(channels) + "&") 74 | 75 | with open(path, "ab") as fid: 76 | if len(array.shape) == 2: 77 | array_trans = np.transpose(array, (1, 0)) 78 | elif len(array.shape) == 3: 79 | array_trans = np.transpose(array, (1, 0, 2)) 80 | else: 81 | assert False 82 | data_1d = array_trans.reshape(-1, order="F") 83 | data_list = data_1d.tolist() 84 | endian_character = "<" 85 | format_char_sequence = "".join(["f"] * len(data_list)) 86 | byte_data = struct.pack(endian_character + format_char_sequence, *data_list) 87 | fid.write(byte_data) 88 | 89 | 90 | def parse_args(): 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("-d", "--depth_map", 93 | help="path to depth map", type=str, required=True) 94 | parser.add_argument("-n", "--normal_map", 95 | help="path to normal map", type=str, required=True) 96 | parser.add_argument("--min_depth_percentile", 97 | help="minimum visualization depth percentile", 98 | type=float, default=5) 99 | parser.add_argument("--max_depth_percentile", 100 | help="maximum visualization depth percentile", 101 | type=float, default=95) 102 | args = parser.parse_args() 103 | return args 104 | 105 | 106 | def main(): 107 | args = parse_args() 108 | 109 | if args.min_depth_percentile > args.max_depth_percentile: 110 | raise ValueError("min_depth_percentile should be less than or equal " 111 | "to the max_depth_perceintile.") 112 | 113 | # Read depth and normal maps corresponding to the same image. 114 | if not os.path.exists(args.depth_map): 115 | raise FileNotFoundError("File not found: {}".format(args.depth_map)) 116 | 117 | if not os.path.exists(args.normal_map): 118 | raise FileNotFoundError("File not found: {}".format(args.normal_map)) 119 | 120 | depth_map = read_array(args.depth_map) 121 | normal_map = read_array(args.normal_map) 122 | 123 | min_depth, max_depth = np.percentile( 124 | depth_map, [args.min_depth_percentile, args.max_depth_percentile]) 125 | depth_map[depth_map < min_depth] = min_depth 126 | depth_map[depth_map > max_depth] = max_depth 127 | 128 | import pylab as plt 129 | 130 | # Visualize the depth map. 131 | plt.figure() 132 | plt.imshow(depth_map) 133 | plt.title("depth map") 134 | 135 | # Visualize the normal map. 136 | plt.figure() 137 | plt.imshow(normal_map) 138 | plt.title("normal map") 139 | 140 | plt.show() 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /lib/utils/colmap/crawl_camera_specs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import re 33 | import argparse 34 | import requests 35 | from lxml.html import soupparser 36 | 37 | 38 | MAX_REQUEST_TRIALS = 10 39 | 40 | 41 | def parse_args(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--lib_path", required=True) 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def request_trial(func, *args, **kwargs): 49 | for i in range(MAX_REQUEST_TRIALS): 50 | try: 51 | response = func(*args, **kwargs) 52 | except: 53 | continue 54 | else: 55 | return response 56 | 57 | raise SystemError 58 | 59 | 60 | def main(): 61 | args = parse_args() 62 | 63 | ########################################################################## 64 | # Header file 65 | ########################################################################## 66 | 67 | with open(args.lib_path + ".h", "w") as f: 68 | f.write("#include \n") 69 | f.write("#include \n") 70 | f.write("#include \n\n") 71 | f.write("// { make1 : ({ model1 : sensor-width in mm }, ...), ... }\n") 72 | f.write("typedef std::vector> make_specs_t;\n") 73 | f.write("typedef std::unordered_map camera_specs_t;;\n\n") 74 | f.write("camera_specs_t InitializeCameraSpecs();\n\n") 75 | 76 | ########################################################################## 77 | # Source file 78 | ########################################################################## 79 | 80 | makes_response = requests.get("http://www.digicamdb.com") 81 | makes_tree = soupparser.fromstring(makes_response.text) 82 | makes_node = makes_tree.find(".//select[@id=\"select_brand\"]") 83 | makes = [b.attrib["value"] for b in makes_node.iter("option")] 84 | 85 | with open(args.lib_path + ".cc", "w") as f: 86 | f.write("camera_specs_t InitializeCameraSpecs() {\n") 87 | f.write(" camera_specs_t specs;\n\n") 88 | for make in makes: 89 | f.write(" {\n") 90 | f.write(" auto& make_specs = specs[\"%s\"];\n" % make.lower().replace(" ", "")) 91 | 92 | models_response = request_trial( 93 | requests.post, 94 | "http://www.digicamdb.com/inc/ajax.php", 95 | data={"b": make, "role": "header_search"}) 96 | 97 | models_tree = soupparser.fromstring(models_response.text) 98 | models_code = "" 99 | num_models = 0 100 | for model_node in models_tree.iter("option"): 101 | model = model_node.attrib.get("value") 102 | model_name = model_node.text 103 | if model is None: 104 | continue 105 | 106 | url = "http://www.digicamdb.com/specs/{0}_{1}" \ 107 | .format(make, model) 108 | specs_response = request_trial(requests.get, url) 109 | 110 | specs_tree = soupparser.fromstring(specs_response.text) 111 | for spec in specs_tree.findall(".//td[@class=\"info_key\"]"): 112 | if spec.text.strip() == "Sensor:": 113 | sensor_text = spec.find("..").find("./td[@class=\"bold\"]") 114 | sensor_text = sensor_text.text.strip() 115 | m = re.match(".*?([\d.]+) x ([\d.]+).*?", sensor_text) 116 | sensor_width = m.group(1) 117 | data = (model_name.lower().replace(" ", ""), 118 | float(sensor_width.replace(" ", ""))) 119 | models_code += " make_specs.emplace_back(\"%s\", %.4ff);\n" % data 120 | 121 | print(make, model_name) 122 | print(" ", sensor_text) 123 | 124 | num_models += 1 125 | 126 | f.write(" make_specs.reserve(%d);\n" % num_models) 127 | f.write(models_code) 128 | f.write(" }\n\n") 129 | 130 | f.write(" return specs;\n") 131 | f.write("}\n") 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /lib/utils/rend_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def normalize(x): 5 | return x / np.linalg.norm(x) 6 | 7 | def ptstocam(pts, c2w): 8 | tt = np.matmul(c2w[:3, :3].T, (pts-c2w[:3, 3])[..., np.newaxis])[..., 0] 9 | return tt 10 | 11 | def viewmatrix(z, up, pos): 12 | vec2 = normalize(z) 13 | vec0_avg = up 14 | vec1 = normalize(np.cross(vec2, vec0_avg)) 15 | vec0 = normalize(np.cross(vec1, vec2)) 16 | m = np.stack([vec0, vec1, vec2, pos], 1) 17 | return m 18 | 19 | def gen_path(RT, center=None, num_views=100): 20 | lower_row = np.array([[0., 0., 0., 1.]]) 21 | 22 | # transfer RT to camera_to_world matrix 23 | RT = np.array(RT) 24 | RT[:] = np.linalg.inv(RT[:]) 25 | 26 | RT = np.concatenate([RT[:, :, 1:2], RT[:, :, 0:1], 27 | -RT[:, :, 2:3], RT[:, :, 3:4]], 2) 28 | 29 | up = normalize(RT[:, :3, 0].sum(0)) # average up vector 30 | z = normalize(RT[0, :3, 2]) 31 | vec1 = normalize(np.cross(z, up)) 32 | vec2 = normalize(np.cross(up, vec1)) 33 | z_off = 0 34 | 35 | if center is None: 36 | center = RT[:, :3, 3].mean(0) 37 | z_off = 1.3 38 | 39 | c2w = np.stack([up, vec1, vec2, center], 1) 40 | 41 | # get radii for spiral path 42 | tt = ptstocam(RT[:, :3, 3], c2w).T 43 | rads = np.percentile(np.abs(tt), 80, -1) 44 | rads = rads * 1.3 45 | rads = np.array(list(rads) + [1.]) 46 | 47 | render_w2c = [] 48 | for theta in np.linspace(0., 2 * np.pi, num_views + 1)[:-1]: 49 | # camera position 50 | cam_pos = np.array([0, np.sin(theta), np.cos(theta), 1] * rads) 51 | cam_pos_world = np.dot(c2w[:3, :4], cam_pos) 52 | # z axis 53 | z = normalize(cam_pos_world - 54 | np.dot(c2w[:3, :4], np.array([z_off, 0, 0, 1.]))) 55 | # vector -> 3x4 matrix (camera_to_world) 56 | mat = viewmatrix(z, up, cam_pos_world) 57 | 58 | mat = np.concatenate([mat[:, 1:2], mat[:, 0:1], 59 | -mat[:, 2:3], mat[:, 3:4]], 1) 60 | mat = np.concatenate([mat, lower_row], 0) 61 | mat = np.linalg.inv(mat) 62 | render_w2c.append(mat) 63 | 64 | return render_w2c 65 | 66 | def create_center_radius(center, radius=5., up='y', ranges=[0, 360, 36], angle_x=0, **kwargs): 67 | center = np.array(center).reshape(1, 3) 68 | thetas = np.deg2rad(np.linspace(*ranges)) 69 | st = np.sin(thetas) 70 | ct = np.cos(thetas) 71 | zero = np.zeros_like(st) 72 | Rotx = cv2.Rodrigues(np.deg2rad(angle_x) * np.array([1., 0., 0.]))[0] 73 | if up == 'z': 74 | center = np.stack([radius*ct, radius*st, zero], axis=1) + center 75 | R = np.stack([-st, ct, zero, zero, zero, zero-1, -ct, -st, zero], axis=-1) 76 | elif up == 'y': 77 | center = np.stack([radius*ct, zero, radius*st, ], axis=1) + center 78 | R = np.stack([ 79 | +st, zero, -ct, 80 | zero, zero-1, zero, 81 | -ct, zero, -st], axis=-1) 82 | R = R.reshape(-1, 3, 3) 83 | R = np.einsum('ab,fbc->fac', Rotx, R) 84 | center = center.reshape(-1, 3, 1) 85 | T = - R @ center 86 | RT = np.dstack([R, T]) 87 | return RT 88 | 89 | 90 | def gen_path(RT, center=None, radius=3.2, up='z', ranges=[0, 360, 90], angle_x=27, **kwargs): 91 | ranges = [0, 360, kwargs['num_views']] 92 | c2ws = np.linalg.inv(RT[:]) 93 | if center is None: 94 | center = c2ws[:, :3, 3].mean(0) 95 | center[0], center[1] = 0., 0. 96 | 97 | RTs = [] 98 | center = np.array(center).reshape(1, 3) 99 | thetas = np.deg2rad(np.linspace(*ranges)) 100 | st = np.sin(thetas) 101 | ct = np.cos(thetas) 102 | zero = np.zeros_like(st) 103 | Rotx = cv2.Rodrigues(np.deg2rad(angle_x) * np.array([1., 0., 0.]))[0] 104 | if up == 'z': 105 | center = np.stack([radius*ct, radius*st, zero], axis=1) + center 106 | R = np.stack([-st, ct, zero, zero, zero, zero-1, -ct, -st, zero], axis=-1) 107 | elif up == 'y': 108 | center = np.stack([radius*ct, zero, radius*st, ], axis=1) + center 109 | R = np.stack([ 110 | +st, zero, -ct, 111 | zero, zero-1, zero, 112 | -ct, zero, -st], axis=-1) 113 | R = R.reshape(-1, 3, 3) 114 | R = np.einsum('ab,fbc->fac', Rotx, R) 115 | center = center.reshape(-1, 3, 1) 116 | T = - R @ center 117 | RT = np.dstack([R, T]) 118 | RT_bottom = np.zeros_like(RT[:, :1]) 119 | RT_bottom[:, :, 3] = 1 120 | # __import__('ipdb').set_trace() 121 | ext = np.concatenate([RT, RT_bottom], axis=1) 122 | c2w = np.linalg.inv(ext) 123 | # __import__('ipdb').set_trace() 124 | # import matplotlib.pyplot as plt 125 | # plt.plot(c2ws[:, 0, 3], c2ws[:, 1, 3], '.') 126 | # plt.plot(c2w[:, 0, 3], c2w[:, 1, 3], '.') 127 | # plt.show() 128 | return ext 129 | 130 | def gen_nerf_path(c2ws, depth_ranges, rads_scale=.5, N_views=60): 131 | c2w = poses_avg(c2ws) 132 | up = normalize(c2ws[:, :3, 1].sum(0)) 133 | 134 | close_depth, inf_depth = depth_ranges 135 | dt = .75 136 | mean_dz = 1./(( (1.-dt)/close_depth + dt/inf_depth )) 137 | focal = mean_dz 138 | 139 | shrink_factor = .8 140 | zdelta = close_depth * .2 141 | tt = c2ws[:, :3, 3] = c2w[:3, 3][None] 142 | rads = np.percentile(np.abs(tt), 70, 0)*rads_scale 143 | 144 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) 145 | return render_poses 146 | 147 | def poses_avg(poses): 148 | center = poses[:, :3, 3].mean(0) 149 | vec2 = normalize(poses[:, :3, 2].sum(0)) 150 | up = poses[:, :3, 1].sum(0) 151 | c2w = viewmatrix(vec2, up, center) 152 | return c2w 153 | 154 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): 155 | render_poses = [] 156 | rads = np.array(list(rads) + [1.]) 157 | 158 | for theta in np.linspace(0., 2. * np.pi * N_rots, N+1)[:-1]: 159 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 160 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 161 | render_poses.append(viewmatrix(z, up, c)) 162 | return render_poses 163 | -------------------------------------------------------------------------------- /lib/datasets/nerf/enerf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from glob import glob 4 | from lib.utils.data_utils import load_K_Rt_from_P, read_cam_file 5 | from lib.config import cfg 6 | import imageio 7 | import tqdm 8 | from multiprocessing import Pool 9 | import copy 10 | import cv2 11 | import random 12 | from lib.config import cfg 13 | from lib.utils import data_utils 14 | from PIL import Image 15 | import torch 16 | import json 17 | from lib.datasets import enerf_utils 18 | 19 | class Dataset: 20 | def __init__(self, **kwargs): 21 | super(Dataset, self).__init__() 22 | self.data_root = os.path.join(cfg.workspace, kwargs['data_root']) 23 | self.split = kwargs['split'] 24 | if 'scene' in kwargs: 25 | self.scenes = [kwargs['scene']] 26 | else: 27 | self.scenes = [] 28 | self.build_metas() 29 | 30 | def build_metas(self): 31 | if len(self.scenes) == 0: 32 | scenes = ['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship'] 33 | else: 34 | scenes = self.scenes 35 | self.scene_infos = {} 36 | self.metas = [] 37 | pairs = torch.load('data/mvsnerf/pairs.th') 38 | for scene in scenes: 39 | json_info = json.load(open(os.path.join(self.data_root, scene,'transforms_train.json'))) 40 | b2c = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 41 | scene_info = {'ixts': [], 'exts': [], 'img_paths': []} 42 | for idx in range(len(json_info['frames'])): 43 | c2w = np.array(json_info['frames'][idx]['transform_matrix']) 44 | c2w = c2w @ b2c 45 | ext = np.linalg.inv(c2w) 46 | ixt = np.eye(3) 47 | ixt[0][2], ixt[1][2] = 400., 400. 48 | focal = .5 * 800 / np.tan(.5 * json_info['camera_angle_x']) 49 | ixt[0][0], ixt[1][1] = focal, focal 50 | scene_info['ixts'].append(ixt.astype(np.float32)) 51 | scene_info['exts'].append(ext.astype(np.float32)) 52 | img_path = os.path.join(self.data_root, scene, 'train/r_{}.png'.format(idx)) 53 | scene_info['img_paths'].append(img_path) 54 | self.scene_infos[scene] = scene_info 55 | train_ids, render_ids = pairs[f'{scene}_train'], pairs[f'{scene}_val'] 56 | if self.split == 'train': 57 | render_ids = train_ids 58 | c2ws = np.stack([np.linalg.inv(scene_info['exts'][idx]) for idx in train_ids]) 59 | for idx in render_ids: 60 | c2w = np.linalg.inv(scene_info['exts'][idx]) 61 | distance = np.linalg.norm((c2w[:3, 3][None] - c2ws[:, :3, 3]), axis=-1) 62 | 63 | argsorts = distance.argsort() 64 | argsorts = argsorts[1:] if idx in train_ids else argsorts 65 | 66 | input_views_num = cfg.enerf.train_input_views[1] + 1 if self.split == 'train' else cfg.enerf.test_input_views 67 | src_views = [train_ids[i] for i in argsorts[:input_views_num]] 68 | self.metas += [(scene, idx, src_views)] 69 | 70 | def __getitem__(self, index_meta): 71 | index, input_views_num = index_meta 72 | scene, tar_view, src_views = self.metas[index] 73 | if self.split == 'train': 74 | if np.random.random() < 0.1: 75 | src_views = src_views + [tar_view] 76 | src_views = random.sample(src_views, input_views_num) 77 | scene_info = self.scene_infos[scene] 78 | scene_info['scene_name'] = scene 79 | tar_img, tar_ext, tar_ixt = self.read_tar(scene_info, tar_view) 80 | src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views) 81 | 82 | ret = {'src_inps': src_inps.transpose(0, 3, 1, 2), 83 | 'src_exts': src_exts, 84 | 'src_ixts': src_ixts} 85 | tar_mask = np.ones_like(tar_img[..., 0]).astype(np.uint8) 86 | H, W = tar_img.shape[:2] 87 | ret.update({'tar_ext': tar_ext, 88 | 'tar_ixt': tar_ixt}) 89 | if self.split != 'train': 90 | ret.update({'tar_img': tar_img, 91 | 'tar_mask': tar_mask}) 92 | near_far = np.array([2.5, 5.5]).astype(np.float32) 93 | ret.update({'near_far': np.array(near_far).astype(np.float32)}) 94 | ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}}) 95 | 96 | for i in range(cfg.enerf.cas_config.num): 97 | rays, rgb, msk = enerf_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split) 98 | ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk}) 99 | s = cfg.enerf.cas_config.volume_scale[i] 100 | ret['meta'].update({f'h_{i}': int(H*s), f'w_{i}': int(W*s)}) 101 | return ret 102 | 103 | def read_src(self, scene, src_views): 104 | src_ids = src_views 105 | ixts, exts, imgs = [], [], [] 106 | for idx in src_ids: 107 | img = self.read_image(scene, idx) 108 | imgs.append((img*2-1).astype(np.float32)) 109 | ixt, ext = self.read_cam(scene, idx) 110 | ixts.append(ixt) 111 | exts.append(ext) 112 | return np.stack(imgs), np.stack(exts), np.stack(ixts) 113 | 114 | def read_tar(self, scene, view_idx): 115 | img = self.read_image(scene, view_idx) 116 | ixt, ext = self.read_cam(scene, view_idx) 117 | return img, ext, ixt 118 | 119 | def read_cam(self, scene, view_idx): 120 | ext = scene['exts'][view_idx] 121 | ixt = scene['ixts'][view_idx] 122 | return ixt, ext 123 | 124 | def read_image(self, scene, view_idx): 125 | img_path = scene['img_paths'][view_idx] 126 | img = (np.array(imageio.imread(img_path)) / 255.).astype(np.float32) 127 | img = (img[..., :3] * img[..., -1:] + (1 - img[..., -1:])).astype(np.float32) 128 | return img 129 | 130 | def __len__(self): 131 | return len(self.metas) 132 | 133 | def get_K_from_params(params): 134 | K = np.zeros((3, 3)).astype(np.float32) 135 | K[0][0], K[0][2], K[1][2] = params[:3] 136 | K[1][1] = K[0][0] 137 | K[2][2] = 1. 138 | return K 139 | 140 | -------------------------------------------------------------------------------- /lib/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from torch.utils.data.sampler import BatchSampler 3 | import numpy as np 4 | import torch 5 | import math 6 | import torch.distributed as dist 7 | from lib.config import cfg 8 | 9 | class EnerfBatchSampler(Sampler): 10 | def __init__(self, sampler, batch_size, drop_last, sampler_meta): 11 | self.sampler = sampler 12 | self.batch_size = batch_size 13 | self.drop_last = drop_last 14 | self.input_views = sampler_meta.input_views_num 15 | self.views_prob = sampler_meta.input_views_prob 16 | if cfg.fix_random: 17 | random.seed(0) 18 | 19 | def __iter__(self): 20 | batch = [] 21 | input_views_num = np.random.choice(self.input_views, 1, p=self.views_prob) 22 | for idx in self.sampler: 23 | batch.append((idx, input_views_num.item())) 24 | if len(batch) == self.batch_size: 25 | input_views_num = np.random.choice(self.input_views, 1, p=self.views_prob) 26 | yield batch 27 | batch = [] 28 | if len(batch) > 0 and not self.drop_last: 29 | yield batch 30 | 31 | def __len__(self): 32 | if self.drop_last: 33 | return len(self.sampler) // self.batch_size 34 | else: 35 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 36 | 37 | 38 | class ImageSizeBatchSampler(Sampler): 39 | def __init__(self, sampler, batch_size, drop_last, sampler_meta): 40 | self.sampler = sampler 41 | self.batch_size = batch_size 42 | self.drop_last = drop_last 43 | self.strategy = sampler_meta.strategy 44 | self.hmin, self.wmin = sampler_meta.min_hw 45 | self.hmax, self.wmax = sampler_meta.max_hw 46 | self.divisor = 32 47 | if cfg.fix_random: 48 | np.random.seed(0) 49 | 50 | def generate_height_width(self): 51 | if self.strategy == 'origin': 52 | return -1, -1 53 | h = np.random.randint(self.hmin, self.hmax + 1) 54 | w = np.random.randint(self.wmin, self.wmax + 1) 55 | h = (h | (self.divisor - 1)) + 1 56 | w = (w | (self.divisor - 1)) + 1 57 | return h, w 58 | 59 | def __iter__(self): 60 | batch = [] 61 | h, w = self.generate_height_width() 62 | for idx in self.sampler: 63 | batch.append((idx, h, w)) 64 | if len(batch) == self.batch_size: 65 | h, w = self.generate_height_width() 66 | yield batch 67 | batch = [] 68 | if len(batch) > 0 and not self.drop_last: 69 | yield batch 70 | 71 | def __len__(self): 72 | if self.drop_last: 73 | return len(self.sampler) // self.batch_size 74 | else: 75 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 76 | 77 | 78 | class IterationBasedBatchSampler(BatchSampler): 79 | """ 80 | Wraps a BatchSampler, resampling from it until 81 | a specified number of iterations have been sampled 82 | """ 83 | 84 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 85 | self.batch_sampler = batch_sampler 86 | self.sampler = self.batch_sampler.sampler 87 | self.num_iterations = num_iterations 88 | self.start_iter = start_iter 89 | 90 | def __iter__(self): 91 | iteration = self.start_iter 92 | while iteration <= self.num_iterations: 93 | for batch in self.batch_sampler: 94 | iteration += 1 95 | if iteration > self.num_iterations: 96 | break 97 | yield batch 98 | 99 | def __len__(self): 100 | return self.num_iterations 101 | 102 | 103 | class DistributedSampler(Sampler): 104 | """Sampler that restricts data loading to a subset of the dataset. 105 | It is especially useful in conjunction with 106 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 107 | process can pass a DistributedSampler instance as a DataLoader sampler, 108 | and load a subset of the original dataset that is exclusive to it. 109 | .. note:: 110 | Dataset is assumed to be of constant size. 111 | Arguments: 112 | dataset: Dataset used for sampling. 113 | num_replicas (optional): Number of processes participating in 114 | distributed training. 115 | rank (optional): Rank of the current process within num_replicas. 116 | """ 117 | 118 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 119 | if num_replicas is None: 120 | if not dist.is_available(): 121 | raise RuntimeError("Requires distributed package to be available") 122 | num_replicas = dist.get_world_size() 123 | if rank is None: 124 | if not dist.is_available(): 125 | raise RuntimeError("Requires distributed package to be available") 126 | rank = dist.get_rank() 127 | self.dataset = dataset 128 | self.num_replicas = num_replicas 129 | self.rank = rank 130 | self.epoch = 0 131 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 132 | self.total_size = self.num_samples * self.num_replicas 133 | self.shuffle = shuffle 134 | 135 | def __iter__(self): 136 | if self.shuffle: 137 | # deterministically shuffle based on epoch 138 | g = torch.Generator() 139 | g.manual_seed(self.epoch) 140 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 141 | else: 142 | indices = torch.arange(len(self.dataset)).tolist() 143 | 144 | # add extra samples to make it evenly divisible 145 | indices += indices[: (self.total_size - len(indices))] 146 | assert len(indices) == self.total_size 147 | 148 | # subsample 149 | offset = self.num_samples * self.rank 150 | indices = indices[offset:offset+self.num_samples] 151 | assert len(indices) == self.num_samples 152 | 153 | return iter(indices) 154 | 155 | def __len__(self): 156 | return self.num_samples 157 | 158 | def set_epoch(self, epoch): 159 | self.epoch = epoch 160 | -------------------------------------------------------------------------------- /lib/networks/enerf/network_human.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from .feature_net import FeatureNet, CNNRender 6 | from .cost_reg_net import CostRegNet, MinCostRegNet 7 | from . import utils 8 | from lib.config import cfg 9 | from .nerf import NeRF 10 | 11 | class Network(nn.Module): 12 | def __init__(self,): 13 | super(Network, self).__init__() 14 | self.feature_net = FeatureNet() 15 | for i in range(cfg.enerf.cas_config.num): 16 | if i == 0: 17 | cost_reg_l = MinCostRegNet(int(32 * (2**(-i)))) 18 | else: 19 | cost_reg_l = CostRegNet(int(32 * (2**(-i)))) 20 | setattr(self, f'cost_reg_{i}', cost_reg_l) 21 | nerf_l = NeRF(feat_ch=cfg.enerf.cas_config.nerf_model_feat_ch[i]+3) 22 | setattr(self, f'nerf_{i}', nerf_l) 23 | 24 | def render_rays(self, rays, **kwargs): 25 | level, batch, im_feat, feat_volume, nerf_model = kwargs['level'], kwargs['batch'], kwargs['im_feat'], kwargs['feature_volume'], kwargs['nerf_model'] 26 | world_xyz, uvd, z_vals = utils.sample_along_depth(rays, N_samples=cfg.enerf.cas_config.num_samples[level], level=level) 27 | B, N_rays, N_samples = world_xyz.shape[:3] 28 | rgbs = utils.unpreprocess(batch['src_inps'], render_scale=cfg.enerf.cas_config.render_scale[level]) 29 | up_feat_scale = cfg.enerf.cas_config.render_scale[level] / cfg.enerf.cas_config.im_ibr_scale[level] 30 | if up_feat_scale != 1.: 31 | B, S, C, H, W = im_feat.shape 32 | im_feat = F.interpolate(im_feat.reshape(B*S, C, H, W), None, scale_factor=up_feat_scale, align_corners=True, mode='bilinear').view(B, S, C, int(H*up_feat_scale), int(W*up_feat_scale)) 33 | 34 | img_feat_rgb = torch.cat((im_feat, rgbs), dim=2) 35 | H_O, W_O = kwargs['batch']['src_inps'].shape[-2:] 36 | B, H, W = len(uvd), int(H_O * cfg.enerf.cas_config.render_scale[level]), int(W_O * cfg.enerf.cas_config.render_scale[level]) 37 | uvd[..., 0], uvd[..., 1] = (uvd[..., 0]) / (W-1), (uvd[..., 1]) / (H-1) 38 | vox_feat = utils.get_vox_feat(uvd.reshape(B, -1, 3), feat_volume) 39 | img_feat_rgb_dir = utils.get_img_feat(world_xyz, img_feat_rgb, batch, self.training, level) # B * N * S * (8+3+4) 40 | net_output = nerf_model(vox_feat, img_feat_rgb_dir) 41 | net_output = net_output.reshape(B, -1, N_samples, net_output.shape[-1]) 42 | outputs = utils.raw2outputs(net_output, z_vals, cfg.enerf.white_bkgd) 43 | return outputs 44 | 45 | def batchify_rays(self, rays, **kwargs): 46 | all_ret = {} 47 | chunk = cfg.enerf.chunk_size 48 | for i in range(0, rays.shape[1], chunk): 49 | ret = self.render_rays(rays[:, i:i + chunk], **kwargs) 50 | for k in ret: 51 | if k not in all_ret: 52 | all_ret[k] = [] 53 | all_ret[k].append(ret[k]) 54 | all_ret = {k: torch.cat(all_ret[k], dim=1) for k in all_ret} 55 | return all_ret 56 | 57 | 58 | def forward_feat(self, x): 59 | B, S, C, H, W = x.shape 60 | x = x.view(B*S, C, H, W) 61 | feat2, feat1, feat0 = self.feature_net(x) 62 | feats = { 63 | 'level_2': feat0.reshape((B, S, feat0.shape[1], H, W)), 64 | 'level_1': feat1.reshape((B, S, feat1.shape[1], H//2, W//2)), 65 | 'level_0': feat2.reshape((B, S, feat2.shape[1], H//4, W//4)), 66 | } 67 | return feats 68 | 69 | def forward(self, batch): 70 | feats = self.forward_feat(batch['src_inps']) 71 | ret = {} 72 | depth, std, near_far = None, None, None 73 | for i in range(cfg.enerf.cas_config.num): 74 | feature_volume, depth_values, near_far = utils.build_feature_volume( 75 | feats[f'level_{i}'], 76 | batch, 77 | D=cfg.enerf.cas_config.volume_planes[i], 78 | depth=depth, 79 | std=std, 80 | near_far=near_far, 81 | level=i) 82 | feature_volume, depth_prob = getattr(self, f'cost_reg_{i}')(feature_volume) 83 | depth, std = utils.depth_regression(depth_prob, depth_values, i, batch) 84 | if not cfg.enerf.cas_config.render_if[i]: 85 | continue 86 | rays = utils.build_rays(depth, std, batch, self.training, near_far, i) 87 | # UV(2) + ray_o (3) + ray_d (3) + ray_near_far (2) + volume_near_far (2) 88 | im_feat_level = cfg.enerf.cas_config.render_im_feat_level[i] 89 | 90 | if 'mask_at_box' in batch and not self.training and i == cfg.enerf.cas_config.num -1: 91 | mask_at_box = batch['mask_at_box'].bool().reshape(1, -1) 92 | # mask_at_box = batch['mask_at_box'].reshape(1, -1) 93 | rays = rays[mask_at_box][None] 94 | 95 | ret_i = self.batchify_rays( 96 | rays=rays, 97 | feature_volume=feature_volume, 98 | batch=batch, 99 | im_feat=feats[f'level_{im_feat_level}'], 100 | nerf_model=getattr(self, f'nerf_{i}'), 101 | level=i) 102 | if 'mask_at_box' in batch and not self.training and i == cfg.enerf.cas_config.num - 1: 103 | rgb = torch.zeros_like(batch['mask_at_box'].reshape(1, -1))[..., None].repeat(1, 1, 3).float() 104 | mask_at_box = batch['mask_at_box'].bool().reshape(1, -1) 105 | if mask_at_box.sum() > 1: 106 | rgb[mask_at_box] = ret_i['rgb'][0] 107 | ret_i['rgb'] = rgb 108 | # ret_i['rgb'].update({'rgb': rgb}) 109 | # if i == 1: 110 | # self.forward_render(ret_i, batch) 111 | if cfg.enerf.cas_config.depth_inv[i]: 112 | ret_i.update({'depth_mvs': 1./depth}) 113 | else: 114 | ret_i.update({'depth_mvs': depth}) 115 | ret_i.update({'std': std}) 116 | # if ret_i['rgb'].isnan().any(): 117 | # __import__('ipdb').set_trace() 118 | ret.update({key+f'_level{i}': ret_i[key] for key in ret_i}) 119 | return ret 120 | -------------------------------------------------------------------------------- /lib/config/config.py: -------------------------------------------------------------------------------- 1 | from .yacs import CfgNode as CN 2 | import argparse 3 | import os 4 | import numpy as np 5 | from . import yacs 6 | 7 | 8 | cfg = CN() 9 | 10 | cfg.workspace = os.environ['workspace'] 11 | print('Workspace: ', cfg.workspace) 12 | 13 | # extract_mesh 14 | cfg.level = 32. 15 | cfg.resolution = 256 16 | 17 | cfg.vis_encoder = '' 18 | cfg.feat_vis_len = 8 19 | 20 | cfg.cache_data = False 21 | cfg.sample_keypoints_epoch = -1 22 | 23 | cfg.write_video = False 24 | cfg.interested_mask = False 25 | cfg.render_path = False 26 | cfg.render_emb = 0 27 | cfg.render_ixt = 0 28 | cfg.code_id = -1 29 | cfg.time_weight = 0. 30 | cfg.render_static = True 31 | cfg.pretrain_path = '' 32 | cfg.scene = 'test' 33 | cfg.last_view = False 34 | cfg.exp_hard = False 35 | cfg.pos_encoding_t = False 36 | cfg.render_time = False 37 | cfg.render_time_skip = [0, -1, 1] 38 | cfg.start_time = [2009, 8, 1] 39 | cfg.end_time = [2013, 12, 1] 40 | cfg.discrete_3views = False 41 | cfg.fps = 24 42 | cfg.dcat = False 43 | cfg.min_y = -100000000. 44 | cfg.time_discrete = -1 45 | cfg.render_day = False 46 | cfg.render_date = [2013, 1, 1] 47 | cfg.rand_t = -1. 48 | cfg.semantic_mask = False 49 | cfg.product_combine = False 50 | cfg.unisample = False 51 | cfg.render_emb_2 = -1 52 | cfg.render_num = 30 53 | cfg.render_ext = 0 54 | cfg.time_geo = False 55 | cfg.reg_beta = False 56 | cfg.fix_beta = False 57 | cfg.hard_lap = False 58 | cfg.render_octree = False 59 | cfg.render_mask = False 60 | cfg.environment_map = False 61 | 62 | cfg.save_result = False 63 | cfg.clear_result = False 64 | cfg.save_tag = 'default' 65 | # module 66 | cfg.train_dataset_module = 'lib.datasets.dtu.neus' 67 | cfg.test_dataset_module = 'lib.datasets.dtu.neus' 68 | cfg.val_dataset_module = 'lib.datasets.dtu.neus' 69 | cfg.network_module = 'lib.neworks.neus.neus' 70 | cfg.loss_module = 'lib.train.losses.neus' 71 | cfg.evaluator_module = 'lib.evaluators.neus' 72 | 73 | # experiment name 74 | cfg.exp_name = 'gitbranch_hello' 75 | cfg.exp_name_tag = '' 76 | cfg.pretrain = '' 77 | 78 | # network 79 | cfg.distributed = False 80 | 81 | # task 82 | cfg.task = 'hello' 83 | 84 | # gpus 85 | cfg.gpus = list(range(4)) 86 | # if load the pretrained network 87 | cfg.resume = True 88 | 89 | # epoch 90 | cfg.ep_iter = -1 91 | cfg.save_ep = 1 92 | cfg.save_latest_ep = 1 93 | cfg.eval_ep = 1 94 | log_interval: 20 95 | 96 | 97 | cfg.task_arg = CN() 98 | cfg.task_arg.sample_more_on_mask = -1. 99 | cfg.task_arg.sample_on_mask = False 100 | 101 | # ----------------------------------------------------------------------------- 102 | # train 103 | # ----------------------------------------------------------------------------- 104 | cfg.train = CN() 105 | cfg.train.epoch = 10000 106 | cfg.train.num_workers = 8 107 | cfg.train.collator = 'default' 108 | cfg.train.batch_sampler = 'default' 109 | cfg.train.sampler_meta = CN({}) 110 | cfg.train.shuffle = True 111 | cfg.train.eps = 1e-8 112 | 113 | # use adam as default 114 | cfg.train.optim = 'adam' 115 | cfg.train.lr = 5e-4 116 | cfg.train.weight_decay = 0. 117 | cfg.train.scheduler = CN({'type': 'multi_step', 'milestones': [80, 120, 200, 240], 'gamma': 0.5}) 118 | cfg.train.batch_size = 4 119 | 120 | # test 121 | cfg.test = CN() 122 | cfg.test.batch_size = 1 123 | cfg.test.collator = 'default' 124 | cfg.test.epoch = -1 125 | cfg.test.batch_sampler = 'default' 126 | cfg.test.sampler_meta = CN({}) 127 | 128 | # trained model 129 | cfg.trained_model_dir = os.path.join(os.environ['workspace'], 'trained_model') 130 | cfg.clean_tag = 'debug' 131 | 132 | # recorder 133 | cfg.record_dir = os.path.join(os.environ['workspace'], 'record') 134 | 135 | # result 136 | cfg.result_dir = os.path.join(os.environ['workspace'], 'result') 137 | 138 | # evaluation 139 | cfg.skip_eval = False 140 | 141 | cfg.fix_random = False 142 | 143 | def parse_cfg(cfg, args): 144 | if len(cfg.task) == 0: 145 | raise ValueError('task must be specified') 146 | 147 | # assign the gpus 148 | if -1 not in cfg.gpus: 149 | os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join([str(gpu) for gpu in cfg.gpus]) 150 | 151 | if 'bbox' in cfg: 152 | bbox = np.array(cfg.bbox).reshape((2, 3)) 153 | center, half_size = np.mean(bbox, axis=0), (bbox[1]-bbox[0]).max().item() / 2. 154 | bbox = np.stack([center-half_size, center+half_size]) 155 | cfg.bbox = bbox.reshape(6).tolist() 156 | 157 | if len(cfg.exp_name_tag) != 0: 158 | cfg.exp_name += ('_' + cfg.exp_name_tag) 159 | cfg.exp_name = cfg.exp_name.replace('gitbranch', os.popen('git describe --all').readline().strip()[6:]) 160 | cfg.exp_name = cfg.exp_name.replace('gitcommit', os.popen('git describe --tags --always').readline().strip()) 161 | print('EXP NAME: ', cfg.exp_name) 162 | cfg.trained_model_dir = os.path.join(cfg.trained_model_dir, cfg.task, cfg.exp_name) 163 | cfg.record_dir = os.path.join(cfg.record_dir, cfg.task, cfg.exp_name) 164 | cfg.result_dir = os.path.join(cfg.result_dir, cfg.task, cfg.exp_name, cfg.save_tag) 165 | cfg.local_rank = args.local_rank 166 | modules = [key for key in cfg if '_module' in key] 167 | for module in modules: 168 | cfg[module.replace('_module', '_path')] = cfg[module].replace('.', '/') + '.py' 169 | 170 | def make_cfg(args): 171 | def merge_cfg(cfg_file, cfg): 172 | with open(cfg_file, 'r') as f: 173 | current_cfg = yacs.load_cfg(f) 174 | if 'parent_cfg' in current_cfg.keys(): 175 | cfg = merge_cfg(current_cfg.parent_cfg, cfg) 176 | cfg.merge_from_other_cfg(current_cfg) 177 | else: 178 | cfg.merge_from_other_cfg(current_cfg) 179 | print(cfg_file) 180 | return cfg 181 | cfg_ = merge_cfg(args.cfg_file, cfg) 182 | try: 183 | index = args.opts.index('other_opts') 184 | cfg_.merge_from_list(args.opts[:index]) 185 | except: 186 | cfg_.merge_from_list(args.opts) 187 | parse_cfg(cfg_, args) 188 | return cfg_ 189 | 190 | 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument("--cfg_file", default="configs/default.yaml", type=str) 193 | parser.add_argument('--test', action='store_true', dest='test', default=False) 194 | parser.add_argument("--type", type=str, default="") 195 | parser.add_argument('--det', type=str, default='') 196 | parser.add_argument('--local_rank', type=int, default=0) 197 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) 198 | args = parser.parse_args() 199 | if len(args.type) > 0: 200 | cfg.task = "run" 201 | cfg = make_cfg(args) 202 | --------------------------------------------------------------------------------