├── basicseg ├── __init__.py ├── utils │ ├── __init__.py │ ├── registry.py │ ├── dist_util.py │ ├── path_utils.py │ ├── yaml_options.py │ ├── lr_scheduler.py │ └── logger.py ├── metric │ ├── __init__.py │ └── iou_fscore.py ├── networks │ ├── common │ │ ├── __init__.py │ │ ├── layernorm.py │ │ ├── agpc │ │ │ ├── fusion.py │ │ │ ├── resnet.py │ │ │ └── context.py │ │ ├── upernet.py │ │ ├── attention.py │ │ ├── Dilatedconv.py │ │ ├── resnet.py │ │ └── conv.py │ ├── __init__.py │ └── HCFnet.py ├── loss │ ├── __init__.py │ ├── boundary_loss.py │ └── basic_loss.py ├── data │ ├── __init__.py │ ├── aug_fn.py │ └── load_data.py ├── test_model.py ├── seg_model.py ├── main_blocks.py └── base_model.py ├── requirements.txt ├── options ├── test.yaml └── train.yaml ├── README_CN.md ├── README.md ├── test.py ├── train.py └── LICENSE /basicseg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /basicseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /basicseg/metric/__init__.py: -------------------------------------------------------------------------------- 1 | from basicseg.metric.iou_fscore import Binary_metric -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.5.2 2 | fvcore==0.1.5.post20220512 3 | numpy==1.21.2 4 | opencv_python==4.5.3.56 5 | ptflops==0.6.9 6 | PyYAML==6.0 7 | scikit_image==0.18.3 8 | scipy==1.7.1 9 | skimage==0.0 10 | timm==0.4.12 11 | torch==1.9.1+cu111 12 | torchvision==0.10.1+cu111 13 | -------------------------------------------------------------------------------- /options/test.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | # save_dir: 3 | bs: 4 4 | device: 0 5 | 6 | model: 7 | net: 8 | type: HCFnet 9 | gt_ds: False 10 | 11 | 12 | dataset: 13 | test: 14 | type: Dataset_test 15 | data_root: /media/data2/zhengshuchen/code/SIRST/test 16 | img_sz: 512 17 | get_name: True 18 | 19 | resume: 20 | net_path: /media/data2/zhengshuchen/code/HCFNet/experiment/HCF_demo/20250327_112905/models/net_best_mean.pth 21 | -------------------------------------------------------------------------------- /basicseg/networks/common/__init__.py: -------------------------------------------------------------------------------- 1 | from basicseg.networks.common.resnet import resnet18, resnet18_d, resnet34, resnet34_d,\ 2 | resnet50, resnet50_d, resnet101, resnet101_d 3 | from basicseg.networks.common.attention import Double_attention, \ 4 | Position_attention, Channel_attention 5 | from basicseg.networks.common.layernorm import LayerNorm2d 6 | from basicseg.networks.common.conv import CDC_conv, ASPP, GatedConv2dWithActivation, DeformConv2d 7 | # from torch 8 | import torch.nn as nn 9 | 10 | def main(): 11 | pass 12 | 13 | if __name__ == '__main__': 14 | main() 15 | # def convert_conv2d(model, in_module, out_module, **kwargs): 16 | # model_output = model 17 | # if isinstance(model, in_module): 18 | # model_output = out -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | 2 | # HCF-Net 3 | ##
4 | HCFnet 是一个用于红外小目标分割的框架 5 | ## 数据集结构 6 | 如果你想要在自己的数据集上训练,你需要按照下列的结构准备数据: 7 | ``` 8 | |-SIRST 9 | |-trainval 10 | |-images 11 | |-xxx.png 12 | |-masks 13 | |-xxx.png 14 | |-test 15 | |-images 16 | |-xxx.png 17 | |-masks 18 | |-xxx.png 19 | ``` 20 | 21 | ## Training 22 | 23 | 使用下面的命令进行训练: 24 | 25 | ```train 26 | python train.py --opt ./options/train.yaml 27 | ``` 28 | ## Evaluation 29 | 30 | 31 | 使用下面的命令进行预测: 32 | 33 | ```eval 34 | python test.py --opt ./options/test.yaml 35 | ``` 36 | ## 预训练权重和结果 37 | 你可以下载预训练权重(我们还提供了整个训练的logs): 38 | 39 | - [HCF for SIRST](https://drive.google.com/drive/folders/1KljHLQjJVdMmaZXnkf1dtajtD8D28n7T?usp=drive_link) 40 | 41 | | Model name | IoU | nIoU | 42 | |------------|-------|-------| 43 | | UCF Net | 80.09 | 78.31 | 44 | -------------------------------------------------------------------------------- /basicseg/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | from copy import deepcopy 4 | from basicseg.utils.path_utils import scandir 5 | from basicseg.utils.registry import NET_REGISTRY 6 | from basicseg.networks.common import CDC_conv 7 | import torch.nn as nn 8 | __all__ = ['build_network'] 9 | 10 | arch_folder = osp.dirname(osp.abspath(__file__)) 11 | arch_filenames = [ 12 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 13 | if v.endswith('.py') 14 | ] 15 | # import all the arch modules 16 | _arch_modules = [ 17 | importlib.import_module(f'basicseg.networks.{file_name}') 18 | for file_name in arch_filenames 19 | ] 20 | 21 | def build_network(opt): 22 | opt = deepcopy(opt) 23 | network_type = opt.pop('type') 24 | net = NET_REGISTRY.get(network_type)(**opt) 25 | # logger = get_root_logger() 26 | # logger.info(f'Network [{net.__class__.__name__}] is created.') 27 | return net 28 | 29 | 30 | def main(): 31 | opt = {'type':'Fpn_res18'} 32 | net = build_network(opt) 33 | print(net) 34 | 35 | # if __name__ == '__main__': 36 | # main() 37 | -------------------------------------------------------------------------------- /basicseg/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | from copy import deepcopy 4 | from basicseg.utils.path_utils import scandir 5 | from basicseg.utils.registry import LOSS_REGISTRY 6 | __all__ = ['build_loss'] 7 | 8 | arch_folder = osp.dirname(osp.abspath(__file__)) 9 | arch_filenames = [ 10 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 11 | if v.endswith('.py') 12 | ] 13 | # import all the arch modules 14 | _arch_modules = [ 15 | importlib.import_module(f'basicseg.loss.{file_name}') 16 | for file_name in arch_filenames 17 | ] 18 | 19 | def build_loss(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | opt.pop('weight') 23 | net = LOSS_REGISTRY.get(network_type)(**opt) 24 | # logger = get_root_logger() 25 | # logger.info(f'Network [{net.__class__.__name__}] is created.') 26 | return net 27 | 28 | def main(): 29 | opt = {'type':'Bce_loss'} 30 | loss = build_loss(opt) 31 | import torch 32 | pred = torch.rand(2,1,512,512) 33 | mask = torch.rand(2,1,512,512) 34 | print(loss(pred, mask)) 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /basicseg/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | from copy import deepcopy 4 | from basicseg.utils.path_utils import scandir 5 | from basicseg.utils.registry import DATASET_REGISTRY 6 | __all__ = ['build_dataset'] 7 | 8 | arch_folder = osp.dirname(osp.abspath(__file__)) 9 | arch_filenames = [ 10 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 11 | if v.endswith('.py') 12 | ] 13 | # import all the arch modules 14 | _arch_modules = [ 15 | importlib.import_module(f'basicseg.data.{file_name}') 16 | for file_name in arch_filenames 17 | ] 18 | 19 | def build_dataset(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = DATASET_REGISTRY.get(network_type)(opt) 23 | # logger = get_root_logger() 24 | # logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | 27 | 28 | 29 | def main(): 30 | opt = {'data_root':"E:/graduate/dataset/Sirst/train", 'type':'Dataset_test', 'imgsz':512} 31 | dataset = build_dataset(opt) 32 | img, mask = (dataset.__getitem__(0)) 33 | print(img.shape, mask.shape) 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # HCF-Net 3 | ## 4 | HCF-Net is a framework for Infrared Small Object Segmentation 5 | ## Dataset Structe 6 | If you want to train on custom datasets you should paper dataset as following structure: 7 | ``` 8 | |-SIRST 9 | |-trainval 10 | |-images 11 | |-xxx.png 12 | |-masks 13 | |-xxx.png 14 | |-test 15 | |-images 16 | |-xxx.png 17 | |-masks 18 | |-xxx.png 19 | ``` 20 | ## Training 21 | 22 | To train the model, run this command: 23 | 24 | ```train 25 | python train.py --opt ./options/train.yaml 26 | ``` 27 | ## Evaluation 28 | 29 | 30 | To evaluate pretrained model, run: 31 | 32 | ```eval 33 | python test.py --opt ./options/test.yaml 34 | ``` 35 | ## Pre-trained Models and Results 36 | 37 | You can download pretrained models (we also provide the whole training logs) here: 38 | 39 | - [HCF for SIRST](https://drive.google.com/drive/folders/1KljHLQjJVdMmaZXnkf1dtajtD8D28n7T?usp=drive_link) 40 | 41 | 42 | | Model name | IoU | nIoU | 43 | |------------|-------|-------| 44 | | UCF Net | 80.09 | 78.31 | 45 | -------------------------------------------------------------------------------- /basicseg/loss/boundary_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import distance_transform_edt as eucl_distance 3 | from basicseg.utils.registry import LOSS_REGISTRY 4 | import torch.nn as nn 5 | import torch 6 | ''' 7 | boundary loss at: https://arxiv.org/abs/1812.07032 8 | modified from: https://github.com/LIVIAETS/boundary-loss only for binary classification 9 | ''' 10 | def get_dist_map(mask): 11 | # 注:mask numpy:array [h,w] 12 | # res = np.zeros_like(mask, dtype=np.float32) 13 | mask = mask.clone().detach() 14 | posmask = mask.numpy().astype(np.bool_) 15 | resolution = [1, 1] 16 | negmask = ~posmask 17 | res = eucl_distance(negmask, sampling=resolution) * negmask - ( 18 | eucl_distance(posmask, sampling=resolution) - 1) * posmask 19 | res = np.clip(res, a_min=0, a_max=None) 20 | return res 21 | 22 | @LOSS_REGISTRY.register() 23 | class BD_loss(nn.Module): 24 | def __init__(self, reduction='mean'): 25 | super(BD_loss, self).__init__() 26 | self.reduction = reduction 27 | def forward(self, pred, target): 28 | pred = torch.sigmoid(pred) 29 | bd_loss = pred * target 30 | if self.reduction == 'mean': 31 | return bd_loss.mean() 32 | elif self.reduction == 'sum': 33 | return bd_loss.sum() 34 | -------------------------------------------------------------------------------- /options/train.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: HCF_demo 3 | save_exp: True 4 | bs: 4 5 | total_epochs: 300 6 | log_interval: 1 7 | save_interval: 150 8 | test_interval: 1 9 | device: 0 10 | model: 11 | net: 12 | type: HCFnet 13 | gt_ds: True 14 | 15 | optim: 16 | type: AdamW 17 | # init_lr: !!float 1e-3 18 | init_lr: !!float 5e-4 19 | weight_decay: !!float 1e-4 20 | betas: [0.9, 0.999] 21 | # Iou_loss, Bce_loss, Dice_loss .. 22 | loss: 23 | loss_1: 24 | type: Bce_loss 25 | weight: 1 26 | loss_2: 27 | type: Iou_loss 28 | weight: 1 29 | # loss_3: 30 | # type: boundary 31 | # weight: 1 32 | # resume_train: ~ 33 | lr: 34 | warmup_iter: -1 # warmup to init_lr 35 | # type: CosineAnnealingLR / 36 | scheduler: 37 | # type: ~ 38 | type: CosineAnnealingLR 39 | step_interval: iter # iter or epoch (every iter or every epoch to update once) 40 | eta_min: !!float 1e-5 41 | 42 | dataset: 43 | name: 44 | train: 45 | type: Dataset_aug_bac 46 | data_root: /media/data2/zhengshuchen/code/nudt/trainval 47 | img_sz: 512 48 | 49 | test: 50 | type: Dataset_test 51 | data_root: /media/data2/zhengshuchen/code/nudt/test 52 | img_sz: 512 53 | 54 | resume: 55 | net_path: 56 | state_path: 57 | -------------------------------------------------------------------------------- /basicseg/networks/common/layernorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LayerNormFunction(torch.autograd.Function): 5 | 6 | @staticmethod 7 | def forward(ctx, x, weight, bias, eps): 8 | ctx.eps = eps 9 | N, C, H, W = x.size() 10 | mu = x.mean(1, keepdim=True) 11 | var = (x - mu).pow(2).mean(1, keepdim=True) 12 | y = (x - mu) / (var + eps).sqrt() 13 | ctx.save_for_backward(y, var, weight) 14 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 15 | return y 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | eps = ctx.eps 20 | 21 | N, C, H, W = grad_output.size() 22 | y, var, weight = ctx.saved_variables 23 | g = grad_output * weight.view(1, C, 1, 1) 24 | mean_g = g.mean(dim=1, keepdim=True) 25 | 26 | mean_gy = (g * y).mean(dim=1, keepdim=True) 27 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 28 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 29 | dim=0), None 30 | 31 | class LayerNorm2d(nn.Module): 32 | 33 | def __init__(self, channels, eps=1e-6): 34 | super(LayerNorm2d, self).__init__() 35 | self.register_parameter('weight', nn.Parameter(torch.ones(channels))) 36 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) 37 | self.eps = eps 38 | 39 | def forward(self, x): 40 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) -------------------------------------------------------------------------------- /basicseg/networks/common/agpc/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ['AsymFusionModule'] 7 | 8 | 9 | class AsymFusionModule(nn.Module): 10 | def __init__(self, planes_high, planes_low, planes_out): 11 | super(AsymFusionModule, self).__init__() 12 | self.pa = nn.Sequential( 13 | nn.Conv2d(planes_low, planes_low//4, kernel_size=1), 14 | nn.BatchNorm2d(planes_low//4), 15 | nn.ReLU(True), 16 | 17 | nn.Conv2d(planes_low//4, planes_low, kernel_size=1), 18 | nn.BatchNorm2d(planes_low), 19 | nn.Sigmoid(), 20 | ) 21 | self.plus_conv = nn.Sequential( 22 | nn.Conv2d(planes_high, planes_low, kernel_size=1), 23 | nn.BatchNorm2d(planes_low), 24 | nn.ReLU(True) 25 | ) 26 | self.ca = nn.Sequential( 27 | nn.AdaptiveAvgPool2d(1), 28 | nn.Conv2d(planes_low, planes_low//4, kernel_size=1), 29 | nn.BatchNorm2d(planes_low//4), 30 | nn.ReLU(True), 31 | 32 | nn.Conv2d(planes_low//4, planes_low, kernel_size=1), 33 | nn.BatchNorm2d(planes_low), 34 | nn.Sigmoid(), 35 | ) 36 | self.end_conv = nn.Sequential( 37 | nn.Conv2d(planes_low, planes_out, 3, 1, 1), 38 | nn.BatchNorm2d(planes_out), 39 | nn.ReLU(True), 40 | ) 41 | 42 | def forward(self, x_high, x_low): 43 | x_high = self.plus_conv(x_high) 44 | pa = self.pa(x_low) 45 | ca = self.ca(x_high) 46 | 47 | feat = x_low + x_high 48 | feat = self.end_conv(feat) 49 | feat = feat * ca 50 | feat = feat * pa 51 | return feat -------------------------------------------------------------------------------- /basicseg/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from basicseg.base_model import Base_model 4 | import copy 5 | from collections import OrderedDict 6 | from basicseg.metric import Binary_metric 7 | import torch.nn.functional as F 8 | 9 | class Test_model(Base_model): 10 | def __init__(self, opt): 11 | self.opt = opt 12 | self.init_model() 13 | 14 | def init_model(self): 15 | self.setup_net() 16 | self.metric = Binary_metric() 17 | 18 | def get_mean_metric(self, dist=False, reduction='mean'): 19 | if dist: 20 | return self.reduce_dict(self.metric.get_mean_result(), reduction) 21 | else: 22 | return self.dict_wrapper(self.metric.get_mean_result()) 23 | 24 | def get_norm_metric(self, dist=False, reduction='mean'): 25 | if dist: 26 | return self.reduce_dict(self.metric.get_norm_result(), reduction) 27 | else: 28 | return self.dict_wrapper(self.metric.get_norm_result()) 29 | def test_one_iter(self, data): 30 | with torch.no_grad(): 31 | img, mask = data 32 | img, mask = img.to(self.device), mask.to(self.device) 33 | pred = self.net(img) 34 | if isinstance(pred, (list, tuple)): 35 | pred = pred[0] 36 | pred = F.interpolate(pred, mask.shape[2:], mode='bilinear', align_corners=False) 37 | self.metric.update(pred=pred, target=mask) 38 | return pred 39 | def infer_one_iter(self, data): 40 | with torch.no_grad(): 41 | img = data 42 | img = img.to(self.device) 43 | pred = self.net(img) 44 | if isinstance(pred, (list, tuple)): 45 | pred = pred[0] 46 | # for loss_type, loss_fn in self.loss_fn.items(): 47 | # loss = loss_fn(pred, mask) 48 | # self.epoch_loss[loss_type] += loss.detach().clone() 49 | # self.batch_loss[loss_type] = loss.detach().clone() 50 | pred = F.interpolate(pred, img.shape[2:], mode='bilinear', align_corners=False) 51 | # self.metric.update(pred=pred, target=mask) 52 | return pred -------------------------------------------------------------------------------- /basicseg/loss/basic_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from basicseg.utils.registry import LOSS_REGISTRY 5 | 6 | @LOSS_REGISTRY.register() 7 | class Iou_loss(nn.Module): 8 | def __init__(self, reduction='mean'): 9 | super().__init__() 10 | self.reduction = reduction 11 | self.eps = 1e-6 12 | def forward(self, pred, target): 13 | pred = torch.sigmoid(pred) 14 | intersection = pred * target 15 | intersection_sum = torch.sum(intersection, dim=(1,2,3)) 16 | pred_sum = torch.sum(pred, dim=(1,2,3)) 17 | target_sum = torch.sum(target, dim=(1,2,3)) 18 | iou = intersection_sum / (pred_sum + target_sum - intersection_sum + self.eps) 19 | if self.reduction == 'mean': 20 | return 1 - iou.mean() 21 | elif self.reduction == 'sum': 22 | return 1 - iou.mean() 23 | else: 24 | raise NotImplementedError('reduction type {} not implemented'.format(self.reduction)) 25 | 26 | @LOSS_REGISTRY.register() 27 | class Dice_loss(nn.Module): 28 | def __init__(self, reduction='mean'): 29 | super().__init__() 30 | self.reduction = reduction 31 | self.eps = 1e-6 32 | def forward(self, pred, target): 33 | pred = torch.sigmoid(pred) 34 | intersection = torch.sum(pred * target, dim=(1,2,3)) 35 | total_sum = torch.sum((pred + target), dim=(1,2,3)) 36 | dice = 2 * intersection / (total_sum + self.eps) 37 | if self.reduction == 'mean': 38 | return 1 - dice.mean() 39 | elif self.reduction == 'sum': 40 | return 1 - dice.sum() 41 | else: 42 | raise NotImplementedError('reduction type {} not implemented'.format(self.reduction)) 43 | 44 | @LOSS_REGISTRY.register() 45 | class Bce_loss(nn.Module): 46 | def __init__(self, reduction='mean'): 47 | super().__init__() 48 | self.reduction = reduction 49 | self.eps = 1e-6 50 | def forward(self, pred, target): 51 | loss_fn = nn.BCEWithLogitsLoss(reduction=self.reduction) 52 | return loss_fn(pred, target) 53 | 54 | @LOSS_REGISTRY.register() 55 | class L1_loss(nn.Module): 56 | def __init__(self, reduction='mean'): 57 | super().__init__() 58 | self.reduction = reduction 59 | self.eps = 1e-6 60 | def forward(self, pred, target): 61 | loss_fn = nn.L1Loss(reduction=self.reduction) 62 | return loss_fn(pred, target) -------------------------------------------------------------------------------- /basicseg/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicseg'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | NET_REGISTRY = Registry('net') 86 | LOSS_REGISTRY = Registry('loss') 87 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as Data 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | from basicseg.test_model import Test_model 8 | from basicseg.utils.yaml_options import parse_options, dict2str 9 | from basicseg.utils.path_utils import * 10 | from basicseg.data import build_dataset 11 | 12 | def init_dataset(opt): 13 | test_opt = opt['dataset']['test'] 14 | testset = build_dataset(test_opt) 15 | return testset 16 | 17 | def init_dataloader(opt, testset): 18 | test_loader = Data.DataLoader(dataset=testset, batch_size=opt['exp']['bs'], 19 | sampler=None, num_workers=opt['exp'].get('nw', 8)) 20 | return test_loader 21 | 22 | def tensor2img(inp): 23 | # [b,1,h,w] -> [b,h,w]-> cpu -> numpy.array -> np.uint8 24 | # we don't do binarize here, 25 | # if you want to only contain 0 and 255, you can modify code here 26 | inp = torch.sigmoid(inp) * 255. 27 | inp = inp.squeeze(1).cpu().numpy().astype(np.uint8) 28 | return inp 29 | 30 | def save_batch_img(imgs, img_names, dire): 31 | for i in range(len(imgs)): 32 | img = imgs[i] 33 | img_name = img_names[i] 34 | img_path = os.path.join(dire, img_name) 35 | cv2.imwrite(img_path, img) 36 | 37 | def main(): 38 | opt, args = parse_options() 39 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt['exp']['device']) 40 | # init dataset 41 | testset = init_dataset(opt) 42 | test_loader = init_dataloader(opt, testset) 43 | # initialize parameters including network, optimizer, loss function, learning rate scheduler 44 | model = Test_model(opt) 45 | save_dir = opt['exp'].get('save_dir', False) 46 | if not os.path.exists(save_dir): 47 | os.makedirs(save_dir) 48 | # load model params 49 | if opt.get('resume'): 50 | if opt['resume'].get('net_path'): 51 | model.load_network(model.net, opt['resume']['net_path']) 52 | print(f'load pretrained network from: {opt["resume"]["net_path"]}') 53 | 54 | model.net.eval() 55 | for idx, data in enumerate(tqdm(test_loader)): 56 | img, label, img_name = data 57 | with torch.no_grad(): 58 | pred = model.test_one_iter((img, label)) 59 | if save_dir: 60 | img_np = tensor2img(pred) 61 | save_batch_img(img_np, img_name, save_dir) 62 | test_mean_metric = model.get_mean_metric() 63 | test_norm_metric = model.get_norm_metric() 64 | ########## trainging done ########## 65 | print(f"best_mean_metric: [miou: {test_mean_metric['iou']:.4f}] [mfscore: {test_mean_metric['fscore']:.4f}]") 66 | print(f"best_norm_metric: [niou: {test_norm_metric['iou']:.4f}] [nfscore: {test_norm_metric['fscore']:.4f}]") 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /basicseg/metric/iou_fscore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | 7 | class Binary_metric(): 8 | 'calculate fscore and iou' 9 | def __init__(self, thr=0.5): 10 | self.mean_reset() 11 | self.norm_reset() 12 | self.thr = thr 13 | self.cnt = 0 14 | 15 | def mean_reset(self): 16 | self.tp = 0 17 | self.tn = 0 18 | self.fp = 0 19 | self.fn = 0 20 | 21 | def norm_reset(self): 22 | self.norm_metric = {'precision':0., 'recall':0., 'fscore':0., 'iou':0.} 23 | self.cnt = 0. 24 | 25 | def update(self, pred, target): 26 | # for safety 27 | pred = pred.detach().clone() 28 | target = target.detach().clone() 29 | if torch.max(pred) > 1.1: 30 | pred = torch.sigmoid(pred) 31 | pred[pred >= self.thr] = 1 32 | pred[pred < self.thr] = 0 33 | self.cur_tp = torch.sum((pred == target) * target, dim=(1,2,3)) 34 | self.cur_tn = torch.sum((pred == target) * (1 - target), dim=(1,2,3)) 35 | self.cur_fp = torch.sum((pred != target) * pred, dim=(1,2,3)) 36 | self.cur_fn = torch.sum((pred != target) * (1 - pred), dim=(1,2,3)) 37 | self.tp += self.cur_tp.sum() 38 | self.tn += self.cur_tn.sum() 39 | self.fp += self.cur_fp.sum() 40 | self.fn += self.cur_fn.sum() 41 | norm_result = self.norm_compute() 42 | for k in self.norm_metric.keys(): 43 | self.norm_metric[k] += norm_result[k] 44 | self.cnt += pred.shape[0] 45 | 46 | def get_mean_result(self): 47 | mean_metric = self.mean_compute() 48 | self.mean_reset() 49 | return mean_metric 50 | 51 | def get_norm_result(self): 52 | for k,v in self.norm_metric.items(): 53 | self.norm_metric[k] /= self.cnt 54 | norm_metric = deepcopy(self.norm_metric) 55 | self.norm_reset() 56 | return norm_metric 57 | 58 | def norm_compute(self): 59 | eps = 1e-6 60 | precision = (self.cur_tp / (self.cur_tp + self.cur_fp + eps)).sum() 61 | recall = (self.cur_tp / (self.cur_tp + self.cur_fn + eps)).sum() 62 | fscore = (2 * precision * recall / (precision + recall + eps)).sum() 63 | iou = (self.cur_tp / (self.cur_tp + self.cur_fn + self.cur_fp + eps)).sum() 64 | return {"precision":precision, "recall":recall, "fscore":fscore, "iou":iou} 65 | 66 | def mean_compute(self): 67 | eps = 1e-6 68 | precision = self.tp / (self.tp + self.fp + eps) 69 | recall = self.tp / (self.tp + self.fn + eps) 70 | fscore = 2 * precision * recall / (precision + recall + eps) 71 | iou = self.tp / (self.tp + self.fn + self.fp + eps) 72 | return {"precision":precision, "recall":recall, "fscore":fscore, "iou":iou} -------------------------------------------------------------------------------- /basicseg/networks/common/upernet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Upernet_head(nn.Module): 6 | def __init__(self, in_planes=[64, 128, 256, 512]): 7 | super().__init__() 8 | scale = [1, 2, 3, 6] 9 | fpn_dim = 256 10 | self.ppm_module = nn.ModuleList() 11 | for i in range(len(scale)): 12 | self.ppm_module.append(nn.Sequential( 13 | nn.AdaptiveAvgPool2d(scale[i]), 14 | nn.Conv2d(in_planes[-1], 512, kernel_size=1, bias=False), 15 | nn.BatchNorm2d(512), 16 | nn.ReLU(inplace=True) 17 | )) 18 | self.ppm_fuse = nn.Sequential( 19 | nn.Conv2d(in_planes[-1] + len(scale) * 512, fpn_dim, kernel_size=1, bias=False), 20 | nn.BatchNorm2d(fpn_dim), 21 | nn.ReLU(inplace=True) 22 | ) 23 | self.fpn_module = nn.ModuleList() 24 | self.fpn_smooth = nn.ModuleList() 25 | for i in range(len(in_planes) - 1): 26 | self.fpn_module.append(nn.Sequential( 27 | nn.Conv2d(in_planes[i], fpn_dim, kernel_size=1, bias=False), 28 | nn.BatchNorm2d(fpn_dim), 29 | nn.ReLU(inplace=True) 30 | )) 31 | self.fpn_smooth.append(nn.Sequential( 32 | nn.Conv2d(fpn_dim, fpn_dim, kernel_size=1, bias=False), 33 | nn.BatchNorm2d(fpn_dim), 34 | nn.ReLU(inplace=True) 35 | )) 36 | 37 | self.fpn_fuse = nn.Sequential( 38 | nn.Conv2d(len(in_planes) * fpn_dim, fpn_dim, kernel_size=1, bias=False), 39 | nn.BatchNorm2d(fpn_dim), 40 | nn.ReLU(inplace=True) 41 | ) 42 | 43 | def forward(self, x): 44 | feat_maps = x 45 | feat_top = feat_maps[-1] 46 | ppm_size = feat_top.shape[-2:] 47 | ppm_out = [] 48 | ppm_out.append(feat_top) 49 | for i in range(len(self.ppm_module)): 50 | out = self.ppm_module[i](feat_top) 51 | ppm_out.append(F.interpolate(out, size=ppm_size, mode='bilinear', align_corners=False)) 52 | ppm_out = torch.cat(ppm_out, dim=1) 53 | ppm_out = self.ppm_fuse(ppm_out) 54 | 55 | fpn_out = [] 56 | fpn_out.append(ppm_out) 57 | f = ppm_out 58 | for i in reversed(range(len(self.fpn_module))): 59 | size = feat_maps[i].shape[-2:] 60 | out = self.fpn_module[i](feat_maps[i]) 61 | f = out + F.interpolate(f, size, mode='bilinear', align_corners=False) 62 | fpn_out.append(self.fpn_smooth[i](f)) 63 | fpn_out.reverse() 64 | fpn_fush = [] 65 | fpn_fush.append(fpn_out[0]) 66 | for i in range(1, len(fpn_out)): 67 | size = fpn_out[0].shape[-2:] 68 | fpn_fush.append(F.interpolate(fpn_out[i], size, mode='bilinear', align_corners=False)) 69 | out = torch.cat(fpn_fush, dim=1) 70 | out = self.fpn_fuse(out) 71 | return out -------------------------------------------------------------------------------- /basicseg/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | 8 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 9 | import functools 10 | import os 11 | import subprocess 12 | import torch 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | 16 | 17 | def init_dist(launcher, backend='nccl', **kwargs): 18 | if mp.get_start_method(allow_none=True) is None: 19 | mp.set_start_method('spawn') 20 | if launcher == 'pytorch': 21 | _init_dist_pytorch(backend, **kwargs) 22 | elif launcher == 'slurm': 23 | _init_dist_slurm(backend, **kwargs) 24 | else: 25 | raise ValueError(f'Invalid launcher type: {launcher}') 26 | 27 | 28 | def _init_dist_pytorch(backend, **kwargs): 29 | rank = int(os.environ['RANK']) 30 | num_gpus = torch.cuda.device_count() 31 | torch.cuda.set_device(rank % num_gpus) 32 | dist.init_process_group(backend=backend, **kwargs) 33 | 34 | 35 | def _init_dist_slurm(backend, port=None): 36 | """Initialize slurm distributed training environment. 37 | 38 | If argument ``port`` is not specified, then the master port will be system 39 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 40 | environment variable, then a default port ``29500`` will be used. 41 | 42 | Args: 43 | backend (str): Backend of torch.distributed. 44 | port (int, optional): Master port. Defaults to None. 45 | """ 46 | proc_id = int(os.environ['SLURM_PROCID']) 47 | ntasks = int(os.environ['SLURM_NTASKS']) 48 | node_list = os.environ['SLURM_NODELIST'] 49 | num_gpus = torch.cuda.device_count() 50 | torch.cuda.set_device(proc_id % num_gpus) 51 | addr = subprocess.getoutput( 52 | f'scontrol show hostname {node_list} | head -n1') 53 | # specify master port 54 | if port is not None: 55 | os.environ['MASTER_PORT'] = str(port) 56 | elif 'MASTER_PORT' in os.environ: 57 | pass # use MASTER_PORT in the environment variable 58 | else: 59 | # 29500 is torch.distributed default port 60 | os.environ['MASTER_PORT'] = '29500' 61 | os.environ['MASTER_ADDR'] = addr 62 | os.environ['WORLD_SIZE'] = str(ntasks) 63 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 64 | os.environ['RANK'] = str(proc_id) 65 | dist.init_process_group(backend=backend) 66 | 67 | 68 | def get_dist_info(): 69 | if dist.is_available(): 70 | initialized = dist.is_initialized() 71 | else: 72 | initialized = False 73 | if initialized: 74 | rank = dist.get_rank() 75 | world_size = dist.get_world_size() 76 | else: 77 | rank = 0 78 | world_size = 1 79 | return rank, world_size 80 | 81 | 82 | def master_only(func): 83 | @functools.wraps(func) 84 | def wrapper(*args, **kwargs): 85 | rank, _ = get_dist_info() 86 | if rank == 0: 87 | return func(*args, **kwargs) 88 | 89 | return wrapper 90 | -------------------------------------------------------------------------------- /basicseg/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from basicseg.utils.dist_util import master_only 4 | import time 5 | 6 | @master_only 7 | def make_dir(root): 8 | if not os.path.exists(root): 9 | os.makedirs(root) 10 | 11 | def get_time_str(): 12 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 13 | 14 | def make_exp_root(root): 15 | exp_root = os.path.join(root, get_time_str()) 16 | make_dir(exp_root) 17 | return exp_root 18 | 19 | 20 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 21 | """Scan a directory to find the interested files. 22 | 23 | Args: 24 | dir_path (str): Path of the directory. 25 | suffix (str | tuple(str), optional): File suffix that we are 26 | interested in. Default: None. 27 | recursive (bool, optional): If set to True, recursively scan the 28 | directory. Default: False. 29 | full_path (bool, optional): If set to True, include the dir_path. 30 | Default: False. 31 | 32 | Returns: 33 | A generator for all the interested files with relative pathes. 34 | """ 35 | 36 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 37 | raise TypeError('"suffix" must be a string or tuple of strings') 38 | 39 | root = dir_path 40 | 41 | def _scandir(dir_path, suffix, recursive): 42 | for entry in os.scandir(dir_path): 43 | if not entry.name.startswith('.') and entry.is_file(): 44 | if full_path: 45 | return_path = entry.path 46 | else: 47 | return_path = os.path.relpath(entry.path, root) 48 | 49 | if suffix is None: 50 | yield return_path 51 | elif return_path.endswith(suffix): 52 | yield return_path 53 | else: 54 | if recursive: 55 | yield from _scandir( 56 | entry.path, suffix=suffix, recursive=recursive) 57 | else: 58 | continue 59 | 60 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 61 | 62 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): 63 | """Scan a directory to find the interested files. 64 | 65 | Args: 66 | dir_path (str): Path of the directory. 67 | keywords (str | tuple(str), optional): File keywords that we are 68 | interested in. Default: None. 69 | recursive (bool, optional): If set to True, recursively scan the 70 | directory. Default: False. 71 | full_path (bool, optional): If set to True, include the dir_path. 72 | Default: False. 73 | 74 | Returns: 75 | A generator for all the interested files with relative pathes. 76 | """ 77 | 78 | if (keywords is not None) and not isinstance(keywords, (str, tuple)): 79 | raise TypeError('"keywords" must be a string or tuple of strings') 80 | 81 | root = dir_path 82 | 83 | def _scandir(dir_path, keywords, recursive): 84 | for entry in os.scandir(dir_path): 85 | if not entry.name.startswith('.') and entry.is_file(): 86 | if full_path: 87 | return_path = entry.path 88 | else: 89 | return_path = os.path.relpath(entry.path, root) 90 | 91 | if keywords is None: 92 | yield return_path 93 | elif return_path.find(keywords) > 0: 94 | yield return_path 95 | else: 96 | if recursive: 97 | yield from _scandir( 98 | entry.path, keywords=keywords, recursive=recursive) 99 | else: 100 | continue 101 | 102 | return _scandir(dir_path, keywords=keywords, recursive=recursive) 103 | 104 | if __name__ == '__main__': 105 | time_ = get_time_str() 106 | print(time_) -------------------------------------------------------------------------------- /basicseg/utils/yaml_options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | import argparse 4 | 5 | def ordered_yaml(): 6 | """Support OrderedDict for yaml. 7 | 8 | Returns: 9 | yaml Loader and Dumper. 10 | """ 11 | try: 12 | from yaml import CDumper as Dumper 13 | from yaml import CLoader as Loader 14 | except ImportError: 15 | from yaml import Dumper, Loader 16 | 17 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 18 | 19 | def dict_representer(dumper, data): 20 | return dumper.represent_dict(data.items()) 21 | 22 | def dict_constructor(loader, node): 23 | return OrderedDict(loader.construct_pairs(node)) 24 | 25 | Dumper.add_representer(OrderedDict, dict_representer) 26 | Loader.add_constructor(_mapping_tag, dict_constructor) 27 | return Loader, Dumper 28 | 29 | def dict2str(opt, indent_level=1): 30 | """dict to string for printing options. 31 | 32 | Args: 33 | opt (dict): Option dict. 34 | indent_level (int): Indent level. Default: 1. 35 | 36 | Return: 37 | (str): Option string for printing. 38 | """ 39 | msg = '\n' 40 | for k, v in opt.items(): 41 | if isinstance(v, dict): 42 | msg += ' ' * (indent_level * 2) + k + ':[' 43 | msg += dict2str(v, indent_level + 1) 44 | msg += ' ' * (indent_level * 2) + ']\n' 45 | else: 46 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 47 | return msg 48 | 49 | def parse_options(): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--opt', type=str, required = True, help='Path to option YAML file.') 52 | parser.add_argument('--local_rank', type=int, default=-1) 53 | parser.add_argument('--device', default='0') 54 | parser.add_argument( 55 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 56 | args = parser.parse_args() 57 | 58 | # parse yml to dict 59 | with open(args.opt, mode='r') as f: 60 | opt = yaml.load(f, Loader=ordered_yaml()[0]) 61 | # opt['rank'], opt['world_size'] = get_dist_info() 62 | 63 | 64 | # force to update yml options 65 | # if args.force_yml is not None: 66 | # for entry in args.force_yml: 67 | # # now do not support creating new keys 68 | # keys, value = entry.split('=') 69 | # keys, value = keys.strip(), value.strip() 70 | # value = _postprocess_yml_value(value) 71 | # eval_str = 'opt' 72 | # for key in keys.split(':'): 73 | # eval_str += f'["{key}"]' 74 | # eval_str += '=value' 75 | # # using exec function 76 | # exec(eval_str) 77 | 78 | # datasets 79 | # for phase, dataset in opt['datasets'].items(): 80 | # # for multiple datasets, e.g., val_1, val_2; test_1, test_2 81 | # phase = phase.split('_')[0] 82 | # dataset['phase'] = phase 83 | # if 'scale' in opt: 84 | # dataset['scale'] = opt['scale'] 85 | # if dataset.get('dataroot_gt') is not None: 86 | # dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 87 | # if dataset.get('dataroot_lq') is not None: 88 | # dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 89 | 90 | # paths 91 | # for key, val in opt['path'].items(): 92 | # if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 93 | # opt['path'][key] = osp.expanduser(val) 94 | 95 | # if is_train: 96 | # experiments_root = osp.join(root_path, 'experiments', opt['name']) 97 | # opt['path']['experiments_root'] = experiments_root 98 | # opt['path']['models'] = osp.join(experiments_root, 'models') 99 | # opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 100 | # opt['path']['log'] = experiments_root 101 | # opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 102 | 103 | return opt, args -------------------------------------------------------------------------------- /basicseg/data/aug_fn.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | from albumentations.pytorch import ToTensorV2 3 | import cv2 4 | import random 5 | 6 | # using for testset, only resize the img 7 | 8 | def aug_transform_test(opt): 9 | transform_fn = [] 10 | transform_fn.append(A.Resize(opt['img_sz'], opt['img_sz'])) 11 | transform_fn.append(A.Normalize()) 12 | transform_fn.append(ToTensorV2()) 13 | transform_fn = A.Compose(transform_fn, is_check_shapes=False) 14 | return transform_fn 15 | 16 | def aug_transform_s(opt): 17 | transform_fn = [] 18 | transform_fn.append(A.Resize(opt['img_sz'], opt['img_sz'])) 19 | H_Flip = opt.get('H_Flip', False) 20 | if H_Flip: 21 | transform_fn.append(A.HorizontalFlip(p=H_Flip)) 22 | V_Flip = opt.get('V_Flip', False) 23 | if V_Flip: 24 | transform_fn.append(A.VerticalFlip(p=V_Flip)) 25 | transform_fn.append(A.Normalize()) 26 | transform_fn.append(ToTensorV2()) 27 | transform_fn = A.Compose(transform_fn, is_check_shapes=False) 28 | return transform_fn 29 | 30 | def aug_transform_bac(opt): 31 | img_sz = opt['img_sz'] 32 | random_sz = random.randint(int(img_sz * 0.5), int(img_sz * 1.5)) 33 | transform_fn = [] 34 | transform_fn.append(A.HorizontalFlip(p=0.5)) 35 | transform_fn.append((A.VerticalFlip(p=0.2))) 36 | transform_fn.append(A.LongestMaxSize(random_sz)) 37 | transform_fn.append(A.PadIfNeeded(img_sz, img_sz)) 38 | transform_fn.append(A.RandomCrop(img_sz, img_sz)) 39 | transform_fn.append(A.Normalize()) 40 | transform_fn.append(ToTensorV2()) 41 | transform_fn = A.Compose(transform_fn, is_check_shapes=False) 42 | return transform_fn 43 | 44 | 45 | def aug_transform_m(opt): 46 | img_sz = opt['img_sz'] 47 | transform = [ 48 | A.Resize(img_sz, img_sz), 49 | A.PadIfNeeded(min_height=img_sz, min_width=img_sz), 50 | A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5), 51 | A.RandomCrop(height=img_sz, width=img_sz), 52 | A.HorizontalFlip(p=0.5), 53 | A.VerticalFlip(p=0.1), 54 | A.Normalize(), 55 | ToTensorV2(), 56 | ] 57 | transform_fn = A.Compose(transform) 58 | return transform_fn 59 | 60 | def aug_transform_l(): 61 | pass 62 | 63 | def aug_transform_train(opt): 64 | img_sz = opt['img_sz'] # 确保最终输出大小为 opt['img_sz'] 65 | transform_fn = [] 66 | 67 | # 几何变换 68 | transform_fn.append(A.HorizontalFlip(p=0.5)) # 随机水平翻转 69 | transform_fn.append(A.VerticalFlip(p=0.1)) # 随机垂直翻转 70 | transform_fn.append(A.ShiftScaleRotate(shift_limit=0.03, scale_limit=0.05, rotate_limit=10, p=0.5)) # 平移、缩放、旋转 71 | 72 | # 尺寸调整和裁剪(确保输出尺寸固定) 73 | transform_fn.append(A.Resize(height=img_sz, width=img_sz)) # 强制调整为固定大小 74 | transform_fn.append(A.PadIfNeeded(min_height=img_sz, min_width=img_sz, border_mode=0, value=0)) # 填充区域的像素值设为0 # 补齐尺寸到固定大小(如果必要) 75 | 76 | # 颜色增强 77 | transform_fn.append(A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5)) # 随机亮度和对比度调整 78 | 79 | # 噪声增强 80 | transform_fn.append(A.GaussNoise(var_limit=(5.0, 15.0), p=0.3)) # 高斯噪声 81 | transform_fn.append(A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.3)) # 模拟红外传感器噪声 82 | 83 | # 标准化和张量转换 84 | transform_fn.append(A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) # 针对 RGB 红外图像的归一化 85 | transform_fn.append(ToTensorV2()) 86 | 87 | transform_fn = A.Compose(transform_fn, is_check_shapes=False) 88 | return transform_fn 89 | 90 | def aug_transform_test_new(opt): 91 | img_sz = opt['img_sz'] # 确保最终输出大小为 opt['img_sz'] 92 | transform_fn = [] 93 | 94 | # 尺寸调整 95 | transform_fn.append(A.Resize(height=img_sz, width=img_sz)) # 强制调整为固定大小 96 | transform_fn.append(A.PadIfNeeded(min_height=img_sz, min_width=img_sz, border_mode=0, value=0)) # 填充区域的像素值设为0 # 补齐尺寸到固定大小(如果必要) 97 | 98 | # 标准化和张量转换 99 | transform_fn.append(A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) # 针对 RGB 红外图像的归一化 100 | transform_fn.append(ToTensorV2()) 101 | 102 | transform_fn = A.Compose(transform_fn, is_check_shapes=False) 103 | return transform_fn -------------------------------------------------------------------------------- /basicseg/networks/common/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | 6 | class Position_attention(nn.Module): 7 | def __init__(self, in_c, mid_c=None): 8 | super().__init__() 9 | mid_c = mid_c or in_c // 8 10 | self.q = nn.Conv2d(in_c, mid_c, kernel_size=1) 11 | self.k = nn.Conv2d(in_c, mid_c, kernel_size=1) 12 | self.v = nn.Conv2d(in_c, in_c, kernel_size=1) 13 | self.gamma = nn.Parameter(torch.zeros(1)) 14 | self.softmax = nn.Softmax(dim=-1) 15 | 16 | def forward(self, x): 17 | b, _, h, w = x.shape 18 | q = self.q(x).view(b, -1, h * w).permute(0, 2, 1) # bs, hw, c 19 | k = self.k(x).view(b, -1, h * w) # bs, c ,hw 20 | v = self.v(x).view(b, -1, h * w) # bs, c, hw 21 | # att = self.softmax(q @ k) 22 | att = self.softmax(torch.bmm(q,k)) 23 | # out = (v @ att.permute(0, 2, 1)).view(b, -1, h, w) 24 | out = torch.bmm(v, att.permute(0, 2, 1)).view(b, -1, h, w) 25 | out = self.gamma * out + x 26 | 27 | return out 28 | 29 | 30 | class Channel_attention(nn.Module): 31 | def __init__(self, in_c): 32 | super().__init__() 33 | self.in_c = in_c 34 | self.softmax = nn.Softmax(dim=-1) 35 | self.gamma = nn.Parameter(torch.zeros(1)) 36 | 37 | def forward(self, x): 38 | b, _, h, w = x.shape 39 | q = x.view(b, -1, h * w) # bs, c ,hw 40 | k = x.view(b, -1, h * w).permute(0, 2, 1) # bs, hw, c 41 | v = x.view(b, -1, h * w) # bs, c, hw 42 | att = self.softmax(q @ k) # b, c, c 43 | out = att @ v 44 | out = out.view(b, -1, h, w) 45 | out = self.gamma * out + x 46 | return out 47 | 48 | 49 | class Double_attention(nn.Module): 50 | def __init__(self, in_c, mid_c=None): 51 | super().__init__() 52 | self.pam = Position_attention(in_c, mid_c) 53 | self.cam = Channel_attention(in_c) 54 | self.relu = nn.ReLU() 55 | 56 | def forward(self, x): 57 | pam_out = self.pam(x) 58 | cam_out = self.cam(x) 59 | return pam_out + cam_out 60 | 61 | 62 | class External_attention(nn.Module): 63 | ''' 64 | Arguments: 65 | c (int): The input and output channel number. 66 | ''' 67 | 68 | def __init__(self, c): 69 | super(External_attention, self).__init__() 70 | 71 | self.conv1 = nn.Conv2d(c, c, 1) 72 | 73 | self.k = 64 74 | self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False) 75 | 76 | self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False) 77 | self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2) 78 | 79 | self.conv2 = nn.Sequential( 80 | nn.Conv2d(c, c, 1, bias=False), 81 | nn.BatchNorm2d(c)) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.Conv1d): 88 | n = m.kernel_size[0] * m.out_channels 89 | m.weight.data.normal_(0, math.sqrt(2. / n)) 90 | elif isinstance(m, nn.BatchNorm2d): 91 | m.weight.data.fill_(1) 92 | if m.bias is not None: 93 | m.bias.data.zero_() 94 | 95 | def forward(self, x): 96 | idn = x 97 | x = self.conv1(x) 98 | 99 | b, c, h, w = x.size() 100 | n = h * w 101 | x = x.view(b, c, h * w) # b * c * n 102 | 103 | attn = self.linear_0(x) # b, k, n 104 | attn = F.softmax(attn, dim=-1) # b, k, n 105 | 106 | attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) # # b, k, n 107 | x = self.linear_1(attn) # b, c, n 108 | 109 | x = x.view(b, c, h, w) 110 | x = self.conv2(x) 111 | x = x + idn 112 | x = F.relu(x) 113 | return x 114 | 115 | def main(): 116 | x = torch.rand(3,512,64,64) 117 | EA = External_attention(256) 118 | out = EA(x) 119 | print(out.shape) 120 | 121 | if __name__ == '__main__': 122 | from fvcore.nn import FlopCountAnalysis, parameter_count_table 123 | model = Double_attention(512) 124 | x = torch.rand(1,512,32,32) 125 | flopts = FlopCountAnalysis(model, x) 126 | print('FLOPS: ',flopts.total()) 127 | print('PARAMS: ', parameter_count_table(model)) 128 | import ptflops 129 | GMacs,Params = ptflops.get_model_complexity_info(model, (512,32,32)) 130 | print(GMacs, Params) -------------------------------------------------------------------------------- /basicseg/networks/common/Dilatedconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | 5 | class MultidilatedConv(nn.Module): 6 | def __init__(self, in_dim, out_dim, kernel_size=3, dilation_num=3, comb_mode='sum', equal_dim=True, 7 | shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): 8 | super().__init__() 9 | convs = [] 10 | self.equal_dim = equal_dim 11 | assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode 12 | if comb_mode in ('cat_out', 'cat_both'): 13 | self.cat_out = True 14 | if equal_dim: 15 | assert out_dim % dilation_num == 0 16 | out_dims = [out_dim // dilation_num] * dilation_num 17 | self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) 18 | else: 19 | out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 20 | out_dims.append(out_dim - sum(out_dims)) 21 | index = [] 22 | starts = [0] + out_dims[:-1] 23 | lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] 24 | for i in range(out_dims[-1]): 25 | for j in range(dilation_num): 26 | index += list(range(starts[j], starts[j] + lengths[j])) 27 | starts[j] += lengths[j] 28 | self.index = index 29 | assert(len(index) == out_dim) 30 | self.out_dims = out_dims 31 | else: 32 | self.cat_out = False 33 | self.out_dims = [out_dim] * dilation_num 34 | 35 | if comb_mode in ('cat_in', 'cat_both'): 36 | if equal_dim: 37 | assert in_dim % dilation_num == 0 38 | in_dims = [in_dim // dilation_num] * dilation_num 39 | else: 40 | in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 41 | in_dims.append(in_dim - sum(in_dims)) 42 | self.in_dims = in_dims 43 | self.cat_in = True 44 | else: 45 | self.cat_in = False 46 | self.in_dims = [in_dim] * dilation_num 47 | 48 | conv_type = nn.Conv2d 49 | dilation = min_dilation 50 | for i in range(dilation_num): 51 | if isinstance(padding, int): 52 | cur_padding = padding * dilation 53 | else: 54 | cur_padding = padding[i] 55 | convs.append(conv_type( 56 | self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs 57 | )) 58 | if i > 0 and shared_weights: 59 | convs[-1].weight = convs[0].weight 60 | convs[-1].bias = convs[0].bias 61 | dilation *= 2 62 | self.convs = nn.ModuleList(convs) 63 | 64 | self.shuffle_in_channels = shuffle_in_channels 65 | if self.shuffle_in_channels: 66 | # shuffle list as shuffling of tensors is nondeterministic 67 | in_channels_permute = list(range(in_dim)) 68 | random.shuffle(in_channels_permute) 69 | # save as buffer so it is saved and loaded with checkpoint 70 | self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) 71 | 72 | def forward(self, x): 73 | if self.shuffle_in_channels: 74 | x = x[:, self.in_channels_permute] 75 | 76 | outs = [] 77 | if self.cat_in: 78 | if self.equal_dim: 79 | x = x.chunk(len(self.convs), dim=1) 80 | else: 81 | new_x = [] 82 | start = 0 83 | for dim in self.in_dims: 84 | new_x.append(x[:, start:start+dim]) 85 | start += dim 86 | x = new_x 87 | for i, conv in enumerate(self.convs): 88 | if self.cat_in: 89 | input = x[i] 90 | else: 91 | input = x 92 | outs.append(conv(input)) 93 | if self.cat_out: 94 | out = torch.cat(outs, dim=1)[:, self.index] 95 | else: 96 | out = sum(outs) 97 | return out 98 | 99 | def main(): 100 | x = torch.rand(1,3,512,512) 101 | net = MultidilatedConv(3,64) 102 | out = net(x) 103 | print(out.shape) 104 | import ptflops 105 | macs,params = ptflops.get_model_complexity_info(net, (3,512,512)) 106 | print(macs, params) 107 | 108 | if __name__ == "__main__": 109 | main() -------------------------------------------------------------------------------- /basicseg/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from collections import Counter 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | class PolyLR(_LRScheduler): 7 | def __init__(self, optimizer, T_max, power=1, eta_min=0, last_epoch=-1, verbose=False): 8 | self.T_max = T_max 9 | self.power = power 10 | self.eta_min = eta_min 11 | super(PolyLR, self).__init__(optimizer, last_epoch, verbose) 12 | def get_lr(self): 13 | return [max((base_lr * (1 - (self.last_epoch / self.T_max))**self.power), self.eta_min) \ 14 | for base_lr in self.base_lrs] 15 | 16 | class MultiStepRestartLR(_LRScheduler): 17 | """ MultiStep with restarts learning rate scheme. 18 | 19 | Args: 20 | optimizer (torch.nn.optimizer): Torch optimizer. 21 | milestones (list): Iterations that will decrease learning rate. 22 | gamma (float): Decrease ratio. Default: 0.1. 23 | restarts (list): Restart iterations. Default: [0]. 24 | restart_weights (list): Restart weights at each restart iteration. 25 | Default: [1]. 26 | last_epoch (int): Used in _LRScheduler. Default: -1. 27 | """ 28 | 29 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 30 | self.milestones = Counter(milestones) 31 | self.gamma = gamma 32 | self.restarts = restarts 33 | self.restart_weights = restart_weights 34 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 35 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | if self.last_epoch in self.restarts: 39 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 40 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 41 | if self.last_epoch not in self.milestones: 42 | return [group['lr'] for group in self.optimizer.param_groups] 43 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 44 | 45 | 46 | def get_position_from_periods(iteration, cumulative_period): 47 | """Get the position from a period list. 48 | 49 | It will return the index of the right-closest number in the period list. 50 | For example, the cumulative_period = [100, 200, 300, 400], 51 | if iteration == 50, return 0; 52 | if iteration == 210, return 2; 53 | if iteration == 300, return 2. 54 | 55 | Args: 56 | iteration (int): Current iteration. 57 | cumulative_period (list[int]): Cumulative period list. 58 | 59 | Returns: 60 | int: The position of the right-closest number in the period list. 61 | """ 62 | for i, period in enumerate(cumulative_period): 63 | if iteration <= period: 64 | return i 65 | 66 | 67 | class CosineAnnealingRestartLR(_LRScheduler): 68 | """ Cosine annealing with restarts learning rate scheme. 69 | 70 | An example of config: 71 | periods = [10, 10, 10, 10] 72 | restart_weights = [1, 0.5, 0.5, 0.5] 73 | eta_min=1e-7 74 | 75 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 76 | scheduler will restart with the weights in restart_weights. 77 | 78 | Args: 79 | optimizer (torch.nn.optimizer): Torch optimizer. 80 | periods (list): Period for each cosine anneling cycle. 81 | restart_weights (list): Restart weights at each restart iteration. 82 | Default: [1]. 83 | eta_min (float): The minimum lr. Default: 0. 84 | last_epoch (int): Used in _LRScheduler. Default: -1. 85 | """ 86 | 87 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 88 | self.periods = periods 89 | self.restart_weights = restart_weights 90 | self.eta_min = eta_min 91 | assert (len(self.periods) == len( 92 | self.restart_weights)), 'periods and restart_weights should have the same length.' 93 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 94 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 95 | 96 | def get_lr(self): 97 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 98 | current_weight = self.restart_weights[idx] 99 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 100 | current_period = self.periods[idx] 101 | 102 | return [ 103 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 104 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 105 | for base_lr in self.base_lrs 106 | ] 107 | -------------------------------------------------------------------------------- /basicseg/data/load_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.utils.data as Data 4 | import numpy as np 5 | import cv2 6 | from basicseg.data.aug_fn import * 7 | from basicseg.utils.registry import DATASET_REGISTRY 8 | from basicseg.loss.boundary_loss import get_dist_map 9 | """ 10 | //train 11 | //images 12 | //masks 13 | //test 14 | //images 15 | //masks 16 | 17 | """ 18 | class Basedataset(Data.Dataset): 19 | def __init__(self, opt): 20 | super().__init__() 21 | self.opt = opt 22 | self.image_root = os.path.join(opt['data_root'], 'images') 23 | self.mask_root = os.path.join(opt['data_root'], 'masks') 24 | self.images = os.listdir(self.image_root) 25 | self.get_name = opt.get('get_name', False) 26 | self.bd_loss = opt.get('bd_loss', False) 27 | def __len__(self): 28 | return len(self.images) 29 | def setup_transform_fn(self): 30 | return None 31 | def __getitem__(self, index): 32 | image_path = os.path.join(self.image_root, self.images[index]) 33 | mask_path = image_path.replace(self.image_root, self.mask_root) 34 | img = np.array(cv2.imread(image_path, cv2.IMREAD_COLOR)) 35 | mask = np.array(cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)) 36 | mask = (mask / 255.).astype(np.float32) 37 | # # 设置保存文件夹路径 38 | # save_dir_image = "/media/data2/zhengshuchen/code/BasicISOS/test_image/images" 39 | # save_dir_mask = "/media/data2/zhengshuchen/code/BasicISOS/test_image/masks" 40 | # # 假设 self.images[index] 是图像的文件名(包括扩展名,如 "image1.jpg") 41 | # filename = self.images[index] 42 | # # 生成保存路径 43 | # image_save_path = os.path.join(save_dir_image, filename) 44 | # mask_save_path = os.path.join(save_dir_mask, filename) 45 | # # 设置目标尺寸 46 | # new_width = 512 47 | # new_height = 512 48 | # # 读取原始图像和掩码 49 | # image1 = cv2.imread(image_path, cv2.IMREAD_COLOR) 50 | # mask1 = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 51 | # # 调整图像和掩码大小 52 | # resized_image = cv2.resize(image1, (new_width, new_height)) 53 | # resized_mask = cv2.resize(mask1, (new_width, new_height)) 54 | # # 保存图像和掩码 55 | # cv2.imwrite(image_save_path, resized_image) 56 | # cv2.imwrite(mask_save_path, resized_mask) 57 | 58 | transform_fn = self.setup_transform_fn() 59 | if transform_fn: 60 | aug_inputs = transform_fn(image=img, mask=mask) 61 | img_aug, mask_aug = aug_inputs['image'], aug_inputs['mask'] 62 | if self.bd_loss: 63 | dist_map = get_dist_map(mask_aug) 64 | if len(mask_aug.shape) == 2: 65 | mask_aug.unsqueeze_(dim=0) 66 | if self.get_name: 67 | return img_aug, mask_aug, self.images[index] 68 | elif self.bd_loss: 69 | return img_aug, mask_aug, dist_map 70 | else: 71 | return img_aug, mask_aug 72 | 73 | @DATASET_REGISTRY.register() 74 | class Dataset_test(Basedataset): 75 | def __init__(self, opt): 76 | super().__init__(opt) 77 | def setup_transform_fn(self): 78 | return aug_transform_test_new(self.opt) 79 | 80 | @DATASET_REGISTRY.register() 81 | class Dataset_aug_s(Basedataset): 82 | def __init__(self, opt): 83 | super().__init__(opt) 84 | def setup_transform_fn(self): 85 | return aug_transform_s(self.opt) 86 | 87 | @DATASET_REGISTRY.register() 88 | class Dataset_aug_m(Basedataset): 89 | def __init__(self, opt): 90 | super().__init__(opt) 91 | def setup_transform_fn(self): 92 | return aug_transform_m(self.opt) 93 | 94 | @DATASET_REGISTRY.register() 95 | class Dataset_aug_bac(Basedataset): 96 | def __init__(self, opt): 97 | super().__init__(opt) 98 | def setup_transform_fn(self): 99 | return aug_transform_train(self.opt) 100 | 101 | 102 | @DATASET_REGISTRY.register() 103 | class Dataset_infer(Data.Dataset): 104 | def __init__(self, opt): 105 | super().__init__() 106 | self.opt = opt 107 | self.image_root = opt['data_root'] 108 | self.images = os.listdir(self.image_root) 109 | self.get_name = opt.get('get_name', False) 110 | def __len__(self): 111 | return len(self.images) 112 | def setup_transform_fn(self): 113 | return aug_transform_test(opt=self.opt) 114 | def __getitem__(self, index): 115 | image_path = os.path.join(self.image_root, self.images[index]) 116 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 117 | transform_fn = self.setup_transform_fn() 118 | if transform_fn: 119 | aug_inputs = transform_fn(image=img) 120 | img_aug = aug_inputs['image'] 121 | 122 | if self.get_name: 123 | return img_aug, self.images[index] 124 | else: 125 | return img_aug -------------------------------------------------------------------------------- /basicseg/seg_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from basicseg.base_model import Base_model 4 | import copy 5 | from collections import OrderedDict 6 | from basicseg.metric import Binary_metric 7 | import torch.nn.functional as F 8 | 9 | class Seg_model(Base_model): 10 | def __init__(self, opt): 11 | super().__init__() 12 | self.opt = opt 13 | self.init_model() 14 | 15 | def init_model(self): 16 | self.setup_net() 17 | self.setup_optimizer() 18 | self.setup_loss() 19 | self.setup_metric() 20 | self.setup_lr_schduler() 21 | 22 | def setup_metric(self): 23 | self.best_norm_metric = {'epoch':0., 'iou':0., 'net': None} 24 | self.best_mean_metric = {'epoch':0, 'iou':0., 'net': None} 25 | self.metric = Binary_metric() 26 | self.epoch_metric = {} 27 | self.batch_metric = {} 28 | 29 | def get_mean_metric(self, dist=False, reduction='mean'): 30 | if dist: 31 | return self.reduce_dict(self.metric.get_mean_result(), reduction) 32 | else: 33 | return self.dict_wrapper(self.metric.get_mean_result()) 34 | 35 | def get_norm_metric(self, dist=False, reduction='mean'): 36 | if dist: 37 | return self.reduce_dict(self.metric.get_norm_result(), reduction) 38 | else: 39 | return self.dict_wrapper(self.metric.get_norm_result()) 40 | 41 | def get_epoch_loss(self, dist=False, reduction='sum'): 42 | epoch_loss = copy.deepcopy(self.epoch_loss) 43 | self.reset_epoch_loss() 44 | if dist: 45 | return self.reduce_dict(epoch_loss, reduction) 46 | else: 47 | return self.dict_wrapper(epoch_loss) 48 | 49 | def get_batch_loss(self, dist=False, reduction='sum'): 50 | batch_loss = copy.deepcopy(self.batch_loss) 51 | self.reset_batch_loss() 52 | if dist: 53 | return self.reduce_dict(batch_loss, reduction) 54 | else: 55 | return self.dict_wrapper(batch_loss) 56 | 57 | def optimize_one_iter(self, data): 58 | if self.bd_loss: 59 | img, mask, dist_map = data 60 | img, mask, dist_map = img.to(self.device), mask.to(self.device), dist_map.to(self.device) 61 | else: 62 | img, mask = data 63 | img, mask = img.to(self.device), mask.to(self.device) 64 | pred, pred_1, pred_2,pred_3,pred_4 = self.net(img) 65 | cur_loss = 0. 66 | if not isinstance(pred, (list, tuple)): 67 | pred = [pred] 68 | for idx, pred_ in enumerate(pred): 69 | pred_ = F.interpolate(pred_, mask.shape[2:], mode='bilinear', align_corners=False) 70 | if idx == 0: 71 | pred[0] = pred_ 72 | for loss_type, loss_criteria in self.loss_fn.items(): 73 | if loss_type == 'BD_loss': 74 | loss = loss_criteria(pred_, dist_map) * self.loss_weight[loss_type][idx] 75 | else: 76 | loss = loss_criteria(pred_, mask) * self.loss_weight[loss_type][idx] +loss_criteria(pred_1, mask) * 0.5 +loss_criteria(pred_2, mask) * 0.25+loss_criteria(pred_3, mask) * 0.125+loss_criteria(pred_4, mask) * 0.0625 77 | self.epoch_loss[loss_type + '_' + str(idx)] += loss.detach().clone() 78 | self.batch_loss[loss_type + '_' + str(idx)] += loss.detach().clone() 79 | cur_loss += loss 80 | self.optim.zero_grad() 81 | cur_loss.backward() 82 | self.optim.step() 83 | with torch.no_grad(): 84 | self.metric.update(pred=pred[0], target=mask) 85 | # return loss_result 86 | 87 | def test_one_iter(self, data): 88 | with torch.no_grad(): 89 | if self.bd_loss: 90 | img, mask, dist_map = data 91 | img, mask, dist_map = img.to(self.device), mask.to(self.device), dist_map.to(self.device) 92 | else: 93 | img, mask = data 94 | img, mask = img.to(self.device), mask.to(self.device) 95 | pred,_,_,_,_ = self.net(img) 96 | if not isinstance(pred, (list, tuple)): 97 | pred = [pred] 98 | for idx, pred_ in enumerate(pred): 99 | pred_ = F.interpolate(pred_, mask.shape[2:], mode='bilinear', align_corners=False) 100 | if idx == 0: 101 | pred[0] = pred_ 102 | for loss_type, loss_criteria in self.loss_fn.items(): 103 | if loss_type == 'BD_loss': 104 | loss = loss_criteria(pred_, dist_map) * self.loss_weight[loss_type][idx] 105 | else: 106 | loss = loss_criteria(pred_, mask) * self.loss_weight[loss_type][idx] 107 | self.epoch_loss[loss_type + '_' + str(idx)] += loss.detach().clone() 108 | self.batch_loss[loss_type + '_' + str(idx)] += loss.detach().clone() 109 | self.metric.update(pred=pred[0], target=mask) 110 | 111 | -------------------------------------------------------------------------------- /basicseg/networks/common/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Basicblock(nn.Module): 5 | "block for resnet18 and resnet34 the same with the original one" 6 | expansion = 1 7 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 8 | super().__init__() 9 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 10 | self.bn1 = nn.BatchNorm2d(planes) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 12 | self.bn2 = nn.BatchNorm2d(planes) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.downsample = downsample 15 | 16 | def forward(self, x): 17 | residual = x 18 | x = self.conv1(x) 19 | x = self.bn1(x) 20 | x = self.relu(x) 21 | x = self.conv2(x) 22 | x = self.bn2(x) 23 | if self.downsample is not None: 24 | residual = self.downsample(residual) 25 | x += residual 26 | out = self.relu(x) 27 | return out 28 | 29 | class Bottleneck(nn.Module): 30 | "block for resnet 50 and more, switching the stride of conv1 and conv2 which is different with the original one" 31 | expansion = 4 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation = 1): 33 | super().__init__() 34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, stride=1, bias=False) 39 | self.bn3 = nn.BatchNorm2d(planes*4) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.downsample = downsample 42 | 43 | def forward(self, x): 44 | residual = x 45 | x = self.conv1(x) 46 | x = self.bn1(x) 47 | x = self.relu(x) 48 | x = self.conv2(x) 49 | x = self.bn2(x) 50 | x = self.relu(x) 51 | x = self.conv3(x) 52 | x = self.bn3(x) 53 | if self.downsample is not None: 54 | residual = self.downsample(residual) 55 | x += residual 56 | out = self.relu(x) 57 | return out 58 | 59 | class Resnet(nn.Module): 60 | def __init__(self, block, layers, basic_planes=64, dilations = [False, False, False]): 61 | super().__init__() 62 | "replace conv7x7 with 3 conv3x3 and replace self.in_planes from 64 to 128(did not do this) from Upernet Implementation" 63 | "change the basic_planes should change the channel of feat_maps but remember to keep the same with decoder in_planes" 64 | self.in_planes = 64 65 | self.dilation = 1 66 | self.basic_planes = basic_planes #the width of conv_layers , in 67 | self.relu = nn.ReLU(inplace=True) 68 | self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=2, padding=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(self.in_planes) 70 | self.conv2 = nn.Conv2d(self.in_planes, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(self.in_planes) 72 | self.conv3 = nn.Conv2d(self.in_planes, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d(self.in_planes) 74 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 75 | self.layer1 = self._make_layer(block, self.basic_planes, layers[0], stride=1) 76 | self.layer2 = self._make_layer(block, self.basic_planes*2, layers[1], stride=2, dilation=dilations[0]) 77 | self.layer3 = self._make_layer(block, self.basic_planes*4, layers[2], stride=2, dilation=dilations[1]) 78 | self.layer4 = self._make_layer(block, self.basic_planes*8, layers[3], stride=2, dilation=dilations[2]) 79 | 80 | def _make_layer(self, block, planes, blocks, stride=1, dilation=False): 81 | downsample = None 82 | previous_dilation = self.dilation 83 | if dilation: 84 | self.dilation *= stride 85 | stride=1 86 | if stride != 1 or planes * block.expansion != self.in_planes: 87 | 'whether we should change the channel of residual original x' 88 | downsample = nn.Sequential( 89 | nn.Conv2d(self.in_planes, planes*block.expansion, kernel_size=1, stride=stride), 90 | nn.BatchNorm2d(planes*block.expansion) 91 | ) 92 | layers = [] 93 | layers.append(block(self.in_planes, planes, stride=stride, downsample=downsample, dilation=previous_dilation)) 94 | self.in_planes = planes * block.expansion 95 | for i in range(1,blocks): 96 | layers.append(block(self.in_planes, planes, stride=1, downsample=None, dilation=self.dilation)) 97 | return nn.Sequential(*layers) 98 | def forward(self, x): 99 | feat_maps = [] 100 | x = self.conv1(x) 101 | x = self.bn1(x) 102 | x = self.relu(x) 103 | x = self.conv2(x) 104 | x = self.bn2(x) 105 | x = self.relu(x) 106 | x = self.conv3(x) 107 | x = self.bn3(x) 108 | x = self.relu(x) 109 | x = self.maxpool(x) 110 | x = self.layer1(x) 111 | feat_maps.append(x) 112 | x = self.layer2(x) 113 | feat_maps.append(x) 114 | x = self.layer3(x) 115 | feat_maps.append(x) 116 | x = self.layer4(x) 117 | feat_maps.append(x) 118 | return feat_maps 119 | 120 | def resnet18_d(in_c=3, basic_planes=64): 121 | return Resnet(Basicblock, [2,2,2,2], basic_planes=basic_planes, dilations=[False, True, True]) 122 | 123 | def resnet34_d(in_c=3, basic_planes=64): 124 | return Resnet(Basicblock, [3,4,6,3], basic_planes=basic_planes, dilations=[False, True, True]) 125 | 126 | def resnet50_d(in_c=3, basic_planes=64): 127 | return Resnet(Bottleneck, [3,4,6,3], basic_planes=basic_planes, dilations=[False, True, True]) 128 | 129 | def resnet101_d(in_c=3, basic_planes=64): 130 | return Resnet(Bottleneck, [3,4,23,3], basic_planes=basic_planes, dilations=[False, True, True]) 131 | 132 | def resnet18(in_c=3, basic_planes=64): 133 | return Resnet(Basicblock, [2,2,2,2], basic_planes=basic_planes, dilations=[False, False, False]) 134 | 135 | def resnet34(in_c=3, basic_planes=64): 136 | return Resnet(Basicblock, [3,4,6,3], basic_planes=basic_planes, dilations=[False, False, False]) 137 | 138 | def resnet50(in_c=3, basic_planes=64): 139 | return Resnet(Bottleneck, [3,4,6,3], basic_planes=basic_planes, dilations=[False, False, False]) 140 | 141 | def resnet101(in_c=3, basic_planes=64): 142 | return Resnet(Bottleneck, [3,4,23,3], basic_planes=basic_planes, dilations=[False, False, False]) 143 | 144 | def main(): 145 | net_18 = resnet18_d(3, 64) 146 | net_34 = resnet34_d(3, 64) 147 | net_50 = resnet50_d(3, 64) 148 | net_50_ = resnet50() 149 | x = torch.rand(2,3,512,512) 150 | y_18 = net_18(x) 151 | y_34 = net_34(x) 152 | y_50 = net_50(x) 153 | y_50_ = net_50_(x) 154 | # print(*y_18) 155 | for i in y_18: 156 | print(i.shape) 157 | # for i in y_34: 158 | # print(i.shape) 159 | # for i in y_50: 160 | # print(i.shape) 161 | # for i in y_50_: 162 | # print(i.shape) 163 | if __name__ == '__main__': 164 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | import torch.utils.data as Data 6 | import os 7 | import time 8 | import logging 9 | import random 10 | import copy 11 | import numpy as np 12 | from basicseg.seg_model import Seg_model 13 | from basicseg.utils.yaml_options import parse_options, dict2str 14 | from basicseg.utils.path_utils import * 15 | from basicseg.utils.logger import get_root_logger, init_tb_logger, get_env_info, MessageLogger 16 | from basicseg.data import build_dataset 17 | 18 | def set_seed(seed, cuda_deterministic=False): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | if cuda_deterministic: 24 | # slower, more reproducible 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | else: 28 | # faster 29 | torch.backends.cudnn.deterministic = False 30 | torch.backends.cudnn.benchmark = True 31 | 32 | def init_exp(opt, args): 33 | exp_name = opt['exp'].get('name') 34 | if not exp_name: 35 | exp_name = os.path.basename(args.opt[:-4]) 36 | opt['exp']['name'] = exp_name 37 | exp_root = make_exp_root(os.path.join('experiment', exp_name)) 38 | opt['exp']['exp_root'] = exp_root 39 | log_file = os.path.join(exp_root, f'train_{exp_name}_{get_time_str()}.log') 40 | logger = get_root_logger(logger_name='basicseg', log_level=logging.INFO, log_file=log_file) 41 | logger.info(get_env_info()) 42 | logger.info(dict2str(opt)) 43 | tb_logger = init_tb_logger(log_dir = os.path.join(exp_root, 'tb_log')) 44 | return logger, tb_logger 45 | 46 | def init_model(opt): 47 | 48 | model = Seg_model(opt) 49 | return model 50 | 51 | def init_dataset(opt): 52 | # trainset 53 | train_opt = opt['dataset']['train'] 54 | trainset = build_dataset(train_opt) 55 | test_opt = opt['dataset']['test'] 56 | testset = build_dataset(test_opt) 57 | return trainset, testset 58 | 59 | def init_dataloader(opt, trainset, testset): 60 | if opt['exp']['dist']: 61 | sampler = Data.DistributedSampler(trainset) 62 | else: 63 | sampler = None 64 | train_loader = Data.DataLoader(dataset=trainset, batch_size=opt['exp']['bs'],\ 65 | sampler=sampler, num_workers=opt['exp'].get('nw', 16)) 66 | test_loader = Data.DataLoader(dataset=testset, batch_size=opt['exp']['bs'],\ 67 | sampler=None, num_workers=opt['exp'].get('nw', 16)) 68 | return train_loader, test_loader 69 | 70 | def main(): 71 | opt, args = parse_options() 72 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt['exp']['device']) # not safe there 73 | if isinstance(opt['exp']['device'], int): 74 | opt['exp']['dist'] = False 75 | cur_rank = 0 76 | total_device = 1 77 | opt['exp']['num_devices'] = total_device 78 | elif isinstance(opt['exp']['device'], str): 79 | opt['exp']['dist'] = True 80 | dist.init_process_group(backend='nccl') 81 | total_device = len(opt['exp']['device']) // 2 + 1 82 | opt['exp']['num_devices'] = total_device 83 | cur_rank = dist.get_rank() 84 | 85 | # init dataset 86 | trainset, testset = init_dataset(opt) 87 | train_loader, test_loader = init_dataloader(opt, trainset, testset) 88 | 89 | # init exp_root, logger, tb_logger 90 | total_epochs = opt['exp']['total_epochs'] 91 | total_iters = total_epochs * (len(trainset) // opt['exp']['bs'] // total_device +1) 92 | opt['exp']['total_iters'] = total_iters 93 | save_interval = opt['exp']['save_interval'] 94 | test_interval = opt['exp']['test_interval'] 95 | logger, tb_logger = init_exp(opt, args) 96 | set_seed(cur_rank + 0) 97 | # initialize parameters including network, optimizer, loss function, learning rate scheduler 98 | model = init_model(opt) 99 | cur_iter = 0 100 | cur_epoch = 1 101 | # train from checkpoint 102 | if opt.get('resume'): 103 | if opt['resume'].get('net_path'): 104 | model.load_network(model.net, opt['resume']['net_path']) 105 | logger.info(f'load pretrained network from: {opt["resume"]["net_path"]}') 106 | else: 107 | logger.info(f'load from random initialized network') 108 | if opt['resume'].get('state_path'): 109 | cur_epoch = model.resume_training(opt['resume']['state_path']) 110 | cur_iter = cur_epoch * (len(trainset) // opt['exp']['bs'] // total_device + 1) 111 | logger.info(f'resume training from epoch: {cur_epoch}') 112 | else: 113 | logger.info(f'training from epoch: 1') 114 | 115 | msg_logger = MessageLogger(opt, start_epoch=cur_epoch, tb_logger=tb_logger) 116 | for epoch in range(cur_epoch, total_epochs+1): 117 | if opt['exp']['dist']: 118 | train_loader.sampler.set_epoch(epoch) 119 | epoch_st_time = time.time() 120 | ########## training ########## 121 | for idx, data in enumerate(train_loader): 122 | cur_iter += 1 123 | model.update_learning_rate(cur_iter, idx) 124 | model.optimize_one_iter(data) 125 | epoch_time = time.time() - epoch_st_time 126 | log_vars = {'epoch': epoch} 127 | log_vars.update({'lrs': model.get_current_learning_rate()}) 128 | log_vars.update({'time': epoch_time}) 129 | log_vars.update({'train_loss': model.get_epoch_loss(opt['exp']['dist'], 'sum')}) 130 | log_vars.update({'train_mean_metric': model.get_mean_metric(opt['exp']['dist'], 'mean')}) 131 | log_vars.update({'train_norm_metric': model.get_norm_metric(opt['exp']['dist'], 'mean')}) 132 | ########## tesing ########## 133 | if cur_rank == 0 and epoch % test_interval == 0: 134 | # model.net.eval() 135 | model.model_to_eval() 136 | for idx, data in enumerate(test_loader): 137 | model.test_one_iter(data) 138 | log_vars.update({'test_loss': model.get_epoch_loss()}) 139 | test_mean_metric = model.get_mean_metric() 140 | test_norm_metric = model.get_norm_metric() 141 | log_vars.update({'test_mean_metric': test_mean_metric}) 142 | log_vars.update({'test_norm_metric': test_norm_metric}) 143 | if test_mean_metric['iou'] > model.best_mean_metric['iou']: 144 | model.best_mean_metric['iou'] = test_mean_metric['iou'] 145 | model.best_mean_metric['net'] = copy.deepcopy(model.net.state_dict()) 146 | model.best_mean_metric['epoch'] = epoch 147 | if test_norm_metric['iou'] > model.best_norm_metric['iou']: 148 | model.best_norm_metric['iou'] = test_norm_metric['iou'] 149 | model.best_norm_metric['net'] = copy.deepcopy(model.net.state_dict()) 150 | model.best_norm_metric['epoch'] = epoch 151 | # model.net.train() 152 | model.model_to_train() 153 | ########## saving_model ########## 154 | if cur_rank == 0 and epoch % save_interval == 0 : 155 | model.save_network(opt, model.net, epoch) 156 | model.save_training_state(opt, epoch) 157 | 158 | msg_logger(log_vars) 159 | 160 | ########## trainging done ########## 161 | if cur_rank == 0: 162 | model.save_network(opt, model.net, current_epoch='latest') 163 | model.save_network(opt, model.best_mean_metric['net'], current_epoch='best_mean', net_dict=True) 164 | model.save_network(opt, model.best_norm_metric['net'], current_epoch='best_norm', net_dict=True) 165 | logger.info(f"best_mean_metric: [epoch: {model.best_mean_metric['epoch']}] [iou: {model.best_mean_metric['iou']:.4f}]") 166 | logger.info(f"best_norm_metric: [epoch: {model.best_norm_metric['epoch']}] [iou: {model.best_norm_metric['iou']:.4f}]") 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /basicseg/networks/HCFnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import time 5 | from thop import profile 6 | from thop import clever_format # 用于格式化输出的 MACs 和参数数量 7 | import torch.nn.functional as F 8 | # from ptflops import get_model_complexity_info 9 | from basicseg.main_blocks import PPA, DASI, MDCR 10 | from basicseg.utils.registry import NET_REGISTRY 11 | 12 | 13 | @NET_REGISTRY.register() 14 | class HCFnet(nn.Module): 15 | def __init__(self, 16 | in_features=3, 17 | out_features=1, 18 | gt_ds = False, 19 | ) -> None: 20 | super().__init__() 21 | self.gt_ds = gt_ds 22 | 23 | self.maxpool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) 24 | 25 | self.p1 = PPA(in_features=in_features, 26 | filters=32) 27 | 28 | self.respath1 = DASI(in_features=32, 29 | out_features=32, 30 | ) 31 | self.p2 = PPA(in_features=32, 32 | filters=int(32 * 2)) 33 | self.respath2 = DASI(in_features=64, 34 | out_features=32 * 2, 35 | ) 36 | self.p3 = PPA(in_features=64, 37 | filters=int(32 * 4)) 38 | self.respath3 = DASI(in_features=128, 39 | out_features=32 * 4, 40 | ) 41 | self.p4 = PPA(in_features=128, 42 | filters=int(32 * 8)) 43 | self.respath4 = DASI(in_features=256, 44 | out_features=32 * 8, 45 | ) 46 | self.p5 = PPA(in_features=256, 47 | filters=int(32 * 16)) 48 | 49 | self.mdcr = MDCR(in_features=int(512),out_features=int(512)) 50 | 51 | 52 | 53 | self.up1 = nn.Sequential(nn.ConvTranspose2d(512, 54 | 32*8, 55 | kernel_size=(2,2), 56 | stride=(2,2)), 57 | nn.BatchNorm2d(32 * 8), 58 | nn.ReLU()) 59 | 60 | 61 | self.p6 = PPA(in_features=32 * 8 * 2, 62 | filters=int(32 * 8)) 63 | self.up2 = nn.Sequential(nn.ConvTranspose2d(256, 64 | 32*4, 65 | kernel_size=(2,2), 66 | stride=(2,2)), 67 | nn.BatchNorm2d(32 * 4), 68 | nn.ReLU()) 69 | 70 | self.p7 = PPA(in_features=32 * 4 * 2, 71 | filters=int(32 * 4)) 72 | 73 | self.up3 = nn.Sequential(nn.ConvTranspose2d(128, 74 | 32*2, 75 | kernel_size=(2,2), 76 | stride=(2,2)), 77 | nn.BatchNorm2d(32 * 2), 78 | nn.ReLU()) 79 | 80 | self.p8 = PPA(in_features=32 * 2 * 2, 81 | filters=int(32 * 2)) 82 | self.up4 = nn.Sequential(nn.ConvTranspose2d(64, 83 | 32, 84 | kernel_size=(2,2), 85 | stride=(2,2)), 86 | nn.BatchNorm2d(32), 87 | nn.ReLU()) 88 | 89 | self.p9 = PPA(in_features=32 * 2, 90 | filters=int(32)) 91 | 92 | 93 | self.out = nn.Conv2d(in_channels=32, 94 | out_channels=out_features, 95 | kernel_size=(1, 1), 96 | padding=(0, 0) 97 | ) 98 | self.out1 = nn.Conv2d(in_channels=512, 99 | out_channels=out_features, 100 | kernel_size=(1, 1), 101 | padding=(0, 0) 102 | ) 103 | self.out2 = nn.Conv2d(in_channels=256, 104 | out_channels=out_features, 105 | kernel_size=(1, 1), 106 | padding=(0, 0) 107 | ) 108 | self.out3 = nn.Conv2d(in_channels=128, 109 | out_channels=out_features, 110 | kernel_size=(1, 1), 111 | padding=(0, 0) 112 | ) 113 | self.out4 = nn.Conv2d(in_channels=64, 114 | out_channels=out_features, 115 | kernel_size=(1, 1), 116 | padding=(0, 0) 117 | ) 118 | 119 | 120 | 121 | def forward(self, x): 122 | #encoder 123 | x1 = self.p1(x) 124 | xp1 = self.maxpool(x1) 125 | x2 = self.p2(xp1) 126 | xp2 = self.maxpool(x2) 127 | x3 = self.p3(xp2) 128 | xp3 = self.maxpool(x3) 129 | x4 = self.p4(xp3) 130 | xp4 = self.maxpool(x4) 131 | x = self.p5(xp4) 132 | x = self.mdcr(x) 133 | 134 | 135 | x1_res = self.respath1(x1, x2, None) # 1 32 512 512 136 | x2_res = self.respath2(x2, x3, x1) # 1 64 256 256 137 | x3_res = self.respath3(x3, x4, x2) # 1 128 128 128 138 | x4_res = self.respath4(x4, x, x3) # 1 256 64 64 139 | 140 | 141 | #decoder 142 | out4 = F.interpolate(self.out1(x), scale_factor=16, mode ='bilinear', align_corners=True) 143 | x = self.up1(x) 144 | x = torch.cat((x, x4_res), dim=1) 145 | x = self.p6(x) 146 | out3 = F.interpolate(self.out2(x), scale_factor=8, mode ='bilinear', align_corners=True) 147 | x = self.up2(x) 148 | x = torch.cat((x, x3_res), dim=1) 149 | x = self.p7(x) 150 | out2 = F.interpolate(self.out3(x), scale_factor=4, mode ='bilinear', align_corners=True) 151 | x = self.up3(x) 152 | x = torch.cat((x, x2_res), dim=1) 153 | x = self.p8(x) 154 | out1 = F.interpolate(self.out4(x), scale_factor=2, mode ='bilinear', align_corners=True) 155 | x = self.up4(x) 156 | x = torch.cat((x, x1_res), dim=1) 157 | x = self.p9(x) 158 | out = self.out(x) 159 | if self.gt_ds: 160 | return out,out1,out2,out3,out4 161 | else: 162 | return out 163 | 164 | 165 | 166 | if __name__ == '__main__': 167 | # 定义输入张量 168 | input_tensor = torch.randn(1, 3, 1024, 1024) # 假设输入为 (batch_size=1, 3通道, 512x512 图像) 169 | 170 | # 实例化模型 (确保你的 `SAMamba` 模型类定义正确) 171 | net = HCFnet() 172 | # print(net.encoder) 173 | 174 | # 检查当前设备并将模型移动到相应设备 175 | device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') 176 | net.to(device) 177 | input_tensor = input_tensor.to(device) 178 | 179 | # 使用 thop 计算 MACs 和参数量 180 | flops, params = profile(net, inputs=(input_tensor,)) 181 | flops, params = clever_format([flops, params], "%.2f") 182 | 183 | # 打印计算成本和参数量 184 | print(f"Computational cost (MACs): {flops}") 185 | print(f"Number of parameters: {params}") 186 | 187 | # # 测试 100 张图片的推理时间 188 | # total_time = 0 189 | # num_images = 100 190 | 191 | # # 确保模型处于评估模式 192 | # net.eval() 193 | # 194 | # with torch.no_grad(): # 禁用梯度计算以提高推理速度 195 | # for _ in range(num_images): 196 | # torch.cuda.synchronize() # 同步 GPU 和 CPU,确保时间精确 197 | # start = time.time() 198 | # result = net(input_tensor) 199 | # torch.cuda.synchronize() 200 | # end = time.time() 201 | # 202 | # infer_time = end - start 203 | # total_time += infer_time 204 | # 205 | # # print(f'Single inference time: {infer_time:.6f} seconds') 206 | # 207 | # # 计算平均推理时间和 FPS 208 | # average_time = total_time / num_images 209 | # fps = 1 / average_time if average_time > 0 else float('inf') 210 | # 211 | # print(f'Average inference time for 100 images: {average_time:.6f} seconds') 212 | # print(f'FPS: {fps:.2f}') 213 | 214 | -------------------------------------------------------------------------------- /basicseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import datetime 8 | import logging 9 | import time 10 | 11 | from .dist_util import get_dist_info, master_only 12 | 13 | 14 | class MessageLogger(): 15 | """Message logger for printing. 16 | 17 | Args: 18 | opt (dict): Config. It contains the following keys: 19 | name (str): Exp name. 20 | logger (dict): Contains 'print_freq' (str) for logger interval. 21 | train (dict): Contains 'total_iter' (int) for total iters. 22 | use_tb_logger (bool): Use tensorboard logger. 23 | start_iter (int): Start iter. Default: 1. 24 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 25 | """ 26 | 27 | def __init__(self, opt, start_epoch=1, tb_logger=None): 28 | self.exp_name = opt['exp']['name'] 29 | self.interval = opt['exp']['log_interval'] 30 | self.start_epoch = start_epoch 31 | self.max_epochs = opt['exp']['total_epochs'] 32 | self.use_tb_logger = True 33 | self.tb_logger = tb_logger 34 | self.start_time = time.time() 35 | self.logger = get_root_logger() 36 | 37 | @master_only 38 | def __call__(self, log_vars): 39 | """Format logging message. 40 | Args: 41 | log_vars (dict): It contains the following keys: 42 | epoch (int): Epoch number. 43 | iter (int): Current iter. 44 | lrs (list): List for learning rates. 45 | 46 | time (float): Iter time. 47 | data_time (float): Data time for each iter. 48 | """ 49 | # epoch, iter, learning rates 50 | current_epoch = log_vars.pop('epoch') 51 | # current_iter = log_vars.pop('iter') 52 | # total_iter = log_vars.pop('total_iter') 53 | lrs = log_vars.pop('lrs') 54 | 55 | message = (f'[{self.exp_name}][epoch:{current_epoch:3d}, ' 56 | f'lr:(') 57 | message += f'{lrs:.3e},' 58 | message += ')] ' 59 | 60 | # time and estimated time 61 | if 'time' in log_vars.keys(): 62 | epoch_time = log_vars.pop('time') 63 | total_time = time.time() - self.start_time 64 | time_sec_avg = total_time / (current_epoch - self.start_epoch + 1) 65 | eta_sec = time_sec_avg * (self.max_epochs - current_epoch) 66 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 67 | message += f'[eta: {eta_str}, ' 68 | message += f'time (epoch): {epoch_time:.3f} ] ' 69 | 70 | # other items, especially losses 71 | for k, v in log_vars.items(): 72 | # message += f'{k}: {v:.4e} ' 73 | # tensorboard logger 74 | if self.use_tb_logger and 'debug' not in self.exp_name: 75 | # normed_step = 10000 * (current_iter / total_iter) 76 | # normed_step = int(normed_step) 77 | 78 | if k == 'train_loss': 79 | message += '\nTrainSet\n' 80 | for loss_type, loss_value in log_vars[k].items(): 81 | self.tb_logger.add_scalar(f'train_losses/{loss_type}', loss_value, current_epoch) 82 | message += f'{loss_type}:{loss_value:.4e} ' 83 | message += '\n' 84 | elif k == 'train_mean_metric': 85 | for metric_type, metric_value in log_vars[k].items(): 86 | self.tb_logger.add_scalar(f'train_mean_metrics/{metric_type}', metric_value, current_epoch) 87 | message += f"m_fscore:{log_vars[k]['fscore']:.4f} m_iou:{log_vars[k]['iou']:.4f} " 88 | elif k == 'train_norm_metric': 89 | for metric_type, metric_value in log_vars[k].items(): 90 | self.tb_logger.add_scalar(f'train_norm_metrics/{metric_type}', metric_value, current_epoch) 91 | message += f"n_fscore:{log_vars[k]['fscore']:.4f} n_iou:{log_vars[k]['iou']:.4f} " 92 | elif k == 'test_loss': 93 | message += '\nTestSet\n' 94 | for loss_type, loss_value in log_vars[k].items(): 95 | self.tb_logger.add_scalar(f'test_losses/{loss_type}', loss_value, current_epoch) 96 | message += f'{loss_type}:{loss_value:.4e} ' 97 | message += '\n' 98 | elif k == 'test_mean_metric': 99 | for metric_type, metric_value in log_vars[k].items(): 100 | self.tb_logger.add_scalar(f'test_mean_metrics/{metric_type}', metric_value, current_epoch) 101 | message += f"m_fscore:{log_vars[k]['fscore']:.4f} m_iou:{log_vars[k]['iou']:.4f} " 102 | elif k == 'test_norm_metric': 103 | for metric_type, metric_value in log_vars[k].items(): 104 | self.tb_logger.add_scalar(f'test_norm_metrics/{metric_type}', metric_value, current_epoch) 105 | message += f"n_fscore:{log_vars[k]['fscore']:.4f} n_iou:{log_vars[k]['iou']:.4f} " 106 | else: 107 | assert 1 == 0 108 | # else: 109 | # self.tb_logger.add_scalar(k, v, current_iter) 110 | message += '\n' 111 | self.logger.info(message) 112 | 113 | 114 | @master_only 115 | def init_tb_logger(log_dir): 116 | from torch.utils.tensorboard import SummaryWriter 117 | tb_logger = SummaryWriter(log_dir=log_dir) 118 | return tb_logger 119 | 120 | 121 | 122 | 123 | 124 | def get_root_logger(logger_name='basicseg', 125 | log_level=logging.INFO, 126 | log_file=None): 127 | """Get the root logger. 128 | 129 | The logger will be initialized if it has not been initialized. By default a 130 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 131 | also be added. 132 | 133 | Args: 134 | logger_name (str): root logger name. Default: 'basicsr'. 135 | log_file (str | None): The log filename. If specified, a FileHandler 136 | will be added to the root logger. 137 | log_level (int): The root logger level. Note that only the process of 138 | rank 0 is affected, while other processes will set the level to 139 | "Error" and be silent most of the time. 140 | 141 | Returns: 142 | logging.Logger: The root logger. 143 | """ 144 | logger = logging.getLogger(logger_name) 145 | # if the logger has been initialized, just return it 146 | if logger.hasHandlers(): 147 | return logger 148 | 149 | format_str = '%(asctime)s %(levelname)s: %(message)s' 150 | logging.basicConfig(format=format_str, level=log_level) 151 | rank, _ = get_dist_info() 152 | if rank != 0: 153 | logger.setLevel('ERROR') 154 | elif log_file is not None: 155 | file_handler = logging.FileHandler(log_file, 'w') 156 | file_handler.setFormatter(logging.Formatter(format_str)) 157 | file_handler.setLevel(log_level) 158 | logger.addHandler(file_handler) 159 | 160 | return logger 161 | 162 | 163 | def get_env_info(): 164 | """Get environment information. 165 | 166 | Currently, only log the software version. 167 | """ 168 | import torch 169 | import torchvision 170 | msg = ('\nVersion Information: ' 171 | f'\n\tPyTorch: {torch.__version__}' 172 | f'\n\tTorchVision: {torchvision.__version__}') 173 | return msg 174 | 175 | # @master_only 176 | # def init_wandb_logger(opt): 177 | # """We now only use wandb to sync tensorboard log.""" 178 | # import wandb 179 | # logger = logging.getLogger('basicsr') 180 | # 181 | # project = opt['logger']['wandb']['project'] 182 | # resume_id = opt['logger']['wandb'].get('resume_id') 183 | # if resume_id: 184 | # wandb_id = resume_id 185 | # resume = 'allow' 186 | # logger.warning(f'Resume wandb logger with id={wandb_id}.') 187 | # else: 188 | # wandb_id = wandb.util.generate_id() 189 | # resume = 'never' 190 | # 191 | # wandb.init( 192 | # id=wandb_id, 193 | # resume=resume, 194 | # name=opt['name'], 195 | # config=opt, 196 | # project=project, 197 | # sync_tensorboard=True) 198 | # 199 | # logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') -------------------------------------------------------------------------------- /basicseg/networks/common/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CDC_conv(nn.Module): 6 | def __init__(self, in_channels, out_channels, bias=True, kernel_size=3, stride=1, 7 | padding=1, dilation=1, theta=0.7, padding_mode='zeros'): 8 | super().__init__() 9 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, 10 | stride = stride, dilation=dilation, bias=bias, padding_mode=padding_mode) 11 | self.theta = theta 12 | 13 | def forward(self, x): 14 | norm_out = self.conv(x) 15 | if (self.theta - 0.0) < 1e-6: 16 | return norm_out 17 | else: 18 | # [c_out, c_in, kernel_size, kernel_size] = self.conv.weight.shape 19 | kernel_diff = self.conv.weight.sum(2).sum(2) 20 | kernel_diff = kernel_diff[:, :, None, None] 21 | diff_out = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, 22 | dilation=1, padding=0) 23 | out = norm_out - self.theta * diff_out 24 | return out 25 | 26 | class ASPPConv(nn.Sequential): 27 | def __init__(self, in_channels, out_channels, dilation): 28 | super(ASPPConv, self).__init__( 29 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 30 | nn.BatchNorm2d(out_channels), 31 | nn.ReLU() 32 | ) 33 | 34 | 35 | class ASPPPooling(nn.Sequential): 36 | def __init__(self, in_channels, out_channels): 37 | super(ASPPPooling, self).__init__( 38 | nn.AdaptiveAvgPool2d(1), 39 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 40 | nn.BatchNorm2d(out_channels), 41 | nn.ReLU() 42 | ) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | size = x.shape[-2:] 46 | for mod in self: 47 | x = mod(x) 48 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 49 | 50 | class ASPP(nn.Module): 51 | def __init__(self, in_channels, atrous_rates, out_channels=256) -> None: 52 | super(ASPP, self).__init__() 53 | modules = [ 54 | nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), 55 | nn.BatchNorm2d(out_channels), 56 | nn.ReLU()) 57 | ] 58 | 59 | rates = tuple(atrous_rates) 60 | for rate in rates: 61 | modules.append(ASPPConv(in_channels, out_channels, rate)) 62 | 63 | modules.append(ASPPPooling(in_channels, out_channels)) 64 | 65 | self.convs = nn.ModuleList(modules) 66 | 67 | self.project = nn.Sequential( 68 | nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), 69 | nn.BatchNorm2d(out_channels), 70 | nn.ReLU(), 71 | nn.Dropout(0.5) 72 | ) 73 | 74 | def forward(self, x): 75 | _res = [] 76 | for conv in self.convs: 77 | _res.append(conv(x)) 78 | res = torch.cat(_res, dim=1) 79 | return self.project(res) 80 | 81 | class DeformConv2d(nn.Module): 82 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): 83 | """ 84 | Args: 85 | modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2). 86 | """ 87 | super(DeformConv2d, self).__init__() 88 | self.device = None 89 | self.kernel_size = kernel_size 90 | self.padding = padding 91 | self.stride = stride 92 | self.zero_padding = nn.ZeroPad2d(padding) 93 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=kernel_size, bias=bias) 94 | 95 | self.p_conv = nn.Conv2d(in_channels, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 96 | nn.init.constant_(self.p_conv.weight, 0) 97 | self.p_conv.register_backward_hook(self._set_lr) 98 | 99 | self.modulation = modulation 100 | if modulation: 101 | self.m_conv = nn.Conv2d(in_channels, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 102 | nn.init.constant_(self.m_conv.weight, 0) 103 | self.m_conv.register_backward_hook(self._set_lr) 104 | 105 | @staticmethod 106 | def _set_lr(module, grad_input, grad_output): 107 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 108 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 109 | 110 | def forward(self, x): 111 | self.device = x.device 112 | offset = self.p_conv(x) 113 | if self.modulation: 114 | m = torch.sigmoid(self.m_conv(x)) 115 | 116 | dtype = offset.data.type() 117 | ks = self.kernel_size 118 | N = offset.size(1) // 2 119 | 120 | if self.padding: 121 | x = self.zero_padding(x) 122 | 123 | # (b, 2N, h, w) 124 | p = self._get_p(offset, dtype) 125 | 126 | # (b, h, w, 2N) 127 | p = p.contiguous().permute(0, 2, 3, 1) 128 | q_lt = p.detach().floor() 129 | q_rb = q_lt + 1 130 | 131 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() 132 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() 133 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 134 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 135 | 136 | # clip p 137 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) 138 | 139 | # bilinear kernel (b, h, w, N) 140 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 141 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 142 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 143 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 144 | 145 | # (b, c, h, w, N) 146 | x_q_lt = self._get_x_q(x, q_lt, N) 147 | x_q_rb = self._get_x_q(x, q_rb, N) 148 | x_q_lb = self._get_x_q(x, q_lb, N) 149 | x_q_rt = self._get_x_q(x, q_rt, N) 150 | 151 | # (b, c, h, w, N) 152 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 153 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 154 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 155 | g_rt.unsqueeze(dim=1) * x_q_rt 156 | 157 | # modulation 158 | if self.modulation: 159 | m = m.contiguous().permute(0, 2, 3, 1) 160 | m = m.unsqueeze(dim=1) 161 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) 162 | x_offset *= m 163 | 164 | x_offset = self._reshape_x_offset(x_offset, ks) 165 | out = self.conv(x_offset) 166 | 167 | return out 168 | 169 | def _get_p_n(self, N, dtype): 170 | p_n_x, p_n_y = torch.meshgrid( 171 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), 172 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) 173 | # (2N, 1) 174 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 175 | p_n = p_n.view(1, 2*N, 1, 1).to(self.device).type(dtype) 176 | 177 | return p_n 178 | 179 | def _get_p_0(self, h, w, N, dtype): 180 | p_0_x, p_0_y = torch.meshgrid( 181 | torch.arange(1, h*self.stride+1, self.stride), 182 | torch.arange(1, w*self.stride+1, self.stride)) 183 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 184 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 185 | p_0 = torch.cat([p_0_x, p_0_y], 1).to(self.device).type(dtype) 186 | 187 | return p_0 188 | 189 | def _get_p(self, offset, dtype): 190 | N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) 191 | 192 | # (1, 2N, 1, 1) 193 | p_n = self._get_p_n(N, dtype) 194 | # (1, 2N, h, w) 195 | p_0 = self._get_p_0(h, w, N, dtype) 196 | p = p_0 + p_n + offset 197 | return p 198 | 199 | def _get_x_q(self, x, q, N): 200 | b, h, w, _ = q.size() 201 | padded_w = x.size(3) 202 | c = x.size(1) 203 | # (b, c, h*w) 204 | x = x.contiguous().view(b, c, -1) 205 | 206 | # (b, h, w, N) 207 | index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y 208 | # (b, c, h*w*N) 209 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 210 | 211 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 212 | 213 | return x_offset 214 | 215 | @staticmethod 216 | def _reshape_x_offset(x_offset, ks): 217 | b, c, h, w, N = x_offset.size() 218 | x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) 219 | x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) 220 | 221 | return x_offset 222 | 223 | 224 | class GatedConv2dWithActivation(nn.Module): 225 | """ 226 | Gated Convlution layer with activation (default activation:LeakyReLU) 227 | Params: same as conv2d 228 | Input: The feature from last layer "I" 229 | Output:\phi(f(I))*\sigmoid(g(I)) 230 | """ 231 | 232 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, 233 | batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): 234 | super(GatedConv2dWithActivation, self).__init__() 235 | self.batch_norm = batch_norm 236 | self.activation = activation 237 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 238 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 239 | self.batch_norm2d = torch.nn.BatchNorm2d(out_channels) 240 | self.sigmoid = torch.nn.Sigmoid() 241 | 242 | for m in self.modules(): 243 | if isinstance(m, nn.Conv2d): 244 | nn.init.kaiming_normal_(m.weight) 245 | def gated(self, mask): 246 | #return torch.clamp(mask, -1, 1) 247 | return self.sigmoid(mask) 248 | def forward(self, input): 249 | x = self.conv2d(input) 250 | mask = self.mask_conv2d(input) 251 | if self.activation is not None: 252 | x = self.activation(x) * self.gated(mask) 253 | else: 254 | x = x * self.gated(mask) 255 | if self.batch_norm: 256 | return self.batch_norm2d(x) 257 | else: 258 | return x 259 | 260 | if __name__ == '__main__': 261 | conv = CDC_conv(3, 3, kernel_size=3, stride=2, padding=0) 262 | d_conv = GatedConv2dWithActivation(in_channels=3, out_channels=3) 263 | x = torch.rand(1,3,512,512) 264 | y = d_conv(x) 265 | stand_conv = nn.Conv2d(3,3,kernel_size=3,padding=1) 266 | print(y.shape) 267 | import ptflops 268 | mac,para = ptflops.get_model_complexity_info(d_conv, (3,512,512), print_per_layer_stat=False) 269 | print(mac,para) 270 | mac,para = ptflops.get_model_complexity_info(stand_conv, (3,512,521), print_per_layer_stat=False) 271 | print(mac, para) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /basicseg/networks/common/agpc/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | import math 7 | 8 | try: 9 | from torch.hub import load_state_dict_from_url 10 | except ImportError: 11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | 13 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | } 26 | 27 | 28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=dilation, groups=groups, bias=False, dilation=dilation) 32 | 33 | 34 | def conv1x1(in_planes, out_planes, stride=1): 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 43 | base_width=64, dilation=1, norm_layer=None): 44 | super(BasicBlock, self).__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = norm_layer(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = norm_layer(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | identity = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out += identity 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 81 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 82 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 83 | # This variant is also known as ResNet V1.5 and improves accuracy according to 84 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 85 | 86 | expansion = 4 87 | 88 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 89 | base_width=64, dilation=1, norm_layer=None): 90 | super(Bottleneck, self).__init__() 91 | if norm_layer is None: 92 | norm_layer = nn.BatchNorm2d 93 | width = int(planes * (base_width / 64.)) * groups 94 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 95 | self.conv1 = conv1x1(inplanes, width) 96 | self.bn1 = norm_layer(width) 97 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 98 | self.bn2 = norm_layer(width) 99 | self.conv3 = conv1x1(width, planes * self.expansion) 100 | self.bn3 = norm_layer(planes * self.expansion) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | self.stride = stride 104 | 105 | def forward(self, x): 106 | identity = x 107 | 108 | out = self.conv1(x) 109 | out = self.bn1(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv2(out) 113 | out = self.bn2(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv3(out) 117 | out = self.bn3(out) 118 | 119 | if self.downsample is not None: 120 | identity = self.downsample(x) 121 | 122 | out += identity 123 | out = self.relu(out) 124 | 125 | return out 126 | 127 | 128 | class ResNet(nn.Module): 129 | 130 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 131 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 132 | norm_layer=None): 133 | super(ResNet, self).__init__() 134 | if norm_layer is None: 135 | norm_layer = nn.BatchNorm2d 136 | self._norm_layer = norm_layer 137 | 138 | self.inplanes = 64 139 | self.dilation = 1 140 | if replace_stride_with_dilation is None: 141 | # each element in the tuple indicates if we should replace 142 | # the 2x2 stride with a dilated convolution instead 143 | replace_stride_with_dilation = [False, False, False] 144 | if len(replace_stride_with_dilation) != 3: 145 | raise ValueError("replace_stride_with_dilation should be None " 146 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 147 | self.groups = groups 148 | self.base_width = width_per_group 149 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3, 150 | bias=False) 151 | self.bn1 = norm_layer(self.inplanes) 152 | self.relu = nn.ReLU(inplace=True) 153 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 154 | self.layer1 = self._make_layer(block, 64, layers[0]) 155 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 156 | dilate=replace_stride_with_dilation[0]) 157 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 158 | dilate=replace_stride_with_dilation[1]) 159 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 160 | dilate=replace_stride_with_dilation[2]) 161 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 162 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 163 | 164 | for m in self.modules(): 165 | if isinstance(m, nn.Conv2d): 166 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 167 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 168 | nn.init.constant_(m.weight, 1) 169 | nn.init.constant_(m.bias, 0) 170 | 171 | # Zero-initialize the last BN in each residual branch, 172 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 173 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 174 | if zero_init_residual: 175 | for m in self.modules(): 176 | if isinstance(m, Bottleneck): 177 | nn.init.constant_(m.bn3.weight, 0) 178 | elif isinstance(m, BasicBlock): 179 | nn.init.constant_(m.bn2.weight, 0) 180 | 181 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 182 | norm_layer = self._norm_layer 183 | downsample = None 184 | previous_dilation = self.dilation 185 | if dilate: 186 | self.dilation *= stride 187 | stride = 1 188 | if stride != 1 or self.inplanes != planes * block.expansion: 189 | downsample = nn.Sequential( 190 | conv1x1(self.inplanes, planes * block.expansion, stride), 191 | norm_layer(planes * block.expansion), 192 | ) 193 | 194 | layers = [] 195 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 196 | self.base_width, previous_dilation, norm_layer)) 197 | self.inplanes = planes * block.expansion 198 | for _ in range(1, blocks): 199 | layers.append(block(self.inplanes, planes, groups=self.groups, 200 | base_width=self.base_width, dilation=self.dilation, 201 | norm_layer=norm_layer)) 202 | 203 | return nn.Sequential(*layers) 204 | 205 | def _forward_impl(self, x): 206 | # See note [TorchScript super()] 207 | x = self.conv1(x) 208 | x = self.bn1(x) 209 | x = self.relu(x) 210 | # x = self.maxpool(x) 211 | 212 | x = self.layer1(x) 213 | c1 = self.layer2(x) 214 | c2 = self.layer3(c1) 215 | c3 = self.layer4(c2) 216 | 217 | # x = self.avgpool(x) 218 | # x = torch.flatten(x, 1) 219 | # x = self.fc(x) 220 | 221 | return c1, c2, c3 222 | 223 | def forward(self, x): 224 | return self._forward_impl(x) 225 | 226 | 227 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 228 | model = ResNet(block, layers, **kwargs) 229 | if pretrained: 230 | state_dict = load_state_dict_from_url(model_urls[arch], 231 | progress=progress) 232 | model.load_state_dict(state_dict, strict=False) 233 | return model 234 | 235 | 236 | def resnet18(pretrained=False, progress=True, **kwargs): 237 | r"""ResNet-18 model from 238 | `"Deep Residual Learning for Image Recognition"