├── __init__.py
├── util
├── __init__.py
├── util.py
└── metric_tool.py
├── SCD
├── __init__.py
├── data
│ ├── __init__.py
│ ├── transform.py
│ └── cd_dataset.py
├── model
│ ├── __init__.py
│ ├── block
│ │ ├── __init__.py
│ │ ├── dcnv2.py
│ │ ├── schedular.py
│ │ ├── vertical.py
│ │ ├── heads.py
│ │ ├── fpn.py
│ │ └── convs.py
│ ├── loss
│ │ ├── __init__.py
│ │ ├── focal.py
│ │ └── dice.py
│ ├── backbone
│ │ ├── __init__.py
│ │ └── mobilenetv2.py
│ ├── util.py
│ ├── create_model.py
│ └── network.py
├── util
│ ├── __init__.py
│ ├── palette.py
│ ├── util.py
│ ├── metric.py
│ └── metric_tool.py
├── test.py
├── option.py
└── trainval.py
├── model
├── block
│ ├── __init__.py
│ ├── dcnv2.py
│ ├── schedular.py
│ ├── heads.py
│ ├── vertical.py
│ ├── fpn.py
│ └── convs.py
├── loss
│ ├── __init__.py
│ ├── dice.py
│ └── focal.py
├── backbone
│ ├── __init__.py
│ └── mobilenetv2.py
├── __init__.py
├── util.py
├── network.py
└── create_model.py
├── data
├── __init__.py
├── transform.py
└── cd_dataset.py
├── test.py
├── README.md
├── option.py
└── trainval.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SCD/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/block/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SCD/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SCD/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SCD/util/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/backbone/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SCD/model/block/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SCD/model/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SCD/model/backbone/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference:
3 | https://github.com/NVIDIA/pix2pixHD/
4 | """
5 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | References:
3 | https://github.com/NVIDIA/pix2pixHD
4 | https://github.com/VainF/DeepLabV3Plus-Pytorch
5 | """
--------------------------------------------------------------------------------
/SCD/util/palette.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def color_map(dataset):
4 | if dataset == 'SECOND_SCD':
5 | cmap = np.zeros((7, 3), dtype=np.uint8)
6 | cmap[0] = np.array([255, 255, 255])
7 | cmap[1] = np.array([0, 0, 255])
8 | cmap[2] = np.array([128, 128, 128])
9 | cmap[3] = np.array([0, 128, 0])
10 | cmap[4] = np.array([0, 255, 0])
11 | cmap[5] = np.array([128, 0, 0])
12 | cmap[6] = np.array([255, 0, 0])
13 | else:
14 | raise NotImplementedError
15 |
16 | return cmap
17 |
18 |
19 | def Color2Index(dataset, ColorLabel):
20 | if dataset == 'SECOND_SCD':
21 | num_classes = 7
22 | ST_COLORMAP = [[255, 255, 255], [0, 0, 255], [128, 128, 128], [0, 128, 0], [0, 255, 0], [128, 0, 0],
23 | [255, 0, 0]]
24 | CLASSES = ['unchanged', 'water', 'ground', 'low vegetation', 'tree', 'building', 'sports field']
25 | else:
26 | raise NotImplementedError
27 |
28 | colormap2label = np.zeros(256 ** 3)
29 | for i, cm in enumerate(ST_COLORMAP):
30 | colormap2label[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
31 |
32 | data = ColorLabel.astype(np.int32)
33 | idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
34 | IndexMap = colormap2label[idx]
35 | IndexMap = IndexMap * (IndexMap < num_classes)
36 |
37 | return IndexMap
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from util.metric_tool import ConfuseMatrixMeter
2 | import torch
3 | from option import Options
4 | from data.cd_dataset import DataLoader
5 | from model.create_model import create_model
6 | from tqdm import tqdm
7 |
8 | if __name__ == '__main__':
9 | opt = Options().parse()
10 | opt.phase = 'test'
11 | test_loader = DataLoader(opt)
12 | test_data = test_loader.load_data()
13 | test_size = len(test_loader)
14 | print("#testing images = %d" % test_size)
15 |
16 | opt.load_pretrain = True
17 | model = create_model(opt)
18 |
19 | tbar = tqdm(test_data, ncols=80)
20 | total_iters = test_size
21 | running_metric = ConfuseMatrixMeter(n_class=2)
22 | running_metric.clear()
23 |
24 | model.eval()
25 | with torch.no_grad():
26 | for i, _data in enumerate(tbar):
27 | val_pred = model.inference(_data['img1'].cuda(), _data['img2'].cuda())
28 | # update metric
29 | val_target = _data['cd_label'].detach()
30 | val_pred = torch.argmax(val_pred.detach(), dim=1)
31 | _ = running_metric.update_cm(pr=val_pred.cpu().numpy(), gt=val_target.cpu().numpy())
32 | val_scores = running_metric.get_scores()
33 | message = '(phase: %s) ' % (opt.phase)
34 | for k, v in val_scores.items():
35 | message += '%s: %.3f ' % (k, v * 100)
36 | print(message)
37 |
--------------------------------------------------------------------------------
/model/loss/dice.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.utils.data
4 | import torch.nn as nn
5 | from kornia.losses import dice_loss
6 |
7 | class DICELoss(nn.Module):
8 | def __init__(self):
9 | super(DICELoss, self).__init__()
10 |
11 | def forward(self, input, target):
12 | target = target.squeeze(1)
13 | loss = dice_loss(input, target)
14 |
15 | return loss
16 |
17 | ### version 2
18 | """
19 | class DICELoss(nn.Module):
20 | def __init__(self, eps=1e-5):
21 | super(DICELoss, self).__init__()
22 | self.eps = eps
23 |
24 | def to_one_hot(self, target):
25 | N, C, H, W = target.size()
26 | assert C == 1
27 | target = torch.zeros(N, 2, H, W).to(target.device).scatter_(1, target, 1)
28 | return target
29 |
30 | def forward(self, input, target):
31 | N, C, _, _ = input.size()
32 | input = F.softmax(input, dim=1)
33 |
34 | #target = self.to_one_hot(target)
35 | target = torch.eye(2)[target.squeeze(1)]
36 | target = target.permute(0, 3, 1, 2).type_as(input)
37 |
38 | dims = tuple(range(1, target.ndimension()))
39 | inter = torch.sum(input * target, dims)
40 | cardinality = torch.sum(input + target, dims)
41 | loss = ((2. * inter) / (cardinality + self.eps)).mean()
42 |
43 | return 1 - loss
44 | """
45 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Code for the TCSVT article 'Cross-Level Attentive Feature Aggregation for Change Detection'.
2 | ---------------------------------------------
3 | Here I provide the PyTorch implementation for CLAFA.
4 |
5 |
6 | ## ENVIRONMENT
7 | >RTX 3090
8 | >python 3.8.8
9 | >PyTorch 1.11.0
10 | >mmcv 1.6.0
11 |
12 | ## Installation
13 | Clone this repo:
14 |
15 | ```shell
16 | git clone https://github.com/xingronaldo/CLAFA.git
17 | cd CLAFA
18 | ```
19 |
20 | * Install dependencies
21 |
22 | All dependencies can be installed via 'pip'.
23 |
24 | ## Dataset Preparation
25 | Download data and add them to `./datasets`.
26 |
27 |
28 | ## Test
29 | Here I provide the trained models for the SV-CD dataset [Baidu Netdisk, code: CLAF](https://pan.baidu.com/s/1nfqqXA3DsZtU4BtHY-3YOg)A.
30 |
31 | Put them in `./checkpoints`.
32 |
33 |
34 | * Test on the SV-CD dataset with the MobileNetV2 backbone
35 |
36 | ```python
37 | python test.py --backbone mobilenetv2 --name SV_mobilenetv2 --gpu_ids 1
38 | ```
39 |
40 | * Test on the SV-CD dataset with the ResNet18d backbone
41 |
42 | ```python
43 | python test.py --backbone resnet18d --name SV_resnet18d --gpu_ids 1
44 | ```
45 |
46 | ## Train & Validation
47 | ```python
48 | python trainval.py --gpu_ids 1
49 | ```
50 | All the hyperparameters can be adjusted in `option.py`.
51 |
52 |
53 | ## Contact
54 | Email: guangxingwang@mail.nwpu.edu.cn
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/model/loss/focal.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | from einops import rearrange
11 |
12 |
13 | class FocalLoss(nn.Module):
14 | def __init__(self, alpha=0.25, gamma=4.0):
15 | super(FocalLoss, self).__init__()
16 | self.alpha = alpha
17 | self.gamma = gamma
18 | if isinstance(alpha, (float, int)):
19 | self.alpha = torch.as_tensor([alpha, 1 - alpha])
20 | if isinstance(alpha, list):
21 | self.alpha = torch.as_tensor(alpha)
22 |
23 | def forward(self, input, target):
24 | N, C, H, W = input.size()
25 | assert C == 2
26 | # input = input.view(N, C, -1)
27 | # input = input.transpose(1, 2)
28 | # input = input.contiguous().view(-1, C)
29 | input = rearrange(input, 'b c h w -> (b h w) c')
30 | # input = input.contiguous().view(-1)
31 |
32 | target = target.view(-1, 1)
33 | logpt = F.log_softmax(input, dim=1)
34 | logpt = logpt.gather(1, target)
35 | logpt = logpt.view(-1)
36 | pt = Variable(logpt.data.exp())
37 |
38 | if self.alpha is not None:
39 | if self.alpha.type() != input.data.type():
40 | self.alpha = self.alpha.type_as(input.data)
41 | at = self.alpha.gather(0, target.data.view(-1))
42 | logpt = logpt * Variable(at)
43 | loss = -1 * (1-pt)**self.gamma * logpt
44 |
45 | return loss.mean()
46 |
47 |
48 |
--------------------------------------------------------------------------------
/SCD/util/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied and modified from
3 | https://github.com/NVIDIA/pix2pixHD/tree/master/util
4 | """
5 | from __future__ import print_function
6 | import os
7 | import torch
8 | import numpy as np
9 | from PIL import Image
10 | from torchvision import utils
11 |
12 |
13 | def mkdirs(paths):
14 | if isinstance(paths, list) and not isinstance(paths, str):
15 | for path in paths:
16 | mkdir(path)
17 | else:
18 | mkdir(paths)
19 |
20 |
21 | def mkdir(path):
22 | if not os.path.exists(path):
23 | os.makedirs(path)
24 |
25 |
26 | def save_image(image_numpy, image_path):
27 | image_pil = Image.fromarray(np.array(image_numpy,dtype=np.uint8))
28 | image_pil.save(image_path)
29 |
30 |
31 | def make_numpy_grid(tensor_data, pad_value=0,padding=0):
32 | tensor_data = tensor_data.detach()
33 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding)
34 | vis = np.array(vis.cpu()).transpose((1,2,0))
35 | if vis.shape[2] == 1:
36 | vis = np.stack([vis, vis, vis], axis=-1)
37 | return vis
38 |
39 |
40 | def de_norm(tensor_data):
41 | return tensor_data * 0.5 + 0.5
42 |
43 |
44 | def replace_batchnorm(net):
45 | for child_name, child in net.named_children():
46 | if hasattr(child, 'fuse'):
47 | setattr(net, child_name, child.fuse())
48 | elif isinstance(child, torch.nn.Conv2d):
49 | child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))
50 | elif isinstance(child, torch.nn.BatchNorm2d):
51 | setattr(net, child_name, torch.nn.Identity())
52 | else:
53 | replace_batchnorm(child)
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied and modified from
3 | https://github.com/NVIDIA/pix2pixHD/tree/master/util
4 | """
5 | from __future__ import print_function
6 | import os
7 | import torch
8 | import numpy as np
9 | from PIL import Image
10 | from torchvision import utils
11 |
12 |
13 | def mkdirs(paths):
14 | if isinstance(paths, list) and not isinstance(paths, str):
15 | for path in paths:
16 | mkdir(path)
17 | else:
18 | mkdir(paths)
19 |
20 |
21 | def mkdir(path):
22 | if not os.path.exists(path):
23 | os.makedirs(path)
24 |
25 |
26 | def save_image(image_numpy, image_path):
27 | image_pil = Image.fromarray(np.array(image_numpy,dtype=np.uint8))
28 | image_pil.save(image_path)
29 |
30 |
31 | def make_numpy_grid(tensor_data, pad_value=0,padding=0):
32 | tensor_data = tensor_data.detach()
33 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding)
34 | vis = np.array(vis.cpu()).transpose((1,2,0))
35 | if vis.shape[2] == 1:
36 | vis = np.stack([vis, vis, vis], axis=-1)
37 | return vis
38 |
39 |
40 | def de_norm(tensor_data):
41 | return tensor_data * 0.5 + 0.5
42 |
43 |
44 | def replace_batchnorm(net):
45 | for child_name, child in net.named_children():
46 | if hasattr(child, 'fuse'):
47 | setattr(net, child_name, child.fuse())
48 | elif isinstance(child, torch.nn.Conv2d):
49 | child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))
50 | elif isinstance(child, torch.nn.BatchNorm2d):
51 | setattr(net, child_name, torch.nn.Identity())
52 | else:
53 | replace_batchnorm(child)
--------------------------------------------------------------------------------
/model/block/dcnv2.py:
--------------------------------------------------------------------------------
1 | from mmcv.ops import ModulatedDeformConv2dPack, modulated_deform_conv2d
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class DCNv2(ModulatedDeformConv2dPack):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 | out_channels = self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
10 | pyconv_kernels = [1, 3, 5]
11 | pyconv_groups = [1, self.deform_groups // 2, self.deform_groups]
12 | pyconv_levels = []
13 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups):
14 | pyconv_levels.append(nn.Sequential(nn.Conv2d(self.in_channels, out_channels, kernel_size=pyconv_kernel,
15 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=False),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(True)))
18 | self.pyconv_levels = nn.Sequential(*pyconv_levels)
19 | self.offset = nn.Conv2d(out_channels * 3, out_channels, 1, bias=True)
20 | self.init_weights()
21 |
22 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
23 | out = []
24 | for level in self.pyconv_levels:
25 | out.append(level(x))
26 | out = torch.cat(out, dim=1)
27 |
28 | out = self.offset(out)
29 | o1, o2, mask = torch.chunk(out, 3, dim=1)
30 | offset = torch.cat((o1, o2), dim=1)
31 | mask = torch.sigmoid(mask)
32 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
33 | self.stride, self.padding,
34 | self.dilation, self.groups,
35 | self.deform_groups)
36 |
37 |
--------------------------------------------------------------------------------
/model/util.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def init_method(*nets, init_type='normal'):
5 | for net in nets:
6 | for module in net.modules():
7 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.ConvTranspose2d):
8 | if init_type == 'normal':
9 | nn.init.normal_(module.weight.data, 0.0, 0.02)
10 | elif init_type == 'xavier':
11 | nn.init.xavier_normal_(module.weight.data, gain=1.0)
12 | elif init_type == 'kaiming_normal':
13 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu')
14 | elif init_type == 'kaiming_normal_out':
15 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_out', nonlinearity='relu')
16 | elif init_type == 'kaiming_uniform':
17 | nn.init.kaiming_uniform_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu')
18 | elif init_type == 'trunc_normal':
19 | nn.init.trunc_normal_(module.weight.data, mean=0.0, std=1.0, a=- 2.0, b=2.0)
20 | elif init_type == 'orthogonal':
21 | nn.init.orthogonal_(module.weight.data, gain=1.0)
22 | else:
23 | raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
24 | if module.bias is not None:
25 | nn.init.constant_(module.bias.data, 0.0)
26 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.LayerNorm):
27 | nn.init.normal_(module.weight.data, 1.0, 0.02)
28 | nn.init.constant_(module.bias.data, 0.0)
29 |
30 | print("initialize \\backbone networks with [%s]" % init_type)
31 |
32 |
--------------------------------------------------------------------------------
/SCD/model/util.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def init_method(*nets, init_type='normal'):
5 | for net in nets:
6 | for module in net.modules():
7 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.ConvTranspose2d):
8 | if init_type == 'normal':
9 | nn.init.normal_(module.weight.data, 0.0, 0.02)
10 | elif init_type == 'xavier':
11 | nn.init.xavier_normal_(module.weight.data, gain=1.0)
12 | elif init_type == 'kaiming_normal':
13 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu')
14 | elif init_type == 'kaiming_normal_out':
15 | nn.init.kaiming_normal_(module.weight.data, a=0, mode='fan_out', nonlinearity='relu')
16 | elif init_type == 'kaiming_uniform':
17 | nn.init.kaiming_uniform_(module.weight.data, a=0, mode='fan_in', nonlinearity='relu')
18 | elif init_type == 'trunc_normal':
19 | nn.init.trunc_normal_(module.weight.data, mean=0.0, std=1.0, a=- 2.0, b=2.0)
20 | elif init_type == 'orthogonal':
21 | nn.init.orthogonal_(module.weight.data, gain=1.0)
22 | else:
23 | raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
24 | if module.bias is not None:
25 | nn.init.constant_(module.bias.data, 0.0)
26 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.LayerNorm):
27 | nn.init.normal_(module.weight.data, 1.0, 0.02)
28 | nn.init.constant_(module.bias.data, 0.0)
29 |
30 | print("initialize \\backbone networks with [%s]" % init_type)
31 |
32 |
--------------------------------------------------------------------------------
/SCD/util/metric.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 |
5 | def cal_kappa(hist):
6 | if hist.sum() == 0:
7 | kappa = 0
8 | else:
9 | po = np.diag(hist).sum() / hist.sum()
10 | pe = np.matmul(hist.sum(1), hist.sum(0).T) / hist.sum() ** 2
11 | if pe == 1:
12 | kappa = 0
13 | else:
14 | kappa = (po - pe) / (1 - pe)
15 |
16 | return kappa
17 |
18 |
19 | class IOUandSek:
20 | def __init__(self, num_classes):
21 | self.num_classes = num_classes
22 | self.hist = np.zeros((num_classes, num_classes))
23 |
24 | def _fast_hist(self, label_pred, label_true):
25 | mask = (label_true >= 0) & (label_true < self.num_classes)
26 | hist = np.bincount(
27 | self.num_classes * label_true[mask].astype(int) +
28 | label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
29 |
30 | return hist
31 |
32 | def add_batch(self, predictions, gts):
33 | for lp, lt in zip(predictions, gts):
34 | self.hist += self._fast_hist(lp.flatten(), lt.flatten())
35 |
36 | def evaluate(self):
37 | confusion_matrix = np.zeros((2, 2))
38 | confusion_matrix[0][0] = self.hist[0][0]
39 | confusion_matrix[0][1] = self.hist.sum(1)[0] - self.hist[0][0]
40 | confusion_matrix[1][0] = self.hist.sum(0)[0] - self.hist[0][0]
41 | confusion_matrix[1][1] = self.hist[1:, 1:].sum()
42 |
43 | iou = np.diag(confusion_matrix) / (confusion_matrix.sum(0) +
44 | confusion_matrix.sum(1) - np.diag(confusion_matrix))
45 | miou = np.mean(iou)
46 |
47 | hist = self.hist.copy()
48 | hist[0][0] = 0
49 | kappa = cal_kappa(hist)
50 | sek = kappa * math.exp(iou[1] - 1)
51 |
52 | score = 0.3 * miou + 0.7 * sek
53 |
54 | return score, miou, sek
55 |
56 | def miou(self):
57 | confusion_matrix = self.hist[1:, 1:]
58 | iou = np.diag(confusion_matrix) / (confusion_matrix.sum(0) + confusion_matrix.sum(1) - np.diag(confusion_matrix))
59 |
60 | return iou, np.mean(iou)
61 |
--------------------------------------------------------------------------------
/SCD/model/block/dcnv2.py:
--------------------------------------------------------------------------------
1 | from mmcv.ops import ModulatedDeformConv2dPack, modulated_deform_conv2d
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class DCNv2(ModulatedDeformConv2dPack):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 | out_channels = self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
10 | pyconv_kernels = [1, 3, 5]
11 | pyconv_groups = [1, self.deform_groups // 2, self.deform_groups]
12 | pyconv_levels = []
13 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups):
14 | pyconv_levels.append(nn.Sequential(nn.Conv2d(self.in_channels, out_channels, kernel_size=pyconv_kernel,
15 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=False),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(True)))
18 | self.pyconv_levels = nn.Sequential(*pyconv_levels)
19 | self.offset = nn.Conv2d(out_channels * 3, out_channels, 1, bias=True)
20 | self.init_weights()
21 |
22 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
23 | out = []
24 | for level in self.pyconv_levels:
25 | out.append(level(x))
26 | out = torch.cat(out, dim=1)
27 |
28 | out = self.offset(out)
29 | o1, o2, mask = torch.chunk(out, 3, dim=1)
30 | offset = torch.cat((o1, o2), dim=1)
31 | mask = torch.sigmoid(mask)
32 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
33 | self.stride, self.padding,
34 | self.dilation, self.groups,
35 | self.deform_groups)
36 |
37 |
38 | from thop import profile
39 |
40 |
41 | # x = torch.randn(2, 512, 8, 8)
42 | # model = DCNv2(in_channels=512, out_channels=128, kernel_size=3, padding=1, deform_groups=4)
43 | # y = model(x)
44 | # print(y.shape)
45 | # flops, params = profile(model, inputs=(x,))
46 | # print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
47 | # print('Params = ' + str(params / 1000 ** 2) + 'M')
48 |
--------------------------------------------------------------------------------
/model/block/schedular.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/huggingface/transformers
3 | """
4 |
5 | import math
6 | from functools import partial
7 | from torch.optim import Optimizer
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 |
11 | def _get_cosine_schedule_with_warmup_lr_lambda(
12 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
13 | ):
14 | if current_step < num_warmup_steps:
15 | return float(current_step) / float(max(1, num_warmup_steps))
16 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
17 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
18 |
19 |
20 | def get_cosine_schedule_with_warmup(
21 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1):
22 | """
23 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
24 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
25 | initial lr set in the optimizer.
26 | Args:
27 | optimizer ([`~torch.optim.Optimizer`]):
28 | The optimizer for which to schedule the learning rate.
29 | num_warmup_steps (`int`):
30 | The number of steps for the warmup phase.
31 | num_training_steps (`int`):
32 | The total number of training steps.
33 | num_cycles (`float`, *optional*, defaults to 0.5):
34 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
35 | following a half-cosine).
36 | last_epoch (`int`, *optional*, defaults to -1):
37 | The index of the last epoch when resuming training.
38 | Return:
39 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
40 | """
41 |
42 | lr_lambda = partial(
43 | _get_cosine_schedule_with_warmup_lr_lambda,
44 | num_warmup_steps=num_warmup_steps,
45 | num_training_steps=num_training_steps,
46 | num_cycles=num_cycles,
47 | )
48 | return LambdaLR(optimizer, lr_lambda, last_epoch)
--------------------------------------------------------------------------------
/SCD/model/block/schedular.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/huggingface/transformers
3 | """
4 |
5 | import math
6 | from functools import partial
7 | from torch.optim import Optimizer
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 |
11 | def _get_cosine_schedule_with_warmup_lr_lambda(
12 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
13 | ):
14 | if current_step < num_warmup_steps:
15 | return float(current_step) / float(max(1, num_warmup_steps))
16 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
17 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
18 |
19 |
20 | def get_cosine_schedule_with_warmup(
21 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1):
22 | """
23 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
24 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
25 | initial lr set in the optimizer.
26 | Args:
27 | optimizer ([`~torch.optim.Optimizer`]):
28 | The optimizer for which to schedule the learning rate.
29 | num_warmup_steps (`int`):
30 | The number of steps for the warmup phase.
31 | num_training_steps (`int`):
32 | The total number of training steps.
33 | num_cycles (`float`, *optional*, defaults to 0.5):
34 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
35 | following a half-cosine).
36 | last_epoch (`int`, *optional*, defaults to -1):
37 | The index of the last epoch when resuming training.
38 | Return:
39 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
40 | """
41 |
42 | lr_lambda = partial(
43 | _get_cosine_schedule_with_warmup_lr_lambda,
44 | num_warmup_steps=num_warmup_steps,
45 | num_training_steps=num_training_steps,
46 | num_cycles=num_cycles,
47 | )
48 | return LambdaLR(optimizer, lr_lambda, last_epoch)
--------------------------------------------------------------------------------
/SCD/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from option import Options
3 | from data.cd_dataset import DataLoader
4 | from model.create_model import create_model
5 | from tqdm import tqdm
6 | import math
7 | from util.palette import color_map
8 | from util.metric import IOUandSek
9 | import os
10 | import numpy as np
11 | import random
12 | from PIL import Image
13 |
14 | if __name__ == '__main__':
15 | opt = Options().parse()
16 | #opt.batch_size = 1
17 | opt.phase = 'test'
18 | test_loader = DataLoader(opt)
19 | test_data = test_loader.load_data()
20 | test_size = len(test_loader)
21 | print("#testing images = %d" % test_size)
22 |
23 | opt.load_pretrain = True
24 | model = create_model(opt)
25 |
26 | tbar = tqdm(test_data)
27 | total_iters = test_size
28 | metric = IOUandSek(num_classes=7)
29 |
30 | model.eval()
31 | for i, _data in enumerate(tbar):
32 | cd_out, seg_out1, seg_out2 = model.inference(_data['img1'].cuda(), _data['img2'].cuda())
33 | # update metric
34 | val_target = _data['cd_label'].detach()
35 | cd_out = torch.argmax(cd_out.detach(), dim=1)
36 | # val_pred = torch.where(val_pred > 0.5, torch.ones_like(val_pred), torch.zeros_like(val_pred)).long()
37 | seg_out1 = torch.argmax(seg_out1, dim=1).cpu().numpy()
38 | seg_out2 = torch.argmax(seg_out2, dim=1).cpu().numpy()
39 | cd_out = cd_out.cpu().numpy().astype(np.uint8)
40 | seg_out1[cd_out == 0] = 0
41 | seg_out2[cd_out == 0] = 0
42 |
43 | if opt.save_mask:
44 | cmap = color_map(opt.dataset)
45 | for i in range(seg_out1.shape[0]):
46 | mask = Image.fromarray(seg_out1[i].astype(np.uint8), mode="P")
47 | mask.putpalette(cmap)
48 | os.makedirs(os.path.join(opt.result_dir, 'test', 'im1'), exist_ok=True)
49 | mask.save(os.path.join(opt.result_dir, 'test', 'im1', _data['fname'][i]))
50 |
51 | mask = Image.fromarray(seg_out2[i].astype(np.uint8), mode="P")
52 | mask.putpalette(cmap)
53 | os.makedirs(os.path.join(opt.result_dir, 'test', 'im2'), exist_ok=True)
54 | mask.save(os.path.join(opt.result_dir, 'test', 'im2', _data['fname'][i]))
55 |
56 | metric.add_batch(seg_out1, _data['label1'].numpy())
57 | metric.add_batch(seg_out2, _data['label2'].numpy())
58 | score, miou, sek = metric.evaluate()
59 | tbar.set_description("Score: %.3f, IOU: %.3f, SeK: %.3f" % (score * 100.0, miou * 100.0, sek * 100.0))
60 |
61 |
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | class Options():
5 | def __init__(self):
6 | self.parser = argparse.ArgumentParser()
7 |
8 | def init(self):
9 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0. use -1 for CPU')
10 | self.parser.add_argument('--name', type=str, default='SV_mobilenetv2')
11 | self.parser.add_argument('--dataroot', type=str, default='./datasets')
12 | self.parser.add_argument('--dataset', type=str, default='SV')
13 | self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='models are saved here')
14 |
15 | self.parser.add_argument('--result_dir', type=str, default='./results', help='results are saved here')
16 | self.parser.add_argument('--load_pretrain', type=bool, default=False)
17 |
18 | self.parser.add_argument('--phase', type=str, default='train')
19 | self.parser.add_argument('--input_size', type=int, default=256)
20 | self.parser.add_argument('--backbone', type=str, default='mobilenetv2')
21 | self.parser.add_argument('--fpn', type=str, default='fpn')
22 | self.parser.add_argument('--fpn_channels', type=int, default=128)
23 | self.parser.add_argument('--deform_groups', type=int, default=4)
24 | self.parser.add_argument('--gamma_mode', type=str, default='SE')
25 | self.parser.add_argument('--beta_mode', type=str, default='contextgatedconv')
26 | self.parser.add_argument('--num_heads', type=int, default=1)
27 | self.parser.add_argument('--num_points', type=int, default=8)
28 | self.parser.add_argument('--kernel_layers', type=int, default=1)
29 | self.parser.add_argument('--init_type', type=str, default='kaiming_normal')
30 | self.parser.add_argument('--alpha', type=float, default=0.25)
31 | self.parser.add_argument('--gamma', type=int, default=4, help='gamma for Focal loss')
32 | self.parser.add_argument('--dropout_rate', type=float, default=0.1)
33 |
34 | self.parser.add_argument('--batch_size', type=int, default=16)
35 | self.parser.add_argument('--num_epochs', type=int, default=200)
36 | self.parser.add_argument('--warmup_epochs', type=int, default=20)
37 | self.parser.add_argument('--num_workers', type=int, default=4, help='#threads for loading data')
38 | self.parser.add_argument('--lr', type=float, default=5e-4)
39 | self.parser.add_argument('--weight_decay', type=float, default=5e-4)
40 |
41 | def parse(self):
42 | self.init()
43 | self.opt = self.parser.parse_args()
44 |
45 | str_ids = self.opt.gpu_ids.split(',')
46 | self.opt.gpu_ids = []
47 | for str_id in str_ids:
48 | id = int(str_id)
49 | if id >= 0:
50 | self.opt.gpu_ids.append(id)
51 |
52 | # set gpu ids
53 | if len(self.opt.gpu_ids) > 0:
54 | torch.cuda.set_device(self.opt.gpu_ids[0])
55 |
56 | args = vars(self.opt)
57 |
58 | print('------------ Options -------------')
59 | for k, v in sorted(args.items()):
60 | print('%s: %s' % (str(k), str(v)))
61 | print('-------------- End ----------------')
62 |
63 | return self.opt
64 |
--------------------------------------------------------------------------------
/SCD/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | class Options():
5 | def __init__(self):
6 | self.parser = argparse.ArgumentParser()
7 |
8 | def init(self):
9 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0. use -1 for CPU')
10 | self.parser.add_argument('--name', type=str, default='SECOND_SCD_mobilenetv2_5')
11 | self.parser.add_argument('--dataroot', type=str, default='../../SupervisedCD/datasets')
12 | self.parser.add_argument('--dataset', type=str, default='SECOND_SCD')
13 | self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='models are saved here')
14 |
15 | self.parser.add_argument('--result_dir', type=str, default='./results', help='results are saved here')
16 | self.parser.add_argument('--load_pretrain', type=bool, default=False)
17 |
18 | self.parser.add_argument('--phase', type=str, default='train')
19 | self.parser.add_argument('--input_size', type=int, default=256)
20 | self.parser.add_argument('--backbone', type=str, default='mobilenetv2')
21 | self.parser.add_argument('--fpn', type=str, default='fpn')
22 | self.parser.add_argument('--fpn_channels', type=int, default=128)
23 | self.parser.add_argument('--deform_groups', type=int, default=4)
24 | self.parser.add_argument('--gamma_mode', type=str, default='SE')
25 | self.parser.add_argument('--beta_mode', type=str, default='contextgatedconv')
26 | self.parser.add_argument('--num_heads', type=int, default=1)
27 | self.parser.add_argument('--num_points', type=int, default=8)
28 | self.parser.add_argument('--kernel_layers', type=int, default=1)
29 | self.parser.add_argument('--init_type', type=str, default='kaiming_normal')
30 | self.parser.add_argument('--alpha', type=float, default=0.25)
31 | self.parser.add_argument('--gamma', type=int, default=4, help='gamma for Focal loss')
32 | self.parser.add_argument('--dropout_rate', type=float, default=0.1)
33 | self.parser.add_argument('--save_mask', type=bool, default=True)
34 |
35 | self.parser.add_argument('--batch_size', type=int, default=16)
36 | self.parser.add_argument('--num_epochs', type=int, default=50)
37 | self.parser.add_argument('--warmup_epochs', type=int, default=5)
38 | self.parser.add_argument('--num_workers', type=int, default=4, help='#threads for loading data')
39 | self.parser.add_argument('--lr', type=float, default=5e-4)
40 | self.parser.add_argument('--weight_decay', type=float, default=5e-4)
41 |
42 | def parse(self):
43 | self.init()
44 | self.opt = self.parser.parse_args()
45 |
46 | str_ids = self.opt.gpu_ids.split(',')
47 | self.opt.gpu_ids = []
48 | for str_id in str_ids:
49 | id = int(str_id)
50 | if id >= 0:
51 | self.opt.gpu_ids.append(id)
52 |
53 | # set gpu ids
54 | if len(self.opt.gpu_ids) > 0:
55 | torch.cuda.set_device(self.opt.gpu_ids[0])
56 |
57 | args = vars(self.opt)
58 |
59 | print('------------ Options -------------')
60 | for k, v in sorted(args.items()):
61 | print('%s: %s' % (str(k), str(v)))
62 | print('-------------- End ----------------')
63 |
64 | return self.opt
65 |
--------------------------------------------------------------------------------
/data/transform.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms.functional as TF
3 | from torchvision import transforms
4 | from torchvision.transforms import InterpolationMode
5 |
6 |
7 | class Transforms(object):
8 | def __call__(self, _data):
9 | img1, img2, cd_label = _data['img1'], _data['img2'], _data['cd_label']
10 |
11 | if random.random() < 0.5:
12 | img1_ = img1
13 | img1 = img2
14 | img2 = img1_
15 |
16 | if random.random() < 0.5:
17 | img1 = TF.hflip(img1)
18 | img2 = TF.hflip(img2)
19 | cd_label = TF.hflip(cd_label)
20 |
21 | if random.random() < 0.5:
22 | img1 = TF.vflip(img1)
23 | img2 = TF.vflip(img2)
24 | cd_label = TF.vflip(cd_label)
25 |
26 | if random.random() < 0.5:
27 | angles = [90, 180, 270]
28 | angle = random.choice(angles)
29 | img1 = TF.rotate(img1, angle)
30 | img2 = TF.rotate(img2, angle)
31 | cd_label = TF.rotate(cd_label, angle)
32 | ### We didnt use colorjitters for the SV dataset.
33 | """
34 | if random.random() < 0.5:
35 | colorjitters = []
36 | brightness_factor = random.uniform(0.5, 1.5)
37 | colorjitters.append(Lambda(lambda img: TF.adjust_brightness(img, brightness_factor)))
38 | contrast_factor = random.uniform(0.5, 1.5)
39 | colorjitters.append(Lambda(lambda img: TF.adjust_contrast(img, contrast_factor)))
40 | saturation_factor = random.uniform(0.5, 1.5)
41 | colorjitters.append(Lambda(lambda img: TF.adjust_saturation(img, saturation_factor)))
42 | random.shuffle(colorjitters)
43 | colorjitter = Compose(colorjitters)
44 | img1 = colorjitter(img1)
45 | img2 = colorjitter(img2)
46 | """
47 | if random.random() < 0.5:
48 | i, j, h, w = transforms.RandomResizedCrop(size=(256, 256)).get_params(img=img1, scale=[0.333, 1.0],
49 | ratio=[0.75, 1.333])
50 | img1 = TF.resized_crop(img1, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR)
51 | img2 = TF.resized_crop(img2, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR)
52 | cd_label = TF.resized_crop(cd_label, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST)
53 |
54 | return {'img1': img1, 'img2': img2, 'cd_label': cd_label}
55 |
56 |
57 | class Lambda(object):
58 | def __init__(self, lambd):
59 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
60 | self.lambd = lambd
61 |
62 | def __call__(self, img):
63 | return self.lambd(img)
64 |
65 | def __repr__(self):
66 | return self.__class__.__name__ + '()'
67 |
68 |
69 | class Compose(object):
70 | def __init__(self, transforms):
71 | self.transforms = transforms
72 |
73 | def __call__(self, img):
74 | for t in self.transforms:
75 | img = t(img)
76 | return img
77 |
78 | def __repr__(self):
79 | format_string = self.__class__.__name__ + '('
80 | for t in self.transforms:
81 | format_string += '\n'
82 | format_string += ' {0}'.format(t)
83 | format_string += '\n)'
84 | return format_string
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/data/cd_dataset.py:
--------------------------------------------------------------------------------
1 | from .transform import Transforms
2 | import numpy as np
3 | import os
4 | from PIL import Image
5 | import torch
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 |
9 |
10 | def make_dataset(dir):
11 | img_paths = []
12 | names = []
13 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
14 |
15 | for root, _, fnames in sorted(os.walk(dir)):
16 | for fname in fnames:
17 | path = os.path.join(root, fname)
18 | img_paths.append(path)
19 | names.append(fname)
20 |
21 | return img_paths, names
22 |
23 |
24 | class Load_Dataset(Dataset):
25 | def __init__(self, opt):
26 | super(Load_Dataset, self).__init__()
27 | self.opt = opt
28 |
29 | self.dir1 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'A')
30 | self.t1_paths, self.fnames = sorted(make_dataset(self.dir1))
31 |
32 | self.dir2 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'B')
33 | self.t2_paths, _ = sorted(make_dataset(self.dir2))
34 |
35 | self.dir_label = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'label')
36 | self.label_paths, _ = sorted(make_dataset(self.dir_label))
37 |
38 | self.dataset_size = len(self.t1_paths)
39 |
40 | self.normalize = transforms.Compose([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
41 | self.transform = transforms.Compose([Transforms()])
42 | self.to_tensor = transforms.Compose([transforms.ToTensor()])
43 |
44 | def __len__(self):
45 | return self.dataset_size
46 |
47 | def __getitem__(self, index):
48 | t1_path = self.t1_paths[index]
49 | fname = self.fnames[index]
50 | img1 = Image.open(t1_path)
51 |
52 | t2_path = self.t2_paths[index]
53 | img2 = Image.open(t2_path)
54 |
55 | label_path = self.label_paths[index]
56 | label = np.array(Image.open(label_path)) // 255
57 | cd_label = Image.fromarray(label)
58 |
59 | if self.opt.phase == 'train':
60 | _data = self.transform({'img1': img1, 'img2': img2, 'cd_label': cd_label})
61 | img1, img2, cd_label = _data['img1'], _data['img2'], _data['cd_label']
62 |
63 | img1 = self.to_tensor(img1)
64 | img2 = self.to_tensor(img2)
65 | img1 = self.normalize(img1)
66 | img2 = self.normalize(img2)
67 | cd_label = torch.from_numpy(np.array(cd_label))
68 | input_dict = {'img1': img1, 'img2': img2, 'cd_label': cd_label, 'fname': fname}
69 |
70 | return input_dict
71 |
72 |
73 | class DataLoader(torch.utils.data.Dataset):
74 |
75 | def __init__(self, opt):
76 | self.dataset = Load_Dataset(opt)
77 | self.dataloader = torch.utils.data.DataLoader(self.dataset,
78 | batch_size=opt.batch_size,
79 | shuffle=opt.phase=='train',
80 | pin_memory=True,
81 | drop_last=opt.phase=='train',
82 | num_workers=int(opt.num_workers)
83 | )
84 |
85 | def load_data(self):
86 | return self.dataloader
87 |
88 | def __len__(self):
89 | return len(self.dataset)
90 |
--------------------------------------------------------------------------------
/model/block/heads.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/swz30/MIRNet
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from .convs import GatedConv2d, ContextGatedConv2d
8 |
9 |
10 | class GatedResidualUp(nn.Module):
11 | def __init__(self, in_channels, up_mode='conv', gate_mode='gated'):
12 | super(GatedResidualUp, self).__init__()
13 | if up_mode == 'conv':
14 | self.residual_up = nn.Sequential(nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1,
15 | output_padding=1, bias=False),
16 | nn.BatchNorm2d(in_channels),
17 | nn.ReLU(True))
18 | elif up_mode == 'bilinear':
19 | self.residual_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
20 |
21 | if gate_mode == 'gated':
22 | self.gate = GatedConv2d(in_channels, in_channels // 2)
23 | elif gate_mode == 'context_gated':
24 | self.gate = ContextGatedConv2d(in_channels, in_channels // 2)
25 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
26 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True)
27 | )
28 | self.relu = nn.ReLU(True)
29 |
30 | def forward(self, x):
31 | residual = self.residual_up(x)
32 | residual = self.gate(residual)
33 | up = self.up(x)
34 | out = self.relu(up + residual)
35 | return out
36 |
37 |
38 | class GatedResidualUpHead(nn.Module):
39 | def __init__(self, in_channels=128, num_classes=1, dropout_rate=0.15):
40 | super(GatedResidualUpHead, self).__init__()
41 |
42 | self.up = nn.Sequential(GatedResidualUp(in_channels),
43 | GatedResidualUp(in_channels // 2))
44 | self.smooth = nn.Sequential(nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, stride=1, padding=1),
45 | nn.ReLU(True),
46 | nn.Dropout2d(dropout_rate))
47 | self.final = nn.Conv2d(in_channels // 4, num_classes, 1)
48 |
49 | def forward(self, x):
50 | x = self.up(x)
51 | x = self.smooth(x)
52 | x = self.final(x)
53 |
54 | return x
55 |
56 |
57 | class FCNHead(nn.Module):
58 | def __init__(self, in_channels, num_classes, num_convs=1, dropout_rate=0.15):
59 | self.num_convs = num_convs
60 | super(FCNHead, self).__init__()
61 | inter_channels = in_channels // 4
62 |
63 | convs = []
64 | convs.append(nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
65 | nn.BatchNorm2d(inter_channels),
66 | nn.ReLU(True)))
67 | for i in range(num_convs - 1):
68 | convs.append(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
69 | nn.BatchNorm2d(inter_channels),
70 | nn.ReLU(True)))
71 | self.convs = nn.Sequential(*convs)
72 | self.final = nn.Conv2d(inter_channels, num_classes, 1)
73 |
74 | def forward(self, x):
75 | out = self.convs(x)
76 | out = self.final(out)
77 |
78 | return out
79 |
80 |
--------------------------------------------------------------------------------
/SCD/data/transform.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import random
3 | import torchvision.transforms.functional as TF
4 | from torchvision import transforms
5 | from torchvision.transforms import InterpolationMode
6 |
7 |
8 | class Transforms(object):
9 | def __call__(self, _data):
10 | img1, img2, label1, label2, cd_label = _data['img1'], _data['img2'], _data['label1'], _data['label2'], _data['cd_label']
11 |
12 | if random.random() < 0.5:
13 | img1_ = img1
14 | img1 = img2
15 | img2 = img1_
16 | label1_ = label1
17 | label1 = label2
18 | label2 = label1_
19 |
20 | if random.random() < 0.5:
21 | img1 = TF.hflip(img1)
22 | img2 = TF.hflip(img2)
23 | label1 = TF.hflip(label1)
24 | label2 = TF.hflip(label2)
25 | cd_label = TF.hflip(cd_label)
26 |
27 | if random.random() < 0.5:
28 | img1 = TF.vflip(img1)
29 | img2 = TF.vflip(img2)
30 | label1 = TF.vflip(label1)
31 | label2 = TF.vflip(label2)
32 | cd_label = TF.vflip(cd_label)
33 |
34 | if random.random() < 0.5:
35 | angles = [90, 180, 270]
36 | angle = random.choice(angles)
37 | img1 = TF.rotate(img1, angle)
38 | img2 = TF.rotate(img2, angle)
39 | label1 = TF.rotate(label1, angle)
40 | label2 = TF.rotate(label2, angle)
41 | cd_label = TF.rotate(cd_label, angle)
42 |
43 | if random.random() < 0.5:
44 | colorjitters = []
45 | brightness_factor = random.uniform(0.5, 1.5)
46 | colorjitters.append(Lambda(lambda img: TF.adjust_brightness(img, brightness_factor)))
47 | contrast_factor = random.uniform(0.5, 1.5)
48 | colorjitters.append(Lambda(lambda img: TF.adjust_contrast(img, contrast_factor)))
49 | saturation_factor = random.uniform(0.5, 1.5)
50 | colorjitters.append(Lambda(lambda img: TF.adjust_saturation(img, saturation_factor)))
51 | random.shuffle(colorjitters)
52 | colorjitter = Compose(colorjitters)
53 | img1 = colorjitter(img1)
54 | img2 = colorjitter(img2)
55 |
56 | if random.random() < 0.5:
57 | i, j, h, w = transforms.RandomResizedCrop(size=(256, 256)).get_params(img=img1, scale=[0.333, 1.0],
58 | ratio=[0.75, 1.333])
59 | img1 = TF.resized_crop(img1, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR)
60 | img2 = TF.resized_crop(img2, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.BILINEAR)
61 | label1 = TF.resized_crop(label1, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST)
62 | label2 = TF.resized_crop(label2, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST)
63 | cd_label = TF.resized_crop(cd_label, i, j, h, w, size=(256, 256), interpolation=InterpolationMode.NEAREST)
64 |
65 | return {'img1': img1, 'img2': img2, 'label1': label1, 'label2': label2, 'cd_label': cd_label}
66 |
67 |
68 | class Lambda(object):
69 | def __init__(self, lambd):
70 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
71 | self.lambd = lambd
72 |
73 | def __call__(self, img):
74 | return self.lambd(img)
75 |
76 | def __repr__(self):
77 | return self.__class__.__name__ + '()'
78 |
79 |
80 | class Compose(object):
81 | def __init__(self, transforms):
82 | self.transforms = transforms
83 |
84 | def __call__(self, img):
85 | for t in self.transforms:
86 | img = t(img)
87 | return img
88 |
89 | def __repr__(self):
90 | format_string = self.__class__.__name__ + '('
91 | for t in self.transforms:
92 | format_string += '\n'
93 | format_string += ' {0}'.format(t)
94 | format_string += '\n)'
95 | return format_string
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/trainval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from option import Options
3 | from data.cd_dataset import DataLoader
4 | from model.create_model import create_model
5 | from tqdm import tqdm
6 | import math
7 | from util.metric_tool import ConfuseMatrixMeter
8 | import os
9 | import numpy as np
10 | import random
11 |
12 |
13 | def setup_seed(seed):
14 | os.environ['PYTHONHASHSEED'] = str(seed)
15 | np.random.seed(seed)
16 | random.seed(seed)
17 |
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | torch.backends.cudnn.deterministic = False #!
22 | torch.backends.cudnn.benchmark = True #!
23 | torch.backends.cudnn.enabled = True #! for accelerating training
24 |
25 |
26 | class Trainval(object):
27 | def __init__(self, opt):
28 | self.opt = opt
29 |
30 | train_loader = DataLoader(opt)
31 | self.train_data = train_loader.load_data()
32 | train_size = len(train_loader)
33 | print("#training images = %d" % train_size)
34 | opt.phase = 'val'
35 | val_loader = DataLoader(opt)
36 | self.val_data = val_loader.load_data()
37 | val_size = len(val_loader)
38 | print("#validation images = %d" % val_size)
39 | opt.phase = 'train'
40 |
41 | self.model = create_model(opt)
42 | self.optimizer = self.model.optimizer
43 | self.schedular = self.model.schedular
44 |
45 | self.iters = 0
46 | self.total_iters = math.ceil(train_size / opt.batch_size) * opt.num_epochs
47 | self.previous_best = 0.0
48 | self.running_metric = ConfuseMatrixMeter(n_class=2)
49 |
50 | def train(self):
51 | tbar = tqdm(self.train_data, ncols=80)
52 | opt.phase = 'train'
53 | _loss = 0.0
54 | _focal_loss = 0.0
55 | _dice_loss = 0.0
56 |
57 | for i, data in enumerate(tbar):
58 | self.model.detector.train()
59 | focal, dice, p2_loss, p3_loss, p4_loss, p5_loss = self.model(data['img1'].cuda(), data['img2'].cuda(), data['cd_label'].cuda())
60 | loss = focal * 0.5 + dice + p3_loss + p4_loss + p5_loss
61 | self.optimizer.zero_grad()
62 | loss.backward()
63 | self.optimizer.step()
64 | self.schedular.step()
65 | _loss += loss.item()
66 | _focal_loss += focal.item()
67 | _dice_loss += dice.item()
68 | del loss
69 |
70 | tbar.set_description("Loss: %.3f, Focal: %.3f, Dice: %.3f, LR: %.6f" %
71 | (_loss / (i + 1), _focal_loss / (i + 1), _dice_loss / (i + 1), self.optimizer.param_groups[0]['lr']))
72 |
73 | def val(self):
74 | tbar = tqdm(self.val_data, ncols=80)
75 | self.running_metric.clear()
76 | opt.phase = 'val'
77 | self.model.eval()
78 |
79 | with torch.no_grad():
80 | for i, _data in enumerate(tbar):
81 | val_pred = self.model.inference(_data['img1'].cuda(), _data['img2'].cuda())
82 | # update metric
83 | val_target = _data['cd_label'].detach()
84 | val_pred = torch.argmax(val_pred.detach(), dim=1)
85 | _ = self.running_metric.update_cm(pr=val_pred.cpu().numpy(), gt=val_target.cpu().numpy())
86 | val_scores = self.running_metric.get_scores()
87 | message = '(phase: %s) ' % (self.opt.phase)
88 | for k, v in val_scores.items():
89 | message += '%s: %.3f ' % (k, v * 100)
90 | print(message)
91 |
92 | if val_scores['mf1'] >= self.previous_best:
93 | self.model.save(self.opt.name, self.opt.backbone)
94 | self.previous_best = val_scores['mf1']
95 |
96 |
97 | if __name__ == "__main__":
98 | opt = Options().parse()
99 | trainval = Trainval(opt)
100 | setup_seed(seed=1)
101 |
102 | for epoch in range(1, opt.num_epochs + 1):
103 | print("\n==> Name %s, Epoch %i, previous best = %.3f" % (opt.name, epoch, trainval.previous_best * 100))
104 | trainval.train()
105 | trainval.val()
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/SCD/data/cd_dataset.py:
--------------------------------------------------------------------------------
1 | from .transform import Transforms
2 | from util.palette import Color2Index
3 | import numpy as np
4 | import os
5 | from PIL import Image
6 | import random
7 | import torch
8 | from torch.utils.data import Dataset
9 | from torchvision import transforms
10 |
11 |
12 | def make_dataset(dir):
13 | img_paths = []
14 | names = []
15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
16 |
17 | for root, _, fnames in sorted(os.walk(dir)):
18 | for fname in fnames:
19 | path = os.path.join(root, fname)
20 | img_paths.append(path)
21 | names.append(fname)
22 |
23 | return img_paths, names
24 |
25 |
26 | class Load_Dataset(Dataset):
27 | def __init__(self, opt):
28 | super(Load_Dataset, self).__init__()
29 | self.opt = opt
30 |
31 | self.dir1 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'im1')
32 | self.t1_paths, self.fnames = sorted(make_dataset(self.dir1))
33 |
34 | self.dir2 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'im2')
35 | self.t2_paths, _ = sorted(make_dataset(self.dir2))
36 |
37 | self.dir_label1 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'label1')
38 | self.label1_paths, _ = sorted(make_dataset(self.dir_label1))
39 |
40 | self.dir_label2 = os.path.join(opt.dataroot, opt.dataset, opt.phase, 'label2')
41 | self.label2_paths, _ = sorted(make_dataset(self.dir_label2))
42 |
43 | self.dataset_size = len(self.t1_paths)
44 |
45 | self.normalize = transforms.Compose([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
46 | self.transform = transforms.Compose([Transforms()])
47 | self.to_tensor = transforms.Compose([transforms.ToTensor()])
48 |
49 |
50 | def __len__(self):
51 | return self.dataset_size
52 |
53 | def __getitem__(self, index):
54 | t1_path = self.t1_paths[index]
55 | fname = self.fnames[index]
56 | img1 = Image.open(t1_path)
57 |
58 | t2_path = self.t2_paths[index]
59 | img2 = Image.open(t2_path)
60 |
61 | label1_path = self.label1_paths[index]
62 | label1 = Image.open(label1_path)
63 | label1 = Image.fromarray(Color2Index(self.opt.dataset, np.array(label1)))
64 |
65 | label2_path = self.label2_paths[index]
66 | label2 = Image.open(label2_path)
67 | label2 = Image.fromarray(Color2Index(self.opt.dataset, np.array(label2)))
68 |
69 | mask = np.array(label1)
70 | cd_label = np.ones_like(mask)
71 | cd_label[mask == 0] = 0
72 | cd_label = Image.fromarray(cd_label)
73 |
74 | if self.opt.phase == 'train':
75 | _data = self.transform(
76 | {'img1': img1, 'img2': img2, 'label1': label1, 'label2': label2, 'cd_label': cd_label})
77 | img1, img2, label1, label2, cd_label = _data['img1'], _data['img2'], _data['label1'], _data['label2'], \
78 | _data['cd_label']
79 |
80 | img1 = self.to_tensor(img1)
81 | img2 = self.to_tensor(img2)
82 | img1 = self.normalize(img1)
83 | img2 = self.normalize(img2)
84 | label1 = torch.from_numpy(np.array(label1)).long()
85 | label2 = torch.from_numpy(np.array(label2)).long()
86 | cd_label = torch.from_numpy(np.array(cd_label)).long()
87 | input_dict = {'img1': img1, 'img2': img2, 'label1': label1, 'label2': label2, 'cd_label': cd_label,
88 | 'fname': fname}
89 |
90 | return input_dict
91 |
92 |
93 | class DataLoader(torch.utils.data.Dataset):
94 |
95 | def __init__(self, opt):
96 | self.dataset = Load_Dataset(opt)
97 | self.dataloader = torch.utils.data.DataLoader(self.dataset,
98 | batch_size=opt.batch_size,
99 | shuffle=opt.phase=='train',
100 | pin_memory=True,
101 | drop_last=opt.phase=='train',
102 | num_workers=int(opt.num_workers)
103 | )
104 |
105 | def load_data(self):
106 | return self.dataloader
107 |
108 | def __len__(self):
109 | return len(self.dataset)
110 |
--------------------------------------------------------------------------------
/model/block/vertical.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .convs import ConvBnRelu
5 | from mmcv.ops import MultiScaleDeformableAttention
6 | from einops import rearrange
7 | from torch import einsum
8 |
9 |
10 | class ScaledSinuEmbedding(nn.Module):
11 | def __init__(self, dim):
12 | super().__init__()
13 | self.scale = nn.Parameter(torch.ones(1,))
14 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
15 | self.register_buffer('inv_freq', inv_freq)
16 |
17 | def forward(self, x):
18 | n, device = x.shape[1], x.device
19 | t = torch.arange(n, device=device).type_as(self.inv_freq)
20 | sinu = einsum('i, j -> i j', t, self.inv_freq)
21 | emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
22 | return emb * self.scale
23 |
24 |
25 | def get_reference_points(spatial_shapes, device):
26 | reference_points_list = []
27 | for lvl, (H_, W_) in enumerate(spatial_shapes):
28 | ref_y, ref_x = torch.meshgrid(
29 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
30 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
31 | ref_y = ref_y.reshape(-1)[None] / H_
32 | ref_x = ref_x.reshape(-1)[None] / W_
33 | ref = torch.stack((ref_x, ref_y), -1)
34 | reference_points_list.append(ref)
35 | reference_points = torch.cat(reference_points_list, 1)
36 | reference_points = reference_points[:, :, None]
37 | return reference_points
38 |
39 |
40 | ### Attend-then-Filter
41 | class VerticalFusion(nn.Module):
42 | def __init__(self, channels, num_heads=4, num_points=4, kernel_layers=1, up_kernel_size=5, enc_kernel_size=3):
43 | super(VerticalFusion, self).__init__()
44 | self.norm1 = nn.LayerNorm(channels)
45 | self.norm2 = nn.LayerNorm(channels)
46 | self.pos = ScaledSinuEmbedding(channels)
47 | self.crossattn = MultiScaleDeformableAttention(embed_dims=channels, num_levels=1, num_heads=num_heads,
48 | num_points=num_points, batch_first=True, dropout=0)
49 | convs = []
50 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels))
51 | for _ in range(kernel_layers - 1):
52 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels))
53 | self.convs = nn.Sequential(*convs)
54 | self.enc = ConvBnRelu(channels, up_kernel_size ** 2, kernel_size=enc_kernel_size,
55 | stride=1, padding=enc_kernel_size // 2, dilation=1)
56 |
57 | self.upsmp = nn.Upsample(scale_factor=2, mode='nearest')
58 | self.unfold = nn.Unfold(kernel_size=up_kernel_size, dilation=2,
59 | padding=up_kernel_size // 2 * 2)
60 |
61 | def get_deform_inputs(self, x1, x2):
62 | _, _, H1, W1 = x1.size()
63 | _, _, H2, W2 = x2.size()
64 | spatial_shapes = torch.as_tensor([(H2, W2)], dtype=torch.long, device=x2.device)
65 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
66 | reference_points = get_reference_points([(H1, W1)], x1.device)
67 |
68 | return reference_points, spatial_shapes, level_start_index
69 |
70 | def forward(self, x1, x2):
71 | #Attend
72 | reference_points, spatial_shapes, level_start_index = self.get_deform_inputs(x1, x2)
73 | B, C, H, W = x1.size()
74 | _, _, H2, W2 = x2.size()
75 | x1_, x2_ = x1.clone(), x2.clone()
76 | x1 = rearrange(x1, 'b c h w -> b (h w) c')
77 | x2 = rearrange(x2, 'b c h w -> b (h w) c')
78 | x1, x2 = self.norm1(x1), self.norm2(x2)
79 | query_pos = self.pos(x1)
80 | x = self.crossattn(query=x1, value=x2, reference_points=reference_points, spatial_shapes=spatial_shapes,
81 | level_start_index=level_start_index, query_pos=query_pos)
82 | x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
83 |
84 | #Filter
85 | kernel = self.convs(x2_)
86 | kernel = self.enc(kernel)
87 | kernel = F.softmax(kernel, dim=1)
88 | # x = self.upsmp(x)
89 | x = F.interpolate(x, size=(H2, W2), mode='nearest')
90 | x = self.unfold(x)
91 | # x = x.view(B, C, -1, H * 2, W * 2)
92 | x = x.view(B, C, -1, H2, W2)
93 | fuse = torch.einsum('bkhw,bckhw->bchw', [kernel, x])
94 | fuse += x2_
95 |
96 | return fuse
97 |
98 |
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import timm
5 | from .backbone.mobilenetv2 import mobilenet_v2
6 | from .block.fpn import FPN
7 | from .block.vertical import VerticalFusion
8 | from .block.convs import ConvBnRelu, DsBnRelu
9 | from .util import init_method
10 | from .block.heads import FCNHead, GatedResidualUpHead
11 |
12 |
13 | def get_backbone(backbone_name):
14 | if backbone_name == 'mobilenetv2':
15 | backbone = mobilenet_v2(pretrained=True, progress=True)
16 | backbone.channels = [16, 24, 32, 96, 320]
17 | elif backbone_name == 'resnet18d':
18 | backbone = timm.create_model('resnet18d', pretrained=True, features_only=True)
19 | backbone.channels = [64, 64, 128, 256, 512]
20 | else:
21 | raise NotImplementedError("BACKBONE [%s] is not implemented!\n" % backbone_name)
22 | return backbone
23 |
24 |
25 | def get_fpn(fpn_name, in_channels, out_channels, deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv'):
26 | if fpn_name == 'fpn':
27 | fpn = FPN(in_channels, out_channels, deform_groups, gamma_mode, beta_mode)
28 | else:
29 | raise NotImplementedError("FPN [%s] is not implemented!\n" % fpn_name)
30 | return fpn
31 |
32 |
33 | class Detector(nn.Module):
34 | def __init__(self, backbone_name='mobilenetv2', fpn_name='fpn', fpn_channels=128,
35 | deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv',
36 | num_heads=1, num_points=8, kernel_layers=1, dropout_rate=0.1, init_type='kaiming_normal'):
37 | super().__init__()
38 | self.backbone = get_backbone(backbone_name)
39 | self.fpn = get_fpn(fpn_name, in_channels=self.backbone.channels[-4:], out_channels=fpn_channels,
40 | deform_groups=deform_groups, gamma_mode=gamma_mode, beta_mode=beta_mode)
41 | self.p5_to_p4 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=4,
42 | kernel_layers=kernel_layers)
43 | self.p4_to_p3 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=8,
44 | kernel_layers=kernel_layers)
45 | self.p3_to_p2 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=16,
46 | kernel_layers=kernel_layers)
47 |
48 | self.p5_head = nn.Conv2d(fpn_channels, 2, 1)
49 | self.p4_head = nn.Conv2d(fpn_channels, 2, 1)
50 | self.p3_head = nn.Conv2d(fpn_channels, 2, 1)
51 | self.p2_head = nn.Conv2d(fpn_channels, 2, 1)
52 | self.project = nn.Sequential(nn.Conv2d(fpn_channels*4, fpn_channels, 1, bias=False),
53 | nn.BatchNorm2d(fpn_channels),
54 | nn.ReLU(True)
55 | )
56 | self.head = GatedResidualUpHead(fpn_channels, 2, dropout_rate=dropout_rate)
57 | # init_method(self.fpn, self.p5_to_p4, self.p4_to_p3, self.p3_to_p2, self.p5_head, self.p4_head,
58 | # self.p3_head, self.p2_head, init_type=init_type)
59 |
60 | def forward(self, x1, x2):
61 | ### Extract backbone features
62 | t1_c1, t1_c2, t1_c3, t1_c4, t1_c5 = self.backbone.forward(x1)
63 | t2_c1, t2_c2, t2_c3, t2_c4, t2_c5 = self.backbone.forward(x2)
64 | t1_p2, t1_p3, t1_p4, t1_p5 = self.fpn([t1_c2, t1_c3, t1_c4, t1_c5])
65 | t2_p2, t2_p3, t2_p4, t2_p5 = self.fpn([t2_c2, t2_c3, t2_c4, t2_c5])
66 |
67 | diff_p2 = torch.abs(t1_p2 - t2_p2)
68 | diff_p3 = torch.abs(t1_p3 - t2_p3)
69 | diff_p4 = torch.abs(t1_p4 - t2_p4)
70 | diff_p5 = torch.abs(t1_p5 - t2_p5)
71 |
72 | fea_p5 = diff_p5
73 | pred_p5 = self.p5_head(fea_p5)
74 | fea_p4 = self.p5_to_p4(fea_p5, diff_p4)
75 | pred_p4 = self.p4_head(fea_p4)
76 | fea_p3 = self.p4_to_p3(fea_p4, diff_p3)
77 | pred_p3 = self.p3_head(fea_p3)
78 | fea_p2 = self.p3_to_p2(fea_p3, diff_p2)
79 | pred_p2 = self.p2_head(fea_p2)
80 | pred = self.head(fea_p2)
81 |
82 | pred_p2 = F.interpolate(pred_p2, size=(256, 256), mode='bilinear', align_corners=False)
83 | pred_p3 = F.interpolate(pred_p3, size=(256, 256), mode='bilinear', align_corners=False)
84 | pred_p4 = F.interpolate(pred_p4, size=(256, 256), mode='bilinear', align_corners=False)
85 | pred_p5 = F.interpolate(pred_p5, size=(256, 256), mode='bilinear', align_corners=False)
86 |
87 | return pred, pred_p2, pred_p3, pred_p4, pred_p5
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/model/create_model.py:
--------------------------------------------------------------------------------
1 | from .network import Detector
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | import os
7 | import torch.optim as optim
8 | from .block.schedular import get_cosine_schedule_with_warmup
9 | from .loss.focal import FocalLoss
10 | from .loss.dice import DICELoss
11 |
12 | def get_model(backbone_name='mobilenetv2', fpn_name='fpn', fpn_channels=128, deform_groups=4,
13 | gamma_mode='SE', beta_mode='contextgatedconv', num_heads=1, num_points=8, kernel_layers=1,
14 | dropout_rate=0.1, init_type='kaiming_normal'):
15 | detector = Detector(backbone_name, fpn_name, fpn_channels, deform_groups, gamma_mode, beta_mode,
16 | num_heads, num_points, kernel_layers, dropout_rate, init_type)
17 | print(detector)
18 |
19 | return detector
20 |
21 |
22 | class Model(nn.Module):
23 | def __init__(self, opt):
24 | super(Model, self).__init__()
25 | self.device = torch.device("cuda:%s" % opt.gpu_ids[0] if torch.cuda.is_available() else "cpu")
26 | self.opt = opt
27 | self.base_lr = opt.lr
28 | self.save_dir = os.path.join(opt.checkpoint_dir, opt.name)
29 | os.makedirs(self.save_dir, exist_ok=True)
30 |
31 | self.detector = get_model(backbone_name=opt.backbone, fpn_name=opt.fpn, fpn_channels=opt.fpn_channels,
32 | deform_groups=opt.deform_groups, gamma_mode=opt.gamma_mode, beta_mode=opt.beta_mode,
33 | num_heads=opt.num_heads, num_points=opt.num_points, kernel_layers=opt.kernel_layers,
34 | dropout_rate=opt.dropout_rate, init_type=opt.init_type)
35 | self.focal = FocalLoss(alpha=opt.alpha, gamma=opt.gamma)
36 | self.dice = DICELoss()
37 |
38 | self.optimizer = optim.AdamW(self.detector.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
39 | # Here, 625 = #training images // batch_size.
40 | self.schedular = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=625 * opt.warmup_epochs,
41 | num_training_steps=625 * opt.num_epochs)
42 | if opt.load_pretrain:
43 | self.load_ckpt(self.detector, self.optimizer, opt.name, opt.backbone)
44 | self.detector.cuda()
45 |
46 | print("---------- Networks initialized -------------")
47 |
48 | def forward(self, x1, x2, label):
49 | pred, pred_p2, pred_p3, pred_p4, pred_p5 = self.detector(x1, x2)
50 | label = label.long()
51 | focal = self.focal(pred, label)
52 | dice = self.dice(pred, label)
53 | p2_loss = self.focal(pred_p2, label) * 0.5 + self.dice(pred_p2, label)
54 | p3_loss = self.focal(pred_p3, label) * 0.5 + self.dice(pred_p3, label)
55 | p4_loss = self.focal(pred_p4, label) * 0.5 + self.dice(pred_p4, label)
56 | p5_loss = self.focal(pred_p5, label) * 0.5 + self.dice(pred_p5, label)
57 |
58 | return focal, dice, p2_loss, p3_loss, p4_loss, p5_loss
59 |
60 | def inference(self, x1, x2):
61 | with torch.no_grad():
62 | pred, _, _, _, _ = self.detector(x1, x2)
63 | return pred
64 |
65 | def load_ckpt(self, network, optimizer, name, backbone):
66 | save_filename = '%s_%s_best.pth' % (name, backbone)
67 | save_path = os.path.join(self.save_dir, save_filename)
68 | if not os.path.isfile(save_path):
69 | print("%s not exists yet!" % save_path)
70 | raise ("%s must exist!" % save_filename)
71 | else:
72 | checkpoint = torch.load(save_path, map_location=self.device)
73 | network.load_state_dict(checkpoint['network'], False)
74 |
75 | def save_ckpt(self, network, optimizer, model_name, backbone):
76 | save_filename = '%s_%s_best.pth' % (model_name, backbone)
77 | save_path = os.path.join(self.save_dir, save_filename)
78 | if os.path.exists(save_path):
79 | os.remove(save_path)
80 | torch.save({'network': network.cpu().state_dict(),
81 | 'optimizer': optimizer.state_dict()},
82 | save_path)
83 | if torch.cuda.is_available():
84 | network.cuda()
85 |
86 | def save(self, model_name, backbone):
87 | self.save_ckpt(self.detector, self.optimizer, model_name, backbone)
88 |
89 | def name(self):
90 | return self.opt.name
91 |
92 |
93 | def create_model(opt):
94 | model = Model(opt)
95 | print("model [%s] was created" % model.name())
96 |
97 | return model.cuda()
98 |
99 |
--------------------------------------------------------------------------------
/SCD/model/loss/focal.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | from einops import rearrange
11 | from typing import Optional, List
12 | from functools import partial
13 | from torch.nn.modules.loss import _Loss
14 | import numpy as np
15 |
16 |
17 | class BinaryFocalLoss(nn.Module):
18 | def __init__(self, alpha=0.25, gamma=4.0):
19 | super(BinaryFocalLoss, self).__init__()
20 | self.alpha = alpha
21 | self.gamma = gamma
22 | if isinstance(alpha, (float, int)):
23 | self.alpha = torch.as_tensor([alpha, 1 - alpha])
24 | if isinstance(alpha, list):
25 | self.alpha = torch.as_tensor(alpha)
26 |
27 | def forward(self, input, target):
28 | N, C, H, W = input.size()
29 | assert C == 2
30 | # input = input.view(N, C, -1)
31 | # input = input.transpose(1, 2)
32 | # input = input.contiguous().view(-1, C)
33 | input = rearrange(input, 'b c h w -> (b h w) c')
34 | # input = input.contiguous().view(-1)
35 |
36 | target = target.view(-1, 1)
37 | logpt = F.log_softmax(input, dim=1)
38 | logpt = logpt.gather(1, target)
39 | logpt = logpt.view(-1)
40 | pt = Variable(logpt.data.exp())
41 |
42 | if self.alpha is not None:
43 | if self.alpha.type() != input.data.type():
44 | self.alpha = self.alpha.type_as(input.data)
45 | at = self.alpha.gather(0, target.data.view(-1))
46 | logpt = logpt * Variable(at)
47 | loss = -1 * (1-pt)**self.gamma * logpt
48 |
49 | return loss.mean()
50 |
51 |
52 | class FocalLoss(_Loss):
53 | def __init__(
54 | self,
55 | alpha: Optional[float] = 0.25,
56 | gamma: Optional[float] = 2.0,
57 | ignore_index: Optional[int] = None,
58 | normalized: bool = False,
59 | reduced_threshold: Optional[float] = None,
60 | ):
61 | super(FocalLoss, self).__init__()
62 | self.ignore_index = ignore_index
63 | self.focal_loss_fn = partial(
64 | focal_loss_with_logits,
65 | alpha=alpha,
66 | gamma=gamma,
67 | reduced_threshold=reduced_threshold,
68 | reduction='mean',
69 | normalized=normalized
70 | )
71 |
72 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
73 | num_classes = y_pred.size(1)
74 | loss = 0
75 |
76 | # Filter anchors with -1 label from loss computation
77 | if self.ignore_index is not None:
78 | not_ignored = y_true != self.ignore_index
79 |
80 | for cls in range(num_classes):
81 | cls_y_true = (y_true == cls).long()
82 | cls_y_pred = y_pred[:, cls, ...]
83 |
84 | if self.ignore_index is not None:
85 | cls_y_true = cls_y_true[not_ignored]
86 | cls_y_pred = cls_y_pred[not_ignored]
87 |
88 | loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
89 |
90 | return loss
91 |
92 |
93 | def focal_loss_with_logits(
94 | output: torch.Tensor,
95 | target: torch.Tensor,
96 | gamma: float = 2.0,
97 | alpha: Optional[float] = 0.25,
98 | reduction: str = 'mean',
99 | normalized: bool = True,
100 | reduced_threshold: Optional[float] = None,
101 | eps: float = 1e-6
102 | ) -> torch.Tensor:
103 |
104 | target = target.type(output.type())
105 | logpt = F.binary_cross_entropy_with_logits(output, target, reduction='none')
106 | pt = torch.exp(-logpt)
107 |
108 | # compute the loss
109 | if reduced_threshold is None:
110 | focal_term = (1.0 - pt).pow(gamma)
111 | else:
112 | focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
113 | focal_term[pt < reduced_threshold] = 1
114 |
115 | loss = focal_term * logpt
116 |
117 | if alpha is not None:
118 | loss *= alpha * target + (1 - alpha) * (1 - target)
119 |
120 | if normalized:
121 | norm_factor = focal_term.sum().clamp_min(eps)
122 | loss /= norm_factor
123 |
124 | if reduction == 'mean':
125 | loss = loss.mean()
126 | if reduction == 'sum':
127 | loss = loss.sum()
128 | if reduction == 'batchwise_mean':
129 | loss = loss.sum(0)
130 |
131 | return loss
132 |
133 |
134 |
135 | if __name__ == '__main__':
136 | x = torch.randn(3, 7, 256, 256)
137 | y = torch.ones(3, 256, 256).long()
138 | model = FocalLoss(alpha=0.25, gamma=2)
139 | out = model(x, y)
140 | print(out)
141 |
--------------------------------------------------------------------------------
/SCD/model/loss/dice.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.data
5 | from kornia.losses import dice_loss
6 | from torch.autograd import Variable
7 | from typing import Optional, List
8 | from functools import partial
9 | from torch.nn.modules.loss import _Loss
10 | import numpy as np
11 |
12 | class BinaryDICELoss(nn.Module):
13 | def __init__(self):
14 | super(BinaryDICELoss, self).__init__()
15 |
16 | def forward(self, input, target):
17 | target = target.squeeze(1)
18 | loss = dice_loss(input, target)
19 |
20 | return loss
21 |
22 |
23 | ### version 2
24 | """
25 | class BinaryDICELoss(nn.Module):
26 | def __init__(self, eps=1e-5):
27 | super(BinaryDICELoss, self).__init__()
28 | self.eps = eps
29 |
30 | def to_one_hot(self, target):
31 | N, C, H, W = target.size()
32 | assert C == 1
33 | target = torch.zeros(N, 2, H, W).to(target.device).scatter_(1, target, 1)
34 | return target
35 |
36 | def forward(self, input, target):
37 | N, C, _, _ = input.size()
38 | input = F.softmax(input, dim=1)
39 |
40 | #target = self.to_one_hot(target)
41 | target = torch.eye(2)[target.squeeze(1)]
42 | target = target.permute(0, 3, 1, 2).type_as(input)
43 |
44 | dims = tuple(range(1, target.ndimension()))
45 | inter = torch.sum(input * target, dims)
46 | cardinality = torch.sum(input + target, dims)
47 | loss = ((2. * inter) / (cardinality + self.eps)).mean()
48 |
49 | return 1 - loss
50 | """
51 |
52 | class DICELoss(_Loss):
53 | def __init__(
54 | self,
55 | smooth: float = 0.0,
56 | ignore_index: Optional[int] = None,
57 | eps: float = 1e-7,
58 | ):
59 | super(DICELoss, self).__init__()
60 | self.smooth = smooth
61 | self.eps = eps
62 | self.ignore_index = ignore_index
63 |
64 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
65 |
66 | assert y_true.size(0) == y_pred.size(0)
67 | y_pred = y_pred.log_softmax(dim=1).exp()
68 |
69 | bs = y_true.size(0)
70 | num_classes = y_pred.size(1)
71 | dims = (0, 2)
72 |
73 | y_true = y_true.view(bs, -1)
74 | y_pred = y_pred.view(bs, num_classes, -1)
75 |
76 | if self.ignore_index is not None:
77 | mask = y_true != self.ignore_index
78 | y_pred = y_pred * mask.unsqueeze(1)
79 |
80 | y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C
81 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
82 | else:
83 | y_true = F.one_hot(y_true.to(torch.long), num_classes) # N,H*W -> N,H*W, C
84 | y_true = y_true.permute(0, 2, 1) # N, C, H*W
85 |
86 | scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
87 | loss = 1.0 - scores
88 |
89 | mask = y_true.sum(dims) > 0
90 | loss *= mask.to(loss.dtype)
91 |
92 | return self.aggregate_loss(loss)
93 |
94 | def aggregate_loss(self, loss):
95 | return loss.mean()
96 |
97 | def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
98 | return soft_dice_score(output, target, smooth, eps, dims)
99 |
100 |
101 | def to_tensor(x, dtype=None) -> torch.Tensor:
102 | if isinstance(x, torch.Tensor):
103 | if dtype is not None:
104 | x = x.type(dtype)
105 | return x
106 | if isinstance(x, np.ndarray):
107 | x = torch.from_numpy(x)
108 | if dtype is not None:
109 | x = x.type(dtype)
110 | return x
111 | if isinstance(x, (list, tuple)):
112 | x = np.array(x)
113 | x = torch.from_numpy(x)
114 | if dtype is not None:
115 | x = x.type(dtype)
116 | return x
117 |
118 |
119 | def soft_dice_score(
120 | output: torch.Tensor,
121 | target: torch.Tensor,
122 | smooth: float = 0.0,
123 | eps: float = 1e-7,
124 | dims=None,
125 | ) -> torch.Tensor:
126 | assert output.size() == target.size()
127 | if dims is not None:
128 | intersection = torch.sum(output * target, dim=dims)
129 | cardinality = torch.sum(output + target, dim=dims)
130 | else:
131 | intersection = torch.sum(output * target)
132 | cardinality = torch.sum(output + target)
133 | dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
134 |
135 | return dice_score
136 |
137 |
138 | if __name__ == '__main__':
139 | x = torch.randn(3, 7, 256, 256)
140 | y = torch.zeros(3, 256, 256).long()
141 | model = DICELoss()
142 | output = model(x, y)
143 | print(output)
144 |
--------------------------------------------------------------------------------
/SCD/model/block/vertical.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .convs import ConvBnRelu, DsBnRelu, PyConv2d
5 | from mmcv.ops import MultiScaleDeformableAttention
6 | from einops import rearrange
7 | from torch import einsum
8 |
9 |
10 | class ScaledSinuEmbedding(nn.Module):
11 | def __init__(self, dim):
12 | super().__init__()
13 | self.scale = nn.Parameter(torch.ones(1,))
14 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
15 | self.register_buffer('inv_freq', inv_freq)
16 |
17 | def forward(self, x):
18 | n, device = x.shape[1], x.device
19 | t = torch.arange(n, device=device).type_as(self.inv_freq)
20 | sinu = einsum('i, j -> i j', t, self.inv_freq)
21 | emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
22 | return emb * self.scale
23 |
24 |
25 | def get_reference_points(spatial_shapes, device):
26 | reference_points_list = []
27 | for lvl, (H_, W_) in enumerate(spatial_shapes):
28 | ref_y, ref_x = torch.meshgrid(
29 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
30 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
31 | ref_y = ref_y.reshape(-1)[None] / H_
32 | ref_x = ref_x.reshape(-1)[None] / W_
33 | ref = torch.stack((ref_x, ref_y), -1)
34 | reference_points_list.append(ref)
35 | reference_points = torch.cat(reference_points_list, 1)
36 | reference_points = reference_points[:, :, None]
37 | return reference_points
38 |
39 |
40 | class VerticalFusion(nn.Module):
41 | def __init__(self, channels, num_heads=4, num_points=4, kernel_layers=1, up_kernel_size=5, enc_kernel_size=3):
42 | super(VerticalFusion, self).__init__()
43 | self.norm1 = nn.LayerNorm(channels)
44 | self.norm2 = nn.LayerNorm(channels)
45 | self.pos = ScaledSinuEmbedding(channels)
46 | self.crossattn = MultiScaleDeformableAttention(embed_dims=channels, num_levels=1, num_heads=num_heads,
47 | num_points=num_points, batch_first=True, dropout=0)
48 | convs = []
49 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels))
50 | for _ in range(kernel_layers - 1):
51 | convs.append(ConvBnRelu(in_channels=channels, out_channels=channels))
52 | self.convs = nn.Sequential(*convs)
53 | self.enc = ConvBnRelu(channels, up_kernel_size ** 2, kernel_size=enc_kernel_size,
54 | stride=1, padding=enc_kernel_size // 2, dilation=1)
55 |
56 | self.upsmp = nn.Upsample(scale_factor=2, mode='nearest')
57 | self.unfold = nn.Unfold(kernel_size=up_kernel_size, dilation=2,
58 | padding=up_kernel_size // 2 * 2)
59 |
60 | def get_deform_inputs(self, x1, x2):
61 | _, _, H1, W1 = x1.size()
62 | _, _, H2, W2 = x2.size()
63 | spatial_shapes = torch.as_tensor([(H2, W2)], dtype=torch.long, device=x2.device)
64 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
65 | reference_points = get_reference_points([(H1, W1)], x1.device)
66 |
67 | return reference_points, spatial_shapes, level_start_index
68 |
69 | def forward(self, x1, x2):
70 | reference_points, spatial_shapes, level_start_index = self.get_deform_inputs(x1, x2)
71 | B, C, H, W = x1.size()
72 | _, _, H2, W2 = x2.size()
73 | x1_, x2_ = x1.clone(), x2.clone()
74 | x1 = rearrange(x1, 'b c h w -> b (h w) c')
75 | x2 = rearrange(x2, 'b c h w -> b (h w) c')
76 | x1, x2 = self.norm1(x1), self.norm2(x2)
77 | query_pos = self.pos(x1)
78 | x = self.crossattn(query=x1, value=x2, reference_points=reference_points, spatial_shapes=spatial_shapes,
79 | level_start_index=level_start_index, query_pos=query_pos)
80 | x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
81 | kernel = self.convs(x2_)
82 | kernel = self.enc(kernel)
83 | kernel = F.softmax(kernel, dim=1)
84 | # x = self.upsmp(x)
85 | x = F.interpolate(x, size=(H2, W2), mode='nearest')
86 | x = self.unfold(x)
87 | # x = x.view(B, C, -1, H * 2, W * 2)
88 | x = x.view(B, C, -1, H2, W2)
89 | fuse = torch.einsum('bkhw,bckhw->bchw', [kernel, x])
90 | fuse += x2_
91 |
92 | return fuse
93 |
94 |
95 | # from thop import profile
96 | #
97 | #
98 | # x1 = torch.randn(2, 128, 32, 32)
99 | # x2 = torch.randn(2, 128, 64, 64)
100 | # model = VerticalFusion(channels=128)
101 | # y = model(x1, x2)
102 | # print(y.shape)
103 | # flops, params = profile(model, inputs=([x1, x2]))
104 | # print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
105 | # print('Params = ' + str(params / 1000 ** 2) + 'M')
106 |
--------------------------------------------------------------------------------
/model/backbone/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.hub import load_state_dict_from_url
3 |
4 | model_urls = {
5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
6 | }
7 |
8 |
9 | class ConvBNReLU(nn.Sequential):
10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1):
11 | padding = (kernel_size - 1) // 2
12 | if dilation != 1:
13 | padding = dilation
14 | super(ConvBNReLU, self).__init__(
15 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation,
16 | bias=False),
17 | nn.BatchNorm2d(out_planes),
18 | nn.ReLU6(inplace=True)
19 | )
20 |
21 |
22 | class InvertedResidual(nn.Module):
23 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
24 | super(InvertedResidual, self).__init__()
25 | self.stride = stride
26 | assert stride in [1, 2]
27 |
28 | hidden_dim = int(round(inp * expand_ratio))
29 | self.use_res_connect = self.stride == 1 and inp == oup
30 |
31 | layers = []
32 | if expand_ratio != 1:
33 | # pw
34 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
35 | layers.extend([
36 | # dw
37 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation),
38 | # pw-linear
39 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
40 | nn.BatchNorm2d(oup),
41 | ])
42 | self.conv = nn.Sequential(*layers)
43 |
44 | def forward(self, x):
45 | if self.use_res_connect:
46 | return x + self.conv(x)
47 | else:
48 | return self.conv(x)
49 |
50 |
51 | class MobileNetV2(nn.Module):
52 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0, replace_stride_with_dilation=False):
53 | super(MobileNetV2, self).__init__()
54 | block = InvertedResidual
55 | input_channel = 32
56 | last_channel = 1280
57 | # inverted_residual_setting = [
58 | # # t, c, n, s, d
59 | # [1, 16, 1, 1, 1],
60 | # [6, 24, 2, 2, 1],
61 | # [6, 32, 3, 2, 1],
62 | # [6, 64, 4, 2, 1],
63 | # [6, 96, 3, 1, 1],
64 | # [6, 160, 3, 2, 1],
65 | # [6, 320, 1, 1, 1],
66 | # ]
67 | inverted_residual_setting = [
68 | # t, c, n, s, d
69 | [1, 16, 1, 1, 1],
70 | [6, 24, 2, 2, 1],
71 | [6, 32, 3, 2, 1],
72 | [6, 64, 4, 1, 2] if replace_stride_with_dilation else [6, 64, 4, 2, 1],
73 | [6, 96, 3, 1, 1],
74 | [6, 160, 3, 1, 2] if replace_stride_with_dilation else [6, 160, 3, 2, 1],
75 | [6, 320, 1, 1, 1],
76 | ]
77 | self.channels = [16, 24, 32, 96, 320]
78 |
79 | # building first layer
80 | input_channel = int(input_channel * width_mult)
81 | self.last_channel = int(last_channel * max(1.0, width_mult))
82 | features = [ConvBNReLU(3, input_channel, stride=2)]
83 | # building inverted residual blocks
84 | for t, c, n, s, d in inverted_residual_setting:
85 | output_channel = int(c * width_mult)
86 | for i in range(n):
87 | stride = s if i == 0 else 1
88 | dilation = d if i == 0 else 1
89 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d))
90 | input_channel = output_channel
91 | # building last several layers
92 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
93 | # make it nn.Sequential
94 | self.features = nn.Sequential(*features)
95 |
96 | # weight initialization
97 | for m in self.modules():
98 | if isinstance(m, nn.Conv2d):
99 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
100 | if m.bias is not None:
101 | nn.init.zeros_(m.bias)
102 | elif isinstance(m, nn.BatchNorm2d):
103 | nn.init.ones_(m.weight)
104 | nn.init.zeros_(m.bias)
105 | elif isinstance(m, nn.Linear):
106 | nn.init.normal_(m.weight, 0, 0.01)
107 | nn.init.zeros_(m.bias)
108 |
109 | def forward(self, x):
110 | res = []
111 | for idx, m in enumerate(self.features):
112 | x = m(x)
113 | if idx in [1, 3, 6, 13, 17]:
114 | res.append(x)
115 | return res
116 |
117 |
118 | def mobilenet_v2(pretrained=True, progress=True, replace_stride_with_dilation=False, **kwargs):
119 | model = MobileNetV2(replace_stride_with_dilation=replace_stride_with_dilation, **kwargs)
120 | if pretrained:
121 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
122 | progress=progress)
123 | print("loading imagenet pretrained mobilenetv2")
124 | model.load_state_dict(state_dict, strict=False)
125 | print("loaded imagenet pretrained mobilenetv2")
126 | return model
127 |
--------------------------------------------------------------------------------
/model/block/fpn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from .dcnv2 import DCNv2
6 | from .convs import GatedConv2d, ContextGatedConv2d
7 |
8 |
9 | class GenerateGamma(nn.Module):
10 | def __init__(self, channels=128, mode='SE'):
11 | super(GenerateGamma, self).__init__()
12 | self.mode = mode
13 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
14 | self.max_pool = nn.AdaptiveMaxPool2d(1)
15 |
16 | self.fc = nn.Sequential(nn.Conv2d(channels, channels // 4, 1, bias=False),
17 | nn.ReLU(True),
18 | nn.Conv2d(channels // 4, channels, 1, bias=False))
19 | self.sigmoid = nn.Sigmoid()
20 |
21 | def forward(self, x):
22 | avg_out = self.fc(self.avg_pool(x))
23 | if self.mode == 'SE':
24 | return self.sigmoid(avg_out)
25 | elif self.mode == 'CBAM':
26 | max_out = self.fc(self.max_pool(x))
27 | out = avg_out + max_out
28 | return self.sigmoid(out)
29 | else:
30 | raise NotImplementedError
31 |
32 |
33 | class GenerateBeta(nn.Module):
34 | def __init__(self, channels=128, mode='conv'):
35 | super(GenerateBeta, self).__init__()
36 | self.stem = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, bias=True), nn.ReLU(True))
37 | if mode == 'conv':
38 | self.conv = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
39 | elif mode == 'gatedconv':
40 | self.conv = GatedConv2d(channels, channels, 3, padding=1, bias=True)
41 | elif mode == 'contextgatedconv':
42 | self.conv = ContextGatedConv2d(channels, channels, 3, padding=1, bias=True)
43 | else:
44 | raise NotImplementedError
45 |
46 | def forward(self, x):
47 | x = self.stem(x)
48 | return self.conv(x)
49 |
50 |
51 | ### MoFPN
52 | class FPN(nn.Module):
53 | def __init__(self, in_channels, out_channels=128, deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv'):
54 | super(FPN, self).__init__()
55 |
56 | self.p2 = DCNv2(in_channels=in_channels[0], out_channels=out_channels,
57 | kernel_size=3, padding=1, deform_groups=deform_groups)
58 | self.p3 = DCNv2(in_channels=in_channels[1], out_channels=out_channels,
59 | kernel_size=3, padding=1, deform_groups=deform_groups)
60 | self.p4 = DCNv2(in_channels=in_channels[2], out_channels=out_channels,
61 | kernel_size=3, padding=1, deform_groups=deform_groups)
62 | self.p5 = DCNv2(in_channels=in_channels[3], out_channels=out_channels,
63 | kernel_size=3, padding=1, deform_groups=deform_groups)
64 |
65 | self.p5_bn = nn.BatchNorm2d(out_channels, affine=True)
66 | self.p4_bn = nn.BatchNorm2d(out_channels, affine=False)
67 | self.p3_bn = nn.BatchNorm2d(out_channels, affine=False)
68 | self.p2_bn = nn.BatchNorm2d(out_channels, affine=False)
69 | self.activation = nn.ReLU(True)
70 |
71 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
72 | self.p4_Gamma = GenerateGamma(out_channels, mode=gamma_mode)
73 | self.p4_beta = GenerateBeta(out_channels, mode=beta_mode)
74 | self.p3_Gamma = GenerateGamma(out_channels, mode=gamma_mode)
75 | self.p3_beta = GenerateBeta(out_channels, mode=beta_mode)
76 | self.p2_Gamma = GenerateGamma(out_channels, mode=gamma_mode)
77 | self.p2_beta = GenerateBeta(out_channels, mode=beta_mode)
78 |
79 | self.p5_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
80 | self.p4_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81 | self.p3_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
82 | self.p2_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
83 | def forward(self, input):
84 | c2, c3, c4, c5 = input
85 |
86 | p5 = self.activation(self.p5_bn(self.p5(c5)))
87 | p5_up = F.interpolate(p5, size=c4.shape[-2:], mode='bilinear', align_corners=False)
88 | p4 = self.p4_bn(self.p4(c4))
89 | p4_gamma, p4_beta = self.p4_Gamma(p5_up), self.p4_beta(p5_up)
90 | p4 = self.activation(p4 * (1 + p4_gamma) + p4_beta)
91 | p4_up = F.interpolate(p4, size=c3.shape[-2:], mode='bilinear', align_corners=False)
92 | p3 = self.p3_bn(self.p3(c3))
93 | p3_gamma, p3_beta = self.p3_Gamma(p4_up), self.p3_beta(p4_up)
94 | p3 = self.activation(p3 * (1 + p3_gamma) + p3_beta)
95 | p3_up = F.interpolate(p3, size=c2.shape[-2:], mode='bilinear', align_corners=False)
96 | p2 = self.p2_bn(self.p2(c2))
97 | p2_gamma, p2_beta = self.p2_Gamma(p3_up), self.p2_beta(p3_up)
98 | p2 = self.activation(p2 * (1 + p2_gamma) + p2_beta)
99 |
100 | p5 = self.p5_smooth(p5)
101 | p4 = self.p4_smooth(p4)
102 | p3 = self.p3_smooth(p3)
103 | p2 = self.p2_smooth(p2)
104 |
105 | return p2, p3, p4, p5
106 |
107 |
108 |
--------------------------------------------------------------------------------
/SCD/model/backbone/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.hub import load_state_dict_from_url
3 |
4 | model_urls = {
5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
6 | }
7 |
8 |
9 | class ConvBNReLU(nn.Sequential):
10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1):
11 | padding = (kernel_size - 1) // 2
12 | if dilation != 1:
13 | padding = dilation
14 | super(ConvBNReLU, self).__init__(
15 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation,
16 | bias=False),
17 | nn.BatchNorm2d(out_planes),
18 | nn.ReLU6(inplace=True)
19 | )
20 |
21 |
22 | class InvertedResidual(nn.Module):
23 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
24 | super(InvertedResidual, self).__init__()
25 | self.stride = stride
26 | assert stride in [1, 2]
27 |
28 | hidden_dim = int(round(inp * expand_ratio))
29 | self.use_res_connect = self.stride == 1 and inp == oup
30 |
31 | layers = []
32 | if expand_ratio != 1:
33 | # pw
34 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
35 | layers.extend([
36 | # dw
37 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation),
38 | # pw-linear
39 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
40 | nn.BatchNorm2d(oup),
41 | ])
42 | self.conv = nn.Sequential(*layers)
43 |
44 | def forward(self, x):
45 | if self.use_res_connect:
46 | return x + self.conv(x)
47 | else:
48 | return self.conv(x)
49 |
50 |
51 | class MobileNetV2(nn.Module):
52 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0, replace_stride_with_dilation=False):
53 | super(MobileNetV2, self).__init__()
54 | block = InvertedResidual
55 | input_channel = 32
56 | last_channel = 1280
57 | # inverted_residual_setting = [
58 | # # t, c, n, s, d
59 | # [1, 16, 1, 1, 1],
60 | # [6, 24, 2, 2, 1],
61 | # [6, 32, 3, 2, 1],
62 | # [6, 64, 4, 2, 1],
63 | # [6, 96, 3, 1, 1],
64 | # [6, 160, 3, 2, 1],
65 | # [6, 320, 1, 1, 1],
66 | # ]
67 | inverted_residual_setting = [
68 | # t, c, n, s, d
69 | [1, 16, 1, 1, 1],
70 | [6, 24, 2, 2, 1],
71 | [6, 32, 3, 2, 1],
72 | [6, 64, 4, 1, 2] if replace_stride_with_dilation else [6, 64, 4, 2, 1],
73 | [6, 96, 3, 1, 1],
74 | [6, 160, 3, 1, 2] if replace_stride_with_dilation else [6, 160, 3, 2, 1],
75 | [6, 320, 1, 1, 1],
76 | ]
77 | self.channels = [16, 24, 32, 96, 320]
78 |
79 | # building first layer
80 | input_channel = int(input_channel * width_mult)
81 | self.last_channel = int(last_channel * max(1.0, width_mult))
82 | features = [ConvBNReLU(3, input_channel, stride=2)]
83 | # building inverted residual blocks
84 | for t, c, n, s, d in inverted_residual_setting:
85 | output_channel = int(c * width_mult)
86 | for i in range(n):
87 | stride = s if i == 0 else 1
88 | dilation = d if i == 0 else 1
89 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d))
90 | input_channel = output_channel
91 | # building last several layers
92 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
93 | # make it nn.Sequential
94 | self.features = nn.Sequential(*features)
95 |
96 | # weight initialization
97 | for m in self.modules():
98 | if isinstance(m, nn.Conv2d):
99 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
100 | if m.bias is not None:
101 | nn.init.zeros_(m.bias)
102 | elif isinstance(m, nn.BatchNorm2d):
103 | nn.init.ones_(m.weight)
104 | nn.init.zeros_(m.bias)
105 | elif isinstance(m, nn.Linear):
106 | nn.init.normal_(m.weight, 0, 0.01)
107 | nn.init.zeros_(m.bias)
108 |
109 | def forward(self, x):
110 | res = []
111 | for idx, m in enumerate(self.features):
112 | x = m(x)
113 | if idx in [1, 3, 6, 13, 17]:
114 | res.append(x)
115 | return res
116 |
117 |
118 | def mobilenet_v2(pretrained=True, progress=True, replace_stride_with_dilation=False, **kwargs):
119 | model = MobileNetV2(replace_stride_with_dilation=replace_stride_with_dilation, **kwargs)
120 | if pretrained:
121 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
122 | progress=progress)
123 | print("loading imagenet pretrained mobilenetv2")
124 | model.load_state_dict(state_dict, strict=False)
125 | print("loaded imagenet pretrained mobilenetv2")
126 | return model
127 |
128 | # import torch
129 | # model = MobileNetV2(pretrained=True)
130 | # input = torch.randn(2, 3, 256, 256)
131 | # output = model(input)
132 | # for i in output:
133 | # print(i.shape)
--------------------------------------------------------------------------------
/SCD/model/block/heads.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/swz30/MIRNet
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from .convs import GatedConv2d, ContextGatedConv2d
8 |
9 |
10 | class ResidualUpSample(nn.Module):
11 | def __init__(self, in_channels):
12 | super(ResidualUpSample, self).__init__()
13 |
14 | self.top = nn.Sequential(nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, output_padding=1,
15 | bias=False),
16 | nn.BatchNorm2d(in_channels),
17 | nn.ReLU(True),
18 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True)
19 | )
20 | self.bot = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
21 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True)
22 | )
23 | self.relu = nn.ReLU(True)
24 |
25 | def forward(self, x):
26 | top = self.top(x)
27 | bot = self.bot(x)
28 | out = self.relu(top + bot)
29 | return out
30 |
31 |
32 | class GatedResidualUp(nn.Module):
33 | def __init__(self, in_channels, up_mode='conv', gate_mode='gated'):
34 | super(GatedResidualUp, self).__init__()
35 | if up_mode == 'conv':
36 | self.residual_up = nn.Sequential(nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1,
37 | output_padding=1, bias=False),
38 | nn.BatchNorm2d(in_channels),
39 | nn.ReLU(True))
40 | elif up_mode == 'bilinear':
41 | self.residual_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
42 |
43 | if gate_mode == 'gated':
44 | self.gate = GatedConv2d(in_channels, in_channels // 2)
45 | elif gate_mode == 'context_gated':
46 | self.gate = ContextGatedConv2d(in_channels, in_channels // 2)
47 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
48 | nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=True)
49 | )
50 | self.relu = nn.ReLU(True)
51 |
52 | def forward(self, x):
53 | residual = self.residual_up(x)
54 | residual = self.gate(residual)
55 | up = self.up(x)
56 | out = self.relu(up + residual)
57 | return out
58 |
59 |
60 | class GatedResidualUpHead(nn.Module):
61 | def __init__(self, in_channels=128, num_classes=1, dropout_rate=0.15):
62 | super(GatedResidualUpHead, self).__init__()
63 |
64 | self.up = nn.Sequential(GatedResidualUp(in_channels),
65 | GatedResidualUp(in_channels // 2))
66 | self.smooth = nn.Sequential(nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, stride=1, padding=1),
67 | nn.ReLU(True),
68 | nn.Dropout2d(dropout_rate))
69 | self.final = nn.Conv2d(in_channels // 4, num_classes, 1)
70 |
71 | def forward(self, x):
72 | x = self.up(x)
73 | x = self.smooth(x)
74 | x = self.final(x)
75 |
76 | return x
77 |
78 |
79 | class ResidualUpHead(nn.Module):
80 | def __init__(self, in_channels=128, num_classes=1, dropout_rate=0.15):
81 | super(ResidualUpHead, self).__init__()
82 |
83 | self.up = nn.Sequential(ResidualUpSample(in_channels),
84 | ResidualUpSample(in_channels // 2))
85 | self.smooth = nn.Sequential(nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, stride=1, padding=1),
86 | nn.ReLU(True),
87 | nn.Dropout2d(dropout_rate))
88 | self.final = nn.Conv2d(in_channels // 4, num_classes, 1)
89 |
90 | def forward(self, x):
91 | x = self.up(x)
92 | x = self.smooth(x)
93 | x = self.final(x)
94 |
95 | return x
96 |
97 |
98 | class FCNHead(nn.Module):
99 | def __init__(self, in_channels, num_classes, num_convs=1, dropout_rate=0.15):
100 | self.num_convs = num_convs
101 | super(FCNHead, self).__init__()
102 | inter_channels = in_channels // 4
103 |
104 | convs = []
105 | convs.append(nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
106 | nn.BatchNorm2d(inter_channels),
107 | nn.ReLU(True)))
108 | for i in range(num_convs - 1):
109 | convs.append(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
110 | nn.BatchNorm2d(inter_channels),
111 | nn.ReLU(True)))
112 | self.convs = nn.Sequential(*convs)
113 | self.final = nn.Conv2d(inter_channels, num_classes, 1)
114 |
115 | def forward(self, x):
116 | out = self.convs(x)
117 | out = self.final(out)
118 |
119 | return out
120 |
121 |
--------------------------------------------------------------------------------
/SCD/trainval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from option import Options
3 | from data.cd_dataset import DataLoader
4 | from model.create_model import create_model
5 | from tqdm import tqdm
6 | import math
7 | from util.palette import color_map
8 | from util.metric import IOUandSek
9 | import os
10 | import numpy as np
11 | import random
12 | from PIL import Image
13 |
14 |
15 | def setup_seed(seed):
16 | os.environ['PYTHONHASHSEED'] = str(seed)
17 | np.random.seed(seed)
18 | random.seed(seed)
19 |
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | torch.backends.cudnn.deterministic = False
24 | torch.backends.cudnn.benchmark = True
25 | torch.backends.cudnn.enabled = True
26 |
27 |
28 | class Trainval(object):
29 | def __init__(self, opt):
30 | self.opt = opt
31 |
32 | train_loader = DataLoader(opt)
33 | self.train_data = train_loader.load_data()
34 | train_size = len(train_loader)
35 | print("#training images = %d" % train_size)
36 | opt.phase = 'val'
37 | val_loader = DataLoader(opt)
38 | self.val_data = val_loader.load_data()
39 | val_size = len(val_loader)
40 | print("#validation images = %d" % val_size)
41 | opt.phase = 'train'
42 |
43 | self.model = create_model(opt)
44 | self.optimizer = self.model.optimizer
45 | self.schedular = self.model.schedular
46 |
47 | self.iters = 0
48 | self.total_iters = math.ceil(train_size / opt.batch_size) * opt.num_epochs
49 | self.previous_best = 0.0
50 |
51 | def train(self):
52 | tbar = tqdm(self.train_data)
53 | opt.phase = 'train'
54 | _loss = 0.0
55 | _cd_loss = 0.0
56 | _seg_loss = 0.0
57 |
58 | for i, data in enumerate(tbar):
59 | self.model.detector.train()
60 | cd_loss, seg_loss = self.model(data['img1'].cuda(), data['img2'].cuda(), data['label1'].cuda(),
61 | data['label2'].cuda(), data['cd_label'].cuda())
62 | loss = 2 * cd_loss + seg_loss
63 | self.optimizer.zero_grad()
64 | loss.backward()
65 | self.optimizer.step()
66 | self.schedular.step()
67 | _loss += loss.item()
68 | _cd_loss += cd_loss.item()
69 | _seg_loss += seg_loss.item()
70 | del loss
71 |
72 | #self.iters += 1
73 | #self.model.adjust_learning_rate(self.iters, self.total_iters)
74 |
75 | tbar.set_description("Loss: %.3f, CD: %.3f, Seg: %.3f, LR: %.6f" %
76 | (_loss / (i + 1), _cd_loss / (i + 1), _seg_loss / (i + 1), self.optimizer.param_groups[0]['lr']))
77 |
78 | def val(self):
79 | tbar = tqdm(self.val_data)
80 | metric = IOUandSek(num_classes=7)
81 | opt.phase = 'val'
82 | self.model.eval()
83 |
84 | with torch.no_grad():
85 | for i, _data in enumerate(tbar):
86 | cd_out, seg_out1, seg_out2 = self.model.inference(_data['img1'].cuda(), _data['img2'].cuda())
87 | # update metric
88 | val_target = _data['cd_label'].detach()
89 | cd_out = torch.argmax(cd_out.detach(), dim=1)
90 | #val_pred = torch.where(val_pred > 0.5, torch.ones_like(val_pred), torch.zeros_like(val_pred)).long()
91 | seg_out1 = torch.argmax(seg_out1, dim=1).cpu().numpy()
92 | seg_out2 = torch.argmax(seg_out2, dim=1).cpu().numpy()
93 | cd_out = cd_out.cpu().numpy().astype(np.uint8)
94 | seg_out1[cd_out == 0] = 0
95 | seg_out2[cd_out == 0] = 0
96 |
97 | if self.opt.save_mask:
98 | cmap = color_map(self.opt.dataset)
99 | for i in range(seg_out1.shape[0]):
100 | mask = Image.fromarray(seg_out1[i].astype(np.uint8), mode="P")
101 | mask.putpalette(cmap)
102 | os.makedirs(os.path.join(self.opt.result_dir, 'val', 'im1'), exist_ok=True)
103 | mask.save(os.path.join(self.opt.result_dir, 'val', 'im1', _data['fname'][i]))
104 |
105 | mask = Image.fromarray(seg_out2[i].astype(np.uint8), mode="P")
106 | mask.putpalette(cmap)
107 | os.makedirs(os.path.join(self.opt.result_dir, 'val', 'im2'), exist_ok=True)
108 | mask.save(os.path.join(self.opt.result_dir, 'val', 'im2', _data['fname'][i]))
109 |
110 | metric.add_batch(seg_out1, _data['label1'].numpy())
111 | metric.add_batch(seg_out2, _data['label2'].numpy())
112 | score, miou, sek = metric.evaluate()
113 | tbar.set_description("Score: %.2f, IOU: %.2f, SeK: %.2f" % (score * 100.0, miou * 100.0, sek * 100.0))
114 |
115 | if score >= self.previous_best:
116 | self.model.save(self.opt.name, self.opt.backbone)
117 | self.previous_best = score
118 |
119 |
120 | if __name__ == "__main__":
121 | opt = Options().parse()
122 | trainval = Trainval(opt)
123 | setup_seed(seed=1)
124 |
125 | for epoch in range(1, opt.num_epochs + 1):
126 | print("\n==> Epoch %i, previous best = %.3f" % (epoch, trainval.previous_best * 100))
127 | trainval.train()
128 | trainval.val()
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/SCD/model/block/fpn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from .dcnv2 import DCNv2
6 | from .convs import GatedConv2d, ContextGatedConv2d
7 |
8 |
9 | class GenerateGamma(nn.Module):
10 | def __init__(self, channels=128, mode='SE'):
11 | super(GenerateGamma, self).__init__()
12 | self.mode = mode
13 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
14 | self.max_pool = nn.AdaptiveMaxPool2d(1)
15 |
16 | self.fc = nn.Sequential(nn.Conv2d(channels, channels // 4, 1, bias=False),
17 | nn.ReLU(True),
18 | nn.Conv2d(channels // 4, channels, 1, bias=False))
19 | self.sigmoid = nn.Sigmoid()
20 |
21 | def forward(self, x):
22 | avg_out = self.fc(self.avg_pool(x))
23 | if self.mode == 'SE':
24 | return self.sigmoid(avg_out)
25 | elif self.mode == 'CBAM':
26 | max_out = self.fc(self.max_pool(x))
27 | out = avg_out + max_out
28 | return self.sigmoid(out)
29 | else:
30 | raise NotImplementedError
31 |
32 |
33 | class GenerateBeta(nn.Module):
34 | def __init__(self, channels=128, mode='conv'):
35 | super(GenerateBeta, self).__init__()
36 | self.stem = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, bias=True), nn.ReLU(True))
37 | if mode == 'conv':
38 | self.conv = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
39 | elif mode == 'gatedconv':
40 | self.conv = GatedConv2d(channels, channels, 3, padding=1, bias=True)
41 | elif mode == 'contextgatedconv':
42 | self.conv = ContextGatedConv2d(channels, channels, 3, padding=1, bias=True)
43 | else:
44 | raise NotImplementedError
45 |
46 | def forward(self, x):
47 | x = self.stem(x)
48 | return self.conv(x)
49 |
50 |
51 | class FPN(nn.Module):
52 | def __init__(self, in_channels, out_channels=128, deform_groups=4, gamma_mode='SE', beta_mode='gatedconv'):
53 | super(FPN, self).__init__()
54 |
55 | self.p2 = DCNv2(in_channels=in_channels[0], out_channels=out_channels,
56 | kernel_size=3, padding=1, deform_groups=deform_groups)
57 | self.p3 = DCNv2(in_channels=in_channels[1], out_channels=out_channels,
58 | kernel_size=3, padding=1, deform_groups=deform_groups)
59 | self.p4 = DCNv2(in_channels=in_channels[2], out_channels=out_channels,
60 | kernel_size=3, padding=1, deform_groups=deform_groups)
61 | self.p5 = DCNv2(in_channels=in_channels[3], out_channels=out_channels,
62 | kernel_size=3, padding=1, deform_groups=deform_groups)
63 |
64 | self.p5_bn = nn.BatchNorm2d(out_channels, affine=True)
65 | self.p4_bn = nn.BatchNorm2d(out_channels, affine=False)
66 | self.p3_bn = nn.BatchNorm2d(out_channels, affine=False)
67 | self.p2_bn = nn.BatchNorm2d(out_channels, affine=False)
68 | self.activation = nn.ReLU(True)
69 |
70 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
71 | self.p4_Gamma = GenerateGamma(out_channels, mode=gamma_mode)
72 | self.p4_beta = GenerateBeta(out_channels, mode=beta_mode)
73 | self.p3_Gamma = GenerateGamma(out_channels, mode=gamma_mode)
74 | self.p3_beta = GenerateBeta(out_channels, mode=beta_mode)
75 | self.p2_Gamma = GenerateGamma(out_channels, mode=gamma_mode)
76 | self.p2_beta = GenerateBeta(out_channels, mode=beta_mode)
77 |
78 | self.p5_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
79 | self.p4_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
80 | self.p3_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81 | self.p2_smooth = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
82 | def forward(self, input):
83 | c2, c3, c4, c5 = input
84 |
85 | p5 = self.activation(self.p5_bn(self.p5(c5)))
86 | p5_up = F.interpolate(p5, size=c4.shape[-2:], mode='bilinear', align_corners=False)
87 | p4 = self.p4_bn(self.p4(c4))
88 | p4_gamma, p4_beta = self.p4_Gamma(p5_up), self.p4_beta(p5_up)
89 | p4 = self.activation(p4 * (1 + p4_gamma) + p4_beta)
90 | p4_up = F.interpolate(p4, size=c3.shape[-2:], mode='bilinear', align_corners=False)
91 | p3 = self.p3_bn(self.p3(c3))
92 | p3_gamma, p3_beta = self.p3_Gamma(p4_up), self.p3_beta(p4_up)
93 | p3 = self.activation(p3 * (1 + p3_gamma) + p3_beta)
94 | p3_up = F.interpolate(p3, size=c2.shape[-2:], mode='bilinear', align_corners=False)
95 | p2 = self.p2_bn(self.p2(c2))
96 | p2_gamma, p2_beta = self.p2_Gamma(p3_up), self.p2_beta(p3_up)
97 | p2 = self.activation(p2 * (1 + p2_gamma) + p2_beta)
98 |
99 | p5 = self.p5_smooth(p5)
100 | p4 = self.p4_smooth(p4)
101 | p3 = self.p3_smooth(p3)
102 | p2 = self.p2_smooth(p2)
103 |
104 | return p2, p3, p4, p5
105 |
106 |
107 | # from thop import profile
108 | #
109 | #
110 | # x1 = torch.randn(2, 512, 8, 8)
111 | # x2 = torch.randn(2, 256, 16, 16)
112 | # x3 = torch.randn(2, 128, 32, 32)
113 | # x4 = torch.randn(2, 64, 64, 64)
114 | # model = FPN(in_channels=[64, 128, 256, 512])
115 | # y = model([x4, x3, x2, x1])
116 | # for i in y:
117 | # print(i.shape)
118 | # flops, params = profile(model, inputs=([x4, x3, x2, x1],))
119 | # print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
120 | # print('Params = ' + str(params / 1000 ** 2) + 'M')
121 |
--------------------------------------------------------------------------------
/SCD/model/create_model.py:
--------------------------------------------------------------------------------
1 | from .network import Detector
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | import os
7 | import torch.optim as optim
8 | from .block.schedular import get_cosine_schedule_with_warmup
9 | from .loss.bcedice import BCEDiceLoss
10 | from .loss.focal import BinaryFocalLoss, FocalLoss
11 | from .loss.dice import BinaryDICELoss, DICELoss
12 | from thop import profile
13 |
14 |
15 | def get_model(backbone_name='mobilenetv2', fpn_name='neighbor', fpn_channels=64, deform_groups=4,
16 | gamma_mode='SE', beta_mode='gatedconv', num_heads=1, num_points=8, kernel_layers=1,
17 | dropout_rate=0.1, init_type='kaiming_normal'):
18 | detector = Detector(backbone_name, fpn_name, fpn_channels, deform_groups, gamma_mode, beta_mode,
19 | num_heads, num_points, kernel_layers, dropout_rate, init_type)
20 | print(detector)
21 | input1 = torch.randn(1, 3, 256, 256)
22 | input2 = torch.randn(1, 3, 256, 256)
23 | flops, params = profile(detector, inputs=(input1, input2))
24 | print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
25 | print('Params = ' + str(params / 1000 ** 2) + 'M')
26 |
27 | return detector
28 |
29 |
30 | class Model(nn.Module):
31 | def __init__(self, opt):
32 | super(Model, self).__init__()
33 | self.device = torch.device("cuda:%s" % opt.gpu_ids[0] if torch.cuda.is_available() else "cpu")
34 | self.opt = opt
35 | self.base_lr = opt.lr
36 | self.save_dir = os.path.join(opt.checkpoint_dir, opt.name)
37 | os.makedirs(self.save_dir, exist_ok=True)
38 |
39 | self.detector = get_model(backbone_name=opt.backbone, fpn_name=opt.fpn, fpn_channels=opt.fpn_channels,
40 | deform_groups=opt.deform_groups, gamma_mode=opt.gamma_mode, beta_mode=opt.beta_mode,
41 | num_heads=opt.num_heads, num_points=opt.num_points, kernel_layers=opt.kernel_layers,
42 | dropout_rate=opt.dropout_rate, init_type=opt.init_type)
43 | self.cd_focal = BinaryFocalLoss(alpha=opt.alpha, gamma=opt.gamma)
44 | self.cd_dice = BinaryDICELoss()
45 | self.scd_focal = FocalLoss(ignore_index=0, alpha=opt.alpha, gamma=opt.gamma)
46 | self.scd_dice = DICELoss(ignore_index=0)
47 |
48 | self.optimizer = optim.AdamW(self.detector.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
49 | self.schedular = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=445 * opt.warmup_epochs,
50 | num_training_steps=445 * opt.num_epochs)
51 |
52 | if opt.load_pretrain:
53 | self.load_ckpt(self.detector, self.optimizer, opt.name, opt.backbone)
54 | self.detector.cuda()
55 | # self.detector = nn.DataParallel(self.detector, device_ids=opt.gpu_ids)
56 | print("---------- Networks initialized -------------")
57 |
58 | def forward(self, x1, x2, label1, label2, cd_label):
59 | pred, pred_seg1, pred_seg2, pred_p2, pred_p3, pred_p4, pred_p5 = self.detector(x1, x2)
60 | # label = label.unsqueeze(1).float()
61 | cd_label = cd_label.long()
62 | cd_focal = self.cd_focal(pred, cd_label)
63 | cd_dice = self.cd_dice(pred, cd_label)
64 | p2_loss = self.cd_focal(pred_p2, cd_label) * 0.5 + self.cd_dice(pred_p2, cd_label)
65 | p3_loss = self.cd_focal(pred_p3, cd_label) * 0.5 + self.cd_dice(pred_p3, cd_label)
66 | p4_loss = self.cd_focal(pred_p4, cd_label) * 0.5 + self.cd_dice(pred_p4, cd_label)
67 | p5_loss = self.cd_focal(pred_p5, cd_label) * 0.5 + self.cd_dice(pred_p5, cd_label)
68 | cd_loss = cd_focal * 0.5 + cd_dice + p3_loss + p4_loss + p5_loss
69 | seg_loss = self.scd_focal(pred_seg1, label1) * 0.5 + self.scd_dice(pred_seg1, label1) \
70 | + self.scd_focal(pred_seg2, label2) * 0.5 + self.scd_dice(pred_seg2, label2)
71 |
72 | return cd_loss, seg_loss
73 |
74 | def inference(self, x1, x2):
75 | with torch.no_grad():
76 | pred, pred_seg1, pred_seg2, _, _, _, _ = self.detector(x1, x2)
77 | return pred, pred_seg1, pred_seg2
78 |
79 | def adjust_learning_rate(self, iter, total_iters, min_lr=1e-6, power=0.9):
80 | lr = (self.base_lr - min_lr) * (1 - iter / total_iters) ** power + min_lr
81 | for param_group in self.optimizer.param_groups:
82 | param_group['lr'] = lr
83 |
84 | def load_ckpt(self, network, optimizer, name, backbone):
85 | save_filename = '%s_%s_best.pth' % (name, backbone)
86 | save_path = os.path.join(self.save_dir, save_filename)
87 | if not os.path.isfile(save_path):
88 | print("%s not exists yet!" % save_path)
89 | raise ("%s must exist!" % save_filename)
90 | else:
91 | checkpoint = torch.load(save_path, map_location=self.device)
92 | network.load_state_dict(checkpoint['network'], False)
93 |
94 | def save_ckpt(self, network, optimizer, model_name, backbone):
95 | save_filename = '%s_%s_best.pth' % (model_name, backbone)
96 | save_path = os.path.join(self.save_dir, save_filename)
97 | if os.path.exists(save_path):
98 | os.remove(save_path)
99 | torch.save({'network': network.cpu().state_dict(),
100 | 'optimizer': optimizer.state_dict()},
101 | save_path)
102 | if torch.cuda.is_available():
103 | network.cuda()
104 |
105 | def save(self, model_name, backbone):
106 | self.save_ckpt(self.detector, self.optimizer, model_name, backbone)
107 |
108 | def name(self):
109 | return self.opt.name
110 |
111 |
112 | def create_model(opt):
113 | model = Model(opt)
114 | print("model [%s] was created" % model.name())
115 |
116 | return model.cuda()
117 |
118 |
--------------------------------------------------------------------------------
/util/metric_tool.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied and modified from
3 | https://github.com/justchenhao/BIT_CD
4 | """
5 | import numpy as np
6 |
7 |
8 | ################### metrics ###################
9 | class AverageMeter(object):
10 | """Computes and stores the average and current value"""
11 | def __init__(self):
12 | self.initialized = False
13 | self.val = None
14 | self.avg = None
15 | self.sum = None
16 | self.count = None
17 |
18 | def initialize(self, val, weight):
19 | self.val = val
20 | self.avg = val
21 | self.sum = val * weight
22 | self.count = weight
23 | self.initialized = True
24 |
25 | def update(self, val, weight=1):
26 | if not self.initialized:
27 | self.initialize(val, weight)
28 | else:
29 | self.add(val, weight)
30 |
31 | def add(self, val, weight):
32 | self.val = val
33 | self.sum += val * weight
34 | self.count += weight
35 | self.avg = self.sum / self.count
36 |
37 | def value(self):
38 | return self.val
39 |
40 | def average(self):
41 | return self.avg
42 |
43 | def get_scores(self):
44 | scores_dict = cm2score(self.sum)
45 | return scores_dict
46 |
47 | def clear(self):
48 | self.initialized = False
49 |
50 |
51 | ################### cm metrics ###################
52 | class ConfuseMatrixMeter(AverageMeter):
53 | """Computes and stores the average and current value"""
54 | def __init__(self, n_class):
55 | super(ConfuseMatrixMeter, self).__init__()
56 | self.n_class = n_class
57 |
58 | def update_cm(self, pr, gt, weight=1):
59 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr)
60 | self.update(val, weight)
61 | current_score = cm2F1(val)
62 | return current_score
63 |
64 | def get_scores(self):
65 | scores_dict = cm2score(self.sum)
66 | return scores_dict
67 |
68 |
69 |
70 | def harmonic_mean(xs):
71 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs)
72 | return harmonic_mean
73 |
74 |
75 | def cm2F1(confusion_matrix):
76 | hist = confusion_matrix
77 | n_class = hist.shape[0]
78 | tp = np.diag(hist)
79 | sum_a1 = hist.sum(axis=1)
80 | sum_a0 = hist.sum(axis=0)
81 | # ---------------------------------------------------------------------- #
82 | # 1. Accuracy & Class Accuracy
83 | # ---------------------------------------------------------------------- #
84 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
85 |
86 | # recall
87 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
88 | # acc_cls = np.nanmean(recall)
89 |
90 | # precision
91 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
92 |
93 | # F1 score
94 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
95 | mean_F1 = np.nanmean(F1)
96 | return mean_F1
97 |
98 |
99 | def cm2score(confusion_matrix):
100 | hist = confusion_matrix
101 | n_class = hist.shape[0]
102 | tp = np.diag(hist)
103 | sum_a1 = hist.sum(axis=1)
104 | sum_a0 = hist.sum(axis=0)
105 | # ---------------------------------------------------------------------- #
106 | # 1. Accuracy & Class Accuracy
107 | # ---------------------------------------------------------------------- #
108 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
109 |
110 | # recall
111 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
112 | # acc_cls = np.nanmean(recall)
113 |
114 | # precision
115 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
116 |
117 | # F1 score
118 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps)
119 | mean_F1 = np.nanmean(F1)
120 | # ---------------------------------------------------------------------- #
121 | # 2. Frequency weighted Accuracy & Mean IoU
122 | # ---------------------------------------------------------------------- #
123 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
124 | mean_iu = np.nanmean(iu)
125 |
126 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps)
127 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
128 |
129 | #
130 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu))
131 |
132 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision))
133 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall))
134 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1))
135 |
136 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1}
137 | score_dict.update(cls_iou)
138 | score_dict.update(cls_F1)
139 | score_dict.update(cls_precision)
140 | score_dict.update(cls_recall)
141 | return score_dict
142 |
143 |
144 | def get_confuse_matrix(num_classes, label_gts, label_preds):
145 | def __fast_hist(label_gt, label_pred):
146 | """
147 | Collect values for Confusion Matrix
148 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
149 | :param label_gt: ground-truth
150 | :param label_pred: prediction
151 | :return: values for confusion matrix
152 | """
153 | mask = (label_gt >= 0) & (label_gt < num_classes)
154 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask],
155 | minlength=num_classes**2).reshape(num_classes, num_classes)
156 | return hist
157 | confusion_matrix = np.zeros((num_classes, num_classes))
158 | for lt, lp in zip(label_gts, label_preds):
159 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten())
160 | return confusion_matrix
161 |
162 |
163 | def get_mIoU(num_classes, label_gts, label_preds):
164 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds)
165 | score_dict = cm2score(confusion_matrix)
166 | return score_dict['miou']
167 |
--------------------------------------------------------------------------------
/SCD/util/metric_tool.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied and modified from
3 | https://github.com/justchenhao/BIT_CD
4 | """
5 | import numpy as np
6 |
7 |
8 | ################### metrics ###################
9 | class AverageMeter(object):
10 | """Computes and stores the average and current value"""
11 | def __init__(self):
12 | self.initialized = False
13 | self.val = None
14 | self.avg = None
15 | self.sum = None
16 | self.count = None
17 |
18 | def initialize(self, val, weight):
19 | self.val = val
20 | self.avg = val
21 | self.sum = val * weight
22 | self.count = weight
23 | self.initialized = True
24 |
25 | def update(self, val, weight=1):
26 | if not self.initialized:
27 | self.initialize(val, weight)
28 | else:
29 | self.add(val, weight)
30 |
31 | def add(self, val, weight):
32 | self.val = val
33 | self.sum += val * weight
34 | self.count += weight
35 | self.avg = self.sum / self.count
36 |
37 | def value(self):
38 | return self.val
39 |
40 | def average(self):
41 | return self.avg
42 |
43 | def get_scores(self):
44 | scores_dict = cm2score(self.sum)
45 | return scores_dict
46 |
47 | def clear(self):
48 | self.initialized = False
49 |
50 |
51 | ################### cm metrics ###################
52 | class ConfuseMatrixMeter(AverageMeter):
53 | """Computes and stores the average and current value"""
54 | def __init__(self, n_class):
55 | super(ConfuseMatrixMeter, self).__init__()
56 | self.n_class = n_class
57 |
58 | def update_cm(self, pr, gt, weight=1):
59 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr)
60 | self.update(val, weight)
61 | current_score = cm2F1(val)
62 | return current_score
63 |
64 | def get_scores(self):
65 | scores_dict = cm2score(self.sum)
66 | return scores_dict
67 |
68 |
69 |
70 | def harmonic_mean(xs):
71 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs)
72 | return harmonic_mean
73 |
74 |
75 | def cm2F1(confusion_matrix):
76 | hist = confusion_matrix
77 | n_class = hist.shape[0]
78 | tp = np.diag(hist)
79 | sum_a1 = hist.sum(axis=1)
80 | sum_a0 = hist.sum(axis=0)
81 | # ---------------------------------------------------------------------- #
82 | # 1. Accuracy & Class Accuracy
83 | # ---------------------------------------------------------------------- #
84 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
85 |
86 | # recall
87 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
88 | # acc_cls = np.nanmean(recall)
89 |
90 | # precision
91 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
92 |
93 | # F1 score
94 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
95 | mean_F1 = np.nanmean(F1)
96 | return mean_F1
97 |
98 |
99 | def cm2score(confusion_matrix):
100 | hist = confusion_matrix
101 | n_class = hist.shape[0]
102 | tp = np.diag(hist)
103 | sum_a1 = hist.sum(axis=1)
104 | sum_a0 = hist.sum(axis=0)
105 | # ---------------------------------------------------------------------- #
106 | # 1. Accuracy & Class Accuracy
107 | # ---------------------------------------------------------------------- #
108 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
109 |
110 | # recall
111 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
112 | # acc_cls = np.nanmean(recall)
113 |
114 | # precision
115 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
116 |
117 | # F1 score
118 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps)
119 | mean_F1 = np.nanmean(F1)
120 | # ---------------------------------------------------------------------- #
121 | # 2. Frequency weighted Accuracy & Mean IoU
122 | # ---------------------------------------------------------------------- #
123 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
124 | mean_iu = np.nanmean(iu)
125 |
126 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps)
127 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
128 |
129 | #
130 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu))
131 |
132 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision))
133 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall))
134 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1))
135 |
136 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1}
137 | score_dict.update(cls_iou)
138 | score_dict.update(cls_F1)
139 | score_dict.update(cls_precision)
140 | score_dict.update(cls_recall)
141 | return score_dict
142 |
143 |
144 | def get_confuse_matrix(num_classes, label_gts, label_preds):
145 | def __fast_hist(label_gt, label_pred):
146 | """
147 | Collect values for Confusion Matrix
148 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
149 | :param label_gt: ground-truth
150 | :param label_pred: prediction
151 | :return: values for confusion matrix
152 | """
153 | mask = (label_gt >= 0) & (label_gt < num_classes)
154 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask],
155 | minlength=num_classes**2).reshape(num_classes, num_classes)
156 | return hist
157 | confusion_matrix = np.zeros((num_classes, num_classes))
158 | for lt, lp in zip(label_gts, label_preds):
159 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten())
160 | return confusion_matrix
161 |
162 |
163 | def get_mIoU(num_classes, label_gts, label_preds):
164 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds)
165 | score_dict = cm2score(confusion_matrix)
166 | return score_dict['miou']
167 |
--------------------------------------------------------------------------------
/model/block/convs.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/iduta/pyconv
3 | https://github.com/XudongLinthu/context-gated-convolution
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import numpy as np
10 |
11 |
12 | class ConvBnRelu(nn.Module):
13 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
14 | super(ConvBnRelu, self).__init__()
15 | self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
16 | padding=padding, dilation=dilation, bias=False),
17 | nn.BatchNorm2d(out_channels),
18 | nn.ReLU(inplace=True))
19 |
20 | def forward(self, x):
21 | x = self.block(x)
22 | return x
23 |
24 |
25 | class DsBnRelu(nn.Module):
26 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
27 | super(DsBnRelu, self).__init__()
28 | self.kernel_size = kernel_size
29 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding,
30 | dilation, groups=in_channels, bias=False)
31 | self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
32 | self.bn = nn.BatchNorm2d(out_channels)
33 | self.relu = nn.ReLU(True)
34 |
35 | def forward(self, x):
36 | if self.kernel_size != 1:
37 | x = self.depthwise(x)
38 | x = self.pointwise(x)
39 | x = self.bn(x)
40 | x = self.relu(x)
41 | return x
42 |
43 |
44 | class PyConv2d(nn.Module):
45 | def __init__(self, in_channels, out_channels, pyconv_kernels=[1, 3, 5, 7], pyconv_groups=[1, 2, 4, 8], bias=False):
46 | super(PyConv2d, self).__init__()
47 |
48 | pyconv_levels = []
49 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups):
50 | pyconv_levels.append(nn.Conv2d(in_channels, out_channels // 2, kernel_size=pyconv_kernel,
51 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=bias))
52 | self.pyconv_levels = nn.Sequential(*pyconv_levels)
53 | self.to_out = nn.Sequential(nn.Conv2d((out_channels // 2) * len(pyconv_kernels), out_channels, 1, bias=False),
54 | nn.BatchNorm2d(out_channels),
55 | nn.ReLU(True))
56 | self.relu = nn.ReLU(True)
57 |
58 | def forward(self, x):
59 | out = []
60 | for level in self.pyconv_levels:
61 | out.append(self.relu(level(x)))
62 | out = torch.cat(out, dim=1)
63 | out = self.to_out(out)
64 |
65 | return out
66 |
67 |
68 | class GatedConv2d(torch.nn.Module):
69 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
70 | super(GatedConv2d, self).__init__()
71 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
72 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
73 | self.sigmoid = torch.nn.Sigmoid()
74 |
75 | def gated(self, mask):
76 | return self.sigmoid(mask)
77 |
78 | def forward(self, input):
79 | x = self.conv2d(input)
80 | mask = self.mask_conv2d(input)
81 | x = x * self.gated(mask)
82 |
83 | return x
84 |
85 |
86 | class ContextGatedConv2d(nn.Conv2d):
87 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
88 | padding=1, dilation=1, groups=1, bias=True):
89 | super(ContextGatedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
90 | padding, dilation, groups, bias)
91 | # for convolutional layers with a kernel size of 1, just use traditional convolution
92 | if kernel_size == 1:
93 | self.ind = True
94 | else:
95 | self.ind = False
96 | self.oc = out_channels
97 | self.ks = kernel_size
98 |
99 | # the target spatial size of the pooling layer
100 | ws = kernel_size
101 | self.avg_pool = nn.AdaptiveAvgPool2d((ws, ws))
102 |
103 | # the dimension of the latent repsentation
104 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1)
105 |
106 | # the context encoding module
107 | self.ce = nn.Linear(ws * ws, self.num_lat, False)
108 | self.ce_bn = nn.BatchNorm1d(in_channels)
109 | self.ci_bn2 = nn.BatchNorm1d(in_channels)
110 |
111 | # activation function is relu
112 | self.act = nn.ReLU(inplace=True)
113 |
114 | # the number of groups in the channel interacting module
115 | if in_channels // 16:
116 | self.g = 16
117 | else:
118 | self.g = in_channels
119 | # the channel interacting module
120 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False)
121 | self.ci_bn = nn.BatchNorm1d(out_channels)
122 |
123 | # the gate decoding module
124 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
125 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
126 |
127 | # used to prrepare the input feature map to patches
128 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride)
129 |
130 | # sigmoid function
131 | self.sig = nn.Sigmoid()
132 |
133 | def forward(self, x):
134 | # for convolutional layers with a kernel size of 1, just use traditional convolution
135 | if self.ind:
136 | return F.conv2d(x, self.weight, self.bias, self.stride,
137 | self.padding, self.dilation, self.groups)
138 | else:
139 | b, c, h, w = x.size()
140 | weight = self.weight
141 | # allocate glbal information
142 | gl = self.avg_pool(x).view(b, c, -1)
143 | # context-encoding module
144 | out = self.ce(gl)
145 | # use different bn for the following two branches
146 | ce2 = out
147 | out = self.ce_bn(out)
148 | out = self.act(out)
149 | # gate decoding branch 1
150 | out = self.gd(out)
151 | # channel interacting module
152 | if self.g > 3:
153 | # grouped linear
154 | oc = self.ci(self.act(self.ci_bn2(ce2).view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2, 3).contiguous()
155 | else:
156 | # linear layer for resnet.conv1
157 | oc = self.ci(self.act(self.ci_bn2(ce2).transpose(2, 1))).transpose(2, 1).contiguous()
158 | oc = oc.view(b, self.oc, -1)
159 | oc = self.ci_bn(oc)
160 | oc = self.act(oc)
161 | # gate decoding branch 2
162 | oc = self.gd2(oc)
163 | # produce gate
164 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks))
165 | # unfolding input feature map to patches
166 | x_un = self.unfold(x)
167 | b, _, l = x_un.size()
168 | # gating
169 | out = (out * weight.unsqueeze(0)).view(b, self.oc, -1)
170 | # currently only handle square input and output
171 | return torch.matmul(out, x_un).view(b, self.oc, int(np.sqrt(l)), int(np.sqrt(l)))
172 |
173 |
174 |
--------------------------------------------------------------------------------
/SCD/model/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import timm
5 | from .backbone.mobilenetv2 import mobilenet_v2
6 | from .backbone.resnet import resnet18, resnet50
7 | from .block.bifpn import BiFPN
8 | from .block.bifpn_add import BiFPN_add
9 | from .block.neighbor import NeighborFeatureAggregation
10 | from .block.fpn import FPN
11 | from .block.fpn_plain import FPN_plain
12 | from .block.vertical import VerticalFusion
13 | from .block.convs import ConvBnRelu, DsBnRelu
14 | from .util import init_method
15 | from .block.heads import FCNHead, GatedResidualUpHead
16 |
17 |
18 | def get_backbone(backbone_name):
19 | if backbone_name == 'mobilenetv2':
20 | backbone = mobilenet_v2(pretrained=True, progress=True)
21 | backbone.channels = [16, 24, 32, 96, 320]
22 | elif backbone_name == 'mobilenetv3_small_075':
23 | backbone = timm.create_model('mobilenetv3_small_075', pretrained=True, features_only=True)
24 | backbone.channels = [16, 16, 24, 40, 432]
25 | elif backbone_name == 'mobilenetv3_small_100':
26 | backbone = timm.create_model('mobilenetv3_small_100', pretrained=True, features_only=True)
27 | backbone.channels = [16, 16, 24, 48, 576]
28 | elif backbone_name == 'resnet18':
29 | backbone = resnet18(pretrained=True, progress=True, replace_stride_with_dilation=[False, False, False])
30 | backbone.channels = [64, 64, 128, 256, 512]
31 | elif backbone_name == 'resnet18d':
32 | backbone = timm.create_model('resnet18d', pretrained=True, features_only=True)
33 | backbone.channels = [64, 64, 128, 256, 512]
34 | elif backbone_name == 'resnet50':
35 | backbone = resnet50(pretrained=True, progress=True)
36 | backbone.channels = [64, 256, 512, 1024, 2048]
37 | elif backbone_name == 'hrnet_w18':
38 | backbone = timm.create_model('hrnet_w18', pretrained=True, features_only=True)
39 | backbone.channels = [64, 128, 256, 512, 1024]
40 | else:
41 | raise NotImplementedError("BACKBONE [%s] is not implemented!\n" % backbone_name)
42 | return backbone
43 |
44 |
45 | def get_fpn(fpn_name, in_channels, out_channels, deform_groups=4, gamma_mode='SE', beta_mode='gatedconv'):
46 | if fpn_name == 'fpn':
47 | fpn = FPN(in_channels, out_channels, deform_groups, gamma_mode, beta_mode)
48 | elif fpn_name == 'fpn_plain':
49 | fpn = FPN_plain(in_channels, out_channels)
50 | elif fpn_name == 'bifpn':
51 | fpn = BiFPN(in_channels, out_channels)
52 | elif fpn_name == 'bifpn_add':
53 | fpn = BiFPN_add(in_channels, out_channels)
54 | elif fpn_name == 'neighbor':
55 | fpn = NeighborFeatureAggregation(in_channels, out_channels)
56 | else:
57 | raise NotImplementedError("FPN [%s] is not implemented!\n" % fpn_name)
58 | return fpn
59 |
60 |
61 | class Detector(nn.Module):
62 | def __init__(self, backbone_name='mobilenetv2', fpn_name='fpn', fpn_channels=64,
63 | deform_groups=4, gamma_mode='SE', beta_mode='contextgatedconv',
64 | num_heads=1, num_points=8, kernel_layers=1, dropout_rate=0.1, init_type='kaiming_normal'):
65 | super().__init__()
66 | self.backbone = get_backbone(backbone_name)
67 | self.fpn = get_fpn(fpn_name, in_channels=self.backbone.channels[-4:], out_channels=fpn_channels,
68 | deform_groups=deform_groups, gamma_mode=gamma_mode, beta_mode=beta_mode)
69 | self.p5_to_p4 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=4,
70 | kernel_layers=kernel_layers)
71 | self.p4_to_p3 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=8,
72 | kernel_layers=kernel_layers)
73 | self.p3_to_p2 = VerticalFusion(fpn_channels, num_heads=num_heads, num_points=16,
74 | kernel_layers=kernel_layers)
75 |
76 | self.p5_head = nn.Conv2d(fpn_channels, 2, 1)
77 | self.p4_head = nn.Conv2d(fpn_channels, 2, 1)
78 | self.p3_head = nn.Conv2d(fpn_channels, 2, 1)
79 | self.p2_head = nn.Conv2d(fpn_channels, 2, 1)
80 | self.project = nn.Sequential(nn.Conv2d(fpn_channels * 4, fpn_channels, 1, bias=False),
81 | nn.BatchNorm2d(fpn_channels),
82 | nn.ReLU(True)
83 | )
84 | self.head = GatedResidualUpHead(fpn_channels, 2, dropout_rate=dropout_rate)
85 | self.scd_head = FCNHead(fpn_channels, 7, dropout_rate=dropout_rate)
86 | # init_method(self.fpn, self.p5_to_p4, self.p4_to_p3, self.p3_to_p2, self.p5_head, self.p4_head,
87 | # self.p3_head, self.p2_head, init_type=init_type)
88 |
89 | def forward(self, x1, x2):
90 | ### Extract backbone features
91 | t1_c1, t1_c2, t1_c3, t1_c4, t1_c5 = self.backbone.forward(x1)
92 | t2_c1, t2_c2, t2_c3, t2_c4, t2_c5 = self.backbone.forward(x2)
93 | t1_p2, t1_p3, t1_p4, t1_p5 = self.fpn([t1_c2, t1_c3, t1_c4, t1_c5])
94 | t2_p2, t2_p3, t2_p4, t2_p5 = self.fpn([t2_c2, t2_c3, t2_c4, t2_c5])
95 |
96 | diff_p2 = torch.abs(t1_p2 - t2_p2)
97 | diff_p3 = torch.abs(t1_p3 - t2_p3)
98 | diff_p4 = torch.abs(t1_p4 - t2_p4)
99 | diff_p5 = torch.abs(t1_p5 - t2_p5)
100 | """
101 | pred_p5 = self.p5_head(diff_p5)
102 | pred_p4 = self.p4_head(diff_p4)
103 | pred_p3 = self.p3_head(diff_p3)
104 | pred_p2 = self.p2_head(diff_p2)
105 |
106 | diff_p3 = F.interpolate(diff_p3, size=(64, 64), mode='bilinear', align_corners=False)
107 | diff_p4 = F.interpolate(diff_p4, size=(64, 64), mode='bilinear', align_corners=False)
108 | diff_p5 = F.interpolate(diff_p5, size=(64, 64), mode='bilinear', align_corners=False)
109 | #diff = diff_p2 + diff_p3 + diff_p4 + diff_p5
110 | diff = torch.cat([diff_p2, diff_p3, diff_p4, diff_p5], dim=1)
111 | diff = self.project(diff)
112 | pred = self.head(diff)
113 |
114 | """
115 | fea_p5 = diff_p5
116 | pred_p5 = self.p5_head(fea_p5)
117 | fea_p4 = self.p5_to_p4(fea_p5, diff_p4)
118 | pred_p4 = self.p4_head(fea_p4)
119 | fea_p3 = self.p4_to_p3(fea_p4, diff_p3)
120 | pred_p3 = self.p3_head(fea_p3)
121 | fea_p2 = self.p3_to_p2(fea_p3, diff_p2)
122 | pred_p2 = self.p2_head(fea_p2)
123 | #fea_p3 = F.interpolate(fea_p3, size=(64, 64), mode='bilinear', align_corners=False)
124 | #fea_p4 = F.interpolate(fea_p4, size=(64, 64), mode='bilinear', align_corners=False)
125 | #fea_p5 = F.interpolate(fea_p5, size=(64, 64), mode='bilinear', align_corners=False)
126 | #diff = diff_p2 + diff_p3 + diff_p4 + diff_p5
127 | #diff = torch.cat([fea_p2, fea_p3, fea_p4, fea_p5], dim=1)
128 | #diff = self.project(diff)
129 | pred = self.head(fea_p2)
130 |
131 |
132 | pred_p2 = F.interpolate(pred_p2, size=(256, 256), mode='bilinear', align_corners=False)
133 | pred_p3 = F.interpolate(pred_p3, size=(256, 256), mode='bilinear', align_corners=False)
134 | pred_p4 = F.interpolate(pred_p4, size=(256, 256), mode='bilinear', align_corners=False)
135 | pred_p5 = F.interpolate(pred_p5, size=(256, 256), mode='bilinear', align_corners=False)
136 | #pred = F.interpolate(pred, size=(256, 256), mode='bilinear', align_corners=False)
137 |
138 | t1_p3 = F.interpolate(t1_p3, size=(64, 64), mode='bilinear', align_corners=False)
139 | t1_p4 = F.interpolate(t1_p4, size=(64, 64), mode='bilinear', align_corners=False)
140 | t1_p5 = F.interpolate(t1_p5, size=(64, 64), mode='bilinear', align_corners=False)
141 | t1_fea = torch.cat([t1_p2, t1_p3, t1_p4, t1_p5], dim=1)
142 | t1_fea = self.project(t1_fea)
143 | pred_seg1 = self.scd_head(t1_fea)
144 | pred_seg1 = F.interpolate(pred_seg1, size=(256, 256), mode='bilinear', align_corners=False)
145 | t2_p3 = F.interpolate(t2_p3, size=(64, 64), mode='bilinear', align_corners=False)
146 | t2_p4 = F.interpolate(t2_p4, size=(64, 64), mode='bilinear', align_corners=False)
147 | t2_p5 = F.interpolate(t2_p5, size=(64, 64), mode='bilinear', align_corners=False)
148 | t2_fea = torch.cat([t2_p2, t2_p3, t2_p4, t2_p5], dim=1)
149 | t2_fea = self.project(t2_fea)
150 | pred_seg2 = self.scd_head(t2_fea)
151 | pred_seg2 = F.interpolate(pred_seg2, size=(256, 256), mode='bilinear', align_corners=False)
152 |
153 | return pred, pred_seg1, pred_seg2, pred_p2, pred_p3, pred_p4, pred_p5
154 |
155 |
156 |
157 |
--------------------------------------------------------------------------------
/SCD/model/block/convs.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/iduta/pyconv
3 | https://github.com/ZitongYu/CDCN/
4 | https://github.com/XudongLinthu/context-gated-convolution
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from mmcv.ops import ModulatedDeformConv2dPack as DCNv2
11 | import numpy as np
12 |
13 |
14 | class ConvBnRelu(nn.Module):
15 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
16 | super(ConvBnRelu, self).__init__()
17 | self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
18 | padding=padding, dilation=dilation, bias=False),
19 | nn.BatchNorm2d(out_channels),
20 | nn.ReLU(inplace=True))
21 |
22 | def forward(self, x):
23 | x = self.block(x)
24 | return x
25 |
26 |
27 | class DsBnRelu(nn.Module):
28 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
29 | super(DsBnRelu, self).__init__()
30 | self.kernel_size = kernel_size
31 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding,
32 | dilation, groups=in_channels, bias=False)
33 | self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
34 | self.bn = nn.BatchNorm2d(out_channels)
35 | self.relu = nn.ReLU(True)
36 |
37 | def forward(self, x):
38 | if self.kernel_size != 1:
39 | x = self.depthwise(x)
40 | x = self.pointwise(x)
41 | x = self.bn(x)
42 | x = self.relu(x)
43 | return x
44 |
45 |
46 | class PyConv2d(nn.Module):
47 | def __init__(self, in_channels, out_channels, pyconv_kernels=[1, 3, 5, 7], pyconv_groups=[1, 2, 4, 8], bias=False):
48 | super(PyConv2d, self).__init__()
49 |
50 | pyconv_levels = []
51 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups):
52 | pyconv_levels.append(nn.Conv2d(in_channels, out_channels // 2, kernel_size=pyconv_kernel,
53 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=bias))
54 | self.pyconv_levels = nn.Sequential(*pyconv_levels)
55 | self.to_out = nn.Sequential(nn.Conv2d((out_channels // 2) * len(pyconv_kernels), out_channels, 1, bias=False),
56 | nn.BatchNorm2d(out_channels),
57 | nn.ReLU(True))
58 | self.relu = nn.ReLU(True)
59 |
60 | def forward(self, x):
61 | out = []
62 | for level in self.pyconv_levels:
63 | out.append(self.relu(level(x)))
64 | out = torch.cat(out, dim=1)
65 | out = self.to_out(out)
66 |
67 | return out
68 |
69 |
70 | class PyDCNv2(nn.Module):
71 | def __init__(self, in_channels, out_channels, pyconv_kernels=[1, 3, 5], pyconv_groups=[1, 4, 8], bias=False):
72 | super(PyDCNv2, self).__init__()
73 |
74 | pyconv_levels = []
75 | for pyconv_kernel, pyconv_group in zip(pyconv_kernels, pyconv_groups):
76 | if pyconv_kernel == 1:
77 | pyconv_levels.append(nn.Conv2d(in_channels, out_channels // 2, kernel_size=pyconv_kernel,
78 | padding=pyconv_kernel // 2, groups=pyconv_group, bias=bias))
79 | else:
80 | pyconv_levels.append(DCNv2(in_channels, out_channels // 2, kernel_size=pyconv_kernel,
81 | padding=pyconv_kernel // 2, deform_groups=pyconv_group))
82 | self.pyconv_levels = nn.Sequential(*pyconv_levels)
83 | self.to_out = nn.Sequential(nn.Conv2d((out_channels // 2) * len(pyconv_kernels), out_channels, 1, bias=False),
84 | nn.BatchNorm2d(out_channels),
85 | nn.ReLU6(True))
86 | self.relu = nn.ReLU6(True)
87 |
88 | def forward(self, x):
89 | out = []
90 | for level in self.pyconv_levels:
91 | out.append(self.relu(level(x)))
92 | out = torch.cat(out, dim=1)
93 | out = self.to_out(out)
94 |
95 | return out
96 |
97 |
98 | class CDC(nn.Module):
99 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
100 | padding=1, dilation=1, groups=1, bias=False, theta=0.7):
101 | super(CDC, self).__init__()
102 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
103 | self.theta = theta
104 |
105 | def forward(self, x):
106 | out_normal = self.conv(x)
107 | #pdb.set_trace()
108 | [C_out, C_in, kernel_size, kernel_size] = self.conv.weight.shape
109 | kernel_diff = self.conv.weight.sum(2).sum(2)
110 | kernel_diff = kernel_diff[:, :, None, None]
111 | out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, groups=self.conv.groups)
112 |
113 | return out_normal - self.theta * out_diff
114 |
115 |
116 | class GatedConv2d(torch.nn.Module):
117 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
118 | super(GatedConv2d, self).__init__()
119 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
120 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
121 | self.sigmoid = torch.nn.Sigmoid()
122 |
123 | def gated(self, mask):
124 | return self.sigmoid(mask)
125 |
126 | def forward(self, input):
127 | x = self.conv2d(input)
128 | mask = self.mask_conv2d(input)
129 | x = x * self.gated(mask)
130 |
131 | return x
132 |
133 |
134 | class ContextGatedConv2d(nn.Conv2d):
135 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
136 | padding=1, dilation=1, groups=1, bias=True):
137 | super(ContextGatedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
138 | padding, dilation, groups, bias)
139 | # for convolutional layers with a kernel size of 1, just use traditional convolution
140 | if kernel_size == 1:
141 | self.ind = True
142 | else:
143 | self.ind = False
144 | self.oc = out_channels
145 | self.ks = kernel_size
146 |
147 | # the target spatial size of the pooling layer
148 | ws = kernel_size
149 | self.avg_pool = nn.AdaptiveAvgPool2d((ws, ws))
150 |
151 | # the dimension of the latent repsentation
152 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1)
153 |
154 | # the context encoding module
155 | self.ce = nn.Linear(ws * ws, self.num_lat, False)
156 | self.ce_bn = nn.BatchNorm1d(in_channels)
157 | self.ci_bn2 = nn.BatchNorm1d(in_channels)
158 |
159 | # activation function is relu
160 | self.act = nn.ReLU(inplace=True)
161 |
162 | # the number of groups in the channel interacting module
163 | if in_channels // 16:
164 | self.g = 16
165 | else:
166 | self.g = in_channels
167 | # the channel interacting module
168 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False)
169 | self.ci_bn = nn.BatchNorm1d(out_channels)
170 |
171 | # the gate decoding module
172 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
173 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
174 |
175 | # used to prrepare the input feature map to patches
176 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride)
177 |
178 | # sigmoid function
179 | self.sig = nn.Sigmoid()
180 |
181 | def forward(self, x):
182 | # for convolutional layers with a kernel size of 1, just use traditional convolution
183 | if self.ind:
184 | return F.conv2d(x, self.weight, self.bias, self.stride,
185 | self.padding, self.dilation, self.groups)
186 | else:
187 | b, c, h, w = x.size()
188 | weight = self.weight
189 | # allocate glbal information
190 | gl = self.avg_pool(x).view(b, c, -1)
191 | # context-encoding module
192 | out = self.ce(gl)
193 | # use different bn for the following two branches
194 | ce2 = out
195 | out = self.ce_bn(out)
196 | out = self.act(out)
197 | # gate decoding branch 1
198 | out = self.gd(out)
199 | # channel interacting module
200 | if self.g > 3:
201 | # grouped linear
202 | oc = self.ci(self.act(self.ci_bn2(ce2).view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2, 3).contiguous()
203 | else:
204 | # linear layer for resnet.conv1
205 | oc = self.ci(self.act(self.ci_bn2(ce2).transpose(2, 1))).transpose(2, 1).contiguous()
206 | oc = oc.view(b, self.oc, -1)
207 | oc = self.ci_bn(oc)
208 | oc = self.act(oc)
209 | # gate decoding branch 2
210 | oc = self.gd2(oc)
211 | # produce gate
212 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks))
213 | # unfolding input feature map to patches
214 | x_un = self.unfold(x)
215 | b, _, l = x_un.size()
216 | # gating
217 | out = (out * weight.unsqueeze(0)).view(b, self.oc, -1)
218 | # currently only handle square input and output
219 | return torch.matmul(out, x_un).view(b, self.oc, int(np.sqrt(l)), int(np.sqrt(l)))
220 |
221 |
222 |
--------------------------------------------------------------------------------