├── data └── empty ├── logs └── empty ├── marinext ├── trained_models │ └── empty ├── mmseg │ ├── ops │ │ ├── __init__.py │ │ ├── wrappers.py │ │ └── encoding.py │ ├── core │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── misc.py │ │ │ └── dist_util.py │ │ ├── seg │ │ │ ├── sampler │ │ │ │ ├── __init__.py │ │ │ │ ├── base_pixel_sampler.py │ │ │ │ └── ohem_pixel_sampler.py │ │ │ ├── __init__.py │ │ │ └── builder.py │ │ ├── optimizers │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ └── eval_hooks.py │ │ └── builder.py │ ├── models │ │ ├── segmentors │ │ │ ├── __init__.py │ │ │ └── cascade_encoder_decoder.py │ │ ├── necks │ │ │ ├── __init__.py │ │ │ ├── featurepyramid.py │ │ │ ├── multilevel_neck.py │ │ │ ├── mla_neck.py │ │ │ ├── jpu.py │ │ │ └── ic_neck.py │ │ ├── __init__.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ ├── utils.py │ │ │ └── dice_loss.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── make_divisible.py │ │ │ ├── se_layer.py │ │ │ ├── res_layer.py │ │ │ ├── shape_convert.py │ │ │ └── up_conv_block.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── timm_backbone.py │ │ │ └── resnext.py │ │ ├── decode_heads │ │ │ ├── cc_head.py │ │ │ ├── __init__.py │ │ │ ├── nl_head.py │ │ │ ├── gc_head.py │ │ │ ├── segformer_head.py │ │ │ ├── setr_mla_head.py │ │ │ ├── cascade_decode_head.py │ │ │ ├── sep_fcn_head.py │ │ │ ├── fpn_head.py │ │ │ ├── setr_up_head.py │ │ │ ├── lraspp_head.py │ │ │ ├── fcn_head.py │ │ │ ├── stdc_head.py │ │ │ ├── sep_aspp_head.py │ │ │ ├── psp_head.py │ │ │ ├── aspp_head.py │ │ │ ├── ocr_head.py │ │ │ ├── uper_head.py │ │ │ ├── dnl_head.py │ │ │ ├── segmenter_mask_head.py │ │ │ ├── isa_head.py │ │ │ ├── dm_head.py │ │ │ ├── apc_head.py │ │ │ └── da_head.py │ │ └── builder.py │ ├── utils │ │ ├── __init__.py │ │ ├── collect_env.py │ │ ├── logger.py │ │ ├── misc.py │ │ ├── set_env.py │ │ └── util_distribution.py │ ├── version.py │ └── __init__.py ├── configs │ └── marinext.tiny.240x240.mados.py └── marinext_wrapper.py ├── .github └── MADOS_LOGO_text.png ├── requirements.txt ├── LICENSE ├── utils ├── vscp.py ├── test_time_aug.py ├── assets.py ├── stack_patches.py └── spectral_extraction.py ├── .gitignore └── README.md /data/empty: -------------------------------------------------------------------------------- 1 | #folder structure 2 | -------------------------------------------------------------------------------- /logs/empty: -------------------------------------------------------------------------------- 1 | #folder structure 2 | -------------------------------------------------------------------------------- /marinext/trained_models/empty: -------------------------------------------------------------------------------- 1 | #folder structure 2 | -------------------------------------------------------------------------------- /.github/MADOS_LOGO_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gkakogeorgiou/mados/HEAD/.github/MADOS_LOGO_text.png -------------------------------------------------------------------------------- /marinext/mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .encoding import Encoding 3 | from .wrappers import Upsample, resize 4 | 5 | __all__ = ['Upsample', 'resize', 'Encoding'] 6 | -------------------------------------------------------------------------------- /marinext/mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_util import check_dist_init, sync_random_seed 3 | from .misc import add_prefix 4 | 5 | __all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed'] 6 | -------------------------------------------------------------------------------- /marinext/mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_pixel_sampler import BasePixelSampler 3 | from .ohem_pixel_sampler import OHEMPixelSampler 4 | 5 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /marinext/mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_pixel_sampler 3 | from .sampler import BasePixelSampler, OHEMPixelSampler 4 | 5 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /marinext/mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseSegmentor 3 | from .cascade_encoder_decoder import CascadeEncoderDecoder 4 | from .encoder_decoder import EncoderDecoder 5 | 6 | __all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] 7 | -------------------------------------------------------------------------------- /marinext/mmseg/core/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .layer_decay_optimizer_constructor import ( 3 | LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) 4 | 5 | __all__ = [ 6 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' 7 | ] 8 | -------------------------------------------------------------------------------- /marinext/mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | PIXEL_SAMPLERS = Registry('pixel sampler') 5 | 6 | 7 | def build_pixel_sampler(cfg, **default_args): 8 | """Build pixel sampler for segmentation map.""" 9 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 10 | -------------------------------------------------------------------------------- /marinext/mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .featurepyramid import Feature2Pyramid 3 | from .fpn import FPN 4 | from .ic_neck import ICNeck 5 | from .jpu import JPU 6 | from .mla_neck import MLANeck 7 | from .multilevel_neck import MultiLevelNeck 8 | 9 | __all__ = [ 10 | 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' 11 | ] 12 | -------------------------------------------------------------------------------- /marinext/mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BasePixelSampler(metaclass=ABCMeta): 6 | """Base class of pixel sampler.""" 7 | 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | @abstractmethod 12 | def sample(self, seg_logit, seg_label): 13 | """Placeholder for sample function.""" 14 | -------------------------------------------------------------------------------- /marinext/mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | from .misc import find_latest_checkpoint 5 | from .set_env import setup_multi_processes 6 | from .util_distribution import build_ddp, build_dp, get_device 7 | 8 | __all__ = [ 9 | 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 10 | 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device' 11 | ] 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.1 2 | networkx==3.1 3 | torch==1.11.0+cu113 4 | torchvision==0.12.0+cu113 5 | mmcv-full==1.6.0 6 | markdown-it-py==2.2.0 7 | timm==0.4.12 8 | opencv-python==4.7.0.72 9 | utm==0.7.0 10 | setuptools==69.1.0 11 | pyproj==3.3.0 12 | pandas==1.4.3 13 | rasterio==1.3a3 14 | glob2==0.7 15 | matplotlib==3.5.2 16 | scikit-image == 0.19.0 17 | scikit-learn == 1.0.1 18 | scipy==1.8.1 19 | pyresample==1.24.1 20 | h5py==3.7.0 21 | tqdm==4.66.2 22 | tensorboard==2.14.0 23 | -------------------------------------------------------------------------------- /marinext/mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import (OPTIMIZER_BUILDERS, build_optimizer, 3 | build_optimizer_constructor) 4 | from .evaluation import * # noqa: F401, F403 5 | from .optimizers import * # noqa: F401, F403 6 | from .seg import * # noqa: F401, F403 7 | from .utils import * # noqa: F401, F403 8 | 9 | __all__ = [ 10 | 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' 11 | ] 12 | -------------------------------------------------------------------------------- /marinext/mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def add_prefix(inputs, prefix): 3 | """Add prefix for dict. 4 | 5 | Args: 6 | inputs (dict): The input dict with str keys. 7 | prefix (str): The prefix to add. 8 | 9 | Returns: 10 | 11 | dict: The dict with keys updated with ``prefix``. 12 | """ 13 | 14 | outputs = dict() 15 | for name, value in inputs.items(): 16 | outputs[f'{prefix}.{name}'] = value 17 | 18 | return outputs 19 | -------------------------------------------------------------------------------- /marinext/mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .class_names import get_classes, get_palette 3 | from .eval_hooks import DistEvalHook, EvalHook 4 | from .metrics import (eval_metrics, intersect_and_union, mean_dice, 5 | mean_fscore, mean_iou, pre_eval_to_metrics) 6 | 7 | __all__ = [ 8 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 9 | 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', 10 | 'intersect_and_union' 11 | ] 12 | -------------------------------------------------------------------------------- /marinext/mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmseg 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /marinext/mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.24.1' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /marinext/mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 4 | build_head, build_loss, build_segmentor) 5 | from .decode_heads import * # noqa: F401,F403 6 | from .losses import * # noqa: F401,F403 7 | from .necks import * # noqa: F401,F403 8 | from .segmentors import * # noqa: F401,F403 9 | 10 | __all__ = [ 11 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 12 | 'build_head', 'build_loss', 'build_segmentor' 13 | ] 14 | -------------------------------------------------------------------------------- /marinext/mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 4 | cross_entropy, mask_cross_entropy) 5 | from .dice_loss import DiceLoss 6 | from .focal_loss import FocalLoss 7 | from .lovasz_loss import LovaszLoss 8 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 9 | 10 | __all__ = [ 11 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 12 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 13 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 14 | 'FocalLoss' 15 | ] 16 | -------------------------------------------------------------------------------- /marinext/mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .embed import PatchEmbed 3 | from .inverted_residual import InvertedResidual, InvertedResidualV3 4 | from .make_divisible import make_divisible 5 | from .res_layer import ResLayer 6 | from .se_layer import SELayer 7 | from .self_attention_block import SelfAttentionBlock 8 | from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, 9 | nlc_to_nchw) 10 | from .up_conv_block import UpConvBlock 11 | 12 | __all__ = [ 13 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 14 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', 15 | 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc' 16 | ] 17 | -------------------------------------------------------------------------------- /marinext/configs/marinext.tiny.240x240.mados.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='EncoderDecoder', 4 | backbone=dict( 5 | type="MSCAN", 6 | in_chans=11, 7 | embed_dims=[32, 64, 160, 256], 8 | mlp_ratios=[8, 8, 4, 4], 9 | drop_rate=0.0, 10 | drop_path_rate=0.1, 11 | depths=[3, 3, 5, 2], 12 | norm_cfg=dict(type="SyncBN", requires_grad=True)), 13 | decode_head=dict( 14 | type='LightHamHead', 15 | in_channels=[32, 64, 160, 256], 16 | in_index=[0, 1, 2, 3], 17 | channels=256, 18 | ham_channels=256, 19 | ham_kwargs=dict(MD_R=16), 20 | dropout_ratio=0.1, 21 | num_classes=15, 22 | norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), 23 | align_corners=False)) 24 | -------------------------------------------------------------------------------- /marinext/mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmseg". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | 26 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 27 | 28 | return logger 29 | -------------------------------------------------------------------------------- /marinext/marinext_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from mmseg.models import build_segmentor 3 | from mmcv.utils import Config 4 | from mmcv.utils import get_logger 5 | from mmcv.cnn.utils import revert_sync_batchnorm 6 | from torch import nn 7 | from os.path import dirname as up 8 | 9 | logger = get_logger('mmdet') 10 | logger.setLevel('WARNING') 11 | 12 | configs_path = os.path.join(up(__file__), 'configs') 13 | 14 | class MariNext(nn.Module): 15 | 16 | def __init__(self, in_chans, num_classes): 17 | super(MariNext, self).__init__() 18 | conf_file = os.path.join(configs_path,'marinext.tiny.240x240.mados.py') 19 | cfg = Config.fromfile(conf_file) 20 | cfg.model.backbone.in_chans = in_chans 21 | cfg.model.decode_head.num_classes = num_classes 22 | model = build_segmentor(cfg.model) 23 | model.init_weights() 24 | model = revert_sync_batchnorm(model) 25 | 26 | self.backbone = model.backbone 27 | self.decode_head = model.decode_head 28 | 29 | def forward(self, x): 30 | return self.decode_head(self.backbone(x)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ioannis Kakogeorgiou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/vscp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Author: Ioannis Kakogeorgiou 4 | Email: gkakogeorgiou@gmail.com 5 | Python Version: 3.7.10 6 | Description: vscp.py includes the very simple copy-paste 7 | augmentation strategy. 8 | ''' 9 | 10 | import numpy as np 11 | 12 | def VSCP(image, target): 13 | 14 | n_augmented = image.shape[0]//2 15 | 16 | image_temp = image[:n_augmented*2,:,:,:].copy() 17 | target_temp = target[:n_augmented*2,:,:].copy() 18 | 19 | image_augmented = [] 20 | target_augmented = [] 21 | for i in range(n_augmented): 22 | 23 | image_temp[i,:,target_temp[i+n_augmented,:,:]!=-1] = image_temp[i+n_augmented,:,target_temp[i+n_augmented,:,:]!=-1] 24 | image_augmented.append(image_temp[i,:,:].copy()) 25 | 26 | target_temp[i,target_temp[i+n_augmented,:,:]!=-1] = target_temp[i+n_augmented,target_temp[i+n_augmented,:,:]!=-1] 27 | target_augmented.append(target_temp[i,:,:].copy()) 28 | 29 | image_augmented = np.stack(image_augmented) 30 | target_augmented = np.stack(target_augmented) 31 | 32 | return image_augmented, target_augmented -------------------------------------------------------------------------------- /marinext/mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .beit import BEiT 3 | from .bisenetv1 import BiSeNetV1 4 | from .bisenetv2 import BiSeNetV2 5 | from .cgnet import CGNet 6 | from .erfnet import ERFNet 7 | from .fast_scnn import FastSCNN 8 | from .hrnet import HRNet 9 | from .icnet import ICNet 10 | from .mae import MAE 11 | from .mit import MixVisionTransformer 12 | from .mobilenet_v2 import MobileNetV2 13 | from .mobilenet_v3 import MobileNetV3 14 | from .resnest import ResNeSt 15 | from .resnet import ResNet, ResNetV1c, ResNetV1d 16 | from .resnext import ResNeXt 17 | from .stdc import STDCContextPathNet, STDCNet 18 | from .swin import SwinTransformer 19 | from .timm_backbone import TIMMBackbone 20 | from .twins import PCPVT, SVT 21 | from .unet import UNet 22 | from .vit import VisionTransformer 23 | from .mscan import MSCAN 24 | __all__ = [ 25 | 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 26 | 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 27 | 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 28 | 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 29 | 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 30 | 'MSCAN' 31 | ] 32 | -------------------------------------------------------------------------------- /utils/test_time_aug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Author: Ioannis Kakogeorgiou 4 | Email: gkakogeorgiou@gmail.com 5 | Python Version: 3.7.10 6 | Description: test_time_aug.py includes the test-time eight augmentations 7 | with random rotations and horizontal flips. 8 | ''' 9 | 10 | 11 | import torch 12 | from torchvision.transforms.functional import hflip 13 | 14 | def TTA(img, reverse_aggregation = False): 15 | im_list = [] 16 | 17 | if not reverse_aggregation: 18 | for k in [0,1,2,3]: 19 | im = torch.rot90(img, k=k, dims=[-2, -1]) 20 | im_list.append(im) 21 | 22 | im = hflip(im) 23 | im_list.append(im) 24 | 25 | img = torch.cat(im_list) 26 | 27 | else: 28 | 29 | for k in [3,2,1,0]: 30 | im = hflip(img[k*2 + 1,:,:]) 31 | im = torch.rot90(im, k=-k, dims=[-2, -1]) 32 | im_list.append(im) 33 | 34 | im = torch.rot90(img[k*2,:,:], k=-k, dims=[-2, -1]) 35 | im_list.append(im) 36 | 37 | img = torch.stack(im_list) 38 | 39 | img = torch.mode(img, dim=0, keepdim=True)[0] 40 | 41 | return img -------------------------------------------------------------------------------- /marinext/mmseg/core/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | 4 | from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS 5 | from mmcv.utils import Registry, build_from_cfg 6 | 7 | OPTIMIZER_BUILDERS = Registry( 8 | 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) 9 | 10 | 11 | def build_optimizer_constructor(cfg): 12 | constructor_type = cfg.get('type') 13 | if constructor_type in OPTIMIZER_BUILDERS: 14 | return build_from_cfg(cfg, OPTIMIZER_BUILDERS) 15 | elif constructor_type in MMCV_OPTIMIZER_BUILDERS: 16 | return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) 17 | else: 18 | raise KeyError(f'{constructor_type} is not registered ' 19 | 'in the optimizer builder registry.') 20 | 21 | 22 | def build_optimizer(model, cfg): 23 | optimizer_cfg = copy.deepcopy(cfg) 24 | constructor_type = optimizer_cfg.pop('constructor', 25 | 'DefaultOptimizerConstructor') 26 | paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) 27 | optim_constructor = build_optimizer_constructor( 28 | dict( 29 | type=constructor_type, 30 | optimizer_cfg=optimizer_cfg, 31 | paramwise_cfg=paramwise_cfg)) 32 | optimizer = optim_constructor(model) 33 | return optimizer 34 | -------------------------------------------------------------------------------- /marinext/mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/cc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | try: 8 | from mmcv.ops import CrissCrossAttention 9 | except ModuleNotFoundError: 10 | CrissCrossAttention = None 11 | 12 | 13 | @HEADS.register_module() 14 | class CCHead(FCNHead): 15 | """CCNet: Criss-Cross Attention for Semantic Segmentation. 16 | 17 | This head is the implementation of `CCNet 18 | `_. 19 | 20 | Args: 21 | recurrence (int): Number of recurrence of Criss Cross Attention 22 | module. Default: 2. 23 | """ 24 | 25 | def __init__(self, recurrence=2, **kwargs): 26 | if CrissCrossAttention is None: 27 | raise RuntimeError('Please install mmcv-full for ' 28 | 'CrissCrossAttention ops') 29 | super(CCHead, self).__init__(num_convs=2, **kwargs) 30 | self.recurrence = recurrence 31 | self.cca = CrissCrossAttention(self.channels) 32 | 33 | def forward(self, inputs): 34 | """Forward function.""" 35 | x = self._transform_inputs(inputs) 36 | output = self.convs[0](x) 37 | for _ in range(self.recurrence): 38 | output = self.cca(output) 39 | output = self.convs[1](output) 40 | if self.concat_input: 41 | output = self.conv_cat(torch.cat([x, output], dim=1)) 42 | output = self.cls_seg(output) 43 | return output 44 | -------------------------------------------------------------------------------- /marinext/mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmcv.cnn import MODELS as MMCV_MODELS 5 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 6 | from mmcv.utils import Registry 7 | 8 | MODELS = Registry('models', parent=MMCV_MODELS) 9 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 10 | 11 | BACKBONES = MODELS 12 | NECKS = MODELS 13 | HEADS = MODELS 14 | LOSSES = MODELS 15 | SEGMENTORS = MODELS 16 | 17 | 18 | def build_backbone(cfg): 19 | """Build backbone.""" 20 | return BACKBONES.build(cfg) 21 | 22 | 23 | def build_neck(cfg): 24 | """Build neck.""" 25 | return NECKS.build(cfg) 26 | 27 | 28 | def build_head(cfg): 29 | """Build head.""" 30 | return HEADS.build(cfg) 31 | 32 | 33 | def build_loss(cfg): 34 | """Build loss.""" 35 | return LOSSES.build(cfg) 36 | 37 | 38 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 39 | """Build segmentor.""" 40 | if train_cfg is not None or test_cfg is not None: 41 | warnings.warn( 42 | 'train_cfg and test_cfg is deprecated, ' 43 | 'please specify them in model', UserWarning) 44 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 45 | 'train_cfg specified in both outer field and model field ' 46 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 47 | 'test_cfg specified in both outer field and model field ' 48 | return SEGMENTORS.build( 49 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 50 | -------------------------------------------------------------------------------- /marinext/mmseg/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os.path as osp 4 | import warnings 5 | 6 | 7 | def find_latest_checkpoint(path, suffix='pth'): 8 | """This function is for finding the latest checkpoint. 9 | 10 | It will be used when automatically resume, modified from 11 | https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py 12 | 13 | Args: 14 | path (str): The path to find checkpoints. 15 | suffix (str): File extension for the checkpoint. Defaults to pth. 16 | 17 | Returns: 18 | latest_path(str | None): File path of the latest checkpoint. 19 | """ 20 | if not osp.exists(path): 21 | warnings.warn("The path of the checkpoints doesn't exist.") 22 | return None 23 | if osp.exists(osp.join(path, f'latest.{suffix}')): 24 | return osp.join(path, f'latest.{suffix}') 25 | 26 | checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) 27 | if len(checkpoints) == 0: 28 | warnings.warn('The are no checkpoints in the path') 29 | return None 30 | latest = -1 31 | latest_path = '' 32 | for checkpoint in checkpoints: 33 | if len(checkpoint) < len(latest_path): 34 | continue 35 | # `count` is iteration number, as checkpoints are saved as 36 | # 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. 37 | count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) 38 | if count > latest: 39 | latest = count 40 | latest_path = checkpoint 41 | return latest_path 42 | -------------------------------------------------------------------------------- /marinext/mmseg/core/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def check_dist_init(): 9 | return dist.is_available() and dist.is_initialized() 10 | 11 | 12 | def sync_random_seed(seed=None, device='cuda'): 13 | """Make sure different ranks share the same seed. All workers must call 14 | this function, otherwise it will deadlock. This method is generally used in 15 | `DistributedSampler`, because the seed should be identical across all 16 | processes in the distributed group. 17 | 18 | In distributed sampling, different ranks should sample non-overlapped 19 | data in the dataset. Therefore, this function is used to make sure that 20 | each rank shuffles the data indices in the same order based 21 | on the same seed. Then different ranks could use different indices 22 | to select non-overlapped data from the same data list. 23 | 24 | Args: 25 | seed (int, Optional): The seed. Default to None. 26 | device (str): The device where the seed will be put on. 27 | Default to 'cuda'. 28 | Returns: 29 | int: Seed to be used. 30 | """ 31 | 32 | if seed is None: 33 | seed = np.random.randint(2**31) 34 | assert isinstance(seed, int) 35 | 36 | rank, world_size = get_dist_info() 37 | 38 | if world_size == 1: 39 | return seed 40 | 41 | if rank == 0: 42 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 43 | else: 44 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 45 | dist.broadcast(random_num, src=0) 46 | return random_num.item() 47 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ann_head import ANNHead 3 | from .apc_head import APCHead 4 | from .aspp_head import ASPPHead 5 | from .cc_head import CCHead 6 | from .da_head import DAHead 7 | from .dm_head import DMHead 8 | from .dnl_head import DNLHead 9 | from .dpt_head import DPTHead 10 | from .ema_head import EMAHead 11 | from .enc_head import EncHead 12 | from .fcn_head import FCNHead 13 | from .fpn_head import FPNHead 14 | from .gc_head import GCHead 15 | from .isa_head import ISAHead 16 | from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator 17 | from .lraspp_head import LRASPPHead 18 | from .nl_head import NLHead 19 | from .ocr_head import OCRHead 20 | from .point_head import PointHead 21 | from .psa_head import PSAHead 22 | from .psp_head import PSPHead 23 | from .segformer_head import SegformerHead 24 | from .segmenter_mask_head import SegmenterMaskTransformerHead 25 | from .sep_aspp_head import DepthwiseSeparableASPPHead 26 | from .sep_fcn_head import DepthwiseSeparableFCNHead 27 | from .setr_mla_head import SETRMLAHead 28 | from .setr_up_head import SETRUPHead 29 | from .stdc_head import STDCHead 30 | from .uper_head import UPerHead 31 | from .ham_head import LightHamHead 32 | 33 | __all__ = [ 34 | 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 35 | 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 36 | 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 37 | 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 38 | 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 39 | 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 40 | 'KernelUpdateHead', 'KernelUpdator', 'LightHamHead' 41 | ] 42 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/nl_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import NonLocal2d 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class NLHead(FCNHead): 11 | """Non-local Neural Networks. 12 | 13 | This head is the implementation of `NLNet 14 | `_. 15 | 16 | Args: 17 | reduction (int): Reduction factor of projection transform. Default: 2. 18 | use_scale (bool): Whether to scale pairwise_weight by 19 | sqrt(1/inter_channels). Default: True. 20 | mode (str): The nonlocal mode. Options are 'embedded_gaussian', 21 | 'dot_product'. Default: 'embedded_gaussian.'. 22 | """ 23 | 24 | def __init__(self, 25 | reduction=2, 26 | use_scale=True, 27 | mode='embedded_gaussian', 28 | **kwargs): 29 | super(NLHead, self).__init__(num_convs=2, **kwargs) 30 | self.reduction = reduction 31 | self.use_scale = use_scale 32 | self.mode = mode 33 | self.nl_block = NonLocal2d( 34 | in_channels=self.channels, 35 | reduction=self.reduction, 36 | use_scale=self.use_scale, 37 | conv_cfg=self.conv_cfg, 38 | norm_cfg=self.norm_cfg, 39 | mode=self.mode) 40 | 41 | def forward(self, inputs): 42 | """Forward function.""" 43 | x = self._transform_inputs(inputs) 44 | output = self.convs[0](x) 45 | output = self.nl_block(output) 46 | output = self.convs[1](output) 47 | if self.concat_input: 48 | output = self.conv_cat(torch.cat([x, output], dim=1)) 49 | output = self.cls_seg(output) 50 | return output 51 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/gc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import ContextBlock 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class GCHead(FCNHead): 11 | """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. 12 | 13 | This head is the implementation of `GCNet 14 | `_. 15 | 16 | Args: 17 | ratio (float): Multiplier of channels ratio. Default: 1/4. 18 | pooling_type (str): The pooling type of context aggregation. 19 | Options are 'att', 'avg'. Default: 'avg'. 20 | fusion_types (tuple[str]): The fusion type for feature fusion. 21 | Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) 22 | """ 23 | 24 | def __init__(self, 25 | ratio=1 / 4., 26 | pooling_type='att', 27 | fusion_types=('channel_add', ), 28 | **kwargs): 29 | super(GCHead, self).__init__(num_convs=2, **kwargs) 30 | self.ratio = ratio 31 | self.pooling_type = pooling_type 32 | self.fusion_types = fusion_types 33 | self.gc_block = ContextBlock( 34 | in_channels=self.channels, 35 | ratio=self.ratio, 36 | pooling_type=self.pooling_type, 37 | fusion_types=self.fusion_types) 38 | 39 | def forward(self, inputs): 40 | """Forward function.""" 41 | x = self._transform_inputs(inputs) 42 | output = self.convs[0](x) 43 | output = self.gc_block(output) 44 | output = self.convs[1](output) 45 | if self.concat_input: 46 | output = self.conv_cat(torch.cat([x, output], dim=1)) 47 | output = self.cls_seg(output) 48 | return output 49 | -------------------------------------------------------------------------------- /marinext/mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > input_w: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class Upsample(nn.Module): 31 | 32 | def __init__(self, 33 | size=None, 34 | scale_factor=None, 35 | mode='nearest', 36 | align_corners=None): 37 | super(Upsample, self).__init__() 38 | self.size = size 39 | if isinstance(scale_factor, tuple): 40 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 41 | else: 42 | self.scale_factor = float(scale_factor) if scale_factor else None 43 | self.mode = mode 44 | self.align_corners = align_corners 45 | 46 | def forward(self, x): 47 | if not self.size: 48 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 49 | else: 50 | size = self.size 51 | return resize(x, size, None, self.mode, self.align_corners) 52 | -------------------------------------------------------------------------------- /marinext/mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | from packaging.version import parse 6 | 7 | from .version import __version__, version_info 8 | 9 | MMCV_MIN = '1.3.13' 10 | MMCV_MAX = '1.6.0' 11 | 12 | 13 | def digit_version(version_str: str, length: int = 4): 14 | """Convert a version string into a tuple of integers. 15 | 16 | This method is usually used for comparing two versions. For pre-release 17 | versions: alpha < beta < rc. 18 | 19 | Args: 20 | version_str (str): The version string. 21 | length (int): The maximum number of version levels. Default: 4. 22 | 23 | Returns: 24 | tuple[int]: The version info in digits (integers). 25 | """ 26 | version = parse(version_str) 27 | assert version.release, f'failed to parse version {version_str}' 28 | release = list(version.release) 29 | release = release[:length] 30 | if len(release) < length: 31 | release = release + [0] * (length - len(release)) 32 | if version.is_prerelease: 33 | mapping = {'a': -3, 'b': -2, 'rc': -1} 34 | val = -4 35 | # version.pre can be None 36 | if version.pre: 37 | if version.pre[0] not in mapping: 38 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 39 | 'version checking may go wrong') 40 | else: 41 | val = mapping[version.pre[0]] 42 | release.extend([val, version.pre[-1]]) 43 | else: 44 | release.extend([val, 0]) 45 | 46 | elif version.is_postrelease: 47 | release.extend([1, version.post]) 48 | else: 49 | release.extend([0, 0]) 50 | return tuple(release) 51 | 52 | 53 | mmcv_min_version = digit_version(MMCV_MIN) 54 | mmcv_max_version = digit_version(MMCV_MAX) 55 | mmcv_version = digit_version(mmcv.__version__) 56 | 57 | 58 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 59 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 60 | f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 61 | 62 | __all__ = ['__version__', 'version_info', 'digit_version'] 63 | -------------------------------------------------------------------------------- /marinext/mmseg/models/backbones/timm_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | try: 3 | import timm 4 | except ImportError: 5 | timm = None 6 | 7 | from mmcv.cnn.bricks.registry import NORM_LAYERS 8 | from mmcv.runner import BaseModule 9 | 10 | from ..builder import BACKBONES 11 | 12 | 13 | @BACKBONES.register_module() 14 | class TIMMBackbone(BaseModule): 15 | """Wrapper to use backbones from timm library. More details can be found in 16 | `timm `_ . 17 | 18 | Args: 19 | model_name (str): Name of timm model to instantiate. 20 | pretrained (bool): Load pretrained weights if True. 21 | checkpoint_path (str): Path of checkpoint to load after 22 | model is initialized. 23 | in_channels (int): Number of input image channels. Default: 3. 24 | init_cfg (dict, optional): Initialization config dict 25 | **kwargs: Other timm & model specific arguments. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model_name, 31 | features_only=True, 32 | pretrained=True, 33 | checkpoint_path='', 34 | in_channels=3, 35 | init_cfg=None, 36 | **kwargs, 37 | ): 38 | if timm is None: 39 | raise RuntimeError('timm is not installed') 40 | super(TIMMBackbone, self).__init__(init_cfg) 41 | if 'norm_layer' in kwargs: 42 | kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) 43 | self.timm_model = timm.create_model( 44 | model_name=model_name, 45 | features_only=features_only, 46 | pretrained=pretrained, 47 | in_chans=in_channels, 48 | checkpoint_path=checkpoint_path, 49 | **kwargs, 50 | ) 51 | 52 | # Make unused parameters None 53 | self.timm_model.global_pool = None 54 | self.timm_model.fc = None 55 | self.timm_model.classifier = None 56 | 57 | # Hack to use pretrained weights from timm 58 | if pretrained or checkpoint_path: 59 | self._is_init = True 60 | 61 | def forward(self, x): 62 | features = self.timm_model(x) 63 | return features 64 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/segformer_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.models.builder import HEADS 7 | from mmseg.models.decode_heads.decode_head import BaseDecodeHead 8 | from mmseg.ops import resize 9 | 10 | 11 | @HEADS.register_module() 12 | class SegformerHead(BaseDecodeHead): 13 | """The all mlp Head of segformer. 14 | 15 | This head is the implementation of 16 | `Segformer ` _. 17 | 18 | Args: 19 | interpolate_mode: The interpolate mode of MLP head upsample operation. 20 | Default: 'bilinear'. 21 | """ 22 | 23 | def __init__(self, interpolate_mode='bilinear', **kwargs): 24 | super().__init__(input_transform='multiple_select', **kwargs) 25 | 26 | self.interpolate_mode = interpolate_mode 27 | num_inputs = len(self.in_channels) 28 | 29 | assert num_inputs == len(self.in_index) 30 | 31 | self.convs = nn.ModuleList() 32 | for i in range(num_inputs): 33 | self.convs.append( 34 | ConvModule( 35 | in_channels=self.in_channels[i], 36 | out_channels=self.channels, 37 | kernel_size=1, 38 | stride=1, 39 | norm_cfg=self.norm_cfg, 40 | act_cfg=self.act_cfg)) 41 | 42 | self.fusion_conv = ConvModule( 43 | in_channels=self.channels * num_inputs, 44 | out_channels=self.channels, 45 | kernel_size=1, 46 | norm_cfg=self.norm_cfg) 47 | 48 | def forward(self, inputs): 49 | # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 50 | inputs = self._transform_inputs(inputs) 51 | outs = [] 52 | for idx in range(len(inputs)): 53 | x = inputs[idx] 54 | conv = self.convs[idx] 55 | outs.append( 56 | resize( 57 | input=conv(x), 58 | size=inputs[0].shape[2:], 59 | mode=self.interpolate_mode, 60 | align_corners=self.align_corners)) 61 | 62 | out = self.fusion_conv(torch.cat(outs, dim=1)) 63 | 64 | out = self.cls_seg(out) 65 | 66 | return out 67 | -------------------------------------------------------------------------------- /marinext/mmseg/models/utils/se_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from .make_divisible import make_divisible 7 | 8 | 9 | class SELayer(nn.Module): 10 | """Squeeze-and-Excitation Module. 11 | 12 | Args: 13 | channels (int): The input (and output) channels of the SE layer. 14 | ratio (int): Squeeze ratio in SELayer, the intermediate channel will be 15 | ``int(channels/ratio)``. Default: 16. 16 | conv_cfg (None or dict): Config dict for convolution layer. 17 | Default: None, which means using conv2d. 18 | act_cfg (dict or Sequence[dict]): Config dict for activation layer. 19 | If act_cfg is a dict, two activation layers will be configured 20 | by this dict. If act_cfg is a sequence of dicts, the first 21 | activation layer will be configured by the first dict and the 22 | second activation layer will be configured by the second dict. 23 | Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, 24 | divisor=6.0)). 25 | """ 26 | 27 | def __init__(self, 28 | channels, 29 | ratio=16, 30 | conv_cfg=None, 31 | act_cfg=(dict(type='ReLU'), 32 | dict(type='HSigmoid', bias=3.0, divisor=6.0))): 33 | super(SELayer, self).__init__() 34 | if isinstance(act_cfg, dict): 35 | act_cfg = (act_cfg, act_cfg) 36 | assert len(act_cfg) == 2 37 | assert mmcv.is_tuple_of(act_cfg, dict) 38 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 39 | self.conv1 = ConvModule( 40 | in_channels=channels, 41 | out_channels=make_divisible(channels // ratio, 8), 42 | kernel_size=1, 43 | stride=1, 44 | conv_cfg=conv_cfg, 45 | act_cfg=act_cfg[0]) 46 | self.conv2 = ConvModule( 47 | in_channels=make_divisible(channels // ratio, 8), 48 | out_channels=channels, 49 | kernel_size=1, 50 | stride=1, 51 | conv_cfg=conv_cfg, 52 | act_cfg=act_cfg[1]) 53 | 54 | def forward(self, x): 55 | out = self.global_avgpool(x) 56 | out = self.conv1(out) 57 | out = self.conv2(out) 58 | return x * out 59 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/setr_mla_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import Upsample 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class SETRMLAHead(BaseDecodeHead): 13 | """Multi level feature aggretation head of SETR. 14 | 15 | MLA head of `SETR `_. 16 | 17 | Args: 18 | mlahead_channels (int): Channels of conv-conv-4x of multi-level feature 19 | aggregation. Default: 128. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | """ 22 | 23 | def __init__(self, mla_channels=128, up_scale=4, **kwargs): 24 | super(SETRMLAHead, self).__init__( 25 | input_transform='multiple_select', **kwargs) 26 | self.mla_channels = mla_channels 27 | 28 | num_inputs = len(self.in_channels) 29 | 30 | # Refer to self.cls_seg settings of BaseDecodeHead 31 | assert self.channels == num_inputs * mla_channels 32 | 33 | self.up_convs = nn.ModuleList() 34 | for i in range(num_inputs): 35 | self.up_convs.append( 36 | nn.Sequential( 37 | ConvModule( 38 | in_channels=self.in_channels[i], 39 | out_channels=mla_channels, 40 | kernel_size=3, 41 | padding=1, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg), 44 | ConvModule( 45 | in_channels=mla_channels, 46 | out_channels=mla_channels, 47 | kernel_size=3, 48 | padding=1, 49 | norm_cfg=self.norm_cfg, 50 | act_cfg=self.act_cfg), 51 | Upsample( 52 | scale_factor=up_scale, 53 | mode='bilinear', 54 | align_corners=self.align_corners))) 55 | 56 | def forward(self, inputs): 57 | inputs = self._transform_inputs(inputs) 58 | outs = [] 59 | for x, up_conv in zip(inputs, self.up_convs): 60 | outs.append(up_conv(x)) 61 | out = torch.cat(outs, dim=1) 62 | out = self.cls_seg(out) 63 | return out 64 | -------------------------------------------------------------------------------- /marinext/mmseg/utils/set_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import platform 4 | 5 | import cv2 6 | import torch.multiprocessing as mp 7 | 8 | from ..utils import get_root_logger 9 | 10 | 11 | def setup_multi_processes(cfg): 12 | """Setup multi-processing environment variables.""" 13 | logger = get_root_logger() 14 | 15 | # set multi-process start method 16 | if platform.system() != 'Windows': 17 | mp_start_method = cfg.get('mp_start_method', None) 18 | current_method = mp.get_start_method(allow_none=True) 19 | if mp_start_method in ('fork', 'spawn', 'forkserver'): 20 | logger.info( 21 | f'Multi-processing start method `{mp_start_method}` is ' 22 | f'different from the previous setting `{current_method}`.' 23 | f'It will be force set to `{mp_start_method}`.') 24 | mp.set_start_method(mp_start_method, force=True) 25 | else: 26 | logger.info( 27 | f'Multi-processing start method is `{mp_start_method}`') 28 | 29 | # disable opencv multithreading to avoid system being overloaded 30 | opencv_num_threads = cfg.get('opencv_num_threads', None) 31 | if isinstance(opencv_num_threads, int): 32 | logger.info(f'OpenCV num_threads is `{opencv_num_threads}`') 33 | cv2.setNumThreads(opencv_num_threads) 34 | else: 35 | logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}') 36 | 37 | if cfg.data.workers_per_gpu > 1: 38 | # setup OMP threads 39 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 40 | omp_num_threads = cfg.get('omp_num_threads', None) 41 | if 'OMP_NUM_THREADS' not in os.environ: 42 | if isinstance(omp_num_threads, int): 43 | logger.info(f'OMP num threads is {omp_num_threads}') 44 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 45 | else: 46 | logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }') 47 | 48 | # setup MKL threads 49 | if 'MKL_NUM_THREADS' not in os.environ: 50 | mkl_num_threads = cfg.get('mkl_num_threads', None) 51 | if isinstance(mkl_num_threads, int): 52 | logger.info(f'MKL num threads is {mkl_num_threads}') 53 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 54 | else: 55 | logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/MADOS 132 | random_forest/rf_classifier.joblib 133 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/cascade_decode_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from .decode_head import BaseDecodeHead 5 | 6 | 7 | class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): 8 | """Base class for cascade decode head used in 9 | :class:`CascadeEncoderDecoder.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) 13 | 14 | @abstractmethod 15 | def forward(self, inputs, prev_output): 16 | """Placeholder of forward function.""" 17 | pass 18 | 19 | def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, 20 | train_cfg): 21 | """Forward function for training. 22 | Args: 23 | inputs (list[Tensor]): List of multi-level img features. 24 | prev_output (Tensor): The output of previous decode head. 25 | img_metas (list[dict]): List of image info dict where each dict 26 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 27 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 28 | For details on the values of these keys see 29 | `mmseg/datasets/pipelines/formatting.py:Collect`. 30 | gt_semantic_seg (Tensor): Semantic segmentation masks 31 | used if the architecture supports semantic segmentation task. 32 | train_cfg (dict): The training config. 33 | 34 | Returns: 35 | dict[str, Tensor]: a dictionary of loss components 36 | """ 37 | seg_logits = self.forward(inputs, prev_output) 38 | losses = self.losses(seg_logits, gt_semantic_seg) 39 | 40 | return losses 41 | 42 | def forward_test(self, inputs, prev_output, img_metas, test_cfg): 43 | """Forward function for testing. 44 | 45 | Args: 46 | inputs (list[Tensor]): List of multi-level img features. 47 | prev_output (Tensor): The output of previous decode head. 48 | img_metas (list[dict]): List of image info dict where each dict 49 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 50 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 51 | For details on the values of these keys see 52 | `mmseg/datasets/pipelines/formatting.py:Collect`. 53 | test_cfg (dict): The testing config. 54 | 55 | Returns: 56 | Tensor: Output segmentation map. 57 | """ 58 | return self.forward(inputs, prev_output) 59 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/sep_fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import DepthwiseSeparableConvModule 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | 8 | @HEADS.register_module() 9 | class DepthwiseSeparableFCNHead(FCNHead): 10 | """Depthwise-Separable Fully Convolutional Network for Semantic 11 | Segmentation. 12 | 13 | This head is implemented according to `Fast-SCNN: Fast Semantic 14 | Segmentation Network `_. 15 | 16 | Args: 17 | in_channels(int): Number of output channels of FFM. 18 | channels(int): Number of middle-stage channels in the decode head. 19 | concat_input(bool): Whether to concatenate original decode input into 20 | the result of several consecutive convolution layers. 21 | Default: True. 22 | num_classes(int): Used to determine the dimension of 23 | final prediction tensor. 24 | in_index(int): Correspond with 'out_indices' in FastSCNN backbone. 25 | norm_cfg (dict | None): Config of norm layers. 26 | align_corners (bool): align_corners argument of F.interpolate. 27 | Default: False. 28 | loss_decode(dict): Config of loss type and some 29 | relevant additional options. 30 | dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is 31 | 'default', it will be the same as `act_cfg`. Default: None. 32 | """ 33 | 34 | def __init__(self, dw_act_cfg=None, **kwargs): 35 | super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) 36 | self.convs[0] = DepthwiseSeparableConvModule( 37 | self.in_channels, 38 | self.channels, 39 | kernel_size=self.kernel_size, 40 | padding=self.kernel_size // 2, 41 | norm_cfg=self.norm_cfg, 42 | dw_act_cfg=dw_act_cfg) 43 | 44 | for i in range(1, self.num_convs): 45 | self.convs[i] = DepthwiseSeparableConvModule( 46 | self.channels, 47 | self.channels, 48 | kernel_size=self.kernel_size, 49 | padding=self.kernel_size // 2, 50 | norm_cfg=self.norm_cfg, 51 | dw_act_cfg=dw_act_cfg) 52 | 53 | if self.concat_input: 54 | self.conv_cat = DepthwiseSeparableConvModule( 55 | self.in_channels + self.channels, 56 | self.channels, 57 | kernel_size=self.kernel_size, 58 | padding=self.kernel_size // 2, 59 | norm_cfg=self.norm_cfg, 60 | dw_act_cfg=dw_act_cfg) 61 | -------------------------------------------------------------------------------- /marinext/mmseg/models/necks/featurepyramid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import build_norm_layer 4 | 5 | from ..builder import NECKS 6 | 7 | 8 | @NECKS.register_module() 9 | class Feature2Pyramid(nn.Module): 10 | """Feature2Pyramid. 11 | 12 | A neck structure connect ViT backbone and decoder_heads. 13 | 14 | Args: 15 | embed_dims (int): Embedding dimension. 16 | rescales (list[float]): Different sampling multiples were 17 | used to obtain pyramid features. Default: [4, 2, 1, 0.5]. 18 | norm_cfg (dict): Config dict for normalization layer. 19 | Default: dict(type='SyncBN', requires_grad=True). 20 | """ 21 | 22 | def __init__(self, 23 | embed_dim, 24 | rescales=[4, 2, 1, 0.5], 25 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 26 | super(Feature2Pyramid, self).__init__() 27 | self.rescales = rescales 28 | self.upsample_4x = None 29 | for k in self.rescales: 30 | if k == 4: 31 | self.upsample_4x = nn.Sequential( 32 | nn.ConvTranspose2d( 33 | embed_dim, embed_dim, kernel_size=2, stride=2), 34 | build_norm_layer(norm_cfg, embed_dim)[1], 35 | nn.GELU(), 36 | nn.ConvTranspose2d( 37 | embed_dim, embed_dim, kernel_size=2, stride=2), 38 | ) 39 | elif k == 2: 40 | self.upsample_2x = nn.Sequential( 41 | nn.ConvTranspose2d( 42 | embed_dim, embed_dim, kernel_size=2, stride=2)) 43 | elif k == 1: 44 | self.identity = nn.Identity() 45 | elif k == 0.5: 46 | self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) 47 | elif k == 0.25: 48 | self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) 49 | else: 50 | raise KeyError(f'invalid {k} for feature2pyramid') 51 | 52 | def forward(self, inputs): 53 | assert len(inputs) == len(self.rescales) 54 | outputs = [] 55 | if self.upsample_4x is not None: 56 | ops = [ 57 | self.upsample_4x, self.upsample_2x, self.identity, 58 | self.downsample_2x 59 | ] 60 | else: 61 | ops = [ 62 | self.upsample_2x, self.identity, self.downsample_2x, 63 | self.downsample_4x 64 | ] 65 | for i in range(len(inputs)): 66 | outputs.append(ops[i](inputs[i])) 67 | return tuple(outputs) 68 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/fpn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import Upsample, resize 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class FPNHead(BaseDecodeHead): 13 | """Panoptic Feature Pyramid Networks. 14 | 15 | This head is the implementation of `Semantic FPN 16 | `_. 17 | 18 | Args: 19 | feature_strides (tuple[int]): The strides for input feature maps. 20 | stack_lateral. All strides suppose to be power of 2. The first 21 | one is of largest resolution. 22 | """ 23 | 24 | def __init__(self, feature_strides, **kwargs): 25 | super(FPNHead, self).__init__( 26 | input_transform='multiple_select', **kwargs) 27 | assert len(feature_strides) == len(self.in_channels) 28 | assert min(feature_strides) == feature_strides[0] 29 | self.feature_strides = feature_strides 30 | 31 | self.scale_heads = nn.ModuleList() 32 | for i in range(len(feature_strides)): 33 | head_length = max( 34 | 1, 35 | int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) 36 | scale_head = [] 37 | for k in range(head_length): 38 | scale_head.append( 39 | ConvModule( 40 | self.in_channels[i] if k == 0 else self.channels, 41 | self.channels, 42 | 3, 43 | padding=1, 44 | conv_cfg=self.conv_cfg, 45 | norm_cfg=self.norm_cfg, 46 | act_cfg=self.act_cfg)) 47 | if feature_strides[i] != feature_strides[0]: 48 | scale_head.append( 49 | Upsample( 50 | scale_factor=2, 51 | mode='bilinear', 52 | align_corners=self.align_corners)) 53 | self.scale_heads.append(nn.Sequential(*scale_head)) 54 | 55 | def forward(self, inputs): 56 | 57 | x = self._transform_inputs(inputs) 58 | 59 | output = self.scale_heads[0](x[0]) 60 | for i in range(1, len(self.feature_strides)): 61 | # non inplace 62 | output = output + resize( 63 | self.scale_heads[i](x[i]), 64 | size=output.shape[2:], 65 | mode='bilinear', 66 | align_corners=self.align_corners) 67 | 68 | output = self.cls_seg(output) 69 | return output 70 | -------------------------------------------------------------------------------- /marinext/mmseg/models/necks/multilevel_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, xavier_init 4 | 5 | from mmseg.ops import resize 6 | from ..builder import NECKS 7 | 8 | 9 | @NECKS.register_module() 10 | class MultiLevelNeck(nn.Module): 11 | """MultiLevelNeck. 12 | 13 | A neck structure connect vit backbone and decoder_heads. 14 | 15 | Args: 16 | in_channels (List[int]): Number of input channels per scale. 17 | out_channels (int): Number of output channels (used at each scale). 18 | scales (List[float]): Scale factors for each input feature map. 19 | Default: [0.5, 1, 2, 4] 20 | norm_cfg (dict): Config dict for normalization layer. Default: None. 21 | act_cfg (dict): Config dict for activation layer in ConvModule. 22 | Default: None. 23 | """ 24 | 25 | def __init__(self, 26 | in_channels, 27 | out_channels, 28 | scales=[0.5, 1, 2, 4], 29 | norm_cfg=None, 30 | act_cfg=None): 31 | super(MultiLevelNeck, self).__init__() 32 | assert isinstance(in_channels, list) 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.scales = scales 36 | self.num_outs = len(scales) 37 | self.lateral_convs = nn.ModuleList() 38 | self.convs = nn.ModuleList() 39 | for in_channel in in_channels: 40 | self.lateral_convs.append( 41 | ConvModule( 42 | in_channel, 43 | out_channels, 44 | kernel_size=1, 45 | norm_cfg=norm_cfg, 46 | act_cfg=act_cfg)) 47 | for _ in range(self.num_outs): 48 | self.convs.append( 49 | ConvModule( 50 | out_channels, 51 | out_channels, 52 | kernel_size=3, 53 | padding=1, 54 | stride=1, 55 | norm_cfg=norm_cfg, 56 | act_cfg=act_cfg)) 57 | 58 | # default init_weights for conv(msra) and norm in ConvModule 59 | def init_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | xavier_init(m, distribution='uniform') 63 | 64 | def forward(self, inputs): 65 | assert len(inputs) == len(self.in_channels) 66 | inputs = [ 67 | lateral_conv(inputs[i]) 68 | for i, lateral_conv in enumerate(self.lateral_convs) 69 | ] 70 | # for len(inputs) not equal to self.num_outs 71 | if len(inputs) == 1: 72 | inputs = [inputs[0] for _ in range(self.num_outs)] 73 | outs = [] 74 | for i in range(self.num_outs): 75 | x_resize = resize( 76 | inputs[i], scale_factor=self.scales[i], mode='bilinear') 77 | outs.append(self.convs[i](x_resize)) 78 | return tuple(outs) 79 | -------------------------------------------------------------------------------- /marinext/mmseg/utils/util_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch 4 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 5 | 6 | from mmseg import digit_version 7 | 8 | dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} 9 | 10 | ddp_factory = {'cuda': MMDistributedDataParallel} 11 | 12 | 13 | def build_dp(model, device='cuda', dim=0, *args, **kwargs): 14 | """build DataParallel module by device type. 15 | 16 | if device is cuda, return a MMDataParallel module; if device is mlu, 17 | return a MLUDataParallel module. 18 | 19 | Args: 20 | model (:class:`nn.Module`): module to be parallelized. 21 | device (str): device type, cuda, cpu or mlu. Defaults to cuda. 22 | dim (int): Dimension used to scatter the data. Defaults to 0. 23 | 24 | Returns: 25 | :class:`nn.Module`: parallelized module. 26 | """ 27 | if device == 'cuda': 28 | model = model.cuda() 29 | elif device == 'mlu': 30 | assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ 31 | 'Please use MMCV >= 1.5.0 for MLU training!' 32 | from mmcv.device.mlu import MLUDataParallel 33 | dp_factory['mlu'] = MLUDataParallel 34 | model = model.mlu() 35 | 36 | return dp_factory[device](model, dim=dim, *args, **kwargs) 37 | 38 | 39 | def build_ddp(model, device='cuda', *args, **kwargs): 40 | """Build DistributedDataParallel module by device type. 41 | 42 | If device is cuda, return a MMDistributedDataParallel module; 43 | if device is mlu, return a MLUDistributedDataParallel module. 44 | 45 | Args: 46 | model (:class:`nn.Module`): module to be parallelized. 47 | device (str): device type, mlu or cuda. 48 | 49 | Returns: 50 | :class:`nn.Module`: parallelized module. 51 | 52 | References: 53 | .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. 54 | DistributedDataParallel.html 55 | """ 56 | assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' 57 | if device == 'cuda': 58 | model = model.cuda() 59 | elif device == 'mlu': 60 | assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ 61 | 'Please use MMCV >= 1.5.0 for MLU training!' 62 | from mmcv.device.mlu import MLUDistributedDataParallel 63 | ddp_factory['mlu'] = MLUDistributedDataParallel 64 | model = model.mlu() 65 | 66 | return ddp_factory[device](model, *args, **kwargs) 67 | 68 | 69 | def is_mlu_available(): 70 | """Returns a bool indicating if MLU is currently available.""" 71 | return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() 72 | 73 | 74 | def get_device(): 75 | """Returns an available device, cpu, cuda or mlu.""" 76 | is_device_available = { 77 | 'cuda': torch.cuda.is_available(), 78 | 'mlu': is_mlu_available() 79 | } 80 | device_list = [k for k, v in is_device_available.items() if v] 81 | return device_list[0] if len(device_list) == 1 else 'cpu' 82 | -------------------------------------------------------------------------------- /marinext/mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Encoding(nn.Module): 8 | """Encoding Layer: a learnable residual encoder. 9 | 10 | Input is of shape (batch_size, channels, height, width). 11 | Output is of shape (batch_size, num_codes, channels). 12 | 13 | Args: 14 | channels: dimension of the features or feature channels 15 | num_codes: number of code words 16 | """ 17 | 18 | def __init__(self, channels, num_codes): 19 | super(Encoding, self).__init__() 20 | # init codewords and smoothing factor 21 | self.channels, self.num_codes = channels, num_codes 22 | std = 1. / ((num_codes * channels)**0.5) 23 | # [num_codes, channels] 24 | self.codewords = nn.Parameter( 25 | torch.empty(num_codes, channels, 26 | dtype=torch.float).uniform_(-std, std), 27 | requires_grad=True) 28 | # [num_codes] 29 | self.scale = nn.Parameter( 30 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 31 | requires_grad=True) 32 | 33 | @staticmethod 34 | def scaled_l2(x, codewords, scale): 35 | num_codes, channels = codewords.size() 36 | batch_size = x.size(0) 37 | reshaped_scale = scale.view((1, 1, num_codes)) 38 | expanded_x = x.unsqueeze(2).expand( 39 | (batch_size, x.size(1), num_codes, channels)) 40 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 41 | 42 | scaled_l2_norm = reshaped_scale * ( 43 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 44 | return scaled_l2_norm 45 | 46 | @staticmethod 47 | def aggregate(assignment_weights, x, codewords): 48 | num_codes, channels = codewords.size() 49 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 50 | batch_size = x.size(0) 51 | 52 | expanded_x = x.unsqueeze(2).expand( 53 | (batch_size, x.size(1), num_codes, channels)) 54 | encoded_feat = (assignment_weights.unsqueeze(3) * 55 | (expanded_x - reshaped_codewords)).sum(dim=1) 56 | return encoded_feat 57 | 58 | def forward(self, x): 59 | assert x.dim() == 4 and x.size(1) == self.channels 60 | # [batch_size, channels, height, width] 61 | batch_size = x.size(0) 62 | # [batch_size, height x width, channels] 63 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 64 | # assignment_weights: [batch_size, channels, num_codes] 65 | assignment_weights = F.softmax( 66 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 67 | # aggregate 68 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 69 | return encoded_feat 70 | 71 | def __repr__(self): 72 | repr_str = self.__class__.__name__ 73 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 74 | f'x{self.channels})' 75 | return repr_str 76 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/setr_up_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | 5 | from mmseg.ops import Upsample 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @HEADS.register_module() 11 | class SETRUPHead(BaseDecodeHead): 12 | """Naive upsampling head and Progressive upsampling head of SETR. 13 | 14 | Naive or PUP head of `SETR `_. 15 | 16 | Args: 17 | norm_layer (dict): Config dict for input normalization. 18 | Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). 19 | num_convs (int): Number of decoder convolutions. Default: 1. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | kernel_size (int): The kernel size of convolution when decoding 22 | feature information from backbone. Default: 3. 23 | init_cfg (dict | list[dict] | None): Initialization config dict. 24 | Default: dict( 25 | type='Constant', val=1.0, bias=0, layer='LayerNorm'). 26 | """ 27 | 28 | def __init__(self, 29 | norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), 30 | num_convs=1, 31 | up_scale=4, 32 | kernel_size=3, 33 | init_cfg=[ 34 | dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), 35 | dict( 36 | type='Normal', 37 | std=0.01, 38 | override=dict(name='conv_seg')) 39 | ], 40 | **kwargs): 41 | 42 | assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' 43 | 44 | super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs) 45 | 46 | assert isinstance(self.in_channels, int) 47 | 48 | _, self.norm = build_norm_layer(norm_layer, self.in_channels) 49 | 50 | self.up_convs = nn.ModuleList() 51 | in_channels = self.in_channels 52 | out_channels = self.channels 53 | for _ in range(num_convs): 54 | self.up_convs.append( 55 | nn.Sequential( 56 | ConvModule( 57 | in_channels=in_channels, 58 | out_channels=out_channels, 59 | kernel_size=kernel_size, 60 | stride=1, 61 | padding=int(kernel_size - 1) // 2, 62 | norm_cfg=self.norm_cfg, 63 | act_cfg=self.act_cfg), 64 | Upsample( 65 | scale_factor=up_scale, 66 | mode='bilinear', 67 | align_corners=self.align_corners))) 68 | in_channels = out_channels 69 | 70 | def forward(self, x): 71 | x = self._transform_inputs(x) 72 | 73 | n, c, h, w = x.shape 74 | x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() 75 | x = self.norm(x) 76 | x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() 77 | 78 | for up_conv in self.up_convs: 79 | x = up_conv(x) 80 | out = self.cls_seg(x) 81 | return out 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![MADOS Logo](./.github/MADOS_LOGO_text.png) 2 | 3 | 4 | [[`paper`](https://www.sciencedirect.com/science/article/pii/S0924271624000625)][[`project page`](https://marine-pollution.github.io/)][[`dataset`](https://zenodo.org/records/10664073)] 5 | 6 | Marine Debris and Oil Spill (MADOS) is a marine pollution dataset based on Sentinel-2 remote sensing data, focusing on marine litter and oil spills. Other sea surface features that coexist with or have been suggested to be spectrally similar to them have also been considered. MADOS formulates a challenging semantic segmentation task using sparse annotations. 7 | 8 | In order to download MADOS go to https://doi.org/10.5281/zenodo.10664073. 9 | 10 | ## Installation 11 | 12 | ```bash 13 | conda create -n mados python=3.8.12 14 | 15 | conda activate mados 16 | 17 | conda install -c conda-forge gdal==3.3.2 18 | 19 | pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11/index.html 20 | 21 | conda install pytables==3.7.0 22 | ``` 23 | 24 | ## Evaluate MariNeXt 25 | 26 | To evaluate MariNeXt, place the MADOS dataset under the `data` folder, download the pretrained models (5 different runs) from [here](https://drive.google.com/drive/folders/1VwkFp47TEvRVXHNbucBmmylfZwIUmCWx?usp=drive_link) and place them under the `marinext/trained_models` folder and then run the following: 27 | 28 | ```bash 29 | python marinext/evaluation.py --path ./data/MADOS --model_path marinext/trained_models/1 30 | ``` 31 | 32 | ## Train MariNeXt 33 | 34 | To train MariNeXt from scratch, run the following: 35 | 36 | 37 | ```bash 38 | python marinext/train.py --path ./data/MADOS 39 | ``` 40 | 41 | ## Stack Patches 42 | 43 | To stack the image patches to form multispectral images, run the following: 44 | 45 | 46 | ```bash 47 | python utils/stack_patches.py --path ./data/MADOS 48 | ``` 49 | 50 | ## Spectral Signatures Extraction 51 | To extract the spectal signatures of MADOS dataset (after stacking) and store them in a HDF5 Table file (DataFrame-like) run the following: 52 | 53 | ```bash 54 | python utils/spectral_extraction.py --path ./data/MADOS_nearest 55 | ``` 56 | 57 | Alternatively, you can download the `dataset.h5` file from [here](https://drive.google.com/file/d/1BUIxcm1SLU9sqr8NE2FKJvJJPv2RLyk-/view?usp=sharing). 58 | 59 | To load the `dataset.h5`, run in a python cell the following: 60 | 61 | ```python 62 | import pandas as pd 63 | 64 | hdf = pd.HDFStore('./data/dataset.h5', mode = 'r') 65 | 66 | df_train = hdf.select('Train') 67 | df_val = hdf.select('Validation') 68 | df_test = hdf.select('Test') 69 | 70 | hdf.close() 71 | ``` 72 | 73 | 74 | ## Acknowledgment 75 | 76 | This implementation is mainly based on [MARIDA](https://github.com/marine-debris/marine-debris.github.io), [SegNeXt](https://github.com/Visual-Attention-Network/SegNeXt), [mmsegmentaion](https://github.com/open-mmlab/mmsegmentation/tree/v0.24.1), [Segformer](https://github.com/NVlabs/SegFormer) and [Enjoy-Hamburger](https://github.com/Gsunshine/Enjoy-Hamburger). 77 | 78 | 79 | 80 | If you find this repository useful, please consider giving a star :star: and citation: 81 | > Kikaki K., Kakogeorgiou I., Hoteit I., Karantzalos K. Detecting Marine Pollutants and Sea Surface Features with Deep Learning in Sentinel-2 Imagery. ISPRS Journal of Photogrammetry and Remote Sensing, 2024. 82 | 83 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/lraspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv import is_tuple_of 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | @HEADS.register_module() 13 | class LRASPPHead(BaseDecodeHead): 14 | """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. 15 | 16 | This head is the improved implementation of `Searching for MobileNetV3 17 | `_. 18 | 19 | Args: 20 | branch_channels (tuple[int]): The number of output channels in every 21 | each branch. Default: (32, 64). 22 | """ 23 | 24 | def __init__(self, branch_channels=(32, 64), **kwargs): 25 | super(LRASPPHead, self).__init__(**kwargs) 26 | if self.input_transform != 'multiple_select': 27 | raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' 28 | f'must be \'multiple_select\'. But received ' 29 | f'\'{self.input_transform}\'') 30 | assert is_tuple_of(branch_channels, int) 31 | assert len(branch_channels) == len(self.in_channels) - 1 32 | self.branch_channels = branch_channels 33 | 34 | self.convs = nn.Sequential() 35 | self.conv_ups = nn.Sequential() 36 | for i in range(len(branch_channels)): 37 | self.convs.add_module( 38 | f'conv{i}', 39 | nn.Conv2d( 40 | self.in_channels[i], branch_channels[i], 1, bias=False)) 41 | self.conv_ups.add_module( 42 | f'conv_up{i}', 43 | ConvModule( 44 | self.channels + branch_channels[i], 45 | self.channels, 46 | 1, 47 | norm_cfg=self.norm_cfg, 48 | act_cfg=self.act_cfg, 49 | bias=False)) 50 | 51 | self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) 52 | 53 | self.aspp_conv = ConvModule( 54 | self.in_channels[-1], 55 | self.channels, 56 | 1, 57 | norm_cfg=self.norm_cfg, 58 | act_cfg=self.act_cfg, 59 | bias=False) 60 | self.image_pool = nn.Sequential( 61 | nn.AvgPool2d(kernel_size=49, stride=(16, 20)), 62 | ConvModule( 63 | self.in_channels[2], 64 | self.channels, 65 | 1, 66 | act_cfg=dict(type='Sigmoid'), 67 | bias=False)) 68 | 69 | def forward(self, inputs): 70 | """Forward function.""" 71 | inputs = self._transform_inputs(inputs) 72 | 73 | x = inputs[-1] 74 | 75 | x = self.aspp_conv(x) * resize( 76 | self.image_pool(x), 77 | size=x.size()[2:], 78 | mode='bilinear', 79 | align_corners=self.align_corners) 80 | x = self.conv_up_input(x) 81 | 82 | for i in range(len(self.branch_channels) - 1, -1, -1): 83 | x = resize( 84 | x, 85 | size=inputs[i].size()[2:], 86 | mode='bilinear', 87 | align_corners=self.align_corners) 88 | x = torch.cat([x, self.convs[i](inputs[i])], 1) 89 | x = self.conv_ups[i](x) 90 | 91 | return self.cls_seg(x) 92 | -------------------------------------------------------------------------------- /utils/assets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Author: Ioannis Kakogeorgiou 4 | Email: gkakogeorgiou@gmail.com 5 | Python Version: 3.7.10 6 | Description: assets.py includes the appropriate mappings. 7 | ''' 8 | import numpy as np 9 | 10 | mados_cat_mapping = {'Marine Debris': 1, 11 | 'Dense Sargassum': 2, 12 | 'Sparse Floating Algae': 3, 13 | 'Natural Organic Material': 4, 14 | 'Ship': 5, 15 | 'Oil Spill': 6, 16 | 'Marine Water': 7, 17 | 'Sediment-Laden Water': 8, 18 | 'Foam': 9, 19 | 'Turbid Water': 10, 20 | 'Shallow Water': 11, 21 | 'Waves & Wakes': 12, 22 | 'Oil Platform': 13, 23 | 'Jellyfish': 14, 24 | 'Sea snot': 15} 25 | 26 | mados_color_mapping = { 'Marine Debris': 'red', 27 | 'Dense Sargassum': 'green', 28 | 'Sparse Floating Algae': 'limegreen', 29 | 'Natural Organic Material': 'brown', 30 | 'Ship': 'orange', 31 | 'Oil Spill': 'thistle', 32 | 'Marine Water': 'navy', 33 | 'Sediment-Laden Water': 'gold', 34 | 'Foam': 'purple', 35 | 'Turbid Water': 'darkkhaki', 36 | 'Shallow Water': 'darkturquoise', 37 | 'Waves & Wakes': 'bisque', 38 | 'Oil Platform': 'dimgrey', 39 | 'Jellyfish': 'hotpink', 40 | 'Sea snot': 'yellow'} 41 | 42 | labels = ['Marine Debris', 'Dense Sargassum', 'Sparse Floating Algae', 'Natural Organic Material', 43 | 'Ship', 'Oil Spill', 'Marine Water', 'Sediment-Laden Water', 'Foam', 44 | 'Turbid Water', 'Shallow Water', 'Waves & Wakes', 'Oil Platform', 'Jellyfish', 'Sea snot'] 45 | 46 | s2_mapping = {'nm440': 0, 47 | 'nm490': 1, 48 | 'nm560': 2, 49 | 'nm665': 3, 50 | 'nm705': 4, 51 | 'nm740': 5, 52 | 'nm783': 6, 53 | 'nm842': 7, 54 | 'nm865': 8, 55 | 'nm1600': 9, 56 | 'nm2200': 10, 57 | 'Class': 11, 58 | 'Confidence': 12, 59 | 'Report': 13} 60 | 61 | conf_mapping = {'High': 1, 62 | 'Moderate': 2, 63 | 'Low': 3} 64 | 65 | report_mapping = {'Very close': 1, 66 | 'Away': 2, 67 | 'No': 3} 68 | 69 | def cat_map(x): 70 | return mados_cat_mapping[x] 71 | 72 | cat_mapping_vec = np.vectorize(cat_map) 73 | 74 | def bool_flag(s): 75 | """ 76 | Parse boolean arguments from the command line. 77 | """ 78 | FALSY_STRINGS = {"off", "false", "0"} 79 | TRUTHY_STRINGS = {"on", "true", "1"} 80 | if s.lower() in FALSY_STRINGS: 81 | return False 82 | elif s.lower() in TRUTHY_STRINGS: 83 | return True 84 | else: 85 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 86 | 87 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 88 | warmup_schedule = np.array([]) 89 | warmup_iters = warmup_epochs * niter_per_ep 90 | if warmup_epochs > 0: 91 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 92 | 93 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 94 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 95 | 96 | schedule = np.concatenate((warmup_schedule, schedule)) 97 | assert len(schedule) == epochs * niter_per_ep 98 | return schedule -------------------------------------------------------------------------------- /marinext/mmseg/models/segmentors/cascade_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from torch import nn 3 | 4 | from mmseg.core import add_prefix 5 | from mmseg.ops import resize 6 | from .. import builder 7 | from ..builder import SEGMENTORS 8 | from .encoder_decoder import EncoderDecoder 9 | 10 | 11 | @SEGMENTORS.register_module() 12 | class CascadeEncoderDecoder(EncoderDecoder): 13 | """Cascade Encoder Decoder segmentors. 14 | 15 | CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of 16 | CascadeEncoderDecoder are cascaded. The output of previous decoder_head 17 | will be the input of next decoder_head. 18 | """ 19 | 20 | def __init__(self, 21 | num_stages, 22 | backbone, 23 | decode_head, 24 | neck=None, 25 | auxiliary_head=None, 26 | train_cfg=None, 27 | test_cfg=None, 28 | pretrained=None, 29 | init_cfg=None): 30 | self.num_stages = num_stages 31 | super(CascadeEncoderDecoder, self).__init__( 32 | backbone=backbone, 33 | decode_head=decode_head, 34 | neck=neck, 35 | auxiliary_head=auxiliary_head, 36 | train_cfg=train_cfg, 37 | test_cfg=test_cfg, 38 | pretrained=pretrained, 39 | init_cfg=init_cfg) 40 | 41 | def _init_decode_head(self, decode_head): 42 | """Initialize ``decode_head``""" 43 | assert isinstance(decode_head, list) 44 | assert len(decode_head) == self.num_stages 45 | self.decode_head = nn.ModuleList() 46 | for i in range(self.num_stages): 47 | self.decode_head.append(builder.build_head(decode_head[i])) 48 | self.align_corners = self.decode_head[-1].align_corners 49 | self.num_classes = self.decode_head[-1].num_classes 50 | 51 | def encode_decode(self, img, img_metas): 52 | """Encode images with backbone and decode into a semantic segmentation 53 | map of the same size as input.""" 54 | x = self.extract_feat(img) 55 | out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) 56 | for i in range(1, self.num_stages): 57 | out = self.decode_head[i].forward_test(x, out, img_metas, 58 | self.test_cfg) 59 | out = resize( 60 | input=out, 61 | size=img.shape[2:], 62 | mode='bilinear', 63 | align_corners=self.align_corners) 64 | return out 65 | 66 | def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): 67 | """Run forward function and calculate loss for decode head in 68 | training.""" 69 | losses = dict() 70 | 71 | loss_decode = self.decode_head[0].forward_train( 72 | x, img_metas, gt_semantic_seg, self.train_cfg) 73 | 74 | losses.update(add_prefix(loss_decode, 'decode_0')) 75 | 76 | for i in range(1, self.num_stages): 77 | # forward test again, maybe unnecessary for most methods. 78 | if i == 1: 79 | prev_outputs = self.decode_head[0].forward_test( 80 | x, img_metas, self.test_cfg) 81 | else: 82 | prev_outputs = self.decode_head[i - 1].forward_test( 83 | x, prev_outputs, img_metas, self.test_cfg) 84 | loss_decode = self.decode_head[i].forward_train( 85 | x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) 86 | losses.update(add_prefix(loss_decode, f'decode_{i}')) 87 | 88 | return losses 89 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @HEADS.register_module() 11 | class FCNHead(BaseDecodeHead): 12 | """Fully Convolution Networks for Semantic Segmentation. 13 | 14 | This head is implemented of `FCNNet `_. 15 | 16 | Args: 17 | num_convs (int): Number of convs in the head. Default: 2. 18 | kernel_size (int): The kernel size for convs in the head. Default: 3. 19 | concat_input (bool): Whether concat the input and output of convs 20 | before classification layer. 21 | dilation (int): The dilation rate for convs in the head. Default: 1. 22 | """ 23 | 24 | def __init__(self, 25 | num_convs=2, 26 | kernel_size=3, 27 | concat_input=True, 28 | dilation=1, 29 | **kwargs): 30 | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) 31 | self.num_convs = num_convs 32 | self.concat_input = concat_input 33 | self.kernel_size = kernel_size 34 | super(FCNHead, self).__init__(**kwargs) 35 | if num_convs == 0: 36 | assert self.in_channels == self.channels 37 | 38 | conv_padding = (kernel_size // 2) * dilation 39 | convs = [] 40 | convs.append( 41 | ConvModule( 42 | self.in_channels, 43 | self.channels, 44 | kernel_size=kernel_size, 45 | padding=conv_padding, 46 | dilation=dilation, 47 | conv_cfg=self.conv_cfg, 48 | norm_cfg=self.norm_cfg, 49 | act_cfg=self.act_cfg)) 50 | for i in range(num_convs - 1): 51 | convs.append( 52 | ConvModule( 53 | self.channels, 54 | self.channels, 55 | kernel_size=kernel_size, 56 | padding=conv_padding, 57 | dilation=dilation, 58 | conv_cfg=self.conv_cfg, 59 | norm_cfg=self.norm_cfg, 60 | act_cfg=self.act_cfg)) 61 | if num_convs == 0: 62 | self.convs = nn.Identity() 63 | else: 64 | self.convs = nn.Sequential(*convs) 65 | if self.concat_input: 66 | self.conv_cat = ConvModule( 67 | self.in_channels + self.channels, 68 | self.channels, 69 | kernel_size=kernel_size, 70 | padding=kernel_size // 2, 71 | conv_cfg=self.conv_cfg, 72 | norm_cfg=self.norm_cfg, 73 | act_cfg=self.act_cfg) 74 | 75 | def _forward_feature(self, inputs): 76 | """Forward function for feature maps before classifying each pixel with 77 | ``self.cls_seg`` fc. 78 | 79 | Args: 80 | inputs (list[Tensor]): List of multi-level img features. 81 | 82 | Returns: 83 | feats (Tensor): A tensor of shape (batch_size, self.channels, 84 | H, W) which is feature map for last layer of decoder head. 85 | """ 86 | x = self._transform_inputs(inputs) 87 | feats = self.convs(x) 88 | if self.concat_input: 89 | feats = self.conv_cat(torch.cat([x, feats], dim=1)) 90 | return feats 91 | 92 | def forward(self, inputs): 93 | """Forward function.""" 94 | output = self._forward_feature(inputs) 95 | output = self.cls_seg(output) 96 | return output 97 | -------------------------------------------------------------------------------- /marinext/mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import build_conv_layer, build_norm_layer 3 | from mmcv.runner import Sequential 4 | from torch import nn as nn 5 | 6 | 7 | class ResLayer(Sequential): 8 | """ResLayer to build ResNet style backbone. 9 | 10 | Args: 11 | block (nn.Module): block used to build ResLayer. 12 | inplanes (int): inplanes of block. 13 | planes (int): planes of block. 14 | num_blocks (int): number of blocks. 15 | stride (int): stride of the first block. Default: 1 16 | avg_down (bool): Use AvgPool instead of stride conv when 17 | downsampling in the bottleneck. Default: False 18 | conv_cfg (dict): dictionary to construct and config conv layer. 19 | Default: None 20 | norm_cfg (dict): dictionary to construct and config norm layer. 21 | Default: dict(type='BN') 22 | multi_grid (int | None): Multi grid dilation rates of last 23 | stage. Default: None 24 | contract_dilation (bool): Whether contract first dilation of each layer 25 | Default: False 26 | """ 27 | 28 | def __init__(self, 29 | block, 30 | inplanes, 31 | planes, 32 | num_blocks, 33 | stride=1, 34 | dilation=1, 35 | avg_down=False, 36 | conv_cfg=None, 37 | norm_cfg=dict(type='BN'), 38 | multi_grid=None, 39 | contract_dilation=False, 40 | **kwargs): 41 | self.block = block 42 | 43 | downsample = None 44 | if stride != 1 or inplanes != planes * block.expansion: 45 | downsample = [] 46 | conv_stride = stride 47 | if avg_down: 48 | conv_stride = 1 49 | downsample.append( 50 | nn.AvgPool2d( 51 | kernel_size=stride, 52 | stride=stride, 53 | ceil_mode=True, 54 | count_include_pad=False)) 55 | downsample.extend([ 56 | build_conv_layer( 57 | conv_cfg, 58 | inplanes, 59 | planes * block.expansion, 60 | kernel_size=1, 61 | stride=conv_stride, 62 | bias=False), 63 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 64 | ]) 65 | downsample = nn.Sequential(*downsample) 66 | 67 | layers = [] 68 | if multi_grid is None: 69 | if dilation > 1 and contract_dilation: 70 | first_dilation = dilation // 2 71 | else: 72 | first_dilation = dilation 73 | else: 74 | first_dilation = multi_grid[0] 75 | layers.append( 76 | block( 77 | inplanes=inplanes, 78 | planes=planes, 79 | stride=stride, 80 | dilation=first_dilation, 81 | downsample=downsample, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | **kwargs)) 85 | inplanes = planes * block.expansion 86 | for i in range(1, num_blocks): 87 | layers.append( 88 | block( 89 | inplanes=inplanes, 90 | planes=planes, 91 | stride=1, 92 | dilation=dilation if multi_grid is None else multi_grid[i], 93 | conv_cfg=conv_cfg, 94 | norm_cfg=norm_cfg, 95 | **kwargs)) 96 | super(ResLayer, self).__init__(*layers) 97 | -------------------------------------------------------------------------------- /marinext/mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..builder import PIXEL_SAMPLERS 7 | from .base_pixel_sampler import BasePixelSampler 8 | 9 | 10 | @PIXEL_SAMPLERS.register_module() 11 | class OHEMPixelSampler(BasePixelSampler): 12 | """Online Hard Example Mining Sampler for segmentation. 13 | 14 | Args: 15 | context (nn.Module): The context of sampler, subclass of 16 | :obj:`BaseDecodeHead`. 17 | thresh (float, optional): The threshold for hard example selection. 18 | Below which, are prediction with low confidence. If not 19 | specified, the hard examples will be pixels of top ``min_kept`` 20 | loss. Default: None. 21 | min_kept (int, optional): The minimum number of predictions to keep. 22 | Default: 100000. 23 | """ 24 | 25 | def __init__(self, context, thresh=None, min_kept=100000): 26 | super(OHEMPixelSampler, self).__init__() 27 | self.context = context 28 | assert min_kept > 1 29 | self.thresh = thresh 30 | self.min_kept = min_kept 31 | 32 | def sample(self, seg_logit, seg_label): 33 | """Sample pixels that have high loss or with low prediction confidence. 34 | 35 | Args: 36 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 37 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 38 | 39 | Returns: 40 | torch.Tensor: segmentation weight, shape (N, H, W) 41 | """ 42 | with torch.no_grad(): 43 | assert seg_logit.shape[2:] == seg_label.shape[2:] 44 | assert seg_label.shape[1] == 1 45 | seg_label = seg_label.squeeze(1).long() 46 | batch_kept = self.min_kept * seg_label.size(0) 47 | valid_mask = seg_label != self.context.ignore_index 48 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 49 | valid_seg_weight = seg_weight[valid_mask] 50 | if self.thresh is not None: 51 | seg_prob = F.softmax(seg_logit, dim=1) 52 | 53 | tmp_seg_label = seg_label.clone().unsqueeze(1) 54 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 55 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 56 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 57 | 58 | if sort_prob.numel() > 0: 59 | min_threshold = sort_prob[min(batch_kept, 60 | sort_prob.numel() - 1)] 61 | else: 62 | min_threshold = 0.0 63 | threshold = max(min_threshold, self.thresh) 64 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 65 | else: 66 | if not isinstance(self.context.loss_decode, nn.ModuleList): 67 | losses_decode = [self.context.loss_decode] 68 | else: 69 | losses_decode = self.context.loss_decode 70 | losses = 0.0 71 | for loss_module in losses_decode: 72 | losses += loss_module( 73 | seg_logit, 74 | seg_label, 75 | weight=None, 76 | ignore_index=self.context.ignore_index, 77 | reduction_override='none') 78 | 79 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 80 | _, sort_indices = losses[valid_mask].sort(descending=True) 81 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 82 | 83 | seg_weight[valid_mask] = valid_seg_weight 84 | 85 | return seg_weight 86 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/stdc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class STDCHead(FCNHead): 11 | """This head is the implementation of `Rethinking BiSeNet For Real-time 12 | Semantic Segmentation `_. 13 | 14 | Args: 15 | boundary_threshold (float): The threshold of calculating boundary. 16 | Default: 0.1. 17 | """ 18 | 19 | def __init__(self, boundary_threshold=0.1, **kwargs): 20 | super(STDCHead, self).__init__(**kwargs) 21 | self.boundary_threshold = boundary_threshold 22 | # Using register buffer to make laplacian kernel on the same 23 | # device of `seg_label`. 24 | self.register_buffer( 25 | 'laplacian_kernel', 26 | torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], 27 | dtype=torch.float32, 28 | requires_grad=False).reshape((1, 1, 3, 3))) 29 | self.fusion_kernel = torch.nn.Parameter( 30 | torch.tensor([[6. / 10], [3. / 10], [1. / 10]], 31 | dtype=torch.float32).reshape(1, 3, 1, 1), 32 | requires_grad=False) 33 | 34 | def losses(self, seg_logit, seg_label): 35 | """Compute Detail Aggregation Loss.""" 36 | # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv 37 | # parameters. However, it is a constant in original repo and other 38 | # codebase because it would not be added into computation graph 39 | # after threshold operation. 40 | seg_label = seg_label.to(self.laplacian_kernel) 41 | boundary_targets = F.conv2d( 42 | seg_label, self.laplacian_kernel, padding=1) 43 | boundary_targets = boundary_targets.clamp(min=0) 44 | boundary_targets[boundary_targets > self.boundary_threshold] = 1 45 | boundary_targets[boundary_targets <= self.boundary_threshold] = 0 46 | 47 | boundary_targets_x2 = F.conv2d( 48 | seg_label, self.laplacian_kernel, stride=2, padding=1) 49 | boundary_targets_x2 = boundary_targets_x2.clamp(min=0) 50 | 51 | boundary_targets_x4 = F.conv2d( 52 | seg_label, self.laplacian_kernel, stride=4, padding=1) 53 | boundary_targets_x4 = boundary_targets_x4.clamp(min=0) 54 | 55 | boundary_targets_x4_up = F.interpolate( 56 | boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') 57 | boundary_targets_x2_up = F.interpolate( 58 | boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') 59 | 60 | boundary_targets_x2_up[ 61 | boundary_targets_x2_up > self.boundary_threshold] = 1 62 | boundary_targets_x2_up[ 63 | boundary_targets_x2_up <= self.boundary_threshold] = 0 64 | 65 | boundary_targets_x4_up[ 66 | boundary_targets_x4_up > self.boundary_threshold] = 1 67 | boundary_targets_x4_up[ 68 | boundary_targets_x4_up <= self.boundary_threshold] = 0 69 | 70 | boudary_targets_pyramids = torch.stack( 71 | (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), 72 | dim=1) 73 | 74 | boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2) 75 | boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, 76 | self.fusion_kernel) 77 | 78 | boudary_targets_pyramid[ 79 | boudary_targets_pyramid > self.boundary_threshold] = 1 80 | boudary_targets_pyramid[ 81 | boudary_targets_pyramid <= self.boundary_threshold] = 0 82 | 83 | loss = super(STDCHead, self).losses(seg_logit, 84 | boudary_targets_pyramid.long()) 85 | return loss 86 | -------------------------------------------------------------------------------- /marinext/mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): 7 | """Calculate accuracy according to the prediction and target. 8 | 9 | Args: 10 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 11 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 12 | ignore_index (int | None): The label index to be ignored. Default: None 13 | topk (int | tuple[int], optional): If the predictions in ``topk`` 14 | matches the target, the predictions will be regarded as 15 | correct ones. Defaults to 1. 16 | thresh (float, optional): If not None, predictions with scores under 17 | this threshold are considered incorrect. Default to None. 18 | 19 | Returns: 20 | float | tuple[float]: If the input ``topk`` is a single integer, 21 | the function will return a single float as accuracy. If 22 | ``topk`` is a tuple containing multiple integers, the 23 | function will return a tuple containing accuracies of 24 | each ``topk`` number. 25 | """ 26 | assert isinstance(topk, (int, tuple)) 27 | if isinstance(topk, int): 28 | topk = (topk, ) 29 | return_single = True 30 | else: 31 | return_single = False 32 | 33 | maxk = max(topk) 34 | if pred.size(0) == 0: 35 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 36 | return accu[0] if return_single else accu 37 | assert pred.ndim == target.ndim + 1 38 | assert pred.size(0) == target.size(0) 39 | assert maxk <= pred.size(1), \ 40 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 41 | pred_value, pred_label = pred.topk(maxk, dim=1) 42 | # transpose to shape (maxk, N, ...) 43 | pred_label = pred_label.transpose(0, 1) 44 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 45 | if thresh is not None: 46 | # Only prediction values larger than thresh are counted as correct 47 | correct = correct & (pred_value > thresh).t() 48 | if ignore_index is not None: 49 | correct = correct[:, target != ignore_index] 50 | res = [] 51 | eps = torch.finfo(torch.float32).eps 52 | for k in topk: 53 | # Avoid causing ZeroDivisionError when all pixels 54 | # of an image are ignored 55 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps 56 | if ignore_index is not None: 57 | total_num = target[target != ignore_index].numel() + eps 58 | else: 59 | total_num = target.numel() + eps 60 | res.append(correct_k.mul_(100.0 / total_num)) 61 | return res[0] if return_single else res 62 | 63 | 64 | class Accuracy(nn.Module): 65 | """Accuracy calculation module.""" 66 | 67 | def __init__(self, topk=(1, ), thresh=None, ignore_index=None): 68 | """Module to calculate the accuracy. 69 | 70 | Args: 71 | topk (tuple, optional): The criterion used to calculate the 72 | accuracy. Defaults to (1,). 73 | thresh (float, optional): If not None, predictions with scores 74 | under this threshold are considered incorrect. Default to None. 75 | """ 76 | super().__init__() 77 | self.topk = topk 78 | self.thresh = thresh 79 | self.ignore_index = ignore_index 80 | 81 | def forward(self, pred, target): 82 | """Forward function to calculate accuracy. 83 | 84 | Args: 85 | pred (torch.Tensor): Prediction of models. 86 | target (torch.Tensor): Target for each prediction. 87 | 88 | Returns: 89 | tuple[float]: The accuracies under different topk criterions. 90 | """ 91 | return accuracy(pred, target, self.topk, self.thresh, 92 | self.ignore_index) 93 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/sep_aspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import HEADS 8 | from .aspp_head import ASPPHead, ASPPModule 9 | 10 | 11 | class DepthwiseSeparableASPPModule(ASPPModule): 12 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 13 | conv.""" 14 | 15 | def __init__(self, **kwargs): 16 | super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) 17 | for i, dilation in enumerate(self.dilations): 18 | if dilation > 1: 19 | self[i] = DepthwiseSeparableConvModule( 20 | self.in_channels, 21 | self.channels, 22 | 3, 23 | dilation=dilation, 24 | padding=dilation, 25 | norm_cfg=self.norm_cfg, 26 | act_cfg=self.act_cfg) 27 | 28 | 29 | @HEADS.register_module() 30 | class DepthwiseSeparableASPPHead(ASPPHead): 31 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 32 | Segmentation. 33 | 34 | This head is the implementation of `DeepLabV3+ 35 | `_. 36 | 37 | Args: 38 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 39 | the no decoder will be used. 40 | c1_channels (int): The intermediate channels of c1 decoder. 41 | """ 42 | 43 | def __init__(self, c1_in_channels, c1_channels, **kwargs): 44 | super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) 45 | assert c1_in_channels >= 0 46 | self.aspp_modules = DepthwiseSeparableASPPModule( 47 | dilations=self.dilations, 48 | in_channels=self.in_channels, 49 | channels=self.channels, 50 | conv_cfg=self.conv_cfg, 51 | norm_cfg=self.norm_cfg, 52 | act_cfg=self.act_cfg) 53 | if c1_in_channels > 0: 54 | self.c1_bottleneck = ConvModule( 55 | c1_in_channels, 56 | c1_channels, 57 | 1, 58 | conv_cfg=self.conv_cfg, 59 | norm_cfg=self.norm_cfg, 60 | act_cfg=self.act_cfg) 61 | else: 62 | self.c1_bottleneck = None 63 | self.sep_bottleneck = nn.Sequential( 64 | DepthwiseSeparableConvModule( 65 | self.channels + c1_channels, 66 | self.channels, 67 | 3, 68 | padding=1, 69 | norm_cfg=self.norm_cfg, 70 | act_cfg=self.act_cfg), 71 | DepthwiseSeparableConvModule( 72 | self.channels, 73 | self.channels, 74 | 3, 75 | padding=1, 76 | norm_cfg=self.norm_cfg, 77 | act_cfg=self.act_cfg)) 78 | 79 | def forward(self, inputs): 80 | """Forward function.""" 81 | x = self._transform_inputs(inputs) 82 | aspp_outs = [ 83 | resize( 84 | self.image_pool(x), 85 | size=x.size()[2:], 86 | mode='bilinear', 87 | align_corners=self.align_corners) 88 | ] 89 | aspp_outs.extend(self.aspp_modules(x)) 90 | aspp_outs = torch.cat(aspp_outs, dim=1) 91 | output = self.bottleneck(aspp_outs) 92 | if self.c1_bottleneck is not None: 93 | c1_output = self.c1_bottleneck(inputs[0]) 94 | output = resize( 95 | input=output, 96 | size=c1_output.shape[2:], 97 | mode='bilinear', 98 | align_corners=self.align_corners) 99 | output = torch.cat([output, c1_output], dim=1) 100 | output = self.sep_bottleneck(output) 101 | output = self.cls_seg(output) 102 | return output 103 | -------------------------------------------------------------------------------- /marinext/mmseg/models/utils/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def nlc_to_nchw(x, hw_shape): 3 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 4 | 5 | Args: 6 | x (Tensor): The input tensor of shape [N, L, C] before conversion. 7 | hw_shape (Sequence[int]): The height and width of output feature map. 8 | 9 | Returns: 10 | Tensor: The output tensor of shape [N, C, H, W] after conversion. 11 | """ 12 | H, W = hw_shape 13 | assert len(x.shape) == 3 14 | B, L, C = x.shape 15 | assert L == H * W, 'The seq_len doesn\'t match H, W' 16 | return x.transpose(1, 2).reshape(B, C, H, W) 17 | 18 | 19 | def nchw_to_nlc(x): 20 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 21 | 22 | Args: 23 | x (Tensor): The input tensor of shape [N, C, H, W] before conversion. 24 | 25 | Returns: 26 | Tensor: The output tensor of shape [N, L, C] after conversion. 27 | """ 28 | assert len(x.shape) == 4 29 | return x.flatten(2).transpose(1, 2).contiguous() 30 | 31 | 32 | def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): 33 | """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the 34 | reshaped tensor as the input of `module`, and the convert the output of 35 | `module`, whose shape is. 36 | 37 | [N, L, C], to [N, C, H, W]. 38 | 39 | Args: 40 | module (Callable): A callable object the takes a tensor 41 | with shape [N, L, C] as input. 42 | x (Tensor): The input tensor of shape [N, C, H, W]. 43 | contiguous: 44 | contiguous (Bool): Whether to make the tensor contiguous 45 | after each shape transform. 46 | 47 | Returns: 48 | Tensor: The output tensor of shape [N, C, H, W]. 49 | 50 | Example: 51 | >>> import torch 52 | >>> import torch.nn as nn 53 | >>> norm = nn.LayerNorm(4) 54 | >>> feature_map = torch.rand(4, 4, 5, 5) 55 | >>> output = nchw2nlc2nchw(norm, feature_map) 56 | """ 57 | B, C, H, W = x.shape 58 | if not contiguous: 59 | x = x.flatten(2).transpose(1, 2) 60 | x = module(x, **kwargs) 61 | x = x.transpose(1, 2).reshape(B, C, H, W) 62 | else: 63 | x = x.flatten(2).transpose(1, 2).contiguous() 64 | x = module(x, **kwargs) 65 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() 66 | return x 67 | 68 | 69 | def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): 70 | """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the 71 | reshaped tensor as the input of `module`, and convert the output of 72 | `module`, whose shape is. 73 | 74 | [N, C, H, W], to [N, L, C]. 75 | 76 | Args: 77 | module (Callable): A callable object the takes a tensor 78 | with shape [N, C, H, W] as input. 79 | x (Tensor): The input tensor of shape [N, L, C]. 80 | hw_shape: (Sequence[int]): The height and width of the 81 | feature map with shape [N, C, H, W]. 82 | contiguous (Bool): Whether to make the tensor contiguous 83 | after each shape transform. 84 | 85 | Returns: 86 | Tensor: The output tensor of shape [N, L, C]. 87 | 88 | Example: 89 | >>> import torch 90 | >>> import torch.nn as nn 91 | >>> conv = nn.Conv2d(16, 16, 3, 1, 1) 92 | >>> feature_map = torch.rand(4, 25, 16) 93 | >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) 94 | """ 95 | H, W = hw_shape 96 | assert len(x.shape) == 3 97 | B, L, C = x.shape 98 | assert L == H * W, 'The seq_len doesn\'t match H, W' 99 | if not contiguous: 100 | x = x.transpose(1, 2).reshape(B, C, H, W) 101 | x = module(x, **kwargs) 102 | x = x.flatten(2).transpose(1, 2) 103 | else: 104 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() 105 | x = module(x, **kwargs) 106 | x = x.flatten(2).transpose(1, 2).contiguous() 107 | return x 108 | -------------------------------------------------------------------------------- /marinext/mmseg/models/necks/mla_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | 5 | from ..builder import NECKS 6 | 7 | 8 | class MLAModule(nn.Module): 9 | 10 | def __init__(self, 11 | in_channels=[1024, 1024, 1024, 1024], 12 | out_channels=256, 13 | norm_cfg=None, 14 | act_cfg=None): 15 | super(MLAModule, self).__init__() 16 | self.channel_proj = nn.ModuleList() 17 | for i in range(len(in_channels)): 18 | self.channel_proj.append( 19 | ConvModule( 20 | in_channels=in_channels[i], 21 | out_channels=out_channels, 22 | kernel_size=1, 23 | norm_cfg=norm_cfg, 24 | act_cfg=act_cfg)) 25 | self.feat_extract = nn.ModuleList() 26 | for i in range(len(in_channels)): 27 | self.feat_extract.append( 28 | ConvModule( 29 | in_channels=out_channels, 30 | out_channels=out_channels, 31 | kernel_size=3, 32 | padding=1, 33 | norm_cfg=norm_cfg, 34 | act_cfg=act_cfg)) 35 | 36 | def forward(self, inputs): 37 | 38 | # feat_list -> [p2, p3, p4, p5] 39 | feat_list = [] 40 | for x, conv in zip(inputs, self.channel_proj): 41 | feat_list.append(conv(x)) 42 | 43 | # feat_list -> [p5, p4, p3, p2] 44 | # mid_list -> [m5, m4, m3, m2] 45 | feat_list = feat_list[::-1] 46 | mid_list = [] 47 | for feat in feat_list: 48 | if len(mid_list) == 0: 49 | mid_list.append(feat) 50 | else: 51 | mid_list.append(mid_list[-1] + feat) 52 | 53 | # mid_list -> [m5, m4, m3, m2] 54 | # out_list -> [o2, o3, o4, o5] 55 | out_list = [] 56 | for mid, conv in zip(mid_list, self.feat_extract): 57 | out_list.append(conv(mid)) 58 | 59 | return tuple(out_list) 60 | 61 | 62 | @NECKS.register_module() 63 | class MLANeck(nn.Module): 64 | """Multi-level Feature Aggregation. 65 | 66 | This neck is `The Multi-level Feature Aggregation construction of 67 | SETR `_. 68 | 69 | 70 | Args: 71 | in_channels (List[int]): Number of input channels per scale. 72 | out_channels (int): Number of output channels (used at each scale). 73 | norm_layer (dict): Config dict for input normalization. 74 | Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). 75 | norm_cfg (dict): Config dict for normalization layer. Default: None. 76 | act_cfg (dict): Config dict for activation layer in ConvModule. 77 | Default: None. 78 | """ 79 | 80 | def __init__(self, 81 | in_channels, 82 | out_channels, 83 | norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), 84 | norm_cfg=None, 85 | act_cfg=None): 86 | super(MLANeck, self).__init__() 87 | assert isinstance(in_channels, list) 88 | self.in_channels = in_channels 89 | self.out_channels = out_channels 90 | 91 | # In order to build general vision transformer backbone, we have to 92 | # move MLA to neck. 93 | self.norm = nn.ModuleList([ 94 | build_norm_layer(norm_layer, in_channels[i])[1] 95 | for i in range(len(in_channels)) 96 | ]) 97 | 98 | self.mla = MLAModule( 99 | in_channels=in_channels, 100 | out_channels=out_channels, 101 | norm_cfg=norm_cfg, 102 | act_cfg=act_cfg) 103 | 104 | def forward(self, inputs): 105 | assert len(inputs) == len(self.in_channels) 106 | 107 | # Convert from nchw to nlc 108 | outs = [] 109 | for i in range(len(inputs)): 110 | x = inputs[i] 111 | n, c, h, w = x.shape 112 | x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() 113 | x = self.norm[i](x) 114 | x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() 115 | outs.append(x) 116 | 117 | outs = self.mla(outs) 118 | return tuple(outs) 119 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/psp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | class PPM(nn.ModuleList): 12 | """Pooling Pyramid Module used in PSPNet. 13 | 14 | Args: 15 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 16 | Module. 17 | in_channels (int): Input channels. 18 | channels (int): Channels after modules, before conv_seg. 19 | conv_cfg (dict|None): Config of conv layers. 20 | norm_cfg (dict|None): Config of norm layers. 21 | act_cfg (dict): Config of activation layers. 22 | align_corners (bool): align_corners argument of F.interpolate. 23 | """ 24 | 25 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 26 | act_cfg, align_corners, **kwargs): 27 | super(PPM, self).__init__() 28 | self.pool_scales = pool_scales 29 | self.align_corners = align_corners 30 | self.in_channels = in_channels 31 | self.channels = channels 32 | self.conv_cfg = conv_cfg 33 | self.norm_cfg = norm_cfg 34 | self.act_cfg = act_cfg 35 | for pool_scale in pool_scales: 36 | self.append( 37 | nn.Sequential( 38 | nn.AdaptiveAvgPool2d(pool_scale), 39 | ConvModule( 40 | self.in_channels, 41 | self.channels, 42 | 1, 43 | conv_cfg=self.conv_cfg, 44 | norm_cfg=self.norm_cfg, 45 | act_cfg=self.act_cfg, 46 | **kwargs))) 47 | 48 | def forward(self, x): 49 | """Forward function.""" 50 | ppm_outs = [] 51 | for ppm in self: 52 | ppm_out = ppm(x) 53 | upsampled_ppm_out = resize( 54 | ppm_out, 55 | size=x.size()[2:], 56 | mode='bilinear', 57 | align_corners=self.align_corners) 58 | ppm_outs.append(upsampled_ppm_out) 59 | return ppm_outs 60 | 61 | 62 | @HEADS.register_module() 63 | class PSPHead(BaseDecodeHead): 64 | """Pyramid Scene Parsing Network. 65 | 66 | This head is the implementation of 67 | `PSPNet `_. 68 | 69 | Args: 70 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 71 | Module. Default: (1, 2, 3, 6). 72 | """ 73 | 74 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 75 | super(PSPHead, self).__init__(**kwargs) 76 | assert isinstance(pool_scales, (list, tuple)) 77 | self.pool_scales = pool_scales 78 | self.psp_modules = PPM( 79 | self.pool_scales, 80 | self.in_channels, 81 | self.channels, 82 | conv_cfg=self.conv_cfg, 83 | norm_cfg=self.norm_cfg, 84 | act_cfg=self.act_cfg, 85 | align_corners=self.align_corners) 86 | self.bottleneck = ConvModule( 87 | self.in_channels + len(pool_scales) * self.channels, 88 | self.channels, 89 | 3, 90 | padding=1, 91 | conv_cfg=self.conv_cfg, 92 | norm_cfg=self.norm_cfg, 93 | act_cfg=self.act_cfg) 94 | 95 | def _forward_feature(self, inputs): 96 | """Forward function for feature maps before classifying each pixel with 97 | ``self.cls_seg`` fc. 98 | 99 | Args: 100 | inputs (list[Tensor]): List of multi-level img features. 101 | 102 | Returns: 103 | feats (Tensor): A tensor of shape (batch_size, self.channels, 104 | H, W) which is feature map for last layer of decoder head. 105 | """ 106 | x = self._transform_inputs(inputs) 107 | psp_outs = [x] 108 | psp_outs.extend(self.psp_modules(x)) 109 | psp_outs = torch.cat(psp_outs, dim=1) 110 | feats = self.bottleneck(psp_outs) 111 | return feats 112 | 113 | def forward(self, inputs): 114 | """Forward function.""" 115 | output = self._forward_feature(inputs) 116 | output = self.cls_seg(output) 117 | return output 118 | -------------------------------------------------------------------------------- /marinext/mmseg/models/utils/up_conv_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule, build_upsample_layer 5 | 6 | 7 | class UpConvBlock(nn.Module): 8 | """Upsample convolution block in decoder for UNet. 9 | 10 | This upsample convolution block consists of one upsample module 11 | followed by one convolution block. The upsample module expands the 12 | high-level low-resolution feature map and the convolution block fuses 13 | the upsampled high-level low-resolution feature map and the low-level 14 | high-resolution feature map from encoder. 15 | 16 | Args: 17 | conv_block (nn.Sequential): Sequential of convolutional layers. 18 | in_channels (int): Number of input channels of the high-level 19 | skip_channels (int): Number of input channels of the low-level 20 | high-resolution feature map from encoder. 21 | out_channels (int): Number of output channels. 22 | num_convs (int): Number of convolutional layers in the conv_block. 23 | Default: 2. 24 | stride (int): Stride of convolutional layer in conv_block. Default: 1. 25 | dilation (int): Dilation rate of convolutional layer in conv_block. 26 | Default: 1. 27 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 28 | memory while slowing down the training speed. Default: False. 29 | conv_cfg (dict | None): Config dict for convolution layer. 30 | Default: None. 31 | norm_cfg (dict | None): Config dict for normalization layer. 32 | Default: dict(type='BN'). 33 | act_cfg (dict | None): Config dict for activation layer in ConvModule. 34 | Default: dict(type='ReLU'). 35 | upsample_cfg (dict): The upsample config of the upsample module in 36 | decoder. Default: dict(type='InterpConv'). If the size of 37 | high-level feature map is the same as that of skip feature map 38 | (low-level feature map from encoder), it does not need upsample the 39 | high-level feature map and the upsample_cfg is None. 40 | dcn (bool): Use deformable convolution in convolutional layer or not. 41 | Default: None. 42 | plugins (dict): plugins for convolutional layers. Default: None. 43 | """ 44 | 45 | def __init__(self, 46 | conv_block, 47 | in_channels, 48 | skip_channels, 49 | out_channels, 50 | num_convs=2, 51 | stride=1, 52 | dilation=1, 53 | with_cp=False, 54 | conv_cfg=None, 55 | norm_cfg=dict(type='BN'), 56 | act_cfg=dict(type='ReLU'), 57 | upsample_cfg=dict(type='InterpConv'), 58 | dcn=None, 59 | plugins=None): 60 | super(UpConvBlock, self).__init__() 61 | assert dcn is None, 'Not implemented yet.' 62 | assert plugins is None, 'Not implemented yet.' 63 | 64 | self.conv_block = conv_block( 65 | in_channels=2 * skip_channels, 66 | out_channels=out_channels, 67 | num_convs=num_convs, 68 | stride=stride, 69 | dilation=dilation, 70 | with_cp=with_cp, 71 | conv_cfg=conv_cfg, 72 | norm_cfg=norm_cfg, 73 | act_cfg=act_cfg, 74 | dcn=None, 75 | plugins=None) 76 | if upsample_cfg is not None: 77 | self.upsample = build_upsample_layer( 78 | cfg=upsample_cfg, 79 | in_channels=in_channels, 80 | out_channels=skip_channels, 81 | with_cp=with_cp, 82 | norm_cfg=norm_cfg, 83 | act_cfg=act_cfg) 84 | else: 85 | self.upsample = ConvModule( 86 | in_channels, 87 | skip_channels, 88 | kernel_size=1, 89 | stride=1, 90 | padding=0, 91 | conv_cfg=conv_cfg, 92 | norm_cfg=norm_cfg, 93 | act_cfg=act_cfg) 94 | 95 | def forward(self, skip, x): 96 | """Forward function.""" 97 | 98 | x = self.upsample(x) 99 | out = torch.cat([skip, x], dim=1) 100 | out = self.conv_block(out) 101 | 102 | return out 103 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/aspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | class ASPPModule(nn.ModuleList): 12 | """Atrous Spatial Pyramid Pooling (ASPP) Module. 13 | 14 | Args: 15 | dilations (tuple[int]): Dilation rate of each layer. 16 | in_channels (int): Input channels. 17 | channels (int): Channels after modules, before conv_seg. 18 | conv_cfg (dict|None): Config of conv layers. 19 | norm_cfg (dict|None): Config of norm layers. 20 | act_cfg (dict): Config of activation layers. 21 | """ 22 | 23 | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, 24 | act_cfg): 25 | super(ASPPModule, self).__init__() 26 | self.dilations = dilations 27 | self.in_channels = in_channels 28 | self.channels = channels 29 | self.conv_cfg = conv_cfg 30 | self.norm_cfg = norm_cfg 31 | self.act_cfg = act_cfg 32 | for dilation in dilations: 33 | self.append( 34 | ConvModule( 35 | self.in_channels, 36 | self.channels, 37 | 1 if dilation == 1 else 3, 38 | dilation=dilation, 39 | padding=0 if dilation == 1 else dilation, 40 | conv_cfg=self.conv_cfg, 41 | norm_cfg=self.norm_cfg, 42 | act_cfg=self.act_cfg)) 43 | 44 | def forward(self, x): 45 | """Forward function.""" 46 | aspp_outs = [] 47 | for aspp_module in self: 48 | aspp_outs.append(aspp_module(x)) 49 | 50 | return aspp_outs 51 | 52 | 53 | @HEADS.register_module() 54 | class ASPPHead(BaseDecodeHead): 55 | """Rethinking Atrous Convolution for Semantic Image Segmentation. 56 | 57 | This head is the implementation of `DeepLabV3 58 | `_. 59 | 60 | Args: 61 | dilations (tuple[int]): Dilation rates for ASPP module. 62 | Default: (1, 6, 12, 18). 63 | """ 64 | 65 | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 66 | super(ASPPHead, self).__init__(**kwargs) 67 | assert isinstance(dilations, (list, tuple)) 68 | self.dilations = dilations 69 | self.image_pool = nn.Sequential( 70 | nn.AdaptiveAvgPool2d(1), 71 | ConvModule( 72 | self.in_channels, 73 | self.channels, 74 | 1, 75 | conv_cfg=self.conv_cfg, 76 | norm_cfg=self.norm_cfg, 77 | act_cfg=self.act_cfg)) 78 | self.aspp_modules = ASPPModule( 79 | dilations, 80 | self.in_channels, 81 | self.channels, 82 | conv_cfg=self.conv_cfg, 83 | norm_cfg=self.norm_cfg, 84 | act_cfg=self.act_cfg) 85 | self.bottleneck = ConvModule( 86 | (len(dilations) + 1) * self.channels, 87 | self.channels, 88 | 3, 89 | padding=1, 90 | conv_cfg=self.conv_cfg, 91 | norm_cfg=self.norm_cfg, 92 | act_cfg=self.act_cfg) 93 | 94 | def _forward_feature(self, inputs): 95 | """Forward function for feature maps before classifying each pixel with 96 | ``self.cls_seg`` fc. 97 | 98 | Args: 99 | inputs (list[Tensor]): List of multi-level img features. 100 | 101 | Returns: 102 | feats (Tensor): A tensor of shape (batch_size, self.channels, 103 | H, W) which is feature map for last layer of decoder head. 104 | """ 105 | x = self._transform_inputs(inputs) 106 | aspp_outs = [ 107 | resize( 108 | self.image_pool(x), 109 | size=x.size()[2:], 110 | mode='bilinear', 111 | align_corners=self.align_corners) 112 | ] 113 | aspp_outs.extend(self.aspp_modules(x)) 114 | aspp_outs = torch.cat(aspp_outs, dim=1) 115 | feats = self.bottleneck(aspp_outs) 116 | return feats 117 | 118 | def forward(self, inputs): 119 | """Forward function.""" 120 | output = self._forward_feature(inputs) 121 | output = self.cls_seg(output) 122 | return output 123 | -------------------------------------------------------------------------------- /marinext/mmseg/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import functools 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_class_weight(class_weight): 11 | """Get class weight for loss function. 12 | 13 | Args: 14 | class_weight (list[float] | str | None): If class_weight is a str, 15 | take it as a file name and read from it. 16 | """ 17 | if isinstance(class_weight, str): 18 | # take it as a file path 19 | if class_weight.endswith('.npy'): 20 | class_weight = np.load(class_weight) 21 | else: 22 | # pkl, json or yaml 23 | class_weight = mmcv.load(class_weight) 24 | 25 | return class_weight 26 | 27 | 28 | def reduce_loss(loss, reduction): 29 | """Reduce loss as specified. 30 | 31 | Args: 32 | loss (Tensor): Elementwise loss tensor. 33 | reduction (str): Options are "none", "mean" and "sum". 34 | 35 | Return: 36 | Tensor: Reduced loss tensor. 37 | """ 38 | reduction_enum = F._Reduction.get_enum(reduction) 39 | # none: 0, elementwise_mean:1, sum: 2 40 | if reduction_enum == 0: 41 | return loss 42 | elif reduction_enum == 1: 43 | return loss.mean() 44 | elif reduction_enum == 2: 45 | return loss.sum() 46 | 47 | 48 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 49 | """Apply element-wise weight and reduce loss. 50 | 51 | Args: 52 | loss (Tensor): Element-wise loss. 53 | weight (Tensor): Element-wise weights. 54 | reduction (str): Same as built-in losses of PyTorch. 55 | avg_factor (float): Average factor when computing the mean of losses. 56 | 57 | Returns: 58 | Tensor: Processed loss values. 59 | """ 60 | # if weight is specified, apply element-wise weight 61 | if weight is not None: 62 | assert weight.dim() == loss.dim() 63 | if weight.dim() > 1: 64 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 65 | loss = loss * weight 66 | 67 | # if avg_factor is not specified, just reduce the loss 68 | if avg_factor is None: 69 | loss = reduce_loss(loss, reduction) 70 | else: 71 | # if reduction is mean, then average the loss by avg_factor 72 | if reduction == 'mean': 73 | # Avoid causing ZeroDivisionError when avg_factor is 0.0, 74 | # i.e., all labels of an image belong to ignore index. 75 | eps = torch.finfo(torch.float32).eps 76 | loss = loss.sum() / (avg_factor + eps) 77 | # if reduction is 'none', then do nothing, otherwise raise an error 78 | elif reduction != 'none': 79 | raise ValueError('avg_factor can not be used with reduction="sum"') 80 | return loss 81 | 82 | 83 | def weighted_loss(loss_func): 84 | """Create a weighted version of a given loss function. 85 | 86 | To use this decorator, the loss function must have the signature like 87 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 88 | element-wise loss without any reduction. This decorator will add weight 89 | and reduction arguments to the function. The decorated function will have 90 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 91 | avg_factor=None, **kwargs)`. 92 | 93 | :Example: 94 | 95 | >>> import torch 96 | >>> @weighted_loss 97 | >>> def l1_loss(pred, target): 98 | >>> return (pred - target).abs() 99 | 100 | >>> pred = torch.Tensor([0, 2, 3]) 101 | >>> target = torch.Tensor([1, 1, 1]) 102 | >>> weight = torch.Tensor([1, 0, 1]) 103 | 104 | >>> l1_loss(pred, target) 105 | tensor(1.3333) 106 | >>> l1_loss(pred, target, weight) 107 | tensor(1.) 108 | >>> l1_loss(pred, target, reduction='none') 109 | tensor([1., 1., 2.]) 110 | >>> l1_loss(pred, target, weight, avg_factor=2) 111 | tensor(1.5000) 112 | """ 113 | 114 | @functools.wraps(loss_func) 115 | def wrapper(pred, 116 | target, 117 | weight=None, 118 | reduction='mean', 119 | avg_factor=None, 120 | **kwargs): 121 | # get element-wise loss 122 | loss = loss_func(pred, target, **kwargs) 123 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 124 | return loss 125 | 126 | return wrapper 127 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/ocr_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from ..utils import SelfAttentionBlock as _SelfAttentionBlock 10 | from .cascade_decode_head import BaseCascadeDecodeHead 11 | 12 | 13 | class SpatialGatherModule(nn.Module): 14 | """Aggregate the context features according to the initial predicted 15 | probability distribution. 16 | 17 | Employ the soft-weighted method to aggregate the context. 18 | """ 19 | 20 | def __init__(self, scale): 21 | super(SpatialGatherModule, self).__init__() 22 | self.scale = scale 23 | 24 | def forward(self, feats, probs): 25 | """Forward function.""" 26 | batch_size, num_classes, height, width = probs.size() 27 | channels = feats.size(1) 28 | probs = probs.view(batch_size, num_classes, -1) 29 | feats = feats.view(batch_size, channels, -1) 30 | # [batch_size, height*width, num_classes] 31 | feats = feats.permute(0, 2, 1) 32 | # [batch_size, channels, height*width] 33 | probs = F.softmax(self.scale * probs, dim=2) 34 | # [batch_size, channels, num_classes] 35 | ocr_context = torch.matmul(probs, feats) 36 | ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) 37 | return ocr_context 38 | 39 | 40 | class ObjectAttentionBlock(_SelfAttentionBlock): 41 | """Make a OCR used SelfAttentionBlock.""" 42 | 43 | def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, 44 | act_cfg): 45 | if scale > 1: 46 | query_downsample = nn.MaxPool2d(kernel_size=scale) 47 | else: 48 | query_downsample = None 49 | super(ObjectAttentionBlock, self).__init__( 50 | key_in_channels=in_channels, 51 | query_in_channels=in_channels, 52 | channels=channels, 53 | out_channels=in_channels, 54 | share_key_query=False, 55 | query_downsample=query_downsample, 56 | key_downsample=None, 57 | key_query_num_convs=2, 58 | key_query_norm=True, 59 | value_out_num_convs=1, 60 | value_out_norm=True, 61 | matmul_norm=True, 62 | with_out=True, 63 | conv_cfg=conv_cfg, 64 | norm_cfg=norm_cfg, 65 | act_cfg=act_cfg) 66 | self.bottleneck = ConvModule( 67 | in_channels * 2, 68 | in_channels, 69 | 1, 70 | conv_cfg=self.conv_cfg, 71 | norm_cfg=self.norm_cfg, 72 | act_cfg=self.act_cfg) 73 | 74 | def forward(self, query_feats, key_feats): 75 | """Forward function.""" 76 | context = super(ObjectAttentionBlock, 77 | self).forward(query_feats, key_feats) 78 | output = self.bottleneck(torch.cat([context, query_feats], dim=1)) 79 | if self.query_downsample is not None: 80 | output = resize(query_feats) 81 | 82 | return output 83 | 84 | 85 | @HEADS.register_module() 86 | class OCRHead(BaseCascadeDecodeHead): 87 | """Object-Contextual Representations for Semantic Segmentation. 88 | 89 | This head is the implementation of `OCRNet 90 | `_. 91 | 92 | Args: 93 | ocr_channels (int): The intermediate channels of OCR block. 94 | scale (int): The scale of probability map in SpatialGatherModule in 95 | Default: 1. 96 | """ 97 | 98 | def __init__(self, ocr_channels, scale=1, **kwargs): 99 | super(OCRHead, self).__init__(**kwargs) 100 | self.ocr_channels = ocr_channels 101 | self.scale = scale 102 | self.object_context_block = ObjectAttentionBlock( 103 | self.channels, 104 | self.ocr_channels, 105 | self.scale, 106 | conv_cfg=self.conv_cfg, 107 | norm_cfg=self.norm_cfg, 108 | act_cfg=self.act_cfg) 109 | self.spatial_gather_module = SpatialGatherModule(self.scale) 110 | 111 | self.bottleneck = ConvModule( 112 | self.in_channels, 113 | self.channels, 114 | 3, 115 | padding=1, 116 | conv_cfg=self.conv_cfg, 117 | norm_cfg=self.norm_cfg, 118 | act_cfg=self.act_cfg) 119 | 120 | def forward(self, inputs, prev_output): 121 | """Forward function.""" 122 | x = self._transform_inputs(inputs) 123 | feats = self.bottleneck(x) 124 | context = self.spatial_gather_module(feats, prev_output) 125 | object_context = self.object_context_block(feats, context) 126 | output = self.cls_seg(object_context) 127 | 128 | return output 129 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/uper_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | from .psp_head import PPM 10 | 11 | 12 | @HEADS.register_module() 13 | class UPerHead(BaseDecodeHead): 14 | """Unified Perceptual Parsing for Scene Understanding. 15 | 16 | This head is the implementation of `UPerNet 17 | `_. 18 | 19 | Args: 20 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 21 | Module applied on the last feature. Default: (1, 2, 3, 6). 22 | """ 23 | 24 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 25 | super(UPerHead, self).__init__( 26 | input_transform='multiple_select', **kwargs) 27 | # PSP Module 28 | self.psp_modules = PPM( 29 | pool_scales, 30 | self.in_channels[-1], 31 | self.channels, 32 | conv_cfg=self.conv_cfg, 33 | norm_cfg=self.norm_cfg, 34 | act_cfg=self.act_cfg, 35 | align_corners=self.align_corners) 36 | self.bottleneck = ConvModule( 37 | self.in_channels[-1] + len(pool_scales) * self.channels, 38 | self.channels, 39 | 3, 40 | padding=1, 41 | conv_cfg=self.conv_cfg, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg) 44 | # FPN Module 45 | self.lateral_convs = nn.ModuleList() 46 | self.fpn_convs = nn.ModuleList() 47 | for in_channels in self.in_channels[:-1]: # skip the top layer 48 | l_conv = ConvModule( 49 | in_channels, 50 | self.channels, 51 | 1, 52 | conv_cfg=self.conv_cfg, 53 | norm_cfg=self.norm_cfg, 54 | act_cfg=self.act_cfg, 55 | inplace=False) 56 | fpn_conv = ConvModule( 57 | self.channels, 58 | self.channels, 59 | 3, 60 | padding=1, 61 | conv_cfg=self.conv_cfg, 62 | norm_cfg=self.norm_cfg, 63 | act_cfg=self.act_cfg, 64 | inplace=False) 65 | self.lateral_convs.append(l_conv) 66 | self.fpn_convs.append(fpn_conv) 67 | 68 | self.fpn_bottleneck = ConvModule( 69 | len(self.in_channels) * self.channels, 70 | self.channels, 71 | 3, 72 | padding=1, 73 | conv_cfg=self.conv_cfg, 74 | norm_cfg=self.norm_cfg, 75 | act_cfg=self.act_cfg) 76 | 77 | def psp_forward(self, inputs): 78 | """Forward function of PSP module.""" 79 | x = inputs[-1] 80 | psp_outs = [x] 81 | psp_outs.extend(self.psp_modules(x)) 82 | psp_outs = torch.cat(psp_outs, dim=1) 83 | output = self.bottleneck(psp_outs) 84 | 85 | return output 86 | 87 | def _forward_feature(self, inputs): 88 | """Forward function for feature maps before classifying each pixel with 89 | ``self.cls_seg`` fc. 90 | 91 | Args: 92 | inputs (list[Tensor]): List of multi-level img features. 93 | 94 | Returns: 95 | feats (Tensor): A tensor of shape (batch_size, self.channels, 96 | H, W) which is feature map for last layer of decoder head. 97 | """ 98 | inputs = self._transform_inputs(inputs) 99 | 100 | # build laterals 101 | laterals = [ 102 | lateral_conv(inputs[i]) 103 | for i, lateral_conv in enumerate(self.lateral_convs) 104 | ] 105 | 106 | laterals.append(self.psp_forward(inputs)) 107 | 108 | # build top-down path 109 | used_backbone_levels = len(laterals) 110 | for i in range(used_backbone_levels - 1, 0, -1): 111 | prev_shape = laterals[i - 1].shape[2:] 112 | laterals[i - 1] = laterals[i - 1] + resize( 113 | laterals[i], 114 | size=prev_shape, 115 | mode='bilinear', 116 | align_corners=self.align_corners) 117 | 118 | # build outputs 119 | fpn_outs = [ 120 | self.fpn_convs[i](laterals[i]) 121 | for i in range(used_backbone_levels - 1) 122 | ] 123 | # append psp feature 124 | fpn_outs.append(laterals[-1]) 125 | 126 | for i in range(used_backbone_levels - 1, 0, -1): 127 | fpn_outs[i] = resize( 128 | fpn_outs[i], 129 | size=fpn_outs[0].shape[2:], 130 | mode='bilinear', 131 | align_corners=self.align_corners) 132 | fpn_outs = torch.cat(fpn_outs, dim=1) 133 | feats = self.fpn_bottleneck(fpn_outs) 134 | return feats 135 | 136 | def forward(self, inputs): 137 | """Forward function.""" 138 | output = self._forward_feature(inputs) 139 | output = self.cls_seg(output) 140 | return output 141 | -------------------------------------------------------------------------------- /utils/stack_patches.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Ioannis Kakogeorgiou 3 | Email: gkakogeorgiou@gmail.com 4 | Python Version: 3.7.10 5 | Description: stack_patches.py production of upsampling and stacking the patches. 6 | ''' 7 | 8 | import os 9 | import argparse 10 | import shutil 11 | import rasterio 12 | from tqdm import tqdm 13 | from glob import glob 14 | from rasterio.enums import Resampling 15 | 16 | def get_band(path, crop): 17 | return int(path.split('_'+crop)[0].split('_')[-1]) 18 | 19 | def main(options): 20 | 21 | main_output_folder = options['path'] + '_' + options['resampling'] 22 | all_tiles = glob(os.path.join(options['path'],'Scene_*')) 23 | if options['resampling'] == 'nearest': 24 | resampling_method = Resampling.nearest 25 | elif options['resampling'] == 'bilinear': 26 | resampling_method = Resampling.bilinear 27 | else: 28 | raise 29 | 30 | # Copy split folder 31 | split_output_folder = os.path.join(main_output_folder, 'splits') 32 | os.makedirs(split_output_folder, exist_ok=True) 33 | 34 | split_files = glob(os.path.join(options['path'],'splits','*')) 35 | for f in split_files: 36 | new_split_file = os.path.join(split_output_folder, os.path.basename(f)) 37 | shutil.copy(f, new_split_file) 38 | 39 | for tile in tqdm(all_tiles): 40 | 41 | # Create the output folder 42 | current_output_folder = os.path.join(main_output_folder, os.path.basename(tile)) 43 | os.makedirs(current_output_folder, exist_ok=True) 44 | 45 | # Copy gt files 46 | gt_files = glob(os.path.join(tile, '10', '*_cl_*')) + glob(os.path.join(tile, '10', '*_conf_*')) + glob(os.path.join(tile, '10', '*_rep_*')) 47 | for f in gt_files: 48 | new_gt_file = os.path.join(current_output_folder, os.path.basename(f)) 49 | shutil.copy(f, new_gt_file) 50 | 51 | # Get the number of different crops for the specific tile 52 | splits = [f.split('_cl_')[-1] for f in glob(os.path.join(tile, '10', '*_cl_*'))] 53 | 54 | for crop in splits: 55 | 56 | # Get the bands for the specific crop 57 | all_bands = glob(os.path.join(tile, '*', '*L2R_rhorc*_'+crop)) 58 | all_bands = sorted(all_bands, key=lambda patch: get_band(patch, crop)) 59 | 60 | ################################ 61 | # Stack and Upsample the bands # 62 | ################################ 63 | 64 | # Get metadata from the second 10m band 65 | with rasterio.open(all_bands[1], mode ='r') as src: 66 | tags = src.tags().copy() 67 | meta = src.meta 68 | image = src.read(1) 69 | shape = image.shape 70 | dtype = image.dtype 71 | 72 | # Update meta to reflect the number of layers 73 | meta.update(count = len(all_bands)) 74 | 75 | # Construct the filename 76 | output_file = os.path.basename(all_bands[1]).replace(str(get_band(all_bands[1], crop))+'_', '') 77 | output_file = os.path.join(current_output_folder, output_file) 78 | 79 | # Write it to stack 80 | with rasterio.open(output_file, 'w', driver='GTiff', 81 | height=shape[-2], 82 | width=shape[-1], 83 | count=len(all_bands), 84 | dtype=dtype, 85 | crs='+proj=latlong') as dst: # non-georeferenced (just to be recognised from gis) 86 | 87 | for c, band in enumerate(all_bands, 1): 88 | upscale_factor = int(os.path.basename(os.path.dirname(band)))//10 89 | 90 | with rasterio.open(band, mode ='r') as src: 91 | dst.write_band(c, src.read(1, 92 | out_shape=( 93 | int(src.height * upscale_factor), 94 | int(src.width * upscale_factor) 95 | ), 96 | resampling=resampling_method 97 | ).astype(dtype).copy() 98 | ) 99 | dst.update_tags(**tags) 100 | 101 | if __name__ == "__main__": 102 | 103 | parser = argparse.ArgumentParser() 104 | 105 | # Options 106 | parser.add_argument('--path', help='Path to dataset') 107 | parser.add_argument('--resampling', default='nearest', type=str, help='Type of resampling before stacking (nearest or bilinear)') 108 | 109 | args = parser.parse_args() 110 | options = vars(args) # convert to ordinary dict 111 | 112 | main(options) 113 | -------------------------------------------------------------------------------- /marinext/mmseg/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | import warnings 4 | 5 | import torch.distributed as dist 6 | from mmcv.runner import DistEvalHook as _DistEvalHook 7 | from mmcv.runner import EvalHook as _EvalHook 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | 10 | 11 | class EvalHook(_EvalHook): 12 | """Single GPU EvalHook, with efficient test support. 13 | 14 | Args: 15 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 16 | If set to True, it will perform by epoch. Otherwise, by iteration. 17 | Default: False. 18 | efficient_test (bool): Whether save the results as local numpy files to 19 | save CPU memory during evaluation. Default: False. 20 | pre_eval (bool): Whether to use progressive mode to evaluate model. 21 | Default: False. 22 | Returns: 23 | list: The prediction results. 24 | """ 25 | 26 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 27 | 28 | def __init__(self, 29 | *args, 30 | by_epoch=False, 31 | efficient_test=False, 32 | pre_eval=False, 33 | **kwargs): 34 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 35 | self.pre_eval = pre_eval 36 | if efficient_test: 37 | warnings.warn( 38 | 'DeprecationWarning: ``efficient_test`` for evaluation hook ' 39 | 'is deprecated, the evaluation hook is CPU memory friendly ' 40 | 'with ``pre_eval=True`` as argument for ``single_gpu_test()`` ' 41 | 'function') 42 | 43 | def _do_evaluate(self, runner): 44 | """perform evaluation and save ckpt.""" 45 | if not self._should_evaluate(runner): 46 | return 47 | 48 | from mmseg.apis import single_gpu_test 49 | results = single_gpu_test( 50 | runner.model, self.dataloader, show=False, pre_eval=self.pre_eval) 51 | runner.log_buffer.clear() 52 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 53 | key_score = self.evaluate(runner, results) 54 | if self.save_best: 55 | self._save_ckpt(runner, key_score) 56 | 57 | 58 | class DistEvalHook(_DistEvalHook): 59 | """Distributed EvalHook, with efficient test support. 60 | 61 | Args: 62 | by_epoch (bool): Determine perform evaluation by epoch or by iteration. 63 | If set to True, it will perform by epoch. Otherwise, by iteration. 64 | Default: False. 65 | efficient_test (bool): Whether save the results as local numpy files to 66 | save CPU memory during evaluation. Default: False. 67 | pre_eval (bool): Whether to use progressive mode to evaluate model. 68 | Default: False. 69 | Returns: 70 | list: The prediction results. 71 | """ 72 | 73 | greater_keys = ['mIoU', 'mAcc', 'aAcc'] 74 | 75 | def __init__(self, 76 | *args, 77 | by_epoch=False, 78 | efficient_test=False, 79 | pre_eval=False, 80 | **kwargs): 81 | super().__init__(*args, by_epoch=by_epoch, **kwargs) 82 | self.pre_eval = pre_eval 83 | if efficient_test: 84 | warnings.warn( 85 | 'DeprecationWarning: ``efficient_test`` for evaluation hook ' 86 | 'is deprecated, the evaluation hook is CPU memory friendly ' 87 | 'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` ' 88 | 'function') 89 | 90 | def _do_evaluate(self, runner): 91 | """perform evaluation and save ckpt.""" 92 | # Synchronization of BatchNorm's buffer (running_mean 93 | # and running_var) is not supported in the DDP of pytorch, 94 | # which may cause the inconsistent performance of models in 95 | # different ranks, so we broadcast BatchNorm's buffers 96 | # of rank 0 to other ranks to avoid this. 97 | if self.broadcast_bn_buffer: 98 | model = runner.model 99 | for name, module in model.named_modules(): 100 | if isinstance(module, 101 | _BatchNorm) and module.track_running_stats: 102 | dist.broadcast(module.running_var, 0) 103 | dist.broadcast(module.running_mean, 0) 104 | 105 | if not self._should_evaluate(runner): 106 | return 107 | 108 | tmpdir = self.tmpdir 109 | if tmpdir is None: 110 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 111 | 112 | from mmseg.apis import multi_gpu_test 113 | results = multi_gpu_test( 114 | runner.model, 115 | self.dataloader, 116 | tmpdir=tmpdir, 117 | gpu_collect=self.gpu_collect, 118 | pre_eval=self.pre_eval) 119 | 120 | runner.log_buffer.clear() 121 | 122 | if runner.rank == 0: 123 | print('\n') 124 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 125 | key_score = self.evaluate(runner, results) 126 | 127 | if self.save_best: 128 | self._save_ckpt(runner, key_score) 129 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/dnl_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import NonLocal2d 4 | from torch import nn 5 | 6 | from ..builder import HEADS 7 | from .fcn_head import FCNHead 8 | 9 | 10 | class DisentangledNonLocal2d(NonLocal2d): 11 | """Disentangled Non-Local Blocks. 12 | 13 | Args: 14 | temperature (float): Temperature to adjust attention. Default: 0.05 15 | """ 16 | 17 | def __init__(self, *arg, temperature, **kwargs): 18 | super().__init__(*arg, **kwargs) 19 | self.temperature = temperature 20 | self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) 21 | 22 | def embedded_gaussian(self, theta_x, phi_x): 23 | """Embedded gaussian with temperature.""" 24 | 25 | # NonLocal2d pairwise_weight: [N, HxW, HxW] 26 | pairwise_weight = torch.matmul(theta_x, phi_x) 27 | if self.use_scale: 28 | # theta_x.shape[-1] is `self.inter_channels` 29 | pairwise_weight /= torch.tensor( 30 | theta_x.shape[-1], 31 | dtype=torch.float, 32 | device=pairwise_weight.device)**torch.tensor( 33 | 0.5, device=pairwise_weight.device) 34 | pairwise_weight /= torch.tensor( 35 | self.temperature, device=pairwise_weight.device) 36 | pairwise_weight = pairwise_weight.softmax(dim=-1) 37 | return pairwise_weight 38 | 39 | def forward(self, x): 40 | # x: [N, C, H, W] 41 | n = x.size(0) 42 | 43 | # g_x: [N, HxW, C] 44 | g_x = self.g(x).view(n, self.inter_channels, -1) 45 | g_x = g_x.permute(0, 2, 1) 46 | 47 | # theta_x: [N, HxW, C], phi_x: [N, C, HxW] 48 | if self.mode == 'gaussian': 49 | theta_x = x.view(n, self.in_channels, -1) 50 | theta_x = theta_x.permute(0, 2, 1) 51 | if self.sub_sample: 52 | phi_x = self.phi(x).view(n, self.in_channels, -1) 53 | else: 54 | phi_x = x.view(n, self.in_channels, -1) 55 | elif self.mode == 'concatenation': 56 | theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) 57 | phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) 58 | else: 59 | theta_x = self.theta(x).view(n, self.inter_channels, -1) 60 | theta_x = theta_x.permute(0, 2, 1) 61 | phi_x = self.phi(x).view(n, self.inter_channels, -1) 62 | 63 | # subtract mean 64 | theta_x -= theta_x.mean(dim=-2, keepdim=True) 65 | phi_x -= phi_x.mean(dim=-1, keepdim=True) 66 | 67 | pairwise_func = getattr(self, self.mode) 68 | # pairwise_weight: [N, HxW, HxW] 69 | pairwise_weight = pairwise_func(theta_x, phi_x) 70 | 71 | # y: [N, HxW, C] 72 | y = torch.matmul(pairwise_weight, g_x) 73 | # y: [N, C, H, W] 74 | y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, 75 | *x.size()[2:]) 76 | 77 | # unary_mask: [N, 1, HxW] 78 | unary_mask = self.conv_mask(x) 79 | unary_mask = unary_mask.view(n, 1, -1) 80 | unary_mask = unary_mask.softmax(dim=-1) 81 | # unary_x: [N, 1, C] 82 | unary_x = torch.matmul(unary_mask, g_x) 83 | # unary_x: [N, C, 1, 1] 84 | unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( 85 | n, self.inter_channels, 1, 1) 86 | 87 | output = x + self.conv_out(y + unary_x) 88 | 89 | return output 90 | 91 | 92 | @HEADS.register_module() 93 | class DNLHead(FCNHead): 94 | """Disentangled Non-Local Neural Networks. 95 | 96 | This head is the implementation of `DNLNet 97 | `_. 98 | 99 | Args: 100 | reduction (int): Reduction factor of projection transform. Default: 2. 101 | use_scale (bool): Whether to scale pairwise_weight by 102 | sqrt(1/inter_channels). Default: False. 103 | mode (str): The nonlocal mode. Options are 'embedded_gaussian', 104 | 'dot_product'. Default: 'embedded_gaussian.'. 105 | temperature (float): Temperature to adjust attention. Default: 0.05 106 | """ 107 | 108 | def __init__(self, 109 | reduction=2, 110 | use_scale=True, 111 | mode='embedded_gaussian', 112 | temperature=0.05, 113 | **kwargs): 114 | super(DNLHead, self).__init__(num_convs=2, **kwargs) 115 | self.reduction = reduction 116 | self.use_scale = use_scale 117 | self.mode = mode 118 | self.temperature = temperature 119 | self.dnl_block = DisentangledNonLocal2d( 120 | in_channels=self.channels, 121 | reduction=self.reduction, 122 | use_scale=self.use_scale, 123 | conv_cfg=self.conv_cfg, 124 | norm_cfg=self.norm_cfg, 125 | mode=self.mode, 126 | temperature=self.temperature) 127 | 128 | def forward(self, inputs): 129 | """Forward function.""" 130 | x = self._transform_inputs(inputs) 131 | output = self.convs[0](x) 132 | output = self.dnl_block(output) 133 | output = self.convs[1](output) 134 | if self.concat_input: 135 | output = self.conv_cat(torch.cat([x, output], dim=1)) 136 | output = self.cls_seg(output) 137 | return output 138 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/segmenter_mask_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn import build_norm_layer 6 | from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, 7 | trunc_normal_init) 8 | from mmcv.runner import ModuleList 9 | 10 | from mmseg.models.backbones.vit import TransformerEncoderLayer 11 | from ..builder import HEADS 12 | from .decode_head import BaseDecodeHead 13 | 14 | 15 | @HEADS.register_module() 16 | class SegmenterMaskTransformerHead(BaseDecodeHead): 17 | """Segmenter: Transformer for Semantic Segmentation. 18 | 19 | This head is the implementation of 20 | `Segmenter: `_. 21 | 22 | Args: 23 | backbone_cfg:(dict): Config of backbone of 24 | Context Path. 25 | in_channels (int): The number of channels of input image. 26 | num_layers (int): The depth of transformer. 27 | num_heads (int): The number of attention heads. 28 | embed_dims (int): The number of embedding dimension. 29 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim. 30 | Default: 4. 31 | drop_path_rate (float): stochastic depth rate. Default 0.1. 32 | drop_rate (float): Probability of an element to be zeroed. 33 | Default 0.0 34 | attn_drop_rate (float): The drop out rate for attention layer. 35 | Default 0.0 36 | num_fcs (int): The number of fully-connected layers for FFNs. 37 | Default: 2. 38 | qkv_bias (bool): Enable bias for qkv if True. Default: True. 39 | act_cfg (dict): The activation config for FFNs. 40 | Default: dict(type='GELU'). 41 | norm_cfg (dict): Config dict for normalization layer. 42 | Default: dict(type='LN') 43 | init_std (float): The value of std in weight initialization. 44 | Default: 0.02. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | in_channels, 50 | num_layers, 51 | num_heads, 52 | embed_dims, 53 | mlp_ratio=4, 54 | drop_path_rate=0.1, 55 | drop_rate=0.0, 56 | attn_drop_rate=0.0, 57 | num_fcs=2, 58 | qkv_bias=True, 59 | act_cfg=dict(type='GELU'), 60 | norm_cfg=dict(type='LN'), 61 | init_std=0.02, 62 | **kwargs, 63 | ): 64 | super(SegmenterMaskTransformerHead, self).__init__( 65 | in_channels=in_channels, **kwargs) 66 | 67 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] 68 | self.layers = ModuleList() 69 | for i in range(num_layers): 70 | self.layers.append( 71 | TransformerEncoderLayer( 72 | embed_dims=embed_dims, 73 | num_heads=num_heads, 74 | feedforward_channels=mlp_ratio * embed_dims, 75 | attn_drop_rate=attn_drop_rate, 76 | drop_rate=drop_rate, 77 | drop_path_rate=dpr[i], 78 | num_fcs=num_fcs, 79 | qkv_bias=qkv_bias, 80 | act_cfg=act_cfg, 81 | norm_cfg=norm_cfg, 82 | batch_first=True, 83 | )) 84 | 85 | self.dec_proj = nn.Linear(in_channels, embed_dims) 86 | 87 | self.cls_emb = nn.Parameter( 88 | torch.randn(1, self.num_classes, embed_dims)) 89 | self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) 90 | self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) 91 | 92 | self.decoder_norm = build_norm_layer( 93 | norm_cfg, embed_dims, postfix=1)[1] 94 | self.mask_norm = build_norm_layer( 95 | norm_cfg, self.num_classes, postfix=2)[1] 96 | 97 | self.init_std = init_std 98 | 99 | delattr(self, 'conv_seg') 100 | 101 | def init_weights(self): 102 | trunc_normal_(self.cls_emb, std=self.init_std) 103 | trunc_normal_init(self.patch_proj, std=self.init_std) 104 | trunc_normal_init(self.classes_proj, std=self.init_std) 105 | for n, m in self.named_modules(): 106 | if isinstance(m, nn.Linear): 107 | trunc_normal_init(m, std=self.init_std, bias=0) 108 | elif isinstance(m, nn.LayerNorm): 109 | constant_init(m, val=1.0, bias=0.0) 110 | 111 | def forward(self, inputs): 112 | x = self._transform_inputs(inputs) 113 | b, c, h, w = x.shape 114 | x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) 115 | 116 | x = self.dec_proj(x) 117 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1) 118 | x = torch.cat((x, cls_emb), 1) 119 | for layer in self.layers: 120 | x = layer(x) 121 | x = self.decoder_norm(x) 122 | 123 | patches = self.patch_proj(x[:, :-self.num_classes]) 124 | cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) 125 | 126 | patches = F.normalize(patches, dim=2, p=2) 127 | cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) 128 | 129 | masks = patches @ cls_seg_feat.transpose(1, 2) 130 | masks = self.mask_norm(masks) 131 | masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w) 132 | 133 | return masks 134 | -------------------------------------------------------------------------------- /marinext/mmseg/models/losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ 3 | segmentron/solver/loss.py (Apache-2.0 License)""" 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ..builder import LOSSES 9 | from .utils import get_class_weight, weighted_loss 10 | 11 | 12 | @weighted_loss 13 | def dice_loss(pred, 14 | target, 15 | valid_mask, 16 | smooth=1, 17 | exponent=2, 18 | class_weight=None, 19 | ignore_index=255): 20 | assert pred.shape[0] == target.shape[0] 21 | total_loss = 0 22 | num_classes = pred.shape[1] 23 | for i in range(num_classes): 24 | if i != ignore_index: 25 | dice_loss = binary_dice_loss( 26 | pred[:, i], 27 | target[..., i], 28 | valid_mask=valid_mask, 29 | smooth=smooth, 30 | exponent=exponent) 31 | if class_weight is not None: 32 | dice_loss *= class_weight[i] 33 | total_loss += dice_loss 34 | return total_loss / num_classes 35 | 36 | 37 | @weighted_loss 38 | def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): 39 | assert pred.shape[0] == target.shape[0] 40 | pred = pred.reshape(pred.shape[0], -1) 41 | target = target.reshape(target.shape[0], -1) 42 | valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) 43 | 44 | num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth 45 | den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth 46 | 47 | return 1 - num / den 48 | 49 | 50 | @LOSSES.register_module() 51 | class DiceLoss(nn.Module): 52 | """DiceLoss. 53 | 54 | This loss is proposed in `V-Net: Fully Convolutional Neural Networks for 55 | Volumetric Medical Image Segmentation `_. 56 | 57 | Args: 58 | smooth (float): A float number to smooth loss, and avoid NaN error. 59 | Default: 1 60 | exponent (float): An float number to calculate denominator 61 | value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. 62 | reduction (str, optional): The method used to reduce the loss. Options 63 | are "none", "mean" and "sum". This parameter only works when 64 | per_image is True. Default: 'mean'. 65 | class_weight (list[float] | str, optional): Weight of each class. If in 66 | str format, read them from a file. Defaults to None. 67 | loss_weight (float, optional): Weight of the loss. Default to 1.0. 68 | ignore_index (int | None): The label index to be ignored. Default: 255. 69 | loss_name (str, optional): Name of the loss item. If you want this loss 70 | item to be included into the backward graph, `loss_` must be the 71 | prefix of the name. Defaults to 'loss_dice'. 72 | """ 73 | 74 | def __init__(self, 75 | smooth=1, 76 | exponent=2, 77 | reduction='mean', 78 | class_weight=None, 79 | loss_weight=1.0, 80 | ignore_index=255, 81 | loss_name='loss_dice', 82 | **kwards): 83 | super(DiceLoss, self).__init__() 84 | self.smooth = smooth 85 | self.exponent = exponent 86 | self.reduction = reduction 87 | self.class_weight = get_class_weight(class_weight) 88 | self.loss_weight = loss_weight 89 | self.ignore_index = ignore_index 90 | self._loss_name = loss_name 91 | 92 | def forward(self, 93 | pred, 94 | target, 95 | avg_factor=None, 96 | reduction_override=None, 97 | **kwards): 98 | assert reduction_override in (None, 'none', 'mean', 'sum') 99 | reduction = ( 100 | reduction_override if reduction_override else self.reduction) 101 | if self.class_weight is not None: 102 | class_weight = pred.new_tensor(self.class_weight) 103 | else: 104 | class_weight = None 105 | 106 | pred = F.softmax(pred, dim=1) 107 | num_classes = pred.shape[1] 108 | one_hot_target = F.one_hot( 109 | torch.clamp(target.long(), 0, num_classes - 1), 110 | num_classes=num_classes) 111 | valid_mask = (target != self.ignore_index).long() 112 | 113 | loss = self.loss_weight * dice_loss( 114 | pred, 115 | one_hot_target, 116 | valid_mask=valid_mask, 117 | reduction=reduction, 118 | avg_factor=avg_factor, 119 | smooth=self.smooth, 120 | exponent=self.exponent, 121 | class_weight=class_weight, 122 | ignore_index=self.ignore_index) 123 | return loss 124 | 125 | @property 126 | def loss_name(self): 127 | """Loss Name. 128 | 129 | This function must be implemented and will return the name of this 130 | loss function. This name will be used to combine different loss items 131 | by simple sum operation. In addition, if you want this loss item to be 132 | included into the backward graph, `loss_` must be the prefix of the 133 | name. 134 | Returns: 135 | str: The name of this loss item. 136 | """ 137 | return self._loss_name 138 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/isa_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from mmcv.cnn import ConvModule 7 | 8 | from ..builder import HEADS 9 | from ..utils import SelfAttentionBlock as _SelfAttentionBlock 10 | from .decode_head import BaseDecodeHead 11 | 12 | 13 | class SelfAttentionBlock(_SelfAttentionBlock): 14 | """Self-Attention Module. 15 | 16 | Args: 17 | in_channels (int): Input channels of key/query feature. 18 | channels (int): Output channels of key/query transform. 19 | conv_cfg (dict | None): Config of conv layers. 20 | norm_cfg (dict | None): Config of norm layers. 21 | act_cfg (dict | None): Config of activation layers. 22 | """ 23 | 24 | def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): 25 | super(SelfAttentionBlock, self).__init__( 26 | key_in_channels=in_channels, 27 | query_in_channels=in_channels, 28 | channels=channels, 29 | out_channels=in_channels, 30 | share_key_query=False, 31 | query_downsample=None, 32 | key_downsample=None, 33 | key_query_num_convs=2, 34 | key_query_norm=True, 35 | value_out_num_convs=1, 36 | value_out_norm=False, 37 | matmul_norm=True, 38 | with_out=False, 39 | conv_cfg=conv_cfg, 40 | norm_cfg=norm_cfg, 41 | act_cfg=act_cfg) 42 | 43 | self.output_project = self.build_project( 44 | in_channels, 45 | in_channels, 46 | num_convs=1, 47 | use_conv_module=True, 48 | conv_cfg=conv_cfg, 49 | norm_cfg=norm_cfg, 50 | act_cfg=act_cfg) 51 | 52 | def forward(self, x): 53 | """Forward function.""" 54 | context = super(SelfAttentionBlock, self).forward(x, x) 55 | return self.output_project(context) 56 | 57 | 58 | @HEADS.register_module() 59 | class ISAHead(BaseDecodeHead): 60 | """Interlaced Sparse Self-Attention for Semantic Segmentation. 61 | 62 | This head is the implementation of `ISA 63 | `_. 64 | 65 | Args: 66 | isa_channels (int): The channels of ISA Module. 67 | down_factor (tuple[int]): The local group size of ISA. 68 | """ 69 | 70 | def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): 71 | super(ISAHead, self).__init__(**kwargs) 72 | self.down_factor = down_factor 73 | 74 | self.in_conv = ConvModule( 75 | self.in_channels, 76 | self.channels, 77 | 3, 78 | padding=1, 79 | conv_cfg=self.conv_cfg, 80 | norm_cfg=self.norm_cfg, 81 | act_cfg=self.act_cfg) 82 | self.global_relation = SelfAttentionBlock( 83 | self.channels, 84 | isa_channels, 85 | conv_cfg=self.conv_cfg, 86 | norm_cfg=self.norm_cfg, 87 | act_cfg=self.act_cfg) 88 | self.local_relation = SelfAttentionBlock( 89 | self.channels, 90 | isa_channels, 91 | conv_cfg=self.conv_cfg, 92 | norm_cfg=self.norm_cfg, 93 | act_cfg=self.act_cfg) 94 | self.out_conv = ConvModule( 95 | self.channels * 2, 96 | self.channels, 97 | 1, 98 | conv_cfg=self.conv_cfg, 99 | norm_cfg=self.norm_cfg, 100 | act_cfg=self.act_cfg) 101 | 102 | def forward(self, inputs): 103 | """Forward function.""" 104 | x_ = self._transform_inputs(inputs) 105 | x = self.in_conv(x_) 106 | residual = x 107 | 108 | n, c, h, w = x.size() 109 | loc_h, loc_w = self.down_factor # size of local group in H- and W-axes 110 | glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w) 111 | pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w 112 | if pad_h > 0 or pad_w > 0: # pad if the size is not divisible 113 | padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, 114 | pad_h - pad_h // 2) 115 | x = F.pad(x, padding) 116 | 117 | # global relation 118 | x = x.view(n, c, glb_h, loc_h, glb_w, loc_w) 119 | # do permutation to gather global group 120 | x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w) 121 | x = x.reshape(-1, c, glb_h, glb_w) 122 | # apply attention within each global group 123 | x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w) 124 | 125 | # local relation 126 | x = x.view(n, loc_h, loc_w, c, glb_h, glb_w) 127 | # do permutation to gather local group 128 | x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w) 129 | x = x.reshape(-1, c, loc_h, loc_w) 130 | # apply attention within each local group 131 | x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w) 132 | 133 | # permute each pixel back to its original position 134 | x = x.view(n, glb_h, glb_w, c, loc_h, loc_w) 135 | x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w) 136 | x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w) 137 | if pad_h > 0 or pad_w > 0: # remove padding 138 | x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w] 139 | 140 | x = self.out_conv(torch.cat([x, residual], dim=1)) 141 | out = self.cls_seg(x) 142 | 143 | return out 144 | -------------------------------------------------------------------------------- /marinext/mmseg/models/necks/jpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 5 | from mmcv.runner import BaseModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import NECKS 9 | 10 | 11 | @NECKS.register_module() 12 | class JPU(BaseModule): 13 | """FastFCN: Rethinking Dilated Convolution in the Backbone 14 | for Semantic Segmentation. 15 | 16 | This Joint Pyramid Upsampling (JPU) neck is the implementation of 17 | `FastFCN `_. 18 | 19 | Args: 20 | in_channels (Tuple[int], optional): The number of input channels 21 | for each convolution operations before upsampling. 22 | Default: (512, 1024, 2048). 23 | mid_channels (int): The number of output channels of JPU. 24 | Default: 512. 25 | start_level (int): Index of the start input backbone level used to 26 | build the feature pyramid. Default: 0. 27 | end_level (int): Index of the end input backbone level (exclusive) to 28 | build the feature pyramid. Default: -1, which means the last level. 29 | dilations (tuple[int]): Dilation rate of each Depthwise 30 | Separable ConvModule. Default: (1, 2, 4, 8). 31 | align_corners (bool, optional): The align_corners argument of 32 | resize operation. Default: False. 33 | conv_cfg (dict | None): Config of conv layers. 34 | Default: None. 35 | norm_cfg (dict | None): Config of norm layers. 36 | Default: dict(type='BN'). 37 | act_cfg (dict): Config of activation layers. 38 | Default: dict(type='ReLU'). 39 | init_cfg (dict or list[dict], optional): Initialization config dict. 40 | Default: None. 41 | """ 42 | 43 | def __init__(self, 44 | in_channels=(512, 1024, 2048), 45 | mid_channels=512, 46 | start_level=0, 47 | end_level=-1, 48 | dilations=(1, 2, 4, 8), 49 | align_corners=False, 50 | conv_cfg=None, 51 | norm_cfg=dict(type='BN'), 52 | act_cfg=dict(type='ReLU'), 53 | init_cfg=None): 54 | super(JPU, self).__init__(init_cfg=init_cfg) 55 | assert isinstance(in_channels, tuple) 56 | assert isinstance(dilations, tuple) 57 | self.in_channels = in_channels 58 | self.mid_channels = mid_channels 59 | self.start_level = start_level 60 | self.num_ins = len(in_channels) 61 | if end_level == -1: 62 | self.backbone_end_level = self.num_ins 63 | else: 64 | self.backbone_end_level = end_level 65 | assert end_level <= len(in_channels) 66 | 67 | self.dilations = dilations 68 | self.align_corners = align_corners 69 | 70 | self.conv_layers = nn.ModuleList() 71 | self.dilation_layers = nn.ModuleList() 72 | for i in range(self.start_level, self.backbone_end_level): 73 | conv_layer = nn.Sequential( 74 | ConvModule( 75 | self.in_channels[i], 76 | self.mid_channels, 77 | kernel_size=3, 78 | padding=1, 79 | conv_cfg=conv_cfg, 80 | norm_cfg=norm_cfg, 81 | act_cfg=act_cfg)) 82 | self.conv_layers.append(conv_layer) 83 | for i in range(len(dilations)): 84 | dilation_layer = nn.Sequential( 85 | DepthwiseSeparableConvModule( 86 | in_channels=(self.backbone_end_level - self.start_level) * 87 | self.mid_channels, 88 | out_channels=self.mid_channels, 89 | kernel_size=3, 90 | stride=1, 91 | padding=dilations[i], 92 | dilation=dilations[i], 93 | dw_norm_cfg=norm_cfg, 94 | dw_act_cfg=None, 95 | pw_norm_cfg=norm_cfg, 96 | pw_act_cfg=act_cfg)) 97 | self.dilation_layers.append(dilation_layer) 98 | 99 | def forward(self, inputs): 100 | """Forward function.""" 101 | assert len(inputs) == len(self.in_channels), 'Length of inputs must \ 102 | be the same with self.in_channels!' 103 | 104 | feats = [ 105 | self.conv_layers[i - self.start_level](inputs[i]) 106 | for i in range(self.start_level, self.backbone_end_level) 107 | ] 108 | 109 | h, w = feats[0].shape[2:] 110 | for i in range(1, len(feats)): 111 | feats[i] = resize( 112 | feats[i], 113 | size=(h, w), 114 | mode='bilinear', 115 | align_corners=self.align_corners) 116 | 117 | feat = torch.cat(feats, dim=1) 118 | concat_feat = torch.cat([ 119 | self.dilation_layers[i](feat) for i in range(len(self.dilations)) 120 | ], 121 | dim=1) 122 | 123 | outs = [] 124 | 125 | # Default: outs[2] is the output of JPU for decoder head, outs[1] is 126 | # the feature map from backbone for auxiliary head. Additionally, 127 | # outs[0] can also be used for auxiliary head. 128 | for i in range(self.start_level, self.backbone_end_level - 1): 129 | outs.append(inputs[i]) 130 | outs.append(concat_feat) 131 | return tuple(outs) 132 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/dm_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer 6 | 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | class DCM(nn.Module): 12 | """Dynamic Convolutional Module used in DMNet. 13 | 14 | Args: 15 | filter_size (int): The filter size of generated convolution kernel 16 | used in Dynamic Convolutional Module. 17 | fusion (bool): Add one conv to fuse DCM output feature. 18 | in_channels (int): Input channels. 19 | channels (int): Channels after modules, before conv_seg. 20 | conv_cfg (dict | None): Config of conv layers. 21 | norm_cfg (dict | None): Config of norm layers. 22 | act_cfg (dict): Config of activation layers. 23 | """ 24 | 25 | def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, 26 | norm_cfg, act_cfg): 27 | super(DCM, self).__init__() 28 | self.filter_size = filter_size 29 | self.fusion = fusion 30 | self.in_channels = in_channels 31 | self.channels = channels 32 | self.conv_cfg = conv_cfg 33 | self.norm_cfg = norm_cfg 34 | self.act_cfg = act_cfg 35 | self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, 36 | 0) 37 | 38 | self.input_redu_conv = ConvModule( 39 | self.in_channels, 40 | self.channels, 41 | 1, 42 | conv_cfg=self.conv_cfg, 43 | norm_cfg=self.norm_cfg, 44 | act_cfg=self.act_cfg) 45 | 46 | if self.norm_cfg is not None: 47 | self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] 48 | else: 49 | self.norm = None 50 | self.activate = build_activation_layer(self.act_cfg) 51 | 52 | if self.fusion: 53 | self.fusion_conv = ConvModule( 54 | self.channels, 55 | self.channels, 56 | 1, 57 | conv_cfg=self.conv_cfg, 58 | norm_cfg=self.norm_cfg, 59 | act_cfg=self.act_cfg) 60 | 61 | def forward(self, x): 62 | """Forward function.""" 63 | generated_filter = self.filter_gen_conv( 64 | F.adaptive_avg_pool2d(x, self.filter_size)) 65 | x = self.input_redu_conv(x) 66 | b, c, h, w = x.shape 67 | # [1, b * c, h, w], c = self.channels 68 | x = x.view(1, b * c, h, w) 69 | # [b * c, 1, filter_size, filter_size] 70 | generated_filter = generated_filter.view(b * c, 1, self.filter_size, 71 | self.filter_size) 72 | pad = (self.filter_size - 1) // 2 73 | if (self.filter_size - 1) % 2 == 0: 74 | p2d = (pad, pad, pad, pad) 75 | else: 76 | p2d = (pad + 1, pad, pad + 1, pad) 77 | x = F.pad(input=x, pad=p2d, mode='constant', value=0) 78 | # [1, b * c, h, w] 79 | output = F.conv2d(input=x, weight=generated_filter, groups=b * c) 80 | # [b, c, h, w] 81 | output = output.view(b, c, h, w) 82 | if self.norm is not None: 83 | output = self.norm(output) 84 | output = self.activate(output) 85 | 86 | if self.fusion: 87 | output = self.fusion_conv(output) 88 | 89 | return output 90 | 91 | 92 | @HEADS.register_module() 93 | class DMHead(BaseDecodeHead): 94 | """Dynamic Multi-scale Filters for Semantic Segmentation. 95 | 96 | This head is the implementation of 97 | `DMNet `_. 100 | 101 | Args: 102 | filter_sizes (tuple[int]): The size of generated convolutional filters 103 | used in Dynamic Convolutional Module. Default: (1, 3, 5, 7). 104 | fusion (bool): Add one conv to fuse DCM output feature. 105 | """ 106 | 107 | def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): 108 | super(DMHead, self).__init__(**kwargs) 109 | assert isinstance(filter_sizes, (list, tuple)) 110 | self.filter_sizes = filter_sizes 111 | self.fusion = fusion 112 | dcm_modules = [] 113 | for filter_size in self.filter_sizes: 114 | dcm_modules.append( 115 | DCM(filter_size, 116 | self.fusion, 117 | self.in_channels, 118 | self.channels, 119 | conv_cfg=self.conv_cfg, 120 | norm_cfg=self.norm_cfg, 121 | act_cfg=self.act_cfg)) 122 | self.dcm_modules = nn.ModuleList(dcm_modules) 123 | self.bottleneck = ConvModule( 124 | self.in_channels + len(filter_sizes) * self.channels, 125 | self.channels, 126 | 3, 127 | padding=1, 128 | conv_cfg=self.conv_cfg, 129 | norm_cfg=self.norm_cfg, 130 | act_cfg=self.act_cfg) 131 | 132 | def forward(self, inputs): 133 | """Forward function.""" 134 | x = self._transform_inputs(inputs) 135 | dcm_outs = [x] 136 | for dcm_module in self.dcm_modules: 137 | dcm_outs.append(dcm_module(x)) 138 | dcm_outs = torch.cat(dcm_outs, dim=1) 139 | output = self.bottleneck(dcm_outs) 140 | output = self.cls_seg(output) 141 | return output 142 | -------------------------------------------------------------------------------- /marinext/mmseg/models/backbones/resnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | from mmcv.cnn import build_conv_layer, build_norm_layer 5 | 6 | from ..builder import BACKBONES 7 | from ..utils import ResLayer 8 | from .resnet import Bottleneck as _Bottleneck 9 | from .resnet import ResNet 10 | 11 | 12 | class Bottleneck(_Bottleneck): 13 | """Bottleneck block for ResNeXt. 14 | 15 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is 16 | "caffe", the stride-two layer is the first 1x1 conv layer. 17 | """ 18 | 19 | def __init__(self, 20 | inplanes, 21 | planes, 22 | groups=1, 23 | base_width=4, 24 | base_channels=64, 25 | **kwargs): 26 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs) 27 | 28 | if groups == 1: 29 | width = self.planes 30 | else: 31 | width = math.floor(self.planes * 32 | (base_width / base_channels)) * groups 33 | 34 | self.norm1_name, norm1 = build_norm_layer( 35 | self.norm_cfg, width, postfix=1) 36 | self.norm2_name, norm2 = build_norm_layer( 37 | self.norm_cfg, width, postfix=2) 38 | self.norm3_name, norm3 = build_norm_layer( 39 | self.norm_cfg, self.planes * self.expansion, postfix=3) 40 | 41 | self.conv1 = build_conv_layer( 42 | self.conv_cfg, 43 | self.inplanes, 44 | width, 45 | kernel_size=1, 46 | stride=self.conv1_stride, 47 | bias=False) 48 | self.add_module(self.norm1_name, norm1) 49 | fallback_on_stride = False 50 | self.with_modulated_dcn = False 51 | if self.with_dcn: 52 | fallback_on_stride = self.dcn.pop('fallback_on_stride', False) 53 | if not self.with_dcn or fallback_on_stride: 54 | self.conv2 = build_conv_layer( 55 | self.conv_cfg, 56 | width, 57 | width, 58 | kernel_size=3, 59 | stride=self.conv2_stride, 60 | padding=self.dilation, 61 | dilation=self.dilation, 62 | groups=groups, 63 | bias=False) 64 | else: 65 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN' 66 | self.conv2 = build_conv_layer( 67 | self.dcn, 68 | width, 69 | width, 70 | kernel_size=3, 71 | stride=self.conv2_stride, 72 | padding=self.dilation, 73 | dilation=self.dilation, 74 | groups=groups, 75 | bias=False) 76 | 77 | self.add_module(self.norm2_name, norm2) 78 | self.conv3 = build_conv_layer( 79 | self.conv_cfg, 80 | width, 81 | self.planes * self.expansion, 82 | kernel_size=1, 83 | bias=False) 84 | self.add_module(self.norm3_name, norm3) 85 | 86 | 87 | @BACKBONES.register_module() 88 | class ResNeXt(ResNet): 89 | """ResNeXt backbone. 90 | 91 | This backbone is the implementation of `Aggregated 92 | Residual Transformations for Deep Neural 93 | Networks `_. 94 | 95 | Args: 96 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 97 | in_channels (int): Number of input image channels. Normally 3. 98 | num_stages (int): Resnet stages, normally 4. 99 | groups (int): Group of resnext. 100 | base_width (int): Base width of resnext. 101 | strides (Sequence[int]): Strides of the first block of each stage. 102 | dilations (Sequence[int]): Dilation of each stage. 103 | out_indices (Sequence[int]): Output from which stages. 104 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 105 | layer is the 3x3 conv layer, otherwise the stride-two layer is 106 | the first 1x1 conv layer. 107 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means 108 | not freezing any parameters. 109 | norm_cfg (dict): dictionary to construct and config norm layer. 110 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 111 | freeze running stats (mean and var). Note: Effect on Batch Norm 112 | and its variants only. 113 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 114 | memory while slowing down the training speed. 115 | zero_init_residual (bool): whether to use zero init for last norm layer 116 | in resblocks to let them behave as identity. 117 | 118 | Example: 119 | >>> from mmseg.models import ResNeXt 120 | >>> import torch 121 | >>> self = ResNeXt(depth=50) 122 | >>> self.eval() 123 | >>> inputs = torch.rand(1, 3, 32, 32) 124 | >>> level_outputs = self.forward(inputs) 125 | >>> for level_out in level_outputs: 126 | ... print(tuple(level_out.shape)) 127 | (1, 256, 8, 8) 128 | (1, 512, 4, 4) 129 | (1, 1024, 2, 2) 130 | (1, 2048, 1, 1) 131 | """ 132 | 133 | arch_settings = { 134 | 50: (Bottleneck, (3, 4, 6, 3)), 135 | 101: (Bottleneck, (3, 4, 23, 3)), 136 | 152: (Bottleneck, (3, 8, 36, 3)) 137 | } 138 | 139 | def __init__(self, groups=1, base_width=4, **kwargs): 140 | self.groups = groups 141 | self.base_width = base_width 142 | super(ResNeXt, self).__init__(**kwargs) 143 | 144 | def make_res_layer(self, **kwargs): 145 | """Pack all blocks in a stage into a ``ResLayer``""" 146 | return ResLayer( 147 | groups=self.groups, 148 | base_width=self.base_width, 149 | base_channels=self.base_channels, 150 | **kwargs) 151 | -------------------------------------------------------------------------------- /marinext/mmseg/models/necks/ic_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn.functional as F 3 | from mmcv.cnn import ConvModule 4 | from mmcv.runner import BaseModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import NECKS 8 | 9 | 10 | class CascadeFeatureFusion(BaseModule): 11 | """Cascade Feature Fusion Unit in ICNet. 12 | 13 | Args: 14 | low_channels (int): The number of input channels for 15 | low resolution feature map. 16 | high_channels (int): The number of input channels for 17 | high resolution feature map. 18 | out_channels (int): The number of output channels. 19 | conv_cfg (dict): Dictionary to construct and config conv layer. 20 | Default: None. 21 | norm_cfg (dict): Dictionary to construct and config norm layer. 22 | Default: dict(type='BN'). 23 | act_cfg (dict): Dictionary to construct and config act layer. 24 | Default: dict(type='ReLU'). 25 | align_corners (bool): align_corners argument of F.interpolate. 26 | Default: False. 27 | init_cfg (dict or list[dict], optional): Initialization config dict. 28 | Default: None. 29 | 30 | Returns: 31 | x (Tensor): The output tensor of shape (N, out_channels, H, W). 32 | x_low (Tensor): The output tensor of shape (N, out_channels, H, W) 33 | for Cascade Label Guidance in auxiliary heads. 34 | """ 35 | 36 | def __init__(self, 37 | low_channels, 38 | high_channels, 39 | out_channels, 40 | conv_cfg=None, 41 | norm_cfg=dict(type='BN'), 42 | act_cfg=dict(type='ReLU'), 43 | align_corners=False, 44 | init_cfg=None): 45 | super(CascadeFeatureFusion, self).__init__(init_cfg=init_cfg) 46 | self.align_corners = align_corners 47 | self.conv_low = ConvModule( 48 | low_channels, 49 | out_channels, 50 | 3, 51 | padding=2, 52 | dilation=2, 53 | conv_cfg=conv_cfg, 54 | norm_cfg=norm_cfg, 55 | act_cfg=act_cfg) 56 | self.conv_high = ConvModule( 57 | high_channels, 58 | out_channels, 59 | 1, 60 | conv_cfg=conv_cfg, 61 | norm_cfg=norm_cfg, 62 | act_cfg=act_cfg) 63 | 64 | def forward(self, x_low, x_high): 65 | x_low = resize( 66 | x_low, 67 | size=x_high.size()[2:], 68 | mode='bilinear', 69 | align_corners=self.align_corners) 70 | # Note: Different from original paper, `x_low` is underwent 71 | # `self.conv_low` rather than another 1x1 conv classifier 72 | # before being used for auxiliary head. 73 | x_low = self.conv_low(x_low) 74 | x_high = self.conv_high(x_high) 75 | x = x_low + x_high 76 | x = F.relu(x, inplace=True) 77 | return x, x_low 78 | 79 | 80 | @NECKS.register_module() 81 | class ICNeck(BaseModule): 82 | """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. 83 | 84 | This head is the implementation of `ICHead 85 | `_. 86 | 87 | Args: 88 | in_channels (int): The number of input image channels. Default: 3. 89 | out_channels (int): The numbers of output feature channels. 90 | Default: 128. 91 | conv_cfg (dict): Dictionary to construct and config conv layer. 92 | Default: None. 93 | norm_cfg (dict): Dictionary to construct and config norm layer. 94 | Default: dict(type='BN'). 95 | act_cfg (dict): Dictionary to construct and config act layer. 96 | Default: dict(type='ReLU'). 97 | align_corners (bool): align_corners argument of F.interpolate. 98 | Default: False. 99 | init_cfg (dict or list[dict], optional): Initialization config dict. 100 | Default: None. 101 | """ 102 | 103 | def __init__(self, 104 | in_channels=(64, 256, 256), 105 | out_channels=128, 106 | conv_cfg=None, 107 | norm_cfg=dict(type='BN'), 108 | act_cfg=dict(type='ReLU'), 109 | align_corners=False, 110 | init_cfg=None): 111 | super(ICNeck, self).__init__(init_cfg=init_cfg) 112 | assert len(in_channels) == 3, 'Length of input channels \ 113 | must be 3!' 114 | 115 | self.in_channels = in_channels 116 | self.out_channels = out_channels 117 | self.conv_cfg = conv_cfg 118 | self.norm_cfg = norm_cfg 119 | self.act_cfg = act_cfg 120 | self.align_corners = align_corners 121 | self.cff_24 = CascadeFeatureFusion( 122 | self.in_channels[2], 123 | self.in_channels[1], 124 | self.out_channels, 125 | conv_cfg=self.conv_cfg, 126 | norm_cfg=self.norm_cfg, 127 | act_cfg=self.act_cfg, 128 | align_corners=self.align_corners) 129 | 130 | self.cff_12 = CascadeFeatureFusion( 131 | self.out_channels, 132 | self.in_channels[0], 133 | self.out_channels, 134 | conv_cfg=self.conv_cfg, 135 | norm_cfg=self.norm_cfg, 136 | act_cfg=self.act_cfg, 137 | align_corners=self.align_corners) 138 | 139 | def forward(self, inputs): 140 | assert len(inputs) == 3, 'Length of input feature \ 141 | maps must be 3!' 142 | 143 | x_sub1, x_sub2, x_sub4 = inputs 144 | x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) 145 | x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1) 146 | # Note: `x_cff_12` is used for decode_head, 147 | # `x_24` and `x_12` are used for auxiliary head. 148 | return x_24, x_12, x_cff_12 149 | -------------------------------------------------------------------------------- /utils/spectral_extraction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Author: Ioannis Kakogeorgiou 4 | Email: gkakogeorgiou@gmail.com 5 | Python Version: 3.7.10 6 | Description: spectral_extraction.py extraction of the spectral signature, indices or texture features 7 | in a hdf5 table format for analysis and for the pixel-level semantic segmentation with 8 | random forest classifier. 9 | ''' 10 | 11 | import os 12 | import sys 13 | import argparse 14 | import numpy as np 15 | import pandas as pd 16 | from glob import glob 17 | from tqdm import tqdm 18 | from osgeo import gdal 19 | from os.path import dirname as up 20 | from assets import s2_mapping, mados_cat_mapping, conf_mapping 21 | 22 | rev_cat_mapping = {v:k for k,v in mados_cat_mapping.items()} 23 | rev_conf_mapping = {v:k for k,v in conf_mapping.items()} 24 | 25 | def ImageToDataframe(RefImage, cols_mapping = {}, keep_annotated = True, prefix = '_rhorc_'): 26 | # This function transform an image with the associated class and 27 | # confidence tif files (_cl.tif and _conf.tif) to a dataframe 28 | 29 | # Read patch 30 | ds = gdal.Open(RefImage) 31 | IM = np.copy(ds.ReadAsArray()) 32 | 33 | # Read associated class patch 34 | ds_cl = gdal.Open(RefImage.replace(prefix , '_cl_')) 35 | IM_cl = np.copy(ds_cl.ReadAsArray())[np.newaxis, :, :] 36 | 37 | # Read associated confidence level patch 38 | ds_conf = gdal.Open(RefImage.replace(prefix , '_conf_')) 39 | IM_conf = np.copy(ds_conf.ReadAsArray())[np.newaxis, :, :] 40 | 41 | # Read associated class patch 42 | ds_rep = gdal.Open(RefImage.replace(prefix , '_rep_')) 43 | IM_rep = np.copy(ds_rep.ReadAsArray())[np.newaxis, :, :] 44 | 45 | # Stack all these together 46 | IM_T = np.moveaxis(np.concatenate([IM, IM_cl, IM_conf, IM_rep], axis = 0), 0, -1) 47 | 48 | bands = IM_T.shape[-1] 49 | IM_VECT = IM_T.reshape([-1,bands]) 50 | 51 | IM_VECT = IM_VECT[IM_VECT[:,-3] > 0] # Keep only based on non zero class 52 | 53 | if cols_mapping: 54 | IM_df = pd.DataFrame({k:IM_VECT[:,v] for k, v in cols_mapping.items()}) 55 | else: 56 | IM_df = pd.DataFrame(IM_VECT) 57 | 58 | ds = None 59 | ds_conf = None 60 | ds_cl = None 61 | ds_rep = None 62 | 63 | return IM_df 64 | 65 | def main(options): 66 | 67 | mapping = s2_mapping 68 | h5_prefix = 'dataset' 69 | prefix = '_rhorc_' 70 | 71 | # Get patches files without _cl and _conf associated files 72 | patches = glob(os.path.join(options['path'],'*','*.tif')) 73 | 74 | patches = [p for p in patches if ('_cl_' not in p) and ('_conf_' not in p) and ('_rep_' not in p)] 75 | 76 | root_path = os.path.dirname(options['path']) 77 | 78 | # Read splits 79 | X_train = np.genfromtxt(os.path.join(options['path'], 'splits','train_X.txt'),dtype='str') 80 | 81 | X_val = np.genfromtxt(os.path.join(options['path'], 'splits','val_X.txt'),dtype='str') 82 | 83 | X_test = np.genfromtxt(os.path.join(options['path'], 'splits','test_X.txt'),dtype='str') 84 | 85 | dataset_name = os.path.join(root_path, h5_prefix + '_nonindex.h5') 86 | hdf = pd.HDFStore(dataset_name, mode = 'w') 87 | 88 | # For each patch extract the spectral signatures and store them 89 | for im_name in tqdm(patches): 90 | 91 | # Get date_tile_image info 92 | 93 | splited_name = os.path.basename(im_name).split('.tif')[0].split('_') 94 | img_name = '_'.join(splited_name[:-3]) + '_' + splited_name[-1] 95 | 96 | # Generate Dataframe from Image 97 | if img_name in X_train: 98 | split = 'Train' 99 | temp = ImageToDataframe(im_name, mapping, prefix = prefix) 100 | elif img_name in X_val: 101 | split = 'Validation' 102 | temp = ImageToDataframe(im_name, mapping, prefix = prefix) 103 | elif img_name in X_test: 104 | split = 'Test' 105 | temp = ImageToDataframe(im_name, mapping, prefix = prefix) 106 | else: 107 | raise AssertionError("Image not in train,val,test splits") 108 | 109 | temp['Scene'] = os.path.basename(im_name).split('_')[1] 110 | temp['Crop'] = os.path.basename(im_name).split('_')[-1].replace('.tif','') 111 | 112 | # Store data 113 | hdf.append(split, temp, format='table', data_columns=True, min_itemsize={'Class':27, 114 | 'Confidence':8, 115 | 'Crop':3, 116 | 'Scene':5}) 117 | 118 | hdf.close() 119 | 120 | # Read the stored file and fix an indexing problem (indexes were not incremental and unique) 121 | hdf_old = pd.HDFStore(dataset_name, mode = 'r') 122 | 123 | df_train = hdf_old['Train'].copy(deep=True) 124 | df_val = hdf_old['Validation'].copy(deep=True) 125 | df_test = hdf_old['Test'].copy(deep=True) 126 | 127 | df_train.reset_index(drop = True, inplace = True) 128 | df_val.reset_index(drop = True, inplace = True) 129 | df_test.reset_index(drop = True, inplace = True) 130 | 131 | hdf_old.close() 132 | 133 | # Store the fixed table to a new dataset file 134 | dataset_name_fixed = os.path.join(root_path, h5_prefix+'.h5') 135 | 136 | df_train.to_hdf(dataset_name_fixed, key='Train', mode='a', format='table', data_columns=True) 137 | df_val.to_hdf(dataset_name_fixed, key='Validation', mode='a', format='table', data_columns=True) 138 | df_test.to_hdf(dataset_name_fixed, key='Test', mode='a', format='table', data_columns=True) 139 | 140 | os.remove(dataset_name) 141 | 142 | if __name__ == "__main__": 143 | 144 | parser = argparse.ArgumentParser() 145 | 146 | # Options 147 | parser.add_argument('--path', default='', help='Path to Images') 148 | 149 | args = parser.parse_args() 150 | options = vars(args) # convert to ordinary dict 151 | 152 | main(options) 153 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/apc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | class ACM(nn.Module): 13 | """Adaptive Context Module used in APCNet. 14 | 15 | Args: 16 | pool_scale (int): Pooling scale used in Adaptive Context 17 | Module to extract region features. 18 | fusion (bool): Add one conv to fuse residual feature. 19 | in_channels (int): Input channels. 20 | channels (int): Channels after modules, before conv_seg. 21 | conv_cfg (dict | None): Config of conv layers. 22 | norm_cfg (dict | None): Config of norm layers. 23 | act_cfg (dict): Config of activation layers. 24 | """ 25 | 26 | def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, 27 | norm_cfg, act_cfg): 28 | super(ACM, self).__init__() 29 | self.pool_scale = pool_scale 30 | self.fusion = fusion 31 | self.in_channels = in_channels 32 | self.channels = channels 33 | self.conv_cfg = conv_cfg 34 | self.norm_cfg = norm_cfg 35 | self.act_cfg = act_cfg 36 | self.pooled_redu_conv = ConvModule( 37 | self.in_channels, 38 | self.channels, 39 | 1, 40 | conv_cfg=self.conv_cfg, 41 | norm_cfg=self.norm_cfg, 42 | act_cfg=self.act_cfg) 43 | 44 | self.input_redu_conv = ConvModule( 45 | self.in_channels, 46 | self.channels, 47 | 1, 48 | conv_cfg=self.conv_cfg, 49 | norm_cfg=self.norm_cfg, 50 | act_cfg=self.act_cfg) 51 | 52 | self.global_info = ConvModule( 53 | self.channels, 54 | self.channels, 55 | 1, 56 | conv_cfg=self.conv_cfg, 57 | norm_cfg=self.norm_cfg, 58 | act_cfg=self.act_cfg) 59 | 60 | self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) 61 | 62 | self.residual_conv = ConvModule( 63 | self.channels, 64 | self.channels, 65 | 1, 66 | conv_cfg=self.conv_cfg, 67 | norm_cfg=self.norm_cfg, 68 | act_cfg=self.act_cfg) 69 | 70 | if self.fusion: 71 | self.fusion_conv = ConvModule( 72 | self.channels, 73 | self.channels, 74 | 1, 75 | conv_cfg=self.conv_cfg, 76 | norm_cfg=self.norm_cfg, 77 | act_cfg=self.act_cfg) 78 | 79 | def forward(self, x): 80 | """Forward function.""" 81 | pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) 82 | # [batch_size, channels, h, w] 83 | x = self.input_redu_conv(x) 84 | # [batch_size, channels, pool_scale, pool_scale] 85 | pooled_x = self.pooled_redu_conv(pooled_x) 86 | batch_size = x.size(0) 87 | # [batch_size, pool_scale * pool_scale, channels] 88 | pooled_x = pooled_x.view(batch_size, self.channels, 89 | -1).permute(0, 2, 1).contiguous() 90 | # [batch_size, h * w, pool_scale * pool_scale] 91 | affinity_matrix = self.gla(x + resize( 92 | self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) 93 | ).permute(0, 2, 3, 1).reshape( 94 | batch_size, -1, self.pool_scale**2) 95 | affinity_matrix = F.sigmoid(affinity_matrix) 96 | # [batch_size, h * w, channels] 97 | z_out = torch.matmul(affinity_matrix, pooled_x) 98 | # [batch_size, channels, h * w] 99 | z_out = z_out.permute(0, 2, 1).contiguous() 100 | # [batch_size, channels, h, w] 101 | z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) 102 | z_out = self.residual_conv(z_out) 103 | z_out = F.relu(z_out + x) 104 | if self.fusion: 105 | z_out = self.fusion_conv(z_out) 106 | 107 | return z_out 108 | 109 | 110 | @HEADS.register_module() 111 | class APCHead(BaseDecodeHead): 112 | """Adaptive Pyramid Context Network for Semantic Segmentation. 113 | 114 | This head is the implementation of 115 | `APCNet `_. 118 | 119 | Args: 120 | pool_scales (tuple[int]): Pooling scales used in Adaptive Context 121 | Module. Default: (1, 2, 3, 6). 122 | fusion (bool): Add one conv to fuse residual feature. 123 | """ 124 | 125 | def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): 126 | super(APCHead, self).__init__(**kwargs) 127 | assert isinstance(pool_scales, (list, tuple)) 128 | self.pool_scales = pool_scales 129 | self.fusion = fusion 130 | acm_modules = [] 131 | for pool_scale in self.pool_scales: 132 | acm_modules.append( 133 | ACM(pool_scale, 134 | self.fusion, 135 | self.in_channels, 136 | self.channels, 137 | conv_cfg=self.conv_cfg, 138 | norm_cfg=self.norm_cfg, 139 | act_cfg=self.act_cfg)) 140 | self.acm_modules = nn.ModuleList(acm_modules) 141 | self.bottleneck = ConvModule( 142 | self.in_channels + len(pool_scales) * self.channels, 143 | self.channels, 144 | 3, 145 | padding=1, 146 | conv_cfg=self.conv_cfg, 147 | norm_cfg=self.norm_cfg, 148 | act_cfg=self.act_cfg) 149 | 150 | def forward(self, inputs): 151 | """Forward function.""" 152 | x = self._transform_inputs(inputs) 153 | acm_outs = [x] 154 | for acm_module in self.acm_modules: 155 | acm_outs.append(acm_module(x)) 156 | acm_outs = torch.cat(acm_outs, dim=1) 157 | output = self.bottleneck(acm_outs) 158 | output = self.cls_seg(output) 159 | return output 160 | -------------------------------------------------------------------------------- /marinext/mmseg/models/decode_heads/da_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from mmcv.cnn import ConvModule, Scale 5 | from torch import nn 6 | 7 | from mmseg.core import add_prefix 8 | from ..builder import HEADS 9 | from ..utils import SelfAttentionBlock as _SelfAttentionBlock 10 | from .decode_head import BaseDecodeHead 11 | 12 | 13 | class PAM(_SelfAttentionBlock): 14 | """Position Attention Module (PAM) 15 | 16 | Args: 17 | in_channels (int): Input channels of key/query feature. 18 | channels (int): Output channels of key/query transform. 19 | """ 20 | 21 | def __init__(self, in_channels, channels): 22 | super(PAM, self).__init__( 23 | key_in_channels=in_channels, 24 | query_in_channels=in_channels, 25 | channels=channels, 26 | out_channels=in_channels, 27 | share_key_query=False, 28 | query_downsample=None, 29 | key_downsample=None, 30 | key_query_num_convs=1, 31 | key_query_norm=False, 32 | value_out_num_convs=1, 33 | value_out_norm=False, 34 | matmul_norm=False, 35 | with_out=False, 36 | conv_cfg=None, 37 | norm_cfg=None, 38 | act_cfg=None) 39 | 40 | self.gamma = Scale(0) 41 | 42 | def forward(self, x): 43 | """Forward function.""" 44 | out = super(PAM, self).forward(x, x) 45 | 46 | out = self.gamma(out) + x 47 | return out 48 | 49 | 50 | class CAM(nn.Module): 51 | """Channel Attention Module (CAM)""" 52 | 53 | def __init__(self): 54 | super(CAM, self).__init__() 55 | self.gamma = Scale(0) 56 | 57 | def forward(self, x): 58 | """Forward function.""" 59 | batch_size, channels, height, width = x.size() 60 | proj_query = x.view(batch_size, channels, -1) 61 | proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) 62 | energy = torch.bmm(proj_query, proj_key) 63 | energy_new = torch.max( 64 | energy, -1, keepdim=True)[0].expand_as(energy) - energy 65 | attention = F.softmax(energy_new, dim=-1) 66 | proj_value = x.view(batch_size, channels, -1) 67 | 68 | out = torch.bmm(attention, proj_value) 69 | out = out.view(batch_size, channels, height, width) 70 | 71 | out = self.gamma(out) + x 72 | return out 73 | 74 | 75 | @HEADS.register_module() 76 | class DAHead(BaseDecodeHead): 77 | """Dual Attention Network for Scene Segmentation. 78 | 79 | This head is the implementation of `DANet 80 | `_. 81 | 82 | Args: 83 | pam_channels (int): The channels of Position Attention Module(PAM). 84 | """ 85 | 86 | def __init__(self, pam_channels, **kwargs): 87 | super(DAHead, self).__init__(**kwargs) 88 | self.pam_channels = pam_channels 89 | self.pam_in_conv = ConvModule( 90 | self.in_channels, 91 | self.channels, 92 | 3, 93 | padding=1, 94 | conv_cfg=self.conv_cfg, 95 | norm_cfg=self.norm_cfg, 96 | act_cfg=self.act_cfg) 97 | self.pam = PAM(self.channels, pam_channels) 98 | self.pam_out_conv = ConvModule( 99 | self.channels, 100 | self.channels, 101 | 3, 102 | padding=1, 103 | conv_cfg=self.conv_cfg, 104 | norm_cfg=self.norm_cfg, 105 | act_cfg=self.act_cfg) 106 | self.pam_conv_seg = nn.Conv2d( 107 | self.channels, self.num_classes, kernel_size=1) 108 | 109 | self.cam_in_conv = ConvModule( 110 | self.in_channels, 111 | self.channels, 112 | 3, 113 | padding=1, 114 | conv_cfg=self.conv_cfg, 115 | norm_cfg=self.norm_cfg, 116 | act_cfg=self.act_cfg) 117 | self.cam = CAM() 118 | self.cam_out_conv = ConvModule( 119 | self.channels, 120 | self.channels, 121 | 3, 122 | padding=1, 123 | conv_cfg=self.conv_cfg, 124 | norm_cfg=self.norm_cfg, 125 | act_cfg=self.act_cfg) 126 | self.cam_conv_seg = nn.Conv2d( 127 | self.channels, self.num_classes, kernel_size=1) 128 | 129 | def pam_cls_seg(self, feat): 130 | """PAM feature classification.""" 131 | if self.dropout is not None: 132 | feat = self.dropout(feat) 133 | output = self.pam_conv_seg(feat) 134 | return output 135 | 136 | def cam_cls_seg(self, feat): 137 | """CAM feature classification.""" 138 | if self.dropout is not None: 139 | feat = self.dropout(feat) 140 | output = self.cam_conv_seg(feat) 141 | return output 142 | 143 | def forward(self, inputs): 144 | """Forward function.""" 145 | x = self._transform_inputs(inputs) 146 | pam_feat = self.pam_in_conv(x) 147 | pam_feat = self.pam(pam_feat) 148 | pam_feat = self.pam_out_conv(pam_feat) 149 | pam_out = self.pam_cls_seg(pam_feat) 150 | 151 | cam_feat = self.cam_in_conv(x) 152 | cam_feat = self.cam(cam_feat) 153 | cam_feat = self.cam_out_conv(cam_feat) 154 | cam_out = self.cam_cls_seg(cam_feat) 155 | 156 | feat_sum = pam_feat + cam_feat 157 | pam_cam_out = self.cls_seg(feat_sum) 158 | 159 | return pam_cam_out, pam_out, cam_out 160 | 161 | def forward_test(self, inputs, img_metas, test_cfg): 162 | """Forward function for testing, only ``pam_cam`` is used.""" 163 | return self.forward(inputs)[0] 164 | 165 | def losses(self, seg_logit, seg_label): 166 | """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" 167 | pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit 168 | loss = dict() 169 | loss.update( 170 | add_prefix( 171 | super(DAHead, self).losses(pam_cam_seg_logit, seg_label), 172 | 'pam_cam')) 173 | loss.update( 174 | add_prefix( 175 | super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam')) 176 | loss.update( 177 | add_prefix( 178 | super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam')) 179 | return loss 180 | --------------------------------------------------------------------------------