├── .idea
├── .gitignore
├── .name
├── AMTNet.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
└── modules.xml
├── README.md
├── __init__.py
├── __pycache__
├── data_utils.cpython-37.pyc
├── data_utils.cpython-39.pyc
├── train_options.cpython-39.pyc
├── train_options_CLCD.cpython-39.pyc
├── train_options_HRSCD.cpython-39.pyc
├── train_options_LEVIR.cpython-39.pyc
└── train_options_WHU.cpython-39.pyc
├── data_utils.py
├── loss
├── __pycache__
│ ├── losses.cpython-37.pyc
│ └── losses.cpython-39.pyc
└── losses.py
├── model
├── MSCANet.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-39.pyc
│ ├── backbone.cpython-39.pyc
│ ├── modules.cpython-39.pyc
│ └── network.cpython-39.pyc
├── backbone.py
├── dtcdscn.py
├── modules.py
├── network.py
├── siamunet_conc.py
├── siamunet_diff.py
└── unet.py
├── test.py
├── train.py
├── train_options.py
├── utils.py
└── visualize_results.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/.name:
--------------------------------------------------------------------------------
1 | losses.py
--------------------------------------------------------------------------------
/.idea/AMTNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AMTNet
2 | The pytorch implementation for AMTNet in paper ["An attention-based multiscale transformer network for remote sensing image
3 | change detection"](https://www.sciencedirect.com/science/article/abs/pii/S092427162300182X?CMX_ID=&SIS_ID=&dgcid=STMJ_AUTH_SERV_PUBLISHED&utm_acid=276849605&utm_campaign=STMJ_AUTH_SERV_PUBLISHED&utm_in=DM391842&utm_medium=email&utm_source=AC_)on ["ISPRS Journal of Photogrammetry and Remote Sensing"](https://www.sciencedirect.com/journal/isprs-journal-of-photogrammetry-and-remote-sensing).
4 | # Requirements
5 | * Python 3.9
6 | * Pytorch 1.12
7 | # DataSet
8 | * Download the [CLCD Dataset](https://pan.baidu.com/share/init?surl=Un-bVxUm1N9IHiDOXLLHlg&pwd=miu2)
9 | * Download the [HRSCD Dataset](https://ieee-dataport.org/open-access/hrscd-high-resolution-semantic-change-detection-dataset)
10 | * Download the [WHU-CD Dataset](http://gpcv.whu.edu.cn/data/building_dataset.html)
11 | * Download the [LEVIR-CD Dataset](http://chenhao.in/LEVIR/)
12 | ```
13 | Prepare datasets into following structure and set their path in train_options.py
14 | ├─Train
15 | │ ├─time1
16 | │ │─time2
17 | │ │─label
18 | ├─Test
19 | │ ├─time1
20 | │ │─time2
21 | │ │─label
22 | ```
23 | # Train
24 | ```
25 | python train.py
26 | ```
27 | All the hyperparameters can be adjusted in train_options.py
28 | # model zool
29 | The models with the scores can be downloaded from[Baidu Cloud]().
30 | # Acknowledgments
31 | This code is heavily borrowed from [MSCANet](https://github.com/liumency/CropLand-CD) and [changer](https://github.com/likyoo/open-cd/tree/main).
32 | # Citation
33 | If you find this repo useful for your research, please consider citing the paper as follows:
34 | ```
35 | @article{liu2023attention,
36 | title={An attention-based multiscale transformer network for remote sensing image change detection},
37 | author={Liu, Wei and Lin, Yiyuan and Liu, Weijia and Yu, Yongtao and Li, Jonathan},
38 | journal={ISPRS Journal of Photogrammetry and Remote Sensing},
39 | volume={202},
40 | pages={599--609},
41 | year={2023},
42 | publisher={Elsevier}
43 | }
44 | ```
45 |
46 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__init__.py
--------------------------------------------------------------------------------
/__pycache__/data_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/data_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/data_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/data_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/train_options.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/train_options_CLCD.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_CLCD.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/train_options_HRSCD.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_HRSCD.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/train_options_LEVIR.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_LEVIR.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/train_options_WHU.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_WHU.cpython-39.pyc
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | from os.path import join
3 | import torch
4 | from PIL import Image, ImageEnhance
5 | from torch.utils.data.dataset import Dataset
6 | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
7 | import numpy as np
8 | import torchvision.transforms as transforms
9 | import os
10 | import imageio
11 |
12 |
13 | def is_image_file(filename):
14 | return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
15 |
16 | def calMetric_iou(predict, label):
17 | tp = np.sum(np.logical_and(predict == 1, label == 1))
18 | fp = np.sum(predict==1)
19 | fn = np.sum(label == 1)
20 | return tp,fp+fn-tp
21 |
22 |
23 | def getDataList(img_path):
24 | dataline = open(img_path, 'r').readlines()
25 | datalist =[]
26 | for line in dataline:
27 | temp = line.strip('\n')
28 | datalist.append(temp)
29 | return datalist
30 |
31 |
32 | def make_one_hot(input, num_classes):
33 | """Convert class index tensor to one hot encoding tensor.
34 |
35 | Args:
36 | input: A tensor of shape [N, 1, *]
37 | num_classes: An int of number of class
38 | Returns:
39 | A tensor of shape [N, num_classes, *]
40 | """
41 | shape = np.array(input.shape)
42 | shape[1] = num_classes
43 | shape = tuple(shape)
44 | result = torch.zeros(shape)
45 | result = result.scatter_(1, input.cpu(), 1)
46 | return result
47 |
48 |
49 | def get_transform(convert=True, normalize=False):
50 | transform_list = []
51 | if convert:
52 | transform_list += [transforms.ToTensor()]
53 | if normalize:
54 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
55 | (0.5, 0.5, 0.5))]
56 | return transforms.Compose(transform_list)
57 |
58 |
59 |
60 | class LoadDatasetFromFolder(Dataset):
61 | def __init__(self, args, hr1_path, hr2_path, lab_path):
62 | super(LoadDatasetFromFolder, self).__init__()
63 | # 获取图片列表
64 | datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if
65 | os.path.splitext(name)[1] == item]
66 |
67 | self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
68 | self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
69 | self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
70 |
71 | self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
72 | self.label_transform = get_transform() # only convert to tensor
73 |
74 | def __getitem__(self, index):
75 | hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
76 | # lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB'))
77 | hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
78 |
79 | label = self.label_transform(Image.open(self.lab_filenames[index]))
80 | label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
81 |
82 | return hr1_img, hr2_img, label
83 |
84 | def __len__(self):
85 | return len(self.hr1_filenames)
86 |
87 |
88 | class TestDatasetFromFolder(Dataset):
89 | def __init__(self, args, Time1_dir, Time2_dir, Label_dir):
90 | super(TestDatasetFromFolder, self).__init__()
91 |
92 | datalist = [name for name in os.listdir(Time1_dir) for item in args.suffix if
93 | os.path.splitext(name)[1] == item]
94 |
95 | self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
96 | self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
97 | self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
98 |
99 | self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
100 | self.label_transform = get_transform()
101 |
102 | def __getitem__(self, index):
103 | image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
104 | image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
105 |
106 | label = self.label_transform(Image.open(self.image3_filenames[index]))
107 | label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
108 |
109 | image_name = self.image1_filenames[index].split('/', -1)
110 | image_name = image_name[len(image_name)-1]
111 |
112 | return image1, image2, label, image_name
113 |
114 | def __len__(self):
115 | return len(self.image1_filenames)
116 |
117 |
118 | class trainImageAug(object):
119 | def __init__(self, crop = True, augment = True, angle = 30):
120 | self.crop =crop
121 | self.augment = augment
122 | self.angle = angle
123 |
124 | def __call__(self, image1, image2, mask):
125 | if self.crop:
126 | w = np.random.randint(0,256)
127 | h = np.random.randint(0,256)
128 | box = (w, h, w+256, h+256)
129 | image1 = image1.crop(box)
130 | image2 = image2.crop(box)
131 | mask = mask.crop(box)
132 | if self.augment:
133 | prop = np.random.uniform(0, 1)
134 | if prop < 0.15:
135 | image1 = image1.transpose(Image.FLIP_LEFT_RIGHT)
136 | image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
137 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
138 | elif prop < 0.3:
139 | image1 = image1.transpose(Image.FLIP_TOP_BOTTOM)
140 | image2 = image2.transpose(Image.FLIP_TOP_BOTTOM)
141 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
142 | elif prop < 0.5:
143 | image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
144 | image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
145 | mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
146 |
147 | return image1, image2, mask
148 |
149 | def get_transform(convert=True, normalize=False):
150 | transform_list = []
151 | if convert:
152 | transform_list += [
153 | transforms.ToTensor(),
154 | ]
155 | if normalize:
156 | transform_list += [
157 | # transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
158 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
159 | return transforms.Compose(transform_list)
160 |
161 |
162 | class DA_DatasetFromFolder(Dataset):
163 | def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30):
164 | super(DA_DatasetFromFolder, self).__init__()
165 | # 获取图片列表
166 | datalist = os.listdir(Image_dir1)
167 | self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)]
168 | self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)]
169 | self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
170 | self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle)
171 | self.img_transform = get_transform(convert=True, normalize=True)
172 | self.lab_transform = get_transform()
173 |
174 | def __getitem__(self, index):
175 | image1 = Image.open(self.image_filenames1[index]).convert('RGB')
176 | image2 = Image.open(self.image_filenames2[index]).convert('RGB')
177 | label = Image.open(self.label_filenames[index])
178 | image1, image2, label = self.data_augment(image1, image2, label)
179 | image1, image2 = self.img_transform(image1), self.img_transform(image2)
180 | label = self.lab_transform(label)
181 | label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
182 | return image1, image2, label
183 |
184 | def __len__(self):
185 | return len(self.image_filenames1)
--------------------------------------------------------------------------------
/loss/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/loss/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/losses.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/loss/__pycache__/losses.cpython-39.pyc
--------------------------------------------------------------------------------
/loss/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | class cross_entropy(nn.Module):
6 | def __init__(self, weight=None, reduction='mean',ignore_index=256):
7 | super(cross_entropy, self).__init__()
8 | self.weight = weight
9 | self.ignore_index =ignore_index
10 | self.reduction = reduction
11 |
12 |
13 | def forward(self,input, target):
14 | target = target.long()
15 | if target.dim() == 4:
16 | target = torch.squeeze(target, dim=1)
17 | if input.shape[-1] != target.shape[-1]:
18 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
19 |
20 | return F.cross_entropy(input=input, target=target, weight=self.weight,
21 | ignore_index=self.ignore_index, reduction=self.reduction)
22 |
--------------------------------------------------------------------------------
/model/MSCANet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 | from torch import nn
5 |
6 |
7 | class Residual(nn.Module):
8 | def __init__(self, fn):
9 | super().__init__()
10 | self.fn = fn
11 | def forward(self, x, **kwargs):
12 | return self.fn(x, **kwargs) + x
13 |
14 |
15 | class Residual2(nn.Module):
16 | def __init__(self, fn):
17 | super().__init__()
18 | self.fn = fn
19 | def forward(self, x, x2, **kwargs):
20 | return self.fn(x, x2, **kwargs) + x
21 |
22 |
23 | class PreNorm(nn.Module):
24 | def __init__(self, dim, fn):
25 | super().__init__()
26 | self.norm = nn.LayerNorm(dim)
27 | self.fn = fn
28 | def forward(self, x, **kwargs):
29 | return self.fn(self.norm(x), **kwargs)
30 |
31 | class FeedForward(nn.Module):
32 | def __init__(self, dim, hidden_dim, dropout = 0.):
33 | super().__init__()
34 | self.net = nn.Sequential(
35 | nn.Linear(dim, hidden_dim),
36 | nn.GELU(),
37 | nn.Dropout(dropout),
38 | nn.Linear(hidden_dim, dim),
39 | nn.Dropout(dropout)
40 | )
41 | def forward(self, x):
42 | return self.net(x)
43 |
44 | class Attention(nn.Module):
45 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
46 | super().__init__()
47 | inner_dim = dim_head * heads
48 | project_out = not (heads == 1 and dim_head == dim)
49 |
50 | self.heads = heads
51 | self.scale = dim_head ** -0.5
52 |
53 | self.attend = nn.Softmax(dim = -1)
54 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
55 |
56 | self.to_out = nn.Sequential(
57 | nn.Linear(inner_dim, dim),
58 | nn.Dropout(dropout)
59 | ) if project_out else nn.Identity()
60 |
61 | def forward(self, x):
62 | qkv = self.to_qkv(x).chunk(3, dim = -1)
63 |
64 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
65 |
66 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
67 |
68 | attn = self.attend(dots)
69 |
70 | out = torch.matmul(attn, v)
71 |
72 | out = rearrange(out, 'b h n d -> b n (h d)')
73 | out = self.to_out(out)
74 |
75 | return out
76 |
77 | class Transformer(nn.Module):
78 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
79 | super().__init__()
80 | self.layers = nn.ModuleList([])
81 | for _ in range(depth):
82 | self.layers.append(nn.ModuleList([
83 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
84 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
85 | ]))
86 | def forward(self, x):
87 | for attn, ff in self.layers:
88 | x = attn(x) + x
89 | x = ff(x) + x
90 | return x
91 |
92 |
93 |
94 | class Cross_Attention(nn.Module):
95 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True):
96 | super().__init__()
97 | inner_dim = dim_head * heads
98 | self.heads = heads
99 | self.scale = dim ** -0.5
100 |
101 | self.softmax = softmax
102 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
103 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
104 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
105 |
106 | self.to_out = nn.Sequential(
107 | nn.Linear(inner_dim, dim),
108 | nn.Dropout(dropout)
109 | )
110 |
111 | def forward(self, x, m, mask = None):
112 |
113 | b, n, _, h = *x.shape, self.heads
114 | q = self.to_q(x)
115 | k = self.to_k(m)
116 | v = self.to_v(m)
117 |
118 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v])
119 |
120 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
121 | mask_value = -torch.finfo(dots.dtype).max
122 |
123 | if mask is not None:
124 | mask = F.pad(mask.flatten(1), (1, 0), value = True)
125 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
126 | mask = mask[:, None, :] * mask[:, :, None]
127 | dots.masked_fill_(~mask, mask_value)
128 | del mask
129 |
130 | if self.softmax:
131 | attn = dots.softmax(dim=-1)
132 | else:
133 | attn = dots
134 |
135 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
136 | out = rearrange(out, 'b h n d -> b n (h d)')
137 | out = self.to_out(out)
138 |
139 | return out
140 |
141 | class PreNorm2(nn.Module):
142 | def __init__(self, dim, fn):
143 | super().__init__()
144 | self.norm = nn.LayerNorm(dim)
145 | self.fn = fn
146 | def forward(self, x, x2, **kwargs):
147 | return self.fn(self.norm(x), self.norm(x2), **kwargs)
148 |
149 |
150 | class TransformerDecoder(nn.Module):
151 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True):
152 | super().__init__()
153 | self.layers = nn.ModuleList([])
154 | for _ in range(depth):
155 | self.layers.append(nn.ModuleList([
156 | Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads,
157 | dim_head = dim_head, dropout = dropout,
158 | softmax=softmax))),
159 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
160 | ]))
161 | def forward(self, x, m, mask = None):
162 | """target(query), memory"""
163 | for attn, ff in self.layers:
164 | x = attn(x, m, mask = mask)
165 | x = ff(x)
166 | return x
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | import math
7 | from torch.autograd import Variable
8 |
9 | class Model(nn.Module):
10 | def __init__(self, args, ckp):
11 | super(Model, self).__init__()
12 | print('Making model...')
13 |
14 | self.scale = args.scale
15 | self.idx_scale = 0
16 | self.self_ensemble = args.self_ensemble
17 | self.chop = args.chop
18 | self.precision = args.precision
19 | self.cpu = args.cpu
20 | self.device = torch.device('cpu' if args.cpu else 'cuda')
21 | self.n_GPUs = args.n_GPUs
22 | self.save_models = args.save_models
23 |
24 | module = import_module('model.' + args.model.lower())
25 | self.model = module.make_model(args).to(self.device)
26 | if args.precision == 'half': self.model.half()
27 |
28 | if not args.cpu and args.n_GPUs > 1:
29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs))
30 |
31 | self.load(
32 | ckp.dir,
33 | pre_train=args.pre_train,
34 | resume=args.resume,
35 | cpu=args.cpu
36 | )
37 | print(self.model, file=ckp.log_file)
38 |
39 | def forward(self, x, scale, pos_mat):
40 | self.scale = scale
41 | target = self.get_model()
42 | if hasattr(target, 'set_scale'):
43 | target.set_scale(scale)
44 |
45 | if self.self_ensemble and not self.training:
46 | if self.chop:
47 | forward_function = self.forward_chop
48 | else:
49 | forward_function = self.model.forward
50 |
51 | return self.forward_x8(x, forward_function)
52 | elif self.chop and not self.training:
53 | return self.forward_chop(x,pos_mat)
54 | else:
55 | return self.model(x,pos_mat)
56 |
57 | def get_model(self):
58 | if self.n_GPUs <= 1 or self.cpu:
59 | return self.model
60 | else:
61 | return self.model.module
62 |
63 | def state_dict(self, **kwargs):
64 | target = self.get_model()
65 | return target.state_dict(**kwargs)
66 |
67 | def save(self, apath, epoch, is_best=False):
68 | target = self.get_model()
69 | torch.save(
70 | target.state_dict(),
71 | os.path.join(apath, 'model', 'model_latest.pt')
72 | )
73 | if is_best:
74 | torch.save(
75 | target.state_dict(),
76 | os.path.join(apath, 'model', 'model_best.pt')
77 | )
78 |
79 | if self.save_models:
80 | torch.save(
81 | target.state_dict(),
82 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
83 | )
84 |
85 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
86 | if cpu:
87 | kwargs = {'map_location': lambda storage, loc: storage}
88 | else:
89 | kwargs = {}
90 |
91 | if resume == -1:
92 | self.get_model().load_state_dict(
93 | torch.load(
94 | os.path.join(apath, 'model', 'model_latest.pt'),
95 | **kwargs
96 | ),
97 | strict=False
98 | )
99 | elif resume == 0:
100 | if pre_train != '.':
101 | print('Loading model from {}'.format(pre_train))
102 | self.get_model().load_state_dict(
103 | torch.load(pre_train, **kwargs),
104 | strict=False
105 | )
106 | print('load_model_mode=1')
107 | else:
108 | self.get_model().load_state_dict(
109 | torch.load(
110 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
111 | **kwargs
112 | ),
113 | strict=False
114 | )
115 | print('load_model_mode=2')
116 |
117 | def forward_chop(self, x, pos_mat, shave=10, min_size=160000):
118 | scale = self.scale[self.idx_scale]
119 | n_GPUs = min(self.n_GPUs, 4)
120 | b, c, h, w = x.size()
121 | h_half, w_half = h // 2, w // 2
122 | h_size, w_size = h_half + shave, w_half + shave
123 | lr_list = [
124 | x[:, :, 0:h_size, 0:w_size],
125 | x[:, :, 0:h_size, (w - w_size):w],
126 | x[:, :, (h - h_size):h, 0:w_size],
127 | x[:, :, (h - h_size):h, (w - w_size):w]]
128 |
129 | if w_size * h_size < min_size:
130 | sr_list = []
131 | for i in range(0, 4, n_GPUs):
132 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
133 | sr_batch = self.model(lr_batch, pos_mat)
134 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
135 | else:
136 | sr_list = [
137 | self.forward_chop(patch, pos_mat, shave=shave, min_size=min_size) \
138 | for patch in lr_list
139 | ]
140 | scale = math.ceil(scale)
141 | h, w = scale * h, scale * w
142 | h_half, w_half = scale * h_half, scale * w_half
143 | h_size, w_size = scale * h_size, scale * w_size
144 | shave *= scale
145 |
146 | output = x.new(b, c, h, w)
147 | output[:, :, 0:h_half, 0:w_half] \
148 | = sr_list[0][:, :, 0:h_half, 0:w_half]
149 | output[:, :, 0:h_half, w_half:w] \
150 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
151 | output[:, :, h_half:h, 0:w_half] \
152 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
153 | output[:, :, h_half:h, w_half:w] \
154 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
155 |
156 | return output
157 |
158 | def forward_x8(self, x, forward_function):
159 | def _transform(v, op):
160 | if self.precision != 'single': v = v.float()
161 |
162 | v2np = v.data.cpu().numpy()
163 | if op == 'v':
164 | tfnp = v2np[:, :, :, ::-1].copy()
165 | elif op == 'h':
166 | tfnp = v2np[:, :, ::-1, :].copy()
167 | elif op == 't':
168 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
169 |
170 | ret = torch.Tensor(tfnp).to(self.device)
171 | if self.precision == 'half': ret = ret.half()
172 |
173 | return ret
174 |
175 | lr_list = [x]
176 | for tf in 'v', 'h', 't':
177 | lr_list.extend([_transform(t, tf) for t in lr_list])
178 |
179 | sr_list = [forward_function(aug) for aug in lr_list]
180 | for i in range(len(sr_list)):
181 | if i > 3:
182 | sr_list[i] = _transform(sr_list[i], 't')
183 | if i % 4 > 1:
184 | sr_list[i] = _transform(sr_list[i], 'h')
185 | if (i % 4) % 2 == 1:
186 | sr_list[i] = _transform(sr_list[i], 'v')
187 |
188 | output_cat = torch.cat(sr_list, dim=0)
189 | output = output_cat.mean(dim=0, keepdim=True)
190 |
191 | return output
192 |
193 |
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/backbone.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/backbone.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/modules.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/modules.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/network.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/network.cpython-39.pyc
--------------------------------------------------------------------------------
/model/backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.model_zoo as model_zoo
5 | import math
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9 | """3x3 convolution with padding"""
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=dilation, groups=groups, bias=False, dilation=dilation)
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19 | }
20 |
21 | def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
22 | """
23 | output, low_level_feat:
24 | 512, 64
25 | """
26 | print(in_c)
27 | model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
28 | if in_c != 3:
29 | pretrained = False
30 | if pretrained:
31 | model._load_pretrained_model(model_urls['resnet34'])
32 | return model
33 |
34 |
35 | def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
36 | """
37 | output, low_level_feat:
38 | 512, 256, 128, 64, 64
39 | """
40 | model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c)
41 | if in_c !=3:
42 | pretrained=False
43 | if pretrained:
44 | print(122222222)
45 | model._load_pretrained_model(model_urls['resnet18'])
46 | return model
47 |
48 |
49 | def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
50 | """
51 | output, low_level_feat:
52 | 2048, 256
53 | """
54 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
55 | if in_c !=3:
56 | pretrained=False
57 | if pretrained:
58 | model._load_pretrained_model(model_urls['resnet50'])
59 | return model
60 |
61 |
62 | class BasicBlock(nn.Module):
63 | expansion = 1
64 |
65 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
66 | super(BasicBlock, self).__init__()
67 |
68 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
69 | dilation=dilation, padding=dilation, bias=False)
70 | self.bn1 = BatchNorm(planes)
71 | self.relu = nn.ReLU(inplace=True)
72 | # self.do1 = nn.Dropout2d(p=0.2)
73 |
74 | self.conv2 = conv3x3(planes, planes)
75 | self.bn2 = BatchNorm(planes)
76 | self.downsample = downsample
77 | self.stride = stride
78 |
79 | def forward(self, x):
80 | identity = x
81 |
82 | out = self.conv1(x)
83 | out = self.bn1(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv2(out)
87 | out = self.bn2(out)
88 |
89 | if self.downsample is not None:
90 | identity = self.downsample(x)
91 |
92 | out += identity
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 | class SELayer(nn.Module):
98 | def __init__(self, channel, reduction=16):
99 | super(SELayer, self).__init__()
100 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
101 | self.fc = nn.Sequential(
102 | nn.Linear(channel, channel // reduction, bias=False),
103 | nn.ReLU(inplace=True),
104 | nn.Linear(channel // reduction, channel, bias=False),
105 | nn.Sigmoid()
106 | )
107 |
108 | def forward(self, x):
109 | b, c, _, _ = x.size()
110 | y = self.avg_pool(x).view(b, c)
111 | y = self.fc(y).view(b, c, 1, 1)
112 | return x * y.expand_as(x)
113 |
114 | class Bottleneck(nn.Module):
115 | expansion = 4
116 |
117 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
118 | super(Bottleneck, self).__init__()
119 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
120 | self.bn1 = BatchNorm(planes)
121 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
122 | dilation=dilation, padding=dilation, bias=False)
123 | self.bn2 = BatchNorm(planes)
124 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
125 | self.bn3 = BatchNorm(planes * 4)
126 | self.relu = nn.ReLU()
127 | self.downsample = downsample
128 | self.stride = stride
129 | self.dilation = dilation
130 |
131 |
132 |
133 | def forward(self, x):
134 | residual = x
135 |
136 | out = self.conv1(x)
137 | out = self.bn1(out)
138 | out = self.relu(out)
139 |
140 | out = self.conv2(out)
141 | out = self.bn2(out)
142 | out = self.relu(out)
143 |
144 | out = self.conv3(out)
145 | out = self.bn3(out)
146 |
147 | if self.downsample is not None:
148 | residual = self.downsample(x)
149 |
150 | out += residual
151 | out = self.relu(out)
152 |
153 | return out
154 |
155 | class PA(nn.Module):
156 | def __init__(self, inchan = 512, out_chan = 32):
157 | super().__init__()
158 | self.conv = nn.Conv2d(inchan, out_chan, kernel_size=3, padding=1, bias=False)
159 | self.bn = nn.BatchNorm2d(out_chan)
160 | self.re = nn.ReLU()
161 | self.do = nn.Dropout2d(0.2)
162 |
163 | self.pa_conv = nn.Conv2d(out_chan, out_chan, kernel_size=1, padding=0, groups=out_chan)
164 | self.sigmoid = nn.Sigmoid()
165 |
166 | def forward(self, x):
167 | x0 = self.conv(x)
168 | x = self.do(self.re(self.bn(x0)))
169 | return x0 *self.sigmoid(self.pa_conv(x))
170 |
171 |
172 | class ResNet(nn.Module):
173 |
174 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, in_c=3):
175 |
176 | self.inplanes = 64
177 | self.in_c = in_c
178 | print('in_c: ',self.in_c)
179 | super(ResNet, self).__init__()
180 | blocks = [1, 2, 4]
181 | if output_stride == 32:
182 | strides = [1, 2, 2, 2]
183 | dilations = [1, 1, 1, 1]
184 | elif output_stride == 16:
185 | strides = [1, 2, 2, 1]
186 | dilations = [1, 1, 1, 2]
187 | elif output_stride == 8:
188 | strides = [1, 2, 1, 1]
189 | dilations = [1, 1, 2, 4]
190 | elif output_stride == 4:
191 | strides = [1, 1, 1, 1]
192 | dilations = [1, 2, 4, 8]
193 | else:
194 | raise NotImplementedError
195 |
196 | # Modules
197 | self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3,
198 | bias=False)
199 | self.bn1 = BatchNorm(64)
200 | self.relu = nn.ReLU(inplace=True)
201 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
202 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
203 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
204 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
205 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
206 | self._init_weight()
207 |
208 | self.pos_s16 = PA(2048, 32)
209 | self.pos_s8 = PA(512, 32)
210 | self.pos_s4 = PA(256, 32)
211 |
212 | # self.pos_s16 = PA(512, 32)
213 | # self.pos_s8 = PA(128, 32)
214 | # self.pos_s4 = PA(64, 32)
215 |
216 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
217 | downsample = None
218 | if stride != 1 or self.inplanes != planes * block.expansion:
219 | downsample = nn.Sequential(
220 | nn.Conv2d(self.inplanes, planes * block.expansion,
221 | kernel_size=1, stride=stride, bias=False),
222 | BatchNorm(planes * block.expansion),
223 | )
224 |
225 | layers = []
226 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
227 | self.inplanes = planes * block.expansion
228 | for i in range(1, blocks):
229 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
230 |
231 | return nn.Sequential(*layers)
232 |
233 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
234 | downsample = None
235 | if stride != 1 or self.inplanes != planes * block.expansion:
236 | downsample = nn.Sequential(
237 | nn.Conv2d(self.inplanes, planes * block.expansion,
238 | kernel_size=1, stride=stride, bias=False),
239 | BatchNorm(planes * block.expansion),
240 | )
241 |
242 | layers = []
243 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
244 | downsample=downsample, BatchNorm=BatchNorm))
245 | self.inplanes = planes * block.expansion
246 | for i in range(1, len(blocks)):
247 | layers.append(block(self.inplanes, planes, stride=1,
248 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
249 |
250 | return nn.Sequential(*layers)
251 |
252 | def forward(self, input):
253 | x = self.conv1(input)
254 | x = self.bn1(x)
255 | x = self.relu(x)
256 | x = self.maxpool(x) # | 4
257 |
258 |
259 | x = self.layer1(x) # | 4
260 | low_level_feat2 = x
261 |
262 | x = self.layer2(x) # | 8
263 | low_level_feat3 = x
264 |
265 | x = self.layer3(x) # | 16
266 | x = self.layer4(x) # | 16
267 |
268 | out_s16, out_s8, out_s4 = self.pos_s16(x), self.pos_s8(low_level_feat3), self.pos_s4(low_level_feat2)
269 | return out_s16, out_s8, out_s4
270 |
271 |
272 | def _init_weight(self):
273 | for m in self.modules():
274 | if isinstance(m, nn.Conv2d):
275 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
276 | m.weight.data.normal_(0, math.sqrt(2. / n))
277 | elif isinstance(m, nn.BatchNorm2d):
278 | m.weight.data.fill_(1)
279 | m.bias.data.zero_()
280 |
281 | def _load_pretrained_model(self, model_path):
282 | pretrain_dict = model_zoo.load_url(model_path)
283 | model_dict = {}
284 | state_dict = self.state_dict()
285 | for k, v in pretrain_dict.items():
286 | if k in state_dict:
287 | model_dict[k] = v
288 | state_dict.update(model_dict)
289 | self.load_state_dict(state_dict)
290 |
291 | def build_backbone(backbone, output_stride, BatchNorm, in_c=3):
292 | if backbone == 'resnet50':
293 | return ResNet50(output_stride, BatchNorm, in_c=in_c)
294 | elif backbone == 'resnet34':
295 | return ResNet34(output_stride, BatchNorm, in_c=in_c)
296 | elif backbone == 'resnet18':
297 | return ResNet18(output_stride, BatchNorm, in_c=in_c)
298 | else:
299 | raise NotImplementedError
300 |
301 |
--------------------------------------------------------------------------------
/model/dtcdscn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torchvision.models import ResNet
5 | import torch.nn.functional as F
6 | from functools import partial
7 |
8 |
9 | nonlinearity = partial(F.relu,inplace=True)
10 |
11 | class SELayer(nn.Module):
12 | def __init__(self, channel, reduction=16):
13 | super(SELayer, self).__init__()
14 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
15 | self.fc = nn.Sequential(
16 | nn.Linear(channel, channel // reduction, bias=False),
17 | nn.ReLU(inplace=True),
18 | nn.Linear(channel // reduction, channel, bias=False),
19 | nn.Sigmoid()
20 | )
21 |
22 | def forward(self, x):
23 | b, c, _, _ = x.size()
24 | y = self.avg_pool(x).view(b, c)
25 | y = self.fc(y).view(b, c, 1, 1)
26 | return x * y.expand_as(x)
27 |
28 | class Dblock_more_dilate(nn.Module):
29 | def __init__(self, channel):
30 | super(Dblock_more_dilate, self).__init__()
31 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
32 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
33 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
34 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
35 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
38 | if m.bias is not None:
39 | m.bias.data.zero_()
40 |
41 | def forward(self, x):
42 | dilate1_out = nonlinearity(self.dilate1(x))
43 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
44 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
45 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
46 | dilate5_out = nonlinearity(self.dilate5(dilate4_out))
47 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out
48 | return out
49 | class Dblock(nn.Module):
50 | def __init__(self, channel):
51 | super(Dblock, self).__init__()
52 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
53 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
54 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
55 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
56 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
59 | if m.bias is not None:
60 | m.bias.data.zero_()
61 |
62 | def forward(self, x):
63 | dilate1_out = nonlinearity(self.dilate1(x))
64 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
65 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
66 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
67 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out))
68 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out
69 | return out
70 |
71 | def conv3x3(in_planes, out_planes, stride=1):
72 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
73 |
74 | class SEBasicBlock(nn.Module):
75 | expansion = 1
76 |
77 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
78 | super(SEBasicBlock, self).__init__()
79 | self.conv1 = conv3x3(inplanes, planes, stride)
80 | self.bn1 = nn.BatchNorm2d(planes)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.conv2 = conv3x3(planes, planes, 1)
83 | self.bn2 = nn.BatchNorm2d(planes)
84 | self.se = SELayer(planes, reduction)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | residual = x
90 | out = self.conv1(x)
91 | out = self.bn1(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv2(out)
95 | out = self.bn2(out)
96 | out = self.se(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 | class DecoderBlock(nn.Module):
107 | def __init__(self, in_channels, n_filters):
108 | super(DecoderBlock,self).__init__()
109 |
110 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
111 | self.norm1 = nn.BatchNorm2d(in_channels // 4)
112 | self.relu1 = nonlinearity
113 | self.scse = SCSEBlock(in_channels // 4)
114 |
115 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
116 | self.norm2 = nn.BatchNorm2d(in_channels // 4)
117 | self.relu2 = nonlinearity
118 |
119 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
120 | self.norm3 = nn.BatchNorm2d(n_filters)
121 | self.relu3 = nonlinearity
122 |
123 | def forward(self, x):
124 | x = self.conv1(x)
125 | x = self.norm1(x)
126 | x = self.relu1(x)
127 | y = self.scse(x)
128 | x = x + y
129 | x = self.deconv2(x)
130 | x = self.norm2(x)
131 | x = self.relu2(x)
132 | x = self.conv3(x)
133 | x = self.norm3(x)
134 | x = self.relu3(x)
135 | return x
136 |
137 | class SCSEBlock(nn.Module):
138 | def __init__(self, channel, reduction=16):
139 | super(SCSEBlock, self).__init__()
140 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
141 |
142 | '''self.channel_excitation = nn.Sequential(nn.(channel, int(channel//reduction)),
143 | nn.ReLU(inplace=True),
144 | nn.Linear(int(channel//reduction), channel),
145 | nn.Sigmoid())'''
146 | self.channel_excitation = nn.Sequential(nn.Conv2d(channel, int(channel//reduction), kernel_size=1,
147 | stride=1, padding=0, bias=False),
148 | nn.ReLU(inplace=True),
149 | nn.Conv2d(int(channel // reduction), channel,kernel_size=1,
150 | stride=1, padding=0, bias=False),
151 | nn.Sigmoid())
152 |
153 | self.spatial_se = nn.Sequential(nn.Conv2d(channel, 1, kernel_size=1,
154 | stride=1, padding=0, bias=False),
155 | nn.Sigmoid())
156 |
157 | def forward(self, x):
158 | bahs, chs, _, _ = x.size()
159 |
160 | # Returns a new tensor with the same data as the self tensor but of a different size.
161 | chn_se = self.avg_pool(x)
162 | chn_se = self.channel_excitation(chn_se)
163 | chn_se = torch.mul(x, chn_se)
164 | spa_se = self.spatial_se(x)
165 | spa_se = torch.mul(x, spa_se)
166 | return torch.add(chn_se, 1, spa_se)
167 |
168 | class CDNet_model(nn.Module):
169 | def __init__(self, in_channels=3, block=SEBasicBlock, layers=[3, 4, 6, 3], num_classes=7):
170 | super(CDNet_model, self).__init__()
171 |
172 | filters = [64, 128, 256, 512]
173 | self.inplanes = 64
174 | self.firstconv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
175 | bias=False)
176 | self.firstbn = nn.BatchNorm2d(64)
177 | self.firstrelu = nn.ReLU(inplace=True)
178 | self.firstmaxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
179 | self.encoder1 = self._make_layer(block, 64, layers[0])
180 | self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
181 | self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
182 | self.encoder4 = self._make_layer(block, 512, layers[3], stride=2)
183 |
184 | self.decoder4 = DecoderBlock(filters[3], filters[2])
185 | self.decoder3 = DecoderBlock(filters[2], filters[1])
186 | self.decoder2 = DecoderBlock(filters[1], filters[0])
187 | self.decoder1 = DecoderBlock(filters[0], filters[0])
188 |
189 | self.dblock_master = Dblock(512)
190 | self.dblock = Dblock(512)
191 |
192 | self.decoder4_master = DecoderBlock(filters[3], filters[2])
193 | self.decoder3_master = DecoderBlock(filters[2], filters[1])
194 | self.decoder2_master = DecoderBlock(filters[1], filters[0])
195 | self.decoder1_master = DecoderBlock(filters[0], filters[0])
196 |
197 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
198 | self.finalrelu1_master = nonlinearity
199 | self.finalconv2_master = nn.Conv2d(32, 32, 3, padding=1)
200 | self.finalrelu2_master = nonlinearity
201 | self.finalconv3_master = nn.Conv2d(32, 2, 3, padding=1)
202 |
203 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
204 | self.finalrelu1 = nonlinearity
205 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
206 | self.finalrelu2 = nonlinearity
207 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
208 |
209 | for m in self.modules():
210 | if isinstance(m, nn.Conv2d):
211 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
212 | m.weight.data.normal_(0, math.sqrt(2. / n))
213 | elif isinstance(m, nn.BatchNorm2d):
214 | m.weight.data.fill_(1)
215 | m.bias.data.zero_()
216 |
217 | def _make_layer(self, block, planes, blocks, stride=1):
218 | downsample = None
219 | if stride != 1 or self.inplanes != planes * block.expansion:
220 | downsample = nn.Sequential(
221 | nn.Conv2d(self.inplanes, planes * block.expansion,
222 | kernel_size=1, stride=stride, bias=False),
223 | nn.BatchNorm2d(planes * block.expansion),
224 | )
225 |
226 | layers = []
227 | layers.append(block(self.inplanes, planes, stride, downsample))
228 | self.inplanes = planes * block.expansion
229 | for i in range(1, blocks):
230 | layers.append(block(self.inplanes, planes))
231 |
232 | return nn.Sequential(*layers)
233 |
234 | def forward(self, x, y):
235 | # Encoder_1
236 | x = self.firstconv(x)
237 | x = self.firstbn(x)
238 | x = self.firstrelu(x)
239 | x = self.firstmaxpool(x)
240 |
241 | e1_x = self.encoder1(x)
242 | e2_x = self.encoder2(e1_x)
243 | e3_x = self.encoder3(e2_x)
244 | e4_x = self.encoder4(e3_x)
245 |
246 | # # Center_1
247 | e4_x_center = self.dblock(e4_x)
248 |
249 | # Decoder_1
250 | d4_x = self.decoder4(e4_x_center) + e3_x
251 | d3_x = self.decoder3(d4_x) + e2_x
252 | d2_x = self.decoder2(d3_x) + e1_x
253 | d1_x = self.decoder1(d2_x)
254 |
255 | out1 = self.finaldeconv1(d1_x)
256 | out1 = self.finalrelu1(out1)
257 | out1 = self.finalconv2(out1)
258 | out1 = self.finalrelu2(out1)
259 | out1 = self.finalconv3(out1)
260 |
261 | # Encoder_2
262 | y = self.firstconv(y)
263 | y = self.firstbn(y)
264 | y = self.firstrelu(y)
265 | y = self.firstmaxpool(y)
266 |
267 | e1_y = self.encoder1(y)
268 | e2_y = self.encoder2(e1_y)
269 | e3_y = self.encoder3(e2_y)
270 | e4_y = self.encoder4(e3_y)
271 |
272 | # # Center_2
273 | e4_y_center = self.dblock(e4_y)
274 |
275 | # Decoder_2
276 | d4_y = self.decoder4(e4_y_center) + e3_y
277 | d3_y = self.decoder3(d4_y) + e2_y
278 | d2_y = self.decoder2(d3_y) + e1_y
279 | d1_y = self.decoder1(d2_y)
280 |
281 | out2 = self.finaldeconv1(d1_y)
282 | out2 = self.finalrelu1(out2)
283 | out2 = self.finalconv2(out2)
284 | out2 = self.finalrelu2(out2)
285 | out2 = self.finalconv3(out2)
286 |
287 | # center_master
288 | e4 = self.dblock_master(e4_x - e4_y)
289 | # decoder_master
290 | d4 = self.decoder4_master(e4) + e3_x - e3_y
291 | d3 = self.decoder3_master(d4) + e2_x - e2_y
292 | d2 = self.decoder2_master(d3) + e1_x - e1_y
293 | d1 = self.decoder1_master(d2)
294 |
295 | out = self.finaldeconv1_master(d1)
296 | out = self.finalrelu1_master(out)
297 | out = self.finalconv2_master(out)
298 | out = self.finalrelu2_master(out)
299 | out = self.finalconv3_master(out)
300 |
301 | return out,out1,out2
302 |
303 | # return F.sigmoid(out),F.sigmoid(out1),F.sigmoid(out2)
304 |
305 |
306 |
307 | def CDNet34(in_channels, **kwargs):
308 |
309 | model = CDNet_model(in_channels, SEBasicBlock, [3, 4, 6, 3], **kwargs)
310 |
311 | return model
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from einops import rearrange
3 | from torch import nn
4 | import torch
5 |
6 | class Residual(nn.Module):
7 | def __init__(self, fn):
8 | super().__init__()
9 | self.fn = fn
10 | def forward(self, x, **kwargs):
11 | return self.fn(x, **kwargs) + x
12 |
13 |
14 | class Residual2(nn.Module):
15 | def __init__(self, fn):
16 | super().__init__()
17 | self.fn = fn
18 | def forward(self, x, x2, **kwargs):
19 | return self.fn(x, x2, **kwargs) + x
20 |
21 |
22 | class PreNorm(nn.Module):
23 | def __init__(self, dim, fn):
24 | super().__init__()
25 | self.norm = nn.LayerNorm(dim)
26 | self.fn = fn
27 | def forward(self, x, **kwargs):
28 | return self.fn(self.norm(x), **kwargs)
29 |
30 | class FeedForward(nn.Module):
31 | def __init__(self, dim, hidden_dim, dropout = 0.):
32 | super().__init__()
33 | self.net = nn.Sequential(
34 | nn.Linear(dim, hidden_dim),
35 | nn.GELU(),
36 | nn.Dropout(dropout),
37 | nn.Linear(hidden_dim, dim),
38 | nn.Dropout(dropout)
39 | )
40 | def forward(self, x):
41 | return self.net(x)
42 |
43 | class Attention(nn.Module):
44 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
45 | super().__init__()
46 | inner_dim = dim_head * heads
47 | project_out = not (heads == 1 and dim_head == dim)
48 |
49 | self.heads = heads
50 | self.scale = dim_head ** -0.5
51 |
52 | self.attend = nn.Softmax(dim = -1)
53 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
54 |
55 | self.to_out = nn.Sequential(
56 | nn.Linear(inner_dim, dim),
57 | nn.Dropout(dropout)
58 | ) if project_out else nn.Identity()
59 |
60 | def forward(self, x):
61 | qkv = self.to_qkv(x).chunk(3, dim = -1)
62 |
63 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
64 |
65 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
66 |
67 | attn = self.attend(dots)
68 |
69 | out = torch.matmul(attn, v)
70 |
71 | out = rearrange(out, 'b h n d -> b n (h d)')
72 | out = self.to_out(out)
73 |
74 | return out
75 |
76 | class Transformer(nn.Module):
77 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
78 | super().__init__()
79 | self.layers = nn.ModuleList([])
80 | for _ in range(depth):
81 | self.layers.append(nn.ModuleList([
82 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
83 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
84 | ]))
85 | def forward(self, x):
86 | for attn, ff in self.layers:
87 | x = attn(x) + x
88 | x = ff(x) + x
89 | return x
90 |
91 |
92 |
93 | class Cross_Attention(nn.Module):
94 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True):
95 | super().__init__()
96 | inner_dim = dim_head * heads
97 | self.heads = heads
98 | self.scale = dim ** -0.5
99 |
100 | self.softmax = softmax
101 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
102 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
103 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
104 |
105 | self.to_out = nn.Sequential(
106 | nn.Linear(inner_dim, dim),
107 | nn.Dropout(dropout)
108 | )
109 |
110 | def forward(self, x, m, mask = None):
111 |
112 | b, n, _, h = *x.shape, self.heads
113 | q = self.to_q(x)
114 | k = self.to_k(m)
115 | v = self.to_v(m)
116 |
117 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v])
118 |
119 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
120 | mask_value = -torch.finfo(dots.dtype).max
121 |
122 | if mask is not None:
123 | mask = F.pad(mask.flatten(1), (1, 0), value = True)
124 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
125 | mask = mask[:, None, :] * mask[:, :, None]
126 | dots.masked_fill_(~mask, mask_value)
127 | del mask
128 |
129 | if self.softmax:
130 | attn = dots.softmax(dim=-1)
131 | else:
132 | attn = dots
133 |
134 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
135 | out = rearrange(out, 'b h n d -> b n (h d)')
136 | out = self.to_out(out)
137 |
138 | return out
139 |
140 | class PreNorm2(nn.Module):
141 | def __init__(self, dim, fn):
142 | super().__init__()
143 | self.norm = nn.LayerNorm(dim)
144 | self.fn = fn
145 | def forward(self, x, x2, **kwargs):
146 | return self.fn(self.norm(x), self.norm(x2), **kwargs)
147 |
148 |
149 | class TransformerDecoder(nn.Module):
150 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True):
151 | super().__init__()
152 | self.layers = nn.ModuleList([])
153 | for _ in range(depth):
154 | self.layers.append(nn.ModuleList([
155 | Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads,
156 | dim_head = dim_head, dropout = dropout,
157 | softmax=softmax))),
158 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
159 | ]))
160 | def forward(self, x, m, mask = None):
161 | """target(query), memory"""
162 | for attn, ff in self.layers:
163 | x = attn(x, m, mask = mask)
164 | x = ff(x)
165 | return x
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from mmcv.runner import BaseModule
5 | from torch import Tensor
6 | from torch import Tensor, reshape, stack
7 | from .backbone import build_backbone
8 | from .modules import TransformerDecoder, Transformer
9 | from einops import rearrange
10 | from torch.nn import Upsample
11 |
12 |
13 | class SpatialAttention(nn.Module):
14 | def __init__(self, kernel_size=7):
15 | super(SpatialAttention, self).__init__()
16 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
17 | padding = 3 if kernel_size == 7 else 1
18 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) # 7,3 3,1
19 | self.sigmoid = nn.Sigmoid()
20 | def forward(self, x):
21 | avg_out = torch.mean(x, dim=1, keepdim=True)
22 | max_out, _ = torch.max(x, dim=1, keepdim=True)
23 | x = torch.cat([avg_out, max_out], dim=1)
24 | x = self.conv1(x)
25 | return self.sigmoid(x)
26 |
27 | class ChannelExchange(BaseModule):
28 |
29 |
30 | def __init__(self, p=2):
31 | super(ChannelExchange, self).__init__()
32 | self.p = p
33 | self.sam = SpatialAttention()
34 | def forward(self, x1, x2):
35 | N, c, h, w = x1.shape
36 |
37 | exchange_map = torch.arange(c) % self.p == 0
38 | exchange_mask = exchange_map.unsqueeze(0).expand((N, -1))
39 |
40 | out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2)
41 | out_x1[~exchange_mask, ...] = x1[~exchange_mask, ...]
42 | out_x2[~exchange_mask, ...] = x2[~exchange_mask, ...]
43 | out_x1[exchange_mask, ...] = x2[exchange_mask, ...]
44 | out_x2[exchange_mask, ...] = x1[exchange_mask, ...]
45 |
46 | return out_x1, out_x2
47 |
48 |
49 | class SpatialExchange(BaseModule):
50 |
51 |
52 | def __init__(self, p=2):
53 | super(SpatialExchange, self).__init__()
54 | self.p = p
55 | self.sam = SpatialAttention()
56 | def forward(self, x1, x2):
57 | N, c, h, w = x1.shape
58 | exchange_mask = torch.arange(w) % self.p == 0
59 |
60 | out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2)
61 | out_x1[..., ~exchange_mask] = x1[..., ~exchange_mask]
62 | out_x2[..., ~exchange_mask] = x2[..., ~exchange_mask]
63 | out_x1[..., exchange_mask] = x2[..., exchange_mask]
64 | out_x2[..., exchange_mask] = x1[..., exchange_mask]
65 |
66 | return out_x1, out_x2
67 |
68 | class token_encoder(nn.Module):
69 | def __init__(self, in_chan = 32, token_len = 4, heads = 8):
70 | super(token_encoder, self).__init__()
71 | self.token_len = token_len
72 | self.conv_a = nn.Conv2d(in_chan, token_len, kernel_size=1, padding=0)
73 | self.pos_embedding = nn.Parameter(torch.randn(1, token_len, in_chan))
74 | self.transformer = Transformer(dim=in_chan, depth=1, heads=heads, dim_head=64, mlp_dim=64, dropout=0)
75 | def forward(self, x):
76 | b, c, h, w = x.shape
77 |
78 | spatial_attention = self.conv_a(x)
79 | spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous()
80 | spatial_attention = torch.softmax(spatial_attention, dim=-1)
81 | x = x.view([b, c, -1]).contiguous()
82 |
83 | tokens = torch.einsum('bln, bcn->blc', spatial_attention, x)
84 |
85 | tokens += self.pos_embedding
86 | x = self.transformer(tokens)
87 | return x
88 |
89 | class token_decoder(nn.Module):
90 | def __init__(self, in_chan = 32, size = 32, heads = 8):
91 | super(token_decoder, self).__init__()
92 | self.pos_embedding_decoder = nn.Parameter(torch.randn(1, in_chan, size, size))
93 | self.transformer_decoder = TransformerDecoder(dim=in_chan, depth=1, heads=heads, dim_head=True, mlp_dim=in_chan*2, dropout=0,softmax=in_chan)
94 |
95 | def forward(self, x, m):
96 | b, c, h, w = x.shape
97 | x = x + self.pos_embedding_decoder
98 | x = rearrange(x, 'b c h w -> b (h w) c')
99 | x = self.transformer_decoder(x, m)
100 | x = rearrange(x, 'b (h w) c -> b c h w', h=h)
101 | return x
102 |
103 |
104 | class context_aggregator(nn.Module):
105 | def __init__(self, in_chan=32, size=32):
106 | super(context_aggregator, self).__init__()
107 | self.token_encoder = token_encoder(in_chan=in_chan, token_len=4)
108 | self.token_decoder = token_decoder(in_chan = 32, size = size, heads = 8)
109 | def forward(self, feature):
110 | token = self.token_encoder(feature)
111 | out = self.token_decoder(feature, token)
112 | return out
113 |
114 | class Classifier(nn.Module):
115 | def __init__(self, in_chan=32, n_class=2):
116 | super(Classifier, self).__init__()
117 | self.head = nn.Sequential(
118 | nn.Conv2d(in_chan * 2, in_chan, kernel_size=3, padding=1, stride=1, bias=False),
119 | nn.BatchNorm2d(in_chan),
120 | nn.ReLU(),
121 | nn.Conv2d(in_chan, n_class, kernel_size=3, padding=1, stride=1))
122 | def forward(self, x):
123 | x = self.head(x)
124 | return x
125 |
126 | class Conv_stride(nn.Module):
127 | def __init__(self,in_channels=32):
128 | super(Conv_stride, self).__init__()
129 | self.conv = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False)
130 | self.batchNorm = nn.BatchNorm2d(32, momentum=0.1)
131 | self.relu = nn.ReLU(inplace=True)
132 | def forward(self, x):
133 | x = self.conv(x)
134 | x = self.batchNorm(x)
135 | x = self.relu(x)
136 | return x
137 |
138 | class Upsample(nn.Module):
139 | def __init__(self,scale_factor=2):
140 | super(Upsample, self).__init__()
141 | self.conv = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
142 | self.batchNorm = nn.BatchNorm2d(32, momentum=0.1)
143 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode="bicubic", align_corners=True)
144 | def forward(self, x):
145 | x = self.conv(x)
146 | x = self.batchNorm(x)
147 | x = self.upsample(x)
148 | return x
149 |
150 | class CDNet(nn.Module):
151 | def __init__(self, backbone='resnet50', output_stride=16, img_size = 512, img_chan=3, chan_num = 32, n_class =2):
152 | super(CDNet, self).__init__()
153 | BatchNorm = nn.BatchNorm2d
154 |
155 | self.backbone = build_backbone(backbone, output_stride, BatchNorm, img_chan)
156 |
157 | self.CA_s16 = context_aggregator(in_chan=chan_num, size=img_size//16)
158 | self.CA_s8 = context_aggregator(in_chan=chan_num, size=img_size//8)
159 | self.CA_s4 = context_aggregator(in_chan=chan_num, size=img_size//4)
160 |
161 | self.conv_s8 = nn.Conv2d(chan_num*2, chan_num, kernel_size=3, padding=1)
162 | self.conv_s4 = nn.Conv2d(chan_num*3, chan_num, kernel_size=3, padding=1)
163 |
164 | self.upsamplex2 = nn.Upsample(scale_factor=2, mode="bicubic", align_corners=True)
165 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode="bicubic", align_corners=True)
166 |
167 | self.EX = ChannelExchange()
168 | self.SE = SpatialExchange()
169 |
170 | self.liner1 = nn.Linear(1, 1)
171 | self.liner2 = nn.Linear(1, 1)
172 | self.relu = nn.ReLU(inplace=True)
173 | self.sigmoid = nn.Sigmoid()
174 |
175 | self.classifier1 = Classifier(n_class = n_class)
176 | self.classifier2 = Classifier(n_class = n_class)
177 | self.classifier3 = Classifier(n_class = n_class)
178 |
179 | self.sam = SpatialAttention()
180 |
181 | self.conv_stride_1_1 = Conv_stride()
182 | self.conv_stride_1_2 = Conv_stride()
183 | self.conv_stride_1_3 = Conv_stride()
184 | self.conv_stride_2_1 = Conv_stride()
185 | self.conv_stride_2_2 = Conv_stride()
186 | self.conv_stride_2_3 = Conv_stride()
187 |
188 | self.upsample_1_1 = Upsample()
189 | self.upsample_1_2 = Upsample()
190 | self.upsample_1_3 = Upsample(scale_factor=4)
191 | self.upsample_2_1 = Upsample()
192 | self.upsample_2_2 = Upsample()
193 | self.upsample_2_3 = Upsample(scale_factor=4)
194 |
195 | def forward(self, img1, img2):
196 | # CNN backbone, feature extractor
197 | out1_s16, out1_s8, out1_s4 = self.backbone(img1)
198 | out2_s16, out2_s8, out2_s4 = self.backbone(img2)
199 |
200 | out1_s16, out2_s16 = self.EX(out1_s16, out2_s16)
201 | out1_s8, out2_s8 = self.EX(out1_s8, out2_s8)
202 | out1_s4, out2_s4 = self.SE(out1_s4, out2_s4)
203 |
204 | out1_s16 = out1_s16*self.sam(out1_s16)
205 | out2_s16 = out2_s16 * self.sam(out2_s16)
206 |
207 | out1_s4 = out1_s4*self.sam(out1_s4)
208 | out2_s4 = out2_s4 * self.sam(out2_s4)
209 |
210 | out2_s8 = out2_s8*self.sam(out2_s8)
211 | out1_s8 = out1_s8 * self.sam(out1_s8)
212 |
213 | out1_s4_down_1 = self.conv_stride_1_1(out1_s4)
214 | out1_s4_down_2 = self.conv_stride_1_2(out1_s4_down_1)
215 |
216 | out1_s8_up = self.upsample_1_1(out1_s8)
217 | out1_s8_down = self.conv_stride_1_3(out1_s8)
218 |
219 | out1_s16_up_1 = self.upsample_1_2(out1_s16)
220 | out1_s16_up_2 = self.upsample_1_3(out1_s16)
221 |
222 | out1_s16 = out1_s16 + out1_s4_down_2 + out1_s8_down
223 | out1_s8 = out1_s8 + out1_s4_down_1 + out1_s16_up_1
224 | out1_s4 = out1_s4 + out1_s16_up_2 + out1_s8_up
225 |
226 | out2_s4_down_1 = self.conv_stride_2_1(out2_s4)
227 | out2_s4_down_2 = self.conv_stride_2_2(out2_s4_down_1)
228 |
229 | out2_s8_up = self.upsample_2_1(out2_s8)
230 | out2_s8_down = self.conv_stride_2_3(out2_s8)
231 |
232 | out2_s16_up_1 = self.upsample_2_2(out2_s16)
233 | out2_s16_up_2 = self.upsample_2_3(out2_s16)
234 |
235 | out2_s16 = out2_s16 + out2_s4_down_2 + out2_s8_down
236 | out2_s8 = out2_s8 + out2_s4_down_1 + out2_s16_up_1
237 | out2_s4 = out2_s4 + out2_s16_up_2 + out2_s8_up
238 |
239 | x1_s16 = self.CA_s16(out1_s16)
240 | x2_s16 = self.CA_s16(out2_s16) # [8,32,32,32]
241 |
242 | x1_s8 = self.CA_s8(out1_s8)
243 | x2_s8 = self.CA_s8(out2_s8) # [8,32,64,64]
244 |
245 | x1 = self.CA_s4(out1_s4) # [8,32,128,128]
246 | x2 = self.CA_s4(out2_s4)
247 |
248 | x_s16 = x1_s16 + x2_s16
249 | weight_s16 = F.adaptive_avg_pool2d(x_s16, (1, 1))
250 |
251 | x_s8 = x1_s8 + x2_s8
252 | weight_s8 = F.adaptive_avg_pool2d(x_s8, (1, 1))
253 |
254 | x_s = x1 + x2
255 | weight_sx = F.adaptive_avg_pool2d(x_s, (1, 1))
256 |
257 | weight = weight_sx + weight_s8 + weight_s16
258 |
259 | weight = self.liner1(weight)
260 | weight = self.relu(weight)
261 | weight = self.liner2(weight)
262 | weight = self.sigmoid(weight)
263 |
264 | x1_s16 = weight * x1_s16
265 | x2_s16 = weight * x2_s16
266 |
267 | x1_s8 = weight * x1_s8
268 | x2_s8 = weight * x2_s8
269 |
270 | x1 = weight * x1
271 | x2 = weight * x2
272 |
273 | x16 = torch.cat([x1_s16, x2_s16], dim=1)
274 | x8 = torch.cat([x1_s8, x2_s8], dim=1)
275 | x = torch.cat([x1, x2], dim=1)
276 |
277 | x16 = F.interpolate(x16, size=img1.shape[2:], mode='bicubic', align_corners=True)
278 | x8 = F.interpolate(x8, size=img1.shape[2:], mode='bicubic', align_corners=True)
279 | x = F.interpolate(x, size=img1.shape[2:], mode='bicubic', align_corners=True)
280 |
281 | x = self.classifier3(x)
282 | x8 = self.classifier2(x8)
283 | x16 = self.classifier1(x16)
284 |
285 | return x, x8, x16
286 |
287 | def freeze_bn(self):
288 | for m in self.modules():
289 | if isinstance(m, nn.BatchNorm2d):
290 | m.eval()
291 |
--------------------------------------------------------------------------------
/model/siamunet_conc.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class SiamUnet_conc(nn.Module):
11 | """SiamUnet_conc segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(SiamUnet_conc, self).__init__()
15 |
16 | self.input_nbr = input_nbr
17 |
18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.2)
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.2)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.2)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.2)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.2)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.2)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.2)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.2)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.2)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.2)
51 |
52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
53 |
54 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1)
55 | self.bn43d = nn.BatchNorm2d(128)
56 | self.do43d = nn.Dropout2d(p=0.2)
57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
58 | self.bn42d = nn.BatchNorm2d(128)
59 | self.do42d = nn.Dropout2d(p=0.2)
60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
61 | self.bn41d = nn.BatchNorm2d(64)
62 | self.do41d = nn.Dropout2d(p=0.2)
63 |
64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
65 |
66 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1)
67 | self.bn33d = nn.BatchNorm2d(64)
68 | self.do33d = nn.Dropout2d(p=0.2)
69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
70 | self.bn32d = nn.BatchNorm2d(64)
71 | self.do32d = nn.Dropout2d(p=0.2)
72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(32)
74 | self.do31d = nn.Dropout2d(p=0.2)
75 |
76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
77 |
78 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1)
79 | self.bn22d = nn.BatchNorm2d(32)
80 | self.do22d = nn.Dropout2d(p=0.2)
81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
82 | self.bn21d = nn.BatchNorm2d(16)
83 | self.do21d = nn.Dropout2d(p=0.2)
84 |
85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
86 |
87 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1)
88 | self.bn12d = nn.BatchNorm2d(16)
89 | self.do12d = nn.Dropout2d(p=0.2)
90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
91 |
92 | self.sm = nn.LogSoftmax(dim=1)
93 |
94 | def forward(self, x1, x2):
95 |
96 | """Forward method."""
97 | # Stage 1
98 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
99 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
100 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
101 |
102 |
103 | # Stage 2
104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
107 |
108 | # Stage 3
109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
113 |
114 | # Stage 4
115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
119 |
120 |
121 | ####################################################
122 | # Stage 1
123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
126 |
127 | # Stage 2
128 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
129 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
130 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
131 |
132 | # Stage 3
133 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
134 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
135 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
136 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
137 |
138 | # Stage 4
139 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
140 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
141 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
142 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
143 |
144 |
145 | ####################################################
146 | # Stage 4d
147 | x4d = self.upconv4(x4p)
148 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
149 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
150 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
151 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
152 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
153 |
154 | # Stage 3d
155 | x3d = self.upconv3(x41d)
156 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
157 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
158 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
159 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
160 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
161 |
162 | # Stage 2d
163 | x2d = self.upconv2(x31d)
164 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
165 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
166 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
167 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
168 |
169 | # Stage 1d
170 | x1d = self.upconv1(x21d)
171 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
172 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
173 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
174 | x11d = self.conv11d(x12d)
175 |
176 | return self.sm(x11d)
177 |
178 |
179 |
--------------------------------------------------------------------------------
/model/siamunet_diff.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class SiamUnet_diff(nn.Module):
11 | """SiamUnet_diff segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(SiamUnet_diff, self).__init__()
15 |
16 | self.input_nbr = input_nbr
17 |
18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.2)
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.2)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.2)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.2)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.2)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.2)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.2)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.2)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.2)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.2)
51 |
52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
53 |
54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
55 | self.bn43d = nn.BatchNorm2d(128)
56 | self.do43d = nn.Dropout2d(p=0.2)
57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
58 | self.bn42d = nn.BatchNorm2d(128)
59 | self.do42d = nn.Dropout2d(p=0.2)
60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
61 | self.bn41d = nn.BatchNorm2d(64)
62 | self.do41d = nn.Dropout2d(p=0.2)
63 |
64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
65 |
66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
67 | self.bn33d = nn.BatchNorm2d(64)
68 | self.do33d = nn.Dropout2d(p=0.2)
69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
70 | self.bn32d = nn.BatchNorm2d(64)
71 | self.do32d = nn.Dropout2d(p=0.2)
72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(32)
74 | self.do31d = nn.Dropout2d(p=0.2)
75 |
76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
77 |
78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
79 | self.bn22d = nn.BatchNorm2d(32)
80 | self.do22d = nn.Dropout2d(p=0.2)
81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
82 | self.bn21d = nn.BatchNorm2d(16)
83 | self.do21d = nn.Dropout2d(p=0.2)
84 |
85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
86 |
87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
88 | self.bn12d = nn.BatchNorm2d(16)
89 | self.do12d = nn.Dropout2d(p=0.2)
90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
91 |
92 | self.sm = nn.LogSoftmax(dim=1)
93 |
94 | def forward(self, x1, x2):
95 |
96 |
97 | """Forward method."""
98 | # Stage 1
99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
102 |
103 |
104 | # Stage 2
105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
108 |
109 | # Stage 3
110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
114 |
115 | # Stage 4
116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
120 |
121 | ####################################################
122 | # Stage 1
123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
126 |
127 |
128 | # Stage 2
129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
132 |
133 | # Stage 3
134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
138 |
139 | # Stage 4
140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
144 |
145 |
146 |
147 | # Stage 4d
148 | x4d = self.upconv4(x4p)
149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1)
151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
154 |
155 | # Stage 3d
156 | x3d = self.upconv3(x41d)
157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1)
159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
162 |
163 | # Stage 2d
164 | x2d = self.upconv2(x31d)
165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1)
167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
169 |
170 | # Stage 1d
171 | x1d = self.upconv1(x21d)
172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1)
174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
175 | x11d = self.conv11d(x12d)
176 |
177 | return self.sm(x11d)
178 |
179 |
180 |
--------------------------------------------------------------------------------
/model/unet.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class Unet(nn.Module):
11 | """EF segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(Unet, self).__init__()
15 |
16 | self.input_nbr = input_nbr
17 |
18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.2)
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.2)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.2)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.2)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.2)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.2)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.2)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.2)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.2)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.2)
51 |
52 |
53 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
54 |
55 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
56 | self.bn43d = nn.BatchNorm2d(128)
57 | self.do43d = nn.Dropout2d(p=0.2)
58 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
59 | self.bn42d = nn.BatchNorm2d(128)
60 | self.do42d = nn.Dropout2d(p=0.2)
61 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
62 | self.bn41d = nn.BatchNorm2d(64)
63 | self.do41d = nn.Dropout2d(p=0.2)
64 |
65 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
66 |
67 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
68 | self.bn33d = nn.BatchNorm2d(64)
69 | self.do33d = nn.Dropout2d(p=0.2)
70 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
71 | self.bn32d = nn.BatchNorm2d(64)
72 | self.do32d = nn.Dropout2d(p=0.2)
73 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
74 | self.bn31d = nn.BatchNorm2d(32)
75 | self.do31d = nn.Dropout2d(p=0.2)
76 |
77 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
78 |
79 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
80 | self.bn22d = nn.BatchNorm2d(32)
81 | self.do22d = nn.Dropout2d(p=0.2)
82 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
83 | self.bn21d = nn.BatchNorm2d(16)
84 | self.do21d = nn.Dropout2d(p=0.2)
85 |
86 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
87 |
88 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
89 | self.bn12d = nn.BatchNorm2d(16)
90 | self.do12d = nn.Dropout2d(p=0.2)
91 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
92 |
93 | self.sm = nn.LogSoftmax(dim=1)
94 |
95 | def forward(self, x1, x2):
96 |
97 | x = torch.cat((x1, x2), 1)
98 |
99 | """Forward method."""
100 | # Stage 1
101 | x11 = self.do11(F.relu(self.bn11(self.conv11(x))))
102 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11))))
103 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2)
104 |
105 | # Stage 2
106 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
107 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21))))
108 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2)
109 |
110 | # Stage 3
111 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
112 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
113 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32))))
114 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2)
115 |
116 | # Stage 4
117 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
118 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
119 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42))))
120 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2)
121 |
122 |
123 | # Stage 4d
124 | x4d = self.upconv4(x4p)
125 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2)))
126 | x4d = torch.cat((pad4(x4d), x43), 1)
127 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
128 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
129 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
130 |
131 | # Stage 3d
132 | x3d = self.upconv3(x41d)
133 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2)))
134 | x3d = torch.cat((pad3(x3d), x33), 1)
135 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
136 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
137 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
138 |
139 | # Stage 2d
140 | x2d = self.upconv2(x31d)
141 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2)))
142 | x2d = torch.cat((pad2(x2d), x22), 1)
143 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
144 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
145 |
146 | # Stage 1d
147 | x1d = self.upconv1(x21d)
148 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2)))
149 | x1d = torch.cat((pad1(x1d), x12), 1)
150 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
151 | x11d = self.conv11d(x12d)
152 |
153 | return self.sm(x11d)
154 |
155 |
156 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | import torch.optim as optim
4 | import torch.utils.data
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 | from data_utils import LoadDatasetFromFolder, DA_DatasetFromFolder, calMetric_iou,TestDatasetFromFolder
8 | import numpy as np
9 | import random
10 | from model.network import CDNet
11 | from train_options_HRSCD import parser
12 | import itertools
13 | from loss.losses import cross_entropy
14 | from ever import opt
15 | import time
16 | from collections import OrderedDict
17 | import ever as er
18 | import cv2 as cv
19 | import numpy as np
20 | import logging
21 | from PIL import Image
22 | import matplotlib.pyplot as plt
23 |
24 | args = parser.parse_args()
25 | os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
26 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
27 |
28 | def seed_torch(seed=2022):
29 | random.seed(seed)
30 | os.environ['PYTHONHASHSEED'] = str(seed)
31 | np.random.seed(seed)
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed(seed)
34 | seed_torch(2022)
35 |
36 | COLOR_MAP = OrderedDict(
37 | Background=(255, 255, 255),
38 | Building=(255, 0, 0),
39 | )
40 |
41 | def tes(mloss):
42 | CDNet.eval()
43 | with torch.no_grad():
44 | test_bar = tqdm(test_loader)
45 | testing_results = {'batch_sizes': 0, 'IoU': 0}
46 |
47 | metric_op = er.metric.PixelMetric(2, logdir=None, logger=None)
48 | for hr_img1, hr_img2, label, name in test_bar:
49 | testing_results['batch_sizes'] += args.val_batchsize
50 |
51 | hr_img1 = hr_img1.to(device, dtype=torch.float)
52 | hr_img2 = hr_img2.to(device, dtype=torch.float)
53 | label = label.to(device, dtype=torch.float)
54 |
55 | label = torch.argmax(label, 1).unsqueeze(1).float()
56 |
57 | cd_map, _, _ = CDNet(hr_img1, hr_img2)
58 | cd_map = torch.argmax(cd_map, 1).unsqueeze(1).float()
59 |
60 |
61 | gt_value = (label > 0).float()
62 | prob = (cd_map > 0).float()
63 | prob = prob.cpu().detach().numpy()
64 |
65 | gt_value = gt_value.cpu().detach().numpy()
66 | gt_value = np.squeeze(gt_value)
67 | result = np.squeeze(prob)
68 | metric_op.forward(gt_value, result)
69 | re = metric_op.summary_all()
70 |
71 | CDNet.train()
72 | return mloss
73 |
74 | import cv2 as cv
75 | if __name__ == '__main__':
76 | mloss = 0
77 |
78 | test_set = TestDatasetFromFolder(args, args.hr1_test, args.hr2_test, args.lab_test)
79 |
80 | test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers, batch_size=args.val_batchsize, shuffle=True)
81 |
82 | # define model
83 | CDNet = CDNet(img_size = args.img_size).to(device, dtype=torch.float)
84 |
85 | CDNet.load_state_dict(torch.load('/./best.pth'),strict=False)
86 | mloss = tes(mloss)
87 |
88 |
89 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | import torch.optim as optim
4 | import torch.utils.data
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 | from data_utils import DA_DatasetFromFolder, TestDatasetFromFolder, LoadDatasetFromFolder
8 | import numpy as np
9 | import random
10 | from model.network import CDNet
11 | from train_options import parser
12 | import itertools
13 | from loss.losses import cross_entropy
14 | import time
15 | from collections import OrderedDict
16 | import ever as er
17 |
18 | args = parser.parse_args()
19 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21 |
22 | def seed_torch(seed=2022):
23 | random.seed(seed)
24 | os.environ['PYTHONHASHSEED'] = str(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | seed_torch(2022)
29 |
30 | COLOR_MAP = OrderedDict(
31 | Background=(255, 255, 255),
32 | Building=(255, 0, 0),
33 | )
34 |
35 | def val(mloss):
36 | CDNet.eval()
37 | with torch.no_grad():
38 | val_bar = tqdm(test_loader)
39 | valing_results = {'batch_sizes': 0, 'IoU': 0}
40 |
41 | metric_op = er.metric.PixelMetric(2, logdir=None, logger=None)
42 | for hr_img1, hr_img2, label, name in val_bar:
43 |
44 | valing_results['batch_sizes'] += args.val_batchsize
45 | hr_img1 = hr_img1.to(device, dtype=torch.float)
46 | hr_img2 = hr_img2.to(device, dtype=torch.float)
47 | label = label.to(device, dtype=torch.float)
48 |
49 | label = torch.argmax(label, 1).unsqueeze(1).float()
50 |
51 | cd_map, _, _ = CDNet(hr_img1, hr_img2)
52 |
53 | cd_map = torch.argmax(cd_map, 1).unsqueeze(1).float()
54 |
55 | gt_value = (label > 0).float()
56 | prob = (cd_map > 0).float()
57 | prob = prob.cpu().detach().numpy()
58 |
59 | gt_value = gt_value.cpu().detach().numpy()
60 | gt_value = np.squeeze(gt_value)
61 | result = np.squeeze(prob)
62 | metric_op.forward(gt_value, result)
63 | re = metric_op.summary_all()
64 |
65 | test_loss = re.rows[1][1]
66 | if test_loss > mloss or epoch == 1:
67 | torch.save(CDNet.state_dict(), args.model_dir +'_'+ str(test_loss) +'_best.pth')
68 | CDNet.train()
69 | return test_loss
70 |
71 | def train_epoch():
72 | CDNet.train()
73 | for hr_img1, hr_img2, label in train_bar:
74 | running_results['batch_sizes'] += args.batchsize
75 |
76 | hr_img1 = hr_img1.to(device, dtype=torch.float)
77 | hr_img2 = hr_img2.to(device, dtype=torch.float)
78 | label = label.to(device, dtype=torch.float)
79 | label = torch.argmax(label, 1).unsqueeze(1).float()
80 |
81 | result1, result2, result3 = CDNet(hr_img1, hr_img2)
82 |
83 | CD_loss = CDcriterionCD(result1, label) + CDcriterionCD(result2, label) + CDcriterionCD(result3, label)
84 |
85 | CDNet.zero_grad()
86 | CD_loss.backward()
87 | optimizer.step()
88 |
89 | running_results['CD_loss'] += CD_loss.item() * args.batchsize
90 |
91 | train_bar.set_description(
92 | desc='[%d/%d] loss: %.4f' % (
93 | epoch, args.num_epochs,
94 | running_results['CD_loss'] / running_results['batch_sizes'],))
95 |
96 |
97 | if __name__ == '__main__':
98 | mloss = 0
99 |
100 | # load data
101 | train_set = DA_DatasetFromFolder(args.hr1_train, args.hr2_train, args.lab_train, crop=False)
102 | test_set = TestDatasetFromFolder(args, args.hr1_test, args.hr2_test, args.lab_test)
103 |
104 | train_loader = DataLoader(dataset=train_set, num_workers=args.num_workers, batch_size=args.batchsize, shuffle=True)
105 | test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers, batch_size=args.val_batchsize, shuffle=True)
106 |
107 | # define model
108 | CDNet = CDNet(img_size = args.img_size).to(device, dtype=torch.float)
109 |
110 | # set optimization
111 | optimizer = optim.AdamW(itertools.chain(CDNet.parameters()), lr= args.lr, betas=(0.9, 0.999),weight_decay=0.01)
112 | CDcriterionCD = cross_entropy().to(device, dtype=torch.float)
113 |
114 | # training
115 | for epoch in range(1, args.num_epochs + 1):
116 | train_bar = tqdm(train_loader)
117 | running_results = {'batch_sizes': 0, 'SR_loss':0, 'CD_loss':0, 'loss': 0 }
118 | train_epoch()
119 | mloss = val(mloss)
120 |
121 |
122 |
--------------------------------------------------------------------------------
/train_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | file_train_path_time1 = '/./train/time1'
3 | file_train_path_time2 = '/./train/time2'
4 | file_train_path_label = '/./train/label'
5 |
6 | file_test_path_time1 = '/./test/time1'
7 | file_test_path_time2 = '/./test/time2'
8 | file_test_path_label = '/./test/label'
9 |
10 | #training options
11 | parser = argparse.ArgumentParser(description='Training Change Detection Network')
12 |
13 | # training parameters
14 | parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
15 | parser.add_argument('--batchsize', default=8, type=int, help='batchsize')
16 | parser.add_argument('--val_batchsize', default=8, type=int, help='batchsize for validation')
17 | parser.add_argument('--num_workers', default=24, type=int, help='num of workers')
18 | parser.add_argument('--n_class', default=2, type=int, help='number of class')
19 | parser.add_argument('--gpu_id', default="0", type=str, help='which gpu to run.')
20 | parser.add_argument('--suffix', default=['.png','.jpg','.tif'], type=list, help='the suffix of the image files.')
21 | parser.add_argument('--img_size', default=256, type=int, help='imagesize')
22 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
23 |
24 |
25 | # path for loading data from folder
26 | parser.add_argument('--hr1_train', default= file_train_path_time1, type=str, help='image at t1 in train set')
27 | parser.add_argument('--hr2_train', default= file_train_path_time2, type=str, help='image at t2 in train set')
28 | parser.add_argument('--lab_train', default= file_train_path_label, type=str, help='label image in train set')
29 | # # """
30 |
31 | parser.add_argument('--hr1_test', default= file_test_path_time1, type=str, help='image at t1 in test set')
32 | parser.add_argument('--hr2_test', default= file_test_path_time2, type=str, help='image at t2 in test set')
33 | parser.add_argument('--lab_test', default= file_test_path_label, type=str, help='label image in test set')
34 |
35 | # network saving
36 | parser.add_argument('--model_dir', default='/./', type=str, help='model save path')
37 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Misc functions, including distributed helpers.
5 |
6 | Mostly copy-paste from torchvision references.
7 | """
8 | import io
9 | import os
10 | import time
11 | from collections import defaultdict, deque
12 | import datetime
13 |
14 | import torch
15 | import torch.distributed as dist
16 |
17 |
18 | class SmoothedValue(object):
19 | """Track a series of values and provide access to smoothed values over a
20 | window or the global series average.
21 | """
22 |
23 | def __init__(self, window_size=20, fmt=None):
24 | if fmt is None:
25 | fmt = "{median:.4f} ({global_avg:.4f})"
26 | self.deque = deque(maxlen=window_size)
27 | self.total = 0.0
28 | self.count = 0
29 | self.fmt = fmt
30 |
31 | def update(self, value, n=1):
32 | self.deque.append(value)
33 | self.count += n
34 | self.total += value * n
35 |
36 | def synchronize_between_processes(self):
37 | """
38 | Warning: does not synchronize the deque!
39 | """
40 | if not is_dist_avail_and_initialized():
41 | return
42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43 | dist.barrier()
44 | dist.all_reduce(t)
45 | t = t.tolist()
46 | self.count = int(t[0])
47 | self.total = t[1]
48 |
49 | @property
50 | def median(self):
51 | d = torch.tensor(list(self.deque))
52 | return d.median().item()
53 |
54 | @property
55 | def avg(self):
56 | d = torch.tensor(list(self.deque), dtype=torch.float32)
57 | return d.mean().item()
58 |
59 | @property
60 | def global_avg(self):
61 | return self.total / self.count
62 |
63 | @property
64 | def max(self):
65 | return max(self.deque)
66 |
67 | @property
68 | def value(self):
69 | return self.deque[-1]
70 |
71 | def __str__(self):
72 | return self.fmt.format(
73 | median=self.median,
74 | avg=self.avg,
75 | global_avg=self.global_avg,
76 | max=self.max,
77 | value=self.value)
78 |
79 |
80 | class MetricLogger(object):
81 | def __init__(self, delimiter="\t"):
82 | self.meters = defaultdict(SmoothedValue)
83 | self.delimiter = delimiter
84 |
85 | def update(self, **kwargs):
86 | for k, v in kwargs.items():
87 | if isinstance(v, torch.Tensor):
88 | v = v.item()
89 | assert isinstance(v, (float, int))
90 | self.meters[k].update(v)
91 |
92 | def __getattr__(self, attr):
93 | if attr in self.meters:
94 | return self.meters[attr]
95 | if attr in self.__dict__:
96 | return self.__dict__[attr]
97 | raise AttributeError("'{}' object has no attribute '{}'".format(
98 | type(self).__name__, attr))
99 |
100 | def __str__(self):
101 | loss_str = []
102 | for name, meter in self.meters.items():
103 | loss_str.append(
104 | "{}: {}".format(name, str(meter))
105 | )
106 | return self.delimiter.join(loss_str)
107 |
108 | def synchronize_between_processes(self):
109 | for meter in self.meters.values():
110 | meter.synchronize_between_processes()
111 |
112 | def add_meter(self, name, meter):
113 | self.meters[name] = meter
114 |
115 | def log_every(self, iterable, print_freq, header=None):
116 | i = 0
117 | if not header:
118 | header = ''
119 | start_time = time.time()
120 | end = time.time()
121 | iter_time = SmoothedValue(fmt='{avg:.4f}')
122 | data_time = SmoothedValue(fmt='{avg:.4f}')
123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
124 | log_msg = [
125 | header,
126 | '[{0' + space_fmt + '}/{1}]',
127 | 'eta: {eta}',
128 | '{meters}',
129 | 'time: {time}',
130 | 'data: {data}'
131 | ]
132 | if torch.cuda.is_available():
133 | log_msg.append('max mem: {memory:.0f}')
134 | log_msg = self.delimiter.join(log_msg)
135 | MB = 1024.0 * 1024.0
136 | for obj in iterable:
137 | data_time.update(time.time() - end)
138 | yield obj
139 | iter_time.update(time.time() - end)
140 | if i % print_freq == 0 or i == len(iterable) - 1:
141 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
143 | if torch.cuda.is_available():
144 | print(log_msg.format(
145 | i, len(iterable), eta=eta_string,
146 | meters=str(self),
147 | time=str(iter_time), data=str(data_time),
148 | memory=torch.cuda.max_memory_allocated() / MB))
149 | else:
150 | print(log_msg.format(
151 | i, len(iterable), eta=eta_string,
152 | meters=str(self),
153 | time=str(iter_time), data=str(data_time)))
154 | i += 1
155 | end = time.time()
156 | total_time = time.time() - start_time
157 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
158 | print('{} Total time: {} ({:.4f} s / it)'.format(
159 | header, total_time_str, total_time / len(iterable)))
160 |
161 |
162 | def _load_checkpoint_for_ema(model_ema, checkpoint):
163 | """
164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
165 | """
166 | mem_file = io.BytesIO()
167 | torch.save(checkpoint, mem_file)
168 | mem_file.seek(0)
169 | model_ema._load_checkpoint(mem_file)
170 |
171 |
172 | def setup_for_distributed(is_master):
173 | """
174 | This function disables printing when not in master process
175 | """
176 | import builtins as __builtin__
177 | builtin_print = __builtin__.print
178 |
179 | def print(*args, **kwargs):
180 | force = kwargs.pop('force', False)
181 | if is_master or force:
182 | builtin_print(*args, **kwargs)
183 |
184 | __builtin__.print = print
185 |
186 |
187 | def is_dist_avail_and_initialized():
188 | if not dist.is_available():
189 | return False
190 | if not dist.is_initialized():
191 | return False
192 | return True
193 |
194 |
195 | def get_world_size():
196 | if not is_dist_avail_and_initialized():
197 | return 1
198 | return dist.get_world_size()
199 |
200 |
201 | def get_rank():
202 | if not is_dist_avail_and_initialized():
203 | return 0
204 | return dist.get_rank()
205 |
206 |
207 | def is_main_process():
208 | return get_rank() == 0
209 |
210 |
211 | def save_on_master(*args, **kwargs):
212 | if is_main_process():
213 | torch.save(*args, **kwargs)
214 |
215 |
216 | def init_distributed_mode(args):
217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
218 | args.rank = int(os.environ["RANK"])
219 | args.world_size = int(os.environ['WORLD_SIZE'])
220 | args.gpu = int(os.environ['LOCAL_RANK'])
221 | elif 'SLURM_PROCID' in os.environ:
222 | args.rank = int(os.environ['SLURM_PROCID'])
223 | args.gpu = args.rank % torch.cuda.device_count()
224 | else:
225 | print('Not using distributed mode')
226 | args.distributed = False
227 | return
228 |
229 | args.distributed = True
230 |
231 | torch.cuda.set_device(args.gpu)
232 | args.dist_backend = 'nccl'
233 | print('| distributed init (rank {}): {}'.format(
234 | args.rank, args.dist_url), flush=True)
235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
236 | world_size=args.world_size, rank=args.rank)
237 | torch.distributed.barrier()
238 | setup_for_distributed(args.rank == 0)
239 |
--------------------------------------------------------------------------------
/visualize_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import shutil
4 | import sys
5 | import warnings
6 | from pathlib import Path
7 |
8 | import cv2
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 |
13 | class CDVisualization(object):
14 | def __init__(self, policy=['compare_pixel', 'pixel']):
15 | """Change Detection Visualization
16 |
17 | Args:
18 | policy (list, optional): _description_. Defaults to ['compare_pixel', 'pixel'].
19 | """
20 | super().__init__()
21 | assert isinstance(policy, list)
22 | self.policy = policy
23 | self.num_classes = 2
24 | self.COLOR_MAP = {'0': (0, 0, 0), # black is TN
25 | '1': (0, 255, 0), # green is FP 误检
26 | '2': (255, 0, 0), # red is FN 漏检
27 | '3': (255, 255, 255)} # white is TP
28 |
29 | def read_and_check_label(self, file_name):
30 | img = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE)
31 | if np.max(img) >= self.num_classes:
32 | warnings.warn('Please make sure the range of the pixel value' \
33 | 'in your pred or gt.')
34 | img[img < 128] = 0
35 | img[img >= 128] = 1
36 | return img
37 |
38 | def read_img(self, file_name):
39 | img = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
40 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41 | return img
42 |
43 | def save_img(self, file_name, vis_res, imgs=None):
44 | dir_name = os.path.dirname(file_name)
45 | base_name, suffix = os.path.basename(file_name).split('.')
46 | os.makedirs(dir_name, exist_ok=True)
47 |
48 | # consistent with the original image
49 | vis_res = cv2.cvtColor(vis_res, cv2.COLOR_RGB2BGR)
50 | if imgs is not None:
51 | assert isinstance(imgs, list), '`imgs` must be a list.'
52 | if vis_res.shape != imgs[0].shape:
53 | vis_res = cv2.cvtColor(vis_res, cv2.COLOR_GRAY2RGB)
54 | for idx, img in enumerate(imgs):
55 | assert img.shape == vis_res.shape, '`img` and `vis_res` must be ' \
56 | 'of the same shape.'
57 | cv2.imwrite(osp.join(dir_name, base_name + '_' + str(idx) + '.' + suffix),
58 | cv2.addWeighted(img, 1, vis_res, 0.25, 0.0))
59 | else:
60 | cv2.imwrite(file_name, vis_res)
61 |
62 | def trainIdToColor(self, trainId):
63 | """convert label id to color
64 |
65 | Args:
66 | trainId (int): _description_
67 |
68 | Returns:
69 | color (tuple)
70 | """
71 | color = self.COLOR_MAP[str(trainId)]
72 | return color
73 |
74 | def gray2color(self, grayImage: np.ndarray, num_class: list):
75 | """convert label to color image
76 |
77 | Args:
78 | grayImage (np.ndarray): _description_
79 | num_class (list): _description_
80 |
81 | Returns:
82 | _type_: _description_
83 | """
84 | rgbImage = np.zeros((grayImage.shape[0], grayImage.shape[1], 3), dtype='uint8')
85 | for cls in num_class:
86 | row, col = np.where(grayImage == cls)
87 | if (len(row) == 0):
88 | continue
89 | color = self.trainIdToColor(cls)
90 | rgbImage[row, col] = color
91 | return rgbImage
92 |
93 | def res_pixel_visual(self, label):
94 | assert np.max(label) < self.num_classes, 'There exists the value of ' \
95 | '`label` that is greater than `num_classes`'
96 | label_rgb = self.gray2color(label, num_class=list(range(self.num_classes)))
97 | return label_rgb
98 |
99 | def res_compare_pixel_visual(self, pred, gt):
100 | """visualize according to confusion matrix.
101 |
102 | Args:
103 | pred (_type_): _description_
104 | gt (_type_): _description_
105 | """
106 | assert np.max(pred) < self.num_classes, 'There exists the value of ' \
107 | '`pred` that is greater than `num_classes`'
108 | assert np.max(gt) < self.num_classes, 'There exists the value of ' \
109 | '`gt` that is greater than `num_classes`'
110 |
111 | visual_ = self.num_classes * gt.astype(int) + pred.astype(int)
112 | visual_rgb = self.gray2color(visual_, num_class=list(range(self.num_classes ** 2)))
113 | return visual_rgb
114 |
115 | def res_compare_boundary_visual(self):
116 | pass
117 |
118 | def __call__(self, pred_path, gt_path, dst_path, imgs=None):
119 | # dst_prefix, dst_suffix = osp.abspath(dst_path).split('.')
120 | # dst_path = dst_prefix + '_{}.' + dst_suffix
121 | file_name = osp.basename(dst_path)
122 | dst_path = osp.dirname(dst_path)
123 |
124 | pred = self.read_and_check_label(pred_path)
125 | gt = self.read_and_check_label(gt_path)
126 | if imgs is not None:
127 | assert isinstance(imgs, list), '`imgs` must be a list.'
128 | imgs = [self.read_img(p) for p in imgs]
129 |
130 | for pol in self.policy:
131 | dst_path_pol = osp.join(dst_path, pol)
132 | os.makedirs(dst_path_pol, exist_ok=True)
133 | dst_file = osp.join(dst_path_pol, file_name)
134 | if pol == 'compare_pixel':
135 | visual_map = self.res_compare_pixel_visual(pred, gt)
136 | self.save_img(dst_file, visual_map, None)
137 | elif pol == 'pixel':
138 | visual_map = self.res_pixel_visual(pred)
139 | self.save_img(dst_file, visual_map, imgs)
140 | else:
141 | raise ValueError(f'Invalid policy {pol}')
142 |
143 |
144 | if __name__ == '__main__':
145 | gt_dir = '/./label'
146 | pred_dir = '/./pre'
147 | dst_dir = '/./save'
148 | img_dir = ['/./test/time1',
149 | '/./time2']
150 |
151 |
152 | file_name_list = os.listdir(gt_dir)
153 |
154 | for file_name in file_name_list:
155 |
156 | CDVisual = CDVisualization(policy=['compare_pixel'])
157 | CDVisual(osp.join(pred_dir, file_name),
158 | osp.join(gt_dir, file_name),
159 | osp.join(dst_dir, file_name),
160 | [osp.join(p, file_name) for p in img_dir])
161 |
162 |
--------------------------------------------------------------------------------