├── .idea
├── FAC0810.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── attention.py
├── cross.py
├── data_F.py
├── main_c.py
├── models_c.py
└── ops.py
/.idea/FAC0810.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | 1641967545753
43 |
44 |
45 | 1641967545753
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | The code in this toolbox implements the "Multi-attentive hierarchical dense fusion net for fusion classification of hyperspectral and LiDAR data".
3 |
4 | Please kindly cite the papers if this code is useful and helpful for your research.
5 |
6 | @article{WANG20221,
7 |
8 | title = {Multi-attentive hierarchical dense fusion net for fusion classification of hyperspectral and LiDAR data},
9 |
10 | journal = {Information Fusion},
11 |
12 | volume = {82},
13 |
14 | pages = {1-18},
15 |
16 | year = {2022},
17 |
18 | issn = {1566-2535},
19 |
20 | doi = {https://doi.org/10.1016/j.inffus.2021.12.008},
21 |
22 | url = {https://www.sciencedirect.com/science/article/pii/S156625352100258X},
23 |
24 | author = {Xianghai Wang and Yining Feng and Ruoxi Song and Zhenhua Mu and Chuanming Song},
25 |
26 | }
27 |
28 | System-specific notes
29 | The code was tested in the environment of Python 3.7 and keras 2.3.1
30 |
31 | How to use it?
32 | Directly run main_c.py to reproduce the results.
33 |
34 | If you want to run the code in your own data, you can accordingly change the input (e.g., data) and tune the parameters
35 |
36 | If you encounter the bugs while using this code, please do not hesitate to contact us. (SYFyn@outlook.com)
37 |
--------------------------------------------------------------------------------
/attention.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Activation, Conv2D
2 | import keras.backend as K1
3 | import tensorflow as tf
4 | from keras.layers import Layer
5 |
6 |
7 | class PAM(Layer):
8 | def __init__(self,
9 | # gamma_initializer=tf.zeros_initializer(),
10 | # gamma_regularizer=None,
11 | # gamma_constraint=None,
12 | **kwargs):
13 |
14 | super(PAM, self).__init__(**kwargs)
15 |
16 | def build(self, input_shape):
17 | self.gamma = self.add_weight(shape=(1, ),
18 | initializer='zeros',
19 | regularizer=None,
20 | constraint=None,
21 | name='gamma',
22 | trainable=True
23 | )
24 |
25 | self.built = True
26 |
27 | def compute_output_shape(self, input_shape):
28 | return input_shape
29 |
30 | def call(self, input):
31 | input_shape = input.get_shape().as_list()
32 | _, h, w, filters = input_shape
33 |
34 | b = Conv2D(filters // 8, 1,use_bias=False, kernel_initializer='he_normal' )(input)
35 | c = Conv2D(filters // 8, 1,use_bias=False, kernel_initializer='he_normal')(input)
36 | d = Conv2D(filters, 1,use_bias=False, kernel_initializer='he_normal')(input)
37 |
38 | vec_b = K1.reshape(b, (-1, h * w, filters // 8))
39 | vec_cT = tf.transpose(K1.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1))
40 | bcT = K1.batch_dot(vec_b, vec_cT)
41 | softmax_bcT = Activation('softmax')(bcT)
42 | vec_d = K1.reshape(d, (-1, h * w, filters))
43 | bcTd = K1.batch_dot(softmax_bcT, vec_d)
44 | bcTd = K1.reshape(bcTd, (-1, h, w, filters))
45 |
46 | out = self.gamma*bcTd + input
47 | return out
48 |
49 |
50 |
51 | class DPAM(Layer):
52 | def __init__(self,
53 | # gamma_initializer=tf.zeros_initializer(),
54 | # gamma_regularizer=None,
55 | # gamma_constraint=None,
56 | **kwargs):
57 | # self.gamma_initializer = gamma_initializer
58 | # self.gamma_regularizer = gamma_regularizer
59 | # self.gamma_constraint = gamma_constraint
60 | super(DPAM, self).__init__(**kwargs)
61 |
62 | def build(self, input_shape):
63 |
64 | self.gamma = self.add_weight(shape=(1, ),
65 | initializer='zeros',
66 | regularizer=None,
67 | constraint=None,
68 | name='gamma',
69 | trainable=True)
70 |
71 |
72 | self.built = True
73 |
74 | def compute_output_shape(self, input_shape):
75 |
76 | return (input_shape[0], input_shape[1],input_shape[2],input_shape[3])
77 |
78 | def call(self, input):
79 | input1 = input[:,:,:,:,0]
80 | input2 = input[:,:,:,:,1]
81 | input_shape = input1.get_shape().as_list()
82 | _, h, w, filters = input_shape
83 |
84 | b = Conv2D(filters // 8, 1 ,use_bias=False, kernel_initializer='he_normal')(input1)
85 | c = Conv2D(filters // 8, 1,use_bias=False, kernel_initializer='he_normal')(input1)
86 | b2 = Conv2D(filters // 8, 1 ,use_bias=False, kernel_initializer='he_normal')(input2)
87 | c2 = Conv2D(filters // 8, 1,use_bias=False, kernel_initializer='he_normal')(input2)
88 | d = Conv2D(filters, 1,use_bias=False, kernel_initializer='he_normal')(input2)
89 |
90 | vec_b = K1.reshape(b, (-1, h * w, filters // 8))
91 | vec_cT = tf.transpose(K1.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1))
92 | bcT = K1.batch_dot(vec_b, vec_cT)
93 | softmax_bcT = Activation('softmax')(bcT)
94 | vec_b2 = K1.reshape(b2, (-1, h * w, filters // 8))
95 | vec_cT2 = tf.transpose(K1.reshape(c2, (-1, h * w, filters // 8)), (0, 2, 1))
96 | bcT2 = K1.batch_dot(vec_b2, vec_cT2)
97 | softmax_bcT2 = Activation('softmax')(bcT2)
98 | vec_d = K1.reshape(d, (-1, h * w, filters))
99 | bcTd = K1.batch_dot(softmax_bcT, vec_d)
100 | bcTd2 = K1.batch_dot(softmax_bcT2, vec_d)
101 | bcTd = K1.reshape(bcTd, (-1, h, w, filters))
102 | bcTd2 = K1.reshape(bcTd2, (-1, h, w, filters))
103 | out = input2 +self.gamma*bcTd +self.gamma*bcTd2
104 | return out
105 |
106 |
107 |
--------------------------------------------------------------------------------
/cross.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import keras.layers as L
4 |
5 |
6 | def cross(args):
7 | a = args[0]
8 | b = args[1]
9 | a_new = tf.expand_dims(a,-1)
10 | b_new = tf.expand_dims(b,-1)
11 | output = L.concatenate([a_new,b_new], axis=-1)
12 | print(output.shape)
13 |
14 | return output
--------------------------------------------------------------------------------
/data_F.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | import tifffile as tiff
4 | import os
5 | import cv2
6 | import scipy.io as sio
7 | from keras.utils.np_utils import to_categorical
8 | from scipy.cluster.vq import whiten
9 |
10 |
11 | NUM_CLASS = 7
12 | PATH = './data/Aug'
13 | SAVA_PATH = './file/'
14 | BATCH_SIZE = 100
15 | r = 5
16 |
17 | HSName='data_HS_LR.mat'#height_test.tif'
18 | data_HS_LR ='data_HS_LR'
19 | lidarName='data_DSM.mat'#height_test.tif'
20 | data_DSM ='data_DSM'
21 | gth_train = 'TrainImage.mat'
22 | TrainImage ='TrainImage'
23 | gth_test = 'TestImage.mat'
24 | TestImage = 'TestImage'
25 | lchn = 1
26 | hchn = 180
27 |
28 | # NUM_CLASS = 15
29 | # PATH = './data/houston'
30 | # SAVA_PATH = './file/'
31 | # BATCH_SIZE = 100
32 | # r = 5
33 | #
34 | #
35 | # HSName='Houston_HS.tif'#height_test.tif'
36 | # lidarName='Houston_Lidar.tif'#height_test.tif'
37 | # gth_train = 'Houston_train.tif'
38 | # gth_test = 'Houston_test.tif'
39 | # lchn = 1
40 | # hchn = 144
41 |
42 | # NUM_CLASS = 11
43 | # PATH = './data/MUUFL'
44 | # SAVA_PATH = './file/'
45 | # BATCH_SIZE = 100
46 | # r = 5
47 | #
48 | #
49 | # HSName='MUUFL_HS.tif'#height_test.tif'
50 | # lidarName='MUUFL_Lidar.tif'#height_test.tif'
51 | # gth_train = 'MUUFL_train.tif'
52 | # gth_test = 'MUUFL_test.tif'
53 | # lchn = 2
54 | # hchn = 64
55 |
56 | # NUM_CLASS = 6
57 | # PATH = './data/Trento'
58 | # SAVA_PATH = './file/'
59 | # BATCH_SIZE = 64
60 | # r = 5
61 | #
62 | #
63 | # HSName='Trento_HSI.tif'#height_test.tif'
64 | # lidarName='Trento_Lidar.tif'#height_test.tif'
65 | # gth_train = 'Trento_train.tif'
66 | # gth_test = 'Trento_test.tif'
67 | # lchn = 1
68 | # hchn = 63
69 |
70 |
71 |
72 | if not os.path.exists(SAVA_PATH):
73 | os.mkdir(SAVA_PATH)
74 |
75 |
76 | def read_image(filename):
77 | img = tiff.imread(filename)
78 | img = np.asarray(img, dtype=np.float32)
79 | return img
80 |
81 | def read_mat(path,file_name,data_name):
82 | mdata=sio.loadmat(os.path.join(path,file_name))
83 | mdata=np.array(mdata[data_name])
84 | return mdata
85 |
86 |
87 | def samele_wise_normalization(data):
88 | """
89 | normalize each sample to 0-1
90 | Input:
91 | sample
92 | Output:
93 | Normalized sample
94 | """
95 | if np.max(data) == np.min(data):
96 | return np.ones_like(data, dtype=np.float32) * 1e-6
97 | else:
98 | return 1.0 * (data - np.min(data)) / (np.max(data) - np.min(data))
99 |
100 |
101 | def sample_wise_standardization(data):
102 | import math
103 | _mean = np.mean(data)
104 | _std = np.std(data)
105 | npixel = np.size(data) * 1.0
106 | min_stddev = 1.0 / math.sqrt(npixel)
107 | return (data - _mean) / max(_std, min_stddev)
108 |
109 |
110 | def gth2mask(gth):
111 | # gth[gth>7]-=1
112 | # gth-=1
113 | new_gth = np.zeros(
114 | shape=(gth.shape[0], gth.shape[1], NUM_CLASS), dtype=np.int8)
115 | for c in range(NUM_CLASS):
116 | new_gth[gth == c, c] = 1
117 | return new_gth
118 |
119 | def down_sampling_hsi(hsi, scale=2):
120 | hsi = cv2.GaussianBlur(hsi, (3, 3), 0)
121 | hsi = cv2.resize(cv2.resize(hsi,
122 | (hsi.shape[1] // scale, hsi.shape[0] // scale),
123 | interpolation=cv2.INTER_CUBIC),
124 | (hsi.shape[1], hsi.shape[0]),
125 | interpolation=cv2.INTER_CUBIC)
126 | return hsi
127 |
128 | def creat_trainf(validation=False):
129 | hsi = read_mat(PATH,HSName,data_HS_LR)
130 | lidar = read_mat(PATH,lidarName,data_DSM)
131 | gth = read_mat(PATH,gth_train,TrainImage)
132 | # hsi = read_image(os.path.join(PATH, HSName))
133 | # lidar = read_image(os.path.join(PATH, lidarName))
134 | # gth = tiff.imread(os.path.join(PATH, gth_train))
135 | hsi = np.pad(hsi, ((r, r), (r, r), (0, 0)), 'symmetric')
136 | if len(lidar.shape) == 2:
137 | lidar = np.pad(lidar, ((r, r), (r, r)), 'symmetric')
138 | if len(lidar.shape) == 3:
139 | lidar = np.pad(lidar, ((r, r), (r, r), (0, 0)), 'symmetric')
140 | gth = np.pad(gth, ((r, r), (r, r)), 'constant', constant_values=(0, 0))
141 | # gth = np.pad(gth, ((r, r), (r, r)), 'constant', constant_values=(0, 0))
142 | per = 0.89
143 |
144 |
145 | # lidar = samele_wise_normalization(lidar)
146 | lidar = sample_wise_standardization(lidar)
147 | # hsi = samele_wise_normalization(hsi)
148 | hsi = sample_wise_standardization(hsi)
149 | # hsi=whiten(hsi)
150 |
151 | Xh = []
152 | Xl = []
153 | Y = []
154 | for c in range(1, NUM_CLASS + 1):
155 | idx, idy = np.where(gth == c)
156 | if not validation:
157 | idx = idx[:int(per * len(idx))]
158 | idy = idy[:int(per * len(idy))]
159 | else:
160 | idx = idx[int(per * len(idx)):]
161 | idy = idy[int(per * len(idy)):]
162 | np.random.seed(820)
163 | ID = np.random.permutation(len(idx))
164 | idx = idx[ID]
165 | idy = idy[ID]
166 | for i in range(len(idx)):
167 | tmph = hsi[idx[i] - r:idx[i] + r + 1, idy[i] - r:idy[i] + r + 1, :]
168 | tmpl = lidar[idx[i] - r:idx[i] + r +
169 | 1, idy[i] - r:idy[i] + r + 1]
170 | tmpy = gth[idx[i], idy[i]] - 1
171 | Xh.append(tmph)
172 | Xh.append(np.flip(tmph, axis=0))
173 | noise = np.random.normal(0.0, 0.01, size=tmph.shape)
174 | Xh.append(np.flip(tmph + noise, axis=1))
175 | k = np.random.randint(4)
176 | Xh.append(np.rot90(tmph, k=k))
177 |
178 |
179 | Xl.append(tmpl)
180 | Xl.append(np.flip(tmpl, axis=0))
181 | noise = np.random.normal(0.0, 0.03, size=tmpl.shape)
182 | Xl.append(np.flip(tmpl + noise, axis=1))
183 | Xl.append(np.rot90(tmpl, k=k))
184 |
185 |
186 |
187 |
188 | Y.append(tmpy)
189 | Y.append(tmpy)
190 | Y.append(tmpy)
191 | Y.append(tmpy)
192 |
193 |
194 | index = np.random.permutation(len(Xh))
195 | Xh = np.asarray(Xh, dtype=np.float32)
196 | Xl = np.asarray(Xl, dtype=np.float32)
197 | Y = np.asarray(Y, dtype=np.int8)
198 | Xh = Xh[index, ...]
199 | if len(Xl.shape) == 3:
200 | Xl = Xl[index, ..., np.newaxis]
201 | elif len(Xl.shape) == 4:
202 | Xl = Xl[index, ...]
203 | Y = Y[index]
204 | print('train hsi data shape:{},train lidar data shape:{}'.format(
205 | Xh.shape, Xl.shape))
206 | if not validation:
207 | np.save(os.path.join(SAVA_PATH, 'train_Xh.npy'), Xh)
208 | np.save(os.path.join(SAVA_PATH, 'train_Xl.npy'), Xl)
209 | np.save(os.path.join(SAVA_PATH, 'train_Y.npy'), Y)
210 | else:
211 | np.save(os.path.join(SAVA_PATH, 'val_Xh.npy'), Xh)
212 | np.save(os.path.join(SAVA_PATH, 'val_Xl.npy'), Xl)
213 | np.save(os.path.join(SAVA_PATH, 'val_Y.npy'), Y)
214 |
215 | def make_AL():
216 | hsi = read_image(os.path.join(PATH, HSName))
217 | lidar = read_image(os.path.join(PATH, lidarName))
218 | gthTR = tiff.imread(os.path.join(PATH, gth_train))
219 | gthTE = tiff.imread(os.path.join(PATH, gth_test))
220 |
221 | gth =gthTE-gthTR
222 | hsi = np.pad(hsi, ((r, r), (r, r), (0, 0)), 'symmetric')
223 | if len(lidar.shape) == 2:
224 | lidar = np.pad(lidar, ((r, r), (r, r)), 'symmetric')
225 | if len(lidar.shape) == 3:
226 | lidar = np.pad(lidar, ((r, r), (r, r), (0, 0)), 'symmetric')
227 | gth = np.pad(gth, ((r, r), (r, r)), 'constant', constant_values=(0, 0))
228 | # gth=read_mat(PATH,gth_test,'mask_test')
229 |
230 | # lidar = samele_wise_normalization(lidar)
231 | lidar = sample_wise_standardization(lidar)
232 | # hsi = samele_wise_normalization(hsi)
233 | hsi = sample_wise_standardization(hsi)
234 | # hsi=whiten(hsi)
235 | idx, idy = np.where(gth != 0)
236 | ID = np.random.permutation(len(idx))
237 | Xh = []
238 | Xl = []
239 | for i in range(len(idx)):
240 | tmph = hsi[idx[ID[i]] - r:idx[ID[i]] + r +
241 | 1, idy[ID[i]] - r:idy[ID[i]] + r + 1, :]
242 | tmpl = lidar[idx[i] - r:idx[i] + r +
243 | 1, idy[i] - r:idy[i] + r + 1]
244 | Xh.append(tmph)
245 | Xl.append(tmpl)
246 | Xh = np.asarray(Xh, dtype=np.float32)
247 | Xl = np.asarray(Xl, dtype=np.float32)
248 | if len(Xl.shape) == 3:
249 | Xl = Xl[..., np.newaxis]
250 | # print index
251 | np.save(os.path.join(SAVA_PATH, 'hsiAL.npy'), Xh)
252 | np.save(os.path.join(SAVA_PATH, 'lidarAL.npy'), Xl)
253 | np.save(os.path.join(SAVA_PATH, 'indexAL.npy'), [idx[ID] - r, idy[ID] - r])
254 | return Xh,Xl
255 |
256 | def make_cTestf():
257 | HS = read_mat(PATH,HSName,data_HS_LR)
258 | lidar = read_mat(PATH,lidarName,data_DSM)
259 | gth = read_mat(PATH,gth_test,TestImage)
260 | # lidar = read_image(os.path.join(PATH, lidarName))
261 | # HS = read_image(os.path.join(PATH, HSName))
262 | # gth = tiff.imread(os.path.join(PATH, gth_test))
263 |
264 | HS = np.pad(HS, ((r, r), (r, r), (0, 0)), 'symmetric')
265 | if len(lidar.shape) == 2:
266 | lidar = np.pad(lidar, ((r, r), (r, r)), 'symmetric')
267 | if len(lidar.shape) == 3:
268 | lidar = np.pad(lidar, ((r, r), (r, r), (0, 0)), 'symmetric')
269 | gth = np.pad(gth, ((r, r), (r, r)), 'constant', constant_values=(0, 0))
270 |
271 | # lidar = samele_wise_normalization(lidar)
272 | lidar = sample_wise_standardization(lidar)
273 | # HS = samele_wise_normalization(HS)
274 | HS = sample_wise_standardization(HS)
275 | # hsi=whiten(hsi)
276 | idx, idy = np.where(gth != 0)
277 | np.random.seed(820)
278 | ID = np.random.permutation(len(idx))
279 | Xh = []
280 | Xl = []
281 | for i in range(len(idx)):
282 | tmph = HS[idx[ID[i]] - r:idx[ID[i]] + r +
283 | 1, idy[ID[i]] - r:idy[ID[i]] + r + 1, :]
284 | tmpl = lidar[idx[ID[i]] - r:idx[ID[i]] +
285 | r + 1, idy[ID[i]] - r:idy[ID[i]] + r + 1]
286 |
287 | Xh.append(tmph)
288 | Xl.append(tmpl)
289 | Xh = np.asarray(Xh, dtype=np.float32)
290 | Xl = np.asarray(Xl, dtype=np.float32)
291 | index = np.concatenate(
292 | (idx[..., np.newaxis], idy[..., np.newaxis]), axis=1)
293 | np.save(os.path.join(SAVA_PATH, 'hsi.npy'), Xh)
294 | np.save(os.path.join(SAVA_PATH, 'lidar.npy'), Xl)
295 | np.save(os.path.join(SAVA_PATH, 'index.npy'), [idx[ID] - r, idy[ID] - r])
296 | if len(Xl.shape) == 3:
297 | Xl = Xl[..., np.newaxis]
298 | return Xl, Xh
299 |
--------------------------------------------------------------------------------
/main_c.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import keras as K
3 | import keras.layers as L
4 | import numpy as np
5 | import os
6 | import random
7 | import time
8 | import h5py
9 | import argparse
10 |
11 | from data_F import *
12 | from models_c import *
13 | from ops import *
14 | from keras.callbacks import ModelCheckpoint
15 | from keras.callbacks import ReduceLROnPlateau
16 |
17 | from keras.callbacks import EarlyStopping
18 | from keras.callbacks import TensorBoard
19 | from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
20 | from keras.models import load_model
21 | from keras.applications.resnet50 import ResNet50
22 | import keras.backend as K1
23 | import math
24 | # save weights
25 |
26 | _weights_f = "my_model_weights.h5"
27 |
28 | _TFBooard = 'logs/events/'
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--modelname', type=str,
32 | default='my_model_weights.h5', help='final model save name')
33 | parser.add_argument('--epochs',type=int,
34 | default=50,help='numb er of epochs')
35 | args = parser.parse_args()
36 |
37 | if not os.path.exists('logs/weights/'):
38 | os.makedirs('logs/weights/')
39 |
40 | if not os.path.exists(_TFBooard):
41 | os.mkdir(_TFBooard)
42 |
43 |
44 |
45 |
46 |
47 | def train_merge(model):
48 |
49 | # # create train data
50 | creat_trainf(validation=False)
51 | creat_trainf(validation=True)
52 |
53 |
54 |
55 |
56 |
57 |
58 | Xh_train = np.load('./file/train_Xh.npy')
59 | Xh_val = np.load('./file/val_Xh.npy')
60 | Xl_train = np.load('./file/train_Xl.npy')
61 | Xl_val = np.load('./file/val_Xl.npy')
62 |
63 |
64 |
65 | Y_train = K.utils.np_utils.to_categorical(np.load('./file/train_Y.npy'))
66 | Y_val = K.utils.np_utils.to_categorical(np.load('./file/val_Y.npy'))
67 |
68 |
69 |
70 |
71 | print('Xl_train', Xl_train.shape)
72 | print('Xl_val', Xl_val.shape)
73 | print('Xh_train', Xh_train.shape)
74 | print('Xh_val', Xh_val.shape)
75 | print('Y_val', Y_val.shape)
76 | print('Y_train', Y_train.shape)
77 |
78 |
79 |
80 | model_ckt = ModelCheckpoint(filepath=_weights_f, monitor = 'val_loss',verbose=1, save_best_only=True)
81 | #
82 |
83 | model.fit([Xh_train, Xh_train[:, r, r, :, np.newaxis],Xl_train], Y_train, batch_size=BATCH_SIZE, epochs=args.epochs,
84 | callbacks=[model_ckt], validation_data=([Xh_val, Xh_val[:, r, r, :, np.newaxis],Xl_val], Y_val))
85 |
86 | print(args.modelname)
87 | print(_weights_f)
88 |
89 |
90 | def test(network,mode=None):
91 | if network == 'merge':
92 | model = merge_branch()
93 | model.load_weights(_weights_f)
94 | [Xl, Xh] = make_cTestf()
95 | pred = model.predict([Xh, Xh[:, r, r, :, np.newaxis], Xl])
96 | acc, kappa = cvt_map(pred, show=False)
97 | print('acc: {:.2f}% Kappa: {:.4f}'.format(acc, kappa))
98 |
99 |
100 |
101 | def main():
102 |
103 |
104 |
105 | model = merge_branch()
106 | model.summary()
107 | start = time.time()
108 | train_merge(model)
109 | test('merge')
110 | print('elapsed time:{:.2f}s'.format(time.time() - start))
111 |
112 |
113 |
114 |
115 | if __name__ == '__main__':
116 | main()
117 |
118 |
--------------------------------------------------------------------------------
/models_c.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import tensorflow as tf
4 | import keras as K
5 | import keras.layers as L
6 | import numpy as np
7 | import os
8 | import time
9 | import h5py
10 | import argparse
11 | import matplotlib.pyplot as plt
12 | from attention import *
13 | from data_F import *
14 | from keras.callbacks import ModelCheckpoint
15 | from keras.callbacks import EarlyStopping
16 | from keras.models import Model
17 | from keras.layers import Input, Activation, Conv2D, Dropout
18 | from keras.layers import MaxPooling2D, BatchNormalization
19 | from keras.layers import UpSampling2D
20 | from keras.layers import concatenate
21 | from keras.layers import add
22 | from attention import PAM, DPAM
23 | from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense,Activation,Layer
24 | import keras.backend as K1
25 | from cross import *
26 | from keras.optimizers import RMSprop
27 | # ===================cascade net=============
28 |
29 | def attention_block_3(inputs,feature_cnt,dim):
30 | a = Flatten()(inputs)
31 | a = Dense(feature_cnt*dim,activation='softmax')(a)
32 | a = L.Reshape((feature_cnt,dim,))(a)
33 | a = L.Lambda(lambda x: K1.sum(x, axis=2), name='attention')(a)
34 | a = L.RepeatVector(dim)(a)
35 | a_probs = L.Permute((2, 1), name='attention_vec')(a)
36 | attention_out = L.Multiply()([inputs, a_probs])
37 | return attention_out
38 |
39 | def small_cnn_branch(input_tensor, small_mode=True):
40 | filters=[32,64,100,200,256]
41 | conv0_spat=L.Conv2D(filters[2],(3,3),padding='same')(input_tensor)
42 | conv0_spat=L.BatchNormalization(axis=-1)(conv0_spat)
43 | conv0_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv0_spat)
44 | conv1_spat=L.Conv2D(filters[2],(3,3),padding='same')(conv0_spat)
45 | conv1_spat=L.BatchNormalization(axis=-1)(conv1_spat)
46 | conv1_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv1_spat)
47 | conv2_spat=L.Conv2D(filters[3],(1,1),padding='same')(conv1_spat)
48 | conv2_spat=L.BatchNormalization(axis=-1)(conv2_spat)
49 | conv2_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv2_spat)
50 | conv3_spat=L.Conv2D(filters[3],(1,1),padding='same')(conv2_spat)
51 | conv3_spat=L.BatchNormalization(axis=-1)(conv3_spat)
52 | conv3_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv3_spat)
53 | conv3_spat = PAM()(conv3_spat)
54 | pool3=L.MaxPool2D(pool_size=(2,2),padding='same')(conv3_spat)
55 | Dense1=L.Dense(1024)(pool3)
56 | Dense1=L.Activation('relu')(Dense1)
57 | Dense1=L.Dropout(0.5)(Dense1)
58 | Dense2=L.Dense(512)(Dense1)
59 | Dense2=L.Activation('relu')(Dense2)
60 | Dense2=L.Dropout(0.5)(Dense2)
61 | conv7_spat=L.Flatten()(Dense2)
62 | return conv7_spat
63 |
64 |
65 |
66 |
67 | def small_cnn_branch_front(input_tensor):
68 | filters=[32,64,100,200,256]
69 | conv0_spat=L.Conv2D(filters[2],(3,3),padding='same')(input_tensor)
70 | conv0_spat=L.BatchNormalization(axis=-1)(conv0_spat)
71 | conv0_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv0_spat)
72 | conv1_spat=L.Conv2D(filters[2],(3,3),padding='same')(conv0_spat)
73 | conv1_spat=L.BatchNormalization(axis=-1)(conv1_spat)
74 | conv1_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv1_spat)
75 | conv2_spat=L.Conv2D(filters[3],(1,1),padding='same')(conv1_spat)
76 | conv2_spat=L.BatchNormalization(axis=-1)(conv2_spat)
77 | conv2_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv2_spat)
78 | conv3_spat=L.Conv2D(filters[3],(1,1),padding='same')(conv2_spat)
79 | conv3_spat=L.BatchNormalization(axis=-1)(conv3_spat)
80 | conv3_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv3_spat)
81 | conv3_spat=L.Conv2D(filters[3],(1,1),padding='same')(conv3_spat)
82 | conv3_spat=L.BatchNormalization(axis=-1)(conv3_spat)
83 | conv3_spat=L.advanced_activations.LeakyReLU(alpha=0.2)(conv3_spat)
84 |
85 | return conv3_spat
86 |
87 |
88 | def small_cnn_branch_latter(input_tensor):
89 | pool1=L.MaxPool2D(pool_size=(2,2),padding='same')(input_tensor)
90 | Dense1=L.Dense(1024)(pool1)
91 | Dense1=L.Activation('relu')(Dense1)
92 | Dense1=L.Dropout(0.4)(Dense1)
93 | Dense2=L.Dense(512)(Dense1)
94 | Dense2=L.Activation('relu')(Dense2)
95 | Dense2=L.Dropout(0.4)(Dense2)
96 | conv7_spat=L.Flatten()(Dense2)
97 | return conv7_spat
98 |
99 |
100 |
101 |
102 |
103 | def pixel_branch(input_tensor):
104 | filters = [8, 16, 32, 64, 96, 128]
105 | conv0 = L.Conv1D(filters[3], 11, padding='valid')(input_tensor)
106 | conv0_a = attention_block_3(conv0,170,64)
107 | conv0 = L.concatenate([conv0,conv0_a])
108 | conv0 = L.BatchNormalization(axis=-1)(conv0)
109 | conv0 = L.advanced_activations.LeakyReLU(alpha=0.2)(conv0)
110 | conv3 = L.Conv1D(filters[5], 3, padding='valid')(conv0)
111 | conv3 = L.advanced_activations.LeakyReLU(alpha=0.2)(conv3)
112 | conv3 = L.MaxPool1D(pool_size=2, padding='valid')(conv3)
113 | conv3 = L.Conv1D(filters[5], 3, padding='valid')(conv3)
114 | conv3 = L.advanced_activations.LeakyReLU(alpha=0.2)(conv3)
115 | conv3 = L.MaxPool1D(pool_size=2, padding='valid')(conv3)
116 | conv3 = L.Conv1D(filters[5], 3, padding='valid')(conv3)
117 | conv3 = L.advanced_activations.LeakyReLU(alpha=0.2)(conv3)
118 | conv3 = L.MaxPool1D(pool_size=2, padding='valid')(conv3)
119 | conv3 = L.Flatten()(conv3)
120 | return conv3
121 |
122 |
123 |
124 | def merge_branch():
125 | ksize = 2 * r + 1
126 | filters = [64, 128, 256, 512]
127 | hsi_in = L.Input((ksize, ksize, hchn))
128 | hsi_pxin = L.Input((hchn, 1))
129 | lidar_in = L.Input((ksize, ksize,lchn))
130 |
131 | h_simple = small_cnn_branch(hsi_in, small_mode=False)
132 | px_out = pixel_branch(hsi_pxin)
133 |
134 |
135 | ha_simple = small_cnn_branch_front(hsi_in)
136 | la_simple = small_cnn_branch_front(lidar_in)
137 |
138 | ha_simple_c = L.Lambda(cross)([la_simple,ha_simple])
139 | la_simple_c = L.Lambda(cross)([ha_simple,la_simple])
140 |
141 | ha_simple = DPAM()(ha_simple_c)
142 | la_simple = DPAM()(la_simple_c)
143 |
144 | ha_simple = small_cnn_branch_latter(ha_simple)
145 | la_simple = small_cnn_branch_latter(la_simple)
146 |
147 | merge1=L.concatenate([h_simple,px_out], axis=-1)
148 | merge1=L.Dropout(0.5)(merge1)
149 | merge2=L.concatenate([ha_simple,la_simple], axis=-1)
150 | merge2=L.Dropout(0.5)(merge2)
151 | merge=L.concatenate([merge1,merge2], axis=-1)
152 | merge=L.Dropout(0.5)(merge)
153 |
154 | logits = L.Dense(NUM_CLASS, activation='softmax',name='logits_out')(merge)
155 |
156 |
157 | model = K.models.Model([hsi_in,hsi_pxin,lidar_in], logits)
158 | adam = K.optimizers.Adam(lr=0.0001,beta_1=0.9,beta_2=0.999,epsilon=1e-8)
159 | optm = K.optimizers.SGD(lr=0.00005,momentum=1e-6,nesterov=True)
160 | model.compile(optimizer=adam,
161 | loss='categorical_crossentropy', metrics=['acc'])
162 |
163 |
164 | return model
165 |
166 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 | import h5py
5 | import keras as K
6 | import keras.layers as L
7 | from data_F import *
8 | from models_c import *
9 | import matplotlib.pyplot as plt
10 |
11 | def read_image(filename):
12 | img = tiff.imread(filename)
13 | img = np.asarray(img, dtype=np.float32)
14 | return img
15 |
16 | def countX(lst,x):
17 | count = 0
18 | for ele in lst:
19 | if(ele == x):
20 | count = count+1
21 | return count
22 |
23 | def cvt_map(pred, show=False):
24 | """
25 | convert prediction percent to map
26 | """
27 | # gth = tiff.imread(os.path.join(PATH, gth_test))
28 | gth = read_mat(PATH, gth_test, TestImage)
29 | # gth=read_mat(PATH,gth_test,'mask_test')
30 | pred = np.argmax(pred, axis=1)
31 | pred = np.asarray(pred, dtype=np.int8) + 1
32 | print(pred)
33 | np.save(os.path.join(SAVA_PATH, 'predops.npy'), pred)
34 | index = np.load(os.path.join(SAVA_PATH, 'index.npy'))
35 | pred_map = np.zeros_like(gth)
36 | cls = []
37 | for i in range(index.shape[1]):
38 | pred_map[index[0, i], index[1, i]] = pred[i]
39 | cls.append(gth[index[0, i], index[1, i]])
40 | cls = np.asarray(cls, dtype=np.int8)
41 | if show:
42 | plt.imshow(pred_map)
43 | plt.figure()
44 | plt.imshow(gth)
45 | plt.show()
46 | tiff.imsave('results.tif',pred_map)
47 | count = np.sum(pred == cls)
48 | mx = confusion(pred - 1, cls - 1)
49 | print(mx)
50 | acc = 100.0 * count / np.sum(gth != 0)
51 | kappa = compute_Kappa(mx)
52 | return acc, kappa
53 |
54 |
55 | def confusion(pred, labels):
56 | """
57 | make confusion matrix
58 | """
59 | mx = np.zeros((NUM_CLASS, NUM_CLASS))
60 | if len(pred.shape) == 2:
61 | pred = np.asarray(np.argmax(pred, axis=1))
62 |
63 | for i in range(labels.shape[0]):
64 | mx[pred[i], labels[i]] += 1
65 | mx = np.asarray(mx, dtype=np.int16)
66 | np.savetxt('confusion.txt', mx, delimiter=" ", fmt="%s")
67 | return mx
68 |
69 | def compute_Kappa(confusion_matrix):
70 | """
71 | TODO =_=
72 | """
73 | N = np.sum(confusion_matrix)
74 | N_observed = np.trace(confusion_matrix)
75 | Po = 1.0 * N_observed / N
76 | h_sum = np.sum(confusion_matrix, axis=0)
77 | v_sum = np.sum(confusion_matrix, axis=1)
78 | Pe = np.sum(np.multiply(1.0 * h_sum / N, 1.0 * v_sum / N))
79 | kappa = (Po - Pe) / (1.0 - Pe)
80 | return kappa
81 |
82 | def eval(pred, gth, show=False):
83 | """
84 | evaluate between prediction and ground truth
85 | return the over accuracy
86 | """
87 | pred = np.argmax(pred, 4)
88 | h, w = gth.shape
89 | if not h % ksize == 0:
90 | hm = ((h // ksize) + 1) * ksize
91 | if not w % ksize == 0:
92 | wm = ((w // ksize) + 1) * ksize
93 | new_map = np.zeros(shape=(hm, wm))
94 | for i in range(pred.shape[1]):
95 | for j in range(pred.shape[0]):
96 | new_map[i * ksize:(i + 1) * ksize, j *
97 | ksize:(j + 1) * ksize] = pred[j, i, :, :]
98 | new_map = np.asarray(new_map, dtype=np.int8)
99 | new_map = new_map[0:h, 0:w]
100 | cls_gth = np.zeros_like(gth)
101 | cls_map = np.zeros_like(new_map)
102 | cls_map[new_map != 0] = new_map[new_map != 0]
103 | cls_gth[gth != 0] = gth[gth != 0]
104 | count = np.sum(cls_gth == cls_map)
105 | acc = 1.0 * count / np.sum(gth != 0)
106 | if show:
107 | plt.imshow(new_map)
108 | plt.figure()
109 | plt.imshow(gth)
110 | plt.show()
111 | return acc
112 |
113 | def visual_model(model,imgname):
114 | from keras.utils import plot_model
115 | # plot_model(model, to_file=imgname, show_shapes=True)
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------