├── .gitignore
├── Figures
└── network.png
├── README.md
├── configs
├── Ev3D_pretrain.yaml
└── default.yaml
├── environment.yml
├── eval_gs.py
├── lib
├── __pycache__
│ ├── losses.cpython-37.pyc
│ ├── recorder.cpython-310.pyc
│ ├── recorder.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── config
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── config.cpython-310.pyc
│ │ ├── config.cpython-37.pyc
│ │ ├── config.cpython-39.pyc
│ │ ├── yacs.cpython-310.pyc
│ │ ├── yacs.cpython-37.pyc
│ │ └── yacs.cpython-39.pyc
│ ├── config.py
│ └── yacs.py
├── dataset
│ ├── Ev3D.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── Ev3D.cpython-310.pyc
│ │ ├── Ev3D.cpython-37.pyc
│ │ ├── Ev3D.cpython-39.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── utils.cpython-37.pyc
│ │ └── utils.cpython-39.pyc
│ └── utils.py
├── losses.py
├── network
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── ASNet.cpython-37.pyc
│ │ ├── ASNet_utils.cpython-37.pyc
│ │ ├── SegNet.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── asnet.cpython-37.pyc
│ │ ├── asnet_utils.cpython-37.pyc
│ │ ├── densenet.cpython-37.pyc
│ │ ├── dfanet.cpython-37.pyc
│ │ ├── eventgaussian.cpython-37.pyc
│ │ ├── firenet.cpython-37.pyc
│ │ ├── gsregressor.cpython-37.pyc
│ │ ├── mobilenetv2.cpython-37.pyc
│ │ ├── pspnet.cpython-37.pyc
│ │ ├── recon_net.cpython-37.pyc
│ │ ├── resnet.cpython-37.pyc
│ │ ├── submodules.cpython-37.pyc
│ │ ├── swin.cpython-37.pyc
│ │ └── unet.cpython-37.pyc
│ ├── asnet.py
│ ├── asnet_utils.py
│ ├── eventgaussian.py
│ ├── firenet.py
│ ├── gsregressor.py
│ ├── neurons.py
│ ├── recon_net.py
│ ├── resnet.py
│ ├── snn.py
│ ├── submodules.py
│ ├── swin.py
│ └── unet.py
├── recorder.py
├── renderer
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── gaussian_render.cpython-37.pyc
│ │ ├── gaussian_render.cpython-39.pyc
│ │ └── rend_utils.cpython-37.pyc
│ ├── gaussian_render.py
│ └── rend_utils.py
└── utils.py
├── pretrain_ckpt
└── download.sh
└── train_gs.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ./experiments/
2 | pretrain_ckpt/*.pth
3 | .history/
4 |
--------------------------------------------------------------------------------
/Figures/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/Figures/network.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EvGGS: A Collaborative Learning Framework for Event-based Generalizable Gaussian Splatting
2 | [A Collaborative Learning Framework for Event-based Generalizable Gaussian Splatting](https://arxiv.org/abs/2405.14959v1)
3 |
4 | Jiaxu Wang, Junhao He, Ziyi Zhang, Mingyuan Sun, Jingkai Sun, Renjing Xu*
5 |
6 |
7 | 
8 | Fig 1. The main pipeline overview of the proposed EvGGS framework.
9 |
10 |
11 | # Create environment
12 | ```
13 | conda env create --file environment.yml
14 | conda activate evggs
15 | ```
16 | Then, compile the diff-gaussian-rasterization in 3DGS repository:
17 | ```
18 | git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive
19 | cd gaussian-splatting/
20 | pip install -e submodules/diff-gaussian-rasterization
21 | cd ..
22 | ```
23 | # Download models
24 | Download the pretrained models from [OneDrive](https://hkustgz-my.sharepoint.com/:u:/g/personal/jwang457_connect_hkust-gz_edu_cn/ESAMKY3oHDRBr2-zeNb3L8IBKnFGiJCAgyRv3HBs6esFaQ?e=O7bili) that are placed at ```\pretrain_ckpt```. This directory includes two warmup ckpts and a pretrained ckpts on the synthetic dataset.
25 |
26 | # Running the code
27 |
28 | ## Download dataset
29 |
30 | - Ev3D-S
31 |
32 | A large-scale synthetic Event-based dataset with varying textures and materials accompanied by well-calibrated frames, depth, and groundtruths.
33 |
34 | You can download the dataset from [OneDrive](https://hkustgz-my.sharepoint.com/:u:/g/personal/jwang457_connect_hkust-gz_edu_cn/EYszUyxQnzRMkC0u5GxDOvEB_NhmBaVe2vBnpMH2ctSWxA?e=kJDwRz) and unzip it. A 50 GB of storage space is necessary.
35 |
36 |
37 | - EV3D-R
38 |
39 | A large-scale realistic Event-based 3D dataset containing various objects captured by a real event camera DVXplore.
40 |
41 | Due to some licensing reasons, we currently need your private application to use this dataset.
42 |
43 | ## Training
44 |
45 | ```
46 | python train_gs.py
47 | ```
48 |
49 | ## Evaluation
50 |
51 | ```
52 | python eval_gs.py
53 | ```
54 |
55 | In ```configs\Ev3D_pretrain```, several primary settings are defined such as experimental name, customized dataset path, please check.
56 |
57 | # Citation
58 |
59 | please cite our work if you use this dataset.
60 |
61 | ```
62 | @misc{wang2024evggscollaborativelearningframework,
63 | title={EvGGS: A Collaborative Learning Framework for Event-based Generalizable Gaussian Splatting},
64 | author={Jiaxu Wang and Junhao He and Ziyi Zhang and Mingyuan Sun and Jingkai Sun and Renjing Xu},
65 | year={2024},
66 | eprint={2405.14959},
67 | archivePrefix={arXiv},
68 | primaryClass={cs.CV},
69 | url={https://arxiv.org/abs/2405.14959},
70 | }
71 | ```
72 |
73 | # Reference
74 |
75 | EventNeRF: [https://github.com/r00tman/EventNeRF?tab=readme-ov-file](https://github.com/r00tman/EventNeRF?tab=readme-ov-file).
76 | 3D Gaussian Splatting: [https://github.com/graphdeco-inria/gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting).
77 | GPS-GS: [https://github.com/aipixel/GPS-Gaussian](https://github.com/aipixel/GPS-Gaussian)
78 | PAEvD3d: [https://github.com/Mercerai/PAEv3d](https://github.com/Mercerai/PAEv3d)
79 |
--------------------------------------------------------------------------------
/configs/Ev3D_pretrain.yaml:
--------------------------------------------------------------------------------
1 | exp_name: experimental_name
2 | lr: 0.0005
3 | wdecay: 1e-5
4 | target_epoch: 100
5 | num_steps: 1000000
6 | cs: [0, 480, 64, 576]
7 |
8 | dataset:
9 | base_folder:
10 | train_batch_size: 1
11 | val_batch_size: 1
12 | ratio: 0.75
13 |
14 | model:
15 | max_depth_plane: 64
16 | max_depth_value: 0.8
17 | num_bins: 5
18 |
19 | record:
20 | loss_freq: 50
21 | eval_freq: 5000
22 | save_freq: 10000
23 |
24 | restore_ckpt: None
25 | depth_warmup_ckpt: "./pretrain_ckpt/depth_warmup.pth"
26 | intensity_warmup_ckpt: "./pretrain_ckpt/intensity_warmup.pth"
27 | pretrain_ckpt: "./pretrain_ckpt/pretrain_evggs.pth"
28 |
29 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/configs/default.yaml
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: evggs
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/cloud/conda-forge/
6 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/free/
7 | - https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/main/
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=conda_forge
11 | - _openmp_mutex=4.5=2_kmp_llvm
12 | - blas=1.0=mkl
13 | - brotli-python=1.0.9=py37hd23a5d3_7
14 | - bzip2=1.0.8=hd590300_5
15 | - ca-certificates=2023.11.17=hbcca054_0
16 | - certifi=2023.11.17=pyhd8ed1ab_0
17 | - charset-normalizer=3.3.2=pyhd8ed1ab_0
18 | - colorama=0.4.6=pyhd8ed1ab_0
19 | - cudatoolkit=11.6.2=hfc3e2af_12
20 | - ffmpeg=4.3=hf484d3e_0
21 | - freetype=2.12.1=h267a509_2
22 | - gmp=6.3.0=h59595ed_0
23 | - gnutls=3.6.13=h85f3911_1
24 | - icu=73.2=h59595ed_0
25 | - idna=3.6=pyhd8ed1ab_0
26 | - jpeg=9e=h0b41bf4_3
27 | - lame=3.100=h166bdaf_1003
28 | - lcms2=2.14=h6ed2654_0
29 | - ld_impl_linux-64=2.40=h41732ed_0
30 | - lerc=4.0.0=h27087fc_0
31 | - libdeflate=1.14=h166bdaf_0
32 | - libffi=3.3=h58526e2_2
33 | - libgcc-ng=13.2.0=h807b86a_3
34 | - libhwloc=2.9.3=default_h554bfaf_1009
35 | - libiconv=1.17=hd590300_1
36 | - libpng=1.6.39=h753d276_0
37 | - libsqlite=3.44.2=h2797004_0
38 | - libstdcxx-ng=13.2.0=h7e041cc_3
39 | - libtiff=4.4.0=h82bc61c_5
40 | - libwebp-base=1.3.2=hd590300_0
41 | - libxcb=1.13=h7f98852_1004
42 | - libxml2=2.11.6=h232c23b_0
43 | - libzlib=1.2.13=hd590300_5
44 | - llvm-openmp=17.0.6=h4dfa4b3_0
45 | - mkl=2021.4.0=h8d4b97c_729
46 | - mkl-service=2.4.0=py37h402132d_0
47 | - mkl_fft=1.3.1=py37h3e078e5_1
48 | - mkl_random=1.2.2=py37h219a48f_0
49 | - ncurses=6.4=h59595ed_2
50 | - nettle=3.6=he412f7d_0
51 | - numpy=1.21.5=py37h6c91a56_3
52 | - numpy-base=1.21.5=py37ha15fc14_3
53 | - openh264=2.1.1=h780b84a_0
54 | - openjpeg=2.5.0=h7d73246_1
55 | - openssl=1.1.1w=hd590300_0
56 | - pip=22.3.1=pyhd8ed1ab_0
57 | - plyfile=0.8.1=pyhd8ed1ab_0
58 | - pthread-stubs=0.4=h36c2ea0_1001
59 | - pysocks=1.7.1=py37h89c1867_5
60 | - python=3.7.13=haa1d7c7_1
61 | - python_abi=3.7=2_cp37m
62 | - pytorch=1.12.1=py3.7_cuda11.6_cudnn8.3.2_0
63 | - pytorch-mutex=1.0=cuda
64 | - readline=8.2=h8228510_1
65 | - requests=2.31.0=pyhd8ed1ab_0
66 | - setuptools=68.2.2=pyhd8ed1ab_0
67 | - six=1.16.0=pyh6c4a22f_0
68 | - sqlite=3.44.2=h2c6b66d_0
69 | - tbb=2021.11.0=h00ab1b0_0
70 | - tk=8.6.13=noxft_h4845f30_101
71 | - torchaudio=0.12.1=py37_cu116
72 | - torchvision=0.13.1=py37_cu116
73 | - tqdm=4.66.1=pyhd8ed1ab_0
74 | - typing_extensions=4.7.1=pyha770c72_0
75 | - urllib3=2.1.0=pyhd8ed1ab_0
76 | - wheel=0.42.0=pyhd8ed1ab_0
77 | - xorg-libxau=1.0.11=hd590300_0
78 | - xorg-libxdmcp=1.1.3=h7f98852_0
79 | - xz=5.2.6=h166bdaf_0
80 | - zlib=1.2.13=hd590300_5
81 | - zstd=1.5.5=hfc55251_0
82 | - pip:
83 | - absl-py==2.0.0
84 | - addict==2.4.0
85 | - ansi2html==1.9.1
86 | - attrs==23.2.0
87 | - backcall==0.2.0
88 | - cachetools==5.3.2
89 | - click==8.1.7
90 | - comm==0.1.4
91 | - configargparse==1.7
92 | - cycler==0.11.0
93 | - dash==2.14.2
94 | - dash-core-components==2.0.0
95 | - dash-html-components==2.0.0
96 | - dash-table==5.0.0
97 | - decorator==5.1.1
98 | - diff-gaussian-rasterization==0.0.0
99 | - fastjsonschema==2.19.1
100 | - flask==2.2.5
101 | - fonttools==4.38.0
102 | - google-auth==2.25.2
103 | - google-auth-oauthlib==0.4.6
104 | - grpcio==1.60.0
105 | - h5py==3.8.0
106 | - imageio==2.31.2
107 | - importlib-metadata==6.7.0
108 | - importlib-resources==5.12.0
109 | - ipython==7.34.0
110 | - ipywidgets==8.1.1
111 | - itsdangerous==2.1.2
112 | - jedi==0.19.1
113 | - jinja2==3.1.2
114 | - joblib==1.3.2
115 | - jsonschema==4.17.3
116 | - jupyter-core==4.12.0
117 | - jupyterlab-widgets==3.0.9
118 | - kiwisolver==1.4.5
119 | - lpips==0.1.4
120 | - markdown==3.4.4
121 | - markupsafe==2.1.3
122 | - matplotlib==3.5.3
123 | - matplotlib-inline==0.1.6
124 | - natsort==8.4.0
125 | - nbformat==5.7.0
126 | - nest-asyncio==1.5.8
127 | - oauthlib==3.2.2
128 | - open3d==0.17.0
129 | - opencv-python==4.9.0.80
130 | - packaging==23.2
131 | - pandas==1.3.5
132 | - parso==0.8.3
133 | - pexpect==4.9.0
134 | - pickleshare==0.7.5
135 | - pillow==9.5.0
136 | - pkgutil-resolve-name==1.3.10
137 | - plotly==5.18.0
138 | - prompt-toolkit==3.0.43
139 | - protobuf==3.20.3
140 | - ptyprocess==0.7.0
141 | - pyasn1==0.5.1
142 | - pyasn1-modules==0.3.0
143 | - pygments==2.17.2
144 | - pyparsing==3.1.1
145 | - pyquaternion==0.9.9
146 | - pyrsistent==0.19.3
147 | - python-dateutil==2.8.2
148 | - pytz==2023.3.post1
149 | - pyyaml==6.0.1
150 | - requests-oauthlib==1.3.1
151 | - retrying==1.3.4
152 | - rsa==4.9
153 | - scikit-learn==1.0.2
154 | - scipy==1.7.3
155 | - simple-knn==0.0.0
156 | - tenacity==8.2.3
157 | - tensorboard==2.11.2
158 | - tensorboard-data-server==0.6.1
159 | - tensorboard-plugin-wit==1.8.1
160 | - threadpoolctl==3.1.0
161 | - traitlets==5.9.0
162 | - wcwidth==0.2.12
163 | - werkzeug==2.2.3
164 | - widgetsnbextension==4.0.9
165 | - yacs==0.1.8
166 | - zipp==3.15.0
167 |
--------------------------------------------------------------------------------
/eval_gs.py:
--------------------------------------------------------------------------------
1 | from lib.config import cfg, args
2 | from lib.dataset import EventDataloader
3 | from lib.recorder import Logger, file_backup
4 | from lib.network import model_loss_light, model_loss, EventGaussian
5 | from lib.renderer import pts2render, depth2pc
6 | from lib.utils import depth2img
7 | from lib.losses import psnr, ssim
8 | import numpy as np
9 | import imageio
10 | import cv2
11 | import os
12 | from pathlib import Path
13 | from tqdm import tqdm
14 | import logging
15 | import torch
16 | from torch import optim
17 | from torch.utils.data import DataLoader
18 | import torch.nn.functional as F
19 | import lpips
20 | import time
21 |
22 | cs = cfg.cs
23 |
24 | class Trainer:
25 | def __init__(self) -> None:
26 | device = torch.device('cuda:{}'.format(cfg.local_rank))
27 | self.device = device
28 | which_test = "val"
29 |
30 | self.train_loader = None
31 | self.val_loader = EventDataloader(cfg.dataset.base_folder, split=which_test, num_workers=1,\
32 | batch_size=1, shuffle=False)
33 |
34 | self.len_val = len(self.val_loader)
35 | self.model = EventGaussian().to(self.device)
36 | # self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=cfg.wdecay, eps=1e-8)
37 | dpt_params = list(map(id,self.model.depth_estimator.parameters())) + list(map(id,self.model.intensity_estimator.parameters()))
38 | rest_params = filter(lambda x:id(x) not in dpt_params,self.model.parameters())
39 | self.optimizer = optim.Adam([
40 | {'params':self.model.depth_estimator.parameters(), 'lr':1},
41 | {'params':self.model.intensity_estimator.parameters(), 'lr':1},
42 | {'params':rest_params, 'lr':1},
43 | ], lr=0.001, weight_decay=cfg.wdecay, eps=1e-8)
44 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=10000, gamma=0.9)
45 | self.logger = Logger(self.scheduler, cfg.record)
46 |
47 | self.total_steps = 0
48 | self.target_epoch = cfg.target_epoch
49 |
50 | if cfg.restore_ckpt:
51 | self.load_ckpt(cfg.restore_ckpt)
52 |
53 | def to_cuda(self, batch):
54 | for k in batch:
55 | if isinstance(batch[k], tuple) or isinstance(batch[k], list):
56 | batch[k] = [b.to(self.device) for b in batch[k]]
57 | elif isinstance(batch[k], dict):
58 | batch[k] = {key: self.to_cuda(batch[k][key]) for key in batch[k]}
59 | else:
60 | batch[k] = batch[k].to(self.device)
61 | return batch
62 |
63 | def run_eval(self):
64 | print(f"Doing validation ...")
65 | torch.cuda.empty_cache()
66 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(self.device)
67 | l1_list = []
68 | psnr_list = []
69 | ssim_list = []
70 | lpips_list = []
71 | show_idx = list(range(self.len_val))
72 | count = 0
73 | scene_num = 0
74 | os.makedirs(r'%s/%d/' % (cfg.record.show_path, scene_num), exist_ok=True)
75 | for idx, batch in enumerate(tqdm(self.val_loader)):
76 | if count == 201:
77 | scene_num += 1
78 | count = 0
79 | os.makedirs(r'%s/%d/' % (cfg.record.show_path, scene_num), exist_ok=True)
80 | with torch.no_grad():
81 | batch = self.to_cuda(batch)
82 | gt = batch["cim"]
83 |
84 | batch["left_event_tensor"] = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1)
85 | batch["right_event_tensor"] = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1)
86 |
87 | start_time = time.time()
88 | data = self.model(batch)
89 |
90 | data["target"] = {"H":batch["H"],
91 | "W":batch["W"],
92 | "FovX":batch["FovX"],
93 | "FovY":batch["FovY"],
94 | 'world_view_transform': batch["world_view_transform"],
95 | 'full_proj_transform': batch["full_proj_transform"],
96 | 'camera_center': batch["camera_center"]}
97 |
98 | data["lview"]["pts"] = depth2pc(data["lview"]["depth"], torch.inverse(batch["lpose"]), batch["intrinsic"])
99 | data["rview"]["pts"] = depth2pc(data["rview"]["depth"], torch.inverse(batch["rpose"]), batch["intrinsic"])
100 |
101 | pred = pts2render(data, [0.,0.,0.])[:,0]
102 | end_time = time.time()
103 | execution_time = end_time - start_time
104 |
105 | pred = pred[:,None]
106 | loss = F.l1_loss(pred.squeeze(), gt.squeeze())
107 | l1_list.append(loss.item())
108 |
109 | count += 1
110 | # if idx == show_idx:
111 | psnr_list.append(torch.mean(psnr(pred, gt)).item())
112 | ssim_list.append(ssim(pred, gt).item())
113 | lpips_list.append(torch.mean(loss_fn_vgg(pred*2-1, gt*2-1)).item())
114 | if idx in show_idx:
115 | tmp_gt = (gt[0]*255.0).cpu().numpy().astype(np.uint8).squeeze()
116 | tmp_pred = (pred[0]*255.0).cpu().numpy().astype(np.uint8).squeeze()
117 | tmp_img_name = '%s/%d/step%s_idx%d.jpg' % (cfg.record.show_path, scene_num, self.total_steps, idx)
118 | imageio.imsave(tmp_img_name, np.concatenate([tmp_pred, tmp_gt], axis=0))
119 |
120 | val_psnr = np.round(np.mean(np.array(psnr_list)), 8)
121 | val_ssim = np.round(np.mean(np.array(ssim_list)), 8)
122 | val_lpips = np.round(np.mean(np.array(lpips_list)), 8)
123 | print(f"Non masked and selected Metrics ({self.total_steps}):, psnr {val_psnr}, ssim {val_ssim}, lpips {val_lpips}")
124 | self.logger.write_dict({'NO masked psnr on val set': val_psnr}, write_step=self.total_steps)
125 | torch.cuda.empty_cache()
126 |
127 | def save_ckpt(self, save_path, show_log=True):
128 | if show_log:
129 | print(f"Save checkpoint to {save_path} ...")
130 |
131 | torch.save({
132 | 'total_steps': self.total_steps,
133 | 'network': self.model.state_dict(),
134 | 'optimizer': self.optimizer.state_dict(),
135 | 'scheduler': self.scheduler.state_dict()
136 | }, save_path)
137 |
138 | def load_ckpt(self, load_path, load_optimizer=True, strict=True):
139 | assert os.path.exists(load_path)
140 | print(f"Loading checkpoint from {load_path} ...")
141 | ckpt = torch.load(load_path, map_location='cuda')
142 |
143 | self.model.load_state_dict(ckpt['network'], strict=strict)
144 | print(f"Parameter loading done")
145 | if load_optimizer:
146 | self.total_steps = ckpt['total_steps'] + 1
147 | self.logger.total_steps = self.total_steps
148 | self.optimizer.load_state_dict(ckpt['optimizer'])
149 | self.scheduler.load_state_dict(ckpt['scheduler'])
150 | print(f"Optimizer loading done")
151 |
152 | if __name__ == "__main__":
153 | trainer = Trainer()
154 | trainer.load_ckpt(cfg.pretrain_ckpt, load_optimizer=False)
155 | trainer.model.eval()
156 | trainer.run_eval()
157 |
158 |
159 |
160 |
--------------------------------------------------------------------------------
/lib/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/recorder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/recorder.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/recorder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/recorder.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import cfg, args
--------------------------------------------------------------------------------
/lib/config/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/config.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/yacs.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/yacs.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/yacs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/yacs.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/yacs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/config/__pycache__/yacs.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/config/config.py:
--------------------------------------------------------------------------------
1 | from .yacs import CfgNode as CN
2 | import argparse
3 | import os
4 | import numpy as np
5 | from . import yacs
6 | from datetime import datetime
7 | from pathlib import Path
8 |
9 | cfg = CN()
10 | cfg.task = 'hello'
11 | cfg.gpus = [0]
12 | cfg.exp_name = 'depth_pred'
13 |
14 | cfg.record = CN()
15 |
16 | def parse_cfg(cfg, args):
17 | if len(cfg.task) == 0:
18 | raise ValueError('task must be specified')
19 |
20 | # assign the gpus
21 | # if -1 not in cfg.gpus:
22 | # os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join([str(gpu) for gpu in cfg.gpus])
23 |
24 | if 'bbox' in cfg:
25 | bbox = np.array(cfg.bbox).reshape((2, 3))
26 | center, half_size = np.mean(bbox, axis=0), (bbox[1]-bbox[0]).max().item() / 2.
27 | bbox = np.stack([center-half_size, center+half_size])
28 | cfg.bbox = bbox.reshape(6).tolist()
29 |
30 | print('EXP NAME: ', cfg.exp_name)
31 |
32 | cfg.local_rank = args.local_rank
33 |
34 | modules = [key for key in cfg if '_module' in key]
35 | for module in modules:
36 | cfg[module.replace('_module', '_path')] = cfg[module].replace('.', '/') + '.py'
37 |
38 | def make_cfg(args):
39 | def merge_cfg(cfg_file, cfg):
40 | with open(cfg_file, 'r') as f:
41 | current_cfg = yacs.load_cfg(f)
42 | if 'parent_cfg' in current_cfg.keys():
43 | cfg = merge_cfg(current_cfg.parent_cfg, cfg)
44 | cfg.merge_from_other_cfg(current_cfg)
45 | else:
46 | cfg.merge_from_other_cfg(current_cfg)
47 | print(cfg_file)
48 | return cfg
49 | cfg_ = merge_cfg(args.cfg_file, cfg)
50 | try:
51 | index = args.opts.index('other_opts')
52 | cfg_.merge_from_list(args.opts[:index])
53 | except:
54 | cfg_.merge_from_list(args.opts)
55 | parse_cfg(cfg_, args)
56 | return cfg_
57 |
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("--cfg_file", default="configs/Ev3D_pretrain.yaml", type=str)
60 | parser.add_argument('--local_rank', type=int, default=0)
61 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)
62 | args = parser.parse_args()
63 |
64 | cfg = make_cfg(args)
65 |
66 | dt = datetime.today()
67 | cfg.exp_name = '%s_%s%s' % (cfg.exp_name, str(dt.month).zfill(2), str(dt.day).zfill(2))
68 | cfg.record.ckpt_path = "experiments/%s/ckpt" % cfg.exp_name
69 | cfg.record.show_path = "experiments/%s/show" % cfg.exp_name
70 | cfg.record.logs_path = "experiments/%s/logs" % cfg.exp_name
71 | cfg.record.file_path = "experiments/%s/file" % cfg.exp_name
72 |
73 | for path in [cfg.record.ckpt_path, cfg.record.show_path, cfg.record.logs_path, cfg.record.file_path]:
74 | Path(path).mkdir(exist_ok=True, parents=True)
--------------------------------------------------------------------------------
/lib/config/yacs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """YACS -- Yet Another Configuration System is designed to be a simple
17 | configuration management system for academic and industrial research
18 | projects.
19 |
20 | See README.md for usage and examples.
21 | """
22 |
23 | import copy
24 | import io
25 | import logging
26 | import os
27 | from ast import literal_eval
28 |
29 | import yaml
30 |
31 |
32 | # Flag for py2 and py3 compatibility to use when separate code paths are necessary
33 | # When _PY2 is False, we assume Python 3 is in use
34 | _PY2 = False
35 |
36 | # Filename extensions for loading configs from files
37 | _YAML_EXTS = {"", ".yaml", ".yml"}
38 | _PY_EXTS = {".py"}
39 |
40 | # py2 and py3 compatibility for checking file object type
41 | # We simply use this to infer py2 vs py3
42 | try:
43 | _FILE_TYPES = (file, io.IOBase)
44 | _PY2 = True
45 | except NameError:
46 | _FILE_TYPES = (io.IOBase,)
47 |
48 | # CfgNodes can only contain a limited set of valid types
49 | _VALID_TYPES = {tuple, list, str, int, float, bool}
50 | # py2 allow for str and unicode
51 | if _PY2:
52 | _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
53 |
54 | # Utilities for importing modules from file paths
55 | if _PY2:
56 | # imp is available in both py2 and py3 for now, but is deprecated in py3
57 | import imp
58 | else:
59 | import importlib.util
60 |
61 | logger = logging.getLogger(__name__)
62 |
63 |
64 | class CfgNode(dict):
65 | """
66 | CfgNode represents an internal node in the configuration tree. It's a simple
67 | dict-like container that allows for attribute-based access to keys.
68 | """
69 |
70 | IMMUTABLE = "__immutable__"
71 | DEPRECATED_KEYS = "__deprecated_keys__"
72 | RENAMED_KEYS = "__renamed_keys__"
73 |
74 | def __init__(self, init_dict=None, key_list=None):
75 | # Recursively convert nested dictionaries in init_dict into CfgNodes
76 | init_dict = {} if init_dict is None else init_dict
77 | key_list = [] if key_list is None else key_list
78 | for k, v in init_dict.items():
79 | if type(v) is dict:
80 | # Convert dict to CfgNode
81 | init_dict[k] = CfgNode(v, key_list=key_list + [k])
82 | else:
83 | # Check for valid leaf type or nested CfgNode
84 | _assert_with_logging(
85 | _valid_type(v, allow_cfg_node=True),
86 | "Key {} with value {} is not a valid type; valid types: {}".format(
87 | ".".join(key_list + [k]), type(v), _VALID_TYPES
88 | ),
89 | )
90 | super(CfgNode, self).__init__(init_dict)
91 | # Manage if the CfgNode is frozen or not
92 | self.__dict__[CfgNode.IMMUTABLE] = False
93 | # Deprecated options
94 | # If an option is removed from the code and you don't want to break existing
95 | # yaml configs, you can add the full config key as a string to the set below.
96 | self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
97 | # Renamed options
98 | # If you rename a config option, record the mapping from the old name to the new
99 | # name in the dictionary below. Optionally, if the type also changed, you can
100 | # make the value a tuple that specifies first the renamed key and then
101 | # instructions for how to edit the config file.
102 | self.__dict__[CfgNode.RENAMED_KEYS] = {
103 | # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
104 | # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
105 | # 'EXAMPLE.NEW.KEY',
106 | # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
107 | # + "'foo:bar' -> ('foo', 'bar')"
108 | # ),
109 | }
110 |
111 | def __getattr__(self, name):
112 | if name in self:
113 | return self[name]
114 | else:
115 | raise AttributeError(name)
116 |
117 | def __setattr__(self, name, value):
118 | if self.is_frozen():
119 | raise AttributeError(
120 | "Attempted to set {} to {}, but CfgNode is immutable".format(
121 | name, value
122 | )
123 | )
124 |
125 | _assert_with_logging(
126 | name not in self.__dict__,
127 | "Invalid attempt to modify internal CfgNode state: {}".format(name),
128 | )
129 | _assert_with_logging(
130 | _valid_type(value, allow_cfg_node=True),
131 | "Invalid type {} for key {}; valid types = {}".format(
132 | type(value), name, _VALID_TYPES
133 | ),
134 | )
135 |
136 | self[name] = value
137 |
138 | def __str__(self):
139 | def _indent(s_, num_spaces):
140 | s = s_.split("\n")
141 | if len(s) == 1:
142 | return s_
143 | first = s.pop(0)
144 | s = [(num_spaces * " ") + line for line in s]
145 | s = "\n".join(s)
146 | s = first + "\n" + s
147 | return s
148 |
149 | r = ""
150 | s = []
151 | for k, v in sorted(self.items()):
152 | seperator = "\n" if isinstance(v, CfgNode) else " "
153 | attr_str = "{}:{}{}".format(str(k), seperator, str(v))
154 | attr_str = _indent(attr_str, 2)
155 | s.append(attr_str)
156 | r += "\n".join(s)
157 | return r
158 |
159 | def __repr__(self):
160 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
161 |
162 | def dump(self):
163 | """Dump to a string."""
164 | self_as_dict = _to_dict(self)
165 | return yaml.safe_dump(self_as_dict)
166 |
167 | def merge_from_file(self, cfg_filename):
168 | """Load a yaml config file and merge it this CfgNode."""
169 | with open(cfg_filename, "r") as f:
170 | cfg = load_cfg(f)
171 | self.merge_from_other_cfg(cfg)
172 |
173 | def merge_from_other_cfg(self, cfg_other):
174 | """Merge `cfg_other` into this CfgNode."""
175 | _merge_a_into_b(cfg_other, self, self, [])
176 |
177 | def merge_from_list(self, cfg_list):
178 | """Merge config (keys, values) in a list (e.g., from command line) into
179 | this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
180 | """
181 | _assert_with_logging(
182 | len(cfg_list) % 2 == 0,
183 | "Override list has odd length: {}; it must be a list of pairs".format(
184 | cfg_list
185 | ),
186 | )
187 | root = self
188 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
189 | if root.key_is_deprecated(full_key):
190 | continue
191 | if root.key_is_renamed(full_key):
192 | root.raise_key_rename_error(full_key)
193 | key_list = full_key.split(".")
194 | d = self
195 | for subkey in key_list[:-1]:
196 | _assert_with_logging(
197 | subkey in d, "Non-existent key: {}".format(full_key)
198 | )
199 | d = d[subkey]
200 | subkey = key_list[-1]
201 | _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
202 | value = _decode_cfg_value(v)
203 | value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
204 | d[subkey] = value
205 |
206 | def freeze(self):
207 | """Make this CfgNode and all of its children immutable."""
208 | self._immutable(True)
209 |
210 | def defrost(self):
211 | """Make this CfgNode and all of its children mutable."""
212 | self._immutable(False)
213 |
214 | def is_frozen(self):
215 | """Return mutability."""
216 | return self.__dict__[CfgNode.IMMUTABLE]
217 |
218 | def _immutable(self, is_immutable):
219 | """Set immutability to is_immutable and recursively apply the setting
220 | to all nested CfgNodes.
221 | """
222 | self.__dict__[CfgNode.IMMUTABLE] = is_immutable
223 | # Recursively set immutable state
224 | for v in self.__dict__.values():
225 | if isinstance(v, CfgNode):
226 | v._immutable(is_immutable)
227 | for v in self.values():
228 | if isinstance(v, CfgNode):
229 | v._immutable(is_immutable)
230 |
231 | def clone(self):
232 | """Recursively copy this CfgNode."""
233 | return copy.deepcopy(self)
234 |
235 | def register_deprecated_key(self, key):
236 | """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
237 | keys a warning is generated and the key is ignored.
238 | """
239 | _assert_with_logging(
240 | key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
241 | "key {} is already registered as a deprecated key".format(key),
242 | )
243 | self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
244 |
245 | def register_renamed_key(self, old_name, new_name, message=None):
246 | """Register a key as having been renamed from `old_name` to `new_name`.
247 | When merging a renamed key, an exception is thrown alerting to user to
248 | the fact that the key has been renamed.
249 | """
250 | _assert_with_logging(
251 | old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
252 | "key {} is already registered as a renamed cfg key".format(old_name),
253 | )
254 | value = new_name
255 | if message:
256 | value = (new_name, message)
257 | self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
258 |
259 | def key_is_deprecated(self, full_key):
260 | """Test if a key is deprecated."""
261 | if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
262 | logger.warning("Deprecated config key (ignoring): {}".format(full_key))
263 | return True
264 | return False
265 |
266 | def key_is_renamed(self, full_key):
267 | """Test if a key is renamed."""
268 | return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
269 |
270 | def raise_key_rename_error(self, full_key):
271 | new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
272 | if isinstance(new_key, tuple):
273 | msg = " Note: " + new_key[1]
274 | new_key = new_key[0]
275 | else:
276 | msg = ""
277 | raise KeyError(
278 | "Key {} was renamed to {}; please update your config.{}".format(
279 | full_key, new_key, msg
280 | )
281 | )
282 |
283 |
284 | def load_cfg(cfg_file_obj_or_str):
285 | """Load a cfg. Supports loading from:
286 | - A file object backed by a YAML file
287 | - A file object backed by a Python source file that exports an attribute
288 | "cfg" that is either a dict or a CfgNode
289 | - A string that can be parsed as valid YAML
290 | """
291 | _assert_with_logging(
292 | isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
293 | "Expected first argument to be of type {} or {}, but it was {}".format(
294 | _FILE_TYPES, str, type(cfg_file_obj_or_str)
295 | ),
296 | )
297 | if isinstance(cfg_file_obj_or_str, str):
298 | return _load_cfg_from_yaml_str(cfg_file_obj_or_str)
299 | elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
300 | return _load_cfg_from_file(cfg_file_obj_or_str)
301 | else:
302 | raise NotImplementedError("Impossible to reach here (unless there's a bug)")
303 |
304 |
305 | def _load_cfg_from_file(file_obj):
306 | """Load a config from a YAML file or a Python source file."""
307 | _, file_extension = os.path.splitext(file_obj.name)
308 | if file_extension in _YAML_EXTS:
309 | return _load_cfg_from_yaml_str(file_obj.read())
310 | elif file_extension in _PY_EXTS:
311 | return _load_cfg_py_source(file_obj.name)
312 | else:
313 | raise Exception(
314 | "Attempt to load from an unsupported file type {}; "
315 | "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
316 | )
317 |
318 |
319 | def _load_cfg_from_yaml_str(str_obj):
320 | """Load a config from a YAML string encoding."""
321 | cfg_as_dict = yaml.safe_load(str_obj)
322 | return CfgNode(cfg_as_dict)
323 |
324 |
325 | def _load_cfg_py_source(filename):
326 | """Load a config from a Python source file."""
327 | module = _load_module_from_file("yacs.config.override", filename)
328 | _assert_with_logging(
329 | hasattr(module, "cfg"),
330 | "Python module from file {} must have 'cfg' attr".format(filename),
331 | )
332 | VALID_ATTR_TYPES = {dict, CfgNode}
333 | _assert_with_logging(
334 | type(module.cfg) in VALID_ATTR_TYPES,
335 | "Imported module 'cfg' attr must be in {} but is {} instead".format(
336 | VALID_ATTR_TYPES, type(module.cfg)
337 | ),
338 | )
339 | if type(module.cfg) is dict:
340 | return CfgNode(module.cfg)
341 | else:
342 | return module.cfg
343 |
344 |
345 | def _to_dict(cfg_node):
346 | """Recursively convert all CfgNode objects to dict objects."""
347 |
348 | def convert_to_dict(cfg_node, key_list):
349 | if not isinstance(cfg_node, CfgNode):
350 | _assert_with_logging(
351 | _valid_type(cfg_node),
352 | "Key {} with value {} is not a valid type; valid types: {}".format(
353 | ".".join(key_list), type(cfg_node), _VALID_TYPES
354 | ),
355 | )
356 | return cfg_node
357 | else:
358 | cfg_dict = dict(cfg_node)
359 | for k, v in cfg_dict.items():
360 | cfg_dict[k] = convert_to_dict(v, key_list + [k])
361 | return cfg_dict
362 |
363 | return convert_to_dict(cfg_node, [])
364 |
365 |
366 | def _valid_type(value, allow_cfg_node=False):
367 | return (type(value) in _VALID_TYPES) or (allow_cfg_node and type(value) == CfgNode)
368 |
369 |
370 | def _merge_a_into_b(a, b, root, key_list):
371 | """Merge config dictionary a into config dictionary b, clobbering the
372 | options in b whenever they are also specified in a.
373 | """
374 | _assert_with_logging(
375 | isinstance(a, CfgNode),
376 | "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
377 | )
378 | _assert_with_logging(
379 | isinstance(b, CfgNode),
380 | "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
381 | )
382 |
383 | for k, v_ in a.items():
384 | full_key = ".".join(key_list + [k])
385 | # a must specify keys that are in b
386 | if k not in b:
387 | if root.key_is_deprecated(full_key):
388 | continue
389 | elif root.key_is_renamed(full_key):
390 | root.raise_key_rename_error(full_key)
391 | else:
392 | v = copy.deepcopy(v_)
393 | v = _decode_cfg_value(v)
394 | b.update({k: v})
395 | else:
396 | v = copy.deepcopy(v_)
397 | v = _decode_cfg_value(v)
398 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
399 |
400 | # Recursively merge dicts
401 | if isinstance(v, CfgNode):
402 | try:
403 | _merge_a_into_b(v, b[k], root, key_list + [k])
404 | except BaseException:
405 | raise
406 | else:
407 | b[k] = v
408 |
409 |
410 | def _decode_cfg_value(v):
411 | """Decodes a raw config value (e.g., from a yaml config files or command
412 | line argument) into a Python object.
413 | """
414 | # Configs parsed from raw yaml will contain dictionary keys that need to be
415 | # converted to CfgNode objects
416 | if isinstance(v, dict):
417 | return CfgNode(v)
418 | # All remaining processing is only applied to strings
419 | if not isinstance(v, str):
420 | return v
421 | # Try to interpret `v` as a:
422 | # string, number, tuple, list, dict, boolean, or None
423 | try:
424 | v = literal_eval(v)
425 | # The following two excepts allow v to pass through when it represents a
426 | # string.
427 | #
428 | # Longer explanation:
429 | # The type of v is always a string (before calling literal_eval), but
430 | # sometimes it *represents* a string and other times a data structure, like
431 | # a list. In the case that v represents a string, what we got back from the
432 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
433 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
434 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
435 | # will raise a SyntaxError.
436 | except ValueError:
437 | pass
438 | except SyntaxError:
439 | pass
440 | return v
441 |
442 |
443 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
444 | """Checks that `replacement`, which is intended to replace `original` is of
445 | the right type. The type is correct if it matches exactly or is one of a few
446 | cases in which the type can be easily coerced.
447 | """
448 | original_type = type(original)
449 | replacement_type = type(replacement)
450 |
451 | # The types must match (with some exceptions)
452 | if replacement_type == original_type:
453 | return replacement
454 |
455 | # Cast replacement from from_type to to_type if the replacement and original
456 | # types match from_type and to_type
457 | def conditional_cast(from_type, to_type):
458 | if replacement_type == from_type and original_type == to_type:
459 | return True, to_type(replacement)
460 | else:
461 | return False, None
462 |
463 | # Conditionally casts
464 | # list <-> tuple
465 | casts = [(tuple, list), (list, tuple)]
466 | # For py2: allow converting from str (bytes) to a unicode string
467 | try:
468 | casts.append((str, unicode)) # noqa: F821
469 | except Exception:
470 | pass
471 |
472 | for (from_type, to_type) in casts:
473 | converted, converted_value = conditional_cast(from_type, to_type)
474 | if converted:
475 | return converted_value
476 |
477 | raise ValueError(
478 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
479 | "key: {}".format(
480 | original_type, replacement_type, original, replacement, full_key
481 | )
482 | )
483 |
484 |
485 | def _assert_with_logging(cond, msg):
486 | if not cond:
487 | logger.debug(msg)
488 | assert cond, msg
489 |
490 |
491 | def _load_module_from_file(name, filename):
492 | if _PY2:
493 | module = imp.load_source(name, filename)
494 | else:
495 | spec = importlib.util.spec_from_file_location(name, filename)
496 | module = importlib.util.module_from_spec(spec)
497 | spec.loader.exec_module(module)
498 | return module
--------------------------------------------------------------------------------
/lib/dataset/Ev3D.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("/home/jiaxu/jx/EvGGS/")
3 | from natsort import natsorted
4 | import open3d as o3d
5 | import h5py
6 | import os
7 | import numpy as np
8 | import torch
9 | from .utils import events_to_voxel_grid
10 | from torch.utils.data import Dataset, DataLoader
11 | import cv2
12 | import glob
13 | from tqdm import tqdm
14 | from torch.utils.data import ConcatDataset
15 | from lib.renderer.rend_utils import getProjectionMatrix, getWorld2View2, focal2fov
16 | from lib.config import cfg, args
17 |
18 | def depth2pc_np_ours(depth, extrinsic, intrinsic, isdisparity=False):
19 | H, W = depth.shape
20 | x_ref, y_ref = np.meshgrid(np.arange(0, W), np.arange(0, H))
21 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
22 |
23 | xyz_ref = np.matmul(np.linalg.inv(intrinsic[:3, :3]),
24 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth.reshape([-1]))
25 | xyz_world = np.matmul(np.linalg.inv(extrinsic), np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
26 |
27 | return xyz_world.transpose((1, 0)).astype(np.float32)
28 |
29 | def find_files(dir, exts):
30 | if os.path.isdir(dir):
31 | files_grabbed = []
32 | for ext in exts:
33 | files_grabbed.extend(glob.glob(os.path.join(dir, ext)))
34 | if len(files_grabbed) > 0:
35 | files_grabbed = sorted(files_grabbed)
36 | return files_grabbed
37 | else:
38 | return []
39 |
40 | def parse_txt(filename, shape):
41 | assert os.path.isfile(filename)
42 | nums = open(filename).read().split()
43 | return np.array([float(x) for x in nums]).reshape(shape).astype(np.float32)
44 |
45 | def concatenate_datasets_ratio(base_folders, dataset_type, split, dataset_kwargs={}):
46 | scene_lists = natsorted(os.listdir(os.path.join(base_folders, 'Event')))
47 | n_scenes = len(scene_lists)
48 | ratio = int(cfg.dataset.ratio * n_scenes)
49 |
50 | if split == "train":
51 | scene_lists = scene_lists[:ratio]
52 | elif split == "val":
53 | scene_lists = scene_lists[ratio:]
54 |
55 | dataset_list = []
56 | for i in range(len(scene_lists)):
57 | dataset_list.append(dataset_type(base_folders, scene_lists[i], **dataset_kwargs))
58 | return ConcatDataset(dataset_list)
59 |
60 | def concatenate_datasets_split(base_folders, dataset_type, split, dataset_kwargs={}):
61 | if split == "train":
62 | scenes_path = os.path.join(base_folders, "train_scenes.txt")
63 | elif split == "test":
64 | scenes_path = os.path.join(base_folders, "test_scenes.txt")
65 | elif split == "val":
66 | scenes_path = os.path.join(base_folders, "val_scenes.txt")
67 |
68 | with open(scenes_path, 'r', encoding='utf-8') as file:
69 | scene_lists = [line.strip() for line in file.readlines()]
70 |
71 | dataset_list = []
72 | for i in range(len(scene_lists)):
73 | dataset_list.append(dataset_type(base_folders, scene_lists[i], **dataset_kwargs))
74 | return ConcatDataset(dataset_list)
75 |
76 | T = np.array([[1,0,0,0],
77 | [0,-1,0,0],
78 | [0,0,-1,0],
79 | [0,0,0,1]])
80 |
81 | class EventDataloader(DataLoader):
82 | def __init__(self, base_folders, split, num_workers, batch_size, shuffle=True):
83 | dataset = concatenate_datasets_split(base_folders, ReadEventFromH5, split=split)
84 | super().__init__(dataset, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle)
85 |
86 | cs = cfg.cs
87 |
88 | class ReadEventFromH5(Dataset):
89 | def __init__(self, base_folder, scene, polarity_offset=0):
90 | self.base_folder = base_folder
91 | self.scene = scene
92 | self.polarity_offset = polarity_offset
93 | self.H, self.W = 480, 640
94 | self.cropped_H, self.cropped_W = cs[1]-cs[0], cs[3] - cs[2]
95 | self.event_slices()
96 |
97 | def event_slices(self):
98 | ## load .h5 event files and generate event frames and voxels
99 | self.event_files_path = os.path.join(self.base_folder, "Event", self.scene)
100 | scene_files_path = os.path.join(self.base_folder, "Scenes", self.scene)
101 | self.pose_files = find_files('{}/Poses'.format(self.base_folder), exts=['*.txt'])
102 | self.num_views = len(self.pose_files)
103 | intrinsic_files = find_files('{}/Intrinsics'.format(self.base_folder), exts=['*.txt'])[:self.num_views]
104 | self.npz_files = find_files('{}'.format(scene_files_path), exts=["*.npz"])[:self.num_views]
105 | self.rgb_files = find_files('{}'.format(scene_files_path), exts=['*.png'])[:self.num_views]
106 |
107 | self.intrinsics = parse_txt(intrinsic_files[0], (4, 4))
108 |
109 | def __len__(self):
110 | return len(self.pose_files)
111 |
112 | def events_to_voxel(self, events):
113 | # generate a voxel grid from input events using temporal bilinear interpolation.
114 | x, y, t, p = events
115 | x = x.astype(np.int32)
116 | y = y.astype(np.int32)
117 | p = p.astype(np.int32)
118 | mask_pos = p.copy()
119 | mask_neg = p.copy()
120 | mask_pos[p < 0] = 0
121 | mask_neg[p > 0] = 0
122 | frame1 = self.events_to_image(x, y, p * mask_pos)
123 | frame2 = self.events_to_image(x, y, p * mask_neg)
124 | frame3 = frame1 - frame2
125 | # cv2.imwrite('1.png', 128+frame1)
126 | # cv2.imwrite('2.png', 128-frame2)
127 | # cv2.imwrite('3.png', 128+frame3)
128 | return np.stack(((128+frame1)/255, (128-frame2)/255, (128 + frame3)/255), axis=2)
129 |
130 | def events_to_image(self, xs, ys, ps):
131 | # accumulate events into an image.
132 | img = np.zeros((self.H, self.W))
133 | np.add.at(img, (ys, xs), ps)
134 |
135 | # img = np.clip(img, -5, 5)
136 | # print(img)
137 | return img
138 |
139 | def find_depth(self, npz_files, idx):
140 | npz = np.load(npz_files[idx], allow_pickle=True)
141 | depth = npz['depth_map']
142 | depth = self.prepare_depth(depth)
143 | return depth
144 |
145 | def find_pose(self, npz_files, idx):
146 | npz = np.load(npz_files[idx], allow_pickle=True)
147 | poses = npz['object_poses']
148 | for obj in poses:
149 | obj_name = obj['name']
150 | obj_mat = obj['pose']
151 | if obj_name == 'Camera':
152 | pose = obj_mat.astype(np.float32)
153 | break
154 | return pose @ T
155 |
156 | def prepare_depth(self, depth):
157 | # adjust depth maps generated by vision blender
158 | INVALID_DEPTH = -1
159 | depth[depth == INVALID_DEPTH] = 0
160 |
161 | return depth
162 |
163 | def accumulate_events_edited(self, events):
164 | x, y, t, p = events
165 |
166 | def events_to_frame(self, events):
167 | # generate a voxel grid from input events using temporal bilinear interpolation.
168 | x, y, t, p = events
169 | x = x.astype(np.int32)
170 | y = y.astype(np.int32)
171 | p = p.astype(np.int32)
172 | mask_pos = p.copy()
173 | mask_neg = p.copy()
174 | mask_pos[p < 0] = 0
175 | mask_neg[p > 0] = 0
176 | frame1 = self.events_to_image(x, y, p * mask_pos)
177 | frame2 = self.events_to_image(x, y, p * mask_neg)
178 | frame3 = frame1 - frame2
179 |
180 | return np.stack(((128 + frame1)/255, (128-frame2)/255, (128 + frame3)/255), axis=2)
181 |
182 | def events_to_image(self, xs, ys, ps):
183 | # accumulate events into an image.
184 | img = np.zeros((self.H, self.W))
185 | np.add.at(img, (ys, xs), ps)
186 | # print(img.max(), img.min())
187 | # img = np.clip(img, -5, 5)
188 | # print(img)
189 | return img
190 |
191 | def accumulate_events(self, events, resolution_level=1, polarity_offset=0):
192 | x, y, t, p = events
193 | acc_frm = np.zeros((self.H, self.W))
194 | np.add.at(acc_frm, (y // resolution_level, x // resolution_level), p + polarity_offset)
195 | return acc_frm
196 |
197 | def __getitem__(self, idx):
198 | index = str(idx).zfill(4)
199 | left_event1 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(index)))
200 |
201 | # left_event_voxel = events_to_voxel_grid(left_event.transpose((1,0)), cfg.model.num_bins, self.W, self.H)
202 | # left_pose = parse_txt(self.pose_files[idx], (4,4))
203 |
204 | left_pose = self.find_pose(self.npz_files, idx)
205 | left_depth_gt = self.find_depth(self.npz_files, idx)
206 | left_mask = (left_depth_gt > 0)
207 | left_img = cv2.cvtColor(cv2.imread(self.rgb_files[idx])[...,:3] * left_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) / 255.
208 |
209 | if idx + 1 < len(self.pose_files):
210 | left_event2 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(str(idx+1).zfill(4))))
211 | center_depth_gt = self.find_depth(self.npz_files, idx+1)
212 | center_mask = (center_depth_gt > 0)
213 | # center_pose = parse_txt(self.pose_files[idx+1], (4,4))
214 | center_pose = self.find_pose(self.npz_files, idx+1)
215 |
216 | # try:
217 | int_img = cv2.cvtColor(cv2.imread(self.rgb_files[idx+1])[...,:3] * center_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) / 255.
218 |
219 | else:
220 | left_event2 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(str(0).zfill(4))))
221 | center_depth_gt = self.find_depth(self.npz_files, 0)
222 | center_mask = (center_depth_gt > 0)
223 | # center_pose = parse_txt(self.pose_files[0], (4,4))
224 | center_pose = self.find_pose(self.npz_files, 0)
225 | int_img = cv2.cvtColor(cv2.imread(self.rgb_files[0])[...,:3] * center_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) /255.
226 |
227 | center_extrinsics = np.linalg.inv(center_pose)
228 |
229 | left_event_frame = self.events_to_frame(np.hstack((left_event1, left_event2)))
230 | left_event_voxel = events_to_voxel_grid(np.hstack((left_event1, left_event2)).transpose((1,0)), cfg.model.num_bins, self.W, self.H)
231 | if idx + 2 < len(self.pose_files):
232 | r_id = idx + 2
233 | r_index = str(r_id).zfill(4)
234 | right_event1 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(r_index)))
235 | # right_pose = parse_txt(self.pose_files[r_id], (4,4))
236 | right_pose = self.find_pose(self.npz_files, r_id)
237 | right_depth_gt = self.find_depth(self.npz_files, r_id)
238 | else:
239 | r_id = (idx + 2) % len(self.pose_files)
240 | r_index = str(r_id).zfill(4)
241 | right_event1 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(r_index)))
242 | # right_pose = parse_txt(self.pose_files[r_id], (4,4))
243 | right_pose = self.find_pose(self.npz_files, r_id)
244 | right_depth_gt = self.find_depth(self.npz_files, r_id)
245 |
246 | if idx + 3 < len(self.pose_files):
247 | r_id2 = idx + 3
248 | r_index2 = str(r_id2).zfill(4)
249 | right_event2 = np.load(os.path.join(self.event_files_path,'{}.npy'.format(r_index2)))
250 | else:
251 | r_id2 = (idx + 3) % len(self.pose_files)
252 | r_index2 = str(r_id2).zfill(4)
253 | right_event2 = np.load(os.path.join(self.event_files_path, '{}.npy'.format(r_index2)))
254 |
255 |
256 | # pr = depth2pc_np_ours(right_depth_gt, np.linalg.inv(right_pose), self.intrinsics)
257 | # pl = depth2pc_np_ours(left_depth_gt, np.linalg.inv(left_pose), self.intrinsics)
258 | # pc = np.concatenate([pr, pl], axis=0)
259 | # pcd = o3d.geometry.PointCloud()
260 | # pcd.points = o3d.utility.Vector3dVector(pc)
261 | # o3d.io.write_point_cloud("pts.ply", pcd)
262 | # print(np.hstack((right_event1, right_event2)).shape)
263 | right_event_frame = self.events_to_frame(np.hstack((right_event1, right_event2)))
264 | right_event_voxel = events_to_voxel_grid(np.hstack((right_event1, right_event2)).transpose((1,0)), cfg.model.num_bins, self.W, self.H)
265 | # right_event_voxel = events_to_voxel_grid(right_event.transpose((1,0)), cfg.model.num_bins, self.W, self.H)
266 |
267 | right_mask = (right_depth_gt > 0)
268 | right_img = cv2.cvtColor(cv2.imread(self.rgb_files[r_id])[...,:3] * right_mask[..., np.newaxis], cv2.COLOR_BGR2GRAY) / 255.
269 |
270 | center_event = np.hstack((left_event2, right_event1))
271 | center_event_frame = self.events_to_frame(center_event)
272 | center_event_voxel = events_to_voxel_grid(center_event.transpose((1,0)), cfg.model.num_bins, self.W, self.H)
273 |
274 | intrinsic = self.intrinsics
275 |
276 | intrinsic[0,2] = (self.cropped_W - 1) / 2
277 | intrinsic[1,2] = (self.cropped_H - 1) / 2
278 |
279 | projection_matrix = getProjectionMatrix(znear=0.01, zfar=0.99, K=intrinsic, h=self.cropped_H, w=self.cropped_W).transpose(0, 1)
280 | world_view_transform = torch.tensor(getWorld2View2(center_extrinsics[:3,:3].reshape(3, 3).transpose(1, 0)\
281 | , center_extrinsics[:3, 3])).transpose(0, 1)
282 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
283 | camera_center = world_view_transform.inverse()[3, :3]
284 |
285 | item = {
286 | 'cim': int_img.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]][np.newaxis], #[1, H, W]
287 | 'lim': left_img.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]][np.newaxis],
288 | 'rim': right_img.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]][np.newaxis],
289 | 'leframe': left_event_frame.transpose((2,0,1)).astype(np.float32)[:, cs[0]:cs[1], cs[2]:cs[3]], #[3, H, W]
290 | 'reframe': right_event_frame.transpose((2,0,1)).astype(np.float32)[:, cs[0]:cs[1], cs[2]:cs[3]],
291 | 'ceframe': center_event_frame.transpose((2,0,1)).astype(np.float32)[:, cs[0]:cs[1], cs[2]:cs[3]],
292 | 'lmask': left_mask[cs[0]:cs[1], cs[2]:cs[3]], #[H, W]
293 | 'rmask': right_mask[cs[0]:cs[1], cs[2]:cs[3]],
294 | 'cmask': center_mask[cs[0]:cs[1], cs[2]:cs[3]],
295 | 'lpose': left_pose.astype(np.float32), #[4, 4]
296 | 'rpose': right_pose.astype(np.float32),
297 | 'intrinsic': intrinsic.astype(np.float32), #[4, 4]
298 | 'ldepth': left_depth_gt.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]], # #[H, W]
299 | 'rdepth': right_depth_gt.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]],
300 | 'cdepth': center_depth_gt.astype(np.float32)[cs[0]:cs[1], cs[2]:cs[3]],
301 | 'center_voxel':center_event_voxel[:, cs[0]:cs[1], cs[2]:cs[3]], #[5, H, W]
302 | 'right_voxel':right_event_voxel[:, cs[0]:cs[1], cs[2]:cs[3]],
303 | 'left_voxel':left_event_voxel[:, cs[0]:cs[1], cs[2]:cs[3]],
304 | ### target view rendering parameters ###
305 | "H":self.cropped_H,
306 | "W":self.cropped_W,
307 | "FovX":focal2fov(intrinsic[0, 0], self.cropped_W),
308 | "FovY":focal2fov(intrinsic[1, 1], self.cropped_H),
309 | 'world_view_transform': world_view_transform, #[4, 4]
310 | 'full_proj_transform': full_proj_transform, #[4, 4]
311 | 'camera_center': camera_center #[3]
312 | }
313 | return item
314 |
315 |
316 |
317 |
318 |
319 | # if __name__ == "__main__":
320 | # dataset = ReadEventFromH5(r"/home/lsf_storage/dataset/EV3D5/", "AK47",0)
321 | # dataset[0]
322 | # # dataloader = EventDataloader(r"/home/lsf_storage/dataset/EV3D/", split="full", batch_size=1, num_workers=1, shuffle=False)
323 | # # print(len(dataloader))
324 |
325 | # # for idx, batch in enumerate(dataloader):
326 | # # print(batch["ldepth"].shape)
327 | # # break
--------------------------------------------------------------------------------
/lib/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from lib.config import cfg, args
2 |
3 |
4 | from .Ev3D import EventDataloader
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/Ev3D.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/Ev3D.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/Ev3D.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/Ev3D.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/Ev3D.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/Ev3D.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/dataset/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/dataset/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def events_to_voxel_grid(events, num_bins, width, height):
5 | """
6 | Build a voxel grid with bilinear interpolation in the time domain from a set of events.
7 |
8 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity]
9 | :param num_bins: number of bins in the temporal axis of the voxel grid
10 | :param width, height: dimensions of the voxel grid
11 | """
12 |
13 | assert(events.shape[1] == 4)
14 | assert(num_bins > 0)
15 | assert(width > 0)
16 | assert(height > 0)
17 |
18 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel()
19 |
20 | # normalize the event timestamps so that they lie between 0 and num_bins
21 | last_stamp = events[-1, 2]
22 | first_stamp = events[0, 2]
23 | deltaT = last_stamp - first_stamp
24 |
25 | if deltaT == 0:
26 | deltaT = 1.0
27 |
28 | events[:, 2] = (num_bins - 1) * (events[:, 2] - first_stamp) / deltaT
29 | ts = events[:, 2]
30 | xs = events[:, 0].astype(np.int)
31 | ys = events[:, 1].astype(np.int)
32 | pols = events[:, 3]
33 | pols[pols == 0] = -1 # polarity should be +1 / -1
34 |
35 | tis = ts.astype(np.int)
36 | dts = ts - tis
37 | vals_left = pols * (1.0 - dts)
38 | vals_right = pols * dts
39 |
40 | valid_indices = tis < num_bins
41 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width
42 | + tis[valid_indices] * width * height, vals_left[valid_indices])
43 |
44 | valid_indices = (tis + 1) < num_bins
45 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width
46 | + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices])
47 |
48 | voxel_grid = np.reshape(voxel_grid, (num_bins, height, width))
49 |
50 | return voxel_grid
51 |
52 |
53 |
--------------------------------------------------------------------------------
/lib/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | from math import exp
5 |
6 |
7 | def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9):
8 | """ Loss function defined over sequence of flow predictions """
9 |
10 | n_predictions = len(flow_preds)
11 | flow_loss = 0.0
12 |
13 | valid = (valid >= 0.5)
14 | assert not torch.isinf(flow_gt[valid.bool()]).any()
15 |
16 | for i in range(n_predictions):
17 | # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
18 | adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
19 | i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
20 | i_loss = (flow_preds[i] - flow_gt).abs()
21 | flow_loss += i_weight * i_loss[valid.bool()].mean()
22 |
23 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
24 | epe = epe.view(-1)[valid.view(-1)]
25 |
26 | metrics = {
27 | 'train_epe': epe.mean().item(),
28 | 'train_1px': (epe < 1).float().mean().item(),
29 | 'train_3px': (epe < 3).float().mean().item()
30 | }
31 |
32 | return flow_loss, metrics
33 |
34 |
35 | def l1_loss(network_output, gt):
36 | return torch.abs((network_output - gt)).mean()
37 |
38 | def mse_loss(out, gt, msk=None):
39 | if msk is None:
40 | loss = torch.mean((out - gt) ** 2)
41 | else:
42 | # loss = torch.mean((out[msk] - gt[msk]) ** 2)
43 | loss = torch.mean((out - gt) ** 2)
44 | return loss
45 |
46 | def gaussian(window_size, sigma):
47 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
48 | return gauss / gauss.sum()
49 |
50 |
51 | def create_window(window_size, channel):
52 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
53 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
54 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
55 | return window
56 |
57 |
58 | def ssim(img1, img2, window_size=11, size_average=True):
59 | channel = img1.size(-3)
60 | window = create_window(window_size, channel)
61 |
62 | if img1.is_cuda:
63 | window = window.cuda(img1.get_device())
64 | window = window.type_as(img1)
65 |
66 | return _ssim(img1, img2, window, window_size, channel, size_average)
67 |
68 |
69 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
70 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
71 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
72 |
73 | mu1_sq = mu1.pow(2)
74 | mu2_sq = mu2.pow(2)
75 | mu1_mu2 = mu1 * mu2
76 |
77 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
78 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
79 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
80 |
81 | C1 = 0.01 ** 2
82 | C2 = 0.03 ** 2
83 |
84 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
85 |
86 | if size_average:
87 | return ssim_map.mean()
88 | else:
89 | return ssim_map.mean(1).mean(1).mean(1)
90 |
91 |
92 | def psnr(img1, img2):
93 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
94 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
95 |
--------------------------------------------------------------------------------
/lib/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .asnet_utils import model_loss, model_loss_light
2 | from .recon_net import E2IM, E2DPT, E2Msk
3 | from .eventgaussian import EventGaussian
--------------------------------------------------------------------------------
/lib/network/__pycache__/ASNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/ASNet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/ASNet_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/ASNet_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/SegNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/SegNet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/asnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/asnet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/asnet_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/asnet_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/densenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/densenet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/dfanet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/dfanet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/eventgaussian.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/eventgaussian.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/firenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/firenet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/gsregressor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/gsregressor.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/mobilenetv2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/mobilenetv2.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/pspnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/pspnet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/recon_net.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/recon_net.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/submodules.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/submodules.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/swin.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/swin.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/__pycache__/unet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/network/__pycache__/unet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/network/asnet.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function
3 | import math
4 | import torch.nn as nn
5 | import torch.utils.data
6 | import torch.nn.functional as F
7 | from .asnet_utils import *
8 |
9 |
10 | class hourglass2D(nn.Module):
11 | def __init__(self, in_channels):
12 | super(hourglass2D, self).__init__()
13 |
14 | self.expanse_ratio = 2
15 |
16 | self.conv1 = MobileV2_Residual(in_channels, in_channels * 2, stride=2, expanse_ratio=self.expanse_ratio)
17 |
18 | self.conv2 = MobileV2_Residual(in_channels * 2, in_channels * 2, stride=1, expanse_ratio=self.expanse_ratio)
19 |
20 | self.conv3 = MobileV2_Residual(in_channels * 2, in_channels * 4, stride=2, expanse_ratio=self.expanse_ratio)
21 |
22 | self.conv4 = MobileV2_Residual(in_channels * 4, in_channels * 4, stride=1, expanse_ratio=self.expanse_ratio)
23 |
24 | self.conv5 = nn.Sequential(
25 | nn.ConvTranspose2d(in_channels * 4, in_channels * 2, 3, padding=1, output_padding=1, stride=2, bias=False),
26 | nn.BatchNorm2d(in_channels * 2))
27 |
28 | self.conv6 = nn.Sequential(
29 | nn.ConvTranspose2d(in_channels * 2, in_channels, 3, padding=1, output_padding=1, stride=2, bias=False),
30 | nn.BatchNorm2d(in_channels))
31 |
32 | self.redir1 = MobileV2_Residual(in_channels, in_channels, stride=1, expanse_ratio=self.expanse_ratio)
33 | self.redir2 = MobileV2_Residual(in_channels * 2, in_channels * 2, stride=1, expanse_ratio=self.expanse_ratio)
34 |
35 | def forward(self, x):
36 | conv1 = self.conv1(x)
37 | conv2 = self.conv2(conv1)
38 |
39 | conv3 = self.conv3(conv2)
40 | conv4 = self.conv4(conv3)
41 |
42 | conv5 = F.relu(self.conv5(conv4) + self.redir2(conv2), inplace=True)
43 | conv6 = F.relu(self.conv6(conv5) + self.redir1(x), inplace=True)
44 |
45 | return conv6
46 |
47 |
48 | class ASNet(nn.Module):
49 | def __init__(self, maxdisp):
50 |
51 | super(ASNet, self).__init__()
52 |
53 | self.maxdisp = maxdisp
54 |
55 | self.num_groups = 1
56 |
57 | self.volume_size = 48
58 |
59 | self.hg_size = 32
60 |
61 | self.dres_expanse_ratio = 3
62 |
63 | self.feature_extraction0 = feature_extraction()
64 |
65 |
66 | self.volume1 = volume_build(self.volume_size)
67 |
68 |
69 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio),
70 | nn.ReLU(inplace=True),
71 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
72 | nn.ReLU(inplace=True))
73 |
74 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
75 | nn.ReLU(inplace=True),
76 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio))
77 |
78 | self.encoder_decoder1 = hourglass2D(self.hg_size)
79 | self.encoder_decoder2 = hourglass2D(self.hg_size)
80 | self.encoder_decoder3 = hourglass2D(self.hg_size)
81 |
82 |
83 | self.classif0 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
84 | nn.ReLU(inplace=True),
85 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
86 | bias=False, dilation=1))
87 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
88 | nn.ReLU(inplace=True),
89 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
90 | bias=False, dilation=1))
91 | self.classif2 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
92 | nn.ReLU(inplace=True),
93 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
94 | bias=False, dilation=1))
95 | self.classif3 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
96 | nn.ReLU(inplace=True),
97 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
98 | bias=False, dilation=1))
99 |
100 |
101 | for m in self.modules():
102 | if isinstance(m, nn.Conv2d):
103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
104 | m.weight.data.normal_(0, math.sqrt(2. / n))
105 | elif isinstance(m, nn.Conv3d):
106 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
107 | m.weight.data.normal_(0, math.sqrt(2. / n))
108 | elif isinstance(m, nn.BatchNorm2d):
109 | m.weight.data.fill_(1)
110 | m.bias.data.zero_()
111 | elif isinstance(m, nn.BatchNorm3d):
112 | m.weight.data.fill_(1)
113 | m.bias.data.zero_()
114 | elif isinstance(m, nn.Linear):
115 | m.bias.data.zero_()
116 |
117 |
118 | def forward(self, L, R):
119 |
120 | featL = self.feature_extraction0(L)
121 | featR = self.feature_extraction0(R)
122 |
123 | xALL0 = self.volume1(featL, featR)
124 |
125 | cost0 = self.dres0(xALL0)
126 | cost0 = self.dres1(cost0) + cost0
127 |
128 | out1 = self.encoder_decoder1(cost0)
129 | out2 = self.encoder_decoder2(out1)
130 | out3 = self.encoder_decoder3(out2)
131 |
132 | if self.training:
133 | cost0 = self.classif0(cost0)
134 | cost1 = self.classif1(out1)
135 | cost2 = self.classif2(out2)
136 | cost3 = self.classif3(out3)
137 |
138 | cost0 = torch.unsqueeze(cost0, 1)
139 | cost0 = F.interpolate(cost0, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
140 |
141 | cost0 = torch.squeeze(cost0, 1)
142 | pred0 = F.softmax(cost0, dim=1)
143 | pred0 = disparity_regression(pred0, self.maxdisp)
144 |
145 | cost1 = torch.unsqueeze(cost1, 1)
146 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
147 | cost1 = torch.squeeze(cost1, 1)
148 | pred1 = F.softmax(cost1, dim=1)
149 | pred1 = disparity_regression(pred1, self.maxdisp)
150 |
151 | cost2 = torch.unsqueeze(cost2, 1)
152 | cost2 = F.interpolate(cost2, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
153 | cost2 = torch.squeeze(cost2, 1)
154 | pred2 = F.softmax(cost2, dim=1)
155 | pred2 = disparity_regression(pred2, self.maxdisp)
156 |
157 | cost3 = torch.unsqueeze(cost3, 1)
158 | cost3 = F.interpolate(cost3, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
159 | cost3 = torch.squeeze(cost3, 1)
160 | pred3 = F.softmax(cost3, dim=1)
161 | pred3 = disparity_regression(pred3, self.maxdisp)
162 |
163 | return [pred0, pred1, pred2, pred3]
164 |
165 | else:
166 | cost3 = self.classif3(out3)
167 | cost3 = torch.unsqueeze(cost3, 1)
168 | cost3 = F.interpolate(cost3, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
169 | cost3 = torch.squeeze(cost3, 1)
170 | pred3 = F.softmax(cost3, dim=1)
171 | pred3 = disparity_regression(pred3, self.maxdisp)
172 |
173 | return [pred3]
174 |
175 |
176 |
177 | class ASNet_light(nn.Module):
178 | def __init__(self, maxdisp):
179 |
180 | super(ASNet_light, self).__init__()
181 |
182 | self.maxdisp = maxdisp
183 |
184 | self.num_groups = 1
185 |
186 | self.volume_size = 48
187 |
188 | self.hg_size = 16
189 |
190 | self.dres_expanse_ratio = 3
191 |
192 | self.feature_extraction0 = feature_extraction()
193 |
194 | self.volume1 = volume_build(self.volume_size)
195 |
196 |
197 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio),
198 | nn.ReLU(inplace=True),
199 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
200 | nn.ReLU(inplace=True))
201 |
202 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
203 | nn.ReLU(inplace=True),
204 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio))
205 |
206 | self.encoder_decoder1 = hourglass2D(self.hg_size)
207 | # self.encoder_decoder2 = hourglass2D(self.hg_size)
208 | # self.encoder_decoder3 = hourglass2D(self.hg_size)
209 |
210 | self.classif0 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
211 | nn.ReLU(inplace=True),
212 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
213 | bias=False, dilation=1))
214 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
215 | nn.ReLU(inplace=True),
216 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
217 | bias=False, dilation=1))
218 |
219 |
220 | for m in self.modules():
221 | if isinstance(m, nn.Conv2d):
222 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
223 | m.weight.data.normal_(0, math.sqrt(2. / n))
224 | elif isinstance(m, nn.Conv3d):
225 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
226 | m.weight.data.normal_(0, math.sqrt(2. / n))
227 | elif isinstance(m, nn.BatchNorm2d):
228 | m.weight.data.fill_(1)
229 | m.bias.data.zero_()
230 | elif isinstance(m, nn.BatchNorm3d):
231 | m.weight.data.fill_(1)
232 | m.bias.data.zero_()
233 | elif isinstance(m, nn.Linear):
234 | m.bias.data.zero_()
235 |
236 | def forward(self, L, R):
237 | featL = self.feature_extraction0(L)
238 | featR = self.feature_extraction0(R)
239 |
240 | xALL0 = self.volume1(featL, featR)
241 |
242 | cost0 = self.dres0(xALL0)
243 | cost0 = self.dres1(cost0) + cost0
244 |
245 | out1 = self.encoder_decoder1(cost0)
246 |
247 | if self.training:
248 | cost0 = self.classif0(cost0)
249 | cost1 = self.classif1(out1)
250 |
251 | cost0 = torch.unsqueeze(cost0, 1)
252 | cost0 = F.interpolate(cost0, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
253 |
254 | cost0 = torch.squeeze(cost0, 1)
255 | pred0 = F.softmax(cost0, dim=1)
256 | pred0 = disparity_regression(pred0, self.maxdisp)
257 |
258 | cost1 = torch.unsqueeze(cost1, 1)
259 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
260 | cost1 = torch.squeeze(cost1, 1)
261 | pred1 = F.softmax(cost1, dim=1)
262 | pred1 = disparity_regression(pred1, self.maxdisp)
263 |
264 | return [pred0, pred1]
265 | else:
266 | cost1 = self.classif1(out1)
267 | cost1 = torch.unsqueeze(cost1, 1)
268 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
269 | cost1 = torch.squeeze(cost1, 1)
270 | pred1 = F.softmax(cost1, dim=1)
271 | pred1 = disparity_regression(pred1, self.maxdisp)
272 |
273 | return [pred1]
274 |
275 | def get_features(self, L, R):
276 | featL = self.feature_extraction0(L)
277 | featR = self.feature_extraction0(R)
278 |
279 | xALL0 = self.volume1(featL, featR)
280 |
281 | cost0 = self.dres0(xALL0)
282 | cost0 = self.dres1(cost0) + cost0
283 |
284 | out1 = self.encoder_decoder1(cost0)
285 |
286 | cost1 = self.classif1(out1)
287 | cost1 = torch.unsqueeze(cost1, 1)
288 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
289 | cost1 = torch.squeeze(cost1, 1)
290 | pred1 = F.softmax(cost1, dim=1)
291 | pred1 = disparity_regression(pred1, self.maxdisp)
292 |
293 | return pred1, F.interpolate(torch.unsqueeze(out1, 1), [self.hg_size, L.size()[2], L.size()[3]], mode='trilinear').squeeze(1)
294 |
295 |
296 |
297 | class ASNet_mask(nn.Module):
298 | def __init__(self, maxdisp):
299 |
300 | super(ASNet_mask, self).__init__()
301 |
302 | self.maxdisp = maxdisp
303 |
304 | self.num_groups = 1
305 |
306 | self.volume_size = 48
307 |
308 | self.hg_size = 32
309 |
310 | self.dres_expanse_ratio = 3
311 |
312 | self.feature_extraction0 = feature_extraction()
313 |
314 | self.volume1 = volume_build(self.volume_size)
315 |
316 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio),
317 | nn.ReLU(inplace=True),
318 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
319 | nn.ReLU(inplace=True))
320 |
321 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
322 | nn.ReLU(inplace=True),
323 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio))
324 |
325 | self.encoder_decoder1 = hourglass2D(self.hg_size)
326 | # self.encoder_decoder2 = hourglass2D(self.hg_size)
327 | # self.encoder_decoder3 = hourglass2D(self.hg_size)
328 |
329 | self.output_head = nn.Sequential(convbn(self.maxdisp, 1, 3, 1, 1, 1),
330 | nn.Sigmoid())
331 |
332 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
333 | nn.ReLU(inplace=True),
334 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
335 | bias=False, dilation=1))
336 |
337 | for m in self.modules():
338 | if isinstance(m, nn.Conv2d):
339 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
340 | m.weight.data.normal_(0, math.sqrt(2. / n))
341 | elif isinstance(m, nn.Conv3d):
342 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
343 | m.weight.data.normal_(0, math.sqrt(2. / n))
344 | elif isinstance(m, nn.BatchNorm2d):
345 | m.weight.data.fill_(1)
346 | m.bias.data.zero_()
347 | elif isinstance(m, nn.BatchNorm3d):
348 | m.weight.data.fill_(1)
349 | m.bias.data.zero_()
350 | elif isinstance(m, nn.Linear):
351 | m.bias.data.zero_()
352 |
353 | def forward(self, L, R):
354 | featL = self.feature_extraction0(L)
355 | featR = self.feature_extraction0(R)
356 |
357 | xALL0 = self.volume1(featL, featR)
358 |
359 | cost0 = self.dres0(xALL0)
360 | cost0 = self.dres1(cost0) + cost0
361 |
362 | out1 = self.encoder_decoder1(cost0)
363 | cost1 = self.classif1(out1)
364 | cost1 = torch.unsqueeze(cost1, 1)
365 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
366 | cost1 = torch.squeeze(cost1, 1)
367 | pred = torch.clamp_max(self.output_head(cost1), 1.0)
368 |
369 | return pred
370 |
371 |
372 |
373 | class ASNet_color(nn.Module):
374 | def __init__(self, maxdisp):
375 |
376 | super(ASNet_color, self).__init__()
377 |
378 | self.maxdisp = maxdisp
379 |
380 | self.num_groups = 1
381 |
382 | self.volume_size = 48
383 |
384 | self.hg_size = 32
385 |
386 | self.dres_expanse_ratio = 3
387 |
388 | self.feature_extraction0 = feature_extraction()
389 |
390 | self.volume1 = volume_build(self.volume_size)
391 |
392 | self.dres0 = nn.Sequential(MobileV2_Residual(self.volume_size, self.hg_size, 1, self.dres_expanse_ratio),
393 | nn.ReLU(inplace=True),
394 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
395 | nn.ReLU(inplace=True))
396 |
397 | self.dres1 = nn.Sequential(MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio),
398 | nn.ReLU(inplace=True),
399 | MobileV2_Residual(self.hg_size, self.hg_size, 1, self.dres_expanse_ratio))
400 |
401 | # self.encoder_decoder1 = hourglass2D(self.hg_size)
402 | # self.encoder_decoder2 = hourglass2D(self.hg_size)
403 | # self.encoder_decoder3 = hourglass2D(self.hg_size)
404 |
405 | self.output_head = nn.Sequential(convbn(self.maxdisp, 1, 3, 1, 1, 1),
406 | nn.Sigmoid())
407 |
408 | self.classif1 = nn.Sequential(convbn(self.hg_size, self.hg_size, 3, 1, 1, 1),
409 | nn.ReLU(inplace=True),
410 | nn.Conv2d(self.hg_size, self.hg_size, kernel_size=3, padding=1, stride=1,
411 | bias=False, dilation=1))
412 |
413 | for m in self.modules():
414 | if isinstance(m, nn.Conv2d):
415 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
416 | m.weight.data.normal_(0, math.sqrt(2. / n))
417 | elif isinstance(m, nn.Conv3d):
418 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
419 | m.weight.data.normal_(0, math.sqrt(2. / n))
420 | elif isinstance(m, nn.BatchNorm2d):
421 | m.weight.data.fill_(1)
422 | m.bias.data.zero_()
423 | elif isinstance(m, nn.BatchNorm3d):
424 | m.weight.data.fill_(1)
425 | m.bias.data.zero_()
426 | elif isinstance(m, nn.Linear):
427 | m.bias.data.zero_()
428 |
429 | def forward(self, L, R):
430 | featL = self.feature_extraction0(L)
431 | featR = self.feature_extraction0(R)
432 |
433 | xALL0 = self.volume1(featL, featR)
434 |
435 | cost0 = self.dres0(xALL0)
436 | cost0 = self.dres1(cost0) + cost0
437 | # out1 = self.encoder_decoder1(cost0)
438 | cost1 = self.classif1(cost0)
439 | cost1 = torch.unsqueeze(cost1, 1)
440 | cost1 = F.interpolate(cost1, [self.maxdisp, L.size()[2], L.size()[3]], mode='trilinear')
441 | cost1 = torch.squeeze(cost1, 1)
442 | pred = self.output_head(cost1)
443 |
444 | return pred
445 |
--------------------------------------------------------------------------------
/lib/network/asnet_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from lib.config import cfg
5 |
6 |
7 | ###############################################################################
8 | """ Fundamental Building Blocks """
9 | ###############################################################################
10 |
11 |
12 | def convbn(in_channels, out_channels, kernel_size, stride, pad, dilation):
13 | return nn.Sequential(
14 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
15 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False),
16 | nn.BatchNorm2d(out_channels)
17 | )
18 |
19 |
20 | def convbn_dws(inp, oup, kernel_size, stride, pad, dilation, second_relu=True):
21 | if second_relu:
22 | return nn.Sequential(
23 | # dw
24 | nn.Conv2d(inp, inp, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad,
25 | dilation=dilation, groups=inp, bias=False),
26 | nn.BatchNorm2d(inp),
27 | nn.ReLU6(inplace=True),
28 | # pw
29 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
30 | nn.BatchNorm2d(oup),
31 | nn.ReLU6(inplace=False)
32 | )
33 | else:
34 | return nn.Sequential(
35 | # dw
36 | nn.Conv2d(inp, inp, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad,
37 | dilation=dilation, groups=inp, bias=False),
38 | nn.BatchNorm2d(inp),
39 | nn.ReLU6(inplace=True),
40 | # pw
41 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
42 | nn.BatchNorm2d(oup)
43 | )
44 |
45 | class MobileV1_Residual(nn.Module):
46 | expansion = 1
47 |
48 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
49 | super(MobileV1_Residual, self).__init__()
50 |
51 | self.stride = stride
52 | self.downsample = downsample
53 | self.conv1 = convbn_dws(inplanes, planes, 3, stride, pad, dilation)
54 | self.conv2 = convbn_dws(planes, planes, 3, 1, pad, dilation, second_relu=False)
55 |
56 | def forward(self, x):
57 | out = self.conv1(x)
58 | out = self.conv2(out)
59 |
60 | if self.downsample is not None:
61 | x = self.downsample(x)
62 |
63 | out += x
64 |
65 | return out
66 |
67 |
68 |
69 | class MobileV2_Residual(nn.Module):
70 | def __init__(self, inp, oup, stride, expanse_ratio, dilation=1):
71 | super(MobileV2_Residual, self).__init__()
72 | self.stride = stride
73 | assert stride in [1, 2]
74 |
75 | hidden_dim = int(inp * expanse_ratio)
76 | self.use_res_connect = self.stride == 1 and inp == oup
77 | pad = dilation
78 |
79 | if expanse_ratio == 1:
80 | self.conv = nn.Sequential(
81 | # dw
82 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, pad, dilation=dilation, groups=hidden_dim, bias=False),
83 | nn.BatchNorm2d(hidden_dim),
84 | nn.ReLU6(inplace=True),
85 | # pw-linear
86 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
87 | nn.BatchNorm2d(oup),
88 | )
89 | else:
90 | self.conv = nn.Sequential(
91 | # pw
92 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
93 | nn.BatchNorm2d(hidden_dim),
94 | nn.ReLU6(inplace=True),
95 | # dw
96 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, pad, dilation=dilation, groups=hidden_dim, bias=False),
97 | nn.BatchNorm2d(hidden_dim),
98 | nn.ReLU6(inplace=True),
99 | # pw-linear
100 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
101 | nn.BatchNorm2d(oup),
102 | )
103 |
104 | def forward(self, x):
105 | if self.use_res_connect:
106 | return x + self.conv(x)
107 | else:
108 | return self.conv(x)
109 |
110 |
111 |
112 | class InsideBlockConv(nn.Module):
113 | def __init__(self, in_features, out_features):
114 | super(InsideBlockConv, self).__init__()
115 | self.double_conv = nn.Sequential(
116 | nn.Conv2d(in_features, out_features, kernel_size=3, padding=(1,1)), # For same padding: pad=1 for filter=3
117 | nn.BatchNorm2d(out_features),
118 | nn.ReLU(inplace=True), # inplace=True doesn't create additonal memory. Not always correct operation. But here there is no issue
119 | nn.Conv2d(out_features, out_features, kernel_size=3, padding=(1,1)),
120 | nn.BatchNorm2d(out_features),
121 | nn.ReLU(inplace=True)
122 | )
123 |
124 | def forward(self, x1):
125 | return self.double_conv(x1)
126 |
127 | ###############################################################################
128 | """ Feature Extraction """
129 | ###############################################################################
130 |
131 |
132 |
133 |
134 |
135 | class feature_extraction(nn.Module):
136 | def __init__(self):
137 | super(feature_extraction, self).__init__()
138 |
139 | self.expanse_ratio = 3
140 | self.inplanes = 32
141 |
142 | self.firstconv0 = nn.Sequential(MobileV2_Residual(1, 4, 2, self.expanse_ratio),
143 | nn.ReLU(inplace=True),
144 | MobileV2_Residual(4, 16, 1, self.expanse_ratio),
145 | nn.ReLU(inplace=True),
146 | MobileV2_Residual(16, 32, 1, self.expanse_ratio),
147 | nn.ReLU(inplace=True)
148 | )
149 | self.firstconv1 = nn.Sequential(MobileV2_Residual(1, 4, 2, self.expanse_ratio),
150 | nn.ReLU(inplace=True),
151 | MobileV2_Residual(4, 16, 1, self.expanse_ratio),
152 | nn.ReLU(inplace=True),
153 | MobileV2_Residual(16, 32, 1, self.expanse_ratio),
154 | nn.ReLU(inplace=True)
155 | )
156 | self.firstconv2 = nn.Sequential(MobileV2_Residual(1, 4, 2, self.expanse_ratio),
157 | nn.ReLU(inplace=True),
158 | MobileV2_Residual(4, 16, 1, self.expanse_ratio),
159 | nn.ReLU(inplace=True),
160 | MobileV2_Residual(16, 32, 1, self.expanse_ratio),
161 | nn.ReLU(inplace=True)
162 | )
163 |
164 |
165 | self.conv3d = nn.Sequential(nn.Conv3d(1, 1, kernel_size=(3, 5, 5), stride=[3, 1, 1], padding=[0, 2, 2]),
166 | nn.BatchNorm3d(1),
167 | nn.ReLU())
168 |
169 | self.layer1 = self._make_layer(MobileV1_Residual, 32, 3, 1, 1, 1)
170 | self.layer2 = self._make_layer(MobileV1_Residual, 64, 16, 2, 1, 1)#
171 | self.layer3 = self._make_layer(MobileV1_Residual, 128, 3, 1, 1, 1)
172 | self.layer4 = self._make_layer(MobileV1_Residual, 128, 3, 1, 1, 2)
173 |
174 | self.preconv11 = nn.Sequential(
175 | convbn(320, 256, 1, 1, 0, 1),
176 | nn.ReLU(inplace=True),
177 | convbn(256, 128, 1, 1, 0, 1),
178 | nn.ReLU(inplace=True),
179 | convbn(128, 64, 1, 1, 0, 1),
180 | nn.ReLU(inplace=True),
181 | nn.Conv2d(64, 32, 1, 1, 0, 1)
182 | )
183 |
184 |
185 |
186 | def _make_layer(self, block, planes, blocks, stride, pad, dilation):
187 | downsample = None
188 |
189 | if stride != 1 or self.inplanes != planes:
190 | downsample = nn.Sequential(
191 | nn.Conv2d(self.inplanes, planes,
192 | kernel_size=1, stride=stride, bias=False),
193 | nn.BatchNorm2d(planes),
194 | )
195 |
196 | layers = [block(self.inplanes, planes, stride, downsample, pad, dilation)]
197 | self.inplanes = planes
198 | for i in range(1, blocks):
199 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation))
200 |
201 | return nn.Sequential(*layers)
202 |
203 | def forward(self, x):
204 |
205 | x0 = torch.unsqueeze(x[:,0,:,:], 1)
206 | x1 = torch.unsqueeze(x[:,1,:,:], 1)
207 | x2 = torch.unsqueeze(x[:,2,:,:], 1)
208 |
209 | x0 = self.firstconv0(x0)
210 | x1 = self.firstconv1(x1)
211 | x2 = self.firstconv2(x2)
212 |
213 | B, C, H, W = x0.shape
214 | interwoven_features = x0.new_zeros([B, 3 * C, H, W])
215 | xall = interweave_tensors3(interwoven_features, x0, x1, x2)
216 |
217 | xall = torch.unsqueeze(xall, 1)
218 | xall = self.conv3d(xall)
219 | xall = torch.squeeze(xall, 1)
220 |
221 |
222 |
223 |
224 |
225 | xall = self.layer1(xall)
226 | xall2 = self.layer2(xall)
227 | xall3 = self.layer3(xall2)
228 | xall4 = self.layer4(xall3)
229 |
230 | feature_volume = torch.cat((xall2, xall3, xall4), dim=1)
231 |
232 | xALL = self.preconv11(feature_volume)
233 |
234 |
235 |
236 | return xALL
237 |
238 |
239 |
240 |
241 | class volume_build(nn.Module):
242 | def __init__(self, volume_size):
243 | super(volume_build, self).__init__()
244 | self.num_groups = 1
245 | self.volume_size = volume_size
246 |
247 |
248 |
249 |
250 |
251 | self.volume11 = nn.Sequential(
252 | convbn(16, 1, 1, 1, 0, 1),
253 | nn.ReLU(inplace=True))
254 | self.conv3d = nn.Sequential(nn.Conv3d(1, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]),
255 | nn.BatchNorm3d(16),
256 | nn.ReLU(),
257 | nn.Conv3d(16, 32, kernel_size=(4, 3, 3), stride=[4, 1, 1], padding=[0, 1, 1]),
258 | nn.BatchNorm3d(32),
259 | nn.ReLU(),
260 | nn.Conv3d(32, 16, kernel_size=(2, 3, 3), stride=[2, 1, 1], padding=[0, 1, 1]),
261 | nn.BatchNorm3d(16),
262 | nn.ReLU())
263 |
264 | def forward(self, featL,featR):
265 |
266 |
267 |
268 |
269 | B, C, H, W = featL.shape
270 | volume = featL.new_zeros([B, self.num_groups, self.volume_size, H, W])
271 |
272 |
273 |
274 | interwoven_features = featL.new_zeros([B, 2 * C, H, W])
275 | for i in range(self.volume_size):
276 |
277 | if i > 0:
278 | x = interweave_tensors(interwoven_features, featL[:, :, :, :-i], featR[:, :, :, i:])
279 | x = torch.unsqueeze(x, 1)
280 | x = self.conv3d(x)
281 | x = torch.squeeze(x, 2)
282 | x = self.volume11(x)
283 | volume[:, :, i, :, i:] = x
284 | else:
285 | x = interweave_tensors(interwoven_features, featL, featR)
286 | x = torch.unsqueeze(x, 1)
287 | x = self.conv3d(x)
288 | x = torch.squeeze(x, 2)
289 | x = self.volume11(x)
290 | volume[:, :, i, :, :] = x
291 |
292 | volume = volume.contiguous()
293 | volume = torch.squeeze(volume, 1)
294 |
295 | return volume
296 |
297 |
298 |
299 |
300 |
301 | ##############################################################################
302 | """ Disparity Regression Function """
303 | ###############################################################################
304 |
305 |
306 | def disparity_regression(x, maxdisp):
307 | assert len(x.shape) == 4
308 | disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
309 | disp_values = disp_values.view(1, maxdisp, 1, 1)
310 | return torch.sum(x * disp_values, 1, keepdim=False) / maxdisp * cfg.model.max_depth_value
311 |
312 |
313 |
314 | def interweave_tensors(interwoven_features, refimg_fea, targetimg_fea):
315 | B, C, H, W = refimg_fea.shape
316 | interwoven_features = interwoven_features[:, :, :, 0:W]
317 | interwoven_features = interwoven_features*0
318 | interwoven_features[:,::2,:,:] = refimg_fea
319 | interwoven_features[:,1::2,:,:] = targetimg_fea
320 | interwoven_features = interwoven_features.contiguous()
321 | return interwoven_features
322 | def interweave_tensors3(interwoven_features, refimg_fea, targetimg_fea, targetimg2_fea):
323 | B, C, H, W = refimg_fea.shape
324 | interwoven_features = interwoven_features[:, :, :, 0:W]
325 | interwoven_features = interwoven_features*0
326 | interwoven_features[:,::3,:,:] = refimg_fea
327 | interwoven_features[:,1::3,:,:] = targetimg_fea
328 | interwoven_features[:,2::3,:,:] = targetimg2_fea
329 | interwoven_features = interwoven_features.contiguous()
330 | return interwoven_features
331 |
332 |
333 | ###############################################################################
334 | """ Loss Function """
335 | ###############################################################################
336 |
337 |
338 | def model_loss(disp_ests, disp_gt, mask):
339 | weights = [0.5, 0.5, 0.7, 1.0]
340 | all_losses = []
341 | for disp_est, weight in zip(disp_ests, weights):
342 | # all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], reduction='mean'))
343 | all_losses.append(weight * F.l1_loss(disp_est[mask], disp_gt[mask], reduction='mean'))
344 | return sum(all_losses)
345 |
346 | def model_loss_light(disp_ests, disp_gt, mask):
347 | weights = [0.5, 1.0]
348 | all_losses = []
349 | for disp_est, weight in zip(disp_ests, weights):
350 | all_losses.append(weight * F.l1_loss(disp_est[mask], disp_gt[mask], reduction='mean'))
351 | return sum(all_losses)
--------------------------------------------------------------------------------
/lib/network/eventgaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from lib.config import cfg, args
4 | # from lib.network.asnet import ASNet
5 | from lib.network.recon_net import E2IM, E2DPT
6 | from lib.network.gsregressor import GSRegressor
7 |
8 | class EventGaussian(nn.Module):
9 | def __init__(self):
10 | super(EventGaussian, self).__init__()
11 | self.depth_estimator = E2DPT(num_input_channels=8)
12 | self.intensity_estimator = E2IM(num_input_channels=32+8)
13 | self.regressor = GSRegressor(input_dim=2 + 32 + 8)
14 | self.gt_depth = False
15 | self.us_mask = "net"
16 | # self.proj1 = nn.Sequential(nn.Conv2D())
17 |
18 | def forward(self, batch):
19 | leT = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1)
20 | riT = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1)
21 | b = leT.shape[0]
22 | inp = torch.cat([leT, riT], dim=0)
23 |
24 | #only available for debugging
25 | if not self.gt_depth:
26 | depths, masks, dfeats = self.depth_estimator.get_features(inp)
27 | depthL, depthR = depths[:b], depths[b:]
28 | masksL, masksR = masks[:b], masks[b:]
29 | dfeatsL, dfeatsR = dfeats[:b], dfeats[b:] #[b, 32, H, W]
30 | else: #debug only
31 | depthL, depthR = batch["ldepth"].unsqueeze(1), batch["rdepth"].unsqueeze(1)
32 |
33 | # only available for debugging
34 | if self.us_mask == "gt":
35 | maskL, maskR = batch["lmask"].unsqueeze(1), batch["rmask"].unsqueeze(1)
36 | elif self.us_mask == "net":
37 | maskL, maskR = masksL, masksR
38 | elif self.us_mask == "none":
39 | maskL, maskR = torch.ones_like(depthL).to(depthL.device), torch.ones_like(depthR).to(depthL.device)
40 |
41 | depthL = depthL * maskL
42 | depthR = depthR * maskR
43 | #
44 | L_img_inp = torch.cat([dfeatsL ,leT], dim=1) #depthFeat, frame, voxel
45 | R_img_inp = torch.cat([dfeatsR ,riT], dim=1)
46 | img_inp = torch.cat([L_img_inp, R_img_inp], dim=0)
47 | img, ifeats = self.intensity_estimator.get_features(img_inp)
48 | imgL, imgR = torch.split(img, b, dim=0)
49 | ifeatL, ifeatR = torch.split(ifeats, b, dim=0)
50 | #
51 | imgL = imgL * maskL
52 | imgR = imgR * maskR
53 | # imgL, imgR = batch["lim"], batch["rim"]
54 | L_gs_inp = torch.cat([depthL, imgL, ifeatL, leT], dim=1) # depthL, imgL, ifeatL, leT
55 | R_gs_inp = torch.cat([depthR, imgR, ifeatR, riT], dim=1)
56 | gs_inp = torch.cat([L_gs_inp, R_gs_inp], dim=0)
57 | #
58 | rot, scale, opacity = self.regressor(gs_inp)
59 |
60 | return {
61 | "lview":{
62 | "depth":depthL,
63 | "mask":maskL,
64 | "pts_valid":maskL.squeeze().reshape(b, -1),
65 | "img": imgL,
66 | "rot":rot[:b],
67 | "scale":scale[:b],
68 | "opacity":opacity[:b]
69 | },
70 | "rview":{
71 | "depth":depthR,
72 | "mask":maskR,
73 | "pts_valid":maskR.squeeze().reshape(b, -1),
74 | "img": imgR,
75 | "rot":rot[b:],
76 | "scale":scale[b:],
77 | "opacity":opacity[b:]
78 | }
79 | }
80 |
81 |
82 |
--------------------------------------------------------------------------------
/lib/network/firenet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import torch
4 | from .unet import UNet, UNetRecurrent, UNetFire, UNetStatic
5 | from os.path import join
6 | from .submodules import ConvLSTM, ResidualBlock, ConvLayer, UpsampleConvLayer, TransposedConvLayer
7 |
8 |
9 | import logging
10 | import numpy as np
11 |
12 | class BaseModel(nn.Module):
13 | """
14 | Base class for all models
15 | """
16 | def __init__(self, config):
17 | super(BaseModel, self).__init__()
18 | self.config = config
19 | self.logger = logging.getLogger(self.__class__.__name__)
20 |
21 | def forward(self, *input):
22 | """
23 | Forward pass logic
24 |
25 | :return: Model output
26 | """
27 | raise NotImplementedError
28 |
29 | def summary(self):
30 | """
31 | Model summary
32 | """
33 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
34 | params = sum([np.prod(p.size()) for p in model_parameters])
35 | self.logger.info('Trainable parameters: {}'.format(params))
36 | self.logger.info(self)
37 |
38 | class BaseE2VID(BaseModel):
39 | def __init__(self, config):
40 | super().__init__(config)
41 |
42 | assert('num_bins' in config)
43 | self.num_bins = int(config['num_bins']) # number of bins in the voxel grid event tensor
44 |
45 | try:
46 | self.skip_type = str(config['skip_type'])
47 | except KeyError:
48 | self.skip_type = 'sum'
49 |
50 | try:
51 | self.num_encoders = int(config['num_encoders'])
52 | except KeyError:
53 | self.num_encoders = 4
54 |
55 | try:
56 | self.base_num_channels = int(config['base_num_channels'])
57 | except KeyError:
58 | self.base_num_channels = 32
59 |
60 | try:
61 | self.num_residual_blocks = int(config['num_residual_blocks'])
62 | except KeyError:
63 | self.num_residual_blocks = 2
64 |
65 | try:
66 | self.norm = str(config['norm'])
67 | except KeyError:
68 | self.norm = None
69 |
70 | try:
71 | self.use_upsample_conv = bool(config['use_upsample_conv'])
72 | except KeyError:
73 | self.use_upsample_conv = True
74 |
75 |
76 | class E2VID(BaseE2VID):
77 | def __init__(self, config):
78 | super(E2VID, self).__init__(config)
79 |
80 | self.unet = UNet(num_input_channels=self.num_bins,
81 | num_output_channels=1,
82 | skip_type=self.skip_type,
83 | activation='sigmoid',
84 | num_encoders=self.num_encoders,
85 | base_num_channels=self.base_num_channels,
86 | num_residual_blocks=self.num_residual_blocks,
87 | norm=self.norm,
88 | use_upsample_conv=self.use_upsample_conv)
89 |
90 | def forward(self, event_tensor, prev_states=None):
91 | """
92 | :param event_tensor: N x num_bins x H x W
93 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
94 | """
95 | return self.unet.forward(event_tensor), None
96 |
97 |
98 | class E2VIDRecurrent(BaseE2VID):
99 | """
100 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU.
101 | """
102 |
103 | def __init__(self, config):
104 | super(E2VIDRecurrent, self).__init__(config)
105 |
106 | try:
107 | self.recurrent_block_type = str(config['recurrent_block_type'])
108 | except KeyError:
109 | self.recurrent_block_type = 'convlstm' # or 'convgru'
110 |
111 | self.unetrecurrent = UNetRecurrent(num_input_channels=self.num_bins,
112 | num_output_channels=1,
113 | skip_type=self.skip_type,
114 | recurrent_block_type=self.recurrent_block_type,
115 | activation='sigmoid',
116 | num_encoders=self.num_encoders,
117 | base_num_channels=self.base_num_channels,
118 | num_residual_blocks=self.num_residual_blocks,
119 | norm=self.norm,
120 | use_upsample_conv=self.use_upsample_conv)
121 |
122 | def forward(self, event_tensor, prev_states):
123 | """
124 | :param event_tensor: N x num_bins x H x W
125 | :param prev_states: previous ConvLSTM state for each encoder module
126 | :return: reconstructed image, taking values in [0,1].
127 | """
128 | img_pred, states = self.unetrecurrent.forward(event_tensor, prev_states)
129 | return img_pred, states
130 |
131 |
132 | class FireNet(BaseE2VID):
133 | """
134 | Model from the paper: "Fast Image Reconstruction with an Event Camera", Scheerlinck et. al., 2019.
135 | The model is essentially a lighter version of E2VID, which runs faster (~2-3x faster) and has considerably less parameters (~200x less).
136 | However, the reconstructions are not as high quality as E2VID: they suffer from smearing artefacts, and initialization takes longer.
137 | """
138 | def __init__(self, config):
139 | super().__init__(config)
140 | self.recurrent_block_type = str(config.get('recurrent_block_type', 'convgru'))
141 | kernel_size = config.get('kernel_size', 3)
142 | recurrent_blocks = config.get('recurrent_blocks', {'resblock': [0]})
143 | self.net = UNetFire(self.num_bins,
144 | num_output_channels=1,
145 | skip_type=self.skip_type,
146 | recurrent_block_type=self.recurrent_block_type,
147 | base_num_channels=self.base_num_channels,
148 | num_residual_blocks=self.num_residual_blocks,
149 | norm=self.norm,
150 | kernel_size=kernel_size,
151 | recurrent_blocks=recurrent_blocks)
152 |
153 | def forward(self, event_tensor, prev_states):
154 | img, states = self.net.forward(event_tensor, prev_states)
155 | return img, states
156 |
157 |
158 | class FireNet_static(BaseE2VID):
159 | """
160 | Model from the paper: "Fast Image Reconstruction with an Event Camera", Scheerlinck et. al., 2019.
161 | The model is essentially a lighter version of E2VID, which runs faster (~2-3x faster) and has considerably less parameters (~200x less).
162 | However, the reconstructions are not as high quality as E2VID: they suffer from smearing artefacts, and initialization takes longer.
163 | """
164 | def __init__(self, config):
165 | super().__init__(config)
166 | self.recurrent_block_type = str(config.get('recurrent_block_type', 'convgru'))
167 | kernel_size = config.get('kernel_size', 3)
168 | recurrent_blocks = config.get('recurrent_blocks', {'resblock': [0]})
169 | self.net = UNetStatic(self.num_bins,
170 | num_output_channels=1,
171 | skip_type=self.skip_type,
172 | recurrent_block_type=self.recurrent_block_type,
173 | base_num_channels=self.base_num_channels,
174 | num_residual_blocks=self.num_residual_blocks,
175 | norm=self.norm,
176 | kernel_size=kernel_size,
177 | recurrent_blocks=recurrent_blocks)
178 |
179 | def forward(self, event_tensor, placeholder=None):
180 | img = self.net.forward(event_tensor)
181 | return img, placeholder
--------------------------------------------------------------------------------
/lib/network/gsregressor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
7 | super(ResidualBlock, self).__init__()
8 |
9 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
10 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | num_groups = planes // 8
14 |
15 | if norm_fn == 'group':
16 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
17 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | if not (stride == 1 and in_planes == planes):
19 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20 |
21 | elif norm_fn == 'batch':
22 | self.norm1 = nn.BatchNorm2d(planes)
23 | self.norm2 = nn.BatchNorm2d(planes)
24 | if not (stride == 1 and in_planes == planes):
25 | self.norm3 = nn.BatchNorm2d(planes)
26 |
27 | elif norm_fn == 'instance':
28 | self.norm1 = nn.InstanceNorm2d(planes)
29 | self.norm2 = nn.InstanceNorm2d(planes)
30 | if not (stride == 1 and in_planes == planes):
31 | self.norm3 = nn.InstanceNorm2d(planes)
32 |
33 | elif norm_fn == 'none':
34 | self.norm1 = nn.Sequential()
35 | self.norm2 = nn.Sequential()
36 | if not (stride == 1 and in_planes == planes):
37 | self.norm3 = nn.Sequential()
38 |
39 | if stride == 1 and in_planes == planes:
40 | self.downsample = None
41 |
42 | else:
43 | self.downsample = nn.Sequential(
44 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
45 |
46 |
47 | def forward(self, x):
48 | y = x
49 | y = self.conv1(y)
50 | y = self.norm1(y)
51 | y = self.relu(y)
52 | y = self.conv2(y)
53 | y = self.norm2(y)
54 | y = self.relu(y)
55 |
56 | if self.downsample is not None:
57 | x = self.downsample(x)
58 |
59 | return self.relu(x+y)
60 |
61 | class GSRegressor(nn.Module):
62 | def __init__(self, input_dim=1+1+8, hidden_dim = 256, norm_fn='group'):
63 | super().__init__()
64 | self.embedding = nn.Conv2d(input_dim, hidden_dim, kernel_size=1, stride=1)
65 |
66 | self.res1 = ResidualBlock(hidden_dim, hidden_dim // 4, norm_fn=norm_fn)
67 |
68 | self.rot_head = nn.Sequential(
69 | nn.Conv2d(hidden_dim // 4, hidden_dim // 4, kernel_size=3, padding=1),
70 | nn.ReLU(inplace=True),
71 | nn.Conv2d(hidden_dim // 4, 4, kernel_size=1),
72 | )
73 | self.scale_head = nn.Sequential(
74 | nn.Conv2d(hidden_dim // 4, hidden_dim // 4, kernel_size=3, padding=1),
75 | nn.ReLU(inplace=True),
76 | nn.Conv2d(hidden_dim // 4, 3, kernel_size=1),
77 | nn.Softplus(beta=100)
78 | )
79 | self.opacity_head = nn.Sequential(
80 | nn.Conv2d(hidden_dim // 4, hidden_dim // 4, kernel_size=3, padding=1),
81 | nn.ReLU(inplace=True),
82 | nn.Conv2d(hidden_dim // 4, 1, kernel_size=1),
83 | nn.Sigmoid()
84 | )
85 |
86 | def forward(self, x):
87 | """
88 | x intensity [B,1,H,W]
89 | depth [B,1,H,W]
90 | eframe [B,3,H,W]
91 | img_feat [B,320,H,W]
92 | """
93 | x = self.embedding(x)
94 | out = self.res1(x)
95 |
96 | rot_out = self.rot_head(out)
97 | rot_out = torch.nn.functional.normalize(rot_out, dim=1)
98 |
99 | # scale head
100 | scale_out = torch.clamp_max(self.scale_head(out), 0.001)
101 |
102 | # opacity head
103 | opacity_out = self.opacity_head(out)
104 |
105 | return rot_out, scale_out, opacity_out
106 |
107 |
108 |
--------------------------------------------------------------------------------
/lib/network/neurons.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Callable
3 | import torch
4 | import torch.nn as nn
5 | from spikingjelly.clock_driven import neuron, surrogate, base, layer
6 | import math
7 | try:
8 | import cupy
9 | from . import neuron_kernel, cu_kernel_opt
10 | except ImportError:
11 | neuron_kernel = None
12 |
13 |
14 | class BaseNode(base.MemoryModule):
15 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
16 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
17 | assert isinstance(v_reset, float) or v_reset is None
18 | assert isinstance(v_threshold, float)
19 | assert isinstance(detach_reset, bool)
20 | super().__init__()
21 |
22 | if v_reset is None:
23 | self.register_memory('v', 0.)
24 | self.register_memory('spike', 0.)
25 | else:
26 | self.register_memory('v', v_reset)
27 | self.register_memory('spike', 0.)
28 |
29 | self.v_threshold = v_threshold
30 | self.v_reset = v_reset
31 |
32 | self.detach_reset = detach_reset
33 | self.surrogate_function = surrogate_function
34 |
35 | @abstractmethod
36 | def neuronal_charge(self, x: torch.Tensor):
37 | raise NotImplementedError
38 |
39 | def neuronal_fire(self):
40 | self.spike = self.surrogate_function(self.v - self.v_threshold)
41 |
42 | def neuronal_reset(self):
43 |
44 | if self.detach_reset:
45 | spike = self.spike.detach()
46 | else:
47 | spike = self.spike
48 |
49 | if self.v_reset is None:
50 | # soft reset
51 | self.v = self.v - spike * self.v_threshold
52 |
53 | else:
54 | # hard reset
55 | self.v = (1. - spike) * self.v + spike * self.v_reset
56 |
57 | def extra_repr(self):
58 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
59 |
60 | def forward(self, x: torch.Tensor):
61 |
62 | self.neuronal_charge(x)
63 | self.neuronal_fire()
64 | self.neuronal_reset()
65 | return self.spike
66 |
67 | class BaseNode_adaspike(base.MemoryModule):
68 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
69 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
70 |
71 | assert isinstance(v_reset, float) or v_reset is None
72 | assert isinstance(v_threshold, float)
73 | assert isinstance(detach_reset, bool)
74 | super().__init__()
75 |
76 | if v_reset is None:
77 | self.register_memory('v', 0.)
78 | self.register_memory('spike', 0.)
79 | else:
80 | self.register_memory('v', v_reset)
81 | self.register_memory('spike', 0.)
82 |
83 | self.v_threshold = v_threshold
84 | self.v_reset = v_reset
85 |
86 | self.detach_reset = detach_reset
87 | self.surrogate_function = surrogate_function
88 |
89 | @abstractmethod
90 | def neuronal_charge(self, x: torch.Tensor):
91 |
92 | raise NotImplementedError
93 |
94 | def neuronal_fire(self):
95 |
96 | self.spike = self.surrogate_function(self.v - self.v_threshold)
97 |
98 | def neuronal_reset(self):
99 |
100 | if self.detach_reset:
101 | spike = self.spike.detach()
102 | else:
103 | spike = self.spike
104 |
105 | if self.v_reset is None:
106 | # soft reset
107 | self.v = self.v - spike * self.v_threshold
108 |
109 | else:
110 | # hard reset
111 | self.v = (1. - spike) * self.v + spike * self.v_reset
112 |
113 | def extra_repr(self):
114 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
115 |
116 | def forward(self, x: torch.Tensor, s: torch.Tensor):
117 |
118 | self.neuronal_charge(x, s)
119 | self.neuronal_fire()
120 | self.neuronal_reset()
121 | return self.spike
122 |
123 | class MpNode(base.MemoryModule):
124 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
125 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
126 |
127 | assert isinstance(v_reset, float) or v_reset is None
128 | assert isinstance(v_threshold, float)
129 | assert isinstance(detach_reset, bool)
130 | super().__init__()
131 |
132 |
133 | if v_reset is None:
134 | self.register_memory('v', 0.)
135 | self.register_memory('spike', 0.)
136 | else:
137 | self.register_memory('v', v_reset)
138 | self.register_memory('spike', 0.)
139 |
140 |
141 | self.v_threshold = v_threshold
142 | self.v_reset = v_reset
143 |
144 | self.detach_reset = detach_reset
145 | self.surrogate_function = surrogate_function
146 |
147 | @abstractmethod
148 | def neuronal_charge(self, x: torch.Tensor):
149 | raise NotImplementedError
150 |
151 | def neuronal_fire(self):
152 | self.spike = self.surrogate_function(self.v - self.v_threshold)
153 |
154 | def neuronal_reset(self):
155 | if self.detach_reset:
156 | spike = self.spike.detach()
157 | else:
158 | spike = self.spike
159 |
160 | if self.v_reset is None:
161 | # soft reset
162 | self.v = self.v - spike * self.v_threshold
163 |
164 | else:
165 | # hard reset
166 | self.v = (1. - spike) * self.v + spike * self.v_reset
167 |
168 | def extra_repr(self):
169 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
170 |
171 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor):
172 | if last_mem is None:
173 | self.neuronal_charge(x)
174 | else:
175 | self.register_memory('v', last_mem)
176 | self.neuronal_charge(x)
177 | return self.v
178 |
179 | class Ada_MpNode(base.MemoryModule):
180 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
181 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
182 |
183 | assert isinstance(v_reset, float) or v_reset is None
184 | assert isinstance(v_threshold, float)
185 | assert isinstance(detach_reset, bool)
186 | super().__init__()
187 |
188 |
189 | if v_reset is None:
190 | self.register_memory('v', 0.)
191 | self.register_memory('spike', 0.)
192 | else:
193 | self.register_memory('v', v_reset)
194 | self.register_memory('spike', 0.)
195 |
196 |
197 | self.v_threshold = v_threshold
198 | self.v_reset = v_reset
199 |
200 | self.detach_reset = detach_reset
201 | self.surrogate_function = surrogate_function
202 |
203 | @abstractmethod
204 | def neuronal_charge(self, x: torch.Tensor):
205 | raise NotImplementedError
206 |
207 | def neuronal_fire(self):
208 | self.spike = self.surrogate_function(self.v - self.v_threshold)
209 |
210 | def neuronal_reset(self):
211 | if self.detach_reset:
212 | spike = self.spike.detach()
213 | else:
214 | spike = self.spike
215 |
216 | if self.v_reset is None:
217 | # soft reset
218 | self.v = self.v - spike * self.v_threshold
219 |
220 | else:
221 | # hard reset
222 | self.v = (1. - spike) * self.v + spike * self.v_reset
223 |
224 | def extra_repr(self):
225 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
226 |
227 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor, w: torch.Tensor):
228 | if last_mem is None:
229 | self.neuronal_charge(x, w)
230 | else:
231 | self.register_memory('v', last_mem)
232 | self.neuronal_charge(x, w)
233 |
234 | return self.v
235 |
236 | class Ada_MpNode_adaspike(base.MemoryModule):
237 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
238 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
239 | assert isinstance(v_reset, float) or v_reset is None
240 | assert isinstance(v_threshold, float)
241 | assert isinstance(detach_reset, bool)
242 | super().__init__()
243 |
244 |
245 | if v_reset is None:
246 | self.register_memory('v', 0.)
247 | self.register_memory('spike', 0.)
248 | else:
249 | self.register_memory('v', v_reset)
250 | self.register_memory('spike', 0.)
251 |
252 |
253 | self.v_threshold = v_threshold
254 | self.v_reset = v_reset
255 |
256 | self.detach_reset = detach_reset
257 | self.surrogate_function = surrogate_function
258 |
259 | @abstractmethod
260 | def neuronal_charge(self, x: torch.Tensor):
261 | raise NotImplementedError
262 |
263 | def neuronal_fire(self):
264 | self.spike = self.surrogate_function(self.v - self.v_threshold)
265 |
266 | def neuronal_reset(self):
267 | if self.detach_reset:
268 | spike = self.spike.detach()
269 | else:
270 | spike = self.spike
271 |
272 | if self.v_reset is None:
273 | # soft reset
274 | self.v = self.v - spike * self.v_threshold
275 |
276 | else:
277 | # hard reset
278 | self.v = (1. - spike) * self.v + spike * self.v_reset
279 |
280 | def extra_repr(self):
281 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
282 |
283 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor, w: torch.Tensor, s: torch.Tensor):
284 | if last_mem is None:
285 | self.neuronal_charge(x, w, s)
286 | else:
287 | self.register_memory('v', last_mem)
288 | self.neuronal_charge(x, w, s)
289 | return self.v
290 |
291 | class Multi_Node(base.MemoryModule):
292 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
293 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
294 | assert isinstance(v_reset, float) or v_reset is None
295 | assert isinstance(v_threshold, float)
296 | assert isinstance(detach_reset, bool)
297 | super().__init__()
298 |
299 | if v_reset is None:
300 | self.register_memory('v', 0.)
301 | self.register_memory('spike', 0.)
302 | else:
303 | self.register_memory('v', v_reset)
304 | self.register_memory('spike', 0.)
305 |
306 | self.v_threshold = v_threshold
307 | self.v_reset = v_reset
308 |
309 | self.detach_reset = detach_reset
310 | self.surrogate_function = surrogate_function
311 |
312 | @abstractmethod
313 | def neuronal_charge(self, x: torch.Tensor):
314 | raise NotImplementedError
315 |
316 | def neuronal_fire(self):
317 | self.spike = self.surrogate_function(self.v - self.v_threshold)
318 |
319 | def neuronal_reset(self):
320 | if self.detach_reset:
321 | spike = self.spike.detach()
322 | else:
323 | spike = self.spike
324 |
325 | if self.v_reset is None:
326 | # soft reset
327 | self.v = self.v - spike * self.v_threshold
328 |
329 | else:
330 | # hard reset
331 | self.v = (1. - spike) * self.v + spike * self.v_reset
332 |
333 | def extra_repr(self):
334 | return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
335 |
336 | def forward(self, x: torch.Tensor, last_mem: torch.Tensor):
337 | if last_mem is None:
338 | self.neuronal_charge(x)
339 | self.neuronal_fire()
340 | self.neuronal_reset()
341 | else:
342 | self.register_memory('v', last_mem)
343 | self.neuronal_charge(x)
344 | self.neuronal_fire()
345 | self.neuronal_reset()
346 |
347 | return self.spike, self.v
348 |
349 | class MpLIFNode(MpNode):
350 | def __init__(self, tau: float = 2., v_threshold: float = 1.,
351 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
352 | detach_reset: bool = False):
353 | assert isinstance(tau, float) and tau > 1.
354 |
355 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
356 | self.tau = tau
357 |
358 | def extra_repr(self):
359 | return super().extra_repr() + f', tau={self.tau}'
360 |
361 | def neuronal_charge(self, x: torch.Tensor):
362 | if self.v_reset is None:
363 | self.v = self.v + (x - self.v) / self.tau
364 |
365 | else:
366 | if isinstance(self.v_reset, float) and self.v_reset == 0.:
367 | self.v = self.v + (x - self.v) / self.tau
368 | else:
369 | self.v = self.v + (x - (self.v - self.v_reset)) / self.tau
370 |
371 | class Mp_AdaLIFNode(Ada_MpNode):
372 | def __init__(self, tau: float = 2., v_threshold: float = 1.,
373 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
374 | detach_reset: bool = False):
375 |
376 | assert isinstance(tau, float) and tau > 1.
377 | self.tau = tau
378 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
379 |
380 |
381 | def extra_repr(self):
382 | return super().extra_repr() + f', tau={self.tau}'
383 |
384 | def neuronal_charge(self, x: torch.Tensor, w: torch.Tensor):
385 | tau = w.sigmoid()
386 | if self.v_reset is None:
387 | self.v = self.v + (x - self.v) * tau
388 |
389 | else:
390 | if isinstance(self.v_reset, float) and self.v_reset == 0.:
391 | self.v = self.v + (x - self.v) * tau
392 | else:
393 | self.v = self.v + (x - (self.v - self.v_reset)) * tau
394 |
395 | class Mp_AdaLIFNode_adaspike(Ada_MpNode_adaspike):
396 | def __init__(self, tau: float = 2., v_threshold: float = 1.,
397 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
398 | detach_reset: bool = False):
399 |
400 | assert isinstance(tau, float) and tau > 1.
401 | self.tau = tau
402 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
403 |
404 |
405 | def extra_repr(self):
406 | return super().extra_repr() + f', tau={self.tau}'
407 |
408 | def neuronal_charge(self, x: torch.Tensor, w: torch.Tensor, s: torch.Tensor):
409 | tau = w.sigmoid()
410 | if self.v_reset is None:
411 | self.v = self.v + s * (x - self.v) * tau
412 |
413 | else:
414 | if isinstance(self.v_reset, float) and self.v_reset == 0.:
415 | self.v = self.v + (x - self.v) * tau
416 | else:
417 | self.v = self.v + (x - (self.v - self.v_reset)) * tau
418 |
419 | class MpIFNode(MpNode):
420 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
421 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
422 |
423 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
424 |
425 | def neuronal_charge(self, x: torch.Tensor):
426 | self.v = self.v + x
427 |
428 | class Mp_ParametricLIFNode(MpNode):
429 | def __init__(self, init_tau: float = 2.0, v_threshold: float = 1.,
430 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
431 | detach_reset: bool = False):
432 |
433 | assert isinstance(init_tau, float) and init_tau > 1.
434 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
435 | init_w = - math.log(init_tau - 1.)
436 | self.w = nn.Parameter(torch.as_tensor(init_w))
437 |
438 | def extra_repr(self):
439 | with torch.no_grad():
440 | tau = self.w.sigmoid() #.sigmoid()
441 | return super().extra_repr() + f', tau={tau}'
442 |
443 | def neuronal_charge(self, x: torch.Tensor):
444 | if self.v_reset is None:
445 | self.v = self.v + (x - self.v) * self.w.sigmoid()
446 | else:
447 | if self.v_reset == 0.:
448 | self.v = self.v + (x - self.v) * self.w.sigmoid()
449 | else:
450 | self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()
451 |
452 | class Mp_ParametricLIFNode_modify(MpNode):
453 | def __init__(self, size_h, size_w, init_tau: float = 2.0, v_threshold: float = 1.,
454 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
455 | detach_reset: bool = False):
456 | assert isinstance(init_tau, float) and init_tau > 1.
457 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
458 | init_w = - math.log(init_tau - 1.)
459 | self.w = nn.Parameter(torch.ones(size=[size_w, size_h])* init_w)
460 | #self.w = self.w * init_w # test
461 |
462 | def extra_repr(self):
463 | with torch.no_grad():
464 | tau = self.w.sigmoid() #.sigmoid()
465 | return super().extra_repr() + f', tau={tau}'
466 |
467 | def neuronal_charge(self, x: torch.Tensor):
468 | if self.v_reset is None:
469 | self.v = self.v + (x - self.v) * self.w.sigmoid()
470 | else:
471 | if self.v_reset == 0.:
472 | self.v = self.v + (x - self.v) * self.w.sigmoid()
473 | else:
474 | self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()
475 |
476 | class IFNode(BaseNode):
477 | def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
478 | surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
479 |
480 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
481 |
482 | def neuronal_charge(self, x: torch.Tensor):
483 | self.v = self.v + x
484 |
485 | class LIFNode(BaseNode):
486 | def __init__(self, tau: float = 2., v_threshold: float = 1.,
487 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
488 | detach_reset: bool = False):
489 |
490 | assert isinstance(tau, float) and tau > 1.
491 |
492 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
493 | self.tau = tau
494 |
495 | def extra_repr(self):
496 | return super().extra_repr() + f', tau={self.tau}'
497 |
498 | def neuronal_charge(self, x: torch.Tensor):
499 | if self.v_reset is None:
500 | self.v = self.v + (x - self.v) / self.tau
501 |
502 | else:
503 | if isinstance(self.v_reset, float) and self.v_reset == 0.:
504 | self.v = self.v + (x - self.v) / self.tau
505 | else:
506 | self.v = self.v + (x - (self.v - self.v_reset)) / self.tau
507 |
508 | class LIFNode_adaspike(BaseNode_adaspike):
509 | def __init__(self, tau: float = 2., v_threshold: float = 1.,
510 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
511 | detach_reset: bool = False):
512 |
513 | assert isinstance(tau, float) and tau > 1.
514 |
515 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
516 | self.tau = tau
517 |
518 | def extra_repr(self):
519 | return super().extra_repr() + f', tau={self.tau}'
520 |
521 | def neuronal_charge(self, x: torch.Tensor, s: torch.Tensor):
522 | if self.v_reset is None:
523 | self.v = self.v + s*(x - self.v) / self.tau
524 |
525 | else:
526 | if isinstance(self.v_reset, float) and self.v_reset == 0.:
527 | self.v = self.v + (x - self.v) / self.tau
528 | else:
529 | self.v = self.v + (x - (self.v - self.v_reset)) / self.tau
530 |
531 |
532 | class ParametricLIFNode(BaseNode):
533 | def __init__(self, init_tau: float = 2.0, v_threshold: float = 1.,
534 | v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
535 | detach_reset: bool = False):
536 |
537 | assert isinstance(init_tau, float) and init_tau > 1.
538 | super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
539 | init_w = - math.log(init_tau - 1.)
540 | self.w = nn.Parameter(torch.as_tensor(init_w))
541 |
542 | def extra_repr(self):
543 | with torch.no_grad():
544 | tau = self.w.sigmoid()
545 | return super().extra_repr() + f', tau={tau}'
546 |
547 | def neuronal_charge(self, x: torch.Tensor):
548 | if self.v_reset is None:
549 | self.v = self.v + (x - self.v) * self.w.sigmoid()
550 | else:
551 | if self.v_reset == 0.:
552 | self.v = self.v + (x - self.v) * self.w.sigmoid()
553 | else:
554 | self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()
--------------------------------------------------------------------------------
/lib/network/recon_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .unet import UNet
4 | from lib.config import cfg, args
5 | from .submodules import ConvLayer
6 |
7 | class E2IM(nn.Module):
8 | def __init__(self, num_input_channels=6,
9 | num_output_channels=1,
10 | skip_type="sum",
11 | activation='sigmoid',
12 | num_encoders=4,
13 | base_num_channels=32,
14 | num_residual_blocks=2,
15 | norm="BN",
16 | use_upsample_conv=True):
17 | super(E2IM, self).__init__()
18 |
19 | self.unet = UNet(num_input_channels=num_input_channels,
20 | num_output_channels=num_output_channels,
21 | skip_type=skip_type,
22 | activation=activation,
23 | num_encoders=num_encoders,
24 | base_num_channels=base_num_channels,
25 | num_residual_blocks=num_residual_blocks,
26 | norm=norm,
27 | use_upsample_conv=use_upsample_conv)
28 |
29 | def forward(self, event_tensor):
30 | """
31 | :param event_tensor: N x num_bins x H x W
32 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
33 | """
34 | return self.unet.forward(event_tensor)
35 |
36 | def get_features(self, event_tensor):
37 | img, feat = self.unet.get_features(event_tensor)
38 | return img, feat
39 |
40 |
41 | class E2DPT(nn.Module):
42 | def __init__(self, num_input_channels=6,
43 | num_output_channels=1,
44 | skip_type="sum",
45 | activation='sigmoid',
46 | num_encoders=4,
47 | base_num_channels=32,
48 | num_residual_blocks=2,
49 | norm="BN",
50 | use_upsample_conv=True):
51 | super(E2DPT, self).__init__()
52 |
53 | self.unet = UNet(num_input_channels=num_input_channels,
54 | num_output_channels=num_output_channels,
55 | skip_type=skip_type,
56 | activation=activation,
57 | num_encoders=num_encoders,
58 | base_num_channels=base_num_channels,
59 | num_residual_blocks=num_residual_blocks,
60 | norm=norm,
61 | use_upsample_conv=use_upsample_conv)
62 |
63 | self.mask_head = nn.Sequential(ConvLayer(base_num_channels, base_num_channels // 2, kernel_size=3, padding=1, norm=norm),
64 | ConvLayer(base_num_channels // 2, 1,kernel_size=1, activation=activation, norm=norm))
65 |
66 |
67 | def forward(self, event_tensor):
68 | """
69 | :param event_tensor: N x num_bins x H x W
70 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
71 | """
72 |
73 | depth, feat = self.unet.get_features(event_tensor)
74 | mask = self.mask_head(feat)
75 | return depth * cfg.model.max_depth_value, mask
76 |
77 | def get_features(self, event_tensor):
78 | depth, feat = self.unet.get_features(event_tensor)
79 | mask = self.mask_head(feat)
80 | depth = depth * cfg.model.max_depth_value
81 |
82 | return depth, mask, feat
83 |
84 |
85 | class E2Msk(nn.Module):
86 | def __init__(self, num_input_channels=6,
87 | num_output_channels=1,
88 | skip_type="sum",
89 | activation='sigmoid',
90 | num_encoders=4,
91 | base_num_channels=32,
92 | num_residual_blocks=2,
93 | norm="BN",
94 | use_upsample_conv=True):
95 | super(E2Msk, self).__init__()
96 |
97 | self.unet = UNet(num_input_channels=num_input_channels,
98 | num_output_channels=num_output_channels,
99 | skip_type=skip_type,
100 | activation=activation,
101 | num_encoders=num_encoders,
102 | base_num_channels=base_num_channels,
103 | num_residual_blocks=num_residual_blocks,
104 | norm=norm,
105 | use_upsample_conv=use_upsample_conv)
106 |
107 |
108 | def forward(self, event_tensor):
109 | """
110 | :param event_tensor: N x num_bins x H x W
111 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
112 | """
113 | depth = self.unet.forward(event_tensor)
114 | return depth * cfg.model.max_depth_value
115 |
116 | def get_features(self, event_tensor):
117 | depth, feat = self.unet.get_features(event_tensor)
118 | depth = depth * cfg.model.max_depth_value
119 |
120 | return depth, feat
121 |
122 |
--------------------------------------------------------------------------------
/lib/network/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 | from torch.hub import load_state_dict_from_url
5 |
6 |
7 | model_urls = {
8 | 'resnet50': 'https://github.com/bubbliiiing/pspnet-pytorch/releases/download/v1.0/resnet50s-a75c83cf.pth',
9 | }
10 |
11 | def conv3x3(in_planes, out_planes, stride=1):
12 | "3x3 convolution with padding"
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=1, bias=False)
15 |
16 | class Bottleneck(nn.Module):
17 | expansion = 4
18 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1,
19 | norm_layer=None):
20 | super(Bottleneck, self).__init__()
21 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
22 | self.bn1 = norm_layer(planes)
23 |
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
25 | self.bn2 = norm_layer(planes)
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
28 | self.bn3 = norm_layer(planes * 4)
29 |
30 | self.relu = nn.ReLU(inplace=True)
31 |
32 | self.downsample = downsample
33 | self.dilation = dilation
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | residual = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv3(out)
48 | out = self.bn3(out)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(x)
52 |
53 | out += residual
54 | out = self.relu(out)
55 | return out
56 |
57 |
58 | class ResNet(nn.Module):
59 | def __init__(self, block, layers, num_classes=1000, dilated=False, deep_base=True, norm_layer=nn.BatchNorm2d):
60 | self.inplanes = 128 if deep_base else 64
61 | super(ResNet, self).__init__()
62 | if deep_base:
63 | self.conv1 = nn.Sequential(
64 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
65 | norm_layer(64),
66 | nn.ReLU(inplace=True),
67 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
68 | norm_layer(64),
69 | nn.ReLU(inplace=True),
70 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
71 | )
72 | else:
73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
74 | bias=False)
75 | self.bn1 = norm_layer(self.inplanes)
76 | self.relu = nn.ReLU(inplace=True)
77 |
78 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
79 |
80 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
81 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
82 | if dilated:
83 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
84 | dilation=2, norm_layer=norm_layer)
85 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
86 | dilation=4, norm_layer=norm_layer)
87 | else:
88 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
89 | norm_layer=norm_layer)
90 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
91 | norm_layer=norm_layer)
92 |
93 | self.avgpool = nn.AvgPool2d(7, stride=1)
94 | self.fc = nn.Linear(512 * block.expansion, num_classes)
95 |
96 | for m in self.modules():
97 | if isinstance(m, nn.Conv2d):
98 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
99 | m.weight.data.normal_(0, math.sqrt(2. / n))
100 | elif isinstance(m, norm_layer):
101 | m.weight.data.fill_(1)
102 | m.bias.data.zero_()
103 |
104 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False):
105 | downsample = None
106 | if stride != 1 or self.inplanes != planes * block.expansion:
107 | downsample = nn.Sequential(
108 | nn.Conv2d(self.inplanes, planes * block.expansion,
109 | kernel_size=1, stride=stride, bias=False),
110 | norm_layer(planes * block.expansion),
111 | )
112 |
113 | layers = []
114 | multi_dilations = [4, 8, 16]
115 | if multi_grid:
116 | layers.append(block(self.inplanes, planes, stride, dilation=multi_dilations[0],
117 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
118 | elif dilation == 1 or dilation == 2:
119 | layers.append(block(self.inplanes, planes, stride, dilation=1,
120 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
121 | elif dilation == 4:
122 | layers.append(block(self.inplanes, planes, stride, dilation=2,
123 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
124 | else:
125 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
126 |
127 | self.inplanes = planes * block.expansion
128 | for i in range(1, blocks):
129 | if multi_grid:
130 | layers.append(block(self.inplanes, planes, dilation=multi_dilations[i],
131 | previous_dilation=dilation, norm_layer=norm_layer))
132 | else:
133 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation,
134 | norm_layer=norm_layer))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | x = self.conv1(x)
140 | x = self.bn1(x)
141 | x = self.relu(x)
142 | x = self.maxpool(x)
143 |
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 |
149 | x = self.avgpool(x)
150 | x = x.view(x.size(0), -1)
151 | x = self.fc(x)
152 |
153 | return x
154 |
155 | def resnet50(pretrained=False, **kwargs):
156 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
157 | if pretrained:
158 | model.load_state_dict(load_state_dict_from_url(model_urls['resnet50'], "./model_data"), strict=False)
159 | return model
--------------------------------------------------------------------------------
/lib/network/submodules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as f
4 | from torch.nn import init
5 |
6 |
7 | class ConvLayer(nn.Module):
8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
9 | super(ConvLayer, self).__init__()
10 |
11 | bias = False if norm == 'BN' else True
12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
13 | if activation=='relu' or activation == "ReLU":
14 | self.activation = getattr(torch, activation, 'relu')
15 | elif activation=='sigmoid' or activation == 'Sigmoid':
16 | self.activation = getattr(torch, activation, 'sigmoid')
17 | else:
18 | self.activation = None
19 |
20 | self.norm = norm
21 | if norm == 'BN':
22 | self.norm_layer = nn.BatchNorm2d(out_channels)
23 | elif norm == 'IN':
24 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
25 |
26 | def forward(self, x):
27 | out = self.conv2d(x)
28 |
29 | if self.norm in ['BN', 'IN']:
30 | out = self.norm_layer(out)
31 |
32 | if self.activation is not None:
33 | out = self.activation(out)
34 |
35 | return out
36 |
37 |
38 | class TransposedConvLayer(nn.Module):
39 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
40 | super(TransposedConvLayer, self).__init__()
41 |
42 | bias = False if norm == 'BN' else True
43 | self.transposed_conv2d = nn.ConvTranspose2d(
44 | in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias)
45 |
46 | if activation is not None:
47 | self.activation = getattr(torch, activation, 'relu')
48 | else:
49 | self.activation = None
50 |
51 | self.norm = norm
52 | if norm == 'BN':
53 | self.norm_layer = nn.BatchNorm2d(out_channels)
54 | elif norm == 'IN':
55 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
56 |
57 | def forward(self, x):
58 | out = self.transposed_conv2d(x)
59 |
60 | if self.norm in ['BN', 'IN']:
61 | out = self.norm_layer(out)
62 |
63 | if self.activation is not None:
64 | out = self.activation(out)
65 |
66 | return out
67 |
68 |
69 | class UpsampleConvLayer(nn.Module):
70 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
71 | super(UpsampleConvLayer, self).__init__()
72 |
73 | bias = False if norm == 'BN' else True
74 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
75 |
76 | if activation is not None:
77 | self.activation = getattr(torch, activation, 'relu')
78 | else:
79 | self.activation = None
80 |
81 | self.norm = norm
82 | if norm == 'BN':
83 | self.norm_layer = nn.BatchNorm2d(out_channels)
84 | elif norm == 'IN':
85 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
86 |
87 | def forward(self, x):
88 | x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
89 | out = self.conv2d(x_upsampled)
90 |
91 | if self.norm in ['BN', 'IN']:
92 | out = self.norm_layer(out)
93 |
94 | if self.activation is not None:
95 | out = self.activation(out)
96 |
97 | return out
98 |
99 |
100 | class RecurrentConvLayer(nn.Module):
101 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
102 | recurrent_block_type='convlstm', activation='relu', norm=None):
103 | super(RecurrentConvLayer, self).__init__()
104 |
105 | assert(recurrent_block_type in ['convlstm', 'convgru'])
106 | self.recurrent_block_type = recurrent_block_type
107 | if self.recurrent_block_type == 'convlstm':
108 | RecurrentBlock = ConvLSTM
109 | else:
110 | RecurrentBlock = ConvGRU
111 | self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm)
112 | self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3)
113 |
114 | def forward(self, x, prev_state):
115 | x = self.conv(x)
116 | state = self.recurrent_block(x, prev_state)
117 | x = state[0] if self.recurrent_block_type == 'convlstm' else state
118 | return x, state
119 |
120 |
121 | class DownsampleRecurrentConvLayer(nn.Module):
122 | def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'):
123 | super(DownsampleRecurrentConvLayer, self).__init__()
124 |
125 | self.activation = getattr(torch, activation, 'relu')
126 |
127 | assert(recurrent_block_type in ['convlstm', 'convgru'])
128 | self.recurrent_block_type = recurrent_block_type
129 | if self.recurrent_block_type == 'convlstm':
130 | RecurrentBlock = ConvLSTM
131 | else:
132 | RecurrentBlock = ConvGRU
133 | self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size)
134 |
135 | def forward(self, x, prev_state):
136 | state = self.recurrent_block(x, prev_state)
137 | x = state[0] if self.recurrent_block_type == 'convlstm' else state
138 | x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
139 | return self.activation(x), state
140 |
141 |
142 | # Residual block
143 | class ResidualBlock(nn.Module):
144 | def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None):
145 | super(ResidualBlock, self).__init__()
146 | bias = False if norm == 'BN' else True
147 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias)
148 | self.norm = norm
149 | if norm == 'BN':
150 | self.bn1 = nn.BatchNorm2d(out_channels)
151 | self.bn2 = nn.BatchNorm2d(out_channels)
152 | elif norm == 'IN':
153 | self.bn1 = nn.InstanceNorm2d(out_channels)
154 | self.bn2 = nn.InstanceNorm2d(out_channels)
155 |
156 | self.relu = nn.ReLU(inplace=True)
157 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
158 | self.downsample = downsample
159 |
160 | def forward(self, x):
161 | residual = x
162 | out = self.conv1(x)
163 | if self.norm in ['BN', 'IN']:
164 | out = self.bn1(out)
165 | out = self.relu(out)
166 | out = self.conv2(out)
167 | if self.norm in ['BN', 'IN']:
168 | out = self.bn2(out)
169 |
170 | if self.downsample:
171 | residual = self.downsample(x)
172 |
173 | out += residual
174 | out = self.relu(out)
175 | return out
176 |
177 |
178 | class ConvLSTM(nn.Module):
179 | """Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """
180 |
181 | def __init__(self, input_size, hidden_size, kernel_size):
182 | super(ConvLSTM, self).__init__()
183 |
184 | self.input_size = input_size
185 | self.hidden_size = hidden_size
186 | pad = kernel_size // 2
187 |
188 | # cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled
189 | self.zero_tensors = {}
190 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad)
191 |
192 | def forward(self, input_, prev_state=None):
193 |
194 | # get batch and spatial sizes
195 | batch_size = input_.data.size()[0]
196 | spatial_size = input_.data.size()[2:]
197 |
198 | # generate empty prev_state, if None is provided
199 | if prev_state is None:
200 |
201 | # create the zero tensor if it has not been created already
202 | state_size = tuple([batch_size, self.hidden_size] + list(spatial_size))
203 | if state_size not in self.zero_tensors:
204 | # allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already)
205 | self.zero_tensors[state_size] = (
206 | torch.zeros(state_size).to(input_.device),
207 | torch.zeros(state_size).to(input_.device)
208 | )
209 |
210 | prev_state = self.zero_tensors[tuple(state_size)]
211 |
212 | prev_hidden, prev_cell = prev_state
213 |
214 | # data size is [batch, channel, height, width]
215 | stacked_inputs = torch.cat((input_, prev_hidden), 1)
216 | gates = self.Gates(stacked_inputs)
217 |
218 | # chunk across channel dimension
219 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
220 |
221 | # apply sigmoid non linearity
222 | in_gate = torch.sigmoid(in_gate)
223 | remember_gate = torch.sigmoid(remember_gate)
224 | out_gate = torch.sigmoid(out_gate)
225 |
226 | # apply tanh non linearity
227 | cell_gate = torch.tanh(cell_gate)
228 |
229 | # compute current cell and hidden state
230 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
231 | hidden = out_gate * torch.tanh(cell)
232 |
233 | return hidden, cell
234 |
235 |
236 | class ConvGRU(nn.Module):
237 | """
238 | Generate a convolutional GRU cell
239 | Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py
240 | """
241 |
242 | def __init__(self, input_size, hidden_size, kernel_size):
243 | super().__init__()
244 | padding = kernel_size // 2
245 | self.input_size = input_size
246 | self.hidden_size = hidden_size
247 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
248 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
249 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
250 |
251 | init.orthogonal_(self.reset_gate.weight)
252 | init.orthogonal_(self.update_gate.weight)
253 | init.orthogonal_(self.out_gate.weight)
254 | init.constant_(self.reset_gate.bias, 0.)
255 | init.constant_(self.update_gate.bias, 0.)
256 | init.constant_(self.out_gate.bias, 0.)
257 |
258 | def forward(self, input_, prev_state):
259 |
260 | # get batch and spatial sizes
261 | batch_size = input_.data.size()[0]
262 | spatial_size = input_.data.size()[2:]
263 |
264 | # generate empty prev_state, if None is provided
265 | if prev_state is None:
266 | state_size = [batch_size, self.hidden_size] + list(spatial_size)
267 | prev_state = torch.zeros(state_size).to(input_.device)
268 |
269 | # data size is [batch, channel, height, width]
270 | stacked_inputs = torch.cat([input_, prev_state], dim=1)
271 | update = torch.sigmoid(self.update_gate(stacked_inputs))
272 | reset = torch.sigmoid(self.reset_gate(stacked_inputs))
273 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1)))
274 | new_state = prev_state * (1 - update) + out_inputs * update
275 |
276 | return new_state
277 |
278 | class RecurrentResidualLayer(nn.Module):
279 | def __init__(self, in_channels, out_channels,
280 | recurrent_block_type='convlstm', norm=None):
281 | super(RecurrentResidualLayer, self).__init__()
282 |
283 | assert(recurrent_block_type in ['convlstm', 'convgru'])
284 | self.recurrent_block_type = recurrent_block_type
285 | if self.recurrent_block_type == 'convlstm':
286 | RecurrentBlock = ConvLSTM
287 | else:
288 | RecurrentBlock = ConvGRU
289 | self.conv = ResidualBlock(in_channels=in_channels,
290 | out_channels=out_channels,
291 | norm=norm)
292 | self.recurrent_block = RecurrentBlock(input_size=out_channels,
293 | hidden_size=out_channels,
294 | kernel_size=3)
295 |
296 | def forward(self, x, prev_state):
297 | x = self.conv(x)
298 | state = self.recurrent_block(x, prev_state)
299 | x = state[0] if self.recurrent_block_type == 'convlstm' else state
300 | return x, state
301 |
302 |
--------------------------------------------------------------------------------
/lib/network/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as f
4 | from torch.nn import init
5 | from .submodules import ConvLayer, UpsampleConvLayer, TransposedConvLayer, RecurrentConvLayer, ResidualBlock, ConvLSTM, ConvGRU, RecurrentResidualLayer
6 |
7 |
8 | def skip_concat(x1, x2):
9 | return torch.cat([x1, x2], dim=1)
10 |
11 |
12 | def skip_sum(x1, x2):
13 | return x1 + x2
14 |
15 |
16 | class BaseUNet(nn.Module):
17 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid',
18 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True):
19 | super(BaseUNet, self).__init__()
20 |
21 | self.num_input_channels = num_input_channels
22 | self.num_output_channels = num_output_channels
23 | self.skip_type = skip_type
24 | self.apply_skip_connection = skip_sum if self.skip_type == 'sum' else skip_concat
25 | self.activation = activation
26 | self.norm = norm
27 |
28 | if use_upsample_conv:
29 | print('Using UpsampleConvLayer (slow, but no checkerboard artefacts)')
30 | self.UpsampleLayer = UpsampleConvLayer
31 | else:
32 | print('Using TransposedConvLayer (fast, with checkerboard artefacts)')
33 | self.UpsampleLayer = TransposedConvLayer
34 |
35 | self.num_encoders = num_encoders
36 | self.base_num_channels = base_num_channels
37 | self.num_residual_blocks = num_residual_blocks
38 | self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders)
39 |
40 | assert(self.num_input_channels > 0)
41 | assert(self.num_output_channels > 0)
42 |
43 | self.encoder_input_sizes = []
44 | for i in range(self.num_encoders):
45 | self.encoder_input_sizes.append(self.base_num_channels * pow(2, i))
46 |
47 | self.encoder_output_sizes = [self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)]
48 |
49 | self.activation = getattr(torch, self.activation, 'sigmoid')
50 |
51 | def build_resblocks(self):
52 | self.resblocks = nn.ModuleList()
53 | for i in range(self.num_residual_blocks):
54 | self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm))
55 |
56 | def build_decoders(self):
57 | decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)]))
58 |
59 | self.decoders = nn.ModuleList()
60 | for input_size in decoder_input_sizes:
61 | self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size,
62 | input_size // 2,
63 | kernel_size=5, padding=2, norm=self.norm))
64 |
65 | def build_prediction_layer(self):
66 | self.pred = ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels,
67 | self.num_output_channels, 1, activation=None, norm=self.norm)
68 |
69 |
70 | class UNet(BaseUNet):
71 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid',
72 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True):
73 | super(UNet, self).__init__(num_input_channels, num_output_channels, skip_type, activation,
74 | num_encoders, base_num_channels, num_residual_blocks, norm, use_upsample_conv)
75 |
76 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels,
77 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W
78 |
79 | self.encoders = nn.ModuleList()
80 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes):
81 | self.encoders.append(ConvLayer(input_size, output_size, kernel_size=5,
82 | stride=2, padding=2, norm=self.norm))
83 |
84 | self.build_resblocks()
85 | self.build_decoders()
86 | self.build_prediction_layer()
87 |
88 | def forward(self, x):
89 | """
90 | :param x: N x num_input_channels x H x W
91 | :return: N x num_output_channels x H x W
92 | """
93 |
94 | # head
95 | x = self.head(x)
96 | head = x
97 |
98 | # encoder
99 | blocks = []
100 | for i, encoder in enumerate(self.encoders):
101 | x = encoder(x)
102 | blocks.append(x)
103 |
104 | # residual blocks
105 | for resblock in self.resblocks:
106 | x = resblock(x)
107 |
108 | # decoder
109 | for i, decoder in enumerate(self.decoders):
110 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1]))
111 |
112 | img = self.activation(self.pred(self.apply_skip_connection(x, head)))
113 |
114 | return img
115 |
116 | def get_features(self, x):
117 | """
118 | :param x: N x num_input_channels x H x W
119 | :return: N x num_output_channels x H x W
120 | """
121 |
122 | # head
123 | x = self.head(x)
124 | head = x
125 |
126 | # encoder
127 | blocks = []
128 | for i, encoder in enumerate(self.encoders):
129 | x = encoder(x)
130 | blocks.append(x)
131 |
132 | # residual blocks
133 | for resblock in self.resblocks:
134 | x = resblock(x)
135 |
136 | mid_feat = x.clone()
137 |
138 | # decoder
139 | for i, decoder in enumerate(self.decoders):
140 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1]))
141 |
142 | final_feature = x.clone()
143 | img = self.activation(self.pred(self.apply_skip_connection(x, head)))
144 |
145 | return img, final_feature
146 |
147 |
148 | class UNetRecurrent(BaseUNet):
149 | """
150 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block,
151 | such as a ConvLSTM or a ConvGRU.
152 | Symmetric, skip connections on every encoding layer.
153 | """
154 |
155 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum',
156 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32,
157 | num_residual_blocks=2, norm=None, use_upsample_conv=True):
158 | super(UNetRecurrent, self).__init__(num_input_channels, num_output_channels, skip_type, activation,
159 | num_encoders, base_num_channels, num_residual_blocks, norm,
160 | use_upsample_conv)
161 |
162 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels,
163 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W
164 |
165 | self.encoders = nn.ModuleList()
166 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes):
167 | self.encoders.append(RecurrentConvLayer(input_size, output_size,
168 | kernel_size=5, stride=2, padding=2,
169 | recurrent_block_type=recurrent_block_type,
170 | norm=self.norm))
171 |
172 | self.build_resblocks()
173 | self.build_decoders()
174 | self.build_prediction_layer()
175 |
176 | def forward(self, x, prev_states):
177 | """
178 | :param x: N x num_input_channels x H x W
179 | :param prev_states: previous LSTM states for every encoder layer
180 | :return: N x num_output_channels x H x W
181 | """
182 |
183 | # head
184 | x = self.head(x)
185 | head = x
186 |
187 | if prev_states is None:
188 | prev_states = [None] * self.num_encoders
189 |
190 | # encoder
191 | blocks = []
192 | states = []
193 | for i, encoder in enumerate(self.encoders):
194 | x, state = encoder(x, prev_states[i])
195 | blocks.append(x)
196 | states.append(state)
197 |
198 | # residual blocks
199 | for resblock in self.resblocks:
200 | x = resblock(x)
201 |
202 | # decoder
203 | for i, decoder in enumerate(self.decoders):
204 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1]))
205 |
206 | # tail
207 | img = self.activation(self.pred(self.apply_skip_connection(x, head)))
208 |
209 | return img, states
210 |
211 |
212 | class UNetFire(BaseUNet):
213 | """
214 | """
215 |
216 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum',
217 | recurrent_block_type='convgru', base_num_channels=16,
218 | num_residual_blocks=2, norm=None, kernel_size=3,
219 | recurrent_blocks={'resblock': [0]}):
220 | super(UNetFire, self).__init__(num_input_channels=num_input_channels,
221 | num_output_channels=num_output_channels,
222 | skip_type=skip_type,
223 | base_num_channels=base_num_channels,
224 | num_residual_blocks=num_residual_blocks,
225 | norm=norm)
226 |
227 | self.kernel_size = kernel_size
228 | self.recurrent_blocks = recurrent_blocks
229 | self.head = RecurrentConvLayer(self.num_input_channels,
230 | self.base_num_channels,
231 | kernel_size=self.kernel_size,
232 | padding=self.kernel_size // 2,
233 | recurrent_block_type=recurrent_block_type,
234 | norm=self.norm)
235 |
236 | self.num_recurrent_units = 1
237 | self.resblocks = nn.ModuleList()
238 | recurrent_indices = self.recurrent_blocks.get('resblock', [])
239 | for i in range(self.num_residual_blocks):
240 | if i in recurrent_indices or -1 in recurrent_indices:
241 | self.resblocks.append(RecurrentResidualLayer(
242 | in_channels=self.base_num_channels,
243 | out_channels=self.base_num_channels,
244 | recurrent_block_type=recurrent_block_type,
245 | norm=self.norm))
246 | self.num_recurrent_units += 1
247 | else:
248 | self.resblocks.append(ResidualBlock(self.base_num_channels,
249 | self.base_num_channels,
250 | norm=self.norm))
251 |
252 | self.pred = ConvLayer(2 * self.base_num_channels if self.skip_type == 'concat' else self.base_num_channels,
253 | self.num_output_channels, kernel_size=1, padding=0, activation=None, norm=None)
254 |
255 | def forward(self, x, prev_states):
256 | """
257 | :param x: N x num_input_channels x H x W
258 | :param prev_states: previous LSTM states for every encoder layer
259 | :return: N x num_output_channels x H x W
260 | """
261 |
262 | if prev_states is None:
263 | prev_states = [None] * (self.num_recurrent_units)
264 |
265 | states = []
266 | state_idx = 0
267 |
268 | # head
269 | x, state = self.head(x, prev_states[state_idx])
270 | state_idx += 1
271 | states.append(state)
272 |
273 | # residual blocks
274 | recurrent_indices = self.recurrent_blocks.get('resblock', [])
275 | for i, resblock in enumerate(self.resblocks):
276 | if i in recurrent_indices or -1 in recurrent_indices:
277 | x, state = resblock(x, prev_states[state_idx])
278 | state_idx += 1
279 | states.append(state)
280 | else:
281 | x = resblock(x)
282 |
283 | # tail
284 | img = self.pred(x)
285 | return img, states
286 |
287 |
288 |
289 | class UNetStatic(BaseUNet):
290 | """
291 | """
292 |
293 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum',
294 | recurrent_block_type='convgru', base_num_channels=16,
295 | num_residual_blocks=2, norm=None, kernel_size=3,
296 | recurrent_blocks={'resblock': [0]}):
297 | super(UNetStatic, self).__init__(num_input_channels=num_input_channels,
298 | num_output_channels=num_output_channels,
299 | skip_type=skip_type,
300 | base_num_channels=base_num_channels,
301 | num_residual_blocks=num_residual_blocks,
302 | norm=norm)
303 |
304 | self.kernel_size = kernel_size
305 | self.recurrent_blocks = recurrent_blocks
306 | self.head = ConvLayer(self.num_input_channels,
307 | self.base_num_channels,
308 | kernel_size=self.kernel_size,
309 | padding=self.kernel_size // 2,
310 | norm=self.norm)
311 |
312 | self.num_recurrent_units = 1
313 | self.resblocks = nn.ModuleList()
314 |
315 | self.resblocks.append(ResidualBlock(self.base_num_channels,
316 | self.base_num_channels,
317 | norm=self.norm))
318 |
319 | self.pred = ConvLayer(2 * self.base_num_channels if self.skip_type == 'concat' else self.base_num_channels,
320 | self.num_output_channels, kernel_size=1, padding=0, activation='relu', norm=None)
321 |
322 | def forward(self, x):
323 | """
324 | :param x: N x num_input_channels x H x W
325 | :param prev_states: previous LSTM states for every encoder layer
326 | :return: N x num_output_channels x H x W
327 | """
328 | # head
329 | x = self.head(x)
330 |
331 | # residual blocks
332 | for i, resblock in enumerate(self.resblocks):
333 | x = resblock(x)
334 |
335 | # tail
336 | img = self.pred(x)
337 | img = torch.clamp_max(img, 1.0)
338 | return img
--------------------------------------------------------------------------------
/lib/recorder.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import shutil
5 | import logging
6 | from pathlib import Path
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 |
10 | def file_backup(exp_path, cfg, train_script):
11 | shutil.copy(train_script, exp_path)
12 | shutil.copytree('core', os.path.join(exp_path, 'core'), dirs_exist_ok=True)
13 | shutil.copytree('config', os.path.join(exp_path, 'config'), dirs_exist_ok=True)
14 | shutil.copytree('gaussian_renderer', os.path.join(exp_path, 'gaussian_renderer'), dirs_exist_ok=True)
15 | for sub_dir in ['lib']:
16 | files = os.listdir(sub_dir)
17 | for file in files:
18 | Path(os.path.join(exp_path, sub_dir)).mkdir(exist_ok=True, parents=True)
19 | if file[-3:] == '.py':
20 | shutil.copy(os.path.join(sub_dir, file), os.path.join(exp_path, sub_dir))
21 |
22 | json_file_name = exp_path + '/cfg.json'
23 | with open(json_file_name, 'w') as json_file:
24 | json.dump(cfg, json_file, indent=2)
25 |
26 | class Logger:
27 | def __init__(self, scheduler, cfg):
28 | self.scheduler = scheduler
29 | self.sum_freq = cfg.loss_freq
30 | self.log_dir = cfg.logs_path
31 | self.total_steps = 0
32 | self.running_loss = {}
33 | self.writer = SummaryWriter(log_dir=self.log_dir)
34 |
35 | def _print_training_status(self):
36 | metrics_data = [self.running_loss[k] / self.sum_freq for k in sorted(self.running_loss.keys())]
37 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps, self.scheduler.get_last_lr()[0])
38 | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
39 |
40 | # print the training status
41 | # print(" in print training status : ", f"steps : {self.total_steps}): {training_str + metrics_str}")
42 | # logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
43 |
44 | if self.writer is None:
45 | self.writer = SummaryWriter(log_dir=self.log_dir)
46 |
47 | for k in self.running_loss:
48 | self.writer.add_scalar(k, self.running_loss[k] / self.sum_freq, self.total_steps)
49 | self.running_loss[k] = 0.0
50 |
51 | def push(self, metrics):
52 | for key in metrics:
53 | if key not in self.running_loss:
54 | self.running_loss[key] = 0.0
55 |
56 | self.running_loss[key] += metrics[key]
57 |
58 | if self.total_steps and self.total_steps % self.sum_freq == 0:
59 | self._print_training_status()
60 | self.running_loss = {}
61 |
62 | self.total_steps += 1
63 |
64 |
65 | def write_dict(self, results, write_step):
66 | if self.writer is None:
67 | self.writer = SummaryWriter(log_dir=self.log_dir)
68 |
69 | for key in results:
70 | self.writer.add_scalar(key, results[key], write_step)
71 |
72 | def close(self):
73 | self.writer.close()
--------------------------------------------------------------------------------
/lib/renderer/__init__.py:
--------------------------------------------------------------------------------
1 | from .gaussian_render import pts2render
2 | from .rend_utils import depth2pc
--------------------------------------------------------------------------------
/lib/renderer/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/renderer/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/renderer/__pycache__/gaussian_render.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/gaussian_render.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/renderer/__pycache__/gaussian_render.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/gaussian_render.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/renderer/__pycache__/rend_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mercerai/EvGGS/4c97c29da989a8555e04e2a0020b46ada8861665/lib/renderer/__pycache__/rend_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/renderer/gaussian_render.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
4 |
5 | def render(cam, idx, pts_xyz, pts_rgb, rotations, scales, opacity, bg_color):
6 | """
7 | Render the scene.
8 | Background tensor (bg_color) must be on GPU!
9 | """
10 |
11 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
12 | bg_color = torch.tensor(bg_color, dtype=torch.float32, device=pts_xyz.device)
13 | screenspace_points = torch.zeros_like(pts_xyz, dtype=torch.float32, requires_grad=True, device=pts_xyz.device) + 0
14 | try:
15 | screenspace_points.retain_grad()
16 | except:
17 | pass
18 |
19 | # Set up rasterization configuration
20 | tanfovx = math.tan(cam['FovX'][idx] * 0.5)
21 | tanfovy = math.tan(cam['FovY'][idx] * 0.5)
22 |
23 | raster_settings = GaussianRasterizationSettings(
24 | image_height=int(cam['H'][idx]),
25 | image_width=int(cam['W'][idx]),
26 | tanfovx=tanfovx,
27 | tanfovy=tanfovy,
28 | bg=bg_color,
29 | scale_modifier=1.0,
30 | viewmatrix=cam['world_view_transform'][idx],
31 | projmatrix=cam['full_proj_transform'][idx],
32 | sh_degree=3,
33 | campos=cam['camera_center'][idx],
34 | prefiltered=False,
35 | debug=True
36 | )
37 |
38 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
39 |
40 | rendered_image, _ = rasterizer(
41 | means3D=pts_xyz,
42 | means2D=screenspace_points,
43 | shs=None,
44 | colors_precomp=pts_rgb,
45 | opacities=opacity,
46 | scales=scales,
47 | rotations=rotations,
48 | cov3D_precomp=None)
49 | # print("render image shape : ", rendered_image.shape)
50 |
51 | return rendered_image
52 |
53 | # def pts2render(cam, pcd, gs_scaling, gs_opacity, gs_rotation, bg_color):
54 | # bs = pcd.shape[0]
55 | # # print(gs_data)
56 | # render_novel_list = []
57 | # for i in range(bs):
58 | # xyz_i = pcd[i, :3, :].permute(1, 0)
59 | # rgb_i = pcd[i, 3:6, :].permute(1, 0)
60 | # scale_i = gs_scaling[i].permute(1, 0)
61 | # opacity_i = gs_opacity[i].permute(1, 0)
62 | # rot_i = gs_rotation[i].permute(1, 0)
63 | # render_novel_i = render(cam, i, xyz_i, rgb_i, rot_i, scale_i, opacity_i, bg_color=bg_color)
64 | # render_novel_list.append(render_novel_i)
65 |
66 | # return torch.stack(render_novel_list, dim=0)
67 |
68 | def pts2render(data, bg_color):
69 | bs = data['lview']['img'].shape[0]
70 | render_novel_list = []
71 | for i in range(bs):
72 | xyz_i_valid = []
73 | rgb_i_valid = []
74 | rot_i_valid = []
75 | scale_i_valid = []
76 | opacity_i_valid = []
77 | for view in ['lview', 'rview']:
78 | valid_i = data[view]['pts_valid'][i, :].bool()
79 | xyz_i = data[view]['pts'][i, :, :]
80 | rgb_i = data[view]['img'][i, :, :, :].permute(1, 2, 0).view(-1, 1)
81 | rot_i = data[view]['rot'][i, :, :, :].permute(1, 2, 0).view(-1, 4)
82 | scale_i = data[view]['scale'][i, :, :, :].permute(1, 2, 0).view(-1, 3)
83 | opacity_i = data[view]['opacity'][i, :, :, :].permute(1, 2, 0).view(-1, 1)
84 |
85 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3))
86 | rgb_i_valid.append(rgb_i[valid_i].view(-1, 1))
87 | rot_i_valid.append(rot_i[valid_i].view(-1, 4))
88 | scale_i_valid.append(scale_i[valid_i].view(-1, 3))
89 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1))
90 |
91 | pts_xyz_i = torch.concat(xyz_i_valid, dim=0)
92 | pts_rgb_i = torch.concat(rgb_i_valid, dim=0).repeat((1,3))
93 | # pts_rgb_i = pts_rgb_i * 0.5 + 0.5
94 | rot_i = torch.concat(rot_i_valid, dim=0)
95 | scale_i = torch.concat(scale_i_valid, dim=0)
96 | opacity_i = torch.concat(opacity_i_valid, dim=0)
97 |
98 | render_novel_i = render(data["target"], i, pts_xyz_i, pts_rgb_i, rot_i, scale_i, opacity_i, bg_color=bg_color)
99 | render_novel_list.append(render_novel_i.unsqueeze(0))
100 |
101 | return torch.concat(render_novel_list, dim=0)
102 |
--------------------------------------------------------------------------------
/lib/renderer/rend_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as Rot
5 | from scipy.spatial.transform import Slerp
6 |
7 |
8 | def focal2fov(focal, pixels):
9 | return 2 * math.atan(pixels / (2 * focal))
10 |
11 |
12 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
13 | Rt = np.zeros((4, 4))
14 | Rt[:3, :3] = R.transpose()
15 | Rt[:3, 3] = t
16 | Rt[3, 3] = 1.0
17 |
18 | C2W = np.linalg.inv(Rt)
19 | cam_center = C2W[:3, 3]
20 | cam_center = (cam_center + translate) * scale
21 | C2W[:3, 3] = cam_center
22 | Rt = np.linalg.inv(C2W)
23 | return np.float32(Rt)
24 |
25 |
26 | def getProjectionMatrix(znear, zfar, K, h, w):
27 | near_fx = znear / K[0, 0]
28 | near_fy = znear / K[1, 1]
29 | left = - (w - K[0, 2]) * near_fx
30 | right = K[0, 2] * near_fx
31 | bottom = (K[1, 2] - h) * near_fy
32 | top = K[1, 2] * near_fy
33 |
34 | P = torch.zeros(4, 4)
35 | z_sign = 1.0
36 | P[0, 0] = 2.0 * znear / (right - left)
37 | P[1, 1] = 2.0 * znear / (top - bottom)
38 | P[0, 2] = (right + left) / (right - left)
39 | P[1, 2] = (top + bottom) / (top - bottom)
40 | P[3, 2] = z_sign
41 | P[2, 2] = z_sign * zfar / (zfar - znear)
42 | P[2, 3] = -(zfar * znear) / (zfar - znear)
43 | return P
44 |
45 |
46 | def preprocess_render(batch):
47 | H, W = batch['H'][0], batch['W'][0]
48 | extrs = batch["cam_extrinsics"]
49 | intrs = batch["cam_intrinsics"]
50 | # znear, zfar = batch["znear"], batch["zfar"]
51 | znear, zfar = 0.5, 10000
52 | B = extrs.shape[0] # `
53 |
54 | proj_mat = [getProjectionMatrix(znear, zfar, intrs[i], H, W).transpose(0, 1) for i in range(B)]
55 | world_view_transform = [
56 | getWorld2View2(extrs[i][:3, :3].reshape(3, 3).transpose(1, 0), extrs[i][:3, 3]).transpose(0, 1) for i in
57 | range(B)]
58 | proj_mat = torch.stack(proj_mat, dim=0) # [4,4]
59 | # print("proj mat = ", proj_mat)
60 |
61 | world_view_transform = torch.stack(world_view_transform, dim=0) # [4,4]
62 |
63 | full_proj_transform = (world_view_transform.bmm(proj_mat))
64 | camera_center = world_view_transform.inverse()[:, 3, :3]
65 |
66 | FovX = [torch.FloatTensor([focal2fov(intrs[i][0, 0], W)]) for i in range(B)]
67 |
68 | # print("111",FovX[0])
69 | FovY = [torch.FloatTensor([focal2fov(intrs[i][1, 1], H)]) for i in range(B)]
70 |
71 | return {"projection_matrix": proj_mat,
72 | "world_view_transform": world_view_transform,
73 | "full_proj_transform": full_proj_transform,
74 | "camera_center": camera_center,
75 | "H": torch.ones(B) * H,
76 | "W": torch.ones(B) * W,
77 | "FovX": torch.stack(FovX, dim=0),
78 | "FovY": torch.stack(FovY, dim=0)
79 | }
80 |
81 | def depth2pc(depth, extrinsic, intrinsic):
82 | B, C, H, W = depth.shape
83 | depth = depth[:, 0, :, :]
84 | rot = extrinsic[:, :3, :3]
85 | trans = extrinsic[:, :3, 3:]
86 |
87 | y, x = torch.meshgrid(torch.linspace(0.5, H-0.5, H, device=depth.device), torch.linspace(0.5, W-0.5, W, device=depth.device))
88 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1) # B S S 3
89 |
90 | pts_2d[..., 2] = depth
91 | pts_2d[:, :, :, 0] -= intrinsic[:, None, None, 0, 2]
92 | pts_2d[:, :, :, 1] -= intrinsic[:, None, None, 1, 2]
93 | pts_2d_xy = pts_2d[:, :, :, :2] * pts_2d[:, :, :, 2:]
94 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1)
95 |
96 | pts_2d[..., 0] /= intrinsic[:, 0, 0][:, None, None]
97 | pts_2d[..., 1] /= intrinsic[:, 1, 1][:, None, None]
98 |
99 | pts_2d = pts_2d.view(B, -1, 3).permute(0, 2, 1)
100 | rot_t = rot.permute(0, 2, 1)
101 | pts = torch.bmm(rot_t, pts_2d) - torch.bmm(rot_t, trans)
102 |
103 | return pts.permute(0, 2, 1)
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
3 | from matplotlib.backends.backend_agg import FigureCanvasAgg
4 | from matplotlib.figure import Figure
5 | import matplotlib as mpl
6 | from matplotlib import cm
7 | import cv2
8 |
9 | def get_vertical_colorbar(h, vmin, vmax, cmap_name='jet', label=None):
10 | fig = Figure(figsize=(1.2, 8), dpi=100)
11 | fig.subplots_adjust(right=1.5)
12 | canvas = FigureCanvasAgg(fig)
13 | # Do some plotting.
14 | ax = fig.add_subplot(111)
15 | cmap = cm.get_cmap(cmap_name)
16 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
17 | tick_cnt = 6
18 | tick_loc = np.linspace(vmin, vmax, tick_cnt)
19 | cb1 = mpl.colorbar.ColorbarBase(ax, cmap=cmap,
20 | norm=norm,
21 | ticks=tick_loc,
22 | orientation='vertical')
23 | tick_label = ['{:3.2f}'.format(x) for x in tick_loc]
24 | cb1.set_ticklabels(tick_label)
25 | cb1.ax.tick_params(labelsize=18, rotation=0)
26 | if label is not None:
27 | cb1.set_label(label)
28 | fig.tight_layout()
29 | canvas.draw()
30 | s, (width, height) = canvas.print_to_buffer()
31 | im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
32 | im = im[:, :, :3].astype(np.float32) / 255.
33 | if h != im.shape[0]:
34 | w = int(im.shape[1] / im.shape[0] * h)
35 | im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
36 | return im
37 |
38 | def colorize_np(x, cmap_name='jet', mask=None, append_cbar=False):
39 | HUGE_NUMBER = 1e10
40 | TINY_NUMBER = 1e-6
41 | if mask is not None:
42 | # vmin, vmax = np.percentile(x[mask], (1, 99))
43 | vmin = np.min(x[mask])
44 | vmax = np.max(x[mask])
45 | vmin = vmin - np.abs(vmin) * 0.01
46 | x[np.logical_not(mask)] = vmin
47 | x = np.clip(x, vmin, vmax)
48 | # print(vmin, vmax)
49 | else:
50 | vmin = x.min()
51 | vmax = x.max() + TINY_NUMBER
52 | x = (x - vmin) / (vmax - vmin)
53 | # x = np.clip(x, 0., 1.)
54 | cmap = cm.get_cmap(cmap_name)
55 | x_new = cmap(x)[:, :, :3]
56 | if mask is not None:
57 | mask = np.float32(mask[:, :, np.newaxis])
58 | x_new = x_new * mask + np.zeros_like(x_new) * (1. - mask)
59 | cbar = get_vertical_colorbar(h=x.shape[0], vmin=vmin, vmax=vmax, cmap_name=cmap_name)
60 | if append_cbar:
61 | x_new = np.concatenate((x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1)
62 | return x_new
63 | else:
64 | return x_new, cbar
65 |
66 | # Traditional depth2img of taichi-nerf
67 | # def depth2img(depth):
68 | # depth = (depth - depth.min()) / (depth.max() - depth.min())
69 | # depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),
70 | # cv2.COLORMAP_TURBO)
71 | # return depth_img
72 |
73 | ### depth2img from EventNeRF ######
74 | def depth2img(depth):
75 | im,cbar = colorize_np(depth, cmap_name='jet', append_cbar=False)
76 | im = to8b(im)
77 | return im
78 |
--------------------------------------------------------------------------------
/pretrain_ckpt/download.sh:
--------------------------------------------------------------------------------
1 | ### download pretrain files ###
--------------------------------------------------------------------------------
/train_gs.py:
--------------------------------------------------------------------------------
1 | from lib.config import cfg, args
2 | from lib.dataset import EventDataloader
3 | from lib.recorder import Logger, file_backup
4 | from lib.network import model_loss_light, model_loss, EventGaussian
5 | from lib.renderer import pts2render, depth2pc
6 | from lib.utils import depth2img
7 | import numpy as np
8 | import imageio
9 | import cv2
10 | import os
11 | from pathlib import Path
12 | from tqdm import tqdm
13 | import logging
14 | import torch
15 | from torch import optim
16 | from torch.utils.data import DataLoader
17 | import torch.nn.functional as F
18 |
19 | cs = cfg.cs
20 |
21 | class Trainer:
22 | def __init__(self) -> None:
23 | device = torch.device('cuda:{}'.format(cfg.local_rank))
24 | self.device = device
25 | which_test = "val"
26 | self.train_loader = EventDataloader(cfg.dataset.base_folder, split="train", num_workers=1,\
27 | batch_size=1, shuffle=False)
28 |
29 | self.val_loader = EventDataloader(cfg.dataset.base_folder, split=which_test, num_workers=1,\
30 | batch_size=1, shuffle=False)
31 |
32 | self.len_val = len(self.val_loader)
33 | self.model = EventGaussian().to(self.device)
34 | print(" Load warm up parameters ... ")
35 | d_warmup = False
36 | int_warmup = False
37 | if cfg.depth_warmup_ckpt is not None:
38 | self.model.depth_estimator.load_state_dict(torch.load(cfg.depth_warmup_ckpt)["network"])
39 | d_warmup = True
40 | if cfg.intensity_warmup_ckpt is not None:
41 | self.model.intensity_estimator.load_state_dict(torch.load(cfg.intensity_warmup_ckpt)["network"])
42 | int_warmup = True
43 | print(f" Using depth warm up {d_warmup} ; intensity warm up {int_warmup}")
44 |
45 | # self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=cfg.wdecay, eps=1e-8)
46 | dpt_params = list(map(id,self.model.depth_estimator.parameters())) + list(map(id,self.model.intensity_estimator.parameters()))
47 | rest_params = filter(lambda x:id(x) not in dpt_params,self.model.parameters())
48 | self.optimizer = optim.Adam([
49 | {'params':self.model.depth_estimator.parameters(), 'lr':0.00001},
50 | {'params':self.model.intensity_estimator.parameters(), 'lr':0.00001},
51 | {'params':rest_params, 'lr':0.0005},
52 | ], lr=0.0005, weight_decay=cfg.wdecay, eps=1e-8)
53 |
54 | # self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, 0.001, 1000000 + 100,
55 | # pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
56 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=15000, gamma=0.9)
57 |
58 | self.logger = Logger(self.scheduler, cfg.record)
59 |
60 | self.total_steps = 0
61 | self.target_epoch = cfg.target_epoch
62 |
63 | if cfg.restore_ckpt:
64 | self.load_ckpt(cfg.restore_ckpt)
65 |
66 | self.model.train()
67 |
68 | def train(self):
69 | # self.model.eval()
70 | # self.run_eval()
71 | # self.model.train()
72 | for e in range(self.target_epoch):
73 | for idx, batch in enumerate(tqdm(self.train_loader)):
74 | batch = self.to_cuda(batch)
75 |
76 | ### model and loss computing ##
77 | gt = {}
78 | gt["cim"] = batch["cim"]
79 |
80 | gt["lim"], gt["rim"], gt["ldepth"], gt["rdepth"], gt["lmask"], gt["rmask"] \
81 | = batch["lim"], batch["rim"], batch["ldepth"], batch["rdepth"], batch["lmask"], batch["rmask"]
82 |
83 | batch["left_event_tensor"] = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1)
84 | batch["right_event_tensor"] = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1)
85 |
86 | data = self.model(batch)
87 |
88 | data["target"] = {"H":batch["H"],
89 | "W":batch["W"],
90 | "FovX":batch["FovX"],
91 | "FovY":batch["FovY"],
92 | 'world_view_transform': batch["world_view_transform"],
93 | 'full_proj_transform': batch["full_proj_transform"],
94 | 'camera_center': batch["camera_center"]}
95 |
96 | imgL, depthL, maskL = data["lview"]["img"], data["lview"]["depth"], data["lview"]["mask"]
97 | imgR, depthR, maskR = data["rview"]["img"], data["rview"]["depth"], data["rview"]["mask"]
98 |
99 | data["lview"]["pts"] = depth2pc(data["lview"]["depth"], torch.inverse(batch["lpose"]), batch["intrinsic"])
100 | data["rview"]["pts"] = depth2pc(data["rview"]["depth"], torch.inverse(batch["rpose"]), batch["intrinsic"])
101 |
102 | pred = pts2render(data, [0.,0.,0.])[:,0:1]
103 | loss = F.l1_loss(pred, gt["cim"])
104 |
105 | imgloss = torch.mean((imgL - gt["lim"])**2) + torch.mean((imgR - gt["rim"])**2)
106 | depthloss = F.l1_loss(depthL, gt["ldepth"]) + F.l1_loss(depthR, gt["rdepth"])
107 | maskloss = F.binary_cross_entropy(maskL.reshape(-1, 1), gt["lmask"].reshape(-1, 1).float()) + \
108 | F.binary_cross_entropy(maskR.reshape(-1, 1), gt["rmask"].reshape(-1, 1).float())
109 | loss = loss + 0.33*imgloss + 0.33*depthloss + 0.33*maskloss
110 | # msk = torch.ones_like(gt).to(bool)
111 | metrics = {
112 | "l1loss" : loss.item()
113 | }
114 |
115 | if self.total_steps and self.total_steps % cfg.record.loss_freq == 0:
116 | self.logger.writer.add_scalar(f'lr', self.optimizer.param_groups[0]['lr'], self.total_steps)
117 | print(f"{cfg.exp_name} epoch {e} step {self.total_steps} L1loss {loss.item()} lr {self.optimizer.param_groups[0]['lr']}")
118 | self.logger.push(metrics)
119 |
120 | if self.total_steps and self.total_steps % cfg.record.save_freq == 0:
121 | self.save_ckpt(save_path=Path('%s/%d_%d.pth' % (cfg.record.ckpt_path, e, self.total_steps)), show_log=False)
122 |
123 | self.optimizer.zero_grad()
124 | loss.backward()
125 | self.optimizer.step()
126 | self.scheduler.step()
127 |
128 | if self.total_steps and self.total_steps % cfg.record.eval_freq == 0:
129 | self.model.eval()
130 | self.run_eval()
131 | self.model.train()
132 |
133 | self.total_steps += 1
134 |
135 | print("FINISHED TRAINING")
136 | self.logger.close()
137 | self.save_ckpt(save_path=Path('%s/%s_final.pth' % (cfg.record.ckpt_path, cfg.exp_name)))
138 |
139 | def to_cuda(self, batch):
140 | for k in batch:
141 | if isinstance(batch[k], tuple) or isinstance(batch[k], list):
142 | batch[k] = [b.to(self.device) for b in batch[k]]
143 | elif isinstance(batch[k], dict):
144 | batch[k] = {key: self.to_cuda(batch[k][key]) for key in batch[k]}
145 | else:
146 | batch[k] = batch[k].to(self.device)
147 | return batch
148 |
149 | def run_eval(self):
150 | print(f"Doing validation ...")
151 | torch.cuda.empty_cache()
152 |
153 | l1_list = []
154 |
155 | show_idx = [np.random.choice(list(range(self.len_val)), 1)]
156 | # show_idx = 0
157 | # show_idx = list(range(self.len_val))
158 | for idx, batch in enumerate(self.val_loader):
159 | with torch.no_grad():
160 | batch = self.to_cuda(batch)
161 | gt = batch["cim"]
162 |
163 | batch["left_event_tensor"] = torch.cat([batch["leframe"], batch["left_voxel"]], dim=1)
164 | batch["right_event_tensor"] = torch.cat([batch["reframe"], batch["right_voxel"]], dim=1)
165 |
166 | data = self.model(batch)
167 |
168 | data["target"] = {"H":batch["H"],
169 | "W":batch["W"],
170 | "FovX":batch["FovX"],
171 | "FovY":batch["FovY"],
172 | 'world_view_transform': batch["world_view_transform"],
173 | 'full_proj_transform': batch["full_proj_transform"],
174 | 'camera_center': batch["camera_center"]}
175 |
176 | data["lview"]["pts"] = depth2pc(data["lview"]["depth"], torch.inverse(batch["lpose"]), batch["intrinsic"])
177 | data["rview"]["pts"] = depth2pc(data["rview"]["depth"], torch.inverse(batch["rpose"]), batch["intrinsic"])
178 |
179 | pred = pts2render(data, [0.,0.,0.])[:,0]
180 | loss = F.l1_loss(pred.squeeze(), gt.squeeze())
181 | l1_list.append(loss.item())
182 |
183 | # if idx == show_idx:
184 | if idx in show_idx:
185 | print("show idx is ", idx)
186 | tmp_gt = (gt[0]*255.0).cpu().numpy().astype(np.uint8).squeeze()
187 | tmp_pred = (pred[0]*255.0).cpu().numpy().astype(np.uint8).squeeze()
188 | # tmp_gt = tmp_pred
189 | tmp_img_name = '%s/step%s_idx%d.jpg' % (cfg.record.show_path, self.total_steps, idx)
190 | imageio.imsave(tmp_img_name, np.concatenate([tmp_pred, tmp_gt], axis=0))
191 |
192 | val_l1 = np.round(np.mean(np.array(l1_list)), 4)
193 | print(f"Validation Metrics ({self.total_steps}):, L1 {val_l1}")
194 | self.logger.write_dict({'val_l1': val_l1}, write_step=self.total_steps)
195 | torch.cuda.empty_cache()
196 |
197 | def save_ckpt(self, save_path, show_log=True):
198 | if show_log:
199 | print(f"Save checkpoint to {save_path} ...")
200 | torch.save({
201 | 'total_steps': self.total_steps,
202 | 'network': self.model.state_dict(),
203 | 'optimizer': self.optimizer.state_dict(),
204 | 'scheduler': self.scheduler.state_dict()
205 | }, save_path)
206 |
207 | def load_ckpt(self, load_path, load_optimizer=True, strict=True):
208 | assert os.path.exists(load_path)
209 | print(f"Loading checkpoint from {load_path} ...")
210 | ckpt = torch.load(load_path, map_location='cuda')
211 | self.model.load_state_dict(ckpt['network'], strict=strict)
212 | print(f"Parameter loading done")
213 | if load_optimizer:
214 | self.total_steps = ckpt['total_steps'] + 1
215 | self.logger.total_steps = self.total_steps
216 | self.optimizer.load_state_dict(ckpt['optimizer'])
217 | self.scheduler.load_state_dict(ckpt['scheduler'])
218 | print(f"Optimizer loading done")
219 |
220 |
221 | if __name__ == "__main__":
222 | # L = torch.randn((1,3,640,480)).cuda()
223 | # R = torch.randn((1,3,640,480)).cuda()
224 | # # net = Net(int(5)).cuda()
225 | # net = ASNet_light(int(5)).cuda()
226 | # out = net(L,R)
227 | # print(len(out))
228 | # Input = torch.randn((1, 5, 640, 480)).cuda()
229 | # fnet = FireNet({"num_bins":5}).cuda()
230 | # out = fnet(Input, None)
231 | # print(out[0].shape)
232 |
233 | trainer = Trainer()
234 | trainer.train()
235 |
236 |
237 |
238 |
--------------------------------------------------------------------------------