├── pretrained └── README.md ├── TAPADL_FAN ├── models │ └── __init__.py ├── segmentation │ ├── mmseg │ │ ├── models │ │ │ ├── necks │ │ │ │ └── __init__.py │ │ │ ├── backbones │ │ │ │ └── __init__.py │ │ │ ├── segmentors │ │ │ │ ├── __init__.py │ │ │ │ └── cascade_encoder_decoder.py │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── make_divisible.py │ │ │ │ ├── se_layer.py │ │ │ │ ├── norm.py │ │ │ │ ├── res_layer.py │ │ │ │ └── up_conv_block.py │ │ │ ├── __init__.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── accuracy.py │ │ │ │ └── utils.py │ │ │ ├── decode_heads │ │ │ │ ├── __init__.py │ │ │ │ ├── cc_head.py │ │ │ │ ├── nl_head.py │ │ │ │ ├── gc_head.py │ │ │ │ ├── sep_fcn_head.py │ │ │ │ ├── cascade_decode_head.py │ │ │ │ ├── fpn_head.py │ │ │ │ ├── fcn_head.py │ │ │ │ ├── setr_up_head.py │ │ │ │ ├── segformer_head.py │ │ │ │ ├── lraspp_head.py │ │ │ │ ├── psp_head.py │ │ │ │ ├── aspp_head.py │ │ │ │ ├── sep_aspp_head.py │ │ │ │ ├── uper_head.py │ │ │ │ └── ocr_head.py │ │ │ └── builder.py │ │ ├── core │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ └── misc.py │ │ │ ├── __init__.py │ │ │ ├── seg │ │ │ │ ├── sampler │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_pixel_sampler.py │ │ │ │ │ └── ohem_pixel_sampler.py │ │ │ │ ├── __init__.py │ │ │ │ └── builder.py │ │ │ └── evaluation │ │ │ │ ├── __init__.py │ │ │ │ └── eval_hooks.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── wrappers.py │ │ │ └── encoding.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── collect_env.py │ │ │ └── logger.py │ │ ├── apis │ │ │ ├── __init__.py │ │ │ └── inference.py │ │ ├── version.py │ │ ├── datasets │ │ │ ├── hrf.py │ │ │ ├── stare.py │ │ │ ├── drive.py │ │ │ ├── pipelines │ │ │ │ ├── __init__.py │ │ │ │ └── compose.py │ │ │ ├── chase_db1.py │ │ │ ├── __init__.py │ │ │ ├── voc.py │ │ │ ├── dataset_wrappers.py │ │ │ ├── pascal_context.py │ │ │ └── mapillary.py │ │ └── __init__.py │ ├── local_configs │ │ ├── _base_ │ │ │ ├── default_runtime.py │ │ │ ├── schedules │ │ │ │ ├── schedule_20k.py │ │ │ │ ├── schedule_40k.py │ │ │ │ ├── schedule_40k_8gpu_adamw.py │ │ │ │ ├── schedule_80k.py │ │ │ │ ├── schedule_80k_8gpu_adamw.py │ │ │ │ ├── schedule_160k.py │ │ │ │ ├── schedule_160k_8gpu_adamw.py │ │ │ │ ├── schedule_40k_8gpu_sgd.py │ │ │ │ └── schedule_80k_8gpu_sgd.py │ │ │ ├── models │ │ │ │ ├── segformer.py │ │ │ │ ├── lraspp_m-v3-d8.py │ │ │ │ ├── fpn_r50.py │ │ │ │ ├── cgnet.py │ │ │ │ ├── ccnet_r50-d8.py │ │ │ │ ├── danet_r50-d8.py │ │ │ │ ├── pspnet_r50-d8.py │ │ │ │ ├── deeplabv3_r50-d8.py │ │ │ │ ├── fcn_r50-d8.py │ │ │ │ ├── upernet_r50.py │ │ │ │ ├── apcnet_r50-d8.py │ │ │ │ ├── dmnet_r50-d8.py │ │ │ │ ├── dnl_r50-d8.py │ │ │ │ ├── nonlocal_r50-d8.py │ │ │ │ ├── gcnet_r50-d8.py │ │ │ │ ├── emanet_r50-d8.py │ │ │ │ ├── ann_r50-d8.py │ │ │ │ ├── deeplabv3plus_r50-d8.py │ │ │ │ ├── ocrnet_r50-d8.py │ │ │ │ ├── psanet_r50-d8.py │ │ │ │ ├── encnet_r50-d8.py │ │ │ │ ├── pspnet_unet_s5-d16.py │ │ │ │ ├── deeplabv3_unet_s5-d16.py │ │ │ │ ├── fcn_unet_s5-d16.py │ │ │ │ ├── fcn_hr18.py │ │ │ │ ├── pointrend_r50.py │ │ │ │ ├── fast_scnn.py │ │ │ │ ├── ocrnet_hr18.py │ │ │ │ └── setr_pup.py │ │ │ └── datasets │ │ │ │ ├── cityscapes_1024x1024_repeat.py │ │ │ │ ├── cityscapes_1024x1024_repeat_cityc.py │ │ │ │ └── cityscapes_1024x1024_repeat_acdc.py │ │ └── fan │ │ │ └── fan_hybrid │ │ │ ├── tapfan_hybrid_base.1024x1024.city.160k.test.py │ │ │ ├── tapfan_hybrid_base.1024x1024.city.160k.test.acdc.py │ │ │ ├── tapfan_hybrid_base.1024x1024.city.160k.test.cityc.py │ │ │ └── tapadl_fan_hybrid_base.1024x1024.city.160k.py │ ├── tools │ │ └── gen_city_c.py │ └── README.md ├── utils │ ├── __init__.py │ └── scaler.py ├── README.md └── myutils.py ├── imgs └── motivation.jpg ├── requirements.txt ├── LICENSE ├── TAPADL_RVT ├── losses.py ├── README.md └── samplers.py └── README.md /pretrained/README.md: -------------------------------------------------------------------------------- 1 | Please put all the pretrained models here. -------------------------------------------------------------------------------- /TAPADL_FAN/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fan import * 2 | from .tap_fan import * 3 | -------------------------------------------------------------------------------- /imgs/motivation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyongcs/TAPADL/HEAD/imgs/motivation.jpg -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn import FPN 2 | 3 | __all__ = ['FPN'] 4 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import add_prefix 2 | 3 | __all__ = ['add_prefix'] 4 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .fan import * 2 | from .swin_utils import * 3 | from .tap_fan import * 4 | 5 | -------------------------------------------------------------------------------- /TAPADL_FAN/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import load_for_transfer_learning, load_for_probing 2 | from .scaler import ApexScaler_SAM 3 | from .mce_utils import * 4 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoding import Encoding 2 | from .wrappers import Upsample, resize 3 | 4 | __all__ = ['Upsample', 'resize', 'Encoding'] 5 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation import * # noqa: F401, F403 2 | from .seg import * # noqa: F401, F403 3 | from .utils import * # noqa: F401, F403 4 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .collect_env import collect_env 2 | from .logger import get_root_logger, print_log 3 | 4 | __all__ = ['get_root_logger', 'collect_env', 'print_log'] 5 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_pixel_sampler import BasePixelSampler 2 | from .ohem_pixel_sampler import OHEMPixelSampler 3 | 4 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 5 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | from .cascade_encoder_decoder import CascadeEncoderDecoder 2 | from .encoder_decoder import EncoderDecoder 3 | 4 | __all__ = ['EncoderDecoder', 'CascadeEncoderDecoder'] 5 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_pixel_sampler 2 | from .sampler import BasePixelSampler, OHEMPixelSampler 3 | 4 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 5 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | 3 | PIXEL_SAMPLERS = Registry('pixel sampler') 4 | 5 | 6 | def build_pixel_sampler(cfg, **default_args): 7 | """Build pixel sampler for segmentation map.""" 8 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 9 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .class_names import get_classes, get_palette 2 | from .eval_hooks import DistEvalHook, EvalHook 3 | from .metrics import eval_metrics, mean_dice, mean_iou 4 | 5 | __all__ = [ 6 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics', 7 | 'get_classes', 'get_palette' 8 | ] 9 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class BasePixelSampler(metaclass=ABCMeta): 5 | """Base class of pixel sampler.""" 6 | 7 | def __init__(self, **kwargs): 8 | pass 9 | 10 | @abstractmethod 11 | def sample(self, seg_logit, seg_label): 12 | """Placeholder for sample function.""" 13 | pass 14 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=1, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 2 | from .test import multi_gpu_test, single_gpu_test 3 | from .train import get_root_logger, set_random_seed, train_segmentor 4 | 5 | __all__ = [ 6 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 7 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 8 | 'show_result_pyplot' 9 | ] 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .inverted_residual import InvertedResidual, InvertedResidualV3 2 | from .make_divisible import make_divisible 3 | from .res_layer import ResLayer 4 | from .self_attention_block import SelfAttentionBlock 5 | from .up_conv_block import UpConvBlock 6 | 7 | __all__ = [ 8 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 9 | 'UpConvBlock', 'InvertedResidualV3' 10 | ] 11 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_20k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=20000) 8 | checkpoint_config = dict(by_epoch=False, interval=2000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_40k_8gpu_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.02/100, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_80k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=8000) 9 | evaluation = dict(interval=8000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_80k_8gpu_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.02/100, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=16000) 9 | evaluation = dict(interval=16000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_160k_8gpu_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.02/100, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_40k_8gpu_sgd.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/schedules/schedule_80k_8gpu_sgd.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=8000) 9 | evaluation = dict(interval=8000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | def add_prefix(inputs, prefix): 2 | """Add prefix for dict. 3 | 4 | Args: 5 | inputs (dict): The input dict with str keys. 6 | prefix (str): The prefix to add. 7 | 8 | Returns: 9 | 10 | dict: The dict with keys updated with ``prefix``. 11 | """ 12 | 13 | outputs = dict() 14 | for name, value in inputs.items(): 15 | outputs[f'{prefix}.{name}'] = value 16 | 17 | return outputs 18 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import collect_env as collect_base_env 2 | from mmcv.utils import get_git_hash 3 | 4 | import mmseg 5 | 6 | 7 | def collect_env(): 8 | """Collect the information of the running environments.""" 9 | env_info = collect_base_env() 10 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 11 | 12 | return env_info 13 | 14 | 15 | if __name__ == '__main__': 16 | for name, val in collect_env().items(): 17 | print('{}: {}'.format(name, val)) 18 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * # noqa: F401,F403 2 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 3 | build_head, build_loss, build_segmentor) 4 | from .decode_heads import * # noqa: F401,F403 5 | from .losses import * # noqa: F401,F403 6 | from .necks import * # noqa: F401,F403 7 | from .segmentors import * # noqa: F401,F403 8 | 9 | __all__ = [ 10 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 11 | 'build_head', 'build_loss', 'build_segmentor' 12 | ] 13 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy, accuracy 2 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 3 | cross_entropy, mask_cross_entropy) 4 | from .lovasz_loss import LovaszLoss 5 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 6 | 7 | __all__ = [ 8 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 9 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 10 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss' 11 | ] 12 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.11.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.4 2 | matplotlib-base==3.3.4 3 | mccabe==0.6.1 4 | mmcv-full==1.7.0 5 | mmsegmentation==0.30.0 6 | mmdet==2.11.0 7 | mmpycocotools==12.0.3 8 | numpy==1.20.1 9 | onnx==1.11.0 10 | opencv-contrib-python-headless==4.5.4.58 11 | opencv-python==4.5.4.58 12 | opencv-python-headless==4.5.4.58 13 | pillow==8.2.0 14 | prettytable==3.6.0 15 | protobuf==3.17.2 16 | pytorch==1.9.0 17 | pyyaml==5.4.1 18 | scikit-image==0.18.1 19 | scikit-learn==0.24.1 20 | scipy==1.6.2 21 | six==1.15.0 22 | tensorboard==2.4.0 23 | tensorflow==2.4.1 24 | terminaltables==3.1.10 25 | threadpoolctl==2.1.0 26 | tifffile==2020.10.1 27 | timm==0.7.0.dev0 28 | torch==1.8.1 29 | torchaudio==0.9.0 30 | torchvision==0.9.1 31 | tqdm==4.59.0 32 | einops==0.3.2 33 | kornia==0.6.1 -------------------------------------------------------------------------------- /TAPADL_FAN/utils/scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.utils import ApexScaler, NativeScaler 3 | try: 4 | from apex import amp 5 | has_apex = True 6 | except ImportError: 7 | amp = None 8 | has_apex = False 9 | 10 | class ApexScaler_SAM(ApexScaler): 11 | 12 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step=0, rho=0.05): 13 | with amp.scale_loss(loss, optimizer) as scaled_loss: 14 | scaled_loss.backward(create_graph=create_graph) 15 | if step==0 or step==2: 16 | if clip_grad is not None: 17 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 18 | optimizer.step() 19 | elif step==1: 20 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), rho, norm_type=2.0) 21 | optimizer.step() 22 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/hrf.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class HRFDataset(CustomDataset): 9 | """HRF dataset. 10 | 11 | In segmentation map annotation for HRF, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(HRFDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert osp.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/stare.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class STAREDataset(CustomDataset): 9 | """STARE dataset. 10 | 11 | In segmentation map annotation for STARE, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '.ah.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(STAREDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='.ah.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert osp.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/segformer.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | find_unused_parameters = True 4 | model = dict( 5 | type='EncoderDecoder', 6 | pretrained=None, 7 | backbone=dict( 8 | type='mit_b2', 9 | # type='IMTRv21_5', 10 | style='pytorch'), 11 | decode_head=dict( 12 | type='SegFormerHead', 13 | in_channels=[64, 128, 320, 512], 14 | in_index=[0, 1, 2, 3], 15 | feature_strides=[4, 8, 16, 32], 16 | channels=128, 17 | dropout_ratio=0.1, 18 | num_classes=19, 19 | norm_cfg=norm_cfg, 20 | align_corners=False, 21 | decoder_params=dict(), 22 | loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 23 | # model training and testing settings 24 | train_cfg=dict(), 25 | test_cfg=dict(mode='whole')) 26 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/drive.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class DRIVEDataset(CustomDataset): 9 | """DRIVE dataset. 10 | 11 | In segmentation map annotation for DRIVE, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_manual1.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(DRIVEDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_manual1.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert osp.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .compose import Compose 2 | from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, 3 | Transpose, to_tensor) 4 | from .loading import LoadAnnotations, LoadImageFromFile 5 | from .test_time_aug import MultiScaleFlipAug 6 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 7 | PhotoMetricDistortion, RandomCrop, RandomFlip, 8 | RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) 9 | 10 | __all__ = [ 11 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 12 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 13 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 14 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 15 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' 16 | ] 17 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/lraspp_m-v3-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | backbone=dict( 6 | type='MobileNetV3', 7 | arch='large', 8 | out_indices=(1, 3, 16), 9 | norm_cfg=norm_cfg), 10 | decode_head=dict( 11 | type='LRASPPHead', 12 | in_channels=(16, 24, 960), 13 | in_index=(0, 1, 2), 14 | channels=128, 15 | input_transform='multiple_select', 16 | dropout_ratio=0.1, 17 | num_classes=19, 18 | norm_cfg=norm_cfg, 19 | act_cfg=dict(type='ReLU'), 20 | align_corners=False, 21 | loss_decode=dict( 22 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 23 | # model training and testing settings 24 | train_cfg=dict(), 25 | test_cfg=dict(mode='whole')) 26 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class ChaseDB1Dataset(CustomDataset): 9 | """Chase_db1 dataset. 10 | 11 | In segmentation map annotation for Chase_db1, 0 stands for background, 12 | which is included in 2 categories. ``reduce_zero_label`` is fixed to False. 13 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_1stHO.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(ChaseDB1Dataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_1stHO.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert osp.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | from .version import __version__, version_info 4 | 5 | MMCV_MIN = '1.1.4' 6 | MMCV_MAX = '1.7.0' 7 | 8 | 9 | def digit_version(version_str): 10 | digit_version = [] 11 | for x in version_str.split('.'): 12 | if x.isdigit(): 13 | digit_version.append(int(x)) 14 | elif x.find('rc') != -1: 15 | patch_version = x.split('rc') 16 | digit_version.append(int(patch_version[0]) - 1) 17 | digit_version.append(int(patch_version[1])) 18 | return digit_version 19 | 20 | 21 | mmcv_min_version = digit_version(MMCV_MIN) 22 | mmcv_max_version = digit_version(MMCV_MAX) 23 | mmcv_version = digit_version(mmcv.__version__) 24 | 25 | 26 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 27 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 28 | f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 29 | 30 | __all__ = ['__version__', 'version_info'] 31 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ade import ADE20KDataset 2 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 3 | from .chase_db1 import ChaseDB1Dataset 4 | from .cityscapes import CityscapesDataset 5 | from .custom import CustomDataset 6 | from .dataset_wrappers import ConcatDataset, RepeatDataset 7 | from .drive import DRIVEDataset 8 | from .hrf import HRFDataset 9 | from .pascal_context import PascalContextDataset 10 | from .stare import STAREDataset 11 | from .voc import PascalVOCDataset 12 | from .mapillary import MapillaryDataset 13 | from .cocostuff import CocoStuff 14 | from .acdc import ACDCDataset 15 | 16 | __all__ = [ 17 | 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 18 | 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', 19 | 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 20 | 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'MapillaryDataset', 'CocoStuff', 'ACDCDataset' 21 | ] 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 guoyong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/fpn_r50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 1, 1), 12 | strides=(1, 2, 2, 2), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | neck=dict( 18 | type='FPN', 19 | in_channels=[256, 512, 1024, 2048], 20 | out_channels=256, 21 | num_outs=4), 22 | decode_head=dict( 23 | type='FPNHead', 24 | in_channels=[256, 256, 256, 256], 25 | in_index=[0, 1, 2, 3], 26 | feature_strides=[4, 8, 16, 32], 27 | channels=128, 28 | dropout_ratio=0.1, 29 | num_classes=19, 30 | norm_cfg=norm_cfg, 31 | align_corners=False, 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 34 | # model training and testing settings 35 | train_cfg=dict(), 36 | test_cfg=dict(mode='whole')) 37 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .ann_head import ANNHead 2 | from .apc_head import APCHead 3 | from .aspp_head import ASPPHead 4 | from .cc_head import CCHead 5 | from .da_head import DAHead 6 | from .dm_head import DMHead 7 | from .dnl_head import DNLHead 8 | from .ema_head import EMAHead 9 | from .enc_head import EncHead 10 | from .fcn_head import FCNHead 11 | from .fpn_head import FPNHead 12 | from .gc_head import GCHead 13 | from .lraspp_head import LRASPPHead 14 | from .nl_head import NLHead 15 | from .ocr_head import OCRHead 16 | from .point_head import PointHead 17 | from .psa_head import PSAHead 18 | from .psp_head import PSPHead 19 | from .sep_aspp_head import DepthwiseSeparableASPPHead 20 | from .sep_fcn_head import DepthwiseSeparableFCNHead 21 | from .uper_head import UPerHead 22 | 23 | 24 | from .segformer_head import SegFormerHead 25 | from .setr_up_head import SETRUPHead 26 | 27 | 28 | __all__ = [ 29 | 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 30 | 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 31 | 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 32 | 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 33 | 'SegFormerHead', 34 | ] 35 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/voc.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class PascalVOCDataset(CustomDataset): 9 | """Pascal VOC dataset. 10 | 11 | Args: 12 | split (str): Split txt file for Pascal VOC. 13 | """ 14 | 15 | CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 16 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 17 | 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 18 | 'train', 'tvmonitor') 19 | 20 | PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], 21 | [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], 22 | [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], 23 | [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], 24 | [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] 25 | 26 | def __init__(self, split, **kwargs): 27 | super(PascalVOCDataset, self).__init__( 28 | img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) 29 | assert osp.exists(self.img_dir) and self.split is not None 30 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/cgnet.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | backbone=dict( 6 | type='CGNet', 7 | norm_cfg=norm_cfg, 8 | in_channels=3, 9 | num_channels=(32, 64, 128), 10 | num_blocks=(3, 21), 11 | dilations=(2, 4), 12 | reductions=(8, 16)), 13 | decode_head=dict( 14 | type='FCNHead', 15 | in_channels=256, 16 | in_index=2, 17 | channels=256, 18 | num_convs=0, 19 | concat_input=False, 20 | dropout_ratio=0, 21 | num_classes=19, 22 | norm_cfg=norm_cfg, 23 | loss_decode=dict( 24 | type='CrossEntropyLoss', 25 | use_sigmoid=False, 26 | loss_weight=1.0, 27 | class_weight=[ 28 | 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352, 29 | 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905, 30 | 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587, 31 | 10.396974, 10.055647 32 | ])), 33 | # model training and testing settings 34 | train_cfg=dict(sampler=None), 35 | test_cfg=dict(mode='whole')) 36 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 2 | """Make divisible function. 3 | 4 | This function rounds the channel number to the nearest value that can be 5 | divisible by the divisor. It is taken from the original tf repo. It ensures 6 | that all layers have a channel number that is divisible by divisor. It can 7 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 8 | 9 | Args: 10 | value (int): The original channel number. 11 | divisor (int): The divisor to fully divide the channel number. 12 | min_value (int): The minimum value of the output channel. 13 | Default: None, means that the minimum value equal to the divisor. 14 | min_ratio (float): The minimum ratio of the rounded channel number to 15 | the original channel number. Default: 0.9. 16 | 17 | Returns: 18 | int: The modified output channel number. 19 | """ 20 | 21 | if min_value is None: 22 | min_value = divisor 23 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 24 | # Make sure that round down does not go down by more than (1-min_ratio). 25 | if new_value < min_ratio * value: 26 | new_value += divisor 27 | return new_value 28 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/ccnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='CCHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | recurrence=2, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/danet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='DAHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | pam_channels=64, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/cc_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..builder import HEADS 4 | from .fcn_head import FCNHead 5 | 6 | try: 7 | from mmcv.ops import CrissCrossAttention 8 | except ModuleNotFoundError: 9 | CrissCrossAttention = None 10 | 11 | 12 | @HEADS.register_module() 13 | class CCHead(FCNHead): 14 | """CCNet: Criss-Cross Attention for Semantic Segmentation. 15 | 16 | This head is the implementation of `CCNet 17 | `_. 18 | 19 | Args: 20 | recurrence (int): Number of recurrence of Criss Cross Attention 21 | module. Default: 2. 22 | """ 23 | 24 | def __init__(self, recurrence=2, **kwargs): 25 | if CrissCrossAttention is None: 26 | raise RuntimeError('Please install mmcv-full for ' 27 | 'CrissCrossAttention ops') 28 | super(CCHead, self).__init__(num_convs=2, **kwargs) 29 | self.recurrence = recurrence 30 | self.cca = CrissCrossAttention(self.channels) 31 | 32 | def forward(self, inputs): 33 | """Forward function.""" 34 | x = self._transform_inputs(inputs) 35 | output = self.convs[0](x) 36 | for _ in range(self.recurrence): 37 | output = self.cca(output) 38 | output = self.convs[1](output) 39 | if self.concat_input: 40 | output = self.conv_cat(torch.cat([x, output], dim=1)) 41 | output = self.cls_seg(output) 42 | return output 43 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/pspnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='PSPHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | pool_scales=(1, 2, 3, 6), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/deeplabv3_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='ASPPHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | dilations=(1, 12, 24, 36), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/fcn_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='FCNHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | num_convs=2, 23 | concat_input=True, 24 | dropout_ratio=0.1, 25 | num_classes=19, 26 | norm_cfg=norm_cfg, 27 | align_corners=False, 28 | loss_decode=dict( 29 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 30 | auxiliary_head=dict( 31 | type='FCNHead', 32 | in_channels=1024, 33 | in_index=2, 34 | channels=256, 35 | num_convs=1, 36 | concat_input=False, 37 | dropout_ratio=0.1, 38 | num_classes=19, 39 | norm_cfg=norm_cfg, 40 | align_corners=False, 41 | loss_decode=dict( 42 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 43 | # model training and testing settings 44 | train_cfg=dict(), 45 | test_cfg=dict(mode='whole')) 46 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/upernet_r50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 1, 1), 12 | strides=(1, 2, 2, 2), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='UPerHead', 19 | in_channels=[256, 512, 1024, 2048], 20 | in_index=[0, 1, 2, 3], 21 | pool_scales=(1, 2, 3, 6), 22 | channels=512, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/apcnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='APCHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | pool_scales=(1, 2, 3, 6), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=dict(type='SyncBN', requires_grad=True), 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/dmnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='DMHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | filter_sizes=(1, 3, 5, 7), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=dict(type='SyncBN', requires_grad=True), 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/dnl_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='DNLHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | dropout_ratio=0.1, 23 | reduction=2, 24 | use_scale=True, 25 | mode='embedded_gaussian', 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/nonlocal_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='NLHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | dropout_ratio=0.1, 23 | reduction=2, 24 | use_scale=True, 25 | mode='embedded_gaussian', 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/gcnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='GCHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | ratio=1 / 4., 23 | pooling_type='att', 24 | fusion_types=('channel_add', ), 25 | dropout_ratio=0.1, 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/emanet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='EMAHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=256, 22 | ema_channels=512, 23 | num_bases=64, 24 | num_stages=3, 25 | momentum=0.1, 26 | dropout_ratio=0.1, 27 | num_classes=19, 28 | norm_cfg=norm_cfg, 29 | align_corners=False, 30 | loss_decode=dict( 31 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 32 | auxiliary_head=dict( 33 | type='FCNHead', 34 | in_channels=1024, 35 | in_index=2, 36 | channels=256, 37 | num_convs=1, 38 | concat_input=False, 39 | dropout_ratio=0.1, 40 | num_classes=19, 41 | norm_cfg=norm_cfg, 42 | align_corners=False, 43 | loss_decode=dict( 44 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 45 | # model training and testing settings 46 | train_cfg=dict(), 47 | test_cfg=dict(mode='whole')) 48 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/ann_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='ANNHead', 19 | in_channels=[1024, 2048], 20 | in_index=[2, 3], 21 | channels=512, 22 | project_channels=256, 23 | query_scales=(1, ), 24 | key_pool_scales=(1, 3, 6, 8), 25 | dropout_ratio=0.1, 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/deeplabv3plus_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='DepthwiseSeparableASPPHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | dilations=(1, 12, 24, 36), 23 | c1_in_channels=256, 24 | c1_channels=48, 25 | dropout_ratio=0.1, 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/ocrnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='CascadeEncoderDecoder', 5 | num_stages=2, 6 | pretrained='open-mmlab://resnet50_v1c', 7 | backbone=dict( 8 | type='ResNetV1c', 9 | depth=50, 10 | num_stages=4, 11 | out_indices=(0, 1, 2, 3), 12 | dilations=(1, 1, 2, 4), 13 | strides=(1, 2, 1, 1), 14 | norm_cfg=norm_cfg, 15 | norm_eval=False, 16 | style='pytorch', 17 | contract_dilation=True), 18 | decode_head=[ 19 | dict( 20 | type='FCNHead', 21 | in_channels=1024, 22 | in_index=2, 23 | channels=256, 24 | num_convs=1, 25 | concat_input=False, 26 | dropout_ratio=0.1, 27 | num_classes=19, 28 | norm_cfg=norm_cfg, 29 | align_corners=False, 30 | loss_decode=dict( 31 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 32 | dict( 33 | type='OCRHead', 34 | in_channels=2048, 35 | in_index=3, 36 | channels=512, 37 | ocr_channels=256, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) 44 | ], 45 | # model training and testing settings 46 | train_cfg=dict(), 47 | test_cfg=dict(mode='whole')) 48 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/psanet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='PSAHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | mask_size=(97, 97), 23 | psa_type='bi-direction', 24 | compact=False, 25 | shrink_factor=2, 26 | normalization_factor=1.0, 27 | psa_softmax=True, 28 | dropout_ratio=0.1, 29 | num_classes=19, 30 | norm_cfg=norm_cfg, 31 | align_corners=False, 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 34 | auxiliary_head=dict( 35 | type='FCNHead', 36 | in_channels=1024, 37 | in_index=2, 38 | channels=256, 39 | num_convs=1, 40 | concat_input=False, 41 | dropout_ratio=0.1, 42 | num_classes=19, 43 | norm_cfg=norm_cfg, 44 | align_corners=False, 45 | loss_decode=dict( 46 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 47 | # model training and testing settings 48 | train_cfg=dict(), 49 | test_cfg=dict(mode='whole')) 50 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/encnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='EncHead', 19 | in_channels=[512, 1024, 2048], 20 | in_index=(1, 2, 3), 21 | channels=512, 22 | num_codes=32, 23 | use_se_loss=True, 24 | add_lateral=False, 25 | dropout_ratio=0.1, 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 31 | loss_se_decode=dict( 32 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)), 33 | auxiliary_head=dict( 34 | type='FCNHead', 35 | in_channels=1024, 36 | in_index=2, 37 | channels=256, 38 | num_convs=1, 39 | concat_input=False, 40 | dropout_ratio=0.1, 41 | num_classes=19, 42 | norm_cfg=norm_cfg, 43 | align_corners=False, 44 | loss_decode=dict( 45 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 46 | # model training and testing settings 47 | train_cfg=dict(), 48 | test_cfg=dict(mode='whole')) 49 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | from mmcv.utils import build_from_cfg 4 | 5 | from ..builder import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class Compose(object): 10 | """Compose multiple transforms sequentially. 11 | 12 | Args: 13 | transforms (Sequence[dict | callable]): Sequence of transform object or 14 | config dict to be composed. 15 | """ 16 | 17 | def __init__(self, transforms): 18 | assert isinstance(transforms, collections.abc.Sequence) 19 | self.transforms = [] 20 | for transform in transforms: 21 | if isinstance(transform, dict): 22 | transform = build_from_cfg(transform, PIPELINES) 23 | self.transforms.append(transform) 24 | elif callable(transform): 25 | self.transforms.append(transform) 26 | else: 27 | raise TypeError('transform must be callable or a dict') 28 | 29 | def __call__(self, data): 30 | """Call function to apply transforms sequentially. 31 | 32 | Args: 33 | data (dict): A result dict contains the data to transform. 34 | 35 | Returns: 36 | dict: Transformed data. 37 | """ 38 | 39 | for t in self.transforms: 40 | data = t(data) 41 | if data is None: 42 | return None 43 | return data 44 | 45 | def __repr__(self): 46 | format_string = self.__class__.__name__ + '(' 47 | for t in self.transforms: 48 | format_string += '\n' 49 | format_string += f' {t}' 50 | format_string += '\n)' 51 | return format_string 52 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 2 | 3 | from .builder import DATASETS 4 | 5 | 6 | @DATASETS.register_module() 7 | class ConcatDataset(_ConcatDataset): 8 | """A wrapper of concatenated dataset. 9 | 10 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 11 | concat the group flag for image aspect ratio. 12 | 13 | Args: 14 | datasets (list[:obj:`Dataset`]): A list of datasets. 15 | """ 16 | 17 | def __init__(self, datasets): 18 | super(ConcatDataset, self).__init__(datasets) 19 | self.CLASSES = datasets[0].CLASSES 20 | self.PALETTE = datasets[0].PALETTE 21 | 22 | 23 | @DATASETS.register_module() 24 | class RepeatDataset(object): 25 | """A wrapper of repeated dataset. 26 | 27 | The length of repeated dataset will be `times` larger than the original 28 | dataset. This is useful when the data loading time is long but the dataset 29 | is small. Using RepeatDataset can reduce the data loading time between 30 | epochs. 31 | 32 | Args: 33 | dataset (:obj:`Dataset`): The dataset to be repeated. 34 | times (int): Repeat times. 35 | """ 36 | 37 | def __init__(self, dataset, times): 38 | self.dataset = dataset 39 | self.times = times 40 | self.CLASSES = dataset.CLASSES 41 | self.PALETTE = dataset.PALETTE 42 | self._ori_len = len(self.dataset) 43 | 44 | def __getitem__(self, idx): 45 | """Get item from original dataset.""" 46 | return self.dataset[idx % self._ori_len] 47 | 48 | def __len__(self): 49 | """The length is multiplied by ``times``""" 50 | return self.times * self._ori_len 51 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/pspnet_unet_s5-d16.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='UNet', 8 | in_channels=3, 9 | base_channels=64, 10 | num_stages=5, 11 | strides=(1, 1, 1, 1, 1), 12 | enc_num_convs=(2, 2, 2, 2, 2), 13 | dec_num_convs=(2, 2, 2, 2), 14 | downsamples=(True, True, True, True), 15 | enc_dilations=(1, 1, 1, 1, 1), 16 | dec_dilations=(1, 1, 1, 1), 17 | with_cp=False, 18 | conv_cfg=None, 19 | norm_cfg=norm_cfg, 20 | act_cfg=dict(type='ReLU'), 21 | upsample_cfg=dict(type='InterpConv'), 22 | norm_eval=False), 23 | decode_head=dict( 24 | type='PSPHead', 25 | in_channels=64, 26 | in_index=4, 27 | channels=16, 28 | pool_scales=(1, 2, 3, 6), 29 | dropout_ratio=0.1, 30 | num_classes=2, 31 | norm_cfg=norm_cfg, 32 | align_corners=False, 33 | loss_decode=dict( 34 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 35 | auxiliary_head=dict( 36 | type='FCNHead', 37 | in_channels=128, 38 | in_index=3, 39 | channels=64, 40 | num_convs=1, 41 | concat_input=False, 42 | dropout_ratio=0.1, 43 | num_classes=2, 44 | norm_cfg=norm_cfg, 45 | align_corners=False, 46 | loss_decode=dict( 47 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 48 | # model training and testing settings 49 | train_cfg=dict(), 50 | test_cfg=dict(mode='slide', crop_size=256, stride=170)) 51 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/deeplabv3_unet_s5-d16.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='UNet', 8 | in_channels=3, 9 | base_channels=64, 10 | num_stages=5, 11 | strides=(1, 1, 1, 1, 1), 12 | enc_num_convs=(2, 2, 2, 2, 2), 13 | dec_num_convs=(2, 2, 2, 2), 14 | downsamples=(True, True, True, True), 15 | enc_dilations=(1, 1, 1, 1, 1), 16 | dec_dilations=(1, 1, 1, 1), 17 | with_cp=False, 18 | conv_cfg=None, 19 | norm_cfg=norm_cfg, 20 | act_cfg=dict(type='ReLU'), 21 | upsample_cfg=dict(type='InterpConv'), 22 | norm_eval=False), 23 | decode_head=dict( 24 | type='ASPPHead', 25 | in_channels=64, 26 | in_index=4, 27 | channels=16, 28 | dilations=(1, 12, 24, 36), 29 | dropout_ratio=0.1, 30 | num_classes=2, 31 | norm_cfg=norm_cfg, 32 | align_corners=False, 33 | loss_decode=dict( 34 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 35 | auxiliary_head=dict( 36 | type='FCNHead', 37 | in_channels=128, 38 | in_index=3, 39 | channels=64, 40 | num_convs=1, 41 | concat_input=False, 42 | dropout_ratio=0.1, 43 | num_classes=2, 44 | norm_cfg=norm_cfg, 45 | align_corners=False, 46 | loss_decode=dict( 47 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 48 | # model training and testing settings 49 | train_cfg=dict(), 50 | test_cfg=dict(mode='slide', crop_size=256, stride=170)) 51 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/fcn_unet_s5-d16.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='UNet', 8 | in_channels=3, 9 | base_channels=64, 10 | num_stages=5, 11 | strides=(1, 1, 1, 1, 1), 12 | enc_num_convs=(2, 2, 2, 2, 2), 13 | dec_num_convs=(2, 2, 2, 2), 14 | downsamples=(True, True, True, True), 15 | enc_dilations=(1, 1, 1, 1, 1), 16 | dec_dilations=(1, 1, 1, 1), 17 | with_cp=False, 18 | conv_cfg=None, 19 | norm_cfg=norm_cfg, 20 | act_cfg=dict(type='ReLU'), 21 | upsample_cfg=dict(type='InterpConv'), 22 | norm_eval=False), 23 | decode_head=dict( 24 | type='FCNHead', 25 | in_channels=64, 26 | in_index=4, 27 | channels=64, 28 | num_convs=1, 29 | concat_input=False, 30 | dropout_ratio=0.1, 31 | num_classes=2, 32 | norm_cfg=norm_cfg, 33 | align_corners=False, 34 | loss_decode=dict( 35 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 36 | auxiliary_head=dict( 37 | type='FCNHead', 38 | in_channels=128, 39 | in_index=3, 40 | channels=64, 41 | num_convs=1, 42 | concat_input=False, 43 | dropout_ratio=0.1, 44 | num_classes=2, 45 | norm_cfg=norm_cfg, 46 | align_corners=False, 47 | loss_decode=dict( 48 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 49 | # model training and testing settings 50 | train_cfg=dict(), 51 | test_cfg=dict(mode='slide', crop_size=256, stride=170)) 52 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/nl_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.cnn import NonLocal2d 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | 8 | @HEADS.register_module() 9 | class NLHead(FCNHead): 10 | """Non-local Neural Networks. 11 | 12 | This head is the implementation of `NLNet 13 | `_. 14 | 15 | Args: 16 | reduction (int): Reduction factor of projection transform. Default: 2. 17 | use_scale (bool): Whether to scale pairwise_weight by 18 | sqrt(1/inter_channels). Default: True. 19 | mode (str): The nonlocal mode. Options are 'embedded_gaussian', 20 | 'dot_product'. Default: 'embedded_gaussian.'. 21 | """ 22 | 23 | def __init__(self, 24 | reduction=2, 25 | use_scale=True, 26 | mode='embedded_gaussian', 27 | **kwargs): 28 | super(NLHead, self).__init__(num_convs=2, **kwargs) 29 | self.reduction = reduction 30 | self.use_scale = use_scale 31 | self.mode = mode 32 | self.nl_block = NonLocal2d( 33 | in_channels=self.channels, 34 | reduction=self.reduction, 35 | use_scale=self.use_scale, 36 | conv_cfg=self.conv_cfg, 37 | norm_cfg=self.norm_cfg, 38 | mode=self.mode) 39 | 40 | def forward(self, inputs): 41 | """Forward function.""" 42 | x = self._transform_inputs(inputs) 43 | output = self.convs[0](x) 44 | output = self.nl_block(output) 45 | output = self.convs[1](output) 46 | if self.concat_input: 47 | output = self.conv_cat(torch.cat([x, output], dim=1)) 48 | output = self.cls_seg(output) 49 | return output 50 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/gc_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.cnn import ContextBlock 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | 8 | @HEADS.register_module() 9 | class GCHead(FCNHead): 10 | """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. 11 | 12 | This head is the implementation of `GCNet 13 | `_. 14 | 15 | Args: 16 | ratio (float): Multiplier of channels ratio. Default: 1/4. 17 | pooling_type (str): The pooling type of context aggregation. 18 | Options are 'att', 'avg'. Default: 'avg'. 19 | fusion_types (tuple[str]): The fusion type for feature fusion. 20 | Options are 'channel_add', 'channel_mul'. Defautl: ('channel_add',) 21 | """ 22 | 23 | def __init__(self, 24 | ratio=1 / 4., 25 | pooling_type='att', 26 | fusion_types=('channel_add', ), 27 | **kwargs): 28 | super(GCHead, self).__init__(num_convs=2, **kwargs) 29 | self.ratio = ratio 30 | self.pooling_type = pooling_type 31 | self.fusion_types = fusion_types 32 | self.gc_block = ContextBlock( 33 | in_channels=self.channels, 34 | ratio=self.ratio, 35 | pooling_type=self.pooling_type, 36 | fusion_types=self.fusion_types) 37 | 38 | def forward(self, inputs): 39 | """Forward function.""" 40 | x = self._transform_inputs(inputs) 41 | output = self.convs[0](x) 42 | output = self.gc_block(output) 43 | output = self.convs[1](output) 44 | if self.concat_input: 45 | output = self.conv_cat(torch.cat([x, output], dim=1)) 46 | output = self.cls_seg(output) 47 | return output 48 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/fcn_hr18.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://msra/hrnetv2_w18', 6 | backbone=dict( 7 | type='HRNet', 8 | norm_cfg=norm_cfg, 9 | norm_eval=False, 10 | extra=dict( 11 | stage1=dict( 12 | num_modules=1, 13 | num_branches=1, 14 | block='BOTTLENECK', 15 | num_blocks=(4, ), 16 | num_channels=(64, )), 17 | stage2=dict( 18 | num_modules=1, 19 | num_branches=2, 20 | block='BASIC', 21 | num_blocks=(4, 4), 22 | num_channels=(18, 36)), 23 | stage3=dict( 24 | num_modules=4, 25 | num_branches=3, 26 | block='BASIC', 27 | num_blocks=(4, 4, 4), 28 | num_channels=(18, 36, 72)), 29 | stage4=dict( 30 | num_modules=3, 31 | num_branches=4, 32 | block='BASIC', 33 | num_blocks=(4, 4, 4, 4), 34 | num_channels=(18, 36, 72, 144)))), 35 | decode_head=dict( 36 | type='FCNHead', 37 | in_channels=[18, 36, 72, 144], 38 | in_index=(0, 1, 2, 3), 39 | channels=sum([18, 36, 72, 144]), 40 | input_transform='resize_concat', 41 | kernel_size=1, 42 | num_convs=1, 43 | concat_input=False, 44 | dropout_ratio=-1, 45 | num_classes=19, 46 | norm_cfg=norm_cfg, 47 | align_corners=False, 48 | loss_decode=dict( 49 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 50 | # model training and testing settings 51 | train_cfg=dict(), 52 | test_cfg=dict(mode='whole')) 53 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/pointrend_r50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='CascadeEncoderDecoder', 5 | num_stages=2, 6 | pretrained='open-mmlab://resnet50_v1c', 7 | backbone=dict( 8 | type='ResNetV1c', 9 | depth=50, 10 | num_stages=4, 11 | out_indices=(0, 1, 2, 3), 12 | dilations=(1, 1, 1, 1), 13 | strides=(1, 2, 2, 2), 14 | norm_cfg=norm_cfg, 15 | norm_eval=False, 16 | style='pytorch', 17 | contract_dilation=True), 18 | neck=dict( 19 | type='FPN', 20 | in_channels=[256, 512, 1024, 2048], 21 | out_channels=256, 22 | num_outs=4), 23 | decode_head=[ 24 | dict( 25 | type='FPNHead', 26 | in_channels=[256, 256, 256, 256], 27 | in_index=[0, 1, 2, 3], 28 | feature_strides=[4, 8, 16, 32], 29 | channels=128, 30 | dropout_ratio=-1, 31 | num_classes=19, 32 | norm_cfg=norm_cfg, 33 | align_corners=False, 34 | loss_decode=dict( 35 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 36 | dict( 37 | type='PointHead', 38 | in_channels=[256], 39 | in_index=[0], 40 | channels=256, 41 | num_fcs=3, 42 | coarse_pred_each_layer=True, 43 | dropout_ratio=-1, 44 | num_classes=19, 45 | align_corners=False, 46 | loss_decode=dict( 47 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) 48 | ], 49 | # model training and testing settings 50 | train_cfg=dict( 51 | num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75), 52 | test_cfg=dict( 53 | mode='whole', 54 | subdivision_steps=2, 55 | subdivision_num_points=8196, 56 | scale_factor=2)) 57 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/tools/gen_city_c.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/FAN/blob/main/LICENSE 6 | 7 | # Copyright (c) Open-MMLab. All rights reserved. 8 | 9 | import numpy as np 10 | from imagecorruptions import corrupt 11 | import random 12 | import os 13 | import mmcv 14 | 15 | 16 | random.seed(8) # for reproducibility 17 | np.random.seed(8) 18 | 19 | 20 | corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 21 | 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 22 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression', 23 | 'speckle_noise', 'gaussian_blur', 'spatter', 'saturate'] 24 | 25 | def perturb(i, p, s): 26 | img = corrupt(i, corruption_name=p, severity=s) 27 | return img 28 | 29 | 30 | def convert_img_path(ori_path, suffix): 31 | new_path = ori_path.replace('clean', suffix) 32 | assert new_path != ori_path 33 | return new_path 34 | 35 | def main(): 36 | img_dir = '../ade20k_c/clean/' 37 | severity = [1, 2, 3, 4, 5] 38 | num_imgs = 5000 39 | for p in corruptions: 40 | print("\n ### gen corruption:{} ###".format(p)) 41 | prog_bar = mmcv.ProgressBar(num_imgs) 42 | for img_path in mmcv.scandir(img_dir, suffix='jpg', recursive=True): 43 | img_path = os.path.join(img_dir, img_path) 44 | img = mmcv.imread(img_path) 45 | prog_bar.update() 46 | for s in severity: 47 | perturbed_img = perturb(img, p, s) 48 | img_suffix = p+"/"+str(s) 49 | perturbed_img_path = convert_img_path(img_path, img_suffix) 50 | mmcv.imwrite(perturbed_img, perturbed_img_path, auto_mkdir=True) 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/fast_scnn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) 3 | model = dict( 4 | type='EncoderDecoder', 5 | backbone=dict( 6 | type='FastSCNN', 7 | downsample_dw_channels=(32, 48), 8 | global_in_channels=64, 9 | global_block_channels=(64, 96, 128), 10 | global_block_strides=(2, 2, 1), 11 | global_out_channels=128, 12 | higher_in_channels=64, 13 | lower_in_channels=128, 14 | fusion_out_channels=128, 15 | out_indices=(0, 1, 2), 16 | norm_cfg=norm_cfg, 17 | align_corners=False), 18 | decode_head=dict( 19 | type='DepthwiseSeparableFCNHead', 20 | in_channels=128, 21 | channels=128, 22 | concat_input=False, 23 | num_classes=19, 24 | in_index=-1, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), 29 | auxiliary_head=[ 30 | dict( 31 | type='FCNHead', 32 | in_channels=128, 33 | channels=32, 34 | num_convs=1, 35 | num_classes=19, 36 | in_index=-2, 37 | norm_cfg=norm_cfg, 38 | concat_input=False, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), 42 | dict( 43 | type='FCNHead', 44 | in_channels=64, 45 | channels=32, 46 | num_convs=1, 47 | num_classes=19, 48 | in_index=-3, 49 | norm_cfg=norm_cfg, 50 | concat_input=False, 51 | align_corners=False, 52 | loss_decode=dict( 53 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), 54 | ], 55 | # model training and testing settings 56 | train_cfg=dict(), 57 | test_cfg=dict(mode='whole')) 58 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | """Get the root logger. 8 | 9 | The logger will be initialized if it has not been initialized. By default a 10 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 11 | also be added. The name of the root logger is the top-level package name, 12 | e.g., "mmseg". 13 | 14 | Args: 15 | log_file (str | None): The log filename. If specified, a FileHandler 16 | will be added to the root logger. 17 | log_level (int): The root logger level. Note that only the process of 18 | rank 0 is affected, while other processes will set the level to 19 | "Error" and be silent most of the time. 20 | 21 | Returns: 22 | logging.Logger: The root logger. 23 | """ 24 | 25 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 26 | 27 | return logger 28 | 29 | def print_log(msg, logger=None, level=logging.INFO): 30 | """Print a log message. 31 | Args: 32 | msg (str): The message to be logged. 33 | logger (logging.Logger | str | None): The logger to be used. Some 34 | special loggers are: 35 | - "root": the root logger obtained with `get_root_logger()`. 36 | - "silent": no message will be printed. 37 | - None: The `print()` method will be used to print log messages. 38 | level (int): Logging level. Only available when `logger` is a Logger 39 | object or "root". 40 | """ 41 | if logger is None: 42 | print(msg) 43 | elif logger == 'root': 44 | _logger = get_root_logger() 45 | _logger.log(level, msg) 46 | elif isinstance(logger, logging.Logger): 47 | logger.log(level, msg) 48 | elif logger != 'silent': 49 | raise TypeError( 50 | 'logger should be either a logging.Logger object, "root", ' 51 | '"silent" or None, but got {}'.format(logger)) -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from mmcv.utils import Registry, build_from_cfg 4 | from torch import nn 5 | 6 | BACKBONES = Registry('backbone') 7 | NECKS = Registry('neck') 8 | HEADS = Registry('head') 9 | LOSSES = Registry('loss') 10 | SEGMENTORS = Registry('segmentor') 11 | 12 | 13 | def build(cfg, registry, default_args=None): 14 | """Build a module. 15 | 16 | Args: 17 | cfg (dict, list[dict]): The config of modules, is is either a dict 18 | or a list of configs. 19 | registry (:obj:`Registry`): A registry the module belongs to. 20 | default_args (dict, optional): Default arguments to build the module. 21 | Defaults to None. 22 | 23 | Returns: 24 | nn.Module: A built nn module. 25 | """ 26 | 27 | if isinstance(cfg, list): 28 | modules = [ 29 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 30 | ] 31 | return nn.Sequential(*modules) 32 | else: 33 | return build_from_cfg(cfg, registry, default_args) 34 | 35 | 36 | def build_backbone(cfg): 37 | """Build backbone.""" 38 | return build(cfg, BACKBONES) 39 | 40 | 41 | def build_neck(cfg): 42 | """Build neck.""" 43 | return build(cfg, NECKS) 44 | 45 | 46 | def build_head(cfg): 47 | """Build head.""" 48 | return build(cfg, HEADS) 49 | 50 | 51 | def build_loss(cfg): 52 | """Build loss.""" 53 | return build(cfg, LOSSES) 54 | 55 | 56 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 57 | """Build segmentor.""" 58 | if train_cfg is not None or test_cfg is not None: 59 | warnings.warn( 60 | 'train_cfg and test_cfg is deprecated, ' 61 | 'please specify them in model', UserWarning) 62 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 63 | 'train_cfg specified in both outer field and model field ' 64 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 65 | 'test_cfg specified in both outer field and model field ' 66 | return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 67 | 68 | 69 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/fan/fan_hybrid/tapfan_hybrid_base.1024x1024.city.160k.test.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/segformer.py', 3 | '../../_base_/datasets/cityscapes_1024x1024_repeat.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k_8gpu_adamw.py' 6 | ] 7 | 8 | # model settings 9 | norm_cfg = dict(type='SyncBN', requires_grad=True) 10 | find_unused_parameters = True 11 | model = dict( 12 | type='EncoderDecoder', 13 | backbone=dict( 14 | type='tap_fan_base_16_p4_hybrid', 15 | style='pytorch'), 16 | decode_head=dict( 17 | type='SegFormerHead', 18 | in_channels=[128, 256, 448, 448], 19 | in_index=[0, 1, 2, 3], 20 | feature_strides=[4, 8, 16, 32], 21 | channels=256, 22 | dropout_ratio=0.1, 23 | num_classes=19, 24 | norm_cfg=norm_cfg, 25 | align_corners=False, 26 | decoder_params=dict(embed_dim=768), 27 | loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 28 | # model training and testing settings 29 | train_cfg=dict(adloss=1), 30 | test_cfg=dict(mode='slide', crop_size=(1024,1024), stride=(768,768))) 31 | 32 | # data 33 | data = dict(samples_per_gpu=1) 34 | evaluation = dict(interval=4000, metric='mIoU') 35 | 36 | # optimizer 37 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 38 | paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.), 39 | 'norm': dict(decay_mult=0.), 40 | 'head': dict(lr_mult=10.) 41 | })) 42 | 43 | lr_config = dict(_delete_=True, policy='poly', 44 | warmup='linear', 45 | warmup_iters=1500, 46 | warmup_ratio=1e-6, 47 | power=1.0, min_lr=0.0, by_epoch=False) 48 | 49 | # uncomment to use fp16 training 50 | # optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic') 51 | # fp16 = dict() 52 | # fp16 = dict(loss_scale='dynamic') 53 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > output_h: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | if isinstance(size, torch.Size): 28 | size = tuple(int(x) for x in size) 29 | return F.interpolate(input, size, scale_factor, mode, align_corners) 30 | 31 | 32 | class Upsample(nn.Module): 33 | 34 | def __init__(self, 35 | size=None, 36 | scale_factor=None, 37 | mode='nearest', 38 | align_corners=None): 39 | super(Upsample, self).__init__() 40 | self.size = size 41 | if isinstance(scale_factor, tuple): 42 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 43 | else: 44 | self.scale_factor = float(scale_factor) if scale_factor else None 45 | self.mode = mode 46 | self.align_corners = align_corners 47 | 48 | def forward(self, x): 49 | if not self.size: 50 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 51 | else: 52 | size = self.size 53 | return resize(x, size, None, self.mode, self.align_corners) 54 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/fan/fan_hybrid/tapfan_hybrid_base.1024x1024.city.160k.test.acdc.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/segformer.py', 3 | '../../_base_/datasets/cityscapes_1024x1024_repeat_acdc.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k_8gpu_adamw.py' 6 | ] 7 | 8 | # model settings 9 | norm_cfg = dict(type='SyncBN', requires_grad=True) 10 | find_unused_parameters = True 11 | model = dict( 12 | type='EncoderDecoder', 13 | backbone=dict( 14 | type='tap_fan_base_16_p4_hybrid', 15 | style='pytorch'), 16 | decode_head=dict( 17 | type='SegFormerHead', 18 | in_channels=[128, 256, 448, 448], 19 | in_index=[0, 1, 2, 3], 20 | feature_strides=[4, 8, 16, 32], 21 | channels=256, 22 | dropout_ratio=0.1, 23 | num_classes=19, 24 | norm_cfg=norm_cfg, 25 | align_corners=False, 26 | decoder_params=dict(embed_dim=768), 27 | loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 28 | # model training and testing settings 29 | train_cfg=dict(adloss=1), 30 | test_cfg=dict(mode='slide', crop_size=(1024,1024), stride=(768,768))) 31 | 32 | # data 33 | data = dict(samples_per_gpu=1) 34 | evaluation = dict(interval=4000, metric='mIoU') 35 | 36 | # optimizer 37 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 38 | paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.), 39 | 'norm': dict(decay_mult=0.), 40 | 'head': dict(lr_mult=10.) 41 | })) 42 | 43 | lr_config = dict(_delete_=True, policy='poly', 44 | warmup='linear', 45 | warmup_iters=1500, 46 | warmup_ratio=1e-6, 47 | power=1.0, min_lr=0.0, by_epoch=False) 48 | 49 | # uncomment to use fp16 training 50 | # optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic') 51 | # fp16 = dict() 52 | # fp16 = dict(loss_scale='dynamic') 53 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/fan/fan_hybrid/tapfan_hybrid_base.1024x1024.city.160k.test.cityc.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/segformer.py', 3 | '../../_base_/datasets/cityscapes_1024x1024_repeat_cityc.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k_8gpu_adamw.py' 6 | ] 7 | 8 | # model settings 9 | norm_cfg = dict(type='SyncBN', requires_grad=True) 10 | find_unused_parameters = True 11 | model = dict( 12 | type='EncoderDecoder', 13 | backbone=dict( 14 | type='tap_fan_base_16_p4_hybrid', 15 | style='pytorch'), 16 | decode_head=dict( 17 | type='SegFormerHead', 18 | in_channels=[128, 256, 448, 448], 19 | in_index=[0, 1, 2, 3], 20 | feature_strides=[4, 8, 16, 32], 21 | channels=256, 22 | dropout_ratio=0.1, 23 | num_classes=19, 24 | norm_cfg=norm_cfg, 25 | align_corners=False, 26 | decoder_params=dict(embed_dim=768), 27 | loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 28 | # model training and testing settings 29 | train_cfg=dict(adloss=1), 30 | test_cfg=dict(mode='slide', crop_size=(1024,1024), stride=(768,768))) 31 | 32 | # data 33 | data = dict(samples_per_gpu=1) 34 | evaluation = dict(interval=4000, metric='mIoU') 35 | 36 | # optimizer 37 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 38 | paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.), 39 | 'norm': dict(decay_mult=0.), 40 | 'head': dict(lr_mult=10.) 41 | })) 42 | 43 | lr_config = dict(_delete_=True, policy='poly', 44 | warmup='linear', 45 | warmup_iters=1500, 46 | warmup_ratio=1e-6, 47 | power=1.0, min_lr=0.0, by_epoch=False) 48 | 49 | # uncomment to use fp16 training 50 | # optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic') 51 | # fp16 = dict() 52 | # fp16 = dict(loss_scale='dynamic') 53 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/datasets/cityscapes_1024x1024_repeat.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CityscapesDataset' 3 | data_root = '/PATH/TO/CITYSCAPES' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (1024, 1024) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations'), 10 | dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 1024), 24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=4, 36 | workers_per_gpu=2, 37 | train=dict( 38 | type='RepeatDataset', 39 | times=500, 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | img_dir='leftImg8bit/train', 44 | ann_dir='gtFine/train', 45 | pipeline=train_pipeline)), 46 | val=dict( 47 | type=dataset_type, 48 | data_root=data_root, 49 | img_dir='leftImg8bit/val', 50 | ann_dir='gtFine/val', 51 | pipeline=test_pipeline), 52 | test=dict( 53 | type=dataset_type, 54 | data_root=data_root, 55 | img_dir='leftImg8bit/val', 56 | ann_dir='gtFine/val', 57 | pipeline=test_pipeline)) 58 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/fan/fan_hybrid/tapadl_fan_hybrid_base.1024x1024.city.160k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/segformer.py', 3 | '../../_base_/datasets/cityscapes_1024x1024_repeat.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k_8gpu_adamw.py' 6 | ] 7 | 8 | # model settings 9 | norm_cfg = dict(type='SyncBN', requires_grad=True) 10 | find_unused_parameters = True 11 | model = dict( 12 | type='EncoderDecoder', 13 | pretrained='../../pretrained/tapadl_fan_base.pth.tar', 14 | backbone=dict( 15 | type='tap_fan_base_16_p4_hybrid', 16 | style='pytorch'), 17 | decode_head=dict( 18 | type='SegFormerHead', 19 | in_channels=[128, 256, 448, 448], 20 | in_index=[0, 1, 2, 3], 21 | feature_strides=[4, 8, 16, 32], 22 | channels=256, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | decoder_params=dict(embed_dim=768), 28 | loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | # model training and testing settings 30 | train_cfg=dict(adloss=1), 31 | test_cfg=dict(mode='slide', crop_size=(1024,1024), stride=(768,768))) 32 | 33 | # data 34 | data = dict(samples_per_gpu=1) 35 | evaluation = dict(interval=4000, metric='mIoU') 36 | 37 | # optimizer 38 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 39 | paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.), 40 | 'norm': dict(decay_mult=0.), 41 | 'head': dict(lr_mult=10.) 42 | })) 43 | 44 | lr_config = dict(_delete_=True, policy='poly', 45 | warmup='linear', 46 | warmup_iters=1500, 47 | warmup_ratio=1e-6, 48 | power=1.0, min_lr=0.0, by_epoch=False) 49 | 50 | # uncomment to use fp16 training 51 | # optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic') 52 | # fp16 = dict() 53 | # fp16 = dict(loss_scale='dynamic') 54 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/sep_fcn_head.py: -------------------------------------------------------------------------------- 1 | from mmcv.cnn import DepthwiseSeparableConvModule 2 | 3 | from ..builder import HEADS 4 | from .fcn_head import FCNHead 5 | 6 | 7 | @HEADS.register_module() 8 | class DepthwiseSeparableFCNHead(FCNHead): 9 | """Depthwise-Separable Fully Convolutional Network for Semantic 10 | Segmentation. 11 | 12 | This head is implemented according to Fast-SCNN paper. 13 | Args: 14 | in_channels(int): Number of output channels of FFM. 15 | channels(int): Number of middle-stage channels in the decode head. 16 | concat_input(bool): Whether to concatenate original decode input into 17 | the result of several consecutive convolution layers. 18 | Default: True. 19 | num_classes(int): Used to determine the dimension of 20 | final prediction tensor. 21 | in_index(int): Correspond with 'out_indices' in FastSCNN backbone. 22 | norm_cfg (dict | None): Config of norm layers. 23 | align_corners (bool): align_corners argument of F.interpolate. 24 | Default: False. 25 | loss_decode(dict): Config of loss type and some 26 | relevant additional options. 27 | """ 28 | 29 | def __init__(self, **kwargs): 30 | super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) 31 | self.convs[0] = DepthwiseSeparableConvModule( 32 | self.in_channels, 33 | self.channels, 34 | kernel_size=self.kernel_size, 35 | padding=self.kernel_size // 2, 36 | norm_cfg=self.norm_cfg) 37 | for i in range(1, self.num_convs): 38 | self.convs[i] = DepthwiseSeparableConvModule( 39 | self.channels, 40 | self.channels, 41 | kernel_size=self.kernel_size, 42 | padding=self.kernel_size // 2, 43 | norm_cfg=self.norm_cfg) 44 | 45 | if self.concat_input: 46 | self.conv_cat = DepthwiseSeparableConvModule( 47 | self.in_channels + self.channels, 48 | self.channels, 49 | kernel_size=self.kernel_size, 50 | padding=self.kernel_size // 2, 51 | norm_cfg=self.norm_cfg) 52 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/datasets/cityscapes_1024x1024_repeat_cityc.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CityscapesDataset' 3 | data_root = '/PATH/TO/CITYSCAPES-C' 4 | 5 | img_norm_cfg = dict( 6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 7 | crop_size = (1024, 1024) 8 | train_pipeline = [ 9 | dict(type='LoadImageFromFile'), 10 | dict(type='LoadAnnotations'), 11 | dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 12 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 13 | dict(type='RandomFlip', prob=0.5), 14 | dict(type='PhotoMetricDistortion'), 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiScaleFlipAug', 24 | img_scale=(2048, 1024), 25 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 26 | flip=False, 27 | transforms=[ 28 | dict(type='Resize', keep_ratio=True), 29 | dict(type='RandomFlip'), 30 | dict(type='Normalize', **img_norm_cfg), 31 | dict(type='ImageToTensor', keys=['img']), 32 | dict(type='Collect', keys=['img']), 33 | ]) 34 | ] 35 | data = dict( 36 | samples_per_gpu=4, 37 | workers_per_gpu=2, 38 | train=dict( 39 | type='RepeatDataset', 40 | times=500, 41 | dataset=dict( 42 | type=dataset_type, 43 | data_root=data_root, 44 | img_dir='leftImg8bit/train', 45 | ann_dir='gtFine/train', 46 | pipeline=train_pipeline)), 47 | val=dict( 48 | type=dataset_type, 49 | data_root=data_root, 50 | img_dir='leftImg8bit/val', 51 | # img_dir='leftImg8bit/clean/val', 52 | ann_dir='gtFine/val', 53 | pipeline=test_pipeline), 54 | test=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir='clean', 58 | # img_dir='leftImg8bit/clean/val', 59 | ann_dir='gtFine/val', 60 | pipeline=test_pipeline)) 61 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/datasets/cityscapes_1024x1024_repeat_acdc.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ACDCDataset' 3 | data_root = '/PATH/TO/ACDC' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (1024, 1024) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations'), 10 | dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 1024), 24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=4, 36 | workers_per_gpu=2, 37 | train=dict( 38 | type='RepeatDataset', 39 | times=500, 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | img_dir='leftImg8bit/train', 44 | ann_dir='gtFine/train', 45 | pipeline=train_pipeline)), 46 | val=dict( 47 | type=dataset_type, 48 | data_root=data_root, 49 | img_dir='leftImg8bit/val', 50 | # img_dir='leftImg8bit/clean/val', 51 | ann_dir='gtFine/val', 52 | pipeline=test_pipeline), 53 | test=dict( 54 | type=dataset_type, 55 | data_root=data_root, 56 | img_dir='rgb_anon_trainvaltest/rgb_anon', 57 | # img_dir='leftImg8bit/clean/val', 58 | ann_dir='gt_trainval/gt', 59 | pipeline=test_pipeline)) 60 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/utils/se_layer.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from .make_divisible import make_divisible 6 | 7 | 8 | class SELayer(nn.Module): 9 | """Squeeze-and-Excitation Module. 10 | 11 | Args: 12 | channels (int): The input (and output) channels of the SE layer. 13 | ratio (int): Squeeze ratio in SELayer, the intermediate channel will be 14 | ``int(channels/ratio)``. Default: 16. 15 | conv_cfg (None or dict): Config dict for convolution layer. 16 | Default: None, which means using conv2d. 17 | act_cfg (dict or Sequence[dict]): Config dict for activation layer. 18 | If act_cfg is a dict, two activation layers will be configurated 19 | by this dict. If act_cfg is a sequence of dicts, the first 20 | activation layer will be configurated by the first dict and the 21 | second activation layer will be configurated by the second dict. 22 | Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, 23 | divisor=6.0)). 24 | """ 25 | 26 | def __init__(self, 27 | channels, 28 | ratio=16, 29 | conv_cfg=None, 30 | act_cfg=(dict(type='ReLU'), 31 | dict(type='HSigmoid', bias=3.0, divisor=6.0))): 32 | super(SELayer, self).__init__() 33 | if isinstance(act_cfg, dict): 34 | act_cfg = (act_cfg, act_cfg) 35 | assert len(act_cfg) == 2 36 | assert mmcv.is_tuple_of(act_cfg, dict) 37 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 38 | self.conv1 = ConvModule( 39 | in_channels=channels, 40 | out_channels=make_divisible(channels // ratio, 8), 41 | kernel_size=1, 42 | stride=1, 43 | conv_cfg=conv_cfg, 44 | act_cfg=act_cfg[0]) 45 | self.conv2 = ConvModule( 46 | in_channels=make_divisible(channels // ratio, 8), 47 | out_channels=channels, 48 | kernel_size=1, 49 | stride=1, 50 | conv_cfg=conv_cfg, 51 | act_cfg=act_cfg[1]) 52 | 53 | def forward(self, x): 54 | out = self.global_avgpool(x) 55 | out = self.conv1(out) 56 | out = self.conv2(out) 57 | return x * out 58 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/ocrnet_hr18.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='CascadeEncoderDecoder', 5 | num_stages=2, 6 | pretrained='open-mmlab://msra/hrnetv2_w18', 7 | backbone=dict( 8 | type='HRNet', 9 | norm_cfg=norm_cfg, 10 | norm_eval=False, 11 | extra=dict( 12 | stage1=dict( 13 | num_modules=1, 14 | num_branches=1, 15 | block='BOTTLENECK', 16 | num_blocks=(4, ), 17 | num_channels=(64, )), 18 | stage2=dict( 19 | num_modules=1, 20 | num_branches=2, 21 | block='BASIC', 22 | num_blocks=(4, 4), 23 | num_channels=(18, 36)), 24 | stage3=dict( 25 | num_modules=4, 26 | num_branches=3, 27 | block='BASIC', 28 | num_blocks=(4, 4, 4), 29 | num_channels=(18, 36, 72)), 30 | stage4=dict( 31 | num_modules=3, 32 | num_branches=4, 33 | block='BASIC', 34 | num_blocks=(4, 4, 4, 4), 35 | num_channels=(18, 36, 72, 144)))), 36 | decode_head=[ 37 | dict( 38 | type='FCNHead', 39 | in_channels=[18, 36, 72, 144], 40 | channels=sum([18, 36, 72, 144]), 41 | in_index=(0, 1, 2, 3), 42 | input_transform='resize_concat', 43 | kernel_size=1, 44 | num_convs=1, 45 | concat_input=False, 46 | dropout_ratio=-1, 47 | num_classes=19, 48 | norm_cfg=norm_cfg, 49 | align_corners=False, 50 | loss_decode=dict( 51 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 52 | dict( 53 | type='OCRHead', 54 | in_channels=[18, 36, 72, 144], 55 | in_index=(0, 1, 2, 3), 56 | input_transform='resize_concat', 57 | channels=512, 58 | ocr_channels=256, 59 | dropout_ratio=-1, 60 | num_classes=19, 61 | norm_cfg=norm_cfg, 62 | align_corners=False, 63 | loss_decode=dict( 64 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 65 | ], 66 | # model training and testing settings 67 | train_cfg=dict(), 68 | test_cfg=dict(mode='whole')) 69 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/utils/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/cascade_decode_head.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | from .decode_head import BaseDecodeHead 4 | 5 | 6 | class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): 7 | """Base class for cascade decode head used in 8 | :class:`CascadeEncoderDecoder.""" 9 | 10 | def __init__(self, *args, **kwargs): 11 | super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) 12 | 13 | @abstractmethod 14 | def forward(self, inputs, prev_output): 15 | """Placeholder of forward function.""" 16 | pass 17 | 18 | def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, 19 | train_cfg): 20 | """Forward function for training. 21 | Args: 22 | inputs (list[Tensor]): List of multi-level img features. 23 | prev_output (Tensor): The output of previous decode head. 24 | img_metas (list[dict]): List of image info dict where each dict 25 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 26 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 27 | For details on the values of these keys see 28 | `mmseg/datasets/pipelines/formatting.py:Collect`. 29 | gt_semantic_seg (Tensor): Semantic segmentation masks 30 | used if the architecture supports semantic segmentation task. 31 | train_cfg (dict): The training config. 32 | 33 | Returns: 34 | dict[str, Tensor]: a dictionary of loss components 35 | """ 36 | seg_logits = self.forward(inputs, prev_output) 37 | losses = self.losses(seg_logits, gt_semantic_seg) 38 | 39 | return losses 40 | 41 | def forward_test(self, inputs, prev_output, img_metas, test_cfg): 42 | """Forward function for testing. 43 | 44 | Args: 45 | inputs (list[Tensor]): List of multi-level img features. 46 | prev_output (Tensor): The output of previous decode head. 47 | img_metas (list[dict]): List of image info dict where each dict 48 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 49 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 50 | For details on the values of these keys see 51 | `mmseg/datasets/pipelines/formatting.py:Collect`. 52 | test_cfg (dict): The testing config. 53 | 54 | Returns: 55 | Tensor: Output segmentation map. 56 | """ 57 | return self.forward(inputs, prev_output) 58 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/local_configs/_base_/models/setr_pup.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) 3 | norm_cfg = dict(type='SyncBN', requires_grad=True) 4 | model = dict( 5 | type='EncoderDecoder', 6 | pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth', 7 | backbone=dict( 8 | type='VisionTransformer', 9 | img_size=(768, 768), 10 | patch_size=16, 11 | in_channels=3, 12 | embed_dims=1024, 13 | num_layers=24, 14 | num_heads=16, 15 | out_indices=(9, 14, 19, 23), 16 | drop_rate=0.1, 17 | norm_cfg=backbone_norm_cfg, 18 | with_cls_token=True, 19 | interpolate_mode='bilinear', 20 | ), 21 | decode_head=dict( 22 | type='SETRUPHead', 23 | in_channels=1024, 24 | channels=256, 25 | in_index=3, 26 | num_classes=19, 27 | dropout_ratio=0, 28 | norm_cfg=norm_cfg, 29 | num_convs=4, 30 | up_scale=2, 31 | kernel_size=3, 32 | align_corners=False, 33 | loss_decode=dict( 34 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 35 | auxiliary_head=[ 36 | dict( 37 | type='SETRUPHead', 38 | in_channels=1024, 39 | channels=256, 40 | in_index=0, 41 | num_classes=19, 42 | dropout_ratio=0, 43 | norm_cfg=norm_cfg, 44 | num_convs=1, 45 | up_scale=4, 46 | kernel_size=3, 47 | align_corners=False, 48 | loss_decode=dict( 49 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 50 | dict( 51 | type='SETRUPHead', 52 | in_channels=1024, 53 | channels=256, 54 | in_index=1, 55 | num_classes=19, 56 | dropout_ratio=0, 57 | norm_cfg=norm_cfg, 58 | num_convs=1, 59 | up_scale=4, 60 | kernel_size=3, 61 | align_corners=False, 62 | loss_decode=dict( 63 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 64 | dict( 65 | type='SETRUPHead', 66 | in_channels=1024, 67 | channels=256, 68 | in_index=2, 69 | num_classes=19, 70 | dropout_ratio=0, 71 | norm_cfg=norm_cfg, 72 | num_convs=1, 73 | up_scale=4, 74 | kernel_size=3, 75 | align_corners=False, 76 | loss_decode=dict( 77 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 78 | ], 79 | train_cfg=dict(), 80 | test_cfg=dict(mode='whole')) 81 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/fpn_head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | from IPython import embed 9 | 10 | @HEADS.register_module() 11 | class FPNHead(BaseDecodeHead): 12 | """Panoptic Feature Pyramid Networks. 13 | 14 | This head is the implementation of `Semantic FPN 15 | `_. 16 | 17 | Args: 18 | feature_strides (tuple[int]): The strides for input feature maps. 19 | stack_lateral. All strides suppose to be power of 2. The first 20 | one is of largest resolution. 21 | """ 22 | 23 | def __init__(self, feature_strides, **kwargs): 24 | super(FPNHead, self).__init__( 25 | input_transform='multiple_select', **kwargs) 26 | assert len(feature_strides) == len(self.in_channels) 27 | assert min(feature_strides) == feature_strides[0] 28 | self.feature_strides = feature_strides 29 | 30 | self.scale_heads = nn.ModuleList() 31 | for i in range(len(feature_strides)): 32 | head_length = max( 33 | 1, 34 | int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) 35 | scale_head = [] 36 | for k in range(head_length): 37 | scale_head.append( 38 | ConvModule( 39 | self.in_channels[i] if k == 0 else self.channels, 40 | self.channels, 41 | 3, 42 | padding=1, 43 | conv_cfg=self.conv_cfg, 44 | norm_cfg=self.norm_cfg, 45 | act_cfg=self.act_cfg)) 46 | if feature_strides[i] != feature_strides[0]: 47 | scale_head.append( 48 | nn.Upsample( 49 | scale_factor=2, 50 | mode='bilinear', 51 | align_corners=self.align_corners)) 52 | self.scale_heads.append(nn.Sequential(*scale_head)) 53 | 54 | def forward(self, inputs): 55 | 56 | x = self._transform_inputs(inputs) 57 | 58 | output = self.scale_heads[0](x[0]) 59 | for i in range(1, len(self.feature_strides)): 60 | # non inplace 61 | output = output + resize( 62 | self.scale_heads[i](x[i]), 63 | size=output.shape[2:], 64 | mode='bilinear', 65 | align_corners=self.align_corners) 66 | 67 | output = self.cls_seg(output) 68 | # embed(header='123123') 69 | return output 70 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from ..builder import HEADS 6 | from .decode_head import BaseDecodeHead 7 | 8 | 9 | @HEADS.register_module() 10 | class FCNHead(BaseDecodeHead): 11 | """Fully Convolution Networks for Semantic Segmentation. 12 | 13 | This head is implemented of `FCNNet `_. 14 | 15 | Args: 16 | num_convs (int): Number of convs in the head. Default: 2. 17 | kernel_size (int): The kernel size for convs in the head. Default: 3. 18 | concat_input (bool): Whether concat the input and output of convs 19 | before classification layer. 20 | """ 21 | 22 | def __init__(self, 23 | num_convs=2, 24 | kernel_size=3, 25 | concat_input=True, 26 | **kwargs): 27 | assert num_convs >= 0 28 | self.num_convs = num_convs 29 | self.concat_input = concat_input 30 | self.kernel_size = kernel_size 31 | super(FCNHead, self).__init__(**kwargs) 32 | if num_convs == 0: 33 | assert self.in_channels == self.channels 34 | 35 | convs = [] 36 | convs.append( 37 | ConvModule( 38 | self.in_channels, 39 | self.channels, 40 | kernel_size=kernel_size, 41 | padding=kernel_size // 2, 42 | conv_cfg=self.conv_cfg, 43 | norm_cfg=self.norm_cfg, 44 | act_cfg=self.act_cfg)) 45 | for i in range(num_convs - 1): 46 | convs.append( 47 | ConvModule( 48 | self.channels, 49 | self.channels, 50 | kernel_size=kernel_size, 51 | padding=kernel_size // 2, 52 | conv_cfg=self.conv_cfg, 53 | norm_cfg=self.norm_cfg, 54 | act_cfg=self.act_cfg)) 55 | if num_convs == 0: 56 | self.convs = nn.Identity() 57 | else: 58 | self.convs = nn.Sequential(*convs) 59 | if self.concat_input: 60 | self.conv_cat = ConvModule( 61 | self.in_channels + self.channels, 62 | self.channels, 63 | kernel_size=kernel_size, 64 | padding=kernel_size // 2, 65 | conv_cfg=self.conv_cfg, 66 | norm_cfg=self.norm_cfg, 67 | act_cfg=self.act_cfg) 68 | 69 | def forward(self, inputs): 70 | """Forward function.""" 71 | x = self._transform_inputs(inputs) 72 | output = self.convs(x) 73 | if self.concat_input: 74 | output = self.conv_cat(torch.cat([x, output], dim=1)) 75 | output = self.cls_seg(output) 76 | return output 77 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/pascal_context.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class PascalContextDataset(CustomDataset): 9 | """PascalContext dataset. 10 | 11 | In segmentation map annotation for PascalContext, 0 stands for background, 12 | which is included in 60 categories. ``reduce_zero_label`` is fixed to 13 | False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is 14 | fixed to '.png'. 15 | 16 | Args: 17 | split (str): Split txt file for PascalContext. 18 | """ 19 | 20 | CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 21 | 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse', 22 | 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 23 | 'tvmonitor', 'bag', 'bed', 'bench', 'book', 'building', 24 | 'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door', 25 | 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 26 | 'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform', 27 | 'sign', 'plate', 'road', 'rock', 'shelves', 'sidewalk', 'sky', 28 | 'snow', 'bedclothes', 'track', 'tree', 'truck', 'wall', 'water', 29 | 'window', 'wood') 30 | 31 | PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 32 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 33 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 34 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 35 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 36 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 37 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 38 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 39 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 40 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 41 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 42 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 43 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 44 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 45 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] 46 | 47 | def __init__(self, split, **kwargs): 48 | super(PascalContextDataset, self).__init__( 49 | img_suffix='.jpg', 50 | seg_map_suffix='.png', 51 | split=split, 52 | reduce_zero_label=False, 53 | **kwargs) 54 | assert osp.exists(self.img_dir) and self.split is not None 55 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | from .builder import DATASETS 2 | from .custom import CustomDataset 3 | from IPython import embed 4 | 5 | @DATASETS.register_module() 6 | class MapillaryDataset(CustomDataset): 7 | """Mapillary dataset. 8 | """ 9 | CLASSES = ('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', 10 | 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', 'Pedestrian Area', 11 | 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', 'Tunnel', 12 | 'Person', 'Bicyclist', 'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk', 13 | 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 14 | 'Water', 'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', 15 | 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole', 16 | 'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', 17 | 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 18 | 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 19 | 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled') 20 | 21 | PALETTE = [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], 22 | [180, 165, 180], [90, 120, 150], [ 23 | 102, 102, 156], [128, 64, 255], 24 | [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], 25 | [230, 150, 140], [128, 64, 128], [ 26 | 110, 110, 110], [244, 35, 232], 27 | [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], 28 | [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], 29 | [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], 30 | [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], 31 | [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], 32 | [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], 33 | [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], 34 | [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], 35 | [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], 36 | [119, 11, 32], [150, 0, 255], [ 37 | 0, 60, 100], [0, 0, 142], [0, 0, 90], 38 | [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], 39 | [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]] 40 | 41 | def __init__(self, **kwargs): 42 | super(MapillaryDataset, self).__init__( 43 | img_suffix='.jpg', 44 | seg_map_suffix='.png', 45 | reduce_zero_label=False, 46 | **kwargs) -------------------------------------------------------------------------------- /TAPADL_RVT/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the knowledge distillation loss 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | class DistillationLoss(torch.nn.Module): 8 | """ 9 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 10 | taking a teacher model prediction and using it as additional supervision. 11 | """ 12 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 13 | distillation_type: str, alpha: float, tau: float): 14 | super().__init__() 15 | self.base_criterion = base_criterion 16 | self.teacher_model = teacher_model 17 | assert distillation_type in ['none', 'soft', 'hard'] 18 | self.distillation_type = distillation_type 19 | self.alpha = alpha 20 | self.tau = tau 21 | 22 | def forward(self, inputs, outputs, labels): 23 | """ 24 | Args: 25 | inputs: The original inputs that are feed to the teacher model 26 | outputs: the outputs of the model to be trained. It is expected to be 27 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 28 | in the first position and the distillation predictions as the second output 29 | labels: the labels for the base criterion 30 | """ 31 | outputs_kd = None 32 | if not isinstance(outputs, torch.Tensor): 33 | # assume that the model outputs a tuple of [outputs, outputs_kd] 34 | outputs, outputs_kd = outputs 35 | base_loss = self.base_criterion(outputs, labels) 36 | if self.distillation_type == 'none': 37 | return base_loss 38 | 39 | if outputs_kd is None: 40 | raise ValueError("When knowledge distillation is enabled, the model is " 41 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 42 | "class_token and the dist_token") 43 | # don't backprop throught the teacher 44 | with torch.no_grad(): 45 | teacher_outputs = self.teacher_model(inputs) 46 | 47 | if self.distillation_type == 'soft': 48 | T = self.tau 49 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 50 | # with slight modifications 51 | distillation_loss = F.kl_div( 52 | F.log_softmax(outputs_kd / T, dim=1), 53 | F.log_softmax(teacher_outputs / T, dim=1), 54 | reduction='sum', 55 | log_target=True 56 | ) * (T * T) / outputs_kd.numel() 57 | elif self.distillation_type == 'hard': 58 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 59 | 60 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 61 | return loss 62 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/setr_up_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | 5 | from mmseg.ops import Upsample 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | from IPython import embed 9 | 10 | @HEADS.register_module() 11 | class SETRUPHead(BaseDecodeHead): 12 | """Naive upsampling head and Progressive upsampling head of SETR. 13 | Naive or PUP head of `SETR `_. 14 | Args: 15 | norm_layer (dict): Config dict for input normalization. 16 | Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). 17 | num_convs (int): Number of decoder convolutions. Default: 1. 18 | up_scale (int): The scale factor of interpolate. Default:4. 19 | kernel_size (int): The kernel size of convolution when decoding 20 | feature information from backbone. Default: 3. 21 | init_cfg (dict | list[dict] | None): Initialization config dict. 22 | Default: dict( 23 | type='Constant', val=1.0, bias=0, layer='LayerNorm'). 24 | """ 25 | 26 | def __init__(self, 27 | norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), 28 | num_convs=1, 29 | up_scale=4, 30 | kernel_size=3, 31 | **kwargs): 32 | 33 | assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' 34 | 35 | super(SETRUPHead, self).__init__(**kwargs) 36 | 37 | assert isinstance(self.in_channels, int) 38 | 39 | _, self.norm = build_norm_layer(norm_layer, self.in_channels) 40 | 41 | self.up_convs = nn.ModuleList() 42 | in_channels = self.in_channels 43 | out_channels = self.channels 44 | for _ in range(num_convs): 45 | self.up_convs.append( 46 | nn.Sequential( 47 | ConvModule( 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | kernel_size=kernel_size, 51 | stride=1, 52 | padding=int(kernel_size - 1) // 2, 53 | norm_cfg=self.norm_cfg, 54 | act_cfg=self.act_cfg), 55 | Upsample( 56 | scale_factor=up_scale, 57 | mode='bilinear', 58 | align_corners=self.align_corners))) 59 | in_channels = out_channels 60 | 61 | def forward(self, x): 62 | x = self._transform_inputs(x) 63 | 64 | n, c, h, w = x.shape 65 | x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() 66 | x = self.norm(x) 67 | x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() 68 | for up_conv in self.up_convs: 69 | x = up_conv(x) 70 | out = self.cls_seg(x) 71 | return out 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robustifying Token Attention for Vision Transformers 2 | [Yong Guo](http://www.guoyongcs.com/), [David Stutz](https://davidstutz.de/), and [Bernt Schiele](https://scholar.google.com/citations?user=z76PBfYAAAAJ&hl=en). ICCV 2023. 3 | 4 | ### [Paper](https://arxiv.org/pdf/2303.11126.pdf) | [Slides](https://www.guoyongcs.com/TAPADL-Materials/TAPADL.pdf) | [Poster](https://www.guoyongcs.com/TAPADL-Materials/TAPADL_Poster.pdf) 5 | 6 | 7 |

8 | 10 |

11 | 12 | This repository contains the official Pytorch implementation and the pretrained models of [Robustifying Token Attention for Vision Transformers](https://arxiv.org/pdf/2303.11126.pdf). 13 | 14 | 15 | ## Catalog 16 | - [x] Pre-trained models for image classification 17 | - [x] Pre-trained models for semantic segmentation 18 | - [x] Evaluation and Training Code 19 | 20 | 21 | # Dependencies 22 | Our code is built based on pytorch and timm library. Please check the detailed dependencies in [requirements.txt](requirements.txt). 23 | 24 | 25 | # Dataset Preparation 26 | 27 | - **Image Classfication**: ImageNet and related robustness benchmarks 28 | 29 | Please download the clean [ImageNet](http://image-net.org/) dataset. We evaluate the models on varisous robustness benchmarks, including [ImageNet-C](https://zenodo.org/record/2235448), [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-P](https://zenodo.org/record/3565846), and [ImageNet-R](https://github.com/hendrycks/imagenet-r). 30 | 31 | - **Semantic Segmentaton**: Cityscapes and related robustness benchmarks 32 | 33 | Please download the clean [Cityscapes](https://www.cityscapes-dataset.com/) dataset. We evaluate the models on varisous robustness benchmarks, including [Cityscapes-C](https://github.com/guoyongcs/TAPADL/blob/main/TAPADL_FAN/segmentation/) and [ACDC](https://acdc.vision.ee.ethz.ch) (test set). 34 | 35 | 36 | 37 | ## Training and Evaluation (using TAP and ADL) 38 | - Image Classification: 39 | 40 | Please see how to train/evaluate FAN and RVT models in [TAPADL_FAN](TAPADL_FAN) and [TAPADL_RVT](TAPADL_RVT), respectively. 41 | 42 | 43 | 44 | 45 | 46 | - Semantic Segmentation: 47 | 48 | Please see how to train/evaluate our segmentation model in [TAPADL_FAN/segmentation](TAPADL_FAN/segmentation). 49 | 50 | 51 | 52 | 53 | ## Acknowledgement 54 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library, [RVT](https://github.com/vtddggg/Robust-Vision-Transformer), and [FAN](https://github.com/NVlabs/FAN) repositories. 55 | 56 | 57 | 58 | 59 | ## Citation 60 | If you find this repository helpful, please consider citing: 61 | ``` 62 | @inproceedings{guo2023robustifying, 63 | title={Robustifying token attention for vision transformers}, 64 | author={Guo, Yong and Stutz, David and Schiele, Bernt}, 65 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)}}, 66 | year={2023} 67 | } 68 | ``` 69 | 70 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Encoding(nn.Module): 7 | """Encoding Layer: a learnable residual encoder. 8 | 9 | Input is of shape (batch_size, channels, height, width). 10 | Output is of shape (batch_size, num_codes, channels). 11 | 12 | Args: 13 | channels: dimension of the features or feature channels 14 | num_codes: number of code words 15 | """ 16 | 17 | def __init__(self, channels, num_codes): 18 | super(Encoding, self).__init__() 19 | # init codewords and smoothing factor 20 | self.channels, self.num_codes = channels, num_codes 21 | std = 1. / ((num_codes * channels)**0.5) 22 | # [num_codes, channels] 23 | self.codewords = nn.Parameter( 24 | torch.empty(num_codes, channels, 25 | dtype=torch.float).uniform_(-std, std), 26 | requires_grad=True) 27 | # [num_codes] 28 | self.scale = nn.Parameter( 29 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 30 | requires_grad=True) 31 | 32 | @staticmethod 33 | def scaled_l2(x, codewords, scale): 34 | num_codes, channels = codewords.size() 35 | batch_size = x.size(0) 36 | reshaped_scale = scale.view((1, 1, num_codes)) 37 | expanded_x = x.unsqueeze(2).expand( 38 | (batch_size, x.size(1), num_codes, channels)) 39 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 40 | 41 | scaled_l2_norm = reshaped_scale * ( 42 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 43 | return scaled_l2_norm 44 | 45 | @staticmethod 46 | def aggregate(assigment_weights, x, codewords): 47 | num_codes, channels = codewords.size() 48 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 49 | batch_size = x.size(0) 50 | 51 | expanded_x = x.unsqueeze(2).expand( 52 | (batch_size, x.size(1), num_codes, channels)) 53 | encoded_feat = (assigment_weights.unsqueeze(3) * 54 | (expanded_x - reshaped_codewords)).sum(dim=1) 55 | return encoded_feat 56 | 57 | def forward(self, x): 58 | assert x.dim() == 4 and x.size(1) == self.channels 59 | # [batch_size, channels, height, width] 60 | batch_size = x.size(0) 61 | # [batch_size, height x width, channels] 62 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 63 | # assignment_weights: [batch_size, channels, num_codes] 64 | assigment_weights = F.softmax( 65 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 66 | # aggregate 67 | encoded_feat = self.aggregate(assigment_weights, x, self.codewords) 68 | return encoded_feat 69 | 70 | def __repr__(self): 71 | repr_str = self.__class__.__name__ 72 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 73 | f'x{self.channels})' 74 | return repr_str 75 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/segformer_head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 5 | from collections import OrderedDict 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | from mmseg.models.utils import * 11 | import attr 12 | 13 | from IPython import embed 14 | 15 | 16 | class MLP(nn.Module): 17 | """ 18 | Linear Embedding 19 | """ 20 | def __init__(self, input_dim=2048, embed_dim=768): 21 | super().__init__() 22 | self.proj = nn.Linear(input_dim, embed_dim) 23 | 24 | def forward(self, x): 25 | x = x.flatten(2).transpose(1, 2) 26 | x = self.proj(x) 27 | return x 28 | 29 | 30 | @HEADS.register_module() 31 | class SegFormerHead(BaseDecodeHead): 32 | """ 33 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 34 | """ 35 | def __init__(self, feature_strides, **kwargs): 36 | super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs) 37 | assert len(feature_strides) == len(self.in_channels) 38 | assert min(feature_strides) == feature_strides[0] 39 | self.feature_strides = feature_strides 40 | 41 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 42 | 43 | decoder_params = kwargs['decoder_params'] 44 | embedding_dim = decoder_params['embed_dim'] 45 | 46 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 47 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 48 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 49 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 50 | 51 | self.linear_fuse = ConvModule( 52 | in_channels=embedding_dim*4, 53 | out_channels=embedding_dim, 54 | kernel_size=1, 55 | norm_cfg=dict(type='SyncBN', requires_grad=True) 56 | ) 57 | 58 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 59 | 60 | def forward(self, inputs): 61 | x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 62 | c1, c2, c3, c4 = x 63 | 64 | ############## MLP decoder on C1-C4 ########### 65 | n, _, h, w = c4.shape 66 | 67 | _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) 68 | _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) 69 | 70 | _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) 71 | _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) 72 | 73 | _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) 74 | _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) 75 | 76 | _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) 77 | 78 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 79 | x = self.dropout(_c) 80 | x = self.linear_pred(x) 81 | 82 | return x 83 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def accuracy(pred, target, topk=1, thresh=None): 5 | """Calculate accuracy according to the prediction and target. 6 | 7 | Args: 8 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 9 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 10 | topk (int | tuple[int], optional): If the predictions in ``topk`` 11 | matches the target, the predictions will be regarded as 12 | correct ones. Defaults to 1. 13 | thresh (float, optional): If not None, predictions with scores under 14 | this threshold are considered incorrect. Default to None. 15 | 16 | Returns: 17 | float | tuple[float]: If the input ``topk`` is a single integer, 18 | the function will return a single float as accuracy. If 19 | ``topk`` is a tuple containing multiple integers, the 20 | function will return a tuple containing accuracies of 21 | each ``topk`` number. 22 | """ 23 | assert isinstance(topk, (int, tuple)) 24 | if isinstance(topk, int): 25 | topk = (topk, ) 26 | return_single = True 27 | else: 28 | return_single = False 29 | 30 | maxk = max(topk) 31 | if pred.size(0) == 0: 32 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 33 | return accu[0] if return_single else accu 34 | assert pred.ndim == target.ndim + 1 35 | assert pred.size(0) == target.size(0) 36 | assert maxk <= pred.size(1), \ 37 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 38 | pred_value, pred_label = pred.topk(maxk, dim=1) 39 | # transpose to shape (maxk, N, ...) 40 | pred_label = pred_label.transpose(0, 1) 41 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 42 | if thresh is not None: 43 | # Only prediction values larger than thresh are counted as correct 44 | correct = correct & (pred_value > thresh).t() 45 | res = [] 46 | for k in topk: 47 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 48 | res.append(correct_k.mul_(100.0 / target.numel())) 49 | return res[0] if return_single else res 50 | 51 | 52 | class Accuracy(nn.Module): 53 | """Accuracy calculation module.""" 54 | 55 | def __init__(self, topk=(1, ), thresh=None): 56 | """Module to calculate the accuracy. 57 | 58 | Args: 59 | topk (tuple, optional): The criterion used to calculate the 60 | accuracy. Defaults to (1,). 61 | thresh (float, optional): If not None, predictions with scores 62 | under this threshold are considered incorrect. Default to None. 63 | """ 64 | super().__init__() 65 | self.topk = topk 66 | self.thresh = thresh 67 | 68 | def forward(self, pred, target): 69 | """Forward function to calculate accuracy. 70 | 71 | Args: 72 | pred (torch.Tensor): Prediction of models. 73 | target (torch.Tensor): Target for each prediction. 74 | 75 | Returns: 76 | tuple[float]: The accuracies under different topk criterions. 77 | """ 78 | return accuracy(pred, target, self.topk, self.thresh) 79 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/lraspp_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv import is_tuple_of 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class LRASPPHead(BaseDecodeHead): 13 | """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. 14 | 15 | This head is the improved implementation of `Searching for MobileNetV3 16 | `_. 17 | 18 | Args: 19 | branch_channels (tuple[int]): The number of output channels in every 20 | each branch. Default: (32, 64). 21 | """ 22 | 23 | def __init__(self, branch_channels=(32, 64), **kwargs): 24 | super(LRASPPHead, self).__init__(**kwargs) 25 | if self.input_transform != 'multiple_select': 26 | raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' 27 | f'must be \'multiple_select\'. But received ' 28 | f'\'{self.input_transform}\'') 29 | assert is_tuple_of(branch_channels, int) 30 | assert len(branch_channels) == len(self.in_channels) - 1 31 | self.branch_channels = branch_channels 32 | 33 | self.convs = nn.Sequential() 34 | self.conv_ups = nn.Sequential() 35 | for i in range(len(branch_channels)): 36 | self.convs.add_module( 37 | f'conv{i}', 38 | nn.Conv2d( 39 | self.in_channels[i], branch_channels[i], 1, bias=False)) 40 | self.conv_ups.add_module( 41 | f'conv_up{i}', 42 | ConvModule( 43 | self.channels + branch_channels[i], 44 | self.channels, 45 | 1, 46 | norm_cfg=self.norm_cfg, 47 | act_cfg=self.act_cfg, 48 | bias=False)) 49 | 50 | self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) 51 | 52 | self.aspp_conv = ConvModule( 53 | self.in_channels[-1], 54 | self.channels, 55 | 1, 56 | norm_cfg=self.norm_cfg, 57 | act_cfg=self.act_cfg, 58 | bias=False) 59 | self.image_pool = nn.Sequential( 60 | nn.AvgPool2d(kernel_size=49, stride=(16, 20)), 61 | ConvModule( 62 | self.in_channels[2], 63 | self.channels, 64 | 1, 65 | act_cfg=dict(type='Sigmoid'), 66 | bias=False)) 67 | 68 | def forward(self, inputs): 69 | """Forward function.""" 70 | inputs = self._transform_inputs(inputs) 71 | 72 | x = inputs[-1] 73 | 74 | x = self.aspp_conv(x) * resize( 75 | self.image_pool(x), 76 | size=x.size()[2:], 77 | mode='bilinear', 78 | align_corners=self.align_corners) 79 | x = self.conv_up_input(x) 80 | 81 | for i in range(len(self.branch_channels) - 1, -1, -1): 82 | x = resize( 83 | x, 84 | size=inputs[i].size()[2:], 85 | mode='bilinear', 86 | align_corners=self.align_corners) 87 | x = torch.cat([x, self.convs[i](inputs[i])], 1) 88 | x = self.conv_ups[i](x) 89 | 90 | return self.cls_seg(x) 91 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from ..builder import PIXEL_SAMPLERS 5 | from .base_pixel_sampler import BasePixelSampler 6 | 7 | 8 | @PIXEL_SAMPLERS.register_module() 9 | class OHEMPixelSampler(BasePixelSampler): 10 | """Online Hard Example Mining Sampler for segmentation. 11 | 12 | Args: 13 | context (nn.Module): The context of sampler, subclass of 14 | :obj:`BaseDecodeHead`. 15 | thresh (float, optional): The threshold for hard example selection. 16 | Below which, are prediction with low confidence. If not 17 | specified, the hard examples will be pixels of top ``min_kept`` 18 | loss. Default: None. 19 | min_kept (int, optional): The minimum number of predictions to keep. 20 | Default: 100000. 21 | """ 22 | 23 | def __init__(self, context, thresh=None, min_kept=100000): 24 | super(OHEMPixelSampler, self).__init__() 25 | self.context = context 26 | assert min_kept > 1 27 | self.thresh = thresh 28 | self.min_kept = min_kept 29 | 30 | def sample(self, seg_logit, seg_label): 31 | """Sample pixels that have high loss or with low prediction confidence. 32 | 33 | Args: 34 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 35 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 36 | 37 | Returns: 38 | torch.Tensor: segmentation weight, shape (N, H, W) 39 | """ 40 | with torch.no_grad(): 41 | assert seg_logit.shape[2:] == seg_label.shape[2:] 42 | assert seg_label.shape[1] == 1 43 | seg_label = seg_label.squeeze(1).long() 44 | batch_kept = self.min_kept * seg_label.size(0) 45 | valid_mask = seg_label != self.context.ignore_index 46 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 47 | valid_seg_weight = seg_weight[valid_mask] 48 | if self.thresh is not None: 49 | seg_prob = F.softmax(seg_logit, dim=1) 50 | 51 | tmp_seg_label = seg_label.clone().unsqueeze(1) 52 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 53 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 54 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 55 | 56 | if sort_prob.numel() > 0: 57 | min_threshold = sort_prob[min(batch_kept, 58 | sort_prob.numel() - 1)] 59 | else: 60 | min_threshold = 0.0 61 | threshold = max(min_threshold, self.thresh) 62 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 63 | else: 64 | losses = self.context.loss_decode( 65 | seg_logit, 66 | seg_label, 67 | weight=None, 68 | ignore_index=self.context.ignore_index, 69 | reduction_override='none') 70 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 71 | _, sort_indices = losses[valid_mask].sort(descending=True) 72 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 73 | 74 | seg_weight[valid_mask] = valid_seg_weight 75 | 76 | return seg_weight 77 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn.functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are "none", "mean" and "sum". 12 | 13 | Return: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | elif reduction_enum == 2: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. 32 | reduction (str): Same as built-in losses of PyTorch. 33 | avg_factor (float): Avarage factor when computing the mean of losses. 34 | 35 | Returns: 36 | Tensor: Processed loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | if weight.dim() > 1: 42 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 43 | loss = loss * weight 44 | 45 | # if avg_factor is not specified, just reduce the loss 46 | if avg_factor is None: 47 | loss = reduce_loss(loss, reduction) 48 | else: 49 | # if reduction is mean, then average the loss by avg_factor 50 | if reduction == 'mean': 51 | loss = loss.sum() / avg_factor 52 | # if reduction is 'none', then do nothing, otherwise raise an error 53 | elif reduction != 'none': 54 | raise ValueError('avg_factor can not be used with reduction="sum"') 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | avg_factor=None, **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, avg_factor=2) 86 | tensor(1.5000) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, 91 | target, 92 | weight=None, 93 | reduction='mean', 94 | avg_factor=None, 95 | **kwargs): 96 | # get element-wise loss 97 | loss = loss_func(pred, target, **kwargs) 98 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 99 | return loss 100 | 101 | return wrapper 102 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | from mmcv.cnn import build_conv_layer, build_norm_layer 2 | from torch import nn as nn 3 | 4 | 5 | class ResLayer(nn.Sequential): 6 | """ResLayer to build ResNet style backbone. 7 | 8 | Args: 9 | block (nn.Module): block used to build ResLayer. 10 | inplanes (int): inplanes of block. 11 | planes (int): planes of block. 12 | num_blocks (int): number of blocks. 13 | stride (int): stride of the first block. Default: 1 14 | avg_down (bool): Use AvgPool instead of stride conv when 15 | downsampling in the bottleneck. Default: False 16 | conv_cfg (dict): dictionary to construct and config conv layer. 17 | Default: None 18 | norm_cfg (dict): dictionary to construct and config norm layer. 19 | Default: dict(type='BN') 20 | multi_grid (int | None): Multi grid dilation rates of last 21 | stage. Default: None 22 | contract_dilation (bool): Whether contract first dilation of each layer 23 | Default: False 24 | """ 25 | 26 | def __init__(self, 27 | block, 28 | inplanes, 29 | planes, 30 | num_blocks, 31 | stride=1, 32 | dilation=1, 33 | avg_down=False, 34 | conv_cfg=None, 35 | norm_cfg=dict(type='BN'), 36 | multi_grid=None, 37 | contract_dilation=False, 38 | **kwargs): 39 | self.block = block 40 | 41 | downsample = None 42 | if stride != 1 or inplanes != planes * block.expansion: 43 | downsample = [] 44 | conv_stride = stride 45 | if avg_down: 46 | conv_stride = 1 47 | downsample.append( 48 | nn.AvgPool2d( 49 | kernel_size=stride, 50 | stride=stride, 51 | ceil_mode=True, 52 | count_include_pad=False)) 53 | downsample.extend([ 54 | build_conv_layer( 55 | conv_cfg, 56 | inplanes, 57 | planes * block.expansion, 58 | kernel_size=1, 59 | stride=conv_stride, 60 | bias=False), 61 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 62 | ]) 63 | downsample = nn.Sequential(*downsample) 64 | 65 | layers = [] 66 | if multi_grid is None: 67 | if dilation > 1 and contract_dilation: 68 | first_dilation = dilation // 2 69 | else: 70 | first_dilation = dilation 71 | else: 72 | first_dilation = multi_grid[0] 73 | layers.append( 74 | block( 75 | inplanes=inplanes, 76 | planes=planes, 77 | stride=stride, 78 | dilation=first_dilation, 79 | downsample=downsample, 80 | conv_cfg=conv_cfg, 81 | norm_cfg=norm_cfg, 82 | **kwargs)) 83 | inplanes = planes * block.expansion 84 | for i in range(1, num_blocks): 85 | layers.append( 86 | block( 87 | inplanes=inplanes, 88 | planes=planes, 89 | stride=1, 90 | dilation=dilation if multi_grid is None else multi_grid[i], 91 | conv_cfg=conv_cfg, 92 | norm_cfg=norm_cfg, 93 | **kwargs)) 94 | super(ResLayer, self).__init__(*layers) 95 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/psp_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | class PPM(nn.ModuleList): 11 | """Pooling Pyramid Module used in PSPNet. 12 | 13 | Args: 14 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 15 | Module. 16 | in_channels (int): Input channels. 17 | channels (int): Channels after modules, before conv_seg. 18 | conv_cfg (dict|None): Config of conv layers. 19 | norm_cfg (dict|None): Config of norm layers. 20 | act_cfg (dict): Config of activation layers. 21 | align_corners (bool): align_corners argument of F.interpolate. 22 | """ 23 | 24 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 25 | act_cfg, align_corners): 26 | super(PPM, self).__init__() 27 | self.pool_scales = pool_scales 28 | self.align_corners = align_corners 29 | self.in_channels = in_channels 30 | self.channels = channels 31 | self.conv_cfg = conv_cfg 32 | self.norm_cfg = norm_cfg 33 | self.act_cfg = act_cfg 34 | for pool_scale in pool_scales: 35 | self.append( 36 | nn.Sequential( 37 | nn.AdaptiveAvgPool2d(pool_scale), 38 | ConvModule( 39 | self.in_channels, 40 | self.channels, 41 | 1, 42 | conv_cfg=self.conv_cfg, 43 | norm_cfg=self.norm_cfg, 44 | act_cfg=self.act_cfg))) 45 | 46 | def forward(self, x): 47 | """Forward function.""" 48 | ppm_outs = [] 49 | for ppm in self: 50 | ppm_out = ppm(x) 51 | upsampled_ppm_out = resize( 52 | ppm_out, 53 | size=x.size()[2:], 54 | mode='bilinear', 55 | align_corners=self.align_corners) 56 | ppm_outs.append(upsampled_ppm_out) 57 | return ppm_outs 58 | 59 | 60 | @HEADS.register_module() 61 | class PSPHead(BaseDecodeHead): 62 | """Pyramid Scene Parsing Network. 63 | 64 | This head is the implementation of 65 | `PSPNet `_. 66 | 67 | Args: 68 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 69 | Module. Default: (1, 2, 3, 6). 70 | """ 71 | 72 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 73 | super(PSPHead, self).__init__(**kwargs) 74 | assert isinstance(pool_scales, (list, tuple)) 75 | self.pool_scales = pool_scales 76 | self.psp_modules = PPM( 77 | self.pool_scales, 78 | self.in_channels, 79 | self.channels, 80 | conv_cfg=self.conv_cfg, 81 | norm_cfg=self.norm_cfg, 82 | act_cfg=self.act_cfg, 83 | align_corners=self.align_corners) 84 | self.bottleneck = ConvModule( 85 | self.in_channels + len(pool_scales) * self.channels, 86 | self.channels, 87 | 3, 88 | padding=1, 89 | conv_cfg=self.conv_cfg, 90 | norm_cfg=self.norm_cfg, 91 | act_cfg=self.act_cfg) 92 | 93 | def forward(self, inputs): 94 | """Forward function.""" 95 | x = self._transform_inputs(inputs) 96 | psp_outs = [x] 97 | psp_outs.extend(self.psp_modules(x)) 98 | psp_outs = torch.cat(psp_outs, dim=1) 99 | output = self.bottleneck(psp_outs) 100 | output = self.cls_seg(output) 101 | return output 102 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/aspp_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | class ASPPModule(nn.ModuleList): 11 | """Atrous Spatial Pyramid Pooling (ASPP) Module. 12 | 13 | Args: 14 | dilations (tuple[int]): Dilation rate of each layer. 15 | in_channels (int): Input channels. 16 | channels (int): Channels after modules, before conv_seg. 17 | conv_cfg (dict|None): Config of conv layers. 18 | norm_cfg (dict|None): Config of norm layers. 19 | act_cfg (dict): Config of activation layers. 20 | """ 21 | 22 | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, 23 | act_cfg): 24 | super(ASPPModule, self).__init__() 25 | self.dilations = dilations 26 | self.in_channels = in_channels 27 | self.channels = channels 28 | self.conv_cfg = conv_cfg 29 | self.norm_cfg = norm_cfg 30 | self.act_cfg = act_cfg 31 | for dilation in dilations: 32 | self.append( 33 | ConvModule( 34 | self.in_channels, 35 | self.channels, 36 | 1 if dilation == 1 else 3, 37 | dilation=dilation, 38 | padding=0 if dilation == 1 else dilation, 39 | conv_cfg=self.conv_cfg, 40 | norm_cfg=self.norm_cfg, 41 | act_cfg=self.act_cfg)) 42 | 43 | def forward(self, x): 44 | """Forward function.""" 45 | aspp_outs = [] 46 | for aspp_module in self: 47 | aspp_outs.append(aspp_module(x)) 48 | 49 | return aspp_outs 50 | 51 | 52 | @HEADS.register_module() 53 | class ASPPHead(BaseDecodeHead): 54 | """Rethinking Atrous Convolution for Semantic Image Segmentation. 55 | 56 | This head is the implementation of `DeepLabV3 57 | `_. 58 | 59 | Args: 60 | dilations (tuple[int]): Dilation rates for ASPP module. 61 | Default: (1, 6, 12, 18). 62 | """ 63 | 64 | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 65 | super(ASPPHead, self).__init__(**kwargs) 66 | assert isinstance(dilations, (list, tuple)) 67 | self.dilations = dilations 68 | self.image_pool = nn.Sequential( 69 | nn.AdaptiveAvgPool2d(1), 70 | ConvModule( 71 | self.in_channels, 72 | self.channels, 73 | 1, 74 | conv_cfg=self.conv_cfg, 75 | norm_cfg=self.norm_cfg, 76 | act_cfg=self.act_cfg)) 77 | self.aspp_modules = ASPPModule( 78 | dilations, 79 | self.in_channels, 80 | self.channels, 81 | conv_cfg=self.conv_cfg, 82 | norm_cfg=self.norm_cfg, 83 | act_cfg=self.act_cfg) 84 | self.bottleneck = ConvModule( 85 | (len(dilations) + 1) * self.channels, 86 | self.channels, 87 | 3, 88 | padding=1, 89 | conv_cfg=self.conv_cfg, 90 | norm_cfg=self.norm_cfg, 91 | act_cfg=self.act_cfg) 92 | 93 | def forward(self, inputs): 94 | """Forward function.""" 95 | x = self._transform_inputs(inputs) 96 | aspp_outs = [ 97 | resize( 98 | self.image_pool(x), 99 | size=x.size()[2:], 100 | mode='bilinear', 101 | align_corners=self.align_corners) 102 | ] 103 | aspp_outs.extend(self.aspp_modules(x)) 104 | aspp_outs = torch.cat(aspp_outs, dim=1) 105 | output = self.bottleneck(aspp_outs) 106 | output = self.cls_seg(output) 107 | return output 108 | -------------------------------------------------------------------------------- /TAPADL_RVT/README.md: -------------------------------------------------------------------------------- 1 | # Training and evaluation of TAPADL based on RVT 2 | [Robustifying Token Attention for Vision Transformers](https://arxiv.org/pdf/2303.11126.pdf), \ 3 | [Yong Guo](http://www.guoyongcs.com/), [David Stutz](https://davidstutz.de/), and [Bernt Schiele](https://scholar.google.com/citations?user=z76PBfYAAAAJ&hl=en). ICCV 2023. 4 | 5 | 6 | 7 | # Dependencies 8 | Our code is built based on pytorch and timm library. Please check the detailed dependencies in [requirements.txt](https://github.com/guoyongcs/TAPADL/blob/main/requirements.txt). 9 | 10 | # Dataset Preparation 11 | 12 | Please download the clean [ImageNet](http://image-net.org/) dataset. 13 | 14 | 15 | We use many robustness benchmarks for evaluation, including [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-C](https://zenodo.org/record/2235448), [ImageNet-P](https://zenodo.org/record/3565846) and [ImageNet-R](https://github.com/hendrycks/imagenet-r). 16 | 17 | 18 | ## Image Classification 19 | 20 | 21 | ### Pretrained Model 22 | 23 | | Model | #Params | IN-1K $\uparrow$ | IN-C $\downarrow$ | IN-A $\uparrow$ | IN-P $\downarrow$ | 24 | |:-----------------:|:----------------:|:-----------------:|:---------------:|:-----------------:|:-------:| 25 | | [RVT-B (TAP & ADL)](https://github.com/guoyongcs/TAPADL/releases/download/v1.0/tapadl_rvt_base.pth.tar) | 92.1M | **83.1** | **44.7** | **32.7** | **29.6** | 26 | 27 | Please download and put the pretrained model [tapadl_rvt_base.pth.tar](https://github.com/guoyongcs/TAPADL/releases/download/v1.0/tapadl_rvt_base.pth.tar) in ```../pretrained```. 28 | 29 | 30 | ### Evaluation 31 | - Evaluate the pretrained model on ImageNet: 32 | ``` 33 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 main.py \ 34 | --eval --model tap_rvt_base_plus --data-path /PATH/TO/IMAGENET \ 35 | --output_dir ./experiments/test_exp_tapadl_rvt_base_imagenet --dist-eval \ 36 | --pretrain_path ../pretrained/tapadl_rvt_base.pth.tar 37 | ``` 38 | 39 | - Evaluate the pretrained model on ImageNet-A/R: 40 | ``` 41 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 main.py \ 42 | --eval --model tap_rvt_base_plus --data-path /PATH/TO/IMAGENET \ 43 | --output_dir ./experiments/test_exp_tapadl_rvt_base_imagenet_a --dist-eval \ 44 | --pretrain_path ../pretrained/tapadl_rvt_base.pth.tar --ina_path /PATH/TO/IMAGENET-A 45 | ``` 46 | 47 | - Evaluate the pretrained model on ImageNet-C: 48 | ``` 49 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 main.py \ 50 | --eval --model tap_rvt_base_plus --data-path /PATH/TO/IMAGENET \ 51 | --output_dir ./experiments/test_exp_tapadl_rvt_base_imagenet_c --dist-eval \ 52 | --pretrain_path ../pretrained/tapadl_rvt_base.pth.tar --inc_path /PATH/TO/IMAGENET-C 53 | ``` 54 | 55 | - Evaluate the pretrained model on ImageNet-P 56 | 57 | Please refer to [test.sh](https://github.com/hendrycks/robustness/blob/master/ImageNet-P/test.sh) to see how to evaluate models on ImageNet-P. 58 | 59 | 60 | 61 | 62 | ### Training 63 | Train FAN-B-Hybrid with TAP and ADL on ImageNet (using 8 nodes and each with 4 GPUs) 64 | ``` 65 | python -m torch.distributed.launch --nproc_per_node=4 --nnodes=8 --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT main.py \ 66 | --model tap_rvt_base_plus --data-path /PATH/TO/IMAGENET --output_dir ./experiments/exp_tapadl_rvt_base_imagenet --dist-eval --use_patch_aug \ 67 | --batch-size 64 --aa rand-m9-mstd0.5-inc1 68 | ``` 69 | 70 | 71 | ## Citation 72 | If you find this repository helpful, please consider citing: 73 | ``` 74 | @inproceedings{guo2023robustifying, 75 | title={Robustifying token attention for vision transformers}, 76 | author={Guo, Yong and Stutz, David and Schiele, Bernt}, 77 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)}}, 78 | year={2023} 79 | } 80 | ``` 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/sep_aspp_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 4 | 5 | from mmseg.ops import resize 6 | from ..builder import HEADS 7 | from .aspp_head import ASPPHead, ASPPModule 8 | 9 | 10 | class DepthwiseSeparableASPPModule(ASPPModule): 11 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 12 | conv.""" 13 | 14 | def __init__(self, **kwargs): 15 | super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) 16 | for i, dilation in enumerate(self.dilations): 17 | if dilation > 1: 18 | self[i] = DepthwiseSeparableConvModule( 19 | self.in_channels, 20 | self.channels, 21 | 3, 22 | dilation=dilation, 23 | padding=dilation, 24 | norm_cfg=self.norm_cfg, 25 | act_cfg=self.act_cfg) 26 | 27 | 28 | @HEADS.register_module() 29 | class DepthwiseSeparableASPPHead(ASPPHead): 30 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 31 | Segmentation. 32 | 33 | This head is the implementation of `DeepLabV3+ 34 | `_. 35 | 36 | Args: 37 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 38 | the no decoder will be used. 39 | c1_channels (int): The intermediate channels of c1 decoder. 40 | """ 41 | 42 | def __init__(self, c1_in_channels, c1_channels, **kwargs): 43 | super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) 44 | assert c1_in_channels >= 0 45 | self.aspp_modules = DepthwiseSeparableASPPModule( 46 | dilations=self.dilations, 47 | in_channels=self.in_channels, 48 | channels=self.channels, 49 | conv_cfg=self.conv_cfg, 50 | norm_cfg=self.norm_cfg, 51 | act_cfg=self.act_cfg) 52 | if c1_in_channels > 0: 53 | self.c1_bottleneck = ConvModule( 54 | c1_in_channels, 55 | c1_channels, 56 | 1, 57 | conv_cfg=self.conv_cfg, 58 | norm_cfg=self.norm_cfg, 59 | act_cfg=self.act_cfg) 60 | else: 61 | self.c1_bottleneck = None 62 | self.sep_bottleneck = nn.Sequential( 63 | DepthwiseSeparableConvModule( 64 | self.channels + c1_channels, 65 | self.channels, 66 | 3, 67 | padding=1, 68 | norm_cfg=self.norm_cfg, 69 | act_cfg=self.act_cfg), 70 | DepthwiseSeparableConvModule( 71 | self.channels, 72 | self.channels, 73 | 3, 74 | padding=1, 75 | norm_cfg=self.norm_cfg, 76 | act_cfg=self.act_cfg)) 77 | 78 | def forward(self, inputs): 79 | """Forward function.""" 80 | x = self._transform_inputs(inputs) 81 | aspp_outs = [ 82 | resize( 83 | self.image_pool(x), 84 | size=x.size()[2:], 85 | mode='bilinear', 86 | align_corners=self.align_corners) 87 | ] 88 | aspp_outs.extend(self.aspp_modules(x)) 89 | aspp_outs = torch.cat(aspp_outs, dim=1) 90 | output = self.bottleneck(aspp_outs) 91 | if self.c1_bottleneck is not None: 92 | c1_output = self.c1_bottleneck(inputs[0]) 93 | output = resize( 94 | input=output, 95 | size=c1_output.shape[2:], 96 | mode='bilinear', 97 | align_corners=self.align_corners) 98 | output = torch.cat([output, c1_output], dim=1) 99 | output = self.sep_bottleneck(output) 100 | output = self.cls_seg(output) 101 | return output 102 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/segmentors/cascade_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from mmseg.core import add_prefix 4 | from mmseg.ops import resize 5 | from .. import builder 6 | from ..builder import SEGMENTORS 7 | from .encoder_decoder import EncoderDecoder 8 | 9 | 10 | @SEGMENTORS.register_module() 11 | class CascadeEncoderDecoder(EncoderDecoder): 12 | """Cascade Encoder Decoder segmentors. 13 | 14 | CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of 15 | CascadeEncoderDecoder are cascaded. The output of previous decoder_head 16 | will be the input of next decoder_head. 17 | """ 18 | 19 | def __init__(self, 20 | num_stages, 21 | backbone, 22 | decode_head, 23 | neck=None, 24 | auxiliary_head=None, 25 | train_cfg=None, 26 | test_cfg=None, 27 | pretrained=None): 28 | self.num_stages = num_stages 29 | super(CascadeEncoderDecoder, self).__init__( 30 | backbone=backbone, 31 | decode_head=decode_head, 32 | neck=neck, 33 | auxiliary_head=auxiliary_head, 34 | train_cfg=train_cfg, 35 | test_cfg=test_cfg, 36 | pretrained=pretrained) 37 | 38 | def _init_decode_head(self, decode_head): 39 | """Initialize ``decode_head``""" 40 | assert isinstance(decode_head, list) 41 | assert len(decode_head) == self.num_stages 42 | self.decode_head = nn.ModuleList() 43 | for i in range(self.num_stages): 44 | self.decode_head.append(builder.build_head(decode_head[i])) 45 | self.align_corners = self.decode_head[-1].align_corners 46 | self.num_classes = self.decode_head[-1].num_classes 47 | 48 | def init_weights(self, pretrained=None): 49 | """Initialize the weights in backbone and heads. 50 | 51 | Args: 52 | pretrained (str, optional): Path to pre-trained weights. 53 | Defaults to None. 54 | """ 55 | self.backbone.init_weights(pretrained=pretrained) 56 | for i in range(self.num_stages): 57 | self.decode_head[i].init_weights() 58 | if self.with_auxiliary_head: 59 | if isinstance(self.auxiliary_head, nn.ModuleList): 60 | for aux_head in self.auxiliary_head: 61 | aux_head.init_weights() 62 | else: 63 | self.auxiliary_head.init_weights() 64 | 65 | def encode_decode(self, img, img_metas): 66 | """Encode images with backbone and decode into a semantic segmentation 67 | map of the same size as input.""" 68 | x = self.extract_feat(img) 69 | out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) 70 | for i in range(1, self.num_stages): 71 | out = self.decode_head[i].forward_test(x, out, img_metas, 72 | self.test_cfg) 73 | out = resize( 74 | input=out, 75 | size=img.shape[2:], 76 | mode='bilinear', 77 | align_corners=self.align_corners) 78 | return out 79 | 80 | def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): 81 | """Run forward function and calculate loss for decode head in 82 | training.""" 83 | losses = dict() 84 | 85 | loss_decode = self.decode_head[0].forward_train( 86 | x, img_metas, gt_semantic_seg, self.train_cfg) 87 | 88 | losses.update(add_prefix(loss_decode, 'decode_0')) 89 | 90 | for i in range(1, self.num_stages): 91 | # forward test again, maybe unnecessary for most methods. 92 | prev_outputs = self.decode_head[i - 1].forward_test( 93 | x, img_metas, self.test_cfg) 94 | loss_decode = self.decode_head[i].forward_train( 95 | x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) 96 | losses.update(add_prefix(loss_decode, f'decode_{i}')) 97 | 98 | return losses 99 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Segmentation codebase for TAP and ADL based on FAN 2 | 3 | We follow [FAN](https://github.com/NVlabs/FAN/tree/master) to build our codebase which is developed on top of [MMSegmentation v0.13.0](https://github.com/open-mmlab/mmsegmentation/tree/v0.13.0). 4 | 5 | 6 | ## Dependencies 7 | 8 | Install according to the guidelines in [MMSegmentation v0.13.0](https://github.com/open-mmlab/mmsegmentation/tree/v0.13.0). 9 | 10 | Please refer to [requirements.txt](https://github.com/guoyongcs/TAPADL/blob/main/requirements.txt) for other dependencies. 11 | 12 | 13 | ## Dataset Preparation 14 | 15 | - Prepare Cityscapes according to the guidelines in [MMSegmentation v0.13.0](https://github.com/open-mmlab/mmsegmentation/tree/v0.13.0). 16 | 17 | 18 | - To generate Cityscapes-C dataset, first install the natural image corruption lib via: 19 | 20 | pip install imagecorruptions 21 | 22 | Then, run the following command: 23 | 24 | ``` 25 | python tools/gen_city_c.py 26 | ``` 27 | 28 | Please see more details of generating Cityscapes-C in the [guidelines](https://github.com/NVlabs/FAN/tree/master/segmentation). 29 | 30 | 31 | - As for ACDC, please refer to [acdc.vision.ee.ethz.ch](https://acdc.vision.ee.ethz.ch) to get the test set and submit the results. 32 | 33 | ## Pretrained Model 34 | 35 | | Model | Cityscapes | Cityscapes-C | ACDC | 36 | |:-----------------:|:----------------:|:-----------------:|:---------------:| 37 | | [FAN-B-ViT (TAP & ADL)](https://github.com/guoyongcs/TAPADL/releases/download/v1.0/tapadl_fan_base_segmentation.pth) | **82.9** | **69.7** | **63.6** | 38 | 39 | Please put the pretrained model [tapadl_fan_base_segmentation.pth](https://github.com/guoyongcs/TAPADL/releases/download/v1.0/tapadl_fan_base_segmentation.pth) in ```../../pretrained```. 40 | 41 | 42 | ## Evaluation 43 | 44 | - Evaluate the pretrained model on Cityscapes: 45 | 46 | 1. Please specify the data path "data_root" in [cityscapes_1024x1024_repeat.py](https://github.com/guoyongcs/TAPADL/blob/main/TAPADL_FAN/segmentation/local_configs/_base_/datasets/cityscapes_1024x1024_repeat.py). 47 | 48 | 49 | 2. Evaluate the model via 50 | ``` 51 | CUDA_VISIBLE_DEVICES=0 python test_cityscapes.py \ 52 | local_configs/fan/fan_hybrid/tapfan_hybrid_base.1024x1024.city.160k.test.py \ 53 | ../../pretrained/tapadl_fan_base_segmentation.pth \ 54 | --eval mIoU --results-file output/ 55 | ``` 56 | 57 | 58 | 59 | - Evaluate the pretrained model on Cityscapes-C: 60 | 61 | 1. Please specify the data path "data_root" in [cityscapes_1024x1024_repeat_cityc.py](https://github.com/guoyongcs/TAPADL/blob/main/TAPADL_FAN/segmentation/local_configs/_base_/datasets/cityscapes_1024x1024_repeat_cityc.py). 62 | 63 | 64 | 2. Evaluate the model via 65 | ``` 66 | CUDA_VISIBLE_DEVICES=0 python test_cityscapes_c.py \ 67 | local_configs/fan/fan_hybrid/tapfan_hybrid_base.1024x1024.city.160k.test.py \ 68 | ../../pretrained/tapadl_fan_base_segmentation.pth \ 69 | --eval mIoU --results-file output/ 70 | ``` 71 | 72 | 73 | - Evaluate the pretrained model on ACDC (test set): 74 | 75 | 1. Please specify the data path "data_root" in [cityscapes_1024x1024_repeat_acdc.py](https://github.com/guoyongcs/TAPADL/blob/main/TAPADL_FAN/segmentation/local_configs/_base_/datasets/cityscapes_1024x1024_repeat_acdc.py). 76 | 77 | 78 | 2. Evaluate the model via 79 | ``` 80 | CUDA_VISIBLE_DEVICES=0 python test_cityscapes_c.py \ 81 | local_configs/fan/fan_hybrid/tapfan_hybrid_base.1024x1024.city.160k.test.py \ 82 | ../../pretrained/tapadl_fan_base_segmentation.pth \ 83 | --results-file output/ --show-dir output/ 84 | ``` 85 | 86 | 3. Submit the results to obtain the mIoU score on [acdc.vision.ee.ethz.ch](https://acdc.vision.ee.ethz.ch) 87 | 88 | 89 | 90 | ## Training 91 | 92 | Train FAN-B-Hybrid with TAP and ADL on Cityscapes (using 4 GPUs) 93 | 94 | ``` 95 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 train.py \ 96 | local_configs/fan/fan_hybrid/tapadl_fan_hybrid_base.1024x1024.city.160k.py \ 97 | --launcher pytorch --work-dir ./exp_tapadl_fan_base_segmentation_cityscapes 98 | ``` 99 | 100 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from mmcv.runner import Hook 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class EvalHook(Hook): 8 | """Evaluation hook. 9 | 10 | Attributes: 11 | dataloader (DataLoader): A PyTorch dataloader. 12 | interval (int): Evaluation interval (by epochs). Default: 1. 13 | """ 14 | 15 | def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs): 16 | if not isinstance(dataloader, DataLoader): 17 | raise TypeError('dataloader must be a pytorch DataLoader, but got ' 18 | f'{type(dataloader)}') 19 | self.dataloader = dataloader 20 | self.interval = interval 21 | self.by_epoch = by_epoch 22 | self.eval_kwargs = eval_kwargs 23 | 24 | def after_train_iter(self, runner): 25 | """After train epoch hook.""" 26 | if self.by_epoch or not self.every_n_iters(runner, self.interval): 27 | return 28 | from mmseg.apis import single_gpu_test 29 | runner.log_buffer.clear() 30 | results = single_gpu_test(runner.model, self.dataloader, show=False) 31 | self.evaluate(runner, results) 32 | 33 | def after_train_epoch(self, runner): 34 | """After train epoch hook.""" 35 | if not self.by_epoch or not self.every_n_epochs(runner, self.interval): 36 | return 37 | from mmseg.apis import single_gpu_test 38 | runner.log_buffer.clear() 39 | results = single_gpu_test(runner.model, self.dataloader, show=False) 40 | self.evaluate(runner, results) 41 | 42 | def evaluate(self, runner, results): 43 | """Call evaluate function of dataset.""" 44 | eval_res = self.dataloader.dataset.evaluate( 45 | results, logger=runner.logger, **self.eval_kwargs) 46 | for name, val in eval_res.items(): 47 | runner.log_buffer.output[name] = val 48 | runner.log_buffer.ready = True 49 | 50 | 51 | class DistEvalHook(EvalHook): 52 | """Distributed evaluation hook. 53 | 54 | Attributes: 55 | dataloader (DataLoader): A PyTorch dataloader. 56 | interval (int): Evaluation interval (by epochs). Default: 1. 57 | tmpdir (str | None): Temporary directory to save the results of all 58 | processes. Default: None. 59 | gpu_collect (bool): Whether to use gpu or cpu to collect results. 60 | Default: False. 61 | """ 62 | 63 | def __init__(self, 64 | dataloader, 65 | interval=1, 66 | gpu_collect=False, 67 | by_epoch=False, 68 | **eval_kwargs): 69 | if not isinstance(dataloader, DataLoader): 70 | raise TypeError( 71 | 'dataloader must be a pytorch DataLoader, but got {}'.format( 72 | type(dataloader))) 73 | self.dataloader = dataloader 74 | self.interval = interval 75 | self.gpu_collect = gpu_collect 76 | self.by_epoch = by_epoch 77 | self.eval_kwargs = eval_kwargs 78 | 79 | def after_train_iter(self, runner): 80 | """After train epoch hook.""" 81 | if self.by_epoch or not self.every_n_iters(runner, self.interval): 82 | return 83 | from mmseg.apis import multi_gpu_test 84 | runner.log_buffer.clear() 85 | results = multi_gpu_test( 86 | runner.model, 87 | self.dataloader, 88 | tmpdir=osp.join(runner.work_dir, '.eval_hook'), 89 | gpu_collect=self.gpu_collect) 90 | if runner.rank == 0: 91 | print('\n') 92 | self.evaluate(runner, results) 93 | 94 | def after_train_epoch(self, runner): 95 | """After train epoch hook.""" 96 | if not self.by_epoch or not self.every_n_epochs(runner, self.interval): 97 | return 98 | from mmseg.apis import multi_gpu_test 99 | runner.log_buffer.clear() 100 | results = multi_gpu_test( 101 | runner.model, 102 | self.dataloader, 103 | tmpdir=osp.join(runner.work_dir, '.eval_hook'), 104 | gpu_collect=self.gpu_collect) 105 | if runner.rank == 0: 106 | print('\n') 107 | self.evaluate(runner, results) 108 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/utils/up_conv_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_upsample_layer 4 | 5 | 6 | class UpConvBlock(nn.Module): 7 | """Upsample convolution block in decoder for UNet. 8 | 9 | This upsample convolution block consists of one upsample module 10 | followed by one convolution block. The upsample module expands the 11 | high-level low-resolution feature map and the convolution block fuses 12 | the upsampled high-level low-resolution feature map and the low-level 13 | high-resolution feature map from encoder. 14 | 15 | Args: 16 | conv_block (nn.Sequential): Sequential of convolutional layers. 17 | in_channels (int): Number of input channels of the high-level 18 | skip_channels (int): Number of input channels of the low-level 19 | high-resolution feature map from encoder. 20 | out_channels (int): Number of output channels. 21 | num_convs (int): Number of convolutional layers in the conv_block. 22 | Default: 2. 23 | stride (int): Stride of convolutional layer in conv_block. Default: 1. 24 | dilation (int): Dilation rate of convolutional layer in conv_block. 25 | Default: 1. 26 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 27 | memory while slowing down the training speed. Default: False. 28 | conv_cfg (dict | None): Config dict for convolution layer. 29 | Default: None. 30 | norm_cfg (dict | None): Config dict for normalization layer. 31 | Default: dict(type='BN'). 32 | act_cfg (dict | None): Config dict for activation layer in ConvModule. 33 | Default: dict(type='ReLU'). 34 | upsample_cfg (dict): The upsample config of the upsample module in 35 | decoder. Default: dict(type='InterpConv'). If the size of 36 | high-level feature map is the same as that of skip feature map 37 | (low-level feature map from encoder), it does not need upsample the 38 | high-level feature map and the upsample_cfg is None. 39 | dcn (bool): Use deformable convoluton in convolutional layer or not. 40 | Default: None. 41 | plugins (dict): plugins for convolutional layers. Default: None. 42 | """ 43 | 44 | def __init__(self, 45 | conv_block, 46 | in_channels, 47 | skip_channels, 48 | out_channels, 49 | num_convs=2, 50 | stride=1, 51 | dilation=1, 52 | with_cp=False, 53 | conv_cfg=None, 54 | norm_cfg=dict(type='BN'), 55 | act_cfg=dict(type='ReLU'), 56 | upsample_cfg=dict(type='InterpConv'), 57 | dcn=None, 58 | plugins=None): 59 | super(UpConvBlock, self).__init__() 60 | assert dcn is None, 'Not implemented yet.' 61 | assert plugins is None, 'Not implemented yet.' 62 | 63 | self.conv_block = conv_block( 64 | in_channels=2 * skip_channels, 65 | out_channels=out_channels, 66 | num_convs=num_convs, 67 | stride=stride, 68 | dilation=dilation, 69 | with_cp=with_cp, 70 | conv_cfg=conv_cfg, 71 | norm_cfg=norm_cfg, 72 | act_cfg=act_cfg, 73 | dcn=None, 74 | plugins=None) 75 | if upsample_cfg is not None: 76 | self.upsample = build_upsample_layer( 77 | cfg=upsample_cfg, 78 | in_channels=in_channels, 79 | out_channels=skip_channels, 80 | with_cp=with_cp, 81 | norm_cfg=norm_cfg, 82 | act_cfg=act_cfg) 83 | else: 84 | self.upsample = ConvModule( 85 | in_channels, 86 | skip_channels, 87 | kernel_size=1, 88 | stride=1, 89 | padding=0, 90 | conv_cfg=conv_cfg, 91 | norm_cfg=norm_cfg, 92 | act_cfg=act_cfg) 93 | 94 | def forward(self, skip, x): 95 | """Forward function.""" 96 | 97 | x = self.upsample(x) 98 | out = torch.cat([skip, x], dim=1) 99 | out = self.conv_block(out) 100 | 101 | return out 102 | -------------------------------------------------------------------------------- /TAPADL_FAN/README.md: -------------------------------------------------------------------------------- 1 | # Training and evaluation of TAPADL based on FAN for image classification and semantic segmentation 2 | [Robustifying Token Attention for Vision Transformers](https://arxiv.org/pdf/2303.11126.pdf), \ 3 | [Yong Guo](http://www.guoyongcs.com/), [David Stutz](https://davidstutz.de/), and [Bernt Schiele](https://scholar.google.com/citations?user=z76PBfYAAAAJ&hl=en). ICCV 2023. 4 | 5 | 6 | 7 | # Dependencies 8 | Our code is built based on pytorch and timm library. Please check the detailed dependencies in [requirements.txt](https://github.com/guoyongcs/TAPADL/blob/main/requirements.txt). 9 | 10 | # Dataset Preparation 11 | 12 | Please download the clean [ImageNet](http://image-net.org/) dataset and [ImageNet-C](https://zenodo.org/record/2235448) dataset and structure the datasets as follows: 13 | 14 | ``` 15 | /PATH/TO/IMAGENET-C/ 16 | clean/ 17 | class1/ 18 | img3.jpeg 19 | class2/ 20 | img4.jpeg 21 | corruption1/ 22 | severity1/ 23 | class1/ 24 | img3.jpeg 25 | class2/ 26 | img4.jpeg 27 | severity2/ 28 | class1/ 29 | img3.jpeg 30 | class2/ 31 | img4.jpeg 32 | ``` 33 | 34 | We also use other robustness benchmarks for evaluation, including [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-P](https://zenodo.org/record/3565846) and [ImageNet-R](https://github.com/hendrycks/imagenet-r). 35 | 36 | 37 | 38 | ## Image Classification 39 | 40 | 41 | 42 | ### Pretrained Model 43 | 44 | | Model | #Params | IN-1K $\uparrow$ | IN-C $\downarrow$ | IN-A $\uparrow$ | IN-P $\downarrow$ | 45 | |:-----------------:|:----------------:|:-----------------:|:---------------:|:-----------------:|:-------:| 46 | | [FAN-B-ViT (TAP & ADL)](https://github.com/guoyongcs/TAPADL/releases/download/v1.0/tapadl_fan_base.pth.tar) | 50.7M | **84.3** | **43.7** | **42.3** | **29.2** | 47 | 48 | Please download and put the pretrained model [tapadl_fan_base.pth.tar](https://github.com/guoyongcs/TAPADL/releases/download/v1.0/tapadl_fan_base.pth.tar) in ```../pretrained```. 49 | 50 | 51 | ### Evaluation 52 | - Evaluate the pretrained model on ImageNet: 53 | ``` 54 | CUDA_VISIBLE_DEVICES=0 python validate_ood.py /PATH/TO/IMAGENET --model tap_fan_base_16_p4_hybrid \ 55 | --checkpoint ../pretrained/tapadl_fan_base.pth.tar --num-gpu 1 --amp --num-scales 4 56 | ``` 57 | 58 | - Evaluate the pretrained model on ImageNet-A/R: 59 | ``` 60 | CUDA_VISIBLE_DEVICES=0 python validate_ood.py /PATH/TO/IMAGENET-A --model tap_fan_base_16_p4_hybrid \ 61 | --checkpoint ../pretrained/tapadl_fan_base.pth.tar --num-gpu 1 --amp --num-scales 4 --imagenet_a 62 | ``` 63 | 64 | - Evaluate the pretrained model on ImageNet-C: 65 | ``` 66 | CUDA_VISIBLE_DEVICES=0 python validate_ood.py /PATH/TO/IMAGENET --model tap_fan_base_16_p4_hybrid \ 67 | --checkpoint ../pretrained/tapadl_fan_base.pth.tar --num-gpu 1 --imagenet_c \ 68 | --inc_path /PATH/TO/IMAGENET-C --amp --num-scales 4 69 | ``` 70 | 71 | - Evaluate the pretrained model on ImageNet-P 72 | 73 | Please refer to [test.sh](https://github.com/hendrycks/robustness/blob/master/ImageNet-P/test.sh) to see how to evaluate models on ImageNet-P. 74 | 75 | 76 | 77 | 78 | ### Training 79 | Train FAN-B-Hybrid with TAP and ADL on ImageNet (using 8 nodes and each with 8 GPUs) 80 | ``` 81 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --node_rank=$NODE_RANK \ 82 | --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT main.py /PATH/TO/IMAGENET \ 83 | --model tap_fan_base_16_p4_hybrid -b 64 --sched cosine --opt adamw -j 16 --warmup-lr 1e-6 \ 84 | --warmup-epochs 10 --aa rand-m9-mstd0.5-inc1 --remode pixel --reprob 0.3 --lr 40e-4 \ 85 | --min-lr 1e-6 --weight-decay .05 --drop 0.0 --drop-path .35 --img-size 224 --mixup 0.8 \ 86 | --cutmix 1.0 --smoothing 0.1 --output ./experiments/exp_tapadl_fan_base_imagenet \ 87 | --amp --model-ema 88 | ``` 89 | 90 | 91 | 92 | ## Semantic Segmentation 93 | 94 | Please see details in [segmentation](https://github.com/guoyongcs/TAPADL/blob/main/TAPADL_FAN/segmentation). 95 | 96 | 97 | ## Citation 98 | If you find this repository helpful, please consider citing: 99 | ``` 100 | @inproceedings{guo2023robustifying, 101 | title={Robustifying token attention for vision transformers}, 102 | author={Guo, Yong and Stutz, David and Schiele, Bernt}, 103 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)}}, 104 | year={2023} 105 | } 106 | ``` 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/apis/inference.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import mmcv 3 | import torch 4 | from mmcv.parallel import collate, scatter 5 | from mmcv.runner import load_checkpoint 6 | 7 | from mmseg.datasets.pipelines import Compose 8 | from mmseg.models import build_segmentor 9 | 10 | 11 | def init_segmentor(config, checkpoint=None, device='cuda:0'): 12 | """Initialize a segmentor from config file. 13 | 14 | Args: 15 | config (str or :obj:`mmcv.Config`): Config file path or the config 16 | object. 17 | checkpoint (str, optional): Checkpoint path. If left as None, the model 18 | will not load any weights. 19 | device (str, optional) CPU/CUDA device option. Default 'cuda:0'. 20 | Use 'cpu' for loading model on CPU. 21 | Returns: 22 | nn.Module: The constructed segmentor. 23 | """ 24 | if isinstance(config, str): 25 | config = mmcv.Config.fromfile(config) 26 | elif not isinstance(config, mmcv.Config): 27 | raise TypeError('config must be a filename or Config object, ' 28 | 'but got {}'.format(type(config))) 29 | config.model.pretrained = None 30 | config.model.train_cfg = None 31 | model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) 32 | if checkpoint is not None: 33 | checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') 34 | model.CLASSES = checkpoint['meta']['CLASSES'] 35 | model.PALETTE = checkpoint['meta']['PALETTE'] 36 | model.cfg = config # save the config in the model for convenience 37 | model.to(device) 38 | model.eval() 39 | return model 40 | 41 | 42 | class LoadImage: 43 | """A simple pipeline to load image.""" 44 | 45 | def __call__(self, results): 46 | """Call function to load images into results. 47 | 48 | Args: 49 | results (dict): A result dict contains the file name 50 | of the image to be read. 51 | 52 | Returns: 53 | dict: ``results`` will be returned containing loaded image. 54 | """ 55 | 56 | if isinstance(results['img'], str): 57 | results['filename'] = results['img'] 58 | results['ori_filename'] = results['img'] 59 | else: 60 | results['filename'] = None 61 | results['ori_filename'] = None 62 | img = mmcv.imread(results['img']) 63 | results['img'] = img 64 | results['img_shape'] = img.shape 65 | results['ori_shape'] = img.shape 66 | return results 67 | 68 | 69 | def inference_segmentor(model, img): 70 | """Inference image(s) with the segmentor. 71 | 72 | Args: 73 | model (nn.Module): The loaded segmentor. 74 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 75 | images. 76 | 77 | Returns: 78 | (list[Tensor]): The segmentation result. 79 | """ 80 | cfg = model.cfg 81 | device = next(model.parameters()).device # model device 82 | # build the data pipeline 83 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 84 | test_pipeline = Compose(test_pipeline) 85 | # prepare data 86 | data = dict(img=img) 87 | data = test_pipeline(data) 88 | data = collate([data], samples_per_gpu=1) 89 | if next(model.parameters()).is_cuda: 90 | # scatter to specified GPU 91 | data = scatter(data, [device])[0] 92 | else: 93 | data['img_metas'] = [i.data[0] for i in data['img_metas']] 94 | 95 | # forward the model 96 | with torch.no_grad(): 97 | result = model(return_loss=False, rescale=True, **data) 98 | return result 99 | 100 | 101 | def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)): 102 | """Visualize the segmentation results on the image. 103 | 104 | Args: 105 | model (nn.Module): The loaded segmentor. 106 | img (str or np.ndarray): Image filename or loaded image. 107 | result (list): The segmentation result. 108 | palette (list[list[int]]] | None): The palette of segmentation 109 | map. If None is given, random palette will be generated. 110 | Default: None 111 | fig_size (tuple): Figure size of the pyplot figure. 112 | """ 113 | if hasattr(model, 'module'): 114 | model = model.module 115 | img = model.show_result(img, result, palette=palette, show=False) 116 | plt.figure(figsize=fig_size) 117 | plt.imshow(mmcv.bgr2rgb(img)) 118 | plt.show() 119 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/uper_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | from .psp_head import PPM 9 | 10 | 11 | @HEADS.register_module() 12 | class UPerHead(BaseDecodeHead): 13 | """Unified Perceptual Parsing for Scene Understanding. 14 | 15 | This head is the implementation of `UPerNet 16 | `_. 17 | 18 | Args: 19 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 20 | Module applied on the last feature. Default: (1, 2, 3, 6). 21 | """ 22 | 23 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 24 | super(UPerHead, self).__init__( 25 | input_transform='multiple_select', **kwargs) 26 | # PSP Module 27 | self.psp_modules = PPM( 28 | pool_scales, 29 | self.in_channels[-1], 30 | self.channels, 31 | conv_cfg=self.conv_cfg, 32 | norm_cfg=self.norm_cfg, 33 | act_cfg=self.act_cfg, 34 | align_corners=self.align_corners) 35 | self.bottleneck = ConvModule( 36 | self.in_channels[-1] + len(pool_scales) * self.channels, 37 | self.channels, 38 | 3, 39 | padding=1, 40 | conv_cfg=self.conv_cfg, 41 | norm_cfg=self.norm_cfg, 42 | act_cfg=self.act_cfg) 43 | # FPN Module 44 | self.lateral_convs = nn.ModuleList() 45 | self.fpn_convs = nn.ModuleList() 46 | for in_channels in self.in_channels[:-1]: # skip the top layer 47 | l_conv = ConvModule( 48 | in_channels, 49 | self.channels, 50 | 1, 51 | conv_cfg=self.conv_cfg, 52 | norm_cfg=self.norm_cfg, 53 | act_cfg=self.act_cfg, 54 | inplace=False) 55 | fpn_conv = ConvModule( 56 | self.channels, 57 | self.channels, 58 | 3, 59 | padding=1, 60 | conv_cfg=self.conv_cfg, 61 | norm_cfg=self.norm_cfg, 62 | act_cfg=self.act_cfg, 63 | inplace=False) 64 | self.lateral_convs.append(l_conv) 65 | self.fpn_convs.append(fpn_conv) 66 | 67 | self.fpn_bottleneck = ConvModule( 68 | len(self.in_channels) * self.channels, 69 | self.channels, 70 | 3, 71 | padding=1, 72 | conv_cfg=self.conv_cfg, 73 | norm_cfg=self.norm_cfg, 74 | act_cfg=self.act_cfg) 75 | 76 | def psp_forward(self, inputs): 77 | """Forward function of PSP module.""" 78 | x = inputs[-1] 79 | psp_outs = [x] 80 | psp_outs.extend(self.psp_modules(x)) 81 | psp_outs = torch.cat(psp_outs, dim=1) 82 | output = self.bottleneck(psp_outs) 83 | 84 | return output 85 | 86 | def forward(self, inputs): 87 | """Forward function.""" 88 | 89 | inputs = self._transform_inputs(inputs) 90 | 91 | # build laterals 92 | laterals = [ 93 | lateral_conv(inputs[i]) 94 | for i, lateral_conv in enumerate(self.lateral_convs) 95 | ] 96 | 97 | laterals.append(self.psp_forward(inputs)) 98 | 99 | # build top-down path 100 | used_backbone_levels = len(laterals) 101 | for i in range(used_backbone_levels - 1, 0, -1): 102 | prev_shape = laterals[i - 1].shape[2:] 103 | laterals[i - 1] += resize( 104 | laterals[i], 105 | size=prev_shape, 106 | mode='bilinear', 107 | align_corners=self.align_corners) 108 | 109 | # build outputs 110 | fpn_outs = [ 111 | self.fpn_convs[i](laterals[i]) 112 | for i in range(used_backbone_levels - 1) 113 | ] 114 | # append psp feature 115 | fpn_outs.append(laterals[-1]) 116 | 117 | for i in range(used_backbone_levels - 1, 0, -1): 118 | fpn_outs[i] = resize( 119 | fpn_outs[i], 120 | size=fpn_outs[0].shape[2:], 121 | mode='bilinear', 122 | align_corners=self.align_corners) 123 | fpn_outs = torch.cat(fpn_outs, dim=1) 124 | output = self.fpn_bottleneck(fpn_outs) 125 | output = self.cls_seg(output) 126 | return output 127 | -------------------------------------------------------------------------------- /TAPADL_FAN/myutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | 4 | Mostly copy-paste from torchvision references. 5 | """ 6 | import io 7 | import os 8 | import time 9 | from collections import defaultdict, deque 10 | import datetime 11 | import cv2 12 | from PIL import Image 13 | 14 | import torch 15 | import torch.distributed as dist 16 | import numpy as np 17 | from torchvision import datasets, transforms 18 | 19 | import torch.nn.functional as F 20 | from timm.utils import * 21 | 22 | 23 | class SmoothedValue(object): 24 | """Track a series of values and provide access to smoothed values over a 25 | window or the global series average. 26 | """ 27 | 28 | def __init__(self, window_size=20, fmt=None): 29 | if fmt is None: 30 | fmt = "{value:.2f} ({global_avg:.2f})" 31 | self.deque = deque(maxlen=window_size) 32 | self.total = 0.0 33 | self.count = 0 34 | self.fmt = fmt 35 | 36 | def update(self, value, n=1): 37 | self.deque.append(value) 38 | self.count += n 39 | self.total += value * n 40 | 41 | def synchronize_between_processes(self): 42 | """ 43 | Warning: does not synchronize the deque! 44 | """ 45 | if not is_dist_avail_and_initialized(): 46 | return 47 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 48 | dist.barrier() 49 | dist.all_reduce(t) 50 | t = t.tolist() 51 | self.count = int(t[0]) 52 | self.total = t[1] 53 | 54 | @property 55 | def median(self): 56 | d = torch.tensor(list(self.deque)) 57 | return d.median().item() 58 | 59 | @property 60 | def avg(self): 61 | d = torch.tensor(list(self.deque), dtype=torch.float32) 62 | return d.mean().item() 63 | 64 | @property 65 | def global_avg(self): 66 | return self.total / self.count 67 | 68 | @property 69 | def max(self): 70 | return max(self.deque) 71 | 72 | @property 73 | def value(self): 74 | return self.deque[-1] 75 | 76 | def __str__(self): 77 | return self.fmt.format( 78 | median=self.median, 79 | avg=self.avg, 80 | global_avg=self.global_avg, 81 | max=self.max, 82 | value=self.value) 83 | 84 | 85 | def is_dist_avail_and_initialized(): 86 | if not dist.is_available(): 87 | return False 88 | if not dist.is_initialized(): 89 | return False 90 | return True 91 | 92 | 93 | def get_world_size(): 94 | if not is_dist_avail_and_initialized(): 95 | return 1 96 | return dist.get_world_size() 97 | 98 | 99 | def get_rank(): 100 | if not is_dist_avail_and_initialized(): 101 | return 0 102 | return dist.get_rank() 103 | 104 | 105 | try: 106 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 107 | has_apex = True 108 | except ImportError: 109 | has_apex = False 110 | 111 | 112 | class MyCheckpointSaver(CheckpointSaver): 113 | def __init__( 114 | self, 115 | model, 116 | optimizer, 117 | args=None, 118 | model_ema=None, 119 | amp_scaler=None, 120 | checkpoint_prefix='checkpoint', 121 | recovery_prefix='recovery', 122 | checkpoint_dir='', 123 | recovery_dir='', 124 | decreasing=False, 125 | max_history=10, 126 | unwrap_fn=unwrap_model): 127 | 128 | super().__init__(model, 129 | optimizer, 130 | args, 131 | model_ema, 132 | amp_scaler, 133 | checkpoint_prefix, 134 | recovery_prefix, 135 | checkpoint_dir, 136 | recovery_dir, 137 | decreasing, 138 | max_history, 139 | unwrap_fn) 140 | 141 | def load_checkpoint_files(self): 142 | tmp_save_path = os.path.join(self.checkpoint_dir, 'checkpoint_files' + self.extension) 143 | file_to_load = torch.load(tmp_save_path) 144 | self.checkpoint_files = file_to_load['checkpoint_files'] 145 | self.best_metric = file_to_load['best_metric'] 146 | 147 | def save_checkpoint_files(self): 148 | tmp_save_path = os.path.join(self.checkpoint_dir, 'checkpoint_files' + self.extension) 149 | file_to_save = {'checkpoint_files': self.checkpoint_files, 'best_metric': self.best_metric} 150 | torch.save(file_to_save, tmp_save_path) 151 | 152 | 153 | -------------------------------------------------------------------------------- /TAPADL_RVT/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class RASampler(torch.utils.data.Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset for distributed, 8 | with repeated augmentation. 9 | It ensures that different each augmented version of a sample will be visible to a 10 | different process (GPU) 11 | Heavily based on torch.utils.data.DistributedSampler 12 | """ 13 | 14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 15 | if num_replicas is None: 16 | if not dist.is_available(): 17 | raise RuntimeError("Requires distributed package to be available") 18 | num_replicas = dist.get_world_size() 19 | if rank is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available") 22 | rank = dist.get_rank() 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 28 | self.total_size = self.num_samples * self.num_replicas 29 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 30 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 31 | self.shuffle = shuffle 32 | 33 | def __iter__(self): 34 | # deterministically shuffle based on epoch 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | if self.shuffle: 38 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 39 | else: 40 | indices = list(range(len(self.dataset))) 41 | 42 | # add extra samples to make it evenly divisible 43 | indices = [ele for ele in indices for i in range(3)] 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | indices = indices[self.rank:self.total_size:self.num_replicas] 49 | assert len(indices) == self.num_samples 50 | 51 | return iter(indices[:self.num_selected_samples]) 52 | 53 | def __len__(self): 54 | return self.num_selected_samples 55 | 56 | def set_epoch(self, epoch): 57 | self.epoch = epoch 58 | 59 | 60 | class RepRASampler(torch.utils.data.Sampler): 61 | """Sampler that restricts data loading to a subset of the dataset for distributed, 62 | with repeated augmentation. 63 | It ensures that different each augmented version of a sample will be visible to a 64 | different process (GPU) 65 | Heavily based on torch.utils.data.DistributedSampler 66 | """ 67 | 68 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 69 | if num_replicas is None: 70 | if not dist.is_available(): 71 | raise RuntimeError("Requires distributed package to be available") 72 | num_replicas = dist.get_world_size() 73 | if rank is None: 74 | if not dist.is_available(): 75 | raise RuntimeError("Requires distributed package to be available") 76 | rank = dist.get_rank() 77 | self.dataset = dataset 78 | self.num_replicas = num_replicas 79 | self.rank = rank 80 | self.epoch = 0 81 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 82 | self.total_size = self.num_samples * self.num_replicas 83 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 84 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 85 | self.shuffle = shuffle 86 | 87 | def __iter__(self): 88 | # deterministically shuffle based on epoch 89 | g = torch.Generator() 90 | g.manual_seed(self.epoch) 91 | if self.shuffle: 92 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 93 | else: 94 | indices = list(range(len(self.dataset))) 95 | 96 | # add extra samples to make it evenly divisible 97 | indices = [ele for ele in indices for i in range(3)] 98 | indices += indices[:(self.total_size - len(indices))] 99 | assert len(indices) == self.total_size 100 | 101 | # subsample 102 | indices = indices[self.rank:self.total_size:self.num_replicas] 103 | assert len(indices) == self.num_samples 104 | 105 | return iter(indices) 106 | 107 | def __len__(self): 108 | return self.num_samples 109 | 110 | def set_epoch(self, epoch): 111 | self.epoch = epoch 112 | 113 | 114 | -------------------------------------------------------------------------------- /TAPADL_FAN/segmentation/mmseg/models/decode_heads/ocr_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import HEADS 8 | from ..utils import SelfAttentionBlock as _SelfAttentionBlock 9 | from .cascade_decode_head import BaseCascadeDecodeHead 10 | 11 | 12 | class SpatialGatherModule(nn.Module): 13 | """Aggregate the context features according to the initial predicted 14 | probability distribution. 15 | 16 | Employ the soft-weighted method to aggregate the context. 17 | """ 18 | 19 | def __init__(self, scale): 20 | super(SpatialGatherModule, self).__init__() 21 | self.scale = scale 22 | 23 | def forward(self, feats, probs): 24 | """Forward function.""" 25 | batch_size, num_classes, height, width = probs.size() 26 | channels = feats.size(1) 27 | probs = probs.view(batch_size, num_classes, -1) 28 | feats = feats.view(batch_size, channels, -1) 29 | # [batch_size, height*width, num_classes] 30 | feats = feats.permute(0, 2, 1) 31 | # [batch_size, channels, height*width] 32 | probs = F.softmax(self.scale * probs, dim=2) 33 | # [batch_size, channels, num_classes] 34 | ocr_context = torch.matmul(probs, feats) 35 | ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) 36 | return ocr_context 37 | 38 | 39 | class ObjectAttentionBlock(_SelfAttentionBlock): 40 | """Make a OCR used SelfAttentionBlock.""" 41 | 42 | def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, 43 | act_cfg): 44 | if scale > 1: 45 | query_downsample = nn.MaxPool2d(kernel_size=scale) 46 | else: 47 | query_downsample = None 48 | super(ObjectAttentionBlock, self).__init__( 49 | key_in_channels=in_channels, 50 | query_in_channels=in_channels, 51 | channels=channels, 52 | out_channels=in_channels, 53 | share_key_query=False, 54 | query_downsample=query_downsample, 55 | key_downsample=None, 56 | key_query_num_convs=2, 57 | key_query_norm=True, 58 | value_out_num_convs=1, 59 | value_out_norm=True, 60 | matmul_norm=True, 61 | with_out=True, 62 | conv_cfg=conv_cfg, 63 | norm_cfg=norm_cfg, 64 | act_cfg=act_cfg) 65 | self.bottleneck = ConvModule( 66 | in_channels * 2, 67 | in_channels, 68 | 1, 69 | conv_cfg=self.conv_cfg, 70 | norm_cfg=self.norm_cfg, 71 | act_cfg=self.act_cfg) 72 | 73 | def forward(self, query_feats, key_feats): 74 | """Forward function.""" 75 | context = super(ObjectAttentionBlock, 76 | self).forward(query_feats, key_feats) 77 | output = self.bottleneck(torch.cat([context, query_feats], dim=1)) 78 | if self.query_downsample is not None: 79 | output = resize(query_feats) 80 | 81 | return output 82 | 83 | 84 | @HEADS.register_module() 85 | class OCRHead(BaseCascadeDecodeHead): 86 | """Object-Contextual Representations for Semantic Segmentation. 87 | 88 | This head is the implementation of `OCRNet 89 | `_. 90 | 91 | Args: 92 | ocr_channels (int): The intermediate channels of OCR block. 93 | scale (int): The scale of probability map in SpatialGatherModule in 94 | Default: 1. 95 | """ 96 | 97 | def __init__(self, ocr_channels, scale=1, **kwargs): 98 | super(OCRHead, self).__init__(**kwargs) 99 | self.ocr_channels = ocr_channels 100 | self.scale = scale 101 | self.object_context_block = ObjectAttentionBlock( 102 | self.channels, 103 | self.ocr_channels, 104 | self.scale, 105 | conv_cfg=self.conv_cfg, 106 | norm_cfg=self.norm_cfg, 107 | act_cfg=self.act_cfg) 108 | self.spatial_gather_module = SpatialGatherModule(self.scale) 109 | 110 | self.bottleneck = ConvModule( 111 | self.in_channels, 112 | self.channels, 113 | 3, 114 | padding=1, 115 | conv_cfg=self.conv_cfg, 116 | norm_cfg=self.norm_cfg, 117 | act_cfg=self.act_cfg) 118 | 119 | def forward(self, inputs, prev_output): 120 | """Forward function.""" 121 | x = self._transform_inputs(inputs) 122 | feats = self.bottleneck(x) 123 | context = self.spatial_gather_module(feats, prev_output) 124 | object_context = self.object_context_block(feats, context) 125 | output = self.cls_seg(object_context) 126 | 127 | return output 128 | --------------------------------------------------------------------------------