├── .idea ├── .gitignore ├── 2018paper_test.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── remote-mappings.xml ├── configs └── configs.py ├── data ├── HSI_data.py ├── __init__.py ├── data_preprocess.py ├── data_test.py ├── get_train_test_set.py └── normalizes.py ├── main.py ├── model ├── LSTM.py ├── MSCNN.py └── SSUN.py └── tool ├── assessment.py ├── show.py ├── test.py └── train.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/2018paper_test.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /configs/configs.py: -------------------------------------------------------------------------------- 1 | import scipy.io as scio 2 | import matplotlib.pyplot as plt 3 | from scipy.io import loadmat 4 | import spectral as spy 5 | 6 | # 预先进行相应数据集的定义 7 | datasets_type = 'HSI Data Sets' 8 | 9 | # data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/PaviaU.mat' 10 | # image_name = 'paviaU' 11 | # gt_name = 'paviaU_gt' 12 | # paviaU的特征分为了9类 13 | # torch([32, 103, 610, 340]) 14 | 15 | # data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/Urban.mat' 16 | # image_name = 'Urban' 17 | # gt_name = 'R' 18 | # Urban的特征分为了9类 19 | 20 | data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/Indian_pines.mat' 21 | image_name = 'indian_pines' 22 | gt_name = 'R' 23 | # indian_pines的特征分为了16类 24 | # torch([32, 220, 145, 145]) 25 | 26 | # data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/KSC.mat' 27 | # image_name = 'KSC' 28 | # gt_name = 'KSC_gt' 29 | # indian_pines的特征分为了13类 30 | # torch([32, 176, 512, 614]) 31 | 32 | 33 | # 其他相应的参数进行配置 34 | 35 | phase = ['train', 'test', 'no_gt'] 36 | pca_num = 5 37 | train_set_num = 64 38 | patch_size = 24 39 | 40 | # 构建data字典,将所有的数据放在data中来进行调用 41 | data = dict( 42 | data_path=data_root, 43 | image_name=image_name, 44 | gt_name=gt_name, 45 | train_set_num=train_set_num, 46 | patch_size=patch_size, 47 | pca=pca_num, 48 | train_data=dict( 49 | phase=phase[0] 50 | ), 51 | test_data=dict( 52 | phase=phase[1] 53 | ), 54 | no_gt_data=dict( 55 | phase=phase[2] 56 | ) 57 | ) 58 | 59 | # 建立相应模型的预参数 60 | 61 | model = dict( 62 | in_fea_num=1, 63 | out_fea_num=9, 64 | ) 65 | 66 | # 训练模型的预参数 67 | 68 | lr = 1e-3 69 | 70 | train = dict( 71 | # 优化器的相应数据 72 | optimizer=dict( 73 | typename='SGD', 74 | lr=lr, 75 | betas=(0.9, 0.999), 76 | momentum=0.9, # 动量 77 | weight_decay=1e-4 # 权重衰减 78 | ), 79 | 80 | train_model=dict( 81 | gpu_train=True, 82 | gpu_num=1, 83 | workers_num=16, 84 | epoch=500, 85 | batch_size=64, 86 | # 学习率的相应参数 87 | lr=lr, 88 | lr_adjust=True, 89 | lr_gamma=0.1, 90 | lr_step=460, 91 | save_folder='./weights/', 92 | save_name='model_CNN1D', 93 | reuse_model=False, 94 | reuse_file='./weights/model_CNN1D_Final.pth' 95 | ) 96 | ) 97 | 98 | test = dict( 99 | batch_size=1000, 100 | gpu_train=True, 101 | gpu_num=1, 102 | workers_num=16, 103 | model_weights='./weights/model_CNN1D_Final.pth', 104 | save_folder='./result' 105 | ) 106 | 107 | 108 | def main(): 109 | data = loadmat(data_root) 110 | print(data.keys()) 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /data/HSI_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | 5 | 6 | class HSI_data(data.Dataset): 7 | def __init__(self, data_sample, cfg): 8 | self.phase = cfg['phase'] 9 | 10 | # img:pad_img 11 | # img_indices:每个patch的坐标合集 12 | self.img = data_sample['pad_img'] 13 | self.img_indices = data_sample['pad_img_indices'] 14 | self.gt = data_sample['img_gt'] 15 | 16 | self.pca = 'img_pca_pad' in data_sample # 判断'img_pca_pad'是否在data_sample中 17 | # 是的话返回 Ture,不是返回Falus 18 | 19 | if self.pca: 20 | self.img_pca = data_sample['img_pca_pad'] 21 | # data_indices:用img_gt的标签信息划分得到的样本 22 | 23 | if self.phase == 'train': 24 | self.data_indices = data_sample['train_indices'] 25 | elif self.phase == 'test': 26 | self.data_indices = data_sample['test_indices'] 27 | elif self.phase == 'no_gt': 28 | self.data_indices = data_sample['no_gt_indices'] 29 | 30 | def __len__(self): 31 | return len(self.data_indices) 32 | 33 | # 将其中的函数根据相应的下标进行索引 34 | # 该方法支持从 0 到 len(self)的索引 35 | # data_indices=: torch.Size([270, 3]) 36 | # img_indices=: torch.Size([207400, 4]) 37 | def __getitem__(self, idx): 38 | 39 | index = self.data_indices[idx] 40 | img_index = self.img_indices[index[0]] # img_index 坐标 41 | 42 | # 从pad_img中根据坐标截取样本 43 | img = self.img[:, img_index[0]:img_index[1], img_index[2]:img_index[3]] 44 | label = self.gt[index[1], index[2]] 45 | 46 | if self.pca: 47 | img_pca = self.img_pca[:, img_index[0]:img_index[1], img_index[2]:img_index[3]] 48 | 49 | # 存在pca的话,则从相应的图像中进行截取 50 | return img, label, index, img_pca 51 | else: 52 | return img, label, index 53 | 54 | 55 | # 用来处理不同情况下的输入的dataset的封装 56 | def batch_collate(batch): 57 | images = [] 58 | labels = [] 59 | indices = [] 60 | images_pca = [] 61 | 62 | for sample in batch: 63 | images.append(sample[0]) 64 | labels.append(sample[1]) 65 | indices.append(sample[2]) 66 | 67 | if len(sample) > 3: 68 | images_pca.append(sample[3]) 69 | 70 | # stack将相应的数组进行连接 71 | if len(images_pca) > 0: 72 | return torch.stack(images, 0), torch.stack(labels), \ 73 | torch.stack(indices), torch.stack(images_pca, 0) 74 | else: 75 | return torch.stack(images, 0), torch.stack(labels), \ 76 | torch.stack(indices) 77 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wzysaber/2018paper_code/ade9a5414ba3df1618dd42136a77763bb2d1f305/data/__init__.py -------------------------------------------------------------------------------- /data/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | 5 | 6 | # 将103通道的图降维成为3通道的图 7 | def extract_pc(image, pc=3): 8 | ''' 9 | :function:123 10 | :param image: 11 | :param pc: 12 | :return: 13 | ''' 14 | channel, height, width = image.shape 15 | data = image.contiguous().reshape(channel, height * width) # 存在contiguous函数,在改变data的值后 16 | data_c = data - data.mean(dim=1).unsqueeze(1) 17 | # 计算一个矩阵或一批矩阵 input 的奇异值分解 18 | u, s, v = torch.svd(data_c.matmul(data_c.T)) # data_c矩阵乘以data_c的转置 19 | sorted_data, indices = s.sort(descending=True) # 将s中的数按降序进行排列 20 | image_pc = u[:, indices[0:pc]].T.matmul(data) 21 | return image_pc.reshape(pc, height, width) 22 | 23 | 24 | # (x - mean(x))/std(x) normalize to mean: 0, std: 1 25 | # 将数据进行归一化处理 26 | 27 | def std_norm(image): 28 | image = image.permute(1, 2, 0).numpy() 29 | trans = transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize(torch.tensor(image).mean(dim=[0, 1]), torch.tensor(image).std(dim=[0, 1])) 32 | ]) 33 | 34 | return trans(image) 35 | 36 | 37 | # (x - min(x))/(max(x) - min(x)) normalize to (0, 1) for each channel 38 | def one_zero_norm(image): 39 | channel, height, width = image.shape 40 | data = image.view(channel, height * width) 41 | data_max = data.max(dim=1)[0] 42 | data_min = data.min(dim=1)[0] 43 | 44 | data = (data - data_min.unsqueeze(1)) / (data_max.unsqueeze(1) - data_min.unsqueeze(1)) # 在第二个维度上插入一个维度 45 | 46 | return data.view(channel, height, width) 47 | 48 | 49 | # input tensor image size with CxHxW 50 | # -1 + 2 * (x - min(x))/(max(x) - min(x)) normalize to (-1, 1) for each channel 51 | # 同样对数据进行归一化处理,范围在(-1,1) 52 | def pos_neg_norm(image): 53 | channel, height, width = image.shape 54 | data = image.view(channel, height * width) 55 | data_max = data.max(dim=1)[0] 56 | data_min = data.min(dim=1)[0] 57 | 58 | data = -1 + 2 * (data - data_min.unsqueeze(1)) / (data_max.unsqueeze(1) - data_min.unsqueeze(1)) 59 | 60 | return data.view(channel, height, width) 61 | 62 | 63 | # function:construct sample,切分得到patch,储存每个patch的坐标值 64 | # input: image:torch.size(103, 610, 340) 65 | # window_size:27 66 | # output:pad_image, batch_image_indices 67 | def construct_sample(image, window_size=27): 68 | # 先输入照片的通道数等数据的指标 69 | channel, height, width = image.shape 70 | 71 | half_window = int(window_size // 2) # 13 72 | # 使用输入边界的复制值来填充 73 | pad = nn.ReplicationPad2d(half_window) # 上下左右伸展13单位值,就是26 74 | # uses (padding_left, padding_right,padding_top, padding_bottom) 75 | 76 | pad_image = pad(image.unsqueeze(0)).squeeze(0) # torch.Size([103, 636, 366]) 77 | 78 | # 用数组存储切分得到的patch的坐标 79 | # torch.Size([207400, 4]) 80 | batch_image_indices = torch.zeros((height * width, 4), dtype=torch.long) 81 | 82 | t = 0 83 | for h in range(height): 84 | for w in range(width): 85 | batch_image_indices[t, :] = torch.tensor([h, h + window_size, w, w + window_size]) 86 | t += 1 87 | 88 | return pad_image, batch_image_indices 89 | 90 | 91 | # 这里的gt是相应的标签值,背景图是0,相应的标志物为1到9 92 | # 将gt的标签进行变化,使背景为-1,相应的标志物为0到9 93 | def label_transform(gt): 94 | ''' 95 | function:tensor label to 0-n for training 96 | input: gt 97 | output:gt 98 | tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) 99 | -> tensor([-1., 0., 1., 2., 3., 4., 5., 6., 7., 8.]) 100 | ''' 101 | label = torch.unique(gt) # 返回标签label中的不同值 102 | gt_new = torch.zeros_like(gt) # zeros_like(a)的目的是构建一个与a同维度的数组,并初始化所有变量为零。 103 | # zeros,则需要代入参数 104 | 105 | for each in range(len(label)): # 长度为10 106 | indices = torch.where(gt == label[each]) 107 | 108 | if label[0] == 0: 109 | gt_new[indices] = each - 1 110 | else: 111 | gt_new[indices] = each 112 | 113 | # tensor([-1., 0., 1., 2., 3., 4., 5., 6., 7., 8.]) 114 | # labeL_new = torch.unique(gt_new) 115 | 116 | return gt_new 117 | 118 | 119 | # 将标签值进行还原 120 | def label_inverse_transform(predict_result, gt): 121 | label_origin = torch.unique(gt) 122 | label_predict = torch.unique(predict_result) 123 | 124 | predict_result_origin = torch.zeros_like(predict_result) 125 | for each in range(len(label_predict)): 126 | indices = torch.where(predict_result == label_predict[each]) # 此时他会返回等式相等时的坐标 127 | if len(label_predict) != len(label_origin): 128 | predict_result_origin[indices] = label_origin[each + 1] 129 | else: 130 | predict_result_origin[indices] = label_origin[each] 131 | 132 | return predict_result_origin 133 | 134 | 135 | def select_sample(gt, ntr): 136 | ''' 137 | function: 用img_gt的标签信息划分样本 138 | input: gt -> torch.Size(610, 340); ntr -> train_set_num 30 139 | output:data_sample = {'train_indices': train_indices, 'train_num': train_num, 140 | 'test_indices': test_indices, 'test_num': test_num, 141 | 'no_gt_indices': no_gt_indices, 'no_gt_num': no_gt_num.unsqueeze(0) } 142 | ''' 143 | gt_vector = gt.reshape(-1, 1).squeeze(1) # 使用reshape函数来对其进行重组, reshape(1,-1)转化成1行 144 | # torch.Size([207400]) 145 | 146 | label = torch.unique(gt) 147 | 148 | first_time = True 149 | 150 | for each in range(len(label)): # each 0~9 151 | indices_vector = torch.where(gt_vector == label[each]) # 返回1位的索引,也就标签值的具体位置 152 | # 将相应的标签进行遍历 153 | indices = torch.where(gt == label[each]) # 返回2维的索引,比如gt中-1的二维坐标 154 | 155 | # print(indices) 156 | indices_vector = indices_vector[0] 157 | indices_row = indices[0] 158 | indices_column = indices[1] 159 | 160 | # 相应的背景值为 -1 161 | if label[each] == -1: 162 | no_gt_indices = torch.cat([ 163 | indices_vector.unsqueeze(1), 164 | indices_row.unsqueeze(1), 165 | indices_column.unsqueeze(1)], 166 | dim=1 167 | ) 168 | no_gt_num = torch.tensor(len(indices_vector)) 169 | 170 | # 其他标签 0-8 171 | else: 172 | class_num = torch.tensor(len(indices_vector)) 173 | # each循环得到class_num:6631->18649->2099->3064->1345->5029->1330->3682->947 174 | # 在不同标签下得到的长度值 175 | 176 | # 得到选择的数量 ntr = train_set_num 30 177 | # if ntr < 1: # 表现为百分数 178 | # ntr0 = int(ntr * class_num) 179 | # else: 180 | # ntr0 = ntr 181 | # # 最小值也得选10 182 | # if ntr0 < 10: 183 | # sel_num = 10 184 | # elif ntr0 > class_num // 2: 185 | # sel_num = class_num // 2 186 | # else: 187 | # sel_num = ntr0 188 | 189 | train_num_array = [30, 150, 150, 100, 150, 150, 20, 150, 15, 150, 150, 150, 150, 150, 50, 50] 190 | sel_num = train_num_array[each-1] 191 | 192 | sel_num = torch.tensor(sel_num) # tensor(30) 193 | 194 | # 将标签进行打乱 195 | rand_indices0 = torch.randperm(class_num) # torch.randperm 给定参数n,返回一个从0到n-1的随机整数排列 196 | rand_indices = indices_vector[rand_indices0] 197 | 198 | # 划分训练集train,测试集test 199 | # 划分打乱后的随机整数排列 200 | tr_ind0 = rand_indices0[0:sel_num] # torch.Size([30]) 201 | te_ind0 = rand_indices0[sel_num:] # 将剩下的数据用作测试集 202 | 203 | # 划分随机整数排列对应的gt 204 | tr_ind = rand_indices[0:sel_num] # torch.Size([30]) 205 | te_ind = rand_indices[sel_num:] 206 | 207 | # 训练集train: 索引+坐标 208 | sel_tr_ind = torch.cat([ 209 | tr_ind.unsqueeze(1), 210 | indices_row[tr_ind0].unsqueeze(1), 211 | indices_column[tr_ind0].unsqueeze(1)], 212 | dim=1 213 | ) # torch.Size([30, 3]) 214 | 215 | # 测试集test 216 | sel_te_ind = torch.cat([ 217 | te_ind.unsqueeze(1), 218 | indices_row[te_ind0].unsqueeze(1), 219 | indices_column[te_ind0].unsqueeze(1)], 220 | dim=1 221 | ) # torch.Size([6601, 3]) 222 | 223 | if first_time: 224 | first_time = False 225 | 226 | train_indices = sel_tr_ind 227 | train_num = sel_num.unsqueeze(0) 228 | 229 | test_indices = sel_te_ind 230 | test_num = (class_num - sel_num).unsqueeze(0) 231 | 232 | else: 233 | train_indices = torch.cat([train_indices, sel_tr_ind], dim=0) 234 | train_num = torch.cat([train_num, sel_num.unsqueeze(0)]) 235 | 236 | test_indices = torch.cat([test_indices, sel_te_ind], dim=0) 237 | test_num = torch.cat([test_num, (class_num - sel_num).unsqueeze(0)]) 238 | 239 | # 训练集 240 | rand_tr_ind = torch.randperm(train_num.sum()) 241 | train_indices = train_indices[rand_tr_ind,] 242 | # 测试集 243 | rand_te_ind = torch.randperm(test_num.sum()) # torch.Size([42506]) 244 | test_indices = test_indices[rand_te_ind,] # torch.Size([42506, 3]) 245 | # 背景 246 | rand_no_gt_ind = torch.randperm(no_gt_num.sum()) # torch.Size([164624]) 247 | no_gt_indices = no_gt_indices[rand_no_gt_ind,] # torch.Size([164624, 3]) 248 | 249 | # 将6种数据参数进行保存 250 | data_sample = {'train_indices': train_indices, 'train_num': train_num, 251 | 'test_indices': test_indices, 'test_num': test_num, 252 | 'no_gt_indices': no_gt_indices, 'no_gt_num': no_gt_num.unsqueeze(0) 253 | } 254 | 255 | return data_sample 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /data/data_test.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import numpy as np 3 | import configs.configs as cfg 4 | 5 | def normalize(x, k): 6 | if k == 1: 7 | mu = np.mean(x, 0) 8 | x_norm = x - mu 9 | sigma = np.std(x_norm, 0) 10 | x_norm = x_norm / sigma 11 | return x_norm 12 | elif type == 2: 13 | minx = np.min(x, 0) 14 | maxx = np.max(x, 0) 15 | x_norm = x - minx 16 | x_norm = x_norm / (maxx - minx) 17 | return x_norm 18 | 19 | 20 | def HyperFunctions(timestep=4, s1s2=2): 21 | cfg_data = cfg.data 22 | data_path = cfg_data['data_path'] 23 | data = sio.loadmat(data_path) 24 | x = data['indian_pines'] 25 | y = data['R'] 26 | 27 | train_num_array = [30, 150, 150, 100, 150, 150, 20, 150, 15, 150, 150, 150, 150, 150, 50, 50] 28 | train_num_array = np.array(train_num_array).astype('int') 29 | [row, col, n_feature] = x.shape 30 | x = x.reshape(row * col, n_feature) 31 | y = y.reshape(row * col, 1) 32 | # 16 33 | n_class = y.max() 34 | # 55 35 | nb_features = int(n_feature / timestep) 36 | # 1765 37 | train_num_all = sum(train_num_array) 38 | # (21025, 4, 55) 39 | x = normalize(x, 1) 40 | 41 | x_reshape = np.zeros((x.shape[0], timestep, nb_features)) 42 | if s1s2 == 2: 43 | 44 | for j in range(0, timestep): 45 | x_reshape[:, j, :] = x[:, j:j + (nb_features - 1) * timestep + 1:timestep] 46 | else: 47 | for j in range(0, timestep): 48 | x_reshape[:, j, :] = x[:, j * nb_features:(j + 1) * nb_features] 49 | 50 | x_data_all = x_reshape 51 | 52 | randomarray = list() 53 | 54 | for i in range(1, n_class + 1): 55 | index = np.where(y == i)[0] 56 | n_data = index.shape[0] 57 | randomarray.append(np.random.permutation(n_data)) 58 | 59 | flag1 = 0 60 | flag2 = 0 61 | # (1765, 4, 55) 62 | x_train = np.zeros((train_num_all, timestep, nb_features)) 63 | 64 | x_test = np.zeros((sum(y > 0)[0] - train_num_all, timestep, nb_features)) 65 | 66 | for i in range(1, n_class + 1): 67 | index = np.where(y == i)[0] 68 | # 46 69 | n_data = index.shape[0] 70 | # 30 71 | train_num = train_num_array[i - 1] 72 | randomx = randomarray[i - 1] 73 | if s1s2 == 2: 74 | 75 | for j in range(0, timestep): 76 | x_train[flag1:flag1 + train_num, j, :] = x[index[randomx[0:train_num]], 77 | j:j + (nb_features - 1) * timestep + 1:timestep] 78 | x_test[flag2:flag2 + n_data - train_num, j, :] = x[index[randomx[train_num:n_data]], 79 | j:j + (nb_features - 1) * timestep + 1:timestep] 80 | 81 | else: 82 | for j in range(0, timestep): 83 | x_train[flag1:flag1 + train_num, j, :] = x[index[randomx[0:train_num]], 84 | j * nb_features:(j + 1) * nb_features] 85 | x_test[flag2:flag2 + n_data - train_num, j, :] = x[index[randomx[train_num:n_data]], 86 | j * nb_features:(j + 1) * nb_features] 87 | 88 | flag1 = flag1 + train_num 89 | flag2 = flag2 + n_data - train_num 90 | # (1765, 4, 55) (8484, 4, 55) 91 | return x_data_all.astype('float32') 92 | 93 | 94 | def main(): 95 | out1 = HyperFunctions() 96 | print(out1.shape) 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /data/get_train_test_set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.io as io 3 | import data_preprocess as pre_fun 4 | 5 | 6 | def get_train_test_set(cfg): 7 | ''' 8 | function: (1)划分数据集train,test 9 | (2)加载数据集,转化为tensor,label transform, 10 | (3)切分patch,储存每个patch的坐标值, 11 | (4)由gt划分样本,最终得到data_sample 12 | 13 | input: cfg,也就是在config中编辑相应的数据 14 | 15 | output:输出由gt进行划分的data_sample 16 | 17 | # dict_keys(['train_indices', 'train_num', 'test_indices', 'test_num', 18 | # 'no_gt_indices', 'no_gt_num', 'pad_img', 'pad_img_indices', 'img_gt', 'ori_gt']) 19 | ''' 20 | 21 | # 从cfg中导入设定好的参数 22 | data_path = cfg['data_path'] # 导入存放的地址 23 | image_name = cfg['image_name'] # paviaU 24 | # 这个其实就是103通道的整个图像的信息 25 | 26 | gt_name = cfg['gt_name'] # 'paviaU_gt' 27 | # 这个是地图的特征信息,同时具有标签值 28 | # [0,0,1,1,1,4,4,4,]这种的标签图 29 | 30 | train_set_num = cfg['train_set_num'] # 30,每一次数据集训练的次数 31 | patch_size = cfg['patch_size'] # 27,用于切分图像的尺寸 32 | 33 | # 加载数据高光谱的数据集 34 | # 加载的数据变化形式,先进行.astype('float32')再进行.from_numpy(img),就转化为torch.Size的格式 35 | data = io.loadmat(data_path) # 从相应的文件夹导入 36 | 37 | img = data[image_name].astype('float32') # .astype转换数组的数据类型 (610, 340, 103) [w,h,c] 38 | gt = data[gt_name].astype('float32') # 转换成float32 (610, 340) ,这个数据从数据库中导入,只有一个数据 39 | 40 | img = torch.from_numpy(img) # 转tensor # torch.Size(610, 340, 103) 41 | gt = torch.from_numpy(gt) # torch.Size(610, 340) 42 | 43 | img = img.permute(2, 0, 1) # 变换tensor的维度,把channel放到第一维CxHxW # torch.Size(103, 610, 340) 44 | img = pre_fun.std_norm(img) # 归一化,torch.Size(103, 610, 340) 将数据分布在(0,1)之间 45 | 46 | # label transform 0~9 -> -1~8 47 | # 将标签值进行转换,应该在mat文件中,对不同的物体的label值就做好了定义 48 | img_gt = pre_fun.label_transform(gt) # torch.size(610, 340) 49 | 50 | # construct_sample:切分patch,储存每个patch的坐标值 51 | # img_pad的值为([103, 636, 366]), 52 | # img_pad_indices的值为([207400, 4]) 53 | img_pad, img_pad_indices = pre_fun.construct_sample(img, patch_size) 54 | 55 | # (1)select_sample:用img_gt的标签信息划分样本 56 | # (2)得到的data_sample = {'train_indices': train_indices, 'train_num': train_num, 57 | # 'test_indices': test_indices, 'test_num': test_num, 58 | # 'no_gt_indices': no_gt_indices, 'no_gt_num': no_gt_num.unsqueeze(0) 59 | # } 60 | data_sample = pre_fun.select_sample(img_gt, train_set_num) 61 | 62 | # data_sample再添加几项数据 63 | data_sample['pad_img'] = img_pad 64 | data_sample['pad_img_indices'] = img_pad_indices 65 | data_sample['img_gt'] = img_gt # 转化后的特征标签图的数据 66 | data_sample['ori_gt'] = gt # 原始特征标签图的数据 67 | 68 | # print('data_sample.keys()',data_sample.keys()) 69 | # dict_keys(['train_indices', 'train_num', 'test_indices', 'test_num', 70 | # 'no_gt_indices', 'no_gt_num', 'pad_img', 'pad_img_indices', 'img_gt', 'ori_gt']) 71 | 72 | # 在预处理中cfg['pca']=1,在我的理解里就是是否执行的标志位 73 | # 属于是buff叠满了 74 | # 将图像进行扩展,归0化和,标准差 75 | if cfg['pca'] > 0: 76 | img_pca = pre_fun.extract_pc(img, cfg['pca']) 77 | img_pca = pre_fun.one_zero_norm(img_pca) 78 | img_pca = pre_fun.std_norm(img_pca) 79 | 80 | img_pca_pad, _ = pre_fun.construct_sample(img_pca, patch_size) 81 | 82 | data_sample['img_pca_pad'] = img_pca_pad 83 | 84 | return data_sample 85 | -------------------------------------------------------------------------------- /data/normalizes.py: -------------------------------------------------------------------------------- 1 | # 纯数学操作来进行相应的变化 2 | 3 | import numpy as np 4 | 5 | 6 | def fun_norm(X, M='AllNorm'): 7 | if X.dtype != float: 8 | X = X.astype(np.float32) 9 | 10 | if type(M) != str: #如果不是字符型还要将其转化 11 | M = str(M) 12 | 13 | if M.lower() == 'allnorm': #字符串中的大写字母转化为小写字母 14 | X_min_value = X.min() 15 | X_max_value = X.max() 16 | if X_min_value == X_max_value: 17 | return X 18 | else: 19 | X0 = X - X_min_value 20 | Xd = X_max_value - X_min_value + 0.001 21 | 22 | return 0.1 + (1 - 0.9) * (X0 / Xd) 23 | 24 | elif M.lower() == 'rownorm': 25 | X_min_row = X.min(axis=1) #axis =0 沿行方向, axis =1 沿列的方向 26 | X_max_row = X.max(axis=1) 27 | X0 = X - np.tile(X_min_row, (X.shape[1], 1)).T # np.tile按照各个方向复制 28 | Xd = X_max_row - X_min_row + 0.001 29 | return 0.1 + (1 - 0.9) * X0 * np.tile(1./Xd, (X.shape[1], 1)).T 30 | 31 | elif M.lower() == 'colnorm': 32 | X_min_col = X.min(axis=0) 33 | X_max_col = X.max(axis=0) 34 | X0 = X - np.tile(X_min_col, (X.shape[0], 1)).T 35 | Xd = X_max_col - X_min_col + 0.001 36 | return 0.1 + (1 - 0.9) * X0 * np.tile(1. / Xd, (X.shape[0], 1)).T 37 | 38 | elif M.lower() == '2norm': 39 | X_2norm = np.linalg.norm(X, axis=1) # 求行向量的范数 40 | return X * (1. / np.tile(X_2norm, (X.shape[1], 1))).T 41 | 42 | elif M.lower() == 'stdnorm': 43 | X_mean = X.mean(axis=0) 44 | X0 = X - np.tile(X_mean, (X.shape[0], 1)) 45 | X_std = X.std(axis=0) + 0.001 # axis=0计算每一列的标准差 46 | return X * (1. / np.tile(X_std, (X.shape[0], 1))) 47 | 48 | else: 49 | return X 50 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 设置采用的GPU序号 6 | 7 | import scipy.io as io 8 | import imageio 9 | 10 | import configs.configs as cfg 11 | import torch.optim as optim 12 | 13 | from data.HSI_data import HSI_data as fun_data 14 | from data.get_train_test_set import get_train_test_set as fun_get_set 15 | 16 | from model.MSCNN import MSCNN 17 | from model.SSUN import SSUN 18 | from model.LSTM import lstm 19 | 20 | from tool.train import train as fun_train 21 | from tool.test import test as fun_test 22 | from matplotlib import pyplot as plt 23 | 24 | import show 25 | import warnings 26 | 27 | warnings.filterwarnings("ignore") 28 | 29 | 30 | def main(): 31 | # (1)基本参数 32 | cfg_data = cfg.data 33 | cfg_model = cfg.model 34 | cfg_train = cfg.train['train_model'] 35 | cfg_optim = cfg.train['optimizer'] # 导入优化模型的相应参数 36 | cfg_test = cfg.test 37 | 38 | # (2)导入数据 39 | data_sets = fun_get_set(cfg_data) 40 | 41 | train_data = fun_data(data_sets, cfg_data['train_data']) 42 | test_data = fun_data(data_sets, cfg_data['test_data']) 43 | no_gt_data = fun_data(data_sets, cfg_data['no_gt_data']) 44 | 45 | # (3)训练的相关配置 46 | device = torch.device("cuda:2") 47 | 48 | # 加载模型 49 | model = SSUN().to(device) 50 | 51 | # 损失函数 52 | loss_fun = nn.CrossEntropyLoss() 53 | 54 | # 优化器 55 | # optimizer = optim.SGD(model.parameters(), lr=cfg_optim['lr'], 56 | # momentum=cfg_optim['momentum'], weight_decay=cfg_optim['weight_decay']) 57 | # optimizer = optim.RMSprop(model.parameters(), lr=cfg_optim['lr'], 58 | # momentum=cfg_optim['momentum'], weight_decay=cfg_optim['weight_decay']) 59 | optimizer = optim.Adam(params=model.parameters(), lr=cfg_optim['lr'], 60 | betas=cfg_optim['betas'], eps=1e-8, weight_decay=cfg_optim['weight_decay']) 61 | 62 | # 训练 63 | fun_train(train_data, model, loss_fun, optimizer, device, cfg_train) 64 | 65 | # 测试 66 | pred_train_label = fun_test(train_data, data_sets['ori_gt'], model, device, cfg_test) 67 | pred_test_label = fun_test(test_data, data_sets['ori_gt'], model, device, cfg_test) 68 | pred_no_gt_label = fun_test(no_gt_data, data_sets['ori_gt'], model, device, cfg_test) 69 | 70 | predict_label = torch.cat([pred_train_label, pred_test_label], dim=0) 71 | 72 | # 直接显示相应的图像 73 | HSI = show.Predict_Label2Img(predict_label) 74 | plt.imshow(HSI) 75 | plt.show() 76 | 77 | save_folder = cfg_test['save_folder'] 78 | if not os.path.exists(save_folder): 79 | os.mkdir(save_folder) # 用于创建目录 80 | 81 | io.savemat(save_folder + '/classification_label.mat', {'predict_label_CNN1D': predict_label}) # 将测试数据保存在mat文件中 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /model/LSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # LSTM 将相应的数据进行学习记忆 7 | 8 | 9 | class lstm(nn.Module): 10 | def __init__(self, band_num=4, chose_model=1): 11 | super(lstm, self).__init__() 12 | self.band = band_num 13 | self.chose_model = chose_model 14 | self.lstm_model = nn.LSTM( 15 | input_size=55, 16 | hidden_size=128, 17 | num_layers=1, 18 | batch_first=True 19 | ) 20 | # h0 = torch.randn(2, 3, 6) 21 | # c0 = torch.randn(2, 3, 6) 在调用 22 | self.outlayer = nn.Sequential( 23 | nn.Linear(128, 50), 24 | nn.Linear(50, 16) 25 | ) 26 | 27 | def forward(self, x): 28 | b, c, h_size, w_size = x.shape 29 | input = x[:, :, h_size // 2, w_size // 2] 30 | input = input.reshape(b, c) 31 | 32 | nb_features = int(c // self.band) 33 | input_reshape = torch.zeros((x.shape[0], self.band, nb_features)).type_as(x) 34 | 35 | if self.chose_model == 1: 36 | for j in range(0, self.band): 37 | input_reshape[:, j, :] = input[:, j:j + (nb_features - 1) * self.band + 1:self.band] 38 | else: 39 | for j in range(0, self.band): 40 | input_reshape[:, j, :] = input[:, j * nb_features:(j + 1) * nb_features] 41 | 42 | out, (h0, c0) = self.lstm_model(input_reshape) 43 | out = out[:, -1, :] # torch.Size([32, 128]) 44 | 45 | out_org = out 46 | out_finnal = self.outlayer(out) 47 | return out_org, out_finnal 48 | 49 | 50 | def main(): 51 | net = lstm(4, 1) 52 | tmp = torch.randn(128, 220, 24, 24) 53 | 54 | out = net(tmp)[1] 55 | print(out.shape) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /model/MSCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # MSCNN,将patch的图片进行相应的带入,经过卷积池化来进行相应的各个部分的全连接 7 | # 图片是经过了PCA处理降维了,只具有5个通道数 8 | 9 | class MSCNN(nn.Module): 10 | def __init__(self): 11 | super(MSCNN, self).__init__() 12 | self.conv1 = nn.Conv2d(in_channels=5, out_channels=32, kernel_size=3, stride=1, padding=1) 13 | self.pool1 = nn.MaxPool2d(kernel_size=2) 14 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1) 15 | self.pool2 = nn.MaxPool2d(kernel_size=2) 16 | self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1) 17 | self.pool3 = nn.MaxPool2d(kernel_size=2) 18 | self.outlayer = nn.Linear(6048, 16) 19 | 20 | def forward(self, x): 21 | out1 = self.conv1(x) 22 | out1 = self.pool1(out1) 23 | 24 | out2 = self.conv2(out1) 25 | out2 = self.pool2(out2) 26 | 27 | out3 = self.conv3(out2) 28 | out3 = self.pool3(out3) 29 | 30 | batchsz1 = out1.size(0) 31 | batchsz2 = out2.size(0) 32 | batchsz3 = out3.size(0) 33 | out1 = out1.view(batchsz1, -1) 34 | out2 = out2.view(batchsz2, -1) 35 | out3 = out3.view(batchsz3, -1) 36 | 37 | out = torch.cat([ 38 | out1, out2, out3 39 | ], dim=1) 40 | 41 | # print(out1.shape, out2.shape, out3.shape) 42 | 43 | out_org = out 44 | out_finnal = self.outlayer(out) 45 | 46 | return out_org, out_finnal 47 | 48 | 49 | def main(): 50 | net = MSCNN() 51 | tmp = torch.randn(64, 5, 24, 24) 52 | out1 = net(tmp)[1] 53 | 54 | print(out1.shape) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /model/SSUN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.LSTM import lstm 5 | from model.MSCNN import MSCNN 6 | 7 | 8 | class SSUN(nn.Module): 9 | def __init__(self): 10 | super(SSUN, self).__init__() 11 | self.spatial = lstm() 12 | self.spectal = MSCNN() 13 | 14 | self.outlayer = nn.Linear(6176, 16) 15 | 16 | def forward(self, img, img_gt): 17 | out_spatial = self.spatial(img)[0] 18 | out_spectal = self.spectal(img_gt)[0] 19 | 20 | out = torch.cat([ 21 | out_spatial, out_spectal 22 | ], dim=1) 23 | 24 | out1 = self.outlayer(out) 25 | out2 = self.spatial(img)[1] 26 | out3 = self.spectal(img_gt)[1] 27 | 28 | return out1, out2, out3 29 | 30 | 31 | def main(): 32 | net = SSUN() 33 | img = torch.randn(64, 220, 24, 24) 34 | img_gt = torch.randn(64, 5, 24, 24) 35 | 36 | out = net(img, img_gt)[1] 37 | print(out.shape) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /tool/assessment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | from sklearn.metrics import roc_curve 4 | from sklearn.metrics import confusion_matrix 5 | from sklearn.metrics import cohen_kappa_score 6 | from sklearn.metrics import accuracy_score 7 | 8 | 9 | def accuracy_assessment(img_gt, changed_map): 10 | """ 11 | assess accuracy of changed map based on ground truth 12 | """ 13 | 14 | cm = [] 15 | gt = [] 16 | TP, TN, FP, FN = 0, 0, 0, 0 17 | esp = 1e-6 18 | 19 | height, width = changed_map.shape 20 | changed_map_ = np.reshape(changed_map, (-1,)) 21 | img_gt_ = np.reshape(img_gt, (-1,)) 22 | 23 | cm = np.ones((height * width,)) 24 | cm[changed_map_ == 1] = 2 25 | cm[changed_map_ == 0] = 1 26 | 27 | gt = np.zeros((height * width,)) 28 | gt[img_gt_ == 1] = 2 29 | gt[img_gt_ == 0] = 1 30 | 31 | # scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口 32 | conf_mat = confusion_matrix(y_true=gt, y_pred=cm, labels=[1, 2]) 33 | kappa_co = cohen_kappa_score(y1=gt, y2=cm, labels=[1, 2]) 34 | 35 | oa = np.sum(conf_mat.diagonal()) / np.sum(conf_mat) 36 | 37 | return conf_mat, oa, kappa_co 38 | -------------------------------------------------------------------------------- /tool/show.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def Predict_Label2Img(predict_label): 6 | # predict_label torch.Size([207400, 4]) 7 | # predict_img (610, 340) 8 | # indian_pines=([145, 145]) 21025 9 | # KSC = ([512, 614]) 10 | predict_img = torch.zeros([145, 145]) 11 | num = predict_label.shape[0] # 207400 12 | 13 | for i in range(num): 14 | x = int(predict_label[i][1]) 15 | y = int(predict_label[i][2]) 16 | l = predict_label[i][3] 17 | predict_img[x][y] = l 18 | 19 | return predict_img 20 | 21 | 22 | if __name__ == '__main__': 23 | predict_label = torch.ones([21025, 4]) 24 | predict_img = Predict_Label2Img(predict_label) 25 | print('over') 26 | -------------------------------------------------------------------------------- /tool/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data.HSI_data import batch_collate as collate_fn 4 | from torch.utils.data import DataLoader 5 | 6 | import data.data_preprocess as pre_fun 7 | 8 | 9 | def check_keys(model, pretrained_state_dict): 10 | ckpt_keys = set(pretrained_state_dict.keys()) # set() 函数创建一个无序不重复元素集,可进行关系测试,删除重复数据 11 | model_keys = set(model.state_dict().keys()) # state_dict其实就是一个字典,该自点中包含了模型各层和其参数tensor的对应关系。 12 | used_pretrained_keys = model_keys & ckpt_keys 13 | unused_pretrained_keys = ckpt_keys - model_keys 14 | missing_keys = model_keys - ckpt_keys 15 | 16 | print('Missing keys:{}'.format(len(missing_keys))) 17 | print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) 18 | print('Used keys:{}'.format(len(used_pretrained_keys))) 19 | 20 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' # 警告位,当有用值小于0则发出警告 21 | 22 | return True 23 | 24 | 25 | def remove_prefix(state_dict, prefix): 26 | print('remove prefix \'{}\''.format(prefix)) 27 | 28 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x # 直接赋给一个变量,然后再像一般函数那样调用 29 | # eg: f(x) 30 | return {f(key): value for key, value in state_dict.items()} 31 | 32 | 33 | def load_model(model, pretrained_path, load_to_cpu): 34 | print('loading pretrained model from {}'.format(pretrained_path)) 35 | 36 | if load_to_cpu == torch.device('cpu'): 37 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)['model'] 38 | else: 39 | device = torch.cuda.current_device() 40 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))['model'] 41 | 42 | if "state_dict" in pretrained_dict.keys(): 43 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') 44 | else: 45 | pretrained_dict = remove_prefix(pretrained_dict, 'module') 46 | 47 | check_keys(model, pretrained_dict) 48 | model.load_state_dict(pretrained_dict, strict=False) 49 | 50 | return model 51 | 52 | 53 | def test(test_data, origin_gt, model, device, cfg): 54 | num_workers = cfg['workers_num'] 55 | gpu_num = cfg['gpu_num'] 56 | 57 | batch_size = cfg['batch_size'] 58 | 59 | model = load_model(model, cfg['model_weights'], device) # 加载模型,文件的格式是pth 60 | model.eval() 61 | model = model.to(device) 62 | 63 | # gpu_num 环境 64 | if gpu_num > 1 and cfg['gpu_train']: 65 | model = torch.nn.DataParallel(model).to(device) # # 采用多卡GPU服务器 66 | else: 67 | model = model.to(device) 68 | 69 | batch_data = DataLoader(test_data, batch_size, shuffle=True, num_workers=num_workers, 70 | collate_fn=collate_fn, pin_memory=True) 71 | 72 | # 数据预定义 73 | predict_correct = 0 74 | label_num = 0 75 | predict_label = [] 76 | 77 | for batch_idx, batch_sample in enumerate(batch_data): 78 | 79 | if len(batch_sample) > 3: 80 | img, target, indices, img_pca = batch_sample 81 | img_pca = img_pca.to(device) 82 | else: 83 | img, target, indices = batch_sample 84 | 85 | img = img.to(device) 86 | 87 | # 在该模块下,所有计算得出的tensor的requires_grad都自动设置为False 88 | # 反向传播时不会自动求导了,大大节约了显存 89 | 90 | with torch.no_grad(): 91 | prediction = model(img, img_pca)[0] 92 | 93 | label = prediction.softmax(dim=-1).cpu().argmax(dim=1, keepdim=True) 94 | 95 | if target.sum() > 0: 96 | predict_correct += label.eq(target.view_as(label)).sum().item() 97 | label_num += len(target) 98 | 99 | label = pre_fun.label_inverse_transform(label, origin_gt.long()) 100 | predict_label.append(torch.cat([indices, label], dim=1)) 101 | 102 | predict_label = torch.cat(predict_label, dim=0) 103 | 104 | if label_num > 0: 105 | print('OA {:.2f}%'.format(100 * predict_correct / label_num)) 106 | 107 | return predict_label 108 | -------------------------------------------------------------------------------- /tool/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import datetime 4 | import math 5 | import os 6 | 7 | from data.HSI_data import batch_collate as collate_fn 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | # from model.LSTM import lstm 12 | # from model.MSCNN import MSCNN 13 | 14 | 15 | # 学习概调整策略 16 | def adjust_lr(lr_init, lr_gamma, optimizer, epoch, step_index): 17 | if epoch < 1: 18 | lr = 0.0001 * lr_init 19 | else: 20 | lr = lr_init * lr_gamma ** ((epoch - 1) // step_index) 21 | 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | 25 | return lr 26 | 27 | 28 | # 训练 29 | def train(train_data, model, loss_fun, optimizer, device, cfg): 30 | ''' 31 | 调用时:fun_train(train_data, model, loss_fun, optimizer, device, cfg_train) 32 | batch_data = DataLoader(train_data, batch_size, shuffle=True, num_workers=num_workers, 33 | collate_fn=collate_fn, pin_memory=True) 34 | 35 | cfg_train = cfg.train['train_model'] 36 | train_data = fun_data(data_sets, cfg_data['train_data']) 37 | ''' 38 | # (1)基础参数配置 39 | 40 | num_workers = cfg['workers_num'] # 导入同时工作的线程数 41 | gpu_num = cfg['gpu_num'] # 几个GPU工作 42 | 43 | save_folder = cfg['save_folder'] # './weights/' 44 | save_name = cfg['save_name'] # 'model_CNN1D' 45 | 46 | lr_init = cfg['lr'] 47 | lr_gamma = cfg['lr_gamma'] 48 | lr_step = cfg['lr_step'] # 步进为60 49 | lr_adjust = cfg['lr_adjust'] # 设置为Ture 50 | 51 | epoch_size = cfg['epoch'] 52 | batch_size = cfg['batch_size'] 53 | 54 | if gpu_num > 1 and cfg['gpu_train']: 55 | # 采用多卡GPU服务器 56 | model = torch.nn.DataParallel(model).to(device) 57 | # 使用样例 58 | # model = model.cuda() 59 | # device_ids = [0, 1, 2, 3] # id为0和1的两块显卡 60 | # model = torch.nn.DataParallel(model, device_ids=device_ids) 61 | else: 62 | model = model.to(device) 63 | 64 | # (2)加载模型开始训练 65 | 66 | model.train() 67 | 68 | # 是否采用上次训练的模型 69 | if cfg['reuse_model']: 70 | 71 | print('loading model') 72 | 73 | checkpoint = torch.load(cfg['reuse_file'], map_location=device) # 用来加载torch.save() 保存的模型文件 74 | start_epoch = checkpoint['epoch'] 75 | 76 | model_dict = model.state_dict() # state_dict其实就是一个字典,该自点中包含了模型各层和其参数tensor的对应关系。 77 | 78 | pretrained_dict = {k: v for k, v in checkpoint['model'].item() if k in model_dict} # 再用预训练模型参数更新model_dict 79 | model_dict.update(pretrained_dict) 80 | 81 | model.load_state_dict(model_dict) # 装载已经训练好的模型 82 | else: 83 | start_epoch = 0 84 | 85 | batch_num = math.ceil(len(train_data) / batch_size) # 向上取整,返回的是共有多少个训练次数的数目 86 | print('Start training!') 87 | 88 | for epoch in range(start_epoch + 1, epoch_size + 1): 89 | 90 | epoch_time0 = time.time() # 记录初始时间 91 | 92 | batch_data = DataLoader(train_data, batch_size, shuffle=True, \ 93 | num_workers=num_workers, collate_fn=collate_fn, pin_memory=True) 94 | 95 | # 判断学习律是否采用步进的形式 96 | if lr_adjust: 97 | lr = adjust_lr(lr_init, lr_gamma, optimizer, epoch, lr_step) 98 | else: 99 | lr = lr_init 100 | 101 | epoch_loss = 0 102 | predict_correct = 0 103 | label_num = 0 104 | 105 | for batch_idx, batch_sample in enumerate(batch_data): # 遍历加载的数据集 106 | 107 | iteration = (epoch - 1) * batch_num + batch_idx + 1 108 | batch_time0 = time.time() 109 | 110 | # (1)导入图片和标签 111 | if len(batch_sample) > 3: 112 | img, target, indices, img_pca = batch_sample 113 | img_pca = img_pca.to(device) 114 | else: 115 | img, target, indices = batch_sample 116 | 117 | img = img.to(device) 118 | target = target.to(device) 119 | 120 | # (2)前向传播 121 | prediction_SSUN = model(img, img_pca)[0] 122 | prediction_lstm = model(img, img_pca)[1] 123 | prediction_MSCNN = model(img, img_pca)[2] 124 | 125 | # (3)计算损失 126 | loss1 = loss_fun(prediction_SSUN, target.long()) # 这里target应该是标签值 127 | loss2 = loss_fun(prediction_lstm, target.long()) # 这里target应该是标签值 128 | loss3 = loss_fun(prediction_MSCNN, target.long()) # 这里target应该是标签值 129 | 130 | loss = loss1 + loss2 + loss3 131 | 132 | # (4)优化器,反向传播 133 | optimizer.zero_grad() # 将梯度归零 134 | loss.backward() # 反向传播计算得到每个参数的梯度值 135 | optimizer.step() # 通过梯度下降执行一步参数更新 136 | 137 | batch_time1 = time.time() 138 | batch_time = batch_time1 - batch_time0 # 在一个迭代中所花费的时间 139 | 140 | # estimated time of Arrival 141 | batch_eta = batch_time * (batch_num - batch_idx) 142 | epoch_eta = int(batch_time * (epoch_size - epoch) * batch_num + batch_eta) 143 | 144 | epoch_loss += loss.item() # item()返回的是一个浮点型的数据 145 | 146 | predict_label = prediction_SSUN.detach().softmax(dim=-1).argmax(dim=1, 147 | keepdim=True) # 返回一个新的从当前图中分离的Variable 148 | # 返回的 Variable 不会梯度更新。 149 | # 不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。 150 | predict_correct += predict_label.eq(target.view_as(predict_label)).sum().item() # 预测正确的数量之和 151 | label_num += len(target) 152 | 153 | epoch_time1 = time.time() 154 | epoch_time = epoch_time1 - epoch_time0 # 在一个epoch中所花费的时间 155 | epoch_eta = int(epoch_time * (epoch_size - epoch)) 156 | 157 | # 将相应的数据进行打印 158 | print('Epoch: {}/{} || lr: {} || loss: {} || Train acc: {:.2f}% || ' 159 | 'Epoch time: {:.4f}s || Epoch ETA: {}' 160 | .format(epoch, epoch_size, lr, epoch_loss / batch_num, 100 * predict_correct / label_num, 161 | epoch_time, str(datetime.timedelta(seconds=epoch_eta)) 162 | ) 163 | ) 164 | 165 | if not os.path.exists(save_folder): 166 | os.makedirs(save_folder) # 递归创建目录 167 | 168 | # 存储最终的模型 169 | save_model = dict( 170 | model=model.state_dict(), 171 | epoch=epoch_size 172 | ) 173 | torch.save(save_model, os.path.join(save_folder, save_name + '_Final.pth')) 174 | --------------------------------------------------------------------------------