├── .gitignore ├── LICENSE ├── README.md ├── configs └── _base_ │ ├── datasets │ ├── uda_cityscapes_to_dark_zurich_640x640_fix_crop.py │ ├── uda_cityscapes_to_dark_zurich_640x640_no_crop.py │ ├── uda_gta_to_cityscapes_640x640_fix_crop.py │ └── uda_gta_to_cityscapes_640x640_no_crop.py │ ├── default_runtime.py │ ├── models │ ├── daformer_sepaspp_proj_mitb5.py │ └── deeplabv2_proj_r50-d8.py │ ├── schedules │ ├── adamw.py │ ├── poly10.py │ └── poly10warm.py │ └── uda │ ├── sepico.py │ └── sepico_dark.py ├── experiments.py ├── mmseg ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── ddp_wrapper.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── class_names.py │ │ ├── eval_hooks.py │ │ └── metrics.py │ ├── seg │ │ ├── __init__.py │ │ ├── builder.py │ │ └── sampler │ │ │ ├── __init__.py │ │ │ ├── base_pixel_sampler.py │ │ │ └── ohem_pixel_sampler.py │ └── utils │ │ ├── __init__.py │ │ └── misc.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── cityscapes.py │ ├── custom.py │ ├── dark_zurich.py │ ├── dataset_wrappers.py │ ├── gta.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formating.py │ │ ├── loading.py │ │ ├── test_time_aug.py │ │ └── transforms.py │ └── uda_dataset.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── mix_transformer.py │ │ └── resnet.py │ ├── builder.py │ ├── decode_heads │ │ ├── __init__.py │ │ ├── aspp_head.py │ │ ├── da_head.py │ │ ├── daformer_head.py │ │ ├── decode_head.py │ │ ├── decode_head_decorator.py │ │ ├── dlv2_head.py │ │ ├── fcn_head.py │ │ ├── isa_head.py │ │ ├── proj_head.py │ │ ├── segformer_head.py │ │ └── sep_aspp_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── contrastive_loss.py │ │ ├── cross_entropy_loss.py │ │ └── utils.py │ ├── necks │ │ └── __init__.py │ ├── segmentors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── encoder_decoder.py │ │ └── encoder_decoder_projector.py │ ├── uda │ │ ├── __init__.py │ │ ├── sepico.py │ │ ├── sepico_dark.py │ │ └── uda_decorator.py │ └── utils │ │ ├── __init__.py │ │ ├── ckpt_convert.py │ │ ├── dacs_transforms.py │ │ ├── make_divisible.py │ │ ├── ours_transforms.py │ │ ├── proto_estimator.py │ │ ├── res_layer.py │ │ ├── self_attention_block.py │ │ ├── shape_convert.py │ │ └── visualization.py ├── ops │ ├── __init__.py │ ├── encoding.py │ └── wrappers.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── logger.py │ └── utils.py └── version.py ├── requirements.txt ├── resources ├── cs2dz_generalization_per_class_results │ ├── bdd100k_night_test.txt │ ├── dark_zurich_test.txt │ └── night_driving_test.txt ├── esi_highly_cited.png ├── license_dacs ├── license_daformer ├── license_dannet ├── license_mmseg ├── license_segformer └── uda_results.png ├── run_experiments.py ├── setup.cfg └── tools ├── analyze_logs.py ├── convert_datasets ├── cityscapes.py └── gta.py ├── download_checkpoints.sh ├── get_param_count.py ├── print_config.py ├── publish_model.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data 107 | .vscode 108 | .idea 109 | 110 | # custom 111 | *.pkl 112 | *.pkl.json 113 | *.log.json 114 | work_dirs/ 115 | mmseg/.mim 116 | 117 | # Pytorch 118 | *.pth 119 | 120 | euler_log.txt 121 | jobs/ 122 | configs/tmp_param.json 123 | configs/generated/ 124 | *.pdf 125 | *.pgf 126 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_cityscapes_to_dark_zurich_640x640_fix_crop.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | dataset_type = 'DarkZurichDataset' 8 | data_root = 'data/dark_zurich/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (640, 640) 12 | cityscapes_train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1280, 640)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico_dark.py 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | dark_zurich_train_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | # dict(type='LoadAnnotations'), 27 | dict(type='Resize', img_scale=(1280, 720)), 28 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.), 29 | dict(type='RandomFlip', prob=0.5), 30 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico_dark.py 31 | dict(type='Normalize', **img_norm_cfg), 32 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img']), 35 | ] 36 | test_pipeline = [ 37 | dict(type='LoadImageFromFile'), 38 | dict( 39 | type='MultiScaleFlipAug', 40 | img_scale=(1280, 720), 41 | # MultiScaleFlipAug is disabled by not providing img_ratios and 42 | # setting flip=False 43 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 44 | flip=False, 45 | transforms=[ 46 | dict(type='Resize', keep_ratio=True), 47 | dict(type='RandomFlip'), 48 | dict(type='Normalize', **img_norm_cfg), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']), 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type='UDADataset', 58 | source=dict( 59 | type='CityscapesDataset', 60 | data_root='data/cityscapes/', 61 | img_dir='leftImg8bit/train', 62 | ann_dir='gtFine/train', 63 | pipeline=cityscapes_train_pipeline), 64 | target=dict( 65 | type='DarkZurichDataset', 66 | data_root='data/dark_zurich/', 67 | img_dir='rgb_anon/train/night', 68 | test_mode=True, 69 | pipeline=dark_zurich_train_pipeline)), 70 | val=dict( 71 | type='DarkZurichDataset', 72 | data_root='data/dark_zurich/', 73 | img_dir='rgb_anon/val/night', 74 | ann_dir='gt/val/night', 75 | pipeline=test_pipeline), 76 | test=dict( 77 | type='DarkZurichDataset', 78 | data_root='data/dark_zurich/', 79 | img_dir='rgb_anon/test/night', 80 | test_mode=True, 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_cityscapes_to_dark_zurich_640x640_no_crop.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | dataset_type = 'DarkZurichDataset' 8 | data_root = 'data/dark_zurich/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (640, 640) 12 | cityscapes_train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1280, 640)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico_dark.py 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | dark_zurich_train_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | # dict(type='LoadAnnotations'), 27 | dict(type='Resize', img_scale=(1280, 720)), 28 | # dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.), # is applied later in sepico_dark.py 29 | dict(type='RandomFlip', prob=0.5), 30 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico_dark.py 31 | dict(type='Normalize', **img_norm_cfg), 32 | # dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), # is applied later in sepico_dark.py 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img']), 35 | ] 36 | test_pipeline = [ 37 | dict(type='LoadImageFromFile'), 38 | dict( 39 | type='MultiScaleFlipAug', 40 | img_scale=(1280, 720), 41 | # MultiScaleFlipAug is disabled by not providing img_ratios and 42 | # setting flip=False 43 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 44 | flip=False, 45 | transforms=[ 46 | dict(type='Resize', keep_ratio=True), 47 | dict(type='RandomFlip'), 48 | dict(type='Normalize', **img_norm_cfg), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']), 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type='UDADataset', 58 | source=dict( 59 | type='CityscapesDataset', 60 | data_root='data/cityscapes/', 61 | img_dir='leftImg8bit/train', 62 | ann_dir='gtFine/train', 63 | pipeline=cityscapes_train_pipeline), 64 | target=dict( 65 | type='DarkZurichDataset', 66 | data_root='data/dark_zurich/', 67 | img_dir='rgb_anon/train/night', 68 | test_mode=True, 69 | pipeline=dark_zurich_train_pipeline)), 70 | val=dict( 71 | type='DarkZurichDataset', 72 | data_root='data/dark_zurich/', 73 | img_dir='rgb_anon/val/night', 74 | ann_dir='gt/val/night', 75 | pipeline=test_pipeline), 76 | test=dict( 77 | type='DarkZurichDataset', 78 | data_root='data/dark_zurich/', 79 | img_dir='rgb_anon/test/night', 80 | test_mode=True, 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_gta_to_cityscapes_640x640_fix_crop.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | dataset_type = 'CityscapesDataset' 8 | data_root = 'data/cityscapes/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (640, 640) 12 | gta_train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1280, 720)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico.py 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | cityscapes_train_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict(type='LoadAnnotations'), 27 | dict(type='Resize', img_scale=(1280, 640)), 28 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.), 29 | dict(type='RandomFlip', prob=0.5), 30 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico.py 31 | dict(type='Normalize', **img_norm_cfg), 32 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 35 | ] 36 | test_pipeline = [ 37 | dict(type='LoadImageFromFile'), 38 | dict( 39 | type='MultiScaleFlipAug', 40 | img_scale=(1280, 640), 41 | # MultiScaleFlipAug is disabled by not providing img_ratios and 42 | # setting flip=False 43 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 44 | flip=False, 45 | transforms=[ 46 | dict(type='Resize', keep_ratio=True), 47 | dict(type='RandomFlip'), 48 | dict(type='Normalize', **img_norm_cfg), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']), 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type='UDADataset', 58 | source=dict( 59 | type='GTADataset', 60 | data_root='data/gta/', 61 | img_dir='images', 62 | ann_dir='labels', 63 | pipeline=gta_train_pipeline), 64 | target=dict( 65 | type='CityscapesDataset', 66 | data_root='data/cityscapes/', 67 | img_dir='leftImg8bit/train', 68 | ann_dir='gtFine/train', 69 | pipeline=cityscapes_train_pipeline)), 70 | val=dict( 71 | type='CityscapesDataset', 72 | data_root='data/cityscapes/', 73 | img_dir='leftImg8bit/val', 74 | ann_dir='gtFine/val', 75 | pipeline=test_pipeline), 76 | test=dict( 77 | type='CityscapesDataset', 78 | data_root='data/cityscapes/', 79 | img_dir='leftImg8bit/val', 80 | ann_dir='gtFine/val', 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/datasets/uda_gta_to_cityscapes_640x640_no_crop.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # dataset settings 7 | dataset_type = 'CityscapesDataset' 8 | data_root = 'data/cityscapes/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (640, 640) 12 | gta_train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1280, 720)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico.py 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | cityscapes_train_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict(type='LoadAnnotations'), 27 | dict(type='Resize', img_scale=(1280, 640)), 28 | # dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.), # is applied later in sepico.py 29 | dict(type='RandomFlip', prob=0.5), 30 | # dict(type='PhotoMetricDistortion'), # is applied later in sepico.py 31 | dict(type='Normalize', **img_norm_cfg), 32 | # dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), # is applied later in sepico.py 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 35 | ] 36 | test_pipeline = [ 37 | dict(type='LoadImageFromFile'), 38 | dict( 39 | type='MultiScaleFlipAug', 40 | img_scale=(1280, 640), 41 | # MultiScaleFlipAug is disabled by not providing img_ratios and 42 | # setting flip=False 43 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 44 | flip=False, 45 | transforms=[ 46 | dict(type='Resize', keep_ratio=True), 47 | dict(type='RandomFlip'), 48 | dict(type='Normalize', **img_norm_cfg), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']), 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type='UDADataset', 58 | source=dict( 59 | type='GTADataset', 60 | data_root='data/gta/', 61 | img_dir='images', 62 | ann_dir='labels', 63 | pipeline=gta_train_pipeline), 64 | target=dict( 65 | type='CityscapesDataset', 66 | data_root='data/cityscapes/', 67 | img_dir='leftImg8bit/train', 68 | ann_dir='gtFine/train', 69 | pipeline=cityscapes_train_pipeline)), 70 | val=dict( 71 | type='CityscapesDataset', 72 | data_root='data/cityscapes/', 73 | img_dir='leftImg8bit/val', 74 | ann_dir='gtFine/val', 75 | pipeline=test_pipeline), 76 | test=dict( 77 | type='CityscapesDataset', 78 | data_root='data/cityscapes/', 79 | img_dir='leftImg8bit/val', 80 | ann_dir='gtFine/val', 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | # yapf:disable 4 | log_config = dict( 5 | interval=50, 6 | hooks=[ 7 | dict(type='TextLoggerHook', by_epoch=False), 8 | dict(type='TensorboardLoggerHook', by_epoch=False) 9 | ]) 10 | # yapf:enable 11 | dist_params = dict(backend='nccl') 12 | log_level = 'INFO' 13 | load_from = None 14 | resume_from = None 15 | workflow = [('train', 1)] 16 | cudnn_benchmark = True 17 | -------------------------------------------------------------------------------- /configs/_base_/models/daformer_sepaspp_proj_mitb5.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | # Adapted from: https://github.com/lhoyer/DAFormer 6 | 7 | # model settings 8 | norm_cfg = dict(type='BN', requires_grad=True) 9 | find_unused_parameters = True 10 | model = dict( 11 | type='EncoderDecoderProjector', 12 | pretrained='pretrained/mit_b5.pth', 13 | backbone=dict(type='mit_b5', style='pytorch'), 14 | decode_head=dict( 15 | type='DAFormerHead', 16 | in_channels=[64, 128, 320, 512], 17 | in_index=[0, 1, 2, 3], 18 | channels=256, 19 | dropout_ratio=0.1, 20 | num_classes=19, 21 | norm_cfg=norm_cfg, 22 | align_corners=False, 23 | decoder_params=dict( 24 | embed_dims=256, 25 | embed_cfg=dict(type='mlp', act_cfg=None, norm_cfg=None), 26 | embed_neck_cfg=dict(type='mlp', act_cfg=None, norm_cfg=None), 27 | fusion_cfg=dict( 28 | type='aspp', 29 | sep=True, 30 | dilations=(1, 6, 12, 18), 31 | pool=False, 32 | act_cfg=dict(type='ReLU'), 33 | norm_cfg=norm_cfg)), 34 | loss_decode=dict( 35 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 36 | auxiliary_head=dict( 37 | type='ProjHead', 38 | in_channels=[64, 128, 320, 512], 39 | in_index=[0, 1, 2, 3], # int or list, depending on value of input_transform 40 | input_transform='resize_concat', # optional(None, 'resize_concat', 'multiple_select') 41 | channels=256, 42 | num_convs=2, 43 | dropout_ratio=0.1, 44 | num_classes=19, 45 | norm_cfg=norm_cfg, 46 | align_corners=False, 47 | loss_decode=dict( 48 | type='ContrastiveLoss', use_dist=False, use_bank=False, use_reg=False, 49 | use_avg_pool=True, scale_min_ratio=0.75, num_classes=19, 50 | contrast_temp=100., loss_weight=1.0, reg_relative_weight=0.01)), 51 | # model training and testing settings 52 | train_cfg=dict(), 53 | test_cfg=dict(mode='whole')) 54 | -------------------------------------------------------------------------------- /configs/_base_/models/deeplabv2_proj_r50-d8.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | # Adapted from: https://github.com/lhoyer/DAFormer 6 | 7 | # model settings 8 | norm_cfg = dict(type='BN', requires_grad=True) 9 | model = dict( 10 | type='EncoderDecoderProjector', 11 | pretrained='open-mmlab://resnet50_v1c', 12 | backbone=dict( 13 | type='ResNetV1c', 14 | depth=50, 15 | num_stages=4, 16 | out_indices=(0, 1, 2, 3), 17 | dilations=(1, 1, 2, 4), 18 | strides=(1, 2, 1, 1), 19 | norm_cfg=norm_cfg, 20 | norm_eval=False, 21 | style='pytorch', 22 | contract_dilation=True), 23 | decode_head=dict( 24 | type='DLV2Head', 25 | in_channels=2048, 26 | in_index=3, 27 | dilations=(6, 12, 18, 24), 28 | num_classes=19, 29 | align_corners=False, 30 | init_cfg=dict( 31 | type='Normal', std=0.01, override=dict(name='aspp_modules')), 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 34 | auxiliary_head=dict( 35 | type='ProjHead', 36 | in_channels=2048, 37 | in_index=3, # int or list, depending on value of input_transform 38 | input_transform=None, # optional(None, 'resize_concat', 'multiple_select') 39 | channels=512, 40 | num_convs=2, 41 | dropout_ratio=0.1, 42 | num_classes=19, 43 | norm_cfg=norm_cfg, 44 | align_corners=False, 45 | loss_decode=dict( 46 | type='ContrastiveLoss', use_dist=False, use_bank=False, use_reg=False, 47 | use_avg_pool=True, scale_min_ratio=0.75, num_classes=19, 48 | contrast_temp=100., loss_weight=1.0, reg_relative_weight=0.01)), 49 | # model training and testing settings 50 | train_cfg=dict(), 51 | test_cfg=dict(mode='whole')) 52 | -------------------------------------------------------------------------------- /configs/_base_/schedules/adamw.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | # optimizer 4 | optimizer = dict( 5 | type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01) 6 | optimizer_config = dict() 7 | -------------------------------------------------------------------------------- /configs/_base_/schedules/poly10.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # learning policy 7 | lr_config = dict(policy='poly', power=1.0, min_lr=1e-4, by_epoch=False) 8 | -------------------------------------------------------------------------------- /configs/_base_/schedules/poly10warm.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # learning policy 7 | lr_config = dict( 8 | policy='poly', 9 | warmup='linear', 10 | warmup_iters=1500, 11 | warmup_ratio=1e-6, 12 | power=1.0, 13 | min_lr=0.0, 14 | by_epoch=False) 15 | -------------------------------------------------------------------------------- /configs/_base_/uda/sepico.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # SePiCo 7 | uda = dict( 8 | type='SePiCo', 9 | alpha=0.999, 10 | pseudo_threshold=0.968, 11 | pseudo_weight_ignore_top=0, 12 | pseudo_weight_ignore_bottom=0, 13 | enable_self_training=False, 14 | enable_strong_aug=True, 15 | start_distribution_iter=0, 16 | mix='class', 17 | blur=True, 18 | color_jitter_strength=0.2, 19 | color_jitter_probability=0.2, 20 | debug_img_interval=1000, 21 | ) 22 | use_ddp_wrapper = True 23 | -------------------------------------------------------------------------------- /configs/_base_/uda/sepico_dark.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # SePiCo Dark 7 | _base_ = ['sepico.py'] 8 | uda = dict( 9 | type='SePiCoDark', 10 | ) 11 | -------------------------------------------------------------------------------- /mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | from .version import __version__, version_info 4 | 5 | MMCV_MIN = '1.3.7' 6 | MMCV_MAX = '1.4.0' 7 | 8 | 9 | def digit_version(version_str): 10 | digit_version = [] 11 | for x in version_str.split('.'): 12 | if x.isdigit(): 13 | digit_version.append(int(x)) 14 | elif x.find('rc') != -1: 15 | patch_version = x.split('rc') 16 | digit_version.append(int(patch_version[0]) - 1) 17 | digit_version.append(int(patch_version[1])) 18 | return digit_version 19 | 20 | 21 | mmcv_min_version = digit_version(MMCV_MIN) 22 | mmcv_max_version = digit_version(MMCV_MAX) 23 | mmcv_version = digit_version(mmcv.__version__) 24 | 25 | 26 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 27 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 28 | f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 29 | 30 | __all__ = ['__version__', 'version_info'] 31 | -------------------------------------------------------------------------------- /mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 4 | from .test import multi_gpu_test, single_gpu_test 5 | from .train import get_root_logger, set_random_seed, train_segmentor 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 9 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 10 | 'show_result_pyplot' 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/apis/inference.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Override palette, classes, and state dict keys 4 | 5 | import matplotlib.pyplot as plt 6 | import mmcv 7 | import torch 8 | from mmcv.parallel import collate, scatter 9 | from mmcv.runner import load_checkpoint 10 | 11 | from mmseg.datasets.pipelines import Compose 12 | from mmseg.models import build_segmentor 13 | 14 | 15 | def init_segmentor(config, 16 | checkpoint=None, 17 | device='cuda:0', 18 | classes=None, 19 | palette=None, 20 | revise_checkpoint=[(r'^module\.', '')]): 21 | """Initialize a segmentor from config file. 22 | 23 | Args: 24 | config (str or :obj:`mmcv.Config`): Config file path or the config 25 | object. 26 | checkpoint (str, optional): Checkpoint path. If left as None, the model 27 | will not load any weights. 28 | device (str, optional) CPU/CUDA device option. Default 'cuda:0'. 29 | Use 'cpu' for loading model on CPU. 30 | Returns: 31 | nn.Module: The constructed segmentor. 32 | """ 33 | if isinstance(config, str): 34 | config = mmcv.Config.fromfile(config) 35 | elif not isinstance(config, mmcv.Config): 36 | raise TypeError('config must be a filename or Config object, ' 37 | 'but got {}'.format(type(config))) 38 | config.model.pretrained = None 39 | config.model.train_cfg = None 40 | model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) 41 | if checkpoint is not None: 42 | checkpoint = load_checkpoint( 43 | model, 44 | checkpoint, 45 | map_location='cpu', 46 | revise_keys=revise_checkpoint) 47 | model.CLASSES = checkpoint['meta']['CLASSES'] if classes is None \ 48 | else classes 49 | model.PALETTE = checkpoint['meta']['PALETTE'] if palette is None \ 50 | else palette 51 | model.cfg = config # save the config in the model for convenience 52 | model.to(device) 53 | model.eval() 54 | return model 55 | 56 | 57 | class LoadImage: 58 | """A simple pipeline to load image.""" 59 | 60 | def __call__(self, results): 61 | """Call function to load images into results. 62 | 63 | Args: 64 | results (dict): A result dict contains the file name 65 | of the image to be read. 66 | 67 | Returns: 68 | dict: ``results`` will be returned containing loaded image. 69 | """ 70 | 71 | if isinstance(results['img'], str): 72 | results['filename'] = results['img'] 73 | results['ori_filename'] = results['img'] 74 | else: 75 | results['filename'] = None 76 | results['ori_filename'] = None 77 | img = mmcv.imread(results['img']) 78 | results['img'] = img 79 | results['img_shape'] = img.shape 80 | results['ori_shape'] = img.shape 81 | return results 82 | 83 | 84 | def inference_segmentor(model, img): 85 | """Inference image(s) with the segmentor. 86 | 87 | Args: 88 | model (nn.Module): The loaded segmentor. 89 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 90 | images. 91 | 92 | Returns: 93 | (list[Tensor]): The segmentation result. 94 | """ 95 | cfg = model.cfg 96 | device = next(model.parameters()).device # model device 97 | # build the data pipeline 98 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 99 | test_pipeline = Compose(test_pipeline) 100 | # prepare data 101 | data = dict(img=img) 102 | data = test_pipeline(data) 103 | data = collate([data], samples_per_gpu=1) 104 | if next(model.parameters()).is_cuda: 105 | # scatter to specified GPU 106 | data = scatter(data, [device])[0] 107 | else: 108 | data['img_metas'] = [i.data[0] for i in data['img_metas']] 109 | 110 | # forward the model 111 | with torch.no_grad(): 112 | result = model(return_loss=False, rescale=True, **data) 113 | return result 114 | 115 | 116 | def show_result_pyplot(model, 117 | img, 118 | result, 119 | palette=None, 120 | fig_size=(15, 10), 121 | opacity=0.5, 122 | title='', 123 | block=True): 124 | """Visualize the segmentation results on the image. 125 | 126 | Args: 127 | model (nn.Module): The loaded segmentor. 128 | img (str or np.ndarray): Image filename or loaded image. 129 | result (list): The segmentation result. 130 | palette (list[list[int]]] | None): The palette of segmentation 131 | map. If None is given, random palette will be generated. 132 | Default: None 133 | fig_size (tuple): Figure size of the pyplot figure. 134 | opacity(float): Opacity of painted segmentation map. 135 | Default 0.5. 136 | Must be in (0, 1] range. 137 | title (str): The title of pyplot figure. 138 | Default is ''. 139 | block (bool): Whether to block the pyplot figure. 140 | Default is True. 141 | """ 142 | if hasattr(model, 'module'): 143 | model = model.module 144 | img = model.show_result( 145 | img, result, palette=palette, show=False, opacity=opacity) 146 | plt.figure(figsize=fig_size) 147 | plt.imshow(mmcv.bgr2rgb(img)) 148 | plt.title(title) 149 | plt.tight_layout() 150 | plt.show(block=block) 151 | -------------------------------------------------------------------------------- /mmseg/apis/test.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import os.path as osp 4 | import tempfile 5 | 6 | import mmcv 7 | import numpy as np 8 | import torch 9 | from mmcv.engine import collect_results_cpu, collect_results_gpu 10 | from mmcv.image import tensor2imgs 11 | from mmcv.runner import get_dist_info 12 | 13 | 14 | def np2tmp(array, temp_file_name=None, tmpdir=None): 15 | """Save ndarray to local numpy file. 16 | 17 | Args: 18 | array (ndarray): Ndarray to save. 19 | temp_file_name (str): Numpy file name. If 'temp_file_name=None', this 20 | function will generate a file name with tempfile.NamedTemporaryFile 21 | to save ndarray. Default: None. 22 | tmpdir (str): Temporary directory to save Ndarray files. Default: None. 23 | 24 | Returns: 25 | str: The numpy file name. 26 | """ 27 | 28 | if temp_file_name is None: 29 | temp_file_name = tempfile.NamedTemporaryFile( 30 | suffix='.npy', delete=False, dir=tmpdir).name 31 | np.save(temp_file_name, array) 32 | return temp_file_name 33 | 34 | 35 | def single_gpu_test(model, 36 | data_loader, 37 | show=False, 38 | out_dir=None, 39 | efficient_test=False, 40 | opacity=0.5): 41 | """Test with single GPU. 42 | 43 | Args: 44 | model (nn.Module): Model to be tested. 45 | data_loader (utils.data.Dataloader): Pytorch data loader. 46 | show (bool): Whether show results during inference. Default: False. 47 | out_dir (str, optional): If specified, the results will be dumped into 48 | the directory to save output results. 49 | efficient_test (bool): Whether save the results as local numpy files to 50 | save CPU memory during evaluation. Default: False. 51 | opacity(float): Opacity of painted segmentation map. 52 | Default 0.5. 53 | Must be in (0, 1] range. 54 | Returns: 55 | list: The prediction results. 56 | """ 57 | 58 | model.eval() 59 | results = [] 60 | dataset = data_loader.dataset 61 | prog_bar = mmcv.ProgressBar(len(dataset)) 62 | if efficient_test: 63 | mmcv.mkdir_or_exist('.efficient_test') 64 | for i, data in enumerate(data_loader): 65 | with torch.no_grad(): 66 | result = model(return_loss=False, **data) 67 | 68 | if show or out_dir: 69 | img_tensor = data['img'][0] 70 | img_metas = data['img_metas'][0].data[0] 71 | imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) 72 | assert len(imgs) == len(img_metas) 73 | 74 | for img, img_meta in zip(imgs, img_metas): 75 | h, w, _ = img_meta['img_shape'] 76 | img_show = img[:h, :w, :] 77 | 78 | ori_h, ori_w = img_meta['ori_shape'][:-1] 79 | img_show = mmcv.imresize(img_show, (ori_w, ori_h)) 80 | 81 | if out_dir: 82 | out_file = osp.join(out_dir, img_meta['ori_filename']) 83 | else: 84 | out_file = None 85 | 86 | model.module.show_result( 87 | img_show, 88 | result, 89 | palette=dataset.PALETTE, 90 | show=show, 91 | out_file=out_file, 92 | opacity=opacity) 93 | 94 | if isinstance(result, list): 95 | if efficient_test: 96 | result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] 97 | results.extend(result) 98 | else: 99 | if efficient_test: 100 | result = np2tmp(result, tmpdir='.efficient_test') 101 | results.append(result) 102 | 103 | batch_size = len(result) 104 | for _ in range(batch_size): 105 | prog_bar.update() 106 | return results 107 | 108 | 109 | def multi_gpu_test(model, 110 | data_loader, 111 | tmpdir=None, 112 | gpu_collect=False, 113 | efficient_test=False): 114 | """Test model with multiple gpus. 115 | 116 | This method tests model with multiple gpus and collects the results 117 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 118 | it encodes results to gpu tensors and use gpu communication for results 119 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 120 | and collects them by the rank 0 worker. 121 | 122 | Args: 123 | model (nn.Module): Model to be tested. 124 | data_loader (utils.data.Dataloader): Pytorch data loader. 125 | tmpdir (str): Path of directory to save the temporary results from 126 | different gpus under cpu mode. The same path is used for efficient 127 | test. 128 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 129 | efficient_test (bool): Whether save the results as local numpy files to 130 | save CPU memory during evaluation. Default: False. 131 | 132 | Returns: 133 | list: The prediction results. 134 | """ 135 | 136 | model.eval() 137 | results = [] 138 | dataset = data_loader.dataset 139 | rank, world_size = get_dist_info() 140 | if rank == 0: 141 | prog_bar = mmcv.ProgressBar(len(dataset)) 142 | if efficient_test: 143 | mmcv.mkdir_or_exist('.efficient_test') 144 | for i, data in enumerate(data_loader): 145 | with torch.no_grad(): 146 | result = model(return_loss=False, rescale=True, **data) 147 | 148 | if isinstance(result, list): 149 | if efficient_test: 150 | result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] 151 | results.extend(result) 152 | else: 153 | if efficient_test: 154 | result = np2tmp(result, tmpdir='.efficient_test') 155 | results.append(result) 156 | 157 | if rank == 0: 158 | batch_size = len(result) 159 | for _ in range(batch_size * world_size): 160 | prog_bar.update() 161 | 162 | # collect results from all ranks 163 | if gpu_collect: 164 | results = collect_results_gpu(results, len(dataset)) 165 | else: 166 | results = collect_results_cpu(results, len(dataset), tmpdir) 167 | return results 168 | -------------------------------------------------------------------------------- /mmseg/apis/train.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Add ddp_wrapper from mmgen 4 | 5 | import random 6 | import warnings 7 | 8 | import mmcv 9 | import numpy as np 10 | import torch 11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 12 | from mmcv.runner import build_optimizer, build_runner 13 | 14 | from mmseg.core import DistEvalHook, EvalHook 15 | from mmseg.core.ddp_wrapper import DistributedDataParallelWrapper 16 | from mmseg.datasets import build_dataloader, build_dataset 17 | from mmseg.utils import get_root_logger 18 | 19 | 20 | def set_random_seed(seed, deterministic=False): 21 | """Set random seed. 22 | 23 | Args: 24 | seed (int): Seed to be used. 25 | deterministic (bool): Whether to set the deterministic option for 26 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 27 | to True and `torch.backends.cudnn.benchmark` to False. 28 | Default: False. 29 | """ 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | if deterministic: 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | 39 | def train_segmentor(model, 40 | dataset, 41 | cfg, 42 | distributed=False, 43 | validate=False, 44 | timestamp=None, 45 | meta=None): 46 | """Launch segmentor training.""" 47 | logger = get_root_logger(cfg.log_level) 48 | 49 | # prepare data loaders 50 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 51 | data_loaders = [ 52 | build_dataloader( 53 | ds, 54 | cfg.data.samples_per_gpu, 55 | cfg.data.workers_per_gpu, 56 | # cfg.gpus will be ignored if distributed 57 | len(cfg.gpu_ids), 58 | dist=distributed, 59 | seed=cfg.seed, 60 | drop_last=True) for ds in dataset 61 | ] 62 | 63 | # put model on gpus 64 | if distributed: 65 | find_unused_parameters = cfg.get('find_unused_parameters', False) 66 | use_ddp_wrapper = cfg.get('use_ddp_wrapper', False) 67 | # Sets the `find_unused_parameters` parameter in 68 | # torch.nn.parallel.DistributedDataParallel 69 | if use_ddp_wrapper: 70 | mmcv.print_log('Use DDP Wrapper.', 'mmseg') 71 | model = DistributedDataParallelWrapper( 72 | model.cuda(), 73 | device_ids=[torch.cuda.current_device()], 74 | broadcast_buffers=False, 75 | find_unused_parameters=find_unused_parameters) 76 | else: 77 | model = MMDistributedDataParallel( 78 | model.cuda(), 79 | device_ids=[torch.cuda.current_device()], 80 | broadcast_buffers=False, 81 | find_unused_parameters=find_unused_parameters) 82 | else: 83 | model = MMDataParallel( 84 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 85 | 86 | # build runner 87 | optimizer = build_optimizer(model, cfg.optimizer) 88 | 89 | if cfg.get('runner') is None: 90 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 91 | warnings.warn( 92 | 'config is now expected to have a `runner` section, ' 93 | 'please set `runner` in your config.', UserWarning) 94 | 95 | runner = build_runner( 96 | cfg.runner, 97 | default_args=dict( 98 | model=model, 99 | batch_processor=None, 100 | optimizer=optimizer, 101 | work_dir=cfg.work_dir, 102 | logger=logger, 103 | meta=meta)) 104 | 105 | # register hooks 106 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 107 | cfg.checkpoint_config, cfg.log_config, 108 | cfg.get('momentum_config', None)) 109 | 110 | # an ugly walkaround to make the .log and .log.json filenames the same 111 | runner.timestamp = timestamp 112 | 113 | # register eval hooks 114 | if validate: 115 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 116 | val_dataloader = build_dataloader( 117 | val_dataset, 118 | samples_per_gpu=1, 119 | workers_per_gpu=cfg.data.workers_per_gpu, 120 | dist=distributed, 121 | shuffle=False) 122 | eval_cfg = cfg.get('evaluation', {}) 123 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 124 | eval_hook = DistEvalHook if distributed else EvalHook 125 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 126 | 127 | if cfg.resume_from: 128 | runner.resume(cfg.resume_from) 129 | elif cfg.load_from: 130 | runner.load_checkpoint(cfg.load_from) 131 | runner.run(data_loaders, cfg.workflow) 132 | -------------------------------------------------------------------------------- /mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation import * # noqa: F401, F403 2 | from .seg import * # noqa: F401, F403 3 | from .utils import * # noqa: F401, F403 4 | -------------------------------------------------------------------------------- /mmseg/core/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmgeneration 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel 6 | from mmcv.parallel.scatter_gather import scatter_kwargs 7 | from torch.cuda._utils import _get_device_index 8 | 9 | 10 | @MODULE_WRAPPERS.register_module('mmseg.DDPWrapper') 11 | class DistributedDataParallelWrapper(nn.Module): 12 | """A DistributedDataParallel wrapper for models in MMGeneration. 13 | 14 | In MMedting, there is a need to wrap different modules in the models 15 | with separate DistributedDataParallel. Otherwise, it will cause 16 | errors for GAN training. 17 | More specific, the GAN model, usually has two sub-modules: 18 | generator and discriminator. If we wrap both of them in one 19 | standard DistributedDataParallel, it will cause errors during training, 20 | because when we update the parameters of the generator (or discriminator), 21 | the parameters of the discriminator (or generator) is not updated, which is 22 | not allowed for DistributedDataParallel. 23 | So we design this wrapper to separately wrap DistributedDataParallel 24 | for generator and discriminator. 25 | In this wrapper, we perform two operations: 26 | 1. Wrap the modules in the models with separate MMDistributedDataParallel. 27 | Note that only modules with parameters will be wrapped. 28 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. 29 | Note that the arguments of this wrapper is the same as those in 30 | `torch.nn.parallel.distributed.DistributedDataParallel`. 31 | Args: 32 | module (nn.Module): Module that needs to be wrapped. 33 | device_ids (list[int | `torch.device`]): Same as that in 34 | `torch.nn.parallel.distributed.DistributedDataParallel`. 35 | dim (int, optional): Same as that in the official scatter function in 36 | pytorch. Defaults to 0. 37 | broadcast_buffers (bool): Same as that in 38 | `torch.nn.parallel.distributed.DistributedDataParallel`. 39 | Defaults to False. 40 | find_unused_parameters (bool, optional): Same as that in 41 | `torch.nn.parallel.distributed.DistributedDataParallel`. 42 | Traverse the autograd graph of all tensors contained in returned 43 | value of the wrapped module’s forward function. Defaults to False. 44 | kwargs (dict): Other arguments used in 45 | `torch.nn.parallel.distributed.DistributedDataParallel`. 46 | """ 47 | 48 | def __init__(self, 49 | module, 50 | device_ids, 51 | dim=0, 52 | broadcast_buffers=False, 53 | find_unused_parameters=False, 54 | **kwargs): 55 | super().__init__() 56 | assert len(device_ids) == 1, ( 57 | 'Currently, DistributedDataParallelWrapper only supports one' 58 | 'single CUDA device for each process.' 59 | f'The length of device_ids must be 1, but got {len(device_ids)}.') 60 | self.module = module 61 | self.dim = dim 62 | self.to_ddp( 63 | device_ids=device_ids, 64 | dim=dim, 65 | broadcast_buffers=broadcast_buffers, 66 | find_unused_parameters=find_unused_parameters, 67 | **kwargs) 68 | self.output_device = _get_device_index(device_ids[0], True) 69 | 70 | def to_ddp(self, device_ids, dim, broadcast_buffers, 71 | find_unused_parameters, **kwargs): 72 | """Wrap models with separate MMDistributedDataParallel. 73 | 74 | It only wraps the modules with parameters. 75 | """ 76 | for name, module in self.module._modules.items(): 77 | if next(module.parameters(), None) is None: 78 | module = module.cuda() 79 | elif all(not p.requires_grad for p in module.parameters()): 80 | module = module.cuda() 81 | else: 82 | module = MMDistributedDataParallel( 83 | module.cuda(), 84 | device_ids=device_ids, 85 | dim=dim, 86 | broadcast_buffers=broadcast_buffers, 87 | find_unused_parameters=find_unused_parameters, 88 | **kwargs) 89 | self.module._modules[name] = module 90 | 91 | def scatter(self, inputs, kwargs, device_ids): 92 | """Scatter function. 93 | 94 | Args: 95 | inputs (Tensor): Input Tensor. 96 | kwargs (dict): Args for 97 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 98 | device_ids (int): Device id. 99 | """ 100 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 101 | 102 | def forward(self, *inputs, **kwargs): 103 | """Forward function. 104 | 105 | Args: 106 | inputs (tuple): Input data. 107 | kwargs (dict): Args for 108 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 109 | """ 110 | inputs, kwargs = self.scatter(inputs, kwargs, 111 | [torch.cuda.current_device()]) 112 | return self.module(*inputs[0], **kwargs[0]) 113 | 114 | def train_step(self, *inputs, **kwargs): 115 | """Train step function. 116 | 117 | Args: 118 | inputs (Tensor): Input Tensor. 119 | kwargs (dict): Args for 120 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 121 | """ 122 | inputs, kwargs = self.scatter(inputs, kwargs, 123 | [torch.cuda.current_device()]) 124 | output = self.module.train_step(*inputs[0], **kwargs[0]) 125 | return output 126 | 127 | def val_step(self, *inputs, **kwargs): 128 | """Validation step function. 129 | 130 | Args: 131 | inputs (tuple): Input data. 132 | kwargs (dict): Args for ``scatter_kwargs``. 133 | """ 134 | inputs, kwargs = self.scatter(inputs, kwargs, 135 | [torch.cuda.current_device()]) 136 | output = self.module.val_step(*inputs[0], **kwargs[0]) 137 | return output 138 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .class_names import get_classes, get_palette 4 | from .eval_hooks import DistEvalHook, EvalHook 5 | from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou 6 | 7 | __all__ = [ 8 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 9 | 'eval_metrics', 'get_classes', 'get_palette' 10 | ] 11 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import os.path as osp 4 | 5 | import torch.distributed as dist 6 | from mmcv.runner import DistEvalHook as _DistEvalHook 7 | from mmcv.runner import EvalHook as _EvalHook 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | 10 | 11 | class EvalHook(_EvalHook): 12 | """Single GPU EvalHook, with efficient test support. 13 | 14 | Args: 15 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 16 | If set to True, it will perform by epoch. Otherwise, by iteration. 17 | Default: False. 18 | efficient_test (bool): Whether save the results as local numpy files to 19 | save CPU memory during evaluation. Default: False. 20 | Returns: 21 | list: The prediction results. 22 | """ 23 | 24 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 25 | 26 | def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): 27 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 28 | self.efficient_test = efficient_test 29 | 30 | def _do_evaluate(self, runner): 31 | """perform evaluation and save ckpt.""" 32 | if not self._should_evaluate(runner): 33 | return 34 | 35 | from mmseg.apis import single_gpu_test 36 | results = single_gpu_test( 37 | runner.model, 38 | self.dataloader, 39 | show=False, 40 | efficient_test=self.efficient_test) 41 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 42 | key_score = self.evaluate(runner, results) 43 | if self.save_best: 44 | self._save_ckpt(runner, key_score) 45 | 46 | 47 | class DistEvalHook(_DistEvalHook): 48 | """Distributed EvalHook, with efficient test support. 49 | 50 | Args: 51 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 52 | If set to True, it will perform by epoch. Otherwise, by iteration. 53 | Default: False. 54 | efficient_test (bool): Whether save the results as local numpy files to 55 | save CPU memory during evaluation. Default: False. 56 | Returns: 57 | list: The prediction results. 58 | """ 59 | 60 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 61 | 62 | def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): 63 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 64 | self.efficient_test = efficient_test 65 | 66 | def _do_evaluate(self, runner): 67 | """perform evaluation and save ckpt.""" 68 | # Synchronization of BatchNorm's buffer (running_mean 69 | # and running_var) is not supported in the DDP of pytorch, 70 | # which may cause the inconsistent performance of models in 71 | # different ranks, so we broadcast BatchNorm's buffers 72 | # of rank 0 to other ranks to avoid this. 73 | if self.broadcast_bn_buffer: 74 | model = runner.model 75 | for name, module in model.named_modules(): 76 | if isinstance(module, 77 | _BatchNorm) and module.track_running_stats: 78 | dist.broadcast(module.running_var, 0) 79 | dist.broadcast(module.running_mean, 0) 80 | 81 | if not self._should_evaluate(runner): 82 | return 83 | 84 | tmpdir = self.tmpdir 85 | if tmpdir is None: 86 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 87 | 88 | from mmseg.apis import multi_gpu_test 89 | results = multi_gpu_test( 90 | runner.model, 91 | self.dataloader, 92 | tmpdir=tmpdir, 93 | gpu_collect=self.gpu_collect, 94 | efficient_test=self.efficient_test) 95 | if runner.rank == 0: 96 | print('\n') 97 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 98 | key_score = self.evaluate(runner, results) 99 | 100 | if self.save_best: 101 | self._save_ckpt(runner, key_score) 102 | -------------------------------------------------------------------------------- /mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .builder import build_pixel_sampler 4 | from .sampler import BasePixelSampler, OHEMPixelSampler 5 | 6 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from mmcv.utils import Registry, build_from_cfg 4 | 5 | PIXEL_SAMPLERS = Registry('pixel sampler') 6 | 7 | 8 | def build_pixel_sampler(cfg, **default_args): 9 | """Build pixel sampler for segmentation map.""" 10 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 11 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .base_pixel_sampler import BasePixelSampler 4 | from .ohem_pixel_sampler import OHEMPixelSampler 5 | 6 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | 6 | class BasePixelSampler(metaclass=ABCMeta): 7 | """Base class of pixel sampler.""" 8 | 9 | def __init__(self, **kwargs): 10 | pass 11 | 12 | @abstractmethod 13 | def sample(self, seg_logit, seg_label): 14 | """Placeholder for sample function.""" 15 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ..builder import PIXEL_SAMPLERS 7 | from .base_pixel_sampler import BasePixelSampler 8 | 9 | 10 | @PIXEL_SAMPLERS.register_module() 11 | class OHEMPixelSampler(BasePixelSampler): 12 | """Online Hard Example Mining Sampler for segmentation. 13 | 14 | Args: 15 | context (nn.Module): The context of sampler, subclass of 16 | :obj:`BaseDecodeHead`. 17 | thresh (float, optional): The threshold for hard example selection. 18 | Below which, are prediction with low confidence. If not 19 | specified, the hard examples will be pixels of top ``min_kept`` 20 | loss. Default: None. 21 | min_kept (int, optional): The minimum number of predictions to keep. 22 | Default: 100000. 23 | """ 24 | 25 | def __init__(self, context, thresh=None, min_kept=100000): 26 | super(OHEMPixelSampler, self).__init__() 27 | self.context = context 28 | assert min_kept > 1 29 | self.thresh = thresh 30 | self.min_kept = min_kept 31 | 32 | def sample(self, seg_logit, seg_label): 33 | """Sample pixels that have high loss or with low prediction confidence. 34 | 35 | Args: 36 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 37 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 38 | 39 | Returns: 40 | torch.Tensor: segmentation weight, shape (N, H, W) 41 | """ 42 | with torch.no_grad(): 43 | assert seg_logit.shape[2:] == seg_label.shape[2:] 44 | assert seg_label.shape[1] == 1 45 | seg_label = seg_label.squeeze(1).long() 46 | batch_kept = self.min_kept * seg_label.size(0) 47 | valid_mask = seg_label != self.context.ignore_index 48 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 49 | valid_seg_weight = seg_weight[valid_mask] 50 | if self.thresh is not None: 51 | seg_prob = F.softmax(seg_logit, dim=1) 52 | 53 | tmp_seg_label = seg_label.clone().unsqueeze(1) 54 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 55 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 56 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 57 | 58 | if sort_prob.numel() > 0: 59 | min_threshold = sort_prob[min(batch_kept, 60 | sort_prob.numel() - 1)] 61 | else: 62 | min_threshold = 0.0 63 | threshold = max(min_threshold, self.thresh) 64 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 65 | else: 66 | losses = self.context.loss_decode( 67 | seg_logit, 68 | seg_label, 69 | weight=None, 70 | ignore_index=self.context.ignore_index, 71 | reduction_override='none') 72 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 73 | _, sort_indices = losses[valid_mask].sort(descending=True) 74 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 75 | 76 | seg_weight[valid_mask] = valid_seg_weight 77 | 78 | return seg_weight 79 | -------------------------------------------------------------------------------- /mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .misc import add_prefix 4 | 5 | __all__ = ['add_prefix'] 6 | -------------------------------------------------------------------------------- /mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def add_prefix(inputs, prefix): 5 | """Add prefix for dict. 6 | 7 | Args: 8 | inputs (dict): The input dict with str keys. 9 | prefix (str): The prefix to add. 10 | 11 | Returns: 12 | 13 | dict: The dict with keys updated with ``prefix``. 14 | """ 15 | 16 | outputs = dict() 17 | for name, value in inputs.items(): 18 | outputs[f'{prefix}.{name}'] = value 19 | 20 | return outputs 21 | -------------------------------------------------------------------------------- /mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 4 | from .cityscapes import CityscapesDataset 5 | from .custom import CustomDataset 6 | from .dataset_wrappers import ConcatDataset, RepeatDataset 7 | from .dark_zurich import DarkZurichDataset 8 | from .gta import GTADataset 9 | from .uda_dataset import UDADataset 10 | 11 | __all__ = [ 12 | 'CustomDataset', 13 | 'build_dataloader', 14 | 'ConcatDataset', 15 | 'RepeatDataset', 16 | 'DATASETS', 17 | 'build_dataset', 18 | 'PIPELINES', 19 | 'CityscapesDataset', 20 | 'DarkZurichDataset', 21 | 'GTADataset', 22 | 'UDADataset', 23 | ] 24 | -------------------------------------------------------------------------------- /mmseg/datasets/dark_zurich.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.17.0 2 | 3 | from .builder import DATASETS 4 | from .cityscapes import CityscapesDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class DarkZurichDataset(CityscapesDataset): 9 | """DarkZurichDataset dataset.""" 10 | 11 | def __init__(self, **kwargs): 12 | super(DarkZurichDataset, self).__init__( 13 | img_suffix='_rgb_anon.png', 14 | seg_map_suffix='_gt_labelTrainIds.png', 15 | **kwargs) 16 | -------------------------------------------------------------------------------- /mmseg/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 4 | 5 | from .builder import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class ConcatDataset(_ConcatDataset): 10 | """A wrapper of concatenated dataset. 11 | 12 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 13 | concat the group flag for image aspect ratio. 14 | 15 | Args: 16 | datasets (list[:obj:`Dataset`]): A list of datasets. 17 | """ 18 | 19 | def __init__(self, datasets): 20 | super(ConcatDataset, self).__init__(datasets) 21 | self.CLASSES = datasets[0].CLASSES 22 | self.PALETTE = datasets[0].PALETTE 23 | 24 | 25 | @DATASETS.register_module() 26 | class RepeatDataset(object): 27 | """A wrapper of repeated dataset. 28 | 29 | The length of repeated dataset will be `times` larger than the original 30 | dataset. This is useful when the data loading time is long but the dataset 31 | is small. Using RepeatDataset can reduce the data loading time between 32 | epochs. 33 | 34 | Args: 35 | dataset (:obj:`Dataset`): The dataset to be repeated. 36 | times (int): Repeat times. 37 | """ 38 | 39 | def __init__(self, dataset, times): 40 | self.dataset = dataset 41 | self.times = times 42 | self.CLASSES = dataset.CLASSES 43 | self.PALETTE = dataset.PALETTE 44 | self._ori_len = len(self.dataset) 45 | 46 | def __getitem__(self, idx): 47 | """Get item from original dataset.""" 48 | return self.dataset[idx % self._ori_len] 49 | 50 | def __len__(self): 51 | """The length is multiplied by ``times``""" 52 | return self.times * self._ori_len 53 | -------------------------------------------------------------------------------- /mmseg/datasets/gta.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | from . import CityscapesDataset 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class GTADataset(CustomDataset): 13 | CLASSES = CityscapesDataset.CLASSES 14 | PALETTE = CityscapesDataset.PALETTE 15 | 16 | def __init__(self, **kwargs): 17 | assert kwargs.get('split') in [None, 'train'] 18 | if 'split' in kwargs: 19 | kwargs.pop('split') 20 | super(GTADataset, self).__init__( 21 | img_suffix='.png', 22 | seg_map_suffix='_labelTrainIds.png', 23 | split=None, 24 | **kwargs) 25 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .compose import Compose 4 | from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, 5 | Transpose, to_tensor) 6 | from .loading import LoadAnnotations, LoadImageFromFile 7 | from .test_time_aug import MultiScaleFlipAug 8 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 9 | PhotoMetricDistortion, RandomCrop, RandomFlip, 10 | RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) 11 | 12 | __all__ = [ 13 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 14 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 15 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 16 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 17 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' 18 | ] 19 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import collections 4 | 5 | from mmcv.utils import build_from_cfg 6 | 7 | from ..builder import PIPELINES 8 | 9 | 10 | @PIPELINES.register_module() 11 | class Compose(object): 12 | """Compose multiple transforms sequentially. 13 | 14 | Args: 15 | transforms (Sequence[dict | callable]): Sequence of transform object or 16 | config dict to be composed. 17 | """ 18 | 19 | def __init__(self, transforms): 20 | assert isinstance(transforms, collections.abc.Sequence) 21 | self.transforms = [] 22 | for transform in transforms: 23 | if isinstance(transform, dict): 24 | transform = build_from_cfg(transform, PIPELINES) 25 | self.transforms.append(transform) 26 | elif callable(transform): 27 | self.transforms.append(transform) 28 | else: 29 | raise TypeError('transform must be callable or a dict') 30 | 31 | def __call__(self, data): 32 | """Call function to apply transforms sequentially. 33 | 34 | Args: 35 | data (dict): A result dict contains the data to transform. 36 | 37 | Returns: 38 | dict: Transformed data. 39 | """ 40 | 41 | for t in self.transforms: 42 | data = t(data) 43 | if data is None: 44 | return None 45 | return data 46 | 47 | def __repr__(self): 48 | format_string = self.__class__.__name__ + '(' 49 | for t in self.transforms: 50 | format_string += '\n' 51 | format_string += f' {t}' 52 | format_string += '\n)' 53 | return format_string 54 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import os.path as osp 4 | 5 | import mmcv 6 | import numpy as np 7 | 8 | from ..builder import PIPELINES 9 | 10 | 11 | @PIPELINES.register_module() 12 | class LoadImageFromFile(object): 13 | """Load an image from file. 14 | 15 | Required keys are "img_prefix" and "img_info" (a dict that must contain the 16 | key "filename"). Added or updated keys are "filename", "img", "img_shape", 17 | "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), 18 | "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). 19 | 20 | Args: 21 | to_float32 (bool): Whether to convert the loaded image to a float32 22 | numpy array. If set to False, the loaded image is an uint8 array. 23 | Defaults to False. 24 | color_type (str): The flag argument for :func:`mmcv.imfrombytes`. 25 | Defaults to 'color'. 26 | file_client_args (dict): Arguments to instantiate a FileClient. 27 | See :class:`mmcv.fileio.FileClient` for details. 28 | Defaults to ``dict(backend='disk')``. 29 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 30 | 'cv2' 31 | """ 32 | 33 | def __init__(self, 34 | to_float32=False, 35 | color_type='color', 36 | file_client_args=dict(backend='disk'), 37 | imdecode_backend='cv2'): 38 | self.to_float32 = to_float32 39 | self.color_type = color_type 40 | self.file_client_args = file_client_args.copy() 41 | self.file_client = None 42 | self.imdecode_backend = imdecode_backend 43 | 44 | def __call__(self, results): 45 | """Call functions to load image and get image meta information. 46 | 47 | Args: 48 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 49 | 50 | Returns: 51 | dict: The dict contains loaded image and meta information. 52 | """ 53 | 54 | if self.file_client is None: 55 | self.file_client = mmcv.FileClient(**self.file_client_args) 56 | 57 | if results.get('img_prefix') is not None: 58 | filename = osp.join(results['img_prefix'], 59 | results['img_info']['filename']) 60 | else: 61 | filename = results['img_info']['filename'] 62 | img_bytes = self.file_client.get(filename) 63 | img = mmcv.imfrombytes( 64 | img_bytes, flag=self.color_type, backend=self.imdecode_backend) 65 | if self.to_float32: 66 | img = img.astype(np.float32) 67 | #################################### 68 | # Added by lmj for debug 69 | if img is None: 70 | with open('bug_img.log', 'w') as f: 71 | f.write(filename + '\n') 72 | #################################### 73 | 74 | results['filename'] = filename 75 | results['ori_filename'] = results['img_info']['filename'] 76 | results['img'] = img 77 | results['img_shape'] = img.shape 78 | results['ori_shape'] = img.shape 79 | # Set initial values for default meta_keys 80 | results['pad_shape'] = img.shape 81 | results['scale_factor'] = 1.0 82 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 83 | results['img_norm_cfg'] = dict( 84 | mean=np.zeros(num_channels, dtype=np.float32), 85 | std=np.ones(num_channels, dtype=np.float32), 86 | to_rgb=False) 87 | return results 88 | 89 | def __repr__(self): 90 | repr_str = self.__class__.__name__ 91 | repr_str += f'(to_float32={self.to_float32},' 92 | repr_str += f"color_type='{self.color_type}'," 93 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 94 | return repr_str 95 | 96 | 97 | @PIPELINES.register_module() 98 | class LoadAnnotations(object): 99 | """Load annotations for semantic segmentation. 100 | 101 | Args: 102 | reduce_zero_label (bool): Whether reduce all label value by 1. 103 | Usually used for datasets where 0 is background label. 104 | Default: False. 105 | file_client_args (dict): Arguments to instantiate a FileClient. 106 | See :class:`mmcv.fileio.FileClient` for details. 107 | Defaults to ``dict(backend='disk')``. 108 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 109 | 'pillow' 110 | """ 111 | 112 | def __init__(self, 113 | reduce_zero_label=False, 114 | file_client_args=dict(backend='disk'), 115 | imdecode_backend='pillow'): 116 | self.reduce_zero_label = reduce_zero_label 117 | self.file_client_args = file_client_args.copy() 118 | self.file_client = None 119 | self.imdecode_backend = imdecode_backend 120 | 121 | def __call__(self, results): 122 | """Call function to load multiple types annotations. 123 | 124 | Args: 125 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 126 | 127 | Returns: 128 | dict: The dict contains loaded semantic segmentation annotations. 129 | """ 130 | 131 | if self.file_client is None: 132 | self.file_client = mmcv.FileClient(**self.file_client_args) 133 | 134 | if results.get('seg_prefix', None) is not None: 135 | filename = osp.join(results['seg_prefix'], 136 | results['ann_info']['seg_map']) 137 | else: 138 | filename = results['ann_info']['seg_map'] 139 | img_bytes = self.file_client.get(filename) 140 | gt_semantic_seg = mmcv.imfrombytes( 141 | img_bytes, flag='unchanged', 142 | backend=self.imdecode_backend).squeeze().astype(np.uint8) 143 | # modify if custom classes 144 | if results.get('label_map', None) is not None: 145 | for old_id, new_id in results['label_map'].items(): 146 | gt_semantic_seg[gt_semantic_seg == old_id] = new_id 147 | # reduce zero_label 148 | if self.reduce_zero_label: 149 | # avoid using underflow conversion 150 | gt_semantic_seg[gt_semantic_seg == 0] = 255 151 | gt_semantic_seg = gt_semantic_seg - 1 152 | gt_semantic_seg[gt_semantic_seg == 254] = 255 153 | results['gt_semantic_seg'] = gt_semantic_seg 154 | results['seg_fields'].append('gt_semantic_seg') 155 | return results 156 | 157 | def __repr__(self): 158 | repr_str = self.__class__.__name__ 159 | repr_str += f'(reduce_zero_label={self.reduce_zero_label},' 160 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 161 | return repr_str 162 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/test_time_aug.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import warnings 4 | 5 | import mmcv 6 | 7 | from ..builder import PIPELINES 8 | from .compose import Compose 9 | 10 | 11 | @PIPELINES.register_module() 12 | class MultiScaleFlipAug(object): 13 | """Test-time augmentation with multiple scales and flipping. 14 | 15 | An example configuration is as followed: 16 | 17 | .. code-block:: 18 | 19 | img_scale=(2048, 1024), 20 | img_ratios=[0.5, 1.0], 21 | flip=True, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ] 30 | 31 | After MultiScaleFLipAug with above configuration, the results are wrapped 32 | into lists of the same length as followed: 33 | 34 | .. code-block:: 35 | 36 | dict( 37 | img=[...], 38 | img_shape=[...], 39 | scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] 40 | flip=[False, True, False, True] 41 | ... 42 | ) 43 | 44 | Args: 45 | transforms (list[dict]): Transforms to apply in each augmentation. 46 | img_scale (None | tuple | list[tuple]): Images scales for resizing. 47 | img_ratios (float | list[float]): Image ratios for resizing 48 | flip (bool): Whether apply flip augmentation. Default: False. 49 | flip_direction (str | list[str]): Flip augmentation directions, 50 | options are "horizontal" and "vertical". If flip_direction is list, 51 | multiple flip augmentations will be applied. 52 | It has no effect when flip == False. Default: "horizontal". 53 | """ 54 | 55 | def __init__(self, 56 | transforms, 57 | img_scale, 58 | img_ratios=None, 59 | flip=False, 60 | flip_direction='horizontal'): 61 | self.transforms = Compose(transforms) 62 | if img_ratios is not None: 63 | img_ratios = img_ratios if isinstance(img_ratios, 64 | list) else [img_ratios] 65 | assert mmcv.is_list_of(img_ratios, float) 66 | if img_scale is None: 67 | # mode 1: given img_scale=None and a range of image ratio 68 | self.img_scale = None 69 | assert mmcv.is_list_of(img_ratios, float) 70 | elif isinstance(img_scale, tuple) and mmcv.is_list_of( 71 | img_ratios, float): 72 | assert len(img_scale) == 2 73 | # mode 2: given a scale and a range of image ratio 74 | self.img_scale = [(int(img_scale[0] * ratio), 75 | int(img_scale[1] * ratio)) 76 | for ratio in img_ratios] 77 | else: 78 | # mode 3: given multiple scales 79 | self.img_scale = img_scale if isinstance(img_scale, 80 | list) else [img_scale] 81 | assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None 82 | self.flip = flip 83 | self.img_ratios = img_ratios 84 | self.flip_direction = flip_direction if isinstance( 85 | flip_direction, list) else [flip_direction] 86 | assert mmcv.is_list_of(self.flip_direction, str) 87 | if not self.flip and self.flip_direction != ['horizontal']: 88 | warnings.warn( 89 | 'flip_direction has no effect when flip is set to False') 90 | if (self.flip 91 | and not any([t['type'] == 'RandomFlip' for t in transforms])): 92 | warnings.warn( 93 | 'flip has no effect when RandomFlip is not in transforms') 94 | 95 | def __call__(self, results): 96 | """Call function to apply test time augment transforms on results. 97 | 98 | Args: 99 | results (dict): Result dict contains the data to transform. 100 | 101 | Returns: 102 | dict[str: list]: The augmented data, where each value is wrapped 103 | into a list. 104 | """ 105 | 106 | aug_data = [] 107 | if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): 108 | h, w = results['img'].shape[:2] 109 | img_scale = [(int(w * ratio), int(h * ratio)) 110 | for ratio in self.img_ratios] 111 | else: 112 | img_scale = self.img_scale 113 | flip_aug = [False, True] if self.flip else [False] 114 | for scale in img_scale: 115 | for flip in flip_aug: 116 | for direction in self.flip_direction: 117 | _results = results.copy() 118 | _results['scale'] = scale 119 | _results['flip'] = flip 120 | _results['flip_direction'] = direction 121 | data = self.transforms(_results) 122 | aug_data.append(data) 123 | # list of dict to dict of list 124 | aug_data_dict = {key: [] for key in aug_data[0]} 125 | for data in aug_data: 126 | for key, val in data.items(): 127 | aug_data_dict[key].append(val) 128 | return aug_data_dict 129 | 130 | def __repr__(self): 131 | repr_str = self.__class__.__name__ 132 | repr_str += f'(transforms={self.transforms}, ' 133 | repr_str += f'img_scale={self.img_scale}, flip={self.flip})' 134 | repr_str += f'flip_direction={self.flip_direction}' 135 | return repr_str 136 | -------------------------------------------------------------------------------- /mmseg/datasets/uda_dataset.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/lhoyer/DAFormer 2 | import json 3 | import os.path as osp 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch 8 | 9 | from . import CityscapesDataset 10 | from .builder import DATASETS 11 | 12 | 13 | def get_rcs_class_probs(data_root, temperature): 14 | with open(osp.join(data_root, 'sample_class_stats.json'), 'r') as of: 15 | sample_class_stats = json.load(of) 16 | overall_class_stats = {} 17 | for s in sample_class_stats: 18 | s.pop('file') 19 | for c, n in s.items(): 20 | c = int(c) 21 | if c not in overall_class_stats: 22 | overall_class_stats[c] = n 23 | else: 24 | overall_class_stats[c] += n 25 | overall_class_stats = { 26 | k: v 27 | for k, v in sorted( 28 | overall_class_stats.items(), key=lambda item: item[1]) 29 | } 30 | freq = torch.tensor(list(overall_class_stats.values())) 31 | freq = freq / torch.sum(freq) 32 | freq = 1 - freq 33 | freq = torch.softmax(freq / temperature, dim=-1) 34 | 35 | return list(overall_class_stats.keys()), freq.numpy() 36 | 37 | 38 | @DATASETS.register_module() 39 | class UDADataset(object): 40 | 41 | def __init__(self, source, target, cfg): 42 | self.source = source 43 | self.target = target 44 | self.ignore_index = target.ignore_index 45 | self.CLASSES = target.CLASSES 46 | self.PALETTE = target.PALETTE 47 | assert target.ignore_index == source.ignore_index 48 | assert target.CLASSES == source.CLASSES 49 | assert target.PALETTE == source.PALETTE 50 | 51 | rcs_cfg = cfg.get('rare_class_sampling') 52 | self.rcs_enabled = rcs_cfg is not None 53 | if self.rcs_enabled: 54 | self.rcs_class_temp = rcs_cfg['class_temp'] 55 | self.rcs_min_crop_ratio = rcs_cfg['min_crop_ratio'] 56 | self.rcs_min_pixels = rcs_cfg['min_pixels'] 57 | 58 | self.rcs_classes, self.rcs_classprob = get_rcs_class_probs( 59 | cfg['source']['data_root'], self.rcs_class_temp) 60 | mmcv.print_log(f'RCS Classes: {self.rcs_classes}', 'mmseg') 61 | mmcv.print_log(f'RCS ClassProb: {self.rcs_classprob}', 'mmseg') 62 | 63 | with open( 64 | osp.join(cfg['source']['data_root'], 65 | 'samples_with_class.json'), 'r') as of: 66 | samples_with_class_and_n = json.load(of) 67 | samples_with_class_and_n = { 68 | int(k): v 69 | for k, v in samples_with_class_and_n.items() 70 | if int(k) in self.rcs_classes 71 | } 72 | self.samples_with_class = {} 73 | for c in self.rcs_classes: 74 | self.samples_with_class[c] = [] 75 | for file, pixels in samples_with_class_and_n[c]: 76 | if pixels > self.rcs_min_pixels: 77 | self.samples_with_class[c].append(file.split('/')[-1]) 78 | assert len(self.samples_with_class[c]) > 0 79 | self.file_to_idx = {} 80 | for i, dic in enumerate(self.source.img_infos): 81 | file = dic['ann']['seg_map'] 82 | if isinstance(self.source, CityscapesDataset): 83 | file = file.split('/')[-1] 84 | self.file_to_idx[file] = i 85 | 86 | def get_rare_class_sample(self): 87 | c = np.random.choice(self.rcs_classes, p=self.rcs_classprob) 88 | f1 = np.random.choice(self.samples_with_class[c]) 89 | i1 = self.file_to_idx[f1] 90 | s1 = self.source[i1] 91 | if self.rcs_min_crop_ratio > 0: 92 | for j in range(10): 93 | n_class = torch.sum(s1['gt_semantic_seg'].data == c) 94 | # mmcv.print_log(f'{j}: {n_class}', 'mmseg') 95 | if n_class > self.rcs_min_pixels * self.rcs_min_crop_ratio: 96 | break 97 | s1 = self.source[i1] 98 | i2 = np.random.choice(range(len(self.target))) 99 | s2 = self.target[i2] 100 | 101 | return { 102 | **s1, 'target_img_metas': s2['img_metas'], 103 | 'target_img': s2['img'], 'target_gt_semantic_seg': s2.get('gt_semantic_seg', -1) 104 | } 105 | 106 | def __getitem__(self, idx): 107 | if self.rcs_enabled: 108 | return self.get_rare_class_sample() 109 | else: 110 | s1 = self.source[idx // len(self.target)] 111 | s2 = self.target[idx % len(self.target)] 112 | return { 113 | **s1, 'target_img_metas': s2['img_metas'], 114 | 'target_img': s2['img'], 'target_gt_semantic_seg': s2.get('gt_semantic_seg', -1) 115 | } 116 | 117 | def __len__(self): 118 | return len(self.source) * len(self.target) 119 | -------------------------------------------------------------------------------- /mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * # noqa: F401,F403 2 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, UDA, 3 | build_backbone, build_head, build_loss, build_segmentor) 4 | from .decode_heads import * # noqa: F401,F403 5 | from .losses import * # noqa: F401,F403 6 | from .necks import * # noqa: F401,F403 7 | from .segmentors import * # noqa: F401,F403 8 | from .uda import * # noqa: F401,F403 9 | 10 | __all__ = [ 11 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'UDA', 'build_backbone', 12 | 'build_head', 'build_loss', 'build_segmentor' 13 | ] 14 | -------------------------------------------------------------------------------- /mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional backbones 3 | 4 | from .mix_transformer import (MixVisionTransformer, mit_b0, mit_b1, mit_b2, 5 | mit_b3, mit_b4, mit_b5) 6 | from .resnet import ResNet, ResNetV1c, ResNetV1d 7 | 8 | __all__ = [ 9 | 'ResNet', 10 | 'ResNetV1c', 11 | 'ResNetV1d', 12 | 'MixVisionTransformer', 13 | 'mit_b0', 14 | 'mit_b1', 15 | 'mit_b2', 16 | 'mit_b3', 17 | 'mit_b4', 18 | 'mit_b5', 19 | ] 20 | -------------------------------------------------------------------------------- /mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Support UDA models 3 | 4 | import warnings 5 | 6 | from mmcv.cnn import MODELS as MMCV_MODELS 7 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 8 | from mmcv.utils import Registry 9 | 10 | MODELS = Registry('models', parent=MMCV_MODELS) 11 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 12 | 13 | BACKBONES = MODELS 14 | NECKS = MODELS 15 | HEADS = MODELS 16 | LOSSES = MODELS 17 | SEGMENTORS = MODELS 18 | UDA = MODELS 19 | 20 | 21 | def build_backbone(cfg): 22 | """Build backbone.""" 23 | return BACKBONES.build(cfg) 24 | 25 | 26 | def build_neck(cfg): 27 | """Build neck.""" 28 | return NECKS.build(cfg) 29 | 30 | 31 | def build_head(cfg): 32 | """Build head.""" 33 | return HEADS.build(cfg) 34 | 35 | 36 | def build_loss(cfg): 37 | """Build loss.""" 38 | return LOSSES.build(cfg) 39 | 40 | 41 | def build_train_model(cfg, train_cfg=None, test_cfg=None): 42 | """Build model.""" 43 | if train_cfg is not None or test_cfg is not None: 44 | warnings.warn( 45 | 'train_cfg and test_cfg is deprecated, ' 46 | 'please specify them in model', UserWarning) 47 | assert cfg.model.get('train_cfg') is None or train_cfg is None, \ 48 | 'train_cfg specified in both outer field and model field ' 49 | assert cfg.model.get('test_cfg') is None or test_cfg is None, \ 50 | 'test_cfg specified in both outer field and model field ' 51 | if 'uda' in cfg: 52 | cfg.uda['model'] = cfg.model 53 | cfg.uda['max_iters'] = cfg.runner.max_iters 54 | cfg.uda['pipeline'] = cfg.data.train.target.pipeline 55 | return UDA.build( 56 | cfg.uda, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 57 | else: 58 | return SEGMENTORS.build( 59 | cfg.model, 60 | default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 61 | 62 | 63 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 64 | """Build segmentor.""" 65 | if train_cfg is not None or test_cfg is not None: 66 | warnings.warn( 67 | 'train_cfg and test_cfg is deprecated, ' 68 | 'please specify them in model', UserWarning) 69 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 70 | 'train_cfg specified in both outer field and model field ' 71 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 72 | 'test_cfg specified in both outer field and model field ' 73 | return SEGMENTORS.build( 74 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 75 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional decode_heads 3 | 4 | from .aspp_head import ASPPHead 5 | from .da_head import DAHead 6 | from .daformer_head import DAFormerHead 7 | from .dlv2_head import DLV2Head 8 | from .fcn_head import FCNHead 9 | from .isa_head import ISAHead 10 | from .segformer_head import SegFormerHead 11 | from .sep_aspp_head import DepthwiseSeparableASPPHead 12 | from .proj_head import ProjHead 13 | 14 | __all__ = [ 15 | 'FCNHead', 16 | 'ASPPHead', 17 | 'DepthwiseSeparableASPPHead', 18 | 'DAHead', 19 | 'DLV2Head', 20 | 'SegFormerHead', 21 | 'DAFormerHead', 22 | 'ISAHead', 23 | 'ProjHead', 24 | ] 25 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/aspp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | class ASPPModule(nn.ModuleList): 13 | """Atrous Spatial Pyramid Pooling (ASPP) Module. 14 | 15 | Args: 16 | dilations (tuple[int]): Dilation rate of each layer. 17 | in_channels (int): Input channels. 18 | channels (int): Channels after modules, before conv_seg. 19 | conv_cfg (dict|None): Config of conv layers. 20 | norm_cfg (dict|None): Config of norm layers. 21 | act_cfg (dict): Config of activation layers. 22 | """ 23 | 24 | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, 25 | act_cfg): 26 | super(ASPPModule, self).__init__() 27 | self.dilations = dilations 28 | self.in_channels = in_channels 29 | self.channels = channels 30 | self.conv_cfg = conv_cfg 31 | self.norm_cfg = norm_cfg 32 | self.act_cfg = act_cfg 33 | for dilation in dilations: 34 | self.append( 35 | ConvModule( 36 | self.in_channels, 37 | self.channels, 38 | 1 if dilation == 1 else 3, 39 | dilation=dilation, 40 | padding=0 if dilation == 1 else dilation, 41 | conv_cfg=self.conv_cfg, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg)) 44 | 45 | def forward(self, x): 46 | """Forward function.""" 47 | aspp_outs = [] 48 | for aspp_module in self: 49 | aspp_outs.append(aspp_module(x)) 50 | 51 | return aspp_outs 52 | 53 | 54 | @HEADS.register_module() 55 | class ASPPHead(BaseDecodeHead): 56 | """Rethinking Atrous Convolution for Semantic Image Segmentation. 57 | 58 | This head is the implementation of `DeepLabV3 59 | `_. 60 | 61 | Args: 62 | dilations (tuple[int]): Dilation rates for ASPP module. 63 | Default: (1, 6, 12, 18). 64 | """ 65 | 66 | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 67 | super(ASPPHead, self).__init__(**kwargs) 68 | assert isinstance(dilations, (list, tuple)) 69 | self.dilations = dilations 70 | self.image_pool = nn.Sequential( 71 | nn.AdaptiveAvgPool2d(1), 72 | ConvModule( 73 | self.in_channels, 74 | self.channels, 75 | 1, 76 | conv_cfg=self.conv_cfg, 77 | norm_cfg=self.norm_cfg, 78 | act_cfg=self.act_cfg)) 79 | self.aspp_modules = ASPPModule( 80 | dilations, 81 | self.in_channels, 82 | self.channels, 83 | conv_cfg=self.conv_cfg, 84 | norm_cfg=self.norm_cfg, 85 | act_cfg=self.act_cfg) 86 | self.bottleneck = ConvModule( 87 | (len(dilations) + 1) * self.channels, 88 | self.channels, 89 | 3, 90 | padding=1, 91 | conv_cfg=self.conv_cfg, 92 | norm_cfg=self.norm_cfg, 93 | act_cfg=self.act_cfg) 94 | 95 | def forward(self, inputs): 96 | """Forward function.""" 97 | x = self._transform_inputs(inputs) 98 | aspp_outs = [ 99 | resize( 100 | self.image_pool(x), 101 | size=x.size()[2:], 102 | mode='bilinear', 103 | align_corners=self.align_corners) 104 | ] 105 | aspp_outs.extend(self.aspp_modules(x)) 106 | aspp_outs = torch.cat(aspp_outs, dim=1) 107 | output = self.bottleneck(aspp_outs) 108 | output = self.cls_seg(output) 109 | return output 110 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/da_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Support for seg_weight 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from mmcv.cnn import ConvModule, Scale 7 | from torch import nn 8 | 9 | from mmseg.core import add_prefix 10 | from ..builder import HEADS 11 | from ..utils import SelfAttentionBlock as _SelfAttentionBlock 12 | from .decode_head import BaseDecodeHead 13 | 14 | 15 | class PAM(_SelfAttentionBlock): 16 | """Position Attention Module (PAM) 17 | 18 | Args: 19 | in_channels (int): Input channels of key/query feature. 20 | channels (int): Output channels of key/query transform. 21 | """ 22 | 23 | def __init__(self, in_channels, channels): 24 | super(PAM, self).__init__( 25 | key_in_channels=in_channels, 26 | query_in_channels=in_channels, 27 | channels=channels, 28 | out_channels=in_channels, 29 | share_key_query=False, 30 | query_downsample=None, 31 | key_downsample=None, 32 | key_query_num_convs=1, 33 | key_query_norm=False, 34 | value_out_num_convs=1, 35 | value_out_norm=False, 36 | matmul_norm=False, 37 | with_out=False, 38 | conv_cfg=None, 39 | norm_cfg=None, 40 | act_cfg=None) 41 | 42 | self.gamma = Scale(0) 43 | 44 | def forward(self, x): 45 | """Forward function.""" 46 | out = super(PAM, self).forward(x, x) 47 | 48 | out = self.gamma(out) + x 49 | return out 50 | 51 | 52 | class CAM(nn.Module): 53 | """Channel Attention Module (CAM)""" 54 | 55 | def __init__(self): 56 | super(CAM, self).__init__() 57 | self.gamma = Scale(0) 58 | 59 | def forward(self, x): 60 | """Forward function.""" 61 | batch_size, channels, height, width = x.size() 62 | proj_query = x.view(batch_size, channels, -1) 63 | proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) 64 | energy = torch.bmm(proj_query, proj_key) 65 | energy_new = torch.max( 66 | energy, -1, keepdim=True)[0].expand_as(energy) - energy 67 | attention = F.softmax(energy_new, dim=-1) 68 | proj_value = x.view(batch_size, channels, -1) 69 | 70 | out = torch.bmm(attention, proj_value) 71 | out = out.view(batch_size, channels, height, width) 72 | 73 | out = self.gamma(out) + x 74 | return out 75 | 76 | 77 | @HEADS.register_module() 78 | class DAHead(BaseDecodeHead): 79 | """Dual Attention Network for Scene Segmentation. 80 | 81 | This head is the implementation of `DANet 82 | `_. 83 | 84 | Args: 85 | pam_channels (int): The channels of Position Attention Module(PAM). 86 | """ 87 | 88 | def __init__(self, pam_channels, **kwargs): 89 | super(DAHead, self).__init__(**kwargs) 90 | self.pam_channels = pam_channels 91 | self.pam_in_conv = ConvModule( 92 | self.in_channels, 93 | self.channels, 94 | 3, 95 | padding=1, 96 | conv_cfg=self.conv_cfg, 97 | norm_cfg=self.norm_cfg, 98 | act_cfg=self.act_cfg) 99 | self.pam = PAM(self.channels, pam_channels) 100 | self.pam_out_conv = ConvModule( 101 | self.channels, 102 | self.channels, 103 | 3, 104 | padding=1, 105 | conv_cfg=self.conv_cfg, 106 | norm_cfg=self.norm_cfg, 107 | act_cfg=self.act_cfg) 108 | self.pam_conv_seg = nn.Conv2d( 109 | self.channels, self.num_classes, kernel_size=1) 110 | 111 | self.cam_in_conv = ConvModule( 112 | self.in_channels, 113 | self.channels, 114 | 3, 115 | padding=1, 116 | conv_cfg=self.conv_cfg, 117 | norm_cfg=self.norm_cfg, 118 | act_cfg=self.act_cfg) 119 | self.cam = CAM() 120 | self.cam_out_conv = ConvModule( 121 | self.channels, 122 | self.channels, 123 | 3, 124 | padding=1, 125 | conv_cfg=self.conv_cfg, 126 | norm_cfg=self.norm_cfg, 127 | act_cfg=self.act_cfg) 128 | self.cam_conv_seg = nn.Conv2d( 129 | self.channels, self.num_classes, kernel_size=1) 130 | 131 | def pam_cls_seg(self, feat): 132 | """PAM feature classification.""" 133 | if self.dropout is not None: 134 | feat = self.dropout(feat) 135 | output = self.pam_conv_seg(feat) 136 | return output 137 | 138 | def cam_cls_seg(self, feat): 139 | """CAM feature classification.""" 140 | if self.dropout is not None: 141 | feat = self.dropout(feat) 142 | output = self.cam_conv_seg(feat) 143 | return output 144 | 145 | def forward(self, inputs): 146 | """Forward function.""" 147 | x = self._transform_inputs(inputs) 148 | pam_feat = self.pam_in_conv(x) 149 | pam_feat = self.pam(pam_feat) 150 | pam_feat = self.pam_out_conv(pam_feat) 151 | pam_out = self.pam_cls_seg(pam_feat) 152 | 153 | cam_feat = self.cam_in_conv(x) 154 | cam_feat = self.cam(cam_feat) 155 | cam_feat = self.cam_out_conv(cam_feat) 156 | cam_out = self.cam_cls_seg(cam_feat) 157 | 158 | feat_sum = pam_feat + cam_feat 159 | pam_cam_out = self.cls_seg(feat_sum) 160 | 161 | return pam_cam_out, pam_out, cam_out 162 | 163 | def forward_test(self, inputs, img_metas, test_cfg): 164 | """Forward function for testing, only ``pam_cam`` is used.""" 165 | return self.forward(inputs)[0] 166 | 167 | def losses(self, seg_logit, seg_label, seg_weight=None): 168 | """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" 169 | pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit 170 | loss = dict() 171 | loss.update( 172 | add_prefix( 173 | super(DAHead, self).losses(pam_cam_seg_logit, seg_label, 174 | seg_weight), 'pam_cam')) 175 | loss.update( 176 | add_prefix( 177 | super(DAHead, self).losses(pam_seg_logit, seg_label, 178 | seg_weight), 'pam')) 179 | loss.update( 180 | add_prefix( 181 | super(DAHead, self).losses(cam_seg_logit, seg_label, 182 | seg_weight), 'cam')) 183 | return loss 184 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/daformer_head.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 9 | 10 | from mmseg.models.decode_heads.isa_head import ISALayer 11 | from mmseg.ops import resize 12 | from ..builder import HEADS 13 | from .aspp_head import ASPPModule 14 | from .decode_head import BaseDecodeHead 15 | from .segformer_head import MLP 16 | from .sep_aspp_head import DepthwiseSeparableASPPModule 17 | 18 | 19 | class ASPPWrapper(nn.Module): 20 | 21 | def __init__(self, 22 | in_channels, 23 | channels, 24 | sep, 25 | dilations, 26 | pool, 27 | norm_cfg, 28 | act_cfg, 29 | align_corners, 30 | context_cfg=None): 31 | super(ASPPWrapper, self).__init__() 32 | assert isinstance(dilations, (list, tuple)) 33 | self.dilations = dilations 34 | self.align_corners = align_corners 35 | if pool: 36 | self.image_pool = nn.Sequential( 37 | nn.AdaptiveAvgPool2d(1), 38 | ConvModule( 39 | in_channels, 40 | channels, 41 | 1, 42 | norm_cfg=norm_cfg, 43 | act_cfg=act_cfg)) 44 | else: 45 | self.image_pool = None 46 | if context_cfg is not None: 47 | self.context_layer = build_layer(in_channels, channels, 48 | **context_cfg) 49 | else: 50 | self.context_layer = None 51 | ASPP = {True: DepthwiseSeparableASPPModule, False: ASPPModule}[sep] 52 | self.aspp_modules = ASPP( 53 | dilations=dilations, 54 | in_channels=in_channels, 55 | channels=channels, 56 | norm_cfg=norm_cfg, 57 | conv_cfg=None, 58 | act_cfg=act_cfg) 59 | self.bottleneck = ConvModule( 60 | (len(dilations) + int(pool) + int(bool(context_cfg))) * channels, 61 | channels, 62 | kernel_size=3, 63 | padding=1, 64 | norm_cfg=norm_cfg, 65 | act_cfg=act_cfg) 66 | 67 | def forward(self, x): 68 | """Forward function.""" 69 | aspp_outs = [] 70 | if self.image_pool is not None: 71 | aspp_outs.append( 72 | resize( 73 | self.image_pool(x), 74 | size=x.size()[2:], 75 | mode='bilinear', 76 | align_corners=self.align_corners)) 77 | if self.context_layer is not None: 78 | aspp_outs.append(self.context_layer(x)) 79 | aspp_outs.extend(self.aspp_modules(x)) 80 | aspp_outs = torch.cat(aspp_outs, dim=1) 81 | 82 | output = self.bottleneck(aspp_outs) 83 | return output 84 | 85 | 86 | def build_layer(in_channels, out_channels, type, **kwargs): 87 | if type == 'id': 88 | return nn.Identity() 89 | elif type == 'mlp': 90 | return MLP(input_dim=in_channels, embed_dim=out_channels) 91 | elif type == 'sep_conv': 92 | return DepthwiseSeparableConvModule( 93 | in_channels=in_channels, 94 | out_channels=out_channels, 95 | padding=kwargs['kernel_size'] // 2, 96 | **kwargs) 97 | elif type == 'conv': 98 | return ConvModule( 99 | in_channels=in_channels, 100 | out_channels=out_channels, 101 | padding=kwargs['kernel_size'] // 2, 102 | **kwargs) 103 | elif type == 'aspp': 104 | return ASPPWrapper( 105 | in_channels=in_channels, channels=out_channels, **kwargs) 106 | elif type == 'rawconv_and_aspp': 107 | kernel_size = kwargs.pop('kernel_size') 108 | return nn.Sequential( 109 | nn.Conv2d( 110 | in_channels=in_channels, 111 | out_channels=out_channels, 112 | kernel_size=kernel_size, 113 | padding=kernel_size // 2), 114 | ASPPWrapper( 115 | in_channels=out_channels, channels=out_channels, **kwargs)) 116 | elif type == 'isa': 117 | return ISALayer( 118 | in_channels=in_channels, channels=out_channels, **kwargs) 119 | else: 120 | raise NotImplementedError(type) 121 | 122 | 123 | @HEADS.register_module() 124 | class DAFormerHead(BaseDecodeHead): 125 | 126 | def __init__(self, **kwargs): 127 | super(DAFormerHead, self).__init__( 128 | input_transform='multiple_select', **kwargs) 129 | 130 | assert not self.align_corners 131 | decoder_params = kwargs['decoder_params'] 132 | embed_dims = decoder_params['embed_dims'] 133 | if isinstance(embed_dims, int): 134 | embed_dims = [embed_dims] * len(self.in_index) 135 | embed_cfg = decoder_params['embed_cfg'] 136 | embed_neck_cfg = decoder_params['embed_neck_cfg'] 137 | if embed_neck_cfg == 'same_as_embed_cfg': 138 | embed_neck_cfg = embed_cfg 139 | fusion_cfg = decoder_params['fusion_cfg'] 140 | for cfg in [embed_cfg, embed_neck_cfg, fusion_cfg]: 141 | if cfg is not None and 'aspp' in cfg['type']: 142 | cfg['align_corners'] = self.align_corners 143 | 144 | self.embed_layers = {} 145 | for i, in_channels, embed_dim in zip(self.in_index, self.in_channels, 146 | embed_dims): 147 | if i == self.in_index[-1]: 148 | self.embed_layers[str(i)] = build_layer( 149 | in_channels, embed_dim, **embed_neck_cfg) 150 | else: 151 | self.embed_layers[str(i)] = build_layer( 152 | in_channels, embed_dim, **embed_cfg) 153 | self.embed_layers = nn.ModuleDict(self.embed_layers) 154 | 155 | self.fuse_layer = build_layer( 156 | sum(embed_dims), self.channels, **fusion_cfg) 157 | 158 | def forward(self, inputs): 159 | x = inputs 160 | n, _, h, w = x[-1].shape 161 | # for f in x: 162 | # mmcv.print_log(f'{f.shape}', 'mmseg') 163 | 164 | os_size = x[0].size()[2:] 165 | _c = {} 166 | for i in self.in_index: 167 | # mmcv.print_log(f'{i}: {x[i].shape}', 'mmseg') 168 | _c[i] = self.embed_layers[str(i)](x[i]) 169 | if _c[i].dim() == 3: 170 | _c[i] = _c[i].permute(0, 2, 1).contiguous()\ 171 | .reshape(n, -1, x[i].shape[2], x[i].shape[3]) 172 | # mmcv.print_log(f'_c{i}: {_c[i].shape}', 'mmseg') 173 | if _c[i].size()[2:] != os_size: 174 | # mmcv.print_log(f'resize {i}', 'mmseg') 175 | _c[i] = resize( 176 | _c[i], 177 | size=os_size, 178 | mode='bilinear', 179 | align_corners=self.align_corners) 180 | 181 | x = self.fuse_layer(torch.cat(list(_c.values()), dim=1)) 182 | x = self.cls_seg(x) 183 | 184 | return x 185 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/dlv2_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/lhoyer/DAFormer 2 | 3 | from ..builder import HEADS 4 | from .aspp_head import ASPPModule 5 | from .decode_head import BaseDecodeHead 6 | 7 | 8 | @HEADS.register_module() 9 | class DLV2Head(BaseDecodeHead): 10 | 11 | def __init__(self, dilations=(6, 12, 18, 24), **kwargs): 12 | assert 'channels' not in kwargs 13 | assert 'dropout_ratio' not in kwargs 14 | assert 'norm_cfg' not in kwargs 15 | kwargs['channels'] = 1 16 | kwargs['dropout_ratio'] = 0 17 | kwargs['norm_cfg'] = None 18 | super(DLV2Head, self).__init__(**kwargs) 19 | del self.conv_seg 20 | assert isinstance(dilations, (list, tuple)) 21 | self.dilations = dilations 22 | self.aspp_modules = ASPPModule( 23 | dilations, 24 | self.in_channels, 25 | self.num_classes, 26 | conv_cfg=self.conv_cfg, 27 | norm_cfg=None, 28 | act_cfg=None) 29 | 30 | def forward(self, inputs): 31 | """Forward function.""" 32 | # for f in inputs: 33 | # mmcv.print_log(f'{f.shape}', 'mmseg') 34 | x = self._transform_inputs(inputs) 35 | aspp_outs = self.aspp_modules(x) 36 | out = aspp_outs[0] 37 | for i in range(len(aspp_outs) - 1): 38 | out += aspp_outs[i + 1] 39 | return out 40 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class FCNHead(BaseDecodeHead): 13 | """Fully Convolution Networks for Semantic Segmentation. 14 | 15 | This head is implemented of `FCNNet `_. 16 | 17 | Args: 18 | num_convs (int): Number of convs in the head. Default: 2. 19 | kernel_size (int): The kernel size for convs in the head. Default: 3. 20 | concat_input (bool): Whether concat the input and output of convs 21 | before classification layer. 22 | dilation (int): The dilation rate for convs in the head. Default: 1. 23 | """ 24 | 25 | def __init__(self, 26 | num_convs=2, 27 | kernel_size=3, 28 | concat_input=True, 29 | dilation=1, 30 | **kwargs): 31 | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) 32 | self.num_convs = num_convs 33 | self.concat_input = concat_input 34 | self.kernel_size = kernel_size 35 | super(FCNHead, self).__init__(**kwargs) 36 | if num_convs == 0: 37 | assert self.in_channels == self.channels 38 | 39 | conv_padding = (kernel_size // 2) * dilation 40 | convs = [] 41 | convs.append( 42 | ConvModule( 43 | self.in_channels, 44 | self.channels, 45 | kernel_size=kernel_size, 46 | padding=conv_padding, 47 | dilation=dilation, 48 | conv_cfg=self.conv_cfg, 49 | norm_cfg=self.norm_cfg, 50 | act_cfg=self.act_cfg)) 51 | for i in range(num_convs - 1): 52 | convs.append( 53 | ConvModule( 54 | self.channels, 55 | self.channels, 56 | kernel_size=kernel_size, 57 | padding=conv_padding, 58 | dilation=dilation, 59 | conv_cfg=self.conv_cfg, 60 | norm_cfg=self.norm_cfg, 61 | act_cfg=self.act_cfg)) 62 | if num_convs == 0: 63 | self.convs = nn.Identity() 64 | else: 65 | self.convs = nn.Sequential(*convs) 66 | if self.concat_input: 67 | self.conv_cat = ConvModule( 68 | self.in_channels + self.channels, 69 | self.channels, 70 | kernel_size=kernel_size, 71 | padding=kernel_size // 2, 72 | conv_cfg=self.conv_cfg, 73 | norm_cfg=self.norm_cfg, 74 | act_cfg=self.act_cfg) 75 | 76 | def forward(self, inputs): 77 | """Forward function.""" 78 | x = self._transform_inputs(inputs) 79 | output = self.convs(x) 80 | if self.concat_input: 81 | output = self.conv_cat(torch.cat([x, output], dim=1)) 82 | output = self.cls_seg(output) 83 | return output 84 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/proj_head.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from mmcv.cnn import ConvModule 10 | 11 | from ..builder import HEADS 12 | from .decode_head_decorator import BaseDecodeHeadDecorator 13 | 14 | 15 | @HEADS.register_module() 16 | class ProjHead(BaseDecodeHeadDecorator): 17 | """Projection Head for feature dimension reduction in contrastive loss. 18 | 19 | Args: 20 | num_convs (int): Number of convs in the head. Default: 2. 21 | kernel_size (int): The kernel size for convs in the head. Default: 3. 22 | concat_input (bool): Whether concat the input and output of convs 23 | before classification layer. 24 | dilation (int): The dilation rate for convs in the head. Default: 1. 25 | """ 26 | def __init__(self, 27 | num_convs=2, 28 | kernel_size=1, 29 | dilation=1, 30 | **kwargs): 31 | assert num_convs in (0, 1, 2) and dilation > 0 and isinstance(dilation, int) 32 | self.num_convs = num_convs 33 | self.kernel_size = kernel_size 34 | super(ProjHead, self).__init__(**kwargs) 35 | if num_convs == 0: 36 | assert self.in_channels == self.channels 37 | 38 | conv_padding = (kernel_size // 2) * dilation 39 | if self.input_transform == 'multiple_select': 40 | convs = [[] for _ in range(len(self.in_channels))] 41 | for i in range(len(self.in_channels)): 42 | if num_convs > 1: 43 | convs[i].append( 44 | ConvModule( 45 | self.in_channels[i], 46 | self.in_channels[i], 47 | kernel_size=kernel_size, 48 | padding=conv_padding, 49 | dilation=dilation, 50 | conv_cfg=self.conv_cfg, 51 | norm_cfg=self.norm_cfg, 52 | act_cfg=self.act_cfg)) 53 | convs[i].append( 54 | ConvModule( 55 | self.in_channels[i], 56 | self.channels, 57 | kernel_size=kernel_size, 58 | padding=conv_padding, 59 | dilation=dilation, 60 | conv_cfg=self.conv_cfg, 61 | norm_cfg=self.norm_cfg, 62 | act_cfg=self.act_cfg)) 63 | if num_convs == 0: 64 | self.convs = nn.ModuleList([nn.Identity() for _ in range(len(self.in_channels))]) 65 | else: 66 | self.convs = nn.ModuleList([nn.Sequential(*convs[i]) for i in range(len(self.in_channels))]) 67 | 68 | else: 69 | if self.input_transform == 'resize_concat': 70 | self.mid_channels = self.in_channels // len(self.in_index) 71 | else: 72 | self.mid_channels = self.in_channels 73 | convs = [] 74 | if num_convs > 1: 75 | convs.append( 76 | ConvModule( 77 | self.in_channels, 78 | self.mid_channels, 79 | kernel_size=kernel_size, 80 | padding=conv_padding, 81 | dilation=dilation, 82 | conv_cfg=self.conv_cfg, 83 | norm_cfg=self.norm_cfg, 84 | act_cfg=self.act_cfg)) 85 | convs.append( 86 | ConvModule( 87 | self.mid_channels, 88 | self.channels, 89 | kernel_size=kernel_size, 90 | padding=conv_padding, 91 | dilation=dilation, 92 | conv_cfg=self.conv_cfg, 93 | norm_cfg=self.norm_cfg, 94 | act_cfg=self.act_cfg)) 95 | if num_convs == 0: 96 | self.convs = nn.Identity() 97 | else: 98 | self.convs = nn.Sequential(*convs) 99 | 100 | def forward(self, inputs): 101 | """Forward function.""" 102 | x = self._transform_inputs(inputs) 103 | if isinstance(x, list): 104 | # multiple_select 105 | output = [F.normalize(self.convs[i](x[i]), p=2, dim=1) for i in range(len(x))] 106 | else: 107 | # resize_concat or single_select 108 | output = F.normalize(self.convs(x), p=2, dim=1) 109 | return output 110 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/segformer_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/NVlabs/SegFormer 2 | # Modifications: Model construction with loop 3 | # --------------------------------------------------------------- 4 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 5 | # 6 | # This work is licensed under the NVIDIA Source Code License 7 | # --------------------------------------------------------------- 8 | 9 | import torch 10 | import torch.nn as nn 11 | from mmcv.cnn import ConvModule 12 | 13 | from mmseg.ops import resize 14 | from ..builder import HEADS 15 | from .decode_head import BaseDecodeHead 16 | 17 | 18 | class MLP(nn.Module): 19 | """Linear Embedding.""" 20 | 21 | def __init__(self, input_dim=2048, embed_dim=768): 22 | super().__init__() 23 | self.proj = nn.Linear(input_dim, embed_dim) 24 | 25 | def forward(self, x): 26 | x = x.flatten(2).transpose(1, 2).contiguous() 27 | x = self.proj(x) 28 | return x 29 | 30 | 31 | @HEADS.register_module() 32 | class SegFormerHead(BaseDecodeHead): 33 | """ 34 | SegFormer: Simple and Efficient Design for Semantic Segmentation with 35 | Transformers 36 | """ 37 | 38 | def __init__(self, **kwargs): 39 | super(SegFormerHead, self).__init__( 40 | input_transform='multiple_select', **kwargs) 41 | 42 | decoder_params = kwargs['decoder_params'] 43 | embedding_dim = decoder_params['embed_dim'] 44 | conv_kernel_size = decoder_params['conv_kernel_size'] 45 | 46 | self.linear_c = {} 47 | for i, in_channels in zip(self.in_index, self.in_channels): 48 | self.linear_c[str(i)] = MLP( 49 | input_dim=in_channels, embed_dim=embedding_dim) 50 | self.linear_c = nn.ModuleDict(self.linear_c) 51 | 52 | self.linear_fuse = ConvModule( 53 | in_channels=embedding_dim * len(self.in_index), 54 | out_channels=embedding_dim, 55 | kernel_size=conv_kernel_size, 56 | padding=0 if conv_kernel_size == 1 else conv_kernel_size // 2, 57 | norm_cfg=kwargs['norm_cfg']) 58 | 59 | self.linear_pred = nn.Conv2d( 60 | embedding_dim, self.num_classes, kernel_size=1) 61 | 62 | def forward(self, inputs): 63 | x = inputs 64 | n, _, h, w = x[-1].shape 65 | # for f in x: 66 | # print(f.shape) 67 | 68 | _c = {} 69 | for i in self.in_index: 70 | # mmcv.print_log(f'{i}: {x[i].shape}, {self.linear_c[str(i)]}') 71 | _c[i] = self.linear_c[str(i)](x[i]).permute(0, 2, 1).contiguous() 72 | _c[i] = _c[i].reshape(n, -1, x[i].shape[2], x[i].shape[3]) 73 | if i != 0: 74 | _c[i] = resize( 75 | _c[i], 76 | size=x[0].size()[2:], 77 | mode='bilinear', 78 | align_corners=False) 79 | 80 | _c = self.linear_fuse(torch.cat(list(_c.values()), dim=1)) 81 | 82 | if self.dropout is not None: 83 | x = self.dropout(_c) 84 | else: 85 | x = _c 86 | x = self.linear_pred(x) 87 | 88 | return x 89 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/sep_aspp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .aspp_head import ASPPHead, ASPPModule 10 | 11 | 12 | class DepthwiseSeparableASPPModule(ASPPModule): 13 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 14 | conv.""" 15 | 16 | def __init__(self, **kwargs): 17 | super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) 18 | for i, dilation in enumerate(self.dilations): 19 | if dilation > 1: 20 | self[i] = DepthwiseSeparableConvModule( 21 | self.in_channels, 22 | self.channels, 23 | 3, 24 | dilation=dilation, 25 | padding=dilation, 26 | norm_cfg=self.norm_cfg, 27 | act_cfg=self.act_cfg) 28 | 29 | 30 | @HEADS.register_module() 31 | class DepthwiseSeparableASPPHead(ASPPHead): 32 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 33 | Segmentation. 34 | 35 | This head is the implementation of `DeepLabV3+ 36 | `_. 37 | 38 | Args: 39 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 40 | the no decoder will be used. 41 | c1_channels (int): The intermediate channels of c1 decoder. 42 | """ 43 | 44 | def __init__(self, c1_in_channels, c1_channels, **kwargs): 45 | super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) 46 | assert c1_in_channels >= 0 47 | self.aspp_modules = DepthwiseSeparableASPPModule( 48 | dilations=self.dilations, 49 | in_channels=self.in_channels, 50 | channels=self.channels, 51 | conv_cfg=self.conv_cfg, 52 | norm_cfg=self.norm_cfg, 53 | act_cfg=self.act_cfg) 54 | if c1_in_channels > 0: 55 | self.c1_bottleneck = ConvModule( 56 | c1_in_channels, 57 | c1_channels, 58 | 1, 59 | conv_cfg=self.conv_cfg, 60 | norm_cfg=self.norm_cfg, 61 | act_cfg=self.act_cfg) 62 | else: 63 | self.c1_bottleneck = None 64 | self.sep_bottleneck = nn.Sequential( 65 | DepthwiseSeparableConvModule( 66 | self.channels + c1_channels, 67 | self.channels, 68 | 3, 69 | padding=1, 70 | norm_cfg=self.norm_cfg, 71 | act_cfg=self.act_cfg), 72 | DepthwiseSeparableConvModule( 73 | self.channels, 74 | self.channels, 75 | 3, 76 | padding=1, 77 | norm_cfg=self.norm_cfg, 78 | act_cfg=self.act_cfg)) 79 | 80 | def forward(self, inputs): 81 | """Forward function.""" 82 | x = self._transform_inputs(inputs) 83 | aspp_outs = [ 84 | resize( 85 | self.image_pool(x), 86 | size=x.size()[2:], 87 | mode='bilinear', 88 | align_corners=self.align_corners) 89 | ] 90 | aspp_outs.extend(self.aspp_modules(x)) 91 | aspp_outs = torch.cat(aspp_outs, dim=1) 92 | output = self.bottleneck(aspp_outs) 93 | if self.c1_bottleneck is not None: 94 | c1_output = self.c1_bottleneck(inputs[0]) 95 | output = resize( 96 | input=output, 97 | size=c1_output.shape[2:], 98 | mode='bilinear', 99 | align_corners=self.align_corners) 100 | output = torch.cat([output, c1_output], dim=1) 101 | output = self.sep_bottleneck(output) 102 | output = self.cls_seg(output) 103 | return output 104 | -------------------------------------------------------------------------------- /mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy, accuracy 2 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 3 | cross_entropy, mask_cross_entropy) 4 | from .contrastive_loss import ContrastiveLoss 5 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 6 | 7 | __all__ = [ 8 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 9 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 10 | 'weight_reduce_loss', 'weighted_loss', 'ContrastiveLoss' 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def accuracy(pred, target, topk=1, thresh=None): 7 | """Calculate accuracy according to the prediction and target. 8 | 9 | Args: 10 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 11 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 12 | topk (int | tuple[int], optional): If the predictions in ``topk`` 13 | matches the target, the predictions will be regarded as 14 | correct ones. Defaults to 1. 15 | thresh (float, optional): If not None, predictions with scores under 16 | this threshold are considered incorrect. Default to None. 17 | 18 | Returns: 19 | float | tuple[float]: If the input ``topk`` is a single integer, 20 | the function will return a single float as accuracy. If 21 | ``topk`` is a tuple containing multiple integers, the 22 | function will return a tuple containing accuracies of 23 | each ``topk`` number. 24 | """ 25 | assert isinstance(topk, (int, tuple)) 26 | if isinstance(topk, int): 27 | topk = (topk, ) 28 | return_single = True 29 | else: 30 | return_single = False 31 | 32 | maxk = max(topk) 33 | if pred.size(0) == 0: 34 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 35 | return accu[0] if return_single else accu 36 | assert pred.ndim == target.ndim + 1 37 | assert pred.size(0) == target.size(0) 38 | assert maxk <= pred.size(1), \ 39 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 40 | pred_value, pred_label = pred.topk(maxk, dim=1) 41 | # transpose to shape (maxk, N, ...) 42 | pred_label = pred_label.transpose(0, 1) 43 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 44 | if thresh is not None: 45 | # Only prediction values larger than thresh are counted as correct 46 | correct = correct & (pred_value > thresh).t() 47 | res = [] 48 | for k in topk: 49 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 50 | res.append(correct_k.mul_(100.0 / target.numel())) 51 | return res[0] if return_single else res 52 | 53 | 54 | class Accuracy(nn.Module): 55 | """Accuracy calculation module.""" 56 | 57 | def __init__(self, topk=(1, ), thresh=None): 58 | """Module to calculate the accuracy. 59 | 60 | Args: 61 | topk (tuple, optional): The criterion used to calculate the 62 | accuracy. Defaults to (1,). 63 | thresh (float, optional): If not None, predictions with scores 64 | under this threshold are considered incorrect. Default to None. 65 | """ 66 | super().__init__() 67 | self.topk = topk 68 | self.thresh = thresh 69 | 70 | def forward(self, pred, target): 71 | """Forward function to calculate accuracy. 72 | 73 | Args: 74 | pred (torch.Tensor): Prediction of models. 75 | target (torch.Tensor): Target for each prediction. 76 | 77 | Returns: 78 | tuple[float]: The accuracies under different topk criterions. 79 | """ 80 | return accuracy(pred, target, self.topk, self.thresh) 81 | -------------------------------------------------------------------------------- /mmseg/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import functools 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_class_weight(class_weight): 11 | """Get class weight for loss function. 12 | 13 | Args: 14 | class_weight (list[float] | str | None): If class_weight is a str, 15 | take it as a file name and read from it. 16 | """ 17 | if isinstance(class_weight, str): 18 | # take it as a file path 19 | if class_weight.endswith('.npy'): 20 | class_weight = np.load(class_weight) 21 | else: 22 | # pkl, json or yaml 23 | class_weight = mmcv.load(class_weight) 24 | 25 | return class_weight 26 | 27 | 28 | def reduce_loss(loss, reduction): 29 | """Reduce loss as specified. 30 | 31 | Args: 32 | loss (Tensor): Elementwise loss tensor. 33 | reduction (str): Options are "none", "mean" and "sum". 34 | 35 | Return: 36 | Tensor: Reduced loss tensor. 37 | """ 38 | reduction_enum = F._Reduction.get_enum(reduction) 39 | # none: 0, elementwise_mean:1, sum: 2 40 | if reduction_enum == 0: 41 | return loss 42 | elif reduction_enum == 1: 43 | return loss.mean() 44 | elif reduction_enum == 2: 45 | return loss.sum() 46 | 47 | 48 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 49 | """Apply element-wise weight and reduce loss. 50 | 51 | Args: 52 | loss (Tensor): Element-wise loss. 53 | weight (Tensor): Element-wise weights. 54 | reduction (str): Same as built-in losses of PyTorch. 55 | avg_factor (float): Avarage factor when computing the mean of losses. 56 | 57 | Returns: 58 | Tensor: Processed loss values. 59 | """ 60 | # if weight is specified, apply element-wise weight 61 | if weight is not None: 62 | assert weight.dim() == loss.dim() 63 | if weight.dim() > 1: 64 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 65 | loss = loss * weight 66 | 67 | # if avg_factor is not specified, just reduce the loss 68 | if avg_factor is None: 69 | loss = reduce_loss(loss, reduction) 70 | else: 71 | # if reduction is mean, then average the loss by avg_factor 72 | if reduction == 'mean': 73 | loss = loss.sum() / avg_factor 74 | # if reduction is 'none', then do nothing, otherwise raise an error 75 | elif reduction != 'none': 76 | raise ValueError('avg_factor can not be used with reduction="sum"') 77 | return loss 78 | 79 | 80 | def weighted_loss(loss_func): 81 | """Create a weighted version of a given loss function. 82 | 83 | To use this decorator, the loss function must have the signature like 84 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 85 | element-wise loss without any reduction. This decorator will add weight 86 | and reduction arguments to the function. The decorated function will have 87 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 88 | avg_factor=None, **kwargs)`. 89 | 90 | :Example: 91 | 92 | >>> import torch 93 | >>> @weighted_loss 94 | >>> def l1_loss(pred, target): 95 | >>> return (pred - target).abs() 96 | 97 | >>> pred = torch.Tensor([0, 2, 3]) 98 | >>> target = torch.Tensor([1, 1, 1]) 99 | >>> weight = torch.Tensor([1, 0, 1]) 100 | 101 | >>> l1_loss(pred, target) 102 | tensor(1.3333) 103 | >>> l1_loss(pred, target, weight) 104 | tensor(1.) 105 | >>> l1_loss(pred, target, reduction='none') 106 | tensor([1., 1., 2.]) 107 | >>> l1_loss(pred, target, weight, avg_factor=2) 108 | tensor(1.5000) 109 | """ 110 | 111 | @functools.wraps(loss_func) 112 | def wrapper(pred, 113 | target, 114 | weight=None, 115 | reduction='mean', 116 | avg_factor=None, 117 | **kwargs): 118 | # get element-wise loss 119 | loss = loss_func(pred, target, **kwargs) 120 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 121 | return loss 122 | 123 | return wrapper 124 | -------------------------------------------------------------------------------- /mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional segmentors 3 | 4 | from .base import BaseSegmentor 5 | from .encoder_decoder import EncoderDecoder 6 | from .encoder_decoder_projector import EncoderDecoderProjector 7 | 8 | __all__ = ['BaseSegmentor', 'EncoderDecoder', 'EncoderDecoderProjector'] 9 | -------------------------------------------------------------------------------- /mmseg/models/uda/__init__.py: -------------------------------------------------------------------------------- 1 | from mmseg.models.uda.sepico import SePiCo 2 | from mmseg.models.uda.sepico_dark import SePiCoDark 3 | 4 | __all__ = ['SePiCo', 'SePiCoDark'] 5 | -------------------------------------------------------------------------------- /mmseg/models/uda/uda_decorator.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | # Modifications: Add target gt for visualization purpose only 6 | 7 | from copy import deepcopy 8 | 9 | from mmcv.parallel import MMDistributedDataParallel 10 | 11 | from mmseg.models import BaseSegmentor, build_segmentor 12 | 13 | 14 | def get_module(module): 15 | """Get `nn.ModuleDict` to fit the `MMDistributedDataParallel` interface. 16 | 17 | Args: 18 | module (MMDistributedDataParallel | nn.ModuleDict): The input 19 | module that needs processing. 20 | 21 | Returns: 22 | nn.ModuleDict: The ModuleDict of multiple networks. 23 | """ 24 | if isinstance(module, MMDistributedDataParallel): 25 | return module.module 26 | 27 | return module 28 | 29 | 30 | class UDADecorator(BaseSegmentor): 31 | 32 | def __init__(self, **cfg): 33 | super(BaseSegmentor, self).__init__() 34 | 35 | self.model = build_segmentor(deepcopy(cfg['model'])) 36 | self.train_cfg = cfg['model']['train_cfg'] 37 | self.test_cfg = cfg['model']['test_cfg'] 38 | self.num_classes = cfg['model']['decode_head']['num_classes'] 39 | 40 | def get_model(self): 41 | return get_module(self.model) 42 | 43 | def extract_feat(self, img): 44 | """Extract features from images.""" 45 | return self.get_model().extract_feat(img) 46 | 47 | def encode_decode(self, img, img_metas): 48 | """Encode images with backbone and decode into a semantic segmentation 49 | map of the same size as input.""" 50 | return self.get_model().encode_decode(img, img_metas) 51 | 52 | def forward_train(self, img, img_metas, gt_semantic_seg, target_img, target_img_metas, target_gt_semantic_seg, return_feat=False): 53 | """Forward function for training. 54 | 55 | Args: 56 | img (Tensor): Input images. 57 | img_metas (list[dict]): List of image info dict where each dict 58 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 59 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 60 | For details on the values of these keys see 61 | `mmseg/datasets/pipelines/formatting.py:Collect`. 62 | gt_semantic_seg (Tensor): Semantic segmentation masks 63 | used if the architecture supports semantic segmentation task. 64 | 65 | Returns: 66 | dict[str, Tensor]: a dictionary of loss components 67 | """ 68 | losses = self.get_model().forward_train( 69 | img, img_metas, gt_semantic_seg, return_feat=return_feat) 70 | return losses 71 | 72 | def inference(self, img, img_meta, rescale): 73 | """Inference with slide/whole style. 74 | 75 | Args: 76 | img (Tensor): The input image of shape (N, 3, H, W). 77 | img_meta (dict): Image info dict where each dict has: 'img_shape', 78 | 'scale_factor', 'flip', and may also contain 79 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 80 | For details on the values of these keys see 81 | `mmseg/datasets/pipelines/formatting.py:Collect`. 82 | rescale (bool): Whether rescale back to original shape. 83 | 84 | Returns: 85 | Tensor: The output segmentation map. 86 | """ 87 | return self.get_model().inference(img, img_meta, rescale) 88 | 89 | def simple_test(self, img, img_meta, rescale=True): 90 | """Simple test with single image.""" 91 | return self.get_model().simple_test(img, img_meta, rescale) 92 | 93 | def aug_test(self, imgs, img_metas, rescale=True): 94 | """Test with augmentations. 95 | 96 | Only rescale=True is supported. 97 | """ 98 | return self.get_model().aug_test(imgs, img_metas, rescale) 99 | -------------------------------------------------------------------------------- /mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .ckpt_convert import mit_convert 2 | from .make_divisible import make_divisible 3 | from .res_layer import ResLayer 4 | from .self_attention_block import SelfAttentionBlock 5 | from .shape_convert import nchw_to_nlc, nlc_to_nchw 6 | 7 | __all__ = [ 8 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'mit_convert', 9 | 'nchw_to_nlc', 'nlc_to_nchw' 10 | ] 11 | -------------------------------------------------------------------------------- /mmseg/models/utils/ckpt_convert.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | 8 | def mit_convert(ckpt): 9 | new_ckpt = OrderedDict() 10 | # Process the concat between q linear weights and kv linear weights 11 | for k, v in ckpt.items(): 12 | if k.startswith('head'): 13 | continue 14 | elif k.startswith('patch_embed'): 15 | stage_i = int(k.split('.')[0].replace('patch_embed', '')) 16 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') 17 | new_v = v 18 | if 'proj.' in new_k: 19 | new_k = new_k.replace('proj.', 'projection.') 20 | elif k.startswith('block'): 21 | stage_i = int(k.split('.')[0].replace('block', '')) 22 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') 23 | new_v = v 24 | if 'attn.q.' in new_k: 25 | sub_item_k = k.replace('q.', 'kv.') 26 | new_k = new_k.replace('q.', 'attn.in_proj_') 27 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) 28 | elif 'attn.kv.' in new_k: 29 | continue 30 | elif 'attn.proj.' in new_k: 31 | new_k = new_k.replace('proj.', 'attn.out_proj.') 32 | elif 'attn.sr.' in new_k: 33 | new_k = new_k.replace('sr.', 'sr.') 34 | elif 'mlp.' in new_k: 35 | string = f'{new_k}-' 36 | new_k = new_k.replace('mlp.', 'ffn.layers.') 37 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k: 38 | new_v = v.reshape((*v.shape, 1, 1)) 39 | new_k = new_k.replace('fc1.', '0.') 40 | new_k = new_k.replace('dwconv.dwconv.', '1.') 41 | new_k = new_k.replace('fc2.', '4.') 42 | string += f'{new_k} {v.shape}-{new_v.shape}' 43 | # print(string) 44 | elif k.startswith('norm'): 45 | stage_i = int(k.split('.')[0].replace('norm', '')) 46 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') 47 | new_v = v 48 | else: 49 | new_k = k 50 | new_v = v 51 | new_ckpt[new_k] = new_v 52 | return new_ckpt 53 | -------------------------------------------------------------------------------- /mmseg/models/utils/dacs_transforms.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/vikolss/DACS 2 | # Copyright (c) 2020 vikolss. Licensed under the MIT License 3 | # A copy of the license is available at resources/license_dacs 4 | 5 | import kornia 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def strong_transform(param, data=None, target=None): 12 | assert ((data is not None) or (target is not None)) 13 | data, target = one_mix(mask=param['mix'], data=data, target=target) 14 | data, target = color_jitter( 15 | color_jitter=param['color_jitter'], 16 | s=param['color_jitter_s'], 17 | p=param['color_jitter_p'], 18 | mean=param['mean'], 19 | std=param['std'], 20 | data=data, 21 | target=target) 22 | data, target = gaussian_blur(blur=param['blur'], data=data, target=target) 23 | return data, target 24 | 25 | 26 | def get_mean_std(img_metas, dev): 27 | mean = [ 28 | torch.as_tensor(img_metas[i]['img_norm_cfg']['mean'], device=dev) 29 | for i in range(len(img_metas)) 30 | ] 31 | mean = torch.stack(mean).view(-1, 3, 1, 1) 32 | std = [ 33 | torch.as_tensor(img_metas[i]['img_norm_cfg']['std'], device=dev) 34 | for i in range(len(img_metas)) 35 | ] 36 | std = torch.stack(std).view(-1, 3, 1, 1) 37 | return mean, std 38 | 39 | 40 | def denorm(img, mean, std): 41 | return img.mul(std).add(mean) / 255.0 42 | 43 | 44 | def denorm_(img, mean, std): 45 | img.mul_(std).add_(mean).div_(255.0) 46 | 47 | 48 | def renorm_(img, mean, std): 49 | img.mul_(255.0).sub_(mean).div_(std) 50 | 51 | 52 | def color_jitter(color_jitter, mean, std, data=None, target=None, s=.25, p=.2): 53 | # s is the strength of colorjitter 54 | if not (data is None): 55 | if data.shape[1] == 3: 56 | if color_jitter > p: 57 | if isinstance(s, dict): 58 | seq = nn.Sequential(kornia.augmentation.ColorJitter(**s)) 59 | else: 60 | seq = nn.Sequential( 61 | kornia.augmentation.ColorJitter( 62 | brightness=s, contrast=s, saturation=s, hue=s)) 63 | denorm_(data, mean, std) 64 | data = seq(data) 65 | renorm_(data, mean, std) 66 | return data, target 67 | 68 | 69 | def gaussian_blur(blur, data=None, target=None): 70 | if not (data is None): 71 | if data.shape[1] == 3: 72 | if blur > 0.5: 73 | sigma = np.random.uniform(0.15, 1.15) 74 | kernel_size_y = int( 75 | np.floor( 76 | np.ceil(0.1 * data.shape[2]) - 0.5 + 77 | np.ceil(0.1 * data.shape[2]) % 2)) 78 | kernel_size_x = int( 79 | np.floor( 80 | np.ceil(0.1 * data.shape[3]) - 0.5 + 81 | np.ceil(0.1 * data.shape[3]) % 2)) 82 | kernel_size = (kernel_size_y, kernel_size_x) 83 | seq = nn.Sequential( 84 | kornia.filters.GaussianBlur2d( 85 | kernel_size=kernel_size, sigma=(sigma, sigma))) 86 | data = seq(data) 87 | return data, target 88 | 89 | 90 | def get_class_masks(labels): 91 | class_masks = [] 92 | for label in labels: 93 | classes = torch.unique(labels) 94 | nclasses = classes.shape[0] 95 | class_choice = np.random.choice( 96 | nclasses, int((nclasses + nclasses % 2) / 2), replace=False) 97 | classes = classes[torch.Tensor(class_choice).long()] 98 | class_masks.append(generate_class_mask(label, classes).unsqueeze(0)) 99 | return class_masks 100 | 101 | 102 | def generate_class_mask(label, classes): 103 | label, classes = torch.broadcast_tensors(label, 104 | classes.unsqueeze(1).unsqueeze(2)) 105 | class_mask = label.eq(classes).sum(0, keepdims=True) 106 | return class_mask 107 | 108 | 109 | def one_mix(mask, data=None, target=None): 110 | if mask is None: 111 | return data, target 112 | if not (data is None): 113 | stackedMask0, _ = torch.broadcast_tensors(mask[0], data[0]) 114 | data = (stackedMask0 * data[0] + 115 | (1 - stackedMask0) * data[1]).unsqueeze(0) 116 | if not (target is None): 117 | stackedMask0, _ = torch.broadcast_tensors(mask[0], target[0]) 118 | target = (stackedMask0 * target[0] + 119 | (1 - stackedMask0) * target[1]).unsqueeze(0) 120 | return data, target 121 | -------------------------------------------------------------------------------- /mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 5 | """Make divisible function. 6 | 7 | This function rounds the channel number to the nearest value that can be 8 | divisible by the divisor. It is taken from the original tf repo. It ensures 9 | that all layers have a channel number that is divisible by divisor. It can 10 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 11 | 12 | Args: 13 | value (int): The original channel number. 14 | divisor (int): The divisor to fully divide the channel number. 15 | min_value (int): The minimum value of the output channel. 16 | Default: None, means that the minimum value equal to the divisor. 17 | min_ratio (float): The minimum ratio of the rounded channel number to 18 | the original channel number. Default: 0.9. 19 | 20 | Returns: 21 | int: The modified output channel number. 22 | """ 23 | 24 | if min_value is None: 25 | min_value = divisor 26 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than (1-min_ratio). 28 | if new_value < min_ratio * value: 29 | new_value += divisor 30 | return new_value 31 | -------------------------------------------------------------------------------- /mmseg/models/utils/ours_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from numpy import random 4 | 5 | 6 | class RandomCrop(object): 7 | """Random crop the image & seg. 8 | 9 | Args: 10 | crop_size (tuple): Expected size after cropping, (h, w). 11 | cat_max_ratio (float): The maximum ratio that single category could 12 | occupy. 13 | """ 14 | 15 | def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): 16 | assert crop_size[0] > 0 and crop_size[1] > 0 17 | self.crop_size = crop_size 18 | self.cat_max_ratio = cat_max_ratio 19 | self.ignore_index = ignore_index 20 | 21 | def get_crop_bbox(self, img): 22 | """Randomly get a crop bounding box.""" 23 | margin_h = max(img.shape[0] - self.crop_size[0], 0) 24 | margin_w = max(img.shape[1] - self.crop_size[1], 0) 25 | offset_h = np.random.randint(0, margin_h + 1) 26 | offset_w = np.random.randint(0, margin_w + 1) 27 | crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] 28 | crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] 29 | 30 | return crop_y1, crop_y2, crop_x1, crop_x2 31 | 32 | def crop(self, img, crop_bbox): 33 | """Crop from ``img``""" 34 | crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox 35 | img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] 36 | return img 37 | 38 | def __call__(self, results): 39 | """Call function to randomly crop images, semantic segmentation maps. 40 | 41 | Args: 42 | results (dict): Result dict from loading pipeline. 43 | 44 | Returns: 45 | dict: Randomly cropped results, 'img_shape' key in result dict is 46 | updated according to crop size. 47 | """ 48 | 49 | img = results['img'] 50 | if 'crop_bbox' in results: 51 | crop_bbox = results['crop_bbox'] 52 | else: 53 | crop_bbox = self.get_crop_bbox(img) 54 | 55 | best_score = -1 56 | best_crop_bbox = None 57 | # Repeat 10 times 58 | for _ in range(10): 59 | if best_score >= 0: 60 | crop_bbox = self.get_crop_bbox(img) 61 | seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) 62 | labels, cnt = torch.unique(seg_temp, return_counts=True) 63 | cnt = cnt[labels != self.ignore_index] 64 | score = 0 65 | if len(cnt) > 1 and torch.max(cnt).item() / torch.sum(cnt).item() < self.cat_max_ratio: 66 | cnt_valid = cnt[cnt > 1] 67 | score = cnt_valid.float().log().sum().item() 68 | if score > best_score: 69 | best_score = score 70 | best_crop_bbox = crop_bbox 71 | crop_bbox = best_crop_bbox 72 | 73 | # crop the image 74 | img = self.crop(img, crop_bbox) 75 | img_shape = img.shape 76 | results['img'] = img 77 | results['img_shape'] = img_shape 78 | results['crop_bbox'] = crop_bbox 79 | 80 | # crop semantic seg 81 | for key in results.get('seg_fields', []): 82 | results[key] = self.crop(results[key], crop_bbox) 83 | 84 | return results 85 | 86 | def __repr__(self): 87 | return self.__class__.__name__ + f'(crop_size={self.crop_size})' 88 | 89 | 90 | class RandomCropNoProd(RandomCrop): 91 | """Random crop the image & seg. 92 | 93 | Args: 94 | crop_size (tuple): Expected size after cropping, (h, w). 95 | cat_max_ratio (float): The maximum ratio that single category could 96 | occupy. 97 | """ 98 | 99 | def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): 100 | super().__init__(crop_size, cat_max_ratio, ignore_index) 101 | 102 | def __call__(self, results): 103 | """Call function to randomly crop images, semantic segmentation maps. 104 | Args: 105 | results (dict): Result dict from loading pipeline. 106 | Returns: 107 | dict: Randomly cropped results, 'img_shape' key in result dict is 108 | updated according to crop size. 109 | """ 110 | 111 | img = results['img'] 112 | if 'crop_bbox' in results: 113 | crop_bbox = results['crop_bbox'] 114 | else: 115 | crop_bbox = self.get_crop_bbox(img) 116 | if self.cat_max_ratio < 1.: 117 | # Repeat 10 times 118 | for _ in range(10): 119 | seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) 120 | labels, cnt = torch.unique(seg_temp, return_counts=True) 121 | cnt = cnt[labels != self.ignore_index] 122 | if len(cnt) > 1 and torch.max(cnt).item() / torch.sum( 123 | cnt).item() < self.cat_max_ratio: 124 | break 125 | crop_bbox = self.get_crop_bbox(img) 126 | 127 | # crop the image 128 | img = self.crop(img, crop_bbox) 129 | img_shape = img.shape 130 | results['img'] = img 131 | results['img_shape'] = img_shape 132 | results['crop_bbox'] = crop_bbox 133 | 134 | # crop semantic seg 135 | for key in results.get('seg_fields', []): 136 | results[key] = self.crop(results[key], crop_bbox) 137 | 138 | return results 139 | -------------------------------------------------------------------------------- /mmseg/models/utils/proto_estimator.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022 BIT-DA. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import torch 7 | import torch.utils.data 8 | import torch.distributed 9 | import torch.backends.cudnn 10 | from collections import deque 11 | 12 | 13 | class ProtoEstimator: 14 | def __init__(self, dim, class_num, memory_length=100, resume=""): 15 | super(ProtoEstimator, self).__init__() 16 | self.dim = dim 17 | self.class_num = class_num 18 | 19 | # init mean and covariance 20 | if resume: 21 | print("Loading checkpoint from {}".format(resume)) 22 | checkpoint = torch.load(resume, map_location=torch.device('cpu')) 23 | self.CoVariance = checkpoint['CoVariance'].cuda() 24 | self.Ave = checkpoint['Ave'].cuda() 25 | self.Amount = checkpoint['Amount'].cuda() 26 | if 'MemoryBank' in checkpoint: 27 | self.MemoryBank = checkpoint['MemoryBank'].cuda() 28 | else: 29 | self.CoVariance = torch.zeros(self.class_num, self.dim).cuda() 30 | self.Ave = torch.zeros(self.class_num, self.dim).cuda() 31 | self.Amount = torch.zeros(self.class_num).cuda() 32 | self.MemoryBank = [deque([self.Ave[cls].unsqueeze(0).detach()], maxlen=memory_length) 33 | for cls in range(self.class_num)] 34 | 35 | def update_proto(self, features, labels): 36 | """Update variance and mean 37 | 38 | Args: 39 | features (Tensor): feature map, shape [B, A, H, W] N = B*H*W 40 | labels (Tensor): shape [B, 1, H, W] 41 | """ 42 | 43 | N, A = features.size() 44 | C = self.class_num 45 | 46 | NxCxA_Features = features.view( 47 | N, 1, A 48 | ).expand( 49 | N, C, A 50 | ) 51 | 52 | onehot = torch.zeros(N, C).cuda() 53 | onehot.scatter_(1, labels.view(-1, 1), 1) 54 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 55 | 56 | features_by_sort = NxCxA_Features.mul(NxCxA_onehot) 57 | 58 | Amount_CxA = NxCxA_onehot.sum(0) 59 | Amount_CxA[Amount_CxA == 0] = 1 60 | 61 | ave_CxA = features_by_sort.sum(0) / Amount_CxA 62 | 63 | # update memory bank 64 | for cls in torch.unique(labels): 65 | self.MemoryBank[cls].append(ave_CxA[cls].unsqueeze(0).detach()) 66 | 67 | var_temp = features_by_sort - ave_CxA.expand(N, C, A).mul(NxCxA_onehot) 68 | 69 | var_temp = var_temp.pow(2).sum(0).div(Amount_CxA) 70 | 71 | sum_weight_CV = onehot.sum(0).view(C, 1).expand(C, A) 72 | 73 | weight_CV = sum_weight_CV.div( 74 | sum_weight_CV + self.Amount.view(C, 1).expand(C, A) 75 | ) 76 | 77 | weight_CV[weight_CV != weight_CV] = 0 78 | 79 | additional_CV = weight_CV.mul(1 - weight_CV).mul((self.Ave - ave_CxA).pow(2)) 80 | 81 | self.CoVariance = (self.CoVariance.mul(1 - weight_CV) + var_temp.mul( 82 | weight_CV)).detach() + additional_CV.detach() 83 | 84 | self.Ave = (self.Ave.mul(1 - weight_CV) + ave_CxA.mul(weight_CV)).detach() 85 | 86 | self.Amount = self.Amount + onehot.sum(0) 87 | 88 | def save_proto(self, path): 89 | torch.save({'CoVariance': self.CoVariance.cpu(), 90 | 'Ave': self.Ave.cpu(), 91 | 'Amount': self.Amount.cpu() 92 | }, path) 93 | -------------------------------------------------------------------------------- /mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from mmcv.cnn import build_conv_layer, build_norm_layer 4 | from mmcv.runner import Sequential 5 | from torch import nn as nn 6 | 7 | 8 | class ResLayer(Sequential): 9 | """ResLayer to build ResNet style backbone. 10 | 11 | Args: 12 | block (nn.Module): block used to build ResLayer. 13 | inplanes (int): inplanes of block. 14 | planes (int): planes of block. 15 | num_blocks (int): number of blocks. 16 | stride (int): stride of the first block. Default: 1 17 | avg_down (bool): Use AvgPool instead of stride conv when 18 | downsampling in the bottleneck. Default: False 19 | conv_cfg (dict): dictionary to construct and config conv layer. 20 | Default: None 21 | norm_cfg (dict): dictionary to construct and config norm layer. 22 | Default: dict(type='BN') 23 | multi_grid (int | None): Multi grid dilation rates of last 24 | stage. Default: None 25 | contract_dilation (bool): Whether contract first dilation of each layer 26 | Default: False 27 | """ 28 | 29 | def __init__(self, 30 | block, 31 | inplanes, 32 | planes, 33 | num_blocks, 34 | stride=1, 35 | dilation=1, 36 | avg_down=False, 37 | conv_cfg=None, 38 | norm_cfg=dict(type='BN'), 39 | multi_grid=None, 40 | contract_dilation=False, 41 | **kwargs): 42 | self.block = block 43 | 44 | downsample = None 45 | if stride != 1 or inplanes != planes * block.expansion: 46 | downsample = [] 47 | conv_stride = stride 48 | if avg_down: 49 | conv_stride = 1 50 | downsample.append( 51 | nn.AvgPool2d( 52 | kernel_size=stride, 53 | stride=stride, 54 | ceil_mode=True, 55 | count_include_pad=False)) 56 | downsample.extend([ 57 | build_conv_layer( 58 | conv_cfg, 59 | inplanes, 60 | planes * block.expansion, 61 | kernel_size=1, 62 | stride=conv_stride, 63 | bias=False), 64 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 65 | ]) 66 | downsample = nn.Sequential(*downsample) 67 | 68 | layers = [] 69 | if multi_grid is None: 70 | if dilation > 1 and contract_dilation: 71 | first_dilation = dilation // 2 72 | else: 73 | first_dilation = dilation 74 | else: 75 | first_dilation = multi_grid[0] 76 | layers.append( 77 | block( 78 | inplanes=inplanes, 79 | planes=planes, 80 | stride=stride, 81 | dilation=first_dilation, 82 | downsample=downsample, 83 | conv_cfg=conv_cfg, 84 | norm_cfg=norm_cfg, 85 | **kwargs)) 86 | inplanes = planes * block.expansion 87 | for i in range(1, num_blocks): 88 | layers.append( 89 | block( 90 | inplanes=inplanes, 91 | planes=planes, 92 | stride=1, 93 | dilation=dilation if multi_grid is None else multi_grid[i], 94 | conv_cfg=conv_cfg, 95 | norm_cfg=norm_cfg, 96 | **kwargs)) 97 | super(ResLayer, self).__init__(*layers) 98 | -------------------------------------------------------------------------------- /mmseg/models/utils/self_attention_block.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | from mmcv.cnn import ConvModule, constant_init 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class SelfAttentionBlock(nn.Module): 10 | """General self-attention block/non-local block. 11 | 12 | Please refer to https://arxiv.org/abs/1706.03762 for details about key, 13 | query and value. 14 | 15 | Args: 16 | key_in_channels (int): Input channels of key feature. 17 | query_in_channels (int): Input channels of query feature. 18 | channels (int): Output channels of key/query transform. 19 | out_channels (int): Output channels. 20 | share_key_query (bool): Whether share projection weight between key 21 | and query projection. 22 | query_downsample (nn.Module): Query downsample module. 23 | key_downsample (nn.Module): Key downsample module. 24 | key_query_num_convs (int): Number of convs for key/query projection. 25 | value_num_convs (int): Number of convs for value projection. 26 | matmul_norm (bool): Whether normalize attention map with sqrt of 27 | channels 28 | with_out (bool): Whether use out projection. 29 | conv_cfg (dict|None): Config of conv layers. 30 | norm_cfg (dict|None): Config of norm layers. 31 | act_cfg (dict|None): Config of activation layers. 32 | """ 33 | 34 | def __init__(self, key_in_channels, query_in_channels, channels, 35 | out_channels, share_key_query, query_downsample, 36 | key_downsample, key_query_num_convs, value_out_num_convs, 37 | key_query_norm, value_out_norm, matmul_norm, with_out, 38 | conv_cfg, norm_cfg, act_cfg): 39 | super(SelfAttentionBlock, self).__init__() 40 | if share_key_query: 41 | assert key_in_channels == query_in_channels 42 | self.key_in_channels = key_in_channels 43 | self.query_in_channels = query_in_channels 44 | self.out_channels = out_channels 45 | self.channels = channels 46 | self.share_key_query = share_key_query 47 | self.conv_cfg = conv_cfg 48 | self.norm_cfg = norm_cfg 49 | self.act_cfg = act_cfg 50 | self.key_project = self.build_project( 51 | key_in_channels, 52 | channels, 53 | num_convs=key_query_num_convs, 54 | use_conv_module=key_query_norm, 55 | conv_cfg=conv_cfg, 56 | norm_cfg=norm_cfg, 57 | act_cfg=act_cfg) 58 | if share_key_query: 59 | self.query_project = self.key_project 60 | else: 61 | self.query_project = self.build_project( 62 | query_in_channels, 63 | channels, 64 | num_convs=key_query_num_convs, 65 | use_conv_module=key_query_norm, 66 | conv_cfg=conv_cfg, 67 | norm_cfg=norm_cfg, 68 | act_cfg=act_cfg) 69 | self.value_project = self.build_project( 70 | key_in_channels, 71 | channels if with_out else out_channels, 72 | num_convs=value_out_num_convs, 73 | use_conv_module=value_out_norm, 74 | conv_cfg=conv_cfg, 75 | norm_cfg=norm_cfg, 76 | act_cfg=act_cfg) 77 | if with_out: 78 | self.out_project = self.build_project( 79 | channels, 80 | out_channels, 81 | num_convs=value_out_num_convs, 82 | use_conv_module=value_out_norm, 83 | conv_cfg=conv_cfg, 84 | norm_cfg=norm_cfg, 85 | act_cfg=act_cfg) 86 | else: 87 | self.out_project = None 88 | 89 | self.query_downsample = query_downsample 90 | self.key_downsample = key_downsample 91 | self.matmul_norm = matmul_norm 92 | 93 | self.init_weights() 94 | 95 | def init_weights(self): 96 | """Initialize weight of later layer.""" 97 | if self.out_project is not None: 98 | if not isinstance(self.out_project, ConvModule): 99 | constant_init(self.out_project, 0) 100 | 101 | def build_project(self, in_channels, channels, num_convs, use_conv_module, 102 | conv_cfg, norm_cfg, act_cfg): 103 | """Build projection layer for key/query/value/out.""" 104 | if use_conv_module: 105 | convs = [ 106 | ConvModule( 107 | in_channels, 108 | channels, 109 | 1, 110 | conv_cfg=conv_cfg, 111 | norm_cfg=norm_cfg, 112 | act_cfg=act_cfg) 113 | ] 114 | for _ in range(num_convs - 1): 115 | convs.append( 116 | ConvModule( 117 | channels, 118 | channels, 119 | 1, 120 | conv_cfg=conv_cfg, 121 | norm_cfg=norm_cfg, 122 | act_cfg=act_cfg)) 123 | else: 124 | convs = [nn.Conv2d(in_channels, channels, 1)] 125 | for _ in range(num_convs - 1): 126 | convs.append(nn.Conv2d(channels, channels, 1)) 127 | if len(convs) > 1: 128 | convs = nn.Sequential(*convs) 129 | else: 130 | convs = convs[0] 131 | return convs 132 | 133 | def forward(self, query_feats, key_feats): 134 | """Forward function.""" 135 | batch_size = query_feats.size(0) 136 | query = self.query_project(query_feats) 137 | if self.query_downsample is not None: 138 | query = self.query_downsample(query) 139 | query = query.reshape(*query.shape[:2], -1) 140 | query = query.permute(0, 2, 1).contiguous() 141 | 142 | key = self.key_project(key_feats) 143 | value = self.value_project(key_feats) 144 | if self.key_downsample is not None: 145 | key = self.key_downsample(key) 146 | value = self.key_downsample(value) 147 | key = key.reshape(*key.shape[:2], -1) 148 | value = value.reshape(*value.shape[:2], -1) 149 | value = value.permute(0, 2, 1).contiguous() 150 | 151 | sim_map = torch.matmul(query, key) 152 | if self.matmul_norm: 153 | sim_map = (self.channels**-.5) * sim_map 154 | sim_map = F.softmax(sim_map, dim=-1) 155 | 156 | context = torch.matmul(sim_map, value) 157 | context = context.permute(0, 2, 1).contiguous() 158 | context = context.reshape(batch_size, -1, *query_feats.shape[2:]) 159 | if self.out_project is not None: 160 | context = self.out_project(context) 161 | return context 162 | -------------------------------------------------------------------------------- /mmseg/models/utils/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def nlc_to_nchw(x, hw_shape): 5 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 6 | 7 | Args: 8 | x (Tensor): The input tensor of shape [N, L, C] before convertion. 9 | hw_shape (Sequence[int]): The height and width of output feature map. 10 | 11 | Returns: 12 | Tensor: The output tensor of shape [N, C, H, W] after convertion. 13 | """ 14 | H, W = hw_shape 15 | assert len(x.shape) == 3 16 | B, L, C = x.shape 17 | assert L == H * W, 'The seq_len doesn\'t match H, W' 18 | return x.transpose(1, 2).reshape(B, C, H, W) 19 | 20 | 21 | def nchw_to_nlc(x): 22 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 23 | 24 | Args: 25 | x (Tensor): The input tensor of shape [N, C, H, W] before convertion. 26 | 27 | Returns: 28 | Tensor: The output tensor of shape [N, L, C] after convertion. 29 | """ 30 | assert len(x.shape) == 4 31 | return x.flatten(2).transpose(1, 2).contiguous() 32 | -------------------------------------------------------------------------------- /mmseg/models/utils/visualization.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add other palettes 3 | 4 | import numpy as np 5 | import torch 6 | from matplotlib import pyplot as plt 7 | from PIL import Image 8 | 9 | Cityscapes_palette = [ 10 | 128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 11 | 153, 153, 250, 170, 30, 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 12 | 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 0, 60, 100, 0, 80, 100, 13 | 0, 0, 230, 119, 11, 32, 128, 192, 0, 0, 64, 128, 128, 64, 128, 0, 192, 128, 14 | 128, 192, 128, 64, 64, 0, 192, 64, 0, 64, 192, 0, 192, 192, 0, 64, 64, 128, 15 | 192, 64, 128, 64, 192, 128, 192, 192, 128, 0, 0, 64, 128, 0, 64, 0, 128, 16 | 64, 128, 128, 64, 0, 0, 192, 128, 0, 192, 0, 128, 192, 128, 128, 192, 64, 17 | 0, 64, 192, 0, 64, 64, 128, 64, 192, 128, 64, 64, 0, 192, 192, 0, 192, 64, 18 | 128, 192, 192, 128, 192, 0, 64, 64, 128, 64, 64, 0, 192, 64, 128, 192, 64, 19 | 0, 64, 192, 128, 64, 192, 0, 192, 192, 128, 192, 192, 64, 64, 64, 192, 64, 20 | 64, 64, 192, 64, 192, 192, 64, 64, 64, 192, 192, 64, 192, 64, 192, 192, 21 | 192, 192, 192, 32, 0, 0, 160, 0, 0, 32, 128, 0, 160, 128, 0, 32, 0, 128, 22 | 160, 0, 128, 32, 128, 128, 160, 128, 128, 96, 0, 0, 224, 0, 0, 96, 128, 0, 23 | 224, 128, 0, 96, 0, 128, 224, 0, 128, 96, 128, 128, 224, 128, 128, 32, 64, 24 | 0, 160, 64, 0, 32, 192, 0, 160, 192, 0, 32, 64, 128, 160, 64, 128, 32, 192, 25 | 128, 160, 192, 128, 96, 64, 0, 224, 64, 0, 96, 192, 0, 224, 192, 0, 96, 64, 26 | 128, 224, 64, 128, 96, 192, 128, 224, 192, 128, 32, 0, 64, 160, 0, 64, 32, 27 | 128, 64, 160, 128, 64, 32, 0, 192, 160, 0, 192, 32, 128, 192, 160, 128, 28 | 192, 96, 0, 64, 224, 0, 64, 96, 128, 64, 224, 128, 64, 96, 0, 192, 224, 0, 29 | 192, 96, 128, 192, 224, 128, 192, 32, 64, 64, 160, 64, 64, 32, 192, 64, 30 | 160, 192, 64, 32, 64, 192, 160, 64, 192, 32, 192, 192, 160, 192, 192, 96, 31 | 64, 64, 224, 64, 64, 96, 192, 64, 224, 192, 64, 96, 64, 192, 224, 64, 192, 32 | 96, 192, 192, 224, 192, 192, 0, 32, 0, 128, 32, 0, 0, 160, 0, 128, 160, 0, 33 | 0, 32, 128, 128, 32, 128, 0, 160, 128, 128, 160, 128, 64, 32, 0, 192, 32, 34 | 0, 64, 160, 0, 192, 160, 0, 64, 32, 128, 192, 32, 128, 64, 160, 128, 192, 35 | 160, 128, 0, 96, 0, 128, 96, 0, 0, 224, 0, 128, 224, 0, 0, 96, 128, 128, 36 | 96, 128, 0, 224, 128, 128, 224, 128, 64, 96, 0, 192, 96, 0, 64, 224, 0, 37 | 192, 224, 0, 64, 96, 128, 192, 96, 128, 64, 224, 128, 192, 224, 128, 0, 32, 38 | 64, 128, 32, 64, 0, 160, 64, 128, 160, 64, 0, 32, 192, 128, 32, 192, 0, 39 | 160, 192, 128, 160, 192, 64, 32, 64, 192, 32, 64, 64, 160, 64, 192, 160, 40 | 64, 64, 32, 192, 192, 32, 192, 64, 160, 192, 192, 160, 192, 0, 96, 64, 128, 41 | 96, 64, 0, 224, 64, 128, 224, 64, 0, 96, 192, 128, 96, 192, 0, 224, 192, 42 | 128, 224, 192, 64, 96, 64, 192, 96, 64, 64, 224, 64, 192, 224, 64, 64, 96, 43 | 192, 192, 96, 192, 64, 224, 192, 192, 224, 192, 32, 32, 0, 160, 32, 0, 32, 44 | 160, 0, 160, 160, 0, 32, 32, 128, 160, 32, 128, 32, 160, 128, 160, 160, 45 | 128, 96, 32, 0, 224, 32, 0, 96, 160, 0, 224, 160, 0, 96, 32, 128, 224, 32, 46 | 128, 96, 160, 128, 224, 160, 128, 32, 96, 0, 160, 96, 0, 32, 224, 0, 160, 47 | 224, 0, 32, 96, 128, 160, 96, 128, 32, 224, 128, 160, 224, 128, 96, 96, 0, 48 | 224, 96, 0, 96, 224, 0, 224, 224, 0, 96, 96, 128, 224, 96, 128, 96, 224, 49 | 128, 224, 224, 128, 32, 32, 64, 160, 32, 64, 32, 160, 64, 160, 160, 64, 32, 50 | 32, 192, 160, 32, 192, 32, 160, 192, 160, 160, 192, 96, 32, 64, 224, 32, 51 | 64, 96, 160, 64, 224, 160, 64, 96, 32, 192, 224, 32, 192, 96, 160, 192, 52 | 224, 160, 192, 32, 96, 64, 160, 96, 64, 32, 224, 64, 160, 224, 64, 32, 96, 53 | 192, 160, 96, 192, 32, 224, 192, 160, 224, 192, 96, 96, 64, 224, 96, 64, 54 | 96, 224, 64, 224, 224, 64, 96, 96, 192, 224, 96, 192, 96, 224, 192, 0, 0, 0 55 | ] 56 | 57 | Cityscapes_palette_16 = Cityscapes_palette[0: 9 * 3] + Cityscapes_palette[10 * 3: 14 * 3] + \ 58 | Cityscapes_palette[15 * 3: 16 * 3] + Cityscapes_palette[17 * 3:] 59 | 60 | Cityscapes_palette_13 = Cityscapes_palette[0: 3 * 3] + Cityscapes_palette[6 * 3: 9 * 3] + \ 61 | Cityscapes_palette[10 * 3: 14 * 3] + Cityscapes_palette[15 * 3: 16 * 3] + \ 62 | Cityscapes_palette[17 * 3:] 63 | 64 | 65 | def colorize_mask(mask, palette): 66 | zero_pad = 256 * 3 - len(palette) 67 | for i in range(zero_pad): 68 | palette.append(0) 69 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 70 | new_mask.putpalette(palette) 71 | return new_mask 72 | 73 | 74 | def _colorize(img, cmap, mask_zero=False): 75 | vmin = np.min(img) 76 | vmax = np.max(img) 77 | mask = (img <= 0).squeeze() 78 | cm = plt.get_cmap(cmap) 79 | colored_image = cm(np.clip(img.squeeze(), vmin, vmax) / vmax)[:, :, :3] 80 | # Use white if no depth is available (<= 0) 81 | if mask_zero: 82 | colored_image[mask, :] = [1, 1, 1] 83 | return colored_image 84 | 85 | 86 | def subplotimg(ax, 87 | img, 88 | title, 89 | range_in_title=False, 90 | palette=Cityscapes_palette, 91 | nc=19, # num_classes 92 | **kwargs): 93 | if img is None: 94 | return 95 | if nc == 16: 96 | palette = Cityscapes_palette_16 97 | elif nc == 13: 98 | palette = Cityscapes_palette_13 99 | with torch.no_grad(): 100 | if torch.is_tensor(img): 101 | img = img.cpu() 102 | if len(img.shape) == 2: 103 | if torch.is_tensor(img): 104 | img = img.numpy() 105 | elif img.shape[0] == 1: 106 | if torch.is_tensor(img): 107 | img = img.numpy() 108 | img = img.squeeze(0) 109 | elif img.shape[0] == 3: 110 | img = img.permute(1, 2, 0) 111 | if not torch.is_tensor(img): 112 | img = img.numpy() 113 | if kwargs.get('cmap', '') == 'cityscapes': 114 | kwargs.pop('cmap') 115 | if torch.is_tensor(img): 116 | img = img.numpy() 117 | img = colorize_mask(img, palette) 118 | 119 | if range_in_title: 120 | vmin = np.min(img) 121 | vmax = np.max(img) 122 | title += f' {vmin:.3f}-{vmax:.3f}' 123 | 124 | ax.imshow(img, **kwargs) 125 | ax.set_title(title) 126 | -------------------------------------------------------------------------------- /mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoding import Encoding 2 | from .wrappers import Upsample, resize 3 | 4 | __all__ = ['Upsample', 'resize', 'Encoding'] 5 | -------------------------------------------------------------------------------- /mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class Encoding(nn.Module): 9 | """Encoding Layer: a learnable residual encoder. 10 | 11 | Input is of shape (batch_size, channels, height, width). 12 | Output is of shape (batch_size, num_codes, channels). 13 | 14 | Args: 15 | channels: dimension of the features or feature channels 16 | num_codes: number of code words 17 | """ 18 | 19 | def __init__(self, channels, num_codes): 20 | super(Encoding, self).__init__() 21 | # init codewords and smoothing factor 22 | self.channels, self.num_codes = channels, num_codes 23 | std = 1. / ((num_codes * channels)**0.5) 24 | # [num_codes, channels] 25 | self.codewords = nn.Parameter( 26 | torch.empty(num_codes, channels, 27 | dtype=torch.float).uniform_(-std, std), 28 | requires_grad=True) 29 | # [num_codes] 30 | self.scale = nn.Parameter( 31 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 32 | requires_grad=True) 33 | 34 | @staticmethod 35 | def scaled_l2(x, codewords, scale): 36 | num_codes, channels = codewords.size() 37 | batch_size = x.size(0) 38 | reshaped_scale = scale.view((1, 1, num_codes)) 39 | expanded_x = x.unsqueeze(2).expand( 40 | (batch_size, x.size(1), num_codes, channels)) 41 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 42 | 43 | scaled_l2_norm = reshaped_scale * ( 44 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 45 | return scaled_l2_norm 46 | 47 | @staticmethod 48 | def aggregate(assignment_weights, x, codewords): 49 | num_codes, channels = codewords.size() 50 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 51 | batch_size = x.size(0) 52 | 53 | expanded_x = x.unsqueeze(2).expand( 54 | (batch_size, x.size(1), num_codes, channels)) 55 | encoded_feat = (assignment_weights.unsqueeze(3) * 56 | (expanded_x - reshaped_codewords)).sum(dim=1) 57 | return encoded_feat 58 | 59 | def forward(self, x): 60 | assert x.dim() == 4 and x.size(1) == self.channels 61 | # [batch_size, channels, height, width] 62 | batch_size = x.size(0) 63 | # [batch_size, height x width, channels] 64 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 65 | # assignment_weights: [batch_size, channels, num_codes] 66 | assignment_weights = F.softmax( 67 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 68 | # aggregate 69 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 70 | return encoded_feat 71 | 72 | def __repr__(self): 73 | repr_str = self.__class__.__name__ 74 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 75 | f'x{self.channels})' 76 | return repr_str 77 | -------------------------------------------------------------------------------- /mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import warnings 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def resize(input, 10 | size=None, 11 | scale_factor=None, 12 | mode='nearest', 13 | align_corners=None, 14 | warning=True): 15 | if warning: 16 | if size is not None and align_corners: 17 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 18 | output_h, output_w = tuple(int(x) for x in size) 19 | if output_h > input_h or output_w > output_h: 20 | if ((output_h > 1 and output_w > 1 and input_h > 1 21 | and input_w > 1) and (output_h - 1) % (input_h - 1) 22 | and (output_w - 1) % (input_w - 1)): 23 | warnings.warn( 24 | f'When align_corners={align_corners}, ' 25 | 'the output would more aligned if ' 26 | f'input size {(input_h, input_w)} is `x+1` and ' 27 | f'out size {(output_h, output_w)} is `nx+1`') 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | 30 | 31 | class Upsample(nn.Module): 32 | 33 | def __init__(self, 34 | size=None, 35 | scale_factor=None, 36 | mode='nearest', 37 | align_corners=None): 38 | super(Upsample, self).__init__() 39 | self.size = size 40 | if isinstance(scale_factor, tuple): 41 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 42 | else: 43 | self.scale_factor = float(scale_factor) if scale_factor else None 44 | self.mode = mode 45 | self.align_corners = align_corners 46 | 47 | def forward(self, x): 48 | if not self.size: 49 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 50 | else: 51 | size = self.size 52 | return resize(x, size, None, self.mode, self.align_corners) 53 | -------------------------------------------------------------------------------- /mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .collect_env import collect_env 2 | from .logger import get_root_logger 3 | 4 | __all__ = ['get_root_logger', 'collect_env'] 5 | -------------------------------------------------------------------------------- /mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add code archive generation 3 | 4 | import os 5 | import tarfile 6 | 7 | from mmcv.utils import collect_env as collect_base_env 8 | from mmcv.utils import get_git_hash 9 | 10 | import mmseg 11 | 12 | 13 | def collect_env(): 14 | """Collect the information of the running environments.""" 15 | env_info = collect_base_env() 16 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 17 | 18 | return env_info 19 | 20 | 21 | def is_source_file(x): 22 | if x.isdir() or x.name.endswith(('.py', '.sh', '.yml', '.json', '.txt')) \ 23 | and '.mim' not in x.name and 'jobs/' not in x.name: 24 | # print(x.name) 25 | return x 26 | else: 27 | return None 28 | 29 | 30 | def gen_code_archive(out_dir, file='code.tar.gz'): 31 | archive = os.path.join(out_dir, file) 32 | os.makedirs(os.path.dirname(archive), exist_ok=True) 33 | with tarfile.open(archive, mode='w:gz') as tar: 34 | tar.add('.', filter=is_source_file) 35 | return archive 36 | 37 | 38 | if __name__ == '__main__': 39 | for name, val in collect_env().items(): 40 | print('{}: {}'.format(name, val)) 41 | -------------------------------------------------------------------------------- /mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import logging 4 | 5 | from mmcv.utils import get_logger 6 | 7 | 8 | def get_root_logger(log_file=None, log_level=logging.INFO): 9 | """Get the root logger. 10 | 11 | The logger will be initialized if it has not been initialized. By default a 12 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 13 | also be added. The name of the root logger is the top-level package name, 14 | e.g., "mmseg". 15 | 16 | Args: 17 | log_file (str | None): The log filename. If specified, a FileHandler 18 | will be added to the root logger. 19 | log_level (int): The root logger level. Note that only the process of 20 | rank 0 is affected, while other processes will set the level to 21 | "Error" and be silent most of the time. 22 | 23 | Returns: 24 | logging.Logger: The root logger. 25 | """ 26 | 27 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 28 | 29 | return logger 30 | -------------------------------------------------------------------------------- /mmseg/utils/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | @contextlib.contextmanager 9 | def np_local_seed(seed): 10 | state = np.random.get_state() 11 | np.random.seed(seed) 12 | try: 13 | yield 14 | finally: 15 | np.random.set_state(state) 16 | -------------------------------------------------------------------------------- /mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.16.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cityscapesScripts==2.2.0 2 | kornia==0.5.8 3 | matplotlib==3.4.2 4 | mmcv_full==1.3.7 5 | numpy==1.19.2 6 | Pillow==9.2.0 7 | prettytable==2.1.0 8 | seaborn==0.11.1 9 | setproctitle==1.2.3 10 | timm==0.3.2 11 | torch==1.7.1+cu110 12 | torchvision==0.8.2+cu110 -------------------------------------------------------------------------------- /resources/cs2dz_generalization_per_class_results/bdd100k_night_test.txt: -------------------------------------------------------------------------------- 1 | per class results: 2 | 3 | +---------------+-------+-------+ 4 | | Class | IoU | Acc | 5 | +---------------+-------+-------+ 6 | | road | 87.26 | 91.58 | 7 | | sidewalk | 48.28 | 82.67 | 8 | | building | 80.23 | 92.64 | 9 | | wall | 3.28 | 4.55 | 10 | | fence | 12.23 | 34.08 | 11 | | pole | 37.86 | 46.74 | 12 | | traffic light | 20.11 | 32.42 | 13 | | traffic sign | 51.41 | 64.53 | 14 | | vegetation | 47.57 | 72.16 | 15 | | terrain | 20.48 | 30.68 | 16 | | sky | 65.49 | 71.41 | 17 | | person | 67.56 | 81.32 | 18 | | rider | 67.05 | 94.47 | 19 | | car | 83.74 | 88.56 | 20 | | truck | 29.93 | 32.56 | 21 | | bus | 46.33 | 77.0 | 22 | | train | 0.0 | nan | 23 | | motorcycle | 0.0 | 0.0 | 24 | | bicycle | 1.92 | 70.44 | 25 | +---------------+-------+-------+ 26 | Summary: 27 | 28 | +------+-------+-------+ 29 | | aAcc | mIoU | mAcc | 30 | +------+-------+-------+ 31 | | 85.1 | 40.56 | 59.32 | 32 | +------+-------+-------+ -------------------------------------------------------------------------------- /resources/cs2dz_generalization_per_class_results/dark_zurich_test.txt: -------------------------------------------------------------------------------- 1 | IoU road : 93.23 2 | IoU sidewalk : 68.13 3 | IoU building : 73.71 4 | IoU wall : 32.82 5 | IoU fence : 16.27 6 | IoU pole : 54.58 7 | IoU traffic light : 49.49 8 | IoU traffic sign : 48.10 9 | IoU vegetation : 74.18 10 | IoU terrain : 31.00 11 | IoU sky : 86.27 12 | IoU person : 57.92 13 | IoU rider : 50.93 14 | IoU car : 82.38 15 | IoU truck : 52.23 16 | IoU bus : 1.34 17 | IoU train : 83.77 18 | IoU motorcycle : 43.90 19 | IoU bicycle : 29.77 20 | ------------------------------------ 21 | Mean IoU over 19 classes: 54.21 -------------------------------------------------------------------------------- /resources/cs2dz_generalization_per_class_results/night_driving_test.txt: -------------------------------------------------------------------------------- 1 | per class results: 2 | 3 | +---------------+-------+-------+ 4 | | Class | IoU | Acc | 5 | +---------------+-------+-------+ 6 | | road | 93.02 | 98.23 | 7 | | sidewalk | 73.66 | 79.11 | 8 | | building | 90.45 | 91.57 | 9 | | wall | 54.75 | 74.78 | 10 | | fence | 0.05 | 0.64 | 11 | | pole | 67.51 | 78.89 | 12 | | traffic light | 80.49 | 85.46 | 13 | | traffic sign | 82.18 | 87.77 | 14 | | vegetation | 67.24 | 84.3 | 15 | | terrain | 0.0 | nan | 16 | | sky | 58.5 | 98.4 | 17 | | person | 62.2 | 66.17 | 18 | | rider | 35.98 | 68.4 | 19 | | car | 74.95 | 87.77 | 20 | | truck | 18.25 | 18.25 | 21 | | bus | 94.61 | 95.56 | 22 | | train | 92.83 | 96.27 | 23 | | motorcycle | 0.0 | nan | 24 | | bicycle | 34.49 | 69.82 | 25 | +---------------+-------+-------+ 26 | Summary: 27 | 28 | +-------+------+-------+ 29 | | aAcc | mIoU | mAcc | 30 | +-------+------+-------+ 31 | | 92.26 | 56.9 | 75.38 | 32 | +-------+------+-------+ -------------------------------------------------------------------------------- /resources/esi_highly_cited.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/SePiCo/bce38edd0425287bff215b3ee479497745a8bbcc/resources/esi_highly_cited.png -------------------------------------------------------------------------------- /resources/license_dacs: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 vikolss 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /resources/license_segformer: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for SegFormer 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under 10 | this License. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under 13 | U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include 14 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are “made available” under this License by including in or with the Work either 17 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 18 | 19 | 2. License Grant 20 | 21 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, 22 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly 23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 24 | 25 | 3. Limitations 26 | 27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you 28 | include a complete copy of this License with your distribution, and (c) you retain without modification any 29 | copyright, patent, trademark, or attribution notices that are present in the Work. 30 | 31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 32 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use 33 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works 34 | that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution 35 | requirements in Section 3.1) will continue to apply to the Work itself. 36 | 37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use 38 | non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative 39 | works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 40 | 41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 42 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 43 | your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 44 | 45 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, 46 | or trademarks, except as necessary to reproduce the notices described in this License. 47 | 48 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the 49 | grant in Section 2.1) will terminate immediately. 50 | 51 | 4. Disclaimer of Warranty. 52 | 53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 54 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU 55 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 56 | 57 | 5. Limitation of Liability. 58 | 59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING 60 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 61 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 62 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR 63 | DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 64 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 65 | -------------------------------------------------------------------------------- /resources/uda_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/SePiCo/bce38edd0425287bff215b3ee479497745a8bbcc/resources/uda_results.png -------------------------------------------------------------------------------- /run_experiments.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | # Modifications: Change process name and other fixes 6 | 7 | import argparse 8 | import json 9 | import os 10 | import subprocess 11 | import uuid 12 | from datetime import datetime 13 | 14 | import torch 15 | from experiments import generate_experiment_cfgs 16 | from mmcv import Config, get_git_hash 17 | from tools import train 18 | 19 | import matplotlib 20 | import warnings 21 | import setproctitle 22 | 23 | matplotlib.use('Agg') 24 | warnings.filterwarnings('ignore') 25 | 26 | 27 | def run_command(command): 28 | p = subprocess.Popen( 29 | command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) 30 | for line in iter(p.stdout.readline, b''): 31 | print(line.decode('utf-8'), end='') 32 | 33 | 34 | def rsync(src, dst): 35 | rsync_cmd = f'rsync -a {src} {dst}' 36 | print(rsync_cmd) 37 | run_command(rsync_cmd) 38 | 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser() 42 | group = parser.add_mutually_exclusive_group(required=True) 43 | group.add_argument( 44 | '--exp', 45 | type=int, 46 | default=None, 47 | help='Experiment id as defined in experiment.py', 48 | ) 49 | group.add_argument( 50 | '--config', 51 | default=None, 52 | help='Path to config file', 53 | ) 54 | parser.add_argument( 55 | '--machine', type=str, choices=['local'], default='local') 56 | parser.add_argument('--debug', action='store_true') 57 | args = parser.parse_args() 58 | assert (args.config is None) != (args.exp is None), \ 59 | 'Either config or exp has to be defined.' 60 | 61 | GEN_CONFIG_DIR = 'configs/generated/' 62 | JOB_DIR = 'jobs' 63 | cfgs, config_files = [], [] 64 | 65 | # Training with Predefined Config 66 | if args.config is not None: 67 | setproctitle.setproctitle(f'CFG: {args.config}') 68 | cfg = Config.fromfile(args.config) 69 | # Specify Name and Work Directory 70 | exp_name = f'{args.machine}-{cfg["exp"]}' 71 | unique_name = f'{datetime.now().strftime("%y%m%d_%H%M")}_' \ 72 | f'{cfg["name"]}_{str(uuid.uuid4())[:5]}' 73 | child_cfg = { 74 | '_base_': args.config.replace('configs', '../..'), 75 | 'name': unique_name, 76 | 'work_dir': os.path.join('work_dirs', exp_name, unique_name), 77 | 'git_rev': get_git_hash() 78 | } 79 | cfg_out_file = f"{GEN_CONFIG_DIR}/{exp_name}/{child_cfg['name']}.json" 80 | os.makedirs(os.path.dirname(cfg_out_file), exist_ok=True) 81 | assert not os.path.isfile(cfg_out_file) 82 | with open(cfg_out_file, 'w') as of: 83 | json.dump(child_cfg, of, indent=4) 84 | config_files.append(cfg_out_file) 85 | cfgs.append(cfg) 86 | 87 | # Training with Generated Configs from experiments.py 88 | if args.exp is not None: 89 | setproctitle.setproctitle(f'SePiCo EXP: {args.exp}') 90 | exp_name = f'{args.machine}-exp{args.exp}' 91 | cfgs = generate_experiment_cfgs(args.exp) 92 | # Generate Configs 93 | for i, cfg in enumerate(cfgs): 94 | if args.debug: 95 | cfg.setdefault('log_config', {})['interval'] = 10 96 | cfg['evaluation'] = dict(interval=200, metric='mIoU') 97 | if 'dacs' in cfg['name']: 98 | cfg.setdefault('uda', {})['debug_img_interval'] = 10 99 | # cfg.setdefault('uda', {})['print_grad_magnitude'] = True 100 | # Generate Config File 101 | cfg['name'] = f'{datetime.now().strftime("%y%m%d_%H%M")}_' \ 102 | f'{cfg["name"]}_{str(uuid.uuid4())[:5]}' 103 | cfg['work_dir'] = os.path.join('work_dirs', exp_name, cfg['name']) 104 | cfg['git_rev'] = get_git_hash() 105 | cfg['_base_'] = ['../../' + e for e in cfg['_base_']] 106 | cfg_out_file = f"{GEN_CONFIG_DIR}/{exp_name}/{cfg['name']}.json" 107 | os.makedirs(os.path.dirname(cfg_out_file), exist_ok=True) 108 | assert not os.path.isfile(cfg_out_file) 109 | with open(cfg_out_file, 'w') as of: 110 | json.dump(cfg, of, indent=4) 111 | config_files.append(cfg_out_file) 112 | 113 | if args.machine == 'local': 114 | for i, cfg in enumerate(cfgs): 115 | print('Run job {}'.format(cfg['name'])) 116 | train.main([config_files[i]]) 117 | torch.cuda.empty_cache() 118 | else: 119 | raise NotImplementedError(args.machine) 120 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Update known_third_party 3 | 4 | [yapf] 5 | based_on_style = pep8 6 | blank_line_before_nested_class_or_def = true 7 | split_before_expression_after_opening_paren = true 8 | 9 | [isort] 10 | line_length = 79 11 | multi_line_output = 0 12 | known_standard_library = setuptools 13 | known_first_party = mmseg 14 | known_third_party = PIL,cityscapesscripts,cv2,kornia,matplotlib,mmcv,numpy,prettytable,seaborn,timm,torch,tqdm 15 | no_lines_before = STDLIB,LOCALFOLDER 16 | default_section = THIRDPARTY 17 | -------------------------------------------------------------------------------- /tools/analyze_logs.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | """Modified from https://github.com/open- 3 | mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" 4 | import argparse 5 | import json 6 | from collections import defaultdict 7 | 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | 12 | def plot_curve(log_dicts, args): 13 | if args.backend is not None: 14 | plt.switch_backend(args.backend) 15 | sns.set_style(args.style) 16 | # if legend is None, use {filename}_{key} as legend 17 | legend = args.legend 18 | if legend is None: 19 | legend = [] 20 | for json_log in args.json_logs: 21 | for metric in args.keys: 22 | legend.append(f'{json_log}_{metric}') 23 | assert len(legend) == (len(args.json_logs) * len(args.keys)) 24 | metrics = args.keys 25 | 26 | num_metrics = len(metrics) 27 | for i, log_dict in enumerate(log_dicts): 28 | epochs = list(log_dict.keys()) 29 | for j, metric in enumerate(metrics): 30 | print(f'plot curve of {args.json_logs[i]}, metric is {metric}') 31 | plot_epochs = [] 32 | plot_iters = [] 33 | plot_values = [] 34 | # In some log files, iters number is not correct, `pre_iter` is 35 | # used to prevent generate wrong lines. 36 | pre_iter = -1 37 | for epoch in epochs: 38 | epoch_logs = log_dict[epoch] 39 | if metric not in epoch_logs.keys(): 40 | continue 41 | if metric in ['mIoU', 'mAcc', 'aAcc']: 42 | plot_epochs.append(epoch) 43 | plot_values.append(epoch_logs[metric][0]) 44 | else: 45 | for idx in range(len(epoch_logs[metric])): 46 | if pre_iter > epoch_logs['iter'][idx]: 47 | continue 48 | pre_iter = epoch_logs['iter'][idx] 49 | plot_iters.append(epoch_logs['iter'][idx]) 50 | plot_values.append(epoch_logs[metric][idx]) 51 | ax = plt.gca() 52 | label = legend[i * num_metrics + j] 53 | if metric in ['mIoU', 'mAcc', 'aAcc']: 54 | ax.set_xticks(plot_epochs) 55 | plt.xlabel('epoch') 56 | plt.plot(plot_epochs, plot_values, label=label, marker='o') 57 | else: 58 | plt.xlabel('iter') 59 | plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) 60 | plt.legend() 61 | if args.title is not None: 62 | plt.title(args.title) 63 | if args.out is None: 64 | plt.show() 65 | else: 66 | print(f'save curve to: {args.out}') 67 | plt.savefig(args.out) 68 | plt.cla() 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser(description='Analyze Json Log') 73 | parser.add_argument( 74 | 'json_logs', 75 | type=str, 76 | nargs='+', 77 | help='path of train log in json format') 78 | parser.add_argument( 79 | '--keys', 80 | type=str, 81 | nargs='+', 82 | default=['mIoU'], 83 | help='the metric that you want to plot') 84 | parser.add_argument('--title', type=str, help='title of figure') 85 | parser.add_argument( 86 | '--legend', 87 | type=str, 88 | nargs='+', 89 | default=None, 90 | help='legend of each plot') 91 | parser.add_argument( 92 | '--backend', type=str, default=None, help='backend of plt') 93 | parser.add_argument( 94 | '--style', type=str, default='dark', help='style of plt') 95 | parser.add_argument('--out', type=str, default=None) 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | def load_json_logs(json_logs): 101 | # load and convert json_logs to log_dict, key is epoch, value is a sub dict 102 | # keys of sub dict is different metrics 103 | # value of sub dict is a list of corresponding values of all iterations 104 | log_dicts = [dict() for _ in json_logs] 105 | for json_log, log_dict in zip(json_logs, log_dicts): 106 | with open(json_log, 'r') as log_file: 107 | for line in log_file: 108 | log = json.loads(line.strip()) 109 | # skip lines without `epoch` field 110 | if 'epoch' not in log: 111 | continue 112 | epoch = log.pop('epoch') 113 | if epoch not in log_dict: 114 | log_dict[epoch] = defaultdict(list) 115 | for k, v in log.items(): 116 | log_dict[epoch][k].append(v) 117 | return log_dicts 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | json_logs = args.json_logs 123 | for json_log in json_logs: 124 | assert json_log.endswith('.json') 125 | log_dicts = load_json_logs(json_logs) 126 | plot_curve(log_dicts, args) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /tools/convert_datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 7 | # Modifications: Add class stats computation 8 | 9 | import argparse 10 | import json 11 | import os.path as osp 12 | 13 | import mmcv 14 | import numpy as np 15 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 16 | from PIL import Image 17 | 18 | 19 | def convert_json_to_label(json_file): 20 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') 21 | json2labelImg(json_file, label_file, 'trainIds') 22 | 23 | if 'train/' in json_file: 24 | pil_label = Image.open(label_file) 25 | label = np.asarray(pil_label) 26 | sample_class_stats = {} 27 | for c in range(19): 28 | n = int(np.sum(label == c)) 29 | if n > 0: 30 | sample_class_stats[int(c)] = n 31 | sample_class_stats['file'] = label_file 32 | return sample_class_stats 33 | else: 34 | return None 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser( 39 | description='Convert Cityscapes annotations to TrainIds') 40 | parser.add_argument('cityscapes_path', help='cityscapes data path') 41 | parser.add_argument('--gt-dir', default='gtFine', type=str) 42 | parser.add_argument('-o', '--out-dir', help='output path') 43 | parser.add_argument( 44 | '--nproc', default=1, type=int, help='number of process') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def save_class_stats(out_dir, sample_class_stats): 50 | sample_class_stats = [e for e in sample_class_stats if e is not None] 51 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'w') as of: 52 | json.dump(sample_class_stats, of, indent=2) 53 | 54 | sample_class_stats_dict = {} 55 | for stats in sample_class_stats: 56 | f = stats.pop('file') 57 | sample_class_stats_dict[f] = stats 58 | with open(osp.join(out_dir, 'sample_class_stats_dict.json'), 'w') as of: 59 | json.dump(sample_class_stats_dict, of, indent=2) 60 | 61 | samples_with_class = {} 62 | for file, stats in sample_class_stats_dict.items(): 63 | for c, n in stats.items(): 64 | if c not in samples_with_class: 65 | samples_with_class[c] = [(file, n)] 66 | else: 67 | samples_with_class[c].append((file, n)) 68 | with open(osp.join(out_dir, 'samples_with_class.json'), 'w') as of: 69 | json.dump(samples_with_class, of, indent=2) 70 | 71 | 72 | def main(): 73 | args = parse_args() 74 | cityscapes_path = args.cityscapes_path 75 | out_dir = args.out_dir if args.out_dir else cityscapes_path 76 | mmcv.mkdir_or_exist(out_dir) 77 | 78 | gt_dir = osp.join(cityscapes_path, args.gt_dir) 79 | 80 | poly_files = [] 81 | for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True): 82 | poly_file = osp.join(gt_dir, poly) 83 | poly_files.append(poly_file) 84 | 85 | only_postprocessing = False 86 | if not only_postprocessing: 87 | if args.nproc > 1: 88 | sample_class_stats = mmcv.track_parallel_progress( 89 | convert_json_to_label, poly_files, args.nproc) 90 | else: 91 | sample_class_stats = mmcv.track_progress(convert_json_to_label, 92 | poly_files) 93 | else: 94 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'r') as of: 95 | sample_class_stats = json.load(of) 96 | 97 | save_class_stats(out_dir, sample_class_stats) 98 | 99 | split_names = ['train', 'val', 'test'] 100 | 101 | for split in split_names: 102 | filenames = [] 103 | for poly in mmcv.scandir( 104 | osp.join(gt_dir, split), '_polygons.json', recursive=True): 105 | filenames.append(poly.replace('_gtFine_polygons.json', '')) 106 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: 107 | f.writelines(f + '\n' for f in filenames) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /tools/convert_datasets/gta.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 4 | # --------------------------------------------------------------- 5 | 6 | import argparse 7 | import json 8 | import os.path as osp 9 | 10 | import mmcv 11 | import numpy as np 12 | from PIL import Image 13 | 14 | 15 | def convert_to_train_id(file): 16 | # re-assign labels to match the format of Cityscapes 17 | pil_label = Image.open(file) 18 | label = np.asarray(pil_label) 19 | id_to_trainid = { 20 | 7: 0, 21 | 8: 1, 22 | 11: 2, 23 | 12: 3, 24 | 13: 4, 25 | 17: 5, 26 | 19: 6, 27 | 20: 7, 28 | 21: 8, 29 | 22: 9, 30 | 23: 10, 31 | 24: 11, 32 | 25: 12, 33 | 26: 13, 34 | 27: 14, 35 | 28: 15, 36 | 31: 16, 37 | 32: 17, 38 | 33: 18 39 | } 40 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 41 | sample_class_stats = {} 42 | for k, v in id_to_trainid.items(): 43 | k_mask = label == k 44 | label_copy[k_mask] = v 45 | n = int(np.sum(k_mask)) 46 | if n > 0: 47 | sample_class_stats[v] = n 48 | new_file = file.replace('.png', '_labelTrainIds.png') 49 | assert file != new_file 50 | sample_class_stats['file'] = new_file 51 | Image.fromarray(label_copy, mode='L').save(new_file) 52 | return sample_class_stats 53 | 54 | 55 | def parse_args(): 56 | parser = argparse.ArgumentParser( 57 | description='Convert GTA annotations to TrainIds') 58 | parser.add_argument('gta_path', help='gta data path') 59 | parser.add_argument('--gt-dir', default='labels', type=str) 60 | parser.add_argument('-o', '--out-dir', help='output path') 61 | parser.add_argument( 62 | '--nproc', default=4, type=int, help='number of process') 63 | args = parser.parse_args() 64 | return args 65 | 66 | 67 | def save_class_stats(out_dir, sample_class_stats): 68 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'w') as of: 69 | json.dump(sample_class_stats, of, indent=2) 70 | 71 | sample_class_stats_dict = {} 72 | for stats in sample_class_stats: 73 | f = stats.pop('file') 74 | sample_class_stats_dict[f] = stats 75 | with open(osp.join(out_dir, 'sample_class_stats_dict.json'), 'w') as of: 76 | json.dump(sample_class_stats_dict, of, indent=2) 77 | 78 | samples_with_class = {} 79 | for file, stats in sample_class_stats_dict.items(): 80 | for c, n in stats.items(): 81 | if c not in samples_with_class: 82 | samples_with_class[c] = [(file, n)] 83 | else: 84 | samples_with_class[c].append((file, n)) 85 | with open(osp.join(out_dir, 'samples_with_class.json'), 'w') as of: 86 | json.dump(samples_with_class, of, indent=2) 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | gta_path = args.gta_path 92 | out_dir = args.out_dir if args.out_dir else gta_path 93 | mmcv.mkdir_or_exist(out_dir) 94 | 95 | gt_dir = osp.join(gta_path, args.gt_dir) 96 | 97 | poly_files = [] 98 | for poly in mmcv.scandir( 99 | gt_dir, suffix=tuple(f'{i}.png' for i in range(10)), 100 | recursive=True): 101 | poly_file = osp.join(gt_dir, poly) 102 | poly_files.append(poly_file) 103 | poly_files = sorted(poly_files) 104 | 105 | only_postprocessing = False 106 | if not only_postprocessing: 107 | if args.nproc > 1: 108 | sample_class_stats = mmcv.track_parallel_progress( 109 | convert_to_train_id, poly_files, args.nproc) 110 | else: 111 | sample_class_stats = mmcv.track_progress(convert_to_train_id, 112 | poly_files) 113 | else: 114 | with open(osp.join(out_dir, 'sample_class_stats.json'), 'r') as of: 115 | sample_class_stats = json.load(of) 116 | 117 | save_class_stats(out_dir, sample_class_stats) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /tools/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Instructions for Manual Download: 4 | # 5 | # Please, download the [MiT weights](https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia?usp=sharing) 6 | # pretrained on ImageNet-1K provided by the official 7 | # [SegFormer repository](https://github.com/NVlabs/SegFormer) and put them in a 8 | # folder `pretrained/` within this project. For most of the experiments, only 9 | # mit_b5.pth is necessary. 10 | # 11 | # Please, download the checkpoint of DAFormer on GTA->Cityscapes from 12 | # [here](https://drive.google.com/file/d/1pG3kDClZDGwp1vSTEXmTchkGHmnLQNdP/view?usp=sharing). 13 | # and extract it to `work_dirs/` 14 | 15 | # Automatic Downloads: 16 | set -e # exit when any command fails 17 | mkdir -p pretrained/ 18 | cd pretrained/ 19 | gdown --id 1d3wU8KNjPL4EqMCIEO_rO-O3-REpG82T # MiT-B3 weights 20 | gdown --id 1BUtU42moYrOFbsMCE-LTTkUE-mrWnfG2 # MiT-B4 weights 21 | gdown --id 1d7I50jVjtCddnhpf-lqj8-f13UyCzoW1 # MiT-B5 weights 22 | cd ../ 23 | 24 | mkdir -p work_dirs/ 25 | cd work_dirs/ 26 | gdown --id 1pG3kDClZDGwp1vSTEXmTchkGHmnLQNdP # DAFormer on GTA->Cityscapes 27 | tar -xzf 211108_1622_gta2cs_daformer_s0_7f24c.tar.gz 28 | rm 211108_1622_gta2cs_daformer_s0_7f24c.tar.gz 29 | cd ../ 30 | -------------------------------------------------------------------------------- /tools/get_param_count.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import argparse 4 | import json 5 | import logging 6 | from copy import deepcopy 7 | 8 | from experiments import generate_experiment_cfgs 9 | from mmcv import Config, get_logger 10 | from prettytable import PrettyTable 11 | 12 | from mmseg.models import build_segmentor 13 | 14 | 15 | def human_format(num): 16 | magnitude = 0 17 | while abs(num) >= 1000: 18 | magnitude += 1 19 | num /= 1000.0 20 | # add more suffixes if you need them 21 | return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 22 | 23 | 24 | def count_parameters(model): 25 | table = PrettyTable(['Modules', 'Parameters']) 26 | total_params = 0 27 | for name, parameter in model.named_parameters(): 28 | if not parameter.requires_grad: 29 | continue 30 | param = parameter.numel() 31 | table.add_row([name, human_format(param)]) 32 | total_params += param 33 | # print(table) 34 | print(f'Total Trainable Params: {human_format(total_params)}') 35 | return total_params 36 | 37 | 38 | # Run: python -m tools.param_count 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | '--exp', 43 | nargs='?', 44 | type=int, 45 | default=100, 46 | help='Experiment id as defined in experiment.py', 47 | ) 48 | args = parser.parse_args() 49 | get_logger('mmseg', log_level=logging.ERROR) 50 | cfgs = generate_experiment_cfgs(args.exp) 51 | for cfg in cfgs: 52 | with open('configs/tmp_param.json', 'w') as f: 53 | json.dump(cfg, f) 54 | cfg = Config.fromfile('configs/tmp_param.json') 55 | 56 | model = build_segmentor(deepcopy(cfg['model'])) 57 | # model.init_weights() 58 | # count_parameters(model) 59 | print(f'Encoder {cfg["name_encoder"]}:') 60 | count_parameters(model.backbone) 61 | print(f'Decoder {cfg["name_decoder"]}:') 62 | count_parameters(model.decode_head) 63 | -------------------------------------------------------------------------------- /tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import argparse 4 | 5 | from mmcv import Config, DictAction 6 | 7 | from mmseg.apis import init_segmentor 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Print the whole config') 12 | parser.add_argument('config', help='config file path') 13 | parser.add_argument( 14 | '--graph', action='store_true', help='print the models graph') 15 | parser.add_argument( 16 | '--options', nargs='+', action=DictAction, help='arguments in dict') 17 | args = parser.parse_args() 18 | 19 | return args 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | 25 | cfg = Config.fromfile(args.config) 26 | if args.options is not None: 27 | cfg.merge_from_dict(args.options) 28 | print(f'Config:\n{cfg.pretty_text}') 29 | # dump config 30 | cfg.dump('example.py') 31 | # dump models graph 32 | if args.graph: 33 | model = init_segmentor(args.config, device='cpu') 34 | print(f'Model graph:\n{str(model)}') 35 | with open('example-graph.txt', 'w') as f: 36 | f.writelines(str(model)) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Remove auxiliary models 3 | 4 | import argparse 5 | 6 | import torch 7 | 8 | # import subprocess 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description='Process a checkpoint to be published') 14 | parser.add_argument('in_file', help='input checkpoint filename') 15 | parser.add_argument('out_file', help='output checkpoint filename') 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | def process_checkpoint(in_file, out_file): 21 | checkpoint = torch.load(in_file, map_location='cpu') 22 | # remove optimizer for smaller file size 23 | if 'optimizer' in checkpoint: 24 | del checkpoint['optimizer'] 25 | # remove auxiliary models 26 | for k in list(checkpoint['state_dict'].keys()): 27 | if 'imnet_model' in k or 'ema_model' in k: 28 | del checkpoint['state_dict'][k] 29 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 30 | # add the code here. 31 | if 'meta' in checkpoint: 32 | del checkpoint['meta'] 33 | # inspect checkpoint 34 | print('Checkpoint keys:', checkpoint.keys()) 35 | print('Checkpoint state_dict keys:', checkpoint['state_dict'].keys()) 36 | # save checkpoint 37 | torch.save(checkpoint, out_file) 38 | # sha = subprocess.check_output(['sha256sum', out_file]).decode() 39 | # final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 40 | # subprocess.Popen(['mv', out_file, final_file]) 41 | 42 | 43 | def main(): 44 | args = parse_args() 45 | process_checkpoint(args.in_file, args.out_file) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Provide args as argument to main() 4 | # - Snapshot source code 5 | # - Build UDA model instead of regular one 6 | 7 | import argparse 8 | import copy 9 | import os 10 | import os.path as osp 11 | import sys 12 | import time 13 | 14 | import mmcv 15 | import torch 16 | from mmcv.runner import init_dist 17 | from mmcv.utils import Config, DictAction, get_git_hash 18 | 19 | from mmseg import __version__ 20 | from mmseg.apis import set_random_seed, train_segmentor 21 | from mmseg.datasets import build_dataset 22 | from mmseg.models.builder import build_train_model 23 | from mmseg.utils import collect_env, get_root_logger 24 | from mmseg.utils.collect_env import gen_code_archive 25 | 26 | 27 | def parse_args(args): 28 | parser = argparse.ArgumentParser(description='Train a segmentor') 29 | parser.add_argument('config', help='train config file path') 30 | parser.add_argument('--work-dir', help='the dir to save logs and models') 31 | parser.add_argument( 32 | '--load-from', help='the checkpoint file to load weights from') 33 | parser.add_argument( 34 | '--resume-from', help='the checkpoint file to resume from') 35 | parser.add_argument( 36 | '--no-validate', 37 | action='store_true', 38 | help='whether not to evaluate the checkpoint during training') 39 | group_gpus = parser.add_mutually_exclusive_group() 40 | group_gpus.add_argument( 41 | '--gpus', 42 | type=int, 43 | help='number of gpus to use ' 44 | '(only applicable to non-distributed training)') 45 | group_gpus.add_argument( 46 | '--gpu-ids', 47 | type=int, 48 | nargs='+', 49 | help='ids of gpus to use ' 50 | '(only applicable to non-distributed training)') 51 | parser.add_argument('--seed', type=int, default=None, help='random seed') 52 | parser.add_argument( 53 | '--deterministic', 54 | action='store_true', 55 | help='whether to set deterministic options for CUDNN backend.') 56 | parser.add_argument( 57 | '--options', nargs='+', action=DictAction, help='custom options') 58 | parser.add_argument( 59 | '--launcher', 60 | choices=['none', 'pytorch', 'slurm', 'mpi'], 61 | default='none', 62 | help='job launcher') 63 | parser.add_argument('--local_rank', type=int, default=0) 64 | args = parser.parse_args(args) 65 | if 'LOCAL_RANK' not in os.environ: 66 | os.environ['LOCAL_RANK'] = str(args.local_rank) 67 | 68 | return args 69 | 70 | 71 | def main(args): 72 | args = parse_args(args) 73 | 74 | cfg = Config.fromfile(args.config) 75 | if args.options is not None: 76 | cfg.merge_from_dict(args.options) 77 | # set cudnn_benchmark 78 | if cfg.get('cudnn_benchmark', False): 79 | torch.backends.cudnn.benchmark = True 80 | 81 | # work_dir is determined in this priority: CLI > segment in file > filename 82 | if args.work_dir is not None: 83 | # update configs according to CLI args if args.work_dir is not None 84 | cfg.work_dir = args.work_dir 85 | elif cfg.get('work_dir', None) is None: 86 | # use config filename as default work_dir if cfg.work_dir is None 87 | cfg.work_dir = osp.join('./work_dirs', 88 | osp.splitext(osp.basename(args.config))[0]) 89 | cfg.model.train_cfg.work_dir = cfg.work_dir 90 | if args.load_from is not None: 91 | cfg.load_from = args.load_from 92 | if args.resume_from is not None: 93 | cfg.resume_from = args.resume_from 94 | if args.gpu_ids is not None: 95 | cfg.gpu_ids = args.gpu_ids 96 | else: 97 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 98 | 99 | # init distributed env first, since logger depends on the dist info. 100 | if args.launcher == 'none': 101 | distributed = False 102 | else: 103 | distributed = True 104 | init_dist(args.launcher, **cfg.dist_params) 105 | 106 | # create work_dir 107 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 108 | # dump config 109 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 110 | # snapshot source code 111 | gen_code_archive(cfg.work_dir) 112 | # init the logger before other steps 113 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 114 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 115 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 116 | 117 | # init the meta dict to record some important information such as 118 | # environment info and seed, which will be logged 119 | meta = dict() 120 | # log env info 121 | env_info_dict = collect_env() 122 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 123 | dash_line = '-' * 60 + '\n' 124 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 125 | dash_line) 126 | meta['env_info'] = env_info 127 | 128 | # log some basic info 129 | logger.info(f'Distributed training: {distributed}') 130 | logger.info(f'Config:\n{cfg.pretty_text}') 131 | 132 | # set random seeds 133 | if args.seed is None and 'seed' in cfg: 134 | args.seed = cfg['seed'] 135 | if args.seed is not None: 136 | logger.info(f'Set random seed to {args.seed}, deterministic: ' 137 | f'{args.deterministic}') 138 | set_random_seed(args.seed, deterministic=args.deterministic) 139 | cfg.seed = args.seed 140 | meta['seed'] = args.seed 141 | meta['exp_name'] = osp.splitext(osp.basename(args.config))[0] 142 | 143 | model = build_train_model( 144 | cfg, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) 145 | model.init_weights() 146 | 147 | logger.info(model) 148 | 149 | datasets = [build_dataset(cfg.data.train)] 150 | if len(cfg.workflow) == 2: 151 | val_dataset = copy.deepcopy(cfg.data.val) 152 | val_dataset.pipeline = cfg.data.train.pipeline 153 | datasets.append(build_dataset(val_dataset)) 154 | if cfg.checkpoint_config is not None: 155 | # save mmseg version, config file content and class names in 156 | # checkpoints as meta data 157 | cfg.checkpoint_config.meta = dict( 158 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 159 | config=cfg.pretty_text, 160 | CLASSES=datasets[0].CLASSES, 161 | PALETTE=datasets[0].PALETTE) 162 | # add an attribute for visualization convenience 163 | model.CLASSES = datasets[0].CLASSES 164 | # passing checkpoint meta for saving best checkpoint 165 | meta.update(cfg.checkpoint_config.meta) 166 | train_segmentor( 167 | model, 168 | datasets, 169 | cfg, 170 | distributed=distributed, 171 | validate=(not args.no_validate), 172 | timestamp=timestamp, 173 | meta=meta) 174 | 175 | 176 | if __name__ == '__main__': 177 | main(sys.argv[1:]) 178 | --------------------------------------------------------------------------------