├── CrossViT_module.py ├── README.md ├── Utils.py ├── classification_maps ├── IP.jpg ├── KSC.jpg └── PU.jpg ├── convert_report_to_csv.ipynb ├── demo.ipynb ├── geniter.py ├── record.py └── vit_pytorch.py /CrossViT_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim), 49 | nn.Dropout(dropout) 50 | ) if project_out else nn.Identity() 51 | 52 | def forward(self, x): 53 | b, n, _, h = *x.shape, self.heads 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | 59 | attn = dots.softmax(dim=-1) 60 | 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | out = self.to_out(out) 64 | return out 65 | 66 | class CrossAttention(nn.Module): 67 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 68 | super().__init__() 69 | inner_dim = dim_head * heads 70 | project_out = not (heads == 1 and dim_head == dim) 71 | 72 | self.heads = heads 73 | self.scale = dim_head ** -0.5 74 | 75 | self.to_k = nn.Linear(dim, inner_dim , bias=False) 76 | self.to_v = nn.Linear(dim, inner_dim , bias = False) 77 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 78 | 79 | self.to_out = nn.Sequential( 80 | nn.Linear(inner_dim, dim), 81 | nn.Dropout(dropout) 82 | ) if project_out else nn.Identity() 83 | 84 | def forward(self, x_qkv): 85 | b, n, _, h = *x_qkv.shape, self.heads 86 | 87 | k = self.to_k(x_qkv) 88 | k = rearrange(k, 'b n (h d) -> b h n d', h = h) 89 | 90 | v = self.to_v(x_qkv) 91 | v = rearrange(v, 'b n (h d) -> b h n d', h = h) 92 | 93 | q = self.to_q(x_qkv[:, 0].unsqueeze(1)) 94 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 95 | 96 | 97 | 98 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 99 | 100 | attn = dots.softmax(dim=-1) 101 | 102 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 103 | out = rearrange(out, 'b h n d -> b n (h d)') 104 | out = self.to_out(out) 105 | return out 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CS2DT 2 | ## Paper 3 | This is the code of the hyperspectral image classification network: CS2DT: Cross Spatial–Spectral Dense Transformer for Hyperspectral Image Classification (HSI) (https://ieeexplore.ieee.org/abstract/document/10268928). 4 | 5 | The datasets required for the code can be found at http://www.ehu.eus/ccwintco/index.php/Hyperspectral_Remote_Sensing_Scenes or http://dase.grss-ieee.org/. 6 | 7 | ## Cite 8 | If you use CS2DT in your work please cite our paper: 9 | 10 | Xu H, Zeng Z, Yao W, et al. CS2DT: Cross Spatial–Spectral Dense Transformer for Hyperspectral Image Classification[J]. IEEE Geoscience and Remote Sensing Letters, 2023. 11 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics, preprocessing 3 | from sklearn.preprocessing import MinMaxScaler 4 | from sklearn.decomposition import PCA 5 | from sklearn.model_selection import train_test_split 6 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score 7 | from operator import truediv 8 | import matplotlib.pyplot as plt 9 | import scipy.io as sio 10 | import os 11 | import spectral 12 | import torch 13 | import cv2 14 | from operator import truediv 15 | 16 | 17 | def sampling(proportion, ground_truth): 18 | train = {} 19 | test = {} 20 | labels_loc = {} 21 | m = max(ground_truth) 22 | for i in range(m): 23 | indexes = [ 24 | j for j, x in enumerate(ground_truth.ravel().tolist()) 25 | if x == i + 1 26 | ] 27 | np.random.shuffle(indexes) 28 | labels_loc[i] = indexes 29 | if proportion != 1: 30 | nb_val = max(int((1 - proportion) * len(indexes)), 3) 31 | else: 32 | nb_val = 0 33 | train[i] = indexes[:nb_val] 34 | test[i] = indexes[nb_val:] 35 | train_indexes = [] 36 | test_indexes = [] 37 | for i in range(m): 38 | train_indexes += train[i] 39 | test_indexes += test[i] 40 | np.random.shuffle(train_indexes) 41 | np.random.shuffle(test_indexes) 42 | return train_indexes, test_indexes 43 | 44 | 45 | def set_figsize(figsize=(3.5, 2.5)): 46 | display.set_matplotlib_formats('svg') 47 | plt.rcParams['figure.figsize'] = figsize 48 | 49 | 50 | def classification_map(map, ground_truth, dpi, save_path): 51 | fig = plt.figure(frameon=False) 52 | fig.set_size_inches(ground_truth.shape[1] * 2.0 / dpi, 53 | ground_truth.shape[0] * 2.0 / dpi) 54 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 55 | ax.set_axis_off() 56 | ax.xaxis.set_visible(False) 57 | ax.yaxis.set_visible(False) 58 | fig.add_axes(ax) 59 | ax.imshow(map) 60 | fig.savefig(save_path, dpi=dpi) 61 | return 0 62 | 63 | 64 | def list_to_colormap(x_list): 65 | y = np.zeros((x_list.shape[0], 3)) 66 | for index, item in enumerate(x_list): 67 | if item == 0: 68 | y[index] = np.array([255, 0, 0]) / 255. 69 | if item == 1: 70 | y[index] = np.array([0, 255, 0]) / 255. 71 | if item == 2: 72 | y[index] = np.array([0, 0, 255]) / 255. 73 | if item == 3: 74 | y[index] = np.array([255, 255, 0]) / 255. 75 | if item == 4: 76 | y[index] = np.array([0, 255, 255]) / 255. 77 | if item == 5: 78 | y[index] = np.array([255, 0, 255]) / 255. 79 | if item == 6: 80 | y[index] = np.array([192, 192, 192]) / 255. 81 | if item == 7: 82 | y[index] = np.array([128, 128, 128]) / 255. 83 | if item == 8: 84 | y[index] = np.array([128, 0, 0]) / 255. 85 | if item == 9: 86 | y[index] = np.array([128, 128, 0]) / 255. 87 | if item == 10: 88 | y[index] = np.array([0, 128, 0]) / 255. 89 | if item == 11: 90 | y[index] = np.array([128, 0, 128]) / 255. 91 | if item == 12: 92 | y[index] = np.array([0, 128, 128]) / 255. 93 | if item == 13: 94 | y[index] = np.array([0, 0, 128]) / 255. 95 | if item == 14: 96 | y[index] = np.array([255, 165, 0]) / 255. 97 | if item == 15: 98 | y[index] = np.array([255, 215, 0]) / 255. 99 | if item == 16: 100 | y[index] = np.array([0, 0, 0]) / 255. 101 | if item == 17: 102 | y[index] = np.array([215, 255, 0]) / 255. 103 | if item == 18: 104 | y[index] = np.array([0, 255, 215]) / 255. 105 | if item == -1: 106 | y[index] = np.array([0, 0, 0]) / 255. 107 | return y 108 | 109 | 110 | def generate_png(all_iter, net, gt_hsi, Dataset, device, total_indices, path): 111 | pred_test = [] 112 | for X, y in all_iter: 113 | #X = X.permute(0, 3, 1, 2) 114 | X = X.to(device) 115 | net.eval() 116 | pred_test.extend(net(X).cpu().argmax(axis=1).detach().numpy()) 117 | gt = gt_hsi.flatten() 118 | x_label = np.zeros(gt.shape) 119 | for i in range(len(gt)): 120 | if gt[i] == 0: 121 | gt[i] = 17 122 | x_label[i] = 16 123 | gt = gt[:] - 1 124 | x_label[total_indices] = pred_test 125 | x = np.ravel(x_label) 126 | y_list = list_to_colormap(x) 127 | y_gt = list_to_colormap(gt) 128 | y_re = np.reshape(y_list, (gt_hsi.shape[0], gt_hsi.shape[1], 3)) 129 | gt_re = np.reshape(y_gt, (gt_hsi.shape[0], gt_hsi.shape[1], 3)) 130 | classification_map(y_re, gt_hsi, 300, 131 | path + '.png') 132 | classification_map(gt_re, gt_hsi, 300, 133 | path + '_gt.png') 134 | print('------Get classification maps successful-------') 135 | -------------------------------------------------------------------------------- /classification_maps/IP.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shouhengx/CS2DT/a94b1c00265d3e0801396adfc347b03a35b93779/classification_maps/IP.jpg -------------------------------------------------------------------------------- /classification_maps/KSC.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shouhengx/CS2DT/a94b1c00265d3e0801396adfc347b03a35b93779/classification_maps/KSC.jpg -------------------------------------------------------------------------------- /classification_maps/PU.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shouhengx/CS2DT/a94b1c00265d3e0801396adfc347b03a35b93779/classification_maps/PU.jpg -------------------------------------------------------------------------------- /convert_report_to_csv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 19, 6 | "id": "8d6c430c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "import re\n", 12 | "from glob import glob\n", 13 | "import argparse\n", 14 | "import os" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 20, 20 | "id": "ed00177e", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def get_data(report):\n", 25 | " class_wise_acc_regex = r'\\[[\\d.\\se\\-\\+(\\\\n)]*\\]'\n", 26 | " oa_aa_kappa_regex = r'([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?\\s±\\s[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)'\n", 27 | "\n", 28 | " result = {}\n", 29 | "\n", 30 | " x = re.findall(oa_aa_kappa_regex, report)\n", 31 | " result['oa'] = x[0][0]\n", 32 | " result['aa'] = x[1][0]\n", 33 | " result['kappa'] = x[2][0]\n", 34 | "\n", 35 | " result['oa'] = \"{:.2f}\".format(\n", 36 | " float(result['oa'].split(' ± ')[0]) * 100) + ' ± ' + \"{:.3f}\".format(\n", 37 | " float(result['oa'].split(' ± ')[1]))\n", 38 | " result['aa'] = \"{:.2f}\".format(\n", 39 | " float(result['aa'].split(' ± ')[0]) * 100) + ' ± ' + \"{:.3f}\".format(\n", 40 | " float(result['aa'].split(' ± ')[1]))\n", 41 | " result['kappa'] = \"{:.4f}\".format(float(\n", 42 | " result['kappa'].split(' ± ')[0])) + ' ± ' + \"{:.3f}\".format(\n", 43 | " float(result['kappa'].split(' ± ')[1]))\n", 44 | "\n", 45 | " x = re.findall(class_wise_acc_regex, report)\n", 46 | " result['class_mean'] = x[0][1:-1].split()\n", 47 | " result['class_std'] = x[1][1:-1].split()\n", 48 | " result['class_wise'] = [\n", 49 | " \"{:.2f}\".format(float(m) * 100) + ' ± ' + \"{:.3f}\".format(float(n))\n", 50 | " for m, n in zip(result['class_mean'], result['class_std'])\n", 51 | " ]\n", 52 | "\n", 53 | " return result\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 21, 59 | "id": "f95fbd99", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "search_path='./report_ok'" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 22, 69 | "id": "49f02ca4", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "dataset='UP' # UP,IN,SV, KSC" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 23, 79 | "id": "d0693097", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "output_file=None" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 24, 89 | "id": "b49f509f", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "all_reports = glob(search_path + '/*' + dataset + '*.txt')" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "c18ab267", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "all_reports" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 26, 109 | "id": "da686816", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "no_of_labels = 0\n", 114 | "dataframe_dict = {}" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "37954ec0", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "for report in all_reports:\n", 125 | " print('Processing...', report)\n", 126 | " column_name = os.path.basename(report)[:-4]\n", 127 | " with open(report) as f:\n", 128 | " report_content = f.read()\n", 129 | "\n", 130 | " result = get_data(report_content)\n", 131 | "\n", 132 | " dataframe_dict[column_name] = result['class_wise'] + [result['oa']] + [\n", 133 | " result['aa']\n", 134 | " ] + [result['kappa']]\n", 135 | " no_of_labels = len(result['class_wise'])\n", 136 | "\n", 137 | "label_list = [str(i)\n", 138 | " for i in range(1, no_of_labels + 1)] + ['oa', 'aa', 'kappa']\n", 139 | "df = pd.DataFrame(dataframe_dict)\n", 140 | "df = df.reindex(sorted(df.columns), axis=1)\n", 141 | "df.insert(0, 'label', label_list)\n", 142 | "\n", 143 | "print('Saving...', dataset, 'report.')\n", 144 | "if output_file is not None:\n", 145 | " df.to_csv(output_file, index=False)\n", 146 | "else:\n", 147 | " if not os.path.exists('csv_reports'):\n", 148 | " os.makedirs('csv_reports')\n", 149 | " df.to_csv(\n", 150 | " os.path.join('csv_reports', dataset + '_DC-DenseFormer_report.csv'), index=False)\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "2d65c010", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "pyHSI", 165 | "language": "python", 166 | "name": "python3" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.8.13" 179 | }, 180 | "vscode": { 181 | "interpreter": { 182 | "hash": "4f0e2fe85115fd2386557c42e0d23c6fb95dd61fe2d1e647d69acf95e1a29035" 183 | } 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 5 188 | } 189 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "3628f23e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import argparse\n", 12 | "import torch.nn as nn\n", 13 | "import torch.utils.data as Data\n", 14 | "import torch.backends.cudnn as cudnn\n", 15 | "import scipy.io as sio\n", 16 | "from scipy.io import savemat\n", 17 | "from torch import optim\n", 18 | "from torch.autograd import Variable\n", 19 | "from vit_pytorch import ViT\n", 20 | "from sklearn.metrics import confusion_matrix\n", 21 | "from sklearn import metrics, preprocessing\n", 22 | "from sklearn.decomposition import PCA\n", 23 | "\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from matplotlib import colors\n", 26 | "import numpy as np\n", 27 | "import time\n", 28 | "import os" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "1098f8cd", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import argparse\n", 39 | "import collections\n", 40 | "import math\n", 41 | "import time\n", 42 | "\n", 43 | "import torch.nn.functional as F\n", 44 | "import torch.optim as optim\n", 45 | "\n", 46 | "from sklearn.metrics import confusion_matrix\n", 47 | "from torchsummary import summary\n", 48 | "import geniter\n", 49 | "import record\n", 50 | "import Utils\n", 51 | "import gc\n", 52 | "\n", 53 | "from thop import profile" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "b052bc94", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "PARAM_DATASET = 'IN' # UP,IN,SV, KSC\n", 64 | "PARAM_EPOCH = 100\n", 65 | "PARAM_ITER = 3\n", 66 | "PATCH_SIZE = 4\n", 67 | "PARAM_VAL = 0.95\n", 68 | "mode='DEN-1'\n", 69 | "\n", 70 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", 71 | "cross_attn_depth=1\n", 72 | "ssf_enc_depth=1\n", 73 | "\n", 74 | "\n", 75 | "P_dim = 128\n", 76 | "P_dim_head=64\n", 77 | "P_mlp_dim = 64\n", 78 | "P_depth = 4" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "id": "e099ada4", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "id": "419dc7d0", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "seeds = [1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341]\n", 99 | "dataset = PARAM_DATASET \n", 100 | "Dataset = dataset.upper()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "id": "0ea5db66", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "def load_dataset(Dataset, split=0.9):\n", 111 | " data_path = './../data/'\n", 112 | " if Dataset == 'IN':\n", 113 | " mat_data = sio.loadmat(data_path + 'Indian_pines_corrected.mat')\n", 114 | " mat_gt = sio.loadmat(data_path + 'Indian_pines_gt.mat')\n", 115 | " data_hsi = mat_data['indian_pines_corrected']\n", 116 | " gt_hsi = mat_gt['indian_pines_gt']\n", 117 | " K = 200\n", 118 | " TOTAL_SIZE = 10249\n", 119 | " VALIDATION_SPLIT = split\n", 120 | " TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)\n", 121 | "\n", 122 | " if Dataset == 'UP':\n", 123 | " uPavia = sio.loadmat(data_path + 'PaviaU.mat')\n", 124 | " gt_uPavia = sio.loadmat(data_path + 'PaviaU_gt.mat')\n", 125 | " data_hsi = uPavia['paviaU']\n", 126 | " gt_hsi = gt_uPavia['paviaU_gt']\n", 127 | " K = 103\n", 128 | " TOTAL_SIZE = 42776\n", 129 | " VALIDATION_SPLIT = split\n", 130 | " TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)\n", 131 | "\n", 132 | " if Dataset == 'SV':\n", 133 | " SV = sio.loadmat(data_path + 'Salinas_corrected.mat')\n", 134 | " gt_SV = sio.loadmat(data_path + 'Salinas_gt.mat')\n", 135 | " data_hsi = SV['salinas_corrected']\n", 136 | " gt_hsi = gt_SV['salinas_gt']\n", 137 | " K = data_hsi.shape[2]\n", 138 | " TOTAL_SIZE = 54129\n", 139 | " VALIDATION_SPLIT = split\n", 140 | " TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)\n", 141 | "\n", 142 | " if Dataset == 'KSC':\n", 143 | " KSV = sio.loadmat(data_path + 'KSC.mat')\n", 144 | " gt_KSV = sio.loadmat(data_path + 'KSC_gt.mat')\n", 145 | " data_hsi = KSV['KSC']\n", 146 | " gt_hsi = gt_KSV['KSC_gt']\n", 147 | " K = data_hsi.shape[2]\n", 148 | " TOTAL_SIZE = 5211\n", 149 | " VALIDATION_SPLIT = split\n", 150 | " TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)\n", 151 | "\n", 152 | " \n", 153 | " if Dataset == 'BO':\n", 154 | " BO = sio.loadmat(data_path + 'Botswana.mat')\n", 155 | " gt_BO = sio.loadmat(data_path + 'Botswana_gt.mat')\n", 156 | " data_hsi = BO['Botswana']\n", 157 | " gt_hsi = gt_BO['Botswana_gt']\n", 158 | " K = data_hsi.shape[2]\n", 159 | " TOTAL_SIZE = 3248\n", 160 | " VALIDATION_SPLIT = split\n", 161 | " TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)\n", 162 | "\n", 163 | "\n", 164 | " if Dataset == 'UH':\n", 165 | " data_hsi = sio.loadmat(data_path + 'houston.mat')['houston']\n", 166 | " gt_hsi = sio.loadmat(data_path + 'houston_gt.mat')['houston_gt_tr']\n", 167 | " gt_hsi += sio.loadmat(data_path + 'houston_gt.mat')['houston_gt_te']\n", 168 | " K = data_hsi.shape[2]\n", 169 | " TOTAL_SIZE = 15029\n", 170 | " VALIDATION_SPLIT = split\n", 171 | " TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)\n", 172 | "\n", 173 | "\n", 174 | " shapeor = data_hsi.shape\n", 175 | " data_hsi = data_hsi.reshape(-1, data_hsi.shape[-1])\n", 176 | " data_hsi = PCA(n_components=K).fit_transform(data_hsi)\n", 177 | " shapeor = np.array(shapeor)\n", 178 | " shapeor[-1] = K\n", 179 | " data_hsi = data_hsi.reshape(shapeor)\n", 180 | "\n", 181 | " return data_hsi, gt_hsi, TOTAL_SIZE, TRAIN_SIZE, VALIDATION_SPLIT" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 7, 187 | "id": "0e231d61", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "data_hsi, gt_hsi, TOTAL_SIZE, TRAIN_SIZE, VALIDATION_SPLIT = load_dataset(\n", 192 | " Dataset, PARAM_VAL)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 8, 198 | "id": "f9eb21ae", 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "image_x, image_y, BAND = data_hsi.shape" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 9, 208 | "id": "dad1d1e9", 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "data": { 213 | "text/plain": [ 214 | "(207400, 103)" 215 | ] 216 | }, 217 | "execution_count": 9, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "data = data_hsi.reshape(\n", 224 | " np.prod(data_hsi.shape[:2]), np.prod(data_hsi.shape[2:]))\n", 225 | "data.shape" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 10, 231 | "id": "99d8eb6b", 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "text/plain": [ 237 | "(610, 340)" 238 | ] 239 | }, 240 | "execution_count": 10, 241 | "metadata": {}, 242 | "output_type": "execute_result" 243 | } 244 | ], 245 | "source": [ 246 | "gt_hsi.shape" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 11, 252 | "id": "42811dd1", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "gt = gt_hsi.reshape(np.prod(gt_hsi.shape[:2]), )\n", 257 | "gt.shape\n", 258 | "CLASSES_NUM = max(gt)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 12, 264 | "id": "8dde8570", 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "PATCH_LENGTH = PATCH_SIZE" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 13, 274 | "id": "6e4a161d", 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "img_rows = 2 * PATCH_LENGTH + 1\n", 279 | "img_cols = 2 * PATCH_LENGTH + 1\n", 280 | "img_channels = data_hsi.shape[2]\n", 281 | "INPUT_DIMENSION = data_hsi.shape[2]\n", 282 | "ALL_SIZE = data_hsi.shape[0] * data_hsi.shape[1]\n", 283 | "VAL_SIZE = int(TRAIN_SIZE)\n", 284 | "TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 14, 290 | "id": "bcdc17fd", 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "data = preprocessing.scale(data)\n", 295 | "data_ = data.reshape(data_hsi.shape[0], data_hsi.shape[1], data_hsi.shape[2])\n", 296 | "whole_data = data_\n", 297 | "padded_data = np.lib.pad(\n", 298 | " whole_data, ((PATCH_LENGTH, PATCH_LENGTH), (PATCH_LENGTH, PATCH_LENGTH),\n", 299 | " (0, 0)),\n", 300 | " 'constant',\n", 301 | " constant_values=0)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 15, 307 | "id": "18f119f5", 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "def sampling(proportion, ground_truth):\n", 312 | " train = {}\n", 313 | " test = {}\n", 314 | " labels_loc = {}\n", 315 | " m = max(ground_truth)\n", 316 | " for i in range(m):\n", 317 | " indexes = [\n", 318 | " j for j, x in enumerate(ground_truth.ravel().tolist())\n", 319 | " if x == i + 1\n", 320 | " ]\n", 321 | " np.random.shuffle(indexes)\n", 322 | " labels_loc[i] = indexes\n", 323 | " if proportion != 1:\n", 324 | " nb_val = max(int((1 - proportion) * len(indexes)), 3)\n", 325 | " else:\n", 326 | " nb_val = 0\n", 327 | " train[i] = indexes[:nb_val]\n", 328 | " test[i] = indexes[nb_val:]\n", 329 | " train_indexes = []\n", 330 | " test_indexes = []\n", 331 | " for i in range(m):\n", 332 | " train_indexes += train[i]\n", 333 | " test_indexes += test[i]\n", 334 | " np.random.shuffle(train_indexes)\n", 335 | " np.random.shuffle(test_indexes)\n", 336 | " return train_indexes, test_indexes" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 16, 342 | "id": "17663dc9", 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "-----Selecting Small Pieces from the Original Cube Data-----\n", 350 | "Train size: (2135, 15, 15, 103)\n" 351 | ] 352 | } 353 | ], 354 | "source": [ 355 | "index_iter=0\n", 356 | "np.random.seed(seeds[index_iter])\n", 357 | "train_indices, test_indices = sampling(VALIDATION_SPLIT, gt)\n", 358 | "_, total_indices = sampling(1, gt)\n", 359 | "\n", 360 | "TRAIN_SIZE = len(train_indices)\n", 361 | "TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE\n", 362 | "VAL_SIZE = int(TRAIN_SIZE)\n", 363 | "\n", 364 | "print('-----Selecting Small Pieces from the Original Cube Data-----')\n", 365 | "x_train,y_train, x_val,y_val, x_test,y_test, all_data, gt_all = geniter.generate_iter(\n", 366 | " TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE,\n", 367 | " total_indices, VAL_SIZE, whole_data, PATCH_LENGTH, padded_data,\n", 368 | " INPUT_DIMENSION, 64, gt) #batchsize in 1\n" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 17, 374 | "id": "924cba18", 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "#band_patch=1\n", 379 | "band=x_train.shape[-1]\n", 380 | "patch=x_train.shape[-2]" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 18, 386 | "id": "9a6ed803", 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "model = ViT(\n", 391 | " image_size = patch,\n", 392 | " num_patches = band,\n", 393 | " num_classes = CLASSES_NUM,\n", 394 | " dim = P_dim,\n", 395 | " dim_head=P_dim_head,\n", 396 | " mlp_dim = P_mlp_dim,\n", 397 | " depth = P_depth,\n", 398 | " heads = 4,\n", 399 | " dropout = 0.1,\n", 400 | " emb_dropout = 0.1,\n", 401 | " mode = mode,\n", 402 | " cross_attn_depth = cross_attn_depth, \n", 403 | " ssf_enc_depth = ssf_enc_depth\n", 404 | ")" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "id": "999b2e67", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "dummy_input = torch.randn(1,patch,patch,band)\n", 415 | "flops, params = profile(model, (dummy_input,))\n", 416 | "print('flops: ', flops, 'params: ', params)\n", 417 | "print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 20, 423 | "id": "f124df3e", 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "def train(net,\n", 428 | " train_iter,\n", 429 | " valida_iter,\n", 430 | " loss,\n", 431 | " optimizer,\n", 432 | " device,\n", 433 | " epochs,\n", 434 | " early_stopping=True,\n", 435 | " early_num=20):\n", 436 | " loss_list = [100]\n", 437 | " early_epoch = 0\n", 438 | "\n", 439 | " net = net.to(device)\n", 440 | " print(\"training on \", device)\n", 441 | " start = time.time()\n", 442 | " train_loss_list = []\n", 443 | " valida_loss_list = []\n", 444 | " train_acc_list = []\n", 445 | " valida_acc_list = []\n", 446 | " for epoch in range(epochs):\n", 447 | " train_acc_sum, n = 0.0, 0\n", 448 | " time_epoch = time.time()\n", 449 | " lr_adjust = torch.optim.lr_scheduler.StepLR(\n", 450 | " optimizer, step_size=PARAM_EPOCH//10, gamma=0.9)\n", 451 | " for X, y in train_iter:\n", 452 | "\n", 453 | " batch_count, train_l_sum = 0, 0\n", 454 | " X = X.to(device)\n", 455 | " y = y.to(device)\n", 456 | " y_hat = net(X)\n", 457 | " l = loss(y_hat, y.long())\n", 458 | "\n", 459 | " optimizer.zero_grad()\n", 460 | " l.backward()\n", 461 | " optimizer.step()\n", 462 | " train_l_sum += l.cpu().item()\n", 463 | " train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n", 464 | " n += y.shape[0]\n", 465 | " batch_count += 1\n", 466 | " lr_adjust.step()\n", 467 | " valida_acc, valida_loss = record.evaluate_accuracy(\n", 468 | " valida_iter, net, loss, device)\n", 469 | " loss_list.append(valida_loss)\n", 470 | "\n", 471 | " train_loss_list.append(train_l_sum) # / batch_count)\n", 472 | " train_acc_list.append(train_acc_sum / n)\n", 473 | " valida_loss_list.append(valida_loss)\n", 474 | " valida_acc_list.append(valida_acc)\n", 475 | "\n", 476 | " print(\n", 477 | " 'epoch %d, train loss %.6f, train acc %.3f, valida loss %.6f, valida acc %.3f, time %.1f sec'\n", 478 | " % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n,\n", 479 | " valida_loss, valida_acc, time.time() - time_epoch))\n", 480 | "\n", 481 | " PATH = \"./net_DBA.pt\"\n", 482 | "\n", 483 | " if early_stopping and loss_list[-2] < loss_list[-1]:\n", 484 | " if early_epoch == 0:\n", 485 | " torch.save(net.state_dict(), PATH)\n", 486 | " early_epoch += 1\n", 487 | " loss_list[-1] = loss_list[-2]\n", 488 | " if early_epoch == early_num:\n", 489 | " net.load_state_dict(torch.load(PATH))\n", 490 | " break\n", 491 | " else:\n", 492 | " early_epoch = 0\n", 493 | "\n", 494 | " print('epoch %d, loss %.4f, train acc %.3f, time %.1f sec'\n", 495 | " % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n,\n", 496 | " time.time() - start))" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 21, 502 | "id": "aa0e9d49", 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "loss = torch.nn.CrossEntropyLoss()" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 22, 512 | "id": "99d02140", 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "ITER = PARAM_ITER\n", 517 | "KAPPA = []\n", 518 | "OA = []\n", 519 | "AA = []\n", 520 | "TRAINING_TIME = []\n", 521 | "TESTING_TIME = []\n", 522 | "ELEMENT_ACC = np.zeros((ITER, CLASSES_NUM))" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "id": "36ca3f42", 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "del data_hsi,data,data_\n", 533 | "gc.collect()" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": null, 539 | "id": "31a49a3f", 540 | "metadata": { 541 | "scrolled": true 542 | }, 543 | "outputs": [], 544 | "source": [ 545 | "for index_iter in range(ITER):\n", 546 | " print('iter:', index_iter)\n", 547 | " \n", 548 | " np.random.seed(seeds[index_iter])\n", 549 | " train_indices, test_indices = sampling(VALIDATION_SPLIT, gt)\n", 550 | " _, total_indices = sampling(1, gt)\n", 551 | "\n", 552 | " TRAIN_SIZE = len(train_indices)\n", 553 | " TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE\n", 554 | " VAL_SIZE = int(TRAIN_SIZE)\n", 555 | "\n", 556 | " print('-----Selecting Small Pieces from the Original Cube Data-----')\n", 557 | " x_train,y_train, x_val,y_val, x_test,y_test, all_data, gt_all = geniter.generate_iter(\n", 558 | " TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE,\n", 559 | " total_indices, VAL_SIZE, whole_data, PATCH_LENGTH, padded_data,\n", 560 | " INPUT_DIMENSION, 64, gt) \n", 561 | "\n", 562 | " del all_data,gt_all\n", 563 | " gc.collect()\n", 564 | " \n", 565 | " band=x_train.shape[-1]\n", 566 | " patch=x_train.shape[-2]\n", 567 | " \n", 568 | " \n", 569 | " x_train=torch.from_numpy(x_train).type(torch.FloatTensor) \n", 570 | " y_train=torch.from_numpy(y_train).type(torch.FloatTensor) \n", 571 | " Label_train=Data.TensorDataset(x_train,y_train)\n", 572 | " \n", 573 | " del x_train,y_train\n", 574 | " gc.collect()\n", 575 | " \n", 576 | " x_test=torch.from_numpy(x_test).type(torch.FloatTensor)\n", 577 | " y_test=torch.from_numpy(y_test).type(torch.FloatTensor) \n", 578 | " Label_test=Data.TensorDataset(x_test,y_test)\n", 579 | " \n", 580 | " del x_test,y_test\n", 581 | " gc.collect()\n", 582 | " \n", 583 | " x_val=torch.from_numpy(x_val).type(torch.FloatTensor)\n", 584 | " y_val=torch.from_numpy(y_val).type(torch.FloatTensor)\n", 585 | " Label_val=Data.TensorDataset(x_val,y_val)\n", 586 | " \n", 587 | " del x_val,y_val\n", 588 | " gc.collect()\n", 589 | " \n", 590 | "\n", 591 | " \n", 592 | " \n", 593 | " label_train_loader=Data.DataLoader(Label_train,batch_size=64,shuffle=True)\n", 594 | " label_test_loader=Data.DataLoader(Label_test,batch_size=64,shuffle=False)\n", 595 | " label_val_loader=Data.DataLoader(Label_val,batch_size=64,shuffle=False)\n", 596 | " \n", 597 | " del Label_train,Label_test,Label_val\n", 598 | " gc.collect()\n", 599 | " model = ViT(\n", 600 | " image_size = patch,\n", 601 | " num_patches = band,\n", 602 | " num_classes = CLASSES_NUM,\n", 603 | " dim = P_dim,\n", 604 | " dim_head=P_dim_head,\n", 605 | " mlp_dim = P_mlp_dim,\n", 606 | " depth = P_depth,\n", 607 | " heads = 4,\n", 608 | " dropout = 0.1,\n", 609 | " emb_dropout = 0.1,\n", 610 | " mode = mode,\n", 611 | " cross_attn_depth = cross_attn_depth, \n", 612 | " ssf_enc_depth = ssf_enc_depth\n", 613 | ")\n", 614 | " model = model.cuda()\n", 615 | " model.train()\n", 616 | " \n", 617 | " optimizer = optim.Adam(\n", 618 | " model.parameters(),\n", 619 | " lr=5e-4,\n", 620 | " weight_decay=0)\n", 621 | "\n", 622 | " \n", 623 | " tic1 = time.time()\n", 624 | " train(\n", 625 | " model,\n", 626 | " label_train_loader,\n", 627 | " label_val_loader,\n", 628 | " loss,\n", 629 | " optimizer,\n", 630 | " device,\n", 631 | " epochs=PARAM_EPOCH)\n", 632 | " toc1 = time.time()\n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " pred_test = []\n", 637 | " tic2 = time.time()\n", 638 | " with torch.no_grad():\n", 639 | " for X, y in label_test_loader:\n", 640 | " X = X.to(device)\n", 641 | " model.eval()\n", 642 | " y_hat = model(X)\n", 643 | " pred_test.extend(np.array(model(X).cpu().argmax(axis=1)))\n", 644 | " toc2 = time.time()\n", 645 | " collections.Counter(pred_test)\n", 646 | " gt_test = gt[test_indices] - 1\n", 647 | "\n", 648 | " overall_acc = metrics.accuracy_score(pred_test, gt_test[:-VAL_SIZE])\n", 649 | " confusion_matrix = metrics.confusion_matrix(pred_test, gt_test[:-VAL_SIZE])\n", 650 | " each_acc, average_acc = record.aa_and_each_accuracy(confusion_matrix)\n", 651 | " kappa = metrics.cohen_kappa_score(pred_test, gt_test[:-VAL_SIZE])\n", 652 | " \n", 653 | " KAPPA.append(kappa)\n", 654 | " OA.append(overall_acc)\n", 655 | " AA.append(average_acc)\n", 656 | " TRAINING_TIME.append(toc1 - tic1)\n", 657 | " TESTING_TIME.append(toc2 - tic2)\n", 658 | " ELEMENT_ACC[index_iter, :] = each_acc\n", 659 | " \n", 660 | " del label_train_loader,label_test_loader,label_val_loader\n", 661 | " gc.collect()\n", 662 | " \n", 663 | "print(\"--------\" + \" Training Finished-----------\")\n", 664 | "record.record_output(\n", 665 | " OA, AA, KAPPA, ELEMENT_ACC, TRAINING_TIME, TESTING_TIME, flops, params,\n", 666 | " './report/' + 'DC-DenseFormer_'+ Dataset + '_' +str(mode)+'_dep_'+str(P_depth)+'_SSF_'+str(ssf_enc_depth)+'_Cro_'+str(cross_attn_depth)+'_Patch_'+ str(img_rows) + '_' +'spl'\n", 667 | " + str(VALIDATION_SPLIT) +'.txt')\n" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": null, 673 | "id": "1ea0d4c8", 674 | "metadata": {}, 675 | "outputs": [], 676 | "source": [ 677 | "x_train,y_train, x_val,y_val, x_test,y_test, all_data, gt_all = geniter.generate_iter(\n", 678 | " TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE,\n", 679 | " total_indices, VAL_SIZE, whole_data, PATCH_LENGTH, padded_data,\n", 680 | " INPUT_DIMENSION, 16, gt) " 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": null, 686 | "id": "c309635a", 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "del whole_data,padded_data,x_train,y_train, x_val,y_val, x_test,y_test\n", 691 | "gc.collect()" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": 27, 697 | "id": "695f5b79", 698 | "metadata": {}, 699 | "outputs": [], 700 | "source": [ 701 | "x_all=torch.from_numpy(all_data).type(torch.FloatTensor)\n", 702 | "y_all=torch.from_numpy(gt_all).type(torch.FloatTensor)\n", 703 | "Label_all=Data.TensorDataset(x_all,y_all)" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": null, 709 | "id": "9f5940ef", 710 | "metadata": {}, 711 | "outputs": [], 712 | "source": [ 713 | "del x_all,y_all\n", 714 | "gc.collect()" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": 29, 720 | "id": "91e662ef", 721 | "metadata": {}, 722 | "outputs": [], 723 | "source": [ 724 | "label_all_loader=Data.DataLoader(Label_all,batch_size=64,shuffle=False)" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "execution_count": null, 730 | "id": "691525cd", 731 | "metadata": {}, 732 | "outputs": [], 733 | "source": [ 734 | "Utils.generate_png(\n", 735 | " label_all_loader, model, gt_hsi, Dataset, device, total_indices,\n", 736 | " './classification_maps/' + 'DC-DenseFormer_'+ Dataset + '_' +str(mode)+'_dep_'+str(P_depth)+'_SSF_'+str(ssf_enc_depth)+'_Cro_'+str(cross_attn_depth)+'_Patch_'+ str(img_rows) + '_' +'spl'\n", 737 | " + str(VALIDATION_SPLIT))" 738 | ] 739 | } 740 | ], 741 | "metadata": { 742 | "kernelspec": { 743 | "display_name": "pyHSI", 744 | "language": "python", 745 | "name": "python3" 746 | }, 747 | "language_info": { 748 | "codemirror_mode": { 749 | "name": "ipython", 750 | "version": 3 751 | }, 752 | "file_extension": ".py", 753 | "mimetype": "text/x-python", 754 | "name": "python", 755 | "nbconvert_exporter": "python", 756 | "pygments_lexer": "ipython3", 757 | "version": "3.8.13" 758 | }, 759 | "vscode": { 760 | "interpreter": { 761 | "hash": "4f0e2fe85115fd2386557c42e0d23c6fb95dd61fe2d1e647d69acf95e1a29035" 762 | } 763 | } 764 | }, 765 | "nbformat": 4, 766 | "nbformat_minor": 5 767 | } 768 | -------------------------------------------------------------------------------- /geniter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.utils.data as Data 4 | 5 | def index_assignment(index, row, col, pad_length): 6 | new_assign = {} 7 | for counter, value in enumerate(index): 8 | assign_0 = value // col + pad_length 9 | assign_1 = value % col + pad_length 10 | new_assign[counter] = [assign_0, assign_1] 11 | return new_assign 12 | 13 | def select_patch(matrix, pos_row, pos_col, ex_len): 14 | selected_rows = matrix[range(pos_row-ex_len, pos_row+ex_len+1)] 15 | selected_patch = selected_rows[:, range(pos_col-ex_len, pos_col+ex_len+1)] 16 | return selected_patch 17 | 18 | 19 | def select_small_cubic(data_size, data_indices, whole_data, patch_length, padded_data, dimension): 20 | small_cubic_data = np.zeros((data_size, 2 * patch_length + 1, 2 * patch_length + 1, dimension)) 21 | data_assign = index_assignment(data_indices, whole_data.shape[0], whole_data.shape[1], patch_length) 22 | for i in range(len(data_assign)): 23 | small_cubic_data[i] = select_patch(padded_data, data_assign[i][0], data_assign[i][1], patch_length) 24 | return small_cubic_data 25 | 26 | 27 | def generate_iter(TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE, total_indices, VAL_SIZE, 28 | whole_data, PATCH_LENGTH, padded_data, INPUT_DIMENSION, batch_size, gt): 29 | gt_all = gt[total_indices] - 1 30 | y_train = gt[train_indices] - 1 31 | y_test = gt[test_indices] - 1 32 | 33 | all_data = select_small_cubic(TOTAL_SIZE, total_indices, whole_data, 34 | PATCH_LENGTH, padded_data, INPUT_DIMENSION) 35 | 36 | train_data = select_small_cubic(TRAIN_SIZE, train_indices, whole_data, 37 | PATCH_LENGTH, padded_data, INPUT_DIMENSION) 38 | 39 | print('Train size: ',train_data.shape) 40 | test_data = select_small_cubic(TEST_SIZE, test_indices, whole_data, 41 | 42 | PATCH_LENGTH, padded_data, INPUT_DIMENSION) 43 | x_train = train_data.reshape(train_data.shape[0], train_data.shape[1], train_data.shape[2], INPUT_DIMENSION) 44 | x_test_all = test_data.reshape(test_data.shape[0], test_data.shape[1], test_data.shape[2], INPUT_DIMENSION) 45 | 46 | x_val = x_test_all[-VAL_SIZE:] 47 | y_val = y_test[-VAL_SIZE:] 48 | 49 | x_test = x_test_all[:-VAL_SIZE] 50 | y_test = y_test[:-VAL_SIZE] 51 | 52 | 53 | all_data.reshape(all_data.shape[0], all_data.shape[1], all_data.shape[2], INPUT_DIMENSION) 54 | 55 | return x_train,y_train, x_val,y_val, x_test,y_test, all_data, gt_all -------------------------------------------------------------------------------- /record.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from operator import truediv 4 | 5 | def evaluate_accuracy(data_iter, net, loss, device): 6 | acc_sum, n = 0.0, 0 7 | with torch.no_grad(): 8 | for X, y in data_iter: 9 | test_l_sum, test_num = 0, 0 10 | #X = X.permute(0, 3, 1, 2) 11 | X = X.to(device) 12 | y = y.to(device) 13 | net.eval() 14 | y_hat = net(X) 15 | l = loss(y_hat, y.long()) 16 | acc_sum += (y_hat.argmax(dim=1) == y.to(device)).float().sum().cpu().item() 17 | test_l_sum += l 18 | test_num += 1 19 | net.train() 20 | n += y.shape[0] 21 | return [acc_sum / n, test_l_sum] # / test_num] 22 | 23 | 24 | def aa_and_each_accuracy(confusion_matrix): 25 | list_diag = np.diag(confusion_matrix) 26 | list_raw_sum = np.sum(confusion_matrix, axis=1) 27 | each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum)) 28 | average_acc = np.mean(each_acc) 29 | return each_acc, average_acc 30 | 31 | 32 | 33 | def record_output(oa_ae, aa_ae, kappa_ae, element_acc_ae, training_time_ae, testing_time_ae,flops,params,path): 34 | f = open(path, 'a') 35 | sentence0 = '\n'+ '\n'+ '\n'+'OAs for each iteration are:' + str(oa_ae) + '\n' 36 | f.write(sentence0) 37 | sentence1 = 'AAs for each iteration are:' + str(aa_ae) + '\n' 38 | f.write(sentence1) 39 | sentence2 = 'KAPPAs for each iteration are:' + str(kappa_ae) + '\n' + '\n' 40 | f.write(sentence2) 41 | sentence3 = 'mean_OA ± std_OA is: ' + str(np.mean(oa_ae)) + ' ± ' + str(np.std(oa_ae)) + '\n' 42 | f.write(sentence3) 43 | sentence4 = 'mean_AA ± std_AA is: ' + str(np.mean(aa_ae)) + ' ± ' + str(np.std(aa_ae)) + '\n' 44 | f.write(sentence4) 45 | sentence5 = 'mean_KAPPA ± std_KAPPA is: ' + str(np.mean(kappa_ae)) + ' ± ' + str(np.std(kappa_ae)) + '\n' + '\n' 46 | f.write(sentence5) 47 | sentence6 = 'Total average Training time is: ' + str(np.sum(training_time_ae)) + '\n' 48 | f.write(sentence6) 49 | sentence7 = 'Total average Testing time is: ' + str(np.sum(testing_time_ae)) + '\n' + '\n' 50 | f.write(sentence7) 51 | element_mean = np.mean(element_acc_ae, axis=0) 52 | element_std = np.std(element_acc_ae, axis=0) 53 | sentence8 = "Mean of all elements in confusion matrix: " + str(element_mean) + '\n' 54 | f.write(sentence8) 55 | sentence9 = "Standard deviation of all elements in confusion matrix: " + str(element_std) + '\n'+ '\n' 56 | f.write(sentence9) 57 | sentence10 = "flops: "+ str(flops / 1000000.0)+"M params:"+ str(params / 1000000.0)+'M' + '\n'+ '\n'+ '\n' 58 | f.write(sentence10) 59 | f.close() -------------------------------------------------------------------------------- /vit_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from einops import rearrange, repeat 5 | from CrossViT_module import CrossAttention 6 | 7 | class Residual(nn.Module): 8 | def __init__(self, fn): 9 | super().__init__() 10 | self.fn = fn 11 | def forward(self, x, **kwargs): 12 | return self.fn(x, **kwargs) + x 13 | 14 | class PreNorm(nn.Module): 15 | def __init__(self, dim, fn): 16 | super().__init__() 17 | self.norm = nn.LayerNorm(dim) 18 | self.fn = fn 19 | def forward(self, x, **kwargs): 20 | return self.fn(self.norm(x), **kwargs) 21 | 22 | class FeedForward(nn.Module): 23 | def __init__(self, dim, hidden_dim, dropout = 0.): 24 | super().__init__() 25 | self.net = nn.Sequential( 26 | nn.Linear(dim, hidden_dim), 27 | nn.GELU(), 28 | nn.Dropout(dropout), 29 | nn.Linear(hidden_dim, dim), 30 | nn.Dropout(dropout) 31 | ) 32 | def forward(self, x): 33 | return self.net(x) 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, heads, dim_head, dropout): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | self.heads = heads 40 | self.scale = dim_head ** -0.5 41 | 42 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 43 | self.to_out = nn.Sequential( 44 | nn.Linear(inner_dim, dim), 45 | nn.Dropout(dropout) 46 | ) 47 | def forward(self, x, mask = None): 48 | # x:[b,n,dim] 49 | b, n, _, h = *x.shape, self.heads 50 | 51 | # get qkv tuple:([b,n,head_num*head_dim],[...],[...]) 52 | qkv = self.to_qkv(x).chunk(3, dim = -1) 53 | # split q,k,v from [b,n,head_num*head_dim] -> [b,head_num,n,head_dim] 54 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 55 | 56 | # transpose(k) * q / sqrt(head_dim) -> [b,head_num,n,n] 57 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 58 | mask_value = -torch.finfo(dots.dtype).max 59 | 60 | # mask value: -inf 61 | if mask is not None: 62 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 63 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 64 | mask = mask[:, None, :] * mask[:, :, None] 65 | dots.masked_fill_(~mask, mask_value) 66 | del mask 67 | 68 | # softmax normalization -> attention matrix 69 | attn = dots.softmax(dim=-1) 70 | # value * attention matrix -> output 71 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 72 | # cat all output -> [b, n, head_num*head_dim] 73 | out = rearrange(out, 'b h n d -> b n (h d)') 74 | out = self.to_out(out) 75 | return out 76 | 77 | class Transformer(nn.Module): 78 | def __init__(self, dim, depth, heads, dim_head, mlp_head, dropout, num_channel, mode): 79 | super().__init__() 80 | 81 | self.layers = nn.ModuleList([]) 82 | self.depth = depth 83 | for _ in range(depth): 84 | self.layers.append(nn.ModuleList([ 85 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 86 | Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout))) 87 | ])) 88 | 89 | self.mode = mode 90 | self.skipcat = nn.ModuleList([]) 91 | if self.mode == 'DEN-1': 92 | for i in range(depth-1): 93 | self.skipcat.append(nn.Conv2d(num_channel+1, num_channel+1, [1, 2+i], 1, 0)) 94 | 95 | 96 | 97 | def forward(self, x, mask = None): 98 | #print(x.shape) 99 | if self.mode == 'ViT': 100 | for attn, ff in self.layers: 101 | x = attn(x, mask = mask) 102 | x = ff(x) 103 | elif self.mode == 'DEN-1': 104 | last_output = [] 105 | nl = 0 106 | for attn, ff in self.layers: 107 | last_output.append(x) 108 | if nl > 0: 109 | x=x.unsqueeze(3) 110 | for j in range(nl): 111 | x=(torch.cat([x, last_output[nl-1-j].unsqueeze(3)], dim=3)) 112 | x = self.skipcat[nl-1](x).squeeze(3) 113 | x = attn(x, mask = mask) 114 | x = ff(x) 115 | nl += 1 116 | 117 | 118 | return x 119 | 120 | 121 | class MultiScaleTransformerEncoder(nn.Module): 122 | 123 | def __init__(self,Snum_patches,Lnum_patches, mode, small_dim = 96, small_depth = 4, small_heads =3, small_dim_head = 32, small_mlp_dim = 384, 124 | large_dim = 192, large_depth = 1, large_heads = 3, large_dim_head = 64, large_mlp_dim = 768, 125 | cross_attn_depth = 1, cross_attn_heads = 3, dropout = 0.): 126 | super().__init__() 127 | self.transformer_enc_small = Transformer(small_dim, small_depth, small_heads, small_dim_head, small_mlp_dim,dropout, Snum_patches, mode) 128 | self.transformer_enc_large = Transformer(large_dim, large_depth, large_heads, large_dim_head, large_mlp_dim,dropout, Lnum_patches, mode) 129 | 130 | self.cross_attn_layers = nn.ModuleList([]) 131 | for _ in range(cross_attn_depth): 132 | self.cross_attn_layers.append(nn.ModuleList([ 133 | nn.Linear(small_dim, large_dim), 134 | nn.Linear(large_dim, small_dim), 135 | PreNorm(large_dim, CrossAttention(large_dim, heads = cross_attn_heads, dim_head = large_dim_head, dropout = dropout)), 136 | nn.Linear(large_dim, small_dim), 137 | nn.Linear(small_dim, large_dim), 138 | PreNorm(small_dim, CrossAttention(small_dim, heads = cross_attn_heads, dim_head = small_dim_head, dropout = dropout)), 139 | ])) 140 | 141 | 142 | def forward(self, xs, xl,mask = None): 143 | xs = self.transformer_enc_small(xs,mask) 144 | xl = self.transformer_enc_large(xl,mask) 145 | for f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l in self.cross_attn_layers: 146 | small_class = xs[:, 0] 147 | x_small = xs[:, 1:] 148 | large_class = xl[:, 0] 149 | x_large = xl[:, 1:] 150 | 151 | # Cross Attn for Large Patch 152 | 153 | cal_q = f_ls(large_class.unsqueeze(1)) 154 | cal_qkv = torch.cat((cal_q, x_small), dim=1) 155 | cal_out = cal_q + cross_attn_l(cal_qkv) 156 | cal_out = g_sl(cal_out) 157 | xl = torch.cat((cal_out, x_large), dim=1) 158 | 159 | # Cross Attn for Smaller Patch 160 | cal_q = f_sl(small_class.unsqueeze(1)) 161 | cal_qkv = torch.cat((cal_q, x_large), dim=1) 162 | cal_out = cal_q + cross_attn_s(cal_qkv) 163 | cal_out = g_ls(cal_out) 164 | xs = torch.cat((cal_out, x_small), dim=1) 165 | 166 | return xs, xl 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | class ViT(nn.Module): 176 | def __init__(self, image_size, num_patches, num_classes, dim, depth, heads, mlp_dim, pool='cls', dim_head = 16, dropout=0., emb_dropout=0., mode='ViT',cross_attn_depth = 1, ssf_enc_depth = 0): 177 | super().__init__() 178 | 179 | patch_dim = image_size ** 2 180 | self.pos_embedding1 = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 181 | self.patch_to_embedding1 = nn.Linear(patch_dim, dim) 182 | self.cls_token1 = nn.Parameter(torch.randn(1, 1, dim)) 183 | self.dropout1 = nn.Dropout(emb_dropout) 184 | self.pool1 = pool 185 | self.to_latent1 = nn.Identity() 186 | 187 | self.pos_embedding2 = nn.Parameter(torch.randn(1, patch_dim + 1, dim)) 188 | self.patch_to_embedding2 = nn.Linear(num_patches, dim) 189 | 190 | 191 | self.cls_token2 = nn.Parameter(torch.randn(1, 1, dim)) 192 | self.dropout2 = nn.Dropout(emb_dropout) 193 | 194 | self.pool2 = pool 195 | self.to_latent2 = nn.Identity() 196 | if ssf_enc_depth > 0: 197 | self.mlp_head1 = nn.Sequential( 198 | nn.LayerNorm(dim), 199 | nn.Linear(dim, num_classes) 200 | ) 201 | 202 | self.mlp_head2 = nn.Sequential( 203 | nn.LayerNorm(dim), 204 | nn.Linear(dim, num_classes) 205 | ) 206 | 207 | self.mlp_head3 = nn.Sequential( 208 | nn.LayerNorm(dim*2), 209 | nn.Linear(dim*2, num_classes) 210 | ) 211 | 212 | self.ssf_enc_depth=ssf_enc_depth 213 | if ssf_enc_depth > 0: 214 | self.ssf_transformers = nn.ModuleList([]) 215 | for _ in range(ssf_enc_depth): 216 | self.ssf_transformers.append(MultiScaleTransformerEncoder(Snum_patches=num_patches, Lnum_patches=patch_dim,mode=mode, 217 | small_dim=dim, small_depth=depth, 218 | small_heads=heads, small_dim_head=dim_head, 219 | small_mlp_dim=mlp_dim, 220 | large_dim=dim, large_depth=depth, 221 | large_heads=heads, large_dim_head=dim_head, 222 | large_mlp_dim=mlp_dim, 223 | cross_attn_depth=cross_attn_depth, cross_attn_heads=heads, 224 | dropout=dropout)) 225 | 226 | else: 227 | self.transformer1 = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_patches, mode) 228 | self.transformer2 = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_dim, mode) 229 | 230 | 231 | 232 | 233 | def forward(self, x, mask = None): 234 | 235 | x=rearrange(x,'b w h c -> b (w h) c') 236 | x2=x 237 | x1=x.transpose(-1, -2) 238 | 239 | 240 | x1 = self.patch_to_embedding1(x1) #[b,n,dim] 241 | b, n, _ = x1.shape 242 | # add position embedding 243 | cls_tokens1 = repeat(self.cls_token1, '() n d -> b n d', b = b) #[b,1,dim] 244 | x1 = torch.cat((cls_tokens1, x1), dim = 1) #[b,n+1,dim] 245 | x1 += self.pos_embedding1[:, :(n + 1)] 246 | x1 = self.dropout1(x1) 247 | 248 | 249 | x2 = self.patch_to_embedding2(x2) #[b,n,dim] 250 | b, n, _ = x2.shape 251 | # add position embedding 252 | cls_tokens2 = repeat(self.cls_token2, '() n d -> b n d', b = b) #[b,1,dim] 253 | x2 = torch.cat((cls_tokens2, x2), dim = 1) #[b,n+1,dim] 254 | x2 += self.pos_embedding2[:, :(n + 1)] 255 | x2 = self.dropout2(x2) 256 | 257 | if self.ssf_enc_depth >0: 258 | xs=x1 259 | xl=x2 260 | for ssf_transformer in self.ssf_transformers: 261 | xs, xl = ssf_transformer(xs, xl) 262 | xs = xs.mean(dim = 1) if self.pool1 == 'mean' else xs[:, 0] 263 | xl = xl.mean(dim = 1) if self.pool2 == 'mean' else xl[:, 0] 264 | 265 | 266 | x3=torch.cat((xs, xl), dim = 1) 267 | x3=self.mlp_head3(x3) 268 | else: 269 | x1 = self.transformer1(x1, mask) 270 | x1 = self.to_latent1(x1[:,0]) 271 | x2 = self.transformer2(x2, mask) 272 | 273 | x2 = self.to_latent2(x2[:,0]) 274 | x3=torch.cat((x1, x2), dim = 1) 275 | x3=self.mlp_head3(x3) 276 | 277 | return x3 278 | --------------------------------------------------------------------------------