├── .gitignore ├── Case_article.py ├── Case_article_test1D.py ├── Case_article_test2D.py ├── Case_article_testreal.py ├── Case_small.py ├── Case_small_test.py ├── Cases_define.py ├── Dockerfile ├── LICENSE ├── README.md ├── plot_example.py ├── plot_loss.py ├── plot_prediction.py ├── plot_real_synth.py ├── realdata └── Process_realdata.py ├── reproduce_results.py ├── requirements.txt ├── semblance └── nmo_correction.py └── vrmslearn ├── Inputqueue.py ├── ModelGenerator.py ├── ModelParameters.py ├── RCNN.py ├── SeismicGenerator.py ├── Tester.py └── Trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | scratch* 3 | __pycache__ 4 | logs* 5 | *.png -------------------------------------------------------------------------------- /Case_article.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Script to perform data creation and training for the main case presented in 5 | the article. For a smaller training set, see Case_small.py 6 | """ 7 | 8 | from vrmslearn.ModelParameters import ModelParameters 9 | from Cases_define import Case_article 10 | from vrmslearn.SeismicGenerator import SeismicGenerator, generate_dataset 11 | from vrmslearn.Trainer import Trainer 12 | from vrmslearn.RCNN import RCNN 13 | import os 14 | import argparse 15 | import tensorflow as tf 16 | import fnmatch 17 | 18 | if __name__ == "__main__": 19 | 20 | # Initialize argument parser 21 | parser = argparse.ArgumentParser() 22 | 23 | # Add arguments to parse for training 24 | parser.add_argument( 25 | "--nthread", 26 | type=int, 27 | default=1, 28 | help="Number of threads for data creation" 29 | ) 30 | parser.add_argument( 31 | "--nthread_read", 32 | type=int, 33 | default=1, 34 | help="Number of threads used as input producer" 35 | ) 36 | parser.add_argument( 37 | "--logdir", 38 | type=str, 39 | default="./logs", 40 | help="Directory in which to store the checkpoints" 41 | ) 42 | parser.add_argument( 43 | "--training", 44 | type=int, 45 | default=1, 46 | help="1: training only, 0: create dataset only, 2: training+dataset" 47 | ) 48 | parser.add_argument( 49 | "--workdir", 50 | type=str, 51 | default="./seiscl_workdir", 52 | help="name of SeisCL working directory " 53 | ) 54 | parser.add_argument( 55 | "--lr", 56 | type=float, 57 | default=0.0008, 58 | help="learning rate " 59 | ) 60 | parser.add_argument( 61 | "--eps", 62 | type=float, 63 | default=1e-5, 64 | help="epsilon for adadelta" 65 | ) 66 | parser.add_argument( 67 | "--batchsize", 68 | type=int, 69 | default=40, 70 | help="size of the batches" 71 | ) 72 | parser.add_argument( 73 | "--beta1", 74 | type=float, 75 | default=0.9, 76 | help="beta1 for adadelta" 77 | ) 78 | parser.add_argument( 79 | "--beta2", 80 | type=float, 81 | default=0.98, 82 | help="beta2 for adadelta" 83 | ) 84 | parser.add_argument( 85 | "--nmodel", 86 | type=int, 87 | default=1, 88 | help="Number of models to train" 89 | ) 90 | parser.add_argument( 91 | "--noise", 92 | type=int, 93 | default=1, 94 | help="1: Add noise to the data" 95 | ) 96 | parser.add_argument( 97 | "--use_peepholes", 98 | type=int, 99 | default=1, 100 | help="1: Use peephole version of LSTM" 101 | ) 102 | 103 | 104 | # Parse the input for training parameters 105 | args, unparsed = parser.parse_known_args() 106 | 107 | savepath = "./dataset_article" 108 | logdir = args.logdir 109 | nthread = args.nthread 110 | batch_size = args.batchsize 111 | 112 | """ 113 | _______________________Define the parameters ______________________ 114 | """ 115 | pars = Case_article(noise=args.noise) 116 | 117 | """ 118 | _______________________Generate the dataset_____________________________ 119 | """ 120 | gen = SeismicGenerator(model_parameters=pars) 121 | 122 | pars.num_layers = 0 123 | dhmins = [5] 124 | layer_num_mins = [5, 10, 30, 50] 125 | nexamples = 10000 126 | 127 | if not os.path.isdir(savepath): 128 | os.mkdir(savepath) 129 | 130 | if args.training != 1: 131 | for dhmin in dhmins: 132 | for layer_num_min in layer_num_mins: 133 | pars.layer_dh_min = dhmin 134 | pars.layer_num_min = layer_num_min 135 | this_savepath = (savepath 136 | + "/dhmin%d" % dhmin 137 | + "_layer_num_min%d" % layer_num_min) 138 | generate_dataset(pars=pars, 139 | savepath=this_savepath, 140 | nthread=args.nthread, 141 | nexamples=nexamples, 142 | workdir=args.workdir) 143 | 144 | 145 | """ 146 | ___________________________Do the training _____________________________ 147 | 148 | We define 3 stages for inversion, with different alpha, beta gamma in the 149 | loss function: 150 | 1st stage: alpha = 0, beta=1 and gamma=0: we train for reflection 151 | identification 152 | 2nd stage: alpha = 0.2, beta=0.1 and gamma=0.1: we train for reflection 153 | identification and vrms, with regularization on vrms time 154 | derivative (alpha) et higher weights on vrms at reflections 155 | arrival times (gamma) 156 | 3rd stage: alpha = 0.02, beta=0.02 and gamma=0.1, we add weight to vrms 157 | 158 | """ 159 | schedules = [[0.0, 0.95, 0, 0, 0], 160 | [0.05, 0.1, 0, 0, 0], 161 | [0.05, 0.1, 0, 0.35, 0.05], 162 | [0.01, 0.01, 0, 0, 0], 163 | [0, 0, 0, 0.95, 0.05]] 164 | niters = [1000, 10000, 10000, 1000, 1000] 165 | if args.training != 0: 166 | for nmod in range(args.nmodel): 167 | restore_from = None 168 | npass = 0 169 | for ii, schedule in enumerate(schedules): 170 | this_savepath = [] 171 | for layer_num_min in layer_num_mins: 172 | for dhmin in dhmins: 173 | this_savepath.append(savepath 174 | + "/dhmin%d" % dhmin 175 | + "_layer_num_min%d" % layer_num_min) 176 | this_logdir = (logdir 177 | + "%d" % nmod 178 | + "/%d" % npass 179 | + "_schedule%d" % ii 180 | + "_lr%f_eps_%f" % (args.lr, args.eps) 181 | + "_beta1%f" % args.beta1 182 | + "_beta2%f" % args.beta2 183 | + "_batch_size_%d" % batch_size) 184 | 185 | lastfile = this_logdir + 'model.ckpt-' + str(niters[ii]) + '*' 186 | 187 | try: 188 | isckpt = fnmatch.filter(os.listdir(this_logdir), 189 | 'model.ckpt-' + str(niters[ii]) + '*') 190 | except FileNotFoundError: 191 | isckpt =[] 192 | 193 | if not isckpt: 194 | print(this_logdir) 195 | pars.layer_dh_min = dhmin 196 | pars.layer_num_min = layer_num_min 197 | seismic_gen = SeismicGenerator(model_parameters=pars) 198 | nn = RCNN(input_size=seismic_gen.image_size, 199 | batch_size=batch_size, 200 | alpha=schedule[0], 201 | beta=schedule[1], 202 | gamma=schedule[2], 203 | zeta=schedule[3], 204 | omega=schedule[4], 205 | use_peepholes=args.use_peepholes) 206 | 207 | if layer_num_min == layer_num_mins[0] and dhmin == dhmins[0]: 208 | learning_rate = args.lr 209 | else: 210 | learning_rate = args.lr/8 211 | if ii>2: 212 | learning_rate = args.lr/128 213 | # Optimize only last layers during last schedule 214 | if ii == 4: 215 | with nn.graph.as_default(): 216 | var_to_minimize = tf.trainable_variables(scope='rnn_vint') 217 | var_to_minimize.append(tf.trainable_variables(scope='Decode_vint')) 218 | else: 219 | var_to_minimize = None 220 | 221 | trainer = Trainer(NN=nn, 222 | data_generator=seismic_gen, 223 | checkpoint_dir=this_logdir, 224 | learning_rate=learning_rate, 225 | beta1=args.beta1, 226 | beta2=args.beta2, 227 | epsilon=args.eps, 228 | var_to_minimize=var_to_minimize) 229 | trainer.train_model(niter=niters[ii], 230 | savepath=this_savepath, 231 | restore_from=restore_from, 232 | thread_read=args.nthread_read) 233 | restore_from = this_logdir + '/model.ckpt-' + str(niters[ii]) 234 | npass += 1 235 | -------------------------------------------------------------------------------- /Case_article_test1D.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Create the test dataset for Case article, performs the testing and plot results 5 | """ 6 | from vrmslearn.ModelParameters import ModelParameters 7 | from vrmslearn.SeismicGenerator import SeismicGenerator, generate_dataset 8 | from vrmslearn.ModelGenerator import interval_velocity_time 9 | from vrmslearn.Tester import Tester 10 | from vrmslearn.RCNN import RCNN 11 | from Cases_define import Case_article 12 | import os 13 | import argparse 14 | import numpy as np 15 | from plot_prediction import plot_predictions_semb3 16 | import h5py as h5 17 | import fnmatch 18 | from scipy.stats import mode 19 | import sys 20 | 21 | 22 | def get_rms(name, masks, vint_pred, vint, vrms_pred, vrms, ref_pred, ref): 23 | print(name) 24 | masks = np.array(masks) 25 | nsamples = np.sum(masks == 1) 26 | vint_pred = np.array(vint_pred) 27 | vint = np.array(vint) 28 | vint_rmse = np.sqrt(np.sum(masks * (vint - vint_pred)**2) / nsamples) 29 | print("Interval velocity RMSE: %f m/s" % vint_rmse) 30 | 31 | vrms_pred = np.array(vrms_pred) 32 | vrms = np.array(vrms) 33 | vrms_rmse = np.sqrt(np.sum(masks * (vrms - vrms_pred) ** 2) / nsamples) 34 | print("RMS velocity RMSE: %f m/s" % vrms_rmse) 35 | 36 | ref_pred = np.array(ref_pred) 37 | ref = np.array(ref) 38 | nsamples = ref.flatten().shape[0] 39 | true_pos = np.sum(((ref - ref_pred) == 0) * (ref == 1)) / nsamples 40 | true_neg = np.sum(((ref - ref_pred) == 0) * (ref == 0)) / nsamples 41 | false_pos = np.sum((ref - ref_pred) == -1) / nsamples 42 | false_neg = np.sum((ref - ref_pred) == 1) / nsamples 43 | 44 | print("True positive: %f, True negative: %f, False positive %f " 45 | "False negative: %f" % (true_pos, true_neg, false_pos, false_neg)) 46 | 47 | print("") 48 | 49 | return vint_rmse, vrms_rmse, true_pos, true_neg, false_pos, false_neg 50 | 51 | 52 | if __name__ == "__main__": 53 | 54 | # Initialize argument parser 55 | parser = argparse.ArgumentParser() 56 | 57 | # Add arguments to parse for training 58 | parser.add_argument( 59 | "--nthread", 60 | type=int, 61 | default=1, 62 | help="Number of threads per gpus for data creation" 63 | ) 64 | parser.add_argument( 65 | "--logdir", 66 | type=str, 67 | default="logs/model.ckpt-5000", 68 | help="Checkpoint filename for which to predict" 69 | ) 70 | parser.add_argument( 71 | "--testing", 72 | type=int, 73 | default=3, 74 | help="1: testing only, 0: create dataset only, 2: testing+dataset, 3: ploting only" 75 | ) 76 | parser.add_argument( 77 | "--workdir", 78 | type=str, 79 | default="./seiscl_workdir", 80 | help="name of SeisCL working directory " 81 | ) 82 | parser.add_argument( 83 | "--dataset_path", 84 | type=str, 85 | default="./dataset_article/test", 86 | help="name of the test dataset directory " 87 | ) 88 | parser.add_argument( 89 | "--niter", 90 | type=int, 91 | default=1000, 92 | help="Iteration number of the checkpoint file" 93 | ) 94 | 95 | 96 | # Parse the input for training parameters 97 | args, unparsed = parser.parse_known_args() 98 | 99 | dirs = [] 100 | for dir1 in os.listdir('./'): 101 | if os.path.isdir(dir1): 102 | for dir2 in os.listdir(dir1): 103 | path2 = os.path.join(dir1, dir2) 104 | if os.path.isdir(path2): 105 | dirs.append(path2) 106 | 107 | logdirs = fnmatch.filter(dirs, args.logdir) 108 | logdirs.sort() 109 | logdirs = [d + "/model.ckpt-" + str(args.niter) for d in logdirs] 110 | print("Found %d log directories to test" % len(logdirs), flush=True) 111 | for logdir in logdirs: 112 | print(logdir, flush=True) 113 | 114 | """ 115 | __________________Define the parameters for Case Article________________ 116 | """ 117 | pars = Case_article() 118 | 119 | """ 120 | __________________Generate the dataset______________________ 121 | """ 122 | pars.num_layers = 0 123 | dhmins = [5] 124 | layer_num_mins = [5, 10, 30, 50] 125 | nexamples = 400 126 | 127 | if not os.path.isdir(args.dataset_path): 128 | os.mkdir(args.dataset_path) 129 | 130 | n = 1 131 | if args.testing != 1 and args.testing != 3: 132 | for dhmin in dhmins: 133 | for layer_num_min in layer_num_mins: 134 | this_savepath = (args.dataset_path 135 | + "/dhmin" + str(dhmin) 136 | + "layer_num_min" + str(layer_num_min)) 137 | pars.layer_dh_min = dhmin 138 | pars.layer_num_min = layer_num_min 139 | generate_dataset(pars=pars, 140 | savepath=this_savepath, 141 | nthread=1, 142 | nexamples=nexamples, 143 | workdir=args.workdir, 144 | seed=n) 145 | n+=1 146 | 147 | if args.testing==0: 148 | sys.exit() 149 | 150 | """ 151 | ___________________________Do the testing ______________________________ 152 | """ 153 | seismic_gen = SeismicGenerator(model_parameters=pars) 154 | nn = RCNN(input_size=seismic_gen.image_size, 155 | batch_size=2) 156 | tester = Tester(NN=nn, data_generator=seismic_gen) 157 | toeval = [nn.output_ref, nn.output_vrms, nn.output_vint] 158 | toeval_names = ["ref", "vrms", "vint"] 159 | vint_rmse_all = 0 160 | vrms_rmse_all = 0 161 | true_pos_all = 0 162 | true_neg_all = 0 163 | false_pos_all = 0 164 | false_neg_all = 0 165 | 166 | for dhmin in dhmins: 167 | for layer_num_min in layer_num_mins: 168 | vint = [None for _ in range(len(logdirs))] 169 | vint_pred = [None for _ in range(len(logdirs))] 170 | vrms = [None for _ in range(len(logdirs))] 171 | vrms_pred = [None for _ in range(len(logdirs))] 172 | ref = [None for _ in range(len(logdirs))] 173 | ref_pred = [None for _ in range(len(logdirs))] 174 | 175 | for n, logdir in enumerate(logdirs): 176 | seismic_gen.pars.layer_dh_min = dhmin 177 | seismic_gen.pars.layer_num_min = layer_num_min 178 | this_savepath = os.path.join(args.dataset_path, logdir) + "/dhmin" + str(dhmin) + "layer_num_min" + str(layer_num_min) 179 | dataset_path = args.dataset_path + "/dhmin" + str(dhmin) + "layer_num_min" + str(layer_num_min) 180 | if not os.path.isdir(this_savepath): 181 | os.makedirs(this_savepath) 182 | 183 | if args.testing != 3: 184 | tester.test_dataset(savepath=this_savepath, 185 | toeval=toeval, 186 | toeval_names=toeval_names, 187 | restore_from=logdir, 188 | testpath = dataset_path) 189 | vp, vint_pred[n], masks, lfiles, pfiles = tester.get_preds(labelname="vp", 190 | predname="vint", 191 | maskname="valid", 192 | savepath=this_savepath, 193 | testpath=dataset_path) 194 | vrms[n], vrms_pred[n], _, _ , _ = tester.get_preds(labelname="vrms", 195 | predname="vrms", 196 | savepath=this_savepath, 197 | testpath=dataset_path) 198 | ref[n], ref_pred[n], _, _ , _ = tester.get_preds(labelname="tlabels", 199 | predname="ref", 200 | savepath=this_savepath, 201 | testpath=dataset_path) 202 | vint[n] = [None for _ in range(len(vp))] 203 | for ii in range(len(vint[n])): 204 | vint[n][ii] = interval_velocity_time(vp[ii], pars=pars) 205 | vint[n][ii] = vint[n][ii][::pars.resampling] 206 | vint_pred[n][ii] = vint_pred[n][ii]*(pars.vp_max - pars.vp_min) + pars.vp_min 207 | vrms_pred[n][ii] = vrms_pred[n][ii] * (pars.vp_max - pars.vp_min) + pars.vp_min 208 | vrms[n][ii] = vrms[n][ii] * (pars.vp_max - pars.vp_min) + pars.vp_min 209 | ref_pred[n][ii] = np.argmax(ref_pred[n][ii], axis=1) 210 | ind0 = np.nonzero(ref[n][ii])[0][0] 211 | masks[ii][0:ind0] = 0 212 | vint[n] = np.array(vint[n]) 213 | vint_pred[n] = np.array(vint_pred[n]) 214 | vrms[n] = np.array(vrms[n]) 215 | vrms_pred[n] = np.array(vrms_pred[n]) 216 | ref[n] = np.array(ref[n]) 217 | ref_pred[n] = np.array(ref_pred[n]) 218 | 219 | name = "Results for dhmin= %f, layer_num_min= %f, NN %d" % (dhmin, layer_num_min, n) 220 | # get_rms(name, masks, vint_pred[n], vint[n], vrms_pred[n], 221 | # vrms[n], ref_pred[n], ref[n]) 222 | 223 | 224 | vint = np.mean(vint, axis=0) 225 | vint_pred_std = np.std(vint_pred, axis=0) 226 | vint_pred = np.mean(vint_pred, axis=0) 227 | vrms = np.mean(vrms, axis=0) 228 | vrms_pred_std = np.std(vrms_pred, axis=0) 229 | vrms_pred = np.mean(vrms_pred, axis=0) 230 | ref_pred = mode(ref_pred, axis=0).mode[0] 231 | ref = mode(ref, axis=0).mode[0] 232 | 233 | name = "Results for dhmin= %f, layer_num_min= %f, total" % (dhmin, layer_num_min) 234 | (vint_rmse, vrms_rmse, true_pos, 235 | true_neg, false_pos, false_neg) = get_rms(name, masks, 236 | vint_pred, vint, 237 | vrms_pred, vrms, 238 | ref_pred, ref) 239 | print("Standard deviation for vint %f, m/s" % np.mean(vint_pred_std)) 240 | print("Standard deviation for vrms %f, m/s" % np.mean(vrms_pred_std)) 241 | vint_rmse_all += vint_rmse 242 | vrms_rmse_all += vrms_rmse 243 | true_pos_all += true_pos 244 | true_neg_all += true_neg 245 | false_pos_all += false_pos 246 | false_neg_all += false_neg 247 | 248 | masks = np.array(masks) 249 | rmses = np.sqrt(np.sum(masks * (vrms - vrms_pred) ** 2, axis=1) / np.sum( 250 | masks == 1, axis=1)) 251 | sort_rmses = np.argsort(rmses) 252 | perc10 = sort_rmses[int(len(sort_rmses) * 0.1)] 253 | perc50 = sort_rmses[int(len(sort_rmses) * 0.5)] 254 | perc90 = sort_rmses[int(len(sort_rmses) * 0.9)] 255 | file = h5.File(lfiles[perc10], "r") 256 | data10 = file['data'][:] 257 | file.close() 258 | file = h5.File(lfiles[perc50], "r") 259 | data50 = file['data'][:] 260 | file.close() 261 | file = h5.File(lfiles[perc90], "r") 262 | data90 = file['data'][:] 263 | file.close() 264 | 265 | t10 = (np.nonzero(ref[perc10, :])[0][0] - 100) * pars.dt * pars.resampling 266 | t50 = (np.nonzero(ref[perc50, :])[0][0] - 100) * pars.dt * pars.resampling 267 | t90 = (np.nonzero(ref[perc90, :])[0][0] - 100) * pars.dt * pars.resampling 268 | 269 | plot_predictions_semb3([data10[:, :], data50[:, :], data90[:,:]], 270 | [vrms[perc10, :], vrms[perc50, :], vrms[perc90, :]], 271 | [vrms_pred[perc10, :], vrms_pred[perc50, :], vrms_pred[perc90, :]], 272 | [ref[perc10, :], ref[perc50, :], ref[perc90, :]], 273 | [ref_pred[perc10, :], ref_pred[perc50, :], ref_pred[perc90, :]], 274 | [vint[perc10, :], vint[perc50, :], vint[perc90, :]], 275 | [vint_pred[perc10, :], vint_pred[perc50, :], vint_pred[perc90, :]], 276 | [masks[perc10, :], masks[perc50, :], masks[perc90, :]], 277 | pars, clip=0.02, clipsemb=0.6, plot_semb=True, 278 | vint_pred_std = [vint_pred_std[perc10, :], vint_pred_std[perc50, :], vint_pred_std[perc90, :]], 279 | vpred_std = [vrms_pred_std[perc10, :], vrms_pred_std[perc50, :], vrms_pred_std[perc90, :]], 280 | tmin = [t10, t50, t90], 281 | textlabels=["$P_{10}$", 282 | "$P_{50}$", 283 | "$P_{90}$"], 284 | savefile="Paper/Fig/Case_article_test_dhmin"+str(dhmin)+"_lnummin" +str(layer_num_min)) 285 | 286 | n = len(dhmins) * len(layer_num_mins) 287 | print("Total Results") 288 | print("Interval velocity RMSE: %f m/s" % (vint_rmse_all/n)) 289 | print("RMS velocity RMSE: %f m/s" % (vrms_rmse_all / n)) 290 | print("True positive: %f, True negative: %f, False positive %f " 291 | "False negative: %f" % (true_pos_all/n, 292 | true_neg_all/n, 293 | false_pos_all/n, 294 | false_neg_all/n)) 295 | 296 | -------------------------------------------------------------------------------- /Case_article_test2D.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Defines parameters for 2D testing, creates the dataset make predictions 5 | """ 6 | from vrmslearn.Trainer import Trainer 7 | from vrmslearn.RCNN import RCNN 8 | from vrmslearn.ModelParameters import ModelParameters 9 | from vrmslearn.ModelGenerator import generate_random_2Dlayered, interval_velocity_time, calculate_vrms 10 | from vrmslearn.SeismicGenerator import SeismicGenerator, mute_direct 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import os 14 | import argparse 15 | from shutil import rmtree 16 | import h5py as h5 17 | import fnmatch 18 | from scipy.signal import medfilt 19 | import matplotlib.gridspec as gridspec 20 | from mpl_toolkits.axes_grid1 import make_axes_locatable 21 | from scipy.ndimage import gaussian_filter 22 | 23 | def cmp_pos(rec_pos, src_pos, bin): 24 | ng = rec_pos.shape[0] / src_pos.shape[0] 25 | 26 | src_pos = np.repeat(src_pos, ng) 27 | cmps = ((src_pos + rec_pos) / 2 / bin).astype(int) * bin 28 | offsets = src_pos - rec_pos 29 | 30 | ind = np.lexsort((offsets, cmps)) 31 | cmps = cmps[ind] 32 | unique_cmps, counts = np.unique(cmps, return_counts=True) 33 | cmax = np.max(counts) 34 | firstcmp = unique_cmps[np.argmax(counts == cmax)] 35 | lastcmp = unique_cmps[-np.argmax(counts[::-1] == cmax) - 1] 36 | ind1 = np.argmax(cmps == firstcmp) 37 | ind2 = np.argmax(cmps > lastcmp) 38 | 39 | 40 | return (ind2-ind1)/cmax 41 | 42 | def sort_cmps(data, rec_pos, src_pos, bin): 43 | 44 | ng = rec_pos.shape[0] / src_pos.shape[0] 45 | 46 | src_pos = np.repeat(src_pos, ng) 47 | cmps = ((src_pos + rec_pos) / 2 / bin).astype(int) * bin 48 | offsets = src_pos - rec_pos 49 | 50 | ind = np.lexsort((offsets, cmps)) 51 | cmps = cmps[ind] 52 | unique_cmps, counts = np.unique(cmps, return_counts=True) 53 | cmax = np.max(counts) 54 | firstcmp = unique_cmps[np.argmax(counts == cmax)] 55 | lastcmp = unique_cmps[-np.argmax(counts[::-1] == cmax) - 1] 56 | ind1 = np.argmax(cmps == firstcmp) 57 | ind2 = np.argmax(cmps > lastcmp) 58 | ntraces = cmps[ind1:ind2].shape[0] 59 | data_cmp = np.zeros([data.shape[0], ntraces]) 60 | n = 0 61 | for ii, jj in enumerate(ind): 62 | if ii >= ind1 and ii < ind2: 63 | data_cmp[:, n] = data[:, jj] 64 | n += 1 65 | return data_cmp 66 | 67 | def get_first_cmp_pos(rec_pos, src_pos, bin): 68 | ng = rec_pos.shape[0] / src_pos.shape[0] 69 | src_pos = np.repeat(src_pos, ng) 70 | cmps = ((src_pos + rec_pos) / 2 / bin).astype(int) * bin 71 | offsets = src_pos - rec_pos 72 | 73 | ind = np.lexsort((offsets, cmps)) 74 | cmps = cmps[ind] 75 | unique_cmps, counts = np.unique(cmps, return_counts=True) 76 | cmax = np.max(counts) 77 | firstcmp = unique_cmps[np.argmax(counts == cmax)] 78 | 79 | return firstcmp 80 | 81 | 82 | if __name__ == "__main__": 83 | 84 | # Initialize argument parser 85 | parser = argparse.ArgumentParser() 86 | 87 | # Add arguments to parse for training 88 | parser.add_argument( 89 | "--logdir", 90 | type=str, 91 | default="logs", 92 | help="Checkpoint filename for which to predict" 93 | ) 94 | parser.add_argument( 95 | "--niter", 96 | type=int, 97 | default=1000, 98 | help="Iteration number of the checkpoint file" 99 | ) 100 | parser.add_argument( 101 | "--create_data", 102 | type=int, 103 | default=1, 104 | help="If 1: create the 2D dataset" 105 | ) 106 | parser.add_argument( 107 | "--data_from", 108 | type=int, 109 | default=0, 110 | help="Start that example creation from data_from" 111 | ) 112 | 113 | # Parse the input for training parameters 114 | args, unparsed = parser.parse_known_args() 115 | 116 | ndatasets = 100 117 | 118 | """ 119 | __________________Find model directories______________________ 120 | """ 121 | dirs = [] 122 | for dir1 in os.listdir('./'): 123 | if os.path.isdir(dir1): 124 | for dir2 in os.listdir(dir1): 125 | path2 = os.path.join(dir1, dir2) 126 | if os.path.isdir(path2): 127 | dirs.append(path2) 128 | 129 | logdirs = fnmatch.filter(dirs, args.logdir) 130 | logdirs.sort() 131 | 132 | """ 133 | _________________________Define the parameters______________________ 134 | """ 135 | pars = ModelParameters() 136 | pars.flat = False 137 | pars.NX = 1700 138 | pars.NZ = 750 * 2 139 | pars.dh = 6.25 140 | pars.peak_freq = 26 141 | 142 | pars.num_layers = 0 143 | pars.layer_dh_min = 10 # minimum number of grid cells that a layer must span 144 | pars.layer_num_min = 25 # minimum number of layers 145 | pars.angle_max = 8 146 | pars.dangle_max = 3 147 | pars.amp_max = 0 148 | pars.max_texture = 0.08 149 | pars.texture_xrange = 1 150 | pars.texture_zrange = 1.95*pars.NZ 151 | 152 | pars.vp_min = 1300.0 # maximum value of vp (in m/s) 153 | pars.vp_max = 4000.0 # minimum value of vp (in m/s) 154 | 155 | pars.dt = 0.0004 156 | pars.resampling = 10 157 | pars.NT = int(8.0 / pars.dt) 158 | pars.marine = True 159 | pars.velwater = 1500 160 | pars.d_velwater = 60 161 | pars.water_depth = 3500 162 | pars.dwater_depth = 1000 163 | pars.dg = 8 164 | pars.gmin = int(470 / pars.dh) 165 | pars.gmax = int((470 + 72 * pars.dg * pars.dh) / pars.dh) 166 | pars.minoffset = 470 167 | 168 | pars.fs = False 169 | pars.source_depth = (pars.Npad + 4) * pars.dh 170 | pars.receiver_depth = (pars.Npad + 4) * pars.dh 171 | pars.identify_direct = False 172 | 173 | pars.mute_dir = True 174 | 175 | gen = SeismicGenerator(model_parameters=pars) 176 | ds = pars.dg 177 | ng = 72 178 | dg = pars.dg 179 | nearoffset = int(pars.minoffset / pars.dh) 180 | length = ng * dg + nearoffset 181 | 182 | sx = np.arange(pars.Npad + length + 1, 183 | pars.NX - pars.Npad - length, 184 | ds) * pars.dh 185 | sz = sx * 0 + pars.source_depth 186 | sid = np.arange(0, sx.shape[0]) 187 | gen.F.src_pos = np.stack([sx, 188 | sx * 0, 189 | sz, 190 | sid, 191 | sx * 0 + pars.sourcetype], axis=0) 192 | gen.F.src_pos_all = gen.F.src_pos 193 | gen.F.src = np.empty((gen.F.csts['NT'], 0)) 194 | 195 | gx = np.concatenate([ s - np.arange(nearoffset, length, dg) * pars.dh for s in sx], axis=0) 196 | gz = gx * 0 + pars.receiver_depth 197 | gid = np.arange(0, len(gx)) 198 | gsid = np.repeat(sid, ng) 199 | gen.F.rec_pos = np.stack([gx, 200 | gx * 0, 201 | gz, 202 | gsid, 203 | gid, 204 | gx * 0 + 2, 205 | gx * 0, 206 | gx * 0], axis=0) 207 | gen.F.rec_pos_all = gen.F.rec_pos 208 | 209 | ncmps = cmp_pos(gen.F.rec_pos[0,:], gen.F.src_pos[0,:], ds * pars.dh) 210 | 211 | """ 212 | _________________________Generate the dataset______________________ 213 | """ 214 | 215 | workdir = "seiscl_workdir" 216 | savedir = "dataset_article/test2D" 217 | if not os.path.isdir(workdir): 218 | os.mkdir(workdir) 219 | 220 | if not os.path.isdir(savedir): 221 | os.mkdir(savedir) 222 | 223 | examples = fnmatch.filter(os.listdir(savedir), 'example_*') 224 | if args.create_data: 225 | pars.save_parameters_to_disk(savedir + "/model_parameters.hdf5") 226 | for ii in range(args.data_from, ndatasets): 227 | savefile = "example_%d" % ii 228 | if savefile not in examples: 229 | 230 | vp, vs, rho, vels, layers, angles = generate_random_2Dlayered(pars, seed=ii) 231 | file = h5.File(savedir + "/" + savefile, "a") 232 | file["vp"] = vp 233 | file["vels"] = vels 234 | file["layers"] = layers 235 | file["angles"] = angles 236 | # cmp0 = get_first_cmp_pos(gen.F.rec_pos[0,:], gen.F.src_pos[0,:], ds * pars.dh) 237 | # ind0 = int(cmp0 / pars.dh) 238 | # indm = int(ind0+ncmps) 239 | # 240 | # plt.imshow(vp[:, ind0:indm], cmap=plt.get_cmap("jet"), aspect="auto", interpolation='bilinear') 241 | # plt.colorbar() 242 | # plt.show() 243 | 244 | gen.F.set_forward(gen.F.src_pos[3, :], 245 | {'vp': vp, 'vs': vs, 'rho': rho}, 246 | workdir, 247 | withgrad=False) 248 | gen.F.execute(workdir) 249 | data = gen.F.read_data(workdir)[0] 250 | file["data"] = data 251 | data_cmp = sort_cmps(data, 252 | gen.F.rec_pos[0,:], 253 | gen.F.src_pos[0,:], 254 | ds * pars.dh) 255 | file["data_cmp"] = data_cmp 256 | file.close() 257 | 258 | rmtree(workdir) 259 | 260 | """ 261 | ______________________Make predictions for each model______________________ 262 | """ 263 | 264 | examples = fnmatch.filter(os.listdir(savedir), 'example_*') 265 | for logdir in logdirs: 266 | preddir = os.path.join(savedir, logdir) 267 | if not os.path.isdir(preddir): 268 | os.makedirs(preddir) 269 | predictions = fnmatch.filter(os.listdir(preddir), 'example_*_pred') 270 | for ii in range(ndatasets): 271 | savefile = "example_%d" % ii 272 | if savefile in examples and savefile + "_pred" not in predictions: 273 | print(preddir) 274 | print(savefile) 275 | file = h5.File(savedir + "/" + savefile, "r") 276 | data_cmp = file["data_cmp"][::pars.resampling,:] 277 | vp = file["vp"][:] 278 | file.close() 279 | 280 | ns = int(data_cmp.shape[1] / ng) 281 | data = np.zeros([ns, data_cmp.shape[0], ng, 1]) 282 | for jj in range(ns): 283 | data[jj, :, :, 0] = mute_direct(data_cmp[:, ng * jj:ng * (jj + 1)], vp[0,0], pars) 284 | 285 | 286 | vrms = np.zeros([data.shape[0], data.shape[1]]) 287 | vint = np.zeros([data.shape[0], data.shape[1]]) 288 | vint = np.zeros([data.shape[0], data.shape[1]]) 289 | valid = np.zeros([data.shape[0], data.shape[1]]) 290 | tlabels = np.zeros([data.shape[0], data.shape[1]]) 291 | 292 | nn = RCNN(input_size=data[0,:,:,0].shape, 293 | batch_size=ns) 294 | trainer = Trainer(NN=nn, 295 | data_generator=gen, 296 | totrain=False) 297 | 298 | preds = trainer.evaluate(toeval=[nn.output_ref, 299 | nn.output_vint, 300 | nn.output_vrms], 301 | niter=args.niter, 302 | dir=logdir, 303 | batch=[data, vrms, vint, valid, tlabels]) 304 | refpred = np.argmax(preds[0], axis=2) 305 | vint_pred = preds[1] 306 | vrms_pred = preds[2] 307 | 308 | vint_pred = vint_pred * (pars.vp_max - pars.vp_min) + pars.vp_min 309 | vrms_pred = vrms_pred * (pars.vp_max - pars.vp_min) + pars.vp_min 310 | vint = vint_pred * 0 311 | vrms = vint_pred * 0 312 | valid = vint_pred * 0 313 | for jj in range(vint.shape[0]): 314 | cmp0 = get_first_cmp_pos(gen.F.rec_pos[0,:], gen.F.src_pos[0,:], ds * pars.dh) 315 | ind0 = int(cmp0 / pars.dh) 316 | vint[jj, :] = interval_velocity_time(vp[:, ind0+jj * ds], pars=pars)[ 317 | ::pars.resampling] 318 | vrms[jj, :] = calculate_vrms(vp[:, ind0 + jj* ds], pars.dh, 319 | pars.Npad, pars.NT, pars.dt, pars.tdelay, 320 | pars.source_depth)[::pars.resampling] 321 | z0 = int(pars.source_depth/pars.dh) 322 | vid = int((2*np.sum(pars.dh/vp[z0:, ind0+jj * ds]) + pars.tdelay) /pars.dt /pars.resampling) 323 | valid[jj, :vid] = 1 324 | 325 | ng = int(gen.F.rec_pos[0,:].shape[0] / gen.F.src_pos[0,:].shape[0]) 326 | offsets = np.abs(gen.F.rec_pos[0,:ng] - gen.F.src_pos[0, 0]) 327 | t = np.arange(0, data_cmp.shape[0]) * pars.dt * pars.resampling 328 | stack = np.zeros_like(vint) 329 | vrms_pred_smooth = medfilt(vrms_pred, [11, 1]) 330 | 331 | savefile = h5.File(preddir + "/" + savefile + "_pred", "w") 332 | savefile['vint_pred'] = vint_pred 333 | savefile['vrms_pred'] = vrms_pred 334 | savefile['ref_pred'] = refpred 335 | savefile['vint'] = vint 336 | savefile['vrms'] = vrms 337 | savefile['valid'] = valid 338 | savefile['stack'] = stack 339 | savefile.close() 340 | 341 | rmse_vrms = np.zeros(ndatasets) 342 | rmse_vint = np.zeros(ndatasets) 343 | 344 | """ 345 | __________________Take the mean of predictions of the ensemble______________ 346 | """ 347 | for ii in range(ndatasets): 348 | savefile = "example_%d" % ii 349 | vint_pred = 0 350 | vrms_pred = 0 351 | n = 0 352 | for logdir in logdirs: 353 | preddir = os.path.join(savedir, logdir) 354 | predictions = fnmatch.filter(os.listdir(preddir), 'example_*_pred') 355 | 356 | if savefile in examples and (savefile + "_pred") in predictions: 357 | savefile = h5.File(preddir + "/" + savefile + "_pred", "r") 358 | vint_pred += np.transpose(savefile['vint_pred'][:]) 359 | vrms_pred += np.transpose(savefile['vrms_pred'][:]) 360 | vint = np.transpose(savefile['vint'][:]) 361 | vrms = np.transpose(savefile['vrms'][:]) 362 | valid = np.transpose(savefile['valid'][:]) 363 | savefile.close() 364 | n += 1 365 | for jj in range(vint.shape[1]): 366 | ind0 = np.nonzero(vint - vint[0,jj])[0][0] 367 | valid[0:ind0, jj] = 0 368 | 369 | vint_pred = vint_pred / n 370 | vrms_pred = vrms_pred / n 371 | rmse_vint[ii] = np.sqrt(np.sum(valid*((vint_pred - vint))**2)/np.sum(valid)) 372 | rmse_vrms[ii] = np.sqrt(np.sum(valid * ((vrms_pred - vrms)) ** 2) / np.sum(valid)) 373 | 374 | sort_rmses = np.argsort(rmse_vint) 375 | perc10 = sort_rmses[int(len(sort_rmses) * 0.1)] 376 | perc50 = sort_rmses[int(len(sort_rmses) * 0.5)] 377 | perc90 = sort_rmses[int(len(sort_rmses) * 0.8)] 378 | percs = [perc10, perc50, perc90] 379 | 380 | NX = vint_pred.shape[1] 381 | NZ = vint_pred.shape[0] 382 | ds = 50 383 | 384 | """ 385 | _____________________________Create the plot_______________________________ 386 | """ 387 | def plot_model(thisax, v, label, with_ylabel=True, tmin=0, tmax=8, noyaxis=False): 388 | im = thisax.imshow(v / 1000, cmap=plt.get_cmap("jet"), 389 | aspect="auto", interpolation="bilinear", 390 | vmin=pars.vp_min / 1000, vmax=pars.vp_max / 1000, 391 | extent=[0, (NX + 1) * ds / 1000, 392 | (NZ + 1) * pars.dt * pars.resampling, 0]) 393 | thisax.set_xlabel('x (km)') 394 | if with_ylabel: 395 | thisax.set_ylabel('T (s)') 396 | thisax.set_ylim(top=tmin) 397 | thisax.set_ylim(bottom=tmax) 398 | thisax.yaxis.set_ticks(np.arange(tmin,tmax,2)) 399 | if noyaxis: 400 | thisax.yaxis.set_ticks([]) 401 | ymin, ymax =thisax.get_ylim() 402 | xmin, xmax = thisax.get_xlim() 403 | thisax.set_title(label, fontsize="medium") 404 | return im 405 | 406 | fig = plt.figure(figsize=(16/2.54, 8/2.54)) 407 | gs = gridspec.GridSpec(nrows=5, ncols=55, height_ratios=[0.1, 1.2, 1.2, 10, 0.1]) 408 | 409 | labels0 = [ "a)", "b)", "c)"] 410 | labels = [ "True", "Pred", "True", "Pred", "True", "Pred"] 411 | 412 | for ii, perc in enumerate(percs): 413 | savefile = "example_%d" % perc 414 | vint_pred = 0 415 | vrms_pred = 0 416 | n = 0 417 | for logdir in logdirs: 418 | preddir = os.path.join(savedir, logdir) 419 | if not os.path.isdir(preddir): 420 | os.makedirs(preddir) 421 | predictions = fnmatch.filter(os.listdir(preddir), 'example_*_pred') 422 | 423 | if savefile in examples and (savefile + "_pred") in predictions: 424 | savefile = h5.File(preddir + "/" + savefile + "_pred", "r") 425 | vint_pred += np.transpose(savefile['vint_pred'][:]) 426 | vint = np.transpose(savefile['vint'][:]) 427 | valid = np.transpose(savefile['valid'][:]) 428 | savefile.close() 429 | n += 1 430 | vint_pred = vint_pred / n 431 | vint_pred = gaussian_filter(vint_pred, [3, 3]) 432 | vint_pred[valid<1] = np.NaN 433 | 434 | if ii==0: 435 | with_ylabel=True 436 | noyaxis = False 437 | else: 438 | with_ylabel=False 439 | noyaxis = True 440 | ax = fig.add_subplot(gs[3, (19*ii):(19*ii+8)]) 441 | plot_model(ax, vint, label=labels[2*ii], with_ylabel=with_ylabel, tmin=3, tmax=8.01, noyaxis=noyaxis) 442 | ymin, ymax = ax.get_ylim() 443 | xmin, xmax = ax.get_xlim() 444 | ax.text(xmin - 0.2 * (xmax-xmin), ymax + 0.11*(ymax-ymin), 445 | labels0[ii], fontsize="large") 446 | im = plot_model(fig.add_subplot(gs[3, (19*ii+9):(19*ii+17)]), vint_pred, label=labels[2*ii+1], with_ylabel=False, tmin=3, tmax=8.01, noyaxis=True) 447 | 448 | cax=fig.add_subplot(gs[1, 40:55]) 449 | clr = plt.colorbar(im, cax=cax, orientation="horizontal") 450 | cax.xaxis.set_ticks_position("top") 451 | clr.set_ticks(np.arange(1.5, 4.1, 1.25)) 452 | cax.xaxis.tick_top() 453 | cax.set_xlabel('V (km/s)', labelpad=10) 454 | cax.xaxis.set_label_position('top') 455 | 456 | savefile = "Paper/Fig/Case_article_predict2d" 457 | plt.savefig(savefile, dpi=600) 458 | plt.savefig(savefile+"_lowres", dpi=100) 459 | plt.show() 460 | 461 | 462 | print("Vint RMSE is %f m/s" % np.sqrt(np.mean(rmse_vint[rmse_vint!=9999]**2))) 463 | print("Vrms RMSE is %f m/s" % np.sqrt(np.mean(rmse_vrms[rmse_vint != 9999] ** 2))) 464 | 465 | 466 | 467 | 468 | 469 | -------------------------------------------------------------------------------- /Case_article_testreal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Performs the testing on the real dataset (reproduces Figures 5 and 6) 5 | """ 6 | 7 | from plot_prediction import plot_predictions_semb3 8 | from semblance.nmo_correction import stack 9 | import numpy as np 10 | import matplotlib as mpl 11 | import matplotlib.pyplot as plt 12 | mpl.rcParams.update({'font.size': 7}) 13 | import matplotlib.gridspec as gridspec 14 | from mpl_toolkits.axes_grid1 import make_axes_locatable 15 | from scipy.stats import mode 16 | from scipy.signal import medfilt 17 | from scipy.signal import butter, filtfilt 18 | from scipy.ndimage import gaussian_filter 19 | import segyio 20 | 21 | from vrmslearn.Trainer import Trainer 22 | from vrmslearn.RCNN import RCNN 23 | from vrmslearn.ModelParameters import ModelParameters 24 | from vrmslearn.ModelGenerator import generate_random_2Dlayered, interval_velocity_time, calculate_vrms 25 | from vrmslearn.SeismicGenerator import SeismicGenerator, mute_direct, random_static 26 | import matplotlib.pyplot as plt 27 | import numpy as np 28 | import os 29 | from shutil import rmtree 30 | import h5py as h5 31 | import tensorflow as tf 32 | import fnmatch 33 | from scipy.signal import medfilt 34 | import argparse 35 | import time 36 | 37 | def butter_bandpass(lowcut, highcut, fs, order=5): 38 | nyq = 0.5 * fs 39 | low = lowcut / nyq 40 | high = highcut / nyq 41 | if lowcut==0: 42 | b, a = butter(order, high, btype='lowpass', analog=False) 43 | elif highcut==0: 44 | b, a = butter(order, low, btype='highpass', analog=False) 45 | else: 46 | b, a = butter(order, [low, high], btype='band', analog=False) 47 | return b, a 48 | 49 | 50 | def bandpass(data, lowcut, highcut, fs, order=5, axis=-1): 51 | b, a = butter_bandpass(lowcut, highcut, fs, order=order) 52 | y = filtfilt(b, a, data, axis=axis) 53 | return y 54 | 55 | if __name__ == "__main__": 56 | 57 | # Initialize argument parser 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--plots", 60 | type=int, 61 | default=1, 62 | help="1: plot only the first CMP results " 63 | "2: plot the2D RMS and interval velocity section and the " 64 | "stacked section" 65 | ) 66 | parser.add_argument("--logdir", 67 | type=str, 68 | default="Case4b2/2_schedule2_lr0.000800_eps_0.000010_beta10.900000_beta20.980000_batch_size_40", 69 | help="name of the directory of the checkpoint: str" 70 | ) 71 | parser.add_argument( 72 | "--niter", 73 | type=int, 74 | default=10000, 75 | help="number of training iterations of the checkpoint" 76 | ) 77 | parser.add_argument( 78 | "--savepred", 79 | type=int, 80 | default=1, 81 | help="Save predictions to a file. 0: no, 1: yes" 82 | ) 83 | parser.add_argument( 84 | "--recompute", 85 | type=int, 86 | default=0, 87 | help="Recompute predictions. 0: no, 1: yes" 88 | ) 89 | 90 | # Parse the input 91 | args = parser.parse_args() 92 | 93 | dirs = [] 94 | for dir1 in os.listdir('./'): 95 | if os.path.isdir(dir1): 96 | for dir2 in os.listdir(dir1): 97 | path2 = os.path.join(dir1, dir2) 98 | if os.path.isdir(path2): 99 | dirs.append(path2) 100 | 101 | logdirs = fnmatch.filter(dirs, args.logdir) 102 | print(logdirs) 103 | 104 | create_data = True 105 | logdir = args.logdir 106 | niter = args.niter 107 | max_batch = 100 108 | 109 | pars = ModelParameters() 110 | pars.layer_dh_min = 5 111 | pars.layer_num_min = 48 112 | 113 | pars.dh = 6.25 114 | pars.peak_freq = 26 115 | pars.df = 5 116 | pars.wavefuns = [0, 1] 117 | pars.NX = 692 * 2 118 | pars.NZ = 752 * 2 119 | pars.dt = 0.0004 120 | pars.NT = int(8.0 / pars.dt) 121 | pars.resampling = 10 122 | 123 | pars.dg = 8 124 | pars.gmin = int(470 / pars.dh) 125 | pars.gmax = int((470 + 72 * pars.dg * pars.dh) / pars.dh) 126 | pars.minoffset = 470 127 | 128 | pars.vp_min = 1300.0 # maximum value of vp (in m/s) 129 | pars.vp_max = 4000.0 # minimum value of vp (in m/s) 130 | 131 | pars.marine = True 132 | pars.velwater = 1500 133 | pars.d_velwater = 60 134 | pars.water_depth = 3500 135 | pars.dwater_depth = 1000 136 | 137 | pars.fs = False 138 | pars.source_depth = (pars.Npad + 4) * pars.dh 139 | pars.receiver_depth = (pars.Npad + 4) * pars.dh 140 | pars.identify_direct = False 141 | 142 | pars.tdelay*=1.5 143 | padt = int(pars.tdelay / pars.dt / pars.resampling) * 0 144 | 145 | savefile = "./realdata/survey.hdf5" 146 | ng = 72 147 | 148 | file = h5.File(savefile, "r") 149 | data_cmp = file["data_cmp"] 150 | 151 | nbatch = int(data_cmp.shape[1] / ng / max_batch) 152 | ns = int(data_cmp.shape[1] / ng) 153 | 154 | nn = RCNN([data_cmp.shape[0]+padt, ng], 155 | batch_size=max_batch, use_peepholes=False) 156 | 157 | if args.plots ==1: 158 | nbatch = 1 159 | ns = max_batch 160 | 161 | refpred = [] 162 | vint_pred = [] 163 | vpred = [] 164 | 165 | for logdir in logdirs: 166 | if not os.path.isfile(logdir + '/realdatapred.h5') or args.recompute: 167 | print('recomputing') 168 | if os.path.isfile(logdir + '/model.ckpt-' + str(niter) + '.meta'): 169 | data = np.zeros([max_batch, data_cmp.shape[0] + padt, ng, 1]) 170 | refpred.append(np.zeros([data_cmp.shape[0]+ padt, ns])) 171 | vint_pred.append(np.zeros([data_cmp.shape[0]+ padt, ns])) 172 | vpred.append(np.zeros([data_cmp.shape[0]+ padt, ns])) 173 | with nn.graph.as_default(): 174 | saver = tf.train.Saver() 175 | with tf.Session() as sess: 176 | saver.restore(sess, logdir + '/model.ckpt-' + str(niter)) 177 | start_time = time.time() 178 | for ii in range(nbatch): 179 | for jj in range(max_batch): 180 | idmin = ii * max_batch * ng + ng * jj 181 | idmax = ii * max_batch * ng + ng * (jj + 1) 182 | data[jj, padt:, :, 0] = data_cmp[:, idmin:idmax] 183 | evaluated = sess.run([nn.input_scaled, nn.output_ref, 184 | nn.output_vint, nn.output_vrms], 185 | feed_dict={nn.input: data}) 186 | idmin = ii*max_batch 187 | idmax = (ii+1)*max_batch 188 | refpred[-1][:, idmin:idmax] = np.transpose(np.argmax(evaluated[1], axis=2), [1, 0]) 189 | vint_pred[-1][:, idmin:idmax] = np.transpose(evaluated[2]) 190 | vpred[-1][:, idmin:idmax] = np.transpose(evaluated[3]) 191 | print("--- %s seconds ---" % (time.time() - start_time)) 192 | if args.savepred==1: 193 | filesave = h5.File(logdir + '/realdatapred.h5', "w") 194 | filesave['refpred'] = refpred[-1] 195 | filesave['vint_pred'] = vint_pred[-1] 196 | filesave['vpred'] = vpred[-1] 197 | filesave.close() 198 | else: 199 | filesave = h5.File(logdir + '/realdatapred.h5', "r") 200 | refpred.append(filesave['refpred'][:]) 201 | vint_pred.append(filesave['vint_pred'][:]) 202 | vpred.append(filesave['vpred'][:]) 203 | filesave.close() 204 | t = np.arange(0, data_cmp.shape[0]) * pars.dt*pars.resampling - pars.tdelay 205 | offsets = (np.arange(pars.gmin, pars.gmax, pars.dg)) * pars.dh 206 | vrms = gaussian_filter(np.mean(vpred, axis=0), [1, 9]) * (pars.vp_max - pars.vp_min) + pars.vp_min 207 | vint = gaussian_filter(np.median(vint_pred, axis=0), [1, 9]) * (pars.vp_max - pars.vp_min) + pars.vp_min 208 | 209 | if not os.path.isfile("./realdata/survey_stacked.hdf5") or (args.recompute and args.plots==2): 210 | 211 | stacked = np.zeros([data_cmp.shape[0], ns]) 212 | for ii in range(ns): 213 | stacked[:, ii] = stack(data_cmp[:, ii*ng:(ii+1)*ng], 214 | t, offsets, vrms[:,ii]) 215 | filesave = h5.File("./realdata/survey_stacked.hdf5", "w") 216 | filesave['stacked'] = stacked 217 | filesave.close() 218 | else: 219 | filesave = h5.File("./realdata/survey_stacked.hdf5", "r") 220 | stacked = filesave['stacked'][:] 221 | filesave.close() 222 | 223 | if args.plots == 1: 224 | shots = [250, 1000, 1750] 225 | datas = [data_cmp[:, ii*ng:(ii+1)*ng] for ii in shots] 226 | vrmss = [np.mean([v[:,ii] for v in vpred], axis=0) for ii in shots] 227 | vrmss = [v * (pars.vp_max - pars.vp_min) + pars.vp_min for v in vrmss] 228 | vints = [np.mean([v[:,ii] for v in vint_pred], axis=0) for ii in shots] 229 | vints = [v * (pars.vp_max - pars.vp_min) + pars.vp_min for v in vints] 230 | refs = [mode([v[:,ii] for v in refpred] , axis=0).mode[0] for ii in shots] 231 | vrms_stds = [np.std([v[:,ii]* (pars.vp_max - pars.vp_min) + pars.vp_min 232 | for v in vpred], axis=0) for ii in shots] 233 | vint_stds = [np.std([v[:,ii] * (pars.vp_max - pars.vp_min) + pars.vp_min 234 | for v in vint_pred], axis=0) for ii in shots] 235 | 236 | plot_predictions_semb3(datas, 237 | None, 238 | vrmss, 239 | None, 240 | refs, 241 | None, 242 | vints, None, 243 | pars, plot_semb=True, vmin=1400, vmax=3400, dv=50, 244 | vpred_std =vrms_stds, 245 | vint_pred_std = vint_stds, clip=0.05, 246 | tmin = 2, tmax=10, 247 | savefile="./Paper/Fig/realdata_semblance", 248 | with_nmo=True 249 | ) 250 | 251 | 252 | if args.plots == 2: 253 | 254 | def plot_model(thisax, v, label, extent = None, cbar=True, vmin=None, vmax=None, 255 | cmap=None): 256 | if cmap is None: 257 | cmap=plt.get_cmap("jet") 258 | im = thisax.imshow(v, cmap=cmap, 259 | interpolation='bilinear', 260 | aspect="auto", 261 | extent=extent, vmin=vmin, vmax=vmax) 262 | thisax.set_xlabel('CMP') 263 | thisax.set_ylabel('T (s)') 264 | thisax.set_ylim(bottom=10, top=2) 265 | thisax.set_xlim(left=1, right=2080) 266 | divider = make_axes_locatable(thisax) 267 | cax = divider.append_axes("right", size="5%", pad=0.1) 268 | if cbar: 269 | clr = plt.colorbar(im, cax=cax) 270 | cax.xaxis.set_ticks_position("top") 271 | cax.xaxis.tick_top() 272 | cax.set_xlabel('V (km/s)', labelpad=10) 273 | cax.xaxis.set_label_position('top') 274 | else: 275 | cax.axis('off') 276 | ymin, ymax =thisax.get_ylim() 277 | xmin, xmax = thisax.get_xlim() 278 | thisax.text(xmin - 0.05 * (xmax - xmin), ymax + 0.15 * (ymax - ymin), 279 | label, ha="right", va="top", fontsize="large") 280 | 281 | fig = plt.figure(figsize=(15 / 2.54, 23 / 2.54)) 282 | gridspec.GridSpec(4,1) 283 | 284 | 285 | extent = [0, vrms.shape[1], np.max(t), 0] 286 | plot_model(plt.subplot2grid( (4,1), (0,0)), vrms/1000, "a)", extent=extent) 287 | plot_model(plt.subplot2grid( (4,1), (1,0)), vint/1000, "b)", extent=extent, vmin=1.4, vmax=3.1) 288 | clip = 0.15 289 | stacked = stacked * (np.reshape(t, [-1, 1])**2 + 1e-6) 290 | stacked = stacked / np.sqrt(np.sum(stacked**2, axis=0)) 291 | vmax = np.max(stacked) * clip 292 | vmin = -vmax 293 | plot_model(plt.subplot2grid( (4,1), (2,0)), stacked, "c)", extent=extent, cbar=False, cmap=plt.get_cmap('Greys'), vmax=vmax, vmin=vmin) 294 | 295 | NT = stacked.shape[0] 296 | with segyio.open("./realdata/USGS_line32/CSDS32_1.SGY", "r", 297 | ignore_geometry=True) as segy: 298 | stacked_usgs = np.transpose(np.array([segy.trace[trid] 299 | for trid in range(segy.tracecount)])) 300 | stacked_usgs = stacked_usgs[:, -2401:-160] 301 | stacked_usgs = stacked_usgs[:,::-1] 302 | for kk in range(stacked_usgs.shape[1]): 303 | stacked_usgs[:, kk] = stacked_usgs[:, kk] / np.sqrt(np.sum(stacked_usgs[:, kk] **2)+1e-4) 304 | clip = 0.25 305 | vmax = np.max(stacked_usgs) * clip 306 | vmin = -vmax 307 | plot_model(plt.subplot2grid( (4,1), (3,0)), stacked_usgs, "d)", extent=extent, cbar=False, cmap=plt.get_cmap('Greys'), vmax=vmax, vmin=vmin) 308 | 309 | plt.tight_layout()#rect=[0, 0, 1, 0.995]) 310 | plt.savefig("./Paper/Fig/realdata_stacked", dpi=600) 311 | plt.savefig("./Paper/Fig/realdata_stacked_lowres", dpi=100) 312 | plt.show() 313 | 314 | file.close() 315 | -------------------------------------------------------------------------------- /Case_small.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Defines parameters for Case 1, creates the dataset and train the NN 5 | """ 6 | 7 | from vrmslearn.ModelParameters import ModelParameters 8 | from vrmslearn.SeismicGenerator import SeismicGenerator, generate_dataset 9 | from vrmslearn.Trainer import Trainer 10 | from vrmslearn.RCNN import RCNN 11 | import os 12 | import argparse 13 | import tensorflow as tf 14 | import fnmatch 15 | 16 | if __name__ == "__main__": 17 | 18 | # Initialize argument parser 19 | parser = argparse.ArgumentParser() 20 | 21 | # Add arguments to parse for training 22 | parser.add_argument( 23 | "--nthread", 24 | type=int, 25 | default=3, 26 | help="Number of threads for data creation" 27 | ) 28 | parser.add_argument( 29 | "--nthread_read", 30 | type=int, 31 | default=3, 32 | help="Number of threads used as input producer" 33 | ) 34 | parser.add_argument( 35 | "--logdir", 36 | type=str, 37 | default="./logs", 38 | help="Directory in which to store the checkpoints" 39 | ) 40 | parser.add_argument( 41 | "--training", 42 | type=int, 43 | default=1, 44 | help="1: training only, 0: create dataset only, 2: training+dataset" 45 | ) 46 | parser.add_argument( 47 | "--workdir", 48 | type=str, 49 | default="./seiscl_workdir", 50 | help="name of SeisCL working directory " 51 | ) 52 | parser.add_argument( 53 | "--lr", 54 | type=float, 55 | default=0.0003, 56 | help="learning rate " 57 | ) 58 | parser.add_argument( 59 | "--eps", 60 | type=float, 61 | default=1e-3, 62 | help="epsilon for adadelta" 63 | ) 64 | parser.add_argument( 65 | "--batchsize", 66 | type=int, 67 | default=40, 68 | help="size of the batches" 69 | ) 70 | parser.add_argument( 71 | "--beta1", 72 | type=float, 73 | default=0.9, 74 | help="beta1 for adadelta" 75 | ) 76 | parser.add_argument( 77 | "--beta2", 78 | type=float, 79 | default=0.98, 80 | help="beta2 for adadelta" 81 | ) 82 | 83 | # Parse the input for training parameters 84 | args, unparsed = parser.parse_known_args() 85 | 86 | 87 | savepath = "./dataset_1" 88 | logdir = args.logdir 89 | nthread = args.nthread 90 | niter = 5000 91 | batch_size = args.batchsize 92 | 93 | """ 94 | __________________Define the parameters for Case 1______________________ 95 | """ 96 | pars = ModelParameters() 97 | pars.num_layers = 0 98 | dhmins = [40, 30, 20] 99 | layer_num_mins = [5, 12] 100 | nexamples = 10000 101 | 102 | """ 103 | _______________________Generate the dataset_____________________________ 104 | """ 105 | if not os.path.isdir(savepath): 106 | os.mkdir(savepath) 107 | 108 | n = 100000 109 | if args.training != 1: 110 | for dhmin in dhmins: 111 | for layer_num_min in layer_num_mins: 112 | pars.layer_dh_min = dhmin 113 | pars.layer_num_min = layer_num_min 114 | this_savepath = (savepath 115 | + "/dhmin%d" % dhmin 116 | + "_layer_num_min%d" % layer_num_min) 117 | generate_dataset(pars=pars, 118 | savepath=this_savepath, 119 | nthread=1, 120 | nexamples=nexamples, 121 | workdir=args.workdir, 122 | seed=n) 123 | n += 1 124 | 125 | """ 126 | ___________________________Do the training _____________________________ 127 | 128 | We define 3 stages for inversion, with different alpha, beta gamma in the 129 | loss function: 130 | 1st stage: alpha = 0, beta=1 and gamma=0: we train for reflection 131 | identification 132 | 2nd stage: alpha = 0.2, beta=0.1 and gamma=0.1: we train for reflection 133 | identification and vrms, with regularization on vrms time 134 | derivative (alpha) et higher weights on vrms at reflections 135 | arrival times (gamma) 136 | 3rd stage: alpha = 0.02, beta=0.02 and gamma=0.1, we add weight to vrms 137 | 138 | """ 139 | if args.training != 0: 140 | schedules = [[0.01, 0.9, 0, 0, 0], 141 | [0.05, 0.2, 0, 0, 0], 142 | [0, 0, 0, 0.9, 0.1]] 143 | restore_from = None 144 | npass = 0 145 | for layer_num_min in layer_num_mins: 146 | for ii, schedule in enumerate(schedules): 147 | this_savepath = [] 148 | for dhmin in dhmins: 149 | this_logdir = (logdir 150 | + "/%d" % npass 151 | + "_dhmin%d" % dhmin 152 | + "_layer_num_min%d" % layer_num_min 153 | + "_schedule%d" % ii 154 | + "_lr%f_eps_%f" % (args.lr, args.eps) 155 | + "_beta1%f" % args.beta1 156 | + "_beta2%f" % args.beta2 157 | + "_batch_size_%d" % batch_size) 158 | this_savepath.append(savepath 159 | + "/dhmin%d" % dhmin 160 | + "_layer_num_min%d" % layer_num_min) 161 | 162 | lastfile = this_logdir + 'model.ckpt-' + str(niter) + '*' 163 | 164 | try: 165 | isckpt = fnmatch.filter(os.listdir(this_logdir), 166 | 'model.ckpt-' + str(niter) + '*') 167 | except FileNotFoundError: 168 | isckpt =[] 169 | 170 | if not isckpt: 171 | print(this_logdir) 172 | pars.layer_dh_min = dhmin 173 | pars.layer_num_min = layer_num_min 174 | seismic_gen = SeismicGenerator(model_parameters=pars) 175 | nn = RCNN(input_size=seismic_gen.image_size, 176 | batch_size=batch_size, 177 | alpha=schedule[0], 178 | beta=schedule[1], 179 | gamma=schedule[2], 180 | zeta=schedule[3]) 181 | 182 | if layer_num_min == layer_num_mins[0] and dhmin == dhmins[0]: 183 | learning_rate = args.lr 184 | else: 185 | learning_rate = args.lr/8 186 | if ii == 2: 187 | with nn.graph.as_default(): 188 | var_to_minimize = tf.trainable_variables( 189 | scope='rnn_vint') 190 | var_to_minimize.append(tf.trainable_variables( 191 | scope='Decode_vint')) 192 | else: 193 | var_to_minimize = None 194 | 195 | 196 | 197 | trainer = Trainer(NN=nn, 198 | data_generator=seismic_gen, 199 | checkpoint_dir=this_logdir, 200 | learning_rate=learning_rate, 201 | beta1=args.beta1, 202 | beta2=args.beta2, 203 | epsilon=args.eps, 204 | var_to_minimize=var_to_minimize) 205 | trainer.train_model(niter=niter, 206 | savepath=this_savepath, 207 | restore_from=restore_from, 208 | thread_read=args.nthread_read) 209 | restore_from = this_logdir + '/model.ckpt-' + str(niter) 210 | npass += 1 211 | -------------------------------------------------------------------------------- /Case_small_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Create the test dataset for Case 1, performs the testing and plot results 5 | """ 6 | 7 | from vrmslearn.ModelParameters import ModelParameters 8 | from vrmslearn.SeismicGenerator import SeismicGenerator, generate_dataset 9 | from vrmslearn.ModelGenerator import interval_velocity_time 10 | from vrmslearn.Tester import Tester 11 | from vrmslearn.RCNN import RCNN 12 | from plot_prediction import plot_predictions_semb3 13 | import os 14 | import argparse 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import h5py as h5 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | # Initialize argument parser 23 | parser = argparse.ArgumentParser() 24 | 25 | # Add arguments to parse for training 26 | parser.add_argument( 27 | "--nthread", 28 | type=int, 29 | default=3, 30 | help="Number of threads per gpus for data creation" 31 | ) 32 | parser.add_argument( 33 | "--model_file", 34 | type=str, 35 | default="logs/model.ckpt-5000", 36 | help="Checkpoint filename for which to predict" 37 | ) 38 | parser.add_argument( 39 | "--testing", 40 | type=int, 41 | default=2, 42 | help="1: testing only, 0: create dataset only, 2: testing+dataset, 3: ploting only" 43 | ) 44 | parser.add_argument( 45 | "--workdir", 46 | type=str, 47 | default="./seiscl_workdir", 48 | help="name of SeisCL working directory " 49 | ) 50 | parser.add_argument( 51 | "--dataset_path", 52 | type=str, 53 | default="./dataset_1/test", 54 | help="name of SeisCL working directory " 55 | ) 56 | 57 | 58 | # Parse the input for training parameters 59 | args, unparsed = parser.parse_known_args() 60 | 61 | 62 | savepath = args.dataset_path 63 | 64 | """ 65 | __________________Define the parameters for Case 1______________________ 66 | """ 67 | pars = ModelParameters() 68 | pars.num_layers = 0 69 | dhmins = [40, 30, 20] 70 | layer_num_mins = [5, 12] 71 | nexamples = 10 72 | 73 | """ 74 | _______________________Generate the dataset_____________________________ 75 | """ 76 | if not os.path.isdir(savepath): 77 | os.mkdir(savepath) 78 | 79 | n = 1 80 | if args.testing != 1: 81 | for dhmin in dhmins: 82 | for layer_num_min in layer_num_mins: 83 | pars.layer_dh_min = dhmin 84 | pars.layer_num_min = layer_num_min 85 | this_savepath = savepath + "/dhmin" + str(dhmin) + "layer_num_min" + str(layer_num_min) 86 | 87 | generate_dataset(pars=pars, 88 | savepath=this_savepath, 89 | nthread=1, 90 | nexamples=nexamples, 91 | workdir=args.workdir, 92 | seed=n) 93 | n += 1 94 | 95 | """ 96 | ___________________________Do the testing ______________________________ 97 | """ 98 | seismic_gen = SeismicGenerator(model_parameters=pars) 99 | nn = RCNN(input_size=seismic_gen.image_size, 100 | batch_size=2) 101 | tester = Tester(NN=nn, data_generator=seismic_gen) 102 | toeval = [nn.output_ref, nn.output_vrms, nn.output_vint] 103 | toeval_names = ["ref", "vrms", "vint"] 104 | vint_rmse = 0 105 | vrms_rmse = 0 106 | true_pos_all = 0 107 | true_neg_all = 0 108 | false_pos_all = 0 109 | false_neg_all = 0 110 | for dhmin in dhmins: 111 | for layer_num_min in layer_num_mins: 112 | this_savepath = savepath + "/dhmin" + str(dhmin) + "layer_num_min" + str(layer_num_min) 113 | if args.testing != 3: 114 | tester.test_dataset(savepath=this_savepath, 115 | toeval=toeval, 116 | toeval_names=toeval_names, 117 | restore_from=args.model_file) 118 | vp, vint_pred, masks, lfiles, pfiles = tester.get_preds(labelname="vp", 119 | predname="vint", 120 | maskname="valid", 121 | savepath=this_savepath) 122 | vrms, vrms_pred, _, _, _ = tester.get_preds(labelname="vrms", 123 | predname="vrms", 124 | savepath=this_savepath) 125 | ref, ref_pred, _, _, _ = tester.get_preds(labelname="tlabels", 126 | predname="ref", 127 | savepath=this_savepath) 128 | 129 | vint = [None] * len(vp) 130 | for ii in range(len(vint)): 131 | vint[ii] = interval_velocity_time(vp[ii], pars=pars) 132 | vint[ii] = vint[ii][::pars.resampling] 133 | vint_pred[ii] = vint_pred[ii]*(pars.vp_max - pars.vp_min) + pars.vp_min 134 | vrms_pred[ii] = vrms_pred[ii] * (pars.vp_max - pars.vp_min) + pars.vp_min 135 | vrms[ii] = vrms[ii] * (pars.vp_max - pars.vp_min) + pars.vp_min 136 | ref_pred[ii] = np.argmax(ref_pred[ii], axis=1) 137 | #plt.plot(vint[ii]) 138 | #plt.plot(vint_pred[ii]) 139 | #plt.show() 140 | 141 | 142 | print("Results for dhmin= %f, layer_num_min= %f" % (dhmin, layer_num_min)) 143 | masks = np.array(masks) 144 | nsamples = np.sum(masks == 1) 145 | vint_pred = np.array(vint_pred) 146 | vint = np.array(vint) 147 | rmse = np.sqrt(np.sum(masks * (vint - vint_pred)**2) / nsamples) 148 | vint_rmse += rmse 149 | print("Interval velocity RMSE: %f m/s" % rmse) 150 | 151 | 152 | vrms_pred = np.array(vrms_pred) 153 | vrms = np.array(vrms) 154 | rmse = np.sqrt(np.sum(masks * (vrms - vrms_pred) ** 2) / nsamples) 155 | vrms_rmse += rmse 156 | print("RMS velocity RMSE: %f m/s" % rmse) 157 | 158 | ref_pred = np.array(ref_pred) 159 | ref = np.array(ref) 160 | nsamples = ref.flatten().shape[0] 161 | true_pos = np.sum(((ref - ref_pred) == 0) * (ref == 1)) / nsamples 162 | true_neg = np.sum(((ref - ref_pred) == 0) * (ref == 0)) / nsamples 163 | false_pos = np.sum((ref - ref_pred) == -1) / nsamples 164 | false_neg = np.sum((ref - ref_pred) == 1) / nsamples 165 | 166 | true_pos_all += true_pos 167 | true_neg_all += true_neg 168 | false_pos_all += false_pos 169 | false_neg_all += false_neg 170 | 171 | print("True positive: %f, True negative: %f, False positive %f " 172 | "False negative: %f" % (true_pos, true_neg, false_pos, false_neg)) 173 | 174 | print("") 175 | 176 | rmses = np.sqrt(np.sum(masks * (vint - vint_pred) ** 2, axis=1) / np.sum( 177 | masks == 1, axis=1)) 178 | sort_rmses = np.argsort(rmses) 179 | perc10 = sort_rmses[int(len(sort_rmses) * 0.1)] 180 | perc50 = sort_rmses[int(len(sort_rmses) * 0.5)] 181 | perc90 = sort_rmses[int(len(sort_rmses) * 0.9)] 182 | file = h5.File(lfiles[perc10], "r") 183 | data10 = file['data'][:] 184 | file.close() 185 | file = h5.File(lfiles[perc50], "r") 186 | data50 = file['data'][:] 187 | file.close() 188 | file = h5.File(lfiles[perc90], "r") 189 | data90 = file['data'][:] 190 | file.close() 191 | 192 | plot_predictions_semb3([data10, data50, data90], 193 | [vrms[perc10, :], vrms[perc50, :], vrms[perc90, :]], 194 | [vrms_pred[perc10, :], vrms_pred[perc50, :], vrms_pred[perc90, :]], 195 | [ref[perc10, :], ref[perc50, :], ref[perc90, :]], 196 | [ref_pred[perc10, :], ref_pred[perc50, :], ref_pred[perc90, :]], 197 | [vint[perc10, :], vint[perc50, :], vint[perc90, :]], 198 | [vint_pred[perc10, :], vint_pred[perc50, :], vint_pred[perc90, :]], 199 | [masks[perc10, :], masks[perc50, :], masks[perc90, :]], 200 | pars, 201 | savefile="Paper/Fig/Case1_test_dhmin"+str(dhmin)+"_lnummin" +str(layer_num_min)) 202 | 203 | 204 | n = len(dhmins) * len(layer_num_mins) 205 | print("Total Results") 206 | print("Interval velocity RMSE: %f m/s" % (vint_rmse/n)) 207 | print("RMS velocity RMSE: %f m/s" % (vrms_rmse / n)) 208 | print("True positive: %f, True negative: %f, False positive %f " 209 | "False negative: %f" % (true_pos_all/n, 210 | true_neg_all/n, 211 | false_pos_all/n, 212 | false_neg_all/n)) -------------------------------------------------------------------------------- /Cases_define.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Defines parameters for different cases 5 | """ 6 | 7 | from vrmslearn.ModelParameters import ModelParameters 8 | 9 | def Case_small(): 10 | return ModelParameters() 11 | 12 | def Case_article(noise=0): 13 | pars = ModelParameters() 14 | pars.layer_dh_min = 5 15 | pars.layer_num_min = 48 16 | 17 | pars.dh = 6.25 18 | pars.peak_freq = 26 19 | pars.df = 5 20 | pars.wavefuns = [0, 1] 21 | pars.NX = 692*2 22 | pars.NZ = 752*2 23 | pars.dt = 0.0004 24 | pars.NT = int(8.0 / pars.dt) 25 | pars.resampling = 10 26 | 27 | pars.dg = 8 28 | pars.gmin = int(470 / pars.dh) 29 | pars.gmax = int((470 + 72 * pars.dg * pars.dh) / pars.dh) 30 | pars.minoffset = 470 31 | 32 | pars.vp_min = 1300.0 # maximum value of vp (in m/s) 33 | pars.vp_max = 4000.0 # minimum value of vp (in m/s) 34 | 35 | pars.marine = True 36 | pars.velwater = 1500 37 | pars.d_velwater = 60 38 | pars.water_depth = 3500 39 | pars.dwater_depth = 1000 40 | 41 | pars.fs = False 42 | pars.source_depth = (pars.Npad + 4) * pars.dh 43 | pars.receiver_depth = (pars.Npad + 4) * pars.dh 44 | pars.identify_direct = False 45 | 46 | pars.mute_dir = True 47 | if noise == 1: 48 | pars.random_static = True 49 | pars.random_static_max = 1 50 | pars.random_noise = True 51 | pars.random_noise_max = 0.02 52 | # pars.mute_nearoffset = True 53 | # pars.mute_nearoffset_max = 10 54 | 55 | return pars 56 | 57 | 58 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-centos7 2 | MAINTAINER Gabriel Fabien-Ouellet 3 | 4 | RUN yum -y install epel-release 5 | RUN yum-config-manager --enable epel 6 | RUN yum -y install hdf5-devel \ 7 | && yum -y install make \ 8 | && yum -y install git 9 | ENV CUDA_PATH /usr/local/cuda 10 | 11 | RUN git clone https://github.com/gfabieno/SeisCL.git 12 | RUN cd SeisCL/src \ 13 | && make all api=cuda nompi=1 H5CC=gcc 14 | 15 | ENV PATH="/SeisCL/src:${PATH}" 16 | RUN yum install -y python36 python36-devel 17 | RUN pip3 install tensorflow-gpu==1.14.0 \ 18 | && pip3 install scipy==1.2.0\ 19 | && pip3 install hdf5storage==0.1.15\ 20 | && pip3 install matplotlib==3.0.2\ 21 | && cd /SeisCL && pip3 install . 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | #FROM tensorflow/tensorflow:1.15.0rc2-gpu-py3 31 | #MAINTAINER Gabriel Fabien-Ouellet 32 | 33 | 34 | #RUN apt-get install -y git cuda-nvrtc-10-0 cuda-nvrtc-dev-10-0 cuda-toolkit-10-0 35 | #RUN git clone https://github.com/gfabieno/SeisCL.git 36 | #ENV LIBRARY_PATH /usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH} 37 | #ENV CUDA_PATH /usr/local/cuda 38 | #RUN cd SeisCL/src \ 39 | # && make all api=cuda nompi=1 H5HEAD=/usr/include/hdf5/serial H5LIB=/usr/lib/x86_64-linux-gnu/hdf5/serial/ 40 | #ENV PATH="/SeisCL/src:${PATH}" 41 | #RUN pip install scipy==1.2.0\ 42 | # && pip install hdf5storage==0.1.15\ 43 | # && pip install matplotlib==3.0.2\ 44 | # && cd /SeisCL && pip install . 45 | 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Gabriel Fabien-Ouellet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3492115.svg)](https://doi.org/10.5281/zenodo.3492115) 2 | 3 | # 1D Velocity estimation using neural-networks 4 | 5 | Code for reproducing results in ["Seismic velocity estimation: a deep recurrent neural-network approach"](tobeannounced) 6 | 7 | 8 | ## Installation 9 | 10 | You should clone this repository 11 | 12 | git clone https://github.com/gfabieno/SeisCL.git 13 | 14 | #### a) Use Docker (easiest) 15 | 16 | We provide a Docker image that contains all necessary python libraries like Tensorflow 17 | and the seismic modeling code SeisCL. 18 | 19 | You first need to install the Docker Engine, following the instructions [here](https://docs.docker.com/install/). 20 | To use GPUs, you also need to install the [Nvidia docker](https://github.com/NVIDIA/nvidia-docker). 21 | For the later to work, Nvidia drivers should be installed. 22 | Then, when in the project repository, build the docker image as follows: 23 | 24 | docker build -t seisai:v0 25 | 26 | You can then launch any of the python scripts in this repo as follows: 27 | 28 | docker run --gpus all -it\ 29 | -v `pwd`:`pwd` -w `pwd` \ 30 | --user $(id -u):$(id -g) \ 31 | seisai:v0 Case_article.py --logdir=./Case_article 32 | 33 | This makes accessible all gpus (`--gpus all`), mounting the current directory to a 34 | to the same path in the docker (second line), running the docker as the current user 35 | (for file permission), and runs the script `Case_article.py`. 36 | 37 | #### b) Install all requirements 38 | 39 | It is recommended to create a new virtual environment for this project with Python3. 40 | The main python requirements are: 41 | * [tensorflow](https://www.tensorflow.org). This project was tested with versions 1.8 to 1.15. 42 | The preferred method of installation is through pip, but many options are available. If using pip, be sure to use python version of <=3.7, as tensorflow 1 is not available from 3.8. 43 | * [SeisCL](https://github.com/gfabieno/SeisCL). Follow the instruction in the README of 44 | the SeisCL repository. Preferred compiling options for this project are api=opencl (use 45 | OpenCL, which is faster than CUDA for small models) and nompi=1, because no MPI parallelization is required. 46 | Be sure to install SeisCL's python wrapper. 47 | 48 | Once SeisCL is installed, you can install all other python requirements with 49 | 50 | pip install -r requirements.txt 51 | 52 | 53 | ## Reproducing the results 54 | 55 | Figures of the article can be reproduced with the script `reproduce_results.py`. 56 | This script automates the following steps. 57 | 58 | #### 1. Training set creation 59 | 60 | To create the synthetic training set, do: 61 | 62 | python Case_article.py --training=0 63 | 64 | Note that to speed up data creation, several GPUs can be used as follows: 65 | 66 | export CUDA_VISIBLE_DEVICES=0; python Case_article.py --training=0 67 | export CUDA_VISIBLE_DEVICES=1; python Case_article.py --training=0 68 | 69 | The same strategy can be applied for the subsequent steps. 70 | 71 | #### 2. Testing set creation (1D) 72 | 73 | To create the synthetic 1D training set, do: 74 | 75 | python Case_article_test1D.py --testing=0 76 | 77 | #### 3. Training 78 | 79 | Perform the training with the following command: 80 | 81 | python Case_article.py --training=1 --logdir=Case_article 82 | 83 | This will train 16 models, for which logs and results will be stored in `Case_article0`, 84 | `Case_article1` and so on. 85 | 86 | #### 4. Plot training loss (Figure 2) 87 | 88 | To reproduce Figure 2 for trained model 0: 89 | 90 | python plot_loss.py --logdir=Case_article0 --dataset_path=dataset_article/test/dhmin5layer_num_min10 91 | 92 | #### 5. Testing in 1D (Figure 3) 93 | 94 | Perform the testing on the 1D test set with 95 | 96 | python Case_article_test1D.py --testing=1 --logdir=Case_article*/4_* --niter=1000 97 | 98 | This will test on all the trained models contained in `Case_articleX`, for models at iteration 99 | 1000 for the 4th step of the training. The predictions for each model are stored in 100 | `/dataset_article/test/Case_articleX`. Figure 3 is reproduced along the test statistics. 101 | 102 | #### 6. Testing in 2D (Figure 4) 103 | 104 | To create the 2D test set and perform the testing: 105 | 106 | python Case_article_test2D.py --testing=1 --logdir=Case_article*/4_* --niter=1000 107 | 108 | This will create the 2D testing set (may take a while) and test on all the trained models. 109 | Figure 4 is produced along the test statistics. 110 | 111 | #### 7. Testing on real data (Figures 5 and 6) 112 | 113 | To download and preprocess the real data set: 114 | 115 | cd realdata 116 | python Process_realdata.py 117 | 118 | Then the testing is carried with 119 | 120 | python Case_article_testreal.py --plots=2 --logdir=Case_article*/4_* --niter=1000 121 | 122 | This produces Figures 5 and 6. 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /plot_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Plot one example with generated data 5 | """ 6 | 7 | from vrmslearn.ModelParameters import ModelParameters 8 | from vrmslearn.SeismicGenerator import SeismicGenerator, mute_direct, random_time_scaling, random_noise, random_static 9 | import argparse 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import os 13 | from shutil import rmtree 14 | import h5py as h5 15 | 16 | def plot_one_example(modeled_data, vrms, vp, tlabels, pars): 17 | """ 18 | This method creates one example by generating a random velocity model, 19 | modeling a shot record with it, and also computes the vrms. The three 20 | results are displayed side by side in an window. 21 | 22 | @params: 23 | 24 | @returns: 25 | """ 26 | 27 | # Plot results 28 | fig, ax = plt.subplots(1, 3, figsize=[16, 8]) 29 | 30 | im1 = ax[0].imshow(vp, cmap=plt.get_cmap('hot'), aspect='auto', vmin=0.9 * pars.vp_min, vmax=1.1 * pars.vp_max) 31 | ax[0].set_xlabel("X Cell Index," + " dh = " + str(pars.dh) + " m", 32 | fontsize=12, fontweight='normal') 33 | ax[0].set_ylabel("Z Cell Index," + " dh = " + str(pars.dh) + " m", 34 | fontsize=12, fontweight='normal') 35 | ax[0].set_title("P Interval Velocity", fontsize=16, fontweight='bold') 36 | p = ax[0].get_position().get_points().flatten() 37 | axis_cbar = fig.add_axes([p[0], 0.03, p[2] - p[0], 0.02]) 38 | plt.colorbar(im1, cax=axis_cbar, orientation='horizontal') 39 | 40 | clip = 0.1 41 | vmax = np.max(modeled_data) * clip 42 | vmin = -vmax 43 | 44 | ax[1].imshow(modeled_data, 45 | interpolation='bilinear', 46 | cmap=plt.get_cmap('Greys'), 47 | vmin=vmin, vmax=vmax, 48 | aspect='auto') 49 | 50 | refpred = [ii for ii, t in enumerate(tlabels) if t == 1] 51 | if pars.minoffset == 0: 52 | toff = np.zeros(len(refpred)) + int(modeled_data.shape[1]/2)-2 53 | else: 54 | toff = np.zeros(len(refpred)) 55 | ax[1].plot(toff, refpred, 'r*') 56 | 57 | ax[1].set_xlabel("Receiver Index", fontsize=12, fontweight='normal') 58 | ax[1].set_ylabel("Time Index," + " dt = " + str(pars.dt * 1000 * pars.resampling) + " ms", 59 | fontsize=12, fontweight='normal') 60 | ax[1].set_title("Shot Gather", fontsize=16, fontweight='bold') 61 | 62 | ax[2].plot(vrms * (pars.vp_max-pars.vp_min) + pars.vp_min, np.arange(0, len(vrms))) 63 | ax[2].invert_yaxis() 64 | ax[2].set_ylim(top=0, bottom=len(vrms)) 65 | ax[2].set_xlim(0.9 * pars.vp_min, 1.1 * pars.vp_max) 66 | ax[2].set_xlabel("RMS Velocity (m/s)", fontsize=12, fontweight='normal') 67 | ax[2].set_ylabel("Time Index," + " dt = " + str(pars.dt * 1000 * pars.resampling) + " ms", 68 | fontsize=12, fontweight='normal') 69 | ax[2].set_title("P RMS Velocity", fontsize=16, fontweight='bold') 70 | 71 | plt.show() 72 | 73 | 74 | if __name__ == "__main__": 75 | 76 | parser = argparse.ArgumentParser() 77 | 78 | # Add arguments to parse 79 | parser.add_argument("-l", "--nlayers", 80 | type=int, 81 | default=12, 82 | help="number of layers : int > 0, default = 0") 83 | parser.add_argument("-d", "--device", 84 | type=int, 85 | default=4, 86 | help="device type : int = 2 or 4, default = 2") 87 | parser.add_argument("-f", "--filename", 88 | type=str, 89 | default="", 90 | help="name of the file containing the example") 91 | 92 | # Parse the input 93 | args = parser.parse_args() 94 | 95 | pars = ModelParameters() 96 | pars.dh = 6.25 97 | pars.peak_freq = 26 98 | pars.NX = 692*2 99 | pars.NZ = 752*2 100 | pars.dt = 0.0004 101 | pars.NT = int(8.0 / pars.dt) 102 | pars.resampling = 10 103 | 104 | pars.dg = 8 105 | pars.gmin = int(470 / pars.dh) 106 | pars.gmax = int((470 + 72 * pars.dg * pars.dh) / pars.dh) 107 | pars.minoffset = 470 108 | 109 | pars.vp_min = 1300.0 # maximum value of vp (in m/s) 110 | pars.vp_max = 4000.0 # minimum value of vp (in m/s) 111 | 112 | pars.marine = True 113 | pars.velwater = 1500 114 | pars.d_velwater = 60 115 | pars.water_depth = 3500 116 | pars.dwater_depth = 1000 117 | 118 | pars.fs = False 119 | pars.source_depth = (pars.Npad + 4) * pars.dh 120 | pars.receiver_depth = (pars.Npad + 4) * pars.dh 121 | pars.identify_direct = False 122 | 123 | pars.random_time_scaling = True 124 | 125 | gen = SeismicGenerator(pars) 126 | if args.filename is "": 127 | workdir = "./seiscl_workdir" 128 | if not os.path.isdir(workdir): 129 | os.mkdir(workdir) 130 | data, vrms, vp, valid, tlabels = gen.compute_example(workdir=workdir) 131 | if os.path.isdir(workdir): 132 | rmtree(workdir) 133 | else: 134 | file = h5.File(args.filename, "r") 135 | data = file['data'][:] 136 | vrms = file['vrms'][:] 137 | vp = file['vp'][:] 138 | valid = file['valid'][:] 139 | tlabels = file['tlabels'][:] 140 | file.close() 141 | 142 | vp = np.stack([vp] * vp.shape[0], axis=1) 143 | data = mute_direct(data, vp[0, 0], pars) 144 | data = random_time_scaling(data, pars.dt * pars.resampling, emin=-2, emax=2) 145 | data = random_noise(data, 0.02) 146 | random_static(data, 2) 147 | plot_one_example(data, vrms, vp, tlabels, pars) 148 | 149 | -------------------------------------------------------------------------------- /plot_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Script to plot the loss as a function of epoch (Reproduces Figure 2 of the article) 5 | """ 6 | import tensorflow as tf 7 | import argparse 8 | import matplotlib as mpl 9 | import matplotlib.pyplot as plt 10 | mpl.rcParams.update({'font.size': 7}) 11 | import fnmatch 12 | import os 13 | import numpy as np 14 | from vrmslearn.ModelParameters import ModelParameters 15 | from vrmslearn.SeismicGenerator import SeismicGenerator 16 | from vrmslearn.Tester import Tester 17 | from vrmslearn.RCNN import RCNN 18 | from vrmslearn.ModelGenerator import interval_velocity_time 19 | import h5py as h5 20 | 21 | def get_test_error(dirlog, savepath, dataset_path): 22 | """ 23 | Compute the error on a test set 24 | 25 | @params: 26 | dirlog (str) : Directory containing the trained model. 27 | savepath (str): Directory in which to save the predictions 28 | dataset_path (str): Directory of the test set 29 | 30 | @returns: 31 | vrms_rmse (float) : RMSE for Vrms 32 | vint_rmse: RMSE for Vint 33 | true_pos: Primary reflection identification: Ratio of true positive 34 | true_neg: Primary reflection identification: Ratio of true negative 35 | false_pos: Primary reflection identification: Ratio of false positive 36 | false_neg: Primary reflection identification: Ratio of false negative 37 | """ 38 | 39 | if os.path.isfile(savepath + ".hdf5"): 40 | file = h5.File(savepath + ".hdf5") 41 | vint_rmse = file["vint_rmse"].value 42 | vrms_rmse = file["vrms_rmse"].value 43 | true_pos = file["true_pos"].value 44 | true_neg = file["true_neg"].value 45 | false_pos = file["false_pos"].value 46 | false_neg = file["false_neg"].value 47 | file.close() 48 | else: 49 | pars = ModelParameters() 50 | pars.read_parameters_from_disk(dataset_path+"/model_parameters.hdf5") 51 | seismic_gen = SeismicGenerator(model_parameters=pars) 52 | nn = RCNN(input_size=seismic_gen.image_size, 53 | batch_size=100) 54 | tester = Tester(NN=nn, data_generator=seismic_gen) 55 | toeval = [nn.output_ref, nn.output_vrms, nn.output_vint] 56 | toeval_names = ["ref", "vrms", "vint"] 57 | vint_rmse_all = 0 58 | vrms_rmse_all = 0 59 | true_pos_all = 0 60 | true_neg_all = 0 61 | false_pos_all = 0 62 | false_neg_all = 0 63 | 64 | tester.test_dataset(savepath=savepath, 65 | toeval=toeval, 66 | toeval_names=toeval_names, 67 | restore_from=dirlog, 68 | testpath = dataset_path) 69 | vp, vint_pred, masks, lfiles, pfiles = tester.get_preds(labelname="vp", 70 | predname="vint", 71 | maskname="valid", 72 | savepath=savepath, 73 | testpath = dataset_path) 74 | vrms, vrms_pred, _, _ , _ = tester.get_preds(labelname="vrms", 75 | predname="vrms", 76 | savepath=savepath, 77 | testpath = dataset_path) 78 | ref, ref_pred, _, _ , _ = tester.get_preds(labelname="tlabels", 79 | predname="ref", 80 | savepath=savepath, 81 | testpath = dataset_path) 82 | vint = [None for _ in range(len(vp))] 83 | for ii in range(len(vint)): 84 | vint[ii] = interval_velocity_time(vp[ii], pars=pars) 85 | vint[ii] = vint[ii][::pars.resampling] 86 | vint_pred[ii] = vint_pred[ii]*(pars.vp_max - pars.vp_min) + pars.vp_min 87 | vrms_pred[ii] = vrms_pred[ii] * (pars.vp_max - pars.vp_min) + pars.vp_min 88 | vrms[ii] = vrms[ii] * (pars.vp_max - pars.vp_min) + pars.vp_min 89 | ref_pred[ii] = np.argmax(ref_pred[ii], axis=1) 90 | ind0 = np.nonzero(ref[ii])[0][0] 91 | masks[ii][0:ind0] = 0 92 | vint = np.array(vint) 93 | vint_pred = np.array(vint_pred) 94 | vrms = np.array(vrms) 95 | vrms_pred = np.array(vrms_pred) 96 | ref = np.array(ref) 97 | ref_pred = np.array(ref_pred) 98 | 99 | 100 | 101 | masks = np.array(masks) 102 | nsamples = np.sum(masks == 1) 103 | vint_rmse = np.sqrt(np.sum(masks * (vint - vint_pred)**2) / nsamples) 104 | vrms_rmse = np.sqrt(np.sum(masks * (vrms - vrms_pred) ** 2) / nsamples) 105 | 106 | nsamples = ref.flatten().shape[0] 107 | true_pos = np.sum(((ref - ref_pred) == 0) * (ref == 1)) / nsamples 108 | true_neg = np.sum(((ref - ref_pred) == 0) * (ref == 0)) / nsamples 109 | false_pos = np.sum((ref - ref_pred) == -1) / nsamples 110 | false_neg = np.sum((ref - ref_pred) == 1) / nsamples 111 | 112 | file = h5.File(savepath + ".hdf5") 113 | file["vint_rmse"] = vint_rmse 114 | file["vrms_rmse"] = vrms_rmse 115 | file["true_pos"] = true_pos 116 | file["true_neg"] = true_neg 117 | file["false_pos"] = false_pos 118 | file["false_neg"] = false_neg 119 | file.close() 120 | 121 | return vrms_rmse, vint_rmse, true_pos, true_neg, false_pos, false_neg 122 | 123 | if __name__ == "__main__": 124 | 125 | 126 | # Initialize argument parser 127 | parser = argparse.ArgumentParser() 128 | 129 | # Add arguments to parse for training 130 | parser.add_argument( 131 | "--logdir", 132 | type=str, 133 | default="Case_article0", 134 | help="name of the directory to save logs : str" 135 | ) 136 | parser.add_argument( 137 | "--dataset_path", 138 | type=str, 139 | default="dataset_article/test/dhmin5layer_num_min10", 140 | help="path of the test dataset" 141 | ) 142 | 143 | # Parse the input for training parameters 144 | args, unparsed = parser.parse_known_args() 145 | training_size = 40000 146 | batch_size = 40 147 | savefile = "Paper/Fig/Case4_loss" 148 | 149 | 150 | # Obtain all subdirectories containing tensorflow models inside args.logdir. 151 | dirs = [] 152 | dir_models = {} 153 | for dir1 in os.listdir(args.logdir): 154 | path1 = os.path.join(args.logdir, dir1) 155 | if os.path.isdir(path1): 156 | files = [] 157 | for dir2 in os.listdir(path1): 158 | path2 = os.path.join(path1, dir2) 159 | if os.path.isfile(path2): 160 | files.append(path2) 161 | efiles = fnmatch.filter(files, os.path.join(path1,"events.*")) 162 | efiles.sort() 163 | dirs.append(efiles) 164 | allmodels = fnmatch.filter(files, os.path.join(path1,"model.ckpt-*.meta")) 165 | allmodels.sort() 166 | dir_models[dirs[-1][-1]] = [a[:-5] for a in allmodels] 167 | for dir in dirs: 168 | print(dir) 169 | 170 | # Create the figure 171 | fig, ax = plt.subplots(3, 1, figsize=[8 / 2.54, 12 / 2.54]) 172 | step0= 0 173 | plots = [[] for _ in range(3)] 174 | labels = ["Phase 0", "Phase 1", "Phase 2"] 175 | for ii, dir in enumerate(dirs[:-2]): 176 | step = [] 177 | loss = [] 178 | # Get Loss for each stage of training and each iteration 179 | for e in dir: 180 | for summary in tf.train.summary_iterator(e): 181 | for v in summary.summary.value: 182 | if v.tag == 'Loss_Function/loss': 183 | loss.append(v.simple_value) 184 | step.append(summary.step + step0) 185 | inds = np.argsort(step) 186 | step = np.array(step)[inds][1:] 187 | loss = np.array(loss)[inds][1:] 188 | plots[ii], = ax[0].semilogy(step * batch_size /training_size, loss, basey=2) 189 | 190 | if ii!=0: 191 | steprms0 = steprms[-1] 192 | vrms0 = vrms[-1] 193 | vint0 = vint[-1] 194 | 195 | # Compute test set error for each model during training (or retrieve it) 196 | steprms = [] 197 | vrms = [] 198 | vint = [] 199 | for dirlog in dir_models[dir[-1]]: 200 | savepath = dirlog + "_test/" + args.dataset_path 201 | if not os.path.isdir(savepath): 202 | os.makedirs(savepath) 203 | 204 | vrms_rmse, vint_rmse, _, _, _, _ = get_test_error(dirlog, savepath, args.dataset_path) 205 | steprms.append(int(dirlog.split("-")[-1]) + step0) 206 | vrms.append(vrms_rmse) 207 | vint.append(vint_rmse) 208 | inds = np.argsort(steprms) 209 | steprms = np.array(steprms)[inds][1:] 210 | vrms = np.array(vrms)[inds][1:] 211 | vint = np.array(vint)[inds][1:] 212 | if ii!=0: 213 | steprms = np.insert(steprms, 0, steprms0) 214 | vrms = np.insert(vrms, 0, vrms0) 215 | vint = np.insert(vint, 0, vint0) 216 | ax[1].plot(steprms * batch_size /training_size, vrms) 217 | ax[2].plot(steprms * batch_size /training_size, vint) 218 | 219 | step0 = step[-1] 220 | 221 | # Figure presentation 222 | ax[0].set_xlabel("Epoch") 223 | ax[0].set_ylabel("Loss") 224 | ax[1].set_xlabel("Epoch") 225 | ax[1].set_ylabel("RMSE (m/s)") 226 | ax[2].set_xlabel("Epoch") 227 | ax[2].set_ylabel("RMSE (m/s)") 228 | ax[0].legend(plots, labels, 229 | loc='upper right', 230 | bbox_to_anchor=(1.15, 1.35), 231 | handlelength=0.4) 232 | plt.tight_layout(rect=[0.001, 0, 0.9999, 1]) 233 | plt.savefig(savefile, dpi=600) 234 | plt.savefig(savefile+"_lowres", dpi=100) 235 | plt.show() 236 | 237 | 238 | -------------------------------------------------------------------------------- /plot_prediction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Functions to plot the NN predictions 5 | """ 6 | from vrmslearn.Trainer import Trainer 7 | from vrmslearn.SeismicGenerator import SeismicGenerator 8 | from vrmslearn.RCNN import RCNN 9 | from vrmslearn.ModelParameters import ModelParameters 10 | from vrmslearn.SeismicGenerator import SeismicGenerator, mute_direct, random_static, random_noise, mute_nearoffset, random_filt 11 | from semblance.nmo_correction import semblance_gather, nmo_correction 12 | import argparse 13 | import matplotlib as mpl 14 | import matplotlib.pyplot as plt 15 | mpl.rcParams.update({'font.size': 7}) 16 | from mpl_toolkits.axes_grid1 import make_axes_locatable 17 | import matplotlib.gridspec as gridspec 18 | import numpy as np 19 | import os 20 | from shutil import rmtree 21 | import h5py as h5 22 | from scipy.signal import butter, lfilter 23 | from scipy import ndimage, misc 24 | 25 | 26 | 27 | def butter_bandpass(lowcut, highcut, fs, order=5): 28 | nyq = 0.5 * fs 29 | low = lowcut / nyq 30 | high = highcut / nyq 31 | b, a = butter(order, [low, high], btype='band') 32 | return b, a 33 | def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): 34 | b, a = butter_bandpass(lowcut, highcut, fs, order=order) 35 | y = lfilter(b, a, data) 36 | return y 37 | 38 | def plot_predictions(modeled_data, 39 | vp, vrms, vpred, tlabels, refpred, vint, vint_pred, pars): 40 | """ 41 | This method creates one example by generating a random velocity model, 42 | modeling a shot record with it, and also computes the vrms. The three 43 | results are displayed side by side in an window. 44 | 45 | @params: 46 | 47 | @returns: 48 | """ 49 | 50 | # Plot results 51 | fig, ax = plt.subplots(1, 3, figsize=[16, 8]) 52 | 53 | im1 = ax[0].imshow(vp, cmap=plt.get_cmap('hot'), aspect='auto', 54 | vmin=0.9 * pars.vp_min, vmax=1.1 * pars.vp_max) 55 | ax[0].set_xlabel("X Cell Index," + " dh = " + str(pars.dh) + " m", 56 | fontsize=12, fontweight='normal') 57 | ax[0].set_ylabel("Z Cell Index," + " dh = " + str(pars.dh) + " m", 58 | fontsize=12, fontweight='normal') 59 | ax[0].set_title("P Interval Velocity", fontsize=16, fontweight='bold') 60 | p = ax[0].get_position().get_points().flatten() 61 | axis_cbar = fig.add_axes([p[0], 0.03, p[2] - p[0], 0.02]) 62 | plt.colorbar(im1, cax=axis_cbar, orientation='horizontal') 63 | 64 | clip = 0.05 65 | vmax = np.max(modeled_data) * clip 66 | vmin = -vmax 67 | 68 | ax[1].imshow(modeled_data, 69 | interpolation='bilinear', 70 | cmap=plt.get_cmap('Greys'), 71 | vmin=vmin, vmax=vmax, 72 | aspect='auto') 73 | tlabels = [ii for ii, t in enumerate(tlabels) if t == 1] 74 | 75 | toff = np.zeros(len(tlabels)) + int(modeled_data.shape[1]/2)+1 76 | ax[1].plot(toff, tlabels, '*') 77 | refpred = [ii for ii, t in enumerate(refpred) if t == 1] 78 | toff = np.zeros(len(refpred)) + int(modeled_data.shape[1]/2)-2 79 | ax[1].plot(toff, refpred, 'r*') 80 | ax[1].set_xlabel("Receiver Index", fontsize=12, fontweight='normal') 81 | ax[1].set_ylabel("Time Index," + " dt = " + str(pars.dt * 1000 * pars.resampling) + " ms", 82 | fontsize=12, fontweight='normal') 83 | ax[1].set_title("Shot Gather", fontsize=16, fontweight='bold') 84 | 85 | ax[2].plot(vrms * (pars.vp_max-pars.vp_min) + pars.vp_min, 86 | np.arange(0, len(vrms))) 87 | ax[2].plot(vpred * (pars.vp_max - pars.vp_min) + pars.vp_min, 88 | np.arange(0, len(vpred))) 89 | ax[2].plot(vint * (pars.vp_max-pars.vp_min) + pars.vp_min, 90 | np.arange(0, len(vint))) 91 | ax[2].plot(vint_pred * (pars.vp_max - pars.vp_min) + pars.vp_min, 92 | np.arange(0, len(vint_pred))) 93 | ax[2].invert_yaxis() 94 | ax[2].set_ylim(top=0, bottom=len(vrms)) 95 | ax[2].set_xlim(0.9 * pars.vp_min, 1.1 * pars.vp_max) 96 | ax[2].set_xlabel("RMS Velocity (m/s)", fontsize=12, fontweight='normal') 97 | ax[2].set_ylabel("Time Index," + " dt = " + str(pars.dt * 1000 * pars.resampling) + " ms", 98 | fontsize=12, fontweight='normal') 99 | ax[2].set_title("P RMS Velocity", fontsize=16, fontweight='bold') 100 | 101 | plt.show() 102 | 103 | 104 | def plot_predictions_semb3(modeled_data, 105 | vrms, vpred, 106 | tlabels, refpred, 107 | vint, vint_pred, 108 | masks, 109 | pars, dv=30, vmin=None, vmax = None, 110 | clip=0.05, clipsemb=1.0, 111 | plot_semb = True, 112 | with_nmo = False, 113 | textlabels = None, 114 | savefile=None, 115 | vint_pred_std=None, 116 | vpred_std=None, tmin=None, tmax=None): 117 | """ 118 | This method creates one example by generating a random velocity model, 119 | modeling a shot record with it, and also computes the vrms. The three 120 | results are displayed side by side in a window. 121 | 122 | @params: 123 | 124 | @returns: 125 | """ 126 | 127 | NT = modeled_data[0].shape[0] 128 | ng = modeled_data[0].shape[1] 129 | dt = pars.resampling * pars.dt 130 | if vmin is None: 131 | vmin = pars.vp_min 132 | if vmax is None: 133 | vmax = pars.vp_max 134 | 135 | if pars.gmin ==-1 or pars.gmax ==-1: 136 | offsets = (np.arange(0, ng) - (ng) / 2) * pars.dh * pars.dg 137 | else: 138 | offsets = (np.arange(pars.gmin, pars.gmax, pars.dg)) * pars.dh 139 | 140 | times = np.reshape(np.arange(0, NT * dt, dt) - pars.tdelay, [-1]) 141 | vels = np.arange(vmin - 5*dv, vmax + 2*dv, dv) 142 | 143 | if with_nmo: 144 | fig, ax = plt.subplots(3, 3, figsize=[11 / 2.54, 18 / 2.54]) 145 | else: 146 | fig, ax = plt.subplots(3, 2, figsize=[8 / 2.54, 18 / 2.54]) 147 | 148 | titles = [["a)", "b)", "c)"], ["d)", "e)", "f)"], ["g)", "h)", "i)"]] 149 | labels = ["True", "Pred", "Vint true", "Vint pred", "Vrms true", "Vrms pred", "Vrms std", "Vint std"] 150 | plots = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 151 | 152 | 153 | for ii in range(3): 154 | if plot_semb: 155 | semb = semblance_gather(modeled_data[ii], times, offsets, vels) 156 | 157 | vmax = np.max(modeled_data[ii]) * clip 158 | vmin = -vmax 159 | ax[ii, 0].imshow(modeled_data[ii], 160 | interpolation='bilinear', 161 | cmap=plt.get_cmap('Greys'), 162 | extent=[offsets[0] / 1000, offsets[-1] / 1000, times[-1], times[0]], 163 | vmin=vmin, vmax=vmax, 164 | aspect='auto') 165 | ymin, ymax = ax[ii, 0].get_ylim() 166 | if tmin is not None: 167 | if type(tmin) is list: 168 | ymax = tmin[ii] 169 | else: 170 | ymax = tmin 171 | if tmax is not None: 172 | if type(tmax) is list: 173 | ymin = tmax[ii] 174 | else: 175 | ymin = tmax 176 | xmin, xmax = ax[ii, 0].get_xlim() 177 | if tlabels is not None: 178 | tlabels[ii] = [jj * dt - pars.tdelay for jj, t in enumerate(tlabels[ii]) if t == 1] 179 | refpred[ii] = [jj * dt - pars.tdelay for jj, t in enumerate(refpred[ii]) if t == 1] 180 | if np.min(offsets) < 0: 181 | if tlabels is not None: 182 | tofflabels = np.zeros(len(tlabels[ii])) - 2 * pars.dh * pars.dg 183 | toffpreds = np.zeros(len(refpred[ii])) + 2 * pars.dh * pars.dg 184 | else: 185 | if tlabels is not None: 186 | tofflabels = np.zeros(len(tlabels[ii])) + np.min(np.abs(offsets)) + 1 * pars.dh * pars.dg 187 | toffpreds = np.zeros(len(refpred[ii])) + np.min(np.abs(offsets)) + 3 * pars.dh * pars.dg 188 | if tlabels is not None: 189 | plots[0], = ax[ii, 0].plot(tofflabels / 1000, tlabels[ii], 'r*', markersize=3) 190 | plots[1], = ax[ii, 0].plot(toffpreds / 1000, refpred[ii], 'b*', markersize=3) 191 | 192 | ax[ii, 0].set_xlabel("Offset (km)") 193 | ax[ii, 0].set_ylabel("Time (s)") 194 | #ax[ii, 0].set_title(titles[0][0]) 195 | 196 | 197 | ax[ii, 0].text(xmin - 0.3 * (xmax-xmin), ymax + 0.1*(ymax-ymin), 198 | titles[0][ii], fontsize="large") 199 | 200 | # ax[ii, 2 * jj].xaxis.set_ticks(np.arange(-1, 1.5, 0.5)) 201 | 202 | if ii == 0: 203 | ax[ii, 0].legend(plots[0:2], labels[0:2], loc='upper right', 204 | bbox_to_anchor=(1.13, 1.29)) 205 | if plot_semb: 206 | vmax = np.max(semb) * clipsemb 207 | vmin = np.min(semb) 208 | ax[ii, 1].imshow(semb, 209 | extent=[(vels[0] - dv / 2) / 1000, 210 | (vels[-1] - dv / 2) / 1000, times[-1], times[0]], 211 | cmap=plt.get_cmap('YlOrRd'), 212 | vmin=vmin, vmax=vmax, 213 | interpolation='bilinear', 214 | aspect='auto') 215 | if masks is not None: 216 | if vint is not None: 217 | vint[ii][masks[ii] == 0] = np.NaN 218 | if vrms is not None: 219 | vrms[ii][masks[ii] == 0] = np.NaN 220 | vint_pred[ii][masks[ii] == 0] = np.NaN 221 | vpred[ii][masks[ii] == 0] = np.NaN 222 | if vint is not None: 223 | plots[2], = ax[ii, 1].plot(vint[ii] / 1000, times, '-', color='lightgray') 224 | if vint_pred_std is not None: 225 | plots[6], = ax[ii, 1].plot((vint_pred[ii] + vint_pred_std[ii]) / 1000, times, '-', color='lightgreen', alpha=0.4) 226 | ax[ii, 1].plot((vint_pred[ii] - vint_pred_std[ii]) / 1000, times, '-', color='lightgreen', alpha=0.4) 227 | if vrms is not None: 228 | plots[4], = ax[ii, 1].plot(vrms[ii] / 1000, times, '-g', color='black') 229 | plots[5], = ax[ii, 1].plot(vpred[ii] / 1000, times, '-b') 230 | plots[3], = ax[ii, 1].plot(vint_pred[ii] / 1000, times, '-', color='lightgreen') 231 | 232 | if vpred_std is not None: 233 | plots[7], = ax[ii, 1].plot((vpred[ii] + vpred_std[ii]) / 1000, times, '-b', alpha=0.2) 234 | ax[ii, 1].plot((vpred[ii] - vpred_std[ii]) / 1000, times, '-b', alpha=0.2) 235 | 236 | 237 | ax[ii, 1].xaxis.set_ticks(np.arange(np.ceil(np.min(vels)/1000), 238 | 1+np.floor(np.max(vels)/1000))) 239 | 240 | ax[ii, 1].set_ylim(bottom=ymin, top=ymax) 241 | ax[ii, 0].set_ylim(bottom=ymin, top=ymax) 242 | xmin, xmax = ax[ii, 1].get_xlim() 243 | ax[ii, 1].set_xlabel("Velocity (km/s)") 244 | ax[ii, 1].set_ylabel("Time (s)") 245 | ax[ii, 1].text(xmin - 0.3 * (xmax - xmin), ymax + 0.1 * (ymax - ymin), 246 | titles[1][ii], fontsize="large") 247 | if textlabels: 248 | ax[ii, 1].text(xmin + 0.94 * (xmax - xmin), ymax + - 0.03 * (ymax - ymin), 249 | textlabels[ii], ha="right", va="top", fontsize="large") 250 | 251 | if ii == 0: 252 | ax[ii, 1].legend(plots[2:6], labels[2:6], 253 | loc='upper right', 254 | bbox_to_anchor=(1.15, 1.50), 255 | handlelength=0.4) 256 | if with_nmo: 257 | vmax = np.max(modeled_data[ii]) * clip 258 | vmin = -vmax 259 | data_nmo = nmo_correction(modeled_data[ii], times, offsets, vpred[ii], stretch_mute=0.3) 260 | ax[ii, 2].imshow(data_nmo, 261 | interpolation='bilinear', 262 | cmap=plt.get_cmap('Greys'), 263 | extent=[offsets[0] / 1000, offsets[-1] / 1000, times[-1], times[0]], 264 | vmin=vmin, vmax=vmax, 265 | aspect='auto') 266 | ax[ii, 2].set_ylim(bottom=ymin, top=ymax) 267 | ax[ii, 2].set_xlabel("Offset (km)") 268 | ax[ii, 2].set_ylabel("Time (s)") 269 | xmin, xmax = ax[ii, 0].get_xlim() 270 | ax[ii, 2].text(xmin - 0.3 * (xmax-xmin), ymax + 0.1*(ymax-ymin), 271 | titles[2][ii], fontsize="large") 272 | 273 | plt.tight_layout(rect=[0, 0, 1, 0.995]) 274 | if savefile: 275 | plt.savefig(savefile, dpi=600) 276 | plt.savefig(savefile+"_lowres", dpi=100) 277 | plt.show() 278 | 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | 284 | # Set pref_device_type = 4 285 | pref_device_type = 4 286 | 287 | # Initialize argument parser 288 | parser = argparse.ArgumentParser() 289 | 290 | # Add arguments to parse for training 291 | parser.add_argument( 292 | "--logdir", 293 | type=str, 294 | default="logs", 295 | help="name of the directory to save logs : str" 296 | ) 297 | parser.add_argument( 298 | "--filename", 299 | type=str, 300 | default="dataset_1/dhmin40_layer_num_min5/example_1_31891", 301 | help="name of the directory to save logs : str" 302 | ) 303 | parser.add_argument( 304 | "--fileparam", 305 | type=str, 306 | default="dataset_1/dhmin40_layer_num_min5/example_1_31891", 307 | help="name of the directory that contains the model parameters: str" 308 | ) 309 | parser.add_argument( 310 | "--niter", 311 | type=int, 312 | default=5000, 313 | help="number of training iterations : int > 0" 314 | ) 315 | parser.add_argument( 316 | "--nbatch", 317 | type=int, 318 | default=10, 319 | help="number of gathers in one batch : int > 0" 320 | ) 321 | parser.add_argument( 322 | "--nlayers", 323 | type=int, 324 | default=2, 325 | help="number of layers in the model : int > 0" 326 | ) 327 | parser.add_argument( 328 | "--layer_num_min", 329 | type=int, 330 | default=5, 331 | help="number of layers in the model : int > 0" 332 | ) 333 | parser.add_argument("-d", "--device", 334 | type=int, 335 | default=4, 336 | help="device type : int = 2 or 4, default = 2") 337 | 338 | 339 | # Parse the input for training parameters 340 | args, unparsed = parser.parse_known_args() 341 | 342 | # Test for input errors 343 | def print_usage_error_message(): 344 | print("\nUsage error.\n") 345 | parser.print_help() 346 | 347 | if args.niter < 0: 348 | print_usage_error_message() 349 | exit() 350 | 351 | if args.nlayers <= -1: 352 | print_usage_error_message() 353 | exit() 354 | 355 | if args.nbatch <= 0: 356 | print_usage_error_message() 357 | exit() 358 | 359 | parameters = ModelParameters() 360 | parameters.read_parameters_from_disk(args.fileparam) 361 | parameters.device_type = args.device 362 | parameters.num_layers = args.nlayers 363 | #parameters.read_parameters_from_disk(filename='dataset_3/dhmin40_layer_num_min5/model_parameters.hdf5') 364 | gen = SeismicGenerator(parameters) 365 | 366 | parameters.mute_nearoffset = False 367 | parameters.random_static = False 368 | parameters.random_noise = False 369 | data, vrms, vint, valid, tlabels = gen.read_example(".", filename=args.filename) 370 | 371 | 372 | # data = mute_direct(data, 1500, parameters) 373 | # #data = random_static(data, 2) 374 | ## data = random_noise(data, 0.01) 375 | ## data = mute_nearoffset(data, 10) 376 | ## data = random_filt(data, 9) 377 | data = np.expand_dims(data, axis=-1) 378 | data = np.expand_dims(data, axis=0) 379 | vrms = np.expand_dims(vrms, axis=0) 380 | vint = np.expand_dims(vint, axis=0) 381 | valid = np.expand_dims(valid, axis=0) 382 | tlabels = np.expand_dims(tlabels, axis=0) 383 | f = h5.File(args.filename, "r") 384 | vp = f['vp'][:] 385 | f.close() 386 | 387 | 388 | 389 | nn = RCNN(input_size=gen.image_size, 390 | batch_size=1) 391 | trainer = Trainer(NN=nn, 392 | data_generator=gen, 393 | totrain=False) 394 | 395 | preds = trainer.evaluate(toeval=[nn.output_ref, nn.output_vint, nn.output_vrms], 396 | niter=args.niter, 397 | dir=args.logdir, 398 | batch=[data, vrms, vint, valid, tlabels]) 399 | 400 | refpred = np.argmax(preds[0][0,:], axis=1) 401 | vint_pred = preds[1] 402 | vpred = preds[2] 403 | vp = np.stack([vp] * vp.shape[0], axis=1) 404 | 405 | 406 | plot_predictions_semb(data[0,:,:,0], 407 | vp, 408 | vrms[0,:], 409 | vpred[0,:], 410 | tlabels[0,:], 411 | refpred, vint[0,:], vint_pred[0,:], parameters, with_semb=False) 412 | 413 | -------------------------------------------------------------------------------- /plot_real_synth.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from vrmslearn.ModelParameters import ModelParameters 4 | from vrmslearn.SeismicGenerator import SeismicGenerator, mute_direct 5 | import argparse 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import os 9 | from shutil import rmtree 10 | import h5py as h5 11 | 12 | 13 | def plot_two_gathers(data1, data2, pars): 14 | """ 15 | Compares two shot gathers 16 | 17 | @params: 18 | 19 | @returns: 20 | """ 21 | 22 | # Plot results 23 | fig, ax = plt.subplots(1, 2, figsize=[16, 8]) 24 | 25 | 26 | clip = 0.1 27 | vmax = np.max(data1) * clip 28 | vmin = -vmax 29 | 30 | ax[0].imshow(data1, 31 | interpolation='bilinear', 32 | cmap=plt.get_cmap('Greys'), 33 | vmin=vmin, vmax=vmax, 34 | aspect='auto') 35 | 36 | clip = 0.1 37 | vmax = np.max(data2) * clip 38 | vmin = -vmax 39 | 40 | ax[1].imshow(data2, 41 | interpolation='bilinear', 42 | cmap=plt.get_cmap('Greys'), 43 | vmin=vmin, vmax=vmax, 44 | aspect='auto') 45 | plt.show() 46 | 47 | def plot_two_traces(data1, data2, pars): 48 | """ 49 | Compares two shot gathers 50 | 51 | @params: 52 | 53 | @returns: 54 | """ 55 | 56 | # Plot results 57 | fig, ax = plt.subplots(2, 1, figsize=[16, 8]) 58 | 59 | 60 | clip = 0.1 61 | vmax = np.max(data1) * clip 62 | vmin = -vmax 63 | 64 | ax[0].plot(data1[:,1]) 65 | ax[1].plot(data2[:,1]) 66 | plt.show() 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | parser = argparse.ArgumentParser() 72 | 73 | # Add arguments to parse 74 | parser.add_argument("-f1", "--filename1", 75 | type=str, 76 | default="", 77 | help="name of the file containing the synth data") 78 | parser.add_argument("-f2", "--filename2", 79 | type=str, 80 | default="", 81 | help="name of the file containing the real data") 82 | 83 | # Parse the input 84 | args = parser.parse_args() 85 | 86 | 87 | def print_usage_error_message(): 88 | print("\nUsage error.\n") 89 | parser.print_help() 90 | 91 | 92 | pars = ModelParameters() 93 | pars.dh = 6.25 94 | pars.peak_freq = 26 95 | pars.NX = 692 * 2 96 | pars.NZ = 752 * 2 97 | pars.dt = 0.0004 98 | pars.NT = int(8.0 / pars.dt) 99 | pars.resampling = 10 100 | 101 | pars.dg = 8 102 | pars.gmin = int(470 / pars.dh) 103 | pars.gmax = int((470 + 72 * pars.dg * pars.dh) / pars.dh) 104 | pars.minoffset = 470 105 | 106 | pars.vp_min = 1300.0 # maximum value of vp (in m/s) 107 | pars.vp_max = 4000.0 # minimum value of vp (in m/s) 108 | 109 | pars.marine = True 110 | pars.velwater = 1500 111 | pars.d_velwater = 60 112 | pars.water_depth = 3500 113 | pars.dwater_depth = 1000 114 | 115 | pars.fs = False 116 | pars.source_depth = (pars.Npad + 4) * pars.dh 117 | pars.receiver_depth = (pars.Npad + 4) * pars.dh 118 | pars.identify_direct = False 119 | 120 | file = h5.File(args.filename1, "r") 121 | data1 = file['data'][:] 122 | vp = file['vp'][:] 123 | data1 = mute_direct(data1, vp[0], pars) 124 | file.close() 125 | 126 | file = h5.File(args.filename2, "r") 127 | data2 = file["data_cmp"][:data1.shape[0], 1:72] 128 | file.close() 129 | 130 | plot_two_gathers(data1, data2, pars) 131 | plot_two_traces(data1, data2, pars) 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /realdata/Process_realdata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Processing of real data available publicly at: 5 | https://cmgds.marine.usgs.gov/fan_info.php?fan=1978-015-FA 6 | """ 7 | 8 | import urllib.request 9 | import os 10 | import segyio 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import h5py as h5 14 | import scipy.ndimage as ndimage 15 | import math 16 | 17 | if __name__ == "__main__": 18 | 19 | """ 20 | __________________Download the data______________________ 21 | """ 22 | 23 | datapath = "./USGS_line32" 24 | 25 | files = {"32obslog.pdf": "http://cotuit.er.usgs.gov/files/1978-015-FA/NL/001/01/32-obslogs/32obslog.pdf", 26 | "report.pdf": "https://pubs.usgs.gov/of/1995/0027/report.pdf", 27 | "CSDS32_1.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/39/CSDS32_1.SGY"} 28 | 29 | 30 | dfiles = {"U32A_01.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_01.SGY", 31 | "U32A_02.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_02.SGY", 32 | "U32A_03.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_03.SGY", 33 | "U32A_04.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_04.SGY", 34 | "U32A_05.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_05.SGY", 35 | "U32A_06.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_06.SGY", 36 | "U32A_07.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_07.SGY", 37 | "U32A_08.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_08.SGY", 38 | "U32A_09.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_09.SGY"} 39 | # "U32A_10.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_10.SGY", 40 | # "U32A_11.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_11.SGY", 41 | # "U32A_12.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_12.SGY", 42 | # "U32A_13.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_13.SGY", 43 | # "U32A_14.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_14.SGY", 44 | # "U32A_15.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_15.SGY", 45 | # "U32A_16.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_16.SGY", 46 | # "U32A_17.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_17.SGY", 47 | # "U32A_18.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_18.SGY", 48 | # "U32A_19.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_19.SGY", 49 | # "U32A_20.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_20.SGY", 50 | # "U32A_21.SGY": "http://cotuit.er.usgs.gov/files/1978-015-FA/SE/001/18/U32A_21.SGY"} 51 | 52 | 53 | fkeys = sorted(list(dfiles.keys())) 54 | if not os.path.isdir(datapath): 55 | os.mkdir(datapath) 56 | 57 | for file in files: 58 | if not os.path.isfile(datapath + "/" + file): 59 | urllib.request.urlretrieve(files[file], datapath + "/" + file) 60 | 61 | for file in dfiles: 62 | if not os.path.isfile(datapath + "/" + file): 63 | print(file) 64 | urllib.request.urlretrieve(dfiles[file], datapath + "/" + file) 65 | 66 | 67 | """ 68 | __________________Read the segy into numpy______________________ 69 | """ 70 | data = [] 71 | fid = [] 72 | cid = [] 73 | NT = 3071 74 | for file in fkeys: 75 | print(file) 76 | with segyio.open(datapath + "/" + file, "r", ignore_geometry=True) as segy: 77 | fid.append([segy.header[trid][segyio.TraceField.FieldRecord] 78 | for trid in range(segy.tracecount)]) 79 | cid.append([segy.header[trid][segyio.TraceField.TraceNumber] 80 | for trid in range(segy.tracecount)]) 81 | data.append(np.transpose(np.array([segy.trace[trid] 82 | for trid in range(segy.tracecount)]))[:NT,:]) 83 | 84 | 85 | """ 86 | __________________Remove bad shots ______________________ 87 | """ 88 | #correct fid 89 | if len(fid) > 16: 90 | fid[16] = [id if id < 700 else id+200 for id in fid[16]] 91 | if len(fid) > 6: 92 | fid[6] = fid[6][:12180] 93 | cid[6] = cid[6][:12180] 94 | data[6] = data[6][:, :12180] 95 | if len(fid) > 7: 96 | fid[7] = fid[7][36:] 97 | cid[7] = cid[7][36:] 98 | data[7] = data[7][:, 36:] 99 | if len(fid) > 2: #repeated shots between files 03 and 04 100 | fid[2] = fid[2][:8872] 101 | cid[2] = cid[2][:8872] 102 | data[2] = data[2][:, :8872] 103 | fid = np.concatenate(fid) 104 | cid = np.concatenate(cid) 105 | data = np.concatenate(data, axis=1) 106 | 107 | #recnoSpn = InterpText() 108 | #recnoSpn.read('recnoSpn.txt') 109 | 110 | #recnoDelrt = InterpText() 111 | #recnoDelrt.read('recnoDelrt.txt') 112 | 113 | prev_fldr=-9999 114 | fldr_bias=0 115 | shot = 0 * cid -1 116 | delrt = 0 * cid -1 117 | 118 | notshots = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 211, 213, 225, 279, 119 | 335, 387, 400, 493, 528, 553, 561, 571, 120 | 668, 669, 698, 699, 700, 727, 728, 780, 816, 826, 1073, 1219, 121 | 1253, 1254, 1300, 1301, 1418, 1419, 1527, 1741, 2089, 2170, 122 | 2303, 2610, 2957, 2980, 3021, 3104, 3167, 3223, 3268, 3476, 123 | 3707, 3784, 3831, 3934, 4051, 4472, 4671, 4757, 4797] 124 | 125 | for ii in range(fid.shape[0]): 126 | fldr = fid[ii] 127 | tracf = cid[ii] 128 | 129 | if fldr < prev_fldr: 130 | fldr_bias += 1000 131 | 132 | prev_fldr = fldr 133 | 134 | fldr += fldr_bias 135 | if fldr not in notshots: 136 | shot[ii] = 6102 - fldr 137 | 138 | # The time 0 of different files changes. We prepad with zero so that all 139 | # shots begin at time 0 140 | if fldr < 15: 141 | delrt[ii] = 4000 142 | elif fldr < 20: 143 | delrt[ii] = 5000 144 | elif fldr < 1043: 145 | delrt[ii] = 4000 146 | elif fldr < 1841: 147 | delrt[ii] = 3000 148 | elif fldr < 2199: 149 | delrt[ii] = 2000 150 | elif fldr < 2472: 151 | delrt[ii] = 1000 152 | else: 153 | delrt[ii] = 0 154 | 155 | valid = shot > 0 156 | shot = shot[valid] 157 | delrt = delrt[valid] 158 | data = data[:, valid] 159 | 160 | 161 | plt.plot(shot) 162 | plt.show() 163 | 164 | dt = 4 # time step, milliseconds 165 | for ii in range(data.shape[1]): 166 | data[:, ii] = np.concatenate([np.zeros(int(delrt[ii]/dt)), data[:,ii]])[:NT] 167 | 168 | # Open the hdf5 file in which to save the pre-processed data 169 | savefile = h5.File("survey.hdf5", "w") 170 | savefile["data"] = data 171 | 172 | """ 173 | ________________________Trace interpolation____________________________ 174 | """ 175 | 176 | #From the observer log, we get the acquisition parameters: 177 | ds = 50 #shot point spacing 178 | dg1 = 100 #geophone spacing for channels 1-24 179 | dg2 = 50 #geophone spacing for channels 25-48 180 | vwater = 1533 181 | ns = int(data.shape[1]/48) 182 | ng = 72 183 | dg = 50 184 | nearoff = 470 #varies for several shots, we take the most common value 185 | 186 | data_i = np.zeros([data.shape[0], ns*ng]) 187 | t0off = 2*np.sqrt((nearoff / 2)**2 +3000**2)/vwater 188 | for ii in range(ns): 189 | data_i[:, ng*ii:ng*ii+23] = data[:, ii*48:ii*48+23] 190 | data_roll = data[:, ii*48+23:(ii+1) * 48] 191 | n = data_roll.shape[1] 192 | for jj in range(n): 193 | toff = 2 * np.sqrt(((nearoff + dg1 * (n - jj)) / 2) ** 2 + 3000 ** 2) / vwater - t0off 194 | data_roll[:, jj] = np.roll(data_roll[:, jj], -int(toff / 0.004)) 195 | data_roll = ndimage.zoom(data_roll, [1, 2], order=1) 196 | n = data_roll.shape[1] 197 | for jj in range(n): 198 | toff = 2 * np.sqrt( 199 | ((nearoff + dg2 * (n - jj)) / 2) ** 2 + 3000 ** 2) / vwater - t0off 200 | data_roll[:, jj] = np.roll(data_roll[:, jj], int(toff / 0.004)) 201 | data_i[:, ng * ii + 23:ng * (ii + 1)] = data_roll[:, :-1] 202 | 203 | savefile['data_i'] = data_i 204 | 205 | """ 206 | ________________________Resort accorging to CMP____________________________ 207 | """ 208 | ns = int(data_i.shape[1]/72) 209 | shots = np.arange(nearoff + ng*dg, nearoff + ng*dg + ns * ds, ds) 210 | recs = np.concatenate([np.arange(0, 0 + ng * dg, dg) + n*ds for n in range(ns)], axis=0) 211 | shots = np.repeat(shots, ng) 212 | cmps = ((shots + recs)/2 / 50).astype(int) * 50 213 | offsets = shots - recs 214 | 215 | ind = np.lexsort((offsets, cmps)) 216 | cmps = cmps[ind] 217 | unique_cmps, counts = np.unique(cmps, return_counts=True) 218 | firstcmp = unique_cmps[np.argmax(counts == 72)] 219 | lastcmp = unique_cmps[-np.argmax(counts[::-1] == 72)-1] 220 | ind1 = np.argmax(cmps == firstcmp) 221 | ind2 = np.argmax(cmps > lastcmp) 222 | ntraces = cmps[ind1:ind2].shape[0] 223 | data_cmp = np.zeros([data_i.shape[0], ntraces]) 224 | 225 | n = 0 226 | for ii, jj in enumerate(ind): 227 | if ii >= ind1 and ii < ind2: 228 | data_cmp[:, n] = data_i[:, jj] 229 | n += 1 230 | 231 | savefile['data_cmp'] = data_cmp 232 | savefile.close() 233 | 234 | """ 235 | ________________________Plots for quality control___________________________ 236 | """ 237 | # Plot some CMP gather 238 | clip = 0.05 239 | vmax = np.max(data_cmp[:,0]) * clip 240 | vmin = -vmax 241 | plt.imshow(data_cmp[:, :200], 242 | interpolation='bilinear', 243 | cmap=plt.get_cmap('Greys'), 244 | vmin=vmin, vmax=vmax, 245 | aspect='auto') 246 | 247 | plt.show() 248 | 249 | # Constant offset plot 250 | clip = 0.05 251 | vmax = np.max(data_cmp[:,0]) * clip 252 | vmin = -vmax 253 | plt.imshow(data_cmp[:, ::72], 254 | interpolation='bilinear', 255 | cmap=plt.get_cmap('Greys'), 256 | vmin=vmin, vmax=vmax, 257 | aspect='auto') 258 | 259 | plt.show() 260 | 261 | -------------------------------------------------------------------------------- /reproduce_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | This scripts reproduce all results found in the article. 5 | """ 6 | 7 | import subprocess 8 | import argparse 9 | 10 | def callcmd(cmd): 11 | pipes = subprocess.Popen(cmd, 12 | stdout=subprocess.PIPE, 13 | stderr=subprocess.PIPE, 14 | shell=True) 15 | stdout, stderr = pipes.communicate() 16 | 17 | if __name__ == "__main__": 18 | 19 | # Initialize argument parser 20 | parser = argparse.ArgumentParser() 21 | 22 | # Add arguments to parse for training 23 | parser.add_argument("--run_commands", 24 | type=int, 25 | default=1, 26 | help="If 1, runs the commands, else just print them" 27 | ) 28 | 29 | # Parse the input for training parameters 30 | args, unparsed = parser.parse_known_args() 31 | 32 | # 1.0 Training dataset creation on n GPUS 33 | print("##Creating training dataset") 34 | cmd = "python Case_article.py --training=0 " 35 | print(cmd,"\n") 36 | if args.run_commands: 37 | callcmd(cmd) 38 | 39 | # 1.1 Perform the training for nmodel models 40 | print("##Training of the models") 41 | cmd= "python Case_article.py --training=1 --logdir=Case_article" 42 | print(cmd,"\n") 43 | if args.run_commands: 44 | callcmd(cmd) 45 | 46 | # 1.2 Create the 1D test dataset 47 | print("##Creating the 1D test dataset") 48 | cmd= "python Case_article_test1D.py --testing=0 " 49 | print(cmd,"\n") 50 | if args.run_commands: 51 | callcmd(cmd) 52 | 53 | # 1.3 Test on the 1D test dataset 54 | print("##Testing in 1D (Figure 3)") 55 | cmd = "python Case_article_test1D.py --testing=1" 56 | cmd+= " --logdir=Case_article*/4_* --niter=1000" 57 | print(cmd,"\n") 58 | if args.run_commands: 59 | callcmd(cmd) 60 | 61 | # 1.4 Test in 2D 62 | print("##Testing in 2D (Figure 4)") 63 | cmd = "python Case_article_test2D.py --testing=1" 64 | cmd+= " --logdir=Case_article*/4_* --niter=1000" 65 | print(cmd,"\n") 66 | if args.run_commands: 67 | callcmd(cmd) 68 | 69 | # 1.5 Test on real data 70 | print("##Testing on real data (Figure 5 and 6)") 71 | cmd = "cd realdata;" 72 | cmd+= "python Process_realdata.py" 73 | print(cmd,"\n") 74 | if args.run_commands: 75 | callcmd(cmd) 76 | 77 | cmd = "python Case_article_test2D.py --plots=2" 78 | cmd+= " --logdir=Case_article*/4_* --niter=1000" 79 | print(cmd,"\n") 80 | if args.run_commands: 81 | callcmd(cmd) 82 | 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.15.5 2 | scipy==1.2.0 3 | hdf5storage==0.1.15 4 | matplotlib==3.0.2 5 | SeisCL 6 | -------------------------------------------------------------------------------- /semblance/nmo_correction.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of seismic functions to compute semblance, NMO correction and 3 | seismic velocities. 4 | """ 5 | import numpy as np 6 | from scipy.interpolate import CubicSpline 7 | 8 | 9 | def stack(cmp, times, offsets, velocities): 10 | """ 11 | Compute the stacked trace of a list of CMP gathers 12 | 13 | @params: 14 | cmps (numpy.ndarray) : CMP gathers NT X Noffset 15 | times (numpy.ndarray) : 1D array containing the time 16 | offsets (numpy.ndarray): 1D array containing the offset of each trace 17 | velocities (numpy.ndarray): 1D array NT containing the velocities 18 | 19 | @returns: 20 | stacked (numpy.ndarray) : a numpy array NT long containing the stacked 21 | traces of each CMP 22 | """ 23 | 24 | return np.sum(nmo_correction(cmp, times, offsets, velocities), axis=1) 25 | 26 | def semblance_gather(cmp, times, offsets, velocities): 27 | """ 28 | Compute the semblance panel of a CMP gather 29 | 30 | @params: 31 | cmp (numpy.ndarray) : CMP gather NT X Noffset 32 | times (numpy.ndarray) : 1D array containing the time 33 | offsets (numpy.ndarray): 1D array containing the offset of each trace 34 | velocities (numpy.ndarray): 1D array containing the test Nv velocities 35 | 36 | @returns: 37 | semb (numpy.ndarray) : numpy array NTxNv containing semblance 38 | """ 39 | NT = cmp.shape[0] 40 | semb = np.zeros([NT, len(velocities)]) 41 | for ii, vel in enumerate(velocities): 42 | nmo = nmo_correction(cmp, times, offsets, np.ones(NT)*vel) 43 | semb[:,ii] = semblance(nmo) 44 | 45 | return semb 46 | 47 | 48 | def nmo_correction(cmp, times, offsets, velocities, stretch_mute=None): 49 | """ 50 | Compute the NMO corrected CMP gather 51 | 52 | @params: 53 | cmp (numpy.ndarray) : CMP gather NT X Noffset 54 | times (numpy.ndarray) : 1D array containing the time 55 | offsets (numpy.ndarray): 1D array containing the offset of each trace 56 | velocities (numpy.ndarray): 1D array containing the test NT velocities 57 | in time 58 | 59 | @returns: 60 | nmo (numpy.ndarray) : array NTxNoffset containing the NMO corrected CMP 61 | """ 62 | 63 | nmo = np.zeros_like(cmp) 64 | for j, x in enumerate(offsets): 65 | t = [reflection_time(t0, x, velocities[i]) for i, t0 in enumerate(times)] 66 | interpolator = CubicSpline(times, cmp[:, j], extrapolate=False) 67 | amps = np.nan_to_num(interpolator(t), copy=False) 68 | nmo[:, j] = amps 69 | if stretch_mute is not None: 70 | nmo[np.abs((times-t)/(times+1e-10)) > stretch_mute, j] = 0 71 | return nmo 72 | 73 | 74 | def reflection_time(t0, x, vnmo): 75 | """ 76 | Compute the arrival time of a reflecion 77 | 78 | @params: 79 | t0 (float) : Two-way travel-time in seconds 80 | x (float) : Offset in meters 81 | vnmo (float): NMO velocity 82 | 83 | @returns: 84 | t (float): Reflection travel time 85 | """ 86 | 87 | t = np.sqrt(t0**2 + x**2/vnmo**2) 88 | return t 89 | 90 | def semblance(nmo_corrected, window=10): 91 | """ 92 | Compute the semblance of a nmo corrected gather 93 | 94 | @params: 95 | nmo_corrected (numpy.ndarray) : NMO corrected CMP gather NT X Noffset 96 | window (int): Number of time samples to average 97 | 98 | @returns: 99 | semblance (numpy.ndarray): Array NTx1 containing semblance 100 | """ 101 | 102 | num = np.sum(nmo_corrected, axis=1) ** 2 103 | den = np.sum(nmo_corrected ** 2, axis=1) + 1e-12 104 | weights = np.ones(window) / window 105 | num = np.convolve(num, weights, mode='same') 106 | den = np.convolve(den, weights, mode='same') 107 | return num/den 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /vrmslearn/Inputqueue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Input queue to feed example when training with Tensorflow 5 | """ 6 | 7 | from multiprocessing import Process, Queue, Event, Value 8 | import queue 9 | 10 | class Counter(object): 11 | def __init__(self): 12 | self.val = Value('i', 0) 13 | 14 | def increment(self, n=1): 15 | with self.val.get_lock(): 16 | self.val.value += n 17 | 18 | @property 19 | def value(self): 20 | return self.val.value 21 | 22 | class DataGenerator(Process): 23 | """ 24 | Created a new process that will generate new examples with the generator_fun 25 | and will put them in the data Queue data_q. 26 | """ 27 | def __init__(self, 28 | data_q: Queue, 29 | stop_event: Event, 30 | generator_fun=lambda x: 1): 31 | """ 32 | This is the constructor for the class. 33 | 34 | @params: 35 | data_q (Queue): A Queue in which to put generated examples 36 | stop_event (Event): An Event to kill the DataGenerator process 37 | generator_fun (callable): A callable that generate the data. This could 38 | be for example a function that reads a file 39 | 40 | @returns: 41 | """ 42 | super().__init__() 43 | self.done_q = data_q 44 | self.stop_event = stop_event 45 | self.generator_fun = generator_fun 46 | 47 | def run(self): 48 | """ 49 | When called, starts the DataGenerator process. Stop by setting the stop 50 | event. 51 | 52 | @returns: 53 | """ 54 | while not self.stop_event.is_set(): 55 | if not self.done_q.full(): 56 | try: 57 | batch = self.generator_fun() 58 | self.done_q.put(batch) 59 | except FileNotFoundError: 60 | pass 61 | 62 | class DataAggregator(Process): 63 | """ 64 | Create a new process that will aggregate examples created by DataGenerator 65 | processes into a batch of examples 66 | """ 67 | def __init__(self, 68 | data_q: Queue, 69 | batch_q: Queue, 70 | stop_event: Event, 71 | batch_size: int, 72 | n_in_queue: Counter=None, 73 | postprocess_fun = None): 74 | """ 75 | This is the constructor for the class. 76 | 77 | @params: 78 | data_q (Queue): A Queue in which to get new examples 79 | batch_q (Queue): A Queue in which to put new batches of examples 80 | stop_event (Event): An Event to kill the DataAggregator process 81 | batch_size (int): The number of examples in a batch 82 | n_in_queue (Counter): A counter sharable accross processes counting the 83 | number of batches 84 | max_capacity (int): Maximum number of batches in the queue 85 | 86 | @returns: 87 | """ 88 | super().__init__() 89 | self.pending_q = data_q 90 | self.done_q = batch_q 91 | self.stop_event = stop_event 92 | self.batch_size = batch_size 93 | self.batch = [] 94 | if n_in_queue is None: 95 | self.n_in_queue = Counter() 96 | else: 97 | self.n_in_queue = n_in_queue 98 | self.nsaved = 0 99 | self.postprocess_fun = postprocess_fun 100 | 101 | def run(self): 102 | """ 103 | When called, starts the DataGenerator process. Stop by setting the stop 104 | event. 105 | 106 | @returns: 107 | """ 108 | while not self.stop_event.is_set(): 109 | if not self.done_q.full(): 110 | batch = self.pending_q.get() 111 | self.batch.append(batch) 112 | 113 | if len(self.batch) == self.batch_size: 114 | batch = self.batch 115 | if self.postprocess_fun is not None: 116 | batch = self.postprocess_fun(batch) 117 | self.done_q.put(batch) 118 | self.batch = [] 119 | self.n_in_queue.increment() 120 | 121 | class BatchManager: 122 | """ 123 | Creates an input queue for Tensorflow, managing example creation and 124 | examples aggregation on multiple processes. 125 | """ 126 | def __init__(self, 127 | MAX_CAPACITY: int=10, 128 | batch_size: int=3, 129 | generator_fun=[lambda: 1], 130 | postprocess_fun=None, 131 | timeout: int=360): 132 | """ 133 | Creates the DataGenerator and DataAggregator processes and starts them. 134 | Use with a with statement, as it will close processes automatically. 135 | 136 | @params: 137 | MAX_CAPACITY (int): Maximum number of batches or examples in 138 | DataGenerator and DataAggregator queues 139 | batch_size (int): The number of examples in a batch 140 | generator_fun (list): List of callables that generates an example. One 141 | DataGenerator process per element in the list will 142 | be created. 143 | timeout (int): Maximum time to retrieve a batch. Default to 60s, 144 | change if generating a batch takes longer. 145 | 146 | @returns: 147 | """ 148 | self.timeout = timeout 149 | self.generator_fun = generator_fun 150 | self.MAX_CAPACITY = MAX_CAPACITY 151 | self.batch_size = batch_size 152 | self.postprocess_fun = postprocess_fun 153 | self.stop_event = None 154 | self.data_q = None 155 | self.batch_q = None 156 | self.n_in_queue = None 157 | self.data_aggregator = None 158 | self.data_generators = None 159 | 160 | self.init() 161 | 162 | def init(self): 163 | self.stop_event = Event() 164 | self.data_q = Queue(self.MAX_CAPACITY) 165 | self.batch_q = Queue(self.MAX_CAPACITY) 166 | self.n_in_queue = Counter() 167 | self.data_aggregator = DataAggregator(self.data_q, 168 | self.batch_q, 169 | self.stop_event, 170 | self.batch_size, 171 | n_in_queue=self.n_in_queue, 172 | postprocess_fun=self.postprocess_fun) 173 | 174 | self.data_generators = [DataGenerator(self.data_q, 175 | self.stop_event, 176 | generator_fun=self.generator_fun[ii]) 177 | for ii in range(len(self.generator_fun))] 178 | 179 | for w in self.data_generators: 180 | w.start() 181 | self.data_aggregator.start() 182 | 183 | def next_batch(self): 184 | """ 185 | Ouput the next batch of examples in the queue 186 | 187 | @returns: 188 | """ 189 | batch = None 190 | while batch is None: 191 | try: 192 | self.n_in_queue.increment(-1) 193 | batch = self.batch_q.get(timeout=self.timeout) 194 | except queue.Empty: 195 | print("Restarting data_generators") 196 | self.close() 197 | self.init() 198 | 199 | return batch 200 | 201 | def put_batch(self, batch): 202 | """ 203 | Puts back a batch of examples in the queue 204 | 205 | @returns: 206 | """ 207 | if not self.batch_q.full(): 208 | self.batch_q.put(batch) 209 | self.n_in_queue.increment(1) 210 | 211 | def close(self, timeout: int = 5): 212 | """ 213 | Terminate running processes 214 | 215 | @returns: 216 | """ 217 | self.stop_event.set() 218 | 219 | for w in self.data_generators: 220 | w.join(timeout=timeout) 221 | while w.is_alive(): 222 | w.terminate() 223 | self.data_aggregator.join(timeout=timeout) 224 | while self.data_aggregator.is_alive(): 225 | self.data_aggregator.terminate() 226 | 227 | def __enter__(self): 228 | return self 229 | 230 | def __exit__(self, exc_type, exc_value, traceback): 231 | self.close() 232 | 233 | 234 | -------------------------------------------------------------------------------- /vrmslearn/ModelGenerator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Class to generate seismic models and labels for training. 5 | """ 6 | 7 | import numpy as np 8 | import copy 9 | from scipy.signal import gaussian 10 | from scipy.interpolate import interp1d 11 | import argparse 12 | from vrmslearn.ModelParameters import ModelParameters 13 | 14 | 15 | class ModelGenerator(object): 16 | """ 17 | Generate a seismic model with the generate_model method and output the 18 | labels, with generate_labels. As of now, this class generates a 1D layered 19 | model, and the labels correspond to the rms velocity. 20 | """ 21 | 22 | def __init__(self, model_parameters=ModelParameters()): 23 | """ 24 | This is the constructor for the class. 25 | 26 | @params: 27 | model_parameters (ModelParameters) : A ModelParameter object 28 | 29 | @returns: 30 | """ 31 | self.pars = model_parameters 32 | self.vp =None 33 | 34 | def generate_model(self): 35 | """ 36 | Output the media parameters required for seismic modelling, in this case 37 | vp, vs and rho. To create 1D model, set pars.flat to True. For 2D dipping 38 | layer models, set it to False. 39 | 40 | @params: 41 | 42 | @returns: 43 | vp (numpy.ndarray) : numpy array (self.pars.NZ, self.pars.NX) for vp. 44 | vs (numpy.ndarray) : numpy array (self.pars.NZ, self.pars.NX) for vs. 45 | rho (numpy.ndarray) : numpy array (self.pars.NZ, self.pars.NX) for rho 46 | values. 47 | """ 48 | if self.pars.flat: 49 | vp, vs, rho = generate_random_1Dlayered(self.pars) 50 | else: 51 | vp, vs, rho, _, _, _ = generate_random_2Dlayered(self.pars) 52 | 53 | self.vp = copy.copy(vp) 54 | return vp, vs, rho 55 | 56 | def generate_labels(self): 57 | """ 58 | Output the labels attached to modelling of a particular dataset. In this 59 | case, we want to predict vrms from a cmp gather. 60 | 61 | @params: 62 | 63 | @returns: 64 | vrms (numpy.ndarray) : numpy array of shape (self.pars.NT, ) with vrms 65 | values in meters/sec. 66 | valid (numpy.ndarray) : numpy array of shape (self.pars.NT, ) with 1 67 | before the last reflection, 0 afterwards 68 | refs (numpy.ndarray) : Two way travel-times of the reflections 69 | """ 70 | vp = self.vp[:, 0] 71 | vrms = calculate_vrms(vp, 72 | self.pars.dh, 73 | self.pars.Npad, 74 | self.pars.NT, 75 | self.pars.dt, 76 | self.pars.tdelay, 77 | self.pars.source_depth) 78 | refs = generate_reflections_ttime(vp, self.pars) 79 | 80 | # Normalize so the labels are between 0 and 1 81 | vrms = (vrms - self.pars.vp_min) / (self.pars.vp_max - self.pars.vp_min) 82 | indt = np.argwhere(refs > 0.1).flatten()[-1] 83 | valid = np.ones(len(vrms)) 84 | valid[indt:] = 0 85 | 86 | return vrms, valid, refs 87 | 88 | 89 | def calculate_vrms(vp, dh, Npad, NT, dt, tdelay, source_depth): 90 | """ 91 | This method inputs vp and outputs the vrms. The global parameters in 92 | common.py are used for defining the depth spacing, source and receiver 93 | depth etc. This method assumes that source and receiver depths are same. 94 | 95 | The convention used is that the velocity denoted by the interval 96 | (i, i+1) grid points is given by the constant vp[i+1]. 97 | 98 | @params: 99 | vp (numpy.ndarray) : 1D vp values in meters/sec. 100 | dh (float) : the spatial grid size 101 | Npad (int) : Number of absorbing padding grid points over the source 102 | NT (int) : Number of time steps of output 103 | dt (float) : Time step of the output 104 | tdelay (float): Time before source peak 105 | source_depth (float) The source depth in meters 106 | 107 | 108 | @returns: 109 | vrms (numpy.ndarray) : numpy array of shape (NT, ) with vrms 110 | values in meters/sec. 111 | """ 112 | 113 | NZ = vp.shape[0] 114 | 115 | # Create a numpy array of depths corresponding to the vp grid locations 116 | depth = np.zeros((NZ,)) 117 | for i in range(NZ): 118 | depth[i] = i * dh 119 | 120 | # Create a list of tuples of (relative depths, velocity) of the layers 121 | # following the depth of the source / receiver depths, till the last layer 122 | # before the padding zone at the bottom 123 | last_depth = dh * (NZ - Npad - 1) 124 | rdepth_vel_pairs = [(d - source_depth, vp[i]) for i, d in enumerate(depth) 125 | if d > source_depth and d <= last_depth] 126 | first_layer_vel = rdepth_vel_pairs[0][1] 127 | rdepth_vel_pairs.insert(0, (0.0, first_layer_vel)) 128 | 129 | # Calculate a list of two-way travel times 130 | t = [2.0 * (rdepth_vel_pairs[index][0] - rdepth_vel_pairs[index - 1][ 131 | 0]) / vel 132 | for index, (_, vel) in enumerate(rdepth_vel_pairs) if index > 0] 133 | t.insert(0, 0.0) 134 | total_time = 0.0 135 | for i, time in enumerate(t): 136 | total_time += time 137 | t[i] = total_time 138 | 139 | # The last time must be 'dt' * 'NT', so adjust the lists 'rdepth_vel_pairs' 140 | # and 't' by cropping and adjusting the last sample accordingly 141 | rdepth_vel_pairs = [(rdepth_vel_pairs[i][0], rdepth_vel_pairs[i][1]) for 142 | i, time in enumerate(t) 143 | if time <= NT * dt] 144 | t = [time for time in t if time <= NT * dt] 145 | last_index = len(t) - 1 146 | extra_distance = (NT * dt - t[last_index]) * rdepth_vel_pairs[last_index][ 147 | 1] / 2.0 148 | rdepth_vel_pairs[last_index] = ( 149 | extra_distance + rdepth_vel_pairs[last_index][0], 150 | rdepth_vel_pairs[last_index][1]) 151 | t[last_index] = NT * dt 152 | 153 | # Compute vrms at the times in t 154 | vrms = [first_layer_vel] 155 | sum_numerator = 0.0 156 | for i in range(1, len(t)): 157 | sum_numerator += (t[i] - t[i - 1]) * rdepth_vel_pairs[i][1] * \ 158 | rdepth_vel_pairs[i][1] 159 | vrms.append((sum_numerator / t[i]) ** 0.5) 160 | 161 | # Interpolate vrms to uniform time grid 162 | tgrid = np.asarray(range(0, NT)) * dt 163 | vrms = np.interp(tgrid, t, vrms) 164 | vrms = np.reshape(vrms, [-1]) 165 | # Adjust for time delay 166 | t0 = int(tdelay / dt) 167 | vrms[t0:] = vrms[:-t0] 168 | vrms[:t0] = vrms[t0] 169 | 170 | # Return vrms 171 | return vrms 172 | 173 | 174 | def generate_random_1Dlayered(pars, seed=None): 175 | if seed is not None: 176 | np.random.seed(seed) 177 | 178 | if pars.num_layers == 0: 179 | nmin = pars.layer_dh_min 180 | nmax = int(pars.NZ / pars.layer_num_min) 181 | n_layers = np.random.choice(range(pars.layer_num_min, int(pars.NZ/nmin))) 182 | else: 183 | nmin = pars.layer_dh_min 184 | nmax = int(pars.NZ / pars.layer_num_min) 185 | n_layers = int(np.clip(pars.num_layers, nmin, nmax)) 186 | 187 | NZ = pars.NZ 188 | NX = pars.NX 189 | dh = pars.dh 190 | top_min = int(pars.source_depth / dh + 2 * pars.layer_dh_min) 191 | layers = (nmin + np.random.rand(n_layers) * (nmax - nmin)).astype(np.int) 192 | tops = np.cumsum(layers) 193 | ntos = np.sum(layers[tops <= top_min]) 194 | if ntos > 0: 195 | layers = np.concatenate([[ntos], layers[tops > top_min]]) 196 | vels = (pars.vp_min 197 | + np.random.rand() * (pars.vp_max - pars.vp_min - pars.dvmax) 198 | + np.random.rand(len(layers)) * pars.dvmax) 199 | ramp = np.abs(np.max(vels) - pars.vp_max) * np.random.rand() + 0.1 200 | vels = vels + np.linspace(0, ramp, vels.shape[0]) 201 | vels[vels > pars.vp_max] = pars.vp_max 202 | vels[vels < pars.vp_min] = pars.vp_min 203 | if pars.marine: 204 | vels[0] = pars.velwater + (np.random.rand() - 0.5) * 2 * pars.d_velwater 205 | layers[0] = int(pars.water_depth / pars.dh + ( 206 | np.random.rand() - 0.5) * 2 * pars.dwater_depth / pars.dh) 207 | 208 | vel1d = np.concatenate([np.ones(layers[n]) * vels[n] 209 | for n in range(len(layers))]) 210 | if len(vel1d) < NZ: 211 | vel1d = np.concatenate([vel1d, np.ones(NZ - len(vel1d)) * vel1d[-1]]) 212 | elif len(vel1d) > NZ: 213 | vel1d = vel1d[:NZ] 214 | 215 | if pars.rho_var: 216 | rhos = (pars.rho_min 217 | + np.random.rand() * ( 218 | pars.rho_max - pars.rho_min - pars.drhomax) 219 | + np.random.rand(len(layers)) * pars.drhomax) 220 | ramp = np.abs(np.max(rhos) - pars.rho_max) * np.random.rand() + 0.1 221 | rhos = rhos + np.linspace(0, ramp, rhos.shape[0]) 222 | rhos[rhos > pars.rho_max] = pars.rho_max 223 | rhos[rhos < pars.rho_min] = pars.rho_min 224 | rho1d = np.concatenate([np.ones(layers[n]) * rhos[n] 225 | for n in range(len(layers))]) 226 | if len(rho1d) < NZ: 227 | rho1d = np.concatenate( 228 | [rho1d, np.ones(NZ - len(rho1d)) * rho1d[-1]]) 229 | elif len(rho1d) > NZ: 230 | rho1d = rho1d[:NZ] 231 | else: 232 | rho1d = vel1d * 0 + pars.rho_default 233 | 234 | vp = np.transpose(np.tile(vel1d, [NX, 1])) 235 | vs = vp * 0 236 | rho = np.transpose(np.tile(rho1d, [NX, 1])) 237 | 238 | return vp, vs, rho 239 | 240 | 241 | def texture_1lay(NZ, NX, lz=2, lx=2): 242 | """ 243 | Created a random model with bandwidth limited noise. 244 | 245 | @params: 246 | NZ (int): Number of cells in Z 247 | NX (int): Number of cells in X 248 | lz (int): High frequency cut-off size in z 249 | lx (int): High frequency cut-off size in x 250 | @returns: 251 | 252 | """ 253 | 254 | noise = np.fft.fft2(np.random.random([NZ, NX])) 255 | noise[0, :] = 0 256 | noise[:, 0] = 0 257 | noise[-1, :] = 0 258 | noise[:, -1] = 0 259 | 260 | iz = lz 261 | ix = lx 262 | maskz = gaussian(NZ, iz) 263 | maskz = np.roll(maskz, [int(NZ / 2), 0]) 264 | maskx = gaussian(NX, ix) 265 | maskx = np.roll(maskx, [int(NX / 2), 0]) 266 | noise = noise * np.reshape(maskz, [-1, 1]) 267 | noise *= maskx 268 | noise = np.real(np.fft.ifft2(noise)) 269 | noise = noise / np.max(noise) 270 | 271 | return noise 272 | 273 | 274 | def generate_reflections_ttime(vp, 275 | pars, 276 | tol=0.015, 277 | window_width=0.45): 278 | """ 279 | Output the reflection travel time at the minimum offset of a CMP gather 280 | 281 | @params: 282 | vp (numpy.ndarray) : A 1D array containing the Vp profile in depth 283 | pars (ModelParameter): Parameters used to generate the model 284 | tol (float): The minimum relative velocity change to consider a reflection 285 | window_width (float): time window width in percentage of pars.peak_freq 286 | 287 | @returns: 288 | 289 | tabel (numpy.ndarray) : A 2D array with pars.NT elements with 1 at reflecion 290 | times +- window_width/pars.peak_freq, 0 elsewhere 291 | """ 292 | 293 | vp = vp[int(pars.source_depth / pars.dh):] 294 | vlast = vp[0] 295 | ind = [] 296 | for ii, v in enumerate(vp): 297 | if np.abs((v - vlast) / vlast) > tol: 298 | ind.append(ii - 1) 299 | vlast = v 300 | 301 | if pars.minoffset != 0: 302 | dt = 2.0 * pars.dh / vp 303 | t0 = np.cumsum(dt) 304 | vrms = np.sqrt(t0 * np.cumsum(vp ** 2 * dt)) 305 | tref = np.sqrt( 306 | t0[ind] ** 2 + pars.minoffset ** 2 / vrms[ind] ** 2) + pars.tdelay 307 | else: 308 | ttime = 2 * np.cumsum(pars.dh / vp) + pars.tdelay 309 | tref = ttime[ind] 310 | 311 | if pars.identify_direct: 312 | dt = 0 313 | if pars.minoffset != 0: 314 | dt = pars.minoffset / vp[0] 315 | tref = np.insert(tref, 0, pars.tdelay + dt) 316 | 317 | tlabel = np.zeros(pars.NT) 318 | for t in tref: 319 | imin = int(t / pars.dt - window_width / pars.peak_freq / pars.dt) 320 | imax = int(t / pars.dt + window_width / pars.peak_freq / pars.dt) 321 | if imin <= pars.NT and imax <= pars.NT: 322 | tlabel[imin:imax] = 1 323 | 324 | return tlabel 325 | 326 | 327 | def two_way_travel_time(vp, pars): 328 | """ 329 | Output the two-way travel-time for each cell in vp 330 | 331 | @params: 332 | vp (numpy.ndarray) : A 1D array containing the Vp profile in depth 333 | pars (ModelParameter): Parameters used to generate the model 334 | 335 | @returns: 336 | 337 | vp (numpy.ndarray) : A 1D array containing the Vp profile in depth, cut to 338 | have the same size of t 339 | t (numpy.ndarray) : The two-way travel time of each cell 340 | 341 | """ 342 | vpt = vp[int(pars.source_depth / pars.dh):] 343 | t = 2 * np.cumsum(pars.dh / vpt) + pars.tdelay 344 | t = t[t < pars.NT * pars.dt] 345 | vpt = vpt[:len(t)] 346 | 347 | return vpt, t 348 | 349 | 350 | def interval_velocity_time(vp, pars): 351 | """ 352 | Output the interval velocity in time 353 | 354 | @params: 355 | vp (numpy.ndarray) : A 1D array containing the Vp profile in depth 356 | pars (ModelParameter): Parameters used to generate the model 357 | 358 | @returns: 359 | 360 | vint (numpy.ndarray) : The interval velocity in time 361 | 362 | """ 363 | vpt, t = two_way_travel_time(vp, pars) 364 | interpolator = interp1d(t, vpt, 365 | bounds_error=False, 366 | fill_value="extrapolate", 367 | kind="nearest") 368 | vint = interpolator(np.arange(0, pars.NT, 1) * pars.dt) 369 | 370 | return vint 371 | 372 | 373 | def generate_random_2Dlayered(pars, seed=None): 374 | """ 375 | This method generates a random 2D model, with parameters given in pars. 376 | Important parameters are: 377 | Model size: 378 | -pars.NX : Number of grid cells in X 379 | -pars.NZ : Number of grid cells in Z 380 | -pars.dh : Cell size in meters 381 | 382 | Number of layers: 383 | -pars.num_layers : Minimum number of layers contained in the model 384 | -pars.layer_dh_min : Minimum thickness of a layer (in grid cell) 385 | -pars.source_depth: Depth in meters of the source. Velocity above the 386 | source is kept constant. 387 | 388 | Layers dip 389 | -pars.angle_max: Maximum dip of a layer in degrees 390 | -pars.dangle_max: Maximum dip difference between adjacent layers 391 | 392 | Model velocity 393 | -pars.vp_max: Maximum Vp velocity 394 | -pars.vp_min: Minimum Vp velocity 395 | -pars.dvmax: Maximum velocity difference of two adajcent layers 396 | 397 | Marine survey parameters 398 | -pars.marine: If True, first layer is water 399 | -pars.velwater: water velocity 400 | -pars.d_velwater: variance of water velocity 401 | -pars.water_depth: Mean water depth 402 | -pars.dwater_depth: variance of water depth 403 | 404 | Non planar layers 405 | pars.max_osci_freq: Maximum spatial frequency (1/m) of a layer interface 406 | pars.min_osci_freq: Minimum spatial frequency (1/m) of a layer interface 407 | pars.amp_max: Minimum amplitude of the ondulation of the layer interface 408 | pars.max_osci_nfreq: Maximum number of frequencies of the interface 409 | 410 | Add texture in layers 411 | pars.texture_zrange 412 | pars.texture_xrange 413 | pars.max_texture 414 | 415 | @params: 416 | pars (str) : A ModelParameters class containing parameters 417 | for model creation. 418 | seed (str) : The seed for the random number generator 419 | 420 | @returns: 421 | vp, vs, rho, vels, layers, angles 422 | vp (numpy.ndarray) : An array containing the vp model 423 | vs (numpy.ndarray) : An array containing the vs model (0 for the moment) 424 | rho (numpy.ndarray) : An array containing the density model 425 | (2000 for the moment) 426 | vels (numpy.ndarray) : 1D array containing the mean velocity of each layer 427 | layers (numpy.ndarray) : 1D array containing the mean thickness of each layer, 428 | at the center of the model 429 | angles (numpy.ndarray) : 1D array containing slope of each layer 430 | """ 431 | 432 | if seed is not None: 433 | np.random.seed(seed) 434 | 435 | # Determine the minimum and maximum number of layers 436 | if pars.num_layers == 0: 437 | nmin = pars.layer_dh_min 438 | nmax = int(pars.NZ / pars.layer_num_min) 439 | if nmin < nmax: 440 | n_layers = np.random.choice(range(nmin, nmax)) 441 | else: 442 | n_layers = nmin 443 | else: 444 | nmin = pars.layer_dh_min 445 | nmax = int(pars.NZ / pars.layer_num_min) 446 | n_layers = int(np.clip(pars.num_layers, nmin, nmax)) 447 | 448 | # Generate a random number of layers with random thicknesses 449 | NZ = pars.NZ 450 | NX = pars.NX 451 | dh = pars.dh 452 | top_min = int(pars.source_depth / dh + 2 * pars.layer_dh_min) 453 | layers = (nmin + np.random.rand(n_layers) * (nmax - nmin)).astype(np.int) 454 | tops = np.cumsum(layers) 455 | ntos = np.sum(layers[tops <= top_min]) 456 | if ntos > 0: 457 | layers = np.concatenate([[ntos], layers[tops > top_min]]) 458 | 459 | # Generate random angles for each layer 460 | n_angles = len(layers) 461 | angles = np.zeros(layers.shape) 462 | angles[1] = -pars.angle_max + np.random.rand() * 2 * pars.angle_max 463 | for ii in range(2, n_angles): 464 | angles[ii] = angles[ii - 1] + ( 465 | 2.0 * np.random.rand() - 1.0) * pars.dangle_max 466 | if np.abs(angles[ii]) > pars.angle_max: 467 | angles[ii] = np.sign(angles[ii]) * pars.angle_max 468 | 469 | # Generate a random velocity for each layer. Velocities are somewhat biased 470 | # to increase in depth 471 | vels = (pars.vp_min 472 | + np.random.rand() * (pars.vp_max - pars.vp_min - pars.dvmax) 473 | + np.random.rand(len(layers)) * pars.dvmax) 474 | ramp = np.abs(np.max(vels) - pars.vp_max) * np.random.rand() + 0.1 475 | vels = vels + np.linspace(0, ramp, vels.shape[0]) 476 | vels[vels > pars.vp_max] = pars.vp_max 477 | vels[vels < pars.vp_min] = pars.vp_min 478 | if pars.marine: 479 | vels[0] = pars.velwater + (np.random.rand() - 0.5) * 2 * pars.d_velwater 480 | layers[0] = int(pars.water_depth / pars.dh + 481 | ( 482 | np.random.rand() - 0.5) * 2 * pars.dwater_depth / pars.dh) 483 | 484 | # Generate the 2D model, from top layers to bottom 485 | vel2d = np.zeros([NZ, NX]) + vels[0] 486 | tops = np.cumsum(layers) 487 | osci = create_oscillation(pars.max_osci_freq, 488 | pars.min_osci_freq, 489 | pars.amp_max, 490 | pars.max_osci_nfreq, NX) 491 | texture = texture_1lay(2 * NZ, 492 | NX, 493 | lz=pars.texture_zrange, 494 | lx=pars.texture_xrange) 495 | for ii in range(0, len(layers) - 1): 496 | if np.random.rand() < pars.prob_osci_change: 497 | osci += create_oscillation(pars.max_osci_freq, 498 | pars.min_osci_freq, 499 | pars.amp_max, 500 | pars.max_osci_nfreq, NX) 501 | 502 | texture = texture / np.max(texture) * ( 503 | np.random.rand() + 0.001) * pars.max_texture * vels[ii + 1] 504 | for jj in range(0, NX): 505 | # depth of the layer at location x 506 | dz = int((np.tan(angles[ii + 1] / 360 * 2 * np.pi) * ( 507 | jj - NX / 2) * dh) / dh) 508 | # add oscillation component 509 | if pars.amp_max > 0: 510 | dz = int(dz + osci[jj]) 511 | # Check if the interface is inside the model 512 | if 0 < tops[ii] + dz < NZ: 513 | vel2d[tops[ii] + dz:, jj] = vels[ii + 1] 514 | if not (pars.marine and ii == 0) and pars.max_texture > 0: 515 | vel2d[tops[ii] + dz:, jj] += texture[tops[ii]:NZ - dz, jj] 516 | elif tops[ii] + dz <= 0: 517 | vel2d[:, jj] = vels[ii + 1] 518 | if not (pars.marine and ii == 0) and pars.max_texture > 0: 519 | vel2d[:, jj] += texture[:, jj] 520 | 521 | # Output the 2D model 522 | vel2d[vel2d > pars.vp_max] = pars.vp_max 523 | vel2d[vel2d < pars.vp_min] = pars.vp_min 524 | vp = vel2d 525 | vs = vp * 0 526 | rho = vp * 0 + 2000 527 | 528 | return vp, vs, rho, vels, layers, angles 529 | 530 | 531 | def create_oscillation(max_osci_freq, min_osci_freq, 532 | amp_max, max_osci_nfreq, Nmax): 533 | nfreqs = np.random.randint(max_osci_nfreq) 534 | freqs = np.random.rand(nfreqs) * ( 535 | max_osci_freq - min_osci_freq) + min_osci_freq 536 | phases = np.random.rand(nfreqs) * np.pi * 2 537 | amps = np.random.rand(nfreqs) 538 | x = np.arange(0, Nmax) 539 | osci = np.zeros(Nmax) 540 | for ii in range(nfreqs): 541 | osci += amps[ii] * np.sin(freqs[ii] * x + phases[ii]) 542 | 543 | dosci = np.max(osci) 544 | if dosci > 0: 545 | osci = osci / dosci * amp_max * np.random.rand() 546 | 547 | return osci 548 | 549 | 550 | if __name__ == "__main__": 551 | import matplotlib.pyplot as plt 552 | 553 | # Initialize argument parser 554 | parser = argparse.ArgumentParser() 555 | 556 | parser.add_argument( 557 | "--ND", 558 | type=int, 559 | default=1, 560 | help="Dimension of the model to display" 561 | ) 562 | # Parse the input for training parameters 563 | args, unparsed = parser.parse_known_args() 564 | 565 | pars = ModelParameters() 566 | pars.layer_dh_min = 20 567 | pars.num_layers = 0 568 | if args.ND == 1: 569 | vp, vs, rho = generate_random_1Dlayered(pars) 570 | vp = vp[:, 0] 571 | vint = interval_velocity_time(vp, pars) 572 | vrms = calculate_vrms(vp, 573 | pars.dh, 574 | pars.Npad, 575 | pars.NT, 576 | pars.dt, 577 | pars.tdelay, 578 | pars.source_depth) 579 | 580 | plt.plot(vint) 581 | plt.plot(vrms) 582 | plt.show() 583 | else: 584 | vp, vs, rho = generate_random_2Dlayered(pars) 585 | plt.imshow(vp) 586 | plt.show() 587 | -------------------------------------------------------------------------------- /vrmslearn/ModelParameters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import h5py as h5 4 | 5 | class ModelParameters(object): 6 | """ 7 | This class contains all model parameters needed to generate random models 8 | and seismic data 9 | """ 10 | 11 | def __init__(self): 12 | 13 | 14 | self.NX = 256 # number of grid cells in X direction 15 | self.NZ = 256 # number of grid cells in Z direction 16 | self.dh = 10.0 # grid spacing in X, Y, Z directions (in meters) 17 | self.fs = False # whether free surface is turned on the top face 18 | self.Npad = 16 # number of padding cells of absorbing boundary 19 | self.NT = 2048 # number of times steps 20 | self.dt = 0.0009 # time sampling for seismogram (in seconds) 21 | self.peak_freq = 10.0 # peak frequency of input wavelet (in Hertz) 22 | self.wavefuns = [1] # Source wave function selection (see seismic generator) 23 | self.df = 2 # Frequency of source peak_freq +- random(df) 24 | self.tdelay = 2.0 / (self.peak_freq - self.df) # delay of the source 25 | self.resampling = 10 # Resampling of the shots time axis 26 | self.source_depth = (self.Npad + 2) * self.dh # depth of sources (m) 27 | self.receiver_depth = (self.Npad + 2) * self.dh # depth of receivers (m) 28 | self.dg = 2 # Receiver interval in grid points 29 | self.ds = 2 # Source interval (in 2D) 30 | self.gmin = None # Minimum position of receivers (-1 = minimum of grid) 31 | self.gmax = None # Maximum position of receivers (-1 = maximum of grid) 32 | self.minoffset = 0 33 | self.sourcetype = 100 # integer used by SeisCL for pressure source 34 | 35 | self.mute_dir = False 36 | self.mask_firstvel = False 37 | self.random_static = False 38 | self.random_static_max = 2 39 | self.random_noise = False 40 | self.random_noise_max = 0.1 41 | self.mute_nearoffset = False 42 | self.mute_nearoffset_max = 10 43 | self.random_time_scaling = False 44 | 45 | self.vp_default = 1500.0 # default value of vp (in m/s) 46 | self.vs_default = 0.0 # default value of vs (in m/s) 47 | self.rho_default = 2000 # default value of rho (in kg/m3) 48 | self.vp_min = 1000.0 # maximum value of vp (in m/s) 49 | self.vp_max = 5000.0 # minimum value of vp (in m/s) 50 | self.dvmax = 2000 # Maximum velocity difference between 2 layers 51 | self.marine = False # if true, first layer will be at water velocity 52 | self.velwater = 1500 # mean velocity of water 53 | self.d_velwater = 30 # amplitude of random variation of water velocity 54 | self.water_depth = 3000 # mean water depth (m) 55 | self.dwater_depth = 2000 # maximum amplitude of water depth variations 56 | 57 | self.identify_direct = True # The direct arrival is contained in labels 58 | 59 | self.rho_var = False 60 | self.rho_min = 2000.0 # maximum value of rho 61 | self.rho_max = 3500.0 # minimum value of rho 62 | self.drhomax = 800 # Maximum velocity difference between 2 layers 63 | 64 | 65 | self.layer_dh_min = 50 # minimum thickness of a layer (in grid cells) 66 | self.layer_num_min = 5 # minimum number of layers 67 | self.num_layers = 10 # Fix the number of layers if not 0 68 | 69 | self.flat = True # True: 1D, False: 2D model 70 | self.ngathers = 1 # Number of gathers to train on 71 | self.train_on_shots = False # Train on True: shot gathers, False: CMP 72 | 73 | self.angle_max = 15 # Maximum dip of a layer 74 | self.dangle_max = 5 # Maximum dip difference between two adjacent layers 75 | self.max_osci_freq = 0 # Max frequency of the layer boundary function 76 | self.min_osci_freq = 0 # Min frequency of the layer boundary function 77 | self.amp_max = 25 # Maximum amplitude of boundary oscillations 78 | self.max_osci_nfreq = 20 # Maximum nb of frequencies of boundary 79 | self.prob_osci_change = 0.3 # Probability that a boundary shape will 80 | # change between two lahyers 81 | self.max_texture = 0.15 # Add random noise two a layer (% or velocity) 82 | self.texture_xrange = 0 # Range of the filter in x for texture creation 83 | self.texture_zrange = 0 # Range of the filter in z for texture creation 84 | 85 | 86 | self.device_type = 4 #For SeisCL 4:GPU, 2: CPU 87 | 88 | 89 | def save_parameters_to_disk(self, filename): 90 | """ 91 | Save all parameters to disk 92 | 93 | @params: 94 | filename (str) : name of the file for saving parameters 95 | 96 | @returns: 97 | 98 | """ 99 | with h5.File(filename, 'w') as file: 100 | for item in self.__dict__: 101 | file.create_dataset(item, data=self.__dict__[item]) 102 | 103 | def read_parameters_from_disk(self, filename): 104 | """ 105 | Read all parameters from a file 106 | 107 | @params: 108 | filename (str) : name of the file containing parameters 109 | 110 | @returns: 111 | 112 | """ 113 | with h5.File(filename, 'r') as file: 114 | for item in self.__dict__: 115 | try: 116 | self.__dict__[item] = file[item][()] 117 | except KeyError: 118 | pass 119 | -------------------------------------------------------------------------------- /vrmslearn/RCNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Class to build the neural network for 1D prediction of RMS and interval velocity 5 | in time. 6 | """ 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | 11 | class RCNN(object): 12 | """ 13 | This class build a NN based on recursive CNN that can identify primarary 14 | reflections on a CMP gather and ouput the RMS and interval velocity in time 15 | """ 16 | 17 | def __init__(self, 18 | input_size: list=[0, 0], 19 | batch_size: int=1, 20 | alpha: float = 0, 21 | beta: float = 0, 22 | gamma: float = 0, 23 | zeta: float = 0, 24 | omega: float = 0, 25 | dec: float = 0, 26 | use_peepholes = False, 27 | with_masking: bool = False): 28 | """ 29 | Build the neural net in tensorflow, along the cost function 30 | 31 | @params: 32 | input_size (list): the size of the CMP [NT, NX] 33 | batch_size (int): Number of CMPs in a batch 34 | alpha (float): Fraction of the loss dedicated to vrms derivative in time 35 | beta (float): Fraction of the loss dedicated to primary identification 36 | gamma (float): Fraction of the loss dedicated vrms at reflection times 37 | zeta (float): Fraction of the loss dedicated interval velocity 38 | omega (float): Fraction of the loss dedicated vint derivative in time 39 | with_masking (bool): If true, masks random part of the CMPs 40 | 41 | @returns: 42 | """ 43 | 44 | self.input_size = input_size 45 | self.graph = tf.Graph() 46 | self.with_masking = with_masking 47 | self.feed_dict = [] 48 | self.batch_size = batch_size 49 | self.use_peepholes = use_peepholes 50 | with self.graph.as_default(): 51 | self.global_step = tf.train.get_or_create_global_step() 52 | (self.input, 53 | self.label_vrms, 54 | self.weights, 55 | self.label_ref, 56 | self.label_vint) = self.generate_io() 57 | self.feed_dict = [self.input, 58 | self.label_vrms, 59 | self.weights, 60 | self.label_ref, 61 | self.label_vint] 62 | self.input_scaled = self.scale_input() 63 | (self.output_vrms, 64 | self.output_ref, 65 | self.output_vint) = self.build_neural_net() 66 | self.loss = self.define_loss(alpha=alpha, 67 | beta=beta, 68 | gamma=gamma, 69 | zeta=zeta, 70 | omega=omega, 71 | dec=dec) 72 | 73 | 74 | def generate_io(self): 75 | """ 76 | This method creates the input nodes. 77 | 78 | @params: 79 | 80 | @returns: 81 | input_data (tf.tensor) : Placeholder of CMP gather. 82 | label_vrms (tf.placeholder) : Placeholder of RMS velocity labels. 83 | weights (tf.placeholder) : Placeholder of time weights 84 | label_ref (tf.placeholder) : Placeholder of primary reflection labels. 85 | label_vint (tf.placeholder) : Placeholder of interval velocity labels. 86 | """ 87 | 88 | with tf.name_scope('Inputs'): 89 | # Create placeholder for input 90 | input_data = tf.placeholder(dtype=tf.float32, 91 | shape=[self.batch_size, 92 | self.input_size[0], 93 | self.input_size[1], 94 | 1], 95 | name='data') 96 | 97 | label_vrms = tf.placeholder(dtype=tf.float32, 98 | shape=[self.batch_size, 99 | self.input_size[0]], 100 | name='vrms') 101 | 102 | weights = tf.placeholder(dtype=tf.float32, 103 | shape=[self.batch_size, 104 | self.input_size[0]], 105 | name='weigths') 106 | 107 | label_ref = tf.placeholder(dtype=tf.int32, 108 | shape=[self.batch_size, 109 | self.input_size[0]], 110 | name='reflections') 111 | label_vint = tf.placeholder(dtype=tf.float32, 112 | shape=[self.batch_size, 113 | self.input_size[0]], 114 | name='vint') 115 | 116 | return input_data, label_vrms, weights, label_ref, label_vint 117 | 118 | def scale_input(self): 119 | """ 120 | Scale each trace to its RMS value, and each CMP to is RMS 121 | 122 | @params: 123 | 124 | @returns: 125 | input_data (tf.tensor) : Placeholder of CMP gather. 126 | label_vrms (tf.placeholder) : Placeholder of RMS velocity labels. 127 | weights (tf.placeholder) : Placeholder of time weights 128 | label_ref (tf.placeholder) : Placeholder of primary reflection labels. 129 | label_vint (tf.placeholder) : Placeholder of interval velocity labels. 130 | """ 131 | scaled = self.input / (tf.sqrt(reduce_sum(self.input ** 2, axis=[1], 132 | keepdims=True)) 133 | + np.finfo(np.float32).eps) 134 | 135 | # scaled = scaled / tf.sqrt(reduce_sum(scaled ** 2, 136 | # axis=[1, 2], 137 | # keepdims=True)) 138 | 139 | scaled = 1000*scaled / tf.reduce_max(scaled, axis=[1, 2], keepdims=True) 140 | 141 | return scaled 142 | 143 | def build_neural_net(self): 144 | """ 145 | This method build the neural net in Tensorflow 146 | 147 | @params: 148 | 149 | @returns: 150 | decode_rms (tf.tensor) : RMS velocity predictions 151 | decode_ref (tf.tensor) : Primary reflections predictions 152 | decode_vint (tf.tensor) : Interval velocity predictions. 153 | """ 154 | 155 | rnn_hidden = 200 156 | weights = [tf.Variable(tf.random_normal([15, 1, 1, 16], stddev=1e-1), 157 | name='w1'), 158 | tf.Variable(tf.random_normal([1, 9, 16, 16], stddev=1e-1), 159 | name='w2'), 160 | tf.Variable(tf.random_normal([15, 1, 16, 32], stddev=1e-1), 161 | name='w3'), 162 | tf.Variable(tf.random_normal([1, 9, 32, 32], stddev=1e-1), 163 | name='w4'), 164 | tf.Variable(tf.random_normal([15, 3, 32, 32], stddev=1e-2), 165 | name='w5'), 166 | tf.Variable(tf.random_normal([1, 2, 32, 32], stddev=1e-0), 167 | name='w6')] 168 | 169 | biases = [tf.Variable(tf.zeros([16]), name='b1'), 170 | tf.Variable(tf.zeros([16]), name='b2'), 171 | tf.Variable(tf.zeros([32]), name='b3'), 172 | tf.Variable(tf.zeros([32]), name='b4'), 173 | tf.Variable(tf.zeros([32]), name='b5'), 174 | tf.Variable(tf.zeros([32]), name='b6')] 175 | 176 | weightsr = [tf.Variable(tf.random_normal([15, 1, 1, 16], stddev=1e-1), 177 | name='w1'), 178 | tf.Variable(tf.random_normal([1, 9, 16, 16], stddev=1e-1), 179 | name='w2'), 180 | tf.Variable(tf.random_normal([15, 1, 16, 32], stddev=1e-1), 181 | name='w3'), 182 | tf.Variable(tf.random_normal([1, 9, 32, 32], stddev=1e-1), 183 | name='w4'), 184 | tf.Variable(tf.random_normal([15, 3, 32, 32], stddev=1e-2), 185 | name='w5'), 186 | tf.Variable(tf.random_normal([1, 2, 32, 32], stddev=1e-0), 187 | name='w6')] 188 | 189 | biasesr = [tf.Variable(tf.zeros([16]), name='b1'), 190 | tf.Variable(tf.zeros([16]), name='b2'), 191 | tf.Variable(tf.zeros([32]), name='b3'), 192 | tf.Variable(tf.zeros([32]), name='b4'), 193 | tf.Variable(tf.zeros([32]), name='b5'), 194 | tf.Variable(tf.zeros([32]), name='b6')] 195 | 196 | data_stream = self.input_scaled 197 | allout = [self.input_scaled] 198 | with tf.name_scope('Encoder'): 199 | for ii in range(len(weights) - 2): 200 | with tf.name_scope('CNN_' + str(ii)): 201 | data_stream = tf.nn.relu( 202 | tf.nn.conv2d(data_stream, 203 | weights[ii], 204 | strides=[1, 1, 1, 1], 205 | padding='SAME') + biases[ii]) 206 | allout.append(data_stream) 207 | self.output_encoder = data_stream 208 | 209 | with tf.name_scope('Time_RCNN'): 210 | for ii in range(7): 211 | data_stream = tf.nn.relu( 212 | tf.nn.conv2d(data_stream, 213 | weights[-2], 214 | strides=[1, 1, 1, 1], 215 | padding='SAME') + biases[-2]) 216 | allout.append(data_stream) 217 | 218 | self.output_time_rcnn = data_stream 219 | 220 | decode_stream = data_stream 221 | with tf.name_scope('Decoder'): 222 | n = -2 223 | for ii in range(7): 224 | decode_stream = tf.nn.conv2d_transpose(tf.nn.relu(decode_stream) + biasesr[-2], 225 | weightsr[-2], 226 | output_shape=allout[n].shape, 227 | strides=[1, 1, 1, 1], 228 | padding='SAME') 229 | 230 | n -= 1 231 | for ii in range(len(weights) - 2): 232 | decode_stream = tf.nn.conv2d_transpose(tf.nn.relu(decode_stream) + biasesr[len(weights) - 3 - ii], 233 | weightsr[len(weights) - 3 - ii], 234 | output_shape=allout[n].shape, 235 | strides=[1, 1, 1, 1], 236 | padding='SAME') 237 | 238 | n -= 1 239 | self.decoded = decode_stream 240 | 241 | with tf.name_scope('Offset_RCNN'): 242 | while data_stream.get_shape()[2] > 1: 243 | data_stream = tf.nn.relu( 244 | tf.nn.conv2d(data_stream, 245 | weights[-1], 246 | strides=[1, 1, 2, 1], 247 | padding='VALID') + biases[-1]) 248 | data_stream = reduce_max(data_stream, axis=[2], keepdims=False) 249 | self.output_offset_rcnn = data_stream 250 | 251 | 252 | output_size = int(data_stream.get_shape()[-1]) 253 | with tf.name_scope('Decode_refevent'): 254 | decode_refw = tf.Variable( 255 | initial_value=tf.random_normal([output_size, 2], 256 | stddev=1e-4), 257 | name='decode_ref') 258 | final_projection = lambda x: tf.matmul(x, decode_refw) 259 | decode_ref = tf.map_fn(final_projection, data_stream) 260 | 261 | with tf.name_scope('RNN_vrms'): 262 | cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden, state_is_tuple=True, use_peepholes=self.use_peepholes) 263 | state0 = cell.zero_state(data_stream.get_shape()[0], tf.float32) 264 | data_stream, rnn_states = tf.nn.dynamic_rnn(cell, data_stream, 265 | initial_state=state0, 266 | time_major=False, 267 | scope="rnn_vrms") 268 | self.rnn_vrms_out = data_stream 269 | 270 | with tf.name_scope('Decode_rms'): 271 | output_size = int(data_stream.get_shape()[-1]) 272 | decode_rmsw = tf.Variable( 273 | initial_value=tf.random_normal([output_size, 1], stddev=1e-4), 274 | name='decode_rms') 275 | final_projection = lambda x: tf.matmul(x, decode_rmsw) 276 | decode_rms = tf.map_fn(final_projection, data_stream) 277 | decode_rms = decode_rms[:, :, 0] 278 | 279 | 280 | with tf.name_scope('RNN_vint'): 281 | cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden, state_is_tuple=True, use_peepholes=self.use_peepholes) 282 | state0 = cell.zero_state(data_stream.get_shape()[0], tf.float32) 283 | data_stream, rnn_states = tf.nn.dynamic_rnn(cell, data_stream, 284 | initial_state=state0, 285 | time_major=False, 286 | scope="rnn_vint") 287 | self.rnn_vint_out = data_stream 288 | 289 | with tf.name_scope('Decode_vint'): 290 | output_size = int(data_stream.get_shape()[-1]) 291 | decode_vintw = tf.Variable( 292 | initial_value=tf.random_normal([output_size, 1], stddev=1e-4), 293 | name='decode_vint') 294 | final_projection = lambda x: tf.matmul(x, decode_vintw) 295 | decode_vint = tf.map_fn(final_projection, data_stream) 296 | decode_vint = decode_vint[:, :, 0] 297 | 298 | self.allout = allout 299 | return decode_rms, decode_ref, decode_vint 300 | 301 | def define_loss(self, alpha=0.2, beta=0.1, gamma=0, zeta=0, omega=0, dec=0): 302 | """ 303 | This method creates a node to compute the loss function. 304 | The loss is normalized. 305 | 306 | @params: 307 | 308 | @returns: 309 | loss (tf.tensor) : Output of node calculating loss. 310 | """ 311 | with tf.name_scope("Loss_Function"): 312 | 313 | losses = [] 314 | 315 | fact1 = (1 - alpha - beta - gamma - zeta - omega) 316 | 317 | # Calculate mean squared error of continuous rms velocity 318 | if fact1 > 0: 319 | num = tf.reduce_sum(self.weights*(self.label_vrms 320 | - self.output_vrms) ** 2) 321 | den = tf.reduce_sum(self.weights*self.label_vrms ** 2) 322 | losses.append(fact1 * num / den) 323 | 324 | # Calculate mean squared error of the derivative of the continuous 325 | # rms velocity(normalized) 326 | if alpha > 0: 327 | dlabels = self.label_vrms[:, 1:] - self.label_vrms[:, :-1] 328 | dout = self.output_vrms[:, 1:] - self.output_vrms[:, :-1] 329 | num = tf.reduce_sum(self.weights[:,:-1]*(dlabels - dout) ** 2) 330 | den = tf.reduce_sum(self.weights[:,:-1]*dlabels ** 2 + 0.000001) 331 | losses.append(alpha * num / den) 332 | 333 | # Logistic regression of zero offset time arrival of reflections 334 | if beta > 0: 335 | if self.with_masking: 336 | weightsr = tf.expand_dims(self.weights, -1) 337 | else: 338 | weightsr = 1.0 339 | preds = self.output_ref * weightsr 340 | labels = tf.one_hot(self.label_ref, 2) * weightsr 341 | losses.append(beta * tf.reduce_mean( 342 | tf.nn.softmax_cross_entropy_with_logits(logits=preds, 343 | labels=labels))) 344 | 345 | # Learning vrms only at the time of reflections 346 | if gamma > 0: 347 | mask = tf.cast(self.label_ref, tf.float32) 348 | num = tf.reduce_sum(mask*self.weights*(self.label_vrms 349 | - self.output_vrms) ** 2) 350 | den = tf.reduce_sum(mask*self.weights*self.label_vrms ** 2) 351 | losses.append(gamma * num / den) 352 | 353 | # Learning interval velocity 354 | if zeta > 0: 355 | num = tf.reduce_sum(self.weights*(self.label_vint 356 | - self.output_vint) ** 2) 357 | den = tf.reduce_sum(self.weights*self.label_vint ** 2) 358 | losses.append(zeta * num / den) 359 | 360 | # Minimize interval velocity gradient (blocky inversion) 361 | if omega > 0: 362 | num = tf.norm((self.output_vint[:, 1:] 363 | - self.output_vint[:, :-1]), ord=1) 364 | den = tf.norm(self.output_vint, ord=1) / 0.02 365 | losses.append(omega * num / den) 366 | 367 | # Reconstruction error 368 | if dec > 0: 369 | num = tf.norm((self.decoded - self.input_scaled), ord=1) 370 | den = tf.norm(self.input_scaled, ord=1) 371 | losses.append(dec * num / den) 372 | 373 | loss = np.sum(losses) 374 | 375 | tf.summary.scalar("loss", loss) 376 | return loss 377 | 378 | 379 | 380 | def reduce_sum(a, axis=None, keepdims=True): 381 | if tf.__version__ == '1.2.0': 382 | return tf.reduce_sum(a, axis=axis, keep_dims=keepdims) 383 | else: 384 | return tf.reduce_sum(a, axis=axis, keepdims=keepdims) 385 | 386 | def reduce_max(a, axis=None, keepdims=True): 387 | if tf.__version__ == '1.2.0': 388 | return tf.reduce_max(a, axis=axis, keep_dims=keepdims) 389 | else: 390 | return tf.reduce_max(a, axis=axis, keepdims=keepdims) 391 | -------------------------------------------------------------------------------- /vrmslearn/SeismicGenerator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | A class to generate the labels (seismic data) 5 | """ 6 | 7 | from vrmslearn.ModelGenerator import ModelGenerator, interval_velocity_time 8 | from vrmslearn.ModelParameters import ModelParameters 9 | from SeisCL.SeisCL import SeisCL 10 | import numpy as np 11 | import os 12 | import h5py as h5 13 | import fnmatch 14 | from multiprocessing import Process, Queue, Event, Value 15 | from shutil import rmtree 16 | import fnmatch 17 | from scipy.signal import convolve2d 18 | 19 | def gaussian(f0, t, o, amp=1.0, order=2): 20 | 21 | x = np.pi * f0 * (t + o) 22 | e = amp * np.exp(-x ** 2) 23 | if order == 1: 24 | return e*x 25 | elif order == 2: 26 | return (1.0 - 2.0 * x ** 2) * e 27 | elif order == 3: 28 | return 2.0 * x * (2.0 * x ** 2 - 3.0) * e 29 | elif order == 4: 30 | return (-8.0 * x ** 4 + 24.0 * x ** 2 - 6.0) * e 31 | elif order == 5: 32 | return 4.0 * x * (4.0 * x ** 2 - 20.0 * x ** 2 + 15.0) * e 33 | elif order == 6: 34 | return -4.0 * (8.0 * x ** 6 - 60.0 * x ** 4 + 90.0 * x ** 2 - 15.0) * e 35 | 36 | def morlet(f0, t, o , amp=1.0, order=5): 37 | x = f0 * (t + o) 38 | return amp * np.cos(x*order) * np.exp(- x ** 2) 39 | 40 | 41 | def shift_trace(signal, phase): 42 | S = np.fft.fft(signal) 43 | NT = len(signal) 44 | S[1:NT//2] *= 2.0 45 | S[NT // 2+1:] *=0 46 | s = np.fft.ifft(S) 47 | return np.real(s) * np.cos(phase) + np.imag(s) * np.sin(phase) 48 | 49 | class SeismicGenerator(object): 50 | """ 51 | Class to generate seismic data with SeisCL and output an example to build 52 | a seismic dataset for training. 53 | """ 54 | def __init__(self, 55 | model_parameters=ModelParameters(), 56 | gpus=[]): 57 | """ 58 | This is the constructor for the class. 59 | 60 | @params: 61 | model_parameters (ModelParameters) : A ModelParameter object 62 | gpus (list): A list of GPUs not to use for computation 63 | wavefuns (list)L A list of wave function generator to source generation 64 | 65 | @returns: 66 | """ 67 | 68 | self.pars = model_parameters 69 | self.F = SeisCL() 70 | 71 | # Overload to generate other kind of models 72 | self.model_generator = ModelGenerator(model_parameters=model_parameters) 73 | 74 | self.init_F(gpus) 75 | self.image_size = [int(np.ceil(self.pars.NT/self.pars.resampling)), 76 | self.F.rec_pos.shape[1]] 77 | 78 | allwavefuns= [lambda f0, t, o: gaussian(f0, t, o, order=1), 79 | lambda f0, t, o: gaussian(f0, t, o, order=2), 80 | lambda f0, t, o: gaussian(f0, t, o, order=3), 81 | lambda f0, t, o: morlet(f0, t, o, order=2), 82 | lambda f0, t, o: morlet(f0, t, o, order=3), 83 | lambda f0, t, o: morlet(f0, t, o, order=4)] 84 | 85 | 86 | self.wavefuns = [allwavefuns[ii] for ii in model_parameters.wavefuns] 87 | 88 | self.files_list = {} 89 | 90 | def init_F(self, gpus=[]): 91 | """ 92 | This method initializes the variable 'self.F', which is used for forward 93 | modeling using the SeisCL engine. We assume here a 1D vp model and we 94 | position a source at the top of the model, centered in the x direction. 95 | 96 | @params: 97 | gpus (list) : A list of GPU ids not to use 98 | 99 | @returns: 100 | """ 101 | # Initialize the modeling engine 102 | self.F.csts['N'] = np.array([self.pars.NZ, self.pars.NX]) 103 | self.F.csts['ND'] = 2 104 | self.F.csts['dh'] = self.pars.dh # Grid spacing 105 | self.F.csts['nab'] = self.pars.Npad # Set padding cells 106 | self.F.csts['dt'] = self.pars.dt # Time step size 107 | self.F.csts['NT'] = self.pars.NT # Nb of time steps 108 | self.F.csts['f0'] = self.pars.peak_freq # Source frequency 109 | self.F.csts['seisout'] = 2 # Output pressure 110 | self.F.csts['freesurf'] = int(self.pars.fs) # Free surface 111 | self.F.csts['no_use_GPUs'] = np.array(gpus) 112 | self.F.csts['pref_device_type'] = self.pars.device_type 113 | 114 | if self.pars.flat: 115 | # Add a source in the middle 116 | sx = np.arange(self.pars.NX / 2, 1 + self.pars.NX / 2) * self.pars.dh 117 | else: 118 | if self.pars.train_on_shots: 119 | l1 = self.pars.Npad + 1 120 | if self.pars.gmin and self.pars.gmin < 0: 121 | l1 += -self.pars.gmin 122 | l2 = self.pars.NX - self.pars.Npad 123 | if self.pars.gmax and self.pars.gmax > 0: 124 | l2 += -self.pars.gmax 125 | 126 | sx = np.arange(l1, l2, self.pars.ds) * self.pars.dh 127 | else: 128 | # We need to compute the true CMP as layers have a slope. 129 | # We compute one CMP, with multiple shots with 1 receiver 130 | sx = np.arange(self.pars.Npad + 1, 131 | self.pars.NX - self.pars.Npad, 132 | self.pars.dg) * self.pars.dh 133 | sz = sx * 0 + self.pars.source_depth 134 | sid = np.arange(0, sx.shape[0]) 135 | 136 | self.F.src_pos = np.stack([sx, 137 | sx * 0, 138 | sz, 139 | sid, 140 | sx * 0 + self.pars.sourcetype], axis=0) 141 | self.F.src_pos_all = self.F.src_pos 142 | self.F.src = np.empty((self.F.csts['NT'], 0)) 143 | 144 | def generate_wavelet(): 145 | t = np.arange(0, self.pars.NT) * self.pars.dt 146 | fmin = self.pars.peak_freq - self.pars.df 147 | fmax = self.pars.peak_freq + self.pars.df 148 | f0 = np.random.rand(1)*(fmax-fmin) + fmin 149 | phase = np.random.rand(1) * np.pi 150 | fun = np.random.choice(self.wavefuns) 151 | tdelay = -self.pars.tdelay 152 | src = fun(f0, t, tdelay) 153 | src = shift_trace(src, phase) 154 | 155 | return src 156 | 157 | self.F.wavelet_generator = generate_wavelet 158 | 159 | # Add receivers 160 | if self.pars.flat or self.pars.train_on_shots: 161 | if self.pars.gmin: 162 | gmin = self.pars.gmin 163 | else: 164 | gmin = -(self.pars.NX - 2 *self.pars.Npad) // 2 165 | if self.pars.gmax: 166 | gmax = self.pars.gmax 167 | else: 168 | gmax = (self.pars.NX - 2 *self.pars.Npad) //2 169 | 170 | gx0 = np.arange(gmin, gmax, self.pars.dg) * self.pars.dh 171 | gx = np.concatenate([s + gx0 for s in sx], axis=0) 172 | gsid = np.concatenate([s + gx0 * 0 for s in sid], axis=0) 173 | 174 | else: 175 | # One receiver per source, with the middle point at NX/2 176 | gx = (self.pars.NX - sx/self.pars.dh) * self.pars.dh 177 | gsid = sid 178 | gz = gx * 0 + self.pars.receiver_depth 179 | gid = np.arange(0, len(gx)) 180 | 181 | self.F.rec_pos = np.stack([gx, 182 | gx * 0, 183 | gz, 184 | gsid, 185 | gid, 186 | gx * 0 + 2, 187 | gx * 0, 188 | gx * 0], axis=0) 189 | self.F.rec_pos_all = self.F.rec_pos 190 | 191 | 192 | def compute_example(self, workdir): 193 | """ 194 | This method generates one example, which contains the vp model, vrms, 195 | the seismic data and the valid vrms time samples. 196 | 197 | @params: 198 | workdir (str) : A string containing the working direction of SeisCL 199 | 200 | @returns: 201 | data (numpy.ndarray) : Contains the modelled seismic data 202 | vrms (numpy.ndarray) : numpy array of shape (self.pars.NT, ) with vrms 203 | values in meters/sec. 204 | vp (numpy.ndarray) : numpy array (self.pars.NZ, self.pars.NX) for vp. 205 | valid (numpy.ndarray) : numpy array (self.pars.NT, )containing the time 206 | samples for which vrms is valid 207 | tlabels (numpy.ndarray) : numpy array (self.pars.NT, ) containing the 208 | if a sample is a primary reflection (1) or not 209 | """ 210 | vp, vs, rho = self.model_generator.generate_model() 211 | vrms, valid, tlabels = self.model_generator.generate_labels() 212 | self.F.set_forward(self.F.src_pos[3, :], 213 | {'vp': vp, 'vs': vs, 'rho': rho}, 214 | workdir, 215 | withgrad=False) 216 | self.F.execute(workdir) 217 | data = self.F.read_data(workdir) 218 | data = data[0][::self.pars.resampling, :] 219 | vrms = vrms[::self.pars.resampling] 220 | valid = valid[::self.pars.resampling] 221 | tlabels = tlabels[::self.pars.resampling] 222 | 223 | return data, vrms, vp[:,0], valid, tlabels 224 | 225 | def write_example(self, n, savedir, data, vrms, vp, valid, tlabels, 226 | filename=None): 227 | """ 228 | This method writes one example in the hdf5 format 229 | 230 | @params: 231 | savedir (str) : A string containing the directory in which to 232 | save the example 233 | data (numpy.ndarray) : Contains the modelled seismic data 234 | vrms (numpy.ndarray) : numpy array of shape (self.pars.NT, ) with vrms 235 | values in meters/sec. 236 | vp (numpy.ndarray) : numpy array (self.pars.NZ, self.pars.NX) for vp. 237 | valid (numpy.ndarray) : numpy array (self.pars.NT, )containing the time 238 | samples for which vrms is valid 239 | tlabels (numpy.ndarray) : numpy array (self.pars.NT, ) containing the 240 | if a sample is a primary reflection (1) or not 241 | 242 | @returns: 243 | n (int) : The number of examples in the directory 244 | """ 245 | 246 | if filename is None: 247 | if not os.path.isdir(savedir): 248 | os.mkdir(savedir) 249 | #n = len(fnmatch.filter(os.listdir(savedir), 'example_*')) 250 | pid = os.getpid() 251 | filename= savedir + "/example_%d_%d" % (n, pid) 252 | else: 253 | filename = savedir + "/" +filename 254 | 255 | file = h5.File(filename, "w") 256 | file["data"] = data 257 | file["vrms"] = vrms 258 | file["vp"] = vp 259 | file["valid"] = valid 260 | file["tlabels"] = tlabels 261 | file.close() 262 | 263 | return n 264 | 265 | def read_example(self, savedir, filename=None): 266 | """ 267 | This method retrieve one example written in the hdf5 format 268 | 269 | @params: 270 | savedir (str) : A string containing the directory in which to 271 | read the example 272 | filename (str) : If provided, the file name of the example to read 273 | 274 | @returns: 275 | data (numpy.ndarray) : Contains the modelled seismic data 276 | vrms (numpy.ndarray) : numpy array of shape (self.pars.NT, ) with vrms 277 | values in meters/sec. 278 | vint (numpy.ndarray) : numpy array (self.pars.NT,) containing interval 279 | velocity 280 | valid (numpy.ndarray) : numpy array (self.pars.NT, )containing the time 281 | samples for which vrms is valid 282 | tlabels (numpy.ndarray) : numpy array (self.pars.NT, ) containing the 283 | if a sample is a primary reflection (1) or not 284 | 285 | """ 286 | if filename is None: 287 | 288 | if type(savedir) is list: 289 | savedir = np.random.choice(savedir, 1)[0] 290 | if not os.path.isdir(savedir): 291 | os.mkdir(savedir) 292 | if savedir not in self.files_list: 293 | files = fnmatch.filter(os.listdir(savedir), 'example_*') 294 | self.files_list[savedir] = files 295 | else: 296 | files = self.files_list[savedir] 297 | if not files: 298 | raise FileNotFoundError() 299 | filename = savedir + "/" + np.random.choice(files, 1)[0] 300 | else: 301 | filename = savedir + "/" + filename 302 | 303 | file = h5.File(filename, "r") 304 | data = file['data'][:] 305 | vrms = file['vrms'][:] 306 | vp = file['vp'][:] 307 | if self.pars.random_time_scaling: 308 | data = random_time_scaling(data, self.pars.dt * self.pars.resampling) 309 | if self.pars.mute_dir: 310 | data = mute_direct(data, vp[0], self.pars) 311 | if self.pars.random_static: 312 | data = random_static(data, self.pars.random_static_max) 313 | if self.pars.random_noise: 314 | data = random_noise(data, self.pars.random_noise_max) 315 | if self.pars.mute_nearoffset: 316 | data = mute_nearoffset(data, self.pars.mute_nearoffset_max) 317 | 318 | 319 | vint = interval_velocity_time(vp, self.pars)[::self.pars.resampling] 320 | vint = (vint - self.pars.vp_min) / (self.pars.vp_max - self.pars.vp_min) 321 | valid = file['valid'][:] 322 | tlabels = file['tlabels'][:] 323 | if self.pars.mask_firstvel: 324 | ind0 = np.nonzero(tlabels)[0][0] 325 | valid[0:ind0] *= 0.05 326 | file.close() 327 | 328 | # return data.tolist(), vrms.tolist(), vint.tolist(), valid.tolist(), tlabels.tolist() 329 | # 330 | return data, vrms, vint, valid, tlabels 331 | 332 | def aggregate_examples(self, batch): 333 | """ 334 | This method aggregates a batch of examples 335 | 336 | @params: 337 | batch (lsit): A list of numpy arrays that contain a list with 338 | all elements of of example. 339 | 340 | @returns: 341 | batch (numpy.ndarray): A list of numpy arrays that contains all examples 342 | for each element of a batch. 343 | 344 | """ 345 | data = np.stack([el[0] for el in batch]) 346 | data = np.expand_dims(data, axis=-1) 347 | vrms = np.stack([el[1] for el in batch]) 348 | vint = np.stack([el[2] for el in batch]) 349 | weights = np.stack([el[3] for el in batch]) 350 | tlabels = np.stack([el[4] for el in batch]) 351 | 352 | return [data, vrms, weights, tlabels, vint] 353 | 354 | 355 | def generate_dataset(pars: ModelParameters, 356 | savepath: str, 357 | nexamples: int, 358 | seed: int=None, 359 | nthread: int=3, 360 | workdir: str="./workdir"): 361 | """ 362 | This method creates a dataset. If multiple threads or processes generate 363 | the dataset, it may not be totally reproducible due to a different 364 | random seed attributed to each process or thread. 365 | 366 | @params: 367 | pars (ModelParameter): A ModelParamter object 368 | savepath (str) : Path in which to create the dataset 369 | nexamples (int): Number of examples to generate 370 | seed (int): Seed for random model generator 371 | nthread (int): Number of processes used to generate examples 372 | workdir (str): Name of the directory for temporary files 373 | 374 | @returns: 375 | 376 | """ 377 | 378 | if not os.path.isdir(savepath): 379 | os.mkdir(savepath) 380 | 381 | generators = [] 382 | for jj in range(nthread): 383 | this_workdir = workdir + "_" + str(jj) 384 | if seed is not None: 385 | thisseed = seed * (jj + 1) 386 | else: 387 | thisseed = seed 388 | thisgen = DatasetGenerator(pars, 389 | savepath, 390 | this_workdir, 391 | nexamples, 392 | [], 393 | seed=thisseed) 394 | thisgen.start() 395 | generators.append(thisgen) 396 | for gen in generators: 397 | gen.join() 398 | 399 | class DatasetGenerator(Process): 400 | """ 401 | This class creates a new process to generate seismic data. 402 | """ 403 | 404 | def __init__(self, 405 | parameters, 406 | savepath: str, 407 | workdir: str, 408 | nexamples: int, 409 | gpus: list, 410 | seed: int=None): 411 | """ 412 | Initialize the DatasetGenerator 413 | 414 | @params: 415 | parameters (ModelParameter): A ModelParamter object 416 | savepath (str) : Path in which to create the dataset 417 | workdir (str): Name of the directory for temporary files 418 | nexamples (int): Number of examples to generate 419 | gpus (list): List of gpus not to use. 420 | seed (int): Seed for random model generator 421 | 422 | @returns: 423 | """ 424 | super().__init__() 425 | 426 | self.savepath = savepath 427 | self.workdir = workdir 428 | self.nexamples = nexamples 429 | self.parameters = parameters 430 | self.gpus = gpus 431 | self.seed = seed 432 | if not os.path.isdir(savepath): 433 | os.mkdir(savepath) 434 | if not os.path.isdir(workdir): 435 | os.mkdir(workdir) 436 | try: 437 | parameters.save_parameters_to_disk(savepath 438 | + "/model_parameters.hdf5") 439 | except OSError: 440 | pass 441 | 442 | def run(self): 443 | """ 444 | Start the process to generate data 445 | """ 446 | n = len(fnmatch.filter(os.listdir(self.savepath), 'example_*')) 447 | gen = SeismicGenerator(model_parameters=self.parameters, 448 | gpus=self.gpus) 449 | if self.seed is not None: 450 | np.random.seed(self.seed) 451 | 452 | while n < self.nexamples: 453 | n = len(fnmatch.filter(os.listdir(self.savepath), 'example_*')) 454 | if self.seed is None: 455 | np.random.seed(n) 456 | data, vrms, vp, valid, tlabels = gen.compute_example(self.workdir) 457 | try: 458 | gen.write_example(n, self.savepath, data, vrms, vp, valid, tlabels) 459 | if n % 100 == 0: 460 | print("%f of examples computed" % (float(n)/self.nexamples)) 461 | except OSError: 462 | pass 463 | if os.path.isdir(self.workdir): 464 | rmtree(self.workdir) 465 | 466 | 467 | def mask_batch(batch, 468 | mask_fraction, 469 | mask_time_frac): 470 | 471 | for ii, el in enumerate(batch): 472 | data = el[0] 473 | NT = data.shape[0] 474 | ng = data.shape[1] 475 | 476 | #Mask time and offset 477 | frac = np.random.rand() * mask_time_frac 478 | twindow = int(frac * NT) 479 | owindow = int(frac * ng / 2) 480 | batch[ii][0][-twindow:,:] = 0 481 | batch[ii][0][:,:owindow] = 0 482 | batch[ii][0][:, -owindow:] = 0 483 | 484 | #take random subset of traces 485 | ntokill = int(np.random.rand()*mask_fraction*ng*frac) 486 | tokill = np.random.choice(np.arange(owindow, ng-owindow), ntokill, replace=False) 487 | batch[ii][0][:, tokill] = 0 488 | 489 | batch[ii][3][-twindow:] = 0 490 | 491 | 492 | 493 | return batch 494 | 495 | 496 | def mute_direct(data, vp0, pars, offsets=None): 497 | 498 | wind_length = int(2 / pars.peak_freq / pars.dt / pars.resampling) 499 | taper = np.arange(wind_length) 500 | taper = np.sin(np.pi * taper / (2 * wind_length - 1)) ** 2 501 | NT = data.shape[0] 502 | ng = data.shape[1] 503 | if offsets is None: 504 | if pars.gmin is None or pars.gmax is None: 505 | offsets = (np.arange(0, ng) - (ng) / 2) * pars.dh * pars.dg 506 | else: 507 | offsets = (np.arange(pars.gmin, pars.gmax, pars.dg)) * pars.dh 508 | 509 | 510 | 511 | for ii, off in enumerate(offsets): 512 | tmute = int(( np.abs(off) / vp0 + 1.5 * pars.tdelay) / pars.dt / pars.resampling) 513 | if tmute <= NT: 514 | data[0:tmute, ii] = 0 515 | mute_max = np.min([tmute + wind_length, NT]) 516 | nmute = mute_max - tmute 517 | data[tmute:mute_max, ii] = data[tmute:mute_max, ii] * taper[:nmute] 518 | else: 519 | data[:, ii] = 0 520 | 521 | return data 522 | 523 | def random_static(data, max_static): 524 | 525 | ng = data.shape[1] 526 | shifts = (np.random.rand(ng) - 0.5) * max_static * 2 527 | for ii in range(ng): 528 | data[:, ii] = np.roll(data[:, ii], int(shifts[ii]), 0) 529 | return data 530 | 531 | def random_noise(data, max_amp): 532 | 533 | max_amp = max_amp * np.max(data) * 2.0 534 | data = data + (np.random.rand(data.shape[0],data.shape[1] ) - 0.5) * max_amp 535 | return data 536 | 537 | def mute_nearoffset(data, max_off): 538 | 539 | data[:,:np.random.randint(max_off)] *= 0 540 | return data 541 | 542 | def random_filt(data, filt_length): 543 | filt_length = int((np.random.randint(filt_length)//2)*2 +1) 544 | filt = np.random.rand(filt_length, 1) 545 | data = convolve2d(data, filt, 'same') 546 | return data 547 | 548 | def random_time_scaling(data, dt, emin=-2.0, emax=2.0, scalmax=None): 549 | t = np.reshape(np.arange(0, data.shape[0]) * dt, [data.shape[0], 1]) 550 | e = np.random.rand() * (emax - emin) + emin 551 | scal = (t+1e-6) ** e 552 | if scalmax is not None: 553 | scal[scal>scalmax] = scalmax 554 | return data * scal 555 | 556 | 557 | -------------------------------------------------------------------------------- /vrmslearn/Tester.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | This class tests a NN on a dataset. 5 | """ 6 | from vrmslearn.RCNN import RCNN 7 | from vrmslearn.ModelParameters import ModelParameters 8 | from vrmslearn.Inputqueue import BatchManager 9 | from vrmslearn.SeismicGenerator import SeismicGenerator 10 | import tensorflow as tf 11 | import time 12 | import os 13 | import fnmatch 14 | import h5py as h5 15 | import copy 16 | 17 | class Tester(object): 18 | """ 19 | This class tests a NN on a dataset. 20 | """ 21 | 22 | def __init__(self, 23 | NN: RCNN, 24 | data_generator: SeismicGenerator): 25 | """ 26 | Initialize the tester 27 | 28 | @params: 29 | NN (RCNN) : A tensforlow neural net 30 | data_generator (SeismicGenerator): A data generator object 31 | 32 | @returns: 33 | """ 34 | self.NN = NN 35 | self.data_generator = data_generator 36 | 37 | def test_dataset(self, 38 | savepath: str, 39 | toeval: list, 40 | toeval_names: list, 41 | testpath: str = None, 42 | filename: str = 'example_*', 43 | restore_from: str = None, 44 | tester_n: int = 0): 45 | """ 46 | This method evaluate predictions on all examples contained in savepath, 47 | and save the predictions in hdf5 files. 48 | 49 | @params: 50 | savepath (str) : The path in which the test examples are found 51 | filename (str): The structure of the examples' filenames 52 | toeval (list): List of tensors to predict 53 | restore_from (str): File containing the trained weights 54 | tester_n (int): Label for the model to use for prediction 55 | 56 | @returns: 57 | """ 58 | 59 | if testpath is None: 60 | testpath = savepath 61 | # Do the testing 62 | examples = fnmatch.filter(os.listdir(testpath), filename) 63 | predictions = fnmatch.filter(os.listdir(savepath), filename) 64 | with self.NN.graph.as_default(): 65 | saver = tf.train.Saver() 66 | with tf.Session() as sess: 67 | saver.restore(sess, restore_from) 68 | batch = [] 69 | bexamples = [] 70 | for ii, example in enumerate(examples): 71 | predname = example + "_pred" + str(tester_n) 72 | if "pred" not in example and predname not in predictions: 73 | bexamples.append(example) 74 | batch.append(self.data_generator.read_example(savedir=testpath, 75 | filename=example)) 76 | 77 | if len(batch) == self.NN.batch_size: 78 | batch = self.data_generator.aggregate_examples(batch) 79 | feed_dict = dict(zip(self.NN.feed_dict, batch)) 80 | evaluated = sess.run(toeval, feed_dict=feed_dict) 81 | for jj, bexample in enumerate(bexamples): 82 | savefile = h5.File(savepath + "/" + bexample + "_pred" + str(tester_n), "w") 83 | for kk, el in enumerate(toeval_names): 84 | savefile[el] = evaluated[kk][jj, :] 85 | savefile.close() 86 | batch = [] 87 | bexamples = [] 88 | 89 | def get_preds(self, 90 | labelname: str, 91 | predname: str, 92 | savepath: str, 93 | testpath: str = None, 94 | filename: str = 'example_*', 95 | maskname: str = None, 96 | tester_n: int = 0): 97 | """ 98 | This method returns the labels and the predictions for an output. 99 | 100 | @params: 101 | labelname (str) : Name of the labels in the example file 102 | predname (str) : Name of the predictions in the example file 103 | maskname(str) : name of the valid predictions mask 104 | savepath (str) : The path in which the test examples are found 105 | filename (str): The structure of the examples' filenames 106 | tester_n (int): Label for the model to use for prediction 107 | 108 | @returns: 109 | labels (list): List containing all labels 110 | preds (list): List containing all predictions 111 | """ 112 | 113 | if testpath is None: 114 | testpath = savepath 115 | examples = fnmatch.filter(os.listdir(testpath), filename) 116 | predictions = fnmatch.filter(os.listdir(savepath), filename) 117 | labels = [] 118 | preds = [] 119 | masks = [] 120 | lfiles = [] 121 | pfiles = [] 122 | for ii, example in enumerate(examples): 123 | example_pred = example + "_pred" + str(tester_n) 124 | if "pred" not in example and example_pred in predictions: 125 | pfiles.append(savepath + "/" + example_pred) 126 | pfile = h5.File(pfiles[-1], "r") 127 | preds.append(pfile[predname][:]) 128 | pfile.close() 129 | lfiles.append(testpath + "/" + example) 130 | lfile = h5.File(lfiles[-1], "r") 131 | labels.append(lfile[labelname][:]) 132 | if maskname is not None: 133 | masks.append(lfile[maskname][:]) 134 | lfile.close() 135 | 136 | 137 | 138 | return labels, preds, masks, lfiles, pfiles 139 | -------------------------------------------------------------------------------- /vrmslearn/Trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | This class trains the neural network 5 | """ 6 | 7 | from vrmslearn.RCNN import RCNN 8 | from vrmslearn.Inputqueue import BatchManager 9 | from vrmslearn.SeismicGenerator import SeismicGenerator 10 | import tensorflow as tf 11 | import time 12 | 13 | 14 | 15 | class Trainer(object): 16 | """ 17 | This class takes a NN model defined in tensorflow and performs the training 18 | """ 19 | 20 | def __init__(self, 21 | NN: RCNN, 22 | data_generator: SeismicGenerator, 23 | checkpoint_dir: str="./logs", 24 | replace_examples = False, 25 | learning_rate: float = 0.001, 26 | beta1: float = 0.9, 27 | beta2: float = 0.999, 28 | epsilon: float = 1e-8, 29 | var_to_minimize: list = None, 30 | totrain = True): 31 | 32 | self.NN = NN 33 | self.data_generator = data_generator 34 | self.checkpoint_dir = checkpoint_dir 35 | with self.NN.graph.as_default(): 36 | self.global_step = tf.train.get_or_create_global_step() 37 | 38 | # Output the graph for Tensorboard 39 | writer = tf.summary.FileWriter(self.checkpoint_dir, 40 | graph=tf.get_default_graph()) 41 | writer.close() 42 | if totrain: 43 | self.tomin = self.define_optimizer(learning_rate, 44 | beta1, 45 | beta2, 46 | epsilon, 47 | var_to_minimize) 48 | 49 | def define_optimizer(self, 50 | learning_rate: float=0.001, 51 | beta1: float=0.9, 52 | beta2: float=0.999, 53 | epsilon: float=1e-8, 54 | var_to_minimize: list=None): 55 | """ 56 | This method creates an optimization node 57 | 58 | @params: 59 | 60 | @returns: 61 | tomin (tf.tensor) : Output of the optimizer node. 62 | """ 63 | with tf.name_scope("Optimizer"): 64 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate, 65 | beta1=beta1, 66 | beta2=beta2, 67 | epsilon=epsilon, 68 | name="Adam") 69 | 70 | # Add node to minimize loss 71 | if var_to_minimize: 72 | tomin = opt.minimize(self.NN.loss, 73 | global_step=self.global_step, 74 | var_list=var_to_minimize) 75 | else: 76 | tomin = opt.minimize(self.NN.loss, global_step=self.global_step) 77 | 78 | return tomin 79 | 80 | 81 | def train_model(self, 82 | niter: int = 10, 83 | savepath: str = None, 84 | restore_from: str = None, 85 | thread_read: int = 1): 86 | """ 87 | This method trains the model. The training is restarted automatically 88 | if any checkpoints are found in self.checkpoint_dir. 89 | 90 | @params: 91 | niter (int) : Number of total training iterations to run. 92 | savepath (str): Directory of list of directories of examples 93 | restore_from (str): Checkpoint file from which to initialize parameters 94 | 95 | @returns: 96 | """ 97 | 98 | # Print optimizer settings being used, batch size, niter 99 | print("number of iterations (niter) = " + str(niter)) 100 | 101 | # Do the learning 102 | generator_fun = [lambda: self.data_generator.read_example(savedir=savepath)] * thread_read 103 | with BatchManager(batch_size=self.NN.batch_size, 104 | generator_fun=generator_fun) as batch_queue: 105 | 106 | with self.NN.graph.as_default(): 107 | summary_op = tf.summary.merge_all() 108 | 109 | # The StopAtStepHook handles stopping after running given steps. 110 | hooks = [tf.train.StopAtStepHook(last_step=niter), 111 | tf.train.SummarySaverHook(save_steps=10, 112 | summary_op=summary_op), 113 | tf.train.CheckpointSaverHook( 114 | checkpoint_dir=self.checkpoint_dir, 115 | save_steps=100, 116 | saver=tf.train.Saver(max_to_keep=None))] 117 | 118 | if restore_from is not None: 119 | saver = tf.train.Saver(tf.trainable_variables()) 120 | with tf.Session() as sess: 121 | saver.restore(sess, restore_from) 122 | vars = tf.trainable_variables() 123 | trained_variables = sess.run(vars) 124 | 125 | assigns = [tf.assign(v, trained_variables[ii]) 126 | for ii, v in enumerate(vars)] 127 | 128 | # Run the training iterations 129 | with tf.train.MonitoredTrainingSession(checkpoint_dir=self.checkpoint_dir, 130 | save_checkpoint_secs=None, 131 | save_summaries_steps=1, 132 | hooks=hooks) as sess: 133 | 134 | if restore_from is not None: 135 | batch = batch_queue.next_batch() 136 | batch = self.data_generator.aggregate_examples(batch) 137 | feed_dict = dict(zip(self.NN.feed_dict, batch)) 138 | step = sess.run(self.global_step, feed_dict=feed_dict) 139 | if step == 0: 140 | sess.run(assigns, feed_dict=feed_dict) 141 | 142 | while not sess.should_stop(): 143 | t0 = time.time() 144 | batch = batch_queue.next_batch() 145 | batch = self.data_generator.aggregate_examples(batch) 146 | t1 = time.time() 147 | feed_dict = dict(zip(self.NN.feed_dict, batch)) 148 | 149 | step, loss, _ = sess.run([self.global_step, 150 | self.NN.loss, 151 | self.tomin], 152 | feed_dict=feed_dict) 153 | t2 = time.time() 154 | print( 155 | "Iteration %d, loss: %f, t_batch: %f, t_graph: %f, nqueue: %d" 156 | % (step, loss, t1 - t0, t2 - t1, 157 | batch_queue.n_in_queue.value)) 158 | 159 | 160 | def evaluate(self, toeval, niter, dir=None, batch=None): 161 | """ 162 | This method compute outputs contained in toeval of a NN. 163 | 164 | @params: 165 | niter (int) : Training iterations of the checkpoint 166 | dir (str): Directory of the checkpoint 167 | batch (tuple): A batch as created to batch_generator, to predict from 168 | 169 | @returns: 170 | data (np.array): Modeled data 171 | vrms (np.array): Rms velocity 172 | vint ( 173 | """ 174 | 175 | if dir is None: 176 | dir = self.checkpoint_dir 177 | 178 | feed_dict = dict(zip(self.NN.feed_dict, batch)) 179 | 180 | with self.NN.graph.as_default(): 181 | saver = tf.train.Saver() 182 | with tf.Session() as sess: 183 | saver.restore(sess, dir + '/model.ckpt-' + str(niter)) 184 | evaluated = sess.run(toeval, 185 | feed_dict=feed_dict) 186 | 187 | return evaluated 188 | --------------------------------------------------------------------------------