├── README.md ├── configs ├── _base_ │ ├── datasets │ │ ├── ddr_1024.py │ │ ├── hr_ddr_2048.py │ │ ├── hr_idrid_2880x1920-slide.py │ │ ├── hr_idrid_2880x1920.py │ │ ├── idrid_1440x960-slide.py │ │ └── idrid_1440x960.py │ ├── default_runtime.py │ ├── models │ │ ├── efficient-hrdecoder_fcn_hr18.py │ │ ├── efficient-hrdecoder_fcn_hr48.py │ │ ├── fcn_hr18.py │ │ ├── fcn_hr48.py │ │ ├── hrdecoder_fcn_hr18.py │ │ └── hrdecoder_fcn_hr48.py │ └── schedules │ │ ├── adamw.py │ │ ├── poly10.py │ │ ├── poly10warm.py │ │ └── sgd.py └── lesion │ ├── efficient-hrdecoder_fcn_hr48_ddr_2048.py │ ├── efficient-hrdecoder_fcn_hr48_idrid_2880x1920-slide.py │ ├── fcn_hr48_ddr_1024.py │ ├── fcn_hr48_idrid_1440x960.py │ ├── hrdecoder_fcn_hr48_ddr_2048.py │ └── hrdecoder_fcn_hr48_idrid_2880x1920-slide.py ├── mmseg ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── ddp_wrapper.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── class_names.py │ │ ├── eval_hooks.py │ │ ├── lesion_metric.py │ │ └── metrics.py │ ├── seg │ │ ├── __init__.py │ │ ├── builder.py │ │ └── sampler │ │ │ ├── __init__.py │ │ │ ├── base_pixel_sampler.py │ │ │ └── ohem_pixel_sampler.py │ └── utils │ │ ├── __init__.py │ │ └── misc.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── custom.py │ ├── dataset_wrappers.py │ ├── lesion_dataset.py │ └── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formating.py │ │ ├── loading.py │ │ ├── test_time_aug.py │ │ └── transforms.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── hrnet.py │ │ ├── mit.py │ │ ├── mix_transformer.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ ├── swin.py │ │ ├── swin_unet.py │ │ ├── unet.py │ │ ├── vit_adapter.py │ │ └── vit_det.py │ ├── builder.py │ ├── decode_heads │ │ ├── __init__.py │ │ ├── aspp_head.py │ │ ├── da_head.py │ │ ├── daformer_head.py │ │ ├── decode_head.py │ │ ├── dlv2_head.py │ │ ├── fcn_head.py │ │ ├── isa_head.py │ │ ├── psp_head.py │ │ ├── segformer_head.py │ │ ├── sep_aspp_head.py │ │ └── uper_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── binary_loss.py │ │ ├── cross_entropy_loss.py │ │ ├── sam_loss.py │ │ └── utils.py │ ├── necks │ │ ├── __init__.py │ │ ├── sam_neck.py │ │ ├── segformer_adapter.py │ │ └── segformer_neck.py │ ├── segmentors │ │ ├── HRDecoder.py │ │ ├── __init__.py │ │ ├── base.py │ │ ├── encoder_decoder.py │ │ └── lesion_encoder_decoder.py │ └── utils │ │ ├── __init__.py │ │ ├── ckpt_convert.py │ │ ├── embed.py │ │ ├── make_divisible.py │ │ ├── res_layer.py │ │ ├── self_attention_block.py │ │ ├── shape_convert.py │ │ ├── up_conv_block.py │ │ └── wrappers.py ├── ops │ ├── __init__.py │ ├── encoding.py │ └── wrappers.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── logger.py │ ├── precision_logger.py │ └── utils.py └── version.py ├── requirements.txt └── tools ├── analyze_logs.py ├── convert_dataset ├── ddr.py └── idrid.py ├── dist_test.sh ├── dist_train.sh ├── get_flops.py ├── get_fps.py ├── print_config.py ├── test.py └── train.py /configs/_base_/datasets/ddr_1024.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | """ 3 | rgb mean: 4 | [81.20546605 50.63635725 21.21597278] 5 | rgb std: 6 | [76.25170836 48.79813652 21.62512444] 7 | """ 8 | dataset_type = 'LesionDataset' 9 | data_root = './data/DDR' 10 | img_norm_cfg = dict( 11 | mean=[81.205, 50.636, 21.216], std=[76.252, 48.798, 21.625], to_rgb=True) 12 | #idrid 13 | #img_norm_cfg = dict( 14 | # mean=[116.513, 56.437, 16.309], std=[80.206, 41.232, 13.293], to_rgb=True) 15 | image_scale = (1024, 1024) 16 | crop_size = (1024, 1024) 17 | palette = [ 18 | [0, 0, 0], 19 | [128, 0, 0], # EX: red 20 | [0, 128, 0], # HE: green 21 | [128, 128, 0], # SE: yellow 22 | [0, 0, 128] # MA: blue 23 | ] 24 | classes = ['bg', 'EX', 'HE', 'SE', 'MA'] 25 | train_pipeline = [ 26 | dict(type='LoadImageFromFile'), 27 | dict(type='LoadAnnotations'), 28 | dict(type='Resize', img_scale=image_scale, ratio_range=(0.5, 2.0)), 29 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 30 | dict(type='RandomFlip', flip_ratio=0.5), 31 | dict(type='RandomRotate', prob=1.0, pad_val=0, seg_pad_val=0, 32 | degree=(-45,45), auto_bound=False), 33 | #dict(type='PhotoMetricDistortion'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=0), 36 | dict(type='DefaultFormatBundle'), 37 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 38 | ] 39 | test_pipeline = [ 40 | dict(type='LoadImageFromFile'), 41 | dict( 42 | type='MultiScaleFlipAug', 43 | img_scale=image_scale, 44 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 45 | flip=False, 46 | transforms=[ 47 | dict(type='Resize', keep_ratio=False), 48 | #dict(type='RandomFlip'), 49 | dict(type='Normalize', **img_norm_cfg), 50 | dict(type='ImageToTensor', keys=['img']), 51 | dict(type='Collect', keys=['img']), 52 | ]) 53 | ] 54 | 55 | data = dict( 56 | samples_per_gpu=1, 57 | workers_per_gpu=1, 58 | train=dict( 59 | img_dir='images/train', 60 | ann_dir='labels/train', 61 | data_root=data_root, 62 | classes=classes, 63 | palette=palette, 64 | type=dataset_type, 65 | pipeline=train_pipeline), 66 | val=dict( 67 | img_dir='images/test', 68 | ann_dir='labels/test', 69 | data_root=data_root, 70 | classes=classes, 71 | palette=palette, 72 | type=dataset_type, 73 | pipeline=test_pipeline), 74 | test=dict( 75 | img_dir='images/test', 76 | ann_dir='labels/test', 77 | data_root=data_root, 78 | classes=classes, 79 | palette=palette, 80 | type=dataset_type, 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/datasets/hr_ddr_2048.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | """ 3 | rgb mean: 4 | [81.20546605 50.63635725 21.21597278] 5 | rgb std: 6 | [76.25170836 48.79813652 21.62512444] 7 | """ 8 | dataset_type = 'LesionDataset' 9 | data_root = './data/DDR' 10 | img_norm_cfg = dict( 11 | mean=[81.205, 50.636, 21.216], std=[76.252, 48.798, 21.625], to_rgb=True) 12 | #idrid 13 | #img_norm_cfg = dict( 14 | # mean=[116.513, 56.437, 16.309], std=[80.206, 41.232, 13.293], to_rgb=True) 15 | image_scale = (2048, 2048) 16 | crop_size = (2048, 2048) 17 | palette = [ 18 | [0, 0, 0], 19 | [128, 0, 0], # EX: red 20 | [0, 128, 0], # HE: green 21 | [128, 128, 0], # SE: yellow 22 | [0, 0, 128] # MA: blue 23 | ] 24 | classes = ['bg', 'EX', 'HE', 'SE', 'MA'] 25 | train_pipeline = [ 26 | dict(type='LoadImageFromFile'), 27 | dict(type='LoadAnnotations'), 28 | dict(type='Resize', img_scale=image_scale, ratio_range=(0.5, 2.0)), 29 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 30 | dict(type='RandomFlip', flip_ratio=0.5), 31 | dict(type='RandomRotate', prob=1.0, pad_val=0, seg_pad_val=0, 32 | degree=(-45,45), auto_bound=False), 33 | #dict(type='PhotoMetricDistortion'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=0), 36 | dict(type='DefaultFormatBundle'), 37 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 38 | ] 39 | test_pipeline = [ 40 | dict(type='LoadImageFromFile'), 41 | dict( 42 | type='MultiScaleFlipAug', 43 | img_scale=image_scale, 44 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 45 | flip=False, 46 | transforms=[ 47 | dict(type='Resize', keep_ratio=False), 48 | #dict(type='RandomFlip'), 49 | dict(type='Normalize', **img_norm_cfg), 50 | dict(type='ImageToTensor', keys=['img']), 51 | dict(type='Collect', keys=['img']), 52 | ]) 53 | ] 54 | 55 | data = dict( 56 | samples_per_gpu=1, 57 | workers_per_gpu=1, 58 | train=dict( 59 | img_dir='images/train', 60 | ann_dir='labels/train', 61 | data_root=data_root, 62 | classes=classes, 63 | palette=palette, 64 | type=dataset_type, 65 | pipeline=train_pipeline), 66 | val=dict( 67 | img_dir='images/test', 68 | ann_dir='labels/test', 69 | data_root=data_root, 70 | classes=classes, 71 | palette=palette, 72 | type=dataset_type, 73 | pipeline=test_pipeline), 74 | test=dict( 75 | img_dir='images/test', 76 | ann_dir='labels/test', 77 | data_root=data_root, 78 | classes=classes, 79 | palette=palette, 80 | type=dataset_type, 81 | pipeline=test_pipeline)) 82 | -------------------------------------------------------------------------------- /configs/_base_/datasets/hr_idrid_2880x1920-slide.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | """ 3 | rgb mean: 4 | [116.51282647 56.43716432 16.30857136] 5 | rgb std: 6 | [80.20605713 41.23209693 13.29250962] 7 | """ 8 | dataset_type = 'LesionDataset' 9 | data_root = './data/IDRID' 10 | # idrid 11 | img_norm_cfg = dict( 12 | mean=[116.513, 56.437, 16.309], std=[80.206, 41.232, 13.293], to_rgb=True) 13 | #ddr 14 | #img_norm_cfg = dict( 15 | # mean=[81.205, 50.636, 21.216], std=[76.252, 48.798, 21.625], to_rgb=True) 16 | image_scale = (2880, 1920) 17 | crop_size = (1920, 1920) 18 | palette = [ 19 | [0, 0, 0], 20 | [128, 0, 0], # EX: red 21 | [0, 128, 0], # HE: green 22 | [128, 128, 0], # SE: yellow 23 | [0, 0, 128] # MA: blue 24 | ] 25 | classes = ['bg', 'EX', 'HE', 'SE', 'MA'] 26 | train_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='LoadAnnotations'), 29 | dict(type='Resize', img_scale=image_scale, ratio_range=(0.5,2.0)), 30 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 31 | dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), 32 | dict(type='RandomRotate', prob=1.0, pad_val=0, seg_pad_val=0, 33 | degree=(-45,45), auto_bound=False), 34 | #dict(type='PhotoMetricDistortion'), 35 | dict(type='Normalize', **img_norm_cfg), 36 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=0), 37 | dict(type='DefaultFormatBundle'), 38 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 39 | ] 40 | test_pipeline = [ 41 | dict(type='LoadImageFromFile'), 42 | dict( 43 | type='MultiScaleFlipAug', 44 | img_scale=image_scale, 45 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 46 | flip=False, 47 | transforms=[ 48 | dict(type='Resize', keep_ratio=False), 49 | # dict(type='RandomFlip'), 50 | dict(type='Normalize', **img_norm_cfg), 51 | dict(type='ImageToTensor', keys=['img']), 52 | dict(type='Collect', keys=['img']), 53 | ]) 54 | ] 55 | 56 | data = dict( 57 | samples_per_gpu=1, 58 | workers_per_gpu=1, 59 | train=dict( 60 | img_dir='images/train', 61 | ann_dir='labels/train', 62 | data_root=data_root, 63 | classes=classes, 64 | palette=palette, 65 | type=dataset_type, 66 | pipeline=train_pipeline), 67 | val=dict( 68 | img_dir='images/test', 69 | ann_dir='labels/test', 70 | data_root=data_root, 71 | classes=classes, 72 | palette=palette, 73 | type=dataset_type, 74 | pipeline=test_pipeline), 75 | test=dict( 76 | img_dir='images/test', 77 | ann_dir='labels/test', 78 | data_root=data_root, 79 | classes=classes, 80 | palette=palette, 81 | type=dataset_type, 82 | pipeline=test_pipeline)) 83 | -------------------------------------------------------------------------------- /configs/_base_/datasets/hr_idrid_2880x1920.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | """ 3 | rgb mean: 4 | [116.51282647 56.43716432 16.30857136] 5 | rgb std: 6 | [80.20605713 41.23209693 13.29250962] 7 | """ 8 | dataset_type = 'LesionDataset' 9 | data_root = './data/IDRID' 10 | # idrid 11 | img_norm_cfg = dict( 12 | mean=[116.513, 56.437, 16.309], std=[80.206, 41.232, 13.293], to_rgb=True) 13 | #ddr 14 | #img_norm_cfg = dict( 15 | # mean=[81.205, 50.636, 21.216], std=[76.252, 48.798, 21.625], to_rgb=True) 16 | image_scale = (2880, 1920) 17 | crop_size = (1920, 2880) 18 | palette = [ 19 | [0, 0, 0], 20 | [128, 0, 0], # EX: red 21 | [0, 128, 0], # HE: green 22 | [128, 128, 0], # SE: yellow 23 | [0, 0, 128] # MA: blue 24 | ] 25 | classes = ['bg', 'EX', 'HE', 'SE', 'MA'] 26 | train_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='LoadAnnotations'), 29 | dict(type='Resize', img_scale=image_scale, ratio_range=(0.5,2.0)), 30 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 31 | dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), 32 | dict(type='RandomRotate', prob=1.0, pad_val=0, seg_pad_val=0, 33 | degree=(-45,45), auto_bound=False), 34 | #dict(type='PhotoMetricDistortion'), 35 | dict(type='Normalize', **img_norm_cfg), 36 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=0), 37 | dict(type='DefaultFormatBundle'), 38 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 39 | ] 40 | test_pipeline = [ 41 | dict(type='LoadImageFromFile'), 42 | dict( 43 | type='MultiScaleFlipAug', 44 | img_scale=image_scale, 45 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 46 | flip=False, 47 | transforms=[ 48 | dict(type='Resize', keep_ratio=False), 49 | # dict(type='RandomFlip'), 50 | dict(type='Normalize', **img_norm_cfg), 51 | dict(type='ImageToTensor', keys=['img']), 52 | dict(type='Collect', keys=['img']), 53 | ]) 54 | ] 55 | 56 | data = dict( 57 | samples_per_gpu=1, 58 | workers_per_gpu=1, 59 | train=dict( 60 | img_dir='images/train', 61 | ann_dir='labels/train', 62 | data_root=data_root, 63 | classes=classes, 64 | palette=palette, 65 | type=dataset_type, 66 | pipeline=train_pipeline), 67 | val=dict( 68 | img_dir='images/test', 69 | ann_dir='labels/test', 70 | data_root=data_root, 71 | classes=classes, 72 | palette=palette, 73 | type=dataset_type, 74 | pipeline=test_pipeline), 75 | test=dict( 76 | img_dir='images/test', 77 | ann_dir='labels/test', 78 | data_root=data_root, 79 | classes=classes, 80 | palette=palette, 81 | type=dataset_type, 82 | pipeline=test_pipeline)) 83 | -------------------------------------------------------------------------------- /configs/_base_/datasets/idrid_1440x960-slide.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | """ 3 | rgb mean: 4 | [116.51282647 56.43716432 16.30857136] 5 | rgb std: 6 | [80.20605713 41.23209693 13.29250962] 7 | """ 8 | dataset_type = 'LesionDataset' 9 | data_root = './data/IDRID' 10 | # idrid 11 | img_norm_cfg = dict( 12 | mean=[116.513, 56.437, 16.309], std=[80.206, 41.232, 13.293], to_rgb=True) 13 | #ddr 14 | #img_norm_cfg = dict( 15 | # mean=[81.205, 50.636, 21.216], std=[76.252, 48.798, 21.625], to_rgb=True) 16 | image_scale = (1440, 960) 17 | crop_size = (960, 960) 18 | palette = [ 19 | [0, 0, 0], 20 | [128, 0, 0], # EX: red 21 | [0, 128, 0], # HE: green 22 | [128, 128, 0], # SE: yellow 23 | [0, 0, 128] # MA: blue 24 | ] 25 | classes = ['bg', 'EX', 'HE', 'SE', 'MA'] 26 | train_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='LoadAnnotations'), 29 | dict(type='Resize', img_scale=image_scale, ratio_range=(0.5,2.0)), 30 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 31 | dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), 32 | dict(type='RandomRotate', prob=1.0, pad_val=0, seg_pad_val=0, 33 | degree=(-45,45), auto_bound=False), 34 | #dict(type='PhotoMetricDistortion'), 35 | dict(type='Normalize', **img_norm_cfg), 36 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=0), 37 | dict(type='DefaultFormatBundle'), 38 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 39 | ] 40 | test_pipeline = [ 41 | dict(type='LoadImageFromFile'), 42 | dict( 43 | type='MultiScaleFlipAug', 44 | img_scale=image_scale, 45 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 46 | flip=False, 47 | transforms=[ 48 | dict(type='Resize', keep_ratio=False), 49 | # dict(type='RandomFlip'), 50 | dict(type='Normalize', **img_norm_cfg), 51 | dict(type='ImageToTensor', keys=['img']), 52 | dict(type='Collect', keys=['img']), 53 | ]) 54 | ] 55 | 56 | data = dict( 57 | samples_per_gpu=1, 58 | workers_per_gpu=1, 59 | train=dict( 60 | img_dir='images/train', 61 | ann_dir='labels/train', 62 | data_root=data_root, 63 | classes=classes, 64 | palette=palette, 65 | type=dataset_type, 66 | pipeline=train_pipeline), 67 | val=dict( 68 | img_dir='images/test', 69 | ann_dir='labels/test', 70 | data_root=data_root, 71 | classes=classes, 72 | palette=palette, 73 | type=dataset_type, 74 | pipeline=test_pipeline), 75 | test=dict( 76 | img_dir='images/test', 77 | ann_dir='labels/test', 78 | data_root=data_root, 79 | classes=classes, 80 | palette=palette, 81 | type=dataset_type, 82 | pipeline=test_pipeline)) 83 | -------------------------------------------------------------------------------- /configs/_base_/datasets/idrid_1440x960.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | """ 3 | rgb mean: 4 | [116.51282647 56.43716432 16.30857136] 5 | rgb std: 6 | [80.20605713 41.23209693 13.29250962] 7 | """ 8 | dataset_type = 'LesionDataset' 9 | data_root = './data/IDRID' 10 | # idrid 11 | img_norm_cfg = dict( 12 | mean=[116.513, 56.437, 16.309], std=[80.206, 41.232, 13.293], to_rgb=True) 13 | #ddr 14 | #img_norm_cfg = dict( 15 | # mean=[81.205, 50.636, 21.216], std=[76.252, 48.798, 21.625], to_rgb=True) 16 | image_scale = (1440, 960) 17 | crop_size = (960, 1440) 18 | palette = [ 19 | [0, 0, 0], 20 | [128, 0, 0], # EX: red 21 | [0, 128, 0], # HE: green 22 | [128, 128, 0], # SE: yellow 23 | [0, 0, 128] # MA: blue 24 | ] 25 | classes = ['bg', 'EX', 'HE', 'SE', 'MA'] 26 | train_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='LoadAnnotations'), 29 | dict(type='Resize', img_scale=image_scale, ratio_range=(0.5,2.0)), 30 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 31 | dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), 32 | dict(type='RandomRotate', prob=1.0, pad_val=0, seg_pad_val=0, 33 | degree=(-45,45), auto_bound=False), 34 | #dict(type='PhotoMetricDistortion'), 35 | dict(type='Normalize', **img_norm_cfg), 36 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=0), 37 | dict(type='DefaultFormatBundle'), 38 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 39 | ] 40 | test_pipeline = [ 41 | dict(type='LoadImageFromFile'), 42 | dict( 43 | type='MultiScaleFlipAug', 44 | img_scale=image_scale, 45 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 46 | flip=False, 47 | transforms=[ 48 | dict(type='Resize', keep_ratio=False), 49 | # dict(type='RandomFlip'), 50 | dict(type='Normalize', **img_norm_cfg), 51 | dict(type='ImageToTensor', keys=['img']), 52 | dict(type='Collect', keys=['img']), 53 | ]) 54 | ] 55 | 56 | data = dict( 57 | samples_per_gpu=1, 58 | workers_per_gpu=1, 59 | train=dict( 60 | img_dir='images/train', 61 | ann_dir='labels/train', 62 | data_root=data_root, 63 | classes=classes, 64 | palette=palette, 65 | type=dataset_type, 66 | pipeline=train_pipeline), 67 | val=dict( 68 | img_dir='images/test', 69 | ann_dir='labels/test', 70 | data_root=data_root, 71 | classes=classes, 72 | palette=palette, 73 | type=dataset_type, 74 | pipeline=test_pipeline), 75 | test=dict( 76 | img_dir='images/test', 77 | ann_dir='labels/test', 78 | data_root=data_root, 79 | classes=classes, 80 | palette=palette, 81 | type=dataset_type, 82 | pipeline=test_pipeline)) 83 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | # yapf:disable 4 | log_config = dict( 5 | interval=100, 6 | hooks=[ 7 | dict(type='TextLoggerHook', by_epoch=False), 8 | # dict(type='TensorboardLoggerHook') 9 | ]) 10 | # yapf:enable 11 | dist_params = dict(backend='nccl') 12 | log_level = 'INFO' 13 | load_from = None 14 | resume_from = None 15 | workflow = [('train', 1)] 16 | cudnn_benchmark = True 17 | -------------------------------------------------------------------------------- /configs/_base_/models/efficient-hrdecoder_fcn_hr18.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EfficientHRDecoder', 5 | use_sigmoid=True, 6 | pretrained='open-mmlab://msra/hrnetv2_w18', 7 | backbone=dict( 8 | type='HRNet', 9 | norm_cfg=norm_cfg, 10 | norm_eval=False, 11 | extra=dict( 12 | stage1=dict( 13 | num_modules=1, 14 | num_branches=1, 15 | block='BOTTLENECK', 16 | num_blocks=(4, ), 17 | num_channels=(64, )), 18 | stage2=dict( 19 | num_modules=1, 20 | num_branches=2, 21 | block='BASIC', 22 | num_blocks=(4, 4), 23 | num_channels=(18, 36)), 24 | stage3=dict( 25 | num_modules=4, 26 | num_branches=3, 27 | block='BASIC', 28 | num_blocks=(4, 4, 4), 29 | num_channels=(18, 36, 72)), 30 | stage4=dict( 31 | num_modules=3, 32 | num_branches=4, 33 | block='BASIC', 34 | num_blocks=(4, 4, 4, 4), 35 | num_channels=(18, 36, 72, 144)))), 36 | hr_settings=dict( 37 | in_channels=sum([18, 36, 72, 144]), 38 | visual_dim = 64, 39 | hr_scale=(1024,1024), 40 | scale_ratio=(0.75,1.25), 41 | divisible=8, 42 | lr_loss_weight=0, 43 | hr_loss_weight=0.1, 44 | fuse_mode = 'simple', 45 | crop_num = 2, 46 | ), 47 | decode_head=dict( 48 | type='FCNHead', 49 | in_channels=64, 50 | in_index=0, 51 | channels=64, 52 | kernel_size=7, 53 | num_convs=1, 54 | compress=False, 55 | concat_input=False, 56 | dropout_ratio=-1, 57 | num_classes=4, 58 | norm_cfg=norm_cfg, 59 | align_corners=False, 60 | loss_decode=dict( 61 | type='BinaryLoss', loss_type='dice', loss_weight=1.0, smooth=1e-5) 62 | ),# model training and testing settings 63 | train_cfg = dict(), 64 | test_cfg = dict(mode='whole',compute_aupr=True) 65 | ) 66 | -------------------------------------------------------------------------------- /configs/_base_/models/efficient-hrdecoder_fcn_hr48.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | 3 | _base_ = './efficient-hrdecoder_fcn_hr18.py' 4 | 5 | model = dict( 6 | pretrained='open-mmlab://msra/hrnetv2_w48', 7 | backbone=dict( 8 | extra=dict( 9 | stage2=dict(num_channels=(48, 96)), 10 | stage3=dict(num_channels=(48, 96, 192)), 11 | stage4=dict(num_channels=(48, 96, 192, 384)))), 12 | hr_settings=dict( 13 | in_channels=sum([48, 96, 192, 384]), 14 | ), 15 | ) -------------------------------------------------------------------------------- /configs/_base_/models/fcn_hr18.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | text_embed_dim=512 4 | visual_embed_dim=512 5 | model = dict( 6 | type='LesionEncoderDecoder', 7 | use_sigmoid=True, 8 | pretrained='open-mmlab://msra/hrnetv2_w18', 9 | backbone=dict( 10 | type='HRNet', 11 | norm_cfg=norm_cfg, 12 | norm_eval=False, 13 | extra=dict( 14 | stage1=dict( 15 | num_modules=1, 16 | num_branches=1, 17 | block='BOTTLENECK', 18 | num_blocks=(4, ), 19 | num_channels=(64, )), 20 | stage2=dict( 21 | num_modules=1, 22 | num_branches=2, 23 | block='BASIC', 24 | num_blocks=(4, 4), 25 | num_channels=(18, 36)), 26 | stage3=dict( 27 | num_modules=4, 28 | num_branches=3, 29 | block='BASIC', 30 | num_blocks=(4, 4, 4), 31 | num_channels=(18, 36, 72)), 32 | stage4=dict( 33 | num_modules=3, 34 | num_branches=4, 35 | block='BASIC', 36 | num_blocks=(4, 4, 4, 4), 37 | num_channels=(18, 36, 72, 144)))), 38 | decode_head=dict( 39 | type='FCNHead', 40 | in_channels=[18, 36, 72, 144], 41 | in_index=(0, 1, 2, 3), 42 | channels=visual_embed_dim, 43 | input_transform='resize_concat', 44 | kernel_size=1, 45 | num_convs=1, 46 | concat_input=False, 47 | dropout_ratio=-1, 48 | num_classes=4, 49 | norm_cfg=norm_cfg, 50 | align_corners=False, 51 | loss_decode=dict( 52 | type='BinaryLoss', loss_type='dice', loss_weight=1.0, smooth=1e-5) 53 | ),# model training and testing settings 54 | train_cfg = dict(), 55 | test_cfg = dict(mode='whole',compute_aupr=True), 56 | ) 57 | -------------------------------------------------------------------------------- /configs/_base_/models/fcn_hr48.py: -------------------------------------------------------------------------------- 1 | _base_='./fcn_hr18.py' 2 | 3 | norm_cfg = dict(type='SyncBN', requires_grad=True) 4 | model = dict( 5 | pretrained='open-mmlab://msra/hrnetv2_w48', 6 | backbone=dict( 7 | extra=dict( 8 | stage2=dict(num_channels=(48, 96)), 9 | stage3=dict(num_channels=(48, 96, 192)), 10 | stage4=dict(num_channels=(48, 96, 192, 384)))), 11 | decode_head=dict( 12 | in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384])), 13 | train_cfg = dict(), 14 | test_cfg = dict(mode='whole',compute_aupr=True)) -------------------------------------------------------------------------------- /configs/_base_/models/hrdecoder_fcn_hr18.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='HRDecoder', 5 | use_sigmoid=True, 6 | pretrained='open-mmlab://msra/hrnetv2_w18', 7 | backbone=dict( 8 | type='HRNet', 9 | norm_cfg=norm_cfg, 10 | norm_eval=False, 11 | extra=dict( 12 | stage1=dict( 13 | num_modules=1, 14 | num_branches=1, 15 | block='BOTTLENECK', 16 | num_blocks=(4, ), 17 | num_channels=(64, )), 18 | stage2=dict( 19 | num_modules=1, 20 | num_branches=2, 21 | block='BASIC', 22 | num_blocks=(4, 4), 23 | num_channels=(18, 36)), 24 | stage3=dict( 25 | num_modules=4, 26 | num_branches=3, 27 | block='BASIC', 28 | num_blocks=(4, 4, 4), 29 | num_channels=(18, 36, 72)), 30 | stage4=dict( 31 | num_modules=3, 32 | num_branches=4, 33 | block='BASIC', 34 | num_blocks=(4, 4, 4, 4), 35 | num_channels=(18, 36, 72, 144)))), 36 | hr_settings=dict( 37 | hr_scale=(1024,1024), 38 | scale_ratio=(0.75,1.25), 39 | divisible=8, 40 | lr_loss_weight=0, 41 | hr_loss_weight=0.1, 42 | fuse_mode = 'simple', 43 | crop_num = 2, 44 | ), 45 | decode_head=dict( 46 | type='FCNHead', 47 | in_channels=sum([18, 36, 72, 144]), 48 | in_index=0, 49 | channels=64, 50 | kernel_size=7, 51 | num_convs=1, 52 | compress=True, 53 | concat_input=False, 54 | dropout_ratio=-1, 55 | num_classes=4, 56 | norm_cfg=norm_cfg, 57 | align_corners=False, 58 | loss_decode=dict( 59 | type='BinaryLoss', loss_type='dice', loss_weight=1.0, smooth=1e-5) 60 | ),# model training and testing settings 61 | train_cfg = dict(), 62 | test_cfg = dict(mode='whole',compute_aupr=True) 63 | ) 64 | -------------------------------------------------------------------------------- /configs/_base_/models/hrdecoder_fcn_hr48.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | 3 | _base_ = './hrdecoder_fcn_hr18.py' 4 | 5 | model = dict( 6 | pretrained='open-mmlab://msra/hrnetv2_w48', 7 | backbone=dict( 8 | extra=dict( 9 | stage2=dict(num_channels=(48, 96)), 10 | stage3=dict(num_channels=(48, 96, 192)), 11 | stage4=dict(num_channels=(48, 96, 192, 384)))), 12 | decode_head=dict( 13 | in_channels=sum([48, 96, 192, 384]), 14 | ) 15 | ) -------------------------------------------------------------------------------- /configs/_base_/schedules/adamw.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | # optimizer 4 | optimizer = dict( 5 | type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01) 6 | optimizer_config = dict() 7 | -------------------------------------------------------------------------------- /configs/_base_/schedules/poly10.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/lhoyer/DAFormer 2 | # --------------------------------------------------------------- 3 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 5 | # --------------------------------------------------------------- 6 | 7 | # learning policy 8 | lr_config = dict(policy='poly', power=1.0, min_lr=1e-4, by_epoch=False) 9 | -------------------------------------------------------------------------------- /configs/_base_/schedules/poly10warm.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/lhoyer/DAFormer 2 | # --------------------------------------------------------------- 3 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 5 | # --------------------------------------------------------------- 6 | 7 | # learning policy 8 | lr_config = dict( 9 | policy='poly', 10 | warmup='linear', 11 | warmup_iters=1500, 12 | warmup_ratio=1e-6, 13 | power=1.0, 14 | min_lr=0.0, 15 | by_epoch=False) 16 | -------------------------------------------------------------------------------- /configs/_base_/schedules/sgd.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | -------------------------------------------------------------------------------- /configs/lesion/efficient-hrdecoder_fcn_hr48_ddr_2048.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/efficient-hrdecoder_fcn_hr48.py', 3 | '../_base_/datasets/hr_ddr_2048.py', 4 | '../_base_/default_runtime.py', 5 | '../_base_/schedules/sgd.py', 6 | '../_base_/schedules/poly10warm.py', 7 | ] 8 | 9 | model = dict( 10 | type='EfficientHRDecoder', 11 | use_sigmoid=True, 12 | hr_settings=dict( 13 | visual_dim = 256, 14 | hr_scale = (1024,1024), 15 | scale_ratio = (0.75, 1.25), 16 | divisible = 8, 17 | lr_loss_weight = 0, 18 | hr_loss_weight = 0.1, 19 | fuse_mode = 'simple', 20 | crop_num = 4, 21 | ), 22 | decode_head=dict( 23 | type='FCNHead', 24 | in_channels=256, 25 | in_index=0, 26 | channels=64, 27 | kernel_size=7, 28 | num_convs=1, 29 | compress=True, 30 | concat_input=False, 31 | num_classes=4, 32 | align_corners=False, 33 | loss_decode=dict( 34 | type='BinaryLoss', loss_type='dice', loss_weight=1.0, smooth=1e-5) 35 | ),# model training and testing settings 36 | ) 37 | 38 | data=dict( 39 | samples_per_gpu=1, 40 | workers_per_gpu=1, 41 | ) 42 | seed = 2 43 | runner = dict(type='IterBasedRunner', max_iters=40000) 44 | checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) 45 | evaluation = dict(interval=4000, 46 | metric='mIoU', 47 | priority='LOW', 48 | save_best='mIoU') 49 | -------------------------------------------------------------------------------- /configs/lesion/efficient-hrdecoder_fcn_hr48_idrid_2880x1920-slide.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/efficient-hrdecoder_fcn_hr48.py', 3 | '../_base_/datasets/hr_idrid_2880x1920-slide.py', 4 | '../_base_/default_runtime.py', 5 | '../_base_/schedules/sgd.py', 6 | '../_base_/schedules/poly10warm.py', 7 | ] 8 | 9 | model = dict( 10 | type='EfficientHRDecoder', 11 | use_sigmoid=True, 12 | hr_settings=dict( 13 | hr_scale = (960,960), 14 | scale_ratio = (0.75, 1.25), 15 | divisible = 8, 16 | lr_loss_weight = 0, 17 | hr_loss_weight = 0.1, 18 | fuse_mode = 'simple', 19 | crop_num = 2, 20 | ), 21 | test_cfg=dict( 22 | mode='slide', 23 | stride=(960,960), 24 | crop_size=(1920,1920), 25 | compute_aupr=True, 26 | ) 27 | ) 28 | 29 | data=dict( 30 | samples_per_gpu=1, 31 | workers_per_gpu=1, 32 | ) 33 | seed = 2 34 | runner = dict(type='IterBasedRunner', max_iters=20000) 35 | checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) 36 | evaluation = dict(interval=4000, 37 | metric='mIoU', 38 | priority='LOW', 39 | save_best='mIoU') 40 | -------------------------------------------------------------------------------- /configs/lesion/fcn_hr48_ddr_1024.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/fcn_hr48.py', 3 | '../_base_/datasets/hr_ddr_1024.py', 4 | '../_base_/default_runtime.py', 5 | '../_base_/schedules/sgd.py', 6 | '../_base_/schedules/poly10warm.py' 7 | ] 8 | model = dict( 9 | use_sigmoid=True, 10 | decode_head=dict( 11 | num_classes=4, 12 | loss_decode=dict( 13 | type='BinaryLoss', 14 | loss_type='dice', 15 | use_sigmoid=False, 16 | loss_weight=1.0, 17 | smooth=1e-5) 18 | ) 19 | ) 20 | data=dict( 21 | samples_per_gpu=1, 22 | workers_per_gpu=1, 23 | ) 24 | lr=0.01 25 | optimizer = dict(lr=lr, momentum=0.9, weight_decay=0.0005) 26 | lr_config = dict(power=0.9, min_lr=lr/100) 27 | runner = dict(type='IterBasedRunner', max_iters=20000) 28 | checkpoint_config = dict(by_epoch=False, interval=4000, max_keep_ckpts=1) 29 | evaluation = dict(interval=4000, 30 | metric='mIoU', 31 | priority='LOW', 32 | save_best='mIoU') 33 | -------------------------------------------------------------------------------- /configs/lesion/fcn_hr48_idrid_1440x960.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/fcn_hr48.py', 3 | '../_base_/datasets/idrid_1440x960.py', 4 | '../_base_/default_runtime.py', 5 | '../_base_/schedules/sgd.py', 6 | '../_base_/schedules/poly10warm.py', 7 | ] 8 | model = dict( 9 | decode_head=dict( 10 | num_classes=4, 11 | loss_decode=dict( 12 | type='BinaryLoss', 13 | loss_type='dice', 14 | use_sigmoid=False, 15 | loss_weight=1.0, 16 | smooth=1e-5) 17 | ) 18 | ) 19 | data=dict( 20 | samples_per_gpu=1, 21 | workers_per_gpu=1, 22 | ) 23 | runner = dict(type='IterBasedRunner', max_iters=20000) 24 | checkpoint_config = dict(by_epoch=False, interval=4000, max_keep_ckpts=1) 25 | evaluation = dict(interval=4000, 26 | metric='mIoU', 27 | priority='LOW', 28 | save_best='mIoU') 29 | -------------------------------------------------------------------------------- /configs/lesion/hrdecoder_fcn_hr48_ddr_2048.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/hrdecoder_fcn_hr48.py', 3 | '../_base_/datasets/hr_ddr_2048.py', 4 | '../_base_/default_runtime.py', 5 | '../_base_/schedules/sgd.py', 6 | '../_base_/schedules/poly10warm.py', 7 | ] 8 | 9 | model = dict( 10 | type= 'HRDecoder', 11 | hr_settings=dict( 12 | hr_scale = (1024,1024), 13 | scale_ratio = (0.75, 1.25), 14 | divisible = 8, 15 | lr_loss_weight = 0, 16 | hr_loss_weight = 0.1, 17 | fuse_mode = 'simple', 18 | crop_num = 4, 19 | ), 20 | ) 21 | 22 | data=dict( 23 | samples_per_gpu=1, 24 | workers_per_gpu=1, 25 | ) 26 | seed = 14 27 | runner = dict(type='IterBasedRunner', max_iters=40000) 28 | checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) 29 | evaluation = dict(interval=4000, 30 | metric='mIoU', 31 | priority='LOW', 32 | save_best='mIoU') 33 | -------------------------------------------------------------------------------- /configs/lesion/hrdecoder_fcn_hr48_idrid_2880x1920-slide.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/hrdecoder_fcn_hr48.py', 3 | '../_base_/datasets/hr_idrid_2880x1920-slide.py', 4 | '../_base_/default_runtime.py', 5 | '../_base_/schedules/sgd.py', 6 | '../_base_/schedules/poly10warm.py', 7 | ] 8 | 9 | model = dict( 10 | type= 'HRDecoder', 11 | hr_settings=dict( 12 | hr_scale = (960,960), 13 | scale_ratio = (0.75, 1.25), 14 | divisible = 8, 15 | lr_loss_weight = 0, 16 | hr_loss_weight = 0.1, 17 | fuse_mode = 'simple', 18 | crop_num = 2, 19 | ), 20 | test_cfg=dict( 21 | mode='slide', 22 | stride=(960,960), 23 | crop_size=(1920,1920), 24 | compute_aupr=True, 25 | ) 26 | ) 27 | 28 | data=dict( 29 | samples_per_gpu=1, 30 | workers_per_gpu=1, 31 | ) 32 | seed = 2 33 | runner = dict(type='IterBasedRunner', max_iters=20000) 34 | checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) 35 | evaluation = dict(interval=2000, 36 | metric='mIoU', 37 | priority='LOW', 38 | save_best='mIoU') 39 | -------------------------------------------------------------------------------- /mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | from .version import __version__, version_info 4 | 5 | MMCV_MIN = '1.3.7' 6 | MMCV_MAX = '1.7.1' 7 | 8 | 9 | def digit_version(version_str): 10 | digit_version = [] 11 | for x in version_str.split('.'): 12 | if x.isdigit(): 13 | digit_version.append(int(x)) 14 | elif x.find('rc') != -1: 15 | patch_version = x.split('rc') 16 | digit_version.append(int(patch_version[0]) - 1) 17 | digit_version.append(int(patch_version[1])) 18 | return digit_version 19 | 20 | 21 | mmcv_min_version = digit_version(MMCV_MIN) 22 | mmcv_max_version = digit_version(MMCV_MAX) 23 | mmcv_version = digit_version(mmcv.__version__) 24 | 25 | 26 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 27 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 28 | f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 29 | 30 | __all__ = ['__version__', 'version_info'] 31 | -------------------------------------------------------------------------------- /mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 4 | from .test import multi_gpu_test, single_gpu_test 5 | from .train import get_root_logger, set_random_seed, train_segmentor 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 9 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 10 | 'show_result_pyplot', 11 | 12 | ] 13 | -------------------------------------------------------------------------------- /mmseg/apis/inference.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Override palette, classes, and state dict keys 4 | 5 | import matplotlib.pyplot as plt 6 | import mmcv 7 | import torch 8 | from mmcv.parallel import collate, scatter 9 | from mmcv.runner import load_checkpoint 10 | 11 | from mmseg.datasets.pipelines import Compose 12 | from mmseg.models import build_segmentor 13 | 14 | 15 | def init_segmentor(config, 16 | checkpoint=None, 17 | device='cuda:0', 18 | classes=None, 19 | palette=None, 20 | revise_checkpoint=[(r'^module\.', '')]): 21 | """Initialize a segmentor from config file. 22 | 23 | Args: 24 | config (str or :obj:`mmcv.Config`): Config file path or the config 25 | object. 26 | checkpoint (str, optional): Checkpoint path. If left as None, the model 27 | will not load any weights. 28 | device (str, optional) CPU/CUDA device option. Default 'cuda:0'. 29 | Use 'cpu' for loading model on CPU. 30 | Returns: 31 | nn.Module: The constructed segmentor. 32 | """ 33 | if isinstance(config, str): 34 | config = mmcv.Config.fromfile(config) 35 | elif not isinstance(config, mmcv.Config): 36 | raise TypeError('config must be a filename or Config object, ' 37 | 'but got {}'.format(type(config))) 38 | config.model.pretrained = None 39 | config.model.train_cfg = None 40 | model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) 41 | if checkpoint is not None: 42 | checkpoint = load_checkpoint( 43 | model, 44 | checkpoint, 45 | map_location='cpu', 46 | revise_keys=revise_checkpoint) 47 | model.CLASSES = checkpoint['meta']['CLASSES'] if classes is None \ 48 | else classes 49 | model.PALETTE = checkpoint['meta']['PALETTE'] if palette is None \ 50 | else palette 51 | model.cfg = config # save the config in the model for convenience 52 | model.to(device) 53 | model.eval() 54 | return model 55 | 56 | 57 | class LoadImage: 58 | """A simple pipeline to load image.""" 59 | 60 | def __call__(self, results): 61 | """Call function to load images into results. 62 | 63 | Args: 64 | results (dict): A result dict contains the file name 65 | of the image to be read. 66 | 67 | Returns: 68 | dict: ``results`` will be returned containing loaded image. 69 | """ 70 | 71 | if isinstance(results['img'], str): 72 | results['filename'] = results['img'] 73 | results['ori_filename'] = results['img'] 74 | else: 75 | results['filename'] = None 76 | results['ori_filename'] = None 77 | img = mmcv.imread(results['img']) 78 | results['img'] = img 79 | results['img_shape'] = img.shape 80 | results['ori_shape'] = img.shape 81 | return results 82 | 83 | 84 | def inference_segmentor(model, img): 85 | """Inference image(s) with the segmentor. 86 | 87 | Args: 88 | model (nn.Module): The loaded segmentor. 89 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 90 | images. 91 | 92 | Returns: 93 | (list[Tensor]): The segmentation result. 94 | """ 95 | cfg = model.cfg 96 | device = next(model.parameters()).device # model device 97 | # build the data pipeline 98 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 99 | test_pipeline = Compose(test_pipeline) 100 | # prepare data 101 | data = dict(img=img) 102 | data = test_pipeline(data) 103 | data = collate([data], samples_per_gpu=1) 104 | if next(model.parameters()).is_cuda: 105 | # scatter to specified GPU 106 | data = scatter(data, [device])[0] 107 | else: 108 | data['img_metas'] = [i.data[0] for i in data['img_metas']] 109 | 110 | # forward the model 111 | with torch.no_grad(): 112 | result = model(return_loss=False, rescale=True, **data) 113 | return result 114 | 115 | 116 | def show_result_pyplot(model, 117 | img, 118 | result, 119 | palette=None, 120 | fig_size=(15, 10), 121 | opacity=0.5, 122 | title='', 123 | block=True): 124 | """Visualize the segmentation results on the image. 125 | 126 | Args: 127 | model (nn.Module): The loaded segmentor. 128 | img (str or np.ndarray): Image filename or loaded image. 129 | result (list): The segmentation result. 130 | palette (list[list[int]]] | None): The palette of segmentation 131 | map. If None is given, random palette will be generated. 132 | Default: None 133 | fig_size (tuple): Figure size of the pyplot figure. 134 | opacity(float): Opacity of painted segmentation map. 135 | Default 0.5. 136 | Must be in (0, 1] range. 137 | title (str): The title of pyplot figure. 138 | Default is ''. 139 | block (bool): Whether to block the pyplot figure. 140 | Default is True. 141 | """ 142 | if hasattr(model, 'module'): 143 | model = model.module 144 | img = model.show_result( 145 | img, result, palette=palette, show=False, opacity=opacity) 146 | plt.figure(figsize=fig_size) 147 | plt.imshow(mmcv.bgr2rgb(img)) 148 | plt.title(title) 149 | plt.tight_layout() 150 | plt.show(block=block) 151 | -------------------------------------------------------------------------------- /mmseg/apis/test.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Support debug_output_attention 3 | 4 | import os.path as osp 5 | import tempfile 6 | 7 | import mmcv 8 | import numpy as np 9 | import torch 10 | from mmcv.engine import collect_results_cpu, collect_results_gpu 11 | from mmcv.image import tensor2imgs 12 | from mmcv.runner import get_dist_info 13 | 14 | 15 | def np2tmp(array, temp_file_name=None, tmpdir=None): 16 | """Save ndarray to local numpy file. 17 | 18 | Args: 19 | array (ndarray): Ndarray to save. 20 | temp_file_name (str): Numpy file name. If 'temp_file_name=None', this 21 | function will generate a file name with tempfile.NamedTemporaryFile 22 | to save ndarray. Default: None. 23 | tmpdir (str): Temporary directory to save Ndarray files. Default: None. 24 | 25 | Returns: 26 | str: The numpy file name. 27 | """ 28 | 29 | if temp_file_name is None: 30 | temp_file_name = tempfile.NamedTemporaryFile( 31 | suffix='.npy', delete=False, dir=tmpdir).name 32 | np.save(temp_file_name, array) 33 | return temp_file_name 34 | 35 | 36 | def single_gpu_test(model, 37 | data_loader, 38 | show=False, 39 | out_dir=None, 40 | efficient_test=False, 41 | opacity=0.5): 42 | """Test with single GPU. 43 | 44 | Args: 45 | model (nn.Module): Model to be tested. 46 | data_loader (utils.data.Dataloader): Pytorch data loader. 47 | show (bool): Whether show results during inference. Default: False. 48 | out_dir (str, optional): If specified, the results will be dumped into 49 | the directory to save output results. 50 | efficient_test (bool): Whether save the results as local numpy files to 51 | save CPU memory during evaluation. Default: False. 52 | opacity(float): Opacity of painted segmentation map. 53 | Default 0.5. 54 | Must be in (0, 1] range. 55 | Returns: 56 | list: The prediction results. 57 | """ 58 | model.eval() 59 | results = [] 60 | dataset = data_loader.dataset 61 | prog_bar = mmcv.ProgressBar(len(dataset)) 62 | if efficient_test: 63 | mmcv.mkdir_or_exist('.efficient_test') 64 | for i, data in enumerate(data_loader): 65 | with torch.no_grad(): 66 | result = model(return_loss=False, **data) 67 | 68 | if show or out_dir: 69 | img_tensor = data['img'][0] 70 | img_metas = data['img_metas'][0].data[0] 71 | imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) 72 | assert len(imgs) == len(img_metas) 73 | 74 | for img, img_meta in zip(imgs, img_metas): 75 | h, w, _ = img_meta['img_shape'] 76 | img_show = img[:h, :w, :] 77 | 78 | ori_h, ori_w = img_meta['ori_shape'][:-1] 79 | img_show = mmcv.imresize(img_show, (ori_w, ori_h)) 80 | 81 | if out_dir: 82 | out_file = osp.join(out_dir, img_meta['ori_filename']) 83 | else: 84 | out_file = None 85 | 86 | if hasattr(model.module.decode_head, 87 | 'debug_output_attention') and \ 88 | model.module.decode_head.debug_output_attention: 89 | # Attention debug output 90 | mmcv.imwrite(result[0] * 255, out_file) 91 | else: 92 | model.module.show_result( 93 | img_show, 94 | result, 95 | palette=dataset.PALETTE, 96 | show=show, 97 | out_file=out_file, 98 | opacity=opacity) 99 | 100 | if isinstance(result, list): 101 | if efficient_test: 102 | result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] 103 | results.extend(result) 104 | else: 105 | if efficient_test: 106 | result = np2tmp(result, tmpdir='.efficient_test') 107 | results.append(result) 108 | 109 | batch_size = len(result) 110 | for _ in range(batch_size): 111 | prog_bar.update() 112 | return results 113 | 114 | 115 | def multi_gpu_test(model, 116 | data_loader, 117 | tmpdir=None, 118 | gpu_collect=False, 119 | efficient_test=False): 120 | """Test model with multiple gpus. 121 | 122 | This method tests model with multiple gpus and collects the results 123 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 124 | it encodes results to gpu tensors and use gpu communication for results 125 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 126 | and collects them by the rank 0 worker. 127 | 128 | Args: 129 | model (nn.Module): Model to be tested. 130 | data_loader (utils.data.Dataloader): Pytorch data loader. 131 | tmpdir (str): Path of directory to save the temporary results from 132 | different gpus under cpu mode. The same path is used for efficient 133 | test. 134 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 135 | efficient_test (bool): Whether save the results as local numpy files to 136 | save CPU memory during evaluation. Default: False. 137 | 138 | Returns: 139 | list: The prediction results. 140 | """ 141 | 142 | model.eval() 143 | results = [] 144 | dataset = data_loader.dataset 145 | rank, world_size = get_dist_info() 146 | if rank == 0: 147 | prog_bar = mmcv.ProgressBar(len(dataset)) 148 | if efficient_test: 149 | mmcv.mkdir_or_exist('.efficient_test') 150 | for i, data in enumerate(data_loader): 151 | with torch.no_grad(): 152 | result = model(return_loss=False, rescale=True, **data) 153 | 154 | if isinstance(result, list): 155 | if efficient_test: 156 | result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] 157 | results.extend(result) 158 | else: 159 | if efficient_test: 160 | result = np2tmp(result, tmpdir='.efficient_test') 161 | results.append(result) 162 | 163 | if rank == 0: 164 | batch_size = len(result) 165 | for _ in range(batch_size * world_size): 166 | prog_bar.update() 167 | 168 | # collect results from all ranks 169 | if gpu_collect: 170 | results = collect_results_gpu(results, len(dataset)) 171 | else: 172 | results = collect_results_cpu(results, len(dataset), tmpdir) 173 | return results 174 | -------------------------------------------------------------------------------- /mmseg/apis/train.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: 3 | # - Add ddp_wrapper from mmgen 4 | 5 | import random 6 | import warnings 7 | 8 | import mmcv 9 | import numpy as np 10 | import torch 11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 12 | from mmcv.runner import build_optimizer, build_runner 13 | 14 | from mmseg.core import DistEvalHook, EvalHook 15 | from mmseg.core.ddp_wrapper import DistributedDataParallelWrapper 16 | from mmseg.datasets import build_dataloader, build_dataset 17 | from mmseg.utils import get_root_logger 18 | 19 | 20 | def set_random_seed(seed, deterministic=False): 21 | """Set random seed. 22 | 23 | Args: 24 | seed (int): Seed to be used. 25 | deterministic (bool): Whether to set the deterministic option for 26 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 27 | to True and `torch.backends.cudnn.benchmark` to False. 28 | Default: False. 29 | """ 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | if deterministic: 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | 39 | def train_segmentor(model, 40 | dataset, 41 | cfg, 42 | distributed=False, 43 | validate=False, 44 | timestamp=None, 45 | meta=None): 46 | """Launch segmentor training.""" 47 | logger = get_root_logger(cfg.log_level) 48 | 49 | # prepare data loaders 50 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 51 | data_loaders = [ 52 | build_dataloader( 53 | ds, 54 | cfg.data.samples_per_gpu, 55 | cfg.data.workers_per_gpu, 56 | # cfg.gpus will be ignored if distributed 57 | len(cfg.gpu_ids), 58 | dist=distributed, 59 | seed=cfg.seed, 60 | drop_last=True) for ds in dataset 61 | ] 62 | 63 | # put model on gpus 64 | if distributed: 65 | find_unused_parameters = cfg.get('find_unused_parameters', False) 66 | use_ddp_wrapper = cfg.get('use_ddp_wrapper', False) 67 | # Sets the `find_unused_parameters` parameter in 68 | # torch.nn.parallel.DistributedDataParallel 69 | if use_ddp_wrapper: 70 | mmcv.print_log('Use DDP Wrapper.', 'mmseg') 71 | model = DistributedDataParallelWrapper( 72 | model.cuda(), 73 | device_ids=[torch.cuda.current_device()], 74 | broadcast_buffers=False, 75 | find_unused_parameters=find_unused_parameters) 76 | else: 77 | model = MMDistributedDataParallel( 78 | model.cuda(), 79 | device_ids=[torch.cuda.current_device()], 80 | broadcast_buffers=False, 81 | find_unused_parameters=find_unused_parameters) 82 | else: 83 | model = MMDataParallel( 84 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 85 | # build runner 86 | optimizer = build_optimizer(model, cfg.optimizer) 87 | 88 | if cfg.get('runner') is None: 89 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 90 | warnings.warn( 91 | 'config is now expected to have a `runner` section, ' 92 | 'please set `runner` in your config.', UserWarning) 93 | 94 | runner = build_runner( 95 | cfg.runner, 96 | default_args=dict( 97 | model=model, 98 | batch_processor=None, 99 | optimizer=optimizer, 100 | work_dir=cfg.work_dir, 101 | logger=logger, 102 | meta=meta)) 103 | 104 | # register hooks 105 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 106 | cfg.checkpoint_config, cfg.log_config, 107 | cfg.get('momentum_config', None)) 108 | 109 | # an ugly walkaround to make the .log and .log.json filenames the same 110 | runner.timestamp = timestamp 111 | 112 | # register eval hooks 113 | if validate: 114 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 115 | val_dataloader = build_dataloader( 116 | val_dataset, 117 | samples_per_gpu=1, 118 | workers_per_gpu=cfg.data.workers_per_gpu, 119 | dist=distributed, 120 | shuffle=False) 121 | eval_cfg = cfg.get('evaluation', {}) 122 | priority = eval_cfg.pop('priority','LOW') 123 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 124 | eval_hook = DistEvalHook if distributed else EvalHook 125 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg),priority=priority) 126 | 127 | if cfg.resume_from: 128 | runner.resume(cfg.resume_from) 129 | elif cfg.load_from: 130 | runner.load_checkpoint(cfg.load_from) 131 | runner.run(data_loaders, cfg.workflow) 132 | -------------------------------------------------------------------------------- /mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .evaluation import * # noqa: F401, F403 4 | from .seg import * # noqa: F401, F403 5 | from .utils import * # noqa: F401, F403 6 | -------------------------------------------------------------------------------- /mmseg/core/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmgeneration 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel 6 | from mmcv.parallel.scatter_gather import scatter_kwargs 7 | from torch.cuda._utils import _get_device_index 8 | 9 | 10 | @MODULE_WRAPPERS.register_module('mmseg.DDPWrapper') 11 | class DistributedDataParallelWrapper(nn.Module): 12 | """A DistributedDataParallel wrapper for models in MMGeneration. 13 | 14 | In MMedting, there is a need to wrap different modules in the models 15 | with separate DistributedDataParallel. Otherwise, it will cause 16 | errors for GAN training. 17 | More specific, the GAN model, usually has two sub-modules: 18 | generator and discriminator. If we wrap both of them in one 19 | standard DistributedDataParallel, it will cause errors during training, 20 | because when we update the parameters of the generator (or discriminator), 21 | the parameters of the discriminator (or generator) is not updated, which is 22 | not allowed for DistributedDataParallel. 23 | So we design this wrapper to separately wrap DistributedDataParallel 24 | for generator and discriminator. 25 | In this wrapper, we perform two operations: 26 | 1. Wrap the modules in the models with separate MMDistributedDataParallel. 27 | Note that only modules with parameters will be wrapped. 28 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. 29 | Note that the arguments of this wrapper is the same as those in 30 | `torch.nn.parallel.distributed.DistributedDataParallel`. 31 | Args: 32 | module (nn.Module): Module that needs to be wrapped. 33 | device_ids (list[int | `torch.device`]): Same as that in 34 | `torch.nn.parallel.distributed.DistributedDataParallel`. 35 | dim (int, optional): Same as that in the official scatter function in 36 | pytorch. Defaults to 0. 37 | broadcast_buffers (bool): Same as that in 38 | `torch.nn.parallel.distributed.DistributedDataParallel`. 39 | Defaults to False. 40 | find_unused_parameters (bool, optional): Same as that in 41 | `torch.nn.parallel.distributed.DistributedDataParallel`. 42 | Traverse the autograd graph of all tensors contained in returned 43 | value of the wrapped module’s forward function. Defaults to False. 44 | kwargs (dict): Other arguments used in 45 | `torch.nn.parallel.distributed.DistributedDataParallel`. 46 | """ 47 | 48 | def __init__(self, 49 | module, 50 | device_ids, 51 | dim=0, 52 | broadcast_buffers=False, 53 | find_unused_parameters=False, 54 | **kwargs): 55 | super().__init__() 56 | assert len(device_ids) == 1, ( 57 | 'Currently, DistributedDataParallelWrapper only supports one' 58 | 'single CUDA device for each process.' 59 | f'The length of device_ids must be 1, but got {len(device_ids)}.') 60 | self.module = module 61 | self.dim = dim 62 | self.to_ddp( 63 | device_ids=device_ids, 64 | dim=dim, 65 | broadcast_buffers=broadcast_buffers, 66 | find_unused_parameters=find_unused_parameters, 67 | **kwargs) 68 | self.output_device = _get_device_index(device_ids[0], True) 69 | 70 | def to_ddp(self, device_ids, dim, broadcast_buffers, 71 | find_unused_parameters, **kwargs): 72 | """Wrap models with separate MMDistributedDataParallel. 73 | 74 | It only wraps the modules with parameters. 75 | """ 76 | for name, module in self.module._modules.items(): 77 | if next(module.parameters(), None) is None: 78 | module = module.cuda() 79 | elif all(not p.requires_grad for p in module.parameters()): 80 | module = module.cuda() 81 | else: 82 | module = MMDistributedDataParallel( 83 | module.cuda(), 84 | device_ids=device_ids, 85 | dim=dim, 86 | broadcast_buffers=broadcast_buffers, 87 | find_unused_parameters=find_unused_parameters, 88 | **kwargs) 89 | self.module._modules[name] = module 90 | 91 | def scatter(self, inputs, kwargs, device_ids): 92 | """Scatter function. 93 | 94 | Args: 95 | inputs (Tensor): Input Tensor. 96 | kwargs (dict): Args for 97 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 98 | device_ids (int): Device id. 99 | """ 100 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 101 | 102 | def forward(self, *inputs, **kwargs): 103 | """Forward function. 104 | 105 | Args: 106 | inputs (tuple): Input data. 107 | kwargs (dict): Args for 108 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 109 | """ 110 | inputs, kwargs = self.scatter(inputs, kwargs, 111 | [torch.cuda.current_device()]) 112 | return self.module(*inputs[0], **kwargs[0]) 113 | 114 | def train_step(self, *inputs, **kwargs): 115 | """Train step function. 116 | 117 | Args: 118 | inputs (Tensor): Input Tensor. 119 | kwargs (dict): Args for 120 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 121 | """ 122 | inputs, kwargs = self.scatter(inputs, kwargs, 123 | [torch.cuda.current_device()]) 124 | output = self.module.train_step(*inputs[0], **kwargs[0]) 125 | return output 126 | 127 | def val_step(self, *inputs, **kwargs): 128 | """Validation step function. 129 | 130 | Args: 131 | inputs (tuple): Input data. 132 | kwargs (dict): Args for ``scatter_kwargs``. 133 | """ 134 | inputs, kwargs = self.scatter(inputs, kwargs, 135 | [torch.cuda.current_device()]) 136 | output = self.module.val_step(*inputs[0], **kwargs[0]) 137 | return output 138 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .class_names import get_classes, get_palette 4 | from .eval_hooks import DistEvalHook, EvalHook 5 | from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou 6 | 7 | from .lesion_metric import lesion_metrics 8 | 9 | __all__ = [ 10 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 11 | 'eval_metrics', 'get_classes', 'get_palette','lesion_metrics' 12 | ] 13 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import os.path as osp 4 | 5 | import torch.distributed as dist 6 | from mmcv.runner import DistEvalHook as _DistEvalHook 7 | from mmcv.runner import EvalHook as _EvalHook 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | 10 | 11 | class EvalHook(_EvalHook): 12 | """Single GPU EvalHook, with efficient test support. 13 | 14 | Args: 15 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 16 | If set to True, it will perform by epoch. Otherwise, by iteration. 17 | Default: False. 18 | efficient_test (bool): Whether save the results as local numpy files to 19 | save CPU memory during evaluation. Default: False. 20 | Returns: 21 | list: The prediction results. 22 | """ 23 | 24 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 25 | 26 | def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): 27 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 28 | self.efficient_test = efficient_test 29 | 30 | def _do_evaluate(self, runner): 31 | """perform evaluation and save ckpt.""" 32 | if not self._should_evaluate(runner): 33 | return 34 | 35 | from mmseg.apis import single_gpu_test 36 | results = single_gpu_test( 37 | runner.model, 38 | self.dataloader, 39 | show=False, 40 | efficient_test=self.efficient_test) 41 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 42 | key_score = self.evaluate(runner, results) 43 | if self.save_best: 44 | self._save_ckpt(runner, key_score) 45 | 46 | 47 | class DistEvalHook(_DistEvalHook): 48 | """Distributed EvalHook, with efficient test support. 49 | 50 | Args: 51 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 52 | If set to True, it will perform by epoch. Otherwise, by iteration. 53 | Default: False. 54 | efficient_test (bool): Whether save the results as local numpy files to 55 | save CPU memory during evaluation. Default: False. 56 | Returns: 57 | list: The prediction results. 58 | """ 59 | 60 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 61 | 62 | def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): 63 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 64 | self.efficient_test = efficient_test 65 | 66 | def _do_evaluate(self, runner): 67 | """perform evaluation and save ckpt.""" 68 | # Synchronization of BatchNorm's buffer (running_mean 69 | # and running_var) is not supported in the DDP of pytorch, 70 | # which may cause the inconsistent performance of models in 71 | # different ranks, so we broadcast BatchNorm's buffers 72 | # of rank 0 to other ranks to avoid this. 73 | if self.broadcast_bn_buffer: 74 | model = runner.model 75 | for name, module in model.named_modules(): 76 | if isinstance(module, 77 | _BatchNorm) and module.track_running_stats: 78 | dist.broadcast(module.running_var, 0) 79 | dist.broadcast(module.running_mean, 0) 80 | 81 | if not self._should_evaluate(runner): 82 | return 83 | 84 | tmpdir = self.tmpdir 85 | if tmpdir is None: 86 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 87 | 88 | from mmseg.apis import multi_gpu_test 89 | results = multi_gpu_test( 90 | runner.model, 91 | self.dataloader, 92 | tmpdir=tmpdir, 93 | gpu_collect=self.gpu_collect, 94 | efficient_test=self.efficient_test) 95 | if runner.rank == 0: 96 | print('\n') 97 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 98 | key_score = self.evaluate(runner, results) 99 | 100 | if self.save_best: 101 | self._save_ckpt(runner, key_score) 102 | -------------------------------------------------------------------------------- /mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .builder import build_pixel_sampler 4 | from .sampler import BasePixelSampler, OHEMPixelSampler 5 | 6 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from mmcv.utils import Registry, build_from_cfg 4 | 5 | PIXEL_SAMPLERS = Registry('pixel sampler') 6 | 7 | 8 | def build_pixel_sampler(cfg, **default_args): 9 | """Build pixel sampler for segmentation map.""" 10 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 11 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .base_pixel_sampler import BasePixelSampler 4 | from .ohem_pixel_sampler import OHEMPixelSampler 5 | 6 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | 6 | class BasePixelSampler(metaclass=ABCMeta): 7 | """Base class of pixel sampler.""" 8 | 9 | def __init__(self, **kwargs): 10 | pass 11 | 12 | @abstractmethod 13 | def sample(self, seg_logit, seg_label): 14 | """Placeholder for sample function.""" 15 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ..builder import PIXEL_SAMPLERS 7 | from .base_pixel_sampler import BasePixelSampler 8 | 9 | 10 | @PIXEL_SAMPLERS.register_module() 11 | class OHEMPixelSampler(BasePixelSampler): 12 | """Online Hard Example Mining Sampler for segmentation. 13 | 14 | Args: 15 | context (nn.Module): The context of sampler, subclass of 16 | :obj:`BaseDecodeHead`. 17 | thresh (float, optional): The threshold for hard example selection. 18 | Below which, are prediction with low confidence. If not 19 | specified, the hard examples will be pixels of top ``min_kept`` 20 | loss. Default: None. 21 | min_kept (int, optional): The minimum number of predictions to keep. 22 | Default: 100000. 23 | """ 24 | 25 | def __init__(self, context, thresh=None, min_kept=100000): 26 | super(OHEMPixelSampler, self).__init__() 27 | self.context = context 28 | assert min_kept > 1 29 | self.thresh = thresh 30 | self.min_kept = min_kept 31 | 32 | def sample(self, seg_logit, seg_label): 33 | """Sample pixels that have high loss or with low prediction confidence. 34 | 35 | Args: 36 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 37 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 38 | 39 | Returns: 40 | torch.Tensor: segmentation weight, shape (N, H, W) 41 | """ 42 | with torch.no_grad(): 43 | assert seg_logit.shape[2:] == seg_label.shape[2:] 44 | assert seg_label.shape[1] == 1 45 | seg_label = seg_label.squeeze(1).long() 46 | batch_kept = self.min_kept * seg_label.size(0) 47 | valid_mask = seg_label != self.context.ignore_index 48 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 49 | valid_seg_weight = seg_weight[valid_mask] 50 | if self.thresh is not None: 51 | seg_prob = F.softmax(seg_logit, dim=1) 52 | 53 | tmp_seg_label = seg_label.clone().unsqueeze(1) 54 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 55 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 56 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 57 | 58 | if sort_prob.numel() > 0: 59 | min_threshold = sort_prob[min(batch_kept, 60 | sort_prob.numel() - 1)] 61 | else: 62 | min_threshold = 0.0 63 | threshold = max(min_threshold, self.thresh) 64 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 65 | else: 66 | losses = self.context.loss_decode( 67 | seg_logit, 68 | seg_label, 69 | weight=None, 70 | ignore_index=self.context.ignore_index, 71 | reduction_override='none') 72 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 73 | _, sort_indices = losses[valid_mask].sort(descending=True) 74 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 75 | 76 | seg_weight[valid_mask] = valid_seg_weight 77 | 78 | return seg_weight 79 | -------------------------------------------------------------------------------- /mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .misc import add_prefix 4 | 5 | __all__ = ['add_prefix'] 6 | -------------------------------------------------------------------------------- /mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def add_prefix(inputs, prefix): 5 | """Add prefix for dict. 6 | 7 | Args: 8 | inputs (dict): The input dict with str keys. 9 | prefix (str): The prefix to add. 10 | 11 | Returns: 12 | 13 | dict: The dict with keys updated with ``prefix``. 14 | """ 15 | 16 | outputs = dict() 17 | for name, value in inputs.items(): 18 | outputs[f'{prefix}.{name}'] = value 19 | 20 | return outputs 21 | -------------------------------------------------------------------------------- /mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional datasets 3 | 4 | 5 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 6 | from .custom import CustomDataset 7 | from .dataset_wrappers import ConcatDataset, RepeatDataset 8 | 9 | from .lesion_dataset import LesionDataset 10 | 11 | __all__ = [ 12 | 'CustomDataset', 13 | 'build_dataloader', 14 | 'ConcatDataset', 15 | 'RepeatDataset', 16 | 'DATASETS', 17 | 'build_dataset', 18 | 'PIPELINES', 19 | 'CityscapesDataset', 20 | 'GTADataset', 21 | 'SynthiaDataset', 22 | 'UDADataset', 23 | 'ACDCDataset', 24 | 'DarkZurichDataset', 25 | 'GTASplitDataset', 26 | 'SYNTHIASplitDataset', 27 | 'CityscapesSplitDataset', 28 | 29 | 'LesionDataset', 30 | ] 31 | -------------------------------------------------------------------------------- /mmseg/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 4 | 5 | from .builder import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class ConcatDataset(_ConcatDataset): 10 | """A wrapper of concatenated dataset. 11 | 12 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 13 | concat the group flag for image aspect ratio. 14 | 15 | Args: 16 | datasets (list[:obj:`Dataset`]): A list of datasets. 17 | """ 18 | 19 | def __init__(self, datasets): 20 | super(ConcatDataset, self).__init__(datasets) 21 | self.CLASSES = datasets[0].CLASSES 22 | self.PALETTE = datasets[0].PALETTE 23 | 24 | 25 | @DATASETS.register_module() 26 | class RepeatDataset(object): 27 | """A wrapper of repeated dataset. 28 | 29 | The length of repeated dataset will be `times` larger than the original 30 | dataset. This is useful when the data loading time is long but the dataset 31 | is small. Using RepeatDataset can reduce the data loading time between 32 | epochs. 33 | 34 | Args: 35 | dataset (:obj:`Dataset`): The dataset to be repeated. 36 | times (int): Repeat times. 37 | """ 38 | 39 | def __init__(self, dataset, times): 40 | self.dataset = dataset 41 | self.times = times 42 | self.CLASSES = dataset.CLASSES 43 | self.PALETTE = dataset.PALETTE 44 | self._ori_len = len(self.dataset) 45 | 46 | def __getitem__(self, idx): 47 | """Get item from original dataset.""" 48 | return self.dataset[idx % self._ori_len] 49 | 50 | def __len__(self): 51 | """The length is multiplied by ``times``""" 52 | return self.times * self._ori_len 53 | -------------------------------------------------------------------------------- /mmseg/datasets/lesion_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import numpy as np 4 | import torch 5 | from mmcv.utils import print_log 6 | from mmseg.core import lesion_metrics 7 | 8 | from .builder import DATASETS 9 | from .custom import CustomDataset 10 | 11 | 12 | @DATASETS.register_module() 13 | class LesionDataset(CustomDataset): 14 | 15 | #CLASSES = ['bg', 'EX', 'HE', 'SE', 'MA', 'IRMA', 'NV'] 16 | CLASSES = ['bg', 'EX', 'HE', 'SE', 'MA'] 17 | ''' 18 | PALETTE = [ 19 | [0, 0, 0], 20 | [128, 0, 0], # EX: red 21 | [0, 128, 0], # HE: green 22 | [128, 128, 0], # SE: yellow 23 | [0, 0, 128], # MA: blue 24 | [128, 0, 128], #IRMA: purple 25 | [0, 128, 128], #NV 26 | ] 27 | ''' 28 | PALETTE = [ 29 | [0, 0, 0], 30 | [128, 0, 0], # EX: red 31 | [0, 128, 0], # HE: green 32 | [128, 128, 0], # SE: yellow 33 | [0, 0, 128], # MA: blue 34 | ] 35 | 36 | def __init__(self, **kwargs): 37 | super(LesionDataset, self).__init__(**kwargs) 38 | 39 | def evaluate(self, results, metric='mIoU', evaluate_per_image=False,logger=None, **kwargs): 40 | # return super(LesionDataset, self).evaluate(results, metric, logger, **kwargs) 41 | return self._evaluate(results, metric, evaluate_per_image, logger, **kwargs) 42 | 43 | def _evaluate(self, results, metric='mIoU', evaluate_per_image=False, logger=None, **kwargs): 44 | """Evaluate the dataset. 45 | 46 | Args: 47 | results (list): Testing results of the dataset. 48 | metric (str | list[str]): Metrics to be evaluated. 49 | logger (logging.Logger | None | str): Logger used for printing 50 | related information during evaluation. Default: None. 51 | 52 | Returns: 53 | dict[str, float]: Default metrics. 54 | """ 55 | if not isinstance(metric, str): 56 | assert len(metric) == 1 57 | metric = metric[0] 58 | allowed_metrics = ['mIoU', 'mIoU2'] 59 | if metric not in allowed_metrics: 60 | raise KeyError('metric {} is not supported'.format(metric)) 61 | 62 | eval_results = {} 63 | gt_seg_maps = self.get_gt_seg_maps() 64 | if self.CLASSES is None: 65 | num_classes = len( 66 | reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) 67 | else: 68 | num_classes = len(self.CLASSES) 69 | 70 | #======================================================== 71 | iou, f1, ppv, s, aupr , ae = lesion_metrics( 72 | results, gt_seg_maps, num_classes, ignore_index=self.ignore_index) # evaluate 73 | summary_str = '' 74 | summary_str += 'per class results:\n' 75 | 76 | line_format = '{:<15} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}\n' 77 | summary_str += line_format.format('Class', 'IoU', 'F1', 'PPV', 'S', 'AUPR','MAE(e-2)') 78 | #======================================================== 79 | if self.CLASSES is None: 80 | class_names = tuple(range(num_classes)) 81 | else: 82 | class_names = self.CLASSES 83 | for i in range(num_classes): 84 | ppv_str = '{:.2f}'.format(ppv[i] * 100) 85 | s_str = '{:.2f}'.format(s[i] * 100) 86 | f1_str = '{:.2f}'.format(f1[i] * 100) 87 | iou_str = '{:.2f}'.format(iou[i] * 100) 88 | aupr_str = '{:.2f}'.format(aupr[i] * 100) 89 | #======================================================== 90 | ae_str = '{:.2f}'.format(ae[i]*100) 91 | summary_str += line_format.format(class_names[i], iou_str, f1_str, ppv_str, s_str, aupr_str,ae_str) 92 | 93 | 94 | mIoU = np.nanmean(np.nan_to_num(iou[-4:], nan=0)) 95 | mF1 = np.nanmean(np.nan_to_num(f1[-4:], nan=0)) 96 | mPPV = np.nanmean(np.nan_to_num(ppv[-4:], nan=0)) 97 | mS = np.nanmean(np.nan_to_num(s[-4:], nan=0)) 98 | mAUPR = np.nanmean(np.nan_to_num(aupr[-4:], nan=0)) 99 | mMAE = np.nanmean(np.nan_to_num(ae[-4:],nan=0)) 100 | 101 | summary_str += 'Summary:\n' 102 | line_format = '{:<15} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}\n' 103 | summary_str += line_format.format('Scope', 'mIoU', 'mF1', 'mPPV', 'mS', 'mAUPR','mMAE(e-2)') 104 | 105 | iou_str = '{:.2f}'.format(mIoU * 100) 106 | f1_str = '{:.2f}'.format(mF1 * 100) 107 | ppv_str = '{:.2f}'.format(mPPV * 100) 108 | s_str = '{:.2f}'.format(mS * 100) 109 | aupr_str = '{:.2f}'.format(mAUPR * 100) 110 | ae_str = '{:.2f}'.format(mMAE*100) 111 | summary_str += line_format.format('global', iou_str, f1_str, ppv_str, s_str, aupr_str,ae_str) 112 | 113 | eval_results['mIoU'] = mIoU 114 | eval_results['mF1'] = mF1 115 | eval_results['mPPV'] = mPPV 116 | eval_results['mS'] = mS 117 | eval_results['mAUPR'] = mAUPR 118 | eval_results['mMAE'] = mMAE 119 | 120 | # NEW: for two classes metric 121 | if metric == 'mIoU2': 122 | summary_str += '\n' 123 | 124 | print_log(summary_str, logger) 125 | if hasattr(torch.cuda, 'empty_cache'): 126 | torch.cuda.empty_cache() 127 | 128 | return eval_results 129 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from .compose import Compose 4 | from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, 5 | Transpose, to_tensor) 6 | from .loading import LoadAnnotations, LoadImageFromFile 7 | from .test_time_aug import MultiScaleFlipAug 8 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 9 | PhotoMetricDistortion, RandomCrop, RandomFlip, 10 | RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) 11 | 12 | __all__ = [ 13 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 14 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 15 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 16 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 17 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 18 | ] 19 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import collections 4 | 5 | from mmcv.utils import build_from_cfg 6 | 7 | from ..builder import PIPELINES 8 | 9 | 10 | @PIPELINES.register_module() 11 | class Compose(object): 12 | """Compose multiple transforms sequentially. 13 | 14 | Args: 15 | transforms (Sequence[dict | callable]): Sequence of transform object or 16 | config dict to be composed. 17 | """ 18 | 19 | def __init__(self, transforms): 20 | assert isinstance(transforms, collections.abc.Sequence) 21 | self.transforms = [] 22 | for transform in transforms: 23 | if isinstance(transform, dict): 24 | transform = build_from_cfg(transform, PIPELINES) 25 | self.transforms.append(transform) 26 | elif callable(transform): 27 | self.transforms.append(transform) 28 | else: 29 | raise TypeError('transform must be callable or a dict') 30 | 31 | def __call__(self, data): 32 | """Call function to apply transforms sequentially. 33 | 34 | Args: 35 | data (dict): A result dict contains the data to transform. 36 | 37 | Returns: 38 | dict: Transformed data. 39 | """ 40 | 41 | for t in self.transforms: 42 | data = t(data) 43 | if data is None: 44 | return None 45 | return data 46 | 47 | def __repr__(self): 48 | format_string = self.__class__.__name__ + '(' 49 | for t in self.transforms: 50 | format_string += '\n' 51 | format_string += f' {t}' 52 | format_string += '\n)' 53 | return format_string 54 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import os.path as osp 4 | 5 | import mmcv 6 | import numpy as np 7 | 8 | from ..builder import PIPELINES 9 | 10 | 11 | @PIPELINES.register_module() 12 | class LoadImageFromFile(object): 13 | """Load an image from file. 14 | 15 | Required keys are "img_prefix" and "img_info" (a dict that must contain the 16 | key "filename"). Added or updated keys are "filename", "img", "img_shape", 17 | "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), 18 | "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). 19 | 20 | Args: 21 | to_float32 (bool): Whether to convert the loaded image to a float32 22 | numpy array. If set to False, the loaded image is an uint8 array. 23 | Defaults to False. 24 | color_type (str): The flag argument for :func:`mmcv.imfrombytes`. 25 | Defaults to 'color'. 26 | file_client_args (dict): Arguments to instantiate a FileClient. 27 | See :class:`mmcv.fileio.FileClient` for details. 28 | Defaults to ``dict(backend='disk')``. 29 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 30 | 'cv2' 31 | """ 32 | 33 | def __init__(self, 34 | to_float32=False, 35 | color_type='color', 36 | file_client_args=dict(backend='disk'), 37 | imdecode_backend='cv2'): 38 | self.to_float32 = to_float32 39 | self.color_type = color_type 40 | self.file_client_args = file_client_args.copy() 41 | self.file_client = None 42 | self.imdecode_backend = imdecode_backend 43 | 44 | def __call__(self, results): 45 | """Call functions to load image and get image meta information. 46 | 47 | Args: 48 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 49 | 50 | Returns: 51 | dict: The dict contains loaded image and meta information. 52 | """ 53 | 54 | if self.file_client is None: 55 | self.file_client = mmcv.FileClient(**self.file_client_args) 56 | 57 | if results.get('img_prefix') is not None: 58 | filename = osp.join(results['img_prefix'], 59 | results['img_info']['filename']) 60 | else: 61 | filename = results['img_info']['filename'] 62 | 63 | img_bytes = self.file_client.get(filename) 64 | img = mmcv.imfrombytes( 65 | img_bytes, flag=self.color_type, backend=self.imdecode_backend) 66 | if self.to_float32: 67 | img = img.astype(np.float32) 68 | 69 | results['filename'] = filename 70 | results['ori_filename'] = results['img_info']['filename'] 71 | results['img'] = img 72 | results['img_shape'] = img.shape 73 | results['ori_shape'] = img.shape 74 | # Set initial values for default meta_keys 75 | results['pad_shape'] = img.shape 76 | results['scale_factor'] = 1.0 77 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 78 | results['img_norm_cfg'] = dict( 79 | mean=np.zeros(num_channels, dtype=np.float32), 80 | std=np.ones(num_channels, dtype=np.float32), 81 | to_rgb=False) 82 | 83 | return results 84 | 85 | def __repr__(self): 86 | repr_str = self.__class__.__name__ 87 | repr_str += f'(to_float32={self.to_float32},' 88 | repr_str += f"color_type='{self.color_type}'," 89 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 90 | return repr_str 91 | 92 | 93 | @PIPELINES.register_module() 94 | class LoadAnnotations(object): 95 | """Load annotations for semantic segmentation. 96 | 97 | Args: 98 | reduce_zero_label (bool): Whether reduce all label value by 1. 99 | Usually used for datasets where 0 is background label. 100 | Default: False. 101 | file_client_args (dict): Arguments to instantiate a FileClient. 102 | See :class:`mmcv.fileio.FileClient` for details. 103 | Defaults to ``dict(backend='disk')``. 104 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 105 | 'pillow' 106 | """ 107 | 108 | def __init__(self, 109 | reduce_zero_label=False, 110 | file_client_args=dict(backend='disk'), 111 | imdecode_backend='pillow'): 112 | self.reduce_zero_label = reduce_zero_label 113 | self.file_client_args = file_client_args.copy() 114 | self.file_client = None 115 | self.imdecode_backend = imdecode_backend 116 | 117 | def __call__(self, results): 118 | """Call function to load multiple types annotations. 119 | 120 | Args: 121 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 122 | 123 | Returns: 124 | dict: The dict contains loaded semantic segmentation annotations. 125 | """ 126 | if self.file_client is None: 127 | self.file_client = mmcv.FileClient(**self.file_client_args) 128 | 129 | if results.get('seg_prefix', None) is not None: 130 | filename = osp.join(results['seg_prefix'], 131 | results['ann_info']['seg_map']) 132 | else: 133 | filename = results['ann_info']['seg_map'] 134 | img_bytes = self.file_client.get(filename) 135 | gt_semantic_seg = mmcv.imfrombytes( 136 | img_bytes, flag='unchanged', 137 | backend=self.imdecode_backend).squeeze().astype(np.uint8) 138 | # modify if custom classes 139 | if results.get('label_map', None) is not None: 140 | for old_id, new_id in results['label_map'].items(): 141 | gt_semantic_seg[gt_semantic_seg == old_id] = new_id 142 | # reduce zero_label 143 | if self.reduce_zero_label: 144 | # avoid using underflow conversion 145 | gt_semantic_seg[gt_semantic_seg == 0] = 255 146 | gt_semantic_seg = gt_semantic_seg - 1 147 | gt_semantic_seg[gt_semantic_seg == 254] = 255 148 | results['gt_semantic_seg'] = gt_semantic_seg 149 | results['seg_fields'].append('gt_semantic_seg') 150 | return results 151 | 152 | def __repr__(self): 153 | repr_str = self.__class__.__name__ 154 | repr_str += f'(reduce_zero_label={self.reduce_zero_label},' 155 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 156 | return repr_str 157 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/test_time_aug.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import warnings 4 | 5 | import mmcv 6 | 7 | from ..builder import PIPELINES 8 | from .compose import Compose 9 | 10 | 11 | @PIPELINES.register_module() 12 | class MultiScaleFlipAug(object): 13 | """Test-time augmentation with multiple scales and flipping. 14 | 15 | An example configuration is as followed: 16 | 17 | .. code-block:: 18 | 19 | img_scale=(2048, 1024), 20 | img_ratios=[0.5, 1.0], 21 | flip=True, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ] 30 | 31 | After MultiScaleFLipAug with above configuration, the results are wrapped 32 | into lists of the same length as followed: 33 | 34 | .. code-block:: 35 | 36 | dict( 37 | img=[...], 38 | img_shape=[...], 39 | scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] 40 | flip=[False, True, False, True] 41 | ... 42 | ) 43 | 44 | Args: 45 | transforms (list[dict]): Transforms to apply in each augmentation. 46 | img_scale (None | tuple | list[tuple]): Images scales for resizing. 47 | img_ratios (float | list[float]): Image ratios for resizing 48 | flip (bool): Whether apply flip augmentation. Default: False. 49 | flip_direction (str | list[str]): Flip augmentation directions, 50 | options are "horizontal" and "vertical". If flip_direction is list, 51 | multiple flip augmentations will be applied. 52 | It has no effect when flip == False. Default: "horizontal". 53 | """ 54 | 55 | def __init__(self, 56 | transforms, 57 | img_scale, 58 | img_ratios=None, 59 | flip=False, 60 | flip_direction='horizontal'): 61 | self.transforms = Compose(transforms) 62 | if img_ratios is not None: 63 | img_ratios = img_ratios if isinstance(img_ratios, 64 | list) else [img_ratios] 65 | assert mmcv.is_list_of(img_ratios, float) 66 | if img_scale is None: 67 | # mode 1: given img_scale=None and a range of image ratio 68 | self.img_scale = None 69 | assert mmcv.is_list_of(img_ratios, float) 70 | elif isinstance(img_scale, tuple) and mmcv.is_list_of( 71 | img_ratios, float): 72 | assert len(img_scale) == 2 73 | # mode 2: given a scale and a range of image ratio 74 | self.img_scale = [(int(img_scale[0] * ratio), 75 | int(img_scale[1] * ratio)) 76 | for ratio in img_ratios] 77 | else: 78 | # mode 3: given multiple scales 79 | self.img_scale = img_scale if isinstance(img_scale, 80 | list) else [img_scale] 81 | assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None 82 | self.flip = flip 83 | self.img_ratios = img_ratios 84 | self.flip_direction = flip_direction if isinstance( 85 | flip_direction, list) else [flip_direction] 86 | assert mmcv.is_list_of(self.flip_direction, str) 87 | if not self.flip and self.flip_direction != ['horizontal']: 88 | warnings.warn( 89 | 'flip_direction has no effect when flip is set to False') 90 | if (self.flip 91 | and not any([t['type'] == 'RandomFlip' for t in transforms])): 92 | warnings.warn( 93 | 'flip has no effect when RandomFlip is not in transforms') 94 | 95 | def __call__(self, results): 96 | """Call function to apply test time augment transforms on results. 97 | 98 | Args: 99 | results (dict): Result dict contains the data to transform. 100 | 101 | Returns: 102 | dict[str: list]: The augmented data, where each value is wrapped 103 | into a list. 104 | """ 105 | 106 | aug_data = [] 107 | if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): 108 | h, w = results['img'].shape[:2] 109 | img_scale = [(int(w * ratio), int(h * ratio)) 110 | for ratio in self.img_ratios] 111 | else: 112 | img_scale = self.img_scale 113 | flip_aug = [False, True] if self.flip else [False] 114 | for scale in img_scale: 115 | for flip in flip_aug: 116 | for direction in self.flip_direction: 117 | _results = results.copy() 118 | _results['scale'] = scale 119 | _results['flip'] = flip 120 | _results['flip_direction'] = direction 121 | data = self.transforms(_results) 122 | aug_data.append(data) 123 | # list of dict to dict of list 124 | aug_data_dict = {key: [] for key in aug_data[0]} 125 | for data in aug_data: 126 | for key, val in data.items(): 127 | aug_data_dict[key].append(val) 128 | return aug_data_dict 129 | 130 | def __repr__(self): 131 | repr_str = self.__class__.__name__ 132 | repr_str += f'(transforms={self.transforms}, ' 133 | repr_str += f'img_scale={self.img_scale}, flip={self.flip})' 134 | repr_str += f'flip_direction={self.flip_direction}' 135 | return repr_str 136 | -------------------------------------------------------------------------------- /mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * # noqa: F401,F403 2 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, 3 | build_backbone, build_head, build_loss, build_segmentor) 4 | from .decode_heads import * # noqa: F401,F403 5 | from .losses import * # noqa: F401,F403 6 | from .necks import * # noqa: F401,F403 7 | from .segmentors import * # noqa: F401,F403 8 | 9 | __all__ = [ 10 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 11 | 'build_head', 'build_loss', 'build_segmentor' 12 | ] 13 | -------------------------------------------------------------------------------- /mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional backbones 3 | 4 | # from .mix_transformer import (MixVisionTransformer, mit_b0, mit_b1, mit_b2, 5 | # mit_b3, mit_b4, mit_b5) 6 | 7 | from .resnest import ResNeSt 8 | from .resnet import ResNet, ResNetV1c, ResNetV1d 9 | from .resnext import ResNeXt 10 | 11 | from .vit_det import SAMImageEncoderViT 12 | from .vit_adapter import ViTAdapter 13 | 14 | from .hrnet import HRNet 15 | from .unet import UNet 16 | from .mit import (MixVisionTransformer, mit_b0, mit_b1, 17 | mit_b2, mit_b3, mit_b4, mit_b5) 18 | from .swin import SwinTransformer 19 | 20 | __all__ = [ 21 | 'ResNet', 22 | 'ResNetV1c', 23 | 'ResNetV1d', 24 | 'ResNeXt', 25 | 'ResNeSt', 26 | 'MixVisionTransformer', 27 | 'mit_b0', 28 | 'mit_b1', 29 | 'mit_b2', 30 | 'mit_b3', 31 | 'mit_b4', 32 | 'mit_b5', 33 | 'SAMImageEncoderViT', 34 | 'ViTAdapter', 35 | 'HRNet','UNet', 36 | 'SwinTransformer' 37 | ] 38 | -------------------------------------------------------------------------------- /mmseg/models/backbones/resnext.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import math 4 | 5 | from mmcv.cnn import build_conv_layer, build_norm_layer 6 | 7 | from ..builder import BACKBONES 8 | from ..utils import ResLayer 9 | from .resnet import Bottleneck as _Bottleneck 10 | from .resnet import ResNet 11 | 12 | 13 | class Bottleneck(_Bottleneck): 14 | """Bottleneck block for ResNeXt. 15 | 16 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is 17 | "caffe", the stride-two layer is the first 1x1 conv layer. 18 | """ 19 | 20 | def __init__(self, 21 | inplanes, 22 | planes, 23 | groups=1, 24 | base_width=4, 25 | base_channels=64, 26 | **kwargs): 27 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs) 28 | 29 | if groups == 1: 30 | width = self.planes 31 | else: 32 | width = math.floor(self.planes * 33 | (base_width / base_channels)) * groups 34 | 35 | self.norm1_name, norm1 = build_norm_layer( 36 | self.norm_cfg, width, postfix=1) 37 | self.norm2_name, norm2 = build_norm_layer( 38 | self.norm_cfg, width, postfix=2) 39 | self.norm3_name, norm3 = build_norm_layer( 40 | self.norm_cfg, self.planes * self.expansion, postfix=3) 41 | 42 | self.conv1 = build_conv_layer( 43 | self.conv_cfg, 44 | self.inplanes, 45 | width, 46 | kernel_size=1, 47 | stride=self.conv1_stride, 48 | bias=False) 49 | self.add_module(self.norm1_name, norm1) 50 | fallback_on_stride = False 51 | self.with_modulated_dcn = False 52 | if self.with_dcn: 53 | fallback_on_stride = self.dcn.pop('fallback_on_stride', False) 54 | if not self.with_dcn or fallback_on_stride: 55 | self.conv2 = build_conv_layer( 56 | self.conv_cfg, 57 | width, 58 | width, 59 | kernel_size=3, 60 | stride=self.conv2_stride, 61 | padding=self.dilation, 62 | dilation=self.dilation, 63 | groups=groups, 64 | bias=False) 65 | else: 66 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN' 67 | self.conv2 = build_conv_layer( 68 | self.dcn, 69 | width, 70 | width, 71 | kernel_size=3, 72 | stride=self.conv2_stride, 73 | padding=self.dilation, 74 | dilation=self.dilation, 75 | groups=groups, 76 | bias=False) 77 | 78 | self.add_module(self.norm2_name, norm2) 79 | self.conv3 = build_conv_layer( 80 | self.conv_cfg, 81 | width, 82 | self.planes * self.expansion, 83 | kernel_size=1, 84 | bias=False) 85 | self.add_module(self.norm3_name, norm3) 86 | 87 | 88 | @BACKBONES.register_module() 89 | class ResNeXt(ResNet): 90 | """ResNeXt backbone. 91 | 92 | Args: 93 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 94 | in_channels (int): Number of input image channels. Normally 3. 95 | num_stages (int): Resnet stages, normally 4. 96 | groups (int): Group of resnext. 97 | base_width (int): Base width of resnext. 98 | strides (Sequence[int]): Strides of the first block of each stage. 99 | dilations (Sequence[int]): Dilation of each stage. 100 | out_indices (Sequence[int]): Output from which stages. 101 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 102 | layer is the 3x3 conv layer, otherwise the stride-two layer is 103 | the first 1x1 conv layer. 104 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means 105 | not freezing any parameters. 106 | norm_cfg (dict): dictionary to construct and config norm layer. 107 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 108 | freeze running stats (mean and var). Note: Effect on Batch Norm 109 | and its variants only. 110 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 111 | memory while slowing down the training speed. 112 | zero_init_residual (bool): whether to use zero init for last norm layer 113 | in resblocks to let them behave as identity. 114 | 115 | Example: 116 | >>> from mmseg.models import ResNeXt 117 | >>> import torch 118 | >>> self = ResNeXt(depth=50) 119 | >>> self.eval() 120 | >>> inputs = torch.rand(1, 3, 32, 32) 121 | >>> level_outputs = self.forward(inputs) 122 | >>> for level_out in level_outputs: 123 | ... print(tuple(level_out.shape)) 124 | (1, 256, 8, 8) 125 | (1, 512, 4, 4) 126 | (1, 1024, 2, 2) 127 | (1, 2048, 1, 1) 128 | """ 129 | 130 | arch_settings = { 131 | 50: (Bottleneck, (3, 4, 6, 3)), 132 | 101: (Bottleneck, (3, 4, 23, 3)), 133 | 152: (Bottleneck, (3, 8, 36, 3)) 134 | } 135 | 136 | def __init__(self, groups=1, base_width=4, **kwargs): 137 | self.groups = groups 138 | self.base_width = base_width 139 | super(ResNeXt, self).__init__(**kwargs) 140 | 141 | def make_res_layer(self, **kwargs): 142 | """Pack all blocks in a stage into a ``ResLayer``""" 143 | return ResLayer( 144 | groups=self.groups, 145 | base_width=self.base_width, 146 | base_channels=self.base_channels, 147 | **kwargs) 148 | -------------------------------------------------------------------------------- /mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Support UDA models 3 | 4 | import warnings 5 | 6 | from mmcv.cnn import MODELS as MMCV_MODELS 7 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 8 | from mmcv.utils import Registry 9 | 10 | MODELS = Registry('models', parent=MMCV_MODELS) 11 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 12 | 13 | BACKBONES = MODELS 14 | NECKS = MODELS 15 | HEADS = MODELS 16 | LOSSES = MODELS 17 | SEGMENTORS = MODELS 18 | 19 | def build_backbone(cfg): 20 | """Build backbone.""" 21 | return BACKBONES.build(cfg) 22 | 23 | 24 | def build_neck(cfg): 25 | """Build neck.""" 26 | return NECKS.build(cfg) 27 | 28 | 29 | def build_head(cfg): 30 | """Build head.""" 31 | return HEADS.build(cfg) 32 | 33 | 34 | def build_loss(cfg): 35 | """Build loss.""" 36 | return LOSSES.build(cfg) 37 | 38 | 39 | def build_train_model(cfg, train_cfg=None, test_cfg=None): 40 | """Build model.""" 41 | if train_cfg is not None or test_cfg is not None: 42 | warnings.warn( 43 | 'train_cfg and test_cfg is deprecated, ' 44 | 'please specify them in model', UserWarning) 45 | assert cfg.model.get('train_cfg') is None or train_cfg is None, \ 46 | 'train_cfg specified in both outer field and model field ' 47 | assert cfg.model.get('test_cfg') is None or test_cfg is None, \ 48 | 'test_cfg specified in both outer field and model field ' 49 | 50 | return SEGMENTORS.build( 51 | cfg.model, 52 | default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 53 | 54 | 55 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 56 | """Build segmentor.""" 57 | if train_cfg is not None or test_cfg is not None: 58 | warnings.warn( 59 | 'train_cfg and test_cfg is deprecated, ' 60 | 'please specify them in model', UserWarning) 61 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 62 | 'train_cfg specified in both outer field and model field ' 63 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 64 | 'test_cfg specified in both outer field and model field ' 65 | return SEGMENTORS.build( 66 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 67 | 68 | def build_attention(cfg): 69 | return MODELS.build(cfg) -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add additional decode_heads 3 | 4 | from .aspp_head import ASPPHead 5 | from .da_head import DAHead 6 | from .daformer_head import DAFormerHead 7 | from .dlv2_head import DLV2Head 8 | from .fcn_head import FCNHead 9 | 10 | from .isa_head import ISAHead 11 | from .psp_head import PSPHead 12 | from .segformer_head import SegFormerHead 13 | from .sep_aspp_head import DepthwiseSeparableASPPHead 14 | from .uper_head import UPerHead 15 | 16 | 17 | __all__ = [ 18 | 'FCNHead','PSPHead','ASPPHead','UPerHead', 19 | 'DepthwiseSeparableASPPHead','DAHead','DLV2Head', 20 | 'SegFormerHead','DAFormerHead','ISAHead' 21 | ] 22 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/aspp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | class ASPPModule(nn.ModuleList): 13 | """Atrous Spatial Pyramid Pooling (ASPP) Module. 14 | 15 | Args: 16 | dilations (tuple[int]): Dilation rate of each layer. 17 | in_channels (int): Input channels. 18 | channels (int): Channels after modules, before conv_seg. 19 | conv_cfg (dict|None): Config of conv layers. 20 | norm_cfg (dict|None): Config of norm layers. 21 | act_cfg (dict): Config of activation layers. 22 | """ 23 | 24 | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, 25 | act_cfg): 26 | super(ASPPModule, self).__init__() 27 | self.dilations = dilations 28 | self.in_channels = in_channels 29 | self.channels = channels 30 | self.conv_cfg = conv_cfg 31 | self.norm_cfg = norm_cfg 32 | self.act_cfg = act_cfg 33 | for dilation in dilations: 34 | self.append( 35 | ConvModule( 36 | self.in_channels, 37 | self.channels, 38 | 1 if dilation == 1 else 3, 39 | dilation=dilation, 40 | padding=0 if dilation == 1 else dilation, 41 | conv_cfg=self.conv_cfg, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg)) 44 | 45 | def forward(self, x): 46 | """Forward function.""" 47 | aspp_outs = [] 48 | for aspp_module in self: 49 | aspp_outs.append(aspp_module(x)) 50 | 51 | return aspp_outs 52 | 53 | 54 | @HEADS.register_module() 55 | class ASPPHead(BaseDecodeHead): 56 | """Rethinking Atrous Convolution for Semantic Image Segmentation. 57 | 58 | This head is the implementation of `DeepLabV3 59 | `_. 60 | 61 | Args: 62 | dilations (tuple[int]): Dilation rates for ASPP module. 63 | Default: (1, 6, 12, 18). 64 | """ 65 | 66 | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 67 | super(ASPPHead, self).__init__(**kwargs) 68 | assert isinstance(dilations, (list, tuple)) 69 | self.dilations = dilations 70 | self.image_pool = nn.Sequential( 71 | nn.AdaptiveAvgPool2d(1), 72 | ConvModule( 73 | self.in_channels, 74 | self.channels, 75 | 1, 76 | conv_cfg=self.conv_cfg, 77 | norm_cfg=self.norm_cfg, 78 | act_cfg=self.act_cfg)) 79 | self.aspp_modules = ASPPModule( 80 | dilations, 81 | self.in_channels, 82 | self.channels, 83 | conv_cfg=self.conv_cfg, 84 | norm_cfg=self.norm_cfg, 85 | act_cfg=self.act_cfg) 86 | self.bottleneck = ConvModule( 87 | (len(dilations) + 1) * self.channels, 88 | self.channels, 89 | 3, 90 | padding=1, 91 | conv_cfg=self.conv_cfg, 92 | norm_cfg=self.norm_cfg, 93 | act_cfg=self.act_cfg) 94 | 95 | def forward(self, inputs): 96 | """Forward function.""" 97 | x = self._transform_inputs(inputs) 98 | aspp_outs = [ 99 | resize( 100 | self.image_pool(x), 101 | size=x.size()[2:], 102 | mode='bilinear', 103 | align_corners=self.align_corners) 104 | ] 105 | aspp_outs.extend(self.aspp_modules(x)) 106 | aspp_outs = torch.cat(aspp_outs, dim=1) 107 | output = self.bottleneck(aspp_outs) 108 | output = self.cls_seg(output) 109 | return output 110 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/da_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Support for seg_weight 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from mmcv.cnn import ConvModule, Scale 7 | from torch import nn 8 | 9 | from mmseg.core import add_prefix 10 | from ..builder import HEADS 11 | from ..utils import SelfAttentionBlock as _SelfAttentionBlock 12 | from .decode_head import BaseDecodeHead 13 | 14 | 15 | class PAM(_SelfAttentionBlock): 16 | """Position Attention Module (PAM) 17 | 18 | Args: 19 | in_channels (int): Input channels of key/query feature. 20 | channels (int): Output channels of key/query transform. 21 | """ 22 | 23 | def __init__(self, in_channels, channels): 24 | super(PAM, self).__init__( 25 | key_in_channels=in_channels, 26 | query_in_channels=in_channels, 27 | channels=channels, 28 | out_channels=in_channels, 29 | share_key_query=False, 30 | query_downsample=None, 31 | key_downsample=None, 32 | key_query_num_convs=1, 33 | key_query_norm=False, 34 | value_out_num_convs=1, 35 | value_out_norm=False, 36 | matmul_norm=False, 37 | with_out=False, 38 | conv_cfg=None, 39 | norm_cfg=None, 40 | act_cfg=None) 41 | 42 | self.gamma = Scale(0) 43 | 44 | def forward(self, x): 45 | """Forward function.""" 46 | out = super(PAM, self).forward(x, x) 47 | 48 | out = self.gamma(out) + x 49 | return out 50 | 51 | 52 | class CAM(nn.Module): 53 | """Channel Attention Module (CAM)""" 54 | 55 | def __init__(self): 56 | super(CAM, self).__init__() 57 | self.gamma = Scale(0) 58 | 59 | def forward(self, x): 60 | """Forward function.""" 61 | batch_size, channels, height, width = x.size() 62 | proj_query = x.view(batch_size, channels, -1) 63 | proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) 64 | energy = torch.bmm(proj_query, proj_key) 65 | energy_new = torch.max( 66 | energy, -1, keepdim=True)[0].expand_as(energy) - energy 67 | attention = F.softmax(energy_new, dim=-1) 68 | proj_value = x.view(batch_size, channels, -1) 69 | 70 | out = torch.bmm(attention, proj_value) 71 | out = out.view(batch_size, channels, height, width) 72 | 73 | out = self.gamma(out) + x 74 | return out 75 | 76 | 77 | @HEADS.register_module() 78 | class DAHead(BaseDecodeHead): 79 | """Dual Attention Network for Scene Segmentation. 80 | 81 | This head is the implementation of `DANet 82 | `_. 83 | 84 | Args: 85 | pam_channels (int): The channels of Position Attention Module(PAM). 86 | """ 87 | 88 | def __init__(self, pam_channels, **kwargs): 89 | super(DAHead, self).__init__(**kwargs) 90 | self.pam_channels = pam_channels 91 | self.pam_in_conv = ConvModule( 92 | self.in_channels, 93 | self.channels, 94 | 3, 95 | padding=1, 96 | conv_cfg=self.conv_cfg, 97 | norm_cfg=self.norm_cfg, 98 | act_cfg=self.act_cfg) 99 | self.pam = PAM(self.channels, pam_channels) 100 | self.pam_out_conv = ConvModule( 101 | self.channels, 102 | self.channels, 103 | 3, 104 | padding=1, 105 | conv_cfg=self.conv_cfg, 106 | norm_cfg=self.norm_cfg, 107 | act_cfg=self.act_cfg) 108 | self.pam_conv_seg = nn.Conv2d( 109 | self.channels, self.num_classes, kernel_size=1) 110 | 111 | self.cam_in_conv = ConvModule( 112 | self.in_channels, 113 | self.channels, 114 | 3, 115 | padding=1, 116 | conv_cfg=self.conv_cfg, 117 | norm_cfg=self.norm_cfg, 118 | act_cfg=self.act_cfg) 119 | self.cam = CAM() 120 | self.cam_out_conv = ConvModule( 121 | self.channels, 122 | self.channels, 123 | 3, 124 | padding=1, 125 | conv_cfg=self.conv_cfg, 126 | norm_cfg=self.norm_cfg, 127 | act_cfg=self.act_cfg) 128 | self.cam_conv_seg = nn.Conv2d( 129 | self.channels, self.num_classes, kernel_size=1) 130 | 131 | def pam_cls_seg(self, feat): 132 | """PAM feature classification.""" 133 | if self.dropout is not None: 134 | feat = self.dropout(feat) 135 | output = self.pam_conv_seg(feat) 136 | return output 137 | 138 | def cam_cls_seg(self, feat): 139 | """CAM feature classification.""" 140 | if self.dropout is not None: 141 | feat = self.dropout(feat) 142 | output = self.cam_conv_seg(feat) 143 | return output 144 | 145 | def forward(self, inputs): 146 | """Forward function.""" 147 | x = self._transform_inputs(inputs) 148 | pam_feat = self.pam_in_conv(x) 149 | pam_feat = self.pam(pam_feat) 150 | pam_feat = self.pam_out_conv(pam_feat) 151 | pam_out = self.pam_cls_seg(pam_feat) 152 | 153 | cam_feat = self.cam_in_conv(x) 154 | cam_feat = self.cam(cam_feat) 155 | cam_feat = self.cam_out_conv(cam_feat) 156 | cam_out = self.cam_cls_seg(cam_feat) 157 | 158 | feat_sum = pam_feat + cam_feat 159 | pam_cam_out = self.cls_seg(feat_sum) 160 | 161 | return pam_cam_out, pam_out, cam_out 162 | 163 | def forward_test(self, inputs, img_metas, test_cfg): 164 | """Forward function for testing, only ``pam_cam`` is used.""" 165 | return self.forward(inputs)[0] 166 | 167 | def losses(self, seg_logit, seg_label, seg_weight=None): 168 | """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" 169 | pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit 170 | loss = dict() 171 | loss.update( 172 | add_prefix( 173 | super(DAHead, self).losses(pam_cam_seg_logit, seg_label, 174 | seg_weight), 'pam_cam')) 175 | loss.update( 176 | add_prefix( 177 | super(DAHead, self).losses(pam_seg_logit, seg_label, 178 | seg_weight), 'pam')) 179 | loss.update( 180 | add_prefix( 181 | super(DAHead, self).losses(cam_seg_logit, seg_label, 182 | seg_weight), 'cam')) 183 | return loss 184 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/dlv2_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/lhoyer/DAFormer 2 | # --------------------------------------------------------------- 3 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 5 | # --------------------------------------------------------------- 6 | 7 | from ..builder import HEADS 8 | from .aspp_head import ASPPModule 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | @HEADS.register_module() 13 | class DLV2Head(BaseDecodeHead): 14 | 15 | def __init__(self, dilations=(6, 12, 18, 24), **kwargs): 16 | assert 'channels' not in kwargs 17 | assert 'dropout_ratio' not in kwargs 18 | assert 'norm_cfg' not in kwargs 19 | kwargs['channels'] = 1 20 | kwargs['dropout_ratio'] = 0 21 | kwargs['norm_cfg'] = None 22 | super(DLV2Head, self).__init__(**kwargs) 23 | del self.conv_seg 24 | assert isinstance(dilations, (list, tuple)) 25 | self.dilations = dilations 26 | self.aspp_modules = ASPPModule( 27 | dilations, 28 | self.in_channels, 29 | self.num_classes, 30 | conv_cfg=self.conv_cfg, 31 | norm_cfg=None, 32 | act_cfg=None) 33 | 34 | def forward(self, inputs): 35 | """Forward function.""" 36 | # for f in inputs: 37 | # mmcv.print_log(f'{f.shape}', 'mmseg') 38 | x = self._transform_inputs(inputs) 39 | aspp_outs = self.aspp_modules(x) 40 | out = aspp_outs[0] 41 | for i in range(len(aspp_outs) - 1): 42 | out += aspp_outs[i + 1] 43 | return out 44 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class FCNHead(BaseDecodeHead): 13 | """Fully Convolution Networks for Semantic Segmentation. 14 | 15 | This head is implemented of `FCNNet `_. 16 | 17 | Args: 18 | num_convs (int): Number of convs in the head. Default: 2. 19 | kernel_size (int): The kernel size for convs in the head. Default: 3. 20 | concat_input (bool): Whether concat the input and output of convs 21 | before classification layer. 22 | dilation (int): The dilation rate for convs in the head. Default: 1. 23 | """ 24 | 25 | def __init__(self, 26 | num_convs=2, 27 | kernel_size=3, 28 | concat_input=True, 29 | dilation=1, 30 | compress=False, 31 | **kwargs): 32 | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) 33 | self.num_convs = num_convs 34 | self.concat_input = concat_input 35 | self.kernel_size = kernel_size 36 | self.compress = compress 37 | super(FCNHead, self).__init__(**kwargs) 38 | if num_convs == 0: 39 | assert self.in_channels == self.channels 40 | 41 | conv_padding = (kernel_size // 2) * dilation 42 | convs = [] 43 | 44 | if self.compress: 45 | convs.append( 46 | ConvModule( 47 | self.in_channels, 48 | self.channels, 49 | kernel_size=1, 50 | conv_cfg=self.conv_cfg, 51 | norm_cfg=self.norm_cfg, 52 | act_cfg=self.act_cfg)) 53 | 54 | convs.append( 55 | ConvModule( 56 | self.channels, 57 | self.channels, 58 | kernel_size=kernel_size, 59 | padding=conv_padding, 60 | dilation=dilation, 61 | conv_cfg=self.conv_cfg, 62 | norm_cfg=self.norm_cfg, 63 | act_cfg=self.act_cfg)) 64 | else: 65 | convs.append( 66 | ConvModule( 67 | self.in_channels, 68 | self.channels, 69 | kernel_size=kernel_size, 70 | padding=conv_padding, 71 | dilation=dilation, 72 | conv_cfg=self.conv_cfg, 73 | norm_cfg=self.norm_cfg, 74 | act_cfg=self.act_cfg)) 75 | 76 | for i in range(num_convs - 1): 77 | convs.append( 78 | ConvModule( 79 | self.channels, 80 | self.channels, 81 | kernel_size=kernel_size, 82 | padding=conv_padding, 83 | dilation=dilation, 84 | conv_cfg=self.conv_cfg, 85 | norm_cfg=self.norm_cfg, 86 | act_cfg=self.act_cfg)) 87 | if num_convs == 0: 88 | self.convs = nn.Identity() 89 | else: 90 | self.convs = nn.Sequential(*convs) 91 | if self.concat_input: 92 | self.conv_cat = ConvModule( 93 | self.in_channels + self.channels, 94 | self.channels, 95 | kernel_size=kernel_size, 96 | padding=kernel_size // 2, 97 | conv_cfg=self.conv_cfg, 98 | norm_cfg=self.norm_cfg, 99 | act_cfg=self.act_cfg) 100 | 101 | def forward(self, inputs): 102 | """Forward function.""" 103 | x = self._transform_inputs(inputs) 104 | output = self.convs(x) 105 | if self.concat_input: 106 | output = self.conv_cat(torch.cat([x, output], dim=1)) 107 | output = self.cls_seg(output) 108 | return output 109 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/psp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | class PPM(nn.ModuleList): 13 | """Pooling Pyramid Module used in PSPNet. 14 | 15 | Args: 16 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 17 | Module. 18 | in_channels (int): Input channels. 19 | channels (int): Channels after modules, before conv_seg. 20 | conv_cfg (dict|None): Config of conv layers. 21 | norm_cfg (dict|None): Config of norm layers. 22 | act_cfg (dict): Config of activation layers. 23 | align_corners (bool): align_corners argument of F.interpolate. 24 | """ 25 | 26 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 27 | act_cfg, align_corners, **kwargs): 28 | super(PPM, self).__init__() 29 | self.pool_scales = pool_scales 30 | self.align_corners = align_corners 31 | self.in_channels = in_channels 32 | self.channels = channels 33 | self.conv_cfg = conv_cfg 34 | self.norm_cfg = norm_cfg 35 | self.act_cfg = act_cfg 36 | for pool_scale in pool_scales: 37 | self.append( 38 | nn.Sequential( 39 | nn.AdaptiveAvgPool2d(pool_scale), 40 | ConvModule( 41 | self.in_channels, 42 | self.channels, 43 | 1, 44 | conv_cfg=self.conv_cfg, 45 | norm_cfg=self.norm_cfg, 46 | act_cfg=self.act_cfg, 47 | **kwargs))) 48 | 49 | def forward(self, x): 50 | """Forward function.""" 51 | ppm_outs = [] 52 | for ppm in self: 53 | ppm_out = ppm(x) 54 | upsampled_ppm_out = resize( 55 | ppm_out, 56 | size=x.size()[2:], 57 | mode='bilinear', 58 | align_corners=self.align_corners) 59 | ppm_outs.append(upsampled_ppm_out) 60 | return ppm_outs 61 | 62 | 63 | @HEADS.register_module() 64 | class PSPHead(BaseDecodeHead): 65 | """Pyramid Scene Parsing Network. 66 | 67 | This head is the implementation of 68 | `PSPNet `_. 69 | 70 | Args: 71 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 72 | Module. Default: (1, 2, 3, 6). 73 | """ 74 | 75 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 76 | super(PSPHead, self).__init__(**kwargs) 77 | assert isinstance(pool_scales, (list, tuple)) 78 | self.pool_scales = pool_scales 79 | self.psp_modules = PPM( 80 | self.pool_scales, 81 | self.in_channels, 82 | self.channels, 83 | conv_cfg=self.conv_cfg, 84 | norm_cfg=self.norm_cfg, 85 | act_cfg=self.act_cfg, 86 | align_corners=self.align_corners) 87 | self.bottleneck = ConvModule( 88 | self.in_channels + len(pool_scales) * self.channels, 89 | self.channels, 90 | 3, 91 | padding=1, 92 | conv_cfg=self.conv_cfg, 93 | norm_cfg=self.norm_cfg, 94 | act_cfg=self.act_cfg) 95 | 96 | def forward(self, inputs): 97 | """Forward function.""" 98 | x = self._transform_inputs(inputs) 99 | psp_outs = [x] 100 | psp_outs.extend(self.psp_modules(x)) 101 | psp_outs = torch.cat(psp_outs, dim=1) 102 | output = self.bottleneck(psp_outs) 103 | output = self.cls_seg(output) 104 | return output 105 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/segformer_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/NVlabs/SegFormer 2 | # Modifications: Model construction with loop 3 | # --------------------------------------------------------------- 4 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 5 | # 6 | # This work is licensed under the NVIDIA Source Code License 7 | # --------------------------------------------------------------- 8 | # A copy of the license is available at resources/license_segformer 9 | 10 | import torch 11 | import torch.nn as nn 12 | from mmcv.cnn import ConvModule 13 | 14 | from mmseg.ops import resize 15 | from ..builder import HEADS 16 | from .decode_head import BaseDecodeHead 17 | 18 | 19 | class MLP(nn.Module): 20 | """Linear Embedding.""" 21 | 22 | def __init__(self, input_dim=2048, embed_dim=768): 23 | super().__init__() 24 | self.proj = nn.Linear(input_dim, embed_dim) 25 | 26 | def forward(self, x): 27 | x = x.flatten(2).transpose(1, 2).contiguous() 28 | x = self.proj(x) 29 | return x 30 | 31 | 32 | @HEADS.register_module() 33 | class SegFormerHead(BaseDecodeHead): 34 | """ 35 | SegFormer: Simple and Efficient Design for Semantic Segmentation with 36 | Transformers 37 | """ 38 | 39 | def __init__(self, **kwargs): 40 | super(SegFormerHead, self).__init__( 41 | input_transform='multiple_select', **kwargs) 42 | 43 | ######################################## 44 | self.conv_seg=nn.Sequential() 45 | ######################################## 46 | 47 | decoder_params = kwargs['decoder_params'] 48 | embedding_dim = decoder_params['embed_dim'] 49 | conv_kernel_size = decoder_params['conv_kernel_size'] 50 | 51 | self.linear_c = {} 52 | for i, in_channels in zip(self.in_index, self.in_channels): 53 | self.linear_c[str(i)] = MLP( 54 | input_dim=in_channels, embed_dim=embedding_dim) 55 | self.linear_c = nn.ModuleDict(self.linear_c) 56 | 57 | self.linear_fuse = ConvModule( 58 | in_channels=embedding_dim * len(self.in_index), 59 | out_channels=embedding_dim, 60 | kernel_size=conv_kernel_size, 61 | padding=0 if conv_kernel_size == 1 else conv_kernel_size // 2, 62 | norm_cfg=kwargs['norm_cfg']) 63 | 64 | self.linear_pred = nn.Conv2d( 65 | embedding_dim, self.num_classes, kernel_size=1) 66 | 67 | def forward(self, inputs): 68 | x = inputs 69 | n, _, h, w = x[-1].shape 70 | # for f in x: 71 | # print(f.shape) 72 | 73 | _c = {} 74 | for i in self.in_index: 75 | # mmcv.print_log(f'{i}: {x[i].shape}, {self.linear_c[str(i)]}') 76 | _c[i] = self.linear_c[str(i)](x[i]).permute(0, 2, 1).contiguous() 77 | _c[i] = _c[i].reshape(n, -1, x[i].shape[2], x[i].shape[3]) 78 | if i != 0: 79 | _c[i] = resize( 80 | _c[i], 81 | size=x[0].size()[2:], 82 | mode='bilinear', 83 | align_corners=False) 84 | 85 | _c = self.linear_fuse(torch.cat(list(_c.values()), dim=1)) 86 | 87 | if self.dropout is not None: 88 | x = self.dropout(_c) 89 | else: 90 | x = _c 91 | x = self.linear_pred(x) 92 | 93 | return x 94 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/sep_aspp_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .aspp_head import ASPPHead, ASPPModule 10 | 11 | 12 | class DepthwiseSeparableASPPModule(ASPPModule): 13 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 14 | conv.""" 15 | 16 | def __init__(self, **kwargs): 17 | super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) 18 | for i, dilation in enumerate(self.dilations): 19 | if dilation > 1: 20 | self[i] = DepthwiseSeparableConvModule( 21 | self.in_channels, 22 | self.channels, 23 | 3, 24 | dilation=dilation, 25 | padding=dilation, 26 | norm_cfg=self.norm_cfg, 27 | act_cfg=self.act_cfg) 28 | 29 | 30 | @HEADS.register_module() 31 | class DepthwiseSeparableASPPHead(ASPPHead): 32 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 33 | Segmentation. 34 | 35 | This head is the implementation of `DeepLabV3+ 36 | `_. 37 | 38 | Args: 39 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 40 | the no decoder will be used. 41 | c1_channels (int): The intermediate channels of c1 decoder. 42 | """ 43 | 44 | def __init__(self, c1_in_channels, c1_channels, **kwargs): 45 | super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) 46 | assert c1_in_channels >= 0 47 | self.aspp_modules = DepthwiseSeparableASPPModule( 48 | dilations=self.dilations, 49 | in_channels=self.in_channels, 50 | channels=self.channels, 51 | conv_cfg=self.conv_cfg, 52 | norm_cfg=self.norm_cfg, 53 | act_cfg=self.act_cfg) 54 | if c1_in_channels > 0: 55 | self.c1_bottleneck = ConvModule( 56 | c1_in_channels, 57 | c1_channels, 58 | 1, 59 | conv_cfg=self.conv_cfg, 60 | norm_cfg=self.norm_cfg, 61 | act_cfg=self.act_cfg) 62 | else: 63 | self.c1_bottleneck = None 64 | self.sep_bottleneck = nn.Sequential( 65 | DepthwiseSeparableConvModule( 66 | self.channels + c1_channels, 67 | self.channels, 68 | 3, 69 | padding=1, 70 | norm_cfg=self.norm_cfg, 71 | act_cfg=self.act_cfg), 72 | DepthwiseSeparableConvModule( 73 | self.channels, 74 | self.channels, 75 | 3, 76 | padding=1, 77 | norm_cfg=self.norm_cfg, 78 | act_cfg=self.act_cfg)) 79 | 80 | def forward(self, inputs): 81 | """Forward function.""" 82 | x = self._transform_inputs(inputs) 83 | aspp_outs = [ 84 | resize( 85 | self.image_pool(x), 86 | size=x.size()[2:], 87 | mode='bilinear', 88 | align_corners=self.align_corners) 89 | ] 90 | aspp_outs.extend(self.aspp_modules(x)) 91 | aspp_outs = torch.cat(aspp_outs, dim=1) 92 | output = self.bottleneck(aspp_outs) 93 | if self.c1_bottleneck is not None: 94 | c1_output = self.c1_bottleneck(inputs[0]) 95 | output = resize( 96 | input=output, 97 | size=c1_output.shape[2:], 98 | mode='bilinear', 99 | align_corners=self.align_corners) 100 | output = torch.cat([output, c1_output], dim=1) 101 | output = self.sep_bottleneck(output) 102 | output = self.cls_seg(output) 103 | return output 104 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/uper_head.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | from .psp_head import PPM 11 | 12 | 13 | @HEADS.register_module() 14 | class UPerHead(BaseDecodeHead): 15 | """Unified Perceptual Parsing for Scene Understanding. 16 | 17 | This head is the implementation of `UPerNet 18 | `_. 19 | 20 | Args: 21 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 22 | Module applied on the last feature. Default: (1, 2, 3, 6). 23 | """ 24 | 25 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 26 | super(UPerHead, self).__init__( 27 | input_transform='multiple_select', **kwargs) 28 | # PSP Module 29 | self.psp_modules = PPM( 30 | pool_scales, 31 | self.in_channels[-1], 32 | self.channels, 33 | conv_cfg=self.conv_cfg, 34 | norm_cfg=self.norm_cfg, 35 | act_cfg=self.act_cfg, 36 | align_corners=self.align_corners) 37 | self.bottleneck = ConvModule( 38 | self.in_channels[-1] + len(pool_scales) * self.channels, 39 | self.channels, 40 | 3, 41 | padding=1, 42 | conv_cfg=self.conv_cfg, 43 | norm_cfg=self.norm_cfg, 44 | act_cfg=self.act_cfg) 45 | # FPN Module 46 | self.lateral_convs = nn.ModuleList() 47 | self.fpn_convs = nn.ModuleList() 48 | for in_channels in self.in_channels[:-1]: # skip the top layer 49 | l_conv = ConvModule( 50 | in_channels, 51 | self.channels, 52 | 1, 53 | conv_cfg=self.conv_cfg, 54 | norm_cfg=self.norm_cfg, 55 | act_cfg=self.act_cfg, 56 | inplace=False) 57 | fpn_conv = ConvModule( 58 | self.channels, 59 | self.channels, 60 | 3, 61 | padding=1, 62 | conv_cfg=self.conv_cfg, 63 | norm_cfg=self.norm_cfg, 64 | act_cfg=self.act_cfg, 65 | inplace=False) 66 | self.lateral_convs.append(l_conv) 67 | self.fpn_convs.append(fpn_conv) 68 | 69 | self.fpn_bottleneck = ConvModule( 70 | len(self.in_channels) * self.channels, 71 | self.channels, 72 | 3, 73 | padding=1, 74 | conv_cfg=self.conv_cfg, 75 | norm_cfg=self.norm_cfg, 76 | act_cfg=self.act_cfg) 77 | 78 | def psp_forward(self, inputs): 79 | """Forward function of PSP module.""" 80 | x = inputs[-1] 81 | psp_outs = [x] 82 | psp_outs.extend(self.psp_modules(x)) 83 | psp_outs = torch.cat(psp_outs, dim=1) 84 | output = self.bottleneck(psp_outs) 85 | 86 | return output 87 | 88 | def forward(self, inputs): 89 | """Forward function.""" 90 | 91 | inputs = self._transform_inputs(inputs) 92 | 93 | # build laterals 94 | laterals = [ 95 | lateral_conv(inputs[i]) 96 | for i, lateral_conv in enumerate(self.lateral_convs) 97 | ] 98 | 99 | laterals.append(self.psp_forward(inputs)) 100 | 101 | # build top-down path 102 | used_backbone_levels = len(laterals) 103 | for i in range(used_backbone_levels - 1, 0, -1): 104 | prev_shape = laterals[i - 1].shape[2:] 105 | laterals[i - 1] += resize( 106 | laterals[i], 107 | size=prev_shape, 108 | mode='bilinear', 109 | align_corners=self.align_corners) 110 | 111 | # build outputs 112 | fpn_outs = [ 113 | self.fpn_convs[i](laterals[i]) 114 | for i in range(used_backbone_levels - 1) 115 | ] 116 | # append psp feature 117 | fpn_outs.append(laterals[-1]) 118 | 119 | for i in range(used_backbone_levels - 1, 0, -1): 120 | fpn_outs[i] = resize( 121 | fpn_outs[i], 122 | size=fpn_outs[0].shape[2:], 123 | mode='bilinear', 124 | align_corners=self.align_corners) 125 | fpn_outs = torch.cat(fpn_outs, dim=1) 126 | output = self.fpn_bottleneck(fpn_outs) 127 | output = self.cls_seg(output) 128 | return output 129 | -------------------------------------------------------------------------------- /mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy, accuracy 2 | from .binary_loss import BinaryLoss 3 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 4 | cross_entropy, mask_cross_entropy) 5 | from .sam_loss import SAMLoss 6 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 7 | 8 | __all__ = [ 9 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 10 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 11 | 'weight_reduce_loss', 'weighted_loss' 12 | ,'SAMLoss','BinaryLoss' 13 | ] 14 | -------------------------------------------------------------------------------- /mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def accuracy(pred, target, topk=1, thresh=None): 7 | """Calculate accuracy according to the prediction and target. 8 | 9 | Args: 10 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 11 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 12 | topk (int | tuple[int], optional): If the predictions in ``topk`` 13 | matches the target, the predictions will be regarded as 14 | correct ones. Defaults to 1. 15 | thresh (float, optional): If not None, predictions with scores under 16 | this threshold are considered incorrect. Default to None. 17 | 18 | Returns: 19 | float | tuple[float]: If the input ``topk`` is a single integer, 20 | the function will return a single float as accuracy. If 21 | ``topk`` is a tuple containing multiple integers, the 22 | function will return a tuple containing accuracies of 23 | each ``topk`` number. 24 | """ 25 | assert isinstance(topk, (int, tuple)) 26 | if isinstance(topk, int): 27 | topk = (topk, ) 28 | return_single = True 29 | else: 30 | return_single = False 31 | 32 | maxk = max(topk) 33 | if pred.size(0) == 0: 34 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 35 | return accu[0] if return_single else accu 36 | assert pred.ndim == target.ndim + 1 37 | assert pred.size(0) == target.size(0) 38 | assert maxk <= pred.size(1), \ 39 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 40 | pred_value, pred_label = pred.topk(maxk, dim=1) 41 | # transpose to shape (maxk, N, ...) 42 | pred_label = pred_label.transpose(0, 1) 43 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 44 | if thresh is not None: 45 | # Only prediction values larger than thresh are counted as correct 46 | correct = correct & (pred_value > thresh).t() 47 | res = [] 48 | for k in topk: 49 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 50 | res.append(correct_k.mul_(100.0 / target.numel())) 51 | return res[0] if return_single else res 52 | 53 | 54 | class Accuracy(nn.Module): 55 | """Accuracy calculation module.""" 56 | 57 | def __init__(self, topk=(1, ), thresh=None): 58 | """Module to calculate the accuracy. 59 | 60 | Args: 61 | topk (tuple, optional): The criterion used to calculate the 62 | accuracy. Defaults to (1,). 63 | thresh (float, optional): If not None, predictions with scores 64 | under this threshold are considered incorrect. Default to None. 65 | """ 66 | super().__init__() 67 | self.topk = topk 68 | self.thresh = thresh 69 | 70 | def forward(self, pred, target): 71 | """Forward function to calculate accuracy. 72 | 73 | Args: 74 | pred (torch.Tensor): Prediction of models. 75 | target (torch.Tensor): Target for each prediction. 76 | 77 | Returns: 78 | tuple[float]: The accuracies under different topk criterions. 79 | """ 80 | return accuracy(pred, target, self.topk, self.thresh) 81 | -------------------------------------------------------------------------------- /mmseg/models/losses/sam_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..builder import LOSSES 6 | from .utils import weight_reduce_loss 7 | 8 | def cross_entropy(input, 9 | target, 10 | weight=None, 11 | class_weight=None, 12 | reduction='mean', 13 | avg_factor=None, 14 | ignore_index=-100): 15 | """The wrapper function for :func:`F.cross_entropy`""" 16 | # class_weight is a manual rescaling weight given to each class. 17 | # If given, has to be a Tensor of size C element-wise losses 18 | loss = F.cross_entropy( 19 | input, 20 | target, 21 | weight=class_weight, 22 | reduction='none', 23 | ignore_index=ignore_index) 24 | 25 | # apply weights and do the reduction 26 | if weight is not None: 27 | weight = weight.float() 28 | loss = weight_reduce_loss( 29 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 30 | 31 | return loss 32 | 33 | def softCrossEntropy(input, target, reduction='mean'): 34 | log_logit = -F.log_softmax(input, dim=1) 35 | batch = input.shape[0] 36 | if reduction == 'batchmean': 37 | loss = torch.mean(torch.mul(log_logit, target)) / batch 38 | elif reduction=='mean': 39 | loss = torch.mean(torch.mul(log_logit, target)) 40 | elif reduction=='batchsum': 41 | loss = torch.sum(torch.mul(log_logit, target)) / batch 42 | elif reduction=='sum': 43 | loss = torch.sum(torch.mul(log_logit, target)) 44 | else: 45 | loss = torch.mul(log_logit, target) 46 | return loss 47 | 48 | 49 | @LOSSES.register_module() 50 | class SAMLoss(nn.Module): 51 | def __init__(self, 52 | size_average=None, 53 | reduce=None, 54 | use_kl=True, 55 | one_hot=False, 56 | reduction='batchmean', 57 | log_target=False, 58 | loss_weight=1, 59 | ): 60 | super(SAMLoss,self).__init__() 61 | self.size_average=size_average 62 | self.reduce=reduce 63 | self.use_kl=use_kl 64 | self.one_hot=one_hot 65 | self.reduction=reduction 66 | self.log_target=log_target 67 | self.loss_weight=loss_weight 68 | if self.use_kl: 69 | self.loss_func = F.kl_div 70 | elif self.one_hot: 71 | self.loss_func = cross_entropy 72 | else: 73 | self.loss_func = softCrossEntropy 74 | 75 | 76 | def forward(self, 77 | seg_logit, 78 | gt_logit, 79 | seg_weight=None, 80 | ): 81 | seg_logit=torch.log_softmax(seg_logit,dim=1)#.cpu() 82 | #gt_logit=gt_logit.cpu() 83 | if self.use_kl: 84 | loss = self.loss_weight* self.loss_func( 85 | input=seg_logit, 86 | target=gt_logit, 87 | size_average=self.size_average, 88 | reduce=self.reduce, 89 | reduction=self.reduction, 90 | log_target=self.log_target, 91 | ) 92 | elif self.one_hot: 93 | gt_logit=gt_logit.squeeze(1) 94 | loss = self.loss_weight*self.loss_func( 95 | input=seg_logit, 96 | target=gt_logit, 97 | weight=seg_weight, 98 | ) 99 | else: 100 | loss = self.loss_weight* self.loss_func( 101 | input=seg_logit, 102 | target=gt_logit, 103 | reduction=self.reduction, 104 | ) 105 | #loss = loss.to(seg_logit.device) 106 | return loss 107 | -------------------------------------------------------------------------------- /mmseg/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import functools 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_class_weight(class_weight): 11 | """Get class weight for loss function. 12 | 13 | Args: 14 | class_weight (list[float] | str | None): If class_weight is a str, 15 | take it as a file name and read from it. 16 | """ 17 | if isinstance(class_weight, str): 18 | # take it as a file path 19 | if class_weight.endswith('.npy'): 20 | class_weight = np.load(class_weight) 21 | else: 22 | # pkl, json or yaml 23 | class_weight = mmcv.load(class_weight) 24 | 25 | return class_weight 26 | 27 | 28 | def reduce_loss(loss, reduction): 29 | """Reduce loss as specified. 30 | 31 | Args: 32 | loss (Tensor): Elementwise loss tensor. 33 | reduction (str): Options are "none", "mean" and "sum". 34 | 35 | Return: 36 | Tensor: Reduced loss tensor. 37 | """ 38 | reduction_enum = F._Reduction.get_enum(reduction) 39 | # none: 0, elementwise_mean:1, sum: 2 40 | if reduction_enum == 0: 41 | return loss 42 | elif reduction_enum == 1: 43 | return loss.mean() 44 | elif reduction_enum == 2: 45 | return loss.sum() 46 | 47 | 48 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 49 | """Apply element-wise weight and reduce loss. 50 | 51 | Args: 52 | loss (Tensor): Element-wise loss. 53 | weight (Tensor): Element-wise weights. 54 | reduction (str): Same as built-in losses of PyTorch. 55 | avg_factor (float): Avarage factor when computing the mean of losses. 56 | 57 | Returns: 58 | Tensor: Processed loss values. 59 | """ 60 | # if weight is specified, apply element-wise weight 61 | if weight is not None: 62 | assert weight.dim() == loss.dim() 63 | if weight.dim() > 1: 64 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 65 | loss = loss * weight 66 | 67 | # if avg_factor is not specified, just reduce the loss 68 | if avg_factor is None: 69 | loss = reduce_loss(loss, reduction) 70 | else: 71 | # if reduction is mean, then average the loss by avg_factor 72 | if reduction == 'mean': 73 | loss = loss.sum() / avg_factor 74 | # if reduction is 'none', then do nothing, otherwise raise an error 75 | elif reduction != 'none': 76 | raise ValueError('avg_factor can not be used with reduction="sum"') 77 | return loss 78 | 79 | 80 | def weighted_loss(loss_func): 81 | """Create a weighted version of a given loss function. 82 | 83 | To use this decorator, the loss function must have the signature like 84 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 85 | element-wise loss without any reduction. This decorator will add weight 86 | and reduction arguments to the function. The decorated function will have 87 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 88 | avg_factor=None, **kwargs)`. 89 | 90 | :Example: 91 | 92 | >>> import torch 93 | >>> @weighted_loss 94 | >>> def l1_loss(pred, target): 95 | >>> return (pred - target).abs() 96 | 97 | >>> pred = torch.Tensor([0, 2, 3]) 98 | >>> target = torch.Tensor([1, 1, 1]) 99 | >>> weight = torch.Tensor([1, 0, 1]) 100 | 101 | >>> l1_loss(pred, target) 102 | tensor(1.3333) 103 | >>> l1_loss(pred, target, weight) 104 | tensor(1.) 105 | >>> l1_loss(pred, target, reduction='none') 106 | tensor([1., 1., 2.]) 107 | >>> l1_loss(pred, target, weight, avg_factor=2) 108 | tensor(1.5000) 109 | """ 110 | 111 | @functools.wraps(loss_func) 112 | def wrapper(pred, 113 | target, 114 | weight=None, 115 | reduction='mean', 116 | avg_factor=None, 117 | **kwargs): 118 | # get element-wise loss 119 | loss = loss_func(pred, target, **kwargs) 120 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 121 | return loss 122 | 123 | return wrapper 124 | -------------------------------------------------------------------------------- /mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .segformer_adapter import SegFormerAdapter 2 | 3 | from .sam_neck import SAMNeck 4 | from .segformer_neck import SegFormerNeck 5 | 6 | 7 | __all__ =[ 8 | 'SegFormerAdapter', 9 | 'SAMNeck', 10 | 'SegFormerNeck', 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/models/necks/sam_neck.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from mmseg.ops import resize 5 | from ..builder import NECKS 6 | 7 | class LayerNorm(nn.Module): 8 | """ 9 | A LayerNorm variant, popularized by Transformers, that performs point-wise mean and 10 | variance normalization over the channel dimension for inputs that have shape 11 | (batch_size, channels, height, width). 12 | https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 13 | """ 14 | 15 | def __init__(self, normalized_shape, eps=1e-6): 16 | super().__init__() 17 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 18 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 19 | self.eps = eps 20 | self.normalized_shape = (normalized_shape,) 21 | 22 | def forward(self, x): 23 | u = x.mean(1, keepdim=True) 24 | s = (x - u).pow(2).mean(1, keepdim=True) 25 | x = (x - u) / torch.sqrt(s + self.eps) 26 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 27 | return x 28 | 29 | @NECKS.register_module() 30 | class SAMNeck(nn.Module): 31 | def __init__(self, 32 | dim=256, 33 | out_channels=[64,128,320,512], 34 | use_conv=True, 35 | scale_factors=[4,2,1,0.5], 36 | norm_layer=LayerNorm): 37 | super(SAMNeck,self).__init__() 38 | self.dim=dim 39 | self.out_channels=out_channels 40 | self.use_conv=use_conv 41 | self.scale_factors=scale_factors 42 | self.norm_layer = norm_layer 43 | 44 | self.stages = nn.ModuleList() 45 | for idx, scale in enumerate(scale_factors): 46 | out_dim = dim 47 | out_channel = self.out_channels[idx] 48 | if scale == 8.0: 49 | layers = [ 50 | nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), 51 | self.norm_layer(dim//2), 52 | nn.GELU(), 53 | nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), 54 | self.norm_layer(dim//4), 55 | nn.GELU(), 56 | nn.ConvTranspose2d(dim // 4, dim // 8, kernel_size=2, stride=2), 57 | ] 58 | out_dim = dim // 8 59 | elif scale == 4.0: 60 | layers = [ 61 | nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), 62 | self.norm_layer(dim//2), 63 | nn.GELU(), 64 | nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), 65 | ] 66 | out_dim = dim // 4 67 | elif scale == 2.0: 68 | layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] 69 | out_dim = dim // 2 70 | elif scale == 1.0: 71 | layers = [] 72 | elif scale == 0.5: 73 | layers = [nn.MaxPool2d(kernel_size=2, stride=2)] 74 | else: 75 | raise NotImplementedError(f"scale_factor={scale} is not supported yet.") 76 | 77 | layers.extend( 78 | [ 79 | nn.Conv2d( 80 | out_dim, 81 | out_channel, 82 | kernel_size=1, 83 | bias=False, 84 | ), 85 | self.norm_layer(out_channel), 86 | nn.Conv2d( 87 | out_channel, 88 | out_channel, 89 | kernel_size=3, 90 | padding=1, 91 | bias=False 92 | ), 93 | self.norm_layer(out_channel), 94 | ] 95 | ) 96 | layers = nn.Sequential(*layers) 97 | self.stages.append(layers) 98 | 99 | def forward(self,x): 100 | feature=[] 101 | for stage in self.stages: 102 | feature.append(stage(x)) 103 | return feature 104 | -------------------------------------------------------------------------------- /mmseg/models/necks/segformer_adapter.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/lhoyer/DAFormer 2 | # --------------------------------------------------------------- 3 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 5 | # --------------------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from mmseg.ops import resize 11 | from ..builder import NECKS 12 | 13 | 14 | @NECKS.register_module() 15 | class SegFormerAdapter(nn.Module): 16 | 17 | def __init__(self, out_layers=[3], scales=[4]): 18 | super(SegFormerAdapter, self).__init__() 19 | self.out_layers = out_layers 20 | self.scales = scales 21 | 22 | def forward(self, x): 23 | _c = {} 24 | for i, s in zip(self.out_layers, self.scales): 25 | if s == 1: 26 | _c[i] = x[i] 27 | else: 28 | _c[i] = resize( 29 | x[i], scale_factor=s, mode='bilinear', align_corners=False) 30 | # mmcv.print_log(f'{i}: {x[i].shape}, {_c[i].shape}', 'mmseg') 31 | 32 | x = torch.cat(list(_c.values()), dim=1) 33 | 34 | return x 35 | -------------------------------------------------------------------------------- /mmseg/models/necks/segformer_neck.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/lhoyer/DAFormer 2 | # --------------------------------------------------------------- 3 | # Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 5 | # --------------------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from mmseg.ops import resize 11 | from ..builder import NECKS 12 | 13 | 14 | @NECKS.register_module() 15 | class SegFormerNeck(nn.Module): 16 | 17 | def __init__(self, 18 | in_channels=[512], 19 | out_channel=1024): 20 | super(SegFormerNeck, self).__init__() 21 | self.out_channel=out_channel 22 | self.in_channels=in_channels 23 | self.dim_cat=lambda x,shape:torch.cat( 24 | [ 25 | resize( 26 | input=each_x, 27 | size=shape, 28 | mode='bilinear', 29 | align_corners=False, 30 | ) 31 | for each_x in x 32 | ], 33 | dim=1 34 | ) 35 | self.dim_fuse=nn.Conv2d( 36 | in_channels=sum(self.in_channels), 37 | out_channels=self.out_channel, 38 | kernel_size=1, 39 | ) 40 | 41 | #self.test_init() 42 | 43 | def test_init(self): 44 | for name,param in self.named_parameters(): 45 | param.requires_grad=False 46 | 47 | def forward(self, x): 48 | size= x[-2].shape[-2:] # 因为倒数第二个的scale_factor是1 49 | x = self.dim_cat(x,size) 50 | x = self.dim_fuse(x) 51 | return x 52 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import BaseSegmentor 3 | from .encoder_decoder import EncoderDecoder 4 | 5 | from .lesion_encoder_decoder import LesionEncoderDecoder 6 | 7 | from .HRDecoder import HRDecoder,EfficientHRDecoder 8 | 9 | __all__ = ['BaseSegmentor', 'EncoderDecoder', 10 | 'LesionEncoderDecoder', 'HRDecoder', 'EfficientHRDecoder' 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .ckpt_convert import mit_convert,vit_convert,swin_convert 2 | from .make_divisible import make_divisible 3 | from .self_attention_block import SelfAttentionBlock 4 | from .embed import PatchEmbed 5 | from .wrappers import Upsample,resize 6 | from .res_layer import ResLayer 7 | from .shape_convert import nchw_to_nlc,nlc_to_nchw 8 | from .up_conv_block import UpConvBlock 9 | 10 | __all__ = [ 11 | 'mit_convert','vit_convert','swin_convert', 12 | 'make_divisible', 13 | 'SelfAttentionBlock', 14 | 'PatchEmbed', 15 | 'Upsample','resize', 16 | 'ResLayer', 17 | 'nchw_to_nlc','nlc_to_nchw', 18 | 'UpConvBlock' 19 | ] 20 | -------------------------------------------------------------------------------- /mmseg/models/utils/ckpt_convert.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def swin_convert(ckpt): 7 | new_ckpt = OrderedDict() 8 | 9 | def correct_unfold_reduction_order(x): 10 | out_channel, in_channel = x.shape 11 | x = x.reshape(out_channel, 4, in_channel // 4) 12 | x = x[:, [0, 2, 1, 3], :].transpose(1, 13 | 2).reshape(out_channel, in_channel) 14 | return x 15 | 16 | def correct_unfold_norm_order(x): 17 | in_channel = x.shape[0] 18 | x = x.reshape(4, in_channel // 4) 19 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 20 | return x 21 | 22 | for k, v in ckpt.items(): 23 | if k.startswith('head'): 24 | continue 25 | elif k.startswith('layers'): 26 | new_v = v 27 | if 'attn.' in k: 28 | new_k = k.replace('attn.', 'attn.w_msa.') 29 | elif 'mlp.' in k: 30 | if 'mlp.fc1.' in k: 31 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 32 | elif 'mlp.fc2.' in k: 33 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 34 | else: 35 | new_k = k.replace('mlp.', 'ffn.') 36 | elif 'downsample' in k: 37 | new_k = k 38 | if 'reduction.' in k: 39 | new_v = correct_unfold_reduction_order(v) 40 | elif 'norm.' in k: 41 | new_v = correct_unfold_norm_order(v) 42 | else: 43 | new_k = k 44 | new_k = new_k.replace('layers', 'stages', 1) 45 | elif k.startswith('patch_embed'): 46 | new_v = v 47 | if 'proj' in k: 48 | new_k = k.replace('proj', 'projection') 49 | else: 50 | new_k = k 51 | else: 52 | new_v = v 53 | new_k = k 54 | 55 | new_ckpt[new_k] = new_v 56 | 57 | return new_ckpt 58 | 59 | 60 | def vit_convert(ckpt): 61 | 62 | new_ckpt = OrderedDict() 63 | 64 | for k, v in ckpt.items(): 65 | if k.startswith('head'): 66 | continue 67 | if k.startswith('norm'): 68 | new_k = k.replace('norm.', 'ln1.') 69 | elif k.startswith('patch_embed'): 70 | if 'proj' in k: 71 | new_k = k.replace('proj', 'projection') 72 | else: 73 | new_k = k 74 | elif k.startswith('blocks'): 75 | if 'norm' in k: 76 | new_k = k.replace('norm', 'ln') 77 | elif 'mlp.fc1' in k: 78 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 79 | elif 'mlp.fc2' in k: 80 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 81 | elif 'attn.qkv' in k: 82 | new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') 83 | elif 'attn.proj' in k: 84 | new_k = k.replace('attn.proj', 'attn.attn.out_proj') 85 | else: 86 | new_k = k 87 | new_k = new_k.replace('blocks.', 'layers.') 88 | else: 89 | new_k = k 90 | new_ckpt[new_k] = v 91 | 92 | return new_ckpt 93 | 94 | 95 | def mit_convert(ckpt): 96 | new_ckpt = OrderedDict() 97 | # Process the concat between q linear weights and kv linear weights 98 | for k, v in ckpt.items(): 99 | if k.startswith('head'): 100 | continue 101 | elif k.startswith('patch_embed'): 102 | stage_i = int(k.split('.')[0].replace('patch_embed', '')) 103 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') 104 | new_v = v 105 | if 'proj.' in new_k: 106 | new_k = new_k.replace('proj.', 'projection.') 107 | elif k.startswith('block'): 108 | stage_i = int(k.split('.')[0].replace('block', '')) 109 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') 110 | new_v = v 111 | if 'attn.q.' in new_k: 112 | sub_item_k = k.replace('q.', 'kv.') 113 | new_k = new_k.replace('q.', 'attn.in_proj_') 114 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) 115 | elif 'attn.kv.' in new_k: 116 | continue 117 | elif 'attn.proj.' in new_k: 118 | new_k = new_k.replace('proj.', 'attn.out_proj.') 119 | elif 'attn.sr.' in new_k: 120 | new_k = new_k.replace('sr.', 'sr.') 121 | elif 'mlp.' in new_k: 122 | string = f'{new_k}-' 123 | new_k = new_k.replace('mlp.', 'ffn.layers.') 124 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k: 125 | new_v = v.reshape((*v.shape, 1, 1)) 126 | new_k = new_k.replace('fc1.', '0.') 127 | new_k = new_k.replace('dwconv.dwconv.', '1.') 128 | new_k = new_k.replace('fc2.', '4.') 129 | string += f'{new_k} {v.shape}-{new_v.shape}' 130 | # print(string) 131 | elif k.startswith('norm'): 132 | stage_i = int(k.split('.')[0].replace('norm', '')) 133 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') 134 | new_v = v 135 | else: 136 | new_k = k 137 | new_v = v 138 | new_ckpt[new_k] = new_v 139 | return new_ckpt 140 | -------------------------------------------------------------------------------- /mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 5 | """Make divisible function. 6 | 7 | This function rounds the channel number to the nearest value that can be 8 | divisible by the divisor. It is taken from the original tf repo. It ensures 9 | that all layers have a channel number that is divisible by divisor. It can 10 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 11 | 12 | Args: 13 | value (int): The original channel number. 14 | divisor (int): The divisor to fully divide the channel number. 15 | min_value (int): The minimum value of the output channel. 16 | Default: None, means that the minimum value equal to the divisor. 17 | min_ratio (float): The minimum ratio of the rounded channel number to 18 | the original channel number. Default: 0.9. 19 | 20 | Returns: 21 | int: The modified output channel number. 22 | """ 23 | 24 | if min_value is None: 25 | min_value = divisor 26 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than (1-min_ratio). 28 | if new_value < min_ratio * value: 29 | new_value += divisor 30 | return new_value 31 | -------------------------------------------------------------------------------- /mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | from mmcv.cnn import build_conv_layer, build_norm_layer 4 | from mmcv.runner import Sequential 5 | from torch import nn as nn 6 | 7 | 8 | class ResLayer(Sequential): 9 | """ResLayer to build ResNet style backbone. 10 | 11 | Args: 12 | block (nn.Module): block used to build ResLayer. 13 | inplanes (int): inplanes of block. 14 | planes (int): planes of block. 15 | num_blocks (int): number of blocks. 16 | stride (int): stride of the first block. Default: 1 17 | avg_down (bool): Use AvgPool instead of stride conv when 18 | downsampling in the bottleneck. Default: False 19 | conv_cfg (dict): dictionary to construct and config conv layer. 20 | Default: None 21 | norm_cfg (dict): dictionary to construct and config norm layer. 22 | Default: dict(type='BN') 23 | multi_grid (int | None): Multi grid dilation rates of last 24 | stage. Default: None 25 | contract_dilation (bool): Whether contract first dilation of each layer 26 | Default: False 27 | """ 28 | 29 | def __init__(self, 30 | block, 31 | inplanes, 32 | planes, 33 | num_blocks, 34 | stride=1, 35 | dilation=1, 36 | avg_down=False, 37 | conv_cfg=None, 38 | norm_cfg=dict(type='BN'), 39 | multi_grid=None, 40 | contract_dilation=False, 41 | **kwargs): 42 | self.block = block 43 | 44 | downsample = None 45 | if stride != 1 or inplanes != planes * block.expansion: 46 | downsample = [] 47 | conv_stride = stride 48 | if avg_down: 49 | conv_stride = 1 50 | downsample.append( 51 | nn.AvgPool2d( 52 | kernel_size=stride, 53 | stride=stride, 54 | ceil_mode=True, 55 | count_include_pad=False)) 56 | downsample.extend([ 57 | build_conv_layer( 58 | conv_cfg, 59 | inplanes, 60 | planes * block.expansion, 61 | kernel_size=1, 62 | stride=conv_stride, 63 | bias=False), 64 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 65 | ]) 66 | downsample = nn.Sequential(*downsample) 67 | 68 | layers = [] 69 | if multi_grid is None: 70 | if dilation > 1 and contract_dilation: 71 | first_dilation = dilation // 2 72 | else: 73 | first_dilation = dilation 74 | else: 75 | first_dilation = multi_grid[0] 76 | layers.append( 77 | block( 78 | inplanes=inplanes, 79 | planes=planes, 80 | stride=stride, 81 | dilation=first_dilation, 82 | downsample=downsample, 83 | conv_cfg=conv_cfg, 84 | norm_cfg=norm_cfg, 85 | **kwargs)) 86 | inplanes = planes * block.expansion 87 | for i in range(1, num_blocks): 88 | layers.append( 89 | block( 90 | inplanes=inplanes, 91 | planes=planes, 92 | stride=1, 93 | dilation=dilation if multi_grid is None else multi_grid[i], 94 | conv_cfg=conv_cfg, 95 | norm_cfg=norm_cfg, 96 | **kwargs)) 97 | super(ResLayer, self).__init__(*layers) 98 | -------------------------------------------------------------------------------- /mmseg/models/utils/self_attention_block.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | from mmcv.cnn import ConvModule, constant_init 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class SelfAttentionBlock(nn.Module): 10 | """General self-attention block/non-local block. 11 | 12 | Please refer to https://arxiv.org/abs/1706.03762 for details about key, 13 | query and value. 14 | 15 | Args: 16 | key_in_channels (int): Input channels of key feature. 17 | query_in_channels (int): Input channels of query feature. 18 | channels (int): Output channels of key/query transform. 19 | out_channels (int): Output channels. 20 | share_key_query (bool): Whether share projection weight between key 21 | and query projection. 22 | query_downsample (nn.Module): Query downsample module. 23 | key_downsample (nn.Module): Key downsample module. 24 | key_query_num_convs (int): Number of convs for key/query projection. 25 | value_num_convs (int): Number of convs for value projection. 26 | matmul_norm (bool): Whether normalize attention map with sqrt of 27 | channels 28 | with_out (bool): Whether use out projection. 29 | conv_cfg (dict|None): Config of conv layers. 30 | norm_cfg (dict|None): Config of norm layers. 31 | act_cfg (dict|None): Config of activation layers. 32 | """ 33 | 34 | def __init__(self, key_in_channels, query_in_channels, channels, 35 | out_channels, share_key_query, query_downsample, 36 | key_downsample, key_query_num_convs, value_out_num_convs, 37 | key_query_norm, value_out_norm, matmul_norm, with_out, 38 | conv_cfg, norm_cfg, act_cfg): 39 | super(SelfAttentionBlock, self).__init__() 40 | if share_key_query: 41 | assert key_in_channels == query_in_channels 42 | self.key_in_channels = key_in_channels 43 | self.query_in_channels = query_in_channels 44 | self.out_channels = out_channels 45 | self.channels = channels 46 | self.share_key_query = share_key_query 47 | self.conv_cfg = conv_cfg 48 | self.norm_cfg = norm_cfg 49 | self.act_cfg = act_cfg 50 | self.key_project = self.build_project( 51 | key_in_channels, 52 | channels, 53 | num_convs=key_query_num_convs, 54 | use_conv_module=key_query_norm, 55 | conv_cfg=conv_cfg, 56 | norm_cfg=norm_cfg, 57 | act_cfg=act_cfg) 58 | if share_key_query: 59 | self.query_project = self.key_project 60 | else: 61 | self.query_project = self.build_project( 62 | query_in_channels, 63 | channels, 64 | num_convs=key_query_num_convs, 65 | use_conv_module=key_query_norm, 66 | conv_cfg=conv_cfg, 67 | norm_cfg=norm_cfg, 68 | act_cfg=act_cfg) 69 | self.value_project = self.build_project( 70 | key_in_channels, 71 | channels if with_out else out_channels, 72 | num_convs=value_out_num_convs, 73 | use_conv_module=value_out_norm, 74 | conv_cfg=conv_cfg, 75 | norm_cfg=norm_cfg, 76 | act_cfg=act_cfg) 77 | if with_out: 78 | self.out_project = self.build_project( 79 | channels, 80 | out_channels, 81 | num_convs=value_out_num_convs, 82 | use_conv_module=value_out_norm, 83 | conv_cfg=conv_cfg, 84 | norm_cfg=norm_cfg, 85 | act_cfg=act_cfg) 86 | else: 87 | self.out_project = None 88 | 89 | self.query_downsample = query_downsample 90 | self.key_downsample = key_downsample 91 | self.matmul_norm = matmul_norm 92 | 93 | self.init_weights() 94 | 95 | def init_weights(self): 96 | """Initialize weight of later layer.""" 97 | if self.out_project is not None: 98 | if not isinstance(self.out_project, ConvModule): 99 | constant_init(self.out_project, 0) 100 | 101 | def build_project(self, in_channels, channels, num_convs, use_conv_module, 102 | conv_cfg, norm_cfg, act_cfg): 103 | """Build projection layer for key/query/value/out.""" 104 | if use_conv_module: 105 | convs = [ 106 | ConvModule( 107 | in_channels, 108 | channels, 109 | 1, 110 | conv_cfg=conv_cfg, 111 | norm_cfg=norm_cfg, 112 | act_cfg=act_cfg) 113 | ] 114 | for _ in range(num_convs - 1): 115 | convs.append( 116 | ConvModule( 117 | channels, 118 | channels, 119 | 1, 120 | conv_cfg=conv_cfg, 121 | norm_cfg=norm_cfg, 122 | act_cfg=act_cfg)) 123 | else: 124 | convs = [nn.Conv2d(in_channels, channels, 1)] 125 | for _ in range(num_convs - 1): 126 | convs.append(nn.Conv2d(channels, channels, 1)) 127 | if len(convs) > 1: 128 | convs = nn.Sequential(*convs) 129 | else: 130 | convs = convs[0] 131 | return convs 132 | 133 | def forward(self, query_feats, key_feats): 134 | """Forward function.""" 135 | batch_size = query_feats.size(0) 136 | query = self.query_project(query_feats) 137 | if self.query_downsample is not None: 138 | query = self.query_downsample(query) 139 | query = query.reshape(*query.shape[:2], -1) 140 | query = query.permute(0, 2, 1).contiguous() 141 | 142 | key = self.key_project(key_feats) 143 | value = self.value_project(key_feats) 144 | if self.key_downsample is not None: 145 | key = self.key_downsample(key) 146 | value = self.key_downsample(value) 147 | key = key.reshape(*key.shape[:2], -1) 148 | value = value.reshape(*value.shape[:2], -1) 149 | value = value.permute(0, 2, 1).contiguous() 150 | 151 | sim_map = torch.matmul(query, key) 152 | if self.matmul_norm: 153 | sim_map = (self.channels**-.5) * sim_map 154 | sim_map = F.softmax(sim_map, dim=-1) 155 | 156 | context = torch.matmul(sim_map, value) 157 | context = context.permute(0, 2, 1).contiguous() 158 | context = context.reshape(batch_size, -1, *query_feats.shape[2:]) 159 | if self.out_project is not None: 160 | context = self.out_project(context) 161 | return context 162 | -------------------------------------------------------------------------------- /mmseg/models/utils/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | 4 | def nlc_to_nchw(x, hw_shape): 5 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 6 | 7 | Args: 8 | x (Tensor): The input tensor of shape [N, L, C] before convertion. 9 | hw_shape (Sequence[int]): The height and width of output feature map. 10 | 11 | Returns: 12 | Tensor: The output tensor of shape [N, C, H, W] after convertion. 13 | """ 14 | H, W = hw_shape 15 | assert len(x.shape) == 3 16 | B, L, C = x.shape 17 | assert L == H * W, 'The seq_len doesn\'t match H, W' 18 | return x.transpose(1, 2).reshape(B, C, H, W) 19 | 20 | 21 | def nchw_to_nlc(x): 22 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 23 | 24 | Args: 25 | x (Tensor): The input tensor of shape [N, C, H, W] before convertion. 26 | 27 | Returns: 28 | Tensor: The output tensor of shape [N, L, C] after convertion. 29 | """ 30 | assert len(x.shape) == 4 31 | return x.flatten(2).transpose(1, 2).contiguous() 32 | -------------------------------------------------------------------------------- /mmseg/models/utils/up_conv_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule, build_upsample_layer 5 | 6 | 7 | class UpConvBlock(nn.Module): 8 | """Upsample convolution block in decoder for UNet. 9 | 10 | This upsample convolution block consists of one upsample module 11 | followed by one convolution block. The upsample module expands the 12 | high-level low-resolution feature map and the convolution block fuses 13 | the upsampled high-level low-resolution feature map and the low-level 14 | high-resolution feature map from encoder. 15 | 16 | Args: 17 | conv_block (nn.Sequential): Sequential of convolutional layers. 18 | in_channels (int): Number of input channels of the high-level 19 | skip_channels (int): Number of input channels of the low-level 20 | high-resolution feature map from encoder. 21 | out_channels (int): Number of output channels. 22 | num_convs (int): Number of convolutional layers in the conv_block. 23 | Default: 2. 24 | stride (int): Stride of convolutional layer in conv_block. Default: 1. 25 | dilation (int): Dilation rate of convolutional layer in conv_block. 26 | Default: 1. 27 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 28 | memory while slowing down the training speed. Default: False. 29 | conv_cfg (dict | None): Config dict for convolution layer. 30 | Default: None. 31 | norm_cfg (dict | None): Config dict for normalization layer. 32 | Default: dict(type='BN'). 33 | act_cfg (dict | None): Config dict for activation layer in ConvModule. 34 | Default: dict(type='ReLU'). 35 | upsample_cfg (dict): The upsample config of the upsample module in 36 | decoder. Default: dict(type='InterpConv'). If the size of 37 | high-level feature map is the same as that of skip feature map 38 | (low-level feature map from encoder), it does not need upsample the 39 | high-level feature map and the upsample_cfg is None. 40 | dcn (bool): Use deformable convolution in convolutional layer or not. 41 | Default: None. 42 | plugins (dict): plugins for convolutional layers. Default: None. 43 | """ 44 | 45 | def __init__(self, 46 | conv_block, 47 | in_channels, 48 | skip_channels, 49 | out_channels, 50 | num_convs=2, 51 | stride=1, 52 | dilation=1, 53 | with_cp=False, 54 | conv_cfg=None, 55 | norm_cfg=dict(type='BN'), 56 | act_cfg=dict(type='ReLU'), 57 | upsample_cfg=dict(type='InterpConv'), 58 | dcn=None, 59 | plugins=None): 60 | super().__init__() 61 | assert dcn is None, 'Not implemented yet.' 62 | assert plugins is None, 'Not implemented yet.' 63 | 64 | self.conv_block = conv_block( 65 | in_channels=2 * skip_channels, 66 | out_channels=out_channels, 67 | num_convs=num_convs, 68 | stride=stride, 69 | dilation=dilation, 70 | with_cp=with_cp, 71 | conv_cfg=conv_cfg, 72 | norm_cfg=norm_cfg, 73 | act_cfg=act_cfg, 74 | dcn=None, 75 | plugins=None) 76 | if upsample_cfg is not None: 77 | self.upsample = build_upsample_layer( 78 | cfg=upsample_cfg, 79 | in_channels=in_channels, 80 | out_channels=skip_channels, 81 | with_cp=with_cp, 82 | norm_cfg=norm_cfg, 83 | act_cfg=act_cfg) 84 | else: 85 | self.upsample = ConvModule( 86 | in_channels, 87 | skip_channels, 88 | kernel_size=1, 89 | stride=1, 90 | padding=0, 91 | conv_cfg=conv_cfg, 92 | norm_cfg=norm_cfg, 93 | act_cfg=act_cfg) 94 | 95 | def forward(self, skip, x): 96 | """Forward function.""" 97 | 98 | x = self.upsample(x) 99 | out = torch.cat([skip, x], dim=1) 100 | out = self.conv_block(out) 101 | 102 | return out -------------------------------------------------------------------------------- /mmseg/models/utils/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > output_h: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class Upsample(nn.Module): 31 | 32 | def __init__(self, 33 | size=None, 34 | scale_factor=None, 35 | mode='nearest', 36 | align_corners=None): 37 | super().__init__() 38 | self.size = size 39 | if isinstance(scale_factor, tuple): 40 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 41 | else: 42 | self.scale_factor = float(scale_factor) if scale_factor else None 43 | self.mode = mode 44 | self.align_corners = align_corners 45 | 46 | def forward(self, x): 47 | if not self.size: 48 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 49 | else: 50 | size = self.size 51 | return resize(x, size, None, self.mode, self.align_corners) 52 | -------------------------------------------------------------------------------- /mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoding import Encoding 2 | from .wrappers import Upsample, resize 3 | 4 | __all__ = ['Upsample', 'resize', 'Encoding'] 5 | -------------------------------------------------------------------------------- /mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class Encoding(nn.Module): 9 | """Encoding Layer: a learnable residual encoder. 10 | 11 | Input is of shape (batch_size, channels, height, width). 12 | Output is of shape (batch_size, num_codes, channels). 13 | 14 | Args: 15 | channels: dimension of the features or feature channels 16 | num_codes: number of code words 17 | """ 18 | 19 | def __init__(self, channels, num_codes): 20 | super(Encoding, self).__init__() 21 | # init codewords and smoothing factor 22 | self.channels, self.num_codes = channels, num_codes 23 | std = 1. / ((num_codes * channels)**0.5) 24 | # [num_codes, channels] 25 | self.codewords = nn.Parameter( 26 | torch.empty(num_codes, channels, 27 | dtype=torch.float).uniform_(-std, std), 28 | requires_grad=True) 29 | # [num_codes] 30 | self.scale = nn.Parameter( 31 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 32 | requires_grad=True) 33 | 34 | @staticmethod 35 | def scaled_l2(x, codewords, scale): 36 | num_codes, channels = codewords.size() 37 | batch_size = x.size(0) 38 | reshaped_scale = scale.view((1, 1, num_codes)) 39 | expanded_x = x.unsqueeze(2).expand( 40 | (batch_size, x.size(1), num_codes, channels)) 41 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 42 | 43 | scaled_l2_norm = reshaped_scale * ( 44 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 45 | return scaled_l2_norm 46 | 47 | @staticmethod 48 | def aggregate(assignment_weights, x, codewords): 49 | num_codes, channels = codewords.size() 50 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 51 | batch_size = x.size(0) 52 | 53 | expanded_x = x.unsqueeze(2).expand( 54 | (batch_size, x.size(1), num_codes, channels)) 55 | encoded_feat = (assignment_weights.unsqueeze(3) * 56 | (expanded_x - reshaped_codewords)).sum(dim=1) 57 | return encoded_feat 58 | 59 | def forward(self, x): 60 | assert x.dim() == 4 and x.size(1) == self.channels 61 | # [batch_size, channels, height, width] 62 | batch_size = x.size(0) 63 | # [batch_size, height x width, channels] 64 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 65 | # assignment_weights: [batch_size, channels, num_codes] 66 | assignment_weights = F.softmax( 67 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 68 | # aggregate 69 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 70 | return encoded_feat 71 | 72 | def __repr__(self): 73 | repr_str = self.__class__.__name__ 74 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 75 | f'x{self.channels})' 76 | return repr_str 77 | -------------------------------------------------------------------------------- /mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import warnings 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def resize(input, 10 | size=None, 11 | scale_factor=None, 12 | mode='nearest', 13 | align_corners=None, 14 | warning=True): 15 | if warning: 16 | if size is not None and align_corners: 17 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 18 | output_h, output_w = tuple(int(x) for x in size) 19 | if output_h > input_h or output_w > output_h: 20 | if ((output_h > 1 and output_w > 1 and input_h > 1 21 | and input_w > 1) and (output_h - 1) % (input_h - 1) 22 | and (output_w - 1) % (input_w - 1)): 23 | warnings.warn( 24 | f'When align_corners={align_corners}, ' 25 | 'the output would more aligned if ' 26 | f'input size {(input_h, input_w)} is `x+1` and ' 27 | f'out size {(output_h, output_w)} is `nx+1`') 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | 30 | 31 | class Upsample(nn.Module): 32 | 33 | def __init__(self, 34 | size=None, 35 | scale_factor=None, 36 | mode='nearest', 37 | align_corners=None): 38 | super(Upsample, self).__init__() 39 | self.size = size 40 | if isinstance(scale_factor, tuple): 41 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 42 | else: 43 | self.scale_factor = float(scale_factor) if scale_factor else None 44 | self.mode = mode 45 | self.align_corners = align_corners 46 | 47 | def forward(self, x): 48 | if not self.size: 49 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 50 | else: 51 | size = self.size 52 | return resize(x, size, None, self.mode, self.align_corners) 53 | -------------------------------------------------------------------------------- /mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .collect_env import collect_env 2 | from .logger import get_root_logger 3 | from .precision_logger import PrecisionLoggerHook 4 | 5 | __all__ = ['get_root_logger', 'collect_env', 'PrecisionLoggerHook'] 6 | -------------------------------------------------------------------------------- /mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Add code archive generation 3 | 4 | import os 5 | import tarfile 6 | 7 | from mmcv.utils import collect_env as collect_base_env 8 | from mmcv.utils import get_git_hash 9 | 10 | import mmseg 11 | 12 | 13 | def collect_env(): 14 | """Collect the information of the running environments.""" 15 | env_info = collect_base_env() 16 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 17 | 18 | return env_info 19 | 20 | 21 | def is_source_file(x): 22 | if x.isdir() or x.name.endswith(('.py', '.sh', '.yml', '.json', '.txt')) \ 23 | and '.mim' not in x.name and 'jobs/' not in x.name: 24 | # print(x.name) 25 | return x 26 | else: 27 | return None 28 | 29 | 30 | def gen_code_archive(out_dir, file='code.tar.gz'): 31 | archive = os.path.join(out_dir, file) 32 | os.makedirs(os.path.dirname(archive), exist_ok=True) 33 | with tarfile.open(archive, mode='w:gz') as tar: 34 | tar.add('.', filter=is_source_file) 35 | return archive 36 | 37 | 38 | if __name__ == '__main__': 39 | for name, val in collect_env().items(): 40 | print('{}: {}'.format(name, val)) 41 | -------------------------------------------------------------------------------- /mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | 3 | import logging 4 | 5 | from mmcv.utils import get_logger 6 | 7 | 8 | def get_root_logger(log_file=None, log_level=logging.INFO): 9 | """Get the root logger. 10 | 11 | The logger will be initialized if it has not been initialized. By default a 12 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 13 | also be added. The name of the root logger is the top-level package name, 14 | e.g., "mmseg". 15 | 16 | Args: 17 | log_file (str | None): The log filename. If specified, a FileHandler 18 | will be added to the root logger. 19 | log_level (int): The root logger level. Note that only the process of 20 | rank 0 is affected, while other processes will set the level to 21 | "Error" and be silent most of the time. 22 | 23 | Returns: 24 | logging.Logger: The root logger. 25 | """ 26 | 27 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 28 | 29 | return logger 30 | -------------------------------------------------------------------------------- /mmseg/utils/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | @contextlib.contextmanager 9 | def np_local_seed(seed): 10 | state = np.random.get_state() 11 | np.random.seed(seed) 12 | try: 13 | yield 14 | finally: 15 | np.random.set_state(state) 16 | 17 | 18 | def downscale_label_ratio(gt, 19 | scale_factor, 20 | min_ratio, 21 | n_classes, 22 | ignore_index=255): 23 | assert scale_factor > 1 24 | bs, orig_c, orig_h, orig_w = gt.shape 25 | assert orig_c == 1 # 判断这个gt是不是单通道的(因为gt一般是最后argmax之后只剩一个通道了) 26 | trg_h, trg_w = orig_h // scale_factor, orig_w // scale_factor 27 | ignore_substitute = n_classes 28 | 29 | out = gt.clone() # otw. next line would modify original gt 30 | out[out == ignore_index] = ignore_substitute 31 | out = F.one_hot( 32 | out.squeeze(1), num_classes=n_classes + 1).permute(0, 3, 1, 2) 33 | assert list(out.shape) == [bs, n_classes + 1, orig_h, orig_w], out.shape 34 | out = F.avg_pool2d(out.float(), kernel_size=scale_factor) 35 | gt_ratio, out = torch.max(out, dim=1, keepdim=True) 36 | out[out == ignore_substitute] = ignore_index 37 | out[gt_ratio < min_ratio] = ignore_index 38 | assert list(out.shape) == [bs, 1, trg_h, trg_w], out.shape 39 | return out 40 | -------------------------------------------------------------------------------- /mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.16.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | aliyun-python-sdk-core==2.15.0 3 | aliyun-python-sdk-kms==2.16.2 4 | certifi==2024.2.2 5 | cffi==1.16.0 6 | charset-normalizer==3.3.2 7 | click==8.1.7 8 | cmake==3.28.3 9 | colorama==0.4.6 10 | contourpy==1.1.1 11 | crcmod==1.7 12 | cryptography==42.0.5 13 | cycler==0.12.1 14 | filelock==3.13.1 15 | fonttools==4.49.0 16 | idna==3.6 17 | importlib_metadata==7.0.2 18 | importlib_resources==6.3.0 19 | Jinja2==3.1.3 20 | jmespath==0.10.0 21 | joblib==1.3.2 22 | kiwisolver==1.4.5 23 | lit==18.1.1 24 | Markdown==3.6 25 | markdown-it-py==3.0.0 26 | MarkupSafe==2.1.5 27 | matplotlib==3.7.5 28 | mdurl==0.1.2 29 | mmcv==1.7.1 30 | mmengine==0.10.3 31 | model-index==0.1.11 32 | mpmath==1.3.0 33 | networkx==3.1 34 | numpy==1.24.4 35 | nvidia-cublas-cu11==11.10.3.66 36 | nvidia-cuda-cupti-cu11==11.7.101 37 | nvidia-cuda-nvrtc-cu11==11.7.99 38 | nvidia-cuda-runtime-cu11==11.7.99 39 | nvidia-cudnn-cu11==8.5.0.96 40 | nvidia-cufft-cu11==10.9.0.58 41 | nvidia-curand-cu11==10.2.10.91 42 | nvidia-cusolver-cu11==11.4.0.1 43 | nvidia-cusparse-cu11==11.7.4.91 44 | nvidia-nccl-cu11==2.14.3 45 | nvidia-nvtx-cu11==11.7.91 46 | opencv-python-headless==4.9.0.80 47 | opendatalab==0.0.10 48 | openmim==0.3.9 49 | openxlab==0.0.36 50 | ordered-set==4.1.0 51 | oss2==2.17.0 52 | packaging==24.0 53 | pandas==2.0.3 54 | pillow==10.2.0 55 | platformdirs==4.2.0 56 | prettytable==3.10.0 57 | pycparser==2.21 58 | pycryptodome==3.20.0 59 | Pygments==2.17.2 60 | pyparsing==3.1.2 61 | python-dateutil==2.9.0.post0 62 | pytz==2023.4 63 | PyYAML==6.0.1 64 | requests==2.28.2 65 | rich==13.4.2 66 | scikit-learn==1.3.2 67 | scipy==1.10.1 68 | six==1.16.0 69 | sympy==1.12 70 | tabulate==0.9.0 71 | termcolor==2.4.0 72 | threadpoolctl==3.3.0 73 | tomli==2.0.1 74 | torch==2.0.1 75 | torchaudio==2.0.2 76 | torchvision==0.15.2 77 | tqdm==4.65.2 78 | triton==2.0.0 79 | typing_extensions==4.10.0 80 | tzdata==2024.1 81 | urllib3==1.26.18 82 | wcwidth==0.2.13 83 | yapf==0.40.1 84 | zipp==3.18.1 85 | -------------------------------------------------------------------------------- /tools/analyze_logs.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | """Modified from https://github.com/open- 3 | mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" 4 | import argparse 5 | import json 6 | from collections import defaultdict 7 | 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | 12 | def plot_curve(log_dicts, args): 13 | if args.backend is not None: 14 | plt.switch_backend(args.backend) 15 | sns.set_style(args.style) 16 | # if legend is None, use {filename}_{key} as legend 17 | legend = args.legend 18 | if legend is None: 19 | legend = [] 20 | for json_log in args.json_logs: 21 | for metric in args.keys: 22 | legend.append(f'{json_log}_{metric}') 23 | assert len(legend) == (len(args.json_logs) * len(args.keys)) 24 | metrics = args.keys 25 | 26 | num_metrics = len(metrics) 27 | for i, log_dict in enumerate(log_dicts): 28 | epochs = list(log_dict.keys()) 29 | for j, metric in enumerate(metrics): 30 | print(f'plot curve of {args.json_logs[i]}, metric is {metric}') 31 | plot_epochs = [] 32 | plot_iters = [] 33 | plot_values = [] 34 | # In some log files, iters number is not correct, `pre_iter` is 35 | # used to prevent generate wrong lines. 36 | pre_iter = -1 37 | for epoch in epochs: 38 | epoch_logs = log_dict[epoch] 39 | if metric not in epoch_logs.keys(): 40 | continue 41 | if metric in ['mIoU', 'mAcc', 'aAcc']: 42 | plot_epochs.append(epoch) 43 | plot_values.append(epoch_logs[metric][0]) 44 | else: 45 | for idx in range(len(epoch_logs[metric])): 46 | if pre_iter > epoch_logs['iter'][idx]: 47 | continue 48 | pre_iter = epoch_logs['iter'][idx] 49 | plot_iters.append(epoch_logs['iter'][idx]) 50 | plot_values.append(epoch_logs[metric][idx]) 51 | ax = plt.gca() 52 | label = legend[i * num_metrics + j] 53 | if metric in ['mIoU', 'mAcc', 'aAcc']: 54 | ax.set_xticks(plot_epochs) 55 | plt.xlabel('epoch') 56 | plt.plot(plot_epochs, plot_values, label=label, marker='o') 57 | else: 58 | plt.xlabel('iter') 59 | plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) 60 | plt.legend() 61 | if args.title is not None: 62 | plt.title(args.title) 63 | if args.out is None: 64 | plt.show() 65 | else: 66 | print(f'save curve to: {args.out}') 67 | plt.savefig(args.out) 68 | plt.cla() 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser(description='Analyze Json Log') 73 | parser.add_argument( 74 | 'json_logs', 75 | type=str, 76 | nargs='+', 77 | help='path of train log in json format') 78 | parser.add_argument( 79 | '--keys', 80 | type=str, 81 | nargs='+', 82 | default=['mIoU'], 83 | help='the metric that you want to plot') 84 | parser.add_argument('--title', type=str, help='title of figure') 85 | parser.add_argument( 86 | '--legend', 87 | type=str, 88 | nargs='+', 89 | default=None, 90 | help='legend of each plot') 91 | parser.add_argument( 92 | '--backend', type=str, default=None, help='backend of plt') 93 | parser.add_argument( 94 | '--style', type=str, default='dark', help='style of plt') 95 | parser.add_argument('--out', type=str, default=None) 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | def load_json_logs(json_logs): 101 | # load and convert json_logs to log_dict, key is epoch, value is a sub dict 102 | # keys of sub dict is different metrics 103 | # value of sub dict is a list of corresponding values of all iterations 104 | log_dicts = [dict() for _ in json_logs] 105 | for json_log, log_dict in zip(json_logs, log_dicts): 106 | with open(json_log, 'r') as log_file: 107 | for line in log_file: 108 | log = json.loads(line.strip()) 109 | # skip lines without `epoch` field 110 | if 'epoch' not in log: 111 | continue 112 | epoch = log.pop('epoch') 113 | if epoch not in log_dict: 114 | log_dict[epoch] = defaultdict(list) 115 | for k, v in log.items(): 116 | log_dict[epoch][k].append(v) 117 | return log_dicts 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | json_logs = args.json_logs 123 | for json_log in json_logs: 124 | assert json_log.endswith('.json') 125 | log_dicts = load_json_logs(json_logs) 126 | plot_curve(log_dicts, args) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /tools/convert_dataset/ddr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | curPath = os.path.abspath(os.path.dirname(__file__)) 4 | rootPath = os.path.split(curPath)[0] 5 | sys.path.append(rootPath) 6 | 7 | import argparse 8 | import cv2 9 | import numpy as np 10 | 11 | ''' 12 | Please, first rename the original directories of DDR dataset to: 13 | 14 | DDR 15 | |——images 16 | | |——test 17 | | |——val 18 | | |——train 19 | |——labels 20 | | |——test 21 | | | |——EX 22 | | | |——HE 23 | | | |——MA 24 | | | |——SE 25 | | |——val 26 | | | |——EX 27 | | | |——HE 28 | | | |——MA 29 | | | |——SE 30 | | |——train 31 | | | |——EX 32 | | | |——HE 33 | | | |——MA 34 | | | |——SE 35 | 36 | Run the following code to generate .png labels for training 37 | ''' 38 | 39 | CLASSES = dict( 40 | EX=1, 41 | HE=2, 42 | SE=3, 43 | MA=4, 44 | ) 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser(description='Reconstruct the DDR dataset') 48 | parser.add_argument('--root', default='./data/DDR') 49 | args = parser.parse_args() 50 | return args 51 | 52 | def main(): 53 | args = parse_args() 54 | root = args.root 55 | 56 | image_root = os.path.join(root,'images') 57 | label_root = os.path.join(root,'labels') 58 | for subset in os.listdir(image_root): 59 | for eachimage in os.listdir(os.path.join(image_root,subset)): 60 | 61 | img = cv2.imread(os.path.join(image_root,subset,eachimage)) 62 | fused_gt = np.zeros((img.shape[0],img.shape[1]),dtype=np.uint8) 63 | for each_class in CLASSES: 64 | each_class_gt_path = os.path.join(label_root,subset,each_class,eachimage.replace('.jpg','.tif')) 65 | each_class_gt = cv2.imread(each_class_gt_path,cv2.IMREAD_GRAYSCALE) 66 | fused_gt[each_class_gt==255] = CLASSES[each_class] 67 | 68 | label_path = os.path.join(label_root,subset,eachimage.replace('.jpg','.png')) 69 | cv2.imwrite(label_path,fused_gt) 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /tools/convert_dataset/idrid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | curPath = os.path.abspath(os.path.dirname(__file__)) 4 | rootPath = os.path.split(curPath)[0] 5 | sys.path.append(rootPath) 6 | 7 | import argparse 8 | import cv2 9 | import numpy as np 10 | 11 | ''' 12 | Please, first rename the original directories of IDRiD dataset to: 13 | 14 | IDRiD 15 | |——images 16 | | |——test 17 | | |——train 18 | |——labels 19 | | |——test 20 | | | |——EX 21 | | | |——HE 22 | | | |——MA 23 | | | |——SE 24 | | |——train 25 | | | |——EX 26 | | | |——HE 27 | | | |——MA 28 | | | |——SE 29 | 30 | Run the following code to generate .png labels for training 31 | ''' 32 | 33 | CLASSES = dict( 34 | EX=1, 35 | HE=2, 36 | SE=3, 37 | MA=4, 38 | ) 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser(description='Reconstruct the IDRiD dataset') 42 | parser.add_argument('--root', default='./data/IDRiD') 43 | args = parser.parse_args() 44 | return args 45 | 46 | def main(): 47 | args = parse_args() 48 | root = args.root 49 | 50 | image_root = os.path.join(root,'images') 51 | label_root = os.path.join(root,'labels') 52 | for subset in os.listdir(image_root): 53 | for eachimage in os.listdir(os.path.join(image_root,subset)): 54 | 55 | img = cv2.imread(os.path.join(image_root,subset,eachimage)) 56 | fused_gt = np.zeros((img.shape[0],img.shape[1]),dtype=np.uint8) 57 | for each_class in CLASSES: 58 | each_class_gt_path = os.path.join(label_root,subset,each_class,eachimage.replace('.jpg',f'_{each_class}.tif')) 59 | each_class_gt = cv2.imread(each_class_gt_path,cv2.IMREAD_GRAYSCALE) 60 | fused_gt[each_class_gt==76] = CLASSES[each_class] 61 | 62 | label_path = os.path.join(label_root,subset,eachimage.replace('.jpg','.png')) 63 | cv2.imwrite(label_path,fused_gt) 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 10 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | torchrun --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import sys 4 | curPath = os.path.abspath(os.path.dirname(__file__)) 5 | rootPath = os.path.split(curPath)[0] 6 | sys.path.append(rootPath) 7 | 8 | import argparse 9 | 10 | from mmcv import Config 11 | from mmcv.cnn import get_model_complexity_info 12 | 13 | from mmseg.models import build_segmentor 14 | 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Get the FLOPs of a segmentor') 20 | parser.add_argument('config', help='train config file path') 21 | parser.add_argument( 22 | '--shape', 23 | type=int, 24 | nargs='+', 25 | default=[2048, 2048], 26 | help='input image size') 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def main(): 32 | 33 | args = parse_args() 34 | 35 | if len(args.shape) == 1: 36 | input_shape = (3, args.shape[0], args.shape[0]) 37 | elif len(args.shape) == 2: 38 | input_shape = (3, ) + tuple(args.shape) 39 | else: 40 | raise ValueError('invalid input shape') 41 | 42 | cfg = Config.fromfile(args.config) 43 | cfg.model.pretrained = None 44 | model = build_segmentor( 45 | cfg.model, 46 | train_cfg=cfg.get('train_cfg'), 47 | test_cfg=cfg.get('test_cfg')).cuda() 48 | model.eval() 49 | 50 | if hasattr(model, 'forward_dummy'): 51 | model.forward = model.forward_dummy 52 | else: 53 | raise NotImplementedError( 54 | 'FLOPs counter is currently not currently supported with {}'. 55 | format(model.__class__.__name__)) 56 | 57 | flops, params = get_model_complexity_info(model, input_shape) 58 | split_line = '=' * 30 59 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 60 | split_line, input_shape, flops, params)) 61 | print('!!!Please be cautious if you use the results in papers. ' 62 | 'You may need to check if all ops are supported and verify that the ' 63 | 'flops computation is correct.') 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /tools/get_fps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | curPath = os.path.abspath(os.path.dirname(__file__)) 4 | rootPath = os.path.split(curPath)[0] 5 | sys.path.append(rootPath) 6 | 7 | import argparse 8 | import time 9 | 10 | import torch 11 | from mmcv import Config 12 | from mmcv.parallel import MMDataParallel 13 | from mmcv.runner import load_checkpoint 14 | 15 | from mmseg.datasets import build_dataloader, build_dataset 16 | from mmseg.models import build_segmentor 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='MMSeg benchmark a model') 21 | parser.add_argument('config', help='test config file path') 22 | parser.add_argument('--checkpoint', help='checkpoint file') 23 | parser.add_argument('--num', type=int, default=1000) 24 | parser.add_argument('--log-interval', type=int, default=50, help='interval of logging') 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def main(): 30 | args = parse_args() 31 | 32 | cfg = Config.fromfile(args.config) 33 | print(args.config) 34 | # set cudnn_benchmark 35 | torch.backends.cudnn.benchmark = False 36 | cfg.model.pretrained = None 37 | cfg.data.test.test_mode = True 38 | 39 | # build the dataloader 40 | dataset = build_dataset(cfg.data.test) 41 | data_loader = build_dataloader( 42 | dataset, 43 | samples_per_gpu=1, 44 | workers_per_gpu=cfg.data.workers_per_gpu, 45 | dist=False, 46 | shuffle=False) 47 | 48 | # build the model 49 | model = build_segmentor(cfg.model, train_cfg=None) 50 | 51 | # no checkpoint is also ok 52 | if args.checkpoint: 53 | load_checkpoint(model, args.checkpoint, map_location='cpu') 54 | model = MMDataParallel(model, device_ids=[0]) 55 | model.eval() 56 | 57 | # the first several iterations may be very slow so skip them 58 | num_warmup = 5 59 | pure_inf_time = 0 60 | total_iters = args.num 61 | 62 | # benchmark with 1000 image and take the average 63 | end_flag = False 64 | total_i = 0 65 | while True: 66 | for i, data in enumerate(data_loader): 67 | torch.cuda.synchronize() 68 | start_time = time.perf_counter() 69 | 70 | with torch.no_grad(): 71 | model(return_loss=False, rescale=True, **data) 72 | 73 | torch.cuda.synchronize() 74 | elapsed = time.perf_counter() - start_time 75 | 76 | if i >= num_warmup: 77 | pure_inf_time += elapsed 78 | total_i += 1 79 | if total_i % args.log_interval == 0: 80 | fps = 1. * total_i / pure_inf_time 81 | print(f'Done image [{total_i:<3}/ {total_iters}], ' 82 | f'fps: {fps:.2f} img / s') 83 | 84 | if total_i == total_iters: 85 | fps = 1. * total_i / pure_inf_time 86 | print(f'Overall fps: {fps:.2f} img / s') 87 | end_flag = True 88 | break 89 | 90 | if end_flag: 91 | break 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0 2 | # Modifications: Also save config as json file 3 | 4 | import argparse 5 | 6 | from mmcv import Config, DictAction 7 | 8 | from mmseg.apis import init_segmentor 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description='Print the whole config') 13 | parser.add_argument('config', help='config file path') 14 | parser.add_argument( 15 | '--graph', action='store_true', help='print the models graph') 16 | parser.add_argument( 17 | '--options', nargs='+', action=DictAction, help='arguments in dict') 18 | args = parser.parse_args() 19 | 20 | return args 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | cfg = Config.fromfile(args.config) 27 | if args.options is not None: 28 | cfg.merge_from_dict(args.options) 29 | print(f'Config:\n{cfg.pretty_text}') 30 | # dump config 31 | cfg.dump('example.py') 32 | super(Config, cfg).__setattr__('_filename', 'example.json') 33 | cfg.dump('example.json') 34 | # dump models graph 35 | if args.graph: 36 | model = init_segmentor(args.config, device='cpu') 37 | print(f'Model graph:\n{str(model)}') 38 | with open('example-graph.txt', 'w') as f: 39 | f.writelines(str(model)) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import sys 7 | import time 8 | 9 | curPath = os.path.abspath(os.path.dirname(__file__)) 10 | rootPath = os.path.split(curPath)[0] 11 | sys.path.append(rootPath) 12 | 13 | import mmcv 14 | import torch 15 | from mmcv.runner import init_dist 16 | from mmcv.utils import Config, DictAction, get_git_hash 17 | 18 | from mmseg import __version__ 19 | from mmseg.apis import set_random_seed, train_segmentor 20 | from mmseg.datasets import build_dataset 21 | from mmseg.models.builder import build_train_model 22 | from mmseg.utils import collect_env, get_root_logger 23 | from mmseg.utils.collect_env import gen_code_archive 24 | 25 | 26 | 27 | def parse_args(args): 28 | parser = argparse.ArgumentParser(description='Train a segmentor') 29 | parser.add_argument('config', help='train config file path') 30 | parser.add_argument('--work-dir', help='the dir to save logs and models') 31 | parser.add_argument( 32 | '--load-from', help='the checkpoint file to load weights from') 33 | parser.add_argument( 34 | '--resume-from', help='the checkpoint file to resume from') 35 | parser.add_argument( 36 | '--no-validate', 37 | action='store_true', 38 | help='whether not to evaluate the checkpoint during training') 39 | group_gpus = parser.add_mutually_exclusive_group() 40 | group_gpus.add_argument( 41 | '--gpus', 42 | type=int, 43 | help='number of gpus to use ' 44 | '(only applicable to non-distributed training)') 45 | group_gpus.add_argument( 46 | '--gpu-ids', 47 | type=int, 48 | nargs='+', 49 | help='ids of gpus to use ' 50 | '(only applicable to non-distributed training)') 51 | parser.add_argument('--seed', type=int, default=None, help='random seed') 52 | parser.add_argument( 53 | '--deterministic', 54 | action='store_true', 55 | help='whether to set deterministic options for CUDNN backend.') 56 | parser.add_argument( 57 | '--options', nargs='+', action=DictAction, help='custom options') 58 | parser.add_argument( 59 | '--launcher', 60 | choices=['none', 'pytorch', 'slurm', 'mpi'], 61 | default='none', 62 | help='job launcher') 63 | parser.add_argument('--local_rank', type=int, default=0) 64 | args = parser.parse_args(args) 65 | if 'LOCAL-RANK' not in os.environ: 66 | os.environ['LOCAL-RANK'] = str(args.local_rank) 67 | 68 | return args 69 | 70 | 71 | def main(args): 72 | args = parse_args(args) 73 | 74 | cfg = Config.fromfile(args.config) 75 | if args.options is not None: 76 | cfg.merge_from_dict(args.options) 77 | # set cudnn_benchmark 78 | if cfg.get('cudnn_benchmark', False): 79 | torch.backends.cudnn.benchmark = True 80 | 81 | # work_dir is determined in this priority: CLI > segment in file > filename 82 | if args.work_dir is not None: 83 | # update configs according to CLI args if args.work_dir is not None 84 | cfg.work_dir = args.work_dir 85 | elif cfg.get('work_dir', None) is None: 86 | # use config filename as default work_dir if cfg.work_dir is None 87 | cfg.work_dir = osp.join('./work_dirs', 88 | osp.splitext(osp.basename(args.config))[0]) 89 | cfg.model.train_cfg.work_dir = cfg.work_dir 90 | if args.load_from is not None: 91 | cfg.load_from = args.load_from 92 | if args.resume_from is not None: 93 | cfg.resume_from = args.resume_from 94 | if args.gpu_ids is not None: 95 | cfg.gpu_ids = args.gpu_ids 96 | else: 97 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 98 | 99 | # init distributed env first, since logger depends on the dist info. 100 | if args.launcher == 'none': 101 | distributed = False 102 | else: 103 | distributed = True 104 | init_dist(args.launcher, **cfg.dist_params) 105 | 106 | # create work_dir 107 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 108 | # dump config 109 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 110 | # snapshot source code 111 | #gen_code_archive(cfg.work_dir) 112 | # init the logger before other steps 113 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 114 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 115 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 116 | 117 | # init the meta dict to record some important information such as 118 | # environment info and seed, which will be logged 119 | meta = dict() 120 | 121 | # log env info 122 | env_info_dict = collect_env() 123 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 124 | dash_line = '-' * 60 + '\n' 125 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 126 | dash_line) 127 | meta['env_info'] = env_info 128 | 129 | # log some basic info 130 | logger.info(f'Distributed training: {distributed}') 131 | logger.info(f'Config:\n{cfg.pretty_text}') 132 | 133 | # set random seeds 134 | if args.seed is None and 'seed' in cfg: 135 | args.seed = cfg['seed'] 136 | if args.seed is not None: 137 | deterministic = args.deterministic or cfg.get('deterministic') 138 | logger.info(f'Set random seed to {args.seed}, deterministic: ' 139 | f'{deterministic}') 140 | set_random_seed(args.seed, deterministic=deterministic) 141 | cfg.seed = args.seed 142 | meta['seed'] = args.seed 143 | meta['exp_name'] = osp.splitext(osp.basename(args.config))[0] 144 | 145 | model = build_train_model( 146 | cfg, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) 147 | model.init_weights() 148 | 149 | logger.info(model) 150 | 151 | datasets = [build_dataset(cfg.data.train)] 152 | if len(cfg.workflow) == 2: 153 | val_dataset = copy.deepcopy(cfg.data.val) 154 | val_dataset.pipeline = cfg.data.train.pipeline 155 | datasets.append(build_dataset(val_dataset)) 156 | if cfg.checkpoint_config is not None: 157 | # save mmseg version, config file content and class names in 158 | # checkpoints as meta data 159 | cfg.checkpoint_config.meta = dict( 160 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 161 | config=cfg.pretty_text, 162 | CLASSES=datasets[0].CLASSES, 163 | PALETTE=datasets[0].PALETTE) 164 | # add an attribute for visualization convenience 165 | model.CLASSES = datasets[0].CLASSES 166 | # passing checkpoint meta for saving best checkpoint 167 | meta.update(cfg.checkpoint_config.meta) 168 | train_segmentor( 169 | model, 170 | datasets, 171 | cfg, 172 | distributed=distributed, 173 | validate=(not args.no_validate), 174 | timestamp=timestamp, 175 | meta=meta) 176 | 177 | 178 | if __name__ == '__main__': 179 | main(sys.argv[1:]) 180 | --------------------------------------------------------------------------------