├── .gitignore ├── README.md ├── configs ├── _base_ │ ├── datasets │ │ ├── cityscapes_half_512x512.py │ │ ├── gta_to_cityscapes_512x512.py │ │ ├── uda_cityscapes_to_acdc_512x512.py │ │ ├── uda_cityscapes_to_darkzurich_512x512.py │ │ ├── uda_gta_to_cityscapes_512x512.py │ │ └── uda_synthia_to_cityscapes_512x512.py │ ├── default_runtime.py │ ├── models │ │ ├── daformer_aspp_mitb5.py │ │ ├── daformer_conv1_mitb5.py │ │ ├── daformer_isa_mitb5.py │ │ ├── daformer_sepaspp_bottleneck_mitb5.py │ │ ├── daformer_sepaspp_mitb5.py │ │ ├── danet_r50-d8.py │ │ ├── deeplabv2_r50-d8.py │ │ ├── deeplabv2red_r50-d8.py │ │ ├── deeplabv3plus_r50-d8.py │ │ ├── isanet_r50-d8.py │ │ ├── segformer.py │ │ ├── segformer_b5.py │ │ ├── segformer_r101.py │ │ ├── upernet_ch256_mit.py │ │ └── upernet_mit.py │ ├── schedules │ │ ├── adamw.py │ │ ├── poly10.py │ │ └── poly10warm.py │ └── uda │ │ ├── dacs.py │ │ ├── dacs_a999_fdthings.py │ │ ├── dacs_cda.py │ │ ├── dacs_fd.py │ │ └── dacs_fdthings.py ├── cdac │ ├── cs2acdc_uda_dacs_cda_mitb5_b2_s0.py │ ├── gta2cs_uda_dacs_cda_mitb5_b2_s0.py │ └── synthia2cs_uda_dacs_cda_mitb5_b2_s0.py └── daformer │ └── gta2cs_uda_warm_fdthings_rcs_croppl_a999_daformer_mitb5_s0.py ├── demo ├── demo.png └── image_demo.py ├── experiments.py ├── mmseg ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── ddp_wrapper.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── class_names.py │ │ ├── eval_hooks.py │ │ └── metrics.py │ ├── seg │ │ ├── __init__.py │ │ ├── builder.py │ │ └── sampler │ │ │ ├── __init__.py │ │ │ ├── base_pixel_sampler.py │ │ │ └── ohem_pixel_sampler.py │ └── utils │ │ ├── __init__.py │ │ └── misc.py ├── datasets │ ├── __init__.py │ ├── acdc.py │ ├── builder.py │ ├── cityscapes.py │ ├── custom.py │ ├── dark_zurich.py │ ├── dataset_wrappers.py │ ├── gta.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formating.py │ │ ├── loading.py │ │ ├── test_time_aug.py │ │ └── transforms.py │ ├── synthia.py │ └── uda_dataset.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── mix_transformer.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ └── resnext.py │ ├── builder.py │ ├── decode_heads │ │ ├── __init__.py │ │ ├── aspp_head.py │ │ ├── da_head.py │ │ ├── daformer_head.py │ │ ├── decode_head.py │ │ ├── dlv2_head.py │ │ ├── fcn_head.py │ │ ├── isa_head.py │ │ ├── psp_head.py │ │ ├── segformer_head.py │ │ ├── sep_aspp_head.py │ │ └── uper_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── cross_entropy_loss.py │ │ └── utils.py │ ├── necks │ │ ├── __init__.py │ │ └── segformer_adapter.py │ ├── segmentors │ │ ├── __init__.py │ │ ├── base.py │ │ └── encoder_decoder.py │ ├── uda │ │ ├── __init__.py │ │ ├── dacs.py │ │ ├── dacs_cda.py │ │ └── uda_decorator.py │ └── utils │ │ ├── __init__.py │ │ ├── ckpt_convert.py │ │ ├── dacs_transforms.py │ │ ├── make_divisible.py │ │ ├── res_layer.py │ │ ├── self_attention_block.py │ │ ├── shape_convert.py │ │ └── visualization.py ├── ops │ ├── __init__.py │ ├── encoding.py │ └── wrappers.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── logger.py │ └── utils.py └── version.py ├── requirements.txt ├── run_experiments.py ├── test.sh └── tools ├── __init__.py ├── analyze_logs.py ├── convert_datasets ├── cityscapes.py ├── gta.py └── synthia.py ├── download_checkpoints.sh ├── get_param_count.py ├── print_config.py ├── publish_model.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | work_dirs/ 2 | jobs/ 3 | pretrained/ 4 | __pycache__/ 5 | */__pycache__ 6 | *.zip 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CDAC: Cross-domain Attention Consistency in Transformer for Domain Adaptive Semantic Segmentation 2 | 3 | Official release of the source code for [CDAC: Cross-domain Attention Consistency in Transformer for Domain Adaptive Semantic Segmentation](https://arxiv.org/abs/2211.14703) at ICCV 2023. 4 | 5 | ## Overview 6 | We propose Cross-Domain Attention Consistency (CDAC), to perform adaptation on attention maps using cross-domain attention layers that share features between source and target domains. Specifically, we impose consistency between predictions from cross-domain attention and self-attention modules to encourage similar distributions across domains in both the attention and output of the model, i.e., attention-level and output-level alignment. We also enforce consistency in attention maps between different augmented views to further strengthen the attention-based alignment. Combining these two components, CDAC mitigates the discrepancy in attention maps across domains and further boosts the performance of the transformer under unsupervised domain adaptation settings. 7 | Our method is evaluated on various widely used benchmarks and outperforms the state-of-the-art baselines, including GTAV-to-Cityscapes by 1.3 and 1.5 percent point (pp) and Synthia-to-Cityscapes by 0.6 pp and 2.9 pp when combining with two competitive Transformer-based backbones, respectively. 8 | 9 | ## Installation and Data Preparation 10 | 11 | Since our model is primarily built on the basis of DAFormer, please refer to the `Setup Environment` and the `Setup Datasets` section in the [original repo](https://github.com/lhoyer/DAFormer/) for instructions to set up the environment and prepare for the datasets. 12 | 13 | ## Training 14 | 15 | For training our model on GTAV->Cityscapes: 16 | ```shell 17 | python run_experiments.py --config configs/cdac/gta2cs_uda_dacs_cda_mitb5_b2_s0.py 18 | ``` 19 | 20 | For training our model on Synthia->Cityscapes: 21 | ```shell 22 | python run_experiments.py --config configs/cdac/synthia2cs_uda_dacs_cda_mitb5_b2_s0.py 23 | ``` 24 | 25 | For training our model on Cityscapes->ACDC: 26 | ```shell 27 | python run_experiments.py --config configs/cdac/cs2acdc_uda_dacs_cda_mitb5_b2_s0.py 28 | ``` 29 | 30 | ## Testing 31 | 32 | Our models pretrained on the three benchmarks are also saved and available online. Please kindly find them [here]([https://drive.google.com/file/d/1Zcb2E6or31_JgLFhaQgeT9UD-7TtkUyl/view?usp=sharing](https://www.dropbox.com/scl/fo/zshfbb85djhxuuu2qx32q/AN_oH5stBEqEE_CRcobFmMs?rlkey=pe0zqg3vf067ig8w9jbwpoiun&st=ecfe0mmh&dl=0)). After downloading the files, please run the following command: 33 | 34 | ```shell 35 | sh test.sh path/to/checkpoint_directory 36 | ``` 37 | 38 | ## Acknowledgements 39 | 40 | The code of this project is heavily borrowed from DAFormer and its dependent repo. 41 | We thank their authors for making the source code publically available. 42 | 43 | * [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) 44 | * [SegFormer](https://github.com/NVlabs/SegFormer) 45 | * [DACS](https://github.com/vikolss/DACS) 46 | * [DAFormer](https://github.com/lhoyer/DAFormer) 47 | 48 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cityscapes_half_512x512.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Half image resolution 3 | 4 | # dataset settings 5 | dataset_type = 'CityscapesDataset' 6 | data_root = 'data/cityscapes/' 7 | img_norm_cfg = dict( 8 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 9 | crop_size = (512, 512) 10 | train_pipeline = [ 11 | dict(type='LoadImageFromFile'), 12 | dict(type='LoadAnnotations'), 13 | dict(type='Resize', img_scale=(1024, 512)), 14 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 15 | dict(type='RandomFlip', prob=0.5), 16 | dict(type='PhotoMetricDistortion'), 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 19 | dict(type='DefaultFormatBundle'), 20 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 21 | ] 22 | test_pipeline = [ 23 | dict(type='LoadImageFromFile'), 24 | dict( 25 | type='MultiScaleFlipAug', 26 | img_scale=(1024, 512), 27 | # MultiScaleFlipAug is disabled by not providing img_ratios and 28 | # setting flip=False 29 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='RandomFlip'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='ImageToTensor', keys=['img']), 36 | dict(type='Collect', keys=['img']), 37 | ]) 38 | ] 39 | data = dict( 40 | samples_per_gpu=2, 41 | workers_per_gpu=4, 42 | train=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | img_dir='leftImg8bit/train', 46 | ann_dir='gtFine/train', 47 | pipeline=train_pipeline), 48 | val=dict( 49 | type=dataset_type, 50 | data_root=data_root, 51 | img_dir='leftImg8bit/val', 52 | ann_dir='gtFine/val', 53 | pipeline=test_pipeline), 54 | test=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir='leftImg8bit/val', 58 | ann_dir='gtFine/val', 59 | pipeline=test_pipeline)) 60 | -------------------------------------------------------------------------------- /configs/_base_/datasets/gta_to_cityscapes_512x512.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | dataset_type = 'CityscapesDataset' 8 | data_root = 'data/cityscapes/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (512, 512) 12 | train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1280, 720)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | dict(type='PhotoMetricDistortion'), 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | test_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict( 27 | type='MultiScaleFlipAug', 28 | img_scale=(1024, 512), 29 | # MultiScaleFlipAug is disabled by not providing img_ratios and 30 | # setting flip=False 31 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 32 | flip=False, 33 | transforms=[ 34 | dict(type='Resize', keep_ratio=True), 35 | dict(type='RandomFlip'), 36 | dict(type='Normalize', **img_norm_cfg), 37 | dict(type='ImageToTensor', keys=['img']), 38 | dict(type='Collect', keys=['img']), 39 | ]) 40 | ] 41 | data = dict( 42 | samples_per_gpu=2, 43 | workers_per_gpu=4, 44 | train=dict( 45 | type='GTADataset', 46 | data_root='data/gta/', 47 | img_dir='images', 48 | ann_dir='labels', 49 | pipeline=train_pipeline), 50 | val=dict( 51 | type='CityscapesDataset', 52 | data_root='data/cityscapes/', 53 | img_dir='leftImg8bit/val', 54 | ann_dir='gtFine/val', 55 | pipeline=test_pipeline), 56 | test=dict( 57 | type='CityscapesDataset', 58 | data_root='data/cityscapes/', 59 | img_dir='leftImg8bit/val', 60 | ann_dir='gtFine/val', 61 | pipeline=test_pipeline)) 62 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_cityscapes_to_acdc_512x512.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | img_norm_cfg = dict( 8 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 9 | crop_size = (512, 512) 10 | cityscapes_train_pipeline = [ 11 | dict(type='LoadImageFromFile'), 12 | dict(type='LoadAnnotations'), 13 | dict(type='Resize', img_scale=(1024, 512)), 14 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 15 | dict(type='RandomFlip', prob=0.5), 16 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 19 | dict(type='DefaultFormatBundle'), 20 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 21 | ] 22 | acdc_train_pipeline = [ 23 | dict(type='LoadImageFromFile'), 24 | dict(type='Resize', img_scale=(960, 540)), # original 1920x1080 25 | dict(type='RandomCrop', crop_size=crop_size), 26 | dict(type='RandomFlip', prob=0.5), 27 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 28 | dict(type='Normalize', **img_norm_cfg), 29 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 30 | dict(type='DefaultFormatBundle'), 31 | dict(type='Collect', keys=['img']), 32 | ] 33 | test_pipeline = [ 34 | dict(type='LoadImageFromFile'), 35 | dict( 36 | type='MultiScaleFlipAug', 37 | img_scale=(960, 540), # original 1920x1080 38 | # MultiScaleFlipAug is disabled by not providing img_ratios and 39 | # setting flip=False 40 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 41 | flip=False, 42 | transforms=[ 43 | dict(type='Resize', keep_ratio=True), 44 | dict(type='RandomFlip'), 45 | dict(type='Normalize', **img_norm_cfg), 46 | dict(type='ImageToTensor', keys=['img']), 47 | dict(type='Collect', keys=['img']), 48 | ]) 49 | ] 50 | data = dict( 51 | samples_per_gpu=2, 52 | workers_per_gpu=4, 53 | train=dict( 54 | type='UDADataset', 55 | source=dict( 56 | type='CityscapesDataset', 57 | data_root='data/cityscapes/', 58 | img_dir='leftImg8bit/train', 59 | ann_dir='gtFine/train', 60 | pipeline=cityscapes_train_pipeline), 61 | target=dict( 62 | type='ACDCDataset', 63 | data_root='data/acdc/', 64 | img_dir='rgb_anon/train', 65 | ann_dir='gt/train', 66 | pipeline=acdc_train_pipeline)), 67 | val=dict( 68 | type='ACDCDataset', 69 | data_root='data/acdc/', 70 | img_dir='rgb_anon/val', 71 | ann_dir='gt/val', 72 | pipeline=test_pipeline), 73 | test=dict( 74 | type='ACDCDataset', 75 | data_root='data/acdc/', 76 | img_dir='rgb_anon/val', 77 | ann_dir='gt/val', 78 | pipeline=test_pipeline)) 79 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_cityscapes_to_darkzurich_512x512.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | img_norm_cfg = dict( 8 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 9 | crop_size = (512, 512) 10 | cityscapes_train_pipeline = [ 11 | dict(type='LoadImageFromFile'), 12 | dict(type='LoadAnnotations'), 13 | dict(type='Resize', img_scale=(1024, 512)), 14 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 15 | dict(type='RandomFlip', prob=0.5), 16 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 19 | dict(type='DefaultFormatBundle'), 20 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 21 | ] 22 | dark_zurich_train_pipeline = [ 23 | dict(type='LoadImageFromFile'), 24 | dict(type='Resize', img_scale=(960, 540)), # original 1920x1080 25 | dict(type='RandomCrop', crop_size=crop_size), 26 | dict(type='RandomFlip', prob=0.5), 27 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 28 | dict(type='Normalize', **img_norm_cfg), 29 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 30 | dict(type='DefaultFormatBundle'), 31 | dict(type='Collect', keys=['img']), 32 | ] 33 | test_pipeline = [ 34 | dict(type='LoadImageFromFile'), 35 | dict( 36 | type='MultiScaleFlipAug', 37 | img_scale=(960, 540), # original 1920x1080 38 | # MultiScaleFlipAug is disabled by not providing img_ratios and 39 | # setting flip=False 40 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 41 | flip=False, 42 | transforms=[ 43 | dict(type='Resize', keep_ratio=True), 44 | dict(type='RandomFlip'), 45 | dict(type='Normalize', **img_norm_cfg), 46 | dict(type='ImageToTensor', keys=['img']), 47 | dict(type='Collect', keys=['img']), 48 | ]) 49 | ] 50 | data = dict( 51 | samples_per_gpu=2, 52 | workers_per_gpu=4, 53 | train=dict( 54 | type='UDADataset', 55 | source=dict( 56 | type='CityscapesDataset', 57 | data_root='data/cityscapes/', 58 | img_dir='leftImg8bit/train', 59 | ann_dir='gtFine/train', 60 | pipeline=cityscapes_train_pipeline), 61 | target=dict( 62 | type='DarkZurichDataset', 63 | data_root='data/dark_zurich/', 64 | img_dir='rgb_anon/train/night/', 65 | ann_dir='gt/train/night/', 66 | pipeline=dark_zurich_train_pipeline)), 67 | val=dict( 68 | type='DarkZurichDataset', 69 | data_root='data/dark_zurich/', 70 | img_dir='rgb_anon/val', 71 | ann_dir='gt/val', 72 | pipeline=test_pipeline), 73 | test=dict( 74 | type='DarkZurichDataset', 75 | data_root='data/dark_zurich/', 76 | img_dir='rgb_anon/val', 77 | ann_dir='gt/val', 78 | pipeline=test_pipeline)) 79 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_gta_to_cityscapes_512x512.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | dataset_type = 'CityscapesDataset' 8 | data_root = 'data/cityscapes/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (512, 512) 12 | gta_train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1280, 720)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | cityscapes_train_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict(type='LoadAnnotations'), 27 | dict(type='Resize', img_scale=(1024, 512)), 28 | dict(type='RandomCrop', crop_size=crop_size), 29 | dict(type='RandomFlip', prob=0.5), 30 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 31 | dict(type='Normalize', **img_norm_cfg), 32 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 35 | ] 36 | test_pipeline = [ 37 | dict(type='LoadImageFromFile'), 38 | dict( 39 | type='MultiScaleFlipAug', 40 | img_scale=(1024, 512), 41 | # MultiScaleFlipAug is disabled by not providing img_ratios and 42 | # setting flip=False 43 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 44 | flip=False, 45 | transforms=[ 46 | dict(type='Resize', keep_ratio=True), 47 | dict(type='RandomFlip'), 48 | dict(type='Normalize', **img_norm_cfg), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']), 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type='UDADataset', 58 | source=dict( 59 | type='GTADataset', 60 | data_root='data/gta/', 61 | img_dir='images', 62 | ann_dir='labels', 63 | pipeline=gta_train_pipeline), 64 | target=dict( 65 | type='CityscapesDataset', 66 | data_root='data/cityscapes/', 67 | img_dir='leftImg8bit/train', 68 | ann_dir='gtFine/train', 69 | pipeline=cityscapes_train_pipeline)), 70 | val=dict( 71 | type='CityscapesDataset', 72 | data_root='data/cityscapes/', 73 | img_dir='leftImg8bit/val', 74 | ann_dir='gtFine/val', 75 | pipeline=test_pipeline), 76 | test=dict( 77 | type='CityscapesDataset', 78 | data_root='data/cityscapes/', 79 | img_dir='leftImg8bit/val', 80 | ann_dir='gtFine/val', 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_synthia_to_cityscapes_512x512.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | dataset_type = 'CityscapesDataset' 8 | data_root = 'data/cityscapes/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (512, 512) 12 | synthia_train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1280, 760)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | cityscapes_train_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict(type='LoadAnnotations'), 27 | dict(type='Resize', img_scale=(1024, 512)), 28 | dict(type='RandomCrop', crop_size=crop_size), 29 | dict(type='RandomFlip', prob=0.5), 30 | # dict(type='PhotoMetricDistortion'), # is applied later in dacs.py 31 | dict(type='Normalize', **img_norm_cfg), 32 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 35 | ] 36 | test_pipeline = [ 37 | dict(type='LoadImageFromFile'), 38 | dict( 39 | type='MultiScaleFlipAug', 40 | img_scale=(1024, 512), 41 | # MultiScaleFlipAug is disabled by not providing img_ratios and 42 | # setting flip=False 43 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 44 | flip=False, 45 | transforms=[ 46 | dict(type='Resize', keep_ratio=True), 47 | dict(type='RandomFlip'), 48 | dict(type='Normalize', **img_norm_cfg), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']), 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type='UDADataset', 58 | source=dict( 59 | type='SynthiaDataset', 60 | data_root='data/synthia/', 61 | img_dir='RGB', 62 | ann_dir='GT/LABELS', 63 | pipeline=synthia_train_pipeline), 64 | target=dict( 65 | type='CityscapesDataset', 66 | data_root='data/cityscapes/', 67 | img_dir='leftImg8bit/train', 68 | ann_dir='gtFine/train', 69 | pipeline=cityscapes_train_pipeline)), 70 | val=dict( 71 | type='CityscapesDataset', 72 | data_root='data/cityscapes/', 73 | img_dir='leftImg8bit/val', 74 | ann_dir='gtFine/val', 75 | pipeline=test_pipeline), 76 | test=dict( 77 | type='CityscapesDataset', 78 | data_root='data/cityscapes/', 79 | img_dir='leftImg8bit/val', 80 | ann_dir='gtFine/val', 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | # yapf:disable 4 | log_config = dict( 5 | interval=50, 6 | hooks=[ 7 | dict(type='TextLoggerHook', by_epoch=False), 8 | # dict(type='TensorboardLoggerHook') 9 | ]) 10 | # yapf:enable 11 | dist_params = dict(backend='nccl') 12 | log_level = 'INFO' 13 | load_from = None 14 | resume_from = None 15 | workflow = [('train', 1)] 16 | cudnn_benchmark = True 17 | -------------------------------------------------------------------------------- /configs/_base_/models/daformer_aspp_mitb5.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # DAFormer w/o DSC in Tab. 7 7 | 8 | _base_ = ['daformer_conv1_mitb5.py'] 9 | 10 | norm_cfg = dict(type='BN', requires_grad=True) 11 | model = dict( 12 | decode_head=dict( 13 | decoder_params=dict( 14 | fusion_cfg=dict( 15 | _delete_=True, 16 | type='aspp', 17 | sep=False, 18 | dilations=(1, 6, 12, 18), 19 | pool=False, 20 | act_cfg=dict(type='ReLU'), 21 | norm_cfg=norm_cfg)))) 22 | -------------------------------------------------------------------------------- /configs/_base_/models/daformer_conv1_mitb5.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # This is the same as SegFormer but with 256 embed_dims 7 | # SegF. with C_e=256 in Tab. 7 8 | 9 | # model settings 10 | norm_cfg = dict(type='BN', requires_grad=True) 11 | find_unused_parameters = True 12 | model = dict( 13 | type='EncoderDecoder', 14 | pretrained='pretrained/mit_b5.pth', 15 | backbone=dict(type='mit_b5', style='pytorch'), 16 | decode_head=dict( 17 | type='DAFormerHead', 18 | in_channels=[64, 128, 320, 512], 19 | in_index=[0, 1, 2, 3], 20 | channels=256, 21 | dropout_ratio=0.1, 22 | num_classes=19, 23 | norm_cfg=norm_cfg, 24 | align_corners=False, 25 | decoder_params=dict( 26 | embed_dims=256, 27 | embed_cfg=dict(type='mlp', act_cfg=None, norm_cfg=None), 28 | embed_neck_cfg=dict(type='mlp', act_cfg=None, norm_cfg=None), 29 | fusion_cfg=dict( 30 | type='conv', 31 | kernel_size=1, 32 | act_cfg=dict(type='ReLU'), 33 | norm_cfg=norm_cfg), 34 | ), 35 | loss_decode=dict( 36 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 37 | # model training and testing settings 38 | train_cfg=dict(), 39 | test_cfg=dict(mode='whole')) 40 | -------------------------------------------------------------------------------- /configs/_base_/models/daformer_isa_mitb5.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # ISA Fusion in Tab. 7 7 | 8 | _base_ = ['daformer_conv1_mitb5.py'] 9 | 10 | norm_cfg = dict(type='BN', requires_grad=True) 11 | model = dict( 12 | decode_head=dict( 13 | decoder_params=dict( 14 | fusion_cfg=dict( 15 | _delete_=True, 16 | type='isa', 17 | isa_channels=256, 18 | key_query_num_convs=1, 19 | down_factor=(8, 8), 20 | act_cfg=dict(type='ReLU'), 21 | norm_cfg=norm_cfg)))) 22 | -------------------------------------------------------------------------------- /configs/_base_/models/daformer_sepaspp_bottleneck_mitb5.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # Context only at F4 in Tab. 7 7 | 8 | _base_ = ['daformer_conv1_mitb5.py'] 9 | 10 | norm_cfg = dict(type='BN', requires_grad=True) 11 | model = dict( 12 | neck=dict(type='SegFormerAdapter', scales=[8]), 13 | decode_head=dict( 14 | decoder_params=dict( 15 | embed_neck_cfg=dict( 16 | _delete_=True, 17 | type='rawconv_and_aspp', 18 | kernel_size=1, 19 | sep=True, 20 | dilations=(1, 6, 12, 18), 21 | pool=False, 22 | act_cfg=dict(type='ReLU'), 23 | norm_cfg=norm_cfg)))) 24 | -------------------------------------------------------------------------------- /configs/_base_/models/daformer_sepaspp_mitb5.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # DAFormer (with context-aware feature fusion) in Tab. 7 7 | 8 | _base_ = ['daformer_conv1_mitb5.py'] 9 | 10 | norm_cfg = dict(type='BN', requires_grad=True) 11 | model = dict( 12 | decode_head=dict( 13 | decoder_params=dict( 14 | fusion_cfg=dict( 15 | _delete_=True, 16 | type='aspp', 17 | sep=True, 18 | dilations=(1, 6, 12, 18), 19 | pool=False, 20 | act_cfg=dict(type='ReLU'), 21 | norm_cfg=norm_cfg)))) 22 | -------------------------------------------------------------------------------- /configs/_base_/models/danet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - BN instead of SyncBN 4 | # - Removed auxiliary decoder 5 | 6 | # model settings 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | model = dict( 9 | type='EncoderDecoder', 10 | pretrained='open-mmlab://resnet50_v1c', 11 | backbone=dict( 12 | type='ResNetV1c', 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | dilations=(1, 1, 2, 4), 17 | strides=(1, 2, 1, 1), 18 | norm_cfg=norm_cfg, 19 | norm_eval=False, 20 | style='pytorch', 21 | contract_dilation=True), 22 | decode_head=dict( 23 | type='DAHead', 24 | in_channels=2048, 25 | in_index=3, 26 | channels=512, 27 | pam_channels=64, 28 | dropout_ratio=0.1, 29 | num_classes=19, 30 | norm_cfg=norm_cfg, 31 | align_corners=False, 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 34 | # model training and testing settings 35 | train_cfg=dict(), 36 | test_cfg=dict(mode='whole')) 37 | -------------------------------------------------------------------------------- /configs/_base_/models/deeplabv2_r50-d8.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # model settings 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | model = dict( 9 | type='EncoderDecoder', 10 | pretrained='open-mmlab://resnet50_v1c', 11 | backbone=dict( 12 | type='ResNetV1c', 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | dilations=(1, 1, 2, 4), 17 | strides=(1, 2, 1, 1), 18 | norm_cfg=norm_cfg, 19 | norm_eval=False, 20 | style='pytorch', 21 | contract_dilation=True), 22 | decode_head=dict( 23 | type='DLV2Head', 24 | in_channels=2048, 25 | in_index=3, 26 | dilations=(6, 12, 18, 24), 27 | num_classes=19, 28 | align_corners=False, 29 | init_cfg=dict( 30 | type='Normal', std=0.01, override=dict(name='aspp_modules')), 31 | loss_decode=dict( 32 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 33 | # model training and testing settings 34 | train_cfg=dict(), 35 | test_cfg=dict(mode='whole')) 36 | -------------------------------------------------------------------------------- /configs/_base_/models/deeplabv2red_r50-d8.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | _base_ = ['deeplabv2_r50-d8.py'] 7 | # Previous UDA methods only use the dilation rates 6 and 12 for DeepLabV2. 8 | # This might be a bit hidden as it is caused by a return statement WITHIN 9 | # a loop over the dilation rates: 10 | # https://github.com/wasidennis/AdaptSegNet/blob/fca9ff0f09dab45d44bf6d26091377ac66607028/model/deeplab.py#L116 11 | model = dict(decode_head=dict(dilations=(6, 12))) 12 | -------------------------------------------------------------------------------- /configs/_base_/models/deeplabv3plus_r50-d8.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - BN instead of SyncBN 4 | # - Removed auxiliary decoder 5 | 6 | # model settings 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | model = dict( 9 | type='EncoderDecoder', 10 | pretrained='open-mmlab://resnet50_v1c', 11 | backbone=dict( 12 | type='ResNetV1c', 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | dilations=(1, 1, 2, 4), 17 | strides=(1, 2, 1, 1), 18 | norm_cfg=norm_cfg, 19 | norm_eval=False, 20 | style='pytorch', 21 | contract_dilation=True), 22 | decode_head=dict( 23 | type='DepthwiseSeparableASPPHead', 24 | in_channels=2048, 25 | in_index=3, 26 | channels=512, 27 | dilations=(1, 12, 24, 36), 28 | c1_in_channels=256, 29 | c1_channels=48, 30 | dropout_ratio=0.1, 31 | num_classes=19, 32 | norm_cfg=norm_cfg, 33 | align_corners=False, 34 | loss_decode=dict( 35 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 36 | # model training and testing settings 37 | train_cfg=dict(), 38 | test_cfg=dict(mode='whole')) 39 | -------------------------------------------------------------------------------- /configs/_base_/models/isanet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - BN instead of SyncBN 4 | # - Removed auxiliary decoder 5 | 6 | # model settings 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | model = dict( 9 | type='EncoderDecoder', 10 | pretrained='open-mmlab://resnet50_v1c', 11 | backbone=dict( 12 | type='ResNetV1c', 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | dilations=(1, 1, 2, 4), 17 | strides=(1, 2, 1, 1), 18 | norm_cfg=norm_cfg, 19 | norm_eval=False, 20 | style='pytorch', 21 | contract_dilation=True), 22 | decode_head=dict( 23 | type='ISAHead', 24 | in_channels=2048, 25 | in_index=3, 26 | channels=512, 27 | isa_channels=256, 28 | down_factor=(8, 8), 29 | dropout_ratio=0.1, 30 | num_classes=19, 31 | norm_cfg=norm_cfg, 32 | align_corners=False, 33 | loss_decode=dict( 34 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 35 | # model training and testing settings 36 | train_cfg=dict(), 37 | test_cfg=dict(mode='whole')) 38 | -------------------------------------------------------------------------------- /configs/_base_/models/segformer.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/NVlabs/SegFormer 2 | # Modifications: BN instead of SyncBN 3 | # This work is licensed under the NVIDIA Source Code License 4 | # A copy of the license is available at resources/license_segformer 5 | 6 | # model settings 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | find_unused_parameters = True 9 | model = dict( 10 | type='EncoderDecoder', 11 | pretrained=None, 12 | backbone=dict(type='IMTRv21_5', style='pytorch'), 13 | decode_head=dict( 14 | type='SegFormerHead', 15 | in_channels=[64, 128, 320, 512], 16 | in_index=[0, 1, 2, 3], 17 | channels=128, 18 | dropout_ratio=0.1, 19 | num_classes=19, 20 | norm_cfg=norm_cfg, 21 | align_corners=False, 22 | decoder_params=dict(), 23 | loss_decode=dict( 24 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 25 | # model training and testing settings 26 | train_cfg=dict(), 27 | test_cfg=dict(mode='whole')) 28 | -------------------------------------------------------------------------------- /configs/_base_/models/segformer_b5.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/NVlabs/SegFormer 2 | # Modifications: BN instead of SyncBN 3 | # This work is licensed under the NVIDIA Source Code License 4 | # A copy of the license is available at resources/license_segformer 5 | 6 | _base_ = ['../../_base_/models/segformer.py'] 7 | 8 | # model settings 9 | norm_cfg = dict(type='BN', requires_grad=True) 10 | find_unused_parameters = True 11 | model = dict( 12 | type='EncoderDecoder', 13 | pretrained='pretrained/mit_b5.pth', 14 | backbone=dict(type='mit_b5', style='pytorch'), 15 | decode_head=dict( 16 | type='SegFormerHead', 17 | in_channels=[64, 128, 320, 512], 18 | in_index=[0, 1, 2, 3], 19 | channels=128, 20 | dropout_ratio=0.1, 21 | num_classes=19, 22 | norm_cfg=norm_cfg, 23 | align_corners=False, 24 | decoder_params=dict(embed_dim=768, conv_kernel_size=1), 25 | loss_decode=dict( 26 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 27 | # model training and testing settings 28 | train_cfg=dict(), 29 | test_cfg=dict(mode='whole')) 30 | -------------------------------------------------------------------------------- /configs/_base_/models/segformer_r101.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/NVlabs/SegFormer 2 | # Modifications: 3 | # - BN instead of SyncBN 4 | # - Replace MiT with ResNet backbone 5 | # This work is licensed under the NVIDIA Source Code License 6 | # A copy of the license is available at resources/license_segformer 7 | 8 | _base_ = ['../../_base_/models/segformer.py'] 9 | 10 | # model settings 11 | norm_cfg = dict(type='BN', requires_grad=True) 12 | find_unused_parameters = True 13 | model = dict( 14 | type='EncoderDecoder', 15 | pretrained='open-mmlab://resnet101_v1c', 16 | backbone=dict( 17 | type='ResNetV1c', 18 | depth=101, 19 | num_stages=4, 20 | out_indices=(0, 1, 2, 3), 21 | dilations=(1, 1, 2, 4), 22 | strides=(1, 2, 1, 1), 23 | norm_cfg=norm_cfg, 24 | norm_eval=False, 25 | style='pytorch', 26 | contract_dilation=True), 27 | decode_head=dict( 28 | type='SegFormerHead', 29 | in_channels=[256, 512, 1024, 2048], 30 | in_index=[0, 1, 2, 3], 31 | channels=128, 32 | dropout_ratio=0.1, 33 | num_classes=19, 34 | norm_cfg=norm_cfg, 35 | align_corners=False, 36 | decoder_params=dict(embed_dim=768, conv_kernel_size=1), 37 | loss_decode=dict( 38 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 39 | # model training and testing settings 40 | train_cfg=dict(), 41 | test_cfg=dict(mode='whole')) 42 | -------------------------------------------------------------------------------- /configs/_base_/models/upernet_ch256_mit.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | _base_ = ['upernet_mit.py'] 7 | 8 | model = dict(decode_head=dict(channels=256, )) 9 | -------------------------------------------------------------------------------- /configs/_base_/models/upernet_mit.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - BN instead of SyncBN 4 | # - Removed auxiliary decoder 5 | # - Replace ResNet encoder with MiT encoder 6 | 7 | # model settings 8 | norm_cfg = dict(type='BN', requires_grad=True) 9 | find_unused_parameters = True 10 | model = dict( 11 | type='EncoderDecoder', 12 | pretrained=None, 13 | backbone=dict(type='IMTRv21_5', style='pytorch'), 14 | decode_head=dict( 15 | type='UPerHead', 16 | in_channels=[64, 128, 320, 512], 17 | in_index=[0, 1, 2, 3], 18 | pool_scales=(1, 2, 3, 6), 19 | channels=512, 20 | dropout_ratio=0.1, 21 | num_classes=19, 22 | norm_cfg=norm_cfg, 23 | align_corners=False, 24 | loss_decode=dict( 25 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 26 | # model training and testing settings 27 | train_cfg=dict(), 28 | test_cfg=dict(mode='whole')) 29 | -------------------------------------------------------------------------------- /configs/_base_/schedules/adamw.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | # optimizer 4 | optimizer = dict( 5 | type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01) 6 | optimizer_config = dict() 7 | -------------------------------------------------------------------------------- /configs/_base_/schedules/poly10.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # learning policy 7 | lr_config = dict(policy='poly', power=1.0, min_lr=1e-4, by_epoch=False) 8 | -------------------------------------------------------------------------------- /configs/_base_/schedules/poly10warm.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # learning policy 7 | lr_config = dict( 8 | policy='poly', 9 | warmup='linear', 10 | warmup_iters=1500, 11 | warmup_ratio=1e-6, 12 | power=1.0, 13 | min_lr=0.0, 14 | by_epoch=False) 15 | -------------------------------------------------------------------------------- /configs/_base_/uda/dacs.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # Baseline UDA 7 | uda = dict( 8 | type='DACS', 9 | alpha=0.99, 10 | pseudo_threshold=0.968, 11 | pseudo_weight_ignore_top=0, 12 | pseudo_weight_ignore_bottom=0, 13 | imnet_feature_dist_lambda=0, 14 | imnet_feature_dist_classes=None, 15 | imnet_feature_dist_scale_min_ratio=None, 16 | mix='class', 17 | blur=True, 18 | color_jitter_strength=0.2, 19 | color_jitter_probability=0.2, 20 | debug_img_interval=1000, 21 | print_grad_magnitude=False, 22 | ) 23 | use_ddp_wrapper = True 24 | -------------------------------------------------------------------------------- /configs/_base_/uda/dacs_a999_fdthings.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # UDA with Thing-Class ImageNet Feature Distance + Increased Alpha 7 | _base_ = ['dacs.py'] 8 | uda = dict( 9 | alpha=0.999, 10 | imnet_feature_dist_lambda=0.005, 11 | imnet_feature_dist_classes=[6, 7, 11, 12, 13, 14, 15, 16, 17, 18], 12 | imnet_feature_dist_scale_min_ratio=0.75, 13 | ) 14 | -------------------------------------------------------------------------------- /configs/_base_/uda/dacs_cda.py: -------------------------------------------------------------------------------- 1 | # Baseline UDA 2 | uda = dict( 3 | type='DACS_CDA', 4 | alpha=0.99, 5 | pseudo_threshold=0.968, 6 | pseudo_weight_ignore_top=0, 7 | pseudo_weight_ignore_bottom=0, 8 | imnet_feature_dist_lambda=0, 9 | imnet_feature_dist_classes=None, 10 | imnet_feature_dist_scale_min_ratio=None, 11 | mix='class', 12 | blur=True, 13 | color_jitter_strength=0.2, 14 | color_jitter_probability=0.2, 15 | debug_img_interval=1000, 16 | print_grad_magnitude=False, 17 | cda_level=None, 18 | cda_tgt_lambda=1, 19 | cheat_level={'attn': True, 'output': True}, 20 | dacs_temp=-1, 21 | attn_sup_arch=None, 22 | attn_sup_patch_size=None, 23 | attn_sup_classes=None, 24 | valid_masking=False, 25 | attn_src_lambda="lambda x, y:0", 26 | attn_tgt_lambda="lambda x, y:0", 27 | attn_s2t_lambda="lambda x, y:0", 28 | attn_t2s_lambda="lambda x, y:0", 29 | attn_lambda=0, 30 | src_attn_tea=False, 31 | ) 32 | use_ddp_wrapper = True 33 | -------------------------------------------------------------------------------- /configs/_base_/uda/dacs_fd.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # UDA with ImageNet Feature Distance 7 | _base_ = ['dacs.py'] 8 | uda = dict(imnet_feature_dist_lambda=0.005, ) 9 | -------------------------------------------------------------------------------- /configs/_base_/uda/dacs_fdthings.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # UDA with Thing-Class ImageNet Feature Distance 7 | _base_ = ['dacs.py'] 8 | uda = dict( 9 | imnet_feature_dist_lambda=0.005, 10 | imnet_feature_dist_classes=[6, 7, 11, 12, 13, 14, 15, 16, 17, 18], 11 | imnet_feature_dist_scale_method='ratio', 12 | imnet_feature_dist_scale_min_ratio=0.75, 13 | ) 14 | -------------------------------------------------------------------------------- /configs/cdac/cs2acdc_uda_dacs_cda_mitb5_b2_s0.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/default_runtime.py', 3 | # DAFormer Network Architecture 4 | '../_base_/models/daformer_sepaspp_mitb5.py', 5 | # Cityscapes->ACDC Data Loading 6 | '../_base_/datasets/uda_cityscapes_to_acdc_512x512.py', 7 | # Basic UDA Self-Training 8 | '../_base_/uda/dacs_cda.py', 9 | # AdamW Optimizer 10 | '../_base_/schedules/adamw.py', 11 | # Linear Learning Rate Warmup with Subsequent Linear Decay 12 | '../_base_/schedules/poly10warm.py' 13 | ] 14 | # Random Seed 15 | seed = 0 16 | model = dict( 17 | backbone=dict( 18 | switch_element='q', 19 | detach=True, 20 | ) 21 | ) 22 | # Modifications to Basic UDA 23 | uda = dict( 24 | # Increased Alpha 25 | alpha=0.999, 26 | # Thing-Class Feature Distance 27 | imnet_feature_dist_lambda=0.005, 28 | imnet_feature_dist_classes=[6, 7, 11, 12, 13, 14, 15, 16, 17, 18], 29 | imnet_feature_dist_scale_min_ratio=0.75, 30 | cda_src_lambda="lambda x, y:0.5", 31 | cda_tgt_lambda="lambda x, y:0.5", 32 | cda_s2t_lambda="lambda x, y:0.5", 33 | cda_t2s_lambda="lambda x, y:0.5", 34 | attn_lambda=1, 35 | branch=4, 36 | src_attn_tea=True, 37 | cheat_level={'attn': False, 'output': False}, 38 | valid_masking=True, 39 | # Pseudo-Label Crop 40 | pseudo_weight_ignore_top=15, 41 | pseudo_weight_ignore_bottom=120) 42 | data = dict( 43 | train=dict( 44 | # Rare Class Sampling 45 | rare_class_sampling=dict( 46 | min_pixels=3000, class_temp=0.01, min_crop_ratio=0.5))) 47 | # Optimizer Hyperparameters 48 | optimizer_config = None 49 | optimizer = dict( 50 | lr=6e-05, 51 | paramwise_cfg=dict( 52 | custom_keys=dict( 53 | head=dict(lr_mult=10.0), 54 | pos_block=dict(decay_mult=0.0), 55 | norm=dict(decay_mult=0.0)))) 56 | n_gpus = 1 57 | runner = dict(type='IterBasedRunner', max_iters=40000) 58 | # Logging Configuration 59 | checkpoint_config = dict(by_epoch=False, interval=40000, max_keep_ckpts=1) 60 | evaluation = dict(interval=4000, metric='mIoU') 61 | # Meta Information for Result Analysis 62 | name = 'cs2acdc_uda_dacs_cda_mitb5_b2_s0' 63 | exp = 'basic' 64 | name_dataset = 'cityscapes2acdc' 65 | name_architecture = 'daformer_sepaspp_mitb5' 66 | name_encoder = 'mitb5' 67 | name_decoder = 'daformer_sepaspp' 68 | name_uda = 'dacs_cda_a999_feat_reg_0.01_cpl' 69 | name_opt = 'adamw_6e-05_pmTrue_poly10warm_1x2_40k' 70 | -------------------------------------------------------------------------------- /configs/cdac/gta2cs_uda_dacs_cda_mitb5_b2_s0.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/default_runtime.py', 3 | # DAFormer Network Architecture 4 | '../_base_/models/daformer_sepaspp_mitb5.py', 5 | # GTA->Cityscapes Data Loading 6 | '../_base_/datasets/uda_gta_to_cityscapes_512x512.py', 7 | # Basic UDA Self-Training 8 | '../_base_/uda/dacs_cda.py', 9 | # AdamW Optimizer 10 | '../_base_/schedules/adamw.py', 11 | # Linear Learning Rate Warmup with Subsequent Linear Decay 12 | '../_base_/schedules/poly10warm.py' 13 | ] 14 | # Random Seed 15 | seed = 0 16 | model = dict( 17 | backbone=dict( 18 | switch_element='q', 19 | detach=True, 20 | ) 21 | ) 22 | # Modifications to Basic UDA 23 | uda = dict( 24 | # Increased Alpha 25 | alpha=0.999, 26 | # Thing-Class Feature Distance 27 | imnet_feature_dist_lambda=0.005, 28 | imnet_feature_dist_classes=[6, 7, 11, 12, 13, 14, 15, 16, 17, 18], 29 | imnet_feature_dist_scale_min_ratio=0.75, 30 | cda_src_lambda="lambda x, y:0.5", 31 | cda_tgt_lambda="lambda x, y:0.5", 32 | cda_s2t_lambda="lambda x, y:0.5", 33 | cda_t2s_lambda="lambda x, y:0.5", 34 | attn_lambda=1, 35 | branch=4, 36 | src_attn_tea=True, 37 | cheat_level={'attn': False, 'output': False}, 38 | valid_masking=True, 39 | # Pseudo-Label Crop 40 | pseudo_weight_ignore_top=15, 41 | pseudo_weight_ignore_bottom=120) 42 | data = dict( 43 | train=dict( 44 | # Rare Class Sampling 45 | rare_class_sampling=dict( 46 | min_pixels=3000, class_temp=0.01, min_crop_ratio=0.5))) 47 | # Optimizer Hyperparameters 48 | optimizer_config = None 49 | optimizer = dict( 50 | lr=6e-05, 51 | paramwise_cfg=dict( 52 | custom_keys=dict( 53 | head=dict(lr_mult=10.0), 54 | pos_block=dict(decay_mult=0.0), 55 | norm=dict(decay_mult=0.0)))) 56 | n_gpus = 1 57 | runner = dict(type='IterBasedRunner', max_iters=40000) 58 | # Logging Configuration 59 | checkpoint_config = dict(by_epoch=False, interval=40000, max_keep_ckpts=1) 60 | evaluation = dict(interval=4000, metric='mIoU') 61 | # Meta Information for Result Analysis 62 | name = 'gta2cs_uda_dacs_cda_mitb5_b2_s0' 63 | exp = 'basic' 64 | name_dataset = 'gta2cityscapes' 65 | name_architecture = 'daformer_sepaspp_mitb5' 66 | name_encoder = 'mitb5' 67 | name_decoder = 'daformer_sepaspp' 68 | name_uda = 'dacs_cda_a999_feat_reg_0.01_cpl' 69 | name_opt = 'adamw_6e-05_pmTrue_poly10warm_1x2_40k' 70 | -------------------------------------------------------------------------------- /configs/cdac/synthia2cs_uda_dacs_cda_mitb5_b2_s0.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/default_runtime.py', 3 | # DAFormer Network Architecture 4 | '../_base_/models/daformer_sepaspp_mitb5.py', 5 | # synthia->Cityscapes Data Loading 6 | '../_base_/datasets/uda_synthia_to_cityscapes_512x512.py', 7 | # Basic UDA Self-Training 8 | '../_base_/uda/dacs_cda.py', 9 | # AdamW Optimizer 10 | '../_base_/schedules/adamw.py', 11 | # Linear Learning Rate Warmup with Subsequent Linear Decay 12 | '../_base_/schedules/poly10warm.py' 13 | ] 14 | # Random Seed 15 | seed = 0 16 | model = dict( 17 | backbone=dict( 18 | switch_element='q', 19 | detach=True, 20 | ) 21 | ) 22 | # Modifications to Basic UDA 23 | uda = dict( 24 | # Increased Alpha 25 | alpha=0.999, 26 | # Thing-Class Feature Distance 27 | imnet_feature_dist_lambda=0.005, 28 | imnet_feature_dist_classes=[6, 7, 11, 12, 13, 14, 15, 16, 17, 18], 29 | imnet_feature_dist_scale_min_ratio=0.75, 30 | cda_src_lambda="lambda x, y:0.5", 31 | cda_tgt_lambda="lambda x, y:0.5", 32 | cda_s2t_lambda="lambda x, y:0.5", 33 | cda_t2s_lambda="lambda x, y:0.5", 34 | attn_lambda=1, 35 | branch=4, 36 | src_attn_tea=True, 37 | cheat_level={'attn': False, 'output': False}, 38 | valid_masking=True, 39 | # Pseudo-Label Crop 40 | pseudo_weight_ignore_top=15, 41 | pseudo_weight_ignore_bottom=120) 42 | data = dict( 43 | train=dict( 44 | # Rare Class Sampling 45 | rare_class_sampling=dict( 46 | min_pixels=3000, class_temp=0.01, min_crop_ratio=0.5))) 47 | # Optimizer Hyperparameters 48 | optimizer_config = None 49 | optimizer = dict( 50 | lr=6e-05, 51 | paramwise_cfg=dict( 52 | custom_keys=dict( 53 | head=dict(lr_mult=10.0), 54 | pos_block=dict(decay_mult=0.0), 55 | norm=dict(decay_mult=0.0)))) 56 | n_gpus = 1 57 | runner = dict(type='IterBasedRunner', max_iters=40000) 58 | # Logging Configuration 59 | checkpoint_config = dict(by_epoch=False, interval=40000, max_keep_ckpts=1) 60 | evaluation = dict(interval=4000, metric='mIoU') 61 | # Meta Information for Result Analysis 62 | name = 'synthia2cs_uda_dacs_cda_mitb5_b2_s0' 63 | exp = 'basic' 64 | name_dataset = 'synthia2cityscapes' 65 | name_architecture = 'daformer_sepaspp_mitb5' 66 | name_encoder = 'mitb5' 67 | name_decoder = 'daformer_sepaspp' 68 | name_uda = 'dacs_cda_a999_feat_reg_0.01_cpl' 69 | name_opt = 'adamw_6e-05_pmTrue_poly10warm_1x2_40k' 70 | -------------------------------------------------------------------------------- /configs/daformer/gta2cs_uda_warm_fdthings_rcs_croppl_a999_daformer_mitb5_s0.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | _base_ = [ 7 | '../_base_/default_runtime.py', 8 | # DAFormer Network Architecture 9 | '../_base_/models/daformer_sepaspp_mitb5.py', 10 | # GTA->Cityscapes Data Loading 11 | '../_base_/datasets/uda_gta_to_cityscapes_512x512.py', 12 | # Basic UDA Self-Training 13 | '../_base_/uda/dacs.py', 14 | # AdamW Optimizer 15 | '../_base_/schedules/adamw.py', 16 | # Linear Learning Rate Warmup with Subsequent Linear Decay 17 | '../_base_/schedules/poly10warm.py' 18 | ] 19 | # Random Seed 20 | seed = 0 21 | # Modifications to Basic UDA 22 | uda = dict( 23 | # Increased Alpha 24 | alpha=0.999, 25 | # Thing-Class Feature Distance 26 | imnet_feature_dist_lambda=0.005, 27 | imnet_feature_dist_classes=[6, 7, 11, 12, 13, 14, 15, 16, 17, 18], 28 | imnet_feature_dist_scale_min_ratio=0.75, 29 | # Pseudo-Label Crop 30 | pseudo_weight_ignore_top=15, 31 | pseudo_weight_ignore_bottom=120) 32 | data = dict( 33 | train=dict( 34 | # Rare Class Sampling 35 | rare_class_sampling=dict( 36 | min_pixels=3000, class_temp=0.01, min_crop_ratio=0.5))) 37 | # Optimizer Hyperparameters 38 | optimizer_config = None 39 | optimizer = dict( 40 | lr=6e-05, 41 | paramwise_cfg=dict( 42 | custom_keys=dict( 43 | head=dict(lr_mult=10.0), 44 | pos_block=dict(decay_mult=0.0), 45 | norm=dict(decay_mult=0.0)))) 46 | n_gpus = 1 47 | runner = dict(type='IterBasedRunner', max_iters=40000) 48 | # Logging Configuration 49 | checkpoint_config = dict(by_epoch=False, interval=40000, max_keep_ckpts=1) 50 | evaluation = dict(interval=4000, metric='mIoU') 51 | # Meta Information for Result Analysis 52 | name = 'gta2cs_uda_warm_fdthings_rcs_croppl_a999_daformer_mitb5_s0' 53 | exp = 'basic' 54 | name_dataset = 'gta2cityscapes' 55 | name_architecture = 'daformer_sepaspp_mitb5' 56 | name_encoder = 'mitb5' 57 | name_decoder = 'daformer_sepaspp' 58 | name_uda = 'dacs_a999_fd_things_rcs0.01_cpl' 59 | name_opt = 'adamw_6e-05_pmTrue_poly10warm_1x2_40k' 60 | -------------------------------------------------------------------------------- /demo/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangkaihong/CDAC/160e3328cae8fb9a61b71529ea562251e120ae34/demo/demo.png -------------------------------------------------------------------------------- /demo/image_demo.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Config and checkpoint update 4 | # - Saving instead of showing prediction 5 | 6 | import os 7 | from argparse import ArgumentParser 8 | 9 | import mmcv 10 | from tools.test import update_legacy_cfg 11 | 12 | from mmseg.apis import inference_segmentor, init_segmentor 13 | from mmseg.core.evaluation import get_classes, get_palette 14 | 15 | 16 | def main(): 17 | parser = ArgumentParser() 18 | parser.add_argument('img', help='Image file') 19 | parser.add_argument('config', help='Config file') 20 | parser.add_argument('checkpoint', help='Checkpoint file') 21 | parser.add_argument( 22 | '--device', default='cuda:0', help='Device used for inference') 23 | parser.add_argument( 24 | '--palette', 25 | default='cityscapes', 26 | help='Color palette used for segmentation map') 27 | parser.add_argument( 28 | '--opacity', 29 | type=float, 30 | default=0.5, 31 | help='Opacity of painted segmentation map. In (0, 1] range.') 32 | args = parser.parse_args() 33 | 34 | # build the model from a config file and a checkpoint file 35 | cfg = mmcv.Config.fromfile(args.config) 36 | cfg = update_legacy_cfg(cfg) 37 | model = init_segmentor( 38 | cfg, 39 | args.checkpoint, 40 | device=args.device, 41 | classes=get_classes(args.palette), 42 | palette=get_palette(args.palette), 43 | revise_checkpoint=[(r'^module\.', ''), ('model.', '')]) 44 | # test a single image 45 | result = inference_segmentor(model, args.img) 46 | # show the results 47 | file, extension = os.path.splitext(args.img) 48 | pred_file = f'{file}_pred{extension}' 49 | assert pred_file != args.img 50 | model.show_result( 51 | args.img, 52 | result, 53 | palette=get_palette(args.palette), 54 | out_file=pred_file, 55 | show=False, 56 | opacity=args.opacity) 57 | print('Save prediction to', pred_file) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | from .version import __version__, version_info 4 | 5 | MMCV_MIN = '1.3.7' 6 | MMCV_MAX = '1.4.0' 7 | 8 | 9 | def digit_version(version_str): 10 | digit_version = [] 11 | for x in version_str.split('.'): 12 | if x.isdigit(): 13 | digit_version.append(int(x)) 14 | elif x.find('rc') != -1: 15 | patch_version = x.split('rc') 16 | digit_version.append(int(patch_version[0]) - 1) 17 | digit_version.append(int(patch_version[1])) 18 | return digit_version 19 | 20 | 21 | mmcv_min_version = digit_version(MMCV_MIN) 22 | mmcv_max_version = digit_version(MMCV_MAX) 23 | mmcv_version = digit_version(mmcv.__version__) 24 | 25 | 26 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 27 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 28 | f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 29 | 30 | __all__ = ['__version__', 'version_info'] 31 | -------------------------------------------------------------------------------- /mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 4 | from .test import multi_gpu_test, single_gpu_test 5 | from .train import get_root_logger, set_random_seed, train_segmentor 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 9 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 10 | 'show_result_pyplot' 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/apis/inference.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Override palette, classes, and state dict keys 4 | 5 | import matplotlib.pyplot as plt 6 | import mmcv 7 | import torch 8 | from mmcv.parallel import collate, scatter 9 | from mmcv.runner import load_checkpoint 10 | 11 | from mmseg.datasets.pipelines import Compose 12 | from mmseg.models import build_segmentor 13 | 14 | 15 | def init_segmentor(config, 16 | checkpoint=None, 17 | device='cuda:0', 18 | classes=None, 19 | palette=None, 20 | revise_checkpoint=[(r'^module\.', '')]): 21 | """Initialize a segmentor from config file. 22 | 23 | Args: 24 | config (str or :obj:`mmcv.Config`): Config file path or the config 25 | object. 26 | checkpoint (str, optional): Checkpoint path. If left as None, the model 27 | will not load any weights. 28 | device (str, optional) CPU/CUDA device option. Default 'cuda:0'. 29 | Use 'cpu' for loading model on CPU. 30 | Returns: 31 | nn.Module: The constructed segmentor. 32 | """ 33 | if isinstance(config, str): 34 | config = mmcv.Config.fromfile(config) 35 | elif not isinstance(config, mmcv.Config): 36 | raise TypeError('config must be a filename or Config object, ' 37 | 'but got {}'.format(type(config))) 38 | config.model.pretrained = None 39 | config.model.train_cfg = None 40 | model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) 41 | if checkpoint is not None: 42 | checkpoint = load_checkpoint( 43 | model, 44 | checkpoint, 45 | map_location='cpu', 46 | revise_keys=revise_checkpoint) 47 | model.CLASSES = checkpoint['meta']['CLASSES'] if classes is None \ 48 | else classes 49 | model.PALETTE = checkpoint['meta']['PALETTE'] if palette is None \ 50 | else palette 51 | model.cfg = config # save the config in the model for convenience 52 | model.to(device) 53 | model.eval() 54 | return model 55 | 56 | 57 | class LoadImage: 58 | """A simple pipeline to load image.""" 59 | 60 | def __call__(self, results): 61 | """Call function to load images into results. 62 | 63 | Args: 64 | results (dict): A result dict contains the file name 65 | of the image to be read. 66 | 67 | Returns: 68 | dict: ``results`` will be returned containing loaded image. 69 | """ 70 | 71 | if isinstance(results['img'], str): 72 | results['filename'] = results['img'] 73 | results['ori_filename'] = results['img'] 74 | else: 75 | results['filename'] = None 76 | results['ori_filename'] = None 77 | img = mmcv.imread(results['img']) 78 | results['img'] = img 79 | results['img_shape'] = img.shape 80 | results['ori_shape'] = img.shape 81 | return results 82 | 83 | 84 | def inference_segmentor(model, img, img2=None): 85 | """Inference image(s) with the segmentor. 86 | 87 | Args: 88 | model (nn.Module): The loaded segmentor. 89 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 90 | images. 91 | 92 | Returns: 93 | (list[Tensor]): The segmentation result. 94 | """ 95 | cfg = model.cfg 96 | device = next(model.parameters()).device # model device 97 | # build the data pipeline 98 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 99 | test_pipeline = Compose(test_pipeline) 100 | # prepare data 101 | data = dict(img=img) 102 | data = test_pipeline(data) 103 | data = collate([data], samples_per_gpu=1) 104 | if next(model.parameters()).is_cuda: 105 | # scatter to specified GPU 106 | data = scatter(data, [device])[0] 107 | else: 108 | data['img_metas'] = [i.data[0] for i in data['img_metas']] 109 | 110 | if img2 is not None: 111 | data_cd = dict(img=img2) 112 | data_cd = test_pipeline(data_cd) 113 | data_cd = collate([data_cd], samples_per_gpu=1) 114 | if next(model.parameters()).is_cuda: 115 | # scatter to specified GPU 116 | data_cd = scatter(data_cd, [device])[0] 117 | else: 118 | data_cd['img_metas'] = [i.data_cd[0] for i in data_cd['img_metas']] 119 | data['img_cd'] = data_cd['img'] 120 | data['img_cd_metas'] = data_cd['img_metas'] 121 | # forward the model 122 | with torch.no_grad(): 123 | result = model(return_loss=False, rescale=True, **data) 124 | return result 125 | 126 | 127 | def show_result_pyplot(model, 128 | img, 129 | result, 130 | palette=None, 131 | fig_size=(15, 10), 132 | opacity=0.5, 133 | title='', 134 | block=True): 135 | """Visualize the segmentation results on the image. 136 | 137 | Args: 138 | model (nn.Module): The loaded segmentor. 139 | img (str or np.ndarray): Image filename or loaded image. 140 | result (list): The segmentation result. 141 | palette (list[list[int]]] | None): The palette of segmentation 142 | map. If None is given, random palette will be generated. 143 | Default: None 144 | fig_size (tuple): Figure size of the pyplot figure. 145 | opacity(float): Opacity of painted segmentation map. 146 | Default 0.5. 147 | Must be in (0, 1] range. 148 | title (str): The title of pyplot figure. 149 | Default is ''. 150 | block (bool): Whether to block the pyplot figure. 151 | Default is True. 152 | """ 153 | if hasattr(model, 'module'): 154 | model = model.module 155 | img = model.show_result( 156 | img, result, palette=palette, show=False, opacity=opacity) 157 | plt.figure(figsize=fig_size) 158 | plt.imshow(mmcv.bgr2rgb(img)) 159 | plt.title(title) 160 | plt.tight_layout() 161 | plt.show(block=block) 162 | -------------------------------------------------------------------------------- /mmseg/apis/test.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import os.path as osp 4 | import tempfile 5 | 6 | import mmcv 7 | import numpy as np 8 | import torch 9 | from mmcv.engine import collect_results_cpu, collect_results_gpu 10 | from mmcv.image import tensor2imgs 11 | from mmcv.runner import get_dist_info 12 | 13 | 14 | def np2tmp(array, temp_file_name=None, tmpdir=None): 15 | """Save ndarray to local numpy file. 16 | 17 | Args: 18 | array (ndarray): Ndarray to save. 19 | temp_file_name (str): Numpy file name. If 'temp_file_name=None', this 20 | function will generate a file name with tempfile.NamedTemporaryFile 21 | to save ndarray. Default: None. 22 | tmpdir (str): Temporary directory to save Ndarray files. Default: None. 23 | 24 | Returns: 25 | str: The numpy file name. 26 | """ 27 | 28 | if temp_file_name is None: 29 | temp_file_name = tempfile.NamedTemporaryFile( 30 | suffix='.npy', delete=False, dir=tmpdir).name 31 | np.save(temp_file_name, array) 32 | return temp_file_name 33 | 34 | 35 | def single_gpu_test(model, 36 | data_loader, 37 | show=False, 38 | out_dir=None, 39 | efficient_test=False, 40 | opacity=0.5): 41 | """Test with single GPU. 42 | 43 | Args: 44 | model (nn.Module): Model to be tested. 45 | data_loader (utils.data.Dataloader): Pytorch data loader. 46 | show (bool): Whether show results during inference. Default: False. 47 | out_dir (str, optional): If specified, the results will be dumped into 48 | the directory to save output results. 49 | efficient_test (bool): Whether save the results as local numpy files to 50 | save CPU memory during evaluation. Default: False. 51 | opacity(float): Opacity of painted segmentation map. 52 | Default 0.5. 53 | Must be in (0, 1] range. 54 | Returns: 55 | list: The prediction results. 56 | """ 57 | 58 | model.eval() 59 | results = [] 60 | dataset = data_loader.dataset 61 | prog_bar = mmcv.ProgressBar(len(dataset)) 62 | if efficient_test: 63 | mmcv.mkdir_or_exist('.efficient_test') 64 | for i, data in enumerate(data_loader): 65 | with torch.no_grad(): 66 | result = model(return_loss=False, **data) 67 | 68 | if show or out_dir: 69 | img_tensor = data['img'][0] 70 | img_metas = data['img_metas'][0].data[0] 71 | imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) 72 | assert len(imgs) == len(img_metas) 73 | 74 | for img, img_meta in zip(imgs, img_metas): 75 | h, w, _ = img_meta['img_shape'] 76 | img_show = img[:h, :w, :] 77 | 78 | ori_h, ori_w = img_meta['ori_shape'][:-1] 79 | img_show = mmcv.imresize(img_show, (ori_w, ori_h)) 80 | 81 | if out_dir: 82 | out_file = osp.join(out_dir, img_meta['ori_filename']) 83 | else: 84 | out_file = None 85 | 86 | model.module.show_result( 87 | img_show, 88 | result, 89 | palette=dataset.PALETTE, 90 | show=show, 91 | out_file=out_file, 92 | opacity=opacity) 93 | 94 | if isinstance(result, list): 95 | if efficient_test: 96 | result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] 97 | results.extend(result) 98 | else: 99 | if efficient_test: 100 | result = np2tmp(result, tmpdir='.efficient_test') 101 | results.append(result) 102 | 103 | batch_size = len(result) 104 | for _ in range(batch_size): 105 | prog_bar.update() 106 | return results 107 | 108 | 109 | def multi_gpu_test(model, 110 | data_loader, 111 | tmpdir=None, 112 | gpu_collect=False, 113 | efficient_test=False): 114 | """Test model with multiple gpus. 115 | 116 | This method tests model with multiple gpus and collects the results 117 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 118 | it encodes results to gpu tensors and use gpu communication for results 119 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 120 | and collects them by the rank 0 worker. 121 | 122 | Args: 123 | model (nn.Module): Model to be tested. 124 | data_loader (utils.data.Dataloader): Pytorch data loader. 125 | tmpdir (str): Path of directory to save the temporary results from 126 | different gpus under cpu mode. The same path is used for efficient 127 | test. 128 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 129 | efficient_test (bool): Whether save the results as local numpy files to 130 | save CPU memory during evaluation. Default: False. 131 | 132 | Returns: 133 | list: The prediction results. 134 | """ 135 | 136 | model.eval() 137 | results = [] 138 | dataset = data_loader.dataset 139 | rank, world_size = get_dist_info() 140 | if rank == 0: 141 | prog_bar = mmcv.ProgressBar(len(dataset)) 142 | if efficient_test: 143 | mmcv.mkdir_or_exist('.efficient_test') 144 | for i, data in enumerate(data_loader): 145 | with torch.no_grad(): 146 | result = model(return_loss=False, rescale=True, **data) 147 | 148 | if isinstance(result, list): 149 | if efficient_test: 150 | result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] 151 | results.extend(result) 152 | else: 153 | if efficient_test: 154 | result = np2tmp(result, tmpdir='.efficient_test') 155 | results.append(result) 156 | 157 | if rank == 0: 158 | batch_size = len(result) 159 | for _ in range(batch_size * world_size): 160 | prog_bar.update() 161 | 162 | # collect results from all ranks 163 | if gpu_collect: 164 | results = collect_results_gpu(results, len(dataset)) 165 | else: 166 | results = collect_results_cpu(results, len(dataset), tmpdir) 167 | return results 168 | -------------------------------------------------------------------------------- /mmseg/apis/train.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Add ddp_wrapper from mmgen 4 | 5 | import random 6 | import warnings 7 | 8 | import mmcv 9 | import numpy as np 10 | import torch 11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 12 | from mmcv.runner import build_optimizer, build_runner 13 | 14 | from mmseg.core import DistEvalHook, EvalHook 15 | from mmseg.core.ddp_wrapper import DistributedDataParallelWrapper 16 | from mmseg.datasets import build_dataloader, build_dataset 17 | from mmseg.utils import get_root_logger 18 | 19 | 20 | def set_random_seed(seed, deterministic=False): 21 | """Set random seed. 22 | 23 | Args: 24 | seed (int): Seed to be used. 25 | deterministic (bool): Whether to set the deterministic option for 26 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 27 | to True and `torch.backends.cudnn.benchmark` to False. 28 | Default: False. 29 | """ 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | if deterministic: 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | 39 | def train_segmentor(model, 40 | dataset, 41 | cfg, 42 | distributed=False, 43 | validate=False, 44 | timestamp=None, 45 | meta=None): 46 | """Launch segmentor training.""" 47 | logger = get_root_logger(cfg.log_level) 48 | 49 | # prepare data loaders 50 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 51 | data_loaders = [ 52 | build_dataloader( 53 | ds, 54 | cfg.data.samples_per_gpu, 55 | cfg.data.workers_per_gpu, 56 | # cfg.gpus will be ignored if distributed 57 | len(cfg.gpu_ids), 58 | dist=distributed, 59 | seed=cfg.seed, 60 | drop_last=True) for ds in dataset 61 | ] 62 | 63 | # put model on gpus 64 | if distributed: 65 | find_unused_parameters = cfg.get('find_unused_parameters', False) 66 | use_ddp_wrapper = cfg.get('use_ddp_wrapper', False) 67 | # Sets the `find_unused_parameters` parameter in 68 | # torch.nn.parallel.DistributedDataParallel 69 | if use_ddp_wrapper: 70 | mmcv.print_log('Use DDP Wrapper.', 'mmseg') 71 | model = DistributedDataParallelWrapper( 72 | model.cuda(), 73 | device_ids=[torch.cuda.current_device()], 74 | broadcast_buffers=False, 75 | find_unused_parameters=find_unused_parameters) 76 | else: 77 | model = MMDistributedDataParallel( 78 | model.cuda(), 79 | device_ids=[torch.cuda.current_device()], 80 | broadcast_buffers=False, 81 | find_unused_parameters=find_unused_parameters) 82 | else: 83 | model = MMDataParallel( 84 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 85 | 86 | # build runner 87 | optimizer = build_optimizer(model, cfg.optimizer) 88 | 89 | if cfg.get('runner') is None: 90 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 91 | warnings.warn( 92 | 'config is now expected to have a `runner` section, ' 93 | 'please set `runner` in your config.', UserWarning) 94 | 95 | runner = build_runner( 96 | cfg.runner, 97 | default_args=dict( 98 | model=model, 99 | batch_processor=None, 100 | optimizer=optimizer, 101 | work_dir=cfg.work_dir, 102 | logger=logger, 103 | meta=meta)) 104 | 105 | # register hooks 106 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 107 | cfg.checkpoint_config, cfg.log_config, 108 | cfg.get('momentum_config', None)) 109 | 110 | # an ugly walkaround to make the .log and .log.json filenames the same 111 | runner.timestamp = timestamp 112 | 113 | # register eval hooks 114 | if validate: 115 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 116 | val_dataloader = build_dataloader( 117 | val_dataset, 118 | samples_per_gpu=1, 119 | workers_per_gpu=cfg.data.workers_per_gpu, 120 | dist=distributed, 121 | shuffle=False) 122 | eval_cfg = cfg.get('evaluation', {}) 123 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 124 | eval_hook = DistEvalHook if distributed else EvalHook 125 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 126 | 127 | if cfg.resume_from: 128 | runner.resume(cfg.resume_from) 129 | elif cfg.load_from: 130 | runner.load_checkpoint(cfg.load_from) 131 | runner.run(data_loaders, cfg.workflow) 132 | -------------------------------------------------------------------------------- /mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .evaluation import * # noqa: F401, F403 4 | from .seg import * # noqa: F401, F403 5 | from .utils import * # noqa: F401, F403 6 | -------------------------------------------------------------------------------- /mmseg/core/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmgeneration 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel 6 | from mmcv.parallel.scatter_gather import scatter_kwargs 7 | from torch.cuda._utils import _get_device_index 8 | 9 | 10 | @MODULE_WRAPPERS.register_module('mmseg.DDPWrapper') 11 | class DistributedDataParallelWrapper(nn.Module): 12 | """A DistributedDataParallel wrapper for models in MMGeneration. 13 | 14 | In MMedting, there is a need to wrap different modules in the models 15 | with separate DistributedDataParallel. Otherwise, it will cause 16 | errors for GAN training. 17 | More specific, the GAN model, usually has two sub-modules: 18 | generator and discriminator. If we wrap both of them in one 19 | standard DistributedDataParallel, it will cause errors during training, 20 | because when we update the parameters of the generator (or discriminator), 21 | the parameters of the discriminator (or generator) is not updated, which is 22 | not allowed for DistributedDataParallel. 23 | So we design this wrapper to separately wrap DistributedDataParallel 24 | for generator and discriminator. 25 | In this wrapper, we perform two operations: 26 | 1. Wrap the modules in the models with separate MMDistributedDataParallel. 27 | Note that only modules with parameters will be wrapped. 28 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. 29 | Note that the arguments of this wrapper is the same as those in 30 | `torch.nn.parallel.distributed.DistributedDataParallel`. 31 | Args: 32 | module (nn.Module): Module that needs to be wrapped. 33 | device_ids (list[int | `torch.device`]): Same as that in 34 | `torch.nn.parallel.distributed.DistributedDataParallel`. 35 | dim (int, optional): Same as that in the official scatter function in 36 | pytorch. Defaults to 0. 37 | broadcast_buffers (bool): Same as that in 38 | `torch.nn.parallel.distributed.DistributedDataParallel`. 39 | Defaults to False. 40 | find_unused_parameters (bool, optional): Same as that in 41 | `torch.nn.parallel.distributed.DistributedDataParallel`. 42 | Traverse the autograd graph of all tensors contained in returned 43 | value of the wrapped module’s forward function. Defaults to False. 44 | kwargs (dict): Other arguments used in 45 | `torch.nn.parallel.distributed.DistributedDataParallel`. 46 | """ 47 | 48 | def __init__(self, 49 | module, 50 | device_ids, 51 | dim=0, 52 | broadcast_buffers=False, 53 | find_unused_parameters=False, 54 | **kwargs): 55 | super().__init__() 56 | assert len(device_ids) == 1, ( 57 | 'Currently, DistributedDataParallelWrapper only supports one' 58 | 'single CUDA device for each process.' 59 | f'The length of device_ids must be 1, but got {len(device_ids)}.') 60 | self.module = module 61 | self.dim = dim 62 | self.to_ddp( 63 | device_ids=device_ids, 64 | dim=dim, 65 | broadcast_buffers=broadcast_buffers, 66 | find_unused_parameters=find_unused_parameters, 67 | **kwargs) 68 | self.output_device = _get_device_index(device_ids[0], True) 69 | 70 | def to_ddp(self, device_ids, dim, broadcast_buffers, 71 | find_unused_parameters, **kwargs): 72 | """Wrap models with separate MMDistributedDataParallel. 73 | 74 | It only wraps the modules with parameters. 75 | """ 76 | for name, module in self.module._modules.items(): 77 | if next(module.parameters(), None) is None: 78 | module = module.cuda() 79 | elif all(not p.requires_grad for p in module.parameters()): 80 | module = module.cuda() 81 | else: 82 | module = MMDistributedDataParallel( 83 | module.cuda(), 84 | device_ids=device_ids, 85 | dim=dim, 86 | broadcast_buffers=broadcast_buffers, 87 | find_unused_parameters=find_unused_parameters, 88 | **kwargs) 89 | self.module._modules[name] = module 90 | 91 | def scatter(self, inputs, kwargs, device_ids): 92 | """Scatter function. 93 | 94 | Args: 95 | inputs (Tensor): Input Tensor. 96 | kwargs (dict): Args for 97 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 98 | device_ids (int): Device id. 99 | """ 100 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 101 | 102 | def forward(self, *inputs, **kwargs): 103 | """Forward function. 104 | 105 | Args: 106 | inputs (tuple): Input data. 107 | kwargs (dict): Args for 108 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 109 | """ 110 | inputs, kwargs = self.scatter(inputs, kwargs, 111 | [torch.cuda.current_device()]) 112 | return self.module(*inputs[0], **kwargs[0]) 113 | 114 | def train_step(self, *inputs, **kwargs): 115 | """Train step function. 116 | 117 | Args: 118 | inputs (Tensor): Input Tensor. 119 | kwargs (dict): Args for 120 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 121 | """ 122 | inputs, kwargs = self.scatter(inputs, kwargs, 123 | [torch.cuda.current_device()]) 124 | output = self.module.train_step(*inputs[0], **kwargs[0]) 125 | return output 126 | 127 | def val_step(self, *inputs, **kwargs): 128 | """Validation step function. 129 | 130 | Args: 131 | inputs (tuple): Input data. 132 | kwargs (dict): Args for ``scatter_kwargs``. 133 | """ 134 | inputs, kwargs = self.scatter(inputs, kwargs, 135 | [torch.cuda.current_device()]) 136 | output = self.module.val_step(*inputs[0], **kwargs[0]) 137 | return output 138 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .class_names import get_classes, get_palette 4 | from .eval_hooks import DistEvalHook, EvalHook 5 | from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou 6 | 7 | __all__ = [ 8 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 9 | 'eval_metrics', 'get_classes', 'get_palette' 10 | ] 11 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import os.path as osp 4 | 5 | import torch.distributed as dist 6 | from mmcv.runner import DistEvalHook as _DistEvalHook 7 | from mmcv.runner import EvalHook as _EvalHook 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | 10 | 11 | class EvalHook(_EvalHook): 12 | """Single GPU EvalHook, with efficient test support. 13 | 14 | Args: 15 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 16 | If set to True, it will perform by epoch. Otherwise, by iteration. 17 | Default: False. 18 | efficient_test (bool): Whether save the results as local numpy files to 19 | save CPU memory during evaluation. Default: False. 20 | Returns: 21 | list: The prediction results. 22 | """ 23 | 24 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 25 | 26 | def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): 27 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 28 | self.efficient_test = efficient_test 29 | 30 | def _do_evaluate(self, runner): 31 | """perform evaluation and save ckpt.""" 32 | if not self._should_evaluate(runner): 33 | return 34 | 35 | from mmseg.apis import single_gpu_test 36 | results = single_gpu_test( 37 | runner.model, 38 | self.dataloader, 39 | show=False, 40 | efficient_test=self.efficient_test) 41 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 42 | key_score = self.evaluate(runner, results) 43 | if self.save_best: 44 | self._save_ckpt(runner, key_score) 45 | 46 | 47 | class DistEvalHook(_DistEvalHook): 48 | """Distributed EvalHook, with efficient test support. 49 | 50 | Args: 51 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 52 | If set to True, it will perform by epoch. Otherwise, by iteration. 53 | Default: False. 54 | efficient_test (bool): Whether save the results as local numpy files to 55 | save CPU memory during evaluation. Default: False. 56 | Returns: 57 | list: The prediction results. 58 | """ 59 | 60 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 61 | 62 | def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): 63 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 64 | self.efficient_test = efficient_test 65 | 66 | def _do_evaluate(self, runner): 67 | """perform evaluation and save ckpt.""" 68 | # Synchronization of BatchNorm's buffer (running_mean 69 | # and running_var) is not supported in the DDP of pytorch, 70 | # which may cause the inconsistent performance of models in 71 | # different ranks, so we broadcast BatchNorm's buffers 72 | # of rank 0 to other ranks to avoid this. 73 | if self.broadcast_bn_buffer: 74 | model = runner.model 75 | for name, module in model.named_modules(): 76 | if isinstance(module, 77 | _BatchNorm) and module.track_running_stats: 78 | dist.broadcast(module.running_var, 0) 79 | dist.broadcast(module.running_mean, 0) 80 | 81 | if not self._should_evaluate(runner): 82 | return 83 | 84 | tmpdir = self.tmpdir 85 | if tmpdir is None: 86 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 87 | 88 | from mmseg.apis import multi_gpu_test 89 | results = multi_gpu_test( 90 | runner.model, 91 | self.dataloader, 92 | tmpdir=tmpdir, 93 | gpu_collect=self.gpu_collect, 94 | efficient_test=self.efficient_test) 95 | if runner.rank == 0: 96 | print('\n') 97 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 98 | key_score = self.evaluate(runner, results) 99 | 100 | if self.save_best: 101 | self._save_ckpt(runner, key_score) 102 | -------------------------------------------------------------------------------- /mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .builder import build_pixel_sampler 4 | from .sampler import BasePixelSampler, OHEMPixelSampler 5 | 6 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from mmcv.utils import Registry, build_from_cfg 4 | 5 | PIXEL_SAMPLERS = Registry('pixel sampler') 6 | 7 | 8 | def build_pixel_sampler(cfg, **default_args): 9 | """Build pixel sampler for segmentation map.""" 10 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 11 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .base_pixel_sampler import BasePixelSampler 4 | from .ohem_pixel_sampler import OHEMPixelSampler 5 | 6 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | 6 | class BasePixelSampler(metaclass=ABCMeta): 7 | """Base class of pixel sampler.""" 8 | 9 | def __init__(self, **kwargs): 10 | pass 11 | 12 | @abstractmethod 13 | def sample(self, seg_logit, seg_label): 14 | """Placeholder for sample function.""" 15 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ..builder import PIXEL_SAMPLERS 7 | from .base_pixel_sampler import BasePixelSampler 8 | 9 | 10 | @PIXEL_SAMPLERS.register_module() 11 | class OHEMPixelSampler(BasePixelSampler): 12 | """Online Hard Example Mining Sampler for segmentation. 13 | 14 | Args: 15 | context (nn.Module): The context of sampler, subclass of 16 | :obj:`BaseDecodeHead`. 17 | thresh (float, optional): The threshold for hard example selection. 18 | Below which, are prediction with low confidence. If not 19 | specified, the hard examples will be pixels of top ``min_kept`` 20 | loss. Default: None. 21 | min_kept (int, optional): The minimum number of predictions to keep. 22 | Default: 100000. 23 | """ 24 | 25 | def __init__(self, context, thresh=None, min_kept=100000): 26 | super(OHEMPixelSampler, self).__init__() 27 | self.context = context 28 | assert min_kept > 1 29 | self.thresh = thresh 30 | self.min_kept = min_kept 31 | 32 | def sample(self, seg_logit, seg_label): 33 | """Sample pixels that have high loss or with low prediction confidence. 34 | 35 | Args: 36 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 37 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 38 | 39 | Returns: 40 | torch.Tensor: segmentation weight, shape (N, H, W) 41 | """ 42 | with torch.no_grad(): 43 | assert seg_logit.shape[2:] == seg_label.shape[2:] 44 | assert seg_label.shape[1] == 1 45 | seg_label = seg_label.squeeze(1).long() 46 | batch_kept = self.min_kept * seg_label.size(0) 47 | valid_mask = seg_label != self.context.ignore_index 48 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 49 | valid_seg_weight = seg_weight[valid_mask] 50 | if self.thresh is not None: 51 | seg_prob = F.softmax(seg_logit, dim=1) 52 | 53 | tmp_seg_label = seg_label.clone().unsqueeze(1) 54 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 55 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 56 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 57 | 58 | if sort_prob.numel() > 0: 59 | min_threshold = sort_prob[min(batch_kept, 60 | sort_prob.numel() - 1)] 61 | else: 62 | min_threshold = 0.0 63 | threshold = max(min_threshold, self.thresh) 64 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 65 | else: 66 | losses = self.context.loss_decode( 67 | seg_logit, 68 | seg_label, 69 | weight=None, 70 | ignore_index=self.context.ignore_index, 71 | reduction_override='none') 72 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 73 | _, sort_indices = losses[valid_mask].sort(descending=True) 74 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 75 | 76 | seg_weight[valid_mask] = valid_seg_weight 77 | 78 | return seg_weight 79 | -------------------------------------------------------------------------------- /mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .misc import add_prefix 4 | 5 | __all__ = ['add_prefix'] 6 | -------------------------------------------------------------------------------- /mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def add_prefix(inputs, prefix): 5 | """Add prefix for dict. 6 | 7 | Args: 8 | inputs (dict): The input dict with str keys. 9 | prefix (str): The prefix to add. 10 | 11 | Returns: 12 | 13 | dict: The dict with keys updated with ``prefix``. 14 | """ 15 | 16 | outputs = dict() 17 | for name, value in inputs.items(): 18 | outputs[f'{prefix}.{name}'] = value 19 | 20 | return outputs 21 | -------------------------------------------------------------------------------- /mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional datasets 3 | 4 | from .acdc import ACDCDataset 5 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 6 | from .cityscapes import CityscapesDataset 7 | from .custom import CustomDataset 8 | from .dark_zurich import DarkZurichDataset 9 | from .dataset_wrappers import ConcatDataset, RepeatDataset 10 | from .gta import GTADataset 11 | from .synthia import SynthiaDataset 12 | from .uda_dataset import UDADataset 13 | 14 | __all__ = [ 15 | 'CustomDataset', 16 | 'build_dataloader', 17 | 'ConcatDataset', 18 | 'RepeatDataset', 19 | 'DATASETS', 20 | 'build_dataset', 21 | 'PIPELINES', 22 | 'CityscapesDataset', 23 | 'GTADataset', 24 | 'SynthiaDataset', 25 | 'UDADataset', 26 | 'ACDCDataset', 27 | 'DarkZurichDataset', 28 | ] 29 | -------------------------------------------------------------------------------- /mmseg/datasets/acdc.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | from .builder import DATASETS 7 | from .cityscapes import CityscapesDataset 8 | 9 | 10 | @DATASETS.register_module() 11 | class ACDCDataset(CityscapesDataset): 12 | 13 | def __init__(self, **kwargs): 14 | super(ACDCDataset, self).__init__( 15 | img_suffix='_rgb_anon.png', 16 | seg_map_suffix='_gt_labelTrainIds.png', 17 | **kwargs) 18 | -------------------------------------------------------------------------------- /mmseg/datasets/dark_zurich.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | from .builder import DATASETS 7 | from .cityscapes import CityscapesDataset 8 | 9 | 10 | @DATASETS.register_module() 11 | class DarkZurichDataset(CityscapesDataset): 12 | 13 | def __init__(self, **kwargs): 14 | super(DarkZurichDataset, self).__init__( 15 | img_suffix='_rgb_anon.png', 16 | seg_map_suffix='_gt_labelTrainIds.png', 17 | **kwargs) 18 | -------------------------------------------------------------------------------- /mmseg/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 4 | 5 | from .builder import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class ConcatDataset(_ConcatDataset): 10 | """A wrapper of concatenated dataset. 11 | 12 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 13 | concat the group flag for image aspect ratio. 14 | 15 | Args: 16 | datasets (list[:obj:`Dataset`]): A list of datasets. 17 | """ 18 | 19 | def __init__(self, datasets): 20 | super(ConcatDataset, self).__init__(datasets) 21 | self.CLASSES = datasets[0].CLASSES 22 | self.PALETTE = datasets[0].PALETTE 23 | 24 | 25 | @DATASETS.register_module() 26 | class RepeatDataset(object): 27 | """A wrapper of repeated dataset. 28 | 29 | The length of repeated dataset will be `times` larger than the original 30 | dataset. This is useful when the data loading time is long but the dataset 31 | is small. Using RepeatDataset can reduce the data loading time between 32 | epochs. 33 | 34 | Args: 35 | dataset (:obj:`Dataset`): The dataset to be repeated. 36 | times (int): Repeat times. 37 | """ 38 | 39 | def __init__(self, dataset, times): 40 | self.dataset = dataset 41 | self.times = times 42 | self.CLASSES = dataset.CLASSES 43 | self.PALETTE = dataset.PALETTE 44 | self._ori_len = len(self.dataset) 45 | 46 | def __getitem__(self, idx): 47 | """Get item from original dataset.""" 48 | return self.dataset[idx % self._ori_len] 49 | 50 | def __len__(self): 51 | """The length is multiplied by ``times``""" 52 | return self.times * self._ori_len 53 | -------------------------------------------------------------------------------- /mmseg/datasets/gta.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | from . import CityscapesDataset 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class GTADataset(CustomDataset): 13 | CLASSES = CityscapesDataset.CLASSES 14 | PALETTE = CityscapesDataset.PALETTE 15 | 16 | def __init__(self, **kwargs): 17 | assert kwargs.get('split') in [None, 'train'] 18 | if 'split' in kwargs: 19 | kwargs.pop('split') 20 | super(GTADataset, self).__init__( 21 | img_suffix='.png', 22 | seg_map_suffix='_labelTrainIds.png', 23 | split=None, 24 | **kwargs) 25 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .compose import Compose 4 | from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, 5 | Transpose, to_tensor) 6 | from .loading import LoadAnnotations, LoadImageFromFile 7 | from .test_time_aug import MultiScaleFlipAug 8 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 9 | PhotoMetricDistortion, RandomCrop, RandomFlip, 10 | RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) 11 | 12 | __all__ = [ 13 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 14 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 15 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 16 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 17 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' 18 | ] 19 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import collections 4 | 5 | from mmcv.utils import build_from_cfg 6 | 7 | from ..builder import PIPELINES 8 | 9 | 10 | @PIPELINES.register_module() 11 | class Compose(object): 12 | """Compose multiple transforms sequentially. 13 | 14 | Args: 15 | transforms (Sequence[dict | callable]): Sequence of transform object or 16 | config dict to be composed. 17 | """ 18 | 19 | def __init__(self, transforms): 20 | assert isinstance(transforms, collections.abc.Sequence) 21 | self.transforms = [] 22 | for transform in transforms: 23 | if isinstance(transform, dict): 24 | transform = build_from_cfg(transform, PIPELINES) 25 | self.transforms.append(transform) 26 | elif callable(transform): 27 | self.transforms.append(transform) 28 | else: 29 | raise TypeError('transform must be callable or a dict') 30 | 31 | def __call__(self, data): 32 | """Call function to apply transforms sequentially. 33 | 34 | Args: 35 | data (dict): A result dict contains the data to transform. 36 | 37 | Returns: 38 | dict: Transformed data. 39 | """ 40 | 41 | for t in self.transforms: 42 | data = t(data) 43 | if data is None: 44 | return None 45 | return data 46 | 47 | def __repr__(self): 48 | format_string = self.__class__.__name__ + '(' 49 | for t in self.transforms: 50 | format_string += '\n' 51 | format_string += f' {t}' 52 | format_string += '\n)' 53 | return format_string 54 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/test_time_aug.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import warnings 4 | 5 | import mmcv 6 | 7 | from ..builder import PIPELINES 8 | from .compose import Compose 9 | 10 | 11 | @PIPELINES.register_module() 12 | class MultiScaleFlipAug(object): 13 | """Test-time augmentation with multiple scales and flipping. 14 | 15 | An example configuration is as followed: 16 | 17 | .. code-block:: 18 | 19 | img_scale=(2048, 1024), 20 | img_ratios=[0.5, 1.0], 21 | flip=True, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ] 30 | 31 | After MultiScaleFLipAug with above configuration, the results are wrapped 32 | into lists of the same length as followed: 33 | 34 | .. code-block:: 35 | 36 | dict( 37 | img=[...], 38 | img_shape=[...], 39 | scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] 40 | flip=[False, True, False, True] 41 | ... 42 | ) 43 | 44 | Args: 45 | transforms (list[dict]): Transforms to apply in each augmentation. 46 | img_scale (None | tuple | list[tuple]): Images scales for resizing. 47 | img_ratios (float | list[float]): Image ratios for resizing 48 | flip (bool): Whether apply flip augmentation. Default: False. 49 | flip_direction (str | list[str]): Flip augmentation directions, 50 | options are "horizontal" and "vertical". If flip_direction is list, 51 | multiple flip augmentations will be applied. 52 | It has no effect when flip == False. Default: "horizontal". 53 | """ 54 | 55 | def __init__(self, 56 | transforms, 57 | img_scale, 58 | img_ratios=None, 59 | flip=False, 60 | flip_direction='horizontal'): 61 | self.transforms = Compose(transforms) 62 | if img_ratios is not None: 63 | img_ratios = img_ratios if isinstance(img_ratios, 64 | list) else [img_ratios] 65 | assert mmcv.is_list_of(img_ratios, float) 66 | if img_scale is None: 67 | # mode 1: given img_scale=None and a range of image ratio 68 | self.img_scale = None 69 | assert mmcv.is_list_of(img_ratios, float) 70 | elif isinstance(img_scale, tuple) and mmcv.is_list_of( 71 | img_ratios, float): 72 | assert len(img_scale) == 2 73 | # mode 2: given a scale and a range of image ratio 74 | self.img_scale = [(int(img_scale[0] * ratio), 75 | int(img_scale[1] * ratio)) 76 | for ratio in img_ratios] 77 | else: 78 | # mode 3: given multiple scales 79 | self.img_scale = img_scale if isinstance(img_scale, 80 | list) else [img_scale] 81 | assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None 82 | self.flip = flip 83 | self.img_ratios = img_ratios 84 | self.flip_direction = flip_direction if isinstance( 85 | flip_direction, list) else [flip_direction] 86 | assert mmcv.is_list_of(self.flip_direction, str) 87 | if not self.flip and self.flip_direction != ['horizontal']: 88 | warnings.warn( 89 | 'flip_direction has no effect when flip is set to False') 90 | if (self.flip 91 | and not any([t['type'] == 'RandomFlip' for t in transforms])): 92 | warnings.warn( 93 | 'flip has no effect when RandomFlip is not in transforms') 94 | 95 | def __call__(self, results): 96 | """Call function to apply test time augment transforms on results. 97 | 98 | Args: 99 | results (dict): Result dict contains the data to transform. 100 | 101 | Returns: 102 | dict[str: list]: The augmented data, where each value is wrapped 103 | into a list. 104 | """ 105 | 106 | aug_data = [] 107 | if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): 108 | h, w = results['img'].shape[:2] 109 | img_scale = [(int(w * ratio), int(h * ratio)) 110 | for ratio in self.img_ratios] 111 | else: 112 | img_scale = self.img_scale 113 | flip_aug = [False, True] if self.flip else [False] 114 | for scale in img_scale: 115 | for flip in flip_aug: 116 | for direction in self.flip_direction: 117 | _results = results.copy() 118 | _results['scale'] = scale 119 | _results['flip'] = flip 120 | _results['flip_direction'] = direction 121 | data = self.transforms(_results) 122 | aug_data.append(data) 123 | # list of dict to dict of list 124 | aug_data_dict = {key: [] for key in aug_data[0]} 125 | for data in aug_data: 126 | for key, val in data.items(): 127 | aug_data_dict[key].append(val) 128 | return aug_data_dict 129 | 130 | def __repr__(self): 131 | repr_str = self.__class__.__name__ 132 | repr_str += f'(transforms={self.transforms}, ' 133 | repr_str += f'img_scale={self.img_scale}, flip={self.flip})' 134 | repr_str += f'flip_direction={self.flip_direction}' 135 | return repr_str 136 | -------------------------------------------------------------------------------- /mmseg/datasets/synthia.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | from . import CityscapesDataset 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class SynthiaDataset(CustomDataset): 13 | CLASSES = CityscapesDataset.CLASSES 14 | PALETTE = CityscapesDataset.PALETTE 15 | 16 | def __init__(self, **kwargs): 17 | assert kwargs.get('split') in [None, 'train'] 18 | if 'split' in kwargs: 19 | kwargs.pop('split') 20 | super(SynthiaDataset, self).__init__( 21 | img_suffix='.png', 22 | seg_map_suffix='_labelTrainIds.png', 23 | split=None, 24 | **kwargs) 25 | -------------------------------------------------------------------------------- /mmseg/datasets/uda_dataset.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import json 7 | import os.path as osp 8 | 9 | import mmcv 10 | import numpy as np 11 | import torch 12 | 13 | from . import CityscapesDataset 14 | from .builder import DATASETS 15 | 16 | 17 | def get_rcs_class_probs(data_root, temperature): 18 | with open(osp.join(data_root, 'sample_class_stats.json'), 'r') as of: 19 | sample_class_stats = json.load(of) 20 | overall_class_stats = {} 21 | for s in sample_class_stats: 22 | s.pop('file') 23 | for c, n in s.items(): 24 | c = int(c) 25 | if c not in overall_class_stats: 26 | overall_class_stats[c] = n 27 | else: 28 | overall_class_stats[c] += n 29 | overall_class_stats = { 30 | k: v 31 | for k, v in sorted( 32 | overall_class_stats.items(), key=lambda item: item[1]) 33 | } 34 | freq = torch.tensor(list(overall_class_stats.values())) 35 | freq = freq / torch.sum(freq) 36 | freq = 1 - freq 37 | freq = torch.softmax(freq / temperature, dim=-1) 38 | 39 | return list(overall_class_stats.keys()), freq.numpy() 40 | 41 | 42 | @DATASETS.register_module() 43 | class UDADataset(object): 44 | 45 | def __init__(self, source, target, cfg): 46 | self.source = source 47 | self.target = target 48 | self.ignore_index = target.ignore_index 49 | self.CLASSES = target.CLASSES 50 | self.PALETTE = target.PALETTE 51 | assert target.ignore_index == source.ignore_index 52 | assert target.CLASSES == source.CLASSES 53 | assert target.PALETTE == source.PALETTE 54 | 55 | rcs_cfg = cfg.get('rare_class_sampling') 56 | self.rcs_enabled = rcs_cfg is not None 57 | if self.rcs_enabled: 58 | self.rcs_class_temp = rcs_cfg['class_temp'] 59 | self.rcs_min_crop_ratio = rcs_cfg['min_crop_ratio'] 60 | self.rcs_min_pixels = rcs_cfg['min_pixels'] 61 | 62 | self.rcs_classes, self.rcs_classprob = get_rcs_class_probs( 63 | cfg['source']['data_root'], self.rcs_class_temp) 64 | mmcv.print_log(f'RCS Classes: {self.rcs_classes}', 'mmseg') 65 | mmcv.print_log(f'RCS ClassProb: {self.rcs_classprob}', 'mmseg') 66 | 67 | with open( 68 | osp.join(cfg['source']['data_root'], 69 | 'samples_with_class.json'), 'r') as of: 70 | samples_with_class_and_n = json.load(of) 71 | samples_with_class_and_n = { 72 | int(k): v 73 | for k, v in samples_with_class_and_n.items() 74 | if int(k) in self.rcs_classes 75 | } 76 | self.samples_with_class = {} 77 | for c in self.rcs_classes: 78 | self.samples_with_class[c] = [] 79 | for file, pixels in samples_with_class_and_n[c]: 80 | if pixels > self.rcs_min_pixels: 81 | self.samples_with_class[c].append(file.split('/')[-1]) 82 | assert len(self.samples_with_class[c]) > 0 83 | self.file_to_idx = {} 84 | for i, dic in enumerate(self.source.img_infos): 85 | file = dic['ann']['seg_map'] 86 | if isinstance(self.source, CityscapesDataset): 87 | file = file.split('/')[-1] 88 | self.file_to_idx[file] = i 89 | 90 | def get_rare_class_sample(self): 91 | c = np.random.choice(self.rcs_classes, p=self.rcs_classprob) 92 | f1 = np.random.choice(self.samples_with_class[c]) 93 | i1 = self.file_to_idx[f1] 94 | s1 = self.source[i1] 95 | if self.rcs_min_crop_ratio > 0: 96 | for j in range(10): 97 | n_class = torch.sum(s1['gt_semantic_seg'].data == c) 98 | # mmcv.print_log(f'{j}: {n_class}', 'mmseg') 99 | if n_class > self.rcs_min_pixels * self.rcs_min_crop_ratio: 100 | break 101 | # Sample a new random crop from source image i1. 102 | # Please note, that self.source.__getitem__(idx) applies the 103 | # preprocessing pipeline to the loaded image, which includes 104 | # RandomCrop, and results in a new crop of the image. 105 | s1 = self.source[i1] 106 | i2 = np.random.choice(range(len(self.target))) 107 | s2 = self.target[i2] 108 | 109 | return { 110 | **s1, 'target_img_metas': s2['img_metas'], 111 | 'target_img': s2['img'] 112 | } 113 | 114 | def __getitem__(self, idx): 115 | if self.rcs_enabled: 116 | return self.get_rare_class_sample() 117 | else: 118 | s1 = self.source[idx // len(self.target)] 119 | s2 = self.target[idx % len(self.target)] 120 | return { 121 | **s1, 'target_img_metas': s2['img_metas'], 122 | 'target_img': s2['img'] 123 | } 124 | 125 | def __len__(self): 126 | return len(self.source) * len(self.target) 127 | -------------------------------------------------------------------------------- /mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * # noqa: F401,F403 2 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, UDA, 3 | build_backbone, build_head, build_loss, build_segmentor) 4 | from .decode_heads import * # noqa: F401,F403 5 | from .losses import * # noqa: F401,F403 6 | from .necks import * # noqa: F401,F403 7 | from .segmentors import * # noqa: F401,F403 8 | from .uda import * # noqa: F401,F403 9 | 10 | __all__ = [ 11 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'UDA', 'build_backbone', 12 | 'build_head', 'build_loss', 'build_segmentor' 13 | ] 14 | -------------------------------------------------------------------------------- /mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional backbones 3 | 4 | from .mix_transformer import (MixVisionTransformer, mit_b0, mit_b1, mit_b2, 5 | mit_b3, mit_b4, mit_b5) 6 | from .resnest import ResNeSt 7 | from .resnet import ResNet, ResNetV1c, ResNetV1d 8 | from .resnext import ResNeXt 9 | 10 | __all__ = [ 11 | 'ResNet', 12 | 'ResNetV1c', 13 | 'ResNetV1d', 14 | 'ResNeXt', 15 | 'ResNeSt', 16 | 'MixVisionTransformer', 17 | 'mit_b0', 18 | 'mit_b1', 19 | 'mit_b2', 20 | 'mit_b3', 21 | 'mit_b4', 22 | 'mit_b5', 23 | ] 24 | -------------------------------------------------------------------------------- /mmseg/models/backbones/resnext.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import math 4 | 5 | from mmcv.cnn import build_conv_layer, build_norm_layer 6 | 7 | from ..builder import BACKBONES 8 | from ..utils import ResLayer 9 | from .resnet import Bottleneck as _Bottleneck 10 | from .resnet import ResNet 11 | 12 | 13 | class Bottleneck(_Bottleneck): 14 | """Bottleneck block for ResNeXt. 15 | 16 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is 17 | "caffe", the stride-two layer is the first 1x1 conv layer. 18 | """ 19 | 20 | def __init__(self, 21 | inplanes, 22 | planes, 23 | groups=1, 24 | base_width=4, 25 | base_channels=64, 26 | **kwargs): 27 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs) 28 | 29 | if groups == 1: 30 | width = self.planes 31 | else: 32 | width = math.floor(self.planes * 33 | (base_width / base_channels)) * groups 34 | 35 | self.norm1_name, norm1 = build_norm_layer( 36 | self.norm_cfg, width, postfix=1) 37 | self.norm2_name, norm2 = build_norm_layer( 38 | self.norm_cfg, width, postfix=2) 39 | self.norm3_name, norm3 = build_norm_layer( 40 | self.norm_cfg, self.planes * self.expansion, postfix=3) 41 | 42 | self.conv1 = build_conv_layer( 43 | self.conv_cfg, 44 | self.inplanes, 45 | width, 46 | kernel_size=1, 47 | stride=self.conv1_stride, 48 | bias=False) 49 | self.add_module(self.norm1_name, norm1) 50 | fallback_on_stride = False 51 | self.with_modulated_dcn = False 52 | if self.with_dcn: 53 | fallback_on_stride = self.dcn.pop('fallback_on_stride', False) 54 | if not self.with_dcn or fallback_on_stride: 55 | self.conv2 = build_conv_layer( 56 | self.conv_cfg, 57 | width, 58 | width, 59 | kernel_size=3, 60 | stride=self.conv2_stride, 61 | padding=self.dilation, 62 | dilation=self.dilation, 63 | groups=groups, 64 | bias=False) 65 | else: 66 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN' 67 | self.conv2 = build_conv_layer( 68 | self.dcn, 69 | width, 70 | width, 71 | kernel_size=3, 72 | stride=self.conv2_stride, 73 | padding=self.dilation, 74 | dilation=self.dilation, 75 | groups=groups, 76 | bias=False) 77 | 78 | self.add_module(self.norm2_name, norm2) 79 | self.conv3 = build_conv_layer( 80 | self.conv_cfg, 81 | width, 82 | self.planes * self.expansion, 83 | kernel_size=1, 84 | bias=False) 85 | self.add_module(self.norm3_name, norm3) 86 | 87 | 88 | @BACKBONES.register_module() 89 | class ResNeXt(ResNet): 90 | """ResNeXt backbone. 91 | 92 | Args: 93 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 94 | in_channels (int): Number of input image channels. Normally 3. 95 | num_stages (int): Resnet stages, normally 4. 96 | groups (int): Group of resnext. 97 | base_width (int): Base width of resnext. 98 | strides (Sequence[int]): Strides of the first block of each stage. 99 | dilations (Sequence[int]): Dilation of each stage. 100 | out_indices (Sequence[int]): Output from which stages. 101 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 102 | layer is the 3x3 conv layer, otherwise the stride-two layer is 103 | the first 1x1 conv layer. 104 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means 105 | not freezing any parameters. 106 | norm_cfg (dict): dictionary to construct and config norm layer. 107 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 108 | freeze running stats (mean and var). Note: Effect on Batch Norm 109 | and its variants only. 110 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 111 | memory while slowing down the training speed. 112 | zero_init_residual (bool): whether to use zero init for last norm layer 113 | in resblocks to let them behave as identity. 114 | 115 | Example: 116 | >>> from mmseg.models import ResNeXt 117 | >>> import torch 118 | >>> self = ResNeXt(depth=50) 119 | >>> self.eval() 120 | >>> inputs = torch.rand(1, 3, 32, 32) 121 | >>> level_outputs = self.forward(inputs) 122 | >>> for level_out in level_outputs: 123 | ... print(tuple(level_out.shape)) 124 | (1, 256, 8, 8) 125 | (1, 512, 4, 4) 126 | (1, 1024, 2, 2) 127 | (1, 2048, 1, 1) 128 | """ 129 | 130 | arch_settings = { 131 | 50: (Bottleneck, (3, 4, 6, 3)), 132 | 101: (Bottleneck, (3, 4, 23, 3)), 133 | 152: (Bottleneck, (3, 8, 36, 3)) 134 | } 135 | 136 | def __init__(self, groups=1, base_width=4, **kwargs): 137 | self.groups = groups 138 | self.base_width = base_width 139 | super(ResNeXt, self).__init__(**kwargs) 140 | 141 | def make_res_layer(self, **kwargs): 142 | """Pack all blocks in a stage into a ``ResLayer``""" 143 | return ResLayer( 144 | groups=self.groups, 145 | base_width=self.base_width, 146 | base_channels=self.base_channels, 147 | **kwargs) 148 | -------------------------------------------------------------------------------- /mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Support UDA models 3 | 4 | import warnings 5 | 6 | from mmcv.cnn import MODELS as MMCV_MODELS 7 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 8 | from mmcv.utils import Registry 9 | 10 | MODELS = Registry('models', parent=MMCV_MODELS) 11 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 12 | 13 | BACKBONES = MODELS 14 | NECKS = MODELS 15 | HEADS = MODELS 16 | LOSSES = MODELS 17 | SEGMENTORS = MODELS 18 | UDA = MODELS 19 | 20 | 21 | def build_backbone(cfg): 22 | """Build backbone.""" 23 | return BACKBONES.build(cfg) 24 | 25 | 26 | def build_neck(cfg): 27 | """Build neck.""" 28 | return NECKS.build(cfg) 29 | 30 | 31 | def build_head(cfg): 32 | """Build head.""" 33 | return HEADS.build(cfg) 34 | 35 | 36 | def build_loss(cfg): 37 | """Build loss.""" 38 | return LOSSES.build(cfg) 39 | 40 | 41 | def build_train_model(cfg, train_cfg=None, test_cfg=None): 42 | """Build model.""" 43 | if train_cfg is not None or test_cfg is not None: 44 | warnings.warn( 45 | 'train_cfg and test_cfg is deprecated, ' 46 | 'please specify them in model', UserWarning) 47 | assert cfg.model.get('train_cfg') is None or train_cfg is None, \ 48 | 'train_cfg specified in both outer field and model field ' 49 | assert cfg.model.get('test_cfg') is None or test_cfg is None, \ 50 | 'test_cfg specified in both outer field and model field ' 51 | if 'uda' in cfg: 52 | cfg.uda['model'] = cfg.model 53 | cfg.uda['max_iters'] = cfg.runner.max_iters 54 | return UDA.build( 55 | cfg.uda, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 56 | else: 57 | return SEGMENTORS.build( 58 | cfg.model, 59 | default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 60 | 61 | 62 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 63 | """Build segmentor.""" 64 | if train_cfg is not None or test_cfg is not None: 65 | warnings.warn( 66 | 'train_cfg and test_cfg is deprecated, ' 67 | 'please specify them in model', UserWarning) 68 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 69 | 'train_cfg specified in both outer field and model field ' 70 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 71 | 'test_cfg specified in both outer field and model field ' 72 | return SEGMENTORS.build( 73 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 74 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional decode_heads 3 | 4 | from .aspp_head import ASPPHead 5 | from .da_head import DAHead 6 | from .daformer_head import DAFormerHead 7 | from .dlv2_head import DLV2Head 8 | from .fcn_head import FCNHead 9 | from .isa_head import ISAHead 10 | from .psp_head import PSPHead 11 | from .segformer_head import SegFormerHead 12 | from .sep_aspp_head import DepthwiseSeparableASPPHead 13 | from .uper_head import UPerHead 14 | 15 | __all__ = [ 16 | 'FCNHead', 17 | 'PSPHead', 18 | 'ASPPHead', 19 | 'UPerHead', 20 | 'DepthwiseSeparableASPPHead', 21 | 'DAHead', 22 | 'DLV2Head', 23 | 'SegFormerHead', 24 | 'DAFormerHead', 25 | 'ISAHead', 26 | ] 27 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/aspp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | class ASPPModule(nn.ModuleList): 13 | """Atrous Spatial Pyramid Pooling (ASPP) Module. 14 | 15 | Args: 16 | dilations (tuple[int]): Dilation rate of each layer. 17 | in_channels (int): Input channels. 18 | channels (int): Channels after modules, before conv_seg. 19 | conv_cfg (dict|None): Config of conv layers. 20 | norm_cfg (dict|None): Config of norm layers. 21 | act_cfg (dict): Config of activation layers. 22 | """ 23 | 24 | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, 25 | act_cfg): 26 | super(ASPPModule, self).__init__() 27 | self.dilations = dilations 28 | self.in_channels = in_channels 29 | self.channels = channels 30 | self.conv_cfg = conv_cfg 31 | self.norm_cfg = norm_cfg 32 | self.act_cfg = act_cfg 33 | for dilation in dilations: 34 | self.append( 35 | ConvModule( 36 | self.in_channels, 37 | self.channels, 38 | 1 if dilation == 1 else 3, 39 | dilation=dilation, 40 | padding=0 if dilation == 1 else dilation, 41 | conv_cfg=self.conv_cfg, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg)) 44 | 45 | def forward(self, x): 46 | """Forward function.""" 47 | aspp_outs = [] 48 | for aspp_module in self: 49 | aspp_outs.append(aspp_module(x)) 50 | 51 | return aspp_outs 52 | 53 | 54 | @HEADS.register_module() 55 | class ASPPHead(BaseDecodeHead): 56 | """Rethinking Atrous Convolution for Semantic Image Segmentation. 57 | 58 | This head is the implementation of `DeepLabV3 59 | `_. 60 | 61 | Args: 62 | dilations (tuple[int]): Dilation rates for ASPP module. 63 | Default: (1, 6, 12, 18). 64 | """ 65 | 66 | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 67 | super(ASPPHead, self).__init__(**kwargs) 68 | assert isinstance(dilations, (list, tuple)) 69 | self.dilations = dilations 70 | self.image_pool = nn.Sequential( 71 | nn.AdaptiveAvgPool2d(1), 72 | ConvModule( 73 | self.in_channels, 74 | self.channels, 75 | 1, 76 | conv_cfg=self.conv_cfg, 77 | norm_cfg=self.norm_cfg, 78 | act_cfg=self.act_cfg)) 79 | self.aspp_modules = ASPPModule( 80 | dilations, 81 | self.in_channels, 82 | self.channels, 83 | conv_cfg=self.conv_cfg, 84 | norm_cfg=self.norm_cfg, 85 | act_cfg=self.act_cfg) 86 | self.bottleneck = ConvModule( 87 | (len(dilations) + 1) * self.channels, 88 | self.channels, 89 | 3, 90 | padding=1, 91 | conv_cfg=self.conv_cfg, 92 | norm_cfg=self.norm_cfg, 93 | act_cfg=self.act_cfg) 94 | 95 | def forward(self, inputs): 96 | """Forward function.""" 97 | x = self._transform_inputs(inputs) 98 | aspp_outs = [ 99 | resize( 100 | self.image_pool(x), 101 | size=x.size()[2:], 102 | mode='bilinear', 103 | align_corners=self.align_corners) 104 | ] 105 | aspp_outs.extend(self.aspp_modules(x)) 106 | aspp_outs = torch.cat(aspp_outs, dim=1) 107 | output = self.bottleneck(aspp_outs) 108 | output = self.cls_seg(output) 109 | return output 110 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/dlv2_head.py: -------------------------------------------------------------------------------- 1 | from ..builder import HEADS 2 | from .aspp_head import ASPPModule 3 | from .decode_head import BaseDecodeHead 4 | 5 | 6 | @HEADS.register_module() 7 | class DLV2Head(BaseDecodeHead): 8 | 9 | def __init__(self, dilations=(6, 12, 18, 24), **kwargs): 10 | assert 'channels' not in kwargs 11 | assert 'dropout_ratio' not in kwargs 12 | assert 'norm_cfg' not in kwargs 13 | kwargs['channels'] = 1 14 | kwargs['dropout_ratio'] = 0 15 | kwargs['norm_cfg'] = None 16 | super(DLV2Head, self).__init__(**kwargs) 17 | del self.conv_seg 18 | assert isinstance(dilations, (list, tuple)) 19 | self.dilations = dilations 20 | self.aspp_modules = ASPPModule( 21 | dilations, 22 | self.in_channels, 23 | self.num_classes, 24 | conv_cfg=self.conv_cfg, 25 | norm_cfg=None, 26 | act_cfg=None) 27 | 28 | def forward(self, inputs): 29 | """Forward function.""" 30 | # for f in inputs: 31 | # mmcv.print_log(f'{f.shape}', 'mmseg') 32 | x = self._transform_inputs(inputs) 33 | aspp_outs = self.aspp_modules(x) 34 | out = aspp_outs[0] 35 | for i in range(len(aspp_outs) - 1): 36 | out += aspp_outs[i + 1] 37 | return out 38 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class FCNHead(BaseDecodeHead): 13 | """Fully Convolution Networks for Semantic Segmentation. 14 | 15 | This head is implemented of `FCNNet `_. 16 | 17 | Args: 18 | num_convs (int): Number of convs in the head. Default: 2. 19 | kernel_size (int): The kernel size for convs in the head. Default: 3. 20 | concat_input (bool): Whether concat the input and output of convs 21 | before classification layer. 22 | dilation (int): The dilation rate for convs in the head. Default: 1. 23 | """ 24 | 25 | def __init__(self, 26 | num_convs=2, 27 | kernel_size=3, 28 | concat_input=True, 29 | dilation=1, 30 | **kwargs): 31 | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) 32 | self.num_convs = num_convs 33 | self.concat_input = concat_input 34 | self.kernel_size = kernel_size 35 | super(FCNHead, self).__init__(**kwargs) 36 | if num_convs == 0: 37 | assert self.in_channels == self.channels 38 | 39 | conv_padding = (kernel_size // 2) * dilation 40 | convs = [] 41 | convs.append( 42 | ConvModule( 43 | self.in_channels, 44 | self.channels, 45 | kernel_size=kernel_size, 46 | padding=conv_padding, 47 | dilation=dilation, 48 | conv_cfg=self.conv_cfg, 49 | norm_cfg=self.norm_cfg, 50 | act_cfg=self.act_cfg)) 51 | for i in range(num_convs - 1): 52 | convs.append( 53 | ConvModule( 54 | self.channels, 55 | self.channels, 56 | kernel_size=kernel_size, 57 | padding=conv_padding, 58 | dilation=dilation, 59 | conv_cfg=self.conv_cfg, 60 | norm_cfg=self.norm_cfg, 61 | act_cfg=self.act_cfg)) 62 | if num_convs == 0: 63 | self.convs = nn.Identity() 64 | else: 65 | self.convs = nn.Sequential(*convs) 66 | if self.concat_input: 67 | self.conv_cat = ConvModule( 68 | self.in_channels + self.channels, 69 | self.channels, 70 | kernel_size=kernel_size, 71 | padding=kernel_size // 2, 72 | conv_cfg=self.conv_cfg, 73 | norm_cfg=self.norm_cfg, 74 | act_cfg=self.act_cfg) 75 | 76 | def forward(self, inputs): 77 | """Forward function.""" 78 | x = self._transform_inputs(inputs) 79 | output = self.convs(x) 80 | if self.concat_input: 81 | output = self.conv_cat(torch.cat([x, output], dim=1)) 82 | output = self.cls_seg(output) 83 | return output 84 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/psp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | class PPM(nn.ModuleList): 13 | """Pooling Pyramid Module used in PSPNet. 14 | 15 | Args: 16 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 17 | Module. 18 | in_channels (int): Input channels. 19 | channels (int): Channels after modules, before conv_seg. 20 | conv_cfg (dict|None): Config of conv layers. 21 | norm_cfg (dict|None): Config of norm layers. 22 | act_cfg (dict): Config of activation layers. 23 | align_corners (bool): align_corners argument of F.interpolate. 24 | """ 25 | 26 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 27 | act_cfg, align_corners, **kwargs): 28 | super(PPM, self).__init__() 29 | self.pool_scales = pool_scales 30 | self.align_corners = align_corners 31 | self.in_channels = in_channels 32 | self.channels = channels 33 | self.conv_cfg = conv_cfg 34 | self.norm_cfg = norm_cfg 35 | self.act_cfg = act_cfg 36 | for pool_scale in pool_scales: 37 | self.append( 38 | nn.Sequential( 39 | nn.AdaptiveAvgPool2d(pool_scale), 40 | ConvModule( 41 | self.in_channels, 42 | self.channels, 43 | 1, 44 | conv_cfg=self.conv_cfg, 45 | norm_cfg=self.norm_cfg, 46 | act_cfg=self.act_cfg, 47 | **kwargs))) 48 | 49 | def forward(self, x): 50 | """Forward function.""" 51 | ppm_outs = [] 52 | for ppm in self: 53 | ppm_out = ppm(x) 54 | upsampled_ppm_out = resize( 55 | ppm_out, 56 | size=x.size()[2:], 57 | mode='bilinear', 58 | align_corners=self.align_corners) 59 | ppm_outs.append(upsampled_ppm_out) 60 | return ppm_outs 61 | 62 | 63 | @HEADS.register_module() 64 | class PSPHead(BaseDecodeHead): 65 | """Pyramid Scene Parsing Network. 66 | 67 | This head is the implementation of 68 | `PSPNet `_. 69 | 70 | Args: 71 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 72 | Module. Default: (1, 2, 3, 6). 73 | """ 74 | 75 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 76 | super(PSPHead, self).__init__(**kwargs) 77 | assert isinstance(pool_scales, (list, tuple)) 78 | self.pool_scales = pool_scales 79 | self.psp_modules = PPM( 80 | self.pool_scales, 81 | self.in_channels, 82 | self.channels, 83 | conv_cfg=self.conv_cfg, 84 | norm_cfg=self.norm_cfg, 85 | act_cfg=self.act_cfg, 86 | align_corners=self.align_corners) 87 | self.bottleneck = ConvModule( 88 | self.in_channels + len(pool_scales) * self.channels, 89 | self.channels, 90 | 3, 91 | padding=1, 92 | conv_cfg=self.conv_cfg, 93 | norm_cfg=self.norm_cfg, 94 | act_cfg=self.act_cfg) 95 | 96 | def forward(self, inputs): 97 | """Forward function.""" 98 | x = self._transform_inputs(inputs) 99 | psp_outs = [x] 100 | psp_outs.extend(self.psp_modules(x)) 101 | psp_outs = torch.cat(psp_outs, dim=1) 102 | output = self.bottleneck(psp_outs) 103 | output = self.cls_seg(output) 104 | return output 105 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/segformer_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/NVlabs/SegFormer 2 | # Modifications: Model construction with loop 3 | # --------------------------------------------------------------- 4 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 5 | # 6 | # This work is licensed under the NVIDIA Source Code License 7 | # --------------------------------------------------------------- 8 | # A copy of the license is available at resources/license_segformer 9 | 10 | import torch 11 | import torch.nn as nn 12 | from mmcv.cnn import ConvModule 13 | 14 | from mmseg.ops import resize 15 | from ..builder import HEADS 16 | from .decode_head import BaseDecodeHead 17 | 18 | 19 | class MLP(nn.Module): 20 | """Linear Embedding.""" 21 | 22 | def __init__(self, input_dim=2048, embed_dim=768): 23 | super().__init__() 24 | self.proj = nn.Linear(input_dim, embed_dim) 25 | 26 | def forward(self, x): 27 | x = x.flatten(2).transpose(1, 2).contiguous() 28 | x = self.proj(x) 29 | return x 30 | 31 | 32 | @HEADS.register_module() 33 | class SegFormerHead(BaseDecodeHead): 34 | """ 35 | SegFormer: Simple and Efficient Design for Semantic Segmentation with 36 | Transformers 37 | """ 38 | 39 | def __init__(self, **kwargs): 40 | super(SegFormerHead, self).__init__( 41 | input_transform='multiple_select', **kwargs) 42 | 43 | decoder_params = kwargs['decoder_params'] 44 | embedding_dim = decoder_params['embed_dim'] 45 | conv_kernel_size = decoder_params['conv_kernel_size'] 46 | 47 | self.linear_c = {} 48 | for i, in_channels in zip(self.in_index, self.in_channels): 49 | self.linear_c[str(i)] = MLP( 50 | input_dim=in_channels, embed_dim=embedding_dim) 51 | self.linear_c = nn.ModuleDict(self.linear_c) 52 | 53 | self.linear_fuse = ConvModule( 54 | in_channels=embedding_dim * len(self.in_index), 55 | out_channels=embedding_dim, 56 | kernel_size=conv_kernel_size, 57 | padding=0 if conv_kernel_size == 1 else conv_kernel_size // 2, 58 | norm_cfg=kwargs['norm_cfg']) 59 | 60 | self.linear_pred = nn.Conv2d( 61 | embedding_dim, self.num_classes, kernel_size=1) 62 | 63 | def forward(self, inputs): 64 | x = inputs 65 | n, _, h, w = x[-1].shape 66 | # for f in x: 67 | # print(f.shape) 68 | 69 | _c = {} 70 | for i in self.in_index: 71 | # mmcv.print_log(f'{i}: {x[i].shape}, {self.linear_c[str(i)]}') 72 | _c[i] = self.linear_c[str(i)](x[i]).permute(0, 2, 1).contiguous() 73 | _c[i] = _c[i].reshape(n, -1, x[i].shape[2], x[i].shape[3]) 74 | if i != 0: 75 | _c[i] = resize( 76 | _c[i], 77 | size=x[0].size()[2:], 78 | mode='bilinear', 79 | align_corners=False) 80 | 81 | _c = self.linear_fuse(torch.cat(list(_c.values()), dim=1)) 82 | 83 | if self.dropout is not None: 84 | x = self.dropout(_c) 85 | else: 86 | x = _c 87 | x = self.linear_pred(x) 88 | 89 | return x 90 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/sep_aspp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .aspp_head import ASPPHead, ASPPModule 10 | 11 | 12 | class DepthwiseSeparableASPPModule(ASPPModule): 13 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 14 | conv.""" 15 | 16 | def __init__(self, **kwargs): 17 | super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) 18 | for i, dilation in enumerate(self.dilations): 19 | if dilation > 1: 20 | self[i] = DepthwiseSeparableConvModule( 21 | self.in_channels, 22 | self.channels, 23 | 3, 24 | dilation=dilation, 25 | padding=dilation, 26 | norm_cfg=self.norm_cfg, 27 | act_cfg=self.act_cfg) 28 | 29 | 30 | @HEADS.register_module() 31 | class DepthwiseSeparableASPPHead(ASPPHead): 32 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 33 | Segmentation. 34 | 35 | This head is the implementation of `DeepLabV3+ 36 | `_. 37 | 38 | Args: 39 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 40 | the no decoder will be used. 41 | c1_channels (int): The intermediate channels of c1 decoder. 42 | """ 43 | 44 | def __init__(self, c1_in_channels, c1_channels, **kwargs): 45 | super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) 46 | assert c1_in_channels >= 0 47 | self.aspp_modules = DepthwiseSeparableASPPModule( 48 | dilations=self.dilations, 49 | in_channels=self.in_channels, 50 | channels=self.channels, 51 | conv_cfg=self.conv_cfg, 52 | norm_cfg=self.norm_cfg, 53 | act_cfg=self.act_cfg) 54 | if c1_in_channels > 0: 55 | self.c1_bottleneck = ConvModule( 56 | c1_in_channels, 57 | c1_channels, 58 | 1, 59 | conv_cfg=self.conv_cfg, 60 | norm_cfg=self.norm_cfg, 61 | act_cfg=self.act_cfg) 62 | else: 63 | self.c1_bottleneck = None 64 | self.sep_bottleneck = nn.Sequential( 65 | DepthwiseSeparableConvModule( 66 | self.channels + c1_channels, 67 | self.channels, 68 | 3, 69 | padding=1, 70 | norm_cfg=self.norm_cfg, 71 | act_cfg=self.act_cfg), 72 | DepthwiseSeparableConvModule( 73 | self.channels, 74 | self.channels, 75 | 3, 76 | padding=1, 77 | norm_cfg=self.norm_cfg, 78 | act_cfg=self.act_cfg)) 79 | 80 | def forward(self, inputs): 81 | """Forward function.""" 82 | x = self._transform_inputs(inputs) 83 | aspp_outs = [ 84 | resize( 85 | self.image_pool(x), 86 | size=x.size()[2:], 87 | mode='bilinear', 88 | align_corners=self.align_corners) 89 | ] 90 | aspp_outs.extend(self.aspp_modules(x)) 91 | aspp_outs = torch.cat(aspp_outs, dim=1) 92 | output = self.bottleneck(aspp_outs) 93 | if self.c1_bottleneck is not None: 94 | c1_output = self.c1_bottleneck(inputs[0]) 95 | output = resize( 96 | input=output, 97 | size=c1_output.shape[2:], 98 | mode='bilinear', 99 | align_corners=self.align_corners) 100 | output = torch.cat([output, c1_output], dim=1) 101 | output = self.sep_bottleneck(output) 102 | output = self.cls_seg(output) 103 | return output 104 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/uper_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | from .psp_head import PPM 11 | 12 | 13 | @HEADS.register_module() 14 | class UPerHead(BaseDecodeHead): 15 | """Unified Perceptual Parsing for Scene Understanding. 16 | 17 | This head is the implementation of `UPerNet 18 | `_. 19 | 20 | Args: 21 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 22 | Module applied on the last feature. Default: (1, 2, 3, 6). 23 | """ 24 | 25 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 26 | super(UPerHead, self).__init__( 27 | input_transform='multiple_select', **kwargs) 28 | # PSP Module 29 | self.psp_modules = PPM( 30 | pool_scales, 31 | self.in_channels[-1], 32 | self.channels, 33 | conv_cfg=self.conv_cfg, 34 | norm_cfg=self.norm_cfg, 35 | act_cfg=self.act_cfg, 36 | align_corners=self.align_corners) 37 | self.bottleneck = ConvModule( 38 | self.in_channels[-1] + len(pool_scales) * self.channels, 39 | self.channels, 40 | 3, 41 | padding=1, 42 | conv_cfg=self.conv_cfg, 43 | norm_cfg=self.norm_cfg, 44 | act_cfg=self.act_cfg) 45 | # FPN Module 46 | self.lateral_convs = nn.ModuleList() 47 | self.fpn_convs = nn.ModuleList() 48 | for in_channels in self.in_channels[:-1]: # skip the top layer 49 | l_conv = ConvModule( 50 | in_channels, 51 | self.channels, 52 | 1, 53 | conv_cfg=self.conv_cfg, 54 | norm_cfg=self.norm_cfg, 55 | act_cfg=self.act_cfg, 56 | inplace=False) 57 | fpn_conv = ConvModule( 58 | self.channels, 59 | self.channels, 60 | 3, 61 | padding=1, 62 | conv_cfg=self.conv_cfg, 63 | norm_cfg=self.norm_cfg, 64 | act_cfg=self.act_cfg, 65 | inplace=False) 66 | self.lateral_convs.append(l_conv) 67 | self.fpn_convs.append(fpn_conv) 68 | 69 | self.fpn_bottleneck = ConvModule( 70 | len(self.in_channels) * self.channels, 71 | self.channels, 72 | 3, 73 | padding=1, 74 | conv_cfg=self.conv_cfg, 75 | norm_cfg=self.norm_cfg, 76 | act_cfg=self.act_cfg) 77 | 78 | def psp_forward(self, inputs): 79 | """Forward function of PSP module.""" 80 | x = inputs[-1] 81 | psp_outs = [x] 82 | psp_outs.extend(self.psp_modules(x)) 83 | psp_outs = torch.cat(psp_outs, dim=1) 84 | output = self.bottleneck(psp_outs) 85 | 86 | return output 87 | 88 | def forward(self, inputs): 89 | """Forward function.""" 90 | 91 | inputs = self._transform_inputs(inputs) 92 | 93 | # build laterals 94 | laterals = [ 95 | lateral_conv(inputs[i]) 96 | for i, lateral_conv in enumerate(self.lateral_convs) 97 | ] 98 | 99 | laterals.append(self.psp_forward(inputs)) 100 | 101 | # build top-down path 102 | used_backbone_levels = len(laterals) 103 | for i in range(used_backbone_levels - 1, 0, -1): 104 | prev_shape = laterals[i - 1].shape[2:] 105 | laterals[i - 1] += resize( 106 | laterals[i], 107 | size=prev_shape, 108 | mode='bilinear', 109 | align_corners=self.align_corners) 110 | 111 | # build outputs 112 | fpn_outs = [ 113 | self.fpn_convs[i](laterals[i]) 114 | for i in range(used_backbone_levels - 1) 115 | ] 116 | # append psp feature 117 | fpn_outs.append(laterals[-1]) 118 | 119 | for i in range(used_backbone_levels - 1, 0, -1): 120 | fpn_outs[i] = resize( 121 | fpn_outs[i], 122 | size=fpn_outs[0].shape[2:], 123 | mode='bilinear', 124 | align_corners=self.align_corners) 125 | fpn_outs = torch.cat(fpn_outs, dim=1) 126 | output = self.fpn_bottleneck(fpn_outs) 127 | output = self.cls_seg(output) 128 | return output 129 | -------------------------------------------------------------------------------- /mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy, accuracy 2 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 3 | cross_entropy, mask_cross_entropy) 4 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 5 | 6 | __all__ = [ 7 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 8 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 9 | 'weight_reduce_loss', 'weighted_loss' 10 | ] 11 | -------------------------------------------------------------------------------- /mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def accuracy(pred, target, topk=1, thresh=None): 7 | """Calculate accuracy according to the prediction and target. 8 | 9 | Args: 10 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 11 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 12 | topk (int | tuple[int], optional): If the predictions in ``topk`` 13 | matches the target, the predictions will be regarded as 14 | correct ones. Defaults to 1. 15 | thresh (float, optional): If not None, predictions with scores under 16 | this threshold are considered incorrect. Default to None. 17 | 18 | Returns: 19 | float | tuple[float]: If the input ``topk`` is a single integer, 20 | the function will return a single float as accuracy. If 21 | ``topk`` is a tuple containing multiple integers, the 22 | function will return a tuple containing accuracies of 23 | each ``topk`` number. 24 | """ 25 | assert isinstance(topk, (int, tuple)) 26 | if isinstance(topk, int): 27 | topk = (topk, ) 28 | return_single = True 29 | else: 30 | return_single = False 31 | 32 | maxk = max(topk) 33 | if pred.size(0) == 0: 34 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 35 | return accu[0] if return_single else accu 36 | assert pred.ndim == target.ndim + 1 37 | assert pred.size(0) == target.size(0) 38 | assert maxk <= pred.size(1), \ 39 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 40 | pred_value, pred_label = pred.topk(maxk, dim=1) 41 | # transpose to shape (maxk, N, ...) 42 | pred_label = pred_label.transpose(0, 1) 43 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 44 | if thresh is not None: 45 | # Only prediction values larger than thresh are counted as correct 46 | correct = correct & (pred_value > thresh).t() 47 | res = [] 48 | for k in topk: 49 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 50 | res.append(correct_k.mul_(100.0 / target.numel())) 51 | return res[0] if return_single else res 52 | 53 | 54 | class Accuracy(nn.Module): 55 | """Accuracy calculation module.""" 56 | 57 | def __init__(self, topk=(1, ), thresh=None): 58 | """Module to calculate the accuracy. 59 | 60 | Args: 61 | topk (tuple, optional): The criterion used to calculate the 62 | accuracy. Defaults to (1,). 63 | thresh (float, optional): If not None, predictions with scores 64 | under this threshold are considered incorrect. Default to None. 65 | """ 66 | super().__init__() 67 | self.topk = topk 68 | self.thresh = thresh 69 | 70 | def forward(self, pred, target): 71 | """Forward function to calculate accuracy. 72 | 73 | Args: 74 | pred (torch.Tensor): Prediction of models. 75 | target (torch.Tensor): Target for each prediction. 76 | 77 | Returns: 78 | tuple[float]: The accuracies under different topk criterions. 79 | """ 80 | return accuracy(pred, target, self.topk, self.thresh) 81 | -------------------------------------------------------------------------------- /mmseg/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import functools 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_class_weight(class_weight): 11 | """Get class weight for loss function. 12 | 13 | Args: 14 | class_weight (list[float] | str | None): If class_weight is a str, 15 | take it as a file name and read from it. 16 | """ 17 | if isinstance(class_weight, str): 18 | # take it as a file path 19 | if class_weight.endswith('.npy'): 20 | class_weight = np.load(class_weight) 21 | else: 22 | # pkl, json or yaml 23 | class_weight = mmcv.load(class_weight) 24 | 25 | return class_weight 26 | 27 | 28 | def reduce_loss(loss, reduction): 29 | """Reduce loss as specified. 30 | 31 | Args: 32 | loss (Tensor): Elementwise loss tensor. 33 | reduction (str): Options are "none", "mean" and "sum". 34 | 35 | Return: 36 | Tensor: Reduced loss tensor. 37 | """ 38 | reduction_enum = F._Reduction.get_enum(reduction) 39 | # none: 0, elementwise_mean:1, sum: 2 40 | if reduction_enum == 0: 41 | return loss 42 | elif reduction_enum == 1: 43 | return loss.mean() 44 | elif reduction_enum == 2: 45 | return loss.sum() 46 | 47 | 48 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 49 | """Apply element-wise weight and reduce loss. 50 | 51 | Args: 52 | loss (Tensor): Element-wise loss. 53 | weight (Tensor): Element-wise weights. 54 | reduction (str): Same as built-in losses of PyTorch. 55 | avg_factor (float): Avarage factor when computing the mean of losses. 56 | 57 | Returns: 58 | Tensor: Processed loss values. 59 | """ 60 | # if weight is specified, apply element-wise weight 61 | if weight is not None: 62 | assert weight.dim() == loss.dim() 63 | if weight.dim() > 1: 64 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 65 | loss = loss * weight 66 | 67 | # if avg_factor is not specified, just reduce the loss 68 | if avg_factor is None: 69 | loss = reduce_loss(loss, reduction) 70 | else: 71 | # if reduction is mean, then average the loss by avg_factor 72 | if reduction == 'mean': 73 | loss = loss.sum() / avg_factor 74 | # if reduction is 'none', then do nothing, otherwise raise an error 75 | elif reduction != 'none': 76 | raise ValueError('avg_factor can not be used with reduction="sum"') 77 | return loss 78 | 79 | 80 | def weighted_loss(loss_func): 81 | """Create a weighted version of a given loss function. 82 | 83 | To use this decorator, the loss function must have the signature like 84 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 85 | element-wise loss without any reduction. This decorator will add weight 86 | and reduction arguments to the function. The decorated function will have 87 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 88 | avg_factor=None, **kwargs)`. 89 | 90 | :Example: 91 | 92 | >>> import torch 93 | >>> @weighted_loss 94 | >>> def l1_loss(pred, target): 95 | >>> return (pred - target).abs() 96 | 97 | >>> pred = torch.Tensor([0, 2, 3]) 98 | >>> target = torch.Tensor([1, 1, 1]) 99 | >>> weight = torch.Tensor([1, 0, 1]) 100 | 101 | >>> l1_loss(pred, target) 102 | tensor(1.3333) 103 | >>> l1_loss(pred, target, weight) 104 | tensor(1.) 105 | >>> l1_loss(pred, target, reduction='none') 106 | tensor([1., 1., 2.]) 107 | >>> l1_loss(pred, target, weight, avg_factor=2) 108 | tensor(1.5000) 109 | """ 110 | 111 | @functools.wraps(loss_func) 112 | def wrapper(pred, 113 | target, 114 | weight=None, 115 | reduction='mean', 116 | avg_factor=None, 117 | **kwargs): 118 | # get element-wise loss 119 | loss = loss_func(pred, target, **kwargs) 120 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 121 | return loss 122 | 123 | return wrapper 124 | -------------------------------------------------------------------------------- /mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .segformer_adapter import SegFormerAdapter 2 | 3 | __all__ = ['SegFormerAdapter'] 4 | -------------------------------------------------------------------------------- /mmseg/models/necks/segformer_adapter.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from mmseg.ops import resize 10 | from ..builder import NECKS 11 | 12 | 13 | @NECKS.register_module() 14 | class SegFormerAdapter(nn.Module): 15 | 16 | def __init__(self, out_layers=[3], scales=[4]): 17 | super(SegFormerAdapter, self).__init__() 18 | self.out_layers = out_layers 19 | self.scales = scales 20 | 21 | def forward(self, x): 22 | _c = {} 23 | for i, s in zip(self.out_layers, self.scales): 24 | if s == 1: 25 | _c[i] = x[i] 26 | else: 27 | _c[i] = resize( 28 | x[i], scale_factor=s, mode='bilinear', align_corners=False) 29 | # mmcv.print_log(f'{i}: {x[i].shape}, {_c[i].shape}', 'mmseg') 30 | 31 | x[-1] = torch.cat(list(_c.values()), dim=1) 32 | return x 33 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseSegmentor 2 | from .encoder_decoder import EncoderDecoder 3 | 4 | __all__ = ['BaseSegmentor', 'EncoderDecoder'] 5 | -------------------------------------------------------------------------------- /mmseg/models/uda/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | from mmseg.models.uda.dacs import DACS 7 | from mmseg.models.uda.dacs_cda import DACS_CDA 8 | 9 | __all__ = ['DACS', 'DACS_CDA'] 10 | -------------------------------------------------------------------------------- /mmseg/models/uda/uda_decorator.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | from copy import deepcopy 7 | 8 | from mmcv.parallel import MMDistributedDataParallel 9 | 10 | from mmseg.models import BaseSegmentor, build_segmentor 11 | 12 | 13 | def get_module(module): 14 | """Get `nn.ModuleDict` to fit the `MMDistributedDataParallel` interface. 15 | 16 | Args: 17 | module (MMDistributedDataParallel | nn.ModuleDict): The input 18 | module that needs processing. 19 | 20 | Returns: 21 | nn.ModuleDict: The ModuleDict of multiple networks. 22 | """ 23 | if isinstance(module, MMDistributedDataParallel): 24 | return module.module 25 | 26 | return module 27 | 28 | 29 | class UDADecorator(BaseSegmentor): 30 | 31 | def __init__(self, **cfg): 32 | super(BaseSegmentor, self).__init__() 33 | 34 | self.model = build_segmentor(deepcopy(cfg['model'])) 35 | self.train_cfg = cfg['model']['train_cfg'] 36 | self.test_cfg = cfg['model']['test_cfg'] 37 | self.num_classes = cfg['model']['decode_head']['num_classes'] 38 | 39 | def get_model(self): 40 | return get_module(self.model) 41 | 42 | def extract_feat(self, img): 43 | """Extract features from images.""" 44 | return self.get_model().extract_feat(img) 45 | 46 | def encode_decode(self, img, img_metas): 47 | """Encode images with backbone and decode into a semantic segmentation 48 | map of the same size as input.""" 49 | return self.get_model().encode_decode(img, img_metas) 50 | 51 | def forward_train(self, 52 | img, 53 | img_metas, 54 | gt_semantic_seg, 55 | target_img, 56 | target_img_metas, 57 | return_feat=False): 58 | """Forward function for training. 59 | 60 | Args: 61 | img (Tensor): Input images. 62 | img_metas (list[dict]): List of image info dict where each dict 63 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 64 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 65 | For details on the values of these keys see 66 | `mmseg/datasets/pipelines/formatting.py:Collect`. 67 | gt_semantic_seg (Tensor): Semantic segmentation masks 68 | used if the architecture supports semantic segmentation task. 69 | 70 | Returns: 71 | dict[str, Tensor]: a dictionary of loss components 72 | """ 73 | losses = self.get_model().forward_train( 74 | img, img_metas, gt_semantic_seg, return_feat=return_feat) 75 | return losses 76 | 77 | def inference(self, img, img_meta, rescale): 78 | """Inference with slide/whole style. 79 | 80 | Args: 81 | img (Tensor): The input image of shape (N, 3, H, W). 82 | img_meta (dict): Image info dict where each dict has: 'img_shape', 83 | 'scale_factor', 'flip', and may also contain 84 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 85 | For details on the values of these keys see 86 | `mmseg/datasets/pipelines/formatting.py:Collect`. 87 | rescale (bool): Whether rescale back to original shape. 88 | 89 | Returns: 90 | Tensor: The output segmentation map. 91 | """ 92 | return self.get_model().inference(img, img_meta, rescale) 93 | 94 | def simple_test(self, img, img_meta, rescale=True): 95 | """Simple test with single image.""" 96 | return self.get_model().simple_test(img, img_meta, rescale) 97 | 98 | def aug_test(self, imgs, img_metas, rescale=True): 99 | """Test with augmentations. 100 | 101 | Only rescale=True is supported. 102 | """ 103 | return self.get_model().aug_test(imgs, img_metas, rescale) 104 | -------------------------------------------------------------------------------- /mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .ckpt_convert import mit_convert 2 | from .make_divisible import make_divisible 3 | from .res_layer import ResLayer 4 | from .self_attention_block import SelfAttentionBlock 5 | from .shape_convert import nchw_to_nlc, nlc_to_nchw 6 | 7 | __all__ = [ 8 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'mit_convert', 9 | 'nchw_to_nlc', 'nlc_to_nchw' 10 | ] 11 | -------------------------------------------------------------------------------- /mmseg/models/utils/ckpt_convert.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | 8 | def mit_convert(ckpt): 9 | new_ckpt = OrderedDict() 10 | # Process the concat between q linear weights and kv linear weights 11 | for k, v in ckpt.items(): 12 | if k.startswith('head'): 13 | continue 14 | elif k.startswith('patch_embed'): 15 | stage_i = int(k.split('.')[0].replace('patch_embed', '')) 16 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') 17 | new_v = v 18 | if 'proj.' in new_k: 19 | new_k = new_k.replace('proj.', 'projection.') 20 | elif k.startswith('block'): 21 | stage_i = int(k.split('.')[0].replace('block', '')) 22 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') 23 | new_v = v 24 | if 'attn.q.' in new_k: 25 | sub_item_k = k.replace('q.', 'kv.') 26 | new_k = new_k.replace('q.', 'attn.in_proj_') 27 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) 28 | elif 'attn.kv.' in new_k: 29 | continue 30 | elif 'attn.proj.' in new_k: 31 | new_k = new_k.replace('proj.', 'attn.out_proj.') 32 | elif 'attn.sr.' in new_k: 33 | new_k = new_k.replace('sr.', 'sr.') 34 | elif 'mlp.' in new_k: 35 | string = f'{new_k}-' 36 | new_k = new_k.replace('mlp.', 'ffn.layers.') 37 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k: 38 | new_v = v.reshape((*v.shape, 1, 1)) 39 | new_k = new_k.replace('fc1.', '0.') 40 | new_k = new_k.replace('dwconv.dwconv.', '1.') 41 | new_k = new_k.replace('fc2.', '4.') 42 | string += f'{new_k} {v.shape}-{new_v.shape}' 43 | # print(string) 44 | elif k.startswith('norm'): 45 | stage_i = int(k.split('.')[0].replace('norm', '')) 46 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') 47 | new_v = v 48 | else: 49 | new_k = k 50 | new_v = v 51 | new_ckpt[new_k] = new_v 52 | return new_ckpt 53 | -------------------------------------------------------------------------------- /mmseg/models/utils/dacs_transforms.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/vikolss/DACS 2 | # Copyright (c) 2020 vikolss. Licensed under the MIT License 3 | # A copy of the license is available at resources/license_dacs 4 | 5 | import kornia 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def strong_transform(param, data=None, target=None): 12 | assert ((data is not None) or (target is not None)) 13 | data, target = one_mix(mask=param['mix'], data=data, target=target) 14 | data, target = color_jitter( 15 | color_jitter=param['color_jitter'], 16 | s=param['color_jitter_s'], 17 | p=param['color_jitter_p'], 18 | mean=param['mean'], 19 | std=param['std'], 20 | data=data, 21 | target=target) 22 | data, target = gaussian_blur(blur=param['blur'], data=data, target=target) 23 | return data, target 24 | 25 | 26 | def get_mean_std(img_metas, dev): 27 | mean = [ 28 | torch.as_tensor(img_metas[i]['img_norm_cfg']['mean'], device=dev) 29 | for i in range(len(img_metas)) 30 | ] 31 | mean = torch.stack(mean).view(-1, 3, 1, 1) 32 | std = [ 33 | torch.as_tensor(img_metas[i]['img_norm_cfg']['std'], device=dev) 34 | for i in range(len(img_metas)) 35 | ] 36 | std = torch.stack(std).view(-1, 3, 1, 1) 37 | return mean, std 38 | 39 | 40 | def denorm(img, mean, std): 41 | return img.mul(std).add(mean) / 255.0 42 | 43 | 44 | def denorm_(img, mean, std): 45 | img.mul_(std).add_(mean).div_(255.0) 46 | 47 | 48 | def renorm_(img, mean, std): 49 | img.mul_(255.0).sub_(mean).div_(std) 50 | 51 | 52 | def color_jitter(color_jitter, mean, std, data=None, target=None, s=.25, p=.2): 53 | # s is the strength of colorjitter 54 | if not (data is None): 55 | if data.shape[1] == 3: 56 | if color_jitter > p: 57 | if isinstance(s, dict): 58 | seq = nn.Sequential(kornia.augmentation.ColorJitter(**s)) 59 | else: 60 | seq = nn.Sequential( 61 | kornia.augmentation.ColorJitter( 62 | brightness=s, contrast=s, saturation=s, hue=s)) 63 | denorm_(data, mean, std) 64 | data = seq(data) 65 | renorm_(data, mean, std) 66 | return data, target 67 | 68 | 69 | def gaussian_blur(blur, data=None, target=None): 70 | if not (data is None): 71 | if data.shape[1] == 3: 72 | if blur > 0.5: 73 | sigma = np.random.uniform(0.15, 1.15) 74 | kernel_size_y = int( 75 | np.floor( 76 | np.ceil(0.1 * data.shape[2]) - 0.5 + 77 | np.ceil(0.1 * data.shape[2]) % 2)) 78 | kernel_size_x = int( 79 | np.floor( 80 | np.ceil(0.1 * data.shape[3]) - 0.5 + 81 | np.ceil(0.1 * data.shape[3]) % 2)) 82 | kernel_size = (kernel_size_y, kernel_size_x) 83 | seq = nn.Sequential( 84 | kornia.filters.GaussianBlur2d( 85 | kernel_size=kernel_size, sigma=(sigma, sigma))) 86 | data = seq(data) 87 | return data, target 88 | 89 | 90 | def get_class_masks(labels): 91 | class_masks = [] 92 | for label in labels: 93 | classes = torch.unique(labels) 94 | nclasses = classes.shape[0] 95 | class_choice = np.random.choice( 96 | nclasses, int((nclasses + nclasses % 2) / 2), replace=False) 97 | classes = classes[torch.Tensor(class_choice).long()] 98 | class_masks.append(generate_class_mask(label, classes).unsqueeze(0)) 99 | return class_masks 100 | 101 | 102 | def generate_class_mask(label, classes): 103 | label, classes = torch.broadcast_tensors(label, 104 | classes.unsqueeze(1).unsqueeze(2)) 105 | class_mask = label.eq(classes).sum(0, keepdims=True) 106 | return class_mask 107 | 108 | 109 | def one_mix(mask, data=None, target=None): 110 | if mask is None: 111 | return data, target 112 | if not (data is None): 113 | stackedMask0, _ = torch.broadcast_tensors(mask[0], data[0]) 114 | data = (stackedMask0 * data[0] + 115 | (1 - stackedMask0) * data[1]).unsqueeze(0) 116 | if not (target is None): 117 | stackedMask0, _ = torch.broadcast_tensors(mask[0], target[0]) 118 | target = (stackedMask0 * target[0] + 119 | (1 - stackedMask0) * target[1]).unsqueeze(0) 120 | return data, target 121 | -------------------------------------------------------------------------------- /mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 5 | """Make divisible function. 6 | 7 | This function rounds the channel number to the nearest value that can be 8 | divisible by the divisor. It is taken from the original tf repo. It ensures 9 | that all layers have a channel number that is divisible by divisor. It can 10 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 11 | 12 | Args: 13 | value (int): The original channel number. 14 | divisor (int): The divisor to fully divide the channel number. 15 | min_value (int): The minimum value of the output channel. 16 | Default: None, means that the minimum value equal to the divisor. 17 | min_ratio (float): The minimum ratio of the rounded channel number to 18 | the original channel number. Default: 0.9. 19 | 20 | Returns: 21 | int: The modified output channel number. 22 | """ 23 | 24 | if min_value is None: 25 | min_value = divisor 26 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than (1-min_ratio). 28 | if new_value < min_ratio * value: 29 | new_value += divisor 30 | return new_value 31 | -------------------------------------------------------------------------------- /mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from mmcv.cnn import build_conv_layer, build_norm_layer 4 | from mmcv.runner import Sequential 5 | from torch import nn as nn 6 | 7 | 8 | class ResLayer(Sequential): 9 | """ResLayer to build ResNet style backbone. 10 | 11 | Args: 12 | block (nn.Module): block used to build ResLayer. 13 | inplanes (int): inplanes of block. 14 | planes (int): planes of block. 15 | num_blocks (int): number of blocks. 16 | stride (int): stride of the first block. Default: 1 17 | avg_down (bool): Use AvgPool instead of stride conv when 18 | downsampling in the bottleneck. Default: False 19 | conv_cfg (dict): dictionary to construct and config conv layer. 20 | Default: None 21 | norm_cfg (dict): dictionary to construct and config norm layer. 22 | Default: dict(type='BN') 23 | multi_grid (int | None): Multi grid dilation rates of last 24 | stage. Default: None 25 | contract_dilation (bool): Whether contract first dilation of each layer 26 | Default: False 27 | """ 28 | 29 | def __init__(self, 30 | block, 31 | inplanes, 32 | planes, 33 | num_blocks, 34 | stride=1, 35 | dilation=1, 36 | avg_down=False, 37 | conv_cfg=None, 38 | norm_cfg=dict(type='BN'), 39 | multi_grid=None, 40 | contract_dilation=False, 41 | **kwargs): 42 | self.block = block 43 | 44 | downsample = None 45 | if stride != 1 or inplanes != planes * block.expansion: 46 | downsample = [] 47 | conv_stride = stride 48 | if avg_down: 49 | conv_stride = 1 50 | downsample.append( 51 | nn.AvgPool2d( 52 | kernel_size=stride, 53 | stride=stride, 54 | ceil_mode=True, 55 | count_include_pad=False)) 56 | downsample.extend([ 57 | build_conv_layer( 58 | conv_cfg, 59 | inplanes, 60 | planes * block.expansion, 61 | kernel_size=1, 62 | stride=conv_stride, 63 | bias=False), 64 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 65 | ]) 66 | downsample = nn.Sequential(*downsample) 67 | 68 | layers = [] 69 | if multi_grid is None: 70 | if dilation > 1 and contract_dilation: 71 | first_dilation = dilation // 2 72 | else: 73 | first_dilation = dilation 74 | else: 75 | first_dilation = multi_grid[0] 76 | layers.append( 77 | block( 78 | inplanes=inplanes, 79 | planes=planes, 80 | stride=stride, 81 | dilation=first_dilation, 82 | downsample=downsample, 83 | conv_cfg=conv_cfg, 84 | norm_cfg=norm_cfg, 85 | **kwargs)) 86 | inplanes = planes * block.expansion 87 | for i in range(1, num_blocks): 88 | layers.append( 89 | block( 90 | inplanes=inplanes, 91 | planes=planes, 92 | stride=1, 93 | dilation=dilation if multi_grid is None else multi_grid[i], 94 | conv_cfg=conv_cfg, 95 | norm_cfg=norm_cfg, 96 | **kwargs)) 97 | super(ResLayer, self).__init__(*layers) 98 | -------------------------------------------------------------------------------- /mmseg/models/utils/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def nlc_to_nchw(x, hw_shape): 5 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 6 | 7 | Args: 8 | x (Tensor): The input tensor of shape [N, L, C] before convertion. 9 | hw_shape (Sequence[int]): The height and width of output feature map. 10 | 11 | Returns: 12 | Tensor: The output tensor of shape [N, C, H, W] after convertion. 13 | """ 14 | H, W = hw_shape 15 | assert len(x.shape) == 3 16 | B, L, C = x.shape 17 | assert L == H * W, 'The seq_len doesn\'t match H, W' 18 | return x.transpose(1, 2).reshape(B, C, H, W) 19 | 20 | 21 | def nchw_to_nlc(x): 22 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 23 | 24 | Args: 25 | x (Tensor): The input tensor of shape [N, C, H, W] before convertion. 26 | 27 | Returns: 28 | Tensor: The output tensor of shape [N, L, C] after convertion. 29 | """ 30 | assert len(x.shape) == 4 31 | return x.flatten(2).transpose(1, 2).contiguous() 32 | -------------------------------------------------------------------------------- /mmseg/models/utils/visualization.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import numpy as np 4 | import torch 5 | from matplotlib import pyplot as plt 6 | from PIL import Image 7 | 8 | Cityscapes_palette = [ 9 | 128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 10 | 153, 153, 250, 170, 30, 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 11 | 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 0, 60, 100, 0, 80, 100, 12 | 0, 0, 230, 119, 11, 32, 128, 192, 0, 0, 64, 128, 128, 64, 128, 0, 192, 128, 13 | 128, 192, 128, 64, 64, 0, 192, 64, 0, 64, 192, 0, 192, 192, 0, 64, 64, 128, 14 | 192, 64, 128, 64, 192, 128, 192, 192, 128, 0, 0, 64, 128, 0, 64, 0, 128, 15 | 64, 128, 128, 64, 0, 0, 192, 128, 0, 192, 0, 128, 192, 128, 128, 192, 64, 16 | 0, 64, 192, 0, 64, 64, 128, 64, 192, 128, 64, 64, 0, 192, 192, 0, 192, 64, 17 | 128, 192, 192, 128, 192, 0, 64, 64, 128, 64, 64, 0, 192, 64, 128, 192, 64, 18 | 0, 64, 192, 128, 64, 192, 0, 192, 192, 128, 192, 192, 64, 64, 64, 192, 64, 19 | 64, 64, 192, 64, 192, 192, 64, 64, 64, 192, 192, 64, 192, 64, 192, 192, 20 | 192, 192, 192, 32, 0, 0, 160, 0, 0, 32, 128, 0, 160, 128, 0, 32, 0, 128, 21 | 160, 0, 128, 32, 128, 128, 160, 128, 128, 96, 0, 0, 224, 0, 0, 96, 128, 0, 22 | 224, 128, 0, 96, 0, 128, 224, 0, 128, 96, 128, 128, 224, 128, 128, 32, 64, 23 | 0, 160, 64, 0, 32, 192, 0, 160, 192, 0, 32, 64, 128, 160, 64, 128, 32, 192, 24 | 128, 160, 192, 128, 96, 64, 0, 224, 64, 0, 96, 192, 0, 224, 192, 0, 96, 64, 25 | 128, 224, 64, 128, 96, 192, 128, 224, 192, 128, 32, 0, 64, 160, 0, 64, 32, 26 | 128, 64, 160, 128, 64, 32, 0, 192, 160, 0, 192, 32, 128, 192, 160, 128, 27 | 192, 96, 0, 64, 224, 0, 64, 96, 128, 64, 224, 128, 64, 96, 0, 192, 224, 0, 28 | 192, 96, 128, 192, 224, 128, 192, 32, 64, 64, 160, 64, 64, 32, 192, 64, 29 | 160, 192, 64, 32, 64, 192, 160, 64, 192, 32, 192, 192, 160, 192, 192, 96, 30 | 64, 64, 224, 64, 64, 96, 192, 64, 224, 192, 64, 96, 64, 192, 224, 64, 192, 31 | 96, 192, 192, 224, 192, 192, 0, 32, 0, 128, 32, 0, 0, 160, 0, 128, 160, 0, 32 | 0, 32, 128, 128, 32, 128, 0, 160, 128, 128, 160, 128, 64, 32, 0, 192, 32, 33 | 0, 64, 160, 0, 192, 160, 0, 64, 32, 128, 192, 32, 128, 64, 160, 128, 192, 34 | 160, 128, 0, 96, 0, 128, 96, 0, 0, 224, 0, 128, 224, 0, 0, 96, 128, 128, 35 | 96, 128, 0, 224, 128, 128, 224, 128, 64, 96, 0, 192, 96, 0, 64, 224, 0, 36 | 192, 224, 0, 64, 96, 128, 192, 96, 128, 64, 224, 128, 192, 224, 128, 0, 32, 37 | 64, 128, 32, 64, 0, 160, 64, 128, 160, 64, 0, 32, 192, 128, 32, 192, 0, 38 | 160, 192, 128, 160, 192, 64, 32, 64, 192, 32, 64, 64, 160, 64, 192, 160, 39 | 64, 64, 32, 192, 192, 32, 192, 64, 160, 192, 192, 160, 192, 0, 96, 64, 128, 40 | 96, 64, 0, 224, 64, 128, 224, 64, 0, 96, 192, 128, 96, 192, 0, 224, 192, 41 | 128, 224, 192, 64, 96, 64, 192, 96, 64, 64, 224, 64, 192, 224, 64, 64, 96, 42 | 192, 192, 96, 192, 64, 224, 192, 192, 224, 192, 32, 32, 0, 160, 32, 0, 32, 43 | 160, 0, 160, 160, 0, 32, 32, 128, 160, 32, 128, 32, 160, 128, 160, 160, 44 | 128, 96, 32, 0, 224, 32, 0, 96, 160, 0, 224, 160, 0, 96, 32, 128, 224, 32, 45 | 128, 96, 160, 128, 224, 160, 128, 32, 96, 0, 160, 96, 0, 32, 224, 0, 160, 46 | 224, 0, 32, 96, 128, 160, 96, 128, 32, 224, 128, 160, 224, 128, 96, 96, 0, 47 | 224, 96, 0, 96, 224, 0, 224, 224, 0, 96, 96, 128, 224, 96, 128, 96, 224, 48 | 128, 224, 224, 128, 32, 32, 64, 160, 32, 64, 32, 160, 64, 160, 160, 64, 32, 49 | 32, 192, 160, 32, 192, 32, 160, 192, 160, 160, 192, 96, 32, 64, 224, 32, 50 | 64, 96, 160, 64, 224, 160, 64, 96, 32, 192, 224, 32, 192, 96, 160, 192, 51 | 224, 160, 192, 32, 96, 64, 160, 96, 64, 32, 224, 64, 160, 224, 64, 32, 96, 52 | 192, 160, 96, 192, 32, 224, 192, 160, 224, 192, 96, 96, 64, 224, 96, 64, 53 | 96, 224, 64, 224, 224, 64, 96, 96, 192, 224, 96, 192, 96, 224, 192, 0, 0, 0 54 | ] 55 | 56 | 57 | def colorize_mask(mask, palette): 58 | zero_pad = 256 * 3 - len(palette) 59 | for i in range(zero_pad): 60 | palette.append(0) 61 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 62 | new_mask.putpalette(palette) 63 | return new_mask 64 | 65 | 66 | def _colorize(img, cmap, mask_zero=False): 67 | vmin = np.min(img) 68 | vmax = np.max(img) 69 | mask = (img <= 0).squeeze() 70 | cm = plt.get_cmap(cmap) 71 | colored_image = cm(np.clip(img.squeeze(), vmin, vmax) / vmax)[:, :, :3] 72 | # Use white if no depth is available (<= 0) 73 | if mask_zero: 74 | colored_image[mask, :] = [1, 1, 1] 75 | return colored_image 76 | 77 | 78 | def subplotimg(ax, 79 | img, 80 | title, 81 | range_in_title=False, 82 | palette=Cityscapes_palette, 83 | **kwargs): 84 | if img is None: 85 | return 86 | with torch.no_grad(): 87 | if torch.is_tensor(img): 88 | img = img.cpu() 89 | if len(img.shape) == 2: 90 | if torch.is_tensor(img): 91 | img = img.numpy() 92 | elif img.shape[0] == 1: 93 | if torch.is_tensor(img): 94 | img = img.numpy() 95 | img = img.squeeze(0) 96 | elif img.shape[0] == 3: 97 | img = img.permute(1, 2, 0) 98 | if not torch.is_tensor(img): 99 | img = img.numpy() 100 | if kwargs.get('cmap', '') == 'cityscapes': 101 | kwargs.pop('cmap') 102 | if torch.is_tensor(img): 103 | img = img.numpy() 104 | img = colorize_mask(img, palette) 105 | 106 | if range_in_title: 107 | vmin = np.min(img) 108 | vmax = np.max(img) 109 | title += f' {vmin:.3f}-{vmax:.3f}' 110 | 111 | ax.imshow(img, **kwargs) 112 | ax.set_title(title) 113 | -------------------------------------------------------------------------------- /mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoding import Encoding 2 | from .wrappers import Upsample, resize 3 | 4 | __all__ = ['Upsample', 'resize', 'Encoding'] 5 | -------------------------------------------------------------------------------- /mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class Encoding(nn.Module): 9 | """Encoding Layer: a learnable residual encoder. 10 | 11 | Input is of shape (batch_size, channels, height, width). 12 | Output is of shape (batch_size, num_codes, channels). 13 | 14 | Args: 15 | channels: dimension of the features or feature channels 16 | num_codes: number of code words 17 | """ 18 | 19 | def __init__(self, channels, num_codes): 20 | super(Encoding, self).__init__() 21 | # init codewords and smoothing factor 22 | self.channels, self.num_codes = channels, num_codes 23 | std = 1. / ((num_codes * channels)**0.5) 24 | # [num_codes, channels] 25 | self.codewords = nn.Parameter( 26 | torch.empty(num_codes, channels, 27 | dtype=torch.float).uniform_(-std, std), 28 | requires_grad=True) 29 | # [num_codes] 30 | self.scale = nn.Parameter( 31 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 32 | requires_grad=True) 33 | 34 | @staticmethod 35 | def scaled_l2(x, codewords, scale): 36 | num_codes, channels = codewords.size() 37 | batch_size = x.size(0) 38 | reshaped_scale = scale.view((1, 1, num_codes)) 39 | expanded_x = x.unsqueeze(2).expand( 40 | (batch_size, x.size(1), num_codes, channels)) 41 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 42 | 43 | scaled_l2_norm = reshaped_scale * ( 44 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 45 | return scaled_l2_norm 46 | 47 | @staticmethod 48 | def aggregate(assignment_weights, x, codewords): 49 | num_codes, channels = codewords.size() 50 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 51 | batch_size = x.size(0) 52 | 53 | expanded_x = x.unsqueeze(2).expand( 54 | (batch_size, x.size(1), num_codes, channels)) 55 | encoded_feat = (assignment_weights.unsqueeze(3) * 56 | (expanded_x - reshaped_codewords)).sum(dim=1) 57 | return encoded_feat 58 | 59 | def forward(self, x): 60 | assert x.dim() == 4 and x.size(1) == self.channels 61 | # [batch_size, channels, height, width] 62 | batch_size = x.size(0) 63 | # [batch_size, height x width, channels] 64 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 65 | # assignment_weights: [batch_size, channels, num_codes] 66 | assignment_weights = F.softmax( 67 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 68 | # aggregate 69 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 70 | return encoded_feat 71 | 72 | def __repr__(self): 73 | repr_str = self.__class__.__name__ 74 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 75 | f'x{self.channels})' 76 | return repr_str 77 | -------------------------------------------------------------------------------- /mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import warnings 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def resize(input, 10 | size=None, 11 | scale_factor=None, 12 | mode='nearest', 13 | align_corners=None, 14 | warning=True): 15 | if warning: 16 | if size is not None and align_corners: 17 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 18 | output_h, output_w = tuple(int(x) for x in size) 19 | if output_h > input_h or output_w > output_h: 20 | if ((output_h > 1 and output_w > 1 and input_h > 1 21 | and input_w > 1) and (output_h - 1) % (input_h - 1) 22 | and (output_w - 1) % (input_w - 1)): 23 | warnings.warn( 24 | f'When align_corners={align_corners}, ' 25 | 'the output would more aligned if ' 26 | f'input size {(input_h, input_w)} is `x+1` and ' 27 | f'out size {(output_h, output_w)} is `nx+1`') 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | 30 | 31 | class Upsample(nn.Module): 32 | 33 | def __init__(self, 34 | size=None, 35 | scale_factor=None, 36 | mode='nearest', 37 | align_corners=None): 38 | super(Upsample, self).__init__() 39 | self.size = size 40 | if isinstance(scale_factor, tuple): 41 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 42 | else: 43 | self.scale_factor = float(scale_factor) if scale_factor else None 44 | self.mode = mode 45 | self.align_corners = align_corners 46 | 47 | def forward(self, x): 48 | if not self.size: 49 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 50 | else: 51 | size = self.size 52 | return resize(x, size, None, self.mode, self.align_corners) 53 | -------------------------------------------------------------------------------- /mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .collect_env import collect_env 2 | from .logger import get_root_logger 3 | 4 | __all__ = ['get_root_logger', 'collect_env'] 5 | -------------------------------------------------------------------------------- /mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add code archive generation 3 | 4 | import os 5 | import tarfile 6 | 7 | from mmcv.utils import collect_env as collect_base_env 8 | from mmcv.utils import get_git_hash 9 | 10 | import mmseg 11 | 12 | 13 | def collect_env(): 14 | """Collect the information of the running environments.""" 15 | env_info = collect_base_env() 16 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 17 | 18 | return env_info 19 | 20 | 21 | def is_source_file(x): 22 | if x.isdir() or x.name.endswith(('.py', '.sh', '.yml', '.json', '.txt')) \ 23 | and '.mim' not in x.name and 'jobs/' not in x.name: 24 | # print(x.name) 25 | return x 26 | else: 27 | return None 28 | 29 | 30 | def gen_code_archive(out_dir, file='code.tar.gz'): 31 | archive = os.path.join(out_dir, file) 32 | os.makedirs(os.path.dirname(archive), exist_ok=True) 33 | with tarfile.open(archive, mode='w:gz') as tar: 34 | tar.add('.', filter=is_source_file) 35 | return archive 36 | 37 | 38 | if __name__ == '__main__': 39 | for name, val in collect_env().items(): 40 | print('{}: {}'.format(name, val)) 41 | -------------------------------------------------------------------------------- /mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import logging 4 | 5 | from mmcv.utils import get_logger 6 | 7 | 8 | def get_root_logger(log_file=None, log_level=logging.INFO): 9 | """Get the root logger. 10 | 11 | The logger will be initialized if it has not been initialized. By default a 12 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 13 | also be added. The name of the root logger is the top-level package name, 14 | e.g., "mmseg". 15 | 16 | Args: 17 | log_file (str | None): The log filename. If specified, a FileHandler 18 | will be added to the root logger. 19 | log_level (int): The root logger level. Note that only the process of 20 | rank 0 is affected, while other processes will set the level to 21 | "Error" and be silent most of the time. 22 | 23 | Returns: 24 | logging.Logger: The root logger. 25 | """ 26 | 27 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 28 | 29 | return logger 30 | -------------------------------------------------------------------------------- /mmseg/utils/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | @contextlib.contextmanager 9 | def np_local_seed(seed): 10 | state = np.random.get_state() 11 | np.random.seed(seed) 12 | try: 13 | yield 14 | finally: 15 | np.random.set_state(state) 16 | 17 | 18 | def downscale_label_ratio(gt, 19 | scale_factor, 20 | min_ratio, 21 | n_classes, 22 | ignore_index=255): 23 | assert scale_factor > 1 24 | bs, orig_c, orig_h, orig_w = gt.shape 25 | assert orig_c == 1 26 | trg_h, trg_w = orig_h // scale_factor, orig_w // scale_factor 27 | ignore_substitute = n_classes 28 | 29 | out = gt.clone() # otw. next line would modify original gt 30 | out[out == ignore_index] = ignore_substitute 31 | out = F.one_hot( 32 | out.squeeze(1), num_classes=n_classes + 1).permute(0, 3, 1, 2) 33 | assert list(out.shape) == [bs, n_classes + 1, orig_h, orig_w], out.shape 34 | out = F.avg_pool2d(out.float(), kernel_size=scale_factor) 35 | gt_ratio, out = torch.max(out, dim=1, keepdim=True) 36 | out[out == ignore_substitute] = ignore_index 37 | out[gt_ratio < min_ratio] = ignore_index 38 | assert list(out.shape) == [bs, 1, trg_h, trg_w], out.shape 39 | return out 40 | -------------------------------------------------------------------------------- /mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.16.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cityscapesscripts==2.2.0 2 | cycler==0.10.0 3 | gdown==4.2.0 4 | humanfriendly==9.2 5 | kiwisolver==1.2.0 6 | kornia==0.5.8 7 | matplotlib==3.4.2 8 | numpy==1.19.2 9 | opencv-python==4.4.0.46 10 | pandas==1.1.3 11 | Pillow==8.3.1 12 | prettytable==2.1.0 13 | pyparsing==2.4.7 14 | pytz==2020.1 15 | PyYAML==5.4.1 16 | scipy==1.6.3 17 | seaborn==0.11.1 18 | timm==0.3.2 19 | torch==1.7.1+cu110 20 | torchvision==0.8.2+cu110 21 | tqdm==4.48.2 22 | typing-extensions==3.7.4.3 23 | wcwidth==0.2.5 24 | yapf==0.31.0 25 | -------------------------------------------------------------------------------- /run_experiments.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import argparse 7 | import json 8 | import os 9 | import subprocess 10 | import uuid 11 | from datetime import datetime 12 | 13 | import torch 14 | from experiments import generate_experiment_cfgs 15 | from mmcv import Config, get_git_hash 16 | from tools import train 17 | 18 | 19 | def run_command(command): 20 | p = subprocess.Popen( 21 | command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) 22 | for line in iter(p.stdout.readline, b''): 23 | print(line.decode('utf-8'), end='') 24 | 25 | 26 | def rsync(src, dst): 27 | rsync_cmd = f'rsync -a {src} {dst}' 28 | print(rsync_cmd) 29 | run_command(rsync_cmd) 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | group = parser.add_mutually_exclusive_group(required=True) 35 | group.add_argument( 36 | '--exp', 37 | type=int, 38 | default=None, 39 | help='Experiment id as defined in experiment.py', 40 | ) 41 | group.add_argument( 42 | '--config', 43 | default=None, 44 | help='Path to config file', 45 | ) 46 | parser.add_argument( 47 | '--machine', type=str, choices=['local'], default='local') 48 | parser.add_argument('--debug', action='store_true') 49 | args = parser.parse_args() 50 | assert (args.config is None) != (args.exp is None), \ 51 | 'Either config or exp has to be defined.' 52 | 53 | GEN_CONFIG_DIR = 'configs/generated/' 54 | JOB_DIR = 'jobs' 55 | cfgs, config_files = [], [] 56 | 57 | # Training with Predefined Config 58 | if args.config is not None: 59 | cfg = Config.fromfile(args.config) 60 | # Specify Name and Work Directory 61 | exp_name = f'{args.machine}-{cfg["exp"]}' 62 | unique_name = f'{datetime.now().strftime("%y%m%d_%H%M")}_' \ 63 | f'{cfg["name"]}_{str(uuid.uuid4())[:5]}' 64 | child_cfg = { 65 | '_base_': args.config.replace('configs', '../..'), 66 | 'name': unique_name, 67 | 'work_dir': os.path.join('work_dirs', exp_name, unique_name), 68 | 'git_rev': get_git_hash() 69 | } 70 | cfg_out_file = f"{GEN_CONFIG_DIR}/{exp_name}/{child_cfg['name']}.json" 71 | os.makedirs(os.path.dirname(cfg_out_file), exist_ok=True) 72 | assert not os.path.isfile(cfg_out_file) 73 | with open(cfg_out_file, 'w') as of: 74 | json.dump(child_cfg, of, indent=4) 75 | config_files.append(cfg_out_file) 76 | cfgs.append(cfg) 77 | 78 | # Training with Generated Configs from experiments.py 79 | if args.exp is not None: 80 | exp_name = f'{args.machine}-exp{args.exp}' 81 | cfgs = generate_experiment_cfgs(args.exp) 82 | # Generate Configs 83 | for i, cfg in enumerate(cfgs): 84 | if args.debug: 85 | cfg.setdefault('log_config', {})['interval'] = 10 86 | cfg['evaluation'] = dict(interval=200, metric='mIoU') 87 | if 'dacs' in cfg['name']: 88 | cfg.setdefault('uda', {})['debug_img_interval'] = 10 89 | # cfg.setdefault('uda', {})['print_grad_magnitude'] = True 90 | # Generate Config File 91 | cfg['name'] = f'{datetime.now().strftime("%y%m%d_%H%M")}_' \ 92 | f'{cfg["name"]}_{str(uuid.uuid4())[:5]}' 93 | cfg['work_dir'] = os.path.join('work_dirs', exp_name, cfg['name']) 94 | cfg['git_rev'] = get_git_hash() 95 | cfg['_base_'] = ['../../' + e for e in cfg['_base_']] 96 | cfg_out_file = f"{GEN_CONFIG_DIR}/{exp_name}/{cfg['name']}.json" 97 | os.makedirs(os.path.dirname(cfg_out_file), exist_ok=True) 98 | assert not os.path.isfile(cfg_out_file) 99 | with open(cfg_out_file, 'w') as of: 100 | json.dump(cfg, of, indent=4) 101 | config_files.append(cfg_out_file) 102 | 103 | if args.machine == 'local': 104 | for i, cfg in enumerate(cfgs): 105 | print('Run job {}'.format(cfg['name'])) 106 | train.main([config_files[i]]) 107 | torch.cuda.empty_cache() 108 | else: 109 | raise NotImplementedError(args.machine) 110 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | #!/bin/bash 7 | 8 | TEST_ROOT=$1 9 | CONFIG_FILE="${TEST_ROOT}/*${TEST_ROOT: -1}.json" 10 | CHECKPOINT_FILE="${TEST_ROOT}/latest.pth" 11 | SHOW_DIR="${TEST_ROOT}/preds/" 12 | echo 'Config File:' $CONFIG_FILE 13 | echo 'Checkpoint File:' $CHECKPOINT_FILE 14 | echo 'Predictions Output Directory:' $SHOW_DIR 15 | python -m tools.test ${CONFIG_FILE} ${CHECKPOINT_FILE} --eval mIoU --show-dir ${SHOW_DIR} --opacity 1 16 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangkaihong/CDAC/160e3328cae8fb9a61b71529ea562251e120ae34/tools/__init__.py -------------------------------------------------------------------------------- /tools/analyze_logs.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | """Modified from https://github.com/open- 3 | mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" 4 | import argparse 5 | import json 6 | from collections import defaultdict 7 | 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | 12 | def plot_curve(log_dicts, args): 13 | if args.backend is not None: 14 | plt.switch_backend(args.backend) 15 | sns.set_style(args.style) 16 | # if legend is None, use {filename}_{key} as legend 17 | legend = args.legend 18 | if legend is None: 19 | legend = [] 20 | for json_log in args.json_logs: 21 | for metric in args.keys: 22 | legend.append(f'{json_log}_{metric}') 23 | assert len(legend) == (len(args.json_logs) * len(args.keys)) 24 | metrics = args.keys 25 | 26 | num_metrics = len(metrics) 27 | for i, log_dict in enumerate(log_dicts): 28 | epochs = list(log_dict.keys()) 29 | for j, metric in enumerate(metrics): 30 | print(f'plot curve of {args.json_logs[i]}, metric is {metric}') 31 | plot_epochs = [] 32 | plot_iters = [] 33 | plot_values = [] 34 | # In some log files, iters number is not correct, `pre_iter` is 35 | # used to prevent generate wrong lines. 36 | pre_iter = -1 37 | for epoch in epochs: 38 | epoch_logs = log_dict[epoch] 39 | if metric not in epoch_logs.keys(): 40 | continue 41 | if metric in ['mIoU', 'mAcc', 'aAcc']: 42 | plot_epochs.append(epoch) 43 | plot_values.append(epoch_logs[metric][0]) 44 | else: 45 | for idx in range(len(epoch_logs[metric])): 46 | if pre_iter > epoch_logs['iter'][idx]: 47 | continue 48 | pre_iter = epoch_logs['iter'][idx] 49 | plot_iters.append(epoch_logs['iter'][idx]) 50 | plot_values.append(epoch_logs[metric][idx]) 51 | ax = plt.gca() 52 | label = legend[i * num_metrics + j] 53 | if metric in ['mIoU', 'mAcc', 'aAcc']: 54 | ax.set_xticks(plot_epochs) 55 | plt.xlabel('epoch') 56 | plt.plot(plot_epochs, plot_values, label=label, marker='o') 57 | else: 58 | plt.xlabel('iter') 59 | plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) 60 | plt.legend() 61 | if args.title is not None: 62 | plt.title(args.title) 63 | if args.out is None: 64 | plt.show() 65 | else: 66 | print(f'save curve to: {args.out}') 67 | plt.savefig(args.out) 68 | plt.cla() 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser(description='Analyze Json Log') 73 | parser.add_argument( 74 | 'json_logs', 75 | type=str, 76 | nargs='+', 77 | help='path of train log in json format') 78 | parser.add_argument( 79 | '--keys', 80 | type=str, 81 | nargs='+', 82 | default=['mIoU'], 83 | help='the metric that you want to plot') 84 | parser.add_argument('--title', type=str, help='title of figure') 85 | parser.add_argument( 86 | '--legend', 87 | type=str, 88 | nargs='+', 89 | default=None, 90 | help='legend of each plot') 91 | parser.add_argument( 92 | '--backend', type=str, default=None, help='backend of plt') 93 | parser.add_argument( 94 | '--style', type=str, default='dark', help='style of plt') 95 | parser.add_argument('--out', type=str, default=None) 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | def load_json_logs(json_logs): 101 | # load and convert json_logs to log_dict, key is epoch, value is a sub dict 102 | # keys of sub dict is different metrics 103 | # value of sub dict is a list of corresponding values of all iterations 104 | log_dicts = [dict() for _ in json_logs] 105 | for json_log, log_dict in zip(json_logs, log_dicts): 106 | with open(json_log, 'r') as log_file: 107 | for line in log_file: 108 | log = json.loads(line.strip()) 109 | # skip lines without `epoch` field 110 | if 'epoch' not in log: 111 | continue 112 | epoch = log.pop('epoch') 113 | if epoch not in log_dict: 114 | log_dict[epoch] = defaultdict(list) 115 | for k, v in log.items(): 116 | log_dict[epoch][k].append(v) 117 | return log_dicts 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | json_logs = args.json_logs 123 | for json_log in json_logs: 124 | assert json_log.endswith('.json') 125 | log_dicts = load_json_logs(json_logs) 126 | plot_curve(log_dicts, args) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /tools/convert_datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 7 | # Modifications: Add class stats computation 8 | 9 | import argparse 10 | import json 11 | import os.path as osp 12 | 13 | import mmcv 14 | import numpy as np 15 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 16 | from PIL import Image 17 | 18 | 19 | def convert_json_to_label(json_file): 20 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') 21 | json2labelImg(json_file, label_file, 'trainIds') 22 | 23 | if 'train/' in json_file: 24 | pil_label = Image.open(label_file) 25 | label = np.asarray(pil_label) 26 | sample_class_stats = {} 27 | for c in range(19): 28 | n = int(np.sum(label == c)) 29 | if n > 0: 30 | sample_class_stats[int(c)] = n 31 | sample_class_stats['file'] = label_file 32 | return sample_class_stats 33 | else: 34 | return None 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser( 39 | description='Convert Cityscapes annotations to TrainIds') 40 | parser.add_argument('cityscapes_path', help='cityscapes data path') 41 | parser.add_argument('--gt-dir', default='gtFine', type=str) 42 | parser.add_argument('-o', '--out-dir', help='output path') 43 | parser.add_argument( 44 | '--nproc', default=1, type=int, help='number of process') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def save_class_stats(out_dir, sample_class_stats): 50 | sample_class_stats = [e for e in sample_class_stats if e is not None] 51 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'w') as of: 52 | json.dump(sample_class_stats, of, indent=2) 53 | 54 | sample_class_stats_dict = {} 55 | for stats in sample_class_stats: 56 | f = stats.pop('file') 57 | sample_class_stats_dict[f] = stats 58 | with open(osp.join(out_dir, 'sample_class_stats_dict.json'), 'w') as of: 59 | json.dump(sample_class_stats_dict, of, indent=2) 60 | 61 | samples_with_class = {} 62 | for file, stats in sample_class_stats_dict.items(): 63 | for c, n in stats.items(): 64 | if c not in samples_with_class: 65 | samples_with_class[c] = [(file, n)] 66 | else: 67 | samples_with_class[c].append((file, n)) 68 | with open(osp.join(out_dir, 'samples_with_class.json'), 'w') as of: 69 | json.dump(samples_with_class, of, indent=2) 70 | 71 | 72 | def main(): 73 | args = parse_args() 74 | cityscapes_path = args.cityscapes_path 75 | out_dir = args.out_dir if args.out_dir else cityscapes_path 76 | mmcv.mkdir_or_exist(out_dir) 77 | 78 | gt_dir = osp.join(cityscapes_path, args.gt_dir) 79 | 80 | poly_files = [] 81 | for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True): 82 | poly_file = osp.join(gt_dir, poly) 83 | poly_files.append(poly_file) 84 | 85 | only_postprocessing = False 86 | if not only_postprocessing: 87 | if args.nproc > 1: 88 | sample_class_stats = mmcv.track_parallel_progress( 89 | convert_json_to_label, poly_files, args.nproc) 90 | else: 91 | sample_class_stats = mmcv.track_progress(convert_json_to_label, 92 | poly_files) 93 | else: 94 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'r') as of: 95 | sample_class_stats = json.load(of) 96 | 97 | save_class_stats(out_dir, sample_class_stats) 98 | 99 | split_names = ['train', 'val', 'test'] 100 | 101 | for split in split_names: 102 | filenames = [] 103 | for poly in mmcv.scandir( 104 | osp.join(gt_dir, split), '_polygons.json', recursive=True): 105 | filenames.append(poly.replace('_gtFine_polygons.json', '')) 106 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: 107 | f.writelines(f + '\n' for f in filenames) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /tools/convert_datasets/gta.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import argparse 7 | import json 8 | import os.path as osp 9 | 10 | import mmcv 11 | import numpy as np 12 | from PIL import Image 13 | 14 | 15 | def convert_to_train_id(file): 16 | # re-assign labels to match the format of Cityscapes 17 | pil_label = Image.open(file) 18 | label = np.asarray(pil_label) 19 | id_to_trainid = { 20 | 7: 0, 21 | 8: 1, 22 | 11: 2, 23 | 12: 3, 24 | 13: 4, 25 | 17: 5, 26 | 19: 6, 27 | 20: 7, 28 | 21: 8, 29 | 22: 9, 30 | 23: 10, 31 | 24: 11, 32 | 25: 12, 33 | 26: 13, 34 | 27: 14, 35 | 28: 15, 36 | 31: 16, 37 | 32: 17, 38 | 33: 18 39 | } 40 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 41 | sample_class_stats = {} 42 | for k, v in id_to_trainid.items(): 43 | k_mask = label == k 44 | label_copy[k_mask] = v 45 | n = int(np.sum(k_mask)) 46 | if n > 0: 47 | sample_class_stats[v] = n 48 | new_file = file.replace('.png', '_labelTrainIds.png') 49 | assert file != new_file 50 | sample_class_stats['file'] = new_file 51 | Image.fromarray(label_copy, mode='L').save(new_file) 52 | return sample_class_stats 53 | 54 | 55 | def parse_args(): 56 | parser = argparse.ArgumentParser( 57 | description='Convert GTA annotations to TrainIds') 58 | parser.add_argument('gta_path', help='gta data path') 59 | parser.add_argument('--gt-dir', default='labels', type=str) 60 | parser.add_argument('-o', '--out-dir', help='output path') 61 | parser.add_argument( 62 | '--nproc', default=4, type=int, help='number of process') 63 | args = parser.parse_args() 64 | return args 65 | 66 | 67 | def save_class_stats(out_dir, sample_class_stats): 68 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'w') as of: 69 | json.dump(sample_class_stats, of, indent=2) 70 | 71 | sample_class_stats_dict = {} 72 | for stats in sample_class_stats: 73 | f = stats.pop('file') 74 | sample_class_stats_dict[f] = stats 75 | with open(osp.join(out_dir, 'sample_class_stats_dict.json'), 'w') as of: 76 | json.dump(sample_class_stats_dict, of, indent=2) 77 | 78 | samples_with_class = {} 79 | for file, stats in sample_class_stats_dict.items(): 80 | for c, n in stats.items(): 81 | if c not in samples_with_class: 82 | samples_with_class[c] = [(file, n)] 83 | else: 84 | samples_with_class[c].append((file, n)) 85 | with open(osp.join(out_dir, 'samples_with_class.json'), 'w') as of: 86 | json.dump(samples_with_class, of, indent=2) 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | gta_path = args.gta_path 92 | out_dir = args.out_dir if args.out_dir else gta_path 93 | mmcv.mkdir_or_exist(out_dir) 94 | 95 | gt_dir = osp.join(gta_path, args.gt_dir) 96 | 97 | poly_files = [] 98 | for poly in mmcv.scandir( 99 | gt_dir, suffix=tuple(f'{i}.png' for i in range(10)), 100 | recursive=True): 101 | poly_file = osp.join(gt_dir, poly) 102 | poly_files.append(poly_file) 103 | poly_files = sorted(poly_files) 104 | 105 | only_postprocessing = False 106 | if not only_postprocessing: 107 | if args.nproc > 1: 108 | sample_class_stats = mmcv.track_parallel_progress( 109 | convert_to_train_id, poly_files, args.nproc) 110 | else: 111 | sample_class_stats = mmcv.track_progress(convert_to_train_id, 112 | poly_files) 113 | else: 114 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'r') as of: 115 | sample_class_stats = json.load(of) 116 | 117 | save_class_stats(out_dir, sample_class_stats) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /tools/convert_datasets/synthia.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import argparse 7 | import json 8 | import os.path as osp 9 | 10 | import cv2 11 | import mmcv 12 | import numpy as np 13 | from PIL import Image 14 | 15 | 16 | def convert_to_train_id(file): 17 | # re-assign labels to match the format of Cityscapes 18 | # PIL does not work with the image format, but cv2 does 19 | label = cv2.imread(file, cv2.IMREAD_UNCHANGED)[:, :, -1] 20 | # mapping based on README.txt from SYNTHIA_RAND_CITYSCAPES 21 | id_to_trainid = { 22 | 3: 0, 23 | 4: 1, 24 | 2: 2, 25 | 21: 3, 26 | 5: 4, 27 | 7: 5, 28 | 15: 6, 29 | 9: 7, 30 | 6: 8, 31 | 16: 9, # not present in synthia 32 | 1: 10, 33 | 10: 11, 34 | 17: 12, 35 | 8: 13, 36 | 18: 14, # not present in synthia 37 | 19: 15, 38 | 20: 16, # not present in synthia 39 | 12: 17, 40 | 11: 18 41 | } 42 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 43 | sample_class_stats = {} 44 | for k, v in id_to_trainid.items(): 45 | k_mask = label == k 46 | label_copy[k_mask] = v 47 | n = int(np.sum(k_mask)) 48 | if n > 0: 49 | sample_class_stats[v] = n 50 | new_file = file.replace('.png', '_labelTrainIds.png') 51 | assert file != new_file 52 | sample_class_stats['file'] = new_file 53 | Image.fromarray(label_copy, mode='L').save(new_file) 54 | return sample_class_stats 55 | 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser( 59 | description='Convert SYNTHIA annotations to TrainIds') 60 | parser.add_argument('synthia_path', help='gta data path') 61 | parser.add_argument('--gt-dir', default='GT/LABELS', type=str) 62 | parser.add_argument('-o', '--out-dir', help='output path') 63 | parser.add_argument( 64 | '--nproc', default=4, type=int, help='number of process') 65 | args = parser.parse_args() 66 | return args 67 | 68 | 69 | def save_class_stats(out_dir, sample_class_stats): 70 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'w') as of: 71 | json.dump(sample_class_stats, of, indent=2) 72 | 73 | sample_class_stats_dict = {} 74 | for stats in sample_class_stats: 75 | f = stats.pop('file') 76 | sample_class_stats_dict[f] = stats 77 | with open(osp.join(out_dir, 'sample_class_stats_dict.json'), 'w') as of: 78 | json.dump(sample_class_stats_dict, of, indent=2) 79 | 80 | samples_with_class = {} 81 | for file, stats in sample_class_stats_dict.items(): 82 | for c, n in stats.items(): 83 | if c not in samples_with_class: 84 | samples_with_class[c] = [(file, n)] 85 | else: 86 | samples_with_class[c].append((file, n)) 87 | with open(osp.join(out_dir, 'samples_with_class.json'), 'w') as of: 88 | json.dump(samples_with_class, of, indent=2) 89 | 90 | 91 | def main(): 92 | args = parse_args() 93 | synthia_path = args.synthia_path 94 | out_dir = args.out_dir if args.out_dir else synthia_path 95 | mmcv.mkdir_or_exist(out_dir) 96 | 97 | gt_dir = osp.join(synthia_path, args.gt_dir) 98 | 99 | poly_files = [] 100 | for poly in mmcv.scandir( 101 | gt_dir, suffix=tuple(f'{i}.png' for i in range(10)), 102 | recursive=True): 103 | poly_file = osp.join(gt_dir, poly) 104 | poly_files.append(poly_file) 105 | poly_files = sorted(poly_files) 106 | 107 | only_postprocessing = False 108 | if not only_postprocessing: 109 | if args.nproc > 1: 110 | sample_class_stats = mmcv.track_parallel_progress( 111 | convert_to_train_id, poly_files, args.nproc) 112 | else: 113 | sample_class_stats = mmcv.track_progress(convert_to_train_id, 114 | poly_files) 115 | else: 116 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'r') as of: 117 | sample_class_stats = json.load(of) 118 | 119 | save_class_stats(out_dir, sample_class_stats) 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /tools/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | #!/bin/bash 7 | 8 | # Instructions for Manual Download: 9 | # 10 | # Please, download the [MiT weights](https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia?usp=sharing) 11 | # pretrained on ImageNet-1K provided by the official 12 | # [SegFormer repository](https://github.com/NVlabs/SegFormer) and put them in a 13 | # folder `pretrained/` within this project. For most of the experiments, only 14 | # mit_b5.pth is necessary. 15 | # 16 | # Please, download the checkpoint of DAFormer on GTA->Cityscapes from 17 | # [here](https://drive.google.com/file/d/1pG3kDClZDGwp1vSTEXmTchkGHmnLQNdP/view?usp=sharing). 18 | # and extract it to `work_dirs/` 19 | 20 | # Automatic Downloads: 21 | set -e # exit when any command fails 22 | mkdir -p pretrained/ 23 | cd pretrained/ 24 | gdown --id 1d3wU8KNjPL4EqMCIEO_rO-O3-REpG82T # MiT-B3 weights 25 | gdown --id 1BUtU42moYrOFbsMCE-LTTkUE-mrWnfG2 # MiT-B4 weights 26 | gdown --id 1d7I50jVjtCddnhpf-lqj8-f13UyCzoW1 # MiT-B5 weights 27 | cd ../ 28 | 29 | mkdir -p work_dirs/ 30 | cd work_dirs/ 31 | gdown --id 1pG3kDClZDGwp1vSTEXmTchkGHmnLQNdP # DAFormer on GTA->Cityscapes 32 | tar -xzf 211108_1622_gta2cs_daformer_s0_7f24c.tar.gz 33 | rm 211108_1622_gta2cs_daformer_s0_7f24c.tar.gz 34 | cd ../ 35 | -------------------------------------------------------------------------------- /tools/get_param_count.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import argparse 4 | import json 5 | import logging 6 | from copy import deepcopy 7 | 8 | from experiments import generate_experiment_cfgs 9 | from mmcv import Config, get_logger 10 | from prettytable import PrettyTable 11 | 12 | from mmseg.models import build_segmentor 13 | 14 | 15 | def human_format(num): 16 | magnitude = 0 17 | while abs(num) >= 1000: 18 | magnitude += 1 19 | num /= 1000.0 20 | # add more suffixes if you need them 21 | return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 22 | 23 | 24 | def count_parameters(model): 25 | table = PrettyTable(['Modules', 'Parameters']) 26 | total_params = 0 27 | for name, parameter in model.named_parameters(): 28 | if not parameter.requires_grad: 29 | continue 30 | param = parameter.numel() 31 | table.add_row([name, human_format(param)]) 32 | total_params += param 33 | # print(table) 34 | print(f'Total Trainable Params: {human_format(total_params)}') 35 | return total_params 36 | 37 | 38 | # Run: python -m tools.param_count 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | '--exp', 43 | nargs='?', 44 | type=int, 45 | default=100, 46 | help='Experiment id as defined in experiment.py', 47 | ) 48 | args = parser.parse_args() 49 | get_logger('mmseg', log_level=logging.ERROR) 50 | cfgs = generate_experiment_cfgs(args.exp) 51 | for cfg in cfgs: 52 | with open('configs/tmp_param.json', 'w') as f: 53 | json.dump(cfg, f) 54 | cfg = Config.fromfile('configs/tmp_param.json') 55 | 56 | model = build_segmentor(deepcopy(cfg['model'])) 57 | # model.init_weights() 58 | # count_parameters(model) 59 | print(f'Encoder {cfg["name_encoder"]}:') 60 | count_parameters(model.backbone) 61 | print(f'Decoder {cfg["name_decoder"]}:') 62 | count_parameters(model.decode_head) 63 | -------------------------------------------------------------------------------- /tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import argparse 4 | 5 | from mmcv import Config, DictAction 6 | 7 | from mmseg.apis import init_segmentor 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Print the whole config') 12 | parser.add_argument('config', help='config file path') 13 | parser.add_argument( 14 | '--graph', action='store_true', help='print the models graph') 15 | parser.add_argument( 16 | '--options', nargs='+', action=DictAction, help='arguments in dict') 17 | args = parser.parse_args() 18 | 19 | return args 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | 25 | cfg = Config.fromfile(args.config) 26 | if args.options is not None: 27 | cfg.merge_from_dict(args.options) 28 | print(f'Config:\n{cfg.pretty_text}') 29 | # dump config 30 | cfg.dump('example.py') 31 | # dump models graph 32 | if args.graph: 33 | model = init_segmentor(args.config, device='cpu') 34 | print(f'Model graph:\n{str(model)}') 35 | with open('example-graph.txt', 'w') as f: 36 | f.writelines(str(model)) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Remove auxiliary models 3 | 4 | import argparse 5 | 6 | import torch 7 | 8 | # import subprocess 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description='Process a checkpoint to be published') 14 | parser.add_argument('in_file', help='input checkpoint filename') 15 | parser.add_argument('out_file', help='output checkpoint filename') 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | def process_checkpoint(in_file, out_file): 21 | checkpoint = torch.load(in_file, map_location='cpu') 22 | # remove optimizer for smaller file size 23 | if 'optimizer' in checkpoint: 24 | del checkpoint['optimizer'] 25 | # remove auxiliary models 26 | for k in list(checkpoint['state_dict'].keys()): 27 | if 'imnet_model' in k or 'ema_model' in k: 28 | del checkpoint['state_dict'][k] 29 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 30 | # add the code here. 31 | if 'meta' in checkpoint: 32 | del checkpoint['meta'] 33 | # inspect checkpoint 34 | print('Checkpoint keys:', checkpoint.keys()) 35 | print('Checkpoint state_dict keys:', checkpoint['state_dict'].keys()) 36 | # save checkpoint 37 | torch.save(checkpoint, out_file) 38 | # sha = subprocess.check_output(['sha256sum', out_file]).decode() 39 | # final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 40 | # subprocess.Popen(['mv', out_file, final_file]) 41 | 42 | 43 | def main(): 44 | args = parse_args() 45 | process_checkpoint(args.in_file, args.out_file) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | --------------------------------------------------------------------------------