├── requirements.txt ├── README.md ├── model_APPDNET.py ├── APP_Run.py └── utils_APPDNET.py /requirements.txt: -------------------------------------------------------------------------------- 1 | keras 2 | tensorflow-gpu 3 | keras-self-attention 4 | h5py 5 | obspy 6 | pandas 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | An Attention Based Network to Achieve P-arrival picking and First-Motion Determination 2 | Author: Ji Zhang 3 | Date: 2022.04.25 4 | Version 1.0.0 5 | 6 | # APPNET 7 | ## Simultaneous Seismic Phase Picking and Polarity Determination with an Attention-based Neural Network 8 | 9 | ### This repository contains the codes to train and test the network proposed in: 10 | 11 | `Zhang J, Li Z, Zhang J. Simultaneous Seismic Phase Picking and Polarity Determination with an Attention‐Based Neural Network [J]. Seismological Research Letters, 2023.` 12 | 13 | ------------------------------------------- 14 | ### Installation: 15 | 16 | `pip install -r requirements.txt` 17 | 18 | or 19 | 20 | `pip install keras-self-attention` 21 | 22 | ------------------------------------------- 23 | ### Short Description: 24 | 25 | The focal mechanism of a small earthquake is difficult to determine, but it plays an important role in understanding the regional stress field. The focal mechanisms of small earthquakes can be obtained by inversion of first-motion polarities. Machine learning can help determine polarities efficiently and accurately. The first-motion polarity determination is inseparable from the accuracy of picking and it highly depends on the latter. We propose a first attention-based network to tackle two tasks of picking and polarity determination with encouraging results. APPNET consists of one simple encoder, one decoder, and one classifier. 26 | 27 | ------------------------------------------- 28 | ### Dataset: 29 | 30 | Data from Southern California Earthquake Data Center. [(SCEDC)](https://scedc.caltech.edu/data/deeplearning.html#picking_polarity) 31 | Download three hdf5 files 32 | `scsn_p_2000_2017_6sec_0.5r_pick_train.hdf5` 33 | `scsn_p_2000_2017_6sec_0.5r_pick_test.hdf5` 34 | `scsn_p_2000_2017_6sec_0.5r_fm_test.hdf5` 35 | or 36 | [Traing_dataset](https://service.scedc.caltech.edu/ftp/Ross_FinalTrainedModels/scsn_p_2000_2017_6sec_0.5r_pick_train.hdf5) 37 | [Validation_dataset](https://service.scedc.caltech.edu/ftp/Ross_FinalTrainedModels/scsn_p_2000_2017_6sec_0.5r_pick_test.hdf5) 38 | [Test_dataset](https://service.scedc.caltech.edu/ftp/Ross_FinalTrainedModels/scsn_p_2000_2017_6sec_0.5r_fm_test.hdf5) 39 | 40 | ------------------------------------------- 41 | ### Run 42 | Download train dataset, validation data, and test dataset to `./dataset/` file. 43 | 44 | ` Train` 45 | > python APP_Run.py --mode='train' 46 | 47 | `Test` 48 | > python APP_Run.py --mode='test' --plot_figure 49 | 50 | `Predict` 51 | > python APP_Run.py --mode='predict' 52 | 53 | ------------------------------------------- 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /model_APPDNET.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Apr 22 13:27:56 2022 5 | 6 | @author: zhangj2 7 | """ 8 | import datetime 9 | import keras 10 | from keras import losses 11 | from keras import optimizers 12 | from keras.models import Sequential 13 | from keras.models import Model,load_model 14 | from keras.layers import Input, Dense, Dropout, Flatten,Embedding, LSTM,GRU,Bidirectional 15 | from keras.layers import Conv1D,Conv2D,MaxPooling1D,MaxPooling2D,BatchNormalization,Reshape 16 | from keras.layers import UpSampling1D,AveragePooling1D,AveragePooling2D,TimeDistributed 17 | from keras_self_attention import SeqSelfAttention 18 | from keras.callbacks import LearningRateScheduler,EarlyStopping,ModelCheckpoint 19 | # In[] 20 | def build_PP_model(time_input=(400,1),clas=3,filter_size=3,num_filter=[16,32,64],num_dense=128): 21 | 22 | inp = Input(shape=time_input, name='input') 23 | # print(num_filter) 24 | x = Conv1D(num_filter[0], filter_size, padding = 'same', activation = 'relu')(inp) 25 | 26 | x = MaxPooling1D(2)(x) 27 | x = BatchNormalization()(x) 28 | 29 | x = Conv1D(num_filter[1], filter_size, padding = 'same', activation = 'relu')(x) 30 | x = MaxPooling1D(2)(x) 31 | x = BatchNormalization()(x) 32 | 33 | x = Conv1D(num_filter[2], filter_size, padding = 'same', activation = 'relu')(x) 34 | x = MaxPooling1D(2)(x) 35 | x = BatchNormalization()(x) 36 | 37 | x = keras.layers.LSTM(units=num_filter[2]*2, return_sequences=True)(x) 38 | 39 | at_x,wt = SeqSelfAttention(return_attention=True, attention_width= 20, 40 | attention_activation='relu',name='Atten')(x) 41 | #----------------------# 42 | x1 = UpSampling1D(2)(at_x) 43 | x1 = Conv1D(num_filter[2], filter_size, padding = 'same', activation = 'relu')(x1) 44 | x1 = BatchNormalization()(x1) 45 | 46 | x1 = UpSampling1D(2)(x1) 47 | x1 = Conv1D(num_filter[1], filter_size, padding = 'same', activation = 'relu')(x1) 48 | x1 = BatchNormalization()(x1) 49 | 50 | x1 = UpSampling1D(2)(x1) 51 | x1 = Conv1D(num_filter[0], filter_size, padding = 'same', activation = 'relu')(x1) 52 | x1 = BatchNormalization()(x1) 53 | 54 | out1 = Conv1D(1, filter_size, padding = 'same', activation = 'sigmoid',name='pk')(x1) 55 | 56 | #----------------------# 57 | x = Flatten()(at_x) 58 | 59 | x = Dense(num_dense,activation = 'relu')(x) 60 | 61 | out2 = Dense(clas,activation = 'softmax',name='po')(x) 62 | 63 | model = Model(inp, [out1,out2]) 64 | 65 | return model 66 | # In[] 67 | ''' 68 | ##build up model 69 | def build_PP_model(time_input=(400,1),clas=3): 70 | 71 | inp = Input(shape=time_input, name='input') 72 | 73 | x = Conv1D(16, 5, padding = 'same', activation = 'relu')(inp) 74 | 75 | x = MaxPooling1D(2)(x) 76 | x = BatchNormalization()(x) 77 | 78 | x = Conv1D(32, 3, padding = 'same', activation = 'relu')(x) 79 | x = MaxPooling1D(2)(x) 80 | x = BatchNormalization()(x) 81 | 82 | x = Conv1D(64, 3, padding = 'same', activation = 'relu')(x) 83 | x = MaxPooling1D(2)(x) 84 | x = BatchNormalization()(x) 85 | 86 | x = keras.layers.LSTM(units=128, return_sequences=True)(x) 87 | 88 | at_x,wt = SeqSelfAttention(return_attention=True, attention_width= 20, 89 | attention_activation='relu',name='Atten')(x) 90 | #----------------------# 91 | x1 = UpSampling1D(2)(at_x) 92 | x1 = Conv1D(64, 3, padding = 'same', activation = 'relu')(x1) 93 | x1 = BatchNormalization()(x1) 94 | 95 | x1 = UpSampling1D(2)(x1) 96 | x1 = Conv1D(32, 3, padding = 'same', activation = 'relu')(x1) 97 | x1 = BatchNormalization()(x1) 98 | 99 | x1 = UpSampling1D(2)(x1) 100 | x1 = Conv1D(16, 3, padding = 'same', activation = 'relu')(x1) 101 | x1 = BatchNormalization()(x1) 102 | 103 | out1 = Conv1D(1, 3, padding = 'same', activation = 'sigmoid',name='pk')(x1) 104 | 105 | #----------------------# 106 | x = Flatten()(at_x) 107 | 108 | x = Dense(128,activation = 'relu')(x) 109 | 110 | out2 = Dense(clas,activation = 'softmax',name='po')(x) 111 | 112 | model = Model(inp, [out1,out2]) 113 | 114 | return model 115 | ''' -------------------------------------------------------------------------------- /APP_Run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Sep 30 00:00:13 2022 5 | 6 | @author: zhangj2 7 | """ 8 | #!/usr/bin/env python3 9 | # -*- coding: utf-8 -*- 10 | """ 11 | Created on Fri Apr 22 11:57:32 2022 12 | 13 | APPDNET: 14 | picking and polarity determination with an attention network 15 | ` Train` 16 | > python APP_Run.py --mode='train' 17 | 18 | `Test` 19 | > python APP_Run.py --mode='test' --plot_figure 20 | 21 | `Predict` 22 | > python APP_Run.py --mode='predict' 23 | 24 | 25 | @author: zhangj2 26 | """ 27 | 28 | # In[] 29 | cuda_kernel='1' 30 | import os 31 | os.getcwd() 32 | import tensorflow as tf 33 | 34 | import datetime 35 | import h5py 36 | import matplotlib.pyplot as plt 37 | import numpy as np 38 | from skimage import transform 39 | from scipy import interpolate 40 | import argparse 41 | from utils_APPDNET import gaussian,plot_loss,DataGenerator_PP1_S,DataGenerator_PP1_S_test,confusion_matrix,gen_test_data_UD 42 | from model_APPDNET import build_PP_model 43 | from keras_self_attention import SeqSelfAttention 44 | from keras.callbacks import LearningRateScheduler,EarlyStopping,ModelCheckpoint 45 | from keras.models import Model,load_model 46 | from keras.utils.np_utils import to_categorical 47 | import random 48 | import pandas as pd 49 | 50 | import matplotlib 51 | matplotlib.use('Agg') 52 | 53 | # In[] 54 | def start_gpu(args): 55 | cuda_kernel=args.GPU 56 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 57 | os.environ["CUDA_VISIBLE_DEVICES"] = cuda_kernel 58 | # os.environ["CUDA_DEVICE_ORDER"] = args.PCI_BUS_ID 59 | 60 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 61 | print('Physical GPU:', len(gpus)) 62 | for gpu in gpus: 63 | tf.config.experimental.set_memory_growth(gpu, True) 64 | 65 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 66 | print('Logical GPU:', len(logical_gpus)) 67 | 68 | 69 | 70 | 71 | def read_args(): 72 | 73 | parser = argparse.ArgumentParser() 74 | 75 | parser.add_argument("--GPU", 76 | default="0", 77 | help="set gpu ids") 78 | 79 | parser.add_argument("--mode", 80 | default="train", 81 | help="/train/predict/hinet") 82 | 83 | parser.add_argument("--model_name", 84 | default="APPNET_LOLITA", 85 | help="model name") 86 | 87 | parser.add_argument("--epochs", 88 | default=100, 89 | type=int, 90 | help="number of epochs (default: 10)") 91 | 92 | parser.add_argument("--batch_size", 93 | default=256, 94 | type=int, 95 | help="batch size") 96 | 97 | parser.add_argument("--learning_rate", 98 | default=0.001, 99 | type=float, 100 | help="learning rate") 101 | 102 | parser.add_argument("--patience", 103 | default=5, 104 | type=int, 105 | help="early stopping") 106 | 107 | parser.add_argument("--clas", 108 | default=3, 109 | type=int, 110 | help="number of class") 111 | 112 | parser.add_argument("--monitor", 113 | default="val_loss", 114 | help="monitor the val_loss/loss/acc/val_acc") 115 | 116 | parser.add_argument("--monitor_mode", 117 | default="min", 118 | help="min/max/auto") 119 | 120 | parser.add_argument("--use_multiprocessing", 121 | default=False, 122 | help="False/True") 123 | 124 | parser.add_argument("--workers", 125 | default=1, 126 | type=int, 127 | help="workers") 128 | 129 | parser.add_argument("--loss", 130 | default=['mse','categorical_crossentropy'], 131 | type=list, 132 | help="loss fucntion") 133 | 134 | parser.add_argument("--num_filter", 135 | default=[16,32,64], 136 | type=int, 137 | nargs='+', 138 | help="num_filter") 139 | 140 | parser.add_argument("--filter_size", 141 | default=3, 142 | type=int, 143 | help="filter_size") 144 | 145 | parser.add_argument("--num_dense", 146 | default=128, 147 | type=int, 148 | help="num_dense") 149 | 150 | parser.add_argument("--model_dir", 151 | default='./model/', 152 | help="Checkpoint directory (default: None)") 153 | 154 | parser.add_argument("--num_plots", 155 | default=10, 156 | type=int, 157 | help="Plotting trainning results") 158 | 159 | parser.add_argument("--input_length", 160 | default=400, 161 | type=int, 162 | help="input length") 163 | 164 | parser.add_argument("--data_dir", 165 | default=None, 166 | help="Input file directory") 167 | 168 | parser.add_argument("--data_list", 169 | default=None, 170 | help="Input csv file") 171 | 172 | parser.add_argument("--train_dir", 173 | default="./dataset/scsn_p_2000_2017_6sec_0.5r_pick_train.hdf5", 174 | help="Input file directory") 175 | 176 | parser.add_argument("--train_list", 177 | default=None, 178 | help="Input csv file") 179 | 180 | parser.add_argument("--valid_dir", 181 | default='./dataset/scsn_p_2000_2017_6sec_0.5r_pick_test.hdf5', 182 | help="Input file directory") 183 | 184 | parser.add_argument("--valid_list", 185 | default=None, 186 | help="Input csv file") 187 | 188 | parser.add_argument("--org_dir", 189 | default='./dataset/scsn_p_2000_2017_6sec_0.5r_fm_test.hdf5', 190 | help="SCSN testing data") 191 | 192 | 193 | parser.add_argument("--output_dir", 194 | default='./result/', 195 | help="Output directory") 196 | 197 | parser.add_argument("--conf_dir", 198 | default='./model_configure/', 199 | help="Configure directory") 200 | 201 | parser.add_argument("--acc_loss_fig", 202 | default='./acc_loss_fig/', 203 | help="acc&loss directory") 204 | 205 | parser.add_argument("--time_shift", 206 | default=True, 207 | help="False/True") 208 | 209 | parser.add_argument("--shuffle", 210 | default=True, 211 | help="False/True") 212 | 213 | parser.add_argument("--augment", 214 | default=True, 215 | help="False/True") 216 | 217 | parser.add_argument("--plot_figure", 218 | action="store_true", 219 | help="If plot figure for test") 220 | 221 | parser.add_argument("--save_result", 222 | action="store_true", 223 | help="If save result for test") 224 | 225 | parser.add_argument("--exam", 226 | default='example1', 227 | help="example1/example2") 228 | 229 | args = parser.parse_args() 230 | return args 231 | 232 | 233 | # In[]Parameters of network 234 | def set_configure(args): 235 | model_name=args.model_name 236 | time_input=(args.input_length,1) 237 | epochs=args.epochs 238 | patience=args.patience 239 | monitor=args.monitor 240 | mode=args.monitor_mode 241 | batch_size=args.batch_size 242 | num_filter=args.num_filter, 243 | filter_size=args.filter_size, 244 | num_dense=args.num_dense, 245 | 246 | clas=args.clas 247 | loss=args.loss 248 | 249 | 250 | if not os.path.exists(args.conf_dir): 251 | os.mkdir(args.conf_dir) 252 | 253 | # save configure 254 | f1 = open(args.conf_dir+'Conf_%s.txt'%args.model_name,'w') 255 | f1.write('Model: %s'%model_name+'\n') 256 | f1.write('num_filter: %s'%num_filter+'\n') 257 | f1.write('filter_size: %d'%filter_size+'\n') 258 | f1.write('num_dense: %d'%num_dense+'\n') 259 | f1.write('epochs: %d'%epochs+'\n') 260 | f1.write('batch_size: %d'%batch_size+'\n') 261 | f1.write('monitor: %s'%monitor+'\n') 262 | f1.write('mode: %s'%mode+'\n') 263 | f1.write('patience: %d'%patience+'\n') 264 | f1.write('time_input: %s'%str(time_input)+'\n') 265 | f1.write('class: %d'%clas+'\n') 266 | f1.write('loss: %s'%loss+'\n') 267 | 268 | 269 | f1.close() 270 | 271 | 272 | def main(args): 273 | time_input=(args.input_length,1) 274 | gaus=gaussian(np.linspace(-5, 5, 100),1,0) 275 | 276 | if args.mode=='train': 277 | args.clas=3 278 | clas=args.clas 279 | #=======train_dataset======# 280 | print('load training data') 281 | f22=h5py.File(args.train_dir,'r') 282 | # train_index=[i for i in range(len(f22['fm'])) if f22['fm'][i]<3 ] 283 | # file_num=len(train_index) 284 | file_num=len(f22['fm'][:]) 285 | steps_per_epoch=file_num//args.batch_size 286 | train_generator=DataGenerator_PP1_S(f22,file_num,gaus,batch_size=args.batch_size, augment=True, 287 | classes=clas,time_shift=True, shuffle=True)#,indexes=train_index) 288 | #=======validation_dataset======# 289 | print('load validation data') 290 | f44=h5py.File(args.valid_dir,'r') 291 | # test_index=[i for i in range(len(f44['fm'])) if f44['fm'][i]<3 ] 292 | # file_num1=len(test_index) 293 | file_num1=len(f44['fm'][:]) 294 | test_index=[i for i in range(len(f44['fm']))] 295 | np.random.seed(0) 296 | np.random.shuffle(test_index) 297 | file_num1=10240 298 | validation_steps=file_num1//args.batch_size 299 | validation_generator=DataGenerator_PP1_S(f44,file_num1,gaus,batch_size=args.batch_size, 300 | classes=clas,time_shift=False, shuffle=False,indexes=test_index[:file_num1]) 301 | 302 | 303 | if args.mode=='train': 304 | # model=build_PP_model(time_input=time_input,clas=clas) 305 | model=build_PP_model(time_input=time_input,clas=clas,num_filter=args.num_filter, 306 | filter_size=args.filter_size,num_dense=args.num_dense) 307 | print(args.mode) 308 | if not os.path.exists(args.model_dir): 309 | os.mkdir(args.model_dir) 310 | model.compile(loss=args.loss,optimizer=tf.keras.optimizers.Adam(learning_rate=args.learning_rate),metrics=['accuracy']) 311 | saveBestModel= ModelCheckpoint(args.model_dir+'%s.h5'%args.model_name, monitor=args.monitor, 312 | verbose=1, save_best_only=True,mode=args.monitor_mode) 313 | estop = EarlyStopping(monitor=args.monitor, patience=args.patience, verbose=0, mode=args.monitor_mode) 314 | callbacks_list = [saveBestModel,estop] 315 | # fit 316 | begin = datetime.datetime.now() 317 | history_callback=model.fit_generator(generator=train_generator, 318 | steps_per_epoch=steps_per_epoch, 319 | epochs=args.epochs, 320 | verbose=1, 321 | callbacks=callbacks_list, 322 | use_multiprocessing=args.use_multiprocessing, 323 | workers=args.workers, 324 | validation_data=validation_generator, 325 | validation_steps=validation_steps) 326 | 327 | model.save_weights(args.model_dir+'%s_wt.h5'%args.model_name) 328 | end = datetime.datetime.now() 329 | 330 | #=======plot acc & loss======# 331 | print('plot acc & loss') 332 | if not os.path.exists(args.acc_loss_fig): 333 | os.mkdir(args.acc_loss_fig) 334 | plot_loss(history_callback,save_path=args.acc_loss_fig,model=args.model_name) 335 | 336 | elif args.mode=='test': 337 | print(args.mode) 338 | clas=args.clas 339 | ## load model 340 | try: 341 | model=load_model(args.model_dir+args.model_name+'.h5',custom_objects=SeqSelfAttention.get_custom_objects()) 342 | # model.save_weights(args.model_dir+'%s_wt.h5'%args.model_name) 343 | except: 344 | print('Do not exists model!') 345 | 346 | #===========QC=================# 347 | f44=h5py.File(args.valid_dir,'r') 348 | file_num1=len(f44['Y'][:]) 349 | validation_generator2=DataGenerator_PP1_S(f44,file_num1,gaus,batch_size=file_num1, 350 | classes=clas,time_shift=False, shuffle=False, 351 | indexes=None,new_label=None,augment=False) 352 | begin = datetime.datetime.now() 353 | gen=iter(validation_generator2) 354 | tmp=next(gen) 355 | test_data=tmp[0]['input'] 356 | test_label1=tmp[1]['pk'] 357 | test_label2=tmp[1]['po'] 358 | pred1,pred2=model.predict(test_data) 359 | end = datetime.datetime.now() 360 | print('Testing time:',end-begin) 361 | 362 | ## QC time 363 | tp_t=np.argmax(pred1[:,:,0],axis=1) 364 | tp_tr=np.argmax(test_label1[:,:,0],axis=1) 365 | dt_p=tp_tr-tp_t 366 | print('MAE: %.2f (s)'%np.mean(abs(dt_p*0.01))) 367 | 368 | if not os.path.exists(args.output_dir): 369 | os.mkdir(args.output_dir) 370 | 371 | ## save 372 | file_path=args.output_dir+'RES_%s.txt'%args.model_name 373 | f1=open(file_path,'w') 374 | f1.write('Testing data: \n') 375 | f1.write('Picking error: \n') 376 | f1.write('MAE: %.2f (s) \n'%np.mean(abs(dt_p*0.01))) 377 | f1.close() 378 | ## save picking 379 | np.savez(args.output_dir+'new_test_p_erro',tp_tr=tp_tr,tp_t=tp_t) 380 | 381 | font2 = {'family': 'Times New Roman','weight': 'normal', 'size': 18, } 382 | figure, ax = plt.subplots(figsize=(8,6)) 383 | plt.hist(dt_p*0.01,40,edgecolor='black') 384 | plt.tick_params(labelsize=15) 385 | labels = ax.get_xticklabels() + ax.get_yticklabels() 386 | [label.set_fontname('Times New Roman') for label in labels] 387 | plt.xlabel('tp_true-tp_pred (s)',font2) 388 | plt.ylabel('Frequency',font2) 389 | plt.savefig(args.output_dir+'P_error_new_testing.png',dpi=600) 390 | plt.show() 391 | 392 | # Recall Precision 393 | tp_up,tp_dn,tp_uw,ffp_up,ffp_dn,fp_up,fp_dn,fp_uw=confusion_matrix(pred2,test_label2,file_path,name='New testing') 394 | if args.save_result: 395 | np.savez(args.output_dir+'new_test_cm',res=np.array([tp_up,tp_dn,tp_uw,ffp_up,ffp_dn,fp_up,fp_dn,fp_uw],dtype=object)) 396 | 397 | if args.plot_figure: 398 | if not os.path.exists(args.output_dir+'atten_map/'): 399 | os.mkdir(args.output_dir+'atten_map/') 400 | 401 | self_model1=Model(inputs=model.input, outputs=model.get_layer('Atten').output) 402 | save_path=True 403 | for na in ['tp_up','tp_dn','tp_uw','ffp_up','ffp_dn','fp_up','fp_dn','fp_uw']: 404 | if na=='tp_up': 405 | cm=tp_up 406 | if na=='tp_dn': 407 | cm=tp_dn 408 | if na=='tp_uw': 409 | cm=tp_uw 410 | 411 | if na=='ffp_up': 412 | cm=ffp_up 413 | if na=='ffp_dn': 414 | cm=ffp_dn 415 | 416 | if na=='fp_up': 417 | cm=fp_up 418 | if na=='fp_dn': 419 | cm=fp_dn 420 | if na=='fp_uw': 421 | cm=fp_uw 422 | 423 | for k in range(1,args.num_plots,2): 424 | i=cm[k] 425 | _,test_wt0=self_model1.predict(test_data[i:i+1,:,:], verbose=1) 426 | wt=np.mean(test_wt0[0,:,:],axis=0).reshape(-1,1) 427 | wt_map=np.repeat(wt,50,axis=1).T/np.max(wt) 428 | wt_map=transform.resize(wt_map,(50,400)) 429 | #===========================# 430 | fl=np.argmax(pred2[i]) 431 | if fl==0: 432 | fl_p1='Up' 433 | if fl==1: 434 | fl_p1='Down' 435 | if fl==2: 436 | fl_p1='Unknown' 437 | 438 | fl=np.argmax(test_label2[i]) 439 | if fl==0: 440 | fl_tr='Up' 441 | if fl==1: 442 | fl_tr='Down' 443 | if fl==2: 444 | fl_tr='Unknown' 445 | #===============================# 446 | labelsize=22 447 | font2 = {'family' : 'Times New Roman','weight' : 'bold','size' : 20,} 448 | figure, axes = plt.subplots(3,1,figsize=(8,8)) 449 | axes[0].plot(test_data[i,:,0],'k') 450 | axes[0].set_xlim([0,400]) 451 | axes[0].set_ylim([-1,1]) 452 | axes[0].set_xticks(()) 453 | axes[0].tick_params(labelsize=labelsize) 454 | axes[0].set_ylabel('Amplitude',font2) 455 | axes[0].set_title('True: %s, Predicted: %s'%(fl_tr,fl_p1),font2) 456 | labels = axes[0].get_xticklabels() + axes[0].get_yticklabels() 457 | _=[label.set_fontname('Times New Roman') for label in labels] 458 | 459 | axes[1].plot(test_label1[i,:,0],'b',label='True') 460 | axes[1].plot(pred1[i,:,0],'r-.',label='Predicted') 461 | axes[1].set_xlim([0,400]) 462 | # axes[0].set_ylim([0,1]) 463 | axes[1].tick_params(labelsize=labelsize) 464 | axes[1].set_title('The Probability of P arrival picking',font2) 465 | labels = axes[1].get_xticklabels() + axes[1].get_yticklabels() 466 | _=[label.set_fontname('Times New Roman') for label in labels] 467 | axes[1].legend(prop=font2) 468 | axes[1].set_ylabel('Probability',font2) 469 | 470 | axes[2].imshow( wt_map ) 471 | axes[2].set_xlim([0,400]) 472 | plt.xlabel('Samples',font2) 473 | axes[2].tick_params(labelsize=labelsize) 474 | axes[2].set_title('Attention Map',font2) 475 | labels = axes[2].get_xticklabels() + axes[2].get_yticklabels() 476 | _=[label.set_fontname('Times New Roman') for label in labels] 477 | if save_path: 478 | plt.savefig(args.output_dir+'atten_map/%s_%d.png'%(na,i),dpi=600) 479 | plt.show() 480 | 481 | ##SCSN ORG 482 | f5=h5py.File(args.org_dir,'r') 483 | ## get data label 484 | test_data2,test_label22=gen_test_data_UD(f5,2353054,classes=args.clas,time_shift=False) #2353054 485 | 486 | pred21,pred22=model.predict(test_data2) 487 | 488 | ## In[] QC time 489 | tp_t=np.argmax(pred21[:,:,0],axis=1) 490 | dt_p=200-tp_t 491 | print('MAE: %.2f (s)'%np.mean(abs(dt_p*0.01))) 492 | 493 | # save pciking error 494 | f1=open(file_path,'a+') 495 | f1.write('SCSN data: \n') 496 | f1.write('Picking error: \n') 497 | f1.write('MAE: %.2f (s) \n'%np.mean(abs(dt_p*0.01))) 498 | f1.close() 499 | ## save picking 500 | if args.save_result: 501 | np.savez(args.output_dir+'scsn_test_p_erro',tp_t=tp_t) 502 | 503 | font2 = {'family' : 'Times New Roman', 504 | 'weight' : 'normal', 505 | 'size' : 18, 506 | } 507 | figure, ax = plt.subplots(figsize=(8,6)) 508 | plt.hist(dt_p*0.01,40) 509 | plt.tick_params(labelsize=15) 510 | labels = ax.get_xticklabels() + ax.get_yticklabels() 511 | [label.set_fontname('Times New Roman') for label in labels] 512 | plt.xlabel('tp_true-tp_pred (s)',font2) 513 | plt.ylabel('Frequency',font2) 514 | plt.savefig(args.output_dir+'scsn_time_error.png',dpi=600) 515 | plt.show() 516 | ## confusion matrix 517 | 518 | tp_up2,tp_dn2,tp_uw2,ffp_up2,ffp_dn2,fp_up2,fp_dn2,fp_uw2=confusion_matrix(pred22,test_label22,file_path,name='SCSN test') 519 | if args.save_result: 520 | np.savez(args.output_dir+'scsn_test_cm',res=np.array([tp_up2,tp_dn2,tp_uw2,ffp_up2,ffp_dn2,fp_up2,fp_dn2,fp_uw2],dtype=object)) 521 | 522 | elif args.mode=='predict': 523 | print(args.mode) 524 | ## load model 525 | try: 526 | model=load_model(args.model_dir+args.model_name+'.h5',custom_objects=SeqSelfAttention.get_custom_objects()) 527 | except: 528 | print('Do not exists model!') 529 | 530 | #============================# 531 | # example 1 532 | if args.exam=='example1': 533 | f5=h5py.File(args.org_dir,'r') 534 | test_data2,test_label22=gen_test_data_UD(f5,2353054,classes=args.clas,time_shift=False) #2353054 535 | pred_pick,pre_polatiry=model.predict(test_data2) 536 | #exmaple 2 537 | if args.exam=='example2': 538 | f44=h5py.File(args.valid_dir,'r') 539 | validation_generator_test=DataGenerator_PP1_S_test(f44,len(f44['Y']),gaus,batch_size=args.batch_size, classes=args.clas) 540 | pred_pick,pre_polatiry=model.predict_generator(validation_generator_test,verbose=1) 541 | if args.save_result: 542 | np.savez(args.output_dir+'%s_%s'%(args.exam,args.model_name),pred_pick=pred_pick,pre_polatiry=pre_polatiry) 543 | 544 | else: 545 | print("mode should be: train, test, or predict") 546 | 547 | # In[] 548 | 549 | if __name__ == '__main__': 550 | args = read_args() 551 | start_gpu(args) 552 | if args.mode=='train': 553 | set_configure(args) 554 | main(args) 555 | print('Finish !!!') 556 | 557 | 558 | 559 | 560 | 561 | -------------------------------------------------------------------------------- /utils_APPDNET.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Apr 22 13:25:19 2022 5 | 6 | @author: zhangj2 7 | """ 8 | # In[] 9 | import math 10 | import h5py 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | import os 15 | import numpy as np 16 | from scipy import signal 17 | import math 18 | try: 19 | from keras.utils import Sequence 20 | except: 21 | from tensorflow.keras.utils import Sequence 22 | 23 | from keras.utils.np_utils import to_categorical 24 | import matplotlib.pyplot as plt 25 | from obspy.signal.trigger import recursive_sta_lta,classic_sta_lta,trigger_onset 26 | from scipy import interpolate 27 | import random 28 | 29 | # In[] 30 | ##calculate cM 31 | def confusion_matrix(pred2,test_label2,file_path=None,name='data'): 32 | true_up=[i for i in range(len(test_label2)) if np.argmax(test_label2[i])==0] 33 | true_dn=[i for i in range(len(test_label2)) if np.argmax(test_label2[i])==1] 34 | true_uw=[i for i in range(len(test_label2)) if np.argmax(test_label2[i])==2] 35 | print('Up:%d;Down:%d;Unknown:%d'%(len(true_up),len(true_dn),len(true_uw))) 36 | 37 | pred_up=[i for i in range(len(pred2)) if np.argmax(pred2[i])==0] 38 | pred_dn=[i for i in range(len(pred2)) if np.argmax(pred2[i])==1] 39 | pred_uw=[i for i in range(len(pred2)) if np.argmax(pred2[i])==2] 40 | 41 | print('Up:%d;Down:%d;Unknown:%d'%(len(pred_up),len(pred_dn),len(pred_uw))) 42 | 43 | tp_up=[i for i in range(len(pred2)) if np.argmax(pred2[i])==0 and np.argmax(test_label2[i])==0 ] 44 | tp_dn=[i for i in range(len(pred2)) if np.argmax(pred2[i])==1 and np.argmax(test_label2[i])==1 ] 45 | tp_uw=[i for i in range(len(pred2)) if np.argmax(pred2[i])==2 and np.argmax(test_label2[i])==2 ] 46 | 47 | print('Up:%d;Down:%d;Unknown:%d'%(len(tp_up),len(tp_dn),len(tp_uw))) 48 | 49 | fp_up=[i for i in range(len(pred2)) if np.argmax(pred2[i])==0 and np.argmax(test_label2[i])!=0 ] 50 | fp_dn=[i for i in range(len(pred2)) if np.argmax(pred2[i])==1 and np.argmax(test_label2[i])!=1 ] 51 | fp_uw=[i for i in range(len(pred2)) if np.argmax(pred2[i])==2 and np.argmax(test_label2[i])!=2 ] 52 | 53 | print('Up:%d;Down:%d;Unknown:%d'%(len(fp_up),len(fp_dn),len(fp_uw))) 54 | # Up:96;Down:40;Unknow:112 55 | 56 | ffp_up=[i for i in range(len(pred2)) if np.argmax(pred2[i])==0 and np.argmax(test_label2[i])==1 ] 57 | ffp_dn=[i for i in range(len(pred2)) if np.argmax(pred2[i])==1 and np.argmax(test_label2[i])==0 ] 58 | # fp_uw=[i for i in range(len(pred2)) if np.argmax(pred2[i])==2 and np.argmax(test_label2[i])!=2 ] 59 | 60 | print('Up:%d;Down:%d;Unknown:%d'%(len(ffp_up),len(ffp_dn),len(fp_uw))) 61 | 62 | Pre_U = len(tp_up)/(len(tp_up)+len(ffp_up)) 63 | Pre_D = len(tp_dn)/(len(tp_dn)+len(ffp_dn)) 64 | Pre_K = len(tp_uw)/len(pred_uw) 65 | print('Pre_Up:%.2f;Pre_Down:%.2f;Pre_Unknown:%.2f'%(Pre_U,Pre_D,Pre_K)) 66 | 67 | Re_U = len(tp_up)/len(true_up) 68 | Re_D = len(tp_dn)/len(true_dn) 69 | try: 70 | Re_K = len(tp_uw)/len(true_uw) 71 | except: 72 | Re_K = np.inf 73 | 74 | print('Re_Up:%.2f;Re_Down:%.2f;Re_Unknown:%.2f'%(Re_U,Re_D,Re_K)) 75 | 76 | if file_path: 77 | f1 = open(file_path,'a+') 78 | f1.write('======================\n') 79 | f1.write(name+'\n') 80 | f1.write('True:\n') 81 | f1.write('U; D; K\n') 82 | f1.write('%d %d %d\n'%(len(true_up),len(true_dn),len(true_uw))) 83 | f1.write('Pred:\n') 84 | f1.write('U; D; K\n') 85 | f1.write('%d %d %d\n'%(len(pred_up),len(pred_dn),len(pred_uw))) 86 | f1.write('TP_U; TP_D; TP_K\n') 87 | f1.write('%d %d %d\n'%(len(tp_up),len(tp_dn),len(tp_uw))) 88 | f1.write('FP_U; FP_D; FP_K\n') 89 | f1.write('%d %d %d\n'%(len(fp_up),len(fp_dn),len(fp_uw))) 90 | f1.write('FFP_U; FFP_D \n') 91 | f1.write('%d %d\n' %(len(ffp_up),len(ffp_dn))) 92 | f1.write('Pre_U; Pre_D; Pre_K\n') 93 | f1.write('%.2f %.2f %.2f\n'%(Pre_U,Pre_D,Pre_K)) 94 | f1.write('Re_U; Re_D; Re_K\n') 95 | f1.write('%.2f %.2f %.2f\n'%(Re_U,Re_D,Re_K)) 96 | f1.write('======================\n') 97 | f1.close() 98 | 99 | return tp_up,tp_dn,tp_uw,ffp_up,ffp_dn,fp_up,fp_dn,fp_uw 100 | 101 | 102 | ## generate gauss function 103 | def gaussian(x, sigma, u): 104 | y = np.exp(-(x - u) ** 2 / (2 * sigma ** 2)) / (sigma * math.sqrt(2 * math.pi)) 105 | return y/np.max(abs(y)) 106 | 107 | # In[] 108 | class DataGenerator_PP1_S(Sequence): 109 | 110 | def __init__(self, L,file_num,gaus, 111 | batch_size=128, 112 | classes=3, 113 | time_shift=False, 114 | shuffle=True,indexes=None,new_label=None,augment=False): 115 | """ 116 | # Arguments 117 | --- 118 | file_num: number of files . 119 | batch_size: . """ 120 | self.L = L 121 | self.batch_size = batch_size 122 | self.file_num=file_num 123 | self.gaus=gaus 124 | if indexes is None: 125 | self.indexes=np.arange(file_num) 126 | else: 127 | self.indexes=indexes 128 | if new_label is None: 129 | self.flag=0 130 | else: 131 | self.flag=1 132 | self.new_label=new_label 133 | 134 | self.shuffle = shuffle 135 | self.classes= classes 136 | self.time_shift= time_shift 137 | self.augment=augment 138 | 139 | def __len__(self): 140 | """return: steps num of one epoch. """ 141 | 142 | return self.file_num// self.batch_size 143 | 144 | def __getitem__(self, index): 145 | """Gets the `index-th` batch. 146 | --- 147 | # Arguments 148 | index: position of the batch in the Sequence. 149 | # Returns 150 | A batch data. """ 151 | 152 | 153 | # get batch data inds. 154 | batch_inds = self.indexes[index * 155 | self.batch_size:(index+1)*self.batch_size] 156 | # read batch data 157 | X, Y1 ,Y2= self._read_data(batch_inds) 158 | return ({'input': X}, {'pk':Y1, 'po':Y2}) 159 | 160 | def on_epoch_end(self): 161 | """shuffle data after one epoch. """ 162 | if self.shuffle == True: 163 | np.random.shuffle(self.indexes) 164 | 165 | def _add_noise(self,sig,db,k): 166 | n=6000 167 | np.random.seed(k) 168 | noise=np.random.normal(size=(3,n)) 169 | s2=np.sum(sig**2)/len(sig) 170 | n2=np.sum(noise[2,:]**2)/len(noise[2,:]) 171 | a=(s2/n2/(10**(db/10)))**(0.5) 172 | noise=noise*a 173 | return noise 174 | 175 | def _bp_filter(self,data,n,n1,n2,dt): 176 | wn1=n1*2*dt 177 | wn2=n2*2*dt 178 | b, a = signal.butter(n, [wn1,wn2], 'bandpass') 179 | filtedData = signal.filtfilt(b, a, data) 180 | return filtedData 181 | 182 | def _normal3(self,data): 183 | data2=np.zeros((data.shape[0],data.shape[1],data.shape[2])) 184 | for i in range(data.shape[0]): 185 | data1=data[i,:,:] 186 | x_max=np.max(abs(data1)) 187 | if x_max!=0.0: 188 | data2[i,:,:]=data1/x_max 189 | return data2 190 | def _taper(self,data,n,N): 191 | nn=len(data) 192 | if n==1: 193 | w=math.pi/N 194 | F0=0.5 195 | F1=0.5 196 | elif n==2: 197 | w=math.pi/N 198 | F0=0.54 199 | F1=0.46 200 | else: 201 | w=math.pi/N/2 202 | F0=1 203 | F1=1 204 | win=np.ones((nn,1)) 205 | for i in range(N): 206 | win[i]=(F0-F1*math.cos(w*(i-1))) 207 | win1=np.flipud(win) 208 | 209 | data1=data*win.reshape(win.shape[0],) 210 | data1=data1*win1.reshape(win1.shape[0],) 211 | return data1 212 | 213 | def _encoder(self,lab,classes=2): 214 | inx=[i for i in range(len(lab)) if lab[i]>0 ] 215 | lab[inx]=1 216 | return to_categorical(lab,classes) 217 | 218 | 219 | def _read_data(self, batch_inds): 220 | """Read a batch data. 221 | --- 222 | # Arguments 223 | batch_files: the file of batch data. 224 | 225 | # Returns 226 | data: (batch_size, (5000,1,3)). 227 | label: (batch_size, (5000,1,num)). """ 228 | #------------------------# 229 | np.random.seed(0) 230 | data=[] 231 | label1=[] 232 | label2=[] 233 | 234 | for k in batch_inds: 235 | # L 236 | dat=self.L['X'][k] 237 | # pt1=300 238 | pt1=int((self.L['Y'][k])*100) 239 | lab2=self.L['fm'][k] 240 | if self.flag==1: 241 | lab2=self.new_label[k] 242 | 243 | try: 244 | lab1=np.zeros(np.size(dat,0),) 245 | lab1[pt1-50:pt1+50]=self.gaus 246 | except: 247 | print(np.size(dat,0),k,pt1,int((self.L['Y'][k])*100)) 248 | 249 | if self.time_shift : 250 | for time_sf in range(-300,-100,10): 251 | data.append(dat[pt1+time_sf:pt1+time_sf+400]) 252 | label1.append(lab1[pt1+time_sf:pt1+time_sf+400]) 253 | 254 | lab3=to_categorical(lab2,self.classes) 255 | label2.append(lab3) 256 | 257 | if self.augment: 258 | data.append(-dat[pt1+time_sf:pt1+time_sf+400]) 259 | label1.append(lab1[pt1+time_sf:pt1+time_sf+400]) 260 | if lab2==2: 261 | lab3=to_categorical(2,self.classes) 262 | if lab2==1: 263 | lab3=to_categorical(0,self.classes) 264 | if lab2==0: 265 | lab3=to_categorical(1,self.classes) 266 | label2.append(lab3) 267 | else: 268 | data.append(dat[pt1-200:pt1+200]) 269 | label1.append(lab1[pt1-200:pt1+200]) 270 | lab2=to_categorical(lab2,self.classes) 271 | label2.append(lab2) 272 | 273 | # L1 274 | if self.augment: 275 | 276 | num=random.randint(2,4) # zj 277 | x=np.linspace(0,600,600) 278 | f1=interpolate.interp1d(x,dat,kind='cubic') 279 | x1=np.linspace(0,600,num*600) 280 | 281 | dat=f1(x1) 282 | # pt1=300*num 283 | pt1=int((self.L['Y'][k]*num)*100) 284 | lab2=self.L['fm'][k] 285 | if self.flag==1: 286 | lab2=self.new_label[k] 287 | 288 | lab1=np.zeros(np.size(dat,0),) 289 | lab1[pt1-50:pt1+50]=self.gaus 290 | 291 | if self.time_shift : 292 | for time_sf in range(-300,-100,10): 293 | data.append(dat[pt1+time_sf:pt1+time_sf+400]) 294 | label1.append(lab1[pt1+time_sf:pt1+time_sf+400]) 295 | 296 | lab3=to_categorical(lab2,self.classes) 297 | label2.append(lab3) 298 | 299 | if self.augment: 300 | data.append(-dat[pt1+time_sf:pt1+time_sf+400]) 301 | label1.append(lab1[pt1+time_sf:pt1+time_sf+400]) 302 | if lab2==2: 303 | lab3=to_categorical(2,self.classes) 304 | if lab2==1: 305 | lab3=to_categorical(0,self.classes) 306 | if lab2==0: 307 | lab3=to_categorical(1,self.classes) 308 | label2.append(lab3) 309 | 310 | else: 311 | data.append(dat[pt1-200:pt1+200]) 312 | label1.append(lab1[pt1-200:pt1+200]) 313 | lab2=to_categorical(lab2,self.classes) 314 | label2.append(lab2) 315 | 316 | data=np.expand_dims(np.array(data) ,axis=2) 317 | label1=np.array(label1) 318 | label2=np.array(label2) 319 | return data, label1.reshape(-1,400,1) , label2 320 | # In[] 321 | class DataGenerator_PP1_S_test(Sequence): 322 | 323 | def __init__(self, L,file_num,gaus, 324 | batch_size=128, 325 | classes=3, 326 | time_shift=False, 327 | shuffle=True,indexes=None): 328 | """ 329 | # Arguments 330 | --- 331 | file_num: number of files . 332 | batch_size: . """ 333 | self.L = L 334 | self.batch_size = batch_size 335 | self.file_num=file_num 336 | self.gaus=gaus 337 | if indexes is None: 338 | self.indexes=np.arange(file_num) 339 | else: 340 | self.indexes=indexes 341 | 342 | self.shuffle = shuffle 343 | self.classes= classes 344 | self.time_shift= time_shift 345 | 346 | def __len__(self): 347 | """return: steps num of one epoch. """ 348 | 349 | return self.file_num// self.batch_size+1 350 | 351 | def __getitem__(self, index): 352 | """Gets the `index-th` batch. 353 | --- 354 | # Arguments 355 | index: position of the batch in the Sequence. 356 | # Returns 357 | A batch data. """ 358 | 359 | 360 | # get batch data inds. 361 | batch_inds = self.indexes[index * 362 | self.batch_size:(index+1)*self.batch_size] 363 | # read batch data 364 | X, Y1 ,Y2= self._read_data(batch_inds) 365 | return ({'input': X}, {'pk':Y1, 'po':Y2}) 366 | 367 | def on_epoch_end(self): 368 | """shuffle data after one epoch. """ 369 | if self.shuffle == True: 370 | np.random.shuffle(self.indexes) 371 | 372 | def _add_noise(self,sig,db,k): 373 | n=6000 374 | np.random.seed(k) 375 | noise=np.random.normal(size=(3,n)) 376 | s2=np.sum(sig**2)/len(sig) 377 | n2=np.sum(noise[2,:]**2)/len(noise[2,:]) 378 | a=(s2/n2/(10**(db/10)))**(0.5) 379 | noise=noise*a 380 | return noise 381 | 382 | def _bp_filter(self,data,n,n1,n2,dt): 383 | wn1=n1*2*dt 384 | wn2=n2*2*dt 385 | b, a = signal.butter(n, [wn1,wn2], 'bandpass') 386 | filtedData = signal.filtfilt(b, a, data) 387 | return filtedData 388 | 389 | def _normal3(self,data): 390 | data2=np.zeros((data.shape[0],data.shape[1],data.shape[2])) 391 | for i in range(data.shape[0]): 392 | data1=data[i,:,:] 393 | x_max=np.max(abs(data1)) 394 | if x_max!=0.0: 395 | data2[i,:,:]=data1/x_max 396 | return data2 397 | def _taper(self,data,n,N): 398 | nn=len(data) 399 | if n==1: 400 | w=math.pi/N 401 | F0=0.5 402 | F1=0.5 403 | elif n==2: 404 | w=math.pi/N 405 | F0=0.54 406 | F1=0.46 407 | else: 408 | w=math.pi/N/2 409 | F0=1 410 | F1=1 411 | win=np.ones((nn,1)) 412 | for i in range(N): 413 | win[i]=(F0-F1*math.cos(w*(i-1))) 414 | win1=np.flipud(win) 415 | 416 | data1=data*win.reshape(win.shape[0],) 417 | data1=data1*win1.reshape(win1.shape[0],) 418 | return data1 419 | 420 | def _encoder(self,lab,classes=2): 421 | inx=[i for i in range(len(lab)) if lab[i]>0 ] 422 | lab[inx]=1 423 | return to_categorical(lab,classes) 424 | 425 | 426 | def _read_data(self, batch_inds): 427 | """Read a batch data. 428 | --- 429 | # Arguments 430 | batch_files: the file of batch data. 431 | 432 | # Returns 433 | data: (batch_size, (5000,1,3)). 434 | label: (batch_size, (5000,1,num)). """ 435 | #------------------------# 436 | np.random.seed(0) 437 | data=[] 438 | label1=[] 439 | label2=[] 440 | 441 | for k in batch_inds: 442 | # L 443 | dat=self.L['X'][k] 444 | pt1=int((self.L['Y'][k])*100) 445 | lab2=self.L['fm'][k] 446 | 447 | try: 448 | lab1=np.zeros(np.size(dat,0),) 449 | lab1[pt1-50:pt1+50]=self.gaus 450 | except: 451 | print(np.size(dat,0),k,pt1,int((self.L['Y'][k])*100)) 452 | 453 | if self.time_shift : 454 | for time_sf in range(-300,-100,10): 455 | data.append(dat[pt1+time_sf:pt1+time_sf+400]) 456 | label1.append(lab1[pt1+time_sf:pt1+time_sf+400]) 457 | 458 | lab3=to_categorical(lab2,self.classes) 459 | label2.append(lab3) 460 | 461 | else: 462 | data.append(dat[pt1-200:pt1+200]) 463 | label1.append(lab1[pt1-200:pt1+200]) 464 | lab2=to_categorical(lab2,self.classes) 465 | label2.append(lab2) 466 | 467 | data=np.expand_dims(np.array(data) ,axis=2) 468 | label1=np.array(label1) 469 | label2=np.array(label2) 470 | return data, label1.reshape(-1,400,1) , label2 471 | 472 | # In[] 473 | def gen_test_data_UD(L,batch_inds,classes=3,time_shift=False,dl=False): 474 | np.random.seed(0) 475 | data=[] 476 | label=[] 477 | info=[] 478 | 479 | for k in range(batch_inds): 480 | dat=L['X'][k] 481 | lab=L['Y'][k] 482 | 483 | if dl: 484 | dist=L['dist'][k] 485 | evids=L['evids'][k] 486 | mag=L['mag'][k] 487 | sncls=L['sncls'][k] 488 | snr=L['snr'][k] 489 | sncls=str(sncls, encoding='utf-8') 490 | 491 | if time_shift : 492 | for time_sf in range(0,200,10): 493 | data.append(dat[time_sf:time_sf+400]) 494 | lab2=to_categorical(lab,classes) 495 | label.append(lab2) 496 | if dl: 497 | info.append([dist,snr,mag,sncls,evids]) 498 | 499 | else: 500 | data.append(dat[100:500]) 501 | lab2=to_categorical(lab,classes) 502 | label.append(lab2) 503 | if dl: 504 | info.append([dist,evids,snr,sncls]) 505 | data=np.expand_dims(np.array(data),axis=2) 506 | label=np.array(label) 507 | 508 | if dl: 509 | info.append([dist,evids,snr,sncls]) 510 | return data, label, info 511 | else: 512 | return data, label 513 | 514 | 515 | # In[] 516 | 517 | def plot_loss(history_callback,save_path=None,model='model'): 518 | font2 = {'family' : 'Times New Roman', 519 | 'weight' : 'normal', 520 | 'size' : 18, 521 | } 522 | 523 | history_dict=history_callback.history 524 | 525 | loss_value=history_dict['loss'] 526 | val_loss_value=history_dict['val_loss'] 527 | 528 | loss_pk=history_dict['pk_loss'] 529 | val_loss_pk=history_dict['val_pk_loss'] 530 | 531 | loss_po=history_dict['po_loss'] 532 | val_loss_po=history_dict['val_po_loss'] 533 | 534 | try: 535 | acc_pk=history_dict['pk_accuracy'] 536 | val_acc_pk=history_dict['val_pk_accuracy'] 537 | acc_po=history_dict['po_accuracy'] 538 | val_acc_po=history_dict['val_po_accuracy'] 539 | 540 | except: 541 | acc_value=history_dict['accuracy'] 542 | val_acc_value=history_dict['val_accuracy'] 543 | 544 | epochs=range(1,len(acc_pk)+1) 545 | if not save_path is None: 546 | np.savez(save_path+'acc_loss_%s'%model, 547 | loss=loss_value,val_loss=val_loss_value, 548 | loss_pk=loss_pk,val_loss_pk=val_loss_pk, 549 | loss_po=loss_po,val_loss_po=val_loss_po, 550 | acc_pk=acc_pk,val_acc_pk=val_acc_pk, 551 | acc_po=acc_po,val_acc_po=val_acc_po) 552 | 553 | # acc picking 554 | figure, ax = plt.subplots(figsize=(8,6)) 555 | plt.plot(epochs,acc_pk,'b',label='Training acc of picking') 556 | plt.plot(epochs,val_acc_pk,'r',label='Validation acc of picking') 557 | plt.tick_params(labelsize=15) 558 | labels = ax.get_xticklabels() + ax.get_yticklabels() 559 | [label.set_fontname('Times New Roman') for label in labels] 560 | plt.xlabel('Epochs',font2) 561 | plt.ylabel('Accuracy',font2) 562 | plt.legend(prop=font2,loc='lower right') 563 | if not save_path is None: 564 | plt.savefig(save_path+'ACC_PK_%s.png'%model,dpi=600) 565 | plt.show() 566 | 567 | # acc polarity 568 | figure, ax = plt.subplots(figsize=(8,6)) 569 | plt.plot(epochs,acc_po,'b--',label='Training acc of polarity') 570 | plt.plot(epochs,val_acc_po,'r--',label='Validation acc of polarity') 571 | plt.tick_params(labelsize=15) 572 | labels = ax.get_xticklabels() + ax.get_yticklabels() 573 | [label.set_fontname('Times New Roman') for label in labels] 574 | plt.xlabel('Epochs',font2) 575 | plt.ylabel('Accuracy',font2) 576 | plt.legend(prop=font2,loc='lower right') 577 | if not save_path is None: 578 | plt.savefig(save_path+'ACC_PO_%s.png'%model,dpi=600) 579 | plt.show() 580 | 581 | # loss 582 | figure, ax = plt.subplots(figsize=(8,6)) 583 | plt.plot(epochs,loss_value,'b',label='Training loss') 584 | plt.plot(epochs,val_loss_value,'r',label='Validation loss') 585 | plt.tick_params(labelsize=15) 586 | labels = ax.get_xticklabels() + ax.get_yticklabels() 587 | [label.set_fontname('Times New Roman') for label in labels] 588 | plt.xlabel('Epochs',font2) 589 | plt.ylabel('Loss',font2) 590 | plt.legend(prop=font2) 591 | if not save_path is None: 592 | plt.savefig(save_path+'Loss_%s.png'%model,dpi=600) 593 | plt.show() 594 | 595 | # loss Picking 596 | figure, ax = plt.subplots(figsize=(8,6)) 597 | plt.plot(epochs,loss_pk,'b--',label='Training loss of picking') 598 | plt.plot(epochs,val_loss_pk,'r--',label='Validation loss of picking') 599 | plt.tick_params(labelsize=15) 600 | labels = ax.get_xticklabels() + ax.get_yticklabels() 601 | [label.set_fontname('Times New Roman') for label in labels] 602 | plt.xlabel('Epochs',font2) 603 | plt.ylabel('Loss',font2) 604 | plt.legend(prop=font2) 605 | if not save_path is None: 606 | plt.savefig(save_path+'Loss_PK_%s.png'%model,dpi=600) 607 | plt.show() 608 | 609 | # loss Polarity 610 | figure, ax = plt.subplots(figsize=(8,6)) 611 | plt.plot(epochs,loss_po,'b-.',label='Training loss of polarity') 612 | plt.plot(epochs,val_loss_po,'r-.',label='Validation loss of polarity') 613 | 614 | plt.tick_params(labelsize=15) 615 | labels = ax.get_xticklabels() + ax.get_yticklabels() 616 | [label.set_fontname('Times New Roman') for label in labels] 617 | plt.xlabel('Epochs',font2) 618 | plt.ylabel('Loss',font2) 619 | plt.legend(prop=font2) 620 | if not save_path is None: 621 | plt.savefig(save_path+'Loss_PO_%s.png'%model,dpi=600) 622 | plt.show() --------------------------------------------------------------------------------