├── LICENSE ├── README.md ├── base ├── __init__.py ├── base_data_loader.py ├── base_model.py └── base_trainer.py ├── configs ├── evaluate │ ├── eval_mr_effinetb5.json │ └── eval_mr_resnet18.json └── train │ ├── train_mr_effinetb5.json │ └── train_mr_resnet18.json ├── create_pointcloud.py ├── data_loader ├── __init__.py ├── data_loaders.py ├── kitti_odometry_dataset.py └── scripts │ ├── __init__.py │ └── preprocess_kitti_transfer_gtdepth_to_odom.py ├── depth_proc_tools └── plot_depth_utils.py ├── evaluate.py ├── evaluater ├── __init__.py └── evaluater.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── model ├── __init__.py ├── dymultidepth │ ├── __init__.py │ ├── base_models.py │ ├── ccf_modules.py │ └── dymultidepth_model.py ├── layers.py ├── loss.py ├── loss_functions │ ├── __init__.py │ ├── common_losses.py │ ├── dymultidepth_loss.py │ └── virtual_normal_loss.py ├── metric.py ├── metric_functions │ ├── __init__.py │ ├── completeness_metrics.py │ ├── dense_metrics.py │ └── sparse_metrics.py └── model.py ├── parse_config.py ├── pictures ├── dynamic_depth_result.gif └── overview.jpg ├── requirements.txt ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── parse_config.py ├── ply_utils.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Rui Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic-multiframe-depth 2 | 3 | This is the code repository for the paper: 4 | > **Learning to Fuse Monocular and Multi-view Cues for Multi-frame Depth Estimation in Dynamic Scenes** 5 | > 6 | > [Rui Li](https://ruili3.github.io/), [Dong Gong](https://donggong1.github.io/index.html), [Yin Wei](https://yvanyin.net/), [Hao Chen](https://stan-haochen.github.io/), Yu Zhu, Kaixuan Wang, Xiaozhi Chen, Jinqiu Sun and Yanning Zhang 7 | > 8 | > **CVPR 2023 [[Project](https://ruili3.github.io/dymultidepth/index.html)] [[arXiv](https://arxiv.org/abs/2304.08993)] [[Video](https://www.youtube.com/watch?v=0ViYXt2bpuM)]** 9 | 10 | ![](./pictures/dynamic_depth_result.gif) 11 | 12 | If you use any content of this repo for your work, please cite the following our paper: 13 | ``` 14 | @inproceedings{li2023learning, 15 | title={Learning to Fuse Monocular and Multi-view Cues for Multi-frame Depth Estimation in Dynamic Scenes}, 16 | author={Li, Rui and Gong, Dong and Yin, Wei and Chen, Hao and Zhu, Yu and Wang, Kaixuan and Chen, Xiaozhi and Sun, Jinqiu and Zhang, Yanning}, 17 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 18 | pages={21539--21548}, 19 | year={2023} 20 | } 21 | ``` 22 | 23 | ## Introduction 24 | Multi-frame depth estimation generally achieves high accuracy relying on the multi-view geometric consistency. When applied in dynamic scenes, this consistency is usually violated in the dynamic areas, leading to corrupted estimations. Many multi-frame methods handle dynamic areas by identifying them with explicit masks and compensating the multi-view cues with monocular cues represented as local monocular depth or features. The improvements are limited due to the uncontrolled quality of the masks and the underutilized benefits of the fusion of the two types of cues. In this paper, we propose a novel method to learn to fuse the multi-view and monocular cues encoded as volumes without needing the heuristically crafted masks. As unveiled in our analyses, the multi-view cues capture more accurate geometric information in static areas, and the monocular cues capture more useful contexts in dynamic areas. To let the geometric perception learned from multi-view cues in static areas propagate to the monocular representation in dynamic areas and let monocular cues enhance the representation of multi-view cost volume, we propose a cross-cue fusion (CCF) module, which includes the cross-cue attention (CCA) to encode the spatially non-local relative intra-relations from each source to enhance the representation of the other. Experiments on real-world datasets prove the significant effectiveness and generalization ability of the proposed method. 25 |
26 | 27 | 28 | Overview of the proposed network. 29 |
30 | 31 | 32 | ## Environment Setup 33 | 34 | You can set up your own `conda` virtual environment by running the commands below. 35 | ```shell 36 | # create a clean conda environment from scratch 37 | conda create --name dymultidepth python=3.7 38 | conda activate dymultidepth 39 | # install pip 40 | conda install ipython 41 | conda install pip 42 | # install required packages 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | 47 | 48 | ## KITTI Odometry Data 49 | 50 | We mainly use the KITTI Odometry dataset for training and testing, and you can follow the steps below to prepare the dataset. 51 | 52 | 1. Download the color images and calibration files from the [official website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php). We use the [improved ground truth depth](http://www.cvlibs.net/datasets/kitti/eval_depth_all.php) for training and evaluation. 53 | 2. Unzip the color images and calibration files into ```../data```. Transfer the initial lidar depth maps to the given format using script ```data_loader/scripts/preprocess_kitti_transfer_gtdepth_to_odom.py```. 54 | 55 | 3. The estimated poses can be downloaded 56 | from [here](https://vision.in.tum.de/_media/research/monorec/poses_dvso.zip) and be placed under ``../data/{kitti_path}/poses_dso``. This folder structure is ensured when unpacking the zip file in the ``{kitti_path}`` directory. 57 | 58 | 59 | 4. The dynamic object masks can be downloaded from [here](https://vision.in.tum.de/_media/research/monorec/mvobj_mask.zip). Unpack the .zip file in the ``{kitti_path}`` directory, the data should be put in ``../data/{kitti_path}/sequences/{seq_num}/mvobj_mask``. 60 | 61 | The dataset should be organized as follows: 62 | ``` 63 | data 64 |    └── dataset 65 |       ├── poses_dvso 66 |       │   ├── 00.txt 67 |       │   ├── 01.txt 68 |       │   ├── ... 69 |       └── sequences 70 |       ├── 00 71 |       |   ├── calib.txt 72 |       |   ├── image_2 73 |       |   ├── image_depth_annotated 74 |       |   ├── mvobj_mask 75 |       |   └── times.txt 76 |       ├── ... 77 | ``` 78 | 79 | 80 | 81 | ## Training 82 | 83 | To train the model from scratch, first set the configuration file. 84 | 1. Set the `dataset_dir` to the directory where the KITTI dataset is located. 85 | 2. Set the `save_dir` to the directory where you want to store the trained models. 86 | 87 | Then run the following command to train the model: 88 | 89 | ```shell 90 | # Here we train our model with ResNet-18 backbone as an example. 91 | python train.py --config configs/train/train_mr_resnet18.json 92 | ``` 93 | 94 | ## Evaluation 95 | We provide KITTI-trained checkpoints of our model: 96 | 97 | | Backbone | Resolution | Download | 98 | | --- | --- | --- | 99 | | ResNet-18 | 256 x 512 | [Link](https://drive.google.com/file/d/1IhrBx3bj6H26UDxMNvRxF7xC0C9qqI1u/view?usp=sharing) | 100 | | EfficicentNet-B5 | 256 x 512 | [Link](https://drive.google.com/file/d/1jS1pbCKfYuuoawZ1nnejtGwQV3FhXGcg/view?usp=sharing) | 101 | 102 | 103 | The checkpoints can be saved in `./ckpt`. To reproduce the evaluation results in the paper, run the following commands: 104 | ```shell 105 | python evaluate.py --config configs/evaluate/eval_mr_resnet18.json 106 | ``` 107 | In the `.json` file, set `checkpoint_location` to the model checkpoints path and set `save_dir` to save the evaluation scores. 108 | ## Acknowledgements 109 | Our method is implemented based on [MonoRec](https://github.com/Brummi/MonoRec). We thank the authors for their open-source code. 110 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.shuffle = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def split_validation(self): 58 | if self.valid_sampler is None: 59 | return None 60 | else: 61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 62 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | from logger import TensorboardWriter 5 | from utils import filter_state_dict 6 | 7 | class BaseTrainer: 8 | """ 9 | Base class for all trainers 10 | """ 11 | def __init__(self, model, loss, metrics, optimizer, config): 12 | self.config = config 13 | 14 | if "trainer" in config.config: 15 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 16 | cfg_trainer = config['trainer'] 17 | self.epochs = cfg_trainer['epochs'] 18 | self.save_period = cfg_trainer['save_period'] 19 | self.monitor = cfg_trainer.get('monitor', 'off') 20 | 21 | else: 22 | self.logger = config.get_logger("evaluater") 23 | self.monitor = "off" 24 | 25 | # setup GPU device if available, move model into configured device 26 | self.device, self.device_ids = self._prepare_device(config['n_gpu']) 27 | self.model = model.to(self.device) 28 | if len(self.device_ids) > 1: 29 | self.model = torch.nn.DataParallel(model, device_ids=self.device_ids) 30 | 31 | self.loss = loss 32 | self.metrics = metrics 33 | self.optimizer = optimizer 34 | self.save_multiple = True 35 | 36 | # configuration to monitor model performance and save best 37 | if self.monitor == 'off': 38 | self.mnt_mode = 'off' 39 | self.mnt_best = 0 40 | else: 41 | self.mnt_mode, self.mnt_metric = self.monitor.split() 42 | assert self.mnt_mode in ['min', 'max'] 43 | 44 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 45 | self.early_stop = cfg_trainer.get('early_stop', inf) 46 | 47 | self.start_epoch = 1 48 | 49 | self.checkpoint_dir = config.save_dir 50 | 51 | # setup visualization writer instance 52 | if "trainer" in config.config: 53 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 54 | 55 | if config.resume is not None: 56 | self._resume_checkpoint(config.resume) 57 | 58 | @abstractmethod 59 | def _train_epoch(self, epoch): 60 | """ 61 | Training logic for an epoch 62 | 63 | :param epoch: Current epoch number 64 | """ 65 | raise NotImplementedError 66 | 67 | def train(self): 68 | """ 69 | Full training logic 70 | """ 71 | not_improved_count = 0 72 | for epoch in range(self.start_epoch, self.epochs + 1): 73 | result = self._train_epoch(epoch) 74 | 75 | # save logged informations into log dict 76 | log = {'epoch': epoch} 77 | for key, value in result.items(): 78 | if key == 'metrics': 79 | log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 80 | elif key == 'val_metrics': 81 | log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 82 | else: 83 | log[key] = value 84 | 85 | # print logged informations to the screen 86 | for key, value in log.items(): 87 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 88 | 89 | # evaluate model performance according to configured metric, save best checkpoint as model_best 90 | best = False 91 | if self.mnt_mode != 'off': 92 | try: 93 | # check whether model performance improved or not, according to specified metric(mnt_metric) 94 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 95 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 96 | except KeyError: 97 | self.logger.warning("Warning: Metric '{}' is not found. " 98 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 99 | self.mnt_mode = 'off' 100 | improved = False 101 | 102 | if improved: 103 | self.mnt_best = log[self.mnt_metric] 104 | not_improved_count = 0 105 | best = True 106 | else: 107 | not_improved_count += 1 108 | 109 | if not_improved_count > self.early_stop: 110 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 111 | "Training stops.".format(self.early_stop)) 112 | break 113 | 114 | if epoch % self.save_period == 0: 115 | self._save_checkpoint(epoch, save_best=best) 116 | 117 | def _prepare_device(self, n_gpu_use): 118 | """ 119 | setup GPU device if available, move model into configured device 120 | """ 121 | n_gpu = torch.cuda.device_count() 122 | if n_gpu_use > 0 and n_gpu == 0: 123 | self.logger.warning("Warning: There\'s no GPU available on this machine," 124 | "training will be performed on CPU.") 125 | n_gpu_use = 0 126 | if n_gpu_use > n_gpu: 127 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 128 | "on this machine.".format(n_gpu_use, n_gpu)) 129 | n_gpu_use = n_gpu 130 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 131 | list_ids = list(range(n_gpu_use)) 132 | return device, list_ids 133 | 134 | def _save_checkpoint(self, epoch, save_best=False): 135 | """ 136 | Saving checkpoints 137 | 138 | :param epoch: current epoch number 139 | :param log: logging information of the epoch 140 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 141 | """ 142 | arch = type(self.model).__name__ 143 | state = { 144 | 'arch': arch, 145 | 'epoch': epoch, 146 | 'state_dict': self.model.state_dict(), 147 | 'optimizer': self.optimizer.state_dict(), 148 | 'monitor_best': self.mnt_best, 149 | 'config': self.config 150 | } 151 | if self.save_multiple: 152 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 153 | torch.save(state, filename) 154 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 155 | else: 156 | filename = str(self.checkpoint_dir / 'checkpoint.pth') 157 | torch.save(state, filename) 158 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 159 | if save_best: 160 | best_path = str(self.checkpoint_dir / 'model_best.pth') 161 | torch.save(state, best_path) 162 | self.logger.info("Saving current best: model_best.pth ...") 163 | 164 | def _resume_checkpoint(self, resume_path): 165 | """ 166 | Resume from saved checkpoints 167 | 168 | :param resume_path: Checkpoint path to be resumed 169 | """ 170 | resume_path = str(resume_path) 171 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 172 | checkpoint = torch.load(resume_path) 173 | self.start_epoch = checkpoint['epoch'] + 1 174 | self.mnt_best = checkpoint['monitor_best'] 175 | 176 | # load architecture params from checkpoint. 177 | if checkpoint['config']['arch'] != self.config['arch']: 178 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 179 | "checkpoint. This may yield an exception while state_dict is being loaded.") 180 | # self.model.load_state_dict(checkpoint['state_dict']) 181 | checkpoint_state_dict = filter_state_dict(checkpoint["state_dict"], checkpoint["arch"] == "DataParallel" and len(self.device_ids) == 1) 182 | self.model.load_state_dict(checkpoint_state_dict, strict=False) 183 | 184 | # load optimizer state from checkpoint only when optimizer type is not changed. 185 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 186 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 187 | "Optimizer parameters not being resumed.") 188 | else: 189 | self.optimizer.load_state_dict(checkpoint['optimizer']) 190 | 191 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 192 | -------------------------------------------------------------------------------- /configs/evaluate/eval_mr_effinetb5.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Eval", 3 | "n_gpu": 8, 4 | "timestamp_replacement": "00", 5 | "models": [ 6 | { 7 | "type": "DyMultiDepthModel", 8 | "args": { 9 | "inv_depth_min_max": [ 10 | 0.33, 11 | 0.0025 12 | ], 13 | "checkpoint_location": [ 14 | "./ckpt/mr_kitti_effib5.pth" 15 | ], 16 | "pretrain_mode": 1, 17 | "pretrain_dropout": 0, 18 | "use_stereo": false, 19 | "use_mono": true, 20 | "use_ssim": 1, 21 | "fusion_type": "ccf_fusion", 22 | "input_size": [256, 512], 23 | "freeze_backbone": false, 24 | "backbone_type": "efficientnetb5" 25 | } 26 | } 27 | ], 28 | "data_loader": { 29 | "type": "KittiOdometryDataloader", 30 | "args": { 31 | "dataset_dir": "./data/dataset/", 32 | "depth_folder": "image_depth_annotated", 33 | "batch_size": 1, 34 | "frame_count": 2, 35 | "shuffle": false, 36 | "validation_split": 0, 37 | "num_workers": 8, 38 | "sequences": [ 39 | "00", 40 | "04", 41 | "05", 42 | "07" 43 | ], 44 | "target_image_size": [ 45 | 256, 46 | 512 47 | ], 48 | "use_color": true, 49 | "use_color_augmentation": false, 50 | "use_dso_poses": true, 51 | "lidar_depth": true, 52 | "dso_depth": false, 53 | "return_stereo": false, 54 | 55 | "return_mvobj_mask": 1 56 | } 57 | }, 58 | "loss": "abs_silog_loss_virtualnormal", 59 | "metrics": [ 60 | "abs_rel_sparse_metric", 61 | "sq_rel_sparse_metric", 62 | "rmse_sparse_metric", 63 | "rmse_log_sparse_metric", 64 | "a1_sparse_metric", 65 | "a2_sparse_metric", 66 | "a3_sparse_metric" 67 | ], 68 | "evaluater": { 69 | "save_dir": "./save_dir/", 70 | "max_distance": 80, 71 | "verbosity": 2, 72 | "log_step": 20, 73 | "tensorboard": false, 74 | "eval_mono": false 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /configs/evaluate/eval_mr_resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Eval", 3 | "n_gpu": 8, 4 | "timestamp_replacement": "00", 5 | "models": [ 6 | { 7 | "type": "DyMultiDepthModel", 8 | "args": { 9 | "inv_depth_min_max": [ 10 | 0.33, 11 | 0.0025 12 | ], 13 | "checkpoint_location": [ 14 | "./ckpt/mr_kitti_effib5.pth" 15 | ], 16 | "pretrain_mode": 1, 17 | "pretrain_dropout": 0, 18 | "use_stereo": false, 19 | "use_mono": true, 20 | "use_ssim": 1, 21 | "fusion_type": "ccf_fusion", 22 | "input_size": [256, 512] 23 | } 24 | } 25 | ], 26 | "data_loader": { 27 | "type": "KittiOdometryDataloader", 28 | "args": { 29 | "dataset_dir": "./data/dataset/", 30 | "depth_folder": "image_depth_annotated", 31 | "batch_size": 1, 32 | "frame_count": 2, 33 | "shuffle": false, 34 | "validation_split": 0, 35 | "num_workers": 8, 36 | "sequences": [ 37 | "00", 38 | "04", 39 | "05", 40 | "07" 41 | ], 42 | "target_image_size": [ 43 | 256, 44 | 512 45 | ], 46 | "use_color": true, 47 | "use_color_augmentation": false, 48 | "use_dso_poses": true, 49 | "lidar_depth": true, 50 | "dso_depth": false, 51 | "return_stereo": false, 52 | "return_mvobj_mask": 1 53 | } 54 | }, 55 | "loss": "abs_silog_loss_virtualnormal", 56 | "metrics": [ 57 | "abs_rel_sparse_metric", 58 | "sq_rel_sparse_metric", 59 | "rmse_sparse_metric", 60 | "rmse_log_sparse_metric", 61 | "a1_sparse_metric", 62 | "a2_sparse_metric", 63 | "a3_sparse_metric" 64 | ], 65 | "evaluater": { 66 | "save_dir": "./save_dir/", 67 | "max_distance": 80, 68 | "verbosity": 2, 69 | "log_step": 20, 70 | "tensorboard": false, 71 | "eval_mono": false 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /configs/train/train_mr_effinetb5.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dy_multi_depth", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "DyMultiDepthModel", 6 | "args": { 7 | "pretrain_mode": 1, 8 | "pretrain_dropout": 0.0, 9 | "augmentation": "depth", 10 | "use_mono": true, 11 | "use_stereo": false, 12 | "checkpoint_location": [], 13 | "fusion_type": "ccf_fusion", 14 | "input_size": [256, 512], 15 | "freeze_backbone": false, 16 | "backbone_type": "efficientnetb5" 17 | } 18 | }, 19 | "data_loader": { 20 | "type": "KittiOdometryDataloader", 21 | "args": { 22 | "dataset_dir": "./data/dataset/", 23 | "depth_folder": "image_depth_annotated", 24 | "batch_size": 8, 25 | "frame_count": 2, 26 | "shuffle": true, 27 | "validation_split": 0, 28 | "num_workers": 16, 29 | "sequences": [ 30 | "01", 31 | "02", 32 | "06", 33 | "08", 34 | "09", 35 | "10" 36 | ], 37 | "target_image_size": [ 38 | 256, 39 | 512 40 | ], 41 | "use_color": true, 42 | "use_color_augmentation": true, 43 | "use_dso_poses": true, 44 | "lidar_depth": true, 45 | "dso_depth": false, 46 | "return_stereo": false, 47 | "return_mvobj_mask": true 48 | } 49 | }, 50 | "val_data_loader": { 51 | "type": "KittiOdometryDataloader", 52 | "args": { 53 | "dataset_dir": "./data/dataset/", 54 | "depth_folder": "image_depth_annotated", 55 | "batch_size": 16, 56 | "frame_count": 2, 57 | "shuffle": false, 58 | "validation_split": 0, 59 | "num_workers": 2, 60 | "sequences": [ 61 | "00", 62 | "04", 63 | "05", 64 | "07" 65 | ], 66 | "target_image_size": [ 67 | 256, 68 | 512 69 | ], 70 | "max_length": 32, 71 | "use_color": true, 72 | "use_color_augmentation": true, 73 | "use_dso_poses": true, 74 | "lidar_depth": true, 75 | "dso_depth": false, 76 | "return_stereo": false, 77 | 78 | 79 | "return_mvobj_mask": true 80 | } 81 | }, 82 | "optimizer": { 83 | "type": "Adam", 84 | "args": { 85 | "lr": 1e-4, 86 | "weight_decay": 0, 87 | "amsgrad": true 88 | } 89 | }, 90 | 91 | 92 | 93 | "loss": "abs_silog_loss_virtualnormal", 94 | 95 | 96 | "metrics": [ 97 | "a1_sparse_metric", 98 | "abs_rel_sparse_metric", 99 | "rmse_sparse_metric" 100 | ], 101 | "lr_scheduler": { 102 | "type": "StepLR", 103 | "args": { 104 | "step_size": 65, 105 | "gamma": 0.1 106 | } 107 | }, 108 | "trainer": { 109 | "compute_mask": false, 110 | "compute_stereo_pred": false, 111 | "epochs": 80, 112 | "save_dir": "./saved_model/", 113 | "save_period": 1, 114 | "verbosity": 2, 115 | "log_step": 4800, 116 | "val_log_step": 40, 117 | "alpha": 0.5, 118 | "max_distance": 80, 119 | "monitor": "min abs_rel_sparse_metric", 120 | "timestamp_replacement": "00", 121 | "tensorboard": true 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /configs/train/train_mr_resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dy_multi_depth", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "DyMultiDepthModel", 6 | "args": { 7 | "pretrain_mode": 1, 8 | "pretrain_dropout": 0.0, 9 | "augmentation": "depth", 10 | "use_mono": true, 11 | "use_stereo": false, 12 | "checkpoint_location": [], 13 | "fusion_type": "ccf_fusion", 14 | "input_size": [256, 512] 15 | } 16 | }, 17 | "data_loader": { 18 | "type": "KittiOdometryDataloader", 19 | "args": { 20 | "dataset_dir": "./data/dataset/", 21 | "depth_folder": "image_depth_annotated", 22 | "batch_size": 8, 23 | "frame_count": 2, 24 | "shuffle": true, 25 | "validation_split": 0, 26 | "num_workers": 16, 27 | "sequences": [ 28 | "01", 29 | "02", 30 | "06", 31 | "08", 32 | "09", 33 | "10" 34 | ], 35 | "target_image_size": [ 36 | 256, 37 | 512 38 | ], 39 | "use_color": true, 40 | "use_color_augmentation": true, 41 | "use_dso_poses": true, 42 | "lidar_depth": true, 43 | "dso_depth": false, 44 | "return_stereo": false, 45 | 46 | 47 | "return_mvobj_mask": true 48 | } 49 | }, 50 | "val_data_loader": { 51 | "type": "KittiOdometryDataloader", 52 | "args": { 53 | "dataset_dir": "./data/dataset/", 54 | "depth_folder": "image_depth_annotated", 55 | "batch_size": 16, 56 | "frame_count": 2, 57 | "shuffle": false, 58 | "validation_split": 0, 59 | "num_workers": 2, 60 | "sequences": [ 61 | "00", 62 | "04", 63 | "05", 64 | "07" 65 | ], 66 | "target_image_size": [ 67 | 256, 68 | 512 69 | ], 70 | "max_length": 32, 71 | "use_color": true, 72 | "use_color_augmentation": true, 73 | "use_dso_poses": true, 74 | "lidar_depth": true, 75 | "dso_depth": false, 76 | "return_stereo": false, 77 | 78 | 79 | "return_mvobj_mask": true 80 | } 81 | }, 82 | "optimizer": { 83 | "type": "Adam", 84 | "args": { 85 | "lr": 1e-4, 86 | "weight_decay": 0, 87 | "amsgrad": true 88 | } 89 | }, 90 | 91 | 92 | 93 | "loss": "abs_silog_loss_virtualnormal", 94 | 95 | 96 | "metrics": [ 97 | "a1_sparse_metric", 98 | "abs_rel_sparse_metric", 99 | "rmse_sparse_metric" 100 | ], 101 | "lr_scheduler": { 102 | "type": "StepLR", 103 | "args": { 104 | "step_size": 65, 105 | "gamma": 0.1 106 | } 107 | }, 108 | "trainer": { 109 | "compute_mask": false, 110 | "compute_stereo_pred": false, 111 | "epochs": 80, 112 | "save_dir": "./saved_model/", 113 | "save_period": 1, 114 | "verbosity": 2, 115 | "log_step": 4800, 116 | "val_log_step": 40, 117 | "alpha": 0.5, 118 | "max_distance": 80, 119 | "monitor": "min abs_rel_sparse_metric", 120 | "timestamp_replacement": "00", 121 | "tensorboard": true 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /create_pointcloud.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | 8 | import data_loader.data_loaders as module_data 9 | import model.model as module_arch 10 | from utils.parse_config import ConfigParser 11 | from utils import to, PLYSaver, DS_Wrapper 12 | 13 | import torch.nn.functional as F 14 | 15 | 16 | def main(config): 17 | logger = config.get_logger('test') 18 | 19 | output_dir = Path(config.config.get("output_dir", "saved")) 20 | output_dir.mkdir(exist_ok=True, parents=True) 21 | file_name = config.config.get("file_name", "pc.ply") 22 | use_mask = config.config.get("use_mask", True) 23 | roi = config.config.get("roi", None) 24 | 25 | max_d = config.config.get("max_d", 30) 26 | min_d = config.config.get("min_d", 3) 27 | 28 | start = config.config.get("start", 0) 29 | end = config.config.get("end", -1) 30 | 31 | # setup data_loader instances 32 | data_loader = DataLoader(DS_Wrapper(config.initialize('data_set', module_data), start=start, end=end), batch_size=1, shuffle=False, num_workers=8) 33 | 34 | # build model architecture 35 | model = config.initialize('arch', module_arch) 36 | logger.info(model) 37 | 38 | if config['n_gpu'] > 1: 39 | model = torch.nn.DataParallel(model) 40 | 41 | # prepare model for testing 42 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | model = model.to(device) 44 | model.eval() 45 | 46 | mask_fill = 32 47 | 48 | n = data_loader.batch_size 49 | 50 | target_image_size = data_loader.dataset.dataset.target_image_size 51 | 52 | plysaver = PLYSaver(target_image_size[0], target_image_size[1], min_d=min_d, max_d=max_d, batch_size=n, roi=roi, dropout=.75) 53 | plysaver.to(device) 54 | 55 | pose_buffer = [] 56 | intrinsics_buffer = [] 57 | mask_buffer = [] 58 | keyframe_buffer = [] 59 | depth_buffer = [] 60 | 61 | buffer_length = 5 62 | min_hits = 1 63 | key_index = buffer_length // 2 64 | 65 | with torch.no_grad(): 66 | for i, (data, target) in enumerate(tqdm(data_loader)): 67 | data = to(data, device) 68 | # if not torch.any(pose_distance_thresh(data, spatial_thresh=1)): 69 | # continue 70 | result = model(data) 71 | if not isinstance(result, dict): 72 | result = {"result": result[0]} 73 | output = result["result"] 74 | if "cv_mask" not in result: 75 | result["cv_mask"] = output.new_zeros(output.shape) 76 | # mask = ((result["cv_mask"] >= .1) & (output >= 1 / max_d)).to(dtype=torch.float32) 77 | mask = (result["cv_mask"] >= .1).to(dtype=torch.float32) 78 | mask = (F.conv2d(mask, mask.new_ones((1, 1, mask_fill+1, mask_fill+1)), padding=mask_fill // 2) < 1).to(dtype=torch.float32) 79 | 80 | pose_buffer += data["keyframe_pose"] 81 | intrinsics_buffer += [data["keyframe_intrinsics"]] 82 | mask_buffer += [mask] 83 | keyframe_buffer += [data["keyframe"]] 84 | depth_buffer += [output] 85 | 86 | if len(pose_buffer) >= buffer_length: 87 | pose = pose_buffer[key_index] 88 | intrinsics = intrinsics_buffer[key_index] 89 | keyframe = keyframe_buffer[key_index] 90 | depth = depth_buffer[key_index] 91 | 92 | mask = (torch.sum(torch.stack(mask_buffer), dim=0) > buffer_length - min_hits).to(dtype=torch.float32) 93 | if use_mask: 94 | depth *= mask 95 | 96 | plysaver.add_depthmap(depth, keyframe, intrinsics, pose) 97 | 98 | del pose_buffer[0] 99 | del intrinsics_buffer[0] 100 | del mask_buffer[0] 101 | del keyframe_buffer[0] 102 | del depth_buffer[0] 103 | 104 | with open(output_dir / file_name, "wb") as f: 105 | plysaver.save(f) 106 | 107 | 108 | if __name__ == '__main__': 109 | args = argparse.ArgumentParser(description='PyTorch Template') 110 | 111 | args.add_argument('-r', '--resume', default=None, type=str, 112 | help='path to latest checkpoint (default: None)') 113 | args.add_argument('-c', '--config', default=None, type=str, 114 | help='config file path (default: None)') 115 | args.add_argument('-d', '--device', default=None, type=str, 116 | help='indices of GPUs to enable (default: all)') 117 | 118 | config = ConfigParser(args) 119 | main(config) 120 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/data_loader/__init__.py -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | from base import BaseDataLoader 2 | 3 | from .kitti_odometry_dataset import * 4 | 5 | class KittiOdometryDataloader(BaseDataLoader): 6 | 7 | def __init__(self, batch_size=1, shuffle=True, validation_split=0.0, num_workers=4, **kwargs): 8 | self.dataset = KittiOdometryDataset(**kwargs) 9 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) -------------------------------------------------------------------------------- /data_loader/kitti_odometry_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pykitti 6 | import torch 7 | import torchvision 8 | from PIL import Image 9 | from scipy import sparse 10 | from skimage.transform import resize 11 | from torch.utils.data import Dataset 12 | 13 | from utils import map_fn 14 | import torchvision.transforms.functional as F 15 | 16 | 17 | class KittiOdometryDataset(Dataset): 18 | 19 | def __init__(self, dataset_dir, frame_count=2, sequences=None, depth_folder="image_depth_annotated", 20 | target_image_size=(256, 512), max_length=None, dilation=1, offset_d=0, use_color=True, use_dso_poses=False, use_color_augmentation=False, lidar_depth=False, dso_depth=True, annotated_lidar=True, return_stereo=False, return_mvobj_mask=False, use_index_mask=()): 21 | """ 22 | Dataset implementation for KITTI Odometry. 23 | :param dataset_dir: Top level folder for KITTI Odometry (should contain folders sequences, poses, poses_dvso (if available) 24 | :param frame_count: Number of frames used per sample (excluding the keyframe). By default, the keyframe is in the middle of those frames. (Default=2) 25 | :param sequences: Which sequences to use. Should be tuple of strings, e.g. ("00", "01", ...) 26 | :param depth_folder: The folder within the sequence folder that contains the depth information (e.g. sequences/00/{depth_folder}) 27 | :param target_image_size: Desired image size (correct processing of depths is only guaranteed for default value). (Default=(256, 512)) 28 | :param max_length: Maximum length per sequence. Useful for splitting up sequences and testing. (Default=None) 29 | :param dilation: Spacing between the frames (Default 1) 30 | :param offset_d: Index offset for frames (offset_d=0 means keyframe is centered). (Default=0) 31 | :param use_color: Use color (camera 2) or greyscale (camera 0) images (default=True) 32 | :param use_dso_poses: Use poses provided by d(v)so instead of KITTI poses. Requires poses_dvso folder. (Default=True) 33 | :param use_color_augmentation: Use color jitter augmentation. The same transformation is applied to all frames in a sample. (Default=False) 34 | :param lidar_depth: Use depth information from (annotated) velodyne data. (Default=False) 35 | :param dso_depth: Use depth information from d(v)so. (Default=True) 36 | :param annotated_lidar: If lidar_depth=True, then this determines whether to use annotated or non-annotated depth maps. (Default=True) 37 | :param return_mvobj_mask: Return additional moving object mask. 38 | """ 39 | self.dataset_dir = Path(dataset_dir) 40 | self.frame_count = frame_count 41 | self.sequences = sequences 42 | self.depth_folder = depth_folder 43 | self.lidar_depth = lidar_depth 44 | self.annotated_lidar = annotated_lidar 45 | self.dso_depth = dso_depth 46 | self.target_image_size = target_image_size 47 | self.use_index_mask = use_index_mask 48 | self.offset_d = offset_d 49 | if self.sequences is None: 50 | self.sequences = [f"{i:02d}" for i in range(11)] 51 | self._datasets = [pykitti.odometry(dataset_dir, sequence) for sequence in self.sequences] 52 | self._offset = (frame_count // 2) * dilation 53 | extra_frames = frame_count * dilation 54 | if self.annotated_lidar and self.lidar_depth: 55 | extra_frames = max(extra_frames, 10) 56 | self._offset = max(self._offset, 5) 57 | self._dataset_sizes = [ 58 | len((dataset.cam0_files if not use_color else dataset.cam2_files)) - (extra_frames if self.use_index_mask is None else 0) for dataset in 59 | self._datasets] 60 | if self.use_index_mask is not None: 61 | index_masks = [] 62 | for sequence_length, sequence in zip(self._dataset_sizes, self.sequences): 63 | index_mask = {i:True for i in range(sequence_length)} 64 | for index_mask_name in self.use_index_mask: 65 | with open(self.dataset_dir / "sequences" / sequence / (index_mask_name + ".json")) as f: 66 | m = json.load(f) 67 | for k in list(index_mask.keys()): 68 | if not str(k) in m or not m[str(k)]: 69 | del index_mask[k] 70 | index_masks.append(index_mask) 71 | self._indices = [ 72 | list(sorted([int(k) for k in sorted(index_mask.keys()) if index_mask[k] and int(k) >= self._offset and int(k) < dataset_size + self._offset - extra_frames])) 73 | for index_mask, dataset_size in zip(index_masks, self._dataset_sizes) 74 | ] 75 | self._dataset_sizes = [len(indices) for indices in self._indices] 76 | if max_length is not None: 77 | self._dataset_sizes = [min(s, max_length) for s in self._dataset_sizes] 78 | self.length = sum(self._dataset_sizes) 79 | 80 | intrinsics_box = [self.compute_target_intrinsics(dataset, target_image_size, use_color) for dataset in 81 | self._datasets] 82 | self._crop_boxes = [b for _, b in intrinsics_box] 83 | if self.dso_depth: 84 | self.dso_depth_parameters = [self.get_dso_depth_parameters(dataset) for dataset in self._datasets] 85 | elif not self.lidar_depth: 86 | self._depth_crop_boxes = [ 87 | self.compute_depth_crop(self.dataset_dir / "sequences" / s / depth_folder) for s in 88 | self.sequences] 89 | self._intrinsics = [format_intrinsics(i, self.target_image_size) for i, _ in intrinsics_box] 90 | self.dilation = dilation 91 | self.use_color = use_color 92 | self.use_dso_poses = use_dso_poses 93 | self.use_color_augmentation = use_color_augmentation 94 | if self.use_dso_poses: 95 | for dataset in self._datasets: 96 | dataset.pose_path = self.dataset_dir / "poses_dvso" 97 | dataset._load_poses() 98 | if self.use_color_augmentation: 99 | self.color_transform = ColorJitterMulti(brightness=.2, contrast=.2, saturation=.2, hue=.1) 100 | self.return_stereo = return_stereo 101 | self.return_mvobj_mask = return_mvobj_mask 102 | 103 | def get_dataset_index(self, index: int): 104 | for dataset_index, dataset_size in enumerate(self._dataset_sizes): 105 | if index >= dataset_size: 106 | index = index - dataset_size 107 | else: 108 | return dataset_index, index 109 | return None, None 110 | 111 | def preprocess_image(self, img: Image.Image, crop_box=None): 112 | if crop_box: 113 | img = img.crop(crop_box) 114 | if self.target_image_size: 115 | img = img.resize((self.target_image_size[1], self.target_image_size[0]), resample=Image.BILINEAR) 116 | if self.use_color_augmentation: 117 | img = self.color_transform(img) 118 | image_tensor = torch.tensor(np.array(img).astype(np.float32)) 119 | image_tensor = image_tensor / 255 - .5 120 | if not self.use_color: 121 | image_tensor = torch.stack((image_tensor, image_tensor, image_tensor)) 122 | else: 123 | image_tensor = image_tensor.permute(2, 0, 1) 124 | del img 125 | return image_tensor 126 | 127 | def preprocess_depth(self, depth: np.ndarray, crop_box=None): 128 | if crop_box: 129 | if crop_box[1] >= 0 and crop_box[3] <= depth.shape[0]: 130 | depth = depth[int(crop_box[1]):int(crop_box[3]), :] 131 | else: 132 | depth_ = np.ones((crop_box[3] - crop_box[1], depth.shape[1])) 133 | depth_[-crop_box[1]:-crop_box[1]+depth.shape[0], :] = depth 134 | depth = depth_ 135 | if crop_box[0] >= 0 and crop_box[2] <= depth.shape[1]: 136 | depth = depth[:, int(crop_box[0]):int(crop_box[2])] 137 | else: 138 | depth_ = np.ones((depth.shape[0], crop_box[2] - crop_box[0])) 139 | depth_[:, -crop_box[0]:-crop_box[0]+depth.shape[1]] = depth 140 | depth = depth_ 141 | if self.target_image_size: 142 | depth = resize(depth, self.target_image_size, order=0) 143 | return torch.tensor(1 / depth) 144 | 145 | def preprocess_depth_dso(self, depth: Image.Image, dso_depth_parameters, crop_box=None): 146 | h, w, f_x = dso_depth_parameters 147 | depth = np.array(depth, dtype=np.float) 148 | indices = np.array(np.nonzero(depth), dtype=np.float) 149 | indices[0] = np.clip(indices[0] / depth.shape[0] * h, 0, h-1) 150 | indices[1] = np.clip(indices[1] / depth.shape[1] * w, 0, w-1) 151 | 152 | depth = depth[depth > 0] 153 | depth = (w * depth / (0.54 * f_x * 65535)) 154 | 155 | data = np.concatenate([indices, np.expand_dims(depth, axis=0)], axis=0) 156 | 157 | if crop_box: 158 | data = data[:, (crop_box[1] <= data[0, :]) & (data[0, :] < crop_box[3]) & (crop_box[0] <= data[1, :]) & (data[1, :] < crop_box[2])] 159 | data[0, :] -= crop_box[1] 160 | data[1, :] -= crop_box[0] 161 | crop_height = crop_box[3] - crop_box[1] 162 | crop_width = crop_box[2] - crop_box[0] 163 | else: 164 | crop_height = h 165 | crop_width = w 166 | 167 | data[0] = np.clip(data[0] / crop_height * self.target_image_size[0], 0, self.target_image_size[0]-1) 168 | data[1] = np.clip(data[1] / crop_width * self.target_image_size[1], 0, self.target_image_size[1]-1) 169 | 170 | depth = np.zeros(self.target_image_size) 171 | depth[np.around(data[0]).astype(np.int), np.around(data[1]).astype(np.int)] = data[2] 172 | 173 | return torch.tensor(depth, dtype=torch.float32) 174 | 175 | def preprocess_depth_annotated_lidar(self, depth: Image.Image, crop_box=None): 176 | depth = np.array(depth, dtype=np.float) 177 | h, w = depth.shape 178 | indices = np.array(np.nonzero(depth), dtype=np.float) 179 | 180 | depth = depth[depth > 0] 181 | depth = 256.0 / depth 182 | 183 | data = np.concatenate([indices, np.expand_dims(depth, axis=0)], axis=0) 184 | 185 | if crop_box: 186 | data = data[:, (crop_box[1] <= data[0, :]) & (data[0, :] < crop_box[3]) & (crop_box[0] <= data[1, :]) & ( 187 | data[1, :] < crop_box[2])] 188 | data[0, :] -= crop_box[1] 189 | data[1, :] -= crop_box[0] 190 | crop_height = crop_box[3] - crop_box[1] 191 | crop_width = crop_box[2] - crop_box[0] 192 | else: 193 | crop_height = h 194 | crop_width = w 195 | 196 | data[0] = np.clip(data[0] / crop_height * self.target_image_size[0], 0, self.target_image_size[0] - 1) 197 | data[1] = np.clip(data[1] / crop_width * self.target_image_size[1], 0, self.target_image_size[1] - 1) 198 | 199 | depth = np.zeros(self.target_image_size) 200 | depth[np.around(data[0]).astype(np.int), np.around(data[1]).astype(np.int)] = data[2] 201 | 202 | return torch.tensor(depth, dtype=torch.float32) 203 | 204 | def __getitem__(self, index: int): 205 | dataset_index, index = self.get_dataset_index(index) 206 | if dataset_index is None: 207 | raise IndexError() 208 | 209 | if self.use_index_mask is not None: 210 | index = self._indices[dataset_index][index] - self._offset 211 | 212 | sequence_folder = self.dataset_dir / "sequences" / self.sequences[dataset_index] 213 | depth_folder = sequence_folder / self.depth_folder 214 | 215 | if self.use_color_augmentation: 216 | self.color_transform.fix_transform() 217 | 218 | dataset = self._datasets[dataset_index] 219 | keyframe_intrinsics = self._intrinsics[dataset_index] 220 | if not (self.lidar_depth or self.dso_depth): 221 | keyframe_depth = self.preprocess_depth(np.load(depth_folder / f"{(index + self._offset):06d}.npy"), self._depth_crop_boxes[dataset_index]).type(torch.float32).unsqueeze(0) 222 | else: 223 | if self.lidar_depth: 224 | if not self.annotated_lidar: 225 | lidar_depth = 1 / torch.tensor(sparse.load_npz(depth_folder / f"{(index + self._offset):06d}.npz").todense()).type(torch.float32).unsqueeze(0) 226 | lidar_depth[torch.isinf(lidar_depth)] = 0 227 | keyframe_depth = lidar_depth 228 | else: 229 | keyframe_depth = self.preprocess_depth_annotated_lidar(Image.open(depth_folder / f"{(index + self._offset):06d}.png"), self._crop_boxes[dataset_index]).unsqueeze(0) 230 | else: 231 | keyframe_depth = torch.zeros(1, self.target_image_size[0], self.target_image_size[1], dtype=torch.float32) 232 | 233 | if self.dso_depth: 234 | dso_depth = self.preprocess_depth_dso(Image.open(depth_folder / f"{(index + self._offset):06d}.png"), self.dso_depth_parameters[dataset_index], self._crop_boxes[dataset_index]).unsqueeze(0) 235 | mask = dso_depth == 0 236 | dso_depth[mask] = keyframe_depth[mask] 237 | keyframe_depth = dso_depth 238 | 239 | keyframe = self.preprocess_image( 240 | (dataset.get_cam0 if not self.use_color else dataset.get_cam2)(index + self._offset), 241 | self._crop_boxes[dataset_index]) 242 | keyframe_pose = torch.tensor(dataset.poses[index + self._offset], dtype=torch.float32) 243 | 244 | frames = [self.preprocess_image((dataset.get_cam0 if not self.use_color else dataset.get_cam2)(index + self._offset + i + self.offset_d), 245 | self._crop_boxes[dataset_index]) for i in 246 | range(-(self.frame_count // 2) * self.dilation, ((self.frame_count + 1) // 2) * self.dilation + 1, self.dilation) if i != 0] 247 | intrinsics = [self._intrinsics[dataset_index] for _ in range(self.frame_count)] 248 | poses = [torch.tensor(dataset.poses[index + self._offset + i + self.offset_d], dtype=torch.float32) for i in 249 | range(-(self.frame_count // 2) * self.dilation, ((self.frame_count + 1) // 2) * self.dilation + 1, self.dilation) if i != 0] 250 | 251 | data = { 252 | "keyframe": keyframe, 253 | "keyframe_pose": keyframe_pose, 254 | "keyframe_intrinsics": keyframe_intrinsics, 255 | "frames": frames, 256 | "poses": poses, 257 | "intrinsics": intrinsics, 258 | "sequence": torch.tensor([int(self.sequences[dataset_index])], dtype=torch.int32), 259 | "image_id": torch.tensor([int(index + self._offset)], dtype=torch.int32) 260 | } 261 | 262 | if self.return_mvobj_mask > 0: 263 | mask = torch.tensor(np.load(sequence_folder / "mvobj_mask" / f"{index + self._offset:06d}.npy"), dtype=torch.float32).unsqueeze(0) 264 | data["mvobj_mask"] = mask 265 | 266 | return data, keyframe_depth 267 | 268 | def __len__(self) -> int: 269 | return self.length 270 | 271 | def compute_depth_crop(self, depth_folder): 272 | # This function is only used for dense gt depth maps. 273 | example_dm = np.load(depth_folder / "000000.npy") 274 | ry = example_dm.shape[0] / self.target_image_size[0] 275 | rx = example_dm.shape[1] / self.target_image_size[1] 276 | if ry < 1 or rx < 1: 277 | if ry >= rx: 278 | o_w = example_dm.shape[1] 279 | w = int(np.ceil(ry * self.target_image_size[1])) 280 | h = example_dm.shape[0] 281 | return ((o_w - w) // 2, 0, (o_w - w) // 2 + w, h) 282 | else: 283 | o_h = example_dm.shape[0] 284 | h = int(np.ceil(rx * self.target_image_size[0])) 285 | w = example_dm.shape[1] 286 | return (0, (o_h - h) // 2, w, (o_h - h) // 2 + h) 287 | if ry >= rx: 288 | o_h = example_dm.shape[0] 289 | h = rx * self.target_image_size[0] 290 | w = example_dm.shape[1] 291 | return (0, (o_h - h) // 2, w, (o_h - h) // 2 + h) 292 | else: 293 | o_w = example_dm.shape[1] 294 | w = ry * self.target_image_size[1] 295 | h = example_dm.shape[0] 296 | return ((o_w - w) // 2, 0, (o_w - w) // 2 + w, h) 297 | 298 | def compute_target_intrinsics(self, dataset, target_image_size, use_color): 299 | # Because of cropping and resizing of the frames, we need to recompute the intrinsics 300 | P_cam = dataset.calib.P_rect_00 if not use_color else dataset.calib.P_rect_20 301 | orig_size = tuple(reversed((dataset.cam0 if not use_color else dataset.cam2).__next__().size)) 302 | 303 | r_orig = orig_size[0] / orig_size[1] 304 | r_target = target_image_size[0] / target_image_size[1] 305 | 306 | if r_orig >= r_target: 307 | new_height = r_target * orig_size[1] 308 | box = (0, (orig_size[0] - new_height) // 2, orig_size[1], orig_size[0] - (orig_size[0] - new_height) // 2) 309 | 310 | c_x = P_cam[0, 2] / orig_size[1] 311 | c_y = (P_cam[1, 2] - (orig_size[0] - new_height) / 2) / new_height 312 | 313 | rescale = orig_size[1] / target_image_size[1] 314 | 315 | else: 316 | new_width = orig_size[0] / r_target 317 | box = ((orig_size[1] - new_width) // 2, 0, orig_size[1] - (orig_size[1] - new_width) // 2, orig_size[0]) 318 | 319 | c_x = (P_cam[0, 2] - (orig_size[1] - new_width) / 2) / new_width 320 | c_y = P_cam[1, 2] / orig_size[0] 321 | 322 | rescale = orig_size[0] / target_image_size[0] 323 | 324 | f_x = P_cam[0, 0] / target_image_size[1] / rescale 325 | f_y = P_cam[1, 1] / target_image_size[0] / rescale 326 | 327 | intrinsics = (f_x, f_y, c_x, c_y) 328 | 329 | return intrinsics, box 330 | 331 | def get_dso_depth_parameters(self, dataset): 332 | # Info required to process d(v)so depths 333 | P_cam = dataset.calib.P_rect_20 334 | orig_size = tuple(reversed(dataset.cam2.__next__().size)) 335 | return orig_size[0], orig_size[1], P_cam[0, 0] 336 | 337 | def get_index(self, sequence, index): 338 | for i in range(len(self.sequences)): 339 | if int(self.sequences[i]) != sequence: 340 | index += self._dataset_sizes[i] 341 | else: 342 | break 343 | return index 344 | 345 | 346 | def format_intrinsics(intrinsics, target_image_size): 347 | intrinsics_mat = torch.zeros(4, 4) 348 | intrinsics_mat[0, 0] = intrinsics[0] * target_image_size[1] 349 | intrinsics_mat[1, 1] = intrinsics[1] * target_image_size[0] 350 | intrinsics_mat[0, 2] = intrinsics[2] * target_image_size[1] 351 | intrinsics_mat[1, 2] = intrinsics[3] * target_image_size[0] 352 | intrinsics_mat[2, 2] = 1 353 | intrinsics_mat[3, 3] = 1 354 | return intrinsics_mat 355 | 356 | 357 | 358 | # if torch.__version__ == "1.5.0": 359 | torchvision_version = torchvision.__version__.split(".")[1] 360 | 361 | if int(torchvision_version) < 9: 362 | class ColorJitterMulti(torchvision.transforms.ColorJitter): 363 | def fix_transform(self): 364 | self.transform = self.get_params(self.brightness, self.contrast, 365 | self.saturation, self.hue) 366 | 367 | def __call__(self, x): 368 | return map_fn(x, self.transform) 369 | else: 370 | class ColorJitterMulti(torchvision.transforms.ColorJitter): 371 | def fix_transform(self): 372 | self.params = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) 373 | 374 | def __call__(self, img): 375 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.params 376 | for fn_id in fn_idx: 377 | if fn_id == 0 and brightness_factor is not None: 378 | img = F.adjust_brightness(img, brightness_factor) 379 | elif fn_id == 1 and contrast_factor is not None: 380 | img = F.adjust_contrast(img, contrast_factor) 381 | elif fn_id == 2 and saturation_factor is not None: 382 | img = F.adjust_saturation(img, saturation_factor) 383 | elif fn_id == 3 and hue_factor is not None: 384 | img = F.adjust_hue(img, hue_factor) 385 | return img -------------------------------------------------------------------------------- /data_loader/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/data_loader/scripts/__init__.py -------------------------------------------------------------------------------- /data_loader/scripts/preprocess_kitti_transfer_gtdepth_to_odom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from pathlib import Path 4 | from zipfile import ZipFile 5 | 6 | mapping = { 7 | "2011_10_03_drive_0027": "00", 8 | "2011_10_03_drive_0042": "01", 9 | "2011_10_03_drive_0034": "02", 10 | "2011_09_26_drive_0067": "03", 11 | "2011_09_30_drive_0016": "04", 12 | "2011_09_30_drive_0018": "05", 13 | "2011_09_30_drive_0020": "06", 14 | "2011_09_30_drive_0027": "07", 15 | "2011_09_30_drive_0028": "08", 16 | "2011_09_30_drive_0033": "09", 17 | "2011_09_30_drive_0034": "10" 18 | } 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description=''' 22 | This script creates depth images from annotated velodyne data. 23 | ''') 24 | parser.add_argument("--output", "-o", help="Path of KITTI odometry dataset", default="../../../data/dataset") 25 | parser.add_argument("--input", "-i", help="Path to KITTI depth dataset (zipped)", required=True) 26 | parser.add_argument("--depth_folder", "-d", help="Name of depth map folders for the respective sequences", default="image_depth_annotated") 27 | 28 | args = parser.parse_args() 29 | input = Path(args.input) 30 | output = Path(args.output) 31 | depth_folder = args.depth_folder 32 | 33 | drives = mapping.keys() 34 | 35 | print("Creating folder structure") 36 | for drive in drives: 37 | sequence = mapping[drive] 38 | folder = output/ "sequences" / sequence / depth_folder 39 | folder.mkdir(parents=True, exist_ok=True) 40 | print(folder) 41 | 42 | print("Extracting enhanced depth maps") 43 | 44 | with ZipFile(input) as depth_archive: 45 | for name in depth_archive.namelist(): 46 | if name[0] == "t": 47 | drive = name[6:27] 48 | else: 49 | drive = name[4:25] 50 | cam = name[-16] 51 | img = name[-10:] 52 | 53 | # the first frame of seq-08 corresponds to the 1100 frame of kitti raw 54 | if drive=='2011_09_30_drive_0028': 55 | raw_img_id = img.split('.')[0] 56 | raw_img_id = int(raw_img_id) 57 | if raw_img_id < 1100: 58 | continue 59 | else: 60 | img = "{:06d}.png".format(raw_img_id - 1100) 61 | 62 | if cam == '2' and drive in drives: 63 | to = output / "sequences" / mapping[drive] / depth_folder / img 64 | print(name, " -> ", to) 65 | with depth_archive.open(name) as i, open(to, 'wb') as o: 66 | shutil.copyfileobj(i, o) 67 | 68 | main() -------------------------------------------------------------------------------- /depth_proc_tools/plot_depth_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image as pil 4 | import matplotlib as mpl 5 | import matplotlib.cm as cm 6 | import matplotlib.pyplot as plt 7 | import cv2 8 | from PIL import Image 9 | # import jax 10 | 11 | def save_pseudo_color_depth(depth:np.ndarray, colormap='rainbow', save_path="./"): 12 | return True 13 | 14 | def disp_to_depth_metric(disp, min_depth=0.1, max_depth=100.0): 15 | """Convert network's sigmoid output into depth prediction, ref: Monodepth2 16 | """ 17 | min_disp = 1 / max_depth 18 | max_disp = 1 / min_depth 19 | scaled_disp = min_disp + (max_disp - min_disp) * disp 20 | depth = 1 / scaled_disp 21 | return scaled_disp, depth 22 | 23 | 24 | 25 | def save_color_imgs(image:np.ndarray, save_id=None, post_fix="img", save_path="./"): 26 | ''' 27 | image: with shape (h,w,c=3) 28 | save_id = specify the name of the saved image 29 | ''' 30 | if not isinstance(image, np.ndarray): 31 | raise Exception("Input image is not a np.ndarray") 32 | if not len(image.shape) == 3: 33 | raise Exception("Wong input shape. It should be a 3-dim image vector of shape (h,w,c)") 34 | 35 | if save_id is None: 36 | dirnames = os.listdir(save_path) 37 | save_id = len(dirnames) 38 | 39 | save_name = os.path.join(save_path,"{}_{}.jpg".format(save_id,post_fix)) 40 | 41 | # for pytorch 42 | if image.shape[-1]==3 or image.shape[-1]==1: 43 | plt.imsave(save_name, image) 44 | else: 45 | raise Exception("invalid color channel of the last dim") 46 | 47 | print(f"successfully saved {save_name}!") 48 | 49 | 50 | def save_pseudo_color(input:np.ndarray, save_id=None, post_fix="error", pseudo_color="rainbow", save_path="./", vmax=None): 51 | ''' 52 | input: with shape (h,w,c=3) 53 | save_id = specify the name of the saved error map 54 | ''' 55 | if not isinstance(input, np.ndarray): 56 | raise Exception("Input input is not a np.ndarray") 57 | if not len(input.shape) == 3: 58 | raise Exception("Wong input shape. It should be a 3-dim input vector of shape (h,w,c)") 59 | if save_id is None: 60 | dirnames = os.listdir(save_path) 61 | save_id = len(dirnames) 62 | 63 | save_name = os.path.join(save_path,"{}_{}.jpg".format(save_id,post_fix)) 64 | 65 | # for pytorch 66 | if input.shape[-1]==1: 67 | disp_resized_np = input.squeeze(-1) 68 | # if "error_" in post_fix: 69 | # print("save/photomtric error {} map:{},max:{},min:{},mean:{}".format(post_fix, disp_resized_np[:20,0],disp_resized_np.max(),disp_resized_np.min(), 70 | # disp_resized_np.mean())) 71 | vmax = np.percentile(disp_resized_np, 95) if vmax is None else vmax 72 | normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) 73 | mapper = cm.ScalarMappable(norm=normalizer, cmap=pseudo_color) 74 | colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) 75 | im = pil.fromarray(colormapped_im) 76 | im.save(save_name) 77 | else: 78 | raise Exception("invalid color channel of the last dim") 79 | 80 | print(f"successfully saved {save_name}!") 81 | 82 | 83 | def numpy_intensitymap_to_pcolor(input, vmin=None, vmax=None, colormap='rainbow'): 84 | ''' 85 | input: h,w,1 86 | ''' 87 | if input.shape[-1]==1: 88 | colormapped_im = numpy_1d_to_coloruint8(input, vmin, vmax, colormap) 89 | im = pil.fromarray(colormapped_im.astype(np.uint8)) 90 | return im 91 | else: 92 | raise Exception("invalid color channel of the last dim") 93 | 94 | 95 | def numpy_1d_to_coloruint8(input, vmin=None, vmax=None, colormap='rainbow'): 96 | ''' 97 | input: h,w,1 98 | ''' 99 | if input.shape[-1]==1: 100 | input = input.squeeze(-1) 101 | invalid_mask = (input == 0).astype(float) 102 | vmax = np.percentile(input, 95) if vmax is None else vmax 103 | vmin = 1e-3 if vmin is None else vmin # vmin = input.min() if vmin is None else vmin 104 | normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 105 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) 106 | colormapped_im = (mapper.to_rgba(input)[:, :, :3] * 255).astype(np.uint8) 107 | invalid_mask = np.expand_dims(invalid_mask,-1) 108 | colormapped_im = colormapped_im * (1-invalid_mask) + (invalid_mask * 255) 109 | return colormapped_im 110 | else: 111 | raise Exception("invalid color channel of the last dim") 112 | 113 | 114 | def numpy_distancemap_validmask_to_pil(input, vmax=None, colormap='rainbow'): 115 | if input.shape[-1]==1: 116 | # For PIL image, you [should] squeeze the addtional dim and set the resolution strictly to "h,w" 117 | input = input.squeeze(-1) 118 | validmask = (input != 0).astype(float) 119 | mask_img = (validmask*255.0).astype(np.uint8) 120 | im = pil.fromarray(mask_img) 121 | return im 122 | else: 123 | raise Exception("invalid color channel of the last dim") 124 | 125 | def numpy_rgb_to_pil(input): 126 | if input.shape[-1]==3: 127 | if input.max()<=1: 128 | colormapped_im = (input[:, :, :3] * 255).astype(np.uint8) 129 | im = pil.fromarray(colormapped_im) 130 | else: 131 | im = pil.fromarray(input) 132 | return im 133 | else: 134 | raise Exception("invalid color channel of the last dim") 135 | 136 | 137 | def get_error_map_value(pred_depth, gt_depth, grag_crop=True, median_scaling=True): 138 | ''' 139 | input shape: h,w,c 140 | ''' 141 | validmask = (gt_depth!=0).astype(bool) 142 | h,w,_ = gt_depth.shape 143 | 144 | if grag_crop: 145 | valid_area = (int(0.40810811 * h), int(0.99189189 * h), int(0.03594771 * w), int(0.96405229 * w)) 146 | area_mask = np.zeros(gt_depth.shape) 147 | area_mask[valid_area[0]:valid_area[1],valid_area[2]:valid_area[3],:] = 1.0 148 | validmask = (validmask * area_mask).astype(bool) 149 | 150 | if median_scaling: 151 | pred_median = np.median(pred_depth[validmask]) 152 | gt_median = np.median(gt_depth[validmask]) 153 | ratio = gt_median/pred_median 154 | pred_depth_rescale = pred_depth * ratio 155 | else: 156 | pred_depth_rescale = pred_depth 157 | 158 | absrel_map = np.zeros(gt_depth.shape) 159 | absrel_map[validmask] = np.abs(gt_depth[validmask]-pred_depth_rescale[validmask]) / gt_depth[validmask] 160 | 161 | absrel_val = absrel_map.sum() / validmask.sum() 162 | 163 | return absrel_map, absrel_val 164 | 165 | 166 | # def save_concat_res(out_path, pred_disp, img, gt_depth): 167 | # ''' 168 | # input shape: h,w,c 169 | # output concatenated img-gt-depth 170 | # ''' 171 | # h,w,_ = gt_depth.shape 172 | 173 | # if img.shape[1]!=gt_depth.shape[1]: 174 | # img = jax.image.resize(img, (gt_depth.shape[0],gt_depth.shape[1], 3),"bilinear") 175 | # pred_disp = jax.image.resize(pred_disp, (gt_depth.shape[0],gt_depth.shape[1], 1),"bilinear") 176 | 177 | # pred_disp = np.array(pred_disp) 178 | # img = np.array(img) 179 | # gt_depth = np.array(gt_depth).squeeze(-1) 180 | 181 | # kernel = np.ones((5, 5), np.uint8) 182 | # gt_depth = cv2.dilate(gt_depth, kernel, iterations=1) 183 | 184 | # gt_depth = np.expand_dims(gt_depth,-1) 185 | 186 | # # get pil outputs 187 | # _, pred_depth = disp_to_depth_metric(pred_disp,min_depth=0.001,max_depth=80.0) 188 | 189 | # error_map, error_val = get_error_map_value(pred_depth, gt_depth) 190 | 191 | # error_pil = numpy_intensitymap_to_pcolor(error_map, vmin=0, vmax=0.5,colormap='jet') 192 | # pred_pil = numpy_intensitymap_to_pcolor(pred_depth) 193 | # gt_pil = numpy_intensitymap_to_pcolor(gt_depth) 194 | # img_pil = numpy_rgb_to_pil(img) 195 | 196 | 197 | # save_id = len(os.listdir(out_path)) 198 | # save_name = os.path.join(out_path,"{}.png".format(save_id)) 199 | 200 | # dst = Image.new('RGB', (w, h*4)) 201 | # dst.paste(img_pil, (0, 0)) 202 | # dst.paste(pred_pil, (0, h)) 203 | # dst.paste(gt_pil, (0, 2*h)) 204 | # dst.paste(error_pil, (0, 3*h)) 205 | 206 | 207 | 208 | # dst.save(save_name) 209 | 210 | 211 | def directly_save_intensitymap(input, out_path, save_id=None, post_fix="error"): 212 | ''' 213 | input shape: h,w 214 | ''' 215 | im = Image.fromarray(input) 216 | im.save(os.path.join(out_path, "{:06}_{}.png".format(save_id,post_fix))) 217 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pathlib 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | 8 | import data_loader.data_loaders as module_data 9 | import model.loss as module_loss 10 | import model.metric as module_metric 11 | import model.model as module_arch 12 | from evaluater import Evaluater 13 | from utils.parse_config import ConfigParser 14 | import torch 15 | 16 | 17 | torch.backends.cuda.matmul.allow_tf32 = False 18 | torch.backends.cudnn.allow_tf32 = False 19 | torch.backends.cudnn.benchmark = True 20 | 21 | 22 | def main(config: ConfigParser): 23 | logger = config.get_logger('train') 24 | 25 | # setup data_loader instances 26 | data_loader = config.initialize('data_loader', module_data) 27 | 28 | # get function handles of loss and metrics 29 | loss = getattr(module_loss, config['loss']) 30 | metrics = [getattr(module_metric, met) for met in config['metrics']] 31 | 32 | # build model architecture, then print to console 33 | 34 | if "arch" in config.config: 35 | models = [config.initialize('arch', module_arch)] 36 | else: 37 | models = config.initialize_list("models", module_arch) 38 | 39 | results = [] 40 | 41 | for i, model in enumerate(models): 42 | model_dict = dict(model.__dict__) 43 | keys = list(model_dict.keys()) 44 | for k in keys: 45 | if k.startswith("_"): 46 | model_dict.__delitem__(k) 47 | elif type(model_dict[k]) == np.ndarray: 48 | model_dict[k] = list(model_dict[k]) 49 | 50 | 51 | dataset_dict = dict(data_loader.dataset.__dict__) 52 | keys = list(dataset_dict.keys()) 53 | for k in keys: 54 | if k.startswith("_"): 55 | dataset_dict.__delitem__(k) 56 | elif type(dataset_dict[k]) == np.ndarray: 57 | dataset_dict[k] = list(dataset_dict[k]) 58 | elif isinstance(dataset_dict[k], pathlib.PurePath): 59 | dataset_dict[k] = str(dataset_dict[k]) 60 | 61 | 62 | logger.info(model_dict) 63 | logger.info(dataset_dict) 64 | 65 | logger.info(f"{sum(p.numel() for p in model.parameters())} total parameters") 66 | 67 | evaluater = Evaluater(model, loss, metrics, config=config, data_loader=data_loader) 68 | result = evaluater.eval(i) 69 | result["metrics"] = result["metrics"] 70 | del model 71 | result["metrics_info"] = [metric.__name__ for metric in metrics] 72 | logger.info(result) 73 | results.append({ 74 | "model": model_dict, 75 | "dataset": dataset_dict, 76 | "result": result 77 | }) 78 | 79 | save_file = Path(config.log_dir) / "results.json" 80 | with open(save_file, "w") as f: 81 | json.dump(results, f, indent=4) 82 | logger.info("Finished") 83 | 84 | 85 | if __name__ == "__main__": 86 | args = argparse.ArgumentParser(description='Deeptam Evaluation') 87 | args.add_argument('-c', '--config', default=None, type=str, 88 | help='config file path (default: None)') 89 | args.add_argument('-d', '--device', default=None, type=str, 90 | help='indices of GPUs to enable (default: all)') 91 | args.add_argument('-r', '--resume', default=None, type=str, 92 | help='path to latest checkpoint (default: None)') 93 | config = ConfigParser(args) 94 | print(config.config) 95 | main(config) 96 | -------------------------------------------------------------------------------- /evaluater/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluater import * -------------------------------------------------------------------------------- /evaluater/evaluater.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from base import BaseTrainer 7 | from utils import operator_on_dict, median_scaling 8 | 9 | 10 | class Evaluater(BaseTrainer): 11 | """ 12 | Trainer class 13 | 14 | Note: 15 | Inherited from BaseTrainer.trainer 16 | """ 17 | def __init__(self, model, loss, metrics, config, data_loader): 18 | super().__init__(model, loss, metrics, None, config) 19 | self.config = config 20 | self.data_loader = data_loader 21 | self.log_step = config["evaluater"].get("log_step", int(np.sqrt(data_loader.batch_size))) 22 | self.model = model 23 | self.loss = loss 24 | self.metrics = metrics 25 | self.len_data = len(self.data_loader) 26 | 27 | if isinstance(loss, torch.nn.Module): 28 | self.loss.to(self.device) 29 | if len(self.device_ids) > 1: 30 | self.loss = torch.nn.DataParallel(self.loss, self.device_ids) 31 | 32 | self.roi = config["evaluater"].get("roi", None) 33 | self.alpha = config["evaluater"].get("alpha", None) 34 | self.max_distance = config["evaluater"].get("max_distance", None) 35 | self.correct_length = config["evaluater"].get("correct_length", False) 36 | self.median_scaling = config["evaluater"].get("median_scaling", False) 37 | self.eval_mono = config["evaluater"].get("eval_mono", False) 38 | 39 | def _eval_metrics(self, data_dict): 40 | acc_metrics = np.zeros(len(self.metrics)) 41 | acc_metrics_mv = np.zeros(len(self.metrics)) 42 | for i, metric in enumerate(self.metrics): 43 | if self.median_scaling: 44 | data_dict = median_scaling(data_dict) 45 | acc_metrics[i] += metric(data_dict, self.roi, self.max_distance, eval_mono=self.eval_mono) 46 | acc_metrics_mv[i] += metric(data_dict, self.roi, self.max_distance, use_cvmask=True, eval_mono = self.eval_mono) 47 | #self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i]) 48 | if np.any(np.isnan(acc_metrics)): 49 | acc_metrics = np.zeros(len(self.metrics)) 50 | valid = np.zeros(len(self.metrics)) 51 | else: 52 | valid = np.ones(len(self.metrics)) 53 | 54 | if np.any(np.isnan(acc_metrics_mv)): 55 | acc_metrics_mv = np.zeros(len(self.metrics)) 56 | valid_mv = np.zeros(len(self.metrics)) 57 | else: 58 | valid_mv = np.ones(len(self.metrics)) 59 | 60 | return acc_metrics, valid, acc_metrics_mv, valid_mv 61 | 62 | def eval(self, model_index): 63 | """ 64 | Training logic for an epoch 65 | 66 | :param model_index: Current training epoch. 67 | :return: A log that contains all information you want to save. 68 | 69 | Note: 70 | If you have additional information to record, for example: 71 | > additional_log = {"x": x, "y": y} 72 | merge it with log before return. i.e. 73 | > log = {**log, **additional_log} 74 | > return log 75 | 76 | The metrics in log must have the key 'metrics'. 77 | """ 78 | self.model.eval() 79 | 80 | total_loss = 0 81 | total_loss_dict = {} 82 | total_metrics = np.zeros(len(self.metrics)) 83 | total_metrics_valid = np.zeros(len(self.metrics)) 84 | 85 | total_metrics_mv = np.zeros(len(self.metrics)) 86 | total_metrics_valid_mv = np.zeros(len(self.metrics)) 87 | 88 | total_metrics_runningavg = np.zeros(len(self.metrics)) 89 | num_samples = 0 90 | 91 | for batch_idx, (data, target) in enumerate(self.data_loader): 92 | data, target = to(data, self.device), to(target, self.device) 93 | data["target"] = target 94 | 95 | with torch.no_grad(): 96 | data = self.model(data) 97 | loss_dict = {"loss": torch.tensor([0])} 98 | loss = loss_dict["loss"] 99 | 100 | output = data["result"] 101 | 102 | #self.writer.set_step((model_index - 1) * self.len_data + batch_idx) 103 | #self.writer.add_scalar('loss', loss.item()) 104 | total_loss += loss.item() 105 | total_loss_dict = operator_on_dict(total_loss_dict, loss_dict, operator.add) 106 | metrics, valid, metrics_mv, valid_mv = self._eval_metrics(data) 107 | total_metrics += metrics 108 | total_metrics_valid += valid 109 | 110 | total_metrics_mv += metrics_mv 111 | total_metrics_valid_mv += valid_mv 112 | 113 | batch_size = target.shape[0] 114 | if num_samples == 0: 115 | total_metrics_runningavg += metrics 116 | else: 117 | total_metrics_runningavg = total_metrics_runningavg * (num_samples / (num_samples + batch_size)) + \ 118 | metrics * (batch_size / (num_samples + batch_size)) 119 | num_samples += batch_size 120 | 121 | if batch_idx % self.log_step == 0: 122 | self.logger.debug(f'Evaluating {self._progress(batch_idx)} Loss: {loss.item() / (batch_idx + 1):.6f} Metrics: {list(total_metrics / (batch_idx + 1))}') 123 | #self.writer.add_image('input', make_grid(to(data["keyframe"], "cpu"), nrow=3, normalize=True)) 124 | #self.writer.add_image('output', make_grid(to(torch.clamp(1 / output, 0, 100), "cpu") , nrow=3, normalize=True)) 125 | #self.writer.add_image('ground_truth', make_grid(to(torch.clamp(1 / target, 0, 100), "cpu"), nrow=3, normalize=True)) 126 | 127 | if batch_idx == self.len_data: 128 | break 129 | 130 | log = { 131 | 'loss': total_loss / self.len_data, 132 | 'metrics': self.save_digits((total_metrics / total_metrics_valid).tolist()), 133 | 'metrics_mv': self.save_digits((total_metrics_mv / total_metrics_valid_mv).tolist()), 134 | 'metrics_correct': self.save_digits(total_metrics_runningavg.tolist()), 135 | 'valid_batches': total_metrics_valid[0], 136 | 'valid_batches_mv': total_metrics_valid_mv[0] 137 | } 138 | for loss_component, v in total_loss_dict.items(): 139 | log[f"loss_{loss_component}"] = v.item() / self.len_data 140 | 141 | return log 142 | 143 | def save_digits(self, input_list): 144 | return [float('{:.3f}'.format(i)) for i in input_list] 145 | 146 | def _progress(self, batch_idx): 147 | base = '[{}/{} ({:.0f}%)]' 148 | if hasattr(self.data_loader, 'n_samples'): 149 | current = batch_idx * self.data_loader.batch_size 150 | total = self.data_loader.n_samples 151 | else: 152 | current = batch_idx 153 | total = self.len_data 154 | return base.format(current, total, 100.0 * current / total) 155 | 156 | 157 | def to(data, device): 158 | if isinstance(data, dict): 159 | return {k: to(data[k], device) for k in data.keys()} 160 | elif isinstance(data, list): 161 | return [to(v, device) for v in data] 162 | else: 163 | return data.to(device) 164 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from utils import Timer 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ 27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ 28 | "the 'config.json' file." 29 | logger.warning(message) 30 | 31 | self.step = 0 32 | self.mode = '' 33 | 34 | self.tb_writer_ftns = { 35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 37 | } 38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 39 | 40 | self.timer = Timer() 41 | 42 | def set_step(self, step, mode='train'): 43 | self.mode = mode 44 | self.step = step 45 | if step == 0: 46 | self.timer.reset() 47 | else: 48 | duration = self.timer.check() 49 | self.add_scalar('steps_per_sec', 1 / duration) 50 | 51 | def __getattr__(self, name): 52 | """ 53 | If visualization is configured to use: 54 | return add_data() methods of tensorboard with additional information (step, tag) added. 55 | Otherwise: 56 | return a blank function handle that does nothing 57 | """ 58 | if name in self.tb_writer_ftns: 59 | add_data = getattr(self.writer, name, None) 60 | 61 | def wrapper(tag, data, *args, **kwargs): 62 | if add_data is not None: 63 | # add mode(train/valid) tag 64 | if name not in self.tag_mode_exceptions: 65 | tag = '{}/{}'.format(tag, self.mode) 66 | add_data(tag, data, self.step, *args, **kwargs) 67 | return wrapper 68 | else: 69 | # default action for returning methods defined in this class, set_step() for instance. 70 | try: 71 | attr = object.__getattr__(name) 72 | except AttributeError: 73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 74 | return attr 75 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/__init__.py -------------------------------------------------------------------------------- /model/dymultidepth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/dymultidepth/__init__.py -------------------------------------------------------------------------------- /model/dymultidepth/ccf_modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torchvision 5 | from torch import nn 6 | 7 | 8 | class CrossCueFusion(nn.Module): 9 | def __init__(self, cv_hypo_num=32, mid_dim=32, input_size=(256,512)): 10 | super().__init__() 11 | self.cv_hypo_num = cv_hypo_num 12 | self.mid_dim = mid_dim 13 | self.residual_connection =True 14 | self.is_reduce = True if input_size[1]>650 else False 15 | 16 | if not self.is_reduce: 17 | self.mono_expand = nn.Sequential( 18 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 19 | nn.BatchNorm2d(self.mid_dim), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 22 | nn.BatchNorm2d(self.mid_dim), 23 | nn.ReLU(inplace=True) 24 | ) 25 | self.multi_expand = nn.Sequential( 26 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 27 | nn.BatchNorm2d(self.mid_dim), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 30 | nn.BatchNorm2d(self.mid_dim), 31 | nn.ReLU(inplace=True) 32 | ) 33 | else: 34 | self.mono_expand = nn.Sequential( 35 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 36 | nn.BatchNorm2d(self.mid_dim), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 39 | nn.BatchNorm2d(self.mid_dim), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 42 | nn.BatchNorm2d(self.mid_dim), 43 | nn.ReLU(inplace=True) 44 | ) 45 | 46 | self.multi_expand = nn.Sequential( 47 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 48 | nn.BatchNorm2d(self.mid_dim), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 51 | nn.BatchNorm2d(self.mid_dim), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 54 | nn.BatchNorm2d(self.mid_dim), 55 | nn.ReLU(inplace=True) 56 | ) 57 | self.kq_dim = self.mid_dim //4 if self.mid_dim>128 else self.mid_dim 58 | 59 | self.lin_mono_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 60 | self.lin_mono_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 61 | self.lin_mono_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1) 62 | 63 | self.lin_multi_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 64 | self.lin_multi_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 65 | self.lin_multi_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1) 66 | 67 | self.softmax = nn.Softmax(dim=-1) 68 | 69 | if self.residual_connection: 70 | self.mono_reg = nn.Sequential( 71 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1), 72 | nn.BatchNorm2d(self.mid_dim), 73 | nn.ReLU(inplace=True), 74 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1), 75 | nn.BatchNorm2d(self.mid_dim), 76 | nn.ReLU(inplace=True) 77 | ) 78 | self.multi_reg = nn.Sequential( 79 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=1, padding=0), 80 | nn.BatchNorm2d(self.mid_dim), 81 | nn.ReLU(inplace=True) 82 | ) 83 | self.gamma = nn.Parameter(torch.zeros(1)) 84 | 85 | def forward(self, mono_pseudo_cost, cost_volume): 86 | init_b, init_c, init_h, init_w = cost_volume.shape 87 | mono_feat = self.mono_expand(mono_pseudo_cost) 88 | multi_feat = self.multi_expand(cost_volume) 89 | b,c,h,w = multi_feat.shape 90 | 91 | # cross-cue attention 92 | mono_q = self.lin_mono_q(mono_feat).view(b,-1,h*w).permute(0,2,1) 93 | mono_k = self.lin_mono_k(mono_feat).view(b,-1,h*w) 94 | mono_score = torch.bmm(mono_q, mono_k) 95 | mono_atten = self.softmax(mono_score) 96 | 97 | multi_q = self.lin_multi_q(multi_feat).view(b,-1,h*w).permute(0,2,1) 98 | multi_k = self.lin_multi_k(multi_feat).view(b,-1,h*w) 99 | multi_score = torch.bmm(multi_q, multi_k) 100 | multi_atten = self.softmax(multi_score) 101 | 102 | mono_v = self.lin_mono_v(mono_feat).view(b,-1,h*w) 103 | mono_out = torch.bmm(mono_v, multi_atten.permute(0,2,1)) 104 | mono_out = mono_out.view(b,self.mid_dim, h,w) 105 | 106 | multi_v = self.lin_multi_v(multi_feat).view(b,-1,h*w) 107 | multi_out = torch.bmm(multi_v, mono_atten.permute(0,2,1)) 108 | multi_out = multi_out.view(b,self.mid_dim, h,w) 109 | 110 | 111 | # concatenate and upsample 112 | fused = torch.cat((multi_out,mono_out), dim=1) 113 | fused = torch.nn.functional.interpolate(fused, size=(init_h,init_w)) 114 | 115 | if self.residual_connection: 116 | mono_residual = self.mono_reg(mono_pseudo_cost) 117 | multi_residual = self.multi_reg(cost_volume) 118 | fused_cat = torch.cat((mono_residual,multi_residual), dim=1) 119 | fused = fused_cat + self.gamma * fused 120 | 121 | return fused 122 | 123 | 124 | 125 | class MultiGuideMono(nn.Module): 126 | def __init__(self, cv_hypo_num=32, mid_dim=32, input_size=(256,512)): 127 | super().__init__() 128 | self.cv_hypo_num = cv_hypo_num 129 | self.mid_dim = mid_dim 130 | self.residual_connection =True 131 | self.is_reduce = True if input_size[1]>650 else False 132 | 133 | if not self.is_reduce: 134 | self.mono_expand = nn.Sequential( 135 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 136 | nn.BatchNorm2d(self.mid_dim), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 139 | nn.BatchNorm2d(self.mid_dim), 140 | nn.ReLU(inplace=True) 141 | ) 142 | self.multi_expand = nn.Sequential( 143 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 144 | nn.BatchNorm2d(self.mid_dim), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 147 | nn.BatchNorm2d(self.mid_dim), 148 | nn.ReLU(inplace=True) 149 | ) 150 | else: 151 | self.mono_expand = nn.Sequential( 152 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 153 | nn.BatchNorm2d(self.mid_dim), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 156 | nn.BatchNorm2d(self.mid_dim), 157 | nn.ReLU(inplace=True), 158 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 159 | nn.BatchNorm2d(self.mid_dim), 160 | nn.ReLU(inplace=True) 161 | ) 162 | 163 | self.multi_expand = nn.Sequential( 164 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 165 | nn.BatchNorm2d(self.mid_dim), 166 | nn.ReLU(inplace=True), 167 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 168 | nn.BatchNorm2d(self.mid_dim), 169 | nn.ReLU(inplace=True), 170 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 171 | nn.BatchNorm2d(self.mid_dim), 172 | nn.ReLU(inplace=True) 173 | ) 174 | self.kq_dim = self.mid_dim //4 if self.mid_dim>128 else self.mid_dim 175 | 176 | self.lin_mono_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1) 177 | 178 | self.lin_multi_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 179 | self.lin_multi_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 180 | 181 | self.softmax = nn.Softmax(dim=-1) 182 | 183 | if self.residual_connection: 184 | self.mono_reg = nn.Sequential( 185 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1), 186 | nn.BatchNorm2d(self.mid_dim), 187 | nn.ReLU(inplace=True), 188 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1), 189 | nn.BatchNorm2d(self.mid_dim), 190 | nn.ReLU(inplace=True) 191 | ) 192 | self.gamma = nn.Parameter(torch.zeros(1)) 193 | 194 | def forward(self, mono_pseudo_cost, cost_volume): 195 | init_b, init_c, init_h, init_w = cost_volume.shape 196 | mono_feat = self.mono_expand(mono_pseudo_cost) 197 | multi_feat = self.multi_expand(cost_volume) 198 | b,c,h,w = multi_feat.shape 199 | 200 | # multi attention 201 | 202 | multi_q = self.lin_multi_q(multi_feat).view(b,-1,h*w).permute(0,2,1) 203 | multi_k = self.lin_multi_k(multi_feat).view(b,-1,h*w) 204 | multi_score = torch.bmm(multi_q, multi_k) 205 | multi_atten = self.softmax(multi_score) 206 | 207 | mono_v = self.lin_mono_v(mono_feat).view(b,-1,h*w) 208 | mono_out = torch.bmm(mono_v, multi_atten.permute(0,2,1)) 209 | mono_out = mono_out.view(b,self.mid_dim, h,w) 210 | 211 | # upsample 212 | fused = torch.nn.functional.interpolate(mono_out, size=(init_h,init_w)) 213 | 214 | if self.residual_connection: 215 | mono_residual = self.mono_reg(mono_pseudo_cost) 216 | fused = mono_residual + self.gamma * fused 217 | 218 | return fused 219 | 220 | 221 | class MonoGuideMulti(nn.Module): 222 | def __init__(self, cv_hypo_num=32, mid_dim=32, input_size=(256,512)): 223 | super().__init__() 224 | self.cv_hypo_num = cv_hypo_num 225 | self.mid_dim = mid_dim 226 | self.residual_connection =True 227 | self.is_reduce = True if input_size[1]>650 else False 228 | 229 | if not self.is_reduce: 230 | self.mono_expand = nn.Sequential( 231 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 232 | nn.BatchNorm2d(self.mid_dim), 233 | nn.ReLU(inplace=True), 234 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 235 | nn.BatchNorm2d(self.mid_dim), 236 | nn.ReLU(inplace=True) 237 | ) 238 | self.multi_expand = nn.Sequential( 239 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 240 | nn.BatchNorm2d(self.mid_dim), 241 | nn.ReLU(inplace=True), 242 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 243 | nn.BatchNorm2d(self.mid_dim), 244 | nn.ReLU(inplace=True) 245 | ) 246 | else: 247 | self.mono_expand = nn.Sequential( 248 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 249 | nn.BatchNorm2d(self.mid_dim), 250 | nn.ReLU(inplace=True), 251 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 252 | nn.BatchNorm2d(self.mid_dim), 253 | nn.ReLU(inplace=True), 254 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 255 | nn.BatchNorm2d(self.mid_dim), 256 | nn.ReLU(inplace=True) 257 | ) 258 | 259 | self.multi_expand = nn.Sequential( 260 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 261 | nn.BatchNorm2d(self.mid_dim), 262 | nn.ReLU(inplace=True), 263 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 264 | nn.BatchNorm2d(self.mid_dim), 265 | nn.ReLU(inplace=True), 266 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=3, padding=1, stride=2), 267 | nn.BatchNorm2d(self.mid_dim), 268 | nn.ReLU(inplace=True) 269 | ) 270 | self.kq_dim = self.mid_dim //4 if self.mid_dim>128 else self.mid_dim 271 | 272 | self.lin_mono_k = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 273 | self.lin_mono_q = nn.Conv2d(self.mid_dim, self.kq_dim, kernel_size=1) 274 | 275 | self.lin_multi_v = nn.Conv2d(self.mid_dim, self.mid_dim, kernel_size=1) 276 | 277 | self.softmax = nn.Softmax(dim=-1) 278 | 279 | if self.residual_connection: 280 | self.multi_reg = nn.Sequential( 281 | nn.Conv2d(self.cv_hypo_num, self.mid_dim, kernel_size=1, padding=0), 282 | nn.BatchNorm2d(self.mid_dim), 283 | nn.ReLU(inplace=True) 284 | ) 285 | self.gamma = nn.Parameter(torch.zeros(1)) 286 | 287 | def forward(self, mono_pseudo_cost, cost_volume): 288 | init_b, init_c, init_h, init_w = cost_volume.shape 289 | mono_feat = self.mono_expand(mono_pseudo_cost) 290 | multi_feat = self.multi_expand(cost_volume) 291 | b,c,h,w = multi_feat.shape 292 | 293 | # mono attention 294 | mono_q = self.lin_mono_q(mono_feat).view(b,-1,h*w).permute(0,2,1) 295 | mono_k = self.lin_mono_k(mono_feat).view(b,-1,h*w) 296 | mono_score = torch.bmm(mono_q, mono_k) 297 | mono_atten = self.softmax(mono_score) 298 | 299 | multi_v = self.lin_multi_v(multi_feat).view(b,-1,h*w) 300 | multi_out = torch.bmm(multi_v, mono_atten.permute(0,2,1)) 301 | multi_out = multi_out.view(b,self.mid_dim, h,w) 302 | 303 | 304 | # upsample 305 | fused = torch.nn.functional.interpolate(multi_out, size=(init_h,init_w)) 306 | 307 | if self.residual_connection: 308 | multi_residual = self.multi_reg(cost_volume) 309 | fused = multi_residual + self.gamma * fused 310 | 311 | return fused 312 | -------------------------------------------------------------------------------- /model/dymultidepth/dymultidepth_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import kornia.augmentation as K 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from torch import nn 9 | 10 | from model.layers import point_projection, PadSameConv2d, ConvReLU2, ConvReLU, Upconv, Refine, SSIM, Backprojection 11 | from utils import conditional_flip, filter_state_dict 12 | 13 | from utils import parse_config 14 | from .base_models import DepthAugmentation, MaskAugmentation, CostVolumeModule, DepthModule, ResnetEncoder,MonoDepthModule, EfficientNetEncoder 15 | from .ccf_modules import * 16 | 17 | class DyMultiDepthModel(nn.Module): 18 | def __init__(self, inv_depth_min_max=(0.33, 0.0025), cv_depth_steps=32, pretrain_mode=False, pretrain_dropout=0.0, pretrain_dropout_mode=0, 19 | augmentation=None, use_mono=True, use_stereo=False, use_ssim=True, sfcv_mult_mask=True, 20 | simple_mask=False, mask_use_cv=True, mask_use_feats=True, cv_patch_size=3, depth_large_model=False, no_cv=False, 21 | freeze_backbone=True, freeze_module=(), checkpoint_location=None, mask_cp_loc=None, depth_cp_loc=None, 22 | fusion_type = 'ccf_fusion', input_size=[256, 512], ccf_mid_dim=32, use_img_in_depthnet=True, 23 | backbone_type='resnet18'): 24 | """ 25 | :param inv_depth_min_max: Min / max (inverse) depth. (Default=(0.33, 0.0025)) 26 | :param cv_depth_steps: Number of depth steps for the cost volume. (Default=32) 27 | :param pretrain_mode: Which pretrain mode to use: 28 | 0 / False: Run full network. 29 | 1 / True: Only run depth module. In this mode, dropout can be activated to zero out patches from the 30 | unmasked cost volume. Dropout was not used for the paper. 31 | 2: Only run mask module. In this mode, the network will return the mask as the main result. 32 | 3: Only run depth module, but use the auxiliary masks to mask the cost volume. This mode was not used in 33 | the paper. (Default=0) 34 | :param pretrain_dropout: Dropout rate used in pretrain_mode=1. (Default=0) 35 | :param augmentation: Which augmentation module to use. "mask"=MaskAugmentation, "depth"=DepthAugmentation. The 36 | exact way to use this is very context dependent. Refer to the training scripts for more details. (Default="none") 37 | :param use_mono: Use monocular frames during the forward pass. (Default=True) 38 | :param use_stereo: Use stereo frame during the forward pass. (Default=False) 39 | :param use_ssim: Use SSIM during cost volume computation. (Default=True) 40 | :param sfcv_mult_mask: For the single frame cost volumes: If a pixel does not have a valid reprojection at any 41 | depth step, all depths get invalidated. (Default=True) 42 | :param simple_mask: Use the standard cost volume instead of multiple single frame cost volumes in the mask 43 | module. (Default=False) 44 | :param cv_patch_size: Patchsize, over which the ssim errors get averaged. (Default=3) 45 | :param freeze_module: Freeze given string list of modules. (Default=()) 46 | :param checkpoint_location: Load given list of checkpoints. (Default=None) 47 | :param mask_cp_loc: Load list of checkpoints for the mask module. (Default=None) 48 | :param depth_cp_loc: Load list of checkpoints for the depth module. (Default=None) 49 | """ 50 | super().__init__() 51 | self.inv_depth_min_max = inv_depth_min_max 52 | self.cv_depth_steps = cv_depth_steps 53 | self.use_mono = use_mono 54 | self.use_stereo = use_stereo 55 | self.use_ssim = use_ssim 56 | self.sfcv_mult_mask = sfcv_mult_mask 57 | self.pretrain_mode = int(pretrain_mode) 58 | self.pretrain_dropout = pretrain_dropout 59 | self.pretrain_dropout_mode = pretrain_dropout_mode 60 | self.augmentation = augmentation 61 | self.simple_mask = simple_mask 62 | self.mask_use_cv = mask_use_cv 63 | self.mask_use_feats = mask_use_feats 64 | self.cv_patch_size = cv_patch_size 65 | self.no_cv = no_cv 66 | self.depth_large_model = depth_large_model 67 | self.checkpoint_location = checkpoint_location 68 | self.mask_cp_loc = mask_cp_loc 69 | self.depth_cp_loc = depth_cp_loc 70 | self.freeze_module = freeze_module 71 | self.freeze_backbone = freeze_backbone 72 | 73 | self.fusion_type = fusion_type 74 | self.input_size = input_size 75 | self.ccf_mid_dim = ccf_mid_dim 76 | self.use_img_in_depthnet = use_img_in_depthnet 77 | self.backbone_type = backbone_type 78 | 79 | assert self.backbone_type in ["resnet18", "efficientnetb5"] 80 | 81 | self.depthmodule_in_chn = self.cv_depth_steps 82 | if fusion_type == 'ccf_fusion': 83 | self.extra_input_dim = 0 84 | self.fusion_module = CrossCueFusion(cv_hypo_num=self.cv_depth_steps, mid_dim=32, input_size=self.input_size) 85 | self.depthmodule_in_chn = self.ccf_mid_dim * 2 86 | elif fusion_type == 'mono_guide_multi': 87 | self.extra_input_dim = 0 88 | self.fusion_module = MonoGuideMulti(cv_hypo_num=self.cv_depth_steps, mid_dim=32, input_size=self.input_size) 89 | self.depthmodule_in_chn = self.ccf_mid_dim 90 | elif fusion_type == 'multi_guide_mono': 91 | self.extra_input_dim = 0 92 | self.fusion_module = MultiGuideMono(cv_hypo_num=self.cv_depth_steps, mid_dim=32, input_size=self.input_size) 93 | self.depthmodule_in_chn = self.ccf_mid_dim 94 | 95 | if self.backbone_type == 'resnet18': 96 | self._feature_extractor = ResnetEncoder(num_layers=18, pretrained=True) 97 | elif self.backbone_type == 'efficientnetb5': 98 | self._feature_extractor = EfficientNetEncoder(pretrained=True) 99 | 100 | if self.freeze_backbone: 101 | for p in self._feature_extractor.parameters(True): 102 | p.requires_grad_(False) 103 | 104 | self.cv_module = CostVolumeModule(use_mono=use_mono, use_stereo=use_stereo, use_ssim=use_ssim, sfcv_mult_mask=self.sfcv_mult_mask, patch_size=cv_patch_size) 105 | 106 | self.depth_module = DepthModule(self.depthmodule_in_chn, feature_channels=self._feature_extractor.num_ch_enc, 107 | large_model=self.depth_large_model, use_input_img=self.use_img_in_depthnet) 108 | self.mono_module = MonoDepthModule(extra_input_dim=self.extra_input_dim, 109 | feature_channels=self._feature_extractor.num_ch_enc, large_model=self.depth_large_model) 110 | 111 | if self.checkpoint_location is not None: 112 | if not isinstance(checkpoint_location, list): 113 | checkpoint_location = [checkpoint_location] 114 | for cp in checkpoint_location: 115 | checkpoint = torch.load(cp, map_location=torch.device("cpu")) 116 | checkpoint_state_dict = checkpoint["state_dict"] 117 | checkpoint_state_dict = filter_state_dict(checkpoint_state_dict, checkpoint["arch"] == "DataParallel") 118 | self.load_state_dict(checkpoint_state_dict, strict=True) 119 | 120 | 121 | for module_name in self.freeze_module: 122 | module = self.__getattr__(module_name + "_module") 123 | module.eval() 124 | for param in module.parameters(True): 125 | param.requires_grad_(False) 126 | 127 | if self.augmentation == "depth": 128 | self.augmenter = DepthAugmentation() 129 | elif self.augmentation == "mask": 130 | self.augmenter = MaskAugmentation() 131 | else: 132 | self.augmenter = None 133 | 134 | def forward(self, data_dict): 135 | keyframe = data_dict["keyframe"] 136 | 137 | data_dict["inv_depth_min"] = keyframe.new_tensor([self.inv_depth_min_max[0]]) 138 | data_dict["inv_depth_max"] = keyframe.new_tensor([self.inv_depth_min_max[1]]) 139 | data_dict["cv_depth_steps"] = keyframe.new_tensor([self.cv_depth_steps], dtype=torch.int32) 140 | 141 | with torch.no_grad(): 142 | data_dict = self.cv_module(data_dict) 143 | 144 | if self.augmenter is not None and self.training: 145 | self.augmenter(data_dict) 146 | 147 | # different with MonoRec: the input image should be the reverted 148 | data_dict["image_features"] = self._feature_extractor(data_dict["keyframe"] + .5) 149 | 150 | data_dict["cost_volume_init"] = data_dict["cost_volume"] 151 | 152 | data_dict = self.mono_module(data_dict) 153 | data_dict["predicted_inverse_depths_mono"] = [(1-pred) * self.inv_depth_min_max[1] + pred * self.inv_depth_min_max[0] 154 | for pred in data_dict["predicted_inverse_depths_mono"]] 155 | mono_depth_pred = torch.clamp(1.0 / data_dict["predicted_inverse_depths_mono"][0], min=1e-3, max=80.0).detach() 156 | 157 | b, c, h, w = keyframe.shape 158 | 159 | 160 | pseudo_mono_cost = self.pseudocost_from_mono(mono_depth_pred, 161 | depth_hypothesis = data_dict["cv_bin_steps"].view(1, -1, 1, 1).expand(b, -1, h, w).detach()).detach() 162 | 163 | 164 | if self.training: 165 | if self.pretrain_dropout_mode == 0: 166 | cv_mask = keyframe.new_ones(b, 1, h // 8, w // 8, requires_grad=False) 167 | F.dropout(cv_mask, p=1 - self.pretrain_dropout, training=self.training, inplace=True) 168 | cv_mask = (cv_mask!=0).float() 169 | cv_mask = F.upsample(cv_mask, (h, w)) 170 | else: 171 | cv_mask = keyframe.new_ones(b, 1, 1, 1, requires_grad=False) 172 | F.dropout(cv_mask, p = 1 - self.pretrain_dropout, training=self.training, inplace=True) 173 | cv_mask = cv_mask.expand(-1, -1, h, w) 174 | else: 175 | cv_mask = keyframe.new_zeros(b, 1, h, w, requires_grad=False) 176 | data_dict["cv_mask"] = cv_mask 177 | 178 | 179 | data_dict["cost_volume"] = (1 - data_dict["cv_mask"]) * self.fusion_module(pseudo_mono_cost, data_dict["cost_volume"]) 180 | 181 | data_dict = self.depth_module(data_dict) 182 | 183 | data_dict["predicted_inverse_depths"] = [(1-pred) * self.inv_depth_min_max[1] + pred * self.inv_depth_min_max[0] 184 | for pred in data_dict["predicted_inverse_depths"]] 185 | 186 | if self.augmenter is not None and self.training: 187 | self.augmenter.revert(data_dict) 188 | 189 | data_dict["result"] = data_dict["predicted_inverse_depths"][0] 190 | data_dict["result_mono"] = data_dict["predicted_inverse_depths_mono"][0] 191 | data_dict["mask"] = data_dict["cv_mask"] 192 | 193 | return data_dict 194 | 195 | 196 | def pseudocost_from_mono(self, monodepth, depth_hypothesis): 197 | abs_depth_diff = torch.abs(monodepth - depth_hypothesis) 198 | # find the closest depth bin that the monodepth correlate with 199 | min_diff_index = torch.argmin(abs_depth_diff, dim=1, keepdim=True) 200 | pseudo_cost = depth_hypothesis.new_zeros(depth_hypothesis.shape) 201 | ones = depth_hypothesis.new_ones(depth_hypothesis.shape) 202 | 203 | pseudo_cost.scatter_(dim = 1, index = min_diff_index, src = ones) 204 | 205 | return pseudo_cost 206 | 207 | def find_mincost_depth(self, cost_volume, depth_hypos): 208 | argmax = torch.argmax(cost_volume, dim=1, keepdim=True) 209 | mincost_depth = torch.gather(input=depth_hypos, dim=1, index=argmax) 210 | return mincost_depth -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F, Conv2d, LeakyReLU, Upsample, Sigmoid, ConvTranspose2d 9 | 10 | 11 | class ConvBlock(nn.Module): 12 | """Layer to perform a convolution followed by ELU 13 | """ 14 | def __init__(self, in_channels, out_channels): 15 | super(ConvBlock, self).__init__() 16 | 17 | self.conv = Conv3x3(in_channels, out_channels) 18 | self.nonlin = nn.ELU(inplace=True) 19 | 20 | def forward(self, x): 21 | out = self.conv(x) 22 | out = self.nonlin(out) 23 | return out 24 | 25 | 26 | class Conv3x3(nn.Module): 27 | """Layer to pad and convolve input 28 | """ 29 | def __init__(self, in_channels, out_channels, use_refl=True): 30 | super(Conv3x3, self).__init__() 31 | 32 | if use_refl: 33 | self.pad = nn.ReflectionPad2d(1) 34 | else: 35 | self.pad = nn.ZeroPad2d(1) 36 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 37 | 38 | def forward(self, x): 39 | out = self.pad(x) 40 | out = self.conv(out) 41 | return out 42 | 43 | class Backprojection(nn.Module): 44 | def __init__(self, batch_size, height, width): 45 | super(Backprojection, self).__init__() 46 | 47 | self.N, self.H, self.W = batch_size, height, width 48 | 49 | yy, xx = torch.meshgrid([torch.arange(0., float(self.H)), torch.arange(0., float(self.W))]) 50 | yy = yy.contiguous().view(-1) 51 | xx = xx.contiguous().view(-1) 52 | self.ones = nn.Parameter(torch.ones(self.N, 1, self.H * self.W), requires_grad=False) 53 | self.coord = torch.unsqueeze(torch.stack([xx, yy], 0), 0).repeat(self.N, 1, 1) 54 | self.coord = nn.Parameter(torch.cat([self.coord, self.ones], 1), requires_grad=False) 55 | 56 | def forward(self, depth, inv_K) : 57 | cam_p_norm = torch.matmul(inv_K[:, :3, :3], self.coord[:depth.shape[0], :, :]) 58 | cam_p_euc = depth.view(depth.shape[0], 1, -1) * cam_p_norm 59 | cam_p_h = torch.cat([cam_p_euc, self.ones[:depth.shape[0], :, :]], 1) 60 | 61 | return cam_p_h 62 | 63 | def point_projection(points3D, batch_size, height, width, K, T): 64 | N, H, W = batch_size, height, width 65 | cam_coord = torch.matmul(torch.matmul(K, T)[:, :3, :], points3D) 66 | img_coord = cam_coord[:, :2, :] / (cam_coord[:, 2:3, :] + 1e-7) 67 | img_coord[:, 0, :] /= W - 1 68 | img_coord[:, 1, :] /= H - 1 69 | img_coord = (img_coord - 0.5) * 2 70 | img_coord = img_coord.view(N, 2, H, W).permute(0, 2, 3, 1) 71 | return img_coord 72 | 73 | def upsample(x): 74 | """Upsample input tensor by a factor of 2 75 | """ 76 | return F.interpolate(x, scale_factor=2, mode="nearest") 77 | 78 | 79 | class GaussianAverage(nn.Module): 80 | def __init__(self) -> None: 81 | super().__init__() 82 | self.window = torch.Tensor([ 83 | [0.0947, 0.1183, 0.0947], 84 | [0.1183, 0.1478, 0.1183], 85 | [0.0947, 0.1183, 0.0947]]) 86 | 87 | def forward(self, x): 88 | kernel = self.window.to(x.device).to(x.dtype).repeat(x.shape[1], 1, 1, 1) 89 | return F.conv2d(x, kernel, padding=0, groups=x.shape[1]) 90 | 91 | class SSIM(nn.Module): 92 | """Layer to compute the SSIM loss between a pair of images 93 | """ 94 | def __init__(self, pad_reflection=True, gaussian_average=False, comp_mode=False): 95 | super(SSIM, self).__init__() 96 | self.comp_mode = comp_mode 97 | 98 | if not gaussian_average: 99 | self.mu_x_pool = nn.AvgPool2d(3, 1) 100 | self.mu_y_pool = nn.AvgPool2d(3, 1) 101 | self.sig_x_pool = nn.AvgPool2d(3, 1) 102 | self.sig_y_pool = nn.AvgPool2d(3, 1) 103 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 104 | else: 105 | self.mu_x_pool = GaussianAverage() 106 | self.mu_y_pool = GaussianAverage() 107 | self.sig_x_pool = GaussianAverage() 108 | self.sig_y_pool = GaussianAverage() 109 | self.sig_xy_pool = GaussianAverage() 110 | 111 | if pad_reflection: 112 | self.pad = nn.ReflectionPad2d(1) 113 | else: 114 | self.pad = nn.ZeroPad2d(1) 115 | 116 | self.C1 = 0.01 ** 2 117 | self.C2 = 0.03 ** 2 118 | 119 | def forward(self, x, y): 120 | x = self.pad(x) 121 | y = self.pad(y) 122 | 123 | mu_x = self.mu_x_pool(x) 124 | mu_y = self.mu_y_pool(y) 125 | mu_x_sq = mu_x ** 2 126 | mu_y_sq = mu_y ** 2 127 | mu_x_y = mu_x * mu_y 128 | 129 | sigma_x = self.sig_x_pool(x ** 2) - mu_x_sq 130 | sigma_y = self.sig_y_pool(y ** 2) - mu_y_sq 131 | sigma_xy = self.sig_xy_pool(x * y) - mu_x_y 132 | 133 | SSIM_n = (2 * mu_x_y + self.C1) * (2 * sigma_xy + self.C2) 134 | SSIM_d = (mu_x_sq + mu_y_sq + self.C1) * (sigma_x + sigma_y + self.C2) 135 | 136 | if not self.comp_mode: 137 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 138 | else: 139 | return torch.clamp((1 - SSIM_n / SSIM_d), 0, 1) / 2 140 | 141 | 142 | def ssim(x, y, pad_reflection=True, gaussian_average=False, comp_mode=False): 143 | ssim_ = SSIM(pad_reflection, gaussian_average, comp_mode) 144 | return ssim_(x, y) 145 | 146 | 147 | class ResidualImage(nn.Module): 148 | def __init__(self): 149 | super().__init__() 150 | self.residual_image = ResidualImageModule() 151 | 152 | def forward(self, keyframe: torch.Tensor, keyframe_pose: torch.Tensor, keyframe_intrinsics: torch.Tensor, 153 | depths: torch.Tensor, frames: list, poses: list, intrinsics: list): 154 | data_dict = {"keyframe": keyframe, "keyframe_pose": keyframe_pose, "keyframe_intrinsics": keyframe_intrinsics, 155 | "predicted_inverse_depths": [depths], "frames": frames, "poses": poses, "list": list, 156 | "intrinsics": intrinsics, "inv_depth_max": 0, "inv_depth_min": 1} 157 | data_dict = self.residual_image(data_dict) 158 | return data_dict["residual_image"] 159 | 160 | 161 | class ResidualImageModule(nn.Module): 162 | def __init__(self, use_mono=True, use_stereo=False): 163 | super().__init__() 164 | self.use_mono = use_mono 165 | self.use_stereo = use_stereo 166 | self.ssim = SSIM() 167 | 168 | def forward(self, data_dict): 169 | keyframe = data_dict["keyframe"] 170 | keyframe_intrinsics = data_dict["keyframe_intrinsics"] 171 | keyframe_pose = data_dict["keyframe_pose"] 172 | depths = (1-data_dict["predicted_inverse_depths"][0]) * data_dict["inv_depth_max"] + data_dict["predicted_inverse_depths"][0] * data_dict["inv_depth_min"] 173 | 174 | frames = [] 175 | intrinsics = [] 176 | poses = [] 177 | 178 | if self.use_mono: 179 | frames += data_dict["frames"] 180 | intrinsics += data_dict["intrinsics"] 181 | poses += data_dict["poses"] 182 | if self.use_stereo: 183 | frames += [data_dict["stereoframe"]] 184 | intrinsics += [data_dict["stereoframe_intrinsics"]] 185 | poses += [data_dict["stereoframe_pose"]] 186 | 187 | n, c, h, w = keyframe.shape 188 | 189 | backproject_depth = Backprojection(n, h, w) 190 | backproject_depth.to(keyframe.device) 191 | 192 | inv_k = torch.inverse(keyframe_intrinsics) 193 | cam_points = (inv_k[:, :3, :3] @ backproject_depth.pix_coords) 194 | cam_points = cam_points / depths.view(n, 1, -1) 195 | cam_points = torch.cat([cam_points, backproject_depth.ones], 1) 196 | 197 | masks = [] 198 | residuals = [] 199 | 200 | for i, image in enumerate(frames): 201 | t = torch.inverse(poses[i]) @ keyframe_pose 202 | pix_coords = point_projection(cam_points, n, h, w, intrinsics[i], t) 203 | warped_image = F.grid_sample(image + 1, pix_coords) 204 | mask = torch.any(warped_image == 0, dim=1, keepdim=True) 205 | warped_image -= .5 206 | residual = self.ssim(warped_image, keyframe + .5) 207 | masks.append(mask) 208 | residuals.append(residual) 209 | 210 | masks = torch.stack(masks, dim=1) 211 | residuals = torch.stack(residuals, dim=1) 212 | residuals[masks.expand(-1, -1 , c, -1, -1)] = float("inf") 213 | 214 | residual_image = torch.min(torch.mean(residuals, dim=2, keepdim=True), dim=1)[0] 215 | residual_image[torch.min(masks, dim=1)[0]] = 0 216 | data_dict["residual_image"] = residual_image 217 | return data_dict 218 | 219 | 220 | class PadSameConv2d(torch.nn.Module): 221 | def __init__(self, kernel_size, stride=1): 222 | """ 223 | Imitates padding_mode="same" from tensorflow. 224 | :param kernel_size: Kernelsize of the convolution, int or tuple/list 225 | :param stride: Stride of the convolution, int or tuple/list 226 | """ 227 | super().__init__() 228 | if isinstance(kernel_size, (tuple, list)): 229 | self.kernel_size_y = kernel_size[0] 230 | self.kernel_size_x = kernel_size[1] 231 | else: 232 | self.kernel_size_y = kernel_size 233 | self.kernel_size_x = kernel_size 234 | if isinstance(stride, (tuple, list)): 235 | self.stride_y = stride[0] 236 | self.stride_x = stride[1] 237 | else: 238 | self.stride_y = stride 239 | self.stride_x = stride 240 | 241 | def forward(self, x: torch.Tensor): 242 | _, _, height, width = x.shape 243 | 244 | # For the convolution we want to achieve a output size of (n_h, n_w) = (math.ceil(h / s_y), math.ceil(w / s_y)). 245 | # Therefore we need to apply n_h convolution kernels with stride s_y. We will have n_h - 1 offsets of size s_y. 246 | # Additionally, we need to add the size of our kernel. This is the height we require to get n_h. We need to pad 247 | # the read difference between this and the old height. We will pad math.floor(pad_y / 2) on the left and 248 | # math-ceil(pad_y / 2) on the right. Same for pad_x respectively. 249 | padding_y = (self.stride_y * (math.ceil(height / self.stride_y) - 1) + self.kernel_size_y - height) / 2 250 | padding_x = (self.stride_x * (math.ceil(width / self.stride_x) - 1) + self.kernel_size_x - width) / 2 251 | padding = [math.floor(padding_x), math.ceil(padding_x), math.floor(padding_y), math.ceil(padding_y)] 252 | return F.pad(input=x, pad=padding) 253 | 254 | 255 | class PadSameConv2dTransposed(torch.nn.Module): 256 | def __init__(self, stride): 257 | """ 258 | Imitates padding_mode="same" from tensorflow. 259 | :param stride: Stride of the convolution_transposed, int or tuple/list 260 | """ 261 | super().__init__() 262 | if isinstance(stride, (tuple, list)): 263 | self.stride_y = stride[0] 264 | self.stride_x = stride[1] 265 | else: 266 | self.stride_y = stride 267 | self.stride_x = stride 268 | 269 | def forward(self, x: torch.Tensor, orig_shape: torch.Tensor): 270 | target_shape = x.new_tensor(list(orig_shape)) 271 | target_shape[-2] *= self.stride_y 272 | target_shape[-1] *= self.stride_x 273 | oversize = target_shape[-2:] - x.new_tensor(x.shape)[-2:] 274 | if oversize[0] > 0 and oversize[1] > 0: 275 | x = F.pad(x, [math.floor(oversize[1] / 2), math.ceil(oversize[1] / 2), math.floor(oversize[0] / 2), 276 | math.ceil(oversize[0] / 2)]) 277 | elif oversize[0] > 0 >= oversize[1]: 278 | x = F.pad(x, [0, 0, math.floor(oversize[0] / 2), math.ceil(oversize[0] / 2)]) 279 | x = x[:, :, :, math.floor(-oversize[1] / 2):-math.ceil(-oversize[1] / 2)] 280 | elif oversize[0] <= 0 < oversize[1]: 281 | x = F.pad(x, [math.floor(oversize[1] / 2), math.ceil(oversize[1] / 2)]) 282 | x = x[:, :, math.floor(-oversize[0] / 2):-math.ceil(-oversize[0] / 2), :] 283 | else: 284 | x = x[:, :, math.floor(-oversize[0] / 2):-math.ceil(-oversize[0] / 2), 285 | math.floor(-oversize[1] / 2):-math.ceil(-oversize[1] / 2)] 286 | return x 287 | 288 | 289 | class ConvReLU2(torch.nn.Module): 290 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1): 291 | """ 292 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one 293 | only in x direction. 294 | :param in_channels: Number of input channels 295 | :param out_channels: Number of output channels 296 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction 297 | :param stride: Stride for the convolutions, first in y direction, then in x direction 298 | """ 299 | super().__init__() 300 | self.pad_0 = PadSameConv2d(kernel_size=(kernel_size, 1), stride=(stride, 1)) 301 | self.conv_y = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1), 302 | stride=(stride, 1)) 303 | self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope) 304 | self.pad_1 = PadSameConv2d(kernel_size=(1, kernel_size), stride=(1, stride)) 305 | self.conv_x = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, kernel_size), 306 | stride=(1, stride)) 307 | 308 | def forward(self, x: torch.Tensor): 309 | t = self.pad_0(x) 310 | t = self.conv_y(t) 311 | t = self.leaky_relu(t) 312 | t = self.pad_1(t) 313 | t = self.conv_x(t) 314 | return self.leaky_relu(t) 315 | 316 | 317 | class ConvReLU(torch.nn.Module): 318 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1): 319 | """ 320 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one 321 | only in x direction. 322 | :param in_channels: Number of input channels 323 | :param out_channels: Number of output channels 324 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction 325 | :param stride: Stride for the convolutions, first in y direction, then in x direction 326 | """ 327 | super().__init__() 328 | self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride) 329 | self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) 330 | self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope) 331 | 332 | def forward(self, x: torch.Tensor): 333 | t = self.pad(x) 334 | t = self.conv(t) 335 | return self.leaky_relu(t) 336 | 337 | 338 | class Upconv(torch.nn.Module): 339 | def __init__(self, in_channels, out_channels): 340 | """ 341 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one 342 | only in x direction. 343 | :param in_channels: Number of input channels 344 | :param out_channels: Number of output channels 345 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction 346 | :param stride: Stride for the convolutions, first in y direction, then in x direction 347 | """ 348 | super().__init__() 349 | self.upsample = Upsample(scale_factor=2) 350 | self.pad = PadSameConv2d(kernel_size=2) 351 | self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=1) 352 | 353 | def forward(self, x: torch.Tensor): 354 | t = self.upsample(x) 355 | t = self.pad(t) 356 | return self.conv(t) 357 | 358 | 359 | class ConvSig(torch.nn.Module): 360 | def __init__(self, in_channels, out_channels, kernel_size, stride=1): 361 | """ 362 | Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one 363 | only in x direction. 364 | :param in_channels: Number of input channels 365 | :param out_channels: Number of output channels 366 | :param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction 367 | :param stride: Stride for the convolutions, first in y direction, then in x direction 368 | """ 369 | super().__init__() 370 | self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride) 371 | self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) 372 | self.sig = Sigmoid() 373 | 374 | def forward(self, x: torch.Tensor): 375 | t = self.pad(x) 376 | t = self.conv(t) 377 | return self.sig(t) 378 | 379 | 380 | class Refine(torch.nn.Module): 381 | def __init__(self, in_channels, out_channels, leaky_relu_neg_slope=0.1): 382 | """ 383 | Performs a transposed conv2d with padding that imitates tensorflow same behaviour. The transposed conv2d has 384 | parameters kernel_size=4 and stride=2. 385 | :param in_channels: Channels that go into the conv2d_transposed 386 | :param out_channels: Channels that come out of the conv2d_transposed 387 | """ 388 | super().__init__() 389 | self.conv2d_t = ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2) 390 | self.pad = PadSameConv2dTransposed(stride=2) 391 | self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope) 392 | 393 | def forward(self, x: torch.Tensor, features_direct=None): 394 | orig_shape=x.shape 395 | x = self.conv2d_t(x) 396 | x = self.leaky_relu(x) 397 | x = self.pad(x, orig_shape) 398 | if features_direct is not None: 399 | x = torch.cat([x, features_direct], dim=1) 400 | return x 401 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | from .loss_functions.dymultidepth_loss import mask_loss, depth_loss, mask_refinement_loss, depth_refinement_loss, depth_aux_mask_loss, silog_loss, silog_vn_loss, silog_vn_loss_update, abs_silog_loss_virtualnormal -------------------------------------------------------------------------------- /model/loss_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/loss_functions/__init__.py -------------------------------------------------------------------------------- /model/loss_functions/common_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from model.layers import Backprojection, point_projection, ssim 7 | from utils import create_mask, mask_mean 8 | 9 | 10 | def compute_errors(img0, img1, mask=None): 11 | errors = .85 * torch.mean(ssim(img0, img1, pad_reflection=False, gaussian_average=True, comp_mode=True), dim=1) + .15 * torch.mean(torch.abs(img0 - img1), dim=1) 12 | if mask is not None: return errors, mask 13 | else: return errors 14 | 15 | 16 | def reprojection_loss(depth_prediction: torch.Tensor, data_dict, automasking=False, 17 | error_function=compute_errors, error_function_weight=None, use_mono=True, use_stereo=False, 18 | reduce=True, combine_frames="min", mono_auto=False, border=0): 19 | keyframe = data_dict["keyframe"] 20 | keyframe_pose = data_dict["keyframe_pose"] 21 | keyframe_intrinsics = data_dict["keyframe_intrinsics"] 22 | 23 | frames = [] 24 | poses = [] 25 | intrinsics = [] 26 | 27 | if use_mono: 28 | frames += data_dict["frames"] 29 | poses += data_dict["poses"] 30 | intrinsics += data_dict["intrinsics"] 31 | if use_stereo: 32 | frames += [data_dict["stereoframe"]] 33 | poses += [data_dict["stereoframe_pose"]] 34 | intrinsics += [data_dict["stereoframe_intrinsics"]] 35 | 36 | batch_size, channels, height, width = keyframe.shape 37 | frame_count = len(frames) 38 | keyframe_extrinsics = torch.inverse(keyframe_pose) 39 | extrinsics = [torch.inverse(pose) for pose in poses] 40 | 41 | reprojections = [] 42 | if border > 0: 43 | masks = [create_mask(batch_size, height, width, border, keyframe.device) for _ in range(frame_count)] 44 | warped_masks = [] 45 | 46 | backproject_depth = Backprojection(batch_size, height, width) 47 | backproject_depth.to(keyframe.device) 48 | 49 | for i, (frame, extrinsic, intrinsic) in enumerate(zip(frames, extrinsics, intrinsics)): 50 | cam_points = backproject_depth(1 / depth_prediction, torch.inverse(keyframe_intrinsics)) 51 | pix_coords = point_projection(cam_points, batch_size, height, width, intrinsic, extrinsic @ keyframe_pose) 52 | reprojections.append(F.grid_sample(frame + 1.5, pix_coords, padding_mode="zeros")) 53 | if border > 0: 54 | warped_masks.append(F.grid_sample(masks[i], pix_coords, padding_mode="zeros")) 55 | 56 | reprojections = torch.stack(reprojections, dim=1).view(batch_size * frame_count, channels, height, width) 57 | mask = reprojections[:, 0, :, :] == 0 58 | reprojections -= 1.0 59 | 60 | if border > 0: 61 | mask = ~(torch.stack(warped_masks, dim=1).view(batch_size * frame_count, height, width) > .5) 62 | 63 | keyframe_expanded = (keyframe + .5).unsqueeze(1).expand(-1, frame_count, -1, -1, -1).reshape(batch_size * frame_count, channels, height, width) 64 | 65 | loss = 0 66 | 67 | if type(error_function) != list: 68 | error_function = [error_function] 69 | if error_function_weight is None: 70 | error_function_weight = [1 for i in range(len(error_function))] 71 | 72 | for ef, w in zip(error_function, error_function_weight): 73 | errors, n_mask = ef(reprojections, keyframe_expanded, mask) 74 | n_height, n_width = n_mask.shape[1:] 75 | errors = errors.view(batch_size, frame_count, n_height, n_width) 76 | 77 | n_mask = n_mask.view(batch_size, frame_count, n_height, n_width) 78 | errors[n_mask] = float("inf") 79 | 80 | if automasking: 81 | frames_stacked = torch.stack(frames, dim=1).view(batch_size * frame_count, channels, height, width) + .5 82 | errors_nowarp = ef(frames_stacked, keyframe_expanded).view(batch_size, frame_count, n_height, n_width) 83 | errors[errors_nowarp < errors] = float("inf") 84 | 85 | if mono_auto: 86 | keyframe_expanded_ = (keyframe + .5).unsqueeze(1).expand(-1, len(data_dict["frames"]), -1, -1, -1).reshape(batch_size * len(data_dict["frames"]), channels, height, width) 87 | frames_stacked = (torch.stack(data_dict["frames"], dim=1) + .5).view(batch_size * len(data_dict["frames"]), channels, height, width) 88 | errors_nowarp = ef(frames_stacked, keyframe_expanded_).view(batch_size, len(data_dict["frames"]), n_height, n_width) 89 | errors_nowarp = torch.mean(errors_nowarp, dim=1, keepdim=True) 90 | errors_nowarp[torch.all(n_mask, dim=1, keepdim=True)] = float("inf") 91 | errors = torch.min(errors, errors_nowarp.expand(-1, frame_count, -1, -1)) 92 | 93 | if combine_frames == "min": 94 | errors = torch.min(errors, dim=1)[0] 95 | n_mask = torch.isinf(errors) 96 | elif combine_frames == "avg": 97 | n_mask = torch.isinf(errors) 98 | hits = torch.sum((~n_mask).to(dtype=torch.float32), dim=1) 99 | errors[n_mask] = 0 100 | errors = torch.sum(errors, dim=1) / hits 101 | n_mask = hits == 0 102 | errors[n_mask] = float("inf") 103 | elif combine_frames == "rnd": 104 | index = torch.randint(frame_count, (batch_size, 1, 1, 1), device=keyframe.device).expand(-1, 1, n_height, n_width) 105 | errors = torch.gather(errors, dim=1, index=index).squeeze(1) 106 | n_mask = torch.gather(n_mask, dim=1, index=index).squeeze(1) 107 | else: 108 | raise ValueError("Combine frames must be \"min\", \"avg\" or \"rnd\".") 109 | 110 | if reduce: 111 | loss += w * mask_mean(errors, n_mask) 112 | else: 113 | loss += w * errors 114 | return loss 115 | 116 | 117 | def edge_aware_smoothness_loss(depth_prediction, input, reduce=True): 118 | keyframe = input["keyframe"] 119 | depth_prediction = depth_prediction / torch.mean(depth_prediction, dim=[2, 3], keepdim=True) 120 | 121 | d_dx = torch.abs(depth_prediction[:, :, :, :-1] - depth_prediction[:, :, :, 1:]) 122 | d_dy = torch.abs(depth_prediction[:, :, :-1, :] - depth_prediction[:, :, 1:, :]) 123 | 124 | k_dx = torch.mean(torch.abs(keyframe[:, :, :, :-1] - keyframe[:, :, :, 1:]), 1, keepdim=True) 125 | k_dy = torch.mean(torch.abs(keyframe[:, :, :-1, :] - keyframe[:, :, 1:, :]), 1, keepdim=True) 126 | 127 | d_dx *= torch.exp(-k_dx) 128 | d_dy *= torch.exp(-k_dy) 129 | 130 | if reduce: 131 | return d_dx.mean() + d_dy.mean() 132 | else: 133 | return F.pad(d_dx, pad=(0, 1), mode='constant', value=0) + F.pad(d_dy, pad=(0, 0, 0, 1), mode='constant', value=0) 134 | 135 | 136 | def sparse_depth_loss(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, l2=False, reduce=True): 137 | """ 138 | :param depth_prediction: 139 | :param depth_gt: (N, 1, H, W) 140 | :return: 141 | """ 142 | n, c, h, w = depth_prediction.shape 143 | mask = depth_gt == 0 144 | if not l2: 145 | errors = torch.abs(depth_prediction - depth_gt) 146 | else: 147 | errors = (depth_prediction - depth_gt) ** 2 148 | 149 | if reduce: 150 | loss = mask_mean(errors, mask) 151 | loss[torch.isnan(loss)] = 0 152 | return loss 153 | else: 154 | return errors, mask 155 | 156 | 157 | def selfsup_loss(depth_prediction: torch.Tensor, input=None, scale=0, automasking=True, error_function=None, error_function_weight=None, use_mono=True, use_stereo=False, reduce=True, combine_frames="min", mask_border=0): 158 | reprojection_l = reprojection_loss(depth_prediction, input, automasking=automasking, error_function=error_function, error_function_weight=error_function_weight, use_mono=use_mono, use_stereo=use_stereo, reduce=reduce, combine_frames=combine_frames, border=mask_border) 159 | reprojection_l[torch.isnan(reprojection_l)] = 0 160 | edge_aware_smoothness_l = edge_aware_smoothness_loss(depth_prediction, input) 161 | edge_aware_smoothness_l[torch.isnan(edge_aware_smoothness_l)] = 0 162 | loss = reprojection_l + edge_aware_smoothness_l * 1e-3 / (2 ** scale) 163 | return loss 164 | 165 | 166 | class PerceptualError(nn.Module): 167 | def __init__(self, small_features=False): 168 | super().__init__() 169 | self.small_features = small_features 170 | vgg16 = torchvision.models.vgg16(True, True) 171 | self.feature_extractor = nn.Sequential(*list(vgg16.features.children())[:4 if self.small_features else 9]) 172 | for p in self.feature_extractor.parameters(True): 173 | p.requires_grad_(False) 174 | self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), requires_grad=False) 175 | self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), requires_grad=False) 176 | 177 | def forward(self, img0: torch.Tensor, img1: torch.Tensor, mask=None): 178 | n, c, h, w = img0.shape 179 | c = torchvision.models.vgg.cfgs["D"][1 if self.small_features else 4] 180 | if not self.small_features: 181 | h //= 2 182 | w //= 2 183 | 184 | img0 = (img0 - self.mean) / self.std 185 | img1 = (img1 - self.mean) / self.std 186 | 187 | if mask is not None: 188 | img0[mask.unsqueeze(1).expand(-1, 3, -1, -1)] = 0 189 | img1[mask.unsqueeze(1).expand(-1, 3, -1, -1)] = 0 190 | 191 | input = torch.cat([img0, img1], dim=0) 192 | features = self.feature_extractor(input) 193 | features = features.view(2, n, c, h, w) 194 | errors = torch.mean((features[1] - features[0]) ** 2, dim=1) 195 | 196 | if mask is not None: 197 | if not self.small_features: 198 | mask = F.upsample(mask.to(dtype=torch.float).unsqueeze(1), (h, w), mode="bilinear").squeeze(1) 199 | mask = mask > 0 200 | return errors, mask 201 | else: 202 | return errors 203 | -------------------------------------------------------------------------------- /model/loss_functions/virtual_normal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | class VNL_Loss(nn.Module): 8 | """ 9 | Virtual Normal Loss Function. 10 | """ 11 | def __init__(self, focal_x, focal_y, input_size, 12 | delta_cos=0.867, delta_diff_x=0.01, 13 | delta_diff_y=0.01, delta_diff_z=0.01, 14 | delta_z=0.0001, sample_ratio=0.15): 15 | super(VNL_Loss, self).__init__() 16 | self.fx = torch.tensor([focal_x], dtype=torch.float32).cuda() 17 | self.fy = torch.tensor([focal_y], dtype=torch.float32).cuda() 18 | self.input_size = input_size 19 | self.u0 = torch.tensor(input_size[1] // 2, dtype=torch.float32).cuda() 20 | self.v0 = torch.tensor(input_size[0] // 2, dtype=torch.float32).cuda() 21 | self.init_image_coor() 22 | self.delta_cos = delta_cos 23 | self.delta_diff_x = delta_diff_x 24 | self.delta_diff_y = delta_diff_y 25 | self.delta_diff_z = delta_diff_z 26 | self.delta_z = delta_z 27 | self.sample_ratio = sample_ratio 28 | 29 | def init_image_coor(self): 30 | x_row = np.arange(0, self.input_size[1]) 31 | x = np.tile(x_row, (self.input_size[0], 1)) 32 | x = x[np.newaxis, :, :] 33 | x = x.astype(np.float32) 34 | x = torch.from_numpy(x.copy()).cuda() 35 | self.u_u0 = x - self.u0 36 | 37 | y_col = np.arange(0, self.input_size[0]) # y_col = np.arange(0, height) 38 | y = np.tile(y_col, (self.input_size[1], 1)).T 39 | y = y[np.newaxis, :, :] 40 | y = y.astype(np.float32) 41 | y = torch.from_numpy(y.copy()).cuda() 42 | self.v_v0 = y - self.v0 43 | 44 | def transfer_xyz(self, depth): 45 | x = self.u_u0 * torch.abs(depth) / self.fx 46 | y = self.v_v0 * torch.abs(depth) / self.fy 47 | z = depth 48 | pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] 49 | return pw 50 | 51 | def select_index(self): 52 | valid_width = self.input_size[1] 53 | valid_height = self.input_size[0] 54 | num = valid_width * valid_height 55 | p1 = np.random.choice(num, int(num * self.sample_ratio), replace=True) 56 | np.random.shuffle(p1) 57 | p2 = np.random.choice(num, int(num * self.sample_ratio), replace=True) 58 | np.random.shuffle(p2) 59 | p3 = np.random.choice(num, int(num * self.sample_ratio), replace=True) 60 | np.random.shuffle(p3) 61 | 62 | p1_x = p1 % self.input_size[1] 63 | p1_y = (p1 / self.input_size[1]).astype(np.int) 64 | 65 | p2_x = p2 % self.input_size[1] 66 | p2_y = (p2 / self.input_size[1]).astype(np.int) 67 | 68 | p3_x = p3 % self.input_size[1] 69 | p3_y = (p3 / self.input_size[1]).astype(np.int) 70 | p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y} 71 | return p123 72 | 73 | def form_pw_groups(self, p123, pw): 74 | """ 75 | Form 3D points groups, with 3 points in each grouup. 76 | :param p123: points index 77 | :param pw: 3D points 78 | :return: 79 | """ 80 | p1_x = p123['p1_x'] 81 | p1_y = p123['p1_y'] 82 | p2_x = p123['p2_x'] 83 | p2_y = p123['p2_y'] 84 | p3_x = p123['p3_x'] 85 | p3_y = p123['p3_y'] 86 | 87 | pw1 = pw[:, p1_y, p1_x, :] 88 | pw2 = pw[:, p2_y, p2_x, :] 89 | pw3 = pw[:, p3_y, p3_x, :] 90 | # [B, N, 3(x,y,z), 3(p1,p2,p3)] 91 | pw_groups = torch.cat([pw1[:, :, :, np.newaxis], pw2[:, :, :, np.newaxis], pw3[:, :, :, np.newaxis]], 3) 92 | return pw_groups 93 | 94 | def filter_mask(self, p123, gt_xyz, delta_cos=0.867, 95 | delta_diff_x=0.005, 96 | delta_diff_y=0.005, 97 | delta_diff_z=0.005): 98 | pw = self.form_pw_groups(p123, gt_xyz) 99 | pw12 = pw[:, :, :, 1] - pw[:, :, :, 0] 100 | pw13 = pw[:, :, :, 2] - pw[:, :, :, 0] 101 | pw23 = pw[:, :, :, 2] - pw[:, :, :, 1] 102 | ###ignore linear 103 | pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]], 104 | 3) # [b, n, 3, 3] 105 | m_batchsize, groups, coords, index = pw_diff.shape 106 | proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1) # (B* X CX(3)) [bn, 3(p123), 3(xyz)] 107 | proj_key = pw_diff.view(m_batchsize * groups, -1, index) # B X (3)*C [bn, 3(xyz), 3(p123)] 108 | q_norm = proj_query.norm(2, dim=2) 109 | nm = torch.bmm(q_norm.view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[] 110 | energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)] 111 | norm_energy = energy / (nm + 1e-8) 112 | norm_energy = norm_energy.view(m_batchsize * groups, -1) 113 | mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3 # igonre 114 | mask_cos = mask_cos.view(m_batchsize, groups) 115 | ##ignore padding and invilid depth 116 | mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3 117 | 118 | ###ignore near 119 | mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0 120 | mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0 121 | mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0 122 | 123 | mask_ignore = (mask_x & mask_y & mask_z) | mask_cos 124 | mask_near = ~mask_ignore 125 | mask = mask_pad & mask_near 126 | 127 | return mask, pw 128 | 129 | def select_points_groups(self, gt_depth, pred_depth): 130 | pw_gt = self.transfer_xyz(gt_depth) 131 | pw_pred = self.transfer_xyz(pred_depth) 132 | B, C, H, W = gt_depth.shape 133 | p123 = self.select_index() 134 | # mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)] 135 | mask, pw_groups_gt = self.filter_mask(p123, pw_gt, 136 | delta_cos=0.867, 137 | delta_diff_x=0.005, 138 | delta_diff_y=0.005, 139 | delta_diff_z=0.005) 140 | 141 | # [b, n, 3, 3] 142 | pw_groups_pred = self.form_pw_groups(p123, pw_pred) 143 | pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001 144 | mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2) 145 | pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3) 146 | pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3) 147 | 148 | return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore 149 | 150 | def forward(self, gt_depth, pred_depth, select=True): 151 | """ 152 | Virtual normal loss. 153 | :param pred_depth: predicted depth map, [B,W,H,C] 154 | :param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down] 155 | :return: 156 | """ 157 | gt_points, dt_points = self.select_points_groups(gt_depth, pred_depth) 158 | 159 | gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0] 160 | gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0] 161 | dt_p12 = dt_points[:, :, :, 1] - dt_points[:, :, :, 0] 162 | dt_p13 = dt_points[:, :, :, 2] - dt_points[:, :, :, 0] 163 | 164 | gt_normal = torch.cross(gt_p12, gt_p13, dim=2) 165 | dt_normal = torch.cross(dt_p12, dt_p13, dim=2) 166 | dt_norm = torch.norm(dt_normal, 2, dim=2, keepdim=True) 167 | gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True) 168 | dt_mask = dt_norm == 0.0 169 | gt_mask = gt_norm == 0.0 170 | dt_mask = dt_mask.to(torch.float32) 171 | gt_mask = gt_mask.to(torch.float32) 172 | dt_mask *= 0.01 173 | gt_mask *= 0.01 174 | gt_norm = gt_norm + gt_mask 175 | dt_norm = dt_norm + dt_mask 176 | gt_normal = gt_normal / gt_norm 177 | dt_normal = dt_normal / dt_norm 178 | loss = torch.abs(gt_normal - dt_normal) 179 | loss = torch.sum(torch.sum(loss, dim=2), dim=0) 180 | if select: 181 | loss, indices = torch.sort(loss, dim=0, descending=False) 182 | loss = loss[int(loss.size(0) * 0.25):] 183 | loss = torch.mean(loss) 184 | return loss -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | from .metric_functions.sparse_metrics import * 2 | from .metric_functions.dense_metrics import * 3 | from .metric_functions.completeness_metrics import * -------------------------------------------------------------------------------- /model/metric_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/model/metric_functions/__init__.py -------------------------------------------------------------------------------- /model/metric_functions/completeness_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import mask_mean 4 | 5 | 6 | def completeness_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None): 7 | return torch.mean((depth_prediction != 0).to(dtype=torch.float32)) 8 | 9 | 10 | def covered_gt_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None): 11 | gt_mask = depth_gt != 0 12 | return mask_mean(((depth_prediction != 0)).to(dtype=torch.float32), gt_mask) 13 | -------------------------------------------------------------------------------- /model/metric_functions/dense_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import preprocess_roi, get_positive_depth, get_absolute_depth 4 | 5 | 6 | def sc_inv_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None): 7 | """ 8 | Computes scale inveriant metric described in (14) 9 | :param depth_prediction: Depth prediction computed by the network 10 | :param depth_gt: GT Depth 11 | :param roi: Specify a region of interest on which the metric should be computed 12 | :return: metric (mean over batch_size) 13 | """ 14 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 15 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 16 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 17 | 18 | n = depth_gt.shape[2] * depth_gt.shape[3] 19 | E = torch.log(depth_prediction) - torch.log(depth_gt) 20 | E[torch.isnan(E)] = 0 21 | batch_metric = torch.sqrt(1 / n * torch.sum(E**2, dim=[2, 3]) - 1 / (n**2) * (torch.sum(E, dim=[2, 3])**2)) 22 | batch_metric[torch.isnan(batch_metric)] = 0 23 | result = torch.mean(batch_metric) 24 | return result 25 | 26 | 27 | def l1_rel_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None): 28 | """ 29 | Computes the L1-rel metric described in (15) 30 | :param depth_prediction: Depth prediction computed by the network 31 | :param depth_gt: GT Depth 32 | :param roi: Specify a region of interest on which the metric should be computed 33 | :return: metric (mean over batch_size) 34 | """ 35 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 36 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 37 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 38 | 39 | return torch.mean(torch.abs(depth_prediction - depth_gt) / depth_gt) 40 | 41 | 42 | def l1_inv_metric(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, roi=None, max_distance=None): 43 | """ 44 | Computes the L1-inv metric described in (16) 45 | :param depth_prediction: Depth prediction computed by the network 46 | :param depth_gt: GT Depth 47 | :param roi: Specify a region of interest on which the metric should be computed 48 | :return: metric (mean over batch_size) 49 | """ 50 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 51 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 52 | 53 | 54 | return torch.mean(torch.abs(depth_prediction - depth_gt)) -------------------------------------------------------------------------------- /model/metric_functions/sparse_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import preprocess_roi, get_positive_depth, get_absolute_depth, get_mask, mask_mean 4 | 5 | from depth_proc_tools.plot_depth_utils import * 6 | import cv2 7 | from PIL import Image 8 | import os 9 | 10 | 11 | def a1_metric(data_dict: dict, roi=None, max_distance=None): 12 | depth_prediction = data_dict["result"] 13 | depth_gt = data_dict["target"] 14 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 15 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 16 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 17 | 18 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)) 19 | return torch.mean((thresh < 1.25).type(torch.float)) 20 | 21 | 22 | def a2_metric(data_dict: dict, roi=None, max_distance=None): 23 | depth_prediction = data_dict["result"] 24 | depth_gt = data_dict["target"] 25 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 26 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 27 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 28 | 29 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float) 30 | return torch.mean((thresh < 1.25 ** 2).type(torch.float)) 31 | 32 | 33 | def a3_metric(data_dict: dict, roi=None, max_distance=None): 34 | depth_prediction = data_dict["result"] 35 | depth_gt = data_dict["target"] 36 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 37 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 38 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 39 | 40 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float) 41 | return torch.mean((thresh < 1.25 ** 3).type(torch.float)) 42 | 43 | 44 | def rmse_metric(data_dict: dict, roi=None, max_distance=None): 45 | depth_prediction = data_dict["result"] 46 | depth_gt = data_dict["target"] 47 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 48 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 49 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 50 | 51 | se = (depth_prediction - depth_gt) ** 2 52 | return torch.mean(torch.sqrt(torch.mean(se, dim=[1, 2, 3]))) 53 | 54 | 55 | def rmse_log_metric(data_dict: dict, roi=None, max_distance=None): 56 | depth_prediction = data_dict["result"] 57 | depth_gt = data_dict["target"] 58 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 59 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 60 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 61 | 62 | sle = (torch.log(depth_prediction) - torch.log(depth_gt)) ** 2 63 | return torch.mean(torch.sqrt(torch.mean(sle, dim=[1, 2, 3]))) 64 | 65 | 66 | def abs_rel_metric(data_dict: dict, roi=None, max_distance=None): 67 | depth_prediction = data_dict["result"] 68 | depth_gt = data_dict["target"] 69 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 70 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 71 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 72 | 73 | return torch.mean(torch.abs(depth_prediction - depth_gt) / depth_gt) 74 | 75 | 76 | def sq_rel_metric(data_dict: dict, roi=None, max_distance=None): 77 | depth_prediction = data_dict["result"] 78 | depth_gt = data_dict["target"] 79 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 80 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 81 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 82 | 83 | return torch.mean(((depth_prediction - depth_gt) ** 2) / depth_gt) 84 | 85 | 86 | def find_mincost_depth(cost_volume, depth_hypos): 87 | argmax = torch.argmax(cost_volume, dim=1, keepdim=True) 88 | mincost_depth = torch.gather(input=depth_hypos, dim=1, index=argmax) 89 | return mincost_depth 90 | 91 | def a1_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False): 92 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"] 93 | depth_gt = data_dict["target"] 94 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 95 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid) 96 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5) 97 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 98 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 99 | 100 | return a1_base(depth_prediction, depth_gt, mask) 101 | 102 | 103 | 104 | def a2_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False): 105 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"] 106 | depth_gt = data_dict["target"] 107 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 108 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid) 109 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5) 110 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 111 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 112 | return a2_base(depth_prediction, depth_gt, mask) 113 | 114 | 115 | def a3_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False): 116 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"] 117 | depth_gt = data_dict["target"] 118 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 119 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid) 120 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5) 121 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 122 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 123 | return a3_base(depth_prediction, depth_gt, mask) 124 | 125 | 126 | def rmse_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False): 127 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"] 128 | depth_gt = data_dict["target"] 129 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 130 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid) 131 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5) 132 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 133 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 134 | return rmse_base(depth_prediction, depth_gt, mask) 135 | 136 | 137 | def rmse_log_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False): 138 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"] 139 | depth_gt = data_dict["target"] 140 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 141 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid) 142 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5) 143 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 144 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 145 | return rmse_log_base(depth_prediction, depth_gt, mask) 146 | 147 | 148 | def abs_rel_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False): 149 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"] 150 | depth_gt = data_dict["target"] 151 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 152 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid) 153 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5) 154 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 155 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 156 | return abs_rel_base(depth_prediction, depth_gt, mask) 157 | 158 | 159 | def sq_rel_sparse_metric(data_dict: dict, roi=None, max_distance=None, pred_all_valid=True, use_cvmask=False, eval_mono=False): 160 | depth_prediction = data_dict["result_mono"] if eval_mono else data_dict["result"] 161 | depth_gt = data_dict["target"] 162 | depth_prediction, depth_gt = preprocess_roi(depth_prediction, depth_gt, roi) 163 | mask = get_mask(depth_prediction, depth_gt, max_distance=max_distance, pred_all_valid=pred_all_valid) 164 | if use_cvmask: mask |= ~ (data_dict["mvobj_mask"] > .5) 165 | depth_prediction, depth_gt = get_positive_depth(depth_prediction, depth_gt) 166 | depth_prediction, depth_gt = get_absolute_depth(depth_prediction, depth_gt, max_distance) 167 | return sq_rel_base(depth_prediction, depth_gt, mask) 168 | 169 | 170 | 171 | def save_results(path, name, img, gt_depth, pred_depth, validmask, cv_mask, costvolume): 172 | savepath = os.path.join(path, name) 173 | device=img.device 174 | bs,_,h,w = img.shape 175 | img = img[0,...].permute(1,2,0).detach().cpu().numpy() + 0.5 176 | gt_depth = gt_depth[0,...].permute(1,2,0).detach().cpu().numpy() 177 | gt_depth[gt_depth==80] = 0 178 | pred_depth = pred_depth[0,...].permute(1,2,0).detach().cpu().numpy() 179 | validmask = validmask[0,0,...].detach().cpu().numpy() 180 | cv_mask = cv_mask[0,0,...].detach().cpu().numpy() 181 | 182 | img = img #* np.expand_dims(cv_mask, axis=-1).astype(float) 183 | 184 | error_map, _ = get_error_map_value(pred_depth,gt_depth, grag_crop=False, median_scaling=False) 185 | 186 | errorpil = numpy_intensitymap_to_pcolor(error_map,vmin=0,vmax=0.5,colormap='jet') 187 | pred_pil = numpy_intensitymap_to_pcolor(pred_depth) 188 | gt_pil = numpy_intensitymap_to_pcolor(gt_depth) 189 | img_pil = numpy_rgb_to_pil(img) 190 | 191 | # generate pil validmask 192 | validmask_pil = Image.fromarray((validmask * 255.0).astype(np.uint8)) 193 | cv_mask_pil = Image.fromarray((cv_mask * 255.0).astype(np.uint8)) 194 | 195 | # cost 196 | print(bs,h,w) 197 | # print(f"cost volume shape:{costvolume.shape}") 198 | depths = (1 / torch.linspace(0.0025, 0.33, 32,device=device)).cuda().view(1, -1, 1, 1).expand(bs, -1, h, w) 199 | cost_volume_depth = find_mincost_depth(costvolume, depths) 200 | # print(f"cost shape:{cost_volume_depth.shape}") 201 | cost_volume_depth = cost_volume_depth[0,...].permute(1,2,0).detach().cpu().numpy() 202 | 203 | cv_depth_pil = numpy_intensitymap_to_pcolor(cost_volume_depth) 204 | 205 | 206 | h,w,_ = gt_depth.shape 207 | dst = Image.new('RGB', (w, h*3)) 208 | dst.paste(img_pil, (0, 0)) 209 | dst.paste(pred_pil, (0, h)) 210 | dst.paste(gt_pil, (0, 2*h)) 211 | # dst.paste(errorpil, (0, 3*h)) 212 | # dst.paste(validmask_pil,(0,4*h)) 213 | # dst.paste(cv_mask_pil,(0,5*h)) 214 | # dst.paste(cv_depth_pil,(0,2*h)) 215 | 216 | dst.save(savepath) 217 | print(f"saved to {savepath}") 218 | 219 | 220 | 221 | def a1_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None): 222 | return a1_sparse_metric(data_dict, roi, max_distance, False) 223 | 224 | 225 | def a2_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None): 226 | return a2_sparse_metric(data_dict, roi, max_distance, False) 227 | 228 | 229 | def a3_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None): 230 | return a3_sparse_metric(data_dict, roi, max_distance, False) 231 | 232 | 233 | def rmse_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None): 234 | return rmse_sparse_metric(data_dict, roi, max_distance, False) 235 | 236 | 237 | def rmse_log_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None): 238 | return rmse_log_sparse_metric(data_dict, roi, max_distance, False) 239 | 240 | 241 | def abs_rel_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None): 242 | return abs_rel_sparse_metric(data_dict, roi, max_distance, False) 243 | 244 | 245 | def sq_rel_sparse_onlyvalid_metric(data_dict: dict, roi=None, max_distance=None): 246 | return sq_rel_sparse_metric(data_dict, roi, max_distance, False) 247 | 248 | 249 | def a1_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None): 250 | return a1_sparse_metric(data_dict, roi, max_distance, use_cvmask=True) 251 | 252 | 253 | def a2_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None): 254 | return a2_sparse_metric(data_dict, roi, max_distance, use_cvmask=True) 255 | 256 | 257 | def a3_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None): 258 | return a3_sparse_metric(data_dict, roi, max_distance, use_cvmask=True) 259 | 260 | 261 | def rmse_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None): 262 | return rmse_sparse_metric(data_dict, roi, max_distance, use_cvmask=True) 263 | 264 | 265 | def rmse_log_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None): 266 | return rmse_log_sparse_metric(data_dict, roi, max_distance, use_cvmask=True) 267 | 268 | 269 | def abs_rel_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None): 270 | return abs_rel_sparse_metric(data_dict, roi, max_distance, use_cvmask=True) 271 | 272 | 273 | def sq_rel_sparse_onlydynamic_metric(data_dict: dict, roi=None, max_distance=None): 274 | return sq_rel_sparse_metric(data_dict, roi, max_distance, use_cvmask=True) 275 | 276 | 277 | def a1_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask): 278 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)) 279 | return mask_mean((thresh < 1.25).type(torch.float), mask) 280 | 281 | 282 | def a2_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask): 283 | depth_gt[mask] = 1 284 | depth_prediction[mask] = 1 285 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float) 286 | return mask_mean((thresh < 1.25 ** 2).type(torch.float), mask) 287 | 288 | 289 | def a3_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask): 290 | depth_gt[mask] = 1 291 | depth_prediction[mask] = 1 292 | thresh = torch.max((depth_gt / depth_prediction), (depth_prediction / depth_gt)).type(torch.float) 293 | return mask_mean((thresh < 1.25 ** 3).type(torch.float), mask) 294 | 295 | 296 | def rmse_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask): 297 | depth_gt[mask] = 1 298 | depth_prediction[mask] = 1 299 | se = (depth_prediction - depth_gt) ** 2 300 | return torch.mean(torch.sqrt(mask_mean(se, mask, dim=[1, 2, 3]))) 301 | 302 | 303 | def rmse_log_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask): 304 | depth_gt[mask] = 1 305 | depth_prediction[mask] = 1 306 | sle = (torch.log(depth_prediction) - torch.log(depth_gt)) ** 2 307 | return torch.mean(torch.sqrt(mask_mean(sle, mask, dim=[1, 2, 3]))) 308 | 309 | 310 | def abs_rel_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask): 311 | return mask_mean(torch.abs(depth_prediction - depth_gt) / depth_gt, mask) 312 | 313 | 314 | def sq_rel_base(depth_prediction: torch.Tensor, depth_gt: torch.Tensor, mask): 315 | return mask_mean(((depth_prediction - depth_gt) ** 2) / depth_gt, mask) -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | from .dymultidepth.dymultidepth_model import DyMultiDepthModel -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, args, options='', timestamp=True): 13 | # parse default and custom cli options 14 | for opt in options: 15 | args.add_argument(*opt.flags, default=None, type=opt.type) 16 | args = args.parse_args() 17 | self.args = args 18 | 19 | if args.device: 20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 21 | if args.resume is None: 22 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 23 | assert args.config is not None, msg_no_cfg 24 | self.cfg_fname = Path(args.config) 25 | config = read_json(self.cfg_fname) 26 | self.resume = None 27 | else: 28 | self.resume = Path(args.resume) 29 | resume_cfg_fname = self.resume.parent / 'config.json' 30 | config = read_json(resume_cfg_fname) 31 | if args.config is not None: 32 | config.update(read_json(Path(args.config))) 33 | 34 | # load config file and apply custom cli options 35 | self._config = _update_config(config, options, args) 36 | 37 | # set save_dir where trained model and log will be saved. 38 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 39 | 40 | if "trainer" in self.config: 41 | save_dir = Path(self.config['trainer']['save_dir']) 42 | if "timestamp_replacement" in self.config["trainer"]: 43 | timestamp = self.config["trainer"]["timestamp_replacement"] 44 | 45 | elif "evaluater" in self.config: 46 | save_dir = Path(self.config['evaluater']['save_dir']) 47 | if "timestamp_replacement" in self.config["evaluater"]: 48 | timestamp = self.config["evaluater"]["timestamp_replacement"] 49 | elif "save_dir" in self.config: 50 | save_dir = Path(self.config["save_dir"]) 51 | else: 52 | save_dir = Path("../saved") 53 | 54 | exper_name = self.config['name'] 55 | self._save_dir = save_dir / 'models' / exper_name / timestamp 56 | self._log_dir = save_dir / 'log' / exper_name / timestamp 57 | 58 | self.save_dir.mkdir(parents=True, exist_ok=True) 59 | self.log_dir.mkdir(parents=True, exist_ok=True) 60 | 61 | # save updated config file to the checkpoint dir 62 | write_json(self.config, self.save_dir / 'config.json') 63 | 64 | # configure logging module 65 | setup_logging(self.log_dir) 66 | self.log_levels = { 67 | 0: logging.WARNING, 68 | 1: logging.INFO, 69 | 2: logging.DEBUG 70 | } 71 | 72 | def initialize(self, name, module, *args, **kwargs): 73 | """ 74 | finds a function handle with the name given as 'type' in config, and returns the 75 | instance initialized with corresponding keyword args given as 'args'. 76 | """ 77 | module_name = self[name]['type'] 78 | module_args = dict(self[name]['args']) 79 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 80 | module_args.update(kwargs) 81 | return getattr(module, module_name)(*args, **module_args) 82 | 83 | def initialize_list(self, name, module, *args, **kwargs): 84 | l = self[name] 85 | for to_init in l: 86 | module_name = to_init["type"] 87 | module_args = dict(to_init["args"]) 88 | module_args.update(kwargs) 89 | yield getattr(module, module_name)(*args, **module_args) 90 | 91 | def __getitem__(self, name): 92 | return self.config[name] 93 | 94 | def get_logger(self, name, verbosity=2): 95 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 96 | assert verbosity in self.log_levels, msg_verbosity 97 | logger = logging.getLogger(name) 98 | logger.setLevel(self.log_levels[verbosity]) 99 | return logger 100 | 101 | # setting read-only attributes 102 | @property 103 | def config(self): 104 | return self._config 105 | 106 | @property 107 | def save_dir(self): 108 | return self._save_dir 109 | 110 | @property 111 | def log_dir(self): 112 | return self._log_dir 113 | 114 | # helper functions used to update config dict with custom cli options 115 | def _update_config(config, options, args): 116 | for opt in options: 117 | value = getattr(args, _get_opt_name(opt.flags)) 118 | if value is not None: 119 | _set_by_path(config, opt.target, value) 120 | return config 121 | 122 | def _get_opt_name(flags): 123 | for flg in flags: 124 | if flg.startswith('--'): 125 | return flg.replace('--', '') 126 | return flags[0].replace('--', '') 127 | 128 | def _set_by_path(tree, keys, value): 129 | """Set a value in a nested object in tree by sequence of keys.""" 130 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 131 | 132 | def _get_by_path(tree, keys): 133 | """Access a nested object in tree by sequence of keys.""" 134 | return reduce(getitem, keys, tree) 135 | -------------------------------------------------------------------------------- /pictures/dynamic_depth_result.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/pictures/dynamic_depth_result.gif -------------------------------------------------------------------------------- /pictures/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/pictures/overview.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | kornia 4 | opencv-python 5 | scipy 6 | scikit-image 7 | tqdm 8 | tensorboard 9 | tensorboardx 10 | pykitti 11 | colour_demosaicing -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import data_loader.data_loaders as module_data 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | import model.model as module_arch 8 | from utils import seed_rng 9 | from utils.parse_config import ConfigParser 10 | from trainer.trainer import Trainer 11 | 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | torch.backends.cudnn.benchmark = True 16 | 17 | def main(config, options=[]): 18 | seed_rng(0) 19 | logger = config.get_logger('train') 20 | 21 | # setup data_loader instances 22 | data_loader = config.initialize('data_loader', module_data) 23 | if "val_data_loader" in config.config: 24 | valid_data_loader = config.initialize("val_data_loader", module_data) 25 | else: 26 | valid_data_loader = data_loader.split_validation() 27 | 28 | # build model architecture, then print to console 29 | model = config.initialize('arch', module_arch) 30 | logger.info(model) 31 | logger.info(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters") 32 | logger.info(f"{sum(p.numel() for p in model.parameters())} total parameters") 33 | 34 | # get function handles of loss and metrics 35 | if "loss_module" in config.config: 36 | loss = config.initialize("loss_module", module_loss) 37 | else: 38 | loss = getattr(module_loss, config['loss']) 39 | metrics = [getattr(module_metric, met) for met in config['metrics']] 40 | 41 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 42 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 43 | optimizer = config.initialize('optimizer', torch.optim, trainable_params) 44 | 45 | lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer) 46 | 47 | trainer = Trainer(model, loss, metrics, optimizer, 48 | config=config, 49 | data_loader=data_loader, 50 | valid_data_loader=valid_data_loader, 51 | lr_scheduler=lr_scheduler, 52 | options=options) 53 | 54 | trainer.train() 55 | 56 | 57 | if __name__ == '__main__': 58 | args = argparse.ArgumentParser(description='PyTorch Template') 59 | args.add_argument('-c', '--config', default=None, type=str, 60 | help='config file path (default: None)') 61 | args.add_argument('-r', '--resume', default=None, type=str, 62 | help='path to latest checkpoint (default: None)') 63 | args.add_argument('-d', '--device', default=None, type=str, 64 | help='indices of GPUs to enable (default: all)') 65 | args.add_argument('-o', '--options', default=[], nargs='+') 66 | 67 | # custom cli options to modify configuration from default values given in json file. 68 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 69 | options = [ 70 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')), 71 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')) 72 | ] 73 | config = ConfigParser(args, options) 74 | print(config.config) 75 | main(config, config.args.options) 76 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/dynamic-multiframe-depth/b3d3f4eb00052c9d1f42c5f0abfaf896fceeb39e/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import numpy as np 4 | import torch 5 | from torchvision.utils import make_grid 6 | 7 | from base import BaseTrainer 8 | from utils import inf_loop, map_fn, operator_on_dict, LossWrapper, ValueFader 9 | 10 | import time 11 | 12 | class Trainer(BaseTrainer): 13 | def __init__(self, model, loss, metrics, optimizer, config, data_loader, 14 | valid_data_loader=None, lr_scheduler=None, options=[]): 15 | super().__init__(model, loss, metrics, optimizer, config) 16 | self.config = config 17 | self.data_loader = data_loader 18 | 19 | len_epoch = config["trainer"].get("len_epoch", None) 20 | 21 | if len_epoch is None: 22 | # epoch-based training 23 | self.len_epoch = len(self.data_loader) 24 | else: 25 | # iteration-based training 26 | self.data_loader = inf_loop(data_loader) 27 | self.len_epoch = len_epoch 28 | 29 | self.valid_data_loader = valid_data_loader 30 | self.do_validation = self.valid_data_loader is not None 31 | self.lr_scheduler = lr_scheduler 32 | self.log_step = config["trainer"].get("log_step", int(np.sqrt(data_loader.batch_size))) 33 | self.val_log_step = config["trainer"].get("val_step", 1) 34 | self.roi = config["trainer"].get("roi") 35 | self.roi_train = config["trainer"].get("roi_train", self.roi) 36 | self.alpha = config["trainer"].get("alpha", None) 37 | self.max_distance = config["trainer"].get("max_distance", None) 38 | self.val_avg = config["trainer"].get("val_avg", True) 39 | self.save_multiple = config["trainer"].get("save_multiple", False) 40 | self.invert_output_images = config["trainer"].get("invert_output_images", True) 41 | self.wrap_loss_in_module = config["trainer"].get("wrap_loss_in_module", False) 42 | self.value_faders = config["trainer"].get("value_faders", {}) 43 | self.options = options 44 | 45 | if self.wrap_loss_in_module: 46 | self.loss = LossWrapper(loss_function=self.loss, roi=self.roi, options=self.options) 47 | 48 | if isinstance(loss, torch.nn.Module) or self.wrap_loss_in_module: 49 | self.module_loss = True 50 | self.loss.to(self.device) 51 | if len(self.device_ids) > 1: 52 | self.loss.num_devices = len(self.device_ids) 53 | self.model = torch.nn.DataParallel(torch.nn.Sequential(self.model.module, self.loss), self.device_ids) 54 | else: 55 | self.model = torch.nn.Sequential(self.model, self.loss) 56 | else: 57 | self.module_loss = False 58 | 59 | self.value_faders = {k: ValueFader(v[0], v[1]) for k, v in self.value_faders.items()} 60 | 61 | def _eval_metrics(self, data_dict, training=False): 62 | acc_metrics = np.zeros(len(self.metrics)) 63 | for i, metric in enumerate(self.metrics): 64 | acc_metrics[i] += metric(data_dict, self.roi, self.max_distance) 65 | if (not self.val_avg) or training: 66 | self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i]) 67 | if np.any(np.isnan(acc_metrics)): 68 | acc_metrics = np.zeros(len(self.metrics)) 69 | valid = np.zeros(len(self.metrics)) 70 | else: 71 | valid = np.ones(len(self.metrics)) 72 | return acc_metrics, valid 73 | 74 | def _train_epoch(self, epoch): 75 | self.model.train() 76 | 77 | total_loss = 0 78 | total_loss_dict = {} 79 | total_metrics = np.zeros(len(self.metrics)) 80 | total_metrics_valid = np.zeros(len(self.metrics)) 81 | 82 | fade_values = {k: torch.tensor([fader.get_value(epoch)]) for k, fader in self.value_faders.items()} 83 | 84 | for batch_idx, (data, target) in enumerate(self.data_loader): 85 | data.update(fade_values) 86 | data, target = to(data, self.device), to(target, self.device) 87 | data["target"] = target 88 | # data["optimizer"] = self.optimizer 89 | 90 | start_time = time.time() 91 | 92 | self.optimizer.zero_grad() 93 | 94 | if not self.module_loss: 95 | data = self.model(data) 96 | loss_dict = self.loss(data, self.alpha, self.roi_train, options=self.options) 97 | else: 98 | data, loss_dict = self.model(data) 99 | 100 | loss_dict = map_fn(loss_dict, torch.sum) 101 | 102 | loss = loss_dict["loss"] 103 | if loss.requires_grad: 104 | loss.backward() 105 | 106 | self.optimizer.step() 107 | 108 | # print("Forward time: ", (time.time() - start_time)) 109 | 110 | loss_dict = map_fn(loss_dict, torch.detach) 111 | 112 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 113 | 114 | self.writer.add_scalar('loss', loss.item()) 115 | for loss_component, v in loss_dict.items(): 116 | self.writer.add_scalar(f"loss_{loss_component}", v.item()) 117 | 118 | total_loss += loss.item() 119 | total_loss_dict = operator_on_dict(total_loss_dict, loss_dict, operator.add) 120 | metrics, valid = self._eval_metrics(data, True) 121 | total_metrics += metrics 122 | total_metrics_valid += valid 123 | 124 | if self.writer.step % self.log_step == 0: 125 | img_count = min(data["keyframe"].shape[0], 8) 126 | 127 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f} Loss_dict: {}'.format( 128 | epoch, 129 | self._progress(batch_idx), 130 | loss.item(), 131 | loss_dict)) 132 | 133 | if "mask" in data: 134 | if self.invert_output_images: 135 | result = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu() 136 | result /= torch.max(result) * 2 / 3 137 | else: 138 | result = data["result"][:img_count].cpu() 139 | mask = data["mask"][:img_count].cpu() 140 | img = torch.cat([result, mask], dim=2) 141 | else: 142 | if self.invert_output_images: 143 | img = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu() 144 | else: 145 | img = data["result"][:img_count].cpu() 146 | 147 | self.writer.add_image('input', make_grid(to(data["keyframe"][:img_count], "cpu"), nrow=2, normalize=True)) 148 | self.writer.add_image('output', make_grid(img , nrow=2, normalize=True)) 149 | self.writer.add_image('ground_truth', make_grid(to(torch.clamp(infnan_to_zero(1 / data["target"][:img_count]), 0, 100), "cpu"), nrow=2, normalize=True)) 150 | 151 | if batch_idx == self.len_epoch: 152 | break 153 | 154 | log = { 155 | 'loss': total_loss / self.len_epoch, 156 | 'metrics': (total_metrics / total_metrics_valid).tolist() 157 | } 158 | for loss_component, v in total_loss_dict.items(): 159 | log[f"loss_{loss_component}"] = v.item() / self.len_epoch 160 | 161 | if self.do_validation: 162 | val_log = self._valid_epoch(epoch) 163 | log.update(val_log) 164 | 165 | if self.lr_scheduler is not None: 166 | self.lr_scheduler.step() 167 | 168 | return log 169 | 170 | def _valid_epoch(self, epoch): 171 | self.model.eval() 172 | 173 | total_val_loss = 0 174 | total_val_loss_dict = {} 175 | total_val_metrics = np.zeros(len(self.metrics)) 176 | total_val_metrics_valid = np.zeros(len(self.metrics)) 177 | 178 | with torch.no_grad(): 179 | for batch_idx, (data, target) in enumerate(self.valid_data_loader): 180 | data, target = to(data, self.device), to(target, self.device) 181 | data["target"] = target 182 | 183 | if not self.module_loss: 184 | data = self.model(data) 185 | loss_dict = self.loss(data, self.alpha, self.roi_train, options=self.options) 186 | else: 187 | data, loss_dict = self.model(data) 188 | 189 | loss_dict = map_fn(loss_dict, torch.sum) 190 | loss = loss_dict["loss"] 191 | 192 | img_count = min(data["keyframe"].shape[0], 8) 193 | 194 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 195 | if not self.val_avg: 196 | self.writer.add_scalar('loss', loss.item()) 197 | for loss_component, v in loss_dict.items(): 198 | self.writer.add_scalar(f"loss_{loss_component}", v.item()) 199 | total_val_loss += loss.item() 200 | total_val_loss_dict = operator_on_dict(total_val_loss_dict, loss_dict, operator.add) 201 | metrics, valid = self._eval_metrics(data) 202 | total_val_metrics += metrics 203 | total_val_metrics_valid += valid 204 | if batch_idx % self.val_log_step == 0: 205 | if "mask" in data: 206 | if self.invert_output_images: 207 | result = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu() 208 | result /= torch.max(result) * 2 / 3 209 | else: 210 | result = data["result"][:img_count].cpu() 211 | mask = data["mask"][:img_count].cpu() 212 | img = torch.cat([result, mask], dim=2) 213 | else: 214 | if self.invert_output_images: 215 | img = torch.clamp(1 / data["result"][:img_count], 0, 100).cpu() 216 | else: 217 | img = data["result"][:img_count].cpu() 218 | 219 | self.writer.add_image('input', make_grid(to(data["keyframe"][:img_count], "cpu"), nrow=2, normalize=True)) 220 | self.writer.add_image('output', make_grid(img, nrow=2, normalize=True)) 221 | self.writer.add_image('ground_truth', make_grid(to(torch.clamp(infnan_to_zero(1 / data["target"][:img_count]), 0, 100), "cpu"), nrow=2, normalize=True)) 222 | 223 | if self.val_avg: 224 | len_val = len(self.valid_data_loader) 225 | self.writer.add_scalar('loss', total_val_loss / len_val) 226 | for i, metric in enumerate(self.metrics): 227 | self.writer.add_scalar('{}'.format(metric.__name__), total_val_metrics[i] / len_val) 228 | for loss_component, v in total_val_loss_dict.items(): 229 | self.writer.add_scalar(f"loss_{loss_component}", v.item() / len_val) 230 | 231 | result = { 232 | 'val_loss': total_val_loss / len(self.valid_data_loader), 233 | 'val_metrics': (total_val_metrics / total_val_metrics_valid).tolist() 234 | } 235 | 236 | for loss_component, v in total_val_loss_dict.items(): 237 | result[f"val_loss_{loss_component}"] = v.item() / len(self.valid_data_loader) 238 | 239 | return result 240 | 241 | def _progress(self, batch_idx): 242 | base = '[{}/{} ({:.0f}%)]' 243 | if hasattr(self.data_loader, 'n_samples'): 244 | current = batch_idx * self.data_loader.batch_size 245 | total = self.data_loader.n_samples 246 | else: 247 | current = batch_idx 248 | total = self.len_epoch 249 | return base.format(current, total, 100.0 * current / total) 250 | 251 | 252 | def to(data, device): 253 | if isinstance(data, dict): 254 | return {k: to(data[k], device) for k in data.keys()} 255 | elif isinstance(data, list): 256 | return [to(v, device) for v in data] 257 | else: 258 | return data.to(device) 259 | 260 | 261 | def infnan_to_zero(t:torch.Tensor()): 262 | t[torch.isinf(t)] = 0 263 | t[torch.isnan(t)] = 0 264 | return t -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .ply_utils import * -------------------------------------------------------------------------------- /utils/parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, args, options='', timestamp=True): 13 | # parse default and custom cli options 14 | for opt in options: 15 | args.add_argument(*opt.flags, default=None, type=opt.type) 16 | args = args.parse_args() 17 | self.args = args 18 | 19 | if args.device: 20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 21 | if args.resume is None: 22 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 23 | assert args.config is not None, msg_no_cfg 24 | self.cfg_fname = Path(args.config) 25 | config = read_json(self.cfg_fname) 26 | self.resume = None 27 | else: 28 | self.resume = Path(args.resume) 29 | resume_cfg_fname = self.resume.parent / 'config.json' 30 | config = read_json(resume_cfg_fname) 31 | if args.config is not None: 32 | config.update(read_json(Path(args.config))) 33 | 34 | # load config file and apply custom cli options 35 | self._config = _update_config(config, options, args) 36 | 37 | # set save_dir where trained model and log will be saved. 38 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 39 | 40 | if "trainer" in self.config: 41 | save_dir = Path(self.config['trainer']['save_dir']) 42 | if "timestamp_replacement" in self.config["trainer"]: 43 | timestamp = self.config["trainer"]["timestamp_replacement"] 44 | 45 | elif "evaluater" in self.config: 46 | save_dir = Path(self.config['evaluater']['save_dir']) 47 | if "timestamp_replacement" in self.config["evaluater"]: 48 | timestamp = self.config["evaluater"]["timestamp_replacement"] 49 | elif "save_dir" in self.config: 50 | save_dir = Path(self.config["save_dir"]) 51 | else: 52 | save_dir = Path("../saved") 53 | 54 | exper_name = self.config['name'] 55 | self._save_dir = save_dir / 'models' / exper_name / timestamp 56 | self._log_dir = save_dir / 'log' / exper_name / timestamp 57 | 58 | self.save_dir.mkdir(parents=True, exist_ok=True) 59 | self.log_dir.mkdir(parents=True, exist_ok=True) 60 | 61 | # save updated config file to the checkpoint dir 62 | write_json(self.config, self.save_dir / 'config.json') 63 | 64 | # configure logging module 65 | setup_logging(self.log_dir) 66 | self.log_levels = { 67 | 0: logging.WARNING, 68 | 1: logging.INFO, 69 | 2: logging.DEBUG 70 | } 71 | 72 | def initialize(self, name, module, *args, **kwargs): 73 | """ 74 | finds a function handle with the name given as 'type' in config, and returns the 75 | instance initialized with corresponding keyword args given as 'args'. 76 | """ 77 | module_name = self[name]['type'] 78 | module_args = dict(self[name]['args']) 79 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 80 | module_args.update(kwargs) 81 | return getattr(module, module_name)(*args, **module_args) 82 | 83 | def initialize_list(self, name, module, *args, **kwargs): 84 | l = self[name] 85 | for to_init in l: 86 | module_name = to_init["type"] 87 | module_args = dict(to_init["args"]) 88 | module_args.update(kwargs) 89 | yield getattr(module, module_name)(*args, **module_args) 90 | 91 | def __getitem__(self, name): 92 | return self.config[name] 93 | 94 | def get_logger(self, name, verbosity=2): 95 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 96 | assert verbosity in self.log_levels, msg_verbosity 97 | logger = logging.getLogger(name) 98 | logger.setLevel(self.log_levels[verbosity]) 99 | return logger 100 | 101 | # setting read-only attributes 102 | @property 103 | def config(self): 104 | return self._config 105 | 106 | @property 107 | def save_dir(self): 108 | return self._save_dir 109 | 110 | @property 111 | def log_dir(self): 112 | return self._log_dir 113 | 114 | # helper functions used to update config dict with custom cli options 115 | def _update_config(config, options, args): 116 | for opt in options: 117 | value = getattr(args, _get_opt_name(opt.flags)) 118 | if value is not None: 119 | _set_by_path(config, opt.target, value) 120 | return config 121 | 122 | def _get_opt_name(flags): 123 | for flg in flags: 124 | if flg.startswith('--'): 125 | return flg.replace('--', '') 126 | return flags[0].replace('--', '') 127 | 128 | def _set_by_path(tree, keys, value): 129 | """Set a value in a nested object in tree by sequence of keys.""" 130 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 131 | 132 | def _get_by_path(tree, keys): 133 | """Access a nested object in tree by sequence of keys.""" 134 | return reduce(getitem, keys, tree) 135 | -------------------------------------------------------------------------------- /utils/ply_utils.py: -------------------------------------------------------------------------------- 1 | from array import array 2 | 3 | import torch 4 | 5 | from model.layers import Backprojection 6 | 7 | 8 | class PLYSaver(torch.nn.Module): 9 | def __init__(self, height, width, min_d=3, max_d=400, batch_size=1, roi=None, dropout=0): 10 | super(PLYSaver, self).__init__() 11 | self.min_d = min_d 12 | self.max_d = max_d 13 | self.roi = roi 14 | self.dropout = dropout 15 | self.data = array('f') 16 | 17 | self.projector = Backprojection(batch_size, height, width) 18 | 19 | def save(self, file): 20 | length = len(self.data) // 6 21 | header = "ply\n" \ 22 | "format binary_little_endian 1.0\n" \ 23 | f"element vertex {length}\n" \ 24 | f"property float x\n" \ 25 | f"property float y\n" \ 26 | f"property float z\n" \ 27 | f"property float red\n" \ 28 | f"property float green\n" \ 29 | f"property float blue\n" \ 30 | f"end_header\n" 31 | file.write(header.encode(encoding="ascii")) 32 | self.data.tofile(file) 33 | 34 | def add_depthmap(self, depth: torch.Tensor, image: torch.Tensor, intrinsics: torch.Tensor, 35 | extrinsics: torch.Tensor): 36 | depth = 1 / depth 37 | image = (image + .5) * 255 38 | mask = (self.min_d <= depth) & (depth <= self.max_d) 39 | if self.roi is not None: 40 | mask[:, :, :self.roi[0], :] = False 41 | mask[:, :, self.roi[1]:, :] = False 42 | mask[:, :, :, :self.roi[2]] = False 43 | mask[:, :, :, self.roi[3]:] = False 44 | if self.dropout > 0: 45 | mask = mask & (torch.rand_like(depth) > self.dropout) 46 | 47 | coords = self.projector(depth, torch.inverse(intrinsics)) 48 | coords = extrinsics @ coords 49 | coords = coords[:, :3, :] 50 | data_batch = torch.cat([coords, image.view_as(coords)], dim=1).permute(0, 2, 1) 51 | data_batch = data_batch[mask.view(depth.shape[0], 1, -1).permute(0, 2, 1).expand(-1, -1, 6)] 52 | 53 | self.data.extend(data_batch.cpu().tolist()) 54 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | from pathlib import Path 4 | from datetime import datetime 5 | from itertools import repeat 6 | from collections import OrderedDict 7 | 8 | import torch 9 | from PIL import Image 10 | import numpy as np 11 | import torch.nn.functional as F 12 | from kornia.geometry.camera import pixel2cam 13 | from kornia.geometry.depth import DepthWarper 14 | 15 | def map_fn(batch, fn): 16 | if isinstance(batch, dict): 17 | for k in batch.keys(): 18 | batch[k] = map_fn(batch[k], fn) 19 | return batch 20 | elif isinstance(batch, list): 21 | return [map_fn(e, fn) for e in batch] 22 | else: 23 | return fn(batch) 24 | 25 | def to(data, device): 26 | if isinstance(data, dict): 27 | return {k: to(data[k], device) for k in data.keys()} 28 | elif isinstance(data, list): 29 | return [to(v, device) for v in data] 30 | else: 31 | return data.to(device) 32 | 33 | 34 | eps = 1e-6 35 | 36 | def preprocess_roi(depth_prediction, depth_gt: torch.Tensor, roi): 37 | if roi is not None: 38 | if isinstance(depth_prediction, list): 39 | depth_prediction = [dpr[:, :, roi[0]:roi[1], roi[2]:roi[3]] for dpr in depth_prediction] 40 | else: 41 | depth_prediction = depth_prediction[:, :, roi[0]:roi[1], roi[2]:roi[3]] 42 | depth_gt = depth_gt[:, :, roi[0]:roi[1], roi[2]:roi[3]] 43 | return depth_prediction, depth_gt 44 | 45 | 46 | def get_absolute_depth(depth_prediction, depth_gt: torch.Tensor, max_distance=None): 47 | if max_distance is not None: 48 | if isinstance(depth_prediction, list): 49 | depth_prediction = [torch.clamp_min(dpr, 1 / max_distance) for dpr in depth_prediction] 50 | else: 51 | depth_prediction = torch.clamp_min(depth_prediction, 1 / max_distance) 52 | depth_gt = torch.clamp_min(depth_gt, 1 / max_distance) 53 | if isinstance(depth_prediction, list): 54 | return [1 / dpr for dpr in depth_prediction], 1 / depth_gt 55 | else: 56 | return 1 / depth_prediction, 1 / depth_gt 57 | 58 | 59 | def get_positive_depth(depth_prediction: torch.Tensor, depth_gt: torch.Tensor): 60 | if isinstance(depth_prediction, list): 61 | depth_prediction = [torch.nn.functional.relu(dpr) for dpr in depth_prediction] 62 | else: 63 | depth_prediction = torch.nn.functional.relu(depth_prediction) 64 | depth_gt = torch.nn.functional.relu(depth_gt) 65 | return depth_prediction, depth_gt 66 | 67 | 68 | def depthmap_to_points(depth: torch.Tensor, intrinsics: torch.Tensor, flatten=False): 69 | n, c, h, w = depth.shape 70 | grid = DepthWarper._create_meshgrid(h, w).expand(n, -1, -1, -1).to(depth.device) 71 | points = pixel2cam(depth, torch.inverse(intrinsics), grid) 72 | if not flatten: 73 | return points 74 | else: 75 | return points.view(n, h * w, 3) 76 | 77 | 78 | def save_frame_for_tsdf(dir: Path, index, keyframe, depth, pose, crop=None, min_distance=None, max_distance=None): 79 | if crop is not None: 80 | keyframe = keyframe[:, crop[0]:crop[1], crop[2]:crop[3]] 81 | depth = depth[crop[0]:crop[1], crop[2]:crop[3]] 82 | keyframe = ((keyframe + .5) * 255).to(torch.uint8).permute(1, 2, 0) 83 | depth = (1 / depth * 100).to(torch.int16) 84 | depth[depth < 0] = 0 85 | if min_distance is not None: 86 | depth[depth < min_distance * 100] = 0 87 | if max_distance is not None: 88 | depth[depth > max_distance * 100] = 0 89 | Image.fromarray(keyframe.numpy()).save(dir / f"frame-{index:06d}.color.jpg") 90 | Image.fromarray(depth.numpy()).save(dir / f"frame-{index:06d}.depth.png") 91 | np.savetxt(dir / f"frame-{index:06d}.pose.txt", torch.inverse(pose).numpy()) 92 | 93 | 94 | def save_intrinsics_for_tsdf(dir: Path, intrinsics, crop=None): 95 | if crop is not None: 96 | intrinsics[0, 2] -= crop[2] 97 | intrinsics[1, 2] -= crop[0] 98 | np.savetxt(dir / f"camera-intrinsics.txt", intrinsics[:3, :3].numpy()) 99 | 100 | 101 | def get_mask(pred: torch.Tensor, gt: torch.Tensor, max_distance=None, pred_all_valid=True): 102 | mask = gt == 0 103 | if max_distance: 104 | mask |= (gt < 1 / max_distance) 105 | if not pred_all_valid: 106 | mask |= pred == 0 107 | return mask 108 | 109 | 110 | def mask_mean(t: torch.Tensor, m: torch.Tensor, dim=None): 111 | t = t.clone() 112 | t[m] = 0 113 | els = 1 114 | if dim is None: 115 | dim = list(range(len(t.shape))) 116 | for d in dim: 117 | els *= t.shape[d] 118 | return torch.sum(t, dim=dim) / (els - torch.sum(m.to(torch.float), dim=dim)) 119 | 120 | 121 | def conditional_flip(x, condition, inplace=True): 122 | if inplace: 123 | x[condition, :, :, :] = x[condition, :, :, :].flip(3) 124 | else: 125 | flipped_x = x.clone() 126 | flipped_x[condition, :, :, :] = x[condition, :, :, :].flip(3) 127 | return flipped_x 128 | 129 | 130 | def create_mask(c: int, height: int, width: int, border_radius: int, device): 131 | mask = torch.ones(c, 1, height - 2 * border_radius, width - 2 * border_radius, device=device) 132 | return torch.nn.functional.pad(mask, [border_radius, border_radius, border_radius, border_radius]) 133 | 134 | 135 | def median_scaling(data_dict): 136 | target = data_dict["target"] 137 | prediction = data_dict["result"] 138 | mask = target > 0 139 | ratios = mask.new_tensor([torch.median(target[i, mask[i]]) / torch.median(prediction[i, mask[i]]) for i in range(target.shape[0])], dtype=torch.float32) 140 | data_dict = dict(data_dict) 141 | data_dict["result"] = prediction * ratios.view(-1, 1, 1, 1) 142 | return data_dict 143 | 144 | 145 | unsqueezer = partial(torch.unsqueeze, dim=0) 146 | 147 | 148 | class DS_Wrapper(torch.utils.data.Dataset): 149 | def __init__(self, dataset, start=0, end=-1, every_nth=1): 150 | super().__init__() 151 | self.dataset = dataset 152 | self.start = start 153 | if end == -1: 154 | self.end = len(self.dataset) 155 | else: 156 | self.end = end 157 | self.every_nth = every_nth 158 | 159 | def __getitem__(self, index: int): 160 | return self.dataset.__getitem__(index * self.every_nth + self.start) 161 | 162 | def __len__(self): 163 | return (self.end - self.start) // self.every_nth + (1 if (self.end - self.start) % self.every_nth != 0 else 0) 164 | 165 | class DS_Merger(torch.utils.data.Dataset): 166 | def __init__(self, datasets): 167 | super().__init__() 168 | self.datasets = datasets 169 | 170 | def __getitem__(self, index: int): 171 | return (ds.__getitem__(index + self.start) for ds in self.datasets) 172 | 173 | def __len__(self): 174 | return len(self.datasets[0]) 175 | 176 | 177 | class LossWrapper(torch.nn.Module): 178 | def __init__(self, loss_function, **kwargs): 179 | super().__init__() 180 | self.kwargs = kwargs 181 | self.loss_function = loss_function 182 | self.num_devices = 1.0 183 | 184 | def forward(self, data): 185 | loss_dict = self.loss_function(data, **self.kwargs) 186 | loss_dict = map_fn(loss_dict, lambda x: (x / self.num_devices)) 187 | if loss_dict["loss"].requires_grad: 188 | loss_dict["loss"].backward() 189 | loss_dict["loss"].detach_() 190 | return data, loss_dict 191 | 192 | 193 | class ValueFader: 194 | def __init__(self, steps, values): 195 | self.steps = steps 196 | self.values = values 197 | self.epoch = 0 198 | 199 | def set_epoch(self, epoch): 200 | self.epoch = epoch 201 | 202 | def get_value(self, epoch=None): 203 | if epoch is None: 204 | epoch = self.epoch 205 | if epoch >= self.steps[-1]: 206 | return self.values[-1] 207 | 208 | step_index = 0 209 | 210 | while step_index < len(self.steps)-1 and epoch >= self.steps[step_index+1]: 211 | step_index += 1 212 | 213 | p = float(epoch - self.steps[step_index]) / float(self.steps[step_index+1] - self.steps[step_index]) 214 | return (1-p) * self.values[step_index] + p * self.values[step_index+1] 215 | 216 | 217 | def pose_distance_thresh(data_dict, spatial_thresh=.6, rotational_thresh=.05): 218 | poses = torch.stack([data_dict["keyframe_pose"]] + data_dict["poses"], dim=1) 219 | forward = poses.new_tensor([0, 0, 1], dtype=torch.float32) 220 | spatial_expanse = torch.norm(torch.max(poses[..., :3, 3], dim=1)[0] - torch.min(poses[..., :3, 3], dim=1)[0], dim=1) 221 | rotational_expanse = torch.norm(torch.max(poses[..., :3, :3] @ forward, dim=1)[0] - torch.min(poses[..., :3, :3] @ forward, dim=1)[0], dim=1) 222 | return (spatial_expanse > spatial_thresh) | (rotational_expanse > rotational_thresh) 223 | 224 | 225 | def dilate_mask(m: torch.Tensor, size: int = 15): 226 | k = m.new_ones((1, 1, size, size), dtype=torch.float32) 227 | dilated_mask = F.conv2d((m >= 0.5).to(dtype=torch.float32), k, padding=(size//2, size//2)) 228 | return dilated_mask > 0 229 | 230 | 231 | def operator_on_dict(dict_0: dict, dict_1: dict, operator, default=0): 232 | keys = set(dict_0.keys()).union(set(dict_1.keys())) 233 | results = {} 234 | for k in keys: 235 | v_0 = dict_0[k] if k in dict_0 else default 236 | v_1 = dict_1[k] if k in dict_1 else default 237 | results[k] = operator(v_0, v_1) 238 | return results 239 | 240 | 241 | numbers = [f"{i:d}" for i in range(1, 10, 1)] 242 | 243 | 244 | def filter_state_dict(state_dict, data_parallel=False): 245 | if data_parallel: 246 | state_dict = {k[7:]: state_dict[k] for k in state_dict} 247 | state_dict = {(k[2:] if k.startswith("0") else k): state_dict[k] for k in state_dict if not k[0] in numbers} 248 | return state_dict 249 | 250 | 251 | def seed_rng(seed): 252 | torch.manual_seed(seed) 253 | import random 254 | random.seed(seed) 255 | np.random.seed(0) 256 | 257 | 258 | def ensure_dir(dirname): 259 | dirname = Path(dirname) 260 | if not dirname.is_dir(): 261 | dirname.mkdir(parents=True, exist_ok=False) 262 | 263 | def read_json(fname): 264 | with fname.open('rt') as handle: 265 | return json.load(handle, object_hook=OrderedDict) 266 | 267 | def write_json(content, fname): 268 | with fname.open('wt') as handle: 269 | json.dump(content, handle, indent=4, sort_keys=False) 270 | 271 | def inf_loop(data_loader): 272 | ''' wrapper function for endless data loader. ''' 273 | for loader in repeat(data_loader): 274 | yield from loader 275 | 276 | class Timer: 277 | def __init__(self): 278 | self.cache = datetime.now() 279 | 280 | def check(self): 281 | now = datetime.now() 282 | duration = now - self.cache 283 | self.cache = now 284 | return duration.total_seconds() 285 | 286 | def reset(self): 287 | self.cache = datetime.now() 288 | --------------------------------------------------------------------------------