├── LICENSE ├── README.md ├── demo_keras.py ├── demo_keras_loadsamples.py ├── demo_keras_predict.py ├── demo_keras_tif.py ├── demo_keras_train.py ├── demo_pytorch.py ├── networks.py ├── pytorch ├── demo_pytorch_v1.py ├── readme.md ├── rscls.py └── torchnet.py └── rscls.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shengjie Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Remote sensing image classification 2 | This project focuses on remote sensing image classification using deep learning. 3 | 4 | The current implementations are based on PyTorch and Keras with TensorFlow backend. 5 | 6 | Feel free to contact me if you need any further information: liushengjie0756 AT gmail.com 7 | 8 | ## Overview 9 | In the script, we first conduct image segmentation and divide the image to several objects. 10 | Then, we generate training samples and train a network. The network is used to predict the whole image. 11 | Finally, the object-based post-classification refinement strategy is utilized to refine the classification maps. 12 | 13 | ### Networks 14 | - Wide Contextual Residual Network - WCRN [2] 15 | - Double Branch Multi Attention Mechanism Network - DBMA [3] 16 | - Residual Network with Average Pooling - ResNet99_avg 17 | - Residual Network - ResNet99 [4] 18 | - Deep Contextual CNN - DCCNN [5] 19 | 20 | ### Requirements 21 | pytorch==1.1.0 # for PyTorch implementation 22 | skimage==0.15.0 23 | sciPy==1.0.0 24 | sklearn==0.19.1 25 | keras==2.1.6 # for Keras implementation 26 | tensorflow==1.9.0 # for Keras implementation 27 | 28 | ### Data sets 29 | You can download the hyperspectral data sets in matlab format at: http://www.ehu.eus/ccwintco/index.php/Hyperspectral_Remote_Sensing_Scenes 30 | 31 | Then, you can convert the data sets to numpy.array 32 | 33 | ## How to use 34 | ### Run demo_pytorch.py 35 | You will see two predicted maps under the current directory when finished. 36 | One is raw classification, and the other is after object-based post-classification refinement (superpixel-based regularization). 37 | 38 | This implementation is based on PyTorch using the Wide Contextual Residual Network [2]. 39 | 40 | ### Run demo_keras.py 41 | This implementation is based on Keras with TensorFlow backend. 42 | 43 | For this demo, the dafault network is DBMA. By changing the parameter - patch, which controls the window size of each sample, other networks will be applied. 44 | 45 | ### Separate training and testing 46 | Some imagery may be too large to be loaded in memory at once. For this scenario, we use subsets of the imagery, and separate the training and testing parts so that all the samples can be used for training. To do so, you need to decide how to clip the imagery and fill in the arguments in <*demo_keras_loadsamples.py*>. The workflow of separate training and testing goes as follows. 47 | 48 | - First, run <*demo_keras_loadsamples.py*> to generate training samples and save them under current dir. 49 | - Then, run <*demo_keras_train.py*> to train the model, and the model will be saved under current dir. 50 | - Finally, run <*demo_keras_predict.py*> to predict the whole image. 51 | 52 | ## Patch and the corresponding network 53 | - patch==5: WCRN 54 | - patch==7: DBMA 55 | - patch==9: ResNet99_avg 56 | 57 | ## Networks' performance 58 | ### Keras 59 | Network | WCRN | DBMA | ResNet99 | ResNet99_avg | DCCNN 60 | :-: | :-: | :-: | :-: | :-: | :-: 61 | train time (s) | 18 | 222 | 21 | 20 | 41 | 62 | test time (s) | 12| 199 | 22 | 21 | 18 | 63 | OA (%) | 83.00 | 86.86 | 72.34 | 86.68 | 77.54 | 64 | 65 | The experiments are based on Keras with TensorFlow backend using **10 samples per class with augmentation**, conducted on a machine equipped with Intel i5-8400, GTX1050Ti 4G and 8G RAM. The OA is of raw classification averaged from 10 Monte Carlo runs. 66 | 67 | Network | WCRN | DBMA | ResNet99 | ResNet99_avg | DCCNN 68 | :-: | :-: | :-: | :-: | :-: | :-: 69 | train time (s) | 9 | 77 | 11 | 10 | 18 | 70 | test time (s) | 13| 133 | 22 | 19 | 16 | 71 | OA (%) | 72.77 | 74.93 | 62.47 | 74.50 | 65.51 | 72 | 73 | The experiments are based on Keras with TensorFlow backend using **5 samples per class with augmentation**, conducted on a machine equipped with Intel i5-8500, GTX1060 5G and 32G RAM. The OA is of raw classification averaged from 10 Monte Carlo runs. 74 | 75 | 76 | Network | WCRN | DBMA | ResNet99 | ResNet99_avg | DCCNN 77 | :-: | :-: | :-: | :-: | :-: | :-: 78 | train time (s) | 91 | 755 | 98 | 88 | 132 | 79 | test time (s) | 14 | 132 | 22 | 20 | 17 | 80 | OA (%) | 77.91 | 81.14 | 78.22 | 79.56 | 77.05 | 81 | 82 | The experiments are based on Keras with TensorFlow backend using **5 samples per class with augmentation and pseudo samples**, conducted on a machine equipped with Intel i5-8500, GTX1060 5G and 32G RAM. The OA is of raw classification averaged from 10 Monte Carlo runs. 83 | 84 | ##### Baseline for data sets 85 | Data | WCRN | DBMA | ResNet99 | ResNet99_avg | DCCNN 86 | :-: | :-: | :-: | :-: | :-: | :-: 87 | Pavia University | - | - | - | - | - | 88 | Pavia Center | - | - | - | - | - | 89 | Indian Pine | - | - | - | 80.21 (2.49) | - | 90 | Salinas Valley | - | - | - | - | - | 91 | KSC | - | - | - | 95.08 (0.96) | - | 92 | University of Houston | - | - | - | - | - | 93 | Flevoland | 77.25 (1.84) | 77.29 (2.55) | - | 81.66 (1.01) | - | 94 | Foulum | 95.87 (1.06) | 97.99 (1.20) | - | 98.26 (1.11) | - | 95 | 96 | The experiments are based on Keras with TensorFlow backend using **10 samples per class**, conducted on a machine equipped with Intel i5-8500, GTX1060 5G and 32G RAM. The OA is of raw classification averaged from 10 Monte Carlo runs. 97 | 98 | ### Pytorch 99 | Network | WCRN | WCRM-group | WCRN-normal | WCRN-bn-normal | wcrn-bn-default | resnet99-torch 100 | :-: | :-: | :-: | :-: | :-: | :-: | :-: 101 | train time (s) | 17 | 49 | 17 | | | | 102 | test time (s) | | 13 | 20 | 13 | | | 103 | OA (%) | 79.88 | 82.16 | 78.83 | 80.36 | | 85.25 | 104 | 105 | The experiments are based on **Pytorch** backend using **10 samples per class with augmentation**, conducted on a machine equipped with Intel i7-8700 and 32G RAM (no cuda). The OA is of raw classification averaged from 10 Monte Carlo runs. 106 | 107 | 108 | 109 | ## To do 110 | - Add PyTorch implementation of DBMA and ResNet99_avg 111 | - Active learning 112 | - Multitask deep learning 113 | 114 | ## References 115 | [1] [Liu, S., Qi, Z., Li, X. and Yeh, A.G.O., 2019. Integration of Convolutional Neural Networks and Object-Based Post-Classification 116 | Refinement for Land Use and Land Cover Mapping with Optical and SAR Data. Remote Sens., 11(6), p.690.](https://doi.org/10.3390/rs11060690) 117 | 118 | [2] [Liu, S., Luo, H., Tu, Y., He, Z. and Li, J., 2018, July. Wide Contextual Residual Network with Active Learning for Remote 119 | Sensing Image Classification. In IGARSS 2018, pp. 7145-7148.](https://doi.org/10.1109/IGARSS.2018.8517855) 120 | 121 | [3] [Ma, W.; Yang, Q.; Wu, Y.; Zhao, W.; Zhang, X. Double-Branch Multi-Attention Mechanism Network for Hyperspectral Image Classification. Remote Sens. 2019, 11, 1307.](https://doi.org/10.3390/rs11111307) 122 | 123 | [4] [Liu, S., and Shi, Q., 2019. Multitask Deep Learning with Spectral Knowledge for Hyperspectral Image Classification. arXiv preprint arXiv:1905.04535.](https://arxiv.org/abs/1905.04535) 124 | 125 | [5] [Lee H. Lee and H. Kwon, "Going Deeper With Contextual CNN for Hyperspectral Image Classification," in IEEE Transactions on Image Processing, vol. 26, no. 10, pp. 4843-4855, Oct. 2017.](https://doi.org/10.1109/TIP.2017.2725580) 126 | 127 | [6] [Liu, S., Shi, Q. and Zhang, L., 2020. Few-shot hyperspectral image classification with unknown classes using multitask deep learning. IEEE Transactions on Geoscience and Remote Sensing, 59(6), pp.5085-5102.](https://ieeexplore.ieee.org/abstract/document/9186822) 128 | -------------------------------------------------------------------------------- /demo_keras.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jul 6 01:36:12 2019 4 | Last updated on Aug 7 2019 5 | 6 | @author: Shengjie Liu 7 | @Email: liushengjie0756@gmail.com 8 | """ 9 | 10 | import numpy as np 11 | from scipy import stats 12 | import rscls 13 | import matplotlib.pyplot as plt 14 | import time 15 | import networks as nw 16 | from keras.utils import to_categorical 17 | import keras 18 | from keras.losses import categorical_crossentropy 19 | from keras.optimizers import Adadelta 20 | from keras import backend as K 21 | import argparse 22 | from sklearn.decomposition import PCA,IncrementalPCA 23 | #from sklearn.covariance import GraphicalLassoCV 24 | from keras.callbacks import EarlyStopping 25 | from scipy.io import loadmat 26 | from skimage.segmentation import felzenszwalb 27 | 28 | ## data location 29 | im1_file = 'PaviaU.mat' 30 | gt1_file = 'PaviaU_gt.mat' 31 | 32 | ## number of training samples per class 33 | parser = argparse.ArgumentParser(description='manual to this script') 34 | parser.add_argument('--nos', type=int, default = 10) 35 | args = parser.parse_args() 36 | num_per_cls1 = args.nos 37 | 38 | ## network configuration 39 | patch = 7 # if patch==5, WCRN. if patch==7, BDMA. if patch==9, ResNet-avg. 40 | vbs = 0 # if vbs==0, training in silent mode 41 | bsz1 = 20 # batch size 42 | ensemble = 1 # if ensemble>1, snapshot ensemble activated 43 | # if loss not decrease for 5 epoches, stop training 44 | early_stopping = EarlyStopping(monitor='loss', patience=5, verbose=2) 45 | 46 | 47 | # for Monte Carlo runs 48 | seedx = [0,1,2,3,4,5,6,7,8,9] 49 | seedi = 0 # default seed = 0 50 | 51 | saved = {} # save confusion matrix for raw classification 52 | saved_p = {} # save confusion matrix after object-based refinement 53 | 54 | if True: 55 | for seedi in range(0,1): 56 | time1 = int(time.time()) 57 | K.clear_session() # clear session before next loop 58 | print('seed'+str(seedi)+','+str(seedx[seedi])) 59 | 60 | gt = loadmat(gt1_file)['paviaU_gt'] 61 | im = loadmat(im1_file)['paviaU'] 62 | cls1 = gt.max() 63 | 64 | im1x,im1y,im1z = im.shape 65 | im = np.float32(im) 66 | 67 | # segmentation on top-3 PCs 68 | estimator = IncrementalPCA(n_components=3) 69 | estimator = PCA(n_components=3) 70 | im = im.reshape(im1x*im1y,-1) 71 | im2 = estimator.fit_transform(im) 72 | im2 = im2.reshape(im1x,im1y,3) 73 | seg = felzenszwalb(im2, scale=0.5, sigma=0.8, min_size=5, multichannel=True) 74 | im = im.reshape(im1x,im1y,im1z) 75 | 76 | # kind of normalization, the max DN of the original image is 8000 77 | im = im/5000.0 78 | 79 | # initilize controller 80 | c1 = rscls.rscls(im,gt,cls=cls1) 81 | c1.padding(patch) 82 | c1.locate_obj(seg) # locate superpixels 83 | 84 | # random seed for Monte Carlo runs 85 | np.random.seed(seedx[seedi]) 86 | x1_train,y1_train = c1.train_sample(num_per_cls1) # load train samples 87 | x1_train,y1_train = rscls.make_sample(x1_train,y1_train) # augmentation 88 | y1_train = to_categorical(y1_train) # to one-hot labels 89 | 90 | 91 | ''' training part ''' 92 | im1z = im.shape[2] 93 | if patch == 7: 94 | model1 = nw.DBMA(im1z,cls1) # 3D CNN, samples are 5-dimensional 95 | x1_train = x1_train.reshape(x1_train.shape[0],patch,patch,im1z,-1) 96 | elif patch == 5: 97 | model1 = nw.wcrn(im1z,cls1) # WCRN 98 | # model1 = nw.DCCNN(im1z,patch,cls1) # DCCNN 99 | elif patch == 9: 100 | model1 = nw.resnet99_avg(im1z,patch,cls1,l=1) 101 | else: 102 | # print('using resnet_avg') 103 | model1 = nw.resnet99_avg(im1z,patch,cls1,l=1) 104 | time2 = int(time.time()) 105 | 106 | # first train the model with lr=1.0 107 | model1.compile(loss=categorical_crossentropy,optimizer=Adadelta(lr=1.0),metrics=['accuracy']) 108 | model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=170,verbose=vbs,shuffle=True) 109 | 110 | # then train the model with lr=0.1 111 | model1.compile(loss=categorical_crossentropy,optimizer=Adadelta(lr=0.1),metrics=['accuracy']) 112 | model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=30,verbose=vbs,shuffle=True,callbacks=[early_stopping]) 113 | time3 = int(time.time()) # training time 114 | print('training time:',time3-time2) 115 | #model1.save('model'+str(time3)[-5:]+'.h5') # uncomment to save model 116 | 117 | # predict part 118 | pre_all_1 = [] 119 | for i in range(ensemble): 120 | pre_rows_1 = [] 121 | # uncomment below if snapshot ensemble activated 122 | # model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=2,verbose=vbs,shuffle=True) 123 | for j in range(im1x): 124 | #print(j) uncomment to monitor predicing stages 125 | sam_row = c1.all_sample_row(j) 126 | if patch == 7: 127 | sam_row = sam_row.reshape(sam_row.shape[0],patch,patch,im1z,1) 128 | pre_row1 = np.argmax(model1.predict(sam_row),axis=1) 129 | pre_row1 = pre_row1.reshape(1,im1y) 130 | pre_rows_1.append(pre_row1) 131 | pre_all_1.append(np.array(pre_rows_1)) 132 | 133 | time4 = int(time.time()) 134 | print('predict time:',time4-time3) # predict time 135 | 136 | # classification map and confusion matrix for raw classification 137 | pre_all_1 = np.array(pre_all_1).reshape(ensemble,im1x,im1y) 138 | pre1 = np.int8(stats.mode(pre_all_1,axis=0)[0]).reshape(im1x,im1y) 139 | result11 = rscls.gtcfm(pre1+1,c1.gt+1,cls1) 140 | saved[str(seedi)+'a'] = result11 141 | 142 | # after object-based refinement 143 | pcmap = rscls.obpc(c1.seg,pre1,c1.obj) 144 | result12 = rscls.gtcfm(pcmap+1,c1.gt+1,cls1) 145 | saved_p[str(seedi)+'b'] = result12 146 | rscls.save_cmap(pre1, 'jet', 'pre.png') 147 | rscls.save_cmap(pcmap, 'jet', 'pcmap.png') 148 | 149 | -------------------------------------------------------------------------------- /demo_keras_loadsamples.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | Last updated on Aug 12 2019 5 | @author: Shengjie Liu 6 | @Email: liushengjie0756@gmail.com 7 | 8 | This implementation reads *.tif files 9 | """ 10 | 11 | import numpy as np 12 | from scipy import stats 13 | import rscls 14 | import matplotlib.pyplot as plt 15 | import time 16 | import networks as nw 17 | from keras.utils import to_categorical 18 | import keras 19 | from keras.losses import categorical_crossentropy 20 | from keras.optimizers import Adadelta 21 | from keras import backend as K 22 | import argparse 23 | #from sklearn.covariance import GraphicalLassoCV 24 | from keras.callbacks import EarlyStopping 25 | from scipy.io import loadmat 26 | import gdal 27 | 28 | ## data location 29 | im1_file = 'F:/t0809/data/yubei0603.tif' 30 | gt1_file = 'yubei2.tif' 31 | 32 | ## number of training samples per class 33 | parser = argparse.ArgumentParser(description='manual to this script') 34 | parser.add_argument('--nos', type=int, default = 5) 35 | args = parser.parse_args() 36 | num_per_cls1 = args.nos 37 | 38 | ## network configuration 39 | patch = 5 # If patch==5, WCRN. If patch==7, BDMA. If patch==9, ResNet-avg. 40 | vbs = 0 # if vbs==0, training in silent mode 41 | bsz1 = 20 # batch size 42 | ensemble = 1 # if ensemble>1, snapshot ensemble activated 43 | # if loss not decrease for 5 epoches, stop training 44 | early_stopping = EarlyStopping(monitor='loss', patience=5, verbose=2) 45 | 46 | 47 | # for Monte Carlo runs 48 | seedx = [0,1,2,3,4,5,6,7,8,9] 49 | seedi = 0 # default seed = 0 50 | 51 | saved = {} # save confusion matrix for raw classification 52 | saved_p = {} # save confusion matrix after object-based refinement 53 | 54 | # 6509,8210 55 | 56 | # bgx,bgy,imx,imy 57 | subsets = [(0,0,6509,4105), 58 | (0,4105,6509,4105)] 59 | 60 | 61 | def setGeo(geotransform,bgx,bgy): 62 | reset0 = geotransform[0] + bgx*geotransform[1] 63 | reset3 = geotransform[3] + bgy*geotransform[5] 64 | reset = (reset0,geotransform[1],geotransform[2], 65 | reset3,geotransform[4],geotransform[5]) 66 | return reset 67 | 68 | 69 | if True: 70 | for subset in subsets: 71 | time1 = int(time.time()) 72 | K.clear_session() # clear session before next loop 73 | print(subset) 74 | 75 | gt = gdal.Open(gt1_file,gdal.GA_ReadOnly) 76 | im = gdal.Open(im1_file,gdal.GA_ReadOnly) 77 | projection = gt.GetProjection() 78 | geotransform = gt.GetGeoTransform() 79 | newgeo = setGeo(geotransform,subset[0],subset[1]) 80 | gt = gt.ReadAsArray(subset[0],subset[1],subset[2],subset[3]) 81 | im = im.ReadAsArray(subset[0],subset[1],subset[2],subset[3]) 82 | im = im.transpose(1,2,0) 83 | cls1 = gt.max() 84 | 85 | im1x,im1y,im1z = im.shape 86 | im = np.float32(im) 87 | 88 | # kind of normalization, the max DN of the original image is 10000 89 | # the goal is to make the DNs range from (-2,2) 90 | im = im/5000.0 91 | 92 | # initilize controller 93 | c1 = rscls.rscls(im,gt,cls=cls1) 94 | c1.padding(patch) 95 | 96 | # random seed for Monte Carlo runs 97 | np.random.seed(seedx[seedi]) 98 | x2_train,y2_train = c1.test_sample() 99 | print(x2_train.shape) 100 | try: 101 | x1_train = np.concatenate([x1_train,x2_train],axis=0) 102 | y1_train = np.concatenate([y1_train,y2_train],axis=0) 103 | del x2_train,y2_train 104 | except: 105 | print('If this line printed twice, failed to concatenate') 106 | x1_train = x2_train 107 | y1_train = y2_train 108 | 109 | np.save('x_train.npy',x1_train) 110 | np.save('y_train.npy',y1_train) 111 | -------------------------------------------------------------------------------- /demo_keras_predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | Last updated on Aug 12 2019 5 | @author: Shengjie Liu 6 | @Email: liushengjie0756@gmail.com 7 | 8 | This script separately load images and predict them. 9 | """ 10 | 11 | import numpy as np 12 | from scipy import stats 13 | import rscls 14 | import matplotlib.pyplot as plt 15 | import time 16 | import networks as nw 17 | from keras.utils import to_categorical 18 | import keras 19 | from keras.losses import categorical_crossentropy 20 | from keras.optimizers import Adadelta 21 | from keras import backend as K 22 | import argparse 23 | #from sklearn.covariance import GraphicalLassoCV 24 | from keras.callbacks import EarlyStopping 25 | from scipy.io import loadmat 26 | import gdal 27 | from keras.models import load_model 28 | 29 | ## data location 30 | im1_file = 'F:/t0809/data/yubei0603.tif' 31 | gt1_file = 'yubei2.tif' 32 | 33 | # load model created by demo_keras_predict.py 34 | model_file = 'model99079.h5' 35 | 36 | ## number of training samples per class 37 | parser = argparse.ArgumentParser(description='manual to this script') 38 | parser.add_argument('--nos', type=int, default = 5) 39 | args = parser.parse_args() 40 | num_per_cls1 = args.nos 41 | 42 | ## network configuration 43 | patch = 5 # If patch==5, WCRN. If patch==7, BDMA. If patch==9, ResNet-avg. 44 | vbs = 0 # if vbs==0, training in silent mode 45 | bsz1 = 20 # batch size 46 | ensemble = 1 # if ensemble>1, snapshot ensemble activated 47 | # if loss not decrease for 5 epoches, stop training 48 | early_stopping = EarlyStopping(monitor='loss', patience=5, verbose=2) 49 | 50 | 51 | # for Monte Carlo runs 52 | seedx = [0,1,2,3,4,5,6,7,8,9] 53 | seedi = 0 # default seed = 0 54 | 55 | saved = {} # save confusion matrix for raw classification 56 | saved_p = {} # save confusion matrix after object-based refinement 57 | 58 | # 6509,8210 59 | 60 | # bgx,bgy,imx,imy 61 | subsets = [(0,0,6509,4105), 62 | (0,4105,6509,4105)] 63 | 64 | 65 | def setGeo(geotransform,bgx,bgy): 66 | reset0 = geotransform[0] + bgx*geotransform[1] 67 | reset3 = geotransform[3] + bgy*geotransform[5] 68 | reset = (reset0,geotransform[1],geotransform[2], 69 | reset3,geotransform[4],geotransform[5]) 70 | return reset 71 | 72 | 73 | if True: 74 | for subset in subsets: 75 | time1 = int(time.time()) 76 | K.clear_session() # clear session before next loop 77 | print(subset) 78 | 79 | gt = gdal.Open(gt1_file,gdal.GA_ReadOnly) 80 | im = gdal.Open(im1_file,gdal.GA_ReadOnly) 81 | projection = gt.GetProjection() 82 | geotransform = gt.GetGeoTransform() 83 | newgeo = setGeo(geotransform,subset[0],subset[1]) 84 | gt = gt.ReadAsArray(subset[0],subset[1],subset[2],subset[3]) 85 | im = im.ReadAsArray(subset[0],subset[1],subset[2],subset[3]) 86 | im = im.transpose(1,2,0) 87 | cls1 = gt.max() 88 | 89 | im1x,im1y,im1z = im.shape 90 | im = np.float32(im) 91 | 92 | # kind of normalization, the max DN of the original image is 10000 93 | # the goal is to make the DNs range from (-2,2) 94 | im = im/5000.0 95 | 96 | # initilize controller 97 | c1 = rscls.rscls(im,gt,cls=cls1) 98 | c1.padding(patch) 99 | 100 | model1 = load_model(model_file) 101 | time3 = int(time.time()) 102 | 103 | # predict part 104 | pre_all_1 = [] 105 | for i in range(ensemble): 106 | pre_rows_1 = [] 107 | # uncomment below if snapshot ensemble activated 108 | # model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=2,verbose=vbs,shuffle=True) 109 | for j in range(im1x): 110 | if j%100==0: 111 | print(j) 112 | #print(j) uncomment to monitor predicing stages 113 | sam_row = c1.all_sample_row(j) 114 | if patch == 7: 115 | sam_row = sam_row.reshape(sam_row.shape[0],patch,patch,im1z,1) 116 | pre_row1 = np.argmax(model1.predict(sam_row),axis=1) 117 | pre_row1 = pre_row1.reshape(1,im1y) 118 | pre_rows_1.append(pre_row1) 119 | pre_all_1.append(np.array(pre_rows_1)) 120 | 121 | time4 = int(time.time()) 122 | print('predict time:',time4-time3) # predict time 123 | 124 | # classification map and confusion matrix for raw classification 125 | pre_all_1 = np.array(pre_all_1).reshape(ensemble,im1x,im1y) 126 | pre1 = np.int8(stats.mode(pre_all_1,axis=0)[0]).reshape(im1x,im1y) 127 | result11 = rscls.gtcfm(pre1+1,c1.gt+1,cls1) 128 | saved[str(seedi)+'a'] = result11 129 | rscls.save_cmap(pre1, 'jet', 'pre'+str(time4)[-5:]+'.png') 130 | 131 | # save as geocode-tif 132 | name = 'predict_'+str(time4)[-5:] 133 | outdata = gdal.GetDriverByName('GTiff').Create(name+'.tif', im1y, im1x, 1, gdal.GDT_UInt16) 134 | outdata.SetGeoTransform(newgeo) 135 | outdata.SetProjection(projection) 136 | outdata.GetRasterBand(1).WriteArray(pre1+1) 137 | outdata.FlushCache() ##saves to disk!! 138 | outdata = None 139 | -------------------------------------------------------------------------------- /demo_keras_tif.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | Last updated on Aug 12 2019 5 | @author: Shengjie Liu 6 | @Email: liushengjie0756@gmail.com 7 | 8 | This implementation reads *.tif files 9 | """ 10 | 11 | import numpy as np 12 | from scipy import stats 13 | import rscls 14 | import matplotlib.pyplot as plt 15 | import time 16 | import networks as nw 17 | from keras.utils import to_categorical 18 | import keras 19 | from keras.losses import categorical_crossentropy 20 | from keras.optimizers import Adadelta 21 | from keras import backend as K 22 | import argparse 23 | #from sklearn.covariance import GraphicalLassoCV 24 | from keras.callbacks import EarlyStopping 25 | from scipy.io import loadmat 26 | import gdal 27 | 28 | ## data location 29 | im1_file = 'F:/t0809/data/yubei0603.tif' 30 | gt1_file = 'yubei2.tif' 31 | 32 | ## number of training samples per class 33 | parser = argparse.ArgumentParser(description='manual to this script') 34 | parser.add_argument('--nos', type=int, default = 5) 35 | args = parser.parse_args() 36 | num_per_cls1 = args.nos 37 | 38 | ## network configuration 39 | patch = 5 # If patch==5, WCRN. If patch==7, BDMA. If patch==9, ResNet-avg. 40 | vbs = 0 # if vbs==0, training in silent mode 41 | bsz1 = 20 # batch size 42 | ensemble = 1 # if ensemble>1, snapshot ensemble activated 43 | # if loss not decrease for 5 epoches, stop training 44 | early_stopping = EarlyStopping(monitor='loss', patience=5, verbose=2) 45 | 46 | 47 | # for Monte Carlo runs 48 | seedx = [0,1,2,3,4,5,6,7,8,9] 49 | seedi = 0 # default seed = 0 50 | 51 | saved = {} # save confusion matrix for raw classification 52 | saved_p = {} # save confusion matrix after object-based refinement 53 | 54 | bgx,bgy,imx,imy = 0,0,3000,3000 55 | 56 | def setGeo(geotransform,bgx,bgy): 57 | reset0 = geotransform[0] + bgx*geotransform[1] 58 | reset3 = geotransform[3] + bgy*geotransform[5] 59 | reset = (reset0,geotransform[1],geotransform[2], 60 | reset3,geotransform[4],geotransform[5]) 61 | return reset 62 | 63 | if True: 64 | for seedi in range(0,1): 65 | time1 = int(time.time()) 66 | K.clear_session() # clear session before next loop 67 | print('seed'+str(seedi)+','+str(seedx[seedi])) 68 | 69 | gt = gdal.Open(gt1_file,gdal.GA_ReadOnly) 70 | im = gdal.Open(im1_file,gdal.GA_ReadOnly) 71 | projection = im.GetProjection() 72 | geotransform = im.GetGeoTransform() 73 | newgeo = setGeo(geotransform,bgx,bgy) 74 | gt = gt.ReadAsArray(bgx,bgy,imx,imy) 75 | gt[np.where(gt==255)] = 0 76 | im = im.ReadAsArray(bgx,bgy,imx,imy) 77 | im = im.transpose(1,2,0) 78 | cls1 = gt.max() 79 | 80 | im1x,im1y,im1z = im.shape 81 | im = np.float32(im) 82 | 83 | # kind of normalization, the max DN of the original image is 8000 84 | im = im/5000.0 85 | 86 | # initilize controller 87 | c1 = rscls.rscls(im,gt,cls=cls1) 88 | c1.padding(patch) 89 | 90 | # random seed for Monte Carlo runs 91 | np.random.seed(seedx[seedi]) 92 | x1_train,y1_train = c1.train_sample(num_per_cls1) # load train samples 93 | x1_train,y1_train = rscls.make_sample(x1_train,y1_train) # augmentation 94 | y1_train = to_categorical(y1_train) # to one-hot labels 95 | 96 | 97 | ''' training part ''' 98 | im1z = im.shape[2] 99 | if patch == 7: 100 | model1 = nw.DBMA(im1z,cls1) # 3D CNN, samples are 5-dimensional 101 | x1_train = x1_train.reshape(x1_train.shape[0],patch,patch,im1z,-1) 102 | elif patch == 5: 103 | model1 = nw.wcrn(im1z,cls1) # WCRN 104 | # model1 = nw.DCCNN(im1z,patch,cls1) # DCCNN 105 | elif patch == 9: 106 | model1 = nw.resnet99_avg(im1z,patch,cls1,l=1) 107 | else: 108 | # print('using resnet_avg') 109 | model1 = nw.resnet99_avg(im1z,patch,cls1,l=1) 110 | time2 = int(time.time()) 111 | 112 | # first train the model with lr=1.0 113 | model1.compile(loss=categorical_crossentropy,optimizer=Adadelta(lr=1.0),metrics=['accuracy']) 114 | model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=170,verbose=vbs,shuffle=True) 115 | 116 | # then train the model with lr=0.1 117 | model1.compile(loss=categorical_crossentropy,optimizer=Adadelta(lr=0.1),metrics=['accuracy']) 118 | model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=30,verbose=vbs,shuffle=True,callbacks=[early_stopping]) 119 | time3 = int(time.time()) # training time 120 | print('training time:',time3-time2) 121 | #model1.save('model'+str(time3)[-5:]+'.h5') # uncomment to save model 122 | 123 | # predict part 124 | pre_all_1 = [] 125 | for i in range(ensemble): 126 | pre_rows_1 = [] 127 | # uncomment below if snapshot ensemble activated 128 | # model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=2,verbose=vbs,shuffle=True) 129 | for j in range(im1x): 130 | if j%100==0: 131 | print(j) 132 | #print(j) uncomment to monitor predicing stages 133 | sam_row = c1.all_sample_row(j) 134 | if patch == 7: 135 | sam_row = sam_row.reshape(sam_row.shape[0],patch,patch,im1z,1) 136 | pre_row1 = np.argmax(model1.predict(sam_row),axis=1) 137 | pre_row1 = pre_row1.reshape(1,im1y) 138 | pre_rows_1.append(pre_row1) 139 | pre_all_1.append(np.array(pre_rows_1)) 140 | 141 | time4 = int(time.time()) 142 | print('predict time:',time4-time3) # predict time 143 | 144 | # classification map and confusion matrix for raw classification 145 | pre_all_1 = np.array(pre_all_1).reshape(ensemble,im1x,im1y) 146 | pre1 = np.int8(stats.mode(pre_all_1,axis=0)[0]).reshape(im1x,im1y) 147 | result11 = rscls.gtcfm(pre1+1,c1.gt+1,cls1) 148 | saved[str(seedi)+'a'] = result11 149 | rscls.save_cmap(pre1, 'jet', 'pre.png') 150 | 151 | # save as geocode-tif 152 | name = 'predict_'+str(time4)[-5:] 153 | outdata = gdal.GetDriverByName('GTiff').Create(name+'.tif', im1y, im1x, 1, gdal.GDT_UInt16, [ 'COMPRESS=LZW' ]) 154 | outdata.SetGeoTransform(newgeo) 155 | outdata.SetProjection(projection) 156 | outdata.GetRasterBand(1).WriteArray(pre1+1) 157 | outdata.FlushCache() ##saves to disk!! 158 | outdata = None 159 | 160 | -------------------------------------------------------------------------------- /demo_keras_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | Last updated on Aug 12 2019 5 | @author: Shengjie Liu 6 | @Email: liushengjie0756@gmail.com 7 | 8 | This script reads *.npy samples for training. 9 | """ 10 | 11 | import numpy as np 12 | from scipy import stats 13 | import rscls 14 | import matplotlib.pyplot as plt 15 | import time 16 | import networks as nw 17 | from keras.utils import to_categorical 18 | import keras 19 | from keras.losses import categorical_crossentropy 20 | from keras.optimizers import Adadelta 21 | from keras import backend as K 22 | import argparse 23 | #from sklearn.covariance import GraphicalLassoCV 24 | from keras.callbacks import EarlyStopping 25 | from scipy.io import loadmat 26 | import gdal 27 | 28 | ## data location 29 | im1_file = 'F:/t0809/data/yubei0603.tif' 30 | gt1_file = 'yubei2.tif' 31 | 32 | ## number of training samples per class 33 | parser = argparse.ArgumentParser(description='manual to this script') 34 | parser.add_argument('--nos', type=int, default = 5) 35 | args = parser.parse_args() 36 | num_per_cls1 = args.nos 37 | 38 | ## network configuration 39 | patch = 5 # If patch==5, WCRN. If patch==7, BDMA. If patch==9, ResNet-avg. 40 | vbs = 1 # if vbs==0, training in silent mode 41 | bsz1 = 20 # batch size 42 | ensemble = 1 # if ensemble>1, snapshot ensemble activated 43 | # if loss not decrease for 5 epoches, stop training 44 | early_stopping = EarlyStopping(monitor='loss', patience=5, verbose=2) 45 | 46 | 47 | # for Monte Carlo runs 48 | seedx = [0,1,2,3,4,5,6,7,8,9] 49 | seedi = 0 # default seed = 0 50 | 51 | saved = {} # save confusion matrix for raw classification 52 | saved_p = {} # save confusion matrix after object-based refinement 53 | 54 | if True: 55 | for seedi in range(0,1): 56 | time1 = int(time.time()) 57 | K.clear_session() # clear session before next loop 58 | 59 | x1_train,y1_train = np.load('x_train.npy'),np.load('y_train.npy') 60 | cls1 = y1_train.max()+1 61 | y1_train = to_categorical(y1_train) # to one-hot labels 62 | 63 | 64 | ''' training part ''' 65 | im1z = x1_train.shape[-1] 66 | if patch == 7: 67 | model1 = nw.DBMA(im1z,cls1) # 3D CNN, samples are 5-dimensional 68 | x1_train = x1_train.reshape(x1_train.shape[0],patch,patch,im1z,-1) 69 | elif patch == 5: 70 | model1 = nw.wcrn(im1z,cls1) # WCRN 71 | # model1 = nw.DCCNN(im1z,patch,cls1) # DCCNN 72 | elif patch == 9: 73 | model1 = nw.resnet99_avg(im1z,patch,cls1,l=1) 74 | else: 75 | # print('using resnet_avg') 76 | model1 = nw.resnet99_avg(im1z,patch,cls1,l=1) 77 | time2 = int(time.time()) 78 | 79 | # first train the model with lr=1.0 80 | print('start training') 81 | model1.compile(loss=categorical_crossentropy,optimizer=Adadelta(lr=1.0),metrics=['accuracy']) 82 | model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=170,verbose=vbs,shuffle=True,callbacks=[early_stopping]) 83 | 84 | # then train the model with lr=0.1 85 | model1.compile(loss=categorical_crossentropy,optimizer=Adadelta(lr=0.1),metrics=['accuracy']) 86 | model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=30,verbose=vbs,shuffle=True,callbacks=[early_stopping]) 87 | time3 = int(time.time()) # training time 88 | print('training time:',time3-time2) 89 | model1.save('model'+str(time3)[-5:]+'.h5') # uncomment to save model 90 | -------------------------------------------------------------------------------- /demo_pytorch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Aug 6 03:32:28 2019 4 | 5 | @author: Shengjie Liu 6 | @Email: liuishengjie0756@gmail.com 7 | """ 8 | 9 | import numpy as np 10 | import rscls 11 | from scipy import stats 12 | import time 13 | import torch 14 | import torch.nn as nn 15 | import torch.utils.data as Data 16 | from skimage.segmentation import felzenszwalb 17 | import matplotlib.pyplot as plt 18 | 19 | # Device configuration 20 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 21 | 22 | # Training configuration 23 | imfile = 'paviaU_im.npy' # width*height*channel 24 | gtfile = 'paviaU_gt.npy' # width*height, classes begin from 1 with background=0 25 | 26 | ensemble = 1 # times of snapshot ensemble 27 | num_per_cls = 10 # number of samples per class 28 | bsz = 20 # batch size 29 | patch = 5 # sample size: 5*5*channel 30 | vbs = 1 # show training process 31 | 32 | # Monte Carlo runs 33 | seedx = [0,1,2,3,4,5,6,7,8,9] 34 | seedi = 0 # default seed is 0 35 | 36 | # network definition 37 | # wide contextual residual network (WCRN) 38 | class WCRN(nn.Module): 39 | def __init__(self, num_classes=9): 40 | super(WCRN, self).__init__() 41 | 42 | self.conv1a = nn.Conv2d(103, 64, kernel_size=3, stride=1, padding=0) 43 | self.conv1b = nn.Conv2d(103, 64, kernel_size=1, stride=1, padding=0) 44 | self.maxp1 = nn.MaxPool2d(kernel_size = 3) 45 | self.maxp2 = nn.MaxPool2d(kernel_size = 5) 46 | 47 | self.bn1 = nn.BatchNorm2d(128) 48 | self.conv2a = nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0) 49 | self.conv2b = nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0) 50 | 51 | self.fc = nn.Linear(128, num_classes) 52 | 53 | def forward(self, x): 54 | out = self.conv1a(x) 55 | out1 = self.conv1b(x) 56 | out = self.maxp1(out) 57 | out1 = self.maxp2(out1) 58 | 59 | out = torch.cat((out,out1),1) 60 | 61 | out1 = self.bn1(out) 62 | out1 = nn.ReLU()(out1) 63 | out1 = self.conv2a(out1) 64 | out1 = nn.ReLU()(out1) 65 | out1 = self.conv2b(out1) 66 | 67 | out = torch.add(out,out1) 68 | out = out.reshape(out.size(0), -1) 69 | out = self.fc(out) 70 | 71 | return out 72 | 73 | 74 | #%% initilize controller and prepare training and testing samples 75 | for seedi in range(1): # for Monte Carlo runs 76 | print('random seed:',seedi) 77 | _ls = [] 78 | if True: 79 | gt = np.load(gtfile) 80 | cls1 = gt.max() 81 | im = np.load(imfile) 82 | imx,imy,imz = im.shape 83 | c = rscls.rscls(im,gt,cls=cls1) 84 | c.padding(patch) 85 | c.normalize(style='-11') 86 | 87 | np.random.seed(seedx[seedi]) 88 | x_train,y_train = c.train_sample(num_per_cls) 89 | x_train,y_train = rscls.make_sample(x_train,y_train) 90 | 91 | x_test,y_test = c.test_sample() 92 | 93 | # segmentation 94 | seg = felzenszwalb(im[:,:,[30,50,90]],scale=0.5,sigma=0.8, 95 | min_size=5,multichannel=True) 96 | c.locate_obj(seg) # locate samples in superpixels 97 | 98 | # pytorch input: (None,channel,width,height) 99 | x_train = np.transpose(x_train, (0,3,1,2)) 100 | x_test = np.transpose(x_test, (0,3,1,2)) 101 | 102 | # convert np.array to torch.tensor 103 | x_train,y_train = torch.from_numpy(x_train),torch.from_numpy(y_train) 104 | x_test,y_test = torch.from_numpy(x_test),torch.from_numpy(y_test) 105 | 106 | # keep it in case of errors 107 | y_test = y_test.long() 108 | y_train = y_train.long() 109 | 110 | # define dataset for training and testing 111 | train_set = Data.TensorDataset(x_train,y_train) 112 | test_set = Data.TensorDataset(x_test,y_test) 113 | 114 | train_loader = Data.DataLoader( 115 | dataset = train_set, 116 | batch_size = bsz, 117 | shuffle = True, 118 | num_workers = 0, 119 | ) 120 | 121 | test_loader = Data.DataLoader( 122 | dataset = test_set, 123 | batch_size = bsz, 124 | shuffle = False, 125 | num_workers = 0, 126 | ) 127 | 128 | #%% begin training 129 | time1 = int(time.time()) 130 | model = WCRN(cls1) 131 | model.to(device) # using gpu or cpu 132 | criterion = nn.CrossEntropyLoss() 133 | 134 | # train the model using lr=1.0 135 | train_model = True 136 | if train_model: 137 | lr = 1.0 138 | optimizer = torch.optim.Adadelta(model.parameters(), lr=lr) 139 | model.train() 140 | total_step = len(train_loader) 141 | num_epochs = 25 142 | for epoch in range(num_epochs): 143 | for i, (images, labels) in enumerate(train_loader): 144 | images = images.to(device) 145 | labels = labels.to(device) 146 | 147 | # Forward pass 148 | outputs = model(images) 149 | loss = criterion(outputs, labels) 150 | 151 | # Backward and optimize 152 | optimizer.zero_grad() 153 | loss.backward() 154 | optimizer.step() 155 | 156 | if (i+1) % 100 == 0: 157 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 158 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 159 | 160 | 161 | #%% train the model using lr=0.8 162 | model.train() 163 | lr = 0.8 164 | optimizer = torch.optim.Adadelta(model.parameters(), lr=lr) 165 | total_step = len(train_loader) 166 | num_epochs = 15 167 | for epoch in range(num_epochs): 168 | for i, (images, labels) in enumerate(train_loader): 169 | images = images.to(device) # sample to gpu/cpu 170 | labels = labels.to(device) # label to gpu/cpu 171 | 172 | # Forward pass 173 | outputs = model(images) 174 | loss = criterion(outputs, labels) 175 | 176 | # Backward and optimize 177 | optimizer.zero_grad() 178 | loss.backward() 179 | optimizer.step() 180 | 181 | if (i+1) % 100 == 0: # print training 182 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 183 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 184 | 185 | #%% Train the model using lr=0.1 186 | lr = 0.1 187 | optimizer = torch.optim.Adadelta(model.parameters(), lr=lr) 188 | total_step = len(train_loader) 189 | num_epochs = 10 190 | for epoch in range(num_epochs): 191 | for i, (images, labels) in enumerate(train_loader): 192 | images = images.to(device) 193 | labels = labels.to(device) 194 | 195 | # Forward pass 196 | outputs = model(images) 197 | loss = criterion(outputs, labels) 198 | 199 | # Backward and optimize 200 | optimizer.zero_grad() 201 | loss.backward() 202 | optimizer.step() 203 | 204 | if (i+1) % 100 == 0: 205 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 206 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 207 | time2 = int(time.time()) 208 | print('training time:',time2-time1,'s') 209 | 210 | #%% Test the model 211 | model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) 212 | with torch.no_grad(): 213 | correct = 0 214 | total = 0 215 | for images, labels in test_loader: 216 | images = images.to(device) 217 | labels = labels.to(device) 218 | outputs = model(images) 219 | 220 | _, predicted = torch.max(outputs.data, 1) 221 | total += labels.size(0) 222 | correct += (predicted == labels).sum().item() 223 | print('Test Accuracy of the model: {} %'.format(100 * correct / total)) 224 | 225 | #%% predict image 226 | time3 = int(time.time()) 227 | pre_all_1 = [] 228 | model.eval() 229 | with torch.no_grad(): 230 | for i in range(ensemble): 231 | pre_rows_1 = [] 232 | 233 | # uncommment if ensemble>1 234 | # model1.fit(x1_train,y1_train,batch_size=bsz1,epochs=2,verbose=vbs,shuffle=True) 235 | for j in range(imx): 236 | # print(j) # monitor predicting stages 237 | sam_row = c.all_sample_row(j) 238 | sam_row = np.transpose(sam_row, (0,3,1,2)) 239 | pre_row1 = model(torch.from_numpy(sam_row).to(device)) 240 | pre_row1 = np.argmax(np.array(pre_row1.cpu()),axis=1) 241 | pre_row1 = pre_row1.reshape(1,imy) 242 | pre_rows_1.append(pre_row1) 243 | pre_all_1.append(np.array(pre_rows_1)) 244 | 245 | time4 = int(time.time()) 246 | print('predicted time:',time4-time3,'s') 247 | 248 | # raw classification 249 | pre_all_1 = np.array(pre_all_1).reshape(ensemble,imx,imy) 250 | pre1 = np.int8(stats.mode(pre_all_1,axis=0)[0]).reshape(imx,imy) 251 | result11 = rscls.gtcfm(pre1+1,c.gt+1,cls1) 252 | 253 | # after post processin using superpixel-based refinement 254 | pcmap = rscls.obpc(c.seg,pre1,c.obj) 255 | result12 = rscls.gtcfm(pcmap+1,c.gt+1,cls1) 256 | rscls.save_cmap(pre1,'jet','pre.png') 257 | rscls.save_cmap(pcmap,'jet','pcmap.png') 258 | 259 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import keras 3 | from keras.models import Model 4 | from keras.layers import concatenate, Dense, Dropout, Flatten, Add, SpatialDropout2D, Conv3D 5 | from keras.layers import Conv2D, MaxPooling2D, Input, Activation,AveragePooling2D,BatchNormalization 6 | from keras.layers import MaxPooling3D, AveragePooling3D 7 | from keras import backend as K 8 | from keras import regularizers 9 | from keras import initializers 10 | from keras.initializers import he_normal, RandomNormal 11 | from keras.layers import multiply, GlobalAveragePooling2D, GlobalAveragePooling3D 12 | from keras.layers.core import Reshape, Dropout 13 | 14 | def DCCNN(band, imx, ncla1): 15 | input1 = Input(shape=(imx,imx,band)) 16 | 17 | # define network 18 | conv01 = Conv2D(128,kernel_size=(1,1),padding='valid', 19 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 20 | conv02 = Conv2D(128,kernel_size=(3,3),padding='valid', 21 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 22 | conv03 = Conv2D(128,kernel_size=(5,5),padding='valid', 23 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 24 | bn1 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 25 | beta_initializer='zeros',gamma_initializer='ones', 26 | moving_mean_initializer='zeros', 27 | moving_variance_initializer='ones') 28 | bn2 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 29 | beta_initializer='zeros',gamma_initializer='ones', 30 | moving_mean_initializer='zeros', 31 | moving_variance_initializer='ones') 32 | conv0 = Conv2D(128,kernel_size=(1,1),padding='same', 33 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 34 | conv11 = Conv2D(128,kernel_size=(1,1),padding='same', 35 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 36 | conv12 = Conv2D(128,kernel_size=(1,1),padding='same', 37 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 38 | conv21 = Conv2D(128,kernel_size=(1,1),padding='same', 39 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 40 | conv22 = Conv2D(128,kernel_size=(1,1),padding='same', 41 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 42 | conv31 = Conv2D(128,kernel_size=(1,1),padding='same', 43 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 44 | conv32 = Conv2D(128,kernel_size=(1,1),padding='same', 45 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 46 | conv33 = Conv2D(128,kernel_size=(1,1),padding='same', 47 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 48 | fc1 = Dense(ncla1,activation='softmax',name='output1', 49 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 50 | 51 | # begin 52 | x1 = conv01(input1) 53 | x2 = conv02(input1) 54 | x3 = conv03(input1) 55 | x1 = MaxPooling2D(pool_size=(5,5))(x1) 56 | x2 = MaxPooling2D(pool_size=(3,3))(x2) 57 | x1 = concatenate([x1,x2,x3],axis=-1) 58 | 59 | x1 = Activation('relu')(x1) 60 | x1 = bn1(x1) 61 | x1 = conv0(x1) 62 | 63 | x11 = Activation('relu')(x1) 64 | x11 = bn2(x11) 65 | x11 = conv11(x11) 66 | x11 = Activation('relu')(x11) 67 | x11 = conv12(x11) 68 | x1 = Add()([x1,x11]) 69 | 70 | x11 = Activation('relu')(x1) 71 | x11 = conv21(x11) 72 | x11 = Activation('relu')(x11) 73 | x11 = conv22(x11) 74 | x1 = Add()([x1,x11]) 75 | 76 | x1 = Activation('relu')(x1) 77 | x1 = conv31(x1) 78 | x1 = Activation('relu')(x1) 79 | x1 = Dropout(0.5)(x1) 80 | x1 = conv32(x1) 81 | x1 = Activation('relu')(x1) 82 | x1 = Dropout(0.5)(x1) 83 | x1 = conv33(x1) 84 | 85 | x1 = Flatten()(x1) 86 | pre1 = fc1(x1) 87 | 88 | model1 = Model(inputs=input1, outputs=pre1) 89 | return model1 90 | 91 | 92 | def DBMA(band, ncla1): 93 | input1 = Input(shape=(7,7,band,1)) 94 | 95 | ## spectral branch 96 | conv11 = Conv3D(24,kernel_size=(1,1,7),strides=(1,1,2),padding='valid', 97 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 98 | 99 | bn12 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 100 | beta_initializer='zeros',gamma_initializer='ones', 101 | moving_mean_initializer='zeros', 102 | moving_variance_initializer='ones') 103 | conv12 = Conv3D(24,kernel_size=(1,1,7),strides=(1,1,1),padding='same', 104 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 105 | 106 | bn13 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 107 | beta_initializer='zeros',gamma_initializer='ones', 108 | moving_mean_initializer='zeros', 109 | moving_variance_initializer='ones') 110 | conv13 = Conv3D(24,kernel_size=(1,1,7),strides=(1,1,1),padding='same', 111 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 112 | 113 | bn14 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 114 | beta_initializer='zeros',gamma_initializer='ones', 115 | moving_mean_initializer='zeros', 116 | moving_variance_initializer='ones') 117 | conv14 = Conv3D(24,kernel_size=(1,1,7),strides=(1,1,1),padding='same', 118 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 119 | bn15 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 120 | beta_initializer='zeros',gamma_initializer='ones', 121 | moving_mean_initializer='zeros', 122 | moving_variance_initializer='ones') 123 | conv15 = Conv3D(60,kernel_size=(1,1,4),strides=(1,1,1),padding='valid', 124 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 125 | fc11 = Dense(30,activation=None,kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 126 | fc12 = Dense(60,activation=None,kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 127 | 128 | ## spatial branch 129 | conv21 = Conv3D(24,kernel_size=(1,1,band),strides=(1,1,1),padding='valid', 130 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 131 | 132 | bn22 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 133 | beta_initializer='zeros',gamma_initializer='ones', 134 | moving_mean_initializer='zeros', 135 | moving_variance_initializer='ones') 136 | conv22 = Conv3D(12,kernel_size=(3,3,1),strides=(1,1,1),padding='same', 137 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 138 | 139 | bn23 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 140 | beta_initializer='zeros',gamma_initializer='ones', 141 | moving_mean_initializer='zeros', 142 | moving_variance_initializer='ones') 143 | conv23 = Conv3D(12,kernel_size=(3,3,1),strides=(1,1,1),padding='same', 144 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 145 | 146 | bn24 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 147 | beta_initializer='zeros',gamma_initializer='ones', 148 | moving_mean_initializer='zeros', 149 | moving_variance_initializer='ones') 150 | conv24 = Conv3D(12,kernel_size=(3,3,1),strides=(1,1,1),padding='same', 151 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 152 | bn25 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 153 | beta_initializer='zeros',gamma_initializer='ones', 154 | moving_mean_initializer='zeros', 155 | moving_variance_initializer='ones') 156 | conv25 = Conv3D(24,kernel_size=(3,3,1),strides=(1,1,1),padding='same', 157 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 158 | conv26 = Conv3D(1,activation=None,kernel_size=(3,3,2),strides=(1,1,2),padding='same', 159 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 160 | 161 | fc = Dense(ncla1,activation='softmax',name='output1', 162 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 163 | 164 | # spectral 165 | x1 = conv11(input1) 166 | 167 | x11 = bn12(x1) 168 | x11 = Activation('relu')(x11) 169 | x11 = conv12(x11) 170 | 171 | x12 = concatenate([x1,x11],axis=-1) 172 | x12 = bn13(x12) 173 | x12 = Activation('relu')(x12) 174 | x12 = conv13(x12) 175 | 176 | x13 = concatenate([x1,x11,x12],axis=-1) 177 | x13 = bn14(x13) 178 | x13 = Activation('relu')(x13) 179 | x13 = conv14(x13) 180 | 181 | x14 = concatenate([x1,x11,x12,x13],axis=-1) 182 | x14 = bn15(x14) 183 | x14 = Activation('relu')(x14) 184 | x14 = conv15(x14) 185 | 186 | x1_max = MaxPooling3D(pool_size=(7,7,1))(x14) 187 | x1_avg = AveragePooling3D(pool_size=(7,7,1))(x14) 188 | 189 | x1_max = fc11(x1_max) 190 | x1_max = fc12(x1_max) 191 | 192 | x1_avg = fc11(x1_avg) 193 | x1_avg = fc12(x1_avg) 194 | 195 | x1 = Add()([x1_max,x1_avg]) 196 | x1 = Activation('sigmoid')(x1) 197 | x1 = multiply([x1,x14]) 198 | x1 = GlobalAveragePooling3D()(x1) 199 | 200 | # spatial 201 | x2 = conv21(input1) 202 | x21 = bn22(x2) 203 | x21 = Activation('relu')(x21) 204 | x21 = conv22(x21) 205 | 206 | x22 = concatenate([x2,x21],axis=-1) 207 | x22 = bn23(x22) 208 | x22 = Activation('relu')(x22) 209 | x22 = conv23(x22) 210 | 211 | x23 = concatenate([x2,x21,x22],axis=-1) 212 | x23 = bn24(x23) 213 | x23 = Activation('relu')(x23) 214 | x23 = conv24(x23) 215 | 216 | x24 = concatenate([x2,x21,x22,x23],axis=-1) 217 | x24 = Reshape(target_shape=(7,7,60,1))(x24) 218 | 219 | x2_max = MaxPooling3D(pool_size=(1,1,60))(x24) 220 | x2_avg = AveragePooling3D(pool_size=(1,1,60))(x24) 221 | 222 | x2_max = Reshape(target_shape=(7,7,1))(x2_max) 223 | x2_avg = Reshape(target_shape=(7,7,1))(x2_avg) 224 | 225 | x25 = concatenate([x2_max,x2_avg],axis=-1) 226 | x25 = Reshape(target_shape=(7,7,2,1))(x25) 227 | x25 = conv26(x25) 228 | x25 = Activation('sigmoid')(x25) 229 | 230 | x2 = multiply([x24,x25]) 231 | x2 = Reshape(target_shape=(7,7,1,60))(x2) 232 | x2 = GlobalAveragePooling3D()(x2) 233 | 234 | x = concatenate([x1,x2],axis=-1) 235 | pre = fc(x) 236 | 237 | model = Model(inputs=input1, outputs=pre) 238 | return model 239 | 240 | def resnet99_avg_se(band, imx, ncla1, l=1): 241 | input1 = Input(shape=(imx,imx,band)) 242 | 243 | # define network 244 | conv0x = Conv2D(32,kernel_size=(3,3),padding='valid', 245 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 246 | conv0 = Conv2D(32,kernel_size=(3,3),padding='valid', 247 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 248 | bn11 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 249 | beta_initializer='zeros',gamma_initializer='ones', 250 | moving_mean_initializer='zeros', 251 | moving_variance_initializer='ones') 252 | conv11 = Conv2D(64,kernel_size=(3,3),padding='same', 253 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 254 | conv12 = Conv2D(64,kernel_size=(3,3),padding='same', 255 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 256 | fc11 = Dense(4,activation=None,kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 257 | fc12 = Dense(64,activation=None,kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 258 | 259 | 260 | bn21 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 261 | beta_initializer='zeros',gamma_initializer='ones', 262 | moving_mean_initializer='zeros', 263 | moving_variance_initializer='ones') 264 | conv21 = Conv2D(64,kernel_size=(3,3),padding='same', 265 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 266 | conv22 = Conv2D(64,kernel_size=(3,3),padding='same', 267 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 268 | fc21 = Dense(4,activation=None,kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 269 | fc22 = Dense(64,activation=None,kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 270 | 271 | 272 | fc1 = Dense(ncla1,activation='softmax',name='output1', 273 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 274 | 275 | # x1 276 | x1 = conv0(input1) 277 | x1x = conv0x(input1) 278 | # x1 = MaxPooling2D(pool_size=(2,2))(x1) 279 | # x1x = MaxPooling2D(pool_size=(2,2))(x1x) 280 | x1 = concatenate([x1,x1x],axis=-1) 281 | x11 = bn11(x1) 282 | x11 = Activation('relu')(x11) 283 | x11 = conv11(x11) 284 | x11 = Activation('relu')(x11) 285 | x11 = conv12(x11) 286 | x12 = GlobalAveragePooling2D()(x11) 287 | x12 = fc11(x12) 288 | x12 = fc12(x12) 289 | x12 = Activation('sigmoid')(x12) 290 | x11 = multiply([x11,x12]) 291 | x1 = Add()([x1,x11]) 292 | 293 | if l==2: 294 | x11 = bn21(x1) 295 | x11 = Activation('relu')(x11) 296 | x11 = conv21(x11) 297 | x11 = Activation('relu')(x11) 298 | x11 = conv22(x11) 299 | x12 = GlobalAveragePooling2D()(x11) 300 | x12 = fc11(x12) 301 | x12 = fc12(x12) 302 | x12 = Activation('sigmoid')(x12) 303 | x11 = multiply([x11,x12]) 304 | x1 = Add()([x1,x11]) 305 | 306 | x1 = GlobalAveragePooling2D()(x1) 307 | 308 | # x1 = Flatten()(x1) 309 | pre1 = fc1(x1) 310 | 311 | model1 = Model(inputs=input1, outputs=pre1) 312 | return model1 313 | 314 | def resnet99_avg(band, imx, ncla1, l=1): 315 | input1 = Input(shape=(imx,imx,band)) 316 | 317 | # define network 318 | conv0x = Conv2D(32,kernel_size=(3,3),padding='valid', 319 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 320 | conv0 = Conv2D(32,kernel_size=(3,3),padding='valid', 321 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 322 | bn11 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 323 | beta_initializer='zeros',gamma_initializer='ones', 324 | moving_mean_initializer='zeros', 325 | moving_variance_initializer='ones') 326 | conv11 = Conv2D(64,kernel_size=(3,3),padding='same', 327 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 328 | conv12 = Conv2D(64,kernel_size=(3,3),padding='same', 329 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 330 | bn21 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 331 | beta_initializer='zeros',gamma_initializer='ones', 332 | moving_mean_initializer='zeros', 333 | moving_variance_initializer='ones') 334 | conv21 = Conv2D(64,kernel_size=(3,3),padding='same', 335 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 336 | conv22 = Conv2D(64,kernel_size=(3,3),padding='same', 337 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 338 | 339 | fc1 = Dense(ncla1,activation='softmax',name='output1', 340 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 341 | 342 | # x1 343 | x1 = conv0(input1) 344 | x1x = conv0x(input1) 345 | # x1 = MaxPooling2D(pool_size=(2,2))(x1) 346 | # x1x = MaxPooling2D(pool_size=(2,2))(x1x) 347 | x1 = concatenate([x1,x1x],axis=-1) 348 | x11 = bn11(x1) 349 | x11 = Activation('relu')(x11) 350 | x11 = conv11(x11) 351 | x11 = Activation('relu')(x11) 352 | x11 = conv12(x11) 353 | x1 = Add()([x1,x11]) 354 | 355 | if l==2: 356 | x11 = bn21(x1) 357 | x11 = Activation('relu')(x11) 358 | x11 = conv21(x11) 359 | x11 = Activation('relu')(x11) 360 | x11 = conv22(x11) 361 | x1 = Add()([x1,x11]) 362 | 363 | x1 = GlobalAveragePooling2D()(x1) 364 | 365 | # x1 = Flatten()(x1) 366 | pre1 = fc1(x1) 367 | 368 | model1 = Model(inputs=input1, outputs=pre1) 369 | return model1 370 | 371 | 372 | def resnet99(band, ncla1): 373 | input1 = Input(shape=(9,9,band)) 374 | 375 | # define network 376 | conv0x = Conv2D(32,kernel_size=(3,3),padding='valid', 377 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 378 | conv0 = Conv2D(32,kernel_size=(3,3),padding='valid', 379 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 380 | bn11 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 381 | beta_initializer='zeros',gamma_initializer='ones', 382 | moving_mean_initializer='zeros', 383 | moving_variance_initializer='ones') 384 | conv11 = Conv2D(64,kernel_size=(3,3),padding='same', 385 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 386 | conv12 = Conv2D(64,kernel_size=(3,3),padding='same', 387 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 388 | bn21 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 389 | beta_initializer='zeros',gamma_initializer='ones', 390 | moving_mean_initializer='zeros', 391 | moving_variance_initializer='ones') 392 | conv21 = Conv2D(64,kernel_size=(3,3),padding='same', 393 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 394 | conv22 = Conv2D(64,kernel_size=(3,3),padding='same', 395 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 396 | 397 | fc1 = Dense(ncla1,activation='softmax',name='output1', 398 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 399 | 400 | # x1 401 | x1 = conv0(input1) 402 | x1x = conv0x(input1) 403 | # x1 = MaxPooling2D(pool_size=(2,2))(x1) 404 | # x1x = MaxPooling2D(pool_size=(2,2))(x1x) 405 | x1 = concatenate([x1,x1x],axis=-1) 406 | x11 = bn11(x1) 407 | x11 = Activation('relu')(x11) 408 | x11 = conv11(x11) 409 | x11 = Activation('relu')(x11) 410 | x11 = conv12(x11) 411 | x1 = Add()([x1,x11]) 412 | 413 | # x11 = bn21(x1) 414 | # x11 = Activation('relu')(x11) 415 | # x11 = conv21(x11) 416 | # x11 = Activation('relu')(x11) 417 | # x11 = conv22(x11) 418 | # x1 = Add()([x1,x11]) 419 | 420 | x1 = Flatten()(x1) 421 | pre1 = fc1(x1) 422 | 423 | model1 = Model(inputs=input1, outputs=pre1) 424 | return model1 425 | 426 | def wcrn3D(band, ncla1): 427 | input1 = Input(shape=(5,5,band)) 428 | 429 | # define network 430 | conv0x = Conv2D(64,kernel_size=(1,1,7),padding='valid', 431 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 432 | conv0 = Conv2D(64,kernel_size=(3,3,1),padding='valid', 433 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 434 | bn11 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 435 | beta_initializer='zeros',gamma_initializer='ones', 436 | moving_mean_initializer='zeros', 437 | moving_variance_initializer='ones') 438 | conv11 = Conv2D(128,kernel_size=(1,1),padding='same', 439 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 440 | conv12 = Conv2D(128,kernel_size=(1,1),padding='same', 441 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 442 | fc1 = Dense(ncla1,activation='softmax',name='output1', 443 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 444 | 445 | # x1 446 | x1 = conv0(input1) 447 | x1x = conv0x(input1) 448 | x1 = MaxPooling2D(pool_size=(3,3))(x1) 449 | x1x = MaxPooling2D(pool_size=(5,5))(x1x) 450 | x1 = concatenate([x1,x1x],axis=-1) 451 | x11 = bn11(x1) 452 | x11 = Activation('relu')(x11) 453 | x11 = conv11(x11) 454 | x11 = Activation('relu')(x11) 455 | x11 = conv12(x11) 456 | x1 = Add()([x1,x11]) 457 | 458 | x1 = Flatten()(x1) 459 | pre1 = fc1(x1) 460 | 461 | model1 = Model(inputs=input1, outputs=pre1) 462 | return model1 463 | 464 | def wcrn(band, ncla1): 465 | input1 = Input(shape=(5,5,band)) 466 | 467 | # define network 468 | conv0x = Conv2D(64,kernel_size=(1,1),padding='valid', 469 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 470 | conv0 = Conv2D(64,kernel_size=(3,3),padding='valid', 471 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 472 | bn11 = BatchNormalization(axis=-1,momentum=0.9,epsilon=0.001,center=True,scale=True, 473 | beta_initializer='zeros',gamma_initializer='ones', 474 | moving_mean_initializer='zeros', 475 | moving_variance_initializer='ones') 476 | conv11 = Conv2D(128,kernel_size=(1,1),padding='same', 477 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 478 | conv12 = Conv2D(128,kernel_size=(1,1),padding='same', 479 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 480 | # 481 | fc1 = Dense(ncla1,activation='softmax',name='output1', 482 | kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 483 | 484 | # x1 485 | x1 = conv0(input1) 486 | x1x = conv0x(input1) 487 | x1 = MaxPooling2D(pool_size=(3,3))(x1) 488 | x1x = MaxPooling2D(pool_size=(5,5))(x1x) 489 | x1 = concatenate([x1,x1x],axis=-1) 490 | x11 = bn11(x1) 491 | x11 = Activation('relu')(x11) 492 | x11 = conv11(x11) 493 | x11 = Activation('relu')(x11) 494 | x11 = conv12(x11) 495 | x1 = Add()([x1,x11]) 496 | 497 | x1 = Flatten()(x1) 498 | pre1 = fc1(x1) 499 | 500 | model1 = Model(inputs=input1, outputs=pre1) 501 | return model1 502 | -------------------------------------------------------------------------------- /pytorch/demo_pytorch_v1.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.optim.lr_scheduler import StepLR,MultiStepLR 9 | import torch.utils.data as Data 10 | import torchnet 11 | import rscls 12 | import numpy as np 13 | from scipy import stats 14 | import time 15 | 16 | #%% arguments 17 | #Net = torchnet.wcrn 18 | Net = torchnet.resnet99_avg 19 | imfile = 'paU_im.npy' 20 | gtfile = 'paU_gt.npy' 21 | patch = 9 22 | vbs = 1 23 | 24 | seedx = [0,1,2,3,4,5,6,7,8,9] 25 | seedi = 0 26 | criterion = nn.CrossEntropyLoss() 27 | 28 | #%% 29 | def train(args, model, device, train_loader, optimizer, epoch, vbs=0): 30 | model.train() 31 | for batch_idx, (data, target) in enumerate(train_loader): 32 | data, target = data.to(device), target.to(device) 33 | optimizer.zero_grad() 34 | output = model(data) 35 | loss = criterion(output, target) 36 | loss.backward() 37 | optimizer.step() 38 | if vbs==0: 39 | continue 40 | if batch_idx % args.log_interval == 0: 41 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 42 | epoch, batch_idx * len(data), len(train_loader.dataset), 43 | 100. * batch_idx / len(train_loader), loss.item())) 44 | 45 | 46 | def test(args, model, device, test_loader): 47 | model.eval() 48 | test_loss = 0 49 | correct = 0 50 | with torch.no_grad(): 51 | for data, target in test_loader: 52 | data, target = data.to(device), target.to(device) 53 | output = model(data) 54 | test_loss += criterion(output, target).item() # sum up batch loss 55 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 56 | correct += pred.eq(target.view_as(pred)).sum().item() 57 | 58 | test_loss /= len(test_loader.dataset) 59 | 60 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 61 | test_loss, correct, len(test_loader.dataset), 62 | 100. * correct / len(test_loader.dataset))) 63 | 64 | #%% begin training 65 | #if True: 66 | oa = [] 67 | for seedi in range(10): # for Monte Carlo runs 68 | print('random seed:',seedi) 69 | parser = argparse.ArgumentParser(description='PyTorch PaviaU') 70 | parser.add_argument('--nps', type=int, default=10) 71 | parser.add_argument('--batch-size', type=int, default=64) 72 | parser.add_argument('--test-batch-size', type=int, default=1000) 73 | parser.add_argument('--epochs', type=int, default=200) 74 | parser.add_argument('--lr', type=float, default=1.0) 75 | parser.add_argument('--gamma', type=float, default=0.1) 76 | parser.add_argument('--no-cuda', action='store_true', default=True, 77 | help='disables CUDA training') 78 | parser.add_argument('--seed', type=int, default=1) 79 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 80 | help='how many batches to wait before logging training status') 81 | parser.add_argument('--save-model', action='store_true', default=False) 82 | args = parser.parse_args() 83 | use_cuda = not args.no_cuda and torch.cuda.is_available() 84 | torch.manual_seed(args.seed) 85 | device = torch.device("cuda" if use_cuda else "cpu") 86 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 87 | 88 | #%% 89 | time1 = time.time() 90 | gt = np.load(gtfile) 91 | cls1 = gt.max() 92 | im = np.load(imfile) 93 | im = np.float32(im) 94 | im = im/5000.0 95 | imx,imy,imz = im.shape 96 | c = rscls.rscls(im,gt,cls=cls1) 97 | c.padding(patch) 98 | # c.normalize(style='01') 99 | 100 | np.random.seed(seedx[seedi]) 101 | x_train,y_train = c.train_sample(args.nps) 102 | x_train,y_train = rscls.make_sample(x_train,y_train) 103 | x_test,y_test = c.test_sample() 104 | x_train = np.transpose(x_train, (0,3,1,2)) 105 | x_test = np.transpose(x_test, (0,3,1,2)) 106 | 107 | x_train,y_train = torch.from_numpy(x_train),torch.from_numpy(y_train) 108 | x_test,y_test = torch.from_numpy(x_test),torch.from_numpy(y_test) 109 | 110 | y_test = y_test.long() 111 | y_train = y_train.long() 112 | 113 | train_set = Data.TensorDataset(x_train,y_train) 114 | test_set = Data.TensorDataset(x_test,y_test) 115 | 116 | train_loader = Data.DataLoader( 117 | dataset = train_set, 118 | batch_size = args.batch_size, 119 | shuffle = True, 120 | **kwargs 121 | ) 122 | 123 | test_loader = Data.DataLoader( 124 | dataset = test_set, 125 | batch_size = args.test_batch_size, 126 | shuffle = False, 127 | **kwargs 128 | ) 129 | 130 | time2 = int(time.time()) 131 | print('load time:',time2-time1,'s') 132 | 133 | model = Net().to(device) 134 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 135 | 136 | scheduler = MultiStepLR(optimizer, milestones=[170,200], gamma=args.gamma) 137 | for epoch in range(1, args.epochs + 1): 138 | train(args, model, device, train_loader, optimizer, epoch, vbs=vbs) 139 | #test(args, model, device, test_loader) 140 | scheduler.step() 141 | 142 | time3 = int(time.time()) 143 | print('train time:',time3-time2,'s') 144 | 145 | # single test 146 | # test(args,model,device,test_loader) 147 | time4 = int(time.time()) 148 | # print('test time:',time4-time3,'s') 149 | 150 | # predict 151 | pre_all_1 = [] 152 | model.eval() 153 | with torch.no_grad(): 154 | ensemble = 1 155 | for i in range(ensemble): 156 | pre_rows_1 = [] 157 | for j in range(imx): 158 | # print(j) # monitor predicting stages 159 | sam_row = c.all_sample_row(j) 160 | sam_row = np.transpose(sam_row, (0,3,1,2)) 161 | pre_row1 = model(torch.from_numpy(sam_row).to(device)) 162 | pre_row1 = np.argmax(np.array(pre_row1.cpu()),axis=1) 163 | pre_row1 = pre_row1.reshape(1,imy) 164 | pre_rows_1.append(pre_row1) 165 | pre_all_1.append(np.array(pre_rows_1)) 166 | 167 | time5 = int(time.time()) 168 | print('predicted time:',time5-time4,'s') 169 | 170 | pre_all_1 = np.array(pre_all_1).reshape(ensemble,imx,imy) 171 | pre1 = np.int8(stats.mode(pre_all_1,axis=0)[0]).reshape(imx,imy) 172 | result11 = rscls.gtcfm(pre1+1,c.gt+1,cls1) 173 | oa.append(result11[-1,0]) 174 | rscls.save_cmap(pre1,'jet','pre.png') 175 | 176 | 177 | if args.save_model: 178 | torch.save(model.state_dict(), "mnist_cnn.pt") 179 | 180 | #%% 181 | oa2 = np.array(oa) 182 | print(oa2.mean(),oa2.std()) 183 | -------------------------------------------------------------------------------- /pytorch/readme.md: -------------------------------------------------------------------------------- 1 | This is the pytorch implementation 2 | -------------------------------------------------------------------------------- /pytorch/rscls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This is a script for satellite image classification 4 | Last updated on Aug 6 2019 5 | 6 | @author: Shengjie Liu 7 | @Email: liushengjie0756@gmail.com 8 | 9 | @functions 10 | 1. generate samples from satellite images 11 | 2. grid search SVM/random forest parameters 12 | 3. object-based post-classification refinement 13 | superpixel-based regularization for classification maps 14 | 15 | 4. confusion matrix: OA, kappa, PA, UA, AA 16 | 5. save maps as images 17 | 18 | 19 | @sample codes 20 | c = rscls.rscls(image,ground_truth,cls=number_of_classes) 21 | c.padding(patch) 22 | c.normalize(style='-11') # optional 23 | 24 | x_train,y_train = c.train_sample(num_per_cls) 25 | x_train,y_train = rscls.make_sample(x_train,y_train) 26 | 27 | x_test,y_test = c.test_sample() 28 | 29 | # for superpixel refinement 30 | c.locate_obj(seg) 31 | pcmap = rscls.obpc(c.seg,predicted,c.obj) 32 | 33 | 34 | @Notes 35 | Ground truth file should be uint8 format begin with 1 36 | Background = 0 37 | """ 38 | 39 | 40 | 41 | import numpy as np 42 | import copy 43 | import scipy.stats as stats 44 | from sklearn.svm import SVC 45 | from sklearn.model_selection import GridSearchCV 46 | from sklearn.ensemble import RandomForestClassifier 47 | from sklearn.naive_bayes import GaussianNB 48 | import matplotlib.pyplot as plt 49 | 50 | class rscls: 51 | def __init__(self,im,gt,cls): 52 | if cls==0: 53 | print('num of class not specified !!') 54 | self.im = copy.deepcopy(im) 55 | if gt.max()!=cls: 56 | self.gt = copy.deepcopy(gt-1) 57 | else: 58 | self.gt = copy.deepcopy(gt-1) 59 | self.gt_b = copy.deepcopy(gt) 60 | self.cls = cls 61 | self.patch = 1 62 | self.imx,self.imy,self.imz = self.im.shape 63 | self.record = [] 64 | self.sample = {} 65 | 66 | def padding(self,patch): 67 | self.patch = patch 68 | pad = self.patch//2 69 | r1 = np.repeat([self.im[0,:,:]], pad, axis=0) 70 | r2 = np.repeat([self.im[-1,:,:]], pad, axis=0) 71 | self.im = np.concatenate((r1, self.im, r2)) 72 | r1 = np.reshape(self.im[:,0,:],[self.imx + 2 * pad, 1, self.imz]) 73 | r2 = np.reshape(self.im[:,-1,:],[self.imx + 2 * pad, 1, self.imz]) 74 | r1 = np.repeat(r1, pad, axis=1) 75 | r2 = np.repeat(r2, pad, axis=1) 76 | self.im = np.concatenate((r1, self.im, r2), axis=1) 77 | self.im = self.im.astype('float32') 78 | 79 | def normalize(self,style='01'): 80 | im = self.im 81 | for i in range(im.shape[-1]): 82 | im[:,:,i]=(im[:,:,i]-im[:,:,i].min())/(im[:,:,i].max()-im[:,:,i].min()) 83 | if style == '-11': 84 | im = im*2-1 85 | 86 | def locate_sample(self): 87 | sam = [] 88 | for i in range(self.cls): 89 | _xy = np.array(np.where(self.gt==i)).T 90 | _sam = np.concatenate([_xy,i*np.ones([_xy.shape[0],1])],axis=-1) 91 | try: 92 | sam = np.concatenate([sam,_sam],axis=0) 93 | except: 94 | sam = _sam 95 | self.sample = sam.astype(int) 96 | 97 | def get_patch(self, xy): 98 | d = self.patch//2 99 | x = xy[0] 100 | y = xy[1] 101 | try: 102 | self.im[x][y] 103 | except IndexError: 104 | return [] 105 | x += d 106 | y += d 107 | sam = self.im[(x - d):(x + d + 1), (y - d):(y + d + 1)] 108 | return np.array(sam) 109 | 110 | def train_sample(self,pn): 111 | x_train,y_train = [],[] 112 | self.locate_sample() 113 | _samp = self.sample 114 | for _cls in range(self.cls): 115 | _xy = _samp[_samp[:,2]==_cls] 116 | np.random.shuffle(_xy) 117 | _xy = _xy[:pn,:] 118 | for xy in _xy: 119 | self.gt[xy[0],xy[1]] = 255 # !! 120 | # 121 | x_train.append(self.get_patch(xy[:-1])) 122 | y_train.append(xy[-1]) 123 | # print(_xy) 124 | x_train,y_train = np.array(x_train), np.array(y_train) 125 | idx = np.random.permutation(x_train.shape[0]) 126 | x_train = x_train[idx] 127 | y_train = y_train[idx] 128 | return x_train,y_train.astype(int) 129 | 130 | def test_sample(self): 131 | x_test,y_test = [],[] 132 | self.locate_sample() 133 | _samp = self.sample 134 | for _cls in range(self.cls): 135 | _xy = _samp[_samp[:,2]==_cls] 136 | np.random.shuffle(_xy) 137 | for xy in _xy: 138 | x_test.append(self.get_patch(xy[:-1])) 139 | y_test.append(xy[-1]) 140 | return np.array(x_test), np.array(y_test) 141 | 142 | def all_sample(self): 143 | imx,imy = self.gt.shape 144 | sample = [] 145 | for i in range(imx): 146 | for j in range(imy): 147 | sample.append(self.get_patch(np.array([i,j]))) 148 | return np.array(sample) 149 | 150 | def all_sample_light(self,clip=0,bs=10): 151 | imx,imy = self.gt.shape 152 | imz = self.im.shape[-1] 153 | patch = self.patch 154 | # fp = np.memmap('allsample' + str(clip) + '.h5', dtype='float32', mode='w+', shape=(imgx*self.IMGY,5,5,bs)) 155 | fp = np.zeros([imx*imy,patch,patch,imz]) 156 | countnum = 0 157 | for i in range(imx*clip,imx*(clip+1)): 158 | for j in range(imy): 159 | xy = np.array([i,j]) 160 | fp[countnum,:,:,:] = self.get_patch(xy) 161 | countnum += 1 162 | return fp 163 | 164 | def all_sample_row_hd(self,sub=0): 165 | imx,imy = self.gt.shape 166 | imz = self.im.shape[-1] 167 | patch = self.patch 168 | # fp = np.memmap('allsample' + str(clip) + '.h5', dtype='float32', mode='w+', shape=(imgx*self.IMGY,5,5,bs)) 169 | fp = np.zeros([imx*imy,patch,patch,imz]) 170 | countnum = 0 171 | for i in range(sub): 172 | for j in range(imy): 173 | xy = np.array([i,j]) 174 | fp[countnum,:,:,:] = self.get_patch(xy) 175 | countnum += 1 176 | return fp 177 | 178 | def all_sample_row(self,sub=0): 179 | imx,imy = self.gt.shape 180 | fp = [] 181 | for j in range(imy): 182 | xy = np.array([sub,j]) 183 | fp.append(self.get_patch(xy)) 184 | return np.array(fp) 185 | 186 | def all_sample_heavy(self,name,clip=0,bs=10): 187 | imx,imy = self.gt.shape 188 | imz = self.im.shape[-1] 189 | patch = self.patch 190 | try: 191 | fp = np.memmap(name, dtype='float32', mode='w+', shape=(imx*imy,patch,patch,imz)) 192 | except: 193 | fp = np.memmap(name, dtype='float32', mode='r', shape=(imx*imy,patch,patch,imz)) 194 | # fp = np.zeros([imx*imy,patch,patch,imz]) 195 | countnum = 0 196 | for i in range(imx*clip,imx*(clip+1)): 197 | for j in range(imy): 198 | xy = np.array([i,j]) 199 | fp[countnum,:,:,:] = self.get_patch(xy) 200 | countnum += 1 201 | return fp 202 | 203 | def read_all_sample(self,name,clip=0,bs=10): 204 | imx,imy = self.gt.shape 205 | imz = self.im.shape[-1] 206 | patch = self.patch 207 | fp = np.memmap(name, dtype='float32', mode='r', shape=(imx*imy,patch,patch,imz)) 208 | return fp 209 | 210 | def locate_obj(self,seg): 211 | obj = {} 212 | for i in range(seg.min(),seg.max()+1): 213 | obj[str(i)] = np.where(seg==i) 214 | self.obj = obj 215 | self.seg = seg 216 | 217 | def obpc(seg,cmap,obj): 218 | pcmap = copy.deepcopy(cmap) 219 | for (k,v) in obj.items(): 220 | tmplabel = stats.mode(cmap[v])[0] 221 | pcmap[v] = tmplabel 222 | return pcmap 223 | 224 | def cfm(pre, ref, ncl=9): 225 | if ref.min() != 0: 226 | print('warning: label should begin with 0 !!') 227 | return 228 | 229 | nsize = ref.shape[0] 230 | cf = np.zeros((ncl,ncl)) 231 | for i in range(nsize): 232 | cf[pre[i], ref[i]] += 1 233 | 234 | tmp1 = 0 235 | for j in range(ncl): 236 | tmp1 = tmp1 + (cf[j,:].sum()/nsize)*(cf[:,j].sum()/nsize) 237 | cfm = np.zeros((ncl+2,ncl+1)) 238 | cfm[:-2,:-1] = cf 239 | oa = 0 240 | for i in range(ncl): 241 | if cf[i,:].sum(): 242 | cfm[i,ncl] = cf[i,i]/cf[i,:].sum() 243 | if cf[:,i].sum(): 244 | cfm[ncl,i] = cf[i,i]/cf[:,i].sum() 245 | oa += cf[i,i] 246 | cfm[-1, 0] = oa/nsize 247 | cfm[-1, 1] = (cfm[-1, 0]-tmp1)/(1-tmp1) 248 | cfm[-1, 2] = cfm[ncl,:-1].mean() 249 | print('oa: ', format(cfm[-1,0],'.5'), ' kappa: ', format(cfm[-1,1],'.5'), 250 | ' mean: ', format(cfm[-1,2],'.5')) 251 | return cfm 252 | 253 | def gtcfm(pre,gt,ncl): 254 | if gt.max()==255: 255 | print('warning: max 255 !!') 256 | cf = np.zeros([ncl,ncl]) 257 | for i in range(gt.shape[0]): 258 | for j in range(gt.shape[1]): 259 | if gt[i,j]: 260 | cf[pre[i,j]-1,gt[i,j]-1] += 1 261 | tmp1 = 0 262 | nsize = np.sum(gt!=0) 263 | for j in range(ncl): 264 | tmp1 = tmp1 + (cf[j,:].sum()/nsize)*(cf[:,j].sum()/nsize) 265 | cfm = np.zeros((ncl+2,ncl+1)) 266 | cfm[:-2,:-1] = cf 267 | oa = 0 268 | for i in range(ncl): 269 | if cf[i,:].sum(): 270 | cfm[i,ncl] = cf[i,i]/cf[i,:].sum() 271 | if cf[:,i].sum(): 272 | cfm[ncl,i] = cf[i,i]/cf[:,i].sum() 273 | oa += cf[i,i] 274 | cfm[-1, 0] = oa/nsize 275 | cfm[-1, 1] = (cfm[-1, 0]-tmp1)/(1-tmp1) 276 | cfm[-1, 2] = cfm[ncl,:-1].mean() 277 | print('oa: ', format(cfm[-1,0],'.5'), ' kappa: ', format(cfm[-1,1],'.5'), 278 | ' mean: ', format(cfm[-1,2],'.5')) 279 | return cfm 280 | 281 | def svm(trainx,trainy): 282 | cost = [] 283 | gamma = [] 284 | for i in range(-5,16,2): 285 | cost.append(np.power(2.0,i)) 286 | for i in range(-15,4,2): 287 | gamma.append(np.power(2.0,i)) 288 | 289 | parameters = {'C':cost,'gamma':gamma} 290 | svm = SVC(verbose=0,kernel='rbf') 291 | clf = GridSearchCV(svm, parameters,cv=3) 292 | p = clf.fit(trainx, trainy) 293 | 294 | print(clf.best_params_) 295 | bestc = clf.best_params_['C'] 296 | bestg = clf.best_params_['gamma'] 297 | tmpc = [-1.75,-1.5,-1.25,-1,-0.75,-0.5,-0.25,0.0, 298 | 0.25,0.5,0.75,1.0,1.25,1.5,1.75] 299 | cost = [] 300 | gamma=[] 301 | for i in tmpc: 302 | cost.append(bestc*np.power(2.0,i)) 303 | gamma.append(bestg*np.power(2.0,i)) 304 | parameters = {'C':cost,'gamma':gamma} 305 | svm = SVC(verbose=0,kernel='rbf') 306 | clf = GridSearchCV(svm, parameters,cv=3) 307 | p = clf.fit(trainx, trainy) 308 | print(clf.best_params_) 309 | p2 = clf.best_estimator_ 310 | return p2 311 | 312 | def svm_rbf(trainx,trainy): 313 | cost = [] 314 | gamma = [] 315 | for i in range(-3,10,2): 316 | cost.append(np.power(2.0,i)) 317 | for i in range(-5,4,2): 318 | gamma.append(np.power(2.0,i)) 319 | 320 | parameters = {'C':cost,'gamma':gamma} 321 | svm = SVC(verbose=0,kernel='rbf') 322 | clf = GridSearchCV(svm, parameters,cv=3) 323 | clf.fit(trainx, trainy) 324 | 325 | #print(clf.best_params_) 326 | bestc = clf.best_params_['C'] 327 | bestg = clf.best_params_['gamma'] 328 | tmpc = [-1.75,-1.5,-1.25,-1,-0.75,-0.5,-0.25,0.0, 329 | 0.25,0.5,0.75,1.0,1.25,1.5,1.75] 330 | cost = [] 331 | gamma=[] 332 | for i in tmpc: 333 | cost.append(bestc*np.power(2.0,i)) 334 | gamma.append(bestg*np.power(2.0,i)) 335 | parameters = {'C':cost,'gamma':gamma} 336 | svm = SVC(verbose=0,kernel='rbf') 337 | clf = GridSearchCV(svm, parameters,cv=3) 338 | clf.fit(trainx, trainy) 339 | #print(clf.best_params_) 340 | p = clf.best_estimator_ 341 | return p 342 | 343 | def rf(trainx,trainy,sim=1,nj=1): 344 | nest = [] 345 | nfea = [] 346 | for i in range(20, 201, 20): 347 | nest.append(i) 348 | if sim: 349 | for i in range(1,int(trainx.shape[-1])): 350 | nfea.append(i) 351 | parameters = {'n_estimators':nest,'max_features':nfea} 352 | else: 353 | parameters = {'n_estimators':nest} 354 | rf = RandomForestClassifier(n_jobs=nj,verbose=0,oob_score=False) 355 | clf = GridSearchCV(rf, parameters, cv=3) 356 | p = clf.fit(trainx, trainy) 357 | p2 = clf.best_estimator_ 358 | return p2 359 | 360 | def GNB(trainx,trainy): 361 | clf = GaussianNB() 362 | p = clf.fit(trainx, trainy) 363 | return p 364 | 365 | def svm_linear(trainx,trainy): 366 | cost = [] 367 | for i in range(-3,10,2): 368 | cost.append(np.power(2.0,i)) 369 | 370 | parameters = {'C':cost} 371 | svm = SVC(verbose=0,kernel='linear') 372 | clf = GridSearchCV(svm, parameters,cv=3) 373 | clf.fit(trainx, trainy) 374 | 375 | #print(clf.best_params_) 376 | bestc = clf.best_params_['C'] 377 | tmpc = [-1.75,-1.5,-1.25,-1,-0.75,-0.5,-0.25,0.0, 378 | 0.25,0.5,0.75,1.0,1.25,1.5,1.75] 379 | cost = [] 380 | for i in tmpc: 381 | cost.append(bestc*np.power(2.0,i)) 382 | parameters = {'C':cost} 383 | svm = SVC(verbose=0,kernel='linear') 384 | clf = GridSearchCV(svm, parameters,cv=3) 385 | clf.fit(trainx, trainy) 386 | p = clf.best_estimator_ 387 | return p 388 | 389 | def make_sample(sample, label): 390 | a = np.flip(sample,1) 391 | b = np.flip(sample,2) 392 | c = np.flip(b,1) 393 | newsample = np.concatenate((a,b,c,sample),axis=0) 394 | newlabel = np.concatenate((label,label,label,label),axis=0) 395 | return newsample, newlabel 396 | 397 | def save_cmap(img, cmap, fname): 398 | 399 | sizes = np.shape(img) 400 | height = float(sizes[0]) 401 | width = float(sizes[1]) 402 | 403 | fig = plt.figure() 404 | fig.set_size_inches(width/height, 1, forward=False) 405 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 406 | ax.set_axis_off() 407 | fig.add_axes(ax) 408 | 409 | ax.imshow(img, cmap=cmap) 410 | plt.savefig(fname, dpi = height) 411 | plt.close() 412 | 413 | -------------------------------------------------------------------------------- /pytorch/torchnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 6 10:07:13 2020 4 | 5 | @author: sjliu.me@gmail.com 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class wcrn(nn.Module): 13 | def __init__(self, num_classes=9): 14 | super(wcrn, self).__init__() 15 | 16 | self.conv1a = nn.Conv2d(103,64,kernel_size=3,stride=1,padding=0,groups=1) 17 | self.conv1b = nn.Conv2d(103,64,kernel_size=1,stride=1,padding=0,groups=1) 18 | self.maxp1 = nn.MaxPool2d(kernel_size=3) 19 | self.maxp2 = nn.MaxPool2d(kernel_size=5) 20 | 21 | # self.bn1 = nn.BatchNorm2d(128,eps=0.001,momentum=0.9) 22 | self.bn1 = nn.BatchNorm2d(128) 23 | self.conv2a = nn.Conv2d(128,128,kernel_size=1,stride=1,padding=0,groups=1) 24 | self.conv2b = nn.Conv2d(128,128,kernel_size=1,stride=1,padding=0,groups=1) 25 | 26 | self.fc = nn.Linear(128, num_classes) 27 | # torch.nn.init.normal_(self.fc.weight, mean=0, std=0.01) 28 | 29 | def forward(self, x): 30 | out = self.conv1a(x) 31 | out1 = self.conv1b(x) 32 | out = self.maxp1(out) 33 | out1 = self.maxp2(out1) 34 | 35 | out = torch.cat((out,out1),1) 36 | 37 | out1 = self.bn1(out) 38 | out1 = nn.ReLU()(out1) 39 | out1 = self.conv2a(out1) 40 | out1 = nn.ReLU()(out1) 41 | out1 = self.conv2b(out1) 42 | 43 | out = torch.add(out,out1) 44 | out = out.reshape(out.size(0), -1) 45 | out = self.fc(out) 46 | 47 | return out 48 | 49 | class resnet99_avg(nn.Module): 50 | def __init__(self, num_classes=9): 51 | super(resnet99_avg, self).__init__() 52 | 53 | self.conv1a = nn.Conv2d(103,32,kernel_size=3,stride=1,padding=0,groups=1) 54 | self.conv1b = nn.Conv2d(103,32,kernel_size=3,stride=1,padding=0,groups=1) 55 | 56 | self.bn1 = nn.BatchNorm2d(64,eps=0.001,momentum=0.9) 57 | self.conv2a = nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1,groups=1) 58 | self.conv2b = nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1,groups=1) 59 | 60 | self.bn2 = nn.BatchNorm2d(64,eps=0.001,momentum=0.9) 61 | self.conv3a = nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1,groups=1) 62 | self.conv3b = nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1,groups=1) 63 | 64 | self.fc = nn.Linear(64, num_classes) 65 | 66 | 67 | def forward(self, x): 68 | x1 = self.conv1a(x) 69 | x2 = self.conv1b(x) 70 | 71 | x1 = torch.cat((x1,x2),axis=1) 72 | x2 = self.bn1(x1) 73 | x2 = nn.ReLU()(x2) 74 | x2 = self.conv2a(x2) 75 | x2 = nn.ReLU()(x2) 76 | x2 = self.conv2b(x2) 77 | x1 = torch.add(x1,x2) 78 | 79 | x2 = self.bn2(x1) 80 | x2 = nn.ReLU()(x2) 81 | x2 = self.conv3a(x2) 82 | x2 = nn.ReLU()(x2) 83 | x2 = self.conv3b(x2) 84 | x1 = torch.add(x1,x2) 85 | 86 | x1 = nn.AdaptiveAvgPool2d((1,1))(x1) 87 | x1 = x1.reshape(x1.size(0), -1) 88 | 89 | out = self.fc(x1) 90 | return out 91 | -------------------------------------------------------------------------------- /rscls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This is a script for satellite image classification 4 | Last updated on Aug 6 2019 5 | 6 | @author: Shengjie Liu 7 | @Email: liushengjie0756@gmail.com 8 | 9 | @functions 10 | 1. generate samples from satellite images 11 | 2. grid search SVM/random forest parameters 12 | 3. object-based post-classification refinement 13 | superpixel-based regularization for classification maps 14 | 15 | 4. confusion matrix: OA, kappa, PA, UA, AA 16 | 5. save maps as images 17 | 18 | 19 | @sample codes 20 | c = rscls.rscls(image,ground_truth,cls=number_of_classes) 21 | c.padding(patch) 22 | c.normalize(style='-11') # optional 23 | 24 | x_train,y_train = c.train_sample(num_per_cls) 25 | x_train,y_train = rscls.make_sample(x_train,y_train) 26 | 27 | x_test,y_test = c.test_sample() 28 | 29 | # for superpixel refinement 30 | c.locate_obj(seg) 31 | pcmap = rscls.obpc(c.seg,predicted,c.obj) 32 | 33 | 34 | @Notes 35 | Ground truth file should be uint8 format begin with 1 36 | Background = 0 37 | """ 38 | 39 | 40 | 41 | import numpy as np 42 | import copy 43 | import scipy.stats as stats 44 | from sklearn.svm import SVC 45 | from sklearn.model_selection import GridSearchCV 46 | from sklearn.ensemble import RandomForestClassifier 47 | from sklearn.naive_bayes import GaussianNB 48 | import matplotlib.pyplot as plt 49 | 50 | class rscls: 51 | def __init__(self,im,gt,cls): 52 | if cls==0: 53 | print('num of class not specified !!') 54 | self.im = copy.deepcopy(im) 55 | if gt.max()!=cls: 56 | self.gt = copy.deepcopy(gt-1) 57 | else: 58 | self.gt = copy.deepcopy(gt-1) 59 | self.gt_b = copy.deepcopy(gt) 60 | self.cls = cls 61 | self.patch = 1 62 | self.imx,self.imy,self.imz = self.im.shape 63 | self.record = [] 64 | self.sample = {} 65 | 66 | def padding(self,patch): 67 | self.patch = patch 68 | pad = self.patch//2 69 | r1 = np.repeat([self.im[0,:,:]], pad, axis=0) 70 | r2 = np.repeat([self.im[-1,:,:]], pad, axis=0) 71 | self.im = np.concatenate((r1, self.im, r2)) 72 | r1 = np.reshape(self.im[:,0,:],[self.imx + 2 * pad, 1, self.imz]) 73 | r2 = np.reshape(self.im[:,-1,:],[self.imx + 2 * pad, 1, self.imz]) 74 | r1 = np.repeat(r1, pad, axis=1) 75 | r2 = np.repeat(r2, pad, axis=1) 76 | self.im = np.concatenate((r1, self.im, r2), axis=1) 77 | self.im = self.im.astype('float32') 78 | 79 | def normalize(self,style='01'): 80 | im = self.im 81 | for i in range(im.shape[-1]): 82 | im[:,:,i]=(im[:,:,i]-im[:,:,i].min())/(im[:,:,i].max()-im[:,:,i].min()) 83 | if style == '-11': 84 | im = im*2-1 85 | 86 | def locate_sample(self): 87 | sam = [] 88 | for i in range(self.cls): 89 | _xy = np.array(np.where(self.gt==i)).T 90 | _sam = np.concatenate([_xy,i*np.ones([_xy.shape[0],1])],axis=-1) 91 | try: 92 | sam = np.concatenate([sam,_sam],axis=0) 93 | except: 94 | sam = _sam 95 | self.sample = sam.astype(int) 96 | 97 | def get_patch(self, xy): 98 | d = self.patch//2 99 | x = xy[0] 100 | y = xy[1] 101 | try: 102 | self.im[x][y] 103 | except IndexError: 104 | return [] 105 | x += d 106 | y += d 107 | sam = self.im[(x - d):(x + d + 1), (y - d):(y + d + 1)] 108 | return np.array(sam) 109 | 110 | def train_sample(self,pn): 111 | x_train,y_train = [],[] 112 | self.locate_sample() 113 | _samp = self.sample 114 | for _cls in range(self.cls): 115 | _xy = _samp[_samp[:,2]==_cls] 116 | np.random.shuffle(_xy) 117 | _xy = _xy[:pn,:] 118 | for xy in _xy: 119 | self.gt[xy[0],xy[1]] = 255 # !! 120 | # 121 | x_train.append(self.get_patch(xy[:-1])) 122 | y_train.append(xy[-1]) 123 | # print(_xy) 124 | x_train,y_train = np.array(x_train), np.array(y_train) 125 | idx = np.random.permutation(x_train.shape[0]) 126 | x_train = x_train[idx] 127 | y_train = y_train[idx] 128 | return x_train,y_train.astype(int) 129 | 130 | def test_sample(self): 131 | x_test,y_test = [],[] 132 | self.locate_sample() 133 | _samp = self.sample 134 | for _cls in range(self.cls): 135 | _xy = _samp[_samp[:,2]==_cls] 136 | np.random.shuffle(_xy) 137 | for xy in _xy: 138 | x_test.append(self.get_patch(xy[:-1])) 139 | y_test.append(xy[-1]) 140 | return np.array(x_test), np.array(y_test) 141 | 142 | def all_sample(self): 143 | imx,imy = self.gt.shape 144 | sample = [] 145 | for i in range(imx): 146 | for j in range(imy): 147 | sample.append(self.get_patch(np.array([i,j]))) 148 | return np.array(sample) 149 | 150 | def all_sample_light(self,clip=0,bs=10): 151 | imx,imy = self.gt.shape 152 | imz = self.im.shape[-1] 153 | patch = self.patch 154 | # fp = np.memmap('allsample' + str(clip) + '.h5', dtype='float32', mode='w+', shape=(imgx*self.IMGY,5,5,bs)) 155 | fp = np.zeros([imx*imy,patch,patch,imz]) 156 | countnum = 0 157 | for i in range(imx*clip,imx*(clip+1)): 158 | for j in range(imy): 159 | xy = np.array([i,j]) 160 | fp[countnum,:,:,:] = self.get_patch(xy) 161 | countnum += 1 162 | return fp 163 | 164 | def all_sample_row_hd(self,sub=0): 165 | imx,imy = self.gt.shape 166 | imz = self.im.shape[-1] 167 | patch = self.patch 168 | # fp = np.memmap('allsample' + str(clip) + '.h5', dtype='float32', mode='w+', shape=(imgx*self.IMGY,5,5,bs)) 169 | fp = np.zeros([imx*imy,patch,patch,imz]) 170 | countnum = 0 171 | for i in range(sub): 172 | for j in range(imy): 173 | xy = np.array([i,j]) 174 | fp[countnum,:,:,:] = self.get_patch(xy) 175 | countnum += 1 176 | return fp 177 | 178 | def all_sample_row(self,sub=0): 179 | imx,imy = self.gt.shape 180 | fp = [] 181 | for j in range(imy): 182 | xy = np.array([sub,j]) 183 | fp.append(self.get_patch(xy)) 184 | return np.array(fp) 185 | 186 | def all_sample_heavy(self,name,clip=0,bs=10): 187 | imx,imy = self.gt.shape 188 | imz = self.im.shape[-1] 189 | patch = self.patch 190 | try: 191 | fp = np.memmap(name, dtype='float32', mode='w+', shape=(imx*imy,patch,patch,imz)) 192 | except: 193 | fp = np.memmap(name, dtype='float32', mode='r', shape=(imx*imy,patch,patch,imz)) 194 | # fp = np.zeros([imx*imy,patch,patch,imz]) 195 | countnum = 0 196 | for i in range(imx*clip,imx*(clip+1)): 197 | for j in range(imy): 198 | xy = np.array([i,j]) 199 | fp[countnum,:,:,:] = self.get_patch(xy) 200 | countnum += 1 201 | return fp 202 | 203 | def read_all_sample(self,name,clip=0,bs=10): 204 | imx,imy = self.gt.shape 205 | imz = self.im.shape[-1] 206 | patch = self.patch 207 | fp = np.memmap(name, dtype='float32', mode='r', shape=(imx*imy,patch,patch,imz)) 208 | return fp 209 | 210 | def locate_obj(self,seg): 211 | obj = {} 212 | for i in range(seg.min(),seg.max()+1): 213 | obj[str(i)] = np.where(seg==i) 214 | self.obj = obj 215 | self.seg = seg 216 | 217 | def obpc(seg,cmap,obj): 218 | pcmap = copy.deepcopy(cmap) 219 | for (k,v) in obj.items(): 220 | tmplabel = stats.mode(cmap[v])[0] 221 | pcmap[v] = tmplabel 222 | return pcmap 223 | 224 | def cfm(pre, ref, ncl=9): 225 | if ref.min() != 0: 226 | print('warning: label should begin with 0 !!') 227 | return 228 | 229 | nsize = ref.shape[0] 230 | cf = np.zeros((ncl,ncl)) 231 | for i in range(nsize): 232 | cf[pre[i], ref[i]] += 1 233 | 234 | tmp1 = 0 235 | for j in range(ncl): 236 | tmp1 = tmp1 + (cf[j,:].sum()/nsize)*(cf[:,j].sum()/nsize) 237 | cfm = np.zeros((ncl+2,ncl+1)) 238 | cfm[:-2,:-1] = cf 239 | oa = 0 240 | for i in range(ncl): 241 | if cf[i,:].sum(): 242 | cfm[i,ncl] = cf[i,i]/cf[i,:].sum() 243 | if cf[:,i].sum(): 244 | cfm[ncl,i] = cf[i,i]/cf[:,i].sum() 245 | oa += cf[i,i] 246 | cfm[-1, 0] = oa/nsize 247 | cfm[-1, 1] = (cfm[-1, 0]-tmp1)/(1-tmp1) 248 | cfm[-1, 2] = cfm[ncl,:-1].mean() 249 | print('oa: ', format(cfm[-1,0],'.5'), ' kappa: ', format(cfm[-1,1],'.5'), 250 | ' mean: ', format(cfm[-1,2],'.5')) 251 | return cfm 252 | 253 | def gtcfm(pre,gt,ncl): 254 | if gt.max()==255: 255 | print('warning: max 255 !!') 256 | cf = np.zeros([ncl,ncl]) 257 | for i in range(gt.shape[0]): 258 | for j in range(gt.shape[1]): 259 | if gt[i,j]: 260 | cf[pre[i,j]-1,gt[i,j]-1] += 1 261 | tmp1 = 0 262 | nsize = np.sum(gt!=0) 263 | for j in range(ncl): 264 | tmp1 = tmp1 + (cf[j,:].sum()/nsize)*(cf[:,j].sum()/nsize) 265 | cfm = np.zeros((ncl+2,ncl+1)) 266 | cfm[:-2,:-1] = cf 267 | oa = 0 268 | for i in range(ncl): 269 | if cf[i,:].sum(): 270 | cfm[i,ncl] = cf[i,i]/cf[i,:].sum() 271 | if cf[:,i].sum(): 272 | cfm[ncl,i] = cf[i,i]/cf[:,i].sum() 273 | oa += cf[i,i] 274 | cfm[-1, 0] = oa/nsize 275 | cfm[-1, 1] = (cfm[-1, 0]-tmp1)/(1-tmp1) 276 | cfm[-1, 2] = cfm[ncl,:-1].mean() 277 | print('oa: ', format(cfm[-1,0],'.5'), ' kappa: ', format(cfm[-1,1],'.5'), 278 | ' mean: ', format(cfm[-1,2],'.5')) 279 | return cfm 280 | 281 | def svm(trainx,trainy): 282 | cost = [] 283 | gamma = [] 284 | for i in range(-5,16,2): 285 | cost.append(np.power(2.0,i)) 286 | for i in range(-15,4,2): 287 | gamma.append(np.power(2.0,i)) 288 | 289 | parameters = {'C':cost,'gamma':gamma} 290 | svm = SVC(verbose=0,kernel='rbf') 291 | clf = GridSearchCV(svm, parameters,cv=3) 292 | p = clf.fit(trainx, trainy) 293 | 294 | print(clf.best_params_) 295 | bestc = clf.best_params_['C'] 296 | bestg = clf.best_params_['gamma'] 297 | tmpc = [-1.75,-1.5,-1.25,-1,-0.75,-0.5,-0.25,0.0, 298 | 0.25,0.5,0.75,1.0,1.25,1.5,1.75] 299 | cost = [] 300 | gamma=[] 301 | for i in tmpc: 302 | cost.append(bestc*np.power(2.0,i)) 303 | gamma.append(bestg*np.power(2.0,i)) 304 | parameters = {'C':cost,'gamma':gamma} 305 | svm = SVC(verbose=0,kernel='rbf') 306 | clf = GridSearchCV(svm, parameters,cv=3) 307 | p = clf.fit(trainx, trainy) 308 | print(clf.best_params_) 309 | p2 = clf.best_estimator_ 310 | return p2 311 | 312 | def svm_rbf(trainx,trainy): 313 | cost = [] 314 | gamma = [] 315 | for i in range(-3,10,2): 316 | cost.append(np.power(2.0,i)) 317 | for i in range(-5,4,2): 318 | gamma.append(np.power(2.0,i)) 319 | 320 | parameters = {'C':cost,'gamma':gamma} 321 | svm = SVC(verbose=0,kernel='rbf') 322 | clf = GridSearchCV(svm, parameters,cv=3) 323 | clf.fit(trainx, trainy) 324 | 325 | #print(clf.best_params_) 326 | bestc = clf.best_params_['C'] 327 | bestg = clf.best_params_['gamma'] 328 | tmpc = [-1.75,-1.5,-1.25,-1,-0.75,-0.5,-0.25,0.0, 329 | 0.25,0.5,0.75,1.0,1.25,1.5,1.75] 330 | cost = [] 331 | gamma=[] 332 | for i in tmpc: 333 | cost.append(bestc*np.power(2.0,i)) 334 | gamma.append(bestg*np.power(2.0,i)) 335 | parameters = {'C':cost,'gamma':gamma} 336 | svm = SVC(verbose=0,kernel='rbf') 337 | clf = GridSearchCV(svm, parameters,cv=3) 338 | clf.fit(trainx, trainy) 339 | #print(clf.best_params_) 340 | p = clf.best_estimator_ 341 | return p 342 | 343 | def rf(trainx,trainy,sim=1,nj=1): 344 | nest = [] 345 | nfea = [] 346 | for i in range(20, 201, 20): 347 | nest.append(i) 348 | if sim: 349 | for i in range(1,int(trainx.shape[-1])): 350 | nfea.append(i) 351 | parameters = {'n_estimators':nest,'max_features':nfea} 352 | else: 353 | parameters = {'n_estimators':nest} 354 | rf = RandomForestClassifier(n_jobs=nj,verbose=0,oob_score=False) 355 | clf = GridSearchCV(rf, parameters, cv=3) 356 | p = clf.fit(trainx, trainy) 357 | p2 = clf.best_estimator_ 358 | return p2 359 | 360 | def GNB(trainx,trainy): 361 | clf = GaussianNB() 362 | p = clf.fit(trainx, trainy) 363 | return p 364 | 365 | def svm_linear(trainx,trainy): 366 | cost = [] 367 | for i in range(-3,10,2): 368 | cost.append(np.power(2.0,i)) 369 | 370 | parameters = {'C':cost} 371 | svm = SVC(verbose=0,kernel='linear') 372 | clf = GridSearchCV(svm, parameters,cv=3) 373 | clf.fit(trainx, trainy) 374 | 375 | #print(clf.best_params_) 376 | bestc = clf.best_params_['C'] 377 | tmpc = [-1.75,-1.5,-1.25,-1,-0.75,-0.5,-0.25,0.0, 378 | 0.25,0.5,0.75,1.0,1.25,1.5,1.75] 379 | cost = [] 380 | for i in tmpc: 381 | cost.append(bestc*np.power(2.0,i)) 382 | parameters = {'C':cost} 383 | svm = SVC(verbose=0,kernel='linear') 384 | clf = GridSearchCV(svm, parameters,cv=3) 385 | clf.fit(trainx, trainy) 386 | p = clf.best_estimator_ 387 | return p 388 | 389 | def make_sample(sample, label): 390 | a = np.flip(sample,1) 391 | b = np.flip(sample,2) 392 | c = np.flip(b,1) 393 | newsample = np.concatenate((a,b,c,sample),axis=0) 394 | newlabel = np.concatenate((label,label,label,label),axis=0) 395 | return newsample, newlabel 396 | 397 | def save_cmap(img, cmap, fname): 398 | 399 | sizes = np.shape(img) 400 | height = float(sizes[0]) 401 | width = float(sizes[1]) 402 | 403 | fig = plt.figure() 404 | fig.set_size_inches(width/height, 1, forward=False) 405 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 406 | ax.set_axis_off() 407 | fig.add_axes(ax) 408 | 409 | ax.imshow(img, cmap=cmap) 410 | plt.savefig(fname, dpi = height) 411 | plt.close() 412 | 413 | --------------------------------------------------------------------------------