├── README.md ├── data_generation ├── random_fields.py └── ns_2d.py ├── lowrank_2d.py ├── utilities3.py ├── lowrank_1d.py ├── fourier_1d.py ├── scripts ├── super_resolution.py ├── eval.py ├── fourier_on_images.py └── ns_fourier_3d_rnn.py ├── fourier_2d.py ├── lowrank_3d.py ├── lowrank_2d_rnn.py ├── fourier_2d_rnn.py └── fourier_3d.py /README.md: -------------------------------------------------------------------------------- 1 | # Fourier Neural Operator 2 | This repository contains the code for the paper: 3 | - [(FNO) Fourier Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2010.08895) 4 | 5 | In this work, we formulate a new neural operator by parameterizing the integral kernel directly in Fourier space, allowing for an expressive and efficient architecture. We perform experiments on Burgers' equation, Darcy flow, and the Navier-Stokes equation (including the turbulent regime). Our Fourier neural operator shows state-of-the-art performance compared to existing neural network methodologies and it is up to three orders of magnitude faster compared to traditional PDE solvers. 6 | 7 | It follows from the previous works: 8 | - [(GKN) Neural Operator: Graph Kernel Network for Partial Differential Equations](https://arxiv.org/abs/2003.03485) 9 | - [(MGKN) Multipole Graph Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2006.09535) 10 | 11 | 12 | ## Requirements 13 | - [PyTorch](https://pytorch.org/) 14 | 15 | ## Files 16 | The code is in the form of simple scripts. Each script shall be stand-alone and directly runnable. 17 | 18 | ## Datasets 19 | We provide the Burgers equation and Darcy flow datasets we used in the paper. The data generation can be found in the paper. 20 | The data are given in the form of matlab file. They can be loaded with the scripts provided in utilities.py. 21 | - [PDE datasets](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-?usp=sharing) 22 | 23 | ## Models 24 | Here are the pre-trained models. It can be evaluated using _eval.py_ or _super_resolution.py_. 25 | - [models](https://drive.google.com/drive/folders/1swLA6yKR1f3PKdYSKhLqK4zfNjS9pt_U?usp=sharing) 26 | -------------------------------------------------------------------------------- /data_generation/random_fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | 7 | from timeit import default_timer 8 | 9 | class GaussianRF(object): 10 | 11 | def __init__(self, dim, size, alpha=2, tau=3, sigma=None, boundary="periodic", device=None): 12 | 13 | self.dim = dim 14 | self.device = device 15 | 16 | if sigma is None: 17 | sigma = tau**(0.5*(2*alpha - self.dim)) 18 | 19 | k_max = size//2 20 | 21 | if dim == 1: 22 | k = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \ 23 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0) 24 | 25 | self.sqrt_eig = size*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k**2) + tau**2)**(-alpha/2.0)) 26 | self.sqrt_eig[0] = 0.0 27 | 28 | elif dim == 2: 29 | wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \ 30 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).repeat(size,1) 31 | 32 | k_x = wavenumers.transpose(0,1) 33 | k_y = wavenumers 34 | 35 | self.sqrt_eig = (size**2)*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k_x**2 + k_y**2) + tau**2)**(-alpha/2.0)) 36 | self.sqrt_eig[0,0] = 0.0 37 | 38 | elif dim == 3: 39 | wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \ 40 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).repeat(size,size,1) 41 | 42 | k_x = wavenumers.transpose(1,2) 43 | k_y = wavenumers 44 | k_z = wavenumers.transpose(0,2) 45 | 46 | self.sqrt_eig = (size**3)*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k_x**2 + k_y**2 + k_z**2) + tau**2)**(-alpha/2.0)) 47 | self.sqrt_eig[0,0,0] = 0.0 48 | 49 | self.size = [] 50 | for j in range(self.dim): 51 | self.size.append(size) 52 | 53 | self.size = tuple(self.size) 54 | 55 | def sample(self, N): 56 | 57 | coeff = torch.randn(N, *self.size, 2, device=self.device) 58 | 59 | coeff[...,0] = self.sqrt_eig*coeff[...,0] 60 | coeff[...,1] = self.sqrt_eig*coeff[...,1] 61 | 62 | u = torch.ifft(coeff, self.dim, normalized=False) 63 | u = u[...,0] 64 | 65 | return u 66 | -------------------------------------------------------------------------------- /data_generation/ns_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | 5 | import matplotlib.pyplot as plt 6 | import matplotlib 7 | 8 | # from drawnow import drawnow, figure 9 | 10 | from random_fields import GaussianRF 11 | 12 | from timeit import default_timer 13 | 14 | import scipy.io 15 | 16 | #w0: initial vorticity 17 | #f: forcing term 18 | #visc: viscosity (1/Re) 19 | #T: final time 20 | #delta_t: internal time-step for solve (descrease if blow-up) 21 | #record_steps: number of in-time snapshots to record 22 | def navier_stokes_2d(w0, f, visc, T, delta_t=1e-4, record_steps=1): 23 | 24 | #Grid size - must be power of 2 25 | N = w0.size()[-1] 26 | 27 | #Maximum frequency 28 | k_max = math.floor(N/2.0) 29 | 30 | #Number of steps to final time 31 | steps = math.ceil(T/delta_t) 32 | 33 | #Initial vorticity to Fourier space 34 | w_h = torch.rfft(w0, 2, normalized=False, onesided=False) 35 | 36 | #Forcing to Fourier space 37 | f_h = torch.rfft(f, 2, normalized=False, onesided=False) 38 | 39 | #If same forcing for the whole batch 40 | if len(f_h.size()) < len(w_h.size()): 41 | f_h = torch.unsqueeze(f_h, 0) 42 | 43 | #Record solution every this number of steps 44 | record_time = math.floor(steps/record_steps) 45 | 46 | #Wavenumbers in y-direction 47 | k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=w0.device), torch.arange(start=-k_max, end=0, step=1, device=w0.device)), 0).repeat(N,1) 48 | #Wavenumbers in x-direction 49 | k_x = k_y.transpose(0,1) 50 | #Negative Laplacian in Fourier space 51 | lap = 4*(math.pi**2)*(k_x**2 + k_y**2) 52 | lap[0,0] = 1.0 53 | #Dealiasing mask 54 | dealias = torch.unsqueeze(torch.logical_and(torch.abs(k_y) <= (2.0/3.0)*k_max, torch.abs(k_x) <= (2.0/3.0)*k_max).float(), 0) 55 | 56 | #Saving solution and time 57 | sol = torch.zeros(*w0.size(), record_steps, device=w0.device) 58 | sol_t = torch.zeros(record_steps, device=w0.device) 59 | 60 | #Record counter 61 | c = 0 62 | #Physical time 63 | t = 0.0 64 | for j in range(steps): 65 | #Stream function in Fourier space: solve Poisson equation 66 | psi_h = w_h.clone() 67 | psi_h[...,0] = psi_h[...,0]/lap 68 | psi_h[...,1] = psi_h[...,1]/lap 69 | 70 | #Velocity field in x-direction = psi_y 71 | q = psi_h.clone() 72 | temp = q[...,0].clone() 73 | q[...,0] = -2*math.pi*k_y*q[...,1] 74 | q[...,1] = 2*math.pi*k_y*temp 75 | q = torch.irfft(q, 2, normalized=False, onesided=False, signal_sizes=(N,N)) 76 | 77 | #Velocity field in y-direction = -psi_x 78 | v = psi_h.clone() 79 | temp = v[...,0].clone() 80 | v[...,0] = 2*math.pi*k_x*v[...,1] 81 | v[...,1] = -2*math.pi*k_x*temp 82 | v = torch.irfft(v, 2, normalized=False, onesided=False, signal_sizes=(N,N)) 83 | 84 | #Partial x of vorticity 85 | w_x = w_h.clone() 86 | temp = w_x[...,0].clone() 87 | w_x[...,0] = -2*math.pi*k_x*w_x[...,1] 88 | w_x[...,1] = 2*math.pi*k_x*temp 89 | w_x = torch.irfft(w_x, 2, normalized=False, onesided=False, signal_sizes=(N,N)) 90 | 91 | #Partial y of vorticity 92 | w_y = w_h.clone() 93 | temp = w_y[...,0].clone() 94 | w_y[...,0] = -2*math.pi*k_y*w_y[...,1] 95 | w_y[...,1] = 2*math.pi*k_y*temp 96 | w_y = torch.irfft(w_y, 2, normalized=False, onesided=False, signal_sizes=(N,N)) 97 | 98 | #Non-linear term (u.grad(w)): compute in physical space then back to Fourier space 99 | F_h = torch.rfft(q*w_x + v*w_y, 2, normalized=False, onesided=False) 100 | 101 | #Dealias 102 | F_h[...,0] = dealias* F_h[...,0] 103 | F_h[...,1] = dealias* F_h[...,1] 104 | 105 | #Cranck-Nicholson update 106 | w_h[...,0] = (-delta_t*F_h[...,0] + delta_t*f_h[...,0] + (1.0 - 0.5*delta_t*visc*lap)*w_h[...,0])/(1.0 + 0.5*delta_t*visc*lap) 107 | w_h[...,1] = (-delta_t*F_h[...,1] + delta_t*f_h[...,1] + (1.0 - 0.5*delta_t*visc*lap)*w_h[...,1])/(1.0 + 0.5*delta_t*visc*lap) 108 | 109 | #Update real time (used only for recording) 110 | t += delta_t 111 | 112 | if (j+1) % record_time == 0: 113 | #Solution in physical space 114 | w = torch.irfft(w_h, 2, normalized=False, onesided=False, signal_sizes=(N,N)) 115 | 116 | #Record solution and time 117 | sol[...,c] = w 118 | sol_t[c] = t 119 | 120 | c += 1 121 | 122 | 123 | return sol, sol_t 124 | 125 | 126 | device = torch.device('cuda') 127 | 128 | #Resolution 129 | s = 256 130 | sub = 1 131 | 132 | #Number of solutions to generate 133 | N = 20 134 | 135 | #Set up 2d GRF with covariance parameters 136 | GRF = GaussianRF(2, s, alpha=2.5, tau=7, device=device) 137 | 138 | #Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y))) 139 | t = torch.linspace(0, 1, s+1, device=device) 140 | t = t[0:-1] 141 | 142 | X,Y = torch.meshgrid(t, t) 143 | f = 0.1*(torch.sin(2*math.pi*(X + Y)) + torch.cos(2*math.pi*(X + Y))) 144 | 145 | #Number of snapshots from solution 146 | record_steps = 200 147 | 148 | #Inputs 149 | a = torch.zeros(N, s, s) 150 | #Solutions 151 | u = torch.zeros(N, s, s, record_steps) 152 | 153 | #Solve equations in batches (order of magnitude speed-up) 154 | 155 | #Batch size 156 | bsize = 20 157 | 158 | c = 0 159 | t0 =default_timer() 160 | for j in range(N//bsize): 161 | 162 | #Sample random feilds 163 | w0 = GRF.sample(bsize) 164 | 165 | #Solve NS 166 | sol, sol_t = navier_stokes_2d(w0, f, 1e-3, 50.0, 1e-4, record_steps) 167 | 168 | a[c:(c+bsize),...] = w0 169 | u[c:(c+bsize),...] = sol 170 | 171 | c += bsize 172 | t1 = default_timer() 173 | print(j, c, t1-t0) 174 | 175 | scipy.io.savemat('ns_data.mat', mdict={'a': a.cpu().numpy(), 'u': u.cpu().numpy(), 't': sol_t.cpu().numpy()}) 176 | -------------------------------------------------------------------------------- /lowrank_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | import h5py 7 | import scipy.io 8 | import matplotlib.pyplot as plt 9 | from timeit import default_timer 10 | import sys 11 | import math 12 | 13 | import operator 14 | from functools import reduce 15 | 16 | from timeit import default_timer 17 | from utilities import * 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | 22 | 23 | 24 | class LowRank2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, s, width, rank): 26 | super(LowRank2d, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.s = s 30 | self.n = s*s 31 | self.rank = rank 32 | 33 | self.phi = DenseNet([3, 64, 128, 256, width*width*rank], torch.nn.ReLU) 34 | self.psi = DenseNet([3, 64, 128, 256, width*width*rank], torch.nn.ReLU) 35 | 36 | 37 | def forward(self, v, a): 38 | # a (batch, n, 3) 39 | # v (batch, n, f) 40 | batch_size = v.shape[0] 41 | 42 | phi_eval = self.phi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 43 | psi_eval = self.psi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 44 | 45 | # print(psi_eval.shape, v.shape, phi_eval.shape) 46 | v = torch.einsum('bnoir,bni,bmoir->bmo', psi_eval, v, phi_eval) / self.n 47 | 48 | return v 49 | 50 | 51 | 52 | class MyNet(torch.nn.Module): 53 | def __init__(self, s, width=32, rank=1): 54 | super(MyNet, self).__init__() 55 | self.s = s 56 | self.width = width 57 | self.rank = rank 58 | 59 | self.fc0 = nn.Linear(3, self.width) 60 | 61 | self.net1 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 62 | self.net2 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 63 | self.net3 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 64 | self.net4 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 65 | self.w1 = nn.Linear(self.width, self.width) 66 | self.w2 = nn.Linear(self.width, self.width) 67 | self.w3 = nn.Linear(self.width, self.width) 68 | self.w4 = nn.Linear(self.width, self.width) 69 | 70 | self.bn1 = torch.nn.BatchNorm1d(self.width) 71 | self.bn2 = torch.nn.BatchNorm1d(self.width) 72 | self.bn3 = torch.nn.BatchNorm1d(self.width) 73 | self.bn4 = torch.nn.BatchNorm1d(self.width) 74 | 75 | self.fc1 = nn.Linear(self.width, 128) 76 | self.fc2 = nn.Linear(128, 1) 77 | 78 | 79 | def forward(self, v): 80 | batch_size, n = v.shape[0], v.shape[1] 81 | a = v.clone() 82 | 83 | v = self.fc0(v) 84 | 85 | v1 = self.net1(v, a) 86 | v2 = self.w1(v) 87 | v = v1+v2 88 | v = self.bn1(v.reshape(-1, self.width)).view(batch_size,n,self.width) 89 | v = F.relu(v) 90 | 91 | v1 = self.net2(v, a) 92 | v2 = self.w2(v) 93 | v = v1+v2 94 | v = self.bn2(v.reshape(-1, self.width)).view(batch_size,n,self.width) 95 | v = F.relu(v) 96 | 97 | v1 = self.net3(v, a) 98 | v2 = self.w3(v) 99 | v = v1+v2 100 | v = self.bn3(v.reshape(-1, self.width)).view(batch_size,n,self.width) 101 | v = F.relu(v) 102 | 103 | v1 = self.net4(v, a) 104 | v2 = self.w4(v) 105 | v = v1+v2 106 | v = self.bn4(v.reshape(-1, self.width)).view(batch_size,n,self.width) 107 | 108 | 109 | v = self.fc1(v) 110 | v = F.relu(v) 111 | v = self.fc2(v) 112 | 113 | return v.squeeze() 114 | 115 | def count_params(self): 116 | c = 0 117 | for p in self.parameters(): 118 | c += reduce(operator.mul, list(p.size())) 119 | 120 | return c 121 | 122 | 123 | 124 | 125 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 126 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 127 | 128 | ntrain = 1000 129 | ntest = 100 130 | 131 | batch_size = 10 132 | 133 | r = 5 134 | h = int(((421 - 1)/r) + 1) 135 | s = h 136 | 137 | learning_rate = 0.00025 138 | 139 | reader = MatReader(TRAIN_PATH) 140 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s].reshape(ntrain,s*s) 141 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s].reshape(ntrain,s*s) 142 | 143 | reader.load_file(TEST_PATH) 144 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s].reshape(ntest,s*s) 145 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s].reshape(ntest,s*s) 146 | 147 | 148 | x_normalizer = UnitGaussianNormalizer(x_train) 149 | x_train = x_normalizer.encode(x_train) 150 | x_test = x_normalizer.encode(x_test) 151 | 152 | y_normalizer = UnitGaussianNormalizer(y_train) 153 | y_train = y_normalizer.encode(y_train) 154 | 155 | grids = [] 156 | grids.append(np.linspace(0, 1, s)) 157 | grids.append(np.linspace(0, 1, s)) 158 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 159 | grid = grid.reshape(1,s*s,2) 160 | grid = torch.tensor(grid, dtype=torch.float) 161 | x_train = torch.cat([x_train.reshape(ntrain,s*s,1), grid.repeat(ntrain,1,1)], dim=2) 162 | x_test = torch.cat([x_test.reshape(ntest,s*s,1), grid.repeat(ntest,1,1)], dim=2) 163 | 164 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 165 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 166 | 167 | model = MyNet(s).cuda() 168 | # model = MyNet_old(s).cuda() 169 | 170 | print(model.count_params()) 171 | 172 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 173 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 174 | epochs = 200 175 | 176 | myloss = LpLoss(size_average=False) 177 | y_normalizer.cuda() 178 | for ep in range(epochs): 179 | model.train() 180 | t1 = default_timer() 181 | train_mse = 0 182 | train_l2 = 0 183 | for x, y in train_loader: 184 | x, y = x.cuda(), y.cuda() 185 | 186 | optimizer.zero_grad() 187 | out = model(x).reshape(batch_size, s*s) 188 | 189 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean') 190 | mse.backward() 191 | 192 | out = y_normalizer.decode(out) 193 | y = y_normalizer.decode(y) 194 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 195 | # loss.backward() 196 | 197 | optimizer.step() 198 | train_mse += mse.item() 199 | train_l2 += loss.item() 200 | 201 | scheduler.step() 202 | 203 | model.eval() 204 | test_l2 = 0.0 205 | with torch.no_grad(): 206 | for x, y in test_loader: 207 | x, y = x.cuda(), y.cuda() 208 | 209 | out = model(x).reshape(batch_size, s*s) 210 | out = y_normalizer.decode(out) 211 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 212 | 213 | train_mse /= len(train_loader) 214 | train_l2 /= ntrain 215 | test_l2 /= ntest 216 | 217 | t2 = default_timer() 218 | print(ep, t2-t1, train_mse, train_l2, test_l2) 219 | -------------------------------------------------------------------------------- /utilities3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import sklearn.metrics 6 | import torch.nn as nn 7 | from scipy.ndimage import gaussian_filter 8 | 9 | 10 | ################################################# 11 | # 12 | # Utilities 13 | # 14 | ################################################# 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | # reading data 18 | class MatReader(object): 19 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 20 | super(MatReader, self).__init__() 21 | 22 | self.to_torch = to_torch 23 | self.to_cuda = to_cuda 24 | self.to_float = to_float 25 | 26 | self.file_path = file_path 27 | 28 | self.data = None 29 | self.old_mat = None 30 | self._load_file() 31 | 32 | def _load_file(self): 33 | try: 34 | self.data = scipy.io.loadmat(self.file_path) 35 | self.old_mat = True 36 | except: 37 | self.data = h5py.File(self.file_path) 38 | self.old_mat = False 39 | 40 | def load_file(self, file_path): 41 | self.file_path = file_path 42 | self._load_file() 43 | 44 | def read_field(self, field): 45 | x = self.data[field] 46 | 47 | if not self.old_mat: 48 | x = x[()] 49 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 50 | 51 | if self.to_float: 52 | x = x.astype(np.float32) 53 | 54 | if self.to_torch: 55 | x = torch.from_numpy(x) 56 | 57 | if self.to_cuda: 58 | x = x.cuda() 59 | 60 | return x 61 | 62 | def set_cuda(self, to_cuda): 63 | self.to_cuda = to_cuda 64 | 65 | def set_torch(self, to_torch): 66 | self.to_torch = to_torch 67 | 68 | def set_float(self, to_float): 69 | self.to_float = to_float 70 | 71 | # normalization, pointwise gaussian 72 | class UnitGaussianNormalizer(object): 73 | def __init__(self, x, eps=0.00001): 74 | super(UnitGaussianNormalizer, self).__init__() 75 | 76 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 77 | self.mean = torch.mean(x, 0) 78 | self.std = torch.std(x, 0) 79 | self.eps = eps 80 | 81 | def encode(self, x): 82 | x = (x - self.mean) / (self.std + self.eps) 83 | return x 84 | 85 | def decode(self, x, sample_idx=None): 86 | if sample_idx is None: 87 | std = self.std + self.eps # n 88 | mean = self.mean 89 | else: 90 | if len(self.mean.shape) == len(sample_idx[0].shape): 91 | std = self.std[sample_idx] + self.eps # batch*n 92 | mean = self.mean[sample_idx] 93 | if len(self.mean.shape) > len(sample_idx[0].shape): 94 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 95 | mean = self.mean[:,sample_idx] 96 | 97 | # x is in shape of batch*n or T*batch*n 98 | x = (x * std) + mean 99 | return x 100 | 101 | def cuda(self): 102 | self.mean = self.mean.cuda() 103 | self.std = self.std.cuda() 104 | 105 | def cpu(self): 106 | self.mean = self.mean.cpu() 107 | self.std = self.std.cpu() 108 | 109 | # normalization, Gaussian 110 | class GaussianNormalizer(object): 111 | def __init__(self, x, eps=0.00001): 112 | super(GaussianNormalizer, self).__init__() 113 | 114 | self.mean = torch.mean(x) 115 | self.std = torch.std(x) 116 | self.eps = eps 117 | 118 | def encode(self, x): 119 | x = (x - self.mean) / (self.std + self.eps) 120 | return x 121 | 122 | def decode(self, x, sample_idx=None): 123 | x = (x * (self.std + self.eps)) + self.mean 124 | return x 125 | 126 | def cuda(self): 127 | self.mean = self.mean.cuda() 128 | self.std = self.std.cuda() 129 | 130 | def cpu(self): 131 | self.mean = self.mean.cpu() 132 | self.std = self.std.cpu() 133 | 134 | 135 | # normalization, scaling by range 136 | class RangeNormalizer(object): 137 | def __init__(self, x, low=0.0, high=1.0): 138 | super(RangeNormalizer, self).__init__() 139 | mymin = torch.min(x, 0)[0].view(-1) 140 | mymax = torch.max(x, 0)[0].view(-1) 141 | 142 | self.a = (high - low)/(mymax - mymin) 143 | self.b = -self.a*mymax + high 144 | 145 | def encode(self, x): 146 | s = x.size() 147 | x = x.view(s[0], -1) 148 | x = self.a*x + self.b 149 | x = x.view(s) 150 | return x 151 | 152 | def decode(self, x): 153 | s = x.size() 154 | x = x.view(s[0], -1) 155 | x = (x - self.b)/self.a 156 | x = x.view(s) 157 | return x 158 | 159 | #loss function with rel/abs Lp loss 160 | class LpLoss(object): 161 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 162 | super(LpLoss, self).__init__() 163 | 164 | #Dimension and Lp-norm type are postive 165 | assert d > 0 and p > 0 166 | 167 | self.d = d 168 | self.p = p 169 | self.reduction = reduction 170 | self.size_average = size_average 171 | 172 | def abs(self, x, y): 173 | num_examples = x.size()[0] 174 | 175 | #Assume uniform mesh 176 | h = 1.0 / (x.size()[1] - 1.0) 177 | 178 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 179 | 180 | if self.reduction: 181 | if self.size_average: 182 | return torch.mean(all_norms) 183 | else: 184 | return torch.sum(all_norms) 185 | 186 | return all_norms 187 | 188 | def rel(self, x, y): 189 | num_examples = x.size()[0] 190 | 191 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 192 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 193 | 194 | if self.reduction: 195 | if self.size_average: 196 | return torch.mean(diff_norms/y_norms) 197 | else: 198 | return torch.sum(diff_norms/y_norms) 199 | 200 | return diff_norms/y_norms 201 | 202 | def __call__(self, x, y): 203 | return self.rel(x, y) 204 | 205 | # A simple feedforward neural network 206 | class DenseNet(torch.nn.Module): 207 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 208 | super(DenseNet, self).__init__() 209 | 210 | self.n_layers = len(layers) - 1 211 | 212 | assert self.n_layers >= 1 213 | 214 | self.layers = nn.ModuleList() 215 | 216 | for j in range(self.n_layers): 217 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 218 | 219 | if j != self.n_layers - 1: 220 | if normalize: 221 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 222 | 223 | self.layers.append(nonlinearity()) 224 | 225 | if out_nonlinearity is not None: 226 | self.layers.append(out_nonlinearity()) 227 | 228 | def forward(self, x): 229 | for _, l in enumerate(self.layers): 230 | x = l(x) 231 | 232 | return x 233 | -------------------------------------------------------------------------------- /lowrank_1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | import h5py 7 | import scipy.io 8 | import matplotlib.pyplot as plt 9 | from timeit import default_timer 10 | import sys 11 | import math 12 | 13 | import operator 14 | from functools import reduce 15 | 16 | from timeit import default_timer 17 | from utilities3 import * 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | 22 | ################################################################ 23 | # lowrank layer 24 | ################################################################ 25 | class LowRank1d(nn.Module): 26 | def __init__(self, in_channels, out_channels, s, width, rank=1): 27 | super(LowRank1d, self).__init__() 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.s = s 31 | self.n = s 32 | self.rank = rank 33 | 34 | self.phi = DenseNet([2, 64, 128, 256, width*width*rank], torch.nn.ReLU) 35 | self.psi = DenseNet([2, 64, 128, 256, width*width*rank], torch.nn.ReLU) 36 | 37 | 38 | def forward(self, v, a): 39 | # a (batch, n, 2) 40 | # v (batch, n, f) 41 | batch_size = v.shape[0] 42 | 43 | phi_eval = self.phi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 44 | psi_eval = self.psi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 45 | 46 | # print(psi_eval.shape, v.shape, phi_eval.shape) 47 | v = torch.einsum('bnoir,bni,bmoir->bmo',psi_eval, v, phi_eval) / self.n 48 | 49 | return v 50 | 51 | 52 | 53 | class MyNet(torch.nn.Module): 54 | def __init__(self, s, width=32, rank=4): 55 | super(MyNet, self).__init__() 56 | self.s = s 57 | self.width = width 58 | self.rank = rank 59 | 60 | self.fc0 = nn.Linear(2, self.width) 61 | 62 | self.net1 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 63 | self.net2 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 64 | self.net3 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 65 | self.net4 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 66 | self.w1 = nn.Linear(self.width, self.width) 67 | self.w2 = nn.Linear(self.width, self.width) 68 | self.w3 = nn.Linear(self.width, self.width) 69 | self.w4 = nn.Linear(self.width, self.width) 70 | 71 | self.bn1 = torch.nn.BatchNorm1d(self.width) 72 | self.bn2 = torch.nn.BatchNorm1d(self.width) 73 | self.bn3 = torch.nn.BatchNorm1d(self.width) 74 | self.bn4 = torch.nn.BatchNorm1d(self.width) 75 | 76 | self.fc1 = nn.Linear(self.width, 128) 77 | self.fc2 = nn.Linear(128, 1) 78 | 79 | 80 | def forward(self, v): 81 | batch_size, n = v.shape[0], v.shape[1] 82 | a = v.clone() 83 | 84 | v = self.fc0(v) 85 | 86 | v1 = self.net1(v, a) 87 | v2 = self.w1(v) 88 | v = v1+v2 89 | v = self.bn1(v.reshape(-1, self.width)).view(batch_size,n,self.width) 90 | v = F.relu(v) 91 | 92 | v1 = self.net2(v, a) 93 | v2 = self.w2(v) 94 | v = v1+v2 95 | v = self.bn2(v.reshape(-1, self.width)).view(batch_size,n,self.width) 96 | v = F.relu(v) 97 | 98 | v1 = self.net3(v, a) 99 | v2 = self.w3(v) 100 | v = v1+v2 101 | v = self.bn3(v.reshape(-1, self.width)).view(batch_size,n,self.width) 102 | v = F.relu(v) 103 | 104 | v1 = self.net4(v, a) 105 | v2 = self.w4(v) 106 | v = v1+v2 107 | v = self.bn4(v.reshape(-1, self.width)).view(batch_size,n,self.width) 108 | 109 | 110 | v = self.fc1(v) 111 | v = F.relu(v) 112 | v = self.fc2(v) 113 | 114 | return v.squeeze() 115 | 116 | def count_params(self): 117 | c = 0 118 | for p in self.parameters(): 119 | c += reduce(operator.mul, list(p.size())) 120 | 121 | return c 122 | 123 | ################################################################ 124 | # configs 125 | ################################################################ 126 | 127 | ntrain = 1000 128 | ntest = 200 129 | 130 | sub = 1 #subsampling rate 131 | h = 2**13 // sub 132 | s = h 133 | 134 | batch_size = 5 135 | learning_rate = 0.001 136 | 137 | 138 | ################################################################ 139 | # reading data and normalization 140 | ################################################################ 141 | dataloader = MatReader('data/burgers_data_R10.mat') 142 | x_data = dataloader.read_field('a')[:,::sub] 143 | y_data = dataloader.read_field('u')[:,::sub] 144 | 145 | x_train = x_data[:ntrain,:] 146 | y_train = y_data[:ntrain,:] 147 | x_test = x_data[-ntest:,:] 148 | y_test = y_data[-ntest:,:] 149 | 150 | x_normalizer = UnitGaussianNormalizer(x_train) 151 | x_train = x_normalizer.encode(x_train) 152 | x_test = x_normalizer.encode(x_test) 153 | 154 | y_normalizer = UnitGaussianNormalizer(y_train) 155 | y_train = y_normalizer.encode(y_train) 156 | 157 | grid = np.linspace(0, 2*np.pi, s).reshape(1, s, 1) 158 | grid = torch.tensor(grid, dtype=torch.float) 159 | x_train = torch.cat([x_train.reshape(ntrain,s,1), grid.repeat(ntrain,1,1)], dim=2) 160 | x_test = torch.cat([x_test.reshape(ntest,s,1), grid.repeat(ntest,1,1)], dim=2) 161 | 162 | 163 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 164 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 165 | 166 | model = MyNet(s).cuda() 167 | print(model.count_params()) 168 | 169 | ################################################################ 170 | # training and evaluation 171 | ################################################################ 172 | 173 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 174 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 175 | epochs = 500 176 | 177 | myloss = LpLoss(size_average=False) 178 | y_normalizer.cuda() 179 | for ep in range(epochs): 180 | model.train() 181 | t1 = default_timer() 182 | train_mse = 0 183 | train_l2 = 0 184 | for x, y in train_loader: 185 | x, y = x.cuda(), y.cuda() 186 | 187 | optimizer.zero_grad() 188 | out = model(x).reshape(batch_size, s) 189 | 190 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean') 191 | # mse.backward() 192 | 193 | out = y_normalizer.decode(out) 194 | y = y_normalizer.decode(y) 195 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 196 | loss.backward() 197 | 198 | optimizer.step() 199 | train_mse += mse.item() 200 | train_l2 += loss.item() 201 | 202 | scheduler.step() 203 | 204 | model.eval() 205 | test_l2 = 0.0 206 | with torch.no_grad(): 207 | for x, y in test_loader: 208 | x, y = x.cuda(), y.cuda() 209 | 210 | out = model(x).reshape(batch_size, s) 211 | out = y_normalizer.decode(out) 212 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 213 | 214 | train_mse /= len(train_loader) 215 | train_l2 /= ntrain 216 | test_l2 /= ntest 217 | 218 | t2 = default_timer() 219 | print(ep, t2-t1, train_mse, train_l2, test_l2) 220 | -------------------------------------------------------------------------------- /fourier_1d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | import matplotlib.pyplot as plt 7 | 8 | import operator 9 | from functools import reduce 10 | from functools import partial 11 | from timeit import default_timer 12 | from utilities3 import * 13 | 14 | torch.manual_seed(0) 15 | np.random.seed(0) 16 | 17 | #Complex multiplication 18 | def compl_mul1d(a, b): 19 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 20 | op = partial(torch.einsum, "bix,iox->box") 21 | return torch.stack([ 22 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 23 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 24 | ], dim=-1) 25 | 26 | ################################################################ 27 | # 1d fourier layer 28 | ################################################################ 29 | class SpectralConv1d(nn.Module): 30 | def __init__(self, in_channels, out_channels, modes1): 31 | super(SpectralConv1d, self).__init__() 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 35 | 36 | 37 | self.scale = (1 / (in_channels*out_channels)) 38 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, 2)) 39 | 40 | def forward(self, x): 41 | batchsize = x.shape[0] 42 | #Compute Fourier coeffcients up to factor of e^(- something constant) 43 | x_ft = torch.rfft(x, 1, normalized=True, onesided=True) 44 | 45 | # Multiply relevant Fourier modes 46 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-1)//2 + 1, 2, device=x.device) 47 | out_ft[:, :, :self.modes1] = compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 48 | 49 | #Return to physical space 50 | x = torch.irfft(out_ft, 1, normalized=True, onesided=True, signal_sizes=(x.size(-1), )) 51 | return x 52 | 53 | class SimpleBlock1d(nn.Module): 54 | def __init__(self, modes, width): 55 | super(SimpleBlock1d, self).__init__() 56 | 57 | self.modes1 = modes 58 | self.width = width 59 | self.fc0 = nn.Linear(2, self.width) 60 | 61 | self.conv0 = SpectralConv1d(self.width, self.width, self.modes1) 62 | self.conv1 = SpectralConv1d(self.width, self.width, self.modes1) 63 | self.conv2 = SpectralConv1d(self.width, self.width, self.modes1) 64 | self.conv3 = SpectralConv1d(self.width, self.width, self.modes1) 65 | self.w0 = nn.Conv1d(self.width, self.width, 1) 66 | self.w1 = nn.Conv1d(self.width, self.width, 1) 67 | self.w2 = nn.Conv1d(self.width, self.width, 1) 68 | self.w3 = nn.Conv1d(self.width, self.width, 1) 69 | self.bn0 = torch.nn.BatchNorm1d(self.width) 70 | self.bn1 = torch.nn.BatchNorm1d(self.width) 71 | self.bn2 = torch.nn.BatchNorm1d(self.width) 72 | self.bn3 = torch.nn.BatchNorm1d(self.width) 73 | 74 | 75 | self.fc1 = nn.Linear(self.width, 128) 76 | self.fc2 = nn.Linear(128, 1) 77 | 78 | def forward(self, x): 79 | 80 | x = self.fc0(x) 81 | x = x.permute(0, 2, 1) 82 | 83 | x1 = self.conv0(x) 84 | x2 = self.w0(x) 85 | x = self.bn0(x1 + x2) 86 | x = F.relu(x) 87 | x1 = self.conv1(x) 88 | x2 = self.w1(x) 89 | x = self.bn1(x1 + x2) 90 | x = F.relu(x) 91 | x1 = self.conv2(x) 92 | x2 = self.w2(x) 93 | x = self.bn2(x1 + x2) 94 | x = F.relu(x) 95 | x1 = self.conv3(x) 96 | x2 = self.w3(x) 97 | x = self.bn3(x1 + x2) 98 | 99 | 100 | x = x.permute(0, 2, 1) 101 | x = self.fc1(x) 102 | x = F.relu(x) 103 | x = self.fc2(x) 104 | return x 105 | 106 | class Net1d(nn.Module): 107 | def __init__(self, modes, width): 108 | super(Net1d, self).__init__() 109 | 110 | self.conv1 = SimpleBlock1d(modes, width) 111 | 112 | 113 | def forward(self, x): 114 | x = self.conv1(x) 115 | return x.squeeze() 116 | 117 | def count_params(self): 118 | c = 0 119 | for p in self.parameters(): 120 | c += reduce(operator.mul, list(p.size())) 121 | 122 | return c 123 | 124 | 125 | ################################################################ 126 | # configurations 127 | ################################################################ 128 | ntrain = 1000 129 | ntest = 100 130 | 131 | sub = 1 #subsampling rate 132 | h = 2**10 // sub 133 | s = h 134 | 135 | batch_size = 20 136 | learning_rate = 0.001 137 | 138 | epochs = 500 139 | step_size = 100 140 | gamma = 0.5 141 | 142 | modes = 16 143 | width = 64 144 | 145 | 146 | ################################################################ 147 | # read data 148 | ################################################################ 149 | dataloader = MatReader('data/burgers_data_R10.mat') 150 | x_data = dataloader.read_field('a')[:,::sub] 151 | y_data = dataloader.read_field('u')[:,::sub] 152 | 153 | x_train = x_data[:ntrain,:] 154 | y_train = y_data[:ntrain,:] 155 | x_test = x_data[-ntest:,:] 156 | y_test = y_data[-ntest:,:] 157 | 158 | # cat the locations information 159 | grid = np.linspace(0, 2*np.pi, s).reshape(1, s, 1) 160 | grid = torch.tensor(grid, dtype=torch.float) 161 | x_train = torch.cat([x_train.reshape(ntrain,s,1), grid.repeat(ntrain,1,1)], dim=2) 162 | x_test = torch.cat([x_test.reshape(ntest,s,1), grid.repeat(ntest,1,1)], dim=2) 163 | 164 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 165 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 166 | 167 | # model 168 | model = Net1d(modes, width).cuda() 169 | print(model.count_params()) 170 | 171 | 172 | ################################################################ 173 | # training and evaluation 174 | ################################################################ 175 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 176 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 177 | 178 | myloss = LpLoss(size_average=False) 179 | for ep in range(epochs): 180 | model.train() 181 | t1 = default_timer() 182 | train_mse = 0 183 | train_l2 = 0 184 | for x, y in train_loader: 185 | x, y = x.cuda(), y.cuda() 186 | 187 | optimizer.zero_grad() 188 | out = model(x) 189 | 190 | mse = F.mse_loss(out, y, reduction='mean') 191 | # mse.backward() 192 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 193 | l2.backward() 194 | 195 | optimizer.step() 196 | train_mse += mse.item() 197 | train_l2 += l2.item() 198 | 199 | scheduler.step() 200 | model.eval() 201 | test_l2 = 0.0 202 | with torch.no_grad(): 203 | for x, y in test_loader: 204 | x, y = x.cuda(), y.cuda() 205 | 206 | out = model(x) 207 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 208 | 209 | train_mse /= len(train_loader) 210 | train_l2 /= ntrain 211 | test_l2 /= ntest 212 | 213 | t2 = default_timer() 214 | print(ep, t2-t1, train_mse, train_l2, test_l2) 215 | 216 | # torch.save(model, 'model/ns_fourier_burgers_8192') 217 | pred = torch.zeros(y_test.shape) 218 | index = 0 219 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 220 | with torch.no_grad(): 221 | for x, y in test_loader: 222 | test_l2 = 0 223 | x, y = x.cuda(), y.cuda() 224 | 225 | out = model(x) 226 | pred[index] = out 227 | 228 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 229 | print(index, test_l2) 230 | index = index + 1 231 | 232 | # scipy.io.savemat('pred/burger_test.mat', mdict={'pred': pred.cpu().numpy()}) 233 | -------------------------------------------------------------------------------- /scripts/super_resolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | def compl_mul3d(a, b): 20 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 21 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 22 | return torch.stack([ 23 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 24 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 25 | ], dim=-1) 26 | 27 | class SpectralConv3d_fast(nn.Module): 28 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 29 | super(SpectralConv3d_fast, self).__init__() 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels 32 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 33 | self.modes2 = modes2 34 | self.modes3 = modes3 35 | 36 | self.scale = (1 / (in_channels * out_channels)) 37 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 38 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 39 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 40 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 41 | 42 | def forward(self, x): 43 | batchsize = x.shape[0] 44 | #Compute Fourier coeffcients up to factor of e^(- something constant) 45 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 46 | 47 | # Multiply relevant Fourier modes 48 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 49 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 50 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 51 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 52 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 53 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 54 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 55 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 56 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 57 | 58 | #Return to physical space 59 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 60 | return x 61 | 62 | class SimpleBlock2d(nn.Module): 63 | def __init__(self, modes1, modes2, modes3, width): 64 | super(SimpleBlock2d, self).__init__() 65 | 66 | self.modes1 = modes1 67 | self.modes2 = modes2 68 | self.modes3 = modes3 69 | self.width = width 70 | self.fc0 = nn.Linear(13, self.width) 71 | 72 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 73 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 74 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 75 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 76 | self.w0 = nn.Conv1d(self.width, self.width, 1) 77 | self.w1 = nn.Conv1d(self.width, self.width, 1) 78 | self.w2 = nn.Conv1d(self.width, self.width, 1) 79 | self.w3 = nn.Conv1d(self.width, self.width, 1) 80 | self.bn0 = torch.nn.BatchNorm3d(self.width) 81 | self.bn1 = torch.nn.BatchNorm3d(self.width) 82 | self.bn2 = torch.nn.BatchNorm3d(self.width) 83 | self.bn3 = torch.nn.BatchNorm3d(self.width) 84 | 85 | 86 | self.fc1 = nn.Linear(self.width, 128) 87 | self.fc2 = nn.Linear(128, 1) 88 | 89 | def forward(self, x): 90 | batchsize = x.shape[0] 91 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 92 | 93 | x = self.fc0(x) 94 | x = x.permute(0, 4, 1, 2, 3) 95 | 96 | x1 = self.conv0(x) 97 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 98 | x = self.bn0(x1 + x2) 99 | x = F.relu(x) 100 | x1 = self.conv1(x) 101 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 102 | x = self.bn1(x1 + x2) 103 | x = F.relu(x) 104 | x1 = self.conv2(x) 105 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 106 | x = self.bn2(x1 + x2) 107 | x = F.relu(x) 108 | x1 = self.conv3(x) 109 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 110 | x = self.bn3(x1 + x2) 111 | 112 | x = x.permute(0, 2, 3, 4, 1) 113 | x = self.fc1(x) 114 | x = F.relu(x) 115 | x = self.fc2(x) 116 | return x 117 | 118 | class Net2d(nn.Module): 119 | def __init__(self, modes, width): 120 | super(Net2d, self).__init__() 121 | self.conv1 = SimpleBlock2d(modes, modes, 6, width) 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | return x.squeeze() 126 | 127 | def count_params(self): 128 | c = 0 129 | for p in self.parameters(): 130 | c += reduce(operator.mul, list(p.size())) 131 | 132 | return c 133 | 134 | 135 | t1 = default_timer() 136 | 137 | TEST_PATH = 'data/ns_data_V1e-4_N20_T50_test.mat' 138 | 139 | 140 | ntest = 20 141 | 142 | sub = 1 143 | sub_t = 1 144 | S = 64 145 | T_in = 10 146 | T = 20 147 | 148 | indent = 1 149 | 150 | # load data 151 | reader = MatReader(TEST_PATH) 152 | test_a = reader.read_field('u')[:,::sub,::sub, 3:T_in*4:4] 153 | test_u = reader.read_field('u')[:,::sub,::sub, indent+T_in*4:indent+(T+T_in)*4:sub_t] 154 | 155 | print(test_a.shape, test_u.shape) 156 | 157 | # pad the location information (s,t) 158 | S = S * (4//sub) 159 | T = T * (4//sub_t) 160 | 161 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 162 | 163 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 164 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 165 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 166 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 167 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 168 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 169 | 170 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 171 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 172 | 173 | t2 = default_timer() 174 | print('preprocessing finished, time used:', t2-t1) 175 | device = torch.device('cuda') 176 | 177 | # load model 178 | model = torch.load('model/ns_fourier_V1e-4_T20_N9800_ep200_m12_w32') 179 | 180 | print(model.count_params()) 181 | 182 | # test 183 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 184 | myloss = LpLoss(size_average=False) 185 | pred = torch.zeros(test_u.shape) 186 | index = 0 187 | with torch.no_grad(): 188 | test_l2 = 0 189 | for x, y in test_loader: 190 | x, y = x.cuda(), y.cuda() 191 | 192 | out = model(x) 193 | pred[index] = out 194 | loss = myloss(out.view(1, -1), y.view(1, -1)).item() 195 | test_l2 += loss 196 | print(index, loss) 197 | index = index + 1 198 | print(test_l2/ntest) 199 | 200 | path = 'eval' 201 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy(), 'u': test_u.cpu().numpy()}) 202 | 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | from utilities3 import * 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | 11 | from timeit import default_timer 12 | import scipy.io 13 | 14 | torch.manual_seed(0) 15 | np.random.seed(0) 16 | 17 | def compl_mul3d(a, b): 18 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 19 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 20 | return torch.stack([ 21 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 22 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 23 | ], dim=-1) 24 | 25 | class SpectralConv3d_fast(nn.Module): 26 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 27 | super(SpectralConv3d_fast, self).__init__() 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 31 | self.modes2 = modes2 32 | self.modes3 = modes3 33 | 34 | self.scale = (1 / (in_channels * out_channels)) 35 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 36 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 37 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 38 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 39 | 40 | def forward(self, x): 41 | batchsize = x.shape[0] 42 | #Compute Fourier coeffcients up to factor of e^(- something constant) 43 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 44 | 45 | # Multiply relevant Fourier modes 46 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 47 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 48 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 49 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 50 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 51 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 52 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 53 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 54 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 55 | 56 | #Return to physical space 57 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 58 | return x 59 | 60 | class SimpleBlock2d(nn.Module): 61 | def __init__(self, modes1, modes2, modes3, width): 62 | super(SimpleBlock2d, self).__init__() 63 | 64 | self.modes1 = modes1 65 | self.modes2 = modes2 66 | self.modes3 = modes3 67 | self.width = width 68 | self.fc0 = nn.Linear(13, self.width) 69 | 70 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 71 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 72 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 73 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 74 | self.w0 = nn.Conv1d(self.width, self.width, 1) 75 | self.w1 = nn.Conv1d(self.width, self.width, 1) 76 | self.w2 = nn.Conv1d(self.width, self.width, 1) 77 | self.w3 = nn.Conv1d(self.width, self.width, 1) 78 | self.bn0 = torch.nn.BatchNorm3d(self.width) 79 | self.bn1 = torch.nn.BatchNorm3d(self.width) 80 | self.bn2 = torch.nn.BatchNorm3d(self.width) 81 | self.bn3 = torch.nn.BatchNorm3d(self.width) 82 | 83 | 84 | self.fc1 = nn.Linear(self.width, 128) 85 | self.fc2 = nn.Linear(128, 1) 86 | 87 | def forward(self, x): 88 | batchsize = x.shape[0] 89 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 90 | 91 | x = self.fc0(x) 92 | x = x.permute(0, 4, 1, 2, 3) 93 | 94 | x1 = self.conv0(x) 95 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 96 | x = self.bn0(x1 + x2) 97 | x = F.relu(x) 98 | x1 = self.conv1(x) 99 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 100 | x = self.bn1(x1 + x2) 101 | x = F.relu(x) 102 | x1 = self.conv2(x) 103 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 104 | x = self.bn2(x1 + x2) 105 | x = F.relu(x) 106 | x1 = self.conv3(x) 107 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 108 | x = self.bn3(x1 + x2) 109 | 110 | x = x.permute(0, 2, 3, 4, 1) 111 | x = self.fc1(x) 112 | x = F.relu(x) 113 | x = self.fc2(x) 114 | return x 115 | 116 | class Net2d(nn.Module): 117 | def __init__(self, modes, width): 118 | super(Net2d, self).__init__() 119 | self.conv1 = SimpleBlock2d(modes, modes, 6, width) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | return x.squeeze() 124 | 125 | def count_params(self): 126 | c = 0 127 | for p in self.parameters(): 128 | c += reduce(operator.mul, list(p.size())) 129 | 130 | return c 131 | 132 | 133 | t1 = default_timer() 134 | 135 | TEST_PATH = 'data/ns_data_V1e-4_N20_T50_R256test.mat' 136 | 137 | 138 | ntest = 20 139 | 140 | sub = 4 141 | sub_t = 4 142 | S = 64 143 | T_in = 10 144 | T = 20 145 | 146 | indent = 3 147 | 148 | # load data 149 | reader = MatReader(TEST_PATH) 150 | test_a = reader.read_field('u')[:,::sub,::sub, indent:T_in*4:4] #([0, T_in]) 151 | test_u = reader.read_field('u')[:,::sub,::sub, indent+T_in*4:indent+(T+T_in)*4:sub_t] #([T_in, T_in + T]) 152 | 153 | print(test_a.shape, test_u.shape) 154 | 155 | # pad the location information (s,t) 156 | S = S * (4//sub) 157 | T = T * (4//sub_t) 158 | 159 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 160 | 161 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 162 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 163 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 164 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 165 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 166 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 167 | 168 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 169 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 170 | 171 | t2 = default_timer() 172 | print('preprocessing finished, time used:', t2-t1) 173 | device = torch.device('cuda') 174 | 175 | # load model 176 | model = torch.load('model/ns_fourier_V1e-4_T20_N9800_ep200_m12_w32') 177 | 178 | print(model.count_params()) 179 | 180 | # test 181 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 182 | myloss = LpLoss(size_average=False) 183 | pred = torch.zeros(test_u.shape) 184 | index = 0 185 | with torch.no_grad(): 186 | test_l2 = 0 187 | for x, y in test_loader: 188 | x, y = x.cuda(), y.cuda() 189 | 190 | out = model(x) 191 | pred[index] = out 192 | loss = myloss(out.view(1, -1), y.view(1, -1)).item() 193 | test_l2 += loss 194 | print(index, loss) 195 | index = index + 1 196 | print(test_l2/ntest) 197 | 198 | path = 'eval' 199 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy(), 'u': test_u.cpu().numpy()}) 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /fourier_2d.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | import operator 11 | from functools import reduce 12 | from functools import partial 13 | 14 | from timeit import default_timer 15 | from utilities3 import * 16 | 17 | torch.manual_seed(0) 18 | np.random.seed(0) 19 | 20 | #Complex multiplication 21 | def compl_mul2d(a, b): 22 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 23 | op = partial(torch.einsum, "bixy,ioxy->boxy") 24 | return torch.stack([ 25 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 26 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 27 | ], dim=-1) 28 | 29 | 30 | ################################################################ 31 | # fourier layer 32 | ################################################################ 33 | class SpectralConv2d(nn.Module): 34 | def __init__(self, in_channels, out_channels, modes1, modes2): 35 | super(SpectralConv2d, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 39 | self.modes2 = modes2 40 | 41 | self.scale = (1 / (in_channels * out_channels)) 42 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 43 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 44 | 45 | def forward(self, x): 46 | batchsize = x.shape[0] 47 | #Compute Fourier coeffcients up to factor of e^(- something constant) 48 | x_ft = torch.rfft(x, 2, normalized=True, onesided=True) 49 | 50 | # Multiply relevant Fourier modes 51 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 52 | out_ft[:, :, :self.modes1, :self.modes2] = \ 53 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 54 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 55 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 56 | 57 | #Return to physical space 58 | x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=( x.size(-2), x.size(-1))) 59 | return x 60 | 61 | class SimpleBlock2d(nn.Module): 62 | def __init__(self, modes1, modes2, width): 63 | super(SimpleBlock2d, self).__init__() 64 | 65 | self.modes1 = modes1 66 | self.modes2 = modes2 67 | self.width = width 68 | self.fc0 = nn.Linear(3, self.width) 69 | 70 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 71 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 72 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 73 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 74 | self.w0 = nn.Conv1d(self.width, self.width, 1) 75 | self.w1 = nn.Conv1d(self.width, self.width, 1) 76 | self.w2 = nn.Conv1d(self.width, self.width, 1) 77 | self.w3 = nn.Conv1d(self.width, self.width, 1) 78 | self.bn0 = torch.nn.BatchNorm2d(self.width) 79 | self.bn1 = torch.nn.BatchNorm2d(self.width) 80 | self.bn2 = torch.nn.BatchNorm2d(self.width) 81 | self.bn3 = torch.nn.BatchNorm2d(self.width) 82 | 83 | 84 | self.fc1 = nn.Linear(self.width, 128) 85 | self.fc2 = nn.Linear(128, 1) 86 | 87 | def forward(self, x): 88 | batchsize = x.shape[0] 89 | size_x, size_y = x.shape[1], x.shape[2] 90 | 91 | x = self.fc0(x) 92 | x = x.permute(0, 3, 1, 2) 93 | 94 | x1 = self.conv0(x) 95 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 96 | x = self.bn0(x1 + x2) 97 | x = F.relu(x) 98 | x1 = self.conv1(x) 99 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 100 | x = self.bn1(x1 + x2) 101 | x = F.relu(x) 102 | x1 = self.conv2(x) 103 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 104 | x = self.bn2(x1 + x2) 105 | x = F.relu(x) 106 | x1 = self.conv3(x) 107 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 108 | x = self.bn3(x1 + x2) 109 | 110 | 111 | x = x.permute(0, 2, 3, 1) 112 | x = self.fc1(x) 113 | x = F.relu(x) 114 | x = self.fc2(x) 115 | return x 116 | 117 | class Net2d(nn.Module): 118 | def __init__(self, modes, width): 119 | super(Net2d, self).__init__() 120 | 121 | self.conv1 = SimpleBlock2d(modes, modes, width) 122 | 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | return x.squeeze() 127 | 128 | 129 | def count_params(self): 130 | c = 0 131 | for p in self.parameters(): 132 | c += reduce(operator.mul, list(p.size())) 133 | 134 | return c 135 | 136 | ################################################################ 137 | # configs 138 | ################################################################ 139 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 140 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 141 | 142 | ntrain = 1000 143 | ntest = 100 144 | 145 | batch_size = 20 146 | learning_rate = 0.001 147 | 148 | epochs = 500 149 | step_size = 100 150 | gamma = 0.5 151 | 152 | modes = 12 153 | width = 32 154 | 155 | r = 5 156 | h = int(((421 - 1)/r) + 1) 157 | s = h 158 | 159 | ################################################################ 160 | # load data and data normalization 161 | ################################################################ 162 | reader = MatReader(TRAIN_PATH) 163 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s] 164 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s] 165 | 166 | reader.load_file(TEST_PATH) 167 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s] 168 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s] 169 | 170 | x_normalizer = UnitGaussianNormalizer(x_train) 171 | x_train = x_normalizer.encode(x_train) 172 | x_test = x_normalizer.encode(x_test) 173 | 174 | y_normalizer = UnitGaussianNormalizer(y_train) 175 | y_train = y_normalizer.encode(y_train) 176 | 177 | grids = [] 178 | grids.append(np.linspace(0, 1, s)) 179 | grids.append(np.linspace(0, 1, s)) 180 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 181 | grid = grid.reshape(1,s,s,2) 182 | grid = torch.tensor(grid, dtype=torch.float) 183 | x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3) 184 | x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3) 185 | 186 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 187 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 188 | 189 | ################################################################ 190 | # training and evaluation 191 | ################################################################ 192 | model = Net2d(modes, width).cuda() 193 | print(model.count_params()) 194 | 195 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 196 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 197 | 198 | myloss = LpLoss(size_average=False) 199 | y_normalizer.cuda() 200 | for ep in range(epochs): 201 | model.train() 202 | t1 = default_timer() 203 | train_mse = 0 204 | for x, y in train_loader: 205 | x, y = x.cuda(), y.cuda() 206 | 207 | optimizer.zero_grad() 208 | # loss = F.mse_loss(model(x).view(-1), y.view(-1), reduction='mean') 209 | out = model(x) 210 | out = y_normalizer.decode(out) 211 | y = y_normalizer.decode(y) 212 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 213 | loss.backward() 214 | 215 | 216 | optimizer.step() 217 | train_mse += loss.item() 218 | 219 | scheduler.step() 220 | 221 | model.eval() 222 | abs_err = 0.0 223 | rel_err = 0.0 224 | with torch.no_grad(): 225 | for x, y in test_loader: 226 | x, y = x.cuda(), y.cuda() 227 | 228 | out = model(x) 229 | out = y_normalizer.decode(model(x)) 230 | 231 | rel_err += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 232 | 233 | train_mse/= ntrain 234 | abs_err /= ntest 235 | rel_err /= ntest 236 | 237 | t2 = default_timer() 238 | print(ep, t2-t1, train_mse, rel_err) 239 | -------------------------------------------------------------------------------- /scripts/fourier_on_images.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | import operator 11 | from functools import reduce 12 | from functools import partial 13 | 14 | from timeit import default_timer 15 | from utilities3 import * 16 | 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | 20 | torch.manual_seed(0) 21 | np.random.seed(0) 22 | 23 | #Complex multiplication 24 | 25 | def compl_mul2d(a, b): 26 | op = partial(torch.einsum, "bctq,dctq->bdtq") 27 | return torch.stack([ 28 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 29 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 30 | ], dim=-1) 31 | 32 | 33 | class SpectralConv2d(nn.Module): 34 | def __init__(self, in_channels, out_channels, mode): 35 | super(SpectralConv2d, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.modes1 = mode #Number of Fourier modes to multiply, at most floor(N/2) + 1 39 | self.modes2 = mode 40 | 41 | self.scale = (1 / (in_channels * out_channels)) 42 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 43 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 44 | 45 | def forward(self, x): 46 | batchsize = x.shape[0] 47 | #Compute Fourier coeffcients up to factor of e^(- something constant) 48 | x_ft = torch.rfft(x, 2, normalized=True, onesided=True) 49 | 50 | # Multiply relevant Fourier modes 51 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 52 | out_ft[:, :, :self.modes1, :self.modes2] = \ 53 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 54 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 55 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 56 | 57 | #Return to physical space 58 | x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=( x.size(-2), x.size(-1))) 59 | return x 60 | 61 | 62 | class SimpleBlock2d(nn.Module): 63 | def __init__(self, modes): 64 | super(SimpleBlock2d, self).__init__() 65 | 66 | self.conv1 = SpectralConv2d(1, 16, modes=modes) 67 | self.conv2 = SpectralConv2d(16, 32, modes=modes) 68 | self.conv3 = SpectralConv2d(32, 64, modes=modes) 69 | 70 | self.pool = nn.MaxPool2d(2, 2) 71 | 72 | 73 | self.fc1 = nn.Linear(64 * 14 * 14, 120) 74 | self.fc2 = nn.Linear(120, 84) 75 | self.fc3 = nn.Linear(84, 10) 76 | 77 | def forward(self, x): 78 | x = self.conv1(x) 79 | x = F.relu(x) 80 | x = self.conv2(x) 81 | x = F.relu(x) 82 | x = self.conv3(x) 83 | x = self.pool(x) 84 | 85 | x = x.view(-1, 64 * 14 * 14) 86 | x = F.relu(self.fc1(x)) 87 | x = F.relu(self.fc2(x)) 88 | x = self.fc3(x) 89 | 90 | return x 91 | 92 | class Net2d(nn.Module): 93 | def __init__(self): 94 | super(Net2d, self).__init__() 95 | 96 | self.conv = SimpleBlock2d(5) 97 | 98 | def forward(self, x): 99 | x = self.conv(x) 100 | 101 | return x.squeeze(-1) 102 | 103 | def count_params(self): 104 | c = 0 105 | for p in self.parameters(): 106 | c += reduce(operator.mul, list(p.size())) 107 | 108 | return c 109 | 110 | 111 | class BasicBlock(nn.Module): 112 | expansion = 1 113 | 114 | def __init__(self, in_planes, planes, stride=1, modes=10): 115 | super(BasicBlock, self).__init__() 116 | self.conv1 = SpectralConv2d(in_planes, planes, modes=modes) 117 | self.bn1 = nn.BatchNorm2d(planes) 118 | self.conv2 = SpectralConv2d(planes, planes, modes=modes) 119 | self.bn2 = nn.BatchNorm2d(planes) 120 | 121 | self.shortcut = nn.Sequential() 122 | if stride != 1 or in_planes != self.expansion*planes: 123 | self.shortcut = nn.Sequential( 124 | SpectralConv2d(in_planes, self.expansion*planes, modes=modes), 125 | nn.BatchNorm2d(self.expansion*planes) 126 | ) 127 | 128 | def forward(self, x): 129 | out = F.relu(self.bn1(self.conv1(x))) 130 | out = self.bn2(self.conv2(out)) 131 | out += self.shortcut(x) 132 | out = F.relu(out) 133 | return out 134 | 135 | class ResNet(nn.Module): 136 | def __init__(self, block, num_blocks, num_classes=10): 137 | super(ResNet, self).__init__() 138 | self.in_planes = 32 139 | 140 | self.conv1 = SpectralConv2d(3, 32, modes=10) 141 | self.bn1 = nn.BatchNorm2d(32) 142 | self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1, modes=3) 143 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=1, modes=3) 144 | self.layer3 = self._make_layer(block, 32, num_blocks[2], stride=1, modes=3) 145 | self.layer4 = self._make_layer(block, 32, num_blocks[3], stride=1, modes=3) 146 | self.linear1 = nn.Linear(32*64*block.expansion, num_classes) 147 | # self.linear2 = nn.Linear(100, num_classes) 148 | 149 | def _make_layer(self, block, planes, num_blocks, stride, modes=10): 150 | strides = [stride] + [1]*(num_blocks-1) 151 | layers = [] 152 | for stride in strides: 153 | layers.append(block(self.in_planes, planes, stride, modes)) 154 | self.in_planes = planes * block.expansion 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x): 158 | out = self.conv1(x) 159 | out = self.bn1(out) 160 | out = F.relu(out) 161 | out = self.layer1(out) 162 | # out = F.avg_pool2d(out, 2) 163 | out = self.layer2(out) 164 | # out = F.avg_pool2d(out, 2) 165 | out = self.layer3(out) 166 | # out = F.avg_pool2d(out, 2) 167 | out = self.layer4(out) 168 | out = F.avg_pool2d(out, 4) 169 | # print(out.shape) 170 | out = out.view(out.size(0), -1) 171 | out = self.linear1(out) 172 | # out = F.relu(out) 173 | # out = self.linear2(out) 174 | return out 175 | 176 | def ResNet18(): 177 | return ResNet(BasicBlock, [3, 4, 23, 3]) 178 | 179 | 180 | ## Mnist 181 | # transform = transforms.Compose([transforms.ToTensor(), 182 | # transforms.Normalize((0.5,), (0.5,)), 183 | # ]) 184 | # trainset = torchvision.datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform) 185 | # testset = torchvision.datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform) 186 | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True) 187 | # testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True) 188 | 189 | ## Cifar10 190 | transform = transforms.Compose( 191 | [transforms.ToTensor(), 192 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 193 | 194 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 195 | download=True, transform=transform) 196 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, 197 | shuffle=True, num_workers=4) 198 | 199 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 200 | download=True, transform=transform) 201 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, 202 | shuffle=False, num_workers=4) 203 | 204 | classes = ('plane', 'car', 'bird', 'cat', 205 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 206 | 207 | 208 | # model = Net2d().cuda() 209 | model = ResNet18().cuda() 210 | # model = torch.load('results/fourier_on_images') 211 | 212 | criterion = nn.CrossEntropyLoss() 213 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4) 214 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.75) 215 | 216 | for epoch in range(50): # loop over the dataset multiple times 217 | running_loss = 0.0 218 | for i, data in enumerate(trainloader, 0): 219 | # get the inputs; data is a list of [inputs, labels] 220 | inputs, labels = data[0].cuda(), data[1].cuda() 221 | 222 | # zero the parameter gradients 223 | optimizer.zero_grad() 224 | 225 | # forward + backward + optimize 226 | outputs = model(inputs) 227 | loss = criterion(outputs, labels) 228 | loss.backward() 229 | optimizer.step() 230 | 231 | # print statistics 232 | running_loss += loss.item() 233 | if i % 100 == 99: # print every 2000 mini-batches 234 | print('[%d, %5d] loss: %.3f' % 235 | (epoch + 1, i + 1, running_loss / 100)) 236 | running_loss = 0.0 237 | 238 | correct = 0 239 | total = 0 240 | with torch.no_grad(): 241 | for data in testloader: 242 | images, labels = data[0].cuda(), data[1].cuda() 243 | 244 | outputs = model(images) 245 | _, predicted = torch.max(outputs.data, 1) 246 | total += labels.size(0) 247 | correct += (predicted == labels).sum().item() 248 | print('Accuracy of the network on the 10000 test images: %f %%' % ( 249 | 100 * correct / total)) 250 | 251 | torch.save(model, 'results/fourier_on_images_mnist_100') 252 | -------------------------------------------------------------------------------- /lowrank_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | 20 | ################################################################ 21 | # 3d lowrank layers 22 | ################################################################ 23 | 24 | class LowRank2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, n, ker_width, rank): 26 | super(LowRank2d, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.n = n 30 | self.rank = rank 31 | 32 | self.phi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 33 | self.psi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 34 | 35 | 36 | def forward(self, v): 37 | batch_size = v.shape[0] 38 | 39 | phi_eval = self.phi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 40 | psi_eval = self.psi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 41 | 42 | # print(psi_eval.shape, v.shape, phi_eval.shape) 43 | v = torch.einsum('bnoir,bni,bmoir->bmo',psi_eval, v, phi_eval) 44 | 45 | return v 46 | 47 | 48 | 49 | class MyNet(torch.nn.Module): 50 | def __init__(self, n, width=16, ker_width=256, rank=16): 51 | super(MyNet, self).__init__() 52 | self.n = n 53 | self.width = width 54 | self.ker_width = ker_width 55 | self.rank = rank 56 | 57 | self.fc0 = nn.Linear(13, self.width) 58 | 59 | self.conv0 = LowRank2d(width, width, n, ker_width, rank) 60 | self.conv1 = LowRank2d(width, width, n, ker_width, rank) 61 | self.conv2 = LowRank2d(width, width, n, ker_width, rank) 62 | self.conv3 = LowRank2d(width, width, n, ker_width, rank) 63 | 64 | self.w0 = nn.Linear(self.width, self.width) 65 | self.w1 = nn.Linear(self.width, self.width) 66 | self.w2 = nn.Linear(self.width, self.width) 67 | self.w3 = nn.Linear(self.width, self.width) 68 | self.bn0 = torch.nn.BatchNorm1d(self.width) 69 | self.bn1 = torch.nn.BatchNorm1d(self.width) 70 | self.bn2 = torch.nn.BatchNorm1d(self.width) 71 | self.bn3 = torch.nn.BatchNorm1d(self.width) 72 | 73 | self.fc1 = nn.Linear(self.width, 128) 74 | self.fc2 = nn.Linear(128, 1) 75 | 76 | 77 | def forward(self, x): 78 | batch_size = x.shape[0] 79 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 80 | x = x.view(batch_size, size_x*size_y*size_z, -1) 81 | 82 | x = self.fc0(x) 83 | 84 | x1 = self.conv0(x) 85 | x2 = self.w0(x) 86 | x = x1 + x2 87 | x = self.bn0(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 88 | x = F.relu(x) 89 | x1 = self.conv1(x) 90 | x2 = self.w1(x) 91 | x = x1 + x2 92 | x = self.bn1(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 93 | x = F.relu(x) 94 | x1 = self.conv2(x) 95 | x2 = self.w2(x) 96 | x = x1 + x2 97 | x = self.bn2(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 98 | x = F.relu(x) 99 | x1 = self.conv3(x) 100 | x2 = self.w3(x) 101 | x = x1 + x2 102 | x = self.bn3(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 103 | 104 | x = self.fc1(x) 105 | x = F.relu(x) 106 | x = self.fc2(x) 107 | x = x.view(batch_size, size_x, size_y, size_z) 108 | return x 109 | 110 | class Net2d(nn.Module): 111 | def __init__(self, width=8, ker_width=128, rank=4): 112 | super(Net2d, self).__init__() 113 | 114 | self.conv1 = MyNet(n=64*64*40, width=width, ker_width=ker_width, rank=rank) 115 | 116 | 117 | def forward(self, x): 118 | x = self.conv1(x) 119 | return x 120 | 121 | 122 | def count_params(self): 123 | c = 0 124 | for p in self.parameters(): 125 | c += reduce(operator.mul, list(p.size())) 126 | 127 | return c 128 | 129 | ################################################################ 130 | # configs 131 | ################################################################ 132 | # TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 133 | # TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 134 | # TRAIN_PATH = 'data/ns_data_V1000_N1000_train.mat' 135 | # TEST_PATH = 'data/ns_data_V1000_N1000_train_2.mat' 136 | # TRAIN_PATH = 'data/ns_data_V1000_N5000.mat' 137 | # TEST_PATH = 'data/ns_data_V1000_N5000.mat' 138 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 139 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 140 | 141 | ntrain = 1000 142 | ntest = 200 143 | 144 | batch_size = 2 145 | batch_size2 = batch_size 146 | 147 | epochs = 500 148 | learning_rate = 0.0025 149 | scheduler_step = 100 150 | scheduler_gamma = 0.5 151 | 152 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 153 | 154 | path = 'ns_lowrank_V100_T40_N'+str(ntrain)+'_ep' + str(epochs) 155 | path_model = 'model/'+path 156 | path_train_err = 'results/'+path+'train.txt' 157 | path_test_err = 'results/'+path+'test.txt' 158 | path_image = 'image/'+path 159 | 160 | runtime = np.zeros(2, ) 161 | t1 = default_timer() 162 | 163 | 164 | sub = 1 165 | S = 64 166 | T_in = 10 167 | T = 40 168 | 169 | ################################################################ 170 | # load data 171 | ################################################################ 172 | 173 | reader = MatReader(TRAIN_PATH) 174 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 175 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 176 | 177 | reader = MatReader(TEST_PATH) 178 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 179 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 180 | 181 | print(train_u.shape) 182 | print(test_u.shape) 183 | assert (S == train_u.shape[-2]) 184 | assert (T == train_u.shape[-1]) 185 | 186 | 187 | a_normalizer = UnitGaussianNormalizer(train_a) 188 | train_a = a_normalizer.encode(train_a) 189 | test_a = a_normalizer.encode(test_a) 190 | 191 | y_normalizer = UnitGaussianNormalizer(train_u) 192 | train_u = y_normalizer.encode(train_u) 193 | 194 | train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1]) 195 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 196 | 197 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 198 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 199 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 200 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 201 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 202 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 203 | 204 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 205 | gridt.repeat([ntrain,1,1,1,1]), train_a), dim=-1) 206 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 207 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 208 | 209 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 210 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 211 | 212 | t2 = default_timer() 213 | 214 | print('preprocessing finished, time used:', t2-t1) 215 | device = torch.device('cuda') 216 | 217 | 218 | ################################################################ 219 | # training and evaluation 220 | ################################################################ 221 | model = Net2d().cuda() 222 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 223 | 224 | print(model.count_params()) 225 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 226 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 227 | 228 | 229 | myloss = LpLoss(size_average=False) 230 | y_normalizer.cuda() 231 | for ep in range(epochs): 232 | model.train() 233 | t1 = default_timer() 234 | train_mse = 0 235 | train_l2 = 0 236 | for x, y in train_loader: 237 | x, y = x.cuda(), y.cuda() 238 | 239 | optimizer.zero_grad() 240 | out = model(x) 241 | 242 | mse = F.mse_loss(out, y, reduction='mean') 243 | # mse.backward() 244 | 245 | y = y_normalizer.decode(y) 246 | out = y_normalizer.decode(out) 247 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 248 | l2.backward() 249 | 250 | optimizer.step() 251 | train_mse += mse.item() 252 | train_l2 += l2.item() 253 | 254 | scheduler.step() 255 | 256 | model.eval() 257 | test_l2 = 0.0 258 | with torch.no_grad(): 259 | for x, y in test_loader: 260 | x, y = x.cuda(), y.cuda() 261 | 262 | out = model(x) 263 | out = y_normalizer.decode(out) 264 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 265 | 266 | train_mse /= len(train_loader) 267 | train_l2 /= ntrain 268 | test_l2 /= ntest 269 | 270 | t2 = default_timer() 271 | print(ep, t2-t1, train_mse, train_l2, test_l2) 272 | # torch.save(model, path_model) 273 | 274 | 275 | pred = torch.zeros(test_u.shape) 276 | index = 0 277 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 278 | with torch.no_grad(): 279 | for x, y in test_loader: 280 | test_l2 = 0; 281 | x, y = x.cuda(), y.cuda() 282 | 283 | out = model(x) 284 | out = y_normalizer.decode(out) 285 | pred[index] = out 286 | 287 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 288 | print(index, test_l2) 289 | index = index + 1 290 | 291 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 292 | 293 | 294 | 295 | 296 | -------------------------------------------------------------------------------- /lowrank_2d_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | activation = F.relu 20 | 21 | ################################################################ 22 | # lowrank layers 23 | ################################################################ 24 | class LowRank2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, s, ker_width, rank): 26 | super(LowRank2d, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.s = s 30 | self.n = s*s 31 | self.rank = rank 32 | 33 | self.phi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 34 | self.psi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 35 | 36 | 37 | def forward(self, v): 38 | batch_size = v.shape[0] 39 | 40 | phi_eval = self.phi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 41 | psi_eval = self.psi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 42 | 43 | # print(psi_eval.shape, v.shape, phi_eval.shape) 44 | v = torch.einsum('bnoir,bni,bmoir->bmo',psi_eval, v, phi_eval) 45 | 46 | return v 47 | 48 | 49 | 50 | class MyNet(torch.nn.Module): 51 | def __init__(self, s, width=16, ker_width=256, rank=16): 52 | super(MyNet, self).__init__() 53 | self.s = s 54 | self.width = width 55 | self.ker_width = ker_width 56 | self.rank = rank 57 | 58 | self.fc0 = nn.Linear(12, self.width) 59 | 60 | self.conv0 = LowRank2d(width, width, s, ker_width, rank) 61 | self.conv1 = LowRank2d(width, width, s, ker_width, rank) 62 | self.conv2 = LowRank2d(width, width, s, ker_width, rank) 63 | self.conv3 = LowRank2d(width, width, s, ker_width, rank) 64 | 65 | self.w0 = nn.Linear(self.width, self.width) 66 | self.w1 = nn.Linear(self.width, self.width) 67 | self.w2 = nn.Linear(self.width, self.width) 68 | self.w3 = nn.Linear(self.width, self.width) 69 | self.bn0 = torch.nn.BatchNorm1d(self.width) 70 | self.bn1 = torch.nn.BatchNorm1d(self.width) 71 | self.bn2 = torch.nn.BatchNorm1d(self.width) 72 | self.bn3 = torch.nn.BatchNorm1d(self.width) 73 | 74 | self.fc1 = nn.Linear(self.width, 128) 75 | self.fc2 = nn.Linear(128, 1) 76 | 77 | 78 | def forward(self, x): 79 | batch_size = x.shape[0] 80 | size_x, size_y = x.shape[1], x.shape[2] 81 | x = x.view(batch_size, size_x*size_y, -1) 82 | 83 | x = self.fc0(x) 84 | 85 | x1 = self.conv0(x) 86 | x2 = self.w0(x) 87 | x = x1 + x2 88 | x = self.bn0(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 89 | x = F.relu(x) 90 | x1 = self.conv1(x) 91 | x2 = self.w1(x) 92 | x = x1 + x2 93 | x = self.bn1(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 94 | x = F.relu(x) 95 | x1 = self.conv2(x) 96 | x2 = self.w2(x) 97 | x = x1 + x2 98 | x = self.bn2(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 99 | x = F.relu(x) 100 | x1 = self.conv3(x) 101 | x2 = self.w3(x) 102 | x = x1 + x2 103 | x = self.bn3(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 104 | 105 | x = self.fc1(x) 106 | x = F.relu(x) 107 | x = self.fc2(x) 108 | x = x.view(batch_size, size_x, size_y, -1) 109 | return x 110 | 111 | class Net2d(nn.Module): 112 | def __init__(self, width=12, ker_width=128, rank=4): 113 | super(Net2d, self).__init__() 114 | 115 | self.conv1 = MyNet(s=64, width=width, ker_width=ker_width, rank=rank) 116 | 117 | 118 | def forward(self, x): 119 | x = self.conv1(x) 120 | return x 121 | 122 | 123 | def count_params(self): 124 | c = 0 125 | for p in self.parameters(): 126 | c += reduce(operator.mul, list(p.size())) 127 | 128 | return c 129 | 130 | ################################################################ 131 | # configs 132 | ################################################################ 133 | # TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 134 | # TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 135 | # TRAIN_PATH = 'data/ns_data_V1000_N1000_train.mat' 136 | # TEST_PATH = 'data/ns_data_V1000_N1000_train_2.mat' 137 | # TRAIN_PATH = 'data/ns_data_V1000_N5000.mat' 138 | # TEST_PATH = 'data/ns_data_V1000_N5000.mat' 139 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 140 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 141 | 142 | ntrain = 1000 143 | ntest = 200 144 | 145 | batch_size = 5 146 | batch_size2 = batch_size 147 | 148 | epochs = 500 149 | learning_rate = 0.0025 150 | scheduler_step = 100 151 | scheduler_gamma = 0.5 152 | 153 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 154 | 155 | path = 'ns_lowrank_rnn_V100_T40_N'+str(ntrain)+'_ep' + str(epochs) + '_m' 156 | path_model = 'model/'+path 157 | path_train_err = 'results/'+path+'train.txt' 158 | path_test_err = 'results/'+path+'test.txt' 159 | path_image = 'image/'+path 160 | 161 | 162 | runtime = np.zeros(2, ) 163 | t1 = default_timer() 164 | 165 | 166 | sub = 1 167 | S = 64 168 | T_in = 10 169 | T = 40 170 | step = 1 171 | 172 | 173 | ################################################################ 174 | # load dataset 175 | ################################################################ 176 | reader = MatReader(TRAIN_PATH) 177 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 178 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 179 | 180 | reader = MatReader(TEST_PATH) 181 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 182 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 183 | 184 | print(train_u.shape) 185 | print(test_u.shape) 186 | assert (S == train_u.shape[-2]) 187 | assert (T == train_u.shape[-1]) 188 | 189 | 190 | train_a = train_a.reshape(ntrain,S,S,T_in) 191 | test_a = test_a.reshape(ntest,S,S,T_in) 192 | 193 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 194 | gridx = gridx.reshape(1, S, 1, 1).repeat([1, 1, S, 1]) 195 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 196 | gridy = gridy.reshape(1, 1, S, 1).repeat([1, S, 1, 1]) 197 | 198 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1]), gridy.repeat([ntrain,1,1,1]), train_a), dim=-1) 199 | test_a = torch.cat((gridx.repeat([ntest,1,1,1]), gridy.repeat([ntest,1,1,1]), test_a), dim=-1) 200 | 201 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 202 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 203 | 204 | t2 = default_timer() 205 | 206 | print('preprocessing finished, time used:', t2-t1) 207 | device = torch.device('cuda') 208 | 209 | ################################################################ 210 | # training and evaluation 211 | ################################################################ 212 | model = Net2d().cuda() 213 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 214 | 215 | print(model.count_params()) 216 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 217 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 218 | 219 | 220 | myloss = LpLoss(size_average=False) 221 | gridx = gridx.to(device) 222 | gridy = gridy.to(device) 223 | 224 | for ep in range(epochs): 225 | model.train() 226 | t1 = default_timer() 227 | train_l2_step = 0 228 | train_l2_full = 0 229 | for xx, yy in train_loader: 230 | loss = 0 231 | xx = xx.to(device) 232 | yy = yy.to(device) 233 | 234 | for t in range(0, T, step): 235 | y = yy[..., t:t + step] 236 | im = model(xx) 237 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 238 | 239 | if t == 0: 240 | pred = im 241 | else: 242 | pred = torch.cat((pred, im), -1) 243 | 244 | xx = torch.cat((xx[..., step:-2], im, 245 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 246 | 247 | train_l2_step += loss.item() 248 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 249 | train_l2_full += l2_full.item() 250 | 251 | optimizer.zero_grad() 252 | loss.backward() 253 | # l2_full.backward() 254 | optimizer.step() 255 | 256 | test_l2_step = 0 257 | test_l2_full = 0 258 | with torch.no_grad(): 259 | for xx, yy in test_loader: 260 | loss = 0 261 | xx = xx.to(device) 262 | yy = yy.to(device) 263 | 264 | for t in range(0, T, step): 265 | y = yy[..., t:t + step] 266 | im = model(xx) 267 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 268 | 269 | if t == 0: 270 | pred = im 271 | else: 272 | pred = torch.cat((pred, im), -1) 273 | 274 | xx = torch.cat((xx[..., step:-2], im, 275 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 276 | 277 | 278 | test_l2_step += loss.item() 279 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 280 | 281 | t2 = default_timer() 282 | scheduler.step() 283 | print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step), 284 | test_l2_full / ntest) 285 | # torch.save(model, path_model) 286 | 287 | 288 | # pred = torch.zeros(test_u.shape) 289 | # index = 0 290 | # test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 291 | # with torch.no_grad(): 292 | # for x, y in test_loader: 293 | # test_l2 = 0; 294 | # x, y = x.cuda(), y.cuda() 295 | # 296 | # out = model(x) 297 | # out = y_normalizer.decode(out) 298 | # pred[index] = out 299 | # 300 | # test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 301 | # print(index, test_l2) 302 | # index = index + 1 303 | 304 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 305 | 306 | -------------------------------------------------------------------------------- /fourier_2d_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | ################################################################ 20 | # fourier layer 21 | ################################################################ 22 | 23 | def compl_mul2d(a, b): 24 | op = partial(torch.einsum, "bctq,dctq->bdtq") 25 | return torch.stack([ 26 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 27 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 28 | ], dim=-1) 29 | 30 | 31 | class SpectralConv2d_fast(nn.Module): 32 | def __init__(self, in_channels, out_channels, modes1, modes2): 33 | super(SpectralConv2d_fast, self).__init__() 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 37 | self.modes2 = modes2 38 | 39 | self.scale = (1 / (in_channels * out_channels)) 40 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 41 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 42 | 43 | def forward(self, x): 44 | batchsize = x.shape[0] 45 | #Compute Fourier coeffcients up to factor of e^(- something constant) 46 | x_ft = torch.rfft(x, 2, normalized=True, onesided=True) 47 | 48 | # Multiply relevant Fourier modes 49 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 50 | out_ft[:, :, :self.modes1, :self.modes2] = \ 51 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 52 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 53 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 54 | 55 | #Return to physical space 56 | x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=(x.size(-2), x.size(-1))) 57 | return x 58 | 59 | class SimpleBlock2d(nn.Module): 60 | def __init__(self, modes1, modes2, width): 61 | super(SimpleBlock2d, self).__init__() 62 | 63 | self.modes1 = modes1 64 | self.modes2 = modes2 65 | self.width = width 66 | self.fc0 = nn.Linear(12, self.width) 67 | 68 | self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 69 | self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 70 | self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 71 | self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 72 | self.w0 = nn.Conv1d(self.width, self.width, 1) 73 | self.w1 = nn.Conv1d(self.width, self.width, 1) 74 | self.w2 = nn.Conv1d(self.width, self.width, 1) 75 | self.w3 = nn.Conv1d(self.width, self.width, 1) 76 | self.bn0 = torch.nn.BatchNorm2d(self.width) 77 | self.bn1 = torch.nn.BatchNorm2d(self.width) 78 | self.bn2 = torch.nn.BatchNorm2d(self.width) 79 | self.bn3 = torch.nn.BatchNorm2d(self.width) 80 | 81 | 82 | self.fc1 = nn.Linear(self.width, 128) 83 | self.fc2 = nn.Linear(128, 1) 84 | 85 | def forward(self, x): 86 | batchsize = x.shape[0] 87 | size_x, size_y = x.shape[1], x.shape[2] 88 | 89 | x = self.fc0(x) 90 | x = x.permute(0, 3, 1, 2) 91 | 92 | x1 = self.conv0(x) 93 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 94 | x = self.bn0(x1 + x2) 95 | x = F.relu(x) 96 | x1 = self.conv1(x) 97 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 98 | x = self.bn1(x1 + x2) 99 | x = F.relu(x) 100 | x1 = self.conv2(x) 101 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 102 | x = self.bn2(x1 + x2) 103 | x = F.relu(x) 104 | x1 = self.conv3(x) 105 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 106 | x = self.bn3(x1 + x2) 107 | 108 | 109 | x = x.permute(0, 2, 3, 1) 110 | x = self.fc1(x) 111 | x = F.relu(x) 112 | x = self.fc2(x) 113 | return x 114 | 115 | class Net2d(nn.Module): 116 | def __init__(self, modes, width): 117 | super(Net2d, self).__init__() 118 | 119 | self.conv1 = SimpleBlock2d(modes, modes, width) 120 | 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | return x 125 | 126 | 127 | def count_params(self): 128 | c = 0 129 | for p in self.parameters(): 130 | c += reduce(operator.mul, list(p.size())) 131 | 132 | return c 133 | 134 | 135 | ################################################################ 136 | # configs 137 | ################################################################ 138 | TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 139 | TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 140 | 141 | ntrain = 1000 142 | ntest = 200 143 | 144 | modes = 12 145 | width = 20 146 | 147 | batch_size = 20 148 | batch_size2 = batch_size 149 | 150 | epochs = 500 151 | learning_rate = 0.0025 152 | scheduler_step = 100 153 | scheduler_gamma = 0.5 154 | 155 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 156 | 157 | path = 'ns_fourier_2d_rnn_V10000_T20_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 158 | path_model = 'model/'+path 159 | path_train_err = 'results/'+path+'train.txt' 160 | path_test_err = 'results/'+path+'test.txt' 161 | path_image = 'image/'+path 162 | 163 | runtime = np.zeros(2, ) 164 | t1 = default_timer() 165 | 166 | sub = 1 167 | S = 64 168 | T_in = 10 169 | T = 10 170 | step = 1 171 | 172 | ################################################################ 173 | # load data 174 | ################################################################ 175 | 176 | reader = MatReader(TRAIN_PATH) 177 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 178 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 179 | 180 | reader = MatReader(TEST_PATH) 181 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 182 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 183 | 184 | print(train_u.shape) 185 | print(test_u.shape) 186 | assert (S == train_u.shape[-2]) 187 | assert (T == train_u.shape[-1]) 188 | 189 | train_a = train_a.reshape(ntrain,S,S,T_in) 190 | test_a = test_a.reshape(ntest,S,S,T_in) 191 | 192 | # pad the location (x,y) 193 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 194 | gridx = gridx.reshape(1, S, 1, 1).repeat([1, 1, S, 1]) 195 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 196 | gridy = gridy.reshape(1, 1, S, 1).repeat([1, S, 1, 1]) 197 | 198 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1]), gridy.repeat([ntrain,1,1,1]), train_a), dim=-1) 199 | test_a = torch.cat((gridx.repeat([ntest,1,1,1]), gridy.repeat([ntest,1,1,1]), test_a), dim=-1) 200 | 201 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 202 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 203 | 204 | t2 = default_timer() 205 | 206 | print('preprocessing finished, time used:', t2-t1) 207 | device = torch.device('cuda') 208 | 209 | ################################################################ 210 | # training and evaluation 211 | ################################################################ 212 | 213 | model = Net2d(modes, width).cuda() 214 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 215 | 216 | print(model.count_params()) 217 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 218 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 219 | 220 | 221 | myloss = LpLoss(size_average=False) 222 | gridx = gridx.to(device) 223 | gridy = gridy.to(device) 224 | 225 | for ep in range(epochs): 226 | model.train() 227 | t1 = default_timer() 228 | train_l2_step = 0 229 | train_l2_full = 0 230 | for xx, yy in train_loader: 231 | loss = 0 232 | xx = xx.to(device) 233 | yy = yy.to(device) 234 | 235 | for t in range(0, T, step): 236 | y = yy[..., t:t + step] 237 | im = model(xx) 238 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 239 | 240 | if t == 0: 241 | pred = im 242 | else: 243 | pred = torch.cat((pred, im), -1) 244 | 245 | xx = torch.cat((xx[..., step:-2], im, 246 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 247 | 248 | train_l2_step += loss.item() 249 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 250 | train_l2_full += l2_full.item() 251 | 252 | optimizer.zero_grad() 253 | loss.backward() 254 | # l2_full.backward() 255 | optimizer.step() 256 | 257 | test_l2_step = 0 258 | test_l2_full = 0 259 | with torch.no_grad(): 260 | for xx, yy in test_loader: 261 | loss = 0 262 | xx = xx.to(device) 263 | yy = yy.to(device) 264 | 265 | for t in range(0, T, step): 266 | y = yy[..., t:t + step] 267 | im = model(xx) 268 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 269 | 270 | if t == 0: 271 | pred = im 272 | else: 273 | pred = torch.cat((pred, im), -1) 274 | 275 | xx = torch.cat((xx[..., step:-2], im, 276 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 277 | 278 | 279 | test_l2_step += loss.item() 280 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 281 | 282 | t2 = default_timer() 283 | scheduler.step() 284 | print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step), 285 | test_l2_full / ntest) 286 | # torch.save(model, path_model) 287 | 288 | 289 | # pred = torch.zeros(test_u.shape) 290 | # index = 0 291 | # test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 292 | # with torch.no_grad(): 293 | # for x, y in test_loader: 294 | # test_l2 = 0; 295 | # x, y = x.cuda(), y.cuda() 296 | # 297 | # out = model(x) 298 | # out = y_normalizer.decode(out) 299 | # pred[index] = out 300 | # 301 | # test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 302 | # print(index, test_l2) 303 | # index = index + 1 304 | 305 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 306 | 307 | 308 | 309 | 310 | -------------------------------------------------------------------------------- /fourier_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | activation = F.relu 20 | 21 | ################################################################ 22 | # 3d fourier layers 23 | ################################################################ 24 | 25 | def compl_mul3d(a, b): 26 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 27 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 28 | return torch.stack([ 29 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 30 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 31 | ], dim=-1) 32 | 33 | class SpectralConv3d_fast(nn.Module): 34 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 35 | super(SpectralConv3d_fast, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 39 | self.modes2 = modes2 40 | self.modes3 = modes3 41 | 42 | self.scale = (1 / (in_channels * out_channels)) 43 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 44 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 45 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 46 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 47 | 48 | def forward(self, x): 49 | batchsize = x.shape[0] 50 | #Compute Fourier coeffcients up to factor of e^(- something constant) 51 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 52 | 53 | # Multiply relevant Fourier modes 54 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 55 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 56 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 57 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 58 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 59 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 60 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 61 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 62 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 63 | 64 | #Return to physical space 65 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 66 | return x 67 | 68 | class SimpleBlock2d(nn.Module): 69 | def __init__(self, modes1, modes2, modes3, width): 70 | super(SimpleBlock2d, self).__init__() 71 | 72 | self.modes1 = modes1 73 | self.modes2 = modes2 74 | self.modes3 = modes3 75 | self.width = width 76 | self.fc0 = nn.Linear(13, self.width) 77 | 78 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 79 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 80 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 81 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 82 | self.w0 = nn.Conv1d(self.width, self.width, 1) 83 | self.w1 = nn.Conv1d(self.width, self.width, 1) 84 | self.w2 = nn.Conv1d(self.width, self.width, 1) 85 | self.w3 = nn.Conv1d(self.width, self.width, 1) 86 | self.bn0 = torch.nn.BatchNorm3d(self.width) 87 | self.bn1 = torch.nn.BatchNorm3d(self.width) 88 | self.bn2 = torch.nn.BatchNorm3d(self.width) 89 | self.bn3 = torch.nn.BatchNorm3d(self.width) 90 | 91 | 92 | self.fc1 = nn.Linear(self.width, 128) 93 | self.fc2 = nn.Linear(128, 1) 94 | 95 | def forward(self, x): 96 | batchsize = x.shape[0] 97 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 98 | 99 | x = self.fc0(x) 100 | x = x.permute(0, 4, 1, 2, 3) 101 | 102 | x1 = self.conv0(x) 103 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 104 | x = self.bn0(x1 + x2) 105 | x = F.relu(x) 106 | x1 = self.conv1(x) 107 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 108 | x = self.bn1(x1 + x2) 109 | x = F.relu(x) 110 | x1 = self.conv2(x) 111 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 112 | x = self.bn2(x1 + x2) 113 | x = F.relu(x) 114 | x1 = self.conv3(x) 115 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 116 | x = self.bn3(x1 + x2) 117 | 118 | 119 | x = x.permute(0, 2, 3, 4, 1) 120 | x = self.fc1(x) 121 | x = F.relu(x) 122 | x = self.fc2(x) 123 | return x 124 | 125 | class Net2d(nn.Module): 126 | def __init__(self, modes, width): 127 | super(Net2d, self).__init__() 128 | 129 | self.conv1 = SimpleBlock2d(modes, modes, modes, width) 130 | 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | return x.squeeze() 135 | 136 | 137 | def count_params(self): 138 | c = 0 139 | for p in self.parameters(): 140 | c += reduce(operator.mul, list(p.size())) 141 | 142 | return c 143 | 144 | ################################################################ 145 | # configs 146 | ################################################################ 147 | 148 | # TRAIN_PATH = 'data/ns_data_V1000_N1000_train.mat' 149 | # TEST_PATH = 'data/ns_data_V1000_N1000_train_2.mat' 150 | # TRAIN_PATH = 'data/ns_data_V1000_N5000.mat' 151 | # TEST_PATH = 'data/ns_data_V1000_N5000.mat' 152 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 153 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 154 | 155 | ntrain = 1000 156 | ntest = 200 157 | 158 | modes = 4 159 | width = 20 160 | 161 | batch_size = 10 162 | batch_size2 = batch_size 163 | 164 | epochs = 10 165 | learning_rate = 0.0025 166 | scheduler_step = 100 167 | scheduler_gamma = 0.5 168 | 169 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 170 | 171 | path = 'test' 172 | # path = 'ns_fourier_V100_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 173 | path_model = 'model/'+path 174 | path_train_err = 'results/'+path+'train.txt' 175 | path_test_err = 'results/'+path+'test.txt' 176 | path_image = 'image/'+path 177 | 178 | 179 | runtime = np.zeros(2, ) 180 | t1 = default_timer() 181 | 182 | 183 | sub = 2 184 | S = 64 // sub 185 | T_in = 10 186 | T = 40 187 | 188 | ################################################################ 189 | # load data 190 | ################################################################ 191 | 192 | reader = MatReader(TRAIN_PATH) 193 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 194 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 195 | 196 | reader = MatReader(TEST_PATH) 197 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 198 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 199 | 200 | print(train_u.shape) 201 | print(test_u.shape) 202 | assert (S == train_u.shape[-2]) 203 | assert (T == train_u.shape[-1]) 204 | 205 | 206 | a_normalizer = UnitGaussianNormalizer(train_a) 207 | train_a = a_normalizer.encode(train_a) 208 | test_a = a_normalizer.encode(test_a) 209 | 210 | y_normalizer = UnitGaussianNormalizer(train_u) 211 | train_u = y_normalizer.encode(train_u) 212 | 213 | train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1]) 214 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 215 | 216 | # pad locations (x,y,t) 217 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 218 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 219 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 220 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 221 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 222 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 223 | 224 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 225 | gridt.repeat([ntrain,1,1,1,1]), train_a), dim=-1) 226 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 227 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 228 | 229 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 230 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 231 | 232 | t2 = default_timer() 233 | 234 | print('preprocessing finished, time used:', t2-t1) 235 | device = torch.device('cuda') 236 | 237 | ################################################################ 238 | # training and evaluation 239 | ################################################################ 240 | model = Net2d(modes, width).cuda() 241 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 242 | 243 | print(model.count_params()) 244 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 245 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 246 | 247 | 248 | myloss = LpLoss(size_average=False) 249 | y_normalizer.cuda() 250 | for ep in range(epochs): 251 | model.train() 252 | t1 = default_timer() 253 | train_mse = 0 254 | train_l2 = 0 255 | for x, y in train_loader: 256 | x, y = x.cuda(), y.cuda() 257 | 258 | optimizer.zero_grad() 259 | out = model(x) 260 | 261 | mse = F.mse_loss(out, y, reduction='mean') 262 | # mse.backward() 263 | 264 | y = y_normalizer.decode(y) 265 | out = y_normalizer.decode(out) 266 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 267 | l2.backward() 268 | 269 | optimizer.step() 270 | train_mse += mse.item() 271 | train_l2 += l2.item() 272 | 273 | scheduler.step() 274 | 275 | model.eval() 276 | test_l2 = 0.0 277 | with torch.no_grad(): 278 | for x, y in test_loader: 279 | x, y = x.cuda(), y.cuda() 280 | 281 | out = model(x) 282 | out = y_normalizer.decode(out) 283 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 284 | 285 | train_mse /= len(train_loader) 286 | train_l2 /= ntrain 287 | test_l2 /= ntest 288 | 289 | t2 = default_timer() 290 | print(ep, t2-t1, train_mse, train_l2, test_l2) 291 | # torch.save(model, path_model) 292 | 293 | 294 | pred = torch.zeros(test_u.shape) 295 | index = 0 296 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 297 | with torch.no_grad(): 298 | for x, y in test_loader: 299 | test_l2 = 0 300 | x, y = x.cuda(), y.cuda() 301 | 302 | out = model(x) 303 | out = y_normalizer.decode(out) 304 | pred[index] = out 305 | 306 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 307 | print(index, test_l2) 308 | index = index + 1 309 | 310 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 311 | 312 | 313 | 314 | 315 | -------------------------------------------------------------------------------- /scripts/ns_fourier_3d_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | ################################################################ 20 | # fourier layers 21 | ################################################################ 22 | 23 | def compl_mul3d(a, b): 24 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 25 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 26 | return torch.stack([ 27 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 28 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 29 | ], dim=-1) 30 | 31 | class SpectralConv3d_fast(nn.Module): 32 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 33 | super(SpectralConv3d_fast, self).__init__() 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 37 | self.modes2 = modes2 38 | self.modes3 = modes3 39 | 40 | self.scale = (1 / (in_channels * out_channels)) 41 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 42 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 43 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 44 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 45 | 46 | def forward(self, x): 47 | batchsize = x.shape[0] 48 | #Compute Fourier coeffcients up to factor of e^(- something constant) 49 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 50 | 51 | # Multiply relevant Fourier modes 52 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 53 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 54 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 55 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 56 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 57 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 58 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 59 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 60 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 61 | 62 | #Return to physical space 63 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 64 | return x 65 | 66 | class SimpleBlock2d(nn.Module): 67 | def __init__(self, modes1, modes2, modes3, width): 68 | super(SimpleBlock2d, self).__init__() 69 | 70 | self.modes1 = modes1 71 | self.modes2 = modes2 72 | self.modes3 = modes3 73 | self.width = width 74 | self.fc0 = nn.Linear(4, self.width) 75 | 76 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 77 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 78 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 79 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 80 | self.w0 = nn.Conv1d(self.width, self.width, 1) 81 | self.w1 = nn.Conv1d(self.width, self.width, 1) 82 | self.w2 = nn.Conv1d(self.width, self.width, 1) 83 | self.w3 = nn.Conv1d(self.width, self.width, 1) 84 | self.bn0 = torch.nn.BatchNorm3d(self.width) 85 | self.bn1 = torch.nn.BatchNorm3d(self.width) 86 | self.bn2 = torch.nn.BatchNorm3d(self.width) 87 | self.bn3 = torch.nn.BatchNorm3d(self.width) 88 | 89 | 90 | self.fc1 = nn.Linear(self.width, 128) 91 | self.fc2 = nn.Linear(128, 1) 92 | 93 | def forward(self, x): 94 | batchsize = x.shape[0] 95 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 96 | 97 | x = self.fc0(x) 98 | x = x.permute(0, 4, 1, 2, 3) 99 | 100 | x1 = self.conv0(x) 101 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 102 | x = self.bn0(x1 + x2) 103 | x = F.relu(x) 104 | x1 = self.conv1(x) 105 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 106 | x = self.bn1(x1 + x2) 107 | x = F.relu(x) 108 | x1 = self.conv2(x) 109 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 110 | x = self.bn2(x1 + x2) 111 | x = F.relu(x) 112 | x1 = self.conv3(x) 113 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 114 | x = self.bn3(x1 + x2) 115 | 116 | 117 | x = x.permute(0, 2, 3, 4, 1) 118 | x = self.fc1(x) 119 | x = F.relu(x) 120 | x = self.fc2(x) 121 | return x 122 | 123 | class Net2d(nn.Module): 124 | def __init__(self, modes, width): 125 | super(Net2d, self).__init__() 126 | 127 | self.conv1 = SimpleBlock2d(modes, modes, 4, width) 128 | 129 | 130 | def forward(self, x): 131 | x = self.conv1(x) 132 | return x 133 | 134 | 135 | def count_params(self): 136 | c = 0 137 | for p in self.parameters(): 138 | c += reduce(operator.mul, list(p.size())) 139 | 140 | return c 141 | 142 | ################################################################ 143 | # configs 144 | ################################################################ 145 | 146 | TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 147 | TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 148 | 149 | ntrain = 1000 150 | ntest = 200 151 | 152 | modes = 12 153 | width = 20 154 | 155 | batch_size = 20 156 | batch_size2 = batch_size 157 | 158 | 159 | epochs = 500 160 | learning_rate = 0.0025 161 | scheduler_step = 100 162 | scheduler_gamma = 0.5 163 | 164 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 165 | 166 | path = 'ns_fourier_3d_rnn_V10000_T20_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 167 | path_model = 'model/'+path 168 | path_train_err = 'results/'+path+'train.txt' 169 | path_test_err = 'results/'+path+'test.txt' 170 | path_image = 'image/'+path 171 | 172 | 173 | runtime = np.zeros(2, ) 174 | t1 = default_timer() 175 | 176 | 177 | sub = 1 178 | S = 64 179 | T_in = 10 180 | T_start = 0 181 | step = T_in - T_start 182 | T = 10 183 | 184 | ################################################################ 185 | # load data 186 | ################################################################ 187 | 188 | reader = MatReader(TRAIN_PATH) 189 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,T_start:T_in] 190 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 191 | 192 | reader = MatReader(TEST_PATH) 193 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,T_start:T_in] 194 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 195 | 196 | print(train_u.shape, test_u.shape) 197 | assert (S == train_u.shape[-2]) 198 | assert (T == train_u.shape[-1]) 199 | 200 | 201 | 202 | train_a = train_a.reshape(ntrain,S,S,step,1) 203 | test_a = test_a.reshape(ntest,S,S,step,1) 204 | 205 | # cat the location information (x,y,t) 206 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 207 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, step, 1]) 208 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 209 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, step, 1]) 210 | gridt = torch.tensor(np.linspace(0, 1, step+1)[1:], dtype=torch.float) 211 | gridt = gridt.reshape(1, 1, 1, step, 1).repeat([1, S, S, 1, 1]) 212 | 213 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 214 | gridt.repeat([ntrain,1,1,1,1]), train_a), dim=-1) 215 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 216 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 217 | 218 | print(train_a.shape, train_u.shape) 219 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 220 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 221 | 222 | t2 = default_timer() 223 | 224 | print('preprocessing finished, time used:', t2-t1) 225 | device = torch.device('cuda') 226 | 227 | ################################################################ 228 | # training and evaluation 229 | ################################################################ 230 | model = Net2d(modes, width).cuda() 231 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 232 | 233 | print(model.count_params()) 234 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 235 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 236 | 237 | myloss = LpLoss(size_average=False) 238 | 239 | gridx = gridx.to(device) 240 | gridy = gridy.to(device) 241 | gridt = gridt.to(device) 242 | for ep in range(epochs): 243 | model.train() 244 | t1 = default_timer() 245 | train_l2_step = 0 246 | train_l2_full = 0 247 | for xx, yy in train_loader: 248 | loss = 0 249 | xx = xx.to(device) 250 | yy = yy.to(device) 251 | 252 | for t in range(0, T, step): 253 | y = yy[..., t:t+step] 254 | im = model(xx) 255 | loss += myloss(im.reshape(batch_size,-1), y.reshape(batch_size,-1)) 256 | 257 | if t == 0: 258 | pred = im.squeeze() 259 | else: 260 | pred = torch.cat((pred, im.squeeze()), -1) 261 | 262 | im = torch.cat((gridx.repeat([batch_size, 1, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1, 1]), 263 | gridt.repeat([batch_size, 1, 1, 1, 1]), im), dim=-1) 264 | xx = torch.cat([xx[..., step:, :], im], -2) 265 | 266 | train_l2_step += loss.item() 267 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 268 | train_l2_full += l2_full.item() 269 | 270 | optimizer.zero_grad() 271 | loss.backward() 272 | # l2_full.backward() 273 | optimizer.step() 274 | 275 | test_l2_step = 0 276 | test_l2_full = 0 277 | with torch.no_grad(): 278 | for xx, yy in test_loader: 279 | loss = 0 280 | xx = xx.to(device) 281 | yy = yy.to(device) 282 | 283 | for t in range(0, T, step): 284 | y = yy[..., t:t + step] 285 | im = model(xx) 286 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 287 | 288 | if t == 0: 289 | pred = im.squeeze() 290 | else: 291 | pred = torch.cat((pred, im.squeeze()), -1) 292 | 293 | im = torch.cat((gridx.repeat([batch_size, 1, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1, 1]), 294 | gridt.repeat([batch_size, 1, 1, 1, 1]), im), dim=-1) 295 | xx = torch.cat([xx[..., step:, :], im], -2) 296 | 297 | test_l2_step += loss.item() 298 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 299 | 300 | t2 = default_timer() 301 | scheduler.step() 302 | print(ep, t2-t1, train_l2_step/ntrain/(T/step), train_l2_full/ntrain, test_l2_step/ntest/(T/step), test_l2_full/ntest) 303 | torch.save(model, path_model) 304 | 305 | 306 | # pred = torch.zeros(test_u.shape) 307 | # index = 0 308 | # test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 309 | # with torch.no_grad(): 310 | # for x, y in test_loader: 311 | # test_l2 = 0; 312 | # x, y = x.cuda(), y.cuda() 313 | # 314 | # out = model(x) 315 | # out = y_normalizer.decode(out) 316 | # pred[index] = out 317 | # 318 | # test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 319 | # print(index, test_l2) 320 | # index = index + 1 321 | 322 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 323 | 324 | 325 | 326 | 327 | --------------------------------------------------------------------------------