├── README.md └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # CSANet 2 | GRSL 2022 CSANet: Cross-Temporal Interaction Symmetric Attention Network for Hyperspectral Image Change Detection 3 | Souce Code available https://github.com/SYFYN0317/CSANet 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import collections 4 | from torch import optim 5 | import torch 6 | from sklearn import metrics, preprocessing 7 | import datetime 8 | from CDCNN.generate_pic import ConstractiveLoss 9 | from CDCNN import train 10 | from CDCNN.DASNET import Finalmodel 11 | import sys 12 | from scipy.io import savemat 13 | sys.path.append('../global_module/') 14 | 15 | 16 | from CDCNN.generate_pic import aa_and_each_accuracy, sampling,load_dataset, generate_png, generate_iter,applyPCA 17 | from global_module.Utils import record, extract_samll_cubic 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | # for Monte Carlo runs 22 | seeds = [1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341] 23 | ensemble = 1 24 | 25 | day = datetime.datetime.now() 26 | day_str = day.strftime('%m_%d_%H_%M') 27 | 28 | print('-----Importing Dataset-----') 29 | 30 | 31 | global Dataset # UP,IN,KSC 32 | Dataset = 'china' 33 | T11, T22, gt_hsi, TOTAL_SIZE, TRAIN_SIZE, VALIDATION_SPLIT = load_dataset() 34 | print(T11.shape) 35 | print(T22.shape) 36 | 37 | image_x, image_y, BAND1 = T11.shape 38 | 39 | 40 | 41 | data1 = T11.reshape(np.prod(T11.shape[:2]), np.prod(T11.shape[2:])) 42 | data2 = T22.reshape(np.prod(T22.shape[:2]), np.prod(T22.shape[2:])) 43 | gt = gt_hsi.reshape(np.prod(gt_hsi.shape[:2]),) 44 | CLASSES_NUM = 2 45 | 46 | 47 | print('-----Importing Setting Parameters-----') 48 | ITER = 1 49 | PATCH_LENGTH = 3 50 | # number of training samples per class 51 | lr, num_epochs, batch_size = 0.0005, 10, 32 52 | loss = torch.nn.CrossEntropyLoss() 53 | 54 | 55 | img_rows = 2*PATCH_LENGTH+1 56 | img_cols = 2*PATCH_LENGTH+1 57 | img_channels1 = T11.shape[2] 58 | img_channels2 = T22.shape[2] 59 | INPUT_DIMENSION1 = T11.shape[2] 60 | INPUT_DIMENSION2 = T22.shape[2] 61 | ALL_SIZE = T11.shape[0] * T22.shape[1] 62 | VAL_SIZE = int(TRAIN_SIZE) 63 | TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE 64 | 65 | 66 | KAPPA = [] 67 | OA = [] 68 | AA = [] 69 | TRAINING_TIME = [] 70 | TESTING_TIME = [] 71 | ELEMENT_ACC = np.zeros((ITER, CLASSES_NUM)) 72 | 73 | data1 = preprocessing.scale(data1) 74 | data2 = preprocessing.scale(data2) 75 | 76 | data1_ = data1.reshape(T11.shape[0], T11.shape[1], T11.shape[2]) 77 | whole_data1 = data1_ 78 | data2_ = data2.reshape(T22.shape[0], T22.shape[1], T22.shape[2]) 79 | whole_data2 = data2_ 80 | 81 | 82 | padded_data1 = np.lib.pad(whole_data1, ((PATCH_LENGTH, PATCH_LENGTH), (PATCH_LENGTH, PATCH_LENGTH), (0, 0)), 83 | 'constant', constant_values=0) 84 | padded_data2 = np.lib.pad(whole_data2, ((PATCH_LENGTH, PATCH_LENGTH), (PATCH_LENGTH, PATCH_LENGTH), (0, 0)), 85 | 'constant', constant_values=0) 86 | 87 | for index_iter in range(ITER): 88 | net = Finalmodel() 89 | optimizer = optim.Adam(net.parameters(), lr=lr) # , weight_decay=0.0001) 90 | time_1 = int(time.time()) 91 | np.random.seed(seeds[index_iter]) 92 | train_indices, test_indices = sampling(VALIDATION_SPLIT, gt) 93 | _, total_indices = sampling(1, gt) 94 | 95 | TRAIN_SIZE = len(train_indices) 96 | print('Train size: ', TRAIN_SIZE) 97 | TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE 98 | print('Test size: ', TEST_SIZE) 99 | VAL_SIZE = int(TRAIN_SIZE) 100 | print('Validation size: ', VAL_SIZE) 101 | 102 | print('-----Selecting Small Pieces from the Original Cube Data-----') 103 | 104 | 105 | train_iter, valida_iter, all_iter = generate_iter(TRAIN_SIZE, train_indices, TOTAL_SIZE, total_indices, VAL_SIZE, 106 | whole_data1, whole_data2, PATCH_LENGTH, padded_data1, padded_data2, INPUT_DIMENSION1, INPUT_DIMENSION2, batch_size, gt) 107 | 108 | 109 | train.train(net, train_iter, valida_iter, loss, optimizer, device, epochs=num_epochs) 110 | 111 | 112 | def score_function(engine): 113 | val_loss = engine.state.metrics['nll'] 114 | return -val_loss 115 | pred_test_fdssc = [] 116 | tic2 = time.clock() 117 | newnet = torch.load("UPbest.pth") 118 | with torch.no_grad(): 119 | for X1, X2, y in all_iter: 120 | X1 = X1.to(device) 121 | X2 = X2.to(device) 122 | newnet.eval() # 评估模式, 这会关闭dropout 123 | out1, out2, y_hat = newnet(X1,X2) 124 | # print(net(X)) 125 | pred_test_fdssc.extend(np.array(y_hat.cpu().argmax(axis=1))) 126 | toc2 = time.clock() 127 | 128 | collections.Counter(pred_test_fdssc) 129 | gt_test = gt[total_indices]-1 130 | # gt_test = gt_test1[:-VAL_SIZE] 131 | # print('pre', np.unique(pred_test_fdssc)) 132 | # print('gt_test', np.unique(gt_test[:-VAL_SIZE])) 133 | 134 | overall_acc_fdssc = metrics.accuracy_score(pred_test_fdssc, gt_test) 135 | confusion_matrix_fdssc = metrics.confusion_matrix(pred_test_fdssc, gt_test) 136 | # print(confusion_matrix_fdssc) 137 | each_acc_fdssc, average_acc_fdssc = aa_and_each_accuracy(confusion_matrix_fdssc) 138 | kappa = metrics.cohen_kappa_score(pred_test_fdssc, gt_test) 139 | print("testing time:",toc2-tic2) 140 | savemat("river.mat", mdict={'result': pred_test_fdssc}) 141 | KAPPA.append(kappa) 142 | OA.append(overall_acc_fdssc) 143 | AA.append(average_acc_fdssc) 144 | ELEMENT_ACC[index_iter, :] = each_acc_fdssc 145 | 146 | print("-------- Training Finished-----------") 147 | print('OA:',OA) 148 | print('AA:',AA) 149 | print('Kappa:',KAPPA) 150 | print('OA_UN',each_acc_fdssc) 151 | 152 | generate_png(pred_test_fdssc, gt_hsi-1, 'China', total_indices) --------------------------------------------------------------------------------