├── output ├── results │ └── readme.md └── cls_maps │ └── readme.md ├── data ├── IP │ └── readme.md ├── UP │ └── readme.md └── HU13_tif │ └── readme.md ├── src └── framework.png ├── c_model ├── MCM_CNN.py ├── CNN1D.py ├── ASPN.py ├── SSFTT.py ├── SSAtt.py ├── SSSAN.py ├── SSTN.py └── A2S2KResNet.py ├── model ├── module │ ├── AMIPS.py │ ├── EucProject.py │ ├── DCR.py │ ├── manifold_learning.py │ └── MPA_Lya.py └── AMS_M2ESL.py ├── utils ├── evaluation.py ├── data_load_operate_c_model_m_scale.py └── data_load_operate.py ├── visual └── cls_visual.py ├── process_cls_disjoint_c_model.py ├── process_cls_c_model.py ├── README.md ├── process_dl_disjoint_c_model_m_scale.py ├── process_dl_c_model_m_scale.py ├── process_dl_disjoint.py └── process_dl.py /output/results/readme.md: -------------------------------------------------------------------------------- 1 | Methods output file. 2 | -------------------------------------------------------------------------------- /output/cls_maps/readme.md: -------------------------------------------------------------------------------- 1 | Viusual cls map of the corresponding method outputs. 2 | -------------------------------------------------------------------------------- /data/IP/readme.md: -------------------------------------------------------------------------------- 1 | The corresponding dat set files are put here in .mat data format. 2 | -------------------------------------------------------------------------------- /data/UP/readme.md: -------------------------------------------------------------------------------- 1 | The corresponding dat set files are put here in .mat data format. 2 | -------------------------------------------------------------------------------- /src/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lms-07/AMS-M2ESL/HEAD/src/framework.png -------------------------------------------------------------------------------- /data/HU13_tif/readme.md: -------------------------------------------------------------------------------- 1 | The corresponding dat set files are put here in .mat data format. 2 | -------------------------------------------------------------------------------- /c_model/MCM_CNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : MCM_CNN.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | 9 | # unofficial implementation based offical Matlab version 10 | # https://github.com/henanjun/demo_MCMs 11 | # Feature Extraction With Multiscale Covariance Maps for Hyperspectral Image Classification, TGRS, 2018 12 | 13 | 14 | import torch.nn as nn 15 | 16 | 17 | class MCM_CNN(nn.Module): 18 | def __init__(self, scales, class_count, ds): 19 | super(MCM_CNN, self).__init__() 20 | self.channels = scales 21 | self.class_count = class_count 22 | if ds == 'IP': 23 | self.channels_fc_1 = 576 24 | self.channels_fc_2 = 128 25 | elif ds == 'UP': 26 | self.channels_fc_1 = 576 27 | self.channels_fc_2 = 512 28 | elif ds == 'UH_tif': 29 | self.channels_fc_1 = 576 30 | self.channels_fc_2 = 128 31 | 32 | self.relu = nn.ReLU(inplace=True) 33 | self.pooling = nn.AvgPool2d((2, 2), stride=2) 34 | 35 | self.conv_1 = nn.Conv2d(self.channels, 128, kernel_size=3, stride=1) 36 | self.conv_2 = nn.Conv2d(128, 64, kernel_size=3, stride=1) 37 | 38 | self.flatten = nn.Flatten(1) 39 | self.fc2 = nn.Linear(self.channels_fc_1, self.channels_fc_2) 40 | self.fc1 = nn.Linear(self.channels_fc_2, self.channels_fc_2) 41 | self.fc0 = nn.Linear(self.channels_fc_2, self.class_count) 42 | 43 | def forward(self, x): 44 | 45 | x = self.conv_1(x) 46 | x = self.relu(self.pooling(x)) 47 | x = self.conv_2(x) 48 | x = self.relu(self.pooling(x)) 49 | 50 | x = self.flatten(x) 51 | x = self.relu(self.fc2(x)) 52 | x = self.relu(self.fc1(x)) 53 | out = self.fc0(x) 54 | return out 55 | 56 | def _init_weight(self): 57 | for m in self.modules(): 58 | nn.init.xavier_normal(m.weight, 1) 59 | nn.init.constant_(m.bias, 0) 60 | 61 | 62 | def MCM_CNN_(channels, num_classes, ds): 63 | model = MCM_CNN(channels, num_classes, ds) 64 | return model 65 | -------------------------------------------------------------------------------- /model/module/AMIPS.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : AMIPS.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import cosine_similarity 11 | 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class AM_IPS(nn.Module): 16 | def __init__(self, ds): 17 | super(AM_IPS, self).__init__() 18 | self.ds = ds 19 | 20 | def forward(self, raw_patch): 21 | bt, c, h, w = raw_patch.size() 22 | 23 | patch_2d = raw_patch.view(bt, -1, h * w).permute(0, 2, 1) 24 | 25 | # central spectral vector sampling 26 | cent_spec_vec = patch_2d[:, int((h * w - 1) / 2)] 27 | cent_spec_vec = torch.unsqueeze(cent_spec_vec, dim=1) 28 | 29 | # central spectral vector oriented similarity 30 | sim_mat = self._sim_euc(cent_spec_vec, patch_2d) 31 | # sim_mat = self._sim_mat_mul(cent_spec_vec, patch_2d) 32 | # sim_mat = self._sim_cos(cent_spec_vec, patch_2d) 33 | 34 | if self.ds == 'UH_tif': 35 | threshold_sampling = torch.mean(sim_mat, dim=1) - 0.25 * torch.std(sim_mat, dim=1) 36 | else: 37 | threshold_sampling = torch.mean(sim_mat, dim=1) - 0.2 * torch.std(sim_mat, dim=1) 38 | 39 | # sampling 40 | threshold_mat = torch.unsqueeze(threshold_sampling, dim=1) * torch.ones_like(sim_mat) 41 | threshold_mask = sim_mat - threshold_mat 42 | 43 | index_mask = torch.where(threshold_mask >= 0, 1, 0) 44 | index_mask = torch.unsqueeze(index_mask, -1) 45 | x_sampling = index_mask * patch_2d 46 | 47 | return x_sampling.contiguous().view(bt, h, w, c).permute(0, 3, 1, 2) 48 | 49 | def _sim_mat_mul(self, central_vector, x_2d): 50 | sim_M = torch.bmm(x_2d, central_vector.permute(0, 2, 1)) 51 | return torch.squeeze(sim_M, dim=-1) 52 | 53 | def _sim_euc(self, central_vector, x_2d): 54 | bt, h_w, c = x_2d.size() 55 | cen_vec_mat = central_vector.expand(bt, h_w, c) 56 | euc_dist = torch.norm(cen_vec_mat - x_2d, dim=2, p=2) 57 | sim_M = 1 / (1 + euc_dist) 58 | return sim_M 59 | 60 | def _sim_cos(self, central_vector, x_2d): 61 | sim_M = cosine_similarity(central_vector, x_2d, dim=2) 62 | return sim_M 63 | -------------------------------------------------------------------------------- /model/module/EucProject.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : EucProject.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | ''' 9 | the ASQRT layer is implemented based on the source code of COSONet, i.e., 10 | COSONet: Compact Second-Order Network for Video Face Recognition, ACCV 2018, 11 | https://github.com/YirongMao/COSONet/blob/master/layer_utils.py 12 | 13 | the earliest version is surly based on excellent work, iSQRT-Conv, i.e., 14 | Towards Faster Training of Global Covariance Pooling Networks by Iterative Matrix Square Root Normalization, CVPR 2018, 15 | https://github.com/jiangtaoxie/fast-MPN-COV/blob/master/src/representation/MPNCOV.py 16 | ''' 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.autograd import Variable 21 | 22 | 23 | # ASQRT for multichannels via autograd 24 | class ASQRT_autograd_mc(nn.Module): 25 | 26 | def __init__(self, norm_type, num_iter): 27 | super(ASQRT_autograd_mc, self).__init__() 28 | self.norm_type = norm_type 29 | self.num_iter = num_iter 30 | 31 | def forward(self, A): 32 | b_s, c, n_c, n_c = A.size() 33 | A = A.view(b_s * c, n_c, n_c) 34 | b_s_c = A.shape[0] 35 | 36 | dtype = A.dtype 37 | device = A.device 38 | # pre normalization 39 | if self.norm_type == 'Frob_n': 40 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 41 | elif self.norm_type == 'Trace_n': 42 | I_bs_mat = torch.eye(n_c, n_c, device=A.device).view(1, n_c, n_c).expand_as(A).type(dtype) 43 | normA = A.mul(I_bs_mat).sum(dim=1).sum(dim=1) 44 | else: 45 | raise NameError('invalid normalize type {}'.format(self.norm_type)) 46 | 47 | Y = A.div(normA.view(b_s_c, 1, 1).expand_as(A)) 48 | # Iteration 49 | I = Variable(torch.eye(n_c, n_c).view(1, n_c, n_c). 50 | repeat(b_s_c, 1, 1).type(dtype).to(device), requires_grad=False) 51 | Z = Variable(torch.eye(n_c, n_c).view(1, n_c, n_c). 52 | repeat(b_s_c, 1, 1).type(dtype).to(device), requires_grad=False) 53 | 54 | for i in range(self.num_iter): 55 | T = 0.5 * (3.0 * I - Z.bmm(Y)) 56 | Y = Y.bmm(T) 57 | Z = T.bmm(Z) 58 | 59 | # post compensation 60 | sA = Y * torch.sqrt(normA).view(b_s_c, 1, 1).expand_as(A) 61 | 62 | sA = sA.view(b_s, c, n_c, n_c) 63 | del I, Z 64 | return sA 65 | -------------------------------------------------------------------------------- /c_model/CNN1D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : CNN1D.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # https://github.com/eecn/Hyperspectral-Classification 9 | # Deep Convolutional Neural Networks for Hyperspectral Image Classification, Journal of Sensors, 2015 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn import init 15 | 16 | 17 | class CNN1D(nn.Module): 18 | """ 19 | Deep Convolutional Neural Networks for Hyperspectral Image Classification 20 | Wei Hu, Yangyu Huang, Li Wei, Fan Zhang and Hengchao Li 21 | Journal of Sensors, Volume 2015 (2015) 22 | https://www.hindawi.com/journals/js/2015/258619/ 23 | """ 24 | 25 | @staticmethod 26 | def weight_init(m): 27 | # [All the trainable parameters in our CNN should be initialized to 28 | # be a random value between −0.05 and 0.05.] 29 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): 30 | init.uniform_(m.weight, -0.05, 0.05) 31 | init.zeros_(m.bias) 32 | 33 | def _get_final_flattened_size(self): 34 | with torch.no_grad(): 35 | x = torch.zeros(1, 1, self.input_channels) 36 | x = self.pool(self.conv(x)) 37 | return x.numel() 38 | 39 | def __init__(self, input_channels, n_classes, kernel_size=None, pool_size=None): 40 | super(CNN1D, self).__init__() 41 | if kernel_size is None: 42 | # [In our experiments, k1 is better to be [ceil](n1/9)] 43 | kernel_size = math.ceil(input_channels / 9) 44 | if pool_size is None: 45 | # The authors recommand that k2's value is chosen so that the pooled features have 30~40 values 46 | # ceil(kernel_size/5) gives the same values as in the paper so let's assume it's okay 47 | pool_size = math.ceil(kernel_size / 5) 48 | self.input_channels = input_channels 49 | 50 | # [The first hidden convolution layer C1 filters the n1 x 1 input data with 20 kernels of size k1 x 1] 51 | self.conv = nn.Conv1d(1, 20, kernel_size) 52 | self.pool = nn.MaxPool1d(pool_size) 53 | self.features_size = self._get_final_flattened_size() 54 | # [n4 is set to be 100] 55 | self.fc1 = nn.Linear(self.features_size, 100) 56 | self.fc2 = nn.Linear(100, n_classes) 57 | self.apply(self.weight_init) 58 | 59 | def forward(self, x): 60 | # [In our design architecture, we choose the hyperbolic tangent function tanh(u)] 61 | 62 | x = x.squeeze(dim=-1).squeeze(dim=-1) 63 | x = x.unsqueeze(1) 64 | x = self.conv(x) 65 | x = torch.tanh(self.pool(x)) 66 | x = x.view(-1, self.features_size) 67 | x = torch.tanh(self.fc1(x)) 68 | x = self.fc2(x) 69 | return x 70 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : evaluation.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | import torch 9 | import numpy as np 10 | 11 | from sklearn import metrics 12 | from operator import truediv 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | def evaluate_OA(data_iter, net, loss, device, model_type_flag): 18 | acc_sum, samples_counter = 0, 0 19 | 20 | with torch.no_grad(): 21 | net.eval() 22 | if model_type_flag == 1: # data for single spatial net 23 | for X_spa, y in data_iter: 24 | loss_sum = 0 25 | X_spa, y = X_spa.to(device), y.to(device) 26 | y_pred = net(X_spa) 27 | 28 | ls = loss(y_pred, y.long()) 29 | 30 | acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 31 | loss_sum += ls 32 | 33 | samples_counter += y.shape[0] 34 | elif model_type_flag == 2: # data for single spectral net 35 | for X_spe, y in data_iter: 36 | loss_sum = 0 37 | X_spe, y = X_spe.to(device), y.to(device) 38 | y_pred = net(X_spe) 39 | 40 | ls = loss(y_pred, y.long()) 41 | 42 | acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 43 | loss_sum += ls 44 | 45 | samples_counter += y.shape[0] 46 | elif model_type_flag == 3: # data for spectral-spatial net 47 | for X_spa, X_spe, y in data_iter: 48 | loss_sum = 0 49 | X_spa, X_spe, y = X_spa.to(device), X_spe.to(device), y.to(device) 50 | y_pred = net(X_spa, X_spe) 51 | 52 | ls = loss(y_pred, y.long()) 53 | 54 | acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 55 | loss_sum += ls 56 | 57 | samples_counter += y.shape[0] 58 | 59 | return [acc_sum / samples_counter, loss_sum] 60 | 61 | 62 | def AA_ECA(confusion_matrix): 63 | # get diagonal element 64 | diag_list = np.diag(confusion_matrix) 65 | row_sum_list = np.sum(confusion_matrix, axis=1) 66 | each_per_acc = np.nan_to_num(truediv(diag_list, row_sum_list)) 67 | avg_acc = np.mean(each_per_acc) 68 | 69 | return each_per_acc, avg_acc 70 | 71 | 72 | def claification_report(label, pred, name): 73 | if name == 'IP': 74 | target_names = ['Alfalfa', 'Corn-notill', 'Corn-mintill', 'Corn' 75 | , 'Grass-pasture', 'Grass-trees', 'Grass-pasture-mowed', 76 | 'Hay-windrowed', 'Oats', 'Soybean-notill', 'Soybean-mintill', 77 | 'Soybean-clean', 'Wheat', 'Woods', 'Buildings-Grass-Trees-Drives', 78 | 'Stone-Steel-Towers'] 79 | elif name == 'UP': 80 | target_names = ['Asphalt', 'Meadows', 'Gravel', 'Trees', 'Painted metal sheets', 'Bare Soil', 'Bitumen', 81 | 'Self-Blocking Bricks', 'Shadows'] 82 | 83 | elif name == 'UH_tif': 84 | target_names = ['Grass_healthy', 'Grass_stressed', 'Grass_synthetic', 'Tree', 'Soil', 'Water', 'Residential', 85 | 'Commercial', 'Road', 'Highway', 'Railway', 'Parking_lot1', 'Parking_lot2', 'Tennis_court', 86 | 'Running_track'] 87 | 88 | classification_report = metrics.classification_report(label, pred, target_names=target_names) 89 | return classification_report 90 | -------------------------------------------------------------------------------- /visual/cls_visual.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : cls_visual.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | import torch 9 | import numpy as np 10 | import spectral as spy 11 | from spectral import spy_colors 12 | 13 | spy.algorithms 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def gt_cls_map(gt_hsi, path): 19 | spy.save_rgb(path + "_gt.png", gt_hsi, colors=spy_colors) 20 | print('------Get ground truth classification map successful-------') 21 | 22 | 23 | def pred_cls_map_dl(sample_list, net, gt_hsi, path, model_type_flag): 24 | pred_sample = [] 25 | pred_label = [] 26 | 27 | net.eval() 28 | if len(sample_list) == 1: 29 | iter = sample_list[0] 30 | if model_type_flag == 1: # data for single spatial net 31 | for X_spa, y in iter: 32 | X_spa = X_spa.to(device) 33 | pre_y = net(X_spa).cpu().argmax(axis=1).detach().numpy() 34 | pred_sample.extend(pre_y + 1) 35 | elif model_type_flag == 2: # data for single spectral net 36 | for X_spe, y in iter: 37 | X_spe = X_spe.to(device) 38 | pre_y = net(X_spe).cpu().argmax(axis=1).detach().numpy() 39 | pred_sample.extend(pre_y + 1) 40 | elif model_type_flag == 3: 41 | for X_spa, X_spe, y in iter: 42 | X_spa, X_spe, y = X_spa.to(device), X_spe.to(device), y.to(device) 43 | pre_y = net(X_spa, X_spe).cpu().argmax(axis=1).detach().numpy() 44 | pred_sample.extend(pre_y + 1) 45 | elif len(sample_list) == 2: 46 | iter, index = sample_list[0], sample_list[1] 47 | if model_type_flag == 1: # data for single spatial net 48 | for X_spa, y in iter: 49 | X_spa = X_spa.to(device) 50 | pre_y = net(X_spa).cpu().argmax(axis=1).detach().numpy() 51 | pred_label.extend(pre_y + 1) 52 | elif model_type_flag == 2: # data for single spectral net 53 | for X_spe, y in iter: 54 | X_spe = X_spe.to(device) 55 | pre_y = net(X_spe).cpu().argmax(axis=1).detach().numpy() 56 | pred_label.extend(pre_y + 1) 57 | elif model_type_flag == 3: 58 | for X_spa, X_spe, y in iter: 59 | X_spa, X_spe, y = X_spa.to(device), X_spe.to(device), y.to(device) 60 | pre_y = net(X_spa, X_spe).cpu().argmax(axis=1).detach().numpy() 61 | pred_label.extend(pre_y + 1) 62 | 63 | gt = np.ravel(gt_hsi) 64 | pred_sample = np.zeros(gt.shape) 65 | pred_sample[index] = pred_label 66 | 67 | pred_hsi = np.reshape(pred_sample, (gt_hsi.shape[0], gt_hsi.shape[1])) 68 | spy.save_rgb(path + '_' + str(len(sample_list)) + '_pre.png', pred_hsi, colors=spy_colors) # dpi haven't set now 69 | print('------Get pred classification maps successful-------') 70 | 71 | 72 | def pred_cls_map_cls(sample_list, gt_hsi, path): 73 | if len(sample_list) == 1: 74 | pred_sample = sample_list[0] 75 | 76 | elif len(sample_list) == 2: 77 | pred_label, index = sample_list[0], sample_list[1] 78 | gt = np.ravel(gt_hsi) 79 | pred_sample = np.zeros(gt.shape) 80 | pred_sample[index] = pred_label 81 | 82 | pred_hsi = np.reshape(pred_sample, (gt_hsi.shape[0], gt_hsi.shape[1])) 83 | spy.save_rgb(path + '_' + str(len(sample_list)) + '_pre.png', pred_hsi, colors=spy_colors) # dpi haven't set now 84 | print('------Get pred classification maps successful-------') 85 | -------------------------------------------------------------------------------- /c_model/ASPN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : ASPN.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # unofficial implementation based on offical Keras version 9 | # https://github.com/mengxue-rs/a-spn 10 | # Attention-Based Second-Order Pooling Network for Hyperspectral Image Classification, TGRS 2021 11 | 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | class ASOP(nn.Module): 20 | def __init__(self, bs, hw): 21 | super(ASOP, self).__init__() 22 | self.bs = bs 23 | self.hw = hw 24 | 25 | self.kernel = nn.Parameter(torch.ones(self.bs, self.hw, 1, device=device), requires_grad=True) 26 | self.bias = nn.Parameter(torch.zeros(self.bs, self.hw, device=device), requires_grad=True) 27 | self.softmax = nn.Softmax(dim=-1) 28 | 29 | def forward(self, x): 30 | bs, c, hw = x.size() 31 | 32 | Xmm = self._second_order_pooling(x) 33 | 34 | norm = torch.norm(Xmm, p=2, dim=-1, keepdim=True) 35 | out = Xmm.div(norm) 36 | 37 | central_vector = out[:, int((hw - 1) / 2)] 38 | central_vector = torch.unsqueeze(central_vector, dim=1) 39 | 40 | cos = torch.mul(central_vector, self.kernel) 41 | 42 | out = torch.bmm(out, cos) + torch.unsqueeze(self.bias, dim=-1) 43 | att = self.softmax(out) 44 | out = torch.bmm(x, att) 45 | 46 | return out 47 | 48 | def _second_order_pooling(self, x): 49 | x1 = x.permute(0, 2, 1) 50 | out = torch.bmm(x1, x) 51 | 52 | return out 53 | 54 | 55 | class A_SPN(nn.Module): 56 | def __init__(self, bs, height, weight, in_channels, class_count): 57 | super(A_SPN, self).__init__() 58 | self.bs = bs 59 | self.h = height 60 | self.w = weight 61 | self.in_channels = in_channels 62 | self.class_count = class_count 63 | 64 | self.bn = nn.BatchNorm2d(self.in_channels) 65 | self.dropout = nn.Dropout(p=0.5) 66 | self.ASOP = ASOP(self.bs, self.h * self.w) 67 | self.flatten = nn.Flatten(1) 68 | 69 | self.fc = nn.Linear(self.in_channels * self.in_channels, class_count) 70 | 71 | def forward(self, x): 72 | b, h, w, c = x.size() 73 | x = x.permute(0, 3, 1, 2) 74 | 75 | x = self.bn(x) 76 | x = x.view(b, -1, h * w) 77 | 78 | out = self.dropout(x) 79 | 80 | norm = torch.norm(out, p=2, dim=-1, keepdim=True) 81 | out = out.div(norm) 82 | 83 | out = self.ASOP(out) 84 | out = self._second_order_pooling(out) 85 | 86 | norm = torch.norm(out, p=2, dim=-1, keepdim=True) 87 | out = out.div(norm) 88 | 89 | norm = torch.norm(out, p='fro', dim=-1, keepdim=True) 90 | out = out.div(norm) 91 | 92 | out = self.flatten(out) 93 | 94 | out = self.fc(out) 95 | 96 | return out 97 | 98 | def _init_weight(self): 99 | for m in self.modules(): 100 | if isinstance(m, nn.Linear): 101 | nn.init.trunc_normal_(m.weight, mean=0, std=1e-4) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | nn.init.constant_(m.weight, 1) 104 | nn.init.constant_(m.bias, 0) 105 | 106 | def _second_order_pooling(self, x): 107 | x1 = x.permute(0, 2, 1) 108 | out = torch.bmm(x, x1) 109 | 110 | return out 111 | 112 | 113 | def ASPN_(bs, height, weight, in_channels, num_classes): 114 | model = A_SPN(bs, height, weight, in_channels, num_classes) 115 | return model 116 | -------------------------------------------------------------------------------- /model/module/DCR.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : DCR.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | ''' 14 | for the two implementations of distance covariance represntation (DCR), 15 | the implementation1 (_DCR_1) is based on DeepBDC, i.e., 16 | Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification, CVPR 2022 17 | https://github.com/Fei-Long121/DeepBDC/blob/main/methods/bdc_module.py, 18 | and _DCR_2 is our original implementation based on Brownian distance covariance, https://doi.org/10.1214/09-AOAS312. 19 | the two implementations achieve similar performance in our AMS-M2ESL framework, we randomly chose _DCR_1 as the final version. 20 | ''' 21 | 22 | 23 | class Spectral_corr_mining(nn.Module): 24 | def __init__(self, in_channels): 25 | super(Spectral_corr_mining, self).__init__() 26 | self.temperature = nn.Parameter( 27 | torch.log((3.2 / (in_channels * in_channels)) * torch.ones(1, 1, device=device)), requires_grad=True) 28 | 29 | def forward(self, x): 30 | x_corr = self._DCR_1(x, self.temperature) 31 | # x_corr=self._DCR_2(x) 32 | 33 | # for abla of DCR 34 | # x_corr=self._CR(x) 35 | 36 | return x_corr 37 | 38 | def _DCR_1(self, x, t): 39 | len_x = len(x.size()) 40 | 41 | if len_x == 3: 42 | # spatial 43 | batchSize, c, h_w = x.size() 44 | x = x.permute(0, 2, 1) 45 | c = h_w 46 | elif len_x == 4: 47 | # spectral channel 48 | batchSize, c, h, w = x.size() 49 | h_w = h * w 50 | x = x.reshape(batchSize, c, h_w) 51 | 52 | I = torch.eye(c, c, device=x.device).view(1, c, c).repeat(batchSize, 1, 1).type(x.dtype) 53 | I_M = torch.ones(batchSize, c, c, device=x.device).type(x.dtype) 54 | x_pow2 = x.bmm(x.transpose(1, 2)) 55 | dcov = I_M.bmm(x_pow2 * I) + (x_pow2 * I).bmm(I_M) - 2 * x_pow2 56 | 57 | dcov = torch.clamp(dcov, min=0.0) 58 | dcov = torch.exp(t) * dcov 59 | dcov = torch.sqrt(dcov + 1e-5) 60 | 61 | out = dcov - 1. / c * dcov.bmm(I_M) - 1. / c * I_M.bmm(dcov) + 1. / (c * c) * I_M.bmm(dcov).bmm(I_M) 62 | 63 | return out * (-1) 64 | 65 | def _DCR_2(self, x): 66 | batch_size, c, h, w = x.size() 67 | 68 | x = x.view(batch_size, -1, h * w).permute(0, 2, 1) 69 | 70 | x = x.permute(0, 2, 1) 71 | x1, x2 = x[:, :, None], x[:, None] 72 | x3 = x1 - x2 73 | band_l2_mat = torch.norm(x3, dim=3, p=2) 74 | 75 | bem_mean_row, becm_mean_col = torch.mean(band_l2_mat, dim=1, keepdim=True), torch.mean(band_l2_mat, dim=2, 76 | keepdim=True) 77 | bem_mean_row_expand, becm_mean_col_expand = bem_mean_row.expand(band_l2_mat.shape), becm_mean_col.expand( 78 | band_l2_mat.shape) 79 | bem_mean_plus_row_col = bem_mean_row_expand + becm_mean_col_expand 80 | bem_mean_all = torch.mean(bem_mean_row, dim=2) 81 | becm = band_l2_mat - bem_mean_plus_row_col + torch.unsqueeze(bem_mean_all, dim=-1) 82 | 83 | return becm * (-1) 84 | 85 | def _CR(self, x): 86 | batch_size, c, h, w = x.size() 87 | 88 | x = x.view(batch_size, -1, h * w).permute(0, 2, 1) 89 | mean_pixel = torch.mean(x, dim=1, keepdim=True) 90 | mean_pixel_expand = mean_pixel.expand(x.shape) 91 | 92 | x_cr = x - mean_pixel_expand 93 | CR = torch.bmm(x_cr.permute(0, 2, 1), x_cr) 94 | CR = torch.div(CR, h * w - 1) 95 | 96 | return CR 97 | -------------------------------------------------------------------------------- /model/module/manifold_learning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : manifold_learning.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # based on the source code of RBN, i.e., A Riemannian Network for SPD Matrix Learning, NeurIPS 2019 9 | # https://proceedings.neurips.cc/paper/2019/hash/6e69ebbfad976d4637bb4b39de261bf7-Abstract.html 10 | 11 | import torch 12 | import torch.nn as nn 13 | import model.module.manifold_learning_fun as m_fun 14 | 15 | dtype = torch.float64 16 | device = torch.device('cuda') 17 | 18 | 19 | class BiMap(nn.Module): 20 | """ 21 | Input X: (batch_size,hi) SPD matrices of size (ni,ni) 22 | Output P: (batch_size,ho) of bilinearly mapped matrices of size (no,no) 23 | Stiefel parameter of size (ho,hi,ni,no) 24 | """ 25 | 26 | def __init__(self, ho, hi, ni, no): 27 | super(BiMap, self).__init__() 28 | self._W = m_fun.StiefelParameter( 29 | torch.empty(ho, hi, ni, no, dtype=dtype, device=device)) 30 | self._ho = ho 31 | self._hi = hi 32 | self._ni = ni 33 | self._no = no 34 | m_fun.init_bimap_parameter(self._W) 35 | # self._no 36 | 37 | def forward(self, X): 38 | return m_fun.bimap_channels(X, self._W) 39 | 40 | 41 | class ReEig(nn.Module): 42 | """ 43 | Input P: (batch_size,h) SPD matrices of size (n,n) 44 | Output X: (batch_size,h) of rectified eigenvalues matrices of size (n,n) 45 | """ 46 | 47 | def forward(self, P): 48 | return m_fun.ReEig.apply(P.cpu()) 49 | # return m_fun.ReEig.apply(P) 50 | 51 | 52 | class LogEig(nn.Module): 53 | """ 54 | Input P: (batch_size,h) SPD matrices of size (n,n) 55 | Output X: (batch_size,h) of log eigenvalues matrices of size (n,n) 56 | """ 57 | 58 | def forward(self, P): 59 | return m_fun.LogEig.apply(P) 60 | 61 | 62 | class SqmEig(nn.Module): 63 | """ 64 | Input P: (batch_size,h) SPD matrices of size (n,n) 65 | Output X: (batch_size,h) of sqrt eigenvalues matrices of size (n,n) 66 | """ 67 | 68 | def forward(self, P): 69 | return m_fun.SqmEig.apply(P) 70 | 71 | 72 | class BatchNormSPD(nn.Module): 73 | """ 74 | Input X: (N,h) SPD matrices of size (n,n) with h channels and batch size N 75 | Output P: (N,h) batch-normalized matrices 76 | SPD parameter of size (n,n) 77 | """ 78 | 79 | def __init__(self, n): 80 | super(__class__, self).__init__() 81 | self.momentum = 0.1 82 | self.running_mean = torch.eye( 83 | n, dtype=dtype, device=device) ################################ 84 | # self.running_mean=nn.Parameter(th.eye(n,dtype=dtype),requires_grad=False) 85 | self.weight = m_fun.SPDParameter(torch.eye(n, dtype=dtype, device=device)) 86 | 87 | def forward(self, X): 88 | N, h, n, n = X.shape 89 | X_batched = X.permute(2, 3, 0, 90 | 1).contiguous().view(n, n, N * h, 91 | 1).permute(2, 3, 0, 92 | 1).contiguous() 93 | if (self.training): 94 | mean = m_fun.BaryGeom(X_batched) 95 | with torch.no_grad(): 96 | self.running_mean.data = m_fun.geodesic( 97 | self.running_mean, mean, self.momentum) 98 | X_centered = m_fun.CongrG(X_batched, mean, 'neg') 99 | else: 100 | X_centered = m_fun.CongrG(X_batched, self.running_mean, 101 | 'neg') # subtract mean 102 | X_normalized = m_fun.CongrG(X_centered, self.weight, 103 | 'pos') # add bias 104 | return X_normalized.permute(2, 3, 0, 105 | 1).contiguous().view(n, n, N, h).permute( 106 | 2, 3, 0, 1).contiguous() 107 | -------------------------------------------------------------------------------- /model/AMS_M2ESL.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : AMS_M2ESL.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | import torch 9 | import torch.nn as nn 10 | import model.module.AMIPS as AMIPS 11 | import model.module.DCR as DCR 12 | import model.module.manifold_learning as SPD_net 13 | import model.module.EucProject as EP 14 | 15 | # import model.module.MPA_Lya as MPA 16 | 17 | 18 | class AMS_M2ESL(nn.Module): 19 | def __init__(self, in_channels, patch_size, class_count, ds_name): 20 | super(AMS_M2ESL, self).__init__() 21 | self.patch_size = patch_size 22 | self.in_channels = in_channels 23 | self.class_count = class_count 24 | self.ds = ds_name 25 | 26 | self.channels_1 = 2 27 | self.inter_num = 2 28 | 29 | # for AMIPS 30 | self.am_ip_sampling = AMIPS.AM_IPS(self.ds) 31 | 32 | # for DC-DCR 33 | self.spe_spa_corr_mine = DCR.Spectral_corr_mining(self.in_channels) 34 | self.dw_deconv_5 = nn.ConvTranspose2d(self.in_channels, self.in_channels, kernel_size=5, stride=1, 35 | padding=5 // 2, groups=self.in_channels) 36 | self.dw_conv_5 = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=5, stride=1, padding=5 // 2, 37 | groups=self.in_channels) 38 | 39 | # for SPD mainifold subspace learning 40 | ''' 41 | the implementation of deep mainifold learning is mainly based on the source code of RBN, 42 | i.e., A Riemannian Network for SPD Matrix Learning, NeurIPS 2019 43 | ''' 44 | self.bit_map = SPD_net.BiMap(self.channels_1, self.channels_1, self.in_channels, self.in_channels) 45 | self.re_eig = SPD_net.ReEig() 46 | 47 | # BN test 48 | self.bn_spd = SPD_net.BatchNormSPD(self.in_channels) 49 | 50 | # for Euclidean projection 51 | self.app_mat_sqrt = EP.ASQRT_autograd_mc(norm_type='Frob_n', num_iter=2) 52 | 53 | # ASQRT test 54 | # self.log_eig=SPD_net.LogEig() 55 | # self.sqrt_eig=SPD_net.SqmEig() 56 | # self.sqrt_MPA_Lya = MPA.MPA_Lya.apply 57 | 58 | # for Euclidean subspace learning 59 | self.bn = nn.BatchNorm2d(self.channels_1) 60 | self.flatten = nn.Flatten(1) 61 | 62 | self.sigmoid = nn.Sigmoid() 63 | self.dropout = nn.Dropout(p=0.2) 64 | 65 | if self.ds == 'IP': 66 | self.fc_0 = nn.Linear(7200, 512) 67 | elif self.ds == 'UP': 68 | self.fc_0 = nn.Linear(1922, 512) 69 | elif self.ds == 'UH_tif': 70 | self.fc_0 = nn.Linear(3872, 512) 71 | 72 | self.fc_1 = nn.Linear(512, 128) 73 | self.fc_2 = nn.Linear(128, class_count) 74 | 75 | def forward(self, x): 76 | x = x.permute(0, 3, 1, 2) 77 | 78 | # AMIPS 79 | x_sampled = self.am_ip_sampling(x) 80 | 81 | # DC-DCR 82 | x_channel_1 = self.spe_spa_corr_mine(x_sampled) 83 | 84 | x_deC = self.dw_deconv_5(x_sampled) 85 | x_deC_C = self.dw_conv_5(x_deC) 86 | x_channel_2 = self.spe_spa_corr_mine(x_deC_C) 87 | 88 | x_channel_1 = torch.unsqueeze(x_channel_1, dim=1) 89 | x_channel_2 = torch.unsqueeze(x_channel_2, dim=1) 90 | a_0 = torch.cat((x_channel_1, x_channel_2), dim=1) 91 | 92 | # M2ESL 93 | a_1 = self.bit_map(a_0) 94 | # a_1=self.bn_spd(a_1) 95 | a_2 = self.re_eig(a_1.cpu()) 96 | 97 | a_2_proj = self.app_mat_sqrt(a_2.cuda()) 98 | 99 | # a_2_proj=self.log_eig(a_2.cpu()) 100 | # a_2_proj=self.sqrt_eig(a_2.cpu()) 101 | # a_2_proj = self._sqrt_mpa_c2(a_2.cuda()) 102 | 103 | a_2_2 = self.bn(a_2_proj) 104 | a_3 = self.flatten(a_2_2) 105 | 106 | a_3_2 = self.fc_0(a_3) 107 | a_4 = self.fc_1(a_3_2) 108 | a_4_2 = self.sigmoid(a_4) 109 | a_4_3 = self.dropout(a_4_2) 110 | 111 | out = self.fc_2(a_4_3) 112 | return out 113 | 114 | def _sqrt_mpa_c2(self, x): 115 | x_channel_0, x_channel_1 = x[:, 0], x[:, 1] 116 | x_channel_0_sqrt, x_channel_1_sqrt = self.sqrt_MPA_Lya(x_channel_0), self.sqrt_MPA_Lya(x_channel_1) 117 | x_channel_0_sqrt, x_channel_1_sqrt = torch.unsqueeze(x_channel_0_sqrt, dim=1), torch.unsqueeze(x_channel_1_sqrt, 118 | dim=1) 119 | out = torch.cat((x_channel_0_sqrt, x_channel_1_sqrt), dim=1) 120 | return out 121 | 122 | 123 | def AMS_M2ESL_(in_channels, patch_size, num_classes, ds): 124 | model = AMS_M2ESL(in_channels, patch_size, num_classes, ds) 125 | return model 126 | -------------------------------------------------------------------------------- /model/module/MPA_Lya.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : MPA_Lya.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | ''' 9 | MPA-Lya, a efficinet method for differentiable matrix square root, 10 | Fast Differentiable Matrix Square Root, ICLR 2022 11 | https://github.com/KingJamesSong/FastDifferentiableMatSqrt/blob/main/torch_utils.py 12 | ''' 13 | 14 | from mpmath import * 15 | import numpy as np 16 | import torch 17 | 18 | mp.dps = 32 19 | one = mpf(1) 20 | mp.pretty = True 21 | 22 | 23 | def f(x): 24 | return sqrt(one - x) 25 | 26 | 27 | # Derive the taylor and pade' coefficients for MTP, MPA 28 | a = taylor(f, 0, 10) 29 | pade_p, pade_q = pade(a, 5, 5) 30 | a = torch.from_numpy(np.array(a).astype(float)) 31 | pade_p = torch.from_numpy(np.array(pade_p).astype(float)) 32 | pade_q = torch.from_numpy(np.array(pade_q).astype(float)) 33 | 34 | 35 | def matrix_taylor_polynomial(p, I): 36 | p_sqrt = I 37 | p_app = I - p 38 | p_hat = p_app 39 | for i in range(10): 40 | p_sqrt += a[i + 1] * p_hat 41 | p_hat = p_hat.bmm(p_app) 42 | return p_sqrt 43 | 44 | 45 | def matrix_pade_approximant(p, I): 46 | p_sqrt = pade_p[0] * I 47 | q_sqrt = pade_q[0] * I 48 | p_app = I - p 49 | p_hat = p_app 50 | for i in range(5): 51 | p_sqrt += pade_p[i + 1] * p_hat 52 | q_sqrt += pade_q[i + 1] * p_hat 53 | p_hat = p_hat.bmm(p_app) 54 | # There are 4 options to compute the MPA: comput Matrix Inverse or Matrix Linear System on CPU/GPU; 55 | # It seems that single matrix is faster on CPU and batched matrices are faster on GPU 56 | # Please check which one is faster befores running the code; 57 | # out=torch.linalg.solve(q_sqrt.cpu(), p_sqrt.cpu()).cuda() 58 | # return out 59 | 60 | # return torch.linalg.solve(q_sqrt, p_sqrt) 61 | return torch.linalg.solve(q_sqrt.cpu(), p_sqrt.cpu()).cuda() # result cpu to cuda 62 | # return torch.linalg.inv(q_sqrt).bmm(p_sqrt) 63 | # return torch.linalg.inv(q_sqrt.cpu()).cuda().bmm(p_sqrt) 64 | 65 | 66 | def matrix_pade_approximant_inverse(p, I): 67 | p_sqrt = pade_p[0] * I 68 | q_sqrt = pade_q[0] * I 69 | p_app = I - p 70 | p_hat = p_app 71 | for i in range(5): 72 | p_sqrt += pade_p[i + 1] * p_hat 73 | q_sqrt += pade_q[i + 1] * p_hat 74 | p_hat = p_hat.bmm(p_app) 75 | # There are 4 options to compute the MPA_inverse: comput Matrix Inverse or Matrix Linear System on CPU/GPU; 76 | # It seems that single matrix is faster on CPU and batched matrices are faster on GPU 77 | # Please check which one is faster before running the code; 78 | # return torch.linalg.solve(p_sqrt, q_sqrt) 79 | return torch.linalg.solve(p_sqrt.cpu(), q_sqrt.cpu()).cuda() 80 | # return torch.linalg.inv(p_sqrt).mm(q_sqrt) 81 | # return torch.linalg.inv(p_sqrt.cpu()).cuda().bmm(q_sqrt) 82 | 83 | 84 | # Differentiable Matrix Square Root by MPA_Lya 85 | class MPA_Lya(torch.autograd.Function): 86 | @staticmethod 87 | def forward(ctx, M): 88 | normM = torch.norm(M, dim=[1, 2]).reshape(M.size(0), 1, 1) 89 | I = torch.eye(M.size(1), requires_grad=False, device=M.device).reshape(1, M.size(1), M.size(1)).repeat( 90 | M.size(0), 1, 1) 91 | # M_sqrt = matrix_taylor_polynomial(M/normM,I) 92 | M_sqrt = matrix_pade_approximant(M / normM, I) 93 | M_sqrt = M_sqrt * torch.sqrt(normM) 94 | ctx.save_for_backward(M, M_sqrt, normM, I) 95 | return M_sqrt 96 | 97 | @staticmethod 98 | def backward(ctx, grad_output): 99 | M, M_sqrt, normM, I = ctx.saved_tensors 100 | b = M_sqrt / torch.sqrt(normM) 101 | c = grad_output / torch.sqrt(normM) 102 | for i in range(8): 103 | # In case you might terminate the iteration by checking convergence 104 | # if torch.norm(b-I)<1e-4: 105 | # break 106 | b_2 = b.bmm(b) 107 | c = 0.5 * (c.bmm(3.0 * I - b_2) - b_2.bmm(c) + b.bmm(c).bmm(b)) 108 | b = 0.5 * b.bmm(3.0 * I - b_2) 109 | grad_input = 0.5 * c 110 | return grad_input 111 | 112 | 113 | # Differentiable Inverse Square Root by MPA_Lya_Inv 114 | class MPA_Lya_Inv(torch.autograd.Function): 115 | @staticmethod 116 | def forward(ctx, M): 117 | normM = torch.norm(M, dim=[1, 2]).reshape(M.size(0), 1, 1) 118 | I = torch.eye(M.size(1), requires_grad=False, device=M.device).reshape(1, M.size(1), M.size(1)).repeat( 119 | M.size(0), 1, 1) 120 | # M_sqrt = matrix_taylor_polynomial(M/normM,I) 121 | M_sqrt_inv = matrix_pade_approximant_inverse(M / normM, I) 122 | M_sqrt_inv = M_sqrt_inv / torch.sqrt(normM) 123 | ctx.save_for_backward(M, M_sqrt_inv, I) 124 | return M_sqrt_inv 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | M, M_sqrt_inv, I = ctx.saved_tensors 129 | M_inv = M_sqrt_inv.bmm(M_sqrt_inv) 130 | grad_lya = - M_inv.bmm(grad_output).bmm(M_inv) 131 | norm_sqrt_inv = torch.norm(M_sqrt_inv) 132 | b = M_sqrt_inv / norm_sqrt_inv 133 | c = grad_lya / norm_sqrt_inv 134 | for i in range(8): 135 | # In case you might terminate the iteration by checking convergence 136 | # if th.norm(b-I)<1e-4: 137 | # break 138 | b_2 = b.bmm(b) 139 | c = 0.5 * (c.bmm(3.0 * I - b_2) - b_2.bmm(c) + b.bmm(c).bmm(b)) 140 | b = 0.5 * b.bmm(3.0 * I - b_2) 141 | grad_input = 0.5 * c 142 | return grad_input 143 | -------------------------------------------------------------------------------- /c_model/SSFTT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : SSFTT.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # https://github.com/zgr6010/HSI_SSFTT/blob/main/cls_SSFTT_IP/SSFTTnet.py 9 | # Spectral–Spatial Feature Tokenization Transformer for Hyperspectral Image Classification, TGRS 2022 10 | 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from einops import rearrange 15 | from torch import nn 16 | import torch.nn.init as init 17 | 18 | 19 | 20 | def _weights_init(m): 21 | classname = m.__class__.__name__ 22 | #print(classname) 23 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): 24 | init.kaiming_normal_(m.weight) 25 | 26 | class Residual(nn.Module): 27 | def __init__(self, fn): 28 | super().__init__() 29 | self.fn = fn 30 | 31 | def forward(self, x, **kwargs): 32 | return self.fn(x, **kwargs) + x 33 | 34 | # 等于 PreNorm 35 | class LayerNormalize(nn.Module): 36 | def __init__(self, dim, fn): 37 | super().__init__() 38 | self.norm = nn.LayerNorm(dim) 39 | self.fn = fn 40 | 41 | def forward(self, x, **kwargs): 42 | return self.fn(self.norm(x), **kwargs) 43 | 44 | # 等于 FeedForward 45 | class MLP_Block(nn.Module): 46 | def __init__(self, dim, hidden_dim, dropout=0.1): 47 | super().__init__() 48 | self.net = nn.Sequential( 49 | nn.Linear(dim, hidden_dim), 50 | nn.GELU(), 51 | nn.Dropout(dropout), 52 | nn.Linear(hidden_dim, dim), 53 | nn.Dropout(dropout) 54 | ) 55 | 56 | def forward(self, x): 57 | return self.net(x) 58 | 59 | 60 | class Attention(nn.Module): 61 | 62 | def __init__(self, dim, heads=8, dropout=0.1): 63 | super().__init__() 64 | self.heads = heads 65 | self.scale = dim ** -0.5 # 1/sqrt(dim) 66 | 67 | self.to_qkv = nn.Linear(dim, dim * 3, bias=True) # Wq,Wk,Wv for each vector, thats why *3 68 | # torch.nn.init.xavier_uniform_(self.to_qkv.weight) 69 | # torch.nn.init.zeros_(self.to_qkv.bias) 70 | 71 | self.nn1 = nn.Linear(dim, dim) 72 | # torch.nn.init.xavier_uniform_(self.nn1.weight) 73 | # torch.nn.init.zeros_(self.nn1.bias) 74 | self.do1 = nn.Dropout(dropout) 75 | 76 | def forward(self, x, mask=None): 77 | 78 | b, n, _, h = *x.shape, self.heads 79 | qkv = self.to_qkv(x).chunk(3, dim = -1) # gets q = Q = Wq matmul x1, k = Wk mm x2, v = Wv mm x3 80 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) # split into multi head attentions 81 | 82 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 83 | mask_value = -torch.finfo(dots.dtype).max 84 | 85 | if mask is not None: 86 | mask = F.pad(mask.flatten(1), (1, 0), value=True) 87 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 88 | mask = mask[:, None, :] * mask[:, :, None] 89 | dots.masked_fill_(~mask, float('-inf')) 90 | del mask 91 | 92 | attn = dots.softmax(dim=-1) # follow the softmax,q,d,v equation in the paper 93 | 94 | out = torch.einsum('bhij,bhjd->bhid', attn, v) # product of v times whatever inside softmax 95 | out = rearrange(out, 'b h n d -> b n (h d)') # concat heads into one matrix, ready for next encoder block 96 | out = self.nn1(out) 97 | out = self.do1(out) 98 | return out 99 | 100 | 101 | class Transformer(nn.Module): 102 | def __init__(self, dim, depth, heads, mlp_dim, dropout): 103 | super().__init__() 104 | self.layers = nn.ModuleList([]) 105 | for _ in range(depth): 106 | self.layers.append(nn.ModuleList([ 107 | Residual(LayerNormalize(dim, Attention(dim, heads=heads, dropout=dropout))), 108 | Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout=dropout))) 109 | ])) 110 | 111 | def forward(self, x, mask=None): 112 | for attention, mlp in self.layers: 113 | x = attention(x, mask=mask) # go to attention 114 | x = mlp(x) # go to MLP_Block 115 | return x 116 | 117 | # NUM_CLASS = 16 118 | 119 | class SSFTTnet(nn.Module): 120 | def __init__(self, in_channels=1, num_classes=16, num_tokens=4, dim=64, depth=1, heads=8, mlp_dim=8, dropout=0.1, emb_dropout=0.1): 121 | super(SSFTTnet, self).__init__() 122 | self.class_count=num_classes 123 | 124 | self.L = num_tokens 125 | self.cT = dim 126 | self.conv3d_features = nn.Sequential( 127 | nn.Conv3d(in_channels, out_channels=8, kernel_size=(3, 3, 3)), 128 | nn.BatchNorm3d(8), 129 | nn.ReLU(), 130 | ) 131 | 132 | self.conv2d_features = nn.Sequential( 133 | nn.Conv2d(in_channels=8*28, out_channels=64, kernel_size=(3, 3)), 134 | nn.BatchNorm2d(64), 135 | nn.ReLU(), 136 | ) 137 | 138 | # Tokenization 139 | self.token_wA = nn.Parameter(torch.empty(1, self.L, 64), 140 | requires_grad=True) # Tokenization parameters 141 | torch.nn.init.xavier_normal_(self.token_wA) 142 | self.token_wV = nn.Parameter(torch.empty(1, 64, self.cT), 143 | requires_grad=True) # Tokenization parameters 144 | torch.nn.init.xavier_normal_(self.token_wV) 145 | 146 | self.pos_embedding = nn.Parameter(torch.empty(1, (num_tokens + 1), dim)) 147 | torch.nn.init.normal_(self.pos_embedding, std=.02) 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) 150 | self.dropout = nn.Dropout(emb_dropout) 151 | 152 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) 153 | 154 | self.to_cls_token = nn.Identity() 155 | 156 | self.nn1 = nn.Linear(dim, num_classes) 157 | torch.nn.init.xavier_uniform_(self.nn1.weight) 158 | torch.nn.init.normal_(self.nn1.bias, std=1e-6) 159 | 160 | def forward(self, x, mask=None): 161 | 162 | x=x.permute(0,1,4,3,2) 163 | 164 | x = self.conv3d_features(x) 165 | x = rearrange(x, 'b c h w y -> b (c h) w y') 166 | x = self.conv2d_features(x) 167 | x = rearrange(x,'b c h w -> b (h w) c') 168 | 169 | wa = rearrange(self.token_wA, 'b h w -> b w h') # Transpose 170 | A = torch.einsum('bij,bjk->bik', x, wa) 171 | A = rearrange(A, 'b h w -> b w h') # Transpose 172 | A = A.softmax(dim=-1) 173 | 174 | VV = torch.einsum('bij,bjk->bik', x, self.token_wV) 175 | T = torch.einsum('bij,bjk->bik', A, VV) 176 | 177 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 178 | x = torch.cat((cls_tokens, T), dim=1) 179 | x += self.pos_embedding 180 | x = self.dropout(x) 181 | x = self.transformer(x, mask) # main game 182 | x = self.to_cls_token(x[:, 0]) 183 | x = self.nn1(x) 184 | 185 | return x -------------------------------------------------------------------------------- /process_cls_disjoint_c_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : process_cls_disjoint_c_model.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # for the UH data set, main processing file for SVM-RBF, 9 | 10 | import os 11 | import time 12 | import torch 13 | import random 14 | import numpy as np 15 | 16 | from sklearn import metrics 17 | from sklearn.svm import SVC 18 | from sklearn.ensemble import RandomForestClassifier 19 | 20 | import utils.evaluation as evaluation 21 | import utils.data_load_operate as data_load_operate 22 | import visual.cls_visual as cls_visual 23 | 24 | # import model.CNN1D as CNN1D 25 | 26 | time_current = time.strftime("%y-%m-%d-%H.%M", time.localtime()) 27 | 28 | # random seed setting 29 | seed = 20 30 | 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | np.random.seed(seed) # Numpy module. 35 | random.seed(seed) # Python random module. 36 | torch.manual_seed(seed) 37 | torch.backends.cudnn.benchmark = False 38 | torch.backends.cudnn.deterministic = True 39 | 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | 42 | ######## 0 # 43 | model_list = ['SVM'] 44 | model_flag = 0 45 | 46 | data_set_name_list = ['UH_tif'] 47 | data_set_name = data_set_name_list[0] 48 | 49 | data_set_path = os.path.join(os.getcwd(), 'data') 50 | 51 | # seed_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 52 | # seed_list=[0,1,2,3,4] 53 | # seed_list=[0,1,2] 54 | # seed_list=[0,1] 55 | seed_list = [0] 56 | 57 | ratio = "hu13" 58 | 59 | results_save_path = \ 60 | os.path.join(os.path.join(os.getcwd(), 'output/results'), model_list[model_flag] + str("_") + 61 | data_set_name + str("_") + str(time_current) + str("_seed") + str(seed)) + str("_ratio") + str(ratio) 62 | cls_map_save_path = \ 63 | os.path.join(os.path.join(os.getcwd(), 'output/cls_maps'), model_list[model_flag] + str("_") + 64 | data_set_name + str("_") + str(time_current) + str("_seed") + str(seed)) + str("_ratio") + str(ratio) 65 | 66 | if __name__ == '__main__': 67 | 68 | data, gt_train, gt_test = data_load_operate.load_HU_data(data_set_path) 69 | data = data_load_operate.standardization(data) 70 | 71 | gt_train_re = gt_train.reshape(-1) 72 | gt_test_re = gt_test.reshape(-1) 73 | height, width, channels = data.shape 74 | class_count = max(np.unique(gt_train_re)) 75 | 76 | OA_ALL = [] 77 | AA_ALL = [] 78 | KPP_ALL = [] 79 | EACH_ACC_ALL = [] 80 | Train_Time_ALL = [] 81 | Test_Time_ALL = [] 82 | CLASS_ACC = np.zeros([len(seed_list), class_count]) 83 | 84 | data_reshape = data.reshape(data.shape[0] * data.shape[1], -1) 85 | for curr_seed in seed_list: 86 | tic1 = time.perf_counter() 87 | train_data_index, test_data_index, all_data_index = data_load_operate.sampling_disjoint(gt_train_re, 88 | gt_test_re, 89 | class_count) 90 | index = (train_data_index, test_data_index, all_data_index) 91 | x_train, y_train, x_test, y_gt = data_load_operate.generate_data_set_disjoint(data_reshape, gt_train_re, 92 | gt_test_re, index) 93 | 94 | if model_flag == 0: 95 | clf = SVC(kernel='rbf', gamma='scale', C=20, tol=1e-5, random_state=10).fit(x_train, y_train) 96 | 97 | toc1 = time.perf_counter() 98 | training_time = toc1 - tic1 99 | Train_Time_ALL.append(training_time) 100 | 101 | tic2 = time.perf_counter() 102 | pred_test = clf.predict(x_test) 103 | toc2 = time.perf_counter() 104 | 105 | testing_time = toc2 - tic2 106 | Test_Time_ALL.append(testing_time) 107 | 108 | y_gt = gt_test_re[test_data_index] - 1 109 | OA = metrics.accuracy_score(y_gt, pred_test) 110 | confusion_matrix = metrics.confusion_matrix(pred_test, y_gt) 111 | print("confusion_matrix\n{}".format(confusion_matrix)) 112 | ECA, AA = evaluation.AA_ECA(confusion_matrix) 113 | kappa = metrics.cohen_kappa_score(pred_test, y_gt) 114 | cls_report = evaluation.claification_report(y_gt, pred_test, data_set_name) 115 | print("classification_report\n{}".format(cls_report)) 116 | 117 | # Visualization for all the labeled samples and total the samples 118 | # total_pred=clf.predict(data_reshape) 119 | # sample_list1=[total_pred+1] 120 | 121 | # all_pred=clf.predict(x_all) 122 | # sample_list2=[all_pred+1,all_data_index] 123 | 124 | # cls_visual.gt_cls_map(gt,cls_map_save_path) 125 | # cls_visual.pred_cls_map_cls(sample_list1,gt,cls_map_save_path) 126 | # cls_visual.pred_cls_map_cls(sample_list2,gt,cls_map_save_path) 127 | 128 | # Output infors 129 | f = open(results_save_path + '_results.txt', 'a+') 130 | str_results = '\n======================' \ 131 | + "\nOA=" + str(OA) \ 132 | + "\nAA=" + str(AA) \ 133 | + '\nkpp=' + str(kappa) \ 134 | + '\nacc per class:' + str(ECA) \ 135 | + "\ntrain time:" + str(training_time) \ 136 | + "\ntest time:" + str(testing_time) + "\n" 137 | 138 | f.write(str_results) 139 | f.write('{}'.format(confusion_matrix)) 140 | f.write('\n\n') 141 | f.write('{}'.format(cls_report)) 142 | f.close() 143 | 144 | OA_ALL.append(OA) 145 | AA_ALL.append(AA) 146 | KPP_ALL.append(kappa) 147 | EACH_ACC_ALL.append(ECA) 148 | 149 | OA_ALL = np.array(OA_ALL) 150 | AA_ALL = np.array(AA_ALL) 151 | KPP_ALL = np.array(KPP_ALL) 152 | EACH_ACC_ALL = np.array(EACH_ACC_ALL) 153 | Train_Time_ALL = np.array(Train_Time_ALL) 154 | Test_Time_ALL = np.array(Test_Time_ALL) 155 | 156 | np.set_printoptions(precision=4) 157 | print("\n====================Mean result of {} times runs ==========================".format(len(seed_list))) 158 | print('List of OA:', list(OA_ALL)) 159 | print('List of AA:', list(AA_ALL)) 160 | print('List of KPP:', list(KPP_ALL)) 161 | print('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2)) 162 | print('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2)) 163 | print('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2)) 164 | print('Acc per class=', np.mean(EACH_ACC_ALL, 0), '+-', np.std(EACH_ACC_ALL, 0)) 165 | 166 | print("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3)) 167 | print("Average testing time=", round(np.mean(Test_Time_ALL), 5), '+-', round(np.std(Test_Time_ALL), 5)) 168 | 169 | # Output infors 170 | f = open(results_save_path + '_results.txt', 'a+') 171 | str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + 'times runs ********************' \ 172 | + '\nList of OA:' + str(list(OA_ALL)) \ 173 | + '\nList of AA:' + str(list(AA_ALL)) \ 174 | + '\nList of KPP:' + str(list(KPP_ALL)) \ 175 | + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \ 176 | + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \ 177 | + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(round(np.std(KPP_ALL) * 100, 2)) \ 178 | + '\nAcc per class=' + str(np.mean(EACH_ACC_ALL, 0)) + '+-' + str(np.std(EACH_ACC_ALL, 0)) \ 179 | + "\nAverage training time=" + str(round(np.mean(Train_Time_ALL), 2)) + '+-' + str( 180 | round(np.std(Train_Time_ALL), 3)) \ 181 | + "\nAverage testing time=" + str(round(np.mean(Test_Time_ALL), 5)) + '+-' + str( 182 | round(np.std(Test_Time_ALL), 5)) 183 | f.write(str_results) 184 | f.close() 185 | -------------------------------------------------------------------------------- /process_cls_c_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : process_cls_c_model.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # # for IP and UP data sets, main processing file for SVM-RBF 9 | 10 | import os 11 | import time 12 | import torch 13 | import random 14 | import numpy as np 15 | 16 | from sklearn import metrics 17 | from sklearn.svm import SVC 18 | from sklearn.ensemble import RandomForestClassifier 19 | 20 | import utils.evaluation as evaluation 21 | import utils.data_load_operate as data_load_operate 22 | import visual.cls_visual as cls_visual 23 | 24 | time_current = time.strftime("%y-%m-%d-%H.%M", time.localtime()) 25 | 26 | # random seed setting 27 | seed = 20 28 | 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | np.random.seed(seed) # Numpy module. 33 | random.seed(seed) # Python random module. 34 | torch.manual_seed(seed) 35 | torch.backends.cudnn.benchmark = False 36 | torch.backends.cudnn.deterministic = True 37 | 38 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 39 | 40 | ######## 0 41 | model_list = ['SVM'] 42 | model_flag = 0 43 | 44 | data_set_name_list = ['IP', 'UP', 'KSC', 'HU_tif'] 45 | data_set_name = data_set_name_list[1] 46 | 47 | data_set_path = os.path.join(os.getcwd(), 'data') 48 | 49 | # seed_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 50 | # seed_list=[0,1,2,3,4] 51 | # seed_list=[0,1,2] 52 | # seed_list=[0,1] 53 | seed_list = [0] 54 | 55 | # data set split 56 | flag_list = [0, 1] # ratio or num 57 | 58 | if data_set_name == 'IP': 59 | ratio_list = [0.05, 0.005] 60 | ratio = 5.0 61 | elif data_set_name == 'UP': 62 | ratio_list = [0.01, 0.001] 63 | ratio = 1.0 64 | 65 | num_list = [50, 0] # [train_num,val_num] 66 | 67 | results_save_path = \ 68 | os.path.join(os.path.join(os.getcwd(), 'output/results'), model_list[model_flag] + str("_") + 69 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str(ratio)) 70 | cls_map_save_path = \ 71 | os.path.join(os.path.join(os.getcwd(), 'output/cls_maps'), model_list[model_flag] + str("_") + 72 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str(ratio)) 73 | 74 | if __name__ == '__main__': 75 | 76 | data, gt = data_load_operate.load_data(data_set_name, data_set_path) 77 | data = data_load_operate.standardization(data) 78 | 79 | gt_reshape = gt.reshape(-1) 80 | height, width, channels = data.shape 81 | class_count = max(np.unique(gt)) 82 | 83 | OA_ALL = [] 84 | AA_ALL = [] 85 | KPP_ALL = [] 86 | EACH_ACC_ALL = [] 87 | Train_Time_ALL = [] 88 | Test_Time_ALL = [] 89 | CLASS_ACC = np.zeros([len(seed_list), class_count]) 90 | 91 | data_reshape = data.reshape(data.shape[0] * data.shape[1], -1) 92 | for curr_seed in seed_list: 93 | tic1 = time.perf_counter() 94 | 95 | train_data_index, test_data_index, all_data_index = data_load_operate.sampling(ratio_list, num_list, 96 | gt_reshape, 97 | class_count, flag_list[0]) 98 | index = (train_data_index, test_data_index, all_data_index) 99 | x_train, y_train, x_test, y_gt, x_all, y_all = data_load_operate.generate_data_set(data_reshape, gt_reshape, 100 | index) 101 | 102 | if model_flag == 0: 103 | clf = SVC(kernel='rbf', gamma='scale', C=20, tol=1e-5, random_state=10).fit(x_train, y_train) 104 | 105 | toc1 = time.perf_counter() 106 | training_time = toc1 - tic1 107 | Train_Time_ALL.append(training_time) 108 | 109 | tic2 = time.perf_counter() 110 | pred_test = clf.predict(x_test) 111 | toc2 = time.perf_counter() 112 | 113 | testing_time = toc2 - tic2 114 | Test_Time_ALL.append(testing_time) 115 | 116 | y_gt = gt_reshape[test_data_index] - 1 117 | OA = metrics.accuracy_score(y_gt, pred_test) 118 | confusion_matrix = metrics.confusion_matrix(pred_test, y_gt) 119 | print("confusion_matrix\n{}".format(confusion_matrix)) 120 | ECA, AA = evaluation.AA_ECA(confusion_matrix) 121 | kappa = metrics.cohen_kappa_score(pred_test, y_gt) 122 | cls_report = evaluation.claification_report(y_gt, pred_test, data_set_name) 123 | print("classification_report\n{}".format(cls_report)) 124 | 125 | # Visualization for all the labeled samples and total the samples 126 | # total_pred = clf.predict(data_reshape) 127 | # sample_list1 = [total_pred + 1] 128 | 129 | # all_pred=clf.predict(x_all) 130 | # sample_list2=[all_pred+1,all_data_index] 131 | 132 | # cls_visual.gt_cls_map(gt,cls_map_save_path) 133 | # cls_visual.pred_cls_map_cls(sample_list1,gt,cls_map_save_path) 134 | # cls_visual.pred_cls_map_cls(sample_list2,gt,cls_map_save_path) 135 | 136 | # Output infors 137 | f = open(results_save_path + '_results.txt', 'a+') 138 | str_results = '\n======================' \ 139 | + "\nOA=" + str(OA) \ 140 | + "\nAA=" + str(AA) \ 141 | + '\nkpp=' + str(kappa) \ 142 | + '\nacc per class:' + str(ECA) \ 143 | + "\ntrain time:" + str(training_time) \ 144 | + "\ntest time:" + str(testing_time) + "\n" 145 | 146 | f.write(str_results) 147 | f.write('{}'.format(confusion_matrix)) 148 | f.write('\n\n') 149 | f.write('{}'.format(cls_report)) 150 | f.close() 151 | 152 | OA_ALL.append(OA) 153 | AA_ALL.append(AA) 154 | KPP_ALL.append(kappa) 155 | EACH_ACC_ALL.append(ECA) 156 | 157 | OA_ALL = np.array(OA_ALL) 158 | AA_ALL = np.array(AA_ALL) 159 | KPP_ALL = np.array(KPP_ALL) 160 | EACH_ACC_ALL = np.array(EACH_ACC_ALL) 161 | Train_Time_ALL = np.array(Train_Time_ALL) 162 | Test_Time_ALL = np.array(Test_Time_ALL) 163 | 164 | np.set_printoptions(precision=4) 165 | print("\n====================Mean result of {} times runs ==========================".format(len(seed_list))) 166 | print('List of OA:', list(OA_ALL)) 167 | print('List of AA:', list(AA_ALL)) 168 | print('List of KPP:', list(KPP_ALL)) 169 | print('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2)) 170 | print('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2)) 171 | print('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2)) 172 | print('Acc per class=', np.mean(EACH_ACC_ALL, 0), '+-', np.std(EACH_ACC_ALL, 0)) 173 | 174 | print("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3)) 175 | print("Average testing time=", round(np.mean(Test_Time_ALL), 5), '+-', round(np.std(Test_Time_ALL), 5)) 176 | 177 | # Output infors 178 | f = open(results_save_path + '_results.txt', 'a+') 179 | str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + 'times runs ********************' \ 180 | + '\nList of OA:' + str(list(OA_ALL)) \ 181 | + '\nList of AA:' + str(list(AA_ALL)) \ 182 | + '\nList of KPP:' + str(list(KPP_ALL)) \ 183 | + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \ 184 | + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \ 185 | + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(round(np.std(KPP_ALL) * 100, 2)) \ 186 | + '\nAcc per class=' + str(np.mean(EACH_ACC_ALL, 0)) + '+-' + str(np.std(EACH_ACC_ALL, 0)) \ 187 | + "\nAverage training time=" + str(round(np.mean(Train_Time_ALL), 2)) + '+-' + str( 188 | round(np.std(Train_Time_ALL), 3)) \ 189 | + "\nAverage testing time=" + str(round(np.mean(Test_Time_ALL), 5)) + '+-' + str( 190 | round(np.std(Test_Time_ALL), 5)) 191 | f.write(str_results) 192 | f.close() 193 | -------------------------------------------------------------------------------- /c_model/SSAtt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : SSAtt.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | 9 | # https://github.com/weecology/DeepTreeAttention/blob/main/src/models/Hang2020.py 10 | # Hyperspectral Image Classification With Attention-Aided CNNs, TGRS 2020 11 | 12 | 13 | from torch.nn import Module 14 | from torch.nn import functional as F 15 | from torch import nn 16 | import torch 17 | 18 | 19 | def global_spectral_pool(x): 20 | """Helper function to keep the same dimensions after pooling to avoid resizing each time""" 21 | global_pool = torch.mean(x, dim=(2, 3)) 22 | global_pool = global_pool.unsqueeze(-1) 23 | 24 | return global_pool 25 | 26 | 27 | class conv_module(Module): 28 | def __init__(self, in_channels, filters, maxpool_kernel=None): 29 | """Define a simple conv block with batchnorm and optional max pooling""" 30 | super(conv_module, self).__init__() 31 | self.conv_layer = nn.Conv2d(in_channels, out_channels=filters, kernel_size=(3, 3), padding=1) 32 | self.bn1 = nn.BatchNorm2d(filters) 33 | self.maxpool_kernal = maxpool_kernel 34 | if maxpool_kernel: 35 | self.max_pool = nn.MaxPool2d(maxpool_kernel) 36 | 37 | def forward(self, x, pool=False): 38 | # x = x.permute(0,3,1,2) # batch_size num_channels D H W 39 | x = self.conv_layer(x) 40 | x = self.bn1(x) 41 | x = F.relu(x) 42 | if pool: 43 | x = self.max_pool(x) 44 | 45 | return x 46 | 47 | 48 | class vanilla_CNN(Module): 49 | """ 50 | A baseline model without spectral convolutions or spatial/spectral attention 51 | """ 52 | 53 | def __init__(self, bands, classes): 54 | super(vanilla_CNN, self).__init__() 55 | self.conv1 = conv_module(in_channels=bands, filters=32) 56 | self.conv2 = conv_module(in_channels=32, filters=64, maxpool_kernel=(2, 2)) 57 | self.conv3 = conv_module(in_channels=64, filters=128, maxpool_kernel=(2, 2)) 58 | # The size of the fully connected layer Assumes a certain band convo, TODO make this flexible by band number. 59 | self.fc1 = nn.Linear(in_features=512, out_features=classes) 60 | 61 | def forward(self, x): 62 | """Take an input image and run the conv blocks, flatten the output and return features""" 63 | x = self.conv1(x) 64 | x = self.conv2(x, pool=True) 65 | x = self.conv3(x, pool=True) 66 | x = torch.flatten(x, start_dim=1) 67 | x = self.fc1(x) 68 | 69 | return x 70 | 71 | 72 | class spatial_attention(Module): 73 | """ 74 | Learn cross band spatial features with a set of convolutions and spectral pooling attention layers 75 | """ 76 | 77 | def __init__(self, filters, classes): 78 | super(spatial_attention, self).__init__() 79 | self.channel_pool = nn.Conv2d(in_channels=filters, out_channels=1, kernel_size=1) 80 | 81 | # Weak Attention with adaptive kernel size based on size of incoming feature map 82 | if filters == 32: 83 | kernel_size = 7 84 | pad = 3 85 | elif filters == 64: 86 | kernel_size = 5 87 | pad = 2 88 | elif filters == 128: 89 | kernel_size = 3 90 | pad = 1 91 | else: 92 | raise ValueError( 93 | "Unknown incoming kernel size {} for attention layers") 94 | 95 | self.attention_conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding=pad) 96 | self.attention_conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding=pad) 97 | 98 | # Add a classfication branch with max pool based on size of the layer 99 | if filters == 32: 100 | pool_size = (4, 4) 101 | in_features = 128 102 | elif filters == 64: 103 | in_features = 256 104 | pool_size = (2, 2) 105 | elif filters == 128: 106 | in_features = 512 107 | pool_size = (1, 1) 108 | else: 109 | raise ValueError("Unknown filter size for max pooling") 110 | 111 | self.class_pool = nn.MaxPool2d(pool_size) 112 | self.fc1 = nn.Linear(in_features=in_features, out_features=classes) 113 | 114 | def forward(self, x): 115 | """Calculate attention and class scores for batch""" 116 | # Global pooling and add dimensions to keep the same shape 117 | pooled_features = self.channel_pool(x) 118 | pooled_features = F.relu(pooled_features) 119 | 120 | # Attention layers 121 | attention = self.attention_conv1(pooled_features) 122 | attention = torch.relu(attention) 123 | attention = self.attention_conv2(attention) 124 | attention = torch.sigmoid(attention) 125 | 126 | # Add dummy dimension to make the shapes the same 127 | attention = torch.mul(x, attention) 128 | 129 | # Classification Head 130 | pooled_attention_features = self.class_pool(attention) 131 | pooled_attention_features = torch.flatten(pooled_attention_features, start_dim=1) 132 | class_features = self.fc1(pooled_attention_features) 133 | 134 | return attention, class_features 135 | 136 | 137 | class spectral_attention(Module): 138 | """ 139 | Learn cross band spectral features with a set of convolutions and spectral pooling attention layers 140 | The feature maps should be pooled to remove spatial dimensions before reading in the module 141 | Args: 142 | in_channels: number of feature maps of the current image 143 | """ 144 | 145 | def __init__(self, filters, classes): 146 | super(spectral_attention, self).__init__() 147 | # Weak Attention with adaptive kernel size based on size of incoming feature map 148 | if filters == 32: 149 | kernel_size = 3 150 | pad = 1 151 | elif filters == 64: 152 | kernel_size = 5 153 | pad = 2 154 | elif filters == 128: 155 | kernel_size = 7 156 | pad = 3 157 | else: 158 | raise ValueError( 159 | "Unknown incoming kernel size {} for attention layers") 160 | 161 | self.attention_conv1 = nn.Conv1d(in_channels=filters, out_channels=filters, kernel_size=kernel_size, 162 | padding=pad) 163 | self.attention_conv2 = nn.Conv1d(in_channels=filters, out_channels=filters, kernel_size=kernel_size, 164 | padding=pad) 165 | 166 | # TODO Does this pool size change base on in_features? 167 | self.fc1 = nn.Linear(in_features=filters, out_features=classes) 168 | 169 | def forward(self, x): 170 | """Calculate attention and class scores for batch""" 171 | # Global pooling and add dimensions to keep the same shape 172 | pooled_features = global_spectral_pool(x) 173 | 174 | # Attention layers 175 | attention = self.attention_conv1(pooled_features) 176 | attention = torch.relu(attention) 177 | attention = self.attention_conv2(attention) 178 | attention = torch.sigmoid(attention) 179 | 180 | # Add dummy dimension to make the shapes the same 181 | attention = attention.unsqueeze(-1) 182 | attention = torch.mul(x, attention) 183 | 184 | # Classification Head 185 | pooled_attention_features = global_spectral_pool(attention) 186 | pooled_attention_features = torch.flatten(pooled_attention_features, start_dim=1) 187 | class_features = self.fc1(pooled_attention_features) 188 | 189 | return attention, class_features 190 | 191 | 192 | class spatial_network(Module): 193 | """ 194 | Learn spatial features with alternating convolutional and attention pooling layers 195 | """ 196 | 197 | def __init__(self, bands, classes): 198 | super(spatial_network, self).__init__() 199 | 200 | # First submodel is 32 filters 201 | self.conv1 = conv_module(in_channels=bands, filters=32) 202 | self.attention_1 = spatial_attention(filters=32, classes=classes) 203 | 204 | self.conv2 = conv_module(in_channels=32, filters=64, maxpool_kernel=(2, 2)) 205 | self.attention_2 = spatial_attention(filters=64, classes=classes) 206 | 207 | self.conv3 = conv_module(in_channels=64, filters=128, maxpool_kernel=(2, 2)) 208 | self.attention_3 = spatial_attention(filters=128, classes=classes) 209 | 210 | def forward(self, x): 211 | """The forward method is written for training the joint scores of the three attention layers""" 212 | x = self.conv1(x) 213 | x, scores1 = self.attention_1(x) 214 | x = self.conv2(x, pool=True) 215 | x, scores2 = self.attention_2(x) 216 | x = self.conv3(x, pool=True) 217 | x, scores3 = self.attention_3(x) 218 | 219 | return [scores1, scores2, scores3] 220 | 221 | 222 | class spectral_network(Module): 223 | """ 224 | Learn spectral features with alternating convolutional and attention pooling layers 225 | """ 226 | 227 | def __init__(self, bands, classes): 228 | super(spectral_network, self).__init__() 229 | 230 | # First submodel is 32 filters 231 | self.conv1 = conv_module(in_channels=bands, filters=32) 232 | self.attention_1 = spectral_attention(filters=32, classes=classes) 233 | 234 | self.conv2 = conv_module(in_channels=32, filters=64, maxpool_kernel=(2, 2)) 235 | self.attention_2 = spectral_attention(filters=64, classes=classes) 236 | 237 | self.conv3 = conv_module(in_channels=64, filters=128, maxpool_kernel=(2, 2)) 238 | self.attention_3 = spectral_attention(filters=128, classes=classes) 239 | 240 | def forward(self, x): 241 | """The forward method is written for training the joint scores of the three attention layers""" 242 | x = self.conv1(x) 243 | x, scores1 = self.attention_1(x) 244 | x = self.conv2(x, pool=True) 245 | x, scores2 = self.attention_2(x) 246 | x = self.conv3(x, pool=True) 247 | x, scores3 = self.attention_3(x) 248 | 249 | return [scores1, scores2, scores3] 250 | 251 | 252 | class Hang2020(Module): 253 | def __init__(self, bands, classes): 254 | super(Hang2020, self).__init__() 255 | self.spectral_network = spectral_network(bands, classes) 256 | self.spatial_network = spatial_network(bands, classes) 257 | 258 | # Learnable weight 259 | self.alpha = nn.Parameter(torch.tensor(0.5, dtype=float), requires_grad=True) 260 | 261 | def forward(self, x): 262 | x = x.type(torch.float) 263 | x = x.permute(0, 3, 1, 2) 264 | spectral_scores = self.spectral_network(x) 265 | spatial_scores = self.spatial_network(x) 266 | 267 | # Take the final attention scores 268 | spectral_classes = spectral_scores[-1] 269 | spatial_classes = spatial_scores[-1] 270 | 271 | # Weighted average 272 | self.weighted_average = torch.sigmoid(self.alpha) 273 | joint_score = spectral_classes * self.weighted_average + spatial_classes * (1 - self.weighted_average) 274 | 275 | return joint_score 276 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/adaptive-mask-sampling-and-manifold-to/hyperspectral-image-classification-on-indian)](https://paperswithcode.com/sota/hyperspectral-image-classification-on-indian?p=adaptive-mask-sampling-and-manifold-to) 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/adaptive-mask-sampling-and-manifold-to/hyperspectral-image-classification-on-pavia)](https://paperswithcode.com/sota/hyperspectral-image-classification-on-pavia?p=adaptive-mask-sampling-and-manifold-to) 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/adaptive-mask-sampling-and-manifold-to/hyperspectral-image-classification-on-casi)](https://paperswithcode.com/sota/hyperspectral-image-classification-on-casi?p=adaptive-mask-sampling-and-manifold-to) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/adaptive-mask-sampling-and-manifold-to/hyperspectral-image-classification-on-houston)](https://paperswithcode.com/sota/hyperspectral-image-classification-on-houston?p=adaptive-mask-sampling-and-manifold-to) 8 | 9 | 10 | 11 | # [TGRS 2023] Adaptive Mask Sampling and Manifold to Euclidean Subspace Learning with Distance Covariance Representation for Hyperspectral Image Classification 12 | 13 | [Mingsong Li](https://lms-07.github.io/), [Wei Li](https://fdss.bit.edu.cn/yjdw/js/b153191.htm), Yikun Liu, [Yuwen Huang](https://jsj.hezeu.edu.cn/info/1302/6525.htm), and [Gongping Yang](https://faculty.sdu.edu.cn/gpyang) 14 | 15 | [Time Lab](https://time.sdu.edu.cn/), [SDU](https://www.sdu.edu.cn/) ; [BIT](https://www.bit.edu.cn/) 16 | 17 | ----------- 18 | This repository is the official implementation of our paper: 19 | [Adaptive Mask Sampling and Manifold to Euclidean Subspace Learning with Distance Covariance Representation for Hyperspectral Image Classification](https://doi.org/10.1109/TGRS.2023.3265388), IEEE Transactions on Geoscience and Remote Sensing (TGRS) 2023. 20 | 21 | ## Contents 22 | 1. [Brief Introduction](#Brief-Introduction) 23 | 1. [Environment](#Environment) 24 | 1. [Datasets and File Hierarchy](#Datasets-and-File-Hierarchy) 25 | 1. [Implementations of Compared Methods](#Implementations-of-Compared-Methods) 26 | 1. [Citation](#Citation) 27 | 1. [License and Acknowledgement](#License-and-Acknowledgement) 28 | 29 | ## Brief Introduction 30 | >

For the abundant spectral and spatial information recorded in hyperspectral images (HSIs), fully exploring spectral-spatial relationships has attracted widespread attention in hyperspectral image classification (HSIC) community. However, there are still some intractable obstructs. For one thing, in the patch-based processing pattern, some spatial neighbor pixels are often inconsistent with the central pixel in land-cover class. For another thing, linear and nonlinear correlations between different spectral bands are vital yet tough for representing and excavating. To overcome these mentioned issues, an adaptive mask sampling and manifold to Euclidean subspace learning (AMS-M2ESL) framework is proposed for HSIC. Specifically, an adaptive mask based intra-patch sampling (AMIPS) module is firstly formulated for intra-patch sampling in an adaptive mask manner based on central spectral vector oriented spatial relationships. Subsequently, based on distance covariance descriptor, a dual channel distance covariance representation (DC-DCR) module is proposed for modeling unified spectral-spatial feature representations and exploring spectral-spatial relationships, especially linear and nonlinear interdependence in spectral domain. Furthermore, considering that distance covariance matrix lies on the symmetric positive definite (SPD) manifold, we implement a manifold to Euclidean subspace learning (M2ESL) module respecting Riemannian geometry of SPD manifold for high-level spectral-spatial feature learning. Additionally, we introduce an approximate matrix square-root (ASQRT) layer for efficient Euclidean subspace projection. Extensive experimental results on three popular HSI datasets with limited training samples demonstrate the superior performance of the proposed method compared with other state-of-the-art methods. The source code is available at https://github.com/lms-07/AMS-M2ESL.

31 | 32 | **AMS-M2ESL Framework** 33 | 34 | ![framework](src/framework.png) 35 | 36 | ## Environment 37 | - The software environment is Ubuntu 18.04.5 LTS 64 bit. 38 | - This project is running on a single Nvidia GeForce RTX 3090 GPU based on Cuda 11.0. 39 | - We adopt Python 3.8.5, PyTorch 1.10.0+cu111. 40 | - The py+torch combination may not be limited by our adopted one. 41 | 42 | 43 | ## Datasets and File Hierarchy 44 | 45 | Three representative HSI datasets are adopted in our experiments, i.e., Indian Pines (IP), University of Pavia (UP), and University of Houston 13 (UH). 46 | The first two datasets could be accessed through [link1](http://www.ehu.eus/ccwintco/index.php?title=Hyperspectral_Remote_Sensing_Scenes##anomaly_detection), 47 | and the UH dataset through [link2](https://hyperspectral.ee.uh.edu/?page_id=459). 48 | Our project is organized as follows: 49 | 50 | ```text 51 | AMS-M2ESL 52 | |-- process_xxx // main files 1) dl for the proposed model 2) cls_c_model 53 | | for the classic compared model, SVM 3) dl_c_model for eight 54 | | dl based compared methods 4) disjoint for the 55 | | disjoint dataset (UH) 5) m_scale for the multiscale model, MCM-CNN 56 | |-- c_model // eight deep learning based compared methods 57 | |-- data 58 | | |-- IP 59 | | | |-- Indian_pines_corrected.mat 60 | | | |-- Indian_pines_gt.mat 61 | | |-- UP 62 | | | |-- PaviaU.mat 63 | | | |-- PaviaU_gt.mat 64 | | |-- HU13_tif 65 | | | |--Houston13_data.mat 66 | | | |--Houston13_gt_train.mat 67 | | | |--Houston13_gt_test.mat 68 | |-- model // the proposed method 69 | |-- output 70 | | |-- cls_maps // classification map visualizations 71 | | |-- results // classification result files 72 | |-- src // source files 73 | |-- utils // data loading, processing, and evaluating 74 | |-- visual // cls maps visual 75 | ``` 76 | ## Implementations of Compared Methods 77 | For comparisons, our codebase also includes related compared methods. 78 | - SVM, PyTorch version, sklearn-based 79 | - J-Play, [Joint & Progressive Learning from High-Dimensional Data for Multi-Label Classification](https://openaccess.thecvf.com/content_ECCV_2018/html/Danfeng_Hong_Joint__Progressive_ECCV_2018_paper.html) ECCV 2018, referring to official Matlab version, https://github.com/danfenghong/ECCV2018_J-Play 80 | - 1D-CNN, [Deep Convolutional Neural Networks for Hyperspectral Image Classification](https://www.hindawi.com/journals/js/2015/258619/) Journal of Sensors 2015, from an HSIC Tool Codebase, https://github.com/eecn/Hyperspectral-Classification 81 | - MCM-CNN, [Feature Extraction With Multiscale Covariance Maps for Hyperspectral Image Classification](https://ieeexplore.ieee.org/document/9565208) TGRS 2018, ***our unofficial PyTorch implementation*** based on official Matlab version, https://github.com/henanjun/demo_MCMs 82 | - SSTN, [Spectral-Spatial Transformer Network for Hyperspectral Image Classification: A Factorized Architecture Search Framework](https://ieeexplore.ieee.org/document/9565208) TGRS 2021, from official PyTorch version, https://github.com/zilongzhong/SSTN/blob/main/NetworksBlocks.py 83 | - SSSAN, [Spectral–Spatial Self-Attention Networks for Hyperspectral Image Classification](https://ieeexplore.ieee.org/document/9508777) TGRS 2021, ***our unofficial PyTorch implementation*** based on the part of source Keras code from the author Dr. Xuming Zhang 84 | - SSAtt, [Hyperspectral Image Classification With Attention-Aided CNNs](https://ieeexplore.ieee.org/abstract/document/9142417) TGRS 2020, from a PyTorch implementation, https://github.com/weecology/DeepTreeAttention/blob/main/src/models/Hang2020.py 85 | - A2S2K-ResNet, [Attention-Based Adaptive Spectral-Spatial Kernel ResNet for Hyperspectral Image Classification](https://ieeexplore.ieee.org/document/9306920) TGRS 2020, from official PyTorch version, https://github.com/suvojit-0x55aa/A2S2K-ResNet/blob/master/A2S2KResNet/A2S2KResNet.py 86 | - SSFTT, [Spectral–Spatial Feature Tokenization Transformer for Hyperspectral Image Classification](https://ieeexplore.ieee.org/document/9684381) TGRS 2022, from official PyTorch version, https://github.com/zgr6010/HSI_SSFTT/blob/main/cls_SSFTT_IP/SSFTTnet.py 87 | - ASPN, [Attention-Based Second-Order Pooling Network for Hyperspectral Image Classification](https://ieeexplore.ieee.org/document/9325094) TGRS 2021, ***our unofficial PyTorch implementation*** based on official Keras version, https://github.com/mengxue-rs/a-spn 88 | 89 | ## Citation 90 | 91 | Please kindly cite our work if this work is helpful for your research. 92 | 93 | [1] M. Li, W. Li, Y. Liu, Y. Huang and G. Yang, "Adaptive Mask Sampling and Manifold to Euclidean Subspace Learning With Distance Covariance Representation for Hyperspectral Image Classification," in IEEE Transactions on Geoscience and Remote Sensing, vol. 61, pp. 1-18, 2023, Art no. 5508518. 94 | 95 | BibTex entry: 96 | ```text 97 | @article{li2023adaptive, 98 | title={Adaptive Mask Sampling and Manifold to Euclidean Subspace Learning with Distance Covariance Representation for Hyperspectral Image Classification}, 99 | author={Li, Mingsong and Li, Wei and Liu, Yikun and Huang, Yuwen and Yang, Gongping}, 100 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 101 | year={2023}, 102 | volume={61}, 103 | number={}, 104 | pages={1-18}, 105 | publisher={IEEE}, 106 | } 107 | ``` 108 | 109 | ## Contact information 110 | 111 | If you have any problem, please do not hesitate to contact us `msli@mail.sdu.edu.cn`. 112 | 113 | ## License and Acknowledgement 114 | 115 | - This project is released under [GPLv3](http://www.gnu.org/licenses/) license. 116 | - We would like to thank the Hyperspectral Image Analysis group and the NSF Funded Center for 117 | Airborne Laser Mapping (NCALM) at the University of Houston for providing the UH dataset used in this work. 118 | - Our HSIC framework is implemented based on our prior work [CVSSN](https://github.com/lms-07/CVSSN). 119 | - Our proposed AMS-M2ESL framework is inspired by the following awesome works: 120 | - [Brownian distance covariance](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-3/issue-4/Brownian-distance-covariance/10.1214/09-AOAS312.full), Ann. Appl. Stat. 2009 121 | - [Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification](https://openaccess.thecvf.com/content/CVPR2022/html/Xie_Joint_Distribution_Matters_Deep_Brownian_Distance_Covariance_for_Few-Shot_Classification_CVPR_2022_paper.html), CVPR 2022 122 | - [Superpixel-Based Brownian Descriptor for Hyperspectral Image Classification](https://ieeexplore.ieee.org/document/9645390?arnumber=9645390), TGRS 2021 123 | - [A Riemannian Network for SPD Matrix Learning](https://ojs.aaai.org/index.php/AAAI/article/view/10866), AAAI 2017 124 | - [Riemannian batch normalization for SPD neural networks](https://proceedings.neurips.cc/paper/2019/hash/6e69ebbfad976d4637bb4b39de261bf7-Abstract.html), NeurIPS 2019 125 | - [Towards Faster Training of Global Covariance Pooling Networks by Iterative Matrix Square Root Normalization](https://openaccess.thecvf.com/content_cvpr_2018/html/Li_Towards_Faster_Training_CVPR_2018_paper.html), CVPR 2018 126 | - [COSONet: Compact Second-Order Network for Video Face Recognition](https://link.springer.com/chapter/10.1007/978-3-030-20893-6_4), ACCV 2018 127 | 128 | -------------------------------------------------------------------------------- /utils/data_load_operate_c_model_m_scale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : data_load_operate_c_model_m_scale.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | import os 9 | import math 10 | import torch 11 | import numpy as np 12 | import spectral as spy 13 | import scipy.io as sio 14 | import torch.utils.data as Data 15 | import matplotlib.pyplot as plt 16 | from sklearn.decomposition import PCA 17 | from sklearn import preprocessing 18 | 19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | def load_data(data_set_name, data_path): 23 | if data_set_name == 'IP': 24 | data = sio.loadmat(os.path.join(data_path, 'IP', 'Indian_pines_corrected.mat'))['indian_pines_corrected'] 25 | labels = sio.loadmat(os.path.join(data_path, 'IP', 'Indian_pines_gt.mat'))['indian_pines_gt'] 26 | elif data_set_name == 'UP': 27 | data = sio.loadmat(os.path.join(data_path, 'UP', 'PaviaU.mat'))['paviaU'] 28 | labels = sio.loadmat(os.path.join(data_path, 'UP', 'PaviaU_gt.mat'))['paviaU_gt'] 29 | 30 | return data, labels 31 | 32 | 33 | def load_HU_data(data_path): 34 | data = sio.loadmat(os.path.join(data_path, 'HU13_tif', "Houston13_data.mat"))['Houston13_data'] 35 | labels_train = sio.loadmat(os.path.join(data_path, 'HU13_tif', "Houston13_gt_train.mat"))['Houston13_gt_train'] 36 | labels_test = sio.loadmat(os.path.join(data_path, 'HU13_tif', "Houston13_gt_test.mat"))['Houston13_gt_test'] 37 | 38 | return data, labels_train, labels_test 39 | 40 | 41 | def standardization(data): 42 | height, width, bands = data.shape 43 | data = np.reshape(data, [height * width, bands]) 44 | # data=preprocessing.scale(data) # 45 | # data = preprocessing.MinMaxScaler().fit_transform(data) 46 | data = preprocessing.StandardScaler().fit_transform(data) # 47 | 48 | data = np.reshape(data, [height, width, bands]) 49 | return data 50 | 51 | 52 | def sampling(ratio_list, num_list, gt_reshape, class_count, Flag): 53 | all_label_index_dict, train_label_index_dict, test_label_index_dict = {}, {}, {} 54 | all_label_index_list, train_label_index_list, test_label_index_list = [], [], [], 55 | 56 | for cls in range(class_count): # [0-15] 57 | cls_index = np.where(gt_reshape == cls + 1)[0] 58 | all_label_index_dict[cls] = list(cls_index) 59 | 60 | np.random.shuffle(cls_index) 61 | 62 | if Flag == 0: # Fixed proportion for each category 63 | train_index_flag = max(int(ratio_list[0] * len(cls_index)), 3) # at least 3 samples per class] 64 | # Split by num per class 65 | elif Flag == 1: # Fixed quantity per category 66 | if len(cls_index) > num_list[0]: 67 | train_index_flag = num_list[0] 68 | else: 69 | train_index_flag = 15 70 | 71 | train_label_index_dict[cls] = list(cls_index[:train_index_flag]) 72 | test_label_index_dict[cls] = list(cls_index[train_index_flag:]) 73 | 74 | train_label_index_list += train_label_index_dict[cls] 75 | test_label_index_list += test_label_index_dict[cls] 76 | all_label_index_list += all_label_index_dict[cls] 77 | 78 | return train_label_index_list, test_label_index_list, all_label_index_list 79 | 80 | 81 | def sampling_disjoint(gt_train_re, gt_test_re, class_count): 82 | all_label_index_dict, train_label_index_dict, test_label_index_dict = {}, {}, {} 83 | all_label_index_list, train_label_index_list, test_label_index_list = [], [], [] 84 | 85 | for cls in range(class_count): 86 | cls_index_train = np.where(gt_train_re == cls + 1)[0] 87 | cls_index_test = np.where(gt_test_re == cls + 1)[0] 88 | 89 | train_label_index_dict[cls] = list(cls_index_train) 90 | test_label_index_dict[cls] = list(cls_index_test) 91 | 92 | train_label_index_list += train_label_index_dict[cls] 93 | test_label_index_list += test_label_index_dict[cls] 94 | all_label_index_list += (train_label_index_dict[cls] + test_label_index_dict[cls]) 95 | 96 | return train_label_index_list, test_label_index_list, all_label_index_list 97 | 98 | 99 | def applyPCA(X, numComponents=75): 100 | newX = np.reshape(X, (-1, X.shape[2])) 101 | pca = PCA(n_components=numComponents, whiten=True) 102 | newX = pca.fit_transform(newX) 103 | newX = np.reshape(newX, (X.shape[0], X.shape[1], numComponents)) 104 | return newX 105 | 106 | 107 | def HSI_MNF(X, MNF_ratio): 108 | denoised_bands = math.ceil(MNF_ratio * X.shape[-1]) 109 | mnfr = spy.mnf(spy.calc_stats(X), spy.noise_from_diffs(X)) 110 | denoised_data = mnfr.reduce(X, num=denoised_bands) 111 | 112 | return denoised_data 113 | 114 | 115 | def data_pad_zero(data, patch_length): 116 | data_padded = np.lib.pad(data, ((patch_length, patch_length), (patch_length, patch_length), (0, 0)), 'constant', 117 | constant_values=0) 118 | return data_padded 119 | 120 | 121 | def img_show(x): 122 | spy.imshow(x) 123 | plt.show() 124 | 125 | 126 | def index_assignment(index, row, col, pad_length): 127 | new_assign = {} # dictionary. 128 | for counter, value in enumerate(index): 129 | assign_0 = value // col + pad_length 130 | assign_1 = value % col + pad_length 131 | new_assign[counter] = [assign_0, assign_1] 132 | return new_assign 133 | 134 | 135 | def select_patch(data_padded, pos_x, pos_y, patch_length): 136 | selected_patch = data_padded[pos_x - patch_length:pos_x + patch_length + 1, 137 | pos_y - patch_length:pos_y + patch_length + 1] 138 | return selected_patch 139 | 140 | 141 | def select_vector(data_padded, pos_x, pos_y): 142 | select_vector = data_padded[pos_x, pos_y] 143 | return select_vector 144 | 145 | 146 | def HSI_create_pathes(data_padded, hsi_h, hsi_w, data_indexes, patch_length, flag): 147 | h_p, w_p, c = data_padded.shape 148 | 149 | data_size = len(data_indexes) 150 | patch_size = patch_length * 2 + 1 151 | 152 | data_assign = index_assignment(data_indexes, hsi_h, hsi_w, patch_length) 153 | if flag == 1: 154 | # for spatial net data, HSI patch 155 | unit_data = np.zeros((data_size, patch_size, patch_size, c)) 156 | unit_data_torch = torch.from_numpy(unit_data).type(torch.FloatTensor).to(device) 157 | for i in range(len(data_assign)): 158 | unit_data_torch[i] = select_patch(data_padded, data_assign[i][0], data_assign[i][1], patch_length) 159 | if flag == 2: 160 | # for spectral net data, HSI vector 161 | unit_data = np.zeros((data_size, c)) 162 | unit_data_torch = torch.from_numpy(unit_data).type(torch.FloatTensor).to(device) 163 | for i in range(len(data_assign)): 164 | unit_data_torch[i] = select_vector(data_padded, data_assign[i][0], data_assign[i][1]) 165 | 166 | return unit_data_torch 167 | 168 | 169 | def HSI_create_pathes_spatial_multiscale(data, data_indexes, scales): 170 | h, w, c = data.shape 171 | data_size = len(data_indexes) 172 | 173 | CR_data = np.zeros((data_size, scales, c, c)) 174 | CR_data_torch = torch.from_numpy(CR_data).type(torch.FloatTensor).to(device) 175 | 176 | for j in range(scales): 177 | patch_length = j + 1 178 | patch_size = 2 * patch_length + 1 179 | 180 | data_padded = data_pad_zero(data, patch_length) 181 | data_assign = index_assignment(data_indexes, h, w, patch_length) 182 | 183 | unit_data = np.zeros((data_size, patch_size, patch_size, c)) 184 | 185 | for i in range(data_size): 186 | unit_data[i] = select_patch(data_padded, data_assign[i][0], data_assign[i][1], patch_length) 187 | 188 | CR_j = Covar_cor_mat(unit_data) 189 | CR_j = torch.unsqueeze(CR_j, dim=1) 190 | CR_data_torch[:, j:, ] = CR_j 191 | 192 | return CR_data_torch 193 | 194 | 195 | def Covar_cor_mat(x): 196 | x_t = torch.from_numpy(x).type(torch.FloatTensor).to(device) 197 | batch_size, h, w, c = x_t.size() 198 | 199 | x_t = x_t.view(batch_size, h * w, c) 200 | mean_pixel = torch.mean(x_t, dim=1, keepdims=True) 201 | mean_pixel_expand = mean_pixel.expand(x_t.shape) 202 | 203 | x_cr = x_t - mean_pixel_expand 204 | CR = torch.bmm(x_cr.permute(0, 2, 1), x_cr) 205 | CR = torch.div(CR, h * w - 1) 206 | # CR = torch.div(CR, h*w) 207 | del mean_pixel, mean_pixel_expand 208 | 209 | return CR 210 | 211 | 212 | # generating HSI patches using GPU directly. 213 | def generate_iter_ms(data, label_reshape, index, batch_size, 214 | model_3D_spa_flag, scales): 215 | # flag for single spatial net or single spectral net or spectral-spatial net 216 | # data_torch = torch.from_numpy(data).type(torch.FloatTensor).to(device) 217 | 218 | # for data label 219 | train_labels = label_reshape[index[0]] - 1 220 | test_labels = label_reshape[index[1]] - 1 221 | 222 | y_tensor_train = torch.from_numpy(train_labels).type(torch.FloatTensor) 223 | y_tensor_test = torch.from_numpy(test_labels).type(torch.FloatTensor) 224 | 225 | # for data 226 | # data for single spatial net 227 | spa_train_samples = HSI_create_pathes_spatial_multiscale(data, index[0], scales) 228 | spa_test_samples = HSI_create_pathes_spatial_multiscale(data, index[1], scales) 229 | 230 | if model_3D_spa_flag == 1: # spatial 3D patch 231 | spa_train_samples = spa_train_samples.unsqueeze(1) 232 | spa_test_samples = spa_test_samples.unsqueeze(1) 233 | 234 | torch_dataset_train = Data.TensorDataset(spa_train_samples, y_tensor_train) 235 | torch_dataset_test = Data.TensorDataset(spa_test_samples, y_tensor_test) 236 | 237 | train_iter = Data.DataLoader(dataset=torch_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0) 238 | test_iter = Data.DataLoader(dataset=torch_dataset_test, batch_size=batch_size, shuffle=False, num_workers=0) 239 | 240 | del torch_dataset_train, torch_dataset_test, spa_train_samples, spa_test_samples 241 | 242 | return train_iter, test_iter 243 | 244 | 245 | def generate_iter_disjoint_ms(data, gt_train_re, gt_test_re, index, batch_size, 246 | model_3D_spa_flag, scales): 247 | # data_padded_torch = torch.from_numpy(data_padded).type(torch.FloatTensor).to(device) 248 | 249 | train_labels = gt_train_re[index[0]] - 1 250 | test_labels = gt_test_re[index[1]] - 1 251 | 252 | y_tensor_train = torch.from_numpy(train_labels).type(torch.FloatTensor) 253 | y_tensor_test = torch.from_numpy(test_labels).type(torch.FloatTensor) 254 | 255 | # for data 256 | # data for single spatial net 257 | spa_train_samples = HSI_create_pathes_spatial_multiscale(data, index[0], scales) 258 | spa_test_samples = HSI_create_pathes_spatial_multiscale(data, index[1], scales) 259 | 260 | if model_3D_spa_flag == 1: # spatial 3D patch 261 | spa_train_samples = spa_train_samples.unsqueeze(1) 262 | spa_test_samples = spa_test_samples.unsqueeze(1) 263 | 264 | torch_dataset_train = Data.TensorDataset(spa_train_samples, y_tensor_train) 265 | torch_dataset_test = Data.TensorDataset(spa_test_samples, y_tensor_test) 266 | 267 | train_iter = Data.DataLoader(dataset=torch_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0) 268 | test_iter = Data.DataLoader(dataset=torch_dataset_test, batch_size=batch_size, shuffle=False, num_workers=0) 269 | 270 | del torch_dataset_train, torch_dataset_test, spa_train_samples, spa_test_samples 271 | 272 | return train_iter, test_iter 273 | -------------------------------------------------------------------------------- /c_model/SSSAN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : SSSAN.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # unofficial implementation based on part of offical Keras version 9 | # dense net backbone from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/densenet.py 10 | # Spectral–Spatial Self-Attention Networks for Hyperspectral Image Classification, TGRS 2021 11 | 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | from torch import cosine_similarity 17 | 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | class spa_self_atten(nn.Module): 22 | def __init__(self, in_channels): 23 | super(spa_self_atten, self).__init__() 24 | self.to_ab = nn.Conv2d(in_channels, in_channels * 2, 1, bias=False) 25 | self.softmax = nn.Softmax(dim=-1) 26 | 27 | def forward(self, x): 28 | batch_size, c, h, w = x.size() 29 | 30 | a, b = self.to_ab(x).chunk(2, dim=1) 31 | a = a.view(batch_size, -1, h * w).permute(0, 2, 1) 32 | b = b.view(batch_size, -1, h * w).permute(0, 2, 1) 33 | cent_spec_vector = a[:, int((h * w - 1) / 2)] 34 | cent_spec_vector = torch.unsqueeze(cent_spec_vector, 1) 35 | 36 | sim_cosine = cosine_similarity(cent_spec_vector, b, dim=2) # cos 37 | sim_cosine_2 = torch.pow(sim_cosine, 2) # cos^2 38 | 39 | atten_s = self.softmax(sim_cosine_2) 40 | atten_s = torch.unsqueeze(atten_s, 2) 41 | 42 | out = torch.mul(atten_s, b).contiguous().view(batch_size, -1, h, w) + x 43 | 44 | return out 45 | 46 | 47 | class spe_self_atten(nn.Module): 48 | def __init__(self): 49 | super(spe_self_atten, self).__init__() 50 | self.softmax = nn.Softmax(dim=-1) 51 | 52 | def forward(self, x): 53 | batch_size, c, l = x.size() 54 | x = x.to(device) 55 | 56 | sim_cosine_mat = torch.zeros(batch_size, c, c).to(device) 57 | for i in range(c): 58 | target_vector = x[:, i] 59 | target_vector = torch.unsqueeze(target_vector, 1) 60 | sim_cosine_mat[:, i] = cosine_similarity(target_vector, x, dim=2) 61 | atten_s = self.softmax(sim_cosine_mat) 62 | 63 | out = torch.bmm(atten_s, x).contiguous() + x 64 | 65 | return out 66 | 67 | 68 | class transition_2D(nn.Module): 69 | def __init__(self, inplanes, outplanes): 70 | super(transition_2D, self).__init__() 71 | self.bn1 = nn.BatchNorm2d(inplanes) 72 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 73 | bias=False) 74 | self.relu = nn.ReLU(inplace=True) 75 | 76 | def forward(self, x): 77 | out = self.bn1(x) 78 | out = self.relu(out) 79 | out = self.conv1(out) 80 | out = F.avg_pool2d(out, 2) 81 | return out 82 | 83 | def _init_weight(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | nn.init.kaiming_normal_(m.weight, a=self.conv_init_a, mode=self.conv_init_mode, 87 | nonlinearity='leaky_relu') 88 | 89 | 90 | class transition_1D(nn.Module): 91 | def __init__(self, inplanes, outplanes): 92 | super(transition_1D, self).__init__() 93 | self.bn1 = nn.BatchNorm1d(inplanes) 94 | self.conv1 = nn.Conv1d(inplanes, outplanes, kernel_size=1, 95 | bias=False) 96 | self.relu = nn.ReLU(inplace=True) 97 | 98 | def forward(self, x): 99 | out = self.bn1(x) 100 | out = self.relu(out) 101 | out = self.conv1(out) 102 | out = F.avg_pool1d(out, 2) 103 | return out 104 | 105 | def _init_weight(self): 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv1d): 108 | nn.init.kaiming_normal_(m.weight, a=self.conv_init_a, mode=self.conv_init_mode, 109 | nonlinearity='leaky_relu') 110 | 111 | 112 | class spe_dense_bottleneck_1D(nn.Module): 113 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 114 | super(spe_dense_bottleneck_1D, self).__init__() 115 | planes = expansion * growthRate 116 | self.bn1 = nn.BatchNorm1d(inplanes) 117 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False) 118 | self.bn2 = nn.BatchNorm1d(planes) 119 | self.conv2 = nn.Conv1d(planes, growthRate, kernel_size=3, 120 | padding=1, bias=False) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.dropRate = dropRate 123 | 124 | self.spe_atten = spe_self_atten() 125 | 126 | def forward(self, x): 127 | x = x.to(device) 128 | out = self.bn1(x) 129 | out = self.relu(out) 130 | out = self.conv1(out) 131 | out = self.bn2(out) 132 | out = self.relu(out) 133 | out = self.conv2(out) 134 | if self.dropRate > 0: 135 | out = F.dropout(out, p=self.dropRate, training=self.training) 136 | 137 | out = self.spe_atten(out) 138 | out = torch.cat((x, out), 1) 139 | 140 | return out 141 | 142 | def _init_weight(self): 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv1d): 145 | nn.init.kaiming_normal_(m.weight, a=self.conv_init_a, mode=self.conv_init_mode, 146 | nonlinearity='leaky_relu') 147 | 148 | 149 | class spa_dense_bottleneck_2D(nn.Module): 150 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 151 | super(spa_dense_bottleneck_2D, self).__init__() 152 | planes = expansion * growthRate 153 | self.bn1 = nn.BatchNorm2d(inplanes) 154 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 155 | self.bn2 = nn.BatchNorm2d(planes) 156 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 157 | padding=1, bias=False) 158 | self.relu = nn.ReLU(inplace=True) 159 | self.dropRate = dropRate 160 | 161 | self.spa_artten = spa_self_atten(growthRate) 162 | 163 | def forward(self, x): 164 | out = self.bn1(x) 165 | out = self.relu(out) 166 | out = self.conv1(out) 167 | out = self.bn2(out) 168 | out = self.relu(out) 169 | out = self.conv2(out) 170 | if self.dropRate > 0: 171 | out = F.dropout(out, p=self.dropRate, training=self.training) 172 | 173 | out = self.spa_artten(out) 174 | 175 | out = torch.cat((x, out), 1) 176 | 177 | return out 178 | 179 | def _init_weight(self): 180 | for m in self.modules(): 181 | if isinstance(m, nn.Conv2d): 182 | nn.init.kaiming_normal_(m.weight, a=self.conv_init_a, mode=self.conv_init_mode, 183 | nonlinearity='leaky_relu') 184 | 185 | 186 | class SpaNet(nn.Module): 187 | def __init__(self, in_channels, depth=16, dropRate=0, growthRate=22, compressionRate=2): 188 | super(SpaNet, self).__init__() 189 | 190 | self.growthRate = growthRate 191 | self.dropRate = dropRate 192 | 193 | self.in_channels = in_channels 194 | self.inplanes = growthRate * 2 195 | 196 | n = (depth - 4) // 6 197 | self.conv1 = nn.Conv2d(self.in_channels, self.inplanes, kernel_size=3, padding=1, 198 | bias=False) 199 | 200 | self.spa_dense_1 = self._make_denseblock(spa_dense_bottleneck_2D, n) 201 | self.trans_2d_1 = self._make_transition(compressionRate) 202 | self.spa_dense_2 = self._make_denseblock(spa_dense_bottleneck_2D, n) 203 | self.trans_2d_2 = self._make_transition(compressionRate) 204 | self.spa_dense_3 = self._make_denseblock(spa_dense_bottleneck_2D, n) 205 | self.trans_2d_3 = self._make_transition(compressionRate) 206 | 207 | self.bn = nn.BatchNorm2d(self.inplanes) 208 | self.relu = nn.ReLU(inplace=True) 209 | 210 | self.avgpool = nn.AdaptiveAvgPool2d(1) 211 | self.flatten = nn.Flatten(1) 212 | self.fc = nn.Linear(self.inplanes, 32) 213 | 214 | def _make_denseblock(self, block, blocks): 215 | layers = [] 216 | for i in range(blocks): 217 | # Currently we fix the expansion ratio as the default value 218 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 219 | self.inplanes += self.growthRate 220 | 221 | return nn.Sequential(*layers) 222 | 223 | def _make_transition(self, compressionRate): 224 | inplanes = self.inplanes 225 | outplanes = int(math.floor(self.inplanes // compressionRate)) 226 | self.inplanes = outplanes 227 | return transition_2D(inplanes, outplanes) 228 | 229 | def forward(self, x_2d): 230 | x_2d = x_2d.permute(0, 3, 1, 2) 231 | 232 | x = self.conv1(x_2d) 233 | 234 | x = self.trans_2d_1(self.spa_dense_1(x)) 235 | x = self.trans_2d_2(self.spa_dense_2(x)) 236 | x = self.trans_2d_3(self.spa_dense_3(x)) 237 | 238 | x = self.bn(x) 239 | x = self.relu(x) 240 | x = self.avgpool(x) 241 | x = self.flatten(x) 242 | x = self.fc(x) 243 | 244 | return x 245 | 246 | 247 | class SpeNet(nn.Module): 248 | def __init__(self, depth=16, dropRate=0, growthRate=22, compressionRate=2): 249 | super(SpeNet, self).__init__() 250 | self.growthRate = growthRate 251 | self.dropRate = dropRate 252 | 253 | self.in_channels = 1 254 | self.inplanes = growthRate * 2 255 | 256 | n = (depth - 4) // 6 257 | self.conv1 = nn.Conv1d(self.in_channels, self.inplanes, kernel_size=3, padding=1, 258 | bias=False) 259 | 260 | self.spe_dense_1 = self._make_denseblock(spe_dense_bottleneck_1D, n) 261 | self.trans_1d_1 = self._make_transition(compressionRate) 262 | self.spe_dense_2 = self._make_denseblock(spe_dense_bottleneck_1D, n) 263 | self.trans_1d_2 = self._make_transition(compressionRate) 264 | self.spe_dense_3 = self._make_denseblock(spe_dense_bottleneck_1D, n) 265 | self.trans_1d_3 = self._make_transition(compressionRate) 266 | 267 | self.bn = nn.BatchNorm1d(self.inplanes) 268 | self.relu = nn.ReLU(inplace=True) 269 | 270 | self.avgpool = nn.AdaptiveAvgPool1d(1) 271 | self.flatten = nn.Flatten(1) 272 | self.fc = nn.Linear(self.inplanes, 32) 273 | 274 | def _make_denseblock(self, block, blocks): 275 | layers = [] 276 | for i in range(blocks): 277 | # Currently we fix the expansion ratio as the default value 278 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 279 | self.inplanes += self.growthRate 280 | 281 | return nn.Sequential(*layers) 282 | 283 | def _make_transition(self, compressionRate): 284 | inplanes = self.inplanes 285 | outplanes = int(math.floor(self.inplanes // compressionRate)) 286 | self.inplanes = outplanes 287 | return transition_1D(inplanes, outplanes) 288 | 289 | def forward(self, x_spe): 290 | x_spe = torch.unsqueeze(x_spe, 1) 291 | 292 | x = self.conv1(x_spe) 293 | 294 | x = self.trans_1d_1(self.spe_dense_1(x)) 295 | x = self.trans_1d_2(self.spe_dense_2(x)) 296 | x = self.trans_1d_3(self.spe_dense_3(x)) 297 | 298 | x = self.bn(x) 299 | x = self.relu(x) 300 | x = self.avgpool(x) 301 | x = self.flatten(x) 302 | x = self.fc(x) 303 | 304 | return x 305 | 306 | 307 | class SSSAN(nn.Module): 308 | def __init__(self, in_channels, dr_channels, class_count): 309 | super(SSSAN, self).__init__() 310 | self.in_channels = in_channels 311 | self.dr_channels = dr_channels 312 | self.class_count = class_count 313 | 314 | self.spa_net = SpaNet(self.dr_channels) 315 | self.spe_net = SpeNet() 316 | self.lamuda = nn.Parameter(torch.tensor(0.5, dtype=float), requires_grad=True) 317 | self.fc = nn.Linear(32, self.class_count) 318 | 319 | def forward(self, X_spa, X_spe): 320 | # X_spa=X_spa.type(torch.float) 321 | # X_spe=X_spe.type(torch.float) 322 | out_spa = self.spa_net(X_spa) 323 | out_spe = self.spe_net(X_spe) 324 | 325 | lmd = torch.sigmoid(self.lamuda) 326 | out = lmd * out_spa + (1 - lmd) * out_spe 327 | out = self.fc(out) 328 | 329 | return out 330 | -------------------------------------------------------------------------------- /c_model/SSTN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : SSTN.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # https://github.com/zilongzhong/SSTN/blob/main/NetworksBlocks.py 9 | # we adopt the final structure, i.e., AEAE in the original paper 10 | # Spectral-Spatial Transformer Network for Hyperspectral Image Classification: A Factorized Architecture Search Framework, TGRS 2021 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | class SpatAttn(nn.Module): 17 | """ Position attention module""" 18 | #Ref from SAGAN 19 | def __init__(self, in_dim, ratio=8): 20 | super(SpatAttn, self).__init__() 21 | self.chanel_in = in_dim 22 | 23 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) 24 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) 25 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 26 | self.gamma = nn.Parameter(torch.zeros(1)) 27 | 28 | self.softmax = nn.Softmax(dim=-1) 29 | 30 | def forward(self, x): 31 | """ 32 | inputs : 33 | x : input feature maps( B X C X H X W) 34 | returns : 35 | out : attention value + input feature 36 | attention: B X (HxW) X (HxW) 37 | """ 38 | m_batchsize, C, height, width = x.size() # BxCxHxW 39 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) # BxHWxC 40 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) # BxCxHW 41 | energy = torch.bmm(proj_query, proj_key) # BxHWxHW, attention maps 42 | attention = self.softmax(energy) # BxHWxHW, normalized attn maps 43 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) # BxCxHW 44 | 45 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # BxCxHW 46 | out = out.view(m_batchsize, C, height, width) # BxCxHxW 47 | 48 | out = self.gamma*out + x 49 | return out 50 | 51 | class SpatAttn_(nn.Module): 52 | """ Position attention module""" 53 | #Ref from SAGAN 54 | def __init__(self, in_dim, ratio=8): 55 | super(SpatAttn_, self).__init__() 56 | self.chanel_in = in_dim 57 | 58 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) 59 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) 60 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 61 | 62 | self.gamma = nn.Parameter(torch.zeros(1)) 63 | 64 | self.softmax = nn.Softmax(dim=-1) 65 | self.bn = nn.Sequential(nn.ReLU(), 66 | nn.BatchNorm2d(in_dim)) 67 | 68 | def forward(self, x): 69 | """ 70 | inputs : 71 | x : input feature maps( B X C X H X W) 72 | returns : 73 | out : attention value + input feature 74 | attention: B X (HxW) X (HxW) 75 | """ 76 | m_batchsize, C, height, width = x.size() # BxCxHxW 77 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) # BxHWxC 78 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) # BxCxHW 79 | energy = torch.bmm(proj_query, proj_key) # BxHWxHW, attention maps 80 | attention = self.softmax(energy) # BxHWxHW, normalized attn maps 81 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) # BxCxHW 82 | 83 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # BxCxHW 84 | out = out.view(m_batchsize, C, height, width) # BxCxHxW 85 | 86 | out = self.gamma*out #+ x 87 | return self.bn(out) 88 | 89 | class SARes(nn.Module): 90 | def __init__(self, in_dim, ratio=8, resin=False): 91 | super(SARes, self).__init__() 92 | 93 | if resin: 94 | self.sa1 = SpatAttn(in_dim, ratio) 95 | self.sa2 = SpatAttn(in_dim, ratio) 96 | else: 97 | self.sa1 = SpatAttn_(in_dim, ratio) 98 | self.sa2 = SpatAttn_(in_dim, ratio) 99 | 100 | def forward(self, x): 101 | identity = x 102 | x = self.sa1(x) 103 | x = self.sa2(x) 104 | 105 | return F.relu(x + identity) 106 | 107 | 108 | class SPC3(nn.Module): 109 | def __init__(self, msize=24, outplane=49, kernel_size=[7,1,1], stride=[1,1,1], padding=[3,0,0], spa_size=9, bias=True): 110 | super(SPC3, self).__init__() 111 | 112 | self.convm0 = nn.Conv3d(1, msize, kernel_size=kernel_size, padding=padding) # generate mask0 113 | self.convm1 = nn.Conv3d(1, msize, kernel_size=kernel_size, padding=padding) # generate mask1 114 | 115 | self.bn2 = nn.BatchNorm2d(outplane) 116 | 117 | def forward(self, x): 118 | 119 | identity = x # NCHW 120 | #n,c,h,w = identity.size() 121 | 122 | mask0 = self.convm0(x.unsqueeze(1)).squeeze(2) # NCHW ==> NDHW 123 | n,_,h,w = mask0.size() 124 | 125 | mask0 = torch.softmax(mask0.view(n,-1,h*w), -1) 126 | mask0 = mask0.view(n,-1,h,w) 127 | _,d,_,_ = mask0.size() 128 | 129 | mask1 = self.convm0(x.unsqueeze(1)).squeeze(2) # NDHW 130 | mask1 = torch.softmax(mask1.view(n,-1,h*w), -1) 131 | mask1 = mask1.view(n,-1,h,w) 132 | #print(mask1.size()) 133 | 134 | fk = torch.einsum('ndhw,nchw->ncd', mask0, x) # NCD 135 | 136 | out = torch.einsum('ncd,ndhw->ncdhw', fk, mask1) # NCDHW 137 | 138 | out = F.leaky_relu(out) 139 | out = out.sum(2) 140 | 141 | out = out + identity 142 | 143 | out = self.bn2(out.view(n,-1,h,w)) 144 | 145 | return out # NCHW 146 | 147 | class SPC32(nn.Module): 148 | def __init__(self, msize=24, outplane=49, kernel_size=[7,1,1], stride=[1,1,1], padding=[3,0,0], spa_size=9, bias=True): 149 | super(SPC32, self).__init__() 150 | 151 | self.convm0 = nn.Conv3d(1, msize, kernel_size=kernel_size, padding=padding) # generate mask0 152 | self.bn1 = nn.BatchNorm2d(outplane) 153 | 154 | self.convm2 = nn.Conv3d(1, msize, kernel_size=kernel_size, padding=padding) # generate mask2 155 | self.bn2 = nn.BatchNorm2d(outplane) 156 | 157 | 158 | def forward(self, x, identity=None): 159 | 160 | if identity is None: 161 | identity = x # NCHW 162 | n,c,h,w = identity.size() 163 | 164 | mask0 = self.convm0(x.unsqueeze(1)).squeeze(2) # NCHW ==> NDHW 165 | mask0 = torch.softmax(mask0.view(n,-1,h*w), -1) 166 | mask0 = mask0.view(n,-1,h,w) 167 | _,d,_,_ = mask0.size() 168 | 169 | fk = torch.einsum('ndhw,nchw->ncd', mask0, x) # NCD 170 | 171 | out = torch.einsum('ncd,ndhw->ncdhw', fk, mask0) # NCDHW 172 | 173 | out = F.leaky_relu(out) 174 | out = out.sum(2) 175 | 176 | out = out #+ identity 177 | 178 | out0 = self.bn1(out.view(n,-1,h,w)) 179 | 180 | mask2 = self.convm2(out0.unsqueeze(1)).squeeze(2) # NCHW ==> NDHW 181 | mask2 = torch.softmax(mask2.view(n,-1,h*w), -1) 182 | mask2 = mask2.view(n,-1,h,w) 183 | 184 | fk = torch.einsum('ndhw,nchw->ncd', mask2, x) # NCD 185 | 186 | out = torch.einsum('ncd,ndhw->ncdhw', fk, mask2) # NCDHW 187 | 188 | out = F.leaky_relu(out) 189 | out = out.sum(2) 190 | 191 | out = out + identity 192 | 193 | out = self.bn2(out.view(n,-1,h,w)) 194 | 195 | return out # NCHW 196 | 197 | class SPCModule(nn.Module): 198 | def __init__(self, in_channels, out_channels, bias=True): 199 | super(SPCModule, self).__init__() 200 | 201 | self.s1 = nn.Conv3d(in_channels, out_channels, kernel_size=(7,1,1), padding=(3,0,0), bias=False) 202 | #self.bn = nn.BatchNorm3d(out_channels) 203 | 204 | def forward(self, input): 205 | 206 | out = self.s1(input) 207 | 208 | return out 209 | 210 | class SPCModuleIN(nn.Module): 211 | def __init__(self, in_channels, out_channels, bias=True): 212 | super(SPCModuleIN, self).__init__() 213 | 214 | self.s1 = nn.Conv3d(in_channels, out_channels, kernel_size=(7,1,1), stride=(2,1,1), bias=False) 215 | #self.bn = nn.BatchNorm3d(out_channels) 216 | 217 | def forward(self, input): 218 | 219 | input = input.unsqueeze(1) 220 | 221 | out = self.s1(input) 222 | 223 | return out.squeeze(1) 224 | 225 | class SPAModuleIN(nn.Module): 226 | def __init__(self, in_channels, out_channels, k=49, bias=True): 227 | super(SPAModuleIN, self).__init__() 228 | 229 | # print('k=',k) 230 | self.s1 = nn.Conv3d(in_channels, out_channels, kernel_size=(k,3,3), bias=False) 231 | #self.bn = nn.BatchNorm2d(out_channels) 232 | 233 | def forward(self, input): 234 | 235 | # print(input.size()) 236 | out = self.s1(input) 237 | out = out.squeeze(2) 238 | # print(out.size) 239 | 240 | return out 241 | 242 | class ResSPC(nn.Module): 243 | def __init__(self, in_channels, out_channels, bias=True): 244 | super(ResSPC, self).__init__() 245 | 246 | self.spc1 = nn.Sequential(nn.Conv3d(in_channels, in_channels, kernel_size=(7,1,1), padding=(3,0,0), bias=False), 247 | nn.LeakyReLU(inplace=True), 248 | nn.BatchNorm3d(in_channels),) 249 | 250 | self.spc2 = nn.Sequential(nn.Conv3d(in_channels, in_channels, kernel_size=(7,1,1), padding=(3,0,0), bias=False), 251 | nn.LeakyReLU(inplace=True),) 252 | 253 | self.bn2 = nn.BatchNorm3d(out_channels) 254 | 255 | def forward(self, input): 256 | 257 | out = self.spc1(input) 258 | out = self.bn2(self.spc2(out)) 259 | 260 | return F.leaky_relu(out + input) 261 | 262 | class ResSPA(nn.Module): 263 | def __init__(self, in_channels, out_channels, bias=True): 264 | super(ResSPA, self).__init__() 265 | 266 | self.spa1 = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), 267 | nn.LeakyReLU(inplace=True), 268 | nn.BatchNorm2d(in_channels),) 269 | 270 | self.spa2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 271 | nn.LeakyReLU(inplace=True),) 272 | self.bn2 = nn.BatchNorm2d(out_channels) 273 | 274 | def forward(self, input): 275 | 276 | out = self.spa1(input) 277 | out = self.bn2(self.spa2(out)) 278 | 279 | return F.leaky_relu(out + input) 280 | 281 | 282 | 283 | class SSNet_AEAE_IN(nn.Module): 284 | def __init__(self, inshape,num_classes=16, msize=16, inter_size=49): 285 | super(SSNet_AEAE_IN, self).__init__() 286 | self.channels=inshape[0] 287 | 288 | # self.layer1 = SPCModuleIN_(1, 1, inter_size=inter_size) 289 | # self.bn1 = nn.BatchNorm2d(inter_size) 290 | self.layer1 = nn.Sequential(nn.Conv2d(self.channels, inter_size, 1), 291 | #nn.LeakyReLU(), 292 | nn.BatchNorm2d(inter_size),) 293 | #nn.LeakyReLU()) 294 | # self.layer1 = SPC1d(stride=[2,1], padding=[0,0]) 295 | 296 | self.layer2 = SARes(inter_size, ratio=8) #ResSPA(inter_size, inter_size) 297 | self.layer3 = SPC32(msize, outplane=inter_size, kernel_size=[inter_size,1,1], padding=[0,0,0]) 298 | 299 | self.layer4 = nn.Conv2d(inter_size, msize, kernel_size=1) 300 | self.bn4 = nn.BatchNorm2d(msize) 301 | 302 | self.layer5 = SARes(msize, ratio=8) #ResSPA(msize, msize) 303 | self.layer6 = SPC32(msize, outplane=msize, kernel_size=[msize,1,1], padding=[0,0,0]) 304 | 305 | self.fc = nn.Linear(msize, num_classes) 306 | 307 | def forward(self, x): 308 | # n,c,h,w = x.size() 309 | x = x.permute(0,3,1,2) 310 | 311 | x = self.layer1(x) 312 | 313 | x = self.layer2(x) 314 | x = self.layer3(x) 315 | #x = self.layer31(x) 316 | 317 | # x = x.contiguous() 318 | # x = x.reshape(n,-1,h,w) 319 | 320 | x = self.bn4(F.leaky_relu(self.layer4(x))) 321 | x = self.layer5(x) 322 | x = self.layer6(x) 323 | 324 | x = F.avg_pool2d(x, x.size()[-1]) 325 | x = self.fc(x.squeeze(dim=-1).squeeze(dim=-1)) 326 | 327 | return x 328 | 329 | 330 | def SSTN_AEAE(in_shape,num_classes): 331 | 332 | model=SSNet_AEAE_IN(in_shape,num_classes) 333 | 334 | return model -------------------------------------------------------------------------------- /process_dl_disjoint_c_model_m_scale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : process_dl_disjoint_c_model_m_scale.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # considering the multiscle feature representation of MCM-CNN 9 | # the customized main processing file for this compared method on the UH data sets 10 | 11 | import os 12 | import time 13 | import torch 14 | import random 15 | import numpy as np 16 | from sklearn import metrics 17 | 18 | import utils.evaluation as evaluation 19 | import utils.data_load_operate_c_model_m_scale as data_load_operate 20 | import c_model.MCM_CNN as MCM_CNN 21 | 22 | time_current = time.strftime("%y-%m-%d-%H.%M", time.localtime()) 23 | 24 | # random seed setting 25 | seed = 20 26 | 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | np.random.seed(seed) # Numpy module. 31 | random.seed(seed) # Python random module. 32 | torch.manual_seed(seed) 33 | torch.backends.cudnn.benchmark = False 34 | torch.backends.cudnn.deterministic = True 35 | 36 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 37 | 38 | ### 0 ## 39 | model_list = ['MCM-CNN'] 40 | model_flag = 0 41 | model_spa_set = {0} 42 | model_spe_set = {} 43 | model_spa_spe_set = {} 44 | model_3D_spa_set = {} 45 | model_3D_spa_flag = 0 46 | 47 | last_batch_flag = 0 48 | 49 | if model_flag in model_spa_set: 50 | model_type_flag = 1 51 | if model_flag in model_3D_spa_set: 52 | model_3D_spa_flag = 1 53 | elif model_flag in model_spe_set: 54 | model_type_flag = 2 55 | elif model_flag in model_spa_spe_set: 56 | model_type_flag = 3 57 | 58 | data_set_name_list = ['UH_tif'] 59 | data_set_name = data_set_name_list[0] 60 | 61 | data_set_path = os.path.join(os.getcwd(), 'data') 62 | 63 | # control running times 64 | # seed_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 65 | # seed_list=[0,1,2,3,4] 66 | # seed_list=[0,1,2] 67 | # seed_list=[0,1] 68 | seed_list = [0] 69 | 70 | ratio = "hu13" 71 | 72 | patch_size = 9 73 | patch_length = 4 74 | 75 | results_save_path = \ 76 | os.path.join(os.getcwd(), 'output/results', model_list[model_flag] + str("_") + 77 | data_set_name + str("_") + str(time_current) + str("_seed") + str(seed) + str("_ratio_") + str( 78 | ratio) + str("_patch_size") + str(patch_size)) 79 | cls_map_save_path = \ 80 | os.path.join(os.path.join(os.getcwd(), 'output/cls_maps'), model_list[model_flag] + str("_") + 81 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str(ratio)) 82 | 83 | if __name__ == '__main__': 84 | data, gt_train, gt_test = data_load_operate.load_HU_data(data_set_path) 85 | data = data_load_operate.standardization(data) 86 | 87 | ratio = round(20 / data.shape[-1], 2) 88 | data = data_load_operate.HSI_MNF(data, MNF_ratio=ratio) 89 | 90 | gt_train_re = gt_train.reshape(-1) 91 | gt_test_re = gt_test.reshape(-1) 92 | height, width, channels = data.shape 93 | class_count = max(np.unique(gt_train_re)) 94 | 95 | batch_size = 100 96 | learning_rate = 1e-3 97 | scales = 15 98 | max_epoch = 40 99 | loss = torch.nn.CrossEntropyLoss() 100 | 101 | OA_ALL = [] 102 | AA_ALL = [] 103 | KPP_ALL = [] 104 | EACH_ACC_ALL = [] 105 | Train_Time_ALL = [] 106 | Test_Time_ALL = [] 107 | CLASS_ACC = np.zeros([len(seed_list), class_count]) 108 | 109 | for curr_seed in seed_list: 110 | tic1 = time.perf_counter() 111 | train_data_index, test_data_index, all_data_index = data_load_operate.sampling_disjoint(gt_train_re, 112 | gt_test_re, 113 | class_count) 114 | index = (train_data_index, test_data_index) 115 | 116 | train_iter, test_iter = data_load_operate.generate_iter_disjoint_ms(data, gt_train_re, 117 | gt_test_re, index, 118 | batch_size, 119 | model_3D_spa_flag, scales) 120 | 121 | if model_flag == 0: 122 | net = MCM_CNN.MCM_CNN_(scales, class_count, data_set_name) 123 | 124 | net.to(device) 125 | 126 | train_loss_list = [100] 127 | train_acc_list = [0] 128 | 129 | optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=5e-4) 130 | 131 | for epoch in range(max_epoch): 132 | train_acc_sum, trained_samples_counter = 0.0, 0 133 | batch_counter, train_loss_sum = 0, 0 134 | time_epoch = time.time() 135 | 136 | if model_type_flag == 1: # data for single spatial net 137 | for X_spa, y in train_iter: 138 | X_spa, y = X_spa.to(device), y.to(device) 139 | y_pred = net(X_spa) 140 | 141 | ls = loss(y_pred, y.long()) 142 | 143 | optimizer.zero_grad() 144 | ls.backward() 145 | optimizer.step() 146 | 147 | train_loss_sum += ls.cpu().item() 148 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 149 | trained_samples_counter += y.shape[0] 150 | batch_counter += 1 151 | epoch_first_iter = 0 152 | elif model_type_flag == 2: # data for single spectral net 153 | for X_spe, y in train_iter: 154 | X_spe, y = X_spe.to(device), y.to(device) 155 | y_pred = net(X_spe) 156 | 157 | ls = loss(y_pred, y.long()) 158 | 159 | optimizer.zero_grad() 160 | ls.backward() 161 | optimizer.step() 162 | 163 | train_loss_sum += ls.cpu().item() 164 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 165 | trained_samples_counter += y.shape[0] 166 | batch_counter += 1 167 | epoch_first_iter = 0 168 | elif model_type_flag == 3: # data for spectral-spatial net 169 | for X_spa, X_spe, y in train_iter: 170 | X_spa, X_spe, y = X_spa.to(device), X_spe.to(device), y.to(device) 171 | y_pred = net(X_spa, X_spe) 172 | 173 | ls = loss(y_pred, y.long()) 174 | 175 | optimizer.zero_grad() 176 | ls.backward() 177 | optimizer.step() 178 | 179 | train_loss_sum += ls.cpu().item() 180 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 181 | trained_samples_counter += y.shape[0] 182 | batch_counter += 1 183 | epoch_first_iter = 0 184 | 185 | torch.cuda.empty_cache() 186 | 187 | train_loss_list.append(train_loss_sum) 188 | train_acc_list.append(train_acc_sum / trained_samples_counter) 189 | 190 | print('epoch: %d, training_sampler_num: %d, batch_count: %.2f, train loss: %.6f, tarin loss sum: %.6f, ' 191 | 'train acc: %.3f, train_acc_sum: %.1f, time: %.1f sec' % 192 | (epoch + 1, trained_samples_counter, batch_counter, train_loss_sum / batch_counter, train_loss_sum, 193 | train_acc_sum / trained_samples_counter, train_acc_sum, time.time() - time_epoch)) 194 | 195 | toc1 = time.perf_counter() 196 | print('Training stage finished:\n epoch %d, loss %.4f, train acc %.3f, training time %.2f s' 197 | % (epoch + 1, train_loss_sum / batch_counter, train_acc_sum / trained_samples_counter, toc1 - tic1)) 198 | training_time = toc1 - tic1 199 | Train_Time_ALL.append(training_time) 200 | 201 | print("\n\n====================Starting evaluation for testing set.========================\n") 202 | 203 | pred_test = [] 204 | y_gt = [] 205 | # torch.cuda.empty_cache() 206 | with torch.no_grad(): 207 | # net.load_state_dict(torch.load(model_save_path+"_best_model.pt")) 208 | net.eval() 209 | train_acc_sum, samples_num_counter = 0.0, 0 210 | if model_type_flag == 1: # data for single spatial net 211 | for X_spa, y in test_iter: 212 | X_spa = X_spa.to(device) 213 | 214 | tic2 = time.perf_counter() 215 | y_pred = net(X_spa) 216 | toc2 = time.perf_counter() 217 | 218 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 219 | y_gt.extend(y) 220 | elif model_type_flag == 2: # data for single spectral net 221 | for X_spe, y in test_iter: 222 | X_spe = X_spe.to(device) 223 | 224 | tic2 = time.perf_counter() 225 | y_pred = net(X_spe) 226 | toc2 = time.perf_counter() 227 | 228 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 229 | y_gt.extend(y) 230 | elif model_type_flag == 3: # data for spectral-spatial net 231 | for X_spa, X_spe, y in test_iter: 232 | X_spa = X_spa.to(device) 233 | X_spe = X_spe.to(device) 234 | 235 | tic2 = time.perf_counter() 236 | y_pred = net(X_spa, X_spe) 237 | toc2 = time.perf_counter() 238 | 239 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 240 | y_gt.extend(y) 241 | 242 | OA = metrics.accuracy_score(y_gt, pred_test) 243 | confusion_matrix = metrics.confusion_matrix(pred_test, y_gt) 244 | print("confusion_matrix\n{}".format(confusion_matrix)) 245 | ECA, AA = evaluation.AA_ECA(confusion_matrix) 246 | kappa = metrics.cohen_kappa_score(pred_test, y_gt) 247 | cls_report = evaluation.claification_report(y_gt, pred_test, data_set_name) 248 | print("classification_report\n{}".format(cls_report)) 249 | 250 | # Visualization for all the labeled samples and total the samples 251 | # sample_list1 = [total_iter] 252 | # sample_list2 = [all_iter, all_data_index] 253 | 254 | # Visualization.gt_cls_map(gt,cls_map_save_path) 255 | # cls_visual.pred_cls_map_dl(sample_list1,net,gt,cls_map_save_path,model_type_flag) 256 | # cls_visual.pred_cls_map_dl(sample_list2,net,gt,cls_map_save_path) 257 | 258 | testing_time = toc2 - tic2 259 | Test_Time_ALL.append(testing_time) 260 | 261 | # Output infors 262 | f = open(results_save_path + '_results.txt', 'a+') 263 | str_results = '\n======================' \ 264 | + " learning rate=" + str(learning_rate) \ 265 | + " epochs=" + str(max_epoch) \ 266 | + " ======================" \ 267 | + "\nOA=" + str(OA) \ 268 | + "\nAA=" + str(AA) \ 269 | + '\nkpp=' + str(kappa) \ 270 | + '\nacc per class:' + str(ECA) \ 271 | + "\ntrain time:" + str(training_time) \ 272 | + "\ntest time:" + str(testing_time) + "\n" 273 | 274 | f.write(str_results) 275 | f.write('{}'.format(confusion_matrix)) 276 | f.write('\n\n') 277 | f.write('{}'.format(cls_report)) 278 | f.close() 279 | 280 | OA_ALL.append(OA) 281 | AA_ALL.append(AA) 282 | KPP_ALL.append(kappa) 283 | EACH_ACC_ALL.append(ECA) 284 | 285 | torch.cuda.empty_cache() 286 | del net, train_iter, test_iter 287 | 288 | OA_ALL = np.array(OA_ALL) 289 | AA_ALL = np.array(AA_ALL) 290 | KPP_ALL = np.array(KPP_ALL) 291 | EACH_ACC_ALL = np.array(EACH_ACC_ALL) 292 | Train_Time_ALL = np.array(Train_Time_ALL) 293 | Test_Time_ALL = np.array(Test_Time_ALL) 294 | 295 | np.set_printoptions(precision=4) 296 | print("\n====================Mean result of {} times runs =========================".format(len(seed_list))) 297 | print('List of OA:', list(OA_ALL)) 298 | print('List of AA:', list(AA_ALL)) 299 | print('List of KPP:', list(KPP_ALL)) 300 | print('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2)) 301 | print('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2)) 302 | print('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2)) 303 | print('Acc per class=', np.mean(EACH_ACC_ALL, 0), '+-', np.std(EACH_ACC_ALL, 0)) 304 | 305 | print("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3)) 306 | print("Average testing time=", round(np.mean(Test_Time_ALL), 5), '+-', round(np.std(Test_Time_ALL), 5)) 307 | 308 | # Output infors 309 | f = open(results_save_path + '_results.txt', 'a+') 310 | str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + ' times runs ********************' \ 311 | + '\nList of OA:' + str(list(OA_ALL)) \ 312 | + '\nList of AA:' + str(list(AA_ALL)) \ 313 | + '\nList of KPP:' + str(list(KPP_ALL)) \ 314 | + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \ 315 | + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \ 316 | + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(round(np.std(KPP_ALL) * 100, 2)) \ 317 | + '\nAcc per class=\n' + str(np.mean(EACH_ACC_ALL, 0)) + '+-' + str(np.std(EACH_ACC_ALL, 0)) \ 318 | + "\nAverage training time=" + str(round(np.mean(Train_Time_ALL), 2)) + '+-' + str( 319 | round(np.std(Train_Time_ALL), 3)) \ 320 | + "\nAverage testing time=" + str(round(np.mean(Test_Time_ALL), 5)) + '+-' + str( 321 | round(np.std(Test_Time_ALL), 5)) 322 | f.write(str_results) 323 | f.close() 324 | -------------------------------------------------------------------------------- /process_dl_c_model_m_scale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : process_dl_c_model_m_scale.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # considering the multiscle feature representation of MCM-CNN 9 | # the customized main processing file for this compared method on IP and UP data sets 10 | 11 | import os 12 | import time 13 | import torch 14 | import random 15 | import numpy as np 16 | from sklearn import metrics 17 | 18 | import utils.evaluation as evaluation 19 | import utils.data_load_operate_c_model_m_scale as data_load_operate 20 | import visual.cls_visual as cls_visual 21 | 22 | import c_model.MCM_CNN as MCM_CNN 23 | 24 | time_current = time.strftime("%y-%m-%d-%H.%M", time.localtime()) 25 | 26 | # random seed setting 27 | seed = 20 28 | 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | np.random.seed(seed) # Numpy module. 33 | random.seed(seed) # Python random module. 34 | torch.manual_seed(seed) 35 | torch.backends.cudnn.benchmark = False 36 | torch.backends.cudnn.deterministic = True 37 | 38 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 39 | 40 | ### 0 ## 41 | model_list = ['MCM-CNN'] 42 | model_flag = 0 43 | model_spa_set = {0} 44 | model_spe_set = {} 45 | model_spa_spe_set = {} 46 | model_3D_spa_set = {} 47 | model_3D_spa_flag = 0 48 | 49 | last_batch_flag = 0 50 | 51 | if model_flag in model_spa_set: 52 | model_type_flag = 1 53 | if model_flag in model_3D_spa_set: 54 | model_3D_spa_flag = 1 55 | elif model_flag in model_spe_set: 56 | model_type_flag = 2 57 | elif model_flag in model_spa_spe_set: 58 | model_type_flag = 3 59 | 60 | data_set_name_list = ['IP', 'UP'] 61 | data_set_name = data_set_name_list[1] 62 | 63 | data_set_path = os.path.join(os.getcwd(), 'data') 64 | 65 | # control running times 66 | # seed_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 67 | # seed_list=[0,1,2,3,4] 68 | # seed_list=[0,1,2] 69 | # seed_list=[0,1] 70 | seed_list = [0] 71 | 72 | # data set split 73 | flag_list = [0, 1] # ratio or num 74 | 75 | if data_set_name == 'IP': 76 | ratio_list = [0.05, 0.005] 77 | ratio = 5.0 78 | elif data_set_name == 'UP': 79 | ratio_list = [0.01, 0.001] 80 | ratio = 1.0 81 | 82 | num_list = [50, 0] # [train_num,val_num] 83 | 84 | patch_size = 9 85 | patch_length = 4 86 | 87 | results_save_path = \ 88 | os.path.join(os.getcwd(), 'output/results', model_list[model_flag] + str("_") + 89 | data_set_name + str("_") + str(time_current) + str("_seed") + str(seed) + str("_ratio_") + str( 90 | ratio) + str("_patch_size") + str(patch_size)) 91 | cls_map_save_path = \ 92 | os.path.join(os.path.join(os.getcwd(), 'output/cls_maps'), model_list[model_flag] + str("_") + 93 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str(ratio)) 94 | 95 | if __name__ == '__main__': 96 | 97 | data, gt = data_load_operate.load_data(data_set_name, data_set_path) 98 | data = data_load_operate.standardization(data) 99 | 100 | if model_flag == 0: 101 | ratio = round(20 / data.shape[-1], 2) 102 | data = data_load_operate.HSI_MNF(data, MNF_ratio=ratio) 103 | 104 | gt_reshape = gt.reshape(-1) 105 | height, width, channels = data.shape 106 | class_count = max(np.unique(gt)) 107 | 108 | if model_flag == 0: 109 | batch_size = 100 110 | learning_rate = 1e-3 111 | 112 | scales = 15 113 | max_epoch = 100 114 | loss = torch.nn.CrossEntropyLoss() 115 | 116 | OA_ALL = [] 117 | AA_ALL = [] 118 | KPP_ALL = [] 119 | EACH_ACC_ALL = [] 120 | Train_Time_ALL = [] 121 | Test_Time_ALL = [] 122 | CLASS_ACC = np.zeros([len(seed_list), class_count]) 123 | 124 | for curr_seed in seed_list: 125 | tic1 = time.perf_counter() 126 | train_data_index, test_data_index, all_data_index = data_load_operate.sampling(ratio_list, 127 | num_list, 128 | gt_reshape, 129 | class_count, 130 | flag_list[0]) 131 | index = (train_data_index, test_data_index) 132 | 133 | train_iter, test_iter = data_load_operate.generate_iter_ms(data, gt_reshape, index, 134 | batch_size, 135 | model_3D_spa_flag, scales) 136 | 137 | if model_flag == 0: 138 | net = MCM_CNN.MCM_CNN_(scales, class_count, data_set_name) 139 | 140 | net.to(device) 141 | 142 | train_loss_list = [100] 143 | train_acc_list = [0] 144 | 145 | if model_flag == 0: 146 | optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=5e-4) 147 | 148 | for epoch in range(max_epoch): 149 | train_acc_sum, trained_samples_counter = 0.0, 0 150 | batch_counter, train_loss_sum = 0, 0 151 | time_epoch = time.time() 152 | 153 | if model_type_flag == 1: # data for single spatial net 154 | for X_spa, y in train_iter: 155 | X_spa, y = X_spa.to(device), y.to(device) 156 | y_pred = net(X_spa) 157 | 158 | ls = loss(y_pred, y.long()) 159 | 160 | optimizer.zero_grad() 161 | ls.backward() 162 | optimizer.step() 163 | 164 | train_loss_sum += ls.cpu().item() 165 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 166 | trained_samples_counter += y.shape[0] 167 | batch_counter += 1 168 | epoch_first_iter = 0 169 | elif model_type_flag == 2: # data for single spectral net 170 | for X_spe, y in train_iter: 171 | X_spe, y = X_spe.to(device), y.to(device) 172 | y_pred = net(X_spe) 173 | 174 | ls = loss(y_pred, y.long()) 175 | 176 | optimizer.zero_grad() 177 | ls.backward() 178 | optimizer.step() 179 | 180 | train_loss_sum += ls.cpu().item() 181 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 182 | trained_samples_counter += y.shape[0] 183 | batch_counter += 1 184 | epoch_first_iter = 0 185 | elif model_type_flag == 3: # data for spectral-spatial net 186 | for X_spa, X_spe, y in train_iter: 187 | X_spa, X_spe, y = X_spa.to(device), X_spe.to(device), y.to(device) 188 | y_pred = net(X_spa, X_spe) 189 | 190 | ls = loss(y_pred, y.long()) 191 | 192 | optimizer.zero_grad() 193 | ls.backward() 194 | optimizer.step() 195 | 196 | train_loss_sum += ls.cpu().item() 197 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 198 | trained_samples_counter += y.shape[0] 199 | batch_counter += 1 200 | epoch_first_iter = 0 201 | 202 | torch.cuda.empty_cache() 203 | 204 | train_loss_list.append(train_loss_sum) 205 | train_acc_list.append(train_acc_sum / trained_samples_counter) 206 | 207 | print('epoch: %d, training_sampler_num: %d, batch_count: %.2f, train loss: %.6f, tarin loss sum: %.6f, ' 208 | 'train acc: %.3f, train_acc_sum: %.1f, time: %.1f sec' % 209 | (epoch + 1, trained_samples_counter, batch_counter, train_loss_sum / batch_counter, train_loss_sum, 210 | train_acc_sum / trained_samples_counter, train_acc_sum, time.time() - time_epoch)) 211 | 212 | toc1 = time.perf_counter() 213 | print('Training stage finished:\n epoch %d, loss %.4f, train acc %.3f, training time %.2f s' 214 | % (epoch + 1, train_loss_sum / batch_counter, train_acc_sum / trained_samples_counter, toc1 - tic1)) 215 | training_time = toc1 - tic1 216 | Train_Time_ALL.append(training_time) 217 | 218 | print("\n\n====================Starting evaluation for testing set.========================\n") 219 | 220 | pred_test = [] 221 | y_gt = [] 222 | # torch.cuda.empty_cache() 223 | with torch.no_grad(): 224 | # net.load_state_dict(torch.load(model_save_path+"_best_model.pt")) 225 | net.eval() 226 | train_acc_sum, samples_num_counter = 0.0, 0 227 | if model_type_flag == 1: # data for single spatial net 228 | for X_spa, y in test_iter: 229 | X_spa = X_spa.to(device) 230 | 231 | tic2 = time.perf_counter() 232 | y_pred = net(X_spa) 233 | toc2 = time.perf_counter() 234 | 235 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 236 | y_gt.extend(y) 237 | elif model_type_flag == 2: # data for single spectral net 238 | for X_spe, y in test_iter: 239 | X_spe = X_spe.to(device) 240 | 241 | tic2 = time.perf_counter() 242 | y_pred = net(X_spe) 243 | toc2 = time.perf_counter() 244 | 245 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 246 | y_gt.extend(y) 247 | elif model_type_flag == 3: # data for spectral-spatial net 248 | for X_spa, X_spe, y in test_iter: 249 | X_spa = X_spa.to(device) 250 | X_spe = X_spe.to(device) 251 | 252 | tic2 = time.perf_counter() 253 | y_pred = net(X_spa, X_spe) 254 | toc2 = time.perf_counter() 255 | 256 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 257 | y_gt.extend(y) 258 | 259 | OA = metrics.accuracy_score(y_gt, pred_test) 260 | confusion_matrix = metrics.confusion_matrix(pred_test, y_gt) 261 | print("confusion_matrix\n{}".format(confusion_matrix)) 262 | ECA, AA = evaluation.AA_ECA(confusion_matrix) 263 | kappa = metrics.cohen_kappa_score(pred_test, y_gt) 264 | cls_report = evaluation.claification_report(y_gt, pred_test, data_set_name) 265 | print("classification_report\n{}".format(cls_report)) 266 | 267 | # Visualization for all the labeled samples and total the samples 268 | # sample_list1 = [total_iter] 269 | # sample_list2 = [all_iter, all_data_index] 270 | 271 | # Visualization.gt_cls_map(gt,cls_map_save_path) 272 | # cls_visual.pred_cls_map_dl(sample_list1,net,gt,cls_map_save_path,model_type_flag) 273 | # cls_visual.pred_cls_map_dl(sample_list2,net,gt,cls_map_save_path) 274 | 275 | testing_time = toc2 - tic2 276 | Test_Time_ALL.append(testing_time) 277 | 278 | # Output infors 279 | f = open(results_save_path + '_results.txt', 'a+') 280 | str_results = '\n======================' \ 281 | + " learning rate=" + str(learning_rate) \ 282 | + " epochs=" + str(max_epoch) \ 283 | + " train ratio=" + str(ratio_list[0]) \ 284 | + " val ratio=" + str(ratio_list[1]) \ 285 | + " ======================" \ 286 | + "\nOA=" + str(OA) \ 287 | + "\nAA=" + str(AA) \ 288 | + '\nkpp=' + str(kappa) \ 289 | + '\nacc per class:' + str(ECA) \ 290 | + "\ntrain time:" + str(training_time) \ 291 | + "\ntest time:" + str(testing_time) + "\n" 292 | 293 | f.write(str_results) 294 | f.write('{}'.format(confusion_matrix)) 295 | f.write('\n\n') 296 | f.write('{}'.format(cls_report)) 297 | f.close() 298 | 299 | OA_ALL.append(OA) 300 | AA_ALL.append(AA) 301 | KPP_ALL.append(kappa) 302 | EACH_ACC_ALL.append(ECA) 303 | 304 | torch.cuda.empty_cache() 305 | del net, train_iter, test_iter 306 | # del net, train_iter, test_iter, val_iter 307 | # del net, train_iter, test_iter, val_iter, all_iter 308 | # del net 309 | 310 | OA_ALL = np.array(OA_ALL) 311 | AA_ALL = np.array(AA_ALL) 312 | KPP_ALL = np.array(KPP_ALL) 313 | EACH_ACC_ALL = np.array(EACH_ACC_ALL) 314 | Train_Time_ALL = np.array(Train_Time_ALL) 315 | Test_Time_ALL = np.array(Test_Time_ALL) 316 | 317 | np.set_printoptions(precision=4) 318 | print("\n====================Mean result of {} times runs =========================".format(len(seed_list))) 319 | print('List of OA:', list(OA_ALL)) 320 | print('List of AA:', list(AA_ALL)) 321 | print('List of KPP:', list(KPP_ALL)) 322 | print('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2)) 323 | print('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2)) 324 | print('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2)) 325 | print('Acc per class=', np.mean(EACH_ACC_ALL, 0), '+-', np.std(EACH_ACC_ALL, 0)) 326 | 327 | print("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3)) 328 | print("Average testing time=", round(np.mean(Test_Time_ALL), 5), '+-', round(np.std(Test_Time_ALL), 5)) 329 | 330 | # Output infors 331 | f = open(results_save_path + '_results.txt', 'a+') 332 | str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + ' times runs ********************' \ 333 | + '\nList of OA:' + str(list(OA_ALL)) \ 334 | + '\nList of AA:' + str(list(AA_ALL)) \ 335 | + '\nList of KPP:' + str(list(KPP_ALL)) \ 336 | + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \ 337 | + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \ 338 | + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(round(np.std(KPP_ALL) * 100, 2)) \ 339 | + '\nAcc per class=\n' + str(np.mean(EACH_ACC_ALL, 0)) + '+-' + str(np.std(EACH_ACC_ALL, 0)) \ 340 | + "\nAverage training time=" + str(round(np.mean(Train_Time_ALL), 2)) + '+-' + str( 341 | round(np.std(Train_Time_ALL), 3)) \ 342 | + "\nAverage testing time=" + str(round(np.mean(Test_Time_ALL), 5)) + '+-' + str( 343 | round(np.std(Test_Time_ALL), 5)) 344 | f.write(str_results) 345 | f.close() 346 | -------------------------------------------------------------------------------- /c_model/A2S2KResNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : A2S2KResNet.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # https://github.com/suvojit-0x55aa/A2S2K-ResNet/blob/master/A2S2KResNet/A2S2KResNet.py 9 | # Attention-Based Adaptive Spectral-Spatial Kernel ResNet for Hyperspectral Image Classification, TGRS 2020 10 | 11 | import math 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | 16 | PARAM_KERNEL_SIZE = 24 17 | 18 | 19 | class ChannelSELayer3D(nn.Module): 20 | """ 21 | 3D extension of Squeeze-and-Excitation (SE) block described in: 22 | *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* 23 | *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238* 24 | """ 25 | 26 | def __init__(self, num_channels, reduction_ratio=2): 27 | """ 28 | :param num_channels: No of input channels 29 | :param reduction_ratio: By how much should the num_channels should be reduced 30 | """ 31 | super(ChannelSELayer3D, self).__init__() 32 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 33 | num_channels_reduced = num_channels // reduction_ratio 34 | self.reduction_ratio = reduction_ratio 35 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 36 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 37 | self.relu = nn.ReLU() 38 | self.sigmoid = nn.Sigmoid() 39 | 40 | def forward(self, input_tensor): 41 | """ 42 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 43 | :return: output tensor 44 | """ 45 | batch_size, num_channels, D, H, W = input_tensor.size() 46 | # Average along each channel 47 | squeeze_tensor = self.avg_pool(input_tensor) 48 | 49 | # channel excitation 50 | fc_out_1 = self.relu( 51 | self.fc1(squeeze_tensor.view(batch_size, num_channels))) 52 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 53 | 54 | output_tensor = torch.mul( 55 | input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) 56 | 57 | return output_tensor 58 | 59 | 60 | class SpatialSELayer3D(nn.Module): 61 | """ 62 | 3D extension of SE block -- squeezing spatially and exciting channel-wise described in: 63 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* 64 | """ 65 | 66 | def __init__(self, num_channels): 67 | """ 68 | :param num_channels: No of input channels 69 | """ 70 | super(SpatialSELayer3D, self).__init__() 71 | self.conv = nn.Conv3d(num_channels, 1, 1) 72 | self.sigmoid = nn.Sigmoid() 73 | 74 | def forward(self, input_tensor, weights=None): 75 | """ 76 | :param weights: weights for few shot learning 77 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 78 | :return: output_tensor 79 | """ 80 | # channel squeeze 81 | batch_size, channel, D, H, W = input_tensor.size() 82 | 83 | if weights: 84 | weights = weights.view(1, channel, 1, 1) 85 | out = F.conv2d(input_tensor, weights) 86 | else: 87 | out = self.conv(input_tensor) 88 | 89 | squeeze_tensor = self.sigmoid(out) 90 | 91 | # spatial excitation 92 | output_tensor = torch.mul(input_tensor, 93 | squeeze_tensor.view(batch_size, 1, D, H, W)) 94 | 95 | return output_tensor 96 | 97 | 98 | class ChannelSpatialSELayer3D(nn.Module): 99 | """ 100 | 3D extension of concurrent spatial and channel squeeze & excitation: 101 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579* 102 | """ 103 | 104 | def __init__(self, num_channels, reduction_ratio=2): 105 | """ 106 | :param num_channels: No of input channels 107 | :param reduction_ratio: By how much should the num_channels should be reduced 108 | """ 109 | super(ChannelSpatialSELayer3D, self).__init__() 110 | self.cSE = ChannelSELayer3D(num_channels, reduction_ratio) 111 | self.sSE = SpatialSELayer3D(num_channels) 112 | 113 | def forward(self, input_tensor): 114 | """ 115 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 116 | :return: output_tensor 117 | """ 118 | output_tensor = torch.max( 119 | self.cSE(input_tensor), self.sSE(input_tensor)) 120 | return output_tensor 121 | 122 | 123 | class ProjectExciteLayer(nn.Module): 124 | """ 125 | Project & Excite Module, specifically designed for 3D inputs 126 | *quote* 127 | """ 128 | 129 | def __init__(self, num_channels, reduction_ratio=2): 130 | """ 131 | :param num_channels: No of input channels 132 | :param reduction_ratio: By how much should the num_channels should be reduced 133 | """ 134 | super(ProjectExciteLayer, self).__init__() 135 | num_channels_reduced = num_channels // reduction_ratio 136 | self.reduction_ratio = reduction_ratio 137 | self.relu = nn.ReLU() 138 | self.conv_c = nn.Conv3d( 139 | in_channels=num_channels, 140 | out_channels=num_channels_reduced, 141 | kernel_size=1, 142 | stride=1) 143 | self.conv_cT = nn.Conv3d( 144 | in_channels=num_channels_reduced, 145 | out_channels=num_channels, 146 | kernel_size=1, 147 | stride=1) 148 | self.sigmoid = nn.Sigmoid() 149 | 150 | def forward(self, input_tensor): 151 | """ 152 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 153 | :return: output tensor 154 | """ 155 | batch_size, num_channels, D, H, W = input_tensor.size() 156 | 157 | # Average along channels and different axes 158 | squeeze_tensor_w = F.adaptive_avg_pool3d(input_tensor, (1, 1, W)) 159 | 160 | squeeze_tensor_h = F.adaptive_avg_pool3d(input_tensor, (1, H, 1)) 161 | 162 | squeeze_tensor_d = F.adaptive_avg_pool3d(input_tensor, (D, 1, 1)) 163 | 164 | # tile tensors to original size and add: 165 | final_squeeze_tensor = sum([ 166 | squeeze_tensor_w.view(batch_size, num_channels, 1, 1, W), 167 | squeeze_tensor_h.view(batch_size, num_channels, 1, H, 1), 168 | squeeze_tensor_d.view(batch_size, num_channels, D, 1, 1) 169 | ]) 170 | 171 | # Excitation: 172 | final_squeeze_tensor = self.sigmoid( 173 | self.conv_cT(self.relu(self.conv_c(final_squeeze_tensor)))) 174 | output_tensor = torch.mul(input_tensor, final_squeeze_tensor) 175 | 176 | return output_tensor 177 | 178 | 179 | class eca_layer(nn.Module): 180 | """Constructs a ECA module. 181 | Args: 182 | channel: Number of channels of the input feature map 183 | k_size: Adaptive selection of kernel size 184 | """ 185 | 186 | def __init__(self, channel, k_size=3): 187 | super(eca_layer, self).__init__() 188 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 189 | self.conv = nn.Conv2d( 190 | 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 191 | self.sigmoid = nn.Sigmoid() 192 | 193 | def forward(self, x): 194 | # x: input features with shape [b, c, h, w] 195 | b, c, h, w, t = x.size() 196 | 197 | # feature descriptor on the global spatial information 198 | # 24, 1, 1, 1 199 | y = self.avg_pool(x) 200 | 201 | # Two different branches of ECA module 202 | y = self.conv(y.squeeze(-1).transpose(-1, -3)).transpose( 203 | -1, -3).unsqueeze(-1) 204 | 205 | # Multi-scale information fusion 206 | y = self.sigmoid(y) 207 | 208 | return x * y.expand_as(x) 209 | 210 | 211 | class Residual(nn.Module): # pytorch 212 | def __init__( 213 | self, 214 | in_channels, 215 | out_channels, 216 | kernel_size, 217 | padding, 218 | use_1x1conv=False, 219 | stride=1, 220 | start_block=False, 221 | end_block=False, 222 | ): 223 | super(Residual, self).__init__() 224 | self.conv1 = nn.Sequential( 225 | nn.Conv3d( 226 | in_channels, 227 | out_channels, 228 | kernel_size=kernel_size, 229 | padding=padding, 230 | stride=stride), nn.ReLU()) 231 | self.conv2 = nn.Conv3d( 232 | out_channels, 233 | out_channels, 234 | kernel_size=kernel_size, 235 | padding=padding, 236 | stride=stride) 237 | if use_1x1conv: 238 | self.conv3 = nn.Conv3d( 239 | in_channels, out_channels, kernel_size=1, stride=stride) 240 | else: 241 | self.conv3 = None 242 | 243 | if not start_block: 244 | self.bn0 = nn.BatchNorm3d(in_channels) 245 | 246 | self.bn1 = nn.BatchNorm3d(out_channels) 247 | self.bn2 = nn.BatchNorm3d(out_channels) 248 | 249 | if start_block: 250 | self.bn2 = nn.BatchNorm3d(out_channels) 251 | 252 | if end_block: 253 | self.bn2 = nn.BatchNorm3d(out_channels) 254 | 255 | # ECA Attention Layer 256 | self.ecalayer = eca_layer(out_channels) 257 | 258 | # start and end block initialization 259 | self.start_block = start_block 260 | self.end_block = end_block 261 | 262 | def forward(self, X): 263 | identity = X 264 | 265 | if self.start_block: 266 | out = self.conv1(X) 267 | else: 268 | out = self.bn0(X) 269 | out = F.relu(out) 270 | out = self.conv1(out) 271 | 272 | out = self.bn1(out) 273 | out = F.relu(out) 274 | 275 | out = self.conv2(out) 276 | 277 | if self.start_block: 278 | out = self.bn2(out) 279 | 280 | out = self.ecalayer(out) # EFR 281 | 282 | out += identity 283 | 284 | if self.end_block: 285 | out = self.bn2(out) 286 | out = F.relu(out) 287 | 288 | return out 289 | 290 | 291 | class S3KAIResNet(nn.Module): 292 | def __init__(self, band, classes, reduction): 293 | super(S3KAIResNet, self).__init__() 294 | self.name = 'SSRN' 295 | self.conv1x1 = nn.Conv3d( 296 | in_channels=1, 297 | out_channels=PARAM_KERNEL_SIZE, 298 | kernel_size=(1, 1, 7), 299 | stride=(1, 1, 2), 300 | padding=0) 301 | 302 | self.conv3x3 = nn.Conv3d( 303 | in_channels=1, 304 | out_channels=PARAM_KERNEL_SIZE, 305 | kernel_size=(3, 3, 7), 306 | stride=(1, 1, 2), 307 | padding=(1, 1, 0)) 308 | 309 | self.batch_norm1x1 = nn.Sequential( 310 | nn.BatchNorm3d( 311 | PARAM_KERNEL_SIZE, eps=0.001, momentum=0.1, 312 | affine=True), # 0.1 313 | nn.ReLU(inplace=True)) 314 | self.batch_norm3x3 = nn.Sequential( 315 | nn.BatchNorm3d( 316 | PARAM_KERNEL_SIZE, eps=0.001, momentum=0.1, 317 | affine=True), # 0.1 318 | nn.ReLU(inplace=True)) 319 | 320 | self.pool = nn.AdaptiveAvgPool3d(1) 321 | self.conv_se = nn.Sequential( 322 | nn.Conv3d( 323 | PARAM_KERNEL_SIZE, band // reduction, 1, padding=0, bias=True), 324 | nn.ReLU(inplace=True)) 325 | self.conv_ex = nn.Conv3d( 326 | band // reduction, PARAM_KERNEL_SIZE, 1, padding=0, bias=True) 327 | self.softmax = nn.Softmax(dim=1) 328 | 329 | self.res_net1 = Residual( 330 | PARAM_KERNEL_SIZE, 331 | PARAM_KERNEL_SIZE, (1, 1, 7), (0, 0, 3), 332 | start_block=True) 333 | self.res_net2 = Residual(PARAM_KERNEL_SIZE, PARAM_KERNEL_SIZE, 334 | (1, 1, 7), (0, 0, 3)) 335 | self.res_net3 = Residual(PARAM_KERNEL_SIZE, PARAM_KERNEL_SIZE, 336 | (3, 3, 1), (1, 1, 0)) 337 | self.res_net4 = Residual( 338 | PARAM_KERNEL_SIZE, 339 | PARAM_KERNEL_SIZE, (3, 3, 1), (1, 1, 0), 340 | end_block=True) 341 | 342 | kernel_3d = math.ceil((band - 6) / 2) # 97 343 | # print(kernel_3d) 344 | 345 | self.conv2 = nn.Conv3d( 346 | in_channels=PARAM_KERNEL_SIZE, 347 | out_channels=128, 348 | kernel_size=(1, 1, kernel_3d), 349 | stride=(1, 1, 1), 350 | padding=(0, 0, 0)) 351 | 352 | self.batch_norm2 = nn.Sequential( 353 | nn.BatchNorm3d(128, eps=0.001, momentum=0.1, affine=True), # 0.1 354 | nn.ReLU(inplace=True)) 355 | self.conv3 = nn.Conv3d( 356 | in_channels=1, 357 | out_channels=PARAM_KERNEL_SIZE, 358 | kernel_size=(3, 3, 128), 359 | stride=(1, 1, 1), 360 | padding=(0, 0, 0) 361 | ) 362 | self.batch_norm3 = nn.Sequential( 363 | nn.BatchNorm3d( 364 | PARAM_KERNEL_SIZE, eps=0.001, momentum=0.1, 365 | affine=True), # 0.1 366 | nn.ReLU(inplace=True)) 367 | 368 | self.avg_pooling = nn.AvgPool3d(kernel_size=(5, 5, 1)) # kernel_size=stride 369 | self.full_connection = nn.Sequential( 370 | nn.Linear(PARAM_KERNEL_SIZE, classes) 371 | # nn.Softmax() 372 | ) 373 | 374 | def forward(self, X): 375 | # X = X.permute(0,1,4,3,2) 376 | x_1x1 = self.conv1x1(X) # 16X1X9X9X200-->[16,24,9,9,97] # Dimensionality Reduction 377 | x_1x1 = self.batch_norm1x1(x_1x1).unsqueeze(dim=1) # [16,1,24,9,9,97] 378 | x_3x3 = self.conv3x3(X) # [16,24,9,9,97] 379 | x_3x3 = self.batch_norm3x3(x_3x3).unsqueeze(dim=1) # [16,1,24,9,9,97] 380 | 381 | x1 = torch.cat([x_3x3, x_1x1], dim=1) # [16,2,24,9,9,97] 382 | U = torch.sum(x1, dim=1) # [2,24,9,9,97] # Dimensionality 1 will be reduced 383 | S = self.pool(U) # [16,24,1,1,1] 384 | Z = self.conv_se(S) # [16,100,1,1,1] 385 | attention_vector = torch.cat( 386 | [ 387 | self.conv_ex(Z).unsqueeze(dim=1), 388 | self.conv_ex(Z).unsqueeze(dim=1) 389 | ], 390 | dim=1) # [16,2,24,1,1,1] 391 | attention_vector = self.softmax(attention_vector) # [16,2,24,1,1,1] 392 | V = (x1 * attention_vector).sum(dim=1) # [16,24,9,9,97] 393 | 394 | ######################################## 395 | x2 = self.res_net1(V) # [16,24,9,9,97] 396 | x2 = self.res_net2(x2) # [16,24,9,9,97] 397 | x2 = self.batch_norm2(self.conv2(x2)) # [16,128,9,9,1] 398 | x2 = x2.permute(0, 4, 2, 3, 1) # [16,1,9,9,128] 399 | x2 = self.batch_norm3(self.conv3(x2)) # [16,24,7,7,1] 400 | 401 | x3 = self.res_net3(x2) # [16,24,7,7,1] 402 | x3 = self.res_net4(x3) # [16,24,7,7,1] 403 | x4 = self.avg_pooling(x3) # [16,24,1,1,1] 404 | x4 = x4.view(x4.size(0), -1) # [16,24] 405 | return self.full_connection(x4) 406 | -------------------------------------------------------------------------------- /process_dl_disjoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : process_dl_disjoint.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # for the UH data set, main processing file for the proposed AMS-M2ESL model 9 | 10 | import os 11 | import time 12 | import torch 13 | import random 14 | import numpy as np 15 | from sklearn import metrics 16 | from ptflops import get_model_complexity_info 17 | 18 | import utils.evaluation as evaluation 19 | import utils.data_load_operate as data_load_operate 20 | import visual.cls_visual as cls_visual 21 | import model.AMS_M2ESL as AMS_M2ESL 22 | 23 | time_current = time.strftime("%y-%m-%d-%H.%M", time.localtime()) 24 | 25 | # random seed setting 26 | seed = 20 27 | 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | np.random.seed(seed) # Numpy module. 32 | random.seed(seed) # Python random module. 33 | torch.manual_seed(seed) 34 | torch.backends.cudnn.benchmark = False 35 | torch.backends.cudnn.deterministic = True 36 | 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | 39 | ### 0 ### 40 | model_list = ['AMS-M2ESL'] 41 | model_flag = 0 42 | model_spa_set = {0} 43 | model_spe_set = {} 44 | model_spa_spe_set = {} 45 | model_3D_spa_set = {} 46 | model_3D_spa_flag = 0 47 | 48 | if model_flag in model_spa_set: 49 | model_type_flag = 1 50 | if model_flag in model_3D_spa_set: 51 | model_3D_spa_flag = 1 52 | elif model_flag in model_spe_set: 53 | model_type_flag = 2 54 | elif model_flag in model_spa_spe_set: 55 | model_type_flag = 3 56 | 57 | # 0 58 | data_set_name_list = ['UH_tif'] 59 | data_set_name = data_set_name_list[0] 60 | 61 | data_set_path = os.path.join(os.getcwd(), 'data') 62 | 63 | # control running times 64 | # seed_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 65 | # seed_list=[0,1,2,3,4] 66 | # seed_list=[0,1,2] 67 | # seed_list=[0,1] 68 | seed_list = [0] 69 | 70 | ratio = "hu13" 71 | 72 | patch_size = 9 73 | patch_length = 4 74 | 75 | results_save_path = \ 76 | os.path.join(os.path.join(os.getcwd(), 'output/results'), model_list[model_flag] + str("_") + 77 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str( 78 | ratio) + str("_patch_size_") + str(patch_size)) 79 | cls_map_save_path = \ 80 | os.path.join(os.path.join(os.getcwd(), 'output/cls_maps'), model_list[model_flag] + str("_") + 81 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str(ratio)) 82 | 83 | if __name__ == '__main__': 84 | 85 | data, gt_train, gt_test = data_load_operate.load_HU_data(data_set_path) 86 | data = data_load_operate.standardization(data) 87 | 88 | # dr=math.ceil(data.shape[-1]*0.3) 89 | # data=data_load_operate.applyPCA(data,dr) # for abla of MNF 90 | data = data_load_operate.HSI_MNF(data, MNF_ratio=0.3) 91 | 92 | gt_train_re = gt_train.reshape(-1) 93 | gt_test_re = gt_test.reshape(-1) 94 | height, width, channels = data.shape 95 | class_count = max(np.unique(gt_train_re)) 96 | 97 | batch_size = 256 98 | max_epoch = 40 99 | learning_rate = 0.001 100 | loss = torch.nn.CrossEntropyLoss() 101 | 102 | OA_ALL = [] 103 | AA_ALL = [] 104 | KPP_ALL = [] 105 | EACH_ACC_ALL = [] 106 | Train_Time_ALL = [] 107 | Test_Time_ALL = [] 108 | CLASS_ACC = np.zeros([len(seed_list), class_count]) 109 | 110 | # data pad zero 111 | # data:[h,w,c]->data_padded:[h+2l,w+2l,c] 112 | data_padded = data_load_operate.data_pad_zero(data, patch_length) 113 | height_patched, width_patched = data_padded.shape[0], data_padded.shape[1] 114 | 115 | # data_total_index = np.arange(data.shape[0] * data.shape[1]) # For total sample cls_map. 116 | 117 | for curr_seed in seed_list: 118 | tic1 = time.perf_counter() 119 | 120 | train_data_index, test_data_index, all_data_index = data_load_operate.sampling_disjoint(gt_train_re, 121 | gt_test_re, 122 | class_count) 123 | index = (train_data_index, test_data_index) 124 | train_iter, test_iter = data_load_operate.generate_iter_disjoint(data_padded, height, width, gt_train_re, 125 | gt_test_re, index, patch_length, 126 | batch_size, model_type_flag, 127 | model_3D_spa_flag) 128 | # load data for the cls map of the total samples 129 | # total_iter = data_load_operate.generate_iter_2(data_padded, height, width, gt_train_re, data_total_index, patch_length, 130 | # 256, model_type_flag, model_3D_spa_flag) 131 | 132 | if model_flag == 0: 133 | net = AMS_M2ESL.AMS_M2ESL_(in_channels=channels, patch_size=patch_size, num_classes=class_count, 134 | ds=data_set_name) 135 | 136 | net.to(device) 137 | 138 | # efficiency test, model complexity and computational cost 139 | # flops,para=get_model_complexity_info(net,(channels,1,1),as_strings=False,print_per_layer_stat=True, verbose=True) 140 | # flops,para=get_model_complexity_info(net,(patch_size,patch_size,channels),as_strings=False,print_per_layer_stat=True, verbose=True) 141 | # # # flops,para=get_model_complexity_info(net,(1,1,patch_size,patch_size,channels),as_strings=False,print_per_layer_stat=True, verbose=True) 142 | # print("para(M):{:.3f},\n flops(M):{:.3f}".format(para/(1000**2),flops/(1000**2),)) 143 | 144 | train_loss_list = [100] 145 | train_acc_list = [0] 146 | optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) 147 | 148 | for epoch in range(max_epoch): 149 | train_acc_sum, trained_samples_counter = 0.0, 0 150 | batch_counter, train_loss_sum = 0, 0 151 | time_epoch = time.time() 152 | 153 | if model_type_flag == 1: # data for single spatial net 154 | for X_spa, y in train_iter: 155 | X_spa, y = X_spa.to(device), y.to(device) 156 | y_pred = net(X_spa) 157 | 158 | ls = loss(y_pred, y.long()) 159 | 160 | optimizer.zero_grad() 161 | ls.backward() 162 | optimizer.step() 163 | 164 | train_loss_sum += ls.cpu().item() 165 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 166 | trained_samples_counter += y.shape[0] 167 | batch_counter += 1 168 | epoch_first_iter = 0 169 | elif model_type_flag == 2: # data for single spectral net 170 | for X_spe, y in train_iter: 171 | X_spe, y = X_spe.to(device), y.to(device) 172 | y_pred = net(X_spe) 173 | 174 | ls = loss(y_pred, y.long()) 175 | 176 | optimizer.zero_grad() 177 | ls.backward() 178 | optimizer.step() 179 | 180 | train_loss_sum += ls.cpu().item() 181 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 182 | trained_samples_counter += y.shape[0] 183 | batch_counter += 1 184 | epoch_first_iter = 0 185 | elif model_type_flag == 3: # data for spectral-spatial net 186 | for X_spa, X_spe, y in train_iter: 187 | X_spa, X_spe, y = X_spa.to(device), X_spe.to(device), y.to(device) 188 | y_pred = net(X_spa, X_spe) 189 | 190 | ls = loss(y_pred, y.long()) 191 | 192 | optimizer.zero_grad() 193 | ls.backward() 194 | optimizer.step() 195 | 196 | train_loss_sum += ls.cpu().item() 197 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 198 | trained_samples_counter += y.shape[0] 199 | batch_counter += 1 200 | epoch_first_iter = 0 201 | 202 | torch.cuda.empty_cache() 203 | 204 | train_loss_list.append(train_loss_sum) 205 | train_acc_list.append(train_acc_sum / trained_samples_counter) 206 | 207 | print('epoch: %d, training_sampler_num: %d, batch_count: %.2f, train loss: %.6f, tarin loss sum: %.6f, ' 208 | 'train acc: %.3f, train_acc_sum: %.1f, time: %.1f sec' % 209 | (epoch + 1, trained_samples_counter, batch_counter, train_loss_sum / batch_counter, train_loss_sum, 210 | train_acc_sum / trained_samples_counter, train_acc_sum, time.time() - time_epoch)) 211 | 212 | toc1 = time.perf_counter() 213 | print('Training stage finished:\n epoch %d, loss %.4f, train acc %.3f, training time %.2f s' 214 | % (epoch + 1, train_loss_sum / batch_counter, train_acc_sum / trained_samples_counter, toc1 - tic1)) 215 | training_time = toc1 - tic1 216 | Train_Time_ALL.append(training_time) 217 | 218 | print("\n\n====================Starting evaluation for testing set.========================\n") 219 | 220 | pred_test = [] 221 | # torch.cuda.empty_cache() 222 | with torch.no_grad(): 223 | # net.load_state_dict(torch.load(model_save_path+"_best_model.pt")) 224 | net.eval() 225 | train_acc_sum, samples_num_counter = 0.0, 0 226 | if model_type_flag == 1: # data for single spatial net 227 | for X_spa, y in test_iter: 228 | X_spa = X_spa.to(device) 229 | y = y.to(device) 230 | 231 | tic2 = time.perf_counter() 232 | y_pred = net(X_spa) 233 | toc2 = time.perf_counter() 234 | 235 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 236 | elif model_type_flag == 2: # data for single spectral net 237 | for X_spe, y in test_iter: 238 | X_spe = X_spe.to(device) 239 | y = y.to(device) 240 | 241 | tic2 = time.perf_counter() 242 | y_pred = net(X_spe) 243 | toc2 = time.perf_counter() 244 | 245 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 246 | elif model_type_flag == 3: # data for spectral-spatial net 247 | for X_spa, X_spe, y in test_iter: 248 | X_spa = X_spa.to(device) 249 | X_spe = X_spe.to(device) 250 | y = y.to(device) 251 | 252 | tic2 = time.perf_counter() 253 | y_pred = net(X_spa, X_spe) 254 | toc2 = time.perf_counter() 255 | 256 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 257 | 258 | y_gt = gt_test_re[test_data_index] - 1 259 | OA = metrics.accuracy_score(y_gt, pred_test) 260 | confusion_matrix = metrics.confusion_matrix(pred_test, y_gt) 261 | print("confusion_matrix\n{}".format(confusion_matrix)) 262 | ECA, AA = evaluation.AA_ECA(confusion_matrix) 263 | kappa = metrics.cohen_kappa_score(pred_test, y_gt) 264 | cls_report = evaluation.claification_report(y_gt, pred_test, data_set_name) 265 | print("classification_report\n{}".format(cls_report)) 266 | 267 | # Visualization for all the labeled samples and total the samples 268 | # sample_list1 = [total_iter] 269 | # sample_list2 = [all_iter, all_data_index] 270 | 271 | # Visualization.gt_cls_map(gt,cls_map_save_path) 272 | # cls_visual.pred_cls_map_dl(sample_list1,net,gt_train,cls_map_save_path,model_type_flag) 273 | # cls_visual.pred_cls_map_dl(sample_list2,net,gt,cls_map_save_path) 274 | 275 | testing_time = toc2 - tic2 276 | Test_Time_ALL.append(testing_time) 277 | 278 | # Output infors 279 | f = open(results_save_path + '_results.txt', 'a+') 280 | str_results = '\n======================' \ 281 | + " learning rate=" + str(learning_rate) \ 282 | + " epochs=" + str(max_epoch) \ 283 | + " ======================" \ 284 | + "\nOA=" + str(OA) \ 285 | + "\nAA=" + str(AA) \ 286 | + '\nkpp=' + str(kappa) \ 287 | + '\nacc per class:' + str(ECA) \ 288 | + "\ntrain time:" + str(training_time) \ 289 | + "\ntest time:" + str(testing_time) + "\n" 290 | 291 | f.write(str_results) 292 | f.write('{}'.format(confusion_matrix)) 293 | f.write('\n\n') 294 | f.write('{}'.format(cls_report)) 295 | f.close() 296 | 297 | OA_ALL.append(OA) 298 | AA_ALL.append(AA) 299 | KPP_ALL.append(kappa) 300 | EACH_ACC_ALL.append(ECA) 301 | 302 | torch.cuda.empty_cache() 303 | del net, train_iter, test_iter 304 | 305 | OA_ALL = np.array(OA_ALL) 306 | AA_ALL = np.array(AA_ALL) 307 | KPP_ALL = np.array(KPP_ALL) 308 | EACH_ACC_ALL = np.array(EACH_ACC_ALL) 309 | Train_Time_ALL = np.array(Train_Time_ALL) 310 | Test_Time_ALL = np.array(Test_Time_ALL) 311 | 312 | np.set_printoptions(precision=4) 313 | print("\n====================Mean result of {} times runs =========================".format(len(seed_list))) 314 | print('List of OA:', list(OA_ALL)) 315 | print('List of AA:', list(AA_ALL)) 316 | print('List of KPP:', list(KPP_ALL)) 317 | print('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2)) 318 | print('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2)) 319 | print('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2)) 320 | print('Acc per class=', np.mean(EACH_ACC_ALL, 0), '+-', np.std(EACH_ACC_ALL, 0)) 321 | 322 | print("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3)) 323 | print("Average testing time=", round(np.mean(Test_Time_ALL), 5), '+-', round(np.std(Test_Time_ALL), 5)) 324 | 325 | # Output infors 326 | f = open(results_save_path + '_results.txt', 'a+') 327 | str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + 'times runs ********************' \ 328 | + '\nList of OA:' + str(list(OA_ALL)) \ 329 | + '\nList of AA:' + str(list(AA_ALL)) \ 330 | + '\nList of KPP:' + str(list(KPP_ALL)) \ 331 | + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \ 332 | + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \ 333 | + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(round(np.std(KPP_ALL) * 100, 2)) \ 334 | + '\nAcc per class=\n' + str(np.mean(EACH_ACC_ALL, 0)) + '+-' + str(np.std(EACH_ACC_ALL, 0)) \ 335 | + "\nAverage training time=" + str(round(np.mean(Train_Time_ALL), 2)) + '+-' + str( 336 | round(np.std(Train_Time_ALL), 3)) \ 337 | + "\nAverage testing time=" + str(round(np.mean(Test_Time_ALL), 5)) + '+-' + str( 338 | round(np.std(Test_Time_ALL), 5)) 339 | f.write(str_results) 340 | f.close() 341 | -------------------------------------------------------------------------------- /utils/data_load_operate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : data_load_operate.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | import os 9 | import math 10 | import torch 11 | import numpy as np 12 | import spectral as spy 13 | import scipy.io as sio 14 | import torch.utils.data as Data 15 | import matplotlib.pyplot as plt 16 | from sklearn.decomposition import PCA 17 | from sklearn import preprocessing 18 | 19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | def load_data(data_set_name, data_path): 23 | if data_set_name == 'IP': 24 | data = sio.loadmat(os.path.join(data_path, 'IP', 'Indian_pines_corrected.mat'))['indian_pines_corrected'] 25 | labels = sio.loadmat(os.path.join(data_path, 'IP', 'Indian_pines_gt.mat'))['indian_pines_gt'] 26 | elif data_set_name == 'UP': 27 | data = sio.loadmat(os.path.join(data_path, 'UP', 'PaviaU.mat'))['paviaU'] 28 | labels = sio.loadmat(os.path.join(data_path, 'UP', 'PaviaU_gt.mat'))['paviaU_gt'] 29 | 30 | return data, labels 31 | 32 | 33 | def load_HU_data(data_path): 34 | data = sio.loadmat(os.path.join(data_path, 'HU13_tif', "Houston13_data.mat"))['Houston13_data'] 35 | labels_train = sio.loadmat(os.path.join(data_path, 'HU13_tif', "Houston13_gt_train.mat"))['Houston13_gt_train'] 36 | labels_test = sio.loadmat(os.path.join(data_path, 'HU13_tif', "Houston13_gt_test.mat"))['Houston13_gt_test'] 37 | 38 | return data, labels_train, labels_test 39 | 40 | 41 | def standardization(data): 42 | height, width, bands = data.shape 43 | data = np.reshape(data, [height * width, bands]) 44 | # data=preprocessing.scale(data) # 45 | # data = preprocessing.MinMaxScaler().fit_transform(data) 46 | data = preprocessing.StandardScaler().fit_transform(data) # 47 | 48 | data = np.reshape(data, [height, width, bands]) 49 | return data 50 | 51 | 52 | def sampling(ratio_list, num_list, gt_reshape, class_count, Flag): 53 | all_label_index_dict, train_label_index_dict, test_label_index_dict = {}, {}, {} 54 | all_label_index_list, train_label_index_list, test_label_index_list = [], [], [], 55 | 56 | for cls in range(class_count): # [0-15] 57 | cls_index = np.where(gt_reshape == cls + 1)[0] 58 | all_label_index_dict[cls] = list(cls_index) 59 | 60 | np.random.shuffle(cls_index) 61 | 62 | if Flag == 0: # Fixed proportion for each category 63 | train_index_flag = max(int(ratio_list[0] * len(cls_index)), 3) # at least 3 samples per class] 64 | # Split by num per class 65 | elif Flag == 1: # Fixed quantity per category 66 | if len(cls_index) > num_list[0]: 67 | train_index_flag = num_list[0] 68 | else: 69 | train_index_flag = 15 70 | 71 | train_label_index_dict[cls] = list(cls_index[:train_index_flag]) 72 | test_label_index_dict[cls] = list(cls_index[train_index_flag:]) 73 | 74 | train_label_index_list += train_label_index_dict[cls] 75 | test_label_index_list += test_label_index_dict[cls] 76 | all_label_index_list += all_label_index_dict[cls] 77 | 78 | return train_label_index_list, test_label_index_list, all_label_index_list 79 | 80 | 81 | def sampling_disjoint(gt_train_re, gt_test_re, class_count): 82 | all_label_index_dict, train_label_index_dict, test_label_index_dict = {}, {}, {} 83 | all_label_index_list, train_label_index_list, test_label_index_list = [], [], [] 84 | 85 | for cls in range(class_count): 86 | cls_index_train = np.where(gt_train_re == cls + 1)[0] 87 | cls_index_test = np.where(gt_test_re == cls + 1)[0] 88 | 89 | train_label_index_dict[cls] = list(cls_index_train) 90 | test_label_index_dict[cls] = list(cls_index_test) 91 | 92 | train_label_index_list += train_label_index_dict[cls] 93 | test_label_index_list += test_label_index_dict[cls] 94 | all_label_index_list += (train_label_index_dict[cls] + test_label_index_dict[cls]) 95 | 96 | return train_label_index_list, test_label_index_list, all_label_index_list 97 | 98 | 99 | def applyPCA(X, numComponents=75): 100 | newX = np.reshape(X, (-1, X.shape[2])) 101 | pca = PCA(n_components=numComponents, whiten=True) 102 | newX = pca.fit_transform(newX) 103 | newX = np.reshape(newX, (X.shape[0], X.shape[1], numComponents)) 104 | return newX 105 | 106 | 107 | def HSI_MNF(X, MNF_ratio): 108 | denoised_bands = math.ceil(MNF_ratio * X.shape[-1]) 109 | mnfr = spy.mnf(spy.calc_stats(X), spy.noise_from_diffs(X)) 110 | denoised_data = mnfr.reduce(X, num=denoised_bands) 111 | 112 | return denoised_data 113 | 114 | 115 | def data_pad_zero(data, patch_length): 116 | data_padded = np.lib.pad(data, ((patch_length, patch_length), (patch_length, patch_length), (0, 0)), 'constant', 117 | constant_values=0) 118 | return data_padded 119 | 120 | def img_show(x): 121 | spy.imshow(x) 122 | plt.show() 123 | 124 | 125 | def index_assignment(index, row, col, pad_length): 126 | new_assign = {} # dictionary. 127 | for counter, value in enumerate(index): 128 | assign_0 = value // col + pad_length 129 | assign_1 = value % col + pad_length 130 | new_assign[counter] = [assign_0, assign_1] 131 | return new_assign 132 | 133 | 134 | def select_patch(data_padded, pos_x, pos_y, patch_length): 135 | selected_patch = data_padded[pos_x - patch_length:pos_x + patch_length + 1, 136 | pos_y - patch_length:pos_y + patch_length + 1] 137 | return selected_patch 138 | 139 | 140 | def select_vector(data_padded, pos_x, pos_y): 141 | select_vector = data_padded[pos_x, pos_y] 142 | return select_vector 143 | 144 | 145 | def HSI_create_pathes(data_padded, hsi_h, hsi_w, data_indexes, patch_length, flag): 146 | h_p, w_p, c = data_padded.shape 147 | 148 | data_size = len(data_indexes) 149 | patch_size = patch_length * 2 + 1 150 | 151 | data_assign = index_assignment(data_indexes, hsi_h, hsi_w, patch_length) 152 | if flag == 1: 153 | # for spatial net data, HSI patch 154 | unit_data = np.zeros((data_size, patch_size, patch_size, c)) 155 | unit_data_torch = torch.from_numpy(unit_data).type(torch.FloatTensor).to(device) 156 | for i in range(len(data_assign)): 157 | unit_data_torch[i] = select_patch(data_padded, data_assign[i][0], data_assign[i][1], patch_length) 158 | 159 | if flag == 2: 160 | # for spectral net data, HSI vector 161 | unit_data = np.zeros((data_size, c)) 162 | unit_data_torch = torch.from_numpy(unit_data).type(torch.FloatTensor).to(device) 163 | for i in range(len(data_assign)): 164 | unit_data_torch[i] = select_vector(data_padded, data_assign[i][0], data_assign[i][1]) 165 | 166 | return unit_data_torch 167 | 168 | 169 | def generate_data_set(data_reshape, label, index): 170 | train_data_index, test_data_index, all_data_index = index 171 | x_train_set = data_reshape[train_data_index] 172 | y_train_set = label[train_data_index] - 1 173 | 174 | x_test_set = data_reshape[test_data_index] 175 | y_test_set = label[test_data_index] - 1 176 | 177 | x_all_set = data_reshape[all_data_index] 178 | y_all_set = label[all_data_index] - 1 179 | 180 | return x_train_set, y_train_set, x_test_set, y_test_set, x_all_set, y_all_set 181 | 182 | 183 | def generate_data_set_disjoint(data_reshape, label_train, label_test, index): 184 | train_data_index, test_data_index, all_data_index = index 185 | x_train_set = data_reshape[train_data_index] 186 | y_train_set = label_train[train_data_index] - 1 187 | 188 | x_test_set = data_reshape[test_data_index] 189 | y_test_set = label_test[test_data_index] - 1 190 | 191 | # x_all_set = data_reshape[all_data_index] 192 | # y_all_set = label[all_data_index] - 1 193 | 194 | return x_train_set, y_train_set, x_test_set, y_test_set 195 | 196 | 197 | # generating HSI patches using GPU directly. 198 | def generate_iter(data_padded, hsi_h, hsi_w, label_reshape, index, patch_length, batch_size, 199 | model_type_flag, 200 | model_3D_spa_flag, last_batch_flag): 201 | # flag for single spatial net or single spectral net or spectral-spatial net 202 | data_padded_torch = torch.from_numpy(data_padded).type(torch.FloatTensor).to(device) 203 | 204 | # for data label 205 | train_labels = label_reshape[index[0]] - 1 206 | test_labels = label_reshape[index[1]] - 1 207 | 208 | y_tensor_train = torch.from_numpy(train_labels).type(torch.FloatTensor) 209 | y_tensor_test = torch.from_numpy(test_labels).type(torch.FloatTensor) 210 | 211 | # for data 212 | if model_type_flag == 1: # data for single spatial net 213 | spa_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 1) 214 | spa_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 1) 215 | 216 | if model_3D_spa_flag == 1: # spatial 3D patch 217 | spa_train_samples = spa_train_samples.unsqueeze(1) 218 | spa_test_samples = spa_test_samples.unsqueeze(1) 219 | 220 | torch_dataset_train = Data.TensorDataset(spa_train_samples, y_tensor_train) 221 | torch_dataset_test = Data.TensorDataset(spa_test_samples, y_tensor_test) 222 | 223 | elif model_type_flag == 2: # data for single spectral net 224 | spe_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 2) 225 | spe_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 2) 226 | 227 | torch_dataset_train = Data.TensorDataset(spe_train_samples, y_tensor_train) 228 | torch_dataset_test = Data.TensorDataset(spe_test_samples, y_tensor_test) 229 | 230 | elif model_type_flag == 3: # data for spectral-spatial net 231 | # spatail data 232 | spa_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 1) 233 | spa_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 1) 234 | 235 | # spectral data 236 | spe_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 2) 237 | spe_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 2) 238 | 239 | torch_dataset_train = Data.TensorDataset(spa_train_samples, spe_train_samples, y_tensor_train) 240 | torch_dataset_test = Data.TensorDataset(spa_test_samples, spe_test_samples, y_tensor_test) 241 | 242 | if last_batch_flag == 0: 243 | train_iter = Data.DataLoader(dataset=torch_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0) 244 | test_iter = Data.DataLoader(dataset=torch_dataset_test, batch_size=batch_size, shuffle=False, num_workers=0) 245 | elif last_batch_flag == 1: 246 | train_iter = Data.DataLoader(dataset=torch_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0, 247 | drop_last=True) 248 | test_iter = Data.DataLoader(dataset=torch_dataset_test, batch_size=batch_size, shuffle=False, num_workers=0, 249 | drop_last=True) 250 | # train_iter = Data.DataLoader(dataset=torch_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0) 251 | # test_iter = Data.DataLoader(dataset=torch_dataset_test, batch_size=batch_size, shuffle=False, num_workers=0) 252 | 253 | return train_iter, test_iter 254 | 255 | 256 | def generate_iter_disjoint(data_padded, hsi_h, hsi_w, gt_train_re, gt_test_re, index, patch_length, batch_size, 257 | model_type_flag, model_3D_spa_flag): 258 | data_padded_torch = torch.from_numpy(data_padded).type(torch.FloatTensor).to(device) 259 | 260 | train_labels = gt_train_re[index[0]] - 1 261 | test_labels = gt_test_re[index[1]] - 1 262 | 263 | y_tensor_train = torch.from_numpy(train_labels).type(torch.FloatTensor) 264 | y_tensor_test = torch.from_numpy(test_labels).type(torch.FloatTensor) 265 | 266 | # for data 267 | if model_type_flag == 1: # data for single spatial net 268 | spa_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 1) 269 | spa_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 1) 270 | 271 | if model_3D_spa_flag == 1: # spatial 3D patch 272 | spa_train_samples = spa_train_samples.unsqueeze(1) 273 | spa_test_samples = spa_test_samples.unsqueeze(1) 274 | 275 | torch_dataset_train = Data.TensorDataset(spa_train_samples, y_tensor_train) 276 | torch_dataset_test = Data.TensorDataset(spa_test_samples, y_tensor_test) 277 | 278 | elif model_type_flag == 2: # data for single spectral net 279 | spe_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 2) 280 | spe_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 2) 281 | 282 | torch_dataset_train = Data.TensorDataset(spe_train_samples, y_tensor_train) 283 | torch_dataset_test = Data.TensorDataset(spe_test_samples, y_tensor_test) 284 | 285 | elif model_type_flag == 3: # data for spectral-spatial net 286 | # spatail data 287 | spa_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 1) 288 | spa_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 1) 289 | 290 | # spectral data 291 | spe_train_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[0], patch_length, 2) 292 | spe_test_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index[1], patch_length, 2) 293 | 294 | torch_dataset_train = Data.TensorDataset(spa_train_samples, spe_train_samples, y_tensor_train) 295 | torch_dataset_test = Data.TensorDataset(spa_test_samples, spe_test_samples, y_tensor_test) 296 | 297 | train_iter = Data.DataLoader(dataset=torch_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0) 298 | test_iter = Data.DataLoader(dataset=torch_dataset_test, batch_size=batch_size, shuffle=False, num_workers=0) 299 | 300 | return train_iter, test_iter 301 | 302 | 303 | # all) generating HSI patches for the visualization of all the labeled samples of the data set 304 | # total) generating HSI patches for the visualization of total the samples of the data set 305 | # in addition, all) and total) both use GPU directly 306 | def generate_iter_total(data_padded, hsi_h, hsi_w, label_reshape, index, patch_length, batch_size, model_type_flag, 307 | model_3D_spa_flag): 308 | data_padded_torch = torch.from_numpy(data_padded).type(torch.FloatTensor).to(device) 309 | 310 | if len(index) < label_reshape.shape[0]: 311 | total_labels = label_reshape[index] - 1 312 | else: 313 | total_labels = np.zeros(label_reshape.shape) 314 | 315 | y_tensor_total = torch.from_numpy(total_labels).type(torch.FloatTensor) 316 | 317 | if model_type_flag == 1: 318 | total_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index, patch_length, 1) 319 | if model_3D_spa_flag == 1: # spatial 3D patch 320 | total_samples = total_samples.unsqueeze(1) 321 | torch_dataset_total = Data.TensorDataset(total_samples, y_tensor_total) 322 | 323 | elif model_type_flag == 2: 324 | total_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index, patch_length, 2) 325 | torch_dataset_total = Data.TensorDataset(total_samples, y_tensor_total) 326 | elif model_type_flag == 3: 327 | spa_total_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index, patch_length, 1) 328 | spe_total_samples = HSI_create_pathes(data_padded_torch, hsi_h, hsi_w, index, patch_length, 2) 329 | torch_dataset_total = Data.TensorDataset(spa_total_samples, spe_total_samples, y_tensor_total) 330 | 331 | total_iter = Data.DataLoader(dataset=torch_dataset_total, batch_size=batch_size, shuffle=False, num_workers=0) 332 | 333 | return total_iter 334 | -------------------------------------------------------------------------------- /process_dl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Auther : Mingsong Li (lms-07) 3 | # @Time : 2023-Apr 4 | # @Address : Time Lab @ SDU 5 | # @FileName : process_dl.py 6 | # @Project : AMS-M2ESL (HSIC), IEEE TGRS 7 | 8 | # for IP and UP data sets, main processing file for the proposed AMS-M2ESL model 9 | 10 | import os 11 | import time 12 | import torch 13 | import random 14 | import numpy as np 15 | from sklearn import metrics 16 | from ptflops import get_model_complexity_info 17 | 18 | import utils.evaluation as evaluation 19 | import utils.data_load_operate as data_load_operate 20 | import visual.cls_visual as cls_visual 21 | import model.AMS_M2ESL as AMS_M2ESL 22 | 23 | # import utils.data_load_operate_AIPS as data_load_operate 24 | 25 | 26 | time_current = time.strftime("%y-%m-%d-%H.%M", time.localtime()) 27 | 28 | # random seed setting 29 | seed = 20 30 | 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | np.random.seed(seed) # Numpy module. 35 | random.seed(seed) # Python random module. 36 | torch.manual_seed(seed) 37 | torch.backends.cudnn.benchmark = False 38 | torch.backends.cudnn.deterministic = True 39 | 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | 42 | ### 0 ### 43 | model_list = ['AMS-M2ESL'] 44 | model_flag = 0 45 | model_spa_set = {0} 46 | model_spe_set = {} 47 | model_spa_spe_set = {} 48 | model_3D_spa_set = {} 49 | model_3D_spa_flag = 0 50 | 51 | last_batch_flag = 0 52 | 53 | if model_flag in model_spa_set: 54 | model_type_flag = 1 55 | if model_flag in model_3D_spa_set: 56 | model_3D_spa_flag = 1 57 | elif model_flag in model_spe_set: 58 | model_type_flag = 2 59 | elif model_flag in model_spa_spe_set: 60 | model_type_flag = 3 61 | 62 | # 0-1 63 | data_set_name_list = ['IP', 'UP'] 64 | data_set_name = data_set_name_list[1] 65 | 66 | data_set_path = os.path.join(os.getcwd(), 'data') 67 | 68 | # control running times 69 | # seed_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 70 | # seed_list=[0,1,2,3,4] 71 | # seed_list=[0,1,2] 72 | # seed_list=[0,1] 73 | seed_list = [0] 74 | 75 | # data set split 76 | flag_list = [0, 1] # ratio or num 77 | 78 | if data_set_name == 'IP': 79 | ratio_list = [0.05, 0.005] 80 | ratio = 5.0 81 | elif data_set_name == 'UP': 82 | ratio_list = [0.01, 0.001] 83 | ratio = 1.0 84 | 85 | num_list = [50, 0] # [train_num,val_num] 86 | 87 | patch_size = 9 88 | patch_length = 4 89 | 90 | results_save_path = \ 91 | os.path.join(os.getcwd(), 'output/results', model_list[model_flag] + str("_") + 92 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str( 93 | ratio) + str("_patch_size_") + str(patch_size)) 94 | cls_map_save_path = \ 95 | os.path.join(os.path.join(os.getcwd(), 'output/cls_maps'), model_list[model_flag] + str("_") + 96 | data_set_name + str("_") + str(time_current) + str("_seed_") + str(seed) + str("_ratio_") + str(ratio)) 97 | 98 | if __name__ == '__main__': 99 | 100 | data, gt = data_load_operate.load_data(data_set_name, data_set_path) 101 | data = data_load_operate.standardization(data) 102 | 103 | # dr=math.ceil(data.shape[-1]*0.3) 104 | # data=data_load_operate.applyPCA(data,dr) # for abla of MNF 105 | data = data_load_operate.HSI_MNF(data, MNF_ratio=0.3) 106 | 107 | gt_reshape = gt.reshape(-1) 108 | height, width, channels = data.shape 109 | class_count = max(np.unique(gt)) 110 | 111 | batch_size = 256 112 | max_epoch = 100 113 | learning_rate = 0.001 114 | loss = torch.nn.CrossEntropyLoss() 115 | 116 | OA_ALL = [] 117 | AA_ALL = [] 118 | KPP_ALL = [] 119 | EACH_ACC_ALL = [] 120 | Train_Time_ALL = [] 121 | Test_Time_ALL = [] 122 | CLASS_ACC = np.zeros([len(seed_list), class_count]) 123 | 124 | # data pad zero 125 | # data:[h,w,c]->data_padded:[h+2l,w+2l,c] 126 | data_padded = data_load_operate.data_pad_zero(data, patch_length) 127 | height_patched, width_patched = data_padded.shape[0], data_padded.shape[1] 128 | 129 | # data_total_index = np.arange(data.shape[0] * data.shape[1]) # For total sample cls_map. 130 | 131 | for curr_seed in seed_list: 132 | tic1 = time.perf_counter() 133 | train_data_index, test_data_index, all_data_index = data_load_operate.sampling(ratio_list, 134 | num_list, 135 | gt_reshape, 136 | class_count, 137 | flag_list[0]) 138 | index = (train_data_index, test_data_index) 139 | 140 | train_iter, test_iter = data_load_operate.generate_iter(data_padded, height, width, 141 | gt_reshape, index, patch_length, 142 | batch_size, 143 | model_type_flag, 144 | model_3D_spa_flag, 145 | last_batch_flag) 146 | 147 | # load data for the cls map of the total samples 148 | # total_iter = data_load_operate.generate_iter_total(data_padded, height, width, gt_reshape, data_total_index, patch_length, 149 | # batch_size, model_type_flag, model_3D_spa_flag) 150 | 151 | if model_flag == 0: 152 | net = AMS_M2ESL.AMS_M2ESL_(in_channels=channels, patch_size=patch_size, num_classes=class_count, 153 | ds=data_set_name) 154 | 155 | net.to(device) 156 | 157 | # efficiency test, model complexity and computational cost 158 | # flops,para=get_model_complexity_info(net,(channels,1,1),as_strings=False,print_per_layer_stat=True, verbose=True) 159 | # flops,para=get_model_complexity_info(net,(patch_size,patch_size,channels),as_strings=False,print_per_layer_stat=True, verbose=True) 160 | # # # flops,para=get_model_complexity_info(net,(1,1,patch_size,patch_size,channels),as_strings=False,print_per_layer_stat=True, verbose=True) 161 | # print("para(M):{:.3f},\n flops(M):{:.3f}".format(para/(1000**2),flops/(1000**2))) 162 | 163 | train_loss_list = [100] 164 | train_acc_list = [0] 165 | 166 | optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) 167 | 168 | for epoch in range(max_epoch): 169 | train_acc_sum, trained_samples_counter = 0.0, 0 170 | batch_counter, train_loss_sum = 0, 0 171 | time_epoch = time.time() 172 | 173 | if model_type_flag == 1: # data for single spatial net 174 | for X_spa, y in train_iter: 175 | X_spa, y = X_spa.to(device), y.to(device) 176 | y_pred = net(X_spa) 177 | 178 | ls = loss(y_pred, y.long()) 179 | 180 | optimizer.zero_grad() 181 | ls.backward() 182 | optimizer.step() 183 | 184 | train_loss_sum += ls.cpu().item() 185 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 186 | trained_samples_counter += y.shape[0] 187 | batch_counter += 1 188 | epoch_first_iter = 0 189 | elif model_type_flag == 2: # data for single spectral net 190 | for X_spe, y in train_iter: 191 | X_spe, y = X_spe.to(device), y.to(device) 192 | y_pred = net(X_spe) 193 | 194 | ls = loss(y_pred, y.long()) 195 | 196 | optimizer.zero_grad() 197 | ls.backward() 198 | optimizer.step() 199 | 200 | train_loss_sum += ls.cpu().item() 201 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 202 | trained_samples_counter += y.shape[0] 203 | batch_counter += 1 204 | epoch_first_iter = 0 205 | elif model_type_flag == 3: # data for spectral-spatial net 206 | for X_spa, X_spe, y in train_iter: 207 | X_spa, X_spe, y = X_spa.to(device), X_spe.to(device), y.to(device) 208 | y_pred = net(X_spa, X_spe) 209 | 210 | ls = loss(y_pred, y.long()) 211 | 212 | optimizer.zero_grad() 213 | ls.backward() 214 | optimizer.step() 215 | 216 | train_loss_sum += ls.cpu().item() 217 | train_acc_sum += (y_pred.argmax(dim=1) == y).sum().cpu().item() 218 | trained_samples_counter += y.shape[0] 219 | batch_counter += 1 220 | epoch_first_iter = 0 221 | 222 | torch.cuda.empty_cache() 223 | 224 | train_loss_list.append(train_loss_sum) 225 | train_acc_list.append(train_acc_sum / trained_samples_counter) 226 | 227 | print('epoch: %d, training_sampler_num: %d, batch_count: %.2f, train loss: %.6f, tarin loss sum: %.6f, ' 228 | 'train acc: %.3f, train_acc_sum: %.1f, time: %.1f sec' % 229 | (epoch + 1, trained_samples_counter, batch_counter, train_loss_sum / batch_counter, train_loss_sum, 230 | train_acc_sum / trained_samples_counter, train_acc_sum, time.time() - time_epoch)) 231 | 232 | toc1 = time.perf_counter() 233 | print('Training stage finished:\n epoch %d, loss %.4f, train acc %.3f, training time %.2f s' 234 | % (epoch + 1, train_loss_sum / batch_counter, train_acc_sum / trained_samples_counter, toc1 - tic1)) 235 | training_time = toc1 - tic1 236 | Train_Time_ALL.append(training_time) 237 | 238 | print("\n\n====================Starting evaluation for testing set.========================\n") 239 | 240 | pred_test = [] 241 | # torch.cuda.empty_cache() 242 | with torch.no_grad(): 243 | # net.load_state_dict(torch.load(model_save_path+"_best_model.pt")) 244 | net.eval() 245 | train_acc_sum, samples_num_counter = 0.0, 0 246 | if model_type_flag == 1: # data for single spatial net 247 | for X_spa, y in test_iter: 248 | X_spa = X_spa.to(device) 249 | 250 | tic2 = time.perf_counter() 251 | y_pred = net(X_spa) 252 | toc2 = time.perf_counter() 253 | 254 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 255 | elif model_type_flag == 2: # data for single spectral net 256 | for X_spe, y in test_iter: 257 | X_spe = X_spe.to(device) 258 | 259 | tic2 = time.perf_counter() 260 | y_pred = net(X_spe) 261 | toc2 = time.perf_counter() 262 | 263 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 264 | elif model_type_flag == 3: # data for spectral-spatial net 265 | for X_spa, X_spe, y in test_iter: 266 | X_spa = X_spa.to(device) 267 | X_spe = X_spe.to(device) 268 | 269 | tic2 = time.perf_counter() 270 | y_pred = net(X_spa, X_spe) 271 | toc2 = time.perf_counter() 272 | 273 | pred_test.extend(np.array(y_pred.cpu().argmax(axis=1))) 274 | 275 | y_gt = gt_reshape[test_data_index] - 1 276 | OA = metrics.accuracy_score(y_gt, pred_test) 277 | confusion_matrix = metrics.confusion_matrix(pred_test, y_gt) 278 | print("confusion_matrix\n{}".format(confusion_matrix)) 279 | ECA, AA = evaluation.AA_ECA(confusion_matrix) 280 | kappa = metrics.cohen_kappa_score(pred_test, y_gt) 281 | cls_report = evaluation.claification_report(y_gt, pred_test, data_set_name) 282 | print("classification_report\n{}".format(cls_report)) 283 | 284 | # Visualization for all the labeled samples and total the samples 285 | # sample_list1 = [total_iter] 286 | # sample_list2 = [all_iter, all_data_index] 287 | 288 | # Visualization.gt_cls_map(gt,cls_map_save_path) 289 | # cls_visual.pred_cls_map_dl(sample_list1, net, gt, cls_map_save_path, model_type_flag) 290 | # cls_visual.pred_cls_map_dl(sample_list2,net,gt,cls_map_save_path) 291 | 292 | testing_time = toc2 - tic2 293 | Test_Time_ALL.append(testing_time) 294 | 295 | # Output infors 296 | f = open(results_save_path + '_results.txt', 'a+') 297 | str_results = '\n======================' \ 298 | + " learning rate=" + str(learning_rate) \ 299 | + " epochs=" + str(max_epoch) \ 300 | + " train ratio=" + str(ratio_list[0]) \ 301 | + " val ratio=" + str(ratio_list[1]) \ 302 | + " ======================" \ 303 | + "\nOA=" + str(OA) \ 304 | + "\nAA=" + str(AA) \ 305 | + '\nkpp=' + str(kappa) \ 306 | + '\nacc per class:' + str(ECA) \ 307 | + "\ntrain time:" + str(training_time) \ 308 | + "\ntest time:" + str(testing_time) + "\n" 309 | 310 | f.write(str_results) 311 | f.write('{}'.format(confusion_matrix)) 312 | f.write('\n\n') 313 | f.write('{}'.format(cls_report)) 314 | f.close() 315 | 316 | OA_ALL.append(OA) 317 | AA_ALL.append(AA) 318 | KPP_ALL.append(kappa) 319 | EACH_ACC_ALL.append(ECA) 320 | 321 | torch.cuda.empty_cache() 322 | del net, train_iter, test_iter 323 | 324 | OA_ALL = np.array(OA_ALL) 325 | AA_ALL = np.array(AA_ALL) 326 | KPP_ALL = np.array(KPP_ALL) 327 | EACH_ACC_ALL = np.array(EACH_ACC_ALL) 328 | Train_Time_ALL = np.array(Train_Time_ALL) 329 | Test_Time_ALL = np.array(Test_Time_ALL) 330 | 331 | np.set_printoptions(precision=4) 332 | print("\n====================Mean result of {} times runs =========================".format(len(seed_list))) 333 | print('List of OA:', list(OA_ALL)) 334 | print('List of AA:', list(AA_ALL)) 335 | print('List of KPP:', list(KPP_ALL)) 336 | print('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2)) 337 | print('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2)) 338 | print('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2)) 339 | print('Acc per class=', np.mean(EACH_ACC_ALL, 0), '+-', np.std(EACH_ACC_ALL, 0)) 340 | 341 | print("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3)) 342 | print("Average testing time=", round(np.mean(Test_Time_ALL), 5), '+-', round(np.std(Test_Time_ALL), 5)) 343 | 344 | # Output infors 345 | f = open(results_save_path + '_results.txt', 'a+') 346 | str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + ' times runs ********************' \ 347 | + '\nList of OA:' + str(list(OA_ALL)) \ 348 | + '\nList of AA:' + str(list(AA_ALL)) \ 349 | + '\nList of KPP:' + str(list(KPP_ALL)) \ 350 | + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \ 351 | + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \ 352 | + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(round(np.std(KPP_ALL) * 100, 2)) \ 353 | + '\nAcc per class=\n' + str(np.mean(EACH_ACC_ALL, 0)) + '+-' + str(np.std(EACH_ACC_ALL, 0)) \ 354 | + "\nAverage training time=" + str(round(np.mean(Train_Time_ALL), 2)) + '+-' + str( 355 | round(np.std(Train_Time_ALL), 3)) \ 356 | + "\nAverage testing time=" + str(round(np.mean(Test_Time_ALL), 5)) + '+-' + str( 357 | round(np.std(Test_Time_ALL), 5)) 358 | f.write(str_results) 359 | f.close() 360 | --------------------------------------------------------------------------------