├── .gitignore ├── LICENSE ├── LSM.py ├── README.md ├── SHAP.py ├── comparison.py ├── figs └── Overview.jpg ├── figure.py ├── meta_LSM.py ├── metatask_sampling └── .gitkeep ├── modeling.py ├── models_of_blocks ├── .gitkeep └── HK │ └── .gitkeep ├── requirements.txt ├── scene_segmentation.py ├── src_data ├── HK │ ├── composite.tfw │ ├── composite.tif │ ├── composite.tif.aux.xml │ └── composite.tif.ovr ├── Ts_HK.csv ├── grid_samples_HK.csv ├── samples_HK.csv └── samples_HK_noTS.csv └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # root 2 | /*.xlsx 3 | /*.xls 4 | #/*.txt 5 | 6 | # pycache 7 | .idea/ 8 | __pycache__/ 9 | 10 | # figs 11 | figs/*.pdf 12 | 13 | # models_of_blocks 14 | models_of_blocks/**/*.npz 15 | 16 | # metatask_sampling 17 | metatask_sampling/*.tif 18 | metatask_sampling/*.xlsx 19 | 20 | # files 21 | unsupervised_pretraining/ 22 | tmp/ 23 | checkpoint_dir/ 24 | 25 | *~ 26 | 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 Li CHEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LSM.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pandas as pd 3 | import numpy as np 4 | from osgeo import gdal 5 | 6 | from meta_LSM import FLAGS 7 | from modeling import MAML 8 | from utils import batch_generator, read_pts, read_tasks 9 | 10 | 11 | def readfxy_csv(file): 12 | tmp = np.loadtxt(file, dtype=str, delimiter=",", encoding='UTF-8') 13 | features = tmp[1:, :-2].astype(np.float32) 14 | features = features / features.max(axis=0) 15 | xy = tmp[1:, -2:].astype(np.float32) 16 | return features, xy 17 | 18 | 19 | def getclusters(gridpts_xy, taskpts, tifformat_path): 20 | dataset = gdal.Open(tifformat_path) 21 | if not dataset: 22 | print("can not open *.tif file!") 23 | im_geotrans = dataset.GetGeoTransform() 24 | gridcluster = [[] for i in range(len(taskpts))] 25 | for i in range(np.shape(gridpts_xy)[0]): 26 | height = int((gridpts_xy[i][1] - im_geotrans[3]) / im_geotrans[5]) 27 | width = int((gridpts_xy[i][0] - im_geotrans[0]) / im_geotrans[1]) 28 | for j in range(len(taskpts)): 29 | if [height, width] in taskpts[j].tolist(): 30 | gridcluster[j].append(i) 31 | break 32 | return gridcluster 33 | 34 | 35 | def predict_LSM(tasks_samples, features, xy, indexes, savename, num_updates=5): 36 | """restore model from checkpoint""" 37 | tf.compat.v1.disable_eager_execution() 38 | model = MAML(FLAGS.dim_input, FLAGS.dim_output, test_num_updates=5) 39 | input_tensors_input = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_input) 40 | input_tensors_label = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_output) 41 | model.construct_model(input_tensors_input=input_tensors_input, input_tensors_label=input_tensors_label, 42 | prefix='metatrain_') 43 | exp_string = '.mbs' + str(FLAGS.meta_batch_size) + '.ubs_' + \ 44 | str(FLAGS.num_samples_each_task) + '.numstep' + str(FLAGS.num_updates) + \ 45 | '.updatelr' + str(FLAGS.update_lr) + '.meta_lr' + str(FLAGS.meta_lr) 46 | saver = tf.compat.v1.train.Saver(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)) 47 | sess = tf.compat.v1.InteractiveSession() 48 | init = tf.compat.v1.global_variables() # optimizer里会有额外variable需要初始化 49 | sess.run(tf.compat.v1.variables_initializer(var_list=init)) 50 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) 51 | if model_file: 52 | print("Restoring model weights from " + model_file) 53 | saver.restore(sess, model_file) # 以model_file初始化sess中图 54 | else: 55 | print("no intermediate model found!") 56 | 57 | savearr = np.arange(4, dtype=np.float32).reshape((1, 4)) # save predicting result 58 | 59 | for i in range(len(tasks_samples)): 60 | np.random.shuffle(tasks_samples[i]) 61 | with tf.compat.v1.variable_scope('model', reuse=True): # Variable reuse in np.normalize() 62 | if len(tasks_samples[i]) > FLAGS.num_samples_each_task: 63 | train_ = tasks_samples[i][:int(len(tasks_samples[i]) / 2)] 64 | batch_size = FLAGS.test_update_batch_size 65 | else: 66 | train_ = tasks_samples[i] 67 | batch_size = int(len(train_) / 2) 68 | fast_weights = model.weights 69 | for j in range(num_updates): 70 | inputa, labela = batch_generator(train_, FLAGS.dim_input, FLAGS.dim_output, 71 | batch_size) 72 | loss = model.loss_func(model.forward(inputa, fast_weights, reuse=True), labela) 73 | grads = tf.gradients(ys=loss, xs=list(fast_weights.values())) 74 | gradients = dict(zip(fast_weights.keys(), grads)) 75 | fast_weights = dict(zip(fast_weights.keys(), 76 | [fast_weights[key] - model.update_lr * gradients[key] for key in 77 | fast_weights.keys()])) 78 | 79 | """predict LSM""" 80 | if len(indexes[i]): 81 | features_arr = np.array([features[index] for index in indexes[i]]) 82 | xy_arr = np.array([xy[index] for index in indexes[i]]) 83 | pred = model.forward(features_arr, fast_weights, reuse=True) 84 | pred = sess.run(tf.nn.softmax(pred)) 85 | tmp = np.hstack( 86 | (xy_arr[:, 0].reshape(xy_arr.shape[0], 1), xy_arr[:, 1].reshape(xy_arr.shape[0], 1), pred)) 87 | savearr = np.vstack((savearr, tmp)) 88 | """save model parameters to npz file""" 89 | adapted_weights = sess.run(fast_weights) 90 | np.savez('models_of_blocks/HK/model' + str(i), adapted_weights['w1'], adapted_weights['b1'], 91 | adapted_weights['w2'], adapted_weights['b2'], 92 | adapted_weights['w3'], adapted_weights['b3'], 93 | adapted_weights['w4'], adapted_weights['b4']) 94 | 95 | writer = pd.ExcelWriter('tmp/' + savename) 96 | data_df = pd.DataFrame(savearr) 97 | data_df.to_excel(writer) 98 | writer.close() 99 | 100 | print('save LSM successfully') 101 | sess.close() 102 | 103 | 104 | if __name__ == "__main__": 105 | print('grid points assignment...') 106 | HK_tasks = read_tasks('./metatask_sampling/HK_tasks_K{k}.xlsx'.format(k=FLAGS.K)) 107 | HK_taskpts = read_pts('./metatask_sampling/HKpts_tasks_K{k}.xlsx'.format(k=FLAGS.K)) 108 | HK_gridpts_feature, HK_gridpts_xy = readfxy_csv('./src_data/grid_samples_HK.csv') 109 | HK_gridcluster = getclusters(HK_gridpts_xy, HK_taskpts, './metatask_sampling/' + FLAGS.str_region + \ 110 | '_SLIC_M{m}_K{k}_loop{loop}.tif'.format(loop=0, m=FLAGS.M, k=FLAGS.K)) 111 | 112 | print('adapt and predict...') 113 | predict_LSM(HK_tasks, HK_gridpts_feature, HK_gridpts_xy, HK_gridcluster, 'proposed_prediction.xlsx') 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

3 |

4 | 5 | # Hong Kong Landslide Susceptibility Mapping in a Meta-learning Way (tf2). 6 | 7 | [//]: # (# Landslide Susceptibility Assessment in Multiple Landslide-inducing Environments with a Landslide Inventory Augmented by InSAR Techniques) 8 | 9 | ## Table of Contents 10 | 11 | - [Background](#background) 12 | - [Data](#data) 13 | - [Dependencies](#dependencies) 14 | - [Usage](#usage) 15 | - [Contact](#contact) 16 | - [Citation](#citation) 17 | 18 | 19 | ## Background 20 | Landslide susceptibility assessment (LSA) is vital for landslide hazard mitigation and prevention. 21 | Recently, there have been vast applications of data-driven LSA methods owing to the increased 22 | availability of high-quality satellite data and landslide statistics. However, two issues remain to 23 | be addressed, as follows: (a) landslide records obtained from a landslide inventory (LI) are mainly 24 | based on the interpretation of optical images and site investigation, resulting in current datadriven 25 | models being insensitive to slope dynamics, such as slow-moving landslides; (b) Most 26 | study areas contain a variety of landslide-inducing environments (LIEs) that a single model 27 | can not accommodate well. In this study, we proposed the utilization of InSAR techniques 28 | to sample weak landslide labels from slow-moving slopes for LI augmentation; and meta-learn 29 | intermediate representations for the fast adaptation of LSA models corresponding to different 30 | LIEs. We performed feature permutation to identify dominant landslide-inducing factors (LIFs) 31 | and fostered guidance for targeted landslide prevention schemes. The results obtained in Hong 32 | Kong revealed that deformation in several mountainous regions are closely associated with the 33 | majority of recorded landslides. By augmenting the LI using InSAR techniques, the proposed 34 | method improved the perception of potential dynamic landslides and achieved better statistical 35 | performance. The discussion highlights that slope and stream power index (SPI) are the key 36 | LIFs in Hong Kong, but the dominant LIFs will vary under different LIEs. By comparison, the 37 | proposed method entails a fast-learning strategy and extensively outperforms other data-driven 38 | LSA techniques, e.g., by 3-6% in accuracy, 2-6% in precision, 1-2% in recall, 3-5% in F1-score, 39 | and approximately 10% in Cohen Kappa. 40 | 41 | 42 | 43 | ​ Fig. 1: Overflow 44 | 45 | 46 | ## Data 47 | 48 | * The landslide inventory can be found [here](https://data.gov.hk/en-data/dataset/hk-cedd-csu-cedd-entli). 49 | * The related thematic information can be found [here](https://geodata.gov.hk/gs). 50 | * The nonlandslide/landslide sample vectors are filed into `./src_data/` where `samples_HK.csv` and `samples_HK_noTS.csv` are datasets with and without augmented slow-moving landslides, respectively. 51 | 52 | [//]: # () 53 | [//]: # (The source and experiment data will be opened...) 54 | 55 | 56 | ## Dependencies 57 | 58 | The default branch uses tf2 environment: 59 | * cudatoolkit 11.2.2 60 | * cudnn 8.1.0.77 61 | * python 3.9.13 62 | * tensorflow 2.10.0 63 | 64 | Install required packages 65 | ``` 66 | python -m pip install -r requirements.txt 67 | ``` 68 | 69 | ## Usage 70 | 71 | * For the scene segmentation and task sampling stage, see `./scene_sampling.py`, the result would be output into `./metatask_sampling` folder. 72 | * For the meta learner, see `./meta_learner.py`. 73 | * For the model adaption and landslide susceptibility prediction, see `./predict_LSM.py`. The intermediate model and adapted models of blocks would be saved in folder `./checkpoint_dir` and `./models_of_blocks`, respectively.The adapted models will predict the susceptibility for each sample vector in `./src_data/grid_samples_HK.csv`. 74 | * The `./tmp` folder restores some temp records. 75 | * For the figuring in the experiment, see `./figure.py`, the figures would be saved in folder `./figs`. 76 | 77 | 78 | ## Contact 79 | 80 | To ask questions or report issues, please open an issue on the [issue tracker](https://github.com/CLi-de/Meta_LSM/issues). 81 | 82 | ## Citation 83 | 84 | If this repository helps your research, please cite the paper. Here is the BibTeX entry: 85 | 86 | ``` 87 | @article{CHEN2023107342, 88 | title = {Landslide susceptibility assessment in multiple urban slope settings with a landslide inventory augmented by InSAR techniques}, 89 | journal = {Engineering Geology}, 90 | volume = {327}, 91 | pages = {107342}, 92 | year = {2023}, 93 | issn = {0013-7952}, 94 | doi = {https://doi.org/10.1016/j.enggeo.2023.107342}, 95 | author = {Li Chen and Peifeng Ma and Chang Yu and Yi Zheng and Qing Zhu and Yulin Ding}, 96 | } 97 | ``` 98 | 99 | The preprint can be found: [here](https://www.sciencedirect.com/science/article/pii/S0013795223003605) -------------------------------------------------------------------------------- /SHAP.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env pytho 2 | # -*- coding: utf-8 -*- 3 | # @Author : CHEN Li 4 | # @Time : 2022/12/2 14:55 5 | # @File : SHAP.py 6 | # @annotation 7 | 8 | import tensorflow as tf 9 | import xgboost 10 | import shap 11 | import warnings 12 | import matplotlib.pyplot as plt 13 | from sklearn import svm 14 | import numpy as np 15 | import pandas as pd 16 | from meta_LSM import FLAGS 17 | from modeling import MAML 18 | 19 | from sklearn.model_selection import train_test_split 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | # construct model 25 | def init_weights(file): 26 | """读取DAS权参""" 27 | with tf.compat.v1.variable_scope('model'): # get variable in 'model' scope, to reuse variables 28 | npzfile = np.load(file) 29 | weights = {} 30 | weights['w1'] = npzfile['arr_0'] 31 | weights['b1'] = npzfile['arr_1'] 32 | weights['w2'] = npzfile['arr_2'] 33 | weights['b2'] = npzfile['arr_3'] 34 | weights['w3'] = npzfile['arr_4'] 35 | weights['b3'] = npzfile['arr_5'] 36 | weights['w4'] = npzfile['arr_6'] 37 | weights['b4'] = npzfile['arr_7'] 38 | return weights 39 | 40 | 41 | # define model.pred_prob() for shap.KernelExplainer(model, data) 42 | def pred_prob(X_): 43 | with tf.compat.v1.variable_scope('model', reuse=True): 44 | return sess.run(tf.nn.softmax(model.forward(X_, model.weights, reuse=True))) 45 | 46 | 47 | # read subtasks 48 | def read_tasks(file): 49 | """获取tasks""" 50 | f = pd.ExcelFile(file) 51 | tasks = [[] for i in range(len(f.sheet_names))] 52 | k = 0 53 | for sheetname in f.sheet_names: 54 | # attr = pd.read_excel(file, usecols=[i for i in range(FLAGS.dim_input)], sheet_name=sheetname, 55 | # header=None).values.astype(np.float32) 56 | arr = pd.read_excel(file, sheet_name=sheetname, 57 | header=None).values.astype(np.float32) 58 | tasks[k] = arr 59 | k = k + 1 60 | return tasks 61 | 62 | 63 | """construct model""" 64 | tf.compat.v1.disable_eager_execution() 65 | model = MAML(FLAGS.dim_input, FLAGS.dim_output, test_num_updates=5) 66 | input_tensors_input = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_input) 67 | input_tensors_label = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_output) 68 | model.construct_model(input_tensors_input=input_tensors_input, input_tensors_label=input_tensors_label, 69 | prefix='metatrain_') 70 | 71 | tmp = np.loadtxt('./src_data/samples_HK.csv', dtype=str, delimiter=",", encoding='UTF-8') 72 | feature_names = tmp[0, :-3].astype(np.str) 73 | task = read_tasks('./metatask_sampling/HK_tasks_K{k}.xlsx'.format(k=FLAGS.K)) 74 | 75 | sess = tf.compat.v1.InteractiveSession() 76 | init = tf.compat.v1.global_variables() # optimizer里会有额外variable需要初始化 77 | sess.run(tf.compat.v1.variables_initializer(var_list=init)) 78 | 79 | # eligible i: [11, 31, 81, ],['planting area', 'catchment', 'mountainous areas with severe deformation', ''] 80 | # SHAP for ith subtasks 81 | for i in range(11, len(task), 10): 82 | model.weights = init_weights('./models_of_blocks/HK/model' + str(i) + '.npz') 83 | 84 | tmp_ = task[i] 85 | np.random.shuffle(tmp_) # shuffle 86 | # # 训练集 87 | # x_train = tmp_[:int(tmp_.shape[0] / 2), :-1] # 加载i行数据部分 88 | # y_train = tmp_[:int(tmp_.shape[0] / 2), -1] # 加载类别标签部分 89 | # # 测试集 90 | # # x_test = tmp_[int(tmp_.shape[0] / 2):, :-1] # 加载i行数据部分 91 | # # y_test = tmp_[int(tmp_.shape[0] / 2):, -1] # 加载类别标签部分 92 | # X, Y 93 | X = tmp_[:, :-1] # 加载i行数据部分 94 | Y = tmp_[:, -1] # 加载类别标签部分 95 | 96 | shap.initjs() 97 | # SHAP demo are using dataframe instead of nparray 98 | X_ = pd.DataFrame(X) # convert np.array to pd.dataframe 99 | # x_test = pd.DataFrame(x_test) 100 | X_.columns = feature_names # 添加特征名称 101 | # x_test.columns = feature_names 102 | 103 | # explainer = shap.KernelExplainer(pred_prob, shap.kmeans(x_train, 80)) 104 | explainer = shap.KernelExplainer(pred_prob, shap.sample(X_, 100)) 105 | shap_values = explainer.shap_values(X_, nsamples=100) # shap_values 106 | # (_prob, n_samples, features) 107 | # TODO: refer https://shap-lrjball.readthedocs.io/en/latest/generated/shap.summary_plot.html to change plot style 108 | # shap.force_plot(explainer.expected_value[1], shap_values[1][0, :], x_test.iloc[0, :], show=True, matplotlib=True) # single feature 109 | shap.summary_plot(shap_values, X_, plot_type="bar", show=False) 110 | plt.savefig('tmp/bar_' + str(i) + '.pdf') 111 | plt.close() 112 | shap.summary_plot(shap_values[1], X_, plot_type="violin", show=False) 113 | plt.savefig('tmp/violin_' + str(i) + '.pdf') 114 | plt.close() 115 | # shap.summary_plot(shap_values[1], x_test, plot_type="compact_dot") 116 | 117 | # shap.force_plot(explainer.expected_value[1], shap_values[1], x_test, link="logit") 118 | 119 | # shap.dependence_plot('DV', shap_values[1], x_test, interaction_index=None) 120 | # shap.dependence_plot('SPI', shap_values[1], x_test, interaction_index='DV') 121 | # shap.plots.beeswarm(shap_values[0]) # the beeswarm plot requires Explanation object as the `shap_values` argument 122 | -------------------------------------------------------------------------------- /comparison.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from sklearn import metrics, svm 7 | 8 | from sklearn.neural_network import MLPClassifier 9 | from sklearn.ensemble import RandomForestClassifier 10 | from unsupervised_pretraining.dbn_.models import SupervisedDBNClassification 11 | 12 | from sklearn.metrics import accuracy_score 13 | from sklearn.metrics import cohen_kappa_score 14 | 15 | from utils import cal_measure 16 | import shap 17 | import matplotlib.pyplot as plt 18 | 19 | import warnings 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | """model-agnostic SHAP""" 24 | 25 | 26 | def SHAP_(predict_proba, x_train, x_test, f_name): 27 | shap.initjs() 28 | # SHAP demo are using dataframe instead of nparray 29 | x_train = pd.DataFrame(x_train) # 将numpy的array数组x_test转为dataframe格式。 30 | x_test = pd.DataFrame(x_test) 31 | x_train.columns = f_name # 添加特征名称 32 | x_test.columns = f_name 33 | 34 | explainer = shap.KernelExplainer(predict_proba, shap.kmeans(x_train, 100)) 35 | x_ = x_test 36 | shap_values = explainer.shap_values(x_, nsamples=100) # shap_values(_prob, n_samples, features) 37 | # shap.force_plot(explainer.expected_value[1], shap_values[1][0, :], x_test.iloc[0, :], show=True, matplotlib=True) # single feature 38 | shap.summary_plot(shap_values, x_, plot_type="bar", show=False) 39 | plt.savefig('tmp/bar_HK.pdf') 40 | plt.close() 41 | shap.summary_plot(shap_values[1], x_, plot_type="violin", show=False) # shap_values[k], k表类别,k=1(landslides) 42 | plt.savefig('tmp/violin_HK.pdf') 43 | plt.close() 44 | # shap.summary_plot(shap_values[1], x_test, plot_type="compact_dot") 45 | 46 | # shap.plots.beeswarm(shap_values[0]) # the beeswarm plot requires Explanation object as the `shap_values` argument 47 | 48 | 49 | def pred_LSM(trained_model, xy, samples, name): 50 | """LSM prediction""" 51 | pred = trained_model.predict_proba(samples) 52 | data = np.hstack((xy, pred)) 53 | data_df = pd.DataFrame(data) 54 | writer = pd.ExcelWriter('./tmp/'+name+'_prediction_HK.xlsx') 55 | data_df.to_excel(writer, 'page_1', float_format='%.5f') 56 | writer.close() 57 | 58 | 59 | def SVM_(x_train, y_train, x_test, y_test): 60 | """predict and test""" 61 | print('start SVM evaluation...') 62 | model = svm.SVC(C=1, kernel='rbf', gamma=1 / (2 * x_train.var()), decision_function_shape='ovr', probability=True) 63 | # clf = svm.SVC(C=0.1, kernel='linear', decision_function_shape='ovr') 64 | model.fit(x_train, y_train) 65 | pred_train = model.predict(x_train) 66 | print('train accuracy:' + str(metrics.accuracy_score(y_train, pred_train))) 67 | pred_test = model.predict(x_test) 68 | print('test accuracy:' + str(metrics.accuracy_score(y_test, pred_test))) 69 | # Precision, Recall, F1-score 70 | cal_measure(pred_test, y_test) 71 | kappa_value = cohen_kappa_score(pred_test, y_test) 72 | print('Cohen_Kappa: %f' % kappa_value) 73 | 74 | # feature permutation 75 | print('SHAP...') 76 | SHAP_(model.predict_proba, x_train, x_test, f_names) 77 | 78 | return model 79 | 80 | 81 | # can be deprecated 82 | def ANN_(x_train, y_train, x_test, y_test): 83 | """predict and test""" 84 | print('start ANN evaluation...') 85 | model = MLPClassifier(hidden_layer_sizes=(32, 32, 16), activation='relu', solver='adam', alpha=0.01, 86 | batch_size=32, max_iter=1000) 87 | model.fit(x_train, y_train) 88 | pred_train = model.predict(x_train) 89 | print('Train Accuracy: %f' % accuracy_score(y_train, pred_train)) 90 | pred_test = model.predict(x_test) 91 | print('Test Accuracy: %f' % accuracy_score(y_test, pred_test)) 92 | # Precision, Recall, F1-score 93 | cal_measure(pred_test, y_test) 94 | kappa_value = cohen_kappa_score(pred_test, y_test) 95 | print('Cohen_Kappa: %f' % kappa_value) 96 | 97 | # SHAP 98 | print('SHAP...') 99 | # SHAP_(model.predict_proba, x_train, x_test, f_names) 100 | 101 | return model 102 | 103 | 104 | def DBN_(x_train, y_train, x_test, y_test): 105 | print('start DBN evaluation...') 106 | # Training 107 | model = SupervisedDBNClassification(hidden_layers_structure=[32, 32], 108 | learning_rate_rbm=0.001, 109 | learning_rate=0.5, 110 | n_epochs_rbm=10, 111 | n_iter_backprop=200, 112 | batch_size=64, 113 | activation_function='relu', 114 | dropout_p=0.1) 115 | model.fit(x_train, y_train) 116 | 117 | pred_train = np.array(model.predict(x_train)) 118 | pred_test = np.array(model.predict(x_test)) 119 | # 训练精度 120 | print('train_Accuracy: %f' % accuracy_score(y_train, pred_train)) 121 | # 测试精度 122 | print('test_Accuracy: %f' % accuracy_score(y_test, pred_test)) 123 | # pred1 = clf2.predict_proba() # 预测类别概率 124 | cal_measure(pred_test, y_test) 125 | kappa_value = cohen_kappa_score(pred_test, y_test) 126 | print('Cohen_Kappa: %f' % kappa_value) 127 | 128 | # SHAP 129 | print('SHAP...') 130 | # SHAP_(model.predict_proba, x_train, x_test, f_names) 131 | return model 132 | 133 | 134 | def RF_(x_train, y_train, x_test, y_test): 135 | """predict and test""" 136 | print('start RF evaluation...') 137 | model = RandomForestClassifier(n_estimators=200, max_depth=None) 138 | 139 | model.fit(x_train, y_train) 140 | pred_train = model.predict(x_train) 141 | pred_test = model.predict(x_test) 142 | # 训练精度 143 | print('train_Accuracy: %f' % accuracy_score(y_train, pred_train)) 144 | # 测试精度 145 | print('test_Accuracy: %f' % accuracy_score(y_test, pred_test)) 146 | # pred1 = clf2.predict_proba() # 预测类别概率 147 | cal_measure(pred_test, y_test) 148 | kappa_value = cohen_kappa_score(pred_test, y_test) 149 | print('Cohen_Kappa: %f' % kappa_value) 150 | 151 | # SHAP 152 | print('SHAP...') 153 | # TODO: SHAP for RF 154 | # SHAP_(model.predict_proba, x_train, x_test, f_names) 155 | shap.initjs() 156 | explainer = shap.Explainer(model) 157 | shap_values = explainer(x_train) 158 | shap.plots.bar(shap_values[:100, :, 0]) # shap_values(n_samples, features, _prob) 159 | return model 160 | 161 | 162 | if __name__ == "__main__": 163 | """Input data""" 164 | tmp = np.loadtxt('./src_data/samples_HK.csv', dtype=str, delimiter=",", encoding='UTF-8') 165 | f_names = tmp[0, :-3].astype(np.str) 166 | tmp_ = np.hstack((tmp[1:, :-3], tmp[1:, -1].reshape(-1, 1))).astype(np.float32) 167 | np.random.shuffle(tmp_) # shuffle 168 | # 训练集 169 | x_train = tmp_[:int(tmp_.shape[0] / 4 * 3), :-1] # 加载i行数据部分 170 | y_train = tmp_[:int(tmp_.shape[0] / 4 * 3), -1] # 加载类别标签部分 171 | x_train = x_train / x_train.max(axis=0) 172 | # 测试集 173 | x_test = tmp_[int(tmp_.shape[0] / 4 * 3):, :-1] # 加载i行数据部分 174 | y_test = tmp_[int(tmp_.shape[0] / 4 * 3):, -1] # 加载类别标签部分 175 | x_test = x_test / x_test.max(axis=0) 176 | # grid samples 177 | grid_f = np.loadtxt('./src_data/grid_samples_HK.csv', dtype=str, delimiter=",", encoding='UTF-8') 178 | samples_f = grid_f[1:, :-2].astype(np.float32) 179 | xy = grid_f[1:, -2:].astype(np.float32) 180 | samples_f = samples_f / samples_f.max(axis=0) 181 | 182 | """evaluate and save LSM result""" 183 | # SVM-based 184 | # model_svm = SVM_(x_train, y_train, x_test, y_test) 185 | # pred_LSM(model_svm, xy, samples_f, 'SVM') 186 | # print('done SVM-based LSM prediction! \n') 187 | 188 | # # MLP_based 189 | # model_mlp = ANN_(x_train, y_train, x_test, y_test) 190 | # pred_LSM(model_mlp, xy, samples_f, 'MLP') 191 | # print('done MLP-based LSM prediction! \n') 192 | 193 | # # DBN-based 194 | # model_dbn = DBN_(x_train, y_train, x_test, y_test) 195 | # pred_LSM(model_dbn, xy, samples_f, 'DBN') 196 | # print('done DBN-based LSM prediction! \n') 197 | 198 | # RF-based 199 | model_rf = RF_(x_train, y_train, x_test, y_test) 200 | pred_LSM(model_rf, xy, samples_f, 'RF') 201 | print('done RF-based LSM prediction! \n') 202 | 203 | 204 | -------------------------------------------------------------------------------- /figs/Overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLi-de/Meta_LSM/c25e2904761e3f4b4d5a797f0b4db7ddfe53236e/figs/Overview.jpg -------------------------------------------------------------------------------- /figure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as plt 4 | from sklearn import manifold 5 | from sklearn.decomposition import PCA 6 | import umap 7 | 8 | from mpl_toolkits.mplot3d import Axes3D # 3D plot 9 | import pandas as pd 10 | import tensorflow as tf 11 | 12 | from sklearn.metrics import roc_curve, auc 13 | from sklearn.model_selection import train_test_split 14 | from sklearn.preprocessing import label_binarize 15 | 16 | from scipy import interp 17 | from sklearn.metrics import roc_auc_score 18 | from sklearn import svm 19 | from sklearn.neural_network import MLPClassifier 20 | from sklearn.ensemble import RandomForestClassifier 21 | # from unsupervised_pretraining.dbn_.models import SupervisedDBNClassification 22 | from scipy.interpolate import make_interp_spline 23 | from sklearn.metrics._classification import accuracy_score 24 | 25 | """for visiualization""" 26 | 27 | 28 | def read_tasks(file, dim_input=16): 29 | """read csv and obtain tasks""" 30 | f = pd.ExcelFile(file) 31 | tasks = [] 32 | for sheetname in f.sheet_names: 33 | attr = pd.read_excel(file, usecols=dim_input - 1, sheet_name=sheetname).values.astype(np.float32) 34 | label = pd.read_excel(file, usecols=[dim_input], sheet_name=sheetname).values.reshape((-1, 1)).astype( 35 | np.float32) 36 | tasks.append([attr, label]) 37 | return tasks 38 | 39 | 40 | def read_csv(path): 41 | tmp = np.loadtxt(path, dtype=np.str, delimiter=",", encoding='UTF-8') 42 | tmp_feature = tmp[1:, :] 43 | np.random.shuffle(tmp_feature) # shuffle 44 | label_attr = tmp_feature[:, -1].astype(np.float32) # 45 | data_atrr = tmp_feature[:, :-1].astype(np.float32) # 46 | return data_atrr, label_attr 47 | 48 | 49 | def load_weights(npzfile): 50 | npzfile = np.load(npzfile) 51 | weights = {} 52 | weights['w0'] = npzfile['arr_0'] 53 | weights['b0'] = npzfile['arr_1'] 54 | weights['w1'] = npzfile['arr_2'] 55 | weights['b1'] = npzfile['arr_3'] 56 | weights['w2'] = npzfile['arr_4'] 57 | weights['b2'] = npzfile['arr_5'] 58 | return weights 59 | 60 | 61 | def transform_relu(inputX, weights, bias, activations=tf.nn.relu): 62 | return activations(tf.transpose(a=tf.matmul(weights, tf.transpose(a=inputX))) + bias) 63 | 64 | 65 | def forward(inp, weights, sess): 66 | for i in range(int(len(weights) / 2)): # 3 layers 67 | inp = transform_relu(inp, tf.transpose(a=weights['w' + str(i)]), weights['b' + str(i)]) 68 | return sess.run(inp) 69 | 70 | 71 | def _PCA(X, y, figsavename): 72 | pca = PCA(n_components=3) 73 | X_pca = pca.fit_transform(X) 74 | 75 | x_min, x_max = X_pca.min(0), X_pca.max(0) 76 | X_norm = (X_pca - x_min) / (x_max - x_min) 77 | 78 | fig = plt.figure() 79 | ax = Axes3D(fig) 80 | # ax.scatter(x1,x2,x3,c=pre 81 | 82 | landslide_pts_x = [] 83 | landslide_pts_y = [] 84 | landslide_pts_z = [] 85 | nonlandslide_pts_x = [] 86 | nonlandslide_pts_y = [] 87 | nonlandslide_pts_z = [] 88 | 89 | for i in range(len(y)): 90 | if y[i] == 0: 91 | nonlandslide_pts_x.append(X_norm[i][0]) 92 | nonlandslide_pts_y.append(X_norm[i][1]) 93 | nonlandslide_pts_z.append(X_norm[i][2]) 94 | if y[i] == 1: 95 | landslide_pts_x.append(X_norm[i][0]) 96 | landslide_pts_y.append(X_norm[i][1]) 97 | landslide_pts_z.append(X_norm[i][2]) 98 | 99 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red') 100 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue') 101 | 102 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2) 103 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 104 | # 设置坐标标签 105 | ax.set_xlabel('x-axis') 106 | ax.set_ylabel('y-axis') 107 | ax.set_zlabel('z-axis') 108 | 109 | # 设置标题 110 | plt.title("Visualization with PCA") 111 | plt.savefig(figsavename) 112 | # 显示图形 113 | plt.show() 114 | 115 | 116 | def ISOMAP(X, y, figsavename): 117 | isomap = manifold.Isomap(n_components=3) 118 | X_isomap = isomap.fit_transform(X) 119 | 120 | x_min, x_max = X_isomap.min(0), X_isomap.max(0) 121 | X_norm = (X_isomap - x_min) / (x_max - x_min) 122 | 123 | fig = plt.figure() 124 | ax = Axes3D(fig) 125 | # ax.scatter(x1,x2,x3,c=pre 126 | 127 | landslide_pts_x = [] 128 | landslide_pts_y = [] 129 | landslide_pts_z = [] 130 | nonlandslide_pts_x = [] 131 | nonlandslide_pts_y = [] 132 | nonlandslide_pts_z = [] 133 | 134 | for i in range(len(y)): 135 | if y[i] == 0: 136 | nonlandslide_pts_x.append(X_norm[i][0]) 137 | nonlandslide_pts_y.append(X_norm[i][1]) 138 | nonlandslide_pts_z.append(X_norm[i][2]) 139 | if y[i] == 1: 140 | landslide_pts_x.append(X_norm[i][0]) 141 | landslide_pts_y.append(X_norm[i][1]) 142 | landslide_pts_z.append(X_norm[i][2]) 143 | 144 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red') 145 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue') 146 | 147 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2) 148 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 149 | # 设置坐标标签 150 | ax.set_xlabel('x-axis') 151 | ax.set_ylabel('y-axis') 152 | ax.set_zlabel('z-axis') 153 | # 设置标题 154 | plt.title("Visualization with Isomap") 155 | plt.savefig(figsavename) 156 | # 显示图形 157 | plt.show() 158 | 159 | 160 | def t_SNE(X, y, figsavename): 161 | tsne = manifold.TSNE(n_components=3, init='random', random_state=501) 162 | X_tsne = tsne.fit_transform(X) 163 | 164 | """嵌入空间可视化""" 165 | x_min, x_max = X_tsne.min(0), X_tsne.max(0) 166 | X_norm = (X_tsne - x_min) / (x_max - x_min) 167 | 168 | fig = plt.figure() 169 | ax = Axes3D(fig) 170 | # ax.scatter(x1,x2,x3,c=pre 171 | 172 | landslide_pts_x = [] 173 | landslide_pts_y = [] 174 | landslide_pts_z = [] 175 | nonlandslide_pts_x = [] 176 | nonlandslide_pts_y = [] 177 | nonlandslide_pts_z = [] 178 | 179 | for i in range(len(y)): 180 | if y[i] == 0: 181 | nonlandslide_pts_x.append(X_norm[i][0]) 182 | nonlandslide_pts_y.append(X_norm[i][1]) 183 | nonlandslide_pts_z.append(X_norm[i][2]) 184 | if y[i] == 1: 185 | landslide_pts_x.append(X_norm[i][0]) 186 | landslide_pts_y.append(X_norm[i][1]) 187 | landslide_pts_z.append(X_norm[i][2]) 188 | 189 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red') 190 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue') 191 | 192 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2) 193 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 194 | # 设置坐标标签 195 | ax.set_xlabel('x-axis') 196 | ax.set_ylabel('y-axis') 197 | ax.set_zlabel('z-axis') 198 | # 设置标题 199 | plt.title("Visualization with t-SNE") 200 | plt.savefig(figsavename) 201 | # 显示图形 202 | plt.show() 203 | 204 | 205 | def UMAP(X, y, figsavename): 206 | reducer = umap.UMAP(n_components=3) 207 | X_umap = reducer.fit_transform(X) 208 | 209 | x_min, x_max = X_umap.min(0), X_umap.max(0) 210 | X_norm = (X_umap - x_min) / (x_max - x_min) 211 | 212 | fig = plt.figure() 213 | ax = Axes3D(fig) 214 | # ax.scatter(x1,x2,x3,c=pre 215 | 216 | landslide_pts_x = [] 217 | landslide_pts_y = [] 218 | landslide_pts_z = [] 219 | nonlandslide_pts_x = [] 220 | nonlandslide_pts_y = [] 221 | nonlandslide_pts_z = [] 222 | 223 | for i in range(len(y)): 224 | if y[i] == 0: 225 | nonlandslide_pts_x.append(X_norm[i][0]) 226 | nonlandslide_pts_y.append(X_norm[i][1]) 227 | nonlandslide_pts_z.append(X_norm[i][2]) 228 | if y[i] == 1: 229 | landslide_pts_x.append(X_norm[i][0]) 230 | landslide_pts_y.append(X_norm[i][1]) 231 | landslide_pts_z.append(X_norm[i][2]) 232 | 233 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red') 234 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue') 235 | 236 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2) 237 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 238 | # 设置坐标标签 239 | ax.set_xlabel('x-axis') 240 | ax.set_ylabel('y-axis') 241 | ax.set_zlabel('z-axis') 242 | # 设置标题 243 | plt.title("Visualization with UMAP") 244 | plt.savefig(figsavename) 245 | # 显示图形 246 | plt.show() 247 | 248 | 249 | def visualization(): 250 | fj_tasks = read_tasks('./metatask_sampling/FJ_tasks.xlsx') # task里的samplles 251 | # fl_tasks = read_tasks('./metatask_sampling/FL_tasks.xlsx') # num_samples of FL is too scarce to visualize 252 | 253 | """select part of FJ and FL data for visualization""" 254 | # def get_oriandinf_Xs(tasks, regionname): 255 | # ori_Xs, inf_Xs, Ys = [], [], [] 256 | # for i in range(len(tasks)): 257 | # if tasks[i][0].shape[0] > 30: 258 | # """download model parameters""" 259 | # w = load_weights('models_of_blocks/'+ regionname +'/model'+str(i) + '.npz') 260 | # ori_Xs.append(tasks[i][0]) 261 | # inf_Xs.append(forward(tasks[i][0], w)) 262 | # Ys.append(tasks[i][1]) 263 | # return ori_Xs, inf_Xs, Ys 264 | """overall FJ and FL data for visualization""" 265 | 266 | def get_oriandinf_Xs(tasks, regionname): 267 | with tf.compat.v1.Session() as sess: # for tf calculation 268 | w = load_weights('models_of_blocks/' + 'overall_' + regionname + '/model_MAML' + '.npz') 269 | ori_Xs = tasks[0][0] 270 | inf_Xs = forward(tasks[0][0], w, sess) 271 | Ys = tasks[0][1] 272 | for i in range(len(tasks) - 1): 273 | if len(tasks[i + 1][0]) > 0: 274 | ori_Xs = np.vstack((ori_Xs, tasks[i + 1][0])) 275 | inf_Xs = np.vstack((inf_Xs, forward(tasks[i + 1][0], w, sess))) 276 | Ys = np.vstack((Ys, tasks[i + 1][1])) 277 | return ori_Xs, inf_Xs, Ys 278 | 279 | ori_FJ_Xs, inf_FJ_Xs, FJ_Ys = get_oriandinf_Xs(fj_tasks, 'FJ') 280 | 281 | # ori_FL_Xs, inf_FL_Xs, FL_Ys = get_oriandinf_Xs(fl_tasks, 'FL') 282 | 283 | # ori_X, y = read_csv('src_data/FJ_FL.csv') 284 | # 285 | # tmp = np.loadtxt('src_data/FJ_FL.csv', dtype=np.str, delimiter=",",encoding='UTF-8') 286 | # w = load_weights('unsupervised_pretraining/model_init/savedmodel.npz') 287 | # unsupervised_X = forward(ori_X, w) 288 | def plot_points(ori_X, inf_X, Y, regionname): 289 | _PCA(ori_X, Y, './figs/' + regionname + '_ori_PCA.pdf') 290 | _PCA(inf_X, Y, './figs/' + regionname + '_inf_PCA.pdf') 291 | 292 | t_SNE(ori_X, Y, './figs/' + regionname + '_ori_t_SNE.pdf') 293 | t_SNE(inf_X, Y, './figs/' + regionname + '_inf_SNE.pdf') 294 | 295 | ISOMAP(ori_X, Y, './figs/' + regionname + '_ori_Isomap.pdf') 296 | ISOMAP(inf_X, Y, './figs/' + regionname + '_inf_Isomap.pdf') 297 | 298 | UMAP(ori_X, Y, './figs/' + regionname + '_ori_UMAP.pdf') 299 | UMAP(inf_X, Y, './figs/' + regionname + '_inf_UMAP.pdf') 300 | 301 | plot_points(ori_FJ_Xs, inf_FJ_Xs, FJ_Ys, 'FJ') 302 | # plot_points(ori_FL_Xs, inf_FL_Xs, FL_Ys, 'FL') 303 | 304 | 305 | """for figure plotting""" 306 | 307 | 308 | def read_statistic(file): 309 | """读取csv获取statistic""" 310 | f = pd.ExcelFile(file) 311 | K, meanOA, maxOA, minOA, std = [], [], [], [], [] 312 | for sheetname in f.sheet_names: 313 | tmp_K, tmp_meanOA, tmp_maxOA, tmp_minOA, tmp_std = np.transpose( 314 | pd.read_excel(file, sheet_name=sheetname).values) 315 | K.append(tmp_K) 316 | meanOA.append(tmp_meanOA) 317 | maxOA.append(tmp_maxOA) 318 | minOA.append(tmp_minOA) 319 | std.append(tmp_std) 320 | return K, meanOA, maxOA, minOA, std 321 | 322 | 323 | def read_statistic1(file): 324 | """读取csv获取statistic""" 325 | f = pd.ExcelFile(file) 326 | K, meanOA = [], [] 327 | for sheetname in f.sheet_names: 328 | tmp_K, tmp_meanOA = np.transpose(pd.read_excel(file, sheet_name=sheetname).values) 329 | K.append(tmp_K) 330 | meanOA.append(tmp_meanOA) 331 | return K, meanOA 332 | 333 | 334 | def read_statistic2(file): 335 | """读取csv获取statistic""" 336 | f = pd.ExcelFile(file) 337 | measures = [] 338 | for sheetname in f.sheet_names: 339 | temp = pd.read_excel(file, sheet_name=sheetname).values 340 | measures.append(temp[:, 1:].tolist()) 341 | return measures 342 | 343 | 344 | def plot_candle(scenes, K, meanOA, maxOA, minOA, std): 345 | # 设置框图 346 | plt.figure("", facecolor="lightgray") 347 | # plt.style.use('ggplot') 348 | # 设置图例并且设置图例的字体及大小 349 | font1 = {'family': 'Times New Roman', 350 | 'weight': 'normal', 351 | 'size': 20, 352 | } 353 | font2 = {'family': 'Times New Roman', 354 | 'weight': 'normal', 355 | 'size': 18, 356 | } 357 | 358 | # legend = plt.legend(handles=[A,B],prop=font1) 359 | # plt.title(scenes, fontdict=font2) 360 | # plt.xlabel("Various methods", fontdict=font1) 361 | plt.ylabel("OA(%)", fontdict=font2) 362 | 363 | my_x_ticks = [1, 2, 3, 4, 5] 364 | # my_x_ticklabels = ['SVM', 'MLP', 'DBN', 'RF', 'Proposed'] 365 | plt.xticks(ticks=my_x_ticks, labels='', fontsize=16) 366 | 367 | plt.ylim((60, 100)) 368 | my_y_ticks = np.arange(60, 100, 5) 369 | plt.yticks(ticks=my_y_ticks, fontsize=16) 370 | 371 | colors = ['dodgerblue', 'lawngreen', 'gold', 'magenta', 'red'] 372 | edge_colors = np.zeros(5, dtype="U1") 373 | edge_colors[:] = 'black' 374 | 375 | '''格网设置''' 376 | plt.grid(linestyle="--", zorder=-1) 377 | 378 | # draw line 379 | # plt.plot(K[0:-1], meanOA[0:-1], color="b", linestyle='solid', 380 | # linewidth=1, label="open", zorder=1) 381 | # plt.plot(K[-2:], meanOA[-2:], color="b", linestyle="--", 382 | # linewidth=1, label="open", zorder=1) 383 | 384 | # draw bar 385 | barwidth = 0.4 386 | plt.bar(K, 2 * std, barwidth, bottom=meanOA - std, color=colors, 387 | edgecolor=edge_colors, linewidth=1, zorder=20, label=['SVM', 'MLP', 'DBN', 'RF', 'Proposed']) 388 | 389 | # draw vertical line 390 | plt.vlines(K, minOA, maxOA, color='black', linestyle='solid', zorder=10) 391 | plt.hlines(meanOA, K - barwidth / 2, K + barwidth / 2, color='black', linestyle='solid', zorder=30) 392 | plt.hlines(minOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10) 393 | plt.hlines(maxOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10) 394 | 395 | # 设置图例 396 | legend = plt.legend(loc="lower center", prop=font1, ncol=3, columnspacing=0.1) 397 | 398 | 399 | def plot_scatter(arr): 400 | '''设置框图''' 401 | # plt.figure("", facecolor="lightgray") # 设置框图大小 402 | font1 = {'family': 'Times New Roman', 403 | 'weight': 'normal', 404 | 'size': 16, 405 | } 406 | font2 = {'family': 'Times New Roman', 407 | 'weight': 'normal', 408 | 'size': 12, 409 | } 410 | plt.xlabel("Subtasks", fontdict=font1) 411 | plt.ylabel("Mean accuracy(%)", fontdict=font1) 412 | 413 | '''设置刻度''' 414 | plt.ylim((50, 100)) 415 | my_y_ticks = np.arange(50, 100, 5) 416 | plt.yticks(my_y_ticks) 417 | my_x_ticks = [i for i in range(1, 204, 40)] 418 | my_x_ticklabel = [str(i) + 'th' for i in range(1, 204, 40)] 419 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabel) 420 | '''格网设置''' 421 | plt.grid(linestyle="--") 422 | 423 | x_ = [i for i in range(arr.shape[0])] 424 | '''draw scatter''' 425 | L1 = plt.scatter(x_, arr[:, 0], label="L=1", c="none", s=20, edgecolors='magenta') 426 | L2 = plt.scatter(x_, arr[:, 1], label="L=2", c="none", s=20, edgecolors='cyan') 427 | L3 = plt.scatter(x_, arr[:, 2], label="L=3", c="none", s=20, edgecolors='b') 428 | L4 = plt.scatter(x_, arr[:, 3], label="L=4", c="none", s=20, edgecolors='g') 429 | L5 = plt.scatter(x_, arr[:, 4], label="L=5", c="none", s=20, edgecolors='r') 430 | 431 | '''设置图例''' 432 | legend = plt.legend(loc="lower left", prop=font2, ncol=3) 433 | # plt.savefig("C:\\Users\\hj\\Desktop\\brokenline_A") 434 | # plt.show() 435 | 436 | 437 | def plot_lines(arr): 438 | '''设置框图''' 439 | # plt.figure("", facecolor="lightgray") # 设置框图大小 440 | font1 = {'family': 'Times New Roman', 441 | 'weight': 'normal', 442 | 'size': 16, 443 | } 444 | font2 = {'family': 'Times New Roman', 445 | 'weight': 'normal', 446 | 'size': 12, 447 | } 448 | plt.xlabel("Subtasks", fontdict=font1) 449 | plt.ylabel("Mean accuracy(%)", fontdict=font1) 450 | 451 | '''设置刻度''' 452 | plt.ylim((50, 100)) 453 | my_y_ticks = np.arange(50, 100, 5) 454 | plt.yticks(my_y_ticks) 455 | my_x_ticks = [i for i in range(6)] 456 | my_x_ticklabel = [str(i + 1) + '/12 M' for i in range(6)] 457 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabel) 458 | '''格网设置''' 459 | plt.grid(linestyle="--") 460 | 461 | x_ = np.array([i for i in range(6)]) 462 | # smooth 463 | # x_ = np.linspace(x_.min(), x_.max(), 400) 464 | # arr = make_interp_spline(x_, arr)(x_) 465 | '''draw line''' 466 | L1 = plt.plot(x_, arr[:, 0], color="r", linestyle="solid", 467 | linewidth=1, label="L=1", markerfacecolor='white', ms=10) 468 | L2 = plt.plot(x_, arr[:, 1], color="orange", linestyle="solid", 469 | linewidth=1, label="L=2", markerfacecolor='white', ms=10) 470 | L3 = plt.plot(x_, arr[:, 2], color="gold", linestyle="solid", 471 | linewidth=1, label="L=3", markerfacecolor='white', ms=10) 472 | L4 = plt.plot(x_, arr[:, 3], color="g", linestyle="solid", 473 | linewidth=1, label="L=4", markerfacecolor='white', ms=10) 474 | L5 = plt.plot(x_, arr[:, 4], color="b", linestyle="solid", 475 | linewidth=1, label="L=5", markerfacecolor='white', ms=10) 476 | 477 | 478 | def plot_histogram(region, measures): 479 | '''设置框图''' 480 | plt.figure("", facecolor="lightgray") # 设置框图大小 481 | font1 = {'family': 'Times New Roman', 482 | 'weight': 'normal', 483 | 'size': 14, 484 | } 485 | font2 = {'family': 'Times New Roman', 486 | 'weight': 'normal', 487 | 'size': 18, 488 | } 489 | # plt.xlabel("Statistical measures", fontdict=font1) 490 | plt.ylabel("Performance(%)", fontdict=font1) 491 | plt.title(region, fontdict=font2) 492 | 493 | '''设置刻度''' 494 | plt.ylim((60, 90)) 495 | my_y_ticks = np.arange(60, 90, 3) 496 | plt.yticks(my_y_ticks) 497 | 498 | my_x_ticklabels = ['Accuracy', 'Precision', 'Recall', 'F1-score'] 499 | bar_width = 0.3 500 | interval = 0.2 501 | my_x_ticks = np.arange(bar_width / 2 + 2.5 * bar_width, 4 * 5 * bar_width + 1, bar_width * 6) 502 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabels, fontproperties='Times New Roman', size=14) 503 | 504 | '''格网设置''' 505 | plt.grid(linestyle="--") 506 | 507 | '''draw bar''' 508 | rects1 = plt.bar([x - 2 * bar_width for x in my_x_ticks], height=measures[0], width=bar_width, alpha=0.8, 509 | color='dodgerblue', label="MLP") 510 | rects2 = plt.bar([x - 1 * bar_width for x in my_x_ticks], height=measures[1], width=bar_width, alpha=0.8, 511 | color='yellowgreen', label="RF") 512 | rects3 = plt.bar([x for x in my_x_ticks], height=measures[2], width=bar_width, alpha=0.8, color='gold', label="RL") 513 | rects4 = plt.bar([x + 1 * bar_width for x in my_x_ticks], height=measures[3], width=bar_width, alpha=0.8, 514 | color='peru', label="MAML") 515 | rects5 = plt.bar([x + 2 * bar_width for x in my_x_ticks], height=measures[4], width=bar_width, alpha=0.8, 516 | color='crimson', label="proposed") 517 | 518 | '''设置图例''' 519 | legend = plt.legend(loc="upper left", prop=font1, ncol=3) 520 | 521 | '''add text''' 522 | # for rect in rects1: 523 | # height = rect.get_height() 524 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom") 525 | # for rect in rects2: 526 | # height = rect.get_height() 527 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom") 528 | # for rect in rects3: 529 | # height = rect.get_height() 530 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom") 531 | # for rect in rects4: 532 | # height = rect.get_height() 533 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom") 534 | # for rect in rects5: 535 | # height = rect.get_height() 536 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom") 537 | 538 | plt.savefig("C:\\Users\\hj\\Desktop\\histogram" + region + '.pdf') 539 | plt.show() 540 | 541 | 542 | """for AUROC plotting""" 543 | 544 | 545 | def load_data(filepath, dim_input): 546 | np.loadtxt(filepath, ) 547 | data = pd.read_excel(filepath).values.astype(np.float32) 548 | attr = data[:, :dim_input] 549 | attr = attr / attr.max(axis=0) 550 | label = data[:, -1].astype(np.int32) 551 | return attr, label 552 | 553 | 554 | def SVM_fit_pred(x_train, x_test, y_train, y_test): 555 | classifier = svm.SVC(C=1, kernel='rbf', gamma=1 / (2 * x_train.var()), decision_function_shape='ovr', 556 | probability=True) 557 | classifier.fit(x_train, y_train) 558 | return classifier.predict_proba(x_test) 559 | 560 | 561 | def MLP_fit_pred(x_train, x_test, y_train, y_test): 562 | classifier = MLPClassifier(hidden_layer_sizes=(32, 32, 16), activation='relu', solver='adam', alpha=0.01, 563 | batch_size=32, max_iter=1000) 564 | classifier.fit(x_train, y_train) 565 | return classifier.predict_proba(x_test) 566 | 567 | 568 | # def DBN_fit_pred(x_train, x_test, y_train, y_test): 569 | # classifier = SupervisedDBNClassification(hidden_layers_structure=[32, 32], 570 | # learning_rate_rbm=0.001, 571 | # learning_rate=0.5, 572 | # n_epochs_rbm=10, 573 | # n_iter_backprop=200, 574 | # batch_size=64, 575 | # activation_function='relu', 576 | # dropout_p=0.1) 577 | # classifier.fit(x_train, y_train) 578 | # pred_prob = classifier.predict_proba(x_test) 579 | # 580 | # # if pred_prob[0][0] > 0.5: 581 | # # pred_prob = np.vstack((pred_prob[:, 0], pred_prob[:, -1])).T # swap 0, 1 prediction 582 | # 583 | # return pred_prob 584 | 585 | 586 | def RF_fit_pred(x_train, x_test, y_train, y_test): 587 | classifier = RandomForestClassifier(n_estimators=200, max_depth=None) 588 | classifier.fit(x_train, y_train) 589 | return classifier.predict_proba(x_test) 590 | 591 | 592 | def plot_auroc(n_times, y_score_SVM, y_score_MLP, y_score_DBN, y_score_RF, y_score_proposed, y_test, y_test_proposed): 593 | # Compute ROC curve and ROC area for each class 594 | def cal_(y_score, y_test): 595 | fpr, tpr = [], [] 596 | for i in range(n_times): 597 | fpr_, tpr_, thresholds = roc_curve(y_test[i], y_score[i][:, -1], pos_label=1) 598 | fpr.append(fpr_) 599 | tpr.append(tpr_) 600 | 601 | # First aggregate all false positive rates 602 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_times)])) 603 | 604 | # Then interpolate all ROC curves at this points 605 | mean_tpr = np.zeros_like(all_fpr) 606 | for i in range(n_times): 607 | mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) 608 | 609 | # Finally average it and compute AUC 610 | mean_tpr /= n_times 611 | mean_auc = auc(all_fpr, mean_tpr) 612 | return all_fpr, mean_tpr, mean_auc, fpr, tpr 613 | 614 | def plot_(y_score, y_test, color, method): 615 | all_fpr, mean_tpr, mean_auc, fpr, tpr = cal_(y_score, y_test) 616 | # draw mean 617 | plt.plot(all_fpr, mean_tpr, 618 | label=method + '_mean_AUC (area = {0:0.3f})'''.format(mean_auc), 619 | color=color, linewidth=1.5) 620 | # draw each 621 | for i in range(n_times): 622 | plt.plot(fpr[i], tpr[i], 623 | color=color, linewidth=1, alpha=.25) 624 | # plt.savefig(method + '.pdf') 625 | 626 | # Plot all ROC curves 627 | # ax = plt.axes() 628 | # ax.set_facecolor("WhiteSmoke") # background color 629 | plot_(y_score_SVM, y_test, color='dodgerblue', method='SVM') 630 | plot_(y_score_MLP, y_test, color='lawngreen', method='MLP') 631 | plot_(y_score_DBN, y_test, color='gold', method='DBN') 632 | plot_(y_score_RF, y_test, color='magenta', method='RF') 633 | plot_(y_score_proposed, y_test_proposed, color='red', method='Proposed') 634 | 635 | # format 636 | font1 = {'family': 'Times New Roman', 637 | 'weight': 'normal', 638 | 'size': 16, 639 | } 640 | font2 = {'family': 'Times New Roman', 641 | 'weight': 'normal', 642 | 'size': 12, 643 | } 644 | plt.plot([0, 1], [0, 1], 'k--', lw=1, label='random') 645 | plt.xlim([0.0, 1.0]) 646 | plt.ylim([0.0, 1.0]) 647 | plt.xlabel('False Positive Rate', fontdict=font1) 648 | plt.ylabel('True Positive Rate', fontdict=font1) 649 | plt.title('ROC curve by various methods', fontdict=font1) 650 | plt.legend(loc="lower right", prop=font2) 651 | 652 | 653 | """space visualization""" 654 | # visualization() 655 | 656 | """draw histogram""" 657 | 658 | # regions = ['FJ', 'FL'] 659 | # measures = read_statistic2("C:\\Users\\hj\\Desktop\\performance.xlsx") 660 | # for i in range(len(regions)): 661 | # plot_histogram(regions[i], measures[i]) 662 | 663 | 664 | """draw candle""" 665 | 666 | 667 | scenes = ['airport', 'urban1', 'urban2', 'plain', 'catchment', 'reservior'] 668 | K, meanOA, maxOA, minOA, std = read_statistic("C:\\Users\\lichen\\OneDrive\\桌面\\statistics_candle.xlsx") 669 | for i in range(len(scenes)): 670 | plot_candle(scenes[i], K[i], meanOA[i], maxOA[i], minOA[i], std[i]) 671 | plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\" + scenes[i] + '_' + 'candle.pdf') 672 | plt.show() 673 | 674 | 675 | def read_f_l_csv(file): 676 | tmp = np.loadtxt(file, dtype=str, delimiter=",", encoding='UTF-8') 677 | features = tmp[1:, :-2].astype(np.float32) 678 | features = features / features.max(axis=0) 679 | label = tmp[1:, -1].astype(np.float32) 680 | return features, label 681 | 682 | 683 | # """draw AUR""" 684 | # print('drawing ROC...') 685 | # x, y = read_f_l_csv('src_data/samples_HK.csv') 686 | # y_score_SVM, y_score_MLP, y_score_DBN, y_score_RF, y_score_proposed, y_test_, y_test_proposed = [], [], [], [], [], [], [] 687 | # n_times = 5 688 | # for i in range(n_times): 689 | # x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.75, test_size=.02, shuffle=True) 690 | # """fit and predict""" 691 | # # for other methods 692 | # y_score_SVM.append(SVM_fit_pred(x_train, x_test, y_train, y_test)) 693 | # y_score_MLP.append(MLP_fit_pred(x_train, x_test, y_train, y_test)) 694 | # y_score_DBN.append(MLP_fit_pred(x_train, x_test, y_train, y_test)) 695 | # y_score_RF.append(RF_fit_pred(x_train, x_test, y_train, y_test)) 696 | # y_test_.append(y_test) 697 | # # for proposed- 698 | # tmp = pd.read_excel('tmp/' + 'proposed_test' + str(i) + '.xlsx').values.astype(np.float32) 699 | # y_score_proposed.append(tmp[:, 1:3]) 700 | # y_test_proposed.append(tmp[:, -1]) 701 | # # draw roc 702 | # plt.clf() 703 | # plot_auroc(n_times, y_score_SVM, y_score_MLP, y_score_DBN, y_score_RF, y_score_proposed, y_test_, y_test_proposed) 704 | # plt.savefig('ROC.pdf') 705 | # plt.show() 706 | # print('finish') 707 | 708 | """draw scatters for fast adaption performance""" 709 | # filename = "C:\\Users\\lichen\\OneDrive\\桌面\\fast_adaption_sheet2.csv" 710 | # arr = np.loadtxt(filename, dtype=float, delimiter=",", encoding='utf-8-sig') 711 | # plot_scatter(arr) 712 | # plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\scatters.pdf") 713 | # plt.show() 714 | 715 | """draw lines for fast adaption performance""" 716 | # filename = "C:\\Users\\lichen\\OneDrive\\桌面\\fast_adaption1.csv" 717 | # arr = np.loadtxt(filename, dtype=float, delimiter=",", encoding='utf-8-sig') 718 | # plot_lines(arr) 719 | # plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\broken.pdf") 720 | # plt.show() 721 | 722 | """ 723 | label: for legend 724 | pos_: -2, -1, 0, 1, 2 725 | """ 726 | 727 | 728 | def plot_candle1(K, meanOA, maxOA, minOA, std, color_, label_, pos_): 729 | # 设置框图 730 | # plt.figure("", facecolor="lightgray") 731 | # plt.style.use('ggplot') 732 | # 设置图例并且设置图例的字体及大小 733 | font1 = {'family': 'Times New Roman', 734 | 'weight': 'normal', 735 | 'size': 14, 736 | } 737 | font2 = {'family': 'Times New Roman', 738 | 'weight': 'normal', 739 | 'size': 16, 740 | } 741 | 742 | # legend = plt.legend(handles=[A,B],prop=font1) 743 | # plt.title(scenes, fontdict=font2) 744 | plt.xlabel("Number of samples", fontdict=font1) 745 | plt.ylabel("OA(%)", fontdict=font2) 746 | 747 | my_x_ticks = [1, 2, 3, 4, 5] 748 | my_x_ticklabels = ['1', '2', '3', '4', '5'] 749 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabels, fontsize=14, fontdict=font2) 750 | 751 | plt.ylim((50, 100)) 752 | my_y_ticks = np.arange(50, 100, 5) 753 | plt.yticks(ticks=my_y_ticks, fontsize=14, font=font2) 754 | 755 | '''格网设置''' 756 | plt.grid(linestyle="--", zorder=-1) 757 | 758 | colors = ['dodgerblue', 'lawngreen', 'gold', 'magenta', 'red'] 759 | edge_colors = np.zeros(5, dtype="U1") 760 | edge_colors[:] = 'black' 761 | 762 | # draw bar 763 | barwidth = 0.15 764 | K = K + barwidth * pos_ 765 | plt.bar(K, 2 * std, barwidth, bottom=meanOA - std, color=color_, 766 | edgecolor=edge_colors, linewidth=1, zorder=20, label=label_, alpha=0.5) 767 | # draw vertical line 768 | plt.vlines(K, minOA, meanOA - std, color='black', linestyle='solid', zorder=10) 769 | plt.vlines(K, maxOA, meanOA + std, color='black', linestyle='solid', zorder=10) 770 | plt.hlines(meanOA, K - barwidth / 2, K + barwidth / 2, color='blue', linestyle='solid', zorder=30) 771 | plt.hlines(minOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10) 772 | plt.hlines(maxOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10) 773 | # 设置图例 774 | legend = plt.legend(loc="lower right", prop=font1, ncol=3, fontsize=24) 775 | 776 | 777 | """draw candles for fast adaption performance""" 778 | # K, meanOA, maxOA, minOA, std = read_statistic("C:\\Users\\lichen\\OneDrive\\桌面\\fast_adaption_candle.xlsx") 779 | # colors = ['magenta', 'cyan', 'b', 'g', 'r'] 780 | # labels = ['L=1', 'L=2', 'L=3', 'L=4', 'L=5'] 781 | # pos = [-2, -1, 0, 1, 2] 782 | # for i in range(5): 783 | # plot_candle1(K[i], meanOA[i], maxOA[i], minOA[i], std[i], colors[i], labels[i], pos[i]) 784 | # # plt.show() 785 | # plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\candle.pdf") 786 | -------------------------------------------------------------------------------- /meta_LSM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import pandas as pd 4 | from modeling import MAML 5 | from scene_segmentation import SLICProcessor, TaskSampling 6 | from tensorflow.python.platform import flags 7 | from utils import tasksbatch_generator, batch_generator, meta_train_test1, save_tasks, \ 8 | read_tasks, savepts_fortask, cal_measure 9 | from unsupervised_pretraining.DAS_pretraining_v2 import Unsupervise_pretrain 10 | from sklearn.metrics import accuracy_score 11 | from sklearn.metrics import cohen_kappa_score 12 | 13 | from sklearn.neural_network import MLPClassifier 14 | 15 | from comparison import SHAP_ 16 | import warnings 17 | import os 18 | 19 | warnings.filterwarnings("ignore") 20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | """for task sampling""" 25 | flags.DEFINE_float('M', 500, 'determine how distance influence the segmentation') 26 | flags.DEFINE_integer('K', 512, 'number of superpixels') 27 | flags.DEFINE_integer('loop', 5, 'number of SLIC iterations') 28 | flags.DEFINE_string('str_region', 'HK', 'the study area') 29 | flags.DEFINE_string('sample_pts', './src_data/samples_HK.csv', 'path to (non)/landslide samples') 30 | flags.DEFINE_string('Ts_pts', './src_data/Ts_HK.csv', 'path to Ts samples') 31 | 32 | """for meta-train""" 33 | flags.DEFINE_string('basemodel', 'DAS', 'MLP: no unsupervised pretraining; DAS: pretraining with DAS') 34 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None') 35 | flags.DEFINE_string('log', './tmp/data', 'batch_norm, layer_norm, or None') 36 | flags.DEFINE_string('logdir', './checkpoint_dir', 'directory for summaries and checkpoints.') 37 | 38 | flags.DEFINE_integer('dim_input', 14, 'dim of input data') 39 | flags.DEFINE_integer('dim_output', 2, 'dim of output data') 40 | flags.DEFINE_integer('meta_batch_size', 16, 'number of tasks sampled per meta-update, not nums tasks') 41 | flags.DEFINE_integer('num_samples_each_task', 16, 42 | 'number of samples sampling from each task when training, inner_batch_size') 43 | flags.DEFINE_integer('test_update_batch_size', 8, 44 | 'number of examples used for gradient update during adapting (K=1,3,5 in experiment, K-shot); -1: M.') 45 | flags.DEFINE_integer('metatrain_iterations', 5001, 'number of meta-training iterations.') 46 | flags.DEFINE_integer('num_updates', 5, 'number of inner gradient updates during training.') 47 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.') 48 | # flags.DEFINE_integer('num_samples', 18469, 'total number of samples in HK, see samples_HK.') 49 | flags.DEFINE_float('update_lr', 1e-2, 'learning rate of single task objective (inner)') # le-2 is the best 50 | flags.DEFINE_float('meta_lr', 1e-3, 'the base learning rate of meta objective (outer)') # le-2 or le-3 51 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)') 52 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available') 53 | 54 | 55 | def train(model, saver, sess, exp_string, tasks, resume_itr): 56 | SUMMARY_INTERVAL = 100 57 | SAVE_INTERVAL = 1000 58 | PRINT_INTERVAL = 1000 59 | 60 | print('Done model initializing, starting training...') 61 | prelosses, postlosses = [], [] 62 | if resume_itr != FLAGS.pretrain_iterations + FLAGS.metatrain_iterations - 1: 63 | if FLAGS.log: 64 | train_writer = tf.compat.v1.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph) 65 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations): 66 | batch_x, batch_y, cnt_sample = tasksbatch_generator(tasks, FLAGS.meta_batch_size 67 | , FLAGS.num_samples_each_task, 68 | FLAGS.dim_input, 69 | FLAGS.dim_output) # task_batch[i]: (x, y, features) 70 | # batch_y = _transform_labels_to_network_format(batch_y, FLAGS.num_classes) 71 | inputa = batch_x[:, :int(FLAGS.num_samples_each_task / 2), :] # a used for training 72 | labela = batch_y[:, :int(FLAGS.num_samples_each_task / 2), :] 73 | inputb = batch_x[:, int(FLAGS.num_samples_each_task / 2):, :] # b used for testing 74 | labelb = batch_y[:, int(FLAGS.num_samples_each_task / 2):, :] 75 | # # when deal with few-shot problem 76 | # inputa = batch_x[:, :int(len(batch_x[0]) / 2), :] # a used for training 77 | # labela = batch_y[:, :int(len(batch_y[0]) / 2), :] 78 | # inputb = batch_x[:, int(len(batch_x[0]) / 2):, :] # b used for testing 79 | # labelb = batch_y[:, int(len(batch_y[0]) / 2):, :] 80 | 81 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, 82 | model.labelb: labelb, model.cnt_sample: cnt_sample} 83 | 84 | if itr < FLAGS.pretrain_iterations: 85 | input_tensors = [model.pretrain_op] # for comparison 86 | else: 87 | input_tensors = [model.metatrain_op] # meta_train 88 | 89 | if (itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0): 90 | input_tensors.extend([model.summ_op, model.total_loss1, model.total_losses2[FLAGS.num_updates - 1]]) 91 | 92 | result = sess.run(input_tensors, feed_dict) 93 | 94 | if itr % SUMMARY_INTERVAL == 0: 95 | prelosses.append(result[-2]) 96 | if FLAGS.log: 97 | train_writer.add_summary(result[1], itr) # add sum_op 98 | postlosses.append(result[-1]) 99 | 100 | if (itr != 0) and itr % PRINT_INTERVAL == 0: 101 | if itr < FLAGS.pretrain_iterations: 102 | print_str = 'Pretrain Iteration ' + str(itr) 103 | else: 104 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations) 105 | print_str += ': ' + str(np.mean(prelosses)) + ', ' + str(np.mean(postlosses)) 106 | print(print_str) 107 | print('inner lr:', sess.run(model.update_lr)) 108 | prelosses, postlosses = [], [] 109 | # save model 110 | if (itr != 0) and itr % SAVE_INTERVAL == 0: 111 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 112 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 113 | 114 | 115 | def test(model, saver, sess, exp_string, tasks, num_updates=5): 116 | print('start evaluation...') 117 | print(exp_string) 118 | total_Ytest, total_Ypred, total_Ytest1, total_Ypred1, sum_accuracies, sum_accuracies1 = [], [], [], [], [], [] 119 | 120 | for i in range(len(tasks)): 121 | np.random.shuffle(tasks[i]) 122 | train_ = tasks[i][:int(len(tasks[i]) / 2)] 123 | test_ = tasks[i][int(len(tasks[i]) / 2):] # test_ samples account 25% 124 | """few-steps tuning (不用op跑是因为采用的batch_size(input shape)不一致,且不想更新model.weight)""" 125 | with tf.compat.v1.variable_scope('model', reuse=True): # np.normalize()里Variable重用 126 | fast_weights = model.weights 127 | for j in range(num_updates): 128 | inputa, labela = batch_generator(train_, FLAGS.dim_input, FLAGS.dim_output, 129 | FLAGS.test_update_batch_size) 130 | loss = model.loss_func(model.forward(inputa, fast_weights, reuse=True), 131 | labela) # fast_weight和grads(stopped)有关系,但不影响这里的梯度计算 132 | grads = tf.gradients(ys=loss, xs=list(fast_weights.values())) 133 | gradients = dict(zip(fast_weights.keys(), grads)) 134 | fast_weights = dict(zip(fast_weights.keys(), 135 | [fast_weights[key] - model.update_lr * gradients[key] for key in 136 | fast_weights.keys()])) 137 | """Single task test accuracy""" 138 | inputb, labelb = batch_generator(test_, FLAGS.dim_input, FLAGS.dim_output, len(test_)) 139 | Y_array = sess.run(tf.nn.softmax(model.forward(inputb, fast_weights, reuse=True))) # pred_prob 140 | total_Ypred1.extend(Y_array) # pred_prob_test 141 | total_Ytest1.extend(labelb) # label 142 | 143 | Y_test = [] # for single task test 144 | for j in range(len(labelb)): 145 | Y_test.append(labelb[j][0]) 146 | total_Ytest.append(labelb[j][0]) 147 | Y_pred = [] # for single task test 148 | for j in range(len(labelb)): 149 | if Y_array[j][0] > Y_array[j][1]: 150 | Y_pred.append(1) 151 | total_Ypred.append(1) # total_Ypred: 1d-array label 152 | else: 153 | Y_pred.append(0) 154 | total_Ypred.append(0) 155 | accuracy = accuracy_score(Y_test, Y_pred) 156 | sum_accuracies.append(accuracy) 157 | # print('Test_Accuracy: %f' % accuracy) 158 | # print('SHAP...') 159 | # SHAP_() # TODO: SHAP for proposed 160 | """Overall evaluation (test data)""" 161 | total_Ypred = np.array(total_Ypred).reshape(len(total_Ypred), ) 162 | total_Ytest = np.array(total_Ytest) 163 | total_acc = accuracy_score(total_Ytest, total_Ypred) 164 | print('Test_Accuracy: %f' % total_acc) 165 | cal_measure(total_Ypred, total_Ytest) 166 | kappa_value = cohen_kappa_score(total_Ypred, total_Ytest) 167 | print('Cohen_Kappa: %f' % kappa_value) 168 | 169 | # save prediction for test samples, which can be used in calculating statistical measure such as AUROC 170 | pred_prob = np.array(total_Ypred1) 171 | label_bi = np.array(total_Ytest1) 172 | savearr = np.hstack((pred_prob, label_bi)) 173 | writer = pd.ExcelWriter('proposed_test.xlsx') 174 | data_df = pd.DataFrame(savearr) 175 | data_df.to_excel(writer) 176 | writer.close() 177 | 178 | sess.close() 179 | 180 | 181 | def main(): 182 | """1.Unsupervised pretraining; 2.segmentation and meta-task sampling; 3.meta-training and -testing""" 183 | 184 | """Unsupervised pretraining""" 185 | # TODO: if it's necessary to mimic batch normalization in Pretraining? 186 | if not os.path.exists('./unsupervised_pretraining/model_init/savedmodel.npz'): 187 | print("start unsupervised pretraining") 188 | tmp = np.loadtxt(FLAGS.sample_pts, dtype=str, delimiter=",", encoding='UTF-8') 189 | tmp_feature = tmp[1:, :].astype(np.float32) 190 | np.random.shuffle(tmp_feature) 191 | Unsupervise_pretrain(tmp_feature) 192 | 193 | """meta task sampling""" 194 | tasks_path = './metatask_sampling/' + FLAGS.str_region + '_tasks_K' + str(FLAGS.K) + '.xlsx' 195 | if not os.path.exists( 196 | './metatask_sampling/' + FLAGS.str_region + '_SLIC_M{m}_K{k}_loop{loop}.tif'.format(loop=0, m=FLAGS.M, 197 | k=FLAGS.K)): 198 | print('start scene segmentation using SLIC algorithm:') 199 | p = SLICProcessor('./src_data/' + FLAGS.str_region + '/composite.tif', FLAGS.K, FLAGS.M) 200 | p.iterate_times(loop=FLAGS.loop) 201 | print('start meta-task sampling:') 202 | t = TaskSampling(p.clusters) 203 | tasks = t.sampling(p.im_geotrans, FLAGS.sample_pts) 204 | save_tasks(tasks, tasks_path) # save each meta-task samples into respective sheet in a .xlsx file 205 | savepts_fortask(p.clusters, './metatask_sampling/' + FLAGS.str_region + 'pts_tasks_K' + str(FLAGS.K) + '.xlsx') 206 | print('produce meta training and testing datasets...') 207 | HK_tasks = read_tasks(tasks_path) 208 | tasks_train, tasks_test = meta_train_test1(HK_tasks) 209 | 210 | """meta-training and -testing""" 211 | print('model construction...') 212 | model = MAML(FLAGS.dim_input, FLAGS.dim_output, test_num_updates=5) 213 | 214 | input_tensors_input = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_input) 215 | input_tensors_label = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_output) 216 | model.construct_model(input_tensors_input=input_tensors_input, input_tensors_label=input_tensors_label, 217 | prefix='metatrain_') 218 | model.summ_op = tf.compat.v1.summary.merge_all() 219 | 220 | saver = tf.compat.v1.train.Saver(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES), 221 | max_to_keep=10) 222 | 223 | sess = tf.compat.v1.InteractiveSession() 224 | init = tf.compat.v1.global_variables() # optimizer里会有额外variable需要初始化 225 | sess.run(tf.compat.v1.variables_initializer(var_list=init)) 226 | 227 | exp_string = '.mbs' + str(FLAGS.meta_batch_size) + '.ubs_' + \ 228 | str(FLAGS.num_samples_each_task) + '.numstep' + str(FLAGS.num_updates) + \ 229 | '.updatelr' + str(FLAGS.update_lr) + '.meta_lr' + str(FLAGS.meta_lr) 230 | 231 | resume_itr = 0 232 | 233 | # 续点训练 234 | if FLAGS.resume: 235 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) 236 | if model_file: 237 | ind1 = model_file.index('model') 238 | resume_itr = int(model_file[ind1 + 5:]) 239 | # print("Restoring model weights from " + model_file) 240 | saver.restore(sess, model_file) # 以model_file初始化sess中图 241 | 242 | train(model, saver, sess, exp_string, tasks_train, resume_itr) 243 | 244 | test(model, saver, sess, exp_string, tasks_test, num_updates=FLAGS.num_updates) 245 | 246 | 247 | # TODO: use tf.estimator 248 | if __name__ == "__main__": 249 | # device=tf.config.list_physical_devices('GPU') 250 | tf.compat.v1.disable_eager_execution() 251 | main() 252 | print('finished!') 253 | -------------------------------------------------------------------------------- /metatask_sampling/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import numpy as np 4 | import tensorflow as tf 5 | import os 6 | from tensorflow.python.platform import flags 7 | from utils import mse, xent, normalize 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | 12 | class MAML: 13 | def __init__(self, dim_input=1, dim_output=1, test_num_updates=5): 14 | """ must call construct_model() after initializing MAML! """ 15 | self.dim_input = dim_input 16 | self.dim_output = dim_output 17 | self.meta_lr = tf.compat.v1.placeholder_with_default(FLAGS.meta_lr, ()) 18 | self.test_num_updates = test_num_updates 19 | self.dim_hidden = [32, 32, 16] 20 | self.loss_func = xent 21 | self.forward = self.forward_fc 22 | if FLAGS.basemodel == 'MLP': 23 | self.construct_weights = self.construct_fc_weights # 参数初始化(random) 24 | elif FLAGS.basemodel == 'DAS': 25 | if os.path.exists('unsupervised_pretraining/model_init/savedmodel.npz'): 26 | self.construct_weights = self.construct_DAS_weights # DAS初始化 27 | else: 28 | raise ValueError('No pretrained model found!') 29 | else: 30 | raise ValueError('Unrecognized base model, please specify a base model such as "MLP", "DAS"...') 31 | 32 | def construct_model(self, input_tensors_input=None, input_tensors_label=None, prefix='metatrain_'): 33 | # a: training data for inner gradient, b: test data for meta gradient 34 | self.inputa = tf.compat.v1.placeholder(tf.float32, 35 | shape=input_tensors_input) # for train in a task, shape should be specified for tf.slim (but not should be correct) 36 | self.inputb = tf.compat.v1.placeholder(tf.float32, shape=input_tensors_input) 37 | self.labela = tf.compat.v1.placeholder(tf.float32, shape=input_tensors_label) # for test in a task 38 | self.labelb = tf.compat.v1.placeholder(tf.float32, shape=input_tensors_label) 39 | self.cnt_sample = tf.compat.v1.placeholder(tf.float32) # count number of samples for each task in the batch 40 | 41 | with tf.compat.v1.variable_scope('model', reuse=None) as training_scope: 42 | # Attention module 43 | self.A = tf.Variable(tf.zeros([self.dim_input, self.dim_input])) 44 | # initialize the inner learning rate as tf.Variable within 'model' scope 45 | self.update_lr = tf.Variable(FLAGS.update_lr) 46 | if 'weights' in dir(self): 47 | training_scope.reuse_variables() 48 | weights = self.weights 49 | else: 50 | # Define the weights 51 | self.weights = weights = self.construct_weights() # 初始化FC权重参数 52 | 53 | num_updates = max(self.test_num_updates, FLAGS.num_updates) # training iteration in a task 54 | 55 | def task_metalearn(inp, reuse=True): 56 | """ Perform gradient descent for one task in the meta-batch. """ 57 | inputa, inputb, labela, labelb = inp # inputa: Task(i)训练输入,batch_size = m(m个samples) 58 | task_outputbs, task_lossesb = [], [] 59 | 60 | task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter 61 | task_lossa = self.loss_func(task_outputa, labela) 62 | 63 | grads = tf.gradients(ys=task_lossa, xs=list(weights.values())) # 计算梯度 64 | if FLAGS.stop_grad: # maml中的二次求导() 65 | grads = [tf.stop_gradient(grad) for grad in grads] # 使梯度无法进行二次求偏导(BP) 66 | gradients = dict(zip(weights.keys(), grads)) 67 | fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr * gradients[key] for key in 68 | weights.keys()])) # 更新weight 69 | output = self.forward(inputb, fast_weights, reuse=True) # Task(i) test output 70 | task_outputbs.append(output) 71 | task_lossesb.append(self.loss_func(output, labelb)) 72 | 73 | for j in range(num_updates - 1): # num_updates:Task(i)中用batch_size个训练样本更新权值的迭代次数 74 | loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), 75 | labela) # fast_weight和grads(stopped)有关系,但不影响这里的梯度计算 76 | grads = tf.gradients(ys=loss, xs=list(fast_weights.values())) 77 | if FLAGS.stop_grad: 78 | grads = [tf.stop_gradient(grad) for grad in grads] 79 | gradients = dict(zip(fast_weights.keys(), grads)) 80 | fast_weights = dict(zip(fast_weights.keys(), 81 | [fast_weights[key] - self.update_lr * gradients[key] for key in 82 | fast_weights.keys()])) 83 | output = self.forward(inputb, fast_weights, reuse=True) 84 | task_outputbs.append(output) 85 | task_lossesb.append(self.loss_func(output, labelb)) 86 | 87 | return [task_outputa, task_outputbs, task_lossa, task_lossesb] # task_outpouta, task_lossa是仅 88 | 89 | if FLAGS.norm != 'None': # 此处不能删,考虑到reuse 90 | '''to initialize the batch norm vars, might want to combine this, and not run idx 0 twice (use reuse=tf.AUTO_REUSE instead).''' 91 | task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False) 92 | 93 | out_dtype = [tf.float32, [tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates] 94 | 95 | """输入各维度(for batch)进行task_metalearn的并行操作, 相较out_dtype多了batch_size的维度""" 96 | 97 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), 98 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size) # parallel calculation 99 | 100 | """ outputas:(num_tasks, num_samples, value) 101 | outputbs[i]:i是迭代次数,不同迭代次数的预测值(num_tasks, num_samples, value) 102 | lossesa:(num_tasks, value) 103 | lossesb[i]:i是迭代次数,不同迭代次数的预测值(num_tasks, value)""" 104 | outputas, outputbs, lossesa, lossesb = result # outputas:(num_tasks, num_samples, value) 105 | 106 | ## Performance & Optimization 107 | if 'train' in prefix: 108 | self.total_loss1 = total_loss1 = tf.reduce_sum(input_tensor=lossesa) / tf.cast(FLAGS.meta_batch_size, 109 | dtype=tf.float32) # total loss的均值,finn论文中的pretrain(对比用) 110 | 111 | self.total_losses2 = total_losses2 = [ 112 | tf.reduce_sum(lossesb[j]) / tf.cast(FLAGS.meta_batch_size, dtype=tf.float32) \ 113 | for j in range(num_updates)] 114 | 115 | # w = self.cnt_sample / tf.cast(FLAGS.num_samples, dtype=tf.float32) 116 | # self.total_losses2 = total_losses2 = [tf.reduce_sum( 117 | # input_tensor=tf.multiply(tf.nn.softmax(w), tf.reduce_sum(input_tensor=lossesb[j], axis=1))) 118 | # for j in range(num_updates)] 119 | 120 | # after the map_fn 121 | self.outputas, self.outputbs = outputas, outputbs # outputbs:25个task, 每个task迭代五次,value(25,5,1) 122 | self.pretrain_op = tf.compat.v1.train.AdamOptimizer(self.meta_lr).minimize(total_loss1) # inner for test 123 | 124 | optimizer = tf.compat.v1.train.AdamOptimizer(self.meta_lr) 125 | 126 | self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[ 127 | FLAGS.num_updates - 1]) # 取最后一次迭代的Lossb,gvs:gradients and variables,对所有trainable variables求梯度 128 | self.metatrain_op = optimizer.apply_gradients(gvs) # outer 129 | else: # 20/11待删 130 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(input_tensor=lossesa) / tf.cast( 131 | FLAGS.meta_batch_size, dtype=tf.float32) # inner loss 132 | self.metaval_total_losses2 = total_losses2 = [ 133 | tf.reduce_sum(input_tensor=lossesb[j]) / tf.cast(FLAGS.meta_batch_size, dtype=tf.float32) for j in 134 | range(num_updates)] # outer losses(每次迭代的) 135 | 136 | ## Summaries 137 | tf.compat.v1.summary.scalar(prefix + 'Pre-update loss', total_loss1) # for test accuracy 138 | for j in range(num_updates): 139 | tf.compat.v1.summary.scalar(prefix + 'Post-update loss, step ' + str(j + 1), total_losses2[j]) 140 | 141 | def construct_fc_weights(self): 142 | weights = {} 143 | weights['w1'] = tf.Variable(tf.random.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01)) 144 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]])) 145 | for i in range(1, len(self.dim_hidden)): 146 | weights['w' + str(i + 1)] = tf.Variable( 147 | tf.random.truncated_normal([self.dim_hidden[i - 1], self.dim_hidden[i]], stddev=0.01)) 148 | weights['b' + str(i + 1)] = tf.Variable(tf.zeros([self.dim_hidden[i]])) 149 | weights['w' + str(len(self.dim_hidden) + 1)] = tf.Variable( 150 | tf.random.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01)) 151 | weights['b' + str(len(self.dim_hidden) + 1)] = tf.Variable(tf.zeros([self.dim_output])) 152 | return weights 153 | 154 | def forward_fc(self, inp, weights, reuse=False): 155 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0') 156 | for i in range(1, len(self.dim_hidden)): 157 | hidden = normalize(tf.matmul(hidden, weights['w' + str(i + 1)]) + weights['b' + str(i + 1)], 158 | activation=tf.nn.relu, reuse=reuse, scope=str(i + 1)) 159 | return tf.matmul(hidden, weights['w' + str(len(self.dim_hidden) + 1)]) + weights[ 160 | 'b' + str(len(self.dim_hidden) + 1)] 161 | 162 | def construct_DAS_weights(self): 163 | """读取DAS权参""" 164 | npzfile = np.load('unsupervised_pretraining/model_init/savedmodel.npz') 165 | weights = {} 166 | weights['w1'] = tf.Variable(tf.transpose(a=npzfile['arr_0'])) 167 | weights['b1'] = tf.Variable(npzfile['arr_1']) 168 | weights['w2'] = tf.Variable(tf.transpose(a=npzfile['arr_2'])) 169 | weights['b2'] = tf.Variable(npzfile['arr_3']) 170 | weights['w3'] = tf.Variable(tf.transpose(a=npzfile['arr_4'])) 171 | weights['b3'] = tf.Variable(npzfile['arr_5']) 172 | weights['w4'] = tf.Variable(tf.random.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01)) 173 | weights['b4'] = tf.Variable(tf.zeros([self.dim_output])) 174 | return weights 175 | -------------------------------------------------------------------------------- /models_of_blocks/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /models_of_blocks/HK/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | GDAL==3.2.3 2 | matplotlib==3.6.2 3 | numpy==1.23.3 4 | pandas==1.5.0 5 | scikit_image==0.19.3 6 | scikit_learn==1.1.2 7 | scipy==1.9.1 8 | shap==0.41.0 9 | tf_slim==1.1.0 10 | tqdm==4.64.1 11 | umap_learn==0.5.3 12 | xgboost==1.7.1 13 | ~atplotlib==3.6.0 14 | -------------------------------------------------------------------------------- /scene_segmentation.py: -------------------------------------------------------------------------------- 1 | import math 2 | from skimage import io, color 3 | import numpy as np 4 | from tqdm import trange 5 | from osgeo import gdal 6 | import pandas as pd 7 | from tensorflow.python.platform import flags 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | 12 | class Cluster(object): 13 | cluster_index = 1 14 | 15 | def __init__(self, h, w, dem, aspect, curvature, slope): 16 | self.update(h, w, dem, aspect, curvature, slope) 17 | self.pixels = [] 18 | self.no = self.cluster_index 19 | self.cluster_index += 1 # 计数 20 | 21 | def update(self, h, w, dem, aspect, curvature, slope): 22 | self.h = h 23 | self.w = w 24 | self.dem = dem 25 | self.aspect = aspect 26 | self.curvature = curvature 27 | self.slope = slope 28 | 29 | def __str__(self): 30 | return "{},{}:{} {} {} ".format(self.h, self.w, self.dem, self.aspect, self.curvature, 31 | self.slope) 32 | 33 | def __repr__(self): 34 | return self.__str__() 35 | 36 | 37 | class SLICProcessor(object): 38 | @staticmethod 39 | def open_image(path): 40 | """ 41 | Return: 42 | 3D array, row col [LAB] 43 | """ 44 | rgb = io.imread(path) 45 | lab_arr = color.rgb2lab(rgb) 46 | return lab_arr 47 | 48 | @staticmethod 49 | def save_lab_image(path, lab_arr): 50 | """ 51 | Convert the array to RBG, then save the image 52 | """ 53 | rgb_arr = color.lab2rgb(lab_arr) 54 | io.imsave(path, rgb_arr) 55 | 56 | def make_cluster(self, h, w): 57 | return Cluster(h, w, 58 | self.data[0][h][w], 59 | self.data[1][h][w], 60 | self.data[2][h][w], 61 | self.data[3][h][w], ) 62 | 63 | def __init__(self, filename, K, M): # K:number of superpixels; M:衡量像素距离占距离测量的比重 64 | self.file = filename 65 | self.K = K 66 | self.M = M 67 | self.data = self.readTif(filename) # shape:(6, , ) 68 | self.im_geotrans = gdal.Open(filename).GetGeoTransform() 69 | self.image_height = self.data.shape[1] 70 | self.image_width = self.data.shape[2] 71 | self.N = self.image_height * self.image_width 72 | self.S = int(math.sqrt(self.N / self.K)) 73 | 74 | self.clusters = [] 75 | self.label = {} 76 | self.dis = np.full((self.image_height, self.image_width), np.inf) # np.inf正无穷 77 | 78 | def readTif(self, fileName): 79 | dataset = gdal.Open(fileName) 80 | if dataset == None: 81 | print(fileName + "文件无法打开") 82 | return 83 | im_width = dataset.RasterXSize # 栅格矩阵的列数 84 | im_height = dataset.RasterYSize # 栅格矩阵的行数 85 | im_bands = dataset.RasterCount # 波段数 86 | im_data = dataset.ReadAsArray(0, 0, im_width, im_height) # 获取数据 87 | im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息 88 | im_proj = dataset.GetProjection() # 获取投影信息 89 | # col = int((coor[i][0] - im_geotrans[0]) / im_geotrans[1]) 90 | # row = int((coor[i][1] - im_geotrans[3]) / im_geotrans[5]) 91 | # im_nirBand = im_data[3,0:im_height,0:im_width]#获取近红外波段 92 | return im_data 93 | 94 | def init_clusters(self, data): 95 | h = int(self.S / 2) # 第一个中心点位(cluster) 96 | w = int(self.S / 2) 97 | while h < self.image_height: 98 | while w < self.image_width: 99 | if data[0][h][w] != -9999: # -9999为Nodata 100 | self.clusters.append(self.make_cluster(h, w)) 101 | w += self.S 102 | w = int(self.S / 2) 103 | h += self.S 104 | 105 | def get_gradient(self, h, w): 106 | if w + 1 >= self.image_width: 107 | w = self.image_width - 2 108 | if h + 1 >= self.image_height: 109 | h = self.image_height - 2 110 | 111 | gradient = self.data[0][h + 1][w + 1] - self.data[0][h][w] + \ 112 | self.data[1][h + 1][w + 1] - self.data[1][h][w] + \ 113 | self.data[2][h + 1][w + 1] - self.data[2][h][w] + \ 114 | self.data[3][h + 1][w + 1] - self.data[3][h][w] 115 | 116 | return gradient 117 | 118 | def move_clusters(self): 119 | for cluster in self.clusters: 120 | cluster_gradient = self.get_gradient(cluster.h, cluster.w) # 计算每个中心的gradient 121 | for dh in range(-5, 6): 122 | for dw in range(-5, 6): 123 | _h = cluster.h + dh 124 | _w = cluster.w + dw 125 | if self.data[0][_h][_w] and self.data[1][_h][_w] and self.data[2][_h][_w] \ 126 | and self.data[3][_h][_w] != -9999: 127 | new_gradient = self.get_gradient(_h, _w) 128 | if new_gradient < cluster_gradient: # 寻找 4 x 4 邻域内梯度最小的像素点(更聚集),并且移动中心 129 | cluster.update(_h, _w, self.data[0][_h][_w], self.data[1][_h][_w], self.data[2][_h][_w], 130 | self.data[3][_h][_w]) 131 | cluster_gradient = new_gradient 132 | 133 | def assignment(self): 134 | W = [0.3, 0.1, 0.2, 0.4] # 权重:各因素影响 135 | for cluster in self.clusters: 136 | for h in range(cluster.h - self.S, cluster.h + self.S): 137 | if h < 0 or h >= self.image_height: continue # continue进入下一个循环 138 | for w in range(cluster.w - self.S, cluster.w + self.S): 139 | if w < 0 or w >= self.image_width: continue 140 | if self.data[0][h][w] != -9999 and self.data[1][h][w] != -9999 \ 141 | and self.data[2][h][w] != -9999 and self.data[3][h][w] != -9999: 142 | Dc = math.sqrt( 143 | math.pow(self.data[0][h][w] - cluster.dem, 2) * W[0] + 144 | math.pow(self.data[1][h][w] - cluster.aspect, 2) * W[1] + 145 | math.pow(self.data[2][h][w] - cluster.curvature, 2) * W[2] + 146 | math.pow(self.data[3][h][w] - cluster.slope, 2) * W[3] 147 | ) # dbs 148 | Ds = math.sqrt( 149 | math.pow(h - cluster.h, 2) + 150 | math.pow(w - cluster.w, 2)) # dxy 151 | D = math.sqrt(math.pow(Dc / self.M, 2) + math.pow(Ds / self.S, 2)) # Ds 152 | if D < self.dis[h][w]: 153 | if (h, w) not in self.label: # dict中tuple也可以作为key 154 | self.label[(h, w)] = cluster 155 | cluster.pixels.append((h, w)) 156 | else: 157 | self.label[(h, w)].pixels.remove((h, w)) 158 | self.label[(h, w)] = cluster 159 | cluster.pixels.append((h, w)) 160 | self.dis[h][w] = D 161 | 162 | def update_cluster(self): # 计算各SLIC聚类的中心 163 | for cluster in self.clusters: 164 | sum_h = sum_w = number = 0 165 | for p in cluster.pixels: 166 | sum_h += p[0] 167 | sum_w += p[1] 168 | number += 1 169 | _h = int(sum_h / number) 170 | _w = int(sum_w / number) 171 | cluster.update(_h, _w, self.data[0][_h][_w], self.data[1][_h][_w], self.data[2][_h][_w], 172 | self.data[3][_h][_w]) # 计算聚类中心 173 | 174 | def writeTiff(self, im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path): 175 | if 'int8' in im_data.dtype.name: 176 | datatype = gdal.GDT_Byte 177 | elif 'int16' in im_data.dtype.name: 178 | datatype = gdal.GDT_UInt16 179 | else: 180 | datatype = gdal.GDT_Float32 181 | 182 | im_bands, im_height, im_width = im_data.shape 183 | path = 'metatask_sampling\\' + path 184 | # 创建文件 185 | driver = gdal.GetDriverByName("GTiff") 186 | dataset = driver.Create(path, im_width, im_height, im_bands, datatype) 187 | if (dataset != None): 188 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 189 | dataset.SetProjection(im_proj) # 写入投影 190 | for i in range(im_bands): 191 | dataset.GetRasterBand(i + 1).SetNoDataValue(-9999) 192 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) 193 | del dataset 194 | 195 | def savetif(self, tiffile, savename, image_arr): 196 | gdal.AllRegister() 197 | dataset = gdal.Open(tiffile) 198 | im_bands = dataset.RasterCount # 波段数 199 | im_width = dataset.RasterXSize # 列数 200 | im_height = dataset.RasterYSize # 行数 201 | im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息 202 | im_proj = dataset.GetProjection() # 获取投影信息 203 | 204 | self.writeTiff(image_arr, im_width, im_height, im_bands, im_geotrans, im_proj, savename) 205 | 206 | def save_current_image(self, tiffile, savename): 207 | image_arr = np.copy(self.data) 208 | for i in range(len(image_arr)): # 初始化 209 | for h in range(self.image_height): 210 | for w in range(self.image_width): 211 | image_arr[i][h][w] = -9999 212 | 213 | c = 0; 214 | interval = int(256 / len(self.clusters)) 215 | for cluster in self.clusters: # 可视化各聚类点 216 | c += 1 217 | r = g = b = d = interval * c - 1 218 | for p in cluster.pixels: # LAB三通道赋值 219 | image_arr[0][p[0]][p[1]] = r 220 | image_arr[1][p[0]][p[1]] = g 221 | image_arr[2][p[0]][p[1]] = b 222 | image_arr[0][cluster.h][cluster.w] = 0 # 让cluster中心为0 223 | image_arr[1][cluster.h][cluster.w] = 0 224 | image_arr[2][cluster.h][cluster.w] = 0 225 | 226 | self.savetif(tiffile, savename, image_arr) 227 | 228 | def iterate_times(self, loop=5): 229 | self.init_clusters(self.data) # 存储所有中心点, clusters = [] 230 | self.move_clusters() 231 | for i in trange(loop): 232 | self.assignment() 233 | self.update_cluster() 234 | savename = FLAGS.str_region + '_SLIC_M{m}_K{k}_loop{loop}.tif'.format(loop=i, m=self.M, 235 | k=self.K) # 生成可视tif 236 | self.save_current_image(self.file, savename) 237 | 238 | 239 | class TaskSampling(object): 240 | def __init__(self, clusters): 241 | self.clusters = clusters 242 | self.tasks = self.init_tasks(len(clusters)) 243 | 244 | def init_tasks(self, num_clusters): 245 | L = [] 246 | for i in range(num_clusters): 247 | L.append([]) 248 | return L 249 | 250 | def readpts(self, filepath): 251 | tmp = np.loadtxt(filepath, dtype=np.str, delimiter=",", encoding='UTF-8') 252 | features = tmp[1:, :-3].astype(np.float32) 253 | features = features / features.max(axis=0) # 减小数值影响 254 | xy = tmp[1:, -3: -1].astype(np.float32) 255 | label = tmp[1:, -1].astype(np.float32) 256 | return features, xy, label 257 | 258 | def sampling(self, im_geotrans, path): 259 | features, xy, label = self.readpts(path) 260 | # features_Ts_, xy_Ts, label_Ts = self.readpts(FLAGS.Ts_pts) 261 | # features = np.vstack((features, features_Ts_)) 262 | # xy = np.vstack((xy, xy_Ts)) 263 | # # labeling Ts pts according to dv value 264 | # for i in range(len(label_Ts)): 265 | # if label_Ts[i] <= 2: 266 | # label_Ts[i] = 0.7 267 | # continue 268 | # if 2 < label_Ts[i] <= 4: 269 | # label_Ts[i] = 0.75 270 | # continue 271 | # if 4 < label_Ts[i] <= 6: 272 | # label_Ts[i] = 0.8 273 | # continue 274 | # if 6 < label_Ts[i] <= 8: 275 | # label_Ts[i] = 0.85 276 | # continue 277 | # if label_Ts[i] > 8: 278 | # label_Ts[i] = 0.9 279 | # label = np.hstack((label, label_Ts)) 280 | # 计算(row, col) 281 | pts = [] 282 | for i in range(xy.shape[0]): 283 | height = int((xy[i][1] - im_geotrans[3]) / im_geotrans[5]) 284 | width = int((xy[i][0] - im_geotrans[0]) / im_geotrans[1]) 285 | pts.append((height, width)) 286 | 287 | pt_index = 0 288 | for pt in pts: 289 | k = 0 # count cluster 290 | for cluster in self.clusters: 291 | if (pt[0], pt[1]) in cluster.pixels: 292 | self.tasks[k].append([features[pt_index], label[pt_index]]) 293 | break 294 | else: 295 | k += 1 296 | pt_index += 1 297 | return self.tasks 298 | -------------------------------------------------------------------------------- /src_data/HK/composite.tfw: -------------------------------------------------------------------------------- 1 | 78.1848426438 2 | 0.0000000000 3 | 0.0000000000 4 | -75.3690471051 5 | 801034.7337224468 6 | 846872.9936815350 7 | -------------------------------------------------------------------------------- /src_data/HK/composite.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLi-de/Meta_LSM/c25e2904761e3f4b4d5a797f0b4db7ddfe53236e/src_data/HK/composite.tif -------------------------------------------------------------------------------- /src_data/HK/composite.tif.aux.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -62 6 | 951 7 | 256 8 | 1 9 | 0 10 | 1|0|0|2|1|0|1|5|7|4|18|41|117|311|1053|8851|9688|11455|8042|7095|6026|5065|4483|3916|3527|3282|2989|2840|2857|3486|2536|2472|2494|2389|2157|2165|1648|2120|2009|1997|1945|1924|1829|1845|1727|1744|1653|1716|1700|1640|1535|1537|1569|1468|1447|1449|1404|1462|1425|1347|1302|1257|1376|1247|1290|1280|1208|1223|1157|1256|1139|1132|806|979|984|1021|979|938|923|933|884|859|795|825|787|732|707|741|737|733|697|719|613|665|583|622|557|601|567|556|499|537|503|466|458|433|445|433|398|297|397|389|389|385|315|381|334|316|303|300|251|298|275|266|285|268|256|223|231|208|194|217|175|168|161|152|151|152|126|152|147|117|114|135|145|81|114|79|117|89|112|95|96|92|81|94|64|71|66|59|55|72|55|65|43|49|66|51|44|63|62|42|59|31|40|52|34|42|47|44|42|24|34|35|37|26|29|25|40|34|35|33|35|29|39|34|40|26|20|33|17|19|31|26|19|13|20|22|13|13|10|8|10|14|16|6|10|16|10|10|12|10|6|12|8|9|7|8|8|7|5|1|7|3|8|4|3|3|5|4|6|1|3|1|4|3|1|1|1|1|0|2|1|0|2|1 11 | 12 | 13 | 14 | ATHEMATIC 15 | 18458.44274292592,662.3847044297564,15.69938554122262,609.7599113609483 16 | 17 | 951 18 | 124.16100786648 19 | -62 20 | 1 21 | 1 22 | 135.86185168371 23 | 24 | 25 | 26 | 27 | 28 | -1 29 | 359.7870178222656 30 | 256 31 | 1 32 | 0 33 | 6062|561|682|639|685|638|662|676|767|632|669|673|567|952|568|670|633|723|594|779|609|654|683|620|636|693|615|592|654|600|511|470|1577|462|538|644|664|562|679|670|606|633|672|628|638|793|610|750|621|588|576|925|586|674|723|629|800|712|713|746|671|667|687|588|1263|614|696|839|684|671|777|789|825|732|723|716|662|932|689|744|743|669|641|877|688|623|779|674|844|699|712|733|708|739|672|364|1683|676|704|714|711|764|709|676|813|584|814|684|651|803|674|713|678|668|1005|544|668|641|791|825|597|631|717|647|634|649|630|1241|406|592|626|680|716|690|629|829|563|742|657|594|963|486|662|669|746|669|780|600|672|764|590|727|588|820|574|660|647|589|568|1608|420|637|661|679|660|644|709|787|691|723|615|746|909|608|675|766|664|708|976|654|701|734|771|804|694|802|660|864|743|711|593|1336|695|729|794|685|812|796|703|835|718|777|699|658|978|618|695|751|643|675|787|626|737|810|653|728|717|726|672|714|663|613|453|1608|602|636|663|750|695|707|823|760|721|797|725|685|846|686|777|705|680|525|984|717|715|803|616|824|699|756|765|719|682|675|471|1 34 | 35 | 36 | 37 | ATHEMATIC 38 | 662.3847044297564,11334.03019575368,0.4471348368896256,73.25494859340394 39 | 40 | 359.78701782227 41 | 176.76238026471 42 | -1 43 | 1 44 | 1 45 | 106.46140237548 46 | 47 | 48 | 49 | 50 | 51 | -21.66348075866699 52 | 22.56612586975098 53 | 256 54 | 1 55 | 0 56 | 1|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|1|0|0|0|0|0|0|1|0|4|0|0|0|0|1|1|0|2|1|0|1|1|1|2|2|1|3|0|1|3|3|4|3|5|3|5|0|5|5|6|5|9|7|12|3|1|11|7|16|11|18|8|23|16|26|11|33|18|44|48|33|73|34|102|60|124|86|177|111|287|159|365|553|320|801|455|1297|802|2003|1279|3364|2163|5299|3456|8857|11860|7213|17833|10400|28238|10688|18076|7496|12015|4696|7681|3172|4951|2041|3141|2277|930|1492|591|998|380|679|274|432|182|290|123|234|187|67|142|75|84|43|83|27|73|15|55|19|28|29|12|33|14|17|10|17|4|12|6|15|8|13|4|3|8|5|6|1|5|3|7|3|1|2|2|3|0|6|1|0|2|1|1|3|1|2|2|0|1|1|0|0|0|1|1|0|1|0|0|1|0|1|0|0|0|0|0|1|0|2|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|1|0|0|0|0|0|0|0|0|0|1 57 | 58 | 59 | 60 | ATHEMATIC 61 | 15.69938554122262,0.4471348368896256,1.404366681951738,0.1800680002874268 62 | 63 | 22.566125869751 64 | -0.012823529684218 65 | -21.663480758667 66 | 1 67 | 1 68 | 1.1850597799064 69 | 70 | 71 | 72 | 73 | 74 | 0 75 | 65.13930511474609 76 | 256 77 | 1 78 | 0 79 | 5017|927|1169|901|1503|2038|1525|1778|2054|1852|2306|1505|1770|2125|1867|1352|2060|1856|1673|1870|1756|1567|1548|1509|1397|1540|1636|1591|1263|1709|1606|1536|1154|1724|1536|1385|1485|1442|1559|1405|1654|1537|1540|1555|1661|1340|1408|1479|1659|1484|1524|1507|1716|1668|1429|1430|1747|1575|1826|1584|1464|1615|1664|1707|1548|1672|1584|1689|1593|1574|1676|1490|1530|1503|1604|1572|1633|1657|1463|1531|1657|1491|1546|1439|1416|1387|1553|1490|1339|1511|1379|1356|1304|1322|1144|1273|1276|1306|1135|1111|1200|1036|1139|1003|1072|969|957|875|901|926|858|780|776|756|723|741|699|670|573|601|579|574|500|504|475|390|498|424|411|380|365|374|336|320|307|285|289|279|240|254|208|199|228|192|154|171|160|154|142|143|116|119|124|109|135|98|92|84|84|69|70|58|66|72|56|49|54|48|46|47|48|31|43|38|34|30|20|24|26|21|22|31|18|17|13|8|22|15|16|17|10|13|12|9|10|10|6|8|5|11|7|4|3|5|8|2|6|5|5|7|2|3|3|3|6|1|1|0|4|2|2|0|2|0|0|0|3|0|0|0|0|3|0|1|1|0|0|1|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|2|1 80 | 81 | 82 | 83 | ATHEMATIC 84 | 609.7599113609483,73.25494859340394,0.1800680002874268,95.64630594975473 85 | 86 | 65.139305114746 87 | 15.048037271844 88 | 0 89 | 1 90 | 1 91 | 9.7798929416305 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /src_data/HK/composite.tif.ovr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLi-de/Meta_LSM/c25e2904761e3f4b4d5a797f0b4db7ddfe53236e/src_data/HK/composite.tif.ovr -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions. """ 2 | import numpy as np 3 | import tensorflow as tf 4 | import pandas as pd 5 | 6 | import tf_slim as slim 7 | from tensorflow.python.platform import flags 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | 12 | def normalize(inp, activation, reuse, scope): 13 | if FLAGS.norm == 'batch_norm': 14 | return slim.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 15 | elif FLAGS.norm == 'layer_norm': 16 | return slim.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 17 | elif FLAGS.norm == 'None': 18 | if activation is not None: 19 | return activation(inp) 20 | else: 21 | return inp 22 | 23 | 24 | def mse(pred, label): 25 | pred = tf.reshape(pred, [-1]) 26 | label = tf.reshape(label, [-1]) 27 | return tf.reduce_mean(input_tensor=tf.square(pred - label)) 28 | 29 | 30 | def xent(pred, label): 31 | # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives 32 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / tf.cast(tf.shape(input=label)[0], 33 | dtype=tf.float32) # 注意归一 34 | 35 | 36 | def tasksbatch_generator(data, batch_size, num_samples, dim_input, dim_output): 37 | """generate batch tasks""" 38 | init_inputs = np.zeros([batch_size, num_samples, dim_input], dtype=np.float32) 39 | labels = np.zeros([batch_size, num_samples, dim_output], dtype=np.float32) 40 | 41 | np.random.shuffle(data) 42 | start_index = np.random.randint(0, len(data) - batch_size) 43 | batch_tasks = data[start_index:(start_index + batch_size)] 44 | 45 | cnt_sample = [] 46 | for i in range(batch_size): 47 | cnt_sample.append(len(batch_tasks[i])) 48 | 49 | for i in range(batch_size): 50 | np.random.shuffle(batch_tasks[i]) # shuffle samples in each task 51 | start_index1 = np.random.randint(0, len(batch_tasks[i]) - num_samples) 52 | task_samples = batch_tasks[i][start_index1:(start_index1 + num_samples)] 53 | for j in range(num_samples): 54 | init_inputs[i][j] = task_samples[j][0] 55 | if task_samples[j][1] == 1: 56 | labels[i][j][0] = 1 # 滑坡 57 | else: 58 | labels[i][j][1] = 1 # 非滑坡 59 | return init_inputs, labels, np.array(cnt_sample).astype(np.float32) 60 | 61 | 62 | def batch_generator(one_task, dim_input, dim_output, batch_size): 63 | """generate samples from one tasks""" 64 | np.random.shuffle(one_task) 65 | batch_ = one_task[:batch_size] 66 | init_inputs = np.zeros([batch_size, dim_input], dtype=np.float32) 67 | labels = np.zeros([batch_size, dim_output], dtype=np.float32) 68 | for i in range(batch_size): 69 | init_inputs[i] = batch_[i][0] 70 | if batch_[i][1] == 1: 71 | labels[i][1] = 1 72 | else: 73 | labels[i][0] = 1 74 | return init_inputs, labels 75 | 76 | 77 | # # for each region (e.g., FJ&FL) 78 | # def sample_generator_(tasks, dim_input, dim_output): 79 | # all_samples = np.array(tasks[0]) 80 | # num_samples = int(FLAGS.num_samples_each_task / 2) 81 | # for i in range(len(tasks) - 1): 82 | # if len(tasks[i + 1]) > 0: 83 | # all_samples = np.vstack((all_samples, np.array(tasks[i + 1]))) 84 | # init_inputs = np.zeros([1, num_samples, dim_input], dtype=np.float32) 85 | # labels = np.zeros([1, num_samples, dim_output], dtype=np.float32) 86 | # for i in range(num_samples): 87 | # init_inputs[0][i] = all_samples[i][:-1] 88 | # if all_samples[i][-1] == 1: 89 | # labels[0][i][0] = 1 90 | # else: 91 | # labels[0][i][1] = 1 92 | # return init_inputs, labels 93 | 94 | 95 | def meta_train_test(fj_tasks, fl_tasks, mode=0): 96 | test1_fj_tasks, test1_fl_tasks, read_tasks, one_test_tasks = [], [], [], [] 97 | _train, _test = [], [] 98 | # np.random.shuffle(tasks) 99 | if mode == 0: 100 | elig_tasks = [] 101 | for i in range(len(fj_tasks)): 102 | if len(fj_tasks[i]) > FLAGS.num_samples_each_task: 103 | elig_tasks.append(fj_tasks[i]) 104 | elif len(fj_tasks[i]) > 10: # set 10 to test K=10-shot learning 105 | test1_fj_tasks.append(fj_tasks[i]) 106 | else: 107 | read_tasks.append(fj_tasks[i]) 108 | _train = elig_tasks[:int(len(elig_tasks) / 4 * 3)] 109 | _test = elig_tasks[int(len(elig_tasks) / 4 * 3):] + test1_fj_tasks 110 | for i in range(len(read_tasks)): # read_tasks暂时不用 111 | one_test_tasks.extend(read_tasks[i]) 112 | return _train, _test 113 | 114 | if mode == 1: 115 | for i in range(len(fj_tasks)): 116 | if len(fj_tasks[i]) > FLAGS.num_samples_each_task: 117 | _train.append(fj_tasks[i]) 118 | for i in range(len(fl_tasks)): 119 | if len(fl_tasks[i]) > 10: 120 | _test.append(fl_tasks[i]) 121 | return _train, _test 122 | 123 | if mode == 2 or mode == 3: 124 | elig_fj_tasks, elig_fl_tasks = [], [] 125 | for i in range(len(fj_tasks)): 126 | if len(fj_tasks[i]) > FLAGS.num_samples_each_task: 127 | elig_fj_tasks.append(fj_tasks[i]) 128 | elif len(fj_tasks[i]) > 10: 129 | test1_fj_tasks.append(fj_tasks[i]) 130 | for i in range(len(fl_tasks)): 131 | if len(fl_tasks[i]) > FLAGS.num_samples_each_task: 132 | elig_fl_tasks.append(fl_tasks[i]) 133 | elif len(fl_tasks[i]) > 10: 134 | test1_fl_tasks.append(fl_tasks[i]) 135 | if mode == 2: 136 | _train = elig_fj_tasks[:int(len(elig_fj_tasks) / 4 * 3)] + elig_fl_tasks 137 | _test = elig_fj_tasks[int(len(elig_fj_tasks) / 4 * 3):] + test1_fj_tasks 138 | return _train, _test 139 | elif mode == 3: 140 | _train = elig_fj_tasks + elig_fl_tasks[:int(len(elig_fj_tasks) / 2)] 141 | _test = elig_fl_tasks[int(len(elig_fl_tasks) / 2):] + test1_fl_tasks 142 | return _train, _test 143 | # _test.extend(resid_tasks) 144 | 145 | 146 | def meta_train_test1(HK_tasks): 147 | test_hk_tasks, one_test_tasks, remain_tasks, elig_tasks = [], [], [], [] 148 | for i in range(len(HK_tasks)): 149 | if len(HK_tasks[i]) > FLAGS.num_samples_each_task: 150 | elig_tasks.append(HK_tasks[i]) 151 | else: 152 | remain_tasks.append(HK_tasks[i]) 153 | np.random.shuffle(elig_tasks) 154 | _train = elig_tasks[:int(len(elig_tasks) / 4 * 3)] 155 | _test = elig_tasks[int(len(elig_tasks) / 4 * 3):] 156 | return _train, _test 157 | 158 | 159 | def save_tasks(tasks, filename): 160 | """将tasks存到csv中""" 161 | writer = pd.ExcelWriter(filename) 162 | for i in range(len(tasks)): 163 | task_sampels = [] 164 | for j in range(len(tasks[i])): 165 | attr_lb = np.append(tasks[i][j][0], tasks[i][j][1]) 166 | task_sampels.append(attr_lb) 167 | data_df = pd.DataFrame(task_sampels) 168 | data_df.to_excel(writer, 'task_' + str(i), float_format='%.5f', header=False, index=False) 169 | writer.close() 170 | 171 | 172 | def read_tasks(file): 173 | """获取tasks""" 174 | f = pd.ExcelFile(file) 175 | tasks = [[] for i in range(len(f.sheet_names))] 176 | k = 0 # count task 177 | for sheetname in f.sheet_names: 178 | attr = pd.read_excel(file, usecols=[i for i in range(FLAGS.dim_input)], sheet_name=sheetname, 179 | header=None).values.astype(np.float32) 180 | label = pd.read_excel(file, usecols=[FLAGS.dim_input], sheet_name=sheetname, header=None).values.reshape( 181 | (-1,)).astype(np.float32) 182 | for j in range(np.shape(attr)[0]): 183 | tasks[k].append([attr[j], label[j]]) 184 | k += 1 185 | return tasks 186 | 187 | 188 | def savepts_fortask(clusters, file): 189 | writer = pd.ExcelWriter(file) 190 | count = 0 191 | for cluster in clusters: 192 | pts = [] 193 | for pixel in cluster.pixels: 194 | pts.append(pixel) 195 | data_df = pd.DataFrame(pts) 196 | data_df.to_excel(writer, 'task_' + str(count), float_format='%.5f', header=False, index=False) 197 | count = count + 1 198 | writer.close() 199 | 200 | 201 | def read_pts(file): 202 | """获取tasks""" 203 | f = pd.ExcelFile(file) 204 | tasks = [] 205 | for sheetname in f.sheet_names: 206 | arr = pd.read_excel(file, sheet_name=sheetname).values.astype(np.float32) 207 | tasks.append(arr) 208 | return tasks 209 | 210 | 211 | def cal_measure(pred, y_test): 212 | TP = ((pred == 1) * (y_test == 1)).astype(int).sum() 213 | FP = ((pred == 1) * (y_test == 0)).astype(int).sum() 214 | FN = ((pred == 0) * (y_test == 1)).astype(int).sum() 215 | TN = ((pred == 0) * (y_test == 0)).astype(int).sum() 216 | # statistical measure 217 | Precision = TP / (TP + FP) 218 | Recall = TP / (TP + FN) 219 | F_measures = 2 * Precision * Recall / (Precision + Recall) 220 | print('Precision: %f' % Precision, '\nRecall: %f' % Recall, '\nF_measures: %f' % F_measures) 221 | --------------------------------------------------------------------------------