├── docs ├── Readme.md └── figure1.png ├── mmseg ├── models │ ├── Readme.md │ ├── __pycache__ │ │ ├── builder.cpython-36.pyc │ │ └── __init__.cpython-36.pyc │ ├── necks │ │ ├── __pycache__ │ │ │ ├── fpn.cpython-36.pyc │ │ │ ├── jpu.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── ic_neck.cpython-36.pyc │ │ │ ├── mla_neck.cpython-36.pyc │ │ │ └── multilevel_neck.cpython-36.pyc │ │ └── __init__.py │ ├── losses │ │ ├── __pycache__ │ │ │ ├── utils.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── accuracy.cpython-36.pyc │ │ │ ├── dice_loss.cpython-36.pyc │ │ │ ├── focal_loss.cpython-36.pyc │ │ │ ├── lovasz_loss.cpython-36.pyc │ │ │ └── cross_entropy_loss.cpython-36.pyc │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── utils.py │ │ ├── dice_loss.py │ │ └── cross_entropy_loss.py │ ├── utils │ │ ├── __pycache__ │ │ │ ├── embed.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── se_layer.cpython-36.pyc │ │ │ ├── res_layer.cpython-36.pyc │ │ │ ├── shape_convert.cpython-36.pyc │ │ │ ├── up_conv_block.cpython-36.pyc │ │ │ ├── make_divisible.cpython-36.pyc │ │ │ ├── inverted_residual.cpython-36.pyc │ │ │ └── self_attention_block.cpython-36.pyc │ │ ├── __init__.py │ │ └── res_layer.py │ ├── backbones │ │ ├── __pycache__ │ │ │ ├── mit.cpython-36.pyc │ │ │ ├── stdc.cpython-36.pyc │ │ │ ├── swin.cpython-36.pyc │ │ │ ├── unet.cpython-36.pyc │ │ │ ├── vit.cpython-36.pyc │ │ │ ├── cgnet.cpython-36.pyc │ │ │ ├── erfnet.cpython-36.pyc │ │ │ ├── hrnet.cpython-36.pyc │ │ │ ├── icnet.cpython-36.pyc │ │ │ ├── resnest.cpython-36.pyc │ │ │ ├── resnet.cpython-36.pyc │ │ │ ├── resnext.cpython-36.pyc │ │ │ ├── twins.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── bisenetv1.cpython-36.pyc │ │ │ ├── bisenetv2.cpython-36.pyc │ │ │ ├── fast_scnn.cpython-36.pyc │ │ │ ├── mobilenet_v2.cpython-36.pyc │ │ │ ├── mobilenet_v3.cpython-36.pyc │ │ │ └── timm_backbone.cpython-36.pyc │ │ └── __init__.py │ ├── segmentors │ │ ├── __pycache__ │ │ │ ├── base.cpython-36.pyc │ │ │ ├── encoder_decoder.cpython-36.pyc │ │ │ └── cascade_encoder_decoder.cpython-36.pyc │ │ ├── __init__.py │ │ └── cascade_encoder_decoder.py │ ├── decode_heads │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── aff_head.cpython-36.pyc │ │ │ ├── ann_head.cpython-36.pyc │ │ │ ├── apc_head.cpython-36.pyc │ │ │ ├── aspp_head.cpython-36.pyc │ │ │ ├── cc_head.cpython-36.pyc │ │ │ ├── da_head.cpython-36.pyc │ │ │ ├── dm_head.cpython-36.pyc │ │ │ ├── dnl_head.cpython-36.pyc │ │ │ ├── dpt_head.cpython-36.pyc │ │ │ ├── ema_head.cpython-36.pyc │ │ │ ├── enc_head.cpython-36.pyc │ │ │ ├── fcn_head.cpython-36.pyc │ │ │ ├── fpn_head.cpython-36.pyc │ │ │ ├── gc_head.cpython-36.pyc │ │ │ ├── ham_head.cpython-36.pyc │ │ │ ├── isa_head.cpython-36.pyc │ │ │ ├── nl_head.cpython-36.pyc │ │ │ ├── ocr_head.cpython-36.pyc │ │ │ ├── psa_head.cpython-36.pyc │ │ │ ├── psp_head.cpython-36.pyc │ │ │ ├── stdc_head.cpython-36.pyc │ │ │ ├── uper_head.cpython-36.pyc │ │ │ ├── decode_head.cpython-36.pyc │ │ │ ├── lraspp_head.cpython-36.pyc │ │ │ ├── point_head.cpython-36.pyc │ │ │ ├── segformer_head.cpython-36.pyc │ │ │ ├── sep_aspp_head.cpython-36.pyc │ │ │ ├── sep_fcn_head.cpython-36.pyc │ │ │ ├── setr_mla_head.cpython-36.pyc │ │ │ ├── setr_up_head.cpython-36.pyc │ │ │ ├── cascade_decode_head.cpython-36.pyc │ │ │ └── segmenter_mask_head.cpython-36.pyc │ │ ├── __init__.py │ │ └── aff_head.py │ ├── __init__.py │ └── builder.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── version.cpython-36.pyc ├── apis │ ├── __pycache__ │ │ ├── test.cpython-36.pyc │ │ ├── train.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── inference.cpython-36.pyc │ ├── __init__.py │ ├── inference.py │ └── train.py ├── utils │ ├── __pycache__ │ │ ├── misc.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── logger.cpython-36.pyc │ │ ├── set_env.cpython-36.pyc │ │ └── collect_env.cpython-36.pyc │ ├── __init__.py │ ├── collect_env.py │ ├── logger.py │ ├── misc.py │ └── set_env.py ├── core │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── seg │ │ ├── __pycache__ │ │ │ ├── builder.cpython-36.pyc │ │ │ └── __init__.cpython-36.pyc │ │ ├── sampler │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── base_pixel_sampler.cpython-36.pyc │ │ │ │ └── ohem_pixel_sampler.cpython-36.pyc │ │ │ ├── __init__.py │ │ │ ├── base_pixel_sampler.py │ │ │ └── ohem_pixel_sampler.py │ │ ├── __init__.py │ │ └── builder.py │ ├── utils │ │ ├── __pycache__ │ │ │ ├── misc.cpython-36.pyc │ │ │ └── __init__.cpython-36.pyc │ │ ├── __init__.py │ │ └── misc.py │ ├── evaluation │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── metrics.cpython-36.pyc │ │ │ ├── eval_hooks.cpython-36.pyc │ │ │ └── class_names.cpython-36.pyc │ │ ├── __init__.py │ │ └── eval_hooks.py │ └── __init__.py ├── datasets │ ├── __pycache__ │ │ ├── ade.cpython-36.pyc │ │ ├── drive.cpython-36.pyc │ │ ├── hrf.cpython-36.pyc │ │ ├── isaid.cpython-36.pyc │ │ ├── isprs.cpython-36.pyc │ │ ├── stare.cpython-36.pyc │ │ ├── voc.cpython-36.pyc │ │ ├── builder.cpython-36.pyc │ │ ├── custom.cpython-36.pyc │ │ ├── loveda.cpython-36.pyc │ │ ├── potsdam.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── chase_db1.cpython-36.pyc │ │ ├── cityscapes.cpython-36.pyc │ │ ├── coco_stuff.cpython-36.pyc │ │ ├── dark_zurich.cpython-36.pyc │ │ ├── night_driving.cpython-36.pyc │ │ ├── pascal_context.cpython-36.pyc │ │ └── dataset_wrappers.cpython-36.pyc │ ├── pipelines │ │ ├── __pycache__ │ │ │ ├── compose.cpython-36.pyc │ │ │ ├── loading.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── formatting.cpython-36.pyc │ │ │ ├── transforms.cpython-36.pyc │ │ │ └── test_time_aug.cpython-36.pyc │ │ ├── formating.py │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── test_time_aug.py │ │ └── loading.py │ ├── __init__.py │ ├── coco_stuff.py │ ├── builder.py │ ├── ade.py │ └── cityscapes.py ├── ops │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── encoding.cpython-36.pyc │ │ └── wrappers.cpython-36.pyc │ ├── __init__.py │ ├── wrappers.py │ └── encoding.py ├── version.py └── __init__.py ├── tools ├── __pycache__ │ └── afformer.cpython-36.pyc ├── dist_train.sh ├── dist_test.sh ├── cmd.sh ├── get_flops_fps.py └── train.py ├── configs ├── _base_ │ ├── default_runtime.py │ ├── schedules │ │ ├── schedule_20k.py │ │ ├── schedule_40k.py │ │ ├── schedule_80k.py │ │ └── schedule_160k.py │ ├── models │ │ └── afformer.py │ └── datasets │ │ ├── cityscapes_640x1024.py │ │ ├── cityscapes_768x1024.py │ │ ├── cityscapes_1024x1024.py │ │ ├── cityscapes.py │ │ ├── coco_stuff10k.py │ │ ├── ade20k_std.py │ │ └── ade20k.py └── AFFormer │ ├── AFFormer_base_ade20k.py │ ├── AFFormer_tiny_ade20k.py │ ├── AFFormer_small_ade20k.py │ ├── AFFormer_base_cityscapes.py │ ├── AFFormer_small_cityscapes.py │ └── AFFormer_tiny_cityscapes.py └── README.md /docs/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mmseg/models/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/docs/figure1.png -------------------------------------------------------------------------------- /mmseg/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/__pycache__/version.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/__pycache__/version.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/apis/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/apis/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/afformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/tools/__pycache__/afformer.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/apis/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/apis/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/apis/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/apis/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/apis/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/apis/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/ade.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/ade.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/drive.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/drive.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/hrf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/hrf.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/isaid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/isaid.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/isprs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/isprs.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/stare.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/stare.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/voc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/voc.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/__pycache__/builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/__pycache__/builder.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/ops/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/ops/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/ops/__pycache__/encoding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/ops/__pycache__/encoding.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/ops/__pycache__/wrappers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/ops/__pycache__/wrappers.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/utils/__pycache__/set_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/utils/__pycache__/set_env.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/seg/__pycache__/builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/seg/__pycache__/builder.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/builder.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/custom.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/custom.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/loveda.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/loveda.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/potsdam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/potsdam.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/necks/__pycache__/fpn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/necks/__pycache__/fpn.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/necks/__pycache__/jpu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/necks/__pycache__/jpu.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/seg/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/seg/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/chase_db1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/chase_db1.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/cityscapes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/cityscapes.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/coco_stuff.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/coco_stuff.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/losses/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/losses/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/embed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/embed.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/utils/__pycache__/collect_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/utils/__pycache__/collect_env.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/dark_zurich.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/dark_zurich.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/mit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/mit.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/stdc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/stdc.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/swin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/swin.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/vit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/vit.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/necks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/necks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/necks/__pycache__/ic_neck.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/necks/__pycache__/ic_neck.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/necks/__pycache__/mla_neck.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/necks/__pycache__/mla_neck.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/se_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/se_layer.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/evaluation/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/evaluation/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/evaluation/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/evaluation/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/night_driving.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/night_driving.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/pascal_context.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/pascal_context.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/cgnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/cgnet.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/erfnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/erfnet.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/hrnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/hrnet.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/icnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/icnet.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/resnest.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/resnest.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/resnext.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/resnext.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/twins.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/twins.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/losses/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/losses/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/losses/__pycache__/accuracy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/losses/__pycache__/accuracy.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/losses/__pycache__/dice_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/losses/__pycache__/dice_loss.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/losses/__pycache__/focal_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/losses/__pycache__/focal_loss.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/segmentors/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/segmentors/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/res_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/res_layer.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/evaluation/__pycache__/eval_hooks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/evaluation/__pycache__/eval_hooks.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/seg/sampler/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/__pycache__/dataset_wrappers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/__pycache__/dataset_wrappers.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__pycache__/compose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/pipelines/__pycache__/compose.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__pycache__/loading.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/pipelines/__pycache__/loading.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/bisenetv1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/bisenetv1.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/bisenetv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/bisenetv2.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/fast_scnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/fast_scnn.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/losses/__pycache__/lovasz_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/losses/__pycache__/lovasz_loss.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/shape_convert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/shape_convert.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/up_conv_block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/up_conv_block.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/evaluation/__pycache__/class_names.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/evaluation/__pycache__/class_names.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/pipelines/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__pycache__/formatting.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/pipelines/__pycache__/formatting.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/pipelines/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/aff_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/aff_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/ann_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/ann_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/apc_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/apc_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/aspp_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/aspp_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/cc_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/cc_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/da_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/da_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/dm_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/dm_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/dnl_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/dnl_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/dpt_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/dpt_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/ema_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/ema_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/enc_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/enc_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/fcn_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/fcn_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/fpn_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/fpn_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/gc_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/gc_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/ham_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/ham_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/isa_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/isa_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/nl_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/nl_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/ocr_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/ocr_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/psa_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/psa_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/psp_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/psp_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/stdc_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/stdc_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/uper_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/uper_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/necks/__pycache__/multilevel_neck.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/necks/__pycache__/multilevel_neck.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/make_divisible.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/make_divisible.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/backbones/__pycache__/timm_backbone.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/backbones/__pycache__/timm_backbone.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/decode_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/decode_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/point_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/point_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/inverted_residual.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/inverted_residual.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__pycache__/test_time_aug.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/datasets/pipelines/__pycache__/test_time_aug.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/segformer_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/segformer_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__pycache__/self_attention_block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/utils/__pycache__/self_attention_block.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__pycache__/base_pixel_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/seg/sampler/__pycache__/base_pixel_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__pycache__/ohem_pixel_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/core/seg/sampler/__pycache__/ohem_pixel_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongbo811/AFFormer/HEAD/mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | ''' 4 | from .res_layer import ResLayer 5 | 6 | 7 | __all__ = [ 8 | 'ResLayer' 9 | ] 10 | -------------------------------------------------------------------------------- /mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import * # noqa: F401, F403 3 | from .seg import * # noqa: F401, F403 4 | from .utils import * # noqa: F401, F403 5 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | from .aff_head import CLS 5 | from .decode_head import BaseDecodeHead 6 | __all__ = ['CLS', 'BaseDecodeHead'] 7 | 8 | -------------------------------------------------------------------------------- /mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/necks/__init__.py 5 | ''' 6 | from .fpn import FPN 7 | 8 | __all__ = ['FPN'] 9 | -------------------------------------------------------------------------------- /mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/utils/__init__.py 5 | ''' 6 | from .misc import add_prefix 7 | 8 | __all__ = ['add_prefix'] 9 | -------------------------------------------------------------------------------- /mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/backbone/__init__.py 5 | ''' 6 | 7 | from .resnet import ResNet 8 | 9 | __all__ = [ 10 | 'ResNet' 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/ops/__init__.py 5 | ''' 6 | from .encoding import Encoding 7 | from .wrappers import Upsample, resize 8 | 9 | __all__ = ['Upsample', 'resize', 'Encoding'] 10 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/seg/sampler/__init__.py 5 | ''' 6 | from .base_pixel_sampler import BasePixelSampler 7 | from .ohem_pixel_sampler import OHEMPixelSampler 8 | 9 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 10 | -------------------------------------------------------------------------------- /mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/seg/__init__.py 5 | ''' 6 | from .builder import build_pixel_sampler 7 | from .sampler import BasePixelSampler, OHEMPixelSampler 8 | 9 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 10 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_20k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=20000) 8 | checkpoint_config = dict(by_epoch=False, interval=2000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_80k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=8000) 9 | evaluation = dict(interval=8000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/segmentors/__init__.py 5 | ''' 6 | from .base import BaseSegmentor 7 | from .cascade_encoder_decoder import CascadeEncoderDecoder 8 | from .encoder_decoder import EncoderDecoder 9 | 10 | __all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] 11 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=16000) 9 | evaluation = dict(interval=5000, metric='mIoU', pre_eval=True) 10 | -------------------------------------------------------------------------------- /mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/seg/builder.py 5 | ''' 6 | from mmcv.utils import Registry, build_from_cfg 7 | 8 | PIXEL_SAMPLERS = Registry('pixel sampler') 9 | 10 | 11 | def build_pixel_sampler(cfg, **default_args): 12 | """Build pixel sampler for segmentation map.""" 13 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 14 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/formating.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/formating.py 5 | ''' 6 | import warnings 7 | 8 | from .formatting import * 9 | 10 | warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be ' 11 | 'deprecated in 2021, please replace it with ' 12 | 'mmseg.datasets.pipelines.formatting.') 13 | -------------------------------------------------------------------------------- /mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/utils/__init__.py 5 | ''' 6 | from .collect_env import collect_env 7 | from .logger import get_root_logger 8 | from .misc import find_latest_checkpoint 9 | from .set_env import setup_multi_processes 10 | 11 | __all__ = [ 12 | 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 13 | 'setup_multi_processes' 14 | ] 15 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | 2 | CONFIG=$1 3 | GPUS=$2 4 | NNODES=${NNODES:-1} 5 | NODE_RANK=${NODE_RANK:-0} 6 | PORT=${PORT:-29500} 7 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 8 | 9 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 10 | python -m torch.distributed.launch \ 11 | --nnodes=$NNODES \ 12 | --node_rank=$NODE_RANK \ 13 | --master_addr=$MASTER_ADDR \ 14 | --nproc_per_node=$GPUS \ 15 | --master_port=$PORT \ 16 | $(dirname "$0")/train.py \ 17 | $CONFIG \ 18 | --seed 0 \ 19 | --launcher pytorch ${@:3} 20 | -------------------------------------------------------------------------------- /mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .train import (get_root_logger, init_random_seed, set_random_seed, 5 | train_segmentor) 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 9 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 10 | 'show_result_pyplot', 'init_random_seed' 11 | ] 12 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | 2 | CONFIG=$1 3 | CHECKPOINT=$2 4 | GPUS=$3 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 11 | python -m torch.distributed.launch \ 12 | --nnodes=$NNODES \ 13 | --node_rank=$NODE_RANK \ 14 | --master_addr=$MASTER_ADDR \ 15 | --nproc_per_node=$GPUS \ 16 | --master_port=$PORT \ 17 | $(dirname "$0")/test.py \ 18 | $CONFIG \ 19 | $CHECKPOINT \ 20 | --launcher pytorch \ 21 | ${@:4} -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/seg/sampler/base_pixel_smaplesr.py 5 | ''' 6 | from abc import ABCMeta, abstractmethod 7 | 8 | 9 | class BasePixelSampler(metaclass=ABCMeta): 10 | """Base class of pixel sampler.""" 11 | 12 | def __init__(self, **kwargs): 13 | pass 14 | 15 | @abstractmethod 16 | def sample(self, seg_logit, seg_label): 17 | """Placeholder for sample function.""" 18 | -------------------------------------------------------------------------------- /mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | ''' 4 | from .backbones import * # noqa: F401,F403 5 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 6 | build_head, build_loss, build_segmentor) 7 | from .decode_heads import * # noqa: F401,F403 8 | from .losses import * # noqa: F401,F403 9 | from .necks import * # noqa: F401,F403 10 | from .segmentors import * # noqa: F401,F403 11 | 12 | __all__ = [ 13 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 14 | 'build_head', 'build_loss', 'build_segmentor' 15 | ] 16 | -------------------------------------------------------------------------------- /mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/utils/misc.py 5 | ''' 6 | def add_prefix(inputs, prefix): 7 | """Add prefix for dict. 8 | 9 | Args: 10 | inputs (dict): The input dict with str keys. 11 | prefix (str): The prefix to add. 12 | 13 | Returns: 14 | 15 | dict: The dict with keys updated with ``prefix``. 16 | """ 17 | 18 | outputs = dict() 19 | for name, value in inputs.items(): 20 | outputs[f'{prefix}.{name}'] = value 21 | 22 | return outputs 23 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/apis/train.py 5 | ''' 6 | from .class_names import get_classes, get_palette 7 | from .eval_hooks import DistEvalHook, EvalHook 8 | from .metrics import (eval_metrics, intersect_and_union, mean_dice, 9 | mean_fscore, mean_iou, pre_eval_to_metrics) 10 | 11 | __all__ = [ 12 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 13 | 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', 14 | 'intersect_and_union' 15 | ] 16 | -------------------------------------------------------------------------------- /mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/utils/colect_env.py 5 | ''' 6 | from mmcv.utils import collect_env as collect_base_env 7 | from mmcv.utils import get_git_hash 8 | 9 | import mmseg 10 | 11 | 12 | def collect_env(): 13 | """Collect the information of the running environments.""" 14 | env_info = collect_base_env() 15 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 16 | 17 | return env_info 18 | 19 | 20 | if __name__ == '__main__': 21 | for name, val in collect_env().items(): 22 | print('{}: {}'.format(name, val)) 23 | -------------------------------------------------------------------------------- /mmseg/version.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/version.py 5 | ''' 6 | 7 | __version__ = '0.21.1' 8 | 9 | 10 | def parse_version_info(version_str): 11 | version_info = [] 12 | for x in version_str.split('.'): 13 | if x.isdigit(): 14 | version_info.append(int(x)) 15 | elif x.find('rc') != -1: 16 | patch_version = x.split('rc') 17 | version_info.append(int(patch_version[0])) 18 | version_info.append(f'rc{patch_version[1]}') 19 | return tuple(version_info) 20 | 21 | 22 | version_info = parse_version_info(__version__) 23 | -------------------------------------------------------------------------------- /mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/losses/__init__.py 5 | ''' 6 | from .accuracy import Accuracy, accuracy 7 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 8 | cross_entropy, mask_cross_entropy) 9 | from .dice_loss import DiceLoss 10 | from .focal_loss import FocalLoss 11 | from .lovasz_loss import LovaszLoss 12 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 13 | 14 | __all__ = [ 15 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 16 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 17 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 18 | 'FocalLoss' 19 | ] 20 | -------------------------------------------------------------------------------- /configs/_base_/models/afformer.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) 4 | model = dict( 5 | type='EncoderDecoder', 6 | pretrained=None, 7 | backbone=dict( 8 | type='afformer_base', 9 | strides=[4, 2, 2, 2]), 10 | decode_head=dict( 11 | type='CLS', 12 | in_channels=256, 13 | in_index=[0, 1, 2 ,3], 14 | channels=512, 15 | aff_channels=512, 16 | dropout_ratio=0.1, 17 | num_classes=150, 18 | norm_cfg=ham_norm_cfg, 19 | align_corners=False, 20 | loss_decode=dict( 21 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 22 | # model training and testing settings 23 | train_cfg=dict(), 24 | test_cfg=dict(mode='whole')) 25 | 26 | -------------------------------------------------------------------------------- /mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/__inti__.py 5 | ''' 6 | from .ade import ADE20KDataset 7 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 8 | from .cityscapes import CityscapesDataset 9 | from .coco_stuff import COCOStuffDataset 10 | from .custom import CustomDataset 11 | from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, 12 | RepeatDataset) 13 | 14 | 15 | __all__ = [ 16 | 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 17 | 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', 18 | 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 19 | 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 20 | 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', 21 | 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', 22 | 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset' 23 | ] 24 | -------------------------------------------------------------------------------- /configs/AFFormer/AFFormer_base_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/afformer.py', '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | pretrained='./pretained_weight/AFFormer_base_ImageNet1k.pth', 7 | backbone=dict( 8 | type='afformer_base', 9 | strides=[4, 2, 2, 2]), 10 | decode_head=dict( 11 | in_channels=[216], 12 | in_index=[3], 13 | channels=256, 14 | aff_channels=256, 15 | aff_kwargs=dict(MD_R=16), 16 | num_classes=150 17 | ) 18 | ) 19 | 20 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 21 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0003, betas=(0.9, 0.999), weight_decay=0.01) 22 | 23 | lr_config = dict(_delete_=True, policy='poly', 24 | warmup='linear', 25 | warmup_iters=1500, 26 | warmup_ratio=1e-6, 27 | power=1.0, min_lr=0.0, by_epoch=False) 28 | 29 | # By default, models are trained on 2 GPUs with 8 images per GPU 30 | data=dict(samples_per_gpu=8, workers_per_gpu=8) 31 | find_unused_parameters=True 32 | -------------------------------------------------------------------------------- /configs/AFFormer/AFFormer_tiny_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/afformer.py', '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | pretrained='./pretained_weight/AFFormer_tiny_ImageNet1k.pth', 7 | backbone=dict( 8 | type='afformer_tiny', 9 | strides=[4, 2, 2, 2]), 10 | decode_head=dict( 11 | in_channels=[216], 12 | in_index=[3], 13 | channels=256, 14 | aff_channels=256, 15 | aff_kwargs=dict(MD_R=16), 16 | num_classes=150 17 | ) 18 | ) 19 | 20 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 21 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0003, betas=(0.9, 0.999), weight_decay=0.01) 22 | 23 | lr_config = dict(_delete_=True, policy='poly', 24 | warmup='linear', 25 | warmup_iters=1500, 26 | warmup_ratio=1e-6, 27 | power=1.0, min_lr=0.0, by_epoch=False) 28 | 29 | # By default, models are trained on 2 GPUs with 8 images per GPU 30 | data=dict(samples_per_gpu=8, workers_per_gpu=8) 31 | find_unused_parameters=True 32 | -------------------------------------------------------------------------------- /configs/AFFormer/AFFormer_small_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/afformer.py', '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | pretrained='./pretained_weight/AFFormer_small_ImageNet1k.pth', 7 | backbone=dict( 8 | type='afformer_small', 9 | strides=[4, 2, 2, 2]), 10 | decode_head=dict( 11 | in_channels=[216], 12 | in_index=[3], 13 | channels=256, 14 | aff_channels=256, 15 | aff_kwargs=dict(MD_R=16), 16 | num_classes=150 17 | ) 18 | ) 19 | 20 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 21 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0003, betas=(0.9, 0.999), weight_decay=0.01) 22 | 23 | lr_config = dict(_delete_=True, policy='poly', 24 | warmup='linear', 25 | warmup_iters=1500, 26 | warmup_ratio=1e-6, 27 | power=1.0, min_lr=0.0, by_epoch=False) 28 | 29 | # By default, models are trained on 2 GPUs with 8 images per GPU 30 | data=dict(samples_per_gpu=8, workers_per_gpu=8) 31 | find_unused_parameters=True 32 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/__init__.py 5 | ''' 6 | from .compose import Compose 7 | from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor, 8 | Transpose, to_tensor) 9 | from .loading import LoadAnnotations, LoadImageFromFile 10 | from .test_time_aug import MultiScaleFlipAug 11 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 12 | PhotoMetricDistortion, RandomCrop, RandomCutOut, 13 | RandomFlip, RandomMosaic, RandomRotate, Rerange, 14 | Resize, RGB2Gray, SegRescale) 15 | 16 | __all__ = [ 17 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 18 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 19 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 20 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 21 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', 22 | 'RandomMosaic' 23 | ] 24 | -------------------------------------------------------------------------------- /configs/AFFormer/AFFormer_base_cityscapes.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/afformer.py', '../_base_/datasets/cityscapes_1024x1024.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | pretrained='./pretained_weight/AFFormer_base_ImageNet1k.pth', 7 | backbone=dict( 8 | type='afformer_base', 9 | strides=[4, 2, 2, 2]), 10 | decode_head=dict( 11 | in_channels=[216], 12 | in_index=[3], 13 | channels=256, 14 | aff_channels=256, 15 | aff_kwargs=dict(MD_R=16), 16 | num_classes=19 17 | ) 18 | ) 19 | 20 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 21 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0004, betas=(0.9, 0.999), weight_decay=0.01) 22 | 23 | lr_config = dict(_delete_=True, policy='poly', 24 | warmup='linear', 25 | warmup_iters=1500, 26 | warmup_ratio=1e-6, 27 | power=1.0, min_lr=0.0, by_epoch=False) 28 | 29 | # By default, models are trained on 2 GPUs with 4 images per GPU 30 | data=dict(samples_per_gpu=4, workers_per_gpu=4) 31 | find_unused_parameters=True 32 | -------------------------------------------------------------------------------- /configs/AFFormer/AFFormer_small_cityscapes.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/afformer.py', '../_base_/datasets/cityscapes_1024x1024.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | pretrained='./pretained_weight/AFFormer_small_ImageNet1k.pth', 7 | backbone=dict( 8 | type='afformer_small', 9 | strides=[4, 2, 2, 2]), 10 | decode_head=dict( 11 | in_channels=[216], 12 | in_index=[3], 13 | channels=256, 14 | aff_channels=256, 15 | aff_kwargs=dict(MD_R=16), 16 | num_classes=19 17 | ) 18 | ) 19 | 20 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 21 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0004, betas=(0.9, 0.999), weight_decay=0.01) 22 | 23 | lr_config = dict(_delete_=True, policy='poly', 24 | warmup='linear', 25 | warmup_iters=1500, 26 | warmup_ratio=1e-6, 27 | power=1.0, min_lr=0.0, by_epoch=False) 28 | 29 | # By default, models are trained on 2 GPUs with 4 images per GPU 30 | data=dict(samples_per_gpu=4, workers_per_gpu=4) 31 | find_unused_parameters=True 32 | -------------------------------------------------------------------------------- /configs/AFFormer/AFFormer_tiny_cityscapes.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/afformer.py', '../_base_/datasets/cityscapes_1024x1024.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | pretrained='./pretained_weight/AFFormer_tiny_ImageNet1k.pth', 7 | backbone=dict( 8 | type='afformer_tiny', 9 | strides=[4, 2, 2, 2]), 10 | decode_head=dict( 11 | in_channels=[216], 12 | in_index=[3], 13 | channels=256, 14 | aff_channels=256, 15 | aff_kwargs=dict(MD_R=16), 16 | num_classes=19 17 | ) 18 | ) 19 | 20 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 21 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0004, betas=(0.9, 0.999), weight_decay=0.01) 22 | 23 | lr_config = dict(_delete_=True, policy='poly', 24 | warmup='linear', 25 | warmup_iters=1500, 26 | warmup_ratio=1e-6, 27 | power=1.0, min_lr=0.0, by_epoch=False) 28 | 29 | # By default, models are trained on 2 GPUs with 4 images per GPU 30 | data=dict(samples_per_gpu=4, workers_per_gpu=4) 31 | find_unused_parameters=True 32 | -------------------------------------------------------------------------------- /mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/utils/logger.py 5 | ''' 6 | import logging 7 | 8 | from mmcv.utils import get_logger 9 | 10 | 11 | def get_root_logger(log_file=None, log_level=logging.INFO): 12 | """Get the root logger. 13 | 14 | The logger will be initialized if it has not been initialized. By default a 15 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 16 | also be added. The name of the root logger is the top-level package name, 17 | e.g., "mmseg". 18 | 19 | Args: 20 | log_file (str | None): The log filename. If specified, a FileHandler 21 | will be added to the root logger. 22 | log_level (int): The root logger level. Note that only the process of 23 | rank 0 is affected, while other processes will set the level to 24 | "Error" and be silent most of the time. 25 | 26 | Returns: 27 | logging.Logger: The root logger. 28 | """ 29 | 30 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 31 | 32 | return logger 33 | -------------------------------------------------------------------------------- /tools/cmd.sh: -------------------------------------------------------------------------------- 1 | # base 2 | bash dist_train.sh ./configs/AFFormer/AFFormer_base_ade20k.py 4 3 | bash dist_train.sh ./configs/AFFormer/AFFormer_base_cityscapes.py 2 4 | 5 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_base_ade20k.py ./pretained_weight/AFFormer_base_ade20k.pth 8 --eval mIoU 6 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_base_cityscapes.py ./pretained_weight/AFFormer_base_cityscapes.pth 8 --eval mIoU 7 | 8 | # ----------------------------------------------------------------- 9 | # small 10 | bash dist_train.sh ./configs/AFFormer/AFFormer_base_ade20k.py 4 11 | bash dist_train.sh ./configs/AFFormer/AFFormer_base_cityscapes.py 2 12 | 13 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_small_ade20k.py ./pretained_weight/AFFormer_small_ade20k.pth 8 --eval mIoU 14 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_small_cityscapes.py ./pretained_weight/AFFormer_small_cityscapes.pth 8 --eval mIoU 15 | 16 | 17 | # ----------------------------------------------------------------- 18 | # tiny 19 | bash dist_train.sh ./configs/AFFormer/AFFormer_base_ade20k.py 4 20 | bash dist_train.sh ./configs/AFFormer/AFFormer_base_cityscapes.py 2 21 | 22 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_tiny_ade20k.py ./pretained_weight/AFFormer_tiny_ade20k.pth 8 --eval mIoU 23 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_tiny_cityscapes.py ./pretained_weight/AFFormer_tiny_cityscapes.pth 8 --eval mIoU 24 | 25 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/aff_head.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from mmcv.cnn import ConvModule 8 | 9 | from mmseg.ops import resize 10 | from ..builder import HEADS 11 | from .decode_head import BaseDecodeHead 12 | 13 | 14 | @HEADS.register_module() 15 | class CLS(BaseDecodeHead): 16 | def __init__(self, 17 | aff_channels=512, 18 | aff_kwargs=dict(), 19 | **kwargs): 20 | super(CLS, self).__init__( 21 | input_transform='multiple_select', **kwargs) 22 | self.aff_channels = aff_channels 23 | 24 | self.squeeze = ConvModule( 25 | sum(self.in_channels), 26 | self.channels, 27 | 1, 28 | conv_cfg=self.conv_cfg, 29 | norm_cfg=self.norm_cfg, 30 | act_cfg=self.act_cfg) 31 | 32 | 33 | self.align = ConvModule( 34 | self.aff_channels, 35 | self.channels, 36 | 1, 37 | conv_cfg=self.conv_cfg, 38 | norm_cfg=self.norm_cfg, 39 | act_cfg=self.act_cfg) 40 | 41 | def forward(self, inputs): 42 | """Forward function.""" 43 | inputs = self._transform_inputs(inputs)[0] 44 | 45 | x = self.squeeze(inputs) 46 | 47 | output = self.cls_seg(x) 48 | return output 49 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cityscapes_640x1024.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | _base_ = './cityscapes.py' 5 | img_norm_cfg = dict( 6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 7 | crop_size = (640, 1280) 8 | train_pipeline = [ 9 | dict(type='LoadImageFromFile'), 10 | dict(type='LoadAnnotations'), 11 | dict(type='Resize', img_scale=(1280, 640), ratio_range=(0.5, 2.0)), 12 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 13 | dict(type='RandomFlip', prob=0.5), 14 | dict(type='PhotoMetricDistortion'), 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiScaleFlipAug', 24 | img_scale=(1280, 640), 25 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 26 | flip=False, 27 | transforms=[ 28 | dict(type='Resize', keep_ratio=True), 29 | dict(type='RandomFlip'), 30 | dict(type='Normalize', **img_norm_cfg), 31 | dict(type='ImageToTensor', keys=['img']), 32 | dict(type='Collect', keys=['img']), 33 | ]) 34 | ] 35 | data = dict( 36 | train=dict(pipeline=train_pipeline), 37 | val=dict(pipeline=test_pipeline), 38 | test=dict(pipeline=test_pipeline)) 39 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cityscapes_768x1024.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | _base_ = './cityscapes.py' 5 | img_norm_cfg = dict( 6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 7 | crop_size = (768, 768) 8 | train_pipeline = [ 9 | dict(type='LoadImageFromFile'), 10 | dict(type='LoadAnnotations'), 11 | dict(type='Resize', img_scale=(1536, 768), ratio_range=(0.5, 2.0)), 12 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 13 | dict(type='RandomFlip', prob=0.5), 14 | dict(type='PhotoMetricDistortion'), 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiScaleFlipAug', 24 | img_scale=(1536, 768), 25 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 26 | flip=False, 27 | transforms=[ 28 | dict(type='Resize', keep_ratio=True), 29 | dict(type='RandomFlip'), 30 | dict(type='Normalize', **img_norm_cfg), 31 | dict(type='ImageToTensor', keys=['img']), 32 | dict(type='Collect', keys=['img']), 33 | ]) 34 | ] 35 | data = dict( 36 | train=dict(pipeline=train_pipeline), 37 | val=dict(pipeline=test_pipeline), 38 | test=dict(pipeline=test_pipeline)) 39 | -------------------------------------------------------------------------------- /mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | ''' 4 | import warnings 5 | 6 | from mmcv.cnn import MODELS as MMCV_MODELS 7 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 8 | from mmcv.utils import Registry 9 | 10 | MODELS = Registry('models', parent=MMCV_MODELS) 11 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 12 | 13 | BACKBONES = MODELS 14 | NECKS = MODELS 15 | HEADS = MODELS 16 | LOSSES = MODELS 17 | SEGMENTORS = MODELS 18 | 19 | 20 | def build_backbone(cfg): 21 | """Build backbone.""" 22 | return BACKBONES.build(cfg) 23 | 24 | 25 | def build_neck(cfg): 26 | """Build neck.""" 27 | return NECKS.build(cfg) 28 | 29 | 30 | def build_head(cfg): 31 | """Build head.""" 32 | return HEADS.build(cfg) 33 | 34 | 35 | def build_loss(cfg): 36 | """Build loss.""" 37 | return LOSSES.build(cfg) 38 | 39 | 40 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 41 | """Build segmentor.""" 42 | if train_cfg is not None or test_cfg is not None: 43 | warnings.warn( 44 | 'train_cfg and test_cfg is deprecated, ' 45 | 'please specify them in model', UserWarning) 46 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 47 | 'train_cfg specified in both outer field and model field ' 48 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 49 | 'test_cfg specified in both outer field and model field ' 50 | return SEGMENTORS.build( 51 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 52 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cityscapes_1024x1024.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/datasets/cityscapes_1024x1024.py 5 | ''' 6 | _base_ = './cityscapes.py' 7 | img_norm_cfg = dict( 8 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 9 | crop_size = (1024, 1024) 10 | train_pipeline = [ 11 | dict(type='LoadImageFromFile'), 12 | dict(type='LoadAnnotations'), 13 | dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 14 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 15 | dict(type='RandomFlip', prob=0.5), 16 | dict(type='PhotoMetricDistortion'), 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 19 | dict(type='DefaultFormatBundle'), 20 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 21 | ] 22 | test_pipeline = [ 23 | dict(type='LoadImageFromFile'), 24 | dict( 25 | type='MultiScaleFlipAug', 26 | img_scale=(2048, 1024), 27 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 28 | flip=False, 29 | transforms=[ 30 | dict(type='Resize', keep_ratio=True), 31 | dict(type='RandomFlip'), 32 | dict(type='Normalize', **img_norm_cfg), 33 | dict(type='ImageToTensor', keys=['img']), 34 | dict(type='Collect', keys=['img']), 35 | ]) 36 | ] 37 | data = dict( 38 | train=dict(pipeline=train_pipeline), 39 | val=dict(pipeline=test_pipeline), 40 | test=dict(pipeline=test_pipeline)) 41 | -------------------------------------------------------------------------------- /mmseg/utils/misc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/utils/misc.py 5 | ''' 6 | import glob 7 | import os.path as osp 8 | import warnings 9 | 10 | 11 | def find_latest_checkpoint(path, suffix='pth'): 12 | """This function is for finding the latest checkpoint. 13 | 14 | It will be used when automatically resume, modified from 15 | https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py 16 | 17 | Args: 18 | path (str): The path to find checkpoints. 19 | suffix (str): File extension for the checkpoint. Defaults to pth. 20 | 21 | Returns: 22 | latest_path(str | None): File path of the latest checkpoint. 23 | """ 24 | if not osp.exists(path): 25 | warnings.warn("The path of the checkpoints doesn't exist.") 26 | return None 27 | if osp.exists(osp.join(path, f'latest.{suffix}')): 28 | return osp.join(path, f'latest.{suffix}') 29 | 30 | checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) 31 | if len(checkpoints) == 0: 32 | warnings.warn('The are no checkpoints in the path') 33 | return None 34 | latest = -1 35 | latest_path = '' 36 | for checkpoint in checkpoints: 37 | if len(checkpoint) < len(latest_path): 38 | continue 39 | # `count` is iteration number, as checkpoints are saved as 40 | # 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. 41 | count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) 42 | if count > latest: 43 | latest = count 44 | latest_path = checkpoint 45 | return latest_path 46 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/compose.py 5 | ''' 6 | import collections 7 | 8 | from mmcv.utils import build_from_cfg 9 | 10 | from ..builder import PIPELINES 11 | 12 | 13 | @PIPELINES.register_module() 14 | class Compose(object): 15 | """Compose multiple transforms sequentially. 16 | 17 | Args: 18 | transforms (Sequence[dict | callable]): Sequence of transform object or 19 | config dict to be composed. 20 | """ 21 | 22 | def __init__(self, transforms): 23 | assert isinstance(transforms, collections.abc.Sequence) 24 | self.transforms = [] 25 | for transform in transforms: 26 | if isinstance(transform, dict): 27 | transform = build_from_cfg(transform, PIPELINES) 28 | self.transforms.append(transform) 29 | elif callable(transform): 30 | self.transforms.append(transform) 31 | else: 32 | raise TypeError('transform must be callable or a dict') 33 | 34 | def __call__(self, data): 35 | """Call function to apply transforms sequentially. 36 | 37 | Args: 38 | data (dict): A result dict contains the data to transform. 39 | 40 | Returns: 41 | dict: Transformed data. 42 | """ 43 | 44 | for t in self.transforms: 45 | data = t(data) 46 | if data is None: 47 | return None 48 | return data 49 | 50 | def __repr__(self): 51 | format_string = self.__class__.__name__ + '(' 52 | for t in self.transforms: 53 | format_string += '\n' 54 | format_string += f' {t}' 55 | format_string += '\n)' 56 | return format_string 57 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/datasets/cityscapes.py 5 | ''' 6 | # dataset settings 7 | dataset_type = 'CityscapesDataset' 8 | data_root = 'data/cityscapes/' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (512, 1024) 12 | train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations'), 15 | dict(type='Resize', img_scale=(1024, 512), ratio_range=(0.5, 2.0)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | dict(type='PhotoMetricDistortion'), 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | test_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict( 27 | type='MultiScaleFlipAug', 28 | img_scale=(1024, 512), 29 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='RandomFlip'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='ImageToTensor', keys=['img']), 36 | dict(type='Collect', keys=['img']), 37 | ]) 38 | ] 39 | data = dict( 40 | samples_per_gpu=2, 41 | workers_per_gpu=2, 42 | train=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | img_dir='leftImg8bit/train', 46 | ann_dir='gtFine/train', 47 | pipeline=train_pipeline), 48 | val=dict( 49 | type=dataset_type, 50 | data_root=data_root, 51 | img_dir='leftImg8bit/val', 52 | ann_dir='gtFine/val', 53 | pipeline=test_pipeline), 54 | test=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir='leftImg8bit/val', 58 | ann_dir='gtFine/val', 59 | pipeline=test_pipeline)) 60 | -------------------------------------------------------------------------------- /mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/ops/wrappers.py 5 | ''' 6 | import warnings 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def resize(input, 13 | size=None, 14 | scale_factor=None, 15 | mode='nearest', 16 | align_corners=None, 17 | warning=True): 18 | if warning: 19 | if size is not None and align_corners: 20 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 21 | output_h, output_w = tuple(int(x) for x in size) 22 | if output_h > input_h or output_w > output_h: 23 | if ((output_h > 1 and output_w > 1 and input_h > 1 24 | and input_w > 1) and (output_h - 1) % (input_h - 1) 25 | and (output_w - 1) % (input_w - 1)): 26 | warnings.warn( 27 | f'When align_corners={align_corners}, ' 28 | 'the output would more aligned if ' 29 | f'input size {(input_h, input_w)} is `x+1` and ' 30 | f'out size {(output_h, output_w)} is `nx+1`') 31 | return F.interpolate(input, size, scale_factor, mode, align_corners) 32 | 33 | 34 | class Upsample(nn.Module): 35 | 36 | def __init__(self, 37 | size=None, 38 | scale_factor=None, 39 | mode='nearest', 40 | align_corners=None): 41 | super(Upsample, self).__init__() 42 | self.size = size 43 | if isinstance(scale_factor, tuple): 44 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 45 | else: 46 | self.scale_factor = float(scale_factor) if scale_factor else None 47 | self.mode = mode 48 | self.align_corners = align_corners 49 | 50 | def forward(self, x): 51 | if not self.size: 52 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 53 | else: 54 | size = self.size 55 | return resize(x, size, None, self.mode, self.align_corners) 56 | -------------------------------------------------------------------------------- /mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/__init__.py 5 | ''' 6 | import warnings 7 | 8 | import mmcv 9 | from packaging.version import parse 10 | 11 | from .version import __version__, version_info 12 | 13 | MMCV_MIN = '1.3.13' 14 | MMCV_MAX = '1.5.0' 15 | 16 | 17 | def digit_version(version_str: str, length: int = 4): 18 | """Convert a version string into a tuple of integers. 19 | 20 | This method is usually used for comparing two versions. For pre-release 21 | versions: alpha < beta < rc. 22 | 23 | Args: 24 | version_str (str): The version string. 25 | length (int): The maximum number of version levels. Default: 4. 26 | 27 | Returns: 28 | tuple[int]: The version info in digits (integers). 29 | """ 30 | version = parse(version_str) 31 | assert version.release, f'failed to parse version {version_str}' 32 | release = list(version.release) 33 | release = release[:length] 34 | if len(release) < length: 35 | release = release + [0] * (length - len(release)) 36 | if version.is_prerelease: 37 | mapping = {'a': -3, 'b': -2, 'rc': -1} 38 | val = -4 39 | # version.pre can be None 40 | if version.pre: 41 | if version.pre[0] not in mapping: 42 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 43 | 'version checking may go wrong') 44 | else: 45 | val = mapping[version.pre[0]] 46 | release.extend([val, version.pre[-1]]) 47 | else: 48 | release.extend([val, 0]) 49 | 50 | elif version.is_postrelease: 51 | release.extend([1, version.post]) 52 | else: 53 | release.extend([0, 0]) 54 | return tuple(release) 55 | 56 | 57 | mmcv_min_version = digit_version(MMCV_MIN) 58 | mmcv_max_version = digit_version(MMCV_MAX) 59 | mmcv_version = digit_version(mmcv.__version__) 60 | 61 | 62 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 63 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 64 | f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 65 | 66 | __all__ = ['__version__', 'version_info', 'digit_version'] 67 | -------------------------------------------------------------------------------- /configs/_base_/datasets/coco_stuff10k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | ''' 3 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 4 | This file is modified from: 5 | https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/datasets/coco_stuff10k.py 6 | ''' 7 | dataset_type = 'COCOStuffDataset' 8 | data_root = 'data/coco_stuff10k' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (512, 512) 12 | train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations', reduce_zero_label=True), 15 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | dict(type='PhotoMetricDistortion'), 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | test_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict( 27 | type='MultiScaleFlipAug', 28 | img_scale=(2048, 512), 29 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='RandomFlip'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='ImageToTensor', keys=['img']), 36 | dict(type='Collect', keys=['img']), 37 | ]) 38 | ] 39 | data = dict( 40 | samples_per_gpu=4, 41 | workers_per_gpu=4, 42 | train=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | reduce_zero_label=True, 46 | img_dir='images/train2014', 47 | ann_dir='annotations/train2014', 48 | pipeline=train_pipeline), 49 | val=dict( 50 | type=dataset_type, 51 | data_root=data_root, 52 | reduce_zero_label=True, 53 | img_dir='images/test2014', 54 | ann_dir='annotations/test2014', 55 | pipeline=test_pipeline), 56 | test=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | reduce_zero_label=True, 60 | img_dir='images/test2014', 61 | ann_dir='annotations/test2014', 62 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /mmseg/utils/set_env.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/utils/set_env.py 5 | ''' 6 | import os 7 | import platform 8 | 9 | import cv2 10 | import torch.multiprocessing as mp 11 | 12 | from ..utils import get_root_logger 13 | 14 | 15 | def setup_multi_processes(cfg): 16 | """Setup multi-processing environment variables.""" 17 | logger = get_root_logger() 18 | 19 | # set multi-process start method 20 | if platform.system() != 'Windows': 21 | mp_start_method = cfg.get('mp_start_method', None) 22 | current_method = mp.get_start_method(allow_none=True) 23 | if mp_start_method in ('fork', 'spawn', 'forkserver'): 24 | logger.info( 25 | f'Multi-processing start method `{mp_start_method}` is ' 26 | f'different from the previous setting `{current_method}`.' 27 | f'It will be force set to `{mp_start_method}`.') 28 | mp.set_start_method(mp_start_method, force=True) 29 | else: 30 | logger.info( 31 | f'Multi-processing start method is `{mp_start_method}`') 32 | 33 | # disable opencv multithreading to avoid system being overloaded 34 | opencv_num_threads = cfg.get('opencv_num_threads', None) 35 | if isinstance(opencv_num_threads, int): 36 | logger.info(f'OpenCV num_threads is `{opencv_num_threads}`') 37 | cv2.setNumThreads(opencv_num_threads) 38 | else: 39 | logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}') 40 | 41 | if cfg.data.workers_per_gpu > 1: 42 | # setup OMP threads 43 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 44 | omp_num_threads = cfg.get('omp_num_threads', None) 45 | if 'OMP_NUM_THREADS' not in os.environ: 46 | if isinstance(omp_num_threads, int): 47 | logger.info(f'OMP num threads is {omp_num_threads}') 48 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 49 | else: 50 | logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }') 51 | 52 | # setup MKL threads 53 | if 'MKL_NUM_THREADS' not in os.environ: 54 | mkl_num_threads = cfg.get('mkl_num_threads', None) 55 | if isinstance(mkl_num_threads, int): 56 | logger.info(f'MKL num threads is {mkl_num_threads}') 57 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 58 | else: 59 | logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') 60 | -------------------------------------------------------------------------------- /configs/_base_/datasets/ade20k_std.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/Gsunshine/Enjoy-Hamburger/blob/main/seg_light_ham/configs/_base_/datasets/ade20k_std.py 5 | ''' 6 | # dataset settings 7 | dataset_type = 'ADE20KDataset' 8 | data_root = 'data/ade/ADEChallengeData2016' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (512, 512) 12 | train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations', reduce_zero_label=True), 15 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | dict(type='PhotoMetricDistortion'), 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | val_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict( 27 | type='MultiScaleFlipAug', 28 | img_scale=(2048, 512), 29 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='RandomFlip'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='ImageToTensor', keys=['img']), 36 | dict(type='Collect', keys=['img']), 37 | ]) 38 | ] 39 | test_pipeline = [ 40 | dict(type='LoadImageFromFile'), 41 | dict( 42 | type='MultiScaleFlipAug', 43 | img_scale=(2048, 512), 44 | img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 45 | flip=False, 46 | transforms=[ 47 | dict(type='Resize', keep_ratio=True), 48 | dict(type='RandomFlip'), 49 | dict(type='Normalize', **img_norm_cfg), 50 | dict(type='ImageToTensor', keys=['img']), 51 | dict(type='Collect', keys=['img']), 52 | ]) 53 | ] 54 | data = dict( 55 | samples_per_gpu=4, 56 | workers_per_gpu=4, 57 | train=dict( 58 | type='RepeatDataset', 59 | times=50, 60 | dataset=dict( 61 | type=dataset_type, 62 | data_root=data_root, 63 | img_dir='images/training', 64 | ann_dir='annotations/training', 65 | pipeline=train_pipeline)), 66 | val=dict( 67 | type=dataset_type, 68 | data_root=data_root, 69 | img_dir='images/validation', 70 | ann_dir='annotations/validation', 71 | pipeline=val_pipeline), 72 | test=dict( 73 | type=dataset_type, 74 | data_root=data_root, 75 | img_dir='images/validation', 76 | ann_dir='annotations/validation', 77 | pipeline=test_pipeline)) 78 | -------------------------------------------------------------------------------- /configs/_base_/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | ''' 3 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 4 | This file is modified from: 5 | https://github.com/Gsunshine/Enjoy-Hamburger/blob/main/seg_light_ham/configs/_base_/datasets/ade20k.py 6 | ''' 7 | dataset_type = 'ADE20KDataset' 8 | data_root = 'data/ade/ADEChallengeData2016' 9 | img_norm_cfg = dict( 10 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 11 | crop_size = (512, 512) 12 | train_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='LoadAnnotations', reduce_zero_label=True), 15 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | dict(type='PhotoMetricDistortion'), 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | val_pipeline = [ 25 | dict(type='LoadImageFromFile'), 26 | dict( 27 | type='MultiScaleFlipAug', 28 | img_scale=(2048, 512), 29 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='RandomFlip'), 34 | dict(type='Normalize', **img_norm_cfg), 35 | dict(type='ImageToTensor', keys=['img']), 36 | dict(type='Collect', keys=['img']), 37 | ]) 38 | ] 39 | test_pipeline = [ 40 | dict(type='LoadImageFromFile'), 41 | dict( 42 | type='MultiScaleFlipAug', 43 | img_scale=(2048, 512), 44 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 45 | flip=False, 46 | transforms=[ 47 | # dict(type='AlignResize', keep_ratio=True, size_divisor=32), 48 | dict(type='Resize', keep_ratio=True), 49 | dict(type='ResizeToMultiple', size_divisor=32), 50 | dict(type='RandomFlip'), 51 | dict(type='Normalize', **img_norm_cfg), 52 | dict(type='ImageToTensor', keys=['img']), 53 | dict(type='Collect', keys=['img']), 54 | ]) 55 | ] 56 | data = dict( 57 | samples_per_gpu=4, 58 | workers_per_gpu=4, 59 | train=dict( 60 | type='RepeatDataset', 61 | times=50, 62 | dataset=dict( 63 | type=dataset_type, 64 | data_root=data_root, 65 | img_dir='images/training', 66 | ann_dir='annotations/training', 67 | pipeline=train_pipeline)), 68 | val=dict( 69 | type=dataset_type, 70 | data_root=data_root, 71 | img_dir='images/validation', 72 | ann_dir='annotations/validation', 73 | pipeline=val_pipeline), 74 | test=dict( 75 | type=dataset_type, 76 | data_root=data_root, 77 | img_dir='images/validation', 78 | ann_dir='annotations/validation', 79 | pipeline=test_pipeline)) 80 | -------------------------------------------------------------------------------- /mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/ops/encoding.py 5 | ''' 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | 11 | class Encoding(nn.Module): 12 | """Encoding Layer: a learnable residual encoder. 13 | 14 | Input is of shape (batch_size, channels, height, width). 15 | Output is of shape (batch_size, num_codes, channels). 16 | 17 | Args: 18 | channels: dimension of the features or feature channels 19 | num_codes: number of code words 20 | """ 21 | 22 | def __init__(self, channels, num_codes): 23 | super(Encoding, self).__init__() 24 | # init codewords and smoothing factor 25 | self.channels, self.num_codes = channels, num_codes 26 | std = 1. / ((num_codes * channels)**0.5) 27 | # [num_codes, channels] 28 | self.codewords = nn.Parameter( 29 | torch.empty(num_codes, channels, 30 | dtype=torch.float).uniform_(-std, std), 31 | requires_grad=True) 32 | # [num_codes] 33 | self.scale = nn.Parameter( 34 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 35 | requires_grad=True) 36 | 37 | @staticmethod 38 | def scaled_l2(x, codewords, scale): 39 | num_codes, channels = codewords.size() 40 | batch_size = x.size(0) 41 | reshaped_scale = scale.view((1, 1, num_codes)) 42 | expanded_x = x.unsqueeze(2).expand( 43 | (batch_size, x.size(1), num_codes, channels)) 44 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 45 | 46 | scaled_l2_norm = reshaped_scale * ( 47 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 48 | return scaled_l2_norm 49 | 50 | @staticmethod 51 | def aggregate(assignment_weights, x, codewords): 52 | num_codes, channels = codewords.size() 53 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 54 | batch_size = x.size(0) 55 | 56 | expanded_x = x.unsqueeze(2).expand( 57 | (batch_size, x.size(1), num_codes, channels)) 58 | encoded_feat = (assignment_weights.unsqueeze(3) * 59 | (expanded_x - reshaped_codewords)).sum(dim=1) 60 | return encoded_feat 61 | 62 | def forward(self, x): 63 | assert x.dim() == 4 and x.size(1) == self.channels 64 | # [batch_size, channels, height, width] 65 | batch_size = x.size(0) 66 | # [batch_size, height x width, channels] 67 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 68 | # assignment_weights: [batch_size, channels, num_codes] 69 | assignment_weights = F.softmax( 70 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 71 | # aggregate 72 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 73 | return encoded_feat 74 | 75 | def __repr__(self): 76 | repr_str = self.__class__.__name__ 77 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 78 | f'x{self.channels})' 79 | return repr_str 80 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/cascade_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/segmentors/cascade_encoder_decoder.py 5 | ''' 6 | from torch import nn 7 | 8 | from mmseg.core import add_prefix 9 | from mmseg.ops import resize 10 | from .. import builder 11 | from ..builder import SEGMENTORS 12 | from .encoder_decoder import EncoderDecoder 13 | 14 | 15 | @SEGMENTORS.register_module() 16 | class CascadeEncoderDecoder(EncoderDecoder): 17 | """Cascade Encoder Decoder segmentors. 18 | 19 | CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of 20 | CascadeEncoderDecoder are cascaded. The output of previous decoder_head 21 | will be the input of next decoder_head. 22 | """ 23 | 24 | def __init__(self, 25 | num_stages, 26 | backbone, 27 | decode_head, 28 | neck=None, 29 | auxiliary_head=None, 30 | train_cfg=None, 31 | test_cfg=None, 32 | pretrained=None, 33 | init_cfg=None): 34 | self.num_stages = num_stages 35 | super(CascadeEncoderDecoder, self).__init__( 36 | backbone=backbone, 37 | decode_head=decode_head, 38 | neck=neck, 39 | auxiliary_head=auxiliary_head, 40 | train_cfg=train_cfg, 41 | test_cfg=test_cfg, 42 | pretrained=pretrained, 43 | init_cfg=init_cfg) 44 | 45 | def _init_decode_head(self, decode_head): 46 | """Initialize ``decode_head``""" 47 | assert isinstance(decode_head, list) 48 | assert len(decode_head) == self.num_stages 49 | self.decode_head = nn.ModuleList() 50 | for i in range(self.num_stages): 51 | self.decode_head.append(builder.build_head(decode_head[i])) 52 | self.align_corners = self.decode_head[-1].align_corners 53 | self.num_classes = self.decode_head[-1].num_classes 54 | 55 | def encode_decode(self, img, img_metas): 56 | """Encode images with backbone and decode into a semantic segmentation 57 | map of the same size as input.""" 58 | x = self.extract_feat(img) 59 | out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) 60 | for i in range(1, self.num_stages): 61 | out = self.decode_head[i].forward_test(x, out, img_metas, 62 | self.test_cfg) 63 | out = resize( 64 | input=out, 65 | size=img.shape[2:], 66 | mode='bilinear', 67 | align_corners=self.align_corners) 68 | return out 69 | 70 | def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): 71 | """Run forward function and calculate loss for decode head in 72 | training.""" 73 | losses = dict() 74 | 75 | loss_decode = self.decode_head[0].forward_train( 76 | x, img_metas, gt_semantic_seg, self.train_cfg) 77 | 78 | losses.update(add_prefix(loss_decode, 'decode_0')) 79 | 80 | for i in range(1, self.num_stages): 81 | # forward test again, maybe unnecessary for most methods. 82 | prev_outputs = self.decode_head[i - 1].forward_test( 83 | x, img_metas, self.test_cfg) 84 | loss_decode = self.decode_head[i].forward_train( 85 | x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) 86 | losses.update(add_prefix(loss_decode, f'decode_{i}')) 87 | 88 | return losses 89 | -------------------------------------------------------------------------------- /mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/losses/accuracy.py 5 | ''' 6 | import torch.nn as nn 7 | 8 | 9 | def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): 10 | """Calculate accuracy according to the prediction and target. 11 | 12 | Args: 13 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 14 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 15 | ignore_index (int | None): The label index to be ignored. Default: None 16 | topk (int | tuple[int], optional): If the predictions in ``topk`` 17 | matches the target, the predictions will be regarded as 18 | correct ones. Defaults to 1. 19 | thresh (float, optional): If not None, predictions with scores under 20 | this threshold are considered incorrect. Default to None. 21 | 22 | Returns: 23 | float | tuple[float]: If the input ``topk`` is a single integer, 24 | the function will return a single float as accuracy. If 25 | ``topk`` is a tuple containing multiple integers, the 26 | function will return a tuple containing accuracies of 27 | each ``topk`` number. 28 | """ 29 | assert isinstance(topk, (int, tuple)) 30 | if isinstance(topk, int): 31 | topk = (topk, ) 32 | return_single = True 33 | else: 34 | return_single = False 35 | 36 | maxk = max(topk) 37 | if pred.size(0) == 0: 38 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 39 | return accu[0] if return_single else accu 40 | assert pred.ndim == target.ndim + 1 41 | assert pred.size(0) == target.size(0) 42 | assert maxk <= pred.size(1), \ 43 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 44 | pred_value, pred_label = pred.topk(maxk, dim=1) 45 | # transpose to shape (maxk, N, ...) 46 | pred_label = pred_label.transpose(0, 1) 47 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 48 | if thresh is not None: 49 | # Only prediction values larger than thresh are counted as correct 50 | correct = correct & (pred_value > thresh).t() 51 | correct = correct[:, target != ignore_index] 52 | res = [] 53 | for k in topk: 54 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 55 | res.append( 56 | correct_k.mul_(100.0 / target[target != ignore_index].numel())) 57 | return res[0] if return_single else res 58 | 59 | 60 | class Accuracy(nn.Module): 61 | """Accuracy calculation module.""" 62 | 63 | def __init__(self, topk=(1, ), thresh=None, ignore_index=None): 64 | """Module to calculate the accuracy. 65 | 66 | Args: 67 | topk (tuple, optional): The criterion used to calculate the 68 | accuracy. Defaults to (1,). 69 | thresh (float, optional): If not None, predictions with scores 70 | under this threshold are considered incorrect. Default to None. 71 | """ 72 | super().__init__() 73 | self.topk = topk 74 | self.thresh = thresh 75 | self.ignore_index = ignore_index 76 | 77 | def forward(self, pred, target): 78 | """Forward function to calculate accuracy. 79 | 80 | Args: 81 | pred (torch.Tensor): Prediction of models. 82 | target (torch.Tensor): Target for each prediction. 83 | 84 | Returns: 85 | tuple[float]: The accuracies under different topk criterions. 86 | """ 87 | return accuracy(pred, target, self.topk, self.thresh, 88 | self.ignore_index) 89 | -------------------------------------------------------------------------------- /mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | ''' 4 | from mmcv.cnn import build_conv_layer, build_norm_layer 5 | from mmcv.runner import Sequential 6 | from torch import nn as nn 7 | 8 | 9 | class ResLayer(Sequential): 10 | """ResLayer to build ResNet style backbone. 11 | 12 | Args: 13 | block (nn.Module): block used to build ResLayer. 14 | inplanes (int): inplanes of block. 15 | planes (int): planes of block. 16 | num_blocks (int): number of blocks. 17 | stride (int): stride of the first block. Default: 1 18 | avg_down (bool): Use AvgPool instead of stride conv when 19 | downsampling in the bottleneck. Default: False 20 | conv_cfg (dict): dictionary to construct and config conv layer. 21 | Default: None 22 | norm_cfg (dict): dictionary to construct and config norm layer. 23 | Default: dict(type='BN') 24 | multi_grid (int | None): Multi grid dilation rates of last 25 | stage. Default: None 26 | contract_dilation (bool): Whether contract first dilation of each layer 27 | Default: False 28 | """ 29 | 30 | def __init__(self, 31 | block, 32 | inplanes, 33 | planes, 34 | num_blocks, 35 | stride=1, 36 | dilation=1, 37 | avg_down=False, 38 | conv_cfg=None, 39 | norm_cfg=dict(type='BN'), 40 | multi_grid=None, 41 | contract_dilation=False, 42 | **kwargs): 43 | self.block = block 44 | 45 | downsample = None 46 | if stride != 1 or inplanes != planes * block.expansion: 47 | downsample = [] 48 | conv_stride = stride 49 | if avg_down: 50 | conv_stride = 1 51 | downsample.append( 52 | nn.AvgPool2d( 53 | kernel_size=stride, 54 | stride=stride, 55 | ceil_mode=True, 56 | count_include_pad=False)) 57 | downsample.extend([ 58 | build_conv_layer( 59 | conv_cfg, 60 | inplanes, 61 | planes * block.expansion, 62 | kernel_size=1, 63 | stride=conv_stride, 64 | bias=False), 65 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 66 | ]) 67 | downsample = nn.Sequential(*downsample) 68 | 69 | layers = [] 70 | if multi_grid is None: 71 | if dilation > 1 and contract_dilation: 72 | first_dilation = dilation // 2 73 | else: 74 | first_dilation = dilation 75 | else: 76 | first_dilation = multi_grid[0] 77 | layers.append( 78 | block( 79 | inplanes=inplanes, 80 | planes=planes, 81 | stride=stride, 82 | dilation=first_dilation, 83 | downsample=downsample, 84 | conv_cfg=conv_cfg, 85 | norm_cfg=norm_cfg, 86 | **kwargs)) 87 | inplanes = planes * block.expansion 88 | for i in range(1, num_blocks): 89 | layers.append( 90 | block( 91 | inplanes=inplanes, 92 | planes=planes, 93 | stride=1, 94 | dilation=dilation if multi_grid is None else multi_grid[i], 95 | conv_cfg=conv_cfg, 96 | norm_cfg=norm_cfg, 97 | **kwargs)) 98 | super(ResLayer, self).__init__(*layers) 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Head-Free Lightweight Semantic Segmentation with Linear Transformer 2 | 3 | This repository contains the official Pytorch implementation of training & evaluation code and the pretrained models for [AFFormer](https://arxiv.org/pdf/2301.04648.pdf).🔥🔥 4 | 5 | 6 | 7 |
8 | 9 |
10 |

11 | Figure 1: Performance of AFFormer. 12 |

13 | 14 | AFFormer is a head-free, lightweight and powerful semantic segmentation method, as shown in Figure 1. 15 | 16 | We use [MMSegmentation v0.21.1](https://github.com/open-mmlab/mmsegmentation/tree/v0.21.1) as the codebase. 17 | 18 | 19 | 20 | ## Installation 21 | 22 | For install and data preparation, please refer to the guidelines in [MMSegmentation v0.21.1](https://github.com/open-mmlab/mmsegmentation/tree/v0.21.1). 23 | 24 | An example (works for me): ```CUDA 11.3``` and ```pytorch 1.10.1``` 25 | 26 | ``` 27 | pip install mmcv-full==1.5.0 28 | pip install torchvision 29 | pip install timm 30 | pip install opencv-python 31 | pip install einops 32 | ``` 33 | 34 | ## Evaluation 35 | 36 | Download `weights` 37 | ( 38 | [google drive](https://drive.google.com/drive/folders/1Mru24qPdta9o8aLn1RwT8EapiQCih1Sw?usp=share_link) | 39 | [alidrive](https://www.aliyundrive.com/s/Ha2xMsG9ufy) 40 | ) 41 | 42 | Example: evaluate ```AFFormer-base``` on ```ADE20K``` : 43 | 44 | ``` 45 | # Single-gpu testing 46 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_base_ade20k.py /path/to/checkpoint_file.pth 1 --eval mIoU 47 | 48 | # Multi-gpu testing 49 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_base_ade20k.py /path/to/checkpoint_file.pth --eval mIoU 50 | 51 | # Multi-gpu, multi-scale testing 52 | bash tools/dist_test.sh ./configs/AFFormer/AFFormer_base_ade20k.py /path/to/checkpoint_file.pth --eval mIoU --aug-test 53 | ``` 54 | 55 | ## Training 56 | 57 | Download `weights` 58 | ( 59 | [google drive](https://drive.google.com/drive/folders/1Mru24qPdta9o8aLn1RwT8EapiQCih1Sw?usp=share_link) | 60 | [alidrive](https://www.aliyundrive.com/s/Ha2xMsG9ufy) 61 | ) 62 | pretrained on ImageNet-1K (refer to [deit](https://github.com/facebookresearch/deit)), and put them in a folder ```pretrained/```. 63 | 64 | Example: train ```AFFormer-base``` on ```ADE20K```: 65 | 66 | ``` 67 | # Single-gpu training 68 | bash tools/dist_train.sh ./configs/AFFormer/AFFormer_base_ade20k.py 69 | 70 | # Multi-gpu training 71 | bash tools/dist_train.sh ./configs/AFFormer/AFFormer_base_ade20k.py 72 | ``` 73 | 74 | ## Visualize 75 | 76 | Here is a demo script to test a single image. More details refer to [MMSegmentation's Doc](https://mmsegmentation.readthedocs.io/en/latest/get_started.html). 77 | 78 | ```shell 79 | python demo/image_demo.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${DEVICE_NAME}] [--palette-thr ${PALETTE}] 80 | ``` 81 | 82 | Example: visualize ```SegFormer-B1``` on ```CityScapes```: 83 | 84 | ```shell 85 | python demo/image_demo.py demo/demo.png local_configs/segformer/B1/segformer.b1.512x512.ade.160k.py \ 86 | /path/to/checkpoint_file --device cuda:0 --palette cityscapes 87 | ``` 88 | 89 | ## License 90 | 91 | The code is released under the MIT license. 92 | 93 | ## Copyright 94 | 95 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 96 | 97 | ## Citation 98 | 99 | If you find this work helpful to your research, please consider citing the paper: 100 | 101 | ```bibtex 102 | @inproceedings{dong2023afformer, 103 | title={AFFormer: Head-Free Lightweight Semantic Segmentation with Linear Transformer}, 104 | author={Bo, Dong and Pichao, Wang and Fan Wang}, 105 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 106 | pages={}, 107 | year={2023} 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/seg/sampler/ohem_pixel_smaplesr.py 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from ..builder import PIXEL_SAMPLERS 11 | from .base_pixel_sampler import BasePixelSampler 12 | 13 | 14 | @PIXEL_SAMPLERS.register_module() 15 | class OHEMPixelSampler(BasePixelSampler): 16 | """Online Hard Example Mining Sampler for segmentation. 17 | 18 | Args: 19 | context (nn.Module): The context of sampler, subclass of 20 | :obj:`BaseDecodeHead`. 21 | thresh (float, optional): The threshold for hard example selection. 22 | Below which, are prediction with low confidence. If not 23 | specified, the hard examples will be pixels of top ``min_kept`` 24 | loss. Default: None. 25 | min_kept (int, optional): The minimum number of predictions to keep. 26 | Default: 100000. 27 | """ 28 | 29 | def __init__(self, context, thresh=None, min_kept=100000): 30 | super(OHEMPixelSampler, self).__init__() 31 | self.context = context 32 | assert min_kept > 1 33 | self.thresh = thresh 34 | self.min_kept = min_kept 35 | 36 | def sample(self, seg_logit, seg_label): 37 | """Sample pixels that have high loss or with low prediction confidence. 38 | 39 | Args: 40 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 41 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 42 | 43 | Returns: 44 | torch.Tensor: segmentation weight, shape (N, H, W) 45 | """ 46 | with torch.no_grad(): 47 | assert seg_logit.shape[2:] == seg_label.shape[2:] 48 | assert seg_label.shape[1] == 1 49 | seg_label = seg_label.squeeze(1).long() 50 | batch_kept = self.min_kept * seg_label.size(0) 51 | valid_mask = seg_label != self.context.ignore_index 52 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 53 | valid_seg_weight = seg_weight[valid_mask] 54 | if self.thresh is not None: 55 | seg_prob = F.softmax(seg_logit, dim=1) 56 | 57 | tmp_seg_label = seg_label.clone().unsqueeze(1) 58 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 59 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 60 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 61 | 62 | if sort_prob.numel() > 0: 63 | min_threshold = sort_prob[min(batch_kept, 64 | sort_prob.numel() - 1)] 65 | else: 66 | min_threshold = 0.0 67 | threshold = max(min_threshold, self.thresh) 68 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 69 | else: 70 | if not isinstance(self.context.loss_decode, nn.ModuleList): 71 | losses_decode = [self.context.loss_decode] 72 | else: 73 | losses_decode = self.context.loss_decode 74 | losses = 0.0 75 | for loss_module in losses_decode: 76 | losses += loss_module( 77 | seg_logit, 78 | seg_label, 79 | weight=None, 80 | ignore_index=self.context.ignore_index, 81 | reduction_override='none') 82 | 83 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 84 | _, sort_indices = losses[valid_mask].sort(descending=True) 85 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 86 | 87 | seg_weight[valid_mask] = valid_seg_weight 88 | 89 | return seg_weight 90 | -------------------------------------------------------------------------------- /tools/get_flops_fps.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | ''' 4 | import argparse 5 | 6 | from mmcv import Config 7 | from mmcv.cnn import get_model_complexity_info 8 | 9 | from mmseg.models import build_segmentor 10 | 11 | import math 12 | import torch 13 | import numpy as np 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.nn import init 17 | from collections import OrderedDict 18 | import warnings 19 | import afformer 20 | warnings.filterwarnings('ignore') 21 | 22 | def fps_params_flops(model, size): 23 | import time 24 | device = torch.device('cuda') 25 | model.eval() 26 | model.to(device) 27 | iterations = None 28 | 29 | input = torch.randn(size).to(device) 30 | with torch.no_grad(): 31 | for _ in range(10): 32 | model(input) 33 | 34 | if iterations is None: 35 | elapsed_time = 0 36 | iterations = 100 37 | while elapsed_time < 1: 38 | torch.cuda.synchronize() 39 | torch.cuda.synchronize() 40 | t_start = time.time() 41 | for _ in range(iterations): 42 | model(input) 43 | torch.cuda.synchronize() 44 | torch.cuda.synchronize() 45 | elapsed_time = time.time() - t_start 46 | iterations *= 2 47 | FPS = iterations / elapsed_time 48 | iterations = int(FPS * 6) 49 | 50 | print('=========Speed Testing=========') 51 | torch.cuda.synchronize() 52 | t_start = time.time() 53 | for _ in range(iterations): 54 | model(input) 55 | torch.cuda.synchronize() 56 | elapsed_time = time.time() - t_start 57 | latency = elapsed_time / iterations * 1000 58 | torch.cuda.empty_cache() 59 | FPS = 1000 / latency 60 | print(FPS, ">>>res. ", size) 61 | 62 | from fvcore.nn import FlopCountAnalysis, ActivationCountAnalysis 63 | 64 | flops = FlopCountAnalysis(model, input) 65 | param = sum(p.numel() for p in model.parameters() if p.requires_grad) 66 | acts = ActivationCountAnalysis(model, input) 67 | 68 | print(f"total flops : {flops.total()}") 69 | print(f"total activations: {acts.total()}") 70 | print(f"number of parameter: {param}") 71 | 72 | def parse_args(): 73 | parser = argparse.ArgumentParser(description='Train a segmentor') 74 | parser.add_argument('config', help='train config file path') 75 | parser.add_argument( 76 | '--shape', 77 | type=int, 78 | nargs='+', 79 | default=[512, 512], 80 | help='input image size') 81 | args = parser.parse_args() 82 | return args 83 | 84 | 85 | def main(): 86 | 87 | args = parse_args() 88 | 89 | if len(args.shape) == 1: 90 | input_shape = (3, args.shape[0], args.shape[0]) 91 | elif len(args.shape) == 2: 92 | input_shape = (3, ) + tuple(args.shape) 93 | else: 94 | raise ValueError('invalid input shape') 95 | 96 | cfg = Config.fromfile(args.config) 97 | cfg.model.pretrained = None 98 | model = build_segmentor( 99 | cfg.model, 100 | train_cfg=cfg.get('train_cfg'), 101 | test_cfg=cfg.get('test_cfg')).cuda() 102 | model.eval() 103 | 104 | if hasattr(model, 'forward_dummy'): 105 | model.forward = model.forward_dummy 106 | else: 107 | raise NotImplementedError( 108 | 'FLOPs counter is currently not currently supported with {}'. 109 | format(model.__class__.__name__)) 110 | 111 | flops, params = get_model_complexity_info(model, input_shape) 112 | split_line = '=' * 30 113 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 114 | split_line, input_shape, flops, params)) 115 | print('!!!Please be cautious if you use the results in papers. ' 116 | 'You may need to check if all ops are supported and verify that the ' 117 | 'flops computation is correct.') 118 | 119 | 120 | 121 | sizes = [[1,3,1024, 2048]] 122 | for size in sizes: 123 | fps_params_flops(model, size) 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /mmseg/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/losses/utils.py 5 | ''' 6 | import functools 7 | 8 | import mmcv 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | 13 | def get_class_weight(class_weight): 14 | """Get class weight for loss function. 15 | 16 | Args: 17 | class_weight (list[float] | str | None): If class_weight is a str, 18 | take it as a file name and read from it. 19 | """ 20 | if isinstance(class_weight, str): 21 | # take it as a file path 22 | if class_weight.endswith('.npy'): 23 | class_weight = np.load(class_weight) 24 | else: 25 | # pkl, json or yaml 26 | class_weight = mmcv.load(class_weight) 27 | 28 | return class_weight 29 | 30 | 31 | def reduce_loss(loss, reduction): 32 | """Reduce loss as specified. 33 | 34 | Args: 35 | loss (Tensor): Elementwise loss tensor. 36 | reduction (str): Options are "none", "mean" and "sum". 37 | 38 | Return: 39 | Tensor: Reduced loss tensor. 40 | """ 41 | reduction_enum = F._Reduction.get_enum(reduction) 42 | # none: 0, elementwise_mean:1, sum: 2 43 | if reduction_enum == 0: 44 | return loss 45 | elif reduction_enum == 1: 46 | return loss.mean() 47 | elif reduction_enum == 2: 48 | return loss.sum() 49 | 50 | 51 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 52 | """Apply element-wise weight and reduce loss. 53 | 54 | Args: 55 | loss (Tensor): Element-wise loss. 56 | weight (Tensor): Element-wise weights. 57 | reduction (str): Same as built-in losses of PyTorch. 58 | avg_factor (float): Average factor when computing the mean of losses. 59 | 60 | Returns: 61 | Tensor: Processed loss values. 62 | """ 63 | # if weight is specified, apply element-wise weight 64 | if weight is not None: 65 | assert weight.dim() == loss.dim() 66 | if weight.dim() > 1: 67 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 68 | loss = loss * weight 69 | 70 | # if avg_factor is not specified, just reduce the loss 71 | if avg_factor is None: 72 | loss = reduce_loss(loss, reduction) 73 | else: 74 | # if reduction is mean, then average the loss by avg_factor 75 | if reduction == 'mean': 76 | loss = loss.sum() / avg_factor 77 | # if reduction is 'none', then do nothing, otherwise raise an error 78 | elif reduction != 'none': 79 | raise ValueError('avg_factor can not be used with reduction="sum"') 80 | return loss 81 | 82 | 83 | def weighted_loss(loss_func): 84 | """Create a weighted version of a given loss function. 85 | 86 | To use this decorator, the loss function must have the signature like 87 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 88 | element-wise loss without any reduction. This decorator will add weight 89 | and reduction arguments to the function. The decorated function will have 90 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 91 | avg_factor=None, **kwargs)`. 92 | 93 | :Example: 94 | 95 | >>> import torch 96 | >>> @weighted_loss 97 | >>> def l1_loss(pred, target): 98 | >>> return (pred - target).abs() 99 | 100 | >>> pred = torch.Tensor([0, 2, 3]) 101 | >>> target = torch.Tensor([1, 1, 1]) 102 | >>> weight = torch.Tensor([1, 0, 1]) 103 | 104 | >>> l1_loss(pred, target) 105 | tensor(1.3333) 106 | >>> l1_loss(pred, target, weight) 107 | tensor(1.) 108 | >>> l1_loss(pred, target, reduction='none') 109 | tensor([1., 1., 2.]) 110 | >>> l1_loss(pred, target, weight, avg_factor=2) 111 | tensor(1.5000) 112 | """ 113 | 114 | @functools.wraps(loss_func) 115 | def wrapper(pred, 116 | target, 117 | weight=None, 118 | reduction='mean', 119 | avg_factor=None, 120 | **kwargs): 121 | # get element-wise loss 122 | loss = loss_func(pred, target, **kwargs) 123 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 124 | return loss 125 | 126 | return wrapper 127 | -------------------------------------------------------------------------------- /mmseg/apis/inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/apis/inference.py 5 | ''' 6 | import matplotlib.pyplot as plt 7 | import mmcv 8 | import torch 9 | from mmcv.parallel import collate, scatter 10 | from mmcv.runner import load_checkpoint 11 | 12 | from mmseg.datasets.pipelines import Compose 13 | from mmseg.models import build_segmentor 14 | 15 | 16 | def init_segmentor(config, checkpoint=None, device='cuda:0'): 17 | """Initialize a segmentor from config file. 18 | 19 | Args: 20 | config (str or :obj:`mmcv.Config`): Config file path or the config 21 | object. 22 | checkpoint (str, optional): Checkpoint path. If left as None, the model 23 | will not load any weights. 24 | device (str, optional) CPU/CUDA device option. Default 'cuda:0'. 25 | Use 'cpu' for loading model on CPU. 26 | Returns: 27 | nn.Module: The constructed segmentor. 28 | """ 29 | if isinstance(config, str): 30 | config = mmcv.Config.fromfile(config) 31 | elif not isinstance(config, mmcv.Config): 32 | raise TypeError('config must be a filename or Config object, ' 33 | 'but got {}'.format(type(config))) 34 | config.model.pretrained = None 35 | config.model.train_cfg = None 36 | model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) 37 | if checkpoint is not None: 38 | checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') 39 | model.CLASSES = checkpoint['meta']['CLASSES'] 40 | model.PALETTE = checkpoint['meta']['PALETTE'] 41 | model.cfg = config # save the config in the model for convenience 42 | model.to(device) 43 | model.eval() 44 | return model 45 | 46 | 47 | class LoadImage: 48 | """A simple pipeline to load image.""" 49 | 50 | def __call__(self, results): 51 | """Call function to load images into results. 52 | 53 | Args: 54 | results (dict): A result dict contains the file name 55 | of the image to be read. 56 | 57 | Returns: 58 | dict: ``results`` will be returned containing loaded image. 59 | """ 60 | 61 | if isinstance(results['img'], str): 62 | results['filename'] = results['img'] 63 | results['ori_filename'] = results['img'] 64 | else: 65 | results['filename'] = None 66 | results['ori_filename'] = None 67 | img = mmcv.imread(results['img']) 68 | results['img'] = img 69 | results['img_shape'] = img.shape 70 | results['ori_shape'] = img.shape 71 | return results 72 | 73 | 74 | def inference_segmentor(model, img): 75 | """Inference image(s) with the segmentor. 76 | 77 | Args: 78 | model (nn.Module): The loaded segmentor. 79 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 80 | images. 81 | 82 | Returns: 83 | (list[Tensor]): The segmentation result. 84 | """ 85 | cfg = model.cfg 86 | device = next(model.parameters()).device # model device 87 | # build the data pipeline 88 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 89 | test_pipeline = Compose(test_pipeline) 90 | # prepare data 91 | data = dict(img=img) 92 | data = test_pipeline(data) 93 | data = collate([data], samples_per_gpu=1) 94 | if next(model.parameters()).is_cuda: 95 | # scatter to specified GPU 96 | data = scatter(data, [device])[0] 97 | else: 98 | data['img_metas'] = [i.data[0] for i in data['img_metas']] 99 | 100 | # forward the model 101 | with torch.no_grad(): 102 | result = model(return_loss=False, rescale=True, **data) 103 | return result 104 | 105 | 106 | def show_result_pyplot(model, 107 | img, 108 | result, 109 | palette=None, 110 | fig_size=(15, 10), 111 | opacity=0.5, 112 | title='', 113 | block=True): 114 | """Visualize the segmentation results on the image. 115 | 116 | Args: 117 | model (nn.Module): The loaded segmentor. 118 | img (str or np.ndarray): Image filename or loaded image. 119 | result (list): The segmentation result. 120 | palette (list[list[int]]] | None): The palette of segmentation 121 | map. If None is given, random palette will be generated. 122 | Default: None 123 | fig_size (tuple): Figure size of the pyplot figure. 124 | opacity(float): Opacity of painted segmentation map. 125 | Default 0.5. 126 | Must be in (0, 1] range. 127 | title (str): The title of pyplot figure. 128 | Default is ''. 129 | block (bool): Whether to block the pyplot figure. 130 | Default is True. 131 | """ 132 | if hasattr(model, 'module'): 133 | model = model.module 134 | img = model.show_result( 135 | img, result, palette=palette, show=False, opacity=opacity) 136 | plt.figure(figsize=fig_size) 137 | plt.imshow(mmcv.bgr2rgb(img)) 138 | plt.title(title) 139 | plt.tight_layout() 140 | plt.show(block=block) 141 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/evaluation/eval_hooks.py 5 | ''' 6 | import os.path as osp 7 | import warnings 8 | 9 | import torch.distributed as dist 10 | from mmcv.runner import DistEvalHook as _DistEvalHook 11 | from mmcv.runner import EvalHook as _EvalHook 12 | from torch.nn.modules.batchnorm import _BatchNorm 13 | 14 | 15 | class EvalHook(_EvalHook): 16 | """Single GPU EvalHook, with efficient test support. 17 | 18 | Args: 19 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 20 | If set to True, it will perform by epoch. Otherwise, by iteration. 21 | Default: False. 22 | efficient_test (bool): Whether save the results as local numpy files to 23 | save CPU memory during evaluation. Default: False. 24 | pre_eval (bool): Whether to use progressive mode to evaluate model. 25 | Default: False. 26 | Returns: 27 | list: The prediction results. 28 | """ 29 | 30 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 31 | 32 | def __init__(self, 33 | *args, 34 | by_epoch=False, 35 | efficient_test=False, 36 | pre_eval=False, 37 | **kwargs): 38 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 39 | self.pre_eval = pre_eval 40 | if efficient_test: 41 | warnings.warn( 42 | 'DeprecationWarning: ``efficient_test`` for evaluation hook ' 43 | 'is deprecated, the evaluation hook is CPU memory friendly ' 44 | 'with ``pre_eval=True`` as argument for ``single_gpu_test()`` ' 45 | 'function') 46 | 47 | def _do_evaluate(self, runner): 48 | """perform evaluation and save ckpt.""" 49 | if not self._should_evaluate(runner): 50 | return 51 | 52 | from mmseg.apis import single_gpu_test 53 | results = single_gpu_test( 54 | runner.model, self.dataloader, show=False, pre_eval=self.pre_eval) 55 | runner.log_buffer.clear() 56 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 57 | key_score = self.evaluate(runner, results) 58 | if self.save_best: 59 | self._save_ckpt(runner, key_score) 60 | 61 | 62 | class DistEvalHook(_DistEvalHook): 63 | """Distributed EvalHook, with efficient test support. 64 | 65 | Args: 66 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 67 | If set to True, it will perform by epoch. Otherwise, by iteration. 68 | Default: False. 69 | efficient_test (bool): Whether save the results as local numpy files to 70 | save CPU memory during evaluation. Default: False. 71 | pre_eval (bool): Whether to use progressive mode to evaluate model. 72 | Default: False. 73 | Returns: 74 | list: The prediction results. 75 | """ 76 | 77 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 78 | 79 | def __init__(self, 80 | *args, 81 | by_epoch=False, 82 | efficient_test=False, 83 | pre_eval=False, 84 | **kwargs): 85 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 86 | self.pre_eval = pre_eval 87 | if efficient_test: 88 | warnings.warn( 89 | 'DeprecationWarning: ``efficient_test`` for evaluation hook ' 90 | 'is deprecated, the evaluation hook is CPU memory friendly ' 91 | 'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` ' 92 | 'function') 93 | 94 | def _do_evaluate(self, runner): 95 | """perform evaluation and save ckpt.""" 96 | # Synchronization of BatchNorm's buffer (running_mean 97 | # and running_var) is not supported in the DDP of pytorch, 98 | # which may cause the inconsistent performance of models in 99 | # different ranks, so we broadcast BatchNorm's buffers 100 | # of rank 0 to other ranks to avoid this. 101 | if self.broadcast_bn_buffer: 102 | model = runner.model 103 | for name, module in model.named_modules(): 104 | if isinstance(module, 105 | _BatchNorm) and module.track_running_stats: 106 | dist.broadcast(module.running_var, 0) 107 | dist.broadcast(module.running_mean, 0) 108 | 109 | if not self._should_evaluate(runner): 110 | return 111 | 112 | tmpdir = self.tmpdir 113 | if tmpdir is None: 114 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 115 | 116 | from mmseg.apis import multi_gpu_test 117 | results = multi_gpu_test( 118 | runner.model, 119 | self.dataloader, 120 | tmpdir=tmpdir, 121 | gpu_collect=self.gpu_collect, 122 | pre_eval=self.pre_eval) 123 | 124 | runner.log_buffer.clear() 125 | 126 | if runner.rank == 0: 127 | print('\n') 128 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 129 | key_score = self.evaluate(runner, results) 130 | 131 | if self.save_best: 132 | self._save_ckpt(runner, key_score) 133 | -------------------------------------------------------------------------------- /mmseg/models/losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/losses/dice.py 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from ..builder import LOSSES 11 | from .utils import get_class_weight, weighted_loss 12 | 13 | 14 | @weighted_loss 15 | def dice_loss(pred, 16 | target, 17 | valid_mask, 18 | smooth=1, 19 | exponent=2, 20 | class_weight=None, 21 | ignore_index=255): 22 | assert pred.shape[0] == target.shape[0] 23 | total_loss = 0 24 | num_classes = pred.shape[1] 25 | for i in range(num_classes): 26 | if i != ignore_index: 27 | dice_loss = binary_dice_loss( 28 | pred[:, i], 29 | target[..., i], 30 | valid_mask=valid_mask, 31 | smooth=smooth, 32 | exponent=exponent) 33 | if class_weight is not None: 34 | dice_loss *= class_weight[i] 35 | total_loss += dice_loss 36 | return total_loss / num_classes 37 | 38 | 39 | @weighted_loss 40 | def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): 41 | assert pred.shape[0] == target.shape[0] 42 | pred = pred.reshape(pred.shape[0], -1) 43 | target = target.reshape(target.shape[0], -1) 44 | valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) 45 | 46 | num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth 47 | den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth 48 | 49 | return 1 - num / den 50 | 51 | 52 | @LOSSES.register_module() 53 | class DiceLoss(nn.Module): 54 | """DiceLoss. 55 | 56 | This loss is proposed in `V-Net: Fully Convolutional Neural Networks for 57 | Volumetric Medical Image Segmentation `_. 58 | 59 | Args: 60 | smooth (float): A float number to smooth loss, and avoid NaN error. 61 | Default: 1 62 | exponent (float): An float number to calculate denominator 63 | value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. 64 | reduction (str, optional): The method used to reduce the loss. Options 65 | are "none", "mean" and "sum". This parameter only works when 66 | per_image is True. Default: 'mean'. 67 | class_weight (list[float] | str, optional): Weight of each class. If in 68 | str format, read them from a file. Defaults to None. 69 | loss_weight (float, optional): Weight of the loss. Default to 1.0. 70 | ignore_index (int | None): The label index to be ignored. Default: 255. 71 | loss_name (str, optional): Name of the loss item. If you want this loss 72 | item to be included into the backward graph, `loss_` must be the 73 | prefix of the name. Defaults to 'loss_dice'. 74 | """ 75 | 76 | def __init__(self, 77 | smooth=1, 78 | exponent=2, 79 | reduction='mean', 80 | class_weight=None, 81 | loss_weight=1.0, 82 | ignore_index=255, 83 | loss_name='loss_dice', 84 | **kwards): 85 | super(DiceLoss, self).__init__() 86 | self.smooth = smooth 87 | self.exponent = exponent 88 | self.reduction = reduction 89 | self.class_weight = get_class_weight(class_weight) 90 | self.loss_weight = loss_weight 91 | self.ignore_index = ignore_index 92 | self._loss_name = loss_name 93 | 94 | def forward(self, 95 | pred, 96 | target, 97 | avg_factor=None, 98 | reduction_override=None, 99 | **kwards): 100 | assert reduction_override in (None, 'none', 'mean', 'sum') 101 | reduction = ( 102 | reduction_override if reduction_override else self.reduction) 103 | if self.class_weight is not None: 104 | class_weight = pred.new_tensor(self.class_weight) 105 | else: 106 | class_weight = None 107 | 108 | pred = F.softmax(pred, dim=1) 109 | num_classes = pred.shape[1] 110 | one_hot_target = F.one_hot( 111 | torch.clamp(target.long(), 0, num_classes - 1), 112 | num_classes=num_classes) 113 | valid_mask = (target != self.ignore_index).long() 114 | 115 | loss = self.loss_weight * dice_loss( 116 | pred, 117 | one_hot_target, 118 | valid_mask=valid_mask, 119 | reduction=reduction, 120 | avg_factor=avg_factor, 121 | smooth=self.smooth, 122 | exponent=self.exponent, 123 | class_weight=class_weight, 124 | ignore_index=self.ignore_index) 125 | return loss 126 | 127 | @property 128 | def loss_name(self): 129 | """Loss Name. 130 | 131 | This function must be implemented and will return the name of this 132 | loss function. This name will be used to combine different loss items 133 | by simple sum operation. In addition, if you want this loss item to be 134 | included into the backward graph, `loss_` must be the prefix of the 135 | name. 136 | Returns: 137 | str: The name of this loss item. 138 | """ 139 | return self._loss_name 140 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/test_time_aug.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/test_time_aug.py 5 | ''' 6 | import warnings 7 | 8 | import mmcv 9 | 10 | from ..builder import PIPELINES 11 | from .compose import Compose 12 | 13 | 14 | @PIPELINES.register_module() 15 | class MultiScaleFlipAug(object): 16 | """Test-time augmentation with multiple scales and flipping. 17 | 18 | An example configuration is as followed: 19 | 20 | .. code-block:: 21 | 22 | img_scale=(2048, 1024), 23 | img_ratios=[0.5, 1.0], 24 | flip=True, 25 | transforms=[ 26 | dict(type='Resize', keep_ratio=True), 27 | dict(type='RandomFlip'), 28 | dict(type='Normalize', **img_norm_cfg), 29 | dict(type='Pad', size_divisor=32), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ] 33 | 34 | After MultiScaleFLipAug with above configuration, the results are wrapped 35 | into lists of the same length as followed: 36 | 37 | .. code-block:: 38 | 39 | dict( 40 | img=[...], 41 | img_shape=[...], 42 | scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] 43 | flip=[False, True, False, True] 44 | ... 45 | ) 46 | 47 | Args: 48 | transforms (list[dict]): Transforms to apply in each augmentation. 49 | img_scale (None | tuple | list[tuple]): Images scales for resizing. 50 | img_ratios (float | list[float]): Image ratios for resizing 51 | flip (bool): Whether apply flip augmentation. Default: False. 52 | flip_direction (str | list[str]): Flip augmentation directions, 53 | options are "horizontal" and "vertical". If flip_direction is list, 54 | multiple flip augmentations will be applied. 55 | It has no effect when flip == False. Default: "horizontal". 56 | """ 57 | 58 | def __init__(self, 59 | transforms, 60 | img_scale, 61 | img_ratios=None, 62 | flip=False, 63 | flip_direction='horizontal'): 64 | self.transforms = Compose(transforms) 65 | if img_ratios is not None: 66 | img_ratios = img_ratios if isinstance(img_ratios, 67 | list) else [img_ratios] 68 | assert mmcv.is_list_of(img_ratios, float) 69 | if img_scale is None: 70 | # mode 1: given img_scale=None and a range of image ratio 71 | self.img_scale = None 72 | assert mmcv.is_list_of(img_ratios, float) 73 | elif isinstance(img_scale, tuple) and mmcv.is_list_of( 74 | img_ratios, float): 75 | assert len(img_scale) == 2 76 | # mode 2: given a scale and a range of image ratio 77 | self.img_scale = [(int(img_scale[0] * ratio), 78 | int(img_scale[1] * ratio)) 79 | for ratio in img_ratios] 80 | else: 81 | # mode 3: given multiple scales 82 | self.img_scale = img_scale if isinstance(img_scale, 83 | list) else [img_scale] 84 | assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None 85 | self.flip = flip 86 | self.img_ratios = img_ratios 87 | self.flip_direction = flip_direction if isinstance( 88 | flip_direction, list) else [flip_direction] 89 | assert mmcv.is_list_of(self.flip_direction, str) 90 | if not self.flip and self.flip_direction != ['horizontal']: 91 | warnings.warn( 92 | 'flip_direction has no effect when flip is set to False') 93 | if (self.flip 94 | and not any([t['type'] == 'RandomFlip' for t in transforms])): 95 | warnings.warn( 96 | 'flip has no effect when RandomFlip is not in transforms') 97 | 98 | def __call__(self, results): 99 | """Call function to apply test time augment transforms on results. 100 | 101 | Args: 102 | results (dict): Result dict contains the data to transform. 103 | 104 | Returns: 105 | dict[str: list]: The augmented data, where each value is wrapped 106 | into a list. 107 | """ 108 | 109 | aug_data = [] 110 | if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): 111 | h, w = results['img'].shape[:2] 112 | img_scale = [(int(w * ratio), int(h * ratio)) 113 | for ratio in self.img_ratios] 114 | else: 115 | img_scale = self.img_scale 116 | flip_aug = [False, True] if self.flip else [False] 117 | for scale in img_scale: 118 | for flip in flip_aug: 119 | for direction in self.flip_direction: 120 | _results = results.copy() 121 | _results['scale'] = scale 122 | _results['flip'] = flip 123 | _results['flip_direction'] = direction 124 | data = self.transforms(_results) 125 | aug_data.append(data) 126 | # list of dict to dict of list 127 | aug_data_dict = {key: [] for key in aug_data[0]} 128 | for data in aug_data: 129 | for key, val in data.items(): 130 | aug_data_dict[key].append(val) 131 | return aug_data_dict 132 | 133 | def __repr__(self): 134 | repr_str = self.__class__.__name__ 135 | repr_str += f'(transforms={self.transforms}, ' 136 | repr_str += f'img_scale={self.img_scale}, flip={self.flip})' 137 | repr_str += f'flip_direction={self.flip_direction}' 138 | return repr_str 139 | -------------------------------------------------------------------------------- /mmseg/datasets/coco_stuff.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/coco_stuff.py 5 | ''' 6 | from .builder import DATASETS 7 | from .custom import CustomDataset 8 | 9 | 10 | @DATASETS.register_module() 11 | class COCOStuffDataset(CustomDataset): 12 | """COCO-Stuff dataset. 13 | 14 | In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version 15 | are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff 16 | 164k is from 0 to 170, where 255 is the ignore index. So, they are all 171 17 | semantic categories. ``reduce_zero_label`` is set to True and False for the 18 | 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg', 19 | and ``seg_map_suffix`` is fixed to '.png'. 20 | """ 21 | CLASSES = ( 22 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 23 | 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 24 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 25 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 26 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 27 | 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 28 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 29 | 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 30 | 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 31 | 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 32 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 33 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 34 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 35 | 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', 36 | 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', 37 | 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', 38 | 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', 39 | 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 40 | 'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 41 | 'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 42 | 'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 43 | 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', 44 | 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', 45 | 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', 46 | 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', 47 | 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', 48 | 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', 49 | 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', 50 | 'window-blind', 'window-other', 'wood') 51 | 52 | PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], 53 | [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], 54 | [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], 55 | [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], 56 | [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], 57 | [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], 58 | [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], 59 | [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], 60 | [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128], 61 | [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], 62 | [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128], 63 | [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192], 64 | [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], 65 | [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0], 66 | [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], 67 | [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], 68 | [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128], 69 | [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], 70 | [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224], 71 | [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0], 72 | [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128], 73 | [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224], 74 | [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128], 75 | [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192], 76 | [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224], 77 | [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0], 78 | [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192], 79 | [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224], 80 | [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128], 81 | [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128], 82 | [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160], 83 | [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64], 84 | [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128], 85 | [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160], 86 | [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192], 87 | [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192], 88 | [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160], 89 | [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64], 90 | [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192], 91 | [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160], 92 | [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192], 93 | [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128], 94 | [64, 192, 96], [64, 160, 64], [64, 64, 0]] 95 | 96 | def __init__(self, **kwargs): 97 | super(COCOStuffDataset, self).__init__( 98 | img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs) 99 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/loading.py 5 | ''' 6 | import os.path as osp 7 | 8 | import mmcv 9 | import numpy as np 10 | 11 | from ..builder import PIPELINES 12 | 13 | 14 | @PIPELINES.register_module() 15 | class LoadImageFromFile(object): 16 | """Load an image from file. 17 | 18 | Required keys are "img_prefix" and "img_info" (a dict that must contain the 19 | key "filename"). Added or updated keys are "filename", "img", "img_shape", 20 | "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), 21 | "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). 22 | 23 | Args: 24 | to_float32 (bool): Whether to convert the loaded image to a float32 25 | numpy array. If set to False, the loaded image is an uint8 array. 26 | Defaults to False. 27 | color_type (str): The flag argument for :func:`mmcv.imfrombytes`. 28 | Defaults to 'color'. 29 | file_client_args (dict): Arguments to instantiate a FileClient. 30 | See :class:`mmcv.fileio.FileClient` for details. 31 | Defaults to ``dict(backend='disk')``. 32 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 33 | 'cv2' 34 | """ 35 | 36 | def __init__(self, 37 | to_float32=False, 38 | color_type='color', 39 | file_client_args=dict(backend='disk'), 40 | imdecode_backend='cv2'): 41 | self.to_float32 = to_float32 42 | self.color_type = color_type 43 | self.file_client_args = file_client_args.copy() 44 | self.file_client = None 45 | self.imdecode_backend = imdecode_backend 46 | 47 | def __call__(self, results): 48 | """Call functions to load image and get image meta information. 49 | 50 | Args: 51 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 52 | 53 | Returns: 54 | dict: The dict contains loaded image and meta information. 55 | """ 56 | 57 | if self.file_client is None: 58 | self.file_client = mmcv.FileClient(**self.file_client_args) 59 | 60 | if results.get('img_prefix') is not None: 61 | filename = osp.join(results['img_prefix'], 62 | results['img_info']['filename']) 63 | else: 64 | filename = results['img_info']['filename'] 65 | img_bytes = self.file_client.get(filename) 66 | img = mmcv.imfrombytes( 67 | img_bytes, flag=self.color_type, backend=self.imdecode_backend) 68 | if self.to_float32: 69 | img = img.astype(np.float32) 70 | 71 | results['filename'] = filename 72 | results['ori_filename'] = results['img_info']['filename'] 73 | results['img'] = img 74 | results['img_shape'] = img.shape 75 | results['ori_shape'] = img.shape 76 | # Set initial values for default meta_keys 77 | results['pad_shape'] = img.shape 78 | results['scale_factor'] = 1.0 79 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 80 | results['img_norm_cfg'] = dict( 81 | mean=np.zeros(num_channels, dtype=np.float32), 82 | std=np.ones(num_channels, dtype=np.float32), 83 | to_rgb=False) 84 | return results 85 | 86 | def __repr__(self): 87 | repr_str = self.__class__.__name__ 88 | repr_str += f'(to_float32={self.to_float32},' 89 | repr_str += f"color_type='{self.color_type}'," 90 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 91 | return repr_str 92 | 93 | 94 | @PIPELINES.register_module() 95 | class LoadAnnotations(object): 96 | """Load annotations for semantic segmentation. 97 | 98 | Args: 99 | reduce_zero_label (bool): Whether reduce all label value by 1. 100 | Usually used for datasets where 0 is background label. 101 | Default: False. 102 | file_client_args (dict): Arguments to instantiate a FileClient. 103 | See :class:`mmcv.fileio.FileClient` for details. 104 | Defaults to ``dict(backend='disk')``. 105 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 106 | 'pillow' 107 | """ 108 | 109 | def __init__(self, 110 | reduce_zero_label=False, 111 | file_client_args=dict(backend='disk'), 112 | imdecode_backend='pillow'): 113 | self.reduce_zero_label = reduce_zero_label 114 | self.file_client_args = file_client_args.copy() 115 | self.file_client = None 116 | self.imdecode_backend = imdecode_backend 117 | 118 | def __call__(self, results): 119 | """Call function to load multiple types annotations. 120 | 121 | Args: 122 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 123 | 124 | Returns: 125 | dict: The dict contains loaded semantic segmentation annotations. 126 | """ 127 | 128 | if self.file_client is None: 129 | self.file_client = mmcv.FileClient(**self.file_client_args) 130 | 131 | if results.get('seg_prefix', None) is not None: 132 | filename = osp.join(results['seg_prefix'], 133 | results['ann_info']['seg_map']) 134 | else: 135 | filename = results['ann_info']['seg_map'] 136 | img_bytes = self.file_client.get(filename) 137 | gt_semantic_seg = mmcv.imfrombytes( 138 | img_bytes, flag='unchanged', 139 | backend=self.imdecode_backend).squeeze().astype(np.uint8) 140 | # modify if custom classes 141 | if results.get('label_map', None) is not None: 142 | for old_id, new_id in results['label_map'].items(): 143 | gt_semantic_seg[gt_semantic_seg == old_id] = new_id 144 | # reduce zero_label 145 | if self.reduce_zero_label: 146 | # avoid using underflow conversion 147 | gt_semantic_seg[gt_semantic_seg == 0] = 255 148 | gt_semantic_seg = gt_semantic_seg - 1 149 | gt_semantic_seg[gt_semantic_seg == 254] = 255 150 | results['gt_semantic_seg'] = gt_semantic_seg 151 | results['seg_fields'].append('gt_semantic_seg') 152 | return results 153 | 154 | def __repr__(self): 155 | repr_str = self.__class__.__name__ 156 | repr_str += f'(reduce_zero_label={self.reduce_zero_label},' 157 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 158 | return repr_str 159 | -------------------------------------------------------------------------------- /mmseg/apis/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/apis/train.py 5 | ''' 6 | import random 7 | import warnings 8 | 9 | import mmcv 10 | import numpy as np 11 | import torch 12 | import torch.distributed as dist 13 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 14 | from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info 15 | from mmcv.utils import build_from_cfg 16 | 17 | from mmseg import digit_version 18 | from mmseg.core import DistEvalHook, EvalHook 19 | from mmseg.datasets import build_dataloader, build_dataset 20 | from mmseg.utils import find_latest_checkpoint, get_root_logger 21 | 22 | 23 | def init_random_seed(seed=None, device='cuda'): 24 | """Initialize random seed. 25 | 26 | If the seed is not set, the seed will be automatically randomized, 27 | and then broadcast to all processes to prevent some potential bugs. 28 | Args: 29 | seed (int, Optional): The seed. Default to None. 30 | device (str): The device where the seed will be put on. 31 | Default to 'cuda'. 32 | Returns: 33 | int: Seed to be used. 34 | """ 35 | if seed is not None: 36 | return seed 37 | 38 | # Make sure all ranks share the same random seed to prevent 39 | # some potential bugs. Please refer to 40 | # https://github.com/open-mmlab/mmdetection/issues/6339 41 | rank, world_size = get_dist_info() 42 | seed = np.random.randint(2**31) 43 | if world_size == 1: 44 | return seed 45 | 46 | if rank == 0: 47 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 48 | else: 49 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 50 | dist.broadcast(random_num, src=0) 51 | return random_num.item() 52 | 53 | 54 | def set_random_seed(seed, deterministic=False): 55 | """Set random seed. 56 | 57 | Args: 58 | seed (int): Seed to be used. 59 | deterministic (bool): Whether to set the deterministic option for 60 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 61 | to True and `torch.backends.cudnn.benchmark` to False. 62 | Default: False. 63 | """ 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | torch.manual_seed(seed) 67 | torch.cuda.manual_seed_all(seed) 68 | if deterministic: 69 | torch.backends.cudnn.deterministic = True 70 | torch.backends.cudnn.benchmark = False 71 | 72 | 73 | def train_segmentor(model, 74 | dataset, 75 | cfg, 76 | distributed=False, 77 | validate=False, 78 | timestamp=None, 79 | meta=None): 80 | """Launch segmentor training.""" 81 | logger = get_root_logger(cfg.log_level) 82 | 83 | # prepare data loaders 84 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 85 | data_loaders = [ 86 | build_dataloader( 87 | ds, 88 | cfg.data.samples_per_gpu, 89 | cfg.data.workers_per_gpu, 90 | # cfg.gpus will be ignored if distributed 91 | len(cfg.gpu_ids), 92 | dist=distributed, 93 | seed=cfg.seed, 94 | drop_last=True) for ds in dataset 95 | ] 96 | 97 | # put model on gpus 98 | if distributed: 99 | find_unused_parameters = cfg.get('find_unused_parameters', False) 100 | # Sets the `find_unused_parameters` parameter in 101 | # torch.nn.parallel.DistributedDataParallel 102 | model = MMDistributedDataParallel( 103 | model.cuda(), 104 | device_ids=[torch.cuda.current_device()], 105 | broadcast_buffers=False, 106 | find_unused_parameters=find_unused_parameters) 107 | else: 108 | if not torch.cuda.is_available(): 109 | assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ 110 | 'Please use MMCV >= 1.4.4 for CPU training!' 111 | model = MMDataParallel(model, device_ids=cfg.gpu_ids) 112 | # build runner 113 | optimizer = build_optimizer(model, cfg.optimizer) 114 | 115 | if cfg.get('runner') is None: 116 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 117 | warnings.warn( 118 | 'config is now expected to have a `runner` section, ' 119 | 'please set `runner` in your config.', UserWarning) 120 | 121 | runner = build_runner( 122 | cfg.runner, 123 | default_args=dict( 124 | model=model, 125 | batch_processor=None, 126 | optimizer=optimizer, 127 | work_dir=cfg.work_dir, 128 | logger=logger, 129 | meta=meta)) 130 | 131 | # register hooks 132 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 133 | cfg.checkpoint_config, cfg.log_config, 134 | cfg.get('momentum_config', None)) 135 | 136 | # an ugly walkaround to make the .log and .log.json filenames the same 137 | runner.timestamp = timestamp 138 | 139 | # register eval hooks 140 | if validate: 141 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 142 | val_dataloader = build_dataloader( 143 | val_dataset, 144 | samples_per_gpu=1, 145 | workers_per_gpu=cfg.data.workers_per_gpu, 146 | dist=distributed, 147 | shuffle=False) 148 | eval_cfg = cfg.get('evaluation', {}) 149 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 150 | eval_hook = DistEvalHook if distributed else EvalHook 151 | # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the 152 | # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. 153 | runner.register_hook( 154 | eval_hook(val_dataloader, **eval_cfg), priority='LOW') 155 | 156 | # user-defined hooks 157 | if cfg.get('custom_hooks', None): 158 | custom_hooks = cfg.custom_hooks 159 | assert isinstance(custom_hooks, list), \ 160 | f'custom_hooks expect list type, but got {type(custom_hooks)}' 161 | for hook_cfg in cfg.custom_hooks: 162 | assert isinstance(hook_cfg, dict), \ 163 | 'Each item in custom_hooks expects dict type, but got ' \ 164 | f'{type(hook_cfg)}' 165 | hook_cfg = hook_cfg.copy() 166 | priority = hook_cfg.pop('priority', 'NORMAL') 167 | hook = build_from_cfg(hook_cfg, HOOKS) 168 | runner.register_hook(hook, priority=priority) 169 | 170 | if cfg.resume_from is None and cfg.get('auto_resume'): 171 | resume_from = find_latest_checkpoint(cfg.work_dir) 172 | if resume_from is not None: 173 | cfg.resume_from = resume_from 174 | if cfg.resume_from: 175 | runner.resume(cfg.resume_from) 176 | elif cfg.load_from: 177 | runner.load_checkpoint(cfg.load_from) 178 | runner.run(data_loaders, cfg.workflow) 179 | -------------------------------------------------------------------------------- /mmseg/datasets/builder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/ade.py 5 | ''' 6 | import copy 7 | import platform 8 | import random 9 | from functools import partial 10 | 11 | import numpy as np 12 | import torch 13 | from mmcv.parallel import collate 14 | from mmcv.runner import get_dist_info 15 | from mmcv.utils import Registry, build_from_cfg, digit_version 16 | from torch.utils.data import DataLoader, DistributedSampler 17 | 18 | if platform.system() != 'Windows': 19 | # https://github.com/pytorch/pytorch/issues/973 20 | import resource 21 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 22 | base_soft_limit = rlimit[0] 23 | hard_limit = rlimit[1] 24 | soft_limit = min(max(4096, base_soft_limit), hard_limit) 25 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 26 | 27 | DATASETS = Registry('dataset') 28 | PIPELINES = Registry('pipeline') 29 | 30 | 31 | def _concat_dataset(cfg, default_args=None): 32 | """Build :obj:`ConcatDataset by.""" 33 | from .dataset_wrappers import ConcatDataset 34 | img_dir = cfg['img_dir'] 35 | ann_dir = cfg.get('ann_dir', None) 36 | split = cfg.get('split', None) 37 | # pop 'separate_eval' since it is not a valid key for common datasets. 38 | separate_eval = cfg.pop('separate_eval', True) 39 | num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 40 | if ann_dir is not None: 41 | num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 42 | else: 43 | num_ann_dir = 0 44 | if split is not None: 45 | num_split = len(split) if isinstance(split, (list, tuple)) else 1 46 | else: 47 | num_split = 0 48 | if num_img_dir > 1: 49 | assert num_img_dir == num_ann_dir or num_ann_dir == 0 50 | assert num_img_dir == num_split or num_split == 0 51 | else: 52 | assert num_split == num_ann_dir or num_ann_dir <= 1 53 | num_dset = max(num_split, num_img_dir) 54 | 55 | datasets = [] 56 | for i in range(num_dset): 57 | data_cfg = copy.deepcopy(cfg) 58 | if isinstance(img_dir, (list, tuple)): 59 | data_cfg['img_dir'] = img_dir[i] 60 | if isinstance(ann_dir, (list, tuple)): 61 | data_cfg['ann_dir'] = ann_dir[i] 62 | if isinstance(split, (list, tuple)): 63 | data_cfg['split'] = split[i] 64 | datasets.append(build_dataset(data_cfg, default_args)) 65 | 66 | return ConcatDataset(datasets, separate_eval) 67 | 68 | 69 | def build_dataset(cfg, default_args=None): 70 | """Build datasets.""" 71 | from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, 72 | RepeatDataset) 73 | if isinstance(cfg, (list, tuple)): 74 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) 75 | elif cfg['type'] == 'RepeatDataset': 76 | dataset = RepeatDataset( 77 | build_dataset(cfg['dataset'], default_args), cfg['times']) 78 | elif cfg['type'] == 'MultiImageMixDataset': 79 | cp_cfg = copy.deepcopy(cfg) 80 | cp_cfg['dataset'] = build_dataset(cp_cfg['dataset']) 81 | cp_cfg.pop('type') 82 | dataset = MultiImageMixDataset(**cp_cfg) 83 | elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( 84 | cfg.get('split', None), (list, tuple)): 85 | dataset = _concat_dataset(cfg, default_args) 86 | else: 87 | dataset = build_from_cfg(cfg, DATASETS, default_args) 88 | 89 | return dataset 90 | 91 | 92 | def build_dataloader(dataset, 93 | samples_per_gpu, 94 | workers_per_gpu, 95 | num_gpus=1, 96 | dist=True, 97 | shuffle=True, 98 | seed=None, 99 | drop_last=False, 100 | pin_memory=True, 101 | persistent_workers=True, 102 | **kwargs): 103 | """Build PyTorch DataLoader. 104 | 105 | In distributed training, each GPU/process has a dataloader. 106 | In non-distributed training, there is only one dataloader for all GPUs. 107 | 108 | Args: 109 | dataset (Dataset): A PyTorch dataset. 110 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 111 | batch size of each GPU. 112 | workers_per_gpu (int): How many subprocesses to use for data loading 113 | for each GPU. 114 | num_gpus (int): Number of GPUs. Only used in non-distributed training. 115 | dist (bool): Distributed training/test or not. Default: True. 116 | shuffle (bool): Whether to shuffle the data at every epoch. 117 | Default: True. 118 | seed (int | None): Seed to be used. Default: None. 119 | drop_last (bool): Whether to drop the last incomplete batch in epoch. 120 | Default: False 121 | pin_memory (bool): Whether to use pin_memory in DataLoader. 122 | Default: True 123 | persistent_workers (bool): If True, the data loader will not shutdown 124 | the worker processes after a dataset has been consumed once. 125 | This allows to maintain the workers Dataset instances alive. 126 | The argument also has effect in PyTorch>=1.7.0. 127 | Default: True 128 | kwargs: any keyword argument to be used to initialize DataLoader 129 | 130 | Returns: 131 | DataLoader: A PyTorch dataloader. 132 | """ 133 | rank, world_size = get_dist_info() 134 | if dist: 135 | sampler = DistributedSampler( 136 | dataset, world_size, rank, shuffle=shuffle) 137 | shuffle = False 138 | batch_size = samples_per_gpu 139 | num_workers = workers_per_gpu 140 | else: 141 | sampler = None 142 | batch_size = num_gpus * samples_per_gpu 143 | num_workers = num_gpus * workers_per_gpu 144 | 145 | init_fn = partial( 146 | worker_init_fn, num_workers=num_workers, rank=rank, 147 | seed=seed) if seed is not None else None 148 | 149 | if digit_version(torch.__version__) >= digit_version('1.8.0'): 150 | data_loader = DataLoader( 151 | dataset, 152 | batch_size=batch_size, 153 | sampler=sampler, 154 | num_workers=num_workers, 155 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 156 | pin_memory=pin_memory, 157 | shuffle=shuffle, 158 | worker_init_fn=init_fn, 159 | drop_last=drop_last, 160 | persistent_workers=persistent_workers, 161 | **kwargs) 162 | else: 163 | data_loader = DataLoader( 164 | dataset, 165 | batch_size=batch_size, 166 | sampler=sampler, 167 | num_workers=num_workers, 168 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 169 | pin_memory=pin_memory, 170 | shuffle=shuffle, 171 | worker_init_fn=init_fn, 172 | drop_last=drop_last, 173 | **kwargs) 174 | 175 | return data_loader 176 | 177 | 178 | def worker_init_fn(worker_id, num_workers, rank, seed): 179 | """Worker init func for dataloader. 180 | 181 | The seed of each worker equals to num_worker * rank + worker_id + user_seed 182 | 183 | Args: 184 | worker_id (int): Worker id. 185 | num_workers (int): Number of workers. 186 | rank (int): The rank of current process. 187 | seed (int): The random seed to use. 188 | """ 189 | 190 | worker_seed = num_workers * rank + worker_id + seed 191 | np.random.seed(worker_seed) 192 | random.seed(worker_seed) 193 | -------------------------------------------------------------------------------- /mmseg/datasets/ade.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmcv 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from .builder import DATASETS 9 | from .custom import CustomDataset 10 | 11 | 12 | @DATASETS.register_module() 13 | class ADE20KDataset(CustomDataset): 14 | """ADE20K dataset. 15 | 16 | In segmentation map annotation for ADE20K, 0 stands for background, which 17 | is not included in 150 categories. ``reduce_zero_label`` is fixed to True. 18 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 19 | '.png'. 20 | """ 21 | CLASSES = ( 22 | 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 23 | 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 24 | 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 25 | 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 26 | 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 27 | 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 28 | 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 29 | 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 30 | 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 31 | 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 32 | 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 33 | 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 34 | 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 35 | 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 36 | 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 37 | 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 38 | 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 39 | 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 40 | 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 41 | 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 42 | 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 43 | 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 44 | 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 45 | 'clock', 'flag') 46 | 47 | PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 48 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 49 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 50 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 51 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 52 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 53 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 54 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 55 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 56 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 57 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 58 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 59 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 60 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 61 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 62 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 63 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 64 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 65 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 66 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 67 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 68 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 69 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 70 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 71 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 72 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 73 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 74 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 75 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 76 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 77 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 78 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 79 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 80 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 81 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 82 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 83 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 84 | [102, 255, 0], [92, 0, 255]] 85 | 86 | def __init__(self, **kwargs): 87 | super(ADE20KDataset, self).__init__( 88 | img_suffix='.jpg', 89 | seg_map_suffix='.png', 90 | reduce_zero_label=True, 91 | **kwargs) 92 | 93 | def results2img(self, results, imgfile_prefix, to_label_id, indices=None): 94 | """Write the segmentation results to images. 95 | 96 | Args: 97 | results (list[ndarray]): Testing results of the 98 | dataset. 99 | imgfile_prefix (str): The filename prefix of the png files. 100 | If the prefix is "somepath/xxx", 101 | the png files will be named "somepath/xxx.png". 102 | to_label_id (bool): whether convert output to label_id for 103 | submission. 104 | indices (list[int], optional): Indices of input results, if not 105 | set, all the indices of the dataset will be used. 106 | Default: None. 107 | 108 | Returns: 109 | list[str: str]: result txt files which contains corresponding 110 | semantic segmentation images. 111 | """ 112 | if indices is None: 113 | indices = list(range(len(self))) 114 | 115 | mmcv.mkdir_or_exist(imgfile_prefix) 116 | result_files = [] 117 | for result, idx in zip(results, indices): 118 | 119 | filename = self.img_infos[idx]['filename'] 120 | basename = osp.splitext(osp.basename(filename))[0] 121 | 122 | png_filename = osp.join(imgfile_prefix, f'{basename}.png') 123 | 124 | # The index range of official requirement is from 0 to 150. 125 | # But the index range of output is from 0 to 149. 126 | # That is because we set reduce_zero_label=True. 127 | result = result + 1 128 | 129 | output = Image.fromarray(result.astype(np.uint8)) 130 | output.save(png_filename) 131 | result_files.append(png_filename) 132 | 133 | return result_files 134 | 135 | def format_results(self, 136 | results, 137 | imgfile_prefix, 138 | to_label_id=True, 139 | indices=None): 140 | """Format the results into dir (standard format for ade20k evaluation). 141 | 142 | Args: 143 | results (list): Testing results of the dataset. 144 | imgfile_prefix (str | None): The prefix of images files. It 145 | includes the file path and the prefix of filename, e.g., 146 | "a/b/prefix". 147 | to_label_id (bool): whether convert output to label_id for 148 | submission. Default: False 149 | indices (list[int], optional): Indices of input results, if not 150 | set, all the indices of the dataset will be used. 151 | Default: None. 152 | 153 | Returns: 154 | tuple: (result_files, tmp_dir), result_files is a list containing 155 | the image paths, tmp_dir is the temporal directory created 156 | for saving json/png files when img_prefix is not specified. 157 | """ 158 | 159 | if indices is None: 160 | indices = list(range(len(self))) 161 | 162 | assert isinstance(results, list), 'results must be a list.' 163 | assert isinstance(indices, list), 'indices must be a list.' 164 | 165 | result_files = self.results2img(results, imgfile_prefix, to_label_id, 166 | indices) 167 | return result_files 168 | -------------------------------------------------------------------------------- /mmseg/models/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/model/losses/cross_entropy_loss.py 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from ..builder import LOSSES 11 | from .utils import get_class_weight, weight_reduce_loss 12 | 13 | 14 | def cross_entropy(pred, 15 | label, 16 | weight=None, 17 | class_weight=None, 18 | reduction='mean', 19 | avg_factor=None, 20 | ignore_index=-100): 21 | """The wrapper function for :func:`F.cross_entropy`""" 22 | # class_weight is a manual rescaling weight given to each class. 23 | # If given, has to be a Tensor of size C element-wise losses 24 | loss = F.cross_entropy( 25 | pred, 26 | label, 27 | weight=class_weight, 28 | reduction='none', 29 | ignore_index=ignore_index) 30 | 31 | # apply weights and do the reduction 32 | if weight is not None: 33 | weight = weight.float() 34 | loss = weight_reduce_loss( 35 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 36 | 37 | return loss 38 | 39 | 40 | def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): 41 | """Expand onehot labels to match the size of prediction.""" 42 | bin_labels = labels.new_zeros(target_shape) 43 | valid_mask = (labels >= 0) & (labels != ignore_index) 44 | inds = torch.nonzero(valid_mask, as_tuple=True) 45 | 46 | if inds[0].numel() > 0: 47 | if labels.dim() == 3: 48 | bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 49 | else: 50 | bin_labels[inds[0], labels[valid_mask]] = 1 51 | 52 | valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() 53 | if label_weights is None: 54 | bin_label_weights = valid_mask 55 | else: 56 | bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) 57 | bin_label_weights *= valid_mask 58 | 59 | return bin_labels, bin_label_weights 60 | 61 | 62 | def binary_cross_entropy(pred, 63 | label, 64 | weight=None, 65 | reduction='mean', 66 | avg_factor=None, 67 | class_weight=None, 68 | ignore_index=255): 69 | """Calculate the binary CrossEntropy loss. 70 | 71 | Args: 72 | pred (torch.Tensor): The prediction with shape (N, 1). 73 | label (torch.Tensor): The learning label of the prediction. 74 | weight (torch.Tensor, optional): Sample-wise loss weight. 75 | reduction (str, optional): The method used to reduce the loss. 76 | Options are "none", "mean" and "sum". 77 | avg_factor (int, optional): Average factor that is used to average 78 | the loss. Defaults to None. 79 | class_weight (list[float], optional): The weight for each class. 80 | ignore_index (int | None): The label index to be ignored. Default: 255 81 | 82 | Returns: 83 | torch.Tensor: The calculated loss 84 | """ 85 | if pred.dim() != label.dim(): 86 | assert (pred.dim() == 2 and label.dim() == 1) or ( 87 | pred.dim() == 4 and label.dim() == 3), \ 88 | 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ 89 | 'H, W], label shape [N, H, W] are supported' 90 | label, weight = _expand_onehot_labels(label, weight, pred.shape, 91 | ignore_index) 92 | 93 | # weighted element-wise losses 94 | if weight is not None: 95 | weight = weight.float() 96 | loss = F.binary_cross_entropy_with_logits( 97 | pred, label.float(), pos_weight=class_weight, reduction='none') 98 | # do the reduction for the weighted loss 99 | loss = weight_reduce_loss( 100 | loss, weight, reduction=reduction, avg_factor=avg_factor) 101 | 102 | return loss 103 | 104 | 105 | def mask_cross_entropy(pred, 106 | target, 107 | label, 108 | reduction='mean', 109 | avg_factor=None, 110 | class_weight=None, 111 | ignore_index=None): 112 | """Calculate the CrossEntropy loss for masks. 113 | 114 | Args: 115 | pred (torch.Tensor): The prediction with shape (N, C), C is the number 116 | of classes. 117 | target (torch.Tensor): The learning label of the prediction. 118 | label (torch.Tensor): ``label`` indicates the class label of the mask' 119 | corresponding object. This will be used to select the mask in the 120 | of the class which the object belongs to when the mask prediction 121 | if not class-agnostic. 122 | reduction (str, optional): The method used to reduce the loss. 123 | Options are "none", "mean" and "sum". 124 | avg_factor (int, optional): Average factor that is used to average 125 | the loss. Defaults to None. 126 | class_weight (list[float], optional): The weight for each class. 127 | ignore_index (None): Placeholder, to be consistent with other loss. 128 | Default: None. 129 | 130 | Returns: 131 | torch.Tensor: The calculated loss 132 | """ 133 | assert ignore_index is None, 'BCE loss does not support ignore_index' 134 | # TODO: handle these two reserved arguments 135 | assert reduction == 'mean' and avg_factor is None 136 | num_rois = pred.size()[0] 137 | inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) 138 | pred_slice = pred[inds, label].squeeze(1) 139 | return F.binary_cross_entropy_with_logits( 140 | pred_slice, target, weight=class_weight, reduction='mean')[None] 141 | 142 | 143 | @LOSSES.register_module() 144 | class CrossEntropyLoss(nn.Module): 145 | """CrossEntropyLoss. 146 | 147 | Args: 148 | use_sigmoid (bool, optional): Whether the prediction uses sigmoid 149 | of softmax. Defaults to False. 150 | use_mask (bool, optional): Whether to use mask cross entropy loss. 151 | Defaults to False. 152 | reduction (str, optional): . Defaults to 'mean'. 153 | Options are "none", "mean" and "sum". 154 | class_weight (list[float] | str, optional): Weight of each class. If in 155 | str format, read them from a file. Defaults to None. 156 | loss_weight (float, optional): Weight of the loss. Defaults to 1.0. 157 | loss_name (str, optional): Name of the loss item. If you want this loss 158 | item to be included into the backward graph, `loss_` must be the 159 | prefix of the name. Defaults to 'loss_ce'. 160 | """ 161 | 162 | def __init__(self, 163 | use_sigmoid=False, 164 | use_mask=False, 165 | reduction='mean', 166 | class_weight=None, 167 | loss_weight=1.0, 168 | loss_name='loss_ce'): 169 | super(CrossEntropyLoss, self).__init__() 170 | assert (use_sigmoid is False) or (use_mask is False) 171 | self.use_sigmoid = use_sigmoid 172 | self.use_mask = use_mask 173 | self.reduction = reduction 174 | self.loss_weight = loss_weight 175 | self.class_weight = get_class_weight(class_weight) 176 | 177 | if self.use_sigmoid: 178 | self.cls_criterion = binary_cross_entropy 179 | elif self.use_mask: 180 | self.cls_criterion = mask_cross_entropy 181 | else: 182 | self.cls_criterion = cross_entropy 183 | self._loss_name = loss_name 184 | 185 | def forward(self, 186 | cls_score, 187 | label, 188 | weight=None, 189 | avg_factor=None, 190 | reduction_override=None, 191 | **kwargs): 192 | """Forward function.""" 193 | assert reduction_override in (None, 'none', 'mean', 'sum') 194 | reduction = ( 195 | reduction_override if reduction_override else self.reduction) 196 | if self.class_weight is not None: 197 | class_weight = cls_score.new_tensor(self.class_weight) 198 | else: 199 | class_weight = None 200 | loss_cls = self.loss_weight * self.cls_criterion( 201 | cls_score, 202 | label, 203 | weight, 204 | class_weight=class_weight, 205 | reduction=reduction, 206 | avg_factor=avg_factor, 207 | **kwargs) 208 | return loss_cls 209 | 210 | @property 211 | def loss_name(self): 212 | """Loss Name. 213 | 214 | This function must be implemented and will return the name of this 215 | loss function. This name will be used to combine different loss items 216 | by simple sum operation. In addition, if you want this loss item to be 217 | included into the backward graph, `loss_` must be the prefix of the 218 | name. 219 | Returns: 220 | str: The name of this loss item. 221 | """ 222 | return self._loss_name 223 | -------------------------------------------------------------------------------- /mmseg/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/cityscapes.py 5 | ''' 6 | import os.path as osp 7 | 8 | import mmcv 9 | import numpy as np 10 | from mmcv.utils import print_log 11 | from PIL import Image 12 | 13 | from .builder import DATASETS 14 | from .custom import CustomDataset 15 | 16 | 17 | @DATASETS.register_module() 18 | class CityscapesDataset(CustomDataset): 19 | """Cityscapes dataset. 20 | 21 | The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is 22 | fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. 23 | """ 24 | 25 | CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 26 | 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 27 | 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 28 | 'bicycle') 29 | 30 | PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], 31 | [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], 32 | [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], 33 | [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], 34 | [0, 80, 100], [0, 0, 230], [119, 11, 32]] 35 | 36 | def __init__(self, 37 | img_suffix='_leftImg8bit.png', 38 | seg_map_suffix='_gtFine_labelTrainIds.png', 39 | **kwargs): 40 | super(CityscapesDataset, self).__init__( 41 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 42 | 43 | @staticmethod 44 | def _convert_to_label_id(result): 45 | """Convert trainId to id for cityscapes.""" 46 | if isinstance(result, str): 47 | result = np.load(result) 48 | import cityscapesscripts.helpers.labels as CSLabels 49 | result_copy = result.copy() 50 | for trainId, label in CSLabels.trainId2label.items(): 51 | result_copy[result == trainId] = label.id 52 | 53 | return result_copy 54 | 55 | def results2img(self, results, imgfile_prefix, to_label_id, indices=None): 56 | """Write the segmentation results to images. 57 | 58 | Args: 59 | results (list[ndarray]): Testing results of the 60 | dataset. 61 | imgfile_prefix (str): The filename prefix of the png files. 62 | If the prefix is "somepath/xxx", 63 | the png files will be named "somepath/xxx.png". 64 | to_label_id (bool): whether convert output to label_id for 65 | submission. 66 | indices (list[int], optional): Indices of input results, 67 | if not set, all the indices of the dataset will be used. 68 | Default: None. 69 | 70 | Returns: 71 | list[str: str]: result txt files which contains corresponding 72 | semantic segmentation images. 73 | """ 74 | if indices is None: 75 | indices = list(range(len(self))) 76 | 77 | mmcv.mkdir_or_exist(imgfile_prefix) 78 | result_files = [] 79 | for result, idx in zip(results, indices): 80 | if to_label_id: 81 | result = self._convert_to_label_id(result) 82 | filename = self.img_infos[idx]['filename'] 83 | basename = osp.splitext(osp.basename(filename))[0] 84 | 85 | png_filename = osp.join(imgfile_prefix, f'{basename}.png') 86 | 87 | output = Image.fromarray(result.astype(np.uint8)).convert('P') 88 | import cityscapesscripts.helpers.labels as CSLabels 89 | palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) 90 | for label_id, label in CSLabels.id2label.items(): 91 | palette[label_id] = label.color 92 | 93 | output.putpalette(palette) 94 | output.save(png_filename) 95 | result_files.append(png_filename) 96 | 97 | return result_files 98 | 99 | def format_results(self, 100 | results, 101 | imgfile_prefix, 102 | to_label_id=True, 103 | indices=None): 104 | """Format the results into dir (standard format for Cityscapes 105 | evaluation). 106 | 107 | Args: 108 | results (list): Testing results of the dataset. 109 | imgfile_prefix (str): The prefix of images files. It 110 | includes the file path and the prefix of filename, e.g., 111 | "a/b/prefix". 112 | to_label_id (bool): whether convert output to label_id for 113 | submission. Default: False 114 | indices (list[int], optional): Indices of input results, 115 | if not set, all the indices of the dataset will be used. 116 | Default: None. 117 | 118 | Returns: 119 | tuple: (result_files, tmp_dir), result_files is a list containing 120 | the image paths, tmp_dir is the temporal directory created 121 | for saving json/png files when img_prefix is not specified. 122 | """ 123 | if indices is None: 124 | indices = list(range(len(self))) 125 | 126 | assert isinstance(results, list), 'results must be a list.' 127 | assert isinstance(indices, list), 'indices must be a list.' 128 | 129 | result_files = self.results2img(results, imgfile_prefix, to_label_id, 130 | indices) 131 | 132 | return result_files 133 | 134 | def evaluate(self, 135 | results, 136 | metric='mIoU', 137 | logger=None, 138 | imgfile_prefix=None): 139 | """Evaluation in Cityscapes/default protocol. 140 | 141 | Args: 142 | results (list): Testing results of the dataset. 143 | metric (str | list[str]): Metrics to be evaluated. 144 | logger (logging.Logger | None | str): Logger used for printing 145 | related information during evaluation. Default: None. 146 | imgfile_prefix (str | None): The prefix of output image file, 147 | for cityscapes evaluation only. It includes the file path and 148 | the prefix of filename, e.g., "a/b/prefix". 149 | If results are evaluated with cityscapes protocol, it would be 150 | the prefix of output png files. The output files would be 151 | png images under folder "a/b/prefix/xxx.png", where "xxx" is 152 | the image name of cityscapes. If not specified, a temp file 153 | will be created for evaluation. 154 | Default: None. 155 | 156 | Returns: 157 | dict[str, float]: Cityscapes/default metrics. 158 | """ 159 | 160 | eval_results = dict() 161 | metrics = metric.copy() if isinstance(metric, list) else [metric] 162 | if 'cityscapes' in metrics: 163 | eval_results.update( 164 | self._evaluate_cityscapes(results, logger, imgfile_prefix)) 165 | metrics.remove('cityscapes') 166 | if len(metrics) > 0: 167 | eval_results.update( 168 | super(CityscapesDataset, 169 | self).evaluate(results, metrics, logger)) 170 | 171 | return eval_results 172 | 173 | def _evaluate_cityscapes(self, results, logger, imgfile_prefix): 174 | """Evaluation in Cityscapes protocol. 175 | 176 | Args: 177 | results (list): Testing results of the dataset. 178 | logger (logging.Logger | str | None): Logger used for printing 179 | related information during evaluation. Default: None. 180 | imgfile_prefix (str | None): The prefix of output image file 181 | 182 | Returns: 183 | dict[str: float]: Cityscapes evaluation results. 184 | """ 185 | try: 186 | import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa 187 | except ImportError: 188 | raise ImportError('Please run "pip install cityscapesscripts" to ' 189 | 'install cityscapesscripts first.') 190 | msg = 'Evaluating in Cityscapes style' 191 | if logger is None: 192 | msg = '\n' + msg 193 | print_log(msg, logger=logger) 194 | 195 | result_dir = imgfile_prefix 196 | 197 | eval_results = dict() 198 | print_log(f'Evaluating results under {result_dir} ...', logger=logger) 199 | 200 | CSEval.args.evalInstLevelScore = True 201 | CSEval.args.predictionPath = osp.abspath(result_dir) 202 | CSEval.args.evalPixelAccuracy = True 203 | CSEval.args.JSONOutput = False 204 | 205 | seg_map_list = [] 206 | pred_list = [] 207 | 208 | # when evaluating with official cityscapesscripts, 209 | # **_gtFine_labelIds.png is used 210 | for seg_map in mmcv.scandir( 211 | self.ann_dir, 'gtFine_labelIds.png', recursive=True): 212 | seg_map_list.append(osp.join(self.ann_dir, seg_map)) 213 | pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) 214 | 215 | eval_results.update( 216 | CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) 217 | 218 | return eval_results 219 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2022 Alibaba Group Holding Limited. 3 | This file is modified from: 4 | https://github.com/open-mmlab/mmsegmentation/tree/master/tools/train.py 5 | ''' 6 | import argparse 7 | import copy 8 | import os 9 | import os.path as osp 10 | import time 11 | import warnings 12 | 13 | import mmcv 14 | import torch 15 | from torch import nn 16 | from mmcv.cnn.utils import revert_sync_batchnorm 17 | from mmcv.runner import get_dist_info, init_dist 18 | from mmcv.utils import Config, DictAction, get_git_hash 19 | 20 | from mmseg import __version__ 21 | from mmseg.apis import init_random_seed, set_random_seed, train_segmentor 22 | from mmseg.datasets import build_dataset 23 | from mmseg.models import build_segmentor 24 | from mmseg.utils import collect_env, get_root_logger, setup_multi_processes 25 | 26 | import afformer 27 | 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser(description='Train a segmentor') 32 | parser.add_argument('config', help='train config file path') 33 | parser.add_argument('--work-dir', help='the dir to save logs and models') 34 | parser.add_argument( 35 | '--load-from', help='the checkpoint file to load weights from') 36 | parser.add_argument( 37 | '--resume-from', help='the checkpoint file to resume from') 38 | parser.add_argument( 39 | '--no-validate', 40 | action='store_true', 41 | help='whether not to evaluate the checkpoint during training') 42 | group_gpus = parser.add_mutually_exclusive_group() 43 | group_gpus.add_argument( 44 | '--gpus', 45 | type=int, 46 | help='(Deprecated, please use --gpu-id) number of gpus to use ' 47 | '(only applicable to non-distributed training)') 48 | group_gpus.add_argument( 49 | '--gpu-ids', 50 | type=int, 51 | nargs='+', 52 | help='(Deprecated, please use --gpu-id) ids of gpus to use ' 53 | '(only applicable to non-distributed training)') 54 | group_gpus.add_argument( 55 | '--gpu-id', 56 | type=int, 57 | default=0, 58 | help='id of gpu to use ' 59 | '(only applicable to non-distributed training)') 60 | parser.add_argument('--seed', type=int, default=None, help='random seed') 61 | parser.add_argument( 62 | '--deterministic', 63 | action='store_true', 64 | help='whether to set deterministic options for CUDNN backend.') 65 | parser.add_argument( 66 | '--options', 67 | nargs='+', 68 | action=DictAction, 69 | help="--options is deprecated in favor of --cfg_options' and it will " 70 | 'not be supported in version v0.22.0. Override some settings in the ' 71 | 'used config, the key-value pair in xxx=yyy format will be merged ' 72 | 'into config file. If the value to be overwritten is a list, it ' 73 | 'should be like key="[a,b]" or key=a,b It also allows nested ' 74 | 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' 75 | 'marks are necessary and that no white space is allowed.') 76 | parser.add_argument( 77 | '--cfg-options', 78 | nargs='+', 79 | action=DictAction, 80 | help='override some settings in the used config, the key-value pair ' 81 | 'in xxx=yyy format will be merged into config file. If the value to ' 82 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 83 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 84 | 'Note that the quotation marks are necessary and that no white space ' 85 | 'is allowed.') 86 | parser.add_argument( 87 | '--launcher', 88 | choices=['none', 'pytorch', 'slurm', 'mpi'], 89 | default='none', 90 | help='job launcher') 91 | parser.add_argument('--local_rank', type=int, default=0) 92 | parser.add_argument( 93 | '--auto-resume', 94 | action='store_true', 95 | help='resume from the latest checkpoint automatically.') 96 | args = parser.parse_args() 97 | if 'LOCAL_RANK' not in os.environ: 98 | os.environ['LOCAL_RANK'] = str(args.local_rank) 99 | 100 | if args.options and args.cfg_options: 101 | raise ValueError( 102 | '--options and --cfg-options cannot be both ' 103 | 'specified, --options is deprecated in favor of --cfg-options. ' 104 | '--options will not be supported in version v0.22.0.') 105 | if args.options: 106 | warnings.warn('--options is deprecated in favor of --cfg-options. ' 107 | '--options will not be supported in version v0.22.0.') 108 | args.cfg_options = args.options 109 | 110 | return args 111 | 112 | 113 | def main(): 114 | args = parse_args() 115 | 116 | cfg = Config.fromfile(args.config) 117 | if args.cfg_options is not None: 118 | cfg.merge_from_dict(args.cfg_options) 119 | 120 | # set cudnn_benchmark 121 | if cfg.get('cudnn_benchmark', False): 122 | torch.backends.cudnn.benchmark = True 123 | 124 | # work_dir is determined in this priority: CLI > segment in file > filename 125 | if args.work_dir is not None: 126 | # update configs according to CLI args if args.work_dir is not None 127 | cfg.work_dir = args.work_dir 128 | elif cfg.get('work_dir', None) is None: 129 | # use config filename as default work_dir if cfg.work_dir is None 130 | cfg.work_dir = osp.join('./work_dirs', 131 | osp.splitext(osp.basename(args.config))[0]) 132 | if args.load_from is not None: 133 | cfg.load_from = args.load_from 134 | if args.resume_from is not None: 135 | cfg.resume_from = args.resume_from 136 | if args.gpus is not None: 137 | cfg.gpu_ids = range(1) 138 | warnings.warn('`--gpus` is deprecated because we only support ' 139 | 'single GPU mode in non-distributed training. ' 140 | 'Use `gpus=1` now.') 141 | if args.gpu_ids is not None: 142 | cfg.gpu_ids = args.gpu_ids[0:1] 143 | warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' 144 | 'Because we only support single GPU mode in ' 145 | 'non-distributed training. Use the first GPU ' 146 | 'in `gpu_ids` now.') 147 | if args.gpus is None and args.gpu_ids is None: 148 | cfg.gpu_ids = [args.gpu_id] 149 | 150 | cfg.auto_resume = args.auto_resume 151 | 152 | # init distributed env first, since logger depends on the dist info. 153 | if args.launcher == 'none': 154 | distributed = False 155 | else: 156 | distributed = True 157 | init_dist(args.launcher, **cfg.dist_params) 158 | # gpu_ids is used to calculate iter when resuming checkpoint 159 | _, world_size = get_dist_info() 160 | cfg.gpu_ids = range(world_size) 161 | 162 | # create work_dir 163 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 164 | # dump config 165 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 166 | # init the logger before other steps 167 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 168 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 169 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 170 | 171 | # set multi-process settings 172 | setup_multi_processes(cfg) 173 | 174 | # init the meta dict to record some important information such as 175 | # environment info and seed, which will be logged 176 | meta = dict() 177 | # log env info 178 | env_info_dict = collect_env() 179 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 180 | dash_line = '-' * 60 + '\n' 181 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 182 | dash_line) 183 | meta['env_info'] = env_info 184 | 185 | # log some basic info 186 | logger.info(f'Distributed training: {distributed}') 187 | logger.info(f'Config:\n{cfg.pretty_text}') 188 | 189 | # set random seeds 190 | seed = init_random_seed(args.seed) 191 | logger.info(f'Set random seed to {seed}, ' 192 | f'deterministic: {args.deterministic}') 193 | set_random_seed(seed, deterministic=args.deterministic) 194 | cfg.seed = seed 195 | meta['seed'] = seed 196 | meta['exp_name'] = osp.basename(args.config) 197 | 198 | model = build_segmentor( 199 | cfg.model, 200 | train_cfg=cfg.get('train_cfg'), 201 | test_cfg=cfg.get('test_cfg')) 202 | model.init_weights() 203 | 204 | # SyncBN is not support for DP 205 | if not distributed: 206 | warnings.warn( 207 | 'SyncBN is only supported with DDP. To be compatible with DP, ' 208 | 'we convert SyncBN to BN. Please use dist_train.sh which can ' 209 | 'avoid this error.') 210 | model = revert_sync_batchnorm(model) 211 | 212 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 213 | logger.info(model) 214 | 215 | datasets = [build_dataset(cfg.data.train)] 216 | if len(cfg.workflow) == 2: 217 | val_dataset = copy.deepcopy(cfg.data.val) 218 | val_dataset.pipeline = cfg.data.train.pipeline 219 | datasets.append(build_dataset(val_dataset)) 220 | if cfg.checkpoint_config is not None: 221 | # save mmseg version, config file content and class names in 222 | # checkpoints as meta data 223 | cfg.checkpoint_config.meta = dict( 224 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 225 | config=cfg.pretty_text, 226 | CLASSES=datasets[0].CLASSES, 227 | PALETTE=datasets[0].PALETTE) 228 | # add an attribute for visualization convenience 229 | model.CLASSES = datasets[0].CLASSES 230 | # passing checkpoint meta for saving best checkpoint 231 | meta.update(cfg.checkpoint_config.meta) 232 | train_segmentor( 233 | model, 234 | datasets, 235 | cfg, 236 | distributed=distributed, 237 | validate=(not args.no_validate), 238 | timestamp=timestamp, 239 | meta=meta) 240 | 241 | 242 | if __name__ == '__main__': 243 | main() 244 | --------------------------------------------------------------------------------