├── LICENSE ├── MFPNet_code ├── eval.py ├── metadata.json ├── metadata_descripation.py ├── models │ ├── MFPNet_model.py │ ├── seresnet50.py │ └── vgg.py ├── train.py └── utils │ ├── dataloaders.py │ ├── helpers.py │ ├── hybridloss.py │ ├── metrics.py │ ├── parser.py │ └── transforms.py ├── README.md └── figure ├── AWF.png ├── MFP.png ├── MFPNet.png └── PSM.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jialang Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MFPNet_code/eval.py: -------------------------------------------------------------------------------- 1 | from shutil import copyfile 2 | import torch.utils.data 3 | from utils.parser import get_parser_with_args 4 | from utils.helpers import get_test_loaders 5 | from tqdm import tqdm 6 | from sklearn.metrics import confusion_matrix 7 | import numpy as np 8 | import torch.nn.functional as F 9 | import cv2 10 | import os 11 | from utils.helpers import load_model 12 | 13 | parser, metadata = get_parser_with_args(metadata_json_path='/home/aaa/xujialang/master_thesis/MFPNet/metadata.json') 14 | opt = parser.parse_args() 15 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | test_loader = get_test_loaders(opt) 18 | 19 | weight_path = os.path.join(opt.weight_dir, 'model_weight.pt') # the path of the model weight 20 | model = load_model(opt, dev) 21 | model.load_state_dict(torch.load(weight_path)) 22 | """ 23 | Begin Test 24 | """ 25 | model.eval() 26 | with torch.no_grad(): 27 | c_matrix = {'tn': 0, 'fp': 0, 'fn': 0, 'tp': 0} 28 | test_metrics = { 29 | 'cd_precisions': [], 30 | 'cd_recalls': [], 31 | 'cd_f1scores': [], 32 | } 33 | 34 | for batch_img1, batch_img2, labels in test_loader: 35 | batch_img1 = batch_img1.float().to(dev) 36 | batch_img2 = batch_img2.float().to(dev) 37 | labels = labels.long().to(dev) 38 | cd_preds = model(batch_img1, batch_img2) 39 | cd_preds = torch.argmax(cd_preds, dim = 1) 40 | 41 | tp= (labels.cpu().numpy() * cd_preds.cpu().numpy()).sum() 42 | tn= ((1-labels.cpu().numpy()) * (1-cd_preds.cpu().numpy())).sum() 43 | fn= (labels.cpu().numpy() * (1-cd_preds.cpu().numpy())).sum() 44 | fp= ((1-labels.cpu().numpy()) * cd_preds.cpu().numpy()).sum() 45 | c_matrix['tn'] += tn 46 | c_matrix['fp'] += fp 47 | c_matrix['fn'] += fn 48 | c_matrix['tp'] += tp 49 | 50 | tn, fp, fn, tp = c_matrix['tn'], c_matrix['fp'], c_matrix['fn'], c_matrix['tp'] 51 | P = tp / (tp + fp) 52 | R = tp / (tp + fn) 53 | F1 = 2 * P * R / (R + P) 54 | IOU = tp/ (fn+tp+fp) 55 | 56 | ttt_test=tn+fp+fn+tp 57 | TA_test = (tp+tn) / ttt_test 58 | Pcp1_test = (tp + fn) / ttt_test 59 | Pcp2_test = (tp + fp) / ttt_test 60 | Pcn1_test = (fp + tn) / ttt_test 61 | Pcn2_test = (fn + tn) / ttt_test 62 | Pc_test = Pcp1_test*Pcp2_test + Pcn1_test*Pcn2_test 63 | kappa_test = (TA_test - Pc_test) / (1 - Pc_test) 64 | 65 | test_metrics['cd_f1scores'] = F1 66 | test_metrics['cd_precisions'] = P 67 | test_metrics['cd_recalls'] = R 68 | print("TEST METRICS. KAPPA: {}. IOU: {} ".format(kappa_test, IOU) + str(test_metrics)) -------------------------------------------------------------------------------- /MFPNet_code/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "patch_size": 256, 3 | "augmentation": true, 4 | "num_gpus": 1, 5 | "num_workers": 4, 6 | "num_channel": 3, 7 | "epochs": 200, 8 | "batch_size": 4, 9 | "learning_rate": 1e-4, 10 | "loss_function": "hybrid", 11 | "dataset_dir": "/home/bigspace/xujialang/cd_dataset/Google/", 12 | "weight_dir": "/home/bigspace/xujialang/MFPNet_result/Google/", 13 | "resume": "None" 14 | } -------------------------------------------------------------------------------- /MFPNet_code/metadata_descripation.py: -------------------------------------------------------------------------------- 1 | 2 | # For Seasonvarying/LEVIR-CD/Google Dataset 3 | { 4 | "patch_size": 256, 5 | "augmentation": true, 6 | "num_gpus": 1, 7 | "num_workers": 4, 8 | "num_channel": 3, 9 | "epochs": 200, 10 | "batch_size": 4, 11 | "learning_rate": 1e-4, 12 | "loss_function": "hybrid", # ['hybird', 'bce', 'dice', 'jaccard'], 'hybrid' means Softmax PPCE + Perceputal Loss 13 | "dataset_dir": "/home/bigspace/xujialang/cd_dataset/Seasonvarying/", # change to your own path 14 | "weight_dir": "/home/bigspace/xujialang/MFPNet_result/Seasonvarying/", # change to your own path 15 | "resume": "None" # Change if you want to continue your training process 16 | } 17 | 18 | # For Zhang dataset 19 | { 20 | "patch_size": 512, 21 | "augmentation": true, 22 | "num_gpus": 1, 23 | "num_workers": 4, 24 | "num_channel": 3, 25 | "epochs": 200, 26 | "batch_size": 2, 27 | "learning_rate": 1e-4, 28 | "loss_function": "hybrid", 29 | "dataset_dir": "/home/bigspace/xujialang/cd_dataset/Zhang/" 30 | "weight_dir": "/home/bigspace/xujialang/MFPNet_result/Zhang/", 31 | "resume": "None" 32 | } 33 | -------------------------------------------------------------------------------- /MFPNet_code/models/MFPNet_model.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import os 7 | 8 | from .seresnet50 import se_resnet50 9 | 10 | class BasicConvBlock(nn.Module): 11 | def __init__(self, in_channels, out_channels=None): 12 | super(BasicConvBlock, self).__init__() 13 | 14 | if out_channels is None: 15 | out_channels = in_channels 16 | 17 | self.conv = nn.Sequential( 18 | nn.Conv2d(in_channels, in_channels, 3, padding=1), 19 | nn.BatchNorm2d(in_channels), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d( in_channels, out_channels, 1, bias=False), 22 | nn.BatchNorm2d(out_channels), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d( out_channels, out_channels, 3, padding=1), 25 | nn.BatchNorm2d(out_channels), 26 | nn.ReLU(inplace=True), 27 | ) 28 | 29 | def forward(self,x): 30 | x=self.conv(x) 31 | return x 32 | 33 | class Conv2dStaticSamePadding(nn.Module): 34 | """ 35 | created by Zylo117 36 | The real keras/tensorflow conv2d with same padding 37 | """ 38 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs): 39 | super().__init__() 40 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 41 | bias=bias, groups=groups) 42 | self.stride = self.conv.stride 43 | self.kernel_size = self.conv.kernel_size 44 | self.dilation = self.conv.dilation 45 | 46 | if isinstance(self.stride, int): 47 | self.stride = [self.stride] * 2 48 | elif len(self.stride) == 1: 49 | self.stride = [self.stride[0]] * 2 50 | 51 | if isinstance(self.kernel_size, int): 52 | self.kernel_size = [self.kernel_size] * 2 53 | elif len(self.kernel_size) == 1: 54 | self.kernel_size = [self.kernel_size[0]] * 2 55 | 56 | def forward(self, x): 57 | h, w = x.shape[-2:] 58 | 59 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1] 60 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0] 61 | 62 | left = extra_h // 2 63 | right = extra_h - left 64 | top = extra_v // 2 65 | bottom = extra_v - top 66 | 67 | x = F.pad(x, [left, right, top, bottom]) 68 | 69 | x = self.conv(x) 70 | return x 71 | 72 | class MaxPool2dStaticSamePadding(nn.Module): 73 | """ 74 | created by Zylo117 75 | The real keras/tensorflow MaxPool2d with same padding 76 | """ 77 | def __init__(self, *args, **kwargs): 78 | super().__init__() 79 | self.pool = nn.MaxPool2d(*args, **kwargs) 80 | self.stride = self.pool.stride 81 | self.kernel_size = self.pool.kernel_size 82 | 83 | if isinstance(self.stride, int): 84 | self.stride = [self.stride] * 2 85 | elif len(self.stride) == 1: 86 | self.stride = [self.stride[0]] * 2 87 | 88 | if isinstance(self.kernel_size, int): 89 | self.kernel_size = [self.kernel_size] * 2 90 | elif len(self.kernel_size) == 1: 91 | self.kernel_size = [self.kernel_size[0]] * 2 92 | 93 | def forward(self, x): 94 | h, w = x.shape[-2:] 95 | 96 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1] 97 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0] 98 | 99 | left = extra_h // 2 100 | right = extra_h - left 101 | top = extra_v // 2 102 | bottom = extra_v - top 103 | 104 | x = F.pad(x, [left, right, top, bottom]) 105 | 106 | x = self.pool(x) 107 | return x 108 | 109 | # Channel Attention Algorithm (CAA) 110 | class CAA(nn.Module): 111 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg','max']): 112 | super(CAA, self).__init__() 113 | self.num=1 114 | self.gate_channels = gate_channels 115 | 116 | self.conv_fc1 = nn.Sequential( 117 | nn.Conv2d(in_channels=gate_channels, out_channels=gate_channels//reduction_ratio, kernel_size=1, bias=False), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(in_channels=gate_channels//reduction_ratio, out_channels=gate_channels, kernel_size=1, bias=False), 120 | ) 121 | self.conv_fc2 = nn.Sequential( 122 | nn.Conv2d(in_channels=gate_channels, out_channels=gate_channels//reduction_ratio, kernel_size=1, bias=False), 123 | nn.ReLU(inplace=True), 124 | nn.Conv2d(in_channels=gate_channels//reduction_ratio, out_channels=gate_channels, kernel_size=1, bias=False), 125 | ) 126 | self.conv= nn.Sequential( 127 | nn.Conv2d(gate_channels,gate_channels,kernel_size=(2,1),bias=False), 128 | nn.Sigmoid() 129 | ) 130 | 131 | self.pool_types = pool_types 132 | 133 | def forward(self, x): 134 | channel_att_sum = None 135 | b,c,h,w=x.size() 136 | 137 | for pool_type in self.pool_types: 138 | if pool_type=='avg': 139 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 140 | channel_att_raw = self.conv_fc1(avg_pool).view(b,c,self.num,self.num) 141 | elif pool_type=='max': 142 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 143 | channel_att_raw = self.conv_fc2(max_pool).view(b,c,self.num,self.num) 144 | 145 | if channel_att_sum is None: 146 | channel_att_sum = channel_att_raw 147 | else: 148 | channel_att_sum = torch.cat([channel_att_sum, channel_att_raw],dim=2) 149 | 150 | channel_weight=self.conv(channel_att_sum) 151 | scale = nn.functional.upsample_bilinear(channel_weight, [h, w]) 152 | 153 | return x * scale 154 | 155 | # Multidirectional Adaptive Feature Fusion Module (MAFFM) 156 | class MAFFM(nn.Module): 157 | def __init__(self, num_channels, conv_channels): 158 | super(MAFFM, self).__init__() 159 | 160 | # Conv layers 161 | self.conv5 = BasicConvBlock(num_channels) 162 | self.conv4 = BasicConvBlock(num_channels) 163 | self.conv3 = BasicConvBlock(num_channels) 164 | self.conv2 = BasicConvBlock(num_channels) 165 | self.conv1 = BasicConvBlock(num_channels) 166 | 167 | self.conv5_1 = BasicConvBlock(num_channels) 168 | self.conv4_1 = BasicConvBlock(num_channels) 169 | self.conv3_1 = BasicConvBlock(num_channels) 170 | self.conv2_1 = BasicConvBlock(num_channels) 171 | self.conv1_1 = BasicConvBlock(num_channels) 172 | 173 | self.conv1_down = BasicConvBlock(num_channels) 174 | self.conv2_down = BasicConvBlock(num_channels) 175 | self.conv3_down = BasicConvBlock(num_channels) 176 | self.conv4_down = BasicConvBlock(num_channels) 177 | self.conv5_down = BasicConvBlock(num_channels) 178 | 179 | # Feature scaling layers 180 | self.p4_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest') 181 | self.p3_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest') 182 | self.p2_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest') 183 | self.p1_upsample_1 = nn.Upsample(scale_factor=2, mode='nearest') 184 | 185 | self.p2_downsample = MaxPool2dStaticSamePadding(3, 2) 186 | self.p3_downsample = MaxPool2dStaticSamePadding(3, 2) 187 | self.p4_downsample = MaxPool2dStaticSamePadding(3, 2) 188 | self.p5_downsample = MaxPool2dStaticSamePadding(3, 2) 189 | 190 | self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') 191 | self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') 192 | self.p2_upsample = nn.Upsample(scale_factor=2, mode='nearest') 193 | self.p1_upsample = nn.Upsample(scale_factor=2, mode='nearest') 194 | 195 | # Channel compression layers 196 | self.p5_down_channel = nn.Sequential( 197 | Conv2dStaticSamePadding(conv_channels[4], num_channels, 1), 198 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 199 | nn.ReLU(inplace=True), 200 | ) 201 | self.p4_down_channel = nn.Sequential( 202 | Conv2dStaticSamePadding(conv_channels[3], num_channels, 1), 203 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 204 | nn.ReLU(inplace=True), 205 | ) 206 | self.p3_down_channel = nn.Sequential( 207 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), 208 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 209 | nn.ReLU(inplace=True), 210 | ) 211 | self.p2_down_channel = nn.Sequential( 212 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), 213 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 214 | nn.ReLU(inplace=True), 215 | ) 216 | self.p1_down_channel = nn.Sequential( 217 | Conv2dStaticSamePadding(conv_channels[0], num_channels, 1), 218 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 219 | nn.ReLU(inplace=True), 220 | ) 221 | 222 | # CAA 223 | self.csac_p1_0=CAA(num_channels,reduction_ratio=1) 224 | self.csac_p2_0=CAA(num_channels,reduction_ratio=1) 225 | self.csac_p3_0=CAA(num_channels,reduction_ratio=1) 226 | self.csac_p4_0=CAA(num_channels,reduction_ratio=1) 227 | self.csac_p5_0=CAA(num_channels,reduction_ratio=1) 228 | 229 | self.csac_p1_1=CAA(num_channels,reduction_ratio=1) 230 | self.csac_p2_1=CAA(num_channels,reduction_ratio=1) 231 | self.csac_p3_1=CAA(num_channels,reduction_ratio=1) 232 | self.csac_p4_1=CAA(num_channels,reduction_ratio=1) 233 | self.csac_p5_1=CAA(num_channels,reduction_ratio=1) 234 | 235 | self.csac_p1_2=CAA(num_channels,reduction_ratio=1) 236 | self.csac_p2_2=CAA(num_channels,reduction_ratio=1) 237 | self.csac_p3_2=CAA(num_channels,reduction_ratio=1) 238 | self.csac_p4_2=CAA(num_channels,reduction_ratio=1) 239 | 240 | 241 | self.csac_p51_0=CAA(num_channels,reduction_ratio=1) 242 | self.csac_p41_0=CAA(num_channels,reduction_ratio=1) 243 | self.csac_p31_0=CAA(num_channels,reduction_ratio=1) 244 | self.csac_p21_0=CAA(num_channels,reduction_ratio=1) 245 | 246 | self.csac_p51_1=CAA(num_channels,reduction_ratio=1) 247 | self.csac_p41_1=CAA(num_channels,reduction_ratio=1) 248 | self.csac_p31_1=CAA(num_channels,reduction_ratio=1) 249 | self.csac_p21_1=CAA(num_channels,reduction_ratio=1) 250 | 251 | 252 | self.csac_p52_0=CAA(num_channels,reduction_ratio=1) 253 | self.csac_p42_0=CAA(num_channels,reduction_ratio=1) 254 | self.csac_p32_0=CAA(num_channels,reduction_ratio=1) 255 | self.csac_p22_0=CAA(num_channels,reduction_ratio=1) 256 | self.csac_p12_0=CAA(num_channels,reduction_ratio=1) 257 | 258 | self.csac_p52_1=CAA(num_channels,reduction_ratio=1) 259 | self.csac_p42_1=CAA(num_channels,reduction_ratio=1) 260 | self.csac_p32_1=CAA(num_channels,reduction_ratio=1) 261 | self.csac_p22_1=CAA(num_channels,reduction_ratio=1) 262 | self.csac_p12_1=CAA(num_channels,reduction_ratio=1) 263 | 264 | self.csac_p42_2=CAA(num_channels,reduction_ratio=1) 265 | self.csac_p32_2=CAA(num_channels,reduction_ratio=1) 266 | self.csac_p22_2=CAA(num_channels,reduction_ratio=1) 267 | self.csac_p12_2=CAA(num_channels,reduction_ratio=1) 268 | 269 | def forward(self, inputs): 270 | p1_pre, p2_pre, p3_pre, p4_pre, p5_pre, p1_now, p2_now, p3_now, p4_now, p5_now = inputs 271 | 272 | p1_in_pre = self.p1_down_channel(p1_pre) 273 | p1_in_now = self.p1_down_channel(p1_now) 274 | 275 | p2_in_pre = self.p2_down_channel(p2_pre) 276 | p2_in_now = self.p2_down_channel(p2_now) 277 | 278 | p3_in_pre = self.p3_down_channel(p3_pre) 279 | p3_in_now = self.p3_down_channel(p3_now) 280 | 281 | p4_in_pre = self.p4_down_channel(p4_pre) 282 | p4_in_now = self.p4_down_channel(p4_now) 283 | 284 | p5_in_pre = self.p5_down_channel(p5_pre) 285 | p5_in_now = self.p5_down_channel(p5_now) 286 | 287 | # Multidirectional Fusion Pathway (MFP) + Adaptive Weighted Fusion (AWF) 288 | # Up 289 | p5_in=self.conv5(self.csac_p5_0(p5_in_now)+self.csac_p5_1(p5_in_pre)) 290 | p4_in=self.conv4(self.csac_p4_0(p4_in_now)+self.csac_p4_1(p4_in_pre)+self.csac_p4_2(self.p4_upsample(p5_in))) 291 | p3_in=self.conv3(self.csac_p3_0(p3_in_now)+self.csac_p3_1(p3_in_pre)+self.csac_p3_2(self.p3_upsample(p4_in))) 292 | p2_in=self.conv2(self.csac_p2_0(p2_in_now)+self.csac_p2_1(p2_in_pre)+self.csac_p2_2(self.p2_upsample(p3_in))) 293 | p1_in=self.conv1(self.csac_p1_0(p1_in_now)+self.csac_p1_1(p1_in_pre)+self.csac_p1_2(self.p1_upsample(p2_in))) 294 | # Down 295 | p1_1 = self.conv1_down(p1_in) 296 | p2_1 = self.conv2_down(self.csac_p21_0(p2_in) + self.csac_p21_1(self.p2_downsample(p1_1))) 297 | p3_1 = self.conv3_down(self.csac_p31_0(p3_in) + self.csac_p31_1(self.p3_downsample(p2_1))) 298 | p4_1 = self.conv4_down(self.csac_p41_0(p4_in) + self.csac_p41_1(self.p4_downsample(p3_1))) 299 | p5_1 = self.conv5_down(self.csac_p51_0(p5_in) + self.csac_p51_1(self.p5_downsample(p4_1))) 300 | # Up 301 | p5_2 = self.conv5_1(self.csac_p52_0(p5_in) + self.csac_p52_1(p5_1)) 302 | p4_2 = self.conv4_1(self.csac_p42_0(p4_in) + self.csac_p42_1(p4_1)+self.csac_p42_2(self.p4_upsample_1(p5_2))) 303 | p3_2 = self.conv3_1(self.csac_p32_0(p3_in) + self.csac_p32_1(p3_1)+self.csac_p32_2(self.p3_upsample_1(p4_2))) 304 | p2_2 = self.conv2_1(self.csac_p22_0(p2_in) + self.csac_p22_1(p2_1)+self.csac_p22_2(self.p2_upsample_1(p3_2))) 305 | p1_2 = self.conv1_1(self.csac_p12_0(p1_in) + self.csac_p12_1(p1_1)+self.csac_p12_2(self.p1_upsample_1(p2_2))) 306 | 307 | return p1_2 308 | 309 | class DECODER(nn.Module): 310 | def __init__(self, in_ch, classes): 311 | super(DECODER, self).__init__() 312 | self.conv1 = nn.Conv2d( 313 | in_ch, in_ch//4, kernel_size=3, padding=1) 314 | self.conv2 = nn.Conv2d( 315 | in_ch//4, in_ch//8, kernel_size=3, padding=1) 316 | self.conv3 = nn.Conv2d( 317 | in_ch//8, classes*4, kernel_size=1) 318 | 319 | self.ps3 = nn.PixelShuffle(2) 320 | 321 | def forward(self, x): 322 | x = self.conv1(x) 323 | x = self.conv2(x) 324 | x = self.conv3(x) 325 | 326 | x = self.ps3(x) 327 | 328 | return x 329 | 330 | class MFPNET(nn.Module): 331 | def __init__(self, classes): 332 | super(MFPNET, self).__init__() 333 | 334 | self.se_resnet50 = se_resnet50(pretrained=True, strides = (1,2,2,2)) 335 | self.stage1 = nn.Sequential(self.se_resnet50.conv1, self.se_resnet50.bn1, self.se_resnet50.relu) 336 | self.stage2 = nn.Sequential(self.se_resnet50.maxpool, self.se_resnet50.layer1) 337 | self.stage3 = nn.Sequential(self.se_resnet50.layer2) 338 | self.stage4 = nn.Sequential(self.se_resnet50.layer3) 339 | self.stage5 = nn.Sequential(self.se_resnet50.layer4) 340 | 341 | self.maffm=MAFFM(256,[64,256,512,1024,2048]) 342 | self.dec = DECODER(256, classes) 343 | 344 | def encoder(self, x): 345 | x1 = self.stage1(x) 346 | x2 = self.stage2(x1) 347 | x3 = self.stage3(x2) 348 | x4 = self.stage4(x3) 349 | x5 = self.stage5(x4) 350 | 351 | return x1, x2, x3, x4, x5 352 | 353 | def forward(self, x_prev, x_now): 354 | p1_t1, p2_t1, p3_t1, p4_t1, p5_t1 = self.encoder(x_prev) 355 | p1_t2, p2_t2, p3_t2, p4_t2, p5_t2 = self.encoder(x_now) 356 | features_t1_t2 = (p1_t1, p2_t1, p3_t1, p4_t1, p5_t1, p1_t2, p2_t2, p3_t2, p4_t2, p5_t2) 357 | 358 | x_fuse=self.maffm(features_t1_t2) 359 | dis_map=self.dec(x_fuse) 360 | 361 | return dis_map 362 | 363 | if __name__ == "__main__": 364 | model = MFPNET(classes = 2) 365 | 366 | # # Example for using Perceptual Similarity Module 367 | # from vgg import Vgg19 368 | 369 | # criterion_perceptual = nn.MSELoss() 370 | # criterion_perceptual.cuda() 371 | # vgg= Vgg19().cuda() 372 | 373 | # for epoch in range(300): 374 | # for i, (data_prev, data_now, label) in enumerate(loader_train, 0): 375 | # model.train() 376 | # model.zero_grad() 377 | # optimizer.zero_grad() 378 | # img_prev_train, img_now_train, label_train = data_prev.cuda(), data_now.cuda(), label.cuda() 379 | 380 | # out_train1, _ = model(img_prev_train, img_now_train) 381 | 382 | # # Perceptual Similarity Module (PSM) 383 | # out_train_softmax2d = F.softmax(out_train1,dim=1) 384 | # an_change = out_train_softmax2d[:,1,:,:].unsqueeze(1).expand_as(img_prev_train) 385 | # an_unchange = out_train_softmax2d[:,0,:,:].unsqueeze(1).expand_as(img_prev_train) 386 | # label_change = label_train.expand_as(img_prev_train).type(torch.FloatTensor).cuda() 387 | # label_unchange = 1-label_change 388 | # an_change = an_change*label_change 389 | # an_unchange = an_unchange*(1-label_change) 390 | 391 | # an_change_feature = vgg(an_change) 392 | # gt_feature = vgg(label_change) 393 | # an_unchange_feature = vgg(an_unchange) 394 | # gt_feature_unchange = vgg(label_unchange) 395 | 396 | # perceptual_loss_change = criterion_perceptual(an_change_feature[0], gt_feature[0]) 397 | # perceptual_loss_unchange = criterion_perceptual(an_unchange_feature[0], gt_feature_unchange[0]) 398 | # perceptual_loss = perceptual_loss_change + perceptual_loss_unchange -------------------------------------------------------------------------------- /MFPNet_code/models/seresnet50.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | model_urls = { 6 | 'seresnet50': 'https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl' 7 | } 8 | 9 | class SELayer(nn.Module): 10 | def __init__(self, channel, reduction=16): 11 | super(SELayer, self).__init__() 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.fc = nn.Sequential( 14 | nn.Linear(channel, channel // reduction, bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(channel // reduction, channel, bias=False), 17 | nn.Sigmoid() 18 | ) 19 | 20 | def forward(self, x): 21 | b, c, _, _ = x.size() 22 | y = self.avg_pool(x).view(b, c) 23 | y = self.fc(y).view(b, c, 1, 1) 24 | return x * y.expand_as(x) 25 | 26 | class FixedBatchNorm(nn.BatchNorm2d): 27 | def forward(self, input): 28 | return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, 29 | training=False, eps=self.eps) 30 | 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, reduction=16): 36 | super(Bottleneck, self).__init__() 37 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 38 | self.bn1 = FixedBatchNorm(planes) 39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 40 | padding=dilation, bias=False, dilation=dilation) 41 | self.bn2 = FixedBatchNorm(planes) 42 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 43 | self.bn3 = FixedBatchNorm(planes * 4) 44 | self.relu = nn.ReLU(inplace=True) 45 | # Squeeze-and-Excitation 46 | self.se = SELayer(planes * 4, reduction) 47 | # Downsample 48 | self.downsample = downsample 49 | self.stride = stride 50 | self.dilation = dilation 51 | 52 | def forward(self, x): 53 | residual = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv3(out) 64 | out = self.bn3(out) 65 | 66 | out = self.se(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | class Bottleneck_mdcn(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, reduction=16): 80 | super(Bottleneck_mdcn, self).__init__() 81 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 82 | self.bn1 = FixedBatchNorm(planes) 83 | self.conv2= ModulatedDeformConvPack(planes, planes, kernel_size=(3, 3), stride=stride, 84 | padding=dilation, dilation=dilation,bias=False,deformable_groups=2) 85 | self.bn2 = FixedBatchNorm(planes) 86 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 87 | self.bn3 = FixedBatchNorm(planes * 4) 88 | self.relu = nn.ReLU(inplace=True) 89 | # Squeeze-and-Excitation 90 | self.se = SELayer(planes * 4, reduction) 91 | # Downsample 92 | self.downsample = downsample 93 | self.stride = stride 94 | self.dilation = dilation 95 | 96 | def forward(self, x): 97 | residual = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | out = self.se(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out += residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class SEResNet(nn.Module): 122 | 123 | def __init__(self, block, layers, strides=(2, 2, 2, 2), dilations=(1, 1, 2, 4),zero_init_residual=True): 124 | super(SEResNet, self).__init__() 125 | self.inplanes = 64 126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = FixedBatchNorm(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dilation=dilations[0]) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1]) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2]) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3]) 135 | self.inplanes = 1024 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | FixedBatchNorm(planes * block.expansion), 144 | ) 145 | 146 | layers = [block(self.inplanes, planes, stride, downsample, dilation=1)] 147 | self.inplanes = planes * block.expansion 148 | for i in range(1, blocks): 149 | layers.append(block(self.inplanes, planes, dilation=dilation)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | 154 | 155 | 156 | 157 | def forward(self, x): 158 | x = self.conv1(x) 159 | x = self.bn1(x) 160 | x = self.relu(x) 161 | x = self.maxpool(x) 162 | 163 | x = self.layer1(x) 164 | x = self.layer2(x) 165 | x = self.layer3(x) 166 | x = self.layer4(x) 167 | 168 | x = self.avgpool(x) 169 | x = x.view(x.size(0), -1) 170 | x = self.fc(x) 171 | 172 | return x 173 | 174 | 175 | def se_resnet50(pretrained=True, **kwargs): 176 | 177 | model = SEResNet(Bottleneck,layers=[3, 4, 6, 3], **kwargs) 178 | if pretrained: 179 | state_dict = model_zoo.load_url(model_urls['seresnet50']) 180 | model_dict = model.state_dict() 181 | 182 | state_dict.pop('fc.weight') 183 | state_dict.pop('fc.bias') 184 | 185 | # state_dict = {k: v for k, v in state_dict.items() if k in model_dict} 186 | # state_dict.update(state_dict) 187 | model.load_state_dict(state_dict) 188 | print("Success to load a pretrained weight") 189 | return model 190 | 191 | if __name__ == "__main__": 192 | model = nn.DataParallel(se_resnet50(pretrained=True, strides = (1, 2, 1, 2)), device_ids=1).cuda() -------------------------------------------------------------------------------- /MFPNet_code/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | class Vgg19(nn.Module): 6 | def __init__(self): 7 | super(Vgg19, self).__init__() 8 | features = models.vgg19(pretrained=True).features 9 | self.to_relu_1_2 = nn.Sequential() 10 | self.to_relu_2_2 = nn.Sequential() 11 | self.to_relu_3_4 = nn.Sequential() 12 | self.to_relu_4_4 = nn.Sequential() 13 | self.to_relu_5_4 = nn.Sequential() 14 | # conv -1 15 | for x in range(4): 16 | self.to_relu_1_2.add_module(str(x), features[x]) 17 | for x in range(4, 9): 18 | self.to_relu_2_2.add_module(str(x), features[x]) 19 | for x in range(9, 18): 20 | self.to_relu_3_4.add_module(str(x), features[x]) 21 | for x in range(18, 27): 22 | self.to_relu_4_4.add_module(str(x), features[x]) 23 | for x in range(27, 36): 24 | self.to_relu_5_4.add_module(str(x), features[x]) 25 | 26 | # don't need the gradients, just want the features 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def forward(self, x): 31 | h = self.to_relu_1_2(x) 32 | h_relu_1_2 = h 33 | h = self.to_relu_2_2(h) 34 | h_relu_2_2 = h 35 | h = self.to_relu_3_4(h) 36 | h_relu_3_4 = h 37 | h = self.to_relu_4_4(h) 38 | h_relu_4_4 = h 39 | h = self.to_relu_5_4(h) 40 | h_relu_5_4 = h 41 | 42 | out = (h_relu_1_2, h_relu_2_2, h_relu_3_4, h_relu_4_4, h_relu_5_4) 43 | return out -------------------------------------------------------------------------------- /MFPNet_code/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.parser import get_parser_with_args 6 | from utils.helpers import (get_loaders, get_criterion, 7 | load_model, initialize_metrics, get_mean_metrics, 8 | set_metrics) 9 | from sklearn.metrics import precision_recall_fscore_support as prfs 10 | import os 11 | import logging 12 | import json 13 | import random 14 | import numpy as np 15 | import re 16 | import warnings 17 | from models.vgg import Vgg19 18 | warnings.filterwarnings("ignore") 19 | 20 | """ 21 | Initialize Parser and define arguments 22 | """ 23 | parser, metadata = get_parser_with_args(metadata_json_path='/home/aaa/xujialang/master_thesis/MFPNet/metadata.json') 24 | opt = parser.parse_args() 25 | 26 | """ 27 | Initialize experiments log 28 | """ 29 | logging.basicConfig(level=logging.INFO) 30 | 31 | """ 32 | Set up environment: define paths, download data, and set device 33 | """ 34 | dev = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') 35 | logging.info('GPU AVAILABLE? ' + str(torch.cuda.is_available())) 36 | 37 | def seed_torch(seed): 38 | random.seed(seed) 39 | os.environ['PYTHONHASHSEED'] = str(seed) 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 44 | torch.backends.cudnn.benchmark = False 45 | torch.backends.cudnn.deterministic = True 46 | seed_torch(seed=777) 47 | 48 | train_loader, val_loader = get_loaders(opt) 49 | print(opt.batch_size * len(train_loader)) 50 | print(opt.batch_size * len(val_loader)) 51 | 52 | """ 53 | Load Model then define other aspects of the model 54 | """ 55 | logging.info('LOADING Model') 56 | model = load_model(opt, dev) 57 | vgg=Vgg19().to(dev) 58 | """ 59 | Resume 60 | """ 61 | epoch_resume=0 62 | if opt.resume != "None": 63 | model.load_state_dict(torch.load(os.path.join(opt.resume))) 64 | epoch_resume=int(re.sub("\D","",opt.resume)) 65 | print('resume success: epoch {}'.format(epoch_resume)) 66 | 67 | criterion_ce = nn.CrossEntropyLoss().to(dev) 68 | criterion_perceptual = nn.MSELoss().to(dev) 69 | criterion = get_criterion(opt) 70 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate) # Be careful when you adjust learning rate, you can refer to the linear scaling rule 71 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, T_mult=2, eta_min=0, last_epoch=-1) 72 | 73 | """ 74 | Set starting values 75 | """ 76 | best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1} 77 | logging.info('STARTING training') 78 | 79 | for epoch in range(opt.epochs): 80 | epoch= epoch + epoch_resume +1 81 | train_metrics = initialize_metrics() 82 | val_metrics = initialize_metrics() 83 | 84 | """ 85 | Begin Training 86 | """ 87 | model.train() 88 | logging.info('SET model mode to train!') 89 | 90 | for batch_img1, batch_img2, labels in train_loader: 91 | # Set variables for training 92 | batch_img1 = batch_img1.float().to(dev) 93 | batch_img2 = batch_img2.float().to(dev) 94 | labels = labels.long().to(dev) 95 | 96 | # Zero the gradient 97 | optimizer.zero_grad() 98 | 99 | # Get model predictions, calculate loss, backprop 100 | cd_preds= model(batch_img1, batch_img2) 101 | loss = criterion(criterion_ce, criterion_perceptual, cd_preds, labels, batch_img1, vgg, dev) 102 | 103 | loss.backward() 104 | optimizer.step() 105 | 106 | # Calculate and log other batch metrics 107 | cd_preds = torch.argmax(cd_preds, dim = 1) 108 | cd_corrects = (100 * 109 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() / 110 | (labels.size()[0] * (opt.patch_size**2))) 111 | cd_train_report = prfs(labels.data.cpu().numpy().flatten(), 112 | cd_preds.data.cpu().numpy().flatten(), 113 | average='binary', 114 | pos_label=1) 115 | train_metrics = set_metrics(train_metrics, 116 | loss, 117 | cd_corrects, 118 | cd_train_report, 119 | scheduler.get_last_lr()) 120 | 121 | # log the batch mean metrics 122 | mean_train_metrics = get_mean_metrics(train_metrics) 123 | 124 | # clear batch variables from memory 125 | del batch_img1, batch_img2, labels 126 | 127 | scheduler.step() 128 | logging.info("EPOCH {} TRAIN METRICS. ".format(epoch) + str(mean_train_metrics)) 129 | 130 | 131 | """ 132 | Begin Validation 133 | """ 134 | model.eval() 135 | with torch.no_grad(): 136 | for batch_img1, batch_img2, labels in val_loader: 137 | # Set variables for training 138 | batch_img1 = batch_img1.float().to(dev) 139 | batch_img2 = batch_img2.float().to(dev) 140 | labels = labels.long().to(dev) 141 | 142 | # Get predictions and calculate loss 143 | cd_preds = model(batch_img1, batch_img2) 144 | val_loss = criterion(criterion_ce, criterion_perceptual, cd_preds, labels, batch_img1, vgg, dev) 145 | 146 | # Calculate and log other batch metrics 147 | cd_preds = torch.argmax(cd_preds, dim = 1) 148 | cd_corrects = (100 * 149 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() / 150 | (labels.size()[0] * (opt.patch_size**2))) 151 | cd_val_report = prfs(labels.data.cpu().numpy().flatten(), 152 | cd_preds.data.cpu().numpy().flatten(), 153 | average='binary', 154 | pos_label=1) 155 | val_metrics = set_metrics(val_metrics, 156 | val_loss, 157 | cd_corrects, 158 | cd_val_report, 159 | scheduler.get_lr()) 160 | 161 | # log the batch mean metrics 162 | mean_val_metrics = get_mean_metrics(val_metrics) 163 | 164 | # clear batch variables from memory 165 | del batch_img1, batch_img2, labels 166 | 167 | logging.info("EPOCH {} VALIDATION METRICS".format(epoch)+str(mean_val_metrics)) 168 | 169 | """ 170 | Store the weights of good epochs based on validation results 171 | """ 172 | if (mean_val_metrics['cd_f1scores'] > best_metrics['cd_f1scores']): 173 | # Insert training and epoch information to metadata dictionary 174 | logging.info('updata the model') 175 | metadata['val_metrics'] = mean_val_metrics 176 | 177 | # Save model and log 178 | if not os.path.exists(opt.weight_dir): 179 | os.mkdir(opt.weight_dir) 180 | with open(opt.weight_dir + 'metadata_val_epoch_' + str(epoch) + '.json', 'w') as fout: 181 | json.dump(metadata, fout) 182 | 183 | torch.save(model.state_dict(), opt.weight_dir + 'checkpoint_epoch_'+str(epoch)+'_f1_'+str(mean_val_metrics['cd_f1scores'])+'.pt') 184 | best_metrics = mean_val_metrics 185 | print('best val: ' + str(mean_val_metrics)) 186 | 187 | print('An epoch finished.') 188 | 189 | print('Done!') -------------------------------------------------------------------------------- /MFPNet_code/utils/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from PIL import Image 4 | from utils import transforms as tr 5 | 6 | 7 | ''' 8 | Load all training and validation data paths 9 | ''' 10 | def full_path_loader(data_dir): 11 | train_data = [i for i in os.listdir(data_dir + 'train/A/') if not 12 | i.startswith('.')] 13 | train_data.sort() 14 | 15 | valid_data = [i for i in os.listdir(data_dir + 'val/A/') if not 16 | i.startswith('.')] 17 | valid_data.sort() 18 | 19 | train_label_paths = [] 20 | val_label_paths = [] 21 | if 'DSIFN' in data_dir: 22 | for img in train_data: 23 | train_label_paths.append(data_dir + 'train/label/' + img.split('.')[0] + '.png') 24 | for img in valid_data: 25 | val_label_paths.append(data_dir + 'val/label/' + img.split('.')[0] + '.png') 26 | else: 27 | for img in train_data: 28 | train_label_paths.append(data_dir + 'train/label/' + img) 29 | for img in valid_data: 30 | val_label_paths.append(data_dir + 'val/label/' + img) 31 | 32 | 33 | train_data_path = [] 34 | val_data_path = [] 35 | 36 | for img in train_data: 37 | train_data_path.append([data_dir + 'train/', img]) 38 | for img in valid_data: 39 | val_data_path.append([data_dir + 'val/', img]) 40 | 41 | train_dataset = {} 42 | val_dataset = {} 43 | for cp in range(len(train_data)): 44 | train_dataset[cp] = {'image': train_data_path[cp], 45 | 'label': train_label_paths[cp]} 46 | for cp in range(len(valid_data)): 47 | val_dataset[cp] = {'image': val_data_path[cp], 48 | 'label': val_label_paths[cp]} 49 | 50 | 51 | return train_dataset, val_dataset 52 | 53 | ''' 54 | Load all testing data paths 55 | ''' 56 | def full_test_loader(data_dir): 57 | 58 | test_data = [i for i in os.listdir(data_dir + 'test/A/') if not 59 | i.startswith('.')] 60 | test_data.sort() 61 | 62 | test_label_paths = [] 63 | if 'DSIFN' in data_dir: 64 | for img in test_data: 65 | test_label_paths.append(data_dir + 'test/label/' + img.split('.')[0] + '.tif') 66 | else: 67 | for img in test_data: 68 | test_label_paths.append(data_dir + 'test/label/' + img) 69 | 70 | test_data_path = [] 71 | for img in test_data: 72 | test_data_path.append([data_dir + 'test/', img]) 73 | 74 | test_dataset = {} 75 | for cp in range(len(test_data)): 76 | test_dataset[cp] = {'image': test_data_path[cp], 77 | 'label': test_label_paths[cp]} 78 | 79 | return test_dataset 80 | 81 | def cdd_loader(img_path, label_path, aug): 82 | dir = img_path[0] 83 | name = img_path[1] 84 | 85 | img1 = Image.open(dir + 'A/' + name) 86 | img2 = Image.open(dir + 'B/' + name) 87 | label = Image.open(label_path).convert('L') 88 | sample = {'image': (img1, img2), 'label': label} 89 | 90 | if aug: 91 | sample = tr.train_transforms(sample) 92 | else: 93 | sample = tr.test_transforms(sample) 94 | 95 | return sample['image'][0], sample['image'][1], sample['label'] 96 | 97 | 98 | class CDDloader(data.Dataset): 99 | 100 | def __init__(self, full_load, aug=False): 101 | 102 | self.full_load = full_load 103 | self.loader = cdd_loader 104 | self.aug = aug 105 | 106 | def __getitem__(self, index): 107 | 108 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label'] 109 | 110 | return self.loader(img_path, 111 | label_path, 112 | self.aug) 113 | 114 | def __len__(self): 115 | return len(self.full_load) 116 | -------------------------------------------------------------------------------- /MFPNet_code/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.utils.data 4 | import torch.nn as nn 5 | import numpy as np 6 | from utils.dataloaders import (full_path_loader, full_test_loader, CDDloader) 7 | from utils.metrics import jaccard_loss, dice_loss 8 | from utils.hybridloss import hybrid_loss 9 | from models.MFPNet_model import MFPNET 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | def initialize_metrics(): 13 | """Generates a dictionary of metrics with metrics as keys 14 | and empty lists as values 15 | 16 | Returns 17 | ------- 18 | dict 19 | a dictionary of metrics 20 | 21 | """ 22 | metrics = { 23 | 'cd_losses': [], 24 | 'cd_corrects': [], 25 | 'cd_precisions': [], 26 | 'cd_recalls': [], 27 | 'cd_f1scores': [], 28 | 'learning_rate': [], 29 | } 30 | 31 | return metrics 32 | 33 | 34 | def get_mean_metrics(metric_dict): 35 | """takes a dictionary of lists for metrics and returns dict of mean values 36 | 37 | Parameters 38 | ---------- 39 | metric_dict : dict 40 | A dictionary of metrics 41 | 42 | Returns 43 | ------- 44 | dict 45 | dict of floats that reflect mean metric value 46 | 47 | """ 48 | return {k: np.mean(v) for k, v in metric_dict.items()} 49 | 50 | 51 | 52 | def set_metrics(metric_dict, cd_loss, cd_corrects, cd_report, lr): 53 | """Updates metric dict with batch metrics 54 | 55 | Parameters 56 | ---------- 57 | metric_dict : dict 58 | dict of metrics 59 | cd_loss : dict(?) 60 | loss value 61 | cd_corrects : dict(?) 62 | number of correct results (to generate accuracy 63 | cd_report : list 64 | precision, recall, f1 values 65 | 66 | Returns 67 | ------- 68 | dict 69 | dict of updated metrics 70 | 71 | 72 | """ 73 | metric_dict['cd_losses'].append(cd_loss.item()) 74 | metric_dict['cd_corrects'].append(cd_corrects.item()) 75 | metric_dict['cd_precisions'].append(cd_report[0]) 76 | metric_dict['cd_recalls'].append(cd_report[1]) 77 | metric_dict['cd_f1scores'].append(cd_report[2]) 78 | metric_dict['learning_rate'].append(lr) 79 | 80 | return metric_dict 81 | 82 | def get_loaders(opt): 83 | 84 | 85 | logging.info('STARTING Dataset Creation') 86 | 87 | train_full_load, val_full_load = full_path_loader(opt.dataset_dir) 88 | 89 | 90 | train_dataset = CDDloader(train_full_load, aug=opt.augmentation) 91 | val_dataset = CDDloader(val_full_load, aug=False) 92 | 93 | logging.info('STARTING Dataloading') 94 | 95 | train_loader = torch.utils.data.DataLoader(train_dataset, 96 | batch_size=opt.batch_size, 97 | shuffle=True, 98 | num_workers=opt.num_workers) 99 | val_loader = torch.utils.data.DataLoader(val_dataset, 100 | batch_size=opt.batch_size, 101 | shuffle=False, 102 | num_workers=opt.num_workers) 103 | return train_loader, val_loader 104 | 105 | def get_test_loaders(opt): 106 | 107 | logging.info('STARTING Test Dataset Creation') 108 | 109 | test_full_load = full_test_loader(opt.dataset_dir) 110 | 111 | test_dataset = CDDloader(test_full_load, aug=False) 112 | 113 | logging.info('STARTING Test Dataloading') 114 | 115 | test_loader = torch.utils.data.DataLoader(test_dataset, 116 | batch_size=1, 117 | shuffle=False, 118 | num_workers=opt.num_workers) 119 | return test_loader 120 | 121 | 122 | def get_criterion(opt): 123 | """get the user selected loss function 124 | 125 | Parameters 126 | ---------- 127 | opt : dict 128 | Dictionary of options/flags 129 | 130 | Returns 131 | ------- 132 | method 133 | loss function 134 | 135 | """ 136 | if opt.loss_function == 'hybrid': 137 | criterion = hybrid_loss 138 | if opt.loss_function == 'bce': 139 | criterion = nn.CrossEntropyLoss() 140 | if opt.loss_function == 'dice': 141 | criterion = dice_loss 142 | if opt.loss_function == 'jaccard': 143 | criterion = jaccard_loss 144 | 145 | return criterion 146 | 147 | 148 | def load_model(opt, device): 149 | """Load the model 150 | 151 | Parameters 152 | ---------- 153 | opt : dict 154 | User specified flags/options 155 | device : string 156 | device on which to train model 157 | 158 | """ 159 | model = MFPNET(classes = 2).to(device) 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /MFPNet_code/utils/hybridloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def hybrid_loss(criterion_ce, criterion_perceptual, prediction, target, img_prev_train, vgg, dev): 5 | """Calculating the loss""" 6 | loss = 0 7 | 8 | # Perceptual Similarity Module (PSM) 9 | out_train_softmax2d = F.softmax(prediction,dim=1) 10 | an_change = out_train_softmax2d[:,1,:,:].unsqueeze(1).expand_as(img_prev_train) 11 | an_unchange = out_train_softmax2d[:,0,:,:].unsqueeze(1).expand_as(img_prev_train) 12 | label_change = target.unsqueeze(1).expand_as(img_prev_train).type(torch.FloatTensor).to(dev) 13 | label_unchange = 1-label_change 14 | an_change = an_change * label_change 15 | an_unchange = an_unchange * label_unchange 16 | 17 | an_change_feature = vgg(an_change) 18 | gt_feature = vgg(label_change) 19 | an_unchange_feature = vgg(an_unchange) 20 | gt_feature_unchange = vgg(label_unchange) 21 | 22 | perceptual_loss_change = criterion_perceptual(an_change_feature[0], gt_feature[0]) 23 | perceptual_loss_unchange = criterion_perceptual(an_unchange_feature[0], gt_feature_unchange[0]) 24 | perceptual_loss = perceptual_loss_change + perceptual_loss_unchange 25 | 26 | loss = 0.0001*perceptual_loss + criterion_ce(prediction, target) 27 | 28 | return loss 29 | 30 | -------------------------------------------------------------------------------- /MFPNet_code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | 9 | class FocalLoss(nn.Module): 10 | def __init__(self, gamma=0, alpha=None, size_average=True): 11 | super(FocalLoss, self).__init__() 12 | self.gamma = gamma 13 | self.alpha = alpha 14 | if isinstance(alpha, (float, int)): 15 | self.alpha = torch.Tensor([alpha, 1-alpha]) 16 | if isinstance(alpha, list): 17 | self.alpha = torch.Tensor(alpha) 18 | self.size_average = size_average 19 | 20 | def forward(self, input, target): 21 | if input.dim() > 2: 22 | # N,C,H,W => N,C,H*W 23 | input = input.view(input.size(0), input.size(1), -1) 24 | 25 | # N,C,H*W => N,H*W,C 26 | input = input.transpose(1, 2) 27 | 28 | # N,H*W,C => N*H*W,C 29 | input = input.contiguous().view(-1, input.size(2)) 30 | 31 | 32 | target = target.view(-1, 1) 33 | logpt = F.log_softmax(input) 34 | logpt = logpt.gather(1, target) 35 | logpt = logpt.view(-1) 36 | pt = Variable(logpt.data.exp()) 37 | 38 | if self.alpha is not None: 39 | if self.alpha.type() != input.data.type(): 40 | self.alpha = self.alpha.type_as(input.data) 41 | at = self.alpha.gather(0, target.data.view(-1)) 42 | logpt = logpt * Variable(at) 43 | 44 | loss = -1 * (1-pt)**self.gamma * logpt 45 | 46 | if self.size_average: 47 | return loss.mean() 48 | else: 49 | return loss.sum() 50 | 51 | def dice_loss(logits, true, eps=1e-7): 52 | """Computes the Sørensen–Dice loss. 53 | Note that PyTorch optimizers minimize a loss. In this 54 | case, we would like to maximize the dice loss so we 55 | return the negated dice loss. 56 | Args: 57 | true: a tensor of shape [B, 1, H, W]. 58 | logits: a tensor of shape [B, C, H, W]. Corresponds to 59 | the raw output or logits of the model. 60 | eps: added to the denominator for numerical stability. 61 | Returns: 62 | dice_loss: the Sørensen–Dice loss. 63 | """ 64 | num_classes = logits.shape[1] 65 | if num_classes == 1: 66 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] 67 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 68 | true_1_hot_f = true_1_hot[:, 0:1, :, :] 69 | true_1_hot_s = true_1_hot[:, 1:2, :, :] 70 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) 71 | pos_prob = torch.sigmoid(logits) 72 | neg_prob = 1 - pos_prob 73 | probas = torch.cat([pos_prob, neg_prob], dim=1) 74 | else: 75 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)] 76 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 77 | probas = F.softmax(logits, dim=1) 78 | true_1_hot = true_1_hot.type(logits.type()) 79 | dims = (0,) + tuple(range(2, true.ndimension())) 80 | intersection = torch.sum(probas * true_1_hot, dims) 81 | cardinality = torch.sum(probas + true_1_hot, dims) 82 | dice_loss = (2. * intersection / (cardinality + eps)).mean() 83 | return (1 - dice_loss) 84 | 85 | 86 | def jaccard_loss(logits, true, eps=1e-7): 87 | """Computes the Jaccard loss, a.k.a the IoU loss. 88 | Note that PyTorch optimizers minimize a loss. In this 89 | case, we would like to maximize the jaccard loss so we 90 | return the negated jaccard loss. 91 | Args: 92 | true: a tensor of shape [B, H, W] or [B, 1, H, W]. 93 | logits: a tensor of shape [B, C, H, W]. Corresponds to 94 | the raw output or logits of the model. 95 | eps: added to the denominator for numerical stability. 96 | Returns: 97 | jacc_loss: the Jaccard loss. 98 | """ 99 | num_classes = logits.shape[1] 100 | if num_classes == 1: 101 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] 102 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 103 | true_1_hot_f = true_1_hot[:, 0:1, :, :] 104 | true_1_hot_s = true_1_hot[:, 1:2, :, :] 105 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) 106 | pos_prob = torch.sigmoid(logits) 107 | neg_prob = 1 - pos_prob 108 | probas = torch.cat([pos_prob, neg_prob], dim=1) 109 | else: 110 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)] 111 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 112 | probas = F.softmax(logits, dim=1) 113 | true_1_hot = true_1_hot.type(logits.type()) 114 | dims = (0,) + tuple(range(2, true.ndimension())) 115 | intersection = torch.sum(probas * true_1_hot, dims) 116 | cardinality = torch.sum(probas + true_1_hot, dims) 117 | union = cardinality - intersection 118 | jacc_loss = (intersection / (union + eps)).mean() 119 | return (1 - jacc_loss) 120 | 121 | 122 | class TverskyLoss(nn.Module): 123 | def __init__(self, alpha=0.5, beta=0.5, eps=1e-7, size_average=True): 124 | super(TverskyLoss, self).__init__() 125 | self.alpha = alpha 126 | self.beta = beta 127 | self.size_average = size_average 128 | self.eps = eps 129 | 130 | def forward(self, logits, true): 131 | """Computes the Tversky loss [1]. 132 | Args: 133 | true: a tensor of shape [B, H, W] or [B, 1, H, W]. 134 | logits: a tensor of shape [B, C, H, W]. Corresponds to 135 | the raw output or logits of the model. 136 | alpha: controls the penalty for false positives. 137 | beta: controls the penalty for false negatives. 138 | eps: added to the denominator for numerical stability. 139 | Returns: 140 | tversky_loss: the Tversky loss. 141 | Notes: 142 | alpha = beta = 0.5 => dice coeff 143 | alpha = beta = 1 => tanimoto coeff 144 | alpha + beta = 1 => F beta coeff 145 | References: 146 | [1]: https://arxiv.org/abs/1706.05721 147 | """ 148 | num_classes = logits.shape[1] 149 | if num_classes == 1: 150 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] 151 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 152 | true_1_hot_f = true_1_hot[:, 0:1, :, :] 153 | true_1_hot_s = true_1_hot[:, 1:2, :, :] 154 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) 155 | pos_prob = torch.sigmoid(logits) 156 | neg_prob = 1 - pos_prob 157 | probas = torch.cat([pos_prob, neg_prob], dim=1) 158 | else: 159 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)] 160 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 161 | probas = F.softmax(logits, dim=1) 162 | 163 | true_1_hot = true_1_hot.type(logits.type()) 164 | dims = (0,) + tuple(range(2, true.ndimension())) 165 | intersection = torch.sum(probas * true_1_hot, dims) 166 | fps = torch.sum(probas * (1 - true_1_hot), dims) 167 | fns = torch.sum((1 - probas) * true_1_hot, dims) 168 | num = intersection 169 | denom = intersection + (self.alpha * fps) + (self.beta * fns) 170 | tversky_loss = (num / (denom + self.eps)).mean() 171 | return (1 - tversky_loss) 172 | -------------------------------------------------------------------------------- /MFPNet_code/utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse as ag 2 | import json 3 | 4 | def get_parser_with_args(metadata_json_path=None): 5 | parser = ag.ArgumentParser(description='Training change detection network') 6 | 7 | with open(metadata_json_path, 'r') as fin: 8 | metadata = json.load(fin) 9 | parser.set_defaults(**metadata) 10 | return parser, metadata 11 | 12 | return None 13 | -------------------------------------------------------------------------------- /MFPNet_code/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | import torchvision.transforms as transforms 7 | 8 | class Normalize(object): 9 | """Normalize a tensor image with mean and standard deviation. 10 | Args: 11 | mean (tuple): means for each channel. 12 | std (tuple): standard deviations for each channel. 13 | """ 14 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def __call__(self, sample): 19 | img = sample['image'] 20 | mask = sample['label'] 21 | img = np.array(img).astype(np.float32) 22 | mask = np.array(mask).astype(np.float32) 23 | img /= 255.0 24 | img -= self.mean 25 | img /= self.std 26 | 27 | return {'image': img, 28 | 'label': mask} 29 | 30 | 31 | class ToTensor(object): 32 | """Convert ndarrays in sample to Tensors.""" 33 | 34 | def __call__(self, sample): 35 | # swap color axis because 36 | # numpy image: H x W x C 37 | # torch image: C X H X W 38 | img1 = sample['image'][0] 39 | img2 = sample['image'][1] 40 | mask = sample['label'] 41 | img1 = np.array(img1).astype(np.float32).transpose((2, 0, 1)) 42 | img2 = np.array(img2).astype(np.float32).transpose((2, 0, 1)) 43 | if np.unique(mask).sum() == 1: 44 | mask = np.array(mask).astype(np.float32) 45 | else: 46 | mask = np.array(mask).astype(np.float32) / 255.0 47 | 48 | img1 = torch.from_numpy(img1).float() 49 | img2 = torch.from_numpy(img2).float() 50 | mask = torch.from_numpy(mask).float() 51 | 52 | return {'image': (img1, img2), 53 | 'label': mask} 54 | 55 | 56 | class RandomHorizontalFlip(object): 57 | def __call__(self, sample): 58 | img1 = sample['image'][0] 59 | img2 = sample['image'][1] 60 | mask = sample['label'] 61 | if random.random() < 0.5: 62 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT) 63 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT) 64 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 65 | 66 | return {'image': (img1, img2), 67 | 'label': mask} 68 | 69 | class RandomVerticalFlip(object): 70 | def __call__(self, sample): 71 | img1 = sample['image'][0] 72 | img2 = sample['image'][1] 73 | mask = sample['label'] 74 | if random.random() < 0.5: 75 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM) 76 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM) 77 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 78 | 79 | return {'image': (img1, img2), 80 | 'label': mask} 81 | 82 | class RandomFixRotate(object): 83 | def __init__(self): 84 | self.degree = [Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270] 85 | 86 | def __call__(self, sample): 87 | img1 = sample['image'][0] 88 | img2 = sample['image'][1] 89 | mask = sample['label'] 90 | if random.random() < 0.75: 91 | rotate_degree = random.choice(self.degree) 92 | img1 = img1.transpose(rotate_degree) 93 | img2 = img2.transpose(rotate_degree) 94 | mask = mask.transpose(rotate_degree) 95 | 96 | return {'image': (img1, img2), 97 | 'label': mask} 98 | 99 | 100 | class RandomRotate(object): 101 | def __init__(self, degree): 102 | self.degree = degree 103 | 104 | def __call__(self, sample): 105 | img1 = sample['image'][0] 106 | img2 = sample['image'][1] 107 | mask = sample['label'] 108 | rotate_degree = random.uniform(-1*self.degree, self.degree) 109 | img1 = img1.rotate(rotate_degree, Image.BILINEAR) 110 | img2 = img2.rotate(rotate_degree, Image.BILINEAR) 111 | mask = mask.rotate(rotate_degree, Image.NEAREST) 112 | 113 | return {'image': (img1, img2), 114 | 'label': mask} 115 | 116 | 117 | class RandomGaussianBlur(object): 118 | def __call__(self, sample): 119 | img1 = sample['image'][0] 120 | img2 = sample['image'][1] 121 | mask = sample['label'] 122 | if random.random() < 0.5: 123 | img1 = img1.filter(ImageFilter.GaussianBlur( 124 | radius=random.random())) 125 | img2 = img2.filter(ImageFilter.GaussianBlur( 126 | radius=random.random())) 127 | 128 | return {'image': (img1, img2), 129 | 'label': mask} 130 | 131 | 132 | class RandomScaleCrop(object): 133 | def __init__(self, base_size, crop_size, fill=0): 134 | self.base_size = base_size 135 | self.crop_size = crop_size 136 | self.fill = fill 137 | 138 | def __call__(self, sample): 139 | img = sample['image'] 140 | mask = sample['label'] 141 | # random scale (short edge) 142 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 143 | w, h = img.size 144 | if h > w: 145 | ow = short_size 146 | oh = int(1.0 * h * ow / w) 147 | else: 148 | oh = short_size 149 | ow = int(1.0 * w * oh / h) 150 | img = img.resize((ow, oh), Image.BILINEAR) 151 | mask = mask.resize((ow, oh), Image.NEAREST) 152 | # pad crop 153 | if short_size < self.crop_size: 154 | padh = self.crop_size - oh if oh < self.crop_size else 0 155 | padw = self.crop_size - ow if ow < self.crop_size else 0 156 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 157 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 158 | # random crop crop_size 159 | w, h = img.size 160 | x1 = random.randint(0, w - self.crop_size) 161 | y1 = random.randint(0, h - self.crop_size) 162 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 163 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 164 | 165 | return {'image': img, 166 | 'label': mask} 167 | 168 | 169 | class FixScaleCrop(object): 170 | def __init__(self, crop_size): 171 | self.crop_size = crop_size 172 | 173 | def __call__(self, sample): 174 | img = sample['image'] 175 | mask = sample['label'] 176 | w, h = img.size 177 | if w > h: 178 | oh = self.crop_size 179 | ow = int(1.0 * w * oh / h) 180 | else: 181 | ow = self.crop_size 182 | oh = int(1.0 * h * ow / w) 183 | img = img.resize((ow, oh), Image.BILINEAR) 184 | mask = mask.resize((ow, oh), Image.NEAREST) 185 | # center crop 186 | w, h = img.size 187 | x1 = int(round((w - self.crop_size) / 2.)) 188 | y1 = int(round((h - self.crop_size) / 2.)) 189 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 190 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 191 | 192 | return {'image': img, 193 | 'label': mask} 194 | 195 | class FixedResize(object): 196 | def __init__(self, size): 197 | self.size = (size, size) # size: (h, w) 198 | 199 | def __call__(self, sample): 200 | img1 = sample['image'][0] 201 | img2 = sample['image'][1] 202 | mask = sample['label'] 203 | 204 | assert img1.size == mask.size and img2.size == mask.size 205 | 206 | img1 = img1.resize(self.size, Image.BILINEAR) 207 | img2 = img2.resize(self.size, Image.BILINEAR) 208 | mask = mask.resize(self.size, Image.NEAREST) 209 | 210 | return {'image': (img1, img2), 211 | 'label': mask} 212 | 213 | 214 | ''' 215 | We don't use Normalize here, because it will bring negative effects. 216 | the mask of ground truth is converted to [0,1] in ToTensor() function. 217 | ''' 218 | train_transforms = transforms.Compose([ 219 | RandomHorizontalFlip(), 220 | RandomVerticalFlip(), 221 | RandomFixRotate(), 222 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 223 | # RandomGaussianBlur(), 224 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 225 | ToTensor()]) 226 | 227 | test_transforms = transforms.Compose([ 228 | # RandomHorizontalFlip(), 229 | # RandomVerticalFlip(), 230 | # RandomFixRotate(), 231 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 232 | # RandomGaussianBlur(), 233 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 234 | ToTensor()]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Remote Sensing Change Detection Based on Multidirectional Adaptive Feature Fusion and Perceptual Similarity 2 | ![](https://img.shields.io/badge/author-Jialang,Xu-brightgreen)![](https://img.shields.io/badge/email-504006937@qq.com-blue) 3 | 4 | PyTorch implementation for "[Remote Sensing Change Detection Based on Multidirectional Adaptive Feature Fusion and Perceptual Similarity](https://www.mdpi.com/2072-4292/13/15/3053)" 5 | 6 | - [03 August 2021] Release the code of MFPNet model. 7 | - [28 June 2022] Release the processed datasets, training/evaluation codes. 8 | - [08 JUly 2022] Release model weights for Season-varying/LEVIR-CD/Google datasets. 9 | 10 | ## Introduction 11 | Remote sensing change detection (RSCD) is an important yet challenging task in Earth observation. The booming development of convolutional neural networks (CNNs) in computer vision raises new possibilities for RSCD, and many recent RSCD methods have introduced CNNs to achieve promising improvements in performance. This paper proposes a novel multidirectional fusion and perception network for change detection in bi-temporal very-high-resolution remote sensing images. First, we propose an elaborate feature fusion module consisting of a multidirectional fusion pathway (MFP) and an adaptive weighted fusion (AWF) strategy for RSCD to boost the way that information propagates in the network. The MFP enhances the flexibility and diversity of information paths by creating extra top-down and shortcut-connection paths. The AWF strategy conducts weight recalibration for every fusion node to highlight salient feature maps and overcome semantic gaps between different features. Second, a novel perceptual similarity module is designed to introduce perceptual loss into the RSCD task, which adds the perceptual information, such as structure and semantic, for high-quality change maps generation. Extensive experiments on four challenging benchmark datasets demonstrate the superiority of the proposed network comparing with eight state-of-the-art methods in terms of F1, Kappa, and visual qualities. 12 | 13 | ## Content 14 | ### Architecture 15 | 16 | 17 | Fig.1 Overall architecture of the proposed multidirectional fusion and perception network (MFPNet).
18 | Note that the process with the dashed line only participates in model training. 19 | 20 | ### Datasets 21 | The processed and original datasets can be downloaded from the table below, we recommended downloading the processed one directly to get a quick start on our codes: 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
DatasetsProcessed LinksOriginal Links
Season-varying Dataset [1][Google Drive] 32 | [Baidu Drive] 33 | [Original]
LEVIR-CD Dataset [2][Original]
Google Dataset [3][Original]
Zhange Dataset [4][Original]
47 | 48 | ### Setup & Usage for the Code 49 | 50 | 1. Check the structure of data folders: 51 | ``` 52 | (root folder) 53 | ├── dataset1 54 | | ├── train 55 | | | ├── A 56 | | | ├── B 57 | | | ├── label 58 | | ├── val 59 | | | ├── A 60 | | | ├── B 61 | | | ├── label 62 | | ├── test 63 | | | ├── A 64 | | | ├── B 65 | | | ├── label 66 | ├── ... 67 | ``` 68 | 69 | 2. Check dependencies: 70 | ``` 71 | - Python 3.6+ 72 | - PyTorch 1.7.0+ 73 | - scikit-learn 74 | - cudatoolkit 75 | - cudnn 76 | - OpenCV-Python 77 | ``` 78 | 79 | 3. Change paths: 80 | ``` 81 | - Change the 'metadata_json_path' in 'train.py' to your 'metadata.json' path. 82 | - Change the 'dataset_dir' and 'weight_dir' in 'metadata.json' to your own path. 83 | ``` 84 | 85 | 4. Train the MFPNet: 86 | ``` 87 | python train.py 88 | ``` 89 | 90 | 5. Evaluate the MFPNet: 91 | ``` 92 | - Download model weights (optional). 93 | - Change the 'weight_path' in 'eval.py' to your model weight path. 94 | - python eval.py 95 | ``` 96 | 97 | ## Model Weights 98 | Model weights for Season-varying/LEVIR-CD/Google datasets are available via [Google Drive](https://drive.google.com/drive/folders/1-2njQ7Z3IIrjv6YGXoMD2CBZbc1nQuRu?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/141aQDQ_lMEi83O2t6AcLqg?pwd=1234). Note that the training/dataloader codes are rewritten and improved so the performance is a little different from the paper. 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 |
DatasetsF1 (%)Kappa (%)
Season-varying97.96497.691
LEVIR-CD91.56891.120
Google88.05884.140
121 | 122 | 123 | ## Reference 124 | Appreciate the work from the following repositories: 125 | * [likyoo/Siam-NestedUNet](https://github.com/likyoo/Siam-NestedUNet) 126 | * [zylo117/Yet-Another-EfficientDet-Pytorch](https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch) 127 | 128 | ## Cite 129 | If this repository is useful for your research, please cite: 130 | ``` 131 | @Article{rs13153053, 132 | AUTHOR = {Xu, Jialang and Luo, Chunbo and Chen, Xinyue and Wei, Shicai and Luo, Yang}, 133 | TITLE = {Remote Sensing Change Detection Based on Multidirectional Adaptive Feature Fusion and Perceptual Similarity}, 134 | JOURNAL = {Remote Sensing}, 135 | VOLUME = {13}, 136 | YEAR = {2021}, 137 | NUMBER = {15}, 138 | ARTICLE-NUMBER = {3053}, 139 | URL = {https://www.mdpi.com/2072-4292/13/15/3053}, 140 | ISSN = {2072-4292}, 141 | DOI = {10.3390/rs13153053} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /figure/AWF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/AWF.png -------------------------------------------------------------------------------- /figure/MFP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/MFP.png -------------------------------------------------------------------------------- /figure/MFPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/MFPNet.png -------------------------------------------------------------------------------- /figure/PSM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzjialang/MFPNet/df2e15c29cb13f01a7150025f622ad69a0ed3d1b/figure/PSM.png --------------------------------------------------------------------------------