├── data └── SG │ ├── GT.png │ ├── T1.png │ ├── T2.png │ └── data_description ├── figures └── SRGCAE.jpg ├── result ├── Adaptive_Fuse.png ├── SRGCAE_VerConc.png ├── Adaptive_Fuse_DI.png ├── SRGCAE_EdgeConc.png ├── Adaptive_Fuse_pps.png ├── SRGCAE_EdgeConc_DI.png └── SRGCAE_VerConc_DI.png ├── model_weight ├── SRGCAE_ER_SG.pth └── SRGCAE_VR_SG.pth ├── model ├── __pycache__ │ └── SRGCAE.cpython-39.pyc ├── SRGCAE.py └── GraphConv.py ├── aux_func ├── __pycache__ │ ├── acc_ass.cpython-39.pyc │ ├── clustering.cpython-39.pyc │ ├── graph_func.cpython-39.pyc │ └── preprocess.cpython-39.pyc ├── postprocess.py ├── clustering.py ├── graph_func.py ├── acc_ass.py └── preprocess.py ├── adaptive_fuse.py ├── README.md ├── train_SRGCAE_Local.py └── train_SRGCAE_Nonlocal.py /data/SG/GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/data/SG/GT.png -------------------------------------------------------------------------------- /data/SG/T1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/data/SG/T1.png -------------------------------------------------------------------------------- /data/SG/T2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/data/SG/T2.png -------------------------------------------------------------------------------- /figures/SRGCAE.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/figures/SRGCAE.jpg -------------------------------------------------------------------------------- /result/Adaptive_Fuse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/result/Adaptive_Fuse.png -------------------------------------------------------------------------------- /result/SRGCAE_VerConc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/result/SRGCAE_VerConc.png -------------------------------------------------------------------------------- /result/Adaptive_Fuse_DI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/result/Adaptive_Fuse_DI.png -------------------------------------------------------------------------------- /result/SRGCAE_EdgeConc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/result/SRGCAE_EdgeConc.png -------------------------------------------------------------------------------- /model_weight/SRGCAE_ER_SG.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/model_weight/SRGCAE_ER_SG.pth -------------------------------------------------------------------------------- /model_weight/SRGCAE_VR_SG.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/model_weight/SRGCAE_VR_SG.pth -------------------------------------------------------------------------------- /result/Adaptive_Fuse_pps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/result/Adaptive_Fuse_pps.png -------------------------------------------------------------------------------- /result/SRGCAE_EdgeConc_DI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/result/SRGCAE_EdgeConc_DI.png -------------------------------------------------------------------------------- /result/SRGCAE_VerConc_DI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/result/SRGCAE_VerConc_DI.png -------------------------------------------------------------------------------- /model/__pycache__/SRGCAE.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/model/__pycache__/SRGCAE.cpython-39.pyc -------------------------------------------------------------------------------- /aux_func/__pycache__/acc_ass.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/aux_func/__pycache__/acc_ass.cpython-39.pyc -------------------------------------------------------------------------------- /aux_func/__pycache__/clustering.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/aux_func/__pycache__/clustering.cpython-39.pyc -------------------------------------------------------------------------------- /aux_func/__pycache__/graph_func.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/aux_func/__pycache__/graph_func.cpython-39.pyc -------------------------------------------------------------------------------- /aux_func/__pycache__/preprocess.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SRGCAE/HEAD/aux_func/__pycache__/preprocess.cpython-39.pyc -------------------------------------------------------------------------------- /data/SG/data_description: -------------------------------------------------------------------------------- 1 | This multimodal dataset is from Prof. Maoguo Gong's group. 2 | If you use this dataset, we suggest you kindly considering to cite their paper, such as [1] and [2]. 3 | [1] J. Liu, M. Gong, K. Qin, and P. Zhang, “A deep convolutional coupling network for change detection based on heterogeneous optical and radar images,” IEEE Trans. Neural Netw. Learn. Syst., vol. 29, no. 3, pp. 545–559, Mar. 2016. 4 | [2] P. Zhang, M. Gong, L. Su, J. Liu, and Z. Li, “Change detection based on deep feature representation and mapping transformation for multi-spatial-resolution remote sensing images,” ISPRS J. Photogramm. Remote Sens., vol. 116, pp. 24–41, Jun. 2016. -------------------------------------------------------------------------------- /aux_func/postprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import imageio 3 | from acc_ass import assess_accuracy 4 | 5 | 6 | def postprocess(data): 7 | kernel_1 = cv.getStructuringElement(cv.MORPH_ELLIPSE, (45, 45)) 8 | kernel_2 = cv.getStructuringElement(cv.MORPH_ELLIPSE, (15, 15)) 9 | 10 | bcm_2 = cv.morphologyEx(data, cv.MORPH_CLOSE, kernel_1, 1) 11 | bcm_2 = cv.morphologyEx(bcm_2, cv.MORPH_OPEN, kernel_2, 1) 12 | 13 | imageio.imsave('../result/Adaptive_Fuse_pps.png', bcm_2) 14 | 15 | ground_truth_changed = imageio.imread('../data/SG/GT.png') 16 | ground_truth_unchanged = 255 - ground_truth_changed 17 | conf_mat, oa, f1, kappa_co = assess_accuracy(ground_truth_changed, ground_truth_unchanged, bcm_2) 18 | print(conf_mat) 19 | print(oa) 20 | print(f1) 21 | print(kappa_co) 22 | 23 | 24 | if __name__ == '__main__': 25 | cm_path = '../result/Adaptive_Fuse.png' 26 | cm = imageio.imread(cm_path) 27 | postprocess(cm) 28 | -------------------------------------------------------------------------------- /aux_func/clustering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def otsu(data, num=1000): 5 | max_value = np.max(data) 6 | min_value = np.min(data) 7 | 8 | total_num = data.shape[0] 9 | step_value = (max_value - min_value) / num 10 | value = min_value + step_value 11 | best_threshold = min_value 12 | best_inter_class_var = 0 13 | while value <= max_value: 14 | data_1 = data[data < value] 15 | data_2 = data[data >= value] 16 | w1 = data_1.shape[0] / total_num 17 | w2 = data_2.shape[0] / total_num 18 | 19 | mean_1 = data_1.mean() 20 | mean_2 = data_2.mean() 21 | 22 | inter_class_var = w1 * w2 * np.power((mean_1 - mean_2), 2) 23 | if best_inter_class_var < inter_class_var: 24 | best_inter_class_var = inter_class_var 25 | best_threshold = value 26 | value += step_value 27 | # bcm = np.zeros(data.shape).astype(np.uint8) 28 | # bcm[data <= best_threshold] = 0 29 | # bcm[data > best_threshold] = 255 30 | return best_threshold 31 | -------------------------------------------------------------------------------- /aux_func/graph_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.pairwise import pairwise_distances 3 | 4 | 5 | def gaussian_kernel_distance(vector, band_width): 6 | euc_dis = pairwise_distances(vector) 7 | gaus_dis = np.exp(- euc_dis * euc_dis / (band_width * band_width)) 8 | return gaus_dis 9 | 10 | 11 | def normalize_adj(adj): 12 | """Symmetrically normalize adjacency matrix.""" 13 | # adj = np.coo_matrix(adj) np.coo_max 14 | rowsum = np.array(adj.sum(1)) # D 15 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() # D^-0.5 16 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 17 | d_mat_inv_sqrt = np.diag(d_inv_sqrt) # D^-0.5 18 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) # D^-0.5AD^0.5 19 | 20 | 21 | def construct_affinity_matrix(data, objects, band_width): 22 | am_set = [] 23 | obj_nums = np.max(objects) 24 | for i in range(0, obj_nums + 1): 25 | sub_object = data[objects == i] 26 | adj_mat = gaussian_kernel_distance(sub_object, band_width=band_width) 27 | norm_adj_mat = normalize_adj(adj_mat) 28 | am_set.append([adj_mat, norm_adj_mat]) 29 | return am_set 30 | -------------------------------------------------------------------------------- /model/SRGCAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from GraphConv import GraphConvolution 5 | 6 | 7 | class GraphConvAutoEncoder_VertexRecon(nn.Module): 8 | def __init__(self, nfeat, nhid, nclass, dropout): 9 | super(GraphConvAutoEncoder_VertexRecon, self).__init__() 10 | 11 | self.gc1 = GraphConvolution(nfeat, nhid) 12 | self.gc2 = GraphConvolution(nhid, 2 * nhid) 13 | self.dropout = nn.Dropout(p=dropout) 14 | self.gc3 = GraphConvolution(2 * nhid, nclass) 15 | 16 | def forward(self, x, adj): 17 | x = torch.sigmoid(self.gc1(x, adj)) 18 | x = torch.sigmoid(self.gc2(x, adj)) 19 | feat = x 20 | x = self.dropout(x) 21 | x = self.gc3(x, adj) 22 | return x, feat 23 | 24 | 25 | class GraphConvAutoEncoder_EdgeRecon(nn.Module): 26 | def __init__(self, nfeat, nhid, nclass, dropout): 27 | super(GraphConvAutoEncoder_EdgeRecon, self).__init__() 28 | 29 | self.gc1 = GraphConvolution(nfeat, nhid) 30 | self.gc2 = GraphConvolution(nhid, 2 * nhid) 31 | # self.dropout = nn.Dropout(p=dropout) 32 | # self.gc3 = GraphConvolution(2 * nhid, nclass) 33 | 34 | def forward(self, x, adj): 35 | x = torch.sigmoid(self.gc1(x, adj)) 36 | x = torch.sigmoid(self.gc2(x, adj)) 37 | # x = self.dropout(x) 38 | # x = self.gc3(x, adj) 39 | return x 40 | -------------------------------------------------------------------------------- /adaptive_fuse.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | from aux_func.acc_ass import assess_accuracy 4 | from aux_func.clustering import otsu 5 | 6 | 7 | def fuse_DI(): 8 | LocalGCN_DI = imageio.imread('./result/SRGCAE_VerConc_DI.png').astype(np.float32) 9 | NLocalGCN_DI = imageio.imread('./result/SRGCAE_EdgeConc_DI.png').astype(np.float32) 10 | height, width = LocalGCN_DI.shape 11 | alpha = np.var(LocalGCN_DI.reshape(-1)) 12 | beta = np.var(NLocalGCN_DI.reshape(-1)) 13 | fuse_DI = (alpha * LocalGCN_DI + beta * NLocalGCN_DI) / (alpha + beta) 14 | fuse_DI = np.reshape(fuse_DI, (height * width, 1)) 15 | threshold = otsu(fuse_DI) 16 | fuse_DI = np.reshape(fuse_DI, (height, width)) 17 | bcm = np.zeros((height, width)).astype(np.uint8) 18 | bcm[fuse_DI > threshold] = 255 19 | bcm[fuse_DI <= threshold] = 0 20 | imageio.imsave('./result/Adaptive_Fuse.png', bcm) 21 | 22 | fuse_DI = 255 * ((fuse_DI - np.min(fuse_DI)) / (np.max(fuse_DI) - np.min(fuse_DI))) 23 | imageio.imsave('./result/Adaptive_Fuse_DI.png', fuse_DI.astype(np.uint8)) 24 | 25 | ground_truth_changed = imageio.imread('./data/SG/GT.png') 26 | ground_truth_unchanged = 255 - ground_truth_changed 27 | conf_mat, oa, f1, kappa_co = assess_accuracy(ground_truth_changed, ground_truth_unchanged, bcm) 28 | print(conf_mat) 29 | print(oa) 30 | print(f1) 31 | print(kappa_co) 32 | 33 | 34 | if __name__ == '__main__': 35 | fuse_DI() 36 | -------------------------------------------------------------------------------- /model/GraphConv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch.nn.parameter import Parameter 6 | from torch.nn.modules.module import Module 7 | 8 | 9 | class GraphConvolution(Module): 10 | """ 11 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 12 | """ 13 | 14 | def __init__(self, in_features, out_features, bias=True): 15 | super(GraphConvolution, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 19 | if bias: 20 | self.bias = Parameter(torch.FloatTensor(out_features)) 21 | else: 22 | self.register_parameter('bias', None) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1. / math.sqrt(self.weight.size(1)) 27 | self.weight.data.uniform_(-stdv, stdv) 28 | if self.bias is not None: 29 | self.bias.data.uniform_(-stdv, stdv) 30 | 31 | def forward(self, x, adj): 32 | support = torch.mm(x, self.weight) 33 | output = torch.mm(adj, support) 34 | if self.bias is not None: 35 | return output + self.bias 36 | else: 37 | return output 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + ' (' \ 41 | + str(self.in_features) + ' -> ' \ 42 | + str(self.out_features) + ')' 43 | -------------------------------------------------------------------------------- /aux_func/acc_ass.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assess_accuracy(gt_changed, gt_unchanged, changed_map): 5 | """ 6 | assess accuracy of changed map based on ground truth 7 | :param gt_changed: changed ground truth 8 | :param gt_unchanged: unchanged ground truth 9 | :param changed_map: changed map 10 | :return: confusion matrix and overall accuracy 11 | """ 12 | change_index = (gt_changed == 255) 13 | unchanged_index = (gt_unchanged == 255) 14 | n_cc = (changed_map[change_index] == 255).sum() # changed-changed 15 | n_uu = (changed_map[unchanged_index] == 0).sum() # unchanged-unchanged 16 | n_cu = change_index.sum() - n_cc # changed-unchanged 17 | n_uc = unchanged_index.sum() - n_uu # unchanged-changed 18 | 19 | conf_mat = np.array([[n_cc, n_cu], [n_uc, n_uu]]) 20 | 21 | pre = n_cc / (n_cc + n_uc) 22 | rec = n_cc / (n_cc + n_cu) 23 | f1 = 2 * pre * rec / (pre + rec) 24 | 25 | over_acc = (conf_mat.diagonal().sum()) / (conf_mat.sum()) 26 | # pe = np.array(0, np.int64) 27 | pe = ((n_cc + n_cu) / conf_mat.sum() * (n_cc + n_uc) + (n_uu + n_uc) / conf_mat.sum() * ( 28 | n_uu + n_cu)) / conf_mat.sum() 29 | kappa_co = (over_acc - pe) / (1 - pe) 30 | return conf_mat, over_acc, f1, kappa_co 31 | 32 | 33 | def assess_accuracy_from_conf_mat(conf_mat): 34 | """ 35 | assess accuracy of changed map based on ground truth 36 | :param gt_changed: changed ground truth 37 | :param gt_unchanged: unchanged ground truth 38 | :param changed_map: changed map 39 | :return: confusion matrix and overall accuracy 40 | """ 41 | 42 | n_cc = conf_mat[0, 0] 43 | n_cu = conf_mat[0, 1] 44 | n_uc = conf_mat[1, 0] 45 | n_uu = conf_mat[1, 1] 46 | # conf_mat = np.array([[n_cc, n_cu], [n_uc, n_uu]]) 47 | 48 | pre = n_cc / (n_cc + n_uc) 49 | rec = n_cc / (n_cc + n_cu) 50 | f1 = 2 * pre * rec / (pre + rec) 51 | 52 | over_acc = (conf_mat.diagonal().sum()) / (conf_mat.sum()) 53 | # pe = np.array(0, np.int64) 54 | pe = ((n_cc + n_cu) / conf_mat.sum() * (n_cc + n_uc) + (n_uu + n_uc) / conf_mat.sum() * ( 55 | n_uu + n_cu)) / conf_mat.sum() 56 | kappa_co = (over_acc - pe) / (1 - pe) 57 | return conf_mat, over_acc, f1, kappa_co, kappa_co 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Unsupervised Multimodal Change Detection Based on Structural Relationship Graph Representation Learning

2 | 3 | 4 |

Hongruixuan Chen, Naoto Yokoya, Chen Wu, and Bo Du

5 | 6 | This is an official implementation of unsupervised multimodal change detection framework **SR-GCAE** in our IEEE TGRS 2022 paper: [Unsupervised Multimodal Change Detection Based on Structural Relationship Graph Representation Learning](https://ieeexplore.ieee.org/document/9984688). 7 | 8 | 9 | ## Abstract 10 | > Unsupervised multimodal change detection is a practical and challenging topic that can play an important role in time-sensitive emergency applications. To address the challenge that multimodal remote sensing images cannot be directly compared due to their modal heterogeneity, we take advantage of two types of modality-independent structural relationships in multimodal images. In particular, we present a structural relationship graph representation learning framework for measuring the similarity of the two structural relationships. Firstly, structural graphs are generated from preprocessed multimodal image pairs by means of an object-based image analysis approach. Then, a structural relationship graph convolutional autoencoder (SR-GCAE) is proposed to learn robust and representative features from graphs. Two loss functions aiming at reconstructing vertex information and edge information are presented to make the learned representations applicable for structural relationship similarity measurement. Subsequently, the similarity levels of two structural relationships are calculated from learned graph representations and two difference images are generated based on the similarity levels. After obtaining the difference images, an adaptive fusion strategy is presented to fuse the two difference images. Finally, a morphological filtering-based postprocessing approach is employed to refine the detection results. Experimental results on six datasets with different modal combinations demonstrate the effectiveness of the proposed method. 11 | 12 | ## Network architecture 13 | 14 | 15 | ## Get started 16 | ### Requirements 17 | 18 | ``` 19 | python==3.9.7 20 | pytorch==1.9.0 21 | scikit-learn==0.18.3 22 | imageio=2.9.0 23 | numpy==1.20.3 24 | gdal==3.0.2 25 | opencv==4.5.5 26 | ``` 27 | ### Dataset 28 | This repo contains the Shuguang dataset. The homogeneous dataset, Hanyang dataset has been open-sourced. You can download it [here](http://sigma.whu.edu.cn/resource.php). The Texas dataset can be downloaded from Prof. Michele Volpi's [webpage](https://sites.google.com/site/michelevolpiresearch/home). 29 | 30 | ### Usage 31 | Performing edge information reconstruction and detecting land-cover changes by utilizing local structural relationship 32 | ``` 33 | train_SRGCAE_Local.py 34 | ``` 35 | Performing vertex information reconstruction and detecting land-cover changes by utilizing nonlocal structural relationship 36 | ``` 37 | train_SRGCAE_Nonlocal.py 38 | ``` 39 | 40 | Adaptively fusing the difference maps 41 | ``` 42 | adaptive_fuse.py 43 | ``` 44 | 45 | Postprocessing based on morphological filtering 46 | ``` 47 | aux_func/postprocess.py 48 | ``` 49 | 50 | 51 | 52 | 53 | ## Citation 54 | If this code or dataset contributes to your research, please consider citing our paper. We appreciate your support!🙂 55 | ``` 56 | @article{chen2022unsupervised, 57 | author={Chen, Hongruixuan and Yokoya, Naoto and Wu, Chen and Du, Bo}, 58 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 59 | title={Unsupervised Multimodal Change Detection Based on Structural Relationship Graph Representation Learning}, 60 | year={2022}, 61 | volume={60}, 62 | number={}, 63 | pages={1-18}, 64 | doi={10.1109/TGRS.2022.3229027} 65 | } 66 | ``` 67 | 68 | ## Acknowledgement 69 | The Python code draws in part on the Matlab code of [NPSG](https://github.com/yulisun/NPSG) and [IRGMcS](https://github.com/yulisun/IRG-McS). Many thanks for these brilliant works! 70 | 71 | ## Q & A 72 | **For any questions, please [contact us.](mailto:Qschrx@gmail.com)** 73 | -------------------------------------------------------------------------------- /aux_func/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def preprocess_img(data, d_type, norm_type): 5 | pps_data = np.array(data).astype(np.float32) 6 | if d_type == 'opt': 7 | if norm_type == 'stad': 8 | pps_data = stad_img(pps_data, channel_first=False) 9 | else: 10 | pps_data = norm_img(pps_data, channel_first=False) 11 | elif d_type == 'sar': 12 | pps_data[np.abs(pps_data) <= 0] = np.min(pps_data[np.abs(pps_data) > 0]) 13 | pps_data = np.log(pps_data + 1.0) 14 | if norm_type == 'stad': 15 | pps_data = stad_img(pps_data, channel_first=False) 16 | # sigma = np.std(pps_data) 17 | # mean = np.mean(pps_data) 18 | # idx_min = pps_data < (mean - 4 * sigma) 19 | # idx_max = pps_data > (mean + 4 * sigma) 20 | # pps_data[idx_min] = np.min(pps_data[~idx_min]) 21 | # pps_data[idx_max] = np.max(pps_data[~idx_max]) 22 | else: 23 | pps_data = norm_img(pps_data, channel_first=False) 24 | # sigma = np.std(pps_data) 25 | # mean = np.mean(pps_data) 26 | # idx_min = pps_data < (mean - 4* sigma) 27 | # idx_max = pps_data > (mean + 4 * sigma) 28 | # pps_data[idx_min] = np.min(pps_data[~idx_min]) 29 | # pps_data[idx_max] = np.max(pps_data[~idx_max]) 30 | return pps_data 31 | 32 | 33 | def norm_img(img, channel_first): 34 | ''' 35 | normalize value to [0, 1] 36 | ''' 37 | if channel_first: 38 | channel, img_height, img_width = img.shape 39 | img = np.reshape(img, (channel, img_height * img_width)) # (channel, height * width) 40 | max_value = np.max(img, axis=1, keepdims=True) # (channel, 1) 41 | min_value = np.min(img, axis=1, keepdims=True) # (channel, 1) 42 | diff_value = max_value - min_value 43 | nm_img = (img - min_value) / diff_value 44 | nm_img = np.reshape(nm_img, (channel, img_height, img_width)) 45 | else: 46 | img_height, img_width, channel = img.shape 47 | img = np.reshape(img, (img_height * img_width, channel)) # (channel, height * width) 48 | max_value = np.max(img, axis=0, keepdims=True) # (channel, 1) 49 | min_value = np.min(img, axis=0, keepdims=True) # (channel, 1) 50 | diff_value = max_value - min_value 51 | nm_img = (img - min_value) / diff_value 52 | nm_img = np.reshape(nm_img, (img_height, img_width, channel)) 53 | return nm_img 54 | 55 | 56 | def norm_img_2(img, channel_first): 57 | ''' 58 | normalize value to [-1, 1] 59 | ''' 60 | if channel_first: 61 | channel, img_height, img_width = img.shape 62 | img = np.reshape(img, (channel, img_height * img_width)) # (channel, height * width) 63 | max_value = np.max(img, axis=1, keepdims=True) # (channel, 1) 64 | min_value = np.min(img, axis=1, keepdims=True) # (channel, 1) 65 | diff_value = max_value - min_value 66 | nm_img = 2 * ((img - min_value) / diff_value - 0.5) 67 | nm_img = np.reshape(nm_img, (channel, img_height, img_width)) 68 | else: 69 | img_height, img_width, channel = img.shape 70 | img = np.reshape(img, (img_height * img_width, channel)) # (channel, height * width) 71 | max_value = np.max(img, axis=0, keepdims=True) # (channel, 1) 72 | min_value = np.min(img, axis=0, keepdims=True) # (channel, 1) 73 | diff_value = max_value - min_value 74 | nm_img = 2 * ((img - min_value) / diff_value - 0.5) 75 | nm_img = np.reshape(nm_img, (img_height, img_width, channel)) 76 | return nm_img 77 | 78 | 79 | def stad_img(img, channel_first): 80 | """ 81 | normalization image 82 | :param channel_first: 83 | :param img: (C, H, W) 84 | :return: 85 | norm_img: (C, H, W) 86 | """ 87 | if channel_first: 88 | channel, img_height, img_width = img.shape 89 | img = np.reshape(img, (channel, img_height * img_width)) # (channel, height * width) 90 | mean = np.mean(img, axis=1, keepdims=True) # (channel, 1) 91 | center = img - mean # (channel, height * width) 92 | var = np.sum(np.power(center, 2), axis=1, keepdims=True) / (img_height * img_width) # (channel, 1) 93 | std = np.sqrt(var) # (channel, 1) 94 | nm_img = center / std # (channel, height * width) 95 | nm_img = np.reshape(nm_img, (channel, img_height, img_width)) 96 | else: 97 | img_height, img_width, channel = img.shape 98 | img = np.reshape(img, (img_height * img_width, channel)) # (height * width, channel) 99 | mean = np.mean(img, axis=0, keepdims=True) # (1, channel) 100 | center = img - mean # (height * width, channel) 101 | var = np.sum(np.power(center, 2), axis=0, keepdims=True) / (img_height * img_width) # (1, channel) 102 | std = np.sqrt(var) # (channel, 1) 103 | nm_img = center / std # (channel, height * width) 104 | nm_img = np.reshape(nm_img, (img_height, img_width, channel)) 105 | return nm_img 106 | -------------------------------------------------------------------------------- /train_SRGCAE_Local.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import imageio 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from skimage.segmentation import slic 10 | 11 | from aux_func.acc_ass import assess_accuracy 12 | from aux_func.clustering import otsu 13 | from aux_func.graph_func import construct_affinity_matrix 14 | from aux_func.preprocess import preprocess_img 15 | from model.SRGCAE import GraphConvAutoEncoder_EdgeRecon 16 | 17 | 18 | def load_checkpoint_for_evaluation(model, checkpoint): 19 | saved_state_dict = torch.load(checkpoint, map_location='cuda:0') 20 | model.load_state_dict(saved_state_dict) 21 | model.cuda() 22 | model.eval() 23 | 24 | 25 | def train_model(args): 26 | img_t1 = imageio.imread('./data/SG/T1.png') # .astype(np.float32) 27 | img_t2 = imageio.imread('./data/SG/T2.png') # .astype(np.float32) 28 | ground_truth_changed = imageio.imread('./data/SG/GT.png') 29 | ground_truth_unchanged = 255 - ground_truth_changed 30 | 31 | height, width, channel_t1 = img_t1.shape 32 | _, _, channel_t2 = img_t2.shape 33 | 34 | # In our paper, the object map is obtained through FNEA algorithm based on eCognition. 35 | # According to our response to reviewers, SLIC can be implemented in Python as an alternative algorithm. 36 | # There is little difference in accuracy between the results obtained by these two methods. 37 | # Yet SLIC can only process images in three bands, so you will need to process images in more than three bands. 38 | objects = slic(img_t2, n_segments=args.n_seg, compactness=args.cmp) 39 | 40 | img_t1 = preprocess_img(img_t1, d_type='sar', norm_type='stad') 41 | img_t2 = preprocess_img(img_t2, d_type='opt', norm_type='stad') 42 | # objects = np.load('./object_idx.npy') 43 | obj_nums = np.max(objects) + 1 44 | 45 | node_set_t1 = [] 46 | node_set_t2 = [] 47 | for idx in range(obj_nums): 48 | obj_idx = objects == idx 49 | node_set_t1.append(img_t1[obj_idx]) 50 | node_set_t2.append(img_t2[obj_idx]) 51 | am_set_t1 = construct_affinity_matrix(img_t1, objects, args.band_width_t1) 52 | am_set_t2 = construct_affinity_matrix(img_t2, objects, args.band_width_t2) 53 | GCAE_model = GraphConvAutoEncoder_EdgeRecon(nfeat=3, nhid=16, nclass=3, dropout=0.5) 54 | optimizer = optim.AdamW(GCAE_model.parameters(), lr=1e-4, weight_decay=1e-6) 55 | GCAE_model.cuda() 56 | GCAE_model.train() 57 | 58 | # Edge information reconstruction 59 | for _epoch in range(args.epoch): 60 | for _iter in range(obj_nums): 61 | optimizer.zero_grad() 62 | node_t1 = node_set_t1[_iter] # np.expand_dims(node_set_t1[_iter], axis=0) 63 | adj_t1, norm_adj_t1 = am_set_t1[_iter] # np.expand_dims(am_set_t1[_iter], axis=0) 64 | node_t1 = torch.from_numpy(node_t1).cuda().float() 65 | adj_t1 = torch.from_numpy(adj_t1).cuda().float() 66 | norm_adj_t1 = torch.from_numpy(norm_adj_t1).cuda().float() 67 | 68 | node_t2 = node_set_t2[_iter] # np.expand_dims(node_set_t2[_iter], axis=0) 69 | adj_t2, norm_adj_t2 = am_set_t2[_iter] # np.expand_dims(am_set_t2[_iter], axis=0) 70 | node_t2 = torch.from_numpy(node_t2).cuda().float() 71 | adj_t2 = torch.from_numpy(adj_t2).cuda().float() 72 | norm_adj_t2 = torch.from_numpy(norm_adj_t2).cuda().float() 73 | 74 | feat_t1 = GCAE_model(node_t1, norm_adj_t1) 75 | feat_t2 = GCAE_model(node_t2, norm_adj_t2) 76 | 77 | recon_adj_t1 = torch.matmul(feat_t1, feat_t1.T) 78 | recon_adj_t2 = torch.matmul(feat_t2, feat_t2.T) 79 | 80 | cnstr_loss_t1 = F.mse_loss(input=recon_adj_t1, target=adj_t1) / adj_t1.size()[0] 81 | cnstr_loss_t2 = F.mse_loss(input=recon_adj_t2, target=adj_t2) / adj_t2.size()[0] 82 | ttl_loss = cnstr_loss_t2 + cnstr_loss_t1 83 | ttl_loss.backward() 84 | optimizer.step() 85 | if (_iter + 1) % 10 == 0: 86 | print(f'Epoch is {_epoch + 1}, iter is {_iter}, mse loss is {ttl_loss.item()}') 87 | torch.save(GCAE_model.state_dict(), './model_weight/' + str(time.time()) + '.pth') 88 | 89 | # Extracting deep edge representations & Change information mapping 90 | # Load pretrained weight 91 | # restore_from = './model_weight/SRGCAE_ER_SG.pth' 92 | # load_checkpoint_for_evaluation(GCAE_model, restore_from) 93 | GCAE_model.eval() 94 | diff_set = [] 95 | 96 | for _iter in range(obj_nums): 97 | node_t1 = node_set_t1[_iter] # np.expand_dims(node_set_t1[_iter], axis=0) 98 | node_t2 = node_set_t2[_iter] # np.expand_dims(node_set_t2[_iter], axis=0) 99 | adj_t1, norm_adj_t1 = am_set_t1[_iter] # np.expand_dims(am_set_t1[_iter], axis=0) 100 | adj_t2, norm_adj_t2 = am_set_t2[_iter] # np.expand_dims(am_set_t2[_iter], axis=0) 101 | 102 | node_t1 = torch.from_numpy(node_t1).cuda().float() 103 | node_t2 = torch.from_numpy(node_t2).cuda().float() 104 | norm_adj_t1 = torch.from_numpy(norm_adj_t1).cuda().float() 105 | norm_adj_t2 = torch.from_numpy(norm_adj_t2).cuda().float() 106 | 107 | feat_t1 = GCAE_model(node_t1, norm_adj_t1) 108 | feat_t2 = GCAE_model(node_t2, norm_adj_t2) 109 | 110 | diff = torch.mean(torch.abs(feat_t1 - feat_t2)) 111 | # diff = torch.sqrt(torch.sum(torch.square(feat_t1 - feat_t2), dim=1)) / norm_adj_t2.size()[0] 112 | diff_set.append(diff.data.cpu().numpy()) 113 | 114 | diff_map = np.zeros((height, width)) 115 | for i in range(0, obj_nums): 116 | diff_map[objects == i] = diff_set[i] 117 | 118 | diff_map = np.reshape(diff_map, (height * width, 1)) 119 | 120 | threshold = otsu(diff_map) 121 | diff_map = np.reshape(diff_map, (height, width)) 122 | 123 | bcm = np.zeros((height, width)).astype(np.uint8) 124 | bcm[diff_map > threshold] = 255 125 | bcm[diff_map <= threshold] = 0 126 | 127 | conf_mat, oa, f1, kappa_co = assess_accuracy(ground_truth_changed, ground_truth_unchanged, bcm) 128 | 129 | imageio.imsave('./result/SRGCAE_EdgeConc_' + str(time.time()) + '.png', bcm) 130 | diff_map = 255 * (diff_map - np.min(diff_map)) / (np.max(diff_map) - np.min(diff_map)) 131 | imageio.imsave('./result/SRGCAE_EdgeConc_' + str(time.time()) + '_DI.png', diff_map.astype(np.uint8)) 132 | 133 | print(conf_mat) 134 | print(oa) 135 | print(f1) 136 | print(kappa_co) 137 | 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser(description="Detecting land-cover changes on SG dataset") 141 | parser.add_argument('--n_seg', type=int, default=1500, 142 | help='Approximate number of objects obtained by the segmentation algorithm') 143 | parser.add_argument('--cmp', type=int, default=5, help='Compectness of the obtained objects') 144 | parser.add_argument('--band_width_t1', type=float, default=0.4, 145 | help='The bandwidth of the Gaussian kernel when calculating the adjacency matrix') 146 | parser.add_argument('--band_width_t2', type=float, default=0.5, 147 | help='The bandwidth of the Gaussian kernel when calculating the adjacency matrix') 148 | parser.add_argument('--epoch', type=int, default=10, help='Training epoch of SRGCAE') 149 | args = parser.parse_args() 150 | 151 | train_model(args) 152 | -------------------------------------------------------------------------------- /train_SRGCAE_Nonlocal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import imageio 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from skimage.segmentation import slic 10 | from sklearn.metrics.pairwise import pairwise_distances 11 | 12 | from aux_func.acc_ass import assess_accuracy 13 | from aux_func.clustering import otsu 14 | from aux_func.graph_func import construct_affinity_matrix 15 | from aux_func.preprocess import preprocess_img 16 | from model.SRGCAE import GraphConvAutoEncoder_VertexRecon 17 | 18 | 19 | def load_checkpoint_for_evaluation(model, checkpoint): 20 | saved_state_dict = torch.load(checkpoint, map_location='cuda:0') 21 | model.load_state_dict(saved_state_dict) 22 | model.cuda() 23 | model.eval() 24 | 25 | 26 | def cal_nonlocal_dist(vector, band_width): 27 | euc_dis = pairwise_distances(vector) 28 | gaus_dis = np.exp(- euc_dis * euc_dis / (band_width * band_width)) 29 | return gaus_dis 30 | 31 | 32 | def train_model(args): 33 | img_t1 = imageio.imread('./data/SG/T1.png') # .astype(np.float32) 34 | img_t2 = imageio.imread('./data/SG/T2.png') # .astype(np.float32) 35 | ground_truth_changed = imageio.imread('./data/SG/GT.png') 36 | ground_truth_unchanged = 255 - ground_truth_changed 37 | 38 | height, width, channel_t1 = img_t1.shape 39 | _, _, channel_t2 = img_t2.shape 40 | 41 | # In our paper, the object map is obtained through FNEA algorithm based on eCognition. 42 | # According to our response to reviewers, SLIC can be implemented in Python as an alternative algorithm. 43 | # There is little difference in accuracy between the results obtained by these two methods. 44 | # Yet SLIC can only process images in three bands, so you will need to process images in more than three bands. 45 | objects = slic(img_t2, n_segments=args.n_seg, compactness=args.cmp) 46 | 47 | img_t1 = preprocess_img(img_t1, d_type='sar', norm_type='norm') 48 | img_t2 = preprocess_img(img_t2, d_type='opt', norm_type='norm') 49 | 50 | obj_nums = np.max(objects) + 1 51 | 52 | node_set_t1 = [] 53 | node_set_t2 = [] 54 | for obj_idx in range(obj_nums): 55 | node_set_t1.append(img_t1[objects == obj_idx]) 56 | node_set_t2.append(img_t2[objects == obj_idx]) 57 | am_set_t1 = construct_affinity_matrix(img_t1, objects, args.band_width_t1) 58 | am_set_t2 = construct_affinity_matrix(img_t2, objects, args.band_width_t2) 59 | 60 | GCAE_model = GraphConvAutoEncoder_VertexRecon(nfeat=3, nhid=16, nclass=3, dropout=0.5) 61 | optimizer = optim.Adam(GCAE_model.parameters(), lr=1e-4, weight_decay=1e-4) 62 | GCAE_model.cuda() 63 | GCAE_model.train() 64 | 65 | # Vertex information reconstruction 66 | for _epoch in range(args.epoch): 67 | for _iter in range(obj_nums): 68 | optimizer.zero_grad() 69 | node_t1 = node_set_t1[_iter] # np.expand_dims(node_set_t1[_iter], axis=0) 70 | node_t2 = node_set_t2[_iter] # np.expand_dims(node_set_t2[_iter], axis=0) 71 | _, norm_adj_t1 = am_set_t1[_iter] # np.expand_dims(am_set_t1[_iter], axis=0) 72 | _, norm_adj_t2 = am_set_t2[_iter] # np.expand_dims(am_set_t2[_iter], axis=0) 73 | 74 | node_t1 = torch.from_numpy(node_t1).cuda().float() 75 | node_t2 = torch.from_numpy(node_t2).cuda().float() 76 | 77 | norm_adj_t1 = torch.from_numpy(norm_adj_t1).cuda().float() 78 | norm_adj_t2 = torch.from_numpy(norm_adj_t2).cuda().float() 79 | 80 | cstr_node_t1, feat_t1 = GCAE_model(node_t1, norm_adj_t1) 81 | cstr_node_t2, feat_t2 = GCAE_model(node_t2, norm_adj_t2) 82 | 83 | cnstr_loss_t1 = F.mse_loss(input=cstr_node_t1, target=node_t1) 84 | cnstr_loss_t2 = F.mse_loss(input=cstr_node_t2, target=node_t2) 85 | ttl_loss = cnstr_loss_t2 + cnstr_loss_t1 86 | ttl_loss.backward() 87 | optimizer.step() 88 | if (_iter + 1) % 10 == 0: 89 | print(f'Epoch is {_epoch + 1}, iter is {_iter}, mse loss is {ttl_loss.item()}') 90 | # torch.save(GCAE_model.state_dict(), './model_weight/' + str(time.time()) + '.pth') 91 | # Extracting deep vertex representations 92 | # Load pretrained weight 93 | # restore_from = './model_weight/SRGCAE_VR_SG.pth' 94 | # load_checkpoint_for_evaluation(GCAE_model, restore_from) 95 | GCAE_model.eval() 96 | feat_set_t1 = [] 97 | feat_set_t2 = [] 98 | 99 | for _iter in range(obj_nums): 100 | node_t1 = node_set_t1[_iter] 101 | node_t2 = node_set_t2[_iter] 102 | _, norm_adj_t1 = am_set_t1[_iter] 103 | _, norm_adj_t2 = am_set_t2[_iter] 104 | 105 | node_t1 = torch.from_numpy(node_t1).cuda().float() 106 | node_t2 = torch.from_numpy(node_t2).cuda().float() 107 | norm_adj_t1 = torch.from_numpy(norm_adj_t1).cuda().float() 108 | norm_adj_t2 = torch.from_numpy(norm_adj_t2).cuda().float() 109 | 110 | _, feat_t1 = GCAE_model(node_t1, norm_adj_t1) 111 | _, feat_t2 = GCAE_model(node_t2, norm_adj_t2) 112 | 113 | feat_t1 = torch.mean(feat_t1, dim=0) 114 | feat_t2 = torch.mean(feat_t2, dim=0) 115 | feat_set_t1.append(feat_t1.data.cpu().numpy()) 116 | feat_set_t2.append(feat_t2.data.cpu().numpy()) 117 | 118 | feat_set_t1 = np.array(feat_set_t1) 119 | feat_set_t2 = np.array(feat_set_t2) 120 | 121 | dist_set_t1 = cal_nonlocal_dist(feat_set_t1, args.deep_band_width_t1) 122 | dist_set_t2 = cal_nonlocal_dist(feat_set_t2, args.deep_band_width_t2) 123 | 124 | neigh_idx_t1 = np.argsort(-dist_set_t1, axis=1) 125 | neigh_idx_t2 = np.argsort(-dist_set_t2, axis=1) 126 | 127 | fx_node_dist = np.zeros((obj_nums, 1)) 128 | fy_node_dist = np.zeros((obj_nums, 1)) 129 | 130 | # Change information mapping 131 | for i in range(obj_nums): 132 | fx_node_dist[i] = np.mean( 133 | np.abs(dist_set_t1[i, neigh_idx_t1[i, 1:args.knn_num]] - dist_set_t1[i, neigh_idx_t2[i, 1:args.knn_num]])) 134 | fy_node_dist[i] = np.mean( 135 | np.abs(dist_set_t2[i, neigh_idx_t2[i, 1:args.knn_num]] - dist_set_t2[i, neigh_idx_t1[i, 1:args.knn_num]])) 136 | diff_map = np.zeros((height, width)) 137 | 138 | for i in range(0, obj_nums): 139 | diff_map[objects == i] = fx_node_dist[i] + fy_node_dist[i] 140 | diff_map = np.reshape(diff_map, (height * width, 1)) 141 | threshold = otsu(diff_map) 142 | diff_map = np.reshape(diff_map, (height, width)) 143 | bcm = np.zeros((height, width)).astype(np.uint8) 144 | bcm[diff_map > threshold] = 255 145 | bcm[diff_map <= threshold] = 0 146 | 147 | conf_mat, oa, f1, kappa_co = assess_accuracy(ground_truth_changed, ground_truth_unchanged, bcm) 148 | 149 | imageio.imsave('./result/SRGCAE_VerConc_' + str(time.time()) + '.png', bcm) 150 | diff_map = 255 * (diff_map - np.min(diff_map)) / (np.max(diff_map) - np.min(diff_map)) 151 | imageio.imsave('./result/SRGCAE_VerConc_' + str(time.time()) + '_DI.png', diff_map.astype(np.uint8)) 152 | 153 | print(conf_mat) 154 | print(oa) 155 | print(f1) 156 | print(kappa_co) 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = argparse.ArgumentParser(description="Detecting land-cover changes on SG dataset") 161 | parser.add_argument('--n_seg', type=int, default=5000, 162 | help='Approximate number of objects obtained by the segmentation algorithm') 163 | parser.add_argument('--cmp', type=int, default=5, help='Compectness of the obtained objects') 164 | parser.add_argument('--band_width_t1', type=float, default=1, 165 | help='The bandwidth of the Gaussian kernel when calculating the adjacency matrix') 166 | parser.add_argument('--band_width_t2', type=float, default=0.7, 167 | help='The bandwidth of the Gaussian kernel when calculating the adjacency matrix') 168 | parser.add_argument('--deep_band_width_t1', type=float, default=0.15, 169 | help='The bandwidth of the Gaussian kernel when calculating the adjacency matrix using deep vertex representations') 170 | parser.add_argument('--deep_band_width_t2', type=float, default=0.15, 171 | help='The bandwidth of the Gaussian kernel when calculating the adjacency matrix using deep vertex representations') 172 | parser.add_argument('--knn_num', type=int, default=100, 173 | help='the number of most similar objects for calculating nonlocal structural relationship') 174 | parser.add_argument('--epoch', type=int, default=15, help='Training epoch of SRGCAE') 175 | args = parser.parse_args() 176 | 177 | train_model(args) 178 | --------------------------------------------------------------------------------