├── Github_page_images ├── Animation_ns_256_3d_1e-4.gif ├── Animation_ns_64_3d_1e-4.gif ├── Burgers_prediction.png ├── WNN.png └── WNN_parameter.png ├── README.md ├── Test_wno_1d_Burgers.py ├── Version 1.0.0 ├── WNO_testing_1d_AV.py ├── WNO_testing_1d_Burgers.py ├── WNO_testing_2d_AC.py ├── WNO_testing_2d_Darcy_notch.py ├── WNO_testing_2d_Darcy_r.py ├── WNO_testing_2d_ERA5.py ├── WNO_testing_2d_ERA5_time.py ├── WNO_testing_2d_NS.py ├── __pycache__ │ └── utilities3.cpython-39.pyc ├── utilities3.py ├── wno_1d_Advection_time_III.py ├── wno_1d_Burger_discontinuous.py ├── wno_1d_Burgers.py ├── wno_2d_AC.py ├── wno_2d_Darcy.py ├── wno_2d_Darcy_notch.py ├── wno_2d_ERA5.py ├── wno_2d_ERA5_time.py └── wno_2d_time_NS.py ├── Version 2.0.0 ├── README.md ├── Test_wno_super_1d_Burgers.py ├── __pycache__ │ ├── utilities3.cpython-39.pyc │ ├── utils.cpython-39.pyc │ └── wavelet_convolution.cpython-39.pyc ├── data │ ├── Burger_data │ │ ├── burgerbc.m │ │ ├── burgeric.m │ │ ├── burgerpde.m │ │ └── main_burger.m │ ├── test_IC2.npz │ └── train_IC2.npz ├── model │ └── WNO_burgers ├── utils.py ├── wavelet_convolution.py ├── wno1d_Advection_time_III.py ├── wno1d_Burger_discontinuous.py ├── wno1d_Burgers.py ├── wno1d_advection_III.py ├── wno2d_AC_dwt.py ├── wno2d_Darcy_dwt.py ├── wno2d_Darcy_notch_cwt.py ├── wno2d_Darcy_notch_dwt.py ├── wno2d_NS_cwt.py ├── wno2d_NS_dwt.py ├── wno2d_Temperature_Daily_Avg.py ├── wno2d_Temperature_Monthly_Avg.py └── wno3d_NS.py ├── utils.py ├── wavelet_convolution_v3.py ├── wno1d_Burgers_v3.py ├── wno2d_Darcy_cwt_v3.py ├── wno2d_Darcy_dwt_v3.py └── wno3d_NS_dwt_v3.py /Github_page_images/Animation_ns_256_3d_1e-4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Github_page_images/Animation_ns_256_3d_1e-4.gif -------------------------------------------------------------------------------- /Github_page_images/Animation_ns_64_3d_1e-4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Github_page_images/Animation_ns_64_3d_1e-4.gif -------------------------------------------------------------------------------- /Github_page_images/Burgers_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Github_page_images/Burgers_prediction.png -------------------------------------------------------------------------------- /Github_page_images/WNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Github_page_images/WNN.png -------------------------------------------------------------------------------- /Github_page_images/WNN_parameter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Github_page_images/WNN_parameter.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wavelet-Neural-Operator (WNO) 2 | This repository contains the python codes of the paper 3 | > + Tripura, T., & Chakraborty, S. (2023). Wavelet Neural Operator for solving parametric partial differential equations in computational mechanics problems. Computer Methods in Applied Mechanics and Engineering, 404, 115783. [Article](https://doi.org/10.1016/j.cma.2022.115783) 4 | > + ArXiv version- "Wavelet neural operator: a neural operator for parametric partial differential equations". The arXiv version can be accessed [here](https://arxiv.org/abs/2205.02191). 5 | 6 | ## New in version 3.0.0 7 | ``` 8 | > Modified the convolution in wavelet space. 9 | > Replaced the element-wise multiplication with a secondary convolution. 10 | > The secondary convolution is done with respect to wavelet coefficients. 11 | > As a result, the WNO-v3 framework is now more accurate. 12 | > The secondary convolution is done in Fourier space, added new parameter `omega`. 13 | > `omega` controls the number of Fourier modes used in the spectral convolution. 14 | ``` 15 | 16 | ## New in version 2.0.0 17 | ``` 18 | > Added superresolution attribute to the WNO. 19 | > Added 3D support to the WNO. 20 | > Improved the interface and readability of the codes. 21 | ``` 22 | 23 | ## Architecture of the wavelet neural operator (WNO). 24 | (a) Schematic of the proposed neural operator. (b) A simple WNO with one wavelet kernel integral layer. 25 | ![WNO](/Github_page_images/WNN.png) 26 | 27 | ## Construction of the parametric space using multi-level wavelet decomposition. 28 | ![Construction of parameterization space in WNO](/Github_page_images/WNN_parameter.png) 29 | 30 | ## Super resolution using Wavelet Neural Operator. 31 | > Super resolution in Burgers' diffusion dynamics: 32 | ![Train at resolution-1024 and Test at resolution-2048](/Github_page_images/Burgers_prediction.png) 33 | > Super resolution in Navier-Stokes equation with 10000 Reynolds number: 34 | ![Train in Low resolution](/Github_page_images/Animation_ns_64_3d_1e-4.gif) 35 | ![Test in High resolution](/Github_page_images/Animation_ns_256_3d_1e-4.gif) 36 | 37 | ## Files 38 | A short despcription on the files are provided below for ease of readers. For `time-dependent` problems, please implement the autoregressive schemes provided in `Version 2.0.0`. 39 | ``` 40 | + `wno1d_Burgers_v3.py`: For 1D Burger's equation (time-independent problem). 41 | + `wno2d_Darcy_cwt_v3.py`: For 2D Darcy equation using Slim Continuous Wavelet Transform (time-independent problem). 42 | + `wno2d_Darcy_dwt_v3.py`: For 2D Darcy equation using Discrete wavelet transform (time-independent problem). 43 | + `wno3d_NS_dwt_v3.py`: For 2D Navier-Stokes equation using 3D WNO (as a time-independent problem). 44 | 45 | + `Test_wno_1d_Burgers.py`: An example of Testing on new data. 46 | 47 | + `utils.py` contains some useful functions for data handling (improvised from [FNO paper](https://github.com/zongyi-li/fourier_neural_operator)). 48 | + `wavelet_convolution_v3.py` contains functions for 1D, 2D, and 3D convolution in wavelet domain. 49 | ``` 50 | 51 | ## Essential Python Libraries 52 | Following packages are required to be installed to run the above codes: 53 | + [PyTorch](https://pytorch.org/) 54 | + [PyWavelets - Wavelet Transforms in Python](https://pywavelets.readthedocs.io/en/latest/) 55 | + [Wavelet Transforms in Pytorch](https://github.com/fbcotter/pytorch_wavelets) 56 | + [Wavelet Transform Toolbox](https://github.com/v0lta/PyTorch-Wavelet-Toolbox) 57 | + [Xarray-Grib reader (To read ERA5 data in section 5)](https://docs.xarray.dev/en/stable/getting-started-guide/installing.html?highlight=install) 58 | 59 | Copy all the data in the folder 'data' and place the folder 'data' inside the same mother folder where the codes are present. Incase, the location of the data are changed, the correct path should be given. 60 | 61 | ## Dataset 62 | + The training and testing datasets for the (i) Burgers equation with discontinuity in the solution field (section 4.1), (ii) 2-D Allen-Cahn equation (section 4.5), and (iii) Weakly-monthly mean 2m air temperature (section 5) are available in the following link: 63 | > [Dataset-1](https://drive.google.com/drive/folders/1scfrpChQ1wqFu8VAyieoSrdgHYCbrT6T?usp=sharing) \ 64 | The dataset for the Weakly and monthly mean 2m air temperature are downloaded from 'European Centre for Medium-Range Weather Forecasts (ECMEF)' database. For more information on the dataset one can browse the link 65 | [ECMEF](https://www.ecmwf.int/en/forecasts/datasets/browse-reanalysis-datasets). 66 | + The datasets for (i) 1-D Burgers equation ('burgers_data_R10.zip'), (ii) 2-D Darcy flow equation in a rectangular domain ('Darcy_421.zip'), (iii) 2-D time-dependent Navier-Stokes equation ('ns_V1e-3_N5000_T50.zip'), are taken from the following link: 67 | > [Dataset-2](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) 68 | + The datasets for 2-D Darcy flow equation with a notch in triangular domain ('Darcy_Triangular_FNO.mat') and 1-D time-dependent wave advection equation are taken from the following link: 69 | > [Dataset-3](https://github.com/lu-group/deeponet-fno/tree/main/data) 70 | 71 | ## BibTex 72 | If you use any part our codes, please cite us at, 73 | ``` 74 | @article{tripura2023wavelet, 75 | title={Wavelet Neural Operator for solving parametric partial differential equations in computational mechanics problems}, 76 | author={Tripura, Tapas and Chakraborty, Souvik}, 77 | journal={Computer Methods in Applied Mechanics and Engineering}, 78 | volume={404}, 79 | pages={115783}, 80 | year={2023}, 81 | publisher={Elsevier} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /Test_wno_1d_Burgers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 1-D Burger's equation (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import matplotlib.pyplot as plt 14 | 15 | from timeit import default_timer 16 | from utils import * 17 | from wavelet_convolution import WaveConv1d 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | device = ('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | # %% 24 | """ The forward operation """ 25 | class WNO1d(nn.Module): 26 | def __init__(self, width, level, layers, size, wavelet, in_channel, grid_range, omega, padding=0): 27 | super(WNO1d, self).__init__() 28 | 29 | """ 30 | The WNO network. It contains l-layers of the Wavelet integral layer. 31 | 1. Lift the input using v(x) = self.fc0 . 32 | 2. l-layers of the integral operators v(j+1)(x) = g(K.v + W.v)(x). 33 | --> W is defined by self.w; K is defined by self.conv. 34 | 3. Project the output of last layer using self.fc1 and self.fc2. 35 | 36 | Input : 2-channel tensor, Initial condition and location (a(x), x) 37 | : shape: (batchsize * x=s * c=2) 38 | Output: Solution of a later timestep (u(x)) 39 | : shape: (batchsize * x=s * c=1) 40 | 41 | Input parameters: 42 | ----------------- 43 | width : scalar, lifting dimension of input 44 | level : scalar, number of wavelet decomposition 45 | layers: scalar, number of wavelet kernel integral blocks 46 | size : scalar, signal length 47 | wavelet: string, wavelet filter 48 | in_channel: scalar, channels in input including grid 49 | grid_range: scalar (for 1D), right support of 1D domain 50 | padding : scalar, size of zero padding 51 | """ 52 | 53 | self.level = level 54 | self.width = width 55 | self.layers = layers 56 | self.size = size 57 | self.wavelet = wavelet 58 | self.omega = omega 59 | self.in_channel = in_channel 60 | self.grid_range = grid_range 61 | self.padding = padding 62 | 63 | self.conv = nn.ModuleList() 64 | self.w = nn.ModuleList() 65 | 66 | self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x) 67 | for i in range( self.layers ): 68 | self.conv.append( WaveConv1d(self.width, self.width, self.level, size=self.size, 69 | wavelet=self.wavelet, omega=self.omega) ) 70 | self.w.append( nn.Conv1d(self.width, self.width, 1) ) 71 | self.fc1 = nn.Linear(self.width, 128) 72 | self.fc2 = nn.Linear(128, 1) 73 | 74 | def forward(self, x): 75 | grid = self.get_grid(x.shape, x.device) 76 | x = torch.cat((x, grid), dim=-1) 77 | x = self.fc0(x) # Shape: Batch * x * Channel 78 | x = x.permute(0, 2, 1) # Shape: Batch * Channel * x 79 | if self.padding != 0: 80 | x = F.pad(x, [0,self.padding]) 81 | 82 | for index, (convl, wl) in enumerate( zip(self.conv, self.w) ): 83 | x = convl(x) + wl(x) 84 | if index != self.layers - 1: # Final layer has no activation 85 | x = F.mish(x) # Shape: Batch * Channel * x 86 | 87 | if self.padding != 0: 88 | x = x[..., :-self.padding] 89 | x = x.permute(0, 2, 1) # Shape: Batch * x * Channel 90 | x = F.mish( self.fc1(x) ) # Shape: Batch * x * Channel 91 | x = self.fc2(x) # Shape: Batch * x * Channel 92 | return x 93 | 94 | def get_grid(self, shape, device): 95 | # The grid of the solution 96 | batchsize, size_x = shape[0], shape[1] 97 | gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float) 98 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 99 | return gridx.to(device) 100 | 101 | 102 | # %% 103 | """ Model configurations """ 104 | 105 | PATH = '/home/user/Desktop/Papers_codes/P3_WNO/WNO-master/data/burgers_data_R10.mat' 106 | ntrain = 1000 107 | ntest = 100 108 | 109 | batch_size = 20 110 | learning_rate = 0.001 111 | 112 | epochs = 500 113 | step_size = 50 # weight-decay step size 114 | gamma = 0.5 # weight-decay rate 115 | 116 | wavelet = 'db6' # wavelet basis function 117 | level = 8 # lavel of wavelet decomposition 118 | width = 64 # uplifting dimension 119 | layers = 4 # no of wavelet layers 120 | 121 | sub = 2**3 # subsampling rate 122 | h = 2**13 // sub # total grid size divided by the subsampling rate 123 | grid_range = 1 124 | in_channel = 2 # (a(x), x) for this case 125 | 126 | # %% 127 | """ Read data """ 128 | dataloader = MatReader(PATH) 129 | x_data = dataloader.read_field('a')[:,::sub] 130 | y_data = dataloader.read_field('u')[:,::sub] 131 | 132 | x_train = x_data[:ntrain,:] 133 | y_train = y_data[:ntrain,:] 134 | x_test = x_data[-ntest:,:] 135 | y_test = y_data[-ntest:,:] 136 | 137 | x_train = x_train[:, :, None] 138 | x_test = x_test[:, :, None] 139 | 140 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 141 | batch_size=batch_size, shuffle=True) 142 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 143 | batch_size=batch_size, shuffle=False) 144 | 145 | # %% 146 | """ The model definition """ 147 | model = torch.load('model/WNO_burgers', map_location=device) 148 | print(count_params(model)) 149 | 150 | myloss = LpLoss(size_average=False) 151 | 152 | # %% 153 | """ Prediction """ 154 | pred = [] 155 | test_e = [] 156 | with torch.no_grad(): 157 | 158 | index = 0 159 | for x, y in test_loader: 160 | test_l2 = 0 161 | x, y = x.to(device), y.to(device) 162 | 163 | out = model(x) 164 | test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 165 | 166 | test_e.append( test_l2/batch_size ) 167 | pred.append( out ) 168 | print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2/batch_size )) 169 | index += 1 170 | 171 | pred = torch.cat((pred)) 172 | test_e = torch.tensor((test_e)) 173 | print('Mean Error:', 100*torch.mean(test_e).numpy()) 174 | 175 | # %% 176 | plt.rcParams['font.family'] = 'Times New Roman' 177 | plt.rcParams['font.size'] = 14 178 | plt.rcParams['mathtext.fontset'] = 'dejavuserif' 179 | 180 | colormap = plt.cm.jet 181 | colors = [colormap(i) for i in np.linspace(0, 1, 5)] 182 | 183 | """ Plotting """ 184 | figure7 = plt.figure(figsize = (10, 5), dpi=300) 185 | index = 0 186 | for i in range(y_test.shape[0]): 187 | if i % 20 == 1: 188 | plt.plot(y_test[i, :].cpu().numpy(), color=colors[index], label='Actual') 189 | plt.plot(pred[i,:].cpu().numpy(), '--', color=colors[index], label='Prediction') 190 | index += 1 191 | plt.legend(ncol=5, loc=3, borderaxespad=0.1, columnspacing=0.75, handletextpad=0.25) 192 | plt.grid(True, alpha=0.35) 193 | plt.ylim([-1,1]) 194 | plt.margins(0) 195 | plt.xlabel('Space ($x$)') 196 | plt.ylabel('$u$($x$)') 197 | plt.title('Mean Error: {:0.4f}%'.format(100*torch.mean(test_e).numpy()), fontweight='bold') 198 | plt.show() 199 | 200 | -------------------------------------------------------------------------------- /Version 1.0.0/WNO_testing_1d_AV.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for performing predictions on pre-trained models for 7 | 1-D Advection equation (time-dependent problem). 8 | """ 9 | 10 | from IPython import get_ipython 11 | get_ipython().magic('reset -sf') 12 | 13 | import pywt 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import matplotlib.pyplot as plt 19 | 20 | from timeit import default_timer 21 | from utilities3 import * 22 | from pytorch_wavelets import DWT1D, IDWT1D 23 | 24 | torch.manual_seed(0) 25 | np.random.seed(0) 26 | 27 | # %% 28 | 29 | class WaveConv1d(nn.Module): 30 | def __init__(self, in_channels, out_channels, modes1): 31 | super(WaveConv1d, self).__init__() 32 | 33 | """ 34 | 1D Wavelet layer. It does Wavelet Transform, linear transform, and 35 | Inverse Wavelet Transform. 36 | """ 37 | 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | self.modes1 = modes1 41 | 42 | self.scale = (1 / (in_channels*out_channels)) 43 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1)) 44 | 45 | # Complex multiplication 46 | def compl_mul1d(self, input, weights): 47 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 48 | return torch.einsum("bix,iox->box", input, weights) 49 | 50 | def forward(self, x): 51 | batchsize = x.shape[0] 52 | # Compute single tree Discrete Wavelet coefficients using some wavelet 53 | dwt = DWT1D(wave='db6', J=3, mode='symmetric').to(device) 54 | x_ft, x_coeff = dwt(x) 55 | 56 | # Multiply the final low pass and high pass coefficients 57 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-1], device=x.device) 58 | out_ft = self.compl_mul1d(x_ft, self.weights1) 59 | x_coeff[-1] = self.compl_mul1d(x_coeff[-1], self.weights1) 60 | 61 | idwt = IDWT1D(wave='db6', mode='symmetric').to(device) 62 | x = idwt((out_ft, x_coeff)) 63 | return x 64 | 65 | class WNO1d(nn.Module): 66 | def __init__(self, modes, width): 67 | super(WNO1d, self).__init__() 68 | 69 | """ 70 | The overall network. It contains 4 layers of the Wavelet layer. 71 | 1. Lift the input to the desire channel dimension by self.fc0 . 72 | 2. 4 layers of the integral operators u' = (W + K)(u). 73 | W defined by self.w; K defined by self.conv . 74 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 75 | 76 | input: the solution of the initial condition and location (a(x), x) 77 | input shape: (batchsize, x=s, c=2) 78 | output: the solution of a later timestep 79 | output shape: (batchsize, x=s, c=1) 80 | """ 81 | 82 | self.modes1 = modes 83 | self.width = width 84 | self.padding = 2 # pad the domain when required 85 | self.fc0 = nn.Linear(40, self.width) # input channel is 2: (a(x), x) 86 | 87 | self.conv0 = WaveConv1d(self.width, self.width, self.modes1) 88 | self.conv1 = WaveConv1d(self.width, self.width, self.modes1) 89 | self.conv2 = WaveConv1d(self.width, self.width, self.modes1) 90 | self.conv3 = WaveConv1d(self.width, self.width, self.modes1) 91 | self.w0 = nn.Conv1d(self.width, self.width, 1) 92 | self.w1 = nn.Conv1d(self.width, self.width, 1) 93 | self.w2 = nn.Conv1d(self.width, self.width, 1) 94 | self.w3 = nn.Conv1d(self.width, self.width, 1) 95 | 96 | self.fc1 = nn.Linear(self.width, 128) 97 | self.fc2 = nn.Linear(128, 1) 98 | 99 | def forward(self, x): 100 | grid = self.get_grid(x.shape, x.device) 101 | x = torch.cat((x, grid), dim=-1) 102 | x = self.fc0(x) 103 | x = x.permute(0, 2, 1) 104 | # x = F.pad(x, [0,self.padding]) 105 | 106 | x1 = self.conv0(x) 107 | x2 = self.w0(x) 108 | x = x1 + x2 109 | x = F.gelu(x) 110 | 111 | x1 = self.conv1(x) 112 | x2 = self.w1(x) 113 | x = x1 + x2 114 | x = F.gelu(x) 115 | 116 | x1 = self.conv2(x) 117 | x2 = self.w2(x) 118 | x = x1 + x2 119 | x = F.gelu(x) 120 | 121 | x1 = self.conv3(x) 122 | x2 = self.w3(x) 123 | x = x1 + x2 124 | 125 | # x = x[..., :-self.padding] 126 | x = x.permute(0, 2, 1) 127 | x = self.fc1(x) 128 | x = F.gelu(x) 129 | x = self.fc2(x) 130 | return x 131 | 132 | def get_grid(self, shape, device): 133 | batchsize, size_x = shape[0], shape[1] 134 | gridx = torch.tensor(np.linspace(0, 100, size_x), dtype=torch.float) 135 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 136 | return gridx.to(device) 137 | 138 | 139 | # %% 140 | 141 | ntrain = 1000 142 | ntest = 100 143 | s = 40 144 | 145 | batch_size = 25 # 50 146 | learning_rate = 0.001 147 | 148 | epochs = 500 149 | scheduler_step = 50 150 | scheduler_gamma = 0.5 151 | 152 | modes = 14 153 | width = 64 154 | T = 39 155 | step = 1 156 | 157 | # %% 158 | 159 | data = np.load('data/train_IC2.npz') 160 | x, t, u_train = data["x"], data["t"], data["u"] # N x nt x nx 161 | x_train = u_train[:ntrain, :-1, :] # N x nx 162 | y_train = u_train[:ntrain, 1:, :] # one step ahead, 163 | x_train = torch.from_numpy(x_train) 164 | y_train = torch.from_numpy(y_train) 165 | x_train = x_train.permute(0,2,1) 166 | y_train = y_train.permute(0,2,1) 167 | 168 | data = np.load('data/test_IC2.npz') 169 | x, t, u_test = data["x"], data["t"], data["u"] # N x nt x nx 170 | x_test = u_test[:ntest, :-1, :] # N x nx 171 | y_test = u_test[:ntest, 1:, :] # one step ahead, 172 | x_test = torch.from_numpy(x_test) 173 | y_test = torch.from_numpy(y_test) 174 | x_test = x_test.permute(0,2,1) 175 | y_test = y_test.permute(0,2,1) 176 | 177 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 178 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 179 | 180 | # %% 181 | 182 | # model 183 | model = torch.load('model/model_wno_1d_advection_III_time') 184 | print(count_params(model)) 185 | 186 | myloss = LpLoss(size_average=False) 187 | 188 | # %% 189 | pred0 = torch.zeros(y_test.shape) 190 | index = 0 191 | test_e = torch.zeros(y_test.shape) 192 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 193 | with torch.no_grad(): 194 | for xx, yy in test_loader: 195 | test_l2_step = 0 196 | test_l2_full = 0 197 | loss = 0 198 | mse = 0 199 | xx = xx.to(device) 200 | yy = yy.to(device) 201 | 202 | for t in range(0, T, step): 203 | y = yy[..., t:t + step] 204 | im = model(xx) 205 | loss += myloss(im.reshape(1, y.size()[-3], y.size()[-2]), y.reshape(1, y.size()[-3], y.size()[-2])) 206 | 207 | if t == 0: 208 | pred = im 209 | else: 210 | pred = torch.cat((pred, im), -1) 211 | 212 | xx = torch.cat((xx[..., step:], im), dim=-1) 213 | 214 | pred0[index] = pred 215 | test_l2_step += loss.item() 216 | test_l2_full += myloss(pred.reshape(1, -1), yy.reshape(1, -1)).item() 217 | mse += F.mse_loss(pred.reshape(1, -1), yy.reshape(1, -1), reduction='mean') 218 | test_e[index] = test_l2_step 219 | 220 | print(index, test_l2_step, test_l2_full, mse.cpu().numpy()) 221 | index = index + 1 222 | 223 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy() /(T/step), '%') 224 | 225 | # %% 226 | plt.rcParams["font.family"] = "serif" 227 | plt.rcParams['font.size'] = 16 228 | 229 | figure1 = plt.figure(figsize = (18, 14)) 230 | figure1.text(0.03,0.17,'\n Error', rotation=90, color='purple', fontsize=20) 231 | figure1.text(0.03,0.34,'\n Prediction', rotation=90, color='green', fontsize=20) 232 | figure1.text(0.03,0.57,'\n Truth', rotation=90, color='red', fontsize=20) 233 | figure1.text(0.03,0.75,'Initial \n Condition', rotation=90, color='b', fontsize=20) 234 | plt.subplots_adjust(wspace=0.7) 235 | index = 0 236 | for value in range(y_test.shape[0]): 237 | if value % 23 == 1 and value != 1: 238 | print(value) 239 | plt.subplot(4,4, index+1) 240 | plt.plot(np.linspace(0,1,39),x_test[value,0,:], linewidth=2, color='blue') 241 | plt.title('IC-{}'.format(index+1), color='b', fontsize=20, fontweight='bold') 242 | plt.xlabel('x', fontweight='bold'); plt.ylabel('u(x,0)', fontweight='bold'); 243 | plt.margins(0) 244 | ax = plt.gca(); 245 | ratio = 0.9 246 | x_left, x_right = ax.get_xlim() 247 | y_low, y_high = ax.get_ylim() 248 | ax.set_aspect(abs((x_right-x_left)/(y_low-y_high))*ratio) 249 | 250 | plt.subplot(4,4, index+1+4) 251 | plt.imshow(y_test[value,:,:], cmap='Spectral', extent=[0,1,0,1], interpolation='Gaussian') 252 | plt.xlabel('x', fontweight='bold'); plt.ylabel('t', fontweight='bold', color='m', fontsize=20); 253 | plt.colorbar(fraction=0.045) 254 | 255 | plt.subplot(4,4, index+1+8) 256 | plt.imshow(pred0[value,:,:], cmap='Spectral', extent=[0,1,0,1], interpolation='Gaussian') 257 | plt.xlabel('x', fontweight='bold'); plt.ylabel('t', fontweight='bold', color='m', fontsize=20); 258 | plt.colorbar(fraction=0.045) 259 | 260 | plt.subplot(4,4, index+1+12) 261 | plt.imshow(np.abs(y_test[value,:,:]-pred0[value,:,:]), cmap='jet', extent=[0,1,0,1], 262 | vmax= 1, interpolation='Gaussian') 263 | plt.xlabel('x', fontweight='bold'); plt.ylabel('t', fontweight='bold', color='m', fontsize=20); 264 | plt.colorbar(fraction=0.045) 265 | 266 | plt.margins(0) 267 | index = index + 1 268 | -------------------------------------------------------------------------------- /Version 1.0.0/WNO_testing_1d_Burgers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for performing predictions on pre-trained models for 7 | 1-D Burger's equation (time-independent problem). 8 | """ 9 | 10 | import pywt 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import matplotlib.pyplot as plt 16 | 17 | from timeit import default_timer 18 | from utilities3 import * 19 | from pytorch_wavelets import DWT1D, IDWT1D 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | # %% 25 | 26 | class WaveConv1d(nn.Module): 27 | def __init__(self, in_channels, out_channels, level1): 28 | super(WaveConv1d, self).__init__() 29 | 30 | """ 31 | 1D Wavelet layer. It does Wavelet Transform, linear transform, and 32 | Inverse Wavelet Transform. 33 | """ 34 | 35 | self.in_channels = in_channels 36 | self.out_channels = out_channels 37 | self.level1 = level1 38 | 39 | self.scale = (1 / (in_channels*out_channels)) 40 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.level1+6)) 41 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.level1+6)) 42 | 43 | # Complex multiplication 44 | def compl_mul1d(self, input, weights): 45 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 46 | return torch.einsum("bix,iox->box", input, weights) 47 | 48 | def forward(self, x): 49 | #Compute single tree Discrete Wavelet coefficients using some wavelet 50 | dwt = DWT1D(wave='db6', J=self.level1, mode='symmetric').to(device) 51 | x_ft, x_coeff = dwt(x) 52 | 53 | # Multiply the final low pass and high pass coefficients 54 | out_ft = self.compl_mul1d(x_ft, self.weights1) 55 | x_coeff[-1] = self.compl_mul1d(x_coeff[-1], self.weights2) 56 | 57 | idwt = IDWT1D(wave='db6', mode='symmetric').to(device) 58 | x = idwt((out_ft, x_coeff)) 59 | return x 60 | 61 | class WNO1d(nn.Module): 62 | def __init__(self, level, width): 63 | super(WNO1d, self).__init__() 64 | 65 | """ 66 | The overall network. It contains 4 layers of the Wavelet layer. 67 | 1. Lift the input to the desire channel dimension by self.fc0 . 68 | 2. 4 layers of the integral operators u' = (W + K)(u). 69 | W defined by self.w; K defined by self.conv . 70 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 71 | 72 | input: the solution of the initial condition and location (a(x), x) 73 | input shape: (batchsize, x=s, c=2) 74 | output: the solution of a later timestep 75 | output shape: (batchsize, x=s, c=1) 76 | """ 77 | 78 | self.level1 = level 79 | self.width = width 80 | self.padding = 2 # pad the domain when required 81 | self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x) 82 | 83 | self.conv0 = WaveConv1d(self.width, self.width, self.level1) 84 | self.conv1 = WaveConv1d(self.width, self.width, self.level1) 85 | self.conv2 = WaveConv1d(self.width, self.width, self.level1) 86 | self.conv3 = WaveConv1d(self.width, self.width, self.level1) 87 | self.w0 = nn.Conv1d(self.width, self.width, 1) 88 | self.w1 = nn.Conv1d(self.width, self.width, 1) 89 | self.w2 = nn.Conv1d(self.width, self.width, 1) 90 | self.w3 = nn.Conv1d(self.width, self.width, 1) 91 | 92 | self.fc1 = nn.Linear(self.width, 128) 93 | self.fc2 = nn.Linear(128, 1) 94 | 95 | def forward(self, x): 96 | grid = self.get_grid(x.shape, x.device) 97 | x = torch.cat((x, grid), dim=-1) 98 | x = self.fc0(x) 99 | x = x.permute(0, 2, 1) 100 | # x = F.pad(x, [0,self.padding]) 101 | 102 | x1 = self.conv0(x) 103 | x2 = self.w0(x) 104 | x = x1 + x2 105 | x = F.gelu(x) 106 | 107 | x1 = self.conv1(x) 108 | x2 = self.w1(x) 109 | x = x1 + x2 110 | x = F.gelu(x) 111 | 112 | x1 = self.conv2(x) 113 | x2 = self.w2(x) 114 | x = x1 + x2 115 | x = F.gelu(x) 116 | 117 | x1 = self.conv3(x) 118 | x2 = self.w3(x) 119 | x = x1 + x2 120 | 121 | # x = x[..., :-self.padding] 122 | x = x.permute(0, 2, 1) 123 | x = self.fc1(x) 124 | x = F.gelu(x) 125 | x = self.fc2(x) 126 | return x 127 | 128 | def get_grid(self, shape, device): 129 | batchsize, size_x = shape[0], shape[1] 130 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 131 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 132 | return gridx.to(device) 133 | 134 | 135 | # %% 136 | 137 | ntrain = 1000 138 | ntest = 100 139 | 140 | sub = 2**3 #subsampling rate 141 | h = 2**13 // sub #total grid size divided by the subsampling rate 142 | s = h 143 | 144 | batch_size = 10 145 | learning_rate = 0.001 146 | 147 | epochs = 800 148 | step_size = 50 149 | gamma = 0.75 150 | 151 | level = 8 152 | width = 64 153 | 154 | # %% 155 | 156 | # Data is of the shape (number of samples, grid size) 157 | dataloader = MatReader('data/burgers_data_R10.mat') 158 | x_data = dataloader.read_field('a')[:,::sub] 159 | y_data = dataloader.read_field('u')[:,::sub] 160 | 161 | x_train = x_data[:ntrain,:] 162 | y_train = y_data[:ntrain,:] 163 | x_test = x_data[-ntest:,:] 164 | y_test = y_data[-ntest:,:] 165 | 166 | x_train = x_train.reshape(ntrain,s,1) 167 | x_test = x_test.reshape(ntest,s,1) 168 | 169 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 170 | batch_size=batch_size, shuffle=True) 171 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 172 | batch_size=batch_size, shuffle=False) 173 | 174 | # %% 175 | 176 | model = torch.load('model/model_wno_1d_burgers') 177 | print(count_params(model)) 178 | 179 | myloss = LpLoss(size_average=False) 180 | 181 | # %% 182 | pred = torch.zeros(y_test.shape) 183 | index = 0 184 | test_e = torch.zeros(y_test.shape) 185 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 186 | with torch.no_grad(): 187 | for x, y in test_loader: 188 | test_l2 = 0 189 | x, y = x.to(device), y.to(device) 190 | 191 | out = model(x).view(-1) 192 | pred[index] = out 193 | 194 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 195 | test_e[index] = test_l2 196 | print(index, test_l2) 197 | index = index + 1 198 | 199 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 200 | 201 | # %% 202 | 203 | plt.rcParams["font.family"] = "Times New Roman" 204 | plt.rcParams['font.size'] = 16 205 | 206 | figure1 = plt.figure(figsize = (12, 8)) 207 | plt.subplots_adjust(hspace=0.4) 208 | for i in range(y_test.shape[0]): 209 | if i % 23 == 1: 210 | plt.subplot(2,1,1) 211 | plt.plot(np.linspace(0,1,1024),x_test[i, :].numpy()) 212 | plt.title('(a) I.C.') 213 | plt.xlabel('x', fontsize=20, fontweight='bold') 214 | plt.ylabel('u(x,0)', fontsize=20, fontweight='bold') 215 | plt.grid(True) 216 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 217 | plt.margins(0) 218 | 219 | plt.subplot(2,1,2) 220 | plt.plot(np.linspace(0,1,1024),y_test[i, :].numpy()) 221 | plt.plot(np.linspace(0,1,1024),pred[i,:], ':k') 222 | plt.title('(b) Solution') 223 | plt.legend(['Truth', 'Prediction'], ncol=2, loc=3, fontsize=20) 224 | plt.xlabel('x', fontsize=20, fontweight='bold') 225 | plt.ylabel('u(x,1)', fontsize=20, fontweight='bold') 226 | plt.grid(True) 227 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 228 | plt.margins(0) 229 | -------------------------------------------------------------------------------- /Version 1.0.0/WNO_testing_2d_AC.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for performing predictions on pre-trained models for 7 | 2-D Allen Cahn equation (time-independent problem). 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.parameter import Parameter 15 | import matplotlib.pyplot as plt 16 | 17 | from timeit import default_timer 18 | from utilities3 import * 19 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | # %% 25 | 26 | class WaveConv2d(nn.Module): 27 | def __init__(self, in_channels, out_channels, modes1, modes2): 28 | super(WaveConv2d, self).__init__() 29 | 30 | """ 31 | 2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 32 | """ 33 | 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.modes1 = modes1 37 | self.modes2 = modes2 38 | 39 | self.scale = (1 / (in_channels * out_channels)) 40 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 41 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 42 | 43 | # Complex multiplication 44 | def compl_mul2d(self, input, weights): 45 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 46 | return torch.einsum("bixy,ioxy->boxy", input, weights) 47 | 48 | def forward(self, x): 49 | batchsize = x.shape[0] 50 | #Compute single tree Discrete Wavelet coefficients using some wavelet 51 | dwt = DWT(J=1, mode='symmetric', wave='db4').to(device) 52 | x_ft, x_coeff = dwt(x) 53 | 54 | # Multiply relevant Wavelet modes 55 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 56 | out_ft = self.compl_mul2d(x_ft, self.weights1) 57 | # Multiply the finer wavelet coefficients 58 | x_coeff[-1][:,:,0,:,:] = self.compl_mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 59 | 60 | # Return to physical space 61 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 62 | x = idwt((out_ft, x_coeff)) 63 | return x 64 | 65 | # %% 66 | class FNO2d(nn.Module): 67 | def __init__(self, modes1, modes2, width): 68 | super(FNO2d, self).__init__() 69 | 70 | """ 71 | The overall network. It contains 4 layers of the Wavelet layer. 72 | 1. Lift the input to the desire channel dimension by self.fc0 . 73 | 2. 4 layers of the integral operators u' = (W + K)(u). 74 | W defined by self.w; K defined by self.conv . 75 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 76 | 77 | input: the solution of the coefficient function and locations (a(x, y), x, y) 78 | input shape: (batchsize, x=s, y=s, c=3) 79 | output: the solution 80 | output shape: (batchsize, x=s, y=s, c=1) 81 | """ 82 | 83 | self.modes1 = modes1 84 | self.modes2 = modes2 85 | self.width = width 86 | self.padding = 1 # pad the domain when required 87 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 88 | 89 | self.conv0 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 90 | self.conv1 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 91 | self.conv2 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 92 | self.conv3 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.w0 = nn.Conv2d(self.width, self.width, 1) 94 | self.w1 = nn.Conv2d(self.width, self.width, 1) 95 | self.w2 = nn.Conv2d(self.width, self.width, 1) 96 | self.w3 = nn.Conv2d(self.width, self.width, 1) 97 | 98 | self.fc1 = nn.Linear(self.width, 128) 99 | self.fc2 = nn.Linear(128, 1) 100 | 101 | def forward(self, x): 102 | grid = self.get_grid(x.shape, x.device) 103 | x = torch.cat((x, grid), dim=-1) 104 | 105 | x = self.fc0(x) 106 | x = x.permute(0, 3, 1, 2) 107 | x = F.pad(x, [0,self.padding, 0,self.padding]) 108 | 109 | x1 = self.conv0(x) 110 | x2 = self.w0(x) 111 | x = x1 + x2 112 | x = F.gelu(x) 113 | 114 | x1 = self.conv1(x) 115 | x2 = self.w1(x) 116 | x = x1 + x2 117 | x = F.gelu(x) 118 | 119 | x1 = self.conv2(x) 120 | x2 = self.w2(x) 121 | x = x1 + x2 122 | x = F.gelu(x) 123 | 124 | x1 = self.conv3(x) 125 | x2 = self.w3(x) 126 | x = x1 + x2 127 | 128 | x = x[..., :-self.padding, :-self.padding] 129 | x = x.permute(0, 2, 3, 1) 130 | x = self.fc1(x) 131 | x = F.gelu(x) 132 | x = self.fc2(x) 133 | return x 134 | 135 | def get_grid(self, shape, device): 136 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 137 | gridx = torch.tensor(np.linspace(0, 2*torch.pi, size_x), dtype=torch.float) 138 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 139 | gridy = torch.tensor(np.linspace(0, 2*torch.pi, size_y), dtype=torch.float) 140 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 141 | return torch.cat((gridx, gridy), dim=-1).to(device) 142 | 143 | # %% 144 | 145 | PATH = 'data/Allen_cahn_2d_128_128_T10.mat' 146 | 147 | ntrain = 1400 148 | ntest = 100 149 | 150 | batch_size = 20 151 | learning_rate = 0.001 152 | 153 | epochs = 500 154 | step_size = 25 155 | gamma = 0.75 156 | 157 | modes = 25 158 | width = 64 159 | 160 | r = 3 161 | h = int(((129 - 1)/r) + 1) 162 | s = h 163 | 164 | # %% 165 | 166 | reader = MatReader(PATH) 167 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s] 168 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s] 169 | 170 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s] 171 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s] 172 | 173 | x_normalizer = UnitGaussianNormalizer(x_train) 174 | x_train = x_normalizer.encode(x_train) 175 | x_test = x_normalizer.encode(x_test) 176 | 177 | y_normalizer = UnitGaussianNormalizer(y_train) 178 | y_train = y_normalizer.encode(y_train) 179 | 180 | x_train = x_train.reshape(ntrain,s,s,1) 181 | x_test = x_test.reshape(ntest,s,s,1) 182 | 183 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 184 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 185 | 186 | # %% 187 | 188 | model = torch.load('model/model_wno_2d_AC') 189 | print(count_params(model)) 190 | 191 | myloss = LpLoss(size_average=False) 192 | y_normalizer.cuda() 193 | 194 | # %% 195 | pred = torch.zeros(y_test.shape) 196 | index = 0 197 | test_e = torch.zeros(y_test.shape) 198 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 199 | with torch.no_grad(): 200 | for x, y in test_loader: 201 | test_l2 = 0 202 | x, y = x.to(device), y.to(device) 203 | 204 | out = model(x).reshape(s, s) 205 | out = y_normalizer.decode(out) 206 | pred[index] = out 207 | 208 | test_l2 += myloss(out.reshape(1, s, s), y.reshape(1, s, s)).item() 209 | test_e[index] = test_l2 210 | print(index, test_l2) 211 | index = index + 1 212 | 213 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 214 | 215 | # %% 216 | plt.rcParams["font.family"] = "Serif" 217 | plt.rcParams['font.size'] = 14 218 | 219 | figure1 = plt.figure(figsize = (18, 14)) 220 | figure1.text(0.04,0.17,'\n Error', rotation=90, color='purple', fontsize=20) 221 | figure1.text(0.04,0.34,'\n Prediction \n u(x,y,10)', rotation=90, color='green', fontsize=20) 222 | figure1.text(0.04,0.57,'\n Truth \n u(x,y,10)', rotation=90, color='red', fontsize=20) 223 | figure1.text(0.04,0.75,'Initial \n Condition \n u$_0$(x,y)', rotation=90, color='b', fontsize=20) 224 | plt.subplots_adjust(wspace=0.7) 225 | index = 0 226 | for value in range(y_test.shape[0]): 227 | if value % 25 == 1: 228 | plt.subplot(4,4, index+1) 229 | plt.imshow(x_test[value,:,:,0], cmap='turbo', extent=[0,3,0,3], interpolation='Gaussian') 230 | plt.title('IC-{}'.format(index+1), color='b', fontsize=18, fontweight='bold') 231 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 232 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 233 | 234 | plt.subplot(4,4, index+1+4) 235 | plt.imshow(y_test[value,:,:], cmap='turbo', extent=[0,3,0,3], interpolation='Gaussian') 236 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 237 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 238 | plt.colorbar(fraction=0.045) 239 | 240 | plt.subplot(4,4, index+1+8) 241 | plt.imshow(pred[value,:,:], cmap='turbo', extent=[0,3,0,3], interpolation='Gaussian') 242 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 243 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 244 | plt.colorbar(fraction=0.045) 245 | 246 | plt.subplot(4,4, index+1+12) 247 | plt.imshow(np.abs(y_test[value,:,:]-pred[value,:,:]), cmap='jet', extent=[0,3,0,3], interpolation='Gaussian') 248 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 249 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 250 | plt.colorbar(fraction=0.045,format='%.0e') 251 | 252 | plt.margins(0) 253 | index = index + 1 254 | -------------------------------------------------------------------------------- /Version 1.0.0/WNO_testing_2d_Darcy_r.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for performing predictions on pre-trained models for 7 | 2-D Darcy equation in rectangular domain (time-independent problem). 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.parameter import Parameter 15 | import matplotlib.pyplot as plt 16 | 17 | from timeit import default_timer 18 | from utilities3 import * 19 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | # %% 25 | 26 | class WaveConv2d(nn.Module): 27 | def __init__(self, in_channels, out_channels, modes1, modes2): 28 | super(WaveConv2d, self).__init__() 29 | 30 | """ 31 | 2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 32 | """ 33 | 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.modes1 = modes1 37 | self.modes2 = modes2 38 | 39 | self.scale = (1 / (in_channels * out_channels)) 40 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 41 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 42 | 43 | # Complex multiplication 44 | def compl_mul2d(self, input, weights): 45 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 46 | return torch.einsum("bixy,ioxy->boxy", input, weights) 47 | 48 | def forward(self, x): 49 | batchsize = x.shape[0] 50 | #Compute single tree Discrete Wavelet coefficients using some wavelet 51 | dwt = DWT(J=4, mode='symmetric', wave='db4').to(device) 52 | x_ft, x_coeff = dwt(x) 53 | 54 | # Multiply relevant Wavelet modes 55 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 56 | out_ft = self.compl_mul2d(x_ft, self.weights1) 57 | # Multiply the finer wavelet coefficients 58 | x_coeff[-1][:,:,0,:,:] = self.compl_mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 59 | 60 | # Return to physical space 61 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 62 | x = idwt((out_ft, x_coeff)) 63 | return x 64 | 65 | # %% 66 | class WNO2d(nn.Module): 67 | def __init__(self, modes1, modes2, width): 68 | super(WNO2d, self).__init__() 69 | 70 | """ 71 | The overall network. It contains 4 layers of the Wavelet layer. 72 | 1. Lift the input to the desire channel dimension by self.fc0 . 73 | 2. 4 layers of the integral operators u' = (W + K)(u). 74 | W defined by self.w; K defined by self.conv . 75 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 76 | 77 | input: the solution of the coefficient function and locations (a(x, y), x, y) 78 | input shape: (batchsize, x=s, y=s, c=3) 79 | output: the solution 80 | output shape: (batchsize, x=s, y=s, c=1) 81 | """ 82 | 83 | self.modes1 = modes1 84 | self.modes2 = modes2 85 | self.width = width 86 | self.padding = 1 # pad the domain when required 87 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 88 | 89 | self.conv0 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 90 | self.conv1 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 91 | self.conv2 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 92 | self.conv3 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.w0 = nn.Conv2d(self.width, self.width, 1) 94 | self.w1 = nn.Conv2d(self.width, self.width, 1) 95 | self.w2 = nn.Conv2d(self.width, self.width, 1) 96 | self.w3 = nn.Conv2d(self.width, self.width, 1) 97 | 98 | self.fc1 = nn.Linear(self.width, 128) 99 | self.fc2 = nn.Linear(128, 1) 100 | 101 | def forward(self, x): 102 | grid = self.get_grid(x.shape, x.device) 103 | x = torch.cat((x, grid), dim=-1) 104 | 105 | x = self.fc0(x) 106 | x = x.permute(0, 3, 1, 2) 107 | x = F.pad(x, [0,self.padding, 0,self.padding]) 108 | 109 | x1 = self.conv0(x) 110 | x2 = self.w0(x) 111 | x = x1 + x2 112 | x = F.gelu(x) 113 | 114 | x1 = self.conv1(x) 115 | x2 = self.w1(x) 116 | x = x1 + x2 117 | x = F.gelu(x) 118 | 119 | x1 = self.conv2(x) 120 | x2 = self.w2(x) 121 | x = x1 + x2 122 | x = F.gelu(x) 123 | 124 | x1 = self.conv3(x) 125 | x2 = self.w3(x) 126 | x = x1 + x2 127 | 128 | x = x[..., :-self.padding, :-self.padding] 129 | x = x.permute(0, 2, 3, 1) 130 | x = self.fc1(x) 131 | x = F.gelu(x) 132 | x = self.fc2(x) 133 | return x 134 | 135 | def get_grid(self, shape, device): 136 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 137 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 138 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 139 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 140 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 141 | return torch.cat((gridx, gridy), dim=-1).to(device) 142 | 143 | # %% 144 | 145 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 146 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 147 | 148 | ntrain = 1000 149 | ntest = 100 150 | 151 | batch_size = 20 152 | learning_rate = 0.001 153 | 154 | epochs = 800 155 | step_size = 50 156 | gamma = 0.75 157 | 158 | modes = 11 159 | width = 64 160 | 161 | r = 5 162 | h = int(((421 - 1)/r) + 1) 163 | s = h 164 | 165 | # %% 166 | 167 | reader = MatReader(TRAIN_PATH) 168 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s] 169 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s] 170 | 171 | reader.load_file(TEST_PATH) 172 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s] 173 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s] 174 | 175 | x_normalizer = UnitGaussianNormalizer(x_train) 176 | x_train = x_normalizer.encode(x_train) 177 | x_test = x_normalizer.encode(x_test) 178 | 179 | y_normalizer = UnitGaussianNormalizer(y_train) 180 | y_train = y_normalizer.encode(y_train) 181 | 182 | x_train = x_train.reshape(ntrain,s,s,1) 183 | x_test = x_test.reshape(ntest,s,s,1) 184 | 185 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 186 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 187 | 188 | # %% 189 | 190 | model = torch.load('model/model_wno_2d_darcy_rect') 191 | 192 | print(count_params(model)) 193 | 194 | myloss = LpLoss(size_average=False) 195 | y_normalizer.cuda() 196 | 197 | # %% 198 | pred = torch.zeros(y_test.shape) 199 | index = 0 200 | test_e = torch.zeros(y_test.shape[0]) 201 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 202 | with torch.no_grad(): 203 | for x, y in test_loader: 204 | test_l2 = 0 205 | x, y = x.to(device), y.to(device) 206 | 207 | out = model(x).reshape(s, s) 208 | out = y_normalizer.decode(out) 209 | pred[index] = out 210 | 211 | test_l2 += myloss(out.reshape(1, s, s), y.reshape(1, s, s)).item() 212 | test_e[index] = test_l2 213 | print(index, test_l2) 214 | index = index + 1 215 | 216 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 217 | 218 | # %% 219 | plt.rcParams["font.family"] = "serif" 220 | plt.rcParams['font.size'] = 14 221 | 222 | figure1 = plt.figure(figsize = (18, 14)) 223 | figure1.text(0.04,0.17,'\n Error', rotation=90, color='purple', fontsize=20) 224 | figure1.text(0.04,0.34,'\n Prediction', rotation=90, color='green', fontsize=20) 225 | figure1.text(0.04,0.57,'\n Truth', rotation=90, color='red', fontsize=20) 226 | figure1.text(0.04,0.75,'Permeability \n field', rotation=90, color='b', fontsize=20) 227 | plt.subplots_adjust(wspace=0.7) 228 | index = 0 229 | for value in range(y_test.shape[0]): 230 | if value % 26 == 1: 231 | plt.subplot(4,4, index+1) 232 | plt.imshow(x_test[value,:,:,0], cmap='rainbow', extent=[0,1,0,1], interpolation='Gaussian') 233 | plt.title('a(x,y)-{}'.format(index+1), color='b', fontsize=20, fontweight='bold') 234 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 235 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 236 | 237 | plt.subplot(4,4, index+1+4) 238 | plt.imshow(y_test[value,:,:], cmap='rainbow', extent=[0,1,0,1], interpolation='Gaussian') 239 | plt.colorbar(fraction=0.045) 240 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 241 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 242 | 243 | plt.subplot(4,4, index+1+8) 244 | plt.imshow(pred[value,:,:], cmap='rainbow', extent=[0,1,0,1], interpolation='Gaussian') 245 | plt.colorbar(fraction=0.045) 246 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 247 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 248 | 249 | plt.subplot(4,4, index+1+12) 250 | plt.imshow(np.abs(pred[value,:,:]-y_test[value,:,:]), cmap='jet', extent=[0,1,0,1], interpolation='Gaussian') 251 | plt.xlabel('x', fontweight='bold'); plt.ylabel('y', fontweight='bold'); 252 | plt.colorbar(fraction=0.045,format='%.0e') 253 | 254 | plt.margins(0) 255 | index = index + 1 256 | -------------------------------------------------------------------------------- /Version 1.0.0/WNO_testing_2d_ERA5.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for performing predictions on pre-trained models for 7 | forecast of monthly averaged 2m air temperature (time-independent problem). 8 | """ 9 | 10 | from IPython import get_ipython 11 | get_ipython().magic('reset -sf') 12 | 13 | # %% 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch.nn.parameter import Parameter 19 | import matplotlib.pyplot as plt 20 | 21 | import xarray as xr 22 | from timeit import default_timer 23 | from utilities3 import * 24 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 25 | 26 | torch.manual_seed(0) 27 | np.random.seed(0) 28 | 29 | # %% 30 | 31 | class WaveConv2d(nn.Module): 32 | def __init__(self, in_channels, out_channels, modes1, modes2): 33 | super(WaveConv2d, self).__init__() 34 | 35 | """ 36 | 2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 37 | """ 38 | 39 | self.in_channels = in_channels 40 | self.out_channels = out_channels 41 | self.modes1 = modes1 42 | self.modes2 = modes2 43 | 44 | self.scale = (1 / (in_channels * out_channels)) 45 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 46 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 47 | 48 | # Complex multiplication 49 | def compl_mul2d(self, input, weights): 50 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 51 | return torch.einsum("bixy,ioxy->boxy", input, weights) 52 | 53 | def forward(self, x): 54 | batchsize = x.shape[0] 55 | #Compute single tree Discrete Wavelet coefficients using some wavelet 56 | dwt = DWT(J=5, mode='symmetric', wave='db4').to(device) 57 | x_ft, x_coeff = dwt(x) 58 | 59 | # Multiply relevant Wavelet modes 60 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 61 | out_ft = self.compl_mul2d(x_ft, self.weights1) 62 | # Multiply the finer wavelet coefficients 63 | x_coeff[-1][:,:,0,:,:] = self.compl_mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 64 | 65 | # Return to physical space 66 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 67 | x = idwt((out_ft, x_coeff)) 68 | return x 69 | 70 | # %% 71 | class WNO2d(nn.Module): 72 | def __init__(self, modes1, modes2, width): 73 | super(WNO2d, self).__init__() 74 | 75 | """ 76 | The overall network. It contains 4 layers of the Wavelet layer. 77 | 1. Lift the input to the desire channel dimension by self.fc0 . 78 | 2. 4 layers of the integral operators u' = (W + K)(u). 79 | W defined by self.w; K defined by self.conv . 80 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 81 | 82 | input: the solution of the coefficient function and locations (a(x, y), x, y) 83 | input shape: (batchsize, x=s, y=s, c=3) 84 | output: the solution 85 | output shape: (batchsize, x=s, y=s, c=1) 86 | """ 87 | 88 | self.modes1 = modes1 89 | self.modes2 = modes2 90 | self.width = width 91 | self.padding = 1 # pad the domain when required 92 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 93 | 94 | self.conv0 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 95 | self.conv1 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 96 | self.conv2 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 97 | self.conv3 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 98 | self.w0 = nn.Conv2d(self.width, self.width, 1) 99 | self.w1 = nn.Conv2d(self.width, self.width, 1) 100 | self.w2 = nn.Conv2d(self.width, self.width, 1) 101 | self.w3 = nn.Conv2d(self.width, self.width, 1) 102 | 103 | self.fc1 = nn.Linear(self.width, 128) 104 | self.fc2 = nn.Linear(128, 1) 105 | 106 | def forward(self, x): 107 | grid = self.get_grid(x.shape, x.device) 108 | x = torch.cat((x, grid), dim=-1) 109 | 110 | x = self.fc0(x) 111 | x = x.permute(0, 3, 1, 2) 112 | x = F.pad(x, [0, self.padding, 0, self.padding]) 113 | 114 | x1 = self.conv0(x) 115 | x2 = self.w0(x) 116 | x = x1 + x2 117 | x = F.gelu(x) 118 | 119 | x1 = self.conv1(x) 120 | x2 = self.w1(x) 121 | x = x1 + x2 122 | x = F.gelu(x) 123 | 124 | x1 = self.conv2(x) 125 | x2 = self.w2(x) 126 | x = x1 + x2 127 | x = F.gelu(x) 128 | 129 | x1 = self.conv3(x) 130 | x2 = self.w3(x) 131 | x = x1 + x2 132 | 133 | x = x[..., :-self.padding, :-self.padding] 134 | x = x.permute(0, 2, 3, 1) 135 | x = self.fc1(x) 136 | x = F.gelu(x) 137 | x = self.fc2(x) 138 | return x 139 | 140 | def get_grid(self, shape, device): 141 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 142 | gridx = torch.tensor(np.linspace(0, 360, size_x), dtype=torch.float) # latitudes 143 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 144 | gridy = torch.tensor(np.linspace(90, -90, size_y), dtype=torch.float) # longitudes 145 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 146 | return torch.cat((gridx, gridy), dim=-1).to(device) 147 | 148 | # %% 149 | 150 | PATH = 'data/ERA5_temp.grib' 151 | 152 | ntrain = 460 153 | ntest = 50 154 | 155 | batch_size = 20 156 | learning_rate = 0.001 157 | 158 | epochs = 500 159 | step_size = 25 160 | gamma = 0.75 161 | 162 | modes1 = 10 163 | modes2 = 14 164 | width = 64 165 | 166 | r = 6 167 | h = int((721 - 1)/r+1) 168 | s = int((1441 - 1)/r+1) 169 | 170 | # %% 171 | 172 | ds = xr.open_dataset(PATH, engine='cfgrib') 173 | data = np.array(ds["t2m"]) 174 | data = torch.tensor(data) 175 | # data = data[:, :720, :] 176 | data = F.pad(data, [0,1]) # pad last dimension to make it periodic 177 | 178 | x_train = data[:ntrain, ::r, ::r] 179 | y_train = data[:ntrain, ::r, ::r] 180 | 181 | x_test = data[-ntest:, ::r, ::r] 182 | y_test = data[-ntest:, ::r, ::r] 183 | 184 | x_normalizer = UnitGaussianNormalizer(x_train) 185 | x_train = x_normalizer.encode(x_train) 186 | x_test = x_normalizer.encode(x_test) 187 | 188 | y_normalizer = UnitGaussianNormalizer(y_train) 189 | y_train = y_normalizer.encode(y_train) 190 | 191 | x_train = x_train.reshape(ntrain,h,s,1) 192 | x_test = x_test.reshape(ntest,h,s,1) 193 | 194 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 195 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 196 | 197 | # %% 198 | 199 | model = torch.load('model/model_wno_2d_ERA5_t2m') 200 | print(count_params(model)) 201 | myloss = LpLoss(size_average=False) 202 | y_normalizer.cuda() 203 | 204 | # %% 205 | pred = torch.zeros(y_test.shape) 206 | index = 0 207 | test_e = torch.zeros(y_test.shape[0]) 208 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 209 | with torch.no_grad(): 210 | for x, y in test_loader: 211 | test_l2 = 0 212 | x, y = x.to(device), y.to(device) 213 | 214 | out = model(x).reshape(h, s) 215 | out = y_normalizer.decode(out) 216 | pred[index] = out 217 | 218 | test_l2 += myloss(out.reshape(1, h, s), y.reshape(1, h, s)).item() 219 | test_e[index] = test_l2 220 | print(index, test_l2) 221 | index = index + 1 222 | 223 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 224 | 225 | # %% 226 | plt.rcParams["font.family"] = "serif" 227 | plt.rcParams['font.size'] = 12 228 | 229 | figure1 = plt.figure(figsize = (12, 13)) 230 | plt.subplots_adjust(hspace=0.01, wspace=0.25) 231 | index = 1 232 | for value in range(y_test.shape[0]): 233 | if value % 22 == 1 and value != 1: 234 | ### 235 | img = y_test[value,:,:-1].cpu().numpy() 236 | plt.subplot(3,2, index) 237 | plt.imshow(img, cmap='nipy_spectral', extent=[0,360,-90,+90]) 238 | plt.xlabel('Longitude ($^{\circ}$)'); plt.ylabel('Lattitude ($^{\circ}$)') 239 | plt.grid(True) 240 | if index==1: 241 | plt.title('Truth: Feb 2019, 1st'); 242 | else: 243 | plt.title('Truth: Feb 2021, 1st') 244 | 245 | ### 246 | plt.subplot(3,2, index+2) 247 | plt.imshow(pred[value,:,:-1], cmap='nipy_spectral', extent=[0,360,-90,+90]) 248 | plt.xlabel('Longitude ($^{\circ}$)'); plt.ylabel('Lattitude ($^{\circ}$)') 249 | plt.grid(True) 250 | if index==1: 251 | plt.title('Identification - Feb 2019, 1st \n (Error-{:0.4f})'.format(100*test_e[value].numpy())) 252 | else: 253 | plt.title('Identification - Feb 2021, 1st \n (Error-{:0.4f})'.format(100*test_e[value].numpy())) 254 | 255 | index = index + 1 256 | -------------------------------------------------------------------------------- /Version 1.0.0/WNO_testing_2d_ERA5_time.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for performing predictions on pre-trained models for 7 | weekly forecast of 2m air temperature (time-dependent problem). 8 | """ 9 | 10 | from IPython import get_ipython 11 | get_ipython().magic('reset -sf') 12 | 13 | import torch 14 | import numpy as np 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | import matplotlib.pyplot as plt 19 | from utilities3 import * 20 | 21 | import xarray as xr 22 | from timeit import default_timer 23 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 24 | 25 | torch.manual_seed(0) 26 | np.random.seed(0) 27 | 28 | # %% 29 | 30 | class WaveConv2d(nn.Module): 31 | def __init__(self, in_channels, out_channels, modes1, modes2): 32 | super(WaveConv2d, self).__init__() 33 | 34 | """ 35 | 2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 36 | """ 37 | 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | self.modes1 = modes1 41 | self.modes2 = modes2 42 | 43 | self.scale = (1 / (in_channels * out_channels)) 44 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 45 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 46 | 47 | # Complex multiplication 48 | def compl_mul2d(self, input, weights): 49 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 50 | return torch.einsum("bixy,ioxy->boxy", input, weights) 51 | 52 | def forward(self, x): 53 | batchsize = x.shape[0] 54 | #Compute single tree Discrete Wavelet coefficients using some wavelet 55 | dwt = DWT(J=2, mode='symmetric', wave='db4').to(device) 56 | x_ft, x_coeff = dwt(x) 57 | 58 | # Multiply relevant Wavelet modes 59 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 60 | out_ft = self.compl_mul2d(x_ft, self.weights1) 61 | # Multiply the finer wavelet coefficients 62 | x_coeff[-1][:,:,0,:,:] = self.compl_mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 63 | 64 | # Return to physical space 65 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 66 | x = idwt((out_ft, x_coeff)) 67 | return x 68 | 69 | # %% 70 | class WNO2d(nn.Module): 71 | def __init__(self, modes1, modes2, width): 72 | super(WNO2d, self).__init__() 73 | 74 | """ 75 | The overall network. It contains 4 layers of the Wavelet layer. 76 | 1. Lift the input to the desire channel dimension by self.fc0 . 77 | 2. 4 layers of the integral operators u' = (W + K)(u). 78 | W defined by self.w; K defined by self.conv . 79 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 80 | 81 | input: the solution of the coefficient function and locations (a(x, y), x, y) 82 | input shape: (batchsize, x=s, y=s, c=3) 83 | output: the solution 84 | output shape: (batchsize, x=s, y=s, c=1) 85 | """ 86 | 87 | self.modes1 = modes1 88 | self.modes2 = modes2 89 | self.width = width 90 | self.padding = 2 # pad the domain when required 91 | self.fc0 = nn.Linear(9, self.width) # input channel is 3: (a(x, y), x, y) 92 | 93 | self.conv0 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 94 | self.conv1 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 95 | self.conv2 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 96 | self.conv3 = WaveConv2d(self.width, self.width, self.modes1, self.modes2) 97 | self.w0 = nn.Conv2d(self.width, self.width, 1) 98 | self.w1 = nn.Conv2d(self.width, self.width, 1) 99 | self.w2 = nn.Conv2d(self.width, self.width, 1) 100 | self.w3 = nn.Conv2d(self.width, self.width, 1) 101 | 102 | self.fc1 = nn.Linear(self.width, 128) 103 | self.fc2 = nn.Linear(128, 1) 104 | 105 | def forward(self, x): 106 | grid = self.get_grid(x.shape, x.device) 107 | x = torch.cat((x, grid), dim=-1) 108 | 109 | x = self.fc0(x) 110 | x = x.permute(0, 3, 1, 2) 111 | x = F.pad(x, [0, self.padding, 0, self.padding]) 112 | 113 | x1 = self.conv0(x) 114 | x2 = self.w0(x) 115 | x = x1 + x2 116 | x = F.gelu(x) 117 | 118 | x1 = self.conv1(x) 119 | x2 = self.w1(x) 120 | x = x1 + x2 121 | x = F.gelu(x) 122 | 123 | x1 = self.conv2(x) 124 | x2 = self.w2(x) 125 | x = x1 + x2 126 | x = F.gelu(x) 127 | 128 | x1 = self.conv3(x) 129 | x2 = self.w3(x) 130 | x = x1 + x2 131 | 132 | x = x[..., :-self.padding, :-self.padding] 133 | x = x.permute(0, 2, 3, 1) 134 | x = self.fc1(x) 135 | x = F.gelu(x) 136 | x = self.fc2(x) 137 | return x 138 | 139 | def get_grid(self, shape, device): 140 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 141 | gridx = torch.tensor(np.linspace(0, 360, size_x), dtype=torch.float) # latitudes 142 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 143 | gridy = torch.tensor(np.linspace(90, -90, size_y), dtype=torch.float) # longitudes 144 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 145 | return torch.cat((gridx, gridy), dim=-1).to(device) 146 | 147 | # %% 148 | 149 | PATH = 'data/ERA5_day_5years.grib' 150 | 151 | ntrain = 270 152 | ntest = 6 153 | 154 | batch_size = 3 155 | learning_rate = 0.001 156 | 157 | epochs = 500 158 | step_size = 50 159 | gamma = 0.75 160 | 161 | modes1 = 17 162 | modes2 = 28 163 | width = 20 164 | 165 | r = 2**4 166 | h = int(((721 - 1)/r)) 167 | s = int(((1441 - 1)/r)) 168 | 169 | T = 7 170 | step = 1 171 | 172 | # %% 173 | 174 | ds = xr.open_dataset(PATH, engine='cfgrib') 175 | data = np.array(ds["t2m"]) 176 | data = torch.tensor(data) 177 | # data = data[:,:720,:] 178 | 179 | Tn = 7*int(1937/7) 180 | x_data = data[:-1, :, :] 181 | y_data = data[1:, :, :] 182 | 183 | x_data = x_data[:Tn, :, :] 184 | y_data = y_data[:Tn, :, :] 185 | 186 | x_data = x_data.reshape(1932,721,1440,1) 187 | x_data = list(torch.split(x_data, int(1932/7), dim=0)) 188 | x_data = torch.cat((x_data), dim=3) 189 | 190 | y_data = y_data.reshape(1932,721,1440,1) 191 | y_data = list(torch.split(y_data, int(1932/7), dim=0)) 192 | y_data = torch.cat((y_data), dim=3) 193 | 194 | # %% 195 | x_train = x_data[:ntrain, ::r, ::r] 196 | y_train = y_data[:ntrain, ::r, ::r] 197 | 198 | x_test = y_data[-ntest:, ::r, ::r] 199 | y_test = y_data[-ntest:, ::r, ::r] 200 | 201 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 202 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 203 | 204 | # %% 205 | 206 | model = torch.load('model/model_wno_2d_ERA5_time') 207 | print(count_params(model)) 208 | 209 | myloss = LpLoss(size_average=False) 210 | 211 | # %% 212 | pred0 = torch.zeros(y_test.shape) 213 | index = 0 214 | test_e = torch.zeros(y_test.shape[0]) 215 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 216 | error = torch.zeros(y_test.shape[0],T) 217 | with torch.no_grad(): 218 | for xx, yy in test_loader: 219 | test_l2_step = 0 220 | test_l2_full = 0 221 | loss = 0 222 | xx = xx.to(device) 223 | yy = yy.to(device) 224 | 225 | for t in range(0, T, step): 226 | y = yy[..., t:t + step] 227 | im = model(xx) 228 | loss += myloss(im.reshape(1, y.size()[-3], y.size()[-2]), y.reshape(1, y.size()[-3], y.size()[-2])) 229 | error[index, t] = myloss(im.reshape(1, y.size()[-3], y.size()[-2]), y.reshape(1, y.size()[-3], y.size()[-2])) 230 | if t == 0: 231 | pred = im 232 | else: 233 | pred = torch.cat((pred, im), -1) 234 | 235 | xx = torch.cat((xx[..., step:], im), dim=-1) 236 | 237 | pred0[index] = pred 238 | test_l2_step += loss.item() 239 | test_l2_full += myloss(pred.reshape(1, -1), yy.reshape(1, -1)).item() 240 | test_e[index] = test_l2_step 241 | 242 | print(index, test_l2_step/ ntest/ (T/step), test_l2_full/ ntest) 243 | index = index + 1 244 | 245 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy() / (T/step), '%') 246 | 247 | # %% 248 | plt.rcParams["font.family"] = "serif" 249 | plt.rcParams['font.size'] = 14 250 | plt.rcParams['font.weight'] = 'bold' 251 | 252 | figure1 = plt.figure(figsize = (18, 16)) 253 | plt.subplots_adjust(hspace=0.05, wspace=0.18) 254 | batch_no = 5 255 | index = 0 256 | for tvalue in range(10): 257 | if tvalue < 6: #(printing till Mon.-Sat.) 258 | ### 259 | plt.subplot(4,3, index+1) 260 | plt.imshow(y_test.numpy()[batch_no,:,:,tvalue], cmap='gist_ncar', interpolation='Gaussian') 261 | plt.title('Day-{}'.format(tvalue+1)); plt.xlabel('Longitude ($^{\circ}$)', fontweight='bold'); 262 | plt.grid(True) 263 | if index == 0 or index == 3: 264 | plt.ylabel('Truth \n Latitude ($^{\circ}$)', fontweight='bold') 265 | else: 266 | plt.ylabel('Latitude ($^{\circ}$)', fontweight='bold') 267 | 268 | ### 269 | plt.subplot(4,3, index+1+6) 270 | plt.imshow(pred0[batch_no,:,:,tvalue], cmap='gist_ncar', interpolation='Gaussian') 271 | plt.title('Day-{} (error={:0.4f}%)'.format(tvalue+1,100*error[batch_no,tvalue])); 272 | plt.xlabel('Longitude ($^{\circ}$)', fontweight='bold'); 273 | plt.grid(True) 274 | if index == 0 or index == 3: 275 | plt.ylabel('Prediction \n Latitude ($^{\circ}$)', fontweight='bold') 276 | else: 277 | plt.ylabel('Latitude ($^{\circ}$)', fontweight='bold') 278 | index = index + 1 279 | -------------------------------------------------------------------------------- /Version 1.0.0/WNO_testing_2d_NS.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for performing predictions on pre-trained models for 7 | 2-D Darcy Navier-Stokes equation (time-dependent problem). 8 | """ 9 | 10 | import torch 11 | import numpy as np 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | import matplotlib.pyplot as plt 16 | from utilities3 import * 17 | 18 | from timeit import default_timer 19 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | # %% 25 | 26 | class WNOConv2d_fast(nn.Module): 27 | def __init__(self, in_channels, out_channels, modes1, modes2): 28 | super(WNOConv2d_fast, self).__init__() 29 | 30 | """ 31 | 2D Wavelet layer. It does FFT, linear transform, and Inverse FFT. 32 | """ 33 | 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.modes1 = modes1 #Number of Wavelet modes to multiply, at most floor(N/2) + 1 37 | self.modes2 = modes2 38 | 39 | self.scale = (1 / (in_channels * out_channels)) 40 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 41 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 42 | 43 | # Complex multiplication 44 | def compl_mul2d(self, input, weights): 45 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 46 | return torch.einsum("bixy,ioxy->boxy", input, weights) 47 | 48 | def forward(self, x): 49 | batchsize = x.shape[0] 50 | 51 | #Compute single tree Discrete Wavelet coefficients using some wavelet 52 | # dwt = DWT(J=3, mode='symmetric', wave='db6').to(device) 53 | dwt = DWT(J=3, mode='symmetric', wave='db4').to(device) 54 | 55 | x_ft, x_coeff = dwt(x) 56 | # print(x_ft.shape) 57 | 58 | # Multiply relevant Wavelet modes 59 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 60 | out_ft = self.compl_mul2d(x_ft, self.weights1) 61 | # Multiply the finer wavelet coefficients 62 | x_coeff[-1][:,:,0,:,:] = self.compl_mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 63 | 64 | # Return to physical space 65 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 66 | x = idwt((out_ft, x_coeff)) 67 | 68 | return x 69 | 70 | class WNO2d(nn.Module): 71 | def __init__(self, modes1, modes2, width): 72 | super(WNO2d, self).__init__() 73 | 74 | """ 75 | The overall network. It contains 4 layers of the Wavelet layer. 76 | 1. Lift the input to the desire channel dimension by self.fc0 . 77 | 2. 4 layers of the integral operators u' = (W + K)(u). 78 | W defined by self.w; K defined by self.conv . 79 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 80 | 81 | input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) 82 | input shape: (batchsize, x=64, y=64, c=12) 83 | output: the solution of the next timestep 84 | output shape: (batchsize, x=64, y=64, c=1) 85 | """ 86 | 87 | self.modes1 = modes1 88 | self.modes2 = modes2 89 | self.width = width 90 | self.padding = 2 # pad the domain when required 91 | self.fc0 = nn.Linear(12, self.width) 92 | # input channel is 12: the solution of the previous 10 timesteps + 93 | # 2 locations (u(t-10, x, y), ..., u(t-1, x, y), x, y) 94 | 95 | self.conv0 = WNOConv2d_fast(self.width, self.width, self.modes1, self.modes2) 96 | self.conv1 = WNOConv2d_fast(self.width, self.width, self.modes1, self.modes2) 97 | self.conv2 = WNOConv2d_fast(self.width, self.width, self.modes1, self.modes2) 98 | self.conv3 = WNOConv2d_fast(self.width, self.width, self.modes1, self.modes2) 99 | self.w0 = nn.Conv2d(self.width, self.width, 1) 100 | self.w1 = nn.Conv2d(self.width, self.width, 1) 101 | self.w2 = nn.Conv2d(self.width, self.width, 1) 102 | self.w3 = nn.Conv2d(self.width, self.width, 1) 103 | 104 | self.fc1 = nn.Linear(self.width, 128) 105 | self.fc2 = nn.Linear(128, 1) 106 | 107 | def forward(self, x): 108 | grid = self.get_grid(x.shape, x.device) 109 | x = torch.cat((x, grid), dim=-1) 110 | x = self.fc0(x) 111 | x = x.permute(0, 3, 1, 2) 112 | # x = F.pad(x, [0,self.padding, 0,self.padding]) 113 | 114 | x1 = self.conv0(x) 115 | x2 = self.w0(x) 116 | x = x1 + x2 117 | x = F.gelu(x) 118 | 119 | x1 = self.conv1(x) 120 | x2 = self.w1(x) 121 | x = x1 + x2 122 | x = F.gelu(x) 123 | 124 | x1 = self.conv2(x) 125 | x2 = self.w2(x) 126 | x = x1 + x2 127 | x = F.gelu(x) 128 | 129 | x1 = self.conv3(x) 130 | x2 = self.w3(x) 131 | x = x1 + x2 132 | 133 | # x = x[..., :-self.padding, :-self.padding] 134 | x = x.permute(0, 2, 3, 1) 135 | x = self.fc1(x) 136 | x = F.gelu(x) 137 | x = self.fc2(x) 138 | return x 139 | 140 | def get_grid(self, shape, device): 141 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 142 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 143 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 144 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 145 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 146 | return torch.cat((gridx, gridy), dim=-1).to(device) 147 | 148 | 149 | # %% 150 | 151 | TRAIN_PATH = 'data/ns_V1e-3_N5000_T50.mat' 152 | 153 | ntrain = 1000 154 | ntest = 20 155 | 156 | modes = 14 157 | width = 26 158 | 159 | batch_size = 20 160 | batch_size2 = batch_size 161 | 162 | epochs = 800 163 | learning_rate = 0.001 164 | scheduler_step = 50 165 | scheduler_gamma = 0.75 166 | 167 | sub = 1 168 | S = 64 169 | T_in = 10 170 | T = 10 171 | step = 1 172 | 173 | # %% 174 | 175 | reader = MatReader(TRAIN_PATH) 176 | data = reader.read_field('u') 177 | train_a = data[:ntrain,::sub,::sub,:T_in] 178 | train_u = data[:ntrain,::sub,::sub,T_in:T+T_in] 179 | 180 | test_a = data[-ntest:,::sub,::sub,:T_in] 181 | test_u = data[-ntest:,::sub,::sub,T_in:T+T_in] 182 | 183 | print(train_u.shape) 184 | print(test_u.shape) 185 | assert (S == train_u.shape[-2]) 186 | assert (T == train_u.shape[-1]) 187 | 188 | train_a = train_a.reshape(ntrain,S,S,T_in) 189 | test_a = test_a.reshape(ntest,S,S,T_in) 190 | 191 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 192 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 193 | 194 | # %% 195 | 196 | model = torch.load('model/model_wno_2d_navier_stokes') 197 | print(count_params(model)) 198 | 199 | myloss = LpLoss(size_average=False) 200 | 201 | # %% 202 | pred0 = torch.zeros(test_u.shape) 203 | index = 0 204 | test_e = torch.zeros(test_u.shape) 205 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 206 | 207 | with torch.no_grad(): 208 | for xx, yy in test_loader: 209 | test_l2_step = 0 210 | test_l2_full = 0 211 | loss = 0 212 | xx = xx.to(device) 213 | yy = yy.to(device) 214 | 215 | for t in range(0, T, step): 216 | y = yy[..., t:t + step] 217 | im = model(xx) 218 | loss += myloss(im.reshape(1, y.size()[-3], y.size()[-2]), y.reshape(1, y.size()[-3], y.size()[-2])) 219 | 220 | if t == 0: 221 | pred = im 222 | else: 223 | pred = torch.cat((pred, im), -1) 224 | 225 | xx = torch.cat((xx[..., step:], im), dim=-1) 226 | 227 | pred0[index] = pred 228 | test_l2_step += loss.item() 229 | test_l2_full += myloss(pred.reshape(1, -1), yy.reshape(1, -1)).item() 230 | test_e[index] = test_l2_step 231 | 232 | print(index, test_l2_step/ ntest/ (T/step), test_l2_full/ ntest) 233 | index = index + 1 234 | 235 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy() / ntest/ (T/step), '%') 236 | 237 | # %% 238 | plt.rcParams["font.family"] = "serif" 239 | plt.rcParams['font.size'] = 14 240 | 241 | figure1 = plt.figure(figsize = (18, 14)) 242 | figure1.text(0.04,0.17,'\n Error', rotation=90, color='purple', fontsize=20) 243 | figure1.text(0.04,0.34,'\n Prediction', rotation=90, color='green', fontsize=20) 244 | figure1.text(0.04,0.57,'\n Truth', rotation=90, color='red', fontsize=20) 245 | figure1.text(0.04,0.75,'Initial \n Condition', rotation=90, color='b', fontsize=20) 246 | plt.subplots_adjust(wspace=0.7) 247 | index = 0 248 | for value in range(test_u.shape[-1]): 249 | if value % 3 == 0: 250 | print(value) 251 | plt.subplot(4,4, index+1) 252 | plt.imshow(test_a.numpy()[15,:,:,0], cmap='jet', extent=[0,1,0,1], interpolation='Gaussian') 253 | plt.title('t={}s'.format(value+10), color='b', fontsize=18, fontweight='bold') 254 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 255 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 256 | 257 | plt.subplot(4,4, index+1+4) 258 | plt.imshow(test_u[15,:,:,value], cmap='jet', extent=[0,1,0,1], interpolation='Gaussian') 259 | plt.colorbar(fraction=0.045) 260 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 261 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 262 | 263 | plt.subplot(4,4, index+1+8) 264 | plt.imshow(pred0[15,:,:,value], cmap='jet', extent=[0,1,0,1], interpolation='Gaussian') 265 | plt.colorbar(fraction=0.045) 266 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 267 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 268 | 269 | plt.subplot(4,4, index+1+12) 270 | plt.imshow(np.abs(test_u[15,:,:,value]-pred0[15,:,:,value]), cmap='jet', extent=[0,1,0,1], interpolation='Gaussian') 271 | plt.xlabel('x', fontweight='bold'); plt.ylabel('y', fontweight='bold'); 272 | plt.colorbar(fraction=0.045,format='%.0e') 273 | 274 | plt.margins(0) 275 | index = index + 1 276 | -------------------------------------------------------------------------------- /Version 1.0.0/__pycache__/utilities3.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Version 1.0.0/__pycache__/utilities3.cpython-39.pyc -------------------------------------------------------------------------------- /Version 1.0.0/utilities3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import torch.nn as nn 6 | 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | 11 | """ 12 | This code is taken from the repo: https://github.com/zongyi-li/fourier_neural_operator 13 | 14 | The associated article is Fourier Neural Operator 15 | """ 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | # reading data 20 | class MatReader(object): 21 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 22 | super(MatReader, self).__init__() 23 | 24 | self.to_torch = to_torch 25 | self.to_cuda = to_cuda 26 | self.to_float = to_float 27 | 28 | self.file_path = file_path 29 | 30 | self.data = None 31 | self.old_mat = None 32 | self._load_file() 33 | 34 | def _load_file(self): 35 | try: 36 | self.data = scipy.io.loadmat(self.file_path) 37 | self.old_mat = True 38 | except: 39 | self.data = h5py.File(self.file_path) 40 | self.old_mat = False 41 | 42 | def load_file(self, file_path): 43 | self.file_path = file_path 44 | self._load_file() 45 | 46 | def read_field(self, field): 47 | x = self.data[field] 48 | 49 | if not self.old_mat: 50 | x = x[()] 51 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 52 | 53 | if self.to_float: 54 | x = x.astype(np.float32) 55 | 56 | if self.to_torch: 57 | x = torch.from_numpy(x) 58 | 59 | if self.to_cuda: 60 | x = x.cuda() 61 | 62 | return x 63 | 64 | def set_cuda(self, to_cuda): 65 | self.to_cuda = to_cuda 66 | 67 | def set_torch(self, to_torch): 68 | self.to_torch = to_torch 69 | 70 | def set_float(self, to_float): 71 | self.to_float = to_float 72 | 73 | # normalization, pointwise gaussian 74 | class UnitGaussianNormalizer(object): 75 | def __init__(self, x, eps=0.00001): 76 | super(UnitGaussianNormalizer, self).__init__() 77 | 78 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 79 | self.mean = torch.mean(x, 0) 80 | self.std = torch.std(x, 0) 81 | self.eps = eps 82 | 83 | def encode(self, x): 84 | x = (x - self.mean) / (self.std + self.eps) 85 | return x 86 | 87 | def decode(self, x, sample_idx=None): 88 | if sample_idx is None: 89 | std = self.std + self.eps # n 90 | mean = self.mean 91 | else: 92 | if len(self.mean.shape) == len(sample_idx[0].shape): 93 | std = self.std[sample_idx] + self.eps # batch*n 94 | mean = self.mean[sample_idx] 95 | if len(self.mean.shape) > len(sample_idx[0].shape): 96 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 97 | mean = self.mean[:,sample_idx] 98 | 99 | # x is in shape of batch*n or T*batch*n 100 | x = (x * std) + mean 101 | return x 102 | 103 | def cuda(self): 104 | self.mean = self.mean.cuda() 105 | self.std = self.std.cuda() 106 | 107 | def cpu(self): 108 | self.mean = self.mean.cpu() 109 | self.std = self.std.cpu() 110 | 111 | # normalization, Gaussian 112 | class GaussianNormalizer(object): 113 | def __init__(self, x, eps=0.00001): 114 | super(GaussianNormalizer, self).__init__() 115 | 116 | self.mean = torch.mean(x) 117 | self.std = torch.std(x) 118 | self.eps = eps 119 | 120 | def encode(self, x): 121 | x = (x - self.mean) / (self.std + self.eps) 122 | return x 123 | 124 | def decode(self, x, sample_idx=None): 125 | x = (x * (self.std + self.eps)) + self.mean 126 | return x 127 | 128 | def cuda(self): 129 | self.mean = self.mean.cuda() 130 | self.std = self.std.cuda() 131 | 132 | def cpu(self): 133 | self.mean = self.mean.cpu() 134 | self.std = self.std.cpu() 135 | 136 | 137 | # normalization, scaling by range 138 | class RangeNormalizer(object): 139 | def __init__(self, x, low=0.0, high=1.0): 140 | super(RangeNormalizer, self).__init__() 141 | mymin = torch.min(x, 0)[0].view(-1) 142 | mymax = torch.max(x, 0)[0].view(-1) 143 | 144 | self.a = (high - low)/(mymax - mymin) 145 | self.b = -self.a*mymax + high 146 | 147 | def encode(self, x): 148 | s = x.size() 149 | x = x.view(s[0], -1) 150 | x = self.a*x + self.b 151 | x = x.view(s) 152 | return x 153 | 154 | def decode(self, x): 155 | s = x.size() 156 | x = x.view(s[0], -1) 157 | x = (x - self.b)/self.a 158 | x = x.view(s) 159 | return x 160 | 161 | #loss function with rel/abs Lp loss 162 | class LpLoss(object): 163 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 164 | super(LpLoss, self).__init__() 165 | 166 | #Dimension and Lp-norm type are postive 167 | assert d > 0 and p > 0 168 | 169 | self.d = d 170 | self.p = p 171 | self.reduction = reduction 172 | self.size_average = size_average 173 | 174 | def abs(self, x, y): 175 | num_examples = x.size()[0] 176 | 177 | #Assume uniform mesh 178 | h = 1.0 / (x.size()[1] - 1.0) 179 | 180 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 181 | 182 | if self.reduction: 183 | if self.size_average: 184 | return torch.mean(all_norms) 185 | else: 186 | return torch.sum(all_norms) 187 | 188 | return all_norms 189 | 190 | def rel(self, x, y): 191 | num_examples = x.size()[0] 192 | 193 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 194 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 195 | 196 | if self.reduction: 197 | if self.size_average: 198 | return torch.mean(diff_norms/y_norms) 199 | else: 200 | return torch.sum(diff_norms/y_norms) 201 | 202 | return diff_norms/y_norms 203 | 204 | def __call__(self, x, y): 205 | return self.rel(x, y) 206 | 207 | # Sobolev norm (HS norm) 208 | # where we also compare the numerical derivatives between the output and target 209 | class HsLoss(object): 210 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 211 | super(HsLoss, self).__init__() 212 | 213 | #Dimension and Lp-norm type are postive 214 | assert d > 0 and p > 0 215 | 216 | self.d = d 217 | self.p = p 218 | self.k = k 219 | self.balanced = group 220 | self.reduction = reduction 221 | self.size_average = size_average 222 | 223 | if a == None: 224 | a = [1,] * k 225 | self.a = a 226 | 227 | def rel(self, x, y): 228 | num_examples = x.size()[0] 229 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 230 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 231 | if self.reduction: 232 | if self.size_average: 233 | return torch.mean(diff_norms/y_norms) 234 | else: 235 | return torch.sum(diff_norms/y_norms) 236 | return diff_norms/y_norms 237 | 238 | def __call__(self, x, y, a=None): 239 | nx = x.size()[1] 240 | ny = x.size()[2] 241 | k = self.k 242 | balanced = self.balanced 243 | a = self.a 244 | x = x.view(x.shape[0], nx, ny, -1) 245 | y = y.view(y.shape[0], nx, ny, -1) 246 | 247 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 248 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 249 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 250 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 251 | 252 | x = torch.fft.fftn(x, dim=[1, 2]) 253 | y = torch.fft.fftn(y, dim=[1, 2]) 254 | 255 | if balanced==False: 256 | weight = 1 257 | if k >= 1: 258 | weight += a[0]**2 * (k_x**2 + k_y**2) 259 | if k >= 2: 260 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 261 | weight = torch.sqrt(weight) 262 | loss = self.rel(x*weight, y*weight) 263 | else: 264 | loss = self.rel(x, y) 265 | if k >= 1: 266 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 267 | loss += self.rel(x*weight, y*weight) 268 | if k >= 2: 269 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 270 | loss += self.rel(x*weight, y*weight) 271 | loss = loss / (k+1) 272 | 273 | return loss 274 | 275 | # print the number of parameters 276 | def count_params(model): 277 | c = 0 278 | for p in list(model.parameters()): 279 | c += reduce(operator.mul, 280 | list(p.size()+(2,) if p.is_complex() else p.size())) 281 | return c 282 | -------------------------------------------------------------------------------- /Version 1.0.0/wno_1d_Burgers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 1-D Burger's equation (time-independent problem). 7 | """ 8 | 9 | from IPython import get_ipython 10 | get_ipython().magic('reset -sf') 11 | 12 | import pywt 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import matplotlib.pyplot as plt 18 | 19 | from timeit import default_timer 20 | from utilities3 import * 21 | from pytorch_wavelets import DWT1D, IDWT1D 22 | 23 | torch.manual_seed(0) 24 | np.random.seed(0) 25 | 26 | # %% 27 | """ Def: 1d Wavelet layer """ 28 | class WaveConv1d(nn.Module): 29 | def __init__(self, in_channels, out_channels, level, dummy): 30 | super(WaveConv1d, self).__init__() 31 | 32 | """ 33 | 1D Wavelet layer. It does Wavelet Transform, linear transform, and 34 | Inverse Wavelet Transform. 35 | """ 36 | 37 | self.in_channels = in_channels 38 | self.out_channels = out_channels 39 | self.level = level 40 | self.dwt_ = DWT1D(wave='db6', J=self.level, mode='symmetric').to(dummy.device) 41 | self.mode_data, _ = self.dwt_(dummy) 42 | self.modes1 = self.mode_data.shape[-1] 43 | 44 | self.scale = (1 / (in_channels*out_channels)) 45 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1)) 46 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1)) 47 | 48 | # Convolution 49 | def mul1d(self, input, weights): 50 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 51 | return torch.einsum("bix,iox->box", input, weights) 52 | 53 | def forward(self, x): 54 | batchsize = x.shape[0] 55 | # Compute single tree Discrete Wavelet coefficients using some wavelet 56 | dwt = DWT1D(wave='db6', J=self.level, mode='symmetric').to(device) 57 | x_ft, x_coeff = dwt(x) 58 | 59 | # Multiply the final low pass and high pass coefficients 60 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-1], device=x.device) 61 | out_ft = self.mul1d(x_ft, self.weights1) 62 | x_coeff[-1] = self.mul1d(x_coeff[-1], self.weights2) 63 | 64 | # Reconstruct the signal 65 | idwt = IDWT1D(wave='db6', mode='symmetric').to(device) 66 | x = idwt((out_ft, x_coeff)) 67 | return x 68 | 69 | """ The forward operation """ 70 | class WNO1d(nn.Module): 71 | def __init__(self, width, level, dummy_data): 72 | super(WNO1d, self).__init__() 73 | 74 | """ 75 | The WNO network. It contains 4 layers of the Wavelet integral layer. 76 | 1. Lift the input using v(x) = self.fc0 . 77 | 2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v). 78 | W is defined by self.w_; K is defined by self.conv_. 79 | 3. Project the output of last layer using self.fc1 and self.fc2. 80 | 81 | input: the solution of the initial condition and location (a(x), x) 82 | input shape: (batchsize, x=s, c=2) 83 | output: the solution of a later timestep 84 | output shape: (batchsize, x=s, c=1) 85 | """ 86 | 87 | self.level = level 88 | self.width = width 89 | self.dummy_data = dummy_data 90 | self.padding = 2 # pad the domain when required 91 | self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x) 92 | 93 | self.conv0 = WaveConv1d(self.width, self.width, self.level, self.dummy_data) 94 | self.conv1 = WaveConv1d(self.width, self.width, self.level, self.dummy_data) 95 | self.conv2 = WaveConv1d(self.width, self.width, self.level, self.dummy_data) 96 | self.conv3 = WaveConv1d(self.width, self.width, self.level, self.dummy_data) 97 | self.w0 = nn.Conv1d(self.width, self.width, 1) 98 | self.w1 = nn.Conv1d(self.width, self.width, 1) 99 | self.w2 = nn.Conv1d(self.width, self.width, 1) 100 | self.w3 = nn.Conv1d(self.width, self.width, 1) 101 | 102 | self.fc1 = nn.Linear(self.width, 128) 103 | self.fc2 = nn.Linear(128, 1) 104 | 105 | def forward(self, x): 106 | grid = self.get_grid(x.shape, x.device) 107 | x = torch.cat((x, grid), dim=-1) 108 | x = self.fc0(x) 109 | x = x.permute(0, 2, 1) 110 | # x = F.pad(x, [0,self.padding]) 111 | 112 | x1 = self.conv0(x) 113 | x2 = self.w0(x) 114 | x = x1 + x2 115 | x = F.gelu(x) 116 | 117 | x1 = self.conv1(x) 118 | x2 = self.w1(x) 119 | x = x1 + x2 120 | x = F.gelu(x) 121 | 122 | x1 = self.conv2(x) 123 | x2 = self.w2(x) 124 | x = x1 + x2 125 | x = F.gelu(x) 126 | 127 | x1 = self.conv3(x) 128 | x2 = self.w3(x) 129 | x = x1 + x2 130 | 131 | # x = x[..., :-self.padding] 132 | x = x.permute(0, 2, 1) 133 | x = self.fc1(x) 134 | x = F.gelu(x) 135 | x = self.fc2(x) 136 | return x 137 | 138 | def get_grid(self, shape, device): 139 | # The grid of the solution 140 | batchsize, size_x = shape[0], shape[1] 141 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 142 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 143 | return gridx.to(device) 144 | 145 | 146 | # %% 147 | """ Model configurations """ 148 | 149 | ntrain = 1000 150 | ntest = 100 151 | 152 | sub = 2**3 # subsampling rate 153 | h = 2**13 // sub # total grid size divided by the subsampling rate 154 | s = h 155 | 156 | batch_size = 10 157 | learning_rate = 0.001 158 | 159 | epochs = 500 160 | step_size = 100 161 | gamma = 0.5 162 | 163 | level = 8 164 | width = 64 165 | 166 | # %% 167 | """ Read data """ 168 | 169 | # Data is of the shape (number of samples, grid size) 170 | dataloader = MatReader('data/burgers_data_R10.mat') 171 | x_data = dataloader.read_field('a')[:,::sub] 172 | y_data = dataloader.read_field('u')[:,::sub] 173 | 174 | x_train = x_data[:ntrain,:] 175 | y_train = y_data[:ntrain,:] 176 | x_test = x_data[-ntest:,:] 177 | y_test = y_data[-ntest:,:] 178 | 179 | x_train = x_train.reshape(ntrain,s,1) 180 | x_test = x_test.reshape(ntest,s,1) 181 | 182 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 183 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 184 | 185 | # %% 186 | """ The model definition """ 187 | model = WNO1d(width, level, x_train.permute(0,2,1)).to(device) 188 | print(count_params(model)) 189 | 190 | """ Training and testing """ 191 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 192 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 193 | 194 | train_loss = torch.zeros(epochs) 195 | test_loss = torch.zeros(epochs) 196 | myloss = LpLoss(size_average=False) 197 | for ep in range(epochs): 198 | model.train() 199 | t1 = default_timer() 200 | train_mse = 0 201 | train_l2 = 0 202 | for x, y in train_loader: 203 | x, y = x.to(device), y.to(device) 204 | 205 | optimizer.zero_grad() 206 | out = model(x) 207 | 208 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean') 209 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 210 | l2.backward() # l2 relative loss 211 | 212 | optimizer.step() 213 | train_mse += mse.item() 214 | train_l2 += l2.item() 215 | 216 | scheduler.step() 217 | model.eval() 218 | test_l2 = 0.0 219 | with torch.no_grad(): 220 | for x, y in test_loader: 221 | x, y = x.to(device), y.to(device) 222 | 223 | out = model(x) 224 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 225 | 226 | train_mse /= len(train_loader) 227 | train_l2 /= ntrain 228 | test_l2 /= ntest 229 | 230 | train_loss[ep] = train_l2 231 | test_loss[ep] = test_l2 232 | 233 | t2 = default_timer() 234 | print(ep, t2-t1, train_mse, train_l2, test_l2) 235 | 236 | # %% 237 | """ Prediction """ 238 | pred = torch.zeros(y_test.shape) 239 | index = 0 240 | test_e = torch.zeros(y_test.shape[0]) 241 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 242 | with torch.no_grad(): 243 | for x, y in test_loader: 244 | test_l2 = 0 245 | x, y = x.to(device), y.to(device) 246 | 247 | out = model(x).view(-1) 248 | pred[index] = out 249 | 250 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 251 | test_e[index] = test_l2 252 | print(index, test_l2) 253 | index = index + 1 254 | 255 | print('Mean Error:', 100*torch.mean(test_e)) 256 | 257 | # %% 258 | """ Plotting """ # for paper figures please see 'WNO_testing_(.).py' files 259 | figure7 = plt.figure(figsize = (10, 8)) 260 | for i in range(y_test.shape[0]): 261 | if i % 20 == 1: 262 | plt.plot(y_test[i, :].numpy(), label='Actual') 263 | plt.plot(pred[i,:].numpy(), 'k', label='Prediction') 264 | plt.legend() 265 | plt.grid(True) 266 | plt.margins(0) 267 | 268 | # %% 269 | """ 270 | For saving the trained model and prediction data 271 | """ 272 | # torch.save(model, 'model/model_wno_burgers') 273 | # scipy.io.savemat('pred/pred_wno_burgers.mat', mdict={'pred': pred.cpu().numpy()}) 274 | # scipy.io.savemat('loss/train_loss_wno_burgers.mat', mdict={'train_loss': train_loss.cpu().numpy()}) 275 | # scipy.io.savemat('loss/test_loss_wno_burgers.mat', mdict={'test_loss': test_loss.cpu().numpy()}) 276 | 277 | # torch.cuda.empty_cache() 278 | -------------------------------------------------------------------------------- /Version 1.0.0/wno_2d_AC.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 2-D Allen-Cahn equation (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.parameter import Parameter 14 | import matplotlib.pyplot as plt 15 | 16 | from timeit import default_timer 17 | from utilities3 import * 18 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 19 | 20 | torch.manual_seed(0) 21 | np.random.seed(0) 22 | 23 | # %% 24 | """ Def: 2d Wavelet layer """ 25 | class WaveConv2d(nn.Module): 26 | def __init__(self, in_channels, out_channels, level, dummy): 27 | super(WaveConv2d, self).__init__() 28 | 29 | """ 30 | 2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 31 | """ 32 | 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.level = level 36 | self.dwt_ = DWT(J=self.level, mode='symmetric', wave='db4').to(dummy.device) 37 | self.mode_data, _ = self.dwt_(dummy) 38 | self.modes1 = self.mode_data.shape[-2] 39 | self.modes2 = self.mode_data.shape[-1] 40 | 41 | self.scale = (1 / (in_channels * out_channels)) 42 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 43 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 44 | # self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 45 | # self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 46 | 47 | # Convolution 48 | def mul2d(self, input, weights): 49 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 50 | return torch.einsum("bixy,ioxy->boxy", input, weights) 51 | 52 | def forward(self, x): 53 | batchsize = x.shape[0] 54 | # Compute single tree Discrete Wavelet coefficients using some wavelet 55 | dwt = DWT(J=self.level, mode='symmetric', wave='db4').to(device) 56 | x_ft, x_coeff = dwt(x) 57 | 58 | # Multiply relevant Wavelet modes 59 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 60 | out_ft = self.mul2d(x_ft, self.weights1) 61 | # Multiply the finer wavelet coefficients 62 | x_coeff[-1][:,:,0,:,:] = self.mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 63 | # x_coeff[-1][:,:,1,:,:] = self.mul2d(x_coeff[-1][:,:,1,:,:].clone(), self.weights3) 64 | # x_coeff[-1][:,:,2,:,:] = self.mul2d(x_coeff[-1][:,:,2,:,:].clone(), self.weights4) 65 | 66 | # Return to physical space 67 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 68 | x = idwt((out_ft, x_coeff)) 69 | return x 70 | 71 | """ The forward operation """ 72 | class WNO2d(nn.Module): 73 | def __init__(self, width, level, dummy_data): 74 | super(WNO2d, self).__init__() 75 | 76 | """ 77 | The WNO network. It contains 4 layers of the Wavelet integral layer. 78 | 1. Lift the input using v(x) = self.fc0 . 79 | 2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v). 80 | W is defined by self.w_; K is defined by self.conv_. 81 | 3. Project the output of last layer using self.fc1 and self.fc2. 82 | 83 | input: the solution of the coefficient function and locations (a(x, y), x, y) 84 | input shape: (batchsize, x=s, y=s, c=3) 85 | output: the solution 86 | output shape: (batchsize, x=s, y=s, c=1) 87 | """ 88 | 89 | self.level = level 90 | self.dummy_data = dummy_data 91 | self.width = width 92 | self.padding = 1 # pad the domain if input is non-periodic 93 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 94 | 95 | self.conv0 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 96 | self.conv1 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 97 | self.conv2 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 98 | self.conv3 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 99 | self.w0 = nn.Conv2d(self.width, self.width, 1) 100 | self.w1 = nn.Conv2d(self.width, self.width, 1) 101 | self.w2 = nn.Conv2d(self.width, self.width, 1) 102 | self.w3 = nn.Conv2d(self.width, self.width, 1) 103 | 104 | self.fc1 = nn.Linear(self.width, 128) 105 | self.fc2 = nn.Linear(128, 1) 106 | 107 | def forward(self, x): 108 | grid = self.get_grid(x.shape, x.device) 109 | x = torch.cat((x, grid), dim=-1) 110 | 111 | x = self.fc0(x) 112 | x = x.permute(0, 3, 1, 2) 113 | x = F.pad(x, [0,self.padding, 0,self.padding]) 114 | 115 | x1 = self.conv0(x) 116 | x2 = self.w0(x) 117 | x = x1 + x2 118 | x = F.gelu(x) 119 | 120 | x1 = self.conv1(x) 121 | x2 = self.w1(x) 122 | x = x1 + x2 123 | x = F.gelu(x) 124 | 125 | x1 = self.conv2(x) 126 | x2 = self.w2(x) 127 | x = x1 + x2 128 | x = F.gelu(x) 129 | 130 | x1 = self.conv3(x) 131 | x2 = self.w3(x) 132 | x = x1 + x2 133 | 134 | x = x[..., :-self.padding, :-self.padding] 135 | x = x.permute(0, 2, 3, 1) 136 | x = self.fc1(x) 137 | x = F.gelu(x) 138 | x = self.fc2(x) 139 | return x 140 | 141 | def get_grid(self, shape, device): 142 | # The grid of the solution 143 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 144 | gridx = torch.tensor(np.linspace(0, 3, size_x), dtype=torch.float) 145 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 146 | gridy = torch.tensor(np.linspace(0, 3, size_y), dtype=torch.float) 147 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 148 | return torch.cat((gridx, gridy), dim=-1).to(device) 149 | 150 | # %% 151 | """ Model configurations """ 152 | 153 | PATH = 'data/Allen_cahn_2d_128_128_T10.mat' 154 | ntrain = 1400 155 | ntest = 100 156 | 157 | batch_size = 20 158 | learning_rate = 0.001 159 | 160 | epochs = 500 161 | step_size = 25 162 | gamma = 0.75 163 | 164 | level = 2 165 | width = 64 166 | 167 | r = 3 168 | h = int(((129 - 1)/r) + 1) 169 | s = h 170 | 171 | # %% 172 | """ Read data """ 173 | reader = MatReader(PATH) 174 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s] 175 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s] 176 | 177 | x_test = reader.read_field('coeff')[-ntest:,::r,::r][:,:s,:s] 178 | y_test = reader.read_field('sol')[-ntest:,::r,::r][:,:s,:s] 179 | 180 | x_normalizer = UnitGaussianNormalizer(x_train) 181 | x_train = x_normalizer.encode(x_train) 182 | x_test = x_normalizer.encode(x_test) 183 | 184 | y_normalizer = UnitGaussianNormalizer(y_train) 185 | y_train = y_normalizer.encode(y_train) 186 | 187 | x_train = x_train.reshape(ntrain,s,s,1) 188 | x_test = x_test.reshape(ntest,s,s,1) 189 | 190 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 191 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 192 | 193 | # %% 194 | """ The model definition """ 195 | model = WNO2d(width, level, x_train.permute(0,3,1,2)).to(device) 196 | print(count_params(model)) 197 | 198 | """ Training and testing """ 199 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) 200 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 201 | 202 | train_loss = torch.zeros(epochs) 203 | test_loss = torch.zeros(epochs) 204 | myloss = LpLoss(size_average=False) 205 | y_normalizer.cuda() 206 | for ep in range(epochs): 207 | model.train() 208 | t1 = default_timer() 209 | train_l2 = 0 210 | for x, y in train_loader: 211 | x, y = x.to(device), y.to(device) 212 | 213 | optimizer.zero_grad() 214 | out = model(x).reshape(batch_size, s, s) 215 | out = y_normalizer.decode(out) 216 | y = y_normalizer.decode(y) 217 | 218 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 219 | loss.backward() 220 | optimizer.step() 221 | train_l2 += loss.item() 222 | 223 | scheduler.step() 224 | model.eval() 225 | test_l2 = 0.0 226 | with torch.no_grad(): 227 | for x, y in test_loader: 228 | x, y = x.to(device), y.to(device) 229 | 230 | out = model(x).reshape(batch_size, s, s) 231 | out = y_normalizer.decode(out) 232 | 233 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 234 | 235 | train_l2/= ntrain 236 | test_l2 /= ntest 237 | 238 | train_loss[ep] = train_l2 239 | test_loss[ep] = test_l2 240 | 241 | t2 = default_timer() 242 | print(ep, t2-t1, train_l2, test_l2) 243 | 244 | # %% 245 | """ Prediction """ 246 | pred = torch.zeros(y_test.shape) 247 | index = 0 248 | test_e = torch.zeros(y_test.shape) 249 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 250 | with torch.no_grad(): 251 | for x, y in test_loader: 252 | test_l2 = 0 253 | x, y = x.to(device), y.to(device) 254 | 255 | out = model(x).reshape(s, s) 256 | out = y_normalizer.decode(out) 257 | pred[index] = out 258 | 259 | test_l2 += myloss(out.reshape(1, s, s), y.reshape(1, s, s)).item() 260 | test_e[index] = test_l2 261 | print(index, test_l2) 262 | index = index + 1 263 | 264 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 265 | 266 | # %% 267 | """ Plotting """ # for paper figures please see 'WNO_testing_(.).py' files 268 | figure7 = plt.figure(figsize = (10, 5)) 269 | plt.subplots_adjust(hspace=0.01) 270 | index = 0 271 | for value in range(y_test.shape[0]): 272 | if value % 20 == 1: 273 | plt.subplot(2,5, index+1) 274 | plt.imshow(y_test[value,:,:], label='True', cmap='seismic') 275 | plt.title('Actual') 276 | plt.subplot(2,5, index+1+5) 277 | plt.imshow(pred.cpu().detach().numpy()[value,:,:], cmap='seismic') 278 | plt.title('Identified') 279 | plt.margins(0) 280 | index = index + 1 281 | 282 | # %% 283 | """ 284 | For saving the trained model and prediction data 285 | """ 286 | # torch.save(model, 'model/model_wno_AC2d') 287 | # scipy.io.savemat('pred/pred_wno_AC2d.mat', mdict={'pred': pred.cpu().numpy()}) 288 | # scipy.io.savemat('loss/train_loss_wno_AC2d.mat', mdict={'train_loss': train_loss.cpu().numpy()}) 289 | # scipy.io.savemat('loss/test_loss_wno_AC2d.mat', mdict={'test_loss': test_loss.cpu().numpy()}) 290 | 291 | # torch.cuda.empty_cache() 292 | -------------------------------------------------------------------------------- /Version 1.0.0/wno_2d_Darcy.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 2-D Darcy equation (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.parameter import Parameter 14 | import matplotlib.pyplot as plt 15 | 16 | from timeit import default_timer 17 | from utilities3 import * 18 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 19 | 20 | torch.manual_seed(0) 21 | np.random.seed(0) 22 | 23 | # %% 24 | """ Def: 2d Wavelet layer """ 25 | class WaveConv2d(nn.Module): 26 | def __init__(self, in_channels, out_channels, level, dummy): 27 | super(WaveConv2d, self).__init__() 28 | 29 | """ 30 | 2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 31 | """ 32 | 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.level = level 36 | self.dwt_ = DWT(J=self.level, mode='symmetric', wave='db4').to(dummy.device) 37 | self.mode_data, _ = self.dwt_(dummy) 38 | self.modes1 = self.mode_data.shape[-2] 39 | self.modes2 = self.mode_data.shape[-1] 40 | 41 | self.scale = (1 / (in_channels * out_channels)) 42 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 43 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 44 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 45 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 46 | 47 | # Convolution 48 | def mul2d(self, input, weights): 49 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 50 | return torch.einsum("bixy,ioxy->boxy", input, weights) 51 | 52 | def forward(self, x): 53 | batchsize = x.shape[0] 54 | # Compute single tree Discrete Wavelet coefficients using some wavelet 55 | dwt = DWT(J=self.level, mode='symmetric', wave='db4').to(device) 56 | x_ft, x_coeff = dwt(x) 57 | 58 | # Multiply relevant Wavelet modes 59 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 60 | out_ft = self.mul2d(x_ft, self.weights1) 61 | # Multiply the finer wavelet coefficients 62 | x_coeff[-1][:,:,0,:,:] = self.mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 63 | x_coeff[-1][:,:,1,:,:] = self.mul2d(x_coeff[-1][:,:,1,:,:].clone(), self.weights3) 64 | x_coeff[-1][:,:,2,:,:] = self.mul2d(x_coeff[-1][:,:,2,:,:].clone(), self.weights4) 65 | 66 | # Return to physical space 67 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 68 | x = idwt((out_ft, x_coeff)) 69 | return x 70 | 71 | """ The forward operation """ 72 | class WNO2d(nn.Module): 73 | def __init__(self, width, level, dummy_data): 74 | super(WNO2d, self).__init__() 75 | 76 | """ 77 | The WNO network. It contains 4 layers of the Wavelet integral layer. 78 | 1. Lift the input using v(x) = self.fc0 . 79 | 2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v). 80 | W is defined by self.w_; K is defined by self.conv_. 81 | 3. Project the output of last layer using self.fc1 and self.fc2. 82 | 83 | input: the solution of the coefficient function and locations (a(x, y), x, y) 84 | input shape: (batchsize, x=s, y=s, c=3) 85 | output: the solution 86 | output shape: (batchsize, x=s, y=s, c=1) 87 | """ 88 | 89 | self.level = level 90 | self.dummy_data = dummy_data 91 | self.width = width 92 | self.padding = 1 # pad the domain when required 93 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 94 | 95 | self.conv0 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 96 | self.conv1 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 97 | self.conv2 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 98 | self.conv3 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 99 | self.w0 = nn.Conv2d(self.width, self.width, 1) 100 | self.w1 = nn.Conv2d(self.width, self.width, 1) 101 | self.w2 = nn.Conv2d(self.width, self.width, 1) 102 | self.w3 = nn.Conv2d(self.width, self.width, 1) 103 | 104 | self.fc1 = nn.Linear(self.width, 192) 105 | self.fc2 = nn.Linear(192, 1) 106 | 107 | def forward(self, x): 108 | grid = self.get_grid(x.shape, x.device) 109 | x = torch.cat((x, grid), dim=-1) 110 | 111 | x = self.fc0(x) 112 | x = x.permute(0, 3, 1, 2) 113 | x = F.pad(x, [0,self.padding, 0,self.padding]) 114 | 115 | x1 = self.conv0(x) 116 | x2 = self.w0(x) 117 | x = x1 + x2 118 | x = F.gelu(x) 119 | 120 | x1 = self.conv1(x) 121 | x2 = self.w1(x) 122 | x = x1 + x2 123 | x = F.gelu(x) 124 | 125 | x1 = self.conv2(x) 126 | x2 = self.w2(x) 127 | x = x1 + x2 128 | x = F.gelu(x) 129 | 130 | x1 = self.conv3(x) 131 | x2 = self.w3(x) 132 | x = x1 + x2 133 | 134 | x = x[..., :-self.padding, :-self.padding] 135 | x = x.permute(0, 2, 3, 1) 136 | x = self.fc1(x) 137 | x = F.gelu(x) 138 | x = self.fc2(x) 139 | return x 140 | 141 | def get_grid(self, shape, device): 142 | # The grid of the solution 143 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 144 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 145 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 146 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 147 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 148 | return torch.cat((gridx, gridy), dim=-1).to(device) 149 | 150 | # %% 151 | """ Model configurations """ 152 | 153 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 154 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 155 | ntrain = 1000 156 | ntest = 100 157 | 158 | batch_size = 20 159 | learning_rate = 0.001 160 | 161 | epochs = 1000 162 | step_size = 50 163 | gamma = 0.75 164 | 165 | level = 4 166 | width = 64 167 | 168 | r = 5 169 | h = int(((421 - 1)/r) + 1) 170 | s = h 171 | 172 | # %% 173 | """ Read data """ 174 | reader = MatReader(TRAIN_PATH) 175 | x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s] 176 | y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s] 177 | 178 | reader.load_file(TEST_PATH) 179 | x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s] 180 | y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s] 181 | 182 | x_normalizer = UnitGaussianNormalizer(x_train) 183 | x_train = x_normalizer.encode(x_train) 184 | x_test = x_normalizer.encode(x_test) 185 | 186 | y_normalizer = UnitGaussianNormalizer(y_train) 187 | y_train = y_normalizer.encode(y_train) 188 | 189 | x_train = x_train.reshape(ntrain,s,s,1) 190 | x_test = x_test.reshape(ntest,s,s,1) 191 | 192 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 193 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 194 | 195 | # %% 196 | """ The model definition """ 197 | model = WNO2d(width, level, x_train.permute(0,3,1,2)).to(device) 198 | print(count_params(model)) 199 | 200 | """ Training and testing """ 201 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 202 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 203 | 204 | train_loss = torch.zeros(epochs) 205 | test_loss = torch.zeros(epochs) 206 | myloss = LpLoss(size_average=False) 207 | y_normalizer.cuda() 208 | for ep in range(epochs): 209 | model.train() 210 | t1 = default_timer() 211 | train_l2 = 0 212 | for x, y in train_loader: 213 | x, y = x.to(device), y.to(device) 214 | 215 | optimizer.zero_grad() 216 | out = model(x).reshape(batch_size, s, s) 217 | out = y_normalizer.decode(out) 218 | y = y_normalizer.decode(y) 219 | 220 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 221 | loss.backward() 222 | optimizer.step() 223 | train_l2 += loss.item() 224 | 225 | scheduler.step() 226 | model.eval() 227 | test_l2 = 0.0 228 | with torch.no_grad(): 229 | for x, y in test_loader: 230 | x, y = x.to(device), y.to(device) 231 | 232 | out = model(x).reshape(batch_size, s, s) 233 | out = y_normalizer.decode(out) 234 | 235 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 236 | 237 | train_l2/= ntrain 238 | test_l2 /= ntest 239 | 240 | train_loss[ep] = train_l2 241 | test_loss[ep] = test_l2 242 | 243 | t2 = default_timer() 244 | print(ep, t2-t1, train_l2, test_l2) 245 | 246 | # %% 247 | """ Prediction """ 248 | pred = torch.zeros(y_test.shape) 249 | index = 0 250 | test_e = torch.zeros(y_test.shape[0]) 251 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 252 | with torch.no_grad(): 253 | for x, y in test_loader: 254 | test_l2 = 0 255 | x, y = x.to(device), y.to(device) 256 | 257 | out = model(x).reshape(s, s) 258 | out = y_normalizer.decode(out) 259 | pred[index] = out 260 | 261 | test_l2 += myloss(out.reshape(1, s, s), y.reshape(1, s, s)).item() 262 | test_e[index] = test_l2 263 | print(index, test_l2) 264 | index = index + 1 265 | 266 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 267 | 268 | # %% 269 | """ Plotting """ # for paper figures please see 'WNO_testing_(.).py' files 270 | figure7 = plt.figure(figsize = (10, 5)) 271 | plt.subplots_adjust(hspace=0.01) 272 | index = 0 273 | for value in range(y_test.shape[0]): 274 | if value % 20 == 1: 275 | plt.subplot(2,5, index+1) 276 | plt.imshow(y_test[value,:,:], label='True', cmap='Spectral') 277 | plt.title('Actual') 278 | plt.subplot(2,5, index+1+5) 279 | plt.imshow(pred.cpu().detach().numpy()[value,:,:], cmap='Spectral') 280 | plt.title('Identified') 281 | plt.margins(0) 282 | index = index + 1 283 | 284 | # %% 285 | """ 286 | For saving the trained model and prediction data 287 | """ 288 | # torch.save(model, 'model/model_wno_darcy2d') 289 | # scipy.io.savemat('pred/pred_wno_darcy2d.mat', mdict={'pred': pred.cpu().numpy()}) 290 | # scipy.io.savemat('loss/train_loss_wno_darcy2d.mat', mdict={'train_loss': train_loss.cpu().numpy()}) 291 | # scipy.io.savemat('loss/test_loss_wno_darcy2d.mat', mdict={'test_loss': test_loss.cpu().numpy()}) 292 | 293 | # torch.cuda.empty_cache() 294 | -------------------------------------------------------------------------------- /Version 1.0.0/wno_2d_ERA5.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for forecast of monthly averaged 2m air temperature (time-independent problem). 7 | """ 8 | 9 | from IPython import get_ipython 10 | get_ipython().magic('reset -sf') 11 | 12 | # %% 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.nn.parameter import Parameter 18 | import matplotlib.pyplot as plt 19 | 20 | import xarray as xr 21 | from timeit import default_timer 22 | from utilities3 import * 23 | from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT) 24 | 25 | torch.manual_seed(0) 26 | np.random.seed(0) 27 | 28 | # %% 29 | """ Def: 2d Wavelet layer """ 30 | class WaveConv2d(nn.Module): 31 | def __init__(self, in_channels, out_channels, level, dummy): 32 | super(WaveConv2d, self).__init__() 33 | 34 | """ 35 | 2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 36 | """ 37 | 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | self.level = level 41 | self.dwt_ = DWT(J=self.level, mode='symmetric', wave='db4').to(dummy.device) 42 | self.mode_data, _ = self.dwt_(dummy) 43 | self.modes1 = self.mode_data.shape[-2] 44 | self.modes2 = self.mode_data.shape[-1] 45 | 46 | self.scale = (1 / (in_channels * out_channels)) 47 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 48 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 49 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 50 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2)) 51 | 52 | # Convolution 53 | def mul2d(self, input, weights): 54 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 55 | return torch.einsum("bixy,ioxy->boxy", input, weights) 56 | 57 | def forward(self, x): 58 | batchsize = x.shape[0] 59 | # Compute single tree Discrete Wavelet coefficients using some wavelet 60 | dwt = DWT(J=self.level, mode='symmetric', wave='db4').to(device) 61 | x_ft, x_coeff = dwt(x) 62 | 63 | # Multiply relevant Wavelet modes 64 | out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-2], x_ft.shape[-1], device=x.device) 65 | out_ft = self.mul2d(x_ft, self.weights1) 66 | # Multiply the finer wavelet coefficients 67 | x_coeff[-1][:,:,0,:,:] = self.mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2) 68 | x_coeff[-1][:,:,1,:,:] = self.mul2d(x_coeff[-1][:,:,1,:,:].clone(), self.weights3) 69 | x_coeff[-1][:,:,2,:,:] = self.mul2d(x_coeff[-1][:,:,2,:,:].clone(), self.weights4) 70 | 71 | # Return to physical space 72 | idwt = IDWT(mode='symmetric', wave='db4').to(device) 73 | x = idwt((out_ft, x_coeff)) 74 | return x 75 | 76 | """ The forward operation """ 77 | class WNO2d(nn.Module): 78 | def __init__(self, width, level, dummy_data): 79 | super(WNO2d, self).__init__() 80 | 81 | """ 82 | The WNO network. It contains 4 layers of the Wavelet integral layer. 83 | 1. Lift the input using v(x) = self.fc0 . 84 | 2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v). 85 | W is defined by self.w_; K is defined by self.conv_. 86 | 3. Project the output of last layer using self.fc1 and self.fc2. 87 | 88 | input: the solution of the coefficient function and locations (a(x, y), x, y) 89 | input shape: (batchsize, x=s, y=s, c=3) 90 | output: the solution 91 | output shape: (batchsize, x=s, y=s, c=1) 92 | """ 93 | 94 | self.level = level 95 | self.dummy_data = dummy_data 96 | self.width = width 97 | self.padding = 2 # pad the domain if input is non-periodic 98 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 99 | 100 | self.conv0 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 101 | self.conv1 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 102 | self.conv2 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 103 | self.conv3 = WaveConv2d(self.width, self.width, self.level, self.dummy_data) 104 | self.w0 = nn.Conv2d(self.width, self.width, 1) 105 | self.w1 = nn.Conv2d(self.width, self.width, 1) 106 | self.w2 = nn.Conv2d(self.width, self.width, 1) 107 | self.w3 = nn.Conv2d(self.width, self.width, 1) 108 | 109 | self.fc1 = nn.Linear(self.width, 128) 110 | self.fc2 = nn.Linear(128, 1) 111 | 112 | def forward(self, x): 113 | grid = self.get_grid(x.shape, x.device) 114 | x = torch.cat((x, grid), dim=-1) 115 | 116 | x = self.fc0(x) 117 | x = x.permute(0, 3, 1, 2) 118 | x = F.pad(x, [0,self.padding, 0,self.padding]) # padding, if required 119 | 120 | x1 = self.conv0(x) 121 | x2 = self.w0(x) 122 | x = x1 + x2 123 | x = F.gelu(x) 124 | 125 | x1 = self.conv1(x) 126 | x2 = self.w1(x) 127 | x = x1 + x2 128 | x = F.gelu(x) 129 | 130 | x1 = self.conv2(x) 131 | x2 = self.w2(x) 132 | x = x1 + x2 133 | x = F.gelu(x) 134 | 135 | x1 = self.conv3(x) 136 | x2 = self.w3(x) 137 | x = x1 + x2 138 | 139 | x = x[..., :-self.padding, :-self.padding] # removing padding, when applicable 140 | x = x.permute(0, 2, 3, 1) 141 | x = self.fc1(x) 142 | x = F.gelu(x) 143 | x = self.fc2(x) 144 | return x 145 | 146 | def get_grid(self, shape, device): 147 | # The grid of the solution 148 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 149 | gridx = torch.tensor(np.linspace(0, 360, size_x), dtype=torch.float) 150 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 151 | gridy = torch.tensor(np.linspace(90, -90, size_y), dtype=torch.float) 152 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 153 | return torch.cat((gridx, gridy), dim=-1).to(device) 154 | 155 | # %% 156 | """ Model configurations """ 157 | PATH = 'data/ERA5_temp.grib' 158 | ntrain = 460 159 | ntest = 50 160 | 161 | batch_size = 10 162 | learning_rate = 0.001 163 | 164 | epochs = 500 165 | step_size = 25 166 | gamma = 0.75 167 | 168 | level = 5 169 | width = 64 170 | 171 | r = 6 172 | h = int(((721 - 1)/r)) 173 | s = int(((1441 - 1)/r)) 174 | 175 | # %% 176 | """ Read data """ 177 | 178 | ds = xr.open_dataset(PATH, engine='cfgrib') 179 | data = np.array(ds["t2m"]) 180 | data = torch.tensor(data) 181 | data = data[:, :720, :] 182 | # data = F.pad(data, [0,1]) # pad last dimension to make it periodic, if required 183 | 184 | x_train = data[:-1][:ntrain, ::r, ::r] 185 | y_train = data[1:][:ntrain, ::r, ::r] 186 | 187 | x_test = data[:-1][-ntest:, ::r, ::r] 188 | y_test = data[1:][-ntest:, ::r, ::r] 189 | 190 | x_normalizer = UnitGaussianNormalizer(x_train) 191 | x_train = x_normalizer.encode(x_train) 192 | x_test = x_normalizer.encode(x_test) 193 | 194 | y_normalizer = UnitGaussianNormalizer(y_train) 195 | y_train = y_normalizer.encode(y_train) 196 | 197 | x_train = x_train.reshape(ntrain,h,s,1) 198 | x_test = x_test.reshape(ntest,h,s,1) 199 | 200 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 201 | batch_size=batch_size, shuffle=True) 202 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 203 | batch_size=batch_size, shuffle=False) 204 | 205 | # %% 206 | """ The model definition """ 207 | model = WNO2d(width, level, x_train.permute(0,3,1,2)).to(device) 208 | print(count_params(model)) 209 | 210 | """ Training and testing """ 211 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 212 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 213 | 214 | myloss = LpLoss(size_average=False) 215 | y_normalizer.cuda() 216 | for ep in range(epochs): 217 | model.train() 218 | t1 = default_timer() 219 | train_l2 = 0 220 | for x, y in train_loader: 221 | x, y = x.to(device), y.to(device) 222 | 223 | optimizer.zero_grad() 224 | out = model(x).reshape(batch_size, h, s) 225 | out = y_normalizer.decode(out) 226 | y = y_normalizer.decode(y) 227 | 228 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 229 | loss.backward() 230 | optimizer.step() 231 | train_l2 += loss.item() 232 | 233 | scheduler.step() 234 | model.eval() 235 | test_l2 = 0.0 236 | with torch.no_grad(): 237 | for x, y in test_loader: 238 | x, y = x.to(device), y.to(device) 239 | 240 | out = model(x).reshape(x.shape[0], h, s) 241 | out = y_normalizer.decode(out) 242 | 243 | test_l2 += myloss(out.view(x.shape[0],-1), y.view(x.shape[0],-1)).item() 244 | 245 | train_l2/= ntrain 246 | test_l2 /= ntest 247 | t2 = default_timer() 248 | print(ep, t2-t1, train_l2, test_l2) 249 | 250 | # %% 251 | """ Prediction """ 252 | pred = torch.zeros(y_test.shape) 253 | index = 0 254 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 255 | with torch.no_grad(): 256 | for x, y in test_loader: 257 | test_l2 = 0 258 | x, y = x.to(device), y.to(device) 259 | 260 | out = model(x).reshape(h, s) 261 | out = y_normalizer.decode(out) 262 | pred[index] = out 263 | 264 | test_l2 += myloss(out.reshape(1, h, s), y.reshape(1, h, s)).item() 265 | print(index, test_l2) 266 | index = index + 1 267 | 268 | # %% 269 | """ Plotting """ # for paper figures please see 'WNO_testing_(.).py' files 270 | figure7 = plt.figure(figsize = (12, 8)) 271 | plt.subplots_adjust(hspace=0.01, wspace=0.25) 272 | index = 1 273 | for value in range(y_test.shape[0]): 274 | if value % 15 == 1 and value != 1: 275 | plt.subplot(3,3, index) 276 | plt.imshow(y_test[value,:,:-1].cpu().numpy(), label='True', cmap='turbo') 277 | plt.title('Actual') 278 | plt.subplot(3,3, index+3) 279 | plt.imshow(pred.cpu().detach().numpy()[value,:,:-1], cmap='turbo') 280 | plt.title('Identified') 281 | plt.subplot(3,3, index+6) 282 | plt.imshow(pred[value,:,:-2]-y_test[value,:,:-2], cmap='turbo') 283 | plt.title('Error') 284 | plt.colorbar(fraction=0.024, pad=0.01) 285 | plt.margins(0) 286 | index = index + 1 287 | 288 | # %% 289 | """ 290 | For saving the trained model and prediction data 291 | """ 292 | # torch.save(model, 'model/model_wno_ERA5_t2m') 293 | # scipy.io.savemat('pred/pred_wno_ERA5_t2m.mat', mdict={'pred': pred.cpu().numpy()}) 294 | 295 | # torch.cuda.empty_cache() 296 | -------------------------------------------------------------------------------- /Version 2.0.0/README.md: -------------------------------------------------------------------------------- 1 | # Wavelet-Neural-Operator (WNO) 2 | This repository contains the python codes of the paper 3 | > + Tripura, T., & Chakraborty, S. (2023). Wavelet Neural Operator for solving parametric partial differential equations in computational mechanics problems. Computer Methods in Applied Mechanics and Engineering, 404, 115783. [Article](https://doi.org/10.1016/j.cma.2022.115783) 4 | > + ArXiv version- "Wavelet neural operator: a neural operator for parametric partial differential equations". The arXiv version can be accessed [here](https://arxiv.org/abs/2205.02191). 5 | 6 | ## New in version 2.0.0 7 | ``` 8 | > Added superresolution attribute to the WNO. 9 | > Added 3D support to the WNO. 10 | > Improved the interface and readability of the codes. 11 | ``` 12 | 13 | ## Architecture of the wavelet neural operator (WNO). 14 | (a) Schematic of the proposed neural operator. (b) A simple WNO with one wavelet kernel integral layer. 15 | ![WNO](/Github_page_images/WNN.png) 16 | 17 | ## Construction of the parametric space using multi-level wavelet decomposition. 18 | ![Construction of parameterization space in WNO](/Github_page_images/WNN_parameter.png) 19 | 20 | ## Super resolution using Wavelet Neural Operator. 21 | > Super resolution in Burgers' diffusion dynamics: 22 | ![Train at resolution-1024 and Test at resolution-2048](/Github_page_images/Burgers_prediction.png) 23 | > Super resolution in Navier-Stokes equation with 10000 Reynolds number: 24 | ![Train in Low resolution](/Github_page_images/Animation_ns_64_3d_1e-4.gif) 25 | ![Test in High resolution](/Github_page_images/Animation_ns_256_3d_1e-4.gif) 26 | 27 | ## Files 28 | A short despcription on the files are provided below for ease of readers. 29 | ``` 30 | + `wno1d_advection_III.py`: For 1D wave advection equation (time-independent problem). 31 | + `wno1d_Advection_time_III.py`: For 1D wave advection equation with time-marching (time-dependent problem). 32 | + `wno1d_Burger_discontinuous.py`: For 1D Burgers' equation with discontinuous field (time-dependent problem). 33 | + `wno1d_Burgers.py`: For 1D Burger's equation (time-independent problem). 34 | + `wno2d_AC_dwt.py`: For 2D Allen-Cahn reaction-diffusion equation (time-independent problem). 35 | + `wno2d_Darcy_notch_cwt.py`: For 2D Darcy equation using Slim Continuous Wavelet Transform (time-independent problem). 36 | + `wno2d_Darcy_notch_dwt.py`: For 2D Darcy equation using Discrete wavelet transform (time-independent problem). 37 | + `wno2d_NS_cwt.py`: For 2D Navier-Stokes equation using Slim Continuous Wavelet Transform (time-dependent problem). 38 | + `wno2d_NS_dwt.py`: For 2D Navier-Stokes equation using Discrete Wavelet Transform (time-dependent problem). 39 | + `wno2d_Temperature_Daily_Avg.py`: For forecasting daily averaged 2m air temperature (time-dependent problem). 40 | + `wno2d_Temperature_Monthly_Avg.py`: For forecasting monthly averaged 2m air temperature (time-independent problem). 41 | + `wno3d_NS.py`: For 2D Navier-Stokes equation using 3D WNO (as a time-independent problem). 42 | 43 | + `Test_wno_super_1d_Burgers.py`: An example of Testing on new data with supersolution. 44 | 45 | + `utils.py` contains some useful functions for data handling (improvised from [FNO paper](https://github.com/zongyi-li/fourier_neural_operator)). 46 | + `wavelet_convolution.py` contains functions for 1D, 2D, and 3D convolution in wavelet domain. 47 | ``` 48 | 49 | ## Essential Python Libraries 50 | Following packages are required to be installed to run the above codes: 51 | + [PyTorch](https://pytorch.org/) 52 | + [PyWavelets - Wavelet Transforms in Python](https://pywavelets.readthedocs.io/en/latest/) 53 | + [Wavelet Transforms in Pytorch](https://github.com/fbcotter/pytorch_wavelets) 54 | + [Wavelet Transform Toolbox](https://github.com/v0lta/PyTorch-Wavelet-Toolbox) 55 | + [Xarray-Grib reader (To read ERA5 data in section 5)](https://docs.xarray.dev/en/stable/getting-started-guide/installing.html?highlight=install) 56 | 57 | Copy all the data in the folder 'data' and place the folder 'data' inside the same mother folder where the codes are present. Incase, the location of the data are changed, the correct path should be given. 58 | 59 | ## Testing 60 | For performing predictions on new inputs, one can use the 'WNO_testing_(.).py' codes given in the `Testing` folder in previous version. The trained models, that were used to produce results for the WNO paper can be found in the following link: 61 | > [Models](https://drive.google.com/drive/folders/1scfrpChQ1wqFu8VAyieoSrdgHYCbrT6T?usp=sharing) 62 | 63 | However, these trained models are not compatible with the version 2.0.0 codes. Hence, new models need to be trained accordingly. 64 | 65 | ## Dataset 66 | + The training and testing datasets for the (i) Burgers equation with discontinuity in the solution field (section 4.1), (ii) 2-D Allen-Cahn equation (section 4.5), and (iii) Weakly-monthly mean 2m air temperature (section 5) are available in the following link: 67 | > [Dataset-1](https://drive.google.com/drive/folders/1scfrpChQ1wqFu8VAyieoSrdgHYCbrT6T?usp=sharing) \ 68 | The dataset for the Weakly and monthly mean 2m air temperature are downloaded from 'European Centre for Medium-Range Weather Forecasts (ECMEF)' database. For more information on the dataset one can browse the link 69 | [ECMEF](https://www.ecmwf.int/en/forecasts/datasets/browse-reanalysis-datasets). 70 | + The datasets for (i) 1-D Burgers equation ('burgers_data_R10.zip'), (ii) 2-D Darcy flow equation in a rectangular domain ('Darcy_421.zip'), (iii) 2-D time-dependent Navier-Stokes equation ('ns_V1e-3_N5000_T50.zip'), are taken from the following link: 71 | > [Dataset-2](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) 72 | + The datasets for 2-D Darcy flow equation with a notch in triangular domain ('Darcy_Triangular_FNO.mat') and 1-D time-dependent wave advection equation are taken from the following link: 73 | > [Dataset-3](https://github.com/lu-group/deeponet-fno/tree/main/data) 74 | 75 | ## BibTex 76 | If you use any part our codes, please cite us at, 77 | ``` 78 | @article{tripura2023wavelet, 79 | title={Wavelet Neural Operator for solving parametric partial differential equations in computational mechanics problems}, 80 | author={Tripura, Tapas and Chakraborty, Souvik}, 81 | journal={Computer Methods in Applied Mechanics and Engineering}, 82 | volume={404}, 83 | pages={115783}, 84 | year={2023}, 85 | publisher={Elsevier} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /Version 2.0.0/Test_wno_super_1d_Burgers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 1-D Burger's equation (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import matplotlib.pyplot as plt 14 | 15 | from timeit import default_timer 16 | from utils import * 17 | from wavelet_convolution import WaveConv1d 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | 22 | # %% 23 | """ The forward operation """ 24 | class WNO1d(nn.Module): 25 | def __init__(self, width, level, layers, size, wavelet, in_channel, grid_range, padding=0): 26 | super(WNO1d, self).__init__() 27 | 28 | """ 29 | The WNO network. It contains l-layers of the Wavelet integral layer. 30 | 1. Lift the input using v(x) = self.fc0 . 31 | 2. l-layers of the integral operators v(j+1)(x) = g(K.v + W.v)(x). 32 | --> W is defined by self.w; K is defined by self.conv. 33 | 3. Project the output of last layer using self.fc1 and self.fc2. 34 | 35 | Input : 2-channel tensor, Initial condition and location (a(x), x) 36 | : shape: (batchsize * x=s * c=2) 37 | Output: Solution of a later timestep (u(x)) 38 | : shape: (batchsize * x=s * c=1) 39 | 40 | Input parameters: 41 | ----------------- 42 | width : scalar, lifting dimension of input 43 | level : scalar, number of wavelet decomposition 44 | layers: scalar, number of wavelet kernel integral blocks 45 | size : scalar, signal length 46 | wavelet: string, wavelet filter 47 | in_channel: scalar, channels in input including grid 48 | grid_range: scalar (for 1D), right support of 1D domain 49 | padding : scalar, size of zero padding 50 | """ 51 | 52 | self.level = level 53 | self.width = width 54 | self.layers = layers 55 | self.size = size 56 | self.wavelet = wavelet 57 | self.in_channel = in_channel 58 | self.grid_range = grid_range 59 | self.padding = padding 60 | 61 | self.conv = nn.ModuleList() 62 | self.w = nn.ModuleList() 63 | 64 | self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x) 65 | for i in range( self.layers ): 66 | self.conv.append( WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet) ) 67 | self.w.append( nn.Conv1d(self.width, self.width, 1) ) 68 | self.fc1 = nn.Linear(self.width, 128) 69 | self.fc2 = nn.Linear(128, 1) 70 | 71 | def forward(self, x): 72 | grid = self.get_grid(x.shape, x.device) 73 | x = torch.cat((x, grid), dim=-1) 74 | x = self.fc0(x) # Shape: Batch * x * Channel 75 | x = x.permute(0, 2, 1) # Shape: Batch * Channel * x 76 | if self.padding != 0: 77 | x = F.pad(x, [0,self.padding]) 78 | 79 | for index, (convl, wl) in enumerate( zip(self.conv, self.w) ): 80 | x = convl(x) + wl(x) 81 | if index != self.layers - 1: # Final layer has no activation 82 | x = F.mish(x) # Shape: Batch * Channel * x 83 | 84 | if self.padding != 0: 85 | x = x[..., :-self.padding] 86 | x = x.permute(0, 2, 1) # Shape: Batch * x * Channel 87 | x = F.gelu( self.fc1(x) ) # Shape: Batch * x * Channel 88 | x = self.fc2(x) # Shape: Batch * x * Channel 89 | return x 90 | 91 | def get_grid(self, shape, device): 92 | # The grid of the solution 93 | batchsize, size_x = shape[0], shape[1] 94 | gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float) 95 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 96 | return gridx.to(device) 97 | 98 | 99 | # %% 100 | """ Model configurations """ 101 | 102 | PATH = 'data/burgers_data_R10.mat' 103 | ntrain = 1000 104 | ntest = 100 105 | 106 | batch_size = 20 107 | learning_rate = 0.001 108 | 109 | epochs = 500 110 | step_size = 50 # weight-decay step size 111 | gamma = 0.5 # weight-decay rate 112 | 113 | wavelet = 'db6' # wavelet basis function 114 | level = 8 # lavel of wavelet decomposition 115 | width = 64 # uplifting dimension 116 | layers = 4 # no of wavelet layers 117 | 118 | sub = 2**3 # subsampling rate 119 | test_sub = 2**2 120 | h = 2**13 // sub # total grid size divided by the subsampling rate 121 | grid_range = 1 122 | in_channel = 2 # (a(x), x) for this case 123 | 124 | # %% 125 | """ Read data """ 126 | 127 | dataloader = MatReader(PATH) 128 | x_data_1024 = dataloader.read_field('a')[:,::sub] 129 | y_data_1024 = dataloader.read_field('u')[:,::sub] 130 | 131 | x_data_2048 = dataloader.read_field('a')[:,::test_sub] 132 | y_data_2048 = dataloader.read_field('u')[:,::test_sub] 133 | 134 | x_train_1024, y_train_1024 = x_data_1024[:ntrain,:], y_data_1024[:ntrain,:] 135 | x_test_1024, y_test_1024 = x_data_1024[-ntest:,:], y_data_1024[-ntest:,:] 136 | 137 | x_train_2048, y_train_2048 = x_data_2048[:ntrain,:], y_data_2048[:ntrain,:] 138 | x_test_2048, y_test_2048 = x_data_2048[-ntest:,:], y_data_2048[-ntest:,:] 139 | 140 | x_train_1024 = x_train_1024[:, :, None] 141 | x_test_1024 = x_test_1024[:, :, None] 142 | 143 | x_train_2048 = x_train_2048[:, :, None] 144 | x_test_2048 = x_test_2048[:, :, None] 145 | 146 | train_loader_1024 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_1024, y_train_1024), 147 | batch_size=batch_size, shuffle=True) 148 | test_loader_1024 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_1024, y_test_1024), 149 | batch_size=batch_size, shuffle=False) 150 | 151 | train_loader_2048 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_2048, y_train_2048), 152 | batch_size=batch_size, shuffle=True) 153 | test_loader_2048 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_2048, y_test_2048), 154 | batch_size=batch_size, shuffle=False) 155 | 156 | # %% 157 | """ The model definition """ 158 | model = WNO1d(width=width, level=level, layers=layers, size=h, wavelet=wavelet, 159 | in_channel=in_channel, grid_range=grid_range).to(device) 160 | print(count_params(model)) 161 | 162 | model.load_state_dict(torch.load('model/WNO_burgers', map_location=device).state_dict()) 163 | myloss = LpLoss(size_average=False) 164 | 165 | # %% 166 | """ Prediction """ 167 | pred_1024, pred_2048 = [], [] 168 | test_e_1024, test_e_2048 = [], [] 169 | with torch.no_grad(): 170 | 171 | index = 0 172 | for x, y in test_loader_1024: 173 | test_l2 = 0 174 | x, y = x.to(device), y.to(device) 175 | 176 | out = model(x) 177 | test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 178 | 179 | test_e_1024.append( test_l2/batch_size ) 180 | pred_1024.append( out.cpu() ) 181 | print("Batch-{}, Train and Test at {}, Test-loss-{:0.6f}".format( index, 2**13//sub, test_l2/batch_size )) 182 | index += 1 183 | 184 | index = 0 185 | for x, y in test_loader_2048: 186 | test_l2 = 0 187 | x, y = x.to(device), y.to(device) 188 | 189 | out = model(x) 190 | test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 191 | 192 | test_e_2048.append( test_l2/batch_size ) 193 | pred_2048.append( out.cpu() ) 194 | print("Batch-{}, Train at {} and Test at {}, Test-loss-{:0.6f}".format( index, 2**13//sub, 195 | 2**13//test_sub, test_l2/batch_size )) 196 | index += 1 197 | 198 | pred_1024 = torch.cat(( pred_1024 )) 199 | pred_2048 = torch.cat(( pred_2048 )) 200 | test_e_1024 = torch.tensor(( test_e_1024 )) 201 | test_e_2048 = torch.tensor(( test_e_2048 )) 202 | 203 | print('\nMean Error: Resolution-1024-{:0.4f}, Resolution-2048-{:0.4f}' 204 | .format(100*torch.mean(test_e_1024).numpy(), 100*torch.mean(test_e_2048).numpy())) 205 | 206 | # %% 207 | plt.rcParams['font.family'] = 'Times New Roman' 208 | plt.rcParams['font.size'] = 14 209 | plt.rcParams['mathtext.fontset'] = 'dejavuserif' 210 | 211 | colormap = plt.cm.rainbow 212 | colors = [colormap(i) for i in np.linspace(0, 1, 5)] 213 | 214 | """ Plotting """ 215 | figure7, ax = plt.subplots(nrows=2, ncols=1, figsize=(12, 8), dpi=300) 216 | plt.subplots_adjust(hspace=0.35) 217 | index = 0 218 | for i in range(ntest): 219 | if i % 20 == 0: 220 | ax[0].plot(y_test_1024[i, :].cpu().numpy(), color=colors[index], label='Truth-s{}'.format(i)) 221 | ax[0].plot(pred_1024[i,:].cpu().numpy(), '--', color=colors[index], label='WNO-s{}'.format(i)) 222 | ax[0].grid(True, alpha=0.35) 223 | ax[0].legend(ncol=5, columnspacing=0.5, labelspacing=0.25, handletextpad=0.25, borderpad=0.15) 224 | ax[0].margins(0) 225 | ax[0].set_title('Train at resolution 1024', fontweight='bold', fontsize=plt.rcParams['font.size']*1.2) 226 | 227 | ax[1].plot(y_test_2048[i, :].cpu().numpy(), color=colors[index], label='Truth-s{}'.format(i)) 228 | ax[1].plot(pred_2048[i,:].cpu().numpy(), '--', color=colors[index], label='WNO-s{}'.format(i)) 229 | ax[1].grid(True, alpha=0.35) 230 | ax[1].legend(ncol=5, columnspacing=0.5, labelspacing=0.25, handletextpad=0.25, borderpad=0.15) 231 | ax[1].margins(0) 232 | ax[1].set_title('Test at resolution 2048', fontweight='bold', fontsize=plt.rcParams['font.size']*1.2) 233 | index += 1 234 | ax[0].set_xlabel('Space') 235 | ax[1].set_xlabel('Space') 236 | ax[0].set_ylabel('$u(x,1)$') 237 | ax[1].set_ylabel('$u(x,1)$') 238 | figure7.suptitle('Superresolution in Burgers equation', fontweight='bold', y=0.95, fontsize=plt.rcParams['font.size']*1.4) 239 | plt.show() 240 | 241 | # figure7.savefig('Burgers_prediction.png', format='png', dpi=300, bbox_inches='tight') 242 | -------------------------------------------------------------------------------- /Version 2.0.0/__pycache__/utilities3.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Version 2.0.0/__pycache__/utilities3.cpython-39.pyc -------------------------------------------------------------------------------- /Version 2.0.0/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Version 2.0.0/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /Version 2.0.0/__pycache__/wavelet_convolution.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Version 2.0.0/__pycache__/wavelet_convolution.cpython-39.pyc -------------------------------------------------------------------------------- /Version 2.0.0/data/Burger_data/burgerbc.m: -------------------------------------------------------------------------------- 1 | function [pl,ql,pr,qr] = burgerbc(xl,ul,xr,ur,t) 2 | pl = ul; 3 | ql = 0; 4 | pr = ur; 5 | qr = 0; 6 | end -------------------------------------------------------------------------------- /Version 2.0.0/data/Burger_data/burgeric.m: -------------------------------------------------------------------------------- 1 | function u0 = burgeric(x) 2 | global np 3 | u0 = -sin(pi*x)+np*sin(pi*x); 4 | end -------------------------------------------------------------------------------- /Version 2.0.0/data/Burger_data/burgerpde.m: -------------------------------------------------------------------------------- 1 | function [c,f,s] = burgerpde(x,t,u,dudx) 2 | c = 1; 3 | f = (0.01/pi)*dudx; 4 | s = -u*dudx; 5 | end -------------------------------------------------------------------------------- /Version 2.0.0/data/Burger_data/main_burger.m: -------------------------------------------------------------------------------- 1 | clear 2 | close all 3 | clc 4 | 5 | L = 1; 6 | x = linspace(-L, L, 512); 7 | t = linspace(0, 1, 51); 8 | sample = 500; 9 | a = 0; 10 | b = 0.5; 11 | r = (b-a).*rand(sample,1) + a; 12 | 13 | global np 14 | rng(1) 15 | 16 | sol = zeros(sample,512,51); 17 | m=0; 18 | % np=0; 19 | % sol1 = pdepe(m,@burgerpde,@burgeric,@burgerbc,x,t); 20 | % surf(sol1) 21 | 22 | for i=1:sample 23 | i 24 | np = r(i); 25 | sol1 = pdepe(m,@burgerpde,@burgeric,@burgerbc,x,t); 26 | sol(i,:,:) = sol1'; 27 | end 28 | 29 | %% 30 | figure; 31 | index = 1; 32 | for i=1:sample 33 | if mod(i,50)==0 34 | i 35 | subplot(2,5,index); imagesc(squeeze(sol(i,:,:))); 36 | index = index+1; 37 | end 38 | end 39 | 40 | %% 41 | figure; surf(squeeze(sol(499,:,:))); 42 | 43 | save ('burgers_data_512_51.mat', 't', 'x', 'sol' ) 44 | -------------------------------------------------------------------------------- /Version 2.0.0/data/test_IC2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Version 2.0.0/data/test_IC2.npz -------------------------------------------------------------------------------- /Version 2.0.0/data/train_IC2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Version 2.0.0/data/train_IC2.npz -------------------------------------------------------------------------------- /Version 2.0.0/model/WNO_burgers: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccm-iitd/WNO/ddd810f68424d8ef0dbde3ca3ff4783146bad349/Version 2.0.0/model/WNO_burgers -------------------------------------------------------------------------------- /Version 2.0.0/utils.py: -------------------------------------------------------------------------------- 1 | """ Load required packages """ 2 | 3 | import torch 4 | import numpy as np 5 | import scipy.io 6 | import h5py 7 | import torch.nn as nn 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | 14 | """ Utility functions for --loading data, 15 | --data normalization, 16 | --data standerdization, and 17 | --loss evaluation 18 | 19 | The base codes are taken from the repo: https://github.com/zongyi-li/fourier_neural_operator 20 | """ 21 | 22 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 23 | 24 | # reading data 25 | class MatReader(object): 26 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 27 | super(MatReader, self).__init__() 28 | 29 | self.to_torch = to_torch 30 | self.to_cuda = to_cuda 31 | self.to_float = to_float 32 | 33 | self.file_path = file_path 34 | 35 | self.data = None 36 | self.old_mat = None 37 | self._load_file() 38 | 39 | def _load_file(self): 40 | try: 41 | self.data = scipy.io.loadmat(self.file_path) 42 | self.old_mat = True 43 | except: 44 | self.data = h5py.File(self.file_path) 45 | self.old_mat = False 46 | 47 | def load_file(self, file_path): 48 | self.file_path = file_path 49 | self._load_file() 50 | 51 | def read_field(self, field): 52 | x = self.data[field] 53 | 54 | if not self.old_mat: 55 | x = x[()] 56 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 57 | 58 | if self.to_float: 59 | x = x.astype(np.float32) 60 | 61 | if self.to_torch: 62 | x = torch.from_numpy(x) 63 | 64 | if self.to_cuda: 65 | x = x.to(device) 66 | 67 | return x 68 | 69 | def set_cuda(self, to_cuda): 70 | self.to_cuda = to_cuda 71 | 72 | def set_torch(self, to_torch): 73 | self.to_torch = to_torch 74 | 75 | def set_float(self, to_float): 76 | self.to_float = to_float 77 | 78 | # normalization, pointwise gaussian 79 | class UnitGaussianNormalizer(object): 80 | def __init__(self, x, eps=0.00001): 81 | super(UnitGaussianNormalizer, self).__init__() 82 | 83 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 84 | self.mean = torch.mean(x, 0) 85 | self.std = torch.std(x, 0) 86 | self.eps = eps 87 | 88 | def encode(self, x): 89 | x = (x - self.mean) / (self.std + self.eps) 90 | return x 91 | 92 | def decode(self, x, sample_idx=None): 93 | if sample_idx is None: 94 | std = self.std + self.eps # n 95 | mean = self.mean 96 | else: 97 | if len(self.mean.shape) == len(sample_idx[0].shape): 98 | std = self.std[sample_idx] + self.eps # batch*n 99 | mean = self.mean[sample_idx] 100 | if len(self.mean.shape) > len(sample_idx[0].shape): 101 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 102 | mean = self.mean[:,sample_idx] 103 | 104 | # x is in shape of batch*n or T*batch*n 105 | x = (x * std) + mean 106 | return x 107 | 108 | def cuda(self): 109 | self.mean = self.mean.cuda() 110 | self.std = self.std.cuda() 111 | 112 | def cpu(self): 113 | self.mean = self.mean.cpu() 114 | self.std = self.std.cpu() 115 | 116 | def to(self, device): 117 | self.mean = self.mean.to(device) 118 | self.std = self.std.to(device) 119 | 120 | # normalization, scaling by range 121 | class RangeNormalizer(object): 122 | def __init__(self, x, low=0.0, high=1.0): 123 | super(RangeNormalizer, self).__init__() 124 | mymin = torch.min(x, 0)[0].view(-1) 125 | mymax = torch.max(x, 0)[0].view(-1) 126 | 127 | self.a = (high - low)/(mymax - mymin) 128 | self.b = -self.a*mymax + high 129 | 130 | def encode(self, x): 131 | s = x.size() 132 | x = x.view(s[0], -1) 133 | x = self.a*x + self.b 134 | x = x.view(s) 135 | return x 136 | 137 | def decode(self, x): 138 | s = x.size() 139 | x = x.view(s[0], -1) 140 | x = (x - self.b)/self.a 141 | x = x.view(s) 142 | return x 143 | 144 | # loss function with rel/abs Lp loss 145 | class LpLoss(object): 146 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 147 | super(LpLoss, self).__init__() 148 | 149 | #Dimension and Lp-norm type are postive 150 | assert d > 0 and p > 0 151 | 152 | self.d = d 153 | self.p = p 154 | self.reduction = reduction 155 | self.size_average = size_average 156 | 157 | def abs(self, x, y): 158 | num_examples = x.size()[0] 159 | 160 | #Assume uniform mesh 161 | h = 1.0 / (x.size()[1] - 1.0) 162 | 163 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 164 | 165 | if self.reduction: 166 | if self.size_average: 167 | return torch.mean(all_norms) 168 | else: 169 | return torch.sum(all_norms) 170 | 171 | return all_norms 172 | 173 | def rel(self, x, y): 174 | num_examples = x.size()[0] 175 | 176 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 177 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 178 | 179 | if self.reduction: 180 | if self.size_average: 181 | return torch.mean(diff_norms/y_norms) 182 | else: 183 | return torch.sum(diff_norms/y_norms) 184 | 185 | return diff_norms/y_norms 186 | 187 | def __call__(self, x, y): 188 | return self.rel(x, y) 189 | 190 | # print the number of parameters 191 | def count_params(model): 192 | c = 0 193 | for p in list(model.parameters()): 194 | c += reduce(operator.mul, 195 | list(p.size()+(2,) if p.is_complex() else p.size())) 196 | return c 197 | -------------------------------------------------------------------------------- /Version 2.0.0/wno1d_Burgers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 1-D Burger's equation (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import matplotlib.pyplot as plt 14 | 15 | from timeit import default_timer 16 | from utils import * 17 | from wavelet_convolution import WaveConv1d 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | 22 | # %% 23 | """ The forward operation """ 24 | class WNO1d(nn.Module): 25 | def __init__(self, width, level, layers, size, wavelet, in_channel, grid_range, padding=0): 26 | super(WNO1d, self).__init__() 27 | 28 | """ 29 | The WNO network. It contains l-layers of the Wavelet integral layer. 30 | 1. Lift the input using v(x) = self.fc0 . 31 | 2. l-layers of the integral operators v(j+1)(x) = g(K.v + W.v)(x). 32 | --> W is defined by self.w; K is defined by self.conv. 33 | 3. Project the output of last layer using self.fc1 and self.fc2. 34 | 35 | Input : 2-channel tensor, Initial condition and location (a(x), x) 36 | : shape: (batchsize * x=s * c=2) 37 | Output: Solution of a later timestep (u(x)) 38 | : shape: (batchsize * x=s * c=1) 39 | 40 | Input parameters: 41 | ----------------- 42 | width : scalar, lifting dimension of input 43 | level : scalar, number of wavelet decomposition 44 | layers: scalar, number of wavelet kernel integral blocks 45 | size : scalar, signal length 46 | wavelet: string, wavelet filter 47 | in_channel: scalar, channels in input including grid 48 | grid_range: scalar (for 1D), right support of 1D domain 49 | padding : scalar, size of zero padding 50 | """ 51 | 52 | self.level = level 53 | self.width = width 54 | self.layers = layers 55 | self.size = size 56 | self.wavelet = wavelet 57 | self.in_channel = in_channel 58 | self.grid_range = grid_range 59 | self.padding = padding 60 | 61 | self.conv = nn.ModuleList() 62 | self.w = nn.ModuleList() 63 | 64 | self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x) 65 | for i in range( self.layers ): 66 | self.conv.append( WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet) ) 67 | self.w.append( nn.Conv1d(self.width, self.width, 1) ) 68 | self.fc1 = nn.Linear(self.width, 128) 69 | self.fc2 = nn.Linear(128, 1) 70 | 71 | def forward(self, x): 72 | grid = self.get_grid(x.shape, x.device) 73 | x = torch.cat((x, grid), dim=-1) 74 | x = self.fc0(x) # Shape: Batch * x * Channel 75 | x = x.permute(0, 2, 1) # Shape: Batch * Channel * x 76 | if self.padding != 0: 77 | x = F.pad(x, [0,self.padding]) 78 | 79 | for index, (convl, wl) in enumerate( zip(self.conv, self.w) ): 80 | x = convl(x) + wl(x) 81 | if index != self.layers - 1: # Final layer has no activation 82 | x = F.mish(x) # Shape: Batch * Channel * x 83 | 84 | if self.padding != 0: 85 | x = x[..., :-self.padding] 86 | x = x.permute(0, 2, 1) # Shape: Batch * x * Channel 87 | x = F.gelu( self.fc1(x) ) # Shape: Batch * x * Channel 88 | x = self.fc2(x) # Shape: Batch * x * Channel 89 | return x 90 | 91 | def get_grid(self, shape, device): 92 | # The grid of the solution 93 | batchsize, size_x = shape[0], shape[1] 94 | gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float) 95 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 96 | return gridx.to(device) 97 | 98 | 99 | # %% 100 | """ Model configurations """ 101 | 102 | PATH = 'data/burgers_data_R10.mat' 103 | ntrain = 1000 104 | ntest = 100 105 | 106 | batch_size = 20 107 | learning_rate = 0.001 108 | 109 | epochs = 500 110 | step_size = 50 # weight-decay step size 111 | gamma = 0.5 # weight-decay rate 112 | 113 | wavelet = 'db6' # wavelet basis function 114 | level = 8 # lavel of wavelet decomposition 115 | width = 64 # uplifting dimension 116 | layers = 4 # no of wavelet layers 117 | 118 | sub = 2**3 # subsampling rate 119 | h = 2**13 // sub # total grid size divided by the subsampling rate 120 | grid_range = 1 121 | in_channel = 2 # (a(x), x) for this case 122 | 123 | # %% 124 | """ Read data """ 125 | 126 | dataloader = MatReader(PATH) 127 | x_data = dataloader.read_field('a')[:,::sub] 128 | y_data = dataloader.read_field('u')[:,::sub] 129 | 130 | x_train = x_data[:ntrain,:] 131 | y_train = y_data[:ntrain,:] 132 | x_test = x_data[-ntest:,:] 133 | y_test = y_data[-ntest:,:] 134 | 135 | x_train = x_train[:, :, None] 136 | x_test = x_test[:, :, None] 137 | 138 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 139 | batch_size=batch_size, shuffle=True) 140 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 141 | batch_size=batch_size, shuffle=False) 142 | 143 | # %% 144 | """ The model definition """ 145 | model = WNO1d(width=width, level=level, layers=layers, size=h, wavelet=wavelet, 146 | in_channel=in_channel, grid_range=grid_range).to(device) 147 | print(count_params(model)) 148 | 149 | """ Training and testing """ 150 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) 151 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 152 | 153 | train_loss = torch.zeros(epochs) 154 | test_loss = torch.zeros(epochs) 155 | myloss = LpLoss(size_average=False) 156 | for ep in range(epochs): 157 | model.train() 158 | t1 = default_timer() 159 | train_mse = 0 160 | train_l2 = 0 161 | for x, y in train_loader: 162 | x, y = x.to(device), y.to(device) 163 | 164 | optimizer.zero_grad() 165 | out = model(x) 166 | 167 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1)) 168 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 169 | l2.backward() # l2 relative loss 170 | 171 | optimizer.step() 172 | train_mse += mse.item() 173 | train_l2 += l2.item() 174 | 175 | scheduler.step() 176 | model.eval() 177 | test_l2 = 0.0 178 | with torch.no_grad(): 179 | for x, y in test_loader: 180 | x, y = x.to(device), y.to(device) 181 | 182 | out = model(x) 183 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 184 | 185 | train_mse /= len(train_loader) 186 | train_l2 /= ntrain 187 | test_l2 /= ntest 188 | 189 | train_loss[ep] = train_l2 190 | test_loss[ep] = test_l2 191 | 192 | t2 = default_timer() 193 | print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}' 194 | .format(ep, t2-t1, train_mse, train_l2, test_l2)) 195 | 196 | # %% 197 | """ Prediction """ 198 | pred = [] 199 | test_e = [] 200 | with torch.no_grad(): 201 | 202 | index = 0 203 | for x, y in test_loader: 204 | test_l2 = 0 205 | x, y = x.to(device), y.to(device) 206 | 207 | out = model(x) 208 | test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 209 | 210 | test_e.append( test_l2/batch_size ) 211 | pred.append( out ) 212 | print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2/batch_size )) 213 | index += 1 214 | 215 | pred = torch.cat((pred)) 216 | test_e = torch.tensor((test_e)) 217 | print('Mean Error:', 100*torch.mean(test_e).numpy()) 218 | 219 | # %% 220 | plt.rcParams['font.family'] = 'Times New Roman' 221 | plt.rcParams['font.size'] = 12 222 | plt.rcParams['mathtext.fontset'] = 'dejavuserif' 223 | 224 | colormap = plt.cm.jet 225 | colors = [colormap(i) for i in np.linspace(0, 1, 5)] 226 | 227 | """ Plotting """ 228 | figure7 = plt.figure(figsize = (10, 4), dpi=300) 229 | index = 0 230 | for i in range(y_test.shape[0]): 231 | if i % 20 == 1: 232 | plt.plot(y_test[i, :].cpu().numpy(), color=colors[index], label='Actual') 233 | plt.plot(pred[i,:].cpu().numpy(), '--', color=colors[index], label='Prediction') 234 | index += 1 235 | plt.legend(ncol=5) 236 | plt.grid(True) 237 | plt.margins(0) 238 | 239 | # %% 240 | """ 241 | For saving the trained model and prediction data 242 | """ 243 | torch.save(model, 'model/WNO_burgers') 244 | scipy.io.savemat('results/wno_results_burgers.mat', mdict={'x_test':x_test.cpu().numpy(), 245 | 'y_test':y_test.cpu().numpy(), 246 | 'pred':pred.cpu().numpy(), 247 | 'test_e':test_e.cpu().numpy()}) 248 | -------------------------------------------------------------------------------- /Version 2.0.0/wno1d_advection_III.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | This code belongs to the paper: 5 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet neural operator: a neural 6 | operator for parametric partial differential equations. arXiv preprint arXiv:2205.02191. 7 | 8 | This code is for 1-D wave advection equation (time-independent problem). 9 | """ 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import matplotlib.pyplot as plt 16 | 17 | from timeit import default_timer 18 | from utils import * 19 | from wavelet_convolution import WaveConv1d 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | # %% 25 | """ The forward operation """ 26 | class WNO1d(nn.Module): 27 | def __init__(self, width, level, layers, size, wavelet, in_channel, grid_range, padding=0): 28 | super(WNO1d, self).__init__() 29 | 30 | """ 31 | The WNO network. It contains l-layers of the Wavelet integral layer. 32 | 1. Lift the input using v(x) = self.fc0 . 33 | 2. l-layers of the integral operators v(j+1)(x) = g(K.v + W.v)(x). 34 | --> W is defined by self.w; K is defined by self.conv. 35 | 3. Project the output of last layer using self.fc1 and self.fc2. 36 | 37 | Input : 2-channel tensor, Initial condition and location (a(x), x) 38 | : shape: (batchsize * x=s * c=2) 39 | Output: Solution of a later timestep (u(x)) 40 | : shape: (batchsize * x=s * c=1) 41 | 42 | Input parameters: 43 | ----------------- 44 | width : scalar, lifting dimension of input 45 | level : scalar, number of wavelet decomposition 46 | layers: scalar, number of wavelet kernel integral blocks 47 | size : scalar, signal length 48 | wavelet: string, wavelet filter 49 | in_channel: scalar, channels in input including grid 50 | grid_range: scalar (for 1D), right support of 1D domain 51 | padding : scalar, size of zero padding 52 | """ 53 | 54 | self.level = level 55 | self.width = width 56 | self.layers = layers 57 | self.size = size 58 | self.wavelet = wavelet 59 | self.in_channel = in_channel 60 | self.grid_range = grid_range 61 | self.padding = padding 62 | 63 | self.conv = nn.ModuleList() 64 | self.w = nn.ModuleList() 65 | 66 | self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x) 67 | for i in range( self.layers ): 68 | self.conv.append( WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet) ) 69 | self.w.append( nn.Conv1d(self.width, self.width, 1) ) 70 | self.fc1 = nn.Linear(self.width, 128) 71 | self.fc2 = nn.Linear(128, 1) 72 | 73 | def forward(self, x): 74 | grid = self.get_grid(x.shape, x.device) 75 | x = torch.cat((x, grid), dim=-1) 76 | x = self.fc0(x) # Shape: Batch * x * Channel 77 | x = x.permute(0, 2, 1) # Shape: Batch * Channel * x 78 | if self.padding != 0: 79 | x = F.pad(x, [0,self.padding]) 80 | 81 | for index, (convl, wl) in enumerate( zip(self.conv, self.w) ): 82 | x = convl(x) + wl(x) 83 | if index != self.layers - 1: # Final layer has no activation 84 | x = F.mish(x) # Shape: Batch * Channel * x 85 | 86 | if self.padding != 0: 87 | x = x[..., :-self.padding] 88 | x = x.permute(0, 2, 1) # Shape: Batch * x * Channel 89 | x = F.gelu( self.fc1(x) ) # Shape: Batch * x * Channel 90 | x = self.fc2(x) # Shape: Batch * x * Channel 91 | return x 92 | 93 | def get_grid(self, shape, device): 94 | # The grid of the solution 95 | batchsize, size_x = shape[0], shape[1] 96 | gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float) 97 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 98 | return gridx.to(device) 99 | 100 | # %% 101 | """ Model configurations """ 102 | 103 | PATH = 'data/train_IC2.npz' 104 | ntrain = 900 105 | ntest = 100 106 | 107 | batch_size = 20 108 | learning_rate = 0.001 109 | 110 | epochs = 500 111 | step_size = 50 # weight-decay step size 112 | gamma = 0.5 # weight-decay rate 113 | 114 | wavelet = 'db6' # wavelet basis function 115 | level = 3 # lavel of wavelet decomposition 116 | width = 96 # uplifting dimension 117 | layers = 4 # no of wavelet layers 118 | 119 | h = 40 # total grid size divided by the subsampling rate 120 | grid_range = 1 121 | in_channel = 2 # (a(x), x) for this case 122 | 123 | # %% 124 | """ Read data """ 125 | 126 | # Data is of the shape (number of samples, grid size) 127 | data = np.load(PATH) 128 | x, t, u_train = data["x"], data["t"], data["u"] # N x nt x nx 129 | 130 | x_data = u_train[:, 0, :] # N x nx, initial solution 131 | y_data = u_train[:, -2, :] # N x nx, final solution 132 | 133 | x_data = torch.tensor(x_data) 134 | y_data = torch.tensor(y_data) 135 | 136 | x_train = x_data[:ntrain,:] 137 | y_train = y_data[:ntrain,:] 138 | x_test = x_data[-ntest:,:] 139 | y_test = y_data[-ntest:,:] 140 | 141 | x_train = x_train[:, :, None] 142 | x_test = x_test[:, :, None] 143 | 144 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 145 | batch_size=batch_size, shuffle=True) 146 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 147 | batch_size=batch_size, shuffle=False) 148 | 149 | # %% 150 | """ The model definition """ 151 | model = WNO1d(width=width, level=level, layers=layers, size=h, wavelet=wavelet, 152 | in_channel=in_channel, grid_range=grid_range).to(device) 153 | print(count_params(model)) 154 | 155 | 156 | """ Training and testing """ 157 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) 158 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 159 | 160 | train_loss = torch.zeros(epochs) 161 | test_loss = torch.zeros(epochs) 162 | myloss = LpLoss(size_average=False) 163 | for ep in range(epochs): 164 | model.train() 165 | t1 = default_timer() 166 | train_mse = 0 167 | train_l2 = 0 168 | for x, y in train_loader: 169 | x, y = x.to(device), y.to(device) 170 | 171 | optimizer.zero_grad() 172 | out = model(x) 173 | 174 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1)) 175 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 176 | l2.backward() # l2 relative loss 177 | 178 | optimizer.step() 179 | train_mse += mse.item() 180 | train_l2 += l2.item() 181 | 182 | scheduler.step() 183 | model.eval() 184 | test_l2 = 0.0 185 | with torch.no_grad(): 186 | for x, y in test_loader: 187 | x, y = x.to(device), y.to(device) 188 | 189 | out = model(x) 190 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 191 | 192 | train_mse /= len(train_loader) 193 | train_l2 /= ntrain 194 | test_l2 /= ntest 195 | 196 | train_loss[ep] = train_l2 197 | test_loss[ep] = test_l2 198 | 199 | t2 = default_timer() 200 | print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}' 201 | .format(ep, t2-t1, train_mse, train_l2, test_l2)) 202 | 203 | # %% 204 | """ Prediction """ 205 | pred = [] 206 | test_e = [] 207 | with torch.no_grad(): 208 | 209 | index = 0 210 | for x, y in test_loader: 211 | test_l2 = 0 212 | x, y = x.to(device), y.to(device) 213 | 214 | out = model(x) 215 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 216 | 217 | test_e.append( test_l2/batch_size ) 218 | pred.append( out ) 219 | print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2/batch_size )) 220 | index += 1 221 | 222 | pred = torch.cat((pred)) 223 | test_e = torch.tensor((test_e)) 224 | print('Mean Error:', 100*torch.mean(test_e).numpy()) 225 | 226 | # %% 227 | """ Plotting """ 228 | plt.rcParams['font.family'] = 'Times New Roman' 229 | plt.rcParams['font.size'] = 12 230 | plt.rcParams['mathtext.fontset'] = 'dejavuserif' 231 | 232 | colormap = plt.cm.jet 233 | colors = [colormap(i) for i in np.linspace(0, 1, 5)] 234 | 235 | """ Plotting """ 236 | figure7 = plt.figure(figsize = (10, 4), dpi=300) 237 | index = 0 238 | for i in range(y_test.shape[0]): 239 | if i % 20 == 1: 240 | plt.plot(y_test[i, :].cpu().numpy(), color=colors[index], label='Actual') 241 | plt.plot(pred[i,:].cpu().numpy(), '--', color=colors[index], label='Prediction') 242 | index += 1 243 | plt.legend(ncol=5) 244 | plt.grid(True) 245 | plt.margins(0) 246 | 247 | # %% 248 | """ 249 | For saving the trained model and prediction data 250 | """ 251 | torch.save(model, 'model/WNO_advection_time_independent') 252 | scipy.io.savemat('results/wno_results_advection_time_independent.mat', mdict={'x_test':x_test.cpu().numpy(), 253 | 'y_test':y_test.cpu().numpy(), 254 | 'pred':pred.cpu().numpy(), 255 | 'test_e':test_e.cpu().numpy()}) 256 | -------------------------------------------------------------------------------- /Version 2.0.0/wno2d_Darcy_dwt.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 2-D Darcy equation (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.parameter import Parameter 14 | import matplotlib.pyplot as plt 15 | 16 | from timeit import default_timer 17 | from utils import * 18 | from wavelet_convolution import WaveConv2d 19 | 20 | torch.manual_seed(0) 21 | np.random.seed(0) 22 | 23 | # %% 24 | """ The forward operation """ 25 | class WNO2d(nn.Module): 26 | def __init__(self, width, level, layers, size, wavelet, in_channel, grid_range, padding=0): 27 | super(WNO2d, self).__init__() 28 | 29 | """ 30 | The WNO network. It contains l-layers of the Wavelet integral layer. 31 | 1. Lift the input using v(x) = self.fc0 . 32 | 2. l-layers of the integral operators v(j+1)(x,y) = g(K.v + W.v)(x,y). 33 | --> W is defined by self.w; K is defined by self.conv. 34 | 3. Project the output of last layer using self.fc1 and self.fc2. 35 | 36 | Input : 3-channel tensor, Initial input and location (a(x,y), x,y) 37 | : shape: (batchsize * x=width * x=height * c=3) 38 | Output: Solution of a later timestep (u(x,y)) 39 | : shape: (batchsize * x=width * x=height * c=1) 40 | 41 | Input parameters: 42 | ----------------- 43 | width : scalar, lifting dimension of input 44 | level : scalar, number of wavelet decomposition 45 | layers: scalar, number of wavelet kernel integral blocks 46 | size : list with 2 elements (for 2D), image size 47 | wavelet: string, wavelet filter 48 | in_channel: scalar, channels in input including grid 49 | grid_range: list with 2 elements (for 2D), right supports of 2D domain 50 | padding : scalar, size of zero padding 51 | """ 52 | 53 | self.level = level 54 | self.width = width 55 | self.layers = layers 56 | self.size = size 57 | self.wavelet = wavelet 58 | self.in_channel = in_channel 59 | self.grid_range = grid_range 60 | self.padding = padding 61 | 62 | self.conv = nn.ModuleList() 63 | self.w = nn.ModuleList() 64 | 65 | self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y) 66 | for i in range( self.layers ): 67 | self.conv.append( WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet) ) 68 | self.w.append( nn.Conv2d(self.width, self.width, 1) ) 69 | self.fc1 = nn.Linear(self.width, 128) 70 | self.fc2 = nn.Linear(128, 1) 71 | 72 | def forward(self, x): 73 | grid = self.get_grid(x.shape, x.device) 74 | x = torch.cat((x, grid), dim=-1) 75 | x = self.fc0(x) # Shape: Batch * x * y * Channel 76 | x = x.permute(0, 3, 1, 2) # Shape: Batch * Channel * x * y 77 | if self.padding != 0: 78 | x = F.pad(x, [0,self.padding, 0,self.padding]) 79 | 80 | for index, (convl, wl) in enumerate( zip(self.conv, self.w) ): 81 | x = convl(x) + wl(x) 82 | if index != self.layers - 1: # Final layer has no activation 83 | x = F.mish(x) # Shape: Batch * Channel * x * y 84 | 85 | if self.padding != 0: 86 | x = x[..., :-self.padding, :-self.padding] 87 | x = x.permute(0, 2, 3, 1) # Shape: Batch * x * y * Channel 88 | x = F.gelu( self.fc1(x) ) # Shape: Batch * x * y * Channel 89 | x = self.fc2(x) # Shape: Batch * x * y * Channel 90 | return x 91 | 92 | def get_grid(self, shape, device): 93 | # The grid of the solution 94 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 95 | gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float) 96 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 97 | gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float) 98 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 99 | return torch.cat((gridx, gridy), dim=-1).to(device) 100 | 101 | # %% 102 | """ Model configurations """ 103 | 104 | PATH_Train = 'data/piececonst_r421_N1024_smooth1.mat' 105 | PATH_Test = 'data/piececonst_r421_N1024_smooth2.mat' 106 | ntrain = 1000 107 | ntest = 100 108 | 109 | batch_size = 20 110 | learning_rate = 0.001 111 | 112 | epochs = 500 113 | step_size = 50 # weight-decay step size 114 | gamma = 0.5 # weight-decay rate 115 | 116 | wavelet = 'db6' # wavelet basis function 117 | level = 4 # lavel of wavelet decomposition 118 | width = 64 # uplifting dimension 119 | layers = 4 # no of wavelet layers 120 | 121 | sub = 5 122 | h = int(((421 - 1)/sub) + 1) # total grid size divided by the subsampling rate 123 | grid_range = [1, 1] # The grid boundary in x and y direction 124 | in_channel = 3 # (a(x, y), x, y) for this case 125 | 126 | # %% 127 | """ Read data """ 128 | reader = MatReader(PATH_Train) 129 | x_train = reader.read_field('coeff')[:ntrain,::sub,::sub][:,:h,:h] 130 | y_train = reader.read_field('sol')[:ntrain,::sub,::sub][:,:h,:h] 131 | 132 | reader.load_file(PATH_Test) 133 | x_test = reader.read_field('coeff')[:ntest,::sub,::sub][:,:h,:h] 134 | y_test = reader.read_field('sol')[:ntest,::sub,::sub][:,:h,:h] 135 | 136 | x_normalizer = UnitGaussianNormalizer(x_train) 137 | x_train = x_normalizer.encode(x_train) 138 | x_test = x_normalizer.encode(x_test) 139 | 140 | y_normalizer = UnitGaussianNormalizer(y_train) 141 | y_train = y_normalizer.encode(y_train) 142 | 143 | x_train = x_train.reshape(ntrain,h,h,1) 144 | x_test = x_test.reshape(ntest,h,h,1) 145 | 146 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 147 | batch_size=batch_size, shuffle=True) 148 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 149 | batch_size=batch_size, shuffle=False) 150 | 151 | # %% 152 | """ The model definition """ 153 | model = WNO2d(width=width, level=level, layers=layers, size=[h,h], wavelet=wavelet, 154 | in_channel=in_channel, grid_range=grid_range, padding=1).to(device) 155 | print(count_params(model)) 156 | 157 | """ Training and testing """ 158 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) 159 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 160 | 161 | train_loss = torch.zeros(epochs) 162 | test_loss = torch.zeros(epochs) 163 | myloss = LpLoss(size_average=False) 164 | y_normalizer.to(device) 165 | for ep in range(epochs): 166 | model.train() 167 | t1 = default_timer() 168 | train_mse = 0 169 | train_l2 = 0 170 | for x, y in train_loader: 171 | x, y = x.to(device), y.to(device) 172 | 173 | optimizer.zero_grad() 174 | out = model(x).reshape(batch_size, h, h) 175 | out = y_normalizer.decode(out) 176 | y = y_normalizer.decode(y) 177 | 178 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1)) 179 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 180 | loss.backward() 181 | optimizer.step() 182 | 183 | train_mse += mse.item() 184 | train_l2 += loss.item() 185 | 186 | scheduler.step() 187 | model.eval() 188 | test_l2 = 0.0 189 | with torch.no_grad(): 190 | for x, y in test_loader: 191 | x, y = x.to(device), y.to(device) 192 | 193 | out = model(x).reshape(batch_size, h, h) 194 | out = y_normalizer.decode(out) 195 | 196 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 197 | 198 | train_mse /= len(train_loader) 199 | train_l2/= ntrain 200 | test_l2 /= ntest 201 | 202 | train_loss[ep] = train_l2 203 | test_loss[ep] = test_l2 204 | 205 | t2 = default_timer() 206 | print("Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}" 207 | .format(ep, t2-t1, train_mse, train_l2, test_l2)) 208 | 209 | # %% 210 | """ Prediction """ 211 | pred = [] 212 | test_e = [] 213 | with torch.no_grad(): 214 | 215 | index = 0 216 | for x, y in test_loader: 217 | test_l2 = 0 218 | x, y = x.to(device), y.to(device) 219 | 220 | out = model(x).reshape(batch_size, h, h) 221 | out = y_normalizer.decode(out) 222 | pred.append( out.cpu() ) 223 | 224 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 225 | test_e.append( test_l2/batch_size ) 226 | 227 | print("Batch-{}, Loss-{}".format(index, test_l2/batch_size) ) 228 | index += 1 229 | 230 | pred = torch.cat((pred)) 231 | test_e = torch.tensor((test_e)) 232 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 233 | 234 | # %% 235 | """ Plotting """ 236 | plt.rcParams["font.family"] = "serif" 237 | plt.rcParams['font.size'] = 14 238 | 239 | figure1 = plt.figure(figsize = (18, 14)) 240 | figure1.text(0.04,0.17,'\n Error', rotation=90, color='purple', fontsize=20) 241 | figure1.text(0.04,0.34,'\n Prediction', rotation=90, color='green', fontsize=20) 242 | figure1.text(0.04,0.57,'\n Truth', rotation=90, color='red', fontsize=20) 243 | figure1.text(0.04,0.75,'Permeability \n field', rotation=90, color='b', fontsize=20) 244 | plt.subplots_adjust(wspace=0.7) 245 | index = 0 246 | for value in range(y_test.shape[0]): 247 | if value % 26 == 1: 248 | plt.subplot(4,4, index+1) 249 | plt.imshow(x_test[value,:,:,0], cmap='rainbow', extent=[0,1,0,1], interpolation='Gaussian') 250 | plt.title('a(x,y)-{}'.format(index+1), color='b', fontsize=20, fontweight='bold') 251 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 252 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 253 | 254 | plt.subplot(4,4, index+1+4) 255 | plt.imshow(y_test[value,:,:], cmap='rainbow', extent=[0,1,0,1], interpolation='Gaussian') 256 | plt.colorbar(fraction=0.045) 257 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 258 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 259 | 260 | plt.subplot(4,4, index+1+8) 261 | plt.imshow(pred[value,:,:], cmap='rainbow', extent=[0,1,0,1], interpolation='Gaussian') 262 | plt.colorbar(fraction=0.045) 263 | plt.xlabel('x',fontweight='bold'); plt.ylabel('y',fontweight='bold') 264 | plt.xticks(fontweight='bold'); plt.yticks(fontweight='bold'); 265 | 266 | plt.subplot(4,4, index+1+12) 267 | plt.imshow(np.abs(pred[value,:,:]-y_test[value,:,:]), cmap='jet', extent=[0,1,0,1], interpolation='Gaussian') 268 | plt.xlabel('x', fontweight='bold'); plt.ylabel('y', fontweight='bold'); 269 | plt.colorbar(fraction=0.045,format='%.0e') 270 | 271 | plt.margins(0) 272 | index = index + 1 273 | 274 | # %% 275 | """ 276 | For saving the trained model and prediction data 277 | """ 278 | torch.save(model, 'model/WNO_darcy') 279 | scipy.io.savemat('results/wno_results_darcy.mat', mdict={'x_test':x_test.cpu().numpy(), 280 | 'y_test':y_test.cpu().numpy(), 281 | 'pred':pred.cpu().numpy(), 282 | 'test_e':test_e.cpu().numpy()}) 283 | -------------------------------------------------------------------------------- /Version 2.0.0/wno2d_Temperature_Monthly_Avg.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for forecast of monthly averaged 2m air temperature (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.parameter import Parameter 14 | import matplotlib.pyplot as plt 15 | 16 | import xarray as xr 17 | from timeit import default_timer 18 | from utils import * 19 | from wavelet_convolution import WaveConv2dCwt 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | # %% 25 | """ The forward operation """ 26 | class WNO2d(nn.Module): 27 | def __init__(self, width, level, layers, size, wavelet, in_channel, xgrid_range, ygrid_range, padding=0): 28 | super(WNO2d, self).__init__() 29 | 30 | """ 31 | The WNO network. It contains l-layers of the Wavelet integral layer. 32 | 1. Lift the input using v(x) = self.fc0 . 33 | 2. l-layers of the integral operators v(j+1)(x,y) = g(K.v + W.v)(x,y). 34 | --> W is defined by self.w; K is defined by self.conv. 35 | 3. Project the output of last layer using self.fc1 and self.fc2. 36 | 37 | Input : (T_in+1)-channel tensor, solution at t0-t_T and location (u(x,y,t0),...u(x,y,t_T), x,y) 38 | : shape: (batchsize * x=width * x=height * c=T_in+1) 39 | Output: Solution of a later timestep (u(x, T_in+1)) 40 | : shape: (batchsize * x=width * x=height * c=1) 41 | 42 | Input parameters: 43 | ----------------- 44 | width : scalar, lifting dimension of input 45 | level : scalar, number of wavelet decomposition 46 | layers: scalar, number of wavelet kernel integral blocks 47 | size : list with 2 elements (for 2D), image size 48 | wavelet: list of strings for 2D, wavelet filter 49 | in_channel: scalar, channels in input including grid 50 | grid_range: list with 2 elements (for 2D), right supports of 2D domain 51 | padding : scalar, size of zero padding 52 | """ 53 | 54 | self.level = level 55 | self.width = width 56 | self.layers = layers 57 | self.size = size 58 | self.wavelet1 = wavelet[0] 59 | self.wavelet2 = wavelet[1] 60 | self.in_channel = in_channel 61 | self.xgrid_range = xgrid_range 62 | self.ygrid_range = ygrid_range 63 | self.padding = padding 64 | 65 | self.conv = nn.ModuleList() 66 | self.w = nn.ModuleList() 67 | 68 | self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y) 69 | for i in range( self.layers ): 70 | self.conv.append( WaveConv2dCwt(self.width, self.width, self.level, self.size, 71 | self.wavelet1, self.wavelet2) ) 72 | self.w.append( nn.Conv2d(self.width, self.width, 1) ) 73 | self.fc1 = nn.Linear(self.width, 128) 74 | self.fc2 = nn.Linear(128, 1) 75 | 76 | def forward(self, x): 77 | grid = self.get_grid(x.shape, x.device) 78 | x = torch.cat((x, grid), dim=-1) 79 | x = self.fc0(x) # Shape: Batch * x * y * Channel 80 | x = x.permute(0, 3, 1, 2) # Shape: Batch * Channel * x * y 81 | if self.padding != 0: 82 | x = F.pad(x, [0,self.padding, 0,self.padding]) 83 | 84 | for index, (convl, wl) in enumerate( zip(self.conv, self.w) ): 85 | x = convl(x) + wl(x) 86 | if index != self.layers - 1: # Final layer has no activation 87 | x = F.mish(x) # Shape: Batch * Channel * x * y 88 | 89 | if self.padding != 0: 90 | x = x[..., :-self.padding, :-self.padding] 91 | x = x.permute(0, 2, 3, 1) # Shape: Batch * x * y * Channel 92 | x = F.gelu( self.fc1(x) ) # Shape: Batch * x * y * Channel 93 | x = self.fc2(x) # Shape: Batch * x * y * Channel 94 | return x 95 | 96 | def get_grid(self, shape, device): 97 | # The grid of the solution 98 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 99 | gridx = torch.tensor(np.linspace(self.xgrid_range[0], self.xgrid_range[1], size_x), dtype=torch.float) 100 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 101 | gridy = torch.tensor(np.linspace(self.ygrid_range[0], self.ygrid_range[1], size_y), dtype=torch.float) 102 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 103 | return torch.cat((gridx, gridy), dim=-1).to(device) 104 | 105 | # %% 106 | """ Model configurations """ 107 | 108 | PATH = 'data/ERA5_monthly_average.grib' 109 | ntrain = 500 110 | ntest = 20 111 | 112 | batch_size = 10 113 | learning_rate = 0.001 114 | 115 | epochs = 500 116 | step_size = 50 # weight-decay step size 117 | gamma = 0.5 # weight-decay rate 118 | 119 | wavelet = ['near_sym_b', 'qshift_b'] # wavelet basis function 120 | level = 5 # lavel of wavelet decomposition 121 | width = 64 # uplifting dimension 122 | layers = 4 # no of wavelet layers 123 | 124 | sub = 6 125 | h = int(((721 - 1)/sub)+1) # total grid size divided by the subsampling rate 126 | s = int(((1441 - 1)/sub)+1) 127 | 128 | xgrid_range = [0, 360] # The grid boundary in x direction 129 | ygrid_range = [90, -90] # The grid boundary in y direction 130 | 131 | in_channel = 3 # (a(x, y), x, y) for this case 132 | 133 | # %% 134 | """ Read data """ 135 | 136 | ds = xr.open_dataset(PATH, engine='cfgrib') 137 | data = np.array(ds["t2m"]) 138 | data = torch.tensor(data) 139 | # data = data[:, :720, :] 140 | data = F.pad(data, [0,1]) # pad last dimension to make it periodic, if required 141 | 142 | x_train = data[:ntrain, ::sub, ::sub] 143 | y_train = data[:ntrain, ::sub, ::sub] 144 | 145 | x_test = data[-ntest:, ::sub, ::sub] 146 | y_test = data[-ntest:, ::sub, ::sub] 147 | 148 | x_normalizer = UnitGaussianNormalizer(x_train) 149 | x_train = x_normalizer.encode(x_train) 150 | x_test = x_normalizer.encode(x_test) 151 | 152 | y_normalizer = UnitGaussianNormalizer(y_train) 153 | y_train = y_normalizer.encode(y_train) 154 | 155 | x_train = x_train.reshape(ntrain,h,s,1) 156 | x_test = x_test.reshape(ntest,h,s,1) 157 | 158 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 159 | batch_size=batch_size, shuffle=True) 160 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 161 | batch_size=batch_size, shuffle=False) 162 | 163 | # %% 164 | """ The model definition """ 165 | model = WNO2d(width=width, level=level, layers=layers, size=[h,s], wavelet=wavelet, 166 | in_channel=in_channel, xgrid_range=xgrid_range, ygrid_range=ygrid_range, padding=1).to(device) 167 | print(count_params(model)) 168 | 169 | """ Training and testing """ 170 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) 171 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 172 | 173 | train_loss = torch.zeros(epochs) 174 | test_loss = torch.zeros(epochs) 175 | myloss = LpLoss(size_average=False) 176 | y_normalizer.to(device) 177 | for ep in range(epochs): 178 | model.train() 179 | t1 = default_timer() 180 | train_mse = 0 181 | train_l2 = 0 182 | for x, y in train_loader: 183 | x, y = x.to(device), y.to(device) 184 | 185 | optimizer.zero_grad() 186 | out = model(x).reshape(batch_size, h, s) 187 | out = y_normalizer.decode(out) 188 | y = y_normalizer.decode(y) 189 | 190 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1)) 191 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 192 | loss.backward() 193 | optimizer.step() 194 | 195 | train_mse += mse.item() 196 | train_l2 += loss.item() 197 | 198 | scheduler.step() 199 | model.eval() 200 | test_l2 = 0.0 201 | with torch.no_grad(): 202 | for x, y in test_loader: 203 | x, y = x.to(device), y.to(device) 204 | 205 | out = model(x).reshape(batch_size, h, s) 206 | out = y_normalizer.decode(out) 207 | 208 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 209 | 210 | train_mse /= len(train_loader) 211 | train_l2/= ntrain 212 | test_l2 /= ntest 213 | 214 | train_loss[ep] = train_l2 215 | test_loss[ep] = test_l2 216 | 217 | t2 = default_timer() 218 | print("Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}" 219 | .format(ep, t2-t1, train_mse, train_l2, test_l2)) 220 | 221 | # %% 222 | """ Prediction """ 223 | pred = [] 224 | test_e = [] 225 | with torch.no_grad(): 226 | 227 | index = 0 228 | for x, y in test_loader: 229 | test_l2 = 0 230 | x, y = x.to(device), y.to(device) 231 | 232 | out = model(x).reshape(batch_size, h, h) 233 | out = y_normalizer.decode(out) 234 | pred.append( out.cpu() ) 235 | 236 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 237 | test_e.append( test_l2/batch_size ) 238 | 239 | print("Batch-{}, Loss-{}".format(index, test_l2/batch_size) ) 240 | index += 1 241 | 242 | pred = torch.cat((pred)) 243 | test_e = torch.tensor((test_e)) 244 | print('Mean Testing Error:', 100*torch.mean(test_e).numpy(), '%') 245 | 246 | # %% 247 | """ Plotting """ 248 | plt.rcParams["font.family"] = "serif" 249 | plt.rcParams['font.size'] = 12 250 | 251 | figure1 = plt.figure(figsize = (12, 13)) 252 | plt.subplots_adjust(hspace=0.01, wspace=0.25) 253 | index = 1 254 | for value in range(y_test.shape[0]): 255 | if value % 5 == 1 and value != 1: 256 | ### 257 | img = y_test[value,:,:-1].cpu().numpy() 258 | plt.subplot(3,2, index) 259 | plt.imshow(img, cmap='nipy_spectral', extent=[0,360,-90,+90]) 260 | plt.xlabel('Longitude ($^{\circ}$)'); plt.ylabel('Lattitude ($^{\circ}$)') 261 | plt.grid(True) 262 | if index==1: 263 | plt.title('Truth: Feb 2019, 1st'); 264 | else: 265 | plt.title('Truth: Feb 2021, 1st') 266 | 267 | ### 268 | plt.subplot(3,2, index+2) 269 | plt.imshow(pred[value,:,:-1], cmap='nipy_spectral', extent=[0,360,-90,+90]) 270 | plt.xlabel('Longitude ($^{\circ}$)'); plt.ylabel('Lattitude ($^{\circ}$)') 271 | plt.grid(True) 272 | if index==1: 273 | plt.title('Identification - Feb 2019, 1st \n (Error-{:0.4f})'.format(100*test_e[value].numpy())) 274 | else: 275 | plt.title('Identification - Feb 2021, 1st \n (Error-{:0.4f})'.format(100*test_e[value].numpy())) 276 | 277 | index = index + 1 278 | 279 | # %% 280 | """ 281 | For saving the trained model and prediction data 282 | """ 283 | torch.save(model, 'model/WNO_ERA5_t2m') 284 | scipy.io.savemat('results/wno_results_ERA5_t2m.mat', mdict={'x_test':x_test.cpu().numpy(), 285 | 'y_test':y_test.cpu().numpy(), 286 | 'pred':pred.cpu().numpy(), 287 | 'test_e':test_e.cpu().numpy()}) 288 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Load required packages """ 2 | 3 | import torch 4 | import numpy as np 5 | import scipy.io 6 | import h5py 7 | import torch.nn as nn 8 | 9 | import operator 10 | from functools import reduce 11 | from functools import partial 12 | 13 | 14 | """ Utility functions for --loading data, 15 | --data normalization, 16 | --data standerdization, and 17 | --loss evaluation 18 | 19 | The base codes are taken from the repo: https://github.com/zongyi-li/fourier_neural_operator 20 | """ 21 | 22 | 23 | # reading data 24 | class MatReader(object): 25 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 26 | super(MatReader, self).__init__() 27 | 28 | self.to_torch = to_torch 29 | self.to_cuda = to_cuda 30 | self.to_float = to_float 31 | 32 | self.file_path = file_path 33 | 34 | self.data = None 35 | self.old_mat = None 36 | self._load_file() 37 | 38 | def _load_file(self): 39 | try: 40 | self.data = scipy.io.loadmat(self.file_path) 41 | self.old_mat = True 42 | except: 43 | self.data = h5py.File(self.file_path) 44 | self.old_mat = False 45 | 46 | def load_file(self, file_path): 47 | self.file_path = file_path 48 | self._load_file() 49 | 50 | def read_field(self, field): 51 | x = self.data[field] 52 | 53 | if not self.old_mat: 54 | x = x[()] 55 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 56 | 57 | if self.to_float: 58 | x = x.astype(np.float32) 59 | 60 | if self.to_torch: 61 | x = torch.from_numpy(x) 62 | return x 63 | 64 | def set_cuda(self, to_cuda): 65 | self.to_cuda = to_cuda 66 | 67 | def set_torch(self, to_torch): 68 | self.to_torch = to_torch 69 | 70 | def set_float(self, to_float): 71 | self.to_float = to_float 72 | 73 | # normalization, pointwise gaussian 74 | class UnitGaussianNormalizer(object): 75 | def __init__(self, x, eps=0.00001): 76 | super(UnitGaussianNormalizer, self).__init__() 77 | 78 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 79 | self.mean = torch.mean(x, 0) 80 | self.std = torch.std(x, 0) 81 | self.eps = eps 82 | 83 | def encode(self, x): 84 | x = (x - self.mean) / (self.std + self.eps) 85 | return x 86 | 87 | def decode(self, x, sample_idx=None): 88 | if sample_idx is None: 89 | std = self.std + self.eps # n 90 | mean = self.mean 91 | else: 92 | if len(self.mean.shape) == len(sample_idx[0].shape): 93 | std = self.std[sample_idx] + self.eps # batch*n 94 | mean = self.mean[sample_idx] 95 | if len(self.mean.shape) > len(sample_idx[0].shape): 96 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 97 | mean = self.mean[:,sample_idx] 98 | 99 | # x is in shape of batch*n or T*batch*n 100 | x = (x * std) + mean 101 | return x 102 | 103 | def cuda(self): 104 | self.mean = self.mean.cuda() 105 | self.std = self.std.cuda() 106 | 107 | def cpu(self): 108 | self.mean = self.mean.cpu() 109 | self.std = self.std.cpu() 110 | 111 | def to(self, device): 112 | self.mean = self.mean.to(device) 113 | self.std = self.std.to(device) 114 | 115 | # normalization, scaling by range 116 | class RangeNormalizer(object): 117 | def __init__(self, x, low=0.0, high=1.0): 118 | super(RangeNormalizer, self).__init__() 119 | mymin = torch.min(x, 0)[0].view(-1) 120 | mymax = torch.max(x, 0)[0].view(-1) 121 | 122 | self.a = (high - low)/(mymax - mymin) 123 | self.b = -self.a*mymax + high 124 | 125 | def encode(self, x): 126 | s = x.size() 127 | x = x.view(s[0], -1) 128 | x = self.a*x + self.b 129 | x = x.view(s) 130 | return x 131 | 132 | def decode(self, x): 133 | s = x.size() 134 | x = x.view(s[0], -1) 135 | x = (x - self.b)/self.a 136 | x = x.view(s) 137 | return x 138 | 139 | def to(self, device): 140 | self.mean = self.mean.to(device) 141 | self.std = self.std.to(device) 142 | 143 | # loss function with rel/abs Lp loss 144 | class LpLoss(object): 145 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 146 | super(LpLoss, self).__init__() 147 | 148 | #Dimension and Lp-norm type are postive 149 | assert d > 0 and p > 0 150 | 151 | self.d = d 152 | self.p = p 153 | self.reduction = reduction 154 | self.size_average = size_average 155 | 156 | def abs(self, x, y): 157 | num_examples = x.size()[0] 158 | 159 | #Assume uniform mesh 160 | h = 1.0 / (x.size()[1] - 1.0) 161 | 162 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 163 | 164 | if self.reduction: 165 | if self.size_average: 166 | return torch.mean(all_norms) 167 | else: 168 | return torch.sum(all_norms) 169 | 170 | return all_norms 171 | 172 | def rel(self, x, y): 173 | num_examples = x.size()[0] 174 | 175 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 176 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 177 | 178 | if self.reduction: 179 | if self.size_average: 180 | return torch.mean(diff_norms/y_norms) 181 | else: 182 | return torch.sum(diff_norms/y_norms) 183 | 184 | return diff_norms/y_norms 185 | 186 | def __call__(self, x, y): 187 | return self.rel(x, y) 188 | 189 | # print the number of parameters 190 | def count_params(model): 191 | c = 0 192 | for p in list(model.parameters()): 193 | c += reduce(operator.mul, 194 | list(p.size()+(2,) if p.is_complex() else p.size())) 195 | return c 196 | -------------------------------------------------------------------------------- /wno1d_Burgers_v3.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code belongs to the paper: 3 | -- Tripura, T., & Chakraborty, S. (2022). Wavelet Neural Operator for solving 4 | parametric partialdifferential equations in computational mechanics problems. 5 | 6 | -- This code is for 1-D Burger's equation (time-independent problem). 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import matplotlib.pyplot as plt 14 | 15 | from timeit import default_timer 16 | from utils import * 17 | from wavelet_convolution_v3 import WaveConv1d 18 | 19 | torch.manual_seed(0) 20 | np.random.seed(0) 21 | device = ('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | # %% 24 | """ The forward operation """ 25 | class WNO1d(nn.Module): 26 | def __init__(self, width, level, layers, size, wavelet, in_channel, grid_range, omega, padding=0): 27 | super(WNO1d, self).__init__() 28 | 29 | """ 30 | The WNO network. It contains l-layers of the Wavelet integral layer. 31 | 1. Lift the input using v(x) = self.fc0 . 32 | 2. l-layers of the integral operators v(j+1)(x) = g(K.v + W.v)(x). 33 | --> W is defined by self.w; K is defined by self.conv. 34 | 3. Project the output of last layer using self.fc1 and self.fc2. 35 | 36 | Input : 2-channel tensor, Initial condition and location (a(x), x) 37 | : shape: (batchsize * x=s * c=2) 38 | Output: Solution of a later timestep (u(x)) 39 | : shape: (batchsize * x=s * c=1) 40 | 41 | Input parameters: 42 | ----------------- 43 | width : scalar, lifting dimension of input 44 | level : scalar, number of wavelet decomposition 45 | layers: scalar, number of wavelet kernel integral blocks 46 | size : scalar, signal length 47 | wavelet: string, wavelet filter 48 | in_channel: scalar, channels in input including grid 49 | grid_range: scalar (for 1D), right support of 1D domain 50 | padding : scalar, size of zero padding 51 | """ 52 | 53 | self.level = level 54 | self.width = width 55 | self.layers = layers 56 | self.size = size 57 | self.wavelet = wavelet 58 | self.omega = omega 59 | self.in_channel = in_channel 60 | self.grid_range = grid_range 61 | self.padding = padding 62 | 63 | self.conv = nn.ModuleList() 64 | self.w = nn.ModuleList() 65 | 66 | self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x) 67 | for i in range( self.layers ): 68 | self.conv.append( WaveConv1d(self.width, self.width, self.level, size=self.size, 69 | wavelet=self.wavelet, omega=self.omega) ) 70 | self.w.append( nn.Conv1d(self.width, self.width, 1) ) 71 | self.fc1 = nn.Linear(self.width, 128) 72 | self.fc2 = nn.Linear(128, 1) 73 | 74 | def forward(self, x): 75 | grid = self.get_grid(x.shape, x.device) 76 | x = torch.cat((x, grid), dim=-1) 77 | x = self.fc0(x) # Shape: Batch * x * Channel 78 | x = x.permute(0, 2, 1) # Shape: Batch * Channel * x 79 | if self.padding != 0: 80 | x = F.pad(x, [0,self.padding]) 81 | 82 | for index, (convl, wl) in enumerate( zip(self.conv, self.w) ): 83 | x = convl(x) + wl(x) 84 | if index != self.layers - 1: # Final layer has no activation 85 | x = F.mish(x) # Shape: Batch * Channel * x 86 | 87 | if self.padding != 0: 88 | x = x[..., :-self.padding] 89 | x = x.permute(0, 2, 1) # Shape: Batch * x * Channel 90 | x = F.mish( self.fc1(x) ) # Shape: Batch * x * Channel 91 | x = self.fc2(x) # Shape: Batch * x * Channel 92 | return x 93 | 94 | def get_grid(self, shape, device): 95 | # The grid of the solution 96 | batchsize, size_x = shape[0], shape[1] 97 | gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float) 98 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 99 | return gridx.to(device) 100 | 101 | 102 | # %% 103 | """ Model configurations """ 104 | 105 | PATH = '/home/user/Desktop/Papers_codes/P3_WNO/WNO-master/data/burgers_data_R10.mat' 106 | ntrain = 1000 107 | ntest = 100 108 | 109 | batch_size = 20 110 | learning_rate = 0.001 111 | 112 | epochs = 500 113 | step_size = 50 # weight-decay step size 114 | gamma = 0.5 # weight-decay rate 115 | 116 | wavelet = 'db2' # wavelet basis function 117 | level = 5 # lavel of wavelet decomposition 118 | width = 40 # uplifting dimension 119 | layers = 4 # no of wavelet layers 120 | 121 | sub = 2**3 # subsampling rate 122 | h = 2**13 // sub # total grid size divided by the subsampling rate 123 | grid_range = 1 124 | in_channel = 2 # (a(x), x) for this case 125 | 126 | # %% 127 | """ Read data """ 128 | 129 | dataloader = MatReader(PATH) 130 | x_data = dataloader.read_field('a')[:,::sub] 131 | y_data = dataloader.read_field('u')[:,::sub] 132 | 133 | x_train = x_data[:ntrain,:] 134 | y_train = y_data[:ntrain,:] 135 | x_test = x_data[-ntest:,:] 136 | y_test = y_data[-ntest:,:] 137 | 138 | x_train = x_train[:, :, None] 139 | x_test = x_test[:, :, None] 140 | 141 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 142 | batch_size=batch_size, shuffle=True) 143 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), 144 | batch_size=batch_size, shuffle=False) 145 | 146 | # %% 147 | """ The model definition """ 148 | model = WNO1d(width=width, level=level, layers=layers, size=h, wavelet=wavelet, 149 | in_channel=in_channel, grid_range=grid_range, omega=4).to(device) 150 | print(count_params(model)) 151 | 152 | """ Training and testing """ 153 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) 154 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 155 | 156 | train_loss = torch.zeros(epochs) 157 | test_loss = torch.zeros(epochs) 158 | myloss = LpLoss(size_average=False) 159 | for ep in range(epochs): 160 | model.train() 161 | t1 = default_timer() 162 | train_mse = 0 163 | train_l2 = 0 164 | for x, y in train_loader: 165 | x, y = x.to(device), y.to(device) 166 | 167 | optimizer.zero_grad() 168 | out = model(x) 169 | 170 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1)) 171 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 172 | l2.backward() # l2 relative loss 173 | 174 | optimizer.step() 175 | train_mse += mse.item() 176 | train_l2 += l2.item() 177 | 178 | scheduler.step() 179 | model.eval() 180 | test_l2 = 0.0 181 | with torch.no_grad(): 182 | for x, y in test_loader: 183 | x, y = x.to(device), y.to(device) 184 | 185 | out = model(x) 186 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 187 | 188 | train_mse /= len(train_loader) 189 | train_l2 /= ntrain 190 | test_l2 /= ntest 191 | 192 | train_loss[ep] = train_l2 193 | test_loss[ep] = test_l2 194 | 195 | t2 = default_timer() 196 | print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}' 197 | .format(ep, t2-t1, train_mse, train_l2, test_l2)) 198 | 199 | # %% 200 | """ Prediction """ 201 | pred = [] 202 | test_e = [] 203 | with torch.no_grad(): 204 | 205 | index = 0 206 | for x, y in test_loader: 207 | test_l2 = 0 208 | x, y = x.to(device), y.to(device) 209 | 210 | out = model(x) 211 | test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 212 | 213 | test_e.append( test_l2/batch_size ) 214 | pred.append( out ) 215 | print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2/batch_size )) 216 | index += 1 217 | 218 | pred = torch.cat((pred)) 219 | test_e = torch.tensor((test_e)) 220 | print('Mean Error:', 100*torch.mean(test_e).numpy()) 221 | 222 | # %% 223 | plt.rcParams['font.family'] = 'Times New Roman' 224 | plt.rcParams['font.size'] = 14 225 | plt.rcParams['mathtext.fontset'] = 'dejavuserif' 226 | 227 | colormap = plt.cm.jet 228 | colors = [colormap(i) for i in np.linspace(0, 1, 5)] 229 | 230 | """ Plotting """ 231 | figure7 = plt.figure(figsize = (10, 5), dpi=300) 232 | index = 0 233 | for i in range(y_test.shape[0]): 234 | if i % 20 == 1: 235 | plt.plot(y_test[i, :].cpu().numpy(), color=colors[index], label='Actual') 236 | plt.plot(pred[i,:].cpu().numpy(), '--', color=colors[index], label='Prediction') 237 | index += 1 238 | plt.legend(ncol=5, loc=3, borderaxespad=0.1, columnspacing=0.75, handletextpad=0.25) 239 | plt.grid(True, alpha=0.35) 240 | plt.ylim([-1,1]) 241 | plt.margins(0) 242 | plt.xlabel('Space ($x$)') 243 | plt.ylabel('$u$($x$)') 244 | plt.title('Mean Error: {:0.4f}%'.format(100*torch.mean(test_e).numpy()), fontweight='bold') 245 | plt.show() 246 | 247 | # %% 248 | """ 249 | For saving the trained model and prediction data 250 | """ 251 | torch.save(model, 'model/WNO_burgers_v2p1') 252 | scipy.io.savemat('results/wno_results_burgers_v2p1.mat', mdict={'x_test':x_test.cpu().numpy(), 253 | 'y_test':y_test.cpu().numpy(), 254 | 'pred':pred.cpu().numpy(), 255 | 'test_e':test_e.cpu().numpy()}) 256 | --------------------------------------------------------------------------------