├── AgriFM ├── losses │ ├── __init__.py │ └── loss.py ├── datasets │ ├── __init__.py │ └── mapping_dataset.py ├── __init__.py ├── evaluation │ └── __init__.py ├── models │ ├── __init__.py │ ├── encoders.py │ ├── heads.py │ ├── multi_unified_model.py │ └── neck.py └── utils │ └── path_utils.py ├── resources └── AgriFM.png ├── mmseg ├── utils │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── collect_env.py │ ├── typing_utils.py │ ├── io.py │ ├── set_env.py │ ├── __init__.py │ └── get_templates.py ├── engine │ ├── schedulers │ │ ├── __init__.py │ │ └── poly_ratio_scheduler.py │ ├── hooks │ │ ├── __init__.py │ │ └── visualization_hook.py │ ├── optimizers │ │ └── __init__.py │ └── __init__.py ├── models │ ├── text_encoder │ │ └── __init__.py │ ├── necks │ │ ├── __init__.py │ │ ├── featurepyramid.py │ │ ├── multilevel_neck.py │ │ └── mla_neck.py │ ├── assigners │ │ ├── __init__.py │ │ ├── base_assigner.py │ │ └── hungarian_assigner.py │ ├── segmentors │ │ ├── __init__.py │ │ └── seg_tta.py │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── boundary_loss.py │ │ ├── ohem_cross_entropy_loss.py │ │ ├── accuracy.py │ │ ├── kldiv_loss.py │ │ ├── silog_loss.py │ │ └── utils.py │ ├── utils │ │ ├── __init__.py │ │ ├── make_divisible.py │ │ ├── wrappers.py │ │ ├── se_layer.py │ │ ├── encoding.py │ │ ├── res_layer.py │ │ ├── shape_convert.py │ │ ├── point_sample.py │ │ └── up_conv_block.py │ ├── backbones │ │ ├── __init__.py │ │ └── timm_backbone.py │ ├── decode_heads │ │ ├── cc_head.py │ │ ├── nl_head.py │ │ ├── gc_head.py │ │ ├── __init__.py │ │ ├── segformer_head.py │ │ ├── setr_mla_head.py │ │ ├── cascade_decode_head.py │ │ ├── sep_fcn_head.py │ │ ├── fpn_head.py │ │ ├── setr_up_head.py │ │ ├── lraspp_head.py │ │ ├── fcn_head.py │ │ ├── sep_aspp_head.py │ │ ├── stdc_head.py │ │ ├── psp_head.py │ │ ├── aspp_head.py │ │ └── ddr_head.py │ └── builder.py ├── visualization │ └── __init__.py ├── evaluation │ ├── __init__.py │ └── metrics │ │ └── __init__.py ├── structures │ ├── sampler │ │ ├── __init__.py │ │ ├── base_pixel_sampler.py │ │ ├── builder.py │ │ └── ohem_pixel_sampler.py │ ├── __init__.py │ └── seg_data_sample.py ├── apis │ ├── __init__.py │ └── utils.py ├── datasets │ ├── dark_zurich.py │ ├── night_driving.py │ ├── refuge.py │ ├── levir.py │ ├── isprs.py │ ├── potsdam.py │ ├── loveda.py │ ├── hrf.py │ ├── stare.py │ ├── drive.py │ ├── chase_db1.py │ ├── synapse.py │ ├── bdd100k.py │ ├── cityscapes.py │ ├── lip.py │ ├── voc.py │ ├── isaid.py │ ├── transforms │ │ ├── __init__.py │ │ └── formatting.py │ ├── __init__.py │ ├── decathlon.py │ └── dsdl.py ├── version.py ├── registry │ └── __init__.py └── __init__.py ├── inference.py └── test.py /AgriFM/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import CropCEloss 2 | __all__=['CropCEloss'] -------------------------------------------------------------------------------- /resources/AgriFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flyakon/AgriFM/HEAD/resources/AgriFM.png -------------------------------------------------------------------------------- /AgriFM/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mapping_dataset import MappingDataset 2 | 3 | __all__ = ['MappingDataset'] -------------------------------------------------------------------------------- /AgriFM/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets,evaluation,models,losses 2 | __all__ = ['datasets', 'evaluation', 'models','losses'] -------------------------------------------------------------------------------- /mmseg/utils/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flyakon/AgriFM/HEAD/mmseg/utils/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /AgriFM/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .metric.iou_metric import CropIoUMetric 3 | 4 | __all__ = ['CropIoUMetric',] 5 | -------------------------------------------------------------------------------- /mmseg/engine/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .poly_ratio_scheduler import PolyLRRatio 3 | 4 | __all__ = ['PolyLRRatio'] 5 | -------------------------------------------------------------------------------- /mmseg/engine/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .visualization_hook import SegVisualizationHook 3 | 4 | __all__ = ['SegVisualizationHook'] 5 | -------------------------------------------------------------------------------- /mmseg/models/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .clip_text_encoder import CLIPTextEncoder 3 | 4 | __all__ = ['CLIPTextEncoder'] 5 | -------------------------------------------------------------------------------- /mmseg/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .local_visualizer import SegLocalVisualizer,CommonVisual 3 | 4 | __all__ = ['SegLocalVisualizer','CommonVisual'] 5 | -------------------------------------------------------------------------------- /mmseg/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .metrics import CityscapesMetric, DepthMetric, IoUMetric 3 | 4 | __all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric'] 5 | -------------------------------------------------------------------------------- /mmseg/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .citys_metric import CityscapesMetric 3 | from .depth_metric import DepthMetric 4 | from .iou_metric import IoUMetric 5 | 6 | __all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric'] 7 | -------------------------------------------------------------------------------- /mmseg/structures/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_pixel_sampler import BasePixelSampler 3 | from .builder import build_pixel_sampler 4 | from .ohem_pixel_sampler import OHEMPixelSampler 5 | 6 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /mmseg/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler 3 | from .seg_data_sample import SegDataSample 4 | 5 | __all__ = [ 6 | 'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler', 7 | 'build_pixel_sampler' 8 | ] 9 | -------------------------------------------------------------------------------- /mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .featurepyramid import Feature2Pyramid 3 | from .fpn import FPN 4 | from .ic_neck import ICNeck 5 | from .jpu import JPU 6 | from .mla_neck import MLANeck 7 | from .multilevel_neck import MultiLevelNeck 8 | 9 | __all__ = [ 10 | 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_model, init_model, show_result_pyplot 3 | from .mmseg_inferencer import MMSegInferencer 4 | from .remote_sense_inferencer import RSImage, RSInferencer 5 | 6 | __all__ = [ 7 | 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer', 8 | 'RSInferencer', 'RSImage' 9 | ] 10 | -------------------------------------------------------------------------------- /mmseg/models/assigners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_assigner import BaseAssigner 3 | from .hungarian_assigner import HungarianAssigner 4 | from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost 5 | 6 | __all__ = [ 7 | 'BaseAssigner', 8 | 'HungarianAssigner', 9 | 'ClassificationCost', 10 | 'CrossEntropyLossCost', 11 | 'DiceCost', 12 | ] 13 | -------------------------------------------------------------------------------- /mmseg/structures/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BasePixelSampler(metaclass=ABCMeta): 6 | """Base class of pixel sampler.""" 7 | 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | @abstractmethod 12 | def sample(self, seg_logit, seg_label): 13 | """Placeholder for sample function.""" 14 | -------------------------------------------------------------------------------- /AgriFM/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .video_swin_transformer import PretrainingSwinTransformer3DEncoder,SwinPatchEmbed3D 3 | from .encoders import MultiModalEncoder 4 | from .neck import MultiFusionNeck 5 | from .heads import CropFCNHead 6 | from .multi_unified_model import MultiUnifiedModel 7 | __all__=['CropFCNHead','MultiFusionNeck','MultiModalEncoder', 8 | 'PretrainingSwinTransformer3DEncoder','SwinPatchEmbed3D', 9 | 'MultiUnifiedModel',] -------------------------------------------------------------------------------- /mmseg/engine/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .force_default_constructor import ForceDefaultOptimWrapperConstructor 3 | from .layer_decay_optimizer_constructor import ( 4 | LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) 5 | 6 | __all__ = [ 7 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', 8 | 'ForceDefaultOptimWrapperConstructor' 9 | ] 10 | -------------------------------------------------------------------------------- /mmseg/structures/sampler/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmseg.registry import TASK_UTILS 5 | 6 | PIXEL_SAMPLERS = TASK_UTILS 7 | 8 | 9 | def build_pixel_sampler(cfg, **default_args): 10 | """Build pixel sampler for segmentation map.""" 11 | warnings.warn( 12 | '``build_pixel_sampler`` would be deprecated soon, please use ' 13 | '``mmseg.registry.TASK_UTILS.build()`` ') 14 | return TASK_UTILS.build(cfg, default_args=default_args) 15 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseSegmentor 3 | from .cascade_encoder_decoder import CascadeEncoderDecoder 4 | from .depth_estimator import DepthEstimator 5 | from .encoder_decoder import EncoderDecoder 6 | from .multimodal_encoder_decoder import MultimodalEncoderDecoder 7 | from .seg_tta import SegTTAModel 8 | 9 | __all__ = [ 10 | 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel', 11 | 'MultimodalEncoderDecoder', 'DepthEstimator' 12 | ] 13 | -------------------------------------------------------------------------------- /mmseg/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .hooks import SegVisualizationHook 3 | from .optimizers import (ForceDefaultOptimWrapperConstructor, 4 | LayerDecayOptimizerConstructor, 5 | LearningRateDecayOptimizerConstructor) 6 | from .schedulers import PolyLRRatio 7 | 8 | __all__ = [ 9 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', 10 | 'SegVisualizationHook', 'PolyLRRatio', 11 | 'ForceDefaultOptimWrapperConstructor' 12 | ] 13 | -------------------------------------------------------------------------------- /AgriFM/losses/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from mmseg.models.builder import LOSSES 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | 8 | 9 | @LOSSES.register_module() 10 | class CropCEloss(torch.nn.Module): 11 | def __init__(self,ignore_index=-1): 12 | super().__init__() 13 | self.criterion=torch.nn.CrossEntropyLoss(reduction='none',ignore_index=ignore_index) 14 | 15 | def forward(self,pred,label): 16 | loss=self.criterion(pred,label) 17 | return {'crop_ce_loss': loss.mean() if isinstance(loss, torch.Tensor) else np.mean(loss)} 18 | -------------------------------------------------------------------------------- /mmseg/datasets/dark_zurich.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class DarkZurichDataset(CityscapesDataset): 8 | """DarkZurichDataset dataset.""" 9 | 10 | def __init__(self, 11 | img_suffix='_rgb_anon.png', 12 | seg_map_suffix='_gt_labelTrainIds.png', 13 | **kwargs) -> None: 14 | super().__init__( 15 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 16 | -------------------------------------------------------------------------------- /mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.utils import get_git_hash 3 | from mmengine.utils.dl_utils import collect_env as collect_base_env 4 | 5 | import mmseg 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print(f'{name}: {val}') 19 | -------------------------------------------------------------------------------- /mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '1.2.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /mmseg/datasets/night_driving.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class NightDrivingDataset(CityscapesDataset): 8 | """NightDrivingDataset dataset.""" 9 | 10 | def __init__(self, 11 | img_suffix='_leftImg8bit.png', 12 | seg_map_suffix='_gtCoarse_labelTrainIds.png', 13 | **kwargs) -> None: 14 | super().__init__( 15 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 16 | -------------------------------------------------------------------------------- /mmseg/models/assigners/base_assigner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Optional 4 | 5 | from mmengine.structures import InstanceData 6 | 7 | 8 | class BaseAssigner(metaclass=ABCMeta): 9 | """Base assigner that assigns masks to ground truth class labels.""" 10 | 11 | @abstractmethod 12 | def assign(self, 13 | pred_instances: InstanceData, 14 | gt_instances: InstanceData, 15 | gt_instances_ignore: Optional[InstanceData] = None, 16 | **kwargs): 17 | """Assign masks to either a ground truth class label or a negative 18 | label.""" 19 | -------------------------------------------------------------------------------- /mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .assigners import * # noqa: F401,F403 3 | from .backbones import * # noqa: F401,F403 4 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 5 | build_head, build_loss, build_segmentor) 6 | from .data_preprocessor import SegDataPreProcessor 7 | from .decode_heads import * # noqa: F401,F403 8 | from .losses import * # noqa: F401,F403 9 | from .necks import * # noqa: F401,F403 10 | from .segmentors import * # noqa: F401,F403 11 | from .text_encoder import * # noqa: F401,F403 12 | 13 | __all__ = [ 14 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 15 | 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor' 16 | ] 17 | -------------------------------------------------------------------------------- /mmseg/registry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS, 3 | LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, 4 | OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, 5 | PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, 6 | TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, 7 | WEIGHT_INITIALIZERS) 8 | 9 | __all__ = [ 10 | 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 11 | 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 12 | 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 13 | 'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 14 | 'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS' 15 | ] 16 | -------------------------------------------------------------------------------- /mmseg/utils/typing_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Collecting some commonly used type hint in mmflow.""" 3 | from typing import Dict, List, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | from mmengine.config import ConfigDict 7 | 8 | from mmseg.structures import SegDataSample 9 | 10 | # Type hint of config data 11 | ConfigType = Union[ConfigDict, dict] 12 | OptConfigType = Optional[ConfigType] 13 | # Type hint of one or more config data 14 | MultiConfig = Union[ConfigType, Sequence[ConfigType]] 15 | OptMultiConfig = Optional[MultiConfig] 16 | 17 | SampleList = Sequence[SegDataSample] 18 | OptSampleList = Optional[SampleList] 19 | 20 | # Type hint of Tensor 21 | TensorDict = Dict[str, torch.Tensor] 22 | TensorList = Sequence[torch.Tensor] 23 | 24 | ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample], 25 | Tuple[torch.Tensor], torch.Tensor] 26 | -------------------------------------------------------------------------------- /AgriFM/models/encoders.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from typing import List, Tuple, Optional, Dict, Union 5 | from mmseg.registry import MODELS 6 | from mmengine.model import BaseModule,BaseModel 7 | 8 | 9 | @MODELS.register_module() 10 | class MultiModalEncoder(BaseModel): 11 | def __init__(self,encoders_cfg): 12 | super().__init__() 13 | self.encoders=nn.ModuleDict() 14 | for name,cfg in encoders_cfg.items(): 15 | self.encoders[name]=MODELS.build(cfg) 16 | def forward(self, 17 | inputs:dict, 18 | data_samples: Optional[list] = None, 19 | mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: 20 | outputs={} 21 | for name,encoder in self.encoders.items(): 22 | 23 | inputs_data=inputs[name] 24 | outputs[name]=encoder(inputs_data,data_samples,mode) 25 | if name in inputs.keys(): 26 | inputs.pop(name) 27 | outputs.update(inputs) 28 | return outputs -------------------------------------------------------------------------------- /mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .boundary_loss import BoundaryLoss 4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 5 | cross_entropy, mask_cross_entropy) 6 | from .dice_loss import DiceLoss 7 | from .focal_loss import FocalLoss 8 | from .huasdorff_distance_loss import HuasdorffDisstanceLoss 9 | from .lovasz_loss import LovaszLoss 10 | from .ohem_cross_entropy_loss import OhemCrossEntropy 11 | from .silog_loss import SiLogLoss 12 | from .tversky_loss import TverskyLoss 13 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 14 | 15 | __all__ = [ 16 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 17 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 18 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 19 | 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss', 20 | 'HuasdorffDisstanceLoss', 'SiLogLoss' 21 | ] 22 | -------------------------------------------------------------------------------- /mmseg/datasets/refuge.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class REFUGEDataset(BaseSegDataset): 10 | """REFUGE dataset. 11 | 12 | In segmentation map annotation for REFUGE, 0 stands for background, which 13 | is not included in 2 categories. ``reduce_zero_label`` is fixed to True. 14 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '.png'. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', ' Optic Cup', 'Optic Disc'), 19 | palette=[[120, 120, 120], [6, 230, 230], [56, 59, 120]]) 20 | 21 | def __init__(self, **kwargs) -> None: 22 | super().__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert fileio.exists( 28 | self.data_prefix['img_path'], backend_args=self.backend_args) 29 | -------------------------------------------------------------------------------- /mmseg/datasets/levir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from mmseg.registry import DATASETS 4 | from .basesegdataset import BaseCDDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class LEVIRCDDataset(BaseCDDataset): 9 | """ISPRS dataset. 10 | 11 | In segmentation map annotation for ISPRS, 0 is to ignore index. 12 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 13 | ``seg_map_suffix`` are both fixed to '.png'. 14 | """ 15 | 16 | METAINFO = dict( 17 | classes=('background', 'changed'), 18 | palette=[[0, 0, 0], [255, 255, 255]]) 19 | 20 | def __init__(self, 21 | img_suffix='.png', 22 | img_suffix2='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | img_suffix2=img_suffix2, 29 | seg_map_suffix=seg_map_suffix, 30 | reduce_zero_label=reduce_zero_label, 31 | **kwargs) 32 | -------------------------------------------------------------------------------- /mmseg/datasets/isprs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class ISPRSDataset(BaseSegDataset): 8 | """ISPRS dataset. 9 | 10 | In segmentation map annotation for ISPRS, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | METAINFO = dict( 15 | classes=('impervious_surface', 'building', 'low_vegetation', 'tree', 16 | 'car', 'clutter'), 17 | palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]]) 19 | 20 | def __init__(self, 21 | img_suffix='.png', 22 | seg_map_suffix='.png', 23 | reduce_zero_label=True, 24 | **kwargs) -> None: 25 | super().__init__( 26 | img_suffix=img_suffix, 27 | seg_map_suffix=seg_map_suffix, 28 | reduce_zero_label=reduce_zero_label, 29 | **kwargs) 30 | -------------------------------------------------------------------------------- /mmseg/datasets/potsdam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class PotsdamDataset(BaseSegDataset): 8 | """ISPRS Potsdam dataset. 9 | 10 | In segmentation map annotation for Potsdam dataset, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | METAINFO = dict( 15 | classes=('impervious_surface', 'building', 'low_vegetation', 'tree', 16 | 'car', 'clutter'), 17 | palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]]) 19 | 20 | def __init__(self, 21 | img_suffix='.png', 22 | seg_map_suffix='.png', 23 | reduce_zero_label=True, 24 | **kwargs) -> None: 25 | super().__init__( 26 | img_suffix=img_suffix, 27 | seg_map_suffix=seg_map_suffix, 28 | reduce_zero_label=reduce_zero_label, 29 | **kwargs) 30 | -------------------------------------------------------------------------------- /mmseg/datasets/loveda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class LoveDADataset(BaseSegDataset): 8 | """LoveDA dataset. 9 | 10 | In segmentation map annotation for LoveDA, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | METAINFO = dict( 15 | classes=('background', 'building', 'road', 'water', 'barren', 'forest', 16 | 'agricultural'), 17 | palette=[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], 18 | [159, 129, 183], [0, 255, 0], [255, 195, 128]]) 19 | 20 | def __init__(self, 21 | img_suffix='.png', 22 | seg_map_suffix='.png', 23 | reduce_zero_label=True, 24 | **kwargs) -> None: 25 | super().__init__( 26 | img_suffix=img_suffix, 27 | seg_map_suffix=seg_map_suffix, 28 | reduce_zero_label=reduce_zero_label, 29 | **kwargs) 30 | -------------------------------------------------------------------------------- /mmseg/datasets/hrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class HRFDataset(BaseSegDataset): 10 | """HRF dataset. 11 | 12 | In segmentation map annotation for HRF, 0 stands for background, which is 13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '.png'. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'vessel'), 19 | palette=[[120, 120, 120], [6, 230, 230]]) 20 | 21 | def __init__(self, 22 | img_suffix='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | seg_map_suffix=seg_map_suffix, 29 | reduce_zero_label=reduce_zero_label, 30 | **kwargs) 31 | assert fileio.exists( 32 | self.data_prefix['img_path'], backend_args=self.backend_args) 33 | -------------------------------------------------------------------------------- /mmseg/datasets/stare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class STAREDataset(BaseSegDataset): 10 | """STARE dataset. 11 | 12 | In segmentation map annotation for STARE, 0 stands for background, which is 13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '.ah.png'. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'vessel'), 19 | palette=[[120, 120, 120], [6, 230, 230]]) 20 | 21 | def __init__(self, 22 | img_suffix='.png', 23 | seg_map_suffix='.ah.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | seg_map_suffix=seg_map_suffix, 29 | reduce_zero_label=reduce_zero_label, 30 | **kwargs) 31 | assert fileio.exists( 32 | self.data_prefix['img_path'], backend_args=self.backend_args) 33 | -------------------------------------------------------------------------------- /mmseg/datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class DRIVEDataset(BaseSegDataset): 10 | """DRIVE dataset. 11 | 12 | In segmentation map annotation for DRIVE, 0 stands for background, which is 13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '_manual1.png'. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'vessel'), 19 | palette=[[120, 120, 120], [6, 230, 230]]) 20 | 21 | def __init__(self, 22 | img_suffix='.png', 23 | seg_map_suffix='_manual1.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | seg_map_suffix=seg_map_suffix, 29 | reduce_zero_label=reduce_zero_label, 30 | **kwargs) 31 | assert fileio.exists( 32 | self.data_prefix['img_path'], backend_args=self.backend_args) 33 | -------------------------------------------------------------------------------- /mmseg/datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class ChaseDB1Dataset(BaseSegDataset): 10 | """Chase_db1 dataset. 11 | 12 | In segmentation map annotation for Chase_db1, 0 stands for background, 13 | which is included in 2 categories. ``reduce_zero_label`` is fixed to False. 14 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '_1stHO.png'. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'vessel'), 19 | palette=[[120, 120, 120], [6, 230, 230]]) 20 | 21 | def __init__(self, 22 | img_suffix='.png', 23 | seg_map_suffix='_1stHO.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | seg_map_suffix=seg_map_suffix, 29 | reduce_zero_label=reduce_zero_label, 30 | **kwargs) 31 | assert fileio.exists( 32 | self.data_prefix['img_path'], backend_args=self.backend_args) 33 | -------------------------------------------------------------------------------- /AgriFM/models/heads.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional, Union, Dict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from mmseg.models.builder import MODELS,LOSSES 8 | 9 | from mmengine.model import BaseModule,BaseModel 10 | @MODELS.register_module() 11 | class CropFCNHead(BaseModel): 12 | def __init__(self,embed_dim,num_classes,loss_model): 13 | super().__init__() 14 | self.embed_dim=embed_dim 15 | self.num_classes=num_classes 16 | self.loss_model=LOSSES.build(loss_model) 17 | self.head=nn.Sequential( 18 | nn.Conv2d(self.embed_dim,self.embed_dim//2,kernel_size=3,stride=1,padding=1), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(self.embed_dim//2,self.num_classes,kernel_size=1,stride=1,padding=0) 21 | ) 22 | 23 | def forward(self, 24 | inputs: torch.Tensor, 25 | data_samples: Optional[list] = None, 26 | mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: 27 | logits=self.head(inputs) 28 | if mode=='loss': 29 | loss=self.loss_model(logits,data_samples) 30 | return logits,loss 31 | else: 32 | return logits 33 | 34 | 35 | -------------------------------------------------------------------------------- /AgriFM/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | def get_filename(file_path,is_suffix=True)->str: 3 | file_name=file_path.replace('/','\\') 4 | file_name=file_name.split('\\')[-1] 5 | if is_suffix: 6 | return file_name 7 | else: 8 | index=file_name.rfind('.') 9 | if index>0: 10 | return file_name[0:index] 11 | else: 12 | return file_name 13 | 14 | def get_parent_folder(file_path,with_root=False): 15 | 16 | file_path=file_path.replace('\\','/') 17 | 18 | 19 | index = file_path.rfind('/') 20 | parent_folder=file_path[0:index] 21 | if not with_root: 22 | return get_filename(parent_folder) 23 | return parent_folder 24 | 25 | 26 | def split_filename(file_path:str,split_str:str)->(str,str): 27 | ''' 28 | 根据split_str将文件分为两部分,split_str为后半部分 29 | :param file_path: 30 | :param split_str: 31 | :return: 32 | ''' 33 | index=file_path.index(split_str) 34 | return file_path[0:index],file_path[index:] 35 | 36 | def get_root_path(data_file): 37 | file_path = data_file.replace('\\', '/') 38 | 39 | index = file_path.find('/') 40 | parent_folder = file_path[0:index] 41 | return parent_folder 42 | 43 | if __name__=='__main__': 44 | file_name='GF1_WFV1_E80.0_N29.6_20200920_L1A0005075711_GEO_4488_816.tif' 45 | print(split_filename(file_name,'_GEO')) -------------------------------------------------------------------------------- /mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .basic_block import BasicBlock, Bottleneck 3 | from .embed import PatchEmbed 4 | from .encoding import Encoding 5 | from .inverted_residual import InvertedResidual, InvertedResidualV3 6 | from .make_divisible import make_divisible 7 | from .point_sample import get_uncertain_point_coords_with_randomness 8 | from .ppm import DAPPM, PAPPM 9 | from .res_layer import ResLayer 10 | from .se_layer import SELayer 11 | from .self_attention_block import SelfAttentionBlock 12 | from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, 13 | nlc_to_nchw) 14 | from .up_conv_block import UpConvBlock 15 | 16 | # isort: off 17 | from .wrappers import Upsample, resize 18 | from .san_layers import MLP, LayerNorm2d, cross_attn_layer 19 | 20 | __all__ = [ 21 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 22 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', 23 | 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding', 24 | 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck', 25 | 'cross_attn_layer', 'LayerNorm2d', 'MLP', 26 | 'get_uncertain_point_coords_with_randomness' 27 | ] 28 | -------------------------------------------------------------------------------- /mmseg/datasets/synapse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class SynapseDataset(BaseSegDataset): 8 | """Synapse dataset. 9 | 10 | Before dataset preprocess of Synapse, there are total 13 categories of 11 | foreground which does not include background. After preprocessing, 8 12 | foreground categories are kept while the other 5 foreground categories are 13 | handled as background. The ``img_suffix`` is fixed to '.jpg' and 14 | ``seg_map_suffix`` is fixed to '.png'. 15 | """ 16 | METAINFO = dict( 17 | classes=('background', 'aorta', 'gallbladder', 'left_kidney', 18 | 'right_kidney', 'liver', 'pancreas', 'spleen', 'stomach'), 19 | palette=[[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], 20 | [0, 255, 255], [255, 0, 255], [255, 255, 0], [60, 255, 255], 21 | [240, 240, 240]]) 22 | 23 | def __init__(self, 24 | img_suffix='.jpg', 25 | seg_map_suffix='.png', 26 | **kwargs) -> None: 27 | super().__init__( 28 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 29 | -------------------------------------------------------------------------------- /mmseg/apis/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import defaultdict 3 | from typing import Sequence, Union 4 | 5 | import numpy as np 6 | from mmengine.dataset import Compose 7 | from mmengine.model import BaseModel 8 | 9 | ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] 10 | 11 | 12 | def _preprare_data(imgs: ImageType, model: BaseModel): 13 | 14 | cfg = model.cfg 15 | for t in cfg.test_pipeline: 16 | if t.get('type') == 'LoadAnnotations': 17 | cfg.test_pipeline.remove(t) 18 | 19 | is_batch = True 20 | if not isinstance(imgs, (list, tuple)): 21 | imgs = [imgs] 22 | is_batch = False 23 | 24 | if isinstance(imgs[0], np.ndarray): 25 | cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray' 26 | 27 | # TODO: Consider using the singleton pattern to avoid building 28 | # a pipeline for each inference 29 | pipeline = Compose(cfg.test_pipeline) 30 | 31 | data = defaultdict(list) 32 | for img in imgs: 33 | if isinstance(img, np.ndarray): 34 | data_ = dict(img=img) 35 | else: 36 | data_ = dict(img_path=img) 37 | data_ = pipeline(data_) 38 | data['inputs'].append(data_['inputs']) 39 | data['data_samples'].append(data_['data_samples']) 40 | 41 | return data, is_batch 42 | -------------------------------------------------------------------------------- /AgriFM/models/multi_unified_model.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional, Union, Dict 3 | import torch 4 | from mmseg.models.builder import MODELS 5 | from mmengine.model import BaseModel 6 | from mmengine.runner import load_state_dict,load_checkpoint 7 | 8 | @MODELS.register_module() 9 | class MultiUnifiedModel(BaseModel): 10 | def __init__(self,encoders,head,neck=None,load_from=None): 11 | super().__init__() 12 | self.encoders=MODELS.build(encoders) 13 | if neck is not None: 14 | self.neck=MODELS.build(neck) 15 | else: 16 | self.neck=None 17 | self.heads=MODELS.build(head) 18 | 19 | if load_from is not None: 20 | load_checkpoint(self,load_from,strict=False) 21 | 22 | def forward(self, 23 | inputs:dict, 24 | data_samples: Optional[list] = None, 25 | mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: 26 | outputs=self.encoders(inputs,data_samples,mode) 27 | if self.neck is not None: 28 | outputs=self.neck(outputs) 29 | if mode=='tensor' or mode=='predict': 30 | outputs=self.heads(outputs,mode=mode) 31 | else: 32 | logits,outputs=self.heads(outputs,data_samples,mode) 33 | self.result_list=logits 34 | return outputs 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .beit import BEiT 3 | from .bisenetv1 import BiSeNetV1 4 | from .bisenetv2 import BiSeNetV2 5 | from .cgnet import CGNet 6 | from .ddrnet import DDRNet 7 | from .erfnet import ERFNet 8 | from .fast_scnn import FastSCNN 9 | from .hrnet import HRNet 10 | from .icnet import ICNet 11 | from .mae import MAE 12 | from .mit import MixVisionTransformer 13 | from .mobilenet_v2 import MobileNetV2 14 | from .mobilenet_v3 import MobileNetV3 15 | from .mscan import MSCAN 16 | from .pidnet import PIDNet 17 | from .resnest import ResNeSt 18 | from .resnet import ResNet, ResNetV1c, ResNetV1d 19 | from .resnext import ResNeXt 20 | from .stdc import STDCContextPathNet, STDCNet 21 | from .swin import SwinTransformer 22 | from .timm_backbone import TIMMBackbone 23 | from .twins import PCPVT, SVT 24 | from .unet import UNet 25 | from .vit import VisionTransformer 26 | from .vpd import VPD 27 | 28 | __all__ = [ 29 | 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 30 | 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 31 | 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 32 | 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 33 | 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN', 34 | 'DDRNet', 'VPD' 35 | ] 36 | -------------------------------------------------------------------------------- /mmseg/datasets/bdd100k.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from mmseg.datasets.basesegdataset import BaseSegDataset 4 | from mmseg.registry import DATASETS 5 | 6 | 7 | @DATASETS.register_module() 8 | class BDD100KDataset(BaseSegDataset): 9 | METAINFO = dict( 10 | classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 11 | 'traffic light', 'traffic sign', 'vegetation', 'terrain', 12 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 13 | 'motorcycle', 'bicycle'), 14 | palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], 15 | [190, 153, 153], [153, 153, 153], [250, 170, 16 | 30], [220, 220, 0], 17 | [107, 142, 35], [152, 251, 152], [70, 130, 180], 18 | [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], 19 | [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) 20 | 21 | def __init__(self, 22 | img_suffix='.jpg', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | seg_map_suffix=seg_map_suffix, 29 | reduce_zero_label=reduce_zero_label, 30 | **kwargs) 31 | -------------------------------------------------------------------------------- /mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /mmseg/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class CityscapesDataset(BaseSegDataset): 8 | """Cityscapes dataset. 9 | 10 | The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is 11 | fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. 12 | """ 13 | METAINFO = dict( 14 | classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 15 | 'traffic light', 'traffic sign', 'vegetation', 'terrain', 16 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 17 | 'motorcycle', 'bicycle'), 18 | palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], 19 | [190, 153, 153], [153, 153, 153], [250, 170, 20 | 30], [220, 220, 0], 21 | [107, 142, 35], [152, 251, 152], [70, 130, 180], 22 | [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], 23 | [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) 24 | 25 | def __init__(self, 26 | img_suffix='_leftImg8bit.png', 27 | seg_map_suffix='_gtFine_labelTrainIds.png', 28 | **kwargs) -> None: 29 | super().__init__( 30 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 31 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/cc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.registry import MODELS 5 | from .fcn_head import FCNHead 6 | 7 | try: 8 | from mmcv.ops import CrissCrossAttention 9 | except ModuleNotFoundError: 10 | CrissCrossAttention = None 11 | 12 | 13 | @MODELS.register_module() 14 | class CCHead(FCNHead): 15 | """CCNet: Criss-Cross Attention for Semantic Segmentation. 16 | 17 | This head is the implementation of `CCNet 18 | `_. 19 | 20 | Args: 21 | recurrence (int): Number of recurrence of Criss Cross Attention 22 | module. Default: 2. 23 | """ 24 | 25 | def __init__(self, recurrence=2, **kwargs): 26 | if CrissCrossAttention is None: 27 | raise RuntimeError('Please install mmcv-full for ' 28 | 'CrissCrossAttention ops') 29 | super().__init__(num_convs=2, **kwargs) 30 | self.recurrence = recurrence 31 | self.cca = CrissCrossAttention(self.channels) 32 | 33 | def forward(self, inputs): 34 | """Forward function.""" 35 | x = self._transform_inputs(inputs) 36 | output = self.convs[0](x) 37 | for _ in range(self.recurrence): 38 | output = self.cca(output) 39 | output = self.convs[1](output) 40 | if self.concat_input: 41 | output = self.conv_cat(torch.cat([x, output], dim=1)) 42 | output = self.cls_seg(output) 43 | return output 44 | -------------------------------------------------------------------------------- /mmseg/datasets/lip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class LIPDataset(BaseSegDataset): 8 | """LIP dataset. 9 | 10 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 11 | '.png'. 12 | """ 13 | METAINFO = dict( 14 | classes=('Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 15 | 'UpperClothes', 'Dress', 'Coat', 'Socks', 'Pants', 16 | 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 17 | 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', 18 | 'Right-shoe'), 19 | palette=( 20 | [0, 0, 0], 21 | [128, 0, 0], 22 | [255, 0, 0], 23 | [0, 85, 0], 24 | [170, 0, 51], 25 | [255, 85, 0], 26 | [0, 0, 85], 27 | [0, 119, 221], 28 | [85, 85, 0], 29 | [0, 85, 85], 30 | [85, 51, 0], 31 | [52, 86, 128], 32 | [0, 128, 0], 33 | [0, 0, 255], 34 | [51, 170, 221], 35 | [0, 255, 255], 36 | [85, 255, 170], 37 | [170, 255, 85], 38 | [255, 255, 0], 39 | [255, 170, 0], 40 | )) 41 | 42 | def __init__(self, 43 | img_suffix='.jpg', 44 | seg_map_suffix='.png', 45 | **kwargs) -> None: 46 | super().__init__( 47 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 48 | -------------------------------------------------------------------------------- /mmseg/utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import gzip 3 | import io 4 | import pickle 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray: 11 | """Data decoding from bytes. 12 | 13 | Args: 14 | content (bytes): The data bytes got from files or other streams. 15 | backend (str): The data decoding backend type. Options are 'numpy', 16 | 'nifti', 'cv2' and 'pickle'. Defaults to 'numpy'. 17 | 18 | Returns: 19 | numpy.ndarray: Loaded data array. 20 | """ 21 | if backend == 'pickle': 22 | data = pickle.loads(content) 23 | else: 24 | with io.BytesIO(content) as f: 25 | if backend == 'nifti': 26 | f = gzip.open(f) 27 | try: 28 | from nibabel import FileHolder, Nifti1Image 29 | except ImportError: 30 | print('nifti files io depends on nibabel, please run' 31 | '`pip install nibabel` to install it') 32 | fh = FileHolder(fileobj=f) 33 | data = Nifti1Image.from_file_map({'header': fh, 'image': fh}) 34 | data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata() 35 | elif backend == 'numpy': 36 | data = np.load(f) 37 | elif backend == 'cv2': 38 | data = np.frombuffer(f.read(), dtype=np.uint8) 39 | data = cv2.imdecode(data, cv2.IMREAD_UNCHANGED) 40 | else: 41 | raise ValueError 42 | return data 43 | -------------------------------------------------------------------------------- /mmseg/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmengine.fileio as fileio 5 | 6 | from mmseg.registry import DATASETS 7 | from .basesegdataset import BaseSegDataset 8 | 9 | 10 | @DATASETS.register_module() 11 | class PascalVOCDataset(BaseSegDataset): 12 | """Pascal VOC dataset. 13 | 14 | Args: 15 | split (str): Split txt file for Pascal VOC. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat', 19 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 20 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 21 | 'sofa', 'train', 'tvmonitor'), 22 | palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 23 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 24 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 25 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 26 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 27 | [0, 64, 128]]) 28 | 29 | def __init__(self, 30 | ann_file, 31 | img_suffix='.jpg', 32 | seg_map_suffix='.png', 33 | **kwargs) -> None: 34 | super().__init__( 35 | img_suffix=img_suffix, 36 | seg_map_suffix=seg_map_suffix, 37 | ann_file=ann_file, 38 | **kwargs) 39 | assert fileio.exists(self.data_prefix['img_path'], 40 | self.backend_args) and osp.isfile(self.ann_file) 41 | -------------------------------------------------------------------------------- /mmseg/datasets/isaid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class iSAIDDataset(BaseSegDataset): 10 | """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images 11 | In segmentation map annotation for iSAID dataset, which is included 12 | in 16 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_manual1.png'. 15 | """ 16 | 17 | METAINFO = dict( 18 | classes=('background', 'ship', 'store_tank', 'baseball_diamond', 19 | 'tennis_court', 'basketball_court', 'Ground_Track_Field', 20 | 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', 21 | 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', 22 | 'Harbor'), 23 | palette=[[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], 24 | [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], 25 | [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], 26 | [0, 127, 191], [0, 127, 255], [0, 100, 155]]) 27 | 28 | def __init__(self, 29 | img_suffix='.png', 30 | seg_map_suffix='_instance_color_RGB.png', 31 | ignore_index=255, 32 | **kwargs) -> None: 33 | super().__init__( 34 | img_suffix=img_suffix, 35 | seg_map_suffix=seg_map_suffix, 36 | ignore_index=ignore_index, 37 | **kwargs) 38 | assert fileio.exists( 39 | self.data_prefix['img_path'], backend_args=self.backend_args) 40 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/nl_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import NonLocal2d 4 | 5 | from mmseg.registry import MODELS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @MODELS.register_module() 10 | class NLHead(FCNHead): 11 | """Non-local Neural Networks. 12 | 13 | This head is the implementation of `NLNet 14 | `_. 15 | 16 | Args: 17 | reduction (int): Reduction factor of projection transform. Default: 2. 18 | use_scale (bool): Whether to scale pairwise_weight by 19 | sqrt(1/inter_channels). Default: True. 20 | mode (str): The nonlocal mode. Options are 'embedded_gaussian', 21 | 'dot_product'. Default: 'embedded_gaussian.'. 22 | """ 23 | 24 | def __init__(self, 25 | reduction=2, 26 | use_scale=True, 27 | mode='embedded_gaussian', 28 | **kwargs): 29 | super().__init__(num_convs=2, **kwargs) 30 | self.reduction = reduction 31 | self.use_scale = use_scale 32 | self.mode = mode 33 | self.nl_block = NonLocal2d( 34 | in_channels=self.channels, 35 | reduction=self.reduction, 36 | use_scale=self.use_scale, 37 | conv_cfg=self.conv_cfg, 38 | norm_cfg=self.norm_cfg, 39 | mode=self.mode) 40 | 41 | def forward(self, inputs): 42 | """Forward function.""" 43 | x = self._transform_inputs(inputs) 44 | output = self.convs[0](x) 45 | output = self.nl_block(output) 46 | output = self.convs[1](output) 47 | if self.concat_input: 48 | output = self.conv_cat(torch.cat([x, output], dim=1)) 49 | output = self.cls_seg(output) 50 | return output 51 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/gc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import ContextBlock 4 | 5 | from mmseg.registry import MODELS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @MODELS.register_module() 10 | class GCHead(FCNHead): 11 | """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. 12 | 13 | This head is the implementation of `GCNet 14 | `_. 15 | 16 | Args: 17 | ratio (float): Multiplier of channels ratio. Default: 1/4. 18 | pooling_type (str): The pooling type of context aggregation. 19 | Options are 'att', 'avg'. Default: 'avg'. 20 | fusion_types (tuple[str]): The fusion type for feature fusion. 21 | Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) 22 | """ 23 | 24 | def __init__(self, 25 | ratio=1 / 4., 26 | pooling_type='att', 27 | fusion_types=('channel_add', ), 28 | **kwargs): 29 | super().__init__(num_convs=2, **kwargs) 30 | self.ratio = ratio 31 | self.pooling_type = pooling_type 32 | self.fusion_types = fusion_types 33 | self.gc_block = ContextBlock( 34 | in_channels=self.channels, 35 | ratio=self.ratio, 36 | pooling_type=self.pooling_type, 37 | fusion_types=self.fusion_types) 38 | 39 | def forward(self, inputs): 40 | """Forward function.""" 41 | x = self._transform_inputs(inputs) 42 | output = self.convs[0](x) 43 | output = self.gc_block(output) 44 | output = self.convs[1](output) 45 | if self.concat_input: 46 | output = self.conv_cat(torch.cat([x, output], dim=1)) 47 | output = self.cls_seg(output) 48 | return output 49 | -------------------------------------------------------------------------------- /mmseg/datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .formatting import PackSegInputs 3 | from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, 4 | LoadBiomedicalData, LoadBiomedicalImageFromFile, 5 | LoadDepthAnnotation, LoadImageFromNDArray, 6 | LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile) 7 | # yapf: disable 8 | from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, 9 | BioMedical3DRandomCrop, BioMedical3DRandomFlip, 10 | BioMedicalGaussianBlur, BioMedicalGaussianNoise, 11 | BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, 12 | PhotoMetricDistortion, RandomCrop, RandomCutOut, 13 | RandomDepthMix, RandomFlip, RandomMosaic, 14 | RandomRotate, RandomRotFlip, Rerange, Resize, 15 | ResizeShortestEdge, ResizeToMultiple, RGB2Gray, 16 | SegRescale) 17 | 18 | # yapf: enable 19 | __all__ = [ 20 | 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', 21 | 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 22 | 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 23 | 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 24 | 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 25 | 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', 26 | 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', 27 | 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput', 28 | 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix', 29 | 'RandomFlip', 'Resize' 30 | ] 31 | -------------------------------------------------------------------------------- /mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmseg.registry import MODELS 5 | 6 | BACKBONES = MODELS 7 | NECKS = MODELS 8 | HEADS = MODELS 9 | LOSSES = MODELS 10 | SEGMENTORS = MODELS 11 | 12 | 13 | def build_backbone(cfg): 14 | """Build backbone.""" 15 | warnings.warn('``build_backbone`` would be deprecated soon, please use ' 16 | '``mmseg.registry.MODELS.build()`` ') 17 | return BACKBONES.build(cfg) 18 | 19 | 20 | def build_neck(cfg): 21 | """Build neck.""" 22 | warnings.warn('``build_neck`` would be deprecated soon, please use ' 23 | '``mmseg.registry.MODELS.build()`` ') 24 | return NECKS.build(cfg) 25 | 26 | 27 | def build_head(cfg): 28 | """Build head.""" 29 | warnings.warn('``build_head`` would be deprecated soon, please use ' 30 | '``mmseg.registry.MODELS.build()`` ') 31 | return HEADS.build(cfg) 32 | 33 | 34 | def build_loss(cfg): 35 | """Build loss.""" 36 | warnings.warn('``build_loss`` would be deprecated soon, please use ' 37 | '``mmseg.registry.MODELS.build()`` ') 38 | return LOSSES.build(cfg) 39 | 40 | 41 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 42 | """Build segmentor.""" 43 | if train_cfg is not None or test_cfg is not None: 44 | warnings.warn( 45 | 'train_cfg and test_cfg is deprecated, ' 46 | 'please specify them in model', UserWarning) 47 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 48 | 'train_cfg specified in both outer field and model field ' 49 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 50 | 'test_cfg specified in both outer field and model field ' 51 | return SEGMENTORS.build( 52 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 53 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/seg_tta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List 3 | 4 | import torch 5 | from mmengine.model import BaseTTAModel 6 | from mmengine.structures import PixelData 7 | 8 | from mmseg.registry import MODELS 9 | from mmseg.utils import SampleList 10 | 11 | 12 | @MODELS.register_module() 13 | class SegTTAModel(BaseTTAModel): 14 | 15 | def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: 16 | """Merge predictions of enhanced data to one prediction. 17 | 18 | Args: 19 | data_samples_list (List[SampleList]): List of predictions 20 | of all enhanced data. 21 | 22 | Returns: 23 | SampleList: Merged prediction. 24 | """ 25 | predictions = [] 26 | for data_samples in data_samples_list: 27 | seg_logits = data_samples[0].seg_logits.data 28 | logits = torch.zeros(seg_logits.shape).to(seg_logits) 29 | for data_sample in data_samples: 30 | seg_logit = data_sample.seg_logits.data 31 | if self.module.out_channels > 1: 32 | logits += seg_logit.softmax(dim=0) 33 | else: 34 | logits += seg_logit.sigmoid() 35 | logits /= len(data_samples) 36 | if self.module.out_channels == 1: 37 | seg_pred = (logits > self.module.decode_head.threshold 38 | ).to(logits).squeeze(1) 39 | else: 40 | seg_pred = logits.argmax(dim=0) 41 | data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)}) 42 | if hasattr(data_samples[0], 'gt_sem_seg'): 43 | data_sample.set_data( 44 | {'gt_sem_seg': data_samples[0].gt_sem_seg}) 45 | data_sample.set_metainfo({'img_path': data_samples[0].img_path}) 46 | predictions.append(data_sample) 47 | return predictions 48 | -------------------------------------------------------------------------------- /mmseg/utils/set_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import datetime 3 | import warnings 4 | 5 | from mmengine import DefaultScope 6 | 7 | 8 | def register_all_modules(init_default_scope: bool = True) -> None: 9 | """Register all modules in mmseg into the registries. 10 | 11 | Args: 12 | init_default_scope (bool): Whether initialize the mmseg default scope. 13 | When `init_default_scope=True`, the global default scope will be 14 | set to `mmseg`, and all registries will build modules from mmseg's 15 | registry node. To understand more about the registry, please refer 16 | to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md 17 | Defaults to True. 18 | """ # noqa 19 | import mmseg.datasets # noqa: F401,F403 20 | import mmseg.engine # noqa: F401,F403 21 | import mmseg.evaluation # noqa: F401,F403 22 | import mmseg.models # noqa: F401,F403 23 | import mmseg.structures # noqa: F401,F403 24 | 25 | if init_default_scope: 26 | never_created = DefaultScope.get_current_instance() is None \ 27 | or not DefaultScope.check_instance_created('mmseg') 28 | if never_created: 29 | DefaultScope.get_instance('mmseg', scope_name='mmseg') 30 | return 31 | current_scope = DefaultScope.get_current_instance() 32 | if current_scope.scope_name != 'mmseg': 33 | warnings.warn('The current default scope ' 34 | f'"{current_scope.scope_name}" is not "mmseg", ' 35 | '`register_all_modules` will force the current' 36 | 'default scope to be "mmseg". If this is not ' 37 | 'expected, please set `init_default_scope=False`.') 38 | # avoid name conflict 39 | new_instance_name = f'mmseg-{datetime.datetime.now()}' 40 | DefaultScope.get_instance(new_instance_name, scope_name='mmseg') 41 | -------------------------------------------------------------------------------- /mmseg/models/utils/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > output_h: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class Upsample(nn.Module): 31 | 32 | def __init__(self, 33 | size=None, 34 | scale_factor=None, 35 | mode='nearest', 36 | align_corners=None): 37 | super().__init__() 38 | self.size = size 39 | if isinstance(scale_factor, tuple): 40 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 41 | else: 42 | self.scale_factor = float(scale_factor) if scale_factor else None 43 | self.mode = mode 44 | self.align_corners = align_corners 45 | 46 | def forward(self, x): 47 | if not self.size: 48 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 49 | else: 50 | size = self.size 51 | return resize(x, size, None, self.mode, self.align_corners) 52 | -------------------------------------------------------------------------------- /mmseg/models/losses/boundary_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from mmseg.registry import MODELS 8 | 9 | 10 | @MODELS.register_module() 11 | class BoundaryLoss(nn.Module): 12 | """Boundary loss. 13 | 14 | This function is modified from 15 | `PIDNet `_. # noqa 16 | Licensed under the MIT License. 17 | 18 | 19 | Args: 20 | loss_weight (float): Weight of the loss. Defaults to 1.0. 21 | loss_name (str): Name of the loss item. If you want this loss 22 | item to be included into the backward graph, `loss_` must be the 23 | prefix of the name. Defaults to 'loss_boundary'. 24 | """ 25 | 26 | def __init__(self, 27 | loss_weight: float = 1.0, 28 | loss_name: str = 'loss_boundary'): 29 | super().__init__() 30 | self.loss_weight = loss_weight 31 | self.loss_name_ = loss_name 32 | 33 | def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor: 34 | """Forward function. 35 | Args: 36 | bd_pre (Tensor): Predictions of the boundary head. 37 | bd_gt (Tensor): Ground truth of the boundary. 38 | 39 | Returns: 40 | Tensor: Loss tensor. 41 | """ 42 | log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1) 43 | target_t = bd_gt.view(1, -1).float() 44 | 45 | pos_index = (target_t == 1) 46 | neg_index = (target_t == 0) 47 | 48 | weight = torch.zeros_like(log_p) 49 | pos_num = pos_index.sum() 50 | neg_num = neg_index.sum() 51 | sum_num = pos_num + neg_num 52 | weight[pos_index] = neg_num * 1.0 / sum_num 53 | weight[neg_index] = pos_num * 1.0 / sum_num 54 | 55 | loss = F.binary_cross_entropy_with_logits( 56 | log_p, target_t, weight, reduction='mean') 57 | 58 | return self.loss_weight * loss 59 | 60 | @property 61 | def loss_name(self): 62 | return self.loss_name_ 63 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ann_head import ANNHead 3 | from .apc_head import APCHead 4 | from .aspp_head import ASPPHead 5 | from .cc_head import CCHead 6 | from .da_head import DAHead 7 | from .ddr_head import DDRHead 8 | from .dm_head import DMHead 9 | from .dnl_head import DNLHead 10 | from .dpt_head import DPTHead 11 | from .ema_head import EMAHead 12 | from .enc_head import EncHead 13 | from .fcn_head import FCNHead 14 | from .fpn_head import FPNHead 15 | from .gc_head import GCHead 16 | from .ham_head import LightHamHead 17 | from .isa_head import ISAHead 18 | from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator 19 | from .lraspp_head import LRASPPHead 20 | from .mask2former_head import Mask2FormerHead 21 | from .maskformer_head import MaskFormerHead 22 | from .nl_head import NLHead 23 | from .ocr_head import OCRHead 24 | from .pid_head import PIDHead 25 | from .point_head import PointHead 26 | from .psa_head import PSAHead 27 | from .psp_head import PSPHead 28 | from .san_head import SideAdapterCLIPHead 29 | from .segformer_head import SegformerHead 30 | from .segmenter_mask_head import SegmenterMaskTransformerHead 31 | from .sep_aspp_head import DepthwiseSeparableASPPHead 32 | from .sep_fcn_head import DepthwiseSeparableFCNHead 33 | from .setr_mla_head import SETRMLAHead 34 | from .setr_up_head import SETRUPHead 35 | from .stdc_head import STDCHead 36 | from .uper_head import UPerHead 37 | from .vpd_depth_head import VPDDepthHead 38 | 39 | __all__ = [ 40 | 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 41 | 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 42 | 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 43 | 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 44 | 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 45 | 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 46 | 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', 47 | 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead' 48 | ] 49 | -------------------------------------------------------------------------------- /mmseg/models/backbones/timm_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | try: 3 | import timm 4 | except ImportError: 5 | timm = None 6 | 7 | from mmengine.model import BaseModule 8 | from mmengine.registry import MODELS as MMENGINE_MODELS 9 | 10 | from mmseg.registry import MODELS 11 | 12 | 13 | @MODELS.register_module() 14 | class TIMMBackbone(BaseModule): 15 | """Wrapper to use backbones from timm library. More details can be found in 16 | `timm `_ . 17 | 18 | Args: 19 | model_name (str): Name of timm model to instantiate. 20 | pretrained (bool): Load pretrained weights if True. 21 | checkpoint_path (str): Path of checkpoint to load after 22 | model is initialized. 23 | in_channels (int): Number of input image channels. Default: 3. 24 | init_cfg (dict, optional): Initialization config dict 25 | **kwargs: Other timm & model specific arguments. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model_name, 31 | features_only=True, 32 | pretrained=True, 33 | checkpoint_path='', 34 | in_channels=3, 35 | init_cfg=None, 36 | **kwargs, 37 | ): 38 | if timm is None: 39 | raise RuntimeError('timm is not installed') 40 | super().__init__(init_cfg) 41 | if 'norm_layer' in kwargs: 42 | kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer']) 43 | self.timm_model = timm.create_model( 44 | model_name=model_name, 45 | features_only=features_only, 46 | pretrained=pretrained, 47 | in_chans=in_channels, 48 | checkpoint_path=checkpoint_path, 49 | **kwargs, 50 | ) 51 | 52 | # Make unused parameters None 53 | self.timm_model.global_pool = None 54 | self.timm_model.fc = None 55 | self.timm_model.classifier = None 56 | 57 | # Hack to use pretrained weights from timm 58 | if pretrained or checkpoint_path: 59 | self._is_init = True 60 | 61 | def forward(self, x): 62 | features = self.timm_model(x) 63 | return features 64 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/segformer_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.models.decode_heads.decode_head import BaseDecodeHead 7 | from mmseg.registry import MODELS 8 | from ..utils import resize 9 | 10 | 11 | @MODELS.register_module() 12 | class SegformerHead(BaseDecodeHead): 13 | """The all mlp Head of segformer. 14 | 15 | This head is the implementation of 16 | `Segformer ` _. 17 | 18 | Args: 19 | interpolate_mode: The interpolate mode of MLP head upsample operation. 20 | Default: 'bilinear'. 21 | """ 22 | 23 | def __init__(self, interpolate_mode='bilinear', **kwargs): 24 | super().__init__(input_transform='multiple_select', **kwargs) 25 | 26 | self.interpolate_mode = interpolate_mode 27 | num_inputs = len(self.in_channels) 28 | 29 | assert num_inputs == len(self.in_index) 30 | 31 | self.convs = nn.ModuleList() 32 | for i in range(num_inputs): 33 | self.convs.append( 34 | ConvModule( 35 | in_channels=self.in_channels[i], 36 | out_channels=self.channels, 37 | kernel_size=1, 38 | stride=1, 39 | norm_cfg=self.norm_cfg, 40 | act_cfg=self.act_cfg)) 41 | 42 | self.fusion_conv = ConvModule( 43 | in_channels=self.channels * num_inputs, 44 | out_channels=self.channels, 45 | kernel_size=1, 46 | norm_cfg=self.norm_cfg) 47 | 48 | def forward(self, inputs): 49 | # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 50 | inputs = self._transform_inputs(inputs) 51 | outs = [] 52 | for idx in range(len(inputs)): 53 | x = inputs[idx] 54 | conv = self.convs[idx] 55 | outs.append( 56 | resize( 57 | input=conv(x), 58 | size=inputs[0].shape[2:], 59 | mode=self.interpolate_mode, 60 | align_corners=self.align_corners)) 61 | 62 | out = self.fusion_conv(torch.cat(outs, dim=1)) 63 | 64 | out = self.cls_seg(out) 65 | 66 | return out 67 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 10/12/2024 6:32 pm 3 | # @Author : Wenyuan Li 4 | # @File : inference_vis.py 5 | # @Description : 6 | import os 7 | import numpy as np 8 | import torch 9 | import cv2 10 | from mmseg.registry import MODELS,DATASETS 11 | from AgriFM.utils import path_utils 12 | import argparse 13 | from mmengine.config import Config, DictAction 14 | import torch.utils.data as data_utils 15 | #import load_checkpoint in mmseg 16 | from mmengine.runner import load_state_dict,load_checkpoint 17 | import tqdm 18 | from skimage import io 19 | import copy 20 | def parse_args(): 21 | parser = argparse.ArgumentParser( 22 | description='MMSeg test (and eval) a model') 23 | parser.add_argument('config', help='train config file path') 24 | parser.add_argument('checkpoint', help='checkpoint file') 25 | parser.add_argument('result_path', help='path to save the inference results') 26 | args=parser.parse_args() 27 | return args 28 | 29 | if __name__=='__main__': 30 | args=parse_args() 31 | config_file=args.config 32 | checkpoint_file=args.checkpoint 33 | 34 | options=["TILED=TRUE","COMPRESS=DEFLATE","NUM_THREADS=4","ZLEVEL=9"] 35 | cfg=Config.fromfile(config_file) 36 | model=MODELS.build(cfg.model) 37 | dataset_cfg=cfg.test_dataloader.dataset 38 | dataset=DATASETS.build(dataset_cfg) 39 | dataloader=data_utils.DataLoader(dataset,batch_size=4,shuffle=False,num_workers=4) 40 | resutl_path=cfg.result_path 41 | if not os.path.exists(resutl_path): 42 | os.makedirs(resutl_path) 43 | #resume from a checkpoint 44 | load_checkpoint(model,checkpoint_file,strict=True) 45 | model.eval() 46 | model.to('cuda') 47 | for data,label in tqdm.tqdm(dataloader): 48 | out_data=copy.deepcopy(data) 49 | file_path = data.pop('file_name') 50 | data=model.data_preprocessor(data) 51 | with torch.no_grad(): 52 | logits=model(data,mode='tensor') 53 | for i in range(len(file_path)): 54 | cls_pred=torch.argmax(logits[i],dim=0) 55 | cls_pred=cls_pred.cpu().numpy() 56 | tile_name=file_path[i] 57 | pred_mask_path=os.path.join(resutl_path,'%s_pred.png'%tile_name) 58 | io.imsave(pred_mask_path,cls_pred.astype(np.uint8),check_contrast=False) 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /mmseg/models/utils/se_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | from mmengine.utils import is_tuple_of 5 | 6 | from .make_divisible import make_divisible 7 | 8 | 9 | class SELayer(nn.Module): 10 | """Squeeze-and-Excitation Module. 11 | 12 | Args: 13 | channels (int): The input (and output) channels of the SE layer. 14 | ratio (int): Squeeze ratio in SELayer, the intermediate channel will be 15 | ``int(channels/ratio)``. Default: 16. 16 | conv_cfg (None or dict): Config dict for convolution layer. 17 | Default: None, which means using conv2d. 18 | act_cfg (dict or Sequence[dict]): Config dict for activation layer. 19 | If act_cfg is a dict, two activation layers will be configured 20 | by this dict. If act_cfg is a sequence of dicts, the first 21 | activation layer will be configured by the first dict and the 22 | second activation layer will be configured by the second dict. 23 | Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, 24 | divisor=6.0)). 25 | """ 26 | 27 | def __init__(self, 28 | channels, 29 | ratio=16, 30 | conv_cfg=None, 31 | act_cfg=(dict(type='ReLU'), 32 | dict(type='HSigmoid', bias=3.0, divisor=6.0))): 33 | super().__init__() 34 | if isinstance(act_cfg, dict): 35 | act_cfg = (act_cfg, act_cfg) 36 | assert len(act_cfg) == 2 37 | assert is_tuple_of(act_cfg, dict) 38 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 39 | self.conv1 = ConvModule( 40 | in_channels=channels, 41 | out_channels=make_divisible(channels // ratio, 8), 42 | kernel_size=1, 43 | stride=1, 44 | conv_cfg=conv_cfg, 45 | act_cfg=act_cfg[0]) 46 | self.conv2 = ConvModule( 47 | in_channels=make_divisible(channels // ratio, 8), 48 | out_channels=channels, 49 | kernel_size=1, 50 | stride=1, 51 | conv_cfg=conv_cfg, 52 | act_cfg=act_cfg[1]) 53 | 54 | def forward(self, x): 55 | out = self.global_avgpool(x) 56 | out = self.conv1(out) 57 | out = self.conv2(out) 58 | return x * out 59 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/setr_mla_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.registry import MODELS 7 | from ..utils import Upsample 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @MODELS.register_module() 12 | class SETRMLAHead(BaseDecodeHead): 13 | """Multi level feature aggretation head of SETR. 14 | 15 | MLA head of `SETR `_. 16 | 17 | Args: 18 | mlahead_channels (int): Channels of conv-conv-4x of multi-level feature 19 | aggregation. Default: 128. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | """ 22 | 23 | def __init__(self, mla_channels=128, up_scale=4, **kwargs): 24 | super().__init__(input_transform='multiple_select', **kwargs) 25 | self.mla_channels = mla_channels 26 | 27 | num_inputs = len(self.in_channels) 28 | 29 | # Refer to self.cls_seg settings of BaseDecodeHead 30 | assert self.channels == num_inputs * mla_channels 31 | 32 | self.up_convs = nn.ModuleList() 33 | for i in range(num_inputs): 34 | self.up_convs.append( 35 | nn.Sequential( 36 | ConvModule( 37 | in_channels=self.in_channels[i], 38 | out_channels=mla_channels, 39 | kernel_size=3, 40 | padding=1, 41 | norm_cfg=self.norm_cfg, 42 | act_cfg=self.act_cfg), 43 | ConvModule( 44 | in_channels=mla_channels, 45 | out_channels=mla_channels, 46 | kernel_size=3, 47 | padding=1, 48 | norm_cfg=self.norm_cfg, 49 | act_cfg=self.act_cfg), 50 | Upsample( 51 | scale_factor=up_scale, 52 | mode='bilinear', 53 | align_corners=self.align_corners))) 54 | 55 | def forward(self, inputs): 56 | inputs = self._transform_inputs(inputs) 57 | outs = [] 58 | for x, up_conv in zip(inputs, self.up_convs): 59 | outs.append(up_conv(x)) 60 | out = torch.cat(outs, dim=1) 61 | out = self.cls_seg(out) 62 | return out 63 | -------------------------------------------------------------------------------- /mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # yapf: disable 3 | from .class_names import (ade_classes, ade_palette, bdd100k_classes, 4 | bdd100k_palette, cityscapes_classes, 5 | cityscapes_palette, cocostuff_classes, 6 | cocostuff_palette, dataset_aliases, get_classes, 7 | get_palette, isaid_classes, isaid_palette, 8 | loveda_classes, loveda_palette, potsdam_classes, 9 | potsdam_palette, stare_classes, stare_palette, 10 | synapse_classes, synapse_palette, vaihingen_classes, 11 | vaihingen_palette, voc_classes, voc_palette) 12 | # yapf: enable 13 | from .collect_env import collect_env 14 | from .get_templates import get_predefined_templates 15 | from .io import datafrombytes 16 | from .misc import add_prefix, stack_batch 17 | from .set_env import register_all_modules 18 | from .tokenizer import tokenize 19 | from .typing_utils import (ConfigType, ForwardResults, MultiConfig, 20 | OptConfigType, OptMultiConfig, OptSampleList, 21 | SampleList, TensorDict, TensorList) 22 | 23 | # isort: off 24 | from .mask_classification import MatchMasks, seg_data_to_instance_data 25 | 26 | __all__ = [ 27 | 'collect_env', 28 | 'register_all_modules', 29 | 'stack_batch', 30 | 'add_prefix', 31 | 'ConfigType', 32 | 'OptConfigType', 33 | 'MultiConfig', 34 | 'OptMultiConfig', 35 | 'SampleList', 36 | 'OptSampleList', 37 | 'TensorDict', 38 | 'TensorList', 39 | 'ForwardResults', 40 | 'cityscapes_classes', 41 | 'ade_classes', 42 | 'voc_classes', 43 | 'cocostuff_classes', 44 | 'loveda_classes', 45 | 'potsdam_classes', 46 | 'vaihingen_classes', 47 | 'isaid_classes', 48 | 'stare_classes', 49 | 'cityscapes_palette', 50 | 'ade_palette', 51 | 'voc_palette', 52 | 'cocostuff_palette', 53 | 'loveda_palette', 54 | 'potsdam_palette', 55 | 'vaihingen_palette', 56 | 'isaid_palette', 57 | 'stare_palette', 58 | 'dataset_aliases', 59 | 'get_classes', 60 | 'get_palette', 61 | 'datafrombytes', 62 | 'synapse_palette', 63 | 'synapse_classes', 64 | 'get_predefined_templates', 65 | 'tokenize', 66 | 'seg_data_to_instance_data', 67 | 'MatchMasks', 68 | 'bdd100k_classes', 69 | 'bdd100k_palette', 70 | ] 71 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/cascade_decode_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import List 4 | 5 | from torch import Tensor 6 | 7 | from mmseg.utils import ConfigType 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): 12 | """Base class for cascade decode head used in 13 | :class:`CascadeEncoderDecoder.""" 14 | 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | 18 | @abstractmethod 19 | def forward(self, inputs, prev_output): 20 | """Placeholder of forward function.""" 21 | pass 22 | 23 | def loss(self, inputs: List[Tensor], prev_output: Tensor, 24 | batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor: 25 | """Forward function for training. 26 | 27 | Args: 28 | inputs (List[Tensor]): List of multi-level img features. 29 | prev_output (Tensor): The output of previous decode head. 30 | batch_data_samples (List[:obj:`SegDataSample`]): The seg 31 | data samples. It usually includes information such 32 | as `metainfo` and `gt_sem_seg`. 33 | train_cfg (dict): The training config. 34 | 35 | Returns: 36 | dict[str, Tensor]: a dictionary of loss components 37 | """ 38 | seg_logits = self.forward(inputs, prev_output) 39 | losses = self.loss_by_feat(seg_logits, batch_data_samples) 40 | 41 | return losses 42 | 43 | def predict(self, inputs: List[Tensor], prev_output: Tensor, 44 | batch_img_metas: List[dict], tese_cfg: ConfigType): 45 | """Forward function for testing. 46 | 47 | Args: 48 | inputs (List[Tensor]): List of multi-level img features. 49 | prev_output (Tensor): The output of previous decode head. 50 | batch_img_metas (dict): List Image info where each dict may also 51 | contain: 'img_shape', 'scale_factor', 'flip', 'img_path', 52 | 'ori_shape', and 'pad_shape'. 53 | For details on the values of these keys see 54 | `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. 55 | test_cfg (dict): The testing config. 56 | 57 | Returns: 58 | Tensor: Output segmentation map. 59 | """ 60 | seg_logits = self.forward(inputs, prev_output) 61 | 62 | return self.predict_by_feat(seg_logits, batch_img_metas) 63 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/sep_fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import DepthwiseSeparableConvModule 3 | 4 | from mmseg.registry import MODELS 5 | from .fcn_head import FCNHead 6 | 7 | 8 | @MODELS.register_module() 9 | class DepthwiseSeparableFCNHead(FCNHead): 10 | """Depthwise-Separable Fully Convolutional Network for Semantic 11 | Segmentation. 12 | 13 | This head is implemented according to `Fast-SCNN: Fast Semantic 14 | Segmentation Network `_. 15 | 16 | Args: 17 | in_channels(int): Number of output channels of FFM. 18 | channels(int): Number of middle-stage channels in the decode head. 19 | concat_input(bool): Whether to concatenate original decode input into 20 | the result of several consecutive convolution layers. 21 | Default: True. 22 | num_classes(int): Used to determine the dimension of 23 | final prediction tensor. 24 | in_index(int): Correspond with 'out_indices' in FastSCNN backbone. 25 | norm_cfg (dict | None): Config of norm layers. 26 | align_corners (bool): align_corners argument of F.interpolate. 27 | Default: False. 28 | loss_decode(dict): Config of loss type and some 29 | relevant additional options. 30 | dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is 31 | 'default', it will be the same as `act_cfg`. Default: None. 32 | """ 33 | 34 | def __init__(self, dw_act_cfg=None, **kwargs): 35 | super().__init__(**kwargs) 36 | self.convs[0] = DepthwiseSeparableConvModule( 37 | self.in_channels, 38 | self.channels, 39 | kernel_size=self.kernel_size, 40 | padding=self.kernel_size // 2, 41 | norm_cfg=self.norm_cfg, 42 | dw_act_cfg=dw_act_cfg) 43 | 44 | for i in range(1, self.num_convs): 45 | self.convs[i] = DepthwiseSeparableConvModule( 46 | self.channels, 47 | self.channels, 48 | kernel_size=self.kernel_size, 49 | padding=self.kernel_size // 2, 50 | norm_cfg=self.norm_cfg, 51 | dw_act_cfg=dw_act_cfg) 52 | 53 | if self.concat_input: 54 | self.conv_cat = DepthwiseSeparableConvModule( 55 | self.in_channels + self.channels, 56 | self.channels, 57 | kernel_size=self.kernel_size, 58 | padding=self.kernel_size // 2, 59 | norm_cfg=self.norm_cfg, 60 | dw_act_cfg=dw_act_cfg) 61 | -------------------------------------------------------------------------------- /mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | import mmengine 6 | from packaging.version import parse 7 | 8 | from .version import __version__, version_info 9 | 10 | MMCV_MIN = '2.0.0rc4' 11 | MMCV_MAX = '2.2.0' 12 | MMENGINE_MIN = '0.5.0' 13 | MMENGINE_MAX = '1.0.0' 14 | 15 | 16 | def digit_version(version_str: str, length: int = 4): 17 | """Convert a version string into a tuple of integers. 18 | 19 | This method is usually used for comparing two versions. For pre-release 20 | versions: alpha < beta < rc. 21 | 22 | Args: 23 | version_str (str): The version string. 24 | length (int): The maximum number of version levels. Default: 4. 25 | 26 | Returns: 27 | tuple[int]: The version info in digits (integers). 28 | """ 29 | version = parse(version_str) 30 | assert version.release, f'failed to parse version {version_str}' 31 | release = list(version.release) 32 | release = release[:length] 33 | if len(release) < length: 34 | release = release + [0] * (length - len(release)) 35 | if version.is_prerelease: 36 | mapping = {'a': -3, 'b': -2, 'rc': -1} 37 | val = -4 38 | # version.pre can be None 39 | if version.pre: 40 | if version.pre[0] not in mapping: 41 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 42 | 'version checking may go wrong') 43 | else: 44 | val = mapping[version.pre[0]] 45 | release.extend([val, version.pre[-1]]) 46 | else: 47 | release.extend([val, 0]) 48 | 49 | elif version.is_postrelease: 50 | release.extend([1, version.post]) 51 | else: 52 | release.extend([0, 0]) 53 | return tuple(release) 54 | 55 | 56 | mmcv_min_version = digit_version(MMCV_MIN) 57 | mmcv_max_version = digit_version(MMCV_MAX) 58 | mmcv_version = digit_version(mmcv.__version__) 59 | 60 | 61 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ 62 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 63 | f'Please install mmcv>=2.0.0rc4.' 64 | 65 | mmengine_min_version = digit_version(MMENGINE_MIN) 66 | mmengine_max_version = digit_version(MMENGINE_MAX) 67 | mmengine_version = digit_version(mmengine.__version__) 68 | 69 | assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \ 70 | f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ 71 | f'Please install mmengine>={mmengine_min_version}, '\ 72 | f'<{mmengine_max_version}.' 73 | 74 | __all__ = ['__version__', 'version_info', 'digit_version'] 75 | -------------------------------------------------------------------------------- /mmseg/models/necks/featurepyramid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import build_norm_layer 4 | 5 | from mmseg.registry import MODELS 6 | 7 | 8 | @MODELS.register_module() 9 | class Feature2Pyramid(nn.Module): 10 | """Feature2Pyramid. 11 | 12 | A neck structure connect ViT backbone and decoder_heads. 13 | 14 | Args: 15 | embed_dims (int): Embedding dimension. 16 | rescales (list[float]): Different sampling multiples were 17 | used to obtain pyramid features. Default: [4, 2, 1, 0.5]. 18 | norm_cfg (dict): Config dict for normalization layer. 19 | Default: dict(type='SyncBN', requires_grad=True). 20 | """ 21 | 22 | def __init__(self, 23 | embed_dim, 24 | rescales=[4, 2, 1, 0.5], 25 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 26 | super().__init__() 27 | self.rescales = rescales 28 | self.upsample_4x = None 29 | for k in self.rescales: 30 | if k == 4: 31 | self.upsample_4x = nn.Sequential( 32 | nn.ConvTranspose2d( 33 | embed_dim, embed_dim, kernel_size=2, stride=2), 34 | build_norm_layer(norm_cfg, embed_dim)[1], 35 | nn.GELU(), 36 | nn.ConvTranspose2d( 37 | embed_dim, embed_dim, kernel_size=2, stride=2), 38 | ) 39 | elif k == 2: 40 | self.upsample_2x = nn.Sequential( 41 | nn.ConvTranspose2d( 42 | embed_dim, embed_dim, kernel_size=2, stride=2)) 43 | elif k == 1: 44 | self.identity = nn.Identity() 45 | elif k == 0.5: 46 | self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) 47 | elif k == 0.25: 48 | self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) 49 | else: 50 | raise KeyError(f'invalid {k} for feature2pyramid') 51 | 52 | def forward(self, inputs): 53 | assert len(inputs) == len(self.rescales) 54 | outputs = [] 55 | if self.upsample_4x is not None: 56 | ops = [ 57 | self.upsample_4x, self.upsample_2x, self.identity, 58 | self.downsample_2x 59 | ] 60 | else: 61 | ops = [ 62 | self.upsample_2x, self.identity, self.downsample_2x, 63 | self.downsample_4x 64 | ] 65 | for i in range(len(inputs)): 66 | outputs.append(ops[i](inputs[i])) 67 | return tuple(outputs) 68 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/fpn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.registry import MODELS 7 | from ..utils import Upsample, resize 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @MODELS.register_module() 12 | class FPNHead(BaseDecodeHead): 13 | """Panoptic Feature Pyramid Networks. 14 | 15 | This head is the implementation of `Semantic FPN 16 | `_. 17 | 18 | Args: 19 | feature_strides (tuple[int]): The strides for input feature maps. 20 | stack_lateral. All strides suppose to be power of 2. The first 21 | one is of largest resolution. 22 | """ 23 | 24 | def __init__(self, feature_strides, **kwargs): 25 | super().__init__(input_transform='multiple_select', **kwargs) 26 | assert len(feature_strides) == len(self.in_channels) 27 | assert min(feature_strides) == feature_strides[0] 28 | self.feature_strides = feature_strides 29 | 30 | self.scale_heads = nn.ModuleList() 31 | for i in range(len(feature_strides)): 32 | head_length = max( 33 | 1, 34 | int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) 35 | scale_head = [] 36 | for k in range(head_length): 37 | scale_head.append( 38 | ConvModule( 39 | self.in_channels[i] if k == 0 else self.channels, 40 | self.channels, 41 | 3, 42 | padding=1, 43 | conv_cfg=self.conv_cfg, 44 | norm_cfg=self.norm_cfg, 45 | act_cfg=self.act_cfg)) 46 | if feature_strides[i] != feature_strides[0]: 47 | scale_head.append( 48 | Upsample( 49 | scale_factor=2, 50 | mode='bilinear', 51 | align_corners=self.align_corners)) 52 | self.scale_heads.append(nn.Sequential(*scale_head)) 53 | 54 | def forward(self, inputs): 55 | 56 | x = self._transform_inputs(inputs) 57 | 58 | output = self.scale_heads[0](x[0]) 59 | for i in range(1, len(self.feature_strides)): 60 | # non inplace 61 | output = output + resize( 62 | self.scale_heads[i](x[i]), 63 | size=output.shape[2:], 64 | mode='bilinear', 65 | align_corners=self.align_corners) 66 | 67 | output = self.cls_seg(output) 68 | return output 69 | -------------------------------------------------------------------------------- /mmseg/engine/schedulers/poly_ratio_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional 3 | 4 | from mmengine.optim.scheduler import PolyLR 5 | 6 | from mmseg.registry import PARAM_SCHEDULERS 7 | 8 | 9 | @PARAM_SCHEDULERS.register_module() 10 | class PolyLRRatio(PolyLR): 11 | """Implements polynomial learning rate decay with ratio. 12 | 13 | This scheduler adjusts the learning rate of each parameter group 14 | following a polynomial decay equation. The decay can occur in 15 | conjunction with external parameter adjustments made outside this 16 | scheduler. 17 | 18 | Args: 19 | optimizer (Optimizer or OptimWrapper): Wrapped optimizer. 20 | eta_min (float): Minimum learning rate at the end of scheduling. 21 | Defaults to 0. 22 | eta_min_ratio (float, optional): The ratio of the minimum parameter 23 | value to the base parameter value. Either `eta_min` or 24 | `eta_min_ratio` should be specified. Defaults to None. 25 | power (float): The power of the polynomial. Defaults to 1.0. 26 | begin (int): Step at which to start updating the parameters. 27 | Defaults to 0. 28 | end (int): Step at which to stop updating the parameters. 29 | Defaults to INF. 30 | last_step (int): The index of last step. Used for resume without 31 | state dict. Defaults to -1. 32 | by_epoch (bool): Whether the scheduled parameters are updated by 33 | epochs. Defaults to True. 34 | verbose (bool): Whether to print the value for each update. 35 | Defaults to False. 36 | """ 37 | 38 | def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | 41 | self.eta_min_ratio = eta_min_ratio 42 | 43 | def _get_value(self): 44 | """Compute value using chainable form of the scheduler.""" 45 | 46 | if self.last_step == 0: 47 | return [ 48 | group[self.param_name] for group in self.optimizer.param_groups 49 | ] 50 | 51 | param_groups_value = [] 52 | for base_value, param_group in zip(self.base_values, 53 | self.optimizer.param_groups): 54 | eta_min = self.eta_min if self.eta_min_ratio is None else \ 55 | base_value * self.eta_min_ratio 56 | step_ratio = (1 - 1 / 57 | (self.total_iters - self.last_step + 1))**self.power 58 | step_value = (param_group[self.param_name] - 59 | eta_min) * step_ratio + eta_min 60 | param_groups_value.append(step_value) 61 | 62 | return param_groups_value 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | 6 | from mmengine.config import Config, DictAction 7 | from mmengine.runner import Runner 8 | 9 | 10 | # TODO: support fuse_conv_bn, visualization, and format_only 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description='MMSeg test (and eval) a model') 14 | parser.add_argument('config', help='train config file path') 15 | parser.add_argument('checkpoint', help='checkpoint file') 16 | parser.add_argument( 17 | '--work-dir', 18 | help=('if specified, the evaluation metric results will be dumped' 19 | 'into the directory as json')) 20 | # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` 21 | # will pass the `--local-rank` parameter to `tools/train.py` instead 22 | # of `--local_rank`. 23 | parser.add_argument('--local_rank', '--local-rank', type=int, default=0) 24 | args = parser.parse_args() 25 | if 'LOCAL_RANK' not in os.environ: 26 | os.environ['LOCAL_RANK'] = str(args.local_rank) 27 | 28 | return args 29 | 30 | 31 | def trigger_visualization_hook(cfg, args): 32 | default_hooks = cfg.default_hooks 33 | if 'visualization' in default_hooks: 34 | visualization_hook = default_hooks['visualization'] 35 | # Turn on visualization 36 | visualization_hook['draw'] = True 37 | if args.show: 38 | visualization_hook['show'] = True 39 | visualization_hook['wait_time'] = args.wait_time 40 | if args.show_dir: 41 | visualizer = cfg.visualizer 42 | visualizer['save_dir'] = args.show_dir 43 | else: 44 | raise RuntimeError( 45 | 'VisualizationHook must be included in default_hooks.' 46 | 'refer to usage ' 47 | '"visualization=dict(type=\'VisualizationHook\')"') 48 | 49 | return cfg 50 | 51 | 52 | def main(): 53 | args = parse_args() 54 | 55 | # load config 56 | cfg = Config.fromfile(args.config) 57 | 58 | 59 | # work_dir is determined in this priority: CLI > segment in file > filename 60 | if args.work_dir is not None: 61 | # update configs according to CLI args if args.work_dir is not None 62 | cfg.work_dir = args.work_dir 63 | elif cfg.get('work_dir', None) is None: 64 | # use config filename as default work_dir if cfg.work_dir is None 65 | cfg.work_dir = osp.join('./work_dirs', 66 | osp.splitext(osp.basename(args.config))[0]) 67 | 68 | cfg.load_from = args.checkpoint 69 | 70 | # build the runner from config 71 | runner = Runner.from_cfg(cfg) 72 | 73 | # start testing 74 | runner.test() 75 | 76 | 77 | if __name__ == '__main__': 78 | main() -------------------------------------------------------------------------------- /mmseg/models/necks/multilevel_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | from mmengine.model.weight_init import xavier_init 5 | 6 | from mmseg.registry import MODELS 7 | from ..utils import resize 8 | 9 | 10 | @MODELS.register_module() 11 | class MultiLevelNeck(nn.Module): 12 | """MultiLevelNeck. 13 | 14 | A neck structure connect vit backbone and decoder_heads. 15 | 16 | Args: 17 | in_channels (List[int]): Number of input channels per scale. 18 | out_channels (int): Number of output channels (used at each scale). 19 | scales (List[float]): Scale factors for each input feature map. 20 | Default: [0.5, 1, 2, 4] 21 | norm_cfg (dict): Config dict for normalization layer. Default: None. 22 | act_cfg (dict): Config dict for activation layer in ConvModule. 23 | Default: None. 24 | """ 25 | 26 | def __init__(self, 27 | in_channels, 28 | out_channels, 29 | scales=[0.5, 1, 2, 4], 30 | norm_cfg=None, 31 | act_cfg=None): 32 | super().__init__() 33 | assert isinstance(in_channels, list) 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.scales = scales 37 | self.num_outs = len(scales) 38 | self.lateral_convs = nn.ModuleList() 39 | self.convs = nn.ModuleList() 40 | for in_channel in in_channels: 41 | self.lateral_convs.append( 42 | ConvModule( 43 | in_channel, 44 | out_channels, 45 | kernel_size=1, 46 | norm_cfg=norm_cfg, 47 | act_cfg=act_cfg)) 48 | for _ in range(self.num_outs): 49 | self.convs.append( 50 | ConvModule( 51 | out_channels, 52 | out_channels, 53 | kernel_size=3, 54 | padding=1, 55 | stride=1, 56 | norm_cfg=norm_cfg, 57 | act_cfg=act_cfg)) 58 | 59 | # default init_weights for conv(msra) and norm in ConvModule 60 | def init_weights(self): 61 | for m in self.modules(): 62 | if isinstance(m, nn.Conv2d): 63 | xavier_init(m, distribution='uniform') 64 | 65 | def forward(self, inputs): 66 | assert len(inputs) == len(self.in_channels) 67 | inputs = [ 68 | lateral_conv(inputs[i]) 69 | for i, lateral_conv in enumerate(self.lateral_convs) 70 | ] 71 | # for len(inputs) not equal to self.num_outs 72 | if len(inputs) == 1: 73 | inputs = [inputs[0] for _ in range(self.num_outs)] 74 | outs = [] 75 | for i in range(self.num_outs): 76 | x_resize = resize( 77 | inputs[i], scale_factor=self.scales[i], mode='bilinear') 78 | outs.append(self.convs[i](x_resize)) 79 | return tuple(outs) 80 | -------------------------------------------------------------------------------- /mmseg/models/utils/encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Encoding(nn.Module): 8 | """Encoding Layer: a learnable residual encoder. 9 | 10 | Input is of shape (batch_size, channels, height, width). 11 | Output is of shape (batch_size, num_codes, channels). 12 | 13 | Args: 14 | channels: dimension of the features or feature channels 15 | num_codes: number of code words 16 | """ 17 | 18 | def __init__(self, channels, num_codes): 19 | super().__init__() 20 | # init codewords and smoothing factor 21 | self.channels, self.num_codes = channels, num_codes 22 | std = 1. / ((num_codes * channels)**0.5) 23 | # [num_codes, channels] 24 | self.codewords = nn.Parameter( 25 | torch.empty(num_codes, channels, 26 | dtype=torch.float).uniform_(-std, std), 27 | requires_grad=True) 28 | # [num_codes] 29 | self.scale = nn.Parameter( 30 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 31 | requires_grad=True) 32 | 33 | @staticmethod 34 | def scaled_l2(x, codewords, scale): 35 | num_codes, channels = codewords.size() 36 | batch_size = x.size(0) 37 | reshaped_scale = scale.view((1, 1, num_codes)) 38 | expanded_x = x.unsqueeze(2).expand( 39 | (batch_size, x.size(1), num_codes, channels)) 40 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 41 | 42 | scaled_l2_norm = reshaped_scale * ( 43 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 44 | return scaled_l2_norm 45 | 46 | @staticmethod 47 | def aggregate(assignment_weights, x, codewords): 48 | num_codes, channels = codewords.size() 49 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 50 | batch_size = x.size(0) 51 | 52 | expanded_x = x.unsqueeze(2).expand( 53 | (batch_size, x.size(1), num_codes, channels)) 54 | encoded_feat = (assignment_weights.unsqueeze(3) * 55 | (expanded_x - reshaped_codewords)).sum(dim=1) 56 | return encoded_feat 57 | 58 | def forward(self, x): 59 | assert x.dim() == 4 and x.size(1) == self.channels 60 | # [batch_size, channels, height, width] 61 | batch_size = x.size(0) 62 | # [batch_size, height x width, channels] 63 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 64 | # assignment_weights: [batch_size, channels, num_codes] 65 | assignment_weights = F.softmax( 66 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 67 | # aggregate 68 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 69 | return encoded_feat 70 | 71 | def __repr__(self): 72 | repr_str = self.__class__.__name__ 73 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 74 | f'x{self.channels})' 75 | return repr_str 76 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/setr_up_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | 5 | from mmseg.registry import MODELS 6 | from ..utils import Upsample 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @MODELS.register_module() 11 | class SETRUPHead(BaseDecodeHead): 12 | """Naive upsampling head and Progressive upsampling head of SETR. 13 | 14 | Naive or PUP head of `SETR `_. 15 | 16 | Args: 17 | norm_layer (dict): Config dict for input normalization. 18 | Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). 19 | num_convs (int): Number of decoder convolutions. Default: 1. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | kernel_size (int): The kernel size of convolution when decoding 22 | feature information from backbone. Default: 3. 23 | init_cfg (dict | list[dict] | None): Initialization config dict. 24 | Default: dict( 25 | type='Constant', val=1.0, bias=0, layer='LayerNorm'). 26 | """ 27 | 28 | def __init__(self, 29 | norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), 30 | num_convs=1, 31 | up_scale=4, 32 | kernel_size=3, 33 | init_cfg=[ 34 | dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), 35 | dict( 36 | type='Normal', 37 | std=0.01, 38 | override=dict(name='conv_seg')) 39 | ], 40 | **kwargs): 41 | 42 | assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' 43 | 44 | super().__init__(init_cfg=init_cfg, **kwargs) 45 | 46 | assert isinstance(self.in_channels, int) 47 | 48 | _, self.norm = build_norm_layer(norm_layer, self.in_channels) 49 | 50 | self.up_convs = nn.ModuleList() 51 | in_channels = self.in_channels 52 | out_channels = self.channels 53 | for _ in range(num_convs): 54 | self.up_convs.append( 55 | nn.Sequential( 56 | ConvModule( 57 | in_channels=in_channels, 58 | out_channels=out_channels, 59 | kernel_size=kernel_size, 60 | stride=1, 61 | padding=int(kernel_size - 1) // 2, 62 | norm_cfg=self.norm_cfg, 63 | act_cfg=self.act_cfg), 64 | Upsample( 65 | scale_factor=up_scale, 66 | mode='bilinear', 67 | align_corners=self.align_corners))) 68 | in_channels = out_channels 69 | 70 | def forward(self, x): 71 | x = self._transform_inputs(x) 72 | 73 | n, c, h, w = x.shape 74 | x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() 75 | x = self.norm(x) 76 | x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() 77 | 78 | for up_conv in self.up_convs: 79 | x = up_conv(x) 80 | out = self.cls_seg(x) 81 | return out 82 | -------------------------------------------------------------------------------- /mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # yapf: disable 3 | from .ade import ADE20KDataset 4 | from .basesegdataset import BaseCDDataset, BaseSegDataset 5 | from .bdd100k import BDD100KDataset 6 | from .chase_db1 import ChaseDB1Dataset 7 | from .cityscapes import CityscapesDataset 8 | from .coco_stuff import COCOStuffDataset 9 | from .dark_zurich import DarkZurichDataset 10 | from .dataset_wrappers import MultiImageMixDataset 11 | from .decathlon import DecathlonDataset 12 | from .drive import DRIVEDataset 13 | from .dsdl import DSDLSegDataset 14 | from .hrf import HRFDataset 15 | from .isaid import iSAIDDataset 16 | from .isprs import ISPRSDataset 17 | from .levir import LEVIRCDDataset 18 | from .lip import LIPDataset 19 | from .loveda import LoveDADataset 20 | from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2 21 | from .night_driving import NightDrivingDataset 22 | from .nyu import NYUDataset 23 | from .pascal_context import PascalContextDataset, PascalContextDataset59 24 | from .potsdam import PotsdamDataset 25 | from .refuge import REFUGEDataset 26 | from .stare import STAREDataset 27 | from .synapse import SynapseDataset 28 | # yapf: disable 29 | from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, 30 | BioMedical3DRandomCrop, BioMedical3DRandomFlip, 31 | BioMedicalGaussianBlur, BioMedicalGaussianNoise, 32 | BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, 33 | LoadAnnotations, LoadBiomedicalAnnotation, 34 | LoadBiomedicalData, LoadBiomedicalImageFromFile, 35 | LoadImageFromNDArray, LoadMultipleRSImageFromFile, 36 | LoadSingleRSImageFromFile, PackSegInputs, 37 | PhotoMetricDistortion, RandomCrop, RandomCutOut, 38 | RandomMosaic, RandomRotate, RandomRotFlip, Rerange, 39 | ResizeShortestEdge, ResizeToMultiple, RGB2Gray, 40 | SegRescale) 41 | from .voc import PascalVOCDataset 42 | 43 | # yapf: enable 44 | __all__ = [ 45 | 'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip', 46 | 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', 47 | 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', 48 | 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', 49 | 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset', 50 | 'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 51 | 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 52 | 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 53 | 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 54 | 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 55 | 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 56 | 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge', 57 | 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', 58 | 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', 59 | 'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1', 60 | 'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset', 61 | 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile', 62 | 'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset', 63 | 'NYUDataset' 64 | ] 65 | -------------------------------------------------------------------------------- /mmseg/structures/seg_data_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.structures import BaseDataElement, PixelData 3 | 4 | 5 | class SegDataSample(BaseDataElement): 6 | """A data structure interface of MMSegmentation. They are used as 7 | interfaces between different components. 8 | 9 | The attributes in ``SegDataSample`` are divided into several parts: 10 | 11 | - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation. 12 | - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. 13 | - ``seg_logits``(PixelData): Predicted logits of semantic segmentation. 14 | 15 | Examples: 16 | >>> import torch 17 | >>> import numpy as np 18 | >>> from mmengine.structures import PixelData 19 | >>> from mmseg.structures import SegDataSample 20 | 21 | >>> data_sample = SegDataSample() 22 | >>> img_meta = dict(img_shape=(4, 4, 3), 23 | ... pad_shape=(4, 4, 3)) 24 | >>> gt_segmentations = PixelData(metainfo=img_meta) 25 | >>> gt_segmentations.data = torch.randint(0, 2, (1, 4, 4)) 26 | >>> data_sample.gt_sem_seg = gt_segmentations 27 | >>> assert 'img_shape' in data_sample.gt_sem_seg.metainfo_keys() 28 | >>> data_sample.gt_sem_seg.shape 29 | (4, 4) 30 | >>> print(data_sample) 31 | 48 | ) at 0x1c2aae44d60> 49 | 50 | >>> data_sample = SegDataSample() 51 | >>> gt_sem_seg_data = dict(sem_seg=torch.rand(1, 4, 4)) 52 | >>> gt_sem_seg = PixelData(**gt_sem_seg_data) 53 | >>> data_sample.gt_sem_seg = gt_sem_seg 54 | >>> assert 'gt_sem_seg' in data_sample 55 | >>> assert 'sem_seg' in data_sample.gt_sem_seg 56 | """ 57 | 58 | @property 59 | def gt_sem_seg(self) -> PixelData: 60 | return self._gt_sem_seg 61 | 62 | @gt_sem_seg.setter 63 | def gt_sem_seg(self, value: PixelData) -> None: 64 | self.set_field(value, '_gt_sem_seg', dtype=PixelData) 65 | 66 | @gt_sem_seg.deleter 67 | def gt_sem_seg(self) -> None: 68 | del self._gt_sem_seg 69 | 70 | @property 71 | def pred_sem_seg(self) -> PixelData: 72 | return self._pred_sem_seg 73 | 74 | @pred_sem_seg.setter 75 | def pred_sem_seg(self, value: PixelData) -> None: 76 | self.set_field(value, '_pred_sem_seg', dtype=PixelData) 77 | 78 | @pred_sem_seg.deleter 79 | def pred_sem_seg(self) -> None: 80 | del self._pred_sem_seg 81 | 82 | @property 83 | def seg_logits(self) -> PixelData: 84 | return self._seg_logits 85 | 86 | @seg_logits.setter 87 | def seg_logits(self, value: PixelData) -> None: 88 | self.set_field(value, '_seg_logits', dtype=PixelData) 89 | 90 | @seg_logits.deleter 91 | def seg_logits(self) -> None: 92 | del self._seg_logits 93 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/lraspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | from mmengine.utils import is_tuple_of 6 | 7 | from mmseg.registry import MODELS 8 | from ..utils import resize 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | @MODELS.register_module() 13 | class LRASPPHead(BaseDecodeHead): 14 | """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. 15 | 16 | This head is the improved implementation of `Searching for MobileNetV3 17 | `_. 18 | 19 | Args: 20 | branch_channels (tuple[int]): The number of output channels in every 21 | each branch. Default: (32, 64). 22 | """ 23 | 24 | def __init__(self, branch_channels=(32, 64), **kwargs): 25 | super().__init__(**kwargs) 26 | if self.input_transform != 'multiple_select': 27 | raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' 28 | f'must be \'multiple_select\'. But received ' 29 | f'\'{self.input_transform}\'') 30 | assert is_tuple_of(branch_channels, int) 31 | assert len(branch_channels) == len(self.in_channels) - 1 32 | self.branch_channels = branch_channels 33 | 34 | self.convs = nn.Sequential() 35 | self.conv_ups = nn.Sequential() 36 | for i in range(len(branch_channels)): 37 | self.convs.add_module( 38 | f'conv{i}', 39 | nn.Conv2d( 40 | self.in_channels[i], branch_channels[i], 1, bias=False)) 41 | self.conv_ups.add_module( 42 | f'conv_up{i}', 43 | ConvModule( 44 | self.channels + branch_channels[i], 45 | self.channels, 46 | 1, 47 | norm_cfg=self.norm_cfg, 48 | act_cfg=self.act_cfg, 49 | bias=False)) 50 | 51 | self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) 52 | 53 | self.aspp_conv = ConvModule( 54 | self.in_channels[-1], 55 | self.channels, 56 | 1, 57 | norm_cfg=self.norm_cfg, 58 | act_cfg=self.act_cfg, 59 | bias=False) 60 | self.image_pool = nn.Sequential( 61 | nn.AvgPool2d(kernel_size=49, stride=(16, 20)), 62 | ConvModule( 63 | self.in_channels[2], 64 | self.channels, 65 | 1, 66 | act_cfg=dict(type='Sigmoid'), 67 | bias=False)) 68 | 69 | def forward(self, inputs): 70 | """Forward function.""" 71 | inputs = self._transform_inputs(inputs) 72 | 73 | x = inputs[-1] 74 | 75 | x = self.aspp_conv(x) * resize( 76 | self.image_pool(x), 77 | size=x.size()[2:], 78 | mode='bilinear', 79 | align_corners=self.align_corners) 80 | x = self.conv_up_input(x) 81 | 82 | for i in range(len(self.branch_channels) - 1, -1, -1): 83 | x = resize( 84 | x, 85 | size=inputs[i].size()[2:], 86 | mode='bilinear', 87 | align_corners=self.align_corners) 88 | x = torch.cat([x, self.convs[i](inputs[i])], 1) 89 | x = self.conv_ups[i](x) 90 | 91 | return self.cls_seg(x) 92 | -------------------------------------------------------------------------------- /AgriFM/datasets/mapping_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from mmseg.registry.registry import MODELS,DATASETS 4 | import numpy as np 5 | import h5py 6 | from AgriFM.datasets.transform import MapCompose 7 | import torch.utils.data as data_utils 8 | 9 | @DATASETS.register_module() 10 | class MappingDataset(data_utils.Dataset): 11 | def __init__(self,data_toot_path, data_list_file, 12 | data_pipelines,data_keys=('S2',),label_key='label'): 13 | ''' 14 | Provide general data loading and preprocessing for mapping dataset, 15 | with data format h5. 16 | :param data_toot_path: directory of h5 files 17 | :param data_list_file: data list file, which contains the list of h5 files. 18 | :param data_pipelines: 19 | data pipelines for data preprocessing, which is a list of dicts. 20 | Each dict should contain the key 'type' and other keys for the specific transform. 21 | The 'type' should be the name of the transform class. 22 | :param data_keys: 23 | keys of the data to be loaded from h5 files, default is ('S2',). 24 | Support multiple keys, such as ('S2', 'Modis', 'Landsat'). 25 | These keys must be in the h5 files. 26 | The data will be loaded from the h5 file with these keys. 27 | :param label_key: 28 | key of the label to be loaded from h5 files, default is 'label'. 29 | The label will be loaded from the h5 file with this key. 30 | ''' 31 | self.data_toot_path = data_toot_path 32 | self.data_list_file = data_list_file 33 | self.data_pipelines = data_pipelines 34 | self.data_keys = data_keys 35 | self.label_key = label_key 36 | self.data_list = np.loadtxt(data_list_file, dtype=str).tolist() 37 | self.data_pipelines = MapCompose(self.data_pipelines) 38 | 39 | 40 | def __len__(self): 41 | return len(self.data_list) 42 | 43 | def __getitem__(self, item): 44 | ''' 45 | Load data from h5 file according to the item index. 46 | :param item: 47 | :return: 48 | A dict containing multi-source data. 49 | The data is a dict with keys as self.data_keys. 50 | Each value in the dict is a tensor with shape (T,C, H, W), 51 | where T is the number of time steps (1 for single image), 52 | C is the number of channels, H is the height, and W is the width. 53 | where C is the number of channels, H is the height, and W is the width. 54 | The label is a tensor with key self.label_key. 55 | ''' 56 | file_name= self.data_list[item] 57 | data_file= os.path.join(self.data_toot_path, '%s.h5'%file_name) 58 | data_dict={} 59 | with h5py.File(data_file, 'r') as f: 60 | for key in self.data_keys: 61 | if key in f.keys(): 62 | data_dict[key] = torch.from_numpy(f[key][:]) 63 | else: 64 | raise KeyError(f'{key} not found in {data_file}') 65 | if self.label_key in f.keys(): 66 | label= torch.from_numpy(f[self.label_key][:]) 67 | else: 68 | raise KeyError(f'{self.label_key} not found in {data_file}') 69 | label=torch.unsqueeze(label,dim=0) 70 | data_dict,label=self.data_pipelines(data_dict,label) 71 | label= label.squeeze(dim=0) 72 | data_dict['file_name'] = file_name 73 | return data_dict, label.long() 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /mmseg/datasets/decathlon.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import os.path as osp 4 | from typing import List 5 | 6 | from mmengine.fileio import load 7 | 8 | from mmseg.registry import DATASETS 9 | from .basesegdataset import BaseSegDataset 10 | 11 | 12 | @DATASETS.register_module() 13 | class DecathlonDataset(BaseSegDataset): 14 | """Dataset for Dacathlon dataset. 15 | 16 | The dataset.json format is shown as follows 17 | 18 | .. code-block:: none 19 | 20 | { 21 | "name": "BRATS", 22 | "tensorImageSize": "4D", 23 | "modality": 24 | { 25 | "0": "FLAIR", 26 | "1": "T1w", 27 | "2": "t1gd", 28 | "3": "T2w" 29 | }, 30 | "labels": { 31 | "0": "background", 32 | "1": "edema", 33 | "2": "non-enhancing tumor", 34 | "3": "enhancing tumour" 35 | }, 36 | "numTraining": 484, 37 | "numTest": 266, 38 | "training": 39 | [ 40 | { 41 | "image": "./imagesTr/BRATS_306.nii.gz" 42 | "label": "./labelsTr/BRATS_306.nii.gz" 43 | ... 44 | } 45 | ] 46 | "test": 47 | [ 48 | "./imagesTs/BRATS_557.nii.gz" 49 | ... 50 | ] 51 | } 52 | """ 53 | 54 | def load_data_list(self) -> List[dict]: 55 | """Load annotation from directory or annotation file. 56 | 57 | Returns: 58 | list[dict]: All data info of dataset. 59 | """ 60 | # `self.ann_file` denotes the absolute annotation file path if 61 | # `self.root=None` or relative path if `self.root=/path/to/data/`. 62 | annotations = load(self.ann_file) 63 | if not isinstance(annotations, dict): 64 | raise TypeError(f'The annotations loaded from annotation file ' 65 | f'should be a dict, but got {type(annotations)}!') 66 | raw_data_list = annotations[ 67 | 'training'] if not self.test_mode else annotations['test'] 68 | data_list = [] 69 | for raw_data_info in raw_data_list: 70 | # `2:` works for removing './' in file path, which will break 71 | # loading from cloud storage. 72 | if isinstance(raw_data_info, dict): 73 | data_info = dict( 74 | img_path=osp.join(self.data_root, raw_data_info['image'] 75 | [2:])) 76 | data_info['seg_map_path'] = osp.join( 77 | self.data_root, raw_data_info['label'][2:]) 78 | else: 79 | data_info = dict( 80 | img_path=osp.join(self.data_root, raw_data_info)[2:]) 81 | data_info['label_map'] = self.label_map 82 | data_info['reduce_zero_label'] = self.reduce_zero_label 83 | data_info['seg_fields'] = [] 84 | data_list.append(data_info) 85 | annotations.pop('training') 86 | annotations.pop('test') 87 | 88 | metainfo = copy.deepcopy(annotations) 89 | metainfo['classes'] = [*metainfo['labels'].values()] 90 | # Meta information load from annotation file will not influence the 91 | # existed meta information load from `BaseDataset.METAINFO` and 92 | # `metainfo` arguments defined in constructor. 93 | for k, v in metainfo.items(): 94 | self._metainfo.setdefault(k, v) 95 | 96 | return data_list 97 | -------------------------------------------------------------------------------- /AgriFM/models/neck.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | from mmseg.models.builder import NECKS,MODELS 6 | from mmengine.runner import load_checkpoint 7 | 8 | 9 | @NECKS.register_module() 10 | class MultiFusionNeck(nn.Module): 11 | def __init__(self,embed_dim,in_feature_key=('S2',), 12 | feature_size=(16,16),out_size=(256,256), 13 | in_fusion_key_list=({'S2':512,'HLS':512}, 14 | {'S2':256}, 15 | {'S2':128,}, 16 | ) 17 | ): 18 | super(MultiFusionNeck, self).__init__() 19 | self.embed_dim=embed_dim 20 | self.fusion_list=nn.ModuleList() 21 | self.in_feature_key=in_feature_key 22 | self.feature_size=feature_size 23 | self.out_size=out_size 24 | self.in_fusion_key_list=in_fusion_key_list 25 | 26 | if len(in_feature_key)==1: 27 | self.in_conv=nn.Identity() 28 | else: 29 | self.in_conv=nn.Sequential( 30 | nn.Conv2d(len(in_feature_key)*self.embed_dim,self.embed_dim,3,1,1), 31 | nn.BatchNorm2d(self.embed_dim), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(self.embed_dim,self.embed_dim,3,1,1), 34 | ) 35 | embed_dim=self.embed_dim 36 | pre_embed = embed_dim 37 | for fusion_keys in in_fusion_key_list: 38 | in_embed=sum(fusion_keys.values()) 39 | fusion=nn.Sequential( 40 | nn.Conv2d(in_embed+pre_embed,pre_embed,3,1,1), 41 | nn.BatchNorm2d(pre_embed), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(pre_embed,embed_dim,3,1,1), 44 | ) 45 | self.fusion_list.append(fusion) 46 | pre_embed=embed_dim 47 | 48 | 49 | self.out_conv=nn.Sequential( 50 | nn.Conv2d(pre_embed,pre_embed,3,1,1), 51 | nn.BatchNorm2d(pre_embed), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(pre_embed,pre_embed,3,1,1), 54 | ) 55 | 56 | def forward(self,inputs): 57 | in_features=[] 58 | for key in self.in_feature_key: 59 | features=inputs[key]['encoder_features'] 60 | features=torch.nn.functional.interpolate(features,self.feature_size,mode='bilinear',align_corners=False) 61 | in_features.append(features) 62 | in_features=torch.cat(in_features,dim=1) 63 | in_features=self.in_conv(in_features) 64 | 65 | for i,fusion_keys in enumerate(self.in_fusion_key_list): 66 | in_features=torch.nn.functional.interpolate(in_features,scale_factor=2,mode='bilinear',align_corners=False) 67 | in_features_h, in_features_w=in_features.shape[-2:] 68 | in_features_idx=len(self.in_fusion_key_list)-i-1 69 | fusion_features=[] 70 | for key in fusion_keys: 71 | features=inputs[key]['features_list'][in_features_idx] 72 | features=torch.nn.functional.interpolate(features,(in_features_h,in_features_w),mode='bilinear',align_corners=False) 73 | fusion_features.append(features) 74 | fusion_features=torch.cat(fusion_features,dim=1) 75 | in_features=torch.cat([in_features,fusion_features],dim=1) 76 | in_features=self.fusion_list[i](in_features) 77 | out_features=self.out_conv(in_features) 78 | out_features=torch.nn.functional.interpolate(out_features,self.out_size,mode='bilinear',align_corners=False) 79 | return out_features 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.registry import MODELS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @MODELS.register_module() 11 | class FCNHead(BaseDecodeHead): 12 | """Fully Convolution Networks for Semantic Segmentation. 13 | 14 | This head is implemented of `FCNNet `_. 15 | 16 | Args: 17 | num_convs (int): Number of convs in the head. Default: 2. 18 | kernel_size (int): The kernel size for convs in the head. Default: 3. 19 | concat_input (bool): Whether concat the input and output of convs 20 | before classification layer. 21 | dilation (int): The dilation rate for convs in the head. Default: 1. 22 | """ 23 | 24 | def __init__(self, 25 | num_convs=2, 26 | kernel_size=3, 27 | concat_input=True, 28 | dilation=1, 29 | **kwargs): 30 | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) 31 | self.num_convs = num_convs 32 | self.concat_input = concat_input 33 | self.kernel_size = kernel_size 34 | super().__init__(**kwargs) 35 | if num_convs == 0: 36 | assert self.in_channels == self.channels 37 | 38 | conv_padding = (kernel_size // 2) * dilation 39 | convs = [] 40 | convs.append( 41 | ConvModule( 42 | self.in_channels, 43 | self.channels, 44 | kernel_size=kernel_size, 45 | padding=conv_padding, 46 | dilation=dilation, 47 | conv_cfg=self.conv_cfg, 48 | norm_cfg=self.norm_cfg, 49 | act_cfg=self.act_cfg)) 50 | for i in range(num_convs - 1): 51 | convs.append( 52 | ConvModule( 53 | self.channels, 54 | self.channels, 55 | kernel_size=kernel_size, 56 | padding=conv_padding, 57 | dilation=dilation, 58 | conv_cfg=self.conv_cfg, 59 | norm_cfg=self.norm_cfg, 60 | act_cfg=self.act_cfg)) 61 | if num_convs == 0: 62 | self.convs = nn.Identity() 63 | else: 64 | self.convs = nn.Sequential(*convs) 65 | if self.concat_input: 66 | self.conv_cat = ConvModule( 67 | self.in_channels + self.channels, 68 | self.channels, 69 | kernel_size=kernel_size, 70 | padding=kernel_size // 2, 71 | conv_cfg=self.conv_cfg, 72 | norm_cfg=self.norm_cfg, 73 | act_cfg=self.act_cfg) 74 | 75 | def _forward_feature(self, inputs): 76 | """Forward function for feature maps before classifying each pixel with 77 | ``self.cls_seg`` fc. 78 | 79 | Args: 80 | inputs (list[Tensor]): List of multi-level img features. 81 | 82 | Returns: 83 | feats (Tensor): A tensor of shape (batch_size, self.channels, 84 | H, W) which is feature map for last layer of decoder head. 85 | """ 86 | x = self._transform_inputs(inputs) 87 | feats = self.convs(x) 88 | if self.concat_input: 89 | feats = self.conv_cat(torch.cat([x, feats], dim=1)) 90 | return feats 91 | 92 | def forward(self, inputs): 93 | """Forward function.""" 94 | output = self._forward_feature(inputs) 95 | output = self.cls_seg(output) 96 | return output 97 | -------------------------------------------------------------------------------- /mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import build_conv_layer, build_norm_layer 3 | from mmengine.model import Sequential 4 | from torch import nn as nn 5 | 6 | 7 | class ResLayer(Sequential): 8 | """ResLayer to build ResNet style backbone. 9 | 10 | Args: 11 | block (nn.Module): block used to build ResLayer. 12 | inplanes (int): inplanes of block. 13 | planes (int): planes of block. 14 | num_blocks (int): number of blocks. 15 | stride (int): stride of the first block. Default: 1 16 | avg_down (bool): Use AvgPool instead of stride conv when 17 | downsampling in the bottleneck. Default: False 18 | conv_cfg (dict): dictionary to construct and config conv layer. 19 | Default: None 20 | norm_cfg (dict): dictionary to construct and config norm layer. 21 | Default: dict(type='BN') 22 | multi_grid (int | None): Multi grid dilation rates of last 23 | stage. Default: None 24 | contract_dilation (bool): Whether contract first dilation of each layer 25 | Default: False 26 | """ 27 | 28 | def __init__(self, 29 | block, 30 | inplanes, 31 | planes, 32 | num_blocks, 33 | stride=1, 34 | dilation=1, 35 | avg_down=False, 36 | conv_cfg=None, 37 | norm_cfg=dict(type='BN'), 38 | multi_grid=None, 39 | contract_dilation=False, 40 | **kwargs): 41 | self.block = block 42 | 43 | downsample = None 44 | if stride != 1 or inplanes != planes * block.expansion: 45 | downsample = [] 46 | conv_stride = stride 47 | if avg_down: 48 | conv_stride = 1 49 | downsample.append( 50 | nn.AvgPool2d( 51 | kernel_size=stride, 52 | stride=stride, 53 | ceil_mode=True, 54 | count_include_pad=False)) 55 | downsample.extend([ 56 | build_conv_layer( 57 | conv_cfg, 58 | inplanes, 59 | planes * block.expansion, 60 | kernel_size=1, 61 | stride=conv_stride, 62 | bias=False), 63 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 64 | ]) 65 | downsample = nn.Sequential(*downsample) 66 | 67 | layers = [] 68 | if multi_grid is None: 69 | if dilation > 1 and contract_dilation: 70 | first_dilation = dilation // 2 71 | else: 72 | first_dilation = dilation 73 | else: 74 | first_dilation = multi_grid[0] 75 | layers.append( 76 | block( 77 | inplanes=inplanes, 78 | planes=planes, 79 | stride=stride, 80 | dilation=first_dilation, 81 | downsample=downsample, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | **kwargs)) 85 | inplanes = planes * block.expansion 86 | for i in range(1, num_blocks): 87 | layers.append( 88 | block( 89 | inplanes=inplanes, 90 | planes=planes, 91 | stride=1, 92 | dilation=dilation if multi_grid is None else multi_grid[i], 93 | conv_cfg=conv_cfg, 94 | norm_cfg=norm_cfg, 95 | **kwargs)) 96 | super().__init__(*layers) 97 | -------------------------------------------------------------------------------- /mmseg/models/assigners/hungarian_assigner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List, Union 3 | 4 | import torch 5 | from mmengine import ConfigDict 6 | from mmengine.structures import InstanceData 7 | from scipy.optimize import linear_sum_assignment 8 | from torch.cuda.amp import autocast 9 | 10 | from mmseg.registry import TASK_UTILS 11 | from .base_assigner import BaseAssigner 12 | 13 | 14 | @TASK_UTILS.register_module() 15 | class HungarianAssigner(BaseAssigner): 16 | """Computes one-to-one matching between prediction masks and ground truth. 17 | 18 | This class uses bipartite matching-based assignment to computes an 19 | assignment between the prediction masks and the ground truth. The 20 | assignment result is based on the weighted sum of match costs. The 21 | Hungarian algorithm is used to calculate the best matching with the 22 | minimum cost. The prediction masks that are not matched are classified 23 | as background. 24 | 25 | Args: 26 | match_costs (ConfigDict|List[ConfigDict]): Match cost configs. 27 | """ 28 | 29 | def __init__( 30 | self, match_costs: Union[List[Union[dict, ConfigDict]], dict, 31 | ConfigDict] 32 | ) -> None: 33 | 34 | if isinstance(match_costs, dict): 35 | match_costs = [match_costs] 36 | elif isinstance(match_costs, list): 37 | assert len(match_costs) > 0, \ 38 | 'match_costs must not be a empty list.' 39 | 40 | self.match_costs = [ 41 | TASK_UTILS.build(match_cost) for match_cost in match_costs 42 | ] 43 | 44 | def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, 45 | **kwargs): 46 | """Computes one-to-one matching based on the weighted costs. 47 | 48 | This method assign each query prediction to a ground truth or 49 | background. The assignment first calculates the cost for each 50 | category assigned to each query mask, and then uses the 51 | Hungarian algorithm to calculate the minimum cost as the best 52 | match. 53 | 54 | Args: 55 | pred_instances (InstanceData): Instances of model 56 | predictions. It includes "masks", with shape 57 | (n, h, w) or (n, l), and "cls", with shape (n, num_classes+1) 58 | gt_instances (InstanceData): Ground truth of instance 59 | annotations. It includes "labels", with shape (k, ), 60 | and "masks", with shape (k, h, w) or (k, l). 61 | 62 | Returns: 63 | matched_quiery_inds (Tensor): The indexes of matched quieres. 64 | matched_label_inds (Tensor): The indexes of matched labels. 65 | """ 66 | # compute weighted cost 67 | cost_list = [] 68 | with autocast(enabled=False): 69 | for match_cost in self.match_costs: 70 | cost = match_cost( 71 | pred_instances=pred_instances, gt_instances=gt_instances) 72 | cost_list.append(cost) 73 | cost = torch.stack(cost_list).sum(dim=0) 74 | 75 | device = cost.device 76 | # do Hungarian matching on CPU using linear_sum_assignment 77 | cost = cost.detach().cpu() 78 | if linear_sum_assignment is None: 79 | raise ImportError('Please run "pip install scipy" ' 80 | 'to install scipy first.') 81 | 82 | matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost) 83 | matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device) 84 | matched_label_inds = torch.from_numpy(matched_label_inds).to(device) 85 | 86 | return matched_quiery_inds, matched_label_inds 87 | -------------------------------------------------------------------------------- /mmseg/structures/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .base_pixel_sampler import BasePixelSampler 7 | from .builder import PIXEL_SAMPLERS 8 | 9 | 10 | @PIXEL_SAMPLERS.register_module() 11 | class OHEMPixelSampler(BasePixelSampler): 12 | """Online Hard Example Mining Sampler for segmentation. 13 | 14 | Args: 15 | context (nn.Module): The context of sampler, subclass of 16 | :obj:`BaseDecodeHead`. 17 | thresh (float, optional): The threshold for hard example selection. 18 | Below which, are prediction with low confidence. If not 19 | specified, the hard examples will be pixels of top ``min_kept`` 20 | loss. Default: None. 21 | min_kept (int, optional): The minimum number of predictions to keep. 22 | Default: 100000. 23 | """ 24 | 25 | def __init__(self, context, thresh=None, min_kept=100000): 26 | super().__init__() 27 | self.context = context 28 | assert min_kept > 1 29 | self.thresh = thresh 30 | self.min_kept = min_kept 31 | 32 | def sample(self, seg_logit, seg_label): 33 | """Sample pixels that have high loss or with low prediction confidence. 34 | 35 | Args: 36 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 37 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 38 | 39 | Returns: 40 | torch.Tensor: segmentation weight, shape (N, H, W) 41 | """ 42 | with torch.no_grad(): 43 | assert seg_logit.shape[2:] == seg_label.shape[2:] 44 | assert seg_label.shape[1] == 1 45 | seg_label = seg_label.squeeze(1).long() 46 | batch_kept = self.min_kept * seg_label.size(0) 47 | valid_mask = seg_label != self.context.ignore_index 48 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 49 | valid_seg_weight = seg_weight[valid_mask] 50 | if self.thresh is not None: 51 | seg_prob = F.softmax(seg_logit, dim=1) 52 | 53 | tmp_seg_label = seg_label.clone().unsqueeze(1) 54 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 55 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 56 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 57 | 58 | if sort_prob.numel() > 0: 59 | min_threshold = sort_prob[min(batch_kept, 60 | sort_prob.numel() - 1)] 61 | else: 62 | min_threshold = 0.0 63 | threshold = max(min_threshold, self.thresh) 64 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 65 | else: 66 | if not isinstance(self.context.loss_decode, nn.ModuleList): 67 | losses_decode = [self.context.loss_decode] 68 | else: 69 | losses_decode = self.context.loss_decode 70 | losses = 0.0 71 | for loss_module in losses_decode: 72 | losses += loss_module( 73 | seg_logit, 74 | seg_label, 75 | weight=None, 76 | ignore_index=self.context.ignore_index, 77 | reduction_override='none') 78 | 79 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 80 | _, sort_indices = losses[valid_mask].sort(descending=True) 81 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 82 | 83 | seg_weight[valid_mask] = valid_seg_weight 84 | 85 | return seg_weight 86 | -------------------------------------------------------------------------------- /mmseg/models/losses/ohem_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List, Optional, Union 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from mmseg.registry import MODELS 9 | 10 | 11 | @MODELS.register_module() 12 | class OhemCrossEntropy(nn.Module): 13 | """OhemCrossEntropy loss. 14 | 15 | This func is modified from 16 | `PIDNet `_. # noqa 17 | 18 | Licensed under the MIT License. 19 | 20 | Args: 21 | ignore_label (int): Labels to ignore when computing the loss. 22 | Default: 255 23 | thresh (float, optional): The threshold for hard example selection. 24 | Below which, are prediction with low confidence. If not 25 | specified, the hard examples will be pixels of top ``min_kept`` 26 | loss. Default: 0.7. 27 | min_kept (int, optional): The minimum number of predictions to keep. 28 | Default: 100000. 29 | loss_weight (float): Weight of the loss. Defaults to 1.0. 30 | class_weight (list[float] | str, optional): Weight of each class. If in 31 | str format, read them from a file. Defaults to None. 32 | loss_name (str): Name of the loss item. If you want this loss 33 | item to be included into the backward graph, `loss_` must be the 34 | prefix of the name. Defaults to 'loss_boundary'. 35 | """ 36 | 37 | def __init__(self, 38 | ignore_label: int = 255, 39 | thres: float = 0.7, 40 | min_kept: int = 100000, 41 | loss_weight: float = 1.0, 42 | class_weight: Optional[Union[List[float], str]] = None, 43 | loss_name: str = 'loss_ohem'): 44 | super().__init__() 45 | self.thresh = thres 46 | self.min_kept = max(1, min_kept) 47 | self.ignore_label = ignore_label 48 | self.loss_weight = loss_weight 49 | self.loss_name_ = loss_name 50 | self.class_weight = class_weight 51 | 52 | def forward(self, score: Tensor, target: Tensor) -> Tensor: 53 | """Forward function. 54 | Args: 55 | score (Tensor): Predictions of the segmentation head. 56 | target (Tensor): Ground truth of the image. 57 | 58 | Returns: 59 | Tensor: Loss tensor. 60 | """ 61 | # score: (N, C, H, W) 62 | pred = F.softmax(score, dim=1) 63 | if self.class_weight is not None: 64 | class_weight = score.new_tensor(self.class_weight) 65 | else: 66 | class_weight = None 67 | 68 | pixel_losses = F.cross_entropy( 69 | score, 70 | target, 71 | weight=class_weight, 72 | ignore_index=self.ignore_label, 73 | reduction='none').contiguous().view(-1) # (N*H*W) 74 | mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W) 75 | 76 | tmp_target = target.clone() # (N, H, W) 77 | tmp_target[tmp_target == self.ignore_label] = 0 78 | # pred: (N, C, H, W) -> (N*H*W, C) 79 | pred = pred.gather(1, tmp_target.unsqueeze(1)) 80 | # pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W) 81 | pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort() 82 | if pred.numel() > 0: 83 | min_value = pred[min(self.min_kept, pred.numel() - 1)] 84 | else: 85 | return score.new_tensor(0.0) 86 | threshold = max(min_value, self.thresh) 87 | 88 | pixel_losses = pixel_losses[mask][ind] 89 | pixel_losses = pixel_losses[pred < threshold] 90 | return self.loss_weight * pixel_losses.mean() 91 | 92 | @property 93 | def loss_name(self): 94 | return self.loss_name_ 95 | -------------------------------------------------------------------------------- /mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): 7 | """Calculate accuracy according to the prediction and target. 8 | 9 | Args: 10 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 11 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 12 | ignore_index (int | None): The label index to be ignored. Default: None 13 | topk (int | tuple[int], optional): If the predictions in ``topk`` 14 | matches the target, the predictions will be regarded as 15 | correct ones. Defaults to 1. 16 | thresh (float, optional): If not None, predictions with scores under 17 | this threshold are considered incorrect. Default to None. 18 | 19 | Returns: 20 | float | tuple[float]: If the input ``topk`` is a single integer, 21 | the function will return a single float as accuracy. If 22 | ``topk`` is a tuple containing multiple integers, the 23 | function will return a tuple containing accuracies of 24 | each ``topk`` number. 25 | """ 26 | assert isinstance(topk, (int, tuple)) 27 | if isinstance(topk, int): 28 | topk = (topk, ) 29 | return_single = True 30 | else: 31 | return_single = False 32 | 33 | maxk = max(topk) 34 | if pred.size(0) == 0: 35 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 36 | return accu[0] if return_single else accu 37 | assert pred.ndim == target.ndim + 1 38 | assert pred.size(0) == target.size(0) 39 | assert maxk <= pred.size(1), \ 40 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 41 | pred_value, pred_label = pred.topk(maxk, dim=1) 42 | # transpose to shape (maxk, N, ...) 43 | pred_label = pred_label.transpose(0, 1) 44 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 45 | if thresh is not None: 46 | # Only prediction values larger than thresh are counted as correct 47 | correct = correct & (pred_value > thresh).t() 48 | if ignore_index is not None: 49 | correct = correct[:, target != ignore_index] 50 | res = [] 51 | eps = torch.finfo(torch.float32).eps 52 | for k in topk: 53 | # Avoid causing ZeroDivisionError when all pixels 54 | # of an image are ignored 55 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps 56 | if ignore_index is not None: 57 | total_num = target[target != ignore_index].numel() + eps 58 | else: 59 | total_num = target.numel() + eps 60 | res.append(correct_k.mul_(100.0 / total_num)) 61 | return res[0] if return_single else res 62 | 63 | 64 | class Accuracy(nn.Module): 65 | """Accuracy calculation module.""" 66 | 67 | def __init__(self, topk=(1, ), thresh=None, ignore_index=None): 68 | """Module to calculate the accuracy. 69 | 70 | Args: 71 | topk (tuple, optional): The criterion used to calculate the 72 | accuracy. Defaults to (1,). 73 | thresh (float, optional): If not None, predictions with scores 74 | under this threshold are considered incorrect. Default to None. 75 | """ 76 | super().__init__() 77 | self.topk = topk 78 | self.thresh = thresh 79 | self.ignore_index = ignore_index 80 | 81 | def forward(self, pred, target): 82 | """Forward function to calculate accuracy. 83 | 84 | Args: 85 | pred (torch.Tensor): Prediction of models. 86 | target (torch.Tensor): Target for each prediction. 87 | 88 | Returns: 89 | tuple[float]: The accuracies under different topk criterions. 90 | """ 91 | return accuracy(pred, target, self.topk, self.thresh, 92 | self.ignore_index) 93 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/sep_aspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 5 | 6 | from mmseg.registry import MODELS 7 | from ..utils import resize 8 | from .aspp_head import ASPPHead, ASPPModule 9 | 10 | 11 | class DepthwiseSeparableASPPModule(ASPPModule): 12 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 13 | conv.""" 14 | 15 | def __init__(self, **kwargs): 16 | super().__init__(**kwargs) 17 | for i, dilation in enumerate(self.dilations): 18 | if dilation > 1: 19 | self[i] = DepthwiseSeparableConvModule( 20 | self.in_channels, 21 | self.channels, 22 | 3, 23 | dilation=dilation, 24 | padding=dilation, 25 | norm_cfg=self.norm_cfg, 26 | act_cfg=self.act_cfg) 27 | 28 | 29 | @MODELS.register_module() 30 | class DepthwiseSeparableASPPHead(ASPPHead): 31 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 32 | Segmentation. 33 | 34 | This head is the implementation of `DeepLabV3+ 35 | `_. 36 | 37 | Args: 38 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 39 | the no decoder will be used. 40 | c1_channels (int): The intermediate channels of c1 decoder. 41 | """ 42 | 43 | def __init__(self, c1_in_channels, c1_channels, **kwargs): 44 | super().__init__(**kwargs) 45 | assert c1_in_channels >= 0 46 | self.aspp_modules = DepthwiseSeparableASPPModule( 47 | dilations=self.dilations, 48 | in_channels=self.in_channels, 49 | channels=self.channels, 50 | conv_cfg=self.conv_cfg, 51 | norm_cfg=self.norm_cfg, 52 | act_cfg=self.act_cfg) 53 | if c1_in_channels > 0: 54 | self.c1_bottleneck = ConvModule( 55 | c1_in_channels, 56 | c1_channels, 57 | 1, 58 | conv_cfg=self.conv_cfg, 59 | norm_cfg=self.norm_cfg, 60 | act_cfg=self.act_cfg) 61 | else: 62 | self.c1_bottleneck = None 63 | self.sep_bottleneck = nn.Sequential( 64 | DepthwiseSeparableConvModule( 65 | self.channels + c1_channels, 66 | self.channels, 67 | 3, 68 | padding=1, 69 | norm_cfg=self.norm_cfg, 70 | act_cfg=self.act_cfg), 71 | DepthwiseSeparableConvModule( 72 | self.channels, 73 | self.channels, 74 | 3, 75 | padding=1, 76 | norm_cfg=self.norm_cfg, 77 | act_cfg=self.act_cfg)) 78 | 79 | def forward(self, inputs): 80 | """Forward function.""" 81 | x = self._transform_inputs(inputs) 82 | aspp_outs = [ 83 | resize( 84 | self.image_pool(x), 85 | size=x.size()[2:], 86 | mode='bilinear', 87 | align_corners=self.align_corners) 88 | ] 89 | aspp_outs.extend(self.aspp_modules(x)) 90 | aspp_outs = torch.cat(aspp_outs, dim=1) 91 | output = self.bottleneck(aspp_outs) 92 | if self.c1_bottleneck is not None: 93 | c1_output = self.c1_bottleneck(inputs[0]) 94 | output = resize( 95 | input=output, 96 | size=c1_output.shape[2:], 97 | mode='bilinear', 98 | align_corners=self.align_corners) 99 | output = torch.cat([output, c1_output], dim=1) 100 | output = self.sep_bottleneck(output) 101 | output = self.cls_seg(output) 102 | return output 103 | -------------------------------------------------------------------------------- /mmseg/models/losses/kldiv_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from mmseg.registry import MODELS 7 | 8 | 9 | @MODELS.register_module() 10 | class KLDivLoss(nn.Module): 11 | 12 | def __init__(self, 13 | temperature: float = 1.0, 14 | reduction: str = 'mean', 15 | loss_name: str = 'loss_kld'): 16 | """Kullback-Leibler divergence Loss. 17 | 18 | 19 | 20 | Args: 21 | temperature (float, optional): Temperature param 22 | reduction (str, optional): The method to reduce the loss into a 23 | scalar. Default is "mean". Options are "none", "sum", 24 | and "mean" 25 | """ 26 | 27 | assert isinstance(temperature, (float, int)), \ 28 | 'Expected temperature to be' \ 29 | f'float or int, but got {temperature.__class__.__name__} instead' 30 | assert temperature != 0., 'Temperature must not be zero' 31 | 32 | assert reduction in ['mean', 'none', 'sum'], \ 33 | 'Reduction must be one of the options ("mean", ' \ 34 | f'"sum", "none"), but got {reduction}' 35 | 36 | super().__init__() 37 | self.temperature = temperature 38 | self.reduction = reduction 39 | self._loss_name = loss_name 40 | 41 | def forward(self, input: torch.Tensor, target: torch.Tensor): 42 | """Forward function. Calculate KL divergence Loss. 43 | 44 | Args: 45 | input (Tensor): Logit tensor, 46 | the data type is float32 or float64. 47 | The shape is (N, C) where N is batchsize and C is number of 48 | channels. 49 | If there more than 2 dimensions, shape is (N, C, D1, D2, ... 50 | Dk), k>= 1 51 | target (Tensor): Logit tensor, 52 | the data type is float32 or float64. 53 | input and target must be with the same shape. 54 | 55 | Returns: 56 | (Tensor): Reduced loss. 57 | """ 58 | assert isinstance(input, torch.Tensor), 'Expected input to' \ 59 | f'be Tensor, but got {input.__class__.__name__} instead' 60 | assert isinstance(target, torch.Tensor), 'Expected target to' \ 61 | f'be Tensor, but got {target.__class__.__name__} instead' 62 | 63 | assert input.shape == target.shape, 'Input and target ' \ 64 | 'must have same shape,' \ 65 | f'but got shapes {input.shape} and {target.shape}' 66 | 67 | input = F.softmax(input / self.temperature, dim=1) 68 | target = F.softmax(target / self.temperature, dim=1) 69 | 70 | loss = F.kl_div(input, target, reduction='none', log_target=False) 71 | loss = loss * self.temperature**2 72 | 73 | batch_size = input.shape[0] 74 | 75 | if self.reduction == 'sum': 76 | # Change view to calculate instance-wise sum 77 | loss = loss.view(batch_size, -1) 78 | return torch.sum(loss, dim=1) 79 | 80 | elif self.reduction == 'mean': 81 | # Change view to calculate instance-wise mean 82 | loss = loss.view(batch_size, -1) 83 | return torch.mean(loss, dim=1) 84 | 85 | return loss 86 | 87 | @property 88 | def loss_name(self): 89 | """Loss Name. 90 | 91 | This function must be implemented and will return the name of this 92 | loss function. This name will be used to combine different loss items 93 | by simple sum operation. In addition, if you want this loss item to be 94 | included into the backward graph, `loss_` must be the prefix of the 95 | name. 96 | Returns: 97 | str: The name of this loss item. 98 | """ 99 | return self._loss_name 100 | -------------------------------------------------------------------------------- /mmseg/utils/get_templates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List 3 | 4 | PREDEFINED_TEMPLATES = { 5 | 'imagenet': [ 6 | 'a bad photo of a {}.', 7 | 'a photo of many {}.', 8 | 'a sculpture of a {}.', 9 | 'a photo of the hard to see {}.', 10 | 'a low resolution photo of the {}.', 11 | 'a rendering of a {}.', 12 | 'graffiti of a {}.', 13 | 'a bad photo of the {}.', 14 | 'a cropped photo of the {}.', 15 | 'a tattoo of a {}.', 16 | 'the embroidered {}.', 17 | 'a photo of a hard to see {}.', 18 | 'a bright photo of a {}.', 19 | 'a photo of a clean {}.', 20 | 'a photo of a dirty {}.', 21 | 'a dark photo of the {}.', 22 | 'a drawing of a {}.', 23 | 'a photo of my {}.', 24 | 'the plastic {}.', 25 | 'a photo of the cool {}.', 26 | 'a close-up photo of a {}.', 27 | 'a black and white photo of the {}.', 28 | 'a painting of the {}.', 29 | 'a painting of a {}.', 30 | 'a pixelated photo of the {}.', 31 | 'a sculpture of the {}.', 32 | 'a bright photo of the {}.', 33 | 'a cropped photo of a {}.', 34 | 'a plastic {}.', 35 | 'a photo of the dirty {}.', 36 | 'a jpeg corrupted photo of a {}.', 37 | 'a blurry photo of the {}.', 38 | 'a photo of the {}.', 39 | 'a good photo of the {}.', 40 | 'a rendering of the {}.', 41 | 'a {} in a video game.', 42 | 'a photo of one {}.', 43 | 'a doodle of a {}.', 44 | 'a close-up photo of the {}.', 45 | 'a photo of a {}.', 46 | 'the origami {}.', 47 | 'the {} in a video game.', 48 | 'a sketch of a {}.', 49 | 'a doodle of the {}.', 50 | 'a origami {}.', 51 | 'a low resolution photo of a {}.', 52 | 'the toy {}.', 53 | 'a rendition of the {}.', 54 | 'a photo of the clean {}.', 55 | 'a photo of a large {}.', 56 | 'a rendition of a {}.', 57 | 'a photo of a nice {}.', 58 | 'a photo of a weird {}.', 59 | 'a blurry photo of a {}.', 60 | 'a cartoon {}.', 61 | 'art of a {}.', 62 | 'a sketch of the {}.', 63 | 'a embroidered {}.', 64 | 'a pixelated photo of a {}.', 65 | 'itap of the {}.', 66 | 'a jpeg corrupted photo of the {}.', 67 | 'a good photo of a {}.', 68 | 'a plushie {}.', 69 | 'a photo of the nice {}.', 70 | 'a photo of the small {}.', 71 | 'a photo of the weird {}.', 72 | 'the cartoon {}.', 73 | 'art of the {}.', 74 | 'a drawing of the {}.', 75 | 'a photo of the large {}.', 76 | 'a black and white photo of a {}.', 77 | 'the plushie {}.', 78 | 'a dark photo of a {}.', 79 | 'itap of a {}.', 80 | 'graffiti of the {}.', 81 | 'a toy {}.', 82 | 'itap of my {}.', 83 | 'a photo of a cool {}.', 84 | 'a photo of a small {}.', 85 | 'a tattoo of the {}.', 86 | ], 87 | 'vild': [ 88 | 'a photo of a {}.', 89 | 'This is a photo of a {}', 90 | 'There is a {} in the scene', 91 | 'There is the {} in the scene', 92 | 'a photo of a {} in the scene', 93 | 'a photo of a small {}.', 94 | 'a photo of a medium {}.', 95 | 'a photo of a large {}.', 96 | 'This is a photo of a small {}.', 97 | 'This is a photo of a medium {}.', 98 | 'This is a photo of a large {}.', 99 | 'There is a small {} in the scene.', 100 | 'There is a medium {} in the scene.', 101 | 'There is a large {} in the scene.', 102 | ], 103 | } 104 | 105 | 106 | def get_predefined_templates(template_set_name: str) -> List[str]: 107 | if template_set_name not in PREDEFINED_TEMPLATES: 108 | raise ValueError(f'Template set {template_set_name} not found') 109 | return PREDEFINED_TEMPLATES[template_set_name] 110 | -------------------------------------------------------------------------------- /mmseg/models/utils/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def nlc_to_nchw(x, hw_shape): 3 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 4 | 5 | Args: 6 | x (Tensor): The input tensor of shape [N, L, C] before conversion. 7 | hw_shape (Sequence[int]): The height and width of output feature map. 8 | 9 | Returns: 10 | Tensor: The output tensor of shape [N, C, H, W] after conversion. 11 | """ 12 | H, W = hw_shape 13 | assert len(x.shape) == 3 14 | B, L, C = x.shape 15 | assert L == H * W, 'The seq_len doesn\'t match H, W' 16 | return x.transpose(1, 2).reshape(B, C, H, W) 17 | 18 | 19 | def nchw_to_nlc(x): 20 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 21 | 22 | Args: 23 | x (Tensor): The input tensor of shape [N, C, H, W] before conversion. 24 | 25 | Returns: 26 | Tensor: The output tensor of shape [N, L, C] after conversion. 27 | """ 28 | assert len(x.shape) == 4 29 | return x.flatten(2).transpose(1, 2).contiguous() 30 | 31 | 32 | def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): 33 | """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the 34 | reshaped tensor as the input of `module`, and the convert the output of 35 | `module`, whose shape is. 36 | 37 | [N, L, C], to [N, C, H, W]. 38 | 39 | Args: 40 | module (Callable): A callable object the takes a tensor 41 | with shape [N, L, C] as input. 42 | x (Tensor): The input tensor of shape [N, C, H, W]. 43 | contiguous: 44 | contiguous (Bool): Whether to make the tensor contiguous 45 | after each shape transform. 46 | 47 | Returns: 48 | Tensor: The output tensor of shape [N, C, H, W]. 49 | 50 | Example: 51 | >>> import torch 52 | >>> import torch.nn as nn 53 | >>> norm = nn.LayerNorm(4) 54 | >>> feature_map = torch.rand(4, 4, 5, 5) 55 | >>> output = nchw2nlc2nchw(norm, feature_map) 56 | """ 57 | B, C, H, W = x.shape 58 | if not contiguous: 59 | x = x.flatten(2).transpose(1, 2) 60 | x = module(x, **kwargs) 61 | x = x.transpose(1, 2).reshape(B, C, H, W) 62 | else: 63 | x = x.flatten(2).transpose(1, 2).contiguous() 64 | x = module(x, **kwargs) 65 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() 66 | return x 67 | 68 | 69 | def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): 70 | """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the 71 | reshaped tensor as the input of `module`, and convert the output of 72 | `module`, whose shape is. 73 | 74 | [N, C, H, W], to [N, L, C]. 75 | 76 | Args: 77 | module (Callable): A callable object the takes a tensor 78 | with shape [N, C, H, W] as input. 79 | x (Tensor): The input tensor of shape [N, L, C]. 80 | hw_shape: (Sequence[int]): The height and width of the 81 | feature map with shape [N, C, H, W]. 82 | contiguous (Bool): Whether to make the tensor contiguous 83 | after each shape transform. 84 | 85 | Returns: 86 | Tensor: The output tensor of shape [N, L, C]. 87 | 88 | Example: 89 | >>> import torch 90 | >>> import torch.nn as nn 91 | >>> conv = nn.Conv2d(16, 16, 3, 1, 1) 92 | >>> feature_map = torch.rand(4, 25, 16) 93 | >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) 94 | """ 95 | H, W = hw_shape 96 | assert len(x.shape) == 3 97 | B, L, C = x.shape 98 | assert L == H * W, 'The seq_len doesn\'t match H, W' 99 | if not contiguous: 100 | x = x.transpose(1, 2).reshape(B, C, H, W) 101 | x = module(x, **kwargs) 102 | x = x.flatten(2).transpose(1, 2) 103 | else: 104 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() 105 | x = module(x, **kwargs) 106 | x = x.flatten(2).transpose(1, 2).contiguous() 107 | return x 108 | -------------------------------------------------------------------------------- /mmseg/models/utils/point_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.ops import point_sample 4 | from torch import Tensor 5 | 6 | 7 | def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor: 8 | """Estimate uncertainty based on pred logits. 9 | 10 | We estimate uncertainty as L1 distance between 0.0 and the logits 11 | prediction in 'mask_preds' for the foreground class in `classes`. 12 | 13 | Args: 14 | mask_preds (Tensor): mask predication logits, shape (num_rois, 15 | num_classes, mask_height, mask_width). 16 | 17 | labels (Tensor): Either predicted or ground truth label for 18 | each predicted mask, of length num_rois. 19 | 20 | Returns: 21 | scores (Tensor): Uncertainty scores with the most uncertain 22 | locations having the highest uncertainty score, 23 | shape (num_rois, 1, mask_height, mask_width) 24 | """ 25 | if mask_preds.shape[1] == 1: 26 | gt_class_logits = mask_preds.clone() 27 | else: 28 | inds = torch.arange(mask_preds.shape[0], device=mask_preds.device) 29 | gt_class_logits = mask_preds[inds, labels].unsqueeze(1) 30 | return -torch.abs(gt_class_logits) 31 | 32 | 33 | def get_uncertain_point_coords_with_randomness( 34 | mask_preds: Tensor, labels: Tensor, num_points: int, 35 | oversample_ratio: float, importance_sample_ratio: float) -> Tensor: 36 | """Get ``num_points`` most uncertain points with random points during 37 | train. 38 | 39 | Sample points in [0, 1] x [0, 1] coordinate space based on their 40 | uncertainty. The uncertainties are calculated for each point using 41 | 'get_uncertainty()' function that takes point's logit prediction as 42 | input. 43 | 44 | Args: 45 | mask_preds (Tensor): A tensor of shape (num_rois, num_classes, 46 | mask_height, mask_width) for class-specific or class-agnostic 47 | prediction. 48 | labels (Tensor): The ground truth class for each instance. 49 | num_points (int): The number of points to sample. 50 | oversample_ratio (float): Oversampling parameter. 51 | importance_sample_ratio (float): Ratio of points that are sampled 52 | via importnace sampling. 53 | 54 | Returns: 55 | point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) 56 | that contains the coordinates sampled points. 57 | """ 58 | assert oversample_ratio >= 1 59 | assert 0 <= importance_sample_ratio <= 1 60 | batch_size = mask_preds.shape[0] 61 | num_sampled = int(num_points * oversample_ratio) 62 | point_coords = torch.rand( 63 | batch_size, num_sampled, 2, device=mask_preds.device) 64 | point_logits = point_sample(mask_preds, point_coords) 65 | # It is crucial to calculate uncertainty based on the sampled 66 | # prediction value for the points. Calculating uncertainties of the 67 | # coarse predictions first and sampling them for points leads to 68 | # incorrect results. To illustrate this: assume uncertainty func( 69 | # logits)=-abs(logits), a sampled point between two coarse 70 | # predictions with -1 and 1 logits has 0 logits, and therefore 0 71 | # uncertainty value. However, if we calculate uncertainties for the 72 | # coarse predictions first, both will have -1 uncertainty, 73 | # and sampled point will get -1 uncertainty. 74 | point_uncertainties = get_uncertainty(point_logits, labels) 75 | num_uncertain_points = int(importance_sample_ratio * num_points) 76 | num_random_points = num_points - num_uncertain_points 77 | idx = torch.topk( 78 | point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 79 | shift = num_sampled * torch.arange( 80 | batch_size, dtype=torch.long, device=mask_preds.device) 81 | idx += shift[:, None] 82 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 83 | batch_size, num_uncertain_points, 2) 84 | if num_random_points > 0: 85 | rand_roi_coords = torch.rand( 86 | batch_size, num_random_points, 2, device=mask_preds.device) 87 | point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) 88 | return point_coords 89 | -------------------------------------------------------------------------------- /mmseg/engine/hooks/visualization_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | import warnings 4 | from typing import Optional, Sequence 5 | 6 | import mmcv 7 | import mmengine.fileio as fileio 8 | from mmengine.hooks import Hook 9 | from mmengine.runner import Runner 10 | 11 | from mmseg.registry import HOOKS 12 | from mmseg.structures import SegDataSample 13 | from mmseg.visualization import SegLocalVisualizer 14 | 15 | 16 | @HOOKS.register_module() 17 | class SegVisualizationHook(Hook): 18 | """Segmentation Visualization Hook. Used to visualize validation and 19 | testing process prediction results. 20 | 21 | In the testing phase: 22 | 23 | 1. If ``show`` is True, it means that only the prediction results are 24 | visualized without storing data, so ``vis_backends`` needs to 25 | be excluded. 26 | 27 | Args: 28 | draw (bool): whether to draw prediction results. If it is False, 29 | it means that no drawing will be done. Defaults to False. 30 | interval (int): The interval of hooks. Defaults to 50. 31 | show (bool): Whether to display the drawn image. Default to False. 32 | wait_time (float): The interval of show (s). Defaults to 0. 33 | backend_args (dict, Optional): Arguments to instantiate a file backend. 34 | See https://mmengine.readthedocs.io/en/latest/api/fileio.htm 35 | for details. Defaults to None. 36 | Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. 37 | """ 38 | 39 | def __init__(self, 40 | draw: bool = False, 41 | interval: int = 50, 42 | show: bool = False, 43 | wait_time: float = 0., 44 | backend_args: Optional[dict] = None): 45 | self._visualizer: SegLocalVisualizer = \ 46 | SegLocalVisualizer.get_current_instance() 47 | self.interval = interval 48 | self.show = show 49 | if self.show: 50 | # No need to think about vis backends. 51 | self._visualizer._vis_backends = {} 52 | warnings.warn('The show is True, it means that only ' 53 | 'the prediction results are visualized ' 54 | 'without storing data, so vis_backends ' 55 | 'needs to be excluded.') 56 | 57 | self.wait_time = wait_time 58 | self.backend_args = backend_args.copy() if backend_args else None 59 | self.draw = draw 60 | if not self.draw: 61 | warnings.warn('The draw is False, it means that the ' 62 | 'hook for hooks will not take ' 63 | 'effect. The results will NOT be ' 64 | 'visualized or stored.') 65 | 66 | def _after_iter(self, 67 | runner: Runner, 68 | batch_idx: int, 69 | data_batch: dict, 70 | outputs: Sequence[SegDataSample], 71 | mode: str = 'val') -> None: 72 | """Run after every ``self.interval`` validation iterations. 73 | 74 | Args: 75 | runner (:obj:`Runner`): The runner of the validation process. 76 | batch_idx (int): The index of the current batch in the val loop. 77 | data_batch (dict): Data from dataloader. 78 | outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. 79 | mode (str): mode (str): Current mode of runner. Defaults to 'val'. 80 | """ 81 | if self.draw is False or mode == 'train': 82 | return 83 | 84 | if self.every_n_inner_iters(batch_idx, self.interval): 85 | for output in outputs: 86 | img_path = output.img_path 87 | img_bytes = fileio.get( 88 | img_path, backend_args=self.backend_args) 89 | img = mmcv.imfrombytes(img_bytes, channel_order='rgb') 90 | window_name = f'{mode}_{osp.basename(img_path)}' 91 | 92 | self._visualizer.add_datasample( 93 | window_name, 94 | img, 95 | data_sample=output, 96 | show=self.show, 97 | wait_time=self.wait_time, 98 | step=runner.iter) 99 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/stdc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from mmengine.structures import PixelData 5 | from torch import Tensor 6 | 7 | from mmseg.registry import MODELS 8 | from mmseg.structures import SegDataSample 9 | from mmseg.utils import SampleList 10 | from .fcn_head import FCNHead 11 | 12 | 13 | @MODELS.register_module() 14 | class STDCHead(FCNHead): 15 | """This head is the implementation of `Rethinking BiSeNet For Real-time 16 | Semantic Segmentation `_. 17 | 18 | Args: 19 | boundary_threshold (float): The threshold of calculating boundary. 20 | Default: 0.1. 21 | """ 22 | 23 | def __init__(self, boundary_threshold=0.1, **kwargs): 24 | super().__init__(**kwargs) 25 | self.boundary_threshold = boundary_threshold 26 | # Using register buffer to make laplacian kernel on the same 27 | # device of `seg_label`. 28 | self.register_buffer( 29 | 'laplacian_kernel', 30 | torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], 31 | dtype=torch.float32, 32 | requires_grad=False).reshape((1, 1, 3, 3))) 33 | self.fusion_kernel = torch.nn.Parameter( 34 | torch.tensor([[6. / 10], [3. / 10], [1. / 10]], 35 | dtype=torch.float32).reshape(1, 3, 1, 1), 36 | requires_grad=False) 37 | 38 | def loss_by_feat(self, seg_logits: Tensor, 39 | batch_data_samples: SampleList) -> dict: 40 | """Compute Detail Aggregation Loss.""" 41 | # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv 42 | # parameters. However, it is a constant in original repo and other 43 | # codebase because it would not be added into computation graph 44 | # after threshold operation. 45 | seg_label = self._stack_batch_gt(batch_data_samples).to( 46 | self.laplacian_kernel) 47 | boundary_targets = F.conv2d( 48 | seg_label, self.laplacian_kernel, padding=1) 49 | boundary_targets = boundary_targets.clamp(min=0) 50 | boundary_targets[boundary_targets > self.boundary_threshold] = 1 51 | boundary_targets[boundary_targets <= self.boundary_threshold] = 0 52 | 53 | boundary_targets_x2 = F.conv2d( 54 | seg_label, self.laplacian_kernel, stride=2, padding=1) 55 | boundary_targets_x2 = boundary_targets_x2.clamp(min=0) 56 | 57 | boundary_targets_x4 = F.conv2d( 58 | seg_label, self.laplacian_kernel, stride=4, padding=1) 59 | boundary_targets_x4 = boundary_targets_x4.clamp(min=0) 60 | 61 | boundary_targets_x4_up = F.interpolate( 62 | boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') 63 | boundary_targets_x2_up = F.interpolate( 64 | boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') 65 | 66 | boundary_targets_x2_up[ 67 | boundary_targets_x2_up > self.boundary_threshold] = 1 68 | boundary_targets_x2_up[ 69 | boundary_targets_x2_up <= self.boundary_threshold] = 0 70 | 71 | boundary_targets_x4_up[ 72 | boundary_targets_x4_up > self.boundary_threshold] = 1 73 | boundary_targets_x4_up[ 74 | boundary_targets_x4_up <= self.boundary_threshold] = 0 75 | 76 | boundary_targets_pyramids = torch.stack( 77 | (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), 78 | dim=1) 79 | 80 | boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) 81 | boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids, 82 | self.fusion_kernel) 83 | 84 | boudary_targets_pyramid[ 85 | boudary_targets_pyramid > self.boundary_threshold] = 1 86 | boudary_targets_pyramid[ 87 | boudary_targets_pyramid <= self.boundary_threshold] = 0 88 | 89 | seg_labels = boudary_targets_pyramid.long() 90 | batch_sample_list = [] 91 | for label in seg_labels: 92 | seg_data_sample = SegDataSample() 93 | seg_data_sample.gt_sem_seg = PixelData(data=label) 94 | batch_sample_list.append(seg_data_sample) 95 | 96 | loss = super().loss_by_feat(seg_logits, batch_sample_list) 97 | return loss 98 | -------------------------------------------------------------------------------- /mmseg/models/necks/mla_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | 5 | from mmseg.registry import MODELS 6 | 7 | 8 | class MLAModule(nn.Module): 9 | 10 | def __init__(self, 11 | in_channels=[1024, 1024, 1024, 1024], 12 | out_channels=256, 13 | norm_cfg=None, 14 | act_cfg=None): 15 | super().__init__() 16 | self.channel_proj = nn.ModuleList() 17 | for i in range(len(in_channels)): 18 | self.channel_proj.append( 19 | ConvModule( 20 | in_channels=in_channels[i], 21 | out_channels=out_channels, 22 | kernel_size=1, 23 | norm_cfg=norm_cfg, 24 | act_cfg=act_cfg)) 25 | self.feat_extract = nn.ModuleList() 26 | for i in range(len(in_channels)): 27 | self.feat_extract.append( 28 | ConvModule( 29 | in_channels=out_channels, 30 | out_channels=out_channels, 31 | kernel_size=3, 32 | padding=1, 33 | norm_cfg=norm_cfg, 34 | act_cfg=act_cfg)) 35 | 36 | def forward(self, inputs): 37 | 38 | # feat_list -> [p2, p3, p4, p5] 39 | feat_list = [] 40 | for x, conv in zip(inputs, self.channel_proj): 41 | feat_list.append(conv(x)) 42 | 43 | # feat_list -> [p5, p4, p3, p2] 44 | # mid_list -> [m5, m4, m3, m2] 45 | feat_list = feat_list[::-1] 46 | mid_list = [] 47 | for feat in feat_list: 48 | if len(mid_list) == 0: 49 | mid_list.append(feat) 50 | else: 51 | mid_list.append(mid_list[-1] + feat) 52 | 53 | # mid_list -> [m5, m4, m3, m2] 54 | # out_list -> [o2, o3, o4, o5] 55 | out_list = [] 56 | for mid, conv in zip(mid_list, self.feat_extract): 57 | out_list.append(conv(mid)) 58 | 59 | return tuple(out_list) 60 | 61 | 62 | @MODELS.register_module() 63 | class MLANeck(nn.Module): 64 | """Multi-level Feature Aggregation. 65 | 66 | This neck is `The Multi-level Feature Aggregation construction of 67 | SETR `_. 68 | 69 | 70 | Args: 71 | in_channels (List[int]): Number of input channels per scale. 72 | out_channels (int): Number of output channels (used at each scale). 73 | norm_layer (dict): Config dict for input normalization. 74 | Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). 75 | norm_cfg (dict): Config dict for normalization layer. Default: None. 76 | act_cfg (dict): Config dict for activation layer in ConvModule. 77 | Default: None. 78 | """ 79 | 80 | def __init__(self, 81 | in_channels, 82 | out_channels, 83 | norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), 84 | norm_cfg=None, 85 | act_cfg=None): 86 | super().__init__() 87 | assert isinstance(in_channels, list) 88 | self.in_channels = in_channels 89 | self.out_channels = out_channels 90 | 91 | # In order to build general vision transformer backbone, we have to 92 | # move MLA to neck. 93 | self.norm = nn.ModuleList([ 94 | build_norm_layer(norm_layer, in_channels[i])[1] 95 | for i in range(len(in_channels)) 96 | ]) 97 | 98 | self.mla = MLAModule( 99 | in_channels=in_channels, 100 | out_channels=out_channels, 101 | norm_cfg=norm_cfg, 102 | act_cfg=act_cfg) 103 | 104 | def forward(self, inputs): 105 | assert len(inputs) == len(self.in_channels) 106 | 107 | # Convert from nchw to nlc 108 | outs = [] 109 | for i in range(len(inputs)): 110 | x = inputs[i] 111 | n, c, h, w = x.shape 112 | x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() 113 | x = self.norm[i](x) 114 | x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() 115 | outs.append(x) 116 | 117 | outs = self.mla(outs) 118 | return tuple(outs) 119 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/psp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.registry import MODELS 7 | from ..utils import resize 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | class PPM(nn.ModuleList): 12 | """Pooling Pyramid Module used in PSPNet. 13 | 14 | Args: 15 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 16 | Module. 17 | in_channels (int): Input channels. 18 | channels (int): Channels after modules, before conv_seg. 19 | conv_cfg (dict|None): Config of conv layers. 20 | norm_cfg (dict|None): Config of norm layers. 21 | act_cfg (dict): Config of activation layers. 22 | align_corners (bool): align_corners argument of F.interpolate. 23 | """ 24 | 25 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 26 | act_cfg, align_corners, **kwargs): 27 | super().__init__() 28 | self.pool_scales = pool_scales 29 | self.align_corners = align_corners 30 | self.in_channels = in_channels 31 | self.channels = channels 32 | self.conv_cfg = conv_cfg 33 | self.norm_cfg = norm_cfg 34 | self.act_cfg = act_cfg 35 | for pool_scale in pool_scales: 36 | self.append( 37 | nn.Sequential( 38 | nn.AdaptiveAvgPool2d(pool_scale), 39 | ConvModule( 40 | self.in_channels, 41 | self.channels, 42 | 1, 43 | conv_cfg=self.conv_cfg, 44 | norm_cfg=self.norm_cfg, 45 | act_cfg=self.act_cfg, 46 | **kwargs))) 47 | 48 | def forward(self, x): 49 | """Forward function.""" 50 | ppm_outs = [] 51 | for ppm in self: 52 | ppm_out = ppm(x) 53 | upsampled_ppm_out = resize( 54 | ppm_out, 55 | size=x.size()[2:], 56 | mode='bilinear', 57 | align_corners=self.align_corners) 58 | ppm_outs.append(upsampled_ppm_out) 59 | return ppm_outs 60 | 61 | 62 | @MODELS.register_module() 63 | class PSPHead(BaseDecodeHead): 64 | """Pyramid Scene Parsing Network. 65 | 66 | This head is the implementation of 67 | `PSPNet `_. 68 | 69 | Args: 70 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 71 | Module. Default: (1, 2, 3, 6). 72 | """ 73 | 74 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 75 | super().__init__(**kwargs) 76 | assert isinstance(pool_scales, (list, tuple)) 77 | self.pool_scales = pool_scales 78 | self.psp_modules = PPM( 79 | self.pool_scales, 80 | self.in_channels, 81 | self.channels, 82 | conv_cfg=self.conv_cfg, 83 | norm_cfg=self.norm_cfg, 84 | act_cfg=self.act_cfg, 85 | align_corners=self.align_corners) 86 | self.bottleneck = ConvModule( 87 | self.in_channels + len(pool_scales) * self.channels, 88 | self.channels, 89 | 3, 90 | padding=1, 91 | conv_cfg=self.conv_cfg, 92 | norm_cfg=self.norm_cfg, 93 | act_cfg=self.act_cfg) 94 | 95 | def _forward_feature(self, inputs): 96 | """Forward function for feature maps before classifying each pixel with 97 | ``self.cls_seg`` fc. 98 | 99 | Args: 100 | inputs (list[Tensor]): List of multi-level img features. 101 | 102 | Returns: 103 | feats (Tensor): A tensor of shape (batch_size, self.channels, 104 | H, W) which is feature map for last layer of decoder head. 105 | """ 106 | x = self._transform_inputs(inputs) 107 | psp_outs = [x] 108 | psp_outs.extend(self.psp_modules(x)) 109 | psp_outs = torch.cat(psp_outs, dim=1) 110 | feats = self.bottleneck(psp_outs) 111 | return feats 112 | 113 | def forward(self, inputs): 114 | """Forward function.""" 115 | output = self._forward_feature(inputs) 116 | output = self.cls_seg(output) 117 | return output 118 | -------------------------------------------------------------------------------- /mmseg/models/utils/up_conv_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule, build_upsample_layer 5 | 6 | 7 | class UpConvBlock(nn.Module): 8 | """Upsample convolution block in decoder for UNet. 9 | 10 | This upsample convolution block consists of one upsample module 11 | followed by one convolution block. The upsample module expands the 12 | high-level low-resolution feature map and the convolution block fuses 13 | the upsampled high-level low-resolution feature map and the low-level 14 | high-resolution feature map from encoder. 15 | 16 | Args: 17 | conv_block (nn.Sequential): Sequential of convolutional layers. 18 | in_channels (int): Number of input channels of the high-level 19 | skip_channels (int): Number of input channels of the low-level 20 | high-resolution feature map from encoder. 21 | out_channels (int): Number of output channels. 22 | num_convs (int): Number of convolutional layers in the conv_block. 23 | Default: 2. 24 | stride (int): Stride of convolutional layer in conv_block. Default: 1. 25 | dilation (int): Dilation rate of convolutional layer in conv_block. 26 | Default: 1. 27 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 28 | memory while slowing down the training speed. Default: False. 29 | conv_cfg (dict | None): Config dict for convolution layer. 30 | Default: None. 31 | norm_cfg (dict | None): Config dict for normalization layer. 32 | Default: dict(type='BN'). 33 | act_cfg (dict | None): Config dict for activation layer in ConvModule. 34 | Default: dict(type='ReLU'). 35 | upsample_cfg (dict): The upsample config of the upsample module in 36 | decoder. Default: dict(type='InterpConv'). If the size of 37 | high-level feature map is the same as that of skip feature map 38 | (low-level feature map from encoder), it does not need upsample the 39 | high-level feature map and the upsample_cfg is None. 40 | dcn (bool): Use deformable convolution in convolutional layer or not. 41 | Default: None. 42 | plugins (dict): plugins for convolutional layers. Default: None. 43 | """ 44 | 45 | def __init__(self, 46 | conv_block, 47 | in_channels, 48 | skip_channels, 49 | out_channels, 50 | num_convs=2, 51 | stride=1, 52 | dilation=1, 53 | with_cp=False, 54 | conv_cfg=None, 55 | norm_cfg=dict(type='BN'), 56 | act_cfg=dict(type='ReLU'), 57 | upsample_cfg=dict(type='InterpConv'), 58 | dcn=None, 59 | plugins=None): 60 | super().__init__() 61 | assert dcn is None, 'Not implemented yet.' 62 | assert plugins is None, 'Not implemented yet.' 63 | 64 | self.conv_block = conv_block( 65 | in_channels=2 * skip_channels, 66 | out_channels=out_channels, 67 | num_convs=num_convs, 68 | stride=stride, 69 | dilation=dilation, 70 | with_cp=with_cp, 71 | conv_cfg=conv_cfg, 72 | norm_cfg=norm_cfg, 73 | act_cfg=act_cfg, 74 | dcn=None, 75 | plugins=None) 76 | if upsample_cfg is not None: 77 | self.upsample = build_upsample_layer( 78 | cfg=upsample_cfg, 79 | in_channels=in_channels, 80 | out_channels=skip_channels, 81 | with_cp=with_cp, 82 | norm_cfg=norm_cfg, 83 | act_cfg=act_cfg) 84 | else: 85 | self.upsample = ConvModule( 86 | in_channels, 87 | skip_channels, 88 | kernel_size=1, 89 | stride=1, 90 | padding=0, 91 | conv_cfg=conv_cfg, 92 | norm_cfg=norm_cfg, 93 | act_cfg=act_cfg) 94 | 95 | def forward(self, skip, x): 96 | """Forward function.""" 97 | 98 | x = self.upsample(x) 99 | out = torch.cat([skip, x], dim=1) 100 | out = self.conv_block(out) 101 | 102 | return out 103 | -------------------------------------------------------------------------------- /mmseg/models/losses/silog_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from mmseg.registry import MODELS 9 | from .utils import weight_reduce_loss 10 | 11 | 12 | def silog_loss(pred: Tensor, 13 | target: Tensor, 14 | weight: Optional[Tensor] = None, 15 | eps: float = 1e-4, 16 | reduction: Union[str, None] = 'mean', 17 | avg_factor: Optional[int] = None) -> Tensor: 18 | """Computes the Scale-Invariant Logarithmic (SI-Log) loss between 19 | prediction and target. 20 | 21 | Args: 22 | pred (Tensor): Predicted output. 23 | target (Tensor): Ground truth. 24 | weight (Optional[Tensor]): Optional weight to apply on the loss. 25 | eps (float): Epsilon value to avoid division and log(0). 26 | reduction (Union[str, None]): Specifies the reduction to apply to the 27 | output: 'mean', 'sum' or None. 28 | avg_factor (Optional[int]): Optional average factor for the loss. 29 | 30 | Returns: 31 | Tensor: The calculated SI-Log loss. 32 | """ 33 | pred, target = pred.flatten(1), target.flatten(1) 34 | valid_mask = (target > eps).detach().float() 35 | 36 | diff_log = torch.log(target.clamp(min=eps)) - torch.log( 37 | pred.clamp(min=eps)) 38 | 39 | valid_mask = (target > eps).detach() & (~torch.isnan(diff_log)) 40 | diff_log[~valid_mask] = 0.0 41 | valid_mask = valid_mask.float() 42 | 43 | diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum( 44 | dim=1) / valid_mask.sum(dim=1).clamp(min=eps) 45 | diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum( 46 | dim=1).clamp(min=eps) 47 | 48 | loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2)) 49 | 50 | if weight is not None: 51 | weight = weight.float() 52 | 53 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 54 | return loss 55 | 56 | 57 | @MODELS.register_module() 58 | class SiLogLoss(nn.Module): 59 | """Compute SiLog loss. 60 | 61 | Args: 62 | reduction (str, optional): The method used 63 | to reduce the loss. Options are "none", 64 | "mean" and "sum". Defaults to 'mean'. 65 | loss_weight (float, optional): Weight of loss. Defaults to 1.0. 66 | eps (float): Avoid dividing by zero. Defaults to 1e-3. 67 | loss_name (str, optional): Name of the loss item. If you want this 68 | loss item to be included into the backward graph, `loss_` must 69 | be the prefix of the name. Defaults to 'loss_silog'. 70 | """ 71 | 72 | def __init__(self, 73 | reduction='mean', 74 | loss_weight=1.0, 75 | eps=1e-6, 76 | loss_name='loss_silog'): 77 | super().__init__() 78 | self.reduction = reduction 79 | self.loss_weight = loss_weight 80 | self.eps = eps 81 | self._loss_name = loss_name 82 | 83 | def forward( 84 | self, 85 | pred, 86 | target, 87 | weight=None, 88 | avg_factor=None, 89 | reduction_override=None, 90 | ): 91 | 92 | assert pred.shape == target.shape, 'the shapes of pred ' \ 93 | f'({pred.shape}) and target ({target.shape}) are mismatch' 94 | 95 | assert reduction_override in (None, 'none', 'mean', 'sum') 96 | reduction = ( 97 | reduction_override if reduction_override else self.reduction) 98 | 99 | loss = self.loss_weight * silog_loss( 100 | pred, 101 | target, 102 | weight, 103 | eps=self.eps, 104 | reduction=reduction, 105 | avg_factor=avg_factor, 106 | ) 107 | 108 | return loss 109 | 110 | @property 111 | def loss_name(self): 112 | """Loss Name. 113 | 114 | This function must be implemented and will return the name of this 115 | loss function. This name will be used to combine different loss items 116 | by simple sum operation. In addition, if you want this loss item to be 117 | included into the backward graph, `loss_` must be the prefix of the 118 | name. 119 | Returns: 120 | str: The name of this loss item. 121 | """ 122 | return self._loss_name 123 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/aspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.registry import MODELS 7 | from ..utils import resize 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | class ASPPModule(nn.ModuleList): 12 | """Atrous Spatial Pyramid Pooling (ASPP) Module. 13 | 14 | Args: 15 | dilations (tuple[int]): Dilation rate of each layer. 16 | in_channels (int): Input channels. 17 | channels (int): Channels after modules, before conv_seg. 18 | conv_cfg (dict|None): Config of conv layers. 19 | norm_cfg (dict|None): Config of norm layers. 20 | act_cfg (dict): Config of activation layers. 21 | """ 22 | 23 | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, 24 | act_cfg): 25 | super().__init__() 26 | self.dilations = dilations 27 | self.in_channels = in_channels 28 | self.channels = channels 29 | self.conv_cfg = conv_cfg 30 | self.norm_cfg = norm_cfg 31 | self.act_cfg = act_cfg 32 | for dilation in dilations: 33 | self.append( 34 | ConvModule( 35 | self.in_channels, 36 | self.channels, 37 | 1 if dilation == 1 else 3, 38 | dilation=dilation, 39 | padding=0 if dilation == 1 else dilation, 40 | conv_cfg=self.conv_cfg, 41 | norm_cfg=self.norm_cfg, 42 | act_cfg=self.act_cfg)) 43 | 44 | def forward(self, x): 45 | """Forward function.""" 46 | aspp_outs = [] 47 | for aspp_module in self: 48 | aspp_outs.append(aspp_module(x)) 49 | 50 | return aspp_outs 51 | 52 | 53 | @MODELS.register_module() 54 | class ASPPHead(BaseDecodeHead): 55 | """Rethinking Atrous Convolution for Semantic Image Segmentation. 56 | 57 | This head is the implementation of `DeepLabV3 58 | `_. 59 | 60 | Args: 61 | dilations (tuple[int]): Dilation rates for ASPP module. 62 | Default: (1, 6, 12, 18). 63 | """ 64 | 65 | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 66 | super().__init__(**kwargs) 67 | assert isinstance(dilations, (list, tuple)) 68 | self.dilations = dilations 69 | self.image_pool = nn.Sequential( 70 | nn.AdaptiveAvgPool2d(1), 71 | ConvModule( 72 | self.in_channels, 73 | self.channels, 74 | 1, 75 | conv_cfg=self.conv_cfg, 76 | norm_cfg=self.norm_cfg, 77 | act_cfg=self.act_cfg)) 78 | self.aspp_modules = ASPPModule( 79 | dilations, 80 | self.in_channels, 81 | self.channels, 82 | conv_cfg=self.conv_cfg, 83 | norm_cfg=self.norm_cfg, 84 | act_cfg=self.act_cfg) 85 | self.bottleneck = ConvModule( 86 | (len(dilations) + 1) * self.channels, 87 | self.channels, 88 | 3, 89 | padding=1, 90 | conv_cfg=self.conv_cfg, 91 | norm_cfg=self.norm_cfg, 92 | act_cfg=self.act_cfg) 93 | 94 | def _forward_feature(self, inputs): 95 | """Forward function for feature maps before classifying each pixel with 96 | ``self.cls_seg`` fc. 97 | 98 | Args: 99 | inputs (list[Tensor]): List of multi-level img features. 100 | 101 | Returns: 102 | feats (Tensor): A tensor of shape (batch_size, self.channels, 103 | H, W) which is feature map for last layer of decoder head. 104 | """ 105 | x = self._transform_inputs(inputs) 106 | aspp_outs = [ 107 | resize( 108 | self.image_pool(x), 109 | size=x.size()[2:], 110 | mode='bilinear', 111 | align_corners=self.align_corners) 112 | ] 113 | aspp_outs.extend(self.aspp_modules(x)) 114 | aspp_outs = torch.cat(aspp_outs, dim=1) 115 | feats = self.bottleneck(aspp_outs) 116 | return feats 117 | 118 | def forward(self, inputs): 119 | """Forward function.""" 120 | output = self._forward_feature(inputs) 121 | output = self.cls_seg(output) 122 | return output 123 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/ddr_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Tuple, Union 3 | 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer 6 | from torch import Tensor 7 | 8 | from mmseg.models.decode_heads.decode_head import BaseDecodeHead 9 | from mmseg.models.losses import accuracy 10 | from mmseg.models.utils import resize 11 | from mmseg.registry import MODELS 12 | from mmseg.utils import OptConfigType, SampleList 13 | 14 | 15 | @MODELS.register_module() 16 | class DDRHead(BaseDecodeHead): 17 | """Decode head for DDRNet. 18 | 19 | Args: 20 | in_channels (int): Number of input channels. 21 | channels (int): Number of output channels. 22 | num_classes (int): Number of classes. 23 | norm_cfg (dict, optional): Config dict for normalization layer. 24 | Default: dict(type='BN'). 25 | act_cfg (dict, optional): Config dict for activation layer. 26 | Default: dict(type='ReLU', inplace=True). 27 | """ 28 | 29 | def __init__(self, 30 | in_channels: int, 31 | channels: int, 32 | num_classes: int, 33 | norm_cfg: OptConfigType = dict(type='BN'), 34 | act_cfg: OptConfigType = dict(type='ReLU', inplace=True), 35 | **kwargs): 36 | super().__init__( 37 | in_channels, 38 | channels, 39 | num_classes=num_classes, 40 | norm_cfg=norm_cfg, 41 | act_cfg=act_cfg, 42 | **kwargs) 43 | 44 | self.head = self._make_base_head(self.in_channels, self.channels) 45 | self.aux_head = self._make_base_head(self.in_channels // 2, 46 | self.channels) 47 | self.aux_cls_seg = nn.Conv2d( 48 | self.channels, self.out_channels, kernel_size=1) 49 | 50 | def init_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_( 54 | m.weight, mode='fan_out', nonlinearity='relu') 55 | elif isinstance(m, nn.BatchNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | 59 | def forward( 60 | self, 61 | inputs: Union[Tensor, 62 | Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: 63 | if self.training: 64 | c3_feat, c5_feat = inputs 65 | x_c = self.head(c5_feat) 66 | x_c = self.cls_seg(x_c) 67 | x_s = self.aux_head(c3_feat) 68 | x_s = self.aux_cls_seg(x_s) 69 | 70 | return x_c, x_s 71 | else: 72 | x_c = self.head(inputs) 73 | x_c = self.cls_seg(x_c) 74 | return x_c 75 | 76 | def _make_base_head(self, in_channels: int, 77 | channels: int) -> nn.Sequential: 78 | layers = [ 79 | ConvModule( 80 | in_channels, 81 | channels, 82 | kernel_size=3, 83 | padding=1, 84 | norm_cfg=self.norm_cfg, 85 | act_cfg=self.act_cfg, 86 | order=('norm', 'act', 'conv')), 87 | build_norm_layer(self.norm_cfg, channels)[1], 88 | build_activation_layer(self.act_cfg), 89 | ] 90 | 91 | return nn.Sequential(*layers) 92 | 93 | def loss_by_feat(self, seg_logits: Tuple[Tensor], 94 | batch_data_samples: SampleList) -> dict: 95 | loss = dict() 96 | context_logit, spatial_logit = seg_logits 97 | seg_label = self._stack_batch_gt(batch_data_samples) 98 | 99 | context_logit = resize( 100 | context_logit, 101 | size=seg_label.shape[2:], 102 | mode='bilinear', 103 | align_corners=self.align_corners) 104 | spatial_logit = resize( 105 | spatial_logit, 106 | size=seg_label.shape[2:], 107 | mode='bilinear', 108 | align_corners=self.align_corners) 109 | seg_label = seg_label.squeeze(1) 110 | 111 | loss['loss_context'] = self.loss_decode[0](context_logit, seg_label) 112 | loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label) 113 | loss['acc_seg'] = accuracy( 114 | context_logit, seg_label, ignore_index=self.ignore_index) 115 | 116 | return loss 117 | -------------------------------------------------------------------------------- /mmseg/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import functools 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from mmengine.fileio import load 8 | 9 | 10 | def get_class_weight(class_weight): 11 | """Get class weight for loss function. 12 | 13 | Args: 14 | class_weight (list[float] | str | None): If class_weight is a str, 15 | take it as a file name and read from it. 16 | """ 17 | if isinstance(class_weight, str): 18 | # take it as a file path 19 | if class_weight.endswith('.npy'): 20 | class_weight = np.load(class_weight) 21 | else: 22 | # pkl, json or yaml 23 | class_weight = load(class_weight) 24 | 25 | return class_weight 26 | 27 | 28 | def reduce_loss(loss, reduction) -> torch.Tensor: 29 | """Reduce loss as specified. 30 | 31 | Args: 32 | loss (Tensor): Elementwise loss tensor. 33 | reduction (str): Options are "none", "mean" and "sum". 34 | 35 | Return: 36 | Tensor: Reduced loss tensor. 37 | """ 38 | reduction_enum = F._Reduction.get_enum(reduction) 39 | # none: 0, elementwise_mean:1, sum: 2 40 | if reduction_enum == 0: 41 | return loss 42 | elif reduction_enum == 1: 43 | return loss.mean() 44 | elif reduction_enum == 2: 45 | return loss.sum() 46 | 47 | 48 | def weight_reduce_loss(loss, 49 | weight=None, 50 | reduction='mean', 51 | avg_factor=None) -> torch.Tensor: 52 | """Apply element-wise weight and reduce loss. 53 | 54 | Args: 55 | loss (Tensor): Element-wise loss. 56 | weight (Tensor): Element-wise weights. 57 | reduction (str): Same as built-in losses of PyTorch. 58 | avg_factor (float): Average factor when computing the mean of losses. 59 | 60 | Returns: 61 | Tensor: Processed loss values. 62 | """ 63 | # if weight is specified, apply element-wise weight 64 | if weight is not None: 65 | assert weight.dim() == loss.dim() 66 | if weight.dim() > 1: 67 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 68 | loss = loss * weight 69 | 70 | # if avg_factor is not specified, just reduce the loss 71 | if avg_factor is None: 72 | loss = reduce_loss(loss, reduction) 73 | else: 74 | # if reduction is mean, then average the loss by avg_factor 75 | if reduction == 'mean': 76 | # Avoid causing ZeroDivisionError when avg_factor is 0.0, 77 | # i.e., all labels of an image belong to ignore index. 78 | eps = torch.finfo(torch.float32).eps 79 | loss = loss.sum() / (avg_factor + eps) 80 | # if reduction is 'none', then do nothing, otherwise raise an error 81 | elif reduction != 'none': 82 | raise ValueError('avg_factor can not be used with reduction="sum"') 83 | return loss 84 | 85 | 86 | def weighted_loss(loss_func): 87 | """Create a weighted version of a given loss function. 88 | 89 | To use this decorator, the loss function must have the signature like 90 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 91 | element-wise loss without any reduction. This decorator will add weight 92 | and reduction arguments to the function. The decorated function will have 93 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 94 | avg_factor=None, **kwargs)`. 95 | 96 | :Example: 97 | 98 | >>> import torch 99 | >>> @weighted_loss 100 | >>> def l1_loss(pred, target): 101 | >>> return (pred - target).abs() 102 | 103 | >>> pred = torch.Tensor([0, 2, 3]) 104 | >>> target = torch.Tensor([1, 1, 1]) 105 | >>> weight = torch.Tensor([1, 0, 1]) 106 | 107 | >>> l1_loss(pred, target) 108 | tensor(1.3333) 109 | >>> l1_loss(pred, target, weight) 110 | tensor(1.) 111 | >>> l1_loss(pred, target, reduction='none') 112 | tensor([1., 1., 2.]) 113 | >>> l1_loss(pred, target, weight, avg_factor=2) 114 | tensor(1.5000) 115 | """ 116 | 117 | @functools.wraps(loss_func) 118 | def wrapper(pred, 119 | target, 120 | weight=None, 121 | reduction='mean', 122 | avg_factor=None, 123 | **kwargs): 124 | # get element-wise loss 125 | loss = loss_func(pred, target, **kwargs) 126 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 127 | return loss 128 | 129 | return wrapper 130 | -------------------------------------------------------------------------------- /mmseg/datasets/transforms/formatting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import numpy as np 5 | from mmcv.transforms import to_tensor 6 | from mmcv.transforms.base import BaseTransform 7 | from mmengine.structures import PixelData 8 | 9 | from mmseg.registry import TRANSFORMS 10 | from mmseg.structures import SegDataSample 11 | 12 | 13 | @TRANSFORMS.register_module() 14 | class PackSegInputs(BaseTransform): 15 | """Pack the inputs data for the semantic segmentation. 16 | 17 | The ``img_meta`` item is always populated. The contents of the 18 | ``img_meta`` dictionary depends on ``meta_keys``. By default this includes: 19 | 20 | - ``img_path``: filename of the image 21 | 22 | - ``ori_shape``: original shape of the image as a tuple (h, w, c) 23 | 24 | - ``img_shape``: shape of the image input to the network as a tuple \ 25 | (h, w, c). Note that images may be zero padded on the \ 26 | bottom/right if the batch tensor is larger than this shape. 27 | 28 | - ``pad_shape``: shape of padded images 29 | 30 | - ``scale_factor``: a float indicating the preprocessing scale 31 | 32 | - ``flip``: a boolean indicating if image flip transform was used 33 | 34 | - ``flip_direction``: the flipping direction 35 | 36 | Args: 37 | meta_keys (Sequence[str], optional): Meta keys to be packed from 38 | ``SegDataSample`` and collected in ``data[img_metas]``. 39 | Default: ``('img_path', 'ori_shape', 40 | 'img_shape', 'pad_shape', 'scale_factor', 'flip', 41 | 'flip_direction')`` 42 | """ 43 | 44 | def __init__(self, 45 | meta_keys=('img_path', 'seg_map_path', 'ori_shape', 46 | 'img_shape', 'pad_shape', 'scale_factor', 'flip', 47 | 'flip_direction', 'reduce_zero_label')): 48 | self.meta_keys = meta_keys 49 | 50 | def transform(self, results: dict) -> dict: 51 | """Method to pack the input data. 52 | 53 | Args: 54 | results (dict): Result dict from the data pipeline. 55 | 56 | Returns: 57 | dict: 58 | 59 | - 'inputs' (obj:`torch.Tensor`): The forward data of models. 60 | - 'data_sample' (obj:`SegDataSample`): The annotation info of the 61 | sample. 62 | """ 63 | packed_results = dict() 64 | if 'img' in results: 65 | img = results['img'] 66 | if len(img.shape) < 3: 67 | img = np.expand_dims(img, -1) 68 | if not img.flags.c_contiguous: 69 | img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) 70 | else: 71 | img = img.transpose(2, 0, 1) 72 | img = to_tensor(img).contiguous() 73 | packed_results['inputs'] = img 74 | 75 | data_sample = SegDataSample() 76 | if 'gt_seg_map' in results: 77 | if len(results['gt_seg_map'].shape) == 2: 78 | data = to_tensor(results['gt_seg_map'][None, 79 | ...].astype(np.int64)) 80 | else: 81 | warnings.warn('Please pay attention your ground truth ' 82 | 'segmentation map, usually the segmentation ' 83 | 'map is 2D, but got ' 84 | f'{results["gt_seg_map"].shape}') 85 | data = to_tensor(results['gt_seg_map'].astype(np.int64)) 86 | gt_sem_seg_data = dict(data=data) 87 | data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) 88 | 89 | if 'gt_edge_map' in results: 90 | gt_edge_data = dict( 91 | data=to_tensor(results['gt_edge_map'][None, 92 | ...].astype(np.int64))) 93 | data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data))) 94 | 95 | if 'gt_depth_map' in results: 96 | gt_depth_data = dict( 97 | data=to_tensor(results['gt_depth_map'][None, ...])) 98 | data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data))) 99 | 100 | img_meta = {} 101 | for key in self.meta_keys: 102 | if key in results: 103 | img_meta[key] = results[key] 104 | data_sample.set_metainfo(img_meta) 105 | packed_results['data_samples'] = data_sample 106 | 107 | return packed_results 108 | 109 | def __repr__(self) -> str: 110 | repr_str = self.__class__.__name__ 111 | repr_str += f'(meta_keys={self.meta_keys})' 112 | return repr_str 113 | -------------------------------------------------------------------------------- /mmseg/datasets/dsdl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | from typing import Dict, List, Optional, Sequence, Union 4 | 5 | from mmseg.registry import DATASETS 6 | from .basesegdataset import BaseSegDataset 7 | 8 | try: 9 | from dsdl.dataset import DSDLDataset 10 | except ImportError: 11 | DSDLDataset = None 12 | 13 | 14 | @DATASETS.register_module() 15 | class DSDLSegDataset(BaseSegDataset): 16 | """Dataset for dsdl segmentation. 17 | 18 | Args: 19 | specific_key_path(dict): Path of specific key which can not 20 | be loaded by it's field name. 21 | pre_transform(dict): pre-transform functions before loading. 22 | used_labels(sequence): list of actual used classes in train steps, 23 | this must be subset of class domain. 24 | """ 25 | 26 | METAINFO = {} 27 | 28 | def __init__(self, 29 | specific_key_path: Dict = {}, 30 | pre_transform: Dict = {}, 31 | used_labels: Optional[Sequence] = None, 32 | **kwargs) -> None: 33 | 34 | if DSDLDataset is None: 35 | raise RuntimeError( 36 | 'Package dsdl is not installed. Please run "pip install dsdl".' 37 | ) 38 | self.used_labels = used_labels 39 | 40 | loc_config = dict(type='LocalFileReader', working_dir='') 41 | if kwargs.get('data_root'): 42 | kwargs['ann_file'] = os.path.join(kwargs['data_root'], 43 | kwargs['ann_file']) 44 | required_fields = ['Image', 'LabelMap'] 45 | 46 | self.dsdldataset = DSDLDataset( 47 | dsdl_yaml=kwargs['ann_file'], 48 | location_config=loc_config, 49 | required_fields=required_fields, 50 | specific_key_path=specific_key_path, 51 | transform=pre_transform, 52 | ) 53 | BaseSegDataset.__init__(self, **kwargs) 54 | 55 | def load_data_list(self) -> List[Dict]: 56 | """Load data info from a dsdl yaml file named as ``self.ann_file`` 57 | 58 | Returns: 59 | List[dict]: A list of data list. 60 | """ 61 | 62 | if self.used_labels: 63 | self._metainfo['classes'] = tuple(self.used_labels) 64 | self.label_map = self.get_label_map(self.used_labels) 65 | else: 66 | self._metainfo['classes'] = tuple(['background'] + 67 | self.dsdldataset.class_names) 68 | data_list = [] 69 | 70 | for i, data in enumerate(self.dsdldataset): 71 | datainfo = dict( 72 | img_path=os.path.join(self.data_prefix['img_path'], 73 | data['Image'][0].location), 74 | seg_map_path=os.path.join(self.data_prefix['seg_map_path'], 75 | data['LabelMap'][0].location), 76 | label_map=self.label_map, 77 | reduce_zero_label=self.reduce_zero_label, 78 | seg_fields=[], 79 | ) 80 | data_list.append(datainfo) 81 | 82 | return data_list 83 | 84 | def get_label_map(self, 85 | new_classes: Optional[Sequence] = None 86 | ) -> Union[Dict, None]: 87 | """Require label mapping. 88 | 89 | The ``label_map`` is a dictionary, its keys are the old label ids and 90 | its values are the new label ids, and is used for changing pixel 91 | labels in load_annotations. If and only if old classes in class_dom 92 | is not equal to new classes in args and nether of them is not 93 | None, `label_map` is not None. 94 | Args: 95 | new_classes (list, tuple, optional): The new classes name from 96 | metainfo. Default to None. 97 | Returns: 98 | dict, optional: The mapping from old classes to new classes. 99 | """ 100 | old_classes = ['background'] + self.dsdldataset.class_names 101 | if (new_classes is not None and old_classes is not None 102 | and list(new_classes) != list(old_classes)): 103 | 104 | label_map = {} 105 | if not set(new_classes).issubset(old_classes): 106 | raise ValueError( 107 | f'new classes {new_classes} is not a ' 108 | f'subset of classes {old_classes} in class_dom.') 109 | for i, c in enumerate(old_classes): 110 | if c not in new_classes: 111 | label_map[i] = 255 112 | else: 113 | label_map[i] = new_classes.index(c) 114 | return label_map 115 | else: 116 | return None 117 | --------------------------------------------------------------------------------