├── __init__.py ├── util ├── __init__.py ├── util.py └── metric_tool.py ├── SCD ├── __init__.py ├── data │ ├── __init__.py │ ├── transform.py │ └── cd_dataset.py ├── model │ ├── __init__.py │ ├── block │ │ ├── __init__.py │ │ ├── dcnv2.py │ │ ├── schedular.py │ │ ├── vertical.py │ │ ├── heads.py │ │ ├── fpn.py │ │ └── convs.py │ ├── loss │ │ ├── __init__.py │ │ ├── focal.py │ │ └── dice.py │ ├── backbone │ │ ├── __init__.py │ │ └── mobilenetv2.py │ ├── util.py │ ├── create_model.py │ └── network.py ├── util │ ├── __init__.py │ ├── palette.py │ ├── util.py │ ├── metric.py │ └── metric_tool.py ├── test.py ├── option.py └── trainval.py ├── model ├── block │ ├── __init__.py │ ├── dcnv2.py │ ├── schedular.py │ ├── heads.py │ ├── vertical.py │ ├── fpn.py │ └── convs.py ├── loss │ ├── __init__.py │ ├── dice.py │ └── focal.py ├── backbone │ ├── __init__.py │ └── mobilenetv2.py ├── __init__.py ├── util.py ├── network.py └── create_model.py ├── data ├── __init__.py ├── transform.py └── cd_dataset.py ├── test.py ├── README.md ├── option.py └── trainval.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SCD/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/block/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SCD/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SCD/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SCD/util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SCD/model/block/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SCD/model/loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SCD/model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | https://github.com/NVIDIA/pix2pixHD/ 4 | """ 5 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: 3 | https://github.com/NVIDIA/pix2pixHD 4 | https://github.com/VainF/DeepLabV3Plus-Pytorch 5 | """ -------------------------------------------------------------------------------- /SCD/util/palette.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def color_map(dataset): 4 | if dataset == 'SECOND_SCD': 5 | cmap = np.zeros((7, 3), dtype=np.uint8) 6 | cmap[0] = np.array([255, 255, 255]) 7 | cmap[1] = np.array([0, 0, 255]) 8 | cmap[2] = np.array([128, 128, 128]) 9 | cmap[3] = np.array([0, 128, 0]) 10 | cmap[4] = np.array([0, 255, 0]) 11 | cmap[5] = np.array([128, 0, 0]) 12 | cmap[6] = np.array([255, 0, 0]) 13 | else: 14 | raise NotImplementedError 15 | 16 | return cmap 17 | 18 | 19 | def Color2Index(dataset, ColorLabel): 20 | if dataset == 'SECOND_SCD': 21 | num_classes = 7 22 | ST_COLORMAP = [[255, 255, 255], [0, 0, 255], [128, 128, 128], [0, 128, 0], [0, 255, 0], [128, 0, 0], 23 | [255, 0, 0]] 24 | CLASSES = ['unchanged', 'water', 'ground', 'low vegetation', 'tree', 'building', 'sports field'] 25 | else: 26 | raise NotImplementedError 27 | 28 | colormap2label = np.zeros(256 ** 3) 29 | for i, cm in enumerate(ST_COLORMAP): 30 | colormap2label[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i 31 | 32 | data = ColorLabel.astype(np.int32) 33 | idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2] 34 | IndexMap = colormap2label[idx] 35 | IndexMap = IndexMap * (IndexMap < num_classes) 36 | 37 | return IndexMap 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from util.metric_tool import ConfuseMatrixMeter 2 | import torch 3 | from option import Options 4 | from data.cd_dataset import DataLoader 5 | from model.create_model import create_model 6 | from tqdm import tqdm 7 | 8 | if __name__ == '__main__': 9 | opt = Options().parse() 10 | opt.phase = 'test' 11 | test_loader = DataLoader(opt) 12 | test_data = test_loader.load_data() 13 | test_size = len(test_loader) 14 | print("#testing images = %d" % test_size) 15 | 16 | opt.load_pretrain = True 17 | model = create_model(opt) 18 | 19 | tbar = tqdm(test_data, ncols=80) 20 | total_iters = test_size 21 | running_metric = ConfuseMatrixMeter(n_class=2) 22 | running_metric.clear() 23 | 24 | model.eval() 25 | with torch.no_grad(): 26 | for i, _data in enumerate(tbar): 27 | val_pred = model.inference(_data['img1'].cuda(), _data['img2'].cuda()) 28 | # update metric 29 | val_target = _data['cd_label'].detach() 30 | val_pred = torch.argmax(val_pred.detach(), dim=1) 31 | _ = running_metric.update_cm(pr=val_pred.cpu().numpy(), gt=val_target.cpu().numpy()) 32 | val_scores = running_metric.get_scores() 33 | message = '(phase: %s) ' % (opt.phase) 34 | for k, v in val_scores.items(): 35 | message += '%s: %.3f ' % (k, v * 100) 36 | print(message) 37 | -------------------------------------------------------------------------------- /model/loss/dice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.utils.data 4 | import torch.nn as nn 5 | from kornia.losses import dice_loss 6 | 7 | class DICELoss(nn.Module): 8 | def __init__(self): 9 | super(DICELoss, self).__init__() 10 | 11 | def forward(self, input, target): 12 | target = target.squeeze(1) 13 | loss = dice_loss(input, target) 14 | 15 | return loss 16 | 17 | ### version 2 18 | """ 19 | class DICELoss(nn.Module): 20 | def __init__(self, eps=1e-5): 21 | super(DICELoss, self).__init__() 22 | self.eps = eps 23 | 24 | def to_one_hot(self, target): 25 | N, C, H, W = target.size() 26 | assert C == 1 27 | target = torch.zeros(N, 2, H, W).to(target.device).scatter_(1, target, 1) 28 | return target 29 | 30 | def forward(self, input, target): 31 | N, C, _, _ = input.size() 32 | input = F.softmax(input, dim=1) 33 | 34 | #target = self.to_one_hot(target) 35 | target = torch.eye(2)[target.squeeze(1)] 36 | target = target.permute(0, 3, 1, 2).type_as(input) 37 | 38 | dims = tuple(range(1, target.ndimension())) 39 | inter = torch.sum(input * target, dims) 40 | cardinality = torch.sum(input + target, dims) 41 | loss = ((2. * inter) / (cardinality + self.eps)).mean() 42 | 43 | return 1 - loss 44 | """ 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for the TCSVT article 'Cross-Level Attentive Feature Aggregation for Change Detection'. 2 | --------------------------------------------- 3 | Here I provide the PyTorch implementation for CLAFA. 4 | 5 | 6 | ## ENVIRONMENT 7 | >RTX 3090
8 | >python 3.8.8
9 | >PyTorch 1.11.0
10 | >mmcv 1.6.0 11 | 12 | ## Installation 13 | Clone this repo: 14 | 15 | ```shell 16 | git clone https://github.com/xingronaldo/CLAFA.git 17 | cd CLAFA 18 | ``` 19 | 20 | * Install dependencies 21 | 22 | All dependencies can be installed via 'pip'. 23 | 24 | ## Dataset Preparation 25 | Download data and add them to `./datasets`. 26 | 27 | 28 | ## Test 29 | Here I provide the trained models for the SV-CD dataset [Baidu Netdisk, code: CLAF](https://pan.baidu.com/s/1nfqqXA3DsZtU4BtHY-3YOg)A. 30 | 31 | Put them in `./checkpoints`. 32 | 33 | 34 | * Test on the SV-CD dataset with the MobileNetV2 backbone 35 | 36 | ```python 37 | python test.py --backbone mobilenetv2 --name SV_mobilenetv2 --gpu_ids 1 38 | ``` 39 | 40 | * Test on the SV-CD dataset with the ResNet18d backbone 41 | 42 | ```python 43 | python test.py --backbone resnet18d --name SV_resnet18d --gpu_ids 1 44 | ``` 45 | 46 | ## Train & Validation 47 | ```python 48 | python trainval.py --gpu_ids 1 49 | ``` 50 | All the hyperparameters can be adjusted in `option.py`. 51 | 52 | 53 | ## Contact 54 | Email: guangxingwang@mail.nwpu.edu.cn 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /model/loss/focal.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from einops import rearrange 11 | 12 | 13 | class FocalLoss(nn.Module): 14 | def __init__(self, alpha=0.25, gamma=4.0): 15 | super(FocalLoss, self).__init__() 16 | self.alpha = alpha 17 | self.gamma = gamma 18 | if isinstance(alpha, (float, int)): 19 | self.alpha = torch.as_tensor([alpha, 1 - alpha]) 20 | if isinstance(alpha, list): 21 | self.alpha = torch.as_tensor(alpha) 22 | 23 | def forward(self, input, target): 24 | N, C, H, W = input.size() 25 | assert C == 2 26 | # input = input.view(N, C, -1) 27 | # input = input.transpose(1, 2) 28 | # input = input.contiguous().view(-1, C) 29 | input = rearrange(input, 'b c h w -> (b h w) c') 30 | # input = input.contiguous().view(-1) 31 | 32 | target = target.view(-1, 1) 33 | logpt = F.log_softmax(input, dim=1) 34 | logpt = logpt.gather(1, target) 35 | logpt = logpt.view(-1) 36 | pt = Variable(logpt.data.exp()) 37 | 38 | if self.alpha is not None: 39 | if self.alpha.type() != input.data.type(): 40 | self.alpha = self.alpha.type_as(input.data) 41 | at = self.alpha.gather(0, target.data.view(-1)) 42 | logpt = logpt * Variable(at) 43 | loss = -1 * (1-pt)**self.gamma * logpt 44 | 45 | return loss.mean() 46 | 47 | 48 | -------------------------------------------------------------------------------- /SCD/util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied and modified from 3 | https://github.com/NVIDIA/pix2pixHD/tree/master/util 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import torch 8 | import numpy as np 9 | from PIL import Image 10 | from torchvision import utils 11 | 12 | 13 | def mkdirs(paths): 14 | if isinstance(paths, list) and not isinstance(paths, str): 15 | for path in paths: 16 | mkdir(path) 17 | else: 18 | mkdir(paths) 19 | 20 | 21 | def mkdir(path): 22 | if not os.path.exists(path): 23 | os.makedirs(path) 24 | 25 | 26 | def save_image(image_numpy, image_path): 27 | image_pil = Image.fromarray(np.array(image_numpy,dtype=np.uint8)) 28 | image_pil.save(image_path) 29 | 30 | 31 | def make_numpy_grid(tensor_data, pad_value=0,padding=0): 32 | tensor_data = tensor_data.detach() 33 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding) 34 | vis = np.array(vis.cpu()).transpose((1,2,0)) 35 | if vis.shape[2] == 1: 36 | vis = np.stack([vis, vis, vis], axis=-1) 37 | return vis 38 | 39 | 40 | def de_norm(tensor_data): 41 | return tensor_data * 0.5 + 0.5 42 | 43 | 44 | def replace_batchnorm(net): 45 | for child_name, child in net.named_children(): 46 | if hasattr(child, 'fuse'): 47 | setattr(net, child_name, child.fuse()) 48 | elif isinstance(child, torch.nn.Conv2d): 49 | child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0))) 50 | elif isinstance(child, torch.nn.BatchNorm2d): 51 | setattr(net, child_name, torch.nn.Identity()) 52 | else: 53 | replace_batchnorm(child) -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied and modified from 3 | https://github.com/NVIDIA/pix2pixHD/tree/master/util 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import torch 8 | import numpy as np 9 | from PIL import Image 10 | from torchvision import utils 11 | 12 | 13 | def mkdirs(paths): 14 | if isinstance(paths, list) and not isinstance(paths, str): 15 | for path in paths: 16 | mkdir(path) 17 | else: 18 | mkdir(paths) 19 | 20 | 21 | def mkdir(path): 22 | if not os.path.exists(path): 23 | os.makedirs(path) 24 | 25 | 26 | def save_image(image_numpy, image_path): 27 | image_pil = Image.fromarray(np.array(image_numpy,dtype=np.uint8)) 28 | image_pil.save(image_path) 29 | 30 | 31 | def make_numpy_grid(tensor_data, pad_value=0,padding=0): 32 | tensor_data = tensor_data.detach() 33 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding) 34 | vis = np.array(vis.cpu()).transpose((1,2,0)) 35 | if vis.shape[2] == 1: 36 | vis = np.stack([vis, vis, vis], axis=-1) 37 | return vis 38 | 39 | 40 | def de_norm(tensor_data): 41 | return tensor_data * 0.5 + 0.5 42 | 43 | 44 | def replace_batchnorm(net): 45 | for child_name, child in net.named_children(): 46 | if hasattr(child, 'fuse'): 47 | setattr(net, child_name, child.fuse()) 48 | elif isinstance(child, torch.nn.Conv2d): 49 | child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0))) 50 | elif isinstance(child, torch.nn.BatchNorm2d): 51 | setattr(net, child_name, torch.nn.Identity()) 52 | else: 53 | replace_batchnorm(child) -------------------------------------------------------------------------------- /model/block/dcnv2.py: -------------------------------------------------------------------------------- 1 | from mmcv.ops import ModulatedDeformConv2dPack, modulated_deform_conv2d 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class DCNv2(ModulatedDeformConv2dPack): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | out_channels = self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 10 | pyconv_kernels = [1, 3, 5] 11 | pyconv_groups = [1, self.deform_groups // 2, self.deform_groups] 12 | pyconv_levels = [] 13 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups): 14 | pyconv_levels.append(nn.Sequential(nn.Conv2d(self.in_channels, out_channels, kernel_size=pyconv_kernel, 15 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=False), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(True))) 18 | self.pyconv_levels = nn.Sequential(*pyconv_levels) 19 | self.offset = nn.Conv2d(out_channels * 3, out_channels, 1, bias=True) 20 | self.init_weights() 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 23 | out = [] 24 | for level in self.pyconv_levels: 25 | out.append(level(x)) 26 | out = torch.cat(out, dim=1) 27 | 28 | out = self.offset(out) 29 | o1, o2, mask = torch.chunk(out, 3, dim=1) 30 | offset = torch.cat((o1, o2), dim=1) 31 | mask = torch.sigmoid(mask) 32 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, 33 | self.stride, self.padding, 34 | self.dilation, self.groups, 35 | self.deform_groups) 36 | 37 | -------------------------------------------------------------------------------- /model/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_method(*nets, init_type='normal'): 5 | for net in nets: 6 | for module in net.modules(): 7 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.ConvTranspose2d): 8 | if init_type == 'normal': 9 | nn.init.normal_(module.weight.data, 0.0, 0.02) 10 | elif init_type == 'xavier': 11 | nn.init.xavier_normal_(module.weight.data, gain=1.0) 12 | elif init_type == 'kaiming_normal': 13 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu') 14 | elif init_type == 'kaiming_normal_out': 15 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_out', nonlinearity='relu') 16 | elif init_type == 'kaiming_uniform': 17 | nn.init.kaiming_uniform_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu') 18 | elif init_type == 'trunc_normal': 19 | nn.init.trunc_normal_(module.weight.data, mean=0.0, std=1.0, a=- 2.0, b=2.0) 20 | elif init_type == 'orthogonal': 21 | nn.init.orthogonal_(module.weight.data, gain=1.0) 22 | else: 23 | raise NotImplementedError("initialization method [%s] is not implemented" % init_type) 24 | if module.bias is not None: 25 | nn.init.constant_(module.bias.data, 0.0) 26 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.LayerNorm): 27 | nn.init.normal_(module.weight.data, 1.0, 0.02) 28 | nn.init.constant_(module.bias.data, 0.0) 29 | 30 | print("initialize \\backbone networks with [%s]" % init_type) 31 | 32 | -------------------------------------------------------------------------------- /SCD/model/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_method(*nets, init_type='normal'): 5 | for net in nets: 6 | for module in net.modules(): 7 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.ConvTranspose2d): 8 | if init_type == 'normal': 9 | nn.init.normal_(module.weight.data, 0.0, 0.02) 10 | elif init_type == 'xavier': 11 | nn.init.xavier_normal_(module.weight.data, gain=1.0) 12 | elif init_type == 'kaiming_normal': 13 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu') 14 | elif init_type == 'kaiming_normal_out': 15 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_out', nonlinearity='relu') 16 | elif init_type == 'kaiming_uniform': 17 | nn.init.kaiming_uniform_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu') 18 | elif init_type == 'trunc_normal': 19 | nn.init.trunc_normal_(module.weight.data, mean=0.0, std=1.0, a=- 2.0, b=2.0) 20 | elif init_type == 'orthogonal': 21 | nn.init.orthogonal_(module.weight.data, gain=1.0) 22 | else: 23 | raise NotImplementedError("initialization method [%s] is not implemented" % init_type) 24 | if module.bias is not None: 25 | nn.init.constant_(module.bias.data, 0.0) 26 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.LayerNorm): 27 | nn.init.normal_(module.weight.data, 1.0, 0.02) 28 | nn.init.constant_(module.bias.data, 0.0) 29 | 30 | print("initialize \\backbone networks with [%s]" % init_type) 31 | 32 | -------------------------------------------------------------------------------- /SCD/util/metric.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | 5 | def cal_kappa(hist): 6 | if hist.sum() == 0: 7 | kappa = 0 8 | else: 9 | po = np.diag(hist).sum() / hist.sum() 10 | pe = np.matmul(hist.sum(1), hist.sum(0).T) / hist.sum() ** 2 11 | if pe == 1: 12 | kappa = 0 13 | else: 14 | kappa = (po - pe) / (1 - pe) 15 | 16 | return kappa 17 | 18 | 19 | class IOUandSek: 20 | def __init__(self, num_classes): 21 | self.num_classes = num_classes 22 | self.hist = np.zeros((num_classes, num_classes)) 23 | 24 | def _fast_hist(self, label_pred, label_true): 25 | mask = (label_true >= 0) & (label_true < self.num_classes) 26 | hist = np.bincount( 27 | self.num_classes * label_true[mask].astype(int) + 28 | label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) 29 | 30 | return hist 31 | 32 | def add_batch(self, predictions, gts): 33 | for lp, lt in zip(predictions, gts): 34 | self.hist += self._fast_hist(lp.flatten(), lt.flatten()) 35 | 36 | def evaluate(self): 37 | confusion_matrix = np.zeros((2, 2)) 38 | confusion_matrix[0][0] = self.hist[0][0] 39 | confusion_matrix[0][1] = self.hist.sum(1)[0] - self.hist[0][0] 40 | confusion_matrix[1][0] = self.hist.sum(0)[0] - self.hist[0][0] 41 | confusion_matrix[1][1] = self.hist[1:, 1:].sum() 42 | 43 | iou = np.diag(confusion_matrix) / (confusion_matrix.sum(0) + 44 | confusion_matrix.sum(1) - np.diag(confusion_matrix)) 45 | miou = np.mean(iou) 46 | 47 | hist = self.hist.copy() 48 | hist[0][0] = 0 49 | kappa = cal_kappa(hist) 50 | sek = kappa * math.exp(iou[1] - 1) 51 | 52 | score = 0.3 * miou + 0.7 * sek 53 | 54 | return score, miou, sek 55 | 56 | def miou(self): 57 | confusion_matrix = self.hist[1:, 1:] 58 | iou = np.diag(confusion_matrix) / (confusion_matrix.sum(0) + confusion_matrix.sum(1) - np.diag(confusion_matrix)) 59 | 60 | return iou, np.mean(iou) 61 | -------------------------------------------------------------------------------- /SCD/model/block/dcnv2.py: -------------------------------------------------------------------------------- 1 | from mmcv.ops import ModulatedDeformConv2dPack, modulated_deform_conv2d 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class DCNv2(ModulatedDeformConv2dPack): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | out_channels = self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 10 | pyconv_kernels = [1, 3, 5] 11 | pyconv_groups = [1, self.deform_groups // 2, self.deform_groups] 12 | pyconv_levels = [] 13 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups): 14 | pyconv_levels.append(nn.Sequential(nn.Conv2d(self.in_channels, out_channels, kernel_size=pyconv_kernel, 15 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=False), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(True))) 18 | self.pyconv_levels = nn.Sequential(*pyconv_levels) 19 | self.offset = nn.Conv2d(out_channels * 3, out_channels, 1, bias=True) 20 | self.init_weights() 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 23 | out = [] 24 | for level in self.pyconv_levels: 25 | out.append(level(x)) 26 | out = torch.cat(out, dim=1) 27 | 28 | out = self.offset(out) 29 | o1, o2, mask = torch.chunk(out, 3, dim=1) 30 | offset = torch.cat((o1, o2), dim=1) 31 | mask = torch.sigmoid(mask) 32 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, 33 | self.stride, self.padding, 34 | self.dilation, self.groups, 35 | self.deform_groups) 36 | 37 | 38 | from thop import profile 39 | 40 | 41 | # x = torch.randn(2, 512, 8, 8) 42 | # model = DCNv2(in_channels=512, out_channels=128, kernel_size=3, padding=1, deform_groups=4) 43 | # y = model(x) 44 | # print(y.shape) 45 | # flops, params = profile(model, inputs=(x,)) 46 | # print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') 47 | # print('Params = ' + str(params / 1000 ** 2) + 'M') 48 | -------------------------------------------------------------------------------- /model/block/schedular.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/huggingface/transformers 3 | """ 4 | 5 | import math 6 | from functools import partial 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | 11 | def _get_cosine_schedule_with_warmup_lr_lambda( 12 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float 13 | ): 14 | if current_step < num_warmup_steps: 15 | return float(current_step) / float(max(1, num_warmup_steps)) 16 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 17 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 18 | 19 | 20 | def get_cosine_schedule_with_warmup( 21 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1): 22 | """ 23 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 24 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 25 | initial lr set in the optimizer. 26 | Args: 27 | optimizer ([`~torch.optim.Optimizer`]): 28 | The optimizer for which to schedule the learning rate. 29 | num_warmup_steps (`int`): 30 | The number of steps for the warmup phase. 31 | num_training_steps (`int`): 32 | The total number of training steps. 33 | num_cycles (`float`, *optional*, defaults to 0.5): 34 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 35 | following a half-cosine). 36 | last_epoch (`int`, *optional*, defaults to -1): 37 | The index of the last epoch when resuming training. 38 | Return: 39 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 40 | """ 41 | 42 | lr_lambda = partial( 43 | _get_cosine_schedule_with_warmup_lr_lambda, 44 | num_warmup_steps=num_warmup_steps, 45 | num_training_steps=num_training_steps, 46 | num_cycles=num_cycles, 47 | ) 48 | return LambdaLR(optimizer, lr_lambda, last_epoch) -------------------------------------------------------------------------------- /SCD/model/block/schedular.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/huggingface/transformers 3 | """ 4 | 5 | import math 6 | from functools import partial 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | 11 | def _get_cosine_schedule_with_warmup_lr_lambda( 12 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float 13 | ): 14 | if current_step < num_warmup_steps: 15 | return float(current_step) / float(max(1, num_warmup_steps)) 16 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 17 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 18 | 19 | 20 | def get_cosine_schedule_with_warmup( 21 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1): 22 | """ 23 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 24 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 25 | initial lr set in the optimizer. 26 | Args: 27 | optimizer ([`~torch.optim.Optimizer`]): 28 | The optimizer for which to schedule the learning rate. 29 | num_warmup_steps (`int`): 30 | The number of steps for the warmup phase. 31 | num_training_steps (`int`): 32 | The total number of training steps. 33 | num_cycles (`float`, *optional*, defaults to 0.5): 34 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 35 | following a half-cosine). 36 | last_epoch (`int`, *optional*, defaults to -1): 37 | The index of the last epoch when resuming training. 38 | Return: 39 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 40 | """ 41 | 42 | lr_lambda = partial( 43 | _get_cosine_schedule_with_warmup_lr_lambda, 44 | num_warmup_steps=num_warmup_steps, 45 | num_training_steps=num_training_steps, 46 | num_cycles=num_cycles, 47 | ) 48 | return LambdaLR(optimizer, lr_lambda, last_epoch) -------------------------------------------------------------------------------- /SCD/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from option import Options 3 | from data.cd_dataset import DataLoader 4 | from model.create_model import create_model 5 | from tqdm import tqdm 6 | import math 7 | from util.palette import color_map 8 | from util.metric import IOUandSek 9 | import os 10 | import numpy as np 11 | import random 12 | from PIL import Image 13 | 14 | if __name__ == '__main__': 15 | opt = Options().parse() 16 | #opt.batch_size = 1 17 | opt.phase = 'test' 18 | test_loader = DataLoader(opt) 19 | test_data = test_loader.load_data() 20 | test_size = len(test_loader) 21 | print("#testing images = %d" % test_size) 22 | 23 | opt.load_pretrain = True 24 | model = create_model(opt) 25 | 26 | tbar = tqdm(test_data) 27 | total_iters = test_size 28 | metric = IOUandSek(num_classes=7) 29 | 30 | model.eval() 31 | for i, _data in enumerate(tbar): 32 | cd_out, seg_out1, seg_out2 = model.inference(_data['img1'].cuda(), _data['img2'].cuda()) 33 | # update metric 34 | val_target = _data['cd_label'].detach() 35 | cd_out = torch.argmax(cd_out.detach(), dim=1) 36 | # val_pred = torch.where(val_pred > 0.5, torch.ones_like(val_pred), torch.zeros_like(val_pred)).long() 37 | seg_out1 = torch.argmax(seg_out1, dim=1).cpu().numpy() 38 | seg_out2 = torch.argmax(seg_out2, dim=1).cpu().numpy() 39 | cd_out = cd_out.cpu().numpy().astype(np.uint8) 40 | seg_out1[cd_out == 0] = 0 41 | seg_out2[cd_out == 0] = 0 42 | 43 | if opt.save_mask: 44 | cmap = color_map(opt.dataset) 45 | for i in range(seg_out1.shape[0]): 46 | mask = Image.fromarray(seg_out1[i].astype(np.uint8), mode="P") 47 | mask.putpalette(cmap) 48 | os.makedirs(os.path.join(opt.result_dir, 'test', 'im1'), exist_ok=True) 49 | mask.save(os.path.join(opt.result_dir, 'test', 'im1', _data['fname'][i])) 50 | 51 | mask = Image.fromarray(seg_out2[i].astype(np.uint8), mode="P") 52 | mask.putpalette(cmap) 53 | os.makedirs(os.path.join(opt.result_dir, 'test', 'im2'), exist_ok=True) 54 | mask.save(os.path.join(opt.result_dir, 'test', 'im2', _data['fname'][i])) 55 | 56 | metric.add_batch(seg_out1, _data['label1'].numpy()) 57 | metric.add_batch(seg_out2, _data['label2'].numpy()) 58 | score, miou, sek = metric.evaluate() 59 | tbar.set_description("Score: %.3f, IOU: %.3f, SeK: %.3f" % (score * 100.0, miou * 100.0, sek * 100.0)) 60 | 61 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | class Options(): 5 | def __init__(self): 6 | self.parser = argparse.ArgumentParser() 7 | 8 | def init(self): 9 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0. use -1 for CPU') 10 | self.parser.add_argument('--name', type=str, default='SV_mobilenetv2') 11 | self.parser.add_argument('--dataroot', type=str, default='./datasets') 12 | self.parser.add_argument('--dataset', type=str, default='SV') 13 | self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='models are saved here') 14 | 15 | self.parser.add_argument('--result_dir', type=str, default='./results', help='results are saved here') 16 | self.parser.add_argument('--load_pretrain', type=bool, default=False) 17 | 18 | self.parser.add_argument('--phase', type=str, default='train') 19 | self.parser.add_argument('--input_size', type=int, default=256) 20 | self.parser.add_argument('--backbone', type=str, default='mobilenetv2') 21 | self.parser.add_argument('--fpn', type=str, default='fpn') 22 | self.parser.add_argument('--fpn_channels', type=int, default=128) 23 | self.parser.add_argument('--deform_groups', type=int, default=4) 24 | self.parser.add_argument('--gamma_mode', type=str, default='SE') 25 | self.parser.add_argument('--beta_mode', type=str, default='contextgatedconv') 26 | self.parser.add_argument('--num_heads', type=int, default=1) 27 | self.parser.add_argument('--num_points', type=int, default=8) 28 | self.parser.add_argument('--kernel_layers', type=int, default=1) 29 | self.parser.add_argument('--init_type', type=str, default='kaiming_normal') 30 | self.parser.add_argument('--alpha', type=float, default=0.25) 31 | self.parser.add_argument('--gamma', type=int, default=4, help='gamma for Focal loss') 32 | self.parser.add_argument('--dropout_rate', type=float, default=0.1) 33 | 34 | self.parser.add_argument('--batch_size', type=int, default=16) 35 | self.parser.add_argument('--num_epochs', type=int, default=200) 36 | self.parser.add_argument('--warmup_epochs', type=int, default=20) 37 | self.parser.add_argument('--num_workers', type=int, default=4, help='#threads for loading data') 38 | self.parser.add_argument('--lr', type=float, default=5e-4) 39 | self.parser.add_argument('--weight_decay', type=float, default=5e-4) 40 | 41 | def parse(self): 42 | self.init() 43 | self.opt = self.parser.parse_args() 44 | 45 | str_ids = self.opt.gpu_ids.split(',') 46 | self.opt.gpu_ids = [] 47 | for str_id in str_ids: 48 | id = int(str_id) 49 | if id >= 0: 50 | self.opt.gpu_ids.append(id) 51 | 52 | # set gpu ids 53 | if len(self.opt.gpu_ids) > 0: 54 | torch.cuda.set_device(self.opt.gpu_ids[0]) 55 | 56 | args = vars(self.opt) 57 | 58 | print('------------ Options -------------') 59 | for k, v in sorted(args.items()): 60 | print('%s: %s' % (str(k), str(v))) 61 | print('-------------- End ----------------') 62 | 63 | return self.opt 64 | -------------------------------------------------------------------------------- /SCD/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | class Options(): 5 | def __init__(self): 6 | self.parser = argparse.ArgumentParser() 7 | 8 | def init(self): 9 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0. use -1 for CPU') 10 | self.parser.add_argument('--name', type=str, default='SECOND_SCD_mobilenetv2_5') 11 | self.parser.add_argument('--dataroot', type=str, default='../../SupervisedCD/datasets') 12 | self.parser.add_argument('--dataset', type=str, default='SECOND_SCD') 13 | self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='models are saved here') 14 | 15 | self.parser.add_argument('--result_dir', type=str, default='./results', help='results are saved here') 16 | self.parser.add_argument('--load_pretrain', type=bool, default=False) 17 | 18 | self.parser.add_argument('--phase', type=str, default='train') 19 | self.parser.add_argument('--input_size', type=int, default=256) 20 | self.parser.add_argument('--backbone', type=str, default='mobilenetv2') 21 | self.parser.add_argument('--fpn', type=str, default='fpn') 22 | self.parser.add_argument('--fpn_channels', type=int, default=128) 23 | self.parser.add_argument('--deform_groups', type=int, default=4) 24 | self.parser.add_argument('--gamma_mode', type=str, default='SE') 25 | self.parser.add_argument('--beta_mode', type=str, default='contextgatedconv') 26 | self.parser.add_argument('--num_heads', type=int, default=1) 27 | self.parser.add_argument('--num_points', type=int, default=8) 28 | self.parser.add_argument('--kernel_layers', type=int, default=1) 29 | self.parser.add_argument('--init_type', type=str, default='kaiming_normal') 30 | self.parser.add_argument('--alpha', type=float, default=0.25) 31 | self.parser.add_argument('--gamma', type=int, default=4, help='gamma for Focal loss') 32 | self.parser.add_argument('--dropout_rate', type=float, default=0.1) 33 | self.parser.add_argument('--save_mask', type=bool, default=True) 34 | 35 | self.parser.add_argument('--batch_size', type=int, default=16) 36 | self.parser.add_argument('--num_epochs', type=int, default=50) 37 | self.parser.add_argument('--warmup_epochs', type=int, default=5) 38 | self.parser.add_argument('--num_workers', type=int, default=4, help='#threads for loading data') 39 | self.parser.add_argument('--lr', type=float, default=5e-4) 40 | self.parser.add_argument('--weight_decay', type=float, default=5e-4) 41 | 42 | def parse(self): 43 | self.init() 44 | self.opt = self.parser.parse_args() 45 | 46 | str_ids = self.opt.gpu_ids.split(',') 47 | self.opt.gpu_ids = [] 48 | for str_id in str_ids: 49 | id = int(str_id) 50 | if id >= 0: 51 | self.opt.gpu_ids.append(id) 52 | 53 | # set gpu ids 54 | if len(self.opt.gpu_ids) > 0: 55 | torch.cuda.set_device(self.opt.gpu_ids[0]) 56 | 57 | args = vars(self.opt) 58 | 59 | print('------------ Options -------------') 60 | for k, v in sorted(args.items()): 61 | print('%s: %s' % (str(k), str(v))) 62 | print('-------------- End ----------------') 63 | 64 | return self.opt 65 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms.functional as TF 3 | from torchvision import transforms 4 | from torchvision.transforms import InterpolationMode 5 | 6 | 7 | class Transforms(object): 8 | def __call__(self, _data): 9 | img1, img2, cd_label = _data['img1'], _data['img2'], _data['cd_label'] 10 | 11 | if random.random() < 0.5: 12 | img1_ = img1 13 | img1 = img2 14 | img2 = img1_ 15 | 16 | if random.random() < 0.5: 17 | img1 = TF.hflip(img1) 18 | img2 = TF.hflip(img2) 19 | cd_label = TF.hflip(cd_label) 20 | 21 | if random.random() < 0.5: 22 | img1 = TF.vflip(img1) 23 | img2 = TF.vflip(img2) 24 | cd_label = TF.vflip(cd_label) 25 | 26 | if random.random() < 0.5: 27 | angles = [90, 180, 270] 28 | angle = random.choice(angles) 29 | img1 = TF.rotate(img1, angle) 30 | img2 = TF.rotate(img2, angle) 31 | cd_label = TF.rotate(cd_label, angle) 32 | ### We didnt use colorjitters for the SV dataset. 33 | """ 34 | if random.random() < 0.5: 35 | colorjitters = [] 36 | brightness_factor = random.uniform(0.5, 1.5) 37 | colorjitters.append(Lambda(lambda img: TF.adjust_brightness(img, brightness_factor))) 38 | contrast_factor = random.uniform(0.5, 1.5) 39 | colorjitters.append(Lambda(lambda img: TF.adjust_contrast(img, contrast_factor))) 40 | saturation_factor = random.uniform(0.5, 1.5) 41 | colorjitters.append(Lambda(lambda img: TF.adjust_saturation(img, saturation_factor))) 42 | random.shuffle(colorjitters) 43 | colorjitter = Compose(colorjitters) 44 | img1 = colorjitter(img1) 45 | img2 = colorjitter(img2) 46 | """ 47 | if random.random() < 0.5: 48 | i, j, h, w = transforms.RandomResizedCrop(size=(256, 256)).get_params(img=img1, scale=[0.333, 1.0], 49 | ratio=[0.75, 1.333]) 50 | img1 = TF.resized_crop(img1, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR) 51 | img2 = TF.resized_crop(img2, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR) 52 | cd_label = TF.resized_crop(cd_label, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST) 53 | 54 | return {'img1': img1, 'img2': img2, 'cd_label': cd_label} 55 | 56 | 57 | class Lambda(object): 58 | def __init__(self, lambd): 59 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 60 | self.lambd = lambd 61 | 62 | def __call__(self, img): 63 | return self.lambd(img) 64 | 65 | def __repr__(self): 66 | return self.__class__.__name__ + '()' 67 | 68 | 69 | class Compose(object): 70 | def __init__(self, transforms): 71 | self.transforms = transforms 72 | 73 | def __call__(self, img): 74 | for t in self.transforms: 75 | img = t(img) 76 | return img 77 | 78 | def __repr__(self): 79 | format_string = self.__class__.__name__ + '(' 80 | for t in self.transforms: 81 | format_string += '\n' 82 | format_string += ' {0}'.format(t) 83 | format_string += '\n)' 84 | return format_string 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /data/cd_dataset.py: -------------------------------------------------------------------------------- 1 | from .transform import Transforms 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | 10 | def make_dataset(dir): 11 | img_paths = [] 12 | names = [] 13 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 14 | 15 | for root, _, fnames in sorted(os.walk(dir)): 16 | for fname in fnames: 17 | path = os.path.join(root, fname) 18 | img_paths.append(path) 19 | names.append(fname) 20 | 21 | return img_paths, names 22 | 23 | 24 | class Load_Dataset(Dataset): 25 | def __init__(self, opt): 26 | super(Load_Dataset, self).__init__() 27 | self.opt = opt 28 | 29 | self.dir1 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'A') 30 | self.t1_paths, self.fnames = sorted(make_dataset(self.dir1)) 31 | 32 | self.dir2 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'B') 33 | self.t2_paths, _ = sorted(make_dataset(self.dir2)) 34 | 35 | self.dir_label = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'label') 36 | self.label_paths, _ = sorted(make_dataset(self.dir_label)) 37 | 38 | self.dataset_size = len(self.t1_paths) 39 | 40 | self.normalize = transforms.Compose([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 41 | self.transform = transforms.Compose([Transforms()]) 42 | self.to_tensor = transforms.Compose([transforms.ToTensor()]) 43 | 44 | def __len__(self): 45 | return self.dataset_size 46 | 47 | def __getitem__(self, index): 48 | t1_path = self.t1_paths[index] 49 | fname = self.fnames[index] 50 | img1 = Image.open(t1_path) 51 | 52 | t2_path = self.t2_paths[index] 53 | img2 = Image.open(t2_path) 54 | 55 | label_path = self.label_paths[index] 56 | label = np.array(Image.open(label_path)) // 255 57 | cd_label = Image.fromarray(label) 58 | 59 | if self.opt.phase == 'train': 60 | _data = self.transform({'img1': img1, 'img2': img2, 'cd_label': cd_label}) 61 | img1, img2, cd_label = _data['img1'], _data['img2'], _data['cd_label'] 62 | 63 | img1 = self.to_tensor(img1) 64 | img2 = self.to_tensor(img2) 65 | img1 = self.normalize(img1) 66 | img2 = self.normalize(img2) 67 | cd_label = torch.from_numpy(np.array(cd_label)) 68 | input_dict = {'img1': img1, 'img2': img2, 'cd_label': cd_label, 'fname': fname} 69 | 70 | return input_dict 71 | 72 | 73 | class DataLoader(torch.utils.data.Dataset): 74 | 75 | def __init__(self, opt): 76 | self.dataset = Load_Dataset(opt) 77 | self.dataloader = torch.utils.data.DataLoader(self.dataset, 78 | batch_size=opt.batch_size, 79 | shuffle=opt.phase=='train', 80 | pin_memory=True, 81 | drop_last=opt.phase=='train', 82 | num_workers=int(opt.num_workers) 83 | ) 84 | 85 | def load_data(self): 86 | return self.dataloader 87 | 88 | def __len__(self): 89 | return len(self.dataset) 90 | -------------------------------------------------------------------------------- /model/block/heads.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/swz30/MIRNet 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .convs import GatedConv2d, ContextGatedConv2d 8 | 9 | 10 | class GatedResidualUp(nn.Module): 11 | def __init__(self, in_channels, up_mode='conv', gate_mode='gated'): 12 | super(GatedResidualUp, self).__init__() 13 | if up_mode == 'conv': 14 | self.residual_up = nn.Sequential(nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, 15 | output_padding=1, bias=False), 16 | nn.BatchNorm2d(in_channels), 17 | nn.ReLU(True)) 18 | elif up_mode == 'bilinear': 19 | self.residual_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 20 | 21 | if gate_mode == 'gated': 22 | self.gate = GatedConv2d(in_channels, in_channels // 2) 23 | elif gate_mode == 'context_gated': 24 | self.gate = ContextGatedConv2d(in_channels, in_channels // 2) 25 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 26 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True) 27 | ) 28 | self.relu = nn.ReLU(True) 29 | 30 | def forward(self, x): 31 | residual = self.residual_up(x) 32 | residual = self.gate(residual) 33 | up = self.up(x) 34 | out = self.relu(up + residual) 35 | return out 36 | 37 | 38 | class GatedResidualUpHead(nn.Module): 39 | def __init__(self, in_channels=128, num_classes=1, dropout_rate=0.15): 40 | super(GatedResidualUpHead, self).__init__() 41 | 42 | self.up = nn.Sequential(GatedResidualUp(in_channels), 43 | GatedResidualUp(in_channels // 2)) 44 | self.smooth = nn.Sequential(nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, stride=1, padding=1), 45 | nn.ReLU(True), 46 | nn.Dropout2d(dropout_rate)) 47 | self.final = nn.Conv2d(in_channels // 4, num_classes, 1) 48 | 49 | def forward(self, x): 50 | x = self.up(x) 51 | x = self.smooth(x) 52 | x = self.final(x) 53 | 54 | return x 55 | 56 | 57 | class FCNHead(nn.Module): 58 | def __init__(self, in_channels, num_classes, num_convs=1, dropout_rate=0.15): 59 | self.num_convs = num_convs 60 | super(FCNHead, self).__init__() 61 | inter_channels = in_channels // 4 62 | 63 | convs = [] 64 | convs.append(nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 65 | nn.BatchNorm2d(inter_channels), 66 | nn.ReLU(True))) 67 | for i in range(num_convs - 1): 68 | convs.append(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 69 | nn.BatchNorm2d(inter_channels), 70 | nn.ReLU(True))) 71 | self.convs = nn.Sequential(*convs) 72 | self.final = nn.Conv2d(inter_channels, num_classes, 1) 73 | 74 | def forward(self, x): 75 | out = self.convs(x) 76 | out = self.final(out) 77 | 78 | return out 79 | 80 | -------------------------------------------------------------------------------- /SCD/data/transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import random 3 | import torchvision.transforms.functional as TF 4 | from torchvision import transforms 5 | from torchvision.transforms import InterpolationMode 6 | 7 | 8 | class Transforms(object): 9 | def __call__(self, _data): 10 | img1, img2, label1, label2, cd_label = _data['img1'], _data['img2'], _data['label1'], _data['label2'], _data['cd_label'] 11 | 12 | if random.random() < 0.5: 13 | img1_ = img1 14 | img1 = img2 15 | img2 = img1_ 16 | label1_ = label1 17 | label1 = label2 18 | label2 = label1_ 19 | 20 | if random.random() < 0.5: 21 | img1 = TF.hflip(img1) 22 | img2 = TF.hflip(img2) 23 | label1 = TF.hflip(label1) 24 | label2 = TF.hflip(label2) 25 | cd_label = TF.hflip(cd_label) 26 | 27 | if random.random() < 0.5: 28 | img1 = TF.vflip(img1) 29 | img2 = TF.vflip(img2) 30 | label1 = TF.vflip(label1) 31 | label2 = TF.vflip(label2) 32 | cd_label = TF.vflip(cd_label) 33 | 34 | if random.random() < 0.5: 35 | angles = [90, 180, 270] 36 | angle = random.choice(angles) 37 | img1 = TF.rotate(img1, angle) 38 | img2 = TF.rotate(img2, angle) 39 | label1 = TF.rotate(label1, angle) 40 | label2 = TF.rotate(label2, angle) 41 | cd_label = TF.rotate(cd_label, angle) 42 | 43 | if random.random() < 0.5: 44 | colorjitters = [] 45 | brightness_factor = random.uniform(0.5, 1.5) 46 | colorjitters.append(Lambda(lambda img: TF.adjust_brightness(img, brightness_factor))) 47 | contrast_factor = random.uniform(0.5, 1.5) 48 | colorjitters.append(Lambda(lambda img: TF.adjust_contrast(img, contrast_factor))) 49 | saturation_factor = random.uniform(0.5, 1.5) 50 | colorjitters.append(Lambda(lambda img: TF.adjust_saturation(img, saturation_factor))) 51 | random.shuffle(colorjitters) 52 | colorjitter = Compose(colorjitters) 53 | img1 = colorjitter(img1) 54 | img2 = colorjitter(img2) 55 | 56 | if random.random() < 0.5: 57 | i, j, h, w = transforms.RandomResizedCrop(size=(256, 256)).get_params(img=img1, scale=[0.333, 1.0], 58 | ratio=[0.75, 1.333]) 59 | img1 = TF.resized_crop(img1, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR) 60 | img2 = TF.resized_crop(img2, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR) 61 | label1 = TF.resized_crop(label1, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST) 62 | label2 = TF.resized_crop(label2, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST) 63 | cd_label = TF.resized_crop(cd_label, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST) 64 | 65 | return {'img1': img1, 'img2': img2, 'label1': label1, 'label2': label2, 'cd_label': cd_label} 66 | 67 | 68 | class Lambda(object): 69 | def __init__(self, lambd): 70 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 71 | self.lambd = lambd 72 | 73 | def __call__(self, img): 74 | return self.lambd(img) 75 | 76 | def __repr__(self): 77 | return self.__class__.__name__ + '()' 78 | 79 | 80 | class Compose(object): 81 | def __init__(self, transforms): 82 | self.transforms = transforms 83 | 84 | def __call__(self, img): 85 | for t in self.transforms: 86 | img = t(img) 87 | return img 88 | 89 | def __repr__(self): 90 | format_string = self.__class__.__name__ + '(' 91 | for t in self.transforms: 92 | format_string += '\n' 93 | format_string += ' {0}'.format(t) 94 | format_string += '\n)' 95 | return format_string 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /trainval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from option import Options 3 | from data.cd_dataset import DataLoader 4 | from model.create_model import create_model 5 | from tqdm import tqdm 6 | import math 7 | from util.metric_tool import ConfuseMatrixMeter 8 | import os 9 | import numpy as np 10 | import random 11 | 12 | 13 | def setup_seed(seed): 14 | os.environ['PYTHONHASHSEED'] = str(seed) 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.deterministic = False #! 22 | torch.backends.cudnn.benchmark = True #! 23 | torch.backends.cudnn.enabled = True #! for accelerating training 24 | 25 | 26 | class Trainval(object): 27 | def __init__(self, opt): 28 | self.opt = opt 29 | 30 | train_loader = DataLoader(opt) 31 | self.train_data = train_loader.load_data() 32 | train_size = len(train_loader) 33 | print("#training images = %d" % train_size) 34 | opt.phase = 'val' 35 | val_loader = DataLoader(opt) 36 | self.val_data = val_loader.load_data() 37 | val_size = len(val_loader) 38 | print("#validation images = %d" % val_size) 39 | opt.phase = 'train' 40 | 41 | self.model = create_model(opt) 42 | self.optimizer = self.model.optimizer 43 | self.schedular = self.model.schedular 44 | 45 | self.iters = 0 46 | self.total_iters = math.ceil(train_size / opt.batch_size) * opt.num_epochs 47 | self.previous_best = 0.0 48 | self.running_metric = ConfuseMatrixMeter(n_class=2) 49 | 50 | def train(self): 51 | tbar = tqdm(self.train_data, ncols=80) 52 | opt.phase = 'train' 53 | _loss = 0.0 54 | _focal_loss = 0.0 55 | _dice_loss = 0.0 56 | 57 | for i, data in enumerate(tbar): 58 | self.model.detector.train() 59 | focal, dice, p2_loss, p3_loss, p4_loss, p5_loss = self.model(data['img1'].cuda(), data['img2'].cuda(), data['cd_label'].cuda()) 60 | loss = focal * 0.5 + dice + p3_loss + p4_loss + p5_loss 61 | self.optimizer.zero_grad() 62 | loss.backward() 63 | self.optimizer.step() 64 | self.schedular.step() 65 | _loss += loss.item() 66 | _focal_loss += focal.item() 67 | _dice_loss += dice.item() 68 | del loss 69 | 70 | tbar.set_description("Loss: %.3f, Focal: %.3f, Dice: %.3f, LR: %.6f" % 71 | (_loss / (i + 1), _focal_loss / (i + 1), _dice_loss / (i + 1), self.optimizer.param_groups[0]['lr'])) 72 | 73 | def val(self): 74 | tbar = tqdm(self.val_data, ncols=80) 75 | self.running_metric.clear() 76 | opt.phase = 'val' 77 | self.model.eval() 78 | 79 | with torch.no_grad(): 80 | for i, _data in enumerate(tbar): 81 | val_pred = self.model.inference(_data['img1'].cuda(), _data['img2'].cuda()) 82 | # update metric 83 | val_target = _data['cd_label'].detach() 84 | val_pred = torch.argmax(val_pred.detach(), dim=1) 85 | _ = self.running_metric.update_cm(pr=val_pred.cpu().numpy(), gt=val_target.cpu().numpy()) 86 | val_scores = self.running_metric.get_scores() 87 | message = '(phase: %s) ' % (self.opt.phase) 88 | for k, v in val_scores.items(): 89 | message += '%s: %.3f ' % (k, v * 100) 90 | print(message) 91 | 92 | if val_scores['mf1'] >= self.previous_best: 93 | self.model.save(self.opt.name, self.opt.backbone) 94 | self.previous_best = val_scores['mf1'] 95 | 96 | 97 | if __name__ == "__main__": 98 | opt = Options().parse() 99 | trainval = Trainval(opt) 100 | setup_seed(seed=1) 101 | 102 | for epoch in range(1, opt.num_epochs + 1): 103 | print("\n==> Name %s, Epoch %i, previous best = %.3f" % (opt.name, epoch, trainval.previous_best * 100)) 104 | trainval.train() 105 | trainval.val() 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /SCD/data/cd_dataset.py: -------------------------------------------------------------------------------- 1 | from .transform import Transforms 2 | from util.palette import Color2Index 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | import random 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | 11 | 12 | def make_dataset(dir): 13 | img_paths = [] 14 | names = [] 15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 16 | 17 | for root, _, fnames in sorted(os.walk(dir)): 18 | for fname in fnames: 19 | path = os.path.join(root, fname) 20 | img_paths.append(path) 21 | names.append(fname) 22 | 23 | return img_paths, names 24 | 25 | 26 | class Load_Dataset(Dataset): 27 | def __init__(self, opt): 28 | super(Load_Dataset, self).__init__() 29 | self.opt = opt 30 | 31 | self.dir1 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'im1') 32 | self.t1_paths, self.fnames = sorted(make_dataset(self.dir1)) 33 | 34 | self.dir2 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'im2') 35 | self.t2_paths, _ = sorted(make_dataset(self.dir2)) 36 | 37 | self.dir_label1 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'label1') 38 | self.label1_paths, _ = sorted(make_dataset(self.dir_label1)) 39 | 40 | self.dir_label2 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'label2') 41 | self.label2_paths, _ = sorted(make_dataset(self.dir_label2)) 42 | 43 | self.dataset_size = len(self.t1_paths) 44 | 45 | self.normalize = transforms.Compose([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 46 | self.transform = transforms.Compose([Transforms()]) 47 | self.to_tensor = transforms.Compose([transforms.ToTensor()]) 48 | 49 | 50 | def __len__(self): 51 | return self.dataset_size 52 | 53 | def __getitem__(self, index): 54 | t1_path = self.t1_paths[index] 55 | fname = self.fnames[index] 56 | img1 = Image.open(t1_path) 57 | 58 | t2_path = self.t2_paths[index] 59 | img2 = Image.open(t2_path) 60 | 61 | label1_path = self.label1_paths[index] 62 | label1 = Image.open(label1_path) 63 | label1 = Image.fromarray(Color2Index(self.opt.dataset, np.array(label1))) 64 | 65 | label2_path = self.label2_paths[index] 66 | label2 = Image.open(label2_path) 67 | label2 = Image.fromarray(Color2Index(self.opt.dataset, np.array(label2))) 68 | 69 | mask = np.array(label1) 70 | cd_label = np.ones_like(mask) 71 | cd_label[mask == 0] = 0 72 | cd_label = Image.fromarray(cd_label) 73 | 74 | if self.opt.phase == 'train': 75 | _data = self.transform( 76 | {'img1': img1, 'img2': img2, 'label1': label1, 'label2': label2, 'cd_label': cd_label}) 77 | img1, img2, label1, label2, cd_label = _data['img1'], _data['img2'], _data['label1'], _data['label2'], \ 78 | _data['cd_label'] 79 | 80 | img1 = self.to_tensor(img1) 81 | img2 = self.to_tensor(img2) 82 | img1 = self.normalize(img1) 83 | img2 = self.normalize(img2) 84 | label1 = torch.from_numpy(np.array(label1)).long() 85 | label2 = torch.from_numpy(np.array(label2)).long() 86 | cd_label = torch.from_numpy(np.array(cd_label)).long() 87 | input_dict = {'img1': img1, 'img2': img2, 'label1': label1, 'label2': label2, 'cd_label': cd_label, 88 | 'fname': fname} 89 | 90 | return input_dict 91 | 92 | 93 | class DataLoader(torch.utils.data.Dataset): 94 | 95 | def __init__(self, opt): 96 | self.dataset = Load_Dataset(opt) 97 | self.dataloader = torch.utils.data.DataLoader(self.dataset, 98 | batch_size=opt.batch_size, 99 | shuffle=opt.phase=='train', 100 | pin_memory=True, 101 | drop_last=opt.phase=='train', 102 | num_workers=int(opt.num_workers) 103 | ) 104 | 105 | def load_data(self): 106 | return self.dataloader 107 | 108 | def __len__(self): 109 | return len(self.dataset) 110 | -------------------------------------------------------------------------------- /model/block/vertical.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .convs import ConvBnRelu 5 | from mmcv.ops import MultiScaleDeformableAttention 6 | from einops import rearrange 7 | from torch import einsum 8 | 9 | 10 | class ScaledSinuEmbedding(nn.Module): 11 | def __init__(self, dim): 12 | super().__init__() 13 | self.scale = nn.Parameter(torch.ones(1,)) 14 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 15 | self.register_buffer('inv_freq', inv_freq) 16 | 17 | def forward(self, x): 18 | n, device = x.shape[1], x.device 19 | t = torch.arange(n, device=device).type_as(self.inv_freq) 20 | sinu = einsum('i, j -> i j', t, self.inv_freq) 21 | emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1) 22 | return emb * self.scale 23 | 24 | 25 | def get_reference_points(spatial_shapes, device): 26 | reference_points_list = [] 27 | for lvl, (H_, W_) in enumerate(spatial_shapes): 28 | ref_y, ref_x = torch.meshgrid( 29 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 30 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 31 | ref_y = ref_y.reshape(-1)[None] / H_ 32 | ref_x = ref_x.reshape(-1)[None] / W_ 33 | ref = torch.stack((ref_x, ref_y), -1) 34 | reference_points_list.append(ref) 35 | reference_points = torch.cat(reference_points_list, 1) 36 | reference_points = reference_points[:, :, None] 37 | return reference_points 38 | 39 | 40 | ### Attend-then-Filter 41 | class VerticalFusion(nn.Module): 42 | def __init__(self, channels, num_heads=4, num_points=4, kernel_layers=1, up_kernel_size=5, enc_kernel_size=3): 43 | super(VerticalFusion, self).__init__() 44 | self.norm1 = nn.LayerNorm(channels) 45 | self.norm2 = nn.LayerNorm(channels) 46 | self.pos = ScaledSinuEmbedding(channels) 47 | self.crossattn = MultiScaleDeformableAttention(embed_dims=channels, num_levels=1, num_heads=num_heads, 48 | num_points=num_points, batch_first=True, dropout=0) 49 | convs = [] 50 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels)) 51 | for _ in range(kernel_layers - 1): 52 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels)) 53 | self.convs = nn.Sequential(*convs) 54 | self.enc = ConvBnRelu(channels, up_kernel_size ** 2, kernel_size=enc_kernel_size, 55 | stride=1, padding=enc_kernel_size // 2, dilation=1) 56 | 57 | self.upsmp = nn.Upsample(scale_factor=2, mode='nearest') 58 | self.unfold = nn.Unfold(kernel_size=up_kernel_size, dilation=2, 59 | padding=up_kernel_size // 2 * 2) 60 | 61 | def get_deform_inputs(self, x1, x2): 62 | _, _, H1, W1 = x1.size() 63 | _, _, H2, W2 = x2.size() 64 | spatial_shapes = torch.as_tensor([(H2, W2)], dtype=torch.long, device=x2.device) 65 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 66 | reference_points = get_reference_points([(H1, W1)], x1.device) 67 | 68 | return reference_points, spatial_shapes, level_start_index 69 | 70 | def forward(self, x1, x2): 71 | #Attend 72 | reference_points, spatial_shapes, level_start_index = self.get_deform_inputs(x1, x2) 73 | B, C, H, W = x1.size() 74 | _, _, H2, W2 = x2.size() 75 | x1_, x2_ = x1.clone(), x2.clone() 76 | x1 = rearrange(x1, 'b c h w -> b (h w) c') 77 | x2 = rearrange(x2, 'b c h w -> b (h w) c') 78 | x1, x2 = self.norm1(x1), self.norm2(x2) 79 | query_pos = self.pos(x1) 80 | x = self.crossattn(query=x1, value=x2, reference_points=reference_points, spatial_shapes=spatial_shapes, 81 | level_start_index=level_start_index, query_pos=query_pos) 82 | x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) 83 | 84 | #Filter 85 | kernel = self.convs(x2_) 86 | kernel = self.enc(kernel) 87 | kernel = F.softmax(kernel, dim=1) 88 | # x = self.upsmp(x) 89 | x = F.interpolate(x, size=(H2, W2), mode='nearest') 90 | x = self.unfold(x) 91 | # x = x.view(B, C, -1, H * 2, W * 2) 92 | x = x.view(B, C, -1, H2, W2) 93 | fuse = torch.einsum('bkhw,bckhw->bchw', [kernel, x]) 94 | fuse += x2_ 95 | 96 | return fuse 97 | 98 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import timm 5 | from .backbone.mobilenetv2 import mobilenet_v2 6 | from .block.fpn import FPN 7 | from .block.vertical import VerticalFusion 8 | from .block.convs import ConvBnRelu, DsBnRelu 9 | from .util import init_method 10 | from .block.heads import FCNHead, GatedResidualUpHead 11 | 12 | 13 | def get_backbone(backbone_name): 14 | if backbone_name == 'mobilenetv2': 15 | backbone = mobilenet_v2(pretrained=True, progress=True) 16 | backbone.channels = [16, 24, 32, 96, 320] 17 | elif backbone_name == 'resnet18d': 18 | backbone = timm.create_model('resnet18d', pretrained=True, features_only=True) 19 | backbone.channels = [64, 64, 128, 256, 512] 20 | else: 21 | raise NotImplementedError("BACKBONE [%s] is not implemented!\n" % backbone_name) 22 | return backbone 23 | 24 | 25 | def get_fpn(fpn_name, in_channels, out_channels, deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv'): 26 | if fpn_name == 'fpn': 27 | fpn = FPN(in_channels, out_channels, deform_groups, gamma_mode, beta_mode) 28 | else: 29 | raise NotImplementedError("FPN [%s] is not implemented!\n" % fpn_name) 30 | return fpn 31 | 32 | 33 | class Detector(nn.Module): 34 | def __init__(self, backbone_name='mobilenetv2', fpn_name='fpn', fpn_channels=128, 35 | deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv', 36 | num_heads=1, num_points=8, kernel_layers=1, dropout_rate=0.1, init_type='kaiming_normal'): 37 | super().__init__() 38 | self.backbone = get_backbone(backbone_name) 39 | self.fpn = get_fpn(fpn_name, in_channels=self.backbone.channels[-4:], out_channels=fpn_channels, 40 | deform_groups=deform_groups, gamma_mode=gamma_mode, beta_mode=beta_mode) 41 | self.p5_to_p4 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=4, 42 | kernel_layers=kernel_layers) 43 | self.p4_to_p3 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=8, 44 | kernel_layers=kernel_layers) 45 | self.p3_to_p2 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=16, 46 | kernel_layers=kernel_layers) 47 | 48 | self.p5_head = nn.Conv2d(fpn_channels, 2, 1) 49 | self.p4_head = nn.Conv2d(fpn_channels, 2, 1) 50 | self.p3_head = nn.Conv2d(fpn_channels, 2, 1) 51 | self.p2_head = nn.Conv2d(fpn_channels, 2, 1) 52 | self.project = nn.Sequential(nn.Conv2d(fpn_channels*4, fpn_channels, 1, bias=False), 53 | nn.BatchNorm2d(fpn_channels), 54 | nn.ReLU(True) 55 | ) 56 | self.head = GatedResidualUpHead(fpn_channels, 2, dropout_rate=dropout_rate) 57 | # init_method(self.fpn, self.p5_to_p4, self.p4_to_p3, self.p3_to_p2, self.p5_head, self.p4_head, 58 | # self.p3_head, self.p2_head, init_type=init_type) 59 | 60 | def forward(self, x1, x2): 61 | ### Extract backbone features 62 | t1_c1, t1_c2, t1_c3, t1_c4, t1_c5 = self.backbone.forward(x1) 63 | t2_c1, t2_c2, t2_c3, t2_c4, t2_c5 = self.backbone.forward(x2) 64 | t1_p2, t1_p3, t1_p4, t1_p5 = self.fpn([t1_c2, t1_c3, t1_c4, t1_c5]) 65 | t2_p2, t2_p3, t2_p4, t2_p5 = self.fpn([t2_c2, t2_c3, t2_c4, t2_c5]) 66 | 67 | diff_p2 = torch.abs(t1_p2 - t2_p2) 68 | diff_p3 = torch.abs(t1_p3 - t2_p3) 69 | diff_p4 = torch.abs(t1_p4 - t2_p4) 70 | diff_p5 = torch.abs(t1_p5 - t2_p5) 71 | 72 | fea_p5 = diff_p5 73 | pred_p5 = self.p5_head(fea_p5) 74 | fea_p4 = self.p5_to_p4(fea_p5, diff_p4) 75 | pred_p4 = self.p4_head(fea_p4) 76 | fea_p3 = self.p4_to_p3(fea_p4, diff_p3) 77 | pred_p3 = self.p3_head(fea_p3) 78 | fea_p2 = self.p3_to_p2(fea_p3, diff_p2) 79 | pred_p2 = self.p2_head(fea_p2) 80 | pred = self.head(fea_p2) 81 | 82 | pred_p2 = F.interpolate(pred_p2, size=(256, 256), mode='bilinear', align_corners=False) 83 | pred_p3 = F.interpolate(pred_p3, size=(256, 256), mode='bilinear', align_corners=False) 84 | pred_p4 = F.interpolate(pred_p4, size=(256, 256), mode='bilinear', align_corners=False) 85 | pred_p5 = F.interpolate(pred_p5, size=(256, 256), mode='bilinear', align_corners=False) 86 | 87 | return pred, pred_p2, pred_p3, pred_p4, pred_p5 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /model/create_model.py: -------------------------------------------------------------------------------- 1 | from .network import Detector 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | import os 7 | import torch.optim as optim 8 | from .block.schedular import get_cosine_schedule_with_warmup 9 | from .loss.focal import FocalLoss 10 | from .loss.dice import DICELoss 11 | 12 | def get_model(backbone_name='mobilenetv2', fpn_name='fpn', fpn_channels=128, deform_groups=4, 13 | gamma_mode='SE', beta_mode='contextgatedconv', num_heads=1, num_points=8, kernel_layers=1, 14 | dropout_rate=0.1, init_type='kaiming_normal'): 15 | detector = Detector(backbone_name, fpn_name, fpn_channels, deform_groups, gamma_mode, beta_mode, 16 | num_heads, num_points, kernel_layers, dropout_rate, init_type) 17 | print(detector) 18 | 19 | return detector 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, opt): 24 | super(Model, self).__init__() 25 | self.device = torch.device("cuda:%s" % opt.gpu_ids[0] if torch.cuda.is_available() else "cpu") 26 | self.opt = opt 27 | self.base_lr = opt.lr 28 | self.save_dir = os.path.join(opt.checkpoint_dir, opt.name) 29 | os.makedirs(self.save_dir, exist_ok=True) 30 | 31 | self.detector = get_model(backbone_name=opt.backbone, fpn_name=opt.fpn, fpn_channels=opt.fpn_channels, 32 | deform_groups=opt.deform_groups, gamma_mode=opt.gamma_mode, beta_mode=opt.beta_mode, 33 | num_heads=opt.num_heads, num_points=opt.num_points, kernel_layers=opt.kernel_layers, 34 | dropout_rate=opt.dropout_rate, init_type=opt.init_type) 35 | self.focal = FocalLoss(alpha=opt.alpha, gamma=opt.gamma) 36 | self.dice = DICELoss() 37 | 38 | self.optimizer = optim.AdamW(self.detector.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) 39 | # Here, 625 = #training images // batch_size. 40 | self.schedular = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=625 * opt.warmup_epochs, 41 | num_training_steps=625 * opt.num_epochs) 42 | if opt.load_pretrain: 43 | self.load_ckpt(self.detector, self.optimizer, opt.name, opt.backbone) 44 | self.detector.cuda() 45 | 46 | print("---------- Networks initialized -------------") 47 | 48 | def forward(self, x1, x2, label): 49 | pred, pred_p2, pred_p3, pred_p4, pred_p5 = self.detector(x1, x2) 50 | label = label.long() 51 | focal = self.focal(pred, label) 52 | dice = self.dice(pred, label) 53 | p2_loss = self.focal(pred_p2, label) * 0.5 + self.dice(pred_p2, label) 54 | p3_loss = self.focal(pred_p3, label) * 0.5 + self.dice(pred_p3, label) 55 | p4_loss = self.focal(pred_p4, label) * 0.5 + self.dice(pred_p4, label) 56 | p5_loss = self.focal(pred_p5, label) * 0.5 + self.dice(pred_p5, label) 57 | 58 | return focal, dice, p2_loss, p3_loss, p4_loss, p5_loss 59 | 60 | def inference(self, x1, x2): 61 | with torch.no_grad(): 62 | pred, _, _, _, _ = self.detector(x1, x2) 63 | return pred 64 | 65 | def load_ckpt(self, network, optimizer, name, backbone): 66 | save_filename = '%s_%s_best.pth' % (name, backbone) 67 | save_path = os.path.join(self.save_dir, save_filename) 68 | if not os.path.isfile(save_path): 69 | print("%s not exists yet!" % save_path) 70 | raise ("%s must exist!" % save_filename) 71 | else: 72 | checkpoint = torch.load(save_path, map_location=self.device) 73 | network.load_state_dict(checkpoint['network'], False) 74 | 75 | def save_ckpt(self, network, optimizer, model_name, backbone): 76 | save_filename = '%s_%s_best.pth' % (model_name, backbone) 77 | save_path = os.path.join(self.save_dir, save_filename) 78 | if os.path.exists(save_path): 79 | os.remove(save_path) 80 | torch.save({'network': network.cpu().state_dict(), 81 | 'optimizer': optimizer.state_dict()}, 82 | save_path) 83 | if torch.cuda.is_available(): 84 | network.cuda() 85 | 86 | def save(self, model_name, backbone): 87 | self.save_ckpt(self.detector, self.optimizer, model_name, backbone) 88 | 89 | def name(self): 90 | return self.opt.name 91 | 92 | 93 | def create_model(opt): 94 | model = Model(opt) 95 | print("model [%s] was created" % model.name()) 96 | 97 | return model.cuda() 98 | 99 | -------------------------------------------------------------------------------- /SCD/model/loss/focal.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from einops import rearrange 11 | from typing import Optional, List 12 | from functools import partial 13 | from torch.nn.modules.loss import _Loss 14 | import numpy as np 15 | 16 | 17 | class BinaryFocalLoss(nn.Module): 18 | def __init__(self, alpha=0.25, gamma=4.0): 19 | super(BinaryFocalLoss, self).__init__() 20 | self.alpha = alpha 21 | self.gamma = gamma 22 | if isinstance(alpha, (float, int)): 23 | self.alpha = torch.as_tensor([alpha, 1 - alpha]) 24 | if isinstance(alpha, list): 25 | self.alpha = torch.as_tensor(alpha) 26 | 27 | def forward(self, input, target): 28 | N, C, H, W = input.size() 29 | assert C == 2 30 | # input = input.view(N, C, -1) 31 | # input = input.transpose(1, 2) 32 | # input = input.contiguous().view(-1, C) 33 | input = rearrange(input, 'b c h w -> (b h w) c') 34 | # input = input.contiguous().view(-1) 35 | 36 | target = target.view(-1, 1) 37 | logpt = F.log_softmax(input, dim=1) 38 | logpt = logpt.gather(1, target) 39 | logpt = logpt.view(-1) 40 | pt = Variable(logpt.data.exp()) 41 | 42 | if self.alpha is not None: 43 | if self.alpha.type() != input.data.type(): 44 | self.alpha = self.alpha.type_as(input.data) 45 | at = self.alpha.gather(0, target.data.view(-1)) 46 | logpt = logpt * Variable(at) 47 | loss = -1 * (1-pt)**self.gamma * logpt 48 | 49 | return loss.mean() 50 | 51 | 52 | class FocalLoss(_Loss): 53 | def __init__( 54 | self, 55 | alpha: Optional[float] = 0.25, 56 | gamma: Optional[float] = 2.0, 57 | ignore_index: Optional[int] = None, 58 | normalized: bool = False, 59 | reduced_threshold: Optional[float] = None, 60 | ): 61 | super(FocalLoss, self).__init__() 62 | self.ignore_index = ignore_index 63 | self.focal_loss_fn = partial( 64 | focal_loss_with_logits, 65 | alpha=alpha, 66 | gamma=gamma, 67 | reduced_threshold=reduced_threshold, 68 | reduction='mean', 69 | normalized=normalized 70 | ) 71 | 72 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 73 | num_classes = y_pred.size(1) 74 | loss = 0 75 | 76 | # Filter anchors with -1 label from loss computation 77 | if self.ignore_index is not None: 78 | not_ignored = y_true != self.ignore_index 79 | 80 | for cls in range(num_classes): 81 | cls_y_true = (y_true == cls).long() 82 | cls_y_pred = y_pred[:, cls, ...] 83 | 84 | if self.ignore_index is not None: 85 | cls_y_true = cls_y_true[not_ignored] 86 | cls_y_pred = cls_y_pred[not_ignored] 87 | 88 | loss += self.focal_loss_fn(cls_y_pred, cls_y_true) 89 | 90 | return loss 91 | 92 | 93 | def focal_loss_with_logits( 94 | output: torch.Tensor, 95 | target: torch.Tensor, 96 | gamma: float = 2.0, 97 | alpha: Optional[float] = 0.25, 98 | reduction: str = 'mean', 99 | normalized: bool = True, 100 | reduced_threshold: Optional[float] = None, 101 | eps: float = 1e-6 102 | ) -> torch.Tensor: 103 | 104 | target = target.type(output.type()) 105 | logpt = F.binary_cross_entropy_with_logits(output, target, reduction='none') 106 | pt = torch.exp(-logpt) 107 | 108 | # compute the loss 109 | if reduced_threshold is None: 110 | focal_term = (1.0 - pt).pow(gamma) 111 | else: 112 | focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) 113 | focal_term[pt < reduced_threshold] = 1 114 | 115 | loss = focal_term * logpt 116 | 117 | if alpha is not None: 118 | loss *= alpha * target + (1 - alpha) * (1 - target) 119 | 120 | if normalized: 121 | norm_factor = focal_term.sum().clamp_min(eps) 122 | loss /= norm_factor 123 | 124 | if reduction == 'mean': 125 | loss = loss.mean() 126 | if reduction == 'sum': 127 | loss = loss.sum() 128 | if reduction == 'batchwise_mean': 129 | loss = loss.sum(0) 130 | 131 | return loss 132 | 133 | 134 | 135 | if __name__ == '__main__': 136 | x = torch.randn(3, 7, 256, 256) 137 | y = torch.ones(3, 256, 256).long() 138 | model = FocalLoss(alpha=0.25, gamma=2) 139 | out = model(x, y) 140 | print(out) 141 | -------------------------------------------------------------------------------- /SCD/model/loss/dice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | from kornia.losses import dice_loss 6 | from torch.autograd import Variable 7 | from typing import Optional, List 8 | from functools import partial 9 | from torch.nn.modules.loss import _Loss 10 | import numpy as np 11 | 12 | class BinaryDICELoss(nn.Module): 13 | def __init__(self): 14 | super(BinaryDICELoss, self).__init__() 15 | 16 | def forward(self, input, target): 17 | target = target.squeeze(1) 18 | loss = dice_loss(input, target) 19 | 20 | return loss 21 | 22 | 23 | ### version 2 24 | """ 25 | class BinaryDICELoss(nn.Module): 26 | def __init__(self, eps=1e-5): 27 | super(BinaryDICELoss, self).__init__() 28 | self.eps = eps 29 | 30 | def to_one_hot(self, target): 31 | N, C, H, W = target.size() 32 | assert C == 1 33 | target = torch.zeros(N, 2, H, W).to(target.device).scatter_(1, target, 1) 34 | return target 35 | 36 | def forward(self, input, target): 37 | N, C, _, _ = input.size() 38 | input = F.softmax(input, dim=1) 39 | 40 | #target = self.to_one_hot(target) 41 | target = torch.eye(2)[target.squeeze(1)] 42 | target = target.permute(0, 3, 1, 2).type_as(input) 43 | 44 | dims = tuple(range(1, target.ndimension())) 45 | inter = torch.sum(input * target, dims) 46 | cardinality = torch.sum(input + target, dims) 47 | loss = ((2. * inter) / (cardinality + self.eps)).mean() 48 | 49 | return 1 - loss 50 | """ 51 | 52 | class DICELoss(_Loss): 53 | def __init__( 54 | self, 55 | smooth: float = 0.0, 56 | ignore_index: Optional[int] = None, 57 | eps: float = 1e-7, 58 | ): 59 | super(DICELoss, self).__init__() 60 | self.smooth = smooth 61 | self.eps = eps 62 | self.ignore_index = ignore_index 63 | 64 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 65 | 66 | assert y_true.size(0) == y_pred.size(0) 67 | y_pred = y_pred.log_softmax(dim=1).exp() 68 | 69 | bs = y_true.size(0) 70 | num_classes = y_pred.size(1) 71 | dims = (0, 2) 72 | 73 | y_true = y_true.view(bs, -1) 74 | y_pred = y_pred.view(bs, num_classes, -1) 75 | 76 | if self.ignore_index is not None: 77 | mask = y_true != self.ignore_index 78 | y_pred = y_pred * mask.unsqueeze(1) 79 | 80 | y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C 81 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W 82 | else: 83 | y_true = F.one_hot(y_true.to(torch.long), num_classes) # N,H*W -> N,H*W, C 84 | y_true = y_true.permute(0, 2, 1) # N, C, H*W 85 | 86 | scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims) 87 | loss = 1.0 - scores 88 | 89 | mask = y_true.sum(dims) > 0 90 | loss *= mask.to(loss.dtype) 91 | 92 | return self.aggregate_loss(loss) 93 | 94 | def aggregate_loss(self, loss): 95 | return loss.mean() 96 | 97 | def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: 98 | return soft_dice_score(output, target, smooth, eps, dims) 99 | 100 | 101 | def to_tensor(x, dtype=None) -> torch.Tensor: 102 | if isinstance(x, torch.Tensor): 103 | if dtype is not None: 104 | x = x.type(dtype) 105 | return x 106 | if isinstance(x, np.ndarray): 107 | x = torch.from_numpy(x) 108 | if dtype is not None: 109 | x = x.type(dtype) 110 | return x 111 | if isinstance(x, (list, tuple)): 112 | x = np.array(x) 113 | x = torch.from_numpy(x) 114 | if dtype is not None: 115 | x = x.type(dtype) 116 | return x 117 | 118 | 119 | def soft_dice_score( 120 | output: torch.Tensor, 121 | target: torch.Tensor, 122 | smooth: float = 0.0, 123 | eps: float = 1e-7, 124 | dims=None, 125 | ) -> torch.Tensor: 126 | assert output.size() == target.size() 127 | if dims is not None: 128 | intersection = torch.sum(output * target, dim=dims) 129 | cardinality = torch.sum(output + target, dim=dims) 130 | else: 131 | intersection = torch.sum(output * target) 132 | cardinality = torch.sum(output + target) 133 | dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) 134 | 135 | return dice_score 136 | 137 | 138 | if __name__ == '__main__': 139 | x = torch.randn(3, 7, 256, 256) 140 | y = torch.zeros(3, 256, 256).long() 141 | model = DICELoss() 142 | output = model(x, y) 143 | print(output) 144 | -------------------------------------------------------------------------------- /SCD/model/block/vertical.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .convs import ConvBnRelu, DsBnRelu, PyConv2d 5 | from mmcv.ops import MultiScaleDeformableAttention 6 | from einops import rearrange 7 | from torch import einsum 8 | 9 | 10 | class ScaledSinuEmbedding(nn.Module): 11 | def __init__(self, dim): 12 | super().__init__() 13 | self.scale = nn.Parameter(torch.ones(1,)) 14 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 15 | self.register_buffer('inv_freq', inv_freq) 16 | 17 | def forward(self, x): 18 | n, device = x.shape[1], x.device 19 | t = torch.arange(n, device=device).type_as(self.inv_freq) 20 | sinu = einsum('i, j -> i j', t, self.inv_freq) 21 | emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1) 22 | return emb * self.scale 23 | 24 | 25 | def get_reference_points(spatial_shapes, device): 26 | reference_points_list = [] 27 | for lvl, (H_, W_) in enumerate(spatial_shapes): 28 | ref_y, ref_x = torch.meshgrid( 29 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 30 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 31 | ref_y = ref_y.reshape(-1)[None] / H_ 32 | ref_x = ref_x.reshape(-1)[None] / W_ 33 | ref = torch.stack((ref_x, ref_y), -1) 34 | reference_points_list.append(ref) 35 | reference_points = torch.cat(reference_points_list, 1) 36 | reference_points = reference_points[:, :, None] 37 | return reference_points 38 | 39 | 40 | class VerticalFusion(nn.Module): 41 | def __init__(self, channels, num_heads=4, num_points=4, kernel_layers=1, up_kernel_size=5, enc_kernel_size=3): 42 | super(VerticalFusion, self).__init__() 43 | self.norm1 = nn.LayerNorm(channels) 44 | self.norm2 = nn.LayerNorm(channels) 45 | self.pos = ScaledSinuEmbedding(channels) 46 | self.crossattn = MultiScaleDeformableAttention(embed_dims=channels, num_levels=1, num_heads=num_heads, 47 | num_points=num_points, batch_first=True, dropout=0) 48 | convs = [] 49 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels)) 50 | for _ in range(kernel_layers - 1): 51 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels)) 52 | self.convs = nn.Sequential(*convs) 53 | self.enc = ConvBnRelu(channels, up_kernel_size ** 2, kernel_size=enc_kernel_size, 54 | stride=1, padding=enc_kernel_size // 2, dilation=1) 55 | 56 | self.upsmp = nn.Upsample(scale_factor=2, mode='nearest') 57 | self.unfold = nn.Unfold(kernel_size=up_kernel_size, dilation=2, 58 | padding=up_kernel_size // 2 * 2) 59 | 60 | def get_deform_inputs(self, x1, x2): 61 | _, _, H1, W1 = x1.size() 62 | _, _, H2, W2 = x2.size() 63 | spatial_shapes = torch.as_tensor([(H2, W2)], dtype=torch.long, device=x2.device) 64 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 65 | reference_points = get_reference_points([(H1, W1)], x1.device) 66 | 67 | return reference_points, spatial_shapes, level_start_index 68 | 69 | def forward(self, x1, x2): 70 | reference_points, spatial_shapes, level_start_index = self.get_deform_inputs(x1, x2) 71 | B, C, H, W = x1.size() 72 | _, _, H2, W2 = x2.size() 73 | x1_, x2_ = x1.clone(), x2.clone() 74 | x1 = rearrange(x1, 'b c h w -> b (h w) c') 75 | x2 = rearrange(x2, 'b c h w -> b (h w) c') 76 | x1, x2 = self.norm1(x1), self.norm2(x2) 77 | query_pos = self.pos(x1) 78 | x = self.crossattn(query=x1, value=x2, reference_points=reference_points, spatial_shapes=spatial_shapes, 79 | level_start_index=level_start_index, query_pos=query_pos) 80 | x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) 81 | kernel = self.convs(x2_) 82 | kernel = self.enc(kernel) 83 | kernel = F.softmax(kernel, dim=1) 84 | # x = self.upsmp(x) 85 | x = F.interpolate(x, size=(H2, W2), mode='nearest') 86 | x = self.unfold(x) 87 | # x = x.view(B, C, -1, H * 2, W * 2) 88 | x = x.view(B, C, -1, H2, W2) 89 | fuse = torch.einsum('bkhw,bckhw->bchw', [kernel, x]) 90 | fuse += x2_ 91 | 92 | return fuse 93 | 94 | 95 | # from thop import profile 96 | # 97 | # 98 | # x1 = torch.randn(2, 128, 32, 32) 99 | # x2 = torch.randn(2, 128, 64, 64) 100 | # model = VerticalFusion(channels=128) 101 | # y = model(x1, x2) 102 | # print(y.shape) 103 | # flops, params = profile(model, inputs=([x1, x2])) 104 | # print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') 105 | # print('Params = ' + str(params / 1000 ** 2) + 'M') 106 | -------------------------------------------------------------------------------- /model/backbone/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.hub import load_state_dict_from_url 3 | 4 | model_urls = { 5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 6 | } 7 | 8 | 9 | class ConvBNReLU(nn.Sequential): 10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1): 11 | padding = (kernel_size - 1) // 2 12 | if dilation != 1: 13 | padding = dilation 14 | super(ConvBNReLU, self).__init__( 15 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation, 16 | bias=False), 17 | nn.BatchNorm2d(out_planes), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | class InvertedResidual(nn.Module): 23 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 24 | super(InvertedResidual, self).__init__() 25 | self.stride = stride 26 | assert stride in [1, 2] 27 | 28 | hidden_dim = int(round(inp * expand_ratio)) 29 | self.use_res_connect = self.stride == 1 and inp == oup 30 | 31 | layers = [] 32 | if expand_ratio != 1: 33 | # pw 34 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 35 | layers.extend([ 36 | # dw 37 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation), 38 | # pw-linear 39 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 40 | nn.BatchNorm2d(oup), 41 | ]) 42 | self.conv = nn.Sequential(*layers) 43 | 44 | def forward(self, x): 45 | if self.use_res_connect: 46 | return x + self.conv(x) 47 | else: 48 | return self.conv(x) 49 | 50 | 51 | class MobileNetV2(nn.Module): 52 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0, replace_stride_with_dilation=False): 53 | super(MobileNetV2, self).__init__() 54 | block = InvertedResidual 55 | input_channel = 32 56 | last_channel = 1280 57 | # inverted_residual_setting = [ 58 | # # t, c, n, s, d 59 | # [1, 16, 1, 1, 1], 60 | # [6, 24, 2, 2, 1], 61 | # [6, 32, 3, 2, 1], 62 | # [6, 64, 4, 2, 1], 63 | # [6, 96, 3, 1, 1], 64 | # [6, 160, 3, 2, 1], 65 | # [6, 320, 1, 1, 1], 66 | # ] 67 | inverted_residual_setting = [ 68 | # t, c, n, s, d 69 | [1, 16, 1, 1, 1], 70 | [6, 24, 2, 2, 1], 71 | [6, 32, 3, 2, 1], 72 | [6, 64, 4, 1, 2] if replace_stride_with_dilation else [6, 64, 4, 2, 1], 73 | [6, 96, 3, 1, 1], 74 | [6, 160, 3, 1, 2] if replace_stride_with_dilation else [6, 160, 3, 2, 1], 75 | [6, 320, 1, 1, 1], 76 | ] 77 | self.channels = [16, 24, 32, 96, 320] 78 | 79 | # building first layer 80 | input_channel = int(input_channel * width_mult) 81 | self.last_channel = int(last_channel * max(1.0, width_mult)) 82 | features = [ConvBNReLU(3, input_channel, stride=2)] 83 | # building inverted residual blocks 84 | for t, c, n, s, d in inverted_residual_setting: 85 | output_channel = int(c * width_mult) 86 | for i in range(n): 87 | stride = s if i == 0 else 1 88 | dilation = d if i == 0 else 1 89 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d)) 90 | input_channel = output_channel 91 | # building last several layers 92 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 93 | # make it nn.Sequential 94 | self.features = nn.Sequential(*features) 95 | 96 | # weight initialization 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 100 | if m.bias is not None: 101 | nn.init.zeros_(m.bias) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | nn.init.ones_(m.weight) 104 | nn.init.zeros_(m.bias) 105 | elif isinstance(m, nn.Linear): 106 | nn.init.normal_(m.weight, 0, 0.01) 107 | nn.init.zeros_(m.bias) 108 | 109 | def forward(self, x): 110 | res = [] 111 | for idx, m in enumerate(self.features): 112 | x = m(x) 113 | if idx in [1, 3, 6, 13, 17]: 114 | res.append(x) 115 | return res 116 | 117 | 118 | def mobilenet_v2(pretrained=True, progress=True, replace_stride_with_dilation=False, **kwargs): 119 | model = MobileNetV2(replace_stride_with_dilation=replace_stride_with_dilation, **kwargs) 120 | if pretrained: 121 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 122 | progress=progress) 123 | print("loading imagenet pretrained mobilenetv2") 124 | model.load_state_dict(state_dict, strict=False) 125 | print("loaded imagenet pretrained mobilenetv2") 126 | return model 127 | -------------------------------------------------------------------------------- /model/block/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from .dcnv2 import DCNv2 6 | from .convs import GatedConv2d, ContextGatedConv2d 7 | 8 | 9 | class GenerateGamma(nn.Module): 10 | def __init__(self, channels=128, mode='SE'): 11 | super(GenerateGamma, self).__init__() 12 | self.mode = mode 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.max_pool = nn.AdaptiveMaxPool2d(1) 15 | 16 | self.fc = nn.Sequential(nn.Conv2d(channels, channels // 4, 1, bias=False), 17 | nn.ReLU(True), 18 | nn.Conv2d(channels // 4, channels, 1, bias=False)) 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | def forward(self, x): 22 | avg_out = self.fc(self.avg_pool(x)) 23 | if self.mode == 'SE': 24 | return self.sigmoid(avg_out) 25 | elif self.mode == 'CBAM': 26 | max_out = self.fc(self.max_pool(x)) 27 | out = avg_out + max_out 28 | return self.sigmoid(out) 29 | else: 30 | raise NotImplementedError 31 | 32 | 33 | class GenerateBeta(nn.Module): 34 | def __init__(self, channels=128, mode='conv'): 35 | super(GenerateBeta, self).__init__() 36 | self.stem = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, bias=True), nn.ReLU(True)) 37 | if mode == 'conv': 38 | self.conv = nn.Conv2d(channels, channels, 3, padding=1, bias=True) 39 | elif mode == 'gatedconv': 40 | self.conv = GatedConv2d(channels, channels, 3, padding=1, bias=True) 41 | elif mode == 'contextgatedconv': 42 | self.conv = ContextGatedConv2d(channels, channels, 3, padding=1, bias=True) 43 | else: 44 | raise NotImplementedError 45 | 46 | def forward(self, x): 47 | x = self.stem(x) 48 | return self.conv(x) 49 | 50 | 51 | ### MoFPN 52 | class FPN(nn.Module): 53 | def __init__(self, in_channels, out_channels=128, deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv'): 54 | super(FPN, self).__init__() 55 | 56 | self.p2 = DCNv2(in_channels=in_channels[0], out_channels=out_channels, 57 | kernel_size=3, padding=1, deform_groups=deform_groups) 58 | self.p3 = DCNv2(in_channels=in_channels[1], out_channels=out_channels, 59 | kernel_size=3, padding=1, deform_groups=deform_groups) 60 | self.p4 = DCNv2(in_channels=in_channels[2], out_channels=out_channels, 61 | kernel_size=3, padding=1, deform_groups=deform_groups) 62 | self.p5 = DCNv2(in_channels=in_channels[3], out_channels=out_channels, 63 | kernel_size=3, padding=1, deform_groups=deform_groups) 64 | 65 | self.p5_bn = nn.BatchNorm2d(out_channels, affine=True) 66 | self.p4_bn = nn.BatchNorm2d(out_channels, affine=False) 67 | self.p3_bn = nn.BatchNorm2d(out_channels, affine=False) 68 | self.p2_bn = nn.BatchNorm2d(out_channels, affine=False) 69 | self.activation = nn.ReLU(True) 70 | 71 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 72 | self.p4_Gamma = GenerateGamma(out_channels, mode=gamma_mode) 73 | self.p4_beta = GenerateBeta(out_channels, mode=beta_mode) 74 | self.p3_Gamma = GenerateGamma(out_channels, mode=gamma_mode) 75 | self.p3_beta = GenerateBeta(out_channels, mode=beta_mode) 76 | self.p2_Gamma = GenerateGamma(out_channels, mode=gamma_mode) 77 | self.p2_beta = GenerateBeta(out_channels, mode=beta_mode) 78 | 79 | self.p5_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 80 | self.p4_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 81 | self.p3_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 82 | self.p2_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 83 | def forward(self, input): 84 | c2, c3, c4, c5 = input 85 | 86 | p5 = self.activation(self.p5_bn(self.p5(c5))) 87 | p5_up = F.interpolate(p5, size=c4.shape[-2:], mode='bilinear', align_corners=False) 88 | p4 = self.p4_bn(self.p4(c4)) 89 | p4_gamma, p4_beta = self.p4_Gamma(p5_up), self.p4_beta(p5_up) 90 | p4 = self.activation(p4 * (1 + p4_gamma) + p4_beta) 91 | p4_up = F.interpolate(p4, size=c3.shape[-2:], mode='bilinear', align_corners=False) 92 | p3 = self.p3_bn(self.p3(c3)) 93 | p3_gamma, p3_beta = self.p3_Gamma(p4_up), self.p3_beta(p4_up) 94 | p3 = self.activation(p3 * (1 + p3_gamma) + p3_beta) 95 | p3_up = F.interpolate(p3, size=c2.shape[-2:], mode='bilinear', align_corners=False) 96 | p2 = self.p2_bn(self.p2(c2)) 97 | p2_gamma, p2_beta = self.p2_Gamma(p3_up), self.p2_beta(p3_up) 98 | p2 = self.activation(p2 * (1 + p2_gamma) + p2_beta) 99 | 100 | p5 = self.p5_smooth(p5) 101 | p4 = self.p4_smooth(p4) 102 | p3 = self.p3_smooth(p3) 103 | p2 = self.p2_smooth(p2) 104 | 105 | return p2, p3, p4, p5 106 | 107 | 108 | -------------------------------------------------------------------------------- /SCD/model/backbone/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.hub import load_state_dict_from_url 3 | 4 | model_urls = { 5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 6 | } 7 | 8 | 9 | class ConvBNReLU(nn.Sequential): 10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1): 11 | padding = (kernel_size - 1) // 2 12 | if dilation != 1: 13 | padding = dilation 14 | super(ConvBNReLU, self).__init__( 15 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation, 16 | bias=False), 17 | nn.BatchNorm2d(out_planes), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | class InvertedResidual(nn.Module): 23 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 24 | super(InvertedResidual, self).__init__() 25 | self.stride = stride 26 | assert stride in [1, 2] 27 | 28 | hidden_dim = int(round(inp * expand_ratio)) 29 | self.use_res_connect = self.stride == 1 and inp == oup 30 | 31 | layers = [] 32 | if expand_ratio != 1: 33 | # pw 34 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 35 | layers.extend([ 36 | # dw 37 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation), 38 | # pw-linear 39 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 40 | nn.BatchNorm2d(oup), 41 | ]) 42 | self.conv = nn.Sequential(*layers) 43 | 44 | def forward(self, x): 45 | if self.use_res_connect: 46 | return x + self.conv(x) 47 | else: 48 | return self.conv(x) 49 | 50 | 51 | class MobileNetV2(nn.Module): 52 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0, replace_stride_with_dilation=False): 53 | super(MobileNetV2, self).__init__() 54 | block = InvertedResidual 55 | input_channel = 32 56 | last_channel = 1280 57 | # inverted_residual_setting = [ 58 | # # t, c, n, s, d 59 | # [1, 16, 1, 1, 1], 60 | # [6, 24, 2, 2, 1], 61 | # [6, 32, 3, 2, 1], 62 | # [6, 64, 4, 2, 1], 63 | # [6, 96, 3, 1, 1], 64 | # [6, 160, 3, 2, 1], 65 | # [6, 320, 1, 1, 1], 66 | # ] 67 | inverted_residual_setting = [ 68 | # t, c, n, s, d 69 | [1, 16, 1, 1, 1], 70 | [6, 24, 2, 2, 1], 71 | [6, 32, 3, 2, 1], 72 | [6, 64, 4, 1, 2] if replace_stride_with_dilation else [6, 64, 4, 2, 1], 73 | [6, 96, 3, 1, 1], 74 | [6, 160, 3, 1, 2] if replace_stride_with_dilation else [6, 160, 3, 2, 1], 75 | [6, 320, 1, 1, 1], 76 | ] 77 | self.channels = [16, 24, 32, 96, 320] 78 | 79 | # building first layer 80 | input_channel = int(input_channel * width_mult) 81 | self.last_channel = int(last_channel * max(1.0, width_mult)) 82 | features = [ConvBNReLU(3, input_channel, stride=2)] 83 | # building inverted residual blocks 84 | for t, c, n, s, d in inverted_residual_setting: 85 | output_channel = int(c * width_mult) 86 | for i in range(n): 87 | stride = s if i == 0 else 1 88 | dilation = d if i == 0 else 1 89 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d)) 90 | input_channel = output_channel 91 | # building last several layers 92 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 93 | # make it nn.Sequential 94 | self.features = nn.Sequential(*features) 95 | 96 | # weight initialization 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 100 | if m.bias is not None: 101 | nn.init.zeros_(m.bias) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | nn.init.ones_(m.weight) 104 | nn.init.zeros_(m.bias) 105 | elif isinstance(m, nn.Linear): 106 | nn.init.normal_(m.weight, 0, 0.01) 107 | nn.init.zeros_(m.bias) 108 | 109 | def forward(self, x): 110 | res = [] 111 | for idx, m in enumerate(self.features): 112 | x = m(x) 113 | if idx in [1, 3, 6, 13, 17]: 114 | res.append(x) 115 | return res 116 | 117 | 118 | def mobilenet_v2(pretrained=True, progress=True, replace_stride_with_dilation=False, **kwargs): 119 | model = MobileNetV2(replace_stride_with_dilation=replace_stride_with_dilation, **kwargs) 120 | if pretrained: 121 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 122 | progress=progress) 123 | print("loading imagenet pretrained mobilenetv2") 124 | model.load_state_dict(state_dict, strict=False) 125 | print("loaded imagenet pretrained mobilenetv2") 126 | return model 127 | 128 | # import torch 129 | # model = MobileNetV2(pretrained=True) 130 | # input = torch.randn(2, 3, 256, 256) 131 | # output = model(input) 132 | # for i in output: 133 | # print(i.shape) -------------------------------------------------------------------------------- /SCD/model/block/heads.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/swz30/MIRNet 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .convs import GatedConv2d, ContextGatedConv2d 8 | 9 | 10 | class ResidualUpSample(nn.Module): 11 | def __init__(self, in_channels): 12 | super(ResidualUpSample, self).__init__() 13 | 14 | self.top = nn.Sequential(nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, output_padding=1, 15 | bias=False), 16 | nn.BatchNorm2d(in_channels), 17 | nn.ReLU(True), 18 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True) 19 | ) 20 | self.bot = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 21 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True) 22 | ) 23 | self.relu = nn.ReLU(True) 24 | 25 | def forward(self, x): 26 | top = self.top(x) 27 | bot = self.bot(x) 28 | out = self.relu(top + bot) 29 | return out 30 | 31 | 32 | class GatedResidualUp(nn.Module): 33 | def __init__(self, in_channels, up_mode='conv', gate_mode='gated'): 34 | super(GatedResidualUp, self).__init__() 35 | if up_mode == 'conv': 36 | self.residual_up = nn.Sequential(nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, 37 | output_padding=1, bias=False), 38 | nn.BatchNorm2d(in_channels), 39 | nn.ReLU(True)) 40 | elif up_mode == 'bilinear': 41 | self.residual_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 42 | 43 | if gate_mode == 'gated': 44 | self.gate = GatedConv2d(in_channels, in_channels // 2) 45 | elif gate_mode == 'context_gated': 46 | self.gate = ContextGatedConv2d(in_channels, in_channels // 2) 47 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 48 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True) 49 | ) 50 | self.relu = nn.ReLU(True) 51 | 52 | def forward(self, x): 53 | residual = self.residual_up(x) 54 | residual = self.gate(residual) 55 | up = self.up(x) 56 | out = self.relu(up + residual) 57 | return out 58 | 59 | 60 | class GatedResidualUpHead(nn.Module): 61 | def __init__(self, in_channels=128, num_classes=1, dropout_rate=0.15): 62 | super(GatedResidualUpHead, self).__init__() 63 | 64 | self.up = nn.Sequential(GatedResidualUp(in_channels), 65 | GatedResidualUp(in_channels // 2)) 66 | self.smooth = nn.Sequential(nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, stride=1, padding=1), 67 | nn.ReLU(True), 68 | nn.Dropout2d(dropout_rate)) 69 | self.final = nn.Conv2d(in_channels // 4, num_classes, 1) 70 | 71 | def forward(self, x): 72 | x = self.up(x) 73 | x = self.smooth(x) 74 | x = self.final(x) 75 | 76 | return x 77 | 78 | 79 | class ResidualUpHead(nn.Module): 80 | def __init__(self, in_channels=128, num_classes=1, dropout_rate=0.15): 81 | super(ResidualUpHead, self).__init__() 82 | 83 | self.up = nn.Sequential(ResidualUpSample(in_channels), 84 | ResidualUpSample(in_channels // 2)) 85 | self.smooth = nn.Sequential(nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, stride=1, padding=1), 86 | nn.ReLU(True), 87 | nn.Dropout2d(dropout_rate)) 88 | self.final = nn.Conv2d(in_channels // 4, num_classes, 1) 89 | 90 | def forward(self, x): 91 | x = self.up(x) 92 | x = self.smooth(x) 93 | x = self.final(x) 94 | 95 | return x 96 | 97 | 98 | class FCNHead(nn.Module): 99 | def __init__(self, in_channels, num_classes, num_convs=1, dropout_rate=0.15): 100 | self.num_convs = num_convs 101 | super(FCNHead, self).__init__() 102 | inter_channels = in_channels // 4 103 | 104 | convs = [] 105 | convs.append(nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 106 | nn.BatchNorm2d(inter_channels), 107 | nn.ReLU(True))) 108 | for i in range(num_convs - 1): 109 | convs.append(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 110 | nn.BatchNorm2d(inter_channels), 111 | nn.ReLU(True))) 112 | self.convs = nn.Sequential(*convs) 113 | self.final = nn.Conv2d(inter_channels, num_classes, 1) 114 | 115 | def forward(self, x): 116 | out = self.convs(x) 117 | out = self.final(out) 118 | 119 | return out 120 | 121 | -------------------------------------------------------------------------------- /SCD/trainval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from option import Options 3 | from data.cd_dataset import DataLoader 4 | from model.create_model import create_model 5 | from tqdm import tqdm 6 | import math 7 | from util.palette import color_map 8 | from util.metric import IOUandSek 9 | import os 10 | import numpy as np 11 | import random 12 | from PIL import Image 13 | 14 | 15 | def setup_seed(seed): 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | torch.backends.cudnn.deterministic = False 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.enabled = True 26 | 27 | 28 | class Trainval(object): 29 | def __init__(self, opt): 30 | self.opt = opt 31 | 32 | train_loader = DataLoader(opt) 33 | self.train_data = train_loader.load_data() 34 | train_size = len(train_loader) 35 | print("#training images = %d" % train_size) 36 | opt.phase = 'val' 37 | val_loader = DataLoader(opt) 38 | self.val_data = val_loader.load_data() 39 | val_size = len(val_loader) 40 | print("#validation images = %d" % val_size) 41 | opt.phase = 'train' 42 | 43 | self.model = create_model(opt) 44 | self.optimizer = self.model.optimizer 45 | self.schedular = self.model.schedular 46 | 47 | self.iters = 0 48 | self.total_iters = math.ceil(train_size / opt.batch_size) * opt.num_epochs 49 | self.previous_best = 0.0 50 | 51 | def train(self): 52 | tbar = tqdm(self.train_data) 53 | opt.phase = 'train' 54 | _loss = 0.0 55 | _cd_loss = 0.0 56 | _seg_loss = 0.0 57 | 58 | for i, data in enumerate(tbar): 59 | self.model.detector.train() 60 | cd_loss, seg_loss = self.model(data['img1'].cuda(), data['img2'].cuda(), data['label1'].cuda(), 61 | data['label2'].cuda(), data['cd_label'].cuda()) 62 | loss = 2 * cd_loss + seg_loss 63 | self.optimizer.zero_grad() 64 | loss.backward() 65 | self.optimizer.step() 66 | self.schedular.step() 67 | _loss += loss.item() 68 | _cd_loss += cd_loss.item() 69 | _seg_loss += seg_loss.item() 70 | del loss 71 | 72 | #self.iters += 1 73 | #self.model.adjust_learning_rate(self.iters, self.total_iters) 74 | 75 | tbar.set_description("Loss: %.3f, CD: %.3f, Seg: %.3f, LR: %.6f" % 76 | (_loss / (i + 1), _cd_loss / (i + 1), _seg_loss / (i + 1), self.optimizer.param_groups[0]['lr'])) 77 | 78 | def val(self): 79 | tbar = tqdm(self.val_data) 80 | metric = IOUandSek(num_classes=7) 81 | opt.phase = 'val' 82 | self.model.eval() 83 | 84 | with torch.no_grad(): 85 | for i, _data in enumerate(tbar): 86 | cd_out, seg_out1, seg_out2 = self.model.inference(_data['img1'].cuda(), _data['img2'].cuda()) 87 | # update metric 88 | val_target = _data['cd_label'].detach() 89 | cd_out = torch.argmax(cd_out.detach(), dim=1) 90 | #val_pred = torch.where(val_pred > 0.5, torch.ones_like(val_pred), torch.zeros_like(val_pred)).long() 91 | seg_out1 = torch.argmax(seg_out1, dim=1).cpu().numpy() 92 | seg_out2 = torch.argmax(seg_out2, dim=1).cpu().numpy() 93 | cd_out = cd_out.cpu().numpy().astype(np.uint8) 94 | seg_out1[cd_out == 0] = 0 95 | seg_out2[cd_out == 0] = 0 96 | 97 | if self.opt.save_mask: 98 | cmap = color_map(self.opt.dataset) 99 | for i in range(seg_out1.shape[0]): 100 | mask = Image.fromarray(seg_out1[i].astype(np.uint8), mode="P") 101 | mask.putpalette(cmap) 102 | os.makedirs(os.path.join(self.opt.result_dir, 'val', 'im1'), exist_ok=True) 103 | mask.save(os.path.join(self.opt.result_dir, 'val', 'im1', _data['fname'][i])) 104 | 105 | mask = Image.fromarray(seg_out2[i].astype(np.uint8), mode="P") 106 | mask.putpalette(cmap) 107 | os.makedirs(os.path.join(self.opt.result_dir, 'val', 'im2'), exist_ok=True) 108 | mask.save(os.path.join(self.opt.result_dir, 'val', 'im2', _data['fname'][i])) 109 | 110 | metric.add_batch(seg_out1, _data['label1'].numpy()) 111 | metric.add_batch(seg_out2, _data['label2'].numpy()) 112 | score, miou, sek = metric.evaluate() 113 | tbar.set_description("Score: %.2f, IOU: %.2f, SeK: %.2f" % (score * 100.0, miou * 100.0, sek * 100.0)) 114 | 115 | if score >= self.previous_best: 116 | self.model.save(self.opt.name, self.opt.backbone) 117 | self.previous_best = score 118 | 119 | 120 | if __name__ == "__main__": 121 | opt = Options().parse() 122 | trainval = Trainval(opt) 123 | setup_seed(seed=1) 124 | 125 | for epoch in range(1, opt.num_epochs + 1): 126 | print("\n==> Epoch %i, previous best = %.3f" % (epoch, trainval.previous_best * 100)) 127 | trainval.train() 128 | trainval.val() 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /SCD/model/block/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from .dcnv2 import DCNv2 6 | from .convs import GatedConv2d, ContextGatedConv2d 7 | 8 | 9 | class GenerateGamma(nn.Module): 10 | def __init__(self, channels=128, mode='SE'): 11 | super(GenerateGamma, self).__init__() 12 | self.mode = mode 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.max_pool = nn.AdaptiveMaxPool2d(1) 15 | 16 | self.fc = nn.Sequential(nn.Conv2d(channels, channels // 4, 1, bias=False), 17 | nn.ReLU(True), 18 | nn.Conv2d(channels // 4, channels, 1, bias=False)) 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | def forward(self, x): 22 | avg_out = self.fc(self.avg_pool(x)) 23 | if self.mode == 'SE': 24 | return self.sigmoid(avg_out) 25 | elif self.mode == 'CBAM': 26 | max_out = self.fc(self.max_pool(x)) 27 | out = avg_out + max_out 28 | return self.sigmoid(out) 29 | else: 30 | raise NotImplementedError 31 | 32 | 33 | class GenerateBeta(nn.Module): 34 | def __init__(self, channels=128, mode='conv'): 35 | super(GenerateBeta, self).__init__() 36 | self.stem = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, bias=True), nn.ReLU(True)) 37 | if mode == 'conv': 38 | self.conv = nn.Conv2d(channels, channels, 3, padding=1, bias=True) 39 | elif mode == 'gatedconv': 40 | self.conv = GatedConv2d(channels, channels, 3, padding=1, bias=True) 41 | elif mode == 'contextgatedconv': 42 | self.conv = ContextGatedConv2d(channels, channels, 3, padding=1, bias=True) 43 | else: 44 | raise NotImplementedError 45 | 46 | def forward(self, x): 47 | x = self.stem(x) 48 | return self.conv(x) 49 | 50 | 51 | class FPN(nn.Module): 52 | def __init__(self, in_channels, out_channels=128, deform_groups=4, gamma_mode='SE', beta_mode='gatedconv'): 53 | super(FPN, self).__init__() 54 | 55 | self.p2 = DCNv2(in_channels=in_channels[0], out_channels=out_channels, 56 | kernel_size=3, padding=1, deform_groups=deform_groups) 57 | self.p3 = DCNv2(in_channels=in_channels[1], out_channels=out_channels, 58 | kernel_size=3, padding=1, deform_groups=deform_groups) 59 | self.p4 = DCNv2(in_channels=in_channels[2], out_channels=out_channels, 60 | kernel_size=3, padding=1, deform_groups=deform_groups) 61 | self.p5 = DCNv2(in_channels=in_channels[3], out_channels=out_channels, 62 | kernel_size=3, padding=1, deform_groups=deform_groups) 63 | 64 | self.p5_bn = nn.BatchNorm2d(out_channels, affine=True) 65 | self.p4_bn = nn.BatchNorm2d(out_channels, affine=False) 66 | self.p3_bn = nn.BatchNorm2d(out_channels, affine=False) 67 | self.p2_bn = nn.BatchNorm2d(out_channels, affine=False) 68 | self.activation = nn.ReLU(True) 69 | 70 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 71 | self.p4_Gamma = GenerateGamma(out_channels, mode=gamma_mode) 72 | self.p4_beta = GenerateBeta(out_channels, mode=beta_mode) 73 | self.p3_Gamma = GenerateGamma(out_channels, mode=gamma_mode) 74 | self.p3_beta = GenerateBeta(out_channels, mode=beta_mode) 75 | self.p2_Gamma = GenerateGamma(out_channels, mode=gamma_mode) 76 | self.p2_beta = GenerateBeta(out_channels, mode=beta_mode) 77 | 78 | self.p5_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 79 | self.p4_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 80 | self.p3_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 81 | self.p2_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 82 | def forward(self, input): 83 | c2, c3, c4, c5 = input 84 | 85 | p5 = self.activation(self.p5_bn(self.p5(c5))) 86 | p5_up = F.interpolate(p5, size=c4.shape[-2:], mode='bilinear', align_corners=False) 87 | p4 = self.p4_bn(self.p4(c4)) 88 | p4_gamma, p4_beta = self.p4_Gamma(p5_up), self.p4_beta(p5_up) 89 | p4 = self.activation(p4 * (1 + p4_gamma) + p4_beta) 90 | p4_up = F.interpolate(p4, size=c3.shape[-2:], mode='bilinear', align_corners=False) 91 | p3 = self.p3_bn(self.p3(c3)) 92 | p3_gamma, p3_beta = self.p3_Gamma(p4_up), self.p3_beta(p4_up) 93 | p3 = self.activation(p3 * (1 + p3_gamma) + p3_beta) 94 | p3_up = F.interpolate(p3, size=c2.shape[-2:], mode='bilinear', align_corners=False) 95 | p2 = self.p2_bn(self.p2(c2)) 96 | p2_gamma, p2_beta = self.p2_Gamma(p3_up), self.p2_beta(p3_up) 97 | p2 = self.activation(p2 * (1 + p2_gamma) + p2_beta) 98 | 99 | p5 = self.p5_smooth(p5) 100 | p4 = self.p4_smooth(p4) 101 | p3 = self.p3_smooth(p3) 102 | p2 = self.p2_smooth(p2) 103 | 104 | return p2, p3, p4, p5 105 | 106 | 107 | # from thop import profile 108 | # 109 | # 110 | # x1 = torch.randn(2, 512, 8, 8) 111 | # x2 = torch.randn(2, 256, 16, 16) 112 | # x3 = torch.randn(2, 128, 32, 32) 113 | # x4 = torch.randn(2, 64, 64, 64) 114 | # model = FPN(in_channels=[64, 128, 256, 512]) 115 | # y = model([x4, x3, x2, x1]) 116 | # for i in y: 117 | # print(i.shape) 118 | # flops, params = profile(model, inputs=([x4, x3, x2, x1],)) 119 | # print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') 120 | # print('Params = ' + str(params / 1000 ** 2) + 'M') 121 | -------------------------------------------------------------------------------- /SCD/model/create_model.py: -------------------------------------------------------------------------------- 1 | from .network import Detector 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | import os 7 | import torch.optim as optim 8 | from .block.schedular import get_cosine_schedule_with_warmup 9 | from .loss.bcedice import BCEDiceLoss 10 | from .loss.focal import BinaryFocalLoss, FocalLoss 11 | from .loss.dice import BinaryDICELoss, DICELoss 12 | from thop import profile 13 | 14 | 15 | def get_model(backbone_name='mobilenetv2', fpn_name='neighbor', fpn_channels=64, deform_groups=4, 16 | gamma_mode='SE', beta_mode='gatedconv', num_heads=1, num_points=8, kernel_layers=1, 17 | dropout_rate=0.1, init_type='kaiming_normal'): 18 | detector = Detector(backbone_name, fpn_name, fpn_channels, deform_groups, gamma_mode, beta_mode, 19 | num_heads, num_points, kernel_layers, dropout_rate, init_type) 20 | print(detector) 21 | input1 = torch.randn(1, 3, 256, 256) 22 | input2 = torch.randn(1, 3, 256, 256) 23 | flops, params = profile(detector, inputs=(input1, input2)) 24 | print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') 25 | print('Params = ' + str(params / 1000 ** 2) + 'M') 26 | 27 | return detector 28 | 29 | 30 | class Model(nn.Module): 31 | def __init__(self, opt): 32 | super(Model, self).__init__() 33 | self.device = torch.device("cuda:%s" % opt.gpu_ids[0] if torch.cuda.is_available() else "cpu") 34 | self.opt = opt 35 | self.base_lr = opt.lr 36 | self.save_dir = os.path.join(opt.checkpoint_dir, opt.name) 37 | os.makedirs(self.save_dir, exist_ok=True) 38 | 39 | self.detector = get_model(backbone_name=opt.backbone, fpn_name=opt.fpn, fpn_channels=opt.fpn_channels, 40 | deform_groups=opt.deform_groups, gamma_mode=opt.gamma_mode, beta_mode=opt.beta_mode, 41 | num_heads=opt.num_heads, num_points=opt.num_points, kernel_layers=opt.kernel_layers, 42 | dropout_rate=opt.dropout_rate, init_type=opt.init_type) 43 | self.cd_focal = BinaryFocalLoss(alpha=opt.alpha, gamma=opt.gamma) 44 | self.cd_dice = BinaryDICELoss() 45 | self.scd_focal = FocalLoss(ignore_index=0, alpha=opt.alpha, gamma=opt.gamma) 46 | self.scd_dice = DICELoss(ignore_index=0) 47 | 48 | self.optimizer = optim.AdamW(self.detector.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) 49 | self.schedular = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=445 * opt.warmup_epochs, 50 | num_training_steps=445 * opt.num_epochs) 51 | 52 | if opt.load_pretrain: 53 | self.load_ckpt(self.detector, self.optimizer, opt.name, opt.backbone) 54 | self.detector.cuda() 55 | # self.detector = nn.DataParallel(self.detector, device_ids=opt.gpu_ids) 56 | print("---------- Networks initialized -------------") 57 | 58 | def forward(self, x1, x2, label1, label2, cd_label): 59 | pred, pred_seg1, pred_seg2, pred_p2, pred_p3, pred_p4, pred_p5 = self.detector(x1, x2) 60 | # label = label.unsqueeze(1).float() 61 | cd_label = cd_label.long() 62 | cd_focal = self.cd_focal(pred, cd_label) 63 | cd_dice = self.cd_dice(pred, cd_label) 64 | p2_loss = self.cd_focal(pred_p2, cd_label) * 0.5 + self.cd_dice(pred_p2, cd_label) 65 | p3_loss = self.cd_focal(pred_p3, cd_label) * 0.5 + self.cd_dice(pred_p3, cd_label) 66 | p4_loss = self.cd_focal(pred_p4, cd_label) * 0.5 + self.cd_dice(pred_p4, cd_label) 67 | p5_loss = self.cd_focal(pred_p5, cd_label) * 0.5 + self.cd_dice(pred_p5, cd_label) 68 | cd_loss = cd_focal * 0.5 + cd_dice + p3_loss + p4_loss + p5_loss 69 | seg_loss = self.scd_focal(pred_seg1, label1) * 0.5 + self.scd_dice(pred_seg1, label1) \ 70 | + self.scd_focal(pred_seg2, label2) * 0.5 + self.scd_dice(pred_seg2, label2) 71 | 72 | return cd_loss, seg_loss 73 | 74 | def inference(self, x1, x2): 75 | with torch.no_grad(): 76 | pred, pred_seg1, pred_seg2, _, _, _, _ = self.detector(x1, x2) 77 | return pred, pred_seg1, pred_seg2 78 | 79 | def adjust_learning_rate(self, iter, total_iters, min_lr=1e-6, power=0.9): 80 | lr = (self.base_lr - min_lr) * (1 - iter / total_iters) ** power + min_lr 81 | for param_group in self.optimizer.param_groups: 82 | param_group['lr'] = lr 83 | 84 | def load_ckpt(self, network, optimizer, name, backbone): 85 | save_filename = '%s_%s_best.pth' % (name, backbone) 86 | save_path = os.path.join(self.save_dir, save_filename) 87 | if not os.path.isfile(save_path): 88 | print("%s not exists yet!" % save_path) 89 | raise ("%s must exist!" % save_filename) 90 | else: 91 | checkpoint = torch.load(save_path, map_location=self.device) 92 | network.load_state_dict(checkpoint['network'], False) 93 | 94 | def save_ckpt(self, network, optimizer, model_name, backbone): 95 | save_filename = '%s_%s_best.pth' % (model_name, backbone) 96 | save_path = os.path.join(self.save_dir, save_filename) 97 | if os.path.exists(save_path): 98 | os.remove(save_path) 99 | torch.save({'network': network.cpu().state_dict(), 100 | 'optimizer': optimizer.state_dict()}, 101 | save_path) 102 | if torch.cuda.is_available(): 103 | network.cuda() 104 | 105 | def save(self, model_name, backbone): 106 | self.save_ckpt(self.detector, self.optimizer, model_name, backbone) 107 | 108 | def name(self): 109 | return self.opt.name 110 | 111 | 112 | def create_model(opt): 113 | model = Model(opt) 114 | print("model [%s] was created" % model.name()) 115 | 116 | return model.cuda() 117 | 118 | -------------------------------------------------------------------------------- /util/metric_tool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied and modified from 3 | https://github.com/justchenhao/BIT_CD 4 | """ 5 | import numpy as np 6 | 7 | 8 | ################### metrics ################### 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.initialized = False 13 | self.val = None 14 | self.avg = None 15 | self.sum = None 16 | self.count = None 17 | 18 | def initialize(self, val, weight): 19 | self.val = val 20 | self.avg = val 21 | self.sum = val * weight 22 | self.count = weight 23 | self.initialized = True 24 | 25 | def update(self, val, weight=1): 26 | if not self.initialized: 27 | self.initialize(val, weight) 28 | else: 29 | self.add(val, weight) 30 | 31 | def add(self, val, weight): 32 | self.val = val 33 | self.sum += val * weight 34 | self.count += weight 35 | self.avg = self.sum / self.count 36 | 37 | def value(self): 38 | return self.val 39 | 40 | def average(self): 41 | return self.avg 42 | 43 | def get_scores(self): 44 | scores_dict = cm2score(self.sum) 45 | return scores_dict 46 | 47 | def clear(self): 48 | self.initialized = False 49 | 50 | 51 | ################### cm metrics ################### 52 | class ConfuseMatrixMeter(AverageMeter): 53 | """Computes and stores the average and current value""" 54 | def __init__(self, n_class): 55 | super(ConfuseMatrixMeter, self).__init__() 56 | self.n_class = n_class 57 | 58 | def update_cm(self, pr, gt, weight=1): 59 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr) 60 | self.update(val, weight) 61 | current_score = cm2F1(val) 62 | return current_score 63 | 64 | def get_scores(self): 65 | scores_dict = cm2score(self.sum) 66 | return scores_dict 67 | 68 | 69 | 70 | def harmonic_mean(xs): 71 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs) 72 | return harmonic_mean 73 | 74 | 75 | def cm2F1(confusion_matrix): 76 | hist = confusion_matrix 77 | n_class = hist.shape[0] 78 | tp = np.diag(hist) 79 | sum_a1 = hist.sum(axis=1) 80 | sum_a0 = hist.sum(axis=0) 81 | # ---------------------------------------------------------------------- # 82 | # 1. Accuracy & Class Accuracy 83 | # ---------------------------------------------------------------------- # 84 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 85 | 86 | # recall 87 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 88 | # acc_cls = np.nanmean(recall) 89 | 90 | # precision 91 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 92 | 93 | # F1 score 94 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps) 95 | mean_F1 = np.nanmean(F1) 96 | return mean_F1 97 | 98 | 99 | def cm2score(confusion_matrix): 100 | hist = confusion_matrix 101 | n_class = hist.shape[0] 102 | tp = np.diag(hist) 103 | sum_a1 = hist.sum(axis=1) 104 | sum_a0 = hist.sum(axis=0) 105 | # ---------------------------------------------------------------------- # 106 | # 1. Accuracy & Class Accuracy 107 | # ---------------------------------------------------------------------- # 108 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 109 | 110 | # recall 111 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 112 | # acc_cls = np.nanmean(recall) 113 | 114 | # precision 115 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 116 | 117 | # F1 score 118 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps) 119 | mean_F1 = np.nanmean(F1) 120 | # ---------------------------------------------------------------------- # 121 | # 2. Frequency weighted Accuracy & Mean IoU 122 | # ---------------------------------------------------------------------- # 123 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) 124 | mean_iu = np.nanmean(iu) 125 | 126 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps) 127 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 128 | 129 | # 130 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu)) 131 | 132 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision)) 133 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall)) 134 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1)) 135 | 136 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1} 137 | score_dict.update(cls_iou) 138 | score_dict.update(cls_F1) 139 | score_dict.update(cls_precision) 140 | score_dict.update(cls_recall) 141 | return score_dict 142 | 143 | 144 | def get_confuse_matrix(num_classes, label_gts, label_preds): 145 | def __fast_hist(label_gt, label_pred): 146 | """ 147 | Collect values for Confusion Matrix 148 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix 149 | :param label_gt: ground-truth 150 | :param label_pred: prediction 151 | :return: values for confusion matrix 152 | """ 153 | mask = (label_gt >= 0) & (label_gt < num_classes) 154 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask], 155 | minlength=num_classes**2).reshape(num_classes, num_classes) 156 | return hist 157 | confusion_matrix = np.zeros((num_classes, num_classes)) 158 | for lt, lp in zip(label_gts, label_preds): 159 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten()) 160 | return confusion_matrix 161 | 162 | 163 | def get_mIoU(num_classes, label_gts, label_preds): 164 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds) 165 | score_dict = cm2score(confusion_matrix) 166 | return score_dict['miou'] 167 | -------------------------------------------------------------------------------- /SCD/util/metric_tool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied and modified from 3 | https://github.com/justchenhao/BIT_CD 4 | """ 5 | import numpy as np 6 | 7 | 8 | ################### metrics ################### 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.initialized = False 13 | self.val = None 14 | self.avg = None 15 | self.sum = None 16 | self.count = None 17 | 18 | def initialize(self, val, weight): 19 | self.val = val 20 | self.avg = val 21 | self.sum = val * weight 22 | self.count = weight 23 | self.initialized = True 24 | 25 | def update(self, val, weight=1): 26 | if not self.initialized: 27 | self.initialize(val, weight) 28 | else: 29 | self.add(val, weight) 30 | 31 | def add(self, val, weight): 32 | self.val = val 33 | self.sum += val * weight 34 | self.count += weight 35 | self.avg = self.sum / self.count 36 | 37 | def value(self): 38 | return self.val 39 | 40 | def average(self): 41 | return self.avg 42 | 43 | def get_scores(self): 44 | scores_dict = cm2score(self.sum) 45 | return scores_dict 46 | 47 | def clear(self): 48 | self.initialized = False 49 | 50 | 51 | ################### cm metrics ################### 52 | class ConfuseMatrixMeter(AverageMeter): 53 | """Computes and stores the average and current value""" 54 | def __init__(self, n_class): 55 | super(ConfuseMatrixMeter, self).__init__() 56 | self.n_class = n_class 57 | 58 | def update_cm(self, pr, gt, weight=1): 59 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr) 60 | self.update(val, weight) 61 | current_score = cm2F1(val) 62 | return current_score 63 | 64 | def get_scores(self): 65 | scores_dict = cm2score(self.sum) 66 | return scores_dict 67 | 68 | 69 | 70 | def harmonic_mean(xs): 71 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs) 72 | return harmonic_mean 73 | 74 | 75 | def cm2F1(confusion_matrix): 76 | hist = confusion_matrix 77 | n_class = hist.shape[0] 78 | tp = np.diag(hist) 79 | sum_a1 = hist.sum(axis=1) 80 | sum_a0 = hist.sum(axis=0) 81 | # ---------------------------------------------------------------------- # 82 | # 1. Accuracy & Class Accuracy 83 | # ---------------------------------------------------------------------- # 84 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 85 | 86 | # recall 87 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 88 | # acc_cls = np.nanmean(recall) 89 | 90 | # precision 91 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 92 | 93 | # F1 score 94 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps) 95 | mean_F1 = np.nanmean(F1) 96 | return mean_F1 97 | 98 | 99 | def cm2score(confusion_matrix): 100 | hist = confusion_matrix 101 | n_class = hist.shape[0] 102 | tp = np.diag(hist) 103 | sum_a1 = hist.sum(axis=1) 104 | sum_a0 = hist.sum(axis=0) 105 | # ---------------------------------------------------------------------- # 106 | # 1. Accuracy & Class Accuracy 107 | # ---------------------------------------------------------------------- # 108 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 109 | 110 | # recall 111 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 112 | # acc_cls = np.nanmean(recall) 113 | 114 | # precision 115 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 116 | 117 | # F1 score 118 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps) 119 | mean_F1 = np.nanmean(F1) 120 | # ---------------------------------------------------------------------- # 121 | # 2. Frequency weighted Accuracy & Mean IoU 122 | # ---------------------------------------------------------------------- # 123 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) 124 | mean_iu = np.nanmean(iu) 125 | 126 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps) 127 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 128 | 129 | # 130 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu)) 131 | 132 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision)) 133 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall)) 134 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1)) 135 | 136 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1} 137 | score_dict.update(cls_iou) 138 | score_dict.update(cls_F1) 139 | score_dict.update(cls_precision) 140 | score_dict.update(cls_recall) 141 | return score_dict 142 | 143 | 144 | def get_confuse_matrix(num_classes, label_gts, label_preds): 145 | def __fast_hist(label_gt, label_pred): 146 | """ 147 | Collect values for Confusion Matrix 148 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix 149 | :param label_gt: ground-truth 150 | :param label_pred: prediction 151 | :return: values for confusion matrix 152 | """ 153 | mask = (label_gt >= 0) & (label_gt < num_classes) 154 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask], 155 | minlength=num_classes**2).reshape(num_classes, num_classes) 156 | return hist 157 | confusion_matrix = np.zeros((num_classes, num_classes)) 158 | for lt, lp in zip(label_gts, label_preds): 159 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten()) 160 | return confusion_matrix 161 | 162 | 163 | def get_mIoU(num_classes, label_gts, label_preds): 164 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds) 165 | score_dict = cm2score(confusion_matrix) 166 | return score_dict['miou'] 167 | -------------------------------------------------------------------------------- /model/block/convs.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/iduta/pyconv 3 | https://github.com/XudongLinthu/context-gated-convolution 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | 12 | class ConvBnRelu(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1): 14 | super(ConvBnRelu, self).__init__() 15 | self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 16 | padding=padding, dilation=dilation, bias=False), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True)) 19 | 20 | def forward(self, x): 21 | x = self.block(x) 22 | return x 23 | 24 | 25 | class DsBnRelu(nn.Module): 26 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1): 27 | super(DsBnRelu, self).__init__() 28 | self.kernel_size = kernel_size 29 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, 30 | dilation, groups=in_channels, bias=False) 31 | self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 32 | self.bn = nn.BatchNorm2d(out_channels) 33 | self.relu = nn.ReLU(True) 34 | 35 | def forward(self, x): 36 | if self.kernel_size != 1: 37 | x = self.depthwise(x) 38 | x = self.pointwise(x) 39 | x = self.bn(x) 40 | x = self.relu(x) 41 | return x 42 | 43 | 44 | class PyConv2d(nn.Module): 45 | def __init__(self, in_channels, out_channels, pyconv_kernels=[1, 3, 5, 7], pyconv_groups=[1, 2, 4, 8], bias=False): 46 | super(PyConv2d, self).__init__() 47 | 48 | pyconv_levels = [] 49 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups): 50 | pyconv_levels.append(nn.Conv2d(in_channels, out_channels // 2, kernel_size=pyconv_kernel, 51 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=bias)) 52 | self.pyconv_levels = nn.Sequential(*pyconv_levels) 53 | self.to_out = nn.Sequential(nn.Conv2d((out_channels // 2) * len(pyconv_kernels), out_channels, 1, bias=False), 54 | nn.BatchNorm2d(out_channels), 55 | nn.ReLU(True)) 56 | self.relu = nn.ReLU(True) 57 | 58 | def forward(self, x): 59 | out = [] 60 | for level in self.pyconv_levels: 61 | out.append(self.relu(level(x))) 62 | out = torch.cat(out, dim=1) 63 | out = self.to_out(out) 64 | 65 | return out 66 | 67 | 68 | class GatedConv2d(torch.nn.Module): 69 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): 70 | super(GatedConv2d, self).__init__() 71 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 72 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 73 | self.sigmoid = torch.nn.Sigmoid() 74 | 75 | def gated(self, mask): 76 | return self.sigmoid(mask) 77 | 78 | def forward(self, input): 79 | x = self.conv2d(input) 80 | mask = self.mask_conv2d(input) 81 | x = x * self.gated(mask) 82 | 83 | return x 84 | 85 | 86 | class ContextGatedConv2d(nn.Conv2d): 87 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 88 | padding=1, dilation=1, groups=1, bias=True): 89 | super(ContextGatedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 90 | padding, dilation, groups, bias) 91 | # for convolutional layers with a kernel size of 1, just use traditional convolution 92 | if kernel_size == 1: 93 | self.ind = True 94 | else: 95 | self.ind = False 96 | self.oc = out_channels 97 | self.ks = kernel_size 98 | 99 | # the target spatial size of the pooling layer 100 | ws = kernel_size 101 | self.avg_pool = nn.AdaptiveAvgPool2d((ws, ws)) 102 | 103 | # the dimension of the latent repsentation 104 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 105 | 106 | # the context encoding module 107 | self.ce = nn.Linear(ws * ws, self.num_lat, False) 108 | self.ce_bn = nn.BatchNorm1d(in_channels) 109 | self.ci_bn2 = nn.BatchNorm1d(in_channels) 110 | 111 | # activation function is relu 112 | self.act = nn.ReLU(inplace=True) 113 | 114 | # the number of groups in the channel interacting module 115 | if in_channels // 16: 116 | self.g = 16 117 | else: 118 | self.g = in_channels 119 | # the channel interacting module 120 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 121 | self.ci_bn = nn.BatchNorm1d(out_channels) 122 | 123 | # the gate decoding module 124 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 125 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 126 | 127 | # used to prrepare the input feature map to patches 128 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 129 | 130 | # sigmoid function 131 | self.sig = nn.Sigmoid() 132 | 133 | def forward(self, x): 134 | # for convolutional layers with a kernel size of 1, just use traditional convolution 135 | if self.ind: 136 | return F.conv2d(x, self.weight, self.bias, self.stride, 137 | self.padding, self.dilation, self.groups) 138 | else: 139 | b, c, h, w = x.size() 140 | weight = self.weight 141 | # allocate glbal information 142 | gl = self.avg_pool(x).view(b, c, -1) 143 | # context-encoding module 144 | out = self.ce(gl) 145 | # use different bn for the following two branches 146 | ce2 = out 147 | out = self.ce_bn(out) 148 | out = self.act(out) 149 | # gate decoding branch 1 150 | out = self.gd(out) 151 | # channel interacting module 152 | if self.g > 3: 153 | # grouped linear 154 | oc = self.ci(self.act(self.ci_bn2(ce2).view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2, 3).contiguous() 155 | else: 156 | # linear layer for resnet.conv1 157 | oc = self.ci(self.act(self.ci_bn2(ce2).transpose(2, 1))).transpose(2, 1).contiguous() 158 | oc = oc.view(b, self.oc, -1) 159 | oc = self.ci_bn(oc) 160 | oc = self.act(oc) 161 | # gate decoding branch 2 162 | oc = self.gd2(oc) 163 | # produce gate 164 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 165 | # unfolding input feature map to patches 166 | x_un = self.unfold(x) 167 | b, _, l = x_un.size() 168 | # gating 169 | out = (out * weight.unsqueeze(0)).view(b, self.oc, -1) 170 | # currently only handle square input and output 171 | return torch.matmul(out, x_un).view(b, self.oc, int(np.sqrt(l)), int(np.sqrt(l))) 172 | 173 | 174 | -------------------------------------------------------------------------------- /SCD/model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import timm 5 | from .backbone.mobilenetv2 import mobilenet_v2 6 | from .backbone.resnet import resnet18, resnet50 7 | from .block.bifpn import BiFPN 8 | from .block.bifpn_add import BiFPN_add 9 | from .block.neighbor import NeighborFeatureAggregation 10 | from .block.fpn import FPN 11 | from .block.fpn_plain import FPN_plain 12 | from .block.vertical import VerticalFusion 13 | from .block.convs import ConvBnRelu, DsBnRelu 14 | from .util import init_method 15 | from .block.heads import FCNHead, GatedResidualUpHead 16 | 17 | 18 | def get_backbone(backbone_name): 19 | if backbone_name == 'mobilenetv2': 20 | backbone = mobilenet_v2(pretrained=True, progress=True) 21 | backbone.channels = [16, 24, 32, 96, 320] 22 | elif backbone_name == 'mobilenetv3_small_075': 23 | backbone = timm.create_model('mobilenetv3_small_075', pretrained=True, features_only=True) 24 | backbone.channels = [16, 16, 24, 40, 432] 25 | elif backbone_name == 'mobilenetv3_small_100': 26 | backbone = timm.create_model('mobilenetv3_small_100', pretrained=True, features_only=True) 27 | backbone.channels = [16, 16, 24, 48, 576] 28 | elif backbone_name == 'resnet18': 29 | backbone = resnet18(pretrained=True, progress=True, replace_stride_with_dilation=[False, False, False]) 30 | backbone.channels = [64, 64, 128, 256, 512] 31 | elif backbone_name == 'resnet18d': 32 | backbone = timm.create_model('resnet18d', pretrained=True, features_only=True) 33 | backbone.channels = [64, 64, 128, 256, 512] 34 | elif backbone_name == 'resnet50': 35 | backbone = resnet50(pretrained=True, progress=True) 36 | backbone.channels = [64, 256, 512, 1024, 2048] 37 | elif backbone_name == 'hrnet_w18': 38 | backbone = timm.create_model('hrnet_w18', pretrained=True, features_only=True) 39 | backbone.channels = [64, 128, 256, 512, 1024] 40 | else: 41 | raise NotImplementedError("BACKBONE [%s] is not implemented!\n" % backbone_name) 42 | return backbone 43 | 44 | 45 | def get_fpn(fpn_name, in_channels, out_channels, deform_groups=4, gamma_mode='SE', beta_mode='gatedconv'): 46 | if fpn_name == 'fpn': 47 | fpn = FPN(in_channels, out_channels, deform_groups, gamma_mode, beta_mode) 48 | elif fpn_name == 'fpn_plain': 49 | fpn = FPN_plain(in_channels, out_channels) 50 | elif fpn_name == 'bifpn': 51 | fpn = BiFPN(in_channels, out_channels) 52 | elif fpn_name == 'bifpn_add': 53 | fpn = BiFPN_add(in_channels, out_channels) 54 | elif fpn_name == 'neighbor': 55 | fpn = NeighborFeatureAggregation(in_channels, out_channels) 56 | else: 57 | raise NotImplementedError("FPN [%s] is not implemented!\n" % fpn_name) 58 | return fpn 59 | 60 | 61 | class Detector(nn.Module): 62 | def __init__(self, backbone_name='mobilenetv2', fpn_name='fpn', fpn_channels=64, 63 | deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv', 64 | num_heads=1, num_points=8, kernel_layers=1, dropout_rate=0.1, init_type='kaiming_normal'): 65 | super().__init__() 66 | self.backbone = get_backbone(backbone_name) 67 | self.fpn = get_fpn(fpn_name, in_channels=self.backbone.channels[-4:], out_channels=fpn_channels, 68 | deform_groups=deform_groups, gamma_mode=gamma_mode, beta_mode=beta_mode) 69 | self.p5_to_p4 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=4, 70 | kernel_layers=kernel_layers) 71 | self.p4_to_p3 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=8, 72 | kernel_layers=kernel_layers) 73 | self.p3_to_p2 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=16, 74 | kernel_layers=kernel_layers) 75 | 76 | self.p5_head = nn.Conv2d(fpn_channels, 2, 1) 77 | self.p4_head = nn.Conv2d(fpn_channels, 2, 1) 78 | self.p3_head = nn.Conv2d(fpn_channels, 2, 1) 79 | self.p2_head = nn.Conv2d(fpn_channels, 2, 1) 80 | self.project = nn.Sequential(nn.Conv2d(fpn_channels * 4, fpn_channels, 1, bias=False), 81 | nn.BatchNorm2d(fpn_channels), 82 | nn.ReLU(True) 83 | ) 84 | self.head = GatedResidualUpHead(fpn_channels, 2, dropout_rate=dropout_rate) 85 | self.scd_head = FCNHead(fpn_channels, 7, dropout_rate=dropout_rate) 86 | # init_method(self.fpn, self.p5_to_p4, self.p4_to_p3, self.p3_to_p2, self.p5_head, self.p4_head, 87 | # self.p3_head, self.p2_head, init_type=init_type) 88 | 89 | def forward(self, x1, x2): 90 | ### Extract backbone features 91 | t1_c1, t1_c2, t1_c3, t1_c4, t1_c5 = self.backbone.forward(x1) 92 | t2_c1, t2_c2, t2_c3, t2_c4, t2_c5 = self.backbone.forward(x2) 93 | t1_p2, t1_p3, t1_p4, t1_p5 = self.fpn([t1_c2, t1_c3, t1_c4, t1_c5]) 94 | t2_p2, t2_p3, t2_p4, t2_p5 = self.fpn([t2_c2, t2_c3, t2_c4, t2_c5]) 95 | 96 | diff_p2 = torch.abs(t1_p2 - t2_p2) 97 | diff_p3 = torch.abs(t1_p3 - t2_p3) 98 | diff_p4 = torch.abs(t1_p4 - t2_p4) 99 | diff_p5 = torch.abs(t1_p5 - t2_p5) 100 | """ 101 | pred_p5 = self.p5_head(diff_p5) 102 | pred_p4 = self.p4_head(diff_p4) 103 | pred_p3 = self.p3_head(diff_p3) 104 | pred_p2 = self.p2_head(diff_p2) 105 | 106 | diff_p3 = F.interpolate(diff_p3, size=(64, 64), mode='bilinear', align_corners=False) 107 | diff_p4 = F.interpolate(diff_p4, size=(64, 64), mode='bilinear', align_corners=False) 108 | diff_p5 = F.interpolate(diff_p5, size=(64, 64), mode='bilinear', align_corners=False) 109 | #diff = diff_p2 + diff_p3 + diff_p4 + diff_p5 110 | diff = torch.cat([diff_p2, diff_p3, diff_p4, diff_p5], dim=1) 111 | diff = self.project(diff) 112 | pred = self.head(diff) 113 | 114 | """ 115 | fea_p5 = diff_p5 116 | pred_p5 = self.p5_head(fea_p5) 117 | fea_p4 = self.p5_to_p4(fea_p5, diff_p4) 118 | pred_p4 = self.p4_head(fea_p4) 119 | fea_p3 = self.p4_to_p3(fea_p4, diff_p3) 120 | pred_p3 = self.p3_head(fea_p3) 121 | fea_p2 = self.p3_to_p2(fea_p3, diff_p2) 122 | pred_p2 = self.p2_head(fea_p2) 123 | #fea_p3 = F.interpolate(fea_p3, size=(64, 64), mode='bilinear', align_corners=False) 124 | #fea_p4 = F.interpolate(fea_p4, size=(64, 64), mode='bilinear', align_corners=False) 125 | #fea_p5 = F.interpolate(fea_p5, size=(64, 64), mode='bilinear', align_corners=False) 126 | #diff = diff_p2 + diff_p3 + diff_p4 + diff_p5 127 | #diff = torch.cat([fea_p2, fea_p3, fea_p4, fea_p5], dim=1) 128 | #diff = self.project(diff) 129 | pred = self.head(fea_p2) 130 | 131 | 132 | pred_p2 = F.interpolate(pred_p2, size=(256, 256), mode='bilinear', align_corners=False) 133 | pred_p3 = F.interpolate(pred_p3, size=(256, 256), mode='bilinear', align_corners=False) 134 | pred_p4 = F.interpolate(pred_p4, size=(256, 256), mode='bilinear', align_corners=False) 135 | pred_p5 = F.interpolate(pred_p5, size=(256, 256), mode='bilinear', align_corners=False) 136 | #pred = F.interpolate(pred, size=(256, 256), mode='bilinear', align_corners=False) 137 | 138 | t1_p3 = F.interpolate(t1_p3, size=(64, 64), mode='bilinear', align_corners=False) 139 | t1_p4 = F.interpolate(t1_p4, size=(64, 64), mode='bilinear', align_corners=False) 140 | t1_p5 = F.interpolate(t1_p5, size=(64, 64), mode='bilinear', align_corners=False) 141 | t1_fea = torch.cat([t1_p2, t1_p3, t1_p4, t1_p5], dim=1) 142 | t1_fea = self.project(t1_fea) 143 | pred_seg1 = self.scd_head(t1_fea) 144 | pred_seg1 = F.interpolate(pred_seg1, size=(256, 256), mode='bilinear', align_corners=False) 145 | t2_p3 = F.interpolate(t2_p3, size=(64, 64), mode='bilinear', align_corners=False) 146 | t2_p4 = F.interpolate(t2_p4, size=(64, 64), mode='bilinear', align_corners=False) 147 | t2_p5 = F.interpolate(t2_p5, size=(64, 64), mode='bilinear', align_corners=False) 148 | t2_fea = torch.cat([t2_p2, t2_p3, t2_p4, t2_p5], dim=1) 149 | t2_fea = self.project(t2_fea) 150 | pred_seg2 = self.scd_head(t2_fea) 151 | pred_seg2 = F.interpolate(pred_seg2, size=(256, 256), mode='bilinear', align_corners=False) 152 | 153 | return pred, pred_seg1, pred_seg2, pred_p2, pred_p3, pred_p4, pred_p5 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /SCD/model/block/convs.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/iduta/pyconv 3 | https://github.com/ZitongYu/CDCN/ 4 | https://github.com/XudongLinthu/context-gated-convolution 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from mmcv.ops import ModulatedDeformConv2dPack as DCNv2 11 | import numpy as np 12 | 13 | 14 | class ConvBnRelu(nn.Module): 15 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1): 16 | super(ConvBnRelu, self).__init__() 17 | self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 18 | padding=padding, dilation=dilation, bias=False), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True)) 21 | 22 | def forward(self, x): 23 | x = self.block(x) 24 | return x 25 | 26 | 27 | class DsBnRelu(nn.Module): 28 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1): 29 | super(DsBnRelu, self).__init__() 30 | self.kernel_size = kernel_size 31 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, 32 | dilation, groups=in_channels, bias=False) 33 | self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 34 | self.bn = nn.BatchNorm2d(out_channels) 35 | self.relu = nn.ReLU(True) 36 | 37 | def forward(self, x): 38 | if self.kernel_size != 1: 39 | x = self.depthwise(x) 40 | x = self.pointwise(x) 41 | x = self.bn(x) 42 | x = self.relu(x) 43 | return x 44 | 45 | 46 | class PyConv2d(nn.Module): 47 | def __init__(self, in_channels, out_channels, pyconv_kernels=[1, 3, 5, 7], pyconv_groups=[1, 2, 4, 8], bias=False): 48 | super(PyConv2d, self).__init__() 49 | 50 | pyconv_levels = [] 51 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups): 52 | pyconv_levels.append(nn.Conv2d(in_channels, out_channels // 2, kernel_size=pyconv_kernel, 53 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=bias)) 54 | self.pyconv_levels = nn.Sequential(*pyconv_levels) 55 | self.to_out = nn.Sequential(nn.Conv2d((out_channels // 2) * len(pyconv_kernels), out_channels, 1, bias=False), 56 | nn.BatchNorm2d(out_channels), 57 | nn.ReLU(True)) 58 | self.relu = nn.ReLU(True) 59 | 60 | def forward(self, x): 61 | out = [] 62 | for level in self.pyconv_levels: 63 | out.append(self.relu(level(x))) 64 | out = torch.cat(out, dim=1) 65 | out = self.to_out(out) 66 | 67 | return out 68 | 69 | 70 | class PyDCNv2(nn.Module): 71 | def __init__(self, in_channels, out_channels, pyconv_kernels=[1, 3, 5], pyconv_groups=[1, 4, 8], bias=False): 72 | super(PyDCNv2, self).__init__() 73 | 74 | pyconv_levels = [] 75 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups): 76 | if pyconv_kernel == 1: 77 | pyconv_levels.append(nn.Conv2d(in_channels, out_channels // 2, kernel_size=pyconv_kernel, 78 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=bias)) 79 | else: 80 | pyconv_levels.append(DCNv2(in_channels, out_channels // 2, kernel_size=pyconv_kernel, 81 | padding=pyconv_kernel // 2, deform_groups=pyconv_group)) 82 | self.pyconv_levels = nn.Sequential(*pyconv_levels) 83 | self.to_out = nn.Sequential(nn.Conv2d((out_channels // 2) * len(pyconv_kernels), out_channels, 1, bias=False), 84 | nn.BatchNorm2d(out_channels), 85 | nn.ReLU6(True)) 86 | self.relu = nn.ReLU6(True) 87 | 88 | def forward(self, x): 89 | out = [] 90 | for level in self.pyconv_levels: 91 | out.append(self.relu(level(x))) 92 | out = torch.cat(out, dim=1) 93 | out = self.to_out(out) 94 | 95 | return out 96 | 97 | 98 | class CDC(nn.Module): 99 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 100 | padding=1, dilation=1, groups=1, bias=False, theta=0.7): 101 | super(CDC, self).__init__() 102 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 103 | self.theta = theta 104 | 105 | def forward(self, x): 106 | out_normal = self.conv(x) 107 | #pdb.set_trace() 108 | [C_out, C_in, kernel_size, kernel_size] = self.conv.weight.shape 109 | kernel_diff = self.conv.weight.sum(2).sum(2) 110 | kernel_diff = kernel_diff[:, :, None, None] 111 | out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, groups=self.conv.groups) 112 | 113 | return out_normal - self.theta * out_diff 114 | 115 | 116 | class GatedConv2d(torch.nn.Module): 117 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): 118 | super(GatedConv2d, self).__init__() 119 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 120 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 121 | self.sigmoid = torch.nn.Sigmoid() 122 | 123 | def gated(self, mask): 124 | return self.sigmoid(mask) 125 | 126 | def forward(self, input): 127 | x = self.conv2d(input) 128 | mask = self.mask_conv2d(input) 129 | x = x * self.gated(mask) 130 | 131 | return x 132 | 133 | 134 | class ContextGatedConv2d(nn.Conv2d): 135 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 136 | padding=1, dilation=1, groups=1, bias=True): 137 | super(ContextGatedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 138 | padding, dilation, groups, bias) 139 | # for convolutional layers with a kernel size of 1, just use traditional convolution 140 | if kernel_size == 1: 141 | self.ind = True 142 | else: 143 | self.ind = False 144 | self.oc = out_channels 145 | self.ks = kernel_size 146 | 147 | # the target spatial size of the pooling layer 148 | ws = kernel_size 149 | self.avg_pool = nn.AdaptiveAvgPool2d((ws, ws)) 150 | 151 | # the dimension of the latent repsentation 152 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 153 | 154 | # the context encoding module 155 | self.ce = nn.Linear(ws * ws, self.num_lat, False) 156 | self.ce_bn = nn.BatchNorm1d(in_channels) 157 | self.ci_bn2 = nn.BatchNorm1d(in_channels) 158 | 159 | # activation function is relu 160 | self.act = nn.ReLU(inplace=True) 161 | 162 | # the number of groups in the channel interacting module 163 | if in_channels // 16: 164 | self.g = 16 165 | else: 166 | self.g = in_channels 167 | # the channel interacting module 168 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 169 | self.ci_bn = nn.BatchNorm1d(out_channels) 170 | 171 | # the gate decoding module 172 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 173 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 174 | 175 | # used to prrepare the input feature map to patches 176 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 177 | 178 | # sigmoid function 179 | self.sig = nn.Sigmoid() 180 | 181 | def forward(self, x): 182 | # for convolutional layers with a kernel size of 1, just use traditional convolution 183 | if self.ind: 184 | return F.conv2d(x, self.weight, self.bias, self.stride, 185 | self.padding, self.dilation, self.groups) 186 | else: 187 | b, c, h, w = x.size() 188 | weight = self.weight 189 | # allocate glbal information 190 | gl = self.avg_pool(x).view(b, c, -1) 191 | # context-encoding module 192 | out = self.ce(gl) 193 | # use different bn for the following two branches 194 | ce2 = out 195 | out = self.ce_bn(out) 196 | out = self.act(out) 197 | # gate decoding branch 1 198 | out = self.gd(out) 199 | # channel interacting module 200 | if self.g > 3: 201 | # grouped linear 202 | oc = self.ci(self.act(self.ci_bn2(ce2).view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2, 3).contiguous() 203 | else: 204 | # linear layer for resnet.conv1 205 | oc = self.ci(self.act(self.ci_bn2(ce2).transpose(2, 1))).transpose(2, 1).contiguous() 206 | oc = oc.view(b, self.oc, -1) 207 | oc = self.ci_bn(oc) 208 | oc = self.act(oc) 209 | # gate decoding branch 2 210 | oc = self.gd2(oc) 211 | # produce gate 212 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 213 | # unfolding input feature map to patches 214 | x_un = self.unfold(x) 215 | b, _, l = x_un.size() 216 | # gating 217 | out = (out * weight.unsqueeze(0)).view(b, self.oc, -1) 218 | # currently only handle square input and output 219 | return torch.matmul(out, x_un).view(b, self.oc, int(np.sqrt(l)), int(np.sqrt(l))) 220 | 221 | 222 | --------------------------------------------------------------------------------