├── README.md └── MCCNN.py /README.md: -------------------------------------------------------------------------------- 1 | markdown 2 | # AIgravity: Recovering Gravity from Satellite Altimetry Data Using Deep Learning 3 | 4 | Welcome to the AIgravity project! This project uses deep learning techniques to recover the ocean gravity field from satellite altimetry data. 5 | 6 | ## Paper 7 | 8 | Our paper, titled 'Recovering Gravity from Satellite Altimetry Data using Deep Learning Network', has been accepted by the Journal of IEEE Transactions on Geoscience and Remote Sensing (TGRS). In this paper, we detail the methods and results of this project. 9 | 10 | ## Code 11 | 12 | The code for this project is available in this repository. It includes all the scripts and data files necessary to reproduce our results. We have also included detailed comments in the code to explain how it works. 13 | 14 | ## Usage 15 | 16 | To use this code, you will need to have Python and several Python libraries installed, including TensorFlow, NumPy, and Matplotlib. You can run the code on any system that supports Python. 17 | 18 | ## Citation 19 | 20 | If you find this code useful in your research, please consider citing our paper. You can use the following citation: 21 | 22 | ``` 23 | @article{zhu_recovering_2023, 24 | title = {Recovering {Gravity} from {Satellite} {Altimetry} {Data} using {Deep} {Learning} {Network}}, 25 | copyright = {All rights reserved}, 26 | issn = {1558-0644}, 27 | doi = {10.1109/TGRS.2023.3280261}, 28 | journal = {IEEE Transactions on Geoscience and Remote Sensing}, 29 | author = {Zhu, Chengcheng and Yang, Lei and Bian, Hongwei and Li, Houpu and Guo, Jinyun and Liu, Na and Lin, Lina}, 30 | year = {2023}, 31 | note = {Conference Name: IEEE Transactions on Geoscience and Remote Sensing}, 32 | keywords = {satellite altimetry, gravity anomaly, Sea measurements, Satellites, Data models, Deep learning, Training, Gravity, deep learning, multi-channel convolutional neural network, submarine topography, Underwater vehicles}, 33 | pages = {1--1}, 34 | } 35 | 36 | ``` 37 | 38 | ## Link 39 | 40 | For more details, please refer to our paper on [IEEE Xplore](https://ieeexplore.ieee.org/document/10136743). 41 | 42 | ## Contact 43 | 44 | If you have any questions or suggestions about this project, please feel free to open an issue or submit a pull request. 45 | -------------------------------------------------------------------------------- /MCCNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import gc 4 | from keras.layers import Dense, Dropout, Flatten, Concatenate 5 | from keras.layers import Conv2D, MaxPooling2D 6 | from keras.optimizers import Adam 7 | from keras.callbacks import EarlyStopping 8 | import netCDF4 as nc 9 | from sklearn.preprocessing import StandardScaler 10 | from keras import regularizers, Model, Input 11 | 12 | 13 | 14 | 15 | def r2_score(y_true, y_pred): 16 | ''' 17 | R^2 (coefficient of determination) regression score function. 18 | 19 | Best possible score is 1.0 and it can be negative (because the 20 | model can be arbitrarily worse). A constant model that always 21 | predicts the expected value of y, disregarding the input features, 22 | would get a R^2 score of 0.0. 23 | 24 | Read more in the :ref:`User Guide `. 25 | 26 | Parameters 27 | ---------- 28 | y_true : array-like of shape = (n_samples) or (n_samples, n_outputs) 29 | Ground truth (correct) target values. 30 | 31 | y_pred : array-like of shape = (n_samples) or (n_samples, n_outputs) 32 | Estimated target values. 33 | ''' 34 | 35 | 36 | numerator = ((y_true - y_pred) ** 2).sum() 37 | denominator = ((y_true - np.average(y_true)) ** 2).sum() 38 | r2=1-numerator/denominator 39 | 40 | return r2 41 | def define_model(xx_train): 42 | #MCCNN model 43 | # channel 1 44 | In_1 = Input(shape=(xx_train.shape[1], xx_train.shape[2], 1)) 45 | model_1 = Conv2D(filters=32,strides=1,kernel_size=4, activation='tanh')(In_1) 46 | 47 | model_1= MaxPooling2D(pool_size=2)(model_1) 48 | model_1= Conv2D(filters=8, kernel_size=4,strides=1, activation='tanh')(model_1) 49 | 50 | model_1 = MaxPooling2D(pool_size=2)(model_1) 51 | model_1= Flatten()(model_1) 52 | # channel 2 53 | In_2 = Input(shape=(xx_train.shape[1], xx_train.shape[2], 1)) 54 | model_2 = Conv2D(filters=32, kernel_size=4, strides=1,activation='tanh')(In_2) 55 | 56 | model_2 = MaxPooling2D(pool_size=2)(model_2) 57 | model_2 = Conv2D(filters=8, kernel_size=4, strides=1,activation='tanh')(model_2) 58 | 59 | model_2 = MaxPooling2D(pool_size=2)(model_2) 60 | model_2 = Flatten()(model_2) 61 | # channel 3 62 | In_3 = Input(shape=(xx_train.shape[1],xx_train.shape[2],1)) #shape 63 | model_3 = Conv2D(filters=32, kernel_size=4, strides=1,activation='tanh')(In_3) 64 | 65 | model_3 = MaxPooling2D(pool_size=2)(model_3) 66 | model_3 = Conv2D(filters=8, kernel_size=4, strides=1,activation='tanh')(model_3) 67 | 68 | model_3 = MaxPooling2D(pool_size=2)(model_3) 69 | model_3 = Flatten()(model_3) 70 | # channel 4 71 | In_4 = Input(shape=(xx_train.shape[1],xx_train.shape[2],1))#shape 72 | model_4 = Conv2D(filters=32, kernel_size=4, strides=1,activation='tanh')(In_4) 73 | 74 | model_4 = MaxPooling2D(pool_size=2)(model_4) 75 | model_4 = Conv2D(filters=8, kernel_size=4, strides=1,activation='tanh')(model_4) 76 | 77 | model_4 = MaxPooling2D(pool_size=2)(model_4) 78 | model_4 = Flatten()(model_4) 79 | # channel 5 80 | In_5 = Input(shape=(xx_train.shape[1],xx_train.shape[2],1)) #shape 81 | model_5 = Conv2D(filters=32, kernel_size=4, strides=1,activation='tanh')(In_5) 82 | 83 | model_5 = MaxPooling2D(pool_size=2)(model_5) 84 | model_5 = Conv2D(filters=8, kernel_size=4, strides=1,activation='tanh')(model_5) 85 | 86 | model_5 = MaxPooling2D(pool_size=2)(model_5) 87 | model_5 = Flatten()(model_5) 88 | # channel 6 89 | In_6 = Input(shape=(xx_train.shape[1],xx_train.shape[2],1)) #shape(model_6) 90 | model_6 = Conv2D(filters=32, kernel_size=4, strides=1,activation='tanh')(In_6) 91 | 92 | model_6 = MaxPooling2D(pool_size=2)(model_6) 93 | model_6 = Conv2D(filters=8, kernel_size=4, strides=1,activation='tanh')(model_6) 94 | 95 | model_6 = MaxPooling2D(pool_size=2)(model_6) 96 | model_6 = Flatten()(model_6) 97 | # channel 7 98 | In_7 = Input(shape=(xx_train.shape[1],xx_train.shape[2],1))#shape(model_7) 99 | model_7 = Conv2D(filters=32, kernel_size=4, strides=1,activation='tanh')(In_7) 100 | 101 | model_7 = MaxPooling2D(pool_size=2)(model_7) 102 | model_7 = Conv2D(filters=8, kernel_size=4, strides=1,activation='tanh')(model_7) 103 | 104 | model_7 = MaxPooling2D(pool_size=2)(model_7) 105 | model_7 = Flatten()(model_7) 106 | #combine 107 | merged = Concatenate()([model_1, model_2, model_3,model_4,model_5,model_6,model_7]) #merged 108 | dense1 = Dense(256, activation='tanh',use_bias=True,kernel_regularizer= regularizers.l1(0.01)) (merged)# interpretation 109 | output = Dense(1,use_bias=True,kernel_regularizer= regularizers.l1(0.01)) (dense1) 110 | model = Model(inputs=[In_1,In_2,In_3,In_4,In_5,In_6,In_7], outputs=output) 111 | # compile 112 | adam1 = Adam(learning_rate=0.001) 113 | model.compile(loss='mse',optimizer=adam1,metrics=['mse']) 114 | return model 115 | 116 | 117 | #import files of marine topography and residual DOVs 118 | Data=nc.Dataset('top.nc') # marine topography 119 | h_gw= (Data.variables["elevation"][:].data).T 120 | Data2=nc.Dataset('e.nc')#residual DOV for east(altimeter-derived DOV-reference DOV) 121 | Data3=nc.Dataset('n.nc')#residual DOV for north(altimeter-derived DOV-reference DOV) 122 | e_gw=(Data2.variables["z"][:].data).T 123 | n_gw=(Data3.variables["z"][:].data).T 124 | 125 | #import ship-borne data for training 126 | ifile='train0.dat' 127 | # Z1 training file: lon,lat,residual gravity anomaly (ship-borne -reference) 128 | Z1 = np.loadtxt(ifile) 129 | y_train = Z1[:, 3] # residual gravity anomaly for train 130 | #inputs for train 131 | x_train = np.zeros((len(Z1[:, 1]), 64, 64, 7)) 132 | for i in range(0,len(Z1[:,1])): 133 | lon_num_min=math.floor((Z1[i,0]-124.002)*60.0*4.0)-31 134 | lon_num_max=math.floor((Z1[i,0]-124.002)*60.0*4.0)+33 135 | lat_num_min=math.floor((Z1[i,1]-9.0021)*60.0*4.0)-31 136 | lat_num_max = math.floor((Z1[i, 1]-9.0021) * 60.0*4.0) +33 137 | for k in range(0,64): 138 | for l in range(0,64): 139 | x_train[i, k, l,0] = lon_num_min / 60.0/4.0 + 124.002 + 1.0 / 60.0*0.25 * k 140 | x_train[i, k, l,1] = lat_num_min / 60.0/4.0 + 9.0021 + 1.0 / 60.0 * 0.25*l 141 | x_train[i, k, l,2] = lon_num_min / 60.0/4.0 + 124.002 + 1.0 / 60.0 * 0.25*k- Z1[i, 0] 142 | x_train[i, k, l,3] = lat_num_min / 60.0/4.0 + 9.0021 + 1.0 / 60.0 *0.25* l-Z1[i, 1] 143 | x_train[i, 0:64, 0:64,6] = h_gw[lon_num_min:lon_num_max, lat_num_min:lat_num_max] 144 | x_train[i, 0:64, 0:64,5] = e_gw[lon_num_min:lon_num_max, lat_num_min:lat_num_max] 145 | x_train[i, 0:64, 0:64,4] = n_gw[lon_num_min:lon_num_max, lat_num_min:lat_num_max] 146 | del h_gw 147 | del e_gw 148 | del n_gw 149 | gc.collect() 150 | 151 | #inputs are standardized by removing the mean and scaling them to unit variance in each channel 152 | mean_train = np.zeros((7)) 153 | std_train = np.zeros((7)) 154 | scaler = StandardScaler() 155 | for i in range(0,7): 156 | mean_train[i]=np.mean(x_train[:,:,:,i]) 157 | std_train[i]=np.std(x_train[:,:,:,i]) 158 | x_train[:,:,:,i]=(x_train[:,:,:,i]-mean_train[i])/std_train[i] 159 | print(mean_train[i],std_train[i]) 160 | 161 | #training model 162 | print('training~~~') 163 | model=define_model(x_train) 164 | deta=0.02 165 | model.fit([x_train[:,:,:,0],x_train[:,:,:,1],x_train[:,:,:,2],x_train[:,:,:,3],x_train[:,:,:,4],x_train[:,:,:,5],x_train[:,:,:,6]],y_train,batch_size=512, shuffle=True,validation_split=0.05,epochs=20, callbacks=[EarlyStopping(monitor='mse',min_delta=deta,patience=3)]) 166 | 167 | # calculate the r2_score for train 168 | score_train = r2_score(y_train, np.transpose(model.predict([x_train[:,:,:,0],x_train[:,:,:,1],x_train[:,:,:,2],x_train[:,:,:,3],x_train[:,:,:,4],x_train[:,:,:,5],x_train[:,:,:,6]]))) 169 | print(score_train) 170 | 171 | #save CNN model 172 | model.save('CNN20.h5') 173 | 174 | 175 | 176 | 177 | --------------------------------------------------------------------------------