├── .gitignore ├── Figures └── network.png ├── README.md ├── configs ├── Ev3D_pretrain.yaml └── default.yaml ├── environment.yml ├── eval_gs.py ├── lib ├── __pycache__ │ ├── losses.cpython-37.pyc │ ├── recorder.cpython-310.pyc │ ├── recorder.cpython-37.pyc │ └── utils.cpython-37.pyc ├── config │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── config.cpython-310.pyc │ │ ├── config.cpython-37.pyc │ │ ├── config.cpython-39.pyc │ │ ├── yacs.cpython-310.pyc │ │ ├── yacs.cpython-37.pyc │ │ └── yacs.cpython-39.pyc │ ├── config.py │ └── yacs.py ├── dataset │ ├── Ev3D.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── Ev3D.cpython-310.pyc │ │ ├── Ev3D.cpython-37.pyc │ │ ├── Ev3D.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-39.pyc │ └── utils.py ├── losses.py ├── network │ ├── __init__.py │ ├── __pycache__ │ │ ├── ASNet.cpython-37.pyc │ │ ├── ASNet_utils.cpython-37.pyc │ │ ├── SegNet.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── asnet.cpython-37.pyc │ │ ├── asnet_utils.cpython-37.pyc │ │ ├── densenet.cpython-37.pyc │ │ ├── dfanet.cpython-37.pyc │ │ ├── eventgaussian.cpython-37.pyc │ │ ├── firenet.cpython-37.pyc │ │ ├── gsregressor.cpython-37.pyc │ │ ├── mobilenetv2.cpython-37.pyc │ │ ├── pspnet.cpython-37.pyc │ │ ├── recon_net.cpython-37.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── submodules.cpython-37.pyc │ │ ├── swin.cpython-37.pyc │ │ └── unet.cpython-37.pyc │ ├── asnet.py │ ├── asnet_utils.py │ ├── eventgaussian.py │ ├── firenet.py │ ├── gsregressor.py │ ├── neurons.py │ ├── recon_net.py │ ├── resnet.py │ ├── snn.py │ ├── submodules.py │ ├── swin.py │ └── unet.py ├── recorder.py ├── renderer │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── gaussian_render.cpython-37.pyc │ │ ├── gaussian_render.cpython-39.pyc │ │ └── rend_utils.cpython-37.pyc │ ├── gaussian_render.py │ └── rend_utils.py └── utils.py ├── pretrain_ckpt └── download.sh └── train_gs.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./experiments/ 2 | pretrain_ckpt/*.pth 3 | .history/ 4 | -------------------------------------------------------------------------------- /Figures/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/Figures/network.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EvGGS: A Collaborative Learning Framework for Event-based Generalizable Gaussian Splatting 2 | [A Collaborative Learning Framework for Event-based Generalizable Gaussian Splatting](https://arxiv.org/abs/2405.14959v1) 3 | 4 | Jiaxu Wang, Junhao He, Ziyi Zhang, Mingyuan Sun, Jingkai Sun, Renjing Xu* 5 | 6 |

7 |
8 | Fig 1. The main pipeline overview of the proposed EvGGS framework. 9 |

10 | 11 | # Create environment 12 | ``` 13 | conda env create --file environment.yml 14 | conda activate evggs 15 | ``` 16 | Then, compile the diff-gaussian-rasterization in 3DGS repository: 17 | ``` 18 | git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive 19 | cd gaussian-splatting/ 20 | pip install -e submodules/diff-gaussian-rasterization 21 | cd .. 22 | ``` 23 | # Download models 24 | Download the pretrained models from [OneDrive](https://hkustgz-my.sharepoint.com/:u:/g/personal/jwang457_connect_hkust-gz_edu_cn/ESAMKY3oHDRBr2-zeNb3L8IBKnFGiJCAgyRv3HBs6esFaQ?e=O7bili) that are placed at ```\pretrain_ckpt```. This directory includes two warmup ckpts and a pretrained ckpts on the synthetic dataset. 25 | 26 | # Running the code 27 | 28 | ## Download dataset 29 | 30 | - Ev3D-S 31 | 32 | A large-scale synthetic Event-based dataset with varying textures and materials accompanied by well-calibrated frames, depth, and groundtruths. 33 | 34 | You can download the dataset from [OneDrive](https://hkustgz-my.sharepoint.com/:u:/g/personal/jwang457_connect_hkust-gz_edu_cn/EYszUyxQnzRMkC0u5GxDOvEB_NhmBaVe2vBnpMH2ctSWxA?e=kJDwRz) and unzip it. A 50 GB of storage space is necessary. 35 | 36 | 37 | - EV3D-R 38 | 39 | A large-scale realistic Event-based 3D dataset containing various objects captured by a real event camera DVXplore. 40 | 41 | Due to some licensing reasons, we currently need your private application to use this dataset. 42 | 43 | ## Training 44 | 45 | ``` 46 | python train_gs.py 47 | ``` 48 | 49 | ## Evaluation 50 | 51 | ``` 52 | python eval_gs.py 53 | ``` 54 | 55 | In ```configs\Ev3D_pretrain```, several primary settings are defined such as experimental name, customized dataset path, please check. 56 | 57 | # Citation 58 | 59 | please cite our work if you use this dataset. 60 | 61 | ``` 62 | @misc{wang2024evggscollaborativelearningframework, 63 | title={EvGGS: A Collaborative Learning Framework for Event-based Generalizable Gaussian Splatting}, 64 | author={Jiaxu Wang and Junhao He and Ziyi Zhang and Mingyuan Sun and Jingkai Sun and Renjing Xu}, 65 | year={2024}, 66 | eprint={2405.14959}, 67 | archivePrefix={arXiv}, 68 | primaryClass={cs.CV}, 69 | url={https://arxiv.org/abs/2405.14959}, 70 | } 71 | ``` 72 | 73 | # Reference 74 | 75 | EventNeRF: [https://github.com/r00tman/EventNeRF?tab=readme-ov-file](https://github.com/r00tman/EventNeRF?tab=readme-ov-file). 76 | 3D Gaussian Splatting: [https://github.com/graphdeco-inria/gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting). 77 | GPS-GS: [https://github.com/aipixel/GPS-Gaussian](https://github.com/aipixel/GPS-Gaussian) 78 | PAEvD3d: [https://github.com/Mercerai/PAEv3d](https://github.com/Mercerai/PAEv3d) 79 | -------------------------------------------------------------------------------- /configs/Ev3D_pretrain.yaml: -------------------------------------------------------------------------------- 1 | exp_name: experimental_name 2 | lr: 0.0005 3 | wdecay: 1e-5 4 | target_epoch: 100 5 | num_steps: 1000000 6 | cs: [0, 480, 64, 576] 7 | 8 | dataset: 9 | base_folder: 10 | train_batch_size: 1 11 | val_batch_size: 1 12 | ratio: 0.75 13 | 14 | model: 15 | max_depth_plane: 64 16 | max_depth_value: 0.8 17 | num_bins: 5 18 | 19 | record: 20 | loss_freq: 50 21 | eval_freq: 5000 22 | save_freq: 10000 23 | 24 | restore_ckpt: None 25 | depth_warmup_ckpt: "./pretrain_ckpt/depth_warmup.pth" 26 | intensity_warmup_ckpt: "./pretrain_ckpt/intensity_warmup.pth" 27 | pretrain_ckpt: "./pretrain_ckpt/pretrain_evggs.pth" 28 | 29 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/configs/default.yaml -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: evggs 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/cloud/conda-forge/ 6 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/free/ 7 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/main/ 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_kmp_llvm 12 | - blas=1.0=mkl 13 | - brotli-python=1.0.9=py37hd23a5d3_7 14 | - bzip2=1.0.8=hd590300_5 15 | - ca-certificates=2023.11.17=hbcca054_0 16 | - certifi=2023.11.17=pyhd8ed1ab_0 17 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 18 | - colorama=0.4.6=pyhd8ed1ab_0 19 | - cudatoolkit=11.6.2=hfc3e2af_12 20 | - ffmpeg=4.3=hf484d3e_0 21 | - freetype=2.12.1=h267a509_2 22 | - gmp=6.3.0=h59595ed_0 23 | - gnutls=3.6.13=h85f3911_1 24 | - icu=73.2=h59595ed_0 25 | - idna=3.6=pyhd8ed1ab_0 26 | - jpeg=9e=h0b41bf4_3 27 | - lame=3.100=h166bdaf_1003 28 | - lcms2=2.14=h6ed2654_0 29 | - ld_impl_linux-64=2.40=h41732ed_0 30 | - lerc=4.0.0=h27087fc_0 31 | - libdeflate=1.14=h166bdaf_0 32 | - libffi=3.3=h58526e2_2 33 | - libgcc-ng=13.2.0=h807b86a_3 34 | - libhwloc=2.9.3=default_h554bfaf_1009 35 | - libiconv=1.17=hd590300_1 36 | - libpng=1.6.39=h753d276_0 37 | - libsqlite=3.44.2=h2797004_0 38 | - libstdcxx-ng=13.2.0=h7e041cc_3 39 | - libtiff=4.4.0=h82bc61c_5 40 | - libwebp-base=1.3.2=hd590300_0 41 | - libxcb=1.13=h7f98852_1004 42 | - libxml2=2.11.6=h232c23b_0 43 | - libzlib=1.2.13=hd590300_5 44 | - llvm-openmp=17.0.6=h4dfa4b3_0 45 | - mkl=2021.4.0=h8d4b97c_729 46 | - mkl-service=2.4.0=py37h402132d_0 47 | - mkl_fft=1.3.1=py37h3e078e5_1 48 | - mkl_random=1.2.2=py37h219a48f_0 49 | - ncurses=6.4=h59595ed_2 50 | - nettle=3.6=he412f7d_0 51 | - numpy=1.21.5=py37h6c91a56_3 52 | - numpy-base=1.21.5=py37ha15fc14_3 53 | - openh264=2.1.1=h780b84a_0 54 | - openjpeg=2.5.0=h7d73246_1 55 | - openssl=1.1.1w=hd590300_0 56 | - pip=22.3.1=pyhd8ed1ab_0 57 | - plyfile=0.8.1=pyhd8ed1ab_0 58 | - pthread-stubs=0.4=h36c2ea0_1001 59 | - pysocks=1.7.1=py37h89c1867_5 60 | - python=3.7.13=haa1d7c7_1 61 | - python_abi=3.7=2_cp37m 62 | - pytorch=1.12.1=py3.7_cuda11.6_cudnn8.3.2_0 63 | - pytorch-mutex=1.0=cuda 64 | - readline=8.2=h8228510_1 65 | - requests=2.31.0=pyhd8ed1ab_0 66 | - setuptools=68.2.2=pyhd8ed1ab_0 67 | - six=1.16.0=pyh6c4a22f_0 68 | - sqlite=3.44.2=h2c6b66d_0 69 | - tbb=2021.11.0=h00ab1b0_0 70 | - tk=8.6.13=noxft_h4845f30_101 71 | - torchaudio=0.12.1=py37_cu116 72 | - torchvision=0.13.1=py37_cu116 73 | - tqdm=4.66.1=pyhd8ed1ab_0 74 | - typing_extensions=4.7.1=pyha770c72_0 75 | - urllib3=2.1.0=pyhd8ed1ab_0 76 | - wheel=0.42.0=pyhd8ed1ab_0 77 | - xorg-libxau=1.0.11=hd590300_0 78 | - xorg-libxdmcp=1.1.3=h7f98852_0 79 | - xz=5.2.6=h166bdaf_0 80 | - zlib=1.2.13=hd590300_5 81 | - zstd=1.5.5=hfc55251_0 82 | - pip: 83 | - absl-py==2.0.0 84 | - addict==2.4.0 85 | - ansi2html==1.9.1 86 | - attrs==23.2.0 87 | - backcall==0.2.0 88 | - cachetools==5.3.2 89 | - click==8.1.7 90 | - comm==0.1.4 91 | - configargparse==1.7 92 | - cycler==0.11.0 93 | - dash==2.14.2 94 | - dash-core-components==2.0.0 95 | - dash-html-components==2.0.0 96 | - dash-table==5.0.0 97 | - decorator==5.1.1 98 | - diff-gaussian-rasterization==0.0.0 99 | - fastjsonschema==2.19.1 100 | - flask==2.2.5 101 | - fonttools==4.38.0 102 | - google-auth==2.25.2 103 | - google-auth-oauthlib==0.4.6 104 | - grpcio==1.60.0 105 | - h5py==3.8.0 106 | - imageio==2.31.2 107 | - importlib-metadata==6.7.0 108 | - importlib-resources==5.12.0 109 | - ipython==7.34.0 110 | - ipywidgets==8.1.1 111 | - itsdangerous==2.1.2 112 | - jedi==0.19.1 113 | - jinja2==3.1.2 114 | - joblib==1.3.2 115 | - jsonschema==4.17.3 116 | - jupyter-core==4.12.0 117 | - jupyterlab-widgets==3.0.9 118 | - kiwisolver==1.4.5 119 | - lpips==0.1.4 120 | - markdown==3.4.4 121 | - markupsafe==2.1.3 122 | - matplotlib==3.5.3 123 | - matplotlib-inline==0.1.6 124 | - natsort==8.4.0 125 | - nbformat==5.7.0 126 | - nest-asyncio==1.5.8 127 | - oauthlib==3.2.2 128 | - open3d==0.17.0 129 | - opencv-python==4.9.0.80 130 | - packaging==23.2 131 | - pandas==1.3.5 132 | - parso==0.8.3 133 | - pexpect==4.9.0 134 | - pickleshare==0.7.5 135 | - pillow==9.5.0 136 | - pkgutil-resolve-name==1.3.10 137 | - plotly==5.18.0 138 | - prompt-toolkit==3.0.43 139 | - protobuf==3.20.3 140 | - ptyprocess==0.7.0 141 | - pyasn1==0.5.1 142 | - pyasn1-modules==0.3.0 143 | - pygments==2.17.2 144 | - pyparsing==3.1.1 145 | - pyquaternion==0.9.9 146 | - pyrsistent==0.19.3 147 | - python-dateutil==2.8.2 148 | - pytz==2023.3.post1 149 | - pyyaml==6.0.1 150 | - requests-oauthlib==1.3.1 151 | - retrying==1.3.4 152 | - rsa==4.9 153 | - scikit-learn==1.0.2 154 | - scipy==1.7.3 155 | - simple-knn==0.0.0 156 | - tenacity==8.2.3 157 | - tensorboard==2.11.2 158 | - tensorboard-data-server==0.6.1 159 | - tensorboard-plugin-wit==1.8.1 160 | - threadpoolctl==3.1.0 161 | - traitlets==5.9.0 162 | - wcwidth==0.2.12 163 | - werkzeug==2.2.3 164 | - widgetsnbextension==4.0.9 165 | - yacs==0.1.8 166 | - zipp==3.15.0 167 | -------------------------------------------------------------------------------- /eval_gs.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | from lib.dataset import EventDataloader 3 | from lib.recorder import Logger, file_backup 4 | from lib.network import model_loss_light, model_loss, EventGaussian 5 | from lib.renderer import pts2render, depth2pc 6 | from lib.utils import depth2img 7 | from lib.losses import psnr, ssim 8 | import numpy as np 9 | import imageio 10 | import cv2 11 | import os 12 | from pathlib import Path 13 | from tqdm import tqdm 14 | import logging 15 | import torch 16 | from torch import optim 17 | from torch.utils.data import DataLoader 18 | import torch.nn.functional as F 19 | import lpips 20 | import time 21 | 22 | cs = cfg.cs 23 | 24 | class Trainer: 25 | def __init__(self) -> None: 26 | device = torch.device('cuda:{}'.format(cfg.local_rank)) 27 | self.device = device 28 | which_test = "val" 29 | 30 | self.train_loader = None 31 | self.val_loader = EventDataloader(cfg.dataset.base_folder, split=which_test, num_workers=1,\ 32 | batch_size=1, shuffle=False) 33 | 34 | self.len_val = len(self.val_loader) 35 | self.model = EventGaussian().to(self.device) 36 | # self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=cfg.wdecay, eps=1e-8) 37 | dpt_params = list(map(id,self.model.depth_estimator.parameters())) + list(map(id,self.model.intensity_estimator.parameters())) 38 | rest_params = filter(lambda x:id(x) not in dpt_params,self.model.parameters()) 39 | self.optimizer = optim.Adam([ 40 | {'params':self.model.depth_estimator.parameters(), 'lr':1}, 41 | {'params':self.model.intensity_estimator.parameters(), 'lr':1}, 42 | {'params':rest_params, 'lr':1}, 43 | ], lr=0.001, weight_decay=cfg.wdecay, eps=1e-8) 44 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=10000, gamma=0.9) 45 | self.logger = Logger(self.scheduler, cfg.record) 46 | 47 | self.total_steps = 0 48 | self.target_epoch = cfg.target_epoch 49 | 50 | if cfg.restore_ckpt: 51 | self.load_ckpt(cfg.restore_ckpt) 52 | 53 | def to_cuda(self, batch): 54 | for k in batch: 55 | if isinstance(batch[k], tuple) or isinstance(batch[k], list): 56 | batch[k] = [b.to(self.device) for b in batch[k]] 57 | elif isinstance(batch[k], dict): 58 | batch[k] = {key: self.to_cuda(batch[k][key]) for key in batch[k]} 59 | else: 60 | batch[k] = batch[k].to(self.device) 61 | return batch 62 | 63 | def run_eval(self): 64 | print(f"Doing validation ...") 65 | torch.cuda.empty_cache() 66 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(self.device) 67 | l1_list = [] 68 | psnr_list = [] 69 | ssim_list = [] 70 | lpips_list = [] 71 | show_idx = list(range(self.len_val)) 72 | count = 0 73 | scene_num = 0 74 | os.makedirs(r'%s/%d/' % (cfg.record.show_path, scene_num), exist_ok=True) 75 | for idx, batch in enumerate(tqdm(self.val_loader)): 76 | if count == 201: 77 | scene_num += 1 78 | count = 0 79 | os.makedirs(r'%s/%d/' % (cfg.record.show_path, scene_num), exist_ok=True) 80 | with torch.no_grad(): 81 | batch = self.to_cuda(batch) 82 | gt = batch["cim"] 83 | 84 | batch["left_event_tensor"] = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1) 85 | batch["right_event_tensor"] = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1) 86 | 87 | start_time = time.time() 88 | data = self.model(batch) 89 | 90 | data["target"] = {"H":batch["H"], 91 | "W":batch["W"], 92 | "FovX":batch["FovX"], 93 | "FovY":batch["FovY"], 94 | 'world_view_transform': batch["world_view_transform"], 95 | 'full_proj_transform': batch["full_proj_transform"], 96 | 'camera_center': batch["camera_center"]} 97 | 98 | data["lview"]["pts"] = depth2pc(data["lview"]["depth"], torch.inverse(batch["lpose"]), batch["intrinsic"]) 99 | data["rview"]["pts"] = depth2pc(data["rview"]["depth"], torch.inverse(batch["rpose"]), batch["intrinsic"]) 100 | 101 | pred = pts2render(data, [0.,0.,0.])[:,0] 102 | end_time = time.time() 103 | execution_time = end_time - start_time 104 | 105 | pred = pred[:,None] 106 | loss = F.l1_loss(pred.squeeze(), gt.squeeze()) 107 | l1_list.append(loss.item()) 108 | 109 | count += 1 110 | # if idx == show_idx: 111 | psnr_list.append(torch.mean(psnr(pred, gt)).item()) 112 | ssim_list.append(ssim(pred, gt).item()) 113 | lpips_list.append(torch.mean(loss_fn_vgg(pred*2-1, gt*2-1)).item()) 114 | if idx in show_idx: 115 | tmp_gt = (gt[0]*255.0).cpu().numpy().astype(np.uint8).squeeze() 116 | tmp_pred = (pred[0]*255.0).cpu().numpy().astype(np.uint8).squeeze() 117 | tmp_img_name = '%s/%d/step%s_idx%d.jpg' % (cfg.record.show_path, scene_num, self.total_steps, idx) 118 | imageio.imsave(tmp_img_name, np.concatenate([tmp_pred, tmp_gt], axis=0)) 119 | 120 | val_psnr = np.round(np.mean(np.array(psnr_list)), 8) 121 | val_ssim = np.round(np.mean(np.array(ssim_list)), 8) 122 | val_lpips = np.round(np.mean(np.array(lpips_list)), 8) 123 | print(f"Non masked and selected Metrics ({self.total_steps}):, psnr {val_psnr}, ssim {val_ssim}, lpips {val_lpips}") 124 | self.logger.write_dict({'NO masked psnr on val set': val_psnr}, write_step=self.total_steps) 125 | torch.cuda.empty_cache() 126 | 127 | def save_ckpt(self, save_path, show_log=True): 128 | if show_log: 129 | print(f"Save checkpoint to {save_path} ...") 130 | 131 | torch.save({ 132 | 'total_steps': self.total_steps, 133 | 'network': self.model.state_dict(), 134 | 'optimizer': self.optimizer.state_dict(), 135 | 'scheduler': self.scheduler.state_dict() 136 | }, save_path) 137 | 138 | def load_ckpt(self, load_path, load_optimizer=True, strict=True): 139 | assert os.path.exists(load_path) 140 | print(f"Loading checkpoint from {load_path} ...") 141 | ckpt = torch.load(load_path, map_location='cuda') 142 | 143 | self.model.load_state_dict(ckpt['network'], strict=strict) 144 | print(f"Parameter loading done") 145 | if load_optimizer: 146 | self.total_steps = ckpt['total_steps'] + 1 147 | self.logger.total_steps = self.total_steps 148 | self.optimizer.load_state_dict(ckpt['optimizer']) 149 | self.scheduler.load_state_dict(ckpt['scheduler']) 150 | print(f"Optimizer loading done") 151 | 152 | if __name__ == "__main__": 153 | trainer = Trainer() 154 | trainer.load_ckpt(cfg.pretrain_ckpt, load_optimizer=False) 155 | trainer.model.eval() 156 | trainer.run_eval() 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /lib/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/recorder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/recorder.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/recorder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/recorder.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import cfg, args -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/yacs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/yacs.cpython-310.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/yacs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/yacs.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/yacs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/yacs.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/config.py: -------------------------------------------------------------------------------- 1 | from .yacs import CfgNode as CN 2 | import argparse 3 | import os 4 | import numpy as np 5 | from . import yacs 6 | from datetime import datetime 7 | from pathlib import Path 8 | 9 | cfg = CN() 10 | cfg.task = 'hello' 11 | cfg.gpus = [0] 12 | cfg.exp_name = 'depth_pred' 13 | 14 | cfg.record = CN() 15 | 16 | def parse_cfg(cfg, args): 17 | if len(cfg.task) == 0: 18 | raise ValueError('task must be specified') 19 | 20 | # assign the gpus 21 | # if -1 not in cfg.gpus: 22 | # os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join([str(gpu) for gpu in cfg.gpus]) 23 | 24 | if 'bbox' in cfg: 25 | bbox = np.array(cfg.bbox).reshape((2, 3)) 26 | center, half_size = np.mean(bbox, axis=0), (bbox[1]-bbox[0]).max().item() / 2. 27 | bbox = np.stack([center-half_size, center+half_size]) 28 | cfg.bbox = bbox.reshape(6).tolist() 29 | 30 | print('EXP NAME: ', cfg.exp_name) 31 | 32 | cfg.local_rank = args.local_rank 33 | 34 | modules = [key for key in cfg if '_module' in key] 35 | for module in modules: 36 | cfg[module.replace('_module', '_path')] = cfg[module].replace('.', '/') + '.py' 37 | 38 | def make_cfg(args): 39 | def merge_cfg(cfg_file, cfg): 40 | with open(cfg_file, 'r') as f: 41 | current_cfg = yacs.load_cfg(f) 42 | if 'parent_cfg' in current_cfg.keys(): 43 | cfg = merge_cfg(current_cfg.parent_cfg, cfg) 44 | cfg.merge_from_other_cfg(current_cfg) 45 | else: 46 | cfg.merge_from_other_cfg(current_cfg) 47 | print(cfg_file) 48 | return cfg 49 | cfg_ = merge_cfg(args.cfg_file, cfg) 50 | try: 51 | index = args.opts.index('other_opts') 52 | cfg_.merge_from_list(args.opts[:index]) 53 | except: 54 | cfg_.merge_from_list(args.opts) 55 | parse_cfg(cfg_, args) 56 | return cfg_ 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--cfg_file", default="configs/Ev3D_pretrain.yaml", type=str) 60 | parser.add_argument('--local_rank', type=int, default=0) 61 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) 62 | args = parser.parse_args() 63 | 64 | cfg = make_cfg(args) 65 | 66 | dt = datetime.today() 67 | cfg.exp_name = '%s_%s%s' % (cfg.exp_name, str(dt.month).zfill(2), str(dt.day).zfill(2)) 68 | cfg.record.ckpt_path = "experiments/%s/ckpt" % cfg.exp_name 69 | cfg.record.show_path = "experiments/%s/show" % cfg.exp_name 70 | cfg.record.logs_path = "experiments/%s/logs" % cfg.exp_name 71 | cfg.record.file_path = "experiments/%s/file" % cfg.exp_name 72 | 73 | for path in [cfg.record.ckpt_path, cfg.record.show_path, cfg.record.logs_path, cfg.record.file_path]: 74 | Path(path).mkdir(exist_ok=True, parents=True) -------------------------------------------------------------------------------- /lib/config/yacs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """YACS -- Yet Another Configuration System is designed to be a simple 17 | configuration management system for academic and industrial research 18 | projects. 19 | 20 | See README.md for usage and examples. 21 | """ 22 | 23 | import copy 24 | import io 25 | import logging 26 | import os 27 | from ast import literal_eval 28 | 29 | import yaml 30 | 31 | 32 | # Flag for py2 and py3 compatibility to use when separate code paths are necessary 33 | # When _PY2 is False, we assume Python 3 is in use 34 | _PY2 = False 35 | 36 | # Filename extensions for loading configs from files 37 | _YAML_EXTS = {"", ".yaml", ".yml"} 38 | _PY_EXTS = {".py"} 39 | 40 | # py2 and py3 compatibility for checking file object type 41 | # We simply use this to infer py2 vs py3 42 | try: 43 | _FILE_TYPES = (file, io.IOBase) 44 | _PY2 = True 45 | except NameError: 46 | _FILE_TYPES = (io.IOBase,) 47 | 48 | # CfgNodes can only contain a limited set of valid types 49 | _VALID_TYPES = {tuple, list, str, int, float, bool} 50 | # py2 allow for str and unicode 51 | if _PY2: 52 | _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 53 | 54 | # Utilities for importing modules from file paths 55 | if _PY2: 56 | # imp is available in both py2 and py3 for now, but is deprecated in py3 57 | import imp 58 | else: 59 | import importlib.util 60 | 61 | logger = logging.getLogger(__name__) 62 | 63 | 64 | class CfgNode(dict): 65 | """ 66 | CfgNode represents an internal node in the configuration tree. It's a simple 67 | dict-like container that allows for attribute-based access to keys. 68 | """ 69 | 70 | IMMUTABLE = "__immutable__" 71 | DEPRECATED_KEYS = "__deprecated_keys__" 72 | RENAMED_KEYS = "__renamed_keys__" 73 | 74 | def __init__(self, init_dict=None, key_list=None): 75 | # Recursively convert nested dictionaries in init_dict into CfgNodes 76 | init_dict = {} if init_dict is None else init_dict 77 | key_list = [] if key_list is None else key_list 78 | for k, v in init_dict.items(): 79 | if type(v) is dict: 80 | # Convert dict to CfgNode 81 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 82 | else: 83 | # Check for valid leaf type or nested CfgNode 84 | _assert_with_logging( 85 | _valid_type(v, allow_cfg_node=True), 86 | "Key {} with value {} is not a valid type; valid types: {}".format( 87 | ".".join(key_list + [k]), type(v), _VALID_TYPES 88 | ), 89 | ) 90 | super(CfgNode, self).__init__(init_dict) 91 | # Manage if the CfgNode is frozen or not 92 | self.__dict__[CfgNode.IMMUTABLE] = False 93 | # Deprecated options 94 | # If an option is removed from the code and you don't want to break existing 95 | # yaml configs, you can add the full config key as a string to the set below. 96 | self.__dict__[CfgNode.DEPRECATED_KEYS] = set() 97 | # Renamed options 98 | # If you rename a config option, record the mapping from the old name to the new 99 | # name in the dictionary below. Optionally, if the type also changed, you can 100 | # make the value a tuple that specifies first the renamed key and then 101 | # instructions for how to edit the config file. 102 | self.__dict__[CfgNode.RENAMED_KEYS] = { 103 | # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow 104 | # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow 105 | # 'EXAMPLE.NEW.KEY', 106 | # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or " 107 | # + "'foo:bar' -> ('foo', 'bar')" 108 | # ), 109 | } 110 | 111 | def __getattr__(self, name): 112 | if name in self: 113 | return self[name] 114 | else: 115 | raise AttributeError(name) 116 | 117 | def __setattr__(self, name, value): 118 | if self.is_frozen(): 119 | raise AttributeError( 120 | "Attempted to set {} to {}, but CfgNode is immutable".format( 121 | name, value 122 | ) 123 | ) 124 | 125 | _assert_with_logging( 126 | name not in self.__dict__, 127 | "Invalid attempt to modify internal CfgNode state: {}".format(name), 128 | ) 129 | _assert_with_logging( 130 | _valid_type(value, allow_cfg_node=True), 131 | "Invalid type {} for key {}; valid types = {}".format( 132 | type(value), name, _VALID_TYPES 133 | ), 134 | ) 135 | 136 | self[name] = value 137 | 138 | def __str__(self): 139 | def _indent(s_, num_spaces): 140 | s = s_.split("\n") 141 | if len(s) == 1: 142 | return s_ 143 | first = s.pop(0) 144 | s = [(num_spaces * " ") + line for line in s] 145 | s = "\n".join(s) 146 | s = first + "\n" + s 147 | return s 148 | 149 | r = "" 150 | s = [] 151 | for k, v in sorted(self.items()): 152 | seperator = "\n" if isinstance(v, CfgNode) else " " 153 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 154 | attr_str = _indent(attr_str, 2) 155 | s.append(attr_str) 156 | r += "\n".join(s) 157 | return r 158 | 159 | def __repr__(self): 160 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 161 | 162 | def dump(self): 163 | """Dump to a string.""" 164 | self_as_dict = _to_dict(self) 165 | return yaml.safe_dump(self_as_dict) 166 | 167 | def merge_from_file(self, cfg_filename): 168 | """Load a yaml config file and merge it this CfgNode.""" 169 | with open(cfg_filename, "r") as f: 170 | cfg = load_cfg(f) 171 | self.merge_from_other_cfg(cfg) 172 | 173 | def merge_from_other_cfg(self, cfg_other): 174 | """Merge `cfg_other` into this CfgNode.""" 175 | _merge_a_into_b(cfg_other, self, self, []) 176 | 177 | def merge_from_list(self, cfg_list): 178 | """Merge config (keys, values) in a list (e.g., from command line) into 179 | this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. 180 | """ 181 | _assert_with_logging( 182 | len(cfg_list) % 2 == 0, 183 | "Override list has odd length: {}; it must be a list of pairs".format( 184 | cfg_list 185 | ), 186 | ) 187 | root = self 188 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 189 | if root.key_is_deprecated(full_key): 190 | continue 191 | if root.key_is_renamed(full_key): 192 | root.raise_key_rename_error(full_key) 193 | key_list = full_key.split(".") 194 | d = self 195 | for subkey in key_list[:-1]: 196 | _assert_with_logging( 197 | subkey in d, "Non-existent key: {}".format(full_key) 198 | ) 199 | d = d[subkey] 200 | subkey = key_list[-1] 201 | _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) 202 | value = _decode_cfg_value(v) 203 | value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) 204 | d[subkey] = value 205 | 206 | def freeze(self): 207 | """Make this CfgNode and all of its children immutable.""" 208 | self._immutable(True) 209 | 210 | def defrost(self): 211 | """Make this CfgNode and all of its children mutable.""" 212 | self._immutable(False) 213 | 214 | def is_frozen(self): 215 | """Return mutability.""" 216 | return self.__dict__[CfgNode.IMMUTABLE] 217 | 218 | def _immutable(self, is_immutable): 219 | """Set immutability to is_immutable and recursively apply the setting 220 | to all nested CfgNodes. 221 | """ 222 | self.__dict__[CfgNode.IMMUTABLE] = is_immutable 223 | # Recursively set immutable state 224 | for v in self.__dict__.values(): 225 | if isinstance(v, CfgNode): 226 | v._immutable(is_immutable) 227 | for v in self.values(): 228 | if isinstance(v, CfgNode): 229 | v._immutable(is_immutable) 230 | 231 | def clone(self): 232 | """Recursively copy this CfgNode.""" 233 | return copy.deepcopy(self) 234 | 235 | def register_deprecated_key(self, key): 236 | """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated 237 | keys a warning is generated and the key is ignored. 238 | """ 239 | _assert_with_logging( 240 | key not in self.__dict__[CfgNode.DEPRECATED_KEYS], 241 | "key {} is already registered as a deprecated key".format(key), 242 | ) 243 | self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) 244 | 245 | def register_renamed_key(self, old_name, new_name, message=None): 246 | """Register a key as having been renamed from `old_name` to `new_name`. 247 | When merging a renamed key, an exception is thrown alerting to user to 248 | the fact that the key has been renamed. 249 | """ 250 | _assert_with_logging( 251 | old_name not in self.__dict__[CfgNode.RENAMED_KEYS], 252 | "key {} is already registered as a renamed cfg key".format(old_name), 253 | ) 254 | value = new_name 255 | if message: 256 | value = (new_name, message) 257 | self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value 258 | 259 | def key_is_deprecated(self, full_key): 260 | """Test if a key is deprecated.""" 261 | if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: 262 | logger.warning("Deprecated config key (ignoring): {}".format(full_key)) 263 | return True 264 | return False 265 | 266 | def key_is_renamed(self, full_key): 267 | """Test if a key is renamed.""" 268 | return full_key in self.__dict__[CfgNode.RENAMED_KEYS] 269 | 270 | def raise_key_rename_error(self, full_key): 271 | new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] 272 | if isinstance(new_key, tuple): 273 | msg = " Note: " + new_key[1] 274 | new_key = new_key[0] 275 | else: 276 | msg = "" 277 | raise KeyError( 278 | "Key {} was renamed to {}; please update your config.{}".format( 279 | full_key, new_key, msg 280 | ) 281 | ) 282 | 283 | 284 | def load_cfg(cfg_file_obj_or_str): 285 | """Load a cfg. Supports loading from: 286 | - A file object backed by a YAML file 287 | - A file object backed by a Python source file that exports an attribute 288 | "cfg" that is either a dict or a CfgNode 289 | - A string that can be parsed as valid YAML 290 | """ 291 | _assert_with_logging( 292 | isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), 293 | "Expected first argument to be of type {} or {}, but it was {}".format( 294 | _FILE_TYPES, str, type(cfg_file_obj_or_str) 295 | ), 296 | ) 297 | if isinstance(cfg_file_obj_or_str, str): 298 | return _load_cfg_from_yaml_str(cfg_file_obj_or_str) 299 | elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): 300 | return _load_cfg_from_file(cfg_file_obj_or_str) 301 | else: 302 | raise NotImplementedError("Impossible to reach here (unless there's a bug)") 303 | 304 | 305 | def _load_cfg_from_file(file_obj): 306 | """Load a config from a YAML file or a Python source file.""" 307 | _, file_extension = os.path.splitext(file_obj.name) 308 | if file_extension in _YAML_EXTS: 309 | return _load_cfg_from_yaml_str(file_obj.read()) 310 | elif file_extension in _PY_EXTS: 311 | return _load_cfg_py_source(file_obj.name) 312 | else: 313 | raise Exception( 314 | "Attempt to load from an unsupported file type {}; " 315 | "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS)) 316 | ) 317 | 318 | 319 | def _load_cfg_from_yaml_str(str_obj): 320 | """Load a config from a YAML string encoding.""" 321 | cfg_as_dict = yaml.safe_load(str_obj) 322 | return CfgNode(cfg_as_dict) 323 | 324 | 325 | def _load_cfg_py_source(filename): 326 | """Load a config from a Python source file.""" 327 | module = _load_module_from_file("yacs.config.override", filename) 328 | _assert_with_logging( 329 | hasattr(module, "cfg"), 330 | "Python module from file {} must have 'cfg' attr".format(filename), 331 | ) 332 | VALID_ATTR_TYPES = {dict, CfgNode} 333 | _assert_with_logging( 334 | type(module.cfg) in VALID_ATTR_TYPES, 335 | "Imported module 'cfg' attr must be in {} but is {} instead".format( 336 | VALID_ATTR_TYPES, type(module.cfg) 337 | ), 338 | ) 339 | if type(module.cfg) is dict: 340 | return CfgNode(module.cfg) 341 | else: 342 | return module.cfg 343 | 344 | 345 | def _to_dict(cfg_node): 346 | """Recursively convert all CfgNode objects to dict objects.""" 347 | 348 | def convert_to_dict(cfg_node, key_list): 349 | if not isinstance(cfg_node, CfgNode): 350 | _assert_with_logging( 351 | _valid_type(cfg_node), 352 | "Key {} with value {} is not a valid type; valid types: {}".format( 353 | ".".join(key_list), type(cfg_node), _VALID_TYPES 354 | ), 355 | ) 356 | return cfg_node 357 | else: 358 | cfg_dict = dict(cfg_node) 359 | for k, v in cfg_dict.items(): 360 | cfg_dict[k] = convert_to_dict(v, key_list + [k]) 361 | return cfg_dict 362 | 363 | return convert_to_dict(cfg_node, []) 364 | 365 | 366 | def _valid_type(value, allow_cfg_node=False): 367 | return (type(value) in _VALID_TYPES) or (allow_cfg_node and type(value) == CfgNode) 368 | 369 | 370 | def _merge_a_into_b(a, b, root, key_list): 371 | """Merge config dictionary a into config dictionary b, clobbering the 372 | options in b whenever they are also specified in a. 373 | """ 374 | _assert_with_logging( 375 | isinstance(a, CfgNode), 376 | "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), 377 | ) 378 | _assert_with_logging( 379 | isinstance(b, CfgNode), 380 | "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), 381 | ) 382 | 383 | for k, v_ in a.items(): 384 | full_key = ".".join(key_list + [k]) 385 | # a must specify keys that are in b 386 | if k not in b: 387 | if root.key_is_deprecated(full_key): 388 | continue 389 | elif root.key_is_renamed(full_key): 390 | root.raise_key_rename_error(full_key) 391 | else: 392 | v = copy.deepcopy(v_) 393 | v = _decode_cfg_value(v) 394 | b.update({k: v}) 395 | else: 396 | v = copy.deepcopy(v_) 397 | v = _decode_cfg_value(v) 398 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) 399 | 400 | # Recursively merge dicts 401 | if isinstance(v, CfgNode): 402 | try: 403 | _merge_a_into_b(v, b[k], root, key_list + [k]) 404 | except BaseException: 405 | raise 406 | else: 407 | b[k] = v 408 | 409 | 410 | def _decode_cfg_value(v): 411 | """Decodes a raw config value (e.g., from a yaml config files or command 412 | line argument) into a Python object. 413 | """ 414 | # Configs parsed from raw yaml will contain dictionary keys that need to be 415 | # converted to CfgNode objects 416 | if isinstance(v, dict): 417 | return CfgNode(v) 418 | # All remaining processing is only applied to strings 419 | if not isinstance(v, str): 420 | return v 421 | # Try to interpret `v` as a: 422 | # string, number, tuple, list, dict, boolean, or None 423 | try: 424 | v = literal_eval(v) 425 | # The following two excepts allow v to pass through when it represents a 426 | # string. 427 | # 428 | # Longer explanation: 429 | # The type of v is always a string (before calling literal_eval), but 430 | # sometimes it *represents* a string and other times a data structure, like 431 | # a list. In the case that v represents a string, what we got back from the 432 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 433 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 434 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 435 | # will raise a SyntaxError. 436 | except ValueError: 437 | pass 438 | except SyntaxError: 439 | pass 440 | return v 441 | 442 | 443 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 444 | """Checks that `replacement`, which is intended to replace `original` is of 445 | the right type. The type is correct if it matches exactly or is one of a few 446 | cases in which the type can be easily coerced. 447 | """ 448 | original_type = type(original) 449 | replacement_type = type(replacement) 450 | 451 | # The types must match (with some exceptions) 452 | if replacement_type == original_type: 453 | return replacement 454 | 455 | # Cast replacement from from_type to to_type if the replacement and original 456 | # types match from_type and to_type 457 | def conditional_cast(from_type, to_type): 458 | if replacement_type == from_type and original_type == to_type: 459 | return True, to_type(replacement) 460 | else: 461 | return False, None 462 | 463 | # Conditionally casts 464 | # list <-> tuple 465 | casts = [(tuple, list), (list, tuple)] 466 | # For py2: allow converting from str (bytes) to a unicode string 467 | try: 468 | casts.append((str, unicode)) # noqa: F821 469 | except Exception: 470 | pass 471 | 472 | for (from_type, to_type) in casts: 473 | converted, converted_value = conditional_cast(from_type, to_type) 474 | if converted: 475 | return converted_value 476 | 477 | raise ValueError( 478 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 479 | "key: {}".format( 480 | original_type, replacement_type, original, replacement, full_key 481 | ) 482 | ) 483 | 484 | 485 | def _assert_with_logging(cond, msg): 486 | if not cond: 487 | logger.debug(msg) 488 | assert cond, msg 489 | 490 | 491 | def _load_module_from_file(name, filename): 492 | if _PY2: 493 | module = imp.load_source(name, filename) 494 | else: 495 | spec = importlib.util.spec_from_file_location(name, filename) 496 | module = importlib.util.module_from_spec(spec) 497 | spec.loader.exec_module(module) 498 | return module -------------------------------------------------------------------------------- /lib/dataset/Ev3D.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/home/jiaxu/jx/EvGGS/") 3 | from natsort import natsorted 4 | import open3d as o3d 5 | import h5py 6 | import os 7 | import numpy as np 8 | import torch 9 | from .utils import events_to_voxel_grid 10 | from torch.utils.data import Dataset, DataLoader 11 | import cv2 12 | import glob 13 | from tqdm import tqdm 14 | from torch.utils.data import ConcatDataset 15 | from lib.renderer.rend_utils import getProjectionMatrix, getWorld2View2, focal2fov 16 | from lib.config import cfg, args 17 | 18 | def depth2pc_np_ours(depth, extrinsic, intrinsic, isdisparity=False): 19 | H, W = depth.shape 20 | x_ref, y_ref = np.meshgrid(np.arange(0, W), np.arange(0, H)) 21 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) 22 | 23 | xyz_ref = np.matmul(np.linalg.inv(intrinsic[:3, :3]), 24 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth.reshape([-1])) 25 | xyz_world = np.matmul(np.linalg.inv(extrinsic), np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] 26 | 27 | return xyz_world.transpose((1, 0)).astype(np.float32) 28 | 29 | def find_files(dir, exts): 30 | if os.path.isdir(dir): 31 | files_grabbed = [] 32 | for ext in exts: 33 | files_grabbed.extend(glob.glob(os.path.join(dir, ext))) 34 | if len(files_grabbed) > 0: 35 | files_grabbed = sorted(files_grabbed) 36 | return files_grabbed 37 | else: 38 | return [] 39 | 40 | def parse_txt(filename, shape): 41 | assert os.path.isfile(filename) 42 | nums = open(filename).read().split() 43 | return np.array([float(x) for x in nums]).reshape(shape).astype(np.float32) 44 | 45 | def concatenate_datasets_ratio(base_folders, dataset_type, split, dataset_kwargs={}): 46 | scene_lists = natsorted(os.listdir(os.path.join(base_folders, 'Event'))) 47 | n_scenes = len(scene_lists) 48 | ratio = int(cfg.dataset.ratio * n_scenes) 49 | 50 | if split == "train": 51 | scene_lists = scene_lists[:ratio] 52 | elif split == "val": 53 | scene_lists = scene_lists[ratio:] 54 | 55 | dataset_list = [] 56 | for i in range(len(scene_lists)): 57 | dataset_list.append(dataset_type(base_folders, scene_lists[i], **dataset_kwargs)) 58 | return ConcatDataset(dataset_list) 59 | 60 | def concatenate_datasets_split(base_folders, dataset_type, split, dataset_kwargs={}): 61 | if split == "train": 62 | scenes_path = os.path.join(base_folders, "train_scenes.txt") 63 | elif split == "test": 64 | scenes_path = os.path.join(base_folders, "test_scenes.txt") 65 | elif split == "val": 66 | scenes_path = os.path.join(base_folders, "val_scenes.txt") 67 | 68 | with open(scenes_path, 'r', encoding='utf-8') as file: 69 | scene_lists = [line.strip() for line in file.readlines()] 70 | 71 | dataset_list = [] 72 | for i in range(len(scene_lists)): 73 | dataset_list.append(dataset_type(base_folders, scene_lists[i], **dataset_kwargs)) 74 | return ConcatDataset(dataset_list) 75 | 76 | T = np.array([[1,0,0,0], 77 | [0,-1,0,0], 78 | [0,0,-1,0], 79 | [0,0,0,1]]) 80 | 81 | class EventDataloader(DataLoader): 82 | def __init__(self, base_folders, split, num_workers, batch_size, shuffle=True): 83 | dataset = concatenate_datasets_split(base_folders, ReadEventFromH5, split=split) 84 | super().__init__(dataset, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle) 85 | 86 | cs = cfg.cs 87 | 88 | class ReadEventFromH5(Dataset): 89 | def __init__(self, base_folder, scene, polarity_offset=0): 90 | self.base_folder = base_folder 91 | self.scene = scene 92 | self.polarity_offset = polarity_offset 93 | self.H, self.W = 480, 640 94 | self.cropped_H, self.cropped_W = cs[1]-cs[0], cs[3] - cs[2] 95 | self.event_slices() 96 | 97 | def event_slices(self): 98 | ## load .h5 event files and generate event frames and voxels 99 | self.event_files_path = os.path.join(self.base_folder, "Event", self.scene) 100 | scene_files_path = os.path.join(self.base_folder, "Scenes", self.scene) 101 | self.pose_files = find_files('{}/Poses'.format(self.base_folder), exts=['*.txt']) 102 | self.num_views = len(self.pose_files) 103 | intrinsic_files = find_files('{}/Intrinsics'.format(self.base_folder), exts=['*.txt'])[:self.num_views] 104 | self.npz_files = find_files('{}'.format(scene_files_path), exts=["*.npz"])[:self.num_views] 105 | self.rgb_files = find_files('{}'.format(scene_files_path), exts=['*.png'])[:self.num_views] 106 | 107 | self.intrinsics = parse_txt(intrinsic_files[0], (4, 4)) 108 | 109 | def __len__(self): 110 | return len(self.pose_files) 111 | 112 | def events_to_voxel(self, events): 113 | # generate a voxel grid from input events using temporal bilinear interpolation. 114 | x, y, t, p = events 115 | x = x.astype(np.int32) 116 | y = y.astype(np.int32) 117 | p = p.astype(np.int32) 118 | mask_pos = p.copy() 119 | mask_neg = p.copy() 120 | mask_pos[p < 0] = 0 121 | mask_neg[p > 0] = 0 122 | frame1 = self.events_to_image(x, y, p * mask_pos) 123 | frame2 = self.events_to_image(x, y, p * mask_neg) 124 | frame3 = frame1 - frame2 125 | # cv2.imwrite('1.png', 128+frame1) 126 | # cv2.imwrite('2.png', 128-frame2) 127 | # cv2.imwrite('3.png', 128+frame3) 128 | return np.stack(((128+frame1)/255, (128-frame2)/255, (128 + frame3)/255), axis=2) 129 | 130 | def events_to_image(self, xs, ys, ps): 131 | # accumulate events into an image. 132 | img = np.zeros((self.H, self.W)) 133 | np.add.at(img, (ys, xs), ps) 134 | 135 | # img = np.clip(img, -5, 5) 136 | # print(img) 137 | return img 138 | 139 | def find_depth(self, npz_files, idx): 140 | npz = np.load(npz_files[idx], allow_pickle=True) 141 | depth = npz['depth_map'] 142 | depth = self.prepare_depth(depth) 143 | return depth 144 | 145 | def find_pose(self, npz_files, idx): 146 | npz = np.load(npz_files[idx], allow_pickle=True) 147 | poses = npz['object_poses'] 148 | for obj in poses: 149 | obj_name = obj['name'] 150 | obj_mat = obj['pose'] 151 | if obj_name == 'Camera': 152 | pose = obj_mat.astype(np.float32) 153 | break 154 | return pose @ T 155 | 156 | def prepare_depth(self, depth): 157 | # adjust depth maps generated by vision blender 158 | INVALID_DEPTH = -1 159 | depth[depth == INVALID_DEPTH] = 0 160 | 161 | return depth 162 | 163 | def accumulate_events_edited(self, events): 164 | x, y, t, p = events 165 | 166 | def events_to_frame(self, events): 167 | # generate a voxel grid from input events using temporal bilinear interpolation. 168 | x, y, t, p = events 169 | x = x.astype(np.int32) 170 | y = y.astype(np.int32) 171 | p = p.astype(np.int32) 172 | mask_pos = p.copy() 173 | mask_neg = p.copy() 174 | mask_pos[p < 0] = 0 175 | mask_neg[p > 0] = 0 176 | frame1 = self.events_to_image(x, y, p * mask_pos) 177 | frame2 = self.events_to_image(x, y, p * mask_neg) 178 | frame3 = frame1 - frame2 179 | 180 | return np.stack(((128 + frame1)/255, (128-frame2)/255, (128 + frame3)/255), axis=2) 181 | 182 | def events_to_image(self, xs, ys, ps): 183 | # accumulate events into an image. 184 | img = np.zeros((self.H, self.W)) 185 | np.add.at(img, (ys, xs), ps) 186 | # print(img.max(), img.min()) 187 | # img = np.clip(img, -5, 5) 188 | # print(img) 189 | return img 190 | 191 | def accumulate_events(self, events, resolution_level=1, polarity_offset=0): 192 | x, y, t, p = events 193 | acc_frm = np.zeros((self.H, self.W)) 194 | np.add.at(acc_frm, (y // resolution_level, x // resolution_level), p + polarity_offset) 195 | return acc_frm 196 | 197 | def __getitem__(self, idx): 198 | index = str(idx).zfill(4) 199 | left_event1 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(index))) 200 | 201 | # left_event_voxel = events_to_voxel_grid(left_event.transpose((1,0)), cfg.model.num_bins, self.W, self.H) 202 | # left_pose = parse_txt(self.pose_files[idx], (4,4)) 203 | 204 | left_pose = self.find_pose(self.npz_files, idx) 205 | left_depth_gt = self.find_depth(self.npz_files, idx) 206 | left_mask = (left_depth_gt > 0) 207 | left_img = cv2.cvtColor(cv2.imread(self.rgb_files[idx])[...,:3] * left_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) / 255. 208 | 209 | if idx + 1 < len(self.pose_files): 210 | left_event2 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(str(idx+1).zfill(4)))) 211 | center_depth_gt = self.find_depth(self.npz_files, idx+1) 212 | center_mask = (center_depth_gt > 0) 213 | # center_pose = parse_txt(self.pose_files[idx+1], (4,4)) 214 | center_pose = self.find_pose(self.npz_files, idx+1) 215 | 216 | # try: 217 | int_img = cv2.cvtColor(cv2.imread(self.rgb_files[idx+1])[...,:3] * center_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) / 255. 218 | 219 | else: 220 | left_event2 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(str(0).zfill(4)))) 221 | center_depth_gt = self.find_depth(self.npz_files, 0) 222 | center_mask = (center_depth_gt > 0) 223 | # center_pose = parse_txt(self.pose_files[0], (4,4)) 224 | center_pose = self.find_pose(self.npz_files, 0) 225 | int_img = cv2.cvtColor(cv2.imread(self.rgb_files[0])[...,:3] * center_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) /255. 226 | 227 | center_extrinsics = np.linalg.inv(center_pose) 228 | 229 | left_event_frame = self.events_to_frame(np.hstack((left_event1, left_event2))) 230 | left_event_voxel = events_to_voxel_grid(np.hstack((left_event1, left_event2)).transpose((1,0)), cfg.model.num_bins, self.W, self.H) 231 | if idx + 2 < len(self.pose_files): 232 | r_id = idx + 2 233 | r_index = str(r_id).zfill(4) 234 | right_event1 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(r_index))) 235 | # right_pose = parse_txt(self.pose_files[r_id], (4,4)) 236 | right_pose = self.find_pose(self.npz_files, r_id) 237 | right_depth_gt = self.find_depth(self.npz_files, r_id) 238 | else: 239 | r_id = (idx + 2) % len(self.pose_files) 240 | r_index = str(r_id).zfill(4) 241 | right_event1 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(r_index))) 242 | # right_pose = parse_txt(self.pose_files[r_id], (4,4)) 243 | right_pose = self.find_pose(self.npz_files, r_id) 244 | right_depth_gt = self.find_depth(self.npz_files, r_id) 245 | 246 | if idx + 3 < len(self.pose_files): 247 | r_id2 = idx + 3 248 | r_index2 = str(r_id2).zfill(4) 249 | right_event2 = np.load(os.path.join(self.event_files_path,'{}.npy'.format(r_index2))) 250 | else: 251 | r_id2 = (idx + 3) % len(self.pose_files) 252 | r_index2 = str(r_id2).zfill(4) 253 | right_event2 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(r_index2))) 254 | 255 | 256 | # pr = depth2pc_np_ours(right_depth_gt, np.linalg.inv(right_pose), self.intrinsics) 257 | # pl = depth2pc_np_ours(left_depth_gt, np.linalg.inv(left_pose), self.intrinsics) 258 | # pc = np.concatenate([pr, pl], axis=0) 259 | # pcd = o3d.geometry.PointCloud() 260 | # pcd.points = o3d.utility.Vector3dVector(pc) 261 | # o3d.io.write_point_cloud("pts.ply", pcd) 262 | # print(np.hstack((right_event1, right_event2)).shape) 263 | right_event_frame = self.events_to_frame(np.hstack((right_event1, right_event2))) 264 | right_event_voxel = events_to_voxel_grid(np.hstack((right_event1, right_event2)).transpose((1,0)), cfg.model.num_bins, self.W, self.H) 265 | # right_event_voxel = events_to_voxel_grid(right_event.transpose((1,0)), cfg.model.num_bins, self.W, self.H) 266 | 267 | right_mask = (right_depth_gt > 0) 268 | right_img = cv2.cvtColor(cv2.imread(self.rgb_files[r_id])[...,:3] * right_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) / 255. 269 | 270 | center_event = np.hstack((left_event2, right_event1)) 271 | center_event_frame = self.events_to_frame(center_event) 272 | center_event_voxel = events_to_voxel_grid(center_event.transpose((1,0)), cfg.model.num_bins, self.W, self.H) 273 | 274 | intrinsic = self.intrinsics 275 | 276 | intrinsic[0,2] = (self.cropped_W - 1) / 2 277 | intrinsic[1,2] = (self.cropped_H - 1) / 2 278 | 279 | projection_matrix = getProjectionMatrix(znear=0.01, zfar=0.99, K=intrinsic, h=self.cropped_H, w=self.cropped_W).transpose(0, 1) 280 | world_view_transform = torch.tensor(getWorld2View2(center_extrinsics[:3,:3].reshape(3, 3).transpose(1, 0)\ 281 | , center_extrinsics[:3, 3])).transpose(0, 1) 282 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) 283 | camera_center = world_view_transform.inverse()[3, :3] 284 | 285 | item = { 286 | 'cim': int_img.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]][np.newaxis], #[1, H, W] 287 | 'lim': left_img.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]][np.newaxis], 288 | 'rim': right_img.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]][np.newaxis], 289 | 'leframe': left_event_frame.transpose((2,0,1)).astype(np.float32)[:, cs[0]:cs[1], cs[2]:cs[3]], #[3, H, W] 290 | 'reframe': right_event_frame.transpose((2,0,1)).astype(np.float32)[:, cs[0]:cs[1], cs[2]:cs[3]], 291 | 'ceframe': center_event_frame.transpose((2,0,1)).astype(np.float32)[:, cs[0]:cs[1], cs[2]:cs[3]], 292 | 'lmask': left_mask[cs[0]:cs[1], cs[2]:cs[3]], #[H, W] 293 | 'rmask': right_mask[cs[0]:cs[1], cs[2]:cs[3]], 294 | 'cmask': center_mask[cs[0]:cs[1], cs[2]:cs[3]], 295 | 'lpose': left_pose.astype(np.float32), #[4, 4] 296 | 'rpose': right_pose.astype(np.float32), 297 | 'intrinsic': intrinsic.astype(np.float32), #[4, 4] 298 | 'ldepth': left_depth_gt.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]], # #[H, W] 299 | 'rdepth': right_depth_gt.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]], 300 | 'cdepth': center_depth_gt.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]], 301 | 'center_voxel':center_event_voxel[:, cs[0]:cs[1], cs[2]:cs[3]], #[5, H, W] 302 | 'right_voxel':right_event_voxel[:, cs[0]:cs[1], cs[2]:cs[3]], 303 | 'left_voxel':left_event_voxel[:, cs[0]:cs[1], cs[2]:cs[3]], 304 | ### target view rendering parameters ### 305 | "H":self.cropped_H, 306 | "W":self.cropped_W, 307 | "FovX":focal2fov(intrinsic[0, 0], self.cropped_W), 308 | "FovY":focal2fov(intrinsic[1, 1], self.cropped_H), 309 | 'world_view_transform': world_view_transform, #[4, 4] 310 | 'full_proj_transform': full_proj_transform, #[4, 4] 311 | 'camera_center': camera_center #[3] 312 | } 313 | return item 314 | 315 | 316 | 317 | 318 | 319 | # if __name__ == "__main__": 320 | # dataset = ReadEventFromH5(r"/home/lsf_storage/dataset/EV3D5/", "AK47",0) 321 | # dataset[0] 322 | # # dataloader = EventDataloader(r"/home/lsf_storage/dataset/EV3D/", split="full", batch_size=1, num_workers=1, shuffle=False) 323 | # # print(len(dataloader)) 324 | 325 | # # for idx, batch in enumerate(dataloader): 326 | # # print(batch["ldepth"].shape) 327 | # # break -------------------------------------------------------------------------------- /lib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | 3 | 4 | from .Ev3D import EventDataloader -------------------------------------------------------------------------------- /lib/dataset/__pycache__/Ev3D.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/Ev3D.cpython-310.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/Ev3D.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/Ev3D.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/Ev3D.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/Ev3D.cpython-39.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /lib/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def events_to_voxel_grid(events, num_bins, width, height): 5 | """ 6 | Build a voxel grid with bilinear interpolation in the time domain from a set of events. 7 | 8 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity] 9 | :param num_bins: number of bins in the temporal axis of the voxel grid 10 | :param width, height: dimensions of the voxel grid 11 | """ 12 | 13 | assert(events.shape[1] == 4) 14 | assert(num_bins > 0) 15 | assert(width > 0) 16 | assert(height > 0) 17 | 18 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel() 19 | 20 | # normalize the event timestamps so that they lie between 0 and num_bins 21 | last_stamp = events[-1, 2] 22 | first_stamp = events[0, 2] 23 | deltaT = last_stamp - first_stamp 24 | 25 | if deltaT == 0: 26 | deltaT = 1.0 27 | 28 | events[:, 2] = (num_bins - 1) * (events[:, 2] - first_stamp) / deltaT 29 | ts = events[:, 2] 30 | xs = events[:, 0].astype(np.int) 31 | ys = events[:, 1].astype(np.int) 32 | pols = events[:, 3] 33 | pols[pols == 0] = -1 # polarity should be +1 / -1 34 | 35 | tis = ts.astype(np.int) 36 | dts = ts - tis 37 | vals_left = pols * (1.0 - dts) 38 | vals_right = pols * dts 39 | 40 | valid_indices = tis < num_bins 41 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width 42 | + tis[valid_indices] * width * height, vals_left[valid_indices]) 43 | 44 | valid_indices = (tis + 1) < num_bins 45 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width 46 | + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices]) 47 | 48 | voxel_grid = np.reshape(voxel_grid, (num_bins, height, width)) 49 | 50 | return voxel_grid 51 | 52 | 53 | -------------------------------------------------------------------------------- /lib/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | from math import exp 5 | 6 | 7 | def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9): 8 | """ Loss function defined over sequence of flow predictions """ 9 | 10 | n_predictions = len(flow_preds) 11 | flow_loss = 0.0 12 | 13 | valid = (valid >= 0.5) 14 | assert not torch.isinf(flow_gt[valid.bool()]).any() 15 | 16 | for i in range(n_predictions): 17 | # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations 18 | adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1)) 19 | i_weight = adjusted_loss_gamma**(n_predictions - i - 1) 20 | i_loss = (flow_preds[i] - flow_gt).abs() 21 | flow_loss += i_weight * i_loss[valid.bool()].mean() 22 | 23 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 24 | epe = epe.view(-1)[valid.view(-1)] 25 | 26 | metrics = { 27 | 'train_epe': epe.mean().item(), 28 | 'train_1px': (epe < 1).float().mean().item(), 29 | 'train_3px': (epe < 3).float().mean().item() 30 | } 31 | 32 | return flow_loss, metrics 33 | 34 | 35 | def l1_loss(network_output, gt): 36 | return torch.abs((network_output - gt)).mean() 37 | 38 | def mse_loss(out, gt, msk=None): 39 | if msk is None: 40 | loss = torch.mean((out - gt) ** 2) 41 | else: 42 | # loss = torch.mean((out[msk] - gt[msk]) ** 2) 43 | loss = torch.mean((out - gt) ** 2) 44 | return loss 45 | 46 | def gaussian(window_size, sigma): 47 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 48 | return gauss / gauss.sum() 49 | 50 | 51 | def create_window(window_size, channel): 52 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 53 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 54 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 55 | return window 56 | 57 | 58 | def ssim(img1, img2, window_size=11, size_average=True): 59 | channel = img1.size(-3) 60 | window = create_window(window_size, channel) 61 | 62 | if img1.is_cuda: 63 | window = window.cuda(img1.get_device()) 64 | window = window.type_as(img1) 65 | 66 | return _ssim(img1, img2, window, window_size, channel, size_average) 67 | 68 | 69 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 70 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 71 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 72 | 73 | mu1_sq = mu1.pow(2) 74 | mu2_sq = mu2.pow(2) 75 | mu1_mu2 = mu1 * mu2 76 | 77 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 78 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 79 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 80 | 81 | C1 = 0.01 ** 2 82 | C2 = 0.03 ** 2 83 | 84 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 85 | 86 | if size_average: 87 | return ssim_map.mean() 88 | else: 89 | return ssim_map.mean(1).mean(1).mean(1) 90 | 91 | 92 | def psnr(img1, img2): 93 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 94 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 95 | -------------------------------------------------------------------------------- /lib/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .asnet_utils import model_loss, model_loss_light 2 | from .recon_net import E2IM, E2DPT, E2Msk 3 | from .eventgaussian import EventGaussian -------------------------------------------------------------------------------- /lib/network/__pycache__/ASNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/ASNet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/ASNet_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/ASNet_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/SegNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/SegNet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/asnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/asnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/asnet_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/asnet_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/densenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/densenet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/dfanet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/dfanet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/eventgaussian.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/eventgaussian.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/firenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/firenet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/gsregressor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/gsregressor.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/mobilenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/mobilenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/pspnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/pspnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/recon_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/recon_net.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/submodules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/submodules.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/swin.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/swin.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/__pycache__/unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/unet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/network/asnet.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import math 4 | import torch.nn as nn 5 | import torch.utils.data 6 | import torch.nn.functional as F 7 | from .asnet_utils import * 8 | 9 | 10 | class hourglass2D(nn.Module): 11 | def __init__(self, in_channels): 12 | super(hourglass2D, self).__init__() 13 | 14 | self.expanse_ratio = 2 15 | 16 | self.conv1 = MobileV2_Residual(in_channels, in_channels * 2, stride=2, expanse_ratio=self.expanse_ratio) 17 | 18 | self.conv2 = MobileV2_Residual(in_channels * 2, in_channels * 2, stride=1, expanse_ratio=self.expanse_ratio) 19 | 20 | self.conv3 = MobileV2_Residual(in_channels * 2, in_channels * 4, stride=2, expanse_ratio=self.expanse_ratio) 21 | 22 | self.conv4 = MobileV2_Residual(in_channels * 4, in_channels * 4, stride=1, expanse_ratio=self.expanse_ratio) 23 | 24 | self.conv5 = nn.Sequential( 25 | nn.ConvTranspose2d(in_channels * 4, in_channels * 2, 3, padding=1, output_padding=1, stride=2, bias=False), 26 | nn.BatchNorm2d(in_channels * 2)) 27 | 28 | self.conv6 = nn.Sequential( 29 | nn.ConvTranspose2d(in_channels * 2, in_channels, 3, padding=1, output_padding=1, stride=2, bias=False), 30 | nn.BatchNorm2d(in_channels)) 31 | 32 | self.redir1 = MobileV2_Residual(in_channels, in_channels, stride=1, expanse_ratio=self.expanse_ratio) 33 | self.redir2 = MobileV2_Residual(in_channels * 2, in_channels * 2, stride=1, expanse_ratio=self.expanse_ratio) 34 | 35 | def forward(self, x): 36 | conv1 = self.conv1(x) 37 | conv2 = self.conv2(conv1) 38 | 39 | conv3 = self.conv3(conv2) 40 | conv4 = self.conv4(conv3) 41 | 42 | conv5 = F.relu(self.conv5(conv4) + self.redir2(conv2), inplace=True) 43 | conv6 = F.relu(self.conv6(conv5) + self.redir1(x), inplace=True) 44 | 45 | return conv6 46 | 47 | 48 | class ASNet(nn.Module): 49 | def __init__(self, maxdisp): 50 | 51 | super(ASNet, self).__init__() 52 | 53 | self.maxdisp = maxdisp 54 | 55 | self.num_groups = 1 56 | 57 | self.volume_size = 48 58 | 59 | self.hg_size = 32 60 | 61 | self.dres_expanse_ratio = 3 62 | 63 | self.feature_extraction0 = feature_extraction() 64 | 65 | 66 | self.volume1 = volume_build(self.volume_size) 67 | 68 | 69 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio), 70 | nn.ReLU(inplace=True), 71 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 72 | nn.ReLU(inplace=True)) 73 | 74 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 75 | nn.ReLU(inplace=True), 76 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio)) 77 | 78 | self.encoder_decoder1 = hourglass2D(self.hg_size) 79 | self.encoder_decoder2 = hourglass2D(self.hg_size) 80 | self.encoder_decoder3 = hourglass2D(self.hg_size) 81 | 82 | 83 | self.classif0 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 86 | bias=False, dilation=1)) 87 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 90 | bias=False, dilation=1)) 91 | self.classif2 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 92 | nn.ReLU(inplace=True), 93 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 94 | bias=False, dilation=1)) 95 | self.classif3 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 96 | nn.ReLU(inplace=True), 97 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 98 | bias=False, dilation=1)) 99 | 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.Conv3d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | elif isinstance(m, nn.BatchNorm3d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | elif isinstance(m, nn.Linear): 115 | m.bias.data.zero_() 116 | 117 | 118 | def forward(self, L, R): 119 | 120 | featL = self.feature_extraction0(L) 121 | featR = self.feature_extraction0(R) 122 | 123 | xALL0 = self.volume1(featL, featR) 124 | 125 | cost0 = self.dres0(xALL0) 126 | cost0 = self.dres1(cost0) + cost0 127 | 128 | out1 = self.encoder_decoder1(cost0) 129 | out2 = self.encoder_decoder2(out1) 130 | out3 = self.encoder_decoder3(out2) 131 | 132 | if self.training: 133 | cost0 = self.classif0(cost0) 134 | cost1 = self.classif1(out1) 135 | cost2 = self.classif2(out2) 136 | cost3 = self.classif3(out3) 137 | 138 | cost0 = torch.unsqueeze(cost0, 1) 139 | cost0 = F.interpolate(cost0, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 140 | 141 | cost0 = torch.squeeze(cost0, 1) 142 | pred0 = F.softmax(cost0, dim=1) 143 | pred0 = disparity_regression(pred0, self.maxdisp) 144 | 145 | cost1 = torch.unsqueeze(cost1, 1) 146 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 147 | cost1 = torch.squeeze(cost1, 1) 148 | pred1 = F.softmax(cost1, dim=1) 149 | pred1 = disparity_regression(pred1, self.maxdisp) 150 | 151 | cost2 = torch.unsqueeze(cost2, 1) 152 | cost2 = F.interpolate(cost2, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 153 | cost2 = torch.squeeze(cost2, 1) 154 | pred2 = F.softmax(cost2, dim=1) 155 | pred2 = disparity_regression(pred2, self.maxdisp) 156 | 157 | cost3 = torch.unsqueeze(cost3, 1) 158 | cost3 = F.interpolate(cost3, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 159 | cost3 = torch.squeeze(cost3, 1) 160 | pred3 = F.softmax(cost3, dim=1) 161 | pred3 = disparity_regression(pred3, self.maxdisp) 162 | 163 | return [pred0, pred1, pred2, pred3] 164 | 165 | else: 166 | cost3 = self.classif3(out3) 167 | cost3 = torch.unsqueeze(cost3, 1) 168 | cost3 = F.interpolate(cost3, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 169 | cost3 = torch.squeeze(cost3, 1) 170 | pred3 = F.softmax(cost3, dim=1) 171 | pred3 = disparity_regression(pred3, self.maxdisp) 172 | 173 | return [pred3] 174 | 175 | 176 | 177 | class ASNet_light(nn.Module): 178 | def __init__(self, maxdisp): 179 | 180 | super(ASNet_light, self).__init__() 181 | 182 | self.maxdisp = maxdisp 183 | 184 | self.num_groups = 1 185 | 186 | self.volume_size = 48 187 | 188 | self.hg_size = 16 189 | 190 | self.dres_expanse_ratio = 3 191 | 192 | self.feature_extraction0 = feature_extraction() 193 | 194 | self.volume1 = volume_build(self.volume_size) 195 | 196 | 197 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio), 198 | nn.ReLU(inplace=True), 199 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 200 | nn.ReLU(inplace=True)) 201 | 202 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 203 | nn.ReLU(inplace=True), 204 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio)) 205 | 206 | self.encoder_decoder1 = hourglass2D(self.hg_size) 207 | # self.encoder_decoder2 = hourglass2D(self.hg_size) 208 | # self.encoder_decoder3 = hourglass2D(self.hg_size) 209 | 210 | self.classif0 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 211 | nn.ReLU(inplace=True), 212 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 213 | bias=False, dilation=1)) 214 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 215 | nn.ReLU(inplace=True), 216 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 217 | bias=False, dilation=1)) 218 | 219 | 220 | for m in self.modules(): 221 | if isinstance(m, nn.Conv2d): 222 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 223 | m.weight.data.normal_(0, math.sqrt(2. / n)) 224 | elif isinstance(m, nn.Conv3d): 225 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 226 | m.weight.data.normal_(0, math.sqrt(2. / n)) 227 | elif isinstance(m, nn.BatchNorm2d): 228 | m.weight.data.fill_(1) 229 | m.bias.data.zero_() 230 | elif isinstance(m, nn.BatchNorm3d): 231 | m.weight.data.fill_(1) 232 | m.bias.data.zero_() 233 | elif isinstance(m, nn.Linear): 234 | m.bias.data.zero_() 235 | 236 | def forward(self, L, R): 237 | featL = self.feature_extraction0(L) 238 | featR = self.feature_extraction0(R) 239 | 240 | xALL0 = self.volume1(featL, featR) 241 | 242 | cost0 = self.dres0(xALL0) 243 | cost0 = self.dres1(cost0) + cost0 244 | 245 | out1 = self.encoder_decoder1(cost0) 246 | 247 | if self.training: 248 | cost0 = self.classif0(cost0) 249 | cost1 = self.classif1(out1) 250 | 251 | cost0 = torch.unsqueeze(cost0, 1) 252 | cost0 = F.interpolate(cost0, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 253 | 254 | cost0 = torch.squeeze(cost0, 1) 255 | pred0 = F.softmax(cost0, dim=1) 256 | pred0 = disparity_regression(pred0, self.maxdisp) 257 | 258 | cost1 = torch.unsqueeze(cost1, 1) 259 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 260 | cost1 = torch.squeeze(cost1, 1) 261 | pred1 = F.softmax(cost1, dim=1) 262 | pred1 = disparity_regression(pred1, self.maxdisp) 263 | 264 | return [pred0, pred1] 265 | else: 266 | cost1 = self.classif1(out1) 267 | cost1 = torch.unsqueeze(cost1, 1) 268 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 269 | cost1 = torch.squeeze(cost1, 1) 270 | pred1 = F.softmax(cost1, dim=1) 271 | pred1 = disparity_regression(pred1, self.maxdisp) 272 | 273 | return [pred1] 274 | 275 | def get_features(self, L, R): 276 | featL = self.feature_extraction0(L) 277 | featR = self.feature_extraction0(R) 278 | 279 | xALL0 = self.volume1(featL, featR) 280 | 281 | cost0 = self.dres0(xALL0) 282 | cost0 = self.dres1(cost0) + cost0 283 | 284 | out1 = self.encoder_decoder1(cost0) 285 | 286 | cost1 = self.classif1(out1) 287 | cost1 = torch.unsqueeze(cost1, 1) 288 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 289 | cost1 = torch.squeeze(cost1, 1) 290 | pred1 = F.softmax(cost1, dim=1) 291 | pred1 = disparity_regression(pred1, self.maxdisp) 292 | 293 | return pred1, F.interpolate(torch.unsqueeze(out1, 1), [self.hg_size, L.size()[2], L.size()[3]], mode='trilinear').squeeze(1) 294 | 295 | 296 | 297 | class ASNet_mask(nn.Module): 298 | def __init__(self, maxdisp): 299 | 300 | super(ASNet_mask, self).__init__() 301 | 302 | self.maxdisp = maxdisp 303 | 304 | self.num_groups = 1 305 | 306 | self.volume_size = 48 307 | 308 | self.hg_size = 32 309 | 310 | self.dres_expanse_ratio = 3 311 | 312 | self.feature_extraction0 = feature_extraction() 313 | 314 | self.volume1 = volume_build(self.volume_size) 315 | 316 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio), 317 | nn.ReLU(inplace=True), 318 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 319 | nn.ReLU(inplace=True)) 320 | 321 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 322 | nn.ReLU(inplace=True), 323 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio)) 324 | 325 | self.encoder_decoder1 = hourglass2D(self.hg_size) 326 | # self.encoder_decoder2 = hourglass2D(self.hg_size) 327 | # self.encoder_decoder3 = hourglass2D(self.hg_size) 328 | 329 | self.output_head = nn.Sequential(convbn(self.maxdisp, 1, 3, 1, 1, 1), 330 | nn.Sigmoid()) 331 | 332 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 333 | nn.ReLU(inplace=True), 334 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 335 | bias=False, dilation=1)) 336 | 337 | for m in self.modules(): 338 | if isinstance(m, nn.Conv2d): 339 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 340 | m.weight.data.normal_(0, math.sqrt(2. / n)) 341 | elif isinstance(m, nn.Conv3d): 342 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 343 | m.weight.data.normal_(0, math.sqrt(2. / n)) 344 | elif isinstance(m, nn.BatchNorm2d): 345 | m.weight.data.fill_(1) 346 | m.bias.data.zero_() 347 | elif isinstance(m, nn.BatchNorm3d): 348 | m.weight.data.fill_(1) 349 | m.bias.data.zero_() 350 | elif isinstance(m, nn.Linear): 351 | m.bias.data.zero_() 352 | 353 | def forward(self, L, R): 354 | featL = self.feature_extraction0(L) 355 | featR = self.feature_extraction0(R) 356 | 357 | xALL0 = self.volume1(featL, featR) 358 | 359 | cost0 = self.dres0(xALL0) 360 | cost0 = self.dres1(cost0) + cost0 361 | 362 | out1 = self.encoder_decoder1(cost0) 363 | cost1 = self.classif1(out1) 364 | cost1 = torch.unsqueeze(cost1, 1) 365 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 366 | cost1 = torch.squeeze(cost1, 1) 367 | pred = torch.clamp_max(self.output_head(cost1), 1.0) 368 | 369 | return pred 370 | 371 | 372 | 373 | class ASNet_color(nn.Module): 374 | def __init__(self, maxdisp): 375 | 376 | super(ASNet_color, self).__init__() 377 | 378 | self.maxdisp = maxdisp 379 | 380 | self.num_groups = 1 381 | 382 | self.volume_size = 48 383 | 384 | self.hg_size = 32 385 | 386 | self.dres_expanse_ratio = 3 387 | 388 | self.feature_extraction0 = feature_extraction() 389 | 390 | self.volume1 = volume_build(self.volume_size) 391 | 392 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio), 393 | nn.ReLU(inplace=True), 394 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 395 | nn.ReLU(inplace=True)) 396 | 397 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio), 398 | nn.ReLU(inplace=True), 399 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio)) 400 | 401 | # self.encoder_decoder1 = hourglass2D(self.hg_size) 402 | # self.encoder_decoder2 = hourglass2D(self.hg_size) 403 | # self.encoder_decoder3 = hourglass2D(self.hg_size) 404 | 405 | self.output_head = nn.Sequential(convbn(self.maxdisp, 1, 3, 1, 1, 1), 406 | nn.Sigmoid()) 407 | 408 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1), 409 | nn.ReLU(inplace=True), 410 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1, 411 | bias=False, dilation=1)) 412 | 413 | for m in self.modules(): 414 | if isinstance(m, nn.Conv2d): 415 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 416 | m.weight.data.normal_(0, math.sqrt(2. / n)) 417 | elif isinstance(m, nn.Conv3d): 418 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 419 | m.weight.data.normal_(0, math.sqrt(2. / n)) 420 | elif isinstance(m, nn.BatchNorm2d): 421 | m.weight.data.fill_(1) 422 | m.bias.data.zero_() 423 | elif isinstance(m, nn.BatchNorm3d): 424 | m.weight.data.fill_(1) 425 | m.bias.data.zero_() 426 | elif isinstance(m, nn.Linear): 427 | m.bias.data.zero_() 428 | 429 | def forward(self, L, R): 430 | featL = self.feature_extraction0(L) 431 | featR = self.feature_extraction0(R) 432 | 433 | xALL0 = self.volume1(featL, featR) 434 | 435 | cost0 = self.dres0(xALL0) 436 | cost0 = self.dres1(cost0) + cost0 437 | # out1 = self.encoder_decoder1(cost0) 438 | cost1 = self.classif1(cost0) 439 | cost1 = torch.unsqueeze(cost1, 1) 440 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear') 441 | cost1 = torch.squeeze(cost1, 1) 442 | pred = self.output_head(cost1) 443 | 444 | return pred 445 | -------------------------------------------------------------------------------- /lib/network/asnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.config import cfg 5 | 6 | 7 | ############################################################################### 8 | """ Fundamental Building Blocks """ 9 | ############################################################################### 10 | 11 | 12 | def convbn(in_channels, out_channels, kernel_size, stride, pad, dilation): 13 | return nn.Sequential( 14 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 15 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 16 | nn.BatchNorm2d(out_channels) 17 | ) 18 | 19 | 20 | def convbn_dws(inp, oup, kernel_size, stride, pad, dilation, second_relu=True): 21 | if second_relu: 22 | return nn.Sequential( 23 | # dw 24 | nn.Conv2d(inp, inp, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, 25 | dilation=dilation, groups=inp, bias=False), 26 | nn.BatchNorm2d(inp), 27 | nn.ReLU6(inplace=True), 28 | # pw 29 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 30 | nn.BatchNorm2d(oup), 31 | nn.ReLU6(inplace=False) 32 | ) 33 | else: 34 | return nn.Sequential( 35 | # dw 36 | nn.Conv2d(inp, inp, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, 37 | dilation=dilation, groups=inp, bias=False), 38 | nn.BatchNorm2d(inp), 39 | nn.ReLU6(inplace=True), 40 | # pw 41 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 42 | nn.BatchNorm2d(oup) 43 | ) 44 | 45 | class MobileV1_Residual(nn.Module): 46 | expansion = 1 47 | 48 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 49 | super(MobileV1_Residual, self).__init__() 50 | 51 | self.stride = stride 52 | self.downsample = downsample 53 | self.conv1 = convbn_dws(inplanes, planes, 3, stride, pad, dilation) 54 | self.conv2 = convbn_dws(planes, planes, 3, 1, pad, dilation, second_relu=False) 55 | 56 | def forward(self, x): 57 | out = self.conv1(x) 58 | out = self.conv2(out) 59 | 60 | if self.downsample is not None: 61 | x = self.downsample(x) 62 | 63 | out += x 64 | 65 | return out 66 | 67 | 68 | 69 | class MobileV2_Residual(nn.Module): 70 | def __init__(self, inp, oup, stride, expanse_ratio, dilation=1): 71 | super(MobileV2_Residual, self).__init__() 72 | self.stride = stride 73 | assert stride in [1, 2] 74 | 75 | hidden_dim = int(inp * expanse_ratio) 76 | self.use_res_connect = self.stride == 1 and inp == oup 77 | pad = dilation 78 | 79 | if expanse_ratio == 1: 80 | self.conv = nn.Sequential( 81 | # dw 82 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, pad, dilation=dilation, groups=hidden_dim, bias=False), 83 | nn.BatchNorm2d(hidden_dim), 84 | nn.ReLU6(inplace=True), 85 | # pw-linear 86 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 87 | nn.BatchNorm2d(oup), 88 | ) 89 | else: 90 | self.conv = nn.Sequential( 91 | # pw 92 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 93 | nn.BatchNorm2d(hidden_dim), 94 | nn.ReLU6(inplace=True), 95 | # dw 96 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, pad, dilation=dilation, groups=hidden_dim, bias=False), 97 | nn.BatchNorm2d(hidden_dim), 98 | nn.ReLU6(inplace=True), 99 | # pw-linear 100 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 101 | nn.BatchNorm2d(oup), 102 | ) 103 | 104 | def forward(self, x): 105 | if self.use_res_connect: 106 | return x + self.conv(x) 107 | else: 108 | return self.conv(x) 109 | 110 | 111 | 112 | class InsideBlockConv(nn.Module): 113 | def __init__(self, in_features, out_features): 114 | super(InsideBlockConv, self).__init__() 115 | self.double_conv = nn.Sequential( 116 | nn.Conv2d(in_features, out_features, kernel_size=3, padding=(1,1)), # For same padding: pad=1 for filter=3 117 | nn.BatchNorm2d(out_features), 118 | nn.ReLU(inplace=True), # inplace=True doesn't create additonal memory. Not always correct operation. But here there is no issue 119 | nn.Conv2d(out_features, out_features, kernel_size=3, padding=(1,1)), 120 | nn.BatchNorm2d(out_features), 121 | nn.ReLU(inplace=True) 122 | ) 123 | 124 | def forward(self, x1): 125 | return self.double_conv(x1) 126 | 127 | ############################################################################### 128 | """ Feature Extraction """ 129 | ############################################################################### 130 | 131 | 132 | 133 | 134 | 135 | class feature_extraction(nn.Module): 136 | def __init__(self): 137 | super(feature_extraction, self).__init__() 138 | 139 | self.expanse_ratio = 3 140 | self.inplanes = 32 141 | 142 | self.firstconv0 = nn.Sequential(MobileV2_Residual(1, 4, 2, self.expanse_ratio), 143 | nn.ReLU(inplace=True), 144 | MobileV2_Residual(4, 16, 1, self.expanse_ratio), 145 | nn.ReLU(inplace=True), 146 | MobileV2_Residual(16, 32, 1, self.expanse_ratio), 147 | nn.ReLU(inplace=True) 148 | ) 149 | self.firstconv1 = nn.Sequential(MobileV2_Residual(1, 4, 2, self.expanse_ratio), 150 | nn.ReLU(inplace=True), 151 | MobileV2_Residual(4, 16, 1, self.expanse_ratio), 152 | nn.ReLU(inplace=True), 153 | MobileV2_Residual(16, 32, 1, self.expanse_ratio), 154 | nn.ReLU(inplace=True) 155 | ) 156 | self.firstconv2 = nn.Sequential(MobileV2_Residual(1, 4, 2, self.expanse_ratio), 157 | nn.ReLU(inplace=True), 158 | MobileV2_Residual(4, 16, 1, self.expanse_ratio), 159 | nn.ReLU(inplace=True), 160 | MobileV2_Residual(16, 32, 1, self.expanse_ratio), 161 | nn.ReLU(inplace=True) 162 | ) 163 | 164 | 165 | self.conv3d = nn.Sequential(nn.Conv3d(1, 1, kernel_size=(3, 5, 5), stride=[3, 1, 1], padding=[0, 2, 2]), 166 | nn.BatchNorm3d(1), 167 | nn.ReLU()) 168 | 169 | self.layer1 = self._make_layer(MobileV1_Residual, 32, 3, 1, 1, 1) 170 | self.layer2 = self._make_layer(MobileV1_Residual, 64, 16, 2, 1, 1)# 171 | self.layer3 = self._make_layer(MobileV1_Residual, 128, 3, 1, 1, 1) 172 | self.layer4 = self._make_layer(MobileV1_Residual, 128, 3, 1, 1, 2) 173 | 174 | self.preconv11 = nn.Sequential( 175 | convbn(320, 256, 1, 1, 0, 1), 176 | nn.ReLU(inplace=True), 177 | convbn(256, 128, 1, 1, 0, 1), 178 | nn.ReLU(inplace=True), 179 | convbn(128, 64, 1, 1, 0, 1), 180 | nn.ReLU(inplace=True), 181 | nn.Conv2d(64, 32, 1, 1, 0, 1) 182 | ) 183 | 184 | 185 | 186 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 187 | downsample = None 188 | 189 | if stride != 1 or self.inplanes != planes: 190 | downsample = nn.Sequential( 191 | nn.Conv2d(self.inplanes, planes, 192 | kernel_size=1, stride=stride, bias=False), 193 | nn.BatchNorm2d(planes), 194 | ) 195 | 196 | layers = [block(self.inplanes, planes, stride, downsample, pad, dilation)] 197 | self.inplanes = planes 198 | for i in range(1, blocks): 199 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation)) 200 | 201 | return nn.Sequential(*layers) 202 | 203 | def forward(self, x): 204 | 205 | x0 = torch.unsqueeze(x[:,0,:,:], 1) 206 | x1 = torch.unsqueeze(x[:,1,:,:], 1) 207 | x2 = torch.unsqueeze(x[:,2,:,:], 1) 208 | 209 | x0 = self.firstconv0(x0) 210 | x1 = self.firstconv1(x1) 211 | x2 = self.firstconv2(x2) 212 | 213 | B, C, H, W = x0.shape 214 | interwoven_features = x0.new_zeros([B, 3 * C, H, W]) 215 | xall = interweave_tensors3(interwoven_features, x0, x1, x2) 216 | 217 | xall = torch.unsqueeze(xall, 1) 218 | xall = self.conv3d(xall) 219 | xall = torch.squeeze(xall, 1) 220 | 221 | 222 | 223 | 224 | 225 | xall = self.layer1(xall) 226 | xall2 = self.layer2(xall) 227 | xall3 = self.layer3(xall2) 228 | xall4 = self.layer4(xall3) 229 | 230 | feature_volume = torch.cat((xall2, xall3, xall4), dim=1) 231 | 232 | xALL = self.preconv11(feature_volume) 233 | 234 | 235 | 236 | return xALL 237 | 238 | 239 | 240 | 241 | class volume_build(nn.Module): 242 | def __init__(self, volume_size): 243 | super(volume_build, self).__init__() 244 | self.num_groups = 1 245 | self.volume_size = volume_size 246 | 247 | 248 | 249 | 250 | 251 | self.volume11 = nn.Sequential( 252 | convbn(16, 1, 1, 1, 0, 1), 253 | nn.ReLU(inplace=True)) 254 | self.conv3d = nn.Sequential(nn.Conv3d(1, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]), 255 | nn.BatchNorm3d(16), 256 | nn.ReLU(), 257 | nn.Conv3d(16, 32, kernel_size=(4, 3, 3), stride=[4, 1, 1], padding=[0, 1, 1]), 258 | nn.BatchNorm3d(32), 259 | nn.ReLU(), 260 | nn.Conv3d(32, 16, kernel_size=(2, 3, 3), stride=[2, 1, 1], padding=[0, 1, 1]), 261 | nn.BatchNorm3d(16), 262 | nn.ReLU()) 263 | 264 | def forward(self, featL,featR): 265 | 266 | 267 | 268 | 269 | B, C, H, W = featL.shape 270 | volume = featL.new_zeros([B, self.num_groups, self.volume_size, H, W]) 271 | 272 | 273 | 274 | interwoven_features = featL.new_zeros([B, 2 * C, H, W]) 275 | for i in range(self.volume_size): 276 | 277 | if i > 0: 278 | x = interweave_tensors(interwoven_features, featL[:, :, :, :-i], featR[:, :, :, i:]) 279 | x = torch.unsqueeze(x, 1) 280 | x = self.conv3d(x) 281 | x = torch.squeeze(x, 2) 282 | x = self.volume11(x) 283 | volume[:, :, i, :, i:] = x 284 | else: 285 | x = interweave_tensors(interwoven_features, featL, featR) 286 | x = torch.unsqueeze(x, 1) 287 | x = self.conv3d(x) 288 | x = torch.squeeze(x, 2) 289 | x = self.volume11(x) 290 | volume[:, :, i, :, :] = x 291 | 292 | volume = volume.contiguous() 293 | volume = torch.squeeze(volume, 1) 294 | 295 | return volume 296 | 297 | 298 | 299 | 300 | 301 | ############################################################################## 302 | """ Disparity Regression Function """ 303 | ############################################################################### 304 | 305 | 306 | def disparity_regression(x, maxdisp): 307 | assert len(x.shape) == 4 308 | disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device) 309 | disp_values = disp_values.view(1, maxdisp, 1, 1) 310 | return torch.sum(x * disp_values, 1, keepdim=False) / maxdisp * cfg.model.max_depth_value 311 | 312 | 313 | 314 | def interweave_tensors(interwoven_features, refimg_fea, targetimg_fea): 315 | B, C, H, W = refimg_fea.shape 316 | interwoven_features = interwoven_features[:, :, :, 0:W] 317 | interwoven_features = interwoven_features*0 318 | interwoven_features[:,::2,:,:] = refimg_fea 319 | interwoven_features[:,1::2,:,:] = targetimg_fea 320 | interwoven_features = interwoven_features.contiguous() 321 | return interwoven_features 322 | def interweave_tensors3(interwoven_features, refimg_fea, targetimg_fea, targetimg2_fea): 323 | B, C, H, W = refimg_fea.shape 324 | interwoven_features = interwoven_features[:, :, :, 0:W] 325 | interwoven_features = interwoven_features*0 326 | interwoven_features[:,::3,:,:] = refimg_fea 327 | interwoven_features[:,1::3,:,:] = targetimg_fea 328 | interwoven_features[:,2::3,:,:] = targetimg2_fea 329 | interwoven_features = interwoven_features.contiguous() 330 | return interwoven_features 331 | 332 | 333 | ############################################################################### 334 | """ Loss Function """ 335 | ############################################################################### 336 | 337 | 338 | def model_loss(disp_ests, disp_gt, mask): 339 | weights = [0.5, 0.5, 0.7, 1.0] 340 | all_losses = [] 341 | for disp_est, weight in zip(disp_ests, weights): 342 | # all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], reduction='mean')) 343 | all_losses.append(weight * F.l1_loss(disp_est[mask], disp_gt[mask], reduction='mean')) 344 | return sum(all_losses) 345 | 346 | def model_loss_light(disp_ests, disp_gt, mask): 347 | weights = [0.5, 1.0] 348 | all_losses = [] 349 | for disp_est, weight in zip(disp_ests, weights): 350 | all_losses.append(weight * F.l1_loss(disp_est[mask], disp_gt[mask], reduction='mean')) 351 | return sum(all_losses) -------------------------------------------------------------------------------- /lib/network/eventgaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from lib.config import cfg, args 4 | # from lib.network.asnet import ASNet 5 | from lib.network.recon_net import E2IM, E2DPT 6 | from lib.network.gsregressor import GSRegressor 7 | 8 | class EventGaussian(nn.Module): 9 | def __init__(self): 10 | super(EventGaussian, self).__init__() 11 | self.depth_estimator = E2DPT(num_input_channels=8) 12 | self.intensity_estimator = E2IM(num_input_channels=32+8) 13 | self.regressor = GSRegressor(input_dim=2 + 32 + 8) 14 | self.gt_depth = False 15 | self.us_mask = "net" 16 | # self.proj1 = nn.Sequential(nn.Conv2D()) 17 | 18 | def forward(self, batch): 19 | leT = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1) 20 | riT = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1) 21 | b = leT.shape[0] 22 | inp = torch.cat([leT, riT], dim=0) 23 | 24 | #only available for debugging 25 | if not self.gt_depth: 26 | depths, masks, dfeats = self.depth_estimator.get_features(inp) 27 | depthL, depthR = depths[:b], depths[b:] 28 | masksL, masksR = masks[:b], masks[b:] 29 | dfeatsL, dfeatsR = dfeats[:b], dfeats[b:] #[b, 32, H, W] 30 | else: #debug only 31 | depthL, depthR = batch["ldepth"].unsqueeze(1), batch["rdepth"].unsqueeze(1) 32 | 33 | # only available for debugging 34 | if self.us_mask == "gt": 35 | maskL, maskR = batch["lmask"].unsqueeze(1), batch["rmask"].unsqueeze(1) 36 | elif self.us_mask == "net": 37 | maskL, maskR = masksL, masksR 38 | elif self.us_mask == "none": 39 | maskL, maskR = torch.ones_like(depthL).to(depthL.device), torch.ones_like(depthR).to(depthL.device) 40 | 41 | depthL = depthL * maskL 42 | depthR = depthR * maskR 43 | # 44 | L_img_inp = torch.cat([dfeatsL ,leT], dim=1) #depthFeat, frame, voxel 45 | R_img_inp = torch.cat([dfeatsR ,riT], dim=1) 46 | img_inp = torch.cat([L_img_inp, R_img_inp], dim=0) 47 | img, ifeats = self.intensity_estimator.get_features(img_inp) 48 | imgL, imgR = torch.split(img, b, dim=0) 49 | ifeatL, ifeatR = torch.split(ifeats, b, dim=0) 50 | # 51 | imgL = imgL * maskL 52 | imgR = imgR * maskR 53 | # imgL, imgR = batch["lim"], batch["rim"] 54 | L_gs_inp = torch.cat([depthL, imgL, ifeatL, leT], dim=1) # depthL, imgL, ifeatL, leT 55 | R_gs_inp = torch.cat([depthR, imgR, ifeatR, riT], dim=1) 56 | gs_inp = torch.cat([L_gs_inp, R_gs_inp], dim=0) 57 | # 58 | rot, scale, opacity = self.regressor(gs_inp) 59 | 60 | return { 61 | "lview":{ 62 | "depth":depthL, 63 | "mask":maskL, 64 | "pts_valid":maskL.squeeze().reshape(b, -1), 65 | "img": imgL, 66 | "rot":rot[:b], 67 | "scale":scale[:b], 68 | "opacity":opacity[:b] 69 | }, 70 | "rview":{ 71 | "depth":depthR, 72 | "mask":maskR, 73 | "pts_valid":maskR.squeeze().reshape(b, -1), 74 | "img": imgR, 75 | "rot":rot[b:], 76 | "scale":scale[b:], 77 | "opacity":opacity[b:] 78 | } 79 | } 80 | 81 | 82 | -------------------------------------------------------------------------------- /lib/network/firenet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch 4 | from .unet import UNet, UNetRecurrent, UNetFire, UNetStatic 5 | from os.path import join 6 | from .submodules import ConvLSTM, ResidualBlock, ConvLayer, UpsampleConvLayer, TransposedConvLayer 7 | 8 | 9 | import logging 10 | import numpy as np 11 | 12 | class BaseModel(nn.Module): 13 | """ 14 | Base class for all models 15 | """ 16 | def __init__(self, config): 17 | super(BaseModel, self).__init__() 18 | self.config = config 19 | self.logger = logging.getLogger(self.__class__.__name__) 20 | 21 | def forward(self, *input): 22 | """ 23 | Forward pass logic 24 | 25 | :return: Model output 26 | """ 27 | raise NotImplementedError 28 | 29 | def summary(self): 30 | """ 31 | Model summary 32 | """ 33 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 34 | params = sum([np.prod(p.size()) for p in model_parameters]) 35 | self.logger.info('Trainable parameters: {}'.format(params)) 36 | self.logger.info(self) 37 | 38 | class BaseE2VID(BaseModel): 39 | def __init__(self, config): 40 | super().__init__(config) 41 | 42 | assert('num_bins' in config) 43 | self.num_bins = int(config['num_bins']) # number of bins in the voxel grid event tensor 44 | 45 | try: 46 | self.skip_type = str(config['skip_type']) 47 | except KeyError: 48 | self.skip_type = 'sum' 49 | 50 | try: 51 | self.num_encoders = int(config['num_encoders']) 52 | except KeyError: 53 | self.num_encoders = 4 54 | 55 | try: 56 | self.base_num_channels = int(config['base_num_channels']) 57 | except KeyError: 58 | self.base_num_channels = 32 59 | 60 | try: 61 | self.num_residual_blocks = int(config['num_residual_blocks']) 62 | except KeyError: 63 | self.num_residual_blocks = 2 64 | 65 | try: 66 | self.norm = str(config['norm']) 67 | except KeyError: 68 | self.norm = None 69 | 70 | try: 71 | self.use_upsample_conv = bool(config['use_upsample_conv']) 72 | except KeyError: 73 | self.use_upsample_conv = True 74 | 75 | 76 | class E2VID(BaseE2VID): 77 | def __init__(self, config): 78 | super(E2VID, self).__init__(config) 79 | 80 | self.unet = UNet(num_input_channels=self.num_bins, 81 | num_output_channels=1, 82 | skip_type=self.skip_type, 83 | activation='sigmoid', 84 | num_encoders=self.num_encoders, 85 | base_num_channels=self.base_num_channels, 86 | num_residual_blocks=self.num_residual_blocks, 87 | norm=self.norm, 88 | use_upsample_conv=self.use_upsample_conv) 89 | 90 | def forward(self, event_tensor, prev_states=None): 91 | """ 92 | :param event_tensor: N x num_bins x H x W 93 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1]. 94 | """ 95 | return self.unet.forward(event_tensor), None 96 | 97 | 98 | class E2VIDRecurrent(BaseE2VID): 99 | """ 100 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU. 101 | """ 102 | 103 | def __init__(self, config): 104 | super(E2VIDRecurrent, self).__init__(config) 105 | 106 | try: 107 | self.recurrent_block_type = str(config['recurrent_block_type']) 108 | except KeyError: 109 | self.recurrent_block_type = 'convlstm' # or 'convgru' 110 | 111 | self.unetrecurrent = UNetRecurrent(num_input_channels=self.num_bins, 112 | num_output_channels=1, 113 | skip_type=self.skip_type, 114 | recurrent_block_type=self.recurrent_block_type, 115 | activation='sigmoid', 116 | num_encoders=self.num_encoders, 117 | base_num_channels=self.base_num_channels, 118 | num_residual_blocks=self.num_residual_blocks, 119 | norm=self.norm, 120 | use_upsample_conv=self.use_upsample_conv) 121 | 122 | def forward(self, event_tensor, prev_states): 123 | """ 124 | :param event_tensor: N x num_bins x H x W 125 | :param prev_states: previous ConvLSTM state for each encoder module 126 | :return: reconstructed image, taking values in [0,1]. 127 | """ 128 | img_pred, states = self.unetrecurrent.forward(event_tensor, prev_states) 129 | return img_pred, states 130 | 131 | 132 | class FireNet(BaseE2VID): 133 | """ 134 | Model from the paper: "Fast Image Reconstruction with an Event Camera", Scheerlinck et. al., 2019. 135 | The model is essentially a lighter version of E2VID, which runs faster (~2-3x faster) and has considerably less parameters (~200x less). 136 | However, the reconstructions are not as high quality as E2VID: they suffer from smearing artefacts, and initialization takes longer. 137 | """ 138 | def __init__(self, config): 139 | super().__init__(config) 140 | self.recurrent_block_type = str(config.get('recurrent_block_type', 'convgru')) 141 | kernel_size = config.get('kernel_size', 3) 142 | recurrent_blocks = config.get('recurrent_blocks', {'resblock': [0]}) 143 | self.net = UNetFire(self.num_bins, 144 | num_output_channels=1, 145 | skip_type=self.skip_type, 146 | recurrent_block_type=self.recurrent_block_type, 147 | base_num_channels=self.base_num_channels, 148 | num_residual_blocks=self.num_residual_blocks, 149 | norm=self.norm, 150 | kernel_size=kernel_size, 151 | recurrent_blocks=recurrent_blocks) 152 | 153 | def forward(self, event_tensor, prev_states): 154 | img, states = self.net.forward(event_tensor, prev_states) 155 | return img, states 156 | 157 | 158 | class FireNet_static(BaseE2VID): 159 | """ 160 | Model from the paper: "Fast Image Reconstruction with an Event Camera", Scheerlinck et. al., 2019. 161 | The model is essentially a lighter version of E2VID, which runs faster (~2-3x faster) and has considerably less parameters (~200x less). 162 | However, the reconstructions are not as high quality as E2VID: they suffer from smearing artefacts, and initialization takes longer. 163 | """ 164 | def __init__(self, config): 165 | super().__init__(config) 166 | self.recurrent_block_type = str(config.get('recurrent_block_type', 'convgru')) 167 | kernel_size = config.get('kernel_size', 3) 168 | recurrent_blocks = config.get('recurrent_blocks', {'resblock': [0]}) 169 | self.net = UNetStatic(self.num_bins, 170 | num_output_channels=1, 171 | skip_type=self.skip_type, 172 | recurrent_block_type=self.recurrent_block_type, 173 | base_num_channels=self.base_num_channels, 174 | num_residual_blocks=self.num_residual_blocks, 175 | norm=self.norm, 176 | kernel_size=kernel_size, 177 | recurrent_blocks=recurrent_blocks) 178 | 179 | def forward(self, event_tensor, placeholder=None): 180 | img = self.net.forward(event_tensor) 181 | return img, placeholder -------------------------------------------------------------------------------- /lib/network/gsregressor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 7 | super(ResidualBlock, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 10 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | num_groups = planes // 8 14 | 15 | if norm_fn == 'group': 16 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 17 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | if not (stride == 1 and in_planes == planes): 19 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 20 | 21 | elif norm_fn == 'batch': 22 | self.norm1 = nn.BatchNorm2d(planes) 23 | self.norm2 = nn.BatchNorm2d(planes) 24 | if not (stride == 1 and in_planes == planes): 25 | self.norm3 = nn.BatchNorm2d(planes) 26 | 27 | elif norm_fn == 'instance': 28 | self.norm1 = nn.InstanceNorm2d(planes) 29 | self.norm2 = nn.InstanceNorm2d(planes) 30 | if not (stride == 1 and in_planes == planes): 31 | self.norm3 = nn.InstanceNorm2d(planes) 32 | 33 | elif norm_fn == 'none': 34 | self.norm1 = nn.Sequential() 35 | self.norm2 = nn.Sequential() 36 | if not (stride == 1 and in_planes == planes): 37 | self.norm3 = nn.Sequential() 38 | 39 | if stride == 1 and in_planes == planes: 40 | self.downsample = None 41 | 42 | else: 43 | self.downsample = nn.Sequential( 44 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 45 | 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.conv1(y) 50 | y = self.norm1(y) 51 | y = self.relu(y) 52 | y = self.conv2(y) 53 | y = self.norm2(y) 54 | y = self.relu(y) 55 | 56 | if self.downsample is not None: 57 | x = self.downsample(x) 58 | 59 | return self.relu(x+y) 60 | 61 | class GSRegressor(nn.Module): 62 | def __init__(self, input_dim=1+1+8, hidden_dim = 256, norm_fn='group'): 63 | super().__init__() 64 | self.embedding = nn.Conv2d(input_dim, hidden_dim, kernel_size=1, stride=1) 65 | 66 | self.res1 = ResidualBlock(hidden_dim, hidden_dim // 4, norm_fn=norm_fn) 67 | 68 | self.rot_head = nn.Sequential( 69 | nn.Conv2d(hidden_dim // 4, hidden_dim // 4, kernel_size=3, padding=1), 70 | nn.ReLU(inplace=True), 71 | nn.Conv2d(hidden_dim // 4, 4, kernel_size=1), 72 | ) 73 | self.scale_head = nn.Sequential( 74 | nn.Conv2d(hidden_dim // 4, hidden_dim // 4, kernel_size=3, padding=1), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(hidden_dim // 4, 3, kernel_size=1), 77 | nn.Softplus(beta=100) 78 | ) 79 | self.opacity_head = nn.Sequential( 80 | nn.Conv2d(hidden_dim // 4, hidden_dim // 4, kernel_size=3, padding=1), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(hidden_dim // 4, 1, kernel_size=1), 83 | nn.Sigmoid() 84 | ) 85 | 86 | def forward(self, x): 87 | """ 88 | x intensity [B,1,H,W] 89 | depth [B,1,H,W] 90 | eframe [B,3,H,W] 91 | img_feat [B,320,H,W] 92 | """ 93 | x = self.embedding(x) 94 | out = self.res1(x) 95 | 96 | rot_out = self.rot_head(out) 97 | rot_out = torch.nn.functional.normalize(rot_out, dim=1) 98 | 99 | # scale head 100 | scale_out = torch.clamp_max(self.scale_head(out), 0.001) 101 | 102 | # opacity head 103 | opacity_out = self.opacity_head(out) 104 | 105 | return rot_out, scale_out, opacity_out 106 | 107 | 108 | -------------------------------------------------------------------------------- /lib/network/neurons.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Callable 3 | import torch 4 | import torch.nn as nn 5 | from spikingjelly.clock_driven import neuron, surrogate, base, layer 6 | import math 7 | try: 8 | import cupy 9 | from . import neuron_kernel, cu_kernel_opt 10 | except ImportError: 11 | neuron_kernel = None 12 | 13 | 14 | class BaseNode(base.MemoryModule): 15 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 16 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 17 | assert isinstance(v_reset, float) or v_reset is None 18 | assert isinstance(v_threshold, float) 19 | assert isinstance(detach_reset, bool) 20 | super().__init__() 21 | 22 | if v_reset is None: 23 | self.register_memory('v', 0.) 24 | self.register_memory('spike', 0.) 25 | else: 26 | self.register_memory('v', v_reset) 27 | self.register_memory('spike', 0.) 28 | 29 | self.v_threshold = v_threshold 30 | self.v_reset = v_reset 31 | 32 | self.detach_reset = detach_reset 33 | self.surrogate_function = surrogate_function 34 | 35 | @abstractmethod 36 | def neuronal_charge(self, x: torch.Tensor): 37 | raise NotImplementedError 38 | 39 | def neuronal_fire(self): 40 | self.spike = self.surrogate_function(self.v - self.v_threshold) 41 | 42 | def neuronal_reset(self): 43 | 44 | if self.detach_reset: 45 | spike = self.spike.detach() 46 | else: 47 | spike = self.spike 48 | 49 | if self.v_reset is None: 50 | # soft reset 51 | self.v = self.v - spike * self.v_threshold 52 | 53 | else: 54 | # hard reset 55 | self.v = (1. - spike) * self.v + spike * self.v_reset 56 | 57 | def extra_repr(self): 58 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}' 59 | 60 | def forward(self, x: torch.Tensor): 61 | 62 | self.neuronal_charge(x) 63 | self.neuronal_fire() 64 | self.neuronal_reset() 65 | return self.spike 66 | 67 | class BaseNode_adaspike(base.MemoryModule): 68 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 69 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 70 | 71 | assert isinstance(v_reset, float) or v_reset is None 72 | assert isinstance(v_threshold, float) 73 | assert isinstance(detach_reset, bool) 74 | super().__init__() 75 | 76 | if v_reset is None: 77 | self.register_memory('v', 0.) 78 | self.register_memory('spike', 0.) 79 | else: 80 | self.register_memory('v', v_reset) 81 | self.register_memory('spike', 0.) 82 | 83 | self.v_threshold = v_threshold 84 | self.v_reset = v_reset 85 | 86 | self.detach_reset = detach_reset 87 | self.surrogate_function = surrogate_function 88 | 89 | @abstractmethod 90 | def neuronal_charge(self, x: torch.Tensor): 91 | 92 | raise NotImplementedError 93 | 94 | def neuronal_fire(self): 95 | 96 | self.spike = self.surrogate_function(self.v - self.v_threshold) 97 | 98 | def neuronal_reset(self): 99 | 100 | if self.detach_reset: 101 | spike = self.spike.detach() 102 | else: 103 | spike = self.spike 104 | 105 | if self.v_reset is None: 106 | # soft reset 107 | self.v = self.v - spike * self.v_threshold 108 | 109 | else: 110 | # hard reset 111 | self.v = (1. - spike) * self.v + spike * self.v_reset 112 | 113 | def extra_repr(self): 114 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}' 115 | 116 | def forward(self, x: torch.Tensor, s: torch.Tensor): 117 | 118 | self.neuronal_charge(x, s) 119 | self.neuronal_fire() 120 | self.neuronal_reset() 121 | return self.spike 122 | 123 | class MpNode(base.MemoryModule): 124 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 125 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 126 | 127 | assert isinstance(v_reset, float) or v_reset is None 128 | assert isinstance(v_threshold, float) 129 | assert isinstance(detach_reset, bool) 130 | super().__init__() 131 | 132 | 133 | if v_reset is None: 134 | self.register_memory('v', 0.) 135 | self.register_memory('spike', 0.) 136 | else: 137 | self.register_memory('v', v_reset) 138 | self.register_memory('spike', 0.) 139 | 140 | 141 | self.v_threshold = v_threshold 142 | self.v_reset = v_reset 143 | 144 | self.detach_reset = detach_reset 145 | self.surrogate_function = surrogate_function 146 | 147 | @abstractmethod 148 | def neuronal_charge(self, x: torch.Tensor): 149 | raise NotImplementedError 150 | 151 | def neuronal_fire(self): 152 | self.spike = self.surrogate_function(self.v - self.v_threshold) 153 | 154 | def neuronal_reset(self): 155 | if self.detach_reset: 156 | spike = self.spike.detach() 157 | else: 158 | spike = self.spike 159 | 160 | if self.v_reset is None: 161 | # soft reset 162 | self.v = self.v - spike * self.v_threshold 163 | 164 | else: 165 | # hard reset 166 | self.v = (1. - spike) * self.v + spike * self.v_reset 167 | 168 | def extra_repr(self): 169 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}' 170 | 171 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor): 172 | if last_mem is None: 173 | self.neuronal_charge(x) 174 | else: 175 | self.register_memory('v', last_mem) 176 | self.neuronal_charge(x) 177 | return self.v 178 | 179 | class Ada_MpNode(base.MemoryModule): 180 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 181 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 182 | 183 | assert isinstance(v_reset, float) or v_reset is None 184 | assert isinstance(v_threshold, float) 185 | assert isinstance(detach_reset, bool) 186 | super().__init__() 187 | 188 | 189 | if v_reset is None: 190 | self.register_memory('v', 0.) 191 | self.register_memory('spike', 0.) 192 | else: 193 | self.register_memory('v', v_reset) 194 | self.register_memory('spike', 0.) 195 | 196 | 197 | self.v_threshold = v_threshold 198 | self.v_reset = v_reset 199 | 200 | self.detach_reset = detach_reset 201 | self.surrogate_function = surrogate_function 202 | 203 | @abstractmethod 204 | def neuronal_charge(self, x: torch.Tensor): 205 | raise NotImplementedError 206 | 207 | def neuronal_fire(self): 208 | self.spike = self.surrogate_function(self.v - self.v_threshold) 209 | 210 | def neuronal_reset(self): 211 | if self.detach_reset: 212 | spike = self.spike.detach() 213 | else: 214 | spike = self.spike 215 | 216 | if self.v_reset is None: 217 | # soft reset 218 | self.v = self.v - spike * self.v_threshold 219 | 220 | else: 221 | # hard reset 222 | self.v = (1. - spike) * self.v + spike * self.v_reset 223 | 224 | def extra_repr(self): 225 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}' 226 | 227 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor, w: torch.Tensor): 228 | if last_mem is None: 229 | self.neuronal_charge(x, w) 230 | else: 231 | self.register_memory('v', last_mem) 232 | self.neuronal_charge(x, w) 233 | 234 | return self.v 235 | 236 | class Ada_MpNode_adaspike(base.MemoryModule): 237 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 238 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 239 | assert isinstance(v_reset, float) or v_reset is None 240 | assert isinstance(v_threshold, float) 241 | assert isinstance(detach_reset, bool) 242 | super().__init__() 243 | 244 | 245 | if v_reset is None: 246 | self.register_memory('v', 0.) 247 | self.register_memory('spike', 0.) 248 | else: 249 | self.register_memory('v', v_reset) 250 | self.register_memory('spike', 0.) 251 | 252 | 253 | self.v_threshold = v_threshold 254 | self.v_reset = v_reset 255 | 256 | self.detach_reset = detach_reset 257 | self.surrogate_function = surrogate_function 258 | 259 | @abstractmethod 260 | def neuronal_charge(self, x: torch.Tensor): 261 | raise NotImplementedError 262 | 263 | def neuronal_fire(self): 264 | self.spike = self.surrogate_function(self.v - self.v_threshold) 265 | 266 | def neuronal_reset(self): 267 | if self.detach_reset: 268 | spike = self.spike.detach() 269 | else: 270 | spike = self.spike 271 | 272 | if self.v_reset is None: 273 | # soft reset 274 | self.v = self.v - spike * self.v_threshold 275 | 276 | else: 277 | # hard reset 278 | self.v = (1. - spike) * self.v + spike * self.v_reset 279 | 280 | def extra_repr(self): 281 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}' 282 | 283 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor, w: torch.Tensor, s: torch.Tensor): 284 | if last_mem is None: 285 | self.neuronal_charge(x, w, s) 286 | else: 287 | self.register_memory('v', last_mem) 288 | self.neuronal_charge(x, w, s) 289 | return self.v 290 | 291 | class Multi_Node(base.MemoryModule): 292 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 293 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 294 | assert isinstance(v_reset, float) or v_reset is None 295 | assert isinstance(v_threshold, float) 296 | assert isinstance(detach_reset, bool) 297 | super().__init__() 298 | 299 | if v_reset is None: 300 | self.register_memory('v', 0.) 301 | self.register_memory('spike', 0.) 302 | else: 303 | self.register_memory('v', v_reset) 304 | self.register_memory('spike', 0.) 305 | 306 | self.v_threshold = v_threshold 307 | self.v_reset = v_reset 308 | 309 | self.detach_reset = detach_reset 310 | self.surrogate_function = surrogate_function 311 | 312 | @abstractmethod 313 | def neuronal_charge(self, x: torch.Tensor): 314 | raise NotImplementedError 315 | 316 | def neuronal_fire(self): 317 | self.spike = self.surrogate_function(self.v - self.v_threshold) 318 | 319 | def neuronal_reset(self): 320 | if self.detach_reset: 321 | spike = self.spike.detach() 322 | else: 323 | spike = self.spike 324 | 325 | if self.v_reset is None: 326 | # soft reset 327 | self.v = self.v - spike * self.v_threshold 328 | 329 | else: 330 | # hard reset 331 | self.v = (1. - spike) * self.v + spike * self.v_reset 332 | 333 | def extra_repr(self): 334 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}' 335 | 336 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor): 337 | if last_mem is None: 338 | self.neuronal_charge(x) 339 | self.neuronal_fire() 340 | self.neuronal_reset() 341 | else: 342 | self.register_memory('v', last_mem) 343 | self.neuronal_charge(x) 344 | self.neuronal_fire() 345 | self.neuronal_reset() 346 | 347 | return self.spike, self.v 348 | 349 | class MpLIFNode(MpNode): 350 | def __init__(self, tau: float = 2., v_threshold: float = 1., 351 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 352 | detach_reset: bool = False): 353 | assert isinstance(tau, float) and tau > 1. 354 | 355 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 356 | self.tau = tau 357 | 358 | def extra_repr(self): 359 | return super().extra_repr() + f', tau={self.tau}' 360 | 361 | def neuronal_charge(self, x: torch.Tensor): 362 | if self.v_reset is None: 363 | self.v = self.v + (x - self.v) / self.tau 364 | 365 | else: 366 | if isinstance(self.v_reset, float) and self.v_reset == 0.: 367 | self.v = self.v + (x - self.v) / self.tau 368 | else: 369 | self.v = self.v + (x - (self.v - self.v_reset)) / self.tau 370 | 371 | class Mp_AdaLIFNode(Ada_MpNode): 372 | def __init__(self, tau: float = 2., v_threshold: float = 1., 373 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 374 | detach_reset: bool = False): 375 | 376 | assert isinstance(tau, float) and tau > 1. 377 | self.tau = tau 378 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 379 | 380 | 381 | def extra_repr(self): 382 | return super().extra_repr() + f', tau={self.tau}' 383 | 384 | def neuronal_charge(self, x: torch.Tensor, w: torch.Tensor): 385 | tau = w.sigmoid() 386 | if self.v_reset is None: 387 | self.v = self.v + (x - self.v) * tau 388 | 389 | else: 390 | if isinstance(self.v_reset, float) and self.v_reset == 0.: 391 | self.v = self.v + (x - self.v) * tau 392 | else: 393 | self.v = self.v + (x - (self.v - self.v_reset)) * tau 394 | 395 | class Mp_AdaLIFNode_adaspike(Ada_MpNode_adaspike): 396 | def __init__(self, tau: float = 2., v_threshold: float = 1., 397 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 398 | detach_reset: bool = False): 399 | 400 | assert isinstance(tau, float) and tau > 1. 401 | self.tau = tau 402 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 403 | 404 | 405 | def extra_repr(self): 406 | return super().extra_repr() + f', tau={self.tau}' 407 | 408 | def neuronal_charge(self, x: torch.Tensor, w: torch.Tensor, s: torch.Tensor): 409 | tau = w.sigmoid() 410 | if self.v_reset is None: 411 | self.v = self.v + s * (x - self.v) * tau 412 | 413 | else: 414 | if isinstance(self.v_reset, float) and self.v_reset == 0.: 415 | self.v = self.v + (x - self.v) * tau 416 | else: 417 | self.v = self.v + (x - (self.v - self.v_reset)) * tau 418 | 419 | class MpIFNode(MpNode): 420 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 421 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 422 | 423 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 424 | 425 | def neuronal_charge(self, x: torch.Tensor): 426 | self.v = self.v + x 427 | 428 | class Mp_ParametricLIFNode(MpNode): 429 | def __init__(self, init_tau: float = 2.0, v_threshold: float = 1., 430 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 431 | detach_reset: bool = False): 432 | 433 | assert isinstance(init_tau, float) and init_tau > 1. 434 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 435 | init_w = - math.log(init_tau - 1.) 436 | self.w = nn.Parameter(torch.as_tensor(init_w)) 437 | 438 | def extra_repr(self): 439 | with torch.no_grad(): 440 | tau = self.w.sigmoid() #.sigmoid() 441 | return super().extra_repr() + f', tau={tau}' 442 | 443 | def neuronal_charge(self, x: torch.Tensor): 444 | if self.v_reset is None: 445 | self.v = self.v + (x - self.v) * self.w.sigmoid() 446 | else: 447 | if self.v_reset == 0.: 448 | self.v = self.v + (x - self.v) * self.w.sigmoid() 449 | else: 450 | self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid() 451 | 452 | class Mp_ParametricLIFNode_modify(MpNode): 453 | def __init__(self, size_h, size_w, init_tau: float = 2.0, v_threshold: float = 1., 454 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 455 | detach_reset: bool = False): 456 | assert isinstance(init_tau, float) and init_tau > 1. 457 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 458 | init_w = - math.log(init_tau - 1.) 459 | self.w = nn.Parameter(torch.ones(size=[size_w, size_h])* init_w) 460 | #self.w = self.w * init_w # test 461 | 462 | def extra_repr(self): 463 | with torch.no_grad(): 464 | tau = self.w.sigmoid() #.sigmoid() 465 | return super().extra_repr() + f', tau={tau}' 466 | 467 | def neuronal_charge(self, x: torch.Tensor): 468 | if self.v_reset is None: 469 | self.v = self.v + (x - self.v) * self.w.sigmoid() 470 | else: 471 | if self.v_reset == 0.: 472 | self.v = self.v + (x - self.v) * self.w.sigmoid() 473 | else: 474 | self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid() 475 | 476 | class IFNode(BaseNode): 477 | def __init__(self, v_threshold: float = 1., v_reset: float = 0., 478 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): 479 | 480 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 481 | 482 | def neuronal_charge(self, x: torch.Tensor): 483 | self.v = self.v + x 484 | 485 | class LIFNode(BaseNode): 486 | def __init__(self, tau: float = 2., v_threshold: float = 1., 487 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 488 | detach_reset: bool = False): 489 | 490 | assert isinstance(tau, float) and tau > 1. 491 | 492 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 493 | self.tau = tau 494 | 495 | def extra_repr(self): 496 | return super().extra_repr() + f', tau={self.tau}' 497 | 498 | def neuronal_charge(self, x: torch.Tensor): 499 | if self.v_reset is None: 500 | self.v = self.v + (x - self.v) / self.tau 501 | 502 | else: 503 | if isinstance(self.v_reset, float) and self.v_reset == 0.: 504 | self.v = self.v + (x - self.v) / self.tau 505 | else: 506 | self.v = self.v + (x - (self.v - self.v_reset)) / self.tau 507 | 508 | class LIFNode_adaspike(BaseNode_adaspike): 509 | def __init__(self, tau: float = 2., v_threshold: float = 1., 510 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 511 | detach_reset: bool = False): 512 | 513 | assert isinstance(tau, float) and tau > 1. 514 | 515 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 516 | self.tau = tau 517 | 518 | def extra_repr(self): 519 | return super().extra_repr() + f', tau={self.tau}' 520 | 521 | def neuronal_charge(self, x: torch.Tensor, s: torch.Tensor): 522 | if self.v_reset is None: 523 | self.v = self.v + s*(x - self.v) / self.tau 524 | 525 | else: 526 | if isinstance(self.v_reset, float) and self.v_reset == 0.: 527 | self.v = self.v + (x - self.v) / self.tau 528 | else: 529 | self.v = self.v + (x - (self.v - self.v_reset)) / self.tau 530 | 531 | 532 | class ParametricLIFNode(BaseNode): 533 | def __init__(self, init_tau: float = 2.0, v_threshold: float = 1., 534 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), 535 | detach_reset: bool = False): 536 | 537 | assert isinstance(init_tau, float) and init_tau > 1. 538 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) 539 | init_w = - math.log(init_tau - 1.) 540 | self.w = nn.Parameter(torch.as_tensor(init_w)) 541 | 542 | def extra_repr(self): 543 | with torch.no_grad(): 544 | tau = self.w.sigmoid() 545 | return super().extra_repr() + f', tau={tau}' 546 | 547 | def neuronal_charge(self, x: torch.Tensor): 548 | if self.v_reset is None: 549 | self.v = self.v + (x - self.v) * self.w.sigmoid() 550 | else: 551 | if self.v_reset == 0.: 552 | self.v = self.v + (x - self.v) * self.w.sigmoid() 553 | else: 554 | self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid() -------------------------------------------------------------------------------- /lib/network/recon_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .unet import UNet 4 | from lib.config import cfg, args 5 | from .submodules import ConvLayer 6 | 7 | class E2IM(nn.Module): 8 | def __init__(self, num_input_channels=6, 9 | num_output_channels=1, 10 | skip_type="sum", 11 | activation='sigmoid', 12 | num_encoders=4, 13 | base_num_channels=32, 14 | num_residual_blocks=2, 15 | norm="BN", 16 | use_upsample_conv=True): 17 | super(E2IM, self).__init__() 18 | 19 | self.unet = UNet(num_input_channels=num_input_channels, 20 | num_output_channels=num_output_channels, 21 | skip_type=skip_type, 22 | activation=activation, 23 | num_encoders=num_encoders, 24 | base_num_channels=base_num_channels, 25 | num_residual_blocks=num_residual_blocks, 26 | norm=norm, 27 | use_upsample_conv=use_upsample_conv) 28 | 29 | def forward(self, event_tensor): 30 | """ 31 | :param event_tensor: N x num_bins x H x W 32 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1]. 33 | """ 34 | return self.unet.forward(event_tensor) 35 | 36 | def get_features(self, event_tensor): 37 | img, feat = self.unet.get_features(event_tensor) 38 | return img, feat 39 | 40 | 41 | class E2DPT(nn.Module): 42 | def __init__(self, num_input_channels=6, 43 | num_output_channels=1, 44 | skip_type="sum", 45 | activation='sigmoid', 46 | num_encoders=4, 47 | base_num_channels=32, 48 | num_residual_blocks=2, 49 | norm="BN", 50 | use_upsample_conv=True): 51 | super(E2DPT, self).__init__() 52 | 53 | self.unet = UNet(num_input_channels=num_input_channels, 54 | num_output_channels=num_output_channels, 55 | skip_type=skip_type, 56 | activation=activation, 57 | num_encoders=num_encoders, 58 | base_num_channels=base_num_channels, 59 | num_residual_blocks=num_residual_blocks, 60 | norm=norm, 61 | use_upsample_conv=use_upsample_conv) 62 | 63 | self.mask_head = nn.Sequential(ConvLayer(base_num_channels, base_num_channels // 2, kernel_size=3, padding=1, norm=norm), 64 | ConvLayer(base_num_channels // 2, 1,kernel_size=1, activation=activation, norm=norm)) 65 | 66 | 67 | def forward(self, event_tensor): 68 | """ 69 | :param event_tensor: N x num_bins x H x W 70 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1]. 71 | """ 72 | 73 | depth, feat = self.unet.get_features(event_tensor) 74 | mask = self.mask_head(feat) 75 | return depth * cfg.model.max_depth_value, mask 76 | 77 | def get_features(self, event_tensor): 78 | depth, feat = self.unet.get_features(event_tensor) 79 | mask = self.mask_head(feat) 80 | depth = depth * cfg.model.max_depth_value 81 | 82 | return depth, mask, feat 83 | 84 | 85 | class E2Msk(nn.Module): 86 | def __init__(self, num_input_channels=6, 87 | num_output_channels=1, 88 | skip_type="sum", 89 | activation='sigmoid', 90 | num_encoders=4, 91 | base_num_channels=32, 92 | num_residual_blocks=2, 93 | norm="BN", 94 | use_upsample_conv=True): 95 | super(E2Msk, self).__init__() 96 | 97 | self.unet = UNet(num_input_channels=num_input_channels, 98 | num_output_channels=num_output_channels, 99 | skip_type=skip_type, 100 | activation=activation, 101 | num_encoders=num_encoders, 102 | base_num_channels=base_num_channels, 103 | num_residual_blocks=num_residual_blocks, 104 | norm=norm, 105 | use_upsample_conv=use_upsample_conv) 106 | 107 | 108 | def forward(self, event_tensor): 109 | """ 110 | :param event_tensor: N x num_bins x H x W 111 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1]. 112 | """ 113 | depth = self.unet.forward(event_tensor) 114 | return depth * cfg.model.max_depth_value 115 | 116 | def get_features(self, event_tensor): 117 | depth, feat = self.unet.get_features(event_tensor) 118 | depth = depth * cfg.model.max_depth_value 119 | 120 | return depth, feat 121 | 122 | -------------------------------------------------------------------------------- /lib/network/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | from torch.hub import load_state_dict_from_url 5 | 6 | 7 | model_urls = { 8 | 'resnet50': 'https://github.com/bubbliiiing/pspnet-pytorch/releases/download/v1.0/resnet50s-a75c83cf.pth', 9 | } 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | "3x3 convolution with padding" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | class Bottleneck(nn.Module): 17 | expansion = 4 18 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, 19 | norm_layer=None): 20 | super(Bottleneck, self).__init__() 21 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 22 | self.bn1 = norm_layer(planes) 23 | 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 25 | self.bn2 = norm_layer(planes) 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 28 | self.bn3 = norm_layer(planes * 4) 29 | 30 | self.relu = nn.ReLU(inplace=True) 31 | 32 | self.downsample = downsample 33 | self.dilation = dilation 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv3(out) 48 | out = self.bn3(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | return out 56 | 57 | 58 | class ResNet(nn.Module): 59 | def __init__(self, block, layers, num_classes=1000, dilated=False, deep_base=True, norm_layer=nn.BatchNorm2d): 60 | self.inplanes = 128 if deep_base else 64 61 | super(ResNet, self).__init__() 62 | if deep_base: 63 | self.conv1 = nn.Sequential( 64 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 65 | norm_layer(64), 66 | nn.ReLU(inplace=True), 67 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 68 | norm_layer(64), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), 71 | ) 72 | else: 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 74 | bias=False) 75 | self.bn1 = norm_layer(self.inplanes) 76 | self.relu = nn.ReLU(inplace=True) 77 | 78 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 79 | 80 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 81 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 82 | if dilated: 83 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 84 | dilation=2, norm_layer=norm_layer) 85 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 86 | dilation=4, norm_layer=norm_layer) 87 | else: 88 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 89 | norm_layer=norm_layer) 90 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 91 | norm_layer=norm_layer) 92 | 93 | self.avgpool = nn.AvgPool2d(7, stride=1) 94 | self.fc = nn.Linear(512 * block.expansion, num_classes) 95 | 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 99 | m.weight.data.normal_(0, math.sqrt(2. / n)) 100 | elif isinstance(m, norm_layer): 101 | m.weight.data.fill_(1) 102 | m.bias.data.zero_() 103 | 104 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False): 105 | downsample = None 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = nn.Sequential( 108 | nn.Conv2d(self.inplanes, planes * block.expansion, 109 | kernel_size=1, stride=stride, bias=False), 110 | norm_layer(planes * block.expansion), 111 | ) 112 | 113 | layers = [] 114 | multi_dilations = [4, 8, 16] 115 | if multi_grid: 116 | layers.append(block(self.inplanes, planes, stride, dilation=multi_dilations[0], 117 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 118 | elif dilation == 1 or dilation == 2: 119 | layers.append(block(self.inplanes, planes, stride, dilation=1, 120 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 121 | elif dilation == 4: 122 | layers.append(block(self.inplanes, planes, stride, dilation=2, 123 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 124 | else: 125 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 126 | 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | if multi_grid: 130 | layers.append(block(self.inplanes, planes, dilation=multi_dilations[i], 131 | previous_dilation=dilation, norm_layer=norm_layer)) 132 | else: 133 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, 134 | norm_layer=norm_layer)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | def resnet50(pretrained=False, **kwargs): 156 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 157 | if pretrained: 158 | model.load_state_dict(load_state_dict_from_url(model_urls['resnet50'], "./model_data"), strict=False) 159 | return model -------------------------------------------------------------------------------- /lib/network/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from torch.nn import init 5 | 6 | 7 | class ConvLayer(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 9 | super(ConvLayer, self).__init__() 10 | 11 | bias = False if norm == 'BN' else True 12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 13 | if activation=='relu' or activation == "ReLU": 14 | self.activation = getattr(torch, activation, 'relu') 15 | elif activation=='sigmoid' or activation == 'Sigmoid': 16 | self.activation = getattr(torch, activation, 'sigmoid') 17 | else: 18 | self.activation = None 19 | 20 | self.norm = norm 21 | if norm == 'BN': 22 | self.norm_layer = nn.BatchNorm2d(out_channels) 23 | elif norm == 'IN': 24 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 25 | 26 | def forward(self, x): 27 | out = self.conv2d(x) 28 | 29 | if self.norm in ['BN', 'IN']: 30 | out = self.norm_layer(out) 31 | 32 | if self.activation is not None: 33 | out = self.activation(out) 34 | 35 | return out 36 | 37 | 38 | class TransposedConvLayer(nn.Module): 39 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 40 | super(TransposedConvLayer, self).__init__() 41 | 42 | bias = False if norm == 'BN' else True 43 | self.transposed_conv2d = nn.ConvTranspose2d( 44 | in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias) 45 | 46 | if activation is not None: 47 | self.activation = getattr(torch, activation, 'relu') 48 | else: 49 | self.activation = None 50 | 51 | self.norm = norm 52 | if norm == 'BN': 53 | self.norm_layer = nn.BatchNorm2d(out_channels) 54 | elif norm == 'IN': 55 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 56 | 57 | def forward(self, x): 58 | out = self.transposed_conv2d(x) 59 | 60 | if self.norm in ['BN', 'IN']: 61 | out = self.norm_layer(out) 62 | 63 | if self.activation is not None: 64 | out = self.activation(out) 65 | 66 | return out 67 | 68 | 69 | class UpsampleConvLayer(nn.Module): 70 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 71 | super(UpsampleConvLayer, self).__init__() 72 | 73 | bias = False if norm == 'BN' else True 74 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 75 | 76 | if activation is not None: 77 | self.activation = getattr(torch, activation, 'relu') 78 | else: 79 | self.activation = None 80 | 81 | self.norm = norm 82 | if norm == 'BN': 83 | self.norm_layer = nn.BatchNorm2d(out_channels) 84 | elif norm == 'IN': 85 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 86 | 87 | def forward(self, x): 88 | x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 89 | out = self.conv2d(x_upsampled) 90 | 91 | if self.norm in ['BN', 'IN']: 92 | out = self.norm_layer(out) 93 | 94 | if self.activation is not None: 95 | out = self.activation(out) 96 | 97 | return out 98 | 99 | 100 | class RecurrentConvLayer(nn.Module): 101 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 102 | recurrent_block_type='convlstm', activation='relu', norm=None): 103 | super(RecurrentConvLayer, self).__init__() 104 | 105 | assert(recurrent_block_type in ['convlstm', 'convgru']) 106 | self.recurrent_block_type = recurrent_block_type 107 | if self.recurrent_block_type == 'convlstm': 108 | RecurrentBlock = ConvLSTM 109 | else: 110 | RecurrentBlock = ConvGRU 111 | self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm) 112 | self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) 113 | 114 | def forward(self, x, prev_state): 115 | x = self.conv(x) 116 | state = self.recurrent_block(x, prev_state) 117 | x = state[0] if self.recurrent_block_type == 'convlstm' else state 118 | return x, state 119 | 120 | 121 | class DownsampleRecurrentConvLayer(nn.Module): 122 | def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'): 123 | super(DownsampleRecurrentConvLayer, self).__init__() 124 | 125 | self.activation = getattr(torch, activation, 'relu') 126 | 127 | assert(recurrent_block_type in ['convlstm', 'convgru']) 128 | self.recurrent_block_type = recurrent_block_type 129 | if self.recurrent_block_type == 'convlstm': 130 | RecurrentBlock = ConvLSTM 131 | else: 132 | RecurrentBlock = ConvGRU 133 | self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size) 134 | 135 | def forward(self, x, prev_state): 136 | state = self.recurrent_block(x, prev_state) 137 | x = state[0] if self.recurrent_block_type == 'convlstm' else state 138 | x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) 139 | return self.activation(x), state 140 | 141 | 142 | # Residual block 143 | class ResidualBlock(nn.Module): 144 | def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None): 145 | super(ResidualBlock, self).__init__() 146 | bias = False if norm == 'BN' else True 147 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias) 148 | self.norm = norm 149 | if norm == 'BN': 150 | self.bn1 = nn.BatchNorm2d(out_channels) 151 | self.bn2 = nn.BatchNorm2d(out_channels) 152 | elif norm == 'IN': 153 | self.bn1 = nn.InstanceNorm2d(out_channels) 154 | self.bn2 = nn.InstanceNorm2d(out_channels) 155 | 156 | self.relu = nn.ReLU(inplace=True) 157 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 158 | self.downsample = downsample 159 | 160 | def forward(self, x): 161 | residual = x 162 | out = self.conv1(x) 163 | if self.norm in ['BN', 'IN']: 164 | out = self.bn1(out) 165 | out = self.relu(out) 166 | out = self.conv2(out) 167 | if self.norm in ['BN', 'IN']: 168 | out = self.bn2(out) 169 | 170 | if self.downsample: 171 | residual = self.downsample(x) 172 | 173 | out += residual 174 | out = self.relu(out) 175 | return out 176 | 177 | 178 | class ConvLSTM(nn.Module): 179 | """Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """ 180 | 181 | def __init__(self, input_size, hidden_size, kernel_size): 182 | super(ConvLSTM, self).__init__() 183 | 184 | self.input_size = input_size 185 | self.hidden_size = hidden_size 186 | pad = kernel_size // 2 187 | 188 | # cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled 189 | self.zero_tensors = {} 190 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad) 191 | 192 | def forward(self, input_, prev_state=None): 193 | 194 | # get batch and spatial sizes 195 | batch_size = input_.data.size()[0] 196 | spatial_size = input_.data.size()[2:] 197 | 198 | # generate empty prev_state, if None is provided 199 | if prev_state is None: 200 | 201 | # create the zero tensor if it has not been created already 202 | state_size = tuple([batch_size, self.hidden_size] + list(spatial_size)) 203 | if state_size not in self.zero_tensors: 204 | # allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already) 205 | self.zero_tensors[state_size] = ( 206 | torch.zeros(state_size).to(input_.device), 207 | torch.zeros(state_size).to(input_.device) 208 | ) 209 | 210 | prev_state = self.zero_tensors[tuple(state_size)] 211 | 212 | prev_hidden, prev_cell = prev_state 213 | 214 | # data size is [batch, channel, height, width] 215 | stacked_inputs = torch.cat((input_, prev_hidden), 1) 216 | gates = self.Gates(stacked_inputs) 217 | 218 | # chunk across channel dimension 219 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 220 | 221 | # apply sigmoid non linearity 222 | in_gate = torch.sigmoid(in_gate) 223 | remember_gate = torch.sigmoid(remember_gate) 224 | out_gate = torch.sigmoid(out_gate) 225 | 226 | # apply tanh non linearity 227 | cell_gate = torch.tanh(cell_gate) 228 | 229 | # compute current cell and hidden state 230 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate) 231 | hidden = out_gate * torch.tanh(cell) 232 | 233 | return hidden, cell 234 | 235 | 236 | class ConvGRU(nn.Module): 237 | """ 238 | Generate a convolutional GRU cell 239 | Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py 240 | """ 241 | 242 | def __init__(self, input_size, hidden_size, kernel_size): 243 | super().__init__() 244 | padding = kernel_size // 2 245 | self.input_size = input_size 246 | self.hidden_size = hidden_size 247 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 248 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 249 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 250 | 251 | init.orthogonal_(self.reset_gate.weight) 252 | init.orthogonal_(self.update_gate.weight) 253 | init.orthogonal_(self.out_gate.weight) 254 | init.constant_(self.reset_gate.bias, 0.) 255 | init.constant_(self.update_gate.bias, 0.) 256 | init.constant_(self.out_gate.bias, 0.) 257 | 258 | def forward(self, input_, prev_state): 259 | 260 | # get batch and spatial sizes 261 | batch_size = input_.data.size()[0] 262 | spatial_size = input_.data.size()[2:] 263 | 264 | # generate empty prev_state, if None is provided 265 | if prev_state is None: 266 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 267 | prev_state = torch.zeros(state_size).to(input_.device) 268 | 269 | # data size is [batch, channel, height, width] 270 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 271 | update = torch.sigmoid(self.update_gate(stacked_inputs)) 272 | reset = torch.sigmoid(self.reset_gate(stacked_inputs)) 273 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 274 | new_state = prev_state * (1 - update) + out_inputs * update 275 | 276 | return new_state 277 | 278 | class RecurrentResidualLayer(nn.Module): 279 | def __init__(self, in_channels, out_channels, 280 | recurrent_block_type='convlstm', norm=None): 281 | super(RecurrentResidualLayer, self).__init__() 282 | 283 | assert(recurrent_block_type in ['convlstm', 'convgru']) 284 | self.recurrent_block_type = recurrent_block_type 285 | if self.recurrent_block_type == 'convlstm': 286 | RecurrentBlock = ConvLSTM 287 | else: 288 | RecurrentBlock = ConvGRU 289 | self.conv = ResidualBlock(in_channels=in_channels, 290 | out_channels=out_channels, 291 | norm=norm) 292 | self.recurrent_block = RecurrentBlock(input_size=out_channels, 293 | hidden_size=out_channels, 294 | kernel_size=3) 295 | 296 | def forward(self, x, prev_state): 297 | x = self.conv(x) 298 | state = self.recurrent_block(x, prev_state) 299 | x = state[0] if self.recurrent_block_type == 'convlstm' else state 300 | return x, state 301 | 302 | -------------------------------------------------------------------------------- /lib/network/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from torch.nn import init 5 | from .submodules import ConvLayer, UpsampleConvLayer, TransposedConvLayer, RecurrentConvLayer, ResidualBlock, ConvLSTM, ConvGRU, RecurrentResidualLayer 6 | 7 | 8 | def skip_concat(x1, x2): 9 | return torch.cat([x1, x2], dim=1) 10 | 11 | 12 | def skip_sum(x1, x2): 13 | return x1 + x2 14 | 15 | 16 | class BaseUNet(nn.Module): 17 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid', 18 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True): 19 | super(BaseUNet, self).__init__() 20 | 21 | self.num_input_channels = num_input_channels 22 | self.num_output_channels = num_output_channels 23 | self.skip_type = skip_type 24 | self.apply_skip_connection = skip_sum if self.skip_type == 'sum' else skip_concat 25 | self.activation = activation 26 | self.norm = norm 27 | 28 | if use_upsample_conv: 29 | print('Using UpsampleConvLayer (slow, but no checkerboard artefacts)') 30 | self.UpsampleLayer = UpsampleConvLayer 31 | else: 32 | print('Using TransposedConvLayer (fast, with checkerboard artefacts)') 33 | self.UpsampleLayer = TransposedConvLayer 34 | 35 | self.num_encoders = num_encoders 36 | self.base_num_channels = base_num_channels 37 | self.num_residual_blocks = num_residual_blocks 38 | self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders) 39 | 40 | assert(self.num_input_channels > 0) 41 | assert(self.num_output_channels > 0) 42 | 43 | self.encoder_input_sizes = [] 44 | for i in range(self.num_encoders): 45 | self.encoder_input_sizes.append(self.base_num_channels * pow(2, i)) 46 | 47 | self.encoder_output_sizes = [self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)] 48 | 49 | self.activation = getattr(torch, self.activation, 'sigmoid') 50 | 51 | def build_resblocks(self): 52 | self.resblocks = nn.ModuleList() 53 | for i in range(self.num_residual_blocks): 54 | self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm)) 55 | 56 | def build_decoders(self): 57 | decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)])) 58 | 59 | self.decoders = nn.ModuleList() 60 | for input_size in decoder_input_sizes: 61 | self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size, 62 | input_size // 2, 63 | kernel_size=5, padding=2, norm=self.norm)) 64 | 65 | def build_prediction_layer(self): 66 | self.pred = ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels, 67 | self.num_output_channels, 1, activation=None, norm=self.norm) 68 | 69 | 70 | class UNet(BaseUNet): 71 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid', 72 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True): 73 | super(UNet, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 74 | num_encoders, base_num_channels, num_residual_blocks, norm, use_upsample_conv) 75 | 76 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels, 77 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W 78 | 79 | self.encoders = nn.ModuleList() 80 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes): 81 | self.encoders.append(ConvLayer(input_size, output_size, kernel_size=5, 82 | stride=2, padding=2, norm=self.norm)) 83 | 84 | self.build_resblocks() 85 | self.build_decoders() 86 | self.build_prediction_layer() 87 | 88 | def forward(self, x): 89 | """ 90 | :param x: N x num_input_channels x H x W 91 | :return: N x num_output_channels x H x W 92 | """ 93 | 94 | # head 95 | x = self.head(x) 96 | head = x 97 | 98 | # encoder 99 | blocks = [] 100 | for i, encoder in enumerate(self.encoders): 101 | x = encoder(x) 102 | blocks.append(x) 103 | 104 | # residual blocks 105 | for resblock in self.resblocks: 106 | x = resblock(x) 107 | 108 | # decoder 109 | for i, decoder in enumerate(self.decoders): 110 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 111 | 112 | img = self.activation(self.pred(self.apply_skip_connection(x, head))) 113 | 114 | return img 115 | 116 | def get_features(self, x): 117 | """ 118 | :param x: N x num_input_channels x H x W 119 | :return: N x num_output_channels x H x W 120 | """ 121 | 122 | # head 123 | x = self.head(x) 124 | head = x 125 | 126 | # encoder 127 | blocks = [] 128 | for i, encoder in enumerate(self.encoders): 129 | x = encoder(x) 130 | blocks.append(x) 131 | 132 | # residual blocks 133 | for resblock in self.resblocks: 134 | x = resblock(x) 135 | 136 | mid_feat = x.clone() 137 | 138 | # decoder 139 | for i, decoder in enumerate(self.decoders): 140 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 141 | 142 | final_feature = x.clone() 143 | img = self.activation(self.pred(self.apply_skip_connection(x, head))) 144 | 145 | return img, final_feature 146 | 147 | 148 | class UNetRecurrent(BaseUNet): 149 | """ 150 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block, 151 | such as a ConvLSTM or a ConvGRU. 152 | Symmetric, skip connections on every encoding layer. 153 | """ 154 | 155 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', 156 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32, 157 | num_residual_blocks=2, norm=None, use_upsample_conv=True): 158 | super(UNetRecurrent, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 159 | num_encoders, base_num_channels, num_residual_blocks, norm, 160 | use_upsample_conv) 161 | 162 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels, 163 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W 164 | 165 | self.encoders = nn.ModuleList() 166 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes): 167 | self.encoders.append(RecurrentConvLayer(input_size, output_size, 168 | kernel_size=5, stride=2, padding=2, 169 | recurrent_block_type=recurrent_block_type, 170 | norm=self.norm)) 171 | 172 | self.build_resblocks() 173 | self.build_decoders() 174 | self.build_prediction_layer() 175 | 176 | def forward(self, x, prev_states): 177 | """ 178 | :param x: N x num_input_channels x H x W 179 | :param prev_states: previous LSTM states for every encoder layer 180 | :return: N x num_output_channels x H x W 181 | """ 182 | 183 | # head 184 | x = self.head(x) 185 | head = x 186 | 187 | if prev_states is None: 188 | prev_states = [None] * self.num_encoders 189 | 190 | # encoder 191 | blocks = [] 192 | states = [] 193 | for i, encoder in enumerate(self.encoders): 194 | x, state = encoder(x, prev_states[i]) 195 | blocks.append(x) 196 | states.append(state) 197 | 198 | # residual blocks 199 | for resblock in self.resblocks: 200 | x = resblock(x) 201 | 202 | # decoder 203 | for i, decoder in enumerate(self.decoders): 204 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 205 | 206 | # tail 207 | img = self.activation(self.pred(self.apply_skip_connection(x, head))) 208 | 209 | return img, states 210 | 211 | 212 | class UNetFire(BaseUNet): 213 | """ 214 | """ 215 | 216 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', 217 | recurrent_block_type='convgru', base_num_channels=16, 218 | num_residual_blocks=2, norm=None, kernel_size=3, 219 | recurrent_blocks={'resblock': [0]}): 220 | super(UNetFire, self).__init__(num_input_channels=num_input_channels, 221 | num_output_channels=num_output_channels, 222 | skip_type=skip_type, 223 | base_num_channels=base_num_channels, 224 | num_residual_blocks=num_residual_blocks, 225 | norm=norm) 226 | 227 | self.kernel_size = kernel_size 228 | self.recurrent_blocks = recurrent_blocks 229 | self.head = RecurrentConvLayer(self.num_input_channels, 230 | self.base_num_channels, 231 | kernel_size=self.kernel_size, 232 | padding=self.kernel_size // 2, 233 | recurrent_block_type=recurrent_block_type, 234 | norm=self.norm) 235 | 236 | self.num_recurrent_units = 1 237 | self.resblocks = nn.ModuleList() 238 | recurrent_indices = self.recurrent_blocks.get('resblock', []) 239 | for i in range(self.num_residual_blocks): 240 | if i in recurrent_indices or -1 in recurrent_indices: 241 | self.resblocks.append(RecurrentResidualLayer( 242 | in_channels=self.base_num_channels, 243 | out_channels=self.base_num_channels, 244 | recurrent_block_type=recurrent_block_type, 245 | norm=self.norm)) 246 | self.num_recurrent_units += 1 247 | else: 248 | self.resblocks.append(ResidualBlock(self.base_num_channels, 249 | self.base_num_channels, 250 | norm=self.norm)) 251 | 252 | self.pred = ConvLayer(2 * self.base_num_channels if self.skip_type == 'concat' else self.base_num_channels, 253 | self.num_output_channels, kernel_size=1, padding=0, activation=None, norm=None) 254 | 255 | def forward(self, x, prev_states): 256 | """ 257 | :param x: N x num_input_channels x H x W 258 | :param prev_states: previous LSTM states for every encoder layer 259 | :return: N x num_output_channels x H x W 260 | """ 261 | 262 | if prev_states is None: 263 | prev_states = [None] * (self.num_recurrent_units) 264 | 265 | states = [] 266 | state_idx = 0 267 | 268 | # head 269 | x, state = self.head(x, prev_states[state_idx]) 270 | state_idx += 1 271 | states.append(state) 272 | 273 | # residual blocks 274 | recurrent_indices = self.recurrent_blocks.get('resblock', []) 275 | for i, resblock in enumerate(self.resblocks): 276 | if i in recurrent_indices or -1 in recurrent_indices: 277 | x, state = resblock(x, prev_states[state_idx]) 278 | state_idx += 1 279 | states.append(state) 280 | else: 281 | x = resblock(x) 282 | 283 | # tail 284 | img = self.pred(x) 285 | return img, states 286 | 287 | 288 | 289 | class UNetStatic(BaseUNet): 290 | """ 291 | """ 292 | 293 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', 294 | recurrent_block_type='convgru', base_num_channels=16, 295 | num_residual_blocks=2, norm=None, kernel_size=3, 296 | recurrent_blocks={'resblock': [0]}): 297 | super(UNetStatic, self).__init__(num_input_channels=num_input_channels, 298 | num_output_channels=num_output_channels, 299 | skip_type=skip_type, 300 | base_num_channels=base_num_channels, 301 | num_residual_blocks=num_residual_blocks, 302 | norm=norm) 303 | 304 | self.kernel_size = kernel_size 305 | self.recurrent_blocks = recurrent_blocks 306 | self.head = ConvLayer(self.num_input_channels, 307 | self.base_num_channels, 308 | kernel_size=self.kernel_size, 309 | padding=self.kernel_size // 2, 310 | norm=self.norm) 311 | 312 | self.num_recurrent_units = 1 313 | self.resblocks = nn.ModuleList() 314 | 315 | self.resblocks.append(ResidualBlock(self.base_num_channels, 316 | self.base_num_channels, 317 | norm=self.norm)) 318 | 319 | self.pred = ConvLayer(2 * self.base_num_channels if self.skip_type == 'concat' else self.base_num_channels, 320 | self.num_output_channels, kernel_size=1, padding=0, activation='relu', norm=None) 321 | 322 | def forward(self, x): 323 | """ 324 | :param x: N x num_input_channels x H x W 325 | :param prev_states: previous LSTM states for every encoder layer 326 | :return: N x num_output_channels x H x W 327 | """ 328 | # head 329 | x = self.head(x) 330 | 331 | # residual blocks 332 | for i, resblock in enumerate(self.resblocks): 333 | x = resblock(x) 334 | 335 | # tail 336 | img = self.pred(x) 337 | img = torch.clamp_max(img, 1.0) 338 | return img -------------------------------------------------------------------------------- /lib/recorder.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import shutil 5 | import logging 6 | from pathlib import Path 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | def file_backup(exp_path, cfg, train_script): 11 | shutil.copy(train_script, exp_path) 12 | shutil.copytree('core', os.path.join(exp_path, 'core'), dirs_exist_ok=True) 13 | shutil.copytree('config', os.path.join(exp_path, 'config'), dirs_exist_ok=True) 14 | shutil.copytree('gaussian_renderer', os.path.join(exp_path, 'gaussian_renderer'), dirs_exist_ok=True) 15 | for sub_dir in ['lib']: 16 | files = os.listdir(sub_dir) 17 | for file in files: 18 | Path(os.path.join(exp_path, sub_dir)).mkdir(exist_ok=True, parents=True) 19 | if file[-3:] == '.py': 20 | shutil.copy(os.path.join(sub_dir, file), os.path.join(exp_path, sub_dir)) 21 | 22 | json_file_name = exp_path + '/cfg.json' 23 | with open(json_file_name, 'w') as json_file: 24 | json.dump(cfg, json_file, indent=2) 25 | 26 | class Logger: 27 | def __init__(self, scheduler, cfg): 28 | self.scheduler = scheduler 29 | self.sum_freq = cfg.loss_freq 30 | self.log_dir = cfg.logs_path 31 | self.total_steps = 0 32 | self.running_loss = {} 33 | self.writer = SummaryWriter(log_dir=self.log_dir) 34 | 35 | def _print_training_status(self): 36 | metrics_data = [self.running_loss[k] / self.sum_freq for k in sorted(self.running_loss.keys())] 37 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps, self.scheduler.get_last_lr()[0]) 38 | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) 39 | 40 | # print the training status 41 | # print(" in print training status : ", f"steps : {self.total_steps}): {training_str + metrics_str}") 42 | # logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}") 43 | 44 | if self.writer is None: 45 | self.writer = SummaryWriter(log_dir=self.log_dir) 46 | 47 | for k in self.running_loss: 48 | self.writer.add_scalar(k, self.running_loss[k] / self.sum_freq, self.total_steps) 49 | self.running_loss[k] = 0.0 50 | 51 | def push(self, metrics): 52 | for key in metrics: 53 | if key not in self.running_loss: 54 | self.running_loss[key] = 0.0 55 | 56 | self.running_loss[key] += metrics[key] 57 | 58 | if self.total_steps and self.total_steps % self.sum_freq == 0: 59 | self._print_training_status() 60 | self.running_loss = {} 61 | 62 | self.total_steps += 1 63 | 64 | 65 | def write_dict(self, results, write_step): 66 | if self.writer is None: 67 | self.writer = SummaryWriter(log_dir=self.log_dir) 68 | 69 | for key in results: 70 | self.writer.add_scalar(key, results[key], write_step) 71 | 72 | def close(self): 73 | self.writer.close() -------------------------------------------------------------------------------- /lib/renderer/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_render import pts2render 2 | from .rend_utils import depth2pc -------------------------------------------------------------------------------- /lib/renderer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/renderer/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/renderer/__pycache__/gaussian_render.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/gaussian_render.cpython-37.pyc -------------------------------------------------------------------------------- /lib/renderer/__pycache__/gaussian_render.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/gaussian_render.cpython-39.pyc -------------------------------------------------------------------------------- /lib/renderer/__pycache__/rend_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/rend_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/renderer/gaussian_render.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 4 | 5 | def render(cam, idx, pts_xyz, pts_rgb, rotations, scales, opacity, bg_color): 6 | """ 7 | Render the scene. 8 | Background tensor (bg_color) must be on GPU! 9 | """ 10 | 11 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 12 | bg_color = torch.tensor(bg_color, dtype=torch.float32, device=pts_xyz.device) 13 | screenspace_points = torch.zeros_like(pts_xyz, dtype=torch.float32, requires_grad=True, device=pts_xyz.device) + 0 14 | try: 15 | screenspace_points.retain_grad() 16 | except: 17 | pass 18 | 19 | # Set up rasterization configuration 20 | tanfovx = math.tan(cam['FovX'][idx] * 0.5) 21 | tanfovy = math.tan(cam['FovY'][idx] * 0.5) 22 | 23 | raster_settings = GaussianRasterizationSettings( 24 | image_height=int(cam['H'][idx]), 25 | image_width=int(cam['W'][idx]), 26 | tanfovx=tanfovx, 27 | tanfovy=tanfovy, 28 | bg=bg_color, 29 | scale_modifier=1.0, 30 | viewmatrix=cam['world_view_transform'][idx], 31 | projmatrix=cam['full_proj_transform'][idx], 32 | sh_degree=3, 33 | campos=cam['camera_center'][idx], 34 | prefiltered=False, 35 | debug=True 36 | ) 37 | 38 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 39 | 40 | rendered_image, _ = rasterizer( 41 | means3D=pts_xyz, 42 | means2D=screenspace_points, 43 | shs=None, 44 | colors_precomp=pts_rgb, 45 | opacities=opacity, 46 | scales=scales, 47 | rotations=rotations, 48 | cov3D_precomp=None) 49 | # print("render image shape : ", rendered_image.shape) 50 | 51 | return rendered_image 52 | 53 | # def pts2render(cam, pcd, gs_scaling, gs_opacity, gs_rotation, bg_color): 54 | # bs = pcd.shape[0] 55 | # # print(gs_data) 56 | # render_novel_list = [] 57 | # for i in range(bs): 58 | # xyz_i = pcd[i, :3, :].permute(1, 0) 59 | # rgb_i = pcd[i, 3:6, :].permute(1, 0) 60 | # scale_i = gs_scaling[i].permute(1, 0) 61 | # opacity_i = gs_opacity[i].permute(1, 0) 62 | # rot_i = gs_rotation[i].permute(1, 0) 63 | # render_novel_i = render(cam, i, xyz_i, rgb_i, rot_i, scale_i, opacity_i, bg_color=bg_color) 64 | # render_novel_list.append(render_novel_i) 65 | 66 | # return torch.stack(render_novel_list, dim=0) 67 | 68 | def pts2render(data, bg_color): 69 | bs = data['lview']['img'].shape[0] 70 | render_novel_list = [] 71 | for i in range(bs): 72 | xyz_i_valid = [] 73 | rgb_i_valid = [] 74 | rot_i_valid = [] 75 | scale_i_valid = [] 76 | opacity_i_valid = [] 77 | for view in ['lview', 'rview']: 78 | valid_i = data[view]['pts_valid'][i, :].bool() 79 | xyz_i = data[view]['pts'][i, :, :] 80 | rgb_i = data[view]['img'][i, :, :, :].permute(1, 2, 0).view(-1, 1) 81 | rot_i = data[view]['rot'][i, :, :, :].permute(1, 2, 0).view(-1, 4) 82 | scale_i = data[view]['scale'][i, :, :, :].permute(1, 2, 0).view(-1, 3) 83 | opacity_i = data[view]['opacity'][i, :, :, :].permute(1, 2, 0).view(-1, 1) 84 | 85 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3)) 86 | rgb_i_valid.append(rgb_i[valid_i].view(-1, 1)) 87 | rot_i_valid.append(rot_i[valid_i].view(-1, 4)) 88 | scale_i_valid.append(scale_i[valid_i].view(-1, 3)) 89 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1)) 90 | 91 | pts_xyz_i = torch.concat(xyz_i_valid, dim=0) 92 | pts_rgb_i = torch.concat(rgb_i_valid, dim=0).repeat((1,3)) 93 | # pts_rgb_i = pts_rgb_i * 0.5 + 0.5 94 | rot_i = torch.concat(rot_i_valid, dim=0) 95 | scale_i = torch.concat(scale_i_valid, dim=0) 96 | opacity_i = torch.concat(opacity_i_valid, dim=0) 97 | 98 | render_novel_i = render(data["target"], i, pts_xyz_i, pts_rgb_i, rot_i, scale_i, opacity_i, bg_color=bg_color) 99 | render_novel_list.append(render_novel_i.unsqueeze(0)) 100 | 101 | return torch.concat(render_novel_list, dim=0) 102 | -------------------------------------------------------------------------------- /lib/renderer/rend_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as Rot 5 | from scipy.spatial.transform import Slerp 6 | 7 | 8 | def focal2fov(focal, pixels): 9 | return 2 * math.atan(pixels / (2 * focal)) 10 | 11 | 12 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 13 | Rt = np.zeros((4, 4)) 14 | Rt[:3, :3] = R.transpose() 15 | Rt[:3, 3] = t 16 | Rt[3, 3] = 1.0 17 | 18 | C2W = np.linalg.inv(Rt) 19 | cam_center = C2W[:3, 3] 20 | cam_center = (cam_center + translate) * scale 21 | C2W[:3, 3] = cam_center 22 | Rt = np.linalg.inv(C2W) 23 | return np.float32(Rt) 24 | 25 | 26 | def getProjectionMatrix(znear, zfar, K, h, w): 27 | near_fx = znear / K[0, 0] 28 | near_fy = znear / K[1, 1] 29 | left = - (w - K[0, 2]) * near_fx 30 | right = K[0, 2] * near_fx 31 | bottom = (K[1, 2] - h) * near_fy 32 | top = K[1, 2] * near_fy 33 | 34 | P = torch.zeros(4, 4) 35 | z_sign = 1.0 36 | P[0, 0] = 2.0 * znear / (right - left) 37 | P[1, 1] = 2.0 * znear / (top - bottom) 38 | P[0, 2] = (right + left) / (right - left) 39 | P[1, 2] = (top + bottom) / (top - bottom) 40 | P[3, 2] = z_sign 41 | P[2, 2] = z_sign * zfar / (zfar - znear) 42 | P[2, 3] = -(zfar * znear) / (zfar - znear) 43 | return P 44 | 45 | 46 | def preprocess_render(batch): 47 | H, W = batch['H'][0], batch['W'][0] 48 | extrs = batch["cam_extrinsics"] 49 | intrs = batch["cam_intrinsics"] 50 | # znear, zfar = batch["znear"], batch["zfar"] 51 | znear, zfar = 0.5, 10000 52 | B = extrs.shape[0] # ` 53 | 54 | proj_mat = [getProjectionMatrix(znear, zfar, intrs[i], H, W).transpose(0, 1) for i in range(B)] 55 | world_view_transform = [ 56 | getWorld2View2(extrs[i][:3, :3].reshape(3, 3).transpose(1, 0), extrs[i][:3, 3]).transpose(0, 1) for i in 57 | range(B)] 58 | proj_mat = torch.stack(proj_mat, dim=0) # [4,4] 59 | # print("proj mat = ", proj_mat) 60 | 61 | world_view_transform = torch.stack(world_view_transform, dim=0) # [4,4] 62 | 63 | full_proj_transform = (world_view_transform.bmm(proj_mat)) 64 | camera_center = world_view_transform.inverse()[:, 3, :3] 65 | 66 | FovX = [torch.FloatTensor([focal2fov(intrs[i][0, 0], W)]) for i in range(B)] 67 | 68 | # print("111",FovX[0]) 69 | FovY = [torch.FloatTensor([focal2fov(intrs[i][1, 1], H)]) for i in range(B)] 70 | 71 | return {"projection_matrix": proj_mat, 72 | "world_view_transform": world_view_transform, 73 | "full_proj_transform": full_proj_transform, 74 | "camera_center": camera_center, 75 | "H": torch.ones(B) * H, 76 | "W": torch.ones(B) * W, 77 | "FovX": torch.stack(FovX, dim=0), 78 | "FovY": torch.stack(FovY, dim=0) 79 | } 80 | 81 | def depth2pc(depth, extrinsic, intrinsic): 82 | B, C, H, W = depth.shape 83 | depth = depth[:, 0, :, :] 84 | rot = extrinsic[:, :3, :3] 85 | trans = extrinsic[:, :3, 3:] 86 | 87 | y, x = torch.meshgrid(torch.linspace(0.5, H-0.5, H, device=depth.device), torch.linspace(0.5, W-0.5, W, device=depth.device)) 88 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1) # B S S 3 89 | 90 | pts_2d[..., 2] = depth 91 | pts_2d[:, :, :, 0] -= intrinsic[:, None, None, 0, 2] 92 | pts_2d[:, :, :, 1] -= intrinsic[:, None, None, 1, 2] 93 | pts_2d_xy = pts_2d[:, :, :, :2] * pts_2d[:, :, :, 2:] 94 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1) 95 | 96 | pts_2d[..., 0] /= intrinsic[:, 0, 0][:, None, None] 97 | pts_2d[..., 1] /= intrinsic[:, 1, 1][:, None, None] 98 | 99 | pts_2d = pts_2d.view(B, -1, 3).permute(0, 2, 1) 100 | rot_t = rot.permute(0, 2, 1) 101 | pts = torch.bmm(rot_t, pts_2d) - torch.bmm(rot_t, trans) 102 | 103 | return pts.permute(0, 2, 1) -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 3 | from matplotlib.backends.backend_agg import FigureCanvasAgg 4 | from matplotlib.figure import Figure 5 | import matplotlib as mpl 6 | from matplotlib import cm 7 | import cv2 8 | 9 | def get_vertical_colorbar(h, vmin, vmax, cmap_name='jet', label=None): 10 | fig = Figure(figsize=(1.2, 8), dpi=100) 11 | fig.subplots_adjust(right=1.5) 12 | canvas = FigureCanvasAgg(fig) 13 | # Do some plotting. 14 | ax = fig.add_subplot(111) 15 | cmap = cm.get_cmap(cmap_name) 16 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 17 | tick_cnt = 6 18 | tick_loc = np.linspace(vmin, vmax, tick_cnt) 19 | cb1 = mpl.colorbar.ColorbarBase(ax, cmap=cmap, 20 | norm=norm, 21 | ticks=tick_loc, 22 | orientation='vertical') 23 | tick_label = ['{:3.2f}'.format(x) for x in tick_loc] 24 | cb1.set_ticklabels(tick_label) 25 | cb1.ax.tick_params(labelsize=18, rotation=0) 26 | if label is not None: 27 | cb1.set_label(label) 28 | fig.tight_layout() 29 | canvas.draw() 30 | s, (width, height) = canvas.print_to_buffer() 31 | im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) 32 | im = im[:, :, :3].astype(np.float32) / 255. 33 | if h != im.shape[0]: 34 | w = int(im.shape[1] / im.shape[0] * h) 35 | im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) 36 | return im 37 | 38 | def colorize_np(x, cmap_name='jet', mask=None, append_cbar=False): 39 | HUGE_NUMBER = 1e10 40 | TINY_NUMBER = 1e-6 41 | if mask is not None: 42 | # vmin, vmax = np.percentile(x[mask], (1, 99)) 43 | vmin = np.min(x[mask]) 44 | vmax = np.max(x[mask]) 45 | vmin = vmin - np.abs(vmin) * 0.01 46 | x[np.logical_not(mask)] = vmin 47 | x = np.clip(x, vmin, vmax) 48 | # print(vmin, vmax) 49 | else: 50 | vmin = x.min() 51 | vmax = x.max() + TINY_NUMBER 52 | x = (x - vmin) / (vmax - vmin) 53 | # x = np.clip(x, 0., 1.) 54 | cmap = cm.get_cmap(cmap_name) 55 | x_new = cmap(x)[:, :, :3] 56 | if mask is not None: 57 | mask = np.float32(mask[:, :, np.newaxis]) 58 | x_new = x_new * mask + np.zeros_like(x_new) * (1. - mask) 59 | cbar = get_vertical_colorbar(h=x.shape[0], vmin=vmin, vmax=vmax, cmap_name=cmap_name) 60 | if append_cbar: 61 | x_new = np.concatenate((x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1) 62 | return x_new 63 | else: 64 | return x_new, cbar 65 | 66 | # Traditional depth2img of taichi-nerf 67 | # def depth2img(depth): 68 | # depth = (depth - depth.min()) / (depth.max() - depth.min()) 69 | # depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8), 70 | # cv2.COLORMAP_TURBO) 71 | # return depth_img 72 | 73 | ### depth2img from EventNeRF ###### 74 | def depth2img(depth): 75 | im,cbar = colorize_np(depth, cmap_name='jet', append_cbar=False) 76 | im = to8b(im) 77 | return im 78 | -------------------------------------------------------------------------------- /pretrain_ckpt/download.sh: -------------------------------------------------------------------------------- 1 | ### download pretrain files ### -------------------------------------------------------------------------------- /train_gs.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | from lib.dataset import EventDataloader 3 | from lib.recorder import Logger, file_backup 4 | from lib.network import model_loss_light, model_loss, EventGaussian 5 | from lib.renderer import pts2render, depth2pc 6 | from lib.utils import depth2img 7 | import numpy as np 8 | import imageio 9 | import cv2 10 | import os 11 | from pathlib import Path 12 | from tqdm import tqdm 13 | import logging 14 | import torch 15 | from torch import optim 16 | from torch.utils.data import DataLoader 17 | import torch.nn.functional as F 18 | 19 | cs = cfg.cs 20 | 21 | class Trainer: 22 | def __init__(self) -> None: 23 | device = torch.device('cuda:{}'.format(cfg.local_rank)) 24 | self.device = device 25 | which_test = "val" 26 | self.train_loader = EventDataloader(cfg.dataset.base_folder, split="train", num_workers=1,\ 27 | batch_size=1, shuffle=False) 28 | 29 | self.val_loader = EventDataloader(cfg.dataset.base_folder, split=which_test, num_workers=1,\ 30 | batch_size=1, shuffle=False) 31 | 32 | self.len_val = len(self.val_loader) 33 | self.model = EventGaussian().to(self.device) 34 | print(" Load warm up parameters ... ") 35 | d_warmup = False 36 | int_warmup = False 37 | if cfg.depth_warmup_ckpt is not None: 38 | self.model.depth_estimator.load_state_dict(torch.load(cfg.depth_warmup_ckpt)["network"]) 39 | d_warmup = True 40 | if cfg.intensity_warmup_ckpt is not None: 41 | self.model.intensity_estimator.load_state_dict(torch.load(cfg.intensity_warmup_ckpt)["network"]) 42 | int_warmup = True 43 | print(f" Using depth warm up {d_warmup} ; intensity warm up {int_warmup}") 44 | 45 | # self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=cfg.wdecay, eps=1e-8) 46 | dpt_params = list(map(id,self.model.depth_estimator.parameters())) + list(map(id,self.model.intensity_estimator.parameters())) 47 | rest_params = filter(lambda x:id(x) not in dpt_params,self.model.parameters()) 48 | self.optimizer = optim.Adam([ 49 | {'params':self.model.depth_estimator.parameters(), 'lr':0.00001}, 50 | {'params':self.model.intensity_estimator.parameters(), 'lr':0.00001}, 51 | {'params':rest_params, 'lr':0.0005}, 52 | ], lr=0.0005, weight_decay=cfg.wdecay, eps=1e-8) 53 | 54 | # self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, 0.001, 1000000 + 100, 55 | # pct_start=0.01, cycle_momentum=False, anneal_strategy='linear') 56 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=15000, gamma=0.9) 57 | 58 | self.logger = Logger(self.scheduler, cfg.record) 59 | 60 | self.total_steps = 0 61 | self.target_epoch = cfg.target_epoch 62 | 63 | if cfg.restore_ckpt: 64 | self.load_ckpt(cfg.restore_ckpt) 65 | 66 | self.model.train() 67 | 68 | def train(self): 69 | # self.model.eval() 70 | # self.run_eval() 71 | # self.model.train() 72 | for e in range(self.target_epoch): 73 | for idx, batch in enumerate(tqdm(self.train_loader)): 74 | batch = self.to_cuda(batch) 75 | 76 | ### model and loss computing ## 77 | gt = {} 78 | gt["cim"] = batch["cim"] 79 | 80 | gt["lim"], gt["rim"], gt["ldepth"], gt["rdepth"], gt["lmask"], gt["rmask"] \ 81 | = batch["lim"], batch["rim"], batch["ldepth"], batch["rdepth"], batch["lmask"], batch["rmask"] 82 | 83 | batch["left_event_tensor"] = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1) 84 | batch["right_event_tensor"] = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1) 85 | 86 | data = self.model(batch) 87 | 88 | data["target"] = {"H":batch["H"], 89 | "W":batch["W"], 90 | "FovX":batch["FovX"], 91 | "FovY":batch["FovY"], 92 | 'world_view_transform': batch["world_view_transform"], 93 | 'full_proj_transform': batch["full_proj_transform"], 94 | 'camera_center': batch["camera_center"]} 95 | 96 | imgL, depthL, maskL = data["lview"]["img"], data["lview"]["depth"], data["lview"]["mask"] 97 | imgR, depthR, maskR = data["rview"]["img"], data["rview"]["depth"], data["rview"]["mask"] 98 | 99 | data["lview"]["pts"] = depth2pc(data["lview"]["depth"], torch.inverse(batch["lpose"]), batch["intrinsic"]) 100 | data["rview"]["pts"] = depth2pc(data["rview"]["depth"], torch.inverse(batch["rpose"]), batch["intrinsic"]) 101 | 102 | pred = pts2render(data, [0.,0.,0.])[:,0:1] 103 | loss = F.l1_loss(pred, gt["cim"]) 104 | 105 | imgloss = torch.mean((imgL - gt["lim"])**2) + torch.mean((imgR - gt["rim"])**2) 106 | depthloss = F.l1_loss(depthL, gt["ldepth"]) + F.l1_loss(depthR, gt["rdepth"]) 107 | maskloss = F.binary_cross_entropy(maskL.reshape(-1, 1), gt["lmask"].reshape(-1, 1).float()) + \ 108 | F.binary_cross_entropy(maskR.reshape(-1, 1), gt["rmask"].reshape(-1, 1).float()) 109 | loss = loss + 0.33*imgloss + 0.33*depthloss + 0.33*maskloss 110 | # msk = torch.ones_like(gt).to(bool) 111 | metrics = { 112 | "l1loss" : loss.item() 113 | } 114 | 115 | if self.total_steps and self.total_steps % cfg.record.loss_freq == 0: 116 | self.logger.writer.add_scalar(f'lr', self.optimizer.param_groups[0]['lr'], self.total_steps) 117 | print(f"{cfg.exp_name} epoch {e} step {self.total_steps} L1loss {loss.item()} lr {self.optimizer.param_groups[0]['lr']}") 118 | self.logger.push(metrics) 119 | 120 | if self.total_steps and self.total_steps % cfg.record.save_freq == 0: 121 | self.save_ckpt(save_path=Path('%s/%d_%d.pth' % (cfg.record.ckpt_path, e, self.total_steps)), show_log=False) 122 | 123 | self.optimizer.zero_grad() 124 | loss.backward() 125 | self.optimizer.step() 126 | self.scheduler.step() 127 | 128 | if self.total_steps and self.total_steps % cfg.record.eval_freq == 0: 129 | self.model.eval() 130 | self.run_eval() 131 | self.model.train() 132 | 133 | self.total_steps += 1 134 | 135 | print("FINISHED TRAINING") 136 | self.logger.close() 137 | self.save_ckpt(save_path=Path('%s/%s_final.pth' % (cfg.record.ckpt_path, cfg.exp_name))) 138 | 139 | def to_cuda(self, batch): 140 | for k in batch: 141 | if isinstance(batch[k], tuple) or isinstance(batch[k], list): 142 | batch[k] = [b.to(self.device) for b in batch[k]] 143 | elif isinstance(batch[k], dict): 144 | batch[k] = {key: self.to_cuda(batch[k][key]) for key in batch[k]} 145 | else: 146 | batch[k] = batch[k].to(self.device) 147 | return batch 148 | 149 | def run_eval(self): 150 | print(f"Doing validation ...") 151 | torch.cuda.empty_cache() 152 | 153 | l1_list = [] 154 | 155 | show_idx = [np.random.choice(list(range(self.len_val)), 1)] 156 | # show_idx = 0 157 | # show_idx = list(range(self.len_val)) 158 | for idx, batch in enumerate(self.val_loader): 159 | with torch.no_grad(): 160 | batch = self.to_cuda(batch) 161 | gt = batch["cim"] 162 | 163 | batch["left_event_tensor"] = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1) 164 | batch["right_event_tensor"] = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1) 165 | 166 | data = self.model(batch) 167 | 168 | data["target"] = {"H":batch["H"], 169 | "W":batch["W"], 170 | "FovX":batch["FovX"], 171 | "FovY":batch["FovY"], 172 | 'world_view_transform': batch["world_view_transform"], 173 | 'full_proj_transform': batch["full_proj_transform"], 174 | 'camera_center': batch["camera_center"]} 175 | 176 | data["lview"]["pts"] = depth2pc(data["lview"]["depth"], torch.inverse(batch["lpose"]), batch["intrinsic"]) 177 | data["rview"]["pts"] = depth2pc(data["rview"]["depth"], torch.inverse(batch["rpose"]), batch["intrinsic"]) 178 | 179 | pred = pts2render(data, [0.,0.,0.])[:,0] 180 | loss = F.l1_loss(pred.squeeze(), gt.squeeze()) 181 | l1_list.append(loss.item()) 182 | 183 | # if idx == show_idx: 184 | if idx in show_idx: 185 | print("show idx is ", idx) 186 | tmp_gt = (gt[0]*255.0).cpu().numpy().astype(np.uint8).squeeze() 187 | tmp_pred = (pred[0]*255.0).cpu().numpy().astype(np.uint8).squeeze() 188 | # tmp_gt = tmp_pred 189 | tmp_img_name = '%s/step%s_idx%d.jpg' % (cfg.record.show_path, self.total_steps, idx) 190 | imageio.imsave(tmp_img_name, np.concatenate([tmp_pred, tmp_gt], axis=0)) 191 | 192 | val_l1 = np.round(np.mean(np.array(l1_list)), 4) 193 | print(f"Validation Metrics ({self.total_steps}):, L1 {val_l1}") 194 | self.logger.write_dict({'val_l1': val_l1}, write_step=self.total_steps) 195 | torch.cuda.empty_cache() 196 | 197 | def save_ckpt(self, save_path, show_log=True): 198 | if show_log: 199 | print(f"Save checkpoint to {save_path} ...") 200 | torch.save({ 201 | 'total_steps': self.total_steps, 202 | 'network': self.model.state_dict(), 203 | 'optimizer': self.optimizer.state_dict(), 204 | 'scheduler': self.scheduler.state_dict() 205 | }, save_path) 206 | 207 | def load_ckpt(self, load_path, load_optimizer=True, strict=True): 208 | assert os.path.exists(load_path) 209 | print(f"Loading checkpoint from {load_path} ...") 210 | ckpt = torch.load(load_path, map_location='cuda') 211 | self.model.load_state_dict(ckpt['network'], strict=strict) 212 | print(f"Parameter loading done") 213 | if load_optimizer: 214 | self.total_steps = ckpt['total_steps'] + 1 215 | self.logger.total_steps = self.total_steps 216 | self.optimizer.load_state_dict(ckpt['optimizer']) 217 | self.scheduler.load_state_dict(ckpt['scheduler']) 218 | print(f"Optimizer loading done") 219 | 220 | 221 | if __name__ == "__main__": 222 | # L = torch.randn((1,3,640,480)).cuda() 223 | # R = torch.randn((1,3,640,480)).cuda() 224 | # # net = Net(int(5)).cuda() 225 | # net = ASNet_light(int(5)).cuda() 226 | # out = net(L,R) 227 | # print(len(out)) 228 | # Input = torch.randn((1, 5, 640, 480)).cuda() 229 | # fnet = FireNet({"num_bins":5}).cuda() 230 | # out = fnet(Input, None) 231 | # print(out[0].shape) 232 | 233 | trainer = Trainer() 234 | trainer.train() 235 | 236 | 237 | 238 | --------------------------------------------------------------------------------