├── LICENSE ├── README.md ├── TrapAttention.png ├── configs ├── _base_ │ ├── datasets │ │ ├── kitti.py │ │ ├── kitti_benchmark.py │ │ ├── nyu.py │ │ └── sun_rgbd.py │ ├── default_runtime.py │ ├── models │ │ ├── trap.py │ │ ├── trap_swin_l_kitti.py │ │ ├── trap_swin_l_nyu.py │ │ ├── trap_xcit_m_24_kitti.py │ │ ├── trap_xcit_m_24_nyu.py │ │ ├── trap_xcit_s_12_kitti.py │ │ └── trap_xcit_s_12_nyu.py │ └── schedules │ │ └── schedule_24x.py └── trap │ ├── trap_swin_l_kitti.py │ ├── trap_swin_l_kitti_benchmark.py │ ├── trap_swin_l_win7_nyu.py │ ├── trap_xcit_m_24_kitti.py │ ├── trap_xcit_m_24_nyu.py │ ├── trap_xcit_s_12_kitti.py │ └── trap_xcit_s_12_nyu.py ├── depth ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── eval_hooks.py │ │ └── metrics.py │ └── utils │ │ ├── __init__.py │ │ └── misc.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── cityscapes.py │ ├── custom.py │ ├── dataset_wrappers.py │ ├── kitti.py │ ├── nyu.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formating.py │ │ ├── loading.py │ │ ├── test_time_aug.py │ │ └── transforms.py │ └── sunrgbd.py ├── mmcv_custom │ ├── __init__.py │ ├── apex_runner │ │ ├── __init__.py │ │ ├── apex_iter_based_runner.py │ │ ├── checkpoint.py │ │ └── optimizer.py │ ├── checkpoint.py │ ├── customized_text.py │ ├── layer_decay_optimizer_constructor.py │ ├── resize_transform.py │ └── train_api.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── swin_transformer.py │ │ └── xcit.py │ ├── builder.py │ ├── decode_heads │ │ ├── __init__.py │ │ ├── decode_head.py │ │ └── trap_head.py │ ├── depther │ │ ├── __init__.py │ │ ├── base.py │ │ └── encoder_decoder.py │ ├── losses │ │ ├── __init__.py │ │ └── sigloss.py │ ├── necks │ │ ├── __init__.py │ │ └── block_selection.py │ ├── trap.py │ └── utils │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── ckpt_convert.py │ │ ├── embed.py │ │ ├── hooks │ │ ├── __init__.py │ │ └── tensorboard_hook.py │ │ ├── inverted_residual.py │ │ ├── logger.py │ │ ├── make_divisible.py │ │ ├── res_layer.py │ │ ├── se_layer.py │ │ ├── self_attention_block.py │ │ ├── shape_convert.py │ │ └── up_conv_block.py ├── ops │ ├── __init__.py │ ├── encoding.py │ └── wrappers.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── color_depth.py │ ├── logger.py │ └── position_encoding.py └── version.py ├── splits ├── SUNRGBD_val_splits.txt ├── kitti_eigen_test.txt ├── kitti_eigen_train.txt ├── nyu_test.txt └── nyu_train.txt └── tools ├── benchmark.py ├── dist_test.sh ├── dist_train.sh ├── ensemble.py ├── misc └── visualize_point-cloud.py ├── print_config.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ICSResearch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Trap Attention 2 | The implementation of Trap Attention: Monocular Depth Estimation with Manual Traps 3 | 4 | ![](TrapAttention.png) 5 | 6 | ## Environment 7 | - python 3.8 8 | - pytorch 1.7.1 (>=1.2.0) 9 | 10 | ## Checkpoint 11 | [Google drive](https://drive.google.com/drive/folders/1kIXg9UP0cVWUq_7Pq20JT9_RyR-PjvkS?usp=sharing) 12 | 13 | ## Citation 14 | If you use this code for a paper, please cite: 15 | ``` 16 | @inproceedings{ning2023trap, 17 | title={Trap Attention: Monocular Depth Estimation with Manual Traps}, 18 | author={Chao Ning and Hongping Gan}, 19 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision and Pattern Recognition}, 20 | pages={}, 21 | year={2023} 22 | } 23 | ``` 24 | 25 | ## Acknowledgement 26 | This repo benefits from awesome works of [Timm](https://github.com/rwightman/pytorch-image-models), 27 | [Monocular-Depth-Estimation-Toolbox](https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/tree/main/configs/bts), 28 | [MMSeg](https://github.com/open-mmlab/mmsegmentation) 29 | and 30 | [Monocular-Depth-Estimation-Toolbox](https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox). 31 | 32 | 33 | -------------------------------------------------------------------------------- /TrapAttention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TrapAttention/7d16a09579055098f01483b244ba37f762fe15e5/TrapAttention.png -------------------------------------------------------------------------------- /configs/_base_/datasets/kitti.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'KITTIDataset' 2 | data_root = r'' 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='DepthLoadAnnotations'), 8 | dict(type='LoadKITTICamIntrinsic'), 9 | dict(type='KBCrop', depth=True), 10 | dict(type='RandomRotate', prob=0.5, degree=2.5), 11 | dict(type='RandomFlip', prob=0.5), 12 | dict(type='RandomCrop', crop_size=(352, 1120)), 13 | dict(type='ColorAug', prob=0.5, gamma_range=[0.9, 1.1], brightness_range=[0.9, 1.1], color_range=[0.9, 1.1]), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='DefaultFormatBundle'), 16 | dict(type='Collect', 17 | keys=['img', 'depth_gt'], 18 | meta_keys=('filename', 'ori_filename', 'ori_shape', 19 | 'img_shape', 'pad_shape', 'scale_factor', 20 | 'flip', 'flip_direction', 'img_norm_cfg', 21 | 'cam_intrinsic')), 22 | ] 23 | test_pipeline = [ 24 | dict(type='LoadImageFromFile'), 25 | dict(type='LoadKITTICamIntrinsic'), 26 | dict(type='KBCrop', depth=False), 27 | dict( 28 | type='MultiScaleFlipAug', 29 | img_scale=(1216, 352), 30 | flip=True, 31 | flip_direction='horizontal', 32 | transforms=[ 33 | dict(type='RandomFlip', direction='horizontal'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='ImageToTensor', keys=['img']), 36 | dict(type='Collect', 37 | keys=['img'], 38 | meta_keys=('filename', 'ori_filename', 'ori_shape', 39 | 'img_shape', 'pad_shape', 'scale_factor', 40 | 'flip', 'flip_direction', 'img_norm_cfg', 41 | 'cam_intrinsic')), 42 | ]) 43 | ] 44 | data = dict( 45 | samples_per_gpu=8, 46 | workers_per_gpu=8, 47 | train=dict( 48 | type=dataset_type, 49 | data_root=data_root, 50 | img_dir='input/', 51 | ann_dir='gt_depth/', 52 | depth_scale=256, 53 | split='kitti_eigen_train.txt', 54 | pipeline=train_pipeline, 55 | garg_crop=True, 56 | eigen_crop=False, 57 | min_depth=1e-3, 58 | max_depth=80), 59 | val=dict( 60 | type=dataset_type, 61 | data_root=data_root, 62 | img_dir='input/', 63 | ann_dir='gt_depth/', 64 | depth_scale=256, 65 | split='kitti_eigen_test.txt', 66 | pipeline=test_pipeline, 67 | garg_crop=True, 68 | eigen_crop=False, 69 | min_depth=1e-3, 70 | max_depth=80), 71 | test=dict( 72 | type=dataset_type, 73 | data_root=data_root, 74 | img_dir='input/', 75 | ann_dir='gt_depth/', 76 | depth_scale=256, 77 | split='kitti_eigen_test.txt', 78 | pipeline=test_pipeline, 79 | garg_crop=True, 80 | eigen_crop=False, 81 | min_depth=1e-3, 82 | max_depth=80)) 83 | 84 | -------------------------------------------------------------------------------- /configs/_base_/datasets/kitti_benchmark.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'KITTIDataset' 3 | data_root = r'' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size= (352, 704) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='DepthLoadAnnotations'), 10 | dict(type='LoadKITTICamIntrinsic'), 11 | dict(type='KBCrop', depth=True), 12 | dict(type='RandomRotate', prob=0.5, degree=2.5), 13 | dict(type='RandomFlip', prob=0.5), 14 | dict(type='RandomCrop', crop_size=(352, 1120)), 15 | dict(type='ColorAug', prob=1, gamma_range=[0.9, 1.1], brightness_range=[0.9, 1.1], color_range=[0.9, 1.1]), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', 19 | keys=['img', 'depth_gt'], 20 | meta_keys=('filename', 'ori_filename', 'ori_shape', 21 | 'img_shape', 'pad_shape', 'scale_factor', 22 | 'flip', 'flip_direction', 'img_norm_cfg', 23 | 'cam_intrinsic') 24 | ), 25 | ] 26 | test_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='LoadKITTICamIntrinsic'), 29 | dict(type='KBCrop', depth=False), 30 | dict( 31 | type='MultiScaleFlipAug', 32 | img_scale=(352, 1216), 33 | flip=True, 34 | flip_direction='horizontal', 35 | transforms=[ 36 | dict(type='RandomFlip', direction='horizontal'), 37 | dict(type='Normalize', **img_norm_cfg), 38 | dict(type='ImageToTensor', keys=['img']), 39 | dict(type='Collect', 40 | keys=['img'], 41 | meta_keys=('filename', 'ori_filename', 'ori_shape', 42 | 'img_shape', 'pad_shape', 'scale_factor', 43 | 'flip', 'flip_direction', 'img_norm_cfg', 44 | 'cam_intrinsic')), 45 | ]) 46 | ] 47 | data = dict( 48 | samples_per_gpu=16, 49 | workers_per_gpu=8, 50 | train=dict( 51 | type=dataset_type, 52 | data_root=data_root, 53 | img_dir='input/', 54 | ann_dir='gt_depth/', 55 | depth_scale=256, 56 | split='kitti_benchmark_train.txt', 57 | pipeline=train_pipeline, 58 | garg_crop=True, 59 | eigen_crop=False, 60 | min_depth=1e-3, 61 | max_depth=88), 62 | val=dict( 63 | type=dataset_type, 64 | data_root=data_root, 65 | img_dir='official/', 66 | ann_dir='official/', 67 | depth_scale=256, 68 | split='benchmark_val.txt', 69 | pipeline=test_pipeline, 70 | garg_crop=True, 71 | eigen_crop=False, 72 | min_depth=1e-3, 73 | max_depth=88), 74 | test=dict( 75 | type=dataset_type, 76 | data_root=data_root, 77 | img_dir='official/', 78 | # ann_dir='gt_depth', 79 | depth_scale=256, 80 | split='benchmark_test.txt', 81 | pipeline=test_pipeline, 82 | garg_crop=True, 83 | eigen_crop=False, 84 | min_depth=1e-3, 85 | max_depth=88) 86 | ) 87 | 88 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nyu.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'NYUDataset' 3 | data_root = 'D:/dataset/Monocular/data/nyu' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='DepthLoadAnnotations'), 10 | dict(type='NYUMask'), 11 | dict(type='RandomRotate', prob=0.5, degree=2.5), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='ColorAug', prob=0.5, gamma_range=[0.9, 1.1], brightness_range=[0.75, 1.25], color_range=[0.9, 1.1]), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='DefaultFormatBundle'), 16 | dict(type='Collect', 17 | keys=['img', 'depth_gt'], 18 | meta_keys=('filename', 'ori_filename', 'ori_shape', 19 | 'img_shape', 'pad_shape', 'scale_factor', 20 | 'flip', 'flip_direction', 'img_norm_cfg', 21 | 'cam_intrinsic')), 22 | ] 23 | test_pipeline = [ 24 | dict(type='LoadImageFromFile'), 25 | dict( 26 | type='MultiScaleFlipAug', 27 | img_scale=(480, 640), 28 | flip=True, 29 | flip_direction='horizontal', 30 | transforms=[ 31 | dict(type='RandomFlip', direction='horizontal'), 32 | dict(type='Normalize', **img_norm_cfg), 33 | dict(type='ImageToTensor', keys=['img']), 34 | dict(type='Collect', 35 | keys=['img'], 36 | meta_keys=('filename', 'ori_filename', 'ori_shape', 37 | 'img_shape', 'pad_shape', 'scale_factor', 38 | 'flip', 'flip_direction', 'img_norm_cfg', 39 | 'cam_intrinsic')), 40 | ]) 41 | ] 42 | 43 | # for visualization of pc 44 | eval_pipeline = [ 45 | dict(type='LoadImageFromFile'), 46 | dict(type='RandomFlip', prob=0.0), # set to zero 47 | dict(type='Normalize', **img_norm_cfg), 48 | dict(type='ImageToTensor', keys=['img']), 49 | dict(type='Collect', 50 | keys=['img'], 51 | meta_keys=('filename', 'ori_filename', 'ori_shape', 52 | 'img_shape', 'pad_shape', 'scale_factor', 53 | 'flip', 'flip_direction', 'img_norm_cfg', 54 | 'cam_intrinsic')), 55 | ] 56 | 57 | data = dict( 58 | samples_per_gpu=4, 59 | workers_per_gpu=8, 60 | train=dict( 61 | type=dataset_type, 62 | data_root=data_root, 63 | depth_scale=1000, 64 | split='nyu_train.txt', 65 | pipeline=train_pipeline, 66 | garg_crop=False, 67 | eigen_crop=True, 68 | min_depth=1e-3, 69 | max_depth=10), 70 | val=dict( 71 | type=dataset_type, 72 | data_root=data_root, 73 | depth_scale=1000, 74 | split='nyu_test.txt', 75 | pipeline=test_pipeline, 76 | garg_crop=False, 77 | eigen_crop=True, 78 | min_depth=1e-3, 79 | max_depth=10), 80 | test=dict( 81 | type=dataset_type, 82 | data_root=data_root, 83 | depth_scale=1000, 84 | split='nyu_test.txt', 85 | pipeline=test_pipeline, 86 | garg_crop=False, 87 | eigen_crop=True, 88 | min_depth=1e-3, 89 | max_depth=10)) 90 | 91 | -------------------------------------------------------------------------------- /configs/_base_/datasets/sun_rgbd.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'SUNRGBDDataset' 3 | data_root = '' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | 7 | 8 | train_pipeline = [ 9 | ] 10 | 11 | test_pipeline = [ 12 | dict(type='LoadImageFromFile'), 13 | dict(type='SUNCrop', depth=False, height=480, width=640), 14 | dict( 15 | type='MultiScaleFlipAug', 16 | img_scale=(480, 640), 17 | flip=False, 18 | flip_direction='horizontal', 19 | transforms=[ 20 | # dict(type='RandomFlip', direction='horizontal'), 21 | dict(type='Normalize', **img_norm_cfg), 22 | # dict(type='Resize', img_scale=(480, 640), keep_ratio=True), 23 | dict(type='ImageToTensor', keys=['img']), 24 | # dict(type='DefaultFormatBundle'), 25 | dict(type='Collect', 26 | keys=['img'], 27 | # keys=['img', 'depth_gt'], 28 | # meta_keys=('cam_intrinsic', 'ori_shape', 'img_shape') 29 | meta_keys=('filename', 'ori_filename', 'ori_shape', 30 | 'img_shape', 'pad_shape', 'scale_factor', 31 | 'flip', 'flip_direction', 'img_norm_cfg', 32 | 'cam_intrinsic') 33 | ), 34 | ]) 35 | ] 36 | eval_pipeline = [ 37 | dict(type='LoadImageFromFile'), 38 | dict(type='SUNCrop', depth=False, height=480, width=640), 39 | dict(type='RandomFlip', prob=0.0), # set to zero 40 | dict(type='Normalize', **img_norm_cfg), 41 | dict(type='ImageToTensor', keys=['img']), 42 | dict(type='Collect', 43 | keys=['img'], 44 | meta_keys=('filename', 'ori_filename', 'ori_shape', 45 | 'img_shape', 'pad_shape', 'scale_factor', 46 | 'flip', 'flip_direction', 'img_norm_cfg', 47 | 'cam_intrinsic')), 48 | ] 49 | data = dict( 50 | samples_per_gpu=16, 51 | workers_per_gpu=8, 52 | train=dict( 53 | type=dataset_type, 54 | data_root=data_root, 55 | depth_scale=1000, 56 | split='SUNRGBD_val_splits.txt', 57 | pipeline=test_pipeline, 58 | garg_crop=False, 59 | eigen_crop=True, 60 | min_depth=1e-3, 61 | max_depth=10), 62 | test=dict( 63 | type=dataset_type, 64 | data_root=data_root, 65 | depth_scale=8000, 66 | split='SUNRGBD_val_splits.txt', 67 | pipeline=test_pipeline, 68 | min_depth=1e-3, 69 | max_depth=10)) 70 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=True), 6 | # dict(type='TensorboardImageLoggerHook', by_epoch=True), 7 | ]) 8 | # yapf:enable 9 | # dist_params = dict(backend='nccl') 10 | dist_params = dict(backend='gloo') 11 | log_level = 'INFO' 12 | load_from = None 13 | resume_from = None 14 | workflow = [('train', 1)] 15 | cudnn_benchmark = True 16 | -------------------------------------------------------------------------------- /configs/_base_/models/trap.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='DepthEncoderDecoder', 4 | backbone=dict( 5 | type='XCiT', 6 | pretrained=r'', 7 | patch_size=8, 8 | embed_dim=384, 9 | depth=12, 10 | num_heads=8, 11 | mlp_ratio=4, 12 | qkv_bias=True, 13 | eta=1.0, 14 | drop_path_rate=0.05, 15 | out_indices=range(12) 16 | ), 17 | neck=dict( 18 | type='BlockSelectionNeck', 19 | in_channels=[384] * 5, 20 | out_channels=[64, 96, 192, 384, 768], 21 | start=[2, 4, 6, 8, 10], 22 | end=[4, 6, 8, 10, 12], 23 | scales=[4, 2, 1, .5, .25]), 24 | decode_head=dict( 25 | type='TrappedHead', 26 | in_channels=[64, 96, 192, 384, 768], 27 | post_process_channels=[64, 96, 192, 384, 768], 28 | # up_sample_channels=[128, 256, 512, 1024, 2048], 29 | channels=32, # last one 30 | final_norm=False, 31 | scale_up=True, 32 | align_corners=False, # for upsample 33 | loss_decode=dict( 34 | type='SigLoss', valid_mask=True, loss_weight=10)), 35 | # model training and testing settings 36 | train_cfg=dict(), 37 | test_cfg=dict(mode='whole')) 38 | -------------------------------------------------------------------------------- /configs/_base_/models/trap_swin_l_kitti.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../datasets/kitti.py', 3 | '../default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='Swin', 12 | pretrained= 13 | 'swin_large_patch4_window7_224_22k.pth', 14 | embed_dims=192, 15 | patch_size=4, 16 | window_size=7, 17 | mlp_ratio=4, 18 | depths=[2, 2, 18, 2], 19 | num_heads=[6, 12, 24, 48], 20 | strides=(4, 2, 2, 2), 21 | out_indices=range(0, 24), 22 | qkv_bias=True, 23 | qk_scale=None, 24 | patch_norm=True, 25 | drop_rate=0.0, 26 | attn_drop_rate=0.0, 27 | drop_path_rate=0.3, 28 | use_abs_pos_embed=False, 29 | act_cfg=dict(type='GELU'), 30 | norm_cfg=dict(type='LN', requires_grad=True), 31 | pretrain_style='official'), 32 | neck=dict( 33 | type='BlockSelectionNeck', 34 | in_channels=[192, 384, 768, 768, 1536], 35 | out_channels=[192, 288, 576, 1152, 2304], 36 | start=[0, 2, 4, 10, 22], 37 | end=[2, 4, 16, 22, 24], 38 | scales=[2, 2, 2, 1.0, 1.0]), 39 | decode_head=dict( 40 | type='TrappedHead', 41 | in_channels=[192, 288, 576, 1152, 2304], 42 | post_process_channels=[192, 288, 576, 1152, 2304], 43 | channels=96, 44 | final_norm=False, 45 | scale_up=True, 46 | # drop_path_rate=0.3, 47 | align_corners=False, 48 | min_depth=0.001, 49 | max_depth=80, 50 | loss_decode=dict(type='SigLoss', valid_mask=True, loss_weight=10)), 51 | train_cfg=dict(), 52 | test_cfg=dict(mode='whole')) 53 | 54 | find_unused_parameters=True 55 | SyncBN = True 56 | 57 | # batch size 58 | data = dict( 59 | samples_per_gpu=2, 60 | workers_per_gpu=8, 61 | ) 62 | 63 | # schedules 64 | # optimizer 65 | max_lr = 0.0001 66 | optimizer = dict( 67 | type='AdamW', 68 | lr=max_lr, 69 | betas=(0.9, 0.999), 70 | weight_decay=0.01, 71 | paramwise_cfg=dict( 72 | custom_keys={ 73 | 'absolute_pos_embed': dict(decay_mult=0.), 74 | 'relative_position_bias_table': dict(decay_mult=0.), 75 | 'norm': dict(decay_mult=0.), 76 | })) 77 | 78 | lr_config = dict(policy='poly', 79 | warmup='linear', 80 | warmup_iters=3200, 81 | warmup_ratio=1e-6, 82 | power=1.0, min_lr=0.0, by_epoch=False) 83 | 84 | 85 | optimizer_config = dict() 86 | runner = dict(type='IterBasedRunner', max_iters=320000) 87 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 88 | evaluation = dict(by_epoch=False, 89 | start=0, 90 | interval=1600, 91 | pre_eval=True, 92 | rule='less', 93 | save_best='abs_rel', 94 | greater_keys=("a1", "a2", "a3"), 95 | less_keys=("abs_rel", "rmse")) 96 | 97 | # iter runtime 98 | log_config = dict( 99 | _delete_=True, 100 | interval=50, 101 | hooks=[ 102 | dict(type='TextLoggerHook', by_epoch=False), 103 | ]) -------------------------------------------------------------------------------- /configs/_base_/models/trap_swin_l_nyu.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../datasets/nyu.py', 3 | '../default_runtime.py' 4 | ] 5 | 6 | backbone_norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='Swin', 12 | pretrained=r'', 13 | embed_dims=192, 14 | patch_size=4, 15 | window_size=7, 16 | mlp_ratio=4, 17 | depths=[2, 2, 18, 2], 18 | num_heads=[6, 12, 24, 48], 19 | strides=(4, 2, 2, 2), 20 | out_indices=range(24), 21 | qkv_bias=True, 22 | qk_scale=None, 23 | patch_norm=True, 24 | drop_rate=0., 25 | attn_drop_rate=0., 26 | drop_path_rate=0.5, 27 | use_abs_pos_embed=False, 28 | act_cfg=dict(type='GELU'), 29 | norm_cfg=backbone_norm_cfg, 30 | pretrain_style='official', 31 | ), 32 | neck=dict( 33 | type='BlockSelectionNeck', 34 | in_channels=[192, 384, 768, 768, 1536], 35 | out_channels=[192, 288, 576, 1152, 2304], 36 | start=[0, 2, 4, 10, 22], 37 | end=[2, 4, 16, 22, 24], 38 | scales=[2, 2, 2, 1., 1.]), 39 | decode_head=dict( 40 | type='TrappedHead', 41 | in_channels=[192, 288, 576, 1152, 2304], 42 | post_process_channels=[192, 288, 576, 1152, 2304], 43 | channels=96, # last one 44 | final_norm=False, 45 | scale_up=True, 46 | align_corners=False, # for upsample 47 | min_depth=1e-3, 48 | max_depth=10, 49 | loss_decode=dict( 50 | type='SigLoss', valid_mask=True, loss_weight=10)), 51 | # model training and testing settings 52 | train_cfg=dict(), 53 | test_cfg=dict(mode='whole') 54 | ) 55 | 56 | find_unused_parameters = True 57 | SyncBN = True 58 | 59 | # batch size 60 | data = dict( 61 | samples_per_gpu=2, 62 | workers_per_gpu=8, 63 | ) 64 | 65 | # schedules 66 | # optimizer 67 | max_lr = 0.0001 68 | 69 | 70 | optimizer = dict( 71 | type='AdamW', 72 | lr=max_lr, 73 | betas=(0.9, 0.999), 74 | weight_decay=0.01, 75 | paramwise_cfg=dict( 76 | custom_keys={ 77 | 'absolute_pos_embed': dict(decay_mult=0.), 78 | 'relative_position_bias_table': dict(decay_mult=0.), 79 | 'norm': dict(decay_mult=0.), 80 | })) 81 | # learning policy 82 | lr_config = dict(policy='poly', 83 | warmup='linear', 84 | warmup_iters=3200, 85 | warmup_ratio=1e-6, 86 | power=1.0, min_lr=0.0, by_epoch=False) 87 | 88 | optimizer_config = dict() 89 | runner = dict(type='IterBasedRunner', max_iters=320000) 90 | 91 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 92 | evaluation = dict(by_epoch=False, 93 | start=0, 94 | interval=1600, 95 | pre_eval=True, 96 | rule='less', 97 | save_best='abs_rel', 98 | greater_keys=("a1", "a2", "a3"), 99 | less_keys=("abs_rel", "rmse")) 100 | 101 | # iter runtime 102 | log_config = dict( 103 | _delete_=True, 104 | interval=50, 105 | hooks=[ 106 | dict(type='TextLoggerHook', by_epoch=False), 107 | ]) -------------------------------------------------------------------------------- /configs/_base_/models/trap_xcit_m_24_kitti.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/kitti.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='XCiT', 12 | pretrained=r'', 13 | patch_size=8, 14 | embed_dim=512, 15 | depth=24, 16 | num_heads=8, 17 | mlp_ratio=4, 18 | qkv_bias=True, 19 | eta=1e-5, 20 | out_indices=range(24), 21 | ), 22 | neck=dict( 23 | type='BlockSelectionNeck', 24 | in_channels=[512] * 5, 25 | out_channels=[128, 192, 384, 768, 1536], 26 | start=[0, 4, 8, 12, 16], 27 | end=[8, 12, 16, 20, 24], 28 | scales=[4, 2, 1, .5, .25]), 29 | decode_head=dict( 30 | type='TrappedHead', 31 | in_channels=[128, 192, 384, 768, 1536], 32 | post_process_channels=[128, 192, 384, 768, 1536], 33 | channels=64, # last one 34 | final_norm=False, 35 | scale_up=True, 36 | align_corners=False, # for upsample 37 | min_depth=1e-3, 38 | max_depth=80, 39 | loss_decode=dict( 40 | type='SigLoss', valid_mask=True, loss_weight=10)), 41 | # model training and testing settings 42 | train_cfg=dict(), 43 | test_cfg=dict(mode='whole') 44 | ) 45 | 46 | find_unused_parameters = True 47 | SyncBN = True 48 | 49 | # batch size 50 | data = dict( 51 | samples_per_gpu=2, 52 | workers_per_gpu=8, 53 | ) 54 | 55 | # schedules 56 | # optimizer 57 | max_lr = 0.0001 58 | optimizer = dict( 59 | type='AdamW', 60 | lr=max_lr, 61 | betas=(0.9, 0.999), 62 | weight_decay=0.01, 63 | paramwise_cfg=dict( 64 | custom_keys={ 65 | 'absolute_pos_embed': dict(decay_mult=0.), 66 | 'relative_position_bias_table': dict(decay_mult=0.), 67 | 'norm': dict(decay_mult=0.), 68 | })) 69 | # learning policy 70 | lr_config = dict(policy='poly', 71 | warmup='linear', 72 | warmup_iters=3200, 73 | warmup_ratio=1e-6, 74 | power=1.0, min_lr=0.0, by_epoch=False) 75 | 76 | optimizer_config = dict() 77 | # runtime settings 78 | runner = dict(type='IterBasedRunner', max_iters=320000) 79 | 80 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 81 | evaluation = dict(by_epoch=False, 82 | start=0, 83 | interval=1600, 84 | pre_eval=True, 85 | rule='less', 86 | save_best='abs_rel', 87 | greater_keys=("a1", "a2", "a3"), 88 | less_keys=("abs_rel", "rmse")) 89 | 90 | # iter runtime 91 | log_config = dict( 92 | _delete_=True, 93 | interval=50, 94 | hooks=[ 95 | dict(type='TextLoggerHook', by_epoch=False), 96 | ]) -------------------------------------------------------------------------------- /configs/_base_/models/trap_xcit_m_24_nyu.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nyu.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='XCiT', 12 | pretrained=r'', 13 | patch_size=8, 14 | embed_dim=512, 15 | depth=24, 16 | num_heads=8, 17 | mlp_ratio=4, 18 | qkv_bias=True, 19 | eta=1e-5, 20 | drop_path_rate=0., 21 | out_indices=range(24), 22 | ), 23 | neck=dict( 24 | type='BlockSelectionNeck', 25 | in_channels=[512] * 5, 26 | out_channels=[128, 192, 384, 768, 1536], 27 | start=[0, 4, 8, 12, 16], 28 | end=[8, 12, 16, 20, 24], 29 | scales=[4, 2, 1, .5, .25]), 30 | decode_head=dict( 31 | type='TrappedHead', 32 | in_channels=[128, 192, 384, 768, 1536], 33 | post_process_channels=[128, 192, 384, 768, 1536], 34 | channels=64, # last one 35 | final_norm=False, 36 | scale_up=True, 37 | align_corners=False, # for upsample 38 | min_depth=1e-3, 39 | max_depth=10, 40 | loss_decode=dict( 41 | type='SigLoss', valid_mask=True, loss_weight=10)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole') 45 | ) 46 | 47 | find_unused_parameters = True 48 | SyncBN = True 49 | 50 | # batch size 51 | data = dict( 52 | samples_per_gpu=2, 53 | workers_per_gpu=8, 54 | ) 55 | 56 | # schedules 57 | # optimizer 58 | max_lr = 0.0001 59 | optimizer = dict( 60 | type='AdamW', 61 | lr=max_lr, 62 | betas=(0.9, 0.999), 63 | weight_decay=0.01, 64 | paramwise_cfg=dict( 65 | custom_keys={ 66 | 'absolute_pos_embed': dict(decay_mult=0.), 67 | 'relative_position_bias_table': dict(decay_mult=0.), 68 | 'norm': dict(decay_mult=0.), 69 | })) 70 | 71 | # learning policy 72 | lr_config = dict(policy='poly', 73 | warmup='linear', 74 | warmup_iters=3200, 75 | warmup_ratio=1e-6, 76 | power=1.0, min_lr=0.0, by_epoch=False) 77 | 78 | optimizer_config = dict() 79 | # runtime settings 80 | runner = dict(type='IterBasedRunner', max_iters=320000) 81 | 82 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 83 | evaluation = dict(by_epoch=False, 84 | start=0, 85 | interval=1600, 86 | pre_eval=True, 87 | rule='less', 88 | save_best='abs_rel', 89 | greater_keys=("a1", "a2", "a3"), 90 | less_keys=("abs_rel", "rmse")) 91 | 92 | # iter runtime 93 | log_config = dict( 94 | _delete_=True, 95 | interval=50, 96 | hooks=[ 97 | dict(type='TextLoggerHook', by_epoch=False), 98 | ]) -------------------------------------------------------------------------------- /configs/_base_/models/trap_xcit_s_12_kitti.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../models/trap.py', '../datasets/kitti.py', 3 | '../default_runtime.py' 4 | ] 5 | 6 | # norm_cfg = dict(type='BN', requires_grad=True) 7 | norm_cfg = dict(type='LN', requires_grad=True) 8 | 9 | model = dict( 10 | decode_head=dict( 11 | min_depth=1e-3, 12 | max_depth=80, 13 | norm_cfg=norm_cfg, 14 | ), 15 | ) 16 | 17 | find_unused_parameters=True 18 | SyncBN = True 19 | 20 | # batch size 21 | data = dict( 22 | samples_per_gpu=2, 23 | workers_per_gpu=8, 24 | ) 25 | 26 | # schedules 27 | # optimizer 28 | max_lr = 0.0001 29 | optimizer = dict( 30 | type='AdamW', 31 | lr=max_lr, 32 | betas=(0.9, 0.999), 33 | weight_decay=0.01, 34 | paramwise_cfg=dict( 35 | custom_keys={ 36 | 'absolute_pos_embed': dict(decay_mult=0.), 37 | 'relative_position_bias_table': dict(decay_mult=0.), 38 | 'norm': dict(decay_mult=0.), 39 | })) 40 | 41 | lr_config = dict(policy='poly', 42 | warmup='linear', 43 | warmup_iters=3200, 44 | warmup_ratio=1e-6, 45 | power=1.0, min_lr=0.0, by_epoch=False) 46 | 47 | 48 | optimizer_config = dict() 49 | 50 | runner = dict(type='IterBasedRunner', max_iters=320000) 51 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 52 | evaluation = dict(by_epoch=False, 53 | start=0, 54 | interval=1600, 55 | pre_eval=True, 56 | rule='less', 57 | save_best='abs_rel', 58 | greater_keys=("a1", "a2", "a3"), 59 | less_keys=("abs_rel", "rmse")) 60 | 61 | # iter runtime 62 | log_config = dict( 63 | _delete_=True, 64 | interval=50, 65 | hooks=[ 66 | dict(type='TextLoggerHook', by_epoch=False), 67 | ]) -------------------------------------------------------------------------------- /configs/_base_/models/trap_xcit_s_12_nyu.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../models/trap.py', '../datasets/nyu.py', 3 | '../default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | decode_head=dict( 10 | min_depth=1e-3, 11 | max_depth=10, 12 | norm_cfg=norm_cfg, 13 | ), 14 | ) 15 | 16 | find_unused_parameters=True 17 | SyncBN = True 18 | 19 | # batch size 20 | data = dict( 21 | samples_per_gpu=2, 22 | workers_per_gpu=8, 23 | ) 24 | 25 | # schedules 26 | # optimizer 27 | max_lr = 1e-4 28 | 29 | optimizer = dict( 30 | type='AdamW', 31 | lr=max_lr, 32 | betas=(0.9, 0.999), 33 | weight_decay=0.01, 34 | paramwise_cfg=dict( 35 | custom_keys={ 36 | 'absolute_pos_embed': dict(decay_mult=0.), 37 | 'relative_position_bias_table': dict(decay_mult=0.), 38 | 'norm': dict(decay_mult=0.), 39 | })) 40 | 41 | lr_config = dict(policy='poly', 42 | warmup='linear', 43 | warmup_iters=3200, 44 | warmup_ratio=1e-6, 45 | power=1.0, min_lr=0.0, by_epoch=False) 46 | 47 | 48 | optimizer_config = dict() 49 | 50 | runner = dict(type='IterBasedRunner', max_iters=320000) 51 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 52 | evaluation = dict(by_epoch=False, 53 | start=0, 54 | interval=1600, 55 | pre_eval=True, 56 | rule='less', 57 | save_best='abs_rel', 58 | greater_keys=("a1", "a2", "a3"), 59 | less_keys=("abs_rel", "rmse")) 60 | 61 | # iter runtime 62 | log_config = dict( 63 | _delete_=True, 64 | interval=50, 65 | hooks=[ 66 | dict(type='TextLoggerHook', by_epoch=False), 67 | ]) -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_24x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | max_lr=1e-4 3 | optimizer = dict(type='AdamW', lr=max_lr, betas=(0.95, 0.99), weight_decay=0.01,) 4 | # learning policy 5 | lr_config = dict( 6 | policy='OneCycle', 7 | max_lr=max_lr, 8 | div_factor=25, 9 | final_div_factor=100, 10 | by_epoch=False, 11 | ) 12 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 13 | # runtime settings 14 | runner = dict(type='EpochBasedRunner', max_epochs=24) 15 | checkpoint_config = dict(by_epoch=True, max_keep_ckpts=2, interval=1600) 16 | evaluation = dict(by_epoch=True, interval=6, pre_eval=True) -------------------------------------------------------------------------------- /configs/trap/trap_swin_l_kitti.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/kitti.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='Swin', 12 | pretrained= 13 | 'swin_large_patch4_window7_224_22k.pth', 14 | embed_dims=192, 15 | patch_size=4, 16 | window_size=7, 17 | mlp_ratio=4, 18 | depths=[2, 2, 18, 2], 19 | num_heads=[6, 12, 24, 48], 20 | strides=(4, 2, 2, 2), 21 | out_indices=range(0, 24), 22 | qkv_bias=True, 23 | qk_scale=None, 24 | patch_norm=True, 25 | drop_rate=0.0, 26 | attn_drop_rate=0.0, 27 | drop_path_rate=0.3, 28 | use_abs_pos_embed=False, 29 | act_cfg=dict(type='GELU'), 30 | norm_cfg=dict(type='LN', requires_grad=True), 31 | pretrain_style='official'), 32 | neck=dict( 33 | type='BlockSelectionNeck', 34 | in_channels=[192, 384, 768, 768, 1536], 35 | out_channels=[192, 288, 576, 1152, 2304], 36 | start=[0, 2, 4, 10, 22], 37 | end=[2, 4, 16, 22, 24], 38 | scales=[2, 2, 2, 1.0, 1.0]), 39 | decode_head=dict( 40 | type='TrappedHead', 41 | in_channels=[192, 288, 576, 1152, 2304], 42 | post_process_channels=[192, 288, 576, 1152, 2304], 43 | channels=96, 44 | final_norm=False, 45 | scale_up=True, 46 | # drop_path_rate=0.3, 47 | align_corners=False, 48 | min_depth=0.001, 49 | max_depth=80, 50 | loss_decode=dict(type='SigLoss', valid_mask=True, loss_weight=10)), 51 | train_cfg=dict(), 52 | test_cfg=dict(mode='whole')) 53 | 54 | find_unused_parameters=True 55 | SyncBN = True 56 | 57 | # batch size 58 | data = dict( 59 | samples_per_gpu=2, 60 | workers_per_gpu=8, 61 | ) 62 | 63 | # schedules 64 | # optimizer 65 | max_lr = 0.0001 66 | optimizer = dict( 67 | type='AdamW', 68 | lr=max_lr, 69 | betas=(0.9, 0.999), 70 | weight_decay=0.01, 71 | paramwise_cfg=dict( 72 | custom_keys={ 73 | 'absolute_pos_embed': dict(decay_mult=0.), 74 | 'relative_position_bias_table': dict(decay_mult=0.), 75 | 'norm': dict(decay_mult=0.), 76 | })) 77 | 78 | lr_config = dict(policy='poly', 79 | warmup='linear', 80 | warmup_iters=3200, 81 | warmup_ratio=1e-6, 82 | power=1.0, min_lr=0.0, by_epoch=False) 83 | 84 | 85 | optimizer_config = dict() 86 | # runner = dict(type='IterBasedRunnerAmp', max_iters=320000) 87 | runner = dict(type='IterBasedRunner', max_iters=320000) 88 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 89 | evaluation = dict(by_epoch=False, 90 | start=0, 91 | interval=1600, 92 | pre_eval=True, 93 | rule='less', 94 | save_best='abs_rel', 95 | greater_keys=("a1", "a2", "a3"), 96 | less_keys=("abs_rel", "rmse")) 97 | 98 | # iter runtime 99 | log_config = dict( 100 | _delete_=True, 101 | interval=50, 102 | hooks=[ 103 | dict(type='TextLoggerHook', by_epoch=False), 104 | ]) -------------------------------------------------------------------------------- /configs/trap/trap_swin_l_kitti_benchmark.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/kitti.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='Swin', 12 | pretrained= 13 | 'swin_large_patch4_window7_224_22k.pth', 14 | embed_dims=192, 15 | patch_size=4, 16 | window_size=7, 17 | mlp_ratio=4, 18 | depths=[2, 2, 18, 2], 19 | num_heads=[6, 12, 24, 48], 20 | strides=(4, 2, 2, 2), 21 | out_indices=range(0, 24), 22 | qkv_bias=True, 23 | qk_scale=None, 24 | patch_norm=True, 25 | drop_rate=0.0, 26 | attn_drop_rate=0.0, 27 | drop_path_rate=0.3, 28 | use_abs_pos_embed=False, 29 | act_cfg=dict(type='GELU'), 30 | norm_cfg=dict(type='LN', requires_grad=True), 31 | pretrain_style='official'), 32 | neck=dict( 33 | type='BlockSelectionNeck', 34 | in_channels=[192, 384, 768, 768, 1536], 35 | out_channels=[192, 288, 576, 1152, 2304], 36 | start=[0, 2, 4, 10, 22], 37 | end=[2, 4, 16, 22, 24], 38 | scales=[2, 2, 2, 1.0, 1.0]), 39 | decode_head=dict( 40 | type='TrappedHead', 41 | in_channels=[192, 288, 576, 1152, 2304], 42 | post_process_channels=[192, 288, 576, 1152, 2304], 43 | channels=96, 44 | final_norm=False, 45 | scale_up=True, 46 | # drop_path_rate=0.3, 47 | align_corners=False, 48 | min_depth=0.001, 49 | max_depth=88, 50 | loss_decode=dict(type='SigLoss', valid_mask=True, loss_weight=10)), 51 | train_cfg=dict(), 52 | test_cfg=dict(mode='whole')) 53 | 54 | find_unused_parameters=True 55 | SyncBN = True 56 | 57 | # batch size 58 | data = dict( 59 | samples_per_gpu=2, 60 | workers_per_gpu=8, 61 | ) 62 | 63 | # schedules 64 | # optimizer 65 | max_lr = 0.0001 66 | optimizer = dict( 67 | type='AdamW', 68 | lr=max_lr, 69 | betas=(0.9, 0.999), 70 | weight_decay=0.01, 71 | paramwise_cfg=dict( 72 | custom_keys={ 73 | 'absolute_pos_embed': dict(decay_mult=0.), 74 | 'relative_position_bias_table': dict(decay_mult=0.), 75 | 'norm': dict(decay_mult=0.), 76 | })) 77 | 78 | lr_config = dict(policy='poly', 79 | warmup='linear', 80 | warmup_iters=3200, 81 | warmup_ratio=1e-6, 82 | power=1.0, min_lr=0.0, by_epoch=False) 83 | 84 | 85 | optimizer_config = dict() 86 | # runner = dict(type='IterBasedRunnerAmp', max_iters=320000) 87 | runner = dict(type='IterBasedRunner', max_iters=320000) 88 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 89 | evaluation = dict(by_epoch=False, 90 | start=0, 91 | interval=1600, 92 | pre_eval=True, 93 | rule='less', 94 | save_best='abs_rel', 95 | greater_keys=("a1", "a2", "a3"), 96 | less_keys=("abs_rel", "rmse")) 97 | 98 | # iter runtime 99 | log_config = dict( 100 | _delete_=True, 101 | interval=50, 102 | hooks=[ 103 | dict(type='TextLoggerHook', by_epoch=False), 104 | ]) -------------------------------------------------------------------------------- /configs/trap/trap_swin_l_win7_nyu.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nyu.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | backbone_norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='Swin', 12 | pretrained=r'', 13 | embed_dims=192, 14 | patch_size=4, 15 | window_size=7, 16 | mlp_ratio=4, 17 | depths=[2, 2, 18, 2], 18 | num_heads=[6, 12, 24, 48], 19 | strides=(4, 2, 2, 2), 20 | out_indices=range(24), 21 | qkv_bias=True, 22 | qk_scale=None, 23 | patch_norm=True, 24 | drop_rate=0., 25 | attn_drop_rate=0., 26 | drop_path_rate=0.5, 27 | use_abs_pos_embed=False, 28 | act_cfg=dict(type='GELU'), 29 | norm_cfg=backbone_norm_cfg, 30 | pretrain_style='official', 31 | ), 32 | neck=dict( 33 | type='BlockSelectionNeck', 34 | in_channels=[192, 384, 768, 768, 1536], 35 | out_channels=[192, 288, 576, 1152, 2304], 36 | start=[0, 2, 4, 10, 22], 37 | end=[2, 4, 16, 22, 24], 38 | scales=[2, 2, 2, 1., 1.]), 39 | decode_head=dict( 40 | type='TrappedHead', 41 | in_channels=[192, 288, 576, 1152, 2304], 42 | post_process_channels=[192, 288, 576, 1152, 2304], 43 | channels=96, # last one 44 | final_norm=False, 45 | scale_up=True, 46 | align_corners=False, # for upsample 47 | min_depth=1e-3, 48 | max_depth=10, 49 | loss_decode=dict( 50 | type='SigLoss', valid_mask=True, loss_weight=10)), 51 | # model training and testing settings 52 | train_cfg=dict(), 53 | test_cfg=dict(mode='whole') 54 | ) 55 | 56 | find_unused_parameters = True 57 | SyncBN = True 58 | 59 | # batch size 60 | data = dict( 61 | samples_per_gpu=2, 62 | workers_per_gpu=8, 63 | ) 64 | 65 | # schedules 66 | # optimizer 67 | max_lr = 0.0001 68 | 69 | 70 | optimizer = dict( 71 | type='AdamW', 72 | lr=max_lr, 73 | betas=(0.9, 0.999), 74 | weight_decay=0.01, 75 | paramwise_cfg=dict( 76 | custom_keys={ 77 | 'absolute_pos_embed': dict(decay_mult=0.), 78 | 'relative_position_bias_table': dict(decay_mult=0.), 79 | 'norm': dict(decay_mult=0.), 80 | })) 81 | # learning policy 82 | lr_config = dict(policy='poly', 83 | warmup='linear', 84 | warmup_iters=3200, 85 | warmup_ratio=1e-6, 86 | power=1.0, min_lr=0.0, by_epoch=False) 87 | 88 | optimizer_config = dict() 89 | # runner = dict(type='IterBasedRunnerAmp', max_iters=320000) 90 | runner = dict(type='IterBasedRunner', max_iters=320000) 91 | 92 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 93 | evaluation = dict(by_epoch=False, 94 | start=0, 95 | interval=1600, 96 | pre_eval=True, 97 | rule='less', 98 | save_best='abs_rel', 99 | greater_keys=("a1", "a2", "a3"), 100 | less_keys=("abs_rel", "rmse")) 101 | 102 | # iter runtime 103 | log_config = dict( 104 | _delete_=True, 105 | interval=50, 106 | hooks=[ 107 | dict(type='TextLoggerHook', by_epoch=False), 108 | ]) -------------------------------------------------------------------------------- /configs/trap/trap_xcit_m_24_kitti.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/kitti.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='XCiT', 12 | pretrained=r'', 13 | patch_size=8, 14 | embed_dim=512, 15 | depth=24, 16 | num_heads=8, 17 | mlp_ratio=4, 18 | qkv_bias=True, 19 | eta=1e-5, 20 | out_indices=range(24), 21 | ), 22 | neck=dict( 23 | type='BlockSelectionNeck', 24 | in_channels=[512] * 5, 25 | out_channels=[128, 192, 384, 768, 1536], 26 | start=[0, 4, 8, 12, 16], 27 | end=[8, 12, 16, 20, 24], 28 | scales=[4, 2, 1, .5, .25]), 29 | decode_head=dict( 30 | type='TrappedHead', 31 | in_channels=[128, 192, 384, 768, 1536], 32 | post_process_channels=[128, 192, 384, 768, 1536], 33 | channels=64, # last one 34 | final_norm=False, 35 | scale_up=True, 36 | align_corners=False, # for upsample 37 | min_depth=1e-3, 38 | max_depth=80, 39 | loss_decode=dict( 40 | type='SigLoss', valid_mask=True, loss_weight=10)), 41 | # model training and testing settings 42 | train_cfg=dict(), 43 | test_cfg=dict(mode='whole') 44 | ) 45 | 46 | find_unused_parameters = True 47 | SyncBN = True 48 | 49 | # batch size 50 | data = dict( 51 | samples_per_gpu=2, 52 | workers_per_gpu=8, 53 | ) 54 | 55 | # schedules 56 | # optimizer 57 | max_lr = 0.00001 58 | optimizer = dict( 59 | type='AdamW', 60 | lr=max_lr, 61 | betas=(0.9, 0.999), 62 | weight_decay=0.01, 63 | paramwise_cfg=dict( 64 | custom_keys={ 65 | 'absolute_pos_embed': dict(decay_mult=0.), 66 | 'relative_position_bias_table': dict(decay_mult=0.), 67 | 'norm': dict(decay_mult=0.), 68 | })) 69 | # learning policy 70 | lr_config = dict(policy='poly', 71 | warmup='linear', 72 | warmup_iters=3200, 73 | warmup_ratio=1e-6, 74 | power=1.0, min_lr=0.0, by_epoch=False) 75 | 76 | optimizer_config = dict() 77 | # runtime settings 78 | # runner = dict(type='IterBasedRunnerAmp', max_iters=320000) 79 | runner = dict(type='IterBasedRunner', max_iters=320000) 80 | 81 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 82 | evaluation = dict(by_epoch=False, 83 | start=0, 84 | interval=1600, 85 | pre_eval=True, 86 | rule='less', 87 | save_best='abs_rel', 88 | greater_keys=("a1", "a2", "a3"), 89 | less_keys=("abs_rel", "rmse")) 90 | 91 | # iter runtime 92 | log_config = dict( 93 | _delete_=True, 94 | interval=50, 95 | hooks=[ 96 | dict(type='TextLoggerHook', by_epoch=False), 97 | ]) -------------------------------------------------------------------------------- /configs/trap/trap_xcit_m_24_nyu.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nyu.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | type='DepthEncoderDecoder', 10 | backbone=dict( 11 | type='XCiT', 12 | pretrained=r'', 13 | patch_size=8, 14 | embed_dim=512, 15 | depth=24, 16 | num_heads=8, 17 | mlp_ratio=4, 18 | qkv_bias=True, 19 | eta=1e-5, 20 | drop_path_rate=0., 21 | out_indices=range(24), 22 | ), 23 | neck=dict( 24 | type='BlockSelectionNeck', 25 | in_channels=[512] * 5, 26 | out_channels=[128, 192, 384, 768, 1536], 27 | start=[0, 4, 8, 12, 16], 28 | end=[8, 12, 16, 20, 24], 29 | scales=[4, 2, 1, .5, .25]), 30 | decode_head=dict( 31 | type='TrappedHead', 32 | in_channels=[128, 192, 384, 768, 1536], 33 | post_process_channels=[128, 192, 384, 768, 1536], 34 | channels=64, # last one 35 | final_norm=False, 36 | scale_up=True, 37 | align_corners=False, # for upsample 38 | min_depth=1e-3, 39 | max_depth=10, 40 | loss_decode=dict( 41 | type='SigLoss', valid_mask=True, loss_weight=10)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole') 45 | ) 46 | 47 | find_unused_parameters = True 48 | SyncBN = True 49 | 50 | # batch size 51 | data = dict( 52 | samples_per_gpu=2, 53 | workers_per_gpu=8, 54 | ) 55 | 56 | # schedules 57 | # optimizer 58 | max_lr = 0.0001 59 | optimizer = dict( 60 | type='AdamW', 61 | lr=max_lr, 62 | betas=(0.9, 0.999), 63 | weight_decay=0.01, 64 | paramwise_cfg=dict( 65 | custom_keys={ 66 | 'absolute_pos_embed': dict(decay_mult=0.), 67 | 'relative_position_bias_table': dict(decay_mult=0.), 68 | 'norm': dict(decay_mult=0.), 69 | })) 70 | 71 | # learning policy 72 | lr_config = dict(policy='poly', 73 | warmup='linear', 74 | warmup_iters=3200, 75 | warmup_ratio=1e-6, 76 | power=1.0, min_lr=0.0, by_epoch=False) 77 | 78 | optimizer_config = dict() 79 | # runtime settings 80 | # runner = dict(type='IterBasedRunnerAmp', max_iters=320000) 81 | runner = dict(type='IterBasedRunner', max_iters=320000) 82 | 83 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 84 | evaluation = dict(by_epoch=False, 85 | start=0, 86 | interval=1600, 87 | pre_eval=True, 88 | rule='less', 89 | save_best='abs_rel', 90 | greater_keys=("a1", "a2", "a3"), 91 | less_keys=("abs_rel", "rmse")) 92 | 93 | # iter runtime 94 | log_config = dict( 95 | _delete_=True, 96 | interval=50, 97 | hooks=[ 98 | dict(type='TextLoggerHook', by_epoch=False), 99 | ]) -------------------------------------------------------------------------------- /configs/trap/trap_xcit_s_12_kitti.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/trap.py', '../_base_/datasets/kitti.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | # norm_cfg = dict(type='BN', requires_grad=True) 7 | norm_cfg = dict(type='LN', requires_grad=True) 8 | 9 | model = dict( 10 | decode_head=dict( 11 | min_depth=1e-3, 12 | max_depth=80, 13 | norm_cfg=norm_cfg, 14 | ), 15 | ) 16 | 17 | find_unused_parameters=True 18 | SyncBN = True 19 | 20 | # batch size 21 | data = dict( 22 | samples_per_gpu=2, 23 | workers_per_gpu=8, 24 | ) 25 | 26 | # schedules 27 | # optimizer 28 | max_lr = 0.0001 29 | optimizer = dict( 30 | type='AdamW', 31 | lr=max_lr, 32 | betas=(0.9, 0.999), 33 | weight_decay=0.01, 34 | paramwise_cfg=dict( 35 | custom_keys={ 36 | 'absolute_pos_embed': dict(decay_mult=0.), 37 | 'relative_position_bias_table': dict(decay_mult=0.), 38 | 'norm': dict(decay_mult=0.), 39 | })) 40 | 41 | lr_config = dict(policy='poly', 42 | warmup='linear', 43 | warmup_iters=3200, 44 | warmup_ratio=1e-6, 45 | power=1.0, min_lr=0.0, by_epoch=False) 46 | 47 | 48 | optimizer_config = dict() 49 | 50 | # runner = dict(type='IterBasedRunnerAmp', max_iters=320000) 51 | runner = dict(type='IterBasedRunner', max_iters=320000) 52 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 53 | evaluation = dict(by_epoch=False, 54 | start=0, 55 | interval=1600, 56 | pre_eval=True, 57 | rule='less', 58 | save_best='abs_rel', 59 | greater_keys=("a1", "a2", "a3"), 60 | less_keys=("abs_rel", "rmse")) 61 | 62 | # iter runtime 63 | log_config = dict( 64 | _delete_=True, 65 | interval=50, 66 | hooks=[ 67 | dict(type='TextLoggerHook', by_epoch=False), 68 | ]) -------------------------------------------------------------------------------- /configs/trap/trap_xcit_s_12_nyu.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/trap.py', '../_base_/datasets/nyu.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | norm_cfg = dict(type='LN', requires_grad=True) 7 | 8 | model = dict( 9 | decode_head=dict( 10 | min_depth=1e-3, 11 | max_depth=10, 12 | norm_cfg=norm_cfg, 13 | ), 14 | ) 15 | 16 | find_unused_parameters=True 17 | SyncBN = True 18 | 19 | # batch size 20 | data = dict( 21 | samples_per_gpu=2, 22 | workers_per_gpu=8, 23 | ) 24 | 25 | # schedules 26 | # optimizer 27 | max_lr = 1e-4 28 | 29 | optimizer = dict( 30 | type='AdamW', 31 | lr=max_lr, 32 | betas=(0.9, 0.999), 33 | weight_decay=0.01, 34 | paramwise_cfg=dict( 35 | custom_keys={ 36 | 'absolute_pos_embed': dict(decay_mult=0.), 37 | 'relative_position_bias_table': dict(decay_mult=0.), 38 | 'norm': dict(decay_mult=0.), 39 | })) 40 | 41 | lr_config = dict(policy='poly', 42 | warmup='linear', 43 | warmup_iters=3200, 44 | warmup_ratio=1e-6, 45 | power=1.0, min_lr=0.0, by_epoch=False) 46 | 47 | 48 | optimizer_config = dict() 49 | 50 | # runner = dict(type='IterBasedRunnerAmp', max_iters=320000) 51 | runner = dict(type='IterBasedRunner', max_iters=320000) 52 | checkpoint_config = dict(by_epoch=False, max_keep_ckpts=2, interval=1600) 53 | evaluation = dict(by_epoch=False, 54 | start=0, 55 | interval=1600, 56 | pre_eval=True, 57 | rule='less', 58 | save_best='abs_rel', 59 | greater_keys=("a1", "a2", "a3"), 60 | less_keys=("abs_rel", "rmse")) 61 | 62 | # iter runtime 63 | log_config = dict( 64 | _delete_=True, 65 | interval=50, 66 | hooks=[ 67 | dict(type='TextLoggerHook', by_epoch=False), 68 | ]) -------------------------------------------------------------------------------- /depth/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | from packaging.version import parse 6 | 7 | from .version import __version__, version_info 8 | 9 | MMCV_MIN = '1.3.7' 10 | # MMCV_MAX = '1.4.0' 11 | MMCV_MAX = '1.5.0' # have tested on mmcv==1.5.0 12 | 13 | 14 | def digit_version(version_str: str, length: int = 4): 15 | """Convert a version string into a tuple of integers. 16 | 17 | This method is usually used for comparing two versions. For pre-release 18 | versions: alpha < beta < rc. 19 | 20 | Args: 21 | version_str (str): The version string. 22 | length (int): The maximum number of version levels. Default: 4. 23 | 24 | Returns: 25 | tuple[int]: The version info in digits (integers). 26 | """ 27 | version = parse(version_str) 28 | assert version.release, f'failed to parse version {version_str}' 29 | release = list(version.release) 30 | release = release[:length] 31 | if len(release) < length: 32 | release = release + [0] * (length - len(release)) 33 | if version.is_prerelease: 34 | mapping = {'a': -3, 'b': -2, 'rc': -1} 35 | val = -4 36 | # version.pre can be None 37 | if version.pre: 38 | if version.pre[0] not in mapping: 39 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 40 | 'version checking may go wrong') 41 | else: 42 | val = mapping[version.pre[0]] 43 | release.extend([val, version.pre[-1]]) 44 | else: 45 | release.extend([val, 0]) 46 | 47 | elif version.is_postrelease: 48 | release.extend([1, version.post]) 49 | else: 50 | release.extend([0, 0]) 51 | return tuple(release) 52 | 53 | 54 | mmcv_min_version = digit_version(MMCV_MIN) 55 | mmcv_max_version = digit_version(MMCV_MAX) 56 | mmcv_version = digit_version(mmcv.__version__) 57 | 58 | 59 | # assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 60 | # f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 61 | # f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 62 | 63 | __all__ = ['__version__', 'version_info', 'digit_version'] 64 | -------------------------------------------------------------------------------- /depth/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_depther, init_depther 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .train import get_root_logger, set_random_seed, train_depther 5 | 6 | __all__ = [ 7 | 'get_root_logger', 'set_random_seed', 'train_depther', 'init_depther', 8 | 'inference_depther', 'multi_gpu_test', 'single_gpu_test', 9 | ] 10 | -------------------------------------------------------------------------------- /depth/apis/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import matplotlib.pyplot as plt 3 | import mmcv 4 | import torch 5 | from mmcv.parallel import collate, scatter 6 | from mmcv.runner import load_checkpoint 7 | 8 | from depth.datasets.pipelines import Compose 9 | from depth.models import build_depther 10 | 11 | 12 | def init_depther(config, checkpoint=None, device='cuda:0'): 13 | """Initialize a depther from config file. 14 | 15 | Args: 16 | config (str or :obj:`mmcv.Config`): Config file path or the config 17 | object. 18 | checkpoint (str, optional): Checkpoint path. If left as None, the model 19 | will not load any weights. 20 | device (str, optional) CPU/CUDA device option. Default 'cuda:0'. 21 | Use 'cpu' for loading model on CPU. 22 | Returns: 23 | nn.Module: The constructed depther. 24 | """ 25 | if isinstance(config, str): 26 | config = mmcv.Config.fromfile(config) 27 | elif not isinstance(config, mmcv.Config): 28 | raise TypeError('config must be a filename or Config object, ' 29 | 'but got {}'.format(type(config))) 30 | config.model.pretrained = None 31 | config.model.train_cfg = None 32 | model = build_depther(config.model, test_cfg=config.get('test_cfg')) 33 | if checkpoint is not None: 34 | checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') 35 | model.CLASSES = checkpoint['meta']['CLASSES'] 36 | model.PALETTE = checkpoint['meta']['PALETTE'] 37 | model.cfg = config # save the config in the model for convenience 38 | model.to(device) 39 | model.eval() 40 | return model 41 | 42 | 43 | class LoadImage: 44 | """A simple pipeline to load image.""" 45 | 46 | def __call__(self, results): 47 | """Call function to load images into results. 48 | 49 | Args: 50 | results (dict): A result dict contains the file name 51 | of the image to be read. 52 | 53 | Returns: 54 | dict: ``results`` will be returned containing loaded image. 55 | """ 56 | 57 | if isinstance(results['img'], str): 58 | results['filename'] = results['img'] 59 | results['ori_filename'] = results['img'] 60 | else: 61 | results['filename'] = None 62 | results['ori_filename'] = None 63 | img = mmcv.imread(results['img']) 64 | results['img'] = img 65 | results['img_shape'] = img.shape 66 | results['ori_shape'] = img.shape 67 | return results 68 | 69 | 70 | def inference_depther(model, img): 71 | """Inference image(s) with the depther. 72 | 73 | Args: 74 | model (nn.Module): The loaded depther. 75 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 76 | images. 77 | 78 | Returns: 79 | (list[Tensor]): The depth estimation result. 80 | """ 81 | cfg = model.cfg 82 | device = next(model.parameters()).device # model device 83 | # build the data pipeline 84 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 85 | test_pipeline = Compose(test_pipeline) 86 | # prepare data 87 | data = dict(img=img) 88 | data = test_pipeline(data) 89 | data = collate([data], samples_per_gpu=1) 90 | if next(model.parameters()).is_cuda: 91 | # scatter to specified GPU 92 | data = scatter(data, [device])[0] 93 | else: 94 | data['img_metas'] = [i.data[0] for i in data['img_metas']] 95 | 96 | # forward the model 97 | with torch.no_grad(): 98 | result = model(return_loss=False, rescale=True, **data) 99 | return result 100 | 101 | -------------------------------------------------------------------------------- /depth/apis/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | import warnings 4 | 5 | import numpy as np 6 | import torch 7 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 8 | from mmcv.runner import build_optimizer, build_runner 9 | 10 | from depth.core import DistEvalHook, EvalHook 11 | from depth.datasets import build_dataloader, build_dataset 12 | from depth.utils import get_root_logger 13 | 14 | def set_random_seed(seed, deterministic=False): 15 | """Set random seed. 16 | 17 | Args: 18 | seed (int): Seed to be used. 19 | deterministic (bool): Whether to set the deterministic option for 20 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 21 | to True and `torch.backends.cudnn.benchmark` to False. 22 | Default: False. 23 | """ 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | if deterministic: 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | 32 | 33 | def train_depther(model, 34 | dataset, 35 | cfg, 36 | distributed=False, 37 | validate=False, 38 | timestamp=None, 39 | meta=None): 40 | """Launch depther training.""" 41 | logger = get_root_logger(cfg.log_level) 42 | 43 | # prepare data loaders 44 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 45 | data_loaders = [ 46 | build_dataloader( 47 | ds, 48 | cfg.data.samples_per_gpu, 49 | cfg.data.workers_per_gpu, 50 | # cfg.gpus will be ignored if distributed 51 | len(cfg.gpu_ids), 52 | dist=distributed, 53 | seed=cfg.seed, 54 | drop_last=True) for ds in dataset 55 | ] 56 | 57 | # put model on gpus 58 | if distributed: 59 | find_unused_parameters = cfg.get('find_unused_parameters', False) 60 | # Sets the `find_unused_parameters` parameter in 61 | # torch.nn.parallel.DistributedDataParallel 62 | model = MMDistributedDataParallel( 63 | model.cuda(), 64 | device_ids=[torch.cuda.current_device()], 65 | broadcast_buffers=False, 66 | find_unused_parameters=find_unused_parameters) 67 | else: 68 | model = MMDataParallel( 69 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 70 | 71 | # build runner 72 | optimizer = build_optimizer(model, cfg.optimizer) 73 | 74 | if cfg.get('runner') is None: 75 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 76 | warnings.warn( 77 | 'config is now expected to have a `runner` section, ' 78 | 'please set `runner` in your config.', UserWarning) 79 | 80 | runner = build_runner( 81 | cfg.runner, 82 | default_args=dict( 83 | model=model, 84 | batch_processor=None, 85 | optimizer=optimizer, 86 | work_dir=cfg.work_dir, 87 | logger=logger, 88 | meta=meta)) 89 | 90 | # register hooks 91 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 92 | cfg.checkpoint_config, cfg.log_config, 93 | cfg.get('momentum_config', None), 94 | custom_hooks_config=cfg.get('custom_hooks', None)) 95 | 96 | # an ugly walkaround to make the .log and .log.json filenames the same 97 | runner.timestamp = timestamp 98 | 99 | # register eval hooks 100 | if validate: 101 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 102 | val_dataloader = build_dataloader( 103 | val_dataset, 104 | samples_per_gpu=1, 105 | workers_per_gpu=cfg.data.workers_per_gpu, 106 | dist=distributed, 107 | shuffle=False) 108 | eval_cfg = cfg.get('evaluation', {}) 109 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 110 | eval_hook = DistEvalHook if distributed else EvalHook 111 | # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the 112 | # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. 113 | runner.register_hook( 114 | eval_hook(val_dataloader, **eval_cfg), priority='LOW') 115 | 116 | if cfg.resume_from: 117 | runner.resume(cfg.resume_from) 118 | elif cfg.load_from: 119 | runner.load_checkpoint(cfg.load_from) 120 | runner.run(data_loaders, cfg.workflow) 121 | -------------------------------------------------------------------------------- /depth/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import * # noqa: F401, F403 3 | from .utils import * # noqa: F401, F403 4 | -------------------------------------------------------------------------------- /depth/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .metrics import metrics, eval_metrics, pre_eval_to_metrics 3 | from .eval_hooks import EvalHook, DistEvalHook -------------------------------------------------------------------------------- /depth/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | import warnings 4 | 5 | import torch.distributed as dist 6 | from mmcv.runner import DistEvalHook as _DistEvalHook 7 | from mmcv.runner import EvalHook as _EvalHook 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | 10 | 11 | class EvalHook(_EvalHook): 12 | """Single GPU EvalHook, with efficient test support. 13 | 14 | Args: 15 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 16 | If set to True, it will perform by epoch. Otherwise, by iteration. 17 | Default: False. 18 | pre_eval (bool): Whether to use progressive mode to evaluate model. 19 | Default: False. 20 | Returns: 21 | list: The prediction results. 22 | """ 23 | 24 | metric = ["a1", "a2", "a3", "abs_rel", "rmse", "log_10", "rmse_log", "silog", "sq_rel"] 25 | # greater_keys = ['mIoU', 'mAcc', 'aAcc'] 26 | 27 | def __init__(self, 28 | *args, 29 | by_epoch=False, 30 | pre_eval=False, 31 | **kwargs): 32 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 33 | self.pre_eval = pre_eval 34 | 35 | def _do_evaluate(self, runner): 36 | """perform evaluation and save ckpt.""" 37 | if not self._should_evaluate(runner): 38 | return 39 | 40 | from depth.apis import single_gpu_test 41 | results = single_gpu_test( 42 | runner.model, self.dataloader, show=False, pre_eval=self.pre_eval) 43 | runner.log_buffer.clear() 44 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 45 | key_score = self.evaluate(runner, results) 46 | if self.save_best: 47 | self._save_ckpt(runner, key_score) 48 | 49 | 50 | class DistEvalHook(_DistEvalHook): 51 | """Distributed EvalHook, with efficient test support. 52 | 53 | Args: 54 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 55 | If set to True, it will perform by epoch. Otherwise, by iteration. 56 | Default: False. 57 | pre_eval (bool): Whether to use progressive mode to evaluate model. 58 | Default: False. 59 | Returns: 60 | list: The prediction results. 61 | """ 62 | 63 | metric = ["a1", "a2", "a3", "abs_rel", "rmse", "log_10", "rmse_log", "silog", "sq_rel"] 64 | # greater_keys = ['mIoU', 'mAcc', 'aAcc'] 65 | 66 | def __init__(self, 67 | *args, 68 | by_epoch=False, 69 | pre_eval=False, 70 | **kwargs): 71 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 72 | self.pre_eval = pre_eval 73 | 74 | def _do_evaluate(self, runner): 75 | """perform evaluation and save ckpt.""" 76 | # Synchronization of BatchNorm's buffer (running_mean 77 | # and running_var) is not supported in the DDP of pytorch, 78 | # which may cause the inconsistent performance of models in 79 | # different ranks, so we broadcast BatchNorm's buffers 80 | # of rank 0 to other ranks to avoid this. 81 | if self.broadcast_bn_buffer: 82 | model = runner.model 83 | for name, module in model.named_modules(): 84 | if isinstance(module, 85 | _BatchNorm) and module.track_running_stats: 86 | dist.broadcast(module.running_var, 0) 87 | dist.broadcast(module.running_mean, 0) 88 | 89 | if not self._should_evaluate(runner): 90 | return 91 | 92 | tmpdir = self.tmpdir 93 | if tmpdir is None: 94 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 95 | 96 | from depth.apis import multi_gpu_test 97 | results = multi_gpu_test( 98 | runner.model, 99 | self.dataloader, 100 | tmpdir=tmpdir, 101 | gpu_collect=self.gpu_collect, 102 | pre_eval=self.pre_eval) 103 | 104 | runner.log_buffer.clear() 105 | 106 | if runner.rank == 0: 107 | print('\n') 108 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 109 | key_score = self.evaluate(runner, results) 110 | 111 | if self.save_best: 112 | self._save_ckpt(runner, key_score) 113 | -------------------------------------------------------------------------------- /depth/core/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import mmcv 4 | import numpy as np 5 | import torch 6 | 7 | def calculate(gt, pred): 8 | if gt.shape[0] == 0: 9 | return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan 10 | 11 | thresh = np.maximum((gt / pred), (pred / gt)) 12 | a1 = (thresh < 1.25).mean() 13 | a2 = (thresh < 1.25 ** 2).mean() 14 | a3 = (thresh < 1.25 ** 3).mean() 15 | 16 | abs_rel = np.mean(np.abs(gt - pred) / gt) 17 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 18 | 19 | rmse = (gt - pred) ** 2 20 | rmse = np.sqrt(rmse.mean()) 21 | 22 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 23 | rmse_log = np.sqrt(rmse_log.mean()) 24 | 25 | err = np.log(pred) - np.log(gt) 26 | 27 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 28 | if np.isnan(silog): 29 | silog = 0 30 | 31 | log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() 32 | return a1, a2, a3, abs_rel, rmse, log_10, rmse_log, silog, sq_rel 33 | 34 | def metrics(gt, pred, min_depth=1e-3, max_depth=80): 35 | mask_1 = gt > min_depth 36 | mask_2 = gt < max_depth 37 | mask = np.logical_and(mask_1, mask_2) 38 | 39 | gt = gt[mask] 40 | pred = pred[mask] 41 | 42 | a1, a2, a3, abs_rel, rmse, log_10, rmse_log, silog, sq_rel = calculate(gt, pred) 43 | 44 | return a1, a2, a3, abs_rel, rmse, log_10, rmse_log, silog, sq_rel 45 | 46 | def eval_metrics(gt, pred, min_depth=1e-3, max_depth=80): 47 | mask_1 = gt > min_depth 48 | mask_2 = gt < max_depth 49 | mask = np.logical_and(mask_1, mask_2) 50 | 51 | gt = gt[mask] 52 | pred = pred[mask] 53 | 54 | thresh = np.maximum((gt / pred), (pred / gt)) 55 | a1 = (thresh < 1.25).mean() 56 | a2 = (thresh < 1.25 ** 2).mean() 57 | a3 = (thresh < 1.25 ** 3).mean() 58 | 59 | abs_rel = np.mean(np.abs(gt - pred) / gt) 60 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 61 | 62 | rmse = (gt - pred) ** 2 63 | rmse = np.sqrt(rmse.mean()) 64 | 65 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 66 | rmse_log = np.sqrt(rmse_log.mean()) 67 | 68 | err = np.log(pred) - np.log(gt) 69 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 70 | 71 | log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() 72 | return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, 73 | silog=silog, sq_rel=sq_rel) 74 | 75 | 76 | def pre_eval_to_metrics(pre_eval_results): 77 | 78 | # convert list of tuples to tuple of lists, e.g. 79 | # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to 80 | # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) 81 | pre_eval_results = tuple(zip(*pre_eval_results)) 82 | ret_metrics = OrderedDict({}) 83 | 84 | ret_metrics['a1'] = np.nanmean(pre_eval_results[0]) 85 | ret_metrics['a2'] = np.nanmean(pre_eval_results[1]) 86 | ret_metrics['a3'] = np.nanmean(pre_eval_results[2]) 87 | ret_metrics['abs_rel'] = np.nanmean(pre_eval_results[3]) 88 | ret_metrics['rmse'] = np.nanmean(pre_eval_results[4]) 89 | ret_metrics['log_10'] = np.nanmean(pre_eval_results[5]) 90 | ret_metrics['rmse_log'] = np.nanmean(pre_eval_results[6]) 91 | ret_metrics['silog'] = np.nanmean(pre_eval_results[7]) 92 | ret_metrics['sq_rel'] = np.nanmean(pre_eval_results[8]) 93 | 94 | ret_metrics = { 95 | metric: value 96 | for metric, value in ret_metrics.items() 97 | } 98 | 99 | return ret_metrics 100 | 101 | -------------------------------------------------------------------------------- /depth/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .misc import add_prefix 3 | 4 | __all__ = ['add_prefix'] 5 | -------------------------------------------------------------------------------- /depth/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def add_prefix(inputs, prefix): 3 | """Add prefix for dict. 4 | 5 | Args: 6 | inputs (dict): The input dict with str keys. 7 | prefix (str): The prefix to add. 8 | 9 | Returns: 10 | 11 | dict: The dict with keys updated with ``prefix``. 12 | """ 13 | 14 | outputs = dict() 15 | for name, value in inputs.items(): 16 | outputs[f'{prefix}.{name}'] = value 17 | 18 | return outputs 19 | -------------------------------------------------------------------------------- /depth/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .kitti import KITTIDataset 3 | from .nyu import NYUDataset 4 | from .sunrgbd import SUNRGBDDataset 5 | from .custom import CustomDepthDataset 6 | from .cityscapes import CSDataset 7 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 8 | 9 | __all__ = [ 10 | 'KITTIDataset', 'NYUDataset', 'SUNRGBDDataset', 'CustomDepthDataset', 'CSDataset', 11 | ] -------------------------------------------------------------------------------- /depth/datasets/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import platform 4 | import random 5 | from functools import partial 6 | 7 | import numpy as np 8 | import torch 9 | from mmcv.parallel import collate 10 | from mmcv.runner import get_dist_info 11 | from mmcv.utils import Registry, build_from_cfg, digit_version 12 | from torch.utils.data import DataLoader, DistributedSampler 13 | 14 | if platform.system() != 'Windows': 15 | # https://github.com/pytorch/pytorch/issues/973 16 | import resource 17 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 18 | base_soft_limit = rlimit[0] 19 | hard_limit = rlimit[1] 20 | soft_limit = min(max(4096, base_soft_limit), hard_limit) 21 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 22 | 23 | DATASETS = Registry('dataset') 24 | PIPELINES = Registry('pipeline') 25 | 26 | 27 | def _concat_dataset(cfg, default_args=None): 28 | """Build :obj:`ConcatDataset by.""" 29 | from .dataset_wrappers import ConcatDataset 30 | img_dir = cfg['img_dir'] 31 | ann_dir = cfg.get('ann_dir', None) 32 | split = cfg.get('split', None) 33 | num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 34 | if ann_dir is not None: 35 | num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 36 | else: 37 | num_ann_dir = 0 38 | if split is not None: 39 | num_split = len(split) if isinstance(split, (list, tuple)) else 1 40 | else: 41 | num_split = 0 42 | if num_img_dir > 1: 43 | assert num_img_dir == num_ann_dir or num_ann_dir == 0 44 | assert num_img_dir == num_split or num_split == 0 45 | else: 46 | assert num_split == num_ann_dir or num_ann_dir <= 1 47 | num_dset = max(num_split, num_img_dir) 48 | 49 | datasets = [] 50 | for i in range(num_dset): 51 | data_cfg = copy.deepcopy(cfg) 52 | if isinstance(img_dir, (list, tuple)): 53 | data_cfg['img_dir'] = img_dir[i] 54 | if isinstance(ann_dir, (list, tuple)): 55 | data_cfg['ann_dir'] = ann_dir[i] 56 | if isinstance(split, (list, tuple)): 57 | data_cfg['split'] = split[i] 58 | datasets.append(build_dataset(data_cfg, default_args)) 59 | 60 | return ConcatDataset(datasets) 61 | 62 | 63 | def build_dataset(cfg, default_args=None): 64 | """Build datasets.""" 65 | from .dataset_wrappers import ConcatDataset, RepeatDataset 66 | if isinstance(cfg, (list, tuple)): 67 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) 68 | elif cfg['type'] == 'RepeatDataset': 69 | dataset = RepeatDataset( 70 | build_dataset(cfg['dataset'], default_args), cfg['times']) 71 | elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( 72 | cfg.get('split', None), (list, tuple)): 73 | dataset = _concat_dataset(cfg, default_args) 74 | else: 75 | dataset = build_from_cfg(cfg, DATASETS, default_args) 76 | 77 | return dataset 78 | 79 | 80 | def build_dataloader(dataset, 81 | samples_per_gpu, 82 | workers_per_gpu, 83 | num_gpus=1, 84 | dist=True, 85 | shuffle=True, 86 | seed=None, 87 | drop_last=False, 88 | pin_memory=True, 89 | persistent_workers=True, 90 | **kwargs): 91 | """Build PyTorch DataLoader. 92 | 93 | In distributed training, each GPU/process has a dataloader. 94 | In non-distributed training, there is only one dataloader for all GPUs. 95 | 96 | Args: 97 | dataset (Dataset): A PyTorch dataset. 98 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 99 | batch size of each GPU. 100 | workers_per_gpu (int): How many subprocesses to use for data loading 101 | for each GPU. 102 | num_gpus (int): Number of GPUs. Only used in non-distributed training. 103 | dist (bool): Distributed training/test or not. Default: True. 104 | shuffle (bool): Whether to shuffle the data at every epoch. 105 | Default: True. 106 | seed (int | None): Seed to be used. Default: None. 107 | drop_last (bool): Whether to drop the last incomplete batch in epoch. 108 | Default: False 109 | pin_memory (bool): Whether to use pin_memory in DataLoader. 110 | Default: True 111 | persistent_workers (bool): If True, the data loader will not shutdown 112 | the worker processes after a dataset has been consumed once. 113 | This allows to maintain the workers Dataset instances alive. 114 | The argument also has effect in PyTorch>=1.7.0. 115 | Default: True 116 | kwargs: any keyword argument to be used to initialize DataLoader 117 | 118 | Returns: 119 | DataLoader: A PyTorch dataloader. 120 | """ 121 | rank, world_size = get_dist_info() 122 | if dist: 123 | sampler = DistributedSampler( 124 | dataset, world_size, rank, shuffle=shuffle) 125 | shuffle = False 126 | batch_size = samples_per_gpu 127 | num_workers = workers_per_gpu 128 | else: 129 | sampler = None 130 | batch_size = num_gpus * samples_per_gpu 131 | num_workers = num_gpus * workers_per_gpu 132 | 133 | init_fn = partial( 134 | worker_init_fn, num_workers=num_workers, rank=rank, 135 | seed=seed) if seed is not None else None 136 | 137 | if digit_version(torch.__version__) >= digit_version('1.8.0'): 138 | data_loader = DataLoader( 139 | dataset, 140 | batch_size=batch_size, 141 | sampler=sampler, 142 | num_workers=num_workers, 143 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 144 | pin_memory=pin_memory, 145 | shuffle=shuffle, 146 | worker_init_fn=init_fn, 147 | drop_last=drop_last, 148 | persistent_workers=persistent_workers, 149 | **kwargs) 150 | else: 151 | data_loader = DataLoader( 152 | dataset, 153 | batch_size=batch_size, 154 | sampler=sampler, 155 | num_workers=num_workers, 156 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 157 | pin_memory=pin_memory, 158 | shuffle=shuffle, 159 | worker_init_fn=init_fn, 160 | drop_last=drop_last, 161 | **kwargs) 162 | 163 | return data_loader 164 | 165 | 166 | def worker_init_fn(worker_id, num_workers, rank, seed): 167 | """Worker init func for dataloader. 168 | 169 | The seed of each worker equals to num_worker * rank + worker_id + user_seed 170 | 171 | Args: 172 | worker_id (int): Worker id. 173 | num_workers (int): Number of workers. 174 | rank (int): The rank of current process. 175 | seed (int): The random seed to use. 176 | """ 177 | 178 | worker_seed = num_workers * rank + worker_id + seed 179 | np.random.seed(worker_seed) 180 | random.seed(worker_seed) 181 | -------------------------------------------------------------------------------- /depth/datasets/custom.py: -------------------------------------------------------------------------------- 1 | from logging import raiseExceptions 2 | import os.path as osp 3 | import warnings 4 | from collections import OrderedDict 5 | from functools import reduce 6 | 7 | import mmcv 8 | import numpy as np 9 | from mmcv.utils import print_log 10 | from prettytable import PrettyTable 11 | from torch.utils.data import Dataset 12 | 13 | from depth.core import pre_eval_to_metrics, metrics, eval_metrics 14 | from depth.utils import get_root_logger 15 | from depth.datasets.builder import DATASETS 16 | from depth.datasets.pipelines import Compose 17 | 18 | from depth.ops import resize 19 | 20 | from PIL import Image 21 | 22 | import torch 23 | import os 24 | 25 | 26 | @DATASETS.register_module() 27 | class CustomDepthDataset(Dataset): 28 | """Custom dataset for supervised monocular depth esitmation. 29 | An example of file structure. is as followed. 30 | .. code-block:: none 31 | ├── data 32 | │ ├── custom 33 | │ │ ├── train 34 | │ │ │ ├── rgb 35 | │ │ │ │ ├── 0.xxx 36 | │ │ │ │ ├── 1.xxx 37 | │ │ │ │ ├── 2.xxx 38 | │ │ │ ├── depth 39 | │ │ │ │ ├── 0.xxx 40 | │ │ │ │ ├── 1.xxx 41 | │ │ │ │ ├── 2.xxx 42 | │ │ ├── val 43 | │ │ │ ... 44 | │ │ │ ... 45 | 46 | Args: 47 | pipeline (list[dict]): Processing pipeline 48 | img_dir (str): Path to image directory 49 | data_root (str, optional): Data root for img_dir. 50 | test_mode (bool): test_mode=True 51 | min_depth=1e-3: Default min depth value. 52 | max_depth=10: Default max depth value. 53 | """ 54 | 55 | def __init__(self, 56 | pipeline, 57 | data_root, 58 | test_mode=True, 59 | min_depth=1e-3, 60 | max_depth=10, 61 | depth_scale=1): 62 | 63 | self.pipeline = Compose(pipeline) 64 | self.img_path = os.path.join(data_root, 'rgb') 65 | self.depth_path = os.path.join(data_root, 'depth') 66 | self.test_mode = test_mode 67 | self.min_depth = min_depth 68 | self.max_depth = max_depth 69 | self.depth_scale = depth_scale 70 | 71 | # load annotations 72 | self.img_infos = self.load_annotations(self.img_path, self.depth_path) 73 | 74 | 75 | def __len__(self): 76 | """Total number of samples of data.""" 77 | return len(self.img_infos) 78 | 79 | def load_annotations(self, img_dir, depth_dir): 80 | """Load annotation from directory. 81 | Args: 82 | img_dir (str): Path to image directory. Load all the images under the root. 83 | Returns: 84 | list[dict]: All image info of dataset. 85 | """ 86 | 87 | img_infos = [] 88 | 89 | imgs = os.listdir(img_dir) 90 | imgs.sort() 91 | 92 | if self.test_mode is not True: 93 | depths = os.listdir(depth_dir) 94 | depths.sort() 95 | 96 | for img, depth in zip(imgs, depths): 97 | img_info = dict() 98 | img_info['filename'] = img 99 | img_info['ann'] = dict(depth_map=depth) 100 | img_infos.append(img_info) 101 | 102 | else: 103 | 104 | for img in imgs: 105 | img_info = dict() 106 | img_info['filename'] = img 107 | img_infos.append(img_info) 108 | 109 | # github issue:: make sure the same order 110 | img_infos = sorted(img_infos, key=lambda x: x['filename']) 111 | print_log(f'Loaded {len(img_infos)} images.', logger=get_root_logger()) 112 | 113 | return img_infos 114 | 115 | def pre_pipeline(self, results): 116 | """Prepare results dict for pipeline.""" 117 | results['depth_fields'] = [] 118 | results['img_prefix'] = self.img_path 119 | results['depth_prefix'] = self.depth_path 120 | results['depth_scale'] = self.depth_scale 121 | 122 | def __getitem__(self, idx): 123 | """Get training/test data after pipeline. 124 | Args: 125 | idx (int): Index of data. 126 | Returns: 127 | dict: Training/test data (with annotation if `test_mode` is set 128 | False). 129 | """ 130 | if self.test_mode: 131 | return self.prepare_test_img(idx) 132 | else: 133 | return self.prepare_train_img(idx) 134 | 135 | def prepare_train_img(self, idx): 136 | """Get training data and annotations after pipeline. 137 | Args: 138 | idx (int): Index of data. 139 | Returns: 140 | dict: Training data and annotation after pipeline with new keys 141 | introduced by pipeline. 142 | """ 143 | 144 | img_info = self.img_infos[idx] 145 | ann_info = self.get_ann_info(idx) 146 | results = dict(img_info=img_info, ann_info=ann_info) 147 | self.pre_pipeline(results) 148 | return self.pipeline(results) 149 | 150 | def prepare_test_img(self, idx): 151 | """Get testing data after pipeline. 152 | Args: 153 | idx (int): Index of data. 154 | Returns: 155 | dict: Testing data after pipeline with new keys introduced by 156 | pipeline. 157 | """ 158 | 159 | img_info = self.img_infos[idx] 160 | results = dict(img_info=img_info) 161 | self.pre_pipeline(results) 162 | return self.pipeline(results) 163 | 164 | def get_ann_info(self, idx): 165 | """Get annotation by index. 166 | Args: 167 | idx (int): Index of data. 168 | Returns: 169 | dict: Annotation info of specified index. 170 | """ 171 | 172 | return self.img_infos[idx]['ann'] 173 | 174 | # waiting to be done 175 | def format_results(self, results, imgfile_prefix=None, indices=None, **kwargs): 176 | """Place holder to format result to dataset specific output.""" 177 | results[0] = (results[0] * self.depth_scale) # Do not convert to np.uint16 for ensembling. # .astype(np.uint16) 178 | return results 179 | 180 | # design your own evaluation pipeline 181 | def pre_eval(self, preds, indices): 182 | pass 183 | 184 | def evaluate(self, results, metric='eigen', logger=None, **kwargs): 185 | pass -------------------------------------------------------------------------------- /depth/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 3 | 4 | from .builder import DATASETS 5 | 6 | # TODO: may need change 7 | @DATASETS.register_module() 8 | class ConcatDataset(_ConcatDataset): 9 | """A wrapper of concatenated dataset. 10 | 11 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 12 | concat the group flag for image aspect ratio. 13 | 14 | Args: 15 | datasets (list[:obj:`Dataset`]): A list of datasets. 16 | """ 17 | 18 | def __init__(self, datasets): 19 | super(ConcatDataset, self).__init__(datasets) 20 | self.CLASSES = datasets[0].CLASSES 21 | 22 | 23 | @DATASETS.register_module() 24 | class RepeatDataset(object): 25 | """A wrapper of repeated dataset. 26 | 27 | The length of repeated dataset will be `times` larger than the original 28 | dataset. This is useful when the data loading time is long but the dataset 29 | is small. Using RepeatDataset can reduce the data loading time between 30 | epochs. 31 | 32 | Args: 33 | dataset (:obj:`Dataset`): The dataset to be repeated. 34 | times (int): Repeat times. 35 | """ 36 | 37 | def __init__(self, dataset, times): 38 | self.dataset = dataset 39 | self.times = times 40 | self.CLASSES = dataset.CLASSES 41 | self._ori_len = len(self.dataset) 42 | 43 | def __getitem__(self, idx): 44 | """Get item from original dataset.""" 45 | return self.dataset[idx % self._ori_len] 46 | 47 | def __len__(self): 48 | """The length is multiplied by ``times``""" 49 | return self.times * self._ori_len 50 | -------------------------------------------------------------------------------- /depth/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .compose import Compose 3 | from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, 4 | Transpose, to_tensor) 5 | from .test_time_aug import MultiScaleFlipAug 6 | 7 | from .loading import DepthLoadAnnotations, DisparityLoadAnnotations, LoadImageFromFile, LoadKITTICamIntrinsic 8 | from .transforms import KBCrop, RandomRotate, RandomFlip, RandomCrop, NYUCrop, NYUMask, Resize, Normalize,\ 9 | RandomChannelSwap, RandomMask 10 | from .formating import DefaultFormatBundle 11 | 12 | __all__ = [ 13 | 'Compose', 'Collect', 'ImageToTensor', 'ToDataContainer', 'ToTensor', 14 | 'Transpose', 'to_tensor', 'MultiScaleFlipAug', 15 | 16 | 'DepthLoadAnnotations', 'KBCrop', 'RandomRotate', 'RandomFlip', 'RandomCrop', 'DefaultFormatBundle', 17 | 'NYUCrop', 'NYUMask', 'DisparityLoadAnnotations', 'Resize', 'LoadImageFromFile', 'Normalize', 'LoadKITTICamIntrinsic', 18 | 'RandomChannelSwap' 19 | ] -------------------------------------------------------------------------------- /depth/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import collections 3 | from ..builder import PIPELINES 4 | from mmcv.utils import build_from_cfg 5 | 6 | 7 | @PIPELINES.register_module() 8 | class Compose(object): 9 | """Compose multiple transforms sequentially. 10 | 11 | Args: 12 | transforms (Sequence[dict | callable]): Sequence of transform object or 13 | config dict to be composed. 14 | """ 15 | def __init__(self, transforms): 16 | assert isinstance(transforms, collections.abc.Sequence) 17 | self.transforms = [] 18 | for transform in transforms: 19 | if isinstance(transform, dict): 20 | transform = build_from_cfg(transform, PIPELINES) 21 | self.transforms.append(transform) 22 | elif callable(transform): 23 | self.transforms.append(transform) 24 | else: 25 | raise TypeError('transform must be callable or a dict') 26 | 27 | def __call__(self, data): 28 | """Call function to apply transforms sequentially. 29 | 30 | Args: 31 | data (dict): A result dict contains the data to transform. 32 | 33 | Returns: 34 | dict: Transformed data. 35 | """ 36 | 37 | for t in self.transforms: 38 | data = t(data) 39 | if data is None: 40 | return None 41 | return data 42 | 43 | def __repr__(self): 44 | format_string = self.__class__.__name__ + '(' 45 | for t in self.transforms: 46 | format_string += '\n' 47 | format_string += f' {t}' 48 | format_string += '\n)' 49 | return format_string 50 | -------------------------------------------------------------------------------- /depth/datasets/pipelines/test_time_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import warnings 4 | from .compose import Compose 5 | from ..builder import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class MultiScaleFlipAug(object): 10 | """Test-time augmentation with multiple scales and flipping. 11 | 12 | An example configuration is as followed: 13 | 14 | .. code-block:: 15 | 16 | img_scale=(2048, 1024), 17 | img_ratios=[0.5, 1.0], 18 | flip=True, 19 | transforms=[ 20 | dict(type='Resize', keep_ratio=True), 21 | dict(type='RandomFlip'), 22 | dict(type='Normalize', **img_norm_cfg), 23 | dict(type='Pad', size_divisor=32), 24 | dict(type='ImageToTensor', keys=['img']), 25 | dict(type='Collect', keys=['img']), 26 | ] 27 | 28 | After MultiScaleFLipAug with above configuration, the results are wrapped 29 | into lists of the same length as followed: 30 | 31 | .. code-block:: 32 | 33 | dict( 34 | img=[...], 35 | img_shape=[...], 36 | scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] 37 | flip=[False, True, False, True] 38 | ... 39 | ) 40 | 41 | Args: 42 | transforms (list[dict]): Transforms to apply in each augmentation. 43 | img_scale (None | tuple | list[tuple]): Images scales for resizing. 44 | img_ratios (float | list[float]): Image ratios for resizing 45 | flip (bool): Whether apply flip augmentation. Default: False. 46 | flip_direction (str | list[str]): Flip augmentation directions, 47 | options are "horizontal" and "vertical". If flip_direction is list, 48 | multiple flip augmentations will be applied. 49 | It has no effect when flip == False. Default: "horizontal". 50 | """ 51 | def __init__(self, 52 | transforms, 53 | img_scale, 54 | img_ratios=None, 55 | flip=False, 56 | flip_direction='horizontal'): 57 | self.transforms = Compose(transforms) 58 | if img_ratios is not None: 59 | img_ratios = img_ratios if isinstance(img_ratios, 60 | list) else [img_ratios] 61 | assert mmcv.is_list_of(img_ratios, float) 62 | if img_scale is None: 63 | # mode 1: given img_scale=None and a range of image ratio 64 | self.img_scale = None 65 | assert mmcv.is_list_of(img_ratios, float) 66 | elif isinstance(img_scale, tuple) and mmcv.is_list_of( 67 | img_ratios, float): 68 | assert len(img_scale) == 2 69 | # mode 2: given a scale and a range of image ratio 70 | self.img_scale = [(int(img_scale[0] * ratio), 71 | int(img_scale[1] * ratio)) 72 | for ratio in img_ratios] 73 | else: 74 | # mode 3: given multiple scales 75 | self.img_scale = img_scale if isinstance(img_scale, 76 | list) else [img_scale] 77 | assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None 78 | self.flip = flip 79 | self.img_ratios = img_ratios 80 | self.flip_direction = flip_direction if isinstance( 81 | flip_direction, list) else [flip_direction] 82 | assert mmcv.is_list_of(self.flip_direction, str) 83 | if not self.flip and self.flip_direction != ['horizontal']: 84 | warnings.warn( 85 | 'flip_direction has no effect when flip is set to False') 86 | if (self.flip 87 | and not any([t['type'] == 'RandomFlip' for t in transforms])): 88 | warnings.warn( 89 | 'flip has no effect when RandomFlip is not in transforms') 90 | 91 | def __call__(self, results): 92 | """Call function to apply test time augment transforms on results. 93 | 94 | Args: 95 | results (dict): Result dict contains the data to transform. 96 | 97 | Returns: 98 | dict[str: list]: The augmented data, where each value is wrapped 99 | into a list. 100 | """ 101 | 102 | aug_data = [] 103 | if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): 104 | h, w = results['img'].shape[:2] 105 | img_scale = [(int(w * ratio), int(h * ratio)) 106 | for ratio in self.img_ratios] 107 | else: 108 | img_scale = self.img_scale 109 | flip_aug = [False, True] if self.flip else [False] 110 | for scale in img_scale: 111 | for flip in flip_aug: 112 | for direction in self.flip_direction: 113 | _results = results.copy() 114 | _results['scale'] = scale 115 | _results['flip'] = flip 116 | _results['flip_direction'] = direction 117 | data = self.transforms(_results) 118 | aug_data.append(data) 119 | # list of dict to dict of list 120 | aug_data_dict = {key: [] for key in aug_data[0]} 121 | for data in aug_data: 122 | for key, val in data.items(): 123 | aug_data_dict[key].append(val) 124 | 125 | # print(aug_data_dict['img'][0].shape) 126 | 127 | return aug_data_dict 128 | 129 | def __repr__(self): 130 | repr_str = self.__class__.__name__ 131 | repr_str += f'(transforms={self.transforms}, ' 132 | repr_str += f'img_scale={self.img_scale}, flip={self.flip})' 133 | repr_str += f'flip_direction={self.flip_direction}' 134 | return repr_str 135 | -------------------------------------------------------------------------------- /depth/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor 5 | from .resize_transform import SETR_Resize 6 | from .apex_runner.optimizer import DistOptimizerHook 7 | from .train_api import train_segmentor 8 | 9 | __all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor', 'SETR_Resize', 'DistOptimizerHook', 'train_segmentor'] 10 | -------------------------------------------------------------------------------- /depth/mmcv_custom/apex_runner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | from .checkpoint import save_checkpoint 3 | from .apex_iter_based_runner import IterBasedRunnerAmp 4 | 5 | 6 | __all__ = [ 7 | 'save_checkpoint', 'IterBasedRunnerAmp', 8 | ] 9 | -------------------------------------------------------------------------------- /depth/mmcv_custom/apex_runner/apex_iter_based_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import platform 4 | import shutil 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.runner import RUNNERS, IterBasedRunner 11 | from .checkpoint import save_checkpoint 12 | 13 | # try: 14 | # import apex 15 | # except: 16 | # print('apex is not installed') 17 | 18 | 19 | @RUNNERS.register_module() 20 | class IterBasedRunnerAmp(IterBasedRunner): 21 | """Iteration-based Runner with AMP support. 22 | 23 | This runner train models iteration by iteration. 24 | """ 25 | 26 | def save_checkpoint(self, 27 | out_dir, 28 | filename_tmpl='iter_{}.pth', 29 | meta=None, 30 | save_optimizer=True, 31 | create_symlink=False): 32 | """Save checkpoint to file. 33 | 34 | Args: 35 | out_dir (str): Directory to save checkpoint files. 36 | filename_tmpl (str, optional): Checkpoint file template. 37 | Defaults to 'iter_{}.pth'. 38 | meta (dict, optional): Metadata to be saved in checkpoint. 39 | Defaults to None. 40 | save_optimizer (bool, optional): Whether save optimizer. 41 | Defaults to True. 42 | create_symlink (bool, optional): Whether create symlink to the 43 | latest checkpoint file. Defaults to True. 44 | """ 45 | if meta is None: 46 | meta = dict(iter=self.iter + 1, epoch=self.epoch + 1) 47 | elif isinstance(meta, dict): 48 | meta.update(iter=self.iter + 1, epoch=self.epoch + 1) 49 | else: 50 | raise TypeError( 51 | f'meta should be a dict or None, but got {type(meta)}') 52 | if self.meta is not None: 53 | meta.update(self.meta) 54 | 55 | filename = filename_tmpl.format(self.iter + 1) 56 | filepath = osp.join(out_dir, filename) 57 | optimizer = self.optimizer if save_optimizer else None 58 | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) 59 | # in some environments, `os.symlink` is not supported, you may need to 60 | # set `create_symlink` to False 61 | # if create_symlink: 62 | # dst_file = osp.join(out_dir, 'latest.pth') 63 | # if platform.system() != 'Windows': 64 | # mmcv.symlink(filename, dst_file) 65 | # else: 66 | # shutil.copy(filepath, dst_file) 67 | 68 | def resume(self, 69 | checkpoint, 70 | resume_optimizer=True, 71 | map_location='default'): 72 | if map_location == 'default': 73 | if torch.cuda.is_available(): 74 | device_id = torch.cuda.current_device() 75 | checkpoint = self.load_checkpoint( 76 | checkpoint, 77 | map_location=lambda storage, loc: storage.cuda(device_id)) 78 | else: 79 | checkpoint = self.load_checkpoint(checkpoint) 80 | else: 81 | checkpoint = self.load_checkpoint( 82 | checkpoint, map_location=map_location) 83 | 84 | self._epoch = checkpoint['meta']['epoch'] 85 | self._iter = checkpoint['meta']['iter'] 86 | self._inner_iter = checkpoint['meta']['iter'] 87 | if 'optimizer' in checkpoint and resume_optimizer: 88 | if isinstance(self.optimizer, Optimizer): 89 | self.optimizer.load_state_dict(checkpoint['optimizer']) 90 | elif isinstance(self.optimizer, dict): 91 | for k in self.optimizer.keys(): 92 | self.optimizer[k].load_state_dict( 93 | checkpoint['optimizer'][k]) 94 | else: 95 | raise TypeError( 96 | 'Optimizer should be dict or torch.optim.Optimizer ' 97 | f'but got {type(self.optimizer)}') 98 | 99 | # if 'amp' in checkpoint: 100 | # apex.amp.load_state_dict(checkpoint['amp']) 101 | # self.logger.info('load amp state dict') 102 | 103 | self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') 104 | -------------------------------------------------------------------------------- /depth/mmcv_custom/apex_runner/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import time 4 | from tempfile import TemporaryDirectory 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.parallel import is_module_wrapper 11 | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict 12 | 13 | # try: 14 | # import apex 15 | # except: 16 | # print('apex is not installed') 17 | 18 | 19 | def save_checkpoint(model, filename, optimizer=None, meta=None): 20 | """Save checkpoint to file. 21 | 22 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and 23 | ``optimizer``, ``amp``. By default ``meta`` will contain version 24 | and time info. 25 | 26 | Args: 27 | model (Module): Module whose params are to be saved. 28 | filename (str): Checkpoint filename. 29 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 30 | meta (dict, optional): Metadata to be saved in checkpoint. 31 | """ 32 | if meta is None: 33 | meta = {} 34 | elif not isinstance(meta, dict): 35 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') 36 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 37 | 38 | if is_module_wrapper(model): 39 | model = model.module 40 | 41 | if hasattr(model, 'CLASSES') and model.CLASSES is not None: 42 | # save class name to the meta 43 | meta.update(CLASSES=model.CLASSES) 44 | 45 | checkpoint = { 46 | 'meta': meta, 47 | 'state_dict': weights_to_cpu(get_state_dict(model)) 48 | } 49 | # save optimizer state dict in the checkpoint 50 | if isinstance(optimizer, Optimizer): 51 | checkpoint['optimizer'] = optimizer.state_dict() 52 | elif isinstance(optimizer, dict): 53 | checkpoint['optimizer'] = {} 54 | for name, optim in optimizer.items(): 55 | checkpoint['optimizer'][name] = optim.state_dict() 56 | 57 | # save amp state dict in the checkpoint 58 | # checkpoint['amp'] = apex.amp.state_dict() 59 | 60 | if filename.startswith('pavi://'): 61 | try: 62 | from pavi import modelcloud 63 | from pavi.exception import NodeNotFoundError 64 | except ImportError: 65 | raise ImportError( 66 | 'Please install pavi to load checkpoint from modelcloud.') 67 | model_path = filename[7:] 68 | root = modelcloud.Folder() 69 | model_dir, model_name = osp.split(model_path) 70 | try: 71 | model = modelcloud.get(model_dir) 72 | except NodeNotFoundError: 73 | model = root.create_training_model(model_dir) 74 | with TemporaryDirectory() as tmp_dir: 75 | checkpoint_file = osp.join(tmp_dir, model_name) 76 | with open(checkpoint_file, 'wb') as f: 77 | torch.save(checkpoint, f) 78 | f.flush() 79 | model.create_file(checkpoint_file, name=model_name) 80 | else: 81 | mmcv.mkdir_or_exist(osp.dirname(filename)) 82 | # immediately flush buffer 83 | with open(filename, 'wb') as f: 84 | torch.save(checkpoint, f) 85 | f.flush() 86 | -------------------------------------------------------------------------------- /depth/mmcv_custom/apex_runner/optimizer.py: -------------------------------------------------------------------------------- 1 | from mmcv.runner import OptimizerHook, HOOKS 2 | # try: 3 | # import apex 4 | # except: 5 | # print('apex is not installed') 6 | 7 | 8 | @HOOKS.register_module() 9 | class DistOptimizerHook(OptimizerHook): 10 | """Optimizer hook for distributed training.""" 11 | 12 | def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): 13 | self.grad_clip = grad_clip 14 | self.coalesce = coalesce 15 | self.bucket_size_mb = bucket_size_mb 16 | self.update_interval = update_interval 17 | self.use_fp16 = use_fp16 18 | 19 | def before_run(self, runner): 20 | runner.optimizer.zero_grad() 21 | 22 | def after_train_iter(self, runner): 23 | runner.outputs['loss'] /= self.update_interval 24 | # if self.use_fp16: 25 | # with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss: 26 | # scaled_loss.backward() 27 | # else: 28 | # runner.outputs['loss'].backward() 29 | runner.outputs['loss'].backward() 30 | if self.every_n_iters(runner, self.update_interval): 31 | if self.grad_clip is not None: 32 | self.clip_grads(runner.model.parameters()) 33 | runner.optimizer.step() 34 | runner.optimizer.zero_grad() 35 | -------------------------------------------------------------------------------- /depth/mmcv_custom/customized_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import datetime 10 | from collections import OrderedDict 11 | 12 | import torch 13 | 14 | import mmcv 15 | from mmcv.runner import HOOKS 16 | from mmcv.runner import TextLoggerHook 17 | 18 | 19 | @HOOKS.register_module() 20 | class CustomizedTextLoggerHook(TextLoggerHook): 21 | """Customized Text Logger hook. 22 | 23 | This logger prints out both lr and layer_0_lr. 24 | 25 | """ 26 | 27 | def _log_info(self, log_dict, runner): 28 | # print exp name for users to distinguish experiments 29 | # at every ``interval_exp_name`` iterations and the end of each epoch 30 | if runner.meta is not None and 'exp_name' in runner.meta: 31 | if (self.every_n_iters(runner, self.interval_exp_name)) or ( 32 | self.by_epoch and self.end_of_epoch(runner)): 33 | exp_info = f'Exp name: {runner.meta["exp_name"]}' 34 | runner.logger.info(exp_info) 35 | 36 | if log_dict['mode'] == 'train': 37 | lr_str = {} 38 | for lr_type in ['lr', 'layer_0_lr']: 39 | if isinstance(log_dict[lr_type], dict): 40 | lr_str[lr_type] = [] 41 | for k, val in log_dict[lr_type].items(): 42 | lr_str.append(f'{lr_type}_{k}: {val:.3e}') 43 | lr_str[lr_type] = ' '.join(lr_str) 44 | else: 45 | lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}' 46 | 47 | # by epoch: Epoch [4][100/1000] 48 | # by iter: Iter [100/100000] 49 | if self.by_epoch: 50 | log_str = f'Epoch [{log_dict["epoch"]}]' \ 51 | f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' 52 | else: 53 | log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t' 54 | log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, ' 55 | 56 | if 'time' in log_dict.keys(): 57 | self.time_sec_tot += (log_dict['time'] * self.interval) 58 | time_sec_avg = self.time_sec_tot / ( 59 | runner.iter - self.start_iter + 1) 60 | eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) 61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 62 | log_str += f'eta: {eta_str}, ' 63 | log_str += f'time: {log_dict["time"]:.3f}, ' \ 64 | f'data_time: {log_dict["data_time"]:.3f}, ' 65 | # statistic memory 66 | if torch.cuda.is_available(): 67 | log_str += f'memory: {log_dict["memory"]}, ' 68 | else: 69 | # val/test time 70 | # here 1000 is the length of the val dataloader 71 | # by epoch: Epoch[val] [4][1000] 72 | # by iter: Iter[val] [1000] 73 | if self.by_epoch: 74 | log_str = f'Epoch({log_dict["mode"]}) ' \ 75 | f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t' 76 | else: 77 | log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t' 78 | 79 | log_items = [] 80 | for name, val in log_dict.items(): 81 | # TODO: resolve this hack 82 | # these items have been in log_str 83 | if name in [ 84 | 'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time', 85 | 'memory', 'epoch' 86 | ]: 87 | continue 88 | if isinstance(val, float): 89 | val = f'{val:.4f}' 90 | log_items.append(f'{name}: {val}') 91 | log_str += ', '.join(log_items) 92 | 93 | runner.logger.info(log_str) 94 | 95 | 96 | def log(self, runner): 97 | if 'eval_iter_num' in runner.log_buffer.output: 98 | # this doesn't modify runner.iter and is regardless of by_epoch 99 | cur_iter = runner.log_buffer.output.pop('eval_iter_num') 100 | else: 101 | cur_iter = self.get_iter(runner, inner_iter=True) 102 | 103 | log_dict = OrderedDict( 104 | mode=self.get_mode(runner), 105 | epoch=self.get_epoch(runner), 106 | iter=cur_iter) 107 | 108 | # record lr and layer_0_lr 109 | cur_lr = runner.current_lr() 110 | if isinstance(cur_lr, list): 111 | log_dict['layer_0_lr'] = min(cur_lr) 112 | log_dict['lr'] = max(cur_lr) 113 | else: 114 | assert isinstance(cur_lr, dict) 115 | log_dict['lr'], log_dict['layer_0_lr'] = {}, {} 116 | for k, lr_ in cur_lr.items(): 117 | assert isinstance(lr_, list) 118 | log_dict['layer_0_lr'].update({k: min(lr_)}) 119 | log_dict['lr'].update({k: max(lr_)}) 120 | 121 | if 'time' in runner.log_buffer.output: 122 | # statistic memory 123 | if torch.cuda.is_available(): 124 | log_dict['memory'] = self._get_max_memory(runner) 125 | 126 | log_dict = dict(log_dict, **runner.log_buffer.output) 127 | 128 | self._log_info(log_dict, runner) 129 | self._dump_log(log_dict, runner) 130 | return log_dict 131 | -------------------------------------------------------------------------------- /depth/mmcv_custom/layer_decay_optimizer_constructor.py: -------------------------------------------------------------------------------- 1 | import json 2 | from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor 3 | from mmcv.runner import get_dist_info 4 | 5 | 6 | def get_num_layer_for_vit(var_name, num_max_layer): 7 | if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): 8 | return 0 9 | elif var_name.startswith("backbone.patch_embed"): 10 | return 0 11 | elif var_name.startswith("backbone.blocks"): 12 | layer_id = int(var_name.split('.')[2]) 13 | return layer_id + 1 14 | else: 15 | return num_max_layer - 1 16 | 17 | 18 | @OPTIMIZER_BUILDERS.register_module() 19 | class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): 20 | def add_params(self, params, module, prefix='', is_dcn_module=None): 21 | """Add all parameters of module to the params list. 22 | The parameters of the given module will be added to the list of param 23 | groups, with specific rules defined by paramwise_cfg. 24 | Args: 25 | params (list[dict]): A list of param groups, it will be modified 26 | in place. 27 | module (nn.Module): The module to be added. 28 | prefix (str): The prefix of the module 29 | is_dcn_module (int|float|None): If the current module is a 30 | submodule of DCN, `is_dcn_module` will be passed to 31 | control conv_offset layer's learning rate. Defaults to None. 32 | """ 33 | parameter_groups = {} 34 | print(self.paramwise_cfg) 35 | num_layers = self.paramwise_cfg.get('num_layers') + 2 36 | layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') 37 | print("Build LayerDecayOptimizerConstructor %f - %d" % (layer_decay_rate, num_layers)) 38 | weight_decay = self.base_wd 39 | 40 | for name, param in module.named_parameters(): 41 | if not param.requires_grad: 42 | continue # frozen weights 43 | if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'): 44 | group_name = "no_decay" 45 | this_weight_decay = 0. 46 | else: 47 | group_name = "decay" 48 | this_weight_decay = weight_decay 49 | 50 | layer_id = get_num_layer_for_vit(name, num_layers) 51 | group_name = "layer_%d_%s" % (layer_id, group_name) 52 | 53 | if group_name not in parameter_groups: 54 | scale = layer_decay_rate ** (num_layers - layer_id - 1) 55 | 56 | parameter_groups[group_name] = { 57 | "weight_decay": this_weight_decay, 58 | "params": [], 59 | "param_names": [], 60 | "lr_scale": scale, 61 | "group_name": group_name, 62 | "lr": scale * self.base_lr, 63 | } 64 | 65 | parameter_groups[group_name]["params"].append(param) 66 | parameter_groups[group_name]["param_names"].append(name) 67 | rank, _ = get_dist_info() 68 | if rank == 0: 69 | to_display = {} 70 | for key in parameter_groups: 71 | to_display[key] = { 72 | "param_names": parameter_groups[key]["param_names"], 73 | "lr_scale": parameter_groups[key]["lr_scale"], 74 | "lr": parameter_groups[key]["lr"], 75 | "weight_decay": parameter_groups[key]["weight_decay"], 76 | } 77 | print("Param groups = %s" % json.dumps(to_display, indent=2)) 78 | 79 | # state_dict = module.state_dict() 80 | # for group_name in parameter_groups: 81 | # group = parameter_groups[group_name] 82 | # for name in group["param_names"]: 83 | # group["params"].append(state_dict[name]) 84 | params.extend(parameter_groups.values()) 85 | -------------------------------------------------------------------------------- /depth/mmcv_custom/train_api.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 7 | from mmcv.runner import build_optimizer, build_runner 8 | 9 | from mmseg.core import DistEvalHook, EvalHook 10 | from mmseg.datasets import build_dataloader, build_dataset 11 | from mmseg.utils import get_root_logger 12 | try: 13 | import apex 14 | except: 15 | print('apex is not installed') 16 | 17 | 18 | def set_random_seed(seed, deterministic=False): 19 | """Set random seed. 20 | 21 | Args: 22 | seed (int): Seed to be used. 23 | deterministic (bool): Whether to set the deterministic option for 24 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 25 | to True and `torch.backends.cudnn.benchmark` to False. 26 | Default: False. 27 | """ 28 | random.seed(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | if deterministic: 33 | torch.backends.cudnn.deterministic = True 34 | torch.backends.cudnn.benchmark = False 35 | 36 | 37 | def train_segmentor(model, 38 | dataset, 39 | cfg, 40 | distributed=False, 41 | validate=False, 42 | timestamp=None, 43 | meta=None): 44 | """Launch segmentor training.""" 45 | logger = get_root_logger(cfg.log_level) 46 | 47 | # prepare data loaders 48 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 49 | data_loaders = [ 50 | build_dataloader( 51 | ds, 52 | cfg.data.samples_per_gpu, 53 | cfg.data.workers_per_gpu, 54 | # cfg.gpus will be ignored if distributed 55 | len(cfg.gpu_ids), 56 | dist=distributed, 57 | seed=cfg.seed, 58 | drop_last=True) for ds in dataset 59 | ] 60 | 61 | # build optimizer 62 | optimizer = build_optimizer(model, cfg.optimizer) 63 | 64 | # use apex fp16 optimizer 65 | if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook": 66 | if cfg.optimizer_config.get("use_fp16", False): 67 | model, optimizer = apex.amp.initialize( 68 | model.cuda(), optimizer, opt_level="O1") 69 | for m in model.modules(): 70 | if hasattr(m, "fp16_enabled"): 71 | m.fp16_enabled = True 72 | 73 | # put model on gpus 74 | if distributed: 75 | find_unused_parameters = cfg.get('find_unused_parameters', False) 76 | # Sets the `find_unused_parameters` parameter in 77 | # torch.nn.parallel.DistributedDataParallel 78 | model = MMDistributedDataParallel( 79 | model.cuda(), 80 | device_ids=[torch.cuda.current_device()], 81 | broadcast_buffers=False, 82 | find_unused_parameters=find_unused_parameters) 83 | else: 84 | model = MMDataParallel( 85 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 86 | 87 | if cfg.get('runner') is None: 88 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 89 | warnings.warn( 90 | 'config is now expected to have a `runner` section, ' 91 | 'please set `runner` in your config.', UserWarning) 92 | 93 | runner = build_runner( 94 | cfg.runner, 95 | default_args=dict( 96 | model=model, 97 | batch_processor=None, 98 | optimizer=optimizer, 99 | work_dir=cfg.work_dir, 100 | logger=logger, 101 | meta=meta)) 102 | 103 | # register hooks 104 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 105 | cfg.checkpoint_config, cfg.log_config, 106 | cfg.get('momentum_config', None)) 107 | 108 | # an ugly walkaround to make the .log and .log.json filenames the same 109 | runner.timestamp = timestamp 110 | 111 | # register eval hooks 112 | if validate: 113 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 114 | val_dataloader = build_dataloader( 115 | val_dataset, 116 | samples_per_gpu=1, 117 | workers_per_gpu=cfg.data.workers_per_gpu, 118 | dist=distributed, 119 | shuffle=False) 120 | eval_cfg = cfg.get('evaluation', {}) 121 | eval_cfg['by_epoch'] = 'IterBasedRunner' not in cfg.runner['type'] 122 | eval_hook = DistEvalHook if distributed else EvalHook 123 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 124 | 125 | if cfg.resume_from: 126 | runner.resume(cfg.resume_from) 127 | elif cfg.load_from: 128 | runner.load_checkpoint(cfg.load_from) 129 | runner.run(data_loaders, cfg.workflow) 130 | -------------------------------------------------------------------------------- /depth/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, HEADS, LOSSES, DEPTHER, build_backbone, 4 | build_head, build_loss, build_depther) 5 | from .decode_heads import * # noqa: F401,F403 6 | from .losses import * # noqa: F401,F403 7 | from .necks import * # noqa: F401,F403 8 | from .depther import * # noqa: F401,F403 9 | 10 | __all__ = [ 11 | 'BACKBONES', 'HEADS', 'LOSSES', 'build_backbone', 12 | 'build_head', 'build_loss', 'DEPTHER', 'build_depther' 13 | ] 14 | -------------------------------------------------------------------------------- /depth/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .swin_transformer import Swin 2 | from .xcit import XCiT -------------------------------------------------------------------------------- /depth/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmcv.cnn import MODELS as MMCV_MODELS 5 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 6 | from mmcv.utils import Registry 7 | 8 | MODELS = Registry('models', parent=MMCV_MODELS) 9 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 10 | 11 | BACKBONES = MODELS 12 | NECKS = MODELS 13 | HEADS = MODELS 14 | LOSSES = MODELS 15 | DEPTHER = MODELS 16 | 17 | def build_backbone(cfg): 18 | """Build backbone.""" 19 | return BACKBONES.build(cfg) 20 | 21 | def build_neck(cfg): 22 | """Build neck.""" 23 | return NECKS.build(cfg) 24 | 25 | def build_head(cfg): 26 | """Build head.""" 27 | return HEADS.build(cfg) 28 | 29 | def build_loss(cfg): 30 | """Build loss.""" 31 | return LOSSES.build(cfg) 32 | 33 | def build_depther(cfg, train_cfg=None, test_cfg=None): 34 | """Build depther.""" 35 | if train_cfg is not None or test_cfg is not None: 36 | warnings.warn( 37 | 'train_cfg and test_cfg is deprecated, ' 38 | 'please specify them in model', UserWarning) 39 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 40 | 'train_cfg specified in both outer field and model field ' 41 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 42 | 'test_cfg specified in both outer field and model field ' 43 | return DEPTHER.build( 44 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 45 | 46 | 47 | -------------------------------------------------------------------------------- /depth/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import update_wrapper 2 | from .trap_head import TrappedHead 3 | from .trap1 import TrappedHead1 4 | from .trap2 import TrappedHead2 -------------------------------------------------------------------------------- /depth/models/depther/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDepther 2 | from .encoder_decoder import DepthEncoderDecoder -------------------------------------------------------------------------------- /depth/models/depther/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | from depth.models import depther 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from depth.core import add_prefix 7 | from depth.ops import resize 8 | from depth.models import builder 9 | from depth.models.builder import DEPTHER 10 | from .base import BaseDepther 11 | 12 | # for model size 13 | import numpy as np 14 | 15 | @DEPTHER.register_module() 16 | class DepthEncoderDecoder(BaseDepther): 17 | """Encoder Decoder depther. 18 | 19 | EncoderDecoder typically consists of backbone, (neck) and decode_head. 20 | """ 21 | 22 | def __init__(self, 23 | backbone, 24 | decode_head, 25 | neck=None, 26 | train_cfg=None, 27 | test_cfg=None, 28 | pretrained=None, 29 | init_cfg=None): 30 | super(DepthEncoderDecoder, self).__init__(init_cfg) 31 | if pretrained is not None: 32 | assert backbone.get('pretrained') is None, \ 33 | 'both backbone and depther set pretrained weight' 34 | backbone.pretrained = pretrained 35 | self.backbone = builder.build_backbone(backbone) 36 | self._init_decode_head(decode_head) 37 | 38 | if neck is not None: 39 | self.neck = builder.build_neck(neck) 40 | 41 | self.train_cfg = train_cfg 42 | self.test_cfg = test_cfg 43 | 44 | assert self.with_decode_head 45 | 46 | def _init_decode_head(self, decode_head): 47 | """Initialize ``decode_head``""" 48 | self.decode_head = builder.build_head(decode_head) 49 | self.align_corners = self.decode_head.align_corners 50 | 51 | def extract_feat(self, img): 52 | """Extract features from images.""" 53 | x = self.backbone(img) 54 | if self.with_neck: 55 | x = self.neck(x) 56 | return x 57 | 58 | def encode_decode(self, img, img_metas, rescale=True): 59 | """Encode images with backbone and decode into a depth estimation 60 | map of the same size as input.""" 61 | 62 | x = self.extract_feat(img) 63 | out = self._decode_head_forward_test(x, img_metas) 64 | # crop the pred depth to the certain range. 65 | out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) 66 | if rescale: 67 | out = resize( 68 | input=out, 69 | size=img.shape[2:], 70 | mode='bilinear', 71 | align_corners=self.align_corners) 72 | return out 73 | 74 | def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): 75 | """Run forward function and calculate loss for decode head in 76 | training.""" 77 | losses = dict() 78 | loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs) 79 | losses.update(add_prefix(loss_decode, 'decode')) 80 | return losses 81 | 82 | def _decode_head_forward_test(self, x, img_metas): 83 | """Run forward function and calculate loss for decode head in 84 | inference.""" 85 | depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg) 86 | return depth_pred 87 | 88 | def forward_dummy(self, img): 89 | """Dummy forward function.""" 90 | depth = self.encode_decode(img, None) 91 | 92 | return depth 93 | 94 | def forward_train(self, img, img_metas, depth_gt, **kwargs): 95 | """Forward function for training. 96 | 97 | Args: 98 | img (Tensor): Input images. 99 | img_metas (list[dict]): List of image info dict where each dict 100 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 101 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 102 | For details on the values of these keys see 103 | `depth/datasets/pipelines/formatting.py:Collect`. 104 | depth_gt (Tensor): Depth gt 105 | used if the architecture supports depth estimation task. 106 | 107 | Returns: 108 | dict[str, Tensor]: a dictionary of loss components 109 | """ 110 | 111 | x = self.extract_feat(img) 112 | 113 | losses = dict() 114 | 115 | # the last of x saves the info from neck 116 | loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) 117 | 118 | losses.update(loss_decode) 119 | 120 | return losses 121 | 122 | def whole_inference(self, img, img_meta, rescale): 123 | """Inference with full image.""" 124 | 125 | depth_pred = self.encode_decode(img, img_meta, rescale) 126 | 127 | return depth_pred 128 | 129 | def inference(self, img, img_meta, rescale): 130 | """Inference with slide/whole style. 131 | 132 | Args: 133 | img (Tensor): The input image of shape (N, 3, H, W). 134 | img_meta (dict): Image info dict where each dict has: 'img_shape', 135 | 'scale_factor', 'flip', and may also contain 136 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 137 | For details on the values of these keys see 138 | `depth/datasets/pipelines/formatting.py:Collect`. 139 | rescale (bool): Whether rescale back to original shape. 140 | 141 | Returns: 142 | Tensor: The output depth map. 143 | """ 144 | 145 | assert self.test_cfg.mode in ['slide', 'whole'] 146 | ori_shape = img_meta[0]['ori_shape'] 147 | assert all(_['ori_shape'] == ori_shape for _ in img_meta) 148 | if self.test_cfg.mode == 'slide': 149 | raise NotImplementedError 150 | else: 151 | depth_pred = self.whole_inference(img, img_meta, rescale) 152 | output = depth_pred 153 | flip = img_meta[0]['flip'] 154 | if flip: 155 | flip_direction = img_meta[0]['flip_direction'] 156 | assert flip_direction in ['horizontal', 'vertical'] 157 | if flip_direction == 'horizontal': 158 | output = output.flip(dims=(3, )) 159 | elif flip_direction == 'vertical': 160 | output = output.flip(dims=(2, )) 161 | 162 | return output 163 | 164 | def simple_test(self, img, img_meta, rescale=True): 165 | """Simple test with single image.""" 166 | depth_pred = self.inference(img, img_meta, rescale) 167 | if torch.onnx.is_in_onnx_export(): 168 | # our inference backend only support 4D output 169 | depth_pred = depth_pred.unsqueeze(0) 170 | return depth_pred 171 | depth_pred = depth_pred.cpu().numpy() 172 | # unravel batch dim 173 | depth_pred = list(depth_pred) 174 | return depth_pred 175 | 176 | def aug_test(self, imgs, img_metas, rescale=True): 177 | """Test with augmentations. 178 | 179 | Only rescale=True is supported. 180 | """ 181 | # aug_test rescale all imgs back to ori_shape for now 182 | assert rescale 183 | # to save memory, we get augmented depth logit inplace 184 | depth_pred = self.inference(imgs[0], img_metas[0], rescale) 185 | for i in range(1, len(imgs)): 186 | cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale) 187 | depth_pred += cur_depth_pred 188 | depth_pred /= len(imgs) 189 | depth_pred = depth_pred.cpu().numpy() 190 | # unravel batch dim 191 | depth_pred = list(depth_pred) 192 | return depth_pred 193 | -------------------------------------------------------------------------------- /depth/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .sigloss import SigLoss -------------------------------------------------------------------------------- /depth/models/losses/sigloss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from depth.models.builder import LOSSES 6 | 7 | @LOSSES.register_module() 8 | class SigLoss(nn.Module): 9 | """SigLoss. 10 | 11 | Args: 12 | valid_mask (bool, optional): Whether filter invalid gt 13 | loss_weight (float, optional): Weight of the loss. Defaults to 1.0. 14 | """ 15 | 16 | def __init__(self, 17 | valid_mask=True, 18 | loss_weight=1.0, 19 | max_depth=None, 20 | warm_up=False, 21 | warm_iter=100): 22 | super(SigLoss, self).__init__() 23 | self.valid_mask = valid_mask 24 | self.loss_weight = loss_weight 25 | self.max_depth = max_depth 26 | 27 | self.eps = 0.1 # avoid grad explode 28 | 29 | # HACK: a hack implement for warmup sigloss 30 | self.warm_up = warm_up 31 | self.warm_iter = warm_iter 32 | self.warm_up_counter = 0 33 | self.rms = torch.nn.MSELoss() 34 | 35 | def sigloss(self, input, target): 36 | if self.valid_mask: 37 | valid_mask = target > 0 38 | if self.max_depth is not None: 39 | valid_mask = torch.logical_and(target > 0, target <= self.max_depth) 40 | input = input[valid_mask] 41 | target = target[valid_mask] 42 | 43 | if self.warm_up: 44 | if self.warm_up_counter < self.warm_iter: 45 | g = torch.log(input + self.eps) - torch.log(target + self.eps) 46 | g = 0.15 * torch.pow(torch.mean(g), 2) 47 | self.warm_up_counter += 1 48 | return torch.sqrt(g) 49 | 50 | g = torch.log(input) - torch.log(target) 51 | Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) 52 | return torch.sqrt(Dg) 53 | 54 | def rmseloss(self, input, target): 55 | if self.valid_mask: 56 | valid_mask = target > 0 57 | if self.max_depth is not None: 58 | valid_mask = torch.logical_and(target > 0, target <= self.max_depth) 59 | input = input[valid_mask] 60 | target = target[valid_mask] 61 | 62 | # return torch.log(torch.sqrt(self.rms(input, target))) 63 | return torch.sqrt(self.rms(input, target)) 64 | 65 | def sqrelloss(self, input, target): 66 | if self.valid_mask: 67 | valid_mask = target > 0 68 | if self.max_depth is not None: 69 | valid_mask = torch.logical_and(target > 0, target <= self.max_depth) 70 | input = input[valid_mask] 71 | target = target[valid_mask] 72 | 73 | return torch.mean(torch.pow(input - target, 2) / target) 74 | 75 | def forward(self, 76 | depth_pred, 77 | depth_gt, 78 | **kwargs): 79 | """Forward function.""" 80 | 81 | loss_depth = self.loss_weight * self.sigloss( 82 | depth_pred, 83 | depth_gt, 84 | ) 85 | 86 | # loss_depth = self.rmseloss(depth_pred, depth_gt) 87 | 88 | # loss_depth = self.loss_weight * self.sigloss( 89 | # depth_pred, 90 | # depth_gt, 91 | # ) + self.rmseloss(depth_pred, depth_gt) 92 | 93 | # loss_depth = self.sigloss( 94 | # depth_pred, 95 | # depth_gt, 96 | # ) + 2 * self.rmseloss(depth_pred, depth_gt) + 4 * self.sqrelloss(depth_pred, depth_gt) 97 | return loss_depth 98 | -------------------------------------------------------------------------------- /depth/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_selection import BlockSelectionNeck -------------------------------------------------------------------------------- /depth/models/necks/block_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, xavier_init 4 | 5 | from depth.ops import resize 6 | from depth.models.builder import NECKS 7 | 8 | import math 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | from mmcv.runner import BaseModule, auto_fp16 14 | 15 | 16 | class BlocksChosen(nn.Module): 17 | def __init__(self): 18 | super(BlocksChosen, self).__init__() 19 | 20 | def forward(self, xs): 21 | x_max = xs[0] 22 | for x_b in xs[1:]: 23 | x_max = torch.maximum(x_max, x_b) 24 | 25 | return x_max 26 | 27 | def _is_contiguous(tensor: torch.Tensor) -> bool: 28 | # jit is oh so lovely :/ 29 | # if torch.jit.is_tracing(): 30 | # return True 31 | if torch.jit.is_scripting(): 32 | return tensor.is_contiguous() 33 | else: 34 | return tensor.is_contiguous(memory_format=torch.contiguous_format) 35 | 36 | 37 | class LayerNorm2d(nn.LayerNorm): 38 | def __init__(self, normalized_shape, eps=1e-6): 39 | super().__init__(normalized_shape, eps=eps) 40 | 41 | def forward(self, x) -> torch.Tensor: 42 | if _is_contiguous(x): 43 | return F.layer_norm( 44 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 45 | else: 46 | s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) 47 | x = (x - u) * torch.rsqrt(s + self.eps) 48 | x = x * self.weight[:, None, None] + self.bias[:, None, None] 49 | return x 50 | 51 | @NECKS.register_module() 52 | class BlockSelectionNeck(BaseModule): 53 | def __init__(self, 54 | in_channels=[384]*5, 55 | out_channels=[48, 96, 192, 384, 768], 56 | start=[3, 5, 7, 9, 11], 57 | end=[5, 7, 9, 11, 12], 58 | scales=[4, 2, 1, .5, .25]): 59 | super(BlockSelectionNeck, self).__init__() 60 | assert isinstance(in_channels, list) 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | self.scales = scales 64 | self.start = start 65 | self.end = end 66 | self.num_outs = len(scales) 67 | 68 | self.blocks_selection = BlocksChosen() 69 | self.trans_proj = nn.ModuleList() 70 | for i in range(len(scales)): 71 | if scales[i] > 1: 72 | self.trans_proj.append(nn.Sequential( 73 | nn.ConvTranspose2d(in_channels[i], in_channels[i], kernel_size=scales[i], stride=scales[i], 74 | groups=in_channels[i]), 75 | LayerNorm2d(in_channels[i]), 76 | nn.Conv2d(in_channels[i], out_channels[i], kernel_size=(1, 1), stride=(1, 1)), 77 | )) 78 | elif scales[i] == 1: 79 | if in_channels[i] == out_channels[i]: 80 | self.trans_proj.append(nn.Identity()) 81 | else: 82 | self.trans_proj.append(nn.Sequential( 83 | LayerNorm2d(in_channels[i]), 84 | nn.Conv2d(in_channels[i], out_channels[i], kernel_size=(1, 1), stride=(1, 1)) 85 | )) 86 | elif scales[i] < 1: 87 | self.trans_proj.append(nn.Sequential( 88 | nn.Conv2d(in_channels[i], in_channels[i], kernel_size=int(1/scales[i]), stride=int(1/scales[i])), 89 | LayerNorm2d(in_channels[i]), 90 | nn.Conv2d(in_channels[i], out_channels[i], kernel_size=(1, 1), stride=(1, 1)), 91 | )) 92 | 93 | 94 | 95 | 96 | # init weight 97 | def init_weights(self): 98 | for p in self.parameters(): 99 | if p.dim() > 1: 100 | nn.init.xavier_uniform_(p) 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | # xavier_init(m, distribution='uniform') 105 | nn.init.trunc_normal_(m.weight, std=0.02) 106 | nn.init.zeros_(m.bias) 107 | elif isinstance(m, LayerNorm2d): 108 | nn.init.constant_(m.bias, 0) 109 | nn.init.constant_(m.weight, 1.0) 110 | 111 | def forward(self, inputs): 112 | assert len(inputs) >= self.end[-1] - 1 113 | outs = [] 114 | # for indices_start, indices_end in zip(self.start, self.end): 115 | for i in range(len(self.start)): 116 | feature = self.blocks_selection(inputs[self.start[i]:self.end[i]]) 117 | outs.append(self.trans_proj[i](feature)) 118 | 119 | return tuple(outs) 120 | 121 | -------------------------------------------------------------------------------- /depth/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ckpt_convert import swin_convert, vit_convert 3 | from .embed import PatchEmbed, PatchEmbedSwin 4 | from .inverted_residual import InvertedResidual, InvertedResidualV3 5 | from .make_divisible import make_divisible 6 | from .res_layer import ResLayer 7 | from .se_layer import SELayer 8 | from .self_attention_block import SelfAttentionBlock 9 | from .shape_convert import nchw_to_nlc, nlc_to_nchw 10 | from .up_conv_block import UpConvBlock, BasicConvBlock 11 | from .logger import get_root_logger 12 | # from .hooks import TensorboardImageLoggerHook 13 | 14 | __all__ = [ 15 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 16 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert', 17 | 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'BasicConvBlock', 18 | 'get_root_logger' 19 | ] -------------------------------------------------------------------------------- /depth/models/utils/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.utils import Registry, build_from_cfg 4 | 5 | TRANSFORMER = Registry('Transformer') 6 | 7 | def build_transformer(cfg, default_args=None): 8 | """Builder for Transformer.""" 9 | return build_from_cfg(cfg, TRANSFORMER, default_args) 10 | -------------------------------------------------------------------------------- /depth/models/utils/ckpt_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import OrderedDict 3 | 4 | 5 | def swin_convert(ckpt): 6 | new_ckpt = OrderedDict() 7 | 8 | def correct_unfold_reduction_order(x): 9 | out_channel, in_channel = x.shape 10 | x = x.reshape(out_channel, 4, in_channel // 4) 11 | x = x[:, [0, 2, 1, 3], :].transpose(1, 12 | 2).reshape(out_channel, in_channel) 13 | return x 14 | 15 | def correct_unfold_norm_order(x): 16 | in_channel = x.shape[0] 17 | x = x.reshape(4, in_channel // 4) 18 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 19 | return x 20 | 21 | for k, v in ckpt.items(): 22 | if k.startswith('head'): 23 | continue 24 | elif k.startswith('layers'): 25 | new_v = v 26 | if 'attn.' in k: 27 | new_k = k.replace('attn.', 'attn.w_msa.') 28 | elif 'mlp.' in k: 29 | if 'mlp.fc1.' in k: 30 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 31 | elif 'mlp.fc2.' in k: 32 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 33 | else: 34 | new_k = k.replace('mlp.', 'ffn.') 35 | elif 'downsample' in k: 36 | new_k = k 37 | if 'reduction.' in k: 38 | new_v = correct_unfold_reduction_order(v) 39 | elif 'norm.' in k: 40 | new_v = correct_unfold_norm_order(v) 41 | else: 42 | new_k = k 43 | new_k = new_k.replace('layers', 'stages', 1) 44 | elif k.startswith('patch_embed'): 45 | new_v = v 46 | if 'proj' in k: 47 | new_k = k.replace('proj', 'projection') 48 | else: 49 | new_k = k 50 | else: 51 | new_v = v 52 | new_k = k 53 | 54 | new_ckpt[new_k] = new_v 55 | 56 | return new_ckpt 57 | 58 | 59 | def vit_convert(ckpt): 60 | 61 | new_ckpt = OrderedDict() 62 | 63 | for k, v in ckpt.items(): 64 | if k.startswith('head'): 65 | continue 66 | if k.startswith('norm'): 67 | new_k = k.replace('norm.', 'ln1.') 68 | elif k.startswith('patch_embed'): 69 | if 'proj' in k: 70 | new_k = k.replace('proj', 'projection') 71 | else: 72 | new_k = k 73 | elif k.startswith('blocks'): 74 | if 'norm' in k: 75 | new_k = k.replace('norm', 'ln') 76 | elif 'mlp.fc1' in k: 77 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 78 | elif 'mlp.fc2' in k: 79 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 80 | elif 'attn.qkv' in k: 81 | new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') 82 | elif 'attn.proj' in k: 83 | new_k = k.replace('attn.proj', 'attn.attn.out_proj') 84 | else: 85 | new_k = k 86 | new_k = new_k.replace('blocks.', 'layers.') 87 | else: 88 | new_k = k 89 | new_ckpt[new_k] = v 90 | 91 | return new_ckpt 92 | -------------------------------------------------------------------------------- /depth/models/utils/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensorboard_hook import TensorboardImageLoggerHook -------------------------------------------------------------------------------- /depth/models/utils/hooks/tensorboard_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from mmcv.utils import TORCH_VERSION, digit_version 5 | from mmcv.runner.dist_utils import master_only 6 | from mmcv.runner.hooks import HOOKS 7 | from mmcv.runner.hooks.logger.base import LoggerHook 8 | 9 | 10 | @HOOKS.register_module() 11 | class TensorboardImageLoggerHook(LoggerHook): 12 | 13 | def __init__(self, 14 | log_dir=None, 15 | interval=10, 16 | ignore_last=True, 17 | reset_flag=False, 18 | by_epoch=True): 19 | super(TensorboardImageLoggerHook, self).__init__(interval, ignore_last, 20 | reset_flag, by_epoch) 21 | self.log_dir = log_dir 22 | 23 | @master_only 24 | def before_run(self, runner): 25 | super(TensorboardImageLoggerHook, self).before_run(runner) 26 | if (TORCH_VERSION == 'parrots' 27 | or digit_version(TORCH_VERSION) < digit_version('1.1')): 28 | try: 29 | from tensorboardX import SummaryWriter 30 | except ImportError: 31 | raise ImportError('Please install tensorboardX to use ' 32 | 'TensorboardImageLoggerHook.') 33 | else: 34 | try: 35 | from torch.utils.tensorboard import SummaryWriter 36 | except ImportError: 37 | raise ImportError( 38 | 'Please run "pip install future tensorboard" to install ' 39 | 'the dependencies to use torch.utils.tensorboard ' 40 | '(applicable to PyTorch 1.1 or higher)') 41 | 42 | if self.log_dir is None: 43 | self.log_dir = osp.join(runner.work_dir, 'tf_logs') 44 | self.writer = SummaryWriter(self.log_dir) 45 | 46 | @master_only 47 | def log(self, runner): 48 | if self.get_mode(runner) == 'train': 49 | log_images = runner.outputs.get('log_imgs') 50 | if log_images is not None: 51 | for tag, val in log_images.items(): 52 | self.writer.add_image(f'{self.get_mode(runner)}/{tag}', val, self.get_iter(runner)) 53 | 54 | tags = self.get_loggable_tags(runner, allow_text=True) 55 | for tag, val in tags.items(): 56 | if isinstance(val, str): 57 | self.writer.add_text(tag, val, self.get_iter(runner)) 58 | else: 59 | self.writer.add_scalar(tag, val, self.get_iter(runner)) 60 | 61 | @master_only 62 | def after_run(self, runner): 63 | self.writer.close() 64 | -------------------------------------------------------------------------------- /depth/models/utils/inverted_residual.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import ConvModule 3 | from torch import nn 4 | from torch.utils import checkpoint as cp 5 | 6 | from .se_layer import SELayer 7 | 8 | 9 | class InvertedResidual(nn.Module): 10 | """InvertedResidual block for MobileNetV2. 11 | 12 | Args: 13 | in_channels (int): The input channels of the InvertedResidual block. 14 | out_channels (int): The output channels of the InvertedResidual block. 15 | stride (int): Stride of the middle (first) 3x3 convolution. 16 | expand_ratio (int): Adjusts number of channels of the hidden layer 17 | in InvertedResidual by this amount. 18 | dilation (int): Dilation rate of depthwise conv. Default: 1 19 | conv_cfg (dict): Config dict for convolution layer. 20 | Default: None, which means using conv2d. 21 | norm_cfg (dict): Config dict for normalization layer. 22 | Default: dict(type='BN'). 23 | act_cfg (dict): Config dict for activation layer. 24 | Default: dict(type='ReLU6'). 25 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 26 | memory while slowing down the training speed. Default: False. 27 | 28 | Returns: 29 | Tensor: The output tensor. 30 | """ 31 | 32 | def __init__(self, 33 | in_channels, 34 | out_channels, 35 | stride, 36 | expand_ratio, 37 | dilation=1, 38 | conv_cfg=None, 39 | norm_cfg=dict(type='BN'), 40 | act_cfg=dict(type='ReLU6'), 41 | with_cp=False, 42 | **kwargs): 43 | super(InvertedResidual, self).__init__() 44 | self.stride = stride 45 | assert stride in [1, 2], f'stride must in [1, 2]. ' \ 46 | f'But received {stride}.' 47 | self.with_cp = with_cp 48 | self.use_res_connect = self.stride == 1 and in_channels == out_channels 49 | hidden_dim = int(round(in_channels * expand_ratio)) 50 | 51 | layers = [] 52 | if expand_ratio != 1: 53 | layers.append( 54 | ConvModule( 55 | in_channels=in_channels, 56 | out_channels=hidden_dim, 57 | kernel_size=1, 58 | conv_cfg=conv_cfg, 59 | norm_cfg=norm_cfg, 60 | act_cfg=act_cfg, 61 | **kwargs)) 62 | layers.extend([ 63 | ConvModule( 64 | in_channels=hidden_dim, 65 | out_channels=hidden_dim, 66 | kernel_size=3, 67 | stride=stride, 68 | padding=dilation, 69 | dilation=dilation, 70 | groups=hidden_dim, 71 | conv_cfg=conv_cfg, 72 | norm_cfg=norm_cfg, 73 | act_cfg=act_cfg, 74 | **kwargs), 75 | ConvModule( 76 | in_channels=hidden_dim, 77 | out_channels=out_channels, 78 | kernel_size=1, 79 | conv_cfg=conv_cfg, 80 | norm_cfg=norm_cfg, 81 | act_cfg=None, 82 | **kwargs) 83 | ]) 84 | self.conv = nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | 88 | def _inner_forward(x): 89 | if self.use_res_connect: 90 | return x + self.conv(x) 91 | else: 92 | return self.conv(x) 93 | 94 | if self.with_cp and x.requires_grad: 95 | out = cp.checkpoint(_inner_forward, x) 96 | else: 97 | out = _inner_forward(x) 98 | 99 | return out 100 | 101 | 102 | class InvertedResidualV3(nn.Module): 103 | """Inverted Residual Block for MobileNetV3. 104 | 105 | Args: 106 | in_channels (int): The input channels of this Module. 107 | out_channels (int): The output channels of this Module. 108 | mid_channels (int): The input channels of the depthwise convolution. 109 | kernel_size (int): The kernel size of the depthwise convolution. 110 | Default: 3. 111 | stride (int): The stride of the depthwise convolution. Default: 1. 112 | se_cfg (dict): Config dict for se layer. Default: None, which means no 113 | se layer. 114 | with_expand_conv (bool): Use expand conv or not. If set False, 115 | mid_channels must be the same with in_channels. Default: True. 116 | conv_cfg (dict): Config dict for convolution layer. Default: None, 117 | which means using conv2d. 118 | norm_cfg (dict): Config dict for normalization layer. 119 | Default: dict(type='BN'). 120 | act_cfg (dict): Config dict for activation layer. 121 | Default: dict(type='ReLU'). 122 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 123 | memory while slowing down the training speed. Default: False. 124 | 125 | Returns: 126 | Tensor: The output tensor. 127 | """ 128 | 129 | def __init__(self, 130 | in_channels, 131 | out_channels, 132 | mid_channels, 133 | kernel_size=3, 134 | stride=1, 135 | se_cfg=None, 136 | with_expand_conv=True, 137 | conv_cfg=None, 138 | norm_cfg=dict(type='BN'), 139 | act_cfg=dict(type='ReLU'), 140 | with_cp=False): 141 | super(InvertedResidualV3, self).__init__() 142 | self.with_res_shortcut = (stride == 1 and in_channels == out_channels) 143 | assert stride in [1, 2] 144 | self.with_cp = with_cp 145 | self.with_se = se_cfg is not None 146 | self.with_expand_conv = with_expand_conv 147 | 148 | if self.with_se: 149 | assert isinstance(se_cfg, dict) 150 | if not self.with_expand_conv: 151 | assert mid_channels == in_channels 152 | 153 | if self.with_expand_conv: 154 | self.expand_conv = ConvModule( 155 | in_channels=in_channels, 156 | out_channels=mid_channels, 157 | kernel_size=1, 158 | stride=1, 159 | padding=0, 160 | conv_cfg=conv_cfg, 161 | norm_cfg=norm_cfg, 162 | act_cfg=act_cfg) 163 | self.depthwise_conv = ConvModule( 164 | in_channels=mid_channels, 165 | out_channels=mid_channels, 166 | kernel_size=kernel_size, 167 | stride=stride, 168 | padding=kernel_size // 2, 169 | groups=mid_channels, 170 | conv_cfg=dict( 171 | type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, 172 | norm_cfg=norm_cfg, 173 | act_cfg=act_cfg) 174 | 175 | if self.with_se: 176 | self.se = SELayer(**se_cfg) 177 | 178 | self.linear_conv = ConvModule( 179 | in_channels=mid_channels, 180 | out_channels=out_channels, 181 | kernel_size=1, 182 | stride=1, 183 | padding=0, 184 | conv_cfg=conv_cfg, 185 | norm_cfg=norm_cfg, 186 | act_cfg=None) 187 | 188 | def forward(self, x): 189 | 190 | def _inner_forward(x): 191 | out = x 192 | 193 | if self.with_expand_conv: 194 | out = self.expand_conv(out) 195 | 196 | out = self.depthwise_conv(out) 197 | 198 | if self.with_se: 199 | out = self.se(out) 200 | 201 | out = self.linear_conv(out) 202 | 203 | if self.with_res_shortcut: 204 | return x + out 205 | else: 206 | return out 207 | 208 | if self.with_cp and x.requires_grad: 209 | out = cp.checkpoint(_inner_forward, x) 210 | else: 211 | out = _inner_forward(x) 212 | 213 | return out 214 | -------------------------------------------------------------------------------- /depth/models/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO, name='mmcv'): 8 | """Get root logger and add a keyword filter to it. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmdet3d". 14 | 15 | Args: 16 | log_file (str, optional): File path of log. Defaults to None. 17 | log_level (int, optional): The level of logger. 18 | Defaults to logging.INFO. 19 | name (str, optional): The name of the root logger, also used as a 20 | filter keyword. Defaults to 'mmdet3d'. 21 | 22 | Returns: 23 | :obj:`logging.Logger`: The obtained logger 24 | """ 25 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 26 | 27 | # add a logging filter 28 | logging_filter = logging.Filter(name) 29 | logging_filter.filter = lambda record: record.find(name) != -1 30 | 31 | return logger -------------------------------------------------------------------------------- /depth/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /depth/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import build_conv_layer, build_norm_layer 3 | from mmcv.runner import Sequential 4 | from torch import nn as nn 5 | 6 | 7 | class ResLayer(Sequential): 8 | """ResLayer to build ResNet style backbone. 9 | 10 | Args: 11 | block (nn.Module): block used to build ResLayer. 12 | inplanes (int): inplanes of block. 13 | planes (int): planes of block. 14 | num_blocks (int): number of blocks. 15 | stride (int): stride of the first block. Default: 1 16 | avg_down (bool): Use AvgPool instead of stride conv when 17 | downsampling in the bottleneck. Default: False 18 | conv_cfg (dict): dictionary to construct and config conv layer. 19 | Default: None 20 | norm_cfg (dict): dictionary to construct and config norm layer. 21 | Default: dict(type='BN') 22 | multi_grid (int | None): Multi grid dilation rates of last 23 | stage. Default: None 24 | contract_dilation (bool): Whether contract first dilation of each layer 25 | Default: False 26 | """ 27 | 28 | def __init__(self, 29 | block, 30 | inplanes, 31 | planes, 32 | num_blocks, 33 | stride=1, 34 | dilation=1, 35 | avg_down=False, 36 | conv_cfg=None, 37 | norm_cfg=dict(type='BN'), 38 | multi_grid=None, 39 | contract_dilation=False, 40 | **kwargs): 41 | self.block = block 42 | 43 | downsample = None 44 | if stride != 1 or inplanes != planes * block.expansion: 45 | downsample = [] 46 | conv_stride = stride 47 | if avg_down: 48 | conv_stride = 1 49 | downsample.append( 50 | nn.AvgPool2d( 51 | kernel_size=stride, 52 | stride=stride, 53 | ceil_mode=True, 54 | count_include_pad=False)) 55 | downsample.extend([ 56 | build_conv_layer( 57 | conv_cfg, 58 | inplanes, 59 | planes * block.expansion, 60 | kernel_size=1, 61 | stride=conv_stride, 62 | bias=False), 63 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 64 | ]) 65 | downsample = nn.Sequential(*downsample) 66 | 67 | layers = [] 68 | if multi_grid is None: 69 | if dilation > 1 and contract_dilation: 70 | first_dilation = dilation // 2 71 | else: 72 | first_dilation = dilation 73 | else: 74 | first_dilation = multi_grid[0] 75 | layers.append( 76 | block( 77 | inplanes=inplanes, 78 | planes=planes, 79 | stride=stride, 80 | dilation=first_dilation, 81 | downsample=downsample, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | **kwargs)) 85 | inplanes = planes * block.expansion 86 | for i in range(1, num_blocks): 87 | layers.append( 88 | block( 89 | inplanes=inplanes, 90 | planes=planes, 91 | stride=1, 92 | dilation=dilation if multi_grid is None else multi_grid[i], 93 | conv_cfg=conv_cfg, 94 | norm_cfg=norm_cfg, 95 | **kwargs)) 96 | super(ResLayer, self).__init__(*layers) 97 | -------------------------------------------------------------------------------- /depth/models/utils/se_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from .make_divisible import make_divisible 7 | 8 | 9 | class SELayer(nn.Module): 10 | """Squeeze-and-Excitation Module. 11 | 12 | Args: 13 | channels (int): The input (and output) channels of the SE layer. 14 | ratio (int): Squeeze ratio in SELayer, the intermediate channel will be 15 | ``int(channels/ratio)``. Default: 16. 16 | conv_cfg (None or dict): Config dict for convolution layer. 17 | Default: None, which means using conv2d. 18 | act_cfg (dict or Sequence[dict]): Config dict for activation layer. 19 | If act_cfg is a dict, two activation layers will be configured 20 | by this dict. If act_cfg is a sequence of dicts, the first 21 | activation layer will be configured by the first dict and the 22 | second activation layer will be configured by the second dict. 23 | Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, 24 | divisor=6.0)). 25 | """ 26 | 27 | def __init__(self, 28 | channels, 29 | ratio=16, 30 | conv_cfg=None, 31 | act_cfg=(dict(type='ReLU'), 32 | dict(type='HSigmoid', bias=3.0, divisor=6.0))): 33 | super(SELayer, self).__init__() 34 | if isinstance(act_cfg, dict): 35 | act_cfg = (act_cfg, act_cfg) 36 | assert len(act_cfg) == 2 37 | assert mmcv.is_tuple_of(act_cfg, dict) 38 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 39 | self.conv1 = ConvModule( 40 | in_channels=channels, 41 | out_channels=make_divisible(channels // ratio, 8), 42 | kernel_size=1, 43 | stride=1, 44 | conv_cfg=conv_cfg, 45 | act_cfg=act_cfg[0]) 46 | self.conv2 = ConvModule( 47 | in_channels=make_divisible(channels // ratio, 8), 48 | out_channels=channels, 49 | kernel_size=1, 50 | stride=1, 51 | conv_cfg=conv_cfg, 52 | act_cfg=act_cfg[1]) 53 | 54 | def forward(self, x): 55 | out = self.global_avgpool(x) 56 | out = self.conv1(out) 57 | out = self.conv2(out) 58 | return x * out 59 | -------------------------------------------------------------------------------- /depth/models/utils/self_attention_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import ConvModule, constant_init 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class SelfAttentionBlock(nn.Module): 9 | """General self-attention block/non-local block. 10 | 11 | Please refer to https://arxiv.org/abs/1706.03762 for details about key, 12 | query and value. 13 | 14 | Args: 15 | key_in_channels (int): Input channels of key feature. 16 | query_in_channels (int): Input channels of query feature. 17 | channels (int): Output channels of key/query transform. 18 | out_channels (int): Output channels. 19 | share_key_query (bool): Whether share projection weight between key 20 | and query projection. 21 | query_downsample (nn.Module): Query downsample module. 22 | key_downsample (nn.Module): Key downsample module. 23 | key_query_num_convs (int): Number of convs for key/query projection. 24 | value_num_convs (int): Number of convs for value projection. 25 | matmul_norm (bool): Whether normalize attention map with sqrt of 26 | channels 27 | with_out (bool): Whether use out projection. 28 | conv_cfg (dict|None): Config of conv layers. 29 | norm_cfg (dict|None): Config of norm layers. 30 | act_cfg (dict|None): Config of activation layers. 31 | """ 32 | 33 | def __init__(self, key_in_channels, query_in_channels, channels, 34 | out_channels, share_key_query, query_downsample, 35 | key_downsample, key_query_num_convs, value_out_num_convs, 36 | key_query_norm, value_out_norm, matmul_norm, with_out, 37 | conv_cfg, norm_cfg, act_cfg): 38 | super(SelfAttentionBlock, self).__init__() 39 | if share_key_query: 40 | assert key_in_channels == query_in_channels 41 | self.key_in_channels = key_in_channels 42 | self.query_in_channels = query_in_channels 43 | self.out_channels = out_channels 44 | self.channels = channels 45 | self.share_key_query = share_key_query 46 | self.conv_cfg = conv_cfg 47 | self.norm_cfg = norm_cfg 48 | self.act_cfg = act_cfg 49 | self.key_project = self.build_project( 50 | key_in_channels, 51 | channels, 52 | num_convs=key_query_num_convs, 53 | use_conv_module=key_query_norm, 54 | conv_cfg=conv_cfg, 55 | norm_cfg=norm_cfg, 56 | act_cfg=act_cfg) 57 | if share_key_query: 58 | self.query_project = self.key_project 59 | else: 60 | self.query_project = self.build_project( 61 | query_in_channels, 62 | channels, 63 | num_convs=key_query_num_convs, 64 | use_conv_module=key_query_norm, 65 | conv_cfg=conv_cfg, 66 | norm_cfg=norm_cfg, 67 | act_cfg=act_cfg) 68 | self.value_project = self.build_project( 69 | key_in_channels, 70 | channels if with_out else out_channels, 71 | num_convs=value_out_num_convs, 72 | use_conv_module=value_out_norm, 73 | conv_cfg=conv_cfg, 74 | norm_cfg=norm_cfg, 75 | act_cfg=act_cfg) 76 | if with_out: 77 | self.out_project = self.build_project( 78 | channels, 79 | out_channels, 80 | num_convs=value_out_num_convs, 81 | use_conv_module=value_out_norm, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | act_cfg=act_cfg) 85 | else: 86 | self.out_project = None 87 | 88 | self.query_downsample = query_downsample 89 | self.key_downsample = key_downsample 90 | self.matmul_norm = matmul_norm 91 | 92 | self.init_weights() 93 | 94 | def init_weights(self): 95 | """Initialize weight of later layer.""" 96 | if self.out_project is not None: 97 | if not isinstance(self.out_project, ConvModule): 98 | constant_init(self.out_project, 0) 99 | 100 | def build_project(self, in_channels, channels, num_convs, use_conv_module, 101 | conv_cfg, norm_cfg, act_cfg): 102 | """Build projection layer for key/query/value/out.""" 103 | if use_conv_module: 104 | convs = [ 105 | ConvModule( 106 | in_channels, 107 | channels, 108 | 1, 109 | conv_cfg=conv_cfg, 110 | norm_cfg=norm_cfg, 111 | act_cfg=act_cfg) 112 | ] 113 | for _ in range(num_convs - 1): 114 | convs.append( 115 | ConvModule( 116 | channels, 117 | channels, 118 | 1, 119 | conv_cfg=conv_cfg, 120 | norm_cfg=norm_cfg, 121 | act_cfg=act_cfg)) 122 | else: 123 | convs = [nn.Conv2d(in_channels, channels, 1)] 124 | for _ in range(num_convs - 1): 125 | convs.append(nn.Conv2d(channels, channels, 1)) 126 | if len(convs) > 1: 127 | convs = nn.Sequential(*convs) 128 | else: 129 | convs = convs[0] 130 | return convs 131 | 132 | def forward(self, query_feats, key_feats): 133 | """Forward function.""" 134 | batch_size = query_feats.size(0) 135 | query = self.query_project(query_feats) 136 | if self.query_downsample is not None: 137 | query = self.query_downsample(query) 138 | query = query.reshape(*query.shape[:2], -1) 139 | query = query.permute(0, 2, 1).contiguous() 140 | 141 | key = self.key_project(key_feats) 142 | value = self.value_project(key_feats) 143 | if self.key_downsample is not None: 144 | key = self.key_downsample(key) 145 | value = self.key_downsample(value) 146 | key = key.reshape(*key.shape[:2], -1) 147 | value = value.reshape(*value.shape[:2], -1) 148 | value = value.permute(0, 2, 1).contiguous() 149 | 150 | sim_map = torch.matmul(query, key) 151 | if self.matmul_norm: 152 | sim_map = (self.channels**-.5) * sim_map 153 | sim_map = F.softmax(sim_map, dim=-1) 154 | 155 | context = torch.matmul(sim_map, value) 156 | context = context.permute(0, 2, 1).contiguous() 157 | context = context.reshape(batch_size, -1, *query_feats.shape[2:]) 158 | if self.out_project is not None: 159 | context = self.out_project(context) 160 | return context 161 | -------------------------------------------------------------------------------- /depth/models/utils/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def nlc_to_nchw(x, hw_shape): 3 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 4 | 5 | Args: 6 | x (Tensor): The input tensor of shape [N, L, C] before convertion. 7 | hw_shape (Sequence[int]): The height and width of output feature map. 8 | 9 | Returns: 10 | Tensor: The output tensor of shape [N, C, H, W] after convertion. 11 | """ 12 | H, W = hw_shape 13 | assert len(x.shape) == 3 14 | B, L, C = x.shape 15 | assert L == H * W, 'The seq_len doesn\'t match H, W' 16 | return x.transpose(1, 2).reshape(B, C, H, W) 17 | 18 | 19 | def nchw_to_nlc(x): 20 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 21 | 22 | Args: 23 | x (Tensor): The input tensor of shape [N, C, H, W] before convertion. 24 | 25 | Returns: 26 | Tensor: The output tensor of shape [N, L, C] after convertion. 27 | """ 28 | assert len(x.shape) == 4 29 | return x.flatten(2).transpose(1, 2).contiguous() 30 | -------------------------------------------------------------------------------- /depth/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .encoding import Encoding 3 | from .wrappers import Upsample, resize 4 | 5 | __all__ = ['Upsample', 'resize', 'Encoding'] 6 | -------------------------------------------------------------------------------- /depth/ops/encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Encoding(nn.Module): 8 | """Encoding Layer: a learnable residual encoder. 9 | 10 | Input is of shape (batch_size, channels, height, width). 11 | Output is of shape (batch_size, num_codes, channels). 12 | 13 | Args: 14 | channels: dimension of the features or feature channels 15 | num_codes: number of code words 16 | """ 17 | 18 | def __init__(self, channels, num_codes): 19 | super(Encoding, self).__init__() 20 | # init codewords and smoothing factor 21 | self.channels, self.num_codes = channels, num_codes 22 | std = 1. / ((num_codes * channels)**0.5) 23 | # [num_codes, channels] 24 | self.codewords = nn.Parameter( 25 | torch.empty(num_codes, channels, 26 | dtype=torch.float).uniform_(-std, std), 27 | requires_grad=True) 28 | # [num_codes] 29 | self.scale = nn.Parameter( 30 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 31 | requires_grad=True) 32 | 33 | @staticmethod 34 | def scaled_l2(x, codewords, scale): 35 | num_codes, channels = codewords.size() 36 | batch_size = x.size(0) 37 | reshaped_scale = scale.view((1, 1, num_codes)) 38 | expanded_x = x.unsqueeze(2).expand( 39 | (batch_size, x.size(1), num_codes, channels)) 40 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 41 | 42 | scaled_l2_norm = reshaped_scale * ( 43 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 44 | return scaled_l2_norm 45 | 46 | @staticmethod 47 | def aggregate(assignment_weights, x, codewords): 48 | num_codes, channels = codewords.size() 49 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 50 | batch_size = x.size(0) 51 | 52 | expanded_x = x.unsqueeze(2).expand( 53 | (batch_size, x.size(1), num_codes, channels)) 54 | encoded_feat = (assignment_weights.unsqueeze(3) * 55 | (expanded_x - reshaped_codewords)).sum(dim=1) 56 | return encoded_feat 57 | 58 | def forward(self, x): 59 | assert x.dim() == 4 and x.size(1) == self.channels 60 | # [batch_size, channels, height, width] 61 | batch_size = x.size(0) 62 | # [batch_size, height x width, channels] 63 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 64 | # assignment_weights: [batch_size, channels, num_codes] 65 | assignment_weights = F.softmax( 66 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 67 | # aggregate 68 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 69 | return encoded_feat 70 | 71 | def __repr__(self): 72 | repr_str = self.__class__.__name__ 73 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 74 | f'x{self.channels})' 75 | return repr_str 76 | -------------------------------------------------------------------------------- /depth/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=False): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > output_h: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class Upsample(nn.Module): 31 | 32 | def __init__(self, 33 | size=None, 34 | scale_factor=None, 35 | mode='nearest', 36 | align_corners=None): 37 | super(Upsample, self).__init__() 38 | self.size = size 39 | if isinstance(scale_factor, tuple): 40 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 41 | else: 42 | self.scale_factor = float(scale_factor) if scale_factor else None 43 | self.mode = mode 44 | self.align_corners = align_corners 45 | 46 | def forward(self, x): 47 | if not self.size: 48 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 49 | else: 50 | size = self.size 51 | return resize(x, size, None, self.mode, self.align_corners) 52 | -------------------------------------------------------------------------------- /depth/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | from .position_encoding import SinePositionalEncoding, LearnedPositionalEncoding 5 | from .color_depth import colorize 6 | 7 | __all__ = ['get_root_logger', 'collect_env', 'SinePositionalEncoding', 'LearnedPositionalEncoding', 'colorize'] 8 | -------------------------------------------------------------------------------- /depth/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import depth 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['Depth'] = f'{depth.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /depth/utils/color_depth.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | # color the depth, kitti magma_r, nyu jet 4 | import numpy as np 5 | 6 | 7 | # def colorize(value, cmap='magma_r', vmin=None, vmax=None): 8 | def colorize(value, cmap='plasma', vmin=None, vmax=None): 9 | # def colorize(value, cmap='jet', vmin=None, vmax=None): 10 | 11 | # for abs 12 | # vmin=1e-3 13 | # vmax=80 14 | 15 | # for relative 16 | # value[value<=vmin]=vmin 17 | 18 | # vmin=None 19 | # vmax=None 20 | 21 | # normalize 22 | vmin = value.min() if vmin is None else vmin 23 | vmax = value.max() if vmax is None else vmax 24 | vmax = -1e-3 25 | vmin = -10 26 | 27 | if vmin != vmax: 28 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 29 | # value = (vmax - value) / vmax # vmin..vmax 30 | else: 31 | # Avoid 0-division 32 | value = value * 0. 33 | 34 | cmapper = matplotlib.cm.get_cmap(cmap) 35 | value = cmapper(value, bytes=True) # ((1)xhxwx4) 36 | 37 | value = value[:, :, :, :3] # bgr -> rgb 38 | rgb_value = value[..., ::-1] 39 | # rgb_value = value[..., [1, 0, 2]] 40 | # rgb_value = value 41 | 42 | return rgb_value -------------------------------------------------------------------------------- /depth/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "depth". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | 26 | logger = get_logger(name='depth', log_file=log_file, log_level=log_level) 27 | 28 | return logger 29 | -------------------------------------------------------------------------------- /depth/utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING 7 | from mmcv.runner import BaseModule 8 | 9 | 10 | @POSITIONAL_ENCODING.register_module() 11 | class SinePositionalEncoding(BaseModule): 12 | """Position encoding with sine and cosine functions. 13 | See `End-to-End Object Detection with Transformers 14 | `_ for details. 15 | Args: 16 | num_feats (int): The feature dimension for each position 17 | along x-axis or y-axis. Note the final returned dimension 18 | for each position is 2 times of this value. 19 | temperature (int, optional): The temperature used for scaling 20 | the position embedding. Defaults to 10000. 21 | normalize (bool, optional): Whether to normalize the position 22 | embedding. Defaults to False. 23 | scale (float, optional): A scale factor that scales the position 24 | embedding. The scale will be used only when `normalize` is True. 25 | Defaults to 2*pi. 26 | eps (float, optional): A value added to the denominator for 27 | numerical stability. Defaults to 1e-6. 28 | offset (float): offset add to embed when do the normalization. 29 | Defaults to 0. 30 | init_cfg (dict or list[dict], optional): Initialization config dict. 31 | Default: None 32 | """ 33 | 34 | def __init__(self, 35 | num_feats, 36 | temperature=10000, 37 | normalize=False, 38 | scale=2 * math.pi, 39 | eps=1e-6, 40 | offset=0., 41 | init_cfg=None): 42 | super(SinePositionalEncoding, self).__init__(init_cfg) 43 | if normalize: 44 | assert isinstance(scale, (float, int)), 'when normalize is set,' \ 45 | 'scale should be provided and in float or int type, ' \ 46 | f'found {type(scale)}' 47 | self.num_feats = num_feats 48 | self.temperature = temperature 49 | self.normalize = normalize 50 | self.scale = scale 51 | self.eps = eps 52 | self.offset = offset 53 | 54 | def forward(self, mask): 55 | """Forward function for `SinePositionalEncoding`. 56 | Args: 57 | mask (Tensor): ByteTensor mask. Non-zero values representing 58 | ignored positions, while zero values means valid positions 59 | for this image. Shape [bs, h, w]. 60 | Returns: 61 | pos (Tensor): Returned position embedding with shape 62 | [bs, num_feats*2, h, w]. 63 | """ 64 | # For convenience of exporting to ONNX, it's required to convert 65 | # `masks` from bool to int. 66 | mask = mask.to(torch.int) 67 | not_mask = 1 - mask # logical_not 68 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 69 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 70 | if self.normalize: 71 | y_embed = (y_embed + self.offset) / \ 72 | (y_embed[:, -1:, :] + self.eps) * self.scale 73 | x_embed = (x_embed + self.offset) / \ 74 | (x_embed[:, :, -1:] + self.eps) * self.scale 75 | dim_t = torch.arange( 76 | self.num_feats, dtype=torch.float32, device=mask.device) 77 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) 78 | pos_x = x_embed[:, :, :, None] / dim_t 79 | pos_y = y_embed[:, :, :, None] / dim_t 80 | # use `view` instead of `flatten` for dynamically exporting to ONNX 81 | B, H, W = mask.size() 82 | pos_x = torch.stack( 83 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), 84 | dim=4).view(B, H, W, -1) 85 | pos_y = torch.stack( 86 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), 87 | dim=4).view(B, H, W, -1) 88 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 89 | return pos 90 | 91 | def __repr__(self): 92 | """str: a string that describes the module""" 93 | repr_str = self.__class__.__name__ 94 | repr_str += f'(num_feats={self.num_feats}, ' 95 | repr_str += f'temperature={self.temperature}, ' 96 | repr_str += f'normalize={self.normalize}, ' 97 | repr_str += f'scale={self.scale}, ' 98 | repr_str += f'eps={self.eps})' 99 | return repr_str 100 | 101 | 102 | @POSITIONAL_ENCODING.register_module() 103 | class LearnedPositionalEncoding(BaseModule): 104 | """Position embedding with learnable embedding weights. 105 | Args: 106 | num_feats (int): The feature dimension for each position 107 | along x-axis or y-axis. The final returned dimension for 108 | each position is 2 times of this value. 109 | row_num_embed (int, optional): The dictionary size of row embeddings. 110 | Default 50. 111 | col_num_embed (int, optional): The dictionary size of col embeddings. 112 | Default 50. 113 | init_cfg (dict or list[dict], optional): Initialization config dict. 114 | """ 115 | 116 | def __init__(self, 117 | num_feats, 118 | row_num_embed=50, 119 | col_num_embed=50, 120 | init_cfg=dict(type='Uniform', layer='Embedding')): 121 | super(LearnedPositionalEncoding, self).__init__(init_cfg) 122 | self.row_embed = nn.Embedding(row_num_embed, num_feats) 123 | self.col_embed = nn.Embedding(col_num_embed, num_feats) 124 | self.num_feats = num_feats 125 | self.row_num_embed = row_num_embed 126 | self.col_num_embed = col_num_embed 127 | 128 | def forward(self, mask): 129 | """Forward function for `LearnedPositionalEncoding`. 130 | Args: 131 | mask (Tensor): ByteTensor mask. Non-zero values representing 132 | ignored positions, while zero values means valid positions 133 | for this image. Shape [bs, h, w]. 134 | Returns: 135 | pos (Tensor): Returned position embedding with shape 136 | [bs, num_feats*2, h, w]. 137 | """ 138 | h, w = mask.shape[-2:] 139 | x = torch.arange(w, device=mask.device) 140 | y = torch.arange(h, device=mask.device) 141 | x_embed = self.col_embed(x) 142 | y_embed = self.row_embed(y) 143 | pos = torch.cat( 144 | (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( 145 | 1, w, 1)), 146 | dim=-1).permute(2, 0, 147 | 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) 148 | return pos 149 | 150 | def __repr__(self): 151 | """str: a string that describes the module""" 152 | repr_str = self.__class__.__name__ 153 | repr_str += f'(num_feats={self.num_feats}, ' 154 | repr_str += f'row_num_embed={self.row_num_embed}, ' 155 | repr_str += f'col_num_embed={self.col_num_embed})' 156 | return repr_str -------------------------------------------------------------------------------- /depth/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.1.1' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import time 4 | 5 | import torch 6 | from mmcv import Config 7 | from mmcv.parallel import MMDataParallel 8 | from mmcv.runner import load_checkpoint, wrap_fp16_model 9 | 10 | from depth.datasets import build_dataloader, build_dataset 11 | from depth.models import build_depther 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='Depth benchmark a model') 16 | parser.add_argument('config', help='test config file path') 17 | parser.add_argument('checkpoint', help='checkpoint file') 18 | parser.add_argument( 19 | '--log-interval', type=int, default=50, help='interval of logging') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | 27 | cfg = Config.fromfile(args.config) 28 | # set cudnn_benchmark 29 | torch.backends.cudnn.benchmark = False 30 | cfg.model.pretrained = None 31 | cfg.data.test.test_mode = True 32 | 33 | # build the dataloader 34 | # TODO: support multiple images per gpu (only minor changes are needed) 35 | dataset = build_dataset(cfg.data.test) 36 | data_loader = build_dataloader( 37 | dataset, 38 | samples_per_gpu=1, 39 | workers_per_gpu=cfg.data.workers_per_gpu, 40 | dist=False, 41 | shuffle=False) 42 | 43 | # build the model and load checkpoint 44 | cfg.model.train_cfg = None 45 | model = build_depther(cfg.model, test_cfg=cfg.get('test_cfg')) 46 | fp16_cfg = cfg.get('fp16', None) 47 | if fp16_cfg is not None: 48 | wrap_fp16_model(model) 49 | load_checkpoint(model, args.checkpoint, map_location='cpu') 50 | 51 | model = MMDataParallel(model, device_ids=[0]) 52 | 53 | model.eval() 54 | 55 | # the first several iterations may be very slow so skip them 56 | num_warmup = 5 57 | pure_inf_time = 0 58 | total_iters = 200 59 | 60 | # benchmark with 200 image and take the average 61 | for i, data in enumerate(data_loader): 62 | 63 | torch.cuda.synchronize() 64 | start_time = time.perf_counter() 65 | 66 | with torch.no_grad(): 67 | model(return_loss=False, rescale=True, **data) 68 | 69 | torch.cuda.synchronize() 70 | elapsed = time.perf_counter() - start_time 71 | 72 | if i >= num_warmup: 73 | pure_inf_time += elapsed 74 | if (i + 1) % args.log_interval == 0: 75 | fps = (i + 1 - num_warmup) / pure_inf_time 76 | print(f'Done image [{i + 1:<3}/ {total_iters}], ' 77 | f'fps: {fps:.2f} img / s') 78 | 79 | if (i + 1) == total_iters: 80 | fps = (i + 1 - num_warmup) / pure_inf_time 81 | print(f'Overall fps: {fps:.2f} img / s') 82 | break 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29547} 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 10 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #CONFIG=$1 4 | GPUS=$1 5 | PORT=${PORT:-38423} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/ensemble.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import os 3 | import numpy as np 4 | import cv2 5 | 6 | res_path = 'nfs/results/test-ensemble-res' 7 | 8 | path_names = ['nfs/saves/test-ensemble'] 9 | 10 | weights = [1] 11 | 12 | reweights = [w/sum(weights) for w in weights] 13 | 14 | file_names = os.listdir(path_names[0]) 15 | 16 | for name in file_names: 17 | for idx, (path_name, w) in enumerate(zip(path_names, reweights)): 18 | file_path = os.path.join(path_name, name) 19 | if idx == 0: 20 | temp_res = w * np.load(file_path) 21 | else: 22 | temp_res += w * np.load(file_path) 23 | ensemble_res = temp_res / len(path_names) 24 | ensemble_res = ensemble_res[0].astype(np.uint16) 25 | filename = name[:-4] 26 | filename = filename + '.png' 27 | mmcv.imwrite(ensemble_res, os.path.join(res_path, filename)) -------------------------------------------------------------------------------- /tools/misc/visualize_point-cloud.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mmcv 3 | import numpy as np 4 | import warnings 5 | import torch 6 | import os 7 | from mmcv import Config, DictAction, mkdir_or_exist 8 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 9 | wrap_fp16_model) 10 | 11 | from os import path as osp 12 | from pathlib import Path 13 | 14 | from depth.datasets import build_dataloader, build_dataset 15 | from depth.models import build_depther 16 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 17 | 18 | import matplotlib.pyplot as plt 19 | import matplotlib 20 | import torch.nn.functional as F 21 | 22 | import random 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Browse a dataset') 26 | parser.add_argument('config', help='train config file path') 27 | parser.add_argument('checkpoint', help='checkpoint file') 28 | parser.add_argument( 29 | '--output-dir', 30 | default=None, 31 | type=str, 32 | help='If there is no display interface, you can save it') 33 | parser.add_argument( 34 | '--cfg-options', 35 | nargs='+', 36 | action=DictAction, 37 | help='override some settings in the used config, the key-value pair ' 38 | 'in xxx=yyy format will be merged into config file. If the value to ' 39 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 40 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 41 | 'Note that the quotation marks are necessary and that no white space ' 42 | 'is allowed.') 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | def build_data_cfg(cfg, cfg_options): 48 | """Build data config for loading visualization data.""" 49 | 50 | if cfg_options is not None: 51 | cfg.merge_from_dict(cfg_options) 52 | # extract inner dataset of `RepeatDataset` as `cfg.data.train` 53 | # so we don't need to worry about it later 54 | if cfg.data.train['type'] == 'RepeatDataset': 55 | cfg.data.train = cfg.data.train.dataset 56 | # use only first dataset for `ConcatDataset` 57 | if cfg.data.train['type'] == 'ConcatDataset': 58 | cfg.data.train = cfg.data.train.datasets[0] 59 | # train_data_cfg = cfg.data.train 60 | # show_pipeline = cfg.eval_pipeline 61 | # train_data_cfg['pipeline'] = show_pipeline 62 | test_data_cfg = cfg.data.test 63 | show_pipeline = cfg.eval_pipeline 64 | test_data_cfg['pipeline'] = show_pipeline 65 | 66 | return cfg 67 | 68 | def generate_pointcloud_ply(xyz, color, pc_file): 69 | # how to generate a pointcloud .ply file using xyz and color 70 | # xyz ndarray 3,N float 71 | # color ndarray 3,N uint8 72 | df = np.zeros((6, xyz.shape[1])) 73 | df[0] = xyz[0] 74 | df[1] = xyz[1] 75 | df[2] = xyz[2] 76 | df[3] = color[0] 77 | df[4] = color[1] 78 | df[5] = color[2] 79 | float_formatter = lambda x: "%.4f" % x 80 | points =[] 81 | for i in df.T: 82 | points.append("{} {} {} {} {} {} 0\n".format 83 | (float_formatter(i[0]), float_formatter(i[1]), float_formatter(i[2]), 84 | int(i[3]), int(i[4]), int(i[5]))) 85 | file = open(pc_file, "w") 86 | file.write('''ply 87 | format ascii 1.0 88 | element vertex %d 89 | property float x 90 | property float y 91 | property float z 92 | property uchar red 93 | property uchar green 94 | property uchar blue 95 | property uchar alpha 96 | end_header 97 | %s 98 | ''' % (len(points), "".join(points))) 99 | file.close() 100 | 101 | def main(): 102 | args = parse_args() 103 | 104 | if args.output_dir is not None: 105 | mkdir_or_exist(args.output_dir) 106 | 107 | cfg = mmcv.Config.fromfile(args.config) 108 | cfg = build_data_cfg(cfg, args.cfg_options) 109 | dataset = build_dataset(cfg.data.test) 110 | data_loader = build_dataloader( 111 | dataset, 112 | samples_per_gpu=1, 113 | workers_per_gpu=cfg.data.workers_per_gpu, 114 | dist=False, 115 | shuffle=False) 116 | model = build_depther(cfg.model, test_cfg=cfg.get('test_cfg')) 117 | model.eval() 118 | 119 | # for other models 120 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 121 | 122 | model = MMDataParallel(model, device_ids=[0]) 123 | 124 | progress_bar = mmcv.ProgressBar(len(dataset)) 125 | 126 | if args.output_dir is not None: 127 | mkdir_or_exist(args.output_dir) 128 | 129 | for idx, input in enumerate(data_loader): 130 | 131 | with torch.no_grad(): 132 | aug_data_dict = {key: [] for key in input} 133 | for data in [input]: 134 | for key, val in data.items(): 135 | aug_data_dict[key].append(val) 136 | 137 | img_file = aug_data_dict['img_metas'][0]._data[0][0]['filename'] 138 | img = mmcv.imread(img_file) 139 | name = osp.splitext(img_file)[0].split('/')[-2] + '_' + osp.splitext(img_file)[0].split('/')[-1] 140 | output = model(return_loss=False, **aug_data_dict) 141 | 142 | depth = torch.tensor(output[0], dtype=torch.float32) 143 | # https://github.com/SJTU-ViSYS/StructDepth/blob/17478278c228662248772c9a0c94d553d20078c5/datasets/nyu_dataset.py#L345 144 | 145 | # y, x 146 | h, w = 480-88, 640-80 147 | fx = 5.1885790117450188e+02 148 | fy = 5.1946961112127485e+02 149 | cx = (3.2558244941119034e+02 - 40) 150 | cy = (2.5373616633400465e+02 - 44) 151 | 152 | intrinsics = [ 153 | [fx, 0., cx, 0.], [0., fy, cy, 0.], 154 | [0., 0., 1., 0.], [0., 0., 0., 1.] 155 | ] 156 | 157 | depth = depth[0, 44:480-44, 40:640-40].contiguous() 158 | 159 | meshgrid = np.meshgrid(range(w), range(h), indexing='xy') 160 | id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 161 | id_coords = depth.new_tensor(id_coords) 162 | pix_coords = torch.cat([id_coords[0].view(-1).unsqueeze(dim=0), id_coords[1].view(-1).unsqueeze(dim=0)], 0) 163 | ones = torch.ones(1, w * h) 164 | pix_coords = torch.cat([pix_coords, ones], dim=0) # 3xHW 165 | 166 | inv_K = np.array(np.matrix(intrinsics).I) 167 | inv_K = pix_coords.new_tensor(inv_K) 168 | cam_points = torch.matmul(inv_K[:3, :3], pix_coords) 169 | 170 | depth_flatten = depth.view(-1) 171 | 172 | cam_points = torch.einsum('cn,n->cn', cam_points, depth_flatten) 173 | 174 | 175 | img_tensor = torch.tensor(img[44:480-44, 40:640-40, :], dtype=torch.uint8) 176 | img_tensor = img_tensor[:, :, [2, 1, 0]] 177 | 178 | img_tensor_flatten = img_tensor.permute(2, 0, 1).flatten(start_dim=1) 179 | 180 | generate_pointcloud_ply(cam_points, img_tensor_flatten.numpy(), os.path.join(args.output_dir, name+'.ply')) 181 | progress_bar.update() 182 | 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config, DictAction 5 | 6 | from depth.apis import init_depther 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='Print the whole config') 11 | parser.add_argument('config', help='config file path') 12 | parser.add_argument( 13 | '--graph', action='store_true', help='print the models graph') 14 | parser.add_argument( 15 | '--options', nargs='+', action=DictAction, help='arguments in dict') 16 | args = parser.parse_args() 17 | 18 | return args 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | 24 | cfg = Config.fromfile(args.config) 25 | if args.options is not None: 26 | cfg.merge_from_dict(args.options) 27 | print(f'Config:\n{cfg.pretty_text}') 28 | # dump config 29 | cfg.dump('example.py') 30 | # dump models graph 31 | if args.graph: 32 | model = init_depther(args.config, device='cpu') 33 | print(f'Model graph:\n{str(model)}') 34 | with open('example-graph.txt', 'w') as f: 35 | f.writelines(str(model)) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-4} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | GPUS=${GPUS:-4} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:4} 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 24 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import shutil 5 | import warnings 6 | 7 | import mmcv 8 | import torch 9 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 10 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 11 | wrap_fp16_model) 12 | from mmcv.utils import DictAction 13 | 14 | from depth.apis import multi_gpu_test, single_gpu_test 15 | from depth.datasets import build_dataloader, build_dataset 16 | from depth.models import build_depther 17 | 18 | import numpy as np 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser( 22 | description='depth test (and eval) a model') 23 | parser.add_argument('config', help='test config file path') 24 | parser.add_argument('checkpoint', help='checkpoint file') 25 | parser.add_argument( 26 | '--aug-test', action='store_true', help='Use Flip and Multi scale aug') 27 | parser.add_argument('--out', help='output result file in pickle format') 28 | parser.add_argument( 29 | '--format-only', 30 | action='store_true', 31 | help='Format the output results without perform evaluation. It is' 32 | 'useful when you want to format the result to a specific format and ' 33 | 'submit it to the test server') 34 | parser.add_argument( 35 | '--eval', 36 | type=str, 37 | nargs='+', 38 | help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' 39 | ' for generic datasets, and "cityscapes" for Cityscapes') 40 | parser.add_argument('--show', action='store_true', help='show results') 41 | parser.add_argument( 42 | '--show-dir', help='directory where painted images will be saved') 43 | parser.add_argument( 44 | '--gpu-collect', 45 | action='store_true', 46 | help='whether to use gpu to collect results.') 47 | parser.add_argument( 48 | '--tmpdir', 49 | help='tmp directory used for collecting results from multiple ' 50 | 'workers, available when gpu_collect is not specified') 51 | parser.add_argument( 52 | '--options', nargs='+', action=DictAction, help='custom options') 53 | parser.add_argument( 54 | '--eval-options', 55 | nargs='+', 56 | action=DictAction, 57 | help='custom options for evaluation') 58 | parser.add_argument( 59 | '--launcher', 60 | choices=['none', 'pytorch', 'slurm', 'mpi'], 61 | default='none', 62 | help='job launcher') 63 | parser.add_argument('--local_rank', type=int, default=0) 64 | args = parser.parse_args() 65 | if 'LOCAL_RANK' not in os.environ: 66 | os.environ['LOCAL_RANK'] = str(args.local_rank) 67 | return args 68 | 69 | 70 | def main(): 71 | args = parse_args() 72 | 73 | assert args.out or args.eval or args.format_only or args.show \ 74 | or args.show_dir, \ 75 | ('Please specify at least one operation (save/eval/format/show the ' 76 | 'results / save the results) with the argument "--out", "--eval"' 77 | ', "--format-only", "--show" or "--show-dir"') 78 | 79 | if args.eval and args.format_only: 80 | raise ValueError('--eval and --format_only cannot be both specified') 81 | 82 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 83 | raise ValueError('The output file must be a pkl file.') 84 | 85 | if args.out: 86 | print(os.path.dirname(args.out)) 87 | mmcv.mkdir_or_exist(os.path.dirname(args.out)) 88 | 89 | cfg = mmcv.Config.fromfile(args.config) 90 | if args.options is not None: 91 | cfg.merge_from_dict(args.options) 92 | # set cudnn_benchmark 93 | if cfg.get('cudnn_benchmark', False): 94 | torch.backends.cudnn.benchmark = True 95 | if args.aug_test: 96 | # hard code index 97 | cfg.data.test.pipeline[1].img_ratios = [ 98 | 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 99 | ] 100 | cfg.data.test.pipeline[1].flip = True 101 | cfg.model.pretrained = None 102 | cfg.data.test.test_mode = True 103 | 104 | # init distributed env first, since logger depends on the dist info. 105 | if args.launcher == 'none': 106 | distributed = False 107 | else: 108 | distributed = True 109 | init_dist(args.launcher, **cfg.dist_params) 110 | 111 | # build the dataloader 112 | # TODO: support multiple images per gpu (only minor changes are needed) 113 | dataset = build_dataset(cfg.data.test) 114 | data_loader = build_dataloader( 115 | dataset, 116 | samples_per_gpu=1, 117 | workers_per_gpu=cfg.data.workers_per_gpu, 118 | dist=distributed, 119 | shuffle=False) 120 | 121 | # build the model and load checkpoint 122 | cfg.model.train_cfg = None 123 | 124 | model = build_depther( 125 | cfg.model, 126 | test_cfg=cfg.get('test_cfg')) 127 | 128 | fp16_cfg = cfg.get('fp16', None) 129 | if fp16_cfg is not None: 130 | wrap_fp16_model(model) 131 | 132 | # for other models 133 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 134 | 135 | # clean gpu memory when starting a new evaluation. 136 | torch.cuda.empty_cache() 137 | eval_kwargs = {} if args.eval_options is None else args.eval_options 138 | 139 | eval_on_format_results = ( 140 | args.eval is not None and 'cityscapes' in args.eval) 141 | if eval_on_format_results: 142 | assert len(args.eval) == 1, 'eval on format results is not ' \ 143 | 'applicable for metrics other than ' \ 144 | 'cityscapes' 145 | if args.format_only or eval_on_format_results: 146 | if 'imgfile_prefix' in eval_kwargs: 147 | tmpdir = eval_kwargs['imgfile_prefix'] 148 | else: 149 | tmpdir = '.format_cityscapes' 150 | eval_kwargs.setdefault('imgfile_prefix', tmpdir) 151 | mmcv.mkdir_or_exist(tmpdir) 152 | else: 153 | tmpdir = None 154 | 155 | if not distributed: 156 | model = MMDataParallel(model, device_ids=[0]) 157 | results = single_gpu_test( 158 | model, 159 | data_loader, 160 | args.show, 161 | args.show_dir, 162 | pre_eval=args.eval is not None and not eval_on_format_results, 163 | format_only=args.format_only or eval_on_format_results, 164 | format_args=eval_kwargs) 165 | else: 166 | model = MMDistributedDataParallel( 167 | model.cuda(), 168 | device_ids=[torch.cuda.current_device()], 169 | broadcast_buffers=False) 170 | results = multi_gpu_test( 171 | model, 172 | data_loader, 173 | args.tmpdir, 174 | args.gpu_collect, 175 | pre_eval=args.eval is not None and not eval_on_format_results, 176 | format_only=args.format_only or eval_on_format_results, 177 | format_args=eval_kwargs) 178 | 179 | rank, _ = get_dist_info() 180 | if rank == 0: 181 | if args.out: 182 | warnings.warn( 183 | 'The pickled outputs could be depth map as type of ' 184 | 'np.array, pre-eval results or file paths for ' 185 | '``dataset.format_results()``.') 186 | print(f'\nwriting results to {args.out}') 187 | mmcv.dump(results, args.out) 188 | if args.eval: 189 | dataset.evaluate(results, args.eval, **eval_kwargs) 190 | if tmpdir is not None and eval_on_format_results: 191 | # remove tmp dir when cityscapes evaluation 192 | shutil.rmtree(tmpdir) 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import time 7 | 8 | import mmcv 9 | import torch 10 | from mmcv.runner import init_dist 11 | from mmcv.utils import Config, DictAction, get_git_hash 12 | from depth.models.trap import siam 13 | 14 | from depth import __version__ 15 | from depth.apis import set_random_seed, train_depther 16 | from depth.datasets import build_dataset 17 | from depth.models import build_depther, build_depther 18 | from depth.utils import collect_env, get_root_logger 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train a depthor') 23 | parser.add_argument('--config', 24 | default=r'', 25 | help='train config file path') 26 | parser.add_argument('--work-dir', default='', help='the dir to save logs and models') 27 | parser.add_argument( 28 | '--load-from', help='the checkpoint file to load weights from', 29 | default='' 30 | ) 31 | parser.add_argument( 32 | '--resume-from', help='the checkpoint file to resume from') 33 | parser.add_argument( 34 | '--no-validate', 35 | action='store_true', 36 | help='whether not to evaluate the checkpoint during training') 37 | group_gpus = parser.add_mutually_exclusive_group() 38 | group_gpus.add_argument( 39 | '--gpus', 40 | type=int, 41 | help='number of gpus to use ' 42 | '(only applicable to non-distributed training)') 43 | group_gpus.add_argument( 44 | '--gpu-ids', 45 | type=int, 46 | nargs='+', 47 | help='ids of gpus to use ' 48 | '(only applicable to non-distributed training)') 49 | parser.add_argument('--seed', type=int, default=None, help='random seed') 50 | parser.add_argument( 51 | '--deterministic', 52 | action='store_true', 53 | help='whether to set deterministic options for CUDNN backend.') 54 | parser.add_argument( 55 | '--options', nargs='+', action=DictAction, help='custom options') 56 | parser.add_argument( 57 | '--launcher', 58 | choices=['none', 'pytorch', 'slurm', 'mpi'], 59 | default='none', 60 | help='job launcher') 61 | parser.add_argument('--local_rank', type=int, default=0) 62 | args = parser.parse_args() 63 | if 'LOCAL_RANK' not in os.environ: 64 | os.environ['LOCAL_RANK'] = str(args.local_rank) 65 | 66 | return args 67 | 68 | 69 | def main(): 70 | args = parse_args() 71 | 72 | cfg = Config.fromfile(args.config) 73 | if args.options is not None: 74 | cfg.merge_from_dict(args.options) 75 | # set cudnn_benchmark 76 | if cfg.get('cudnn_benchmark', False): 77 | torch.backends.cudnn.benchmark = True 78 | 79 | # work_dir is determined in this priority: CLI > segment in file > filename 80 | if args.work_dir is not None: 81 | # update configs according to CLI args if args.work_dir is not None 82 | cfg.work_dir = args.work_dir 83 | elif cfg.get('work_dir', None) is None: 84 | # use config filename as default work_dir if cfg.work_dir is None 85 | cfg.work_dir = osp.join('./work_dirs', 86 | osp.splitext(osp.basename(args.config))[0]) 87 | if args.load_from is not None: 88 | cfg.load_from = args.load_from 89 | if args.resume_from is not None: 90 | cfg.resume_from = args.resume_from 91 | if args.gpu_ids is not None: 92 | cfg.gpu_ids = args.gpu_ids 93 | else: 94 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 95 | 96 | # init distributed env first, since logger depends on the dist info. 97 | if args.launcher == 'none': 98 | distributed = False 99 | else: 100 | distributed = True 101 | init_dist(args.launcher, **cfg.dist_params) 102 | 103 | # create work_dir 104 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 105 | # dump config 106 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 107 | # init the logger before other steps 108 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 109 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 110 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 111 | 112 | # init the meta dict to record some important information such as 113 | # environment info and seed, which will be logged 114 | meta = dict() 115 | # log env info 116 | env_info_dict = collect_env() 117 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 118 | dash_line = '-' * 60 + '\n' 119 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 120 | dash_line) 121 | meta['env_info'] = env_info 122 | 123 | # log some basic info 124 | logger.info(f'Distributed training: {distributed}') 125 | logger.info(f'Config:\n{cfg.pretty_text}') 126 | 127 | # set random seeds 128 | if args.seed is not None: 129 | logger.info(f'Set random seed to {args.seed}, deterministic: ' 130 | f'{args.deterministic}') 131 | set_random_seed(args.seed, deterministic=args.deterministic) 132 | cfg.seed = args.seed 133 | meta['seed'] = args.seed 134 | meta['exp_name'] = osp.basename(args.config) 135 | 136 | model = build_depther( 137 | cfg.model, 138 | train_cfg=cfg.get('train_cfg'), 139 | test_cfg=cfg.get('test_cfg')) 140 | model.init_weights() 141 | 142 | # NOTE: set all the bn to syncbn 143 | import torch.nn as nn 144 | if cfg.get('SyncBN', False): 145 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 146 | 147 | logger.info(model) 148 | 149 | datasets = [build_dataset(cfg.data.train)] 150 | if len(cfg.workflow) == 2: 151 | val_dataset = copy.deepcopy(cfg.data.val) 152 | val_dataset.pipeline = cfg.data.train.pipeline 153 | datasets.append(build_dataset(val_dataset)) 154 | if cfg.checkpoint_config is not None: 155 | # save depth version, config file content and class names in 156 | # checkpoints as meta data 157 | cfg.checkpoint_config.meta = dict( 158 | depth_version=f'{__version__}+{get_git_hash()[:7]}', 159 | config=cfg.pretty_text) 160 | # passing checkpoint meta for saving best checkpoint 161 | meta.update(cfg.checkpoint_config.meta) 162 | train_depther( 163 | model, 164 | datasets, 165 | cfg, 166 | distributed=distributed, 167 | validate=(not args.no_validate), 168 | timestamp=timestamp, 169 | meta=meta) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | --------------------------------------------------------------------------------