├── .gitignore
├── LICENSE
├── LSM.py
├── README.md
├── SHAP.py
├── comparison.py
├── figs
└── Overview.jpg
├── figure.py
├── meta_LSM.py
├── metatask_sampling
└── .gitkeep
├── modeling.py
├── models_of_blocks
├── .gitkeep
└── HK
│ └── .gitkeep
├── requirements.txt
├── scene_segmentation.py
├── src_data
├── HK
│ ├── composite.tfw
│ ├── composite.tif
│ ├── composite.tif.aux.xml
│ └── composite.tif.ovr
├── Ts_HK.csv
├── grid_samples_HK.csv
├── samples_HK.csv
└── samples_HK_noTS.csv
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # root
2 | /*.xlsx
3 | /*.xls
4 | #/*.txt
5 |
6 | # pycache
7 | .idea/
8 | __pycache__/
9 |
10 | # figs
11 | figs/*.pdf
12 |
13 | # models_of_blocks
14 | models_of_blocks/**/*.npz
15 |
16 | # metatask_sampling
17 | metatask_sampling/*.tif
18 | metatask_sampling/*.xlsx
19 |
20 | # files
21 | unsupervised_pretraining/
22 | tmp/
23 | checkpoint_dir/
24 |
25 | *~
26 |
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2021 Li CHEN
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LSM.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import pandas as pd
3 | import numpy as np
4 | from osgeo import gdal
5 |
6 | from meta_LSM import FLAGS
7 | from modeling import MAML
8 | from utils import batch_generator, read_pts, read_tasks
9 |
10 |
11 | def readfxy_csv(file):
12 | tmp = np.loadtxt(file, dtype=str, delimiter=",", encoding='UTF-8')
13 | features = tmp[1:, :-2].astype(np.float32)
14 | features = features / features.max(axis=0)
15 | xy = tmp[1:, -2:].astype(np.float32)
16 | return features, xy
17 |
18 |
19 | def getclusters(gridpts_xy, taskpts, tifformat_path):
20 | dataset = gdal.Open(tifformat_path)
21 | if not dataset:
22 | print("can not open *.tif file!")
23 | im_geotrans = dataset.GetGeoTransform()
24 | gridcluster = [[] for i in range(len(taskpts))]
25 | for i in range(np.shape(gridpts_xy)[0]):
26 | height = int((gridpts_xy[i][1] - im_geotrans[3]) / im_geotrans[5])
27 | width = int((gridpts_xy[i][0] - im_geotrans[0]) / im_geotrans[1])
28 | for j in range(len(taskpts)):
29 | if [height, width] in taskpts[j].tolist():
30 | gridcluster[j].append(i)
31 | break
32 | return gridcluster
33 |
34 |
35 | def predict_LSM(tasks_samples, features, xy, indexes, savename, num_updates=5):
36 | """restore model from checkpoint"""
37 | tf.compat.v1.disable_eager_execution()
38 | model = MAML(FLAGS.dim_input, FLAGS.dim_output, test_num_updates=5)
39 | input_tensors_input = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_input)
40 | input_tensors_label = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_output)
41 | model.construct_model(input_tensors_input=input_tensors_input, input_tensors_label=input_tensors_label,
42 | prefix='metatrain_')
43 | exp_string = '.mbs' + str(FLAGS.meta_batch_size) + '.ubs_' + \
44 | str(FLAGS.num_samples_each_task) + '.numstep' + str(FLAGS.num_updates) + \
45 | '.updatelr' + str(FLAGS.update_lr) + '.meta_lr' + str(FLAGS.meta_lr)
46 | saver = tf.compat.v1.train.Saver(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))
47 | sess = tf.compat.v1.InteractiveSession()
48 | init = tf.compat.v1.global_variables() # optimizer里会有额外variable需要初始化
49 | sess.run(tf.compat.v1.variables_initializer(var_list=init))
50 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
51 | if model_file:
52 | print("Restoring model weights from " + model_file)
53 | saver.restore(sess, model_file) # 以model_file初始化sess中图
54 | else:
55 | print("no intermediate model found!")
56 |
57 | savearr = np.arange(4, dtype=np.float32).reshape((1, 4)) # save predicting result
58 |
59 | for i in range(len(tasks_samples)):
60 | np.random.shuffle(tasks_samples[i])
61 | with tf.compat.v1.variable_scope('model', reuse=True): # Variable reuse in np.normalize()
62 | if len(tasks_samples[i]) > FLAGS.num_samples_each_task:
63 | train_ = tasks_samples[i][:int(len(tasks_samples[i]) / 2)]
64 | batch_size = FLAGS.test_update_batch_size
65 | else:
66 | train_ = tasks_samples[i]
67 | batch_size = int(len(train_) / 2)
68 | fast_weights = model.weights
69 | for j in range(num_updates):
70 | inputa, labela = batch_generator(train_, FLAGS.dim_input, FLAGS.dim_output,
71 | batch_size)
72 | loss = model.loss_func(model.forward(inputa, fast_weights, reuse=True), labela)
73 | grads = tf.gradients(ys=loss, xs=list(fast_weights.values()))
74 | gradients = dict(zip(fast_weights.keys(), grads))
75 | fast_weights = dict(zip(fast_weights.keys(),
76 | [fast_weights[key] - model.update_lr * gradients[key] for key in
77 | fast_weights.keys()]))
78 |
79 | """predict LSM"""
80 | if len(indexes[i]):
81 | features_arr = np.array([features[index] for index in indexes[i]])
82 | xy_arr = np.array([xy[index] for index in indexes[i]])
83 | pred = model.forward(features_arr, fast_weights, reuse=True)
84 | pred = sess.run(tf.nn.softmax(pred))
85 | tmp = np.hstack(
86 | (xy_arr[:, 0].reshape(xy_arr.shape[0], 1), xy_arr[:, 1].reshape(xy_arr.shape[0], 1), pred))
87 | savearr = np.vstack((savearr, tmp))
88 | """save model parameters to npz file"""
89 | adapted_weights = sess.run(fast_weights)
90 | np.savez('models_of_blocks/HK/model' + str(i), adapted_weights['w1'], adapted_weights['b1'],
91 | adapted_weights['w2'], adapted_weights['b2'],
92 | adapted_weights['w3'], adapted_weights['b3'],
93 | adapted_weights['w4'], adapted_weights['b4'])
94 |
95 | writer = pd.ExcelWriter('tmp/' + savename)
96 | data_df = pd.DataFrame(savearr)
97 | data_df.to_excel(writer)
98 | writer.close()
99 |
100 | print('save LSM successfully')
101 | sess.close()
102 |
103 |
104 | if __name__ == "__main__":
105 | print('grid points assignment...')
106 | HK_tasks = read_tasks('./metatask_sampling/HK_tasks_K{k}.xlsx'.format(k=FLAGS.K))
107 | HK_taskpts = read_pts('./metatask_sampling/HKpts_tasks_K{k}.xlsx'.format(k=FLAGS.K))
108 | HK_gridpts_feature, HK_gridpts_xy = readfxy_csv('./src_data/grid_samples_HK.csv')
109 | HK_gridcluster = getclusters(HK_gridpts_xy, HK_taskpts, './metatask_sampling/' + FLAGS.str_region + \
110 | '_SLIC_M{m}_K{k}_loop{loop}.tif'.format(loop=0, m=FLAGS.M, k=FLAGS.K))
111 |
112 | print('adapt and predict...')
113 | predict_LSM(HK_tasks, HK_gridpts_feature, HK_gridpts_xy, HK_gridcluster, 'proposed_prediction.xlsx')
114 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Hong Kong Landslide Susceptibility Mapping in a Meta-learning Way (tf2).
6 |
7 | [//]: # (# Landslide Susceptibility Assessment in Multiple Landslide-inducing Environments with a Landslide Inventory Augmented by InSAR Techniques)
8 |
9 | ## Table of Contents
10 |
11 | - [Background](#background)
12 | - [Data](#data)
13 | - [Dependencies](#dependencies)
14 | - [Usage](#usage)
15 | - [Contact](#contact)
16 | - [Citation](#citation)
17 |
18 |
19 | ## Background
20 | Landslide susceptibility assessment (LSA) is vital for landslide hazard mitigation and prevention.
21 | Recently, there have been vast applications of data-driven LSA methods owing to the increased
22 | availability of high-quality satellite data and landslide statistics. However, two issues remain to
23 | be addressed, as follows: (a) landslide records obtained from a landslide inventory (LI) are mainly
24 | based on the interpretation of optical images and site investigation, resulting in current datadriven
25 | models being insensitive to slope dynamics, such as slow-moving landslides; (b) Most
26 | study areas contain a variety of landslide-inducing environments (LIEs) that a single model
27 | can not accommodate well. In this study, we proposed the utilization of InSAR techniques
28 | to sample weak landslide labels from slow-moving slopes for LI augmentation; and meta-learn
29 | intermediate representations for the fast adaptation of LSA models corresponding to different
30 | LIEs. We performed feature permutation to identify dominant landslide-inducing factors (LIFs)
31 | and fostered guidance for targeted landslide prevention schemes. The results obtained in Hong
32 | Kong revealed that deformation in several mountainous regions are closely associated with the
33 | majority of recorded landslides. By augmenting the LI using InSAR techniques, the proposed
34 | method improved the perception of potential dynamic landslides and achieved better statistical
35 | performance. The discussion highlights that slope and stream power index (SPI) are the key
36 | LIFs in Hong Kong, but the dominant LIFs will vary under different LIEs. By comparison, the
37 | proposed method entails a fast-learning strategy and extensively outperforms other data-driven
38 | LSA techniques, e.g., by 3-6% in accuracy, 2-6% in precision, 1-2% in recall, 3-5% in F1-score,
39 | and approximately 10% in Cohen Kappa.
40 |
41 |
42 |
43 | Fig. 1: Overflow
44 |
45 |
46 | ## Data
47 |
48 | * The landslide inventory can be found [here](https://data.gov.hk/en-data/dataset/hk-cedd-csu-cedd-entli).
49 | * The related thematic information can be found [here](https://geodata.gov.hk/gs).
50 | * The nonlandslide/landslide sample vectors are filed into `./src_data/` where `samples_HK.csv` and `samples_HK_noTS.csv` are datasets with and without augmented slow-moving landslides, respectively.
51 |
52 | [//]: # ()
53 | [//]: # (The source and experiment data will be opened...)
54 |
55 |
56 | ## Dependencies
57 |
58 | The default branch uses tf2 environment:
59 | * cudatoolkit 11.2.2
60 | * cudnn 8.1.0.77
61 | * python 3.9.13
62 | * tensorflow 2.10.0
63 |
64 | Install required packages
65 | ```
66 | python -m pip install -r requirements.txt
67 | ```
68 |
69 | ## Usage
70 |
71 | * For the scene segmentation and task sampling stage, see `./scene_sampling.py`, the result would be output into `./metatask_sampling` folder.
72 | * For the meta learner, see `./meta_learner.py`.
73 | * For the model adaption and landslide susceptibility prediction, see `./predict_LSM.py`. The intermediate model and adapted models of blocks would be saved in folder `./checkpoint_dir` and `./models_of_blocks`, respectively.The adapted models will predict the susceptibility for each sample vector in `./src_data/grid_samples_HK.csv`.
74 | * The `./tmp` folder restores some temp records.
75 | * For the figuring in the experiment, see `./figure.py`, the figures would be saved in folder `./figs`.
76 |
77 |
78 | ## Contact
79 |
80 | To ask questions or report issues, please open an issue on the [issue tracker](https://github.com/CLi-de/Meta_LSM/issues).
81 |
82 | ## Citation
83 |
84 | If this repository helps your research, please cite the paper. Here is the BibTeX entry:
85 |
86 | ```
87 | @article{CHEN2023107342,
88 | title = {Landslide susceptibility assessment in multiple urban slope settings with a landslide inventory augmented by InSAR techniques},
89 | journal = {Engineering Geology},
90 | volume = {327},
91 | pages = {107342},
92 | year = {2023},
93 | issn = {0013-7952},
94 | doi = {https://doi.org/10.1016/j.enggeo.2023.107342},
95 | author = {Li Chen and Peifeng Ma and Chang Yu and Yi Zheng and Qing Zhu and Yulin Ding},
96 | }
97 | ```
98 |
99 | The preprint can be found: [here](https://www.sciencedirect.com/science/article/pii/S0013795223003605)
--------------------------------------------------------------------------------
/SHAP.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env pytho
2 | # -*- coding: utf-8 -*-
3 | # @Author : CHEN Li
4 | # @Time : 2022/12/2 14:55
5 | # @File : SHAP.py
6 | # @annotation
7 |
8 | import tensorflow as tf
9 | import xgboost
10 | import shap
11 | import warnings
12 | import matplotlib.pyplot as plt
13 | from sklearn import svm
14 | import numpy as np
15 | import pandas as pd
16 | from meta_LSM import FLAGS
17 | from modeling import MAML
18 |
19 | from sklearn.model_selection import train_test_split
20 |
21 | warnings.filterwarnings("ignore")
22 |
23 |
24 | # construct model
25 | def init_weights(file):
26 | """读取DAS权参"""
27 | with tf.compat.v1.variable_scope('model'): # get variable in 'model' scope, to reuse variables
28 | npzfile = np.load(file)
29 | weights = {}
30 | weights['w1'] = npzfile['arr_0']
31 | weights['b1'] = npzfile['arr_1']
32 | weights['w2'] = npzfile['arr_2']
33 | weights['b2'] = npzfile['arr_3']
34 | weights['w3'] = npzfile['arr_4']
35 | weights['b3'] = npzfile['arr_5']
36 | weights['w4'] = npzfile['arr_6']
37 | weights['b4'] = npzfile['arr_7']
38 | return weights
39 |
40 |
41 | # define model.pred_prob() for shap.KernelExplainer(model, data)
42 | def pred_prob(X_):
43 | with tf.compat.v1.variable_scope('model', reuse=True):
44 | return sess.run(tf.nn.softmax(model.forward(X_, model.weights, reuse=True)))
45 |
46 |
47 | # read subtasks
48 | def read_tasks(file):
49 | """获取tasks"""
50 | f = pd.ExcelFile(file)
51 | tasks = [[] for i in range(len(f.sheet_names))]
52 | k = 0
53 | for sheetname in f.sheet_names:
54 | # attr = pd.read_excel(file, usecols=[i for i in range(FLAGS.dim_input)], sheet_name=sheetname,
55 | # header=None).values.astype(np.float32)
56 | arr = pd.read_excel(file, sheet_name=sheetname,
57 | header=None).values.astype(np.float32)
58 | tasks[k] = arr
59 | k = k + 1
60 | return tasks
61 |
62 |
63 | """construct model"""
64 | tf.compat.v1.disable_eager_execution()
65 | model = MAML(FLAGS.dim_input, FLAGS.dim_output, test_num_updates=5)
66 | input_tensors_input = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_input)
67 | input_tensors_label = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_output)
68 | model.construct_model(input_tensors_input=input_tensors_input, input_tensors_label=input_tensors_label,
69 | prefix='metatrain_')
70 |
71 | tmp = np.loadtxt('./src_data/samples_HK.csv', dtype=str, delimiter=",", encoding='UTF-8')
72 | feature_names = tmp[0, :-3].astype(np.str)
73 | task = read_tasks('./metatask_sampling/HK_tasks_K{k}.xlsx'.format(k=FLAGS.K))
74 |
75 | sess = tf.compat.v1.InteractiveSession()
76 | init = tf.compat.v1.global_variables() # optimizer里会有额外variable需要初始化
77 | sess.run(tf.compat.v1.variables_initializer(var_list=init))
78 |
79 | # eligible i: [11, 31, 81, ],['planting area', 'catchment', 'mountainous areas with severe deformation', '']
80 | # SHAP for ith subtasks
81 | for i in range(11, len(task), 10):
82 | model.weights = init_weights('./models_of_blocks/HK/model' + str(i) + '.npz')
83 |
84 | tmp_ = task[i]
85 | np.random.shuffle(tmp_) # shuffle
86 | # # 训练集
87 | # x_train = tmp_[:int(tmp_.shape[0] / 2), :-1] # 加载i行数据部分
88 | # y_train = tmp_[:int(tmp_.shape[0] / 2), -1] # 加载类别标签部分
89 | # # 测试集
90 | # # x_test = tmp_[int(tmp_.shape[0] / 2):, :-1] # 加载i行数据部分
91 | # # y_test = tmp_[int(tmp_.shape[0] / 2):, -1] # 加载类别标签部分
92 | # X, Y
93 | X = tmp_[:, :-1] # 加载i行数据部分
94 | Y = tmp_[:, -1] # 加载类别标签部分
95 |
96 | shap.initjs()
97 | # SHAP demo are using dataframe instead of nparray
98 | X_ = pd.DataFrame(X) # convert np.array to pd.dataframe
99 | # x_test = pd.DataFrame(x_test)
100 | X_.columns = feature_names # 添加特征名称
101 | # x_test.columns = feature_names
102 |
103 | # explainer = shap.KernelExplainer(pred_prob, shap.kmeans(x_train, 80))
104 | explainer = shap.KernelExplainer(pred_prob, shap.sample(X_, 100))
105 | shap_values = explainer.shap_values(X_, nsamples=100) # shap_values
106 | # (_prob, n_samples, features)
107 | # TODO: refer https://shap-lrjball.readthedocs.io/en/latest/generated/shap.summary_plot.html to change plot style
108 | # shap.force_plot(explainer.expected_value[1], shap_values[1][0, :], x_test.iloc[0, :], show=True, matplotlib=True) # single feature
109 | shap.summary_plot(shap_values, X_, plot_type="bar", show=False)
110 | plt.savefig('tmp/bar_' + str(i) + '.pdf')
111 | plt.close()
112 | shap.summary_plot(shap_values[1], X_, plot_type="violin", show=False)
113 | plt.savefig('tmp/violin_' + str(i) + '.pdf')
114 | plt.close()
115 | # shap.summary_plot(shap_values[1], x_test, plot_type="compact_dot")
116 |
117 | # shap.force_plot(explainer.expected_value[1], shap_values[1], x_test, link="logit")
118 |
119 | # shap.dependence_plot('DV', shap_values[1], x_test, interaction_index=None)
120 | # shap.dependence_plot('SPI', shap_values[1], x_test, interaction_index='DV')
121 | # shap.plots.beeswarm(shap_values[0]) # the beeswarm plot requires Explanation object as the `shap_values` argument
122 |
--------------------------------------------------------------------------------
/comparison.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from sklearn import metrics, svm
7 |
8 | from sklearn.neural_network import MLPClassifier
9 | from sklearn.ensemble import RandomForestClassifier
10 | from unsupervised_pretraining.dbn_.models import SupervisedDBNClassification
11 |
12 | from sklearn.metrics import accuracy_score
13 | from sklearn.metrics import cohen_kappa_score
14 |
15 | from utils import cal_measure
16 | import shap
17 | import matplotlib.pyplot as plt
18 |
19 | import warnings
20 |
21 | warnings.filterwarnings("ignore")
22 |
23 | """model-agnostic SHAP"""
24 |
25 |
26 | def SHAP_(predict_proba, x_train, x_test, f_name):
27 | shap.initjs()
28 | # SHAP demo are using dataframe instead of nparray
29 | x_train = pd.DataFrame(x_train) # 将numpy的array数组x_test转为dataframe格式。
30 | x_test = pd.DataFrame(x_test)
31 | x_train.columns = f_name # 添加特征名称
32 | x_test.columns = f_name
33 |
34 | explainer = shap.KernelExplainer(predict_proba, shap.kmeans(x_train, 100))
35 | x_ = x_test
36 | shap_values = explainer.shap_values(x_, nsamples=100) # shap_values(_prob, n_samples, features)
37 | # shap.force_plot(explainer.expected_value[1], shap_values[1][0, :], x_test.iloc[0, :], show=True, matplotlib=True) # single feature
38 | shap.summary_plot(shap_values, x_, plot_type="bar", show=False)
39 | plt.savefig('tmp/bar_HK.pdf')
40 | plt.close()
41 | shap.summary_plot(shap_values[1], x_, plot_type="violin", show=False) # shap_values[k], k表类别,k=1(landslides)
42 | plt.savefig('tmp/violin_HK.pdf')
43 | plt.close()
44 | # shap.summary_plot(shap_values[1], x_test, plot_type="compact_dot")
45 |
46 | # shap.plots.beeswarm(shap_values[0]) # the beeswarm plot requires Explanation object as the `shap_values` argument
47 |
48 |
49 | def pred_LSM(trained_model, xy, samples, name):
50 | """LSM prediction"""
51 | pred = trained_model.predict_proba(samples)
52 | data = np.hstack((xy, pred))
53 | data_df = pd.DataFrame(data)
54 | writer = pd.ExcelWriter('./tmp/'+name+'_prediction_HK.xlsx')
55 | data_df.to_excel(writer, 'page_1', float_format='%.5f')
56 | writer.close()
57 |
58 |
59 | def SVM_(x_train, y_train, x_test, y_test):
60 | """predict and test"""
61 | print('start SVM evaluation...')
62 | model = svm.SVC(C=1, kernel='rbf', gamma=1 / (2 * x_train.var()), decision_function_shape='ovr', probability=True)
63 | # clf = svm.SVC(C=0.1, kernel='linear', decision_function_shape='ovr')
64 | model.fit(x_train, y_train)
65 | pred_train = model.predict(x_train)
66 | print('train accuracy:' + str(metrics.accuracy_score(y_train, pred_train)))
67 | pred_test = model.predict(x_test)
68 | print('test accuracy:' + str(metrics.accuracy_score(y_test, pred_test)))
69 | # Precision, Recall, F1-score
70 | cal_measure(pred_test, y_test)
71 | kappa_value = cohen_kappa_score(pred_test, y_test)
72 | print('Cohen_Kappa: %f' % kappa_value)
73 |
74 | # feature permutation
75 | print('SHAP...')
76 | SHAP_(model.predict_proba, x_train, x_test, f_names)
77 |
78 | return model
79 |
80 |
81 | # can be deprecated
82 | def ANN_(x_train, y_train, x_test, y_test):
83 | """predict and test"""
84 | print('start ANN evaluation...')
85 | model = MLPClassifier(hidden_layer_sizes=(32, 32, 16), activation='relu', solver='adam', alpha=0.01,
86 | batch_size=32, max_iter=1000)
87 | model.fit(x_train, y_train)
88 | pred_train = model.predict(x_train)
89 | print('Train Accuracy: %f' % accuracy_score(y_train, pred_train))
90 | pred_test = model.predict(x_test)
91 | print('Test Accuracy: %f' % accuracy_score(y_test, pred_test))
92 | # Precision, Recall, F1-score
93 | cal_measure(pred_test, y_test)
94 | kappa_value = cohen_kappa_score(pred_test, y_test)
95 | print('Cohen_Kappa: %f' % kappa_value)
96 |
97 | # SHAP
98 | print('SHAP...')
99 | # SHAP_(model.predict_proba, x_train, x_test, f_names)
100 |
101 | return model
102 |
103 |
104 | def DBN_(x_train, y_train, x_test, y_test):
105 | print('start DBN evaluation...')
106 | # Training
107 | model = SupervisedDBNClassification(hidden_layers_structure=[32, 32],
108 | learning_rate_rbm=0.001,
109 | learning_rate=0.5,
110 | n_epochs_rbm=10,
111 | n_iter_backprop=200,
112 | batch_size=64,
113 | activation_function='relu',
114 | dropout_p=0.1)
115 | model.fit(x_train, y_train)
116 |
117 | pred_train = np.array(model.predict(x_train))
118 | pred_test = np.array(model.predict(x_test))
119 | # 训练精度
120 | print('train_Accuracy: %f' % accuracy_score(y_train, pred_train))
121 | # 测试精度
122 | print('test_Accuracy: %f' % accuracy_score(y_test, pred_test))
123 | # pred1 = clf2.predict_proba() # 预测类别概率
124 | cal_measure(pred_test, y_test)
125 | kappa_value = cohen_kappa_score(pred_test, y_test)
126 | print('Cohen_Kappa: %f' % kappa_value)
127 |
128 | # SHAP
129 | print('SHAP...')
130 | # SHAP_(model.predict_proba, x_train, x_test, f_names)
131 | return model
132 |
133 |
134 | def RF_(x_train, y_train, x_test, y_test):
135 | """predict and test"""
136 | print('start RF evaluation...')
137 | model = RandomForestClassifier(n_estimators=200, max_depth=None)
138 |
139 | model.fit(x_train, y_train)
140 | pred_train = model.predict(x_train)
141 | pred_test = model.predict(x_test)
142 | # 训练精度
143 | print('train_Accuracy: %f' % accuracy_score(y_train, pred_train))
144 | # 测试精度
145 | print('test_Accuracy: %f' % accuracy_score(y_test, pred_test))
146 | # pred1 = clf2.predict_proba() # 预测类别概率
147 | cal_measure(pred_test, y_test)
148 | kappa_value = cohen_kappa_score(pred_test, y_test)
149 | print('Cohen_Kappa: %f' % kappa_value)
150 |
151 | # SHAP
152 | print('SHAP...')
153 | # TODO: SHAP for RF
154 | # SHAP_(model.predict_proba, x_train, x_test, f_names)
155 | shap.initjs()
156 | explainer = shap.Explainer(model)
157 | shap_values = explainer(x_train)
158 | shap.plots.bar(shap_values[:100, :, 0]) # shap_values(n_samples, features, _prob)
159 | return model
160 |
161 |
162 | if __name__ == "__main__":
163 | """Input data"""
164 | tmp = np.loadtxt('./src_data/samples_HK.csv', dtype=str, delimiter=",", encoding='UTF-8')
165 | f_names = tmp[0, :-3].astype(np.str)
166 | tmp_ = np.hstack((tmp[1:, :-3], tmp[1:, -1].reshape(-1, 1))).astype(np.float32)
167 | np.random.shuffle(tmp_) # shuffle
168 | # 训练集
169 | x_train = tmp_[:int(tmp_.shape[0] / 4 * 3), :-1] # 加载i行数据部分
170 | y_train = tmp_[:int(tmp_.shape[0] / 4 * 3), -1] # 加载类别标签部分
171 | x_train = x_train / x_train.max(axis=0)
172 | # 测试集
173 | x_test = tmp_[int(tmp_.shape[0] / 4 * 3):, :-1] # 加载i行数据部分
174 | y_test = tmp_[int(tmp_.shape[0] / 4 * 3):, -1] # 加载类别标签部分
175 | x_test = x_test / x_test.max(axis=0)
176 | # grid samples
177 | grid_f = np.loadtxt('./src_data/grid_samples_HK.csv', dtype=str, delimiter=",", encoding='UTF-8')
178 | samples_f = grid_f[1:, :-2].astype(np.float32)
179 | xy = grid_f[1:, -2:].astype(np.float32)
180 | samples_f = samples_f / samples_f.max(axis=0)
181 |
182 | """evaluate and save LSM result"""
183 | # SVM-based
184 | # model_svm = SVM_(x_train, y_train, x_test, y_test)
185 | # pred_LSM(model_svm, xy, samples_f, 'SVM')
186 | # print('done SVM-based LSM prediction! \n')
187 |
188 | # # MLP_based
189 | # model_mlp = ANN_(x_train, y_train, x_test, y_test)
190 | # pred_LSM(model_mlp, xy, samples_f, 'MLP')
191 | # print('done MLP-based LSM prediction! \n')
192 |
193 | # # DBN-based
194 | # model_dbn = DBN_(x_train, y_train, x_test, y_test)
195 | # pred_LSM(model_dbn, xy, samples_f, 'DBN')
196 | # print('done DBN-based LSM prediction! \n')
197 |
198 | # RF-based
199 | model_rf = RF_(x_train, y_train, x_test, y_test)
200 | pred_LSM(model_rf, xy, samples_f, 'RF')
201 | print('done RF-based LSM prediction! \n')
202 |
203 |
204 |
--------------------------------------------------------------------------------
/figs/Overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLi-de/Meta_LSM/c25e2904761e3f4b4d5a797f0b4db7ddfe53236e/figs/Overview.jpg
--------------------------------------------------------------------------------
/figure.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import matplotlib.pyplot as plt
4 | from sklearn import manifold
5 | from sklearn.decomposition import PCA
6 | import umap
7 |
8 | from mpl_toolkits.mplot3d import Axes3D # 3D plot
9 | import pandas as pd
10 | import tensorflow as tf
11 |
12 | from sklearn.metrics import roc_curve, auc
13 | from sklearn.model_selection import train_test_split
14 | from sklearn.preprocessing import label_binarize
15 |
16 | from scipy import interp
17 | from sklearn.metrics import roc_auc_score
18 | from sklearn import svm
19 | from sklearn.neural_network import MLPClassifier
20 | from sklearn.ensemble import RandomForestClassifier
21 | # from unsupervised_pretraining.dbn_.models import SupervisedDBNClassification
22 | from scipy.interpolate import make_interp_spline
23 | from sklearn.metrics._classification import accuracy_score
24 |
25 | """for visiualization"""
26 |
27 |
28 | def read_tasks(file, dim_input=16):
29 | """read csv and obtain tasks"""
30 | f = pd.ExcelFile(file)
31 | tasks = []
32 | for sheetname in f.sheet_names:
33 | attr = pd.read_excel(file, usecols=dim_input - 1, sheet_name=sheetname).values.astype(np.float32)
34 | label = pd.read_excel(file, usecols=[dim_input], sheet_name=sheetname).values.reshape((-1, 1)).astype(
35 | np.float32)
36 | tasks.append([attr, label])
37 | return tasks
38 |
39 |
40 | def read_csv(path):
41 | tmp = np.loadtxt(path, dtype=np.str, delimiter=",", encoding='UTF-8')
42 | tmp_feature = tmp[1:, :]
43 | np.random.shuffle(tmp_feature) # shuffle
44 | label_attr = tmp_feature[:, -1].astype(np.float32) #
45 | data_atrr = tmp_feature[:, :-1].astype(np.float32) #
46 | return data_atrr, label_attr
47 |
48 |
49 | def load_weights(npzfile):
50 | npzfile = np.load(npzfile)
51 | weights = {}
52 | weights['w0'] = npzfile['arr_0']
53 | weights['b0'] = npzfile['arr_1']
54 | weights['w1'] = npzfile['arr_2']
55 | weights['b1'] = npzfile['arr_3']
56 | weights['w2'] = npzfile['arr_4']
57 | weights['b2'] = npzfile['arr_5']
58 | return weights
59 |
60 |
61 | def transform_relu(inputX, weights, bias, activations=tf.nn.relu):
62 | return activations(tf.transpose(a=tf.matmul(weights, tf.transpose(a=inputX))) + bias)
63 |
64 |
65 | def forward(inp, weights, sess):
66 | for i in range(int(len(weights) / 2)): # 3 layers
67 | inp = transform_relu(inp, tf.transpose(a=weights['w' + str(i)]), weights['b' + str(i)])
68 | return sess.run(inp)
69 |
70 |
71 | def _PCA(X, y, figsavename):
72 | pca = PCA(n_components=3)
73 | X_pca = pca.fit_transform(X)
74 |
75 | x_min, x_max = X_pca.min(0), X_pca.max(0)
76 | X_norm = (X_pca - x_min) / (x_max - x_min)
77 |
78 | fig = plt.figure()
79 | ax = Axes3D(fig)
80 | # ax.scatter(x1,x2,x3,c=pre
81 |
82 | landslide_pts_x = []
83 | landslide_pts_y = []
84 | landslide_pts_z = []
85 | nonlandslide_pts_x = []
86 | nonlandslide_pts_y = []
87 | nonlandslide_pts_z = []
88 |
89 | for i in range(len(y)):
90 | if y[i] == 0:
91 | nonlandslide_pts_x.append(X_norm[i][0])
92 | nonlandslide_pts_y.append(X_norm[i][1])
93 | nonlandslide_pts_z.append(X_norm[i][2])
94 | if y[i] == 1:
95 | landslide_pts_x.append(X_norm[i][0])
96 | landslide_pts_y.append(X_norm[i][1])
97 | landslide_pts_z.append(X_norm[i][2])
98 |
99 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red')
100 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue')
101 |
102 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2)
103 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
104 | # 设置坐标标签
105 | ax.set_xlabel('x-axis')
106 | ax.set_ylabel('y-axis')
107 | ax.set_zlabel('z-axis')
108 |
109 | # 设置标题
110 | plt.title("Visualization with PCA")
111 | plt.savefig(figsavename)
112 | # 显示图形
113 | plt.show()
114 |
115 |
116 | def ISOMAP(X, y, figsavename):
117 | isomap = manifold.Isomap(n_components=3)
118 | X_isomap = isomap.fit_transform(X)
119 |
120 | x_min, x_max = X_isomap.min(0), X_isomap.max(0)
121 | X_norm = (X_isomap - x_min) / (x_max - x_min)
122 |
123 | fig = plt.figure()
124 | ax = Axes3D(fig)
125 | # ax.scatter(x1,x2,x3,c=pre
126 |
127 | landslide_pts_x = []
128 | landslide_pts_y = []
129 | landslide_pts_z = []
130 | nonlandslide_pts_x = []
131 | nonlandslide_pts_y = []
132 | nonlandslide_pts_z = []
133 |
134 | for i in range(len(y)):
135 | if y[i] == 0:
136 | nonlandslide_pts_x.append(X_norm[i][0])
137 | nonlandslide_pts_y.append(X_norm[i][1])
138 | nonlandslide_pts_z.append(X_norm[i][2])
139 | if y[i] == 1:
140 | landslide_pts_x.append(X_norm[i][0])
141 | landslide_pts_y.append(X_norm[i][1])
142 | landslide_pts_z.append(X_norm[i][2])
143 |
144 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red')
145 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue')
146 |
147 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2)
148 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
149 | # 设置坐标标签
150 | ax.set_xlabel('x-axis')
151 | ax.set_ylabel('y-axis')
152 | ax.set_zlabel('z-axis')
153 | # 设置标题
154 | plt.title("Visualization with Isomap")
155 | plt.savefig(figsavename)
156 | # 显示图形
157 | plt.show()
158 |
159 |
160 | def t_SNE(X, y, figsavename):
161 | tsne = manifold.TSNE(n_components=3, init='random', random_state=501)
162 | X_tsne = tsne.fit_transform(X)
163 |
164 | """嵌入空间可视化"""
165 | x_min, x_max = X_tsne.min(0), X_tsne.max(0)
166 | X_norm = (X_tsne - x_min) / (x_max - x_min)
167 |
168 | fig = plt.figure()
169 | ax = Axes3D(fig)
170 | # ax.scatter(x1,x2,x3,c=pre
171 |
172 | landslide_pts_x = []
173 | landslide_pts_y = []
174 | landslide_pts_z = []
175 | nonlandslide_pts_x = []
176 | nonlandslide_pts_y = []
177 | nonlandslide_pts_z = []
178 |
179 | for i in range(len(y)):
180 | if y[i] == 0:
181 | nonlandslide_pts_x.append(X_norm[i][0])
182 | nonlandslide_pts_y.append(X_norm[i][1])
183 | nonlandslide_pts_z.append(X_norm[i][2])
184 | if y[i] == 1:
185 | landslide_pts_x.append(X_norm[i][0])
186 | landslide_pts_y.append(X_norm[i][1])
187 | landslide_pts_z.append(X_norm[i][2])
188 |
189 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red')
190 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue')
191 |
192 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2)
193 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
194 | # 设置坐标标签
195 | ax.set_xlabel('x-axis')
196 | ax.set_ylabel('y-axis')
197 | ax.set_zlabel('z-axis')
198 | # 设置标题
199 | plt.title("Visualization with t-SNE")
200 | plt.savefig(figsavename)
201 | # 显示图形
202 | plt.show()
203 |
204 |
205 | def UMAP(X, y, figsavename):
206 | reducer = umap.UMAP(n_components=3)
207 | X_umap = reducer.fit_transform(X)
208 |
209 | x_min, x_max = X_umap.min(0), X_umap.max(0)
210 | X_norm = (X_umap - x_min) / (x_max - x_min)
211 |
212 | fig = plt.figure()
213 | ax = Axes3D(fig)
214 | # ax.scatter(x1,x2,x3,c=pre
215 |
216 | landslide_pts_x = []
217 | landslide_pts_y = []
218 | landslide_pts_z = []
219 | nonlandslide_pts_x = []
220 | nonlandslide_pts_y = []
221 | nonlandslide_pts_z = []
222 |
223 | for i in range(len(y)):
224 | if y[i] == 0:
225 | nonlandslide_pts_x.append(X_norm[i][0])
226 | nonlandslide_pts_y.append(X_norm[i][1])
227 | nonlandslide_pts_z.append(X_norm[i][2])
228 | if y[i] == 1:
229 | landslide_pts_x.append(X_norm[i][0])
230 | landslide_pts_y.append(X_norm[i][1])
231 | landslide_pts_z.append(X_norm[i][2])
232 |
233 | type_landslide = ax.scatter(landslide_pts_x, landslide_pts_y, landslide_pts_z, c='red')
234 | type_nonlandslide = ax.scatter(nonlandslide_pts_x, nonlandslide_pts_y, nonlandslide_pts_z, c='blue')
235 |
236 | ax.legend((type_landslide, type_nonlandslide), ('landslide points', 'nonlandslide points'), loc=2)
237 | # plt.legend( bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
238 | # 设置坐标标签
239 | ax.set_xlabel('x-axis')
240 | ax.set_ylabel('y-axis')
241 | ax.set_zlabel('z-axis')
242 | # 设置标题
243 | plt.title("Visualization with UMAP")
244 | plt.savefig(figsavename)
245 | # 显示图形
246 | plt.show()
247 |
248 |
249 | def visualization():
250 | fj_tasks = read_tasks('./metatask_sampling/FJ_tasks.xlsx') # task里的samplles
251 | # fl_tasks = read_tasks('./metatask_sampling/FL_tasks.xlsx') # num_samples of FL is too scarce to visualize
252 |
253 | """select part of FJ and FL data for visualization"""
254 | # def get_oriandinf_Xs(tasks, regionname):
255 | # ori_Xs, inf_Xs, Ys = [], [], []
256 | # for i in range(len(tasks)):
257 | # if tasks[i][0].shape[0] > 30:
258 | # """download model parameters"""
259 | # w = load_weights('models_of_blocks/'+ regionname +'/model'+str(i) + '.npz')
260 | # ori_Xs.append(tasks[i][0])
261 | # inf_Xs.append(forward(tasks[i][0], w))
262 | # Ys.append(tasks[i][1])
263 | # return ori_Xs, inf_Xs, Ys
264 | """overall FJ and FL data for visualization"""
265 |
266 | def get_oriandinf_Xs(tasks, regionname):
267 | with tf.compat.v1.Session() as sess: # for tf calculation
268 | w = load_weights('models_of_blocks/' + 'overall_' + regionname + '/model_MAML' + '.npz')
269 | ori_Xs = tasks[0][0]
270 | inf_Xs = forward(tasks[0][0], w, sess)
271 | Ys = tasks[0][1]
272 | for i in range(len(tasks) - 1):
273 | if len(tasks[i + 1][0]) > 0:
274 | ori_Xs = np.vstack((ori_Xs, tasks[i + 1][0]))
275 | inf_Xs = np.vstack((inf_Xs, forward(tasks[i + 1][0], w, sess)))
276 | Ys = np.vstack((Ys, tasks[i + 1][1]))
277 | return ori_Xs, inf_Xs, Ys
278 |
279 | ori_FJ_Xs, inf_FJ_Xs, FJ_Ys = get_oriandinf_Xs(fj_tasks, 'FJ')
280 |
281 | # ori_FL_Xs, inf_FL_Xs, FL_Ys = get_oriandinf_Xs(fl_tasks, 'FL')
282 |
283 | # ori_X, y = read_csv('src_data/FJ_FL.csv')
284 | #
285 | # tmp = np.loadtxt('src_data/FJ_FL.csv', dtype=np.str, delimiter=",",encoding='UTF-8')
286 | # w = load_weights('unsupervised_pretraining/model_init/savedmodel.npz')
287 | # unsupervised_X = forward(ori_X, w)
288 | def plot_points(ori_X, inf_X, Y, regionname):
289 | _PCA(ori_X, Y, './figs/' + regionname + '_ori_PCA.pdf')
290 | _PCA(inf_X, Y, './figs/' + regionname + '_inf_PCA.pdf')
291 |
292 | t_SNE(ori_X, Y, './figs/' + regionname + '_ori_t_SNE.pdf')
293 | t_SNE(inf_X, Y, './figs/' + regionname + '_inf_SNE.pdf')
294 |
295 | ISOMAP(ori_X, Y, './figs/' + regionname + '_ori_Isomap.pdf')
296 | ISOMAP(inf_X, Y, './figs/' + regionname + '_inf_Isomap.pdf')
297 |
298 | UMAP(ori_X, Y, './figs/' + regionname + '_ori_UMAP.pdf')
299 | UMAP(inf_X, Y, './figs/' + regionname + '_inf_UMAP.pdf')
300 |
301 | plot_points(ori_FJ_Xs, inf_FJ_Xs, FJ_Ys, 'FJ')
302 | # plot_points(ori_FL_Xs, inf_FL_Xs, FL_Ys, 'FL')
303 |
304 |
305 | """for figure plotting"""
306 |
307 |
308 | def read_statistic(file):
309 | """读取csv获取statistic"""
310 | f = pd.ExcelFile(file)
311 | K, meanOA, maxOA, minOA, std = [], [], [], [], []
312 | for sheetname in f.sheet_names:
313 | tmp_K, tmp_meanOA, tmp_maxOA, tmp_minOA, tmp_std = np.transpose(
314 | pd.read_excel(file, sheet_name=sheetname).values)
315 | K.append(tmp_K)
316 | meanOA.append(tmp_meanOA)
317 | maxOA.append(tmp_maxOA)
318 | minOA.append(tmp_minOA)
319 | std.append(tmp_std)
320 | return K, meanOA, maxOA, minOA, std
321 |
322 |
323 | def read_statistic1(file):
324 | """读取csv获取statistic"""
325 | f = pd.ExcelFile(file)
326 | K, meanOA = [], []
327 | for sheetname in f.sheet_names:
328 | tmp_K, tmp_meanOA = np.transpose(pd.read_excel(file, sheet_name=sheetname).values)
329 | K.append(tmp_K)
330 | meanOA.append(tmp_meanOA)
331 | return K, meanOA
332 |
333 |
334 | def read_statistic2(file):
335 | """读取csv获取statistic"""
336 | f = pd.ExcelFile(file)
337 | measures = []
338 | for sheetname in f.sheet_names:
339 | temp = pd.read_excel(file, sheet_name=sheetname).values
340 | measures.append(temp[:, 1:].tolist())
341 | return measures
342 |
343 |
344 | def plot_candle(scenes, K, meanOA, maxOA, minOA, std):
345 | # 设置框图
346 | plt.figure("", facecolor="lightgray")
347 | # plt.style.use('ggplot')
348 | # 设置图例并且设置图例的字体及大小
349 | font1 = {'family': 'Times New Roman',
350 | 'weight': 'normal',
351 | 'size': 20,
352 | }
353 | font2 = {'family': 'Times New Roman',
354 | 'weight': 'normal',
355 | 'size': 18,
356 | }
357 |
358 | # legend = plt.legend(handles=[A,B],prop=font1)
359 | # plt.title(scenes, fontdict=font2)
360 | # plt.xlabel("Various methods", fontdict=font1)
361 | plt.ylabel("OA(%)", fontdict=font2)
362 |
363 | my_x_ticks = [1, 2, 3, 4, 5]
364 | # my_x_ticklabels = ['SVM', 'MLP', 'DBN', 'RF', 'Proposed']
365 | plt.xticks(ticks=my_x_ticks, labels='', fontsize=16)
366 |
367 | plt.ylim((60, 100))
368 | my_y_ticks = np.arange(60, 100, 5)
369 | plt.yticks(ticks=my_y_ticks, fontsize=16)
370 |
371 | colors = ['dodgerblue', 'lawngreen', 'gold', 'magenta', 'red']
372 | edge_colors = np.zeros(5, dtype="U1")
373 | edge_colors[:] = 'black'
374 |
375 | '''格网设置'''
376 | plt.grid(linestyle="--", zorder=-1)
377 |
378 | # draw line
379 | # plt.plot(K[0:-1], meanOA[0:-1], color="b", linestyle='solid',
380 | # linewidth=1, label="open", zorder=1)
381 | # plt.plot(K[-2:], meanOA[-2:], color="b", linestyle="--",
382 | # linewidth=1, label="open", zorder=1)
383 |
384 | # draw bar
385 | barwidth = 0.4
386 | plt.bar(K, 2 * std, barwidth, bottom=meanOA - std, color=colors,
387 | edgecolor=edge_colors, linewidth=1, zorder=20, label=['SVM', 'MLP', 'DBN', 'RF', 'Proposed'])
388 |
389 | # draw vertical line
390 | plt.vlines(K, minOA, maxOA, color='black', linestyle='solid', zorder=10)
391 | plt.hlines(meanOA, K - barwidth / 2, K + barwidth / 2, color='black', linestyle='solid', zorder=30)
392 | plt.hlines(minOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10)
393 | plt.hlines(maxOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10)
394 |
395 | # 设置图例
396 | legend = plt.legend(loc="lower center", prop=font1, ncol=3, columnspacing=0.1)
397 |
398 |
399 | def plot_scatter(arr):
400 | '''设置框图'''
401 | # plt.figure("", facecolor="lightgray") # 设置框图大小
402 | font1 = {'family': 'Times New Roman',
403 | 'weight': 'normal',
404 | 'size': 16,
405 | }
406 | font2 = {'family': 'Times New Roman',
407 | 'weight': 'normal',
408 | 'size': 12,
409 | }
410 | plt.xlabel("Subtasks", fontdict=font1)
411 | plt.ylabel("Mean accuracy(%)", fontdict=font1)
412 |
413 | '''设置刻度'''
414 | plt.ylim((50, 100))
415 | my_y_ticks = np.arange(50, 100, 5)
416 | plt.yticks(my_y_ticks)
417 | my_x_ticks = [i for i in range(1, 204, 40)]
418 | my_x_ticklabel = [str(i) + 'th' for i in range(1, 204, 40)]
419 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabel)
420 | '''格网设置'''
421 | plt.grid(linestyle="--")
422 |
423 | x_ = [i for i in range(arr.shape[0])]
424 | '''draw scatter'''
425 | L1 = plt.scatter(x_, arr[:, 0], label="L=1", c="none", s=20, edgecolors='magenta')
426 | L2 = plt.scatter(x_, arr[:, 1], label="L=2", c="none", s=20, edgecolors='cyan')
427 | L3 = plt.scatter(x_, arr[:, 2], label="L=3", c="none", s=20, edgecolors='b')
428 | L4 = plt.scatter(x_, arr[:, 3], label="L=4", c="none", s=20, edgecolors='g')
429 | L5 = plt.scatter(x_, arr[:, 4], label="L=5", c="none", s=20, edgecolors='r')
430 |
431 | '''设置图例'''
432 | legend = plt.legend(loc="lower left", prop=font2, ncol=3)
433 | # plt.savefig("C:\\Users\\hj\\Desktop\\brokenline_A")
434 | # plt.show()
435 |
436 |
437 | def plot_lines(arr):
438 | '''设置框图'''
439 | # plt.figure("", facecolor="lightgray") # 设置框图大小
440 | font1 = {'family': 'Times New Roman',
441 | 'weight': 'normal',
442 | 'size': 16,
443 | }
444 | font2 = {'family': 'Times New Roman',
445 | 'weight': 'normal',
446 | 'size': 12,
447 | }
448 | plt.xlabel("Subtasks", fontdict=font1)
449 | plt.ylabel("Mean accuracy(%)", fontdict=font1)
450 |
451 | '''设置刻度'''
452 | plt.ylim((50, 100))
453 | my_y_ticks = np.arange(50, 100, 5)
454 | plt.yticks(my_y_ticks)
455 | my_x_ticks = [i for i in range(6)]
456 | my_x_ticklabel = [str(i + 1) + '/12 M' for i in range(6)]
457 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabel)
458 | '''格网设置'''
459 | plt.grid(linestyle="--")
460 |
461 | x_ = np.array([i for i in range(6)])
462 | # smooth
463 | # x_ = np.linspace(x_.min(), x_.max(), 400)
464 | # arr = make_interp_spline(x_, arr)(x_)
465 | '''draw line'''
466 | L1 = plt.plot(x_, arr[:, 0], color="r", linestyle="solid",
467 | linewidth=1, label="L=1", markerfacecolor='white', ms=10)
468 | L2 = plt.plot(x_, arr[:, 1], color="orange", linestyle="solid",
469 | linewidth=1, label="L=2", markerfacecolor='white', ms=10)
470 | L3 = plt.plot(x_, arr[:, 2], color="gold", linestyle="solid",
471 | linewidth=1, label="L=3", markerfacecolor='white', ms=10)
472 | L4 = plt.plot(x_, arr[:, 3], color="g", linestyle="solid",
473 | linewidth=1, label="L=4", markerfacecolor='white', ms=10)
474 | L5 = plt.plot(x_, arr[:, 4], color="b", linestyle="solid",
475 | linewidth=1, label="L=5", markerfacecolor='white', ms=10)
476 |
477 |
478 | def plot_histogram(region, measures):
479 | '''设置框图'''
480 | plt.figure("", facecolor="lightgray") # 设置框图大小
481 | font1 = {'family': 'Times New Roman',
482 | 'weight': 'normal',
483 | 'size': 14,
484 | }
485 | font2 = {'family': 'Times New Roman',
486 | 'weight': 'normal',
487 | 'size': 18,
488 | }
489 | # plt.xlabel("Statistical measures", fontdict=font1)
490 | plt.ylabel("Performance(%)", fontdict=font1)
491 | plt.title(region, fontdict=font2)
492 |
493 | '''设置刻度'''
494 | plt.ylim((60, 90))
495 | my_y_ticks = np.arange(60, 90, 3)
496 | plt.yticks(my_y_ticks)
497 |
498 | my_x_ticklabels = ['Accuracy', 'Precision', 'Recall', 'F1-score']
499 | bar_width = 0.3
500 | interval = 0.2
501 | my_x_ticks = np.arange(bar_width / 2 + 2.5 * bar_width, 4 * 5 * bar_width + 1, bar_width * 6)
502 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabels, fontproperties='Times New Roman', size=14)
503 |
504 | '''格网设置'''
505 | plt.grid(linestyle="--")
506 |
507 | '''draw bar'''
508 | rects1 = plt.bar([x - 2 * bar_width for x in my_x_ticks], height=measures[0], width=bar_width, alpha=0.8,
509 | color='dodgerblue', label="MLP")
510 | rects2 = plt.bar([x - 1 * bar_width for x in my_x_ticks], height=measures[1], width=bar_width, alpha=0.8,
511 | color='yellowgreen', label="RF")
512 | rects3 = plt.bar([x for x in my_x_ticks], height=measures[2], width=bar_width, alpha=0.8, color='gold', label="RL")
513 | rects4 = plt.bar([x + 1 * bar_width for x in my_x_ticks], height=measures[3], width=bar_width, alpha=0.8,
514 | color='peru', label="MAML")
515 | rects5 = plt.bar([x + 2 * bar_width for x in my_x_ticks], height=measures[4], width=bar_width, alpha=0.8,
516 | color='crimson', label="proposed")
517 |
518 | '''设置图例'''
519 | legend = plt.legend(loc="upper left", prop=font1, ncol=3)
520 |
521 | '''add text'''
522 | # for rect in rects1:
523 | # height = rect.get_height()
524 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom")
525 | # for rect in rects2:
526 | # height = rect.get_height()
527 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom")
528 | # for rect in rects3:
529 | # height = rect.get_height()
530 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom")
531 | # for rect in rects4:
532 | # height = rect.get_height()
533 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom")
534 | # for rect in rects5:
535 | # height = rect.get_height()
536 | # plt.text(rect.get_x() + rect.get_width() / 2, height+1, str(height)+'%', ha="center", va="bottom")
537 |
538 | plt.savefig("C:\\Users\\hj\\Desktop\\histogram" + region + '.pdf')
539 | plt.show()
540 |
541 |
542 | """for AUROC plotting"""
543 |
544 |
545 | def load_data(filepath, dim_input):
546 | np.loadtxt(filepath, )
547 | data = pd.read_excel(filepath).values.astype(np.float32)
548 | attr = data[:, :dim_input]
549 | attr = attr / attr.max(axis=0)
550 | label = data[:, -1].astype(np.int32)
551 | return attr, label
552 |
553 |
554 | def SVM_fit_pred(x_train, x_test, y_train, y_test):
555 | classifier = svm.SVC(C=1, kernel='rbf', gamma=1 / (2 * x_train.var()), decision_function_shape='ovr',
556 | probability=True)
557 | classifier.fit(x_train, y_train)
558 | return classifier.predict_proba(x_test)
559 |
560 |
561 | def MLP_fit_pred(x_train, x_test, y_train, y_test):
562 | classifier = MLPClassifier(hidden_layer_sizes=(32, 32, 16), activation='relu', solver='adam', alpha=0.01,
563 | batch_size=32, max_iter=1000)
564 | classifier.fit(x_train, y_train)
565 | return classifier.predict_proba(x_test)
566 |
567 |
568 | # def DBN_fit_pred(x_train, x_test, y_train, y_test):
569 | # classifier = SupervisedDBNClassification(hidden_layers_structure=[32, 32],
570 | # learning_rate_rbm=0.001,
571 | # learning_rate=0.5,
572 | # n_epochs_rbm=10,
573 | # n_iter_backprop=200,
574 | # batch_size=64,
575 | # activation_function='relu',
576 | # dropout_p=0.1)
577 | # classifier.fit(x_train, y_train)
578 | # pred_prob = classifier.predict_proba(x_test)
579 | #
580 | # # if pred_prob[0][0] > 0.5:
581 | # # pred_prob = np.vstack((pred_prob[:, 0], pred_prob[:, -1])).T # swap 0, 1 prediction
582 | #
583 | # return pred_prob
584 |
585 |
586 | def RF_fit_pred(x_train, x_test, y_train, y_test):
587 | classifier = RandomForestClassifier(n_estimators=200, max_depth=None)
588 | classifier.fit(x_train, y_train)
589 | return classifier.predict_proba(x_test)
590 |
591 |
592 | def plot_auroc(n_times, y_score_SVM, y_score_MLP, y_score_DBN, y_score_RF, y_score_proposed, y_test, y_test_proposed):
593 | # Compute ROC curve and ROC area for each class
594 | def cal_(y_score, y_test):
595 | fpr, tpr = [], []
596 | for i in range(n_times):
597 | fpr_, tpr_, thresholds = roc_curve(y_test[i], y_score[i][:, -1], pos_label=1)
598 | fpr.append(fpr_)
599 | tpr.append(tpr_)
600 |
601 | # First aggregate all false positive rates
602 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_times)]))
603 |
604 | # Then interpolate all ROC curves at this points
605 | mean_tpr = np.zeros_like(all_fpr)
606 | for i in range(n_times):
607 | mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
608 |
609 | # Finally average it and compute AUC
610 | mean_tpr /= n_times
611 | mean_auc = auc(all_fpr, mean_tpr)
612 | return all_fpr, mean_tpr, mean_auc, fpr, tpr
613 |
614 | def plot_(y_score, y_test, color, method):
615 | all_fpr, mean_tpr, mean_auc, fpr, tpr = cal_(y_score, y_test)
616 | # draw mean
617 | plt.plot(all_fpr, mean_tpr,
618 | label=method + '_mean_AUC (area = {0:0.3f})'''.format(mean_auc),
619 | color=color, linewidth=1.5)
620 | # draw each
621 | for i in range(n_times):
622 | plt.plot(fpr[i], tpr[i],
623 | color=color, linewidth=1, alpha=.25)
624 | # plt.savefig(method + '.pdf')
625 |
626 | # Plot all ROC curves
627 | # ax = plt.axes()
628 | # ax.set_facecolor("WhiteSmoke") # background color
629 | plot_(y_score_SVM, y_test, color='dodgerblue', method='SVM')
630 | plot_(y_score_MLP, y_test, color='lawngreen', method='MLP')
631 | plot_(y_score_DBN, y_test, color='gold', method='DBN')
632 | plot_(y_score_RF, y_test, color='magenta', method='RF')
633 | plot_(y_score_proposed, y_test_proposed, color='red', method='Proposed')
634 |
635 | # format
636 | font1 = {'family': 'Times New Roman',
637 | 'weight': 'normal',
638 | 'size': 16,
639 | }
640 | font2 = {'family': 'Times New Roman',
641 | 'weight': 'normal',
642 | 'size': 12,
643 | }
644 | plt.plot([0, 1], [0, 1], 'k--', lw=1, label='random')
645 | plt.xlim([0.0, 1.0])
646 | plt.ylim([0.0, 1.0])
647 | plt.xlabel('False Positive Rate', fontdict=font1)
648 | plt.ylabel('True Positive Rate', fontdict=font1)
649 | plt.title('ROC curve by various methods', fontdict=font1)
650 | plt.legend(loc="lower right", prop=font2)
651 |
652 |
653 | """space visualization"""
654 | # visualization()
655 |
656 | """draw histogram"""
657 |
658 | # regions = ['FJ', 'FL']
659 | # measures = read_statistic2("C:\\Users\\hj\\Desktop\\performance.xlsx")
660 | # for i in range(len(regions)):
661 | # plot_histogram(regions[i], measures[i])
662 |
663 |
664 | """draw candle"""
665 |
666 |
667 | scenes = ['airport', 'urban1', 'urban2', 'plain', 'catchment', 'reservior']
668 | K, meanOA, maxOA, minOA, std = read_statistic("C:\\Users\\lichen\\OneDrive\\桌面\\statistics_candle.xlsx")
669 | for i in range(len(scenes)):
670 | plot_candle(scenes[i], K[i], meanOA[i], maxOA[i], minOA[i], std[i])
671 | plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\" + scenes[i] + '_' + 'candle.pdf')
672 | plt.show()
673 |
674 |
675 | def read_f_l_csv(file):
676 | tmp = np.loadtxt(file, dtype=str, delimiter=",", encoding='UTF-8')
677 | features = tmp[1:, :-2].astype(np.float32)
678 | features = features / features.max(axis=0)
679 | label = tmp[1:, -1].astype(np.float32)
680 | return features, label
681 |
682 |
683 | # """draw AUR"""
684 | # print('drawing ROC...')
685 | # x, y = read_f_l_csv('src_data/samples_HK.csv')
686 | # y_score_SVM, y_score_MLP, y_score_DBN, y_score_RF, y_score_proposed, y_test_, y_test_proposed = [], [], [], [], [], [], []
687 | # n_times = 5
688 | # for i in range(n_times):
689 | # x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.75, test_size=.02, shuffle=True)
690 | # """fit and predict"""
691 | # # for other methods
692 | # y_score_SVM.append(SVM_fit_pred(x_train, x_test, y_train, y_test))
693 | # y_score_MLP.append(MLP_fit_pred(x_train, x_test, y_train, y_test))
694 | # y_score_DBN.append(MLP_fit_pred(x_train, x_test, y_train, y_test))
695 | # y_score_RF.append(RF_fit_pred(x_train, x_test, y_train, y_test))
696 | # y_test_.append(y_test)
697 | # # for proposed-
698 | # tmp = pd.read_excel('tmp/' + 'proposed_test' + str(i) + '.xlsx').values.astype(np.float32)
699 | # y_score_proposed.append(tmp[:, 1:3])
700 | # y_test_proposed.append(tmp[:, -1])
701 | # # draw roc
702 | # plt.clf()
703 | # plot_auroc(n_times, y_score_SVM, y_score_MLP, y_score_DBN, y_score_RF, y_score_proposed, y_test_, y_test_proposed)
704 | # plt.savefig('ROC.pdf')
705 | # plt.show()
706 | # print('finish')
707 |
708 | """draw scatters for fast adaption performance"""
709 | # filename = "C:\\Users\\lichen\\OneDrive\\桌面\\fast_adaption_sheet2.csv"
710 | # arr = np.loadtxt(filename, dtype=float, delimiter=",", encoding='utf-8-sig')
711 | # plot_scatter(arr)
712 | # plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\scatters.pdf")
713 | # plt.show()
714 |
715 | """draw lines for fast adaption performance"""
716 | # filename = "C:\\Users\\lichen\\OneDrive\\桌面\\fast_adaption1.csv"
717 | # arr = np.loadtxt(filename, dtype=float, delimiter=",", encoding='utf-8-sig')
718 | # plot_lines(arr)
719 | # plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\broken.pdf")
720 | # plt.show()
721 |
722 | """
723 | label: for legend
724 | pos_: -2, -1, 0, 1, 2
725 | """
726 |
727 |
728 | def plot_candle1(K, meanOA, maxOA, minOA, std, color_, label_, pos_):
729 | # 设置框图
730 | # plt.figure("", facecolor="lightgray")
731 | # plt.style.use('ggplot')
732 | # 设置图例并且设置图例的字体及大小
733 | font1 = {'family': 'Times New Roman',
734 | 'weight': 'normal',
735 | 'size': 14,
736 | }
737 | font2 = {'family': 'Times New Roman',
738 | 'weight': 'normal',
739 | 'size': 16,
740 | }
741 |
742 | # legend = plt.legend(handles=[A,B],prop=font1)
743 | # plt.title(scenes, fontdict=font2)
744 | plt.xlabel("Number of samples", fontdict=font1)
745 | plt.ylabel("OA(%)", fontdict=font2)
746 |
747 | my_x_ticks = [1, 2, 3, 4, 5]
748 | my_x_ticklabels = ['1', '2', '3', '4', '5']
749 | plt.xticks(ticks=my_x_ticks, labels=my_x_ticklabels, fontsize=14, fontdict=font2)
750 |
751 | plt.ylim((50, 100))
752 | my_y_ticks = np.arange(50, 100, 5)
753 | plt.yticks(ticks=my_y_ticks, fontsize=14, font=font2)
754 |
755 | '''格网设置'''
756 | plt.grid(linestyle="--", zorder=-1)
757 |
758 | colors = ['dodgerblue', 'lawngreen', 'gold', 'magenta', 'red']
759 | edge_colors = np.zeros(5, dtype="U1")
760 | edge_colors[:] = 'black'
761 |
762 | # draw bar
763 | barwidth = 0.15
764 | K = K + barwidth * pos_
765 | plt.bar(K, 2 * std, barwidth, bottom=meanOA - std, color=color_,
766 | edgecolor=edge_colors, linewidth=1, zorder=20, label=label_, alpha=0.5)
767 | # draw vertical line
768 | plt.vlines(K, minOA, meanOA - std, color='black', linestyle='solid', zorder=10)
769 | plt.vlines(K, maxOA, meanOA + std, color='black', linestyle='solid', zorder=10)
770 | plt.hlines(meanOA, K - barwidth / 2, K + barwidth / 2, color='blue', linestyle='solid', zorder=30)
771 | plt.hlines(minOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10)
772 | plt.hlines(maxOA, K - barwidth / 4, K + barwidth / 4, color='black', linestyle='solid', zorder=10)
773 | # 设置图例
774 | legend = plt.legend(loc="lower right", prop=font1, ncol=3, fontsize=24)
775 |
776 |
777 | """draw candles for fast adaption performance"""
778 | # K, meanOA, maxOA, minOA, std = read_statistic("C:\\Users\\lichen\\OneDrive\\桌面\\fast_adaption_candle.xlsx")
779 | # colors = ['magenta', 'cyan', 'b', 'g', 'r']
780 | # labels = ['L=1', 'L=2', 'L=3', 'L=4', 'L=5']
781 | # pos = [-2, -1, 0, 1, 2]
782 | # for i in range(5):
783 | # plot_candle1(K[i], meanOA[i], maxOA[i], minOA[i], std[i], colors[i], labels[i], pos[i])
784 | # # plt.show()
785 | # plt.savefig("C:\\Users\\lichen\\OneDrive\\桌面\\candle.pdf")
786 |
--------------------------------------------------------------------------------
/meta_LSM.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import pandas as pd
4 | from modeling import MAML
5 | from scene_segmentation import SLICProcessor, TaskSampling
6 | from tensorflow.python.platform import flags
7 | from utils import tasksbatch_generator, batch_generator, meta_train_test1, save_tasks, \
8 | read_tasks, savepts_fortask, cal_measure
9 | from unsupervised_pretraining.DAS_pretraining_v2 import Unsupervise_pretrain
10 | from sklearn.metrics import accuracy_score
11 | from sklearn.metrics import cohen_kappa_score
12 |
13 | from sklearn.neural_network import MLPClassifier
14 |
15 | from comparison import SHAP_
16 | import warnings
17 | import os
18 |
19 | warnings.filterwarnings("ignore")
20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | """for task sampling"""
25 | flags.DEFINE_float('M', 500, 'determine how distance influence the segmentation')
26 | flags.DEFINE_integer('K', 512, 'number of superpixels')
27 | flags.DEFINE_integer('loop', 5, 'number of SLIC iterations')
28 | flags.DEFINE_string('str_region', 'HK', 'the study area')
29 | flags.DEFINE_string('sample_pts', './src_data/samples_HK.csv', 'path to (non)/landslide samples')
30 | flags.DEFINE_string('Ts_pts', './src_data/Ts_HK.csv', 'path to Ts samples')
31 |
32 | """for meta-train"""
33 | flags.DEFINE_string('basemodel', 'DAS', 'MLP: no unsupervised pretraining; DAS: pretraining with DAS')
34 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
35 | flags.DEFINE_string('log', './tmp/data', 'batch_norm, layer_norm, or None')
36 | flags.DEFINE_string('logdir', './checkpoint_dir', 'directory for summaries and checkpoints.')
37 |
38 | flags.DEFINE_integer('dim_input', 14, 'dim of input data')
39 | flags.DEFINE_integer('dim_output', 2, 'dim of output data')
40 | flags.DEFINE_integer('meta_batch_size', 16, 'number of tasks sampled per meta-update, not nums tasks')
41 | flags.DEFINE_integer('num_samples_each_task', 16,
42 | 'number of samples sampling from each task when training, inner_batch_size')
43 | flags.DEFINE_integer('test_update_batch_size', 8,
44 | 'number of examples used for gradient update during adapting (K=1,3,5 in experiment, K-shot); -1: M.')
45 | flags.DEFINE_integer('metatrain_iterations', 5001, 'number of meta-training iterations.')
46 | flags.DEFINE_integer('num_updates', 5, 'number of inner gradient updates during training.')
47 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.')
48 | # flags.DEFINE_integer('num_samples', 18469, 'total number of samples in HK, see samples_HK.')
49 | flags.DEFINE_float('update_lr', 1e-2, 'learning rate of single task objective (inner)') # le-2 is the best
50 | flags.DEFINE_float('meta_lr', 1e-3, 'the base learning rate of meta objective (outer)') # le-2 or le-3
51 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')
52 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available')
53 |
54 |
55 | def train(model, saver, sess, exp_string, tasks, resume_itr):
56 | SUMMARY_INTERVAL = 100
57 | SAVE_INTERVAL = 1000
58 | PRINT_INTERVAL = 1000
59 |
60 | print('Done model initializing, starting training...')
61 | prelosses, postlosses = [], []
62 | if resume_itr != FLAGS.pretrain_iterations + FLAGS.metatrain_iterations - 1:
63 | if FLAGS.log:
64 | train_writer = tf.compat.v1.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph)
65 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations):
66 | batch_x, batch_y, cnt_sample = tasksbatch_generator(tasks, FLAGS.meta_batch_size
67 | , FLAGS.num_samples_each_task,
68 | FLAGS.dim_input,
69 | FLAGS.dim_output) # task_batch[i]: (x, y, features)
70 | # batch_y = _transform_labels_to_network_format(batch_y, FLAGS.num_classes)
71 | inputa = batch_x[:, :int(FLAGS.num_samples_each_task / 2), :] # a used for training
72 | labela = batch_y[:, :int(FLAGS.num_samples_each_task / 2), :]
73 | inputb = batch_x[:, int(FLAGS.num_samples_each_task / 2):, :] # b used for testing
74 | labelb = batch_y[:, int(FLAGS.num_samples_each_task / 2):, :]
75 | # # when deal with few-shot problem
76 | # inputa = batch_x[:, :int(len(batch_x[0]) / 2), :] # a used for training
77 | # labela = batch_y[:, :int(len(batch_y[0]) / 2), :]
78 | # inputb = batch_x[:, int(len(batch_x[0]) / 2):, :] # b used for testing
79 | # labelb = batch_y[:, int(len(batch_y[0]) / 2):, :]
80 |
81 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela,
82 | model.labelb: labelb, model.cnt_sample: cnt_sample}
83 |
84 | if itr < FLAGS.pretrain_iterations:
85 | input_tensors = [model.pretrain_op] # for comparison
86 | else:
87 | input_tensors = [model.metatrain_op] # meta_train
88 |
89 | if (itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0):
90 | input_tensors.extend([model.summ_op, model.total_loss1, model.total_losses2[FLAGS.num_updates - 1]])
91 |
92 | result = sess.run(input_tensors, feed_dict)
93 |
94 | if itr % SUMMARY_INTERVAL == 0:
95 | prelosses.append(result[-2])
96 | if FLAGS.log:
97 | train_writer.add_summary(result[1], itr) # add sum_op
98 | postlosses.append(result[-1])
99 |
100 | if (itr != 0) and itr % PRINT_INTERVAL == 0:
101 | if itr < FLAGS.pretrain_iterations:
102 | print_str = 'Pretrain Iteration ' + str(itr)
103 | else:
104 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations)
105 | print_str += ': ' + str(np.mean(prelosses)) + ', ' + str(np.mean(postlosses))
106 | print(print_str)
107 | print('inner lr:', sess.run(model.update_lr))
108 | prelosses, postlosses = [], []
109 | # save model
110 | if (itr != 0) and itr % SAVE_INTERVAL == 0:
111 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
112 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
113 |
114 |
115 | def test(model, saver, sess, exp_string, tasks, num_updates=5):
116 | print('start evaluation...')
117 | print(exp_string)
118 | total_Ytest, total_Ypred, total_Ytest1, total_Ypred1, sum_accuracies, sum_accuracies1 = [], [], [], [], [], []
119 |
120 | for i in range(len(tasks)):
121 | np.random.shuffle(tasks[i])
122 | train_ = tasks[i][:int(len(tasks[i]) / 2)]
123 | test_ = tasks[i][int(len(tasks[i]) / 2):] # test_ samples account 25%
124 | """few-steps tuning (不用op跑是因为采用的batch_size(input shape)不一致,且不想更新model.weight)"""
125 | with tf.compat.v1.variable_scope('model', reuse=True): # np.normalize()里Variable重用
126 | fast_weights = model.weights
127 | for j in range(num_updates):
128 | inputa, labela = batch_generator(train_, FLAGS.dim_input, FLAGS.dim_output,
129 | FLAGS.test_update_batch_size)
130 | loss = model.loss_func(model.forward(inputa, fast_weights, reuse=True),
131 | labela) # fast_weight和grads(stopped)有关系,但不影响这里的梯度计算
132 | grads = tf.gradients(ys=loss, xs=list(fast_weights.values()))
133 | gradients = dict(zip(fast_weights.keys(), grads))
134 | fast_weights = dict(zip(fast_weights.keys(),
135 | [fast_weights[key] - model.update_lr * gradients[key] for key in
136 | fast_weights.keys()]))
137 | """Single task test accuracy"""
138 | inputb, labelb = batch_generator(test_, FLAGS.dim_input, FLAGS.dim_output, len(test_))
139 | Y_array = sess.run(tf.nn.softmax(model.forward(inputb, fast_weights, reuse=True))) # pred_prob
140 | total_Ypred1.extend(Y_array) # pred_prob_test
141 | total_Ytest1.extend(labelb) # label
142 |
143 | Y_test = [] # for single task test
144 | for j in range(len(labelb)):
145 | Y_test.append(labelb[j][0])
146 | total_Ytest.append(labelb[j][0])
147 | Y_pred = [] # for single task test
148 | for j in range(len(labelb)):
149 | if Y_array[j][0] > Y_array[j][1]:
150 | Y_pred.append(1)
151 | total_Ypred.append(1) # total_Ypred: 1d-array label
152 | else:
153 | Y_pred.append(0)
154 | total_Ypred.append(0)
155 | accuracy = accuracy_score(Y_test, Y_pred)
156 | sum_accuracies.append(accuracy)
157 | # print('Test_Accuracy: %f' % accuracy)
158 | # print('SHAP...')
159 | # SHAP_() # TODO: SHAP for proposed
160 | """Overall evaluation (test data)"""
161 | total_Ypred = np.array(total_Ypred).reshape(len(total_Ypred), )
162 | total_Ytest = np.array(total_Ytest)
163 | total_acc = accuracy_score(total_Ytest, total_Ypred)
164 | print('Test_Accuracy: %f' % total_acc)
165 | cal_measure(total_Ypred, total_Ytest)
166 | kappa_value = cohen_kappa_score(total_Ypred, total_Ytest)
167 | print('Cohen_Kappa: %f' % kappa_value)
168 |
169 | # save prediction for test samples, which can be used in calculating statistical measure such as AUROC
170 | pred_prob = np.array(total_Ypred1)
171 | label_bi = np.array(total_Ytest1)
172 | savearr = np.hstack((pred_prob, label_bi))
173 | writer = pd.ExcelWriter('proposed_test.xlsx')
174 | data_df = pd.DataFrame(savearr)
175 | data_df.to_excel(writer)
176 | writer.close()
177 |
178 | sess.close()
179 |
180 |
181 | def main():
182 | """1.Unsupervised pretraining; 2.segmentation and meta-task sampling; 3.meta-training and -testing"""
183 |
184 | """Unsupervised pretraining"""
185 | # TODO: if it's necessary to mimic batch normalization in Pretraining?
186 | if not os.path.exists('./unsupervised_pretraining/model_init/savedmodel.npz'):
187 | print("start unsupervised pretraining")
188 | tmp = np.loadtxt(FLAGS.sample_pts, dtype=str, delimiter=",", encoding='UTF-8')
189 | tmp_feature = tmp[1:, :].astype(np.float32)
190 | np.random.shuffle(tmp_feature)
191 | Unsupervise_pretrain(tmp_feature)
192 |
193 | """meta task sampling"""
194 | tasks_path = './metatask_sampling/' + FLAGS.str_region + '_tasks_K' + str(FLAGS.K) + '.xlsx'
195 | if not os.path.exists(
196 | './metatask_sampling/' + FLAGS.str_region + '_SLIC_M{m}_K{k}_loop{loop}.tif'.format(loop=0, m=FLAGS.M,
197 | k=FLAGS.K)):
198 | print('start scene segmentation using SLIC algorithm:')
199 | p = SLICProcessor('./src_data/' + FLAGS.str_region + '/composite.tif', FLAGS.K, FLAGS.M)
200 | p.iterate_times(loop=FLAGS.loop)
201 | print('start meta-task sampling:')
202 | t = TaskSampling(p.clusters)
203 | tasks = t.sampling(p.im_geotrans, FLAGS.sample_pts)
204 | save_tasks(tasks, tasks_path) # save each meta-task samples into respective sheet in a .xlsx file
205 | savepts_fortask(p.clusters, './metatask_sampling/' + FLAGS.str_region + 'pts_tasks_K' + str(FLAGS.K) + '.xlsx')
206 | print('produce meta training and testing datasets...')
207 | HK_tasks = read_tasks(tasks_path)
208 | tasks_train, tasks_test = meta_train_test1(HK_tasks)
209 |
210 | """meta-training and -testing"""
211 | print('model construction...')
212 | model = MAML(FLAGS.dim_input, FLAGS.dim_output, test_num_updates=5)
213 |
214 | input_tensors_input = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_input)
215 | input_tensors_label = (FLAGS.meta_batch_size, int(FLAGS.num_samples_each_task / 2), FLAGS.dim_output)
216 | model.construct_model(input_tensors_input=input_tensors_input, input_tensors_label=input_tensors_label,
217 | prefix='metatrain_')
218 | model.summ_op = tf.compat.v1.summary.merge_all()
219 |
220 | saver = tf.compat.v1.train.Saver(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES),
221 | max_to_keep=10)
222 |
223 | sess = tf.compat.v1.InteractiveSession()
224 | init = tf.compat.v1.global_variables() # optimizer里会有额外variable需要初始化
225 | sess.run(tf.compat.v1.variables_initializer(var_list=init))
226 |
227 | exp_string = '.mbs' + str(FLAGS.meta_batch_size) + '.ubs_' + \
228 | str(FLAGS.num_samples_each_task) + '.numstep' + str(FLAGS.num_updates) + \
229 | '.updatelr' + str(FLAGS.update_lr) + '.meta_lr' + str(FLAGS.meta_lr)
230 |
231 | resume_itr = 0
232 |
233 | # 续点训练
234 | if FLAGS.resume:
235 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
236 | if model_file:
237 | ind1 = model_file.index('model')
238 | resume_itr = int(model_file[ind1 + 5:])
239 | # print("Restoring model weights from " + model_file)
240 | saver.restore(sess, model_file) # 以model_file初始化sess中图
241 |
242 | train(model, saver, sess, exp_string, tasks_train, resume_itr)
243 |
244 | test(model, saver, sess, exp_string, tasks_test, num_updates=FLAGS.num_updates)
245 |
246 |
247 | # TODO: use tf.estimator
248 | if __name__ == "__main__":
249 | # device=tf.config.list_physical_devices('GPU')
250 | tf.compat.v1.disable_eager_execution()
251 | main()
252 | print('finished!')
253 |
--------------------------------------------------------------------------------
/metatask_sampling/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
--------------------------------------------------------------------------------
/modeling.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function
3 | import numpy as np
4 | import tensorflow as tf
5 | import os
6 | from tensorflow.python.platform import flags
7 | from utils import mse, xent, normalize
8 |
9 | FLAGS = flags.FLAGS
10 |
11 |
12 | class MAML:
13 | def __init__(self, dim_input=1, dim_output=1, test_num_updates=5):
14 | """ must call construct_model() after initializing MAML! """
15 | self.dim_input = dim_input
16 | self.dim_output = dim_output
17 | self.meta_lr = tf.compat.v1.placeholder_with_default(FLAGS.meta_lr, ())
18 | self.test_num_updates = test_num_updates
19 | self.dim_hidden = [32, 32, 16]
20 | self.loss_func = xent
21 | self.forward = self.forward_fc
22 | if FLAGS.basemodel == 'MLP':
23 | self.construct_weights = self.construct_fc_weights # 参数初始化(random)
24 | elif FLAGS.basemodel == 'DAS':
25 | if os.path.exists('unsupervised_pretraining/model_init/savedmodel.npz'):
26 | self.construct_weights = self.construct_DAS_weights # DAS初始化
27 | else:
28 | raise ValueError('No pretrained model found!')
29 | else:
30 | raise ValueError('Unrecognized base model, please specify a base model such as "MLP", "DAS"...')
31 |
32 | def construct_model(self, input_tensors_input=None, input_tensors_label=None, prefix='metatrain_'):
33 | # a: training data for inner gradient, b: test data for meta gradient
34 | self.inputa = tf.compat.v1.placeholder(tf.float32,
35 | shape=input_tensors_input) # for train in a task, shape should be specified for tf.slim (but not should be correct)
36 | self.inputb = tf.compat.v1.placeholder(tf.float32, shape=input_tensors_input)
37 | self.labela = tf.compat.v1.placeholder(tf.float32, shape=input_tensors_label) # for test in a task
38 | self.labelb = tf.compat.v1.placeholder(tf.float32, shape=input_tensors_label)
39 | self.cnt_sample = tf.compat.v1.placeholder(tf.float32) # count number of samples for each task in the batch
40 |
41 | with tf.compat.v1.variable_scope('model', reuse=None) as training_scope:
42 | # Attention module
43 | self.A = tf.Variable(tf.zeros([self.dim_input, self.dim_input]))
44 | # initialize the inner learning rate as tf.Variable within 'model' scope
45 | self.update_lr = tf.Variable(FLAGS.update_lr)
46 | if 'weights' in dir(self):
47 | training_scope.reuse_variables()
48 | weights = self.weights
49 | else:
50 | # Define the weights
51 | self.weights = weights = self.construct_weights() # 初始化FC权重参数
52 |
53 | num_updates = max(self.test_num_updates, FLAGS.num_updates) # training iteration in a task
54 |
55 | def task_metalearn(inp, reuse=True):
56 | """ Perform gradient descent for one task in the meta-batch. """
57 | inputa, inputb, labela, labelb = inp # inputa: Task(i)训练输入,batch_size = m(m个samples)
58 | task_outputbs, task_lossesb = [], []
59 |
60 | task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter
61 | task_lossa = self.loss_func(task_outputa, labela)
62 |
63 | grads = tf.gradients(ys=task_lossa, xs=list(weights.values())) # 计算梯度
64 | if FLAGS.stop_grad: # maml中的二次求导()
65 | grads = [tf.stop_gradient(grad) for grad in grads] # 使梯度无法进行二次求偏导(BP)
66 | gradients = dict(zip(weights.keys(), grads))
67 | fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr * gradients[key] for key in
68 | weights.keys()])) # 更新weight
69 | output = self.forward(inputb, fast_weights, reuse=True) # Task(i) test output
70 | task_outputbs.append(output)
71 | task_lossesb.append(self.loss_func(output, labelb))
72 |
73 | for j in range(num_updates - 1): # num_updates:Task(i)中用batch_size个训练样本更新权值的迭代次数
74 | loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True),
75 | labela) # fast_weight和grads(stopped)有关系,但不影响这里的梯度计算
76 | grads = tf.gradients(ys=loss, xs=list(fast_weights.values()))
77 | if FLAGS.stop_grad:
78 | grads = [tf.stop_gradient(grad) for grad in grads]
79 | gradients = dict(zip(fast_weights.keys(), grads))
80 | fast_weights = dict(zip(fast_weights.keys(),
81 | [fast_weights[key] - self.update_lr * gradients[key] for key in
82 | fast_weights.keys()]))
83 | output = self.forward(inputb, fast_weights, reuse=True)
84 | task_outputbs.append(output)
85 | task_lossesb.append(self.loss_func(output, labelb))
86 |
87 | return [task_outputa, task_outputbs, task_lossa, task_lossesb] # task_outpouta, task_lossa是仅
88 |
89 | if FLAGS.norm != 'None': # 此处不能删,考虑到reuse
90 | '''to initialize the batch norm vars, might want to combine this, and not run idx 0 twice (use reuse=tf.AUTO_REUSE instead).'''
91 | task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
92 |
93 | out_dtype = [tf.float32, [tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates]
94 |
95 | """输入各维度(for batch)进行task_metalearn的并行操作, 相较out_dtype多了batch_size的维度"""
96 |
97 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb),
98 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size) # parallel calculation
99 |
100 | """ outputas:(num_tasks, num_samples, value)
101 | outputbs[i]:i是迭代次数,不同迭代次数的预测值(num_tasks, num_samples, value)
102 | lossesa:(num_tasks, value)
103 | lossesb[i]:i是迭代次数,不同迭代次数的预测值(num_tasks, value)"""
104 | outputas, outputbs, lossesa, lossesb = result # outputas:(num_tasks, num_samples, value)
105 |
106 | ## Performance & Optimization
107 | if 'train' in prefix:
108 | self.total_loss1 = total_loss1 = tf.reduce_sum(input_tensor=lossesa) / tf.cast(FLAGS.meta_batch_size,
109 | dtype=tf.float32) # total loss的均值,finn论文中的pretrain(对比用)
110 |
111 | self.total_losses2 = total_losses2 = [
112 | tf.reduce_sum(lossesb[j]) / tf.cast(FLAGS.meta_batch_size, dtype=tf.float32) \
113 | for j in range(num_updates)]
114 |
115 | # w = self.cnt_sample / tf.cast(FLAGS.num_samples, dtype=tf.float32)
116 | # self.total_losses2 = total_losses2 = [tf.reduce_sum(
117 | # input_tensor=tf.multiply(tf.nn.softmax(w), tf.reduce_sum(input_tensor=lossesb[j], axis=1)))
118 | # for j in range(num_updates)]
119 |
120 | # after the map_fn
121 | self.outputas, self.outputbs = outputas, outputbs # outputbs:25个task, 每个task迭代五次,value(25,5,1)
122 | self.pretrain_op = tf.compat.v1.train.AdamOptimizer(self.meta_lr).minimize(total_loss1) # inner for test
123 |
124 | optimizer = tf.compat.v1.train.AdamOptimizer(self.meta_lr)
125 |
126 | self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[
127 | FLAGS.num_updates - 1]) # 取最后一次迭代的Lossb,gvs:gradients and variables,对所有trainable variables求梯度
128 | self.metatrain_op = optimizer.apply_gradients(gvs) # outer
129 | else: # 20/11待删
130 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(input_tensor=lossesa) / tf.cast(
131 | FLAGS.meta_batch_size, dtype=tf.float32) # inner loss
132 | self.metaval_total_losses2 = total_losses2 = [
133 | tf.reduce_sum(input_tensor=lossesb[j]) / tf.cast(FLAGS.meta_batch_size, dtype=tf.float32) for j in
134 | range(num_updates)] # outer losses(每次迭代的)
135 |
136 | ## Summaries
137 | tf.compat.v1.summary.scalar(prefix + 'Pre-update loss', total_loss1) # for test accuracy
138 | for j in range(num_updates):
139 | tf.compat.v1.summary.scalar(prefix + 'Post-update loss, step ' + str(j + 1), total_losses2[j])
140 |
141 | def construct_fc_weights(self):
142 | weights = {}
143 | weights['w1'] = tf.Variable(tf.random.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
144 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
145 | for i in range(1, len(self.dim_hidden)):
146 | weights['w' + str(i + 1)] = tf.Variable(
147 | tf.random.truncated_normal([self.dim_hidden[i - 1], self.dim_hidden[i]], stddev=0.01))
148 | weights['b' + str(i + 1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
149 | weights['w' + str(len(self.dim_hidden) + 1)] = tf.Variable(
150 | tf.random.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
151 | weights['b' + str(len(self.dim_hidden) + 1)] = tf.Variable(tf.zeros([self.dim_output]))
152 | return weights
153 |
154 | def forward_fc(self, inp, weights, reuse=False):
155 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0')
156 | for i in range(1, len(self.dim_hidden)):
157 | hidden = normalize(tf.matmul(hidden, weights['w' + str(i + 1)]) + weights['b' + str(i + 1)],
158 | activation=tf.nn.relu, reuse=reuse, scope=str(i + 1))
159 | return tf.matmul(hidden, weights['w' + str(len(self.dim_hidden) + 1)]) + weights[
160 | 'b' + str(len(self.dim_hidden) + 1)]
161 |
162 | def construct_DAS_weights(self):
163 | """读取DAS权参"""
164 | npzfile = np.load('unsupervised_pretraining/model_init/savedmodel.npz')
165 | weights = {}
166 | weights['w1'] = tf.Variable(tf.transpose(a=npzfile['arr_0']))
167 | weights['b1'] = tf.Variable(npzfile['arr_1'])
168 | weights['w2'] = tf.Variable(tf.transpose(a=npzfile['arr_2']))
169 | weights['b2'] = tf.Variable(npzfile['arr_3'])
170 | weights['w3'] = tf.Variable(tf.transpose(a=npzfile['arr_4']))
171 | weights['b3'] = tf.Variable(npzfile['arr_5'])
172 | weights['w4'] = tf.Variable(tf.random.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
173 | weights['b4'] = tf.Variable(tf.zeros([self.dim_output]))
174 | return weights
175 |
--------------------------------------------------------------------------------
/models_of_blocks/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
--------------------------------------------------------------------------------
/models_of_blocks/HK/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | GDAL==3.2.3
2 | matplotlib==3.6.2
3 | numpy==1.23.3
4 | pandas==1.5.0
5 | scikit_image==0.19.3
6 | scikit_learn==1.1.2
7 | scipy==1.9.1
8 | shap==0.41.0
9 | tf_slim==1.1.0
10 | tqdm==4.64.1
11 | umap_learn==0.5.3
12 | xgboost==1.7.1
13 | ~atplotlib==3.6.0
14 |
--------------------------------------------------------------------------------
/scene_segmentation.py:
--------------------------------------------------------------------------------
1 | import math
2 | from skimage import io, color
3 | import numpy as np
4 | from tqdm import trange
5 | from osgeo import gdal
6 | import pandas as pd
7 | from tensorflow.python.platform import flags
8 |
9 | FLAGS = flags.FLAGS
10 |
11 |
12 | class Cluster(object):
13 | cluster_index = 1
14 |
15 | def __init__(self, h, w, dem, aspect, curvature, slope):
16 | self.update(h, w, dem, aspect, curvature, slope)
17 | self.pixels = []
18 | self.no = self.cluster_index
19 | self.cluster_index += 1 # 计数
20 |
21 | def update(self, h, w, dem, aspect, curvature, slope):
22 | self.h = h
23 | self.w = w
24 | self.dem = dem
25 | self.aspect = aspect
26 | self.curvature = curvature
27 | self.slope = slope
28 |
29 | def __str__(self):
30 | return "{},{}:{} {} {} ".format(self.h, self.w, self.dem, self.aspect, self.curvature,
31 | self.slope)
32 |
33 | def __repr__(self):
34 | return self.__str__()
35 |
36 |
37 | class SLICProcessor(object):
38 | @staticmethod
39 | def open_image(path):
40 | """
41 | Return:
42 | 3D array, row col [LAB]
43 | """
44 | rgb = io.imread(path)
45 | lab_arr = color.rgb2lab(rgb)
46 | return lab_arr
47 |
48 | @staticmethod
49 | def save_lab_image(path, lab_arr):
50 | """
51 | Convert the array to RBG, then save the image
52 | """
53 | rgb_arr = color.lab2rgb(lab_arr)
54 | io.imsave(path, rgb_arr)
55 |
56 | def make_cluster(self, h, w):
57 | return Cluster(h, w,
58 | self.data[0][h][w],
59 | self.data[1][h][w],
60 | self.data[2][h][w],
61 | self.data[3][h][w], )
62 |
63 | def __init__(self, filename, K, M): # K:number of superpixels; M:衡量像素距离占距离测量的比重
64 | self.file = filename
65 | self.K = K
66 | self.M = M
67 | self.data = self.readTif(filename) # shape:(6, , )
68 | self.im_geotrans = gdal.Open(filename).GetGeoTransform()
69 | self.image_height = self.data.shape[1]
70 | self.image_width = self.data.shape[2]
71 | self.N = self.image_height * self.image_width
72 | self.S = int(math.sqrt(self.N / self.K))
73 |
74 | self.clusters = []
75 | self.label = {}
76 | self.dis = np.full((self.image_height, self.image_width), np.inf) # np.inf正无穷
77 |
78 | def readTif(self, fileName):
79 | dataset = gdal.Open(fileName)
80 | if dataset == None:
81 | print(fileName + "文件无法打开")
82 | return
83 | im_width = dataset.RasterXSize # 栅格矩阵的列数
84 | im_height = dataset.RasterYSize # 栅格矩阵的行数
85 | im_bands = dataset.RasterCount # 波段数
86 | im_data = dataset.ReadAsArray(0, 0, im_width, im_height) # 获取数据
87 | im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
88 | im_proj = dataset.GetProjection() # 获取投影信息
89 | # col = int((coor[i][0] - im_geotrans[0]) / im_geotrans[1])
90 | # row = int((coor[i][1] - im_geotrans[3]) / im_geotrans[5])
91 | # im_nirBand = im_data[3,0:im_height,0:im_width]#获取近红外波段
92 | return im_data
93 |
94 | def init_clusters(self, data):
95 | h = int(self.S / 2) # 第一个中心点位(cluster)
96 | w = int(self.S / 2)
97 | while h < self.image_height:
98 | while w < self.image_width:
99 | if data[0][h][w] != -9999: # -9999为Nodata
100 | self.clusters.append(self.make_cluster(h, w))
101 | w += self.S
102 | w = int(self.S / 2)
103 | h += self.S
104 |
105 | def get_gradient(self, h, w):
106 | if w + 1 >= self.image_width:
107 | w = self.image_width - 2
108 | if h + 1 >= self.image_height:
109 | h = self.image_height - 2
110 |
111 | gradient = self.data[0][h + 1][w + 1] - self.data[0][h][w] + \
112 | self.data[1][h + 1][w + 1] - self.data[1][h][w] + \
113 | self.data[2][h + 1][w + 1] - self.data[2][h][w] + \
114 | self.data[3][h + 1][w + 1] - self.data[3][h][w]
115 |
116 | return gradient
117 |
118 | def move_clusters(self):
119 | for cluster in self.clusters:
120 | cluster_gradient = self.get_gradient(cluster.h, cluster.w) # 计算每个中心的gradient
121 | for dh in range(-5, 6):
122 | for dw in range(-5, 6):
123 | _h = cluster.h + dh
124 | _w = cluster.w + dw
125 | if self.data[0][_h][_w] and self.data[1][_h][_w] and self.data[2][_h][_w] \
126 | and self.data[3][_h][_w] != -9999:
127 | new_gradient = self.get_gradient(_h, _w)
128 | if new_gradient < cluster_gradient: # 寻找 4 x 4 邻域内梯度最小的像素点(更聚集),并且移动中心
129 | cluster.update(_h, _w, self.data[0][_h][_w], self.data[1][_h][_w], self.data[2][_h][_w],
130 | self.data[3][_h][_w])
131 | cluster_gradient = new_gradient
132 |
133 | def assignment(self):
134 | W = [0.3, 0.1, 0.2, 0.4] # 权重:各因素影响
135 | for cluster in self.clusters:
136 | for h in range(cluster.h - self.S, cluster.h + self.S):
137 | if h < 0 or h >= self.image_height: continue # continue进入下一个循环
138 | for w in range(cluster.w - self.S, cluster.w + self.S):
139 | if w < 0 or w >= self.image_width: continue
140 | if self.data[0][h][w] != -9999 and self.data[1][h][w] != -9999 \
141 | and self.data[2][h][w] != -9999 and self.data[3][h][w] != -9999:
142 | Dc = math.sqrt(
143 | math.pow(self.data[0][h][w] - cluster.dem, 2) * W[0] +
144 | math.pow(self.data[1][h][w] - cluster.aspect, 2) * W[1] +
145 | math.pow(self.data[2][h][w] - cluster.curvature, 2) * W[2] +
146 | math.pow(self.data[3][h][w] - cluster.slope, 2) * W[3]
147 | ) # dbs
148 | Ds = math.sqrt(
149 | math.pow(h - cluster.h, 2) +
150 | math.pow(w - cluster.w, 2)) # dxy
151 | D = math.sqrt(math.pow(Dc / self.M, 2) + math.pow(Ds / self.S, 2)) # Ds
152 | if D < self.dis[h][w]:
153 | if (h, w) not in self.label: # dict中tuple也可以作为key
154 | self.label[(h, w)] = cluster
155 | cluster.pixels.append((h, w))
156 | else:
157 | self.label[(h, w)].pixels.remove((h, w))
158 | self.label[(h, w)] = cluster
159 | cluster.pixels.append((h, w))
160 | self.dis[h][w] = D
161 |
162 | def update_cluster(self): # 计算各SLIC聚类的中心
163 | for cluster in self.clusters:
164 | sum_h = sum_w = number = 0
165 | for p in cluster.pixels:
166 | sum_h += p[0]
167 | sum_w += p[1]
168 | number += 1
169 | _h = int(sum_h / number)
170 | _w = int(sum_w / number)
171 | cluster.update(_h, _w, self.data[0][_h][_w], self.data[1][_h][_w], self.data[2][_h][_w],
172 | self.data[3][_h][_w]) # 计算聚类中心
173 |
174 | def writeTiff(self, im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path):
175 | if 'int8' in im_data.dtype.name:
176 | datatype = gdal.GDT_Byte
177 | elif 'int16' in im_data.dtype.name:
178 | datatype = gdal.GDT_UInt16
179 | else:
180 | datatype = gdal.GDT_Float32
181 |
182 | im_bands, im_height, im_width = im_data.shape
183 | path = 'metatask_sampling\\' + path
184 | # 创建文件
185 | driver = gdal.GetDriverByName("GTiff")
186 | dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
187 | if (dataset != None):
188 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
189 | dataset.SetProjection(im_proj) # 写入投影
190 | for i in range(im_bands):
191 | dataset.GetRasterBand(i + 1).SetNoDataValue(-9999)
192 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
193 | del dataset
194 |
195 | def savetif(self, tiffile, savename, image_arr):
196 | gdal.AllRegister()
197 | dataset = gdal.Open(tiffile)
198 | im_bands = dataset.RasterCount # 波段数
199 | im_width = dataset.RasterXSize # 列数
200 | im_height = dataset.RasterYSize # 行数
201 | im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
202 | im_proj = dataset.GetProjection() # 获取投影信息
203 |
204 | self.writeTiff(image_arr, im_width, im_height, im_bands, im_geotrans, im_proj, savename)
205 |
206 | def save_current_image(self, tiffile, savename):
207 | image_arr = np.copy(self.data)
208 | for i in range(len(image_arr)): # 初始化
209 | for h in range(self.image_height):
210 | for w in range(self.image_width):
211 | image_arr[i][h][w] = -9999
212 |
213 | c = 0;
214 | interval = int(256 / len(self.clusters))
215 | for cluster in self.clusters: # 可视化各聚类点
216 | c += 1
217 | r = g = b = d = interval * c - 1
218 | for p in cluster.pixels: # LAB三通道赋值
219 | image_arr[0][p[0]][p[1]] = r
220 | image_arr[1][p[0]][p[1]] = g
221 | image_arr[2][p[0]][p[1]] = b
222 | image_arr[0][cluster.h][cluster.w] = 0 # 让cluster中心为0
223 | image_arr[1][cluster.h][cluster.w] = 0
224 | image_arr[2][cluster.h][cluster.w] = 0
225 |
226 | self.savetif(tiffile, savename, image_arr)
227 |
228 | def iterate_times(self, loop=5):
229 | self.init_clusters(self.data) # 存储所有中心点, clusters = []
230 | self.move_clusters()
231 | for i in trange(loop):
232 | self.assignment()
233 | self.update_cluster()
234 | savename = FLAGS.str_region + '_SLIC_M{m}_K{k}_loop{loop}.tif'.format(loop=i, m=self.M,
235 | k=self.K) # 生成可视tif
236 | self.save_current_image(self.file, savename)
237 |
238 |
239 | class TaskSampling(object):
240 | def __init__(self, clusters):
241 | self.clusters = clusters
242 | self.tasks = self.init_tasks(len(clusters))
243 |
244 | def init_tasks(self, num_clusters):
245 | L = []
246 | for i in range(num_clusters):
247 | L.append([])
248 | return L
249 |
250 | def readpts(self, filepath):
251 | tmp = np.loadtxt(filepath, dtype=np.str, delimiter=",", encoding='UTF-8')
252 | features = tmp[1:, :-3].astype(np.float32)
253 | features = features / features.max(axis=0) # 减小数值影响
254 | xy = tmp[1:, -3: -1].astype(np.float32)
255 | label = tmp[1:, -1].astype(np.float32)
256 | return features, xy, label
257 |
258 | def sampling(self, im_geotrans, path):
259 | features, xy, label = self.readpts(path)
260 | # features_Ts_, xy_Ts, label_Ts = self.readpts(FLAGS.Ts_pts)
261 | # features = np.vstack((features, features_Ts_))
262 | # xy = np.vstack((xy, xy_Ts))
263 | # # labeling Ts pts according to dv value
264 | # for i in range(len(label_Ts)):
265 | # if label_Ts[i] <= 2:
266 | # label_Ts[i] = 0.7
267 | # continue
268 | # if 2 < label_Ts[i] <= 4:
269 | # label_Ts[i] = 0.75
270 | # continue
271 | # if 4 < label_Ts[i] <= 6:
272 | # label_Ts[i] = 0.8
273 | # continue
274 | # if 6 < label_Ts[i] <= 8:
275 | # label_Ts[i] = 0.85
276 | # continue
277 | # if label_Ts[i] > 8:
278 | # label_Ts[i] = 0.9
279 | # label = np.hstack((label, label_Ts))
280 | # 计算(row, col)
281 | pts = []
282 | for i in range(xy.shape[0]):
283 | height = int((xy[i][1] - im_geotrans[3]) / im_geotrans[5])
284 | width = int((xy[i][0] - im_geotrans[0]) / im_geotrans[1])
285 | pts.append((height, width))
286 |
287 | pt_index = 0
288 | for pt in pts:
289 | k = 0 # count cluster
290 | for cluster in self.clusters:
291 | if (pt[0], pt[1]) in cluster.pixels:
292 | self.tasks[k].append([features[pt_index], label[pt_index]])
293 | break
294 | else:
295 | k += 1
296 | pt_index += 1
297 | return self.tasks
298 |
--------------------------------------------------------------------------------
/src_data/HK/composite.tfw:
--------------------------------------------------------------------------------
1 | 78.1848426438
2 | 0.0000000000
3 | 0.0000000000
4 | -75.3690471051
5 | 801034.7337224468
6 | 846872.9936815350
7 |
--------------------------------------------------------------------------------
/src_data/HK/composite.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLi-de/Meta_LSM/c25e2904761e3f4b4d5a797f0b4db7ddfe53236e/src_data/HK/composite.tif
--------------------------------------------------------------------------------
/src_data/HK/composite.tif.aux.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | -62
6 | 951
7 | 256
8 | 1
9 | 0
10 | 1|0|0|2|1|0|1|5|7|4|18|41|117|311|1053|8851|9688|11455|8042|7095|6026|5065|4483|3916|3527|3282|2989|2840|2857|3486|2536|2472|2494|2389|2157|2165|1648|2120|2009|1997|1945|1924|1829|1845|1727|1744|1653|1716|1700|1640|1535|1537|1569|1468|1447|1449|1404|1462|1425|1347|1302|1257|1376|1247|1290|1280|1208|1223|1157|1256|1139|1132|806|979|984|1021|979|938|923|933|884|859|795|825|787|732|707|741|737|733|697|719|613|665|583|622|557|601|567|556|499|537|503|466|458|433|445|433|398|297|397|389|389|385|315|381|334|316|303|300|251|298|275|266|285|268|256|223|231|208|194|217|175|168|161|152|151|152|126|152|147|117|114|135|145|81|114|79|117|89|112|95|96|92|81|94|64|71|66|59|55|72|55|65|43|49|66|51|44|63|62|42|59|31|40|52|34|42|47|44|42|24|34|35|37|26|29|25|40|34|35|33|35|29|39|34|40|26|20|33|17|19|31|26|19|13|20|22|13|13|10|8|10|14|16|6|10|16|10|10|12|10|6|12|8|9|7|8|8|7|5|1|7|3|8|4|3|3|5|4|6|1|3|1|4|3|1|1|1|1|0|2|1|0|2|1
11 |
12 |
13 |
14 | ATHEMATIC
15 | 18458.44274292592,662.3847044297564,15.69938554122262,609.7599113609483
16 |
17 | 951
18 | 124.16100786648
19 | -62
20 | 1
21 | 1
22 | 135.86185168371
23 |
24 |
25 |
26 |
27 |
28 | -1
29 | 359.7870178222656
30 | 256
31 | 1
32 | 0
33 | 6062|561|682|639|685|638|662|676|767|632|669|673|567|952|568|670|633|723|594|779|609|654|683|620|636|693|615|592|654|600|511|470|1577|462|538|644|664|562|679|670|606|633|672|628|638|793|610|750|621|588|576|925|586|674|723|629|800|712|713|746|671|667|687|588|1263|614|696|839|684|671|777|789|825|732|723|716|662|932|689|744|743|669|641|877|688|623|779|674|844|699|712|733|708|739|672|364|1683|676|704|714|711|764|709|676|813|584|814|684|651|803|674|713|678|668|1005|544|668|641|791|825|597|631|717|647|634|649|630|1241|406|592|626|680|716|690|629|829|563|742|657|594|963|486|662|669|746|669|780|600|672|764|590|727|588|820|574|660|647|589|568|1608|420|637|661|679|660|644|709|787|691|723|615|746|909|608|675|766|664|708|976|654|701|734|771|804|694|802|660|864|743|711|593|1336|695|729|794|685|812|796|703|835|718|777|699|658|978|618|695|751|643|675|787|626|737|810|653|728|717|726|672|714|663|613|453|1608|602|636|663|750|695|707|823|760|721|797|725|685|846|686|777|705|680|525|984|717|715|803|616|824|699|756|765|719|682|675|471|1
34 |
35 |
36 |
37 | ATHEMATIC
38 | 662.3847044297564,11334.03019575368,0.4471348368896256,73.25494859340394
39 |
40 | 359.78701782227
41 | 176.76238026471
42 | -1
43 | 1
44 | 1
45 | 106.46140237548
46 |
47 |
48 |
49 |
50 |
51 | -21.66348075866699
52 | 22.56612586975098
53 | 256
54 | 1
55 | 0
56 | 1|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|1|0|0|0|0|0|0|1|0|4|0|0|0|0|1|1|0|2|1|0|1|1|1|2|2|1|3|0|1|3|3|4|3|5|3|5|0|5|5|6|5|9|7|12|3|1|11|7|16|11|18|8|23|16|26|11|33|18|44|48|33|73|34|102|60|124|86|177|111|287|159|365|553|320|801|455|1297|802|2003|1279|3364|2163|5299|3456|8857|11860|7213|17833|10400|28238|10688|18076|7496|12015|4696|7681|3172|4951|2041|3141|2277|930|1492|591|998|380|679|274|432|182|290|123|234|187|67|142|75|84|43|83|27|73|15|55|19|28|29|12|33|14|17|10|17|4|12|6|15|8|13|4|3|8|5|6|1|5|3|7|3|1|2|2|3|0|6|1|0|2|1|1|3|1|2|2|0|1|1|0|0|0|1|1|0|1|0|0|1|0|1|0|0|0|0|0|1|0|2|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|1|0|0|0|0|0|0|0|0|0|1
57 |
58 |
59 |
60 | ATHEMATIC
61 | 15.69938554122262,0.4471348368896256,1.404366681951738,0.1800680002874268
62 |
63 | 22.566125869751
64 | -0.012823529684218
65 | -21.663480758667
66 | 1
67 | 1
68 | 1.1850597799064
69 |
70 |
71 |
72 |
73 |
74 | 0
75 | 65.13930511474609
76 | 256
77 | 1
78 | 0
79 | 5017|927|1169|901|1503|2038|1525|1778|2054|1852|2306|1505|1770|2125|1867|1352|2060|1856|1673|1870|1756|1567|1548|1509|1397|1540|1636|1591|1263|1709|1606|1536|1154|1724|1536|1385|1485|1442|1559|1405|1654|1537|1540|1555|1661|1340|1408|1479|1659|1484|1524|1507|1716|1668|1429|1430|1747|1575|1826|1584|1464|1615|1664|1707|1548|1672|1584|1689|1593|1574|1676|1490|1530|1503|1604|1572|1633|1657|1463|1531|1657|1491|1546|1439|1416|1387|1553|1490|1339|1511|1379|1356|1304|1322|1144|1273|1276|1306|1135|1111|1200|1036|1139|1003|1072|969|957|875|901|926|858|780|776|756|723|741|699|670|573|601|579|574|500|504|475|390|498|424|411|380|365|374|336|320|307|285|289|279|240|254|208|199|228|192|154|171|160|154|142|143|116|119|124|109|135|98|92|84|84|69|70|58|66|72|56|49|54|48|46|47|48|31|43|38|34|30|20|24|26|21|22|31|18|17|13|8|22|15|16|17|10|13|12|9|10|10|6|8|5|11|7|4|3|5|8|2|6|5|5|7|2|3|3|3|6|1|1|0|4|2|2|0|2|0|0|0|3|0|0|0|0|3|0|1|1|0|0|1|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|2|1
80 |
81 |
82 |
83 | ATHEMATIC
84 | 609.7599113609483,73.25494859340394,0.1800680002874268,95.64630594975473
85 |
86 | 65.139305114746
87 | 15.048037271844
88 | 0
89 | 1
90 | 1
91 | 9.7798929416305
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/src_data/HK/composite.tif.ovr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLi-de/Meta_LSM/c25e2904761e3f4b4d5a797f0b4db7ddfe53236e/src_data/HK/composite.tif.ovr
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions. """
2 | import numpy as np
3 | import tensorflow as tf
4 | import pandas as pd
5 |
6 | import tf_slim as slim
7 | from tensorflow.python.platform import flags
8 |
9 | FLAGS = flags.FLAGS
10 |
11 |
12 | def normalize(inp, activation, reuse, scope):
13 | if FLAGS.norm == 'batch_norm':
14 | return slim.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
15 | elif FLAGS.norm == 'layer_norm':
16 | return slim.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
17 | elif FLAGS.norm == 'None':
18 | if activation is not None:
19 | return activation(inp)
20 | else:
21 | return inp
22 |
23 |
24 | def mse(pred, label):
25 | pred = tf.reshape(pred, [-1])
26 | label = tf.reshape(label, [-1])
27 | return tf.reduce_mean(input_tensor=tf.square(pred - label))
28 |
29 |
30 | def xent(pred, label):
31 | # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives
32 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / tf.cast(tf.shape(input=label)[0],
33 | dtype=tf.float32) # 注意归一
34 |
35 |
36 | def tasksbatch_generator(data, batch_size, num_samples, dim_input, dim_output):
37 | """generate batch tasks"""
38 | init_inputs = np.zeros([batch_size, num_samples, dim_input], dtype=np.float32)
39 | labels = np.zeros([batch_size, num_samples, dim_output], dtype=np.float32)
40 |
41 | np.random.shuffle(data)
42 | start_index = np.random.randint(0, len(data) - batch_size)
43 | batch_tasks = data[start_index:(start_index + batch_size)]
44 |
45 | cnt_sample = []
46 | for i in range(batch_size):
47 | cnt_sample.append(len(batch_tasks[i]))
48 |
49 | for i in range(batch_size):
50 | np.random.shuffle(batch_tasks[i]) # shuffle samples in each task
51 | start_index1 = np.random.randint(0, len(batch_tasks[i]) - num_samples)
52 | task_samples = batch_tasks[i][start_index1:(start_index1 + num_samples)]
53 | for j in range(num_samples):
54 | init_inputs[i][j] = task_samples[j][0]
55 | if task_samples[j][1] == 1:
56 | labels[i][j][0] = 1 # 滑坡
57 | else:
58 | labels[i][j][1] = 1 # 非滑坡
59 | return init_inputs, labels, np.array(cnt_sample).astype(np.float32)
60 |
61 |
62 | def batch_generator(one_task, dim_input, dim_output, batch_size):
63 | """generate samples from one tasks"""
64 | np.random.shuffle(one_task)
65 | batch_ = one_task[:batch_size]
66 | init_inputs = np.zeros([batch_size, dim_input], dtype=np.float32)
67 | labels = np.zeros([batch_size, dim_output], dtype=np.float32)
68 | for i in range(batch_size):
69 | init_inputs[i] = batch_[i][0]
70 | if batch_[i][1] == 1:
71 | labels[i][1] = 1
72 | else:
73 | labels[i][0] = 1
74 | return init_inputs, labels
75 |
76 |
77 | # # for each region (e.g., FJ&FL)
78 | # def sample_generator_(tasks, dim_input, dim_output):
79 | # all_samples = np.array(tasks[0])
80 | # num_samples = int(FLAGS.num_samples_each_task / 2)
81 | # for i in range(len(tasks) - 1):
82 | # if len(tasks[i + 1]) > 0:
83 | # all_samples = np.vstack((all_samples, np.array(tasks[i + 1])))
84 | # init_inputs = np.zeros([1, num_samples, dim_input], dtype=np.float32)
85 | # labels = np.zeros([1, num_samples, dim_output], dtype=np.float32)
86 | # for i in range(num_samples):
87 | # init_inputs[0][i] = all_samples[i][:-1]
88 | # if all_samples[i][-1] == 1:
89 | # labels[0][i][0] = 1
90 | # else:
91 | # labels[0][i][1] = 1
92 | # return init_inputs, labels
93 |
94 |
95 | def meta_train_test(fj_tasks, fl_tasks, mode=0):
96 | test1_fj_tasks, test1_fl_tasks, read_tasks, one_test_tasks = [], [], [], []
97 | _train, _test = [], []
98 | # np.random.shuffle(tasks)
99 | if mode == 0:
100 | elig_tasks = []
101 | for i in range(len(fj_tasks)):
102 | if len(fj_tasks[i]) > FLAGS.num_samples_each_task:
103 | elig_tasks.append(fj_tasks[i])
104 | elif len(fj_tasks[i]) > 10: # set 10 to test K=10-shot learning
105 | test1_fj_tasks.append(fj_tasks[i])
106 | else:
107 | read_tasks.append(fj_tasks[i])
108 | _train = elig_tasks[:int(len(elig_tasks) / 4 * 3)]
109 | _test = elig_tasks[int(len(elig_tasks) / 4 * 3):] + test1_fj_tasks
110 | for i in range(len(read_tasks)): # read_tasks暂时不用
111 | one_test_tasks.extend(read_tasks[i])
112 | return _train, _test
113 |
114 | if mode == 1:
115 | for i in range(len(fj_tasks)):
116 | if len(fj_tasks[i]) > FLAGS.num_samples_each_task:
117 | _train.append(fj_tasks[i])
118 | for i in range(len(fl_tasks)):
119 | if len(fl_tasks[i]) > 10:
120 | _test.append(fl_tasks[i])
121 | return _train, _test
122 |
123 | if mode == 2 or mode == 3:
124 | elig_fj_tasks, elig_fl_tasks = [], []
125 | for i in range(len(fj_tasks)):
126 | if len(fj_tasks[i]) > FLAGS.num_samples_each_task:
127 | elig_fj_tasks.append(fj_tasks[i])
128 | elif len(fj_tasks[i]) > 10:
129 | test1_fj_tasks.append(fj_tasks[i])
130 | for i in range(len(fl_tasks)):
131 | if len(fl_tasks[i]) > FLAGS.num_samples_each_task:
132 | elig_fl_tasks.append(fl_tasks[i])
133 | elif len(fl_tasks[i]) > 10:
134 | test1_fl_tasks.append(fl_tasks[i])
135 | if mode == 2:
136 | _train = elig_fj_tasks[:int(len(elig_fj_tasks) / 4 * 3)] + elig_fl_tasks
137 | _test = elig_fj_tasks[int(len(elig_fj_tasks) / 4 * 3):] + test1_fj_tasks
138 | return _train, _test
139 | elif mode == 3:
140 | _train = elig_fj_tasks + elig_fl_tasks[:int(len(elig_fj_tasks) / 2)]
141 | _test = elig_fl_tasks[int(len(elig_fl_tasks) / 2):] + test1_fl_tasks
142 | return _train, _test
143 | # _test.extend(resid_tasks)
144 |
145 |
146 | def meta_train_test1(HK_tasks):
147 | test_hk_tasks, one_test_tasks, remain_tasks, elig_tasks = [], [], [], []
148 | for i in range(len(HK_tasks)):
149 | if len(HK_tasks[i]) > FLAGS.num_samples_each_task:
150 | elig_tasks.append(HK_tasks[i])
151 | else:
152 | remain_tasks.append(HK_tasks[i])
153 | np.random.shuffle(elig_tasks)
154 | _train = elig_tasks[:int(len(elig_tasks) / 4 * 3)]
155 | _test = elig_tasks[int(len(elig_tasks) / 4 * 3):]
156 | return _train, _test
157 |
158 |
159 | def save_tasks(tasks, filename):
160 | """将tasks存到csv中"""
161 | writer = pd.ExcelWriter(filename)
162 | for i in range(len(tasks)):
163 | task_sampels = []
164 | for j in range(len(tasks[i])):
165 | attr_lb = np.append(tasks[i][j][0], tasks[i][j][1])
166 | task_sampels.append(attr_lb)
167 | data_df = pd.DataFrame(task_sampels)
168 | data_df.to_excel(writer, 'task_' + str(i), float_format='%.5f', header=False, index=False)
169 | writer.close()
170 |
171 |
172 | def read_tasks(file):
173 | """获取tasks"""
174 | f = pd.ExcelFile(file)
175 | tasks = [[] for i in range(len(f.sheet_names))]
176 | k = 0 # count task
177 | for sheetname in f.sheet_names:
178 | attr = pd.read_excel(file, usecols=[i for i in range(FLAGS.dim_input)], sheet_name=sheetname,
179 | header=None).values.astype(np.float32)
180 | label = pd.read_excel(file, usecols=[FLAGS.dim_input], sheet_name=sheetname, header=None).values.reshape(
181 | (-1,)).astype(np.float32)
182 | for j in range(np.shape(attr)[0]):
183 | tasks[k].append([attr[j], label[j]])
184 | k += 1
185 | return tasks
186 |
187 |
188 | def savepts_fortask(clusters, file):
189 | writer = pd.ExcelWriter(file)
190 | count = 0
191 | for cluster in clusters:
192 | pts = []
193 | for pixel in cluster.pixels:
194 | pts.append(pixel)
195 | data_df = pd.DataFrame(pts)
196 | data_df.to_excel(writer, 'task_' + str(count), float_format='%.5f', header=False, index=False)
197 | count = count + 1
198 | writer.close()
199 |
200 |
201 | def read_pts(file):
202 | """获取tasks"""
203 | f = pd.ExcelFile(file)
204 | tasks = []
205 | for sheetname in f.sheet_names:
206 | arr = pd.read_excel(file, sheet_name=sheetname).values.astype(np.float32)
207 | tasks.append(arr)
208 | return tasks
209 |
210 |
211 | def cal_measure(pred, y_test):
212 | TP = ((pred == 1) * (y_test == 1)).astype(int).sum()
213 | FP = ((pred == 1) * (y_test == 0)).astype(int).sum()
214 | FN = ((pred == 0) * (y_test == 1)).astype(int).sum()
215 | TN = ((pred == 0) * (y_test == 0)).astype(int).sum()
216 | # statistical measure
217 | Precision = TP / (TP + FP)
218 | Recall = TP / (TP + FN)
219 | F_measures = 2 * Precision * Recall / (Precision + Recall)
220 | print('Precision: %f' % Precision, '\nRecall: %f' % Recall, '\nF_measures: %f' % F_measures)
221 |
--------------------------------------------------------------------------------