├── tools ├── __init__.py ├── potsdam_preprocess.sh ├── vaihingen_preprocess.sh ├── throughput_count.py ├── tsne.py ├── latency_count.py ├── cam.py ├── mask_convert.py └── dataset_patch_split.py ├── cap.jpg ├── rsseg ├── losses │ ├── __init__.py │ ├── __pycache__ │ │ ├── ce_loss.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── build_loss.cpython-39.pyc │ ├── ce_loss.py │ ├── dice_loss.py │ ├── build_loss.py │ ├── focal_loss.py │ └── l1l2_loss.py ├── models │ ├── __init__.py │ ├── basemodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── dysample.cpython-39.pyc │ │ │ └── Interpolate.cpython-39.pyc │ │ ├── Interpolate.py │ │ └── dysample.py │ ├── classifiers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── base_classifier.cpython-39.pyc │ │ └── base_classifier.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ └── build_model.cpython-39.pyc │ ├── backbones │ │ ├── __pycache__ │ │ │ ├── hrnet.cpython-39.pyc │ │ │ ├── repvit.cpython-39.pyc │ │ │ ├── resnet.cpython-39.pyc │ │ │ ├── tinyvim.cpython-39.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── tvimblock.cpython-39.pyc │ │ │ └── efficientformer_v2.cpython-39.pyc │ │ ├── __init__.py │ │ ├── resnet.py │ │ ├── tinyvim.py │ │ └── repvit.py │ ├── segheads │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── doc_head.cpython-39.pyc │ │ │ ├── logcan_head.cpython-39.pyc │ │ │ └── logcanplus_head.cpython-39.pyc │ │ ├── __init__.py │ │ ├── logcan_head.py │ │ ├── doc_head.py │ │ └── logcanplus_head.py │ └── build_model.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── transform.cpython-39.pyc │ │ ├── base_dataset.cpython-39.pyc │ │ ├── build_dataset.cpython-39.pyc │ │ ├── loveda_dataset.cpython-39.pyc │ │ ├── potsdam_dataset.cpython-39.pyc │ │ └── vaihingen_dataset.cpython-39.pyc │ ├── base_dataset.py │ ├── potsdam_dataset.py │ ├── vaihingen_dataset.py │ ├── build_dataset.py │ └── loveda_dataset.py └── optimizers │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── build_optimizer.cpython-39.pyc │ └── build_optimizer.py ├── requirements.txt ├── utils ├── util.py ├── build.py └── registry.py ├── configs ├── _base_ │ ├── uavid_config.py │ ├── potsdam_config.py │ ├── vaihingen_config.py │ └── loveda_config.py ├── loveda │ ├── sacanet.py │ ├── docnet.py │ ├── logcan.py │ └── logcanplus.py ├── potsdam │ ├── sacanet.py │ ├── docnet.py │ ├── logcan.py │ └── logcanplus.py └── vaihingen │ ├── sacanet.py │ ├── docnet.py │ ├── logcan.py │ └── logcanplus.py ├── online_test.py ├── test.py ├── train.py └── README.md /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cap.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/cap.jpg -------------------------------------------------------------------------------- /rsseg/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from rsseg.losses.build_loss import build_loss -------------------------------------------------------------------------------- /rsseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | #from rsseg.models.build_model import build_model -------------------------------------------------------------------------------- /rsseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from rsseg.datasets.build_dataset import build_dataloader -------------------------------------------------------------------------------- /rsseg/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from rsseg.optimizers.build_optimizer import build_optimizer -------------------------------------------------------------------------------- /rsseg/models/basemodules/__init__.py: -------------------------------------------------------------------------------- 1 | from rsseg.models.basemodules.Interpolate import Interpolate -------------------------------------------------------------------------------- /rsseg/models/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | from rsseg.models.classifiers.base_classifier import Base_Classifier -------------------------------------------------------------------------------- /rsseg/losses/__pycache__/ce_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/losses/__pycache__/ce_loss.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/losses/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/losses/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/datasets/__pycache__/transform.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/datasets/__pycache__/transform.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/losses/__pycache__/build_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/losses/__pycache__/build_loss.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/__pycache__/build_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/__pycache__/build_model.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/optimizers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/optimizers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/datasets/__pycache__/base_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/datasets/__pycache__/base_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/datasets/__pycache__/build_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/datasets/__pycache__/build_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/datasets/__pycache__/loveda_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/datasets/__pycache__/loveda_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/datasets/__pycache__/potsdam_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/datasets/__pycache__/potsdam_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/backbones/__pycache__/hrnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/backbones/__pycache__/hrnet.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/backbones/__pycache__/repvit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/backbones/__pycache__/repvit.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/backbones/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/backbones/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/backbones/__pycache__/tinyvim.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/backbones/__pycache__/tinyvim.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/segheads/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/segheads/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/segheads/__pycache__/doc_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/segheads/__pycache__/doc_head.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/datasets/__pycache__/vaihingen_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/datasets/__pycache__/vaihingen_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/backbones/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/backbones/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/backbones/__pycache__/tvimblock.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/backbones/__pycache__/tvimblock.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/optimizers/__pycache__/build_optimizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/optimizers/__pycache__/build_optimizer.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/basemodules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/basemodules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/basemodules/__pycache__/dysample.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/basemodules/__pycache__/dysample.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/classifiers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/classifiers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/segheads/__pycache__/logcan_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/segheads/__pycache__/logcan_head.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/basemodules/__pycache__/Interpolate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/basemodules/__pycache__/Interpolate.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/segheads/__pycache__/logcanplus_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/segheads/__pycache__/logcanplus_head.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/classifiers/__pycache__/base_classifier.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/classifiers/__pycache__/base_classifier.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/backbones/__pycache__/efficientformer_v2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwmaxwma/rssegmentation/HEAD/rsseg/models/backbones/__pycache__/efficientformer_v2.cpython-39.pyc -------------------------------------------------------------------------------- /rsseg/models/segheads/__init__.py: -------------------------------------------------------------------------------- 1 | from rsseg.models.segheads.logcan_head import LoGCAN_Head 2 | from rsseg.models.segheads.doc_head import DOC_Head 3 | from rsseg.models.segheads.logcanplus_head import LoGCANPlus_Head -------------------------------------------------------------------------------- /rsseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from rsseg.models.backbones.resnet import get_resnet50_OS32, get_resnet50_OS8 2 | from rsseg.models.backbones.hrnet import get_hrnetv2_w32 3 | # from rsseg.models.backbones.tinyvim import TinyViM_S,TinyViM_B,TinyViM_L 4 | from rsseg.models.backbones.repvit import repvit_m2_3 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchmetrics==0.11.4 2 | pytorch-lightning==2.0.6 3 | scikit-image==0.21.0 4 | 5 | catalyst==20.9 6 | albumentations==1.3.1 7 | ttach==0.0.3 8 | einops==0.6.1 9 | timm==0.6.7 10 | addict==2.4.0 11 | soundfile==0.12.1 12 | prettytable==3.8.0 13 | mmcv-full==1.7.1 14 | mmsegmentation==0.30.0 15 | grad-cam==1.5.4 -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 5 | if not osp.isfile(filename): 6 | raise FileNotFoundError(msg_tmpl.format(filename)) 7 | 8 | 9 | def mkdir_or_exist(dir_name, mode=0o777): 10 | if dir_name == '': 11 | return 12 | dir_name = osp.expanduser(dir_name) 13 | os.makedirs(dir_name, mode=mode, exist_ok=True) -------------------------------------------------------------------------------- /rsseg/models/classifiers/base_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Base_Classifier(nn.Module): 6 | def __init__(self, transform_channel, num_class): 7 | super(Base_Classifier, self).__init__() 8 | self.classifier = nn.Conv2d(transform_channel, num_class, kernel_size=1, stride=1) 9 | 10 | def forward(self, out): 11 | pred = self.classifier(out[0]) 12 | return [pred] + out[1:] -------------------------------------------------------------------------------- /rsseg/models/basemodules/Interpolate.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class Interpolate(nn.Module): 5 | def __init__(self, scale=8, mode='bilinear'): 6 | super().__init__() 7 | self.scale_list = scale 8 | self.mode = mode 9 | 10 | def forward(self, x_list): 11 | for i in range(len(self.scale_list)): 12 | x_list[i] = F.interpolate(x_list[i], scale_factor = self.scale_list[i], mode = self.mode) 13 | return x_list -------------------------------------------------------------------------------- /rsseg/losses/ce_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CELoss(nn.Module): 5 | def __init__(self, ignore_index=255, reduction='mean'): 6 | super(CELoss, self).__init__() 7 | 8 | self.ignore_index = ignore_index 9 | self.criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction=reduction) 10 | if not reduction: 11 | print("disabled the reduction.") 12 | 13 | def forward(self, pred, target): 14 | loss = self.criterion(pred, target) 15 | return loss 16 | -------------------------------------------------------------------------------- /utils/build.py: -------------------------------------------------------------------------------- 1 | from utils.config import Config, ConfigDict 2 | from rsseg.models.backbones import * 3 | from rsseg.models.segheads import * 4 | from rsseg.models.classifiers import * 5 | from rsseg.models.basemodules import * 6 | from rsseg.losses import * 7 | 8 | def build_from_cfg(cfg): 9 | if not isinstance(cfg, (dict, ConfigDict, Config)): 10 | raise TypeError( 11 | f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}') 12 | if 'type' not in cfg: 13 | raise KeyError( 14 | '`cfg` must contain the key "type", ' 15 | f'but got {cfg}') 16 | obj_type = cfg.pop('type') 17 | obj_cls = eval(obj_type) 18 | obj = obj_cls(**cfg) 19 | return obj -------------------------------------------------------------------------------- /rsseg/losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class Dice_Loss(nn.Module): 6 | def __init__(self, exp=2, smooth=1, ignore_index=None, reduction='mean'): 7 | super(Dice_Loss, self).__init__() 8 | self.exp = exp 9 | self.smooth = smooth 10 | self.ignore_index = ignore_index 11 | self.reduction = reduction 12 | 13 | # pred (B K H W) target (B H W) 14 | def forward(self, pred, target): 15 | pred = F.softmax(pred, dim=1) 16 | pred = pred.reshape(pred.shape[0], -1) # (B *) 17 | target = target.reshape(target.shape[0], -1) # 18 | 19 | num = torch.sum(torch.mul(pred, one_hot_target), dim=1) + self.smooth 20 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth -------------------------------------------------------------------------------- /tools/potsdam_preprocess.sh: -------------------------------------------------------------------------------- 1 | python tools/dataset_patch_split.py \ 2 | --dataset-type "potsdam" \ 3 | --img-dir "data/potsdam/2_Ortho_RGB" \ 4 | --mask-dir "data/potsdam/5_Labels_all" \ 5 | --output-img-dir "data/potsdam/train/images_1024" \ 6 | --output-mask-dir "data/potsdam/train/masks_1024" \ 7 | --split-size 1024 \ 8 | --stride 512 \ 9 | --mode "train" 10 | 11 | python tools/dataset_patch_split.py \ 12 | --dataset-type "potsdam" \ 13 | --img-dir "data/potsdam/2_Ortho_RGB" \ 14 | --mask-dir "data/potsdam/5_Labels_all_noBoundary" \ 15 | --output-img-dir "data/potsdam/test/images_1024" \ 16 | --output-mask-dir "data/potsdam/test/masks_1024" \ 17 | --split-size 1024 \ 18 | --stride 1024 \ 19 | --mode "test" 20 | 21 | python tools/dataset_patch_split.py \ 22 | --dataset-type "potsdam" \ 23 | --img-dir "data/potsdam/2_Ortho_RGB" \ 24 | --mask-dir "data/potsdam/5_Labels_all" \ 25 | --output-img-dir "data/potsdam/test/images_1024" \ 26 | --output-mask-dir "data/potsdam/test/masks_1024_RGB" \ 27 | --split-size 1024 \ 28 | --stride 1024 \ 29 | --mode "test" -------------------------------------------------------------------------------- /rsseg/losses/build_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from rsseg.losses.ce_loss import CELoss 4 | class myLoss(nn.Module): 5 | def __init__(self, loss_name=['CELoss'], loss_weight=[1.0], ignore_index=255, reduction='mean', **kwargs): 6 | super(myLoss, self).__init__() 7 | self.loss_weight = loss_weight 8 | self.ignore_index = ignore_index 9 | self.loss_name = loss_name 10 | self.loss = list() 11 | for _loss in loss_name: 12 | self.loss.append(eval(_loss)(ignore_index,**kwargs)) 13 | 14 | def forward(self, preds, target): 15 | #loss = self.loss[0](preds[0], target) * self.loss_weight[0] 16 | all_loss = dict() 17 | all_loss['total_loss'] = 0 18 | for i in range(0, len(self.loss)): 19 | loss = self.loss[i](preds[i], target) * self.loss_weight[i] 20 | if self.loss_name[i] in all_loss: 21 | all_loss[self.loss_name[i]] += loss 22 | else: 23 | all_loss[self.loss_name[i]] = loss 24 | all_loss['total_loss'] += loss 25 | return all_loss 26 | 27 | def build_loss(cfg): 28 | loss_type = cfg.pop('type') 29 | obj_cls = eval(loss_type) 30 | obj = obj_cls(**cfg) 31 | return obj -------------------------------------------------------------------------------- /tools/vaihingen_preprocess.sh: -------------------------------------------------------------------------------- 1 | python tools/dataset_patch_split.py \ 2 | --dataset-type "vaihingen" \ 3 | --img-dir "data/vaihingen/ISPRS_semantic_labeling_Vaihingen/top" \ 4 | --mask-dir "data/vaihingen/ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE" \ 5 | --output-img-dir "data/vaihingen/train/images_1024" \ 6 | --output-mask-dir "data/vaihingen/train/masks_1024" \ 7 | --split-size 1024 \ 8 | --stride 512 \ 9 | --mode "train" 10 | 11 | python tools/dataset_patch_split.py \ 12 | --dataset-type "vaihingen" \ 13 | --img-dir "data/vaihingen/ISPRS_semantic_labeling_Vaihingen/top" \ 14 | --mask-dir "data/vaihingen/ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE" \ 15 | --output-img-dir "data/vaihingen/test/images_1024" \ 16 | --output-mask-dir "data/vaihingen/test/masks_1024" \ 17 | --split-size 1024 \ 18 | --stride 1024 \ 19 | --mode "test" 20 | 21 | python tools/dataset_patch_split.py \ 22 | --dataset-type "vaihingen" \ 23 | --img-dir "data/vaihingen/ISPRS_semantic_labeling_Vaihingen/top" \ 24 | --mask-dir "data/vaihingen/ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE" \ 25 | --output-img-dir "data/vaihingen/test/images_1024" \ 26 | --output-mask-dir "data/vaihingen/test/masks_1024_RGB" \ 27 | --split-size 1024 \ 28 | --stride 1024 \ 29 | --mode "test" -------------------------------------------------------------------------------- /configs/_base_/uavid_config.py: -------------------------------------------------------------------------------- 1 | dataset = 'uavid' 2 | dataset_config = dict( 3 | type='Uavid', 4 | data_root='data/uavid/tmp', 5 | train_mode=dict( 6 | transform=dict( 7 | RandomSizeAndCrop={"size": 512, "crop_nopad": False}, 8 | RandomHorizontallyFlip=None, 9 | RandomVerticalFlip=None, 10 | RandomRotate={"degree": 0.2}, 11 | ), 12 | loader=dict( 13 | batch_size=4, 14 | num_workers=4, 15 | pin_memory=True, 16 | shuffle=True, 17 | drop_last=True 18 | ), 19 | ), 20 | 21 | val_mode=dict( 22 | transform=dict( 23 | Resize={'size': 512} 24 | ), 25 | loader=dict( 26 | batch_size=4, 27 | num_workers=4, 28 | pin_memory=True, 29 | shuffle=False, 30 | drop_last=False 31 | ) 32 | ), 33 | ) 34 | 35 | metric_cfg1 = dict( 36 | task = 'multiclass', 37 | average='micro', 38 | num_classes = 9, 39 | ignore_index = 8 40 | ) 41 | 42 | metric_cfg2 = dict( 43 | task = 'multiclass', 44 | average='none', 45 | num_classes = 9, 46 | ignore_index = 8 47 | ) 48 | -------------------------------------------------------------------------------- /configs/_base_/potsdam_config.py: -------------------------------------------------------------------------------- 1 | dataset = 'potsdam' 2 | 3 | dataset_config = dict( 4 | type='Potsdam', 5 | data_root='data/potsdam', 6 | train_mode=dict( 7 | transform=dict( 8 | RandomSizeAndCrop = {"size": 512, "crop_nopad": False}, 9 | RandomHorizontallyFlip = None, 10 | RandomVerticalFlip = None, 11 | RandomRotate = {"degree": 0.2}, 12 | RandomGaussianBlur = None 13 | ), 14 | loader=dict( 15 | batch_size=4, 16 | num_workers=4, 17 | pin_memory=True, 18 | shuffle=True, 19 | drop_last=True 20 | ), 21 | ), 22 | 23 | val_mode=dict( 24 | transform=dict(), 25 | loader=dict( 26 | batch_size=4, 27 | num_workers=4, 28 | pin_memory=True, 29 | shuffle=False, 30 | drop_last=False 31 | ) 32 | ), 33 | ) 34 | 35 | metric_cfg1 = dict( 36 | task = 'multiclass', 37 | average='micro', 38 | num_classes = 7, 39 | ignore_index = 6 40 | ) 41 | 42 | metric_cfg2 = dict( 43 | task = 'multiclass', 44 | average='none', 45 | num_classes = 7, 46 | ignore_index = 6 47 | ) 48 | 49 | eval_label_id_left = 0 50 | eval_label_id_right = 5 51 | class_name = ['ImSurf', 'Building', 'LowVeg', 'Tree', 'Car', 'Clutter', 'Boundary'] -------------------------------------------------------------------------------- /rsseg/models/build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.build import build_from_cfg 4 | class myModel(nn.Module): 5 | def __init__(self, cfg): 6 | super(myModel, self).__init__() 7 | self.backbone = build_from_cfg(cfg.backbone) 8 | self.seghead = build_from_cfg(cfg.seghead) 9 | self.classifier = build_from_cfg(cfg.classifier) 10 | self.upsample = build_from_cfg(cfg.upsample) 11 | 12 | 13 | def forward(self, x): 14 | backbone_outputs = self.backbone(x) 15 | x_list = self.seghead(backbone_outputs) # 考虑到辅助损失 16 | x_list = self.classifier(x_list) 17 | x_list = self.upsample(x_list) 18 | 19 | return x_list 20 | 21 | """ 22 | 对于不满足该范式的模型可在backbone部分进行定义, 并在此处导入 23 | """ 24 | 25 | # model_config 26 | def build_model(cfg): 27 | c = myModel(cfg) 28 | return c 29 | 30 | 31 | if __name__ == "__main__": 32 | x = torch.randn(2, 3, 512, 512) 33 | target = torch.randint(low=0,high=6,size=[2, 512, 512]) 34 | file_path = "/home/xwma/rssegmentation/configs/docnet.py" 35 | import sys 36 | sys.path.append('/home/xwma/rssegmentation') 37 | sys.path.append('/home/xwma/rssegmentation/rsseg') 38 | from utils.config import Config 39 | from rsseg.losses import build_loss 40 | 41 | 42 | cfg = Config.fromfile(file_path) 43 | net = build_model(cfg.model_config) 44 | res = net(x) 45 | loss = build_loss(cfg.loss_config) 46 | 47 | compute = loss([res],target) 48 | print(compute) -------------------------------------------------------------------------------- /configs/_base_/vaihingen_config.py: -------------------------------------------------------------------------------- 1 | dataset = 'vaihingen' 2 | dataset_config = dict( 3 | type = 'Vaihingen', 4 | data_root = 'data/vaihingen', 5 | train_mode = dict( 6 | transform = dict( 7 | # RandomSizeAndCrop = {"size": 512, "crop_nopad": False}, 8 | RandomScale = {'scale_list':[0.5, 0.75, 1.0, 1.25, 1.5], 'mode':'value'}, 9 | SmartCropV1 = {'crop_size':512, 'max_ratio':0.75, 10 | 'ignore_index':6, 'nopad':False}, 11 | RandomHorizontallyFlip = None, 12 | RandomVerticalFlip = None, 13 | RandomRotate = {"degree": 0.2}, 14 | RandomGaussianBlur = None 15 | ), 16 | loader = dict( 17 | batch_size = 4, 18 | num_workers = 4, 19 | pin_memory=True, 20 | shuffle = True, 21 | drop_last = True 22 | ), 23 | ), 24 | 25 | val_mode = dict( 26 | transform = dict(), 27 | loader = dict( 28 | batch_size = 4, 29 | num_workers = 4, 30 | pin_memory=True, 31 | shuffle = False, 32 | drop_last = False 33 | ) 34 | ), 35 | ) 36 | metric_cfg1 = dict( 37 | task = 'multiclass', 38 | average='micro', 39 | num_classes = 7, 40 | ignore_index = 6 41 | ) 42 | 43 | metric_cfg2 = dict( 44 | task = 'multiclass', 45 | average='none', 46 | num_classes = 7, 47 | ignore_index = 6 48 | ) 49 | class_name = ['ImSurf', 'Building', 'LowVeg', 'Tree', 'Car', 'Clutter', 'Boundary'] 50 | eval_label_id_left = 0 51 | eval_label_id_right = 5 52 | -------------------------------------------------------------------------------- /configs/loveda/sacanet.py: -------------------------------------------------------------------------------- 1 | 2 | ######################## base_config ######################### 3 | gpus = [1] 4 | save_top_k = 1 5 | save_last = True 6 | check_val_every_n_epoch = 1 7 | logging_interval = 'epoch' 8 | resume_ckpt_path = None 9 | pretrained_ckpt_path = None 10 | monitor = 'val_miou' 11 | 12 | test_ckpt_path = None 13 | 14 | ######################## dataset_config ###################### 15 | exp_name = "work_dirs/sacanet_loveda" 16 | _base_ = '../_base_/loveda_config.py' 17 | epoch = 50 18 | num_class = 7 19 | ignore_index = 7 20 | 21 | ######################### model_config ######################### 22 | model_config = dict( 23 | backbone = dict( 24 | type = 'get_hrnetv2_w32', 25 | ), 26 | seghead = dict( 27 | type = 'Sacanet', 28 | num_class = num_class, 29 | patch_size=(16, 16), 30 | ), 31 | classifier = dict( 32 | type = 'Base_Classifier', 33 | transform_channel = 128, 34 | num_class = num_class, 35 | ), 36 | upsample=dict( 37 | type='Interpolate', 38 | mode='bilinear', 39 | scale=[4,4], 40 | ) 41 | ) 42 | loss_config = dict( 43 | type = 'myLoss', 44 | loss_name = ['CELoss', 'CELoss'], 45 | loss_weight = [1, 0.8], 46 | ignore_index = ignore_index 47 | ) 48 | 49 | ######################## optimizer_config ###################### 50 | optimizer_config = dict( 51 | optimizer = dict( 52 | type = 'SGD', 53 | backbone_lr = 0.001, 54 | backbone_weight_decay = 1e-4, 55 | lr = 0.01, 56 | weight_decay = 1e-4, 57 | momentum = 0.9, 58 | lr_mode = "single" 59 | ), 60 | scheduler = dict( 61 | type = 'Poly', 62 | poly_exp = 0.9, 63 | max_epoch = epoch 64 | ) 65 | ) 66 | -------------------------------------------------------------------------------- /configs/potsdam/sacanet.py: -------------------------------------------------------------------------------- 1 | 2 | ######################## base_config ######################### 3 | gpus = [1] 4 | save_top_k = 1 5 | save_last = True 6 | check_val_every_n_epoch = 1 7 | logging_interval = 'epoch' 8 | resume_ckpt_path = None 9 | pretrained_ckpt_path = None 10 | monitor = 'val_miou' 11 | 12 | test_ckpt_path = None 13 | 14 | ######################## dataset_config ###################### 15 | exp_name = "work_dirs/sacanet_potsdam" 16 | _base_ = '../_base_/potsdam_config.py' 17 | epoch = 80 18 | num_class = 6 19 | ignore_index = 6 20 | 21 | ######################### model_config ######################### 22 | model_config = dict( 23 | backbone = dict( 24 | type = 'get_hrnetv2_w32', 25 | ), 26 | seghead = dict( 27 | type = 'Sacanet', 28 | num_class = num_class, 29 | patch_size=(16, 16), 30 | ), 31 | classifier = dict( 32 | type = 'Base_Classifier', 33 | transform_channel = 128, 34 | num_class = num_class, 35 | ), 36 | upsample=dict( 37 | type='Interpolate', 38 | mode='bilinear', 39 | scale=[4,4], 40 | ) 41 | ) 42 | loss_config = dict( 43 | type = 'myLoss', 44 | loss_name = ['CELoss', 'CELoss'], 45 | loss_weight = [1, 0.8], 46 | ignore_index = ignore_index 47 | ) 48 | 49 | ######################## optimizer_config ###################### 50 | optimizer_config = dict( 51 | optimizer = dict( 52 | type = 'SGD', 53 | backbone_lr = 0.001, 54 | backbone_weight_decay = 1e-4, 55 | lr = 0.01, 56 | weight_decay = 1e-4, 57 | momentum = 0.9, 58 | lr_mode = "single" 59 | ), 60 | scheduler = dict( 61 | type = 'Poly', 62 | poly_exp = 0.9, 63 | max_epoch = epoch 64 | ) 65 | ) 66 | -------------------------------------------------------------------------------- /configs/vaihingen/sacanet.py: -------------------------------------------------------------------------------- 1 | 2 | ######################## base_config ######################### 3 | gpus = [1] 4 | save_top_k = 1 5 | save_last = True 6 | check_val_every_n_epoch = 1 7 | logging_interval = 'epoch' 8 | resume_ckpt_path = None 9 | pretrained_ckpt_path = None 10 | monitor = 'val_miou' 11 | 12 | test_ckpt_path = None 13 | 14 | ######################## dataset_config ###################### 15 | exp_name = "work_dirs/sacanet_vaihingen" 16 | _base_ = '../_base_/vaihingen_config.py' 17 | epoch = 150 18 | num_class = 6 19 | ignore_index = 6 20 | 21 | ######################### model_config ######################### 22 | model_config = dict( 23 | backbone = dict( 24 | type = 'get_hrnetv2_w32', 25 | ), 26 | seghead = dict( 27 | type = 'Sacanet', 28 | num_class = num_class, 29 | patch_size=(16, 16), 30 | ), 31 | classifier = dict( 32 | type = 'Base_Classifier', 33 | transform_channel = 128, 34 | num_class = num_class, 35 | ), 36 | upsample=dict( 37 | type='Interpolate', 38 | mode='bilinear', 39 | scale=[4,4], 40 | ) 41 | ) 42 | loss_config = dict( 43 | type = 'myLoss', 44 | loss_name = ['CELoss', 'CELoss'], 45 | loss_weight = [1, 0.8], 46 | ignore_index = ignore_index 47 | ) 48 | 49 | ######################## optimizer_config ###################### 50 | optimizer_config = dict( 51 | optimizer = dict( 52 | type = 'SGD', 53 | backbone_lr = 0.001, 54 | backbone_weight_decay = 1e-4, 55 | lr = 0.01, 56 | weight_decay = 1e-4, 57 | momentum = 0.9, 58 | lr_mode = "single" 59 | ), 60 | scheduler = dict( 61 | type = 'Poly', 62 | poly_exp = 0.9, 63 | max_epoch = epoch 64 | ) 65 | ) 66 | -------------------------------------------------------------------------------- /configs/loveda/docnet.py: -------------------------------------------------------------------------------- 1 | ######################## base_config ######################### 2 | gpus = [1] 3 | save_top_k = 1 4 | save_last = True 5 | check_val_every_n_epoch = 1 6 | logging_interval = 'epoch' 7 | resume_ckpt_path = None 8 | pretrained_ckpt_path = None 9 | monitor = 'val_miou' 10 | 11 | test_ckpt_path = None 12 | 13 | ######################## dataset_config ###################### 14 | exp_name = "work_dirs/docnet_loveda" 15 | _base_ = '../_base_/loveda_config.py' 16 | epoch = 50 17 | num_class = 7 18 | ignore_index = 7 19 | 20 | ######################### model_config ######################### 21 | model_config = dict( 22 | backbone = dict( 23 | type = 'get_hrnetv2_w32' 24 | ), 25 | seghead = dict( 26 | type = 'DOC_Head', 27 | num_class = num_class 28 | ), 29 | classifier = dict( 30 | type = 'Base_Classifier', 31 | transform_channel = 512, 32 | num_class = num_class, 33 | ), 34 | upsample=dict( 35 | type='Interpolate', 36 | mode='bilinear', 37 | scale=[4, 4], 38 | ) 39 | ) 40 | loss_config = dict( 41 | type = 'myLoss', 42 | loss_name = ['CELoss', 'CELoss'], 43 | loss_weight = [1, 0.4], 44 | ignore_index = ignore_index 45 | ) 46 | 47 | 48 | ######################## optimizer_config ###################### 49 | optimizer_config = dict( 50 | optimizer = dict( 51 | type = 'AdamW', 52 | backbone_lr = 0.0001, 53 | backbone_weight_decay = 0.05, 54 | lr = 0.0001, 55 | weight_decay = 1e-4, 56 | momentum = 0.9, 57 | lr_mode = "single" 58 | ), 59 | scheduler = dict( 60 | type = 'Poly', 61 | poly_exp = 0.9, 62 | max_epoch = epoch 63 | ) 64 | ) 65 | -------------------------------------------------------------------------------- /configs/potsdam/docnet.py: -------------------------------------------------------------------------------- 1 | ######################## base_config ######################### 2 | gpus = [1] 3 | save_top_k = 1 4 | save_last = True 5 | check_val_every_n_epoch = 1 6 | logging_interval = 'epoch' 7 | resume_ckpt_path = None 8 | pretrained_ckpt_path = None 9 | monitor = 'val_miou' 10 | 11 | test_ckpt_path = None 12 | 13 | ######################## dataset_config ###################### 14 | exp_name = "work_dirs/docnet_potsdam" 15 | _base_ = '../_base_/potsdam_config.py' 16 | epoch = 80 17 | num_class = 6 18 | ignore_index = 6 19 | 20 | ######################### model_config ######################### 21 | model_config = dict( 22 | backbone = dict( 23 | type = 'get_hrnetv2_w32' 24 | ), 25 | seghead = dict( 26 | type = 'DOC_Head', 27 | num_class = num_class 28 | ), 29 | classifier = dict( 30 | type = 'Base_Classifier', 31 | transform_channel = 512, 32 | num_class = num_class, 33 | ), 34 | upsample=dict( 35 | type='Interpolate', 36 | mode='bilinear', 37 | scale=[4, 4], 38 | ) 39 | ) 40 | loss_config = dict( 41 | type = 'myLoss', 42 | loss_name = ['CELoss', 'CELoss'], 43 | loss_weight = [1, 0.4], 44 | ignore_index = ignore_index 45 | ) 46 | 47 | 48 | ######################## optimizer_config ###################### 49 | optimizer_config = dict( 50 | optimizer = dict( 51 | type = 'AdamW', 52 | backbone_lr = 0.0001, 53 | backbone_weight_decay = 0.05, 54 | lr = 0.0001, 55 | weight_decay = 1e-4, 56 | momentum = 0.9, 57 | lr_mode = "single" 58 | ), 59 | scheduler = dict( 60 | type = 'Poly', 61 | poly_exp = 0.9, 62 | max_epoch = epoch 63 | ) 64 | ) 65 | -------------------------------------------------------------------------------- /configs/vaihingen/docnet.py: -------------------------------------------------------------------------------- 1 | ######################## base_config ######################### 2 | gpus = [1] 3 | save_top_k = 1 4 | save_last = True 5 | check_val_every_n_epoch = 1 6 | logging_interval = 'epoch' 7 | resume_ckpt_path = None 8 | pretrained_ckpt_path = None 9 | monitor = 'val_miou' 10 | 11 | test_ckpt_path = None 12 | 13 | ######################## dataset_config ###################### 14 | exp_name = "work_dirs/docnet_vaihingen" 15 | _base_ = '../_base_/vaihingen_config.py' 16 | epoch = 150 17 | num_class = 6 18 | ignore_index = 6 19 | 20 | ######################### model_config ######################### 21 | model_config = dict( 22 | backbone = dict( 23 | type = 'get_hrnetv2_w32' 24 | ), 25 | seghead = dict( 26 | type = 'DOC_Head', 27 | num_class = num_class 28 | ), 29 | classifier = dict( 30 | type = 'Base_Classifier', 31 | transform_channel = 512, 32 | num_class = num_class, 33 | ), 34 | upsample=dict( 35 | type='Interpolate', 36 | mode='bilinear', 37 | scale=[4, 4], 38 | ) 39 | ) 40 | loss_config = dict( 41 | type = 'myLoss', 42 | loss_name = ['CELoss', 'CELoss'], 43 | loss_weight = [1, 0.4], 44 | ignore_index = ignore_index 45 | ) 46 | 47 | 48 | ######################## optimizer_config ###################### 49 | optimizer_config = dict( 50 | optimizer = dict( 51 | type = 'AdamW', 52 | backbone_lr = 0.0001, 53 | backbone_weight_decay = 0.05, 54 | lr = 0.0001, 55 | weight_decay = 1e-4, 56 | momentum = 0.9, 57 | lr_mode = "single" 58 | ), 59 | scheduler = dict( 60 | type = 'Poly', 61 | poly_exp = 0.9, 62 | max_epoch = epoch 63 | ) 64 | ) 65 | -------------------------------------------------------------------------------- /configs/_base_/loveda_config.py: -------------------------------------------------------------------------------- 1 | dataset = 'loveda' 2 | 3 | dataset_config = dict( 4 | type='LoveDA', 5 | data_root='data/2021LoveDA', 6 | train_mode=dict( 7 | transform=dict( 8 | RandomSizeAndCrop = {"size": 512, "crop_nopad": False}, 9 | RandomHorizontallyFlip = None, 10 | RandomVerticalFlip = None, 11 | RandomRotate = {"degree": 0.2}, 12 | RandomGaussianBlur = None 13 | ), 14 | loader=dict( 15 | batch_size=4, 16 | num_workers=4, 17 | pin_memory=True, 18 | shuffle=True, 19 | drop_last=True 20 | ), 21 | ), 22 | val_mode=dict( 23 | transform=dict( 24 | #Resize={'size': 512} 25 | ), 26 | loader=dict( 27 | batch_size=4, 28 | num_workers=4, 29 | pin_memory=True, 30 | shuffle=False, 31 | drop_last=False 32 | ) 33 | ), 34 | test_mode=dict( 35 | transform=dict( 36 | # Resize={'size': 512} 37 | ), 38 | loader=dict( 39 | batch_size=4, 40 | num_workers=4, 41 | pin_memory=True, 42 | shuffle=False, 43 | drop_last=False 44 | ) 45 | ) 46 | ) 47 | 48 | metric_cfg1 = dict( 49 | task = 'multiclass', 50 | average='micro', 51 | num_classes = 8, 52 | ignore_index = 7 53 | ) 54 | 55 | metric_cfg2 = dict( 56 | task = 'multiclass', 57 | average='none', 58 | num_classes = 8, 59 | ignore_index = 7 60 | ) 61 | 62 | eval_label_id_left = 0 63 | eval_label_id_right = 6 64 | 65 | class_name = ['building', 'road', 'water', 'barren', 'forest', 'agricultural', 'background'] -------------------------------------------------------------------------------- /rsseg/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class sigmoid_focal_loss(nn.Module): 6 | def __init__(self, alpha=-1, gamma=0, reduction='mean', ignore_index=None): 7 | super(sigmoid_focal_loss, self).__init__() 8 | self.alpha = alpha 9 | self.gamma = gamma 10 | self.reduction = reduction 11 | self.ignore_index = ignore_index 12 | 13 | # pred (B K H W) target (B H W) 14 | def forward(self, pred, target): 15 | pred = pred.permute(0, 2, 3, 1).float() # (B H W K) 16 | if self.ignore_index is not None: 17 | mask = (target != self.ignore_index) 18 | pred = pred[mask] # (n, k) 19 | target = target[mask] # (n) 20 | 21 | num_classes = pred.size(-1) 22 | one_hot_target = F.one_hot(target, num_classes) # (n k) 23 | one_hot_target = one_hot_target.float() 24 | 25 | p = torch.sigmoid(pred) 26 | ce_loss = F.binary_cross_entropy_with_logits(pred, one_hot_target, reduction="none") 27 | 28 | 29 | p_t = p * one_hot_target + (1 - p) * (1 - one_hot_target) 30 | loss = ce_loss * ((1 - p_t) ** self.gamma) 31 | 32 | if self.alpha >= 0: 33 | alpha_t = self.alpha * target + (1 - self.alpha) * (1 - one_hot_target) 34 | loss = alpha_t * loss 35 | 36 | if self.reduction == "mean": 37 | loss = loss.mean() 38 | elif self.reduction == "sum": 39 | loss = loss.sum() 40 | 41 | return loss 42 | 43 | if __name__ == "__main__": 44 | lossmodel = sigmoid_focal_loss(ignore_index=6) 45 | pred = torch.randn(4, 6, 64, 64) 46 | target = torch.randint(low=0, high=7, size=(4, 64, 64)) 47 | loss = lossmodel(pred, target) 48 | print(loss) 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/loveda/logcan.py: -------------------------------------------------------------------------------- 1 | ######################## base_config ######################### 2 | gpus = [1] 3 | save_top_k = 1 4 | save_last = True 5 | check_val_every_n_epoch = 1 6 | logging_interval = 'epoch' 7 | resume_ckpt_path = None 8 | pretrained_ckpt_path = None 9 | monitor = 'val_miou' 10 | 11 | test_ckpt_path = None 12 | 13 | ######################## dataset_config ###################### 14 | exp_name = "work_dirs/logcan_loveda" 15 | _base_ = '../_base_/loveda_config.py' 16 | epoch = 50 17 | num_class = 7 18 | ignore_index = 7 19 | 20 | ######################### model_config ######################### 21 | model_config = dict( 22 | transform_channel = 128, 23 | num_class = num_class, 24 | backbone = dict( 25 | type = 'get_resnet50_OS32', 26 | pretrained = True 27 | ), 28 | seghead = dict( 29 | type = 'LoGCAN_Head', 30 | in_channel = [256, 512, 1024, 2048], 31 | transform_channel = 128, 32 | num_class = num_class, 33 | ), 34 | classifier = dict( 35 | type = 'Base_Classifier', 36 | transform_channel = 128, 37 | num_class = num_class, 38 | ), 39 | upsample=dict( 40 | type='Interpolate', 41 | mode='bilinear', 42 | scale=[4, 32], 43 | ) 44 | ) 45 | loss_config = dict( 46 | type = 'myLoss', 47 | loss_name = ['CELoss', 'CELoss'], 48 | loss_weight = [1, 0.8], 49 | ignore_index = ignore_index 50 | ) 51 | 52 | ######################## optimizer_config ###################### 53 | optimizer_config = dict( 54 | optimizer = dict( 55 | type = 'SGD', 56 | backbone_lr = 0.001, 57 | backbone_weight_decay = 1e-4, 58 | lr = 0.01, 59 | weight_decay = 1e-4, 60 | momentum = 0.9, 61 | lr_mode = "single" 62 | ), 63 | scheduler = dict( 64 | type = 'Poly', 65 | poly_exp = 0.9, 66 | max_epoch = epoch 67 | ) 68 | ) 69 | -------------------------------------------------------------------------------- /configs/potsdam/logcan.py: -------------------------------------------------------------------------------- 1 | ######################## base_config ######################### 2 | gpus = [1] 3 | save_top_k = 1 4 | save_last = True 5 | check_val_every_n_epoch = 1 6 | logging_interval = 'epoch' 7 | resume_ckpt_path = None 8 | pretrained_ckpt_path = None 9 | monitor = 'val_miou' 10 | 11 | test_ckpt_path = None 12 | 13 | ######################## dataset_config ###################### 14 | exp_name = "work_dirs/logcan_potsdam" 15 | _base_ = '../_base_/potsdam_config.py' 16 | epoch = 80 17 | num_class = 6 18 | ignore_index = 6 19 | 20 | ######################### model_config ######################### 21 | model_config = dict( 22 | transform_channel = 128, 23 | num_class = num_class, 24 | backbone = dict( 25 | type = 'get_resnet50_OS32', 26 | pretrained = True 27 | ), 28 | seghead = dict( 29 | type = 'LoGCAN_Head', 30 | in_channel = [256, 512, 1024, 2048], 31 | transform_channel = 128, 32 | num_class = num_class, 33 | ), 34 | classifier = dict( 35 | type = 'Base_Classifier', 36 | transform_channel = 128, 37 | num_class = num_class, 38 | ), 39 | upsample=dict( 40 | type='Interpolate', 41 | mode='bilinear', 42 | scale=[4, 32], 43 | ) 44 | ) 45 | loss_config = dict( 46 | type = 'myLoss', 47 | loss_name = ['CELoss', 'CELoss'], 48 | loss_weight = [1, 0.8], 49 | ignore_index = ignore_index 50 | ) 51 | 52 | ######################## optimizer_config ###################### 53 | optimizer_config = dict( 54 | optimizer = dict( 55 | type = 'SGD', 56 | backbone_lr = 0.001, 57 | backbone_weight_decay = 1e-4, 58 | lr = 0.01, 59 | weight_decay = 1e-4, 60 | momentum = 0.9, 61 | lr_mode = "single" 62 | ), 63 | scheduler = dict( 64 | type = 'Poly', 65 | poly_exp = 0.9, 66 | max_epoch = epoch 67 | ) 68 | ) 69 | -------------------------------------------------------------------------------- /configs/vaihingen/logcan.py: -------------------------------------------------------------------------------- 1 | ######################## base_config ######################### 2 | gpus = [1] 3 | save_top_k = 1 4 | save_last = True 5 | check_val_every_n_epoch = 1 6 | logging_interval = 'epoch' 7 | resume_ckpt_path = None 8 | pretrained_ckpt_path = None 9 | monitor = 'val_miou' 10 | 11 | test_ckpt_path = None 12 | 13 | ######################## dataset_config ###################### 14 | exp_name = "work_dirs/logcan_vaihingen" 15 | _base_ = '../_base_/vaihingen_config.py' 16 | epoch = 150 17 | num_class = 6 18 | ignore_index = 6 19 | 20 | ######################### model_config ######################### 21 | model_config = dict( 22 | transform_channel = 128, 23 | num_class = num_class, 24 | backbone = dict( 25 | type = 'get_resnet50_OS32', 26 | pretrained = True 27 | ), 28 | seghead = dict( 29 | type = 'LoGCAN_Head', 30 | in_channel = [256, 512, 1024, 2048], 31 | transform_channel = 128, 32 | num_class = num_class, 33 | ), 34 | classifier = dict( 35 | type = 'Base_Classifier', 36 | transform_channel = 128, 37 | num_class = num_class, 38 | ), 39 | upsample=dict( 40 | type='Interpolate', 41 | mode='bilinear', 42 | scale=[4, 32], 43 | ) 44 | ) 45 | loss_config = dict( 46 | type = 'myLoss', 47 | loss_name = ['CELoss', 'CELoss'], 48 | loss_weight = [1, 0.8], 49 | ignore_index = ignore_index 50 | ) 51 | 52 | ######################## optimizer_config ###################### 53 | optimizer_config = dict( 54 | optimizer = dict( 55 | type = 'SGD', 56 | backbone_lr = 0.001, 57 | backbone_weight_decay = 1e-4, 58 | lr = 0.01, 59 | weight_decay = 1e-4, 60 | momentum = 0.9, 61 | lr_mode = "single" 62 | ), 63 | scheduler = dict( 64 | type = 'Poly', 65 | poly_exp = 0.9, 66 | max_epoch = epoch 67 | ) 68 | ) 69 | -------------------------------------------------------------------------------- /configs/loveda/logcanplus.py: -------------------------------------------------------------------------------- 1 | 2 | ######################## base_config ######################### 3 | gpus = [1] 4 | save_top_k = 1 5 | save_last = True 6 | check_val_every_n_epoch = 1 7 | logging_interval = 'epoch' 8 | resume_ckpt_path = None 9 | pretrained_ckpt_path = None 10 | monitor = 'val_miou' 11 | 12 | test_ckpt_path = None 13 | 14 | ######################## dataset_config ###################### 15 | exp_name = "work_dirs/logcanplus_loveda" 16 | _base_ = '../_base_/loveda_config.py' 17 | epoch = 50 18 | num_class = 7 19 | ignore_index = 7 20 | 21 | ######################### model_config ######################### 22 | model_config = dict( 23 | num_class = num_class, 24 | backbone = dict( 25 | type = 'repvit_m2_3', 26 | init_cfg=dict( 27 | type='Pretrained', 28 | checkpoint='pretrain/repvit_m2_3_distill_450e.pth', 29 | ), 30 | out_indices=[7, 15, 51, 54] 31 | ), 32 | seghead = dict( 33 | type = 'LoGCANPlus_Head', 34 | in_channel = [80, 160, 320, 640], 35 | transform_channel = 96, 36 | num_class = num_class, 37 | num_heads = 8, 38 | patch_size = (4,4) 39 | ), 40 | classifier = dict( 41 | type = 'Base_Classifier', 42 | transform_channel = 96, 43 | num_class = num_class, 44 | ), 45 | upsample=dict( 46 | type='Interpolate', 47 | mode='bilinear', 48 | scale=[4, 32], 49 | ) 50 | ) 51 | loss_config = dict( 52 | type = 'myLoss', 53 | loss_name = ['CELoss', 'CELoss'], 54 | loss_weight = [1, 0.8], 55 | ignore_index = ignore_index 56 | ) 57 | 58 | ######################## optimizer_config ###################### 59 | optimizer_config = dict( 60 | optimizer = dict( 61 | type = 'AdamW', 62 | lr = 1e-4, 63 | weight_decay = 1e-4, 64 | momentum = 0.9, 65 | lr_mode = "single" 66 | ), 67 | scheduler = dict( 68 | type = 'Poly', 69 | poly_exp = 0.9, 70 | max_epoch = epoch 71 | ) 72 | ) 73 | -------------------------------------------------------------------------------- /configs/potsdam/logcanplus.py: -------------------------------------------------------------------------------- 1 | 2 | ######################## base_config ######################### 3 | gpus = [0] 4 | save_top_k = 1 5 | save_last = True 6 | check_val_every_n_epoch = 1 7 | logging_interval = 'epoch' 8 | resume_ckpt_path = None 9 | pretrained_ckpt_path = None 10 | monitor = 'val_miou' 11 | 12 | test_ckpt_path = None 13 | 14 | ######################## dataset_config ###################### 15 | exp_name = "work_dirs/logcanplus_potsdam" 16 | _base_ = '../_base_/potsdam_config.py' 17 | epoch = 80 18 | num_class = 6 19 | ignore_index = 6 20 | 21 | ######################### model_config ######################### 22 | model_config = dict( 23 | num_class = num_class, 24 | backbone = dict( 25 | type = 'repvit_m2_3', 26 | init_cfg=dict( 27 | type='Pretrained', 28 | checkpoint='pretrain/repvit_m2_3_distill_450e.pth', 29 | ), 30 | out_indices=[7, 15, 51, 54] 31 | ), 32 | seghead = dict( 33 | type = 'LoGCANPlus_Head', 34 | in_channel = [80, 160, 320, 640], 35 | transform_channel = 96, 36 | num_class = num_class, 37 | num_heads = 8, 38 | patch_size = (4,4) 39 | ), 40 | classifier = dict( 41 | type = 'Base_Classifier', 42 | transform_channel = 96, 43 | num_class = num_class, 44 | ), 45 | upsample=dict( 46 | type='Interpolate', 47 | mode='bilinear', 48 | scale=[4, 32], 49 | ) 50 | ) 51 | loss_config = dict( 52 | type = 'myLoss', 53 | loss_name = ['CELoss', 'CELoss'], 54 | loss_weight = [1, 0.8], 55 | ignore_index = ignore_index 56 | ) 57 | 58 | ######################## optimizer_config ###################### 59 | optimizer_config = dict( 60 | optimizer = dict( 61 | type = 'AdamW', 62 | lr = 1e-4, 63 | weight_decay = 1e-4, 64 | momentum = 0.9, 65 | lr_mode = "single" 66 | ), 67 | scheduler = dict( 68 | type = 'Poly', 69 | poly_exp = 0.9, 70 | max_epoch = epoch 71 | ) 72 | ) 73 | -------------------------------------------------------------------------------- /configs/vaihingen/logcanplus.py: -------------------------------------------------------------------------------- 1 | 2 | ######################## base_config ######################### 3 | gpus = [1] 4 | save_top_k = 1 5 | save_last = True 6 | check_val_every_n_epoch = 1 7 | logging_interval = 'epoch' 8 | resume_ckpt_path = None 9 | pretrained_ckpt_path = None 10 | monitor = 'val_miou' 11 | 12 | test_ckpt_path = None 13 | 14 | ######################## dataset_config ###################### 15 | exp_name = "work_dirs/logcanplus_vaihingen" 16 | _base_ = '../_base_/vaihingen_config.py' 17 | epoch = 150 18 | num_class = 6 19 | ignore_index = 6 20 | 21 | ######################### model_config ######################### 22 | model_config = dict( 23 | num_class = num_class, 24 | backbone = dict( 25 | type = 'repvit_m2_3', 26 | init_cfg=dict( 27 | type='Pretrained', 28 | checkpoint='pretrain/repvit_m2_3_distill_450e.pth', 29 | ), 30 | out_indices=[7, 15, 51, 54] 31 | ), 32 | seghead = dict( 33 | type = 'LoGCANPlus_Head', 34 | in_channel = [80, 160, 320, 640], 35 | transform_channel = 96, 36 | num_class = num_class, 37 | num_heads = 8, 38 | patch_size = (4,4) 39 | ), 40 | classifier = dict( 41 | type = 'Base_Classifier', 42 | transform_channel = 96, 43 | num_class = num_class, 44 | ), 45 | upsample=dict( 46 | type='Interpolate', 47 | mode='bilinear', 48 | scale=[4, 32], 49 | ) 50 | ) 51 | loss_config = dict( 52 | type = 'myLoss', 53 | loss_name = ['CELoss', 'CELoss'], 54 | loss_weight = [1, 0.8], 55 | ignore_index = ignore_index 56 | ) 57 | 58 | ######################## optimizer_config ###################### 59 | optimizer_config = dict( 60 | optimizer = dict( 61 | type = 'AdamW', 62 | lr = 1e-4, 63 | weight_decay = 1e-4, 64 | momentum = 0.9, 65 | lr_mode = "single" 66 | ), 67 | scheduler = dict( 68 | type = 'Poly', 69 | poly_exp = 0.9, 70 | max_epoch = epoch 71 | ) 72 | ) 73 | -------------------------------------------------------------------------------- /rsseg/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from .transform import * 3 | import albumentations as albu 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import torch 8 | class BaseDataset(Dataset): 9 | def __init__(self, transform=None,mode="train"): 10 | self.mode = mode 11 | aug_list = [] 12 | for k,v in transform.items(): 13 | if v != None: 14 | aug_list.append(eval(k)(**v)) 15 | else: aug_list.append(eval(k)()) 16 | 17 | self.transform = Compose(aug_list) 18 | 19 | self.normalize = albu.Compose([ 20 | albu.Normalize()]) 21 | 22 | 23 | def __getitem__(self, item): 24 | pass 25 | def __len__(self): 26 | return len(self.file_paths) 27 | 28 | def __getitem__(self, index): 29 | img, mask, img_id = self.load_img_and_mask(index) 30 | 31 | if len(self.transform.transforms) != 0: 32 | img, mask = self.transform(img, mask) 33 | img,mask = np.array(img), np.array(mask) 34 | 35 | aug = self.normalize(image=img.copy(), mask=mask.copy()) 36 | img, mask = aug['image'], aug['mask'] 37 | 38 | img = torch.from_numpy(img).permute(2, 0, 1).float() 39 | mask = torch.from_numpy(mask).long() 40 | return [img,mask,img_id] 41 | 42 | def get_path(self, data_root, img_dir, mask_dir): 43 | img_filename_list = os.listdir(os.path.join(data_root, img_dir)) 44 | mask_filename_list = os.listdir(os.path.join(data_root, mask_dir)) 45 | assert len(img_filename_list) == len(mask_filename_list) 46 | img_ids = [str(id.split('.')[0]) for id in mask_filename_list] 47 | return img_ids 48 | 49 | def load_img_and_mask(self, index): 50 | img_id = self.file_paths[index] 51 | img_name = os.path.join(self.data_root, self.img_dir, img_id + self.img_suffix) 52 | mask_name = os.path.join(self.data_root, self.mask_dir, img_id + self.mask_suffix) 53 | img = Image.open(img_name).convert('RGB') 54 | mask_rgb = Image.open(mask_name).convert('RGB') 55 | mask = self.rgb2label(mask_rgb) 56 | return [img, mask, img_id] 57 | 58 | -------------------------------------------------------------------------------- /rsseg/datasets/potsdam_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | from PIL import Image 3 | import numpy as np 4 | class Potsdam(BaseDataset): 5 | def __init__(self,data_root='data/vaihingen', mode='train', transform=None,img_dir='images_1024', mask_dir='masks_1024', img_suffix='.tif', mask_suffix='.png', **kwargs): 6 | super(Potsdam, self).__init__(transform) 7 | 8 | self.img_dir = img_dir 9 | self.img_suffix = img_suffix 10 | self.mask_dir = mask_dir 11 | self.mask_suffix = mask_suffix 12 | 13 | self.data_root = data_root + "/train" if mode == "train" else data_root + "/test" 14 | self.file_paths = self.get_path(self.data_root, img_dir, mask_dir) 15 | 16 | #RGB 17 | self.color_map = { 18 | 'ImSurf' : np.array([255, 255, 255]), # label 0 19 | 'Building' : np.array([0, 0, 255]), # label 1 20 | 'LowVeg' : np.array([0, 255, 255]), # label 2 21 | 'Tree' : np.array([0, 255, 0]), # label 3 22 | "Car" : np.array([255, 255, 0]), # label 4 23 | 'Clutter' : np.array([255, 0, 0]), # label 5 24 | 'Boundary' : np.array([0, 0, 0]), # label 6 25 | } 26 | 27 | self.num_classes = 6 28 | 29 | def rgb2label(self,mask_rgb): 30 | mask_rgb = np.array(mask_rgb) 31 | _mask_rgb = mask_rgb.transpose(2, 0, 1) 32 | label_seg = np.zeros(_mask_rgb.shape[1:], dtype=np.uint8) 33 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['ImSurf'], axis=-1)] = 0 34 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Building'], axis=-1)] = 1 35 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['LowVeg'], axis=-1)] = 2 36 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Tree'], axis=-1)] = 3 37 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Car'], axis=-1)] = 4 38 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Clutter'], axis=-1)] = 5 39 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Boundary'], axis=-1)] = 6 40 | 41 | _label_seg = Image.fromarray(label_seg).convert('L') 42 | return _label_seg -------------------------------------------------------------------------------- /rsseg/datasets/vaihingen_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from .base_dataset import BaseDataset 3 | import numpy as np 4 | class Vaihingen(BaseDataset): 5 | def __init__(self, data_root='data/vaihingen', mode='train', transform=None, img_dir='images_1024', mask_dir='masks_1024', img_suffix='.tif', mask_suffix='.png', **kwargs): 6 | super(Vaihingen, self).__init__(transform, mode) 7 | 8 | self.img_dir = img_dir 9 | self.img_suffix = img_suffix 10 | self.mask_dir = mask_dir 11 | self.mask_suffix = mask_suffix 12 | 13 | self.data_root = data_root + "/train" if mode == "train" else data_root + "/test" 14 | self.file_paths = self.get_path(self.data_root, img_dir, mask_dir) 15 | 16 | #RGB 17 | self.color_map = { 18 | 'ImSurf' : np.array([255, 255, 255]), # label 0 19 | 'Building' : np.array([0, 0, 255]), # label 1 20 | 'LowVeg' : np.array([0, 255, 255]), # label 2 21 | 'Tree' : np.array([0, 255, 0]), # label 3 22 | "Car" : np.array([255, 255, 0]), # label 4 23 | 'Clutter' : np.array([255, 0, 0]), # label 5 24 | 'Boundary' : np.array([0, 0, 0]), # label 6 25 | } 26 | 27 | self.num_classes = 6 28 | 29 | def rgb2label(self,mask_rgb): 30 | mask_rgb = np.array(mask_rgb) 31 | _mask_rgb = mask_rgb.transpose(2, 0, 1) 32 | label_seg = np.zeros(_mask_rgb.shape[1:], dtype=np.uint8) 33 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['ImSurf'], axis=-1)] = 0 34 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Building'], axis=-1)] = 1 35 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['LowVeg'], axis=-1)] = 2 36 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Tree'], axis=-1)] = 3 37 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Car'], axis=-1)] = 4 38 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Clutter'], axis=-1)] = 5 39 | label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Boundary'], axis=-1)] = 6 40 | 41 | _label_seg = Image.fromarray(label_seg).convert('L') 42 | return _label_seg 43 | 44 | 45 | -------------------------------------------------------------------------------- /rsseg/optimizers/build_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from catalyst.contrib.nn import Lookahead 4 | from catalyst import utils 5 | import math 6 | 7 | class lambdax: 8 | def __init__(self, cfg): 9 | self.cfg = cfg 10 | @staticmethod 11 | def lambda_epoch(self, epoch): 12 | return math.pow(1 - epoch / self.cfg.max_epoch, self.cfg.poly_exp) 13 | 14 | 15 | def get_optimizer(cfg, net): 16 | if cfg.lr_mode == 'multi': 17 | layerwise_params = {"backbone.*": dict(lr=cfg.backbone_lr, weight_decay=cfg.backbone_weight_decay)} 18 | net_params = utils.process_model_params(net, layerwise_params=layerwise_params) 19 | else: 20 | net_params = net.parameters() 21 | 22 | if cfg.type == "AdamW": 23 | optimizer = optim.AdamW(net_params, lr=cfg.lr, weight_decay=cfg.weight_decay) 24 | # optimizer = Lookahead(optimizer) 25 | 26 | elif cfg.type == "SGD": 27 | optimizer = optim.SGD(net_params, lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum, 28 | nesterov=False) 29 | optimizer = Lookahead(optimizer) 30 | else: 31 | raise KeyError("The optimizer type ( %s ) doesn't exist!!!" % cfg.type) 32 | 33 | return optimizer 34 | 35 | 36 | def get_scheduler(cfg, optimizer): 37 | if cfg.type == 'Poly': 38 | lambda1 = lambda epoch: math.pow(1 - epoch / cfg.max_epoch, cfg.poly_exp) 39 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 40 | elif cfg.type == 'CosineAnnealingLR': 41 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch, eta_min=1e-6) 42 | else: 43 | raise KeyError("The scheduler type ( %s ) doesn't exist!!!" % cfg.type) 44 | 45 | return scheduler 46 | 47 | def build_optimizer(cfg, net): 48 | optimizer = get_optimizer(cfg.optimizer, net) 49 | scheduler = get_scheduler(cfg.scheduler, optimizer) 50 | # if cfg.type == 'Poly': 51 | # lambda1 = lambda epoch: math.pow(1 - epoch / cfg.max_epoch, cfg.poly_exp) 52 | # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 53 | # elif cfg.type == 'CosineAnnealingLR': 54 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch, eta_min=1e-6) 55 | # else: 56 | # raise KeyError("The scheduler type ( %s ) doesn't exist!!!" % cfg.type) 57 | 58 | return optimizer, scheduler 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /tools/throughput_count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import time 4 | from timm import create_model 5 | from mmcv import Config 6 | 7 | from mmseg.datasets import build_dataloader, build_dataset 8 | from mmcv.parallel import MMDataParallel 9 | from mmseg.models import build_segmentor 10 | 11 | torch.autograd.set_grad_enabled(False) 12 | 13 | import sys 14 | sys.path.append('.') 15 | from train import * 16 | 17 | def replace_batchnorm(net): 18 | for child_name, child in net.named_children(): 19 | if hasattr(child, 'fuse'): 20 | fused = child.fuse() 21 | setattr(net, child_name, fused) 22 | replace_batchnorm(fused) 23 | elif isinstance(child, torch.nn.BatchNorm2d): 24 | setattr(net, child_name, torch.nn.Identity()) 25 | else: 26 | replace_batchnorm(child) 27 | 28 | T0 = 5 29 | T1 = 10 30 | 31 | def throughput(model, device, batch_size, resolution=224): 32 | inputs = torch.randn(batch_size, 3, resolution, resolution, device=device) 33 | torch.cuda.empty_cache() 34 | torch.cuda.synchronize() 35 | start = time.time() 36 | while time.time() - start < T0: 37 | model(inputs) 38 | timing = [] 39 | torch.cuda.synchronize() 40 | while sum(timing) < T1: 41 | start = time.time() 42 | model(inputs) 43 | torch.cuda.synchronize() 44 | timing.append(time.time() - start) 45 | timing = torch.as_tensor(timing, dtype=torch.float32) 46 | print(device, batch_size / timing.mean().item(), 47 | 'images/s @ batch size', batch_size) 48 | 49 | device = "cuda:0" 50 | 51 | from argparse import ArgumentParser 52 | 53 | parser = ArgumentParser() 54 | 55 | def parse_args(): 56 | parser = argparse.ArgumentParser(description='rsseg: Benchmark a model') 57 | parser.add_argument("-c", "--config", type=str, default="configs/logcan.py") 58 | parser.add_argument('--resolution', default=512, type=int) 59 | parser.add_argument('--batch-size', default=32, type=int) 60 | args = parser.parse_args() 61 | return args 62 | 63 | if __name__ == "__main__": 64 | args = parse_args() 65 | 66 | cfg = Config.fromfile(args.config) 67 | model = myTrain(cfg) 68 | 69 | batch_size = args.batch_size 70 | resolution = args.resolution 71 | torch.cuda.empty_cache() 72 | inputs = torch.randn(batch_size, 3, resolution, 73 | resolution, device=device) 74 | replace_batchnorm(model) 75 | 76 | model.to(device) 77 | model.eval() 78 | throughput(model, device, batch_size, resolution=resolution) 79 | -------------------------------------------------------------------------------- /rsseg/datasets/build_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Iterable, Optional, Sequence, Union 3 | from torch.utils.data import DataLoader, Dataset, Sampler 4 | from torch.utils.data.dataloader import _collate_fn_t, _worker_init_fn_t 5 | from rsseg.datasets.vaihingen_dataset import * 6 | from rsseg.datasets.potsdam_dataset import * 7 | from rsseg.datasets.base_dataset import * 8 | from rsseg.datasets.loveda_dataset import * 9 | 10 | def get_loader(dataset, cfg): 11 | loader = DataLoader( 12 | dataset=dataset, 13 | batch_size=cfg.batch_size, 14 | num_workers=cfg.num_workers, 15 | pin_memory=cfg.pin_memory, 16 | shuffle=cfg.shuffle, 17 | drop_last=cfg.drop_last 18 | ) 19 | return loader 20 | 21 | # dataset_config 22 | def build_dataloader(cfg, mode='train'): # get dataloader 23 | dataset_type = cfg.type 24 | data_root = cfg.data_root 25 | if dataset_type == 'LoveDA' and mode == 'test': 26 | dataset = eval(dataset_type)(data_root, mode, **cfg.test_mode) 27 | loader_cfg = cfg.test_mode.loader 28 | elif mode == 'train': 29 | dataset = eval(dataset_type)(data_root, mode, **cfg.train_mode) 30 | loader_cfg = cfg.train_mode.loader 31 | elif mode == 'val': 32 | dataset = eval(dataset_type)(data_root, mode, **cfg.val_mode) 33 | loader_cfg = cfg.val_mode.loader 34 | else: 35 | mode = 'val' 36 | dataset = eval(dataset_type)(data_root, mode, **cfg.val_mode) 37 | loader_cfg = cfg.val_mode.loader 38 | 39 | data_loader = DataLoader( 40 | dataset = dataset, 41 | batch_size = loader_cfg.batch_size, 42 | num_workers = loader_cfg.num_workers, 43 | pin_memory = loader_cfg.pin_memory, 44 | shuffle = loader_cfg.shuffle, 45 | drop_last = loader_cfg.drop_last 46 | ) 47 | 48 | return data_loader 49 | 50 | if __name__ == '__main__': #you can test dataloader from here 51 | file_path = "/home/xwma/lrr/rssegmentation/configs/ssnet.py" 52 | 53 | print(file_path) 54 | 55 | from utils.config import Config 56 | 57 | cfg = Config.fromfile(file_path) 58 | print(cfg) 59 | train_loader = build_dataloader(cfg.dataset_config) 60 | cnt = 0 61 | for i,(img,tar) in enumerate(train_loader): 62 | print(img.shape) 63 | cnt += 1 64 | if cnt > 10: 65 | break 66 | 67 | # print("start print val_loader #####################") 68 | # 69 | # for i,(img,tar) in enumerate(val_loader): 70 | # print(img.shape) 71 | # cnt += 1 72 | # if cnt > 20: 73 | # break 74 | -------------------------------------------------------------------------------- /online_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | import ttach as tta 6 | import time 7 | import os 8 | import multiprocessing.pool as mpp 9 | import multiprocessing as mp 10 | 11 | from train import * 12 | 13 | import argparse 14 | from utils.config import Config 15 | from tools.mask_convert import mask_save 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser('description=online test') 19 | parser.add_argument("-c", "--config", type=str, default="configs/logcan.py") 20 | parser.add_argument("--ckpt", type=str, default="work_dirs/LoGCAN_ResNet50_Loveda/epoch=45.ckpt") 21 | parser.add_argument("--tta", type=str, default="d4") 22 | parser.add_argument("--masks_output_dir", default=None) 23 | return parser.parse_args() 24 | 25 | 26 | if __name__ == "__main__": 27 | args = get_args() 28 | cfg = Config.fromfile(args.config) 29 | 30 | if args.masks_output_dir is not None: 31 | masks_output_dir = args.masks_output_dir 32 | else: 33 | masks_output_dir = cfg.exp_name + '/online_figs' 34 | 35 | model = myTrain.load_from_checkpoint(args.ckpt, cfg = cfg) 36 | model = model.to('cuda') 37 | 38 | model.eval() 39 | 40 | if args.tta == "lr": 41 | transforms = tta.Compose( 42 | [ 43 | tta.HorizontalFlip(), 44 | tta.VerticalFlip() 45 | ] 46 | ) 47 | model = tta.SegmentationTTAWrapper(model, transforms) 48 | elif args.tta == "d4": 49 | transforms = tta.Compose( 50 | [ 51 | tta.HorizontalFlip(), 52 | tta.VerticalFlip(), 53 | tta.Rotate90(angles=[90]), 54 | tta.Scale(scales=[0.5, 0.75, 1.0, 1.25, 1.5], interpolation='bicubic', align_corners=False) 55 | ] 56 | ) 57 | model = tta.SegmentationTTAWrapper(model, transforms) 58 | 59 | results = [] 60 | mask2RGB = False 61 | with torch.no_grad(): 62 | test_loader = build_dataloader(cfg.dataset_config, mode='test') 63 | print(len(test_loader)) 64 | for input in tqdm(test_loader): 65 | raw_predictions, img_id = model(input[0].cuda(), True), input[2] 66 | pred = raw_predictions.argmax(dim=1) 67 | 68 | for i in range(raw_predictions.shape[0]): 69 | mask_pred = pred[i].cpu().numpy() 70 | mask_name = str(img_id[i]) 71 | results.append((mask2RGB, mask_pred, cfg.dataset, masks_output_dir, mask_name)) 72 | 73 | if not os.path.exists(masks_output_dir): 74 | os.makedirs(masks_output_dir) 75 | print("masks_save_dir: ", masks_output_dir) 76 | 77 | t0 = time.time() 78 | mpp.Pool(processes=mp.cpu_count()).map(mask_save, results) 79 | t1 = time.time() 80 | img_write_time = t1 - t0 81 | print('images writing spends: {} s'.format(img_write_time)) -------------------------------------------------------------------------------- /tools/tsne.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import argparse 4 | import torch 5 | from sklearn.manifold import TSNE 6 | import matplotlib.pyplot as plt 7 | import torch.nn.functional as F 8 | import time 9 | from tqdm import tqdm 10 | import os 11 | import sys 12 | sys.path.append('.') 13 | from train import * 14 | from utils.config import Config 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='rsseg: tsne map') 18 | parser.add_argument("-c", "--config", type=str, default="configs/vaihingen/logcanplus.py") 19 | parser.add_argument("--ckpt", type=str, default="work_dirs/logcanplus_vaihingen/epoch=64.ckpt") 20 | parser.add_argument("--tar_size", type=tuple, default=(64, 64)) 21 | parser.add_argument("--n_components", type=int, default=2) 22 | parser.add_argument("--random_state", type=int, default=45) 23 | parser.add_argument("--tsne_output_dir", default=None) 24 | args = parser.parse_args() 25 | return args 26 | 27 | def color_trans(tsne_color): 28 | if tsne_color.all() == np.array([255, 255, 255]).all(): 29 | tsne_color = np.array([50, 100, 150]) 30 | return tsne_color / 255.0 31 | 32 | def main(): 33 | args = parse_args() 34 | cfg = Config.fromfile(args.config) 35 | cfg.dataset_config.val_mode.loader.batch_size = 1 36 | model = myTrain.load_from_checkpoint(args.ckpt, cfg = cfg) 37 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 38 | model.to(device) 39 | 40 | if args.tsne_output_dir is not None: 41 | tsne_output_dir = args.tsne_output_dir 42 | else: 43 | tsne_output_dir = cfg.exp_name + '/tsne_figs/' 44 | if not os.path.exists(tsne_output_dir): 45 | os.makedirs(tsne_output_dir) 46 | 47 | test_loader = build_dataloader(cfg.dataset_config, mode='test') 48 | model.eval() 49 | 50 | mytsne = TSNE(n_components=args.n_components, random_state=args.random_state) 51 | 52 | for input in tqdm(test_loader): 53 | images, gts, img_ids = input[0].to(device), input[1].to(device), input[2] 54 | img_id = img_ids[0] 55 | features = model.net.backbone(images) 56 | features = model.net.seghead(features)[0] 57 | 58 | features = F.interpolate(features, args.tar_size, mode='nearest') 59 | gts = F.interpolate(gts.float().unsqueeze(1), size=args.tar_size, mode='nearest').squeeze(1) 60 | 61 | features = features.flatten(2).transpose(1,2) 62 | features = features.cpu().detach().numpy()[0] 63 | gts = gts.cpu().numpy()[0].reshape(-1) 64 | class_name = cfg.class_name[cfg.eval_label_id_left:cfg.eval_label_id_right] 65 | 66 | tsne = mytsne.fit_transform(features) 67 | color_map = test_loader.dataset.color_map 68 | 69 | for j in range(0,len(class_name)): 70 | plt.scatter(tsne[:, 0][gts == j], tsne[:, 1][gts == j], c=color_trans(color_map[class_name[j]]).reshape(1, -1), 71 | s=1, label=class_name[j]) 72 | plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.13), ncol=6, borderpad=0.15, markerscale=5) 73 | plt.savefig(tsne_output_dir + '{}'.format(img_id) + '.pdf') 74 | plt.close() 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /rsseg/datasets/loveda_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | from PIL import Image 3 | import os 4 | import numpy as np 5 | class LoveDA(BaseDataset): 6 | def __init__(self,data_root='data/vaihingen', mode='train', transform=None,img_dir='images_png', mask_dir='masks_png', img_suffix='.png', mask_suffix='.png', **kwargs): 7 | super(LoveDA, self).__init__(transform) 8 | 9 | self.img_dir = img_dir 10 | self.img_suffix = img_suffix 11 | self.mask_dir = mask_dir 12 | self.mask_suffix = mask_suffix 13 | self.mode = mode 14 | 15 | if mode == "train": 16 | self.data_root = data_root + "/Train" 17 | elif mode == "val": 18 | self.data_root = data_root + "/Val" 19 | elif mode == "test": 20 | self.data_root = data_root + "/Test" 21 | 22 | self.file_paths = self.get_path(self.data_root,img_dir,mask_dir) 23 | 24 | 25 | #RGB 26 | self.color_map = { 27 | 'building' : np.array([255, 0, 0]), # label 0 28 | 'road' : np.array([255, 255, 0]), # label 1 29 | 'water' : np.array([0, 0, 255]), # label 2 30 | 'barren' : np.array([159, 129, 183]), # label 3 31 | 'forest' : np.array([0, 255, 0]), # label 4 32 | 'agricultural' : np.array([255, 195, 128]), # label 5 33 | 'background' : np.array([255, 255, 255]), # label 6 34 | } 35 | 36 | self.num_classes = 7 37 | 38 | def get_path(self, data_root, img_dir, mask_dir): 39 | urban_img_filename_list = os.listdir(os.path.join(data_root, 'Urban', img_dir)) 40 | if self.mode != 'test': 41 | urban_mask_filename_list = os.listdir(os.path.join(data_root, 'Urban', mask_dir)) 42 | assert len(urban_img_filename_list) == len(urban_mask_filename_list) 43 | 44 | urban_img_ids = [(str(id.split('.')[0]), 'Urban') for id in urban_img_filename_list] 45 | 46 | rural_img_filename_list = os.listdir(os.path.join(data_root, 'Rural', img_dir)) 47 | if self.mode != 'test': 48 | rural_mask_filename_list = os.listdir(os.path.join(data_root, 'Rural', mask_dir)) 49 | assert len(rural_img_filename_list) == len(rural_mask_filename_list) 50 | rural_img_ids = [(str(id.split('.')[0]), 'Rural') for id in rural_img_filename_list] 51 | img_ids = urban_img_ids + rural_img_ids 52 | return img_ids 53 | 54 | 55 | def load_img_and_mask(self, index): 56 | img_id, img_type = self.file_paths[index] 57 | img_name = os.path.join(self.data_root, img_type, self.img_dir, img_id + self.img_suffix) 58 | img = Image.open(img_name).convert('RGB') 59 | 60 | if self.mode == 'test': 61 | mask = np.array(img) 62 | mask = Image.fromarray(mask).convert('L') 63 | return [img, mask, img_id] 64 | 65 | mask_name = os.path.join(self.data_root, img_type, self.mask_dir, img_id + self.mask_suffix) 66 | mask = Image.open(mask_name).convert('L') 67 | 68 | np_mask = np.array(mask) 69 | np_mask[np_mask == 0] = 8 70 | np_mask -= 1 71 | mask = Image.fromarray(np_mask).convert('L') 72 | return [img, mask, img_id] 73 | -------------------------------------------------------------------------------- /rsseg/losses/l1l2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class L1_Loss(nn.Module): 6 | def __init__(self, ignore_index = None, reduction = 'mean'): 7 | super(L1_Loss, self).__init__() 8 | self.ignore_index = ignore_index 9 | self.reduction = reduction 10 | self.loss = nn.L1Loss(reduction=reduction) 11 | 12 | # pred (B K H W) target (B H W) 13 | def forward(self, pred, target): 14 | pred = pred.permute(0, 2, 3, 1) # (B H W K) 15 | if self.ignore_index is not None: 16 | mask = (target != self.ignore_index) 17 | pred = pred[mask] # (n, k) 18 | target = target[mask] # (n) 19 | num_classes = pred.size(-1) 20 | one_hot_target = F.one_hot(target, num_classes) # (n k) 21 | 22 | l1_loss = self.loss(pred, one_hot_target) 23 | return l1_loss 24 | 25 | class L2_Loss(nn.Module): 26 | def __init__(self, ignore_index = None, reduction = 'mean'): 27 | super(L2_Loss, self).__init__() 28 | self.ignore_index = ignore_index 29 | self.reduction = reduction 30 | self.loss = nn.MSELoss(reduction=reduction) 31 | 32 | # pred (B K H W) target (B H W) 33 | def forward(self, pred, target): 34 | pred = pred.permute(0, 2, 3, 1) # (B H W K) 35 | if self.ignore_index is not None: 36 | mask = (target != self.ignore_index) 37 | pred = pred[mask] # (n, k) 38 | target = target[mask] # (n) 39 | num_classes = pred.size(-1) 40 | one_hot_target = F.one_hot(target, num_classes) # (n k) 41 | 42 | l2_loss = self.loss(pred, one_hot_target) 43 | return l2_loss 44 | 45 | class Smooth_L1_Loss(nn.Module): 46 | def __init__(self, beta = 1, ignore_index = None, reduction = 'mean'): 47 | super(Smooth_L1_Loss, self).__init__() 48 | 49 | self.beta = beta 50 | self.reduction = reduction 51 | self.ignore_index = ignore_index 52 | 53 | def forward(self, pred, target): 54 | pred = pred.permute(0, 2, 3, 1) # (B H W K) 55 | if self.ignore_index is not None: 56 | mask = (target != self.ignore_index) 57 | pred = pred[mask] # (n, k) 58 | target = target[mask] # (n) 59 | num_classes = pred.size(-1) 60 | one_hot_target = F.one_hot(target, num_classes) # (n k) 61 | 62 | if self.beta < 1e-5: 63 | loss = torch.abs(pred - one_hot_target) 64 | else: 65 | n = torch.abs(pred - one_hot_target) 66 | cond = n < self.beta 67 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. 68 | loss = torch.where(cond, 0.5 * n**2 / self.beta, n - 0.5 * self.beta) 69 | 70 | if self.reduction == "mean": 71 | loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() 72 | elif self.reduction == "sum": 73 | loss = loss.sum() 74 | return loss 75 | 76 | if __name__ == "__main__": 77 | # lossmodel = Smooth_L1_Loss() 78 | lossmodel = Smooth_L1_Loss(ignore_index=6) 79 | pred = torch.randn(4, 6, 64, 64) 80 | target = torch.randint(low=0, high=7, size=(4, 64, 64)) 81 | loss = lossmodel(pred, target) 82 | print(loss) 83 | 84 | -------------------------------------------------------------------------------- /rsseg/models/basemodules/dysample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def normal_init(module, mean=0, std=1, bias=0): 7 | if hasattr(module, 'weight') and module.weight is not None: 8 | nn.init.normal_(module.weight, mean, std) 9 | if hasattr(module, 'bias') and module.bias is not None: 10 | nn.init.constant_(module.bias, bias) 11 | 12 | 13 | def constant_init(module, val, bias=0): 14 | if hasattr(module, 'weight') and module.weight is not None: 15 | nn.init.constant_(module.weight, val) 16 | if hasattr(module, 'bias') and module.bias is not None: 17 | nn.init.constant_(module.bias, bias) 18 | 19 | 20 | class DySample(nn.Module): 21 | def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False): 22 | super().__init__() 23 | self.scale = scale 24 | self.style = style 25 | self.groups = groups 26 | assert style in ['lp', 'pl'] 27 | if style == 'pl': 28 | assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0 29 | assert in_channels >= groups and in_channels % groups == 0 30 | 31 | if style == 'pl': 32 | in_channels = in_channels // scale ** 2 33 | out_channels = 2 * groups 34 | else: 35 | out_channels = 2 * groups * scale ** 2 36 | 37 | self.offset = nn.Conv2d(in_channels, out_channels, 1) 38 | normal_init(self.offset, std=0.001) 39 | if dyscope: 40 | self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False) 41 | constant_init(self.scope, val=0.) 42 | 43 | self.register_buffer('init_pos', self._init_pos()) 44 | 45 | def _init_pos(self): 46 | h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale 47 | return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1) 48 | 49 | def sample(self, x, offset): 50 | B, _, H, W = offset.shape 51 | offset = offset.view(B, 2, -1, H, W) 52 | coords_h = torch.arange(H) + 0.5 53 | coords_w = torch.arange(W) + 0.5 54 | coords = torch.stack(torch.meshgrid([coords_w, coords_h]) 55 | ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device) 56 | normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1) 57 | coords = 2 * (coords + offset) / normalizer - 1 58 | coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view( 59 | B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1) 60 | return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear', 61 | align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W) 62 | 63 | def forward_lp(self, x): 64 | if hasattr(self, 'scope'): 65 | offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos 66 | else: 67 | offset = self.offset(x) * 0.25 + self.init_pos 68 | return self.sample(x, offset) 69 | 70 | def forward_pl(self, x): 71 | x_ = F.pixel_shuffle(x, self.scale) 72 | if hasattr(self, 'scope'): 73 | offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos 74 | else: 75 | offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos 76 | return self.sample(x, offset) 77 | 78 | def forward(self, x): 79 | if self.style == 'pl': 80 | return self.forward_pl(x) 81 | return self.forward_lp(x) 82 | 83 | 84 | if __name__ == '__main__': 85 | x = torch.rand(2, 64, 4, 7) 86 | dys = DySample(64) 87 | print(dys(x).shape) -------------------------------------------------------------------------------- /tools/latency_count.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import sys 9 | sys.path.append('.') 10 | from utils.config import Config 11 | import json 12 | 13 | from train import * 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='rsseg: Benchmark a model') 18 | parser.add_argument("-c", "--config", type=str, default="configs/logcan.py") 19 | parser.add_argument("--ckpt", type=str, default="work_dirs/LoGCAN_ResNet50_Loveda/epoch=45.ckpt") 20 | parser.add_argument( 21 | '--log-interval', type=int, default=50, help='interval of logging') 22 | parser.add_argument( 23 | '--work-dir', 24 | help=('if specified, the results will be dumped ' 25 | 'into the directory as json'), default=None) 26 | parser.add_argument('--repeat-times', type=int, default=3) 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def main(): 32 | args = parse_args() 33 | cfg = Config.fromfile(args.config) 34 | 35 | repeat_times = args.repeat_times 36 | # set cudnn_benchmark 37 | torch.backends.cudnn.benchmark = False 38 | 39 | benchmark_dict = dict(config=args.config, dataset=cfg.dataset, unit='img / s') 40 | overall_fps_list = [] 41 | cfg.dataset_config.val_mode.loader.batch_size = 1 42 | 43 | # build the model and load checkpoint 44 | if osp.exists(args.ckpt): 45 | model = myTrain.load_from_checkpoint(args.ckpt, cfg = cfg) 46 | else: 47 | model = myTrain(cfg) 48 | model = model.cuda() 49 | model.eval() 50 | 51 | for time_index in range(repeat_times): 52 | print(f'Run {time_index + 1}:') 53 | # build the dataloader 54 | data_loader = build_dataloader(cfg.dataset_config, mode='val') 55 | 56 | # the first several iterations may be very slow so skip them 57 | num_warmup = 5 58 | pure_inf_time = 0 59 | total_iters = 100 60 | 61 | # benchmark with 200 batches and take the average 62 | for i, input in enumerate(data_loader): 63 | if torch.cuda.is_available(): 64 | torch.cuda.synchronize() 65 | start_time = time.perf_counter() 66 | 67 | with torch.no_grad(): 68 | model(input[0].cuda(), True) 69 | 70 | if torch.cuda.is_available(): 71 | torch.cuda.synchronize() 72 | elapsed = time.perf_counter() - start_time 73 | 74 | if i >= num_warmup: 75 | pure_inf_time += elapsed 76 | if (i + 1) % args.log_interval == 0: 77 | fps = (i + 1 - num_warmup) / pure_inf_time 78 | print(f'Done image [{i + 1:<3}/ {total_iters}], ' 79 | f'fps: {fps:.2f} img / s') 80 | 81 | if (i + 1) == total_iters: 82 | fps = (i + 1 - num_warmup) / pure_inf_time 83 | print(f'Overall fps: {fps:.2f} img / s\n') 84 | benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2) 85 | overall_fps_list.append(fps) 86 | break 87 | print(overall_fps_list) 88 | benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2) 89 | benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4) 90 | print(f'Average fps of {repeat_times} evaluations: ' 91 | f'{benchmark_dict["average_fps"]}') 92 | print(f'The variance of {repeat_times} evaluations: ' 93 | f'{benchmark_dict["fps_variance"]}') 94 | 95 | json_str = json.dumps(benchmark_dict, indent=0) 96 | file_name = cfg.exp_name + "/eval_metric.txt" 97 | with open(file_name, 'a') as f: 98 | f.write(json_str+'\n') 99 | 100 | if __name__ == '__main__': 101 | main() -------------------------------------------------------------------------------- /tools/cam.py: -------------------------------------------------------------------------------- 1 | import ttach as tta 2 | import argparse 3 | from pathlib import Path 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from PIL import Image 11 | from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image 12 | from pytorch_grad_cam import GradCAM 13 | import random 14 | import os 15 | import torch.nn.functional as F 16 | 17 | import sys 18 | sys.path.append('.') 19 | from train import * 20 | from utils.config import Config 21 | 22 | class SemanticSegmentationTarget: 23 | """wrap the model. 24 | 25 | requirement: pip install grad-cam 26 | 27 | Args: 28 | category (int): Visualization class. 29 | mask (ndarray): Mask of class. 30 | size (tuple): Image size. 31 | """ 32 | 33 | def __init__(self, category, mask, size): 34 | self.category = category 35 | self.mask = torch.from_numpy(mask) 36 | self.size = size 37 | if torch.cuda.is_available(): 38 | self.mask = self.mask.to("cuda:0") 39 | 40 | def __call__(self, model_output): 41 | model_output = F.interpolate( 42 | model_output, size=self.size, mode='bilinear') 43 | model_output = torch.squeeze(model_output, dim=0) 44 | 45 | return (model_output[self.category, :, :] * self.mask).sum() 46 | 47 | def parse_args(): 48 | parser = argparse.ArgumentParser(description='rsseg: cam map') 49 | parser.add_argument("-c", "--config", type=str, default="configs/vaihingen/logcanplus.py") 50 | parser.add_argument("--ckpt", type=str, default="work_dirs/logcanplus_vaihingen/epoch=64.ckpt") 51 | parser.add_argument("--tar_layer", type=str, default="model.net.seghead.catconv2[-2]") 52 | parser.add_argument("--tar_category", type=int, default=1) 53 | parser.add_argument("--cam_output_dir", default=None) 54 | args = parser.parse_args() 55 | return args 56 | 57 | def main(): 58 | args = parse_args() 59 | cfg = Config.fromfile(args.config) 60 | cfg.dataset_config.val_mode.loader.batch_size = 1 61 | model = myTrain.load_from_checkpoint(args.ckpt, cfg = cfg) 62 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 63 | model.to(device) 64 | 65 | if args.cam_output_dir is not None: 66 | cam_output_dir = args.cam_output_dir 67 | else: 68 | cam_output_dir = cfg.exp_name + '/cam_figs/' 69 | if not os.path.exists(cam_output_dir): 70 | os.makedirs(cam_output_dir) 71 | 72 | class_name = cfg.class_name 73 | category = args.tar_category 74 | test_loader = build_dataloader(cfg.dataset_config, mode='test') 75 | model.eval() 76 | 77 | for input in tqdm(test_loader): 78 | masks, gts, img_ids = model(input[0].to(device)), input[1].to(device), input[2] 79 | masks = nn.Softmax(dim=1)(masks[0]) 80 | masks = masks.argmax(dim=1) 81 | for i in range(masks.shape[0]): 82 | mask = masks[i].cpu().numpy() 83 | gt = gts[i].cpu().numpy() 84 | mask_name = img_ids[i] 85 | 86 | tar_layer = [eval(args.tar_layer)] 87 | 88 | height, width = gt.shape[-2:] 89 | mask_float = np.float32(mask == category) 90 | 91 | targets = [ 92 | SemanticSegmentationTarget(category, mask_float, (height, width)) 93 | ] 94 | 95 | test_dataset = test_loader.dataset 96 | img_path = os.path.join(test_dataset.data_root, test_dataset.img_dir, mask_name + test_dataset.img_suffix) 97 | 98 | ori_img = cv2.imread(img_path) 99 | rgb_img = ori_img.astype(np.float32) / 255.0 100 | 101 | cam = GradCAM(model=model, target_layers=tar_layer) 102 | grayscale_cam = cam(input_tensor=input[0][i].unsqueeze(0).to(device), targets=targets)[0, :] 103 | 104 | cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) 105 | Image.fromarray(cam_image).save(cam_output_dir + mask_name + '_{}'.format(class_name[category]) + '.png') 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | import ttach as tta 6 | import prettytable 7 | import time 8 | import glob 9 | import os 10 | import os.path as osp 11 | import multiprocessing.pool as mpp 12 | import multiprocessing as mp 13 | 14 | from train import * 15 | 16 | import argparse 17 | from utils.config import Config 18 | from tools.mask_convert import mask_save 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser(description='rsseg: test model') 22 | parser.add_argument("-c", "--config", type=str, default="configs/logcan.py") 23 | parser.add_argument("--ckpt", type=str, default="work_dirs/LoGCAN_ResNet50_Loveda/epoch=45.ckpt") 24 | parser.add_argument("--tta", type=str, default="d4") 25 | parser.add_argument("--masks_output_dir", default=None) 26 | return parser.parse_args() 27 | 28 | if __name__ == "__main__": 29 | args = get_args() 30 | cfg = Config.fromfile(args.config) 31 | 32 | if args.masks_output_dir is not None: 33 | masks_output_dir = args.masks_output_dir 34 | else: 35 | masks_output_dir = cfg.exp_name + '/figs' 36 | 37 | model = myTrain.load_from_checkpoint(args.ckpt, cfg = cfg) 38 | model = model.to('cuda') 39 | 40 | model.eval() 41 | 42 | if args.tta == "lr": 43 | transforms = tta.Compose( 44 | [ 45 | tta.HorizontalFlip(), 46 | tta.VerticalFlip() 47 | ] 48 | ) 49 | model = tta.SegmentationTTAWrapper(model, transforms) 50 | elif args.tta == "d4": 51 | transforms = tta.Compose( 52 | [ 53 | tta.HorizontalFlip(), 54 | tta.VerticalFlip(), 55 | tta.Rotate90(angles=[90]), 56 | tta.Scale(scales=[0.5, 0.75, 1.0, 1.25, 1.5], interpolation='bicubic', align_corners=False) 57 | ] 58 | ) 59 | model = tta.SegmentationTTAWrapper(model, transforms) 60 | 61 | metric_cfg1 = cfg.metric_cfg1 62 | metric_cfg2 = cfg.metric_cfg2 63 | 64 | test_oa=torchmetrics.Accuracy(**metric_cfg1).to('cuda') 65 | test_prec = torchmetrics.Precision(**metric_cfg2).to('cuda') 66 | test_recall = torchmetrics.Recall(**metric_cfg2).to('cuda') 67 | test_f1 = torchmetrics.F1Score(**metric_cfg2).to('cuda') 68 | test_iou=torchmetrics.JaccardIndex(**metric_cfg2).to('cuda') 69 | 70 | results = [] 71 | mask2RGB = True 72 | with torch.no_grad(): 73 | test_loader = build_dataloader(cfg.dataset_config, mode='val') 74 | for input in tqdm(test_loader): 75 | raw_predictions, mask, img_id = model(input[0].cuda(), True), input[1].cuda(), input[2] 76 | pred = raw_predictions.argmax(dim=1) 77 | 78 | test_oa(pred, mask) 79 | test_iou(pred, mask) 80 | test_prec(pred, mask) 81 | test_f1(pred, mask) 82 | test_recall(pred, mask) 83 | 84 | for i in range(raw_predictions.shape[0]): 85 | mask_pred = pred[i].cpu().numpy() 86 | mask_name = str(img_id[i]) 87 | results.append((mask2RGB, mask_pred, cfg.dataset, masks_output_dir, mask_name)) 88 | 89 | metrics = [test_prec.compute(), 90 | test_recall.compute(), 91 | test_f1.compute(), 92 | test_iou.compute()] 93 | 94 | total_metrics = [test_oa.compute().cpu().numpy(), 95 | np.mean([item.cpu() for item in metrics[0][cfg.eval_label_id_left: cfg.eval_label_id_right] if item > 0]), 96 | np.mean([item.cpu() for item in metrics[1][cfg.eval_label_id_left: cfg.eval_label_id_right] if item > 0]), 97 | np.mean([item.cpu() for item in metrics[2][cfg.eval_label_id_left: cfg.eval_label_id_right] if item > 0]), 98 | np.mean([item.cpu() for item in metrics[3][cfg.eval_label_id_left: cfg.eval_label_id_right] if item > 0])] 99 | 100 | result_table = prettytable.PrettyTable() 101 | result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU'] 102 | 103 | for i in range(len(metrics[0])): 104 | item = [i, '--'] 105 | for j in range(len(metrics)): 106 | item.append(np.round(metrics[j][i].cpu().numpy(), 4)) 107 | result_table.add_row(item) 108 | 109 | total = [np.round(v, 4) for v in total_metrics] 110 | total.insert(0, 'Total') 111 | result_table.add_row(total) 112 | 113 | print(result_table) 114 | 115 | file_name = cfg.exp_name + "/eval_metric.txt" 116 | f = open(file_name,"a") 117 | current_time = time.strftime('%Y_%m_%d %H:%M:%S',time.localtime(time.time())) 118 | f.write(current_time+' test\n') 119 | f.write(str(result_table)+'\n') 120 | 121 | if not os.path.exists(masks_output_dir): 122 | os.makedirs(masks_output_dir) 123 | print("masks_save_dir: ", masks_output_dir) 124 | 125 | t0 = time.time() 126 | mpp.Pool(processes=mp.cpu_count()).map(mask_save, results) 127 | t1 = time.time() 128 | img_write_time = t1 - t0 129 | print('images writing spends: {} s'.format(img_write_time)) -------------------------------------------------------------------------------- /tools/mask_convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import glob 4 | import os 5 | import sys 6 | import torch 7 | import cv2 8 | import random 9 | import time 10 | import multiprocessing.pool as mpp 11 | import multiprocessing as mp 12 | SEED = 66 13 | 14 | def seed_everything(seed): 15 | random.seed(seed) 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | def vaihingen_label2rgb(mask): 23 | h, w = mask.shape[0], mask.shape[1] 24 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 25 | mask_convert = mask[np.newaxis, :, :] 26 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 255, 255] 27 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [0, 0, 255] 28 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [0, 255, 255] 29 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [0, 255, 0] 30 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [255, 255, 0] 31 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [255, 0, 0] 32 | return mask_rgb 33 | 34 | def loveda_label2rgb(mask): 35 | h, w = mask.shape[0], mask.shape[1] 36 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 37 | mask_convert = mask[np.newaxis, :, :] 38 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 0, 0] 39 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 255, 0] 40 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [0, 0, 255] 41 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [159, 129, 183] 42 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [0, 255, 0] 43 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [255, 195, 128] 44 | mask_rgb[np.all(mask_convert == 6, axis=0)] = [255, 255, 255] 45 | return mask_rgb 46 | 47 | def uavid_label2rgb(mask): 48 | h, w = mask.shape[0], mask.shape[1] 49 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 50 | mask_convert = mask[np.newaxis, :, :] 51 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [128, 0, 0] 52 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [128, 64, 128] 53 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [0, 128, 0] 54 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [128, 128, 0] 55 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [64, 0, 128] 56 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [192, 0, 192] 57 | mask_rgb[np.all(mask_convert == 6, axis=0)] = [64, 64, 0] 58 | mask_rgb[np.all(mask_convert == 7, axis=0)] = [0, 0, 0] 59 | return mask_rgb 60 | 61 | def potsdam_label2rgb(mask): 62 | h, w = mask.shape[0], mask.shape[1] 63 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 64 | mask_convert = mask[np.newaxis, :, :] 65 | 66 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 255, 255] 67 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [0, 0, 255] 68 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [0, 255, 255] 69 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [0, 255, 0] 70 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [255, 255, 0] 71 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [255, 0, 0] 72 | return mask_rgb 73 | 74 | def parse_args(): 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--dataset", default="Vaihingen") 77 | parser.add_argument("--mask-dir", default="data/Test/masks") 78 | parser.add_argument("--output-mask-dir", default="data/Test/masks_rgb") 79 | return parser.parse_args() 80 | 81 | def mask_save(inp): 82 | (mask2RGB, mask, type, masks_output_dir, file_name) = inp 83 | out_mask_path = os.path.join(masks_output_dir, "{}.png".format(file_name)) 84 | if mask2RGB: 85 | if type == "loveda": 86 | label = loveda_label2rgb(mask.copy()) 87 | elif type == "vaihingen": 88 | label = vaihingen_label2rgb(mask.copy()) 89 | elif type == "potsdam": 90 | label = potsdam_label2rgb(mask.copy()) 91 | elif type == "uavid": 92 | label = uavid_label2rgb(mask.copy()) 93 | else: raise AttributeError(f"dataset type {type} not exist") 94 | 95 | rgb_label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) 96 | cv2.imwrite(out_mask_path, rgb_label) 97 | else: 98 | cv2.imwrite(out_mask_path, mask) 99 | 100 | def get_rgb(inp): 101 | (mask_path, masks_output_dir,dataset) = inp 102 | mask_filename = os.path.splitext(os.path.basename(mask_path))[0] 103 | mask_bgr = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED) 104 | mask = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB) 105 | if dataset == "LoveDA": 106 | rgb_label = loveda_label2rgb(mask.copy()) 107 | elif dataset == "Vaihingen": 108 | rgb_label = vaihingen_label2rgb(mask.copy()) 109 | elif dataset == "Potsdam": 110 | rgb_label = potsdam_label2rgb(mask.copy()) 111 | elif dataset == "uavid": 112 | rgb_label = uavid_label2rgb(mask.copy()) 113 | else: return 114 | #rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_RGB2BGR) 115 | 116 | out_mask_path_rgb = os.path.join(masks_output_dir, "{}.png".format(mask_filename)) 117 | rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_BGR2RGB) 118 | cv2.imwrite(out_mask_path_rgb, rgb_label) 119 | 120 | if __name__ == '__main__': 121 | base_path = "/home/xwma/lrr/rssegmentation/" 122 | args = parse_args() 123 | dataset = args.dataset 124 | 125 | seed_everything(SEED) 126 | masks_dir = args.mask_dir 127 | masks_output_dir = args.output_mask_dir 128 | masks_dir = base_path + masks_dir 129 | masks_output_dir = base_path + masks_output_dir 130 | 131 | mask_paths = glob.glob(os.path.join(masks_dir, "*.png")) 132 | inp = [(mask_path, masks_output_dir, dataset) for mask_path in mask_paths] 133 | if not os.path.exists(masks_output_dir): 134 | os.makedirs(masks_output_dir) 135 | 136 | t0 = time.time() 137 | mpp.Pool(processes=mp.cpu_count()).map(get_rgb, inp) 138 | t1 = time.time() 139 | split_time = t1 - t0 140 | print('images spliting spends: {} s'.format(split_time)) 141 | 142 | -------------------------------------------------------------------------------- /rsseg/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import os 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 12 | 'resnet18stem': 'https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth', 13 | 'resnet50stem': 'https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth', 14 | 'resnet101stem': 'https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth', 15 | } 16 | 17 | def conv3x3(in_planes, outplanes, stride=1): 18 | # 带padding的3*3卷积 19 | return nn.Conv2d(in_planes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | class Bottleneck(nn.Module): 22 | expansion = 4 23 | def __init__(self, in_planes, planes, stride=1, dilation=1, downsample=None): 24 | super(Bottleneck, self).__init__() 25 | self.conv1 = nn.Conv2d(in_planes, planes, 1, bias=False) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=dilation, 28 | dilation=dilation, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False) 31 | self.bn3 = nn.BatchNorm2d(planes*self.expansion) 32 | 33 | self.relu = nn.ReLU(inplace=False) 34 | self.relu_inplace = nn.ReLU(inplace=True) 35 | 36 | self.downsample = downsample 37 | self.dilation = dilation 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out = out + residual 51 | out = self.relu_inplace(out) 52 | 53 | return out 54 | 55 | class Resnet(nn.Module): 56 | def __init__(self, block, layers, out_stride=8, use_stem=False, stem_channels=64, in_channels=3): 57 | self.inplanes = 64 58 | super(Resnet, self).__init__() 59 | outstride_to_strides_and_dilations = { 60 | 8: ((1, 2, 1, 1), (1, 1, 2, 4)), 61 | 16: ((1, 2, 2, 1), (1, 1, 1, 2)), 62 | 32: ((1, 2, 2, 2), (1, 1, 1, 1)), 63 | } 64 | stride_list, dilation_list = outstride_to_strides_and_dilations[out_stride] 65 | 66 | self.use_stem = use_stem 67 | if use_stem: 68 | self.stem = nn.Sequential( 69 | conv3x3(in_channels, stem_channels//2, stride=2), 70 | nn.BatchNorm2d(stem_channels//2), 71 | nn.ReLU(inplace=False), 72 | 73 | conv3x3(stem_channels//2, stem_channels//2), 74 | nn.BatchNorm2d(stem_channels//2), 75 | nn.ReLU(inplace=False), 76 | 77 | conv3x3(stem_channels//2, stem_channels), 78 | nn.BatchNorm2d(stem_channels), 79 | nn.ReLU(inplace=False) 80 | ) 81 | else: 82 | self.conv1 = nn.Conv2d(in_channels, stem_channels, kernel_size=7, stride=2, padding=3, bias=False) 83 | self.bn1 = nn.BatchNorm2d(stem_channels) 84 | self.relu = nn.ReLU(inplace=False) 85 | 86 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) 87 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 88 | 89 | self.layer1 = self._make_layer(block, 64, blocks=layers[0], stride=stride_list[0], dilation=dilation_list[0]) 90 | self.layer2 = self._make_layer(block, 128, blocks=layers[1], stride=stride_list[1], dilation=dilation_list[1]) 91 | self.layer3 = self._make_layer(block, 256, blocks=layers[2], stride=stride_list[2], dilation=dilation_list[2]) 92 | self.layer4 = self._make_layer(block, 512, blocks=layers[3], stride=stride_list[3], dilation=dilation_list[3]) 93 | 94 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, contract_dilation=True): 95 | downsample = None 96 | dilations = [dilation] * blocks 97 | 98 | if contract_dilation and dilation > 1: dilations[0] = dilation // 2 99 | if stride != 1 or self.inplanes != planes * block.expansion: 100 | downsample = nn.Sequential( 101 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False), 102 | nn.BatchNorm2d(planes*block.expansion) 103 | ) 104 | 105 | layers = [] 106 | layers.append(block(self.inplanes, planes, stride, dilation=dilations[0], downsample=downsample)) 107 | self.inplanes = planes * block.expansion 108 | 109 | for i in range(1, blocks): 110 | layers.append(block(self.inplanes, planes, dilation=dilations[i])) 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | if self.use_stem: 115 | x = self.stem(x) 116 | else: 117 | x = self.relu(self.bn1(self.conv1(x))) 118 | 119 | x = self.maxpool(x) 120 | 121 | x1 = self.layer1(x) 122 | x2 = self.layer2(x1) 123 | x3 = self.layer3(x2) 124 | x4 = self.layer4(x3) 125 | 126 | outs = [x1, x2, x3, x4] 127 | 128 | return tuple(outs) 129 | 130 | def get_resnet50_OS8(pretrained=True): 131 | model = Resnet(Bottleneck, [3, 4, 6, 3], out_stride=8, use_stem=True) 132 | if pretrained: 133 | checkpoint = model_zoo.load_url(model_urls['resnet50stem']) 134 | if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] 135 | else: state_dict = checkpoint 136 | model.load_state_dict(state_dict, strict=False) 137 | return model 138 | 139 | def get_resnet50_OS32(pretrained=True): 140 | model = Resnet(Bottleneck, [3, 4, 6, 3], out_stride=32, use_stem=False) 141 | if pretrained: 142 | checkpoint = model_zoo.load_url(model_urls['resnet50']) 143 | if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] 144 | else: state_dict = checkpoint 145 | model.load_state_dict(state_dict, strict=False) 146 | return model 147 | 148 | if __name__ == "__main__": 149 | model = get_resnet50_OS32() 150 | x = torch.randn(2, 3, 64, 64) 151 | x = model(x)[-1] 152 | print(x.shape) 153 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 5 | from pytorch_lightning import LightningModule, Trainer, seed_everything 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, TQDMProgressBar 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | import torchmetrics 9 | import prettytable 10 | import numpy as np 11 | 12 | import argparse 13 | from rsseg.models.build_model import build_model 14 | from rsseg.datasets import build_dataloader 15 | from rsseg.optimizers import build_optimizer 16 | from rsseg.losses import build_loss 17 | from utils.config import Config 18 | 19 | seed_everything(2025, workers=True) 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser('rsseg: train model') 23 | parser.add_argument("-c", "--config", type=str, default="configs/docnet.py") 24 | return parser.parse_args() 25 | 26 | class myTrain(LightningModule): 27 | def __init__(self, cfg): 28 | super(myTrain, self).__init__() 29 | 30 | self.cfg = cfg 31 | self.net = build_model(cfg.model_config) 32 | self.loss = build_loss(cfg.loss_config) 33 | 34 | self.loss.to("cuda") 35 | self.eval_label_id_left = cfg.eval_label_id_left 36 | self.eval_label_id_right = cfg.eval_label_id_right 37 | 38 | metric_cfg1 = cfg.metric_cfg1 39 | metric_cfg2 = cfg.metric_cfg2 40 | 41 | 42 | self.tr_oa=torchmetrics.Accuracy(**metric_cfg1) 43 | self.tr_prec = torchmetrics.Precision(**metric_cfg2) 44 | self.tr_recall = torchmetrics.Recall(**metric_cfg2) 45 | self.tr_f1 = torchmetrics.F1Score(**metric_cfg2) 46 | self.tr_iou=torchmetrics.JaccardIndex(**metric_cfg2) 47 | 48 | self.val_oa=torchmetrics.Accuracy(**metric_cfg1) 49 | self.val_prec = torchmetrics.Precision(**metric_cfg2) 50 | self.val_recall = torchmetrics.Recall(**metric_cfg2) 51 | self.val_f1 = torchmetrics.F1Score(**metric_cfg2) 52 | self.val_iou=torchmetrics.JaccardIndex(**metric_cfg2) 53 | 54 | def forward(self, x, test = False) : 55 | pred = self.net(x) 56 | if test: 57 | return pred[0] 58 | return pred 59 | 60 | def configure_optimizers(self): 61 | optimizer, scheduler = build_optimizer(self.cfg.optimizer_config, self.net) 62 | return {'optimizer':optimizer,'lr_scheduler':scheduler} 63 | 64 | def train_dataloader(self): 65 | loader = build_dataloader(self.cfg.dataset_config, mode='train') 66 | return loader 67 | 68 | def val_dataloader(self): 69 | loader = build_dataloader(self.cfg.dataset_config, mode='val') 70 | return loader 71 | 72 | def output(self, metrics, total_metrics, mode): 73 | result_table = prettytable.PrettyTable() 74 | result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU'] 75 | 76 | for i in range(len(metrics[0])): 77 | item = [self.cfg.class_name[i], '--'] 78 | for j in range(len(metrics)): 79 | item.append(np.round(metrics[j][i].cpu().numpy(), 4)) 80 | result_table.add_row(item) 81 | 82 | total = list(total_metrics.values()) 83 | total = [np.round(v, 4) for v in total] 84 | total.insert(0, 'total') 85 | result_table.add_row(total) 86 | 87 | print(result_table) 88 | 89 | file_name = cfg.exp_name + "/train_metric.txt" 90 | f = open(file_name,"a") 91 | f.write('epoch:{}/{} {}\n'.format(self.current_epoch, self.cfg.epoch, mode)) 92 | f.write(str(result_table)+'\n') 93 | f.close() 94 | 95 | def training_step(self, batch, batch_idx): 96 | image, mask = batch[0], batch[1] 97 | preds = self(image) 98 | all_loss = self.loss(preds, mask) 99 | 100 | pred = preds[0].argmax(dim=1) 101 | 102 | self.tr_oa(pred, mask) 103 | self.tr_prec(pred, mask) 104 | self.tr_recall(pred, mask) 105 | self.tr_f1(pred, mask) 106 | self.tr_iou(pred, mask) 107 | 108 | for loss_name in all_loss: 109 | self.log(loss_name, all_loss[loss_name], on_step=False,on_epoch=True,prog_bar=True) 110 | return all_loss['total_loss'] 111 | 112 | def on_train_epoch_end(self): 113 | metrics = [self.tr_prec.compute(), 114 | self.tr_recall.compute(), 115 | self.tr_f1.compute(), 116 | self.tr_iou.compute()] 117 | 118 | log = {'tr_oa': float(self.tr_oa.compute().cpu()), 119 | 'tr_prec': np.mean([item.cpu() for item in metrics[0][self.eval_label_id_left: self.eval_label_id_right] if item > 0]), 120 | 'tr_recall': np.mean([item.cpu() for item in metrics[1][self.eval_label_id_left: self.eval_label_id_right] if item > 0]), 121 | 'tr_f1': np.mean([item.cpu() for item in metrics[2][self.eval_label_id_left: self.eval_label_id_right] if item > 0]), 122 | 'tr_miou': np.mean([item.cpu() for item in metrics[3][self.eval_label_id_left: self.eval_label_id_right] if item > 0])} 123 | 124 | # self.output(metrics, log, 'train') 125 | 126 | for key, value in zip(log.keys(), log.values()): 127 | self.log(key, value, on_step=False,on_epoch=True,prog_bar=False) 128 | 129 | self.tr_oa.reset() 130 | self.tr_prec.reset() 131 | self.tr_recall.reset() 132 | self.tr_f1.reset() 133 | self.tr_iou.reset() 134 | 135 | def validation_step(self, batch, batch_idx): 136 | image, mask = batch[0], batch[1] 137 | preds = self(image) 138 | all_loss = self.loss(preds, mask) 139 | 140 | pred = preds[0].argmax(dim=1) 141 | 142 | self.val_oa(pred, mask) 143 | self.val_prec(pred, mask) 144 | self.val_recall(pred, mask) 145 | self.val_f1(pred, mask) 146 | self.val_iou(pred, mask) 147 | 148 | for loss_name in all_loss: 149 | self.log(loss_name, all_loss[loss_name], on_step=False,on_epoch=True,prog_bar=True) 150 | return all_loss['total_loss'] 151 | 152 | def on_validation_epoch_end(self): 153 | metrics = [self.val_prec.compute(), 154 | self.val_recall.compute(), 155 | self.val_f1.compute(), 156 | self.val_iou.compute()] 157 | 158 | log = {'val_oa': float(self.val_oa.compute().cpu()), 159 | 'val_prec': np.mean([item.cpu() for item in metrics[0][self.eval_label_id_left: self.eval_label_id_right] if item > 0]), 160 | 'val_recall': np.mean([item.cpu() for item in metrics[1][self.eval_label_id_left: self.eval_label_id_right] if item > 0]), 161 | 'val_f1': np.mean([item.cpu() for item in metrics[2][self.eval_label_id_left: self.eval_label_id_right] if item > 0]), 162 | 'val_miou': np.mean([item.cpu() for item in metrics[3][self.eval_label_id_left: self.eval_label_id_right] if item > 0])} 163 | 164 | self.output(metrics, log, 'val') 165 | 166 | for key, value in zip(log.keys(), log.values()): 167 | self.log(key, value, on_step=False, on_epoch=True, prog_bar=False) 168 | 169 | self.val_oa.reset() 170 | self.val_prec.reset() 171 | self.val_recall.reset() 172 | self.val_f1.reset() 173 | self.val_iou.reset() 174 | 175 | if __name__ == "__main__": 176 | args = get_args() 177 | cfg = Config.fromfile(args.config) 178 | print(cfg) 179 | model = myTrain(cfg) 180 | 181 | 182 | lr_monitor=LearningRateMonitor(logging_interval = cfg.logging_interval) 183 | 184 | ckpt_cb = ModelCheckpoint(dirpath = cfg.exp_name, 185 | filename = '{epoch:d}', 186 | monitor = cfg.monitor, 187 | mode = 'max', 188 | save_top_k = cfg.save_top_k) 189 | 190 | pbar = TQDMProgressBar(refresh_rate=1) 191 | 192 | callbacks = [ckpt_cb, pbar, lr_monitor] 193 | 194 | logger = TensorBoardLogger(save_dir = "", 195 | name = cfg.exp_name, 196 | default_hp_metric = False) 197 | 198 | 199 | trainer = Trainer(max_epochs = cfg.epoch, 200 | # precision='16-mixed', 201 | callbacks = callbacks, 202 | logger = logger, 203 | enable_model_summary = True, 204 | accelerator = 'auto', 205 | devices = cfg.gpus, 206 | num_sanity_val_steps = 2, 207 | benchmark = True) 208 | 209 | trainer.fit(model, ckpt_path=cfg.resume_ckpt_path) -------------------------------------------------------------------------------- /utils/registry.py: -------------------------------------------------------------------------------- 1 | # Code simplified from https://github.com/open-mmlab/mmengine/blob/main/mmengine/registry/registry.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | from collections.abc import Callable 5 | from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union 6 | from rich.table import Table 7 | from rich.console import Console 8 | 9 | 10 | class Registry: 11 | """A registry to map strings to classes or functions. 12 | 13 | Registered object could be built from registry. Meanwhile, registered 14 | functions could be called from registry. 15 | """ 16 | 17 | def __init__(self, 18 | name: str, 19 | build_func: Optional[Callable] = None): 20 | self._name = name 21 | self._module_dict: Dict[str, Type] = dict() 22 | self._imported = False 23 | 24 | self.build_func: Callable 25 | if build_func is None: 26 | self.build_func = build_from_cfg 27 | else: 28 | self.build_func = build_func 29 | 30 | def __len__(self): 31 | return len(self._module_dict) 32 | 33 | def __repr__(self): 34 | table = Table(title=f'Registry of {self._name}') 35 | table.add_column('Names', justify='left', style='cyan') 36 | table.add_column('Objects', justify='left', style='green') 37 | 38 | for name, obj in sorted(self._module_dict.items()): 39 | table.add_row(name, str(obj)) 40 | 41 | console = Console() 42 | with console.capture() as capture: 43 | console.print(table, end='') 44 | 45 | return capture.get() 46 | 47 | @property 48 | def name(self): 49 | return self._name 50 | 51 | @property 52 | def module_dict(self): 53 | return self._module_dict 54 | 55 | def build(self, cfg: dict, *args, **kwargs) -> Any: 56 | """Build an instance. 57 | Build an instance by calling :attr:`build_func`. 58 | """ 59 | return self.build_func(cfg, *args, **kwargs, registry=self) 60 | 61 | def _register_module(self, 62 | module: Type, 63 | module_name: Optional[Union[str, List[str]]] = None) -> None: 64 | """Register a module. 65 | 66 | Args: 67 | module (type): Module to be registered. Typically a class or a 68 | function, but generally all ``Callable`` are acceptable. 69 | module_name (str or list of str, optional): The module name to be 70 | registered. If not specified, the class name will be used. 71 | Defaults to None. 72 | force (bool): Whether to override an existing class with the same 73 | name. Defaults to False. 74 | """ 75 | if not callable(module): 76 | raise TypeError(f'module must be Callable, but got {type(module)}') 77 | 78 | if module_name is None: 79 | module_name = module.__name__ 80 | if isinstance(module_name, str): 81 | module_name = [module_name] 82 | for name in module_name: 83 | if name in self._module_dict: 84 | existed_module = self.module_dict[name] 85 | raise KeyError(f'{name} is already registered in {self.name} ' 86 | f'at {existed_module.__module__}') 87 | self._module_dict[name] = module 88 | 89 | def register_module( 90 | self, 91 | name: Optional[Union[str, List[str]]] = None, 92 | module: Optional[Type] = None) -> Union[type, Callable]: 93 | """Register a module. 94 | 95 | A record will be added to ``self._module_dict``, whose key is the class 96 | name or the specified name, and value is the class itself. 97 | It can be used as a decorator or a normal function. 98 | """ 99 | 100 | # raise the error ahead of time 101 | if not (name is None or isinstance(name, str)): 102 | raise TypeError( 103 | 'name must be None, an instance of str, ' 104 | f'but got {type(name)}') 105 | 106 | # use it as a normal method: x.register_module(module=SomeClass) 107 | if module is not None: 108 | self._register_module(module=module, module_name=name) 109 | return module 110 | 111 | # use it as a decorator: @x.register_module() 112 | def _register(module): 113 | self._register_module(module=module, module_name=name) 114 | return module 115 | 116 | return _register 117 | 118 | 119 | def build_from_cfg( 120 | cfg: dict, 121 | registry: Registry, 122 | default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: 123 | """Build a module from config dict when it is a class configuration, or 124 | call a function from config dict when it is a function configuration. 125 | """ 126 | # Avoid circular import 127 | 128 | if not isinstance(cfg, (dict, ConfigDict, Config)): 129 | raise TypeError( 130 | f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}') 131 | 132 | if 'type' not in cfg: 133 | if default_args is None or 'type' not in default_args: 134 | raise KeyError( 135 | '`cfg` or `default_args` must contain the key "type", ' 136 | f'but got {cfg}\n{default_args}') 137 | 138 | if not isinstance(registry, Registry): 139 | raise TypeError('registry must be a mmengine.Registry object, ' 140 | f'but got {type(registry)}') 141 | 142 | if not (isinstance(default_args, 143 | (dict, ConfigDict, Config)) or default_args is None): 144 | raise TypeError( 145 | 'default_args should be a dict, ConfigDict, Config or None, ' 146 | f'but got {type(default_args)}') 147 | 148 | args = cfg.copy() 149 | if default_args is not None: 150 | for name, value in default_args.items(): 151 | args.setdefault(name, value) 152 | 153 | # Instance should be built under target scope, if `_scope_` is defined 154 | # in cfg, current default scope should switch to specified scope 155 | # temporarily. 156 | scope = args.pop('_scope_', None) 157 | with registry.switch_scope_and_registry(scope) as registry: 158 | obj_type = args.pop('type') 159 | if isinstance(obj_type, str): 160 | obj_cls = registry.get(obj_type) 161 | if obj_cls is None: 162 | raise KeyError( 163 | f'{obj_type} is not in the {registry.name} registry. ' 164 | f'Please check whether the value of `{obj_type}` is ' 165 | 'correct or it was registered as expected. More details ' 166 | 'can be found at ' 167 | 'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 168 | ) 169 | # this will include classes, functions, partial functions and more 170 | elif callable(obj_type): 171 | obj_cls = obj_type 172 | else: 173 | raise TypeError( 174 | f'type must be a str or valid type, but got {type(obj_type)}') 175 | 176 | try: 177 | # If `obj_cls` inherits from `ManagerMixin`, it should be 178 | # instantiated by `ManagerMixin.get_instance` to ensure that it 179 | # can be accessed globally. 180 | if inspect.isclass(obj_cls) and \ 181 | issubclass(obj_cls, ManagerMixin): # type: ignore 182 | obj = obj_cls.get_instance(**args) # type: ignore 183 | else: 184 | obj = obj_cls(**args) # type: ignore 185 | 186 | if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls) 187 | or inspect.ismethod(obj_cls)): 188 | print_log( 189 | f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 190 | 'registry, and its implementation can be found in ' 191 | f'{obj_cls.__module__}', # type: ignore 192 | logger='current', 193 | level=logging.DEBUG) 194 | else: 195 | print_log( 196 | 'An instance is built from registry, and its constructor ' 197 | f'is {obj_cls}', 198 | logger='current', 199 | level=logging.DEBUG) 200 | return obj 201 | 202 | except Exception as e: 203 | # Normal TypeError does not print class name. 204 | cls_location = '/'.join( 205 | obj_cls.__module__.split('.')) # type: ignore 206 | raise type(e)( 207 | f'class `{obj_cls.__name__}` in ' # type: ignore 208 | f'{cls_location}.py: {e}') 209 | -------------------------------------------------------------------------------- /rsseg/models/segheads/logcan_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | def conv_3x3(in_channel, out_channel): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False), 8 | nn.BatchNorm2d(out_channel), 9 | nn.ReLU(inplace=True) 10 | ) 11 | def patch_split(input, patch_size): 12 | """ 13 | input: (B, C, H, W) 14 | output: (B*num_h*num_w, C, patch_h, patch_w) 15 | """ 16 | B, C, H, W = input.size() 17 | num_h, num_w = patch_size 18 | patch_h, patch_w = H // num_h, W // num_w 19 | out = input.view(B, C, num_h, patch_h, num_w, patch_w) 20 | out = out.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, patch_h, 21 | patch_w) # (B*num_h*num_w, C, patch_h, patch_w) 22 | return out 23 | 24 | 25 | def patch_recover(input, patch_size): 26 | """ 27 | input: (B*num_h*num_w, C, patch_h, patch_w) 28 | output: (B, C, H, W) 29 | """ 30 | N, C, patch_h, patch_w = input.size() 31 | num_h, num_w = patch_size 32 | H, W = num_h * patch_h, num_w * patch_w 33 | B = N // (num_h * num_w) 34 | 35 | out = input.view(B, num_h, num_w, C, patch_h, patch_w) 36 | out = out.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, C, H, W) 37 | return out 38 | 39 | 40 | class SelfAttentionBlock(nn.Module): 41 | """ 42 | query_feats: (B*num_h*num_w, C, patch_h, patch_w) 43 | key_feats: (B*num_h*num_w, C, K, 1) 44 | value_feats: (B*num_h*num_w, C, K, 1) 45 | 46 | output: (B*num_h*num_w, C, patch_h, patch_w) 47 | """ 48 | 49 | def __init__(self, key_in_channels, query_in_channels, transform_channels, out_channels, 50 | key_query_num_convs, value_out_num_convs): 51 | super(SelfAttentionBlock, self).__init__() 52 | self.key_project = self.buildproject( 53 | in_channels=key_in_channels, 54 | out_channels=transform_channels, 55 | num_convs=key_query_num_convs, 56 | ) 57 | self.query_project = self.buildproject( 58 | in_channels=query_in_channels, 59 | out_channels=transform_channels, 60 | num_convs=key_query_num_convs 61 | ) 62 | self.value_project = self.buildproject( 63 | in_channels=key_in_channels, 64 | out_channels=transform_channels, 65 | num_convs=value_out_num_convs 66 | ) 67 | self.out_project = self.buildproject( 68 | in_channels=transform_channels, 69 | out_channels=out_channels, 70 | num_convs=value_out_num_convs 71 | ) 72 | self.transform_channels = transform_channels 73 | 74 | def forward(self, query_feats, key_feats, value_feats): 75 | batch_size = query_feats.size(0) 76 | 77 | query = self.query_project(query_feats) 78 | query = query.reshape(*query.shape[:2], -1) 79 | query = query.permute(0, 2, 1).contiguous() # (B*num_h*num_w, patch_h*patch_w, C) 80 | 81 | key = self.key_project(key_feats) 82 | key = key.reshape(*key.shape[:2], -1) # (B*num_h*num_w, C, K) 83 | 84 | value = self.value_project(value_feats) 85 | value = value.reshape(*value.shape[:2], -1) 86 | value = value.permute(0, 2, 1).contiguous() # (B*num_h*num_w, K, C) 87 | 88 | sim_map = torch.matmul(query, key) 89 | 90 | sim_map = (self.transform_channels ** -0.5) * sim_map 91 | sim_map = F.softmax(sim_map, dim=-1) # (B*num_h*num_w, patch_h*patch_w, K) 92 | 93 | context = torch.matmul(sim_map, value) # (B*num_h*num_w, patch_h*patch_w, C) 94 | context = context.permute(0, 2, 1).contiguous() 95 | context = context.reshape(batch_size, -1, *query_feats.shape[2:]) # (B*num_h*num_w, C, patch_h, patch_w) 96 | 97 | context = self.out_project(context) # (B*num_h*num_w, C, patch_h, patch_w) 98 | return context 99 | 100 | def buildproject(self, in_channels, out_channels, num_convs): 101 | convs = nn.Sequential( 102 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), 103 | nn.BatchNorm2d(out_channels), 104 | nn.ReLU(inplace=True) 105 | ) 106 | for _ in range(num_convs - 1): 107 | convs.append( 108 | nn.Sequential( 109 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), 110 | nn.BatchNorm2d(out_channels), 111 | nn.ReLU(inplace=True) 112 | ) 113 | ) 114 | if len(convs) > 1: 115 | return nn.Sequential(*convs) 116 | return convs[0] 117 | 118 | 119 | class SpatialGatherModule(nn.Module): 120 | def __init__(self, scale=1): 121 | super(SpatialGatherModule, self).__init__() 122 | self.scale = scale 123 | 124 | def forward(self, features, probs): 125 | batch_size, num_classes, h, w = probs.size() 126 | probs = probs.view(batch_size, num_classes, -1) # batch * k * hw 127 | probs = F.softmax(self.scale * probs, dim=2) 128 | 129 | features = features.view(batch_size, features.size(1), -1) 130 | features = features.permute(0, 2, 1) # batch * hw * c 131 | 132 | ocr_context = torch.matmul(probs, features) # (B, k, c) 133 | ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(-1) # (B, C, K, 1) 134 | return ocr_context 135 | 136 | class MRAM(nn.Module): 137 | """ 138 | feat: (B, C, H, W) 139 | global_center: (B, C, K, 1) 140 | """ 141 | 142 | def __init__(self, in_channels, inner_channels, num_class, patch_size=(4, 4)): 143 | super(MRAM, self).__init__() 144 | self.patch_size = patch_size 145 | self.feat_decoder = nn.Conv2d(in_channels, num_class, kernel_size=1) 146 | 147 | self.correlate_net = SelfAttentionBlock( 148 | key_in_channels=in_channels, 149 | query_in_channels=in_channels, 150 | transform_channels=inner_channels, 151 | out_channels=in_channels, 152 | key_query_num_convs=2, 153 | value_out_num_convs=1 154 | ) 155 | 156 | self.get_center = SpatialGatherModule() 157 | 158 | self.cat_conv = nn.Sequential( 159 | conv_3x3(in_channels * 2, in_channels), 160 | nn.Dropout2d(0.1), 161 | conv_3x3(in_channels, in_channels), 162 | nn.Dropout2d(0.1) 163 | ) 164 | 165 | def forward(self, feat, global_center): 166 | pred = self.feat_decoder(feat) 167 | patch_feat = patch_split(feat, self.patch_size) # (B*num_h*num_w, C, patch_h, patch_w) 168 | patch_pred = patch_split(pred, self.patch_size) # (B*num_h*num_w, K, patch_h, patch_w) 169 | local_center = self.get_center(patch_feat, patch_pred) # (B*num_h*num_w, C, K, 1) 170 | num_h, num_w = self.patch_size 171 | global_center = global_center.repeat(num_h * num_w, 1, 1, 1) 172 | 173 | new_feat = self.correlate_net(patch_feat, local_center, global_center) # (B*num_h*num_w, C, patch_h, patch_w) 174 | new_feat = patch_recover(new_feat, self.patch_size) # (B, C, H, W) 175 | out = self.cat_conv(torch.cat([feat, new_feat], dim=1)) 176 | 177 | return out 178 | 179 | class SpatialGatherModule(nn.Module): 180 | def __init__(self, scale=1): 181 | super(SpatialGatherModule, self).__init__() 182 | self.scale = scale 183 | 184 | def forward(self, features, probs): 185 | batch_size, num_classes, h, w = probs.size() 186 | probs = probs.view(batch_size, num_classes, -1) # batch * k * hw 187 | probs = F.softmax(self.scale * probs, dim=2) 188 | 189 | features = features.view(batch_size, features.size(1), -1) 190 | features = features.permute(0, 2, 1) # batch * hw * c 191 | 192 | ocr_context = torch.matmul(probs, features) # (B, k, c) 193 | ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(-1) # (B, C, K, 1) 194 | return ocr_context 195 | 196 | def upsample_add(x_small, x_big): 197 | x_small = F.interpolate(x_small, scale_factor=2, mode="bilinear", align_corners=False) 198 | return torch.cat([x_small, x_big], dim=1) 199 | 200 | class LoGCAN_Head(nn.Module): 201 | def __init__(self, in_channel=[256, 512, 1024, 2048], transform_channel=128, num_class=6): 202 | super(LoGCAN_Head, self).__init__() 203 | self.bottleneck1 = conv_3x3(in_channel[0], transform_channel) 204 | self.bottleneck2 = conv_3x3(in_channel[1], transform_channel) 205 | self.bottleneck3 = conv_3x3(in_channel[2], transform_channel) 206 | self.bottleneck4 = conv_3x3(in_channel[3], transform_channel) 207 | 208 | self.decoder_stage1 = nn.Conv2d(transform_channel, num_class, kernel_size=1) 209 | self.global_gather = SpatialGatherModule() 210 | 211 | self.center1 = MRAM(transform_channel, transform_channel//2, num_class) 212 | self.center2 = MRAM(transform_channel, transform_channel//2, num_class) 213 | self.center3 = MRAM(transform_channel, transform_channel//2, num_class) 214 | self.center4 = MRAM(transform_channel, transform_channel//2, num_class) 215 | 216 | self.catconv1 = conv_3x3(transform_channel*2, transform_channel) 217 | self.catconv2 = conv_3x3(transform_channel*2, transform_channel) 218 | self.catconv3 = conv_3x3(transform_channel*2, transform_channel) 219 | 220 | self.catconv = conv_3x3(transform_channel, transform_channel) 221 | 222 | def forward(self, x_list): 223 | feat1, feat2, feat3, feat4 = self.bottleneck1(x_list[0]), self.bottleneck2(x_list[1]), self.bottleneck3(x_list[2]), self.bottleneck4(x_list[3]) 224 | pred1 = self.decoder_stage1(feat4) 225 | 226 | global_center = self.global_gather(feat4, pred1) 227 | 228 | new_feat4 = self.center4(feat4, global_center) 229 | 230 | feat3 = self.catconv1(upsample_add(new_feat4, feat3)) 231 | new_feat3 = self.center3(feat3, global_center) 232 | 233 | feat2 = self.catconv2(upsample_add(new_feat3, feat2)) 234 | new_feat2 = self.center2(feat2, global_center) 235 | 236 | feat1 = self.catconv3(upsample_add(new_feat2, feat1)) 237 | new_feat1 = self.center1(feat1, global_center) 238 | 239 | new_feat4 = F.interpolate(new_feat4, scale_factor=8, mode="bilinear", align_corners=False) 240 | new_feat3 = F.interpolate(new_feat3, scale_factor=4, mode="bilinear", align_corners=False) 241 | new_feat2 = F.interpolate(new_feat2, scale_factor=2, mode="bilinear", align_corners=False) 242 | 243 | out = self.catconv(new_feat1 + new_feat2 + new_feat3 + new_feat4) 244 | 245 | return [out, pred1] -------------------------------------------------------------------------------- /rsseg/models/backbones/tinyvim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 6 | from timm.models.layers import DropPath, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.layers.helpers import to_2tuple 9 | from rsseg.models.backbones.tvimblock import TViMBlock, Conv2d_BN, RepDW, FFN 10 | from mmseg.utils import get_root_logger 11 | from mmcv.runner import _load_checkpoint 12 | 13 | TinyViM_width = { 14 | 'S': [48, 64, 168, 224], 15 | 'B': [48, 96, 192, 384], 16 | 'L': [64, 128, 384, 512], 17 | } 18 | 19 | TinyViM_depth = { 20 | 'S': [3, 3, 9, 6], 21 | 'B': [4, 3, 10, 5], 22 | 'L': [4, 4, 12, 6], 23 | } 24 | 25 | def stem(in_chs, out_chs): 26 | """ 27 | Stem Layer that is implemented by two layers of conv. 28 | Output: sequence of layers with final shape of [B, C, H/4, W/4] 29 | """ 30 | return nn.Sequential( 31 | Conv2d_BN(in_chs, out_chs // 2, 3, 2, 1), 32 | nn.GELU(), 33 | Conv2d_BN(out_chs // 2, out_chs, 3, 2, 1), 34 | nn.GELU(),) 35 | 36 | 37 | class Embedding(nn.Module): 38 | """ 39 | Patch Embedding that is implemented by a layer of conv. 40 | Input: tensor in shape [B, C, H, W] 41 | Output: tensor in shape [B, C, H/stride, W/stride] 42 | """ 43 | 44 | def __init__(self, patch_size=16, stride=2, padding=0, 45 | in_chans=3, embed_dim=48): 46 | super().__init__() 47 | patch_size = to_2tuple(patch_size) 48 | stride = to_2tuple(stride) 49 | padding = to_2tuple(padding) 50 | self.proj = Conv2d_BN(in_chans, embed_dim, patch_size, stride,padding) 51 | 52 | def forward(self, x): 53 | x = self.proj(x) 54 | return x 55 | 56 | class LocalBlock(nn.Module): 57 | """ 58 | Implementation of ConvEncoder with 3*3 and 1*1 convolutions. 59 | Input: tensor with shape [B, C, H, W] 60 | Output: tensor with shape [B, C, H, W] 61 | """ 62 | 63 | def __init__(self, dim, hidden_dim=64, drop_path=0., use_layer_scale=True): 64 | super().__init__() 65 | self.dwconv = RepDW(dim) 66 | self.mlp = FFN(dim, hidden_dim) 67 | self.drop_path = DropPath(drop_path) if drop_path > 0. \ 68 | else nn.Identity() 69 | self.use_layer_scale = use_layer_scale 70 | if use_layer_scale: 71 | self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) 72 | self.apply(self._init_weights) 73 | 74 | def _init_weights(self, m): 75 | if isinstance(m, nn.Conv2d): 76 | trunc_normal_(m.weight, std=.02) 77 | if m.bias is not None: 78 | nn.init.constant_(m.bias, 0) 79 | 80 | def forward(self, x): 81 | input = x 82 | x = self.dwconv(x) 83 | x = self.mlp(x) 84 | if self.use_layer_scale: 85 | x = input + self.drop_path(self.layer_scale * x) 86 | else: 87 | x = input + self.drop_path(x) 88 | return x 89 | 90 | 91 | 92 | def Stage(dim, index, layers, mlp_ratio=4., 93 | ssm_d_state=8, ssm_ratio=1.0, ssm_num=1, 94 | use_layer_scale=True, layer_scale_init_value=1e-5): 95 | """ 96 | Implementation of each TinyViM 97 | stages. 98 | Input: tensor in shape [B, C, H, W] 99 | Output: tensor in shape [B, C, H, W] 100 | """ 101 | blocks = [] 102 | 103 | for block_idx in range(layers[index]): 104 | if layers[index] - block_idx <= ssm_num: 105 | blocks.append(TViMBlock(dim, ssm_d_state=ssm_d_state,ssm_ratio=ssm_ratio,ssm_conv_bias=False, index=index)) 106 | 107 | else: 108 | if index==2 and block_idx == layers[index] // 2: 109 | blocks.append(TViMBlock(dim, ssm_d_state=8,ssm_ratio=1.0,ssm_conv_bias=False, index=index)) 110 | else: 111 | blocks.append(LocalBlock(dim=dim, hidden_dim=int(mlp_ratio * dim))) 112 | 113 | blocks = nn.Sequential(*blocks) 114 | return blocks 115 | 116 | 117 | class TinyViM(nn.Module): 118 | 119 | def __init__(self, layers, embed_dims=None, 120 | mlp_ratios=4, downsamples=None, 121 | num_classes=1000, 122 | down_patch_size=3, down_stride=2, down_pad=1, 123 | use_layer_scale=True, layer_scale_init_value=1e-5, 124 | fork_feat=True, 125 | init_cfg=None, 126 | pretrained=None, 127 | ssm_num=1, 128 | distillation=True, 129 | **kwargs): 130 | super().__init__() 131 | 132 | if not fork_feat: 133 | self.num_classes = num_classes 134 | self.fork_feat = fork_feat 135 | 136 | self.patch_embed = stem(3, embed_dims[0]) 137 | 138 | network = [] 139 | for i in range(len(layers)): 140 | stage = Stage(embed_dims[i], i, layers, mlp_ratio=mlp_ratios, 141 | use_layer_scale=use_layer_scale, 142 | layer_scale_init_value=layer_scale_init_value, 143 | ssm_num=ssm_num) 144 | network.append(stage) 145 | if i >= len(layers) - 1: 146 | break 147 | if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: 148 | network.append( 149 | Embedding( 150 | patch_size=down_patch_size, stride=down_stride, 151 | padding=down_pad, 152 | in_chans=embed_dims[i], embed_dim=embed_dims[i + 1] 153 | ) 154 | ) 155 | 156 | self.network = nn.ModuleList(network) 157 | 158 | if self.fork_feat: 159 | # add a norm layer for each output 160 | self.out_indices = [0, 2, 4, 6] 161 | for i_emb, i_layer in enumerate(self.out_indices): 162 | if i_emb == 0 and os.environ.get('FORK_LAST3', None): 163 | layer = nn.Identity() 164 | else: 165 | layer = nn.BatchNorm2d(embed_dims[i_emb]) 166 | layer_name = f'norm{i_layer}' 167 | self.add_module(layer_name, layer) 168 | else: 169 | # Classifier head 170 | self.norm = nn.BatchNorm2d(embed_dims[-1]) 171 | self.head = nn.Linear( 172 | embed_dims[-1], num_classes) if num_classes > 0 \ 173 | else nn.Identity() 174 | self.dist = distillation 175 | if self.dist: 176 | self.dist_head = nn.Linear( 177 | embed_dims[-1], num_classes) if num_classes > 0 \ 178 | else nn.Identity() 179 | 180 | # self.apply(self.cls_init_weights) 181 | self.apply(self._init_weights) 182 | 183 | self.init_cfg = copy.deepcopy(init_cfg) 184 | # load pre-trained model 185 | if self.fork_feat and ( 186 | self.init_cfg is not None or pretrained is not None): 187 | self.init_weights() 188 | 189 | # init for mmdetection or mmsegmentation by loading 190 | # imagenet pre-trained weights 191 | def init_weights(self, pretrained=None): 192 | logger = get_root_logger() 193 | if self.init_cfg is None and pretrained is None: 194 | logger.warn(f'No pre-trained weights for ' 195 | f'{self.__class__.__name__}, ' 196 | f'training start from scratch') 197 | pass 198 | else: 199 | assert 'checkpoint' in self.init_cfg, f'Only support ' \ 200 | f'specify `Pretrained` in ' \ 201 | f'`init_cfg` in ' \ 202 | f'{self.__class__.__name__} ' 203 | if self.init_cfg is not None: 204 | ckpt_path = self.init_cfg['checkpoint'] 205 | elif pretrained is not None: 206 | ckpt_path = pretrained 207 | 208 | ckpt = _load_checkpoint( 209 | ckpt_path, logger=logger, map_location='cpu') 210 | if 'state_dict' in ckpt: 211 | _state_dict = ckpt['state_dict'] 212 | elif 'model' in ckpt: 213 | _state_dict = ckpt['model'] 214 | else: 215 | _state_dict = ckpt 216 | 217 | state_dict = _state_dict 218 | missing_keys, unexpected_keys = \ 219 | self.load_state_dict(state_dict, False) 220 | 221 | def _init_weights(self, m): 222 | if isinstance(m, (nn.Conv2d, nn.Linear)): 223 | trunc_normal_(m.weight, std=.02) 224 | if m.bias is not None: 225 | nn.init.constant_(m.bias, 0) 226 | elif isinstance(m, (nn.LayerNorm)): 227 | nn.init.constant_(m.bias, 0) 228 | nn.init.constant_(m.weight, 1.0) 229 | 230 | def forward_tokens(self, x): 231 | outs = [] 232 | for idx, block in enumerate(self.network): 233 | x = block(x) 234 | if self.fork_feat and idx in self.out_indices: 235 | norm_layer = getattr(self, f'norm{idx}') 236 | x_out = norm_layer(x) 237 | outs.append(x_out) 238 | if self.fork_feat: 239 | return outs 240 | return x 241 | 242 | def forward(self, x): 243 | x = self.patch_embed(x) 244 | x = self.forward_tokens(x) 245 | if self.fork_feat: 246 | # Output features of four stages for dense prediction 247 | return x 248 | 249 | x = self.norm(x) 250 | if self.dist: 251 | cls_out = self.head(x.flatten(2).mean(-1)), self.dist_head(x.flatten(2).mean(-1)) 252 | if not self.training: 253 | cls_out = (cls_out[0] + cls_out[1]) / 2 254 | else: 255 | cls_out = self.head(x.flatten(2).mean(-1)) 256 | # For image classification 257 | return cls_out 258 | 259 | 260 | def _cfg(url='', **kwargs): 261 | return { 262 | 'url': url, 263 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 264 | 'crop_pct': .95, 'interpolation': 'bicubic', 265 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 266 | 'classifier': 'head', 267 | **kwargs 268 | } 269 | 270 | @register_model 271 | def TinyViM_S(pretrained=None, **kwargs): 272 | model = TinyViM( 273 | layers=TinyViM_depth['S'], 274 | embed_dims=TinyViM_width['S'], 275 | downsamples=[True, True, True, True], 276 | ssm_num=1, 277 | **kwargs) 278 | model.default_cfg = _cfg(crop_pct=0.9) 279 | return model 280 | 281 | 282 | @register_model 283 | def TinyViM_B(pretrained=None, **kwargs): 284 | model = TinyViM( 285 | layers=TinyViM_depth['B'], 286 | embed_dims=TinyViM_width['B'], 287 | downsamples=[True, True, True, True], 288 | ssm_num=1, 289 | **kwargs) 290 | model.default_cfg = _cfg(crop_pct=0.9) 291 | return model 292 | 293 | 294 | @register_model 295 | def TinyViM_L(pretrained=None, **kwargs): 296 | model = TinyViM( 297 | layers=TinyViM_depth['L'], 298 | embed_dims=TinyViM_width['L'], 299 | downsamples=[True, True, True, True], 300 | ssm_num=1, 301 | **kwargs) 302 | model.default_cfg = _cfg(crop_pct=0.9) 303 | return model -------------------------------------------------------------------------------- /tools/dataset_patch_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import torch 5 | from PIL import Image 6 | import random 7 | import glob 8 | import time 9 | import multiprocessing.pool as mpp 10 | import multiprocessing as mp 11 | import cv2 12 | import re 13 | import albumentations as albu 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--dataset-type", type=str, default="potsdam") 18 | parser.add_argument("--img-dir", default="data/potsdam/train_images") 19 | parser.add_argument("--mask-dir", default="data/potsdam/train_masks") 20 | parser.add_argument("--output-img-dir", default="data/potsdam/train/images_1024") 21 | parser.add_argument("--output-mask-dir", default="data/potsdam/train/masks_1024") 22 | parser.add_argument("--mode", type=str, default='train') 23 | 24 | parser.add_argument("--split-size", type=int, default=1024) 25 | parser.add_argument("--stride", type=int, default=512) 26 | return parser.parse_args() 27 | 28 | def randomsizedcrop(image, mask): 29 | # assert image.shape[:2] == mask.shape 30 | h, w = image.shape[0], image.shape[1] 31 | crop = albu.RandomSizedCrop(min_max_height=(int(3*h//8), int(h//2)), width=h, height=w)(image=image.copy(), mask=mask.copy()) 32 | img_crop, mask_crop = crop['image'], crop['mask'] 33 | return img_crop, mask_crop 34 | 35 | def aug(image, mask, img_filename, mask_filename, k): 36 | # print(mask) 37 | h,w = image.shape[:2] 38 | car_mask = np.zeros((h,w),dtype=np.uint8) 39 | veg_mask = np.zeros((h,w),dtype=np.uint8) 40 | mask = np.array(mask) 41 | car_mask[np.all(mask==np.array([0, 255, 255]), axis=-1)] = 1 42 | veg_mask[np.all(mask==np.array([255, 255, 0]), axis=-1)] = 1 43 | count_car = np.count_nonzero(car_mask) 44 | count_veg = np.count_nonzero(veg_mask) 45 | 46 | if count_car / (h*w) < 0.08: 47 | return 48 | 49 | v_flip = albu.VerticalFlip(p=1.0)(image=image.copy(), mask=mask.copy()) 50 | h_flip = albu.HorizontalFlip(p=1.0)(image=image.copy(), mask=mask.copy()) 51 | rotate_90 = albu.RandomRotate90(p=1.0)(image=image.copy(), mask=mask.copy()) 52 | 53 | image_vflip, mask_vflip = v_flip['image'], v_flip['mask'] 54 | image_hflip, mask_hflip = h_flip['image'], h_flip['mask'] 55 | image_rotate, mask_rotate = rotate_90['image'], rotate_90['mask'] 56 | 57 | image_list = [image, image_vflip, image_hflip, image_rotate] 58 | mask_list = [mask, mask_vflip, mask_hflip, mask_rotate] 59 | 60 | for i in range(len(image_list)): 61 | image, mask = image_list[i], mask_list[i] 62 | out_img_path = os.path.join(imgs_output_dir, "{}_{}_{}.tif".format(img_filename, k, i)) 63 | cv2.imwrite(out_img_path, image) 64 | out_mask_path = os.path.join(masks_output_dir, "{}_{}_{}.png".format(mask_filename, k, i)) 65 | cv2.imwrite(out_mask_path, mask) 66 | return 67 | 68 | def split(img,mask_rgb,img_filename,mask_filename): 69 | assert img.shape[0] == mask_rgb.shape[0] \ 70 | and img.shape[1] == mask_rgb.shape[1] 71 | img_W = img.shape[1] 72 | img_H = img.shape[0] 73 | k = 0 74 | 75 | v_flip = albu.VerticalFlip(p=1.0)(image=img.copy(), mask=mask_rgb.copy()) 76 | h_flip = albu.HorizontalFlip(p=1.0)(image=img.copy(), mask=mask_rgb.copy()) 77 | img_mask = [(img, mask_rgb)] 78 | if dataset_type == 'vaihingen' and mode == "train": 79 | img_mask += [(v_flip['image'],v_flip['mask']), (h_flip['image'],h_flip['mask'])] 80 | 81 | for i in range(len(img_mask)): 82 | img, mask_rgb = img_mask[i] 83 | 84 | for y in range(0, img.shape[0], stride): 85 | for x in range(0, img.shape[1], stride): 86 | x_str = x 87 | x_end = x + split_size 88 | y_str = y 89 | y_end = y + split_size 90 | 91 | if x_end > img_W: 92 | diff_x = x_end - img_W 93 | x_str -= diff_x 94 | x_end = img_W 95 | 96 | if y_end > img_H: 97 | diff_y = y_end - img_H 98 | y_str -= diff_y 99 | y_end = img_H 100 | 101 | 102 | img_tile = img[y_str:y_end, x_str:x_end] 103 | mask_rgb_tile = mask_rgb[y_str:y_end, x_str:x_end] 104 | 105 | if img_tile.shape[0] == split_size and img_tile.shape[1] == split_size \ 106 | and mask_rgb_tile.shape[0] == split_size and mask_rgb_tile.shape[1] == split_size: 107 | 108 | 109 | out_img_path = os.path.join(imgs_output_dir, "{}_{}.tif".format(img_filename, k)) 110 | cv2.imwrite(out_img_path, img_tile) 111 | 112 | out_mask_path = os.path.join(masks_output_dir, "{}_{}.png".format(mask_filename, k)) 113 | cv2.imwrite(out_mask_path, mask_rgb_tile) 114 | 115 | if dataset_type == "vaihingen" and mode == "train": 116 | img_crop, mask_crop = randomsizedcrop(img_tile, mask_rgb_tile) 117 | aug(img_crop, mask_crop, img_filename, mask_filename, k) 118 | 119 | k += 1 120 | 121 | 122 | def vaihingen_split(inp): 123 | (img_path, mask_path, imgs_output_dir, masks_output_dir, split_size, stride) = inp 124 | img_filename = os.path.splitext(os.path.basename(img_path))[0] 125 | mask_filename = img_filename 126 | 127 | image = Image.open(img_path).convert('RGB') 128 | mask = Image.open(mask_path).convert('RGB') 129 | 130 | image_width, image_height = image.size[1], mask.size[0] 131 | mask_width, mask_height = mask.size[1], mask.size[0] 132 | assert image_height == mask_height and image_width == mask_width 133 | 134 | img = cv2.cvtColor(np.array(image.copy()), cv2.COLOR_RGB2BGR) 135 | mask_rgb = cv2.cvtColor(np.array(mask.copy()), cv2.COLOR_RGB2BGR) 136 | split(img,mask_rgb,img_filename,mask_filename) 137 | 138 | def get_vaihingen_file(imgs_dir,masks_dir,imgs_output_dir,masks_output_dir,split_size,stride,mode): 139 | train_seq_list = [1, 3, 5, 7, 11, 13, 15, 17, 21, 23, 26, 28, 30, 32, 34, 37] 140 | test_seq_list = [2, 4, 6, 8, 10, 12, 14, 16, 20, 22, 24, 27, 29, 31, 33, 35, 38] 141 | 142 | seq_list = train_seq_list if mode == 'train' else test_seq_list 143 | 144 | img_paths_ori = glob.glob(os.path.join(imgs_dir, "*.tif")) 145 | mask_paths_ori = glob.glob(os.path.join(masks_dir, "*.tif")) 146 | 147 | img_num = [] 148 | img_paths = [] 149 | mask_paths = [] 150 | mask_num = [] 151 | 152 | for file_name in img_paths_ori: 153 | match = re.search(r'\d+', file_name[::-1]) 154 | if match: 155 | number = int(match.group()[::-1]) 156 | if number in seq_list: 157 | img_paths.append(file_name) 158 | img_num.append(number) 159 | for file_name in mask_paths_ori: 160 | match = re.search(r'\d+', file_name[::-1]) 161 | if match: 162 | number = int(match.group()[::-1]) 163 | if number in seq_list: 164 | mask_paths.append(file_name) 165 | mask_num.append(number) 166 | 167 | img_paths = sorted(img_paths, key=lambda x: img_num[img_paths.index(x)]) 168 | mask_paths = sorted(mask_paths, key=lambda x: mask_num[mask_paths.index(x)]) 169 | 170 | 171 | 172 | if not os.path.exists(imgs_output_dir): 173 | os.makedirs(imgs_output_dir) 174 | if not os.path.exists(masks_output_dir): 175 | os.makedirs(masks_output_dir) 176 | 177 | inp = [(img_path, mask_path, imgs_output_dir, masks_output_dir, split_size, stride) 178 | for img_path, mask_path in zip(img_paths, mask_paths)] 179 | 180 | t0 = time.time() 181 | mpp.Pool(processes=mp.cpu_count()).map(vaihingen_split, inp) 182 | t1 = time.time() 183 | split_time = t1 - t0 184 | print('images spliting spends: {} s'.format(split_time)) 185 | 186 | def potsdam_split(inp): 187 | (img_path, mask_path, imgs_output_dir, masks_output_dir, split_size, stride) = inp 188 | 189 | img_filename = os.path.splitext(os.path.basename(img_path))[0] 190 | mask_filename = img_filename 191 | 192 | image = Image.open(img_path).convert('RGB') 193 | mask = Image.open(mask_path).convert('RGB') 194 | 195 | 196 | image_width, image_height = image.size[1], mask.size[0] 197 | mask_width, mask_height = mask.size[1], mask.size[0] 198 | assert image_height == mask_height and image_width == mask_width 199 | 200 | img = cv2.cvtColor(np.array(image.copy()), cv2.COLOR_RGB2BGR) 201 | mask_rgb = cv2.cvtColor(np.array(mask.copy()), cv2.COLOR_RGB2BGR) 202 | split(img,mask_rgb,img_filename,mask_filename) 203 | 204 | 205 | def get_potsdam_file(imgs_dir,masks_dir,imgs_output_dir,masks_output_dir,split_size,stride,mode): 206 | 207 | train_seq_list = ['2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11', '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7', '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9'] 208 | test_seq_list = ['2_13', '2_14', '3_13', '3_14', '4_13', '4_14', '4_15', '5_13', '5_14', '5_15', '6_13', '6_14', '6_15', '7_13'] 209 | 210 | seq_list = train_seq_list if mode == 'train' else test_seq_list 211 | 212 | 213 | img_paths_ori = glob.glob(os.path.join(imgs_dir, "*.tif")) 214 | mask_paths_ori = glob.glob(os.path.join(masks_dir, "*.tif")) 215 | 216 | img_paths = [] 217 | mask_paths = [] 218 | 219 | for file_name in img_paths_ori: 220 | match = re.search(r'(\d+_\d+)', file_name) 221 | if match: 222 | number = match.group() 223 | if number in seq_list: 224 | img_paths.append(file_name) 225 | for file_name in mask_paths_ori: 226 | match = re.search(r'(\d+_\d+)', file_name) 227 | if match: 228 | number = match.group() 229 | if number in seq_list: 230 | mask_paths.append(file_name) 231 | 232 | 233 | img_paths.sort() 234 | mask_paths.sort() 235 | 236 | if not os.path.exists(imgs_output_dir): 237 | os.makedirs(imgs_output_dir) 238 | if not os.path.exists(masks_output_dir): 239 | os.makedirs(masks_output_dir) 240 | 241 | inp = [(img_path, mask_path, imgs_output_dir, masks_output_dir, 242 | split_size, stride) 243 | for img_path, mask_path in zip(img_paths, mask_paths)] 244 | 245 | t0 = time.time() 246 | mpp.Pool(processes=mp.cpu_count()).map(potsdam_split, inp) 247 | t1 = time.time() 248 | split_time = t1 - t0 249 | print('images spliting spends: {} s'.format(split_time)) 250 | 251 | 252 | SEED = 66 253 | def seed_everything(seed): 254 | random.seed(seed) 255 | os.environ['PYTHONHASHSEED'] = str(seed) 256 | np.random.seed(seed) 257 | torch.manual_seed(seed) 258 | torch.cuda.manual_seed(seed) 259 | torch.backends.cudnn.deterministic = True 260 | torch.backends.cudnn.benchmark = True 261 | 262 | if __name__ == '__main__': 263 | seed_everything(66) 264 | args = parse_args() 265 | dataset_type = args.dataset_type 266 | imgs_dir = args.img_dir 267 | masks_dir = args.mask_dir 268 | imgs_output_dir = args.output_img_dir 269 | masks_output_dir = args.output_mask_dir 270 | mode = args.mode 271 | 272 | split_size = args.split_size 273 | stride = args.stride 274 | 275 | if dataset_type == "vaihingen": 276 | get_vaihingen_file(imgs_dir,masks_dir,imgs_output_dir,masks_output_dir,split_size,stride,mode) 277 | elif dataset_type == "potsdam": 278 | get_potsdam_file(imgs_dir,masks_dir,imgs_output_dir,masks_output_dir,split_size,stride,mode) 279 | else: 280 | print("dataset_type error [vaihingen,potsdam]") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |  2 | 3 | # 🔥 News 4 | - `2025/7/3`: [TinyViM](https://arxiv.org/abs/2411.17473) has been accepted by ICCV 2025! It is an efficient and powerful backbone, and performs well on remote sensing image segmentation tasks. We have included the code of TinyViM in the rssegmentation repository. 5 | - `2025/3/18`: We add **TSNE map** and the **link of preprocessed dataset and checkpoints**. 6 | - `2025/3/14`: We fix some bugs and add **CAM** and **Throughput**. 7 | - `2025/2/11`: [LOGCAN++](https://arxiv.org/abs/2406.16502) has been accepted by TGRS2025! 8 | - `2025/1/24`: [SCSM](https://arxiv.org/abs/2501.13130) has been accepted by ISPRS2025! 9 | - `2024/10/11`: [SSA-Seg](https://arxiv.org/abs/2405.06525) has been accepted by NeurIPS2024! It is an effective and powerful classifier for semantic segmentation. We recommend interested researchers to optimize it for semantic segmentation in remote sensing, which is a promising direction. 10 | 11 | 12 | 13 | # 📷 Introduction 14 | 15 | **rssegmentation** is an open-source semantic segmentation toolbox, which is dedicated to reproducing and developing advanced methods for semantic segmentation of remote sensing images. 16 | 17 |
| 24 | Methods 25 | | 26 |27 | Datasets 28 | | 29 |30 | Tools 31 | | 32 |
35 |
|
44 | 45 | 51 | | 52 |
53 |
|
61 |