├── AdvNet ├── HelperFunctions.py ├── InvNet ├── KSNet_FINAL ├── README.md ├── test_KSNet.py ├── test_advNet.py ├── test_invNet.py ├── train_KSNet.py ├── train_advNet.py └── train_invNet.py /AdvNet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MPTCO/FiniteNet/8d49564062c58509cacc411ad77550bbc518ac6b/AdvNet -------------------------------------------------------------------------------- /HelperFunctions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 8 17:37:03 2020 4 | 5 | @author: ben91 6 | """ 7 | import torch 8 | from torch.nn.parameter import Parameter 9 | 10 | import numpy as np 11 | 12 | def dudx(u): 13 | derv = 1/12*u[:,:,0]-2/3*u[:,:,1]+0*u[:,:,2]+2/3*u[:,:,3]-1/12*u[:,:,4] 14 | return torch.t(derv).unsqueeze(0) 15 | def d2udx2(u): 16 | derv = -1/12*u[:,:,0]+4/3*u[:,:,1]-5/2*u[:,:,2]+4/3*u[:,:,3]-1/12*u[:,:,4] 17 | return torch.t(derv).unsqueeze(0) 18 | def d3udx3(u): 19 | derv = -1/2*u[:,:,0]+u[:,:,1]+0*u[:,:,2]-1*u[:,:,3]+1/2*u[:,:,4] 20 | return torch.t(derv).unsqueeze(0) 21 | def d4udx4(u): 22 | derv = 1*u[:,:,0]-4*u[:,:,1]+6*u[:,:,2]-4*u[:,:,3]+1*u[:,:,4] 23 | return torch.t(derv).unsqueeze(0) 24 | 25 | def makeIC(L): 26 | def IC(x): 27 | #np.random.seed(N) 28 | f = 0*x 29 | for j in range(0,5): 30 | f = f + torch.rand(1,dtype=torch.double)*torch.sin(2*j*np.pi*(x-torch.rand(1,dtype=torch.double))/L) 31 | #f = f + (x>(L/2)).double()*(5-10*torch.rand(1,dtype=torch.double)) 32 | return f 33 | return IC 34 | 35 | def makeICdsc(L): 36 | def IC(x): 37 | #np.random.seed(N) 38 | f = 0*x 39 | for j in range(0,5): 40 | f = f + torch.rand(1,dtype=torch.double)*torch.sin(2*j*np.pi*(x-torch.rand(1,dtype=torch.double))/L) 41 | #dscc = 1+4*torch.rand(1,dtype=torch.double) 42 | f = f + (x>(L/2)).double()*(5-10*torch.rand(1,dtype=torch.double)) 43 | #f = f + (x>(L/2)).double()*dscc*np.random.choice((-1, 1)) 44 | return f 45 | return IC 46 | 47 | def reshKS(ui): 48 | uj = torch.t(ui.squeeze(0)) 49 | Nx = uj.size(0) 50 | Nt = ui.size(1) 51 | U_proc = torch.zeros(Nx,Nt,7).double() 52 | U_proc[:,:,0] = uj.roll(-3,0) 53 | U_proc[:,:,1] = uj.roll(-2,0) 54 | U_proc[:,:,2] = uj.roll(-1,0) 55 | U_proc[:,:,3] = uj.roll(0,0) 56 | U_proc[:,:,4] = uj.roll(1,0) 57 | U_proc[:,:,5] = uj.roll(2,0) 58 | U_proc[:,:,6] = uj.roll(3,0) 59 | return U_proc 60 | def resh(ui): 61 | uj = torch.t(ui.squeeze(0)) 62 | Nx = uj.size(0) 63 | Nt = ui.size(1) 64 | U_proc = torch.zeros(Nx,Nt,5).double() 65 | #U_proc[:,:,0] = uj.roll(-2,0) 66 | #U_proc[:,:,1] = uj.roll(-1,0) 67 | #U_proc[:,:,2] = uj.roll(0,0) 68 | #U_proc[:,:,3] = uj.roll(1,0) 69 | #U_proc[:,:,4] = uj.roll(2,0) 70 | 71 | U_proc[:,:,0] = uj.roll(2,0) 72 | U_proc[:,:,1] = uj.roll(1,0) 73 | U_proc[:,:,2] = uj.roll(0,0) 74 | U_proc[:,:,3] = uj.roll(-1,0) 75 | U_proc[:,:,4] = uj.roll(-2,0) 76 | return U_proc 77 | 78 | def wenoCoeff(ur): 79 | ep = 1E-6 80 | 81 | B1 = 13/12*(ur[:,:,0] - 2*ur[:,:,1] + ur[:,:,2])**2 + 1/4*(ur[:,:,0] - 4*ur[:,:,1] + 3*ur[:,:,2])**2 82 | B2 = 13/12*(ur[:,:,1] - 2*ur[:,:,2] + ur[:,:,3])**2 + 1/4*(ur[:,:,1] - ur[:,:,3])**2 83 | B3 = 13/12*(ur[:,:,2] - 2*ur[:,:,3] + ur[:,:,4])**2 + 1/4*(3*ur[:,:,2] - 4*ur[:,:,3] + ur[:,:,4])**2 84 | 85 | g1 = 1/10 86 | g2 = 3/5 87 | g3 = 3/10 88 | 89 | wt1 = g1/(ep+B1)**2 90 | wt2 = g2/(ep+B2)**2 91 | wt3 = g3/(ep+B3)**2 92 | wts = wt1 + wt2 + wt3 93 | 94 | w1 = wt1/wts 95 | w2 = wt2/wts 96 | w3 = wt3/wts 97 | 98 | c = torch.zeros(ur.size()).double() 99 | 100 | c[:,:,0] = 1/3*w1 101 | c[:,:,1] = -7/6*w1 - 1/6*w2 102 | c[:,:,2] = 11/6*w1 + 5/6*w2 + 1/3*w3 103 | c[:,:,3] = 1/3*w2 + 5/6*w3 104 | c[:,:,4] = -1/6*w3 105 | return c 106 | 107 | def exactSol(x,t): 108 | def f(x): 109 | #return torch.exp(-25*x**2) 110 | return (x>=1).double() 111 | return f((x-t)%2) 112 | 113 | def randGrid(L,fullInf,S): 114 | cfl = torch.rand(1,dtype=torch.double)*0.8+0.2 115 | dx = torch.rand(1,dtype=torch.double)*0.046+0.004 116 | dt = cfl*dx 117 | if(fullInf == 0): 118 | T = 1 119 | else: 120 | T = dt*(S+1) 121 | xc = torch.linspace(0,L,int(L/dx)+1,dtype=torch.double) 122 | xc = xc[:-1] 123 | tc = torch.linspace(0,T,int(T/dt),dtype=torch.double) 124 | return torch.meshgrid(xc,tc) 125 | 126 | def scaleAvg(uip): 127 | min_u = uip.min(2)[0] 128 | max_u = uip.max(2)[0] 129 | const_n = min_u==max_u 130 | #print('u: ', u) 131 | u_tmp = torch.zeros_like(uip[:,:,2]) 132 | u_tmp[:] = uip[:,:,2] 133 | for i in range(0,5): 134 | uip[:,:,i] = (uip[:,:,i]-min_u)/(max_u-min_u) 135 | return uip, const_n, u_tmp 136 | -------------------------------------------------------------------------------- /InvNet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MPTCO/FiniteNet/8d49564062c58509cacc411ad77550bbc518ac6b/InvNet -------------------------------------------------------------------------------- /KSNet_FINAL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MPTCO/FiniteNet/8d49564062c58509cacc411ad77550bbc518ac6b/KSNet_FINAL -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FiniteNet 2 | Code for the FiniteNet ICML Paper. To run this code, simply put the HelperFunctions.py file into your directory, along with any data you wish to use or trained models you wish to evaluate. Then run one of the attached scripts that start with 'test' or 'train' to test or train a FiniteNet model. To run this code, you must have numpy, pytorch, and matplotlib. 3 | 4 | Dataset used to train for inviscid Burgers' equation: https://figshare.com/articles/invBurg_train_npy/11796201 5 | 6 | Dataset used to train for KS equation: https://figshare.com/articles/ks_train_csv/11796198 7 | 8 | Dataset used to test for KS equation: https://figshare.com/articles/ks_test_npy/11796186 9 | 10 | Linear advection equation was trained and tested with data generated on the fly, as was data for testing inviscid Burgers' equation. 11 | -------------------------------------------------------------------------------- /test_KSNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Sep 24 18:08:55 2019 4 | 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn.parameter import Parameter 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import matplotlib as mpl 13 | import time 14 | 15 | from numpy import genfromtxt 16 | from HelperFunctions import makeIC, resh, reshKS, wenoCoeff, exactSol, randGrid 17 | plt.close('all') # close all open figures 18 | # Define and set custom LaTeX style 19 | styleNHN = { 20 | "pgf.rcfonts":False, 21 | "pgf.texsystem": "pdflatex", 22 | "text.usetex": False, 23 | "font.family": "serif" 24 | } 25 | mpl.rcParams.update(styleNHN) 26 | 27 | # Plotting defaults 28 | ALW = 0.75 # AxesLineWidth 29 | FSZ = 12 # Fontsize 30 | LW = 2 # LineWidth 31 | MSZ = 5 # MarkerSize 32 | SMALL_SIZE = 8 # Tiny font size 33 | MEDIUM_SIZE = 10 # Small font size 34 | BIGGER_SIZE = 14 # Large font size 35 | plt.rc('font', size=FSZ) # controls default text sizes 36 | plt.rc('axes', titlesize=FSZ) # fontsize of the axes title 37 | plt.rc('axes', labelsize=FSZ) # fontsize of the x and y labels 38 | plt.rc('xtick', labelsize=FSZ) # fontsize of the x-tick labels 39 | plt.rc('ytick', labelsize=FSZ) # fontsize of the y-tick labels 40 | plt.rc('legend', fontsize=FSZ) # legend fontsize 41 | plt.rc('figure', titlesize=FSZ) # fontsize of the figure title 42 | plt.rcParams['axes.linewidth'] = ALW # sets the default axes lindewidth to ``ALW'' 43 | plt.rcParams["mathtext.fontset"] = 'cm' # Computer Modern mathtext font (applies when ``usetex=False'') 44 | 45 | r2 = np.load('ks_test.npy') 46 | 47 | class Model(nn.Module): 48 | def __init__(self, input_size, output_size, hidden_dim, n_layers): 49 | super(Model, self).__init__() 50 | self.hidden_dim = hidden_dim 51 | self.n_layers = n_layers 52 | 53 | # RNN Layer 54 | self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True).double() 55 | # Fully connected layer 56 | self.fc1 = nn.Linear(hidden_dim, 5).double() 57 | self.fc2 = nn.Linear(hidden_dim, 5).double() 58 | self.fc4 = nn.Linear(hidden_dim, 7).double() 59 | self.trf1 = nn.Linear(5, 5).double() 60 | self.trf2 = nn.Linear(5, 5).double() 61 | self.trf4 = nn.Linear(7, 7).double() 62 | cm1 = Parameter(torch.tensor([[4/35,-9/35,3/35,1/7,-3/35],[-9/35,22/35,-12/35,-6/35,1/7],[3/35,-12/35,18/35,-12/35,3/35],[1/7,-6/35,-12/35,22/35,-9/35],[-3/35,1/7,3/35,-9/35,4/35]]).double())#2nd order 63 | cm2 = Parameter(torch.tensor([[1/70,-2/35,3/35,-2/35,1/70],[-2/35,8/35,-12/35,8/35,-2/35],[3/35,-12/35,18/35,-12/35,3/35],[-2/35,8/35,-12/35,8/35,-2/35],[1/70,-2/35,3/35,-2/35,1/70]]).double())#2nd order 64 | cm4 = Parameter(torch.tensor([[1/924,-1/154,5/308,-5/231,5/308,-1/154,1/924],[-1/154,3/77,-15/154,10/77,-15/154,3/77,-1/154],[5/308,-15/154,75/308,-25/77,75/308,-15/154,5/308], [-5/231,10/77,-25/77,100/231,-25/77,10/77,-5/231],[5/308,-15/154,75/308,-25/77,75/308,-15/154,5/308],[-1/154,3/77,-15/154,10/77,-15/154,3/77,-1/154],[1/924,-1/154,5/308,-5/231,5/308,-1/154,1/924]]).double())#2nd order 65 | cv1 = Parameter(torch.tensor([-0.2,-0.1,0,0.1,0.2]).double())#2nd order 66 | cv2 = Parameter(torch.tensor([2/7,-1/7,-2/7,-1/7,2/7]).double())#2nd order 67 | cv4 = Parameter(torch.tensor([3/11,-7/11,1/11,6/11,1/11,-7/11,3/11]).double())#2nd order 68 | self.trf1.bias = cv1 69 | self.trf1.weight = cm1 70 | self.trf2.bias = cv2 71 | self.trf2.weight = cm2 72 | self.trf4.bias = cv4 73 | self.trf4.weight = cm4 74 | 75 | for p in self.trf1.parameters(): 76 | p.requires_grad=False 77 | for p in self.trf2.parameters(): 78 | p.requires_grad=False 79 | for p in self.trf4.parameters(): 80 | p.requires_grad=False 81 | 82 | def forward(self, ui, dt, dx, hidden, test): 83 | Nx = ui.size(2) 84 | Nt = ui.size(1) 85 | uip = torch.zeros(Nx,Nt,7).double() 86 | 87 | uip = reshKS(ui) 88 | c1 = torch.zeros(Nx,Nt,5).double() 89 | c1[:,:,0] = 1/12 90 | c1[:,:,1] = -2/3 91 | c1[:,:,2] = 0 92 | c1[:,:,3] = 2/3 93 | c1[:,:,4] = -1/12 94 | c2 = torch.zeros(Nx,Nt,5).double() 95 | c2[:,:,0] = -1/12 96 | c2[:,:,1] = 4/3 97 | c2[:,:,2] = -5/2 98 | c2[:,:,3] = 4/3 99 | c2[:,:,4] = -1/12 100 | c4 = torch.zeros(Nx,Nt,7).double() 101 | c4[:,:,0] = -1/6 102 | c4[:,:,1] = 2 103 | c4[:,:,2] = -13/2 104 | c4[:,:,3] = 28/3 105 | c4[:,:,4] = -13/2 106 | c4[:,:,5] = 2 107 | c4[:,:,6] = -1/6 108 | vis = 1 109 | hvis = 0.1 110 | if(test==1): 111 | f, hidden = self.lstm(uip, hidden)# Passing in the input and hidden state into the model and obtaining outputs 112 | #f = self.fc1(f) 113 | fi1 = self.fc1(f) 114 | fi2 = self.fc2(f) 115 | fi4 = self.fc4(f) 116 | f1 = fi1 + c1 117 | f2 = fi2 + c2 118 | f4 = fi4 + c4 119 | f1 = self.trf1(f1)#transform coefficients to be consistent 120 | f2 = self.trf2(f2)#transform coefficients to be consistent 121 | f4 = self.trf4(f4)#transform coefficients to be consistent 122 | else: 123 | f1 = c1 124 | f2 = c2 125 | f4 = c4 126 | fi1 = 0 127 | fi2 = 0 128 | fi4 = 0 129 | F1 = 0.5*torch.t(torch.sum(f1*(uip[:,:,1:6]**2), dim = 2)).unsqueeze(0)/dx 130 | F2 = torch.t(torch.sum(f2*uip[:,:,1:6], dim = 2)).unsqueeze(0)/(dx**2) 131 | F4 = torch.t(torch.sum(f4*uip[:,:,:], dim = 2)).unsqueeze(0)/(dx**4) 132 | u1 = ui - dt*(F1 + F2*vis + F4*hvis) 133 | pen = fi1**2 + fi2**2 134 | pen2 = fi4**2 135 | #u1 = ui - dt/dx*(dui**2-dui.roll(1,2)**2)/2 136 | 137 | u1p = reshKS(u1) 138 | if(test==1): 139 | f, hidden = self.lstm(u1p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 140 | #f = self.fc1(f) 141 | f11 = self.fc1(f) 142 | f12 = self.fc2(f) 143 | f14 = self.fc4(f) 144 | f1 = f11 + c1 145 | f2 = f12 + c2 146 | f4 = f14 + c4 147 | f1 = self.trf1(f1)#transform coefficients to be consistent 148 | f2 = self.trf2(f2)#transform coefficients to be consistent 149 | f4 = self.trf4(f4)#transform coefficients to be consistent 150 | else: 151 | f1 = c1 152 | f2 = c2 153 | f4 = c4 154 | f11 = 0 155 | f12 = 0 156 | f14 = 0 157 | F1 = 0.5*torch.t(torch.sum(f1*(u1p[:,:,1:6]**2), dim = 2)).unsqueeze(0)/dx 158 | F2 = torch.t(torch.sum(f2*u1p[:,:,1:6], dim = 2)).unsqueeze(0)/(dx**2) 159 | F4 = torch.t(torch.sum(f4*u1p[:,:,:], dim = 2)).unsqueeze(0)/(dx**4) 160 | u2 = 3/4*ui + 1/4*u1 - 1/4*dt*(F1 + F2*vis + F4*hvis) 161 | pen += f11**2 + f12**2 162 | pen2 += f14**2 163 | u2p = reshKS(u2) 164 | if(test==1): 165 | f, hidden = self.lstm(u2p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 166 | #f = self.fc1(f) 167 | f21 = self.fc1(f) 168 | f22 = self.fc2(f) 169 | f24 = self.fc4(f) 170 | f1 = f21 + c1 171 | f2 = f22 + c2 172 | f4 = f24 + c4 173 | f1 = self.trf1(f1)#transform coefficients to be consistent 174 | f2 = self.trf2(f2)#transform coefficients to be consistent 175 | f4 = self.trf4(f4)#transform coefficients to be consistent 176 | else: 177 | f1 = c1 178 | f2 = c2 179 | f4 = c4 180 | f21 = 0 181 | f22 = 0 182 | f24 = 0 183 | F1 = 0.5*torch.t(torch.sum(f1*(u2p[:,:,1:6]**2), dim = 2)).unsqueeze(0)/dx 184 | F2 = torch.t(torch.sum(f2*u2p[:,:,1:6], dim = 2)).unsqueeze(0)/(dx**2) 185 | F4 = torch.t(torch.sum(f4*u2p[:,:,:], dim = 2)).unsqueeze(0)/(dx**4) 186 | out = 1/3*ui + 2/3*u2 - 2/3*dt*(F1 + F2*vis + F4*hvis) 187 | pen = f21**2 + f22**2 188 | pen2 = f24**2 189 | return out, hidden, pen, pen2 190 | 191 | def init_hidden(self, batch_size): 192 | # This method generates the first hidden state of zeros which we'll use in the forward pass 193 | hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 194 | cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 195 | # We'll send the tensor holding the hidden state to the device we specified earlier as well 196 | hc = (hidden,cell) 197 | return hc 198 | 199 | # Instantiate the model with hyperparameters 200 | model = Model(input_size=7, output_size=1, hidden_dim=32, n_layers=3) 201 | 202 | # Define hyperparameters 203 | lr = 0.001 204 | 205 | # Define Loss, Optimizer 206 | criterion = nn.MSELoss() 207 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 208 | 209 | 210 | def KS(xgf,tgf,IC,al,dt,dx): 211 | nx = np.shape(xgf)[0] 212 | nt = np.shape(xgf)[1] 213 | if(al == 0): 214 | u_ex = np.zeros((int(nx/4),int((nt-1)/256+1))) 215 | else: 216 | u_ex = np.zeros((nx,nt)) 217 | u0 = IC 218 | a = 0.1 219 | for i in range(0,nt): 220 | F1 = (1/12*(np.roll(u0,-2)**2)-2/3*(np.roll(u0,-1)**2) + 2/3*(np.roll(u0,1)**2)-1/12*(np.roll(u0,2)**2))/(2*dx); 221 | F2 = (-1/12*np.roll(u0,-2)+4/3*np.roll(u0,-1)-5/2*u0+4/3*np.roll(u0,1)-1/12*np.roll(u0,2))/(dx**2); 222 | F3 = a*(-1/6*np.roll(u0,-3)+2*np.roll(u0,-2)-13/2*np.roll(u0,-1)+28/3*u0-13/2*np.roll(u0,1)+2*np.roll(u0,2)-1/6*np.roll(u0,3))/(dx**4); 223 | u1 = u0 - dt*(F1+F2+F3); 224 | 225 | F1 = (1/12*(np.roll(u1,-2)**2)-2/3*(np.roll(u1,-1)**2) + 2/3*(np.roll(u1,1)**2)-1/12*(np.roll(u1,2)**2))/(2*dx); 226 | F2 = (-1/12*np.roll(u1,-2)+4/3*np.roll(u1,-1)-5/2*u1+4/3*np.roll(u1,1)-1/12*np.roll(u1,2))/(dx**2); 227 | F3 = a*(-1/6*np.roll(u1,-3)+2*np.roll(u1,-2)-13/2*np.roll(u1,-1)+28/3*u1-13/2*np.roll(u1,1)+2*np.roll(u1,2)-1/6*np.roll(u1,3))/(dx**4); 228 | u2 = 3/4*u0 + 1/4*u1 - 1/4*dt*(F1+F2+F3); 229 | 230 | F1 = (1/12*(np.roll(u2,-2)**2)-2/3*(np.roll(u2,-1)**2) + 2/3*(np.roll(u2,1)**2)-1/12*(np.roll(u2,2)**2))/(2*dx); 231 | F2 = (-1/12*np.roll(u2,-2)+4/3*np.roll(u2,-1)-5/2*u2+4/3*np.roll(u2,1)-1/12*np.roll(u2,2))/(dx**2); 232 | F3 = a*(-1/6*np.roll(u2,-3)+2*np.roll(u2,-2)-13/2*np.roll(u2,-1)+28/3*u2-13/2*np.roll(u2,1)+2*np.roll(u2,2)-1/6*np.roll(u2,3))/(dx**4); 233 | u0 = 1/3*u0 + 2/3*u2 - 2/3*dt*(F1+F2+F3); 234 | if(i%256==0 and al == 0): 235 | u_ex[:,int(i/256)] = u0[0::4] 236 | if(al==1): 237 | u_ex[:,i] = u0 238 | return u_ex 239 | #Load trained model 240 | model.load_state_dict(torch.load('KSNet_FINAL')) 241 | 242 | strt = time.time() 243 | L = 20 244 | 245 | nt = 1600 246 | mbs = 5 247 | 248 | cfl = 0.01 249 | cflf = 0.01/64 250 | dx = 0.25 251 | dxf = dx/4 252 | dt = cfl*dx 253 | dtf = cflf*dxf 254 | 255 | T = dt*nt 256 | xc = torch.linspace(0,L,int(L/dx)+1,dtype=torch.double) 257 | xcf = torch.linspace(0,L,int(L/dxf)+1,dtype=torch.double) 258 | xc = xc[:-1] 259 | xcf = xcf[:-1] 260 | tc = torch.linspace(0,T,int(T/dt)+1,dtype=torch.double) 261 | tcf = torch.linspace(0,T,int(T/dtf)+1,dtype=torch.double) 262 | xg,tg = torch.meshgrid(xc,tc)#make the coarse grid 263 | xgf,tgf = torch.meshgrid(xcf,tcf)#make the fine grid 264 | 265 | N_tst = 1000 266 | test_rat = torch.zeros(N_tst) 267 | fdm_err = torch.zeros(N_tst) 268 | fnn_err = torch.zeros(N_tst) 269 | min_rat = 1 270 | te = torch.linspace(0,T,int(4/dt)+1,dtype=torch.double) 271 | xge,tge = torch.meshgrid(xc,te)#make the fine grid 272 | for j in range(0,N_tst): 273 | strt = time.time() 274 | solt = torch.tensor(r2[1600:,j,:]) 275 | target_seq = torch.t(solt[:,:]) 276 | target_seq = solt.unsqueeze(0) 277 | batch_size = xc.size(0) 278 | hidden1 = model.init_hidden(batch_size) 279 | h1_we = model.init_hidden(batch_size) 280 | 281 | x_t = target_seq[0,0,:] 282 | output = torch.zeros_like(target_seq) 283 | output_we = torch.zeros_like(target_seq) 284 | x_t = x_t.unsqueeze(0) 285 | x_t = x_t.unsqueeze(0) 286 | y_t = x_t[:,:,:] 287 | 288 | output[:,0,:] = x_t.detach() 289 | output_we[:,0,:] = x_t.detach() 290 | 291 | for i in range(0,1600): 292 | #print(i) 293 | x_t, hidden1, pen1, pen2 = model(x_t.detach(),dt,dx, hidden1, 1) 294 | y_t, h1_we, abq, abq2 = model(y_t.detach(),dt,dx, h1_we, 0) 295 | output[:,i+1,:] = x_t.detach() 296 | output_we[:,i+1,:] = y_t.detach() 297 | 298 | err_FNN = ((output-target_seq)**2).mean() 299 | err_WE5 = ((output_we-target_seq)**2).mean() 300 | fdm_err[j] = err_WE5.detach() 301 | fnn_err[j] = err_FNN.detach() 302 | test_rat[j] = (err_FNN/err_WE5).detach() 303 | if(min_rat>=test_rat[j]): 304 | min_rat=test_rat[j].detach() 305 | sol_eg = target_seq.detach() 306 | sol_NN = output.detach() 307 | sol_WE = output_we.detach() 308 | print('Iter: ',j,' Err: ',test_rat[j]) 309 | enddt = time.time() 310 | print('Time of Epoch: ', enddt-strt) 311 | #Animate best case 312 | plt.figure() 313 | for i in range(0,160): 314 | plt.clf() 315 | plt.plot(xc,sol_NN[0,i*10,:].detach(),'.') 316 | plt.plot(xc,sol_WE[0,i*10,:].detach(),'.') 317 | plt.plot(xc,sol_eg[0,i*10,:].detach()) 318 | plt.pause(0.001) 319 | plt.legend(('RNN','FDM','Exact')) 320 | #Animate last case 321 | plt.figure() 322 | for i in range(0,160): 323 | plt.clf() 324 | plt.plot(xc,output[0,i*10,:].detach()) 325 | plt.plot(xc,output_we[0,i*10,:].detach()) 326 | plt.plot(xc,target_seq[0,i*10,:].detach()) 327 | plt.pause(0.001) 328 | plt.legend(('RNN','FDM','Exact')) 329 | #Create pmf 330 | plt.figure(figsize=(6, 2)) 331 | heights,bins = np.histogram(torch.log10(test_rat).detach(),bins=20) 332 | heights = heights/sum(heights) 333 | plt.bar(bins[:-1],heights,width=(max(bins) - min(bins))/len(bins), color="blue", alpha=0.5) 334 | plt.xlabel('Error Ratio') 335 | plt.ylabel('Frequency') 336 | plt.xticks((-2,-1.5,-1,-0.5,0,0.5),['$10^{-2.0}$','$10^{-1.5}$','$10^{-1.0}$','$10^{-0.5}$','$10^0$','$10^{-0.5}$']) 337 | plt.tight_layout() 338 | -------------------------------------------------------------------------------- /test_advNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Sep 24 18:08:55 2019 4 | 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn.parameter import Parameter 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import matplotlib as mpl 13 | 14 | from HelperFunctions import makeICdsc, resh, wenoCoeff, exactSol, randGrid 15 | plt.close('all') # close all open figures 16 | # Define and set custom LaTeX style 17 | styleNHN = { 18 | "pgf.rcfonts":False, 19 | "pgf.texsystem": "pdflatex", 20 | "text.usetex": False, 21 | "font.family": "serif" 22 | } 23 | mpl.rcParams.update(styleNHN) 24 | 25 | # Plotting defaults 26 | ALW = 0.75 # AxesLineWidth 27 | FSZ = 12 # Fontsize 28 | LW = 2 # LineWidth 29 | MSZ = 5 # MarkerSize 30 | SMALL_SIZE = 8 # Tiny font size 31 | MEDIUM_SIZE = 10 # Small font size 32 | BIGGER_SIZE = 14 # Large font size 33 | plt.rc('font', size=FSZ) # controls default text sizes 34 | plt.rc('axes', titlesize=FSZ) # fontsize of the axes title 35 | plt.rc('axes', labelsize=FSZ) # fontsize of the x and y labels 36 | plt.rc('xtick', labelsize=FSZ) # fontsize of the x-tick labels 37 | plt.rc('ytick', labelsize=FSZ) # fontsize of the y-tick labels 38 | plt.rc('legend', fontsize=FSZ) # legend fontsize 39 | plt.rc('figure', titlesize=FSZ) # fontsize of the figure title 40 | plt.rcParams['axes.linewidth'] = ALW # sets the default axes lindewidth to ``ALW'' 41 | plt.rcParams["mathtext.fontset"] = 'cm' # Computer Modern mathtext font (applies when ``usetex=False'') 42 | 43 | class Model(nn.Module): 44 | def __init__(self, input_size, output_size, hidden_dim, n_layers): 45 | super(Model, self).__init__() 46 | self.hidden_dim = hidden_dim 47 | self.n_layers = n_layers 48 | 49 | # RNN Layer 50 | self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True).double() 51 | # Fully connected layer 52 | self.fc1 = nn.Linear(hidden_dim, hidden_dim).double() 53 | self.fc2 = nn.Linear(hidden_dim, 5).double() 54 | self.trf = nn.Linear(5, 5).double() 55 | #cm = Parameter(torch.tensor([[0.8,-0.2,-0.2,-0.2,-0.2],[-0.2,0.8,-0.2,-0.2,-0.2],[-0.2,-0.2,0.8,-0.2,-0.2],[-0.2,-0.2,-0.2,0.8,-0.2],[-0.2,-0.2,-0.2,-0.2,0.8]]).double())#1st order 56 | cm = Parameter(torch.tensor([[0.4,-0.4,-0.2,0,0.2],[-0.4,0.7,-0.2,-0.1,0],[-0.2,-0.2,0.8,-0.2,-0.2],[0,-0.1,-0.2,0.7,-0.4],[0.2,0,-0.2,-0.4,0.4]]).double())#2nd order 57 | #cm = Parameter(torch.tensor([[4/35,-9/35,3/35,1/7,-3/35],[-9/35,22/35,-12/35,-6/35,1/7],[3/35,-12/35,18/35,-12/35,3/35],[1/7,-6/35,-12/35,22/35,-9/35],[-3/35,1/7,3/35,-9/35,4/35]]).double())#3rd order 58 | #cv = Parameter(torch.tensor([0.2,0.2,0.2,0.2,0.2]).double())#1st order 59 | cv = Parameter(torch.tensor([0.1,0.15,0.2,0.25,0.3]).double())#2nd order 60 | #cv = Parameter(torch.tensor([-17/105,59/210,97/210,8/21,4/105]).double())#3rd order 61 | self.trf.bias = cv 62 | self.trf.weight = cm 63 | 64 | for p in self.trf.parameters(): 65 | p.requires_grad=False 66 | 67 | def forward(self, ui, dt, dx, hidden, test): 68 | Nx = ui.size(2) 69 | Nt = ui.size(1) 70 | uip = torch.zeros(Nx,Nt,5).double() 71 | 72 | uip = resh(ui) 73 | ci = wenoCoeff(uip) 74 | cm = torch.zeros_like(uip) 75 | cm[:,:,0] = 2/60 76 | cm[:,:,1] =-13/60 77 | cm[:,:,2] = 47/60 78 | cm[:,:,3] = 27/60 79 | cm[:,:,4] = -3/60 80 | if(test==1): 81 | f, hidden = self.lstm(uip, hidden)# Passing in the input and hidden state into the model and obtaining outputs 82 | #f = self.fc1(f) 83 | fi = self.fc2(f) 84 | f = fi + cm 85 | f = self.trf(f)#transform coefficients to be consistent 86 | else: 87 | f = ci 88 | fi = 0 89 | dui = torch.t(torch.sum(f*uip, dim = 2)).unsqueeze(0) 90 | u1 = ui - dt/dx*(dui-dui.roll(1,2)) 91 | 92 | u1p = resh(u1) 93 | c1 = wenoCoeff(u1p) 94 | if(test==1): 95 | f, hidden = self.lstm(u1p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 96 | #f = self.fc1(f) 97 | f1 = self.fc2(f) 98 | f = f1 + cm 99 | f = self.trf(f)#transform coefficients to be consistent 100 | else: 101 | f = c1 102 | f1 = 0 103 | du1 = torch.t(torch.sum(f*u1p, dim = 2)).unsqueeze(0) 104 | u2 = 3/4*ui + 1/4*u1 - 1/4*dt/dx*(du1-du1.roll(1,2)) 105 | 106 | u2p = resh(u2) 107 | c2 = wenoCoeff(u2p) 108 | if(test==1): 109 | f, hidden = self.lstm(u2p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 110 | #f = self.fc1(f) 111 | f2 = self.fc2(f) 112 | f = f2 + cm 113 | f = self.trf(f)#transform coefficients to be consistent 114 | else: 115 | f = c2 116 | f2 = 0 117 | du2 = torch.t(torch.sum(f*u2p, dim = 2)).unsqueeze(0) 118 | out = 1/3*ui + 2/3*u2 - 2/3*dt/dx*(du2-du2.roll(1,2)) 119 | return out, hidden, fi, f1, f2 120 | 121 | def init_hidden(self, batch_size): 122 | # This method generates the first hidden state of zeros which we'll use in the forward pass 123 | hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 124 | cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 125 | # We'll send the tensor holding the hidden state to the device we specified earlier as well 126 | hc = (hidden,cell) 127 | return hc 128 | 129 | # Instantiate the model with hyperparameters 130 | model = Model(input_size=5, output_size=1, hidden_dim=32, n_layers=3) 131 | 132 | # Define hyperparameters 133 | n_epochs = 10 134 | lr = 0.001 135 | 136 | # Define Loss, Optimizer 137 | criterion = nn.MSELoss() 138 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 139 | 140 | errs_WE = torch.zeros(n_epochs) 141 | errs_NN = torch.zeros(n_epochs) 142 | rel_err = torch.zeros(n_epochs) 143 | 144 | randgr = 0 145 | S = 100 146 | L = 1 147 | 148 | def compTV(u): 149 | dif = abs(u.roll(1)-u) 150 | return torch.sum(dif) 151 | # Training Run 152 | if(randgr==0): 153 | cfl = 0.5 154 | dx = 0.01 155 | dt = cfl*dx 156 | T = dt*(S) 157 | xc = torch.linspace(0,L,int(L/dx)+1,dtype=torch.double) 158 | xcf = torch.linspace(0,L,int(L/dx)*4+1,dtype=torch.double) 159 | xc = xc[:-1] 160 | xcf = xcf[:-1] 161 | tc = torch.linspace(0,T,int(T/dt)+1,dtype=torch.double) 162 | tcf = torch.linspace(0,T,int(T/dt)*4+1,dtype=torch.double) 163 | xg,tg = torch.meshgrid(xc,tc)#make the coarse grid 164 | xgf,tgf = torch.meshgrid(xcf,tcf)#make the fine grid 165 | batch_size = xc.size(0) 166 | IC_fx = makeICdsc(L) 167 | 168 | modelTEST = Model(input_size=5, output_size=1, hidden_dim=32, n_layers=3) 169 | modelTEST.load_state_dict(torch.load('AdvNet')) 170 | mbs = 1000 171 | #Test network on new initial conditions without further training. 172 | TVTEST = torch.zeros((mbs),dtype=torch.double) 173 | discSZ_TEST = torch.zeros((mbs),dtype=torch.double) 174 | all_ratioTEST = torch.zeros((mbs),dtype=torch.double) 175 | all_ICTEST = torch.zeros((mbs, len(xc)),dtype=torch.double) 176 | for j in range(0,mbs): 177 | solt = torch.t(IC_fx((xg-tg)%L)).unsqueeze(0) 178 | IC = solt[0,0,:] 179 | discSZ_TEST[j] = max(abs(IC-IC.roll(1))).detach() 180 | all_ICTEST[j,:] = IC.detach() 181 | target_seq = solt 182 | 183 | hidden1 = modelTEST.init_hidden(batch_size) 184 | h1_we = modelTEST.init_hidden(batch_size) 185 | 186 | x_t = solt[0,0,:] 187 | output = torch.zeros_like(target_seq) 188 | output_we = torch.zeros_like(target_seq) 189 | x_t = x_t.unsqueeze(0) 190 | x_t = x_t.unsqueeze(0) 191 | y_t = x_t[:,:,:] 192 | 193 | output[:,0,:] = x_t 194 | output_we[:,0,:] = x_t 195 | for i in range(0,S): 196 | x_t, hidden1, fi, f1, f2 = modelTEST(x_t,dt,dx, hidden1, 1) 197 | y_t, h1_we, fi_we, fi_we, fi_we = modelTEST(y_t,dt,dx, h1_we, 0) 198 | 199 | output[:,i+1,:] = x_t 200 | output_we[:,i+1,:] = y_t 201 | TVTEST[j] = compTV(IC).detach() 202 | all_ratioTEST[j] = (criterion(output.flatten(), target_seq.flatten())/criterion(output_we.flatten(), target_seq.flatten())).detach() 203 | print('J: ', [j]) 204 | print('Error: ',all_ratioTEST[j]) 205 | 206 | plt.figure() 207 | plt.plot(xc, output[0,-1,:].detach(),'.') 208 | plt.plot(xc, output_we[0,-1,:].detach(),'.') 209 | plt.plot(xc, target_seq[0,-1,:].detach()) 210 | plt.xlabel('$x$') 211 | plt.ylabel('$u$') 212 | plt.legend(('FiniteNet','WENO5','Exact')) 213 | 214 | plt.figure() 215 | plt.plot(discSZ_TEST, all_ratioTEST,'.') 216 | plt.xlabel('Discontinuity Size') 217 | plt.ylabel('Error Ratio') 218 | 219 | plt.figure(figsize=(6, 2)) 220 | heights,bins = np.histogram(torch.log10(all_ratioTEST).detach(),bins=20) 221 | heights = heights/sum(heights) 222 | plt.bar(bins[:-1],heights,width=(max(bins) - min(bins))/len(bins), color="blue", alpha=0.5) 223 | plt.xlabel('Error Ratio') 224 | plt.ylabel('Frequency') 225 | plt.xticks((-0.5,-0.375,-0.25,-0.125,0),['$10^{-0.5}$','$10^{-0.375}$','$10^{-0.25}$','$10^{-0.125}$','$10^0$']) 226 | plt.tight_layout() 227 | -------------------------------------------------------------------------------- /test_invNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Sep 24 18:08:55 2019 4 | 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn.parameter import Parameter 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import matplotlib as mpl 13 | import time 14 | 15 | from numpy import genfromtxt 16 | from HelperFunctions import makeIC, resh, reshKS, wenoCoeff, exactSol, randGrid 17 | plt.close('all') # close all open figures 18 | # Define and set custom LaTeX style 19 | styleNHN = { 20 | "pgf.rcfonts":False, 21 | "pgf.texsystem": "pdflatex", 22 | "text.usetex": False, 23 | "font.family": "serif" 24 | } 25 | mpl.rcParams.update(styleNHN) 26 | 27 | # Plotting defaults 28 | ALW = 0.75 # AxesLineWidth 29 | FSZ = 12 # Fontsize 30 | LW = 2 # LineWidth 31 | MSZ = 5 # MarkerSize 32 | SMALL_SIZE = 8 # Tiny font size 33 | MEDIUM_SIZE = 10 # Small font size 34 | BIGGER_SIZE = 14 # Large font size 35 | plt.rc('font', size=FSZ) # controls default text sizes 36 | plt.rc('axes', titlesize=FSZ) # fontsize of the axes title 37 | plt.rc('axes', labelsize=FSZ) # fontsize of the x and y labels 38 | plt.rc('xtick', labelsize=FSZ) # fontsize of the x-tick labels 39 | plt.rc('ytick', labelsize=FSZ) # fontsize of the y-tick labels 40 | plt.rc('legend', fontsize=FSZ) # legend fontsize 41 | plt.rc('figure', titlesize=FSZ) # fontsize of the figure title 42 | plt.rcParams['axes.linewidth'] = ALW # sets the default axes lindewidth to ``ALW'' 43 | plt.rcParams["mathtext.fontset"] = 'cm' # Computer Modern mathtext font (applies when ``usetex=False'') 44 | 45 | # Define the network architecture 46 | class Model(nn.Module): 47 | def __init__(self, input_size, output_size, hidden_dim, n_layers): 48 | super(Model, self).__init__() 49 | self.hidden_dim = hidden_dim 50 | self.n_layers = n_layers 51 | 52 | # RNN Layer 53 | self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True).double() 54 | # Fully connected layer 55 | #self.fc1 = nn.Linear(hidden_dim, hidden_dim).double() 56 | self.fc2 = nn.Linear(hidden_dim, 5).double() 57 | self.trf = nn.Linear(5, 5).double() 58 | #cm = Parameter(torch.tensor([[0.8,-0.2,-0.2,-0.2,-0.2],[-0.2,0.8,-0.2,-0.2,-0.2],[-0.2,-0.2,0.8,-0.2,-0.2],[-0.2,-0.2,-0.2,0.8,-0.2],[-0.2,-0.2,-0.2,-0.2,0.8]]).double())#1st order 59 | cm = Parameter(torch.tensor([[0.4,-0.4,-0.2,0,0.2],[-0.4,0.7,-0.2,-0.1,0],[-0.2,-0.2,0.8,-0.2,-0.2],[0,-0.1,-0.2,0.7,-0.4],[0.2,0,-0.2,-0.4,0.4]]).double())#2nd order 60 | #cm = Parameter(torch.tensor([[4/35,-9/35,3/35,1/7,-3/35],[-9/35,22/35,-12/35,-6/35,1/7],[3/35,-12/35,18/35,-12/35,3/35],[1/7,-6/35,-12/35,22/35,-9/35],[-3/35,1/7,3/35,-9/35,4/35]]).double())#3rd order 61 | #cv = Parameter(torch.tensor([0.2,0.2,0.2,0.2,0.2]).double())#1st order 62 | cv = Parameter(torch.tensor([0.1,0.15,0.2,0.25,0.3]).double())#2nd order 63 | #cv = Parameter(torch.tensor([-17/105,59/210,97/210,8/21,4/105]).double())#3rd order 64 | self.trf.bias = cv 65 | self.trf.weight = cm 66 | 67 | for p in self.trf.parameters(): 68 | p.requires_grad=False 69 | 70 | def forward(self, ui, dt, dx, hidden, test): 71 | Nx = ui.size(2) 72 | Nt = ui.size(1) 73 | uip = torch.zeros(Nx,Nt,5).double() 74 | 75 | uip = resh(ui) 76 | ci = wenoCoeff(uip) 77 | cm = torch.zeros_like(uip) 78 | cm[:,:,0] = 2/60 79 | cm[:,:,1] =-13/60 80 | cm[:,:,2] = 47/60 81 | cm[:,:,3] = 27/60 82 | cm[:,:,4] = -3/60 83 | if(test==1): 84 | f, hidden = self.lstm(uip, hidden)# Passing in the input and hidden state into the model and obtaining outputs 85 | #f = self.fc1(f) 86 | fi = self.fc2(f) 87 | f = fi + cm 88 | f = self.trf(f)#transform coefficients to be consistent 89 | else: 90 | f = ci.detach() 91 | fi = 0 92 | dui = torch.t(torch.sum(f*uip, dim = 2)).unsqueeze(0) 93 | u1 = ui - dt/dx*(dui**2-dui.roll(1,2)**2)/2 94 | 95 | u1p = resh(u1) 96 | c1 = wenoCoeff(u1p) 97 | if(test==1): 98 | f, hidden = self.lstm(u1p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 99 | #f = self.fc1(f) 100 | f1 = self.fc2(f) 101 | f = f1 + cm 102 | f = self.trf(f)#transform coefficients to be consistent 103 | else: 104 | f = c1.detach() 105 | f1 = 0 106 | du1 = torch.t(torch.sum(f*u1p, dim = 2)).unsqueeze(0) 107 | u2 = 3/4*ui + 1/4*u1 - 1/4*dt/dx*(du1**2-du1.roll(1,2)**2)/2 108 | 109 | u2p = resh(u2) 110 | c2 = wenoCoeff(u2p) 111 | if(test==1): 112 | f, hidden = self.lstm(u2p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 113 | #f = self.fc1(f) 114 | f2 = self.fc2(f) 115 | f = f2 + cm 116 | f = self.trf(f)#transform coefficients to be consistent 117 | else: 118 | f = c2.detach() 119 | f2 = 0 120 | du2 = torch.t(torch.sum(f*u2p, dim = 2)).unsqueeze(0) 121 | out = 1/3*ui + 2/3*u2 - 2/3*dt/dx*(du2**2-du2.roll(1,2)**2)/2 122 | return out, hidden, (fi**2+ f1**2+ f2**2) 123 | 124 | def init_hidden(self, batch_size): 125 | # This method generates the first hidden state of zeros which we'll use in the forward pass 126 | hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 127 | cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 128 | # We'll send the tensor holding the hidden state to the device we specified earlier as well 129 | hc = (hidden,cell) 130 | return hc 131 | 132 | # Instantiate the model with hyperparameters 133 | model = Model(input_size=5, output_size=1, hidden_dim=32, n_layers=3) 134 | 135 | # Define hyperparameters 136 | lr = 0.001 137 | 138 | # Define Loss, Optimizer 139 | criterion = nn.MSELoss() 140 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 141 | n_epochs = 400 142 | 143 | errs_WE = torch.zeros(n_epochs) 144 | errs_NN = torch.zeros(n_epochs) 145 | rel_err = torch.zeros(n_epochs) 146 | def burgEx(xgf,tgf,IC): 147 | Nt = xgf.size()[1]-1 148 | solt = torch.zeros_like(torch.t(xgf)) 149 | solt = solt.unsqueeze(0) 150 | solt[:,0,:] = IC 151 | x_t = IC 152 | x_t = x_t.unsqueeze(0) 153 | x_t = x_t.unsqueeze(0) 154 | for i in range(0,Nt): 155 | x_t, hidden1, p1 = model(x_t,dt,dx, 0, 0) 156 | solt[:,i+1,:] = x_t 157 | return solt[:,0::4,0::4] 158 | 159 | errs_WE = torch.zeros(n_epochs) 160 | errs_NN = torch.zeros(n_epochs) 161 | L = 1 162 | S = 100 163 | cfl = 0.25 164 | dx = 0.01 165 | dt = cfl*dx 166 | T = dt*(S) 167 | xc = torch.linspace(0,L,int(L/dx)+1,dtype=torch.double) 168 | xcf = torch.linspace(0,L,int(L/dx)*4+1,dtype=torch.double) 169 | xc = xc[:-1] 170 | xcf = xcf[:-1] 171 | dxf = xcf[1] - xcf[0] 172 | tc = torch.linspace(0,T,int(T/dt)+1,dtype=torch.double) 173 | tcf = torch.linspace(0,T,int(T/dt)*4+1,dtype=torch.double) 174 | dtf = tcf[1] - tcf[0] 175 | xg,tg = torch.meshgrid(xc,tc)#make the coarse grid 176 | xgf,tgf = torch.meshgrid(xcf,tcf)#make the fine grid 177 | 178 | def compTV(u): 179 | dif = abs(u.roll(1)-u) 180 | return torch.sum(dif) 181 | batch_size = xc.size(0) 182 | IC_fx = makeIC(L) 183 | mbs = 5 184 | 185 | model.load_state_dict(torch.load('InvNet')) 186 | 187 | n_tests = 1000 188 | test_rat = torch.zeros(n_tests) 189 | test_TV = torch.zeros(n_tests) 190 | min_rat = 1 191 | 192 | icf = makeIC(L) 193 | for j in range(0,n_tests): 194 | IC = icf(xcf) 195 | IC = IC - min(IC) 196 | target_seq = burgEx(xgf,tgf,IC) 197 | 198 | hidden1 = model.init_hidden(batch_size) 199 | h1_we = model.init_hidden(batch_size) 200 | 201 | x_t = target_seq[0,0,:] 202 | output = torch.zeros_like(target_seq) 203 | output_we = torch.zeros_like(target_seq) 204 | x_t = x_t.unsqueeze(0) 205 | x_t = x_t.unsqueeze(0) 206 | y_t = x_t[:,:,:] 207 | 208 | output[:,0,:] = x_t.detach() 209 | output_we[:,0,:] = x_t.detach() 210 | 211 | for i in range(0,S): 212 | x_t, hidden1, f1nn = model(x_t.detach(),dt,dx, hidden1, 1) 213 | y_t, h1_we, f1we = model(y_t.detach(),dt,dx, h1_we, 0) 214 | output[:,i+1,:] = x_t.detach() 215 | output_we[:,i+1,:] = y_t.detach() 216 | 217 | err_FNN = ((output-target_seq)**2).mean() 218 | err_WE5 = ((output_we-target_seq)**2).mean() 219 | test_rat[j] = (err_FNN/err_WE5).detach() 220 | test_TV[j] = compTV(target_seq[0,0,:]).detach() 221 | if(min_rat>=test_rat[j]): 222 | min_rat=test_rat[j].detach() 223 | sol_eg = target_seq.detach() 224 | sol_NN = output.detach() 225 | sol_WE = output_we.detach() 226 | print('Iter: ',j,' Err: ',test_rat[j]) 227 | 228 | #Plot total variation vs error ratio 229 | plt.figure() 230 | plt.plot(test_TV,test_rat,'.') 231 | plt.xlabel('$TV(u(x,0))$') 232 | plt.ylabel('Error Ratio') 233 | plt.tight_layout() 234 | 235 | #Animate a solution 236 | plt.figure() 237 | for i in range(0,len(tc)): 238 | plt.clf() 239 | plt.plot(xc,output[0,i,:].detach()) 240 | plt.plot(xc,output_we[0,i,:].detach()) 241 | plt.plot(xc,target_seq[0,i,:].detach()) 242 | plt.pause(0.001) 243 | plt.legend(('RNN','FDM','Exact')) 244 | 245 | #Make histogram of error ratio 246 | plt.figure(figsize=(6, 2)) 247 | heights,bins = np.histogram(torch.log10(test_rat).detach(),bins=20) 248 | heights = heights/sum(heights) 249 | plt.bar(bins[:-1],heights,width=(max(bins) - min(bins))/len(bins), color="blue", alpha=0.5) 250 | plt.xlabel('Error Ratio') 251 | plt.ylabel('Frequency') 252 | plt.xticks((-0.8,-0.6,-0.4,-0.2,0,0.2),['$10^{-0.8}$','$10^{-0.6}$','$10^{-0.4}$','$10^{-0.2}$','$10^0$','$10^{0.2}$']) 253 | plt.tight_layout() 254 | -------------------------------------------------------------------------------- /train_KSNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Sep 24 18:08:55 2019 4 | 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn.parameter import Parameter 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import matplotlib as mpl 13 | import time 14 | 15 | from numpy import genfromtxt 16 | from HelperFunctions import makeIC, resh, reshKS, wenoCoeff, exactSol, randGrid 17 | plt.close('all') # close all open figures 18 | # Define and set custom LaTeX style 19 | styleNHN = { 20 | "pgf.rcfonts":False, 21 | "pgf.texsystem": "pdflatex", 22 | "text.usetex": False, 23 | "font.family": "serif" 24 | } 25 | mpl.rcParams.update(styleNHN) 26 | 27 | # Plotting defaults 28 | ALW = 0.75 # AxesLineWidth 29 | FSZ = 12 # Fontsize 30 | LW = 2 # LineWidth 31 | MSZ = 5 # MarkerSize 32 | SMALL_SIZE = 8 # Tiny font size 33 | MEDIUM_SIZE = 10 # Small font size 34 | BIGGER_SIZE = 14 # Large font size 35 | plt.rc('font', size=FSZ) # controls default text sizes 36 | plt.rc('axes', titlesize=FSZ) # fontsize of the axes title 37 | plt.rc('axes', labelsize=FSZ) # fontsize of the x and y labels 38 | plt.rc('xtick', labelsize=FSZ) # fontsize of the x-tick labels 39 | plt.rc('ytick', labelsize=FSZ) # fontsize of the y-tick labels 40 | plt.rc('legend', fontsize=FSZ) # legend fontsize 41 | plt.rc('figure', titlesize=FSZ) # fontsize of the figure title 42 | plt.rcParams['axes.linewidth'] = ALW # sets the default axes lindewidth to ``ALW'' 43 | plt.rcParams["mathtext.fontset"] = 'cm' # Computer Modern mathtext font (applies when ``usetex=False'') 44 | 45 | #load the data here. coment out whenever possible 46 | 47 | import numpy as np 48 | from numpy import genfromtxt 49 | all_data = genfromtxt('KS_big_TT_keep.csv', delimiter=',') 50 | r2 = np.reshape(all_data, (3201,750,80)) 51 | 52 | class Model(nn.Module): 53 | def __init__(self, input_size, output_size, hidden_dim, n_layers): 54 | super(Model, self).__init__() 55 | self.hidden_dim = hidden_dim 56 | self.n_layers = n_layers 57 | 58 | # RNN Layer 59 | self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True).double() 60 | # Fully connected layer 61 | self.fc1 = nn.Linear(hidden_dim, 5).double() 62 | self.fc2 = nn.Linear(hidden_dim, 5).double() 63 | self.fc4 = nn.Linear(hidden_dim, 7).double() 64 | self.trf1 = nn.Linear(5, 5).double() 65 | self.trf2 = nn.Linear(5, 5).double() 66 | self.trf4 = nn.Linear(7, 7).double() 67 | #cm = Parameter(torch.tensor([[0.8,-0.2,-0.2,-0.2,-0.2],[-0.2,0.8,-0.2,-0.2,-0.2],[-0.2,-0.2,0.8,-0.2,-0.2],[-0.2,-0.2,-0.2,0.8,-0.2],[-0.2,-0.2,-0.2,-0.2,0.8]]).double())#1st order 68 | cm1 = Parameter(torch.tensor([[4/35,-9/35,3/35,1/7,-3/35],[-9/35,22/35,-12/35,-6/35,1/7],[3/35,-12/35,18/35,-12/35,3/35],[1/7,-6/35,-12/35,22/35,-9/35],[-3/35,1/7,3/35,-9/35,4/35]]).double())#2nd order 69 | cm2 = Parameter(torch.tensor([[1/70,-2/35,3/35,-2/35,1/70],[-2/35,8/35,-12/35,8/35,-2/35],[3/35,-12/35,18/35,-12/35,3/35],[-2/35,8/35,-12/35,8/35,-2/35],[1/70,-2/35,3/35,-2/35,1/70]]).double())#2nd order 70 | cm4 = Parameter(torch.tensor([[1/924,-1/154,5/308,-5/231,5/308,-1/154,1/924],[-1/154,3/77,-15/154,10/77,-15/154,3/77,-1/154],[5/308,-15/154,75/308,-25/77,75/308,-15/154,5/308], [-5/231,10/77,-25/77,100/231,-25/77,10/77,-5/231],[5/308,-15/154,75/308,-25/77,75/308,-15/154,5/308],[-1/154,3/77,-15/154,10/77,-15/154,3/77,-1/154],[1/924,-1/154,5/308,-5/231,5/308,-1/154,1/924]]).double())#2nd order 71 | #cm = Parameter(torch.tensor([[4/35,-9/35,3/35,1/7,-3/35],[-9/35,22/35,-12/35,-6/35,1/7],[3/35,-12/35,18/35,-12/35,3/35],[1/7,-6/35,-12/35,22/35,-9/35],[-3/35,1/7,3/35,-9/35,4/35]]).double())#3rd order 72 | #cv = Parameter(torch.tensor([0.2,0.2,0.2,0.2,0.2]).double())#1st order 73 | cv1 = Parameter(torch.tensor([-0.2,-0.1,0,0.1,0.2]).double())#2nd order 74 | cv2 = Parameter(torch.tensor([2/7,-1/7,-2/7,-1/7,2/7]).double())#2nd order 75 | cv4 = Parameter(torch.tensor([3/11,-7/11,1/11,6/11,1/11,-7/11,3/11]).double())#2nd order 76 | #cv = Parameter(torch.tensor([-17/105,59/210,97/210,8/21,4/105]).double())#3rd order 77 | self.trf1.bias = cv1 78 | self.trf1.weight = cm1 79 | self.trf2.bias = cv2 80 | self.trf2.weight = cm2 81 | self.trf4.bias = cv4 82 | self.trf4.weight = cm4 83 | 84 | for p in self.trf1.parameters(): 85 | p.requires_grad=False 86 | for p in self.trf2.parameters(): 87 | p.requires_grad=False 88 | for p in self.trf4.parameters(): 89 | p.requires_grad=False 90 | 91 | def forward(self, ui, dt, dx, hidden, test): 92 | Nx = ui.size(2) 93 | Nt = ui.size(1) 94 | uip = torch.zeros(Nx,Nt,7).double() 95 | 96 | uip = reshKS(ui) 97 | c1 = torch.zeros(Nx,Nt,5).double() 98 | c1[:,:,0] = 1/12 99 | c1[:,:,1] = -2/3 100 | c1[:,:,2] = 0 101 | c1[:,:,3] = 2/3 102 | c1[:,:,4] = -1/12 103 | c2 = torch.zeros(Nx,Nt,5).double() 104 | c2[:,:,0] = -1/12 105 | c2[:,:,1] = 4/3 106 | c2[:,:,2] = -5/2 107 | c2[:,:,3] = 4/3 108 | c2[:,:,4] = -1/12 109 | c4 = torch.zeros(Nx,Nt,7).double() 110 | c4[:,:,0] = -1/6 111 | c4[:,:,1] = 2 112 | c4[:,:,2] = -13/2 113 | c4[:,:,3] = 28/3 114 | c4[:,:,4] = -13/2 115 | c4[:,:,5] = 2 116 | c4[:,:,6] = -1/6 117 | vis = 1 118 | hvis = 0.1 119 | if(test==1): 120 | f, hidden = self.lstm(uip, hidden)# Passing in the input and hidden state into the model and obtaining outputs 121 | #f = self.fc1(f) 122 | fi1 = self.fc1(f) 123 | fi2 = self.fc2(f) 124 | fi4 = self.fc4(f) 125 | f1 = fi1 + c1 126 | f2 = fi2 + c2 127 | f4 = fi4 + c4 128 | f1 = self.trf1(f1)#transform coefficients to be consistent 129 | f2 = self.trf2(f2)#transform coefficients to be consistent 130 | f4 = self.trf4(f4)#transform coefficients to be consistent 131 | else: 132 | f1 = c1 133 | f2 = c2 134 | f4 = c4 135 | fi1 = 0 136 | fi2 = 0 137 | fi4 = 0 138 | F1 = 0.5*torch.t(torch.sum(f1*(uip[:,:,1:6]**2), dim = 2)).unsqueeze(0)/dx 139 | F2 = torch.t(torch.sum(f2*uip[:,:,1:6], dim = 2)).unsqueeze(0)/(dx**2) 140 | F4 = torch.t(torch.sum(f4*uip[:,:,:], dim = 2)).unsqueeze(0)/(dx**4) 141 | u1 = ui - dt*(F1 + F2*vis + F4*hvis) 142 | pen = fi1**2 + fi2**2 143 | pen2 = fi4**2 144 | #u1 = ui - dt/dx*(dui**2-dui.roll(1,2)**2)/2 145 | 146 | u1p = reshKS(u1) 147 | if(test==1): 148 | f, hidden = self.lstm(u1p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 149 | #f = self.fc1(f) 150 | f11 = self.fc1(f) 151 | f12 = self.fc2(f) 152 | f14 = self.fc4(f) 153 | f1 = f11 + c1 154 | f2 = f12 + c2 155 | f4 = f14 + c4 156 | f1 = self.trf1(f1)#transform coefficients to be consistent 157 | f2 = self.trf2(f2)#transform coefficients to be consistent 158 | f4 = self.trf4(f4)#transform coefficients to be consistent 159 | else: 160 | f1 = c1 161 | f2 = c2 162 | f4 = c4 163 | f11 = 0 164 | f12 = 0 165 | f14 = 0 166 | F1 = 0.5*torch.t(torch.sum(f1*(u1p[:,:,1:6]**2), dim = 2)).unsqueeze(0)/dx 167 | F2 = torch.t(torch.sum(f2*u1p[:,:,1:6], dim = 2)).unsqueeze(0)/(dx**2) 168 | F4 = torch.t(torch.sum(f4*u1p[:,:,:], dim = 2)).unsqueeze(0)/(dx**4) 169 | u2 = 3/4*ui + 1/4*u1 - 1/4*dt*(F1 + F2*vis + F4*hvis) 170 | pen += f11**2 + f12**2 171 | pen2 += f14**2 172 | u2p = reshKS(u2) 173 | if(test==1): 174 | f, hidden = self.lstm(u2p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 175 | #f = self.fc1(f) 176 | f21 = self.fc1(f) 177 | f22 = self.fc2(f) 178 | f24 = self.fc4(f) 179 | f1 = f21 + c1 180 | f2 = f22 + c2 181 | f4 = f24 + c4 182 | f1 = self.trf1(f1)#transform coefficients to be consistent 183 | f2 = self.trf2(f2)#transform coefficients to be consistent 184 | f4 = self.trf4(f4)#transform coefficients to be consistent 185 | else: 186 | f1 = c1 187 | f2 = c2 188 | f4 = c4 189 | f21 = 0 190 | f22 = 0 191 | f24 = 0 192 | F1 = 0.5*torch.t(torch.sum(f1*(u2p[:,:,1:6]**2), dim = 2)).unsqueeze(0)/dx 193 | F2 = torch.t(torch.sum(f2*u2p[:,:,1:6], dim = 2)).unsqueeze(0)/(dx**2) 194 | F4 = torch.t(torch.sum(f4*u2p[:,:,:], dim = 2)).unsqueeze(0)/(dx**4) 195 | out = 1/3*ui + 2/3*u2 - 2/3*dt*(F1 + F2*vis + F4*hvis) 196 | pen = f21**2 + f22**2 197 | pen2 = f24**2 198 | return out, hidden, pen, pen2 199 | 200 | def init_hidden(self, batch_size): 201 | # This method generates the first hidden state of zeros which we'll use in the forward pass 202 | hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 203 | cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 204 | # We'll send the tensor holding the hidden state to the device we specified earlier as well 205 | hc = (hidden,cell) 206 | return hc 207 | 208 | # Instantiate the model with hyperparameters 209 | model = Model(input_size=7, output_size=1, hidden_dim=32, n_layers=3) 210 | 211 | # Define hyperparameters 212 | lr = 0.001 213 | 214 | # Define Loss, Optimizer 215 | criterion = nn.MSELoss() 216 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 217 | 218 | def burgEx(xgf,tgf,IC): 219 | Nt = xgf.size()[1]-1 220 | solt = torch.zeros_like(torch.t(xgf)) 221 | solt = solt.unsqueeze(0) 222 | solt[:,0,:] = IC 223 | x_t = IC 224 | x_t = x_t.unsqueeze(0) 225 | x_t = x_t.unsqueeze(0) 226 | for i in range(0,Nt): 227 | x_t, hidden1, p1, p2 = model(x_t,dt,dx, 0, 0) 228 | solt[:,i+1,:] = x_t 229 | return solt 230 | 231 | def KS(xgf,tgf,IC,al,dt,dx): 232 | nx = np.shape(xgf)[0] 233 | nt = np.shape(xgf)[1] 234 | if(al == 0): 235 | u_ex = np.zeros((int(nx/4),int((nt-1)/256+1))) 236 | else: 237 | u_ex = np.zeros((nx,nt)) 238 | u0 = IC 239 | a = 0.1 240 | for i in range(0,nt): 241 | F1 = (1/12*(np.roll(u0,-2)**2)-2/3*(np.roll(u0,-1)**2) + 2/3*(np.roll(u0,1)**2)-1/12*(np.roll(u0,2)**2))/(2*dx); 242 | F2 = (-1/12*np.roll(u0,-2)+4/3*np.roll(u0,-1)-5/2*u0+4/3*np.roll(u0,1)-1/12*np.roll(u0,2))/(dx**2); 243 | F3 = a*(-1/6*np.roll(u0,-3)+2*np.roll(u0,-2)-13/2*np.roll(u0,-1)+28/3*u0-13/2*np.roll(u0,1)+2*np.roll(u0,2)-1/6*np.roll(u0,3))/(dx**4); 244 | u1 = u0 - dt*(F1+F2+F3); 245 | 246 | F1 = (1/12*(np.roll(u1,-2)**2)-2/3*(np.roll(u1,-1)**2) + 2/3*(np.roll(u1,1)**2)-1/12*(np.roll(u1,2)**2))/(2*dx); 247 | F2 = (-1/12*np.roll(u1,-2)+4/3*np.roll(u1,-1)-5/2*u1+4/3*np.roll(u1,1)-1/12*np.roll(u1,2))/(dx**2); 248 | F3 = a*(-1/6*np.roll(u1,-3)+2*np.roll(u1,-2)-13/2*np.roll(u1,-1)+28/3*u1-13/2*np.roll(u1,1)+2*np.roll(u1,2)-1/6*np.roll(u1,3))/(dx**4); 249 | u2 = 3/4*u0 + 1/4*u1 - 1/4*dt*(F1+F2+F3); 250 | 251 | F1 = (1/12*(np.roll(u2,-2)**2)-2/3*(np.roll(u2,-1)**2) + 2/3*(np.roll(u2,1)**2)-1/12*(np.roll(u2,2)**2))/(2*dx); 252 | F2 = (-1/12*np.roll(u2,-2)+4/3*np.roll(u2,-1)-5/2*u2+4/3*np.roll(u2,1)-1/12*np.roll(u2,2))/(dx**2); 253 | F3 = a*(-1/6*np.roll(u2,-3)+2*np.roll(u2,-2)-13/2*np.roll(u2,-1)+28/3*u2-13/2*np.roll(u2,1)+2*np.roll(u2,2)-1/6*np.roll(u2,3))/(dx**4); 254 | u0 = 1/3*u0 + 2/3*u2 - 2/3*dt*(F1+F2+F3); 255 | if(i%256==0 and al == 0): 256 | u_ex[:,int(i/256)] = u0[0::4] 257 | if(al==1): 258 | u_ex[:,i] = u0 259 | return u_ex 260 | 261 | # Training Run 262 | n_epochs = 400 263 | 264 | errs_WE = torch.zeros(n_epochs) 265 | errs_NN = torch.zeros(n_epochs) 266 | 267 | loadIt = 0 268 | if(loadIt == 1): 269 | model.load_state_dict(torch.load('KSNet_temp')) 270 | rel_err = torch.load('ks_storeError.pt') 271 | spe = torch.load('ks_storeEpoch.pt') 272 | else: 273 | rel_err = torch.zeros(n_epochs) 274 | spe = 0 275 | 276 | strt = time.time() 277 | L = 20 278 | 279 | S = 200 280 | nt = len(r2) 281 | mbs = 5 282 | 283 | cfl = 0.01 284 | cflf = 0.01/64 285 | dx = 0.25 286 | dxf = dx/4 287 | dt = cfl*dx 288 | dtf = cflf*dxf 289 | 290 | T = dt*(S) 291 | xc = torch.linspace(0,L,int(L/dx)+1,dtype=torch.double) 292 | xcf = torch.linspace(0,L,int(L/dxf)+1,dtype=torch.double) 293 | xc = xc[:-1] 294 | xcf = xcf[:-1] 295 | tc = torch.linspace(0,T,int(T/dt)+1,dtype=torch.double) 296 | tcf = torch.linspace(0,T,int(T/dtf)+1,dtype=torch.double) 297 | xg,tg = torch.meshgrid(xc,tc)#make the coarse grid 298 | xgf,tgf = torch.meshgrid(xcf,tcf)#make the fine grid 299 | for epoch in range(spe, n_epochs): 300 | #torch.manual_seed(7) 301 | strt = time.time() 302 | optimizer.zero_grad()#TODO: is this is the right place? 303 | target_seq_a = torch.zeros((1, S+1, len(xc), mbs),dtype=torch.double) 304 | output_a = torch.zeros_like(target_seq_a) 305 | output_we_a = torch.zeros_like(target_seq_a) 306 | pens2_a = torch.zeros((len(xc), S, 7, mbs),dtype=torch.double) 307 | pens1_a = torch.zeros((len(xc), S, 5, mbs),dtype=torch.double) 308 | for j in range(0,mbs): 309 | sp = np.random.randint(1600,high=(3200-S-1)) 310 | rsn = np.random.randint(600) 311 | target_seq = torch.tensor(r2[sp:sp+S+1,rsn,:]).unsqueeze(0) 312 | 313 | batch_size = xc.size(0) 314 | hidden1 = model.init_hidden(batch_size) 315 | h1_we = model.init_hidden(batch_size) 316 | 317 | x_t = target_seq[0,0,:] 318 | output = torch.zeros_like(target_seq) 319 | output_we = torch.zeros_like(target_seq) 320 | x_t = x_t.unsqueeze(0) 321 | x_t = x_t.unsqueeze(0) 322 | y_t = x_t[:,:,:] 323 | pens1 = torch.zeros((len(xc), S, 5),dtype=torch.double) 324 | pens2 = torch.zeros((len(xc), S, 7),dtype=torch.double) 325 | output[:,0,:] = x_t 326 | output_we[:,0,:] = x_t 327 | for i in range(0,S): 328 | #print(i) 329 | x_t, hidden1, pen1, pen2 = model(x_t,dt,dx, hidden1, 1) 330 | y_t, h1_we, abq, abq2 = model(y_t,dt,dx, h1_we, 0) 331 | pens1[:,i,:] = pen1.squeeze() 332 | pens2[:,i,:] = pen2.squeeze() 333 | output[:,i+1,:] = x_t 334 | output_we[:,i+1,:] = y_t 335 | target_seq_a[:,:,:,j] = target_seq 336 | output_a[:,:,:,j] = output 337 | output_we_a[:,:,:,j] = output_we.detach() 338 | pens2_a[:,:,:,j] = pens2 339 | pens1_a[:,:,:,j] = pens1 340 | loss = criterion(output_a.flatten(), target_seq_a.flatten()) + 0.1*(pens1_a.mean()+pens2_a.mean())#maybe penalize 1-norm 341 | loss.backward() # Does backpropagation and calculates gradients 342 | optimizer.step() # Updates the weights accordingly 343 | errs_NN[epoch] = criterion(output_a.flatten(), target_seq_a.flatten()).detach() 344 | errs_WE[epoch] = criterion(output_we_a.flatten(), target_seq_a.flatten()).detach() 345 | rel_err[epoch] = (errs_NN[epoch].detach())/(errs_WE[epoch].detach()) 346 | if epoch%1 == 0: 347 | print('Epoch: {}/{}.............'.format(epoch, n_epochs), end=' ') 348 | print("Loss: {:.4f}".format(rel_err[epoch].detach())) 349 | plt.clf() 350 | plt.semilogy(rel_err[rel_err!=0].detach()) 351 | plt.pause(0.01) 352 | torch.save(model.state_dict(), 'KSNet_temp') 353 | torch.save(rel_err, 'ks_storeError.pt') 354 | torch.save(epoch, 'ks_storeEpoch.pt') 355 | enddt = time.time() 356 | print('Time of Epoch: ', enddt-strt) 357 | torch.save(model.state_dict(), 'KSNet') 358 | -------------------------------------------------------------------------------- /train_advNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Sep 24 18:08:55 2019 4 | 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn.parameter import Parameter 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import matplotlib as mpl 13 | 14 | from HelperFunctions import makeICdsc, resh, wenoCoeff, exactSol, randGrid 15 | plt.close('all') # close all open figures 16 | # Define and set custom LaTeX style 17 | styleNHN = { 18 | "pgf.rcfonts":False, 19 | "pgf.texsystem": "pdflatex", 20 | "text.usetex": False, 21 | "font.family": "serif" 22 | } 23 | mpl.rcParams.update(styleNHN) 24 | 25 | # Plotting defaults 26 | ALW = 0.75 # AxesLineWidth 27 | FSZ = 12 # Fontsize 28 | LW = 2 # LineWidth 29 | MSZ = 5 # MarkerSize 30 | SMALL_SIZE = 8 # Tiny font size 31 | MEDIUM_SIZE = 10 # Small font size 32 | BIGGER_SIZE = 14 # Large font size 33 | plt.rc('font', size=FSZ) # controls default text sizes 34 | plt.rc('axes', titlesize=FSZ) # fontsize of the axes title 35 | plt.rc('axes', labelsize=FSZ) # fontsize of the x and y labels 36 | plt.rc('xtick', labelsize=FSZ) # fontsize of the x-tick labels 37 | plt.rc('ytick', labelsize=FSZ) # fontsize of the y-tick labels 38 | plt.rc('legend', fontsize=FSZ) # legend fontsize 39 | plt.rc('figure', titlesize=FSZ) # fontsize of the figure title 40 | plt.rcParams['axes.linewidth'] = ALW # sets the default axes lindewidth to ``ALW'' 41 | plt.rcParams["mathtext.fontset"] = 'cm' # Computer Modern mathtext font (applies when ``usetex=False'') 42 | 43 | class Model(nn.Module): 44 | def __init__(self, input_size, output_size, hidden_dim, n_layers): 45 | super(Model, self).__init__() 46 | self.hidden_dim = hidden_dim 47 | self.n_layers = n_layers 48 | 49 | # RNN Layer 50 | self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True).double() 51 | # Fully connected layer 52 | self.fc1 = nn.Linear(hidden_dim, hidden_dim).double() 53 | self.fc2 = nn.Linear(hidden_dim, 5).double() 54 | self.trf = nn.Linear(5, 5).double() 55 | #cm = Parameter(torch.tensor([[0.8,-0.2,-0.2,-0.2,-0.2],[-0.2,0.8,-0.2,-0.2,-0.2],[-0.2,-0.2,0.8,-0.2,-0.2],[-0.2,-0.2,-0.2,0.8,-0.2],[-0.2,-0.2,-0.2,-0.2,0.8]]).double())#1st order 56 | cm = Parameter(torch.tensor([[0.4,-0.4,-0.2,0,0.2],[-0.4,0.7,-0.2,-0.1,0],[-0.2,-0.2,0.8,-0.2,-0.2],[0,-0.1,-0.2,0.7,-0.4],[0.2,0,-0.2,-0.4,0.4]]).double())#2nd order 57 | #cm = Parameter(torch.tensor([[4/35,-9/35,3/35,1/7,-3/35],[-9/35,22/35,-12/35,-6/35,1/7],[3/35,-12/35,18/35,-12/35,3/35],[1/7,-6/35,-12/35,22/35,-9/35],[-3/35,1/7,3/35,-9/35,4/35]]).double())#3rd order 58 | #cv = Parameter(torch.tensor([0.2,0.2,0.2,0.2,0.2]).double())#1st order 59 | cv = Parameter(torch.tensor([0.1,0.15,0.2,0.25,0.3]).double())#2nd order 60 | #cv = Parameter(torch.tensor([-17/105,59/210,97/210,8/21,4/105]).double())#3rd order 61 | self.trf.bias = cv 62 | self.trf.weight = cm 63 | 64 | for p in self.trf.parameters(): 65 | p.requires_grad=False 66 | 67 | def forward(self, ui, dt, dx, hidden, test): 68 | Nx = ui.size(2) 69 | Nt = ui.size(1) 70 | uip = torch.zeros(Nx,Nt,5).double() 71 | 72 | uip = resh(ui) 73 | ci = wenoCoeff(uip) 74 | cm = torch.zeros_like(uip) 75 | cm[:,:,0] = 2/60 76 | cm[:,:,1] =-13/60 77 | cm[:,:,2] = 47/60 78 | cm[:,:,3] = 27/60 79 | cm[:,:,4] = -3/60 80 | if(test==1): 81 | f, hidden = self.lstm(uip, hidden)# Passing in the input and hidden state into the model and obtaining outputs 82 | #f = self.fc1(f) 83 | fi = self.fc2(f) 84 | f = fi + cm 85 | f = self.trf(f)#transform coefficients to be consistent 86 | else: 87 | f = ci 88 | fi = 0 89 | dui = torch.t(torch.sum(f*uip, dim = 2)).unsqueeze(0) 90 | u1 = ui - dt/dx*(dui-dui.roll(1,2)) 91 | 92 | u1p = resh(u1) 93 | c1 = wenoCoeff(u1p) 94 | if(test==1): 95 | f, hidden = self.lstm(u1p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 96 | #f = self.fc1(f) 97 | f1 = self.fc2(f) 98 | f = f1 + cm 99 | f = self.trf(f)#transform coefficients to be consistent 100 | else: 101 | f = c1 102 | f1 = 0 103 | du1 = torch.t(torch.sum(f*u1p, dim = 2)).unsqueeze(0) 104 | u2 = 3/4*ui + 1/4*u1 - 1/4*dt/dx*(du1-du1.roll(1,2)) 105 | 106 | u2p = resh(u2) 107 | c2 = wenoCoeff(u2p) 108 | if(test==1): 109 | f, hidden = self.lstm(u2p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 110 | #f = self.fc1(f) 111 | f2 = self.fc2(f) 112 | f = f2 + cm 113 | f = self.trf(f)#transform coefficients to be consistent 114 | else: 115 | f = c2 116 | f2 = 0 117 | du2 = torch.t(torch.sum(f*u2p, dim = 2)).unsqueeze(0) 118 | out = 1/3*ui + 2/3*u2 - 2/3*dt/dx*(du2-du2.roll(1,2)) 119 | return out, hidden, fi, f1, f2 120 | 121 | def init_hidden(self, batch_size): 122 | # This method generates the first hidden state of zeros which we'll use in the forward pass 123 | hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 124 | cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 125 | # We'll send the tensor holding the hidden state to the device we specified earlier as well 126 | hc = (hidden,cell) 127 | return hc 128 | 129 | # Instantiate the model with hyperparameters 130 | model = Model(input_size=5, output_size=1, hidden_dim=32, n_layers=3) 131 | 132 | # Define hyperparameters 133 | n_epochs = 400 134 | lr = 0.001 135 | 136 | # Define Loss, Optimizer 137 | criterion = nn.MSELoss() 138 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 139 | 140 | errs_WE = torch.zeros(n_epochs) 141 | errs_NN = torch.zeros(n_epochs) 142 | rel_err = torch.zeros(n_epochs) 143 | 144 | randgr = 0 145 | S = 100 146 | L = 1 147 | def burgEx(xgf,tgf,IC): 148 | Nt = xgf.size()[1]-1 149 | solt = torch.zeros_like(torch.t(xgf)) 150 | solt = solt.unsqueeze(0) 151 | solt[:,0,:] = IC 152 | x_t = IC 153 | x_t = x_t.unsqueeze(0) 154 | x_t = x_t.unsqueeze(0) 155 | for i in range(0,Nt): 156 | x_t, hidden1, fi, f1, f2 = model(x_t,dt,dx, 0, 0) 157 | solt[:,i+1,:] = x_t 158 | return solt 159 | 160 | def compTV(u): 161 | dif = abs(u.roll(1)-u) 162 | return torch.sum(dif) 163 | # Training Run 164 | if(randgr==0): 165 | cfl = 0.5 166 | dx = 0.01 167 | dt = cfl*dx 168 | T = dt*(S) 169 | xc = torch.linspace(0,L,int(L/dx)+1,dtype=torch.double) 170 | xcf = torch.linspace(0,L,int(L/dx)*4+1,dtype=torch.double) 171 | xc = xc[:-1] 172 | xcf = xcf[:-1] 173 | tc = torch.linspace(0,T,int(T/dt)+1,dtype=torch.double) 174 | tcf = torch.linspace(0,T,int(T/dt)*4+1,dtype=torch.double) 175 | xg,tg = torch.meshgrid(xc,tc)#make the coarse grid 176 | xgf,tgf = torch.meshgrid(xcf,tcf)#make the fine grid 177 | batch_size = xc.size(0) 178 | IC_fx = makeICdsc(L) 179 | mbs = 5 180 | 181 | TV = torch.zeros((n_epochs,mbs),dtype=torch.double) 182 | all_ratio = torch.zeros((n_epochs,mbs),dtype=torch.double) 183 | for epoch in range(0, n_epochs): 184 | optimizer.zero_grad() 185 | target_seq_a = torch.zeros((1,len(tc),len(xc),mbs),dtype=torch.double) 186 | output_a = torch.zeros((1,len(tc),len(xc),mbs),dtype=torch.double) 187 | output_we_a = torch.zeros((1,len(tc),len(xc),mbs),dtype=torch.double) 188 | fis_a = torch.zeros((len(xc),S,5,mbs),dtype=torch.double) 189 | f1s_a = torch.zeros((len(xc),S,5,mbs),dtype=torch.double) 190 | f2s_a = torch.zeros((len(xc),S,5,mbs),dtype=torch.double) 191 | for j in range(0,mbs): 192 | 193 | solt = torch.t(IC_fx((xg-tg)%L)).unsqueeze(0) 194 | IC = solt[0,0,:] 195 | target_seq = solt 196 | 197 | hidden1 = model.init_hidden(batch_size) 198 | h1_we = model.init_hidden(batch_size) 199 | 200 | x_t = solt[0,0,:] 201 | output = torch.zeros_like(target_seq) 202 | output_we = torch.zeros_like(target_seq) 203 | x_t = x_t.unsqueeze(0) 204 | x_t = x_t.unsqueeze(0) 205 | y_t = x_t[:,:,:] 206 | fis = torch.zeros((len(xc), S, 5),dtype=torch.double) 207 | f1s = torch.zeros((len(xc), S, 5),dtype=torch.double) 208 | f2s = torch.zeros((len(xc), S, 5),dtype=torch.double) 209 | output[:,0,:] = x_t 210 | output_we[:,0,:] = x_t 211 | for i in range(0,S): 212 | x_t, hidden1, fi, f1, f2 = model(x_t,dt,dx, hidden1, 1) 213 | y_t, h1_we, fi_we, fi_we, fi_we = model(y_t,dt,dx, h1_we, 0) 214 | fis[:,i,:] = fi.squeeze() 215 | f1s[:,i,:] = f1.squeeze() 216 | f2s[:,i,:] = f2.squeeze() 217 | output[:,i+1,:] = x_t 218 | output_we[:,i+1,:] = y_t 219 | target_seq_a[:,:,:,j] = target_seq 220 | output_a[:,:,:,j] = output 221 | output_we_a[:,:,:,j] = output_we 222 | all_ratio[epoch,j] = criterion(output.flatten(), target_seq.flatten())/criterion(output_we.flatten(), target_seq.flatten()) 223 | fis_a[:,:,:,j] = fis 224 | f1s_a[:,:,:,j] = f1s 225 | f2s_a[:,:,:,j] = f2s 226 | loss = criterion(output_a.flatten(), target_seq_a.flatten()) + 0.001*(fis_a**2+f1s_a**2+f2s_a**2).mean()#maybe penalize 1-norm 227 | loss.backward() # Does backpropagation and calculates gradients 228 | optimizer.step() # Updates the weights accordingly 229 | rel_err[epoch] = ((criterion(output_a.flatten(), target_seq_a.flatten()))/(criterion(output_we_a.flatten(), target_seq_a.flatten()))).detach() 230 | if epoch%10 == 0: 231 | ''' 232 | plt.clf() 233 | plt.plot(output.flatten().detach()) 234 | plt.plot(target_seq.flatten().detach()) 235 | plt.pause(0.001) 236 | ''' 237 | 238 | print('Epoch: {}/{}.............'.format(epoch, n_epochs), end=' ') 239 | print("Loss: {:.4f}".format(rel_err[epoch])) 240 | plt.clf() 241 | #plt.plot(torch.cat((rel_errppp,rel_err[rel_err!=0])).detach()) 242 | plt.semilogy(rel_err[rel_err!=0].detach()) 243 | plt.pause(0.01) 244 | -------------------------------------------------------------------------------- /train_invNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Sep 24 18:08:55 2019 4 | 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn.parameter import Parameter 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import matplotlib as mpl 13 | import time 14 | 15 | from numpy import genfromtxt 16 | from HelperFunctions import makeIC, resh, reshKS, wenoCoeff, exactSol, randGrid 17 | plt.close('all') # close all open figures 18 | # Define and set custom LaTeX style 19 | styleNHN = { 20 | "pgf.rcfonts":False, 21 | "pgf.texsystem": "pdflatex", 22 | "text.usetex": False, 23 | "font.family": "serif" 24 | } 25 | mpl.rcParams.update(styleNHN) 26 | 27 | # Plotting defaults 28 | ALW = 0.75 # AxesLineWidth 29 | FSZ = 12 # Fontsize 30 | LW = 2 # LineWidth 31 | MSZ = 5 # MarkerSize 32 | SMALL_SIZE = 8 # Tiny font size 33 | MEDIUM_SIZE = 10 # Small font size 34 | BIGGER_SIZE = 14 # Large font size 35 | plt.rc('font', size=FSZ) # controls default text sizes 36 | plt.rc('axes', titlesize=FSZ) # fontsize of the axes title 37 | plt.rc('axes', labelsize=FSZ) # fontsize of the x and y labels 38 | plt.rc('xtick', labelsize=FSZ) # fontsize of the x-tick labels 39 | plt.rc('ytick', labelsize=FSZ) # fontsize of the y-tick labels 40 | plt.rc('legend', fontsize=FSZ) # legend fontsize 41 | plt.rc('figure', titlesize=FSZ) # fontsize of the figure title 42 | plt.rcParams['axes.linewidth'] = ALW # sets the default axes lindewidth to ``ALW'' 43 | plt.rcParams["mathtext.fontset"] = 'cm' # Computer Modern mathtext font (applies when ``usetex=False'') 44 | 45 | r2 = np.load('invBurg_train.npy') 46 | 47 | class Model(nn.Module): 48 | def __init__(self, input_size, output_size, hidden_dim, n_layers): 49 | super(Model, self).__init__() 50 | self.hidden_dim = hidden_dim 51 | self.n_layers = n_layers 52 | 53 | # RNN Layer 54 | self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True).double() 55 | # Fully connected layer 56 | #self.fc1 = nn.Linear(hidden_dim, hidden_dim).double() 57 | self.fc2 = nn.Linear(hidden_dim, 5).double() 58 | self.trf = nn.Linear(5, 5).double() 59 | #cm = Parameter(torch.tensor([[0.8,-0.2,-0.2,-0.2,-0.2],[-0.2,0.8,-0.2,-0.2,-0.2],[-0.2,-0.2,0.8,-0.2,-0.2],[-0.2,-0.2,-0.2,0.8,-0.2],[-0.2,-0.2,-0.2,-0.2,0.8]]).double())#1st order 60 | cm = Parameter(torch.tensor([[0.4,-0.4,-0.2,0,0.2],[-0.4,0.7,-0.2,-0.1,0],[-0.2,-0.2,0.8,-0.2,-0.2],[0,-0.1,-0.2,0.7,-0.4],[0.2,0,-0.2,-0.4,0.4]]).double())#2nd order 61 | #cm = Parameter(torch.tensor([[4/35,-9/35,3/35,1/7,-3/35],[-9/35,22/35,-12/35,-6/35,1/7],[3/35,-12/35,18/35,-12/35,3/35],[1/7,-6/35,-12/35,22/35,-9/35],[-3/35,1/7,3/35,-9/35,4/35]]).double())#3rd order 62 | #cv = Parameter(torch.tensor([0.2,0.2,0.2,0.2,0.2]).double())#1st order 63 | cv = Parameter(torch.tensor([0.1,0.15,0.2,0.25,0.3]).double())#2nd order 64 | #cv = Parameter(torch.tensor([-17/105,59/210,97/210,8/21,4/105]).double())#3rd order 65 | self.trf.bias = cv 66 | self.trf.weight = cm 67 | 68 | for p in self.trf.parameters(): 69 | p.requires_grad=False 70 | 71 | def forward(self, ui, dt, dx, hidden, test): 72 | Nx = ui.size(2) 73 | Nt = ui.size(1) 74 | uip = torch.zeros(Nx,Nt,5).double() 75 | 76 | uip = resh(ui) 77 | ci = wenoCoeff(uip) 78 | cm = torch.zeros_like(uip) 79 | cm[:,:,0] = 2/60 80 | cm[:,:,1] =-13/60 81 | cm[:,:,2] = 47/60 82 | cm[:,:,3] = 27/60 83 | cm[:,:,4] = -3/60 84 | if(test==1): 85 | f, hidden = self.lstm(uip, hidden)# Passing in the input and hidden state into the model and obtaining outputs 86 | #f = self.fc1(f) 87 | fi = self.fc2(f) 88 | f = fi + cm 89 | f = self.trf(f)#transform coefficients to be consistent 90 | else: 91 | f = ci.detach() 92 | fi = 0 93 | dui = torch.t(torch.sum(f*uip, dim = 2)).unsqueeze(0) 94 | u1 = ui - dt/dx*(dui**2-dui.roll(1,2)**2)/2 95 | 96 | u1p = resh(u1) 97 | c1 = wenoCoeff(u1p) 98 | if(test==1): 99 | f, hidden = self.lstm(u1p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 100 | #f = self.fc1(f) 101 | f1 = self.fc2(f) 102 | f = f1 + cm 103 | f = self.trf(f)#transform coefficients to be consistent 104 | else: 105 | f = c1.detach() 106 | f1 = 0 107 | du1 = torch.t(torch.sum(f*u1p, dim = 2)).unsqueeze(0) 108 | u2 = 3/4*ui + 1/4*u1 - 1/4*dt/dx*(du1**2-du1.roll(1,2)**2)/2 109 | 110 | u2p = resh(u2) 111 | c2 = wenoCoeff(u2p) 112 | if(test==1): 113 | f, hidden = self.lstm(u2p, hidden)# Passing in the input and hidden state into the model and obtaining outputs 114 | #f = self.fc1(f) 115 | f2 = self.fc2(f) 116 | f = f2 + cm 117 | f = self.trf(f)#transform coefficients to be consistent 118 | else: 119 | f = c2.detach() 120 | f2 = 0 121 | du2 = torch.t(torch.sum(f*u2p, dim = 2)).unsqueeze(0) 122 | out = 1/3*ui + 2/3*u2 - 2/3*dt/dx*(du2**2-du2.roll(1,2)**2)/2 123 | return out, hidden, (fi**2+ f1**2+ f2**2) 124 | 125 | def init_hidden(self, batch_size): 126 | # This method generates the first hidden state of zeros which we'll use in the forward pass 127 | hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 128 | cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim,dtype=torch.double) 129 | # We'll send the tensor holding the hidden state to the device we specified earlier as well 130 | hc = (hidden,cell) 131 | return hc 132 | 133 | # Instantiate the model with hyperparameters 134 | model = Model(input_size=5, output_size=1, hidden_dim=32, n_layers=3) 135 | 136 | # Define hyperparameters 137 | lr = 0.001 138 | 139 | # Define Loss, Optimizer 140 | criterion = nn.MSELoss() 141 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 142 | n_epochs = 400 143 | 144 | errs_WE = torch.zeros(n_epochs) 145 | errs_NN = torch.zeros(n_epochs) 146 | rel_err = torch.zeros(n_epochs) 147 | def burgEx(xgf,tgf,IC): 148 | Nt = xgf.size()[1]-1 149 | solt = torch.zeros_like(torch.t(xgf)) 150 | solt = solt.unsqueeze(0) 151 | solt[:,0,:] = IC 152 | x_t = IC 153 | x_t = x_t.unsqueeze(0) 154 | x_t = x_t.unsqueeze(0) 155 | for i in range(0,Nt): 156 | x_t, hidden1, p1 = model(x_t,dt,dx, 0, 0) 157 | solt[:,i+1,:] = x_t 158 | return solt[:,0::4,0::4] 159 | 160 | errs_WE = torch.zeros(n_epochs) 161 | errs_NN = torch.zeros(n_epochs) 162 | L = 1 163 | S = 100 164 | cfl = 0.25 165 | dx = 0.01 166 | dt = cfl*dx 167 | T = dt*(S) 168 | xc = torch.linspace(0,L,int(L/dx)+1,dtype=torch.double) 169 | xcf = torch.linspace(0,L,int(L/dx)*4+1,dtype=torch.double) 170 | xc = xc[:-1] 171 | xcf = xcf[:-1] 172 | dxf = xcf[1] - xcf[0] 173 | tc = torch.linspace(0,T,int(T/dt)+1,dtype=torch.double) 174 | tcf = torch.linspace(0,T,int(T/dt)*4+1,dtype=torch.double) 175 | dtf = tcf[1] - tcf[0] 176 | xg,tg = torch.meshgrid(xc,tc)#make the coarse grid 177 | xgf,tgf = torch.meshgrid(xcf,tcf)#make the fine grid 178 | 179 | def compTV(u): 180 | dif = abs(u.roll(1)-u) 181 | return torch.sum(dif) 182 | batch_size = xc.size(0) 183 | IC_fx = makeIC(L) 184 | mbs = 5 185 | 186 | ordr = np.linspace(0,1999,num=2000) 187 | np.random.shuffle(ordr) 188 | tccnt = 0 189 | for epoch in range(0, n_epochs): 190 | #torch.manual_seed(7) 191 | strt = time.time() 192 | optimizer.zero_grad() 193 | target_seq_a = torch.zeros((1, S+1, len(xc), mbs),dtype=torch.double) 194 | output_a = torch.zeros_like(target_seq_a) 195 | output_we_a = torch.zeros_like(target_seq_a) 196 | fis_a = torch.zeros((len(xc),S,5,mbs),dtype=torch.double) 197 | for j in range(0,mbs): 198 | target_seq = torch.tensor(r2[int(ordr[tccnt]),:,:]).unsqueeze(0) 199 | tccnt += 1 200 | 201 | batch_size = xc.size(0) 202 | hidden1 = model.init_hidden(batch_size) 203 | h1_we = model.init_hidden(batch_size) 204 | 205 | x_t = target_seq[0,0,:] 206 | output = torch.zeros_like(target_seq) 207 | output_we = torch.zeros_like(target_seq) 208 | x_t = x_t.unsqueeze(0) 209 | x_t = x_t.unsqueeze(0) 210 | y_t = x_t[:,:,:] 211 | fis = torch.zeros((len(xc), S, 5),dtype=torch.double) 212 | output[:,0,:] = x_t 213 | output_we[:,0,:] = x_t 214 | for i in range(0,S): 215 | x_t, hidden1, fi = model(x_t,dt,dx, hidden1, 1) 216 | y_t, h1_we, fi_we = model(y_t,dt,dx, h1_we, 0) 217 | fis[:,i,:] = fi.squeeze() 218 | output[:,i+1,:] = x_t 219 | output_we[:,i+1,:] = y_t 220 | target_seq_a[:,:,:,j] = target_seq 221 | output_a[:,:,:,j] = output 222 | output_we_a[:,:,:,j] = output_we 223 | fis_a[:,:,:,j] = fis 224 | loss = criterion(output_a.flatten(), target_seq_a.flatten()) + 0.001*(fis_a).mean()#maybe penalize 1-norm 225 | loss.backward() # Does backpropagation and calculates gradients 226 | optimizer.step() # Updates the weights accordingly 227 | errs_NN[epoch] = criterion(output_a.flatten(), target_seq_a.flatten()).detach() 228 | errs_WE[epoch] = criterion(output_we_a.flatten(), target_seq_a.flatten()).detach() 229 | rel_err[epoch] = (errs_NN[epoch].detach())/(errs_WE[epoch].detach()) 230 | if epoch%10 == 0: 231 | print('Epoch: {}/{}.............'.format(epoch, n_epochs), end=' ') 232 | print("Loss: {:.4f}".format(rel_err[epoch].detach())) 233 | plt.clf() 234 | plt.plot(rel_err[rel_err!=0].detach()) 235 | plt.pause(0.01) 236 | #torch.save(model.state_dict(), 'InvNet_temp') 237 | #torch.save(rel_err, 'Inv_storeError.pt') 238 | #torch.save(epoch, 'Inv_storeEpoch.pt') 239 | enddt = time.time() 240 | print('Time of Epoch: ', enddt-strt) 241 | torch.save(model.state_dict(), 'InvNet') 242 | --------------------------------------------------------------------------------