├── .gitignore ├── FNO-torch.1.6 ├── fourier_1d.py ├── fourier_2d.py ├── fourier_2d_time.py └── fourier_3d.py ├── LICENSE ├── README.md ├── data_generation ├── burgers │ ├── GRF1.m │ ├── burgers1.m │ └── gen_burgers1.m ├── darcy │ ├── GRF.m │ ├── demo.m │ └── solve_gwf.m └── navier_stokes │ ├── ns_2d.py │ └── random_fields.py ├── fourier_1d.py ├── fourier_2d.py ├── fourier_2d_time.py ├── fourier_3d.py ├── lowrank_operators ├── lowrank_1d.py ├── lowrank_2d.py ├── lowrank_2d_time.py └── lowrank_3d.py ├── scripts ├── eval.py ├── fourier_2d_tuned.py ├── fourier_3d_time.py ├── fourier_on_images.py └── super_resolution.py └── utilities3.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ -------------------------------------------------------------------------------- /FNO-torch.1.6/fourier_1d.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 1D problem such as the (time-independent) Burgers equation discussed in Section 5.1 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.parameter import Parameter 12 | import matplotlib.pyplot as plt 13 | 14 | import operator 15 | from functools import reduce 16 | from functools import partial 17 | from timeit import default_timer 18 | from utilities3 import * 19 | 20 | torch.manual_seed(0) 21 | np.random.seed(0) 22 | 23 | #Complex multiplication 24 | def compl_mul1d(a, b): 25 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 26 | op = partial(torch.einsum, "bix,iox->box") 27 | return torch.stack([ 28 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 29 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 30 | ], dim=-1) 31 | 32 | ################################################################ 33 | # 1d fourier layer 34 | ################################################################ 35 | class SpectralConv1d(nn.Module): 36 | def __init__(self, in_channels, out_channels, modes1): 37 | super(SpectralConv1d, self).__init__() 38 | 39 | """ 40 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 41 | """ 42 | 43 | self.in_channels = in_channels 44 | self.out_channels = out_channels 45 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 46 | 47 | self.scale = (1 / (in_channels*out_channels)) 48 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, 2)) 49 | 50 | def forward(self, x): 51 | batchsize = x.shape[0] 52 | #Compute Fourier coeffcients up to factor of e^(- something constant) 53 | x_ft = torch.rfft(x, 1, normalized=True, onesided=True) 54 | 55 | # Multiply relevant Fourier modes 56 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, 2, device=x.device) 57 | out_ft[:, :, :self.modes1] = compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 58 | 59 | #Return to physical space 60 | x = torch.irfft(out_ft, 1, normalized=True, onesided=True, signal_sizes=(x.size(-1), )) 61 | return x 62 | 63 | class SimpleBlock1d(nn.Module): 64 | def __init__(self, modes, width): 65 | super(SimpleBlock1d, self).__init__() 66 | 67 | """ 68 | The overall network. It contains 4 layers of the Fourier layer. 69 | 1. Lift the input to the desire channel dimension by self.fc0 . 70 | 2. 4 layers of the integral operators u' = (W + K)(u). 71 | W defined by self.w; K defined by self.conv . 72 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 73 | 74 | input: the solution of the initial condition and location (a(x), x) 75 | input shape: (batchsize, x=s, c=2) 76 | output: the solution of a later timestep 77 | output shape: (batchsize, x=s, c=1) 78 | """ 79 | 80 | self.modes1 = modes 81 | self.width = width 82 | self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x) 83 | 84 | self.conv0 = SpectralConv1d(self.width, self.width, self.modes1) 85 | self.conv1 = SpectralConv1d(self.width, self.width, self.modes1) 86 | self.conv2 = SpectralConv1d(self.width, self.width, self.modes1) 87 | self.conv3 = SpectralConv1d(self.width, self.width, self.modes1) 88 | self.w0 = nn.Conv1d(self.width, self.width, 1) 89 | self.w1 = nn.Conv1d(self.width, self.width, 1) 90 | self.w2 = nn.Conv1d(self.width, self.width, 1) 91 | self.w3 = nn.Conv1d(self.width, self.width, 1) 92 | 93 | 94 | self.fc1 = nn.Linear(self.width, 128) 95 | self.fc2 = nn.Linear(128, 1) 96 | 97 | def forward(self, x): 98 | 99 | x = self.fc0(x) 100 | x = x.permute(0, 2, 1) 101 | 102 | x1 = self.conv0(x) 103 | x2 = self.w0(x) 104 | x = x1 + x2 105 | x = F.relu(x) 106 | 107 | x1 = self.conv1(x) 108 | x2 = self.w1(x) 109 | x = x1 + x2 110 | x = F.relu(x) 111 | 112 | x1 = self.conv2(x) 113 | x2 = self.w2(x) 114 | x = x1 + x2 115 | x = F.relu(x) 116 | 117 | x1 = self.conv3(x) 118 | x2 = self.w3(x) 119 | x = x1 + x2 120 | 121 | x = x.permute(0, 2, 1) 122 | x = self.fc1(x) 123 | x = F.relu(x) 124 | x = self.fc2(x) 125 | return x 126 | 127 | class Net1d(nn.Module): 128 | def __init__(self, modes, width): 129 | super(Net1d, self).__init__() 130 | 131 | """ 132 | A wrapper function 133 | """ 134 | 135 | self.conv1 = SimpleBlock1d(modes, width) 136 | 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | return x.squeeze() 141 | 142 | def count_params(self): 143 | c = 0 144 | for p in self.parameters(): 145 | c += reduce(operator.mul, list(p.size())) 146 | 147 | return c 148 | 149 | 150 | ################################################################ 151 | # configurations 152 | ################################################################ 153 | ntrain = 1000 154 | ntest = 100 155 | 156 | sub = 2**3 #subsampling rate 157 | h = 2**13 // sub #total grid size divided by the subsampling rate 158 | s = h 159 | 160 | batch_size = 20 161 | learning_rate = 0.001 162 | 163 | epochs = 500 164 | step_size = 100 165 | gamma = 0.5 166 | 167 | modes = 16 168 | width = 64 169 | 170 | 171 | ################################################################ 172 | # read data 173 | ################################################################ 174 | 175 | # Data is of the shape (number of samples, grid size) 176 | dataloader = MatReader('data/burgers_data_R10.mat') 177 | x_data = dataloader.read_field('a')[:,::sub] 178 | y_data = dataloader.read_field('u')[:,::sub] 179 | 180 | x_train = x_data[:ntrain,:] 181 | y_train = y_data[:ntrain,:] 182 | x_test = x_data[-ntest:,:] 183 | y_test = y_data[-ntest:,:] 184 | 185 | # cat the locations information 186 | grid = np.linspace(0, 2*np.pi, s).reshape(1, s, 1) 187 | grid = torch.tensor(grid, dtype=torch.float) 188 | x_train = torch.cat([x_train.reshape(ntrain,s,1), grid.repeat(ntrain,1,1)], dim=2) 189 | x_test = torch.cat([x_test.reshape(ntest,s,1), grid.repeat(ntest,1,1)], dim=2) 190 | 191 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 192 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 193 | 194 | # model 195 | model = Net1d(modes, width).cuda() 196 | print(model.count_params()) 197 | 198 | 199 | ################################################################ 200 | # training and evaluation 201 | ################################################################ 202 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 203 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 204 | 205 | myloss = LpLoss(size_average=False) 206 | for ep in range(epochs): 207 | model.train() 208 | t1 = default_timer() 209 | train_mse = 0 210 | train_l2 = 0 211 | for x, y in train_loader: 212 | x, y = x.cuda(), y.cuda() 213 | 214 | optimizer.zero_grad() 215 | out = model(x) 216 | 217 | mse = F.mse_loss(out, y, reduction='mean') 218 | # mse.backward() 219 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 220 | l2.backward() # use the l2 relative loss 221 | 222 | optimizer.step() 223 | train_mse += mse.item() 224 | train_l2 += l2.item() 225 | 226 | scheduler.step() 227 | model.eval() 228 | test_l2 = 0.0 229 | with torch.no_grad(): 230 | for x, y in test_loader: 231 | x, y = x.cuda(), y.cuda() 232 | 233 | out = model(x) 234 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 235 | 236 | train_mse /= len(train_loader) 237 | train_l2 /= ntrain 238 | test_l2 /= ntest 239 | 240 | t2 = default_timer() 241 | print(ep, t2-t1, train_mse, train_l2, test_l2) 242 | 243 | # torch.save(model, 'model/ns_fourier_burgers_8192') 244 | pred = torch.zeros(y_test.shape) 245 | index = 0 246 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 247 | with torch.no_grad(): 248 | for x, y in test_loader: 249 | test_l2 = 0 250 | x, y = x.cuda(), y.cuda() 251 | 252 | out = model(x) 253 | pred[index] = out 254 | 255 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 256 | print(index, test_l2) 257 | index = index + 1 258 | 259 | # scipy.io.savemat('pred/burger_test.mat', mdict={'pred': pred.cpu().numpy()}) 260 | -------------------------------------------------------------------------------- /FNO-torch.1.6/fourier_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 2D problem such as the Darcy Flow discussed in Section 5.2 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.parameter import Parameter 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | import operator 15 | from functools import reduce 16 | from functools import partial 17 | 18 | from timeit import default_timer 19 | from utilities3 import * 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | #Complex multiplication 25 | def compl_mul2d(a, b): 26 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 27 | op = partial(torch.einsum, "bixy,ioxy->boxy") 28 | return torch.stack([ 29 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 30 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 31 | ], dim=-1) 32 | 33 | 34 | ################################################################ 35 | # fourier layer 36 | ################################################################ 37 | class SpectralConv2d(nn.Module): 38 | def __init__(self, in_channels, out_channels, modes1, modes2): 39 | super(SpectralConv2d, self).__init__() 40 | 41 | """ 42 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 43 | """ 44 | 45 | self.in_channels = in_channels 46 | self.out_channels = out_channels 47 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 48 | self.modes2 = modes2 49 | 50 | self.scale = (1 / (in_channels * out_channels)) 51 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 52 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 53 | 54 | def forward(self, x): 55 | batchsize = x.shape[0] 56 | #Compute Fourier coeffcients up to factor of e^(- something constant) 57 | x_ft = torch.rfft(x, 2, normalized=True, onesided=True) 58 | 59 | # Multiply relevant Fourier modes 60 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 61 | out_ft[:, :, :self.modes1, :self.modes2] = \ 62 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 63 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 64 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 65 | 66 | #Return to physical space 67 | x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=( x.size(-2), x.size(-1))) 68 | return x 69 | 70 | class SimpleBlock2d(nn.Module): 71 | def __init__(self, modes1, modes2, width): 72 | super(SimpleBlock2d, self).__init__() 73 | 74 | """ 75 | The overall network. It contains 4 layers of the Fourier layer. 76 | 1. Lift the input to the desire channel dimension by self.fc0 . 77 | 2. 4 layers of the integral operators u' = (W + K)(u). 78 | W defined by self.w; K defined by self.conv . 79 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 80 | 81 | input: the solution of the coefficient function and locations (a(x, y), x, y) 82 | input shape: (batchsize, x=s, y=s, c=3) 83 | output: the solution 84 | output shape: (batchsize, x=s, y=s, c=1) 85 | """ 86 | 87 | self.modes1 = modes1 88 | self.modes2 = modes2 89 | self.width = width 90 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 91 | 92 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 94 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 95 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 96 | self.w0 = nn.Conv1d(self.width, self.width, 1) 97 | self.w1 = nn.Conv1d(self.width, self.width, 1) 98 | self.w2 = nn.Conv1d(self.width, self.width, 1) 99 | self.w3 = nn.Conv1d(self.width, self.width, 1) 100 | 101 | 102 | self.fc1 = nn.Linear(self.width, 128) 103 | self.fc2 = nn.Linear(128, 1) 104 | 105 | def forward(self, x): 106 | batchsize = x.shape[0] 107 | size_x, size_y = x.shape[1], x.shape[2] 108 | 109 | x = self.fc0(x) 110 | x = x.permute(0, 3, 1, 2) 111 | 112 | x1 = self.conv0(x) 113 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 114 | x = x1 + x2 115 | x = F.relu(x) 116 | 117 | x1 = self.conv1(x) 118 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 119 | x = x1 + x2 120 | x = F.relu(x) 121 | 122 | x1 = self.conv2(x) 123 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 124 | x = x1 + x2 125 | x = F.relu(x) 126 | 127 | x1 = self.conv3(x) 128 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 129 | x = x1 + x2 130 | 131 | x = x.permute(0, 2, 3, 1) 132 | x = self.fc1(x) 133 | x = F.relu(x) 134 | x = self.fc2(x) 135 | return x 136 | 137 | class Net2d(nn.Module): 138 | def __init__(self, modes, width): 139 | super(Net2d, self).__init__() 140 | 141 | """ 142 | A wrapper function 143 | """ 144 | 145 | self.conv1 = SimpleBlock2d(modes, modes, width) 146 | 147 | 148 | def forward(self, x): 149 | x = self.conv1(x) 150 | return x.squeeze() 151 | 152 | 153 | def count_params(self): 154 | c = 0 155 | for p in self.parameters(): 156 | c += reduce(operator.mul, list(p.size())) 157 | 158 | return c 159 | 160 | ################################################################ 161 | # configs 162 | ################################################################ 163 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 164 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 165 | 166 | ntrain = 1000 167 | ntest = 100 168 | 169 | batch_size = 20 170 | learning_rate = 0.001 171 | 172 | epochs = 500 173 | step_size = 100 174 | gamma = 0.5 175 | 176 | modes = 12 177 | width = 32 178 | 179 | r = 5 180 | h = int(((421 - 1)/r) + 1) 181 | s = h 182 | 183 | ################################################################ 184 | # load data and data normalization 185 | ################################################################ 186 | reader = MatReader(TRAIN_PATH) 187 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s] 188 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s] 189 | 190 | reader.load_file(TEST_PATH) 191 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s] 192 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s] 193 | 194 | x_normalizer = UnitGaussianNormalizer(x_train) 195 | x_train = x_normalizer.encode(x_train) 196 | x_test = x_normalizer.encode(x_test) 197 | 198 | y_normalizer = UnitGaussianNormalizer(y_train) 199 | y_train = y_normalizer.encode(y_train) 200 | 201 | grids = [] 202 | grids.append(np.linspace(0, 1, s)) 203 | grids.append(np.linspace(0, 1, s)) 204 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 205 | grid = grid.reshape(1,s,s,2) 206 | grid = torch.tensor(grid, dtype=torch.float) 207 | x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3) 208 | x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3) 209 | 210 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 211 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 212 | 213 | ################################################################ 214 | # training and evaluation 215 | ################################################################ 216 | model = Net2d(modes, width).cuda() 217 | print(model.count_params()) 218 | 219 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 220 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 221 | 222 | myloss = LpLoss(size_average=False) 223 | y_normalizer.cuda() 224 | for ep in range(epochs): 225 | model.train() 226 | t1 = default_timer() 227 | train_mse = 0 228 | for x, y in train_loader: 229 | x, y = x.cuda(), y.cuda() 230 | 231 | optimizer.zero_grad() 232 | # loss = F.mse_loss(model(x).view(-1), y.view(-1), reduction='mean') 233 | out = model(x) 234 | out = y_normalizer.decode(out) 235 | y = y_normalizer.decode(y) 236 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 237 | loss.backward() 238 | 239 | 240 | optimizer.step() 241 | train_mse += loss.item() 242 | 243 | scheduler.step() 244 | 245 | model.eval() 246 | abs_err = 0.0 247 | rel_err = 0.0 248 | with torch.no_grad(): 249 | for x, y in test_loader: 250 | x, y = x.cuda(), y.cuda() 251 | 252 | out = model(x) 253 | out = y_normalizer.decode(model(x)) 254 | 255 | rel_err += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 256 | 257 | train_mse/= ntrain 258 | abs_err /= ntest 259 | rel_err /= ntest 260 | 261 | t2 = default_timer() 262 | print(ep, t2-t1, train_mse, rel_err) 263 | -------------------------------------------------------------------------------- /FNO-torch.1.6/fourier_2d_time.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 2D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf), 4 | which uses a recurrent structure to propagates in time. 5 | """ 6 | 7 | 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import matplotlib.pyplot as plt 14 | from utilities3 import * 15 | 16 | import operator 17 | from functools import reduce 18 | from functools import partial 19 | 20 | from timeit import default_timer 21 | import scipy.io 22 | 23 | torch.manual_seed(0) 24 | np.random.seed(0) 25 | 26 | #Complex multiplication 27 | def compl_mul2d(a, b): 28 | op = partial(torch.einsum, "bctq,dctq->bdtq") 29 | return torch.stack([ 30 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 31 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 32 | ], dim=-1) 33 | 34 | ################################################################ 35 | # fourier layer 36 | ################################################################ 37 | 38 | class SpectralConv2d_fast(nn.Module): 39 | def __init__(self, in_channels, out_channels, modes1, modes2): 40 | super(SpectralConv2d_fast, self).__init__() 41 | 42 | """ 43 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 44 | """ 45 | 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 49 | self.modes2 = modes2 50 | 51 | self.scale = (1 / (in_channels * out_channels)) 52 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 53 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 54 | 55 | def forward(self, x): 56 | batchsize = x.shape[0] 57 | #Compute Fourier coeffcients up to factor of e^(- something constant) 58 | x_ft = torch.rfft(x, 2, normalized=True, onesided=True) 59 | 60 | # Multiply relevant Fourier modes 61 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 62 | out_ft[:, :, :self.modes1, :self.modes2] = \ 63 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 64 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 65 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 66 | 67 | #Return to physical space 68 | x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=(x.size(-2), x.size(-1))) 69 | return x 70 | 71 | class SimpleBlock2d(nn.Module): 72 | def __init__(self, modes1, modes2, width): 73 | super(SimpleBlock2d, self).__init__() 74 | 75 | """ 76 | The overall network. It contains 4 layers of the Fourier layer. 77 | 1. Lift the input to the desire channel dimension by self.fc0 . 78 | 2. 4 layers of the integral operators u' = (W + K)(u). 79 | W defined by self.w; K defined by self.conv . 80 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 81 | 82 | input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) 83 | input shape: (batchsize, x=64, y=64, c=12) 84 | output: the solution of the next timestep 85 | output shape: (batchsize, x=64, y=64, c=1) 86 | """ 87 | 88 | self.modes1 = modes1 89 | self.modes2 = modes2 90 | self.width = width 91 | self.fc0 = nn.Linear(12, self.width) 92 | # input channel is 12: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) 93 | 94 | self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 95 | self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 96 | self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 97 | self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2) 98 | self.w0 = nn.Conv1d(self.width, self.width, 1) 99 | self.w1 = nn.Conv1d(self.width, self.width, 1) 100 | self.w2 = nn.Conv1d(self.width, self.width, 1) 101 | self.w3 = nn.Conv1d(self.width, self.width, 1) 102 | self.bn0 = torch.nn.BatchNorm2d(self.width) 103 | self.bn1 = torch.nn.BatchNorm2d(self.width) 104 | self.bn2 = torch.nn.BatchNorm2d(self.width) 105 | self.bn3 = torch.nn.BatchNorm2d(self.width) 106 | 107 | 108 | self.fc1 = nn.Linear(self.width, 128) 109 | self.fc2 = nn.Linear(128, 1) 110 | 111 | def forward(self, x): 112 | batchsize = x.shape[0] 113 | size_x, size_y = x.shape[1], x.shape[2] 114 | 115 | x = self.fc0(x) 116 | x = x.permute(0, 3, 1, 2) 117 | 118 | x1 = self.conv0(x) 119 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 120 | x = self.bn0(x1 + x2) 121 | x = F.relu(x) 122 | x1 = self.conv1(x) 123 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 124 | x = self.bn1(x1 + x2) 125 | x = F.relu(x) 126 | x1 = self.conv2(x) 127 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 128 | x = self.bn2(x1 + x2) 129 | x = F.relu(x) 130 | x1 = self.conv3(x) 131 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 132 | x = self.bn3(x1 + x2) 133 | 134 | 135 | x = x.permute(0, 2, 3, 1) 136 | x = self.fc1(x) 137 | x = F.relu(x) 138 | x = self.fc2(x) 139 | return x 140 | 141 | class Net2d(nn.Module): 142 | def __init__(self, modes, width): 143 | super(Net2d, self).__init__() 144 | 145 | """ 146 | A wrapper function 147 | """ 148 | 149 | self.conv1 = SimpleBlock2d(modes, modes, width) 150 | 151 | 152 | def forward(self, x): 153 | x = self.conv1(x) 154 | return x 155 | 156 | 157 | def count_params(self): 158 | c = 0 159 | for p in self.parameters(): 160 | c += reduce(operator.mul, list(p.size())) 161 | 162 | return c 163 | 164 | 165 | ################################################################ 166 | # configs 167 | ################################################################ 168 | TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 169 | TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 170 | 171 | ntrain = 1000 172 | ntest = 200 173 | 174 | modes = 12 175 | width = 20 176 | 177 | batch_size = 20 178 | batch_size2 = batch_size 179 | 180 | epochs = 500 181 | learning_rate = 0.0025 182 | scheduler_step = 100 183 | scheduler_gamma = 0.5 184 | 185 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 186 | 187 | path = 'ns_fourier_2d_rnn_V10000_T20_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 188 | path_model = 'model/'+path 189 | path_train_err = 'results/'+path+'train.txt' 190 | path_test_err = 'results/'+path+'test.txt' 191 | path_image = 'image/'+path 192 | 193 | runtime = np.zeros(2, ) 194 | t1 = default_timer() 195 | 196 | sub = 1 197 | S = 64 198 | T_in = 10 199 | T = 10 200 | step = 1 201 | 202 | ################################################################ 203 | # load data 204 | ################################################################ 205 | 206 | reader = MatReader(TRAIN_PATH) 207 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 208 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 209 | 210 | reader = MatReader(TEST_PATH) 211 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 212 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 213 | 214 | print(train_u.shape) 215 | print(test_u.shape) 216 | assert (S == train_u.shape[-2]) 217 | assert (T == train_u.shape[-1]) 218 | 219 | train_a = train_a.reshape(ntrain,S,S,T_in) 220 | test_a = test_a.reshape(ntest,S,S,T_in) 221 | 222 | # pad the location (x,y) 223 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 224 | gridx = gridx.reshape(1, S, 1, 1).repeat([1, 1, S, 1]) 225 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 226 | gridy = gridy.reshape(1, 1, S, 1).repeat([1, S, 1, 1]) 227 | 228 | train_a = torch.cat((train_a, gridx.repeat([ntrain,1,1,1]), gridy.repeat([ntrain,1,1,1])), dim=-1) 229 | test_a = torch.cat((test_a, gridx.repeat([ntest,1,1,1]), gridy.repeat([ntest,1,1,1])), dim=-1) 230 | 231 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 232 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 233 | 234 | t2 = default_timer() 235 | 236 | print('preprocessing finished, time used:', t2-t1) 237 | device = torch.device('cuda') 238 | 239 | ################################################################ 240 | # training and evaluation 241 | ################################################################ 242 | 243 | model = Net2d(modes, width).cuda() 244 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 245 | 246 | print(model.count_params()) 247 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 248 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 249 | 250 | 251 | myloss = LpLoss(size_average=False) 252 | gridx = gridx.to(device) 253 | gridy = gridy.to(device) 254 | 255 | for ep in range(epochs): 256 | model.train() 257 | t1 = default_timer() 258 | train_l2_step = 0 259 | train_l2_full = 0 260 | for xx, yy in train_loader: 261 | loss = 0 262 | xx = xx.to(device) 263 | yy = yy.to(device) 264 | 265 | for t in range(0, T, step): 266 | y = yy[..., t:t + step] 267 | im = model(xx) 268 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 269 | 270 | if t == 0: 271 | pred = im 272 | else: 273 | pred = torch.cat((pred, im), -1) 274 | 275 | xx = torch.cat((xx[..., step:-2], im, 276 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 277 | 278 | train_l2_step += loss.item() 279 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 280 | train_l2_full += l2_full.item() 281 | 282 | optimizer.zero_grad() 283 | loss.backward() 284 | # l2_full.backward() 285 | optimizer.step() 286 | 287 | test_l2_step = 0 288 | test_l2_full = 0 289 | with torch.no_grad(): 290 | for xx, yy in test_loader: 291 | loss = 0 292 | xx = xx.to(device) 293 | yy = yy.to(device) 294 | 295 | for t in range(0, T, step): 296 | y = yy[..., t:t + step] 297 | im = model(xx) 298 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 299 | 300 | if t == 0: 301 | pred = im 302 | else: 303 | pred = torch.cat((pred, im), -1) 304 | 305 | xx = torch.cat((xx[..., step:-2], im, 306 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 307 | 308 | 309 | test_l2_step += loss.item() 310 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 311 | 312 | t2 = default_timer() 313 | scheduler.step() 314 | print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step), 315 | test_l2_full / ntest) 316 | # torch.save(model, path_model) 317 | 318 | 319 | # pred = torch.zeros(test_u.shape) 320 | # index = 0 321 | # test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 322 | # with torch.no_grad(): 323 | # for x, y in test_loader: 324 | # test_l2 = 0; 325 | # x, y = x.cuda(), y.cuda() 326 | # 327 | # out = model(x) 328 | # out = y_normalizer.decode(out) 329 | # pred[index] = out 330 | # 331 | # test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 332 | # print(index, test_l2) 333 | # index = index + 1 334 | 335 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 336 | 337 | 338 | 339 | 340 | -------------------------------------------------------------------------------- /FNO-torch.1.6/fourier_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 3D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf), 4 | which takes the 2D spatial + 1D temporal equation directly as a 3D problem 5 | """ 6 | 7 | 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import matplotlib.pyplot as plt 14 | from utilities3 import * 15 | 16 | import operator 17 | from functools import reduce 18 | from functools import partial 19 | 20 | from timeit import default_timer 21 | import scipy.io 22 | 23 | torch.manual_seed(0) 24 | np.random.seed(0) 25 | 26 | #Complex multiplication 27 | def compl_mul3d(a, b): 28 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 29 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 30 | return torch.stack([ 31 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 32 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 33 | ], dim=-1) 34 | 35 | ################################################################ 36 | # 3d fourier layers 37 | ################################################################ 38 | 39 | class SpectralConv3d_fast(nn.Module): 40 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 41 | super(SpectralConv3d_fast, self).__init__() 42 | 43 | """ 44 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. 45 | """ 46 | 47 | self.in_channels = in_channels 48 | self.out_channels = out_channels 49 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 50 | self.modes2 = modes2 51 | self.modes3 = modes3 52 | 53 | self.scale = (1 / (in_channels * out_channels)) 54 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 55 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 56 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 57 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 58 | 59 | def forward(self, x): 60 | batchsize = x.shape[0] 61 | #Compute Fourier coeffcients up to factor of e^(- something constant) 62 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 63 | 64 | # Multiply relevant Fourier modes 65 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 66 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 67 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 68 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 69 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 70 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 71 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 72 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 73 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 74 | 75 | #Return to physical space 76 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 77 | return x 78 | 79 | class SimpleBlock3d(nn.Module): 80 | def __init__(self, modes1, modes2, modes3, width): 81 | super(SimpleBlock3d, self).__init__() 82 | 83 | """ 84 | The overall network. It contains 4 layers of the Fourier layer. 85 | 1. Lift the input to the desire channel dimension by self.fc0 . 86 | 2. 4 layers of the integral operators u' = (W + K)(u). 87 | W defined by self.w; K defined by self.conv . 88 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 89 | 90 | input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. 91 | input shape: (batchsize, x=64, y=64, t=40, c=13) 92 | output: the solution of the next 40 timesteps 93 | output shape: (batchsize, x=64, y=64, t=40, c=1) 94 | """ 95 | 96 | self.modes1 = modes1 97 | self.modes2 = modes2 98 | self.modes3 = modes3 99 | self.width = width 100 | self.fc0 = nn.Linear(13, self.width) 101 | # input channel is 12: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t) 102 | 103 | 104 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 105 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 106 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 107 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 108 | self.w0 = nn.Conv1d(self.width, self.width, 1) 109 | self.w1 = nn.Conv1d(self.width, self.width, 1) 110 | self.w2 = nn.Conv1d(self.width, self.width, 1) 111 | self.w3 = nn.Conv1d(self.width, self.width, 1) 112 | self.bn0 = torch.nn.BatchNorm3d(self.width) 113 | self.bn1 = torch.nn.BatchNorm3d(self.width) 114 | self.bn2 = torch.nn.BatchNorm3d(self.width) 115 | self.bn3 = torch.nn.BatchNorm3d(self.width) 116 | 117 | 118 | self.fc1 = nn.Linear(self.width, 128) 119 | self.fc2 = nn.Linear(128, 1) 120 | 121 | def forward(self, x): 122 | batchsize = x.shape[0] 123 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 124 | 125 | x = self.fc0(x) 126 | x = x.permute(0, 4, 1, 2, 3) 127 | 128 | x1 = self.conv0(x) 129 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 130 | x = self.bn0(x1 + x2) 131 | x = F.relu(x) 132 | x1 = self.conv1(x) 133 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 134 | x = self.bn1(x1 + x2) 135 | x = F.relu(x) 136 | x1 = self.conv2(x) 137 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 138 | x = self.bn2(x1 + x2) 139 | x = F.relu(x) 140 | x1 = self.conv3(x) 141 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 142 | x = self.bn3(x1 + x2) 143 | 144 | 145 | x = x.permute(0, 2, 3, 4, 1) 146 | x = self.fc1(x) 147 | x = F.relu(x) 148 | x = self.fc2(x) 149 | return x 150 | 151 | class Net3d(nn.Module): 152 | def __init__(self, modes, width): 153 | super(Net3d, self).__init__() 154 | 155 | """ 156 | A wrapper function 157 | """ 158 | 159 | self.conv1 = SimpleBlock3d(modes, modes, modes, width) 160 | 161 | 162 | def forward(self, x): 163 | x = self.conv1(x) 164 | return x.squeeze() 165 | 166 | 167 | def count_params(self): 168 | c = 0 169 | for p in self.parameters(): 170 | c += reduce(operator.mul, list(p.size())) 171 | 172 | return c 173 | 174 | ################################################################ 175 | # configs 176 | ################################################################ 177 | 178 | # TRAIN_PATH = 'data/ns_data_V1000_N1000_train.mat' 179 | # TEST_PATH = 'data/ns_data_V1000_N1000_train_2.mat' 180 | # TRAIN_PATH = 'data/ns_data_V1000_N5000.mat' 181 | # TEST_PATH = 'data/ns_data_V1000_N5000.mat' 182 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 183 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 184 | 185 | ntrain = 1000 186 | ntest = 200 187 | 188 | modes = 4 189 | width = 20 190 | 191 | batch_size = 10 192 | batch_size2 = batch_size 193 | 194 | epochs = 10 195 | learning_rate = 0.0025 196 | scheduler_step = 100 197 | scheduler_gamma = 0.5 198 | 199 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 200 | 201 | path = 'test' 202 | # path = 'ns_fourier_V100_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 203 | path_model = 'model/'+path 204 | path_train_err = 'results/'+path+'train.txt' 205 | path_test_err = 'results/'+path+'test.txt' 206 | path_image = 'image/'+path 207 | 208 | 209 | runtime = np.zeros(2, ) 210 | t1 = default_timer() 211 | 212 | 213 | sub = 1 214 | S = 64 // sub 215 | T_in = 10 216 | T = 40 217 | 218 | ################################################################ 219 | # load data 220 | ################################################################ 221 | 222 | reader = MatReader(TRAIN_PATH) 223 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 224 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 225 | 226 | reader = MatReader(TEST_PATH) 227 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 228 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 229 | 230 | print(train_u.shape) 231 | print(test_u.shape) 232 | assert (S == train_u.shape[-2]) 233 | assert (T == train_u.shape[-1]) 234 | 235 | 236 | a_normalizer = UnitGaussianNormalizer(train_a) 237 | train_a = a_normalizer.encode(train_a) 238 | test_a = a_normalizer.encode(test_a) 239 | 240 | y_normalizer = UnitGaussianNormalizer(train_u) 241 | train_u = y_normalizer.encode(train_u) 242 | 243 | train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1]) 244 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 245 | 246 | # pad locations (x,y,t) 247 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 248 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 249 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 250 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 251 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 252 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 253 | 254 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 255 | gridt.repeat([ntrain,1,1,1,1]), train_a), dim=-1) 256 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 257 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 258 | 259 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 260 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 261 | 262 | t2 = default_timer() 263 | 264 | print('preprocessing finished, time used:', t2-t1) 265 | device = torch.device('cuda') 266 | 267 | ################################################################ 268 | # training and evaluation 269 | ################################################################ 270 | model = Net3d(modes, width).cuda() 271 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 272 | 273 | print(model.count_params()) 274 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 275 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 276 | 277 | 278 | myloss = LpLoss(size_average=False) 279 | y_normalizer.cuda() 280 | for ep in range(epochs): 281 | model.train() 282 | t1 = default_timer() 283 | train_mse = 0 284 | train_l2 = 0 285 | for x, y in train_loader: 286 | x, y = x.cuda(), y.cuda() 287 | 288 | optimizer.zero_grad() 289 | out = model(x) 290 | 291 | mse = F.mse_loss(out, y, reduction='mean') 292 | # mse.backward() 293 | 294 | y = y_normalizer.decode(y) 295 | out = y_normalizer.decode(out) 296 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 297 | l2.backward() 298 | 299 | optimizer.step() 300 | train_mse += mse.item() 301 | train_l2 += l2.item() 302 | 303 | scheduler.step() 304 | 305 | model.eval() 306 | test_l2 = 0.0 307 | with torch.no_grad(): 308 | for x, y in test_loader: 309 | x, y = x.cuda(), y.cuda() 310 | 311 | out = model(x) 312 | out = y_normalizer.decode(out) 313 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 314 | 315 | train_mse /= len(train_loader) 316 | train_l2 /= ntrain 317 | test_l2 /= ntest 318 | 319 | t2 = default_timer() 320 | print(ep, t2-t1, train_mse, train_l2, test_l2) 321 | # torch.save(model, path_model) 322 | 323 | 324 | pred = torch.zeros(test_u.shape) 325 | index = 0 326 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 327 | with torch.no_grad(): 328 | for x, y in test_loader: 329 | test_l2 = 0 330 | x, y = x.cuda(), y.cuda() 331 | 332 | out = model(x) 333 | out = y_normalizer.decode(out) 334 | pred[index] = out 335 | 336 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 337 | print(index, test_l2) 338 | index = index + 1 339 | 340 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 341 | 342 | 343 | 344 | 345 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zongyi Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fourier Neural Operator 2 | 3 | This repository contains the code for the paper: 4 | - [(FNO) Fourier Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2010.08895) 5 | 6 | In this work, we formulate a new neural operator by parameterizing the integral kernel directly in Fourier space, allowing for an expressive and efficient architecture. 7 | We perform experiments on Burgers' equation, Darcy flow, and the Navier-Stokes equation (including the turbulent regime). 8 | Our Fourier neural operator shows state-of-the-art performance compared to existing neural network methodologies and it is up to three orders of magnitude faster compared to traditional PDE solvers. 9 | 10 | It follows from the previous works: 11 | - [(GKN) Neural Operator: Graph Kernel Network for Partial Differential Equations](https://arxiv.org/abs/2003.03485) 12 | - [(MGKN) Multipole Graph Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2006.09535) 13 | 14 | 15 | Follow-ups: 16 | - [(PINO) Physics-Informed Neural Operator for Learning Partial Differential Equations](https://arxiv.org/pdf/2111.03794.pdf) 17 | - [(Geo-FNO) Fourier Neural Operator with Learned Deformations for PDEs on General Geometries](https://arxiv.org/pdf/2207.05209.pdf) 18 | 19 | Examples of applications: 20 | - [Weather Forecast](https://arxiv.org/pdf/2202.11214.pdf) 21 | - [Carbon capture and storage](https://arxiv.org/pdf/2210.17051.pdf) 22 | 23 | ## Requirements 24 | - We have updated the files to support [PyTorch 1.8.0](https://pytorch.org/). 25 | Pytorch 1.8.0 starts to support complex numbers and it has a new implementation of FFT. 26 | As a result the code is about 30% faster. 27 | - Previous version for [PyTorch 1.6.0](https://pytorch.org/) is avaiable at `FNO-torch.1.6`. 28 | 29 | ## Major Updates: 30 | - Dec 2022: Add an MLP per layer. Add InstanceNorm layers for fourier_2d_time. Add Cosine Annealing scheduler. 31 | - Aug 2021: use GeLU instead of ReLU. 32 | - Jan 2021: remove unnecessary BatchNorm layers. 33 | 34 | ## Files 35 | The code is in the form of simple scripts. Each script shall be stand-alone and directly runnable. 36 | 37 | - `fourier_1d.py` is the Fourier Neural Operator for 1D problem such as the (time-independent) Burgers equation discussed in Section 5.1 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 38 | The neural operator maps the solution function from time 0 to time 1. 39 | - `fourier_2d.py` is the Fourier Neural Operator for 2D problem such as the Darcy Flow discussed in Section 5.2 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 40 | The neural operator maps from the coefficient function to the solution function. 41 | - `fourier_2d_time.py` is the Fourier Neural Operator for 2D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf), 42 | which uses a recurrent structure to propagates in time. The neural operator maps the solution function from time `[t-10:t]` to time `t+1`. 43 | - `fourier_3d.py` is the Fourier Neural Operator for 3D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf), 44 | which takes the 2D spatial + 1D temporal equation directly as a 3D problem. The neural operator maps the solution function from time `[1:10]` to time `[11:T]`. 45 | - The lowrank methods are similar. These scripts are the Lowrank neural operators for the corresponding settings. 46 | - `data_generation` are the conventional solvers we used to generate the datasets for the Burgers equation, Darcy flow, and Navier-Stokes equation. 47 | 48 | ## Datasets 49 | We provide the Burgers equation, Darcy flow, and Navier-Stokes equation datasets we used in the paper. 50 | The data generation configuration can be found in the paper. 51 | - [PDE datasets](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-?usp=sharing) 52 | 53 | The datasets are given in the form of matlab file. They can be loaded with the scripts provided in utilities.py. 54 | Each data file is loaded as a tensor. The first index is the samples; the rest of indices are the discretization. 55 | For example, 56 | - `Burgers_R10.mat` contains the dataset for the Burgers equation. It is of the shape `[1000, 8192]`, 57 | meaning it has `1000` training samples on a grid of `8192`. 58 | - `NavierStokes_V1e-3_N5000_T50.mat` contains the dataset for the 2D Navier-Stokes equation. It is of the shape `[5000, 64, 64, 50]`, 59 | meaning it has `5000` training samples on a grid of `(64, 64)` with `50` time steps. 60 | 61 | We also provide the data generation scripts at `data_generation`. 62 | 63 | ## Models 64 | Here are the pre-trained models. It can be evaluated using _eval.py_ or _super_resolution.py_. 65 | - [models](https://drive.google.com/drive/folders/1swLA6yKR1f3PKdYSKhLqK4zfNjS9pt_U?usp=sharing) 66 | 67 | ## Citations 68 | 69 | ``` 70 | @misc{li2020fourier, 71 | title={Fourier Neural Operator for Parametric Partial Differential Equations}, 72 | author={Zongyi Li and Nikola Kovachki and Kamyar Azizzadenesheli and Burigede Liu and Kaushik Bhattacharya and Andrew Stuart and Anima Anandkumar}, 73 | year={2020}, 74 | eprint={2010.08895}, 75 | archivePrefix={arXiv}, 76 | primaryClass={cs.LG} 77 | } 78 | 79 | @misc{li2020neural, 80 | title={Neural Operator: Graph Kernel Network for Partial Differential Equations}, 81 | author={Zongyi Li and Nikola Kovachki and Kamyar Azizzadenesheli and Burigede Liu and Kaushik Bhattacharya and Andrew Stuart and Anima Anandkumar}, 82 | year={2020}, 83 | eprint={2003.03485}, 84 | archivePrefix={arXiv}, 85 | primaryClass={cs.LG} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /data_generation/burgers/GRF1.m: -------------------------------------------------------------------------------- 1 | %Radom function from N(m, C) on [0 1] where 2 | %C = sigma^2(-Delta + tau^2 I)^(-gamma) 3 | %with periodic, zero dirichlet, and zero neumann boundary. 4 | %Dirichlet only supports m = 0. 5 | %N is the # of Fourier modes, usually, grid size / 2. 6 | function u = GRF1(N, m, gamma, tau, sigma, type) 7 | 8 | if type == "dirichlet" 9 | m = 0; 10 | end 11 | 12 | if type == "periodic" 13 | my_const = 2*pi; 14 | else 15 | my_const = pi; 16 | end 17 | 18 | my_eigs = sqrt(2)*(abs(sigma).*((my_const.*(1:N)').^2 + tau^2).^(-gamma/2)); 19 | 20 | if type == "dirichlet" 21 | alpha = zeros(N,1); 22 | else 23 | xi_alpha = randn(N,1); 24 | alpha = my_eigs.*xi_alpha; 25 | end 26 | 27 | if type == "neumann" 28 | beta = zeros(N,1); 29 | else 30 | xi_beta = randn(N,1); 31 | beta = my_eigs.*xi_beta; 32 | end 33 | 34 | a = alpha/2; 35 | b = -beta/2; 36 | 37 | c = [flipud(a) - flipud(b).*1i;m + 0*1i;a + b.*1i]; 38 | 39 | if type == "periodic" 40 | uu = chebfun(c, [0 1], 'trig', 'coeffs'); 41 | u = chebfun(@(t) uu(t - 0.5), [0 1], 'trig'); 42 | else 43 | uu = chebfun(c, [-pi pi], 'trig', 'coeffs'); 44 | u = chebfun(@(t) uu(pi*t), [0 1]); 45 | end -------------------------------------------------------------------------------- /data_generation/burgers/burgers1.m: -------------------------------------------------------------------------------- 1 | function u = burgers1(init, tspan, s, visc) 2 | 3 | S = spinop([0 1], tspan); 4 | dt = tspan(2) - tspan(1); 5 | S.lin = @(u) + visc*diff(u,2); 6 | S.nonlin = @(u) - 0.5*diff(u.^2); 7 | S.init = init; 8 | u = spin(S,s,dt,'plot','off'); 9 | 10 | -------------------------------------------------------------------------------- /data_generation/burgers/gen_burgers1.m: -------------------------------------------------------------------------------- 1 | % number of realizations to generate 2 | N = 1; 3 | 4 | % parameters for the Gaussian random field 5 | gamma = 2.5; 6 | tau = 7; 7 | sigma = 7^(2); 8 | 9 | % viscosity 10 | visc = 1/1000; 11 | 12 | % grid size 13 | s = 1024; 14 | steps = 200; 15 | 16 | 17 | input = zeros(N, s); 18 | if steps == 1 19 | output = zeros(N, s); 20 | else 21 | output = zeros(N, steps, s); 22 | end 23 | 24 | tspan = linspace(0,1,steps+1); 25 | x = linspace(0,1,s+1); 26 | for j=1:N 27 | u0 = GRF1(s/2, 0, gamma, tau, sigma, "periodic"); 28 | u = burgers1(u0, tspan, s, visc); 29 | 30 | u0eval = u0(x); 31 | input(j,:) = u0eval(1:end-1); 32 | 33 | if steps == 1 34 | output(j,:) = u.values; 35 | else 36 | for k=2:(steps+1) 37 | output(j,k,:) = u{k}.values; 38 | end 39 | end 40 | 41 | disp(j); 42 | end 43 | -------------------------------------------------------------------------------- /data_generation/darcy/GRF.m: -------------------------------------------------------------------------------- 1 | % Return a sample of a Gaussian random field on [0,1]^2 with: 2 | % mean 0 3 | % covariance operator C = (-Delta + tau^2)^(-alpha) 4 | % where Delta is the Laplacian with zero Neumann boundary conditions. 5 | 6 | 7 | function U = GRF(alpha,tau,s) 8 | 9 | % Random variables in KL expansion 10 | xi = normrnd(0,1,s); 11 | 12 | % Define the (square root of) eigenvalues of the covariance operator 13 | [K1,K2] = meshgrid(0:s-1,0:s-1); 14 | %coef = (pi^2*(K1.^2+K2.^2) + tau^2).^(-alpha/2); 15 | coef = tau^(alpha-1).*(pi^2*(K1.^2+K2.^2) + tau^2).^(-alpha/2); 16 | %coef = (pi^2*(K1.^2+K2.^2)).^(-alpha/2); 17 | % Construct the KL coefficients 18 | L = s*coef.*xi; 19 | L(1,1) = 0; 20 | 21 | U = idct2(L); 22 | 23 | end 24 | -------------------------------------------------------------------------------- /data_generation/darcy/demo.m: -------------------------------------------------------------------------------- 1 | %Number of grid points on [0,1]^2 2 | %i.e. uniform mesh with step h=1/(s-1) 3 | s = 256; 4 | 5 | %Create mesh (only needed for plotting) 6 | [X,Y] = meshgrid(0:(1/(s-1)):1); 7 | 8 | %Parameters of covariance C = tau^(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha) 9 | %Note that we need alpha > d/2 (here d= 2) 10 | %Laplacian has zero Neumann boundry 11 | %alpha and tau control smoothness; the bigger they are, the smoother the 12 | %function 13 | alpha = 2; 14 | tau = 3; 15 | 16 | %Generate random coefficients from N(0,C) 17 | norm_a = GRF(alpha, tau, s); 18 | 19 | %Exponentiate it, so that a(x) > 0 20 | %Now a ~ Lognormal(0, C) 21 | %This is done so that the PDE is elliptic 22 | lognorm_a = exp(norm_a); 23 | 24 | %Another way to achieve ellipticity is to threshhold the coefficients 25 | thresh_a = zeros(s,s); 26 | thresh_a(norm_a >= 0) = 12; 27 | thresh_a(norm_a < 0) = 4; 28 | 29 | %Forcing function, f(x) = 1 30 | f = ones(s,s); 31 | 32 | %Solve PDE: - div(a(x)*grad(p(x))) = f(x) 33 | lognorm_p = solve_gwf(lognorm_a,f); 34 | thresh_p = solve_gwf(thresh_a,f); 35 | 36 | %Plot coefficients and solutions 37 | subplot(2,2,1) 38 | surf(X,Y,lognorm_a); 39 | view(2); 40 | shading interp; 41 | colorbar; 42 | subplot(2,2,2) 43 | surf(X,Y,lognorm_p); 44 | view(2); 45 | shading interp; 46 | colorbar; 47 | subplot(2,2,3) 48 | surf(X,Y,thresh_a); 49 | view(2); 50 | shading interp; 51 | colorbar; 52 | subplot(2,2,4) 53 | surf(X,Y,thresh_p); 54 | view(2); 55 | shading interp; 56 | colorbar; 57 | -------------------------------------------------------------------------------- /data_generation/darcy/solve_gwf.m: -------------------------------------------------------------------------------- 1 | %% 2 | % Solve the equation -d(coef*dp) = F 3 | 4 | function P = solve_gwf(coef,F,~) 5 | 6 | K = length(coef); 7 | 8 | [X1,Y1] = meshgrid(1/(2*K):1/K:(2*K-1)/(2*K),1/(2*K):1/K:(2*K-1)/(2*K)); 9 | [X2,Y2] = meshgrid(0:1/(K-1):1,0:1/(K-1):1); 10 | 11 | coef = interp2(X1,Y1,coef,X2,Y2,'spline'); 12 | F = interp2(X1,Y1,F,X2,Y2,'spline'); 13 | 14 | F = F(2:K-1,2:K-1); 15 | 16 | d = cell(K-2,K-2); 17 | [d{:}] = deal(sparse(zeros(K-2))); 18 | 19 | for j=2:K-1 20 | d{j-1,j-1} = spdiags([[-(coef(2:K-2,j)+coef(3:K-1,j))/2;0],... 21 | (coef(1:K-2,j)+coef(2:K-1,j))/2 + (coef(3:K,j)+coef(2:K-1,j))/2 ... 22 | + (coef(2:K-1,j-1)+coef(2:K-1,j))/2 + (coef(2:K-1,j+1)+coef(2:K-1,j))/2,... 23 | [0;-(coef(2:K-2,j)+coef(3:K-1,j))/2]],... 24 | -1:1,K-2,K-2); 25 | 26 | if j~=K-1 27 | d{j-1,j} = spdiags(-(coef(2:K-1,j)+coef(2:K-1,j+1))/2,0,K-2,K-2); 28 | d{j,j-1} = d{j-1,j}; 29 | end 30 | end 31 | 32 | A = cell2mat(d)*(K-1)^2; 33 | P =[zeros(1,K);[zeros(K-2,1),vec2mat(A\F(:),K-2),zeros(K-2,1)];zeros(1,K)]; 34 | 35 | P = interp2(X2,Y2,P,X1,Y1,'spline')'; 36 | 37 | end -------------------------------------------------------------------------------- /data_generation/navier_stokes/ns_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | 5 | from random_fields import GaussianRF 6 | 7 | from timeit import default_timer 8 | 9 | import scipy.io 10 | 11 | 12 | #w0: initial vorticity 13 | #f: forcing term 14 | #visc: viscosity (1/Re) 15 | #T: final time 16 | #delta_t: internal time-step for solve (descrease if blow-up) 17 | #record_steps: number of in-time snapshots to record 18 | def navier_stokes_2d(w0, f, visc, T, delta_t=1e-4, record_steps=1): 19 | 20 | #Grid size - must be power of 2 21 | N = w0.size()[-1] 22 | 23 | #Maximum frequency 24 | k_max = math.floor(N/2.0) 25 | 26 | #Number of steps to final time 27 | steps = math.ceil(T/delta_t) 28 | 29 | #Initial vorticity to Fourier space 30 | w_h = torch.fft.rfft2(w0) 31 | 32 | #Forcing to Fourier space 33 | f_h = torch.fft.rfft2(f) 34 | 35 | #If same forcing for the whole batch 36 | if len(f_h.size()) < len(w_h.size()): 37 | f_h = torch.unsqueeze(f_h, 0) 38 | 39 | #Record solution every this number of steps 40 | record_time = math.floor(steps/record_steps) 41 | 42 | #Wavenumbers in y-direction 43 | k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=w0.device), torch.arange(start=-k_max, end=0, step=1, device=w0.device)), 0).repeat(N,1) 44 | #Wavenumbers in x-direction 45 | k_x = k_y.transpose(0,1) 46 | 47 | #Truncate redundant modes 48 | k_x = k_x[..., :k_max + 1] 49 | k_y = k_y[..., :k_max + 1] 50 | 51 | #Negative Laplacian in Fourier space 52 | lap = 4*(math.pi**2)*(k_x**2 + k_y**2) 53 | lap[0,0] = 1.0 54 | #Dealiasing mask 55 | dealias = torch.unsqueeze(torch.logical_and(torch.abs(k_y) <= (2.0/3.0)*k_max, torch.abs(k_x) <= (2.0/3.0)*k_max).float(), 0) 56 | 57 | #Saving solution and time 58 | sol = torch.zeros(*w0.size(), record_steps, device=w0.device) 59 | sol_t = torch.zeros(record_steps, device=w0.device) 60 | 61 | #Record counter 62 | c = 0 63 | #Physical time 64 | t = 0.0 65 | for j in range(steps): 66 | #Stream function in Fourier space: solve Poisson equation 67 | psi_h = w_h / lap 68 | 69 | #Velocity field in x-direction = psi_y 70 | q = 2. * math.pi * k_y * 1j * psi_h 71 | q = torch.fft.irfft2(q, s=(N, N)) 72 | 73 | #Velocity field in y-direction = -psi_x 74 | v = -2. * math.pi * k_x * 1j * psi_h 75 | v = torch.fft.irfft2(v, s=(N, N)) 76 | 77 | #Partial x of vorticity 78 | w_x = 2. * math.pi * k_x * 1j * w_h 79 | w_x = torch.fft.irfft2(w_x, s=(N, N)) 80 | 81 | #Partial y of vorticity 82 | w_y = 2. * math.pi * k_y * 1j * w_h 83 | w_y = torch.fft.irfft2(w_y, s=(N, N)) 84 | 85 | #Non-linear term (u.grad(w)): compute in physical space then back to Fourier space 86 | F_h = torch.fft.rfft2(q*w_x + v*w_y) 87 | 88 | #Dealias 89 | F_h = dealias* F_h 90 | 91 | #Crank-Nicolson update 92 | w_h = (-delta_t*F_h + delta_t*f_h + (1.0 - 0.5*delta_t*visc*lap)*w_h)/(1.0 + 0.5*delta_t*visc*lap) 93 | 94 | #Update real time (used only for recording) 95 | t += delta_t 96 | 97 | if (j+1) % record_time == 0: 98 | #Solution in physical space 99 | w = torch.fft.irfft2(w_h, s=(N, N)) 100 | 101 | #Record solution and time 102 | sol[...,c] = w 103 | sol_t[c] = t 104 | 105 | c += 1 106 | 107 | 108 | return sol, sol_t 109 | 110 | 111 | device = torch.device('cuda') 112 | 113 | #Resolution 114 | s = 256 115 | 116 | #Number of solutions to generate 117 | N = 20 118 | 119 | #Set up 2d GRF with covariance parameters 120 | GRF = GaussianRF(2, s, alpha=2.5, tau=7, device=device) 121 | 122 | #Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y))) 123 | t = torch.linspace(0, 1, s+1, device=device) 124 | t = t[0:-1] 125 | 126 | X,Y = torch.meshgrid(t, t, indexing='ij') 127 | f = 0.1*(torch.sin(2*math.pi*(X + Y)) + torch.cos(2*math.pi*(X + Y))) 128 | 129 | #Number of snapshots from solution 130 | record_steps = 200 131 | 132 | #Inputs 133 | a = torch.zeros(N, s, s) 134 | #Solutions 135 | u = torch.zeros(N, s, s, record_steps) 136 | 137 | #Solve equations in batches (order of magnitude speed-up) 138 | 139 | #Batch size 140 | bsize = 20 141 | 142 | c = 0 143 | t0 =default_timer() 144 | for j in range(N//bsize): 145 | 146 | #Sample random feilds 147 | w0 = GRF.sample(bsize) 148 | 149 | #Solve NS 150 | sol, sol_t = navier_stokes_2d(w0, f, 1e-3, 50.0, 1e-4, record_steps) 151 | 152 | a[c:(c+bsize),...] = w0 153 | u[c:(c+bsize),...] = sol 154 | 155 | c += bsize 156 | t1 = default_timer() 157 | print(j, c, t1-t0) 158 | 159 | scipy.io.savemat('ns_data.mat', mdict={'a': a.cpu().numpy(), 'u': u.cpu().numpy(), 't': sol_t.cpu().numpy()}) 160 | -------------------------------------------------------------------------------- /data_generation/navier_stokes/random_fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from timeit import default_timer 5 | 6 | 7 | class GaussianRF(object): 8 | 9 | def __init__(self, dim, size, alpha=2, tau=3, sigma=None, boundary="periodic", device=None): 10 | 11 | self.dim = dim 12 | self.device = device 13 | 14 | if sigma is None: 15 | sigma = tau**(0.5*(2*alpha - self.dim)) 16 | 17 | k_max = size//2 18 | 19 | if dim == 1: 20 | k = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \ 21 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0) 22 | 23 | self.sqrt_eig = size*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k**2) + tau**2)**(-alpha/2.0)) 24 | self.sqrt_eig[0] = 0.0 25 | 26 | elif dim == 2: 27 | wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \ 28 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).repeat(size,1) 29 | 30 | k_x = wavenumers.transpose(0,1) 31 | k_y = wavenumers 32 | 33 | self.sqrt_eig = (size**2)*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k_x**2 + k_y**2) + tau**2)**(-alpha/2.0)) 34 | self.sqrt_eig[0,0] = 0.0 35 | 36 | elif dim == 3: 37 | wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \ 38 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).repeat(size,size,1) 39 | 40 | k_x = wavenumers.transpose(1,2) 41 | k_y = wavenumers 42 | k_z = wavenumers.transpose(0,2) 43 | 44 | self.sqrt_eig = (size**3)*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k_x**2 + k_y**2 + k_z**2) + tau**2)**(-alpha/2.0)) 45 | self.sqrt_eig[0,0,0] = 0.0 46 | 47 | self.size = [] 48 | for j in range(self.dim): 49 | self.size.append(size) 50 | 51 | self.size = tuple(self.size) 52 | 53 | def sample(self, N): 54 | 55 | coeff = torch.randn(N, *self.size, dtype=torch.cfloat, device=self.device) 56 | coeff = self.sqrt_eig * coeff 57 | 58 | return torch.fft.ifftn(coeff, dim=list(range(-1, -self.dim - 1, -1))).real 59 | -------------------------------------------------------------------------------- /fourier_1d.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 1D problem such as the (time-independent) Burgers equation discussed in Section 5.1 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | import torch.nn.functional as F 7 | from timeit import default_timer 8 | from utilities3 import * 9 | 10 | torch.manual_seed(0) 11 | np.random.seed(0) 12 | 13 | 14 | ################################################################ 15 | # 1d fourier layer 16 | ################################################################ 17 | class SpectralConv1d(nn.Module): 18 | def __init__(self, in_channels, out_channels, modes1): 19 | super(SpectralConv1d, self).__init__() 20 | 21 | """ 22 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 23 | """ 24 | 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 28 | 29 | self.scale = (1 / (in_channels*out_channels)) 30 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat)) 31 | 32 | # Complex multiplication 33 | def compl_mul1d(self, input, weights): 34 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 35 | return torch.einsum("bix,iox->box", input, weights) 36 | 37 | def forward(self, x): 38 | batchsize = x.shape[0] 39 | #Compute Fourier coeffcients up to factor of e^(- something constant) 40 | x_ft = torch.fft.rfft(x) 41 | 42 | # Multiply relevant Fourier modes 43 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat) 44 | out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 45 | 46 | #Return to physical space 47 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 48 | return x 49 | 50 | class MLP(nn.Module): 51 | def __init__(self, in_channels, out_channels, mid_channels): 52 | super(MLP, self).__init__() 53 | self.mlp1 = nn.Conv1d(in_channels, mid_channels, 1) 54 | self.mlp2 = nn.Conv1d(mid_channels, out_channels, 1) 55 | 56 | def forward(self, x): 57 | x = self.mlp1(x) 58 | x = F.gelu(x) 59 | x = self.mlp2(x) 60 | return x 61 | 62 | class FNO1d(nn.Module): 63 | def __init__(self, modes, width): 64 | super(FNO1d, self).__init__() 65 | 66 | """ 67 | The overall network. It contains 4 layers of the Fourier layer. 68 | 1. Lift the input to the desire channel dimension by self.fc0 . 69 | 2. 4 layers of the integral operators u' = (W + K)(u). 70 | W defined by self.w; K defined by self.conv . 71 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 72 | 73 | input: the solution of the initial condition and location (a(x), x) 74 | input shape: (batchsize, x=s, c=2) 75 | output: the solution of a later timestep 76 | output shape: (batchsize, x=s, c=1) 77 | """ 78 | 79 | self.modes1 = modes 80 | self.width = width 81 | self.padding = 8 # pad the domain if input is non-periodic 82 | 83 | self.p = nn.Linear(2, self.width) # input channel_dim is 2: (u0(x), x) 84 | self.conv0 = SpectralConv1d(self.width, self.width, self.modes1) 85 | self.conv1 = SpectralConv1d(self.width, self.width, self.modes1) 86 | self.conv2 = SpectralConv1d(self.width, self.width, self.modes1) 87 | self.conv3 = SpectralConv1d(self.width, self.width, self.modes1) 88 | self.mlp0 = MLP(self.width, self.width, self.width) 89 | self.mlp1 = MLP(self.width, self.width, self.width) 90 | self.mlp2 = MLP(self.width, self.width, self.width) 91 | self.mlp3 = MLP(self.width, self.width, self.width) 92 | self.w0 = nn.Conv1d(self.width, self.width, 1) 93 | self.w1 = nn.Conv1d(self.width, self.width, 1) 94 | self.w2 = nn.Conv1d(self.width, self.width, 1) 95 | self.w3 = nn.Conv1d(self.width, self.width, 1) 96 | self.q = MLP(self.width, 1, self.width*2) # output channel_dim is 1: u1(x) 97 | 98 | def forward(self, x): 99 | grid = self.get_grid(x.shape, x.device) 100 | x = torch.cat((x, grid), dim=-1) 101 | x = self.p(x) 102 | x = x.permute(0, 2, 1) 103 | # x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic 104 | 105 | x1 = self.conv0(x) 106 | x1 = self.mlp0(x1) 107 | x2 = self.w0(x) 108 | x = x1 + x2 109 | x = F.gelu(x) 110 | 111 | x1 = self.conv1(x) 112 | x1 = self.mlp1(x1) 113 | x2 = self.w1(x) 114 | x = x1 + x2 115 | x = F.gelu(x) 116 | 117 | x1 = self.conv2(x) 118 | x1 = self.mlp2(x1) 119 | x2 = self.w2(x) 120 | x = x1 + x2 121 | x = F.gelu(x) 122 | 123 | x1 = self.conv3(x) 124 | x1 = self.mlp3(x1) 125 | x2 = self.w3(x) 126 | x = x1 + x2 127 | 128 | # x = x[..., :-self.padding] # pad the domain if input is non-periodic 129 | x = self.q(x) 130 | x = x.permute(0, 2, 1) 131 | return x 132 | 133 | def get_grid(self, shape, device): 134 | batchsize, size_x = shape[0], shape[1] 135 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 136 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 137 | return gridx.to(device) 138 | 139 | ################################################################ 140 | # configurations 141 | ################################################################ 142 | ntrain = 1000 143 | ntest = 100 144 | 145 | sub = 2**3 #subsampling rate 146 | h = 2**13 // sub #total grid size divided by the subsampling rate 147 | s = h 148 | 149 | batch_size = 20 150 | learning_rate = 0.001 151 | epochs = 500 152 | iterations = epochs*(ntrain//batch_size) 153 | 154 | modes = 16 155 | width = 64 156 | 157 | ################################################################ 158 | # read data 159 | ################################################################ 160 | 161 | # Data is of the shape (number of samples, grid size) 162 | dataloader = MatReader('data/burgers_data_R10.mat') 163 | x_data = dataloader.read_field('a')[:,::sub] 164 | y_data = dataloader.read_field('u')[:,::sub] 165 | 166 | x_train = x_data[:ntrain,:] 167 | y_train = y_data[:ntrain,:] 168 | x_test = x_data[-ntest:,:] 169 | y_test = y_data[-ntest:,:] 170 | 171 | x_train = x_train.reshape(ntrain,s,1) 172 | x_test = x_test.reshape(ntest,s,1) 173 | 174 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 175 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 176 | 177 | # model 178 | model = FNO1d(modes, width).cuda() 179 | print(count_params(model)) 180 | 181 | ################################################################ 182 | # training and evaluation 183 | ################################################################ 184 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 185 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 186 | 187 | myloss = LpLoss(size_average=False) 188 | for ep in range(epochs): 189 | model.train() 190 | t1 = default_timer() 191 | train_mse = 0 192 | train_l2 = 0 193 | for x, y in train_loader: 194 | x, y = x.cuda(), y.cuda() 195 | 196 | optimizer.zero_grad() 197 | out = model(x) 198 | 199 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean') 200 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 201 | l2.backward() # use the l2 relative loss 202 | 203 | optimizer.step() 204 | scheduler.step() 205 | train_mse += mse.item() 206 | train_l2 += l2.item() 207 | 208 | model.eval() 209 | test_l2 = 0.0 210 | with torch.no_grad(): 211 | for x, y in test_loader: 212 | x, y = x.cuda(), y.cuda() 213 | 214 | out = model(x) 215 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 216 | 217 | train_mse /= len(train_loader) 218 | train_l2 /= ntrain 219 | test_l2 /= ntest 220 | 221 | t2 = default_timer() 222 | print(ep, t2-t1, train_mse, train_l2, test_l2) 223 | 224 | # torch.save(model, 'model/ns_fourier_burgers') 225 | pred = torch.zeros(y_test.shape) 226 | index = 0 227 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 228 | with torch.no_grad(): 229 | for x, y in test_loader: 230 | test_l2 = 0 231 | x, y = x.cuda(), y.cuda() 232 | 233 | out = model(x).view(-1) 234 | pred[index] = out 235 | 236 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 237 | print(index, test_l2) 238 | index = index + 1 239 | 240 | # scipy.io.savemat('pred/burger_test.mat', mdict={'pred': pred.cpu().numpy()}) 241 | -------------------------------------------------------------------------------- /fourier_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 2D problem such as the Darcy Flow discussed in Section 5.2 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | import torch.nn.functional as F 7 | from timeit import default_timer 8 | from utilities3 import * 9 | 10 | torch.manual_seed(0) 11 | np.random.seed(0) 12 | 13 | ################################################################ 14 | # fourier layer 15 | ################################################################ 16 | class SpectralConv2d(nn.Module): 17 | def __init__(self, in_channels, out_channels, modes1, modes2): 18 | super(SpectralConv2d, self).__init__() 19 | 20 | """ 21 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 22 | """ 23 | 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 27 | self.modes2 = modes2 28 | 29 | self.scale = (1 / (in_channels * out_channels)) 30 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 31 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 32 | 33 | # Complex multiplication 34 | def compl_mul2d(self, input, weights): 35 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 36 | return torch.einsum("bixy,ioxy->boxy", input, weights) 37 | 38 | def forward(self, x): 39 | batchsize = x.shape[0] 40 | #Compute Fourier coeffcients up to factor of e^(- something constant) 41 | x_ft = torch.fft.rfft2(x) 42 | 43 | # Multiply relevant Fourier modes 44 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 45 | out_ft[:, :, :self.modes1, :self.modes2] = \ 46 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 47 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 48 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 49 | 50 | #Return to physical space 51 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 52 | return x 53 | 54 | class MLP(nn.Module): 55 | def __init__(self, in_channels, out_channels, mid_channels): 56 | super(MLP, self).__init__() 57 | self.mlp1 = nn.Conv2d(in_channels, mid_channels, 1) 58 | self.mlp2 = nn.Conv2d(mid_channels, out_channels, 1) 59 | 60 | def forward(self, x): 61 | x = self.mlp1(x) 62 | x = F.gelu(x) 63 | x = self.mlp2(x) 64 | return x 65 | 66 | class FNO2d(nn.Module): 67 | def __init__(self, modes1, modes2, width): 68 | super(FNO2d, self).__init__() 69 | 70 | """ 71 | The overall network. It contains 4 layers of the Fourier layer. 72 | 1. Lift the input to the desire channel dimension by self.fc0 . 73 | 2. 4 layers of the integral operators u' = (W + K)(u). 74 | W defined by self.w; K defined by self.conv . 75 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 76 | 77 | input: the solution of the coefficient function and locations (a(x, y), x, y) 78 | input shape: (batchsize, x=s, y=s, c=3) 79 | output: the solution 80 | output shape: (batchsize, x=s, y=s, c=1) 81 | """ 82 | 83 | self.modes1 = modes1 84 | self.modes2 = modes2 85 | self.width = width 86 | self.padding = 9 # pad the domain if input is non-periodic 87 | 88 | self.p = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 89 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 90 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 91 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 92 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.mlp0 = MLP(self.width, self.width, self.width) 94 | self.mlp1 = MLP(self.width, self.width, self.width) 95 | self.mlp2 = MLP(self.width, self.width, self.width) 96 | self.mlp3 = MLP(self.width, self.width, self.width) 97 | self.w0 = nn.Conv2d(self.width, self.width, 1) 98 | self.w1 = nn.Conv2d(self.width, self.width, 1) 99 | self.w2 = nn.Conv2d(self.width, self.width, 1) 100 | self.w3 = nn.Conv2d(self.width, self.width, 1) 101 | self.q = MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y) 102 | 103 | def forward(self, x): 104 | grid = self.get_grid(x.shape, x.device) 105 | x = torch.cat((x, grid), dim=-1) 106 | x = self.p(x) 107 | x = x.permute(0, 3, 1, 2) 108 | x = F.pad(x, [0,self.padding, 0,self.padding]) 109 | 110 | x1 = self.conv0(x) 111 | x1 = self.mlp0(x1) 112 | x2 = self.w0(x) 113 | x = x1 + x2 114 | x = F.gelu(x) 115 | 116 | x1 = self.conv1(x) 117 | x1 = self.mlp1(x1) 118 | x2 = self.w1(x) 119 | x = x1 + x2 120 | x = F.gelu(x) 121 | 122 | x1 = self.conv2(x) 123 | x1 = self.mlp2(x1) 124 | x2 = self.w2(x) 125 | x = x1 + x2 126 | x = F.gelu(x) 127 | 128 | x1 = self.conv3(x) 129 | x1 = self.mlp3(x1) 130 | x2 = self.w3(x) 131 | x = x1 + x2 132 | 133 | x = x[..., :-self.padding, :-self.padding] 134 | x = self.q(x) 135 | x = x.permute(0, 2, 3, 1) 136 | return x 137 | 138 | def get_grid(self, shape, device): 139 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 140 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 141 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 142 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 143 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 144 | return torch.cat((gridx, gridy), dim=-1).to(device) 145 | 146 | ################################################################ 147 | # configs 148 | ################################################################ 149 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 150 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 151 | 152 | ntrain = 1000 153 | ntest = 100 154 | 155 | batch_size = 20 156 | learning_rate = 0.001 157 | epochs = 500 158 | iterations = epochs*(ntrain//batch_size) 159 | 160 | modes = 12 161 | width = 32 162 | 163 | r = 5 164 | h = int(((421 - 1)/r) + 1) 165 | s = h 166 | 167 | ################################################################ 168 | # load data and data normalization 169 | ################################################################ 170 | reader = MatReader(TRAIN_PATH) 171 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s] 172 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s] 173 | 174 | reader.load_file(TEST_PATH) 175 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s] 176 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s] 177 | 178 | x_normalizer = UnitGaussianNormalizer(x_train) 179 | x_train = x_normalizer.encode(x_train) 180 | x_test = x_normalizer.encode(x_test) 181 | 182 | y_normalizer = UnitGaussianNormalizer(y_train) 183 | y_train = y_normalizer.encode(y_train) 184 | 185 | x_train = x_train.reshape(ntrain,s,s,1) 186 | x_test = x_test.reshape(ntest,s,s,1) 187 | 188 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 189 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 190 | 191 | ################################################################ 192 | # training and evaluation 193 | ################################################################ 194 | model = FNO2d(modes, modes, width).cuda() 195 | print(count_params(model)) 196 | 197 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 198 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 199 | 200 | myloss = LpLoss(size_average=False) 201 | y_normalizer.cuda() 202 | for ep in range(epochs): 203 | model.train() 204 | t1 = default_timer() 205 | train_l2 = 0 206 | for x, y in train_loader: 207 | x, y = x.cuda(), y.cuda() 208 | 209 | optimizer.zero_grad() 210 | out = model(x).reshape(batch_size, s, s) 211 | out = y_normalizer.decode(out) 212 | y = y_normalizer.decode(y) 213 | 214 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 215 | loss.backward() 216 | 217 | optimizer.step() 218 | scheduler.step() 219 | train_l2 += loss.item() 220 | 221 | model.eval() 222 | test_l2 = 0.0 223 | with torch.no_grad(): 224 | for x, y in test_loader: 225 | x, y = x.cuda(), y.cuda() 226 | 227 | out = model(x).reshape(batch_size, s, s) 228 | out = y_normalizer.decode(out) 229 | 230 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 231 | 232 | train_l2/= ntrain 233 | test_l2 /= ntest 234 | 235 | t2 = default_timer() 236 | print(ep, t2-t1, train_l2, test_l2) 237 | -------------------------------------------------------------------------------- /fourier_2d_time.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 2D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf), 4 | which uses a recurrent structure to propagates in time. 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from utilities3 import * 9 | from timeit import default_timer 10 | 11 | torch.manual_seed(0) 12 | np.random.seed(0) 13 | 14 | ################################################################ 15 | # fourier layer 16 | ################################################################ 17 | 18 | class SpectralConv2d(nn.Module): 19 | def __init__(self, in_channels, out_channels, modes1, modes2): 20 | super(SpectralConv2d, self).__init__() 21 | 22 | """ 23 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 24 | """ 25 | 26 | self.in_channels = in_channels 27 | self.out_channels = out_channels 28 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 29 | self.modes2 = modes2 30 | 31 | self.scale = (1 / (in_channels * out_channels)) 32 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 33 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 34 | 35 | # Complex multiplication 36 | def compl_mul2d(self, input, weights): 37 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 38 | return torch.einsum("bixy,ioxy->boxy", input, weights) 39 | 40 | def forward(self, x): 41 | batchsize = x.shape[0] 42 | #Compute Fourier coeffcients up to factor of e^(- something constant) 43 | x_ft = torch.fft.rfft2(x) 44 | 45 | # Multiply relevant Fourier modes 46 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 47 | out_ft[:, :, :self.modes1, :self.modes2] = \ 48 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 49 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 50 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 51 | 52 | #Return to physical space 53 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 54 | return x 55 | 56 | class MLP(nn.Module): 57 | def __init__(self, in_channels, out_channels, mid_channels): 58 | super(MLP, self).__init__() 59 | self.mlp1 = nn.Conv2d(in_channels, mid_channels, 1) 60 | self.mlp2 = nn.Conv2d(mid_channels, out_channels, 1) 61 | 62 | def forward(self, x): 63 | x = self.mlp1(x) 64 | x = F.gelu(x) 65 | x = self.mlp2(x) 66 | return x 67 | 68 | class FNO2d(nn.Module): 69 | def __init__(self, modes1, modes2, width): 70 | super(FNO2d, self).__init__() 71 | 72 | """ 73 | The overall network. It contains 4 layers of the Fourier layer. 74 | 1. Lift the input to the desire channel dimension by self.fc0 . 75 | 2. 4 layers of the integral operators u' = (W + K)(u). 76 | W defined by self.w; K defined by self.conv . 77 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 78 | 79 | input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) 80 | input shape: (batchsize, x=64, y=64, c=12) 81 | output: the solution of the next timestep 82 | output shape: (batchsize, x=64, y=64, c=1) 83 | """ 84 | 85 | self.modes1 = modes1 86 | self.modes2 = modes2 87 | self.width = width 88 | self.padding = 8 # pad the domain if input is non-periodic 89 | 90 | self.p = nn.Linear(12, self.width) # input channel is 12: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) 91 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 92 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 94 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 95 | self.mlp0 = MLP(self.width, self.width, self.width) 96 | self.mlp1 = MLP(self.width, self.width, self.width) 97 | self.mlp2 = MLP(self.width, self.width, self.width) 98 | self.mlp3 = MLP(self.width, self.width, self.width) 99 | self.w0 = nn.Conv2d(self.width, self.width, 1) 100 | self.w1 = nn.Conv2d(self.width, self.width, 1) 101 | self.w2 = nn.Conv2d(self.width, self.width, 1) 102 | self.w3 = nn.Conv2d(self.width, self.width, 1) 103 | self.norm = nn.InstanceNorm2d(self.width) 104 | self.q = MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y) 105 | 106 | def forward(self, x): 107 | grid = self.get_grid(x.shape, x.device) 108 | x = torch.cat((x, grid), dim=-1) 109 | x = self.p(x) 110 | x = x.permute(0, 3, 1, 2) 111 | # x = F.pad(x, [0,self.padding, 0,self.padding]) # pad the domain if input is non-periodic 112 | 113 | x1 = self.norm(self.conv0(self.norm(x))) 114 | x1 = self.mlp0(x1) 115 | x2 = self.w0(x) 116 | x = x1 + x2 117 | x = F.gelu(x) 118 | 119 | x1 = self.norm(self.conv1(self.norm(x))) 120 | x1 = self.mlp1(x1) 121 | x2 = self.w1(x) 122 | x = x1 + x2 123 | x = F.gelu(x) 124 | 125 | x1 = self.norm(self.conv2(self.norm(x))) 126 | x1 = self.mlp2(x1) 127 | x2 = self.w2(x) 128 | x = x1 + x2 129 | x = F.gelu(x) 130 | 131 | x1 = self.norm(self.conv3(self.norm(x))) 132 | x1 = self.mlp3(x1) 133 | x2 = self.w3(x) 134 | x = x1 + x2 135 | 136 | # x = x[..., :-self.padding, :-self.padding] # pad the domain if input is non-periodic 137 | x = self.q(x) 138 | x = x.permute(0, 2, 3, 1) 139 | return x 140 | 141 | def get_grid(self, shape, device): 142 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 143 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 144 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 145 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 146 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 147 | return torch.cat((gridx, gridy), dim=-1).to(device) 148 | 149 | ################################################################ 150 | # configs 151 | ################################################################ 152 | 153 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 154 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 155 | 156 | ntrain = 1000 157 | ntest = 200 158 | 159 | modes = 12 160 | width = 20 161 | 162 | batch_size = 20 163 | learning_rate = 0.001 164 | epochs = 500 165 | iterations = epochs*(ntrain//batch_size) 166 | 167 | path = 'ns_fourier_2d_time_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 168 | path_model = 'model/'+path 169 | path_train_err = 'results/'+path+'train.txt' 170 | path_test_err = 'results/'+path+'test.txt' 171 | path_image = 'image/'+path 172 | 173 | sub = 1 174 | S = 64 175 | T_in = 10 176 | T = 40 # T=40 for V1e-3; T=20 for V1e-4; T=10 for V1e-5; 177 | step = 1 178 | 179 | ################################################################ 180 | # load data 181 | ################################################################ 182 | 183 | reader = MatReader(TRAIN_PATH) 184 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 185 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 186 | 187 | reader = MatReader(TEST_PATH) 188 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 189 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 190 | 191 | print(train_u.shape) 192 | print(test_u.shape) 193 | assert (S == train_u.shape[-2]) 194 | assert (T == train_u.shape[-1]) 195 | 196 | train_a = train_a.reshape(ntrain,S,S,T_in) 197 | test_a = test_a.reshape(ntest,S,S,T_in) 198 | 199 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 200 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 201 | 202 | ################################################################ 203 | # training and evaluation 204 | ################################################################ 205 | model = FNO2d(modes, modes, width).cuda() 206 | print(count_params(model)) 207 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 208 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 209 | 210 | myloss = LpLoss(size_average=False) 211 | for ep in range(epochs): 212 | model.train() 213 | t1 = default_timer() 214 | train_l2_step = 0 215 | train_l2_full = 0 216 | for xx, yy in train_loader: 217 | loss = 0 218 | xx = xx.to(device) 219 | yy = yy.to(device) 220 | 221 | for t in range(0, T, step): 222 | y = yy[..., t:t + step] 223 | im = model(xx) 224 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 225 | 226 | if t == 0: 227 | pred = im 228 | else: 229 | pred = torch.cat((pred, im), -1) 230 | 231 | xx = torch.cat((xx[..., step:], im), dim=-1) 232 | 233 | train_l2_step += loss.item() 234 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 235 | train_l2_full += l2_full.item() 236 | 237 | optimizer.zero_grad() 238 | loss.backward() 239 | optimizer.step() 240 | scheduler.step() 241 | 242 | test_l2_step = 0 243 | test_l2_full = 0 244 | with torch.no_grad(): 245 | for xx, yy in test_loader: 246 | loss = 0 247 | xx = xx.to(device) 248 | yy = yy.to(device) 249 | 250 | for t in range(0, T, step): 251 | y = yy[..., t:t + step] 252 | im = model(xx) 253 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 254 | 255 | if t == 0: 256 | pred = im 257 | else: 258 | pred = torch.cat((pred, im), -1) 259 | 260 | xx = torch.cat((xx[..., step:], im), dim=-1) 261 | 262 | test_l2_step += loss.item() 263 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 264 | 265 | t2 = default_timer() 266 | print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step), 267 | test_l2_full / ntest) 268 | # torch.save(model, path_model) 269 | 270 | -------------------------------------------------------------------------------- /fourier_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 3D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf), 4 | which takes the 2D spatial + 1D temporal equation directly as a 3D problem 5 | """ 6 | 7 | 8 | import torch.nn.functional as F 9 | from utilities3 import * 10 | from timeit import default_timer 11 | 12 | torch.manual_seed(0) 13 | np.random.seed(0) 14 | 15 | ################################################################ 16 | # 3d fourier layers 17 | ################################################################ 18 | 19 | class SpectralConv3d(nn.Module): 20 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 21 | super(SpectralConv3d, self).__init__() 22 | 23 | """ 24 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. 25 | """ 26 | 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 30 | self.modes2 = modes2 31 | self.modes3 = modes3 32 | 33 | self.scale = (1 / (in_channels * out_channels)) 34 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 35 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 36 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 37 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 38 | 39 | # Complex multiplication 40 | def compl_mul3d(self, input, weights): 41 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 42 | return torch.einsum("bixyz,ioxyz->boxyz", input, weights) 43 | 44 | def forward(self, x): 45 | batchsize = x.shape[0] 46 | #Compute Fourier coeffcients up to factor of e^(- something constant) 47 | x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1]) 48 | 49 | # Multiply relevant Fourier modes 50 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 51 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 52 | self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 53 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 54 | self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 55 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 56 | self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 57 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 58 | self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 59 | 60 | #Return to physical space 61 | x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) 62 | return x 63 | 64 | class MLP(nn.Module): 65 | def __init__(self, in_channels, out_channels, mid_channels): 66 | super(MLP, self).__init__() 67 | self.mlp1 = nn.Conv3d(in_channels, mid_channels, 1) 68 | self.mlp2 = nn.Conv3d(mid_channels, out_channels, 1) 69 | 70 | def forward(self, x): 71 | x = self.mlp1(x) 72 | x = F.gelu(x) 73 | x = self.mlp2(x) 74 | return x 75 | 76 | class FNO3d(nn.Module): 77 | def __init__(self, modes1, modes2, modes3, width): 78 | super(FNO3d, self).__init__() 79 | 80 | """ 81 | The overall network. It contains 4 layers of the Fourier layer. 82 | 1. Lift the input to the desire channel dimension by self.fc0 . 83 | 2. 4 layers of the integral operators u' = (W + K)(u). 84 | W defined by self.w; K defined by self.conv . 85 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 86 | 87 | input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. 88 | input shape: (batchsize, x=64, y=64, t=40, c=13) 89 | output: the solution of the next 40 timesteps 90 | output shape: (batchsize, x=64, y=64, t=40, c=1) 91 | """ 92 | 93 | self.modes1 = modes1 94 | self.modes2 = modes2 95 | self.modes3 = modes3 96 | self.width = width 97 | self.padding = 6 # pad the domain if input is non-periodic 98 | 99 | self.p = nn.Linear(13, self.width)# input channel is 12: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t) 100 | self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 101 | self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 102 | self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 103 | self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 104 | self.mlp0 = MLP(self.width, self.width, self.width) 105 | self.mlp1 = MLP(self.width, self.width, self.width) 106 | self.mlp2 = MLP(self.width, self.width, self.width) 107 | self.mlp3 = MLP(self.width, self.width, self.width) 108 | self.w0 = nn.Conv3d(self.width, self.width, 1) 109 | self.w1 = nn.Conv3d(self.width, self.width, 1) 110 | self.w2 = nn.Conv3d(self.width, self.width, 1) 111 | self.w3 = nn.Conv3d(self.width, self.width, 1) 112 | self.q = MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y) 113 | 114 | def forward(self, x): 115 | grid = self.get_grid(x.shape, x.device) 116 | x = torch.cat((x, grid), dim=-1) 117 | x = self.p(x) 118 | x = x.permute(0, 4, 1, 2, 3) 119 | x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic 120 | 121 | x1 = self.conv0(x) 122 | x1 = self.mlp0(x1) 123 | x2 = self.w0(x) 124 | x = x1 + x2 125 | x = F.gelu(x) 126 | 127 | x1 = self.conv1(x) 128 | x1 = self.mlp1(x1) 129 | x2 = self.w1(x) 130 | x = x1 + x2 131 | x = F.gelu(x) 132 | 133 | x1 = self.conv2(x) 134 | x1 = self.mlp2(x1) 135 | x2 = self.w2(x) 136 | x = x1 + x2 137 | x = F.gelu(x) 138 | 139 | x1 = self.conv3(x) 140 | x1 = self.mlp3(x1) 141 | x2 = self.w3(x) 142 | x = x1 + x2 143 | 144 | x = x[..., :-self.padding] 145 | x = self.q(x) 146 | x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic 147 | return x 148 | 149 | 150 | def get_grid(self, shape, device): 151 | batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] 152 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 153 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 154 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 155 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 156 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float) 157 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 158 | return torch.cat((gridx, gridy, gridz), dim=-1).to(device) 159 | 160 | ################################################################ 161 | # configs 162 | ################################################################ 163 | 164 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 165 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 166 | 167 | ntrain = 1000 168 | ntest = 200 169 | 170 | modes = 8 171 | width = 20 172 | 173 | batch_size = 10 174 | learning_rate = 0.001 175 | epochs = 500 176 | iterations = epochs*(ntrain//batch_size) 177 | 178 | path = 'ns_fourier_3d_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 179 | path_model = 'model/'+path 180 | path_train_err = 'results/'+path+'train.txt' 181 | path_test_err = 'results/'+path+'test.txt' 182 | path_image = 'image/'+path 183 | 184 | runtime = np.zeros(2, ) 185 | t1 = default_timer() 186 | 187 | sub = 1 188 | S = 64 // sub 189 | T_in = 10 190 | T = 40 # T=40 for V1e-3; T=20 for V1e-4; T=10 for V1e-5; 191 | 192 | ################################################################ 193 | # load data 194 | ################################################################ 195 | 196 | reader = MatReader(TRAIN_PATH) 197 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 198 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 199 | 200 | reader = MatReader(TEST_PATH) 201 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 202 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 203 | 204 | print(train_u.shape) 205 | print(test_u.shape) 206 | assert (S == train_u.shape[-2]) 207 | assert (T == train_u.shape[-1]) 208 | 209 | 210 | a_normalizer = UnitGaussianNormalizer(train_a) 211 | train_a = a_normalizer.encode(train_a) 212 | test_a = a_normalizer.encode(test_a) 213 | 214 | y_normalizer = UnitGaussianNormalizer(train_u) 215 | train_u = y_normalizer.encode(train_u) 216 | 217 | train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1]) 218 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 219 | 220 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 221 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 222 | 223 | t2 = default_timer() 224 | 225 | print('preprocessing finished, time used:', t2-t1) 226 | device = torch.device('cuda') 227 | 228 | ################################################################ 229 | # training and evaluation 230 | ################################################################ 231 | model = FNO3d(modes, modes, modes, width).cuda() 232 | print(count_params(model)) 233 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 234 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 235 | 236 | myloss = LpLoss(size_average=False) 237 | y_normalizer.cuda() 238 | for ep in range(epochs): 239 | model.train() 240 | t1 = default_timer() 241 | train_mse = 0 242 | train_l2 = 0 243 | for x, y in train_loader: 244 | x, y = x.cuda(), y.cuda() 245 | 246 | optimizer.zero_grad() 247 | out = model(x).view(batch_size, S, S, T) 248 | 249 | mse = F.mse_loss(out, y, reduction='mean') 250 | # mse.backward() 251 | 252 | y = y_normalizer.decode(y) 253 | out = y_normalizer.decode(out) 254 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 255 | l2.backward() 256 | 257 | optimizer.step() 258 | scheduler.step() 259 | train_mse += mse.item() 260 | train_l2 += l2.item() 261 | 262 | model.eval() 263 | test_l2 = 0.0 264 | with torch.no_grad(): 265 | for x, y in test_loader: 266 | x, y = x.cuda(), y.cuda() 267 | 268 | out = model(x).view(batch_size, S, S, T) 269 | out = y_normalizer.decode(out) 270 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 271 | 272 | train_mse /= len(train_loader) 273 | train_l2 /= ntrain 274 | test_l2 /= ntest 275 | 276 | t2 = default_timer() 277 | print(ep, t2-t1, train_mse, train_l2, test_l2) 278 | # torch.save(model, path_model) 279 | 280 | pred = torch.zeros(test_u.shape) 281 | index = 0 282 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 283 | with torch.no_grad(): 284 | for x, y in test_loader: 285 | test_l2 = 0 286 | x, y = x.cuda(), y.cuda() 287 | 288 | out = model(x) 289 | out = y_normalizer.decode(out) 290 | pred[index] = out 291 | 292 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 293 | print(index, test_l2) 294 | index = index + 1 295 | 296 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 297 | 298 | 299 | 300 | 301 | -------------------------------------------------------------------------------- /lowrank_operators/lowrank_1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | import h5py 7 | import scipy.io 8 | import matplotlib.pyplot as plt 9 | from timeit import default_timer 10 | import sys 11 | import math 12 | 13 | import operator 14 | from functools import reduce 15 | 16 | from timeit import default_timer 17 | from utilities3 import * 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | 22 | ################################################################ 23 | # lowrank layer 24 | ################################################################ 25 | class LowRank1d(nn.Module): 26 | def __init__(self, in_channels, out_channels, s, width, rank=1): 27 | super(LowRank1d, self).__init__() 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.s = s 31 | self.n = s 32 | self.rank = rank 33 | 34 | self.phi = DenseNet([2, 64, 128, 256, width*width*rank], torch.nn.ReLU) 35 | self.psi = DenseNet([2, 64, 128, 256, width*width*rank], torch.nn.ReLU) 36 | 37 | 38 | def forward(self, v, a): 39 | # a (batch, n, 2) 40 | # v (batch, n, f) 41 | batch_size = v.shape[0] 42 | 43 | phi_eval = self.phi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 44 | psi_eval = self.psi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 45 | 46 | # print(psi_eval.shape, v.shape, phi_eval.shape) 47 | v = torch.einsum('bnoir,bni,bmoir->bmo',psi_eval, v, phi_eval) / self.n 48 | 49 | return v 50 | 51 | 52 | 53 | class MyNet(torch.nn.Module): 54 | def __init__(self, s, width=32, rank=4): 55 | super(MyNet, self).__init__() 56 | self.s = s 57 | self.width = width 58 | self.rank = rank 59 | 60 | self.fc0 = nn.Linear(2, self.width) 61 | 62 | self.net1 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 63 | self.net2 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 64 | self.net3 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 65 | self.net4 = LowRank1d(self.width, self.width, s, width, rank=self.rank) 66 | self.w1 = nn.Linear(self.width, self.width) 67 | self.w2 = nn.Linear(self.width, self.width) 68 | self.w3 = nn.Linear(self.width, self.width) 69 | self.w4 = nn.Linear(self.width, self.width) 70 | 71 | self.bn1 = torch.nn.BatchNorm1d(self.width) 72 | self.bn2 = torch.nn.BatchNorm1d(self.width) 73 | self.bn3 = torch.nn.BatchNorm1d(self.width) 74 | self.bn4 = torch.nn.BatchNorm1d(self.width) 75 | 76 | self.fc1 = nn.Linear(self.width, 128) 77 | self.fc2 = nn.Linear(128, 1) 78 | 79 | 80 | def forward(self, v): 81 | batch_size, n = v.shape[0], v.shape[1] 82 | a = v.clone() 83 | 84 | v = self.fc0(v) 85 | 86 | v1 = self.net1(v, a) 87 | v2 = self.w1(v) 88 | v = v1+v2 89 | v = self.bn1(v.reshape(-1, self.width)).view(batch_size,n,self.width) 90 | v = F.relu(v) 91 | 92 | v1 = self.net2(v, a) 93 | v2 = self.w2(v) 94 | v = v1+v2 95 | v = self.bn2(v.reshape(-1, self.width)).view(batch_size,n,self.width) 96 | v = F.relu(v) 97 | 98 | v1 = self.net3(v, a) 99 | v2 = self.w3(v) 100 | v = v1+v2 101 | v = self.bn3(v.reshape(-1, self.width)).view(batch_size,n,self.width) 102 | v = F.relu(v) 103 | 104 | v1 = self.net4(v, a) 105 | v2 = self.w4(v) 106 | v = v1+v2 107 | v = self.bn4(v.reshape(-1, self.width)).view(batch_size,n,self.width) 108 | 109 | 110 | v = self.fc1(v) 111 | v = F.relu(v) 112 | v = self.fc2(v) 113 | 114 | return v.squeeze() 115 | 116 | def count_params(self): 117 | c = 0 118 | for p in self.parameters(): 119 | c += reduce(operator.mul, list(p.size())) 120 | 121 | return c 122 | 123 | ################################################################ 124 | # configs 125 | ################################################################ 126 | 127 | ntrain = 1000 128 | ntest = 200 129 | 130 | sub = 1 #subsampling rate 131 | h = 2**13 // sub 132 | s = h 133 | 134 | batch_size = 5 135 | learning_rate = 0.001 136 | 137 | 138 | ################################################################ 139 | # reading data and normalization 140 | ################################################################ 141 | dataloader = MatReader('data/burgers_data_R10.mat') 142 | x_data = dataloader.read_field('a')[:,::sub] 143 | y_data = dataloader.read_field('u')[:,::sub] 144 | 145 | x_train = x_data[:ntrain,:] 146 | y_train = y_data[:ntrain,:] 147 | x_test = x_data[-ntest:,:] 148 | y_test = y_data[-ntest:,:] 149 | 150 | x_normalizer = UnitGaussianNormalizer(x_train) 151 | x_train = x_normalizer.encode(x_train) 152 | x_test = x_normalizer.encode(x_test) 153 | 154 | y_normalizer = UnitGaussianNormalizer(y_train) 155 | y_train = y_normalizer.encode(y_train) 156 | 157 | grid = np.linspace(0, 2*np.pi, s).reshape(1, s, 1) 158 | grid = torch.tensor(grid, dtype=torch.float) 159 | x_train = torch.cat([x_train.reshape(ntrain,s,1), grid.repeat(ntrain,1,1)], dim=2) 160 | x_test = torch.cat([x_test.reshape(ntest,s,1), grid.repeat(ntest,1,1)], dim=2) 161 | 162 | 163 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 164 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 165 | 166 | model = MyNet(s).cuda() 167 | print(model.count_params()) 168 | 169 | ################################################################ 170 | # training and evaluation 171 | ################################################################ 172 | 173 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 174 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 175 | epochs = 500 176 | 177 | myloss = LpLoss(size_average=False) 178 | y_normalizer.cuda() 179 | for ep in range(epochs): 180 | model.train() 181 | t1 = default_timer() 182 | train_mse = 0 183 | train_l2 = 0 184 | for x, y in train_loader: 185 | x, y = x.cuda(), y.cuda() 186 | 187 | optimizer.zero_grad() 188 | out = model(x).reshape(batch_size, s) 189 | 190 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean') 191 | # mse.backward() 192 | 193 | out = y_normalizer.decode(out) 194 | y = y_normalizer.decode(y) 195 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 196 | loss.backward() 197 | 198 | optimizer.step() 199 | train_mse += mse.item() 200 | train_l2 += loss.item() 201 | 202 | scheduler.step() 203 | 204 | model.eval() 205 | test_l2 = 0.0 206 | with torch.no_grad(): 207 | for x, y in test_loader: 208 | x, y = x.cuda(), y.cuda() 209 | 210 | out = model(x).reshape(batch_size, s) 211 | out = y_normalizer.decode(out) 212 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 213 | 214 | train_mse /= len(train_loader) 215 | train_l2 /= ntrain 216 | test_l2 /= ntest 217 | 218 | t2 = default_timer() 219 | print(ep, t2-t1, train_mse, train_l2, test_l2) 220 | -------------------------------------------------------------------------------- /lowrank_operators/lowrank_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | import h5py 7 | import scipy.io 8 | import matplotlib.pyplot as plt 9 | from timeit import default_timer 10 | import sys 11 | import math 12 | 13 | import operator 14 | from functools import reduce 15 | 16 | from timeit import default_timer 17 | from utilities import * 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | 22 | 23 | 24 | class LowRank2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, s, width, rank): 26 | super(LowRank2d, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.s = s 30 | self.n = s*s 31 | self.rank = rank 32 | 33 | self.phi = DenseNet([3, 64, 128, 256, width*width*rank], torch.nn.ReLU) 34 | self.psi = DenseNet([3, 64, 128, 256, width*width*rank], torch.nn.ReLU) 35 | 36 | 37 | def forward(self, v, a): 38 | # a (batch, n, 3) 39 | # v (batch, n, f) 40 | batch_size = v.shape[0] 41 | 42 | phi_eval = self.phi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 43 | psi_eval = self.psi(a).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 44 | 45 | # print(psi_eval.shape, v.shape, phi_eval.shape) 46 | v = torch.einsum('bnoir,bni,bmoir->bmo', psi_eval, v, phi_eval) / self.n 47 | 48 | return v 49 | 50 | 51 | 52 | class MyNet(torch.nn.Module): 53 | def __init__(self, s, width=32, rank=1): 54 | super(MyNet, self).__init__() 55 | self.s = s 56 | self.width = width 57 | self.rank = rank 58 | 59 | self.fc0 = nn.Linear(3, self.width) 60 | 61 | self.net1 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 62 | self.net2 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 63 | self.net3 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 64 | self.net4 = LowRank2d(self.width, self.width, s, width, rank=self.rank) 65 | self.w1 = nn.Linear(self.width, self.width) 66 | self.w2 = nn.Linear(self.width, self.width) 67 | self.w3 = nn.Linear(self.width, self.width) 68 | self.w4 = nn.Linear(self.width, self.width) 69 | 70 | self.bn1 = torch.nn.BatchNorm1d(self.width) 71 | self.bn2 = torch.nn.BatchNorm1d(self.width) 72 | self.bn3 = torch.nn.BatchNorm1d(self.width) 73 | self.bn4 = torch.nn.BatchNorm1d(self.width) 74 | 75 | self.fc1 = nn.Linear(self.width, 128) 76 | self.fc2 = nn.Linear(128, 1) 77 | 78 | 79 | def forward(self, v): 80 | batch_size, n = v.shape[0], v.shape[1] 81 | a = v.clone() 82 | 83 | v = self.fc0(v) 84 | 85 | v1 = self.net1(v, a) 86 | v2 = self.w1(v) 87 | v = v1+v2 88 | v = self.bn1(v.reshape(-1, self.width)).view(batch_size,n,self.width) 89 | v = F.relu(v) 90 | 91 | v1 = self.net2(v, a) 92 | v2 = self.w2(v) 93 | v = v1+v2 94 | v = self.bn2(v.reshape(-1, self.width)).view(batch_size,n,self.width) 95 | v = F.relu(v) 96 | 97 | v1 = self.net3(v, a) 98 | v2 = self.w3(v) 99 | v = v1+v2 100 | v = self.bn3(v.reshape(-1, self.width)).view(batch_size,n,self.width) 101 | v = F.relu(v) 102 | 103 | v1 = self.net4(v, a) 104 | v2 = self.w4(v) 105 | v = v1+v2 106 | v = self.bn4(v.reshape(-1, self.width)).view(batch_size,n,self.width) 107 | 108 | 109 | v = self.fc1(v) 110 | v = F.relu(v) 111 | v = self.fc2(v) 112 | 113 | return v.squeeze() 114 | 115 | def count_params(self): 116 | c = 0 117 | for p in self.parameters(): 118 | c += reduce(operator.mul, list(p.size())) 119 | 120 | return c 121 | 122 | 123 | 124 | 125 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 126 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 127 | 128 | ntrain = 1000 129 | ntest = 100 130 | 131 | batch_size = 10 132 | 133 | r = 5 134 | h = int(((421 - 1)/r) + 1) 135 | s = h 136 | 137 | learning_rate = 0.00025 138 | 139 | reader = MatReader(TRAIN_PATH) 140 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s].reshape(ntrain,s*s) 141 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s].reshape(ntrain,s*s) 142 | 143 | reader.load_file(TEST_PATH) 144 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s].reshape(ntest,s*s) 145 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s].reshape(ntest,s*s) 146 | 147 | 148 | x_normalizer = UnitGaussianNormalizer(x_train) 149 | x_train = x_normalizer.encode(x_train) 150 | x_test = x_normalizer.encode(x_test) 151 | 152 | y_normalizer = UnitGaussianNormalizer(y_train) 153 | y_train = y_normalizer.encode(y_train) 154 | 155 | grids = [] 156 | grids.append(np.linspace(0, 1, s)) 157 | grids.append(np.linspace(0, 1, s)) 158 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 159 | grid = grid.reshape(1,s*s,2) 160 | grid = torch.tensor(grid, dtype=torch.float) 161 | x_train = torch.cat([x_train.reshape(ntrain,s*s,1), grid.repeat(ntrain,1,1)], dim=2) 162 | x_test = torch.cat([x_test.reshape(ntest,s*s,1), grid.repeat(ntest,1,1)], dim=2) 163 | 164 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 165 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 166 | 167 | model = MyNet(s).cuda() 168 | # model = MyNet_old(s).cuda() 169 | 170 | print(model.count_params()) 171 | 172 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 173 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 174 | epochs = 200 175 | 176 | myloss = LpLoss(size_average=False) 177 | y_normalizer.cuda() 178 | for ep in range(epochs): 179 | model.train() 180 | t1 = default_timer() 181 | train_mse = 0 182 | train_l2 = 0 183 | for x, y in train_loader: 184 | x, y = x.cuda(), y.cuda() 185 | 186 | optimizer.zero_grad() 187 | out = model(x).reshape(batch_size, s*s) 188 | 189 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean') 190 | mse.backward() 191 | 192 | out = y_normalizer.decode(out) 193 | y = y_normalizer.decode(y) 194 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 195 | # loss.backward() 196 | 197 | optimizer.step() 198 | train_mse += mse.item() 199 | train_l2 += loss.item() 200 | 201 | scheduler.step() 202 | 203 | model.eval() 204 | test_l2 = 0.0 205 | with torch.no_grad(): 206 | for x, y in test_loader: 207 | x, y = x.cuda(), y.cuda() 208 | 209 | out = model(x).reshape(batch_size, s*s) 210 | out = y_normalizer.decode(out) 211 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 212 | 213 | train_mse /= len(train_loader) 214 | train_l2 /= ntrain 215 | test_l2 /= ntest 216 | 217 | t2 = default_timer() 218 | print(ep, t2-t1, train_mse, train_l2, test_l2) 219 | -------------------------------------------------------------------------------- /lowrank_operators/lowrank_2d_time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | activation = F.relu 20 | 21 | ################################################################ 22 | # lowrank layers 23 | ################################################################ 24 | class LowRank2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, s, ker_width, rank): 26 | super(LowRank2d, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.s = s 30 | self.n = s*s 31 | self.rank = rank 32 | 33 | self.phi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 34 | self.psi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 35 | 36 | 37 | def forward(self, v): 38 | batch_size = v.shape[0] 39 | 40 | phi_eval = self.phi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 41 | psi_eval = self.psi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 42 | 43 | # print(psi_eval.shape, v.shape, phi_eval.shape) 44 | v = torch.einsum('bnoir,bni,bmoir->bmo',psi_eval, v, phi_eval) 45 | 46 | return v 47 | 48 | 49 | 50 | class MyNet(torch.nn.Module): 51 | def __init__(self, s, width=16, ker_width=256, rank=16): 52 | super(MyNet, self).__init__() 53 | self.s = s 54 | self.width = width 55 | self.ker_width = ker_width 56 | self.rank = rank 57 | 58 | self.fc0 = nn.Linear(12, self.width) 59 | 60 | self.conv0 = LowRank2d(width, width, s, ker_width, rank) 61 | self.conv1 = LowRank2d(width, width, s, ker_width, rank) 62 | self.conv2 = LowRank2d(width, width, s, ker_width, rank) 63 | self.conv3 = LowRank2d(width, width, s, ker_width, rank) 64 | 65 | self.w0 = nn.Linear(self.width, self.width) 66 | self.w1 = nn.Linear(self.width, self.width) 67 | self.w2 = nn.Linear(self.width, self.width) 68 | self.w3 = nn.Linear(self.width, self.width) 69 | self.bn0 = torch.nn.BatchNorm1d(self.width) 70 | self.bn1 = torch.nn.BatchNorm1d(self.width) 71 | self.bn2 = torch.nn.BatchNorm1d(self.width) 72 | self.bn3 = torch.nn.BatchNorm1d(self.width) 73 | 74 | self.fc1 = nn.Linear(self.width, 128) 75 | self.fc2 = nn.Linear(128, 1) 76 | 77 | 78 | def forward(self, x): 79 | batch_size = x.shape[0] 80 | size_x, size_y = x.shape[1], x.shape[2] 81 | x = x.view(batch_size, size_x*size_y, -1) 82 | 83 | x = self.fc0(x) 84 | 85 | x1 = self.conv0(x) 86 | x2 = self.w0(x) 87 | x = x1 + x2 88 | x = self.bn0(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 89 | x = F.relu(x) 90 | x1 = self.conv1(x) 91 | x2 = self.w1(x) 92 | x = x1 + x2 93 | x = self.bn1(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 94 | x = F.relu(x) 95 | x1 = self.conv2(x) 96 | x2 = self.w2(x) 97 | x = x1 + x2 98 | x = self.bn2(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 99 | x = F.relu(x) 100 | x1 = self.conv3(x) 101 | x2 = self.w3(x) 102 | x = x1 + x2 103 | x = self.bn3(x.reshape(-1, self.width)).view(batch_size, size_x*size_y, self.width) 104 | 105 | x = self.fc1(x) 106 | x = F.relu(x) 107 | x = self.fc2(x) 108 | x = x.view(batch_size, size_x, size_y, -1) 109 | return x 110 | 111 | class Net2d(nn.Module): 112 | def __init__(self, width=12, ker_width=128, rank=4): 113 | super(Net2d, self).__init__() 114 | 115 | self.conv1 = MyNet(s=64, width=width, ker_width=ker_width, rank=rank) 116 | 117 | 118 | def forward(self, x): 119 | x = self.conv1(x) 120 | return x 121 | 122 | 123 | def count_params(self): 124 | c = 0 125 | for p in self.parameters(): 126 | c += reduce(operator.mul, list(p.size())) 127 | 128 | return c 129 | 130 | ################################################################ 131 | # configs 132 | ################################################################ 133 | # TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 134 | # TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 135 | # TRAIN_PATH = 'data/ns_data_V1000_N1000_train.mat' 136 | # TEST_PATH = 'data/ns_data_V1000_N1000_train_2.mat' 137 | # TRAIN_PATH = 'data/ns_data_V1000_N5000.mat' 138 | # TEST_PATH = 'data/ns_data_V1000_N5000.mat' 139 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 140 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 141 | 142 | ntrain = 1000 143 | ntest = 200 144 | 145 | batch_size = 5 146 | batch_size2 = batch_size 147 | 148 | epochs = 500 149 | learning_rate = 0.0025 150 | scheduler_step = 100 151 | scheduler_gamma = 0.5 152 | 153 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 154 | 155 | path = 'ns_lowrank_rnn_V100_T40_N'+str(ntrain)+'_ep' + str(epochs) + '_m' 156 | path_model = 'model/'+path 157 | path_train_err = 'results/'+path+'train.txt' 158 | path_test_err = 'results/'+path+'test.txt' 159 | path_image = 'image/'+path 160 | 161 | 162 | runtime = np.zeros(2, ) 163 | t1 = default_timer() 164 | 165 | 166 | sub = 1 167 | S = 64 168 | T_in = 10 169 | T = 40 170 | step = 1 171 | 172 | 173 | ################################################################ 174 | # load dataset 175 | ################################################################ 176 | reader = MatReader(TRAIN_PATH) 177 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 178 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 179 | 180 | reader = MatReader(TEST_PATH) 181 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 182 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 183 | 184 | print(train_u.shape) 185 | print(test_u.shape) 186 | assert (S == train_u.shape[-2]) 187 | assert (T == train_u.shape[-1]) 188 | 189 | 190 | train_a = train_a.reshape(ntrain,S,S,T_in) 191 | test_a = test_a.reshape(ntest,S,S,T_in) 192 | 193 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 194 | gridx = gridx.reshape(1, S, 1, 1).repeat([1, 1, S, 1]) 195 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 196 | gridy = gridy.reshape(1, 1, S, 1).repeat([1, S, 1, 1]) 197 | 198 | train_a = torch.cat((train_a, gridx.repeat([ntrain,1,1,1]), gridy.repeat([ntrain,1,1,1])), dim=-1) 199 | test_a = torch.cat((test_a, gridx.repeat([ntest,1,1,1]), gridy.repeat([ntest,1,1,1])), dim=-1) 200 | 201 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 202 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 203 | 204 | t2 = default_timer() 205 | 206 | print('preprocessing finished, time used:', t2-t1) 207 | device = torch.device('cuda') 208 | 209 | ################################################################ 210 | # training and evaluation 211 | ################################################################ 212 | model = Net2d().cuda() 213 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 214 | 215 | print(model.count_params()) 216 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 217 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 218 | 219 | 220 | myloss = LpLoss(size_average=False) 221 | gridx = gridx.to(device) 222 | gridy = gridy.to(device) 223 | 224 | for ep in range(epochs): 225 | model.train() 226 | t1 = default_timer() 227 | train_l2_step = 0 228 | train_l2_full = 0 229 | for xx, yy in train_loader: 230 | loss = 0 231 | xx = xx.to(device) 232 | yy = yy.to(device) 233 | 234 | for t in range(0, T, step): 235 | y = yy[..., t:t + step] 236 | im = model(xx) 237 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 238 | 239 | if t == 0: 240 | pred = im 241 | else: 242 | pred = torch.cat((pred, im), -1) 243 | 244 | xx = torch.cat((xx[..., step:-2], im, 245 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 246 | 247 | train_l2_step += loss.item() 248 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 249 | train_l2_full += l2_full.item() 250 | 251 | optimizer.zero_grad() 252 | loss.backward() 253 | # l2_full.backward() 254 | optimizer.step() 255 | 256 | test_l2_step = 0 257 | test_l2_full = 0 258 | with torch.no_grad(): 259 | for xx, yy in test_loader: 260 | loss = 0 261 | xx = xx.to(device) 262 | yy = yy.to(device) 263 | 264 | for t in range(0, T, step): 265 | y = yy[..., t:t + step] 266 | im = model(xx) 267 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 268 | 269 | if t == 0: 270 | pred = im 271 | else: 272 | pred = torch.cat((pred, im), -1) 273 | 274 | xx = torch.cat((xx[..., step:-2], im, 275 | gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1) 276 | 277 | 278 | test_l2_step += loss.item() 279 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 280 | 281 | t2 = default_timer() 282 | scheduler.step() 283 | print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step), 284 | test_l2_full / ntest) 285 | # torch.save(model, path_model) 286 | 287 | 288 | # pred = torch.zeros(test_u.shape) 289 | # index = 0 290 | # test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 291 | # with torch.no_grad(): 292 | # for x, y in test_loader: 293 | # test_l2 = 0; 294 | # x, y = x.cuda(), y.cuda() 295 | # 296 | # out = model(x) 297 | # out = y_normalizer.decode(out) 298 | # pred[index] = out 299 | # 300 | # test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 301 | # print(index, test_l2) 302 | # index = index + 1 303 | 304 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 305 | 306 | -------------------------------------------------------------------------------- /lowrank_operators/lowrank_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | 20 | ################################################################ 21 | # 3d lowrank layers 22 | ################################################################ 23 | 24 | class LowRank2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, n, ker_width, rank): 26 | super(LowRank2d, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.n = n 30 | self.rank = rank 31 | 32 | self.phi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 33 | self.psi = DenseNet([in_channels, ker_width, in_channels*out_channels*rank], torch.nn.ReLU) 34 | 35 | 36 | def forward(self, v): 37 | batch_size = v.shape[0] 38 | 39 | phi_eval = self.phi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 40 | psi_eval = self.psi(v).reshape(batch_size, self.n, self.out_channels, self.in_channels, self.rank) 41 | 42 | # print(psi_eval.shape, v.shape, phi_eval.shape) 43 | v = torch.einsum('bnoir,bni,bmoir->bmo',psi_eval, v, phi_eval) 44 | 45 | return v 46 | 47 | 48 | 49 | class MyNet(torch.nn.Module): 50 | def __init__(self, n, width=16, ker_width=256, rank=16): 51 | super(MyNet, self).__init__() 52 | self.n = n 53 | self.width = width 54 | self.ker_width = ker_width 55 | self.rank = rank 56 | 57 | self.fc0 = nn.Linear(13, self.width) 58 | 59 | self.conv0 = LowRank2d(width, width, n, ker_width, rank) 60 | self.conv1 = LowRank2d(width, width, n, ker_width, rank) 61 | self.conv2 = LowRank2d(width, width, n, ker_width, rank) 62 | self.conv3 = LowRank2d(width, width, n, ker_width, rank) 63 | 64 | self.w0 = nn.Linear(self.width, self.width) 65 | self.w1 = nn.Linear(self.width, self.width) 66 | self.w2 = nn.Linear(self.width, self.width) 67 | self.w3 = nn.Linear(self.width, self.width) 68 | self.bn0 = torch.nn.BatchNorm1d(self.width) 69 | self.bn1 = torch.nn.BatchNorm1d(self.width) 70 | self.bn2 = torch.nn.BatchNorm1d(self.width) 71 | self.bn3 = torch.nn.BatchNorm1d(self.width) 72 | 73 | self.fc1 = nn.Linear(self.width, 128) 74 | self.fc2 = nn.Linear(128, 1) 75 | 76 | 77 | def forward(self, x): 78 | batch_size = x.shape[0] 79 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 80 | x = x.view(batch_size, size_x*size_y*size_z, -1) 81 | 82 | x = self.fc0(x) 83 | 84 | x1 = self.conv0(x) 85 | x2 = self.w0(x) 86 | x = x1 + x2 87 | x = self.bn0(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 88 | x = F.relu(x) 89 | x1 = self.conv1(x) 90 | x2 = self.w1(x) 91 | x = x1 + x2 92 | x = self.bn1(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 93 | x = F.relu(x) 94 | x1 = self.conv2(x) 95 | x2 = self.w2(x) 96 | x = x1 + x2 97 | x = self.bn2(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 98 | x = F.relu(x) 99 | x1 = self.conv3(x) 100 | x2 = self.w3(x) 101 | x = x1 + x2 102 | x = self.bn3(x.reshape(-1, self.width)).view(batch_size, size_x*size_y*size_z, self.width) 103 | 104 | x = self.fc1(x) 105 | x = F.relu(x) 106 | x = self.fc2(x) 107 | x = x.view(batch_size, size_x, size_y, size_z) 108 | return x 109 | 110 | class Net2d(nn.Module): 111 | def __init__(self, width=8, ker_width=128, rank=4): 112 | super(Net2d, self).__init__() 113 | 114 | self.conv1 = MyNet(n=64*64*40, width=width, ker_width=ker_width, rank=rank) 115 | 116 | 117 | def forward(self, x): 118 | x = self.conv1(x) 119 | return x 120 | 121 | 122 | def count_params(self): 123 | c = 0 124 | for p in self.parameters(): 125 | c += reduce(operator.mul, list(p.size())) 126 | 127 | return c 128 | 129 | ################################################################ 130 | # configs 131 | ################################################################ 132 | # TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 133 | # TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 134 | # TRAIN_PATH = 'data/ns_data_V1000_N1000_train.mat' 135 | # TEST_PATH = 'data/ns_data_V1000_N1000_train_2.mat' 136 | # TRAIN_PATH = 'data/ns_data_V1000_N5000.mat' 137 | # TEST_PATH = 'data/ns_data_V1000_N5000.mat' 138 | TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat' 139 | TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat' 140 | 141 | ntrain = 1000 142 | ntest = 200 143 | 144 | batch_size = 2 145 | batch_size2 = batch_size 146 | 147 | epochs = 500 148 | learning_rate = 0.0025 149 | scheduler_step = 100 150 | scheduler_gamma = 0.5 151 | 152 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 153 | 154 | path = 'ns_lowrank_V100_T40_N'+str(ntrain)+'_ep' + str(epochs) 155 | path_model = 'model/'+path 156 | path_train_err = 'results/'+path+'train.txt' 157 | path_test_err = 'results/'+path+'test.txt' 158 | path_image = 'image/'+path 159 | 160 | runtime = np.zeros(2, ) 161 | t1 = default_timer() 162 | 163 | 164 | sub = 1 165 | S = 64 166 | T_in = 10 167 | T = 40 168 | 169 | ################################################################ 170 | # load data 171 | ################################################################ 172 | 173 | reader = MatReader(TRAIN_PATH) 174 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in] 175 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 176 | 177 | reader = MatReader(TEST_PATH) 178 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in] 179 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 180 | 181 | print(train_u.shape) 182 | print(test_u.shape) 183 | assert (S == train_u.shape[-2]) 184 | assert (T == train_u.shape[-1]) 185 | 186 | 187 | a_normalizer = UnitGaussianNormalizer(train_a) 188 | train_a = a_normalizer.encode(train_a) 189 | test_a = a_normalizer.encode(test_a) 190 | 191 | y_normalizer = UnitGaussianNormalizer(train_u) 192 | train_u = y_normalizer.encode(train_u) 193 | 194 | train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1]) 195 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 196 | 197 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 198 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 199 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 200 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 201 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 202 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 203 | 204 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 205 | gridt.repeat([ntrain,1,1,1,1]), train_a), dim=-1) 206 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 207 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 208 | 209 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 210 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 211 | 212 | t2 = default_timer() 213 | 214 | print('preprocessing finished, time used:', t2-t1) 215 | device = torch.device('cuda') 216 | 217 | 218 | ################################################################ 219 | # training and evaluation 220 | ################################################################ 221 | model = Net2d().cuda() 222 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 223 | 224 | print(model.count_params()) 225 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 226 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 227 | 228 | 229 | myloss = LpLoss(size_average=False) 230 | y_normalizer.cuda() 231 | for ep in range(epochs): 232 | model.train() 233 | t1 = default_timer() 234 | train_mse = 0 235 | train_l2 = 0 236 | for x, y in train_loader: 237 | x, y = x.cuda(), y.cuda() 238 | 239 | optimizer.zero_grad() 240 | out = model(x) 241 | 242 | mse = F.mse_loss(out, y, reduction='mean') 243 | # mse.backward() 244 | 245 | y = y_normalizer.decode(y) 246 | out = y_normalizer.decode(out) 247 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 248 | l2.backward() 249 | 250 | optimizer.step() 251 | train_mse += mse.item() 252 | train_l2 += l2.item() 253 | 254 | scheduler.step() 255 | 256 | model.eval() 257 | test_l2 = 0.0 258 | with torch.no_grad(): 259 | for x, y in test_loader: 260 | x, y = x.cuda(), y.cuda() 261 | 262 | out = model(x) 263 | out = y_normalizer.decode(out) 264 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 265 | 266 | train_mse /= len(train_loader) 267 | train_l2 /= ntrain 268 | test_l2 /= ntest 269 | 270 | t2 = default_timer() 271 | print(ep, t2-t1, train_mse, train_l2, test_l2) 272 | # torch.save(model, path_model) 273 | 274 | 275 | pred = torch.zeros(test_u.shape) 276 | index = 0 277 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 278 | with torch.no_grad(): 279 | for x, y in test_loader: 280 | test_l2 = 0; 281 | x, y = x.cuda(), y.cuda() 282 | 283 | out = model(x) 284 | out = y_normalizer.decode(out) 285 | pred[index] = out 286 | 287 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 288 | print(index, test_l2) 289 | index = index + 1 290 | 291 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 292 | 293 | 294 | 295 | 296 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | from utilities3 import * 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | 11 | from timeit import default_timer 12 | import scipy.io 13 | 14 | torch.manual_seed(0) 15 | np.random.seed(0) 16 | 17 | def compl_mul3d(a, b): 18 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 19 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 20 | return torch.stack([ 21 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 22 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 23 | ], dim=-1) 24 | 25 | class SpectralConv3d_fast(nn.Module): 26 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 27 | super(SpectralConv3d_fast, self).__init__() 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 31 | self.modes2 = modes2 32 | self.modes3 = modes3 33 | 34 | self.scale = (1 / (in_channels * out_channels)) 35 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 36 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 37 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 38 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 39 | 40 | def forward(self, x): 41 | batchsize = x.shape[0] 42 | #Compute Fourier coeffcients up to factor of e^(- something constant) 43 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 44 | 45 | # Multiply relevant Fourier modes 46 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 47 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 48 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 49 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 50 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 51 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 52 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 53 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 54 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 55 | 56 | #Return to physical space 57 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 58 | return x 59 | 60 | class SimpleBlock2d(nn.Module): 61 | def __init__(self, modes1, modes2, modes3, width): 62 | super(SimpleBlock2d, self).__init__() 63 | 64 | self.modes1 = modes1 65 | self.modes2 = modes2 66 | self.modes3 = modes3 67 | self.width = width 68 | self.fc0 = nn.Linear(13, self.width) 69 | 70 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 71 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 72 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 73 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 74 | self.w0 = nn.Conv1d(self.width, self.width, 1) 75 | self.w1 = nn.Conv1d(self.width, self.width, 1) 76 | self.w2 = nn.Conv1d(self.width, self.width, 1) 77 | self.w3 = nn.Conv1d(self.width, self.width, 1) 78 | self.bn0 = torch.nn.BatchNorm3d(self.width) 79 | self.bn1 = torch.nn.BatchNorm3d(self.width) 80 | self.bn2 = torch.nn.BatchNorm3d(self.width) 81 | self.bn3 = torch.nn.BatchNorm3d(self.width) 82 | 83 | 84 | self.fc1 = nn.Linear(self.width, 128) 85 | self.fc2 = nn.Linear(128, 1) 86 | 87 | def forward(self, x): 88 | batchsize = x.shape[0] 89 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 90 | 91 | x = self.fc0(x) 92 | x = x.permute(0, 4, 1, 2, 3) 93 | 94 | x1 = self.conv0(x) 95 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 96 | x = self.bn0(x1 + x2) 97 | x = F.relu(x) 98 | x1 = self.conv1(x) 99 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 100 | x = self.bn1(x1 + x2) 101 | x = F.relu(x) 102 | x1 = self.conv2(x) 103 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 104 | x = self.bn2(x1 + x2) 105 | x = F.relu(x) 106 | x1 = self.conv3(x) 107 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 108 | x = self.bn3(x1 + x2) 109 | 110 | x = x.permute(0, 2, 3, 4, 1) 111 | x = self.fc1(x) 112 | x = F.relu(x) 113 | x = self.fc2(x) 114 | return x 115 | 116 | class Net2d(nn.Module): 117 | def __init__(self, modes, width): 118 | super(Net2d, self).__init__() 119 | self.conv1 = SimpleBlock2d(modes, modes, 6, width) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | return x.squeeze() 124 | 125 | def count_params(self): 126 | c = 0 127 | for p in self.parameters(): 128 | c += reduce(operator.mul, list(p.size())) 129 | 130 | return c 131 | 132 | 133 | t1 = default_timer() 134 | 135 | TEST_PATH = 'data/ns_data_V1e-4_N20_T50_R256test.mat' 136 | 137 | 138 | ntest = 20 139 | 140 | sub = 4 141 | sub_t = 4 142 | S = 64 143 | T_in = 10 144 | T = 20 145 | 146 | indent = 3 147 | 148 | # load data 149 | reader = MatReader(TEST_PATH) 150 | test_a = reader.read_field('u')[:,::sub,::sub, indent:T_in*4:4] #([0, T_in]) 151 | test_u = reader.read_field('u')[:,::sub,::sub, indent+T_in*4:indent+(T+T_in)*4:sub_t] #([T_in, T_in + T]) 152 | 153 | print(test_a.shape, test_u.shape) 154 | 155 | # pad the location information (s,t) 156 | S = S * (4//sub) 157 | T = T * (4//sub_t) 158 | 159 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 160 | 161 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 162 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 163 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 164 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 165 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 166 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 167 | 168 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 169 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 170 | 171 | t2 = default_timer() 172 | print('preprocessing finished, time used:', t2-t1) 173 | device = torch.device('cuda') 174 | 175 | # load model 176 | model = torch.load('model/ns_fourier_V1e-4_T20_N9800_ep200_m12_w32') 177 | 178 | print(model.count_params()) 179 | 180 | # test 181 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 182 | myloss = LpLoss(size_average=False) 183 | pred = torch.zeros(test_u.shape) 184 | index = 0 185 | with torch.no_grad(): 186 | test_l2 = 0 187 | for x, y in test_loader: 188 | x, y = x.cuda(), y.cuda() 189 | 190 | out = model(x) 191 | pred[index] = out 192 | loss = myloss(out.view(1, -1), y.view(1, -1)).item() 193 | test_l2 += loss 194 | print(index, loss) 195 | index = index + 1 196 | print(test_l2/ntest) 197 | 198 | path = 'eval' 199 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy(), 'u': test_u.cpu().numpy()}) 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /scripts/fourier_2d_tuned.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | 20 | def compl_mul2d(a, b): 21 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 22 | return torch.einsum("bixy,ioxy->boxy", a, b) 23 | 24 | # return torch.stack([ 25 | # op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 26 | # op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 27 | # ], dim=-1) 28 | 29 | class SpectralConv2d(nn.Module): 30 | def __init__(self, in_channels, out_channels, modes1, modes2): 31 | super(SpectralConv2d, self).__init__() 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 35 | self.modes2 = modes2 36 | 37 | self.scale = (1 / (in_channels * out_channels)) 38 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 39 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 40 | 41 | def forward(self, x, size=None): 42 | if size==None: 43 | size = x.size(-1) 44 | 45 | batchsize = x.shape[0] 46 | #Compute Fourier coeffcients up to factor of e^(- something constant) 47 | x_ft = torch.fft.rfftn(x, dim=[2,3]) 48 | 49 | # Multiply relevant Fourier modes 50 | out_ft = torch.zeros(batchsize, self.out_channels, size, size//2 + 1, device=x.device, dtype=torch.cfloat) 51 | out_ft[:, :, :self.modes1, :self.modes2] = \ 52 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 53 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 54 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 55 | 56 | 57 | #Return to physical space 58 | x = torch.fft.irfftn(out_ft, s=(size, size), dim=[2,3]) 59 | return x 60 | 61 | class SimpleBlock2d(nn.Module): 62 | def __init__(self, in_dim, out_dim, modes1, modes2, width): 63 | super(SimpleBlock2d, self).__init__() 64 | 65 | self.modes1 = modes1 66 | self.modes2 = modes2 67 | 68 | self.width_list = [width*2//4, width*3//4, width*4//4, width*4//4, width*5//4] 69 | self.size_list = [64,] * 5 70 | self.grid_dim = 2 71 | 72 | self.fc0 = nn.Linear(in_dim+self.grid_dim, self.width_list[0]) 73 | 74 | self.conv0 = SpectralConv2d(self.width_list[0]+self.grid_dim, self.width_list[1], self.modes1*4//4, self.modes2*4//4) 75 | self.conv1 = SpectralConv2d(self.width_list[1]+self.grid_dim, self.width_list[2], self.modes1*3//4, self.modes2*3//4) 76 | self.conv2 = SpectralConv2d(self.width_list[2]+self.grid_dim, self.width_list[3], self.modes1*2//4, self.modes2*2//4) 77 | self.conv3 = SpectralConv2d(self.width_list[3]+self.grid_dim, self.width_list[4], self.modes1*1//4, self.modes2*1//4) 78 | self.w0 = nn.Conv1d(self.width_list[0]+self.grid_dim, self.width_list[1], 1) 79 | self.w1 = nn.Conv1d(self.width_list[1]+self.grid_dim, self.width_list[2], 1) 80 | self.w2 = nn.Conv1d(self.width_list[2]+self.grid_dim, self.width_list[3], 1) 81 | self.w3 = nn.Conv1d(self.width_list[3]+self.grid_dim, self.width_list[4], 1) 82 | 83 | self.fc1 = nn.Linear(self.width_list[4], self.width_list[4]*2) 84 | self.fc2 = nn.Linear(self.width_list[4]*2, self.width_list[4]*2) 85 | self.fc3 = nn.Linear(self.width_list[4]*2, out_dim) 86 | 87 | def forward(self, x): 88 | 89 | batchsize = x.shape[0] 90 | size_x, size_y= x.shape[1], x.shape[2] 91 | grid = self.get_grid(size_x, batchsize, x.device) 92 | size_list = self.size_list 93 | 94 | x = torch.cat((x, grid.permute(0, 2, 3, 1).repeat([1,1,1,1])), dim=-1) 95 | 96 | x = self.fc0(x) 97 | x = x.permute(0, 3, 1, 2) 98 | 99 | x = torch.cat((x, grid), dim=1) 100 | x1 = self.conv0(x, size_list[1]) 101 | x2 = self.w0(x.view(batchsize, self.width_list[0]+self.grid_dim, size_list[0]**2)).view(batchsize, self.width_list[1], size_list[0], size_list[0]) 102 | # x2 = F.interpolate(x2, size=size_list[1], mode='trilinear') 103 | x = x1 + x2 104 | x = F.selu(x) 105 | 106 | x = torch.cat((x, grid), dim=1) 107 | x1 = self.conv1(x, size_list[2]) 108 | x2 = self.w1(x.view(batchsize, self.width_list[1]+self.grid_dim, size_list[1]**2)).view(batchsize, self.width_list[2], size_list[1], size_list[1]) 109 | # x2 = F.interpolate(x2, size=size_list[2], mode='trilinear') 110 | x = x1 + x2 111 | x = F.selu(x) 112 | 113 | x = torch.cat((x, grid), dim=1) 114 | x1 = self.conv2(x, size_list[3]) 115 | x2 = self.w2(x.view(batchsize, self.width_list[2]+self.grid_dim, size_list[2]**2)).view(batchsize, self.width_list[3], size_list[2], size_list[2]) 116 | # x2 = F.interpolate(x2, size=size_list[3], mode='trilinear') 117 | x = x1 + x2 118 | x = F.selu(x) 119 | 120 | x = torch.cat((x, grid), dim=1) 121 | x1 = self.conv3(x, size_list[4]) 122 | x2 = self.w3(x.view(batchsize, self.width_list[3]+self.grid_dim, size_list[3]**2)).view(batchsize, self.width_list[4], size_list[3], size_list[3]) 123 | # x2 = F.interpolate(x2, size=size_list[4], mode='trilinear') 124 | x = x1 + x2 125 | 126 | x = x.permute(0, 2, 3, 1) 127 | x = self.fc1(x) 128 | x = F.selu(x) 129 | x = self.fc2(x) 130 | x = F.selu(x) 131 | x = self.fc3(x) 132 | return x 133 | 134 | def get_grid(self, S, batchsize, device): 135 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 136 | gridx = gridx.reshape(1, 1, S, 1).repeat([batchsize, 1, 1, S]) 137 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 138 | gridy = gridy.reshape(1, 1, 1, S).repeat([batchsize, 1, S, 1]) 139 | return torch.cat((gridx, gridy), dim=1).to(device) 140 | 141 | class Net2d(nn.Module): 142 | def __init__(self, in_dim, out_dim, modes, width): 143 | super(Net2d, self).__init__() 144 | self.conv1 = SimpleBlock2d(in_dim, out_dim, modes, modes, width) 145 | 146 | def forward(self, x): 147 | x = self.conv1(x) 148 | return x.squeeze() 149 | 150 | def count_params(self): 151 | c = 0 152 | for p in self.parameters(): 153 | c += reduce(operator.mul, list(p.size())) 154 | 155 | return c 156 | 157 | # TRAIN_PATH = 'data/ns_data_V100_N400_T200_0.mat' 158 | # TEST_PATH = 'data/ns_data_V100_N400_T200_0.mat' 159 | # TRAIN_PATH = 'data/ns_data_V1000_N200_T400_0.mat' 160 | # TEST_PATH = 'data/ns_data_V1000_N200_T400_0.mat' 161 | 162 | 163 | ntrain = 20 164 | ntest = 5 165 | 166 | modes = 20 167 | width = 32 168 | 169 | in_dim = 4 170 | out_dim = 2 171 | 172 | batch_size = 5 173 | batch_size2 = 5 174 | 175 | 176 | epochs = 100 177 | learning_rate = 0.0025 178 | scheduler_step = 20 179 | scheduler_gamma = 0.5 180 | 181 | loss_k = 1 182 | loss_group = False 183 | 184 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 185 | 186 | path = 'KF_vel_N'+str(ntrain)+'_k' + str(loss_k)+'_g' + str(loss_group)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 187 | path_model = 'model/'+path 188 | path_train_err = 'results/'+path+'train.txt' 189 | path_test_err = 'results/'+path+'test.txt' 190 | path_image = 'image/'+path 191 | 192 | 193 | runtime = np.zeros(2, ) 194 | t1 = default_timer() 195 | 196 | 197 | sub = 4 198 | S = 64 199 | 200 | T_in = 100 201 | T = 400 202 | T_out = T_in+T 203 | step = 1 204 | 205 | 206 | 207 | data = np.load('data/KFvelocity_Re40_N25_part1.npy') 208 | data = torch.tensor(data, dtype=torch.float) 209 | print(data.shape ) 210 | 211 | train_a = data[:ntrain,T_in-1:T_out-1,::sub,::sub,:].permute(0,2,3,1,4) 212 | train_u = data[:ntrain,T_in:T_out,::sub,::sub,:].permute(0,2,3,1,4) 213 | 214 | test_a = data[-ntest:,T_in-1:T_out-1,::sub,::sub,:].permute(0,2,3,1,4) 215 | test_u = data[-ntest:,T_in:T_out,::sub,::sub,:].permute(0,2,3,1,4) 216 | 217 | print(train_a.shape) 218 | print(train_u.shape) 219 | assert (S == train_u.shape[2]) 220 | assert (T == train_u.shape[3]) 221 | 222 | 223 | 224 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 225 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, 1, 1]) 226 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 227 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, 1, 1]) 228 | 229 | train_a = torch.cat((gridx.repeat([ntrain,1,1,T,1]), gridy.repeat([ntrain,1,1,T,1]), train_a), dim=-1) 230 | test_a = torch.cat((gridx.repeat([ntest,1,1,T,1]), gridy.repeat([ntest,1,1,T,1]), test_a), dim=-1) 231 | 232 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 233 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 234 | 235 | t2 = default_timer() 236 | 237 | print('preprocessing finished, time used:', t2-t1) 238 | device = torch.device('cuda') 239 | 240 | model = Net2d(in_dim, out_dim, modes, width).cuda() 241 | # model = torch.load('model/KF_vel_N20_ep200_m12_w32') 242 | # model = torch.load('model/KF_vol_500_N20_ep200_m12_w32') 243 | 244 | 245 | print(model.count_params()) 246 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 247 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 248 | 249 | 250 | lploss = LpLoss(size_average=False) 251 | hsloss = HsLoss(k=loss_k, group=loss_group, size_average=False) 252 | 253 | gridx = gridx.to(device) 254 | gridy = gridy.to(device) 255 | 256 | for ep in range(epochs): 257 | model.train() 258 | t1 = default_timer() 259 | train_l2 = 0 260 | for xx, yy in train_loader: 261 | xx = xx.to(device) 262 | yy = yy.to(device) 263 | 264 | for t in range(0, T): 265 | x = xx[:,:,:,t,:] 266 | y = yy[:,:,:,t] 267 | 268 | out = model(x) 269 | loss = hsloss(out.reshape(batch_size, S, S, out_dim), y.reshape(batch_size, S, S, out_dim)) 270 | train_l2 += loss.item() 271 | 272 | optimizer.zero_grad() 273 | loss.backward() 274 | optimizer.step() 275 | 276 | test_l2 = 0 277 | test_l2_hp = 0 278 | with torch.no_grad(): 279 | for xx, yy in test_loader: 280 | xx = xx.to(device) 281 | yy = yy.to(device) 282 | 283 | for t in range(0, T): 284 | x = xx[:, :, :, t, :] 285 | y = yy[:, :, :, t] 286 | 287 | out = model(x) 288 | test_l2 += lploss(out.reshape(batch_size, S, S, out_dim), y.reshape(batch_size, S, S, out_dim)).item() 289 | test_l2_hp += hsloss(out.reshape(batch_size, S, S, out_dim), y.reshape(batch_size, S, S, out_dim)).item() 290 | 291 | 292 | t2 = default_timer() 293 | scheduler.step() 294 | print(ep, t2 - t1, train_l2/ntrain/T, test_l2_hp/ntest/T, test_l2/ntest/T) 295 | # print(ep, t2 - t1, test_l2/ntest/T) 296 | 297 | 298 | # torch.save(model, path_model) 299 | # 300 | # 301 | # model.eval() 302 | # 303 | # test_a = test_a[0,:,:,0,-2:] 304 | # 305 | # T = 1000 - T_in 306 | # pred = torch.zeros(S,S,T,2) 307 | # gridx = gridx.reshape(1,S,S,1) 308 | # gridy = gridy.reshape(1,S,S,1) 309 | # x_out = test_a.reshape(1,S,S,2).cuda() 310 | # with torch.no_grad(): 311 | # for i in range(T): 312 | # print(i) 313 | # x_in = torch.cat([gridx, gridy, x_out], dim=-1) 314 | # x_out = model(x_in) 315 | # pred[:,:,i] = x_out.view(S,S,2) 316 | # 317 | # 318 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 319 | 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /scripts/fourier_3d_time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | ################################################################ 20 | # fourier layers 21 | ################################################################ 22 | 23 | def compl_mul3d(a, b): 24 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 25 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 26 | return torch.stack([ 27 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 28 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 29 | ], dim=-1) 30 | 31 | class SpectralConv3d_fast(nn.Module): 32 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 33 | super(SpectralConv3d_fast, self).__init__() 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 37 | self.modes2 = modes2 38 | self.modes3 = modes3 39 | 40 | self.scale = (1 / (in_channels * out_channels)) 41 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 42 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 43 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 44 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 45 | 46 | def forward(self, x): 47 | batchsize = x.shape[0] 48 | #Compute Fourier coeffcients up to factor of e^(- something constant) 49 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 50 | 51 | # Multiply relevant Fourier modes 52 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 53 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 54 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 55 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 56 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 57 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 58 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 59 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 60 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 61 | 62 | #Return to physical space 63 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 64 | return x 65 | 66 | class SimpleBlock2d(nn.Module): 67 | def __init__(self, modes1, modes2, modes3, width): 68 | super(SimpleBlock2d, self).__init__() 69 | 70 | self.modes1 = modes1 71 | self.modes2 = modes2 72 | self.modes3 = modes3 73 | self.width = width 74 | self.fc0 = nn.Linear(4, self.width) 75 | 76 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 77 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 78 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 79 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 80 | self.w0 = nn.Conv1d(self.width, self.width, 1) 81 | self.w1 = nn.Conv1d(self.width, self.width, 1) 82 | self.w2 = nn.Conv1d(self.width, self.width, 1) 83 | self.w3 = nn.Conv1d(self.width, self.width, 1) 84 | self.bn0 = torch.nn.BatchNorm3d(self.width) 85 | self.bn1 = torch.nn.BatchNorm3d(self.width) 86 | self.bn2 = torch.nn.BatchNorm3d(self.width) 87 | self.bn3 = torch.nn.BatchNorm3d(self.width) 88 | 89 | 90 | self.fc1 = nn.Linear(self.width, 128) 91 | self.fc2 = nn.Linear(128, 1) 92 | 93 | def forward(self, x): 94 | batchsize = x.shape[0] 95 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 96 | 97 | x = self.fc0(x) 98 | x = x.permute(0, 4, 1, 2, 3) 99 | 100 | x1 = self.conv0(x) 101 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 102 | x = self.bn0(x1 + x2) 103 | x = F.relu(x) 104 | x1 = self.conv1(x) 105 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 106 | x = self.bn1(x1 + x2) 107 | x = F.relu(x) 108 | x1 = self.conv2(x) 109 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 110 | x = self.bn2(x1 + x2) 111 | x = F.relu(x) 112 | x1 = self.conv3(x) 113 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 114 | x = self.bn3(x1 + x2) 115 | 116 | 117 | x = x.permute(0, 2, 3, 4, 1) 118 | x = self.fc1(x) 119 | x = F.relu(x) 120 | x = self.fc2(x) 121 | return x 122 | 123 | class Net2d(nn.Module): 124 | def __init__(self, modes, width): 125 | super(Net2d, self).__init__() 126 | 127 | self.conv1 = SimpleBlock2d(modes, modes, 4, width) 128 | 129 | 130 | def forward(self, x): 131 | x = self.conv1(x) 132 | return x 133 | 134 | 135 | def count_params(self): 136 | c = 0 137 | for p in self.parameters(): 138 | c += reduce(operator.mul, list(p.size())) 139 | 140 | return c 141 | 142 | ################################################################ 143 | # configs 144 | ################################################################ 145 | 146 | TRAIN_PATH = 'data/ns_data_V10000_N1200_T20.mat' 147 | TEST_PATH = 'data/ns_data_V10000_N1200_T20.mat' 148 | 149 | ntrain = 1000 150 | ntest = 200 151 | 152 | modes = 12 153 | width = 20 154 | 155 | batch_size = 20 156 | batch_size2 = batch_size 157 | 158 | 159 | epochs = 500 160 | learning_rate = 0.0025 161 | scheduler_step = 100 162 | scheduler_gamma = 0.5 163 | 164 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 165 | 166 | path = 'ns_fourier_3d_rnn_V10000_T20_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 167 | path_model = 'model/'+path 168 | path_train_err = 'results/'+path+'train.txt' 169 | path_test_err = 'results/'+path+'test.txt' 170 | path_image = 'image/'+path 171 | 172 | 173 | runtime = np.zeros(2, ) 174 | t1 = default_timer() 175 | 176 | 177 | sub = 1 178 | S = 64 179 | T_in = 10 180 | T_start = 0 181 | step = T_in - T_start 182 | T = 10 183 | 184 | ################################################################ 185 | # load data 186 | ################################################################ 187 | 188 | reader = MatReader(TRAIN_PATH) 189 | train_a = reader.read_field('u')[:ntrain,::sub,::sub,T_start:T_in] 190 | train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in] 191 | 192 | reader = MatReader(TEST_PATH) 193 | test_a = reader.read_field('u')[-ntest:,::sub,::sub,T_start:T_in] 194 | test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in] 195 | 196 | print(train_u.shape, test_u.shape) 197 | assert (S == train_u.shape[-2]) 198 | assert (T == train_u.shape[-1]) 199 | 200 | 201 | 202 | train_a = train_a.reshape(ntrain,S,S,step,1) 203 | test_a = test_a.reshape(ntest,S,S,step,1) 204 | 205 | # cat the location information (x,y,t) 206 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 207 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, step, 1]) 208 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 209 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, step, 1]) 210 | gridt = torch.tensor(np.linspace(0, 1, step+1)[1:], dtype=torch.float) 211 | gridt = gridt.reshape(1, 1, 1, step, 1).repeat([1, S, S, 1, 1]) 212 | 213 | train_a = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 214 | gridt.repeat([ntrain,1,1,1,1]), train_a), dim=-1) 215 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 216 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 217 | 218 | print(train_a.shape, train_u.shape) 219 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 220 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 221 | 222 | t2 = default_timer() 223 | 224 | print('preprocessing finished, time used:', t2-t1) 225 | device = torch.device('cuda') 226 | 227 | ################################################################ 228 | # training and evaluation 229 | ################################################################ 230 | model = Net2d(modes, width).cuda() 231 | # model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20') 232 | 233 | print(model.count_params()) 234 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 235 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 236 | 237 | myloss = LpLoss(size_average=False) 238 | 239 | gridx = gridx.to(device) 240 | gridy = gridy.to(device) 241 | gridt = gridt.to(device) 242 | for ep in range(epochs): 243 | model.train() 244 | t1 = default_timer() 245 | train_l2_step = 0 246 | train_l2_full = 0 247 | for xx, yy in train_loader: 248 | loss = 0 249 | xx = xx.to(device) 250 | yy = yy.to(device) 251 | 252 | for t in range(0, T, step): 253 | y = yy[..., t:t+step] 254 | im = model(xx) 255 | loss += myloss(im.reshape(batch_size,-1), y.reshape(batch_size,-1)) 256 | 257 | if t == 0: 258 | pred = im.squeeze() 259 | else: 260 | pred = torch.cat((pred, im.squeeze()), -1) 261 | 262 | im = torch.cat((gridx.repeat([batch_size, 1, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1, 1]), 263 | gridt.repeat([batch_size, 1, 1, 1, 1]), im), dim=-1) 264 | xx = torch.cat([xx[..., step:, :], im], -2) 265 | 266 | train_l2_step += loss.item() 267 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 268 | train_l2_full += l2_full.item() 269 | 270 | optimizer.zero_grad() 271 | loss.backward() 272 | # l2_full.backward() 273 | optimizer.step() 274 | 275 | test_l2_step = 0 276 | test_l2_full = 0 277 | with torch.no_grad(): 278 | for xx, yy in test_loader: 279 | loss = 0 280 | xx = xx.to(device) 281 | yy = yy.to(device) 282 | 283 | for t in range(0, T, step): 284 | y = yy[..., t:t + step] 285 | im = model(xx) 286 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 287 | 288 | if t == 0: 289 | pred = im.squeeze() 290 | else: 291 | pred = torch.cat((pred, im.squeeze()), -1) 292 | 293 | im = torch.cat((gridx.repeat([batch_size, 1, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1, 1]), 294 | gridt.repeat([batch_size, 1, 1, 1, 1]), im), dim=-1) 295 | xx = torch.cat([xx[..., step:, :], im], -2) 296 | 297 | test_l2_step += loss.item() 298 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 299 | 300 | t2 = default_timer() 301 | scheduler.step() 302 | print(ep, t2-t1, train_l2_step/ntrain/(T/step), train_l2_full/ntrain, test_l2_step/ntest/(T/step), test_l2_full/ntest) 303 | torch.save(model, path_model) 304 | 305 | 306 | # pred = torch.zeros(test_u.shape) 307 | # index = 0 308 | # test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 309 | # with torch.no_grad(): 310 | # for x, y in test_loader: 311 | # test_l2 = 0; 312 | # x, y = x.cuda(), y.cuda() 313 | # 314 | # out = model(x) 315 | # out = y_normalizer.decode(out) 316 | # pred[index] = out 317 | # 318 | # test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 319 | # print(index, test_l2) 320 | # index = index + 1 321 | 322 | # scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 323 | 324 | 325 | 326 | 327 | -------------------------------------------------------------------------------- /scripts/fourier_on_images.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | import operator 11 | from functools import reduce 12 | from functools import partial 13 | 14 | from timeit import default_timer 15 | from utilities3 import * 16 | 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | 20 | torch.manual_seed(0) 21 | np.random.seed(0) 22 | 23 | #Complex multiplication 24 | 25 | def compl_mul2d(a, b): 26 | op = partial(torch.einsum, "bctq,dctq->bdtq") 27 | return torch.stack([ 28 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 29 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 30 | ], dim=-1) 31 | 32 | 33 | class SpectralConv2d(nn.Module): 34 | def __init__(self, in_channels, out_channels, mode): 35 | super(SpectralConv2d, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.modes1 = mode #Number of Fourier modes to multiply, at most floor(N/2) + 1 39 | self.modes2 = mode 40 | 41 | self.scale = (1 / (in_channels * out_channels)) 42 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 43 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 44 | 45 | def forward(self, x): 46 | batchsize = x.shape[0] 47 | #Compute Fourier coeffcients up to factor of e^(- something constant) 48 | x_ft = torch.rfft(x, 2, normalized=True, onesided=True) 49 | 50 | # Multiply relevant Fourier modes 51 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 52 | out_ft[:, :, :self.modes1, :self.modes2] = \ 53 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 54 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 55 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 56 | 57 | #Return to physical space 58 | x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=( x.size(-2), x.size(-1))) 59 | return x 60 | 61 | 62 | class SimpleBlock2d(nn.Module): 63 | def __init__(self, modes): 64 | super(SimpleBlock2d, self).__init__() 65 | 66 | self.conv1 = SpectralConv2d(1, 16, modes=modes) 67 | self.conv2 = SpectralConv2d(16, 32, modes=modes) 68 | self.conv3 = SpectralConv2d(32, 64, modes=modes) 69 | 70 | self.pool = nn.MaxPool2d(2, 2) 71 | 72 | 73 | self.fc1 = nn.Linear(64 * 14 * 14, 120) 74 | self.fc2 = nn.Linear(120, 84) 75 | self.fc3 = nn.Linear(84, 10) 76 | 77 | def forward(self, x): 78 | x = self.conv1(x) 79 | x = F.relu(x) 80 | x = self.conv2(x) 81 | x = F.relu(x) 82 | x = self.conv3(x) 83 | x = self.pool(x) 84 | 85 | x = x.view(-1, 64 * 14 * 14) 86 | x = F.relu(self.fc1(x)) 87 | x = F.relu(self.fc2(x)) 88 | x = self.fc3(x) 89 | 90 | return x 91 | 92 | class Net2d(nn.Module): 93 | def __init__(self): 94 | super(Net2d, self).__init__() 95 | 96 | self.conv = SimpleBlock2d(5) 97 | 98 | def forward(self, x): 99 | x = self.conv(x) 100 | 101 | return x.squeeze(-1) 102 | 103 | def count_params(self): 104 | c = 0 105 | for p in self.parameters(): 106 | c += reduce(operator.mul, list(p.size())) 107 | 108 | return c 109 | 110 | 111 | class BasicBlock(nn.Module): 112 | expansion = 1 113 | 114 | def __init__(self, in_planes, planes, stride=1, modes=10): 115 | super(BasicBlock, self).__init__() 116 | self.conv1 = SpectralConv2d(in_planes, planes, modes=modes) 117 | self.bn1 = nn.BatchNorm2d(planes) 118 | self.conv2 = SpectralConv2d(planes, planes, modes=modes) 119 | self.bn2 = nn.BatchNorm2d(planes) 120 | 121 | self.shortcut = nn.Sequential() 122 | if stride != 1 or in_planes != self.expansion*planes: 123 | self.shortcut = nn.Sequential( 124 | SpectralConv2d(in_planes, self.expansion*planes, modes=modes), 125 | nn.BatchNorm2d(self.expansion*planes) 126 | ) 127 | 128 | def forward(self, x): 129 | out = F.relu(self.bn1(self.conv1(x))) 130 | out = self.bn2(self.conv2(out)) 131 | out += self.shortcut(x) 132 | out = F.relu(out) 133 | return out 134 | 135 | class ResNet(nn.Module): 136 | def __init__(self, block, num_blocks, num_classes=10): 137 | super(ResNet, self).__init__() 138 | self.in_planes = 32 139 | 140 | self.conv1 = SpectralConv2d(3, 32, modes=10) 141 | self.bn1 = nn.BatchNorm2d(32) 142 | self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1, modes=3) 143 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=1, modes=3) 144 | self.layer3 = self._make_layer(block, 32, num_blocks[2], stride=1, modes=3) 145 | self.layer4 = self._make_layer(block, 32, num_blocks[3], stride=1, modes=3) 146 | self.linear1 = nn.Linear(32*64*block.expansion, num_classes) 147 | # self.linear2 = nn.Linear(100, num_classes) 148 | 149 | def _make_layer(self, block, planes, num_blocks, stride, modes=10): 150 | strides = [stride] + [1]*(num_blocks-1) 151 | layers = [] 152 | for stride in strides: 153 | layers.append(block(self.in_planes, planes, stride, modes)) 154 | self.in_planes = planes * block.expansion 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x): 158 | out = self.conv1(x) 159 | out = self.bn1(out) 160 | out = F.relu(out) 161 | out = self.layer1(out) 162 | # out = F.avg_pool2d(out, 2) 163 | out = self.layer2(out) 164 | # out = F.avg_pool2d(out, 2) 165 | out = self.layer3(out) 166 | # out = F.avg_pool2d(out, 2) 167 | out = self.layer4(out) 168 | out = F.avg_pool2d(out, 4) 169 | # print(out.shape) 170 | out = out.view(out.size(0), -1) 171 | out = self.linear1(out) 172 | # out = F.relu(out) 173 | # out = self.linear2(out) 174 | return out 175 | 176 | def ResNet18(): 177 | return ResNet(BasicBlock, [3, 4, 23, 3]) 178 | 179 | 180 | ## Mnist 181 | # transform = transforms.Compose([transforms.ToTensor(), 182 | # transforms.Normalize((0.5,), (0.5,)), 183 | # ]) 184 | # trainset = torchvision.datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform) 185 | # testset = torchvision.datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform) 186 | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True) 187 | # testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True) 188 | 189 | ## Cifar10 190 | transform = transforms.Compose( 191 | [transforms.ToTensor(), 192 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 193 | 194 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 195 | download=True, transform=transform) 196 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, 197 | shuffle=True, num_workers=4) 198 | 199 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 200 | download=True, transform=transform) 201 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, 202 | shuffle=False, num_workers=4) 203 | 204 | classes = ('plane', 'car', 'bird', 'cat', 205 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 206 | 207 | 208 | # model = Net2d().cuda() 209 | model = ResNet18().cuda() 210 | # model = torch.load('results/fourier_on_images') 211 | 212 | criterion = nn.CrossEntropyLoss() 213 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4) 214 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.75) 215 | 216 | for epoch in range(50): # loop over the dataset multiple times 217 | running_loss = 0.0 218 | for i, data in enumerate(trainloader, 0): 219 | # get the inputs; data is a list of [inputs, labels] 220 | inputs, labels = data[0].cuda(), data[1].cuda() 221 | 222 | # zero the parameter gradients 223 | optimizer.zero_grad() 224 | 225 | # forward + backward + optimize 226 | outputs = model(inputs) 227 | loss = criterion(outputs, labels) 228 | loss.backward() 229 | optimizer.step() 230 | 231 | # print statistics 232 | running_loss += loss.item() 233 | if i % 100 == 99: # print every 2000 mini-batches 234 | print('[%d, %5d] loss: %.3f' % 235 | (epoch + 1, i + 1, running_loss / 100)) 236 | running_loss = 0.0 237 | 238 | correct = 0 239 | total = 0 240 | with torch.no_grad(): 241 | for data in testloader: 242 | images, labels = data[0].cuda(), data[1].cuda() 243 | 244 | outputs = model(images) 245 | _, predicted = torch.max(outputs.data, 1) 246 | total += labels.size(0) 247 | correct += (predicted == labels).sum().item() 248 | print('Accuracy of the network on the 10000 test images: %f %%' % ( 249 | 100 * correct / total)) 250 | 251 | torch.save(model, 'results/fourier_on_images_mnist_100') 252 | -------------------------------------------------------------------------------- /scripts/super_resolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import matplotlib.pyplot as plt 7 | from utilities3 import * 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | def compl_mul3d(a, b): 20 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 21 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 22 | return torch.stack([ 23 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 24 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 25 | ], dim=-1) 26 | 27 | class SpectralConv3d_fast(nn.Module): 28 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 29 | super(SpectralConv3d_fast, self).__init__() 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels 32 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 33 | self.modes2 = modes2 34 | self.modes3 = modes3 35 | 36 | self.scale = (1 / (in_channels * out_channels)) 37 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 38 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 39 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 40 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2)) 41 | 42 | def forward(self, x): 43 | batchsize = x.shape[0] 44 | #Compute Fourier coeffcients up to factor of e^(- something constant) 45 | x_ft = torch.rfft(x, 3, normalized=True, onesided=True) 46 | 47 | # Multiply relevant Fourier modes 48 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device) 49 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 50 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 51 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 52 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 53 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 54 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 55 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 56 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 57 | 58 | #Return to physical space 59 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) 60 | return x 61 | 62 | class SimpleBlock2d(nn.Module): 63 | def __init__(self, modes1, modes2, modes3, width): 64 | super(SimpleBlock2d, self).__init__() 65 | 66 | self.modes1 = modes1 67 | self.modes2 = modes2 68 | self.modes3 = modes3 69 | self.width = width 70 | self.fc0 = nn.Linear(13, self.width) 71 | 72 | self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 73 | self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 74 | self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 75 | self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3) 76 | self.w0 = nn.Conv1d(self.width, self.width, 1) 77 | self.w1 = nn.Conv1d(self.width, self.width, 1) 78 | self.w2 = nn.Conv1d(self.width, self.width, 1) 79 | self.w3 = nn.Conv1d(self.width, self.width, 1) 80 | self.bn0 = torch.nn.BatchNorm3d(self.width) 81 | self.bn1 = torch.nn.BatchNorm3d(self.width) 82 | self.bn2 = torch.nn.BatchNorm3d(self.width) 83 | self.bn3 = torch.nn.BatchNorm3d(self.width) 84 | 85 | 86 | self.fc1 = nn.Linear(self.width, 128) 87 | self.fc2 = nn.Linear(128, 1) 88 | 89 | def forward(self, x): 90 | batchsize = x.shape[0] 91 | size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3] 92 | 93 | x = self.fc0(x) 94 | x = x.permute(0, 4, 1, 2, 3) 95 | 96 | x1 = self.conv0(x) 97 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 98 | x = self.bn0(x1 + x2) 99 | x = F.relu(x) 100 | x1 = self.conv1(x) 101 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 102 | x = self.bn1(x1 + x2) 103 | x = F.relu(x) 104 | x1 = self.conv2(x) 105 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 106 | x = self.bn2(x1 + x2) 107 | x = F.relu(x) 108 | x1 = self.conv3(x) 109 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z) 110 | x = self.bn3(x1 + x2) 111 | 112 | x = x.permute(0, 2, 3, 4, 1) 113 | x = self.fc1(x) 114 | x = F.relu(x) 115 | x = self.fc2(x) 116 | return x 117 | 118 | class Net2d(nn.Module): 119 | def __init__(self, modes, width): 120 | super(Net2d, self).__init__() 121 | self.conv1 = SimpleBlock2d(modes, modes, 6, width) 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | return x.squeeze() 126 | 127 | def count_params(self): 128 | c = 0 129 | for p in self.parameters(): 130 | c += reduce(operator.mul, list(p.size())) 131 | 132 | return c 133 | 134 | 135 | t1 = default_timer() 136 | 137 | TEST_PATH = 'data/ns_data_V1e-4_N20_T50_test.mat' 138 | 139 | 140 | ntest = 20 141 | 142 | sub = 1 143 | sub_t = 1 144 | S = 64 145 | T_in = 10 146 | T = 20 147 | 148 | indent = 1 149 | 150 | # load data 151 | reader = MatReader(TEST_PATH) 152 | test_a = reader.read_field('u')[:,::sub,::sub, 3:T_in*4:4] 153 | test_u = reader.read_field('u')[:,::sub,::sub, indent+T_in*4:indent+(T+T_in)*4:sub_t] 154 | 155 | print(test_a.shape, test_u.shape) 156 | 157 | # pad the location information (s,t) 158 | S = S * (4//sub) 159 | T = T * (4//sub_t) 160 | 161 | test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 162 | 163 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 164 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 165 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 166 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 167 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 168 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 169 | 170 | test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 171 | gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1) 172 | 173 | t2 = default_timer() 174 | print('preprocessing finished, time used:', t2-t1) 175 | device = torch.device('cuda') 176 | 177 | # load model 178 | model = torch.load('model/ns_fourier_V1e-4_T20_N9800_ep200_m12_w32') 179 | 180 | print(model.count_params()) 181 | 182 | # test 183 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 184 | myloss = LpLoss(size_average=False) 185 | pred = torch.zeros(test_u.shape) 186 | index = 0 187 | with torch.no_grad(): 188 | test_l2 = 0 189 | for x, y in test_loader: 190 | x, y = x.cuda(), y.cuda() 191 | 192 | out = model(x) 193 | pred[index] = out 194 | loss = myloss(out.view(1, -1), y.view(1, -1)).item() 195 | test_l2 += loss 196 | print(index, loss) 197 | index = index + 1 198 | print(test_l2/ntest) 199 | 200 | path = 'eval' 201 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy(), 'u': test_u.cpu().numpy()}) 202 | 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /utilities3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import torch.nn as nn 6 | 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | 11 | ################################################# 12 | # 13 | # Utilities 14 | # 15 | ################################################# 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | # reading data 19 | class MatReader(object): 20 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 21 | super(MatReader, self).__init__() 22 | 23 | self.to_torch = to_torch 24 | self.to_cuda = to_cuda 25 | self.to_float = to_float 26 | 27 | self.file_path = file_path 28 | 29 | self.data = None 30 | self.old_mat = None 31 | self._load_file() 32 | 33 | def _load_file(self): 34 | try: 35 | self.data = scipy.io.loadmat(self.file_path) 36 | self.old_mat = True 37 | except: 38 | self.data = h5py.File(self.file_path) 39 | self.old_mat = False 40 | 41 | def load_file(self, file_path): 42 | self.file_path = file_path 43 | self._load_file() 44 | 45 | def read_field(self, field): 46 | x = self.data[field] 47 | 48 | if not self.old_mat: 49 | x = x[()] 50 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 51 | 52 | if self.to_float: 53 | x = x.astype(np.float32) 54 | 55 | if self.to_torch: 56 | x = torch.from_numpy(x) 57 | 58 | if self.to_cuda: 59 | x = x.cuda() 60 | 61 | return x 62 | 63 | def set_cuda(self, to_cuda): 64 | self.to_cuda = to_cuda 65 | 66 | def set_torch(self, to_torch): 67 | self.to_torch = to_torch 68 | 69 | def set_float(self, to_float): 70 | self.to_float = to_float 71 | 72 | # normalization, pointwise gaussian 73 | class UnitGaussianNormalizer(object): 74 | def __init__(self, x, eps=0.00001, time_last=True): 75 | super(UnitGaussianNormalizer, self).__init__() 76 | 77 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T in 1D 78 | # x could be in shape of ntrain*w*l or ntrain*T*w*l or ntrain*w*l*T in 2D 79 | self.mean = torch.mean(x, 0) 80 | self.std = torch.std(x, 0) 81 | self.eps = eps 82 | self.time_last = time_last # if the time dimension is the last dim 83 | 84 | def encode(self, x): 85 | x = (x - self.mean) / (self.std + self.eps) 86 | return x 87 | 88 | def decode(self, x, sample_idx=None): 89 | # sample_idx is the spatial sampling mask 90 | if sample_idx is None: 91 | std = self.std + self.eps # n 92 | mean = self.mean 93 | else: 94 | if self.mean.ndim == sample_idx.ndim or self.time_last: 95 | std = self.std[sample_idx] + self.eps # batch*n 96 | mean = self.mean[sample_idx] 97 | if self.mean.ndim > sample_idx.ndim and not self.time_last: 98 | std = self.std[...,sample_idx] + self.eps # T*batch*n 99 | mean = self.mean[...,sample_idx] 100 | # x is in shape of batch*(spatial discretization size) or T*batch*(spatial discretization size) 101 | x = (x * std) + mean 102 | return x 103 | 104 | def to(self, device): 105 | if torch.is_tensor(self.mean): 106 | self.mean = self.mean.to(device) 107 | self.std = self.std.to(device) 108 | else: 109 | self.mean = torch.from_numpy(self.mean).to(device) 110 | self.std = torch.from_numpy(self.std).to(device) 111 | return self 112 | 113 | def cuda(self): 114 | self.mean = self.mean.cuda() 115 | self.std = self.std.cuda() 116 | 117 | def cpu(self): 118 | self.mean = self.mean.cpu() 119 | self.std = self.std.cpu() 120 | 121 | # normalization, Gaussian 122 | class GaussianNormalizer(object): 123 | def __init__(self, x, eps=0.00001): 124 | super(GaussianNormalizer, self).__init__() 125 | 126 | self.mean = torch.mean(x) 127 | self.std = torch.std(x) 128 | self.eps = eps 129 | 130 | def encode(self, x): 131 | x = (x - self.mean) / (self.std + self.eps) 132 | return x 133 | 134 | def decode(self, x, sample_idx=None): 135 | x = (x * (self.std + self.eps)) + self.mean 136 | return x 137 | 138 | def cuda(self): 139 | self.mean = self.mean.cuda() 140 | self.std = self.std.cuda() 141 | 142 | def cpu(self): 143 | self.mean = self.mean.cpu() 144 | self.std = self.std.cpu() 145 | 146 | 147 | # normalization, scaling by range 148 | class RangeNormalizer(object): 149 | def __init__(self, x, low=0.0, high=1.0): 150 | super(RangeNormalizer, self).__init__() 151 | mymin = torch.min(x, 0)[0].view(-1) 152 | mymax = torch.max(x, 0)[0].view(-1) 153 | 154 | self.a = (high - low)/(mymax - mymin) 155 | self.b = -self.a*mymax + high 156 | 157 | def encode(self, x): 158 | s = x.size() 159 | x = x.view(s[0], -1) 160 | x = self.a*x + self.b 161 | x = x.view(s) 162 | return x 163 | 164 | def decode(self, x): 165 | s = x.size() 166 | x = x.view(s[0], -1) 167 | x = (x - self.b)/self.a 168 | x = x.view(s) 169 | return x 170 | 171 | #loss function with rel/abs Lp loss 172 | class LpLoss(object): 173 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 174 | super(LpLoss, self).__init__() 175 | 176 | #Dimension and Lp-norm type are postive 177 | assert d > 0 and p > 0 178 | 179 | self.d = d 180 | self.p = p 181 | self.reduction = reduction 182 | self.size_average = size_average 183 | 184 | def abs(self, x, y): 185 | num_examples = x.size()[0] 186 | 187 | #Assume uniform mesh 188 | h = 1.0 / (x.size()[1] - 1.0) 189 | 190 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 191 | 192 | if self.reduction: 193 | if self.size_average: 194 | return torch.mean(all_norms) 195 | else: 196 | return torch.sum(all_norms) 197 | 198 | return all_norms 199 | 200 | def rel(self, x, y): 201 | num_examples = x.size()[0] 202 | 203 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 204 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 205 | 206 | if self.reduction: 207 | if self.size_average: 208 | return torch.mean(diff_norms/y_norms) 209 | else: 210 | return torch.sum(diff_norms/y_norms) 211 | 212 | return diff_norms/y_norms 213 | 214 | def __call__(self, x, y): 215 | return self.rel(x, y) 216 | 217 | # Sobolev norm (HS norm) 218 | # where we also compare the numerical derivatives between the output and target 219 | class HsLoss(object): 220 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 221 | super(HsLoss, self).__init__() 222 | 223 | #Dimension and Lp-norm type are postive 224 | assert d > 0 and p > 0 225 | 226 | self.d = d 227 | self.p = p 228 | self.k = k 229 | self.balanced = group 230 | self.reduction = reduction 231 | self.size_average = size_average 232 | 233 | if a == None: 234 | a = [1,] * k 235 | self.a = a 236 | 237 | def rel(self, x, y): 238 | num_examples = x.size()[0] 239 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 240 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 241 | if self.reduction: 242 | if self.size_average: 243 | return torch.mean(diff_norms/y_norms) 244 | else: 245 | return torch.sum(diff_norms/y_norms) 246 | return diff_norms/y_norms 247 | 248 | def __call__(self, x, y, a=None): 249 | nx = x.size()[1] 250 | ny = x.size()[2] 251 | k = self.k 252 | balanced = self.balanced 253 | a = self.a 254 | x = x.view(x.shape[0], nx, ny, -1) 255 | y = y.view(y.shape[0], nx, ny, -1) 256 | 257 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 258 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 259 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 260 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 261 | 262 | x = torch.fft.fftn(x, dim=[1, 2]) 263 | y = torch.fft.fftn(y, dim=[1, 2]) 264 | 265 | if balanced==False: 266 | weight = 1 267 | if k >= 1: 268 | weight += a[0]**2 * (k_x**2 + k_y**2) 269 | if k >= 2: 270 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 271 | weight = torch.sqrt(weight) 272 | loss = self.rel(x*weight, y*weight) 273 | else: 274 | loss = self.rel(x, y) 275 | if k >= 1: 276 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 277 | loss += self.rel(x*weight, y*weight) 278 | if k >= 2: 279 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 280 | loss += self.rel(x*weight, y*weight) 281 | loss = loss / (k+1) 282 | 283 | return loss 284 | 285 | # A simple feedforward neural network 286 | class DenseNet(torch.nn.Module): 287 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 288 | super(DenseNet, self).__init__() 289 | 290 | self.n_layers = len(layers) - 1 291 | 292 | assert self.n_layers >= 1 293 | 294 | self.layers = nn.ModuleList() 295 | 296 | for j in range(self.n_layers): 297 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 298 | 299 | if j != self.n_layers - 1: 300 | if normalize: 301 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 302 | 303 | self.layers.append(nonlinearity()) 304 | 305 | if out_nonlinearity is not None: 306 | self.layers.append(out_nonlinearity()) 307 | 308 | def forward(self, x): 309 | for _, l in enumerate(self.layers): 310 | x = l(x) 311 | 312 | return x 313 | 314 | 315 | # print the number of parameters 316 | def count_params(model): 317 | c = 0 318 | for p in list(model.parameters()): 319 | c += reduce(operator.mul, 320 | list(p.size()+(2,) if p.is_complex() else p.size())) 321 | return c 322 | --------------------------------------------------------------------------------