├── .idea
├── .gitignore
├── 2018paper_test.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── remote-mappings.xml
├── configs
└── configs.py
├── data
├── HSI_data.py
├── __init__.py
├── data_preprocess.py
├── data_test.py
├── get_train_test_set.py
└── normalizes.py
├── main.py
├── model
├── LSTM.py
├── MSCNN.py
└── SSUN.py
└── tool
├── assessment.py
├── show.py
├── test.py
└── train.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/2018paper_test.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/configs/configs.py:
--------------------------------------------------------------------------------
1 | import scipy.io as scio
2 | import matplotlib.pyplot as plt
3 | from scipy.io import loadmat
4 | import spectral as spy
5 |
6 | # 预先进行相应数据集的定义
7 | datasets_type = 'HSI Data Sets'
8 |
9 | # data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/PaviaU.mat'
10 | # image_name = 'paviaU'
11 | # gt_name = 'paviaU_gt'
12 | # paviaU的特征分为了9类
13 | # torch([32, 103, 610, 340])
14 |
15 | # data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/Urban.mat'
16 | # image_name = 'Urban'
17 | # gt_name = 'R'
18 | # Urban的特征分为了9类
19 |
20 | data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/Indian_pines.mat'
21 | image_name = 'indian_pines'
22 | gt_name = 'R'
23 | # indian_pines的特征分为了16类
24 | # torch([32, 220, 145, 145])
25 |
26 | # data_root = '/home/students/master/2022/wangzy/PyCharm-Remote/datasets/KSC.mat'
27 | # image_name = 'KSC'
28 | # gt_name = 'KSC_gt'
29 | # indian_pines的特征分为了13类
30 | # torch([32, 176, 512, 614])
31 |
32 |
33 | # 其他相应的参数进行配置
34 |
35 | phase = ['train', 'test', 'no_gt']
36 | pca_num = 5
37 | train_set_num = 64
38 | patch_size = 24
39 |
40 | # 构建data字典,将所有的数据放在data中来进行调用
41 | data = dict(
42 | data_path=data_root,
43 | image_name=image_name,
44 | gt_name=gt_name,
45 | train_set_num=train_set_num,
46 | patch_size=patch_size,
47 | pca=pca_num,
48 | train_data=dict(
49 | phase=phase[0]
50 | ),
51 | test_data=dict(
52 | phase=phase[1]
53 | ),
54 | no_gt_data=dict(
55 | phase=phase[2]
56 | )
57 | )
58 |
59 | # 建立相应模型的预参数
60 |
61 | model = dict(
62 | in_fea_num=1,
63 | out_fea_num=9,
64 | )
65 |
66 | # 训练模型的预参数
67 |
68 | lr = 1e-3
69 |
70 | train = dict(
71 | # 优化器的相应数据
72 | optimizer=dict(
73 | typename='SGD',
74 | lr=lr,
75 | betas=(0.9, 0.999),
76 | momentum=0.9, # 动量
77 | weight_decay=1e-4 # 权重衰减
78 | ),
79 |
80 | train_model=dict(
81 | gpu_train=True,
82 | gpu_num=1,
83 | workers_num=16,
84 | epoch=500,
85 | batch_size=64,
86 | # 学习率的相应参数
87 | lr=lr,
88 | lr_adjust=True,
89 | lr_gamma=0.1,
90 | lr_step=460,
91 | save_folder='./weights/',
92 | save_name='model_CNN1D',
93 | reuse_model=False,
94 | reuse_file='./weights/model_CNN1D_Final.pth'
95 | )
96 | )
97 |
98 | test = dict(
99 | batch_size=1000,
100 | gpu_train=True,
101 | gpu_num=1,
102 | workers_num=16,
103 | model_weights='./weights/model_CNN1D_Final.pth',
104 | save_folder='./result'
105 | )
106 |
107 |
108 | def main():
109 | data = loadmat(data_root)
110 | print(data.keys())
111 |
112 |
113 | if __name__ == '__main__':
114 | main()
115 |
--------------------------------------------------------------------------------
/data/HSI_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 |
4 |
5 |
6 | class HSI_data(data.Dataset):
7 | def __init__(self, data_sample, cfg):
8 | self.phase = cfg['phase']
9 |
10 | # img:pad_img
11 | # img_indices:每个patch的坐标合集
12 | self.img = data_sample['pad_img']
13 | self.img_indices = data_sample['pad_img_indices']
14 | self.gt = data_sample['img_gt']
15 |
16 | self.pca = 'img_pca_pad' in data_sample # 判断'img_pca_pad'是否在data_sample中
17 | # 是的话返回 Ture,不是返回Falus
18 |
19 | if self.pca:
20 | self.img_pca = data_sample['img_pca_pad']
21 | # data_indices:用img_gt的标签信息划分得到的样本
22 |
23 | if self.phase == 'train':
24 | self.data_indices = data_sample['train_indices']
25 | elif self.phase == 'test':
26 | self.data_indices = data_sample['test_indices']
27 | elif self.phase == 'no_gt':
28 | self.data_indices = data_sample['no_gt_indices']
29 |
30 | def __len__(self):
31 | return len(self.data_indices)
32 |
33 | # 将其中的函数根据相应的下标进行索引
34 | # 该方法支持从 0 到 len(self)的索引
35 | # data_indices=: torch.Size([270, 3])
36 | # img_indices=: torch.Size([207400, 4])
37 | def __getitem__(self, idx):
38 |
39 | index = self.data_indices[idx]
40 | img_index = self.img_indices[index[0]] # img_index 坐标
41 |
42 | # 从pad_img中根据坐标截取样本
43 | img = self.img[:, img_index[0]:img_index[1], img_index[2]:img_index[3]]
44 | label = self.gt[index[1], index[2]]
45 |
46 | if self.pca:
47 | img_pca = self.img_pca[:, img_index[0]:img_index[1], img_index[2]:img_index[3]]
48 |
49 | # 存在pca的话,则从相应的图像中进行截取
50 | return img, label, index, img_pca
51 | else:
52 | return img, label, index
53 |
54 |
55 | # 用来处理不同情况下的输入的dataset的封装
56 | def batch_collate(batch):
57 | images = []
58 | labels = []
59 | indices = []
60 | images_pca = []
61 |
62 | for sample in batch:
63 | images.append(sample[0])
64 | labels.append(sample[1])
65 | indices.append(sample[2])
66 |
67 | if len(sample) > 3:
68 | images_pca.append(sample[3])
69 |
70 | # stack将相应的数组进行连接
71 | if len(images_pca) > 0:
72 | return torch.stack(images, 0), torch.stack(labels), \
73 | torch.stack(indices), torch.stack(images_pca, 0)
74 | else:
75 | return torch.stack(images, 0), torch.stack(labels), \
76 | torch.stack(indices)
77 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wzysaber/2018paper_code/ade9a5414ba3df1618dd42136a77763bb2d1f305/data/__init__.py
--------------------------------------------------------------------------------
/data/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 |
5 |
6 | # 将103通道的图降维成为3通道的图
7 | def extract_pc(image, pc=3):
8 | '''
9 | :function:123
10 | :param image:
11 | :param pc:
12 | :return:
13 | '''
14 | channel, height, width = image.shape
15 | data = image.contiguous().reshape(channel, height * width) # 存在contiguous函数,在改变data的值后
16 | data_c = data - data.mean(dim=1).unsqueeze(1)
17 | # 计算一个矩阵或一批矩阵 input 的奇异值分解
18 | u, s, v = torch.svd(data_c.matmul(data_c.T)) # data_c矩阵乘以data_c的转置
19 | sorted_data, indices = s.sort(descending=True) # 将s中的数按降序进行排列
20 | image_pc = u[:, indices[0:pc]].T.matmul(data)
21 | return image_pc.reshape(pc, height, width)
22 |
23 |
24 | # (x - mean(x))/std(x) normalize to mean: 0, std: 1
25 | # 将数据进行归一化处理
26 |
27 | def std_norm(image):
28 | image = image.permute(1, 2, 0).numpy()
29 | trans = transforms.Compose([
30 | transforms.ToTensor(),
31 | transforms.Normalize(torch.tensor(image).mean(dim=[0, 1]), torch.tensor(image).std(dim=[0, 1]))
32 | ])
33 |
34 | return trans(image)
35 |
36 |
37 | # (x - min(x))/(max(x) - min(x)) normalize to (0, 1) for each channel
38 | def one_zero_norm(image):
39 | channel, height, width = image.shape
40 | data = image.view(channel, height * width)
41 | data_max = data.max(dim=1)[0]
42 | data_min = data.min(dim=1)[0]
43 |
44 | data = (data - data_min.unsqueeze(1)) / (data_max.unsqueeze(1) - data_min.unsqueeze(1)) # 在第二个维度上插入一个维度
45 |
46 | return data.view(channel, height, width)
47 |
48 |
49 | # input tensor image size with CxHxW
50 | # -1 + 2 * (x - min(x))/(max(x) - min(x)) normalize to (-1, 1) for each channel
51 | # 同样对数据进行归一化处理,范围在(-1,1)
52 | def pos_neg_norm(image):
53 | channel, height, width = image.shape
54 | data = image.view(channel, height * width)
55 | data_max = data.max(dim=1)[0]
56 | data_min = data.min(dim=1)[0]
57 |
58 | data = -1 + 2 * (data - data_min.unsqueeze(1)) / (data_max.unsqueeze(1) - data_min.unsqueeze(1))
59 |
60 | return data.view(channel, height, width)
61 |
62 |
63 | # function:construct sample,切分得到patch,储存每个patch的坐标值
64 | # input: image:torch.size(103, 610, 340)
65 | # window_size:27
66 | # output:pad_image, batch_image_indices
67 | def construct_sample(image, window_size=27):
68 | # 先输入照片的通道数等数据的指标
69 | channel, height, width = image.shape
70 |
71 | half_window = int(window_size // 2) # 13
72 | # 使用输入边界的复制值来填充
73 | pad = nn.ReplicationPad2d(half_window) # 上下左右伸展13单位值,就是26
74 | # uses (padding_left, padding_right,padding_top, padding_bottom)
75 |
76 | pad_image = pad(image.unsqueeze(0)).squeeze(0) # torch.Size([103, 636, 366])
77 |
78 | # 用数组存储切分得到的patch的坐标
79 | # torch.Size([207400, 4])
80 | batch_image_indices = torch.zeros((height * width, 4), dtype=torch.long)
81 |
82 | t = 0
83 | for h in range(height):
84 | for w in range(width):
85 | batch_image_indices[t, :] = torch.tensor([h, h + window_size, w, w + window_size])
86 | t += 1
87 |
88 | return pad_image, batch_image_indices
89 |
90 |
91 | # 这里的gt是相应的标签值,背景图是0,相应的标志物为1到9
92 | # 将gt的标签进行变化,使背景为-1,相应的标志物为0到9
93 | def label_transform(gt):
94 | '''
95 | function:tensor label to 0-n for training
96 | input: gt
97 | output:gt
98 | tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
99 | -> tensor([-1., 0., 1., 2., 3., 4., 5., 6., 7., 8.])
100 | '''
101 | label = torch.unique(gt) # 返回标签label中的不同值
102 | gt_new = torch.zeros_like(gt) # zeros_like(a)的目的是构建一个与a同维度的数组,并初始化所有变量为零。
103 | # zeros,则需要代入参数
104 |
105 | for each in range(len(label)): # 长度为10
106 | indices = torch.where(gt == label[each])
107 |
108 | if label[0] == 0:
109 | gt_new[indices] = each - 1
110 | else:
111 | gt_new[indices] = each
112 |
113 | # tensor([-1., 0., 1., 2., 3., 4., 5., 6., 7., 8.])
114 | # labeL_new = torch.unique(gt_new)
115 |
116 | return gt_new
117 |
118 |
119 | # 将标签值进行还原
120 | def label_inverse_transform(predict_result, gt):
121 | label_origin = torch.unique(gt)
122 | label_predict = torch.unique(predict_result)
123 |
124 | predict_result_origin = torch.zeros_like(predict_result)
125 | for each in range(len(label_predict)):
126 | indices = torch.where(predict_result == label_predict[each]) # 此时他会返回等式相等时的坐标
127 | if len(label_predict) != len(label_origin):
128 | predict_result_origin[indices] = label_origin[each + 1]
129 | else:
130 | predict_result_origin[indices] = label_origin[each]
131 |
132 | return predict_result_origin
133 |
134 |
135 | def select_sample(gt, ntr):
136 | '''
137 | function: 用img_gt的标签信息划分样本
138 | input: gt -> torch.Size(610, 340); ntr -> train_set_num 30
139 | output:data_sample = {'train_indices': train_indices, 'train_num': train_num,
140 | 'test_indices': test_indices, 'test_num': test_num,
141 | 'no_gt_indices': no_gt_indices, 'no_gt_num': no_gt_num.unsqueeze(0) }
142 | '''
143 | gt_vector = gt.reshape(-1, 1).squeeze(1) # 使用reshape函数来对其进行重组, reshape(1,-1)转化成1行
144 | # torch.Size([207400])
145 |
146 | label = torch.unique(gt)
147 |
148 | first_time = True
149 |
150 | for each in range(len(label)): # each 0~9
151 | indices_vector = torch.where(gt_vector == label[each]) # 返回1位的索引,也就标签值的具体位置
152 | # 将相应的标签进行遍历
153 | indices = torch.where(gt == label[each]) # 返回2维的索引,比如gt中-1的二维坐标
154 |
155 | # print(indices)
156 | indices_vector = indices_vector[0]
157 | indices_row = indices[0]
158 | indices_column = indices[1]
159 |
160 | # 相应的背景值为 -1
161 | if label[each] == -1:
162 | no_gt_indices = torch.cat([
163 | indices_vector.unsqueeze(1),
164 | indices_row.unsqueeze(1),
165 | indices_column.unsqueeze(1)],
166 | dim=1
167 | )
168 | no_gt_num = torch.tensor(len(indices_vector))
169 |
170 | # 其他标签 0-8
171 | else:
172 | class_num = torch.tensor(len(indices_vector))
173 | # each循环得到class_num:6631->18649->2099->3064->1345->5029->1330->3682->947
174 | # 在不同标签下得到的长度值
175 |
176 | # 得到选择的数量 ntr = train_set_num 30
177 | # if ntr < 1: # 表现为百分数
178 | # ntr0 = int(ntr * class_num)
179 | # else:
180 | # ntr0 = ntr
181 | # # 最小值也得选10
182 | # if ntr0 < 10:
183 | # sel_num = 10
184 | # elif ntr0 > class_num // 2:
185 | # sel_num = class_num // 2
186 | # else:
187 | # sel_num = ntr0
188 |
189 | train_num_array = [30, 150, 150, 100, 150, 150, 20, 150, 15, 150, 150, 150, 150, 150, 50, 50]
190 | sel_num = train_num_array[each-1]
191 |
192 | sel_num = torch.tensor(sel_num) # tensor(30)
193 |
194 | # 将标签进行打乱
195 | rand_indices0 = torch.randperm(class_num) # torch.randperm 给定参数n,返回一个从0到n-1的随机整数排列
196 | rand_indices = indices_vector[rand_indices0]
197 |
198 | # 划分训练集train,测试集test
199 | # 划分打乱后的随机整数排列
200 | tr_ind0 = rand_indices0[0:sel_num] # torch.Size([30])
201 | te_ind0 = rand_indices0[sel_num:] # 将剩下的数据用作测试集
202 |
203 | # 划分随机整数排列对应的gt
204 | tr_ind = rand_indices[0:sel_num] # torch.Size([30])
205 | te_ind = rand_indices[sel_num:]
206 |
207 | # 训练集train: 索引+坐标
208 | sel_tr_ind = torch.cat([
209 | tr_ind.unsqueeze(1),
210 | indices_row[tr_ind0].unsqueeze(1),
211 | indices_column[tr_ind0].unsqueeze(1)],
212 | dim=1
213 | ) # torch.Size([30, 3])
214 |
215 | # 测试集test
216 | sel_te_ind = torch.cat([
217 | te_ind.unsqueeze(1),
218 | indices_row[te_ind0].unsqueeze(1),
219 | indices_column[te_ind0].unsqueeze(1)],
220 | dim=1
221 | ) # torch.Size([6601, 3])
222 |
223 | if first_time:
224 | first_time = False
225 |
226 | train_indices = sel_tr_ind
227 | train_num = sel_num.unsqueeze(0)
228 |
229 | test_indices = sel_te_ind
230 | test_num = (class_num - sel_num).unsqueeze(0)
231 |
232 | else:
233 | train_indices = torch.cat([train_indices, sel_tr_ind], dim=0)
234 | train_num = torch.cat([train_num, sel_num.unsqueeze(0)])
235 |
236 | test_indices = torch.cat([test_indices, sel_te_ind], dim=0)
237 | test_num = torch.cat([test_num, (class_num - sel_num).unsqueeze(0)])
238 |
239 | # 训练集
240 | rand_tr_ind = torch.randperm(train_num.sum())
241 | train_indices = train_indices[rand_tr_ind,]
242 | # 测试集
243 | rand_te_ind = torch.randperm(test_num.sum()) # torch.Size([42506])
244 | test_indices = test_indices[rand_te_ind,] # torch.Size([42506, 3])
245 | # 背景
246 | rand_no_gt_ind = torch.randperm(no_gt_num.sum()) # torch.Size([164624])
247 | no_gt_indices = no_gt_indices[rand_no_gt_ind,] # torch.Size([164624, 3])
248 |
249 | # 将6种数据参数进行保存
250 | data_sample = {'train_indices': train_indices, 'train_num': train_num,
251 | 'test_indices': test_indices, 'test_num': test_num,
252 | 'no_gt_indices': no_gt_indices, 'no_gt_num': no_gt_num.unsqueeze(0)
253 | }
254 |
255 | return data_sample
256 |
257 |
258 |
259 |
--------------------------------------------------------------------------------
/data/data_test.py:
--------------------------------------------------------------------------------
1 | import scipy.io as sio
2 | import numpy as np
3 | import configs.configs as cfg
4 |
5 | def normalize(x, k):
6 | if k == 1:
7 | mu = np.mean(x, 0)
8 | x_norm = x - mu
9 | sigma = np.std(x_norm, 0)
10 | x_norm = x_norm / sigma
11 | return x_norm
12 | elif type == 2:
13 | minx = np.min(x, 0)
14 | maxx = np.max(x, 0)
15 | x_norm = x - minx
16 | x_norm = x_norm / (maxx - minx)
17 | return x_norm
18 |
19 |
20 | def HyperFunctions(timestep=4, s1s2=2):
21 | cfg_data = cfg.data
22 | data_path = cfg_data['data_path']
23 | data = sio.loadmat(data_path)
24 | x = data['indian_pines']
25 | y = data['R']
26 |
27 | train_num_array = [30, 150, 150, 100, 150, 150, 20, 150, 15, 150, 150, 150, 150, 150, 50, 50]
28 | train_num_array = np.array(train_num_array).astype('int')
29 | [row, col, n_feature] = x.shape
30 | x = x.reshape(row * col, n_feature)
31 | y = y.reshape(row * col, 1)
32 | # 16
33 | n_class = y.max()
34 | # 55
35 | nb_features = int(n_feature / timestep)
36 | # 1765
37 | train_num_all = sum(train_num_array)
38 | # (21025, 4, 55)
39 | x = normalize(x, 1)
40 |
41 | x_reshape = np.zeros((x.shape[0], timestep, nb_features))
42 | if s1s2 == 2:
43 |
44 | for j in range(0, timestep):
45 | x_reshape[:, j, :] = x[:, j:j + (nb_features - 1) * timestep + 1:timestep]
46 | else:
47 | for j in range(0, timestep):
48 | x_reshape[:, j, :] = x[:, j * nb_features:(j + 1) * nb_features]
49 |
50 | x_data_all = x_reshape
51 |
52 | randomarray = list()
53 |
54 | for i in range(1, n_class + 1):
55 | index = np.where(y == i)[0]
56 | n_data = index.shape[0]
57 | randomarray.append(np.random.permutation(n_data))
58 |
59 | flag1 = 0
60 | flag2 = 0
61 | # (1765, 4, 55)
62 | x_train = np.zeros((train_num_all, timestep, nb_features))
63 |
64 | x_test = np.zeros((sum(y > 0)[0] - train_num_all, timestep, nb_features))
65 |
66 | for i in range(1, n_class + 1):
67 | index = np.where(y == i)[0]
68 | # 46
69 | n_data = index.shape[0]
70 | # 30
71 | train_num = train_num_array[i - 1]
72 | randomx = randomarray[i - 1]
73 | if s1s2 == 2:
74 |
75 | for j in range(0, timestep):
76 | x_train[flag1:flag1 + train_num, j, :] = x[index[randomx[0:train_num]],
77 | j:j + (nb_features - 1) * timestep + 1:timestep]
78 | x_test[flag2:flag2 + n_data - train_num, j, :] = x[index[randomx[train_num:n_data]],
79 | j:j + (nb_features - 1) * timestep + 1:timestep]
80 |
81 | else:
82 | for j in range(0, timestep):
83 | x_train[flag1:flag1 + train_num, j, :] = x[index[randomx[0:train_num]],
84 | j * nb_features:(j + 1) * nb_features]
85 | x_test[flag2:flag2 + n_data - train_num, j, :] = x[index[randomx[train_num:n_data]],
86 | j * nb_features:(j + 1) * nb_features]
87 |
88 | flag1 = flag1 + train_num
89 | flag2 = flag2 + n_data - train_num
90 | # (1765, 4, 55) (8484, 4, 55)
91 | return x_data_all.astype('float32')
92 |
93 |
94 | def main():
95 | out1 = HyperFunctions()
96 | print(out1.shape)
97 |
98 |
99 |
100 | if __name__ == '__main__':
101 | main()
102 |
--------------------------------------------------------------------------------
/data/get_train_test_set.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.io as io
3 | import data_preprocess as pre_fun
4 |
5 |
6 | def get_train_test_set(cfg):
7 | '''
8 | function: (1)划分数据集train,test
9 | (2)加载数据集,转化为tensor,label transform,
10 | (3)切分patch,储存每个patch的坐标值,
11 | (4)由gt划分样本,最终得到data_sample
12 |
13 | input: cfg,也就是在config中编辑相应的数据
14 |
15 | output:输出由gt进行划分的data_sample
16 |
17 | # dict_keys(['train_indices', 'train_num', 'test_indices', 'test_num',
18 | # 'no_gt_indices', 'no_gt_num', 'pad_img', 'pad_img_indices', 'img_gt', 'ori_gt'])
19 | '''
20 |
21 | # 从cfg中导入设定好的参数
22 | data_path = cfg['data_path'] # 导入存放的地址
23 | image_name = cfg['image_name'] # paviaU
24 | # 这个其实就是103通道的整个图像的信息
25 |
26 | gt_name = cfg['gt_name'] # 'paviaU_gt'
27 | # 这个是地图的特征信息,同时具有标签值
28 | # [0,0,1,1,1,4,4,4,]这种的标签图
29 |
30 | train_set_num = cfg['train_set_num'] # 30,每一次数据集训练的次数
31 | patch_size = cfg['patch_size'] # 27,用于切分图像的尺寸
32 |
33 | # 加载数据高光谱的数据集
34 | # 加载的数据变化形式,先进行.astype('float32')再进行.from_numpy(img),就转化为torch.Size的格式
35 | data = io.loadmat(data_path) # 从相应的文件夹导入
36 |
37 | img = data[image_name].astype('float32') # .astype转换数组的数据类型 (610, 340, 103) [w,h,c]
38 | gt = data[gt_name].astype('float32') # 转换成float32 (610, 340) ,这个数据从数据库中导入,只有一个数据
39 |
40 | img = torch.from_numpy(img) # 转tensor # torch.Size(610, 340, 103)
41 | gt = torch.from_numpy(gt) # torch.Size(610, 340)
42 |
43 | img = img.permute(2, 0, 1) # 变换tensor的维度,把channel放到第一维CxHxW # torch.Size(103, 610, 340)
44 | img = pre_fun.std_norm(img) # 归一化,torch.Size(103, 610, 340) 将数据分布在(0,1)之间
45 |
46 | # label transform 0~9 -> -1~8
47 | # 将标签值进行转换,应该在mat文件中,对不同的物体的label值就做好了定义
48 | img_gt = pre_fun.label_transform(gt) # torch.size(610, 340)
49 |
50 | # construct_sample:切分patch,储存每个patch的坐标值
51 | # img_pad的值为([103, 636, 366]),
52 | # img_pad_indices的值为([207400, 4])
53 | img_pad, img_pad_indices = pre_fun.construct_sample(img, patch_size)
54 |
55 | # (1)select_sample:用img_gt的标签信息划分样本
56 | # (2)得到的data_sample = {'train_indices': train_indices, 'train_num': train_num,
57 | # 'test_indices': test_indices, 'test_num': test_num,
58 | # 'no_gt_indices': no_gt_indices, 'no_gt_num': no_gt_num.unsqueeze(0)
59 | # }
60 | data_sample = pre_fun.select_sample(img_gt, train_set_num)
61 |
62 | # data_sample再添加几项数据
63 | data_sample['pad_img'] = img_pad
64 | data_sample['pad_img_indices'] = img_pad_indices
65 | data_sample['img_gt'] = img_gt # 转化后的特征标签图的数据
66 | data_sample['ori_gt'] = gt # 原始特征标签图的数据
67 |
68 | # print('data_sample.keys()',data_sample.keys())
69 | # dict_keys(['train_indices', 'train_num', 'test_indices', 'test_num',
70 | # 'no_gt_indices', 'no_gt_num', 'pad_img', 'pad_img_indices', 'img_gt', 'ori_gt'])
71 |
72 | # 在预处理中cfg['pca']=1,在我的理解里就是是否执行的标志位
73 | # 属于是buff叠满了
74 | # 将图像进行扩展,归0化和,标准差
75 | if cfg['pca'] > 0:
76 | img_pca = pre_fun.extract_pc(img, cfg['pca'])
77 | img_pca = pre_fun.one_zero_norm(img_pca)
78 | img_pca = pre_fun.std_norm(img_pca)
79 |
80 | img_pca_pad, _ = pre_fun.construct_sample(img_pca, patch_size)
81 |
82 | data_sample['img_pca_pad'] = img_pca_pad
83 |
84 | return data_sample
85 |
--------------------------------------------------------------------------------
/data/normalizes.py:
--------------------------------------------------------------------------------
1 | # 纯数学操作来进行相应的变化
2 |
3 | import numpy as np
4 |
5 |
6 | def fun_norm(X, M='AllNorm'):
7 | if X.dtype != float:
8 | X = X.astype(np.float32)
9 |
10 | if type(M) != str: #如果不是字符型还要将其转化
11 | M = str(M)
12 |
13 | if M.lower() == 'allnorm': #字符串中的大写字母转化为小写字母
14 | X_min_value = X.min()
15 | X_max_value = X.max()
16 | if X_min_value == X_max_value:
17 | return X
18 | else:
19 | X0 = X - X_min_value
20 | Xd = X_max_value - X_min_value + 0.001
21 |
22 | return 0.1 + (1 - 0.9) * (X0 / Xd)
23 |
24 | elif M.lower() == 'rownorm':
25 | X_min_row = X.min(axis=1) #axis =0 沿行方向, axis =1 沿列的方向
26 | X_max_row = X.max(axis=1)
27 | X0 = X - np.tile(X_min_row, (X.shape[1], 1)).T # np.tile按照各个方向复制
28 | Xd = X_max_row - X_min_row + 0.001
29 | return 0.1 + (1 - 0.9) * X0 * np.tile(1./Xd, (X.shape[1], 1)).T
30 |
31 | elif M.lower() == 'colnorm':
32 | X_min_col = X.min(axis=0)
33 | X_max_col = X.max(axis=0)
34 | X0 = X - np.tile(X_min_col, (X.shape[0], 1)).T
35 | Xd = X_max_col - X_min_col + 0.001
36 | return 0.1 + (1 - 0.9) * X0 * np.tile(1. / Xd, (X.shape[0], 1)).T
37 |
38 | elif M.lower() == '2norm':
39 | X_2norm = np.linalg.norm(X, axis=1) # 求行向量的范数
40 | return X * (1. / np.tile(X_2norm, (X.shape[1], 1))).T
41 |
42 | elif M.lower() == 'stdnorm':
43 | X_mean = X.mean(axis=0)
44 | X0 = X - np.tile(X_mean, (X.shape[0], 1))
45 | X_std = X.std(axis=0) + 0.001 # axis=0计算每一列的标准差
46 | return X * (1. / np.tile(X_std, (X.shape[0], 1)))
47 |
48 | else:
49 | return X
50 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import os
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 设置采用的GPU序号
6 |
7 | import scipy.io as io
8 | import imageio
9 |
10 | import configs.configs as cfg
11 | import torch.optim as optim
12 |
13 | from data.HSI_data import HSI_data as fun_data
14 | from data.get_train_test_set import get_train_test_set as fun_get_set
15 |
16 | from model.MSCNN import MSCNN
17 | from model.SSUN import SSUN
18 | from model.LSTM import lstm
19 |
20 | from tool.train import train as fun_train
21 | from tool.test import test as fun_test
22 | from matplotlib import pyplot as plt
23 |
24 | import show
25 | import warnings
26 |
27 | warnings.filterwarnings("ignore")
28 |
29 |
30 | def main():
31 | # (1)基本参数
32 | cfg_data = cfg.data
33 | cfg_model = cfg.model
34 | cfg_train = cfg.train['train_model']
35 | cfg_optim = cfg.train['optimizer'] # 导入优化模型的相应参数
36 | cfg_test = cfg.test
37 |
38 | # (2)导入数据
39 | data_sets = fun_get_set(cfg_data)
40 |
41 | train_data = fun_data(data_sets, cfg_data['train_data'])
42 | test_data = fun_data(data_sets, cfg_data['test_data'])
43 | no_gt_data = fun_data(data_sets, cfg_data['no_gt_data'])
44 |
45 | # (3)训练的相关配置
46 | device = torch.device("cuda:2")
47 |
48 | # 加载模型
49 | model = SSUN().to(device)
50 |
51 | # 损失函数
52 | loss_fun = nn.CrossEntropyLoss()
53 |
54 | # 优化器
55 | # optimizer = optim.SGD(model.parameters(), lr=cfg_optim['lr'],
56 | # momentum=cfg_optim['momentum'], weight_decay=cfg_optim['weight_decay'])
57 | # optimizer = optim.RMSprop(model.parameters(), lr=cfg_optim['lr'],
58 | # momentum=cfg_optim['momentum'], weight_decay=cfg_optim['weight_decay'])
59 | optimizer = optim.Adam(params=model.parameters(), lr=cfg_optim['lr'],
60 | betas=cfg_optim['betas'], eps=1e-8, weight_decay=cfg_optim['weight_decay'])
61 |
62 | # 训练
63 | fun_train(train_data, model, loss_fun, optimizer, device, cfg_train)
64 |
65 | # 测试
66 | pred_train_label = fun_test(train_data, data_sets['ori_gt'], model, device, cfg_test)
67 | pred_test_label = fun_test(test_data, data_sets['ori_gt'], model, device, cfg_test)
68 | pred_no_gt_label = fun_test(no_gt_data, data_sets['ori_gt'], model, device, cfg_test)
69 |
70 | predict_label = torch.cat([pred_train_label, pred_test_label], dim=0)
71 |
72 | # 直接显示相应的图像
73 | HSI = show.Predict_Label2Img(predict_label)
74 | plt.imshow(HSI)
75 | plt.show()
76 |
77 | save_folder = cfg_test['save_folder']
78 | if not os.path.exists(save_folder):
79 | os.mkdir(save_folder) # 用于创建目录
80 |
81 | io.savemat(save_folder + '/classification_label.mat', {'predict_label_CNN1D': predict_label}) # 将测试数据保存在mat文件中
82 |
83 |
84 | if __name__ == '__main__':
85 | main()
86 |
--------------------------------------------------------------------------------
/model/LSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # LSTM 将相应的数据进行学习记忆
7 |
8 |
9 | class lstm(nn.Module):
10 | def __init__(self, band_num=4, chose_model=1):
11 | super(lstm, self).__init__()
12 | self.band = band_num
13 | self.chose_model = chose_model
14 | self.lstm_model = nn.LSTM(
15 | input_size=55,
16 | hidden_size=128,
17 | num_layers=1,
18 | batch_first=True
19 | )
20 | # h0 = torch.randn(2, 3, 6)
21 | # c0 = torch.randn(2, 3, 6) 在调用
22 | self.outlayer = nn.Sequential(
23 | nn.Linear(128, 50),
24 | nn.Linear(50, 16)
25 | )
26 |
27 | def forward(self, x):
28 | b, c, h_size, w_size = x.shape
29 | input = x[:, :, h_size // 2, w_size // 2]
30 | input = input.reshape(b, c)
31 |
32 | nb_features = int(c // self.band)
33 | input_reshape = torch.zeros((x.shape[0], self.band, nb_features)).type_as(x)
34 |
35 | if self.chose_model == 1:
36 | for j in range(0, self.band):
37 | input_reshape[:, j, :] = input[:, j:j + (nb_features - 1) * self.band + 1:self.band]
38 | else:
39 | for j in range(0, self.band):
40 | input_reshape[:, j, :] = input[:, j * nb_features:(j + 1) * nb_features]
41 |
42 | out, (h0, c0) = self.lstm_model(input_reshape)
43 | out = out[:, -1, :] # torch.Size([32, 128])
44 |
45 | out_org = out
46 | out_finnal = self.outlayer(out)
47 | return out_org, out_finnal
48 |
49 |
50 | def main():
51 | net = lstm(4, 1)
52 | tmp = torch.randn(128, 220, 24, 24)
53 |
54 | out = net(tmp)[1]
55 | print(out.shape)
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
--------------------------------------------------------------------------------
/model/MSCNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # MSCNN,将patch的图片进行相应的带入,经过卷积池化来进行相应的各个部分的全连接
7 | # 图片是经过了PCA处理降维了,只具有5个通道数
8 |
9 | class MSCNN(nn.Module):
10 | def __init__(self):
11 | super(MSCNN, self).__init__()
12 | self.conv1 = nn.Conv2d(in_channels=5, out_channels=32, kernel_size=3, stride=1, padding=1)
13 | self.pool1 = nn.MaxPool2d(kernel_size=2)
14 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
15 | self.pool2 = nn.MaxPool2d(kernel_size=2)
16 | self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
17 | self.pool3 = nn.MaxPool2d(kernel_size=2)
18 | self.outlayer = nn.Linear(6048, 16)
19 |
20 | def forward(self, x):
21 | out1 = self.conv1(x)
22 | out1 = self.pool1(out1)
23 |
24 | out2 = self.conv2(out1)
25 | out2 = self.pool2(out2)
26 |
27 | out3 = self.conv3(out2)
28 | out3 = self.pool3(out3)
29 |
30 | batchsz1 = out1.size(0)
31 | batchsz2 = out2.size(0)
32 | batchsz3 = out3.size(0)
33 | out1 = out1.view(batchsz1, -1)
34 | out2 = out2.view(batchsz2, -1)
35 | out3 = out3.view(batchsz3, -1)
36 |
37 | out = torch.cat([
38 | out1, out2, out3
39 | ], dim=1)
40 |
41 | # print(out1.shape, out2.shape, out3.shape)
42 |
43 | out_org = out
44 | out_finnal = self.outlayer(out)
45 |
46 | return out_org, out_finnal
47 |
48 |
49 | def main():
50 | net = MSCNN()
51 | tmp = torch.randn(64, 5, 24, 24)
52 | out1 = net(tmp)[1]
53 |
54 | print(out1.shape)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/model/SSUN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from model.LSTM import lstm
5 | from model.MSCNN import MSCNN
6 |
7 |
8 | class SSUN(nn.Module):
9 | def __init__(self):
10 | super(SSUN, self).__init__()
11 | self.spatial = lstm()
12 | self.spectal = MSCNN()
13 |
14 | self.outlayer = nn.Linear(6176, 16)
15 |
16 | def forward(self, img, img_gt):
17 | out_spatial = self.spatial(img)[0]
18 | out_spectal = self.spectal(img_gt)[0]
19 |
20 | out = torch.cat([
21 | out_spatial, out_spectal
22 | ], dim=1)
23 |
24 | out1 = self.outlayer(out)
25 | out2 = self.spatial(img)[1]
26 | out3 = self.spectal(img_gt)[1]
27 |
28 | return out1, out2, out3
29 |
30 |
31 | def main():
32 | net = SSUN()
33 | img = torch.randn(64, 220, 24, 24)
34 | img_gt = torch.randn(64, 5, 24, 24)
35 |
36 | out = net(img, img_gt)[1]
37 | print(out.shape)
38 |
39 |
40 | if __name__ == '__main__':
41 | main()
42 |
--------------------------------------------------------------------------------
/tool/assessment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | from sklearn.metrics import roc_curve
4 | from sklearn.metrics import confusion_matrix
5 | from sklearn.metrics import cohen_kappa_score
6 | from sklearn.metrics import accuracy_score
7 |
8 |
9 | def accuracy_assessment(img_gt, changed_map):
10 | """
11 | assess accuracy of changed map based on ground truth
12 | """
13 |
14 | cm = []
15 | gt = []
16 | TP, TN, FP, FN = 0, 0, 0, 0
17 | esp = 1e-6
18 |
19 | height, width = changed_map.shape
20 | changed_map_ = np.reshape(changed_map, (-1,))
21 | img_gt_ = np.reshape(img_gt, (-1,))
22 |
23 | cm = np.ones((height * width,))
24 | cm[changed_map_ == 1] = 2
25 | cm[changed_map_ == 0] = 1
26 |
27 | gt = np.zeros((height * width,))
28 | gt[img_gt_ == 1] = 2
29 | gt[img_gt_ == 0] = 1
30 |
31 | # scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口
32 | conf_mat = confusion_matrix(y_true=gt, y_pred=cm, labels=[1, 2])
33 | kappa_co = cohen_kappa_score(y1=gt, y2=cm, labels=[1, 2])
34 |
35 | oa = np.sum(conf_mat.diagonal()) / np.sum(conf_mat)
36 |
37 | return conf_mat, oa, kappa_co
38 |
--------------------------------------------------------------------------------
/tool/show.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def Predict_Label2Img(predict_label):
6 | # predict_label torch.Size([207400, 4])
7 | # predict_img (610, 340)
8 | # indian_pines=([145, 145]) 21025
9 | # KSC = ([512, 614])
10 | predict_img = torch.zeros([145, 145])
11 | num = predict_label.shape[0] # 207400
12 |
13 | for i in range(num):
14 | x = int(predict_label[i][1])
15 | y = int(predict_label[i][2])
16 | l = predict_label[i][3]
17 | predict_img[x][y] = l
18 |
19 | return predict_img
20 |
21 |
22 | if __name__ == '__main__':
23 | predict_label = torch.ones([21025, 4])
24 | predict_img = Predict_Label2Img(predict_label)
25 | print('over')
26 |
--------------------------------------------------------------------------------
/tool/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from data.HSI_data import batch_collate as collate_fn
4 | from torch.utils.data import DataLoader
5 |
6 | import data.data_preprocess as pre_fun
7 |
8 |
9 | def check_keys(model, pretrained_state_dict):
10 | ckpt_keys = set(pretrained_state_dict.keys()) # set() 函数创建一个无序不重复元素集,可进行关系测试,删除重复数据
11 | model_keys = set(model.state_dict().keys()) # state_dict其实就是一个字典,该自点中包含了模型各层和其参数tensor的对应关系。
12 | used_pretrained_keys = model_keys & ckpt_keys
13 | unused_pretrained_keys = ckpt_keys - model_keys
14 | missing_keys = model_keys - ckpt_keys
15 |
16 | print('Missing keys:{}'.format(len(missing_keys)))
17 | print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
18 | print('Used keys:{}'.format(len(used_pretrained_keys)))
19 |
20 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' # 警告位,当有用值小于0则发出警告
21 |
22 | return True
23 |
24 |
25 | def remove_prefix(state_dict, prefix):
26 | print('remove prefix \'{}\''.format(prefix))
27 |
28 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x # 直接赋给一个变量,然后再像一般函数那样调用
29 | # eg: f(x)
30 | return {f(key): value for key, value in state_dict.items()}
31 |
32 |
33 | def load_model(model, pretrained_path, load_to_cpu):
34 | print('loading pretrained model from {}'.format(pretrained_path))
35 |
36 | if load_to_cpu == torch.device('cpu'):
37 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)['model']
38 | else:
39 | device = torch.cuda.current_device()
40 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))['model']
41 |
42 | if "state_dict" in pretrained_dict.keys():
43 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
44 | else:
45 | pretrained_dict = remove_prefix(pretrained_dict, 'module')
46 |
47 | check_keys(model, pretrained_dict)
48 | model.load_state_dict(pretrained_dict, strict=False)
49 |
50 | return model
51 |
52 |
53 | def test(test_data, origin_gt, model, device, cfg):
54 | num_workers = cfg['workers_num']
55 | gpu_num = cfg['gpu_num']
56 |
57 | batch_size = cfg['batch_size']
58 |
59 | model = load_model(model, cfg['model_weights'], device) # 加载模型,文件的格式是pth
60 | model.eval()
61 | model = model.to(device)
62 |
63 | # gpu_num 环境
64 | if gpu_num > 1 and cfg['gpu_train']:
65 | model = torch.nn.DataParallel(model).to(device) # # 采用多卡GPU服务器
66 | else:
67 | model = model.to(device)
68 |
69 | batch_data = DataLoader(test_data, batch_size, shuffle=True, num_workers=num_workers,
70 | collate_fn=collate_fn, pin_memory=True)
71 |
72 | # 数据预定义
73 | predict_correct = 0
74 | label_num = 0
75 | predict_label = []
76 |
77 | for batch_idx, batch_sample in enumerate(batch_data):
78 |
79 | if len(batch_sample) > 3:
80 | img, target, indices, img_pca = batch_sample
81 | img_pca = img_pca.to(device)
82 | else:
83 | img, target, indices = batch_sample
84 |
85 | img = img.to(device)
86 |
87 | # 在该模块下,所有计算得出的tensor的requires_grad都自动设置为False
88 | # 反向传播时不会自动求导了,大大节约了显存
89 |
90 | with torch.no_grad():
91 | prediction = model(img, img_pca)[0]
92 |
93 | label = prediction.softmax(dim=-1).cpu().argmax(dim=1, keepdim=True)
94 |
95 | if target.sum() > 0:
96 | predict_correct += label.eq(target.view_as(label)).sum().item()
97 | label_num += len(target)
98 |
99 | label = pre_fun.label_inverse_transform(label, origin_gt.long())
100 | predict_label.append(torch.cat([indices, label], dim=1))
101 |
102 | predict_label = torch.cat(predict_label, dim=0)
103 |
104 | if label_num > 0:
105 | print('OA {:.2f}%'.format(100 * predict_correct / label_num))
106 |
107 | return predict_label
108 |
--------------------------------------------------------------------------------
/tool/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import datetime
4 | import math
5 | import os
6 |
7 | from data.HSI_data import batch_collate as collate_fn
8 | from torch.utils.data import DataLoader
9 |
10 |
11 | # from model.LSTM import lstm
12 | # from model.MSCNN import MSCNN
13 |
14 |
15 | # 学习概调整策略
16 | def adjust_lr(lr_init, lr_gamma, optimizer, epoch, step_index):
17 | if epoch < 1:
18 | lr = 0.0001 * lr_init
19 | else:
20 | lr = lr_init * lr_gamma ** ((epoch - 1) // step_index)
21 |
22 | for param_group in optimizer.param_groups:
23 | param_group['lr'] = lr
24 |
25 | return lr
26 |
27 |
28 | # 训练
29 | def train(train_data, model, loss_fun, optimizer, device, cfg):
30 | '''
31 | 调用时:fun_train(train_data, model, loss_fun, optimizer, device, cfg_train)
32 | batch_data = DataLoader(train_data, batch_size, shuffle=True, num_workers=num_workers,
33 | collate_fn=collate_fn, pin_memory=True)
34 |
35 | cfg_train = cfg.train['train_model']
36 | train_data = fun_data(data_sets, cfg_data['train_data'])
37 | '''
38 | # (1)基础参数配置
39 |
40 | num_workers = cfg['workers_num'] # 导入同时工作的线程数
41 | gpu_num = cfg['gpu_num'] # 几个GPU工作
42 |
43 | save_folder = cfg['save_folder'] # './weights/'
44 | save_name = cfg['save_name'] # 'model_CNN1D'
45 |
46 | lr_init = cfg['lr']
47 | lr_gamma = cfg['lr_gamma']
48 | lr_step = cfg['lr_step'] # 步进为60
49 | lr_adjust = cfg['lr_adjust'] # 设置为Ture
50 |
51 | epoch_size = cfg['epoch']
52 | batch_size = cfg['batch_size']
53 |
54 | if gpu_num > 1 and cfg['gpu_train']:
55 | # 采用多卡GPU服务器
56 | model = torch.nn.DataParallel(model).to(device)
57 | # 使用样例
58 | # model = model.cuda()
59 | # device_ids = [0, 1, 2, 3] # id为0和1的两块显卡
60 | # model = torch.nn.DataParallel(model, device_ids=device_ids)
61 | else:
62 | model = model.to(device)
63 |
64 | # (2)加载模型开始训练
65 |
66 | model.train()
67 |
68 | # 是否采用上次训练的模型
69 | if cfg['reuse_model']:
70 |
71 | print('loading model')
72 |
73 | checkpoint = torch.load(cfg['reuse_file'], map_location=device) # 用来加载torch.save() 保存的模型文件
74 | start_epoch = checkpoint['epoch']
75 |
76 | model_dict = model.state_dict() # state_dict其实就是一个字典,该自点中包含了模型各层和其参数tensor的对应关系。
77 |
78 | pretrained_dict = {k: v for k, v in checkpoint['model'].item() if k in model_dict} # 再用预训练模型参数更新model_dict
79 | model_dict.update(pretrained_dict)
80 |
81 | model.load_state_dict(model_dict) # 装载已经训练好的模型
82 | else:
83 | start_epoch = 0
84 |
85 | batch_num = math.ceil(len(train_data) / batch_size) # 向上取整,返回的是共有多少个训练次数的数目
86 | print('Start training!')
87 |
88 | for epoch in range(start_epoch + 1, epoch_size + 1):
89 |
90 | epoch_time0 = time.time() # 记录初始时间
91 |
92 | batch_data = DataLoader(train_data, batch_size, shuffle=True, \
93 | num_workers=num_workers, collate_fn=collate_fn, pin_memory=True)
94 |
95 | # 判断学习律是否采用步进的形式
96 | if lr_adjust:
97 | lr = adjust_lr(lr_init, lr_gamma, optimizer, epoch, lr_step)
98 | else:
99 | lr = lr_init
100 |
101 | epoch_loss = 0
102 | predict_correct = 0
103 | label_num = 0
104 |
105 | for batch_idx, batch_sample in enumerate(batch_data): # 遍历加载的数据集
106 |
107 | iteration = (epoch - 1) * batch_num + batch_idx + 1
108 | batch_time0 = time.time()
109 |
110 | # (1)导入图片和标签
111 | if len(batch_sample) > 3:
112 | img, target, indices, img_pca = batch_sample
113 | img_pca = img_pca.to(device)
114 | else:
115 | img, target, indices = batch_sample
116 |
117 | img = img.to(device)
118 | target = target.to(device)
119 |
120 | # (2)前向传播
121 | prediction_SSUN = model(img, img_pca)[0]
122 | prediction_lstm = model(img, img_pca)[1]
123 | prediction_MSCNN = model(img, img_pca)[2]
124 |
125 | # (3)计算损失
126 | loss1 = loss_fun(prediction_SSUN, target.long()) # 这里target应该是标签值
127 | loss2 = loss_fun(prediction_lstm, target.long()) # 这里target应该是标签值
128 | loss3 = loss_fun(prediction_MSCNN, target.long()) # 这里target应该是标签值
129 |
130 | loss = loss1 + loss2 + loss3
131 |
132 | # (4)优化器,反向传播
133 | optimizer.zero_grad() # 将梯度归零
134 | loss.backward() # 反向传播计算得到每个参数的梯度值
135 | optimizer.step() # 通过梯度下降执行一步参数更新
136 |
137 | batch_time1 = time.time()
138 | batch_time = batch_time1 - batch_time0 # 在一个迭代中所花费的时间
139 |
140 | # estimated time of Arrival
141 | batch_eta = batch_time * (batch_num - batch_idx)
142 | epoch_eta = int(batch_time * (epoch_size - epoch) * batch_num + batch_eta)
143 |
144 | epoch_loss += loss.item() # item()返回的是一个浮点型的数据
145 |
146 | predict_label = prediction_SSUN.detach().softmax(dim=-1).argmax(dim=1,
147 | keepdim=True) # 返回一个新的从当前图中分离的Variable
148 | # 返回的 Variable 不会梯度更新。
149 | # 不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。
150 | predict_correct += predict_label.eq(target.view_as(predict_label)).sum().item() # 预测正确的数量之和
151 | label_num += len(target)
152 |
153 | epoch_time1 = time.time()
154 | epoch_time = epoch_time1 - epoch_time0 # 在一个epoch中所花费的时间
155 | epoch_eta = int(epoch_time * (epoch_size - epoch))
156 |
157 | # 将相应的数据进行打印
158 | print('Epoch: {}/{} || lr: {} || loss: {} || Train acc: {:.2f}% || '
159 | 'Epoch time: {:.4f}s || Epoch ETA: {}'
160 | .format(epoch, epoch_size, lr, epoch_loss / batch_num, 100 * predict_correct / label_num,
161 | epoch_time, str(datetime.timedelta(seconds=epoch_eta))
162 | )
163 | )
164 |
165 | if not os.path.exists(save_folder):
166 | os.makedirs(save_folder) # 递归创建目录
167 |
168 | # 存储最终的模型
169 | save_model = dict(
170 | model=model.state_dict(),
171 | epoch=epoch_size
172 | )
173 | torch.save(save_model, os.path.join(save_folder, save_name + '_Final.pth'))
174 |
--------------------------------------------------------------------------------