├── .gitignore ├── README.md ├── framework.png └── src ├── augment.py ├── data_provider ├── __init__.py └── data_provider.py ├── evaluation.py ├── models ├── __init__.py └── cross_transformer.py ├── params ├── indian_diffusion.json ├── pavia_diffusion.json └── salinas_diffusion.json ├── plot.py ├── process.py ├── trainer.py ├── utils.py └── workflow.py /.gitignore: -------------------------------------------------------------------------------- 1 | src/__pycache__ 2 | src/data_provider/__pycache__/* 3 | res 4 | src/models/__pycache__ 5 | *.pyc 6 | */*.pyc 7 | */*/*.pyc 8 | src/jupyter/.* 9 | data/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models 2 | 3 | [Ning Chen](), [Jun Yue](), [Leyuan Fang](), [Shaobo Xia]() 4 | ___________ 5 | 6 | The code in this toolbox implements the ["SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models"](https://ieeexplore.ieee.org/document/10234379). 7 | 8 | **The codes for this research includes two parts, [spectral-spatial diffusion module](https://github.com/chenning0115/spectraldiff_diffusion/) and [attention-based classification module](https://github.com/chenning0115/SpectralDiff#spectraldiff). This repository is for the attention-based classification module.** 9 | 10 | More specifically, it is detailed as follow. 11 | 12 | ![alt text](./framework.png) 13 | 14 | Citation 15 | --------------------- 16 | 17 | **Please kindly cite the papers if this code is useful and helpful for your research.** 18 | 19 | ``` 20 | N. Chen, J. Yue, L. Fang and S. Xia, "SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2023.3310023. 21 | 22 | ``` 23 | 24 | ``` 25 | @ARTICLE{10234379, 26 | author={Chen, Ning and Yue, Jun and Fang, Leyuan and Xia, Shaobo}, 27 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 28 | title={SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models}, 29 | year={2023}, 30 | volume={}, 31 | number={}, 32 | pages={1-1}, 33 | doi={10.1109/TGRS.2023.3310023}} 34 | 35 | ``` 36 | 37 | 38 | How to use it? 39 | --------------------- 40 | 1. Prepare raw data and diffusion features 41 | * Raw data is origin HSI data, likes IP, PU, SA datasets, You need to separate the training and test sets in advance. Or you can download ours from [baiduyun](https://pan.baidu.com/s/19-YNNIjQxEOz-gl3vCLuDg), extract codes is ```pabk```. 42 | * The classification module requires providing the features extracted by the diffusion module as input. We provide the diffusion features extracted in our experiments for researchers to reproduce the results. For the convenience of testing, we have provided all diffusion features data before PCA. Please download the specific data from [baiduyun](https://pan.baidu.com/s/19-YNNIjQxEOz-gl3vCLuDg), extract codes is ```pabk```, you can also get data from [google_drive](https://drive.google.com/drive/folders/10n7MmQRbIh-fpmIIFVtQk17QlXq-kBAo?usp=drive_link). 43 | * To train a diffusion model by yourself, you can use this code repository [spectral-spatial diffusion module](https://github.com/chenning0115/spectraldiff_diffusion/). 44 | 2. Modify the path of diffusion features in the params to ensure that the code can read them. 45 | 3. Run the code 46 | ``` 47 | python workflow.py 48 | ``` 49 | 50 | Others 51 | ---------------------- 52 | If you want to run the code in your own data, you can accordingly change the input (e.g., data, labels) and tune the parameters. 53 | 54 | If you encounter the bugs while using this code, please do not hesitate to contact us. 55 | 56 | Licensing 57 | --------- 58 | 59 | Copyright (C) 2023 Ning Chen 60 | 61 | This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3 of the License. 62 | 63 | This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. 64 | 65 | You should have received a copy of the GNU General Public License along with this program. 66 | 67 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenning0115/SpectralDiff/d1a4c978f57df8a11fa6609abbc64ec706f02a1c/framework.png -------------------------------------------------------------------------------- /src/augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.functional import Tensor 3 | from torchvision import transforms 4 | import torch.nn.functional as F 5 | import random 6 | import math 7 | ''' 8 | 这个是对原patch进行缩小,参数size小于原patch的size 9 | ''' 10 | 11 | class Augment: 12 | def __init__(self,params) -> None: 13 | self.name=params['type'] 14 | 15 | def do(self,data): 16 | return self.real_do(data) 17 | 18 | def real_do(self,data)->Tensor: 19 | pass 20 | 21 | class ShrinkAugment(Augment): 22 | def __init__(self,params) -> None: 23 | super(ShrinkAugment,self).__init__(params) 24 | self.size=params.get("size",3) 25 | 26 | def real_do(self,data): 27 | # data: batch,channel,patch_size,patch_size 28 | batch_size=data.size(0) 29 | channel_num=data.size(1) 30 | center=int(data.size(2)/2) 31 | margin=int((self.size-1)/2) 32 | newdata=torch.zeros(data.size()) 33 | newdata[:,:,center-margin:center+margin+1,center-margin:center+margin+1]=data[:,:,center-margin:center+margin+1,center-margin:center+margin+1] 34 | 35 | return newdata 36 | 37 | ''' 38 | 使用高斯核对每个spectrum进行模糊,参数包括kernel_size和sigma_square 39 | 在json中: 40 | "type":"Gauss", 41 | "kernel_size":5 42 | "sigma_sq":2.25 43 | ''' 44 | class GaussAugment(Augment): 45 | def __init__(self,params) -> None: 46 | super(GaussAugment,self).__init__(params) 47 | self.kernel_size=params.get("kernel_size",3) 48 | self.sigma_sq=params.get("sigma_sq",2.25) 49 | 50 | def real_do(self,data): 51 | # data: batch,channel,patch_size,patch_size 52 | t=transforms.GaussianBlur(self.kernel_size,self.sigma_sq) 53 | newdata=t(data) 54 | return newdata 55 | 56 | ''' 57 | 使用在spectrum维的gaussblur 58 | "type":"SpectralFilter", 59 | "kernel_size":5 60 | "sigma_sq":2.25 61 | ''' 62 | class SpecFilterAugment(Augment): 63 | def __init__(self,params) -> None: 64 | super(SpecFilterAugment,self).__init__(params) 65 | self.kernel_size=params.get("kernel_size",3) 66 | self.sigma_sq=params.get("sigma_sq",2.25) 67 | self.margin=self.kernel_size/2 68 | self.filter=torch.Tensor(self.kernel_size) 69 | for i in range(self.margin+1): 70 | self.filter[i]=self.filter[self.kernel_size-1-i]=-1*torch.exp((self.margin-i)*(self.margin-i)/2/self.sigma_sq)/torch.sqrt(2*torch.PI*self.sigma_sq) 71 | 72 | def real_do(self,data): 73 | # data: batch,channel,patch_size,patch_size 74 | batch_size=data.size(0) 75 | channel_num=data.size(1) 76 | H=data.size(2) 77 | W=data.size(3) 78 | data=torch.transpose(data,(0,2,3,1)) 79 | newdata=torch.zeros(data.shape()) 80 | for i in range(batch_size): 81 | padding_data=torch.zeros(H,W,channel_num+2*self.margin) 82 | padding_data[:,:,self.margin:self.margin+channel_num+1]=data[i] 83 | for j in range(H): 84 | for k in range(W): 85 | for l in range(channel_num): 86 | newdata[i][j][k][l]=torch.dot(self.filter,padding_data[j][k][l:l+self.kernel_size]) 87 | data=torch.transpose(data,(0,3,1,2)) 88 | newdata=torch.transpose(newdata,(0,3,1,2)) 89 | return newdata 90 | 91 | class FlipAugment(Augment): 92 | def __init__(self, params) -> None: 93 | super().__init__(params) 94 | self.mirror=params.get('mirror','horizontal') 95 | 96 | def real_do(self,data):# b c h w 97 | if self.mirror=='horizontal': 98 | return transforms.functional.hflip(data) 99 | else: 100 | return transforms.functional.vflip(data) 101 | 102 | class RotateAugment(Augment): 103 | def __init__(self, params) -> None: 104 | super().__init__(params) 105 | self.angle=params.get('angle',90) # 默认90,也可以是270,逆时针为正 106 | 107 | def real_do(self, data): 108 | newdata=torch.transpose(data,2,3) 109 | if self.angle==270: 110 | return transforms.functional.hflip(newdata) 111 | else: 112 | return transforms.functional.vflip(newdata) 113 | 114 | class DownSampleAugment(Augment): 115 | # 降采样 116 | def __init__(self, params) -> None: 117 | super().__init__(params) 118 | self.scale=params.get("scale",2) 119 | 120 | def real_do(self, data): 121 | x=F.interpolate(data,scale_factor=(1./self.scale,1./self.scale)) 122 | return F.interpolate(x,size=(data.size(2),data.size(3))) 123 | 124 | class MaskAugment(Augment):# 3D随机mask,指的是mask大小随机再加left_top点随机 125 | def __init__(self, params) -> None: 126 | super().__init__(params) 127 | self.max_ratio=params['max_ratio'] 128 | 129 | def rand_mask(self,data): 130 | b,s,h,w=data.size() 131 | s_len=math.floor((1-random.random()*self.max_ratio)*s) 132 | s_o=random.randint(0,s_len-1) 133 | h_len=math.floor((1-random.random()*self.max_ratio)*h) 134 | h_o=random.randint(0,h_len-1) 135 | w_len=math.floor((1-random.random()*self.max_ratio)*w) 136 | w_o=random.randint(0,w_len-1) 137 | return s_o,h_o,w_o,s-s_len,h-h_len,w-w_len # 返回mask起始原点,以及三个维度上的mask长度 138 | 139 | def real_do(self,data)->Tensor: 140 | b,s,h,w=data.size() 141 | s_o1,h_o1,w_o1,s_m1,h_m1,w_m1=self.rand_mask(data) 142 | s_o2,h_o2,w_o2,s_m2,h_m2,w_m2=self.rand_mask(data) 143 | left_mask=torch.ones_like(data) 144 | left_mask[:,s_o1:s_o1+s_m1,h_o1:h_o1+h_m1,w_o1:w_o1+w_m1]=0 145 | right_mask=torch.ones_like(data) 146 | right_mask[:,s_o2:s_o2+s_m2,h_o2:h_o2+h_m2,w_o2:w_o2+w_m2]=0 147 | return data*left_mask,data*right_mask 148 | 149 | 150 | class SameAugment(Augment): 151 | def __init__(self, params) -> None: 152 | super().__init__(params) 153 | 154 | def real_do(self, data) -> Tensor: 155 | return data,data 156 | 157 | class XMaskAugment(Augment): 158 | def __init__(self, params) -> None: 159 | super().__init__(params) 160 | 161 | def real_do(self, data) -> Tensor: 162 | ''' 163 | data shape is [batch, spe, h, w] 164 | 左边 奇数mask 165 | 右边 偶数mask 166 | ''' 167 | b, s, h, w = data.shape 168 | left_mask = torch.zeros_like(data) 169 | left_mask[:,list(range(0,s,2)),:,:] = 1 170 | right_mask = torch.ones_like(data) - left_mask 171 | left = data * left_mask 172 | right = data * right_mask 173 | return left, right 174 | 175 | 176 | def do_augment(params,data):# 增强也有一系列参数呢,比如multiscale的尺寸、mask的大小、Gaussian噪声的参数等 177 | if params['type']=='shrink': 178 | return ShrinkAugment(params).do(data) 179 | if params['type']=='Gauss': 180 | return GaussAugment(params).do(data) 181 | if params['type']=='Flip': 182 | return FlipAugment(params).do(data) 183 | if params['type']=='Rotate': 184 | return RotateAugment(params).do(data) 185 | if params["type"]=='DownSample': 186 | return DownSampleAugment(params).do(data) 187 | if params['type'] == 'Same': 188 | return SameAugment(params).do(data) 189 | if params['type'] == 'Mask': 190 | return XMaskAugment(params).do(data) 191 | if params['type'] == '3DMask': 192 | return MaskAugment(params).do(data) -------------------------------------------------------------------------------- /src/data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenning0115/SpectralDiff/d1a4c978f57df8a11fa6609abbc64ec706f02a1c/src/data_provider/__init__.py -------------------------------------------------------------------------------- /src/data_provider/data_provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | from sklearn.decomposition import PCA 4 | from sklearn.preprocessing import MinMaxScaler, StandardScaler 5 | from sklearn.model_selection import train_test_split 6 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from operator import truediv 11 | import time, json 12 | import os, sys 13 | 14 | """ Training dataset""" 15 | 16 | class DataSetIter(torch.utils.data.Dataset): 17 | def __init__(self, _base_img, _base_labels, _index2pos, _margin, _patch_size, _append_dim) -> None: 18 | self.base_img = _base_img #全量数据包括margin (145+2margin * 145+2margin * spe) 19 | self.base_labels = _base_labels #全量数据无margin (145 * 145) 20 | self.index2pos = _index2pos #训练数据 index -> (x, y) 对应margin后base_img的中心点坐标 21 | self.size = len(_index2pos) 22 | 23 | self.margin = _margin 24 | self.patch_size = _patch_size 25 | self.append_dim = _append_dim 26 | 27 | def __getitem__(self, index): 28 | start_x, start_y = self.index2pos[index] 29 | patch = self.base_img[start_x:start_x+2*self.margin+1 , start_y:start_y+2*self.margin+1,:] 30 | if self.append_dim: 31 | patch = np.expand_dims(patch, 0) # [channel=1, h, w, spe] 32 | patch = patch.transpose((0,3,1,2)) # [c, spe, h, w] 33 | else: 34 | patch = patch.transpose((2, 0, 1)) #[spe, h, w] 35 | label = self.base_labels[start_x, start_y] - 1 36 | # print(index, patch.shape, start_x, start_y, label) 37 | return torch.FloatTensor(patch), torch.LongTensor(label.reshape(-1))[0] 38 | 39 | def __len__(self): 40 | return self.size 41 | 42 | 43 | class HSIDataLoader(object): 44 | def __init__(self, param) -> None: 45 | self.data_param = param['data'] 46 | self.data_path_prefix = "../data" 47 | self.data = None #原始读入X数据 shape=(h,w,c) 48 | self.labels = None #原始读入Y数据 shape=(h,w,1) 49 | self.TR = None #标记训练数据 50 | self.TE = None #标记测试数据 51 | 52 | # 参数设置 53 | self.data_path_prefix = self.data_param.get('data_path_prefix', '../data') 54 | self.if_numpy = self.data_param.get('if_numpy', False) 55 | self.data_sign = self.data_param.get('data_sign', 'Indian') 56 | self.data_file = self.data_param.get('data_file', self.data_sign) 57 | self.patch_size = self.data_param.get('patch_size', 13) # n * n 58 | self.remove_zeros = self.data_param.get('remove_zeros', True) 59 | self.batch_size = self.data_param.get('batch_size', 256) 60 | self.none_zero_num = self.data_param.get('none_zero_num', 0) 61 | self.spectracl_size = self.data_param.get("spectral_size", 0) 62 | self.append_dim = self.data_param.get("append_dim", False) 63 | self.use_norm = self.data_param.get("use_norm", True) 64 | self.norm_type = self.data_param.get("norm_type", 'max_min') # 'none', 'max_min', 'mean_var' 65 | 66 | 67 | self.diffusion_data_path = self.data_param.get("diffusion_data_path", "unet3d_pavia.pkl") 68 | 69 | 70 | def load_data_from_diffusion(self, data_ori, labels): 71 | path = "%s/%s" % (self.data_path_prefix, self.diffusion_data_path) 72 | data = np.load(path) 73 | ori_h, ori_w, _= data_ori.shape 74 | h, w, _= data.shape 75 | assert ori_h == h, ori_w == w 76 | print("load diffusion data shape is ", data.shape) 77 | return data, labels 78 | 79 | def load_raw_data(self): 80 | data, labels = None, None 81 | assert self.data_sign in ['Indian', 'Pavia', 'Houston', 'Salinas'] 82 | data_path = '%s/%s/%s_split.mat' % (self.data_path_prefix, self.data_sign, self.data_file) 83 | all_data = sio.loadmat(data_path) 84 | data = all_data['input'] 85 | TR = all_data['TR'] # train label 86 | TE = all_data['TE'] # test label 87 | labels = TR + TE 88 | return data, labels, TR, TE 89 | 90 | def load_data(self): 91 | ori_data, labels, TR, TE = self.load_raw_data() 92 | diffusion_data, diffusion_labels = self.load_data_from_diffusion(ori_data, labels) 93 | return diffusion_data, diffusion_labels, TR, TE 94 | 95 | def _padding(self, X, margin=2): 96 | # pading with zeros 97 | w,h,c = X.shape 98 | new_x, new_h, new_c = w+margin*2, h+margin*2, c 99 | returnX = np.zeros((new_x, new_h, new_c)) 100 | start_x, start_y = margin, margin 101 | returnX[start_x:start_x+w, start_y:start_y+h,:] = X 102 | return returnX 103 | 104 | def get_valid_num(self, y): 105 | tempy = y.reshape(-1) 106 | validy = tempy[tempy > 0] 107 | print('valid y shape is ', validy.shape) 108 | return validy.shape[0] 109 | 110 | def get_train_test_num(self, TR, TE): 111 | train_num, test_num = TR[TR>0].reshape(-1).size, TE[TE>0].reshape(-1).size 112 | print("train_num=%s, test_num=%s" % (train_num, test_num)) 113 | return train_num, test_num 114 | 115 | 116 | def get_train_test_patches(self, X, y, TR, TE): 117 | h, w, c = X.shape 118 | # 给 X 做 padding 119 | windowSize = self.patch_size 120 | margin = int((windowSize - 1) / 2) 121 | zeroPaddedX = self._padding(X, margin=margin) 122 | 123 | # 确定train和test的数据量 124 | train_num, test_num = self.get_train_test_num(TR, TE) 125 | trainX_index2pos = {} 126 | testX_index2pos = {} 127 | all_index2pos = {} 128 | 129 | patchIndex = 0 130 | trainIndex = 0 131 | testIndex = 0 132 | for r in range(margin, zeroPaddedX.shape[0] - margin): 133 | for c in range(margin, zeroPaddedX.shape[1] - margin): 134 | start_x, start_y = r-margin, c-margin 135 | tempy = y[start_x, start_y] 136 | temp_tr = TR[start_x, start_y] 137 | temp_te = TE[start_x, start_y] 138 | if temp_tr > 0 and temp_te > 0: 139 | print("here", temp_tr, temp_te, r, c) 140 | raise Exception("data error, find sample in trainset as well as testset.") 141 | 142 | if temp_tr > 0: #train data 143 | trainX_index2pos[trainIndex] = [start_x, start_y] 144 | trainIndex += 1 145 | elif temp_te > 0: 146 | testX_index2pos[testIndex] = [start_x, start_y] 147 | testIndex += 1 148 | all_index2pos[patchIndex] =[start_x, start_y] 149 | patchIndex = patchIndex + 1 150 | return zeroPaddedX, y, trainX_index2pos, testX_index2pos, all_index2pos, margin, self.patch_size 151 | 152 | 153 | def applyPCA(self, X, numComponents=30): 154 | newX = np.reshape(X, (-1, X.shape[2])) 155 | pca = PCA(n_components=numComponents, whiten=True) 156 | newX = pca.fit_transform(newX) 157 | newX = np.reshape(newX, (X.shape[0], X.shape[1], numComponents)) 158 | return newX 159 | 160 | def mean_var_norm(self, data): 161 | print("use mean_var norm...") 162 | h, w, c = data.shape 163 | data = data.reshape(h * w, c) 164 | data = StandardScaler().fit_transform(data) 165 | data = data.reshape(h, w, c) 166 | return data 167 | 168 | def data_preprocessing(self, data): 169 | ''' 170 | 1. normalization 171 | 2. pca 172 | 3. spectral filter 173 | data: [h, w, spectral] 174 | ''' 175 | if self.norm_type == 'max_min': 176 | norm_data = np.zeros(data.shape) 177 | for i in range(data.shape[2]): 178 | input_max = np.max(data[:,:,i]) 179 | input_min = np.min(data[:,:,i]) 180 | norm_data[:,:,i] = (data[:,:,i]-input_min)/(input_max-input_min) 181 | elif self.norm_type == 'mean_var': 182 | norm_data = self.mean_var_norm(data) 183 | else: 184 | norm_data = data 185 | pca_num = self.data_param.get('pca', 0) 186 | if pca_num > 0: 187 | print('before pca') 188 | pca_data = self.applyPCA(norm_data, int(self.data_param['pca'])) 189 | norm_data = pca_data 190 | print('after pca') 191 | if self.spectracl_size > 0: # 按照给定的spectral size截取数据 192 | norm_data = norm_data[:,:,:self.spectracl_size] 193 | return norm_data 194 | 195 | 196 | 197 | def generate_numpy_dataset(self): 198 | #1. 根据data_sign load data 199 | self.data, self.labels, self.TR, self.TE = self.load_data() 200 | print('[load data done.] load data shape data=%s, label=%s' % (str(self.data.shape), str(self.labels.shape))) 201 | 202 | #2. 数据预处理 主要是norm化 203 | norm_data = self.data_preprocessing(self.data) 204 | 205 | print('[data preprocessing done.] data shape data=%s, label=%s' % (str(norm_data.shape), str(self.labels.shape))) 206 | 207 | # 3. reshape & filter 208 | h, w, c = norm_data.shape 209 | norm_data = norm_data.reshape((h*w,c)) 210 | norm_label = self.labels.reshape((h*w)) 211 | TR_reshape = self.TR.reshape((h*w)) 212 | TE_reshape = self.TE.reshape((h*w)) 213 | TrainX = norm_data[TR_reshape>0] 214 | TrainY = norm_label[TR_reshape>0] 215 | TestX = norm_data[TE_reshape>0] 216 | TestY = norm_label[TE_reshape>0] 217 | train_test_data = norm_data[norm_label>0] 218 | train_test_label = norm_label[norm_label>0] 219 | 220 | print('------[data] split data to train, test------') 221 | print("X_train shape : %s" % str(TrainX.shape)) 222 | print("Y_train shape : %s" % str(TrainY.shape)) 223 | print("X_test shape : %s" % str(TestX.shape)) 224 | print("Y_test shape : %s" % str(TestY.shape)) 225 | 226 | return TrainX, TrainY, TestX, TestY, norm_data 227 | 228 | def reconstruct_pred(self, y_pred): 229 | ''' 230 | 根据原始label信息 对一维预测结果重建图像 231 | y_pred: [h*w] 232 | return: pred: [h, w] 233 | ''' 234 | h, w = self.labels.shape 235 | return y_pred.reshape((h,w)) 236 | 237 | def prepare_data(self): 238 | #1. 根据data_sign load data 239 | self.data, self.labels, self.TR, self.TE = self.load_data() 240 | print('[load data done.] load data shape data=%s, label=%s' % (str(self.data.shape), str(self.labels.shape))) 241 | 242 | #2. 数据预处理 主要是norm化 243 | norm_data = self.data_preprocessing(self.data) 244 | 245 | print('[data preprocessing done.] data shape data=%s, label=%s' % (str(norm_data.shape), str(self.labels.shape))) 246 | 247 | #3. 获取patch 并形成batch型数据 248 | base_img, labels, train_index2pos, test_index2pos, all_index2pos, margin, patch_size \ 249 | = self.get_train_test_patches(norm_data, self.labels, self.TR, self.TE) 250 | 251 | print('------[data] split data to train, test------') 252 | print("train len: %s" % len(train_index2pos )) 253 | print("test len : %s" % len(test_index2pos )) 254 | print("all len: %s" % len(all_index2pos )) 255 | 256 | 257 | trainset = DataSetIter(base_img, labels, train_index2pos, margin, patch_size, self.append_dim) 258 | unlabelset=DataSetIter(base_img,labels,test_index2pos,margin, patch_size, self.append_dim) 259 | testset = DataSetIter(base_img, labels, test_index2pos , margin, patch_size, self.append_dim) 260 | allset = DataSetIter(base_img, labels, all_index2pos, margin, patch_size, self.append_dim) 261 | 262 | return trainset, unlabelset, testset, allset 263 | 264 | def generate_torch_dataset(self): 265 | # 0. 判断是否使用numpy数据集 266 | if self.if_numpy: 267 | return self.generate_numpy_dataset() 268 | 269 | 270 | trainset, unlabelset, testset, allset = self.prepare_data() 271 | 272 | multi=self.data_param.get('unlabelled_multiple',1) 273 | train_loader = torch.utils.data.DataLoader(dataset=trainset, 274 | batch_size=self.batch_size, 275 | shuffle=True, 276 | drop_last=False 277 | ) 278 | unlabel_loader=torch.utils.data.DataLoader(dataset=unlabelset, 279 | batch_size=int(self.batch_size*multi), 280 | shuffle=False, 281 | num_workers=0, 282 | drop_last=False) 283 | test_loader = torch.utils.data.DataLoader(dataset=testset, 284 | batch_size=self.batch_size, 285 | shuffle=False, 286 | num_workers=0, 287 | drop_last=False 288 | ) 289 | all_loader = torch.utils.data.DataLoader(dataset=allset, 290 | batch_size=self.batch_size, 291 | shuffle=False, 292 | num_workers=0, 293 | drop_last=False 294 | ) 295 | 296 | return train_loader, unlabel_loader,test_loader, all_loader 297 | 298 | 299 | 300 | 301 | if __name__ == "__main__": 302 | dataloader = HSIDataLoader({"data":{"data_path_prefix":'../../data', "data_sign": "Indian", 303 | "data_file": "Indian_40"}}) 304 | train_loader, unlabel_loader, test_loader, all_loader = dataloader.generate_torch_dataset() 305 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import scipy.io as sio 4 | from sklearn.decomposition import PCA 5 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from operator import truediv 10 | import time 11 | 12 | INDIAN_TARGET_NAMES = ['Alfalfa', 'Corn-notill', 'Corn-mintill', 'Corn' 13 | , 'Grass-pasture', 'Grass-trees', 'Grass-pasture-mowed', 14 | 'Hay-windrowed', 'Oats', 'Soybean-notill', 'Soybean-mintill', 15 | 'Soybean-clean', 'Wheat', 'Woods', 'Buildings-Grass-Trees-Drives', 16 | 'Stone-Steel-Towers'] 17 | 18 | PAVIA_UNIVERSITY_NAMES = ['Asphalt','Meadows','Gravel','Trees','Painted_metal_sheets','Bare_Soil','Bitumen','Self_Blocking_Bricks','Shadows'] 19 | 20 | HOUSTION_NAMES = ['Unclassified','Healthy grass','Stressed grass','Synthetic grass','Trees','Soil','Water','Residential','Commercial','Road', 21 | 'Highway','Railway','Parking Lot 1','Parking Lot 2','Tennis Court','Running Track'] 22 | 23 | 24 | class HSIEvaluation(object): 25 | def __init__(self, param) -> None: 26 | self.param = param 27 | self.target_names = None 28 | data_sign = param['data']['data_sign'] 29 | if data_sign == 'Indian': 30 | self.target_names = INDIAN_TARGET_NAMES 31 | elif data_sign == "Pavia": 32 | self.target_names = PAVIA_UNIVERSITY_NAMES 33 | 34 | self.res = {} 35 | 36 | def AA_andEachClassAccuracy(self, confusion_matrix): 37 | list_diag = np.diag(confusion_matrix) 38 | list_raw_sum = np.sum(confusion_matrix, axis=1) 39 | each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum)) 40 | average_acc = np.mean(each_acc) 41 | return each_acc, average_acc 42 | 43 | 44 | def eval(self, y_test, y_pred_test): 45 | class_num = np.max(y_test) 46 | classification = classification_report(y_test, y_pred_test, 47 | labels=list(range(class_num)), digits=4, target_names=self.target_names) 48 | oa = accuracy_score(y_test, y_pred_test) 49 | confusion = confusion_matrix(y_test, y_pred_test) 50 | each_acc, aa = self.AA_andEachClassAccuracy(confusion) 51 | kappa = cohen_kappa_score(y_test, y_pred_test) 52 | 53 | self.res['classification'] = str(classification) 54 | self.res['oa'] = oa * 100 55 | self.res['confusion'] = str(confusion) 56 | self.res['each_acc'] = str(each_acc * 100) 57 | self.res['aa'] = aa*100 58 | self.res['kappa'] = kappa * 100 59 | return self.res -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenning0115/SpectralDiff/d1a4c978f57df8a11fa6609abbc64ec706f02a1c/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/cross_transformer.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import time, json 3 | import torch 4 | import torchvision 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from torch import nn 8 | import torch.nn.init as init 9 | from einops import rearrange, repeat 10 | import collections 11 | import torch.nn as nn 12 | 13 | 14 | def _weights_init(m): 15 | classname = m.__class__.__name__ 16 | #print(classname) 17 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): 18 | init.kaiming_normal_(m.weight) 19 | 20 | class Residual(nn.Module): 21 | def __init__(self, fn): 22 | super().__init__() 23 | self.fn = fn 24 | 25 | def forward(self, x, **kwargs): 26 | return self.fn(x, **kwargs) + x 27 | 28 | # 等于 PreNorm 29 | class LayerNormalize(nn.Module): 30 | def __init__(self, dim, fn): 31 | super().__init__() 32 | self.norm = nn.LayerNorm(dim) 33 | self.fn = fn 34 | 35 | def forward(self, x, **kwargs): 36 | return self.fn(self.norm(x), **kwargs) 37 | 38 | 39 | # 等于 FeedForward 40 | class MLP_Block(nn.Module): 41 | def __init__(self, dim, hidden_dim, dropout=0.1): 42 | super().__init__() 43 | self.net = nn.Sequential( 44 | nn.Linear(dim, hidden_dim), 45 | nn.GELU(), 46 | nn.Dropout(dropout), 47 | nn.Linear(hidden_dim, dim), 48 | nn.Dropout(dropout) 49 | ) 50 | 51 | def forward(self, x): 52 | return self.net(x) 53 | 54 | 55 | class CrossAttention(nn.Module): 56 | def __init__(self, q_dim, kv_dim, dim, heads=8, droupout=0.1) -> None: 57 | ''' 58 | 给定q和kv, 利用attention的方式返回对q的新空间编码new_q 59 | 其中q的输入维度为(batch, seq, q_dim), 最终输出维度为(batch, seq, dim) 60 | ''' 61 | super().__init__() 62 | self.heads = heads 63 | self.scale = kv_dim ** -0.5 #1/sqrt(dim) 64 | 65 | self.to_q = nn.Linear(q_dim, dim, bias=True) # dim = heads * per_dim 66 | self.to_k = nn.Linear(kv_dim, dim, bias=True) 67 | self.to_v = nn.Linear(kv_dim, dim, bias=True) 68 | 69 | self.nn1 = nn.Linear(dim, dim) 70 | self.do1 = nn.Dropout(droupout) 71 | 72 | 73 | 74 | def forward(self, x, y, mask=None): 75 | # x shape is (batch, seq1, q_dim) 76 | # y shape is (batch, seq2, kv_dim) 77 | b, n, _, h = *x.shape, self.heads 78 | by, ny, _, hy= *x.shape, self.heads 79 | assert b == by 80 | 81 | # q,k,v获取 82 | qheads, kheads, vheads = self.to_q(x), self.to_k(y), self.to_v(y) # qheads,kvheads shape all is (batch, seq, dim) 83 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (qheads, kheads, vheads)) # split into multi head attentions 84 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 85 | mask_value = -torch.finfo(dots.dtype).max 86 | 87 | if mask is not None: 88 | mask = F.pad(mask.flatten(1), (1, 0), value=True) 89 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 90 | mask = mask[:, None, :] * mask[:, :, None] 91 | dots.masked_fill_(~mask, float('-inf')) 92 | del mask 93 | 94 | attn = dots.softmax(dim=-1) # follow the softmax,q,d,v equation in the paper 95 | 96 | out = torch.einsum('bhij,bhjd->bhid', attn, v) # product of v times whatever inside softmax 97 | out = rearrange(out, 'b h n d -> b n (h d)') # concat heads into one matrix, ready for next encoder block 98 | out = self.nn1(out) 99 | out = self.do1(out) # (batch, seq1, dim) 100 | return out 101 | 102 | class Attention(nn.Module): 103 | 104 | def __init__(self, dim, heads, dim_heads, dropout): 105 | super().__init__() 106 | inner_dim = dim_heads * heads 107 | self.heads = heads 108 | self.scale = dim_heads ** -0.5 109 | 110 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 111 | self.to_out = nn.Sequential( 112 | nn.Linear(inner_dim, dim), 113 | nn.Dropout(dropout) 114 | ) 115 | def forward(self, x, mask = None): 116 | # x:[b,n,dim] 117 | b, n, _, h = *x.shape, self.heads 118 | 119 | # get qkv tuple:([b,n,head_num*head_dim],[...],[...]) 120 | qkv = self.to_qkv(x).chunk(3, dim = -1) 121 | # split q,k,v from [b,n,head_num*head_dim] -> [b,head_num,n,head_dim] 122 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 123 | 124 | # transpose(k) * q / sqrt(head_dim) -> [b,head_num,n,n] 125 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 126 | mask_value = -torch.finfo(dots.dtype).max 127 | 128 | # mask value: -inf 129 | if mask is not None: 130 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 131 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 132 | mask = mask[:, None, :] * mask[:, :, None] 133 | dots.masked_fill_(~mask, mask_value) 134 | del mask 135 | 136 | # softmax normalization -> attention matrix 137 | attn = dots.softmax(dim=-1) 138 | # value * attention matrix -> output 139 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 140 | # cat all output -> [b, n, head_num*head_dim] 141 | out = rearrange(out, 'b h n d -> b n (h d)') 142 | out = self.to_out(out) 143 | return out 144 | 145 | class Transformer(nn.Module): 146 | def __init__(self, dim, depth, heads, dim_heads, mlp_dim, dropout): 147 | super().__init__() 148 | self.layers = nn.ModuleList([]) 149 | for _ in range(depth): 150 | self.layers.append(nn.ModuleList([ 151 | Residual(LayerNormalize(dim, Attention(dim, heads=heads, dim_heads=dim_heads, dropout=dropout))), 152 | Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout=dropout))) 153 | ])) 154 | 155 | def forward(self, x, mask=None): 156 | for attention, mlp in self.layers: 157 | x = attention(x, mask=mask) # go to attention 158 | x = mlp(x) # go to MLP_Block 159 | return x 160 | 161 | class CrossTransformer(nn.Module): 162 | def __init__(self, dim, heads, mlp_dim, drouput) -> None: 163 | ''' 164 | 输入x和y, 将x在y空间中进行transformer的encoder生成x的新表示 new_x 165 | 输入x和y, 将y在x空间中进行transformer的encoder生成y的新表示 new_y 166 | 输入x和y维度应该相同 否则无法做residule 输入dim,输出dim 167 | x shape(batch, seq1, dim) 168 | y shape(batch, seq2, dim) 169 | ''' 170 | super().__init__() 171 | # CrossTransformer 目前只支持one layer, 即depth=1 172 | self.norm_x = nn.LayerNorm(dim) 173 | self.norm_y = nn.LayerNorm(dim) 174 | self.norm_x2 = nn.LayerNorm(dim) 175 | self.norm_y2 = nn.LayerNorm(dim) 176 | 177 | self.cross_attention_x = CrossAttention(dim, dim, dim, heads=heads, droupout=drouput) 178 | self.cross_attention_y = CrossAttention(dim, dim, dim, heads=heads, droupout=drouput) 179 | 180 | self.mlp_x = MLP_Block(dim, mlp_dim, dropout=drouput) 181 | self.mlp_y = MLP_Block(dim, mlp_dim, dropout=drouput) 182 | 183 | def forward(self, x, y, mask=None): 184 | assert mask==None 185 | # x和y会分别作为q以及对应的kv进行cross-transformer 186 | #1. 保留shortcut 187 | shortcut_x = x 188 | shortcut_y = y 189 | 190 | #2. 做prenorm 191 | x = self.norm_x(x) 192 | y = self.norm_y(y) 193 | 194 | #3. 分别做cross-attention 195 | x = self.cross_attention_x(x, y, mask=mask) 196 | y = self.cross_attention_y(y, x, mask=mask) 197 | 198 | #4. short cut收 199 | x = shortcut_x + x 200 | y = shortcut_y + y 201 | 202 | #5. 做mlp 和 residual 203 | x = x + self.mlp_x(self.norm_x2(x)) 204 | y = y + self.mlp_y(self.norm_y2(y)) 205 | 206 | return x, y 207 | 208 | 209 | 210 | class SE(nn.Module): 211 | 212 | def __init__(self, in_chnls, ratio): 213 | super(SE, self).__init__() 214 | self.squeeze = nn.AdaptiveAvgPool2d((1, 1)) 215 | self.compress = nn.Conv2d(in_chnls, in_chnls//ratio, 1, 1, 0) 216 | self.excitation = nn.Conv2d(in_chnls//ratio, in_chnls, 1, 1, 0) 217 | 218 | def forward(self, x): 219 | out = self.squeeze(x) 220 | out = self.compress(out) 221 | out = F.relu(out) 222 | out = self.excitation(out) 223 | return F.sigmoid(out) 224 | 225 | 226 | 227 | class HSINet(nn.Module): 228 | def __init__(self, params): 229 | super(HSINet, self).__init__() 230 | self.params = params 231 | net_params = params['net'] 232 | data_params = params['data'] 233 | 234 | num_classes = data_params.get("num_classes", 16) 235 | patch_size = data_params.get("patch_size", 13) 236 | self.spectral_size = data_params.get("spectral_size", 200) 237 | 238 | depth = net_params.get("depth", 1) 239 | heads = net_params.get("heads", 8) 240 | mlp_dim = net_params.get("mlp_dim", 8) 241 | dropout = net_params.get("dropout", 0) 242 | conv2d_out = 64 243 | dim = net_params.get("dim", 64) 244 | dim_heads = dim 245 | mlp_head_dim = dim 246 | 247 | image_size = patch_size * patch_size 248 | 249 | self.pixel_patch_embedding = nn.Linear(conv2d_out, dim) 250 | 251 | self.local_trans_pixel = Transformer(dim=dim, depth=depth, heads=heads, dim_heads=dim_heads, mlp_dim=mlp_dim, dropout=dropout) 252 | self.new_image_size = image_size 253 | self.pixel_pos_embedding = nn.Parameter(torch.randn(1, self.new_image_size+1, dim)) 254 | self.pixel_pos_scale = nn.Parameter(torch.ones(1) * 0.01) 255 | 256 | self.conv2d_features = nn.Sequential( 257 | nn.Conv2d(in_channels=self.spectral_size, out_channels=conv2d_out, kernel_size=(3, 3), padding=(1,1)), 258 | nn.BatchNorm2d(conv2d_out), 259 | nn.ReLU(), 260 | # featuremap 是在这之后加一层channel上的压缩 261 | # nn.Conv2d(in_channels=conv2d_out,out_channels=dim,kernel_size=3,padding=1), 262 | # nn.BatchNorm2d(dim), 263 | # nn.ReLU() 264 | ) 265 | 266 | self.senet = SE(conv2d_out, 5) 267 | 268 | self.cls_token_pixel = nn.Parameter(torch.randn(1, 1, dim)) 269 | self.to_latent_pixel = nn.Identity() 270 | 271 | self.mlp_head =nn.Linear(dim, num_classes) 272 | torch.nn.init.xavier_uniform_(self.mlp_head.weight) 273 | torch.nn.init.normal_(self.mlp_head.bias, std=1e-6) 274 | self.dropout = nn.Dropout(0.1) 275 | 276 | linear_dim = dim * 2 277 | self.classifier_mlp = nn.Sequential( 278 | nn.Linear(dim, linear_dim), 279 | nn.BatchNorm1d(linear_dim), 280 | nn.Dropout(0.1), 281 | nn.ReLU(), 282 | nn.Linear(linear_dim, num_classes), 283 | ) 284 | 285 | def encoder_block(self, x): 286 | ''' 287 | x: (batch, s, w, h), s=spectral, w=weigth, h=height 288 | ''' 289 | x_pixel = x 290 | 291 | b, s, w, h = x_pixel.shape 292 | img = w * h 293 | x_pixel = self.conv2d_features(x_pixel) 294 | 295 | 296 | scale = self.senet(x_pixel) 297 | # print('scale shape is ', scale.shape) 298 | # print('pixel shape is ', x_pixel.shape) 299 | # x_pixel = x_pixel * scale#(batch, image_size, dim) 300 | 301 | #1. reshape 302 | x_pixel = rearrange(x_pixel, 'b s w h-> b (w h) s') # (batch, w*h, s) 303 | 304 | #2. patch_embedding 305 | # x_pixel = self.pixel_patch_embedding(x_pixel) 306 | 307 | #3. local transformer 308 | cls_tokens_pixel = self.cls_token_pixel.expand(x_pixel.shape[0], -1, -1) 309 | x_pixel = torch.cat((cls_tokens_pixel, x_pixel), dim = 1) #[b,image+1,dim] 310 | x_pixel = x_pixel + self.pixel_pos_embedding[:,:] * self.pixel_pos_scale 311 | # x_pixel = x_pixel + self.pixel_pos_embedding[:,:] 312 | # x_pixel = self.dropout(x_pixel) 313 | 314 | x_pixel = self.local_trans_pixel(x_pixel) #(batch, image_size+1, dim) 315 | 316 | logit_pixel = self.to_latent_pixel(x_pixel[:,0]) 317 | 318 | logit_x = logit_pixel 319 | reduce_x = torch.mean(x_pixel, dim=1) 320 | 321 | return logit_x, reduce_x 322 | 323 | def forward(self, x,left=None,right=None): 324 | ''' 325 | x: (batch, s, w, h), s=spectral, w=weigth, h=height 326 | 327 | ''' 328 | logit_x, _ = self.encoder_block(x) 329 | mean_left, mean_right = None, None 330 | if left is not None and right is not None: 331 | _, mean_left = self.encoder_block(left) 332 | _, mean_right = self.encoder_block(right) 333 | 334 | # return self.mlp_head(logit_x), mean_left, mean_right 335 | return self.classifier_mlp(logit_x), mean_left, mean_right -------------------------------------------------------------------------------- /src/params/indian_diffusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "data_sign": "Indian", 4 | "data_file": "Indian_30", 5 | "diffusion_data_path": "../data/unet3d_indian.pkl.npy", 6 | "patch_size": 13, 7 | "batch_size": 64, 8 | "num_classes": 16, 9 | "pca": 1200, 10 | "dim_heads": 64, 11 | "spectral_size": 1200 12 | }, 13 | "net": { 14 | "trainer": "cross_trainer", 15 | "net_type": "just_pixel", 16 | "mlp_head_dim": 64, 17 | "depth": 2, 18 | "dim": 64, 19 | "heads": 20 20 | }, 21 | "train": { 22 | "epochs": 30, 23 | "lr": 0.001, 24 | "weight_decay": 0, 25 | "temp": 20 26 | }, 27 | "uniq_name": "Indian_test" 28 | } 29 | -------------------------------------------------------------------------------- /src/params/pavia_diffusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "data_sign": "Pavia", 4 | "data_file": "Pavia_30", 5 | "diffusion_data_path": "../data/unet3d_pavia.pkl.npy", 6 | "patch_size": 13, 7 | "batch_size": 64, 8 | "num_classes": 9, 9 | "pca": 600, 10 | "dim_heads": 64, 11 | "spectral_size":600 12 | }, 13 | "net": { 14 | "trainer": "cross_trainer", 15 | "net_type": "just_pixel", 16 | "mlp_head_dim": 64, 17 | "depth": 2, 18 | "dim": 64, 19 | "heads": 20 20 | }, 21 | "train": { 22 | "epochs": 20, 23 | "lr": 0.001, 24 | "weight_decay": 0, 25 | "temp": 20 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/params/salinas_diffusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "data_sign": "Salinas", 4 | "data_file": "Salinas_30", 5 | "diffusion_data_path": "../data/unet3d_salinas.pkl.npy", 6 | "patch_size": 13, 7 | "batch_size": 64, 8 | "num_classes": 16, 9 | "pca": 900, 10 | "dim_heads": 64, 11 | "spectral_size": 900 12 | }, 13 | "net": { 14 | "trainer": "cross_trainer", 15 | "net_type": "just_pixel", 16 | "mlp_head_dim": 64, 17 | "depth": 2, 18 | "dim": 64, 19 | "heads": 20 20 | }, 21 | "train": { 22 | "epochs": 50, 23 | "lr": 0.001, 24 | "weight_decay": 0, 25 | "temp": 20 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | from sklearn.decomposition import PCA 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from operator import truediv 10 | import matplotlib.pyplot as plt 11 | import time 12 | import pandas as pd 13 | 14 | # colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255], 15 | # [255, 0, 255], [176, 48, 96], [46, 139, 87], [160, 32, 240], [255, 127, 80], 16 | # [127, 255, 212], [218, 112, 214], [160, 82, 45], [127, 255, 0], [216, 191, 216]] 17 | 18 | colors = [[0, 0, 1], 19 | [127, 255, 0],[0, 255, 0], [0, 0, 255], [46, 139, 87],[255, 0, 255], 20 | [0, 255, 255],[255, 255, 255], [160, 82, 45], [160, 32, 240], [255, 127, 80], 21 | [218, 112, 214], [255, 0, 0], [255, 255, 0], [127, 255, 212], [216, 191, 216] 22 | ] 23 | def data_to_colormap2(data): 24 | assert len(data.shape)==2 25 | x_list = data.reshape((-1,)) 26 | y = np.zeros((x_list.shape[0], 3)) 27 | for index, item in enumerate(x_list): 28 | y[index] = np.array(colors[item]) / 255 29 | return y 30 | 31 | 32 | def data_to_colormap(data): 33 | assert len(data.shape)==2 34 | x_list = data.reshape((-1,)) 35 | y = np.zeros((x_list.shape[0], 3)) 36 | for index, item in enumerate(x_list): 37 | if item == 0: 38 | y[index] = np.array([0, 0, 0]) / 255. 39 | if item == 1: 40 | y[index] = np.array([147, 67, 46]) / 255. 41 | if item == 2: 42 | y[index] = np.array([0, 0, 255]) / 255. 43 | if item == 3: 44 | y[index] = np.array([255, 100, 0]) / 255. 45 | if item == 4: 46 | y[index] = np.array([0, 255, 123]) / 255. 47 | if item == 5: 48 | y[index] = np.array([164, 75, 155]) / 255. 49 | if item == 6: 50 | y[index] = np.array([101, 174, 255]) / 255. 51 | if item == 7: 52 | y[index] = np.array([118, 254, 172]) / 255. 53 | if item == 8: 54 | y[index] = np.array([60, 91, 112]) / 255. 55 | if item == 9: 56 | y[index] = np.array([255, 255, 0]) / 255. 57 | if item == 10: 58 | y[index] = np.array([255, 255, 125]) / 255. 59 | if item == 11: 60 | y[index] = np.array([255, 0, 255]) / 255. 61 | if item == 12: 62 | y[index] = np.array([100, 0, 255]) / 255. 63 | if item == 13: 64 | y[index] = np.array([0, 172, 254]) / 255. 65 | if item == 14: 66 | y[index] = np.array([0, 255, 0]) / 255. 67 | if item == 15: 68 | y[index] = np.array([171, 175, 80]) / 255. 69 | if item == 16: 70 | y[index] = np.array([101, 193, 60]) / 255. 71 | 72 | return y 73 | 74 | 75 | def classification_map(name, map, ground_truth, dpi, save_path): 76 | fig = plt.figure(frameon=False) 77 | fig.set_size_inches(ground_truth.shape[1]*2.0/dpi, ground_truth.shape[0]*2.0/dpi) 78 | 79 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 80 | ax.set_axis_off() 81 | ax.xaxis.set_visible(False) 82 | ax.yaxis.set_visible(False) 83 | fig.add_axes(ax) 84 | ax.set_title(name) 85 | 86 | ax.imshow(map) 87 | fig.savefig(save_path, dpi=dpi) 88 | return 0 89 | 90 | def show_map(map, data, dpi): 91 | fig = plt.figure(figsize=(12,10), frameon=False) 92 | fig.set_size_inches(data.shape[1]*2.0/dpi, data.shape[0]*2.0/dpi) 93 | 94 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 95 | ax.set_axis_off() 96 | ax.xaxis.set_visible(False) 97 | ax.yaxis.set_visible(False) 98 | fig.add_axes(ax) 99 | 100 | ax.imshow(map) 101 | return 0 -------------------------------------------------------------------------------- /src/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from utils import check_convention, config_path_prefix 4 | from workflow import train_by_param, train_convention_by_param 5 | import subprocess 6 | import workflow 7 | 8 | exchange_json_file='%s/temp.json' % config_path_prefix 9 | 10 | def simple_run_times(): 11 | times = 1 #每个配置跑5次 12 | sample_num = [80] 13 | #times = 1 #每个配置跑5次 14 | #sample_num = [40] 15 | 16 | configs = [ 17 | #'indian_cross_param_use.json', 18 | #'pavia_cross_param_use.json', 19 | #'salinas_cross_param_use.json', 20 | 21 | #'indian_diffusion.json', 22 | # 'pavia_diffusion.json', 23 | 'salinas_diffusion.json', 24 | 25 | #'indian_ssftt.json', 26 | #'pavia_ssftt.json', 27 | #'salinas_ssftt.json', 28 | 29 | # 'indian_conv1d.json', 30 | # 'pavia_conv1d.json', 31 | # 'salinas_conv1d.json', 32 | 33 | # 'indian_conv2d.json', 34 | # 'pavia_conv2d.json', 35 | # 'salinas_conv2d.json', 36 | ] 37 | for config_name in configs: 38 | path_param = '%s/%s' % (config_path_prefix, config_name ) 39 | with open(path_param, 'r') as fin: 40 | params = json.loads(fin.read()) 41 | data_sign = params['data']['data_sign'] 42 | for num in sample_num: 43 | for t in range(times): 44 | uniq_name = "%s_%s_%s" % (config_name, num, t) 45 | params['data']['data_file'] = '%s_%s' % (data_sign, num) 46 | params['uniq_name'] = uniq_name 47 | with open(exchange_json_file,'w') as fout: 48 | json.dump(params,fout) 49 | print("schedule %s..." % uniq_name) 50 | # subprocess.run('python ./workflow.py', shell=True) 51 | workflow.run_all() 52 | print("schedule done of %s..." % uniq_name) 53 | 54 | 55 | 56 | def simple_run_diffusion_t_layer(): 57 | times = 1 #每个配置跑5次 58 | sample_num = 30 59 | config_name = 'indian_diffusion.json' 60 | tlist = [5, 10, 50, 100, 200] 61 | layers [0, 1, 2] 62 | 63 | path_param = '%s/%s' % (config_path_prefix, config_name) 64 | with open(path_param, 'r') as fin: 65 | params = json.loads(fin.read()) 66 | for t in tlist: 67 | for l in layers: 68 | data_sign = params['data']['data_sign'] 69 | uniq_name = "%s_%s_%s" % (config_name, t, l) 70 | params['uniq_name'] = uniq_name 71 | params['data']['diffusion_data_sign'] = 't%s_%s_full.pkl.npy' % (data_sign, num) 72 | with open(exchange_json_file,'w') as fout: 73 | json.dump(params,fout) 74 | print("schedule %s..." % uniq_name) 75 | # subprocess.run('python ./workflow.py', shell=True) 76 | workflow.run_all() 77 | print("schedule done of %s..." % uniq_name) 78 | 79 | if __name__ == "__main__": 80 | simple_run_times() 81 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | from sklearn.decomposition import PCA 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from models import cross_transformer as cross_transformer 8 | import utils 9 | from utils import recorder 10 | from evaluation import HSIEvaluation 11 | from augment import do_augment 12 | import itertools 13 | from sklearn import svm 14 | from sklearn.ensemble import RandomForestClassifier 15 | from sklearn.neighbors import KNeighborsClassifier 16 | 17 | from utils import device 18 | 19 | 20 | class BaseTrainer(object): 21 | def __init__(self, params) -> None: 22 | self.params = params 23 | self.net_params = params['net'] 24 | self.train_params = params['train'] 25 | self.device = device 26 | self.evalator = HSIEvaluation(param=params) 27 | 28 | self.net = None 29 | self.criterion = None 30 | self.optimizer = None 31 | self.clip = 15 32 | self.unlabel_loader=None 33 | self.real_init() 34 | 35 | def real_init(self): 36 | pass 37 | 38 | def get_loss(self, outputs, target): 39 | return self.criterion(outputs, target) 40 | 41 | def train(self, train_loader, unlabel_loader=None, test_loader=None): 42 | epochs = self.params['train'].get('epochs', 100) 43 | total_loss = 0 44 | epoch_avg_loss = utils.AvgrageMeter() 45 | for epoch in range(epochs): 46 | self.net.train() 47 | epoch_avg_loss.reset() 48 | for i, (data, target) in enumerate(train_loader): 49 | data, target = data.to(self.device), target.to(self.device) 50 | outputs = self.net(data) 51 | loss = self.get_loss(outputs, target) 52 | self.optimizer.zero_grad() 53 | loss.backward() 54 | torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.clip) 55 | self.optimizer.step() 56 | # batch stat 57 | total_loss += loss.item() 58 | epoch_avg_loss.update(loss.item(), data.shape[0]) 59 | recorder.append_index_value("epoch_loss", epoch + 1, epoch_avg_loss.get_avg()) 60 | print('[Epoch: %d] [epoch_loss: %.5f] [all_epoch_loss: %.5f] [current_batch_loss: %.5f] [batch_num: %s]' % (epoch + 1, 61 | epoch_avg_loss.get_avg(), 62 | total_loss / (epoch + 1), 63 | loss.item(), epoch_avg_loss.get_num())) 64 | # 一定epoch下进行一次eval 65 | if test_loader and (epoch+1) % 10 == 0: 66 | y_pred_test, y_test = self.test(test_loader) 67 | temp_res = self.evalator.eval(y_test, y_pred_test) 68 | recorder.append_index_value("train_oa", epoch+1, temp_res['oa']) 69 | recorder.append_index_value("train_aa", epoch+1, temp_res['aa']) 70 | recorder.append_index_value("train_kappa", epoch+1, temp_res['kappa']) 71 | print('[--TEST--] [Epoch: %d] [oa: %.5f] [aa: %.5f] [kappa: %.5f] [num: %s]' % (epoch+1, temp_res['oa'], temp_res['aa'], temp_res['kappa'], str(y_test.shape))) 72 | print('Finished Training') 73 | return True 74 | 75 | def final_eval(self, test_loader): 76 | y_pred_test, y_test = self.test(test_loader) 77 | temp_res = self.evalator.eval(y_test, y_pred_test) 78 | return temp_res 79 | 80 | def get_logits(self, output): 81 | if type(output) == tuple: 82 | return output[0] 83 | return output 84 | 85 | def test(self, test_loader): 86 | """ 87 | provide test_loader, return test result(only net output) 88 | """ 89 | count = 0 90 | self.net.eval() 91 | y_pred_test = 0 92 | y_test = 0 93 | for inputs, labels in test_loader: 94 | inputs = inputs.to(self.device) 95 | outputs = self.get_logits(self.net(inputs)) 96 | outputs = np.argmax(outputs.detach().cpu().numpy(), axis=1) 97 | if count == 0: 98 | y_pred_test = outputs 99 | y_test = labels 100 | count = 1 101 | else: 102 | y_pred_test = np.concatenate((y_pred_test, outputs)) 103 | y_test = np.concatenate((y_test, labels)) 104 | return y_pred_test, y_test 105 | 106 | 107 | class CrossTransformerTrainer(BaseTrainer): 108 | def __init__(self, params): 109 | super(CrossTransformerTrainer, self).__init__(params) 110 | 111 | 112 | def real_init(self): 113 | # net 114 | self.net = cross_transformer.HSINet(self.params).to(self.device) 115 | # loss 116 | self.criterion = nn.CrossEntropyLoss() 117 | # optimizer 118 | self.lr = self.train_params.get('lr', 0.001) 119 | self.weight_decay = self.train_params.get('weight_decay', 5e-3) 120 | self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay) 121 | 122 | def get_loss(self, outputs, target): 123 | ''' 124 | A_vecs: [batch, dim] 125 | B_vecs: [batch, dim] 126 | logits: [batch, class_num] 127 | ''' 128 | logits, A_vecs, B_vecs = outputs 129 | 130 | loss_main = nn.CrossEntropyLoss()(logits, target) 131 | 132 | return loss_main 133 | 134 | def get_trainer(params): 135 | trainer_type = params['net']['trainer'] 136 | if trainer_type == "cross_trainer": 137 | return CrossTransformerTrainer(params) 138 | 139 | assert Exception("Trainer not implemented!") 140 | 141 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json, time 3 | import numpy as np 4 | import torch 5 | 6 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 7 | config_path_prefix = './params' 8 | 9 | def check_convention(name): 10 | for a in ['knn', 'random_forest', 'svm']: 11 | if a in name: 12 | return True 13 | return False 14 | 15 | 16 | class AvgrageMeter(object): 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.avg = 0 22 | self.sum = 0 23 | self.cnt = 0 24 | 25 | def update(self, val, n=1): 26 | self.sum += val * n 27 | self.cnt += n 28 | self.avg = self.sum / self.cnt 29 | 30 | def get_avg(self): 31 | return self.avg 32 | 33 | def get_num(self): 34 | return self.cnt 35 | 36 | class HSIRecoder(object): 37 | def __init__(self) -> None: 38 | self.record_data = {} 39 | self.pred = None 40 | 41 | def append_index_value(self, name, index, value): 42 | """ 43 | index : int, 44 | value: Any 45 | save to dict 46 | {index: list, value: list} 47 | """ 48 | if name not in self.record_data: 49 | self.record_data[name] = { 50 | "type": "index_value", 51 | "index":[], 52 | "value":[] 53 | } 54 | self.record_data[name]['index'].append(index) 55 | self.record_data[name]['value'].append(value) 56 | 57 | def record_time(self, time): 58 | self.record_data['eval_time'] = time 59 | 60 | def record_param(self, param): 61 | self.record_data['param'] = param 62 | 63 | def record_eval(self, eval_obj): 64 | self.record_data['eval'] = eval_obj 65 | 66 | def record_pred(self, pred_matrix): 67 | self.pred = pred_matrix 68 | 69 | def to_file(self, path): 70 | time_stamp = int(time.time()) 71 | save_path_json = "%s_%s.json" % (path, str(time_stamp)) 72 | save_path_pred = "%s_%s.pred.npy" % (path, str(time_stamp)) 73 | 74 | ss = json.dumps(self.record_data, indent=4) 75 | with open(save_path_json, 'w') as fout: 76 | fout.write(ss) 77 | fout.flush() 78 | np.save(save_path_pred, self.pred) 79 | 80 | print("save record of %s done!" % path) 81 | 82 | def reset(self): 83 | self.record_data = {} 84 | 85 | 86 | # global recorder 87 | recorder = HSIRecoder() 88 | -------------------------------------------------------------------------------- /src/workflow.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, json 2 | import numpy as np 3 | import time 4 | import utils 5 | from utils import recorder 6 | 7 | from data_provider.data_provider import HSIDataLoader 8 | from trainer import get_trainer, BaseTrainer, CrossTransformerTrainer 9 | import evaluation 10 | from utils import check_convention, config_path_prefix 11 | 12 | DEFAULT_RES_SAVE_PATH_PREFIX = "./res/" 13 | 14 | def train_by_param(param): 15 | #0. recorder reset防止污染数据 16 | recorder.reset() 17 | # 1. 数据生成 18 | dataloader = HSIDataLoader(param) 19 | train_loader,unlabel_loader, test_loader, all_loader = dataloader.generate_torch_dataset() 20 | 21 | # 2. 训练和测试 22 | trainer = get_trainer(param) 23 | trainer.train(train_loader, unlabel_loader,test_loader) 24 | eval_res = trainer.final_eval(test_loader) 25 | 26 | start_eval_time = time.time() 27 | pred_all, y_all = trainer.test(all_loader) 28 | end_eval_time = time.time() 29 | eval_time = end_eval_time - start_eval_time 30 | print("eval time is %s" % eval_time) 31 | recorder.record_time(eval_time) 32 | pred_matrix = dataloader.reconstruct_pred(pred_all) 33 | 34 | 35 | #3. record all information 36 | recorder.record_param(param) 37 | recorder.record_eval(eval_res) 38 | recorder.record_pred(pred_matrix) 39 | 40 | return recorder 41 | 42 | def train_convention_by_param(param): 43 | #0. recorder reset防止污染数据 44 | recorder.reset() 45 | # 1. 数据生成 46 | dataloader = HSIDataLoader(param) 47 | trainX, trainY, testX, testY, allX = dataloader.generate_torch_dataset() 48 | 49 | # 2. 训练和测试 50 | trainer = get_trainer(param) 51 | trainer.train(trainX, trainY) 52 | eval_res = trainer.final_eval(testX, testY) 53 | pred_all = trainer.test(allX) 54 | pred_matrix = dataloader.reconstruct_pred(pred_all) 55 | 56 | #3. record all information 57 | recorder.record_param(param) 58 | recorder.record_eval(eval_res) 59 | recorder.record_pred(pred_matrix) 60 | 61 | return recorder 62 | 63 | 64 | 65 | 66 | include_path = [ 67 | 'indian_diffusion.json', 68 | # 'pavia_diffusion.json', 69 | # 'salinas_diffusion.json', 70 | ] 71 | def run_one(param): 72 | save_path_prefix = DEFAULT_RES_SAVE_PATH_PREFIX 73 | if not os.path.exists(save_path_prefix): 74 | os.makedirs(save_path_prefix) 75 | name = param['net']['trainer'] 76 | convention = check_convention(name) 77 | uniq_name = param.get('uniq_name', name) 78 | print('start to train %s...' % uniq_name) 79 | if convention: 80 | train_convention_by_param(param) 81 | else: 82 | train_by_param(param) 83 | print('model eval done of %s...' % uniq_name) 84 | path = '%s/%s' % (save_path_prefix, uniq_name) 85 | recorder.to_file(path) 86 | 87 | 88 | def run_all(): 89 | save_path_prefix = DEFAULT_RES_SAVE_PATH_PREFIX 90 | if not os.path.exists(save_path_prefix): 91 | os.makedirs(save_path_prefix) 92 | for name in include_path: 93 | convention = check_convention(name) 94 | path_param = '%s/%s' % (config_path_prefix, name) 95 | with open(path_param, 'r') as fin: 96 | param = json.loads(fin.read()) 97 | uniq_name = param.get('uniq_name', name) 98 | print('start to train %s...' % uniq_name) 99 | if convention: 100 | train_convention_by_param(param) 101 | else: 102 | train_by_param(param) 103 | print('model eval done of %s...' % uniq_name) 104 | path = '%s/%s' % (save_path_prefix, uniq_name) 105 | recorder.to_file(path) 106 | 107 | 108 | def result_file_exists(prefix, file_name_part): 109 | ll = os.listdir(prefix) 110 | for l in ll: 111 | if file_name_part in l: 112 | return True 113 | return False 114 | 115 | 116 | if __name__ == "__main__": 117 | run_all() 118 | 119 | 120 | 121 | 122 | 123 | 124 | --------------------------------------------------------------------------------