├── LICENSE
├── MFPNet_code
├── eval.py
├── metadata.json
├── metadata_descripation.py
├── models
│ ├── MFPNet_model.py
│ ├── seresnet50.py
│ └── vgg.py
├── train.py
└── utils
│ ├── dataloaders.py
│ ├── helpers.py
│ ├── hybridloss.py
│ ├── metrics.py
│ ├── parser.py
│ └── transforms.py
├── README.md
└── figure
├── AWF.png
├── MFP.png
├── MFPNet.png
└── PSM.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jialang Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MFPNet_code/eval.py:
--------------------------------------------------------------------------------
1 | from shutil import copyfile
2 | import torch.utils.data
3 | from utils.parser import get_parser_with_args
4 | from utils.helpers import get_test_loaders
5 | from tqdm import tqdm
6 | from sklearn.metrics import confusion_matrix
7 | import numpy as np
8 | import torch.nn.functional as F
9 | import cv2
10 | import os
11 | from utils.helpers import load_model
12 |
13 | parser, metadata = get_parser_with_args(metadata_json_path='/home/aaa/xujialang/master_thesis/MFPNet/metadata.json')
14 | opt = parser.parse_args()
15 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16 |
17 | test_loader = get_test_loaders(opt)
18 |
19 | weight_path = os.path.join(opt.weight_dir, 'model_weight.pt') # the path of the model weight
20 | model = load_model(opt, dev)
21 | model.load_state_dict(torch.load(weight_path))
22 | """
23 | Begin Test
24 | """
25 | model.eval()
26 | with torch.no_grad():
27 | c_matrix = {'tn': 0, 'fp': 0, 'fn': 0, 'tp': 0}
28 | test_metrics = {
29 | 'cd_precisions': [],
30 | 'cd_recalls': [],
31 | 'cd_f1scores': [],
32 | }
33 |
34 | for batch_img1, batch_img2, labels in test_loader:
35 | batch_img1 = batch_img1.float().to(dev)
36 | batch_img2 = batch_img2.float().to(dev)
37 | labels = labels.long().to(dev)
38 | cd_preds = model(batch_img1, batch_img2)
39 | cd_preds = torch.argmax(cd_preds, dim = 1)
40 |
41 | tp= (labels.cpu().numpy() * cd_preds.cpu().numpy()).sum()
42 | tn= ((1-labels.cpu().numpy()) * (1-cd_preds.cpu().numpy())).sum()
43 | fn= (labels.cpu().numpy() * (1-cd_preds.cpu().numpy())).sum()
44 | fp= ((1-labels.cpu().numpy()) * cd_preds.cpu().numpy()).sum()
45 | c_matrix['tn'] += tn
46 | c_matrix['fp'] += fp
47 | c_matrix['fn'] += fn
48 | c_matrix['tp'] += tp
49 |
50 | tn, fp, fn, tp = c_matrix['tn'], c_matrix['fp'], c_matrix['fn'], c_matrix['tp']
51 | P = tp / (tp + fp)
52 | R = tp / (tp + fn)
53 | F1 = 2 * P * R / (R + P)
54 | IOU = tp/ (fn+tp+fp)
55 |
56 | ttt_test=tn+fp+fn+tp
57 | TA_test = (tp+tn) / ttt_test
58 | Pcp1_test = (tp + fn) / ttt_test
59 | Pcp2_test = (tp + fp) / ttt_test
60 | Pcn1_test = (fp + tn) / ttt_test
61 | Pcn2_test = (fn + tn) / ttt_test
62 | Pc_test = Pcp1_test*Pcp2_test + Pcn1_test*Pcn2_test
63 | kappa_test = (TA_test - Pc_test) / (1 - Pc_test)
64 |
65 | test_metrics['cd_f1scores'] = F1
66 | test_metrics['cd_precisions'] = P
67 | test_metrics['cd_recalls'] = R
68 | print("TEST METRICS. KAPPA: {}. IOU: {} ".format(kappa_test, IOU) + str(test_metrics))
--------------------------------------------------------------------------------
/MFPNet_code/metadata.json:
--------------------------------------------------------------------------------
1 | {
2 | "patch_size": 256,
3 | "augmentation": true,
4 | "num_gpus": 1,
5 | "num_workers": 4,
6 | "num_channel": 3,
7 | "epochs": 200,
8 | "batch_size": 4,
9 | "learning_rate": 1e-4,
10 | "loss_function": "hybrid",
11 | "dataset_dir": "/home/bigspace/xujialang/cd_dataset/Google/",
12 | "weight_dir": "/home/bigspace/xujialang/MFPNet_result/Google/",
13 | "resume": "None"
14 | }
--------------------------------------------------------------------------------
/MFPNet_code/metadata_descripation.py:
--------------------------------------------------------------------------------
1 |
2 | # For Seasonvarying/LEVIR-CD/Google Dataset
3 | {
4 | "patch_size": 256,
5 | "augmentation": true,
6 | "num_gpus": 1,
7 | "num_workers": 4,
8 | "num_channel": 3,
9 | "epochs": 200,
10 | "batch_size": 4,
11 | "learning_rate": 1e-4,
12 | "loss_function": "hybrid", # ['hybird', 'bce', 'dice', 'jaccard'], 'hybrid' means Softmax PPCE + Perceputal Loss
13 | "dataset_dir": "/home/bigspace/xujialang/cd_dataset/Seasonvarying/", # change to your own path
14 | "weight_dir": "/home/bigspace/xujialang/MFPNet_result/Seasonvarying/", # change to your own path
15 | "resume": "None" # Change if you want to continue your training process
16 | }
17 |
18 | # For Zhang dataset
19 | {
20 | "patch_size": 512,
21 | "augmentation": true,
22 | "num_gpus": 1,
23 | "num_workers": 4,
24 | "num_channel": 3,
25 | "epochs": 200,
26 | "batch_size": 2,
27 | "learning_rate": 1e-4,
28 | "loss_function": "hybrid",
29 | "dataset_dir": "/home/bigspace/xujialang/cd_dataset/Zhang/"
30 | "weight_dir": "/home/bigspace/xujialang/MFPNet_result/Zhang/",
31 | "resume": "None"
32 | }
33 |
--------------------------------------------------------------------------------
/MFPNet_code/models/MFPNet_model.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import os
7 |
8 | from .seresnet50 import se_resnet50
9 |
10 | class BasicConvBlock(nn.Module):
11 | def __init__(self, in_channels, out_channels=None):
12 | super(BasicConvBlock, self).__init__()
13 |
14 | if out_channels is None:
15 | out_channels = in_channels
16 |
17 | self.conv = nn.Sequential(
18 | nn.Conv2d(in_channels, in_channels, 3, padding=1),
19 | nn.BatchNorm2d(in_channels),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d( in_channels, out_channels, 1, bias=False),
22 | nn.BatchNorm2d(out_channels),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d( out_channels, out_channels, 3, padding=1),
25 | nn.BatchNorm2d(out_channels),
26 | nn.ReLU(inplace=True),
27 | )
28 |
29 | def forward(self,x):
30 | x=self.conv(x)
31 | return x
32 |
33 | class Conv2dStaticSamePadding(nn.Module):
34 | """
35 | created by Zylo117
36 | The real keras/tensorflow conv2d with same padding
37 | """
38 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
39 | super().__init__()
40 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
41 | bias=bias, groups=groups)
42 | self.stride = self.conv.stride
43 | self.kernel_size = self.conv.kernel_size
44 | self.dilation = self.conv.dilation
45 |
46 | if isinstance(self.stride, int):
47 | self.stride = [self.stride] * 2
48 | elif len(self.stride) == 1:
49 | self.stride = [self.stride[0]] * 2
50 |
51 | if isinstance(self.kernel_size, int):
52 | self.kernel_size = [self.kernel_size] * 2
53 | elif len(self.kernel_size) == 1:
54 | self.kernel_size = [self.kernel_size[0]] * 2
55 |
56 | def forward(self, x):
57 | h, w = x.shape[-2:]
58 |
59 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
60 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
61 |
62 | left = extra_h // 2
63 | right = extra_h - left
64 | top = extra_v // 2
65 | bottom = extra_v - top
66 |
67 | x = F.pad(x, [left, right, top, bottom])
68 |
69 | x = self.conv(x)
70 | return x
71 |
72 | class MaxPool2dStaticSamePadding(nn.Module):
73 | """
74 | created by Zylo117
75 | The real keras/tensorflow MaxPool2d with same padding
76 | """
77 | def __init__(self, *args, **kwargs):
78 | super().__init__()
79 | self.pool = nn.MaxPool2d(*args, **kwargs)
80 | self.stride = self.pool.stride
81 | self.kernel_size = self.pool.kernel_size
82 |
83 | if isinstance(self.stride, int):
84 | self.stride = [self.stride] * 2
85 | elif len(self.stride) == 1:
86 | self.stride = [self.stride[0]] * 2
87 |
88 | if isinstance(self.kernel_size, int):
89 | self.kernel_size = [self.kernel_size] * 2
90 | elif len(self.kernel_size) == 1:
91 | self.kernel_size = [self.kernel_size[0]] * 2
92 |
93 | def forward(self, x):
94 | h, w = x.shape[-2:]
95 |
96 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
97 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
98 |
99 | left = extra_h // 2
100 | right = extra_h - left
101 | top = extra_v // 2
102 | bottom = extra_v - top
103 |
104 | x = F.pad(x, [left, right, top, bottom])
105 |
106 | x = self.pool(x)
107 | return x
108 |
109 | # Channel Attention Algorithm (CAA)
110 | class CAA(nn.Module):
111 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg','max']):
112 | super(CAA, self).__init__()
113 | self.num=1
114 | self.gate_channels = gate_channels
115 |
116 | self.conv_fc1 = nn.Sequential(
117 | nn.Conv2d(in_channels=gate_channels, out_channels=gate_channels//reduction_ratio, kernel_size=1, bias=False),
118 | nn.ReLU(inplace=True),
119 | nn.Conv2d(in_channels=gate_channels//reduction_ratio, out_channels=gate_channels, kernel_size=1, bias=False),
120 | )
121 | self.conv_fc2 = nn.Sequential(
122 | nn.Conv2d(in_channels=gate_channels, out_channels=gate_channels//reduction_ratio, kernel_size=1, bias=False),
123 | nn.ReLU(inplace=True),
124 | nn.Conv2d(in_channels=gate_channels//reduction_ratio, out_channels=gate_channels, kernel_size=1, bias=False),
125 | )
126 | self.conv= nn.Sequential(
127 | nn.Conv2d(gate_channels,gate_channels,kernel_size=(2,1),bias=False),
128 | nn.Sigmoid()
129 | )
130 |
131 | self.pool_types = pool_types
132 |
133 | def forward(self, x):
134 | channel_att_sum = None
135 | b,c,h,w=x.size()
136 |
137 | for pool_type in self.pool_types:
138 | if pool_type=='avg':
139 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
140 | channel_att_raw = self.conv_fc1(avg_pool).view(b,c,self.num,self.num)
141 | elif pool_type=='max':
142 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
143 | channel_att_raw = self.conv_fc2(max_pool).view(b,c,self.num,self.num)
144 |
145 | if channel_att_sum is None:
146 | channel_att_sum = channel_att_raw
147 | else:
148 | channel_att_sum = torch.cat([channel_att_sum, channel_att_raw],dim=2)
149 |
150 | channel_weight=self.conv(channel_att_sum)
151 | scale = nn.functional.upsample_bilinear(channel_weight, [h, w])
152 |
153 | return x * scale
154 |
155 | # Multidirectional Adaptive Feature Fusion Module (MAFFM)
156 | class MAFFM(nn.Module):
157 | def __init__(self, num_channels, conv_channels):
158 | super(MAFFM, self).__init__()
159 |
160 | # Conv layers
161 | self.conv5 = BasicConvBlock(num_channels)
162 | self.conv4 = BasicConvBlock(num_channels)
163 | self.conv3 = BasicConvBlock(num_channels)
164 | self.conv2 = BasicConvBlock(num_channels)
165 | self.conv1 = BasicConvBlock(num_channels)
166 |
167 | self.conv5_1 = BasicConvBlock(num_channels)
168 | self.conv4_1 = BasicConvBlock(num_channels)
169 | self.conv3_1 = BasicConvBlock(num_channels)
170 | self.conv2_1 = BasicConvBlock(num_channels)
171 | self.conv1_1 = BasicConvBlock(num_channels)
172 |
173 | self.conv1_down = BasicConvBlock(num_channels)
174 | self.conv2_down = BasicConvBlock(num_channels)
175 | self.conv3_down = BasicConvBlock(num_channels)
176 | self.conv4_down = BasicConvBlock(num_channels)
177 | self.conv5_down = BasicConvBlock(num_channels)
178 |
179 | # Feature scaling layers
180 | self.p4_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest')
181 | self.p3_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest')
182 | self.p2_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest')
183 | self.p1_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest')
184 |
185 | self.p2_downsample = MaxPool2dStaticSamePadding(3, 2)
186 | self.p3_downsample = MaxPool2dStaticSamePadding(3, 2)
187 | self.p4_downsample = MaxPool2dStaticSamePadding(3, 2)
188 | self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
189 |
190 | self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
191 | self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
192 | self.p2_upsample = nn.Upsample(scale_factor=2, mode='nearest')
193 | self.p1_upsample = nn.Upsample(scale_factor=2, mode='nearest')
194 |
195 | # Channel compression layers
196 | self.p5_down_channel = nn.Sequential(
197 | Conv2dStaticSamePadding(conv_channels[4], num_channels, 1),
198 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
199 | nn.ReLU(inplace=True),
200 | )
201 | self.p4_down_channel = nn.Sequential(
202 | Conv2dStaticSamePadding(conv_channels[3], num_channels, 1),
203 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
204 | nn.ReLU(inplace=True),
205 | )
206 | self.p3_down_channel = nn.Sequential(
207 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
208 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
209 | nn.ReLU(inplace=True),
210 | )
211 | self.p2_down_channel = nn.Sequential(
212 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
213 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
214 | nn.ReLU(inplace=True),
215 | )
216 | self.p1_down_channel = nn.Sequential(
217 | Conv2dStaticSamePadding(conv_channels[0], num_channels, 1),
218 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
219 | nn.ReLU(inplace=True),
220 | )
221 |
222 | # CAA
223 | self.csac_p1_0=CAA(num_channels,reduction_ratio=1)
224 | self.csac_p2_0=CAA(num_channels,reduction_ratio=1)
225 | self.csac_p3_0=CAA(num_channels,reduction_ratio=1)
226 | self.csac_p4_0=CAA(num_channels,reduction_ratio=1)
227 | self.csac_p5_0=CAA(num_channels,reduction_ratio=1)
228 |
229 | self.csac_p1_1=CAA(num_channels,reduction_ratio=1)
230 | self.csac_p2_1=CAA(num_channels,reduction_ratio=1)
231 | self.csac_p3_1=CAA(num_channels,reduction_ratio=1)
232 | self.csac_p4_1=CAA(num_channels,reduction_ratio=1)
233 | self.csac_p5_1=CAA(num_channels,reduction_ratio=1)
234 |
235 | self.csac_p1_2=CAA(num_channels,reduction_ratio=1)
236 | self.csac_p2_2=CAA(num_channels,reduction_ratio=1)
237 | self.csac_p3_2=CAA(num_channels,reduction_ratio=1)
238 | self.csac_p4_2=CAA(num_channels,reduction_ratio=1)
239 |
240 |
241 | self.csac_p51_0=CAA(num_channels,reduction_ratio=1)
242 | self.csac_p41_0=CAA(num_channels,reduction_ratio=1)
243 | self.csac_p31_0=CAA(num_channels,reduction_ratio=1)
244 | self.csac_p21_0=CAA(num_channels,reduction_ratio=1)
245 |
246 | self.csac_p51_1=CAA(num_channels,reduction_ratio=1)
247 | self.csac_p41_1=CAA(num_channels,reduction_ratio=1)
248 | self.csac_p31_1=CAA(num_channels,reduction_ratio=1)
249 | self.csac_p21_1=CAA(num_channels,reduction_ratio=1)
250 |
251 |
252 | self.csac_p52_0=CAA(num_channels,reduction_ratio=1)
253 | self.csac_p42_0=CAA(num_channels,reduction_ratio=1)
254 | self.csac_p32_0=CAA(num_channels,reduction_ratio=1)
255 | self.csac_p22_0=CAA(num_channels,reduction_ratio=1)
256 | self.csac_p12_0=CAA(num_channels,reduction_ratio=1)
257 |
258 | self.csac_p52_1=CAA(num_channels,reduction_ratio=1)
259 | self.csac_p42_1=CAA(num_channels,reduction_ratio=1)
260 | self.csac_p32_1=CAA(num_channels,reduction_ratio=1)
261 | self.csac_p22_1=CAA(num_channels,reduction_ratio=1)
262 | self.csac_p12_1=CAA(num_channels,reduction_ratio=1)
263 |
264 | self.csac_p42_2=CAA(num_channels,reduction_ratio=1)
265 | self.csac_p32_2=CAA(num_channels,reduction_ratio=1)
266 | self.csac_p22_2=CAA(num_channels,reduction_ratio=1)
267 | self.csac_p12_2=CAA(num_channels,reduction_ratio=1)
268 |
269 | def forward(self, inputs):
270 | p1_pre, p2_pre, p3_pre, p4_pre, p5_pre, p1_now, p2_now, p3_now, p4_now, p5_now = inputs
271 |
272 | p1_in_pre = self.p1_down_channel(p1_pre)
273 | p1_in_now = self.p1_down_channel(p1_now)
274 |
275 | p2_in_pre = self.p2_down_channel(p2_pre)
276 | p2_in_now = self.p2_down_channel(p2_now)
277 |
278 | p3_in_pre = self.p3_down_channel(p3_pre)
279 | p3_in_now = self.p3_down_channel(p3_now)
280 |
281 | p4_in_pre = self.p4_down_channel(p4_pre)
282 | p4_in_now = self.p4_down_channel(p4_now)
283 |
284 | p5_in_pre = self.p5_down_channel(p5_pre)
285 | p5_in_now = self.p5_down_channel(p5_now)
286 |
287 | # Multidirectional Fusion Pathway (MFP) + Adaptive Weighted Fusion (AWF)
288 | # Up
289 | p5_in=self.conv5(self.csac_p5_0(p5_in_now)+self.csac_p5_1(p5_in_pre))
290 | p4_in=self.conv4(self.csac_p4_0(p4_in_now)+self.csac_p4_1(p4_in_pre)+self.csac_p4_2(self.p4_upsample(p5_in)))
291 | p3_in=self.conv3(self.csac_p3_0(p3_in_now)+self.csac_p3_1(p3_in_pre)+self.csac_p3_2(self.p3_upsample(p4_in)))
292 | p2_in=self.conv2(self.csac_p2_0(p2_in_now)+self.csac_p2_1(p2_in_pre)+self.csac_p2_2(self.p2_upsample(p3_in)))
293 | p1_in=self.conv1(self.csac_p1_0(p1_in_now)+self.csac_p1_1(p1_in_pre)+self.csac_p1_2(self.p1_upsample(p2_in)))
294 | # Down
295 | p1_1 = self.conv1_down(p1_in)
296 | p2_1 = self.conv2_down(self.csac_p21_0(p2_in) + self.csac_p21_1(self.p2_downsample(p1_1)))
297 | p3_1 = self.conv3_down(self.csac_p31_0(p3_in) + self.csac_p31_1(self.p3_downsample(p2_1)))
298 | p4_1 = self.conv4_down(self.csac_p41_0(p4_in) + self.csac_p41_1(self.p4_downsample(p3_1)))
299 | p5_1 = self.conv5_down(self.csac_p51_0(p5_in) + self.csac_p51_1(self.p5_downsample(p4_1)))
300 | # Up
301 | p5_2 = self.conv5_1(self.csac_p52_0(p5_in) + self.csac_p52_1(p5_1))
302 | p4_2 = self.conv4_1(self.csac_p42_0(p4_in) + self.csac_p42_1(p4_1)+self.csac_p42_2(self.p4_upsample_1(p5_2)))
303 | p3_2 = self.conv3_1(self.csac_p32_0(p3_in) + self.csac_p32_1(p3_1)+self.csac_p32_2(self.p3_upsample_1(p4_2)))
304 | p2_2 = self.conv2_1(self.csac_p22_0(p2_in) + self.csac_p22_1(p2_1)+self.csac_p22_2(self.p2_upsample_1(p3_2)))
305 | p1_2 = self.conv1_1(self.csac_p12_0(p1_in) + self.csac_p12_1(p1_1)+self.csac_p12_2(self.p1_upsample_1(p2_2)))
306 |
307 | return p1_2
308 |
309 | class DECODER(nn.Module):
310 | def __init__(self, in_ch, classes):
311 | super(DECODER, self).__init__()
312 | self.conv1 = nn.Conv2d(
313 | in_ch, in_ch//4, kernel_size=3, padding=1)
314 | self.conv2 = nn.Conv2d(
315 | in_ch//4, in_ch//8, kernel_size=3, padding=1)
316 | self.conv3 = nn.Conv2d(
317 | in_ch//8, classes*4, kernel_size=1)
318 |
319 | self.ps3 = nn.PixelShuffle(2)
320 |
321 | def forward(self, x):
322 | x = self.conv1(x)
323 | x = self.conv2(x)
324 | x = self.conv3(x)
325 |
326 | x = self.ps3(x)
327 |
328 | return x
329 |
330 | class MFPNET(nn.Module):
331 | def __init__(self, classes):
332 | super(MFPNET, self).__init__()
333 |
334 | self.se_resnet50 = se_resnet50(pretrained=True, strides = (1,2,2,2))
335 | self.stage1 = nn.Sequential(self.se_resnet50.conv1, self.se_resnet50.bn1, self.se_resnet50.relu)
336 | self.stage2 = nn.Sequential(self.se_resnet50.maxpool, self.se_resnet50.layer1)
337 | self.stage3 = nn.Sequential(self.se_resnet50.layer2)
338 | self.stage4 = nn.Sequential(self.se_resnet50.layer3)
339 | self.stage5 = nn.Sequential(self.se_resnet50.layer4)
340 |
341 | self.maffm=MAFFM(256,[64,256,512,1024,2048])
342 | self.dec = DECODER(256, classes)
343 |
344 | def encoder(self, x):
345 | x1 = self.stage1(x)
346 | x2 = self.stage2(x1)
347 | x3 = self.stage3(x2)
348 | x4 = self.stage4(x3)
349 | x5 = self.stage5(x4)
350 |
351 | return x1, x2, x3, x4, x5
352 |
353 | def forward(self, x_prev, x_now):
354 | p1_t1, p2_t1, p3_t1, p4_t1, p5_t1 = self.encoder(x_prev)
355 | p1_t2, p2_t2, p3_t2, p4_t2, p5_t2 = self.encoder(x_now)
356 | features_t1_t2 = (p1_t1, p2_t1, p3_t1, p4_t1, p5_t1, p1_t2, p2_t2, p3_t2, p4_t2, p5_t2)
357 |
358 | x_fuse=self.maffm(features_t1_t2)
359 | dis_map=self.dec(x_fuse)
360 |
361 | return dis_map
362 |
363 | if __name__ == "__main__":
364 | model = MFPNET(classes = 2)
365 |
366 | # # Example for using Perceptual Similarity Module
367 | # from vgg import Vgg19
368 |
369 | # criterion_perceptual = nn.MSELoss()
370 | # criterion_perceptual.cuda()
371 | # vgg= Vgg19().cuda()
372 |
373 | # for epoch in range(300):
374 | # for i, (data_prev, data_now, label) in enumerate(loader_train, 0):
375 | # model.train()
376 | # model.zero_grad()
377 | # optimizer.zero_grad()
378 | # img_prev_train, img_now_train, label_train = data_prev.cuda(), data_now.cuda(), label.cuda()
379 |
380 | # out_train1, _ = model(img_prev_train, img_now_train)
381 |
382 | # # Perceptual Similarity Module (PSM)
383 | # out_train_softmax2d = F.softmax(out_train1,dim=1)
384 | # an_change = out_train_softmax2d[:,1,:,:].unsqueeze(1).expand_as(img_prev_train)
385 | # an_unchange = out_train_softmax2d[:,0,:,:].unsqueeze(1).expand_as(img_prev_train)
386 | # label_change = label_train.expand_as(img_prev_train).type(torch.FloatTensor).cuda()
387 | # label_unchange = 1-label_change
388 | # an_change = an_change*label_change
389 | # an_unchange = an_unchange*(1-label_change)
390 |
391 | # an_change_feature = vgg(an_change)
392 | # gt_feature = vgg(label_change)
393 | # an_unchange_feature = vgg(an_unchange)
394 | # gt_feature_unchange = vgg(label_unchange)
395 |
396 | # perceptual_loss_change = criterion_perceptual(an_change_feature[0], gt_feature[0])
397 | # perceptual_loss_unchange = criterion_perceptual(an_unchange_feature[0], gt_feature_unchange[0])
398 | # perceptual_loss = perceptual_loss_change + perceptual_loss_unchange
--------------------------------------------------------------------------------
/MFPNet_code/models/seresnet50.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | model_urls = {
6 | 'seresnet50': 'https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl'
7 | }
8 |
9 | class SELayer(nn.Module):
10 | def __init__(self, channel, reduction=16):
11 | super(SELayer, self).__init__()
12 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
13 | self.fc = nn.Sequential(
14 | nn.Linear(channel, channel // reduction, bias=False),
15 | nn.ReLU(inplace=True),
16 | nn.Linear(channel // reduction, channel, bias=False),
17 | nn.Sigmoid()
18 | )
19 |
20 | def forward(self, x):
21 | b, c, _, _ = x.size()
22 | y = self.avg_pool(x).view(b, c)
23 | y = self.fc(y).view(b, c, 1, 1)
24 | return x * y.expand_as(x)
25 |
26 | class FixedBatchNorm(nn.BatchNorm2d):
27 | def forward(self, input):
28 | return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias,
29 | training=False, eps=self.eps)
30 |
31 |
32 | class Bottleneck(nn.Module):
33 | expansion = 4
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, reduction=16):
36 | super(Bottleneck, self).__init__()
37 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
38 | self.bn1 = FixedBatchNorm(planes)
39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
40 | padding=dilation, bias=False, dilation=dilation)
41 | self.bn2 = FixedBatchNorm(planes)
42 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
43 | self.bn3 = FixedBatchNorm(planes * 4)
44 | self.relu = nn.ReLU(inplace=True)
45 | # Squeeze-and-Excitation
46 | self.se = SELayer(planes * 4, reduction)
47 | # Downsample
48 | self.downsample = downsample
49 | self.stride = stride
50 | self.dilation = dilation
51 |
52 | def forward(self, x):
53 | residual = x
54 |
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu(out)
58 |
59 | out = self.conv2(out)
60 | out = self.bn2(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv3(out)
64 | out = self.bn3(out)
65 |
66 | out = self.se(out)
67 |
68 | if self.downsample is not None:
69 | residual = self.downsample(x)
70 |
71 | out += residual
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 | class Bottleneck_mdcn(nn.Module):
77 | expansion = 4
78 |
79 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, reduction=16):
80 | super(Bottleneck_mdcn, self).__init__()
81 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
82 | self.bn1 = FixedBatchNorm(planes)
83 | self.conv2= ModulatedDeformConvPack(planes, planes, kernel_size=(3, 3), stride=stride,
84 | padding=dilation, dilation=dilation,bias=False,deformable_groups=2)
85 | self.bn2 = FixedBatchNorm(planes)
86 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
87 | self.bn3 = FixedBatchNorm(planes * 4)
88 | self.relu = nn.ReLU(inplace=True)
89 | # Squeeze-and-Excitation
90 | self.se = SELayer(planes * 4, reduction)
91 | # Downsample
92 | self.downsample = downsample
93 | self.stride = stride
94 | self.dilation = dilation
95 |
96 | def forward(self, x):
97 | residual = x
98 |
99 | out = self.conv1(x)
100 | out = self.bn1(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv2(out)
104 | out = self.bn2(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv3(out)
108 | out = self.bn3(out)
109 |
110 | out = self.se(out)
111 |
112 | if self.downsample is not None:
113 | residual = self.downsample(x)
114 |
115 | out += residual
116 | out = self.relu(out)
117 |
118 | return out
119 |
120 |
121 | class SEResNet(nn.Module):
122 |
123 | def __init__(self, block, layers, strides=(2, 2, 2, 2), dilations=(1, 1, 2, 4),zero_init_residual=True):
124 | super(SEResNet, self).__init__()
125 | self.inplanes = 64
126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
127 | bias=False)
128 | self.bn1 = FixedBatchNorm(64)
129 | self.relu = nn.ReLU(inplace=True)
130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
131 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dilation=dilations[0])
132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])
133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])
134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3])
135 | self.inplanes = 1024
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | nn.Conv2d(self.inplanes, planes * block.expansion,
142 | kernel_size=1, stride=stride, bias=False),
143 | FixedBatchNorm(planes * block.expansion),
144 | )
145 |
146 | layers = [block(self.inplanes, planes, stride, downsample, dilation=1)]
147 | self.inplanes = planes * block.expansion
148 | for i in range(1, blocks):
149 | layers.append(block(self.inplanes, planes, dilation=dilation))
150 |
151 | return nn.Sequential(*layers)
152 |
153 |
154 |
155 |
156 |
157 | def forward(self, x):
158 | x = self.conv1(x)
159 | x = self.bn1(x)
160 | x = self.relu(x)
161 | x = self.maxpool(x)
162 |
163 | x = self.layer1(x)
164 | x = self.layer2(x)
165 | x = self.layer3(x)
166 | x = self.layer4(x)
167 |
168 | x = self.avgpool(x)
169 | x = x.view(x.size(0), -1)
170 | x = self.fc(x)
171 |
172 | return x
173 |
174 |
175 | def se_resnet50(pretrained=True, **kwargs):
176 |
177 | model = SEResNet(Bottleneck,layers=[3, 4, 6, 3], **kwargs)
178 | if pretrained:
179 | state_dict = model_zoo.load_url(model_urls['seresnet50'])
180 | model_dict = model.state_dict()
181 |
182 | state_dict.pop('fc.weight')
183 | state_dict.pop('fc.bias')
184 |
185 | # state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
186 | # state_dict.update(state_dict)
187 | model.load_state_dict(state_dict)
188 | print("Success to load a pretrained weight")
189 | return model
190 |
191 | if __name__ == "__main__":
192 | model = nn.DataParallel(se_resnet50(pretrained=True, strides = (1, 2, 1, 2)), device_ids=1).cuda()
--------------------------------------------------------------------------------
/MFPNet_code/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models
4 |
5 | class Vgg19(nn.Module):
6 | def __init__(self):
7 | super(Vgg19, self).__init__()
8 | features = models.vgg19(pretrained=True).features
9 | self.to_relu_1_2 = nn.Sequential()
10 | self.to_relu_2_2 = nn.Sequential()
11 | self.to_relu_3_4 = nn.Sequential()
12 | self.to_relu_4_4 = nn.Sequential()
13 | self.to_relu_5_4 = nn.Sequential()
14 | # conv -1
15 | for x in range(4):
16 | self.to_relu_1_2.add_module(str(x), features[x])
17 | for x in range(4, 9):
18 | self.to_relu_2_2.add_module(str(x), features[x])
19 | for x in range(9, 18):
20 | self.to_relu_3_4.add_module(str(x), features[x])
21 | for x in range(18, 27):
22 | self.to_relu_4_4.add_module(str(x), features[x])
23 | for x in range(27, 36):
24 | self.to_relu_5_4.add_module(str(x), features[x])
25 |
26 | # don't need the gradients, just want the features
27 | for param in self.parameters():
28 | param.requires_grad = False
29 |
30 | def forward(self, x):
31 | h = self.to_relu_1_2(x)
32 | h_relu_1_2 = h
33 | h = self.to_relu_2_2(h)
34 | h_relu_2_2 = h
35 | h = self.to_relu_3_4(h)
36 | h_relu_3_4 = h
37 | h = self.to_relu_4_4(h)
38 | h_relu_4_4 = h
39 | h = self.to_relu_5_4(h)
40 | h_relu_5_4 = h
41 |
42 | out = (h_relu_1_2, h_relu_2_2, h_relu_3_4, h_relu_4_4, h_relu_5_4)
43 | return out
--------------------------------------------------------------------------------
/MFPNet_code/train.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from utils.parser import get_parser_with_args
6 | from utils.helpers import (get_loaders, get_criterion,
7 | load_model, initialize_metrics, get_mean_metrics,
8 | set_metrics)
9 | from sklearn.metrics import precision_recall_fscore_support as prfs
10 | import os
11 | import logging
12 | import json
13 | import random
14 | import numpy as np
15 | import re
16 | import warnings
17 | from models.vgg import Vgg19
18 | warnings.filterwarnings("ignore")
19 |
20 | """
21 | Initialize Parser and define arguments
22 | """
23 | parser, metadata = get_parser_with_args(metadata_json_path='/home/aaa/xujialang/master_thesis/MFPNet/metadata.json')
24 | opt = parser.parse_args()
25 |
26 | """
27 | Initialize experiments log
28 | """
29 | logging.basicConfig(level=logging.INFO)
30 |
31 | """
32 | Set up environment: define paths, download data, and set device
33 | """
34 | dev = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
35 | logging.info('GPU AVAILABLE? ' + str(torch.cuda.is_available()))
36 |
37 | def seed_torch(seed):
38 | random.seed(seed)
39 | os.environ['PYTHONHASHSEED'] = str(seed)
40 | np.random.seed(seed)
41 | torch.manual_seed(seed)
42 | torch.cuda.manual_seed(seed)
43 | # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
44 | torch.backends.cudnn.benchmark = False
45 | torch.backends.cudnn.deterministic = True
46 | seed_torch(seed=777)
47 |
48 | train_loader, val_loader = get_loaders(opt)
49 | print(opt.batch_size * len(train_loader))
50 | print(opt.batch_size * len(val_loader))
51 |
52 | """
53 | Load Model then define other aspects of the model
54 | """
55 | logging.info('LOADING Model')
56 | model = load_model(opt, dev)
57 | vgg=Vgg19().to(dev)
58 | """
59 | Resume
60 | """
61 | epoch_resume=0
62 | if opt.resume != "None":
63 | model.load_state_dict(torch.load(os.path.join(opt.resume)))
64 | epoch_resume=int(re.sub("\D","",opt.resume))
65 | print('resume success: epoch {}'.format(epoch_resume))
66 |
67 | criterion_ce = nn.CrossEntropyLoss().to(dev)
68 | criterion_perceptual = nn.MSELoss().to(dev)
69 | criterion = get_criterion(opt)
70 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate) # Be careful when you adjust learning rate, you can refer to the linear scaling rule
71 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, T_mult=2, eta_min=0, last_epoch=-1)
72 |
73 | """
74 | Set starting values
75 | """
76 | best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1}
77 | logging.info('STARTING training')
78 |
79 | for epoch in range(opt.epochs):
80 | epoch= epoch + epoch_resume +1
81 | train_metrics = initialize_metrics()
82 | val_metrics = initialize_metrics()
83 |
84 | """
85 | Begin Training
86 | """
87 | model.train()
88 | logging.info('SET model mode to train!')
89 |
90 | for batch_img1, batch_img2, labels in train_loader:
91 | # Set variables for training
92 | batch_img1 = batch_img1.float().to(dev)
93 | batch_img2 = batch_img2.float().to(dev)
94 | labels = labels.long().to(dev)
95 |
96 | # Zero the gradient
97 | optimizer.zero_grad()
98 |
99 | # Get model predictions, calculate loss, backprop
100 | cd_preds= model(batch_img1, batch_img2)
101 | loss = criterion(criterion_ce, criterion_perceptual, cd_preds, labels, batch_img1, vgg, dev)
102 |
103 | loss.backward()
104 | optimizer.step()
105 |
106 | # Calculate and log other batch metrics
107 | cd_preds = torch.argmax(cd_preds, dim = 1)
108 | cd_corrects = (100 *
109 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
110 | (labels.size()[0] * (opt.patch_size**2)))
111 | cd_train_report = prfs(labels.data.cpu().numpy().flatten(),
112 | cd_preds.data.cpu().numpy().flatten(),
113 | average='binary',
114 | pos_label=1)
115 | train_metrics = set_metrics(train_metrics,
116 | loss,
117 | cd_corrects,
118 | cd_train_report,
119 | scheduler.get_last_lr())
120 |
121 | # log the batch mean metrics
122 | mean_train_metrics = get_mean_metrics(train_metrics)
123 |
124 | # clear batch variables from memory
125 | del batch_img1, batch_img2, labels
126 |
127 | scheduler.step()
128 | logging.info("EPOCH {} TRAIN METRICS. ".format(epoch) + str(mean_train_metrics))
129 |
130 |
131 | """
132 | Begin Validation
133 | """
134 | model.eval()
135 | with torch.no_grad():
136 | for batch_img1, batch_img2, labels in val_loader:
137 | # Set variables for training
138 | batch_img1 = batch_img1.float().to(dev)
139 | batch_img2 = batch_img2.float().to(dev)
140 | labels = labels.long().to(dev)
141 |
142 | # Get predictions and calculate loss
143 | cd_preds = model(batch_img1, batch_img2)
144 | val_loss = criterion(criterion_ce, criterion_perceptual, cd_preds, labels, batch_img1, vgg, dev)
145 |
146 | # Calculate and log other batch metrics
147 | cd_preds = torch.argmax(cd_preds, dim = 1)
148 | cd_corrects = (100 *
149 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
150 | (labels.size()[0] * (opt.patch_size**2)))
151 | cd_val_report = prfs(labels.data.cpu().numpy().flatten(),
152 | cd_preds.data.cpu().numpy().flatten(),
153 | average='binary',
154 | pos_label=1)
155 | val_metrics = set_metrics(val_metrics,
156 | val_loss,
157 | cd_corrects,
158 | cd_val_report,
159 | scheduler.get_lr())
160 |
161 | # log the batch mean metrics
162 | mean_val_metrics = get_mean_metrics(val_metrics)
163 |
164 | # clear batch variables from memory
165 | del batch_img1, batch_img2, labels
166 |
167 | logging.info("EPOCH {} VALIDATION METRICS".format(epoch)+str(mean_val_metrics))
168 |
169 | """
170 | Store the weights of good epochs based on validation results
171 | """
172 | if (mean_val_metrics['cd_f1scores'] > best_metrics['cd_f1scores']):
173 | # Insert training and epoch information to metadata dictionary
174 | logging.info('updata the model')
175 | metadata['val_metrics'] = mean_val_metrics
176 |
177 | # Save model and log
178 | if not os.path.exists(opt.weight_dir):
179 | os.mkdir(opt.weight_dir)
180 | with open(opt.weight_dir + 'metadata_val_epoch_' + str(epoch) + '.json', 'w') as fout:
181 | json.dump(metadata, fout)
182 |
183 | torch.save(model.state_dict(), opt.weight_dir + 'checkpoint_epoch_'+str(epoch)+'_f1_'+str(mean_val_metrics['cd_f1scores'])+'.pt')
184 | best_metrics = mean_val_metrics
185 | print('best val: ' + str(mean_val_metrics))
186 |
187 | print('An epoch finished.')
188 |
189 | print('Done!')
--------------------------------------------------------------------------------
/MFPNet_code/utils/dataloaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | from PIL import Image
4 | from utils import transforms as tr
5 |
6 |
7 | '''
8 | Load all training and validation data paths
9 | '''
10 | def full_path_loader(data_dir):
11 | train_data = [i for i in os.listdir(data_dir + 'train/A/') if not
12 | i.startswith('.')]
13 | train_data.sort()
14 |
15 | valid_data = [i for i in os.listdir(data_dir + 'val/A/') if not
16 | i.startswith('.')]
17 | valid_data.sort()
18 |
19 | train_label_paths = []
20 | val_label_paths = []
21 | if 'DSIFN' in data_dir:
22 | for img in train_data:
23 | train_label_paths.append(data_dir + 'train/label/' + img.split('.')[0] + '.png')
24 | for img in valid_data:
25 | val_label_paths.append(data_dir + 'val/label/' + img.split('.')[0] + '.png')
26 | else:
27 | for img in train_data:
28 | train_label_paths.append(data_dir + 'train/label/' + img)
29 | for img in valid_data:
30 | val_label_paths.append(data_dir + 'val/label/' + img)
31 |
32 |
33 | train_data_path = []
34 | val_data_path = []
35 |
36 | for img in train_data:
37 | train_data_path.append([data_dir + 'train/', img])
38 | for img in valid_data:
39 | val_data_path.append([data_dir + 'val/', img])
40 |
41 | train_dataset = {}
42 | val_dataset = {}
43 | for cp in range(len(train_data)):
44 | train_dataset[cp] = {'image': train_data_path[cp],
45 | 'label': train_label_paths[cp]}
46 | for cp in range(len(valid_data)):
47 | val_dataset[cp] = {'image': val_data_path[cp],
48 | 'label': val_label_paths[cp]}
49 |
50 |
51 | return train_dataset, val_dataset
52 |
53 | '''
54 | Load all testing data paths
55 | '''
56 | def full_test_loader(data_dir):
57 |
58 | test_data = [i for i in os.listdir(data_dir + 'test/A/') if not
59 | i.startswith('.')]
60 | test_data.sort()
61 |
62 | test_label_paths = []
63 | if 'DSIFN' in data_dir:
64 | for img in test_data:
65 | test_label_paths.append(data_dir + 'test/label/' + img.split('.')[0] + '.tif')
66 | else:
67 | for img in test_data:
68 | test_label_paths.append(data_dir + 'test/label/' + img)
69 |
70 | test_data_path = []
71 | for img in test_data:
72 | test_data_path.append([data_dir + 'test/', img])
73 |
74 | test_dataset = {}
75 | for cp in range(len(test_data)):
76 | test_dataset[cp] = {'image': test_data_path[cp],
77 | 'label': test_label_paths[cp]}
78 |
79 | return test_dataset
80 |
81 | def cdd_loader(img_path, label_path, aug):
82 | dir = img_path[0]
83 | name = img_path[1]
84 |
85 | img1 = Image.open(dir + 'A/' + name)
86 | img2 = Image.open(dir + 'B/' + name)
87 | label = Image.open(label_path).convert('L')
88 | sample = {'image': (img1, img2), 'label': label}
89 |
90 | if aug:
91 | sample = tr.train_transforms(sample)
92 | else:
93 | sample = tr.test_transforms(sample)
94 |
95 | return sample['image'][0], sample['image'][1], sample['label']
96 |
97 |
98 | class CDDloader(data.Dataset):
99 |
100 | def __init__(self, full_load, aug=False):
101 |
102 | self.full_load = full_load
103 | self.loader = cdd_loader
104 | self.aug = aug
105 |
106 | def __getitem__(self, index):
107 |
108 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label']
109 |
110 | return self.loader(img_path,
111 | label_path,
112 | self.aug)
113 |
114 | def __len__(self):
115 | return len(self.full_load)
116 |
--------------------------------------------------------------------------------
/MFPNet_code/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.utils.data
4 | import torch.nn as nn
5 | import numpy as np
6 | from utils.dataloaders import (full_path_loader, full_test_loader, CDDloader)
7 | from utils.metrics import jaccard_loss, dice_loss
8 | from utils.hybridloss import hybrid_loss
9 | from models.MFPNet_model import MFPNET
10 | logging.basicConfig(level=logging.INFO)
11 |
12 | def initialize_metrics():
13 | """Generates a dictionary of metrics with metrics as keys
14 | and empty lists as values
15 |
16 | Returns
17 | -------
18 | dict
19 | a dictionary of metrics
20 |
21 | """
22 | metrics = {
23 | 'cd_losses': [],
24 | 'cd_corrects': [],
25 | 'cd_precisions': [],
26 | 'cd_recalls': [],
27 | 'cd_f1scores': [],
28 | 'learning_rate': [],
29 | }
30 |
31 | return metrics
32 |
33 |
34 | def get_mean_metrics(metric_dict):
35 | """takes a dictionary of lists for metrics and returns dict of mean values
36 |
37 | Parameters
38 | ----------
39 | metric_dict : dict
40 | A dictionary of metrics
41 |
42 | Returns
43 | -------
44 | dict
45 | dict of floats that reflect mean metric value
46 |
47 | """
48 | return {k: np.mean(v) for k, v in metric_dict.items()}
49 |
50 |
51 |
52 | def set_metrics(metric_dict, cd_loss, cd_corrects, cd_report, lr):
53 | """Updates metric dict with batch metrics
54 |
55 | Parameters
56 | ----------
57 | metric_dict : dict
58 | dict of metrics
59 | cd_loss : dict(?)
60 | loss value
61 | cd_corrects : dict(?)
62 | number of correct results (to generate accuracy
63 | cd_report : list
64 | precision, recall, f1 values
65 |
66 | Returns
67 | -------
68 | dict
69 | dict of updated metrics
70 |
71 |
72 | """
73 | metric_dict['cd_losses'].append(cd_loss.item())
74 | metric_dict['cd_corrects'].append(cd_corrects.item())
75 | metric_dict['cd_precisions'].append(cd_report[0])
76 | metric_dict['cd_recalls'].append(cd_report[1])
77 | metric_dict['cd_f1scores'].append(cd_report[2])
78 | metric_dict['learning_rate'].append(lr)
79 |
80 | return metric_dict
81 |
82 | def get_loaders(opt):
83 |
84 |
85 | logging.info('STARTING Dataset Creation')
86 |
87 | train_full_load, val_full_load = full_path_loader(opt.dataset_dir)
88 |
89 |
90 | train_dataset = CDDloader(train_full_load, aug=opt.augmentation)
91 | val_dataset = CDDloader(val_full_load, aug=False)
92 |
93 | logging.info('STARTING Dataloading')
94 |
95 | train_loader = torch.utils.data.DataLoader(train_dataset,
96 | batch_size=opt.batch_size,
97 | shuffle=True,
98 | num_workers=opt.num_workers)
99 | val_loader = torch.utils.data.DataLoader(val_dataset,
100 | batch_size=opt.batch_size,
101 | shuffle=False,
102 | num_workers=opt.num_workers)
103 | return train_loader, val_loader
104 |
105 | def get_test_loaders(opt):
106 |
107 | logging.info('STARTING Test Dataset Creation')
108 |
109 | test_full_load = full_test_loader(opt.dataset_dir)
110 |
111 | test_dataset = CDDloader(test_full_load, aug=False)
112 |
113 | logging.info('STARTING Test Dataloading')
114 |
115 | test_loader = torch.utils.data.DataLoader(test_dataset,
116 | batch_size=1,
117 | shuffle=False,
118 | num_workers=opt.num_workers)
119 | return test_loader
120 |
121 |
122 | def get_criterion(opt):
123 | """get the user selected loss function
124 |
125 | Parameters
126 | ----------
127 | opt : dict
128 | Dictionary of options/flags
129 |
130 | Returns
131 | -------
132 | method
133 | loss function
134 |
135 | """
136 | if opt.loss_function == 'hybrid':
137 | criterion = hybrid_loss
138 | if opt.loss_function == 'bce':
139 | criterion = nn.CrossEntropyLoss()
140 | if opt.loss_function == 'dice':
141 | criterion = dice_loss
142 | if opt.loss_function == 'jaccard':
143 | criterion = jaccard_loss
144 |
145 | return criterion
146 |
147 |
148 | def load_model(opt, device):
149 | """Load the model
150 |
151 | Parameters
152 | ----------
153 | opt : dict
154 | User specified flags/options
155 | device : string
156 | device on which to train model
157 |
158 | """
159 | model = MFPNET(classes = 2).to(device)
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/MFPNet_code/utils/hybridloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def hybrid_loss(criterion_ce, criterion_perceptual, prediction, target, img_prev_train, vgg, dev):
5 | """Calculating the loss"""
6 | loss = 0
7 |
8 | # Perceptual Similarity Module (PSM)
9 | out_train_softmax2d = F.softmax(prediction,dim=1)
10 | an_change = out_train_softmax2d[:,1,:,:].unsqueeze(1).expand_as(img_prev_train)
11 | an_unchange = out_train_softmax2d[:,0,:,:].unsqueeze(1).expand_as(img_prev_train)
12 | label_change = target.unsqueeze(1).expand_as(img_prev_train).type(torch.FloatTensor).to(dev)
13 | label_unchange = 1-label_change
14 | an_change = an_change * label_change
15 | an_unchange = an_unchange * label_unchange
16 |
17 | an_change_feature = vgg(an_change)
18 | gt_feature = vgg(label_change)
19 | an_unchange_feature = vgg(an_unchange)
20 | gt_feature_unchange = vgg(label_unchange)
21 |
22 | perceptual_loss_change = criterion_perceptual(an_change_feature[0], gt_feature[0])
23 | perceptual_loss_unchange = criterion_perceptual(an_unchange_feature[0], gt_feature_unchange[0])
24 | perceptual_loss = perceptual_loss_change + perceptual_loss_unchange
25 |
26 | loss = 0.0001*perceptual_loss + criterion_ce(prediction, target)
27 |
28 | return loss
29 |
30 |
--------------------------------------------------------------------------------
/MFPNet_code/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 |
8 |
9 | class FocalLoss(nn.Module):
10 | def __init__(self, gamma=0, alpha=None, size_average=True):
11 | super(FocalLoss, self).__init__()
12 | self.gamma = gamma
13 | self.alpha = alpha
14 | if isinstance(alpha, (float, int)):
15 | self.alpha = torch.Tensor([alpha, 1-alpha])
16 | if isinstance(alpha, list):
17 | self.alpha = torch.Tensor(alpha)
18 | self.size_average = size_average
19 |
20 | def forward(self, input, target):
21 | if input.dim() > 2:
22 | # N,C,H,W => N,C,H*W
23 | input = input.view(input.size(0), input.size(1), -1)
24 |
25 | # N,C,H*W => N,H*W,C
26 | input = input.transpose(1, 2)
27 |
28 | # N,H*W,C => N*H*W,C
29 | input = input.contiguous().view(-1, input.size(2))
30 |
31 |
32 | target = target.view(-1, 1)
33 | logpt = F.log_softmax(input)
34 | logpt = logpt.gather(1, target)
35 | logpt = logpt.view(-1)
36 | pt = Variable(logpt.data.exp())
37 |
38 | if self.alpha is not None:
39 | if self.alpha.type() != input.data.type():
40 | self.alpha = self.alpha.type_as(input.data)
41 | at = self.alpha.gather(0, target.data.view(-1))
42 | logpt = logpt * Variable(at)
43 |
44 | loss = -1 * (1-pt)**self.gamma * logpt
45 |
46 | if self.size_average:
47 | return loss.mean()
48 | else:
49 | return loss.sum()
50 |
51 | def dice_loss(logits, true, eps=1e-7):
52 | """Computes the Sørensen–Dice loss.
53 | Note that PyTorch optimizers minimize a loss. In this
54 | case, we would like to maximize the dice loss so we
55 | return the negated dice loss.
56 | Args:
57 | true: a tensor of shape [B, 1, H, W].
58 | logits: a tensor of shape [B, C, H, W]. Corresponds to
59 | the raw output or logits of the model.
60 | eps: added to the denominator for numerical stability.
61 | Returns:
62 | dice_loss: the Sørensen–Dice loss.
63 | """
64 | num_classes = logits.shape[1]
65 | if num_classes == 1:
66 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
67 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
68 | true_1_hot_f = true_1_hot[:, 0:1, :, :]
69 | true_1_hot_s = true_1_hot[:, 1:2, :, :]
70 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
71 | pos_prob = torch.sigmoid(logits)
72 | neg_prob = 1 - pos_prob
73 | probas = torch.cat([pos_prob, neg_prob], dim=1)
74 | else:
75 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
76 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
77 | probas = F.softmax(logits, dim=1)
78 | true_1_hot = true_1_hot.type(logits.type())
79 | dims = (0,) + tuple(range(2, true.ndimension()))
80 | intersection = torch.sum(probas * true_1_hot, dims)
81 | cardinality = torch.sum(probas + true_1_hot, dims)
82 | dice_loss = (2. * intersection / (cardinality + eps)).mean()
83 | return (1 - dice_loss)
84 |
85 |
86 | def jaccard_loss(logits, true, eps=1e-7):
87 | """Computes the Jaccard loss, a.k.a the IoU loss.
88 | Note that PyTorch optimizers minimize a loss. In this
89 | case, we would like to maximize the jaccard loss so we
90 | return the negated jaccard loss.
91 | Args:
92 | true: a tensor of shape [B, H, W] or [B, 1, H, W].
93 | logits: a tensor of shape [B, C, H, W]. Corresponds to
94 | the raw output or logits of the model.
95 | eps: added to the denominator for numerical stability.
96 | Returns:
97 | jacc_loss: the Jaccard loss.
98 | """
99 | num_classes = logits.shape[1]
100 | if num_classes == 1:
101 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
102 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
103 | true_1_hot_f = true_1_hot[:, 0:1, :, :]
104 | true_1_hot_s = true_1_hot[:, 1:2, :, :]
105 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
106 | pos_prob = torch.sigmoid(logits)
107 | neg_prob = 1 - pos_prob
108 | probas = torch.cat([pos_prob, neg_prob], dim=1)
109 | else:
110 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
111 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
112 | probas = F.softmax(logits, dim=1)
113 | true_1_hot = true_1_hot.type(logits.type())
114 | dims = (0,) + tuple(range(2, true.ndimension()))
115 | intersection = torch.sum(probas * true_1_hot, dims)
116 | cardinality = torch.sum(probas + true_1_hot, dims)
117 | union = cardinality - intersection
118 | jacc_loss = (intersection / (union + eps)).mean()
119 | return (1 - jacc_loss)
120 |
121 |
122 | class TverskyLoss(nn.Module):
123 | def __init__(self, alpha=0.5, beta=0.5, eps=1e-7, size_average=True):
124 | super(TverskyLoss, self).__init__()
125 | self.alpha = alpha
126 | self.beta = beta
127 | self.size_average = size_average
128 | self.eps = eps
129 |
130 | def forward(self, logits, true):
131 | """Computes the Tversky loss [1].
132 | Args:
133 | true: a tensor of shape [B, H, W] or [B, 1, H, W].
134 | logits: a tensor of shape [B, C, H, W]. Corresponds to
135 | the raw output or logits of the model.
136 | alpha: controls the penalty for false positives.
137 | beta: controls the penalty for false negatives.
138 | eps: added to the denominator for numerical stability.
139 | Returns:
140 | tversky_loss: the Tversky loss.
141 | Notes:
142 | alpha = beta = 0.5 => dice coeff
143 | alpha = beta = 1 => tanimoto coeff
144 | alpha + beta = 1 => F beta coeff
145 | References:
146 | [1]: https://arxiv.org/abs/1706.05721
147 | """
148 | num_classes = logits.shape[1]
149 | if num_classes == 1:
150 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
151 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
152 | true_1_hot_f = true_1_hot[:, 0:1, :, :]
153 | true_1_hot_s = true_1_hot[:, 1:2, :, :]
154 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
155 | pos_prob = torch.sigmoid(logits)
156 | neg_prob = 1 - pos_prob
157 | probas = torch.cat([pos_prob, neg_prob], dim=1)
158 | else:
159 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
160 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
161 | probas = F.softmax(logits, dim=1)
162 |
163 | true_1_hot = true_1_hot.type(logits.type())
164 | dims = (0,) + tuple(range(2, true.ndimension()))
165 | intersection = torch.sum(probas * true_1_hot, dims)
166 | fps = torch.sum(probas * (1 - true_1_hot), dims)
167 | fns = torch.sum((1 - probas) * true_1_hot, dims)
168 | num = intersection
169 | denom = intersection + (self.alpha * fps) + (self.beta * fns)
170 | tversky_loss = (num / (denom + self.eps)).mean()
171 | return (1 - tversky_loss)
172 |
--------------------------------------------------------------------------------
/MFPNet_code/utils/parser.py:
--------------------------------------------------------------------------------
1 | import argparse as ag
2 | import json
3 |
4 | def get_parser_with_args(metadata_json_path=None):
5 | parser = ag.ArgumentParser(description='Training change detection network')
6 |
7 | with open(metadata_json_path, 'r') as fin:
8 | metadata = json.load(fin)
9 | parser.set_defaults(**metadata)
10 | return parser, metadata
11 |
12 | return None
13 |
--------------------------------------------------------------------------------
/MFPNet_code/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 |
5 | from PIL import Image, ImageOps, ImageFilter
6 | import torchvision.transforms as transforms
7 |
8 | class Normalize(object):
9 | """Normalize a tensor image with mean and standard deviation.
10 | Args:
11 | mean (tuple): means for each channel.
12 | std (tuple): standard deviations for each channel.
13 | """
14 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
15 | self.mean = mean
16 | self.std = std
17 |
18 | def __call__(self, sample):
19 | img = sample['image']
20 | mask = sample['label']
21 | img = np.array(img).astype(np.float32)
22 | mask = np.array(mask).astype(np.float32)
23 | img /= 255.0
24 | img -= self.mean
25 | img /= self.std
26 |
27 | return {'image': img,
28 | 'label': mask}
29 |
30 |
31 | class ToTensor(object):
32 | """Convert ndarrays in sample to Tensors."""
33 |
34 | def __call__(self, sample):
35 | # swap color axis because
36 | # numpy image: H x W x C
37 | # torch image: C X H X W
38 | img1 = sample['image'][0]
39 | img2 = sample['image'][1]
40 | mask = sample['label']
41 | img1 = np.array(img1).astype(np.float32).transpose((2, 0, 1))
42 | img2 = np.array(img2).astype(np.float32).transpose((2, 0, 1))
43 | if np.unique(mask).sum() == 1:
44 | mask = np.array(mask).astype(np.float32)
45 | else:
46 | mask = np.array(mask).astype(np.float32) / 255.0
47 |
48 | img1 = torch.from_numpy(img1).float()
49 | img2 = torch.from_numpy(img2).float()
50 | mask = torch.from_numpy(mask).float()
51 |
52 | return {'image': (img1, img2),
53 | 'label': mask}
54 |
55 |
56 | class RandomHorizontalFlip(object):
57 | def __call__(self, sample):
58 | img1 = sample['image'][0]
59 | img2 = sample['image'][1]
60 | mask = sample['label']
61 | if random.random() < 0.5:
62 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
63 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
64 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
65 |
66 | return {'image': (img1, img2),
67 | 'label': mask}
68 |
69 | class RandomVerticalFlip(object):
70 | def __call__(self, sample):
71 | img1 = sample['image'][0]
72 | img2 = sample['image'][1]
73 | mask = sample['label']
74 | if random.random() < 0.5:
75 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM)
76 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM)
77 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
78 |
79 | return {'image': (img1, img2),
80 | 'label': mask}
81 |
82 | class RandomFixRotate(object):
83 | def __init__(self):
84 | self.degree = [Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270]
85 |
86 | def __call__(self, sample):
87 | img1 = sample['image'][0]
88 | img2 = sample['image'][1]
89 | mask = sample['label']
90 | if random.random() < 0.75:
91 | rotate_degree = random.choice(self.degree)
92 | img1 = img1.transpose(rotate_degree)
93 | img2 = img2.transpose(rotate_degree)
94 | mask = mask.transpose(rotate_degree)
95 |
96 | return {'image': (img1, img2),
97 | 'label': mask}
98 |
99 |
100 | class RandomRotate(object):
101 | def __init__(self, degree):
102 | self.degree = degree
103 |
104 | def __call__(self, sample):
105 | img1 = sample['image'][0]
106 | img2 = sample['image'][1]
107 | mask = sample['label']
108 | rotate_degree = random.uniform(-1*self.degree, self.degree)
109 | img1 = img1.rotate(rotate_degree, Image.BILINEAR)
110 | img2 = img2.rotate(rotate_degree, Image.BILINEAR)
111 | mask = mask.rotate(rotate_degree, Image.NEAREST)
112 |
113 | return {'image': (img1, img2),
114 | 'label': mask}
115 |
116 |
117 | class RandomGaussianBlur(object):
118 | def __call__(self, sample):
119 | img1 = sample['image'][0]
120 | img2 = sample['image'][1]
121 | mask = sample['label']
122 | if random.random() < 0.5:
123 | img1 = img1.filter(ImageFilter.GaussianBlur(
124 | radius=random.random()))
125 | img2 = img2.filter(ImageFilter.GaussianBlur(
126 | radius=random.random()))
127 |
128 | return {'image': (img1, img2),
129 | 'label': mask}
130 |
131 |
132 | class RandomScaleCrop(object):
133 | def __init__(self, base_size, crop_size, fill=0):
134 | self.base_size = base_size
135 | self.crop_size = crop_size
136 | self.fill = fill
137 |
138 | def __call__(self, sample):
139 | img = sample['image']
140 | mask = sample['label']
141 | # random scale (short edge)
142 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
143 | w, h = img.size
144 | if h > w:
145 | ow = short_size
146 | oh = int(1.0 * h * ow / w)
147 | else:
148 | oh = short_size
149 | ow = int(1.0 * w * oh / h)
150 | img = img.resize((ow, oh), Image.BILINEAR)
151 | mask = mask.resize((ow, oh), Image.NEAREST)
152 | # pad crop
153 | if short_size < self.crop_size:
154 | padh = self.crop_size - oh if oh < self.crop_size else 0
155 | padw = self.crop_size - ow if ow < self.crop_size else 0
156 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
157 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
158 | # random crop crop_size
159 | w, h = img.size
160 | x1 = random.randint(0, w - self.crop_size)
161 | y1 = random.randint(0, h - self.crop_size)
162 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
163 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
164 |
165 | return {'image': img,
166 | 'label': mask}
167 |
168 |
169 | class FixScaleCrop(object):
170 | def __init__(self, crop_size):
171 | self.crop_size = crop_size
172 |
173 | def __call__(self, sample):
174 | img = sample['image']
175 | mask = sample['label']
176 | w, h = img.size
177 | if w > h:
178 | oh = self.crop_size
179 | ow = int(1.0 * w * oh / h)
180 | else:
181 | ow = self.crop_size
182 | oh = int(1.0 * h * ow / w)
183 | img = img.resize((ow, oh), Image.BILINEAR)
184 | mask = mask.resize((ow, oh), Image.NEAREST)
185 | # center crop
186 | w, h = img.size
187 | x1 = int(round((w - self.crop_size) / 2.))
188 | y1 = int(round((h - self.crop_size) / 2.))
189 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
190 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
191 |
192 | return {'image': img,
193 | 'label': mask}
194 |
195 | class FixedResize(object):
196 | def __init__(self, size):
197 | self.size = (size, size) # size: (h, w)
198 |
199 | def __call__(self, sample):
200 | img1 = sample['image'][0]
201 | img2 = sample['image'][1]
202 | mask = sample['label']
203 |
204 | assert img1.size == mask.size and img2.size == mask.size
205 |
206 | img1 = img1.resize(self.size, Image.BILINEAR)
207 | img2 = img2.resize(self.size, Image.BILINEAR)
208 | mask = mask.resize(self.size, Image.NEAREST)
209 |
210 | return {'image': (img1, img2),
211 | 'label': mask}
212 |
213 |
214 | '''
215 | We don't use Normalize here, because it will bring negative effects.
216 | the mask of ground truth is converted to [0,1] in ToTensor() function.
217 | '''
218 | train_transforms = transforms.Compose([
219 | RandomHorizontalFlip(),
220 | RandomVerticalFlip(),
221 | RandomFixRotate(),
222 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
223 | # RandomGaussianBlur(),
224 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
225 | ToTensor()])
226 |
227 | test_transforms = transforms.Compose([
228 | # RandomHorizontalFlip(),
229 | # RandomVerticalFlip(),
230 | # RandomFixRotate(),
231 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
232 | # RandomGaussianBlur(),
233 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
234 | ToTensor()])
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Remote Sensing Change Detection Based on Multidirectional Adaptive Feature Fusion and Perceptual Similarity
2 | 
3 |
4 | PyTorch implementation for "[Remote Sensing Change Detection Based on Multidirectional Adaptive Feature Fusion and Perceptual Similarity](https://www.mdpi.com/2072-4292/13/15/3053)"
5 |
6 | - [03 August 2021] Release the code of MFPNet model.
7 | - [28 June 2022] Release the processed datasets, training/evaluation codes.
8 | - [08 JUly 2022] Release model weights for Season-varying/LEVIR-CD/Google datasets.
9 |
10 | ## Introduction
11 | Remote sensing change detection (RSCD) is an important yet challenging task in Earth observation. The booming development of convolutional neural networks (CNNs) in computer vision raises new possibilities for RSCD, and many recent RSCD methods have introduced CNNs to achieve promising improvements in performance. This paper proposes a novel multidirectional fusion and perception network for change detection in bi-temporal very-high-resolution remote sensing images. First, we propose an elaborate feature fusion module consisting of a multidirectional fusion pathway (MFP) and an adaptive weighted fusion (AWF) strategy for RSCD to boost the way that information propagates in the network. The MFP enhances the flexibility and diversity of information paths by creating extra top-down and shortcut-connection paths. The AWF strategy conducts weight recalibration for every fusion node to highlight salient feature maps and overcome semantic gaps between different features. Second, a novel perceptual similarity module is designed to introduce perceptual loss into the RSCD task, which adds the perceptual information, such as structure and semantic, for high-quality change maps generation. Extensive experiments on four challenging benchmark datasets demonstrate the superiority of the proposed network comparing with eight state-of-the-art methods in terms of F1, Kappa, and visual qualities.
12 |
13 | ## Content
14 | ### Architecture
15 |
16 |
17 | Fig.1 Overall architecture of the proposed multidirectional fusion and perception network (MFPNet).
18 | Note that the process with the dashed line only participates in model training.
19 |
20 | ### Datasets
21 | The processed and original datasets can be downloaded from the table below, we recommended downloading the processed one directly to get a quick start on our codes:
22 |
23 |
24 |
25 | Datasets |
26 | Processed Links |
27 | Original Links |
28 |
29 |
30 | Season-varying Dataset [1] |
31 | [Google Drive]
32 | [Baidu Drive]
33 | | [Original] |
34 |
35 | LEVIR-CD Dataset [2] |
36 | [Original] |
37 |
38 |
39 | Google Dataset [3] |
40 | [Original] |
41 |
42 |
43 | Zhange Dataset [4] |
44 | [Original] |
45 |
46 |
47 |
48 | ### Setup & Usage for the Code
49 |
50 | 1. Check the structure of data folders:
51 | ```
52 | (root folder)
53 | ├── dataset1
54 | | ├── train
55 | | | ├── A
56 | | | ├── B
57 | | | ├── label
58 | | ├── val
59 | | | ├── A
60 | | | ├── B
61 | | | ├── label
62 | | ├── test
63 | | | ├── A
64 | | | ├── B
65 | | | ├── label
66 | ├── ...
67 | ```
68 |
69 | 2. Check dependencies:
70 | ```
71 | - Python 3.6+
72 | - PyTorch 1.7.0+
73 | - scikit-learn
74 | - cudatoolkit
75 | - cudnn
76 | - OpenCV-Python
77 | ```
78 |
79 | 3. Change paths:
80 | ```
81 | - Change the 'metadata_json_path' in 'train.py' to your 'metadata.json' path.
82 | - Change the 'dataset_dir' and 'weight_dir' in 'metadata.json' to your own path.
83 | ```
84 |
85 | 4. Train the MFPNet:
86 | ```
87 | python train.py
88 | ```
89 |
90 | 5. Evaluate the MFPNet:
91 | ```
92 | - Download model weights (optional).
93 | - Change the 'weight_path' in 'eval.py' to your model weight path.
94 | - python eval.py
95 | ```
96 |
97 | ## Model Weights
98 | Model weights for Season-varying/LEVIR-CD/Google datasets are available via [Google Drive](https://drive.google.com/drive/folders/1-2njQ7Z3IIrjv6YGXoMD2CBZbc1nQuRu?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/141aQDQ_lMEi83O2t6AcLqg?pwd=1234). Note that the training/dataloader codes are rewritten and improved so the performance is a little different from the paper.
99 |
100 |
101 |
102 | Datasets |
103 | F1 (%) |
104 | Kappa (%) |
105 |
106 |
107 | Season-varying |
108 | 97.964 |
109 | 97.691 |
110 |
111 | LEVIR-CD |
112 | 91.568 |
113 | 91.120 |
114 |
115 |
116 | Google |
117 | 88.058 |
118 | 84.140 |
119 |
120 |
121 |
122 |
123 | ## Reference
124 | Appreciate the work from the following repositories:
125 | * [likyoo/Siam-NestedUNet](https://github.com/likyoo/Siam-NestedUNet)
126 | * [zylo117/Yet-Another-EfficientDet-Pytorch](https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch)
127 |
128 | ## Cite
129 | If this repository is useful for your research, please cite:
130 | ```
131 | @Article{rs13153053,
132 | AUTHOR = {Xu, Jialang and Luo, Chunbo and Chen, Xinyue and Wei, Shicai and Luo, Yang},
133 | TITLE = {Remote Sensing Change Detection Based on Multidirectional Adaptive Feature Fusion and Perceptual Similarity},
134 | JOURNAL = {Remote Sensing},
135 | VOLUME = {13},
136 | YEAR = {2021},
137 | NUMBER = {15},
138 | ARTICLE-NUMBER = {3053},
139 | URL = {https://www.mdpi.com/2072-4292/13/15/3053},
140 | ISSN = {2072-4292},
141 | DOI = {10.3390/rs13153053}
142 | }
143 | ```
144 |
--------------------------------------------------------------------------------
/figure/AWF.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/AWF.png
--------------------------------------------------------------------------------
/figure/MFP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/MFP.png
--------------------------------------------------------------------------------
/figure/MFPNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/MFPNet.png
--------------------------------------------------------------------------------
/figure/PSM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/PSM.png
--------------------------------------------------------------------------------