├── Keras version ├── DSIFN.py ├── attention_module.py └── loss.py ├── README.md ├── _config.yml ├── dataset ├── README.md └── imgs │ ├── 1.png │ ├── 1.txt │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ └── 6.png ├── imgs ├── 1.png └── 1.test ├── pytorch version ├── DSIFN.py └── loss.py └── results ├── README.md └── imgs ├── 00000.jpg ├── 00001.jpg ├── 00002.jpg ├── 19.jpg ├── 4.jpg └── 41.jpg /Keras version/DSIFN.py: -------------------------------------------------------------------------------- 1 | 2 | # credits: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images 3 | 4 | from keras import applications 5 | from keras.models import Model 6 | from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \ 7 | Activation, ZeroPadding2D,Conv2DTranspose,Subtract,multiply,add,UpSampling2D,PReLU 8 | from keras import layers 9 | from keras.optimizers import Adam 10 | from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping, TensorBoard, CSVLogger 11 | from attention_module import channel_attention, spatial_attention,get_spatial_attention_map 12 | 13 | def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same',with_activation = False): 14 | x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides)(x) 15 | x=PReLU()(x) 16 | x=BatchNormalization(axis=3)(x) 17 | x=Dropout(rate = 0.6)(x) 18 | if with_activation == True: 19 | x = Activation('relu')(x) 20 | return x 21 | 22 | def vgg16(): 23 | vgg_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(512,512,3)) 24 | model = Model(inputs=vgg_model.input, outputs = vgg_model.get_layer('block5_conv3').output) 25 | model.trainable=False 26 | return model 27 | 28 | def DSIFN(): 29 | #DFEN accepts inputs in size of 512*512*3 30 | vgg_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(512,512,3)) 31 | 32 | b5c3_model = Model(inputs=vgg_model.input, outputs = vgg_model.get_layer('block5_conv3').output) 33 | b5c3_model.trainable=False 34 | 35 | b4c3_model = Model(inputs=vgg_model.input, outputs = vgg_model.get_layer('block4_conv3').output) 36 | b4c3_model.trainable=False 37 | 38 | b3c3_model = Model(inputs=vgg_model.input, outputs = vgg_model.get_layer('block3_conv3').output) 39 | b3c3_model.trainable=False 40 | 41 | b2c2_model = Model(inputs=vgg_model.input, outputs = vgg_model.get_layer('block2_conv2').output) 42 | b2c2_model.trainable=False 43 | 44 | b1c2_model = Model(inputs=vgg_model.input, outputs = vgg_model.get_layer('block1_conv2').output) 45 | b1c2_model.trainable=False 46 | 47 | input_t1 = layers.Input((512,512,3), name='Input_t1') 48 | input_t2 = layers.Input((512,512,3), name='Input_t2') 49 | 50 | t1_b5c3 = b5c3_model(input_t1) 51 | t2_b5c3 = b5c3_model(input_t2) 52 | 53 | t1_b4c3 = b4c3_model(input_t1) 54 | t2_b4c3 = b4c3_model(input_t2) 55 | 56 | t1_b3c3 = b3c3_model(input_t1) 57 | t2_b3c3 = b3c3_model(input_t2) 58 | 59 | t1_b2c2 = b2c2_model(input_t1) 60 | t2_b2c2 = b2c2_model(input_t2) 61 | 62 | t1_b1c2 = b1c2_model(input_t1) 63 | t2_b1c2 = b1c2_model(input_t2) 64 | 65 | concat_b5c3 = concatenate([t1_b5c3, t2_b5c3], axis=3) #channel 1024 66 | x = Conv2d_BN(concat_b5c3,512, 3) 67 | x = Conv2d_BN(x,512,3) 68 | attention_map_1 = get_spatial_attention_map(x) 69 | x = multiply([x, attention_map_1]) 70 | x = BatchNormalization(axis=3)(x) 71 | 72 | #branche1 73 | branch_1 =Conv2D(1, kernel_size=1, activation='sigmoid', padding='same',name='output_32')(x) 74 | 75 | x = Conv2DTranspose(512, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x) 76 | x = concatenate([x,t1_b4c3,t2_b4c3],axis=3) 77 | x = channel_attention(x) 78 | x = Conv2d_BN(x,512,3) 79 | x = Conv2d_BN(x,256,3) 80 | x = Conv2d_BN(x,256,3) 81 | attention_map_2 = get_spatial_attention_map(x) 82 | x = multiply([x, attention_map_2]) 83 | x = BatchNormalization(axis=3)(x) 84 | 85 | #branche2 86 | branch_2 =Conv2D(1, kernel_size=1, activation='sigmoid', padding='same',name='output_64')(x) 87 | 88 | x = Conv2DTranspose(256, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x) 89 | x = concatenate([x,t1_b3c3,t2_b3c3],axis=3) 90 | x = channel_attention(x) 91 | x = Conv2d_BN(x,256,3) 92 | x = Conv2d_BN(x,128,3) 93 | x = Conv2d_BN(x, 128, 3) 94 | attention_map_3 = get_spatial_attention_map(x) 95 | x = multiply([x, attention_map_3]) 96 | x = BatchNormalization(axis=3)(x) 97 | 98 | #branche3 99 | branch_3 =Conv2D(1, kernel_size=1, activation='sigmoid', padding='same',name='output_128')(x) 100 | 101 | x = Conv2DTranspose(128, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x) 102 | x = concatenate([x,t1_b2c2,t2_b2c2],axis=3) 103 | x = channel_attention(x) 104 | x = Conv2d_BN(x,128,3) 105 | x = Conv2d_BN(x,64,3) 106 | x = Conv2d_BN(x, 64, 3) 107 | attention_map_4 = get_spatial_attention_map(x) 108 | x = multiply([x, attention_map_4]) 109 | x = BatchNormalization(axis=3)(x) 110 | 111 | #branche4 112 | branch_4 =Conv2D(1, kernel_size=1, activation='sigmoid', padding='same',name='output_256')(x) 113 | 114 | x = Conv2DTranspose(64, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x) 115 | x = concatenate([x,t1_b1c2,t2_b1c2],axis=3) 116 | x = channel_attention(x) 117 | x = Conv2d_BN(x,64,3) 118 | x = Conv2d_BN(x,32,3) 119 | x = Conv2d_BN(x, 16, 3) 120 | attention_map_5 = get_spatial_attention_map(x) 121 | x = multiply([x, attention_map_5]) 122 | 123 | # branche5 124 | branch_5 =Conv2D(1, kernel_size=1, activation='sigmoid', padding='same',name='output_512')(x) 125 | 126 | DSIFN = Model(inputs=[input_t1,input_t2], outputs=[branch_1,branch_2,branch_3,branch_4,branch_5]) 127 | 128 | return DSIFN 129 | -------------------------------------------------------------------------------- /Keras version/attention_module.py: -------------------------------------------------------------------------------- 1 | from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, \ 2 | Conv2D, Add, Activation, Lambda 3 | from keras import backend as K 4 | from keras.activations import sigmoid 5 | 6 | def attach_attention_module(net, attention_module): 7 | if attention_module == 'se_block': # SE_block 8 | net = se_block(net) 9 | elif attention_module == 'cbam_block': # CBAM_block 10 | net = cbam_block(net) 11 | else: 12 | raise Exception("'{}' is not supported attention module!".format(attention_module)) 13 | 14 | return net 15 | 16 | 17 | def se_block(input_feature, ratio=8): 18 | """Contains the implementation of Squeeze-and-Excitation(SE) block. 19 | As described in https://arxiv.org/abs/1709.01507. 20 | """ 21 | 22 | channel_axis = 1 if K.image_data_format() == "channels_first" else -1 23 | channel = input_feature._keras_shape[channel_axis] 24 | 25 | se_feature = GlobalAveragePooling2D()(input_feature) 26 | se_feature = Reshape((1, 1, channel))(se_feature) 27 | assert se_feature._keras_shape[1:] == (1, 1, channel) 28 | se_feature = Dense(channel // ratio, 29 | activation='relu', 30 | kernel_initializer='he_normal', 31 | use_bias=True, 32 | bias_initializer='zeros')(se_feature) 33 | assert se_feature._keras_shape[1:] == (1, 1, channel // ratio) 34 | se_feature = Dense(channel, 35 | activation='sigmoid', 36 | kernel_initializer='he_normal', 37 | use_bias=True, 38 | bias_initializer='zeros')(se_feature) 39 | assert se_feature._keras_shape[1:] == (1, 1, channel) 40 | if K.image_data_format() == 'channels_first': 41 | se_feature = Permute((3, 1, 2))(se_feature) 42 | 43 | se_feature = multiply([input_feature, se_feature]) 44 | return se_feature 45 | 46 | 47 | def cbam_block(cbam_feature, ratio=8): 48 | """Contains the implementation of Convolutional Block Attention Module(CBAM) block. 49 | As described in https://arxiv.org/abs/1807.06521. 50 | """ 51 | 52 | cbam_feature = channel_attention(cbam_feature, ratio) 53 | cbam_feature = spatial_attention(cbam_feature) 54 | return cbam_feature 55 | 56 | 57 | def channel_attention(input_feature, ratio=8): 58 | channel_axis = 1 if K.image_data_format() == "channels_first" else -1 59 | channel = input_feature._keras_shape[channel_axis] 60 | 61 | shared_layer_one = Dense(channel // ratio, 62 | activation='relu', 63 | kernel_initializer='he_normal', 64 | use_bias=True, 65 | bias_initializer='zeros') 66 | shared_layer_two = Dense(channel, 67 | kernel_initializer='he_normal', 68 | use_bias=True, 69 | bias_initializer='zeros') 70 | 71 | avg_pool = GlobalAveragePooling2D()(input_feature) 72 | avg_pool = Reshape((1, 1, channel))(avg_pool) 73 | assert avg_pool._keras_shape[1:] == (1, 1, channel) 74 | avg_pool = shared_layer_one(avg_pool) 75 | assert avg_pool._keras_shape[1:] == (1, 1, channel // ratio) 76 | avg_pool = shared_layer_two(avg_pool) 77 | assert avg_pool._keras_shape[1:] == (1, 1, channel) 78 | 79 | max_pool = GlobalMaxPooling2D()(input_feature) 80 | max_pool = Reshape((1, 1, channel))(max_pool) 81 | assert max_pool._keras_shape[1:] == (1, 1, channel) 82 | max_pool = shared_layer_one(max_pool) 83 | assert max_pool._keras_shape[1:] == (1, 1, channel // ratio) 84 | max_pool = shared_layer_two(max_pool) 85 | assert max_pool._keras_shape[1:] == (1, 1, channel) 86 | 87 | cbam_feature = Add()([avg_pool, max_pool]) 88 | cbam_feature = Activation('sigmoid')(cbam_feature) 89 | 90 | if K.image_data_format() == "channels_first": 91 | cbam_feature = Permute((3, 1, 2))(cbam_feature) 92 | 93 | return multiply([input_feature, cbam_feature]) 94 | 95 | 96 | def spatial_attention(input_feature): 97 | kernel_size = 7 98 | 99 | if K.image_data_format() == "channels_first": 100 | channel = input_feature._keras_shape[1] 101 | cbam_feature = Permute((2, 3, 1))(input_feature) 102 | else: 103 | channel = input_feature._keras_shape[-1] 104 | cbam_feature = input_feature 105 | 106 | avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature) 107 | assert avg_pool._keras_shape[-1] == 1 108 | max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature) 109 | assert max_pool._keras_shape[-1] == 1 110 | concat = Concatenate(axis=3)([avg_pool, max_pool]) 111 | assert concat._keras_shape[-1] == 2 112 | cbam_feature = Conv2D(filters=1, 113 | kernel_size=kernel_size, 114 | strides=1, 115 | padding='same', 116 | activation='sigmoid', 117 | kernel_initializer='he_normal', 118 | use_bias=False)(concat) 119 | assert cbam_feature._keras_shape[-1] == 1 120 | 121 | if K.image_data_format() == "channels_first": 122 | cbam_feature = Permute((3, 1, 2))(cbam_feature) 123 | 124 | return multiply([input_feature, cbam_feature]) 125 | 126 | 127 | def get_spatial_attention_map(input_feature): 128 | kernel_size = 7 129 | cbam_feature = input_feature 130 | 131 | avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature) 132 | assert avg_pool._keras_shape[-1] == 1 133 | max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature) 134 | assert max_pool._keras_shape[-1] == 1 135 | concat = Concatenate(axis=3)([avg_pool, max_pool]) 136 | assert concat._keras_shape[-1] == 2 137 | cbam_feature = Conv2D(filters=1, 138 | kernel_size=kernel_size, 139 | strides=1, 140 | padding='same', 141 | activation='sigmoid', 142 | kernel_initializer='he_normal', 143 | use_bias=False)(concat) 144 | assert cbam_feature._keras_shape[-1] == 1 145 | 146 | if K.image_data_format() == "channels_first": 147 | cbam_feature = Permute((3, 1, 2))(cbam_feature) 148 | 149 | return cbam_feature 150 | -------------------------------------------------------------------------------- /Keras version/loss.py: -------------------------------------------------------------------------------- 1 | from keras.losses import binary_crossentropy 2 | import keras.backend as K 3 | 4 | def dice_coeff(y_true, y_pred): 5 | smooth = 1. 6 | y_true_f = K.flatten(y_true) 7 | y_pred_f = K.flatten(y_pred) 8 | intersection = K.sum(y_true_f * y_pred_f) 9 | score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 10 | return score 11 | 12 | def bce_dice_loss(y_true, y_pred): 13 | sig_y_true = K.sigmoid(y_true) 14 | sig_y_pred = K.sigmoid(y_pred) 15 | 16 | dice_loss = 1 - dice_coeff(y_true, y_pred) 17 | 18 | return binary_crossentropy(sig_y_true, sig_y_pred) + dice_loss 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images 2 | # 深度监督影像融合网络DSIFN用于高分辨率双时相遥感影像变化检测 3 | 4 | Official implement of the Paper:A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images. If you find this work helps in your research, please consider citing: 5 | 6 | 论文《A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sening images》的官方模型代码。如果该代码对你的研究有所帮助,烦请引用: 7 | 8 | > [Zhang, C., Yue, P., Tapete, D., Jiang, L., Shangguan, B., Huang, L., & Liu, G. (2020). A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images. ISPRS Journal of Photogrammetry and Remote Sensing, 166, 183-200.](https://www.sciencedirect.com/science/article/abs/pii/S0924271620301532) 9 | 10 | 11 | ## Introduction 12 | This repository includes DSIFN implementations in PyTorch and Keras version and datasets in the paper 13 | 14 | 该库包含了DSIFN网络的pytorch和keras版本的代码实现以及论文中使用的数据 15 | 16 | ## Model Structure 17 | The overview of Deeply supervised image fusion network (DSIFN). The network has two sub-networks: DFEN with pre-trained VGG16 as the backbone for deep feature extraction and DDN with deep feature fusion modules and deep supervision branches for change map reconstruction. 18 | 19 | 深度监督影像融合网络框架。该网络包含两个子网络:DFEN(深度特征提取网络)以VGG16为网络基底实现深度特征提取;DDN(差异判别网络)由深度特征融合模块和深度监督分支搭建实现影像变化图重建。 20 | 21 | ![1](imgs/1.png) 22 | 23 | ## Pytorch version requirements 24 | - Python3.7 25 | - PyTorch 1.6.0 26 | - torchversion 0.7.0 27 | 28 | ## Keras version requirements 29 | - Python 3.6 30 | - Tensorflow-gpu 1.13.1 31 | - Keras 2.2.4 32 | 33 | ## Reference 34 | > [Zhang, C., Yue, P., Tapete, D., Jiang, L., Shangguan, B., Huang, L., & Liu, G. (2020). A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images. ISPRS Journal of Photogrammetry and Remote Sensing, 166, 183-200.](https://www.sciencedirect.com/science/article/abs/pii/S0924271620301532) 35 | 36 | 37 | ## License 38 | Code and datasets are released for non-commercial and research purposes **only**. For commercial purposes, please contact the authors. 39 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # DSIFN Dataset 2 | DSIFN 数据集 3 | 4 | ## Introduction 5 | 6 | The dataset is available for scientific research. If you find the dataset helps in your research, please consider citing: 7 | 8 | > [Zhang, C., Yue, P., Tapete, D., Jiang, L., Shangguan, B., Huang, L., & Liu, G. (2020). A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images. ISPRS Journal of Photogrammetry and Remote Sensing, 166, 183-200.](https://www.sciencedirect.com/science/article/abs/pii/S0924271620301532) 9 | 10 | The dataset is manually collected from Google Earth. It consists of six large bi-temporal high resolution images covering six cities (i.e., Beijing, Chengdu, Shenzhen, Chongqing, Wuhan, Xian) in China. The five large image-pairs (i.e., Beijing, Chengdu, Shenzhen, Chongqing, Wuhan) are clipped into 394 subimage pairs with sizes of 512×512. After data augmentation, a collection of 3940 bi-temporal image pairs is acquired. Xian image pair is clipped into 48 image pairs for model testing. There are 3600 image pairs in the training dataset, 340 image paris in the validation dataset, and 48 image pairs in the test dataset. 11 | 12 | The below link provides the trainig, validation, and test datasets, as well as the model weight file, the predicted results on the test dataset. 13 | 14 | ## Dataset link 15 | 16 | -Link链接: https://pan.baidu.com/s/1EsVDb4FqJKLm23gHGvA77Q 17 | 18 | -Password提取码: ru08 19 | 20 | Or: 21 | 22 | -Google drive link: https://drive.google.com/drive/folders/1yutLU4WI7eeeGbuxilsq2OtGQr9EHDLy?usp=sharing 23 | 24 | ## Raw images 25 | ![1](imgs/1.png) 26 | 27 | 28 | ![2](imgs/2.png) 29 | 30 | 31 | ![3](imgs/3.png) 32 | 33 | 34 | ![4](imgs/4.png) 35 | 36 | 37 | ![5](imgs/5.png) 38 | 39 | 40 | ![6](imgs/6.png) 41 | -------------------------------------------------------------------------------- /dataset/imgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/dataset/imgs/1.png -------------------------------------------------------------------------------- /dataset/imgs/1.txt: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /dataset/imgs/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/dataset/imgs/2.png -------------------------------------------------------------------------------- /dataset/imgs/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/dataset/imgs/3.png -------------------------------------------------------------------------------- /dataset/imgs/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/dataset/imgs/4.png -------------------------------------------------------------------------------- /dataset/imgs/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/dataset/imgs/5.png -------------------------------------------------------------------------------- /dataset/imgs/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/dataset/imgs/6.png -------------------------------------------------------------------------------- /imgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/imgs/1.png -------------------------------------------------------------------------------- /imgs/1.test: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /pytorch version/DSIFN.py: -------------------------------------------------------------------------------- 1 | 2 | # credits: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.models import vgg16 7 | import numpy as np 8 | 9 | class vgg16_base(nn.Module): 10 | def __init__(self): 11 | super(vgg16_base,self).__init__() 12 | features = list(vgg16(pretrained=True).features)[:30] 13 | self.features = nn.ModuleList(features).eval() 14 | 15 | def forward(self,x): 16 | results = [] 17 | for ii, model in enumerate(self.features): 18 | x = model(x) 19 | if ii in {3,8,15,22,29}: 20 | results.append(x) 21 | return results 22 | 23 | class ChannelAttention(nn.Module): 24 | def __init__(self, in_channels, ratio = 8): 25 | super(ChannelAttention, self).__init__() 26 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 27 | self.max_pool = nn.AdaptiveMaxPool2d(1) 28 | self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) 29 | self.relu1 = nn.ReLU() 30 | self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) 31 | self.sigmod = nn.Sigmoid() 32 | def forward(self,x): 33 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 34 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 35 | out = avg_out + max_out 36 | return self.sigmod(out) 37 | 38 | class SpatialAttention(nn.Module): 39 | def __init__(self): 40 | super(SpatialAttention,self).__init__() 41 | self.conv1 = nn.Conv2d(2,1,7,padding=3,bias=False) 42 | self.sigmoid = nn.Sigmoid() 43 | def forward(self, x): 44 | avg_out = torch.mean(x,dim=1,keepdim=True) 45 | max_out = torch.max(x,dim=1,keepdim=True,out=None)[0] 46 | 47 | x = torch.cat([avg_out,max_out],dim=1) 48 | x = self.conv1(x) 49 | return self.sigmoid(x) 50 | 51 | def conv2d_bn(in_channels, out_channels): 52 | return nn.Sequential( 53 | nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1), 54 | nn.PReLU(), 55 | nn.BatchNorm2d(out_channels), 56 | nn.Dropout(p=0.6), 57 | ) 58 | 59 | class DSIFN(nn.Module): 60 | def __init__(self, model_A, model_B): 61 | super().__init__() 62 | self.t1_base = model_A 63 | self.t2_base = model_B 64 | self.sa1 = SpatialAttention() 65 | self.sa2= SpatialAttention() 66 | self.sa3 = SpatialAttention() 67 | self.sa4 = SpatialAttention() 68 | self.sa5 = SpatialAttention() 69 | 70 | self.sigmoid = nn.Sigmoid() 71 | 72 | # branch1 73 | self.ca1 = ChannelAttention(in_channels=1024) 74 | self.bn_ca1 = nn.BatchNorm2d(1024) 75 | self.o1_conv1 = conv2d_bn(1024, 512) 76 | self.o1_conv2 = conv2d_bn(512, 512) 77 | self.bn_sa1 = nn.BatchNorm2d(512) 78 | self.o1_conv3 = nn.Conv2d(512, 1, 1) 79 | self.trans_conv1 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2) 80 | 81 | # branch 2 82 | self.ca2 = ChannelAttention(in_channels=1536) 83 | self.bn_ca2 = nn.BatchNorm2d(1536) 84 | self.o2_conv1 = conv2d_bn(1536, 512) 85 | self.o2_conv2 = conv2d_bn(512, 256) 86 | self.o2_conv3 = conv2d_bn(256, 256) 87 | self.bn_sa2 = nn.BatchNorm2d(256) 88 | self.o2_conv4 = nn.Conv2d(256, 1, 1) 89 | self.trans_conv2 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2) 90 | 91 | # branch 3 92 | self.ca3 = ChannelAttention(in_channels=768) 93 | self.o3_conv1 = conv2d_bn(768, 256) 94 | self.o3_conv2 = conv2d_bn(256, 128) 95 | self.o3_conv3 = conv2d_bn(128, 128) 96 | self.bn_sa3 = nn.BatchNorm2d(128) 97 | self.o3_conv4 = nn.Conv2d(128, 1, 1) 98 | self.trans_conv3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2) 99 | 100 | # branch 4 101 | self.ca4 = ChannelAttention(in_channels=384) 102 | self.o4_conv1 = conv2d_bn(384, 128) 103 | self.o4_conv2 = conv2d_bn(128, 64) 104 | self.o4_conv3 = conv2d_bn(64, 64) 105 | self.bn_sa4 = nn.BatchNorm2d(64) 106 | self.o4_conv4 = nn.Conv2d(64, 1, 1) 107 | self.trans_conv4 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) 108 | 109 | # branch 5 110 | self.ca5 = ChannelAttention(in_channels=192) 111 | self.o5_conv1 = conv2d_bn(192, 64) 112 | self.o5_conv2 = conv2d_bn(64, 32) 113 | self.o5_conv3 = conv2d_bn(32, 16) 114 | self.bn_sa5 = nn.BatchNorm2d(16) 115 | self.o5_conv4 = nn.Conv2d(16, 1, 1) 116 | 117 | def forward(self,t1_input,t2_input): 118 | t1_list = self.t1_base(t1_input) 119 | t2_list = self.t2_base(t2_input) 120 | 121 | t1_f_l3,t1_f_l8,t1_f_l15,t1_f_l22,t1_f_l29 = t1_list[0],t1_list[1],t1_list[2],t1_list[3],t1_list[4] 122 | t2_f_l3,t2_f_l8,t2_f_l15,t2_f_l22,t2_f_l29,= t2_list[0],t2_list[1],t2_list[2],t2_list[3],t2_list[4] 123 | 124 | x = torch.cat((t1_f_l29,t2_f_l29),dim=1) 125 | #optional to use channel attention module in the first combined feature 126 | #在第一个深度特征叠加层之后可以选择使用或者不使用通道注意力模块 127 | # x = self.ca1(x) * x 128 | x = self.o1_conv1(x) 129 | x = self.o1_conv2(x) 130 | x = self.sa1(x) * x 131 | x = self.bn_sa1(x) 132 | 133 | branch_1_out = self.sigmoid(self.o1_conv3(x)) 134 | 135 | x = self.trans_conv1(x) 136 | x = torch.cat((x,t1_f_l22,t2_f_l22),dim=1) 137 | x = self.ca2(x)*x 138 | #According to the amount of the training data, appropriately reduce the use of conv layers to prevent overfitting 139 | #根据训练数据的大小,适当减少conv层的使用来防止过拟合 140 | x = self.o2_conv1(x) 141 | x = self.o2_conv2(x) 142 | x = self.o2_conv3(x) 143 | x = self.sa2(x) *x 144 | x = self.bn_sa2(x) 145 | 146 | branch_2_out = self.sigmoid(self.o2_conv4(x)) 147 | 148 | x = self.trans_conv2(x) 149 | x = torch.cat((x,t1_f_l15,t2_f_l15),dim=1) 150 | x = self.ca3(x)*x 151 | x = self.o3_conv1(x) 152 | x = self.o3_conv2(x) 153 | x = self.o3_conv3(x) 154 | x = self.sa3(x) *x 155 | x = self.bn_sa3(x) 156 | 157 | branch_3_out = self.sigmoid(self.o3_conv4(x)) 158 | 159 | x = self.trans_conv3(x) 160 | x = torch.cat((x,t1_f_l8,t2_f_l8),dim=1) 161 | x = self.ca4(x)*x 162 | x = self.o4_conv1(x) 163 | x = self.o4_conv2(x) 164 | x = self.o4_conv3(x) 165 | x = self.sa4(x) *x 166 | x = self.bn_sa4(x) 167 | 168 | branch_4_out = self.sigmoid(self.o4_conv4(x)) 169 | 170 | x = self.trans_conv4(x) 171 | x = torch.cat((x,t1_f_l3,t2_f_l3),dim=1) 172 | x = self.ca5(x)*x 173 | x = self.o5_conv1(x) 174 | x = self.o5_conv2(x) 175 | x = self.o5_conv3(x) 176 | x = self.sa5(x) *x 177 | x = self.bn_sa5(x) 178 | 179 | branch_5_out = self.sigmoid(self.o5_conv4(x)) 180 | 181 | return branch_5_out,branch_4_out,branch_3_out,branch_2_out,branch_1_out 182 | -------------------------------------------------------------------------------- /pytorch version/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def cd_loss(input,target): 4 | bce_loss = nn.BCELoss() 5 | bce_loss = bce_loss(torch.sigmoid(input),target) 6 | 7 | smooth = 1. 8 | iflat = input.view(-1) 9 | tflat = target.view(-1) 10 | intersection = (iflat * tflat).sum() 11 | dic_loss = 1 - ((2. * intersection + smooth)/(iflat.sum() + tflat.sum() + smooth)) 12 | 13 | return dic_loss + bce_loss 14 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | # DSIFN predicted results and model weights of the two datasets 2 | DSIFN 预测结果和模型权重文件 3 | 4 | ## Introduction 5 | 6 | The predicted results & model weights files are provided in the below link. Note that the model weights files are in '.h5' format and are available only on the Keras version. 7 | 8 | To load the model weights file: 9 | 10 | -- from keras.models import load_model 11 | 12 | -- model = load_model('best_model.h5', custom_objects={'bce_dice_loss': bce_dice_loss,'f1':f1}) 13 | 14 | -- results = model.predict(input_data) 15 | 16 | 17 | ## Download links 18 | 19 | -Link: https://pan.baidu.com/s/1Ae6natlx2pA3ULfC56uMkg 20 | 21 | -Password: nl1h 22 | 23 | ## Example of predicted images 24 | 25 | ![1](imgs/00000.jpg) 26 | 27 | 28 | ![2](imgs/00001.jpg) 29 | 30 | 31 | ![3](imgs/00002.jpg) 32 | 33 | 34 | ![4](imgs/41.jpg) 35 | 36 | 37 | ![5](imgs/19.jpg) 38 | 39 | 40 | ![6](imgs/4.jpg) 41 | -------------------------------------------------------------------------------- /results/imgs/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/results/imgs/00000.jpg -------------------------------------------------------------------------------- /results/imgs/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/results/imgs/00001.jpg -------------------------------------------------------------------------------- /results/imgs/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/results/imgs/00002.jpg -------------------------------------------------------------------------------- /results/imgs/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/results/imgs/19.jpg -------------------------------------------------------------------------------- /results/imgs/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/results/imgs/4.jpg -------------------------------------------------------------------------------- /results/imgs/41.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/0ddad933493ab85875d09d6b2c92e2bd37b3b68e/results/imgs/41.jpg --------------------------------------------------------------------------------