├── LICENSE
├── README.md
├── base
├── __init__.py
├── base_data_loader.py
├── base_model.py
└── base_trainer.py
├── configs
├── evaluate
│ ├── eval_mr_effinetb5.json
│ └── eval_mr_resnet18.json
└── train
│ ├── train_mr_effinetb5.json
│ └── train_mr_resnet18.json
├── create_pointcloud.py
├── data_loader
├── __init__.py
├── data_loaders.py
├── kitti_odometry_dataset.py
└── scripts
│ ├── __init__.py
│ └── preprocess_kitti_transfer_gtdepth_to_odom.py
├── depth_proc_tools
└── plot_depth_utils.py
├── evaluate.py
├── evaluater
├── __init__.py
└── evaluater.py
├── logger
├── __init__.py
├── logger.py
├── logger_config.json
└── visualization.py
├── model
├── __init__.py
├── dymultidepth
│ ├── __init__.py
│ ├── base_models.py
│ ├── ccf_modules.py
│ └── dymultidepth_model.py
├── layers.py
├── loss.py
├── loss_functions
│ ├── __init__.py
│ ├── common_losses.py
│ ├── dymultidepth_loss.py
│ └── virtual_normal_loss.py
├── metric.py
├── metric_functions
│ ├── __init__.py
│ ├── completeness_metrics.py
│ ├── dense_metrics.py
│ └── sparse_metrics.py
└── model.py
├── parse_config.py
├── pictures
├── dynamic_depth_result.gif
└── overview.jpg
├── requirements.txt
├── train.py
├── trainer
├── __init__.py
└── trainer.py
└── utils
├── __init__.py
├── parse_config.py
├── ply_utils.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Rui Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dynamic-multiframe-depth
2 |
3 | This is the code repository for the paper:
4 | > **Learning to Fuse Monocular and Multi-view Cues for Multi-frame Depth Estimation in Dynamic Scenes**
5 | >
6 | > [Rui Li](https://ruili3.github.io/), [Dong Gong](https://donggong1.github.io/index.html), [Yin Wei](https://yvanyin.net/), [Hao Chen](https://stan-haochen.github.io/), Yu Zhu, Kaixuan Wang, Xiaozhi Chen, Jinqiu Sun and Yanning Zhang
7 | >
8 | > **CVPR 2023 [[Project](https://ruili3.github.io/dymultidepth/index.html)] [[arXiv](https://arxiv.org/abs/2304.08993)] [[Video](https://www.youtube.com/watch?v=0ViYXt2bpuM)]**
9 |
10 | 
11 |
12 | If you use any content of this repo for your work, please cite the following our paper:
13 | ```
14 | @inproceedings{li2023learning,
15 | title={Learning to Fuse Monocular and Multi-view Cues for Multi-frame Depth Estimation in Dynamic Scenes},
16 | author={Li, Rui and Gong, Dong and Yin, Wei and Chen, Hao and Zhu, Yu and Wang, Kaixuan and Chen, Xiaozhi and Sun, Jinqiu and Zhang, Yanning},
17 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
18 | pages={21539--21548},
19 | year={2023}
20 | }
21 | ```
22 |
23 | ## Introduction
24 | Multi-frame depth estimation generally achieves high accuracy relying on the multi-view geometric consistency. When applied in dynamic scenes, this consistency is usually violated in the dynamic areas, leading to corrupted estimations. Many multi-frame methods handle dynamic areas by identifying them with explicit masks and compensating the multi-view cues with monocular cues represented as local monocular depth or features. The improvements are limited due to the uncontrolled quality of the masks and the underutilized benefits of the fusion of the two types of cues. In this paper, we propose a novel method to learn to fuse the multi-view and monocular cues encoded as volumes without needing the heuristically crafted masks. As unveiled in our analyses, the multi-view cues capture more accurate geometric information in static areas, and the monocular cues capture more useful contexts in dynamic areas. To let the geometric perception learned from multi-view cues in static areas propagate to the monocular representation in dynamic areas and let monocular cues enhance the representation of multi-view cost volume, we propose a cross-cue fusion (CCF) module, which includes the cross-cue attention (CCA) to encode the spatially non-local relative intra-relations from each source to enhance the representation of the other. Experiments on real-world datasets prove the significant effectiveness and generalization ability of the proposed method.
25 |
26 |

27 |
28 | Overview of the proposed network.
29 |
30 |
31 |
32 | ## Environment Setup
33 |
34 | You can set up your own `conda` virtual environment by running the commands below.
35 | ```shell
36 | # create a clean conda environment from scratch
37 | conda create --name dymultidepth python=3.7
38 | conda activate dymultidepth
39 | # install pip
40 | conda install ipython
41 | conda install pip
42 | # install required packages
43 | pip install -r requirements.txt
44 | ```
45 |
46 |
47 |
48 | ## KITTI Odometry Data
49 |
50 | We mainly use the KITTI Odometry dataset for training and testing, and you can follow the steps below to prepare the dataset.
51 |
52 | 1. Download the color images and calibration files from the [official website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php). We use the [improved ground truth depth](http://www.cvlibs.net/datasets/kitti/eval_depth_all.php) for training and evaluation.
53 | 2. Unzip the color images and calibration files into ```../data```. Transfer the initial lidar depth maps to the given format using script ```data_loader/scripts/preprocess_kitti_transfer_gtdepth_to_odom.py```.
54 |
55 | 3. The estimated poses can be downloaded
56 | from [here](https://vision.in.tum.de/_media/research/monorec/poses_dvso.zip) and be placed under ``../data/{kitti_path}/poses_dso``. This folder structure is ensured when unpacking the zip file in the ``{kitti_path}`` directory.
57 |
58 |
59 | 4. The dynamic object masks can be downloaded from [here](https://vision.in.tum.de/_media/research/monorec/mvobj_mask.zip). Unpack the .zip file in the ``{kitti_path}`` directory, the data should be put in ``../data/{kitti_path}/sequences/{seq_num}/mvobj_mask``.
60 |
61 | The dataset should be organized as follows:
62 | ```
63 | data
64 | └── dataset
65 | ├── poses_dvso
66 | │ ├── 00.txt
67 | │ ├── 01.txt
68 | │ ├── ...
69 | └── sequences
70 | ├── 00
71 | | ├── calib.txt
72 | | ├── image_2
73 | | ├── image_depth_annotated
74 | | ├── mvobj_mask
75 | | └── times.txt
76 | ├── ...
77 | ```
78 |
79 |
80 |
81 | ## Training
82 |
83 | To train the model from scratch, first set the configuration file.
84 | 1. Set the `dataset_dir` to the directory where the KITTI dataset is located.
85 | 2. Set the `save_dir` to the directory where you want to store the trained models.
86 |
87 | Then run the following command to train the model:
88 |
89 | ```shell
90 | # Here we train our model with ResNet-18 backbone as an example.
91 | python train.py --config configs/train/train_mr_resnet18.json
92 | ```
93 |
94 | ## Evaluation
95 | We provide KITTI-trained checkpoints of our model:
96 |
97 | | Backbone | Resolution | Download |
98 | | --- | --- | --- |
99 | | ResNet-18 | 256 x 512 | [Link](https://drive.google.com/file/d/1IhrBx3bj6H26UDxMNvRxF7xC0C9qqI1u/view?usp=sharing) |
100 | | EfficicentNet-B5 | 256 x 512 | [Link](https://drive.google.com/file/d/1jS1pbCKfYuuoawZ1nnejtGwQV3FhXGcg/view?usp=sharing) |
101 |
102 |
103 | The checkpoints can be saved in `./ckpt`. To reproduce the evaluation results in the paper, run the following commands:
104 | ```shell
105 | python evaluate.py --config configs/evaluate/eval_mr_resnet18.json
106 | ```
107 | In the `.json` file, set `checkpoint_location` to the model checkpoints path and set `save_dir` to save the evaluation scores.
108 | ## Acknowledgements
109 | Our method is implemented based on [MonoRec](https://github.com/Brummi/MonoRec). We thank the authors for their open-source code.
110 |
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_data_loader import *
2 | from .base_model import *
3 | from .base_trainer import *
4 |
--------------------------------------------------------------------------------
/base/base_data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataloader import default_collate
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 |
6 |
7 | class BaseDataLoader(DataLoader):
8 | """
9 | Base class for all data loaders
10 | """
11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12 | self.validation_split = validation_split
13 | self.shuffle = shuffle
14 |
15 | self.batch_idx = 0
16 | self.n_samples = len(dataset)
17 |
18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19 |
20 | self.init_kwargs = {
21 | 'dataset': dataset,
22 | 'batch_size': batch_size,
23 | 'shuffle': self.shuffle,
24 | 'collate_fn': collate_fn,
25 | 'num_workers': num_workers
26 | }
27 | super().__init__(sampler=self.sampler, **self.init_kwargs)
28 |
29 | def _split_sampler(self, split):
30 | if split == 0.0:
31 | return None, None
32 |
33 | idx_full = np.arange(self.n_samples)
34 |
35 | np.random.seed(0)
36 | np.random.shuffle(idx_full)
37 |
38 | if isinstance(split, int):
39 | assert split > 0
40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
41 | len_valid = split
42 | else:
43 | len_valid = int(self.n_samples * split)
44 |
45 | valid_idx = idx_full[0:len_valid]
46 | train_idx = np.delete(idx_full, np.arange(0, len_valid))
47 |
48 | train_sampler = SubsetRandomSampler(train_idx)
49 | valid_sampler = SubsetRandomSampler(valid_idx)
50 |
51 | # turn off shuffle option which is mutually exclusive with sampler
52 | self.shuffle = False
53 | self.n_samples = len(train_idx)
54 |
55 | return train_sampler, valid_sampler
56 |
57 | def split_validation(self):
58 | if self.valid_sampler is None:
59 | return None
60 | else:
61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
62 |
--------------------------------------------------------------------------------
/base/base_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from abc import abstractmethod
4 |
5 |
6 | class BaseModel(nn.Module):
7 | """
8 | Base class for all models
9 | """
10 | @abstractmethod
11 | def forward(self, *inputs):
12 | """
13 | Forward pass logic
14 |
15 | :return: Model output
16 | """
17 | raise NotImplementedError
18 |
19 | def __str__(self):
20 | """
21 | Model prints with number of trainable parameters
22 | """
23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24 | params = sum([np.prod(p.size()) for p in model_parameters])
25 | return super().__str__() + '\nTrainable parameters: {}'.format(params)
26 |
--------------------------------------------------------------------------------
/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import abstractmethod
3 | from numpy import inf
4 | from logger import TensorboardWriter
5 | from utils import filter_state_dict
6 |
7 | class BaseTrainer:
8 | """
9 | Base class for all trainers
10 | """
11 | def __init__(self, model, loss, metrics, optimizer, config):
12 | self.config = config
13 |
14 | if "trainer" in config.config:
15 | self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
16 | cfg_trainer = config['trainer']
17 | self.epochs = cfg_trainer['epochs']
18 | self.save_period = cfg_trainer['save_period']
19 | self.monitor = cfg_trainer.get('monitor', 'off')
20 |
21 | else:
22 | self.logger = config.get_logger("evaluater")
23 | self.monitor = "off"
24 |
25 | # setup GPU device if available, move model into configured device
26 | self.device, self.device_ids = self._prepare_device(config['n_gpu'])
27 | self.model = model.to(self.device)
28 | if len(self.device_ids) > 1:
29 | self.model = torch.nn.DataParallel(model, device_ids=self.device_ids)
30 |
31 | self.loss = loss
32 | self.metrics = metrics
33 | self.optimizer = optimizer
34 | self.save_multiple = True
35 |
36 | # configuration to monitor model performance and save best
37 | if self.monitor == 'off':
38 | self.mnt_mode = 'off'
39 | self.mnt_best = 0
40 | else:
41 | self.mnt_mode, self.mnt_metric = self.monitor.split()
42 | assert self.mnt_mode in ['min', 'max']
43 |
44 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf
45 | self.early_stop = cfg_trainer.get('early_stop', inf)
46 |
47 | self.start_epoch = 1
48 |
49 | self.checkpoint_dir = config.save_dir
50 |
51 | # setup visualization writer instance
52 | if "trainer" in config.config:
53 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
54 |
55 | if config.resume is not None:
56 | self._resume_checkpoint(config.resume)
57 |
58 | @abstractmethod
59 | def _train_epoch(self, epoch):
60 | """
61 | Training logic for an epoch
62 |
63 | :param epoch: Current epoch number
64 | """
65 | raise NotImplementedError
66 |
67 | def train(self):
68 | """
69 | Full training logic
70 | """
71 | not_improved_count = 0
72 | for epoch in range(self.start_epoch, self.epochs + 1):
73 | result = self._train_epoch(epoch)
74 |
75 | # save logged informations into log dict
76 | log = {'epoch': epoch}
77 | for key, value in result.items():
78 | if key == 'metrics':
79 | log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
80 | elif key == 'val_metrics':
81 | log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
82 | else:
83 | log[key] = value
84 |
85 | # print logged informations to the screen
86 | for key, value in log.items():
87 | self.logger.info(' {:15s}: {}'.format(str(key), value))
88 |
89 | # evaluate model performance according to configured metric, save best checkpoint as model_best
90 | best = False
91 | if self.mnt_mode != 'off':
92 | try:
93 | # check whether model performance improved or not, according to specified metric(mnt_metric)
94 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
95 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
96 | except KeyError:
97 | self.logger.warning("Warning: Metric '{}' is not found. "
98 | "Model performance monitoring is disabled.".format(self.mnt_metric))
99 | self.mnt_mode = 'off'
100 | improved = False
101 |
102 | if improved:
103 | self.mnt_best = log[self.mnt_metric]
104 | not_improved_count = 0
105 | best = True
106 | else:
107 | not_improved_count += 1
108 |
109 | if not_improved_count > self.early_stop:
110 | self.logger.info("Validation performance didn\'t improve for {} epochs. "
111 | "Training stops.".format(self.early_stop))
112 | break
113 |
114 | if epoch % self.save_period == 0:
115 | self._save_checkpoint(epoch, save_best=best)
116 |
117 | def _prepare_device(self, n_gpu_use):
118 | """
119 | setup GPU device if available, move model into configured device
120 | """
121 | n_gpu = torch.cuda.device_count()
122 | if n_gpu_use > 0 and n_gpu == 0:
123 | self.logger.warning("Warning: There\'s no GPU available on this machine,"
124 | "training will be performed on CPU.")
125 | n_gpu_use = 0
126 | if n_gpu_use > n_gpu:
127 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
128 | "on this machine.".format(n_gpu_use, n_gpu))
129 | n_gpu_use = n_gpu
130 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
131 | list_ids = list(range(n_gpu_use))
132 | return device, list_ids
133 |
134 | def _save_checkpoint(self, epoch, save_best=False):
135 | """
136 | Saving checkpoints
137 |
138 | :param epoch: current epoch number
139 | :param log: logging information of the epoch
140 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
141 | """
142 | arch = type(self.model).__name__
143 | state = {
144 | 'arch': arch,
145 | 'epoch': epoch,
146 | 'state_dict': self.model.state_dict(),
147 | 'optimizer': self.optimizer.state_dict(),
148 | 'monitor_best': self.mnt_best,
149 | 'config': self.config
150 | }
151 | if self.save_multiple:
152 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
153 | torch.save(state, filename)
154 | self.logger.info("Saving checkpoint: {} ...".format(filename))
155 | else:
156 | filename = str(self.checkpoint_dir / 'checkpoint.pth')
157 | torch.save(state, filename)
158 | self.logger.info("Saving checkpoint: {} ...".format(filename))
159 | if save_best:
160 | best_path = str(self.checkpoint_dir / 'model_best.pth')
161 | torch.save(state, best_path)
162 | self.logger.info("Saving current best: model_best.pth ...")
163 |
164 | def _resume_checkpoint(self, resume_path):
165 | """
166 | Resume from saved checkpoints
167 |
168 | :param resume_path: Checkpoint path to be resumed
169 | """
170 | resume_path = str(resume_path)
171 | self.logger.info("Loading checkpoint: {} ...".format(resume_path))
172 | checkpoint = torch.load(resume_path)
173 | self.start_epoch = checkpoint['epoch'] + 1
174 | self.mnt_best = checkpoint['monitor_best']
175 |
176 | # load architecture params from checkpoint.
177 | if checkpoint['config']['arch'] != self.config['arch']:
178 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
179 | "checkpoint. This may yield an exception while state_dict is being loaded.")
180 | # self.model.load_state_dict(checkpoint['state_dict'])
181 | checkpoint_state_dict = filter_state_dict(checkpoint["state_dict"], checkpoint["arch"] == "DataParallel" and len(self.device_ids) == 1)
182 | self.model.load_state_dict(checkpoint_state_dict, strict=False)
183 |
184 | # load optimizer state from checkpoint only when optimizer type is not changed.
185 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
186 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
187 | "Optimizer parameters not being resumed.")
188 | else:
189 | self.optimizer.load_state_dict(checkpoint['optimizer'])
190 |
191 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
192 |
--------------------------------------------------------------------------------
/configs/evaluate/eval_mr_effinetb5.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Eval",
3 | "n_gpu": 8,
4 | "timestamp_replacement": "00",
5 | "models": [
6 | {
7 | "type": "DyMultiDepthModel",
8 | "args": {
9 | "inv_depth_min_max": [
10 | 0.33,
11 | 0.0025
12 | ],
13 | "checkpoint_location": [
14 | "./ckpt/mr_kitti_effib5.pth"
15 | ],
16 | "pretrain_mode": 1,
17 | "pretrain_dropout": 0,
18 | "use_stereo": false,
19 | "use_mono": true,
20 | "use_ssim": 1,
21 | "fusion_type": "ccf_fusion",
22 | "input_size": [256, 512],
23 | "freeze_backbone": false,
24 | "backbone_type": "efficientnetb5"
25 | }
26 | }
27 | ],
28 | "data_loader": {
29 | "type": "KittiOdometryDataloader",
30 | "args": {
31 | "dataset_dir": "./data/dataset/",
32 | "depth_folder": "image_depth_annotated",
33 | "batch_size": 1,
34 | "frame_count": 2,
35 | "shuffle": false,
36 | "validation_split": 0,
37 | "num_workers": 8,
38 | "sequences": [
39 | "00",
40 | "04",
41 | "05",
42 | "07"
43 | ],
44 | "target_image_size": [
45 | 256,
46 | 512
47 | ],
48 | "use_color": true,
49 | "use_color_augmentation": false,
50 | "use_dso_poses": true,
51 | "lidar_depth": true,
52 | "dso_depth": false,
53 | "return_stereo": false,
54 |
55 | "return_mvobj_mask": 1
56 | }
57 | },
58 | "loss": "abs_silog_loss_virtualnormal",
59 | "metrics": [
60 | "abs_rel_sparse_metric",
61 | "sq_rel_sparse_metric",
62 | "rmse_sparse_metric",
63 | "rmse_log_sparse_metric",
64 | "a1_sparse_metric",
65 | "a2_sparse_metric",
66 | "a3_sparse_metric"
67 | ],
68 | "evaluater": {
69 | "save_dir": "./save_dir/",
70 | "max_distance": 80,
71 | "verbosity": 2,
72 | "log_step": 20,
73 | "tensorboard": false,
74 | "eval_mono": false
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/configs/evaluate/eval_mr_resnet18.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Eval",
3 | "n_gpu": 8,
4 | "timestamp_replacement": "00",
5 | "models": [
6 | {
7 | "type": "DyMultiDepthModel",
8 | "args": {
9 | "inv_depth_min_max": [
10 | 0.33,
11 | 0.0025
12 | ],
13 | "checkpoint_location": [
14 | "./ckpt/mr_kitti_effib5.pth"
15 | ],
16 | "pretrain_mode": 1,
17 | "pretrain_dropout": 0,
18 | "use_stereo": false,
19 | "use_mono": true,
20 | "use_ssim": 1,
21 | "fusion_type": "ccf_fusion",
22 | "input_size": [256, 512]
23 | }
24 | }
25 | ],
26 | "data_loader": {
27 | "type": "KittiOdometryDataloader",
28 | "args": {
29 | "dataset_dir": "./data/dataset/",
30 | "depth_folder": "image_depth_annotated",
31 | "batch_size": 1,
32 | "frame_count": 2,
33 | "shuffle": false,
34 | "validation_split": 0,
35 | "num_workers": 8,
36 | "sequences": [
37 | "00",
38 | "04",
39 | "05",
40 | "07"
41 | ],
42 | "target_image_size": [
43 | 256,
44 | 512
45 | ],
46 | "use_color": true,
47 | "use_color_augmentation": false,
48 | "use_dso_poses": true,
49 | "lidar_depth": true,
50 | "dso_depth": false,
51 | "return_stereo": false,
52 | "return_mvobj_mask": 1
53 | }
54 | },
55 | "loss": "abs_silog_loss_virtualnormal",
56 | "metrics": [
57 | "abs_rel_sparse_metric",
58 | "sq_rel_sparse_metric",
59 | "rmse_sparse_metric",
60 | "rmse_log_sparse_metric",
61 | "a1_sparse_metric",
62 | "a2_sparse_metric",
63 | "a3_sparse_metric"
64 | ],
65 | "evaluater": {
66 | "save_dir": "./save_dir/",
67 | "max_distance": 80,
68 | "verbosity": 2,
69 | "log_step": 20,
70 | "tensorboard": false,
71 | "eval_mono": false
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/configs/train/train_mr_effinetb5.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "dy_multi_depth",
3 | "n_gpu": 8,
4 | "arch": {
5 | "type": "DyMultiDepthModel",
6 | "args": {
7 | "pretrain_mode": 1,
8 | "pretrain_dropout": 0.0,
9 | "augmentation": "depth",
10 | "use_mono": true,
11 | "use_stereo": false,
12 | "checkpoint_location": [],
13 | "fusion_type": "ccf_fusion",
14 | "input_size": [256, 512],
15 | "freeze_backbone": false,
16 | "backbone_type": "efficientnetb5"
17 | }
18 | },
19 | "data_loader": {
20 | "type": "KittiOdometryDataloader",
21 | "args": {
22 | "dataset_dir": "./data/dataset/",
23 | "depth_folder": "image_depth_annotated",
24 | "batch_size": 8,
25 | "frame_count": 2,
26 | "shuffle": true,
27 | "validation_split": 0,
28 | "num_workers": 16,
29 | "sequences": [
30 | "01",
31 | "02",
32 | "06",
33 | "08",
34 | "09",
35 | "10"
36 | ],
37 | "target_image_size": [
38 | 256,
39 | 512
40 | ],
41 | "use_color": true,
42 | "use_color_augmentation": true,
43 | "use_dso_poses": true,
44 | "lidar_depth": true,
45 | "dso_depth": false,
46 | "return_stereo": false,
47 | "return_mvobj_mask": true
48 | }
49 | },
50 | "val_data_loader": {
51 | "type": "KittiOdometryDataloader",
52 | "args": {
53 | "dataset_dir": "./data/dataset/",
54 | "depth_folder": "image_depth_annotated",
55 | "batch_size": 16,
56 | "frame_count": 2,
57 | "shuffle": false,
58 | "validation_split": 0,
59 | "num_workers": 2,
60 | "sequences": [
61 | "00",
62 | "04",
63 | "05",
64 | "07"
65 | ],
66 | "target_image_size": [
67 | 256,
68 | 512
69 | ],
70 | "max_length": 32,
71 | "use_color": true,
72 | "use_color_augmentation": true,
73 | "use_dso_poses": true,
74 | "lidar_depth": true,
75 | "dso_depth": false,
76 | "return_stereo": false,
77 |
78 |
79 | "return_mvobj_mask": true
80 | }
81 | },
82 | "optimizer": {
83 | "type": "Adam",
84 | "args": {
85 | "lr": 1e-4,
86 | "weight_decay": 0,
87 | "amsgrad": true
88 | }
89 | },
90 |
91 |
92 |
93 | "loss": "abs_silog_loss_virtualnormal",
94 |
95 |
96 | "metrics": [
97 | "a1_sparse_metric",
98 | "abs_rel_sparse_metric",
99 | "rmse_sparse_metric"
100 | ],
101 | "lr_scheduler": {
102 | "type": "StepLR",
103 | "args": {
104 | "step_size": 65,
105 | "gamma": 0.1
106 | }
107 | },
108 | "trainer": {
109 | "compute_mask": false,
110 | "compute_stereo_pred": false,
111 | "epochs": 80,
112 | "save_dir": "./saved_model/",
113 | "save_period": 1,
114 | "verbosity": 2,
115 | "log_step": 4800,
116 | "val_log_step": 40,
117 | "alpha": 0.5,
118 | "max_distance": 80,
119 | "monitor": "min abs_rel_sparse_metric",
120 | "timestamp_replacement": "00",
121 | "tensorboard": true
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/configs/train/train_mr_resnet18.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "dy_multi_depth",
3 | "n_gpu": 8,
4 | "arch": {
5 | "type": "DyMultiDepthModel",
6 | "args": {
7 | "pretrain_mode": 1,
8 | "pretrain_dropout": 0.0,
9 | "augmentation": "depth",
10 | "use_mono": true,
11 | "use_stereo": false,
12 | "checkpoint_location": [],
13 | "fusion_type": "ccf_fusion",
14 | "input_size": [256, 512]
15 | }
16 | },
17 | "data_loader": {
18 | "type": "KittiOdometryDataloader",
19 | "args": {
20 | "dataset_dir": "./data/dataset/",
21 | "depth_folder": "image_depth_annotated",
22 | "batch_size": 8,
23 | "frame_count": 2,
24 | "shuffle": true,
25 | "validation_split": 0,
26 | "num_workers": 16,
27 | "sequences": [
28 | "01",
29 | "02",
30 | "06",
31 | "08",
32 | "09",
33 | "10"
34 | ],
35 | "target_image_size": [
36 | 256,
37 | 512
38 | ],
39 | "use_color": true,
40 | "use_color_augmentation": true,
41 | "use_dso_poses": true,
42 | "lidar_depth": true,
43 | "dso_depth": false,
44 | "return_stereo": false,
45 |
46 |
47 | "return_mvobj_mask": true
48 | }
49 | },
50 | "val_data_loader": {
51 | "type": "KittiOdometryDataloader",
52 | "args": {
53 | "dataset_dir": "./data/dataset/",
54 | "depth_folder": "image_depth_annotated",
55 | "batch_size": 16,
56 | "frame_count": 2,
57 | "shuffle": false,
58 | "validation_split": 0,
59 | "num_workers": 2,
60 | "sequences": [
61 | "00",
62 | "04",
63 | "05",
64 | "07"
65 | ],
66 | "target_image_size": [
67 | 256,
68 | 512
69 | ],
70 | "max_length": 32,
71 | "use_color": true,
72 | "use_color_augmentation": true,
73 | "use_dso_poses": true,
74 | "lidar_depth": true,
75 | "dso_depth": false,
76 | "return_stereo": false,
77 |
78 |
79 | "return_mvobj_mask": true
80 | }
81 | },
82 | "optimizer": {
83 | "type": "Adam",
84 | "args": {
85 | "lr": 1e-4,
86 | "weight_decay": 0,
87 | "amsgrad": true
88 | }
89 | },
90 |
91 |
92 |
93 | "loss": "abs_silog_loss_virtualnormal",
94 |
95 |
96 | "metrics": [
97 | "a1_sparse_metric",
98 | "abs_rel_sparse_metric",
99 | "rmse_sparse_metric"
100 | ],
101 | "lr_scheduler": {
102 | "type": "StepLR",
103 | "args": {
104 | "step_size": 65,
105 | "gamma": 0.1
106 | }
107 | },
108 | "trainer": {
109 | "compute_mask": false,
110 | "compute_stereo_pred": false,
111 | "epochs": 80,
112 | "save_dir": "./saved_model/",
113 | "save_period": 1,
114 | "verbosity": 2,
115 | "log_step": 4800,
116 | "val_log_step": 40,
117 | "alpha": 0.5,
118 | "max_distance": 80,
119 | "monitor": "min abs_rel_sparse_metric",
120 | "timestamp_replacement": "00",
121 | "tensorboard": true
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/create_pointcloud.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 |
8 | import data_loader.data_loaders as module_data
9 | import model.model as module_arch
10 | from utils.parse_config import ConfigParser
11 | from utils import to, PLYSaver, DS_Wrapper
12 |
13 | import torch.nn.functional as F
14 |
15 |
16 | def main(config):
17 | logger = config.get_logger('test')
18 |
19 | output_dir = Path(config.config.get("output_dir", "saved"))
20 | output_dir.mkdir(exist_ok=True, parents=True)
21 | file_name = config.config.get("file_name", "pc.ply")
22 | use_mask = config.config.get("use_mask", True)
23 | roi = config.config.get("roi", None)
24 |
25 | max_d = config.config.get("max_d", 30)
26 | min_d = config.config.get("min_d", 3)
27 |
28 | start = config.config.get("start", 0)
29 | end = config.config.get("end", -1)
30 |
31 | # setup data_loader instances
32 | data_loader = DataLoader(DS_Wrapper(config.initialize('data_set', module_data), start=start, end=end), batch_size=1, shuffle=False, num_workers=8)
33 |
34 | # build model architecture
35 | model = config.initialize('arch', module_arch)
36 | logger.info(model)
37 |
38 | if config['n_gpu'] > 1:
39 | model = torch.nn.DataParallel(model)
40 |
41 | # prepare model for testing
42 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43 | model = model.to(device)
44 | model.eval()
45 |
46 | mask_fill = 32
47 |
48 | n = data_loader.batch_size
49 |
50 | target_image_size = data_loader.dataset.dataset.target_image_size
51 |
52 | plysaver = PLYSaver(target_image_size[0], target_image_size[1], min_d=min_d, max_d=max_d, batch_size=n, roi=roi, dropout=.75)
53 | plysaver.to(device)
54 |
55 | pose_buffer = []
56 | intrinsics_buffer = []
57 | mask_buffer = []
58 | keyframe_buffer = []
59 | depth_buffer = []
60 |
61 | buffer_length = 5
62 | min_hits = 1
63 | key_index = buffer_length // 2
64 |
65 | with torch.no_grad():
66 | for i, (data, target) in enumerate(tqdm(data_loader)):
67 | data = to(data, device)
68 | # if not torch.any(pose_distance_thresh(data, spatial_thresh=1)):
69 | # continue
70 | result = model(data)
71 | if not isinstance(result, dict):
72 | result = {"result": result[0]}
73 | output = result["result"]
74 | if "cv_mask" not in result:
75 | result["cv_mask"] = output.new_zeros(output.shape)
76 | # mask = ((result["cv_mask"] >= .1) & (output >= 1 / max_d)).to(dtype=torch.float32)
77 | mask = (result["cv_mask"] >= .1).to(dtype=torch.float32)
78 | mask = (F.conv2d(mask, mask.new_ones((1, 1, mask_fill+1, mask_fill+1)), padding=mask_fill // 2) < 1).to(dtype=torch.float32)
79 |
80 | pose_buffer += data["keyframe_pose"]
81 | intrinsics_buffer += [data["keyframe_intrinsics"]]
82 | mask_buffer += [mask]
83 | keyframe_buffer += [data["keyframe"]]
84 | depth_buffer += [output]
85 |
86 | if len(pose_buffer) >= buffer_length:
87 | pose = pose_buffer[key_index]
88 | intrinsics = intrinsics_buffer[key_index]
89 | keyframe = keyframe_buffer[key_index]
90 | depth = depth_buffer[key_index]
91 |
92 | mask = (torch.sum(torch.stack(mask_buffer), dim=0) > buffer_length - min_hits).to(dtype=torch.float32)
93 | if use_mask:
94 | depth *= mask
95 |
96 | plysaver.add_depthmap(depth, keyframe, intrinsics, pose)
97 |
98 | del pose_buffer[0]
99 | del intrinsics_buffer[0]
100 | del mask_buffer[0]
101 | del keyframe_buffer[0]
102 | del depth_buffer[0]
103 |
104 | with open(output_dir / file_name, "wb") as f:
105 | plysaver.save(f)
106 |
107 |
108 | if __name__ == '__main__':
109 | args = argparse.ArgumentParser(description='PyTorch Template')
110 |
111 | args.add_argument('-r', '--resume', default=None, type=str,
112 | help='path to latest checkpoint (default: None)')
113 | args.add_argument('-c', '--config', default=None, type=str,
114 | help='config file path (default: None)')
115 | args.add_argument('-d', '--device', default=None, type=str,
116 | help='indices of GPUs to enable (default: all)')
117 |
118 | config = ConfigParser(args)
119 | main(config)
120 |
--------------------------------------------------------------------------------
/data_loader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/data_loader/__init__.py
--------------------------------------------------------------------------------
/data_loader/data_loaders.py:
--------------------------------------------------------------------------------
1 | from base import BaseDataLoader
2 |
3 | from .kitti_odometry_dataset import *
4 |
5 | class KittiOdometryDataloader(BaseDataLoader):
6 |
7 | def __init__(self, batch_size=1, shuffle=True, validation_split=0.0, num_workers=4, **kwargs):
8 | self.dataset = KittiOdometryDataset(**kwargs)
9 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
--------------------------------------------------------------------------------
/data_loader/kitti_odometry_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import pykitti
6 | import torch
7 | import torchvision
8 | from PIL import Image
9 | from scipy import sparse
10 | from skimage.transform import resize
11 | from torch.utils.data import Dataset
12 |
13 | from utils import map_fn
14 | import torchvision.transforms.functional as F
15 |
16 |
17 | class KittiOdometryDataset(Dataset):
18 |
19 | def __init__(self, dataset_dir, frame_count=2, sequences=None, depth_folder="image_depth_annotated",
20 | target_image_size=(256, 512), max_length=None, dilation=1, offset_d=0, use_color=True, use_dso_poses=False, use_color_augmentation=False, lidar_depth=False, dso_depth=True, annotated_lidar=True, return_stereo=False, return_mvobj_mask=False, use_index_mask=()):
21 | """
22 | Dataset implementation for KITTI Odometry.
23 | :param dataset_dir: Top level folder for KITTI Odometry (should contain folders sequences, poses, poses_dvso (if available)
24 | :param frame_count: Number of frames used per sample (excluding the keyframe). By default, the keyframe is in the middle of those frames. (Default=2)
25 | :param sequences: Which sequences to use. Should be tuple of strings, e.g. ("00", "01", ...)
26 | :param depth_folder: The folder within the sequence folder that contains the depth information (e.g. sequences/00/{depth_folder})
27 | :param target_image_size: Desired image size (correct processing of depths is only guaranteed for default value). (Default=(256, 512))
28 | :param max_length: Maximum length per sequence. Useful for splitting up sequences and testing. (Default=None)
29 | :param dilation: Spacing between the frames (Default 1)
30 | :param offset_d: Index offset for frames (offset_d=0 means keyframe is centered). (Default=0)
31 | :param use_color: Use color (camera 2) or greyscale (camera 0) images (default=True)
32 | :param use_dso_poses: Use poses provided by d(v)so instead of KITTI poses. Requires poses_dvso folder. (Default=True)
33 | :param use_color_augmentation: Use color jitter augmentation. The same transformation is applied to all frames in a sample. (Default=False)
34 | :param lidar_depth: Use depth information from (annotated) velodyne data. (Default=False)
35 | :param dso_depth: Use depth information from d(v)so. (Default=True)
36 | :param annotated_lidar: If lidar_depth=True, then this determines whether to use annotated or non-annotated depth maps. (Default=True)
37 | :param return_mvobj_mask: Return additional moving object mask.
38 | """
39 | self.dataset_dir = Path(dataset_dir)
40 | self.frame_count = frame_count
41 | self.sequences = sequences
42 | self.depth_folder = depth_folder
43 | self.lidar_depth = lidar_depth
44 | self.annotated_lidar = annotated_lidar
45 | self.dso_depth = dso_depth
46 | self.target_image_size = target_image_size
47 | self.use_index_mask = use_index_mask
48 | self.offset_d = offset_d
49 | if self.sequences is None:
50 | self.sequences = [f"{i:02d}" for i in range(11)]
51 | self._datasets = [pykitti.odometry(dataset_dir, sequence) for sequence in self.sequences]
52 | self._offset = (frame_count // 2) * dilation
53 | extra_frames = frame_count * dilation
54 | if self.annotated_lidar and self.lidar_depth:
55 | extra_frames = max(extra_frames, 10)
56 | self._offset = max(self._offset, 5)
57 | self._dataset_sizes = [
58 | len((dataset.cam0_files if not use_color else dataset.cam2_files)) - (extra_frames if self.use_index_mask is None else 0) for dataset in
59 | self._datasets]
60 | if self.use_index_mask is not None:
61 | index_masks = []
62 | for sequence_length, sequence in zip(self._dataset_sizes, self.sequences):
63 | index_mask = {i:True for i in range(sequence_length)}
64 | for index_mask_name in self.use_index_mask:
65 | with open(self.dataset_dir / "sequences" / sequence / (index_mask_name + ".json")) as f:
66 | m = json.load(f)
67 | for k in list(index_mask.keys()):
68 | if not str(k) in m or not m[str(k)]:
69 | del index_mask[k]
70 | index_masks.append(index_mask)
71 | self._indices = [
72 | list(sorted([int(k) for k in sorted(index_mask.keys()) if index_mask[k] and int(k) >= self._offset and int(k) < dataset_size + self._offset - extra_frames]))
73 | for index_mask, dataset_size in zip(index_masks, self._dataset_sizes)
74 | ]
75 | self._dataset_sizes = [len(indices) for indices in self._indices]
76 | if max_length is not None:
77 | self._dataset_sizes = [min(s, max_length) for s in self._dataset_sizes]
78 | self.length = sum(self._dataset_sizes)
79 |
80 | intrinsics_box = [self.compute_target_intrinsics(dataset, target_image_size, use_color) for dataset in
81 | self._datasets]
82 | self._crop_boxes = [b for _, b in intrinsics_box]
83 | if self.dso_depth:
84 | self.dso_depth_parameters = [self.get_dso_depth_parameters(dataset) for dataset in self._datasets]
85 | elif not self.lidar_depth:
86 | self._depth_crop_boxes = [
87 | self.compute_depth_crop(self.dataset_dir / "sequences" / s / depth_folder) for s in
88 | self.sequences]
89 | self._intrinsics = [format_intrinsics(i, self.target_image_size) for i, _ in intrinsics_box]
90 | self.dilation = dilation
91 | self.use_color = use_color
92 | self.use_dso_poses = use_dso_poses
93 | self.use_color_augmentation = use_color_augmentation
94 | if self.use_dso_poses:
95 | for dataset in self._datasets:
96 | dataset.pose_path = self.dataset_dir / "poses_dvso"
97 | dataset._load_poses()
98 | if self.use_color_augmentation:
99 | self.color_transform = ColorJitterMulti(brightness=.2, contrast=.2, saturation=.2, hue=.1)
100 | self.return_stereo = return_stereo
101 | self.return_mvobj_mask = return_mvobj_mask
102 |
103 | def get_dataset_index(self, index: int):
104 | for dataset_index, dataset_size in enumerate(self._dataset_sizes):
105 | if index >= dataset_size:
106 | index = index - dataset_size
107 | else:
108 | return dataset_index, index
109 | return None, None
110 |
111 | def preprocess_image(self, img: Image.Image, crop_box=None):
112 | if crop_box:
113 | img = img.crop(crop_box)
114 | if self.target_image_size:
115 | img = img.resize((self.target_image_size[1], self.target_image_size[0]), resample=Image.BILINEAR)
116 | if self.use_color_augmentation:
117 | img = self.color_transform(img)
118 | image_tensor = torch.tensor(np.array(img).astype(np.float32))
119 | image_tensor = image_tensor / 255 - .5
120 | if not self.use_color:
121 | image_tensor = torch.stack((image_tensor, image_tensor, image_tensor))
122 | else:
123 | image_tensor = image_tensor.permute(2, 0, 1)
124 | del img
125 | return image_tensor
126 |
127 | def preprocess_depth(self, depth: np.ndarray, crop_box=None):
128 | if crop_box:
129 | if crop_box[1] >= 0 and crop_box[3] <= depth.shape[0]:
130 | depth = depth[int(crop_box[1]):int(crop_box[3]), :]
131 | else:
132 | depth_ = np.ones((crop_box[3] - crop_box[1], depth.shape[1]))
133 | depth_[-crop_box[1]:-crop_box[1]+depth.shape[0], :] = depth
134 | depth = depth_
135 | if crop_box[0] >= 0 and crop_box[2] <= depth.shape[1]:
136 | depth = depth[:, int(crop_box[0]):int(crop_box[2])]
137 | else:
138 | depth_ = np.ones((depth.shape[0], crop_box[2] - crop_box[0]))
139 | depth_[:, -crop_box[0]:-crop_box[0]+depth.shape[1]] = depth
140 | depth = depth_
141 | if self.target_image_size:
142 | depth = resize(depth, self.target_image_size, order=0)
143 | return torch.tensor(1 / depth)
144 |
145 | def preprocess_depth_dso(self, depth: Image.Image, dso_depth_parameters, crop_box=None):
146 | h, w, f_x = dso_depth_parameters
147 | depth = np.array(depth, dtype=np.float)
148 | indices = np.array(np.nonzero(depth), dtype=np.float)
149 | indices[0] = np.clip(indices[0] / depth.shape[0] * h, 0, h-1)
150 | indices[1] = np.clip(indices[1] / depth.shape[1] * w, 0, w-1)
151 |
152 | depth = depth[depth > 0]
153 | depth = (w * depth / (0.54 * f_x * 65535))
154 |
155 | data = np.concatenate([indices, np.expand_dims(depth, axis=0)], axis=0)
156 |
157 | if crop_box:
158 | data = data[:, (crop_box[1] <= data[0, :]) & (data[0, :] < crop_box[3]) & (crop_box[0] <= data[1, :]) & (data[1, :] < crop_box[2])]
159 | data[0, :] -= crop_box[1]
160 | data[1, :] -= crop_box[0]
161 | crop_height = crop_box[3] - crop_box[1]
162 | crop_width = crop_box[2] - crop_box[0]
163 | else:
164 | crop_height = h
165 | crop_width = w
166 |
167 | data[0] = np.clip(data[0] / crop_height * self.target_image_size[0], 0, self.target_image_size[0]-1)
168 | data[1] = np.clip(data[1] / crop_width * self.target_image_size[1], 0, self.target_image_size[1]-1)
169 |
170 | depth = np.zeros(self.target_image_size)
171 | depth[np.around(data[0]).astype(np.int), np.around(data[1]).astype(np.int)] = data[2]
172 |
173 | return torch.tensor(depth, dtype=torch.float32)
174 |
175 | def preprocess_depth_annotated_lidar(self, depth: Image.Image, crop_box=None):
176 | depth = np.array(depth, dtype=np.float)
177 | h, w = depth.shape
178 | indices = np.array(np.nonzero(depth), dtype=np.float)
179 |
180 | depth = depth[depth > 0]
181 | depth = 256.0 / depth
182 |
183 | data = np.concatenate([indices, np.expand_dims(depth, axis=0)], axis=0)
184 |
185 | if crop_box:
186 | data = data[:, (crop_box[1] <= data[0, :]) & (data[0, :] < crop_box[3]) & (crop_box[0] <= data[1, :]) & (
187 | data[1, :] < crop_box[2])]
188 | data[0, :] -= crop_box[1]
189 | data[1, :] -= crop_box[0]
190 | crop_height = crop_box[3] - crop_box[1]
191 | crop_width = crop_box[2] - crop_box[0]
192 | else:
193 | crop_height = h
194 | crop_width = w
195 |
196 | data[0] = np.clip(data[0] / crop_height * self.target_image_size[0], 0, self.target_image_size[0] - 1)
197 | data[1] = np.clip(data[1] / crop_width * self.target_image_size[1], 0, self.target_image_size[1] - 1)
198 |
199 | depth = np.zeros(self.target_image_size)
200 | depth[np.around(data[0]).astype(np.int), np.around(data[1]).astype(np.int)] = data[2]
201 |
202 | return torch.tensor(depth, dtype=torch.float32)
203 |
204 | def __getitem__(self, index: int):
205 | dataset_index, index = self.get_dataset_index(index)
206 | if dataset_index is None:
207 | raise IndexError()
208 |
209 | if self.use_index_mask is not None:
210 | index = self._indices[dataset_index][index] - self._offset
211 |
212 | sequence_folder = self.dataset_dir / "sequences" / self.sequences[dataset_index]
213 | depth_folder = sequence_folder / self.depth_folder
214 |
215 | if self.use_color_augmentation:
216 | self.color_transform.fix_transform()
217 |
218 | dataset = self._datasets[dataset_index]
219 | keyframe_intrinsics = self._intrinsics[dataset_index]
220 | if not (self.lidar_depth or self.dso_depth):
221 | keyframe_depth = self.preprocess_depth(np.load(depth_folder / f"{(index + self._offset):06d}.npy"), self._depth_crop_boxes[dataset_index]).type(torch.float32).unsqueeze(0)
222 | else:
223 | if self.lidar_depth:
224 | if not self.annotated_lidar:
225 | lidar_depth = 1 / torch.tensor(sparse.load_npz(depth_folder / f"{(index + self._offset):06d}.npz").todense()).type(torch.float32).unsqueeze(0)
226 | lidar_depth[torch.isinf(lidar_depth)] = 0
227 | keyframe_depth = lidar_depth
228 | else:
229 | keyframe_depth = self.preprocess_depth_annotated_lidar(Image.open(depth_folder / f"{(index + self._offset):06d}.png"), self._crop_boxes[dataset_index]).unsqueeze(0)
230 | else:
231 | keyframe_depth = torch.zeros(1, self.target_image_size[0], self.target_image_size[1], dtype=torch.float32)
232 |
233 | if self.dso_depth:
234 | dso_depth = self.preprocess_depth_dso(Image.open(depth_folder / f"{(index + self._offset):06d}.png"), self.dso_depth_parameters[dataset_index], self._crop_boxes[dataset_index]).unsqueeze(0)
235 | mask = dso_depth == 0
236 | dso_depth[mask] = keyframe_depth[mask]
237 | keyframe_depth = dso_depth
238 |
239 | keyframe = self.preprocess_image(
240 | (dataset.get_cam0 if not self.use_color else dataset.get_cam2)(index + self._offset),
241 | self._crop_boxes[dataset_index])
242 | keyframe_pose = torch.tensor(dataset.poses[index + self._offset], dtype=torch.float32)
243 |
244 | frames = [self.preprocess_image((dataset.get_cam0 if not self.use_color else dataset.get_cam2)(index + self._offset + i + self.offset_d),
245 | self._crop_boxes[dataset_index]) for i in
246 | range(-(self.frame_count // 2) * self.dilation, ((self.frame_count + 1) // 2) * self.dilation + 1, self.dilation) if i != 0]
247 | intrinsics = [self._intrinsics[dataset_index] for _ in range(self.frame_count)]
248 | poses = [torch.tensor(dataset.poses[index + self._offset + i + self.offset_d], dtype=torch.float32) for i in
249 | range(-(self.frame_count // 2) * self.dilation, ((self.frame_count + 1) // 2) * self.dilation + 1, self.dilation) if i != 0]
250 |
251 | data = {
252 | "keyframe": keyframe,
253 | "keyframe_pose": keyframe_pose,
254 | "keyframe_intrinsics": keyframe_intrinsics,
255 | "frames": frames,
256 | "poses": poses,
257 | "intrinsics": intrinsics,
258 | "sequence": torch.tensor([int(self.sequences[dataset_index])], dtype=torch.int32),
259 | "image_id": torch.tensor([int(index + self._offset)], dtype=torch.int32)
260 | }
261 |
262 | if self.return_mvobj_mask > 0:
263 | mask = torch.tensor(np.load(sequence_folder / "mvobj_mask" / f"{index + self._offset:06d}.npy"), dtype=torch.float32).unsqueeze(0)
264 | data["mvobj_mask"] = mask
265 |
266 | return data, keyframe_depth
267 |
268 | def __len__(self) -> int:
269 | return self.length
270 |
271 | def compute_depth_crop(self, depth_folder):
272 | # This function is only used for dense gt depth maps.
273 | example_dm = np.load(depth_folder / "000000.npy")
274 | ry = example_dm.shape[0] / self.target_image_size[0]
275 | rx = example_dm.shape[1] / self.target_image_size[1]
276 | if ry < 1 or rx < 1:
277 | if ry >= rx:
278 | o_w = example_dm.shape[1]
279 | w = int(np.ceil(ry * self.target_image_size[1]))
280 | h = example_dm.shape[0]
281 | return ((o_w - w) // 2, 0, (o_w - w) // 2 + w, h)
282 | else:
283 | o_h = example_dm.shape[0]
284 | h = int(np.ceil(rx * self.target_image_size[0]))
285 | w = example_dm.shape[1]
286 | return (0, (o_h - h) // 2, w, (o_h - h) // 2 + h)
287 | if ry >= rx:
288 | o_h = example_dm.shape[0]
289 | h = rx * self.target_image_size[0]
290 | w = example_dm.shape[1]
291 | return (0, (o_h - h) // 2, w, (o_h - h) // 2 + h)
292 | else:
293 | o_w = example_dm.shape[1]
294 | w = ry * self.target_image_size[1]
295 | h = example_dm.shape[0]
296 | return ((o_w - w) // 2, 0, (o_w - w) // 2 + w, h)
297 |
298 | def compute_target_intrinsics(self, dataset, target_image_size, use_color):
299 | # Because of cropping and resizing of the frames, we need to recompute the intrinsics
300 | P_cam = dataset.calib.P_rect_00 if not use_color else dataset.calib.P_rect_20
301 | orig_size = tuple(reversed((dataset.cam0 if not use_color else dataset.cam2).__next__().size))
302 |
303 | r_orig = orig_size[0] / orig_size[1]
304 | r_target = target_image_size[0] / target_image_size[1]
305 |
306 | if r_orig >= r_target:
307 | new_height = r_target * orig_size[1]
308 | box = (0, (orig_size[0] - new_height) // 2, orig_size[1], orig_size[0] - (orig_size[0] - new_height) // 2)
309 |
310 | c_x = P_cam[0, 2] / orig_size[1]
311 | c_y = (P_cam[1, 2] - (orig_size[0] - new_height) / 2) / new_height
312 |
313 | rescale = orig_size[1] / target_image_size[1]
314 |
315 | else:
316 | new_width = orig_size[0] / r_target
317 | box = ((orig_size[1] - new_width) // 2, 0, orig_size[1] - (orig_size[1] - new_width) // 2, orig_size[0])
318 |
319 | c_x = (P_cam[0, 2] - (orig_size[1] - new_width) / 2) / new_width
320 | c_y = P_cam[1, 2] / orig_size[0]
321 |
322 | rescale = orig_size[0] / target_image_size[0]
323 |
324 | f_x = P_cam[0, 0] / target_image_size[1] / rescale
325 | f_y = P_cam[1, 1] / target_image_size[0] / rescale
326 |
327 | intrinsics = (f_x, f_y, c_x, c_y)
328 |
329 | return intrinsics, box
330 |
331 | def get_dso_depth_parameters(self, dataset):
332 | # Info required to process d(v)so depths
333 | P_cam = dataset.calib.P_rect_20
334 | orig_size = tuple(reversed(dataset.cam2.__next__().size))
335 | return orig_size[0], orig_size[1], P_cam[0, 0]
336 |
337 | def get_index(self, sequence, index):
338 | for i in range(len(self.sequences)):
339 | if int(self.sequences[i]) != sequence:
340 | index += self._dataset_sizes[i]
341 | else:
342 | break
343 | return index
344 |
345 |
346 | def format_intrinsics(intrinsics, target_image_size):
347 | intrinsics_mat = torch.zeros(4, 4)
348 | intrinsics_mat[0, 0] = intrinsics[0] * target_image_size[1]
349 | intrinsics_mat[1, 1] = intrinsics[1] * target_image_size[0]
350 | intrinsics_mat[0, 2] = intrinsics[2] * target_image_size[1]
351 | intrinsics_mat[1, 2] = intrinsics[3] * target_image_size[0]
352 | intrinsics_mat[2, 2] = 1
353 | intrinsics_mat[3, 3] = 1
354 | return intrinsics_mat
355 |
356 |
357 |
358 | # if torch.__version__ == "1.5.0":
359 | torchvision_version = torchvision.__version__.split(".")[1]
360 |
361 | if int(torchvision_version) < 9:
362 | class ColorJitterMulti(torchvision.transforms.ColorJitter):
363 | def fix_transform(self):
364 | self.transform = self.get_params(self.brightness, self.contrast,
365 | self.saturation, self.hue)
366 |
367 | def __call__(self, x):
368 | return map_fn(x, self.transform)
369 | else:
370 | class ColorJitterMulti(torchvision.transforms.ColorJitter):
371 | def fix_transform(self):
372 | self.params = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
373 |
374 | def __call__(self, img):
375 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.params
376 | for fn_id in fn_idx:
377 | if fn_id == 0 and brightness_factor is not None:
378 | img = F.adjust_brightness(img, brightness_factor)
379 | elif fn_id == 1 and contrast_factor is not None:
380 | img = F.adjust_contrast(img, contrast_factor)
381 | elif fn_id == 2 and saturation_factor is not None:
382 | img = F.adjust_saturation(img, saturation_factor)
383 | elif fn_id == 3 and hue_factor is not None:
384 | img = F.adjust_hue(img, hue_factor)
385 | return img
--------------------------------------------------------------------------------
/data_loader/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/data_loader/scripts/__init__.py
--------------------------------------------------------------------------------
/data_loader/scripts/preprocess_kitti_transfer_gtdepth_to_odom.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import shutil
3 | from pathlib import Path
4 | from zipfile import ZipFile
5 |
6 | mapping = {
7 | "2011_10_03_drive_0027": "00",
8 | "2011_10_03_drive_0042": "01",
9 | "2011_10_03_drive_0034": "02",
10 | "2011_09_26_drive_0067": "03",
11 | "2011_09_30_drive_0016": "04",
12 | "2011_09_30_drive_0018": "05",
13 | "2011_09_30_drive_0020": "06",
14 | "2011_09_30_drive_0027": "07",
15 | "2011_09_30_drive_0028": "08",
16 | "2011_09_30_drive_0033": "09",
17 | "2011_09_30_drive_0034": "10"
18 | }
19 |
20 | def main():
21 | parser = argparse.ArgumentParser(description='''
22 | This script creates depth images from annotated velodyne data.
23 | ''')
24 | parser.add_argument("--output", "-o", help="Path of KITTI odometry dataset", default="../../../data/dataset")
25 | parser.add_argument("--input", "-i", help="Path to KITTI depth dataset (zipped)", required=True)
26 | parser.add_argument("--depth_folder", "-d", help="Name of depth map folders for the respective sequences", default="image_depth_annotated")
27 |
28 | args = parser.parse_args()
29 | input = Path(args.input)
30 | output = Path(args.output)
31 | depth_folder = args.depth_folder
32 |
33 | drives = mapping.keys()
34 |
35 | print("Creating folder structure")
36 | for drive in drives:
37 | sequence = mapping[drive]
38 | folder = output/ "sequences" / sequence / depth_folder
39 | folder.mkdir(parents=True, exist_ok=True)
40 | print(folder)
41 |
42 | print("Extracting enhanced depth maps")
43 |
44 | with ZipFile(input) as depth_archive:
45 | for name in depth_archive.namelist():
46 | if name[0] == "t":
47 | drive = name[6:27]
48 | else:
49 | drive = name[4:25]
50 | cam = name[-16]
51 | img = name[-10:]
52 |
53 | # the first frame of seq-08 corresponds to the 1100 frame of kitti raw
54 | if drive=='2011_09_30_drive_0028':
55 | raw_img_id = img.split('.')[0]
56 | raw_img_id = int(raw_img_id)
57 | if raw_img_id < 1100:
58 | continue
59 | else:
60 | img = "{:06d}.png".format(raw_img_id - 1100)
61 |
62 | if cam == '2' and drive in drives:
63 | to = output / "sequences" / mapping[drive] / depth_folder / img
64 | print(name, " -> ", to)
65 | with depth_archive.open(name) as i, open(to, 'wb') as o:
66 | shutil.copyfileobj(i, o)
67 |
68 | main()
--------------------------------------------------------------------------------
/depth_proc_tools/plot_depth_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL.Image as pil
4 | import matplotlib as mpl
5 | import matplotlib.cm as cm
6 | import matplotlib.pyplot as plt
7 | import cv2
8 | from PIL import Image
9 | # import jax
10 |
11 | def save_pseudo_color_depth(depth:np.ndarray, colormap='rainbow', save_path="./"):
12 | return True
13 |
14 | def disp_to_depth_metric(disp, min_depth=0.1, max_depth=100.0):
15 | """Convert network's sigmoid output into depth prediction, ref: Monodepth2
16 | """
17 | min_disp = 1 / max_depth
18 | max_disp = 1 / min_depth
19 | scaled_disp = min_disp + (max_disp - min_disp) * disp
20 | depth = 1 / scaled_disp
21 | return scaled_disp, depth
22 |
23 |
24 |
25 | def save_color_imgs(image:np.ndarray, save_id=None, post_fix="img", save_path="./"):
26 | '''
27 | image: with shape (h,w,c=3)
28 | save_id = specify the name of the saved image
29 | '''
30 | if not isinstance(image, np.ndarray):
31 | raise Exception("Input image is not a np.ndarray")
32 | if not len(image.shape) == 3:
33 | raise Exception("Wong input shape. It should be a 3-dim image vector of shape (h,w,c)")
34 |
35 | if save_id is None:
36 | dirnames = os.listdir(save_path)
37 | save_id = len(dirnames)
38 |
39 | save_name = os.path.join(save_path,"{}_{}.jpg".format(save_id,post_fix))
40 |
41 | # for pytorch
42 | if image.shape[-1]==3 or image.shape[-1]==1:
43 | plt.imsave(save_name, image)
44 | else:
45 | raise Exception("invalid color channel of the last dim")
46 |
47 | print(f"successfully saved {save_name}!")
48 |
49 |
50 | def save_pseudo_color(input:np.ndarray, save_id=None, post_fix="error", pseudo_color="rainbow", save_path="./", vmax=None):
51 | '''
52 | input: with shape (h,w,c=3)
53 | save_id = specify the name of the saved error map
54 | '''
55 | if not isinstance(input, np.ndarray):
56 | raise Exception("Input input is not a np.ndarray")
57 | if not len(input.shape) == 3:
58 | raise Exception("Wong input shape. It should be a 3-dim input vector of shape (h,w,c)")
59 | if save_id is None:
60 | dirnames = os.listdir(save_path)
61 | save_id = len(dirnames)
62 |
63 | save_name = os.path.join(save_path,"{}_{}.jpg".format(save_id,post_fix))
64 |
65 | # for pytorch
66 | if input.shape[-1]==1:
67 | disp_resized_np = input.squeeze(-1)
68 | # if "error_" in post_fix:
69 | # print("save/photomtric error {} map:{},max:{},min:{},mean:{}".format(post_fix, disp_resized_np[:20,0],disp_resized_np.max(),disp_resized_np.min(),
70 | # disp_resized_np.mean()))
71 | vmax = np.percentile(disp_resized_np, 95) if vmax is None else vmax
72 | normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
73 | mapper = cm.ScalarMappable(norm=normalizer, cmap=pseudo_color)
74 | colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
75 | im = pil.fromarray(colormapped_im)
76 | im.save(save_name)
77 | else:
78 | raise Exception("invalid color channel of the last dim")
79 |
80 | print(f"successfully saved {save_name}!")
81 |
82 |
83 | def numpy_intensitymap_to_pcolor(input, vmin=None, vmax=None, colormap='rainbow'):
84 | '''
85 | input: h,w,1
86 | '''
87 | if input.shape[-1]==1:
88 | colormapped_im = numpy_1d_to_coloruint8(input, vmin, vmax, colormap)
89 | im = pil.fromarray(colormapped_im.astype(np.uint8))
90 | return im
91 | else:
92 | raise Exception("invalid color channel of the last dim")
93 |
94 |
95 | def numpy_1d_to_coloruint8(input, vmin=None, vmax=None, colormap='rainbow'):
96 | '''
97 | input: h,w,1
98 | '''
99 | if input.shape[-1]==1:
100 | input = input.squeeze(-1)
101 | invalid_mask = (input == 0).astype(float)
102 | vmax = np.percentile(input, 95) if vmax is None else vmax
103 | vmin = 1e-3 if vmin is None else vmin # vmin = input.min() if vmin is None else vmin
104 | normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
105 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap)
106 | colormapped_im = (mapper.to_rgba(input)[:, :, :3] * 255).astype(np.uint8)
107 | invalid_mask = np.expand_dims(invalid_mask,-1)
108 | colormapped_im = colormapped_im * (1-invalid_mask) + (invalid_mask * 255)
109 | return colormapped_im
110 | else:
111 | raise Exception("invalid color channel of the last dim")
112 |
113 |
114 | def numpy_distancemap_validmask_to_pil(input, vmax=None, colormap='rainbow'):
115 | if input.shape[-1]==1:
116 | # For PIL image, you [should] squeeze the addtional dim and set the resolution strictly to "h,w"
117 | input = input.squeeze(-1)
118 | validmask = (input != 0).astype(float)
119 | mask_img = (validmask*255.0).astype(np.uint8)
120 | im = pil.fromarray(mask_img)
121 | return im
122 | else:
123 | raise Exception("invalid color channel of the last dim")
124 |
125 | def numpy_rgb_to_pil(input):
126 | if input.shape[-1]==3:
127 | if input.max()<=1:
128 | colormapped_im = (input[:, :, :3] * 255).astype(np.uint8)
129 | im = pil.fromarray(colormapped_im)
130 | else:
131 | im = pil.fromarray(input)
132 | return im
133 | else:
134 | raise Exception("invalid color channel of the last dim")
135 |
136 |
137 | def get_error_map_value(pred_depth, gt_depth, grag_crop=True, median_scaling=True):
138 | '''
139 | input shape: h,w,c
140 | '''
141 | validmask = (gt_depth!=0).astype(bool)
142 | h,w,_ = gt_depth.shape
143 |
144 | if grag_crop:
145 | valid_area = (int(0.40810811 * h), int(0.99189189 * h), int(0.03594771 * w), int(0.96405229 * w))
146 | area_mask = np.zeros(gt_depth.shape)
147 | area_mask[valid_area[0]:valid_area[1],valid_area[2]:valid_area[3],:] = 1.0
148 | validmask = (validmask * area_mask).astype(bool)
149 |
150 | if median_scaling:
151 | pred_median = np.median(pred_depth[validmask])
152 | gt_median = np.median(gt_depth[validmask])
153 | ratio = gt_median/pred_median
154 | pred_depth_rescale = pred_depth * ratio
155 | else:
156 | pred_depth_rescale = pred_depth
157 |
158 | absrel_map = np.zeros(gt_depth.shape)
159 | absrel_map[validmask] = np.abs(gt_depth[validmask]-pred_depth_rescale[validmask]) / gt_depth[validmask]
160 |
161 | absrel_val = absrel_map.sum() / validmask.sum()
162 |
163 | return absrel_map, absrel_val
164 |
165 |
166 | # def save_concat_res(out_path, pred_disp, img, gt_depth):
167 | # '''
168 | # input shape: h,w,c
169 | # output concatenated img-gt-depth
170 | # '''
171 | # h,w,_ = gt_depth.shape
172 |
173 | # if img.shape[1]!=gt_depth.shape[1]:
174 | # img = jax.image.resize(img, (gt_depth.shape[0],gt_depth.shape[1], 3),"bilinear")
175 | # pred_disp = jax.image.resize(pred_disp, (gt_depth.shape[0],gt_depth.shape[1], 1),"bilinear")
176 |
177 | # pred_disp = np.array(pred_disp)
178 | # img = np.array(img)
179 | # gt_depth = np.array(gt_depth).squeeze(-1)
180 |
181 | # kernel = np.ones((5, 5), np.uint8)
182 | # gt_depth = cv2.dilate(gt_depth, kernel, iterations=1)
183 |
184 | # gt_depth = np.expand_dims(gt_depth,-1)
185 |
186 | # # get pil outputs
187 | # _, pred_depth = disp_to_depth_metric(pred_disp,min_depth=0.001,max_depth=80.0)
188 |
189 | # error_map, error_val = get_error_map_value(pred_depth, gt_depth)
190 |
191 | # error_pil = numpy_intensitymap_to_pcolor(error_map, vmin=0, vmax=0.5,colormap='jet')
192 | # pred_pil = numpy_intensitymap_to_pcolor(pred_depth)
193 | # gt_pil = numpy_intensitymap_to_pcolor(gt_depth)
194 | # img_pil = numpy_rgb_to_pil(img)
195 |
196 |
197 | # save_id = len(os.listdir(out_path))
198 | # save_name = os.path.join(out_path,"{}.png".format(save_id))
199 |
200 | # dst = Image.new('RGB', (w, h*4))
201 | # dst.paste(img_pil, (0, 0))
202 | # dst.paste(pred_pil, (0, h))
203 | # dst.paste(gt_pil, (0, 2*h))
204 | # dst.paste(error_pil, (0, 3*h))
205 |
206 |
207 |
208 | # dst.save(save_name)
209 |
210 |
211 | def directly_save_intensitymap(input, out_path, save_id=None, post_fix="error"):
212 | '''
213 | input shape: h,w
214 | '''
215 | im = Image.fromarray(input)
216 | im.save(os.path.join(out_path, "{:06}_{}.png".format(save_id,post_fix)))
217 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import pathlib
4 | from pathlib import Path
5 |
6 | import numpy as np
7 |
8 | import data_loader.data_loaders as module_data
9 | import model.loss as module_loss
10 | import model.metric as module_metric
11 | import model.model as module_arch
12 | from evaluater import Evaluater
13 | from utils.parse_config import ConfigParser
14 | import torch
15 |
16 |
17 | torch.backends.cuda.matmul.allow_tf32 = False
18 | torch.backends.cudnn.allow_tf32 = False
19 | torch.backends.cudnn.benchmark = True
20 |
21 |
22 | def main(config: ConfigParser):
23 | logger = config.get_logger('train')
24 |
25 | # setup data_loader instances
26 | data_loader = config.initialize('data_loader', module_data)
27 |
28 | # get function handles of loss and metrics
29 | loss = getattr(module_loss, config['loss'])
30 | metrics = [getattr(module_metric, met) for met in config['metrics']]
31 |
32 | # build model architecture, then print to console
33 |
34 | if "arch" in config.config:
35 | models = [config.initialize('arch', module_arch)]
36 | else:
37 | models = config.initialize_list("models", module_arch)
38 |
39 | results = []
40 |
41 | for i, model in enumerate(models):
42 | model_dict = dict(model.__dict__)
43 | keys = list(model_dict.keys())
44 | for k in keys:
45 | if k.startswith("_"):
46 | model_dict.__delitem__(k)
47 | elif type(model_dict[k]) == np.ndarray:
48 | model_dict[k] = list(model_dict[k])
49 |
50 |
51 | dataset_dict = dict(data_loader.dataset.__dict__)
52 | keys = list(dataset_dict.keys())
53 | for k in keys:
54 | if k.startswith("_"):
55 | dataset_dict.__delitem__(k)
56 | elif type(dataset_dict[k]) == np.ndarray:
57 | dataset_dict[k] = list(dataset_dict[k])
58 | elif isinstance(dataset_dict[k], pathlib.PurePath):
59 | dataset_dict[k] = str(dataset_dict[k])
60 |
61 |
62 | logger.info(model_dict)
63 | logger.info(dataset_dict)
64 |
65 | logger.info(f"{sum(p.numel() for p in model.parameters())} total parameters")
66 |
67 | evaluater = Evaluater(model, loss, metrics, config=config, data_loader=data_loader)
68 | result = evaluater.eval(i)
69 | result["metrics"] = result["metrics"]
70 | del model
71 | result["metrics_info"] = [metric.__name__ for metric in metrics]
72 | logger.info(result)
73 | results.append({
74 | "model": model_dict,
75 | "dataset": dataset_dict,
76 | "result": result
77 | })
78 |
79 | save_file = Path(config.log_dir) / "results.json"
80 | with open(save_file, "w") as f:
81 | json.dump(results, f, indent=4)
82 | logger.info("Finished")
83 |
84 |
85 | if __name__ == "__main__":
86 | args = argparse.ArgumentParser(description='Deeptam Evaluation')
87 | args.add_argument('-c', '--config', default=None, type=str,
88 | help='config file path (default: None)')
89 | args.add_argument('-d', '--device', default=None, type=str,
90 | help='indices of GPUs to enable (default: all)')
91 | args.add_argument('-r', '--resume', default=None, type=str,
92 | help='path to latest checkpoint (default: None)')
93 | config = ConfigParser(args)
94 | print(config.config)
95 | main(config)
96 |
--------------------------------------------------------------------------------
/evaluater/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluater import *
--------------------------------------------------------------------------------
/evaluater/evaluater.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from base import BaseTrainer
7 | from utils import operator_on_dict, median_scaling
8 |
9 |
10 | class Evaluater(BaseTrainer):
11 | """
12 | Trainer class
13 |
14 | Note:
15 | Inherited from BaseTrainer.trainer
16 | """
17 | def __init__(self, model, loss, metrics, config, data_loader):
18 | super().__init__(model, loss, metrics, None, config)
19 | self.config = config
20 | self.data_loader = data_loader
21 | self.log_step = config["evaluater"].get("log_step", int(np.sqrt(data_loader.batch_size)))
22 | self.model = model
23 | self.loss = loss
24 | self.metrics = metrics
25 | self.len_data = len(self.data_loader)
26 |
27 | if isinstance(loss, torch.nn.Module):
28 | self.loss.to(self.device)
29 | if len(self.device_ids) > 1:
30 | self.loss = torch.nn.DataParallel(self.loss, self.device_ids)
31 |
32 | self.roi = config["evaluater"].get("roi", None)
33 | self.alpha = config["evaluater"].get("alpha", None)
34 | self.max_distance = config["evaluater"].get("max_distance", None)
35 | self.correct_length = config["evaluater"].get("correct_length", False)
36 | self.median_scaling = config["evaluater"].get("median_scaling", False)
37 | self.eval_mono = config["evaluater"].get("eval_mono", False)
38 |
39 | def _eval_metrics(self, data_dict):
40 | acc_metrics = np.zeros(len(self.metrics))
41 | acc_metrics_mv = np.zeros(len(self.metrics))
42 | for i, metric in enumerate(self.metrics):
43 | if self.median_scaling:
44 | data_dict = median_scaling(data_dict)
45 | acc_metrics[i] += metric(data_dict, self.roi, self.max_distance, eval_mono=self.eval_mono)
46 | acc_metrics_mv[i] += metric(data_dict, self.roi, self.max_distance, use_cvmask=True, eval_mono = self.eval_mono)
47 | #self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
48 | if np.any(np.isnan(acc_metrics)):
49 | acc_metrics = np.zeros(len(self.metrics))
50 | valid = np.zeros(len(self.metrics))
51 | else:
52 | valid = np.ones(len(self.metrics))
53 |
54 | if np.any(np.isnan(acc_metrics_mv)):
55 | acc_metrics_mv = np.zeros(len(self.metrics))
56 | valid_mv = np.zeros(len(self.metrics))
57 | else:
58 | valid_mv = np.ones(len(self.metrics))
59 |
60 | return acc_metrics, valid, acc_metrics_mv, valid_mv
61 |
62 | def eval(self, model_index):
63 | """
64 | Training logic for an epoch
65 |
66 | :param model_index: Current training epoch.
67 | :return: A log that contains all information you want to save.
68 |
69 | Note:
70 | If you have additional information to record, for example:
71 | > additional_log = {"x": x, "y": y}
72 | merge it with log before return. i.e.
73 | > log = {**log, **additional_log}
74 | > return log
75 |
76 | The metrics in log must have the key 'metrics'.
77 | """
78 | self.model.eval()
79 |
80 | total_loss = 0
81 | total_loss_dict = {}
82 | total_metrics = np.zeros(len(self.metrics))
83 | total_metrics_valid = np.zeros(len(self.metrics))
84 |
85 | total_metrics_mv = np.zeros(len(self.metrics))
86 | total_metrics_valid_mv = np.zeros(len(self.metrics))
87 |
88 | total_metrics_runningavg = np.zeros(len(self.metrics))
89 | num_samples = 0
90 |
91 | for batch_idx, (data, target) in enumerate(self.data_loader):
92 | data, target = to(data, self.device), to(target, self.device)
93 | data["target"] = target
94 |
95 | with torch.no_grad():
96 | data = self.model(data)
97 | loss_dict = {"loss": torch.tensor([0])}
98 | loss = loss_dict["loss"]
99 |
100 | output = data["result"]
101 |
102 | #self.writer.set_step((model_index - 1) * self.len_data + batch_idx)
103 | #self.writer.add_scalar('loss', loss.item())
104 | total_loss += loss.item()
105 | total_loss_dict = operator_on_dict(total_loss_dict, loss_dict, operator.add)
106 | metrics, valid, metrics_mv, valid_mv = self._eval_metrics(data)
107 | total_metrics += metrics
108 | total_metrics_valid += valid
109 |
110 | total_metrics_mv += metrics_mv
111 | total_metrics_valid_mv += valid_mv
112 |
113 | batch_size = target.shape[0]
114 | if num_samples == 0:
115 | total_metrics_runningavg += metrics
116 | else:
117 | total_metrics_runningavg = total_metrics_runningavg * (num_samples / (num_samples + batch_size)) + \
118 | metrics * (batch_size / (num_samples + batch_size))
119 | num_samples += batch_size
120 |
121 | if batch_idx % self.log_step == 0:
122 | self.logger.debug(f'Evaluating {self._progress(batch_idx)} Loss: {loss.item() / (batch_idx + 1):.6f} Metrics: {list(total_metrics / (batch_idx + 1))}')
123 | #self.writer.add_image('input', make_grid(to(data["keyframe"], "cpu"), nrow=3, normalize=True))
124 | #self.writer.add_image('output', make_grid(to(torch.clamp(1 / output, 0, 100), "cpu") , nrow=3, normalize=True))
125 | #self.writer.add_image('ground_truth', make_grid(to(torch.clamp(1 / target, 0, 100), "cpu"), nrow=3, normalize=True))
126 |
127 | if batch_idx == self.len_data:
128 | break
129 |
130 | log = {
131 | 'loss': total_loss / self.len_data,
132 | 'metrics': self.save_digits((total_metrics / total_metrics_valid).tolist()),
133 | 'metrics_mv': self.save_digits((total_metrics_mv / total_metrics_valid_mv).tolist()),
134 | 'metrics_correct': self.save_digits(total_metrics_runningavg.tolist()),
135 | 'valid_batches': total_metrics_valid[0],
136 | 'valid_batches_mv': total_metrics_valid_mv[0]
137 | }
138 | for loss_component, v in total_loss_dict.items():
139 | log[f"loss_{loss_component}"] = v.item() / self.len_data
140 |
141 | return log
142 |
143 | def save_digits(self, input_list):
144 | return [float('{:.3f}'.format(i)) for i in input_list]
145 |
146 | def _progress(self, batch_idx):
147 | base = '[{}/{} ({:.0f}%)]'
148 | if hasattr(self.data_loader, 'n_samples'):
149 | current = batch_idx * self.data_loader.batch_size
150 | total = self.data_loader.n_samples
151 | else:
152 | current = batch_idx
153 | total = self.len_data
154 | return base.format(current, total, 100.0 * current / total)
155 |
156 |
157 | def to(data, device):
158 | if isinstance(data, dict):
159 | return {k: to(data[k], device) for k in data.keys()}
160 | elif isinstance(data, list):
161 | return [to(v, device) for v in data]
162 | else:
163 | return data.to(device)
164 |
--------------------------------------------------------------------------------
/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import *
2 | from .visualization import *
--------------------------------------------------------------------------------
/logger/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | from pathlib import Path
4 | from utils import read_json
5 |
6 |
7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8 | """
9 | Setup logging configuration
10 | """
11 | log_config = Path(log_config)
12 | if log_config.is_file():
13 | config = read_json(log_config)
14 | # modify logging paths based on run config
15 | for _, handler in config['handlers'].items():
16 | if 'filename' in handler:
17 | handler['filename'] = str(save_dir / handler['filename'])
18 |
19 | logging.config.dictConfig(config)
20 | else:
21 | print("Warning: logging configuration file is not found in {}.".format(log_config))
22 | logging.basicConfig(level=default_level)
23 |
--------------------------------------------------------------------------------
/logger/logger_config.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | "version": 1,
4 | "disable_existing_loggers": false,
5 | "formatters": {
6 | "simple": {"format": "%(message)s"},
7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8 | },
9 | "handlers": {
10 | "console": {
11 | "class": "logging.StreamHandler",
12 | "level": "DEBUG",
13 | "formatter": "simple",
14 | "stream": "ext://sys.stdout"
15 | },
16 | "info_file_handler": {
17 | "class": "logging.handlers.RotatingFileHandler",
18 | "level": "INFO",
19 | "formatter": "datetime",
20 | "filename": "info.log",
21 | "maxBytes": 10485760,
22 | "backupCount": 20, "encoding": "utf8"
23 | }
24 | },
25 | "root": {
26 | "level": "INFO",
27 | "handlers": [
28 | "console",
29 | "info_file_handler"
30 | ]
31 | }
32 | }
--------------------------------------------------------------------------------
/logger/visualization.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from utils import Timer
3 |
4 |
5 | class TensorboardWriter():
6 | def __init__(self, log_dir, logger, enabled):
7 | self.writer = None
8 | self.selected_module = ""
9 |
10 | if enabled:
11 | log_dir = str(log_dir)
12 |
13 | # Retrieve vizualization writer.
14 | succeeded = False
15 | for module in ["torch.utils.tensorboard", "tensorboardX"]:
16 | try:
17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18 | succeeded = True
19 | break
20 | except ImportError:
21 | succeeded = False
22 | self.selected_module = module
23 |
24 | if not succeeded:
25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
28 | "the 'config.json' file."
29 | logger.warning(message)
30 |
31 | self.step = 0
32 | self.mode = ''
33 |
34 | self.tb_writer_ftns = {
35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
37 | }
38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
39 |
40 | self.timer = Timer()
41 |
42 | def set_step(self, step, mode='train'):
43 | self.mode = mode
44 | self.step = step
45 | if step == 0:
46 | self.timer.reset()
47 | else:
48 | duration = self.timer.check()
49 | self.add_scalar('steps_per_sec', 1 / duration)
50 |
51 | def __getattr__(self, name):
52 | """
53 | If visualization is configured to use:
54 | return add_data() methods of tensorboard with additional information (step, tag) added.
55 | Otherwise:
56 | return a blank function handle that does nothing
57 | """
58 | if name in self.tb_writer_ftns:
59 | add_data = getattr(self.writer, name, None)
60 |
61 | def wrapper(tag, data, *args, **kwargs):
62 | if add_data is not None:
63 | # add mode(train/valid) tag
64 | if name not in self.tag_mode_exceptions:
65 | tag = '{}/{}'.format(tag, self.mode)
66 | add_data(tag, data, self.step, *args, **kwargs)
67 | return wrapper
68 | else:
69 | # default action for returning methods defined in this class, set_step() for instance.
70 | try:
71 | attr = object.__getattr__(name)
72 | except AttributeError:
73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
74 | return attr
75 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/__init__.py
--------------------------------------------------------------------------------
/model/dymultidepth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/dymultidepth/__init__.py
--------------------------------------------------------------------------------
/model/dymultidepth/ccf_modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torchvision
5 | from torch import nn
6 |
7 |
8 | class CrossCueFusion(nn.Module):
9 | def __init__(self, cv_hypo_num=32, mid_dim=32, input_size=(256,512)):
10 | super().__init__()
11 | self.cv_hypo_num = cv_hypo_num
12 | self.mid_dim = mid_dim
13 | self.residual_connection =True
14 | self.is_reduce = True if input_size[1]>650 else False
15 |
16 | if not self.is_reduce:
17 | self.mono_expand = nn.Sequential(
18 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
19 | nn.BatchNorm2d(self.mid_dim),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
22 | nn.BatchNorm2d(self.mid_dim),
23 | nn.ReLU(inplace=True)
24 | )
25 | self.multi_expand = nn.Sequential(
26 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
27 | nn.BatchNorm2d(self.mid_dim),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
30 | nn.BatchNorm2d(self.mid_dim),
31 | nn.ReLU(inplace=True)
32 | )
33 | else:
34 | self.mono_expand = nn.Sequential(
35 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
36 | nn.BatchNorm2d(self.mid_dim),
37 | nn.ReLU(inplace=True),
38 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
39 | nn.BatchNorm2d(self.mid_dim),
40 | nn.ReLU(inplace=True),
41 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
42 | nn.BatchNorm2d(self.mid_dim),
43 | nn.ReLU(inplace=True)
44 | )
45 |
46 | self.multi_expand = nn.Sequential(
47 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
48 | nn.BatchNorm2d(self.mid_dim),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
51 | nn.BatchNorm2d(self.mid_dim),
52 | nn.ReLU(inplace=True),
53 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
54 | nn.BatchNorm2d(self.mid_dim),
55 | nn.ReLU(inplace=True)
56 | )
57 | self.kq_dim = self.mid_dim //4 if self.mid_dim>128 else self.mid_dim
58 |
59 | self.lin_mono_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
60 | self.lin_mono_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
61 | self.lin_mono_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1)
62 |
63 | self.lin_multi_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
64 | self.lin_multi_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
65 | self.lin_multi_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1)
66 |
67 | self.softmax = nn.Softmax(dim=-1)
68 |
69 | if self.residual_connection:
70 | self.mono_reg = nn.Sequential(
71 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1),
72 | nn.BatchNorm2d(self.mid_dim),
73 | nn.ReLU(inplace=True),
74 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1),
75 | nn.BatchNorm2d(self.mid_dim),
76 | nn.ReLU(inplace=True)
77 | )
78 | self.multi_reg = nn.Sequential(
79 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=1, padding=0),
80 | nn.BatchNorm2d(self.mid_dim),
81 | nn.ReLU(inplace=True)
82 | )
83 | self.gamma = nn.Parameter(torch.zeros(1))
84 |
85 | def forward(self, mono_pseudo_cost, cost_volume):
86 | init_b, init_c, init_h, init_w = cost_volume.shape
87 | mono_feat = self.mono_expand(mono_pseudo_cost)
88 | multi_feat = self.multi_expand(cost_volume)
89 | b,c,h,w = multi_feat.shape
90 |
91 | # cross-cue attention
92 | mono_q = self.lin_mono_q(mono_feat).view(b,-1,h*w).permute(0,2,1)
93 | mono_k = self.lin_mono_k(mono_feat).view(b,-1,h*w)
94 | mono_score = torch.bmm(mono_q, mono_k)
95 | mono_atten = self.softmax(mono_score)
96 |
97 | multi_q = self.lin_multi_q(multi_feat).view(b,-1,h*w).permute(0,2,1)
98 | multi_k = self.lin_multi_k(multi_feat).view(b,-1,h*w)
99 | multi_score = torch.bmm(multi_q, multi_k)
100 | multi_atten = self.softmax(multi_score)
101 |
102 | mono_v = self.lin_mono_v(mono_feat).view(b,-1,h*w)
103 | mono_out = torch.bmm(mono_v, multi_atten.permute(0,2,1))
104 | mono_out = mono_out.view(b,self.mid_dim, h,w)
105 |
106 | multi_v = self.lin_multi_v(multi_feat).view(b,-1,h*w)
107 | multi_out = torch.bmm(multi_v, mono_atten.permute(0,2,1))
108 | multi_out = multi_out.view(b,self.mid_dim, h,w)
109 |
110 |
111 | # concatenate and upsample
112 | fused = torch.cat((multi_out,mono_out), dim=1)
113 | fused = torch.nn.functional.interpolate(fused, size=(init_h,init_w))
114 |
115 | if self.residual_connection:
116 | mono_residual = self.mono_reg(mono_pseudo_cost)
117 | multi_residual = self.multi_reg(cost_volume)
118 | fused_cat = torch.cat((mono_residual,multi_residual), dim=1)
119 | fused = fused_cat + self.gamma * fused
120 |
121 | return fused
122 |
123 |
124 |
125 | class MultiGuideMono(nn.Module):
126 | def __init__(self, cv_hypo_num=32, mid_dim=32, input_size=(256,512)):
127 | super().__init__()
128 | self.cv_hypo_num = cv_hypo_num
129 | self.mid_dim = mid_dim
130 | self.residual_connection =True
131 | self.is_reduce = True if input_size[1]>650 else False
132 |
133 | if not self.is_reduce:
134 | self.mono_expand = nn.Sequential(
135 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
136 | nn.BatchNorm2d(self.mid_dim),
137 | nn.ReLU(inplace=True),
138 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
139 | nn.BatchNorm2d(self.mid_dim),
140 | nn.ReLU(inplace=True)
141 | )
142 | self.multi_expand = nn.Sequential(
143 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
144 | nn.BatchNorm2d(self.mid_dim),
145 | nn.ReLU(inplace=True),
146 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
147 | nn.BatchNorm2d(self.mid_dim),
148 | nn.ReLU(inplace=True)
149 | )
150 | else:
151 | self.mono_expand = nn.Sequential(
152 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
153 | nn.BatchNorm2d(self.mid_dim),
154 | nn.ReLU(inplace=True),
155 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
156 | nn.BatchNorm2d(self.mid_dim),
157 | nn.ReLU(inplace=True),
158 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
159 | nn.BatchNorm2d(self.mid_dim),
160 | nn.ReLU(inplace=True)
161 | )
162 |
163 | self.multi_expand = nn.Sequential(
164 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
165 | nn.BatchNorm2d(self.mid_dim),
166 | nn.ReLU(inplace=True),
167 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
168 | nn.BatchNorm2d(self.mid_dim),
169 | nn.ReLU(inplace=True),
170 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
171 | nn.BatchNorm2d(self.mid_dim),
172 | nn.ReLU(inplace=True)
173 | )
174 | self.kq_dim = self.mid_dim //4 if self.mid_dim>128 else self.mid_dim
175 |
176 | self.lin_mono_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1)
177 |
178 | self.lin_multi_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
179 | self.lin_multi_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
180 |
181 | self.softmax = nn.Softmax(dim=-1)
182 |
183 | if self.residual_connection:
184 | self.mono_reg = nn.Sequential(
185 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1),
186 | nn.BatchNorm2d(self.mid_dim),
187 | nn.ReLU(inplace=True),
188 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1),
189 | nn.BatchNorm2d(self.mid_dim),
190 | nn.ReLU(inplace=True)
191 | )
192 | self.gamma = nn.Parameter(torch.zeros(1))
193 |
194 | def forward(self, mono_pseudo_cost, cost_volume):
195 | init_b, init_c, init_h, init_w = cost_volume.shape
196 | mono_feat = self.mono_expand(mono_pseudo_cost)
197 | multi_feat = self.multi_expand(cost_volume)
198 | b,c,h,w = multi_feat.shape
199 |
200 | # multi attention
201 |
202 | multi_q = self.lin_multi_q(multi_feat).view(b,-1,h*w).permute(0,2,1)
203 | multi_k = self.lin_multi_k(multi_feat).view(b,-1,h*w)
204 | multi_score = torch.bmm(multi_q, multi_k)
205 | multi_atten = self.softmax(multi_score)
206 |
207 | mono_v = self.lin_mono_v(mono_feat).view(b,-1,h*w)
208 | mono_out = torch.bmm(mono_v, multi_atten.permute(0,2,1))
209 | mono_out = mono_out.view(b,self.mid_dim, h,w)
210 |
211 | # upsample
212 | fused = torch.nn.functional.interpolate(mono_out, size=(init_h,init_w))
213 |
214 | if self.residual_connection:
215 | mono_residual = self.mono_reg(mono_pseudo_cost)
216 | fused = mono_residual + self.gamma * fused
217 |
218 | return fused
219 |
220 |
221 | class MonoGuideMulti(nn.Module):
222 | def __init__(self, cv_hypo_num=32, mid_dim=32, input_size=(256,512)):
223 | super().__init__()
224 | self.cv_hypo_num = cv_hypo_num
225 | self.mid_dim = mid_dim
226 | self.residual_connection =True
227 | self.is_reduce = True if input_size[1]>650 else False
228 |
229 | if not self.is_reduce:
230 | self.mono_expand = nn.Sequential(
231 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
232 | nn.BatchNorm2d(self.mid_dim),
233 | nn.ReLU(inplace=True),
234 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
235 | nn.BatchNorm2d(self.mid_dim),
236 | nn.ReLU(inplace=True)
237 | )
238 | self.multi_expand = nn.Sequential(
239 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
240 | nn.BatchNorm2d(self.mid_dim),
241 | nn.ReLU(inplace=True),
242 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
243 | nn.BatchNorm2d(self.mid_dim),
244 | nn.ReLU(inplace=True)
245 | )
246 | else:
247 | self.mono_expand = nn.Sequential(
248 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
249 | nn.BatchNorm2d(self.mid_dim),
250 | nn.ReLU(inplace=True),
251 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
252 | nn.BatchNorm2d(self.mid_dim),
253 | nn.ReLU(inplace=True),
254 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
255 | nn.BatchNorm2d(self.mid_dim),
256 | nn.ReLU(inplace=True)
257 | )
258 |
259 | self.multi_expand = nn.Sequential(
260 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
261 | nn.BatchNorm2d(self.mid_dim),
262 | nn.ReLU(inplace=True),
263 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
264 | nn.BatchNorm2d(self.mid_dim),
265 | nn.ReLU(inplace=True),
266 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2),
267 | nn.BatchNorm2d(self.mid_dim),
268 | nn.ReLU(inplace=True)
269 | )
270 | self.kq_dim = self.mid_dim //4 if self.mid_dim>128 else self.mid_dim
271 |
272 | self.lin_mono_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
273 | self.lin_mono_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1)
274 |
275 | self.lin_multi_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1)
276 |
277 | self.softmax = nn.Softmax(dim=-1)
278 |
279 | if self.residual_connection:
280 | self.multi_reg = nn.Sequential(
281 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=1, padding=0),
282 | nn.BatchNorm2d(self.mid_dim),
283 | nn.ReLU(inplace=True)
284 | )
285 | self.gamma = nn.Parameter(torch.zeros(1))
286 |
287 | def forward(self, mono_pseudo_cost, cost_volume):
288 | init_b, init_c, init_h, init_w = cost_volume.shape
289 | mono_feat = self.mono_expand(mono_pseudo_cost)
290 | multi_feat = self.multi_expand(cost_volume)
291 | b,c,h,w = multi_feat.shape
292 |
293 | # mono attention
294 | mono_q = self.lin_mono_q(mono_feat).view(b,-1,h*w).permute(0,2,1)
295 | mono_k = self.lin_mono_k(mono_feat).view(b,-1,h*w)
296 | mono_score = torch.bmm(mono_q, mono_k)
297 | mono_atten = self.softmax(mono_score)
298 |
299 | multi_v = self.lin_multi_v(multi_feat).view(b,-1,h*w)
300 | multi_out = torch.bmm(multi_v, mono_atten.permute(0,2,1))
301 | multi_out = multi_out.view(b,self.mid_dim, h,w)
302 |
303 |
304 | # upsample
305 | fused = torch.nn.functional.interpolate(multi_out, size=(init_h,init_w))
306 |
307 | if self.residual_connection:
308 | multi_residual = self.multi_reg(cost_volume)
309 | fused = multi_residual + self.gamma * fused
310 |
311 | return fused
312 |
--------------------------------------------------------------------------------
/model/dymultidepth/dymultidepth_model.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import kornia.augmentation as K
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision
8 | from torch import nn
9 |
10 | from model.layers import point_projection, PadSameConv2d, ConvReLU2, ConvReLU, Upconv, Refine, SSIM, Backprojection
11 | from utils import conditional_flip, filter_state_dict
12 |
13 | from utils import parse_config
14 | from .base_models import DepthAugmentation, MaskAugmentation, CostVolumeModule, DepthModule, ResnetEncoder,MonoDepthModule, EfficientNetEncoder
15 | from .ccf_modules import *
16 |
17 | class DyMultiDepthModel(nn.Module):
18 | def __init__(self, inv_depth_min_max=(0.33, 0.0025), cv_depth_steps=32, pretrain_mode=False, pretrain_dropout=0.0, pretrain_dropout_mode=0,
19 | augmentation=None, use_mono=True, use_stereo=False, use_ssim=True, sfcv_mult_mask=True,
20 | simple_mask=False, mask_use_cv=True, mask_use_feats=True, cv_patch_size=3, depth_large_model=False, no_cv=False,
21 | freeze_backbone=True, freeze_module=(), checkpoint_location=None, mask_cp_loc=None, depth_cp_loc=None,
22 | fusion_type = 'ccf_fusion', input_size=[256, 512], ccf_mid_dim=32, use_img_in_depthnet=True,
23 | backbone_type='resnet18'):
24 | """
25 | :param inv_depth_min_max: Min / max (inverse) depth. (Default=(0.33, 0.0025))
26 | :param cv_depth_steps: Number of depth steps for the cost volume. (Default=32)
27 | :param pretrain_mode: Which pretrain mode to use:
28 | 0 / False: Run full network.
29 | 1 / True: Only run depth module. In this mode, dropout can be activated to zero out patches from the
30 | unmasked cost volume. Dropout was not used for the paper.
31 | 2: Only run mask module. In this mode, the network will return the mask as the main result.
32 | 3: Only run depth module, but use the auxiliary masks to mask the cost volume. This mode was not used in
33 | the paper. (Default=0)
34 | :param pretrain_dropout: Dropout rate used in pretrain_mode=1. (Default=0)
35 | :param augmentation: Which augmentation module to use. "mask"=MaskAugmentation, "depth"=DepthAugmentation. The
36 | exact way to use this is very context dependent. Refer to the training scripts for more details. (Default="none")
37 | :param use_mono: Use monocular frames during the forward pass. (Default=True)
38 | :param use_stereo: Use stereo frame during the forward pass. (Default=False)
39 | :param use_ssim: Use SSIM during cost volume computation. (Default=True)
40 | :param sfcv_mult_mask: For the single frame cost volumes: If a pixel does not have a valid reprojection at any
41 | depth step, all depths get invalidated. (Default=True)
42 | :param simple_mask: Use the standard cost volume instead of multiple single frame cost volumes in the mask
43 | module. (Default=False)
44 | :param cv_patch_size: Patchsize, over which the ssim errors get averaged. (Default=3)
45 | :param freeze_module: Freeze given string list of modules. (Default=())
46 | :param checkpoint_location: Load given list of checkpoints. (Default=None)
47 | :param mask_cp_loc: Load list of checkpoints for the mask module. (Default=None)
48 | :param depth_cp_loc: Load list of checkpoints for the depth module. (Default=None)
49 | """
50 | super().__init__()
51 | self.inv_depth_min_max = inv_depth_min_max
52 | self.cv_depth_steps = cv_depth_steps
53 | self.use_mono = use_mono
54 | self.use_stereo = use_stereo
55 | self.use_ssim = use_ssim
56 | self.sfcv_mult_mask = sfcv_mult_mask
57 | self.pretrain_mode = int(pretrain_mode)
58 | self.pretrain_dropout = pretrain_dropout
59 | self.pretrain_dropout_mode = pretrain_dropout_mode
60 | self.augmentation = augmentation
61 | self.simple_mask = simple_mask
62 | self.mask_use_cv = mask_use_cv
63 | self.mask_use_feats = mask_use_feats
64 | self.cv_patch_size = cv_patch_size
65 | self.no_cv = no_cv
66 | self.depth_large_model = depth_large_model
67 | self.checkpoint_location = checkpoint_location
68 | self.mask_cp_loc = mask_cp_loc
69 | self.depth_cp_loc = depth_cp_loc
70 | self.freeze_module = freeze_module
71 | self.freeze_backbone = freeze_backbone
72 |
73 | self.fusion_type = fusion_type
74 | self.input_size = input_size
75 | self.ccf_mid_dim = ccf_mid_dim
76 | self.use_img_in_depthnet = use_img_in_depthnet
77 | self.backbone_type = backbone_type
78 |
79 | assert self.backbone_type in ["resnet18", "efficientnetb5"]
80 |
81 | self.depthmodule_in_chn = self.cv_depth_steps
82 | if fusion_type == 'ccf_fusion':
83 | self.extra_input_dim = 0
84 | self.fusion_module = CrossCueFusion(cv_hypo_num=self.cv_depth_steps, mid_dim=32, input_size=self.input_size)
85 | self.depthmodule_in_chn = self.ccf_mid_dim * 2
86 | elif fusion_type == 'mono_guide_multi':
87 | self.extra_input_dim = 0
88 | self.fusion_module = MonoGuideMulti(cv_hypo_num=self.cv_depth_steps, mid_dim=32, input_size=self.input_size)
89 | self.depthmodule_in_chn = self.ccf_mid_dim
90 | elif fusion_type == 'multi_guide_mono':
91 | self.extra_input_dim = 0
92 | self.fusion_module = MultiGuideMono(cv_hypo_num=self.cv_depth_steps, mid_dim=32, input_size=self.input_size)
93 | self.depthmodule_in_chn = self.ccf_mid_dim
94 |
95 | if self.backbone_type == 'resnet18':
96 | self._feature_extractor = ResnetEncoder(num_layers=18, pretrained=True)
97 | elif self.backbone_type == 'efficientnetb5':
98 | self._feature_extractor = EfficientNetEncoder(pretrained=True)
99 |
100 | if self.freeze_backbone:
101 | for p in self._feature_extractor.parameters(True):
102 | p.requires_grad_(False)
103 |
104 | self.cv_module = CostVolumeModule(use_mono=use_mono, use_stereo=use_stereo, use_ssim=use_ssim, sfcv_mult_mask=self.sfcv_mult_mask, patch_size=cv_patch_size)
105 |
106 | self.depth_module = DepthModule(self.depthmodule_in_chn, feature_channels=self._feature_extractor.num_ch_enc,
107 | large_model=self.depth_large_model, use_input_img=self.use_img_in_depthnet)
108 | self.mono_module = MonoDepthModule(extra_input_dim=self.extra_input_dim,
109 | feature_channels=self._feature_extractor.num_ch_enc, large_model=self.depth_large_model)
110 |
111 | if self.checkpoint_location is not None:
112 | if not isinstance(checkpoint_location, list):
113 | checkpoint_location = [checkpoint_location]
114 | for cp in checkpoint_location:
115 | checkpoint = torch.load(cp, map_location=torch.device("cpu"))
116 | checkpoint_state_dict = checkpoint["state_dict"]
117 | checkpoint_state_dict = filter_state_dict(checkpoint_state_dict, checkpoint["arch"] == "DataParallel")
118 | self.load_state_dict(checkpoint_state_dict, strict=True)
119 |
120 |
121 | for module_name in self.freeze_module:
122 | module = self.__getattr__(module_name + "_module")
123 | module.eval()
124 | for param in module.parameters(True):
125 | param.requires_grad_(False)
126 |
127 | if self.augmentation == "depth":
128 | self.augmenter = DepthAugmentation()
129 | elif self.augmentation == "mask":
130 | self.augmenter = MaskAugmentation()
131 | else:
132 | self.augmenter = None
133 |
134 | def forward(self, data_dict):
135 | keyframe = data_dict["keyframe"]
136 |
137 | data_dict["inv_depth_min"] = keyframe.new_tensor([self.inv_depth_min_max[0]])
138 | data_dict["inv_depth_max"] = keyframe.new_tensor([self.inv_depth_min_max[1]])
139 | data_dict["cv_depth_steps"] = keyframe.new_tensor([self.cv_depth_steps], dtype=torch.int32)
140 |
141 | with torch.no_grad():
142 | data_dict = self.cv_module(data_dict)
143 |
144 | if self.augmenter is not None and self.training:
145 | self.augmenter(data_dict)
146 |
147 | # different with MonoRec: the input image should be the reverted
148 | data_dict["image_features"] = self._feature_extractor(data_dict["keyframe"] + .5)
149 |
150 | data_dict["cost_volume_init"] = data_dict["cost_volume"]
151 |
152 | data_dict = self.mono_module(data_dict)
153 | data_dict["predicted_inverse_depths_mono"] = [(1-pred) * self.inv_depth_min_max[1] + pred * self.inv_depth_min_max[0]
154 | for pred in data_dict["predicted_inverse_depths_mono"]]
155 | mono_depth_pred = torch.clamp(1.0 / data_dict["predicted_inverse_depths_mono"][0], min=1e-3, max=80.0).detach()
156 |
157 | b, c, h, w = keyframe.shape
158 |
159 |
160 | pseudo_mono_cost = self.pseudocost_from_mono(mono_depth_pred,
161 | depth_hypothesis = data_dict["cv_bin_steps"].view(1, -1, 1, 1).expand(b, -1, h, w).detach()).detach()
162 |
163 |
164 | if self.training:
165 | if self.pretrain_dropout_mode == 0:
166 | cv_mask = keyframe.new_ones(b, 1, h // 8, w // 8, requires_grad=False)
167 | F.dropout(cv_mask, p=1 - self.pretrain_dropout, training=self.training, inplace=True)
168 | cv_mask = (cv_mask!=0).float()
169 | cv_mask = F.upsample(cv_mask, (h, w))
170 | else:
171 | cv_mask = keyframe.new_ones(b, 1, 1, 1, requires_grad=False)
172 | F.dropout(cv_mask, p = 1 - self.pretrain_dropout, training=self.training, inplace=True)
173 | cv_mask = cv_mask.expand(-1, -1, h, w)
174 | else:
175 | cv_mask = keyframe.new_zeros(b, 1, h, w, requires_grad=False)
176 | data_dict["cv_mask"] = cv_mask
177 |
178 |
179 | data_dict["cost_volume"] = (1 - data_dict["cv_mask"]) * self.fusion_module(pseudo_mono_cost, data_dict["cost_volume"])
180 |
181 | data_dict = self.depth_module(data_dict)
182 |
183 | data_dict["predicted_inverse_depths"] = [(1-pred) * self.inv_depth_min_max[1] + pred * self.inv_depth_min_max[0]
184 | for pred in data_dict["predicted_inverse_depths"]]
185 |
186 | if self.augmenter is not None and self.training:
187 | self.augmenter.revert(data_dict)
188 |
189 | data_dict["result"] = data_dict["predicted_inverse_depths"][0]
190 | data_dict["result_mono"] = data_dict["predicted_inverse_depths_mono"][0]
191 | data_dict["mask"] = data_dict["cv_mask"]
192 |
193 | return data_dict
194 |
195 |
196 | def pseudocost_from_mono(self, monodepth, depth_hypothesis):
197 | abs_depth_diff = torch.abs(monodepth - depth_hypothesis)
198 | # find the closest depth bin that the monodepth correlate with
199 | min_diff_index = torch.argmin(abs_depth_diff, dim=1, keepdim=True)
200 | pseudo_cost = depth_hypothesis.new_zeros(depth_hypothesis.shape)
201 | ones = depth_hypothesis.new_ones(depth_hypothesis.shape)
202 |
203 | pseudo_cost.scatter_(dim = 1, index = min_diff_index, src = ones)
204 |
205 | return pseudo_cost
206 |
207 | def find_mincost_depth(self, cost_volume, depth_hypos):
208 | argmax = torch.argmax(cost_volume, dim=1, keepdim=True)
209 | mincost_depth = torch.gather(input=depth_hypos, dim=1, index=argmax)
210 | return mincost_depth
--------------------------------------------------------------------------------
/model/layers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import math
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F, Conv2d, LeakyReLU, Upsample, Sigmoid, ConvTranspose2d
9 |
10 |
11 | class ConvBlock(nn.Module):
12 | """Layer to perform a convolution followed by ELU
13 | """
14 | def __init__(self, in_channels, out_channels):
15 | super(ConvBlock, self).__init__()
16 |
17 | self.conv = Conv3x3(in_channels, out_channels)
18 | self.nonlin = nn.ELU(inplace=True)
19 |
20 | def forward(self, x):
21 | out = self.conv(x)
22 | out = self.nonlin(out)
23 | return out
24 |
25 |
26 | class Conv3x3(nn.Module):
27 | """Layer to pad and convolve input
28 | """
29 | def __init__(self, in_channels, out_channels, use_refl=True):
30 | super(Conv3x3, self).__init__()
31 |
32 | if use_refl:
33 | self.pad = nn.ReflectionPad2d(1)
34 | else:
35 | self.pad = nn.ZeroPad2d(1)
36 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
37 |
38 | def forward(self, x):
39 | out = self.pad(x)
40 | out = self.conv(out)
41 | return out
42 |
43 | class Backprojection(nn.Module):
44 | def __init__(self, batch_size, height, width):
45 | super(Backprojection, self).__init__()
46 |
47 | self.N, self.H, self.W = batch_size, height, width
48 |
49 | yy, xx = torch.meshgrid([torch.arange(0., float(self.H)), torch.arange(0., float(self.W))])
50 | yy = yy.contiguous().view(-1)
51 | xx = xx.contiguous().view(-1)
52 | self.ones = nn.Parameter(torch.ones(self.N, 1, self.H * self.W), requires_grad=False)
53 | self.coord = torch.unsqueeze(torch.stack([xx, yy], 0), 0).repeat(self.N, 1, 1)
54 | self.coord = nn.Parameter(torch.cat([self.coord, self.ones], 1), requires_grad=False)
55 |
56 | def forward(self, depth, inv_K) :
57 | cam_p_norm = torch.matmul(inv_K[:, :3, :3], self.coord[:depth.shape[0], :, :])
58 | cam_p_euc = depth.view(depth.shape[0], 1, -1) * cam_p_norm
59 | cam_p_h = torch.cat([cam_p_euc, self.ones[:depth.shape[0], :, :]], 1)
60 |
61 | return cam_p_h
62 |
63 | def point_projection(points3D, batch_size, height, width, K, T):
64 | N, H, W = batch_size, height, width
65 | cam_coord = torch.matmul(torch.matmul(K, T)[:, :3, :], points3D)
66 | img_coord = cam_coord[:, :2, :] / (cam_coord[:, 2:3, :] + 1e-7)
67 | img_coord[:, 0, :] /= W - 1
68 | img_coord[:, 1, :] /= H - 1
69 | img_coord = (img_coord - 0.5) * 2
70 | img_coord = img_coord.view(N, 2, H, W).permute(0, 2, 3, 1)
71 | return img_coord
72 |
73 | def upsample(x):
74 | """Upsample input tensor by a factor of 2
75 | """
76 | return F.interpolate(x, scale_factor=2, mode="nearest")
77 |
78 |
79 | class GaussianAverage(nn.Module):
80 | def __init__(self) -> None:
81 | super().__init__()
82 | self.window = torch.Tensor([
83 | [0.0947, 0.1183, 0.0947],
84 | [0.1183, 0.1478, 0.1183],
85 | [0.0947, 0.1183, 0.0947]])
86 |
87 | def forward(self, x):
88 | kernel = self.window.to(x.device).to(x.dtype).repeat(x.shape[1], 1, 1, 1)
89 | return F.conv2d(x, kernel, padding=0, groups=x.shape[1])
90 |
91 | class SSIM(nn.Module):
92 | """Layer to compute the SSIM loss between a pair of images
93 | """
94 | def __init__(self, pad_reflection=True, gaussian_average=False, comp_mode=False):
95 | super(SSIM, self).__init__()
96 | self.comp_mode = comp_mode
97 |
98 | if not gaussian_average:
99 | self.mu_x_pool = nn.AvgPool2d(3, 1)
100 | self.mu_y_pool = nn.AvgPool2d(3, 1)
101 | self.sig_x_pool = nn.AvgPool2d(3, 1)
102 | self.sig_y_pool = nn.AvgPool2d(3, 1)
103 | self.sig_xy_pool = nn.AvgPool2d(3, 1)
104 | else:
105 | self.mu_x_pool = GaussianAverage()
106 | self.mu_y_pool = GaussianAverage()
107 | self.sig_x_pool = GaussianAverage()
108 | self.sig_y_pool = GaussianAverage()
109 | self.sig_xy_pool = GaussianAverage()
110 |
111 | if pad_reflection:
112 | self.pad = nn.ReflectionPad2d(1)
113 | else:
114 | self.pad = nn.ZeroPad2d(1)
115 |
116 | self.C1 = 0.01 ** 2
117 | self.C2 = 0.03 ** 2
118 |
119 | def forward(self, x, y):
120 | x = self.pad(x)
121 | y = self.pad(y)
122 |
123 | mu_x = self.mu_x_pool(x)
124 | mu_y = self.mu_y_pool(y)
125 | mu_x_sq = mu_x ** 2
126 | mu_y_sq = mu_y ** 2
127 | mu_x_y = mu_x * mu_y
128 |
129 | sigma_x = self.sig_x_pool(x ** 2) - mu_x_sq
130 | sigma_y = self.sig_y_pool(y ** 2) - mu_y_sq
131 | sigma_xy = self.sig_xy_pool(x * y) - mu_x_y
132 |
133 | SSIM_n = (2 * mu_x_y + self.C1) * (2 * sigma_xy + self.C2)
134 | SSIM_d = (mu_x_sq + mu_y_sq + self.C1) * (sigma_x + sigma_y + self.C2)
135 |
136 | if not self.comp_mode:
137 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
138 | else:
139 | return torch.clamp((1 - SSIM_n / SSIM_d), 0, 1) / 2
140 |
141 |
142 | def ssim(x, y, pad_reflection=True, gaussian_average=False, comp_mode=False):
143 | ssim_ = SSIM(pad_reflection, gaussian_average, comp_mode)
144 | return ssim_(x, y)
145 |
146 |
147 | class ResidualImage(nn.Module):
148 | def __init__(self):
149 | super().__init__()
150 | self.residual_image = ResidualImageModule()
151 |
152 | def forward(self, keyframe: torch.Tensor, keyframe_pose: torch.Tensor, keyframe_intrinsics: torch.Tensor,
153 | depths: torch.Tensor, frames: list, poses: list, intrinsics: list):
154 | data_dict = {"keyframe": keyframe, "keyframe_pose": keyframe_pose, "keyframe_intrinsics": keyframe_intrinsics,
155 | "predicted_inverse_depths": [depths], "frames": frames, "poses": poses, "list": list,
156 | "intrinsics": intrinsics, "inv_depth_max": 0, "inv_depth_min": 1}
157 | data_dict = self.residual_image(data_dict)
158 | return data_dict["residual_image"]
159 |
160 |
161 | class ResidualImageModule(nn.Module):
162 | def __init__(self, use_mono=True, use_stereo=False):
163 | super().__init__()
164 | self.use_mono = use_mono
165 | self.use_stereo = use_stereo
166 | self.ssim = SSIM()
167 |
168 | def forward(self, data_dict):
169 | keyframe = data_dict["keyframe"]
170 | keyframe_intrinsics = data_dict["keyframe_intrinsics"]
171 | keyframe_pose = data_dict["keyframe_pose"]
172 | depths = (1-data_dict["predicted_inverse_depths"][0]) * data_dict["inv_depth_max"] + data_dict["predicted_inverse_depths"][0] * data_dict["inv_depth_min"]
173 |
174 | frames = []
175 | intrinsics = []
176 | poses = []
177 |
178 | if self.use_mono:
179 | frames += data_dict["frames"]
180 | intrinsics += data_dict["intrinsics"]
181 | poses += data_dict["poses"]
182 | if self.use_stereo:
183 | frames += [data_dict["stereoframe"]]
184 | intrinsics += [data_dict["stereoframe_intrinsics"]]
185 | poses += [data_dict["stereoframe_pose"]]
186 |
187 | n, c, h, w = keyframe.shape
188 |
189 | backproject_depth = Backprojection(n, h, w)
190 | backproject_depth.to(keyframe.device)
191 |
192 | inv_k = torch.inverse(keyframe_intrinsics)
193 | cam_points = (inv_k[:, :3, :3] @ backproject_depth.pix_coords)
194 | cam_points = cam_points / depths.view(n, 1, -1)
195 | cam_points = torch.cat([cam_points, backproject_depth.ones], 1)
196 |
197 | masks = []
198 | residuals = []
199 |
200 | for i, image in enumerate(frames):
201 | t = torch.inverse(poses[i]) @ keyframe_pose
202 | pix_coords = point_projection(cam_points, n, h, w, intrinsics[i], t)
203 | warped_image = F.grid_sample(image + 1, pix_coords)
204 | mask = torch.any(warped_image == 0, dim=1, keepdim=True)
205 | warped_image -= .5
206 | residual = self.ssim(warped_image, keyframe + .5)
207 | masks.append(mask)
208 | residuals.append(residual)
209 |
210 | masks = torch.stack(masks, dim=1)
211 | residuals = torch.stack(residuals, dim=1)
212 | residuals[masks.expand(-1, -1 , c, -1, -1)] = float("inf")
213 |
214 | residual_image = torch.min(torch.mean(residuals, dim=2, keepdim=True), dim=1)[0]
215 | residual_image[torch.min(masks, dim=1)[0]] = 0
216 | data_dict["residual_image"] = residual_image
217 | return data_dict
218 |
219 |
220 | class PadSameConv2d(torch.nn.Module):
221 | def __init__(self, kernel_size, stride=1):
222 | """
223 | Imitates padding_mode="same" from tensorflow.
224 | :param kernel_size: Kernelsize of the convolution, int or tuple/list
225 | :param stride: Stride of the convolution, int or tuple/list
226 | """
227 | super().__init__()
228 | if isinstance(kernel_size, (tuple, list)):
229 | self.kernel_size_y = kernel_size[0]
230 | self.kernel_size_x = kernel_size[1]
231 | else:
232 | self.kernel_size_y = kernel_size
233 | self.kernel_size_x = kernel_size
234 | if isinstance(stride, (tuple, list)):
235 | self.stride_y = stride[0]
236 | self.stride_x = stride[1]
237 | else:
238 | self.stride_y = stride
239 | self.stride_x = stride
240 |
241 | def forward(self, x: torch.Tensor):
242 | _, _, height, width = x.shape
243 |
244 | # For the convolution we want to achieve a output size of (n_h, n_w) = (math.ceil(h / s_y), math.ceil(w / s_y)).
245 | # Therefore we need to apply n_h convolution kernels with stride s_y. We will have n_h - 1 offsets of size s_y.
246 | # Additionally, we need to add the size of our kernel. This is the height we require to get n_h. We need to pad
247 | # the read difference between this and the old height. We will pad math.floor(pad_y / 2) on the left and
248 | # math-ceil(pad_y / 2) on the right. Same for pad_x respectively.
249 | padding_y = (self.stride_y * (math.ceil(height / self.stride_y) - 1) + self.kernel_size_y - height) / 2
250 | padding_x = (self.stride_x * (math.ceil(width / self.stride_x) - 1) + self.kernel_size_x - width) / 2
251 | padding = [math.floor(padding_x), math.ceil(padding_x), math.floor(padding_y), math.ceil(padding_y)]
252 | return F.pad(input=x, pad=padding)
253 |
254 |
255 | class PadSameConv2dTransposed(torch.nn.Module):
256 | def __init__(self, stride):
257 | """
258 | Imitates padding_mode="same" from tensorflow.
259 | :param stride: Stride of the convolution_transposed, int or tuple/list
260 | """
261 | super().__init__()
262 | if isinstance(stride, (tuple, list)):
263 | self.stride_y = stride[0]
264 | self.stride_x = stride[1]
265 | else:
266 | self.stride_y = stride
267 | self.stride_x = stride
268 |
269 | def forward(self, x: torch.Tensor, orig_shape: torch.Tensor):
270 | target_shape = x.new_tensor(list(orig_shape))
271 | target_shape[-2] *= self.stride_y
272 | target_shape[-1] *= self.stride_x
273 | oversize = target_shape[-2:] - x.new_tensor(x.shape)[-2:]
274 | if oversize[0] > 0 and oversize[1] > 0:
275 | x = F.pad(x, [math.floor(oversize[1] / 2), math.ceil(oversize[1] / 2), math.floor(oversize[0] / 2),
276 | math.ceil(oversize[0] / 2)])
277 | elif oversize[0] > 0 >= oversize[1]:
278 | x = F.pad(x, [0, 0, math.floor(oversize[0] / 2), math.ceil(oversize[0] / 2)])
279 | x = x[:, :, :, math.floor(-oversize[1] / 2):-math.ceil(-oversize[1] / 2)]
280 | elif oversize[0] <= 0 < oversize[1]:
281 | x = F.pad(x, [math.floor(oversize[1] / 2), math.ceil(oversize[1] / 2)])
282 | x = x[:, :, math.floor(-oversize[0] / 2):-math.ceil(-oversize[0] / 2), :]
283 | else:
284 | x = x[:, :, math.floor(-oversize[0] / 2):-math.ceil(-oversize[0] / 2),
285 | math.floor(-oversize[1] / 2):-math.ceil(-oversize[1] / 2)]
286 | return x
287 |
288 |
289 | class ConvReLU2(torch.nn.Module):
290 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1):
291 | """
292 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
293 | only in x direction.
294 | :param in_channels: Number of input channels
295 | :param out_channels: Number of output channels
296 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
297 | :param stride: Stride for the convolutions, first in y direction, then in x direction
298 | """
299 | super().__init__()
300 | self.pad_0 = PadSameConv2d(kernel_size=(kernel_size, 1), stride=(stride, 1))
301 | self.conv_y = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
302 | stride=(stride, 1))
303 | self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope)
304 | self.pad_1 = PadSameConv2d(kernel_size=(1, kernel_size), stride=(1, stride))
305 | self.conv_x = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
306 | stride=(1, stride))
307 |
308 | def forward(self, x: torch.Tensor):
309 | t = self.pad_0(x)
310 | t = self.conv_y(t)
311 | t = self.leaky_relu(t)
312 | t = self.pad_1(t)
313 | t = self.conv_x(t)
314 | return self.leaky_relu(t)
315 |
316 |
317 | class ConvReLU(torch.nn.Module):
318 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1):
319 | """
320 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
321 | only in x direction.
322 | :param in_channels: Number of input channels
323 | :param out_channels: Number of output channels
324 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
325 | :param stride: Stride for the convolutions, first in y direction, then in x direction
326 | """
327 | super().__init__()
328 | self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride)
329 | self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
330 | self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope)
331 |
332 | def forward(self, x: torch.Tensor):
333 | t = self.pad(x)
334 | t = self.conv(t)
335 | return self.leaky_relu(t)
336 |
337 |
338 | class Upconv(torch.nn.Module):
339 | def __init__(self, in_channels, out_channels):
340 | """
341 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
342 | only in x direction.
343 | :param in_channels: Number of input channels
344 | :param out_channels: Number of output channels
345 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
346 | :param stride: Stride for the convolutions, first in y direction, then in x direction
347 | """
348 | super().__init__()
349 | self.upsample = Upsample(scale_factor=2)
350 | self.pad = PadSameConv2d(kernel_size=2)
351 | self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=1)
352 |
353 | def forward(self, x: torch.Tensor):
354 | t = self.upsample(x)
355 | t = self.pad(t)
356 | return self.conv(t)
357 |
358 |
359 | class ConvSig(torch.nn.Module):
360 | def __init__(self, in_channels, out_channels, kernel_size, stride=1):
361 | """
362 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
363 | only in x direction.
364 | :param in_channels: Number of input channels
365 | :param out_channels: Number of output channels
366 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
367 | :param stride: Stride for the convolutions, first in y direction, then in x direction
368 | """
369 | super().__init__()
370 | self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride)
371 | self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
372 | self.sig = Sigmoid()
373 |
374 | def forward(self, x: torch.Tensor):
375 | t = self.pad(x)
376 | t = self.conv(t)
377 | return self.sig(t)
378 |
379 |
380 | class Refine(torch.nn.Module):
381 | def __init__(self, in_channels, out_channels, leaky_relu_neg_slope=0.1):
382 | """
383 | Performs a transposed conv2d with padding that imitates tensorflow same behaviour. The transposed conv2d has
384 | parameters kernel_size=4 and stride=2.
385 | :param in_channels: Channels that go into the conv2d_transposed
386 | :param out_channels: Channels that come out of the conv2d_transposed
387 | """
388 | super().__init__()
389 | self.conv2d_t = ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2)
390 | self.pad = PadSameConv2dTransposed(stride=2)
391 | self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope)
392 |
393 | def forward(self, x: torch.Tensor, features_direct=None):
394 | orig_shape=x.shape
395 | x = self.conv2d_t(x)
396 | x = self.leaky_relu(x)
397 | x = self.pad(x, orig_shape)
398 | if features_direct is not None:
399 | x = torch.cat([x, features_direct], dim=1)
400 | return x
401 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | from .loss_functions.dymultidepth_loss import mask_loss, depth_loss, mask_refinement_loss, depth_refinement_loss, depth_aux_mask_loss, silog_loss, silog_vn_loss, silog_vn_loss_update, abs_silog_loss_virtualnormal
--------------------------------------------------------------------------------
/model/loss_functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/loss_functions/__init__.py
--------------------------------------------------------------------------------
/model/loss_functions/common_losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from model.layers import Backprojection, point_projection, ssim
7 | from utils import create_mask, mask_mean
8 |
9 |
10 | def compute_errors(img0, img1, mask=None):
11 | errors = .85 * torch.mean(ssim(img0, img1, pad_reflection=False, gaussian_average=True, comp_mode=True), dim=1) + .15 * torch.mean(torch.abs(img0 - img1), dim=1)
12 | if mask is not None: return errors, mask
13 | else: return errors
14 |
15 |
16 | def reprojection_loss(depth_prediction: torch.Tensor, data_dict, automasking=False,
17 | error_function=compute_errors, error_function_weight=None, use_mono=True, use_stereo=False,
18 | reduce=True, combine_frames="min", mono_auto=False, border=0):
19 | keyframe = data_dict["keyframe"]
20 | keyframe_pose = data_dict["keyframe_pose"]
21 | keyframe_intrinsics = data_dict["keyframe_intrinsics"]
22 |
23 | frames = []
24 | poses = []
25 | intrinsics = []
26 |
27 | if use_mono:
28 | frames += data_dict["frames"]
29 | poses += data_dict["poses"]
30 | intrinsics += data_dict["intrinsics"]
31 | if use_stereo:
32 | frames += [data_dict["stereoframe"]]
33 | poses += [data_dict["stereoframe_pose"]]
34 | intrinsics += [data_dict["stereoframe_intrinsics"]]
35 |
36 | batch_size, channels, height, width = keyframe.shape
37 | frame_count = len(frames)
38 | keyframe_extrinsics = torch.inverse(keyframe_pose)
39 | extrinsics = [torch.inverse(pose) for pose in poses]
40 |
41 | reprojections = []
42 | if border > 0:
43 | masks = [create_mask(batch_size, height, width, border, keyframe.device) for _ in range(frame_count)]
44 | warped_masks = []
45 |
46 | backproject_depth = Backprojection(batch_size, height, width)
47 | backproject_depth.to(keyframe.device)
48 |
49 | for i, (frame, extrinsic, intrinsic) in enumerate(zip(frames, extrinsics, intrinsics)):
50 | cam_points = backproject_depth(1 / depth_prediction, torch.inverse(keyframe_intrinsics))
51 | pix_coords = point_projection(cam_points, batch_size, height, width, intrinsic, extrinsic @ keyframe_pose)
52 | reprojections.append(F.grid_sample(frame + 1.5, pix_coords, padding_mode="zeros"))
53 | if border > 0:
54 | warped_masks.append(F.grid_sample(masks[i], pix_coords, padding_mode="zeros"))
55 |
56 | reprojections = torch.stack(reprojections, dim=1).view(batch_size * frame_count, channels, height, width)
57 | mask = reprojections[:, 0, :, :] == 0
58 | reprojections -= 1.0
59 |
60 | if border > 0:
61 | mask = ~(torch.stack(warped_masks, dim=1).view(batch_size * frame_count, height, width) > .5)
62 |
63 | keyframe_expanded = (keyframe + .5).unsqueeze(1).expand(-1, frame_count, -1, -1, -1).reshape(batch_size * frame_count, channels, height, width)
64 |
65 | loss = 0
66 |
67 | if type(error_function) != list:
68 | error_function = [error_function]
69 | if error_function_weight is None:
70 | error_function_weight = [1 for i in range(len(error_function))]
71 |
72 | for ef, w in zip(error_function, error_function_weight):
73 | errors, n_mask = ef(reprojections, keyframe_expanded, mask)
74 | n_height, n_width = n_mask.shape[1:]
75 | errors = errors.view(batch_size, frame_count, n_height, n_width)
76 |
77 | n_mask = n_mask.view(batch_size, frame_count, n_height, n_width)
78 | errors[n_mask] = float("inf")
79 |
80 | if automasking:
81 | frames_stacked = torch.stack(frames, dim=1).view(batch_size * frame_count, channels, height, width) + .5
82 | errors_nowarp = ef(frames_stacked, keyframe_expanded).view(batch_size, frame_count, n_height, n_width)
83 | errors[errors_nowarp < errors] = float("inf")
84 |
85 | if mono_auto:
86 | keyframe_expanded_ = (keyframe + .5).unsqueeze(1).expand(-1, len(data_dict["frames"]), -1, -1, -1).reshape(batch_size * len(data_dict["frames"]), channels, height, width)
87 | frames_stacked = (torch.stack(data_dict["frames"], dim=1) + .5).view(batch_size * len(data_dict["frames"]), channels, height, width)
88 | errors_nowarp = ef(frames_stacked, keyframe_expanded_).view(batch_size, len(data_dict["frames"]), n_height, n_width)
89 | errors_nowarp = torch.mean(errors_nowarp, dim=1, keepdim=True)
90 | errors_nowarp[torch.all(n_mask, dim=1, keepdim=True)] = float("inf")
91 | errors = torch.min(errors, errors_nowarp.expand(-1, frame_count, -1, -1))
92 |
93 | if combine_frames == "min":
94 | errors = torch.min(errors, dim=1)[0]
95 | n_mask = torch.isinf(errors)
96 | elif combine_frames == "avg":
97 | n_mask = torch.isinf(errors)
98 | hits = torch.sum((~n_mask).to(dtype=torch.float32), dim=1)
99 | errors[n_mask] = 0
100 | errors = torch.sum(errors, dim=1) / hits
101 | n_mask = hits == 0
102 | errors[n_mask] = float("inf")
103 | elif combine_frames == "rnd":
104 | index = torch.randint(frame_count, (batch_size, 1, 1, 1), device=keyframe.device).expand(-1, 1, n_height, n_width)
105 | errors = torch.gather(errors, dim=1, index=index).squeeze(1)
106 | n_mask = torch.gather(n_mask, dim=1, index=index).squeeze(1)
107 | else:
108 | raise ValueError("Combine frames must be \"min\", \"avg\" or \"rnd\".")
109 |
110 | if reduce:
111 | loss += w * mask_mean(errors, n_mask)
112 | else:
113 | loss += w * errors
114 | return loss
115 |
116 |
117 | def edge_aware_smoothness_loss(depth_prediction, input, reduce=True):
118 | keyframe = input["keyframe"]
119 | depth_prediction = depth_prediction / torch.mean(depth_prediction, dim=[2, 3], keepdim=True)
120 |
121 | d_dx = torch.abs(depth_prediction[:, :, :, :-1] - depth_prediction[:, :, :, 1:])
122 | d_dy = torch.abs(depth_prediction[:, :, :-1, :] - depth_prediction[:, :, 1:, :])
123 |
124 | k_dx = torch.mean(torch.abs(keyframe[:, :, :, :-1] - keyframe[:, :, :, 1:]), 1, keepdim=True)
125 | k_dy = torch.mean(torch.abs(keyframe[:, :, :-1, :] - keyframe[:, :, 1:, :]), 1, keepdim=True)
126 |
127 | d_dx *= torch.exp(-k_dx)
128 | d_dy *= torch.exp(-k_dy)
129 |
130 | if reduce:
131 | return d_dx.mean() + d_dy.mean()
132 | else:
133 | return F.pad(d_dx, pad=(0, 1), mode='constant', value=0) + F.pad(d_dy, pad=(0, 0, 0, 1), mode='constant', value=0)
134 |
135 |
136 | def sparse_depth_loss(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, l2=False, reduce=True):
137 | """
138 | :param depth_prediction:
139 | :param depth_gt: (N, 1, H, W)
140 | :return:
141 | """
142 | n, c, h, w = depth_prediction.shape
143 | mask = depth_gt == 0
144 | if not l2:
145 | errors = torch.abs(depth_prediction - depth_gt)
146 | else:
147 | errors = (depth_prediction - depth_gt) ** 2
148 |
149 | if reduce:
150 | loss = mask_mean(errors, mask)
151 | loss[torch.isnan(loss)] = 0
152 | return loss
153 | else:
154 | return errors, mask
155 |
156 |
157 | def selfsup_loss(depth_prediction: torch.Tensor, input=None, scale=0, automasking=True, error_function=None, error_function_weight=None, use_mono=True, use_stereo=False, reduce=True, combine_frames="min", mask_border=0):
158 | reprojection_l = reprojection_loss(depth_prediction, input, automasking=automasking, error_function=error_function, error_function_weight=error_function_weight, use_mono=use_mono, use_stereo=use_stereo, reduce=reduce, combine_frames=combine_frames, border=mask_border)
159 | reprojection_l[torch.isnan(reprojection_l)] = 0
160 | edge_aware_smoothness_l = edge_aware_smoothness_loss(depth_prediction, input)
161 | edge_aware_smoothness_l[torch.isnan(edge_aware_smoothness_l)] = 0
162 | loss = reprojection_l + edge_aware_smoothness_l * 1e-3 / (2 ** scale)
163 | return loss
164 |
165 |
166 | class PerceptualError(nn.Module):
167 | def __init__(self, small_features=False):
168 | super().__init__()
169 | self.small_features = small_features
170 | vgg16 = torchvision.models.vgg16(True, True)
171 | self.feature_extractor = nn.Sequential(*list(vgg16.features.children())[:4 if self.small_features else 9])
172 | for p in self.feature_extractor.parameters(True):
173 | p.requires_grad_(False)
174 | self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), requires_grad=False)
175 | self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), requires_grad=False)
176 |
177 | def forward(self, img0: torch.Tensor, img1: torch.Tensor, mask=None):
178 | n, c, h, w = img0.shape
179 | c = torchvision.models.vgg.cfgs["D"][1 if self.small_features else 4]
180 | if not self.small_features:
181 | h //= 2
182 | w //= 2
183 |
184 | img0 = (img0 - self.mean) / self.std
185 | img1 = (img1 - self.mean) / self.std
186 |
187 | if mask is not None:
188 | img0[mask.unsqueeze(1).expand(-1, 3, -1, -1)] = 0
189 | img1[mask.unsqueeze(1).expand(-1, 3, -1, -1)] = 0
190 |
191 | input = torch.cat([img0, img1], dim=0)
192 | features = self.feature_extractor(input)
193 | features = features.view(2, n, c, h, w)
194 | errors = torch.mean((features[1] - features[0]) ** 2, dim=1)
195 |
196 | if mask is not None:
197 | if not self.small_features:
198 | mask = F.upsample(mask.to(dtype=torch.float).unsqueeze(1), (h, w), mode="bilinear").squeeze(1)
199 | mask = mask > 0
200 | return errors, mask
201 | else:
202 | return errors
203 |
--------------------------------------------------------------------------------
/model/loss_functions/virtual_normal_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import cv2
5 |
6 |
7 | class VNL_Loss(nn.Module):
8 | """
9 | Virtual Normal Loss Function.
10 | """
11 | def __init__(self, focal_x, focal_y, input_size,
12 | delta_cos=0.867, delta_diff_x=0.01,
13 | delta_diff_y=0.01, delta_diff_z=0.01,
14 | delta_z=0.0001, sample_ratio=0.15):
15 | super(VNL_Loss, self).__init__()
16 | self.fx = torch.tensor([focal_x], dtype=torch.float32).cuda()
17 | self.fy = torch.tensor([focal_y], dtype=torch.float32).cuda()
18 | self.input_size = input_size
19 | self.u0 = torch.tensor(input_size[1] // 2, dtype=torch.float32).cuda()
20 | self.v0 = torch.tensor(input_size[0] // 2, dtype=torch.float32).cuda()
21 | self.init_image_coor()
22 | self.delta_cos = delta_cos
23 | self.delta_diff_x = delta_diff_x
24 | self.delta_diff_y = delta_diff_y
25 | self.delta_diff_z = delta_diff_z
26 | self.delta_z = delta_z
27 | self.sample_ratio = sample_ratio
28 |
29 | def init_image_coor(self):
30 | x_row = np.arange(0, self.input_size[1])
31 | x = np.tile(x_row, (self.input_size[0], 1))
32 | x = x[np.newaxis, :, :]
33 | x = x.astype(np.float32)
34 | x = torch.from_numpy(x.copy()).cuda()
35 | self.u_u0 = x - self.u0
36 |
37 | y_col = np.arange(0, self.input_size[0]) # y_col = np.arange(0, height)
38 | y = np.tile(y_col, (self.input_size[1], 1)).T
39 | y = y[np.newaxis, :, :]
40 | y = y.astype(np.float32)
41 | y = torch.from_numpy(y.copy()).cuda()
42 | self.v_v0 = y - self.v0
43 |
44 | def transfer_xyz(self, depth):
45 | x = self.u_u0 * torch.abs(depth) / self.fx
46 | y = self.v_v0 * torch.abs(depth) / self.fy
47 | z = depth
48 | pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
49 | return pw
50 |
51 | def select_index(self):
52 | valid_width = self.input_size[1]
53 | valid_height = self.input_size[0]
54 | num = valid_width * valid_height
55 | p1 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
56 | np.random.shuffle(p1)
57 | p2 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
58 | np.random.shuffle(p2)
59 | p3 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
60 | np.random.shuffle(p3)
61 |
62 | p1_x = p1 % self.input_size[1]
63 | p1_y = (p1 / self.input_size[1]).astype(np.int)
64 |
65 | p2_x = p2 % self.input_size[1]
66 | p2_y = (p2 / self.input_size[1]).astype(np.int)
67 |
68 | p3_x = p3 % self.input_size[1]
69 | p3_y = (p3 / self.input_size[1]).astype(np.int)
70 | p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y}
71 | return p123
72 |
73 | def form_pw_groups(self, p123, pw):
74 | """
75 | Form 3D points groups, with 3 points in each grouup.
76 | :param p123: points index
77 | :param pw: 3D points
78 | :return:
79 | """
80 | p1_x = p123['p1_x']
81 | p1_y = p123['p1_y']
82 | p2_x = p123['p2_x']
83 | p2_y = p123['p2_y']
84 | p3_x = p123['p3_x']
85 | p3_y = p123['p3_y']
86 |
87 | pw1 = pw[:, p1_y, p1_x, :]
88 | pw2 = pw[:, p2_y, p2_x, :]
89 | pw3 = pw[:, p3_y, p3_x, :]
90 | # [B, N, 3(x,y,z), 3(p1,p2,p3)]
91 | pw_groups = torch.cat([pw1[:, :, :, np.newaxis], pw2[:, :, :, np.newaxis], pw3[:, :, :, np.newaxis]], 3)
92 | return pw_groups
93 |
94 | def filter_mask(self, p123, gt_xyz, delta_cos=0.867,
95 | delta_diff_x=0.005,
96 | delta_diff_y=0.005,
97 | delta_diff_z=0.005):
98 | pw = self.form_pw_groups(p123, gt_xyz)
99 | pw12 = pw[:, :, :, 1] - pw[:, :, :, 0]
100 | pw13 = pw[:, :, :, 2] - pw[:, :, :, 0]
101 | pw23 = pw[:, :, :, 2] - pw[:, :, :, 1]
102 | ###ignore linear
103 | pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]],
104 | 3) # [b, n, 3, 3]
105 | m_batchsize, groups, coords, index = pw_diff.shape
106 | proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1) # (B* X CX(3)) [bn, 3(p123), 3(xyz)]
107 | proj_key = pw_diff.view(m_batchsize * groups, -1, index) # B X (3)*C [bn, 3(xyz), 3(p123)]
108 | q_norm = proj_query.norm(2, dim=2)
109 | nm = torch.bmm(q_norm.view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[]
110 | energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)]
111 | norm_energy = energy / (nm + 1e-8)
112 | norm_energy = norm_energy.view(m_batchsize * groups, -1)
113 | mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3 # igonre
114 | mask_cos = mask_cos.view(m_batchsize, groups)
115 | ##ignore padding and invilid depth
116 | mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3
117 |
118 | ###ignore near
119 | mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0
120 | mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0
121 | mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0
122 |
123 | mask_ignore = (mask_x & mask_y & mask_z) | mask_cos
124 | mask_near = ~mask_ignore
125 | mask = mask_pad & mask_near
126 |
127 | return mask, pw
128 |
129 | def select_points_groups(self, gt_depth, pred_depth):
130 | pw_gt = self.transfer_xyz(gt_depth)
131 | pw_pred = self.transfer_xyz(pred_depth)
132 | B, C, H, W = gt_depth.shape
133 | p123 = self.select_index()
134 | # mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)]
135 | mask, pw_groups_gt = self.filter_mask(p123, pw_gt,
136 | delta_cos=0.867,
137 | delta_diff_x=0.005,
138 | delta_diff_y=0.005,
139 | delta_diff_z=0.005)
140 |
141 | # [b, n, 3, 3]
142 | pw_groups_pred = self.form_pw_groups(p123, pw_pred)
143 | pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001
144 | mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2)
145 | pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3)
146 | pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3)
147 |
148 | return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore
149 |
150 | def forward(self, gt_depth, pred_depth, select=True):
151 | """
152 | Virtual normal loss.
153 | :param pred_depth: predicted depth map, [B,W,H,C]
154 | :param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down]
155 | :return:
156 | """
157 | gt_points, dt_points = self.select_points_groups(gt_depth, pred_depth)
158 |
159 | gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0]
160 | gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0]
161 | dt_p12 = dt_points[:, :, :, 1] - dt_points[:, :, :, 0]
162 | dt_p13 = dt_points[:, :, :, 2] - dt_points[:, :, :, 0]
163 |
164 | gt_normal = torch.cross(gt_p12, gt_p13, dim=2)
165 | dt_normal = torch.cross(dt_p12, dt_p13, dim=2)
166 | dt_norm = torch.norm(dt_normal, 2, dim=2, keepdim=True)
167 | gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True)
168 | dt_mask = dt_norm == 0.0
169 | gt_mask = gt_norm == 0.0
170 | dt_mask = dt_mask.to(torch.float32)
171 | gt_mask = gt_mask.to(torch.float32)
172 | dt_mask *= 0.01
173 | gt_mask *= 0.01
174 | gt_norm = gt_norm + gt_mask
175 | dt_norm = dt_norm + dt_mask
176 | gt_normal = gt_normal / gt_norm
177 | dt_normal = dt_normal / dt_norm
178 | loss = torch.abs(gt_normal - dt_normal)
179 | loss = torch.sum(torch.sum(loss, dim=2), dim=0)
180 | if select:
181 | loss, indices = torch.sort(loss, dim=0, descending=False)
182 | loss = loss[int(loss.size(0) * 0.25):]
183 | loss = torch.mean(loss)
184 | return loss
--------------------------------------------------------------------------------
/model/metric.py:
--------------------------------------------------------------------------------
1 | from .metric_functions.sparse_metrics import *
2 | from .metric_functions.dense_metrics import *
3 | from .metric_functions.completeness_metrics import *
--------------------------------------------------------------------------------
/model/metric_functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/metric_functions/__init__.py
--------------------------------------------------------------------------------
/model/metric_functions/completeness_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import mask_mean
4 |
5 |
6 | def completeness_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None):
7 | return torch.mean((depth_prediction != 0).to(dtype=torch.float32))
8 |
9 |
10 | def covered_gt_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None):
11 | gt_mask = depth_gt != 0
12 | return mask_mean(((depth_prediction != 0)).to(dtype=torch.float32), gt_mask)
13 |
--------------------------------------------------------------------------------
/model/metric_functions/dense_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import preprocess_roi, get_positive_depth, get_absolute_depth
4 |
5 |
6 | def sc_inv_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None):
7 | """
8 | Computes scale inveriant metric described in (14)
9 | :param depth_prediction: Depth prediction computed by the network
10 | :param depth_gt: GT Depth
11 | :param roi: Specify a region of interest on which the metric should be computed
12 | :return: metric (mean over batch_size)
13 | """
14 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
15 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
16 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
17 |
18 | n = depth_gt.shape[2] * depth_gt.shape[3]
19 | E = torch.log(depth_prediction) - torch.log(depth_gt)
20 | E[torch.isnan(E)] = 0
21 | batch_metric = torch.sqrt(1 / n * torch.sum(E**2, dim=[2, 3]) - 1 / (n**2) * (torch.sum(E, dim=[2, 3])**2))
22 | batch_metric[torch.isnan(batch_metric)] = 0
23 | result = torch.mean(batch_metric)
24 | return result
25 |
26 |
27 | def l1_rel_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None):
28 | """
29 | Computes the L1-rel metric described in (15)
30 | :param depth_prediction: Depth prediction computed by the network
31 | :param depth_gt: GT Depth
32 | :param roi: Specify a region of interest on which the metric should be computed
33 | :return: metric (mean over batch_size)
34 | """
35 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
36 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
37 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
38 |
39 | return torch.mean(torch.abs(depth_prediction - depth_gt) / depth_gt)
40 |
41 |
42 | def l1_inv_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None):
43 | """
44 | Computes the L1-inv metric described in (16)
45 | :param depth_prediction: Depth prediction computed by the network
46 | :param depth_gt: GT Depth
47 | :param roi: Specify a region of interest on which the metric should be computed
48 | :return: metric (mean over batch_size)
49 | """
50 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
51 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
52 |
53 |
54 | return torch.mean(torch.abs(depth_prediction - depth_gt))
--------------------------------------------------------------------------------
/model/metric_functions/sparse_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import preprocess_roi, get_positive_depth, get_absolute_depth, get_mask, mask_mean
4 |
5 | from depth_proc_tools.plot_depth_utils import *
6 | import cv2
7 | from PIL import Image
8 | import os
9 |
10 |
11 | def a1_metric(data_dict: dict, roi=None, max_distance=None):
12 | depth_prediction = data_dict["result"]
13 | depth_gt = data_dict["target"]
14 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
15 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
16 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
17 |
18 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt))
19 | return torch.mean((thresh < 1.25).type(torch.float))
20 |
21 |
22 | def a2_metric(data_dict: dict, roi=None, max_distance=None):
23 | depth_prediction = data_dict["result"]
24 | depth_gt = data_dict["target"]
25 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
26 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
27 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
28 |
29 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float)
30 | return torch.mean((thresh < 1.25 ** 2).type(torch.float))
31 |
32 |
33 | def a3_metric(data_dict: dict, roi=None, max_distance=None):
34 | depth_prediction = data_dict["result"]
35 | depth_gt = data_dict["target"]
36 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
37 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
38 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
39 |
40 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float)
41 | return torch.mean((thresh < 1.25 ** 3).type(torch.float))
42 |
43 |
44 | def rmse_metric(data_dict: dict, roi=None, max_distance=None):
45 | depth_prediction = data_dict["result"]
46 | depth_gt = data_dict["target"]
47 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
48 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
49 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
50 |
51 | se = (depth_prediction - depth_gt) ** 2
52 | return torch.mean(torch.sqrt(torch.mean(se, dim=[1, 2, 3])))
53 |
54 |
55 | def rmse_log_metric(data_dict: dict, roi=None, max_distance=None):
56 | depth_prediction = data_dict["result"]
57 | depth_gt = data_dict["target"]
58 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
59 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
60 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
61 |
62 | sle = (torch.log(depth_prediction) - torch.log(depth_gt)) ** 2
63 | return torch.mean(torch.sqrt(torch.mean(sle, dim=[1, 2, 3])))
64 |
65 |
66 | def abs_rel_metric(data_dict: dict, roi=None, max_distance=None):
67 | depth_prediction = data_dict["result"]
68 | depth_gt = data_dict["target"]
69 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
70 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
71 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
72 |
73 | return torch.mean(torch.abs(depth_prediction - depth_gt) / depth_gt)
74 |
75 |
76 | def sq_rel_metric(data_dict: dict, roi=None, max_distance=None):
77 | depth_prediction = data_dict["result"]
78 | depth_gt = data_dict["target"]
79 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
80 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
81 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
82 |
83 | return torch.mean(((depth_prediction - depth_gt) ** 2) / depth_gt)
84 |
85 |
86 | def find_mincost_depth(cost_volume, depth_hypos):
87 | argmax = torch.argmax(cost_volume, dim=1, keepdim=True)
88 | mincost_depth = torch.gather(input=depth_hypos, dim=1, index=argmax)
89 | return mincost_depth
90 |
91 | def a1_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False):
92 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"]
93 | depth_gt = data_dict["target"]
94 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
95 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid)
96 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5)
97 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
98 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
99 |
100 | return a1_base(depth_prediction, depth_gt, mask)
101 |
102 |
103 |
104 | def a2_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False):
105 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"]
106 | depth_gt = data_dict["target"]
107 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
108 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid)
109 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5)
110 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
111 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
112 | return a2_base(depth_prediction, depth_gt, mask)
113 |
114 |
115 | def a3_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False):
116 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"]
117 | depth_gt = data_dict["target"]
118 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
119 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid)
120 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5)
121 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
122 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
123 | return a3_base(depth_prediction, depth_gt, mask)
124 |
125 |
126 | def rmse_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False):
127 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"]
128 | depth_gt = data_dict["target"]
129 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
130 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid)
131 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5)
132 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
133 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
134 | return rmse_base(depth_prediction, depth_gt, mask)
135 |
136 |
137 | def rmse_log_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False):
138 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"]
139 | depth_gt = data_dict["target"]
140 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
141 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid)
142 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5)
143 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
144 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
145 | return rmse_log_base(depth_prediction, depth_gt, mask)
146 |
147 |
148 | def abs_rel_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False):
149 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"]
150 | depth_gt = data_dict["target"]
151 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
152 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid)
153 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5)
154 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
155 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
156 | return abs_rel_base(depth_prediction, depth_gt, mask)
157 |
158 |
159 | def sq_rel_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False):
160 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"]
161 | depth_gt = data_dict["target"]
162 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi)
163 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid)
164 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5)
165 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt)
166 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance)
167 | return sq_rel_base(depth_prediction, depth_gt, mask)
168 |
169 |
170 |
171 | def save_results(path, name, img, gt_depth, pred_depth, validmask, cv_mask, costvolume):
172 | savepath = os.path.join(path, name)
173 | device=img.device
174 | bs,_,h,w = img.shape
175 | img = img[0,...].permute(1,2,0).detach().cpu().numpy() + 0.5
176 | gt_depth = gt_depth[0,...].permute(1,2,0).detach().cpu().numpy()
177 | gt_depth[gt_depth==80] = 0
178 | pred_depth = pred_depth[0,...].permute(1,2,0).detach().cpu().numpy()
179 | validmask = validmask[0,0,...].detach().cpu().numpy()
180 | cv_mask = cv_mask[0,0,...].detach().cpu().numpy()
181 |
182 | img = img #* np.expand_dims(cv_mask, axis=-1).astype(float)
183 |
184 | error_map, _ = get_error_map_value(pred_depth,gt_depth, grag_crop=False, median_scaling=False)
185 |
186 | errorpil = numpy_intensitymap_to_pcolor(error_map,vmin=0,vmax=0.5,colormap='jet')
187 | pred_pil = numpy_intensitymap_to_pcolor(pred_depth)
188 | gt_pil = numpy_intensitymap_to_pcolor(gt_depth)
189 | img_pil = numpy_rgb_to_pil(img)
190 |
191 | # generate pil validmask
192 | validmask_pil = Image.fromarray((validmask * 255.0).astype(np.uint8))
193 | cv_mask_pil = Image.fromarray((cv_mask * 255.0).astype(np.uint8))
194 |
195 | # cost
196 | print(bs,h,w)
197 | # print(f"cost volume shape:{costvolume.shape}")
198 | depths = (1 / torch.linspace(0.0025, 0.33, 32,device=device)).cuda().view(1, -1, 1, 1).expand(bs, -1, h, w)
199 | cost_volume_depth = find_mincost_depth(costvolume, depths)
200 | # print(f"cost shape:{cost_volume_depth.shape}")
201 | cost_volume_depth = cost_volume_depth[0,...].permute(1,2,0).detach().cpu().numpy()
202 |
203 | cv_depth_pil = numpy_intensitymap_to_pcolor(cost_volume_depth)
204 |
205 |
206 | h,w,_ = gt_depth.shape
207 | dst = Image.new('RGB', (w, h*3))
208 | dst.paste(img_pil, (0, 0))
209 | dst.paste(pred_pil, (0, h))
210 | dst.paste(gt_pil, (0, 2*h))
211 | # dst.paste(errorpil, (0, 3*h))
212 | # dst.paste(validmask_pil,(0,4*h))
213 | # dst.paste(cv_mask_pil,(0,5*h))
214 | # dst.paste(cv_depth_pil,(0,2*h))
215 |
216 | dst.save(savepath)
217 | print(f"saved to {savepath}")
218 |
219 |
220 |
221 | def a1_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None):
222 | return a1_sparse_metric(data_dict, roi, max_distance, False)
223 |
224 |
225 | def a2_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None):
226 | return a2_sparse_metric(data_dict, roi, max_distance, False)
227 |
228 |
229 | def a3_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None):
230 | return a3_sparse_metric(data_dict, roi, max_distance, False)
231 |
232 |
233 | def rmse_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None):
234 | return rmse_sparse_metric(data_dict, roi, max_distance, False)
235 |
236 |
237 | def rmse_log_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None):
238 | return rmse_log_sparse_metric(data_dict, roi, max_distance, False)
239 |
240 |
241 | def abs_rel_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None):
242 | return abs_rel_sparse_metric(data_dict, roi, max_distance, False)
243 |
244 |
245 | def sq_rel_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None):
246 | return sq_rel_sparse_metric(data_dict, roi, max_distance, False)
247 |
248 |
249 | def a1_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None):
250 | return a1_sparse_metric(data_dict, roi, max_distance, use_cvmask=True)
251 |
252 |
253 | def a2_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None):
254 | return a2_sparse_metric(data_dict, roi, max_distance, use_cvmask=True)
255 |
256 |
257 | def a3_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None):
258 | return a3_sparse_metric(data_dict, roi, max_distance, use_cvmask=True)
259 |
260 |
261 | def rmse_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None):
262 | return rmse_sparse_metric(data_dict, roi, max_distance, use_cvmask=True)
263 |
264 |
265 | def rmse_log_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None):
266 | return rmse_log_sparse_metric(data_dict, roi, max_distance, use_cvmask=True)
267 |
268 |
269 | def abs_rel_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None):
270 | return abs_rel_sparse_metric(data_dict, roi, max_distance, use_cvmask=True)
271 |
272 |
273 | def sq_rel_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None):
274 | return sq_rel_sparse_metric(data_dict, roi, max_distance, use_cvmask=True)
275 |
276 |
277 | def a1_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask):
278 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt))
279 | return mask_mean((thresh < 1.25).type(torch.float), mask)
280 |
281 |
282 | def a2_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask):
283 | depth_gt[mask] = 1
284 | depth_prediction[mask] = 1
285 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float)
286 | return mask_mean((thresh < 1.25 ** 2).type(torch.float), mask)
287 |
288 |
289 | def a3_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask):
290 | depth_gt[mask] = 1
291 | depth_prediction[mask] = 1
292 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float)
293 | return mask_mean((thresh < 1.25 ** 3).type(torch.float), mask)
294 |
295 |
296 | def rmse_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask):
297 | depth_gt[mask] = 1
298 | depth_prediction[mask] = 1
299 | se = (depth_prediction - depth_gt) ** 2
300 | return torch.mean(torch.sqrt(mask_mean(se, mask, dim=[1, 2, 3])))
301 |
302 |
303 | def rmse_log_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask):
304 | depth_gt[mask] = 1
305 | depth_prediction[mask] = 1
306 | sle = (torch.log(depth_prediction) - torch.log(depth_gt)) ** 2
307 | return torch.mean(torch.sqrt(mask_mean(sle, mask, dim=[1, 2, 3])))
308 |
309 |
310 | def abs_rel_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask):
311 | return mask_mean(torch.abs(depth_prediction - depth_gt) / depth_gt, mask)
312 |
313 |
314 | def sq_rel_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask):
315 | return mask_mean(((depth_prediction - depth_gt) ** 2) / depth_gt, mask)
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | from .dymultidepth.dymultidepth_model import DyMultiDepthModel
--------------------------------------------------------------------------------
/parse_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce
5 | from operator import getitem
6 | from datetime import datetime
7 | from logger import setup_logging
8 | from utils import read_json, write_json
9 |
10 |
11 | class ConfigParser:
12 | def __init__(self, args, options='', timestamp=True):
13 | # parse default and custom cli options
14 | for opt in options:
15 | args.add_argument(*opt.flags, default=None, type=opt.type)
16 | args = args.parse_args()
17 | self.args = args
18 |
19 | if args.device:
20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
21 | if args.resume is None:
22 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
23 | assert args.config is not None, msg_no_cfg
24 | self.cfg_fname = Path(args.config)
25 | config = read_json(self.cfg_fname)
26 | self.resume = None
27 | else:
28 | self.resume = Path(args.resume)
29 | resume_cfg_fname = self.resume.parent / 'config.json'
30 | config = read_json(resume_cfg_fname)
31 | if args.config is not None:
32 | config.update(read_json(Path(args.config)))
33 |
34 | # load config file and apply custom cli options
35 | self._config = _update_config(config, options, args)
36 |
37 | # set save_dir where trained model and log will be saved.
38 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
39 |
40 | if "trainer" in self.config:
41 | save_dir = Path(self.config['trainer']['save_dir'])
42 | if "timestamp_replacement" in self.config["trainer"]:
43 | timestamp = self.config["trainer"]["timestamp_replacement"]
44 |
45 | elif "evaluater" in self.config:
46 | save_dir = Path(self.config['evaluater']['save_dir'])
47 | if "timestamp_replacement" in self.config["evaluater"]:
48 | timestamp = self.config["evaluater"]["timestamp_replacement"]
49 | elif "save_dir" in self.config:
50 | save_dir = Path(self.config["save_dir"])
51 | else:
52 | save_dir = Path("../saved")
53 |
54 | exper_name = self.config['name']
55 | self._save_dir = save_dir / 'models' / exper_name / timestamp
56 | self._log_dir = save_dir / 'log' / exper_name / timestamp
57 |
58 | self.save_dir.mkdir(parents=True, exist_ok=True)
59 | self.log_dir.mkdir(parents=True, exist_ok=True)
60 |
61 | # save updated config file to the checkpoint dir
62 | write_json(self.config, self.save_dir / 'config.json')
63 |
64 | # configure logging module
65 | setup_logging(self.log_dir)
66 | self.log_levels = {
67 | 0: logging.WARNING,
68 | 1: logging.INFO,
69 | 2: logging.DEBUG
70 | }
71 |
72 | def initialize(self, name, module, *args, **kwargs):
73 | """
74 | finds a function handle with the name given as 'type' in config, and returns the
75 | instance initialized with corresponding keyword args given as 'args'.
76 | """
77 | module_name = self[name]['type']
78 | module_args = dict(self[name]['args'])
79 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
80 | module_args.update(kwargs)
81 | return getattr(module, module_name)(*args, **module_args)
82 |
83 | def initialize_list(self, name, module, *args, **kwargs):
84 | l = self[name]
85 | for to_init in l:
86 | module_name = to_init["type"]
87 | module_args = dict(to_init["args"])
88 | module_args.update(kwargs)
89 | yield getattr(module, module_name)(*args, **module_args)
90 |
91 | def __getitem__(self, name):
92 | return self.config[name]
93 |
94 | def get_logger(self, name, verbosity=2):
95 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
96 | assert verbosity in self.log_levels, msg_verbosity
97 | logger = logging.getLogger(name)
98 | logger.setLevel(self.log_levels[verbosity])
99 | return logger
100 |
101 | # setting read-only attributes
102 | @property
103 | def config(self):
104 | return self._config
105 |
106 | @property
107 | def save_dir(self):
108 | return self._save_dir
109 |
110 | @property
111 | def log_dir(self):
112 | return self._log_dir
113 |
114 | # helper functions used to update config dict with custom cli options
115 | def _update_config(config, options, args):
116 | for opt in options:
117 | value = getattr(args, _get_opt_name(opt.flags))
118 | if value is not None:
119 | _set_by_path(config, opt.target, value)
120 | return config
121 |
122 | def _get_opt_name(flags):
123 | for flg in flags:
124 | if flg.startswith('--'):
125 | return flg.replace('--', '')
126 | return flags[0].replace('--', '')
127 |
128 | def _set_by_path(tree, keys, value):
129 | """Set a value in a nested object in tree by sequence of keys."""
130 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
131 |
132 | def _get_by_path(tree, keys):
133 | """Access a nested object in tree by sequence of keys."""
134 | return reduce(getitem, keys, tree)
135 |
--------------------------------------------------------------------------------
/pictures/dynamic_depth_result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/pictures/dynamic_depth_result.gif
--------------------------------------------------------------------------------
/pictures/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/pictures/overview.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | torchvision==0.8.2
3 | kornia
4 | opencv-python
5 | scipy
6 | scikit-image
7 | tqdm
8 | tensorboard
9 | tensorboardx
10 | pykitti
11 | colour_demosaicing
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import torch
4 | import data_loader.data_loaders as module_data
5 | import model.loss as module_loss
6 | import model.metric as module_metric
7 | import model.model as module_arch
8 | from utils import seed_rng
9 | from utils.parse_config import ConfigParser
10 | from trainer.trainer import Trainer
11 |
12 |
13 | torch.backends.cuda.matmul.allow_tf32 = False
14 | torch.backends.cudnn.allow_tf32 = False
15 | torch.backends.cudnn.benchmark = True
16 |
17 | def main(config, options=[]):
18 | seed_rng(0)
19 | logger = config.get_logger('train')
20 |
21 | # setup data_loader instances
22 | data_loader = config.initialize('data_loader', module_data)
23 | if "val_data_loader" in config.config:
24 | valid_data_loader = config.initialize("val_data_loader", module_data)
25 | else:
26 | valid_data_loader = data_loader.split_validation()
27 |
28 | # build model architecture, then print to console
29 | model = config.initialize('arch', module_arch)
30 | logger.info(model)
31 | logger.info(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters")
32 | logger.info(f"{sum(p.numel() for p in model.parameters())} total parameters")
33 |
34 | # get function handles of loss and metrics
35 | if "loss_module" in config.config:
36 | loss = config.initialize("loss_module", module_loss)
37 | else:
38 | loss = getattr(module_loss, config['loss'])
39 | metrics = [getattr(module_metric, met) for met in config['metrics']]
40 |
41 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
42 | trainable_params = filter(lambda p: p.requires_grad, model.parameters())
43 | optimizer = config.initialize('optimizer', torch.optim, trainable_params)
44 |
45 | lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer)
46 |
47 | trainer = Trainer(model, loss, metrics, optimizer,
48 | config=config,
49 | data_loader=data_loader,
50 | valid_data_loader=valid_data_loader,
51 | lr_scheduler=lr_scheduler,
52 | options=options)
53 |
54 | trainer.train()
55 |
56 |
57 | if __name__ == '__main__':
58 | args = argparse.ArgumentParser(description='PyTorch Template')
59 | args.add_argument('-c', '--config', default=None, type=str,
60 | help='config file path (default: None)')
61 | args.add_argument('-r', '--resume', default=None, type=str,
62 | help='path to latest checkpoint (default: None)')
63 | args.add_argument('-d', '--device', default=None, type=str,
64 | help='indices of GPUs to enable (default: all)')
65 | args.add_argument('-o', '--options', default=[], nargs='+')
66 |
67 | # custom cli options to modify configuration from default values given in json file.
68 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
69 | options = [
70 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
71 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
72 | ]
73 | config = ConfigParser(args, options)
74 | print(config.config)
75 | main(config, config.args.options)
76 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/trainer/__init__.py
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 | import numpy as np
4 | import torch
5 | from torchvision.utils import make_grid
6 |
7 | from base import BaseTrainer
8 | from utils import inf_loop, map_fn, operator_on_dict, LossWrapper, ValueFader
9 |
10 | import time
11 |
12 | class Trainer(BaseTrainer):
13 | def __init__(self, model, loss, metrics, optimizer, config, data_loader,
14 | valid_data_loader=None, lr_scheduler=None, options=[]):
15 | super().__init__(model, loss, metrics, optimizer, config)
16 | self.config = config
17 | self.data_loader = data_loader
18 |
19 | len_epoch = config["trainer"].get("len_epoch", None)
20 |
21 | if len_epoch is None:
22 | # epoch-based training
23 | self.len_epoch = len(self.data_loader)
24 | else:
25 | # iteration-based training
26 | self.data_loader = inf_loop(data_loader)
27 | self.len_epoch = len_epoch
28 |
29 | self.valid_data_loader = valid_data_loader
30 | self.do_validation = self.valid_data_loader is not None
31 | self.lr_scheduler = lr_scheduler
32 | self.log_step = config["trainer"].get("log_step", int(np.sqrt(data_loader.batch_size)))
33 | self.val_log_step = config["trainer"].get("val_step", 1)
34 | self.roi = config["trainer"].get("roi")
35 | self.roi_train = config["trainer"].get("roi_train", self.roi)
36 | self.alpha = config["trainer"].get("alpha", None)
37 | self.max_distance = config["trainer"].get("max_distance", None)
38 | self.val_avg = config["trainer"].get("val_avg", True)
39 | self.save_multiple = config["trainer"].get("save_multiple", False)
40 | self.invert_output_images = config["trainer"].get("invert_output_images", True)
41 | self.wrap_loss_in_module = config["trainer"].get("wrap_loss_in_module", False)
42 | self.value_faders = config["trainer"].get("value_faders", {})
43 | self.options = options
44 |
45 | if self.wrap_loss_in_module:
46 | self.loss = LossWrapper(loss_function=self.loss, roi=self.roi, options=self.options)
47 |
48 | if isinstance(loss, torch.nn.Module) or self.wrap_loss_in_module:
49 | self.module_loss = True
50 | self.loss.to(self.device)
51 | if len(self.device_ids) > 1:
52 | self.loss.num_devices = len(self.device_ids)
53 | self.model = torch.nn.DataParallel(torch.nn.Sequential(self.model.module, self.loss), self.device_ids)
54 | else:
55 | self.model = torch.nn.Sequential(self.model, self.loss)
56 | else:
57 | self.module_loss = False
58 |
59 | self.value_faders = {k: ValueFader(v[0], v[1]) for k, v in self.value_faders.items()}
60 |
61 | def _eval_metrics(self, data_dict, training=False):
62 | acc_metrics = np.zeros(len(self.metrics))
63 | for i, metric in enumerate(self.metrics):
64 | acc_metrics[i] += metric(data_dict, self.roi, self.max_distance)
65 | if (not self.val_avg) or training:
66 | self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
67 | if np.any(np.isnan(acc_metrics)):
68 | acc_metrics = np.zeros(len(self.metrics))
69 | valid = np.zeros(len(self.metrics))
70 | else:
71 | valid = np.ones(len(self.metrics))
72 | return acc_metrics, valid
73 |
74 | def _train_epoch(self, epoch):
75 | self.model.train()
76 |
77 | total_loss = 0
78 | total_loss_dict = {}
79 | total_metrics = np.zeros(len(self.metrics))
80 | total_metrics_valid = np.zeros(len(self.metrics))
81 |
82 | fade_values = {k: torch.tensor([fader.get_value(epoch)]) for k, fader in self.value_faders.items()}
83 |
84 | for batch_idx, (data, target) in enumerate(self.data_loader):
85 | data.update(fade_values)
86 | data, target = to(data, self.device), to(target, self.device)
87 | data["target"] = target
88 | # data["optimizer"] = self.optimizer
89 |
90 | start_time = time.time()
91 |
92 | self.optimizer.zero_grad()
93 |
94 | if not self.module_loss:
95 | data = self.model(data)
96 | loss_dict = self.loss(data, self.alpha, self.roi_train, options=self.options)
97 | else:
98 | data, loss_dict = self.model(data)
99 |
100 | loss_dict = map_fn(loss_dict, torch.sum)
101 |
102 | loss = loss_dict["loss"]
103 | if loss.requires_grad:
104 | loss.backward()
105 |
106 | self.optimizer.step()
107 |
108 | # print("Forward time: ", (time.time() - start_time))
109 |
110 | loss_dict = map_fn(loss_dict, torch.detach)
111 |
112 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
113 |
114 | self.writer.add_scalar('loss', loss.item())
115 | for loss_component, v in loss_dict.items():
116 | self.writer.add_scalar(f"loss_{loss_component}", v.item())
117 |
118 | total_loss += loss.item()
119 | total_loss_dict = operator_on_dict(total_loss_dict, loss_dict, operator.add)
120 | metrics, valid = self._eval_metrics(data, True)
121 | total_metrics += metrics
122 | total_metrics_valid += valid
123 |
124 | if self.writer.step % self.log_step == 0:
125 | img_count = min(data["keyframe"].shape[0], 8)
126 |
127 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f} Loss_dict: {}'.format(
128 | epoch,
129 | self._progress(batch_idx),
130 | loss.item(),
131 | loss_dict))
132 |
133 | if "mask" in data:
134 | if self.invert_output_images:
135 | result = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu()
136 | result /= torch.max(result) * 2 / 3
137 | else:
138 | result = data["result"][:img_count].cpu()
139 | mask = data["mask"][:img_count].cpu()
140 | img = torch.cat([result, mask], dim=2)
141 | else:
142 | if self.invert_output_images:
143 | img = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu()
144 | else:
145 | img = data["result"][:img_count].cpu()
146 |
147 | self.writer.add_image('input', make_grid(to(data["keyframe"][:img_count], "cpu"), nrow=2, normalize=True))
148 | self.writer.add_image('output', make_grid(img , nrow=2, normalize=True))
149 | self.writer.add_image('ground_truth', make_grid(to(torch.clamp(infnan_to_zero(1 / data["target"][:img_count]), 0, 100), "cpu"), nrow=2, normalize=True))
150 |
151 | if batch_idx == self.len_epoch:
152 | break
153 |
154 | log = {
155 | 'loss': total_loss / self.len_epoch,
156 | 'metrics': (total_metrics / total_metrics_valid).tolist()
157 | }
158 | for loss_component, v in total_loss_dict.items():
159 | log[f"loss_{loss_component}"] = v.item() / self.len_epoch
160 |
161 | if self.do_validation:
162 | val_log = self._valid_epoch(epoch)
163 | log.update(val_log)
164 |
165 | if self.lr_scheduler is not None:
166 | self.lr_scheduler.step()
167 |
168 | return log
169 |
170 | def _valid_epoch(self, epoch):
171 | self.model.eval()
172 |
173 | total_val_loss = 0
174 | total_val_loss_dict = {}
175 | total_val_metrics = np.zeros(len(self.metrics))
176 | total_val_metrics_valid = np.zeros(len(self.metrics))
177 |
178 | with torch.no_grad():
179 | for batch_idx, (data, target) in enumerate(self.valid_data_loader):
180 | data, target = to(data, self.device), to(target, self.device)
181 | data["target"] = target
182 |
183 | if not self.module_loss:
184 | data = self.model(data)
185 | loss_dict = self.loss(data, self.alpha, self.roi_train, options=self.options)
186 | else:
187 | data, loss_dict = self.model(data)
188 |
189 | loss_dict = map_fn(loss_dict, torch.sum)
190 | loss = loss_dict["loss"]
191 |
192 | img_count = min(data["keyframe"].shape[0], 8)
193 |
194 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
195 | if not self.val_avg:
196 | self.writer.add_scalar('loss', loss.item())
197 | for loss_component, v in loss_dict.items():
198 | self.writer.add_scalar(f"loss_{loss_component}", v.item())
199 | total_val_loss += loss.item()
200 | total_val_loss_dict = operator_on_dict(total_val_loss_dict, loss_dict, operator.add)
201 | metrics, valid = self._eval_metrics(data)
202 | total_val_metrics += metrics
203 | total_val_metrics_valid += valid
204 | if batch_idx % self.val_log_step == 0:
205 | if "mask" in data:
206 | if self.invert_output_images:
207 | result = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu()
208 | result /= torch.max(result) * 2 / 3
209 | else:
210 | result = data["result"][:img_count].cpu()
211 | mask = data["mask"][:img_count].cpu()
212 | img = torch.cat([result, mask], dim=2)
213 | else:
214 | if self.invert_output_images:
215 | img = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu()
216 | else:
217 | img = data["result"][:img_count].cpu()
218 |
219 | self.writer.add_image('input', make_grid(to(data["keyframe"][:img_count], "cpu"), nrow=2, normalize=True))
220 | self.writer.add_image('output', make_grid(img, nrow=2, normalize=True))
221 | self.writer.add_image('ground_truth', make_grid(to(torch.clamp(infnan_to_zero(1 / data["target"][:img_count]), 0, 100), "cpu"), nrow=2, normalize=True))
222 |
223 | if self.val_avg:
224 | len_val = len(self.valid_data_loader)
225 | self.writer.add_scalar('loss', total_val_loss / len_val)
226 | for i, metric in enumerate(self.metrics):
227 | self.writer.add_scalar('{}'.format(metric.__name__), total_val_metrics[i] / len_val)
228 | for loss_component, v in total_val_loss_dict.items():
229 | self.writer.add_scalar(f"loss_{loss_component}", v.item() / len_val)
230 |
231 | result = {
232 | 'val_loss': total_val_loss / len(self.valid_data_loader),
233 | 'val_metrics': (total_val_metrics / total_val_metrics_valid).tolist()
234 | }
235 |
236 | for loss_component, v in total_val_loss_dict.items():
237 | result[f"val_loss_{loss_component}"] = v.item() / len(self.valid_data_loader)
238 |
239 | return result
240 |
241 | def _progress(self, batch_idx):
242 | base = '[{}/{} ({:.0f}%)]'
243 | if hasattr(self.data_loader, 'n_samples'):
244 | current = batch_idx * self.data_loader.batch_size
245 | total = self.data_loader.n_samples
246 | else:
247 | current = batch_idx
248 | total = self.len_epoch
249 | return base.format(current, total, 100.0 * current / total)
250 |
251 |
252 | def to(data, device):
253 | if isinstance(data, dict):
254 | return {k: to(data[k], device) for k in data.keys()}
255 | elif isinstance(data, list):
256 | return [to(v, device) for v in data]
257 | else:
258 | return data.to(device)
259 |
260 |
261 | def infnan_to_zero(t:torch.Tensor()):
262 | t[torch.isinf(t)] = 0
263 | t[torch.isnan(t)] = 0
264 | return t
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import *
2 | from .ply_utils import *
--------------------------------------------------------------------------------
/utils/parse_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce
5 | from operator import getitem
6 | from datetime import datetime
7 | from logger import setup_logging
8 | from utils import read_json, write_json
9 |
10 |
11 | class ConfigParser:
12 | def __init__(self, args, options='', timestamp=True):
13 | # parse default and custom cli options
14 | for opt in options:
15 | args.add_argument(*opt.flags, default=None, type=opt.type)
16 | args = args.parse_args()
17 | self.args = args
18 |
19 | if args.device:
20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
21 | if args.resume is None:
22 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
23 | assert args.config is not None, msg_no_cfg
24 | self.cfg_fname = Path(args.config)
25 | config = read_json(self.cfg_fname)
26 | self.resume = None
27 | else:
28 | self.resume = Path(args.resume)
29 | resume_cfg_fname = self.resume.parent / 'config.json'
30 | config = read_json(resume_cfg_fname)
31 | if args.config is not None:
32 | config.update(read_json(Path(args.config)))
33 |
34 | # load config file and apply custom cli options
35 | self._config = _update_config(config, options, args)
36 |
37 | # set save_dir where trained model and log will be saved.
38 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
39 |
40 | if "trainer" in self.config:
41 | save_dir = Path(self.config['trainer']['save_dir'])
42 | if "timestamp_replacement" in self.config["trainer"]:
43 | timestamp = self.config["trainer"]["timestamp_replacement"]
44 |
45 | elif "evaluater" in self.config:
46 | save_dir = Path(self.config['evaluater']['save_dir'])
47 | if "timestamp_replacement" in self.config["evaluater"]:
48 | timestamp = self.config["evaluater"]["timestamp_replacement"]
49 | elif "save_dir" in self.config:
50 | save_dir = Path(self.config["save_dir"])
51 | else:
52 | save_dir = Path("../saved")
53 |
54 | exper_name = self.config['name']
55 | self._save_dir = save_dir / 'models' / exper_name / timestamp
56 | self._log_dir = save_dir / 'log' / exper_name / timestamp
57 |
58 | self.save_dir.mkdir(parents=True, exist_ok=True)
59 | self.log_dir.mkdir(parents=True, exist_ok=True)
60 |
61 | # save updated config file to the checkpoint dir
62 | write_json(self.config, self.save_dir / 'config.json')
63 |
64 | # configure logging module
65 | setup_logging(self.log_dir)
66 | self.log_levels = {
67 | 0: logging.WARNING,
68 | 1: logging.INFO,
69 | 2: logging.DEBUG
70 | }
71 |
72 | def initialize(self, name, module, *args, **kwargs):
73 | """
74 | finds a function handle with the name given as 'type' in config, and returns the
75 | instance initialized with corresponding keyword args given as 'args'.
76 | """
77 | module_name = self[name]['type']
78 | module_args = dict(self[name]['args'])
79 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
80 | module_args.update(kwargs)
81 | return getattr(module, module_name)(*args, **module_args)
82 |
83 | def initialize_list(self, name, module, *args, **kwargs):
84 | l = self[name]
85 | for to_init in l:
86 | module_name = to_init["type"]
87 | module_args = dict(to_init["args"])
88 | module_args.update(kwargs)
89 | yield getattr(module, module_name)(*args, **module_args)
90 |
91 | def __getitem__(self, name):
92 | return self.config[name]
93 |
94 | def get_logger(self, name, verbosity=2):
95 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
96 | assert verbosity in self.log_levels, msg_verbosity
97 | logger = logging.getLogger(name)
98 | logger.setLevel(self.log_levels[verbosity])
99 | return logger
100 |
101 | # setting read-only attributes
102 | @property
103 | def config(self):
104 | return self._config
105 |
106 | @property
107 | def save_dir(self):
108 | return self._save_dir
109 |
110 | @property
111 | def log_dir(self):
112 | return self._log_dir
113 |
114 | # helper functions used to update config dict with custom cli options
115 | def _update_config(config, options, args):
116 | for opt in options:
117 | value = getattr(args, _get_opt_name(opt.flags))
118 | if value is not None:
119 | _set_by_path(config, opt.target, value)
120 | return config
121 |
122 | def _get_opt_name(flags):
123 | for flg in flags:
124 | if flg.startswith('--'):
125 | return flg.replace('--', '')
126 | return flags[0].replace('--', '')
127 |
128 | def _set_by_path(tree, keys, value):
129 | """Set a value in a nested object in tree by sequence of keys."""
130 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
131 |
132 | def _get_by_path(tree, keys):
133 | """Access a nested object in tree by sequence of keys."""
134 | return reduce(getitem, keys, tree)
135 |
--------------------------------------------------------------------------------
/utils/ply_utils.py:
--------------------------------------------------------------------------------
1 | from array import array
2 |
3 | import torch
4 |
5 | from model.layers import Backprojection
6 |
7 |
8 | class PLYSaver(torch.nn.Module):
9 | def __init__(self, height, width, min_d=3, max_d=400, batch_size=1, roi=None, dropout=0):
10 | super(PLYSaver, self).__init__()
11 | self.min_d = min_d
12 | self.max_d = max_d
13 | self.roi = roi
14 | self.dropout = dropout
15 | self.data = array('f')
16 |
17 | self.projector = Backprojection(batch_size, height, width)
18 |
19 | def save(self, file):
20 | length = len(self.data) // 6
21 | header = "ply\n" \
22 | "format binary_little_endian 1.0\n" \
23 | f"element vertex {length}\n" \
24 | f"property float x\n" \
25 | f"property float y\n" \
26 | f"property float z\n" \
27 | f"property float red\n" \
28 | f"property float green\n" \
29 | f"property float blue\n" \
30 | f"end_header\n"
31 | file.write(header.encode(encoding="ascii"))
32 | self.data.tofile(file)
33 |
34 | def add_depthmap(self, depth: torch.Tensor, image: torch.Tensor, intrinsics: torch.Tensor,
35 | extrinsics: torch.Tensor):
36 | depth = 1 / depth
37 | image = (image + .5) * 255
38 | mask = (self.min_d <= depth) & (depth <= self.max_d)
39 | if self.roi is not None:
40 | mask[:, :, :self.roi[0], :] = False
41 | mask[:, :, self.roi[1]:, :] = False
42 | mask[:, :, :, :self.roi[2]] = False
43 | mask[:, :, :, self.roi[3]:] = False
44 | if self.dropout > 0:
45 | mask = mask & (torch.rand_like(depth) > self.dropout)
46 |
47 | coords = self.projector(depth, torch.inverse(intrinsics))
48 | coords = extrinsics @ coords
49 | coords = coords[:, :3, :]
50 | data_batch = torch.cat([coords, image.view_as(coords)], dim=1).permute(0, 2, 1)
51 | data_batch = data_batch[mask.view(depth.shape[0], 1, -1).permute(0, 2, 1).expand(-1, -1, 6)]
52 |
53 | self.data.extend(data_batch.cpu().tolist())
54 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | from functools import partial
3 | from pathlib import Path
4 | from datetime import datetime
5 | from itertools import repeat
6 | from collections import OrderedDict
7 |
8 | import torch
9 | from PIL import Image
10 | import numpy as np
11 | import torch.nn.functional as F
12 | from kornia.geometry.camera import pixel2cam
13 | from kornia.geometry.depth import DepthWarper
14 |
15 | def map_fn(batch, fn):
16 | if isinstance(batch, dict):
17 | for k in batch.keys():
18 | batch[k] = map_fn(batch[k], fn)
19 | return batch
20 | elif isinstance(batch, list):
21 | return [map_fn(e, fn) for e in batch]
22 | else:
23 | return fn(batch)
24 |
25 | def to(data, device):
26 | if isinstance(data, dict):
27 | return {k: to(data[k], device) for k in data.keys()}
28 | elif isinstance(data, list):
29 | return [to(v, device) for v in data]
30 | else:
31 | return data.to(device)
32 |
33 |
34 | eps = 1e-6
35 |
36 | def preprocess_roi(depth_prediction, depth_gt: torch.Tensor, roi):
37 | if roi is not None:
38 | if isinstance(depth_prediction, list):
39 | depth_prediction = [dpr[:, :, roi[0]:roi[1], roi[2]:roi[3]] for dpr in depth_prediction]
40 | else:
41 | depth_prediction = depth_prediction[:, :, roi[0]:roi[1], roi[2]:roi[3]]
42 | depth_gt = depth_gt[:, :, roi[0]:roi[1], roi[2]:roi[3]]
43 | return depth_prediction, depth_gt
44 |
45 |
46 | def get_absolute_depth(depth_prediction, depth_gt: torch.Tensor, max_distance=None):
47 | if max_distance is not None:
48 | if isinstance(depth_prediction, list):
49 | depth_prediction = [torch.clamp_min(dpr, 1 / max_distance) for dpr in depth_prediction]
50 | else:
51 | depth_prediction = torch.clamp_min(depth_prediction, 1 / max_distance)
52 | depth_gt = torch.clamp_min(depth_gt, 1 / max_distance)
53 | if isinstance(depth_prediction, list):
54 | return [1 / dpr for dpr in depth_prediction], 1 / depth_gt
55 | else:
56 | return 1 / depth_prediction, 1 / depth_gt
57 |
58 |
59 | def get_positive_depth(depth_prediction: torch.Tensor, depth_gt: torch.Tensor):
60 | if isinstance(depth_prediction, list):
61 | depth_prediction = [torch.nn.functional.relu(dpr) for dpr in depth_prediction]
62 | else:
63 | depth_prediction = torch.nn.functional.relu(depth_prediction)
64 | depth_gt = torch.nn.functional.relu(depth_gt)
65 | return depth_prediction, depth_gt
66 |
67 |
68 | def depthmap_to_points(depth: torch.Tensor, intrinsics: torch.Tensor, flatten=False):
69 | n, c, h, w = depth.shape
70 | grid = DepthWarper._create_meshgrid(h, w).expand(n, -1, -1, -1).to(depth.device)
71 | points = pixel2cam(depth, torch.inverse(intrinsics), grid)
72 | if not flatten:
73 | return points
74 | else:
75 | return points.view(n, h * w, 3)
76 |
77 |
78 | def save_frame_for_tsdf(dir: Path, index, keyframe, depth, pose, crop=None, min_distance=None, max_distance=None):
79 | if crop is not None:
80 | keyframe = keyframe[:, crop[0]:crop[1], crop[2]:crop[3]]
81 | depth = depth[crop[0]:crop[1], crop[2]:crop[3]]
82 | keyframe = ((keyframe + .5) * 255).to(torch.uint8).permute(1, 2, 0)
83 | depth = (1 / depth * 100).to(torch.int16)
84 | depth[depth < 0] = 0
85 | if min_distance is not None:
86 | depth[depth < min_distance * 100] = 0
87 | if max_distance is not None:
88 | depth[depth > max_distance * 100] = 0
89 | Image.fromarray(keyframe.numpy()).save(dir / f"frame-{index:06d}.color.jpg")
90 | Image.fromarray(depth.numpy()).save(dir / f"frame-{index:06d}.depth.png")
91 | np.savetxt(dir / f"frame-{index:06d}.pose.txt", torch.inverse(pose).numpy())
92 |
93 |
94 | def save_intrinsics_for_tsdf(dir: Path, intrinsics, crop=None):
95 | if crop is not None:
96 | intrinsics[0, 2] -= crop[2]
97 | intrinsics[1, 2] -= crop[0]
98 | np.savetxt(dir / f"camera-intrinsics.txt", intrinsics[:3, :3].numpy())
99 |
100 |
101 | def get_mask(pred: torch.Tensor, gt: torch.Tensor, max_distance=None, pred_all_valid=True):
102 | mask = gt == 0
103 | if max_distance:
104 | mask |= (gt < 1 / max_distance)
105 | if not pred_all_valid:
106 | mask |= pred == 0
107 | return mask
108 |
109 |
110 | def mask_mean(t: torch.Tensor, m: torch.Tensor, dim=None):
111 | t = t.clone()
112 | t[m] = 0
113 | els = 1
114 | if dim is None:
115 | dim = list(range(len(t.shape)))
116 | for d in dim:
117 | els *= t.shape[d]
118 | return torch.sum(t, dim=dim) / (els - torch.sum(m.to(torch.float), dim=dim))
119 |
120 |
121 | def conditional_flip(x, condition, inplace=True):
122 | if inplace:
123 | x[condition, :, :, :] = x[condition, :, :, :].flip(3)
124 | else:
125 | flipped_x = x.clone()
126 | flipped_x[condition, :, :, :] = x[condition, :, :, :].flip(3)
127 | return flipped_x
128 |
129 |
130 | def create_mask(c: int, height: int, width: int, border_radius: int, device):
131 | mask = torch.ones(c, 1, height - 2 * border_radius, width - 2 * border_radius, device=device)
132 | return torch.nn.functional.pad(mask, [border_radius, border_radius, border_radius, border_radius])
133 |
134 |
135 | def median_scaling(data_dict):
136 | target = data_dict["target"]
137 | prediction = data_dict["result"]
138 | mask = target > 0
139 | ratios = mask.new_tensor([torch.median(target[i, mask[i]]) / torch.median(prediction[i, mask[i]]) for i in range(target.shape[0])], dtype=torch.float32)
140 | data_dict = dict(data_dict)
141 | data_dict["result"] = prediction * ratios.view(-1, 1, 1, 1)
142 | return data_dict
143 |
144 |
145 | unsqueezer = partial(torch.unsqueeze, dim=0)
146 |
147 |
148 | class DS_Wrapper(torch.utils.data.Dataset):
149 | def __init__(self, dataset, start=0, end=-1, every_nth=1):
150 | super().__init__()
151 | self.dataset = dataset
152 | self.start = start
153 | if end == -1:
154 | self.end = len(self.dataset)
155 | else:
156 | self.end = end
157 | self.every_nth = every_nth
158 |
159 | def __getitem__(self, index: int):
160 | return self.dataset.__getitem__(index * self.every_nth + self.start)
161 |
162 | def __len__(self):
163 | return (self.end - self.start) // self.every_nth + (1 if (self.end - self.start) % self.every_nth != 0 else 0)
164 |
165 | class DS_Merger(torch.utils.data.Dataset):
166 | def __init__(self, datasets):
167 | super().__init__()
168 | self.datasets = datasets
169 |
170 | def __getitem__(self, index: int):
171 | return (ds.__getitem__(index + self.start) for ds in self.datasets)
172 |
173 | def __len__(self):
174 | return len(self.datasets[0])
175 |
176 |
177 | class LossWrapper(torch.nn.Module):
178 | def __init__(self, loss_function, **kwargs):
179 | super().__init__()
180 | self.kwargs = kwargs
181 | self.loss_function = loss_function
182 | self.num_devices = 1.0
183 |
184 | def forward(self, data):
185 | loss_dict = self.loss_function(data, **self.kwargs)
186 | loss_dict = map_fn(loss_dict, lambda x: (x / self.num_devices))
187 | if loss_dict["loss"].requires_grad:
188 | loss_dict["loss"].backward()
189 | loss_dict["loss"].detach_()
190 | return data, loss_dict
191 |
192 |
193 | class ValueFader:
194 | def __init__(self, steps, values):
195 | self.steps = steps
196 | self.values = values
197 | self.epoch = 0
198 |
199 | def set_epoch(self, epoch):
200 | self.epoch = epoch
201 |
202 | def get_value(self, epoch=None):
203 | if epoch is None:
204 | epoch = self.epoch
205 | if epoch >= self.steps[-1]:
206 | return self.values[-1]
207 |
208 | step_index = 0
209 |
210 | while step_index < len(self.steps)-1 and epoch >= self.steps[step_index+1]:
211 | step_index += 1
212 |
213 | p = float(epoch - self.steps[step_index]) / float(self.steps[step_index+1] - self.steps[step_index])
214 | return (1-p) * self.values[step_index] + p * self.values[step_index+1]
215 |
216 |
217 | def pose_distance_thresh(data_dict, spatial_thresh=.6, rotational_thresh=.05):
218 | poses = torch.stack([data_dict["keyframe_pose"]] + data_dict["poses"], dim=1)
219 | forward = poses.new_tensor([0, 0, 1], dtype=torch.float32)
220 | spatial_expanse = torch.norm(torch.max(poses[..., :3, 3], dim=1)[0] - torch.min(poses[..., :3, 3], dim=1)[0], dim=1)
221 | rotational_expanse = torch.norm(torch.max(poses[..., :3, :3] @ forward, dim=1)[0] - torch.min(poses[..., :3, :3] @ forward, dim=1)[0], dim=1)
222 | return (spatial_expanse > spatial_thresh) | (rotational_expanse > rotational_thresh)
223 |
224 |
225 | def dilate_mask(m: torch.Tensor, size: int = 15):
226 | k = m.new_ones((1, 1, size, size), dtype=torch.float32)
227 | dilated_mask = F.conv2d((m >= 0.5).to(dtype=torch.float32), k, padding=(size//2, size//2))
228 | return dilated_mask > 0
229 |
230 |
231 | def operator_on_dict(dict_0: dict, dict_1: dict, operator, default=0):
232 | keys = set(dict_0.keys()).union(set(dict_1.keys()))
233 | results = {}
234 | for k in keys:
235 | v_0 = dict_0[k] if k in dict_0 else default
236 | v_1 = dict_1[k] if k in dict_1 else default
237 | results[k] = operator(v_0, v_1)
238 | return results
239 |
240 |
241 | numbers = [f"{i:d}" for i in range(1, 10, 1)]
242 |
243 |
244 | def filter_state_dict(state_dict, data_parallel=False):
245 | if data_parallel:
246 | state_dict = {k[7:]: state_dict[k] for k in state_dict}
247 | state_dict = {(k[2:] if k.startswith("0") else k): state_dict[k] for k in state_dict if not k[0] in numbers}
248 | return state_dict
249 |
250 |
251 | def seed_rng(seed):
252 | torch.manual_seed(seed)
253 | import random
254 | random.seed(seed)
255 | np.random.seed(0)
256 |
257 |
258 | def ensure_dir(dirname):
259 | dirname = Path(dirname)
260 | if not dirname.is_dir():
261 | dirname.mkdir(parents=True, exist_ok=False)
262 |
263 | def read_json(fname):
264 | with fname.open('rt') as handle:
265 | return json.load(handle, object_hook=OrderedDict)
266 |
267 | def write_json(content, fname):
268 | with fname.open('wt') as handle:
269 | json.dump(content, handle, indent=4, sort_keys=False)
270 |
271 | def inf_loop(data_loader):
272 | ''' wrapper function for endless data loader. '''
273 | for loader in repeat(data_loader):
274 | yield from loader
275 |
276 | class Timer:
277 | def __init__(self):
278 | self.cache = datetime.now()
279 |
280 | def check(self):
281 | now = datetime.now()
282 | duration = now - self.cache
283 | self.cache = now
284 | return duration.total_seconds()
285 |
286 | def reset(self):
287 | self.cache = datetime.now()
288 |
--------------------------------------------------------------------------------