├── CS-I ├── CS_I_D_DeepONet_GitHub.ipynb ├── CS_I_POD_GP_GitHub.m └── CS_I_VB_DeepONet_GitHub.ipynb ├── CS-II ├── CS_II_D_DeepONet_GitHub.ipynb ├── CS_II_POD_GP_GitHub.m └── CS_II_VB_DeepONet_GitHub.ipynb ├── CS-III ├── CS_III_D_DeepONet_GitHub.ipynb ├── CS_III_DenseED_GitHub.ipynb ├── CS_III_POD_GP_GitHub.m └── CS_III_VB_DeepONet_GitHub.ipynb ├── CS-IV ├── CS_IV_D_DeepONet_GitHub.ipynb ├── CS_IV_DenseED_GitHub.ipynb ├── CS_IV_POD_GP_GitHub.m └── CS_IV_VB_DeepONet_GitHub.ipynb ├── LICENSE └── README.md /CS-I/CS_I_D_DeepONet_GitHub.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6dc23078-05ca-4bbe-bcb0-e21d2a079cf7", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%reset -f\n", 11 | "import h5py\n", 12 | "import time as t\n", 13 | "import numpy as np\n", 14 | "import scipy as sp\n", 15 | "import scipy.io as spi\n", 16 | "import tensorflow as tf\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import tensorflow_probability as tfp" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "0e9d027f-dcf9-429d-a2b6-9377b20237fa", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "text/plain": [ 30 | "(60000, (100,), (60000, 100), (60000, 1), (60000, 1))" 31 | ] 32 | }, 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "output_type": "execute_result" 36 | } 37 | ], 38 | "source": [ 39 | "isamp = 3000\n", 40 | "t = np.arange(0,1,0.01)\n", 41 | "pps = 20\n", 42 | "tsamp = pps*isamp\n", 43 | "\n", 44 | "data_train = spi.loadmat('AD_5000_DP_TrData.mat')\n", 45 | "\n", 46 | "u_in = data_train['u_in'][0:tsamp,:]\n", 47 | "x_t_in = data_train['x_t_in'][0:tsamp,:]\n", 48 | "s_in = data_train['s_in'][0:tsamp,:]\n", 49 | "\n", 50 | "tsamp, t.shape, u_in.shape, x_t_in.shape, s_in.shape" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "id": "cdc7b4e8-f1a8-4a64-864f-1ad936d997a2", 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "(2.814752324700168,\n", 63 | " -3.0266588699038914,\n", 64 | " 3.6565790163373184,\n", 65 | " -3.6177953494816775,\n", 66 | " 0.99,\n", 67 | " 0.0)" 68 | ] 69 | }, 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "max_u = np.max(u_in)\n", 77 | "min_u = np.min(u_in)\n", 78 | "\n", 79 | "u_in = (u_in-min_u)/(max_u-min_u)\n", 80 | "\n", 81 | "max_t = np.max(x_t_in)\n", 82 | "min_t = np.min(x_t_in)\n", 83 | "\n", 84 | "x_t_in = (x_t_in-min_t)/(max_t-min_t)\n", 85 | "\n", 86 | "max_s = np.max(s_in)\n", 87 | "min_s = np.min(s_in)\n", 88 | "\n", 89 | "s_in = (s_in-min_s)/(max_s-min_s)\n", 90 | "\n", 91 | "max_s, min_s, max_u, min_u, max_t, min_t" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "id": "cc8fdc38-c3cb-47e3-ab99-f0870272984a", 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "Model: \"model\"\n", 105 | "__________________________________________________________________________________________________\n", 106 | " Layer (type) Output Shape Param # Connected to \n", 107 | "==================================================================================================\n", 108 | " inputsB (InputLayer) [(None, 100)] 0 [] \n", 109 | " \n", 110 | " inputsT (InputLayer) [(None, 1)] 0 [] \n", 111 | " \n", 112 | " dense (Dense) (None, 30) 3030 ['inputsB[0][0]'] \n", 113 | " \n", 114 | " dense_3 (Dense) (None, 30) 60 ['inputsT[0][0]'] \n", 115 | " \n", 116 | " dense_1 (Dense) (None, 30) 930 ['dense[0][0]'] \n", 117 | " \n", 118 | " dense_4 (Dense) (None, 30) 930 ['dense_3[0][0]'] \n", 119 | " \n", 120 | " dense_2 (Dense) (None, 30) 930 ['dense_1[0][0]'] \n", 121 | " \n", 122 | " dense_5 (Dense) (None, 30) 930 ['dense_4[0][0]'] \n", 123 | " \n", 124 | " lambda (Lambda) (None, 1) 0 ['dense_2[0][0]', \n", 125 | " 'dense_5[0][0]'] \n", 126 | " \n", 127 | " dense_6 (Dense) (None, 1) 2 ['lambda[0][0]'] \n", 128 | " \n", 129 | "==================================================================================================\n", 130 | "Total params: 6,812\n", 131 | "Trainable params: 6,812\n", 132 | "Non-trainable params: 0\n", 133 | "__________________________________________________________________________________________________\n", 134 | "Iteration: 0\n", 135 | "6/6 [==============================] - 1s 80ms/step - loss: 0.3378\n", 136 | "Iteration: 250\n", 137 | "6/6 [==============================] - 0s 56ms/step - loss: 0.0017\n", 138 | "Iteration: 500\n", 139 | "6/6 [==============================] - 0s 59ms/step - loss: 3.2616e-04\n", 140 | "Iteration: 750\n", 141 | "6/6 [==============================] - 0s 18ms/step - loss: 1.0353e-04\n", 142 | "Iteration: 1000\n", 143 | "6/6 [==============================] - 0s 15ms/step - loss: 4.6628e-05\n", 144 | "Iteration: 1250\n", 145 | "6/6 [==============================] - 0s 17ms/step - loss: 1.4813e-05\n", 146 | "Iteration: 1500\n", 147 | "6/6 [==============================] - 0s 14ms/step - loss: 1.2703e-05\n", 148 | "Iteration: 1750\n", 149 | "6/6 [==============================] - 0s 17ms/step - loss: 4.3009e-06\n", 150 | "Iteration: 2000\n", 151 | "6/6 [==============================] - 0s 79ms/step - loss: 2.3669e-06\n", 152 | "Iteration: 2250\n", 153 | "6/6 [==============================] - 0s 14ms/step - loss: 4.7548e-06\n", 154 | "Iteration: 2500\n", 155 | "6/6 [==============================] - 0s 14ms/step - loss: 1.4258e-06\n", 156 | "Iteration: 2750\n", 157 | "6/6 [==============================] - 0s 17ms/step - loss: 1.2956e-06\n", 158 | "Iteration: 3000\n", 159 | "6/6 [==============================] - 0s 17ms/step - loss: 2.7571e-05\n", 160 | "Iteration: 3250\n", 161 | "6/6 [==============================] - 0s 14ms/step - loss: 9.3125e-07\n", 162 | "Iteration: 3500\n", 163 | "6/6 [==============================] - 0s 20ms/step - loss: 1.0879e-06\n", 164 | "Iteration: 3750\n", 165 | "6/6 [==============================] - 0s 15ms/step - loss: 5.2974e-06\n", 166 | "Iteration: 4000\n", 167 | "6/6 [==============================] - 0s 17ms/step - loss: 9.9447e-07\n", 168 | "Iteration: 4250\n", 169 | "6/6 [==============================] - 0s 11ms/step - loss: 9.8116e-07\n", 170 | "Iteration: 4500\n", 171 | "6/6 [==============================] - 0s 12ms/step - loss: 1.7023e-06\n", 172 | "Iteration: 4750\n", 173 | "6/6 [==============================] - 0s 11ms/step - loss: 7.9464e-07\n", 174 | "Iteration: 5000\n", 175 | "6/6 [==============================] - 0s 14ms/step - loss: 6.9840e-07\n", 176 | "Iteration: 5250\n", 177 | "6/6 [==============================] - 0s 14ms/step - loss: 3.8651e-07\n", 178 | "Iteration: 5500\n", 179 | "6/6 [==============================] - 0s 14ms/step - loss: 2.5284e-05\n", 180 | "Iteration: 5750\n", 181 | "6/6 [==============================] - 0s 12ms/step - loss: 5.0177e-07\n", 182 | "Iteration: 6000\n", 183 | "6/6 [==============================] - 0s 12ms/step - loss: 4.7943e-07\n", 184 | "Iteration: 6250\n", 185 | "6/6 [==============================] - 0s 12ms/step - loss: 4.2561e-07\n", 186 | "Iteration: 6500\n", 187 | "6/6 [==============================] - 0s 11ms/step - loss: 3.8821e-07\n", 188 | "Iteration: 6750\n", 189 | "6/6 [==============================] - 0s 14ms/step - loss: 8.9071e-06\n", 190 | "Iteration: 7000\n", 191 | "6/6 [==============================] - 0s 9ms/step - loss: 4.8193e-07\n", 192 | "Iteration: 7250\n", 193 | "6/6 [==============================] - 0s 11ms/step - loss: 1.2941e-06\n", 194 | "Iteration: 7500\n", 195 | "6/6 [==============================] - 0s 6ms/step - loss: 1.3633e-06\n", 196 | "Iteration: 7750\n", 197 | "6/6 [==============================] - 0s 6ms/step - loss: 2.4018e-07\n", 198 | "Iteration: 8000\n", 199 | "6/6 [==============================] - 0s 9ms/step - loss: 9.2562e-07\n", 200 | "Iteration: 8250\n", 201 | "6/6 [==============================] - 0s 6ms/step - loss: 3.2380e-07\n", 202 | "Iteration: 8500\n", 203 | "6/6 [==============================] - 0s 9ms/step - loss: 2.9737e-06\n", 204 | "Iteration: 8750\n", 205 | "6/6 [==============================] - 0s 8ms/step - loss: 3.2556e-07\n", 206 | "Iteration: 9000\n", 207 | "6/6 [==============================] - 0s 9ms/step - loss: 1.1267e-06\n", 208 | "Iteration: 9250\n", 209 | "6/6 [==============================] - 0s 7ms/step - loss: 2.1554e-07\n", 210 | "Iteration: 9500\n", 211 | "6/6 [==============================] - 0s 7ms/step - loss: 5.8274e-06\n", 212 | "Iteration: 9750\n", 213 | "6/6 [==============================] - 0s 9ms/step - loss: 1.5994e-07\n", 214 | "Total iterations: 10000\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "from tensorflow.keras.models import Model\n", 220 | "from tensorflow.keras.layers import Input, Lambda, Dense\n", 221 | "\n", 222 | "bs = 10000\n", 223 | "\n", 224 | "def fn(x):\n", 225 | " y = tf.einsum(\"ij, ij->i\", x[0], x[1])\n", 226 | " y = tf.expand_dims(y, axis = 1)\n", 227 | " return y\n", 228 | "\n", 229 | "hln = 30\n", 230 | "\n", 231 | "inputsB = Input(shape = (100,), name = 'inputsB')\n", 232 | "hiddenB = Dense(hln, activation = \"relu\")(inputsB)\n", 233 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 234 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 235 | "\n", 236 | "inputsT = Input(shape = (1,), name = 'inputsT')\n", 237 | "hiddenT = Dense(hln, activation = \"relu\")(inputsT)\n", 238 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 239 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 240 | "\n", 241 | "combined = Lambda(fn, output_shape = [None, 1])([hiddenB, hiddenT])\n", 242 | "output = Dense(1)(combined)\n", 243 | "\n", 244 | "model = Model(inputs = [inputsB, inputsT], outputs = output) \n", 245 | "\n", 246 | "model.compile(optimizer = tf.optimizers.Adam(learning_rate = 0.001), loss = 'mse')\n", 247 | "model.summary()\n", 248 | "\n", 249 | "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", 250 | " filepath = './ChkPts/CI_VDN_s1_10000/',\n", 251 | " save_weights_only=True,\n", 252 | " monitor='loss',\n", 253 | " mode='min',\n", 254 | " save_best_only=True)\n", 255 | "\n", 256 | "itr = 0\n", 257 | "for i in range(0, 40):\n", 258 | " print('Iteration: '+str(itr))\n", 259 | " itr = itr+1\n", 260 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = 1,\n", 261 | " verbose = 1, batch_size = bs, callbacks = [model_checkpoint_callback]) \n", 262 | "\n", 263 | " itr_ps = 250-1\n", 264 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = itr_ps,\n", 265 | " verbose = 0, batch_size = bs, callbacks = [model_checkpoint_callback])\n", 266 | " itr = itr+itr_ps\n", 267 | "\n", 268 | "print('Total iterations: '+str(itr))\n", 269 | "\n", 270 | "model.load_weights('./ChkPts/CI_VDN_s1_10000/')\n", 271 | "\n", 272 | "model.save_weights('./model/CI_VDN_s1_10000')\n", 273 | "\n", 274 | "# model.load_weights('./model/Dense_model_TF_weights_CI_VDN')" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 5, 280 | "id": "9cebf033-567b-4508-a037-5951bf7adb06", 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "\n", 288 | "0.00011219960841149184\n", 289 | "0.28067175942433253\n", 290 | "0.00039975382148035524\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "# data_test = spi.loadmat('test_ODE_AD.mat')\n", 296 | "\n", 297 | "# u_in_test = data_test['X_test0']\n", 298 | "# x_t_in_test = data_test['X_test1']\n", 299 | "# s_in_test = data_test['y_test']\n", 300 | "\n", 301 | "# u_in_test = (u_in_test-min_u)/(max_u-min_u)\n", 302 | "# x_t_in_test = (x_t_in_test-min_t)/(max_t-min_t)\n", 303 | "# s_in_test = (s_in_test-min_s)/(max_s-min_s)\n", 304 | "\n", 305 | "# pred = model({\"inputsB\":u_in_test, \"inputsT\":x_t_in_test})\n", 306 | "\n", 307 | "# print()\n", 308 | "# print(np.mean((s_in_test-pred)**2))\n", 309 | "# print(np.mean((s_in_test)**2))\n", 310 | "# print(np.mean((s_in_test-pred)**2)/np.mean((s_in_test)**2)) \n", 311 | "\n", 312 | "t = np.arange(0,1,0.01)\n", 313 | "testdata = spi.loadmat('AD_TestData.mat')\n", 314 | "\n", 315 | "u_in_test = testdata['u_in_test']\n", 316 | "x_t_in_test = testdata['x_t_in_test']\n", 317 | "s_in_test = testdata['s_in_test']\n", 318 | "\n", 319 | "u_in_test.shape, x_t_in_test.shape, s_in_test.shape\n", 320 | "\n", 321 | "u_in_test = (u_in_test-min_u)/(max_u-min_u)\n", 322 | "x_t_in_test = (x_t_in_test-min_t)/(max_t-min_t)\n", 323 | "s_in_test = (s_in_test-min_s)/(max_s-min_s)\n", 324 | "\n", 325 | "pred = model({\"inputsB\":u_in_test, \"inputsT\":x_t_in_test}).numpy()\n", 326 | "\n", 327 | "pred = (pred*(max_s-min_s))+min_s \n", 328 | "s_in_test = (s_in_test*(max_s-min_s))+min_s\n", 329 | "\n", 330 | "print()\n", 331 | "print(np.mean((s_in_test-pred)**2))\n", 332 | "print(np.mean((s_in_test)**2))\n", 333 | "print(np.mean((s_in_test-pred)**2)/np.mean((s_in_test)**2)) " 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 6, 339 | "id": "df7255ed-aa43-4a72-a18f-04c557e4623e", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "# s_in_test.shape, pred.shape" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 7, 349 | "id": "5df25778-f186-49eb-bc37-fc246b56d806", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "# string = './CI_VDN_PDFs.mat'\n", 354 | "# data_test = spi.savemat(string,{'pred':pred, 's_in_test':s_in_test})" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "id": "39d20e8f-c249-48ca-a252-809549ecb5d0", 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [] 364 | } 365 | ], 366 | "metadata": { 367 | "kernelspec": { 368 | "display_name": "Python 3 (ipykernel)", 369 | "language": "python", 370 | "name": "python3" 371 | }, 372 | "language_info": { 373 | "codemirror_mode": { 374 | "name": "ipython", 375 | "version": 3 376 | }, 377 | "file_extension": ".py", 378 | "mimetype": "text/x-python", 379 | "name": "python", 380 | "nbconvert_exporter": "python", 381 | "pygments_lexer": "ipython3", 382 | "version": "3.10.2" 383 | } 384 | }, 385 | "nbformat": 4, 386 | "nbformat_minor": 5 387 | } 388 | -------------------------------------------------------------------------------- /CS-I/CS_I_POD_GP_GitHub.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | 5 | tic 6 | 7 | %% %% 8 | 9 | IFR = 0.95; 10 | load AD_5000_DP_TrData.mat 11 | 12 | n = 60000; 13 | y = u_in(1:n,:); 14 | 15 | [~, sd, vd] = svd(y); 16 | 17 | sd = sd.^2; 18 | nd = 1; 19 | chkd = sum(diag(sd)); 20 | rd = 0; 21 | while rd < IFR 22 | rd = sum(diag(sd(1:nd,1:nd)))/chkd; 23 | nd = nd+1; 24 | end 25 | nd = nd-1; 26 | fprintf('\n\n%d\n\n',nd); 27 | 28 | red = y*vd(:,1:nd); 29 | 30 | toc 31 | 32 | %% 33 | 34 | in = [red, x_t_in(1:n, :)]; 35 | mdl = fitrgp(in, s_in(1:n,:)); 36 | 37 | toc 38 | 39 | save("GPmdl95P_S1.mat",'sd','vd','n','IFR','nd','mdl') 40 | 41 | %% PREDICTION 42 | 43 | IFR 44 | 45 | load AD_TestData.mat 46 | 47 | S_mse = zeros(100,1); 48 | S_nmse = zeros(100,1); 49 | 50 | for i = 1:100 51 | 52 | i 53 | 54 | n = 10000; 55 | y = u_in_test((i-1)*n+1:i*n, :); 56 | 57 | pfr = y*vd(:,1:nd); 58 | in = [pfr, x_t_in_test((i-1)*n+1:i*n, :)]; 59 | 60 | pred = zeros(10,10000,1); 61 | for j = 1:10 62 | pred(j,:,:) = predict(mdl, in); 63 | end 64 | 65 | mpred = squeeze(mean(pred, 1)); 66 | spred = squeeze(std(pred, 1)); 67 | 68 | mse = mean(mean((mpred'-s_in_test((i-1)*n+1:i*n,1)).^2)); 69 | nmse = mean(mean((mpred'-s_in_test((i-1)*n+1:i*n,1)).^2))./mean(mean(s_in_test((i-1)*n+1:i*n,1).^2)); 70 | 71 | S_mse(i) = mse; 72 | S_nmse(i) = nmse; 73 | 74 | end 75 | 76 | MSE = mean(S_mse) 77 | NMSE = mean(S_nmse) 78 | 79 | % MSE = 80 | % 81 | % 0.0021 82 | % 83 | % 84 | % NMSE = 85 | % 86 | % 0.0076 87 | -------------------------------------------------------------------------------- /CS-II/CS_II_D_DeepONet_GitHub.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6dc23078-05ca-4bbe-bcb0-e21d2a079cf7", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%reset -f\n", 11 | "import h5py\n", 12 | "import time as t\n", 13 | "import numpy as np\n", 14 | "import scipy as sp\n", 15 | "import scipy.io as spi\n", 16 | "import tensorflow as tf\n", 17 | "import matplotlib.pyplot as plt" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "c77b56ad-938b-4e14-bf5b-b13dcf195e17", 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "data": { 28 | "text/plain": [ 29 | "((70000, 100), (70000, 1), (70000, 1))" 30 | ] 31 | }, 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "isamp = 3500\n", 39 | "t = np.arange(0,1,0.01)\n", 40 | "\n", 41 | "data_train = spi.loadmat('GP_train_data.mat')\n", 42 | "\n", 43 | "u_in = data_train['u_in']\n", 44 | "x_t_in = data_train['x_t_in']\n", 45 | "s_in = data_train['s_in']\n", 46 | "\n", 47 | "u_in.shape, x_t_in.shape, s_in.shape" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "id": "edf85e77-6ee1-403c-9b4f-1ba0b7b50f8a", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "(1.2198099329889192,\n", 60 | " -1.3115838274707885,\n", 61 | " 4.208713652908879,\n", 62 | " -3.541514625432276,\n", 63 | " 0.99,\n", 64 | " 0.0)" 65 | ] 66 | }, 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "max_u = np.max(u_in)\n", 74 | "min_u = np.min(u_in)\n", 75 | "\n", 76 | "u_in = (u_in-min_u)/(max_u-min_u)\n", 77 | "\n", 78 | "max_t = np.max(x_t_in)\n", 79 | "min_t = np.min(x_t_in)\n", 80 | "\n", 81 | "x_t_in = (x_t_in-min_t)/(max_t-min_t)\n", 82 | "\n", 83 | "max_s = np.max(s_in)\n", 84 | "min_s = np.min(s_in)\n", 85 | "\n", 86 | "s_in = (s_in-min_s)/(max_s-min_s)\n", 87 | "\n", 88 | "max_s, min_s, max_u, min_u, max_t, min_t" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "id": "cc8fdc38-c3cb-47e3-ab99-f0870272984a", 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "Model: \"model\"\n", 102 | "__________________________________________________________________________________________________\n", 103 | " Layer (type) Output Shape Param # Connected to \n", 104 | "==================================================================================================\n", 105 | " inputsB (InputLayer) [(None, 100)] 0 [] \n", 106 | " \n", 107 | " inputsT (InputLayer) [(None, 1)] 0 [] \n", 108 | " \n", 109 | " dense (Dense) (None, 25) 2525 ['inputsB[0][0]'] \n", 110 | " \n", 111 | " dense_4 (Dense) (None, 25) 50 ['inputsT[0][0]'] \n", 112 | " \n", 113 | " dense_1 (Dense) (None, 25) 650 ['dense[0][0]'] \n", 114 | " \n", 115 | " dense_5 (Dense) (None, 25) 650 ['dense_4[0][0]'] \n", 116 | " \n", 117 | " dense_2 (Dense) (None, 25) 650 ['dense_1[0][0]'] \n", 118 | " \n", 119 | " dense_6 (Dense) (None, 25) 650 ['dense_5[0][0]'] \n", 120 | " \n", 121 | " dense_3 (Dense) (None, 25) 650 ['dense_2[0][0]'] \n", 122 | " \n", 123 | " dense_7 (Dense) (None, 25) 650 ['dense_6[0][0]'] \n", 124 | " \n", 125 | " lambda (Lambda) (None, 1) 0 ['dense_3[0][0]', \n", 126 | " 'dense_7[0][0]'] \n", 127 | " \n", 128 | " dense_8 (Dense) (None, 1) 2 ['lambda[0][0]'] \n", 129 | " \n", 130 | "==================================================================================================\n", 131 | "Total params: 6,477\n", 132 | "Trainable params: 6,477\n", 133 | "Non-trainable params: 0\n", 134 | "__________________________________________________________________________________________________\n", 135 | "Iteration: 0\n", 136 | "1/1 [==============================] - 1s 562ms/step - loss: 0.2437\n", 137 | "Iteration: 250\n", 138 | "1/1 [==============================] - 0s 282ms/step - loss: 0.0016\n", 139 | "Iteration: 500\n", 140 | "1/1 [==============================] - 0s 347ms/step - loss: 1.6776e-04\n", 141 | "Iteration: 750\n", 142 | "1/1 [==============================] - 0s 322ms/step - loss: 1.4479e-04\n", 143 | "Iteration: 1000\n", 144 | "1/1 [==============================] - 0s 301ms/step - loss: 1.2962e-04\n", 145 | "Iteration: 1250\n", 146 | "1/1 [==============================] - 0s 254ms/step - loss: 6.0995e-05\n", 147 | "Iteration: 1500\n", 148 | "1/1 [==============================] - 0s 284ms/step - loss: 3.6435e-05\n", 149 | "Iteration: 1750\n", 150 | "1/1 [==============================] - 0s 31ms/step - loss: 2.1164e-05\n", 151 | "Iteration: 2000\n", 152 | "1/1 [==============================] - 0s 269ms/step - loss: 1.6539e-05\n", 153 | "Iteration: 2250\n", 154 | "1/1 [==============================] - 0s 401ms/step - loss: 1.5448e-05\n", 155 | "Iteration: 2500\n", 156 | "1/1 [==============================] - 0s 108ms/step - loss: 1.4724e-05\n", 157 | "Iteration: 2750\n", 158 | "1/1 [==============================] - 0s 109ms/step - loss: 4.2001e-05\n", 159 | "Iteration: 3000\n", 160 | "1/1 [==============================] - 0s 121ms/step - loss: 1.3251e-05\n", 161 | "Iteration: 3250\n", 162 | "1/1 [==============================] - 0s 142ms/step - loss: 1.4724e-05\n", 163 | "Iteration: 3500\n", 164 | "1/1 [==============================] - 0s 121ms/step - loss: 1.2041e-05\n", 165 | "Iteration: 3750\n", 166 | "1/1 [==============================] - 0s 114ms/step - loss: 1.1542e-05\n", 167 | "Iteration: 4000\n", 168 | "1/1 [==============================] - 0s 109ms/step - loss: 1.2047e-05\n", 169 | "Iteration: 4250\n", 170 | "1/1 [==============================] - 0s 127ms/step - loss: 1.0731e-05\n", 171 | "Iteration: 4500\n", 172 | "1/1 [==============================] - 0s 121ms/step - loss: 1.1124e-05\n", 173 | "Iteration: 4750\n", 174 | "1/1 [==============================] - 0s 144ms/step - loss: 1.1224e-05\n", 175 | "Iteration: 5000\n", 176 | "1/1 [==============================] - 0s 80ms/step - loss: 1.0449e-05\n", 177 | "Iteration: 5250\n", 178 | "1/1 [==============================] - 0s 101ms/step - loss: 9.8735e-06\n", 179 | "Iteration: 5500\n", 180 | "1/1 [==============================] - 0s 128ms/step - loss: 9.4909e-06\n", 181 | "Iteration: 5750\n", 182 | "1/1 [==============================] - 0s 129ms/step - loss: 9.6598e-06\n", 183 | "Iteration: 6000\n", 184 | "1/1 [==============================] - 0s 368ms/step - loss: 8.6301e-06\n", 185 | "Iteration: 6250\n", 186 | "1/1 [==============================] - 0s 110ms/step - loss: 9.4722e-06\n", 187 | "Iteration: 6500\n", 188 | "1/1 [==============================] - 0s 123ms/step - loss: 2.2897e-05\n", 189 | "Iteration: 6750\n", 190 | "1/1 [==============================] - 0s 121ms/step - loss: 8.2797e-06\n", 191 | "Iteration: 7000\n", 192 | "1/1 [==============================] - 0s 91ms/step - loss: 4.0560e-05\n", 193 | "Iteration: 7250\n", 194 | "1/1 [==============================] - 0s 85ms/step - loss: 7.3007e-06\n", 195 | "Iteration: 7500\n", 196 | "1/1 [==============================] - 0s 93ms/step - loss: 1.6869e-05\n", 197 | "Iteration: 7750\n", 198 | "1/1 [==============================] - 0s 63ms/step - loss: 6.4992e-06\n", 199 | "Iteration: 8000\n", 200 | "1/1 [==============================] - 0s 85ms/step - loss: 9.0930e-06\n", 201 | "Iteration: 8250\n", 202 | "1/1 [==============================] - 0s 38ms/step - loss: 4.9472e-06\n", 203 | "Iteration: 8500\n", 204 | "1/1 [==============================] - 0s 38ms/step - loss: 9.1173e-06\n", 205 | "Iteration: 8750\n", 206 | "1/1 [==============================] - 0s 311ms/step - loss: 3.2722e-06\n", 207 | "Iteration: 9000\n", 208 | "1/1 [==============================] - 0s 132ms/step - loss: 2.7618e-06\n", 209 | "Iteration: 9250\n", 210 | "1/1 [==============================] - 0s 91ms/step - loss: 2.5110e-06\n", 211 | "Iteration: 9500\n", 212 | "1/1 [==============================] - 0s 121ms/step - loss: 6.5903e-06\n", 213 | "Iteration: 9750\n", 214 | "1/1 [==============================] - 0s 85ms/step - loss: 1.3208e-05\n", 215 | "Total iterations: 10000\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "from tensorflow.keras.models import Model\n", 221 | "from tensorflow.keras.layers import Input, Lambda, Dense\n", 222 | "\n", 223 | "bs = 70000\n", 224 | "\n", 225 | "def fn(x):\n", 226 | " y = tf.einsum(\"ij, ij->i\", x[0], x[1])\n", 227 | " y = tf.expand_dims(y, axis = 1)\n", 228 | " return y\n", 229 | "\n", 230 | "hln = 25\n", 231 | "\n", 232 | "inputsB = Input(shape = (100,), name = 'inputsB')\n", 233 | "hiddenB = Dense(hln, activation = \"relu\")(inputsB)\n", 234 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 235 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 236 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 237 | "\n", 238 | "inputsT = Input(shape = (1,), name = 'inputsT')\n", 239 | "hiddenT = Dense(hln, activation = \"relu\")(inputsT)\n", 240 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 241 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 242 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 243 | "\n", 244 | "combined = Lambda(fn, output_shape = [None, 1])([hiddenB, hiddenT])\n", 245 | "output = Dense(1)(combined)\n", 246 | "\n", 247 | "model = Model(inputs = [inputsB, inputsT], outputs = output) \n", 248 | "\n", 249 | "model.compile(optimizer = tf.optimizers.Adam(learning_rate = 0.001), loss = 'mse')\n", 250 | "model.summary()\n", 251 | "\n", 252 | "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", 253 | " filepath = './ChkPts/S1_CII_VDN_s1_10000/',\n", 254 | " save_weights_only=True,\n", 255 | " monitor='loss',\n", 256 | " mode='min',\n", 257 | " save_best_only=True)\n", 258 | "\n", 259 | "itr = 0\n", 260 | "for i in range(0, 40):\n", 261 | " print('Iteration: '+str(itr))\n", 262 | " itr = itr+1\n", 263 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = 1,\n", 264 | " verbose = 1, batch_size = bs, callbacks = [model_checkpoint_callback]) \n", 265 | "\n", 266 | " itr_ps = 250-1\n", 267 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = itr_ps,\n", 268 | " verbose = 0, batch_size = bs, callbacks = [model_checkpoint_callback])\n", 269 | " itr = itr+itr_ps\n", 270 | "\n", 271 | "print('Total iterations: '+str(itr))\n", 272 | "\n", 273 | "model.load_weights('./ChkPts/S1_CII_VDN_s1_10000/')\n", 274 | "\n", 275 | "model.save_weights('./model/S1_CII_VDN_s1_10000')\n", 276 | "\n", 277 | "# model.load_weights('./model/DenseFlipout_model_TFP_weights_30HL_S1_CII_VDN')" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 5, 283 | "id": "2f794c77-e476-4a04-9e36-17ce06d584a5", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "(1000000, 100) (1000000, 1) (1000000, 1)\n", 291 | "\n", 292 | "1.4497887e-05\n", 293 | "0.0380138393414362\n", 294 | "0.00038138443647824094\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "data_train = spi.loadmat('GP_matlab.mat')\n", 300 | "\n", 301 | "t = np.arange(0,1,0.01)\n", 302 | "u = data_train['u'][10000:20000,:]\n", 303 | "s = data_train['s1'][:,10000:20000].T\n", 304 | "\n", 305 | "pointer = np.random.randint(0,100,10000)\n", 306 | "\n", 307 | "u_in_test = np.zeros([1000000,100])\n", 308 | "x_t_in_test = np.zeros([1000000,1])\n", 309 | "s_in_test = np.zeros([1000000,1])\n", 310 | "for i in range(0,10000):\n", 311 | " for j in range(0,100):\n", 312 | " u_in_test[100*i+j,:] = u[i,:]\n", 313 | " x_t_in_test[100*i+j,:] = t[j]\n", 314 | " s_in_test[100*i+j,:] = s[i,j]\n", 315 | " \n", 316 | "print(u_in_test.shape, x_t_in_test.shape, s_in_test.shape)\n", 317 | "\n", 318 | "u_in_test = (u_in_test-min_u)/(max_u-min_u)\n", 319 | "x_t_in_test = (x_t_in_test-min_t)/(max_t-min_t)\n", 320 | "s_in_test = (s_in_test-min_s)/(max_s-min_s)\n", 321 | "\n", 322 | "pred = model({\"inputsB\":u_in_test, \"inputsT\":x_t_in_test})\n", 323 | "\n", 324 | "pred = (pred*(max_s-min_s))+min_s \n", 325 | "s_in_test = (s_in_test*(max_s-min_s))+min_s\n", 326 | "\n", 327 | "print()\n", 328 | "print(np.mean((s_in_test-pred)**2))\n", 329 | "print(np.mean((s_in_test)**2))\n", 330 | "print(np.mean((s_in_test-pred)**2)/np.mean((s_in_test)**2)) " 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "id": "46488a56-700f-4c3b-8bfb-1454815e7750", 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "Python 3 (ipykernel)", 345 | "language": "python", 346 | "name": "python3" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.10.2" 359 | }, 360 | "widgets": { 361 | "application/vnd.jupyter.widget-state+json": { 362 | "state": {}, 363 | "version_major": 2, 364 | "version_minor": 0 365 | } 366 | } 367 | }, 368 | "nbformat": 4, 369 | "nbformat_minor": 5 370 | } 371 | -------------------------------------------------------------------------------- /CS-II/CS_II_POD_GP_GitHub.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | 5 | tic 6 | 7 | %% %% 8 | 9 | IFR = 0.95; 10 | load GP_train_data.mat 11 | 12 | n = 70000; 13 | y = u_in(1:n,:); 14 | 15 | [~, sd, vd] = svd(y); 16 | 17 | sd = sd.^2; 18 | nd = 1; 19 | chkd = sum(diag(sd)); 20 | rd = 0; 21 | while rd < IFR 22 | rd = sum(diag(sd(1:nd,1:nd)))/chkd; 23 | nd = nd+1; 24 | end 25 | nd = nd-1; 26 | fprintf('\n\n%d\n\n',nd); 27 | 28 | red = y*vd(:,1:nd); 29 | 30 | toc 31 | 32 | %% 33 | 34 | in = [red, x_t_in(1:n, :)]; 35 | mdl = fitrgp(in, s_in(1:n,:)); 36 | 37 | toc 38 | 39 | save("GPmdl95P_S5.mat",'sd','vd','n','IFR','nd','mdl') 40 | 41 | %% PREDICTION 42 | 43 | load GP_test_data.mat 44 | 45 | S_mse = zeros(100,1); 46 | S_nmse = zeros(100,1); 47 | 48 | for i = 1:100 49 | 50 | i 51 | 52 | n = 10000; 53 | y = u_in_test((i-1)*n+1:i*n, :); 54 | 55 | pfr = y*vd(:,1:nd); 56 | in = [pfr, x_t_in_test((i-1)*n+1:i*n, :)]; 57 | 58 | pred = zeros(10,10000,1); 59 | for j = 1:10 60 | pred(j,:,:) = predict(mdl, in); 61 | end 62 | 63 | mpred = squeeze(mean(pred, 1)); 64 | spred = squeeze(std(pred, 1)); 65 | 66 | mse = mean(mean((mpred'-s_in_test((i-1)*n+1:i*n,1)).^2)); 67 | nmse = mean(mean((mpred'-s_in_test((i-1)*n+1:i*n,1)).^2))./mean(mean(s_in_test((i-1)*n+1:i*n,1).^2)); 68 | 69 | S_mse(i) = mse; 70 | S_nmse(i) = nmse; 71 | end 72 | 73 | MSE = mean(S_mse) 74 | NMSE = mean(S_nmse) 75 | 76 | % MSE = 77 | % 78 | % 7.8033e-04 79 | % 80 | % 81 | % NMSE = 82 | % 83 | % 0.0205 84 | -------------------------------------------------------------------------------- /CS-III/CS_III_D_DeepONet_GitHub.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6dc23078-05ca-4bbe-bcb0-e21d2a079cf7", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((50000, 100), (50000, 2), (50000, 1))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "%reset -f\n", 22 | "import h5py\n", 23 | "import time as t\n", 24 | "import numpy as np\n", 25 | "import scipy as sp\n", 26 | "import scipy.io as spi\n", 27 | "import tensorflow as tf\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "data_train = spi.loadmat('train_PDE_DR.mat')\n", 31 | "\n", 32 | "u_in = data_train['X_train0'][0:50000,:]\n", 33 | "x_t_in = data_train['X_train1'][0:50000,:]\n", 34 | "s_in = data_train['y_train'][0:50000,:]\n", 35 | "\n", 36 | "u_in.shape, x_t_in.shape, s_in.shape" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "cc8fdc38-c3cb-47e3-ab99-f0870272984a", 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "Model: \"model\"\n", 50 | "__________________________________________________________________________________________________\n", 51 | " Layer (type) Output Shape Param # Connected to \n", 52 | "==================================================================================================\n", 53 | " inputsB (InputLayer) [(None, 100)] 0 [] \n", 54 | " \n", 55 | " inputsT (InputLayer) [(None, 2)] 0 [] \n", 56 | " \n", 57 | " dense (Dense) (None, 25) 2525 ['inputsB[0][0]'] \n", 58 | " \n", 59 | " dense_4 (Dense) (None, 25) 75 ['inputsT[0][0]'] \n", 60 | " \n", 61 | " dense_1 (Dense) (None, 25) 650 ['dense[0][0]'] \n", 62 | " \n", 63 | " dense_5 (Dense) (None, 25) 650 ['dense_4[0][0]'] \n", 64 | " \n", 65 | " dense_2 (Dense) (None, 25) 650 ['dense_1[0][0]'] \n", 66 | " \n", 67 | " dense_6 (Dense) (None, 25) 650 ['dense_5[0][0]'] \n", 68 | " \n", 69 | " dense_3 (Dense) (None, 25) 650 ['dense_2[0][0]'] \n", 70 | " \n", 71 | " dense_7 (Dense) (None, 25) 650 ['dense_6[0][0]'] \n", 72 | " \n", 73 | " lambda (Lambda) (None, 1) 0 ['dense_3[0][0]', \n", 74 | " 'dense_7[0][0]'] \n", 75 | " \n", 76 | " dense_8 (Dense) (None, 1) 2 ['lambda[0][0]'] \n", 77 | " \n", 78 | "==================================================================================================\n", 79 | "Total params: 6,502\n", 80 | "Trainable params: 6,502\n", 81 | "Non-trainable params: 0\n", 82 | "__________________________________________________________________________________________________\n", 83 | "Iteration: 0\n", 84 | "1/1 [==============================] - 1s 827ms/step - loss: 0.2393\n", 85 | "Iteration: 250\n", 86 | "1/1 [==============================] - 0s 332ms/step - loss: 0.0644\n", 87 | "Iteration: 500\n", 88 | "1/1 [==============================] - 0s 296ms/step - loss: 0.0371\n", 89 | "Iteration: 750\n", 90 | "1/1 [==============================] - 0s 250ms/step - loss: 0.0243\n", 91 | "Iteration: 1000\n", 92 | "1/1 [==============================] - 0s 285ms/step - loss: 0.0166\n", 93 | "Iteration: 1250\n", 94 | "1/1 [==============================] - 0s 253ms/step - loss: 0.0123\n", 95 | "Iteration: 1500\n", 96 | "1/1 [==============================] - 0s 51ms/step - loss: 0.0095\n", 97 | "Iteration: 1750\n", 98 | "1/1 [==============================] - 0s 34ms/step - loss: 0.0075\n", 99 | "Iteration: 2000\n", 100 | "1/1 [==============================] - 0s 31ms/step - loss: 0.0060\n", 101 | "Iteration: 2250\n", 102 | "1/1 [==============================] - 0s 38ms/step - loss: 0.0050\n", 103 | "Iteration: 2500\n", 104 | "1/1 [==============================] - 0s 30ms/step - loss: 0.0042\n", 105 | "Iteration: 2750\n", 106 | "1/1 [==============================] - 0s 47ms/step - loss: 0.0036\n", 107 | "Iteration: 3000\n", 108 | "1/1 [==============================] - 1s 520ms/step - loss: 0.0030\n", 109 | "Iteration: 3250\n", 110 | "1/1 [==============================] - 0s 38ms/step - loss: 0.0027\n", 111 | "Iteration: 3500\n", 112 | "1/1 [==============================] - 0s 47ms/step - loss: 0.0024\n", 113 | "Iteration: 3750\n", 114 | "1/1 [==============================] - 0s 47ms/step - loss: 0.0020\n", 115 | "Iteration: 4000\n", 116 | "1/1 [==============================] - 0s 62ms/step - loss: 0.0018\n", 117 | "Iteration: 4250\n", 118 | "1/1 [==============================] - 0s 31ms/step - loss: 0.0016\n", 119 | "Iteration: 4500\n", 120 | "1/1 [==============================] - 0s 38ms/step - loss: 0.0014\n", 121 | "Iteration: 4750\n", 122 | "1/1 [==============================] - 0s 31ms/step - loss: 0.0015\n", 123 | "Iteration: 5000\n", 124 | "1/1 [==============================] - 0s 38ms/step - loss: 0.0012\n", 125 | "Iteration: 5250\n", 126 | "1/1 [==============================] - 0s 253ms/step - loss: 0.0011\n", 127 | "Iteration: 5500\n", 128 | "1/1 [==============================] - 0s 47ms/step - loss: 0.0014\n", 129 | "Iteration: 5750\n", 130 | "1/1 [==============================] - 0s 62ms/step - loss: 9.7317e-04\n", 131 | "Iteration: 6000\n", 132 | "1/1 [==============================] - 0s 38ms/step - loss: 9.3186e-04\n", 133 | "Iteration: 6250\n", 134 | "1/1 [==============================] - 0s 283ms/step - loss: 8.6686e-04\n", 135 | "Iteration: 6500\n", 136 | "1/1 [==============================] - 0s 66ms/step - loss: 0.0011\n", 137 | "Iteration: 6750\n", 138 | "1/1 [==============================] - 0s 84ms/step - loss: 8.4038e-04\n", 139 | "Iteration: 7000\n", 140 | "1/1 [==============================] - 0s 65ms/step - loss: 7.7355e-04\n", 141 | "Iteration: 7250\n", 142 | "1/1 [==============================] - 0s 38ms/step - loss: 9.0091e-04\n", 143 | "Iteration: 7500\n", 144 | "1/1 [==============================] - 0s 39ms/step - loss: 7.2588e-04\n", 145 | "Iteration: 7750\n", 146 | "1/1 [==============================] - 0s 39ms/step - loss: 7.0640e-04\n", 147 | "Iteration: 8000\n", 148 | "1/1 [==============================] - 0s 70ms/step - loss: 6.6383e-04\n", 149 | "Iteration: 8250\n", 150 | "1/1 [==============================] - 0s 62ms/step - loss: 6.4144e-04\n", 151 | "Iteration: 8500\n", 152 | "1/1 [==============================] - 0s 103ms/step - loss: 6.6727e-04\n", 153 | "Iteration: 8750\n", 154 | "1/1 [==============================] - 0s 44ms/step - loss: 7.2094e-04\n", 155 | "Iteration: 9000\n", 156 | "1/1 [==============================] - 0s 85ms/step - loss: 5.8546e-04\n", 157 | "Iteration: 9250\n", 158 | "1/1 [==============================] - 0s 84ms/step - loss: 6.0858e-04\n", 159 | "Iteration: 9500\n", 160 | "1/1 [==============================] - 0s 31ms/step - loss: 5.9831e-04\n", 161 | "Iteration: 9750\n", 162 | "1/1 [==============================] - 0s 70ms/step - loss: 7.4892e-04\n", 163 | "Total iterations: 10000\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "from tensorflow.keras.models import Model\n", 169 | "from tensorflow.keras.layers import Input, Lambda, Dense\n", 170 | "\n", 171 | "bs = 50000\n", 172 | "\n", 173 | "def fn(x):\n", 174 | " y = tf.einsum(\"ij, ij->i\", x[0], x[1])\n", 175 | " y = tf.expand_dims(y, axis = 1)\n", 176 | " return y\n", 177 | "\n", 178 | "hln = 25\n", 179 | "\n", 180 | "inputsB = Input(shape = (100,), name = 'inputsB')\n", 181 | "hiddenB = Dense(hln, activation = \"relu\")(inputsB)\n", 182 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 183 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 184 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 185 | "\n", 186 | "inputsT = Input(shape = (2,), name = 'inputsT')\n", 187 | "hiddenT = Dense(hln, activation = \"relu\")(inputsT)\n", 188 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 189 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 190 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 191 | "\n", 192 | "combined = Lambda(fn, output_shape = [None, 1])([hiddenB, hiddenT])\n", 193 | "output = Dense(1)(combined)\n", 194 | "\n", 195 | "model = Model(inputs = [inputsB, inputsT], outputs = output) \n", 196 | "\n", 197 | "model.compile(optimizer = tf.optimizers.Adam(learning_rate = 0.001), loss = 'mse')\n", 198 | "model.summary()\n", 199 | "\n", 200 | "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", 201 | " filepath = './ChkPts/CIII_V2_VDN_S1_10000/',\n", 202 | " save_weights_only=True,\n", 203 | " monitor='loss',\n", 204 | " mode='min',\n", 205 | " save_best_only=True)\n", 206 | "\n", 207 | "itr = 0\n", 208 | "for i in range(0, 40):\n", 209 | " print('Iteration: '+str(itr))\n", 210 | " itr = itr+1\n", 211 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = 1,\n", 212 | " verbose = 1, batch_size = bs, callbacks = [model_checkpoint_callback]) \n", 213 | "\n", 214 | " itr_ps = 250-1\n", 215 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = itr_ps,\n", 216 | " verbose = 0, batch_size = bs, callbacks = [model_checkpoint_callback])\n", 217 | " itr = itr+itr_ps\n", 218 | "\n", 219 | "print('Total iterations: '+str(itr))\n", 220 | "\n", 221 | "model.load_weights('./ChkPts/CIII_V2_VDN_S1_10000/')\n", 222 | "\n", 223 | "model.save_weights('./model/CIII_V2_VDN_S1_10000')\n", 224 | "\n", 225 | "# model.load_weights('./model/CIII_V2_VDN_S1')" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 3, 231 | "id": "9cebf033-567b-4508-a037-5951bf7adb06", 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "\n", 239 | "0.0021799998\n", 240 | "0.23837758612778512\n", 241 | "0.009145154421609878\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "data_test = spi.loadmat('test_PDE_DR.mat')\n", 247 | "\n", 248 | "u_in_test = data_test['X_test0']\n", 249 | "x_t_in_test = data_test['X_test1']\n", 250 | "s_in_test = data_test['y_test']\n", 251 | "\n", 252 | "pred = model({\"inputsB\":u_in_test, \"inputsT\":x_t_in_test})\n", 253 | " \n", 254 | "print()\n", 255 | "print(np.mean((s_in_test-pred)**2))\n", 256 | "print(np.mean((s_in_test)**2))\n", 257 | "print(np.mean((s_in_test-pred)**2)/np.mean((s_in_test)**2)) " 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "id": "1181fd7b-12ec-4383-a47e-127a22ffdd7a", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "Python 3 (ipykernel)", 272 | "language": "python", 273 | "name": "python3" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.10.2" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 5 290 | } 291 | -------------------------------------------------------------------------------- /CS-III/CS_III_POD_GP_GitHub.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | 5 | tic 6 | 7 | %% %% 8 | 9 | IFR = 0.95; 10 | load train_PDE_DR.mat 11 | 12 | n = 50000; 13 | y = X_train0(1:n,:); 14 | 15 | [~, sd, vd] = svd(y); 16 | 17 | sd = sd.^2; 18 | nd = 1; 19 | chkd = sum(diag(sd)); 20 | rd = 0; 21 | while rd < IFR 22 | rd = sum(diag(sd(1:nd,1:nd)))/chkd; 23 | nd = nd+1; 24 | end 25 | nd = nd-1; 26 | fprintf('\n\n%d\n\n',nd); 27 | 28 | red = y*vd(:,1:nd); 29 | 30 | toc 31 | 32 | %% 33 | 34 | in = [red, X_train1(1:n, :)]; 35 | mdl = fitrgp(in, y_train(1:n,:)); 36 | 37 | toc 38 | 39 | %% 40 | 41 | save("GPmdl95P_new1.mat",'sd','vd','n','IFR','nd','mdl') 42 | 43 | %% PREDICTION 44 | 45 | load test_PDE_DR.mat 46 | 47 | S_mse = zeros(100,1); 48 | S_nmse = zeros(100,1); 49 | 50 | for i = 1:100 51 | 52 | i 53 | 54 | n = 10000; 55 | y = X_test0((i-1)*n+1:i*n, :); 56 | 57 | pfr = y*vd(:,1:nd); 58 | in = [pfr, X_test1((i-1)*n+1:i*n, :)]; 59 | 60 | pred = zeros(10,10000,1); 61 | for j = 1:10 62 | pred(j,:,:) = predict(mdl, in); 63 | end 64 | 65 | mpred = squeeze(mean(pred, 1)); 66 | spred = squeeze(std(pred, 1)); 67 | 68 | mse = mean(mean((mpred'-y_test((i-1)*n+1:i*n,1)).^2)); 69 | nmse = mean(mean((mpred'-y_test((i-1)*n+1:i*n,1)).^2))./mean(mean(y_test((i-1)*n+1:i*n,1).^2)); 70 | 71 | S_mse(i) = mse; 72 | S_nmse(i) = nmse; 73 | 74 | end 75 | 76 | MSE = mean(S_mse) 77 | NMSE = mean(S_nmse) 78 | 79 | % MSE = 80 | % 81 | % 0.0819 82 | % 83 | % 84 | % NMSE = 85 | % 86 | % 0.3446 87 | -------------------------------------------------------------------------------- /CS-III/CS_III_VB_DeepONet_GitHub.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6dc23078-05ca-4bbe-bcb0-e21d2a079cf7", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((50000, 100), (50000, 2), (50000, 1))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "%reset -f\n", 22 | "import h5py\n", 23 | "import time as t\n", 24 | "import numpy as np\n", 25 | "import scipy as sp\n", 26 | "import scipy.io as spi\n", 27 | "import tensorflow as tf\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import tensorflow_probability as tfp\n", 30 | "\n", 31 | "data_train = spi.loadmat('train_PDE_DR.mat')\n", 32 | "\n", 33 | "u_in = data_train['X_train0'][0:50000,:]\n", 34 | "x_t_in = data_train['X_train1'][0:50000,:]\n", 35 | "s_in = data_train['y_train'][0:50000,:]\n", 36 | "\n", 37 | "u_in.shape, x_t_in.shape, s_in.shape" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "cc8fdc38-c3cb-47e3-ab99-f0870272984a", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "/home/user/anaconda3/envs/SG_Env_TFP/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:99: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n", 51 | " loc = add_variable_fn(\n", 52 | "2022-10-24 13:30:26.157439: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory\n", 53 | "2022-10-24 13:30:26.157527: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", 54 | "Skipping registering GPU devices...\n", 55 | "2022-10-24 13:30:26.159473: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n", 56 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 57 | "/home/user/anaconda3/envs/SG_Env_TFP/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:109: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n", 58 | " untransformed_scale = add_variable_fn(\n" 59 | ] 60 | }, 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Model: \"model\"\n", 66 | "__________________________________________________________________________________________________\n", 67 | " Layer (type) Output Shape Param # Connected to \n", 68 | "==================================================================================================\n", 69 | " inputsB (InputLayer) [(None, 100)] 0 [] \n", 70 | " \n", 71 | " inputsT (InputLayer) [(None, 2)] 0 [] \n", 72 | " \n", 73 | " dense_flipout (DenseFlipout) (None, 25) 5025 ['inputsB[0][0]'] \n", 74 | " \n", 75 | " dense_flipout_4 (DenseFlipout) (None, 25) 125 ['inputsT[0][0]'] \n", 76 | " \n", 77 | " dense_flipout_1 (DenseFlipout) (None, 25) 1275 ['dense_flipout[0][0]'] \n", 78 | " \n", 79 | " dense_flipout_5 (DenseFlipout) (None, 25) 1275 ['dense_flipout_4[0][0]'] \n", 80 | " \n", 81 | " dense_flipout_2 (DenseFlipout) (None, 25) 1275 ['dense_flipout_1[0][0]'] \n", 82 | " \n", 83 | " dense_flipout_6 (DenseFlipout) (None, 25) 1275 ['dense_flipout_5[0][0]'] \n", 84 | " \n", 85 | " dense_flipout_3 (DenseFlipout) (None, 25) 1275 ['dense_flipout_2[0][0]'] \n", 86 | " \n", 87 | " dense_flipout_7 (DenseFlipout) (None, 25) 1275 ['dense_flipout_6[0][0]'] \n", 88 | " \n", 89 | " lambda (Lambda) (None, 1) 0 ['dense_flipout_3[0][0]', \n", 90 | " 'dense_flipout_7[0][0]'] \n", 91 | " \n", 92 | " dense_flipout_8 (DenseFlipout) (None, 2) 6 ['lambda[0][0]'] \n", 93 | " \n", 94 | " distribution_lambda (Distribut ((None, 1), 0 ['dense_flipout_8[0][0]'] \n", 95 | " ionLambda) (None, 1)) \n", 96 | " \n", 97 | "==================================================================================================\n", 98 | "Total params: 12,806\n", 99 | "Trainable params: 12,806\n", 100 | "Non-trainable params: 0\n", 101 | "__________________________________________________________________________________________________\n" 102 | ] 103 | }, 104 | { 105 | "name": "stderr", 106 | "output_type": "stream", 107 | "text": [ 108 | "2022-10-24 13:30:26.993811: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "from tensorflow.keras.models import Model\n", 114 | "from tensorflow.keras.layers import Input, Lambda, Dense\n", 115 | "\n", 116 | "bs = 50000\n", 117 | "\n", 118 | "def fn(x):\n", 119 | " y = tf.einsum(\"ij, ij->i\", x[0], x[1])\n", 120 | " y = tf.expand_dims(y, axis = 1)\n", 121 | " return y\n", 122 | "\n", 123 | "tfd = tfp.distributions\n", 124 | "tfb = tfp.bijectors\n", 125 | "\n", 126 | "def normal_sp(params):\n", 127 | " return tfd.Normal(loc = params[:, 0:1], scale = 0.001+tf.math.softplus(params[:, 1:2])) \n", 128 | "\n", 129 | "def negloglikelihood(y_true, y_pred):\n", 130 | " return tf.keras.backend.sum(-y_pred.log_prob(y_true))+(sum(model.losses)/bs)\n", 131 | "\n", 132 | "hln = 25\n", 133 | "\n", 134 | "inputsB = Input(shape = (100,), name = 'inputsB')\n", 135 | "hiddenB = tfp.layers.DenseFlipout(hln, activation = \"relu\")(inputsB)\n", 136 | "hiddenB = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenB)\n", 137 | "hiddenB = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenB)\n", 138 | "hiddenB = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenB)\n", 139 | "\n", 140 | "inputsT = Input(shape = (2,), name = 'inputsT')\n", 141 | "hiddenT = tfp.layers.DenseFlipout(hln, activation = \"relu\")(inputsT)\n", 142 | "hiddenT = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenT)\n", 143 | "hiddenT = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenT)\n", 144 | "hiddenT = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenT)\n", 145 | "\n", 146 | "combined = Lambda(fn, output_shape = [None, 1])([hiddenB, hiddenT])\n", 147 | "output = tfp.layers.DenseFlipout(2)(combined)\n", 148 | "\n", 149 | "dist = tfp.layers.DistributionLambda(normal_sp)(output)\n", 150 | "model = Model(inputs = [inputsB, inputsT], outputs = dist) \n", 151 | "\n", 152 | "model.summary()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 3, 158 | "id": "d88e6bb8-eb50-4662-9c72-52ccce4acc7f", 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "Epoch 0, loss 39877.29\n", 166 | "Epoch 10, loss 38531.90\n", 167 | "Epoch 20, loss 36072.61\n", 168 | "Epoch 30, loss 33013.02\n", 169 | "Epoch 40, loss 30713.41\n", 170 | "Epoch 50, loss 28726.76\n", 171 | "Epoch 60, loss 26616.61\n", 172 | "Epoch 70, loss 25168.88\n", 173 | "Epoch 80, loss 23819.35\n", 174 | "Epoch 90, loss 22713.27\n", 175 | "Epoch 100, loss 21371.43\n", 176 | "Epoch 110, loss 19408.15\n", 177 | "Epoch 120, loss 17520.28\n", 178 | "Epoch 130, loss 16635.89\n", 179 | "Epoch 140, loss 13187.33\n", 180 | "Epoch 150, loss 10226.96\n", 181 | "Epoch 160, loss 7357.02\n", 182 | "Epoch 170, loss 5567.42\n", 183 | "Epoch 180, loss 3769.03\n", 184 | "Epoch 190, loss -1537.37\n", 185 | "Epoch 200, loss -5026.68\n", 186 | "Epoch 210, loss -7830.66\n", 187 | "Epoch 220, loss -10840.90\n", 188 | "Epoch 230, loss -9039.54\n", 189 | "Epoch 240, loss -14135.99\n", 190 | "Epoch 250, loss -13071.82\n", 191 | "Epoch 260, loss -14424.20\n", 192 | "Epoch 270, loss -20474.61\n", 193 | "Epoch 280, loss -19320.96\n", 194 | "Epoch 290, loss -6328.44\n", 195 | "Epoch 300, loss -18255.17\n", 196 | "Epoch 310, loss -23530.44\n", 197 | "Epoch 320, loss -29075.37\n", 198 | "Epoch 330, loss -31208.97\n", 199 | "Epoch 340, loss -32969.49\n", 200 | "Epoch 350, loss -34637.25\n", 201 | "Epoch 360, loss -34674.25\n", 202 | "Epoch 370, loss -29429.44\n", 203 | "Epoch 380, loss -17477.26\n", 204 | "Epoch 390, loss -21171.82\n", 205 | "Epoch 400, loss -27455.93\n", 206 | "Epoch 410, loss -34836.09\n", 207 | "Epoch 420, loss -41206.99\n", 208 | "Epoch 430, loss -45704.17\n", 209 | "Epoch 440, loss 67094.43\n", 210 | "Epoch 450, loss -7484.27\n", 211 | "Epoch 460, loss -17318.89\n", 212 | "Epoch 470, loss -20708.11\n", 213 | "Epoch 480, loss -25448.49\n", 214 | "Epoch 490, loss -30165.41\n", 215 | "Epoch 500, loss -35283.66\n", 216 | "Epoch 510, loss -40137.30\n", 217 | "Epoch 520, loss -44647.61\n", 218 | "Epoch 530, loss -48670.86\n", 219 | "Epoch 540, loss -51846.25\n", 220 | "Epoch 550, loss -51190.06\n", 221 | "Epoch 560, loss -35153.79\n", 222 | "Epoch 570, loss -43911.79\n", 223 | "Epoch 580, loss -47062.34\n", 224 | "Epoch 590, loss -50320.53\n", 225 | "Epoch 600, loss -53655.75\n", 226 | "Epoch 610, loss -56991.34\n", 227 | "Epoch 620, loss -59769.63\n", 228 | "Epoch 630, loss -50915.04\n", 229 | "Epoch 640, loss -16210.88\n", 230 | "Epoch 650, loss -38177.85\n", 231 | "Epoch 660, loss -41646.18\n", 232 | "Epoch 670, loss -46636.45\n", 233 | "Epoch 680, loss -50874.62\n", 234 | "Epoch 690, loss -55399.22\n", 235 | "Epoch 700, loss -59393.99\n", 236 | "Epoch 710, loss -62611.89\n", 237 | "Epoch 720, loss -64988.91\n", 238 | "Epoch 730, loss -56415.46\n", 239 | "Epoch 740, loss 67544.20\n", 240 | "Epoch 750, loss 1985.83\n", 241 | "Epoch 760, loss -5228.71\n", 242 | "Epoch 770, loss -9084.31\n", 243 | "Epoch 780, loss -11369.56\n", 244 | "Epoch 790, loss -13888.36\n", 245 | "Epoch 800, loss -16372.86\n", 246 | "Epoch 810, loss -18818.07\n", 247 | "Epoch 820, loss -21227.30\n", 248 | "Epoch 830, loss -23452.48\n", 249 | "Epoch 840, loss -25507.50\n", 250 | "Epoch 850, loss -27701.08\n", 251 | "Epoch 860, loss -30014.55\n", 252 | "Epoch 870, loss -32329.88\n", 253 | "Epoch 880, loss -34661.86\n", 254 | "Epoch 890, loss -36806.50\n", 255 | "Epoch 900, loss -39136.26\n", 256 | "Epoch 910, loss -41480.32\n", 257 | "Epoch 920, loss -43789.23\n", 258 | "Epoch 930, loss -45989.72\n", 259 | "Epoch 940, loss -47120.11\n", 260 | "Epoch 950, loss -49160.41\n", 261 | "Epoch 960, loss -51738.08\n", 262 | "Epoch 970, loss -53719.14\n", 263 | "Epoch 980, loss -54594.71\n", 264 | "Epoch 990, loss -37955.09\n", 265 | "Epoch 1000, loss -54098.08\n", 266 | "Epoch 1010, loss -56544.80\n", 267 | "Epoch 1020, loss -57652.64\n", 268 | "Epoch 1030, loss -59762.47\n", 269 | "Epoch 1040, loss -61102.08\n", 270 | "Epoch 1050, loss -62705.71\n", 271 | "Epoch 1060, loss -62518.20\n", 272 | "Epoch 1070, loss -56234.33\n", 273 | "Epoch 1080, loss -57242.17\n", 274 | "Epoch 1090, loss -61324.43\n", 275 | "Epoch 1100, loss -65754.66\n", 276 | "Epoch 1110, loss -63953.59\n", 277 | "Epoch 1120, loss -48741.12\n", 278 | "Epoch 1130, loss -29090.29\n", 279 | "Epoch 1140, loss -53853.59\n", 280 | "Epoch 1150, loss -59865.13\n", 281 | "Epoch 1160, loss -64464.63\n", 282 | "Epoch 1170, loss -66733.52\n", 283 | "Epoch 1180, loss -69037.99\n", 284 | "Epoch 1190, loss -71245.88\n", 285 | "Epoch 1200, loss -73076.67\n", 286 | "Epoch 1210, loss 298286.88\n", 287 | "Epoch 1220, loss 79448.67\n", 288 | "Epoch 1230, loss 4780.92\n", 289 | "Epoch 1240, loss -7073.68\n", 290 | "Epoch 1250, loss -11705.63\n", 291 | "Epoch 1260, loss -14897.88\n", 292 | "Epoch 1270, loss -17584.15\n", 293 | "Epoch 1280, loss -20054.69\n", 294 | "Epoch 1290, loss -22235.57\n", 295 | "Epoch 1300, loss -24325.37\n", 296 | "Epoch 1310, loss -26394.29\n", 297 | "Epoch 1320, loss -28445.45\n", 298 | "Epoch 1330, loss -30499.71\n", 299 | "Epoch 1340, loss -32568.94\n", 300 | "Epoch 1350, loss -34519.61\n", 301 | "Epoch 1360, loss -36475.36\n", 302 | "Epoch 1370, loss -38384.16\n", 303 | "Epoch 1380, loss -40302.97\n", 304 | "Epoch 1390, loss -42177.90\n", 305 | "Epoch 1400, loss -44011.34\n", 306 | "Epoch 1410, loss -45831.68\n", 307 | "Epoch 1420, loss -47617.91\n", 308 | "Epoch 1430, loss -49366.58\n", 309 | "Epoch 1440, loss -50875.71\n", 310 | "Epoch 1450, loss -52719.30\n", 311 | "Epoch 1460, loss -54385.65\n", 312 | "Epoch 1470, loss -55951.33\n", 313 | "Epoch 1480, loss -57398.82\n", 314 | "Epoch 1490, loss -58492.64\n", 315 | "Epoch 1500, loss -60026.34\n", 316 | "Epoch 1510, loss -61321.67\n", 317 | "Epoch 1520, loss -62435.44\n", 318 | "Epoch 1530, loss -62949.56\n", 319 | "Epoch 1540, loss -64326.41\n", 320 | "Epoch 1550, loss -65295.13\n", 321 | "Epoch 1560, loss -64732.43\n", 322 | "Epoch 1570, loss -66301.23\n", 323 | "Epoch 1580, loss -66972.06\n", 324 | "Epoch 1590, loss -68025.16\n", 325 | "Epoch 1600, loss -61943.22\n", 326 | "Epoch 1610, loss -68556.55\n", 327 | "Epoch 1620, loss -68170.21\n", 328 | "Epoch 1630, loss -69817.32\n", 329 | "Epoch 1640, loss -68871.11\n", 330 | "Epoch 1650, loss -68355.23\n", 331 | "Epoch 1660, loss -65750.27\n", 332 | "Epoch 1670, loss -72015.88\n", 333 | "Epoch 1680, loss -70920.30\n", 334 | "Epoch 1690, loss -67535.43\n", 335 | "Epoch 1700, loss -70718.39\n", 336 | "Epoch 1710, loss -74319.36\n", 337 | "Epoch 1720, loss -75035.09\n", 338 | "Epoch 1730, loss -73354.46\n", 339 | "Epoch 1740, loss -77267.91\n", 340 | "Epoch 1750, loss -77672.33\n", 341 | "Epoch 1760, loss -76815.55\n", 342 | "Epoch 1770, loss -62680.55\n", 343 | "Epoch 1780, loss -72739.41\n", 344 | "Epoch 1790, loss -73749.89\n", 345 | "Epoch 1800, loss -78774.23\n", 346 | "Epoch 1810, loss -78991.12\n", 347 | "Epoch 1820, loss -80322.95\n", 348 | "Epoch 1830, loss -80686.21\n", 349 | "Epoch 1840, loss -55581.43\n", 350 | "Epoch 1850, loss -46390.55\n", 351 | "Epoch 1860, loss -57801.64\n", 352 | "Epoch 1870, loss -76046.61\n", 353 | "Epoch 1880, loss -76357.95\n", 354 | "Epoch 1890, loss -78617.67\n", 355 | "Epoch 1900, loss -80762.16\n", 356 | "Epoch 1910, loss -82393.91\n", 357 | "Epoch 1920, loss -83980.20\n", 358 | "Epoch 1930, loss -84822.54\n", 359 | "Epoch 1940, loss 345088.59\n", 360 | "Epoch 1950, loss -23741.51\n", 361 | "Epoch 1960, loss -37661.29\n", 362 | "Epoch 1970, loss -43121.84\n", 363 | "Epoch 1980, loss -46113.71\n", 364 | "Epoch 1990, loss -48687.82\n", 365 | "Epoch 2000, loss -51514.29\n", 366 | "Epoch 2010, loss -54268.95\n", 367 | "Epoch 2020, loss -57063.11\n", 368 | "Epoch 2030, loss -59798.37\n", 369 | "Epoch 2040, loss -62479.46\n", 370 | "Epoch 2050, loss -65088.39\n", 371 | "Epoch 2060, loss -67622.91\n", 372 | "Epoch 2070, loss -70044.23\n", 373 | "Epoch 2080, loss -72375.73\n", 374 | "Epoch 2090, loss -74585.80\n", 375 | "Epoch 2100, loss -76670.95\n", 376 | "Epoch 2110, loss -78600.62\n", 377 | "Epoch 2120, loss -80372.09\n", 378 | "Epoch 2130, loss -81969.05\n", 379 | "Epoch 2140, loss -83370.11\n", 380 | "Epoch 2150, loss -84618.47\n", 381 | "Epoch 2160, loss -85771.45\n", 382 | "Epoch 2170, loss -86807.05\n", 383 | "Epoch 2180, loss -87606.51\n", 384 | "Epoch 2190, loss -61470.08\n", 385 | "Epoch 2200, loss -74010.64\n", 386 | "Epoch 2210, loss -76022.68\n", 387 | "Epoch 2220, loss -82089.31\n", 388 | "Epoch 2230, loss -84937.18\n", 389 | "Epoch 2240, loss -86448.73\n", 390 | "Epoch 2250, loss -88282.96\n", 391 | "Epoch 2260, loss -89594.67\n", 392 | "Epoch 2270, loss -90739.01\n", 393 | "Epoch 2280, loss -21140.45\n", 394 | "Epoch 2290, loss -39450.83\n", 395 | "Epoch 2300, loss -59282.54\n", 396 | "Epoch 2310, loss -74747.38\n", 397 | "Epoch 2320, loss -78880.91\n", 398 | "Epoch 2330, loss -80913.02\n", 399 | "Epoch 2340, loss -83735.23\n", 400 | "Epoch 2350, loss -86272.65\n", 401 | "Epoch 2360, loss -88344.35\n", 402 | "Epoch 2370, loss -90184.88\n", 403 | "Epoch 2380, loss -91706.80\n", 404 | "Epoch 2390, loss -92963.16\n", 405 | "Epoch 2400, loss -93987.36\n", 406 | "Epoch 2410, loss 744482.00\n", 407 | "Epoch 2420, loss 44513.98\n", 408 | "Epoch 2430, loss -9130.21\n", 409 | "Epoch 2440, loss -15187.98\n", 410 | "Epoch 2450, loss -18211.37\n", 411 | "Epoch 2460, loss -19731.43\n", 412 | "Epoch 2470, loss -21620.98\n", 413 | "Epoch 2480, loss -23381.19\n", 414 | "Epoch 2490, loss -25183.43\n", 415 | "Epoch 2500, loss -27060.05\n", 416 | "Epoch 2510, loss -29001.80\n", 417 | "Epoch 2520, loss -30948.19\n", 418 | "Epoch 2530, loss -32784.82\n", 419 | "Epoch 2540, loss -34430.95\n", 420 | "Epoch 2550, loss -35960.95\n", 421 | "Epoch 2560, loss -37482.44\n", 422 | "Epoch 2570, loss -39036.33\n", 423 | "Epoch 2580, loss -40664.12\n", 424 | "Epoch 2590, loss -42376.24\n", 425 | "Epoch 2600, loss -44134.36\n", 426 | "Epoch 2610, loss -45895.88\n", 427 | "Epoch 2620, loss -47583.29\n", 428 | "Epoch 2630, loss -49187.32\n", 429 | "Epoch 2640, loss -50741.23\n", 430 | "Epoch 2650, loss -52238.79\n", 431 | "Epoch 2660, loss -53694.21\n", 432 | "Epoch 2670, loss -55129.62\n", 433 | "Epoch 2680, loss -56540.00\n", 434 | "Epoch 2690, loss -57914.24\n", 435 | "Epoch 2700, loss -59271.00\n", 436 | "Epoch 2710, loss -60609.45\n", 437 | "Epoch 2720, loss -61903.97\n", 438 | "Epoch 2730, loss -63161.50\n", 439 | "Epoch 2740, loss -64389.88\n", 440 | "Epoch 2750, loss -65577.27\n", 441 | "Epoch 2760, loss -66720.92\n", 442 | "Epoch 2770, loss -67805.08\n", 443 | "Epoch 2780, loss -68828.96\n", 444 | "Epoch 2790, loss -69790.92\n", 445 | "Epoch 2800, loss -70710.38\n", 446 | "Epoch 2810, loss -71576.09\n", 447 | "Epoch 2820, loss -72393.73\n", 448 | "Epoch 2830, loss -73161.41\n", 449 | "Epoch 2840, loss -73874.81\n", 450 | "Epoch 2850, loss -74556.52\n", 451 | "Epoch 2860, loss -75189.28\n", 452 | "Epoch 2870, loss -75763.25\n", 453 | "Epoch 2880, loss -76316.74\n", 454 | "Epoch 2890, loss -76831.20\n", 455 | "Epoch 2900, loss -77305.71\n", 456 | "Epoch 2910, loss -77776.20\n", 457 | "Epoch 2920, loss -78238.48\n", 458 | "Epoch 2930, loss -78645.98\n", 459 | "Epoch 2940, loss -79104.46\n", 460 | "Epoch 2950, loss -79583.48\n", 461 | "Epoch 2960, loss -80040.57\n", 462 | "Epoch 2970, loss -80529.94\n", 463 | "Epoch 2980, loss -81054.62\n", 464 | "Epoch 2990, loss -81649.76\n", 465 | "Epoch 3000, loss -82431.91\n", 466 | "Epoch 3010, loss -83554.46\n", 467 | "Epoch 3020, loss -84913.59\n", 468 | "Epoch 3030, loss -86289.64\n", 469 | "Epoch 3040, loss -81194.52\n", 470 | "Epoch 3050, loss -85428.85\n", 471 | "Epoch 3060, loss -87980.28\n", 472 | "Epoch 3070, loss -90324.62\n", 473 | "Epoch 3080, loss -83613.60\n", 474 | "Epoch 3090, loss -82955.12\n", 475 | "Epoch 3100, loss -90944.62\n", 476 | "Epoch 3110, loss -88823.47\n", 477 | "Epoch 3120, loss -81910.95\n", 478 | "Epoch 3130, loss -92585.73\n", 479 | "Epoch 3140, loss -91417.54\n", 480 | "Epoch 3150, loss -87628.83\n", 481 | "Epoch 3160, loss -75756.72\n", 482 | "Epoch 3170, loss -88894.14\n", 483 | "Epoch 3180, loss -92092.09\n", 484 | "Epoch 3190, loss -94124.62\n", 485 | "Epoch 3200, loss -90816.95\n", 486 | "Epoch 3210, loss -88344.04\n", 487 | "Epoch 3220, loss -71449.78\n", 488 | "Epoch 3230, loss -80729.23\n", 489 | "Epoch 3240, loss -89975.44\n", 490 | "Epoch 3250, loss -93350.40\n", 491 | "Epoch 3260, loss -94743.73\n", 492 | "Epoch 3270, loss -96806.68\n", 493 | "Epoch 3280, loss -97470.52\n", 494 | "Epoch 3290, loss -98185.43\n", 495 | "Epoch 3300, loss -6764.65\n", 496 | "Epoch 3310, loss -73645.73\n", 497 | "Epoch 3320, loss -86519.07\n", 498 | "Epoch 3330, loss -91021.43\n", 499 | "Epoch 3340, loss -93277.52\n", 500 | "Epoch 3350, loss -95111.97\n", 501 | "Epoch 3360, loss -96612.22\n", 502 | "Epoch 3370, loss -97856.93\n", 503 | "Epoch 3380, loss -98998.73\n", 504 | "Epoch 3390, loss -99978.87\n", 505 | "Epoch 3400, loss -88397.48\n", 506 | "Epoch 3410, loss 158809.56\n", 507 | "Epoch 3420, loss -55764.23\n", 508 | "Epoch 3430, loss -63518.30\n", 509 | "Epoch 3440, loss -65781.38\n", 510 | "Epoch 3450, loss -69931.09\n", 511 | "Epoch 3460, loss -72585.02\n", 512 | "Epoch 3470, loss -75254.98\n", 513 | "Epoch 3480, loss -78007.34\n", 514 | "Epoch 3490, loss -80588.73\n", 515 | "Epoch 3500, loss -83076.34\n", 516 | "Epoch 3510, loss -85426.57\n", 517 | "Epoch 3520, loss -87668.03\n", 518 | "Epoch 3530, loss -89776.33\n", 519 | "Epoch 3540, loss -91752.87\n", 520 | "Epoch 3550, loss -93594.47\n", 521 | "Epoch 3560, loss -95276.14\n", 522 | "Epoch 3570, loss -96799.57\n", 523 | "Epoch 3580, loss -98166.21\n", 524 | "Epoch 3590, loss -99381.25\n", 525 | "Epoch 3600, loss -100461.64\n", 526 | "Epoch 3610, loss -101393.97\n", 527 | "Epoch 3620, loss -102156.15\n", 528 | "Epoch 3630, loss -88458.48\n", 529 | "Epoch 3640, loss 27747.54\n", 530 | "Epoch 3650, loss -41076.41\n", 531 | "Epoch 3660, loss -48306.17\n", 532 | "Epoch 3670, loss -51584.99\n", 533 | "Epoch 3680, loss -53613.86\n", 534 | "Epoch 3690, loss -55867.57\n", 535 | "Epoch 3700, loss -57911.46\n", 536 | "Epoch 3710, loss -59896.06\n", 537 | "Epoch 3720, loss -61849.99\n", 538 | "Epoch 3730, loss -63768.74\n", 539 | "Epoch 3740, loss -65679.94\n", 540 | "Epoch 3750, loss -67587.01\n", 541 | "Epoch 3760, loss -69490.23\n", 542 | "Epoch 3770, loss -71417.75\n", 543 | "Epoch 3780, loss -73326.37\n", 544 | "Epoch 3790, loss -75199.57\n", 545 | "Epoch 3800, loss -77086.34\n", 546 | "Epoch 3810, loss -78951.89\n", 547 | "Epoch 3820, loss -80784.46\n", 548 | "Epoch 3830, loss -82668.26\n", 549 | "Epoch 3840, loss -84530.95\n", 550 | "Epoch 3850, loss -86328.55\n", 551 | "Epoch 3860, loss -88164.55\n", 552 | "Epoch 3870, loss -89936.71\n", 553 | "Epoch 3880, loss -91658.90\n", 554 | "Epoch 3890, loss -93232.55\n", 555 | "Epoch 3900, loss -94694.48\n", 556 | "Epoch 3910, loss -96044.77\n", 557 | "Epoch 3920, loss -97296.05\n", 558 | "Epoch 3930, loss -98472.09\n", 559 | "Epoch 3940, loss -99584.84\n", 560 | "Epoch 3950, loss -100607.41\n", 561 | "Epoch 3960, loss -101540.94\n", 562 | "Epoch 3970, loss -100977.95\n", 563 | "Epoch 3980, loss -99686.05\n", 564 | "Epoch 3990, loss -100127.66\n", 565 | "Epoch 4000, loss -92953.41\n", 566 | "Epoch 4010, loss -99552.77\n", 567 | "Epoch 4020, loss -99054.23\n", 568 | "Epoch 4030, loss -104116.86\n", 569 | "Epoch 4040, loss -52000.74\n", 570 | "Epoch 4050, loss -94998.57\n", 571 | "Epoch 4060, loss -90036.80\n", 572 | "Epoch 4070, loss -101566.39\n", 573 | "Epoch 4080, loss -93916.91\n", 574 | "Epoch 4090, loss -99944.02\n", 575 | "Epoch 4100, loss -103591.44\n", 576 | "Epoch 4110, loss -103210.05\n", 577 | "Epoch 4120, loss -105100.19\n", 578 | "Epoch 4130, loss -106069.65\n", 579 | "Epoch 4140, loss -34873.50\n", 580 | "Epoch 4150, loss -7480.66\n", 581 | "Epoch 4160, loss -74698.80\n", 582 | "Epoch 4170, loss -89828.90\n", 583 | "Epoch 4180, loss -94130.16\n", 584 | "Epoch 4190, loss -96908.14\n", 585 | "Epoch 4200, loss -98987.12\n", 586 | "Epoch 4210, loss -100603.50\n", 587 | "Epoch 4220, loss -102146.91\n", 588 | "Epoch 4230, loss -103498.73\n", 589 | "Epoch 4240, loss -104724.30\n", 590 | "Epoch 4250, loss -105777.87\n", 591 | "Epoch 4260, loss -106697.45\n", 592 | "Epoch 4270, loss -107464.04\n", 593 | "Epoch 4280, loss -107145.18\n", 594 | "Epoch 4290, loss 50179.43\n", 595 | "Epoch 4300, loss -76451.35\n", 596 | "Epoch 4310, loss -84666.50\n", 597 | "Epoch 4320, loss -93507.75\n", 598 | "Epoch 4330, loss -95469.19\n", 599 | "Epoch 4340, loss -97255.57\n", 600 | "Epoch 4350, loss -99428.41\n", 601 | "Epoch 4360, loss -101263.91\n", 602 | "Epoch 4370, loss -102956.08\n", 603 | "Epoch 4380, loss -104434.53\n", 604 | "Epoch 4390, loss -105733.59\n", 605 | "Epoch 4400, loss -106857.23\n", 606 | "Epoch 4410, loss -107808.33\n", 607 | "Epoch 4420, loss -108608.50\n", 608 | "Epoch 4430, loss -109273.43\n", 609 | "Epoch 4440, loss 40229.18\n", 610 | "Epoch 4450, loss 154976.84\n", 611 | "Epoch 4460, loss -25104.65\n", 612 | "Epoch 4470, loss -41701.54\n", 613 | "Epoch 4480, loss -46744.01\n", 614 | "Epoch 4490, loss -48865.25\n", 615 | "Epoch 4500, loss -50956.71\n", 616 | "Epoch 4510, loss -52989.01\n", 617 | "Epoch 4520, loss -54856.98\n", 618 | "Epoch 4530, loss -56613.54\n", 619 | "Epoch 4540, loss -58306.95\n", 620 | "Epoch 4550, loss -59946.86\n", 621 | "Epoch 4560, loss -61556.13\n", 622 | "Epoch 4570, loss -63134.77\n", 623 | "Epoch 4580, loss -64694.33\n", 624 | "Epoch 4590, loss -66251.62\n", 625 | "Epoch 4600, loss -67838.18\n", 626 | "Epoch 4610, loss -69392.95\n", 627 | "Epoch 4620, loss -70971.81\n", 628 | "Epoch 4630, loss -72575.58\n", 629 | "Epoch 4640, loss -74287.96\n", 630 | "Epoch 4650, loss -76096.46\n", 631 | "Epoch 4660, loss -77885.64\n", 632 | "Epoch 4670, loss -79722.05\n", 633 | "Epoch 4680, loss -81563.27\n", 634 | "Epoch 4690, loss -83376.52\n", 635 | "Epoch 4700, loss -85138.78\n", 636 | "Epoch 4710, loss -86869.65\n", 637 | "Epoch 4720, loss -88534.77\n", 638 | "Epoch 4730, loss -90138.80\n", 639 | "Epoch 4740, loss -91715.55\n", 640 | "Epoch 4750, loss -93254.80\n", 641 | "Epoch 4760, loss -94734.90\n", 642 | "Epoch 4770, loss -96136.34\n", 643 | "Epoch 4780, loss -97500.91\n", 644 | "Epoch 4790, loss -98827.72\n", 645 | "Epoch 4800, loss -100062.82\n", 646 | "Epoch 4810, loss -101209.55\n", 647 | "Epoch 4820, loss -102291.44\n", 648 | "Epoch 4830, loss -103297.22\n", 649 | "Epoch 4840, loss -104232.45\n", 650 | "Epoch 4850, loss -105085.96\n", 651 | "Epoch 4860, loss -105858.30\n", 652 | "Epoch 4870, loss -106559.77\n", 653 | "Epoch 4880, loss -107192.84\n", 654 | "Epoch 4890, loss -107609.21\n", 655 | "Epoch 4900, loss -69385.87\n", 656 | "Epoch 4910, loss -100873.80\n", 657 | "Epoch 4920, loss -106412.80\n", 658 | "Epoch 4930, loss -105919.39\n", 659 | "Epoch 4940, loss -108033.14\n", 660 | "Epoch 4950, loss -108386.62\n", 661 | "Epoch 4960, loss -108525.30\n", 662 | "Epoch 4970, loss -47695.86\n", 663 | "Epoch 4980, loss -107796.94\n", 664 | "Epoch 4990, loss -102208.20\n", 665 | "Epoch 5000, loss -101055.08\n", 666 | "Epoch 5010, loss -108069.83\n", 667 | "Epoch 5020, loss -108257.25\n", 668 | "Epoch 5030, loss -108774.73\n", 669 | "Epoch 5040, loss -109636.82\n", 670 | "Epoch 5050, loss -109725.91\n", 671 | "Epoch 5060, loss -109514.97\n", 672 | "Epoch 5070, loss -52975.13\n", 673 | "Epoch 5080, loss -95365.26\n", 674 | "Epoch 5090, loss -104794.48\n", 675 | "Epoch 5100, loss -105140.96\n", 676 | "Epoch 5110, loss -107128.32\n", 677 | "Epoch 5120, loss -108073.70\n", 678 | "Epoch 5130, loss -108947.34\n", 679 | "Epoch 5140, loss -109783.16\n", 680 | "Epoch 5150, loss -110474.47\n", 681 | "Epoch 5160, loss -111075.19\n", 682 | "Epoch 5170, loss -111429.46\n", 683 | "Epoch 5180, loss 83548.16\n", 684 | "Epoch 5190, loss -41627.70\n", 685 | "Epoch 5200, loss -87061.01\n", 686 | "Epoch 5210, loss -89843.94\n", 687 | "Epoch 5220, loss -90954.05\n", 688 | "Epoch 5230, loss -93179.61\n", 689 | "Epoch 5240, loss -95698.82\n", 690 | "Epoch 5250, loss -98009.93\n", 691 | "Epoch 5260, loss -99943.44\n", 692 | "Epoch 5270, loss -101844.41\n", 693 | "Epoch 5280, loss -103573.16\n", 694 | "Epoch 5290, loss -105158.73\n", 695 | "Epoch 5300, loss -106580.29\n", 696 | "Epoch 5310, loss -107845.62\n", 697 | "Epoch 5320, loss -108954.88\n", 698 | "Epoch 5330, loss -109920.48\n", 699 | "Epoch 5340, loss -110744.90\n", 700 | "Epoch 5350, loss -111441.09\n", 701 | "Epoch 5360, loss -112030.77\n", 702 | "Epoch 5370, loss -112522.66\n", 703 | "Epoch 5380, loss -112929.64\n", 704 | "Epoch 5390, loss -97046.19\n", 705 | "Epoch 5400, loss 170803.30\n", 706 | "Epoch 5410, loss -22692.87\n", 707 | "Epoch 5420, loss -45831.59\n", 708 | "Epoch 5430, loss -51859.54\n", 709 | "Epoch 5440, loss -54495.63\n", 710 | "Epoch 5450, loss -57058.67\n", 711 | "Epoch 5460, loss -59167.44\n", 712 | "Epoch 5470, loss -61229.66\n", 713 | "Epoch 5480, loss -63261.40\n", 714 | "Epoch 5490, loss -65291.31\n", 715 | "Epoch 5500, loss -67308.91\n", 716 | "Epoch 5510, loss -69326.96\n", 717 | "Epoch 5520, loss -71363.73\n", 718 | "Epoch 5530, loss -73425.38\n", 719 | "Epoch 5540, loss -75458.30\n", 720 | "Epoch 5550, loss -77448.34\n", 721 | "Epoch 5560, loss -79408.96\n", 722 | "Epoch 5570, loss -81327.62\n", 723 | "Epoch 5580, loss -83219.09\n", 724 | "Epoch 5590, loss -85080.94\n", 725 | "Epoch 5600, loss -86903.84\n", 726 | "Epoch 5610, loss -88691.30\n", 727 | "Epoch 5620, loss -90456.96\n", 728 | "Epoch 5630, loss -92189.08\n", 729 | "Epoch 5640, loss -93839.98\n", 730 | "Epoch 5650, loss -95421.43\n", 731 | "Epoch 5660, loss -96938.85\n", 732 | "Epoch 5670, loss -98397.57\n", 733 | "Epoch 5680, loss -99782.93\n", 734 | "Epoch 5690, loss -101098.66\n", 735 | "Epoch 5700, loss -102341.05\n", 736 | "Epoch 5710, loss -103519.83\n", 737 | "Epoch 5720, loss -104622.14\n", 738 | "Epoch 5730, loss -105636.45\n", 739 | "Epoch 5740, loss -106575.08\n", 740 | "Epoch 5750, loss -107442.15\n", 741 | "Epoch 5760, loss -108231.09\n", 742 | "Epoch 5770, loss -108957.66\n", 743 | "Epoch 5780, loss -109612.02\n", 744 | "Epoch 5790, loss -110199.25\n", 745 | "Epoch 5800, loss -110724.89\n", 746 | "Epoch 5810, loss -111214.62\n", 747 | "Epoch 5820, loss -111650.59\n", 748 | "Epoch 5830, loss 5656.66\n", 749 | "Epoch 5840, loss -66662.41\n", 750 | "Epoch 5850, loss -39172.06\n", 751 | "Epoch 5860, loss -66872.05\n", 752 | "Epoch 5870, loss -76977.48\n", 753 | "Epoch 5880, loss -80942.12\n", 754 | "Epoch 5890, loss -83573.24\n", 755 | "Epoch 5900, loss -85763.07\n", 756 | "Epoch 5910, loss -87816.61\n", 757 | "Epoch 5920, loss -89809.44\n", 758 | "Epoch 5930, loss -91741.66\n", 759 | "Epoch 5940, loss -93607.09\n", 760 | "Epoch 5950, loss -95408.85\n", 761 | "Epoch 5960, loss -97137.91\n", 762 | "Epoch 5970, loss -98792.87\n", 763 | "Epoch 5980, loss -100368.09\n", 764 | "Epoch 5990, loss -101859.48\n", 765 | "Epoch 6000, loss -103259.14\n", 766 | "Epoch 6010, loss -104566.41\n", 767 | "Epoch 6020, loss -105774.88\n", 768 | "Epoch 6030, loss -106884.38\n", 769 | "Epoch 6040, loss -107897.21\n", 770 | "Epoch 6050, loss -108818.22\n", 771 | "Epoch 6060, loss -109651.15\n", 772 | "Epoch 6070, loss -110395.73\n", 773 | "Epoch 6080, loss -111062.98\n", 774 | "Epoch 6090, loss -111653.27\n", 775 | "Epoch 6100, loss -112171.78\n", 776 | "Epoch 6110, loss -112635.21\n", 777 | "Epoch 6120, loss -113040.23\n", 778 | "Epoch 6130, loss -113401.89\n", 779 | "Epoch 6140, loss -113714.83\n", 780 | "Epoch 6150, loss -108802.47\n", 781 | "Epoch 6160, loss -104949.93\n", 782 | "Epoch 6170, loss -58156.18\n", 783 | "Epoch 6180, loss -89268.23\n", 784 | "Epoch 6190, loss -100014.55\n", 785 | "Epoch 6200, loss -103692.32\n", 786 | "Epoch 6210, loss -104837.87\n", 787 | "Epoch 6220, loss -106163.09\n", 788 | "Epoch 6230, loss -107358.96\n", 789 | "Epoch 6240, loss -108626.21\n", 790 | "Epoch 6250, loss -109711.43\n", 791 | "Epoch 6260, loss -110674.66\n", 792 | "Epoch 6270, loss -111534.79\n", 793 | "Epoch 6280, loss -112284.39\n", 794 | "Epoch 6290, loss -112937.16\n", 795 | "Epoch 6300, loss -113500.41\n", 796 | "Epoch 6310, loss -113985.76\n", 797 | "Epoch 6320, loss -114403.71\n", 798 | "Epoch 6330, loss -114757.48\n", 799 | "Epoch 6340, loss -115053.96\n", 800 | "Epoch 6350, loss -80065.43\n", 801 | "Epoch 6360, loss 42780.45\n", 802 | "Epoch 6370, loss -74112.86\n", 803 | "Epoch 6380, loss -96772.39\n", 804 | "Epoch 6390, loss -96349.12\n", 805 | "Epoch 6400, loss -98951.05\n", 806 | "Epoch 6410, loss -101816.53\n", 807 | "Epoch 6420, loss -103573.51\n", 808 | "Epoch 6430, loss -105382.97\n", 809 | "Epoch 6440, loss -106949.32\n", 810 | "Epoch 6450, loss -108382.36\n", 811 | "Epoch 6460, loss -109691.45\n", 812 | "Epoch 6470, loss -110847.89\n", 813 | "Epoch 6480, loss -111862.71\n", 814 | "Epoch 6490, loss -112745.71\n", 815 | "Epoch 6500, loss -113504.03\n", 816 | "Epoch 6510, loss -114147.58\n", 817 | "Epoch 6520, loss -114685.12\n", 818 | "Epoch 6530, loss -115142.58\n", 819 | "Epoch 6540, loss -115522.51\n", 820 | "Epoch 6550, loss -115835.65\n", 821 | "Epoch 6560, loss -116072.12\n", 822 | "Epoch 6570, loss 63859.48\n", 823 | "Epoch 6580, loss 32298.74\n", 824 | "Epoch 6590, loss -75705.68\n", 825 | "Epoch 6600, loss -93918.87\n", 826 | "Epoch 6610, loss -92042.83\n", 827 | "Epoch 6620, loss -95133.46\n", 828 | "Epoch 6630, loss -98018.44\n", 829 | "Epoch 6640, loss -100114.11\n", 830 | "Epoch 6650, loss -102124.23\n", 831 | "Epoch 6660, loss -104077.86\n", 832 | "Epoch 6670, loss -105891.88\n", 833 | "Epoch 6680, loss -107557.69\n", 834 | "Epoch 6690, loss -109063.29\n", 835 | "Epoch 6700, loss -110411.93\n", 836 | "Epoch 6710, loss -111602.75\n", 837 | "Epoch 6720, loss -112643.05\n", 838 | "Epoch 6730, loss -113537.16\n", 839 | "Epoch 6740, loss -114300.48\n", 840 | "Epoch 6750, loss -114943.32\n", 841 | "Epoch 6760, loss -115480.37\n", 842 | "Epoch 6770, loss -115918.21\n", 843 | "Epoch 6780, loss -116283.47\n", 844 | "Epoch 6790, loss -116577.47\n", 845 | "Epoch 6800, loss -116783.80\n", 846 | "Epoch 6810, loss -73295.74\n", 847 | "Epoch 6820, loss -106503.05\n", 848 | "Epoch 6830, loss -110904.19\n", 849 | "Epoch 6840, loss -111974.62\n", 850 | "Epoch 6850, loss -112636.39\n", 851 | "Epoch 6860, loss -114095.71\n", 852 | "Epoch 6870, loss -114944.34\n", 853 | "Epoch 6880, loss -115682.11\n", 854 | "Epoch 6890, loss -116215.46\n", 855 | "Epoch 6900, loss -116671.80\n", 856 | "Epoch 6910, loss -117015.91\n", 857 | "Epoch 6920, loss -117292.39\n", 858 | "Epoch 6930, loss -100912.11\n", 859 | "Epoch 6940, loss -15685.92\n", 860 | "Epoch 6950, loss -96513.90\n", 861 | "Epoch 6960, loss -102837.55\n", 862 | "Epoch 6970, loss -106339.73\n", 863 | "Epoch 6980, loss -108053.76\n", 864 | "Epoch 6990, loss -109473.07\n", 865 | "Epoch 7000, loss -111090.58\n", 866 | "Epoch 7010, loss -112455.77\n", 867 | "Epoch 7020, loss -113676.30\n", 868 | "Epoch 7030, loss -114685.65\n", 869 | "Epoch 7040, loss -115529.86\n", 870 | "Epoch 7050, loss -116223.52\n", 871 | "Epoch 7060, loss -116779.95\n", 872 | "Epoch 7070, loss -117222.30\n", 873 | "Epoch 7080, loss -117572.12\n", 874 | "Epoch 7090, loss -117839.12\n", 875 | "Epoch 7100, loss -118048.28\n", 876 | "Epoch 7110, loss -84183.46\n", 877 | "Epoch 7120, loss -8931.41\n", 878 | "Epoch 7130, loss -72306.54\n", 879 | "Epoch 7140, loss -82578.55\n", 880 | "Epoch 7150, loss -85630.09\n", 881 | "Epoch 7160, loss -88073.62\n", 882 | "Epoch 7170, loss -90509.62\n", 883 | "Epoch 7180, loss -93213.04\n", 884 | "Epoch 7190, loss -95900.32\n", 885 | "Epoch 7200, loss -98422.82\n", 886 | "Epoch 7210, loss -100841.41\n", 887 | "Epoch 7220, loss -103117.41\n", 888 | "Epoch 7230, loss -105240.80\n", 889 | "Epoch 7240, loss -107195.78\n", 890 | "Epoch 7250, loss -108972.08\n", 891 | "Epoch 7260, loss -110569.65\n", 892 | "Epoch 7270, loss -111990.98\n", 893 | "Epoch 7280, loss -113244.09\n", 894 | "Epoch 7290, loss -114320.58\n", 895 | "Epoch 7300, loss -115231.22\n", 896 | "Epoch 7310, loss -115991.73\n", 897 | "Epoch 7320, loss -116616.80\n", 898 | "Epoch 7330, loss -117123.64\n", 899 | "Epoch 7340, loss -117537.57\n", 900 | "Epoch 7350, loss -117860.09\n", 901 | "Epoch 7360, loss -118133.37\n", 902 | "Epoch 7370, loss -118030.54\n", 903 | "Epoch 7380, loss -47835.49\n", 904 | "Epoch 7390, loss -101363.86\n", 905 | "Epoch 7400, loss -83497.20\n", 906 | "Epoch 7410, loss -94927.00\n", 907 | "Epoch 7420, loss -99580.90\n", 908 | "Epoch 7430, loss -102781.48\n", 909 | "Epoch 7440, loss -105393.71\n", 910 | "Epoch 7450, loss -107271.02\n", 911 | "Epoch 7460, loss -109211.84\n", 912 | "Epoch 7470, loss -110933.73\n", 913 | "Epoch 7480, loss -112452.00\n", 914 | "Epoch 7490, loss -113768.66\n", 915 | "Epoch 7500, loss -114892.98\n", 916 | "Epoch 7510, loss -115834.88\n", 917 | "Epoch 7520, loss -116615.83\n", 918 | "Epoch 7530, loss -117250.87\n", 919 | "Epoch 7540, loss -117763.14\n", 920 | "Epoch 7550, loss -118161.00\n", 921 | "Epoch 7560, loss -118486.39\n", 922 | "Epoch 7570, loss -118743.43\n", 923 | "Epoch 7580, loss -118900.87\n", 924 | "Epoch 7590, loss 64195.87\n", 925 | "Epoch 7600, loss -65829.59\n", 926 | "Epoch 7610, loss -99492.86\n", 927 | "Epoch 7620, loss -101486.91\n", 928 | "Epoch 7630, loss -102585.97\n", 929 | "Epoch 7640, loss -106663.02\n", 930 | "Epoch 7650, loss -108187.64\n", 931 | "Epoch 7660, loss -110183.18\n", 932 | "Epoch 7670, loss -111983.46\n", 933 | "Epoch 7680, loss -113478.40\n", 934 | "Epoch 7690, loss -114780.83\n", 935 | "Epoch 7700, loss -115872.18\n", 936 | "Epoch 7710, loss -116771.15\n", 937 | "Epoch 7720, loss -117498.16\n", 938 | "Epoch 7730, loss -118083.88\n", 939 | "Epoch 7740, loss -118535.88\n", 940 | "Epoch 7750, loss -118889.20\n", 941 | "Epoch 7760, loss -119149.26\n", 942 | "Epoch 7770, loss -119353.18\n", 943 | "Epoch 7780, loss -99775.72\n", 944 | "Epoch 7790, loss -46727.00\n", 945 | "Epoch 7800, loss -111086.62\n", 946 | "Epoch 7810, loss -105316.71\n", 947 | "Epoch 7820, loss -111437.66\n", 948 | "Epoch 7830, loss -112634.93\n", 949 | "Epoch 7840, loss -114444.47\n", 950 | "Epoch 7850, loss -115644.07\n", 951 | "Epoch 7860, loss -116751.44\n", 952 | "Epoch 7870, loss -117608.38\n", 953 | "Epoch 7880, loss -118289.20\n", 954 | "Epoch 7890, loss -118810.38\n", 955 | "Epoch 7900, loss -119215.91\n", 956 | "Epoch 7910, loss -119524.48\n", 957 | "Epoch 7920, loss -119755.83\n", 958 | "Epoch 7930, loss -119614.84\n", 959 | "Epoch 7940, loss -103011.78\n", 960 | "Epoch 7950, loss -75541.19\n", 961 | "Epoch 7960, loss -103726.48\n", 962 | "Epoch 7970, loss -105124.46\n", 963 | "Epoch 7980, loss -109073.39\n", 964 | "Epoch 7990, loss -111198.39\n", 965 | "Epoch 8000, loss -113063.79\n", 966 | "Epoch 8010, loss -114677.93\n", 967 | "Epoch 8020, loss -116034.30\n", 968 | "Epoch 8030, loss -117143.84\n", 969 | "Epoch 8040, loss -118036.09\n", 970 | "Epoch 8050, loss -118733.73\n", 971 | "Epoch 8060, loss -119267.59\n", 972 | "Epoch 8070, loss -119675.54\n", 973 | "Epoch 8080, loss -119969.70\n", 974 | "Epoch 8090, loss -120188.27\n", 975 | "Epoch 8100, loss -114375.14\n", 976 | "Epoch 8110, loss -67193.92\n", 977 | "Epoch 8120, loss -103463.41\n", 978 | "Epoch 8130, loss -108949.55\n", 979 | "Epoch 8140, loss -110628.45\n", 980 | "Epoch 8150, loss -111818.41\n", 981 | "Epoch 8160, loss -113481.04\n", 982 | "Epoch 8170, loss -115118.09\n", 983 | "Epoch 8180, loss -116507.28\n", 984 | "Epoch 8190, loss -117645.30\n", 985 | "Epoch 8200, loss -118538.22\n", 986 | "Epoch 8210, loss -119230.86\n", 987 | "Epoch 8220, loss -119751.25\n", 988 | "Epoch 8230, loss -120136.62\n", 989 | "Epoch 8240, loss -120429.91\n", 990 | "Epoch 8250, loss -120538.88\n", 991 | "Epoch 8260, loss 184777.78\n", 992 | "Epoch 8270, loss -111113.37\n", 993 | "Epoch 8280, loss -111335.44\n", 994 | "Epoch 8290, loss -111665.68\n", 995 | "Epoch 8300, loss -114595.43\n", 996 | "Epoch 8310, loss -115903.51\n", 997 | "Epoch 8320, loss -117137.14\n", 998 | "Epoch 8330, loss -118243.72\n", 999 | "Epoch 8340, loss -119088.33\n", 1000 | "Epoch 8350, loss -119759.71\n", 1001 | "Epoch 8360, loss -120249.80\n", 1002 | "Epoch 8370, loss -120612.55\n", 1003 | "Epoch 8380, loss -120861.72\n", 1004 | "Epoch 8390, loss -118918.66\n", 1005 | "Epoch 8400, loss -44942.49\n", 1006 | "Epoch 8410, loss -114123.66\n", 1007 | "Epoch 8420, loss -111895.97\n", 1008 | "Epoch 8430, loss -114036.12\n", 1009 | "Epoch 8440, loss -115731.43\n", 1010 | "Epoch 8450, loss -117003.55\n", 1011 | "Epoch 8460, loss -118211.25\n", 1012 | "Epoch 8470, loss -119181.43\n", 1013 | "Epoch 8480, loss -119919.43\n", 1014 | "Epoch 8490, loss -120484.05\n", 1015 | "Epoch 8500, loss -120886.65\n", 1016 | "Epoch 8510, loss -121175.90\n", 1017 | "Epoch 8520, loss -121205.37\n", 1018 | "Epoch 8530, loss 66279.67\n", 1019 | "Epoch 8540, loss -106688.22\n", 1020 | "Epoch 8550, loss -105536.00\n", 1021 | "Epoch 8560, loss -104695.39\n", 1022 | "Epoch 8570, loss -108694.82\n", 1023 | "Epoch 8580, loss -110973.62\n", 1024 | "Epoch 8590, loss -113155.57\n", 1025 | "Epoch 8600, loss -115110.91\n", 1026 | "Epoch 8610, loss -116783.18\n", 1027 | "Epoch 8620, loss -118141.80\n", 1028 | "Epoch 8630, loss -119229.76\n", 1029 | "Epoch 8640, loss -120057.34\n", 1030 | "Epoch 8650, loss -120677.54\n", 1031 | "Epoch 8660, loss -121132.95\n", 1032 | "Epoch 8670, loss -121456.93\n", 1033 | "Epoch 8680, loss -121679.68\n", 1034 | "Epoch 8690, loss -104398.47\n", 1035 | "Epoch 8700, loss -57141.88\n", 1036 | "Epoch 8710, loss -102278.43\n", 1037 | "Epoch 8720, loss -108114.41\n", 1038 | "Epoch 8730, loss -110978.48\n", 1039 | "Epoch 8740, loss -113095.73\n", 1040 | "Epoch 8750, loss -115049.16\n", 1041 | "Epoch 8760, loss -116840.65\n", 1042 | "Epoch 8770, loss -118310.27\n", 1043 | "Epoch 8780, loss -119440.48\n", 1044 | "Epoch 8790, loss -120323.47\n", 1045 | "Epoch 8800, loss -120983.71\n", 1046 | "Epoch 8810, loss -121463.69\n", 1047 | "Epoch 8820, loss -121810.71\n", 1048 | "Epoch 8830, loss -122016.45\n", 1049 | "Epoch 8840, loss -37323.09\n", 1050 | "Epoch 8850, loss -105820.95\n", 1051 | "Epoch 8860, loss -112040.70\n", 1052 | "Epoch 8870, loss -113147.72\n", 1053 | "Epoch 8880, loss -116080.71\n", 1054 | "Epoch 8890, loss -117591.04\n", 1055 | "Epoch 8900, loss -118807.54\n", 1056 | "Epoch 8910, loss -119996.21\n", 1057 | "Epoch 8920, loss -120910.84\n", 1058 | "Epoch 8930, loss -121601.50\n", 1059 | "Epoch 8940, loss -122123.59\n", 1060 | "Epoch 8950, loss -122551.16\n", 1061 | "Epoch 8960, loss -122840.71\n", 1062 | "Epoch 8970, loss 121901.48\n", 1063 | "Epoch 8980, loss -88795.62\n", 1064 | "Epoch 8990, loss -103422.36\n", 1065 | "Epoch 9000, loss -105195.09\n", 1066 | "Epoch 9010, loss -108088.33\n", 1067 | "Epoch 9020, loss -110715.51\n", 1068 | "Epoch 9030, loss -113259.52\n", 1069 | "Epoch 9040, loss -115477.83\n", 1070 | "Epoch 9050, loss -117404.30\n", 1071 | "Epoch 9060, loss -119024.55\n", 1072 | "Epoch 9070, loss -120362.77\n", 1073 | "Epoch 9080, loss -121434.71\n", 1074 | "Epoch 9090, loss -122258.57\n", 1075 | "Epoch 9100, loss -122906.34\n", 1076 | "Epoch 9110, loss -123428.41\n", 1077 | "Epoch 9120, loss -123641.44\n", 1078 | "Epoch 9130, loss 22052.00\n", 1079 | "Epoch 9140, loss -106779.45\n", 1080 | "Epoch 9150, loss -105741.48\n", 1081 | "Epoch 9160, loss -107795.78\n", 1082 | "Epoch 9170, loss -111509.61\n", 1083 | "Epoch 9180, loss -113945.46\n", 1084 | "Epoch 9190, loss -116126.55\n", 1085 | "Epoch 9200, loss -118112.77\n", 1086 | "Epoch 9210, loss -119820.16\n", 1087 | "Epoch 9220, loss -121232.79\n", 1088 | "Epoch 9230, loss -122367.16\n", 1089 | "Epoch 9240, loss -123259.04\n", 1090 | "Epoch 9250, loss -123944.57\n", 1091 | "Epoch 9260, loss -124468.65\n", 1092 | "Epoch 9270, loss -123448.89\n", 1093 | "Epoch 9280, loss -57126.57\n", 1094 | "Epoch 9290, loss -102568.71\n", 1095 | "Epoch 9300, loss -105603.94\n", 1096 | "Epoch 9310, loss -111867.22\n", 1097 | "Epoch 9320, loss -113318.23\n", 1098 | "Epoch 9330, loss -115809.19\n", 1099 | "Epoch 9340, loss -118005.62\n", 1100 | "Epoch 9350, loss -119948.33\n", 1101 | "Epoch 9360, loss -121486.52\n", 1102 | "Epoch 9370, loss -122744.26\n", 1103 | "Epoch 9380, loss -123728.66\n", 1104 | "Epoch 9390, loss -124496.58\n", 1105 | "Epoch 9400, loss -125096.15\n", 1106 | "Epoch 9410, loss -125341.29\n", 1107 | "Epoch 9420, loss 419679.56\n", 1108 | "Epoch 9430, loss -20509.63\n", 1109 | "Epoch 9440, loss -68183.92\n", 1110 | "Epoch 9450, loss -73763.55\n", 1111 | "Epoch 9460, loss -76417.11\n", 1112 | "Epoch 9470, loss -79788.40\n", 1113 | "Epoch 9480, loss -82700.73\n", 1114 | "Epoch 9490, loss -85698.94\n", 1115 | "Epoch 9500, loss -88672.09\n", 1116 | "Epoch 9510, loss -91592.95\n", 1117 | "Epoch 9520, loss -94436.27\n", 1118 | "Epoch 9530, loss -97197.32\n", 1119 | "Epoch 9540, loss -99870.03\n", 1120 | "Epoch 9550, loss -102436.72\n", 1121 | "Epoch 9560, loss -104889.84\n", 1122 | "Epoch 9570, loss -107221.86\n", 1123 | "Epoch 9580, loss -109435.41\n", 1124 | "Epoch 9590, loss -111541.16\n", 1125 | "Epoch 9600, loss -113534.41\n", 1126 | "Epoch 9610, loss -115415.64\n", 1127 | "Epoch 9620, loss -117091.48\n", 1128 | "Epoch 9630, loss -118593.51\n", 1129 | "Epoch 9640, loss -119920.29\n", 1130 | "Epoch 9650, loss -121088.91\n", 1131 | "Epoch 9660, loss -122102.80\n", 1132 | "Epoch 9670, loss -122983.23\n", 1133 | "Epoch 9680, loss -123740.38\n", 1134 | "Epoch 9690, loss -124386.51\n", 1135 | "Epoch 9700, loss -124933.59\n", 1136 | "Epoch 9710, loss -125400.37\n", 1137 | "Epoch 9720, loss -125797.27\n", 1138 | "Epoch 9730, loss -126022.37\n", 1139 | "Epoch 9740, loss 64298.67\n", 1140 | "Epoch 9750, loss -93096.83\n", 1141 | "Epoch 9760, loss -110061.53\n", 1142 | "Epoch 9770, loss -110898.59\n", 1143 | "Epoch 9780, loss -112989.97\n", 1144 | "Epoch 9790, loss -115281.79\n", 1145 | "Epoch 9800, loss -117095.16\n", 1146 | "Epoch 9810, loss -118799.05\n", 1147 | "Epoch 9820, loss -120284.80\n", 1148 | "Epoch 9830, loss -121597.21\n", 1149 | "Epoch 9840, loss -122731.80\n", 1150 | "Epoch 9850, loss -123715.04\n", 1151 | "Epoch 9860, loss -124554.30\n", 1152 | "Epoch 9870, loss -125258.61\n", 1153 | "Epoch 9880, loss -125846.72\n", 1154 | "Epoch 9890, loss -126323.46\n", 1155 | "Epoch 9900, loss -126722.23\n", 1156 | "Epoch 9910, loss -124622.50\n", 1157 | "Epoch 9920, loss 33982.61\n", 1158 | "Epoch 9930, loss -103630.72\n", 1159 | "Epoch 9940, loss -103730.55\n", 1160 | "Epoch 9950, loss -108173.68\n", 1161 | "Epoch 9960, loss -110663.20\n", 1162 | "Epoch 9970, loss -113044.00\n", 1163 | "Epoch 9980, loss -115109.23\n", 1164 | "Epoch 9990, loss -117048.94\n", 1165 | "9905\n" 1166 | ] 1167 | } 1168 | ], 1169 | "source": [ 1170 | "optimizer = tf.keras.optimizers.Adam(learning_rate = 0.005)\n", 1171 | "spe = 25\n", 1172 | "string = './model/TRIALmodel1000C3_S1'\n", 1173 | "\n", 1174 | "@tf.function\n", 1175 | "def train_step():\n", 1176 | " with tf.GradientTape() as tape:\n", 1177 | " loss_value = 0\n", 1178 | " for i in range(0,spe): \n", 1179 | " logits = model({\"inputsB\":u_in, \"inputsT\":x_t_in}, training=True)\n", 1180 | " loss_value = loss_value + negloglikelihood(s_in, logits)\n", 1181 | " loss_value = loss_value*(1/spe)\n", 1182 | " grads = tape.gradient(loss_value, model.trainable_weights)\n", 1183 | " optimizer.apply_gradients(zip(grads, model.trainable_weights))\n", 1184 | " return loss_value\n", 1185 | "\n", 1186 | "epochs = 10000\n", 1187 | "loss = np.zeros(epochs)\n", 1188 | "\n", 1189 | "for epoch in range(epochs):\n", 1190 | " loss_value = train_step()\n", 1191 | " loss[epoch] = loss_value.numpy()\n", 1192 | " if loss[epoch] <= np.min(loss[0:epoch+1]):\n", 1193 | " model.save_weights(string)\n", 1194 | " last_saved_wt = epoch\n", 1195 | " if epoch%10 == 0:\n", 1196 | " print(\"Epoch %d, loss %.2f\" % (epoch, loss[epoch]))\n", 1197 | "\n", 1198 | "print(last_saved_wt)" 1199 | ] 1200 | }, 1201 | { 1202 | "cell_type": "code", 1203 | "execution_count": null, 1204 | "id": "e3eb53b1-f0fa-4fce-bdbe-a87281efaa1f", 1205 | "metadata": {}, 1206 | "outputs": [], 1207 | "source": [] 1208 | }, 1209 | { 1210 | "cell_type": "code", 1211 | "execution_count": 5, 1212 | "id": "9cebf033-567b-4508-a037-5951bf7adb06", 1213 | "metadata": {}, 1214 | "outputs": [ 1215 | { 1216 | "name": "stdout", 1217 | "output_type": "stream", 1218 | "text": [ 1219 | "0\n", 1220 | "5\n", 1221 | "10\n", 1222 | "15\n", 1223 | "20\n", 1224 | "25\n", 1225 | "30\n", 1226 | "35\n", 1227 | "40\n", 1228 | "45\n", 1229 | "50\n", 1230 | "55\n", 1231 | "60\n", 1232 | "65\n", 1233 | "70\n", 1234 | "75\n", 1235 | "80\n", 1236 | "85\n", 1237 | "90\n", 1238 | "95\n", 1239 | "\n", 1240 | "0.0016519009111849896\n", 1241 | "0.23837758612778512\n", 1242 | "0.006929766082535414\n" 1243 | ] 1244 | } 1245 | ], 1246 | "source": [ 1247 | "data_test = spi.loadmat('test_PDE_DR.mat')\n", 1248 | "\n", 1249 | "u_in_test = data_test['X_test0']\n", 1250 | "x_t_in_test = data_test['X_test1']\n", 1251 | "s_in_test = data_test['y_test']\n", 1252 | "\n", 1253 | "nsamples = 100\n", 1254 | "pred = np.zeros([nsamples,1000000])\n", 1255 | "for i in range(0,nsamples):\n", 1256 | " if i%5 == 0:\n", 1257 | " print(i)\n", 1258 | " pred[i,:] = np.squeeze((model({\"inputsB\":u_in_test, \"inputsT\":x_t_in_test})).sample(1))\n", 1259 | " \n", 1260 | "print()\n", 1261 | "print(np.mean((s_in_test-np.mean(pred, axis = 0)[..., np.newaxis])**2))\n", 1262 | "print(np.mean((s_in_test)**2))\n", 1263 | "print(np.mean((s_in_test-np.mean(pred, axis = 0)[..., np.newaxis])**2)/np.mean((s_in_test)**2)) " 1264 | ] 1265 | }, 1266 | { 1267 | "cell_type": "code", 1268 | "execution_count": 6, 1269 | "id": "1181fd7b-12ec-4383-a47e-127a22ffdd7a", 1270 | "metadata": {}, 1271 | "outputs": [], 1272 | "source": [ 1273 | "model.save_weights('./model/TRIALmodel1000C3_S1')" 1274 | ] 1275 | }, 1276 | { 1277 | "cell_type": "code", 1278 | "execution_count": null, 1279 | "id": "ee780291-74d9-44a1-907c-25f20f94b06b", 1280 | "metadata": {}, 1281 | "outputs": [], 1282 | "source": [] 1283 | } 1284 | ], 1285 | "metadata": { 1286 | "kernelspec": { 1287 | "display_name": "Python 3 (ipykernel)", 1288 | "language": "python", 1289 | "name": "python3" 1290 | }, 1291 | "language_info": { 1292 | "codemirror_mode": { 1293 | "name": "ipython", 1294 | "version": 3 1295 | }, 1296 | "file_extension": ".py", 1297 | "mimetype": "text/x-python", 1298 | "name": "python", 1299 | "nbconvert_exporter": "python", 1300 | "pygments_lexer": "ipython3", 1301 | "version": "3.9.13" 1302 | } 1303 | }, 1304 | "nbformat": 4, 1305 | "nbformat_minor": 5 1306 | } 1307 | -------------------------------------------------------------------------------- /CS-IV/CS_IV_D_DeepONet_GitHub.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6dc23078-05ca-4bbe-bcb0-e21d2a079cf7", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((100000, 100), (100000, 2), (100000, 1))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "%reset -f\n", 22 | "import h5py\n", 23 | "import time as t\n", 24 | "import numpy as np\n", 25 | "import scipy as sp\n", 26 | "import scipy.io as spi\n", 27 | "import tensorflow as tf\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "data_train = spi.loadmat('train_PDE_ADVD.mat')\n", 31 | "\n", 32 | "u_in = data_train['X_train0']\n", 33 | "x_t_in = data_train['X_train1']\n", 34 | "s_in = data_train['y_train']\n", 35 | "\n", 36 | "u_in.shape, x_t_in.shape, s_in.shape" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "cc8fdc38-c3cb-47e3-ab99-f0870272984a", 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "Model: \"model\"\n", 50 | "__________________________________________________________________________________________________\n", 51 | " Layer (type) Output Shape Param # Connected to \n", 52 | "==================================================================================================\n", 53 | " inputsB (InputLayer) [(None, 100)] 0 [] \n", 54 | " \n", 55 | " inputsT (InputLayer) [(None, 2)] 0 [] \n", 56 | " \n", 57 | " dense (Dense) (None, 35) 3535 ['inputsB[0][0]'] \n", 58 | " \n", 59 | " dense_3 (Dense) (None, 35) 105 ['inputsT[0][0]'] \n", 60 | " \n", 61 | " dense_1 (Dense) (None, 35) 1260 ['dense[0][0]'] \n", 62 | " \n", 63 | " dense_4 (Dense) (None, 35) 1260 ['dense_3[0][0]'] \n", 64 | " \n", 65 | " dense_2 (Dense) (None, 35) 1260 ['dense_1[0][0]'] \n", 66 | " \n", 67 | " dense_5 (Dense) (None, 35) 1260 ['dense_4[0][0]'] \n", 68 | " \n", 69 | " lambda (Lambda) (None, 1) 0 ['dense_2[0][0]', \n", 70 | " 'dense_5[0][0]'] \n", 71 | " \n", 72 | " dense_6 (Dense) (None, 1) 2 ['lambda[0][0]'] \n", 73 | " \n", 74 | "==================================================================================================\n", 75 | "Total params: 8,682\n", 76 | "Trainable params: 8,682\n", 77 | "Non-trainable params: 0\n", 78 | "__________________________________________________________________________________________________\n", 79 | "Iteration: 0\n", 80 | "1/1 [==============================] - 1s 846ms/step - loss: 0.6309\n", 81 | "Iteration: 250\n", 82 | "1/1 [==============================] - 0s 335ms/step - loss: 0.1295\n", 83 | "Iteration: 500\n", 84 | "1/1 [==============================] - 0s 381ms/step - loss: 0.0753\n", 85 | "Iteration: 750\n", 86 | "1/1 [==============================] - 0s 362ms/step - loss: 0.0498\n", 87 | "Iteration: 1000\n", 88 | "1/1 [==============================] - 0s 366ms/step - loss: 0.0358\n", 89 | "Iteration: 1250\n", 90 | "1/1 [==============================] - 0s 269ms/step - loss: 0.0271\n", 91 | "Iteration: 1500\n", 92 | "1/1 [==============================] - 0s 304ms/step - loss: 0.0213\n", 93 | "Iteration: 1750\n", 94 | "1/1 [==============================] - 0s 301ms/step - loss: 0.0173\n", 95 | "Iteration: 2000\n", 96 | "1/1 [==============================] - 0s 102ms/step - loss: 0.0145\n", 97 | "Iteration: 2250\n", 98 | "1/1 [==============================] - 0s 249ms/step - loss: 0.0117\n", 99 | "Iteration: 2500\n", 100 | "1/1 [==============================] - 0s 101ms/step - loss: 0.0098\n", 101 | "Iteration: 2750\n", 102 | "1/1 [==============================] - 0s 114ms/step - loss: 0.0083\n", 103 | "Iteration: 3000\n", 104 | "1/1 [==============================] - 0s 68ms/step - loss: 0.0073\n", 105 | "Iteration: 3250\n", 106 | "1/1 [==============================] - 0s 316ms/step - loss: 0.0064\n", 107 | "Iteration: 3500\n", 108 | "1/1 [==============================] - 0s 69ms/step - loss: 0.0057\n", 109 | "Iteration: 3750\n", 110 | "1/1 [==============================] - 0s 131ms/step - loss: 0.0052\n", 111 | "Iteration: 4000\n", 112 | "1/1 [==============================] - 0s 169ms/step - loss: 0.0047\n", 113 | "Iteration: 4250\n", 114 | "1/1 [==============================] - 0s 53ms/step - loss: 0.0043\n", 115 | "Iteration: 4500\n", 116 | "1/1 [==============================] - 0s 38ms/step - loss: 0.0040\n", 117 | "Iteration: 4750\n", 118 | "1/1 [==============================] - 0s 316ms/step - loss: 0.0037\n", 119 | "Iteration: 5000\n", 120 | "1/1 [==============================] - 0s 417ms/step - loss: 0.0034\n", 121 | "Iteration: 5250\n", 122 | "1/1 [==============================] - 0s 370ms/step - loss: 0.0032\n", 123 | "Iteration: 5500\n", 124 | "1/1 [==============================] - 0s 52ms/step - loss: 0.0030\n", 125 | "Iteration: 5750\n", 126 | "1/1 [==============================] - 0s 47ms/step - loss: 0.0028\n", 127 | "Iteration: 6000\n", 128 | "1/1 [==============================] - 0s 116ms/step - loss: 0.0027\n", 129 | "Iteration: 6250\n", 130 | "1/1 [==============================] - 0s 100ms/step - loss: 0.0026\n", 131 | "Iteration: 6500\n", 132 | "1/1 [==============================] - 0s 38ms/step - loss: 0.0024\n", 133 | "Iteration: 6750\n", 134 | "1/1 [==============================] - 0s 53ms/step - loss: 0.0023\n", 135 | "Iteration: 7000\n", 136 | "1/1 [==============================] - 0s 291ms/step - loss: 0.0021\n", 137 | "Iteration: 7250\n", 138 | "1/1 [==============================] - 0s 100ms/step - loss: 0.0022\n", 139 | "Iteration: 7500\n", 140 | "1/1 [==============================] - 0s 183ms/step - loss: 0.0021\n", 141 | "Iteration: 7750\n", 142 | "1/1 [==============================] - 0s 123ms/step - loss: 0.0019\n", 143 | "Iteration: 8000\n", 144 | "1/1 [==============================] - 0s 116ms/step - loss: 0.0020\n", 145 | "Iteration: 8250\n", 146 | "1/1 [==============================] - 0s 69ms/step - loss: 0.0018\n", 147 | "Iteration: 8500\n", 148 | "1/1 [==============================] - 0s 53ms/step - loss: 0.0018\n", 149 | "Iteration: 8750\n", 150 | "1/1 [==============================] - 0s 141ms/step - loss: 0.0019\n", 151 | "Iteration: 9000\n", 152 | "1/1 [==============================] - 0s 131ms/step - loss: 0.0018\n", 153 | "Iteration: 9250\n", 154 | "1/1 [==============================] - 0s 79ms/step - loss: 0.0017\n", 155 | "Iteration: 9500\n", 156 | "1/1 [==============================] - 0s 132ms/step - loss: 0.0017\n", 157 | "Iteration: 9750\n", 158 | "1/1 [==============================] - 0s 53ms/step - loss: 0.0017\n", 159 | "Total iterations: 10000\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "from tensorflow.keras.models import Model\n", 165 | "from tensorflow.keras.layers import Input, Lambda, Dense\n", 166 | "\n", 167 | "bs = 100000\n", 168 | "\n", 169 | "def fn(x):\n", 170 | " y = tf.einsum(\"ij, ij->i\", x[0], x[1])\n", 171 | " y = tf.expand_dims(y, axis = 1)\n", 172 | " return y\n", 173 | "\n", 174 | "hln = 35\n", 175 | "\n", 176 | "inputsB = Input(shape = (100,), name = 'inputsB')\n", 177 | "hiddenB = Dense(hln, activation = \"relu\")(inputsB)\n", 178 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 179 | "hiddenB = Dense(hln, activation = \"relu\")(hiddenB)\n", 180 | "\n", 181 | "inputsT = Input(shape = (2,), name = 'inputsT')\n", 182 | "hiddenT = Dense(hln, activation = \"relu\")(inputsT)\n", 183 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 184 | "hiddenT = Dense(hln, activation = \"relu\")(hiddenT)\n", 185 | "\n", 186 | "combined = Lambda(fn, output_shape = [None, 1])([hiddenB, hiddenT])\n", 187 | "output = Dense(1)(combined)\n", 188 | "\n", 189 | "model = Model(inputs = [inputsB, inputsT], outputs = output) \n", 190 | "\n", 191 | "model.compile(optimizer = tf.optimizers.Adam(learning_rate = 0.001), loss = 'mse')\n", 192 | "model.summary()\n", 193 | "\n", 194 | "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", 195 | " filepath = './ChkPts/CIV_VDN_S1_10000/',\n", 196 | " save_weights_only=True,\n", 197 | " monitor='loss',\n", 198 | " mode='min',\n", 199 | " save_best_only=True)\n", 200 | "\n", 201 | "itr = 0\n", 202 | "for i in range(0, 40):\n", 203 | " print('Iteration: '+str(itr))\n", 204 | " itr = itr+1\n", 205 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = 1,\n", 206 | " verbose = 1, batch_size = bs, callbacks = [model_checkpoint_callback]) \n", 207 | "\n", 208 | " itr_ps = 250-1\n", 209 | " model.fit({\"inputsB\":u_in, \"inputsT\":x_t_in}, s_in, epochs = itr_ps,\n", 210 | " verbose = 0, batch_size = bs, callbacks = [model_checkpoint_callback])\n", 211 | " itr = itr+itr_ps\n", 212 | "\n", 213 | "print('Total iterations: '+str(itr))\n", 214 | "\n", 215 | "model.load_weights('./ChkPts/CIV_VDN_S1_10000/')\n", 216 | "\n", 217 | "model.save_weights('./model/CIV_VDN_S1_10000')\n", 218 | "\n", 219 | "# model.load_weights('./model/Dense_model_TF_weights_CIV_VDN')" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 3, 225 | "id": "9cebf033-567b-4508-a037-5951bf7adb06", 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "\n", 233 | "0.0024013603\n", 234 | "0.41757784508986057\n", 235 | "0.005750688976560404\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "data_test = spi.loadmat('test_PDE_ADVD.mat')\n", 241 | "\n", 242 | "u_in_test = data_test['X_test0']\n", 243 | "x_t_in_test = data_test['X_test1']\n", 244 | "s_in_test = data_test['y_test']\n", 245 | "\n", 246 | "pred = model({\"inputsB\":u_in_test, \"inputsT\":x_t_in_test})\n", 247 | " \n", 248 | "print()\n", 249 | "print(np.mean((s_in_test-pred)**2))\n", 250 | "print(np.mean((s_in_test)**2))\n", 251 | "print(np.mean((s_in_test-pred)**2)/np.mean((s_in_test)**2)) " 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "df7255ed-aa43-4a72-a18f-04c557e4623e", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [] 261 | } 262 | ], 263 | "metadata": { 264 | "kernelspec": { 265 | "display_name": "Python 3 (ipykernel)", 266 | "language": "python", 267 | "name": "python3" 268 | }, 269 | "language_info": { 270 | "codemirror_mode": { 271 | "name": "ipython", 272 | "version": 3 273 | }, 274 | "file_extension": ".py", 275 | "mimetype": "text/x-python", 276 | "name": "python", 277 | "nbconvert_exporter": "python", 278 | "pygments_lexer": "ipython3", 279 | "version": "3.10.2" 280 | }, 281 | "widgets": { 282 | "application/vnd.jupyter.widget-state+json": { 283 | "state": {}, 284 | "version_major": 2, 285 | "version_minor": 0 286 | } 287 | } 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 5 291 | } 292 | -------------------------------------------------------------------------------- /CS-IV/CS_IV_POD_GP_GitHub.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | 5 | tic 6 | 7 | %% %% 8 | 9 | IFR = 0.95; 10 | load train_PDE_ADVD.mat 11 | 12 | n = 100000; 13 | y = X_train0(1:n,:); 14 | 15 | [~, sd, vd] = svd(y); 16 | 17 | sd = sd.^2; 18 | nd = 1; 19 | chkd = sum(diag(sd)); 20 | rd = 0; 21 | while rd < IFR 22 | rd = sum(diag(sd(1:nd,1:nd)))/chkd; 23 | nd = nd+1; 24 | end 25 | nd = nd-1; 26 | fprintf('\n\n%d\n\n',nd); 27 | 28 | red = y*vd(:,1:nd); 29 | 30 | toc 31 | 32 | %% 33 | 34 | in = [red, X_train1(1:n, :)]; 35 | mdl = fitrgp(in, y_train(1:n,:)); 36 | 37 | toc 38 | 39 | %% 40 | 41 | if IFR == 0.95 42 | save("GPmdl95P_new_5.mat",'sd','vd','n','IFR','nd','mdl') 43 | else 44 | save("GPmdl99d99P_new_5.mat",'sd','vd','n','IFR','nd','mdl') 45 | end 46 | 47 | %% PREDICTION 48 | 49 | load test_PDE_ADVD.mat 50 | 51 | S_mse = zeros(100,1); 52 | S_nmse = zeros(100,1); 53 | 54 | for i = 1:100 55 | 56 | i 57 | 58 | n = 10000; 59 | y = X_test0((i-1)*n+1:i*n, :); 60 | 61 | pfr = y*vd(:,1:nd); 62 | in = [pfr, X_test1((i-1)*n+1:i*n, :)]; 63 | 64 | pred = zeros(10,10000,1); 65 | for j = 1:10 66 | pred(j,:,:) = predict(mdl, in); 67 | end 68 | 69 | mpred = squeeze(mean(pred, 1)); 70 | spred = squeeze(std(pred, 1)); 71 | 72 | mse = mean(mean((mpred'-y_test((i-1)*n+1:i*n,1)).^2)); 73 | nmse = mean(mean((mpred'-y_test((i-1)*n+1:i*n,1)).^2))./mean(mean(y_test((i-1)*n+1:i*n,1).^2)); 74 | 75 | S_mse(i) = mse; 76 | S_nmse(i) = nmse; 77 | end 78 | 79 | MSE = mean(S_mse) 80 | NMSE = mean(S_nmse) 81 | 82 | % MSE = 83 | % 84 | % 0.0473 85 | % 86 | % 87 | % NMSE = 88 | % 89 | % 0.1146 90 | -------------------------------------------------------------------------------- /CS-IV/CS_IV_VB_DeepONet_GitHub.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6dc23078-05ca-4bbe-bcb0-e21d2a079cf7", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((100000, 100), (100000, 2), (100000, 1))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "%reset -f\n", 22 | "import h5py\n", 23 | "import time as t\n", 24 | "import numpy as np\n", 25 | "import scipy as sp\n", 26 | "import scipy.io as spi\n", 27 | "import tensorflow as tf\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import tensorflow_probability as tfp\n", 30 | "\n", 31 | "data_train = spi.loadmat('train_PDE_ADVD.mat')\n", 32 | "\n", 33 | "u_in = data_train['X_train0']\n", 34 | "x_t_in = data_train['X_train1']\n", 35 | "s_in = data_train['y_train']\n", 36 | "\n", 37 | "u_in.shape, x_t_in.shape, s_in.shape" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "cc8fdc38-c3cb-47e3-ab99-f0870272984a", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "C:\\Users\\shailesh\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tensorflow_probability\\python\\layers\\util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n", 51 | " loc = add_variable_fn(\n", 52 | "C:\\Users\\shailesh\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tensorflow_probability\\python\\layers\\util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n", 53 | " untransformed_scale = add_variable_fn(\n" 54 | ] 55 | }, 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Model: \"model\"\n", 61 | "__________________________________________________________________________________________________\n", 62 | " Layer (type) Output Shape Param # Connected to \n", 63 | "==================================================================================================\n", 64 | " inputsB (InputLayer) [(None, 100)] 0 [] \n", 65 | " \n", 66 | " inputsT (InputLayer) [(None, 2)] 0 [] \n", 67 | " \n", 68 | " dense_flipout (DenseFlipout) (None, 35) 7035 ['inputsB[0][0]'] \n", 69 | " \n", 70 | " dense_flipout_3 (DenseFlipout) (None, 35) 175 ['inputsT[0][0]'] \n", 71 | " \n", 72 | " dense_flipout_1 (DenseFlipout) (None, 35) 2485 ['dense_flipout[0][0]'] \n", 73 | " \n", 74 | " dense_flipout_4 (DenseFlipout) (None, 35) 2485 ['dense_flipout_3[0][0]'] \n", 75 | " \n", 76 | " dense_flipout_2 (DenseFlipout) (None, 35) 2485 ['dense_flipout_1[0][0]'] \n", 77 | " \n", 78 | " dense_flipout_5 (DenseFlipout) (None, 35) 2485 ['dense_flipout_4[0][0]'] \n", 79 | " \n", 80 | " lambda (Lambda) (None, 1) 0 ['dense_flipout_2[0][0]', \n", 81 | " 'dense_flipout_5[0][0]'] \n", 82 | " \n", 83 | " dense_flipout_6 (DenseFlipout) (None, 2) 6 ['lambda[0][0]'] \n", 84 | " \n", 85 | " distribution_lambda (Distribut ((None, 1), 0 ['dense_flipout_6[0][0]'] \n", 86 | " ionLambda) (None, 1)) \n", 87 | " \n", 88 | "==================================================================================================\n", 89 | "Total params: 17,156\n", 90 | "Trainable params: 17,156\n", 91 | "Non-trainable params: 0\n", 92 | "__________________________________________________________________________________________________\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "from tensorflow.keras.models import Model\n", 98 | "from tensorflow.keras.layers import Input, Lambda, Dense\n", 99 | "\n", 100 | "bs = 100000\n", 101 | "\n", 102 | "def fn(x):\n", 103 | " y = tf.einsum(\"ij, ij->i\", x[0], x[1])\n", 104 | " y = tf.expand_dims(y, axis = 1)\n", 105 | " return y\n", 106 | "\n", 107 | "tfd = tfp.distributions\n", 108 | "tfb = tfp.bijectors\n", 109 | "\n", 110 | "def normal_sp(params):\n", 111 | " return tfd.Normal(loc = params[:, 0:1], scale = 0.001+tf.math.softplus(params[:, 1:2])) \n", 112 | "\n", 113 | "def negloglikelihood(y_true, y_pred):\n", 114 | " return tf.keras.backend.sum(-y_pred.log_prob(y_true))+(sum(model.losses)/bs)\n", 115 | "\n", 116 | "hln = 35\n", 117 | "\n", 118 | "inputsB = Input(shape = (100,), name = 'inputsB')\n", 119 | "hiddenB = tfp.layers.DenseFlipout(hln, activation = \"relu\")(inputsB)\n", 120 | "hiddenB = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenB)\n", 121 | "hiddenB = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenB)\n", 122 | "\n", 123 | "inputsT = Input(shape = (2,), name = 'inputsT')\n", 124 | "hiddenT = tfp.layers.DenseFlipout(hln, activation = \"relu\")(inputsT)\n", 125 | "hiddenT = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenT)\n", 126 | "hiddenT = tfp.layers.DenseFlipout(hln, activation = \"relu\")(hiddenT)\n", 127 | "\n", 128 | "combined = Lambda(fn, output_shape = [None, 1])([hiddenB, hiddenT])\n", 129 | "output = tfp.layers.DenseFlipout(2)(combined)\n", 130 | "\n", 131 | "dist = tfp.layers.DistributionLambda(normal_sp)(output)\n", 132 | "model = Model(inputs = [inputsB, inputsT], outputs = dist) \n", 133 | "\n", 134 | "model.summary()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 3, 140 | "id": "a0e0257a-2383-49ef-90d7-620c522adb45", 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stdout", 145 | "output_type": "stream", 146 | "text": [ 147 | "Epoch 0, loss 98778.48\n", 148 | "Epoch 10, loss 96796.16\n", 149 | "Epoch 20, loss 92387.11\n", 150 | "Epoch 30, loss 88785.96\n", 151 | "Epoch 40, loss 86650.29\n", 152 | "Epoch 50, loss 84265.85\n", 153 | "Epoch 60, loss 82161.11\n", 154 | "Epoch 70, loss 79693.80\n", 155 | "Epoch 80, loss 78118.35\n", 156 | "Epoch 90, loss 76280.92\n", 157 | "Epoch 100, loss 74288.95\n", 158 | "Epoch 110, loss 72288.89\n", 159 | "Epoch 120, loss 70487.96\n", 160 | "Epoch 130, loss 68522.03\n", 161 | "Epoch 140, loss 69157.39\n", 162 | "Epoch 150, loss 64608.23\n", 163 | "Epoch 160, loss 62815.16\n", 164 | "Epoch 170, loss 61782.79\n", 165 | "Epoch 180, loss 59464.08\n", 166 | "Epoch 190, loss 58922.18\n", 167 | "Epoch 200, loss 54458.47\n", 168 | "Epoch 210, loss 53855.80\n", 169 | "Epoch 220, loss 53010.52\n", 170 | "Epoch 230, loss 50939.17\n", 171 | "Epoch 240, loss 51222.13\n", 172 | "Epoch 250, loss 50328.01\n", 173 | "Epoch 260, loss 48201.49\n", 174 | "Epoch 270, loss 48554.84\n", 175 | "Epoch 280, loss 45105.72\n", 176 | "Epoch 290, loss 43516.95\n", 177 | "Epoch 300, loss 43246.69\n", 178 | "Epoch 310, loss 41460.08\n", 179 | "Epoch 320, loss 40558.35\n", 180 | "Epoch 330, loss 41593.46\n", 181 | "Epoch 340, loss 39304.34\n", 182 | "Epoch 350, loss 37403.43\n", 183 | "Epoch 360, loss 37521.10\n", 184 | "Epoch 370, loss 34182.15\n", 185 | "Epoch 380, loss 34700.66\n", 186 | "Epoch 390, loss 30740.71\n", 187 | "Epoch 400, loss 30996.26\n", 188 | "Epoch 410, loss 29724.63\n", 189 | "Epoch 420, loss 27287.83\n", 190 | "Epoch 430, loss 24360.80\n", 191 | "Epoch 440, loss 23894.92\n", 192 | "Epoch 450, loss 21829.04\n", 193 | "Epoch 460, loss 17437.16\n", 194 | "Epoch 470, loss 16580.23\n", 195 | "Epoch 480, loss 18286.54\n", 196 | "Epoch 490, loss 10016.86\n", 197 | "Epoch 500, loss 13166.19\n", 198 | "Epoch 510, loss 6675.05\n", 199 | "Epoch 520, loss 6454.41\n", 200 | "Epoch 530, loss 5080.92\n", 201 | "Epoch 540, loss 2504.96\n", 202 | "Epoch 550, loss 2132.10\n", 203 | "Epoch 560, loss -1581.43\n", 204 | "Epoch 570, loss 1040.25\n", 205 | "Epoch 580, loss -4171.51\n", 206 | "Epoch 590, loss -5765.65\n", 207 | "Epoch 600, loss -5281.13\n", 208 | "Epoch 610, loss -8323.67\n", 209 | "Epoch 620, loss -9657.71\n", 210 | "Epoch 630, loss -10777.83\n", 211 | "Epoch 640, loss -11317.01\n", 212 | "Epoch 650, loss -12956.27\n", 213 | "Epoch 660, loss -11536.84\n", 214 | "Epoch 670, loss -14525.91\n", 215 | "Epoch 680, loss -18325.03\n", 216 | "Epoch 690, loss -16480.04\n", 217 | "Epoch 700, loss -16981.80\n", 218 | "Epoch 710, loss -20150.24\n", 219 | "Epoch 720, loss -15874.86\n", 220 | "Epoch 730, loss -22737.21\n", 221 | "Epoch 740, loss -23490.40\n", 222 | "Epoch 750, loss -23878.28\n", 223 | "Epoch 760, loss -23995.43\n", 224 | "Epoch 770, loss -27176.91\n", 225 | "Epoch 780, loss -28308.01\n", 226 | "Epoch 790, loss -27487.47\n", 227 | "Epoch 800, loss -29360.22\n", 228 | "Epoch 810, loss -33843.75\n", 229 | "Epoch 820, loss -35673.72\n", 230 | "Epoch 830, loss -33846.03\n", 231 | "Epoch 840, loss -37994.36\n", 232 | "Epoch 850, loss -37577.27\n", 233 | "Epoch 860, loss -37860.10\n", 234 | "Epoch 870, loss -42010.75\n", 235 | "Epoch 880, loss -42475.24\n", 236 | "Epoch 890, loss -42245.27\n", 237 | "Epoch 900, loss -41615.95\n", 238 | "Epoch 910, loss -43085.11\n", 239 | "Epoch 920, loss -46051.54\n", 240 | "Epoch 930, loss -48939.41\n", 241 | "Epoch 940, loss -46389.84\n", 242 | "Epoch 950, loss -49734.06\n", 243 | "Epoch 960, loss -51666.85\n", 244 | "Epoch 970, loss -53662.79\n", 245 | "Epoch 980, loss -55633.29\n", 246 | "Epoch 990, loss -55388.32\n", 247 | "Epoch 1000, loss -55716.38\n", 248 | "Epoch 1010, loss -59305.96\n", 249 | "Epoch 1020, loss -61073.84\n", 250 | "Epoch 1030, loss -62433.24\n", 251 | "Epoch 1040, loss -64718.44\n", 252 | "Epoch 1050, loss -62980.46\n", 253 | "Epoch 1060, loss -62592.66\n", 254 | "Epoch 1070, loss -68965.74\n", 255 | "Epoch 1080, loss -67795.55\n", 256 | "Epoch 1090, loss -69483.55\n", 257 | "Epoch 1100, loss -69049.48\n", 258 | "Epoch 1110, loss -70459.02\n", 259 | "Epoch 1120, loss -69402.49\n", 260 | "Epoch 1130, loss -73387.95\n", 261 | "Epoch 1140, loss -75180.91\n", 262 | "Epoch 1150, loss -73930.42\n", 263 | "Epoch 1160, loss -74814.22\n", 264 | "Epoch 1170, loss -74879.03\n", 265 | "Epoch 1180, loss -80081.02\n", 266 | "Epoch 1190, loss -78180.65\n", 267 | "Epoch 1200, loss -81663.26\n", 268 | "Epoch 1210, loss -80045.30\n", 269 | "Epoch 1220, loss -84783.58\n", 270 | "Epoch 1230, loss -83164.34\n", 271 | "Epoch 1240, loss -85740.77\n", 272 | "Epoch 1250, loss -86319.61\n", 273 | "Epoch 1260, loss -85805.34\n", 274 | "Epoch 1270, loss -89949.35\n", 275 | "Epoch 1280, loss -89771.09\n", 276 | "Epoch 1290, loss -92349.79\n", 277 | "Epoch 1300, loss -89672.48\n", 278 | "Epoch 1310, loss -94113.55\n", 279 | "Epoch 1320, loss -93025.37\n", 280 | "Epoch 1330, loss -97501.98\n", 281 | "Epoch 1340, loss -96230.04\n", 282 | "Epoch 1350, loss -97057.82\n", 283 | "Epoch 1360, loss -92433.91\n", 284 | "Epoch 1370, loss -97149.53\n", 285 | "Epoch 1380, loss -96936.14\n", 286 | "Epoch 1390, loss -101130.80\n", 287 | "Epoch 1400, loss -101821.11\n", 288 | "Epoch 1410, loss -101125.53\n", 289 | "Epoch 1420, loss -102727.55\n", 290 | "Epoch 1430, loss -104587.23\n", 291 | "Epoch 1440, loss -102910.30\n", 292 | "Epoch 1450, loss -105418.68\n", 293 | "Epoch 1460, loss -105731.91\n", 294 | "Epoch 1470, loss -103805.44\n", 295 | "Epoch 1480, loss -109404.30\n", 296 | "Epoch 1490, loss -109086.22\n", 297 | "Epoch 1500, loss -111635.75\n", 298 | "Epoch 1510, loss -109562.88\n", 299 | "Epoch 1520, loss -112520.95\n", 300 | "Epoch 1530, loss -111836.86\n", 301 | "Epoch 1540, loss -111890.01\n", 302 | "Epoch 1550, loss -112947.94\n", 303 | "Epoch 1560, loss -117583.48\n", 304 | "Epoch 1570, loss -114358.52\n", 305 | "Epoch 1580, loss -116991.88\n", 306 | "Epoch 1590, loss -117468.70\n", 307 | "Epoch 1600, loss -114206.55\n", 308 | "Epoch 1610, loss -119306.50\n", 309 | "Epoch 1620, loss -121832.12\n", 310 | "Epoch 1630, loss -120580.08\n", 311 | "Epoch 1640, loss -121459.29\n", 312 | "Epoch 1650, loss -125151.84\n", 313 | "Epoch 1660, loss -122042.30\n", 314 | "Epoch 1670, loss -124137.80\n", 315 | "Epoch 1680, loss -125500.84\n", 316 | "Epoch 1690, loss -124709.46\n", 317 | "Epoch 1700, loss -126065.33\n", 318 | "Epoch 1710, loss -130223.66\n", 319 | "Epoch 1720, loss -130364.73\n", 320 | "Epoch 1730, loss -130426.55\n", 321 | "Epoch 1740, loss -124508.12\n", 322 | "Epoch 1750, loss -130872.39\n", 323 | "Epoch 1760, loss -130949.36\n", 324 | "Epoch 1770, loss -129580.96\n", 325 | "Epoch 1780, loss -133244.81\n", 326 | "Epoch 1790, loss -136507.30\n", 327 | "Epoch 1800, loss -132130.50\n", 328 | "Epoch 1810, loss -134175.55\n", 329 | "Epoch 1820, loss -137881.45\n", 330 | "Epoch 1830, loss -136084.20\n", 331 | "Epoch 1840, loss -140337.03\n", 332 | "Epoch 1850, loss -132032.84\n", 333 | "Epoch 1860, loss -139666.36\n", 334 | "Epoch 1870, loss -140005.02\n", 335 | "Epoch 1880, loss -142490.89\n", 336 | "Epoch 1890, loss -140431.73\n", 337 | "Epoch 1900, loss -142247.95\n", 338 | "Epoch 1910, loss -141210.84\n", 339 | "Epoch 1920, loss -142062.34\n", 340 | "Epoch 1930, loss -142378.06\n", 341 | "Epoch 1940, loss -145019.92\n", 342 | "Epoch 1950, loss -142167.19\n", 343 | "Epoch 1960, loss -143022.31\n", 344 | "Epoch 1970, loss -143726.94\n", 345 | "Epoch 1980, loss -145997.06\n", 346 | "Epoch 1990, loss -149767.97\n", 347 | "Epoch 2000, loss -151228.11\n", 348 | "Epoch 2010, loss -146549.38\n", 349 | "Epoch 2020, loss -145229.64\n", 350 | "Epoch 2030, loss -150247.02\n", 351 | "Epoch 2040, loss -149942.23\n", 352 | "Epoch 2050, loss -145378.53\n", 353 | "Epoch 2060, loss -152144.73\n", 354 | "Epoch 2070, loss -151993.69\n", 355 | "Epoch 2080, loss -152370.58\n", 356 | "Epoch 2090, loss -156051.53\n", 357 | "Epoch 2100, loss -154086.36\n", 358 | "Epoch 2110, loss -151847.50\n", 359 | "Epoch 2120, loss -155967.59\n", 360 | "Epoch 2130, loss -157670.11\n", 361 | "Epoch 2140, loss -156616.42\n", 362 | "Epoch 2150, loss -159887.19\n", 363 | "Epoch 2160, loss -156724.89\n", 364 | "Epoch 2170, loss -156321.41\n", 365 | "Epoch 2180, loss -161195.39\n", 366 | "Epoch 2190, loss -160872.28\n", 367 | "Epoch 2200, loss -158946.28\n", 368 | "Epoch 2210, loss -161404.03\n", 369 | "Epoch 2220, loss -161368.16\n", 370 | "Epoch 2230, loss -163220.56\n", 371 | "Epoch 2240, loss -162795.42\n", 372 | "Epoch 2250, loss -165622.59\n", 373 | "Epoch 2260, loss -165378.56\n", 374 | "Epoch 2270, loss -163839.36\n", 375 | "Epoch 2280, loss -163015.61\n", 376 | "Epoch 2290, loss -164003.95\n", 377 | "Epoch 2300, loss -168995.25\n", 378 | "Epoch 2310, loss -166053.91\n", 379 | "Epoch 2320, loss -167313.53\n", 380 | "Epoch 2330, loss -165758.50\n", 381 | "Epoch 2340, loss -169346.94\n", 382 | "Epoch 2350, loss -170953.31\n", 383 | "Epoch 2360, loss -170592.20\n", 384 | "Epoch 2370, loss -169678.28\n", 385 | "Epoch 2380, loss -170978.67\n", 386 | "Epoch 2390, loss -171390.00\n", 387 | "Epoch 2400, loss -172035.39\n", 388 | "Epoch 2410, loss -172984.42\n", 389 | "Epoch 2420, loss -173924.88\n", 390 | "Epoch 2430, loss -172039.88\n", 391 | "Epoch 2440, loss -170528.02\n", 392 | "Epoch 2450, loss -171302.42\n", 393 | "Epoch 2460, loss -175246.75\n", 394 | "Epoch 2470, loss -175940.16\n", 395 | "Epoch 2480, loss -175927.95\n", 396 | "Epoch 2490, loss -173733.92\n", 397 | "Epoch 2500, loss -175307.56\n", 398 | "Epoch 2510, loss -176654.22\n", 399 | "Epoch 2520, loss -177757.88\n", 400 | "Epoch 2530, loss -178354.16\n", 401 | "Epoch 2540, loss -178975.11\n", 402 | "Epoch 2550, loss -179366.88\n", 403 | "Epoch 2560, loss -174895.03\n", 404 | "Epoch 2570, loss -179303.64\n", 405 | "Epoch 2580, loss -181905.78\n", 406 | "Epoch 2590, loss -178764.00\n", 407 | "Epoch 2600, loss -179910.47\n", 408 | "Epoch 2610, loss -180629.75\n", 409 | "Epoch 2620, loss -184153.31\n", 410 | "Epoch 2630, loss -163353.14\n", 411 | "Epoch 2640, loss -181242.64\n", 412 | "Epoch 2650, loss -182822.16\n", 413 | "Epoch 2660, loss -184259.08\n", 414 | "Epoch 2670, loss -183804.81\n", 415 | "Epoch 2680, loss -186125.16\n", 416 | "Epoch 2690, loss -185824.64\n", 417 | "Epoch 2700, loss -187282.20\n", 418 | "Epoch 2710, loss -184692.80\n", 419 | "Epoch 2720, loss -185206.66\n", 420 | "Epoch 2730, loss -188292.36\n", 421 | "Epoch 2740, loss -187415.67\n", 422 | "Epoch 2750, loss -186803.88\n", 423 | "Epoch 2760, loss -188159.81\n", 424 | "Epoch 2770, loss -190882.61\n", 425 | "Epoch 2780, loss -187837.14\n", 426 | "Epoch 2790, loss -189519.17\n", 427 | "Epoch 2800, loss -189143.30\n", 428 | "Epoch 2810, loss -189653.16\n", 429 | "Epoch 2820, loss -191253.88\n", 430 | "Epoch 2830, loss -192014.73\n", 431 | "Epoch 2840, loss -189601.36\n", 432 | "Epoch 2850, loss -193122.50\n", 433 | "Epoch 2860, loss -192341.36\n", 434 | "Epoch 2870, loss -193610.02\n", 435 | "Epoch 2880, loss -188647.56\n", 436 | "Epoch 2890, loss -190324.47\n", 437 | "Epoch 2900, loss -190661.56\n", 438 | "Epoch 2910, loss -192663.00\n", 439 | "Epoch 2920, loss -192587.83\n", 440 | "Epoch 2930, loss -193338.25\n", 441 | "Epoch 2940, loss -195287.58\n", 442 | "Epoch 2950, loss -196477.03\n", 443 | "Epoch 2960, loss -195919.59\n", 444 | "Epoch 2970, loss -190165.95\n", 445 | "Epoch 2980, loss -196606.36\n", 446 | "Epoch 2990, loss -195812.03\n", 447 | "Epoch 3000, loss -197079.00\n", 448 | "Epoch 3010, loss -198682.03\n", 449 | "Epoch 3020, loss -184833.44\n", 450 | "Epoch 3030, loss -195394.66\n", 451 | "Epoch 3040, loss -195524.56\n", 452 | "Epoch 3050, loss -196468.08\n", 453 | "Epoch 3060, loss -196896.38\n", 454 | "Epoch 3070, loss -199414.14\n", 455 | "Epoch 3080, loss -199914.72\n", 456 | "Epoch 3090, loss -200460.20\n", 457 | "Epoch 3100, loss -199620.59\n", 458 | "Epoch 3110, loss -199461.44\n", 459 | "Epoch 3120, loss -196085.28\n", 460 | "Epoch 3130, loss -200345.70\n", 461 | "Epoch 3140, loss -201790.16\n", 462 | "Epoch 3150, loss -200231.52\n", 463 | "Epoch 3160, loss -201420.75\n", 464 | "Epoch 3170, loss -203614.38\n", 465 | "Epoch 3180, loss -178783.56\n", 466 | "Epoch 3190, loss -192976.92\n", 467 | "Epoch 3200, loss -198485.16\n", 468 | "Epoch 3210, loss -201979.95\n", 469 | "Epoch 3220, loss -202105.42\n", 470 | "Epoch 3230, loss -204853.31\n", 471 | "Epoch 3240, loss -204120.94\n", 472 | "Epoch 3250, loss -204744.53\n", 473 | "Epoch 3260, loss -205776.67\n", 474 | "Epoch 3270, loss -205617.31\n", 475 | "Epoch 3280, loss -205113.94\n", 476 | "Epoch 3290, loss -192817.50\n", 477 | "Epoch 3300, loss -204113.72\n", 478 | "Epoch 3310, loss -205218.28\n", 479 | "Epoch 3320, loss -205446.83\n", 480 | "Epoch 3330, loss -204950.23\n", 481 | "Epoch 3340, loss -206221.11\n", 482 | "Epoch 3350, loss -207923.50\n", 483 | "Epoch 3360, loss -207483.47\n", 484 | "Epoch 3370, loss -201663.89\n", 485 | "Epoch 3380, loss -208830.06\n", 486 | "Epoch 3390, loss -206009.44\n", 487 | "Epoch 3400, loss -208586.47\n", 488 | "Epoch 3410, loss -208279.80\n", 489 | "Epoch 3420, loss -208954.25\n", 490 | "Epoch 3430, loss -208152.08\n", 491 | "Epoch 3440, loss -202459.83\n", 492 | "Epoch 3450, loss -201512.52\n", 493 | "Epoch 3460, loss -208502.61\n", 494 | "Epoch 3470, loss -208983.88\n", 495 | "Epoch 3480, loss -210463.73\n", 496 | "Epoch 3490, loss -210644.92\n", 497 | "Epoch 3500, loss -212409.14\n", 498 | "Epoch 3510, loss -210939.53\n", 499 | "Epoch 3520, loss -191179.78\n", 500 | "Epoch 3530, loss -209973.11\n", 501 | "Epoch 3540, loss -211140.14\n", 502 | "Epoch 3550, loss -211347.42\n", 503 | "Epoch 3560, loss -213210.66\n", 504 | "Epoch 3570, loss -213220.58\n", 505 | "Epoch 3580, loss -212766.17\n", 506 | "Epoch 3590, loss -203129.61\n", 507 | "Epoch 3600, loss -208646.38\n", 508 | "Epoch 3610, loss -213015.73\n", 509 | "Epoch 3620, loss -212969.94\n", 510 | "Epoch 3630, loss -213655.42\n", 511 | "Epoch 3640, loss -214205.02\n", 512 | "Epoch 3650, loss -215044.00\n", 513 | "Epoch 3660, loss -215135.11\n", 514 | "Epoch 3670, loss -178378.09\n", 515 | "Epoch 3680, loss -205324.89\n", 516 | "Epoch 3690, loss -213495.94\n", 517 | "Epoch 3700, loss -214034.89\n", 518 | "Epoch 3710, loss -215490.02\n", 519 | "Epoch 3720, loss -215129.95\n", 520 | "Epoch 3730, loss -215318.38\n", 521 | "Epoch 3740, loss -216418.83\n", 522 | "Epoch 3750, loss -214354.66\n", 523 | "Epoch 3760, loss -206240.64\n", 524 | "Epoch 3770, loss -214576.66\n", 525 | "Epoch 3780, loss -215833.03\n", 526 | "Epoch 3790, loss -217762.56\n", 527 | "Epoch 3800, loss -216647.56\n", 528 | "Epoch 3810, loss -217657.11\n", 529 | "Epoch 3820, loss -215055.58\n", 530 | "Epoch 3830, loss -217112.09\n", 531 | "Epoch 3840, loss -218028.09\n", 532 | "Epoch 3850, loss -217034.94\n", 533 | "Epoch 3860, loss -217910.02\n", 534 | "Epoch 3870, loss -219501.52\n", 535 | "Epoch 3880, loss -219005.58\n", 536 | "Epoch 3890, loss -212167.86\n", 537 | "Epoch 3900, loss -210427.42\n", 538 | "Epoch 3910, loss -219580.17\n", 539 | "Epoch 3920, loss -219148.69\n", 540 | "Epoch 3930, loss -220557.33\n", 541 | "Epoch 3940, loss -219703.73\n", 542 | "Epoch 3950, loss -220508.97\n", 543 | "Epoch 3960, loss -220557.38\n", 544 | "Epoch 3970, loss -205568.45\n", 545 | "Epoch 3980, loss -219038.95\n", 546 | "Epoch 3990, loss -220302.23\n", 547 | "Epoch 4000, loss -220927.86\n", 548 | "Epoch 4010, loss -221122.16\n", 549 | "Epoch 4020, loss -221354.89\n", 550 | "Epoch 4030, loss -221933.03\n", 551 | "Epoch 4040, loss -208036.28\n", 552 | "Epoch 4050, loss -221958.06\n", 553 | "Epoch 4060, loss -219954.17\n", 554 | "Epoch 4070, loss -220612.08\n", 555 | "Epoch 4080, loss -221865.95\n", 556 | "Epoch 4090, loss -222971.11\n", 557 | "Epoch 4100, loss -222600.25\n", 558 | "Epoch 4110, loss -224101.83\n", 559 | "Epoch 4120, loss -223041.56\n", 560 | "Epoch 4130, loss -223717.16\n", 561 | "Epoch 4140, loss -223498.45\n", 562 | "Epoch 4150, loss -209566.52\n", 563 | "Epoch 4160, loss -222422.42\n", 564 | "Epoch 4170, loss -223694.33\n", 565 | "Epoch 4180, loss -223224.80\n", 566 | "Epoch 4190, loss -224052.56\n", 567 | "Epoch 4200, loss -224630.61\n", 568 | "Epoch 4210, loss -224106.94\n", 569 | "Epoch 4220, loss -225301.50\n", 570 | "Epoch 4230, loss -223540.31\n", 571 | "Epoch 4240, loss -224723.33\n", 572 | "Epoch 4250, loss -214198.64\n", 573 | "Epoch 4260, loss -223185.80\n", 574 | "Epoch 4270, loss -224490.30\n", 575 | "Epoch 4280, loss -224931.67\n", 576 | "Epoch 4290, loss -225793.28\n", 577 | "Epoch 4300, loss -226282.69\n", 578 | "Epoch 4310, loss -218726.81\n", 579 | "Epoch 4320, loss -211890.42\n", 580 | "Epoch 4330, loss -224086.50\n", 581 | "Epoch 4340, loss -224298.17\n", 582 | "Epoch 4350, loss -226156.50\n", 583 | "Epoch 4360, loss -226531.94\n", 584 | "Epoch 4370, loss -227572.83\n", 585 | "Epoch 4380, loss -227657.45\n", 586 | "Epoch 4390, loss -225690.83\n", 587 | "Epoch 4400, loss -215609.97\n", 588 | "Epoch 4410, loss -225388.53\n", 589 | "Epoch 4420, loss -226719.72\n", 590 | "Epoch 4430, loss -227472.33\n", 591 | "Epoch 4440, loss -227905.78\n", 592 | "Epoch 4450, loss -227820.19\n", 593 | "Epoch 4460, loss -228357.58\n", 594 | "Epoch 4470, loss -228383.59\n", 595 | "Epoch 4480, loss -228044.16\n", 596 | "Epoch 4490, loss -224337.33\n", 597 | "Epoch 4500, loss -221395.03\n", 598 | "Epoch 4510, loss -227217.22\n", 599 | "Epoch 4520, loss -228591.11\n", 600 | "Epoch 4530, loss -228903.56\n", 601 | "Epoch 4540, loss -229061.38\n", 602 | "Epoch 4550, loss -229456.78\n", 603 | "Epoch 4560, loss -229270.69\n", 604 | "Epoch 4570, loss -229738.02\n", 605 | "Epoch 4580, loss -229709.69\n", 606 | "Epoch 4590, loss -229337.39\n", 607 | "Epoch 4600, loss -197877.45\n", 608 | "Epoch 4610, loss -223228.45\n", 609 | "Epoch 4620, loss -227463.50\n", 610 | "Epoch 4630, loss -228801.17\n", 611 | "Epoch 4640, loss -230239.38\n", 612 | "Epoch 4650, loss -230609.50\n", 613 | "Epoch 4660, loss -230292.69\n", 614 | "Epoch 4670, loss -231359.88\n", 615 | "Epoch 4680, loss -231011.56\n", 616 | "Epoch 4690, loss -231353.45\n", 617 | "Epoch 4700, loss -221010.66\n", 618 | "Epoch 4710, loss -217889.47\n", 619 | "Epoch 4720, loss -227862.80\n", 620 | "Epoch 4730, loss -229703.72\n", 621 | "Epoch 4740, loss -231763.36\n", 622 | "Epoch 4750, loss -231697.86\n", 623 | "Epoch 4760, loss -232279.02\n", 624 | "Epoch 4770, loss -232281.94\n", 625 | "Epoch 4780, loss -232124.53\n", 626 | "Epoch 4790, loss -227086.78\n", 627 | "Epoch 4800, loss -230651.67\n", 628 | "Epoch 4810, loss -230940.00\n", 629 | "Epoch 4820, loss -231662.33\n", 630 | "Epoch 4830, loss -232276.17\n", 631 | "Epoch 4840, loss -232058.14\n", 632 | "Epoch 4850, loss -233075.53\n", 633 | "Epoch 4860, loss -232770.64\n", 634 | "Epoch 4870, loss -230354.38\n", 635 | "Epoch 4880, loss -229757.78\n", 636 | "Epoch 4890, loss -232213.97\n", 637 | "Epoch 4900, loss -232376.02\n", 638 | "Epoch 4910, loss -232176.02\n", 639 | "Epoch 4920, loss -233214.22\n", 640 | "Epoch 4930, loss -233364.50\n", 641 | "Epoch 4940, loss -233819.45\n", 642 | "Epoch 4950, loss -234488.45\n", 643 | "Epoch 4960, loss -233810.58\n", 644 | "Epoch 4970, loss -232031.64\n", 645 | "Epoch 4980, loss -230255.17\n", 646 | "Epoch 4990, loss -232360.02\n", 647 | "Epoch 5000, loss -234137.31\n", 648 | "Epoch 5010, loss -233740.94\n", 649 | "Epoch 5020, loss -234196.38\n", 650 | "Epoch 5030, loss -234480.78\n", 651 | "Epoch 5040, loss -234931.22\n", 652 | "Epoch 5050, loss -234644.61\n", 653 | "Epoch 5060, loss -231481.14\n", 654 | "Epoch 5070, loss -233682.86\n", 655 | "Epoch 5080, loss -231529.38\n", 656 | "Epoch 5090, loss -233490.38\n", 657 | "Epoch 5100, loss -234329.28\n", 658 | "Epoch 5110, loss -235089.06\n", 659 | "Epoch 5120, loss -236009.97\n", 660 | "Epoch 5130, loss -235827.58\n", 661 | "Epoch 5140, loss -236179.58\n", 662 | "Epoch 5150, loss -235621.64\n", 663 | "Epoch 5160, loss -234321.66\n", 664 | "Epoch 5170, loss -235281.25\n", 665 | "Epoch 5180, loss -231000.14\n", 666 | "Epoch 5190, loss -235444.00\n", 667 | "Epoch 5200, loss -235435.94\n", 668 | "Epoch 5210, loss -236489.25\n", 669 | "Epoch 5220, loss -236042.19\n", 670 | "Epoch 5230, loss -236756.53\n", 671 | "Epoch 5240, loss -236577.28\n", 672 | "Epoch 5250, loss -236901.02\n", 673 | "Epoch 5260, loss -237275.58\n", 674 | "Epoch 5270, loss -237104.73\n", 675 | "Epoch 5280, loss -234086.22\n", 676 | "Epoch 5290, loss -230633.47\n", 677 | "Epoch 5300, loss -235584.81\n", 678 | "Epoch 5310, loss -233703.42\n", 679 | "Epoch 5320, loss -236321.72\n", 680 | "Epoch 5330, loss -236841.25\n", 681 | "Epoch 5340, loss -237487.44\n", 682 | "Epoch 5350, loss -237684.52\n", 683 | "Epoch 5360, loss -237965.47\n", 684 | "Epoch 5370, loss -238121.39\n", 685 | "Epoch 5380, loss -237509.88\n", 686 | "Epoch 5390, loss -237197.22\n", 687 | "Epoch 5400, loss -215442.81\n", 688 | "Epoch 5410, loss -233898.66\n", 689 | "Epoch 5420, loss -237438.81\n", 690 | "Epoch 5430, loss -238297.80\n", 691 | "Epoch 5440, loss -238239.14\n", 692 | "Epoch 5450, loss -238035.25\n", 693 | "Epoch 5460, loss -238870.58\n", 694 | "Epoch 5470, loss -238877.95\n", 695 | "Epoch 5480, loss -236051.44\n", 696 | "Epoch 5490, loss -235766.14\n", 697 | "Epoch 5500, loss -237760.25\n", 698 | "Epoch 5510, loss -236087.11\n", 699 | "Epoch 5520, loss -236908.08\n", 700 | "Epoch 5530, loss -237868.78\n", 701 | "Epoch 5540, loss -239235.64\n", 702 | "Epoch 5550, loss -239572.28\n", 703 | "Epoch 5560, loss -239310.81\n", 704 | "Epoch 5570, loss -239640.59\n", 705 | "Epoch 5580, loss -239581.94\n", 706 | "Epoch 5590, loss -238263.83\n", 707 | "Epoch 5600, loss -218218.28\n", 708 | "Epoch 5610, loss -237464.73\n", 709 | "Epoch 5620, loss -238052.23\n", 710 | "Epoch 5630, loss -239837.45\n", 711 | "Epoch 5640, loss -239526.89\n", 712 | "Epoch 5650, loss -240194.17\n", 713 | "Epoch 5660, loss -240091.00\n", 714 | "Epoch 5670, loss -240466.50\n", 715 | "Epoch 5680, loss -240297.30\n", 716 | "Epoch 5690, loss -202332.88\n", 717 | "Epoch 5700, loss -234720.88\n", 718 | "Epoch 5710, loss -239797.36\n", 719 | "Epoch 5720, loss -240037.16\n", 720 | "Epoch 5730, loss -240331.55\n", 721 | "Epoch 5740, loss -240452.53\n", 722 | "Epoch 5750, loss -240948.61\n", 723 | "Epoch 5760, loss -240812.36\n", 724 | "Epoch 5770, loss -241115.61\n", 725 | "Epoch 5780, loss -237087.56\n", 726 | "Epoch 5790, loss -240400.23\n", 727 | "Epoch 5800, loss -235826.64\n", 728 | "Epoch 5810, loss -238987.28\n", 729 | "Epoch 5820, loss -240196.44\n", 730 | "Epoch 5830, loss -241499.53\n", 731 | "Epoch 5840, loss -241583.80\n", 732 | "Epoch 5850, loss -241571.05\n", 733 | "Epoch 5860, loss -241858.78\n", 734 | "Epoch 5870, loss -241998.17\n", 735 | "Epoch 5880, loss -238187.44\n", 736 | "Epoch 5890, loss -236635.39\n", 737 | "Epoch 5900, loss -240973.39\n", 738 | "Epoch 5910, loss -239899.47\n", 739 | "Epoch 5920, loss -240969.59\n", 740 | "Epoch 5930, loss -241716.25\n", 741 | "Epoch 5940, loss -241997.25\n", 742 | "Epoch 5950, loss -242579.28\n", 743 | "Epoch 5960, loss -242702.25\n", 744 | "Epoch 5970, loss -242692.03\n", 745 | "Epoch 5980, loss -242505.58\n", 746 | "Epoch 5990, loss -236446.66\n", 747 | "Epoch 6000, loss -233153.08\n", 748 | "Epoch 6010, loss -241394.23\n", 749 | "Epoch 6020, loss -241435.31\n", 750 | "Epoch 6030, loss -242061.72\n", 751 | "Epoch 6040, loss -242969.42\n", 752 | "Epoch 6050, loss -243211.95\n", 753 | "Epoch 6060, loss -243210.55\n", 754 | "Epoch 6070, loss -243021.28\n", 755 | "Epoch 6080, loss -243204.97\n", 756 | "Epoch 6090, loss -241319.36\n", 757 | "Epoch 6100, loss -241867.39\n", 758 | "Epoch 6110, loss -236425.88\n", 759 | "Epoch 6120, loss -242614.02\n", 760 | "Epoch 6130, loss -242676.30\n", 761 | "Epoch 6140, loss -243272.88\n", 762 | "Epoch 6150, loss -243446.83\n", 763 | "Epoch 6160, loss -243768.75\n", 764 | "Epoch 6170, loss -243885.55\n", 765 | "Epoch 6180, loss -244081.67\n", 766 | "Epoch 6190, loss -244064.00\n", 767 | "Epoch 6200, loss -243686.23\n", 768 | "Epoch 6210, loss -200124.86\n", 769 | "Epoch 6220, loss -235488.30\n", 770 | "Epoch 6230, loss -240487.42\n", 771 | "Epoch 6240, loss -243018.64\n", 772 | "Epoch 6250, loss -243916.61\n", 773 | "Epoch 6260, loss -244512.33\n", 774 | "Epoch 6270, loss -244723.14\n", 775 | "Epoch 6280, loss -244868.72\n", 776 | "Epoch 6290, loss -244946.95\n", 777 | "Epoch 6300, loss -244772.25\n", 778 | "Epoch 6310, loss -239552.36\n", 779 | "Epoch 6320, loss -241281.94\n", 780 | "Epoch 6330, loss -242139.83\n", 781 | "Epoch 6340, loss -243590.50\n", 782 | "Epoch 6350, loss -244036.36\n", 783 | "Epoch 6360, loss -244899.11\n", 784 | "Epoch 6370, loss -245245.31\n", 785 | "Epoch 6380, loss -245399.86\n", 786 | "Epoch 6390, loss -245536.47\n", 787 | "Epoch 6400, loss -245605.09\n", 788 | "Epoch 6410, loss -245756.92\n", 789 | "Epoch 6420, loss -245405.72\n", 790 | "Epoch 6430, loss -134625.50\n", 791 | "Epoch 6440, loss -229397.11\n", 792 | "Epoch 6450, loss -242708.75\n", 793 | "Epoch 6460, loss -242870.31\n", 794 | "Epoch 6470, loss -244227.44\n", 795 | "Epoch 6480, loss -244792.80\n", 796 | "Epoch 6490, loss -245558.88\n", 797 | "Epoch 6500, loss -246007.31\n", 798 | "Epoch 6510, loss -246380.55\n", 799 | "Epoch 6520, loss -246334.09\n", 800 | "Epoch 6530, loss -246326.69\n", 801 | "Epoch 6540, loss -246352.80\n", 802 | "Epoch 6550, loss -246602.75\n", 803 | "Epoch 6560, loss -246309.95\n", 804 | "Epoch 6570, loss -246668.03\n", 805 | "Epoch 6580, loss -240708.59\n", 806 | "Epoch 6590, loss -245720.83\n", 807 | "Epoch 6600, loss -245755.92\n", 808 | "Epoch 6610, loss -245603.81\n", 809 | "Epoch 6620, loss -246483.36\n", 810 | "Epoch 6630, loss -246610.97\n", 811 | "Epoch 6640, loss -246891.16\n", 812 | "Epoch 6650, loss -247091.58\n", 813 | "Epoch 6660, loss -247377.03\n", 814 | "Epoch 6670, loss -246468.44\n", 815 | "Epoch 6680, loss -197849.52\n", 816 | "Epoch 6690, loss -244461.23\n", 817 | "Epoch 6700, loss -246392.89\n", 818 | "Epoch 6710, loss -246822.80\n", 819 | "Epoch 6720, loss -247107.28\n", 820 | "Epoch 6730, loss -247326.75\n", 821 | "Epoch 6740, loss -247649.30\n", 822 | "Epoch 6750, loss -247645.67\n", 823 | "Epoch 6760, loss -247529.64\n", 824 | "Epoch 6770, loss -247946.39\n", 825 | "Epoch 6780, loss -247704.16\n", 826 | "Epoch 6790, loss -215776.67\n", 827 | "Epoch 6800, loss -247523.05\n", 828 | "Epoch 6810, loss -246816.02\n", 829 | "Epoch 6820, loss -245957.72\n", 830 | "Epoch 6830, loss -247013.75\n", 831 | "Epoch 6840, loss -247636.69\n", 832 | "Epoch 6850, loss -247885.00\n", 833 | "Epoch 6860, loss -248379.16\n", 834 | "Epoch 6870, loss -248193.44\n", 835 | "Epoch 6880, loss -248645.14\n", 836 | "Epoch 6890, loss -248633.75\n", 837 | "Epoch 6900, loss -248578.17\n", 838 | "Epoch 6910, loss -247435.92\n", 839 | "Epoch 6920, loss -192510.89\n", 840 | "Epoch 6930, loss -237795.92\n", 841 | "Epoch 6940, loss -245393.36\n", 842 | "Epoch 6950, loss -246270.03\n", 843 | "Epoch 6960, loss -247538.39\n", 844 | "Epoch 6970, loss -248341.19\n", 845 | "Epoch 6980, loss -248538.88\n", 846 | "Epoch 6990, loss -248851.92\n", 847 | "Epoch 7000, loss -249154.08\n", 848 | "Epoch 7010, loss -249003.78\n", 849 | "Epoch 7020, loss -249088.28\n", 850 | "Epoch 7030, loss -249348.58\n", 851 | "Epoch 7040, loss -249324.66\n", 852 | "Epoch 7050, loss -238353.72\n", 853 | "Epoch 7060, loss -230720.80\n", 854 | "Epoch 7070, loss -247235.44\n", 855 | "Epoch 7080, loss -247908.78\n", 856 | "Epoch 7090, loss -248612.03\n", 857 | "Epoch 7100, loss -249288.11\n", 858 | "Epoch 7110, loss -249399.45\n", 859 | "Epoch 7120, loss -249600.47\n", 860 | "Epoch 7130, loss -249739.78\n", 861 | "Epoch 7140, loss -249997.52\n", 862 | "Epoch 7150, loss -249894.38\n", 863 | "Epoch 7160, loss -249267.72\n", 864 | "Epoch 7170, loss -201593.56\n", 865 | "Epoch 7180, loss -240567.69\n", 866 | "Epoch 7190, loss -248893.36\n", 867 | "Epoch 7200, loss -249752.64\n", 868 | "Epoch 7210, loss -249520.14\n", 869 | "Epoch 7220, loss -250150.80\n", 870 | "Epoch 7230, loss -250257.64\n", 871 | "Epoch 7240, loss -250438.30\n", 872 | "Epoch 7250, loss -250622.67\n", 873 | "Epoch 7260, loss -250723.58\n", 874 | "Epoch 7270, loss -250284.73\n", 875 | "Epoch 7280, loss -182021.88\n", 876 | "Epoch 7290, loss -241027.61\n", 877 | "Epoch 7300, loss -247999.52\n", 878 | "Epoch 7310, loss -248581.52\n", 879 | "Epoch 7320, loss -249734.88\n", 880 | "Epoch 7330, loss -250230.08\n", 881 | "Epoch 7340, loss -250579.80\n", 882 | "Epoch 7350, loss -250842.86\n", 883 | "Epoch 7360, loss -251089.67\n", 884 | "Epoch 7370, loss -251076.16\n", 885 | "Epoch 7380, loss -251193.19\n", 886 | "Epoch 7390, loss -251246.78\n", 887 | "Epoch 7400, loss -251258.88\n", 888 | "Epoch 7410, loss -250951.47\n", 889 | "Epoch 7420, loss -159752.20\n", 890 | "Epoch 7430, loss -245570.67\n", 891 | "Epoch 7440, loss -248467.64\n", 892 | "Epoch 7450, loss -249394.22\n", 893 | "Epoch 7460, loss -250219.16\n", 894 | "Epoch 7470, loss -250807.83\n", 895 | "Epoch 7480, loss -251236.30\n", 896 | "Epoch 7490, loss -251529.83\n", 897 | "Epoch 7500, loss -251597.36\n", 898 | "Epoch 7510, loss -251807.89\n", 899 | "Epoch 7520, loss -251837.55\n", 900 | "Epoch 7530, loss -251938.25\n", 901 | "Epoch 7540, loss -251849.11\n", 902 | "Epoch 7550, loss -250513.59\n", 903 | "Epoch 7560, loss -214625.30\n", 904 | "Epoch 7570, loss -245354.19\n", 905 | "Epoch 7580, loss -249376.53\n", 906 | "Epoch 7590, loss -251551.39\n", 907 | "Epoch 7600, loss -251558.75\n", 908 | "Epoch 7610, loss -252068.52\n", 909 | "Epoch 7620, loss -252216.94\n", 910 | "Epoch 7630, loss -252056.39\n", 911 | "Epoch 7640, loss -252522.02\n", 912 | "Epoch 7650, loss -252545.72\n", 913 | "Epoch 7660, loss -252365.39\n", 914 | "Epoch 7670, loss -221401.67\n", 915 | "Epoch 7680, loss -251818.52\n", 916 | "Epoch 7690, loss -248077.11\n", 917 | "Epoch 7700, loss -249483.80\n", 918 | "Epoch 7710, loss -251344.08\n", 919 | "Epoch 7720, loss -251675.03\n", 920 | "Epoch 7730, loss -252404.30\n", 921 | "Epoch 7740, loss -252581.72\n", 922 | "Epoch 7750, loss -252717.28\n", 923 | "Epoch 7760, loss -252806.02\n", 924 | "Epoch 7770, loss -253003.11\n", 925 | "Epoch 7780, loss -253181.67\n", 926 | "Epoch 7790, loss -253178.28\n", 927 | "Epoch 7800, loss -252083.53\n", 928 | "Epoch 7810, loss -208140.95\n", 929 | "Epoch 7820, loss -251061.47\n", 930 | "Epoch 7830, loss -251541.55\n", 931 | "Epoch 7840, loss -252135.38\n", 932 | "Epoch 7850, loss -252488.53\n", 933 | "Epoch 7860, loss -252996.61\n", 934 | "Epoch 7870, loss -253258.53\n", 935 | "Epoch 7880, loss -253523.25\n", 936 | "Epoch 7890, loss -253520.92\n", 937 | "Epoch 7900, loss -253650.03\n", 938 | "Epoch 7910, loss -253636.92\n", 939 | "Epoch 7920, loss -253022.55\n", 940 | "Epoch 7930, loss -193008.67\n", 941 | "Epoch 7940, loss -248707.88\n", 942 | "Epoch 7950, loss -251664.83\n", 943 | "Epoch 7960, loss -252713.19\n", 944 | "Epoch 7970, loss -253182.36\n", 945 | "Epoch 7980, loss -253467.38\n", 946 | "Epoch 7990, loss -253673.55\n", 947 | "Epoch 8000, loss -253984.33\n", 948 | "Epoch 8010, loss -254108.14\n", 949 | "Epoch 8020, loss -254121.28\n", 950 | "Epoch 8030, loss -254145.44\n", 951 | "Epoch 8040, loss -254104.78\n", 952 | "Epoch 8050, loss -239411.00\n", 953 | "Epoch 8060, loss -233809.61\n", 954 | "Epoch 8070, loss -247694.22\n", 955 | "Epoch 8080, loss -251728.11\n", 956 | "Epoch 8090, loss -253067.75\n", 957 | "Epoch 8100, loss -253814.19\n", 958 | "Epoch 8110, loss -254241.22\n", 959 | "Epoch 8120, loss -254433.19\n", 960 | "Epoch 8130, loss -254632.22\n", 961 | "Epoch 8140, loss -254666.88\n", 962 | "Epoch 8150, loss -254653.67\n", 963 | "Epoch 8160, loss -254740.75\n", 964 | "Epoch 8170, loss -248373.09\n", 965 | "Epoch 8180, loss -246733.59\n", 966 | "Epoch 8190, loss -252447.28\n", 967 | "Epoch 8200, loss -253702.75\n", 968 | "Epoch 8210, loss -254179.61\n", 969 | "Epoch 8220, loss -254317.64\n", 970 | "Epoch 8230, loss -254687.14\n", 971 | "Epoch 8240, loss -255027.17\n", 972 | "Epoch 8250, loss -255074.55\n", 973 | "Epoch 8260, loss -255212.55\n", 974 | "Epoch 8270, loss -255180.92\n", 975 | "Epoch 8280, loss -255310.64\n", 976 | "Epoch 8290, loss -254228.66\n", 977 | "Epoch 8300, loss -175455.95\n", 978 | "Epoch 8310, loss -249767.28\n", 979 | "Epoch 8320, loss -249315.11\n", 980 | "Epoch 8330, loss -253518.31\n", 981 | "Epoch 8340, loss -253764.86\n", 982 | "Epoch 8350, loss -254680.39\n", 983 | "Epoch 8360, loss -255099.92\n", 984 | "Epoch 8370, loss -255404.45\n", 985 | "Epoch 8380, loss -255551.50\n", 986 | "Epoch 8390, loss -255630.80\n", 987 | "Epoch 8400, loss -255739.17\n", 988 | "Epoch 8410, loss -255833.80\n", 989 | "Epoch 8420, loss -255646.95\n", 990 | "Epoch 8430, loss -248848.69\n", 991 | "Epoch 8440, loss -254179.73\n", 992 | "Epoch 8450, loss -251450.92\n", 993 | "Epoch 8460, loss -255491.28\n", 994 | "Epoch 8470, loss -254980.52\n", 995 | "Epoch 8480, loss -255584.31\n", 996 | "Epoch 8490, loss -255920.81\n", 997 | "Epoch 8500, loss -256047.17\n", 998 | "Epoch 8510, loss -256187.30\n", 999 | "Epoch 8520, loss -256188.39\n", 1000 | "Epoch 8530, loss -255918.89\n", 1001 | "Epoch 8540, loss -213745.03\n", 1002 | "Epoch 8550, loss -250809.59\n", 1003 | "Epoch 8560, loss -253372.89\n", 1004 | "Epoch 8570, loss -253867.94\n", 1005 | "Epoch 8580, loss -255059.52\n", 1006 | "Epoch 8590, loss -255845.97\n", 1007 | "Epoch 8600, loss -256236.22\n", 1008 | "Epoch 8610, loss -256397.81\n", 1009 | "Epoch 8620, loss -256484.94\n", 1010 | "Epoch 8630, loss -256647.75\n", 1011 | "Epoch 8640, loss -256728.92\n", 1012 | "Epoch 8650, loss -256677.19\n", 1013 | "Epoch 8660, loss -256291.00\n", 1014 | "Epoch 8670, loss -170247.78\n", 1015 | "Epoch 8680, loss -255868.44\n", 1016 | "Epoch 8690, loss -247803.81\n", 1017 | "Epoch 8700, loss -255061.53\n", 1018 | "Epoch 8710, loss -255006.53\n", 1019 | "Epoch 8720, loss -256079.17\n", 1020 | "Epoch 8730, loss -256446.94\n", 1021 | "Epoch 8740, loss -256684.94\n", 1022 | "Epoch 8750, loss -256888.67\n", 1023 | "Epoch 8760, loss -257038.73\n", 1024 | "Epoch 8770, loss -257163.89\n", 1025 | "Epoch 8780, loss -257170.94\n", 1026 | "Epoch 8790, loss -257019.89\n", 1027 | "Epoch 8800, loss -256580.88\n", 1028 | "Epoch 8810, loss -248716.73\n", 1029 | "Epoch 8820, loss -256076.45\n", 1030 | "Epoch 8830, loss -256981.30\n", 1031 | "Epoch 8840, loss -257143.47\n", 1032 | "Epoch 8850, loss -257201.09\n", 1033 | "Epoch 8860, loss -257305.45\n", 1034 | "Epoch 8870, loss -257458.14\n", 1035 | "Epoch 8880, loss -257501.75\n", 1036 | "Epoch 8890, loss -257601.81\n", 1037 | "Epoch 8900, loss -253139.58\n", 1038 | "Epoch 8910, loss -252153.80\n", 1039 | "Epoch 8920, loss -255627.50\n", 1040 | "Epoch 8930, loss -253377.97\n", 1041 | "Epoch 8940, loss -256481.02\n", 1042 | "Epoch 8950, loss -256672.28\n", 1043 | "Epoch 8960, loss -257292.72\n", 1044 | "Epoch 8970, loss -257604.66\n", 1045 | "Epoch 8980, loss -257805.14\n", 1046 | "Epoch 8990, loss -257913.75\n", 1047 | "Epoch 9000, loss -257981.88\n", 1048 | "Epoch 9010, loss -258119.42\n", 1049 | "Epoch 9020, loss -258221.73\n", 1050 | "Epoch 9030, loss -256670.61\n", 1051 | "Epoch 9040, loss -212199.89\n", 1052 | "Epoch 9050, loss -249394.88\n", 1053 | "Epoch 9060, loss -252373.25\n", 1054 | "Epoch 9070, loss -257130.64\n", 1055 | "Epoch 9080, loss -257036.17\n", 1056 | "Epoch 9090, loss -257780.38\n", 1057 | "Epoch 9100, loss -258118.97\n", 1058 | "Epoch 9110, loss -258270.69\n", 1059 | "Epoch 9120, loss -258531.09\n", 1060 | "Epoch 9130, loss -258639.69\n", 1061 | "Epoch 9140, loss -258727.88\n", 1062 | "Epoch 9150, loss -258745.30\n", 1063 | "Epoch 9160, loss -258567.89\n", 1064 | "Epoch 9170, loss -225345.94\n", 1065 | "Epoch 9180, loss -258179.94\n", 1066 | "Epoch 9190, loss -257489.58\n", 1067 | "Epoch 9200, loss -257234.64\n", 1068 | "Epoch 9210, loss -257677.55\n", 1069 | "Epoch 9220, loss -258369.36\n", 1070 | "Epoch 9230, loss -258619.86\n", 1071 | "Epoch 9240, loss -258881.58\n", 1072 | "Epoch 9250, loss -259042.67\n", 1073 | "Epoch 9260, loss -259208.03\n", 1074 | "Epoch 9270, loss -259192.69\n", 1075 | "Epoch 9280, loss -259335.69\n", 1076 | "Epoch 9290, loss -258877.72\n", 1077 | "Epoch 9300, loss -180914.17\n", 1078 | "Epoch 9310, loss -249591.69\n", 1079 | "Epoch 9320, loss -257817.11\n", 1080 | "Epoch 9330, loss -256833.05\n", 1081 | "Epoch 9340, loss -258375.33\n", 1082 | "Epoch 9350, loss -258828.47\n", 1083 | "Epoch 9360, loss -259087.73\n", 1084 | "Epoch 9370, loss -259376.67\n", 1085 | "Epoch 9380, loss -259496.02\n", 1086 | "Epoch 9390, loss -259615.95\n", 1087 | "Epoch 9400, loss -259695.88\n", 1088 | "Epoch 9410, loss -259738.50\n", 1089 | "Epoch 9420, loss -259760.78\n", 1090 | "Epoch 9430, loss -257501.09\n", 1091 | "Epoch 9440, loss -239199.66\n", 1092 | "Epoch 9450, loss -248344.88\n", 1093 | "Epoch 9460, loss -256773.55\n", 1094 | "Epoch 9470, loss -258635.00\n", 1095 | "Epoch 9480, loss -258673.09\n", 1096 | "Epoch 9490, loss -259385.69\n", 1097 | "Epoch 9500, loss -259609.83\n", 1098 | "Epoch 9510, loss -259896.09\n", 1099 | "Epoch 9520, loss -260059.69\n", 1100 | "Epoch 9530, loss -260134.25\n", 1101 | "Epoch 9540, loss -260125.72\n", 1102 | "Epoch 9550, loss -260206.66\n", 1103 | "Epoch 9560, loss -260119.86\n", 1104 | "Epoch 9570, loss -255652.61\n", 1105 | "Epoch 9580, loss -259807.58\n", 1106 | "Epoch 9590, loss -253633.89\n", 1107 | "Epoch 9600, loss -259070.61\n", 1108 | "Epoch 9610, loss -258974.03\n", 1109 | "Epoch 9620, loss -259548.78\n", 1110 | "Epoch 9630, loss -259964.19\n", 1111 | "Epoch 9640, loss -260234.14\n", 1112 | "Epoch 9650, loss -260397.47\n", 1113 | "Epoch 9660, loss -260531.94\n", 1114 | "Epoch 9670, loss -260621.58\n", 1115 | "Epoch 9680, loss -260707.66\n", 1116 | "Epoch 9690, loss -260721.00\n", 1117 | "Epoch 9700, loss -260016.50\n", 1118 | "Epoch 9710, loss -188198.61\n", 1119 | "Epoch 9720, loss -253015.86\n", 1120 | "Epoch 9730, loss -259231.81\n", 1121 | "Epoch 9740, loss -259448.97\n", 1122 | "Epoch 9750, loss -259826.67\n", 1123 | "Epoch 9760, loss -260339.50\n", 1124 | "Epoch 9770, loss -260602.31\n", 1125 | "Epoch 9780, loss -260797.14\n", 1126 | "Epoch 9790, loss -260997.78\n", 1127 | "Epoch 9800, loss -261016.95\n", 1128 | "Epoch 9810, loss -261106.22\n", 1129 | "Epoch 9820, loss -261134.55\n", 1130 | "Epoch 9830, loss -259630.28\n", 1131 | "Epoch 9840, loss -212660.38\n", 1132 | "Epoch 9850, loss -258195.59\n", 1133 | "Epoch 9860, loss -258386.42\n", 1134 | "Epoch 9870, loss -259592.02\n", 1135 | "Epoch 9880, loss -260255.11\n", 1136 | "Epoch 9890, loss -260996.22\n", 1137 | "Epoch 9900, loss -261157.05\n", 1138 | "Epoch 9910, loss -261369.66\n", 1139 | "Epoch 9920, loss -261461.61\n", 1140 | "Epoch 9930, loss -261505.66\n", 1141 | "Epoch 9940, loss -261609.53\n", 1142 | "Epoch 9950, loss -261549.38\n", 1143 | "Epoch 9960, loss -242829.72\n", 1144 | "Epoch 9970, loss -240825.17\n", 1145 | "Epoch 9980, loss -259397.86\n", 1146 | "Epoch 9990, loss -258035.36\n", 1147 | "9946\n" 1148 | ] 1149 | } 1150 | ], 1151 | "source": [ 1152 | "optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", 1153 | "spe = 25\n", 1154 | "string = './model/modelC4_S1_/'+str(spe)\n", 1155 | "\n", 1156 | "@tf.function\n", 1157 | "def train_step():\n", 1158 | " with tf.GradientTape() as tape:\n", 1159 | " loss_value = 0\n", 1160 | " for i in range(0,spe): \n", 1161 | " logits = model({\"inputsB\":u_in, \"inputsT\":x_t_in}, training=True)\n", 1162 | " loss_value = loss_value + negloglikelihood(s_in, logits)\n", 1163 | " loss_value = loss_value*(1/spe)\n", 1164 | " grads = tape.gradient(loss_value, model.trainable_weights)\n", 1165 | " optimizer.apply_gradients(zip(grads, model.trainable_weights))\n", 1166 | " return loss_value\n", 1167 | "\n", 1168 | "epochs = 10000\n", 1169 | "loss = np.zeros(epochs)\n", 1170 | "\n", 1171 | "for epoch in range(epochs):\n", 1172 | " loss_value = train_step()\n", 1173 | " loss[epoch] = loss_value.numpy()\n", 1174 | " if loss[epoch] <= np.min(loss[0:epoch+1]):\n", 1175 | " model.save_weights(string)\n", 1176 | " last_saved_wt = epoch\n", 1177 | " if epoch%10 == 0:\n", 1178 | " print(\"Epoch %d, loss %.2f\" % (epoch, loss[epoch]))\n", 1179 | "\n", 1180 | "print(last_saved_wt)" 1181 | ] 1182 | }, 1183 | { 1184 | "cell_type": "code", 1185 | "execution_count": 4, 1186 | "id": "9cebf033-567b-4508-a037-5951bf7adb06", 1187 | "metadata": {}, 1188 | "outputs": [ 1189 | { 1190 | "name": "stdout", 1191 | "output_type": "stream", 1192 | "text": [ 1193 | "0\n", 1194 | "5\n", 1195 | "10\n", 1196 | "15\n", 1197 | "20\n", 1198 | "25\n", 1199 | "30\n", 1200 | "35\n", 1201 | "40\n", 1202 | "45\n", 1203 | "50\n", 1204 | "55\n", 1205 | "60\n", 1206 | "65\n", 1207 | "70\n", 1208 | "75\n", 1209 | "80\n", 1210 | "85\n", 1211 | "90\n", 1212 | "95\n", 1213 | "\n", 1214 | "0.0018931915236530012\n", 1215 | "0.41757784508986057\n", 1216 | "0.004533745134983385\n" 1217 | ] 1218 | } 1219 | ], 1220 | "source": [ 1221 | "data_test = spi.loadmat('test_PDE_ADVD.mat')\n", 1222 | "\n", 1223 | "u_in_test = data_test['X_test0']\n", 1224 | "x_t_in_test = data_test['X_test1']\n", 1225 | "s_in_test = data_test['y_test']\n", 1226 | "\n", 1227 | "nsamples = 100\n", 1228 | "nps = 1\n", 1229 | "pred = np.zeros([nsamples,1000000])\n", 1230 | "for i in range(0,nsamples):\n", 1231 | " if i%5 == 0:\n", 1232 | " print(i)\n", 1233 | " pred[i,:] = np.squeeze((model({\"inputsB\":u_in_test, \"inputsT\":x_t_in_test})).sample(1))\n", 1234 | " \n", 1235 | "print()\n", 1236 | "print(np.mean((s_in_test-np.mean(pred, axis = 0)[..., np.newaxis])**2))\n", 1237 | "print(np.mean((s_in_test)**2))\n", 1238 | "print(np.mean((s_in_test-np.mean(pred, axis = 0)[..., np.newaxis])**2)/np.mean((s_in_test)**2)) " 1239 | ] 1240 | }, 1241 | { 1242 | "cell_type": "code", 1243 | "execution_count": 5, 1244 | "id": "522dce78-9071-4c0e-8a64-d3faf8b27077", 1245 | "metadata": {}, 1246 | "outputs": [], 1247 | "source": [ 1248 | "model.save_weights('./model/modelC4stat1')" 1249 | ] 1250 | }, 1251 | { 1252 | "cell_type": "code", 1253 | "execution_count": null, 1254 | "id": "0682ec5b-e914-43af-860f-733d82fe8104", 1255 | "metadata": {}, 1256 | "outputs": [], 1257 | "source": [] 1258 | } 1259 | ], 1260 | "metadata": { 1261 | "kernelspec": { 1262 | "display_name": "Python 3 (ipykernel)", 1263 | "language": "python", 1264 | "name": "python3" 1265 | }, 1266 | "language_info": { 1267 | "codemirror_mode": { 1268 | "name": "ipython", 1269 | "version": 3 1270 | }, 1271 | "file_extension": ".py", 1272 | "mimetype": "text/x-python", 1273 | "name": "python", 1274 | "nbconvert_exporter": "python", 1275 | "pygments_lexer": "ipython3", 1276 | "version": "3.10.2" 1277 | } 1278 | }, 1279 | "nbformat": 4, 1280 | "nbformat_minor": 5 1281 | } 1282 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shailesh-Garg-SG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VB-DeepONet 2 | 3 | VB-DeepONet: A Bayesian operator learning framework for uncertainty quantification 4 | 5 | The GitHub repository contains sample codes for the case studies carried out in the research paper titled 'VB-DeepONet: A Bayesian operator learning framework for uncertainty quantification'. Please go through the research paper to understand the implemented algorithm. Note: Results may vary slightly for different iterations of programs as random initializations of neural network is involved. 6 | 7 | Dataset Link: https://csciitd-my.sharepoint.com/:f:/g/personal/amz218308_iitd_ac_in/Ep2kkIW9rXFMs5UAvDFUWdwBC-iL1QwWmKxlVmfDJtEI1g?e=lg0dxa 8 | 9 | ** If there is some ambiguity in the datasets/codes please comment in the repository. 10 | 11 | arXiv Citation details: 12 | 13 | @article{garg2022variational, 14 | title={Variational Bayes Deep Operator Network: A data-driven Bayesian solver for parametric differential equations}, 15 | author={Garg, Shailesh and Chakraborty, Souvik}, 16 | journal={arXiv preprint arXiv:2206.05655}, 17 | year={2022} 18 | } 19 | 20 | **Citation details for the journal paper will be updated later 21 | --------------------------------------------------------------------------------