├── requirements.txt ├── resources └── mwtMain.png ├── Data ├── kDV │ ├── KdV1.m │ ├── GRF_kdv.m │ ├── gen_KdV_fluctuation.m │ └── gen_kdv_smooth.m └── EulerBern │ ├── EB_beam_gen.m │ └── EB_beam_gen_3order.m ├── README.md ├── kDV.ipynb ├── Darcy.ipynb ├── tests ├── test_NS_MWT_N_1000.py └── test_NS_MWT_N_10000.py ├── models ├── utils.py └── models.py └── NS.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | sympy 2 | jupterlab 3 | scipy 4 | h5py 5 | matplotlib -------------------------------------------------------------------------------- /resources/mwtMain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav71531/mwt-operator/HEAD/resources/mwtMain.png -------------------------------------------------------------------------------- /Data/kDV/KdV1.m: -------------------------------------------------------------------------------- 1 | function u = KdV1(init, tspan, s) 2 | 3 | S = spinop([0 1], tspan); 4 | dt = tspan(2) - tspan(1); 5 | S.lin = @(u) - diff(u,3); 6 | S.nonlin = @(u) -.5*diff(u.^2); 7 | S.init = init; 8 | u = spin(S,s,dt,'plot','off'); 9 | 10 | -------------------------------------------------------------------------------- /Data/kDV/GRF_kdv.m: -------------------------------------------------------------------------------- 1 | function u = GRF_kdv(N, m, gamma, tau, sigma) 2 | 3 | my_const = 2*pi; 4 | 5 | my_eigs = sqrt(2)*(abs(sigma).*((my_const.*(1:N)').^2 + tau^2).^(-gamma/2)); 6 | 7 | xi_alpha = randn(N,1); 8 | alpha = my_eigs.*xi_alpha; 9 | 10 | xi_beta = randn(N,1); 11 | beta = my_eigs.*xi_beta; 12 | 13 | a = alpha/2; 14 | b = -beta/2; 15 | c = [flipud(a) - flipud(b).*1i;m + 0*1i;a + b.*1i]; 16 | 17 | uu = chebfun(c, [0 1],'trig','coeffs'); 18 | u = chebfun(@(t) uu(t - 0.5), [0 1],'trig'); 19 | 20 | end -------------------------------------------------------------------------------- /Data/EulerBern/EB_beam_gen.m: -------------------------------------------------------------------------------- 1 | num = 1; 2 | s = 1024; 3 | w = 215; 4 | lamda = 0.2; 5 | F = 5e3; 6 | N = chebop([0,1]); 7 | N.op = @(x,u) diff(u,4)-w^2.*u; 8 | 9 | N.lbc = @(u) [u;diff(u)]; 10 | N.rbc = @(u) [u;diff(u)]; 11 | x = linspace(0,1,s+1); 12 | input = zeros(num, s); 13 | output = zeros(num,s); 14 | %% Force function 15 | for i = 1:num 16 | f = randnfun(lamda,[0,1],'trig'); %randn 17 | f = (f + 2)*F; 18 | 19 | u = N\f; 20 | u_output = u(x); 21 | output(i,:) = u_output(1:end-1); 22 | a = f/F; % input 23 | a_input = a(x); 24 | input(i,:) = a_input(1:end-1); 25 | figure; 26 | subplot(1, 2, 1);plot(input); 27 | subplot(1, 2, 2);plot(output); 28 | end 29 | % save('EB_beam_1024_02_215.mat','input','output'); 30 | -------------------------------------------------------------------------------- /Data/EulerBern/EB_beam_gen_3order.m: -------------------------------------------------------------------------------- 1 | num = 1; 2 | s = 1024; 3 | w = 215; 4 | lamda = 0.2; 5 | F = 2e4; 6 | N = chebop([0,1]); 7 | N.op = @(x,u) diff(u,3)-w^2.*u; 8 | 9 | N.lbc = @(u) [u;diff(u)]; 10 | N.rbc = 0; 11 | 12 | x = linspace(0,1,s+1); 13 | input = zeros(num,s); 14 | output = zeros(num,s); 15 | %% Force function 16 | for i = 1:num 17 | 18 | f = randnfun(lamda,[0,1],'trig'); %randn 19 | f = f*F; 20 | 21 | u = N\f; 22 | u_output = u(x); 23 | output(i,:) = u_output(1:end-1); 24 | a = f/F; % input 25 | a_input = a(x); 26 | input(i,:) = a_input(1:end-1); 27 | 28 | figure; 29 | subplot(1, 2, 1);plot(input); 30 | subplot(1, 2, 2);plot(output); 31 | end 32 | % save('EB_beam_1024_03_215_02.mat','input','output'); 33 | -------------------------------------------------------------------------------- /Data/kDV/gen_KdV_fluctuation.m: -------------------------------------------------------------------------------- 1 | % number of realizations to generate 2 | N = 1; 3 | s = 2048;% grid size 4 | 5 | steps = 1; 6 | 7 | input = zeros(N, s); 8 | 9 | if steps == 1 10 | output = zeros(N, s); 11 | else 12 | output = zeros(N, steps, s); 13 | end 14 | 15 | tspan = linspace(0,1,steps+1); 16 | x = linspace(0,1,s+1); 17 | lamda = 0.05; 18 | 19 | for j=1:N 20 | u0 = 0.5*randnfun(lamda,[0,1],'trig'); 21 | u = KdV1(u0, tspan, s); 22 | 23 | u0eval = u0(x); 24 | input(j,:) = u0eval(1:end-1); 25 | 26 | if steps == 1 27 | output(j,:) = u.values; 28 | else 29 | for k=2:(steps+1) 30 | output(j,k,:) = u{k}.values; 31 | end 32 | end 33 | end 34 | 35 | figure; 36 | plot(input) 37 | title('input') 38 | figure; 39 | plot(output) 40 | title('output') 41 | 42 | % save('kdv_fluc_005_2048.mat','input','output'); -------------------------------------------------------------------------------- /Data/kDV/gen_kdv_smooth.m: -------------------------------------------------------------------------------- 1 | % number of realizations to generate 2 | N = 1; 3 | s = 8192; 4 | % parameters for the Gaussian random field 5 | gamma = 2.5; 6 | tau = 7; 7 | sigma = 7^(2); 8 | 9 | steps = 1; 10 | 11 | input = zeros(N, s); 12 | 13 | if steps == 1 14 | output = zeros(N, s); 15 | else 16 | output = zeros(N, steps, s); 17 | end 18 | 19 | tspan = linspace(0,1,steps+1); 20 | x = linspace(0,1,s+1); 21 | for j=1:N 22 | u0 = GRF_kdv(s/2, 0, gamma, tau, sigma); 23 | u = KdV1(u0, tspan, s); 24 | 25 | u0eval = u0(x); 26 | input(j,:) = u0eval(1:end-1); 27 | 28 | if steps == 1 29 | output(j,:) = u.values; 30 | else 31 | for k=2:(steps+1) 32 | output(j,k,:) = u{k}.values; 33 | end 34 | end 35 | 36 | % disp(j); 37 | end 38 | 39 | figure; 40 | plot(input) 41 | title('input') 42 | figure; 43 | plot(output) 44 | title('output') 45 | % save('kdv_train_test.mat','input','output'); 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiwavelet model for Operator maps 2 | 3 | ![Image](resources/mwtMain.png) 4 | Gaurav Gupta, Xiongye Xiao, and Paul Bogdan\ 5 | **Multiwavelet-based Operator Learning for Differential Equations**\ 6 | In NeurIPS 2021. [arXiv:2109.13459](https://arxiv.org/abs/2109.13459) 7 | 8 | 9 | ## Setup 10 | 11 | ### Requirements 12 | The code package is developed using Python 3.8 and Pytorch 1.8 with cuda 11.0. For running the experiments first install the required packages using 'requirements.txt' 13 | 14 | ## Experiments 15 | ### Data 16 | Generate the data using the scripts provided in the 'Data' directory. The scripts use Matlab 2018+. A sample generated dataset for KdV is uploaded at [KdV data](https://drive.google.com/drive/folders/1--KYHPjl-pkrrGRtH8eg0aG7q8hUjiKg). 17 | 18 | For the experiments on Burgers, Darcy, and Navier Stokes, the code package uses the datasets as provided in the following repository by the Authors Zongyi Li et al. 19 | 20 | [PDE datasets](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) 21 | 22 | ### Scripts 23 | Choose the required model from the `models` (1-d, 2-d, 2-d time-varying) and pass-in the required polynomial: 'legendre' or 'chebyshev'. Next, choose the desired value of multiwavelets 'k'. 24 | 25 | ### kDV 26 | As an example, a complete pipeline is shown for the kDV equation in the attached `kDV.ipynb` notebook. 27 | 28 | ### Navier Stokes 29 | The pre-trained models for Navier Stokes equation is provided using the following link: 30 | 31 | [NS Pre trained](https://drive.google.com/drive/folders/1VDnz_8OdvfQYOneYQ2TFKryJ9Q6oXmnr) 32 | 33 | A visual of time-evolution of the estimated outputs of the pre-trained models is available [Here](https://drive.google.com/drive/folders/1yLCy5C_z37nWP9H8LeFqY_4yLHuNnCmB?usp=sharing). 34 | 35 | To test the model, first download the models to the 'ptmodels' directory. Next, 36 | For N=1000, T = 50, \nu = 1e-3 37 | ``` 38 | python test_NS_MWT_N_1000.py 39 | ``` 40 | For N = 10000, T = 30, \nu = 1e-4 41 | ``` 42 | python test_NS_MWT_N_10000.py 43 | ``` 44 | 45 | **Note:** The NS experiments were done using Pytorch 1.7 cuda 11.0 46 | 47 | ## Citation 48 | If you use this code, or our work, please cite: 49 | ``` 50 | @misc{gupta2021multiwavelet, 51 | title={Multiwavelet-based Operator Learning for Differential Equations}, 52 | author={Gaurav Gupta and Xiongye Xiao and Paul Bogdan}, 53 | year={2021}, 54 | eprint={2109.13459}, 55 | archivePrefix={arXiv}, 56 | primaryClass={cs.LG} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /kDV.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "source": [ 7 | "import torch\n", 8 | "import torch.nn as nn\n", 9 | "\n", 10 | "import numpy as np\n", 11 | "from scipy.io import loadmat, savemat\n", 12 | "import math\n", 13 | "import os\n", 14 | "import h5py\n", 15 | "\n", 16 | "from functools import partial\n", 17 | "from models.models import MWT1d\n", 18 | "from models.utils import train, test, LpLoss, get_filter, UnitGaussianNormalizer" 19 | ], 20 | "outputs": [], 21 | "metadata": {} 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "source": [ 27 | "torch.manual_seed(0)\n", 28 | "np.random.seed(0)" 29 | ], 30 | "outputs": [], 31 | "metadata": {} 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "source": [ 37 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 38 | ], 39 | "outputs": [], 40 | "metadata": {} 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "source": [ 46 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 47 | ], 48 | "outputs": [], 49 | "metadata": {} 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 5, 54 | "source": [ 55 | "def get_initializer(name):\n", 56 | " \n", 57 | " if name == 'xavier_normal':\n", 58 | " init_ = partial(nn.init.xavier_normal_)\n", 59 | " elif name == 'kaiming_uniform':\n", 60 | " init_ = partial(nn.init.kaiming_uniform_)\n", 61 | " elif name == 'kaiming_normal':\n", 62 | " init_ = partial(nn.init.kaiming_normal_)\n", 63 | " return init_" 64 | ], 65 | "outputs": [], 66 | "metadata": {} 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 6, 71 | "source": [ 72 | "ntrain = 1000\n", 73 | "ntest = 200\n", 74 | "\n", 75 | "sub = 2**3 #subsampling rate\n", 76 | "h = 2**13 // sub #total grid size divided by the subsampling rate\n", 77 | "s = h\n", 78 | "batch_size = 20\n", 79 | "\n", 80 | "rw_ = loadmat('Data/KDV/kdv_train_test.mat')\n", 81 | "x_data = rw_['input'].astype(np.float32)\n", 82 | "y_data = rw_['output'].astype(np.float32)" 83 | ], 84 | "outputs": [], 85 | "metadata": {} 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 7, 90 | "source": [ 91 | "x_train = x_data[:ntrain,::sub]\n", 92 | "y_train = y_data[:ntrain,::sub]\n", 93 | "x_test = x_data[-ntest:,::sub]\n", 94 | "y_test = y_data[-ntest:,::sub]\n", 95 | "\n", 96 | "x_train = torch.from_numpy(x_train)\n", 97 | "x_test = torch.from_numpy(x_test)\n", 98 | "y_train = torch.from_numpy(y_train)\n", 99 | "y_test = torch.from_numpy(y_test)\n", 100 | "\n", 101 | "x_train = x_train.unsqueeze(-1)\n", 102 | "x_test = x_test.unsqueeze(-1)\n", 103 | "\n", 104 | "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n", 105 | "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)" 106 | ], 107 | "outputs": [], 108 | "metadata": {} 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 8, 113 | "source": [ 114 | "ich = 1\n", 115 | "initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform\n", 116 | "\n", 117 | "model = MWT1d(ich,\n", 118 | " alpha = 10,\n", 119 | " c = 4*4,\n", 120 | " k = 4,\n", 121 | " base = 'legendre', # chebyshev\n", 122 | " nCZ = 2,\n", 123 | " initializer = initializer,\n", 124 | " ).to(device)\n", 125 | "learning_rate = 0.001\n", 126 | "\n", 127 | "epochs = 500\n", 128 | "step_size = 100\n", 129 | "gamma = 0.5" 130 | ], 131 | "outputs": [], 132 | "metadata": {} 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "source": [ 138 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n", 139 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n", 140 | "\n", 141 | "myloss = LpLoss(size_average=False)\n", 142 | "\n", 143 | "for epoch in range(1, epochs+1):\n", 144 | " train_l2 = train(model, train_loader, optimizer, epoch, device,\n", 145 | " lossFn = myloss, lr_schedule = scheduler)\n", 146 | " \n", 147 | " test_l2 = test(model, test_loader, device, lossFn=myloss)\n", 148 | " print(f'epoch: {epoch}, train l2 = {train_l2}, test l2 = {test_l2}')" 149 | ], 150 | "outputs": [], 151 | "metadata": {} 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "pde", 157 | "language": "python", 158 | "name": "pde" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.8.8" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 5 175 | } -------------------------------------------------------------------------------- /Darcy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "source": [ 7 | "import torch\n", 8 | "import torch.nn as nn\n", 9 | "\n", 10 | "import numpy as np\n", 11 | "from scipy.io import loadmat, savemat\n", 12 | "import math\n", 13 | "import os\n", 14 | "import h5py\n", 15 | "\n", 16 | "from functools import partial\n", 17 | "from models.models import MWT2d\n", 18 | "from models.utils import train, test, LpLoss, get_filter, UnitGaussianNormalizer" 19 | ], 20 | "outputs": [], 21 | "metadata": {} 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "source": [ 27 | "torch.manual_seed(0)\n", 28 | "np.random.seed(0)" 29 | ], 30 | "outputs": [], 31 | "metadata": {} 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "source": [ 37 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 38 | ], 39 | "outputs": [], 40 | "metadata": {} 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "source": [ 46 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 47 | ], 48 | "outputs": [], 49 | "metadata": {} 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 5, 54 | "source": [ 55 | "def get_initializer(name):\n", 56 | " \n", 57 | " if name == 'xavier_normal':\n", 58 | " init_ = partial(nn.init.xavier_normal_)\n", 59 | " elif name == 'kaiming_uniform':\n", 60 | " init_ = partial(nn.init.kaiming_uniform_)\n", 61 | " elif name == 'kaiming_normal':\n", 62 | " init_ = partial(nn.init.kaiming_normal_)\n", 63 | " return init_" 64 | ], 65 | "outputs": [], 66 | "metadata": {} 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 6, 71 | "source": [ 72 | "ntrain = 1000\n", 73 | "ntest = 200\n", 74 | "\n", 75 | "r = 1\n", 76 | "h = int(((512 - 1)/r) + 1)\n", 77 | "s = h" 78 | ], 79 | "outputs": [], 80 | "metadata": {} 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 7, 85 | "source": [ 86 | "# laod data\n", 87 | "dataloader = h5py.File('../../data/all_train_test_Darcy.mat','r')\n", 88 | "\n", 89 | "a_data = dataloader['a_train']\n", 90 | "p_data = dataloader['p_train']\n", 91 | "\n", 92 | "a_numpy = []\n", 93 | "for i in range(len(a_data)):\n", 94 | " obj = dataloader[a_data[i][0]]\n", 95 | " a_numpy.append(obj[:])\n", 96 | "a_tensor = torch.from_numpy(np.array(a_numpy).astype(np.float32))\n", 97 | "\n", 98 | "p_numpy = []\n", 99 | "for i in range(len(p_data)):\n", 100 | " obj = dataloader[p_data[i][0]]\n", 101 | " p_numpy.append(obj[:])\n", 102 | "p_tensor = torch.from_numpy(np.array(p_numpy).astype(np.float32))" 103 | ], 104 | "outputs": [], 105 | "metadata": {} 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 8, 110 | "source": [ 111 | "x_train = a_tensor[:ntrain,::r,::r][:,:s,:s]\n", 112 | "y_train = p_tensor[:ntrain,::r,::r][:,:s,:s]\n", 113 | "\n", 114 | "x_test = a_tensor[-ntest:,::r,::r][:,:s,:s]\n", 115 | "y_test = p_tensor[-ntest:,::r,::r][:,:s,:s]" 116 | ], 117 | "outputs": [], 118 | "metadata": {} 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 9, 123 | "source": [ 124 | "x_normalizer = UnitGaussianNormalizer(x_train)\n", 125 | "x_train = x_normalizer.encode(x_train)\n", 126 | "x_test = x_normalizer.encode(x_test)\n", 127 | "\n", 128 | "y_normalizer = UnitGaussianNormalizer(y_train)\n", 129 | "y_train = y_normalizer.encode(y_train)\n", 130 | "\n", 131 | "grids = []\n", 132 | "grids.append(np.linspace(0, 1, s))\n", 133 | "grids.append(np.linspace(0, 1, s))\n", 134 | "grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T\n", 135 | "grid = grid.reshape(1,s,s,2)\n", 136 | "grid = torch.tensor(grid, dtype=torch.float)\n", 137 | "x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3)\n", 138 | "x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3)" 139 | ], 140 | "outputs": [], 141 | "metadata": {} 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 10, 146 | "source": [ 147 | "batch_size = 10\n", 148 | "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n", 149 | "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)" 150 | ], 151 | "outputs": [], 152 | "metadata": {} 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 11, 157 | "source": [ 158 | "ich = 3\n", 159 | "initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform\n", 160 | "\n", 161 | "torch.manual_seed(0)\n", 162 | "np.random.seed(0)\n", 163 | "\n", 164 | "model = MWT2d(ich, \n", 165 | " alpha = 12,\n", 166 | " c = 4,\n", 167 | " k = 4, \n", 168 | " base = 'legendre', # 'chebyshev'\n", 169 | " nCZ = 4,\n", 170 | " L = 0,\n", 171 | " initializer = initializer,\n", 172 | " ).to(device)\n", 173 | "learning_rate = 0.001\n", 174 | "\n", 175 | "epochs = 500\n", 176 | "step_size = 100\n", 177 | "gamma = 0.5" 178 | ], 179 | "outputs": [], 180 | "metadata": {} 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "source": [ 186 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n", 187 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n", 188 | "\n", 189 | "myloss = LpLoss(size_average=False)\n", 190 | "y_normalizer.cuda()\n", 191 | "\n", 192 | "for epoch in range(1, epochs+1):\n", 193 | " train_l2 = train(model, train_loader, optimizer, epoch, device,\n", 194 | " lossFn = myloss, lr_schedule = scheduler,\n", 195 | " post_proc = y_normalizer.decode)\n", 196 | " \n", 197 | " test_l2 = test(model, test_loader, device, lossFn=myloss, post_proc=y_normalizer.decode)\n", 198 | " print(f'epoch: {epoch}, train l2 = {train_l2}, test l2 = {test_l2}')" 199 | ], 200 | "outputs": [], 201 | "metadata": {} 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "source": [], 207 | "outputs": [], 208 | "metadata": {} 209 | } 210 | ], 211 | "metadata": { 212 | "kernelspec": { 213 | "display_name": "pde", 214 | "language": "python", 215 | "name": "pde" 216 | }, 217 | "language_info": { 218 | "codemirror_mode": { 219 | "name": "ipython", 220 | "version": 3 221 | }, 222 | "file_extension": ".py", 223 | "mimetype": "text/x-python", 224 | "name": "python", 225 | "nbconvert_exporter": "python", 226 | "pygments_lexer": "ipython3", 227 | "version": "3.8.8" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 5 232 | } -------------------------------------------------------------------------------- /tests/test_NS_MWT_N_1000.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from typing import List, Tuple 6 | 7 | import numpy as np 8 | import math 9 | import os 10 | import h5py 11 | 12 | from functools import partial 13 | from models.utils import train, test, LpLoss, get_filter, UnitGaussianNormalizer 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | def get_initializer(name): 18 | 19 | if name == 'xavier_normal': 20 | init_ = partial(nn.init.xavier_normal_) 21 | elif name == 'kaiming_uniform': 22 | init_ = partial(nn.init.kaiming_uniform_) 23 | elif name == 'kaiming_normal': 24 | init_ = partial(nn.init.kaiming_normal_) 25 | return init_ 26 | 27 | class sparseKernel(nn.Module): 28 | def __init__(self, 29 | k, alpha, c=1, 30 | nl = 1, 31 | initializer = None, 32 | **kwargs): 33 | super(sparseKernel,self).__init__() 34 | 35 | self.k = k 36 | self.conv = self.convBlock(alpha*k**2, alpha*k**2) 37 | self.Lo = nn.Conv1d(alpha*k**2, c*k**2, 1) 38 | 39 | def forward(self, x): 40 | B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, Nx, Ny, T) 41 | x = x.reshape(B, -1, Nx, Ny, T) 42 | x = self.conv(x) 43 | x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T) 44 | return x 45 | 46 | 47 | def convBlock(self, ich, och): 48 | net = nn.Sequential( 49 | nn.Conv3d(och, och, 3, 1, 1), 50 | nn.ReLU(inplace=True), 51 | ) 52 | return net 53 | 54 | 55 | def compl_mul3d(a, b): 56 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 57 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 58 | return torch.stack([ 59 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 60 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 61 | ], dim=-1) 62 | 63 | 64 | # fft conv taken from: https://github.com/zongyi-li/fourier_neural_operator 65 | class sparseKernelFT(nn.Module): 66 | def __init__(self, 67 | k, alpha, c=1, 68 | nl = 1, 69 | initializer = None, 70 | **kwargs): 71 | super(sparseKernelFT, self).__init__() 72 | 73 | self.modes = alpha 74 | 75 | self.weights1 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2)) 76 | self.weights2 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2)) 77 | self.weights3 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2)) 78 | self.weights4 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2)) 79 | nn.init.xavier_normal_(self.weights1) 80 | nn.init.xavier_normal_(self.weights2) 81 | nn.init.xavier_normal_(self.weights3) 82 | nn.init.xavier_normal_(self.weights4) 83 | 84 | self.Lo = nn.Conv1d(c*k**2, c*k**2, 1) 85 | self.k = k 86 | 87 | def forward(self, x): 88 | B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, N, N, T) 89 | 90 | x = x.reshape(B, -1, Nx, Ny, T) 91 | x_fft = torch.rfft(x, 3, normalized=True, onesided=True) 92 | 93 | # Multiply relevant Fourier modes 94 | l1 = min(self.modes, Nx//2+1) 95 | l2 = min(self.modes, Ny//2+1) 96 | out_ft = torch.zeros(B, c*ich, Nx, Ny, T//2 +1, 2, device=x.device) 97 | 98 | out_ft[:, :, :l1, :l2, :self.modes] = compl_mul3d( 99 | x_fft[:, :, :l1, :l2, :self.modes], self.weights1[:, :, :l1, :l2, :]) 100 | out_ft[:, :, -l1:, :l2, :self.modes] = compl_mul3d( 101 | x_fft[:, :, -l1:, :l2, :self.modes], self.weights2[:, :, :l1, :l2, :]) 102 | out_ft[:, :, :l1, -l2:, :self.modes] = compl_mul3d( 103 | x_fft[:, :, :l1, -l2:, :self.modes], self.weights3[:, :, :l1, :l2, :]) 104 | out_ft[:, :, -l1:, -l2:, :self.modes] = compl_mul3d( 105 | x_fft[:, :, -l1:, -l2:, :self.modes], self.weights4[:, :, :l1, :l2, :]) 106 | 107 | #Return to physical space 108 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(Nx, Ny, T)) 109 | 110 | x = F.relu(x) 111 | x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T) 112 | return x 113 | 114 | 115 | class MWT_CZ(nn.Module): 116 | def __init__(self, 117 | k = 3, alpha = 5, 118 | L = 0, c = 1, 119 | base = 'legendre', 120 | initializer = None, 121 | **kwargs): 122 | super(MWT_CZ, self).__init__() 123 | 124 | self.k = k 125 | self.L = L 126 | H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) 127 | H0r = H0@PHI0 128 | G0r = G0@PHI0 129 | H1r = H1@PHI1 130 | G1r = G1@PHI1 131 | 132 | H0r[np.abs(H0r)<1e-8]=0 133 | H1r[np.abs(H1r)<1e-8]=0 134 | G0r[np.abs(G0r)<1e-8]=0 135 | G1r[np.abs(G1r)<1e-8]=0 136 | 137 | self.A = sparseKernelFT(k, alpha, c) 138 | self.B = sparseKernelFT(k, alpha, c) 139 | self.C = sparseKernelFT(k, alpha, c) 140 | 141 | self.T0 = nn.Conv1d(c*k**2, c*k**2, 1) 142 | 143 | if initializer is not None: 144 | self.reset_parameters(initializer) 145 | 146 | self.register_buffer('ec_s', torch.Tensor( 147 | np.concatenate((np.kron(H0, H0).T, 148 | np.kron(H0, H1).T, 149 | np.kron(H1, H0).T, 150 | np.kron(H1, H1).T, 151 | ), axis=0))) 152 | self.register_buffer('ec_d', torch.Tensor( 153 | np.concatenate((np.kron(G0, G0).T, 154 | np.kron(G0, G1).T, 155 | np.kron(G1, G0).T, 156 | np.kron(G1, G1).T, 157 | ), axis=0))) 158 | 159 | self.register_buffer('rc_ee', torch.Tensor( 160 | np.concatenate((np.kron(H0r, H0r), 161 | np.kron(G0r, G0r), 162 | ), axis=0))) 163 | self.register_buffer('rc_eo', torch.Tensor( 164 | np.concatenate((np.kron(H0r, H1r), 165 | np.kron(G0r, G1r), 166 | ), axis=0))) 167 | self.register_buffer('rc_oe', torch.Tensor( 168 | np.concatenate((np.kron(H1r, H0r), 169 | np.kron(G1r, G0r), 170 | ), axis=0))) 171 | self.register_buffer('rc_oo', torch.Tensor( 172 | np.concatenate((np.kron(H1r, H1r), 173 | np.kron(G1r, G1r), 174 | ), axis=0))) 175 | 176 | 177 | def forward(self, x): 178 | 179 | B, c, ich, Nx, Ny, T = x.shape # (B, c, k^2, Nx, Ny, T) 180 | ns = math.floor(np.log2(Nx)) 181 | 182 | Ud = torch.jit.annotate(List[Tensor], []) 183 | Us = torch.jit.annotate(List[Tensor], []) 184 | 185 | # decompose 186 | for i in range(ns-self.L): 187 | d, x = self.wavelet_transform(x) 188 | Ud += [self.A(d) + self.B(x)] 189 | Us += [self.C(d)] 190 | x = self.T0(x.reshape(B, c*ich, -1)).view( 191 | B, c, ich, 2**self.L, 2**self.L, T) # coarsest scale transform 192 | 193 | # reconstruct 194 | for i in range(ns-1-self.L,-1,-1): 195 | x = x + Us[i] 196 | x = torch.cat((x, Ud[i]), 2) 197 | x = self.evenOdd(x) 198 | 199 | return x 200 | 201 | 202 | def wavelet_transform(self, x): 203 | xa = torch.cat([x[:, :, :, ::2 , ::2 , :], 204 | x[:, :, :, ::2 , 1::2, :], 205 | x[:, :, :, 1::2, ::2 , :], 206 | x[:, :, :, 1::2, 1::2, :] 207 | ], 2) 208 | waveFil = partial(torch.einsum, 'bcixyt,io->bcoxyt') 209 | d = waveFil(xa, self.ec_d) 210 | s = waveFil(xa, self.ec_s) 211 | return d, s 212 | 213 | 214 | def evenOdd(self, x): 215 | 216 | B, c, ich, Nx, Ny, T = x.shape # (B, c, 2*k^2, Nx, Ny) 217 | assert ich == 2*self.k**2 218 | evOd = partial(torch.einsum, 'bcixyt,io->bcoxyt') 219 | x_ee = evOd(x, self.rc_ee) 220 | x_eo = evOd(x, self.rc_eo) 221 | x_oe = evOd(x, self.rc_oe) 222 | x_oo = evOd(x, self.rc_oo) 223 | 224 | x = torch.zeros(B, c, self.k**2, Nx*2, Ny*2, T, 225 | device = x.device) 226 | x[:, :, :, ::2 , ::2 , :] = x_ee 227 | x[:, :, :, ::2 , 1::2, :] = x_eo 228 | x[:, :, :, 1::2, ::2 , :] = x_oe 229 | x[:, :, :, 1::2, 1::2, :] = x_oo 230 | return x 231 | 232 | def reset_parameters(self, initializer): 233 | initializer(self.T0.weight) 234 | 235 | 236 | class MWT(nn.Module): 237 | def __init__(self, 238 | ich = 1, k = 3, alpha = 2, c = 1, 239 | nCZ = 3, 240 | L = 0, 241 | base = 'legendre', 242 | initializer = None, 243 | **kwargs): 244 | super(MWT,self).__init__() 245 | 246 | self.k = k 247 | self.c = c 248 | self.L = L 249 | self.nCZ = nCZ 250 | self.Lk = nn.Linear(ich, c*k**2) 251 | 252 | self.MWT_CZ = nn.ModuleList( 253 | [MWT_CZ(k, alpha, L, c, base, 254 | initializer) for _ in range(nCZ)] 255 | ) 256 | self.BN = nn.ModuleList( 257 | [nn.BatchNorm3d(c*k**2) for _ in range(nCZ)] 258 | ) 259 | self.Lc0 = nn.Linear(c*k**2, 128) 260 | self.Lc1 = nn.Linear(128, 1) 261 | 262 | if initializer is not None: 263 | self.reset_parameters(initializer) 264 | 265 | def forward(self, x): 266 | 267 | B, Nx, Ny, T, ich = x.shape # (B, Nx, Ny, T, d) 268 | ns = math.floor(np.log2(Nx)) 269 | x = self.Lk(x) 270 | x = x.view(B, Nx, Ny, T, self.c, self.k**2) 271 | x = x.permute(0, 4, 5, 1, 2, 3) 272 | 273 | for i in range(self.nCZ): 274 | x = self.MWT_CZ[i](x) 275 | x = self.BN[i](x.view(B, -1, Nx, Ny, T)).view( 276 | B, self.c, self.k**2, Nx, Ny, T) 277 | if i < self.nCZ-1: 278 | x = F.relu(x) 279 | 280 | x = x.view(B, -1, Nx, Ny, T) # collapse c and k**2 281 | x = x.permute(0, 2, 3, 4, 1) 282 | x = self.Lc0(x) 283 | x = F.relu(x) 284 | x = self.Lc1(x) 285 | return x.squeeze() 286 | 287 | def reset_parameters(self, initializer): 288 | initializer(self.Lc0.weight) 289 | initializer(self.Lc1.weight) 290 | 291 | 292 | def load_data(): 293 | 294 | data_path = 'Data/NS/ns_V1e-3_N5000_T50.mat' 295 | ntrain = 1000 296 | ntest = 200 297 | 298 | sub = 1 299 | S = 64 // sub 300 | T_in = 10 301 | T = 40 302 | 303 | dataloader = h5py.File(data_path) 304 | u_data = dataloader['u'] 305 | t_data = dataloader['u'] 306 | 307 | train_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,:ntrain] 308 | ).permute(3, 1, 2, 0) 309 | train_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,:ntrain] 310 | ).permute(3, 1, 2, 0) 311 | 312 | test_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,-ntest:] 313 | ).permute(3, 1, 2, 0) 314 | test_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,-ntest:] 315 | ).permute(3, 1, 2, 0) 316 | 317 | assert (S == train_u.shape[-2]) 318 | assert (T == train_u.shape[-1]) 319 | 320 | a_normalizer = UnitGaussianNormalizer(train_a) 321 | x_train = a_normalizer.encode(train_a) 322 | x_test = a_normalizer.encode(test_a) 323 | 324 | y_normalizer = UnitGaussianNormalizer(train_u) 325 | y_train = y_normalizer.encode(train_u) 326 | 327 | x_train = x_train.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1]) 328 | x_test = x_test.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 329 | 330 | # pad locations (x,y,t) 331 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 332 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 333 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 334 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 335 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 336 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 337 | 338 | x_train = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 339 | gridt.repeat([ntrain,1,1,1,1]), x_train), dim=-1) 340 | x_test = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 341 | gridt.repeat([ntest,1,1,1,1]), x_test), dim=-1) 342 | 343 | return x_train, y_train, x_test, test_u, y_normalizer 344 | 345 | 346 | 347 | def main(): 348 | x_train, y_train, x_test, y_test, y_normalizer = load_data() 349 | 350 | batch_size = 10 351 | train_loader = torch.utils.data.DataLoader( 352 | torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 353 | test_loader = torch.utils.data.DataLoader( 354 | torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 355 | 356 | myloss = LpLoss(size_average=False) 357 | y_normalizer.cuda() 358 | 359 | # Legendre 360 | model = torch.load('ptmodels/NS_v_1e-3_N1000_T50_alpha_12_c_4_k_3_nCZ_4_L_0_3CNN_BN_Leg_epoch_500.pt') 361 | model.to(device) 362 | 363 | l2_test = test(model, test_loader, device, lossFn=myloss, post_proc=y_normalizer.decode) 364 | print(f'test relative L2 error for N=1000, T=50, nu=1e-3 with Legendre = {l2_test}') 365 | 366 | # Chebyshev 367 | 368 | model = torch.load('ptmodels/NS_v_1e-3_N1000_T50_alpha_12_c_4_k_3_nCZ_4_L_0_3CNN_BN_Chb_epoch_500.pt') 369 | model.to(device) 370 | 371 | l2_test = test(model, test_loader, device, lossFn=myloss, post_proc=y_normalizer.decode) 372 | print(f'test relative L2 error for N=1000, T=50, nu=1e-3 with Chebyshev = {l2_test}') 373 | 374 | 375 | if __name__ == '__main__': 376 | main() 377 | 378 | -------------------------------------------------------------------------------- /tests/test_NS_MWT_N_10000.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from typing import List, Tuple 6 | 7 | import numpy as np 8 | import math 9 | import os 10 | import h5py 11 | 12 | from functools import partial 13 | from models.utils import train, test, LpLoss, get_filter, UnitGaussianNormalizer 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | def get_initializer(name): 18 | 19 | if name == 'xavier_normal': 20 | init_ = partial(nn.init.xavier_normal_) 21 | elif name == 'kaiming_uniform': 22 | init_ = partial(nn.init.kaiming_uniform_) 23 | elif name == 'kaiming_normal': 24 | init_ = partial(nn.init.kaiming_normal_) 25 | return init_ 26 | 27 | class sparseKernel(nn.Module): 28 | def __init__(self, 29 | k, alpha, c=1, 30 | nl = 1, 31 | initializer = None, 32 | **kwargs): 33 | super(sparseKernel,self).__init__() 34 | 35 | self.k = k 36 | self.conv = self.convBlock(alpha*k**2, alpha*k**2) 37 | self.Lo = nn.Conv1d(alpha*k**2, c*k**2, 1) 38 | 39 | def forward(self, x): 40 | B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, Nx, Ny, T) 41 | x = x.reshape(B, -1, Nx, Ny, T) 42 | x = self.conv(x) 43 | x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T) 44 | 45 | return x 46 | 47 | 48 | def convBlock(self, ich, och): 49 | net = nn.Sequential( 50 | nn.Conv3d(och, och, 3, 1, 1), 51 | nn.ReLU(inplace=True), 52 | ) 53 | return net 54 | 55 | 56 | def compl_mul3d(a, b): 57 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 58 | op = partial(torch.einsum, "bixyz,ioxyz->boxyz") 59 | return torch.stack([ 60 | op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]), 61 | op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1]) 62 | ], dim=-1) 63 | 64 | 65 | class sparseKernelFT(nn.Module): 66 | def __init__(self, 67 | k, alpha1, alpha2, alpha3, c=1, 68 | nl = 1, 69 | initializer = None, 70 | **kwargs): 71 | super(sparseKernelFT, self).__init__() 72 | 73 | self.modes1 = alpha1 74 | self.modes2 = alpha2 75 | self.modes3 = alpha3 76 | 77 | self.weights1 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes1, self.modes2, self.modes3, 2)) 78 | self.weights2 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes1, self.modes2, self.modes3, 2)) 79 | self.weights3 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes1, self.modes2, self.modes3, 2)) 80 | self.weights4 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes1, self.modes2, self.modes3, 2)) 81 | nn.init.xavier_normal_(self.weights1) 82 | nn.init.xavier_normal_(self.weights2) 83 | nn.init.xavier_normal_(self.weights3) 84 | nn.init.xavier_normal_(self.weights4) 85 | 86 | self.Lo = nn.Conv1d(c*k**2, c*k**2, 1) 87 | self.k = k 88 | 89 | def forward(self, x): 90 | B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, N, N, T) 91 | 92 | x = x.reshape(B, -1, Nx, Ny, T) 93 | x_fft = torch.rfft(x, 3, normalized=True, onesided=True) 94 | 95 | # Multiply relevant Fourier modes 96 | l1 = min(self.modes1, Nx//2+1) 97 | l2 = min(self.modes2, Ny//2+1) 98 | out_ft = torch.zeros(B, c*ich, Nx, Ny, T//2 +1, 2, device=x.device) 99 | 100 | out_ft[:, :, :l1, :l2, :self.modes3] = compl_mul3d( 101 | x_fft[:, :, :l1, :l2, :self.modes3], self.weights1[:, :, :l1, :l2, :]) 102 | out_ft[:, :, -l1:, :l2, :self.modes3] = compl_mul3d( 103 | x_fft[:, :, -l1:, :l2, :self.modes3], self.weights2[:, :, :l1, :l2, :]) 104 | out_ft[:, :, :l1, -l2:, :self.modes3] = compl_mul3d( 105 | x_fft[:, :, :l1, -l2:, :self.modes3], self.weights3[:, :, :l1, :l2, :]) 106 | out_ft[:, :, -l1:, -l2:, :self.modes3] = compl_mul3d( 107 | x_fft[:, :, -l1:, -l2:, :self.modes3], self.weights4[:, :, :l1, :l2, :]) 108 | 109 | #Return to physical space 110 | x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(Nx, Ny, T)) 111 | 112 | x = F.relu(x) 113 | x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T) 114 | return x 115 | 116 | 117 | class MWT_CZ(nn.Module): 118 | def __init__(self, 119 | k = 3, alpha = 5, 120 | L = 0, c = 1, 121 | base = 'legendre', 122 | initializer = None, 123 | **kwargs): 124 | super(MWT_CZ, self).__init__() 125 | 126 | self.k = k 127 | self.L = L 128 | H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) 129 | H0r = H0@PHI0 130 | G0r = G0@PHI0 131 | H1r = H1@PHI1 132 | G1r = G1@PHI1 133 | 134 | H0r[np.abs(H0r)<1e-8]=0 135 | H1r[np.abs(H1r)<1e-8]=0 136 | G0r[np.abs(G0r)<1e-8]=0 137 | G1r[np.abs(G1r)<1e-8]=0 138 | 139 | self.A = sparseKernelFT(k, alpha, alpha, 10, c) 140 | self.B = sparseKernelFT(k, alpha, alpha, 10, c) 141 | self.C = sparseKernelFT(k, alpha, alpha, 10, c) 142 | 143 | self.T0 = nn.Conv1d(c*k**2, c*k**2, 1) 144 | 145 | if initializer is not None: 146 | self.reset_parameters(initializer) 147 | 148 | self.register_buffer('ec_s', torch.Tensor( 149 | np.concatenate((np.kron(H0, H0).T, 150 | np.kron(H0, H1).T, 151 | np.kron(H1, H0).T, 152 | np.kron(H1, H1).T, 153 | ), axis=0))) 154 | self.register_buffer('ec_d', torch.Tensor( 155 | np.concatenate((np.kron(G0, G0).T, 156 | np.kron(G0, G1).T, 157 | np.kron(G1, G0).T, 158 | np.kron(G1, G1).T, 159 | ), axis=0))) 160 | 161 | self.register_buffer('rc_ee', torch.Tensor( 162 | np.concatenate((np.kron(H0r, H0r), 163 | np.kron(G0r, G0r), 164 | ), axis=0))) 165 | self.register_buffer('rc_eo', torch.Tensor( 166 | np.concatenate((np.kron(H0r, H1r), 167 | np.kron(G0r, G1r), 168 | ), axis=0))) 169 | self.register_buffer('rc_oe', torch.Tensor( 170 | np.concatenate((np.kron(H1r, H0r), 171 | np.kron(G1r, G0r), 172 | ), axis=0))) 173 | self.register_buffer('rc_oo', torch.Tensor( 174 | np.concatenate((np.kron(H1r, H1r), 175 | np.kron(G1r, G1r), 176 | ), axis=0))) 177 | 178 | 179 | def forward(self, x): 180 | 181 | B, c, ich, Nx, Ny, T = x.shape # (B, c, k^2, Nx, Ny, T) 182 | ns = math.floor(np.log2(Nx)) 183 | 184 | Ud = torch.jit.annotate(List[Tensor], []) 185 | Us = torch.jit.annotate(List[Tensor], []) 186 | 187 | # decompose 188 | for i in range(ns-self.L): 189 | d, x = self.wavelet_transform(x) 190 | Ud += [self.A(d) + self.B(x)] 191 | Us += [self.C(d)] 192 | x = self.T0(x.reshape(B, c*ich, -1)).view( 193 | B, c, ich, 2**self.L, 2**self.L, T) # coarsest scale transform 194 | 195 | # reconstruct 196 | for i in range(ns-1-self.L,-1,-1): 197 | x = x + Us[i] 198 | x = torch.cat((x, Ud[i]), 2) 199 | x = self.evenOdd(x) 200 | 201 | return x 202 | 203 | 204 | def wavelet_transform(self, x): 205 | xa = torch.cat([x[:, :, :, ::2 , ::2 , :], 206 | x[:, :, :, ::2 , 1::2, :], 207 | x[:, :, :, 1::2, ::2 , :], 208 | x[:, :, :, 1::2, 1::2, :] 209 | ], 2) 210 | waveFil = partial(torch.einsum, 'bcixyt,io->bcoxyt') 211 | d = waveFil(xa, self.ec_d) 212 | s = waveFil(xa, self.ec_s) 213 | return d, s 214 | 215 | 216 | def evenOdd(self, x): 217 | 218 | B, c, ich, Nx, Ny, T = x.shape # (B, c, 2*k^2, Nx, Ny) 219 | assert ich == 2*self.k**2 220 | evOd = partial(torch.einsum, 'bcixyt,io->bcoxyt') 221 | x_ee = evOd(x, self.rc_ee) 222 | x_eo = evOd(x, self.rc_eo) 223 | x_oe = evOd(x, self.rc_oe) 224 | x_oo = evOd(x, self.rc_oo) 225 | 226 | x = torch.zeros(B, c, self.k**2, Nx*2, Ny*2, T, 227 | device = x.device) 228 | x[:, :, :, ::2 , ::2 , :] = x_ee 229 | x[:, :, :, ::2 , 1::2, :] = x_eo 230 | x[:, :, :, 1::2, ::2 , :] = x_oe 231 | x[:, :, :, 1::2, 1::2, :] = x_oo 232 | return x 233 | 234 | def reset_parameters(self, initializer): 235 | initializer(self.T0.weight) 236 | 237 | 238 | class MWT(nn.Module): 239 | def __init__(self, 240 | ich = 1, k = 3, alpha = 2, c = 1, 241 | nCZ = 3, 242 | L = 0, 243 | base = 'legendre', 244 | initializer = None, 245 | **kwargs): 246 | super(MWT,self).__init__() 247 | 248 | self.k = k 249 | self.c = c 250 | self.L = L 251 | self.nCZ = nCZ 252 | self.Lk = nn.Linear(ich, c*k**2) 253 | 254 | self.MWT_CZ = nn.ModuleList( 255 | [MWT_CZ(k, alpha, L, c, base, 256 | initializer) for _ in range(nCZ)] 257 | ) 258 | self.BN = nn.ModuleList( 259 | [nn.BatchNorm3d(c*k**2) for _ in range(nCZ)] 260 | ) 261 | self.Lc0 = nn.Linear(c*k**2, 128) 262 | self.Lc1 = nn.Linear(128, 1) 263 | 264 | if initializer is not None: 265 | self.reset_parameters(initializer) 266 | 267 | def forward(self, x): 268 | 269 | B, Nx, Ny, T, ich = x.shape # (B, Nx, Ny, T, d) 270 | ns = math.floor(np.log2(Nx)) 271 | x = self.Lk(x) 272 | x = x.view(B, Nx, Ny, T, self.c, self.k**2) 273 | x = x.permute(0, 4, 5, 1, 2, 3) 274 | 275 | for i in range(self.nCZ): 276 | x = self.MWT_CZ[i](x) 277 | x = self.BN[i](x.view(B, -1, Nx, Ny, T)).view( 278 | B, self.c, self.k**2, Nx, Ny, T) 279 | if i < self.nCZ-1: 280 | x = F.relu(x) 281 | 282 | x = x.view(B, -1, Nx, Ny, T) # collapse c and k**2 283 | x = x.permute(0, 2, 3, 4, 1) 284 | x = self.Lc0(x) 285 | x = F.relu(x) 286 | x = self.Lc1(x) 287 | return x.squeeze() 288 | 289 | def reset_parameters(self, initializer): 290 | initializer(self.Lc0.weight) 291 | initializer(self.Lc1.weight) 292 | 293 | 294 | def load_data(): 295 | 296 | data_path = 'Data/NS/ns_V1e-4_N10000_T30.mat' 297 | 298 | ntest = 200 299 | ntrain = 10000-ntest 300 | 301 | sub = 1 302 | S = 64 // sub 303 | T_in = 10 304 | T = 20 305 | 306 | dataloader = h5py.File(data_path) 307 | u_data = dataloader['u'] 308 | t_data = dataloader['u'] 309 | 310 | train_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,:ntrain] 311 | ).permute(3, 1, 2, 0) 312 | train_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,:ntrain] 313 | ).permute(3, 1, 2, 0) 314 | 315 | test_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,-ntest:] 316 | ).permute(3, 1, 2, 0) 317 | test_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,-ntest:] 318 | ).permute(3, 1, 2, 0) 319 | 320 | print('data loading complete') 321 | assert (S == train_u.shape[-2]) 322 | assert (T == train_u.shape[-1]) 323 | 324 | a_normalizer = UnitGaussianNormalizer(train_a) 325 | # x_train = a_normalizer.encode(train_a) 326 | x_test = a_normalizer.encode(test_a) 327 | 328 | y_normalizer = UnitGaussianNormalizer(train_u) 329 | # y_train = y_normalizer.encode(train_u) 330 | 331 | # x_train = x_train.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1]) 332 | x_test = x_test.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1]) 333 | 334 | # pad locations (x,y,t) 335 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 336 | gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1]) 337 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 338 | gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1]) 339 | gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float) 340 | gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1]) 341 | 342 | # x_train = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]), 343 | # gridt.repeat([ntrain,1,1,1,1]), x_train), dim=-1) 344 | x_test = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]), 345 | gridt.repeat([ntest,1,1,1,1]), x_test), dim=-1) 346 | 347 | return x_test, test_u, y_normalizer 348 | 349 | 350 | def main(): 351 | x_test, y_test, y_normalizer = load_data() 352 | 353 | batch_size = 10 354 | # train_loader = torch.utils.data.DataLoader( 355 | # torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 356 | test_loader = torch.utils.data.DataLoader( 357 | torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 358 | 359 | myloss = LpLoss(size_average=False) 360 | y_normalizer.cuda() 361 | 362 | # Legendre 363 | model = torch.load('ptmodels/NS_v_1e-4_N9800_T30_alpha_12_c_4_k_3_nCZ_4_L_0_3CNN_BN_epoch_200.pt') 364 | model.to(device) 365 | 366 | l2_test = test(model, test_loader, device, lossFn=myloss, post_proc=y_normalizer.decode) 367 | print(f'test relative L2 error for N=10000, T=30, nu=1e-4 with Legendre = {l2_test}') 368 | 369 | # Chebyshev 370 | model = torch.load('ptmodels/NS_v_1e-4_N9800_T20_alpha_12_c_4_k_3_nCZ_4_L_0_3CNN_BN_Chb_epoch_200.pt') 371 | model.to(device) 372 | 373 | l2_test = test(model, test_loader, device, lossFn=myloss, post_proc=y_normalizer.decode) 374 | print(f'test relative L2 error for N=10000, T=30, nu=1e-4 with Chebyshev = {l2_test}') 375 | 376 | 377 | if __name__ == '__main__': 378 | main() 379 | 380 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | from functools import partial 6 | 7 | from scipy.special import eval_legendre 8 | from sympy import Poly, legendre, Symbol, chebyshevt 9 | 10 | def legendreDer(k, x): 11 | def _legendre(k, x): 12 | return (2*k+1) * eval_legendre(k, x) 13 | out = 0 14 | for i in np.arange(k-1,-1,-2): 15 | out += _legendre(i, x) 16 | return out 17 | 18 | def phi_(phi_c, x, lb = 0, ub = 1): 19 | mask = np.logical_or(xub) * 1.0 20 | return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask) 21 | 22 | def get_phi_psi(k, base): 23 | 24 | x = Symbol('x') 25 | phi_coeff = np.zeros((k,k)) 26 | phi_2x_coeff = np.zeros((k,k)) 27 | if base == 'legendre': 28 | for ki in range(k): 29 | coeff_ = Poly(legendre(ki, 2*x-1), x).all_coeffs() 30 | phi_coeff[ki,:ki+1] = np.flip(np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64)) 31 | coeff_ = Poly(legendre(ki, 4*x-1), x).all_coeffs() 32 | phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64)) 33 | 34 | psi1_coeff = np.zeros((k, k)) 35 | psi2_coeff = np.zeros((k, k)) 36 | for ki in range(k): 37 | psi1_coeff[ki,:] = phi_2x_coeff[ki,:] 38 | for i in range(k): 39 | a = phi_2x_coeff[ki,:ki+1] 40 | b = phi_coeff[i, :i+1] 41 | prod_ = np.convolve(a, b) 42 | prod_[np.abs(prod_)<1e-8] = 0 43 | proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum() 44 | psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:] 45 | psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:] 46 | for j in range(ki): 47 | a = phi_2x_coeff[ki,:ki+1] 48 | b = psi1_coeff[j, :] 49 | prod_ = np.convolve(a, b) 50 | prod_[np.abs(prod_)<1e-8] = 0 51 | proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum() 52 | psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:] 53 | psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:] 54 | 55 | a = psi1_coeff[ki,:] 56 | prod_ = np.convolve(a, a) 57 | prod_[np.abs(prod_)<1e-8] = 0 58 | norm1 = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum() 59 | 60 | a = psi2_coeff[ki,:] 61 | prod_ = np.convolve(a, a) 62 | prod_[np.abs(prod_)<1e-8] = 0 63 | norm2 = (prod_ * 1/(np.arange(len(prod_))+1) * (1-np.power(0.5, 1+np.arange(len(prod_))))).sum() 64 | norm_ = np.sqrt(norm1 + norm2) 65 | psi1_coeff[ki,:] /= norm_ 66 | psi2_coeff[ki,:] /= norm_ 67 | psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0 68 | psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0 69 | 70 | phi = [np.poly1d(np.flip(phi_coeff[i,:])) for i in range(k)] 71 | psi1 = [np.poly1d(np.flip(psi1_coeff[i,:])) for i in range(k)] 72 | psi2 = [np.poly1d(np.flip(psi2_coeff[i,:])) for i in range(k)] 73 | 74 | elif base == 'chebyshev': 75 | for ki in range(k): 76 | if ki == 0: 77 | phi_coeff[ki,:ki+1] = np.sqrt(2/np.pi) 78 | phi_2x_coeff[ki,:ki+1] = np.sqrt(2/np.pi) * np.sqrt(2) 79 | else: 80 | coeff_ = Poly(chebyshevt(ki, 2*x-1), x).all_coeffs() 81 | phi_coeff[ki,:ki+1] = np.flip(2/np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)) 82 | coeff_ = Poly(chebyshevt(ki, 4*x-1), x).all_coeffs() 83 | phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)) 84 | 85 | phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)] 86 | 87 | x = Symbol('x') 88 | kUse = 2*k 89 | roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() 90 | x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) 91 | # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) 92 | # not needed for our purpose here, we use even k always to avoid 93 | wm = np.pi / kUse / 2 94 | 95 | psi1_coeff = np.zeros((k, k)) 96 | psi2_coeff = np.zeros((k, k)) 97 | 98 | psi1 = [[] for _ in range(k)] 99 | psi2 = [[] for _ in range(k)] 100 | 101 | for ki in range(k): 102 | psi1_coeff[ki,:] = phi_2x_coeff[ki,:] 103 | for i in range(k): 104 | proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum() 105 | psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:] 106 | psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:] 107 | 108 | for j in range(ki): 109 | proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum() 110 | psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:] 111 | psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:] 112 | 113 | psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5) 114 | psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1) 115 | 116 | norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum() 117 | norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum() 118 | 119 | norm_ = np.sqrt(norm1 + norm2) 120 | psi1_coeff[ki,:] /= norm_ 121 | psi2_coeff[ki,:] /= norm_ 122 | psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0 123 | psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0 124 | 125 | psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16) 126 | psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1) 127 | 128 | return phi, psi1, psi2 129 | 130 | 131 | def get_filter(base, k): 132 | 133 | def psi(psi1, psi2, i, inp): 134 | mask = (inp<=0.5) * 1.0 135 | return psi1[i](inp) * mask + psi2[i](inp) * (1-mask) 136 | 137 | if base not in ['legendre', 'chebyshev']: 138 | raise Exception('Base not supported') 139 | 140 | x = Symbol('x') 141 | H0 = np.zeros((k,k)) 142 | H1 = np.zeros((k,k)) 143 | G0 = np.zeros((k,k)) 144 | G1 = np.zeros((k,k)) 145 | PHI0 = np.zeros((k,k)) 146 | PHI1 = np.zeros((k,k)) 147 | phi, psi1, psi2 = get_phi_psi(k, base) 148 | if base == 'legendre': 149 | roots = Poly(legendre(k, 2*x-1)).all_roots() 150 | x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) 151 | wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1) 152 | 153 | for ki in range(k): 154 | for kpi in range(k): 155 | H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() 156 | G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() 157 | H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() 158 | G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() 159 | 160 | PHI0 = np.eye(k) 161 | PHI1 = np.eye(k) 162 | 163 | elif base == 'chebyshev': 164 | x = Symbol('x') 165 | kUse = 2*k 166 | roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() 167 | x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) 168 | # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) 169 | # not needed for our purpose here, we use even k always to avoid 170 | wm = np.pi / kUse / 2 171 | 172 | for ki in range(k): 173 | for kpi in range(k): 174 | H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() 175 | G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() 176 | H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() 177 | G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() 178 | 179 | PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2 180 | PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2 181 | 182 | PHI0[np.abs(PHI0)<1e-8] = 0 183 | PHI1[np.abs(PHI1)<1e-8] = 0 184 | 185 | H0[np.abs(H0)<1e-8] = 0 186 | H1[np.abs(H1)<1e-8] = 0 187 | G0[np.abs(G0)<1e-8] = 0 188 | G1[np.abs(G1)<1e-8] = 0 189 | 190 | return H0, H1, G0, G1, PHI0, PHI1 191 | 192 | 193 | def train(model, train_loader, optimizer, epoch, device, verbose = 0, 194 | lossFn = None, lr_schedule=None, 195 | post_proc = lambda args: args): 196 | 197 | if lossFn is None: 198 | lossFn = nn.MSELoss() 199 | 200 | model.train() 201 | 202 | total_loss = 0. 203 | 204 | for batch_idx, (data, target) in enumerate(train_loader): 205 | 206 | bs = len(data) 207 | data, target = data.to(device), target.to(device) 208 | optimizer.zero_grad() 209 | 210 | output = model(data) 211 | 212 | target = post_proc(target) 213 | output = post_proc(output) 214 | loss = lossFn(output.view(bs, -1), target.view(bs, -1)) 215 | 216 | loss.backward() 217 | optimizer.step() 218 | total_loss += loss.sum().item() 219 | if lr_schedule is not None: lr_schedule.step() 220 | 221 | if verbose>0: 222 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 223 | epoch, batch_idx * len(data), len(train_loader.dataset), 224 | 100. * batch_idx / len(train_loader), loss.item())) 225 | 226 | return total_loss/len(train_loader.dataset) 227 | 228 | 229 | def test(model, test_loader, device, verbose=0, lossFn=None, 230 | post_proc = lambda args: args): 231 | 232 | model.eval() 233 | if lossFn is None: 234 | lossFn = nn.MSELoss() 235 | 236 | 237 | total_loss = 0. 238 | predictions = [] 239 | 240 | with torch.no_grad(): 241 | for data, target in test_loader: 242 | bs = len(data) 243 | 244 | data, target = data.to(device), target.to(device) 245 | output = model(data) 246 | output = post_proc(output) 247 | 248 | loss = lossFn(output.view(bs, -1), target.view(bs, -1)) 249 | total_loss += loss.sum().item() 250 | 251 | return total_loss/len(test_loader.dataset) 252 | 253 | 254 | # Till EoF 255 | # taken from FNO paper: 256 | # https://github.com/zongyi-li/fourier_neural_operator 257 | 258 | # normalization, pointwise gaussian 259 | class UnitGaussianNormalizer(object): 260 | def __init__(self, x, eps=0.00001): 261 | super(UnitGaussianNormalizer, self).__init__() 262 | 263 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 264 | self.mean = torch.mean(x, 0) 265 | self.std = torch.std(x, 0) 266 | self.eps = eps 267 | 268 | def encode(self, x): 269 | x = (x - self.mean) / (self.std + self.eps) 270 | return x 271 | 272 | def decode(self, x, sample_idx=None): 273 | if sample_idx is None: 274 | std = self.std + self.eps # n 275 | mean = self.mean 276 | else: 277 | if len(self.mean.shape) == len(sample_idx[0].shape): 278 | std = self.std[sample_idx] + self.eps # batch*n 279 | mean = self.mean[sample_idx] 280 | if len(self.mean.shape) > len(sample_idx[0].shape): 281 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 282 | mean = self.mean[:,sample_idx] 283 | 284 | # x is in shape of batch*n or T*batch*n 285 | x = (x * std) + mean 286 | return x 287 | 288 | def cuda(self): 289 | self.mean = self.mean.cuda() 290 | self.std = self.std.cuda() 291 | 292 | def cpu(self): 293 | self.mean = self.mean.cpu() 294 | self.std = self.std.cpu() 295 | 296 | # normalization, Gaussian 297 | class GaussianNormalizer(object): 298 | def __init__(self, x, eps=0.00001): 299 | super(GaussianNormalizer, self).__init__() 300 | 301 | self.mean = torch.mean(x) 302 | self.std = torch.std(x) 303 | self.eps = eps 304 | 305 | def encode(self, x): 306 | x = (x - self.mean) / (self.std + self.eps) 307 | return x 308 | 309 | def decode(self, x, sample_idx=None): 310 | x = (x * (self.std + self.eps)) + self.mean 311 | return x 312 | 313 | def cuda(self): 314 | self.mean = self.mean.cuda() 315 | self.std = self.std.cuda() 316 | 317 | def cpu(self): 318 | self.mean = self.mean.cpu() 319 | self.std = self.std.cpu() 320 | 321 | 322 | # normalization, scaling by range 323 | class RangeNormalizer(object): 324 | def __init__(self, x, low=0.0, high=1.0): 325 | super(RangeNormalizer, self).__init__() 326 | mymin = torch.min(x, 0)[0].view(-1) 327 | mymax = torch.max(x, 0)[0].view(-1) 328 | 329 | self.a = (high - low)/(mymax - mymin) 330 | self.b = -self.a*mymax + high 331 | 332 | def encode(self, x): 333 | s = x.size() 334 | x = x.view(s[0], -1) 335 | x = self.a*x + self.b 336 | x = x.view(s) 337 | return x 338 | 339 | def decode(self, x): 340 | s = x.size() 341 | x = x.view(s[0], -1) 342 | x = (x - self.b)/self.a 343 | x = x.view(s) 344 | return x 345 | 346 | class LpLoss(object): 347 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 348 | super(LpLoss, self).__init__() 349 | 350 | #Dimension and Lp-norm type are postive 351 | assert d > 0 and p > 0 352 | 353 | self.d = d 354 | self.p = p 355 | self.reduction = reduction 356 | self.size_average = size_average 357 | 358 | def abs(self, x, y): 359 | num_examples = x.size()[0] 360 | 361 | #Assume uniform mesh 362 | h = 1.0 / (x.size()[1] - 1.0) 363 | 364 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 365 | 366 | if self.reduction: 367 | if self.size_average: 368 | return torch.mean(all_norms) 369 | else: 370 | return torch.sum(all_norms) 371 | 372 | return all_norms 373 | 374 | def rel(self, x, y): 375 | num_examples = x.size()[0] 376 | 377 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 378 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 379 | 380 | if self.reduction: 381 | if self.size_average: 382 | return torch.mean(diff_norms/y_norms) 383 | else: 384 | return torch.sum(diff_norms/y_norms) 385 | 386 | return diff_norms/y_norms 387 | 388 | def __call__(self, x, y): 389 | return self.rel(x, y) 390 | -------------------------------------------------------------------------------- /NS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "source": [ 7 | "import torch\n", 8 | "import torch.nn as nn\n", 9 | "import torch.nn.functional as F\n", 10 | "from torch import Tensor\n", 11 | "from typing import List, Tuple\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "import math\n", 15 | "import os\n", 16 | "import h5py\n", 17 | "\n", 18 | "from functools import partial\n", 19 | "\n", 20 | "from models.utils import train, test, LpLoss, get_filter, UnitGaussianNormalizer" 21 | ], 22 | "outputs": [], 23 | "metadata": {} 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "source": [ 29 | "def get_initializer(name):\n", 30 | " \n", 31 | " if name == 'xavier_normal':\n", 32 | " init_ = partial(nn.init.xavier_normal_)\n", 33 | " elif name == 'kaiming_uniform':\n", 34 | " init_ = partial(nn.init.kaiming_uniform_)\n", 35 | " elif name == 'kaiming_normal':\n", 36 | " init_ = partial(nn.init.kaiming_normal_)\n", 37 | " return init_" 38 | ], 39 | "outputs": [], 40 | "metadata": {} 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "source": [ 46 | "class sparseKernel(nn.Module):\n", 47 | " def __init__(self,\n", 48 | " k, alpha, c=1, \n", 49 | " nl = 1,\n", 50 | " initializer = None,\n", 51 | " **kwargs):\n", 52 | " super(sparseKernel,self).__init__()\n", 53 | " \n", 54 | " self.k = k\n", 55 | " self.conv = self.convBlock(alpha*k**2, alpha*k**2)\n", 56 | " self.Lo = nn.Conv1d(alpha*k**2, c*k**2, 1)\n", 57 | " \n", 58 | " def forward(self, x):\n", 59 | " B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, Nx, Ny, T)\n", 60 | " x = x.reshape(B, -1, Nx, Ny, T)\n", 61 | " x = self.conv(x)\n", 62 | " x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T)\n", 63 | " return x\n", 64 | " \n", 65 | " \n", 66 | " def convBlock(self, ich, och):\n", 67 | " net = nn.Sequential(\n", 68 | " nn.Conv3d(och, och, 3, 1, 1),\n", 69 | " nn.ReLU(inplace=True),\n", 70 | " )\n", 71 | " return net \n", 72 | " \n", 73 | "\n", 74 | "def compl_mul3d(a, b):\n", 75 | " # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)\n", 76 | " op = partial(torch.einsum, \"bixyz,ioxyz->boxyz\")\n", 77 | " return torch.stack([\n", 78 | " op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]),\n", 79 | " op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1])\n", 80 | " ], dim=-1)\n", 81 | "\n", 82 | "\n", 83 | "# fft conv taken from: https://github.com/zongyi-li/fourier_neural_operator\n", 84 | "class sparseKernelFT(nn.Module):\n", 85 | " def __init__(self,\n", 86 | " k, alpha, c=1, \n", 87 | " nl = 1,\n", 88 | " initializer = None,\n", 89 | " **kwargs):\n", 90 | " super(sparseKernelFT, self).__init__() \n", 91 | " \n", 92 | " self.modes = alpha\n", 93 | "\n", 94 | " self.weights1 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2))\n", 95 | " self.weights2 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2)) \n", 96 | " self.weights3 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2)) \n", 97 | " self.weights4 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, 2)) \n", 98 | " nn.init.xavier_normal_(self.weights1)\n", 99 | " nn.init.xavier_normal_(self.weights2)\n", 100 | " nn.init.xavier_normal_(self.weights3)\n", 101 | " nn.init.xavier_normal_(self.weights4)\n", 102 | " \n", 103 | " self.Lo = nn.Conv1d(c*k**2, c*k**2, 1)\n", 104 | "# self.Wo = nn.Conv1d(c*k**2, c*k**2, 1)\n", 105 | " self.k = k\n", 106 | " \n", 107 | " def forward(self, x):\n", 108 | " B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, N, N, T)\n", 109 | " \n", 110 | " x = x.reshape(B, -1, Nx, Ny, T)\n", 111 | " x_fft = torch.rfft(x, 3, normalized=True, onesided=True)\n", 112 | " \n", 113 | " # Multiply relevant Fourier modes\n", 114 | " l1 = min(self.modes, Nx//2+1)\n", 115 | " l2 = min(self.modes, Ny//2+1)\n", 116 | " out_ft = torch.zeros(B, c*ich, Nx, Ny, T//2 +1, 2, device=x.device)\n", 117 | " \n", 118 | " out_ft[:, :, :l1, :l2, :self.modes] = compl_mul3d(\n", 119 | " x_fft[:, :, :l1, :l2, :self.modes], self.weights1[:, :, :l1, :l2, :])\n", 120 | " out_ft[:, :, -l1:, :l2, :self.modes] = compl_mul3d(\n", 121 | " x_fft[:, :, -l1:, :l2, :self.modes], self.weights2[:, :, :l1, :l2, :])\n", 122 | " out_ft[:, :, :l1, -l2:, :self.modes] = compl_mul3d(\n", 123 | " x_fft[:, :, :l1, -l2:, :self.modes], self.weights3[:, :, :l1, :l2, :])\n", 124 | " out_ft[:, :, -l1:, -l2:, :self.modes] = compl_mul3d(\n", 125 | " x_fft[:, :, -l1:, -l2:, :self.modes], self.weights4[:, :, :l1, :l2, :])\n", 126 | " \n", 127 | " #Return to physical space\n", 128 | " x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(Nx, Ny, T))\n", 129 | " \n", 130 | " x = F.relu(x)\n", 131 | " x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T)\n", 132 | " return x\n", 133 | " \n", 134 | " \n", 135 | "class MWT_CZ(nn.Module):\n", 136 | " def __init__(self,\n", 137 | " k = 3, alpha = 5, \n", 138 | " L = 0, c = 1,\n", 139 | " base = 'legendre',\n", 140 | " initializer = None,\n", 141 | " **kwargs):\n", 142 | " super(MWT_CZ, self).__init__()\n", 143 | " \n", 144 | " self.k = k\n", 145 | " self.L = L\n", 146 | " H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)\n", 147 | " H0r = H0@PHI0\n", 148 | " G0r = G0@PHI0\n", 149 | " H1r = H1@PHI1\n", 150 | " G1r = G1@PHI1\n", 151 | " \n", 152 | " H0r[np.abs(H0r)<1e-8]=0\n", 153 | " H1r[np.abs(H1r)<1e-8]=0\n", 154 | " G0r[np.abs(G0r)<1e-8]=0\n", 155 | " G1r[np.abs(G1r)<1e-8]=0\n", 156 | " \n", 157 | " self.A = sparseKernelFT(k, alpha, c)\n", 158 | " self.B = sparseKernelFT(k, alpha, c)\n", 159 | " self.C = sparseKernelFT(k, alpha, c)\n", 160 | " \n", 161 | " self.T0 = nn.Conv1d(c*k**2, c*k**2, 1)\n", 162 | "\n", 163 | " if initializer is not None:\n", 164 | " self.reset_parameters(initializer)\n", 165 | "\n", 166 | " self.register_buffer('ec_s', torch.Tensor(\n", 167 | " np.concatenate((np.kron(H0, H0).T, \n", 168 | " np.kron(H0, H1).T,\n", 169 | " np.kron(H1, H0).T,\n", 170 | " np.kron(H1, H1).T,\n", 171 | " ), axis=0)))\n", 172 | " self.register_buffer('ec_d', torch.Tensor(\n", 173 | " np.concatenate((np.kron(G0, G0).T,\n", 174 | " np.kron(G0, G1).T,\n", 175 | " np.kron(G1, G0).T,\n", 176 | " np.kron(G1, G1).T,\n", 177 | " ), axis=0)))\n", 178 | " \n", 179 | " self.register_buffer('rc_ee', torch.Tensor(\n", 180 | " np.concatenate((np.kron(H0r, H0r), \n", 181 | " np.kron(G0r, G0r),\n", 182 | " ), axis=0)))\n", 183 | " self.register_buffer('rc_eo', torch.Tensor(\n", 184 | " np.concatenate((np.kron(H0r, H1r), \n", 185 | " np.kron(G0r, G1r),\n", 186 | " ), axis=0)))\n", 187 | " self.register_buffer('rc_oe', torch.Tensor(\n", 188 | " np.concatenate((np.kron(H1r, H0r), \n", 189 | " np.kron(G1r, G0r),\n", 190 | " ), axis=0)))\n", 191 | " self.register_buffer('rc_oo', torch.Tensor(\n", 192 | " np.concatenate((np.kron(H1r, H1r), \n", 193 | " np.kron(G1r, G1r),\n", 194 | " ), axis=0)))\n", 195 | " \n", 196 | " \n", 197 | " def forward(self, x):\n", 198 | " \n", 199 | " B, c, ich, Nx, Ny, T = x.shape # (B, c, k^2, Nx, Ny, T)\n", 200 | " ns = math.floor(np.log2(Nx))\n", 201 | "\n", 202 | " Ud = torch.jit.annotate(List[Tensor], [])\n", 203 | " Us = torch.jit.annotate(List[Tensor], [])\n", 204 | "\n", 205 | "# decompose\n", 206 | " for i in range(ns-self.L):\n", 207 | " d, x = self.wavelet_transform(x)\n", 208 | " Ud += [self.A(d) + self.B(x)]\n", 209 | " Us += [self.C(d)]\n", 210 | " x = self.T0(x.reshape(B, c*ich, -1)).view(\n", 211 | " B, c, ich, 2**self.L, 2**self.L, T) # coarsest scale transform\n", 212 | "\n", 213 | "# reconstruct \n", 214 | " for i in range(ns-1-self.L,-1,-1):\n", 215 | " x = x + Us[i]\n", 216 | " x = torch.cat((x, Ud[i]), 2)\n", 217 | " x = self.evenOdd(x)\n", 218 | "\n", 219 | " return x\n", 220 | "\n", 221 | " \n", 222 | " def wavelet_transform(self, x):\n", 223 | " xa = torch.cat([x[:, :, :, ::2 , ::2 , :], \n", 224 | " x[:, :, :, ::2 , 1::2, :], \n", 225 | " x[:, :, :, 1::2, ::2 , :], \n", 226 | " x[:, :, :, 1::2, 1::2, :]\n", 227 | " ], 2)\n", 228 | " waveFil = partial(torch.einsum, 'bcixyt,io->bcoxyt') \n", 229 | " d = waveFil(xa, self.ec_d)\n", 230 | " s = waveFil(xa, self.ec_s)\n", 231 | " return d, s\n", 232 | " \n", 233 | " \n", 234 | " def evenOdd(self, x):\n", 235 | " \n", 236 | " B, c, ich, Nx, Ny, T = x.shape # (B, c, 2*k^2, Nx, Ny)\n", 237 | " assert ich == 2*self.k**2\n", 238 | " evOd = partial(torch.einsum, 'bcixyt,io->bcoxyt')\n", 239 | " x_ee = evOd(x, self.rc_ee)\n", 240 | " x_eo = evOd(x, self.rc_eo)\n", 241 | " x_oe = evOd(x, self.rc_oe)\n", 242 | " x_oo = evOd(x, self.rc_oo)\n", 243 | " \n", 244 | " x = torch.zeros(B, c, self.k**2, Nx*2, Ny*2, T,\n", 245 | " device = x.device)\n", 246 | " x[:, :, :, ::2 , ::2 , :] = x_ee\n", 247 | " x[:, :, :, ::2 , 1::2, :] = x_eo\n", 248 | " x[:, :, :, 1::2, ::2 , :] = x_oe\n", 249 | " x[:, :, :, 1::2, 1::2, :] = x_oo\n", 250 | " return x\n", 251 | " \n", 252 | " def reset_parameters(self, initializer):\n", 253 | " initializer(self.T0.weight)\n", 254 | " \n", 255 | " \n", 256 | "class MWT(nn.Module):\n", 257 | " def __init__(self,\n", 258 | " ich = 1, k = 3, alpha = 2, c = 1,\n", 259 | " nCZ = 3,\n", 260 | " L = 0,\n", 261 | " base = 'legendre',\n", 262 | " initializer = None,\n", 263 | " **kwargs):\n", 264 | " super(MWT,self).__init__()\n", 265 | " \n", 266 | " self.k = k\n", 267 | " self.c = c\n", 268 | " self.L = L\n", 269 | " self.nCZ = nCZ\n", 270 | " self.Lk = nn.Linear(ich, c*k**2)\n", 271 | " \n", 272 | " self.MWT_CZ = nn.ModuleList(\n", 273 | " [MWT_CZ(k, alpha, L, c, base, \n", 274 | " initializer) for _ in range(nCZ)]\n", 275 | " )\n", 276 | " self.BN = nn.ModuleList(\n", 277 | " [nn.BatchNorm3d(c*k**2) for _ in range(nCZ)]\n", 278 | " )\n", 279 | " self.Lc0 = nn.Linear(c*k**2, 128)\n", 280 | " self.Lc1 = nn.Linear(128, 1)\n", 281 | " \n", 282 | " if initializer is not None:\n", 283 | " self.reset_parameters(initializer)\n", 284 | " \n", 285 | " def forward(self, x):\n", 286 | " \n", 287 | " B, Nx, Ny, T, ich = x.shape # (B, Nx, Ny, T, d)\n", 288 | " ns = math.floor(np.log2(Nx))\n", 289 | " x = model.Lk(x)\n", 290 | " x = x.view(B, Nx, Ny, T, self.c, self.k**2)\n", 291 | " x = x.permute(0, 4, 5, 1, 2, 3)\n", 292 | " \n", 293 | " for i in range(self.nCZ):\n", 294 | " x = self.MWT_CZ[i](x)\n", 295 | " x = self.BN[i](x.view(B, -1, Nx, Ny, T)).view(\n", 296 | " B, self.c, self.k**2, Nx, Ny, T)\n", 297 | " if i < self.nCZ-1:\n", 298 | " x = F.relu(x)\n", 299 | "\n", 300 | " x = x.view(B, -1, Nx, Ny, T) # collapse c and k**2\n", 301 | " x = x.permute(0, 2, 3, 4, 1)\n", 302 | " x = self.Lc0(x)\n", 303 | " x = F.relu(x)\n", 304 | " x = self.Lc1(x)\n", 305 | " return x.squeeze()\n", 306 | " \n", 307 | " def reset_parameters(self, initializer):\n", 308 | " initializer(self.Lc0.weight)\n", 309 | " initializer(self.Lc1.weight)" 310 | ], 311 | "outputs": [], 312 | "metadata": {} 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 4, 317 | "source": [ 318 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 319 | ], 320 | "outputs": [], 321 | "metadata": {} 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 5, 326 | "source": [ 327 | "data_path = '../../data/ns_V1e-3_N5000_T50.mat'\n", 328 | "\n", 329 | "ntrain = 1000\n", 330 | "ntest = 200\n", 331 | "\n", 332 | "batch_size = 20" 333 | ], 334 | "outputs": [], 335 | "metadata": {} 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 6, 340 | "source": [ 341 | "sub = 1\n", 342 | "S = 64 // sub\n", 343 | "T_in = 10\n", 344 | "T = 40\n", 345 | "\n", 346 | "dataloader = h5py.File(data_path)\n", 347 | "u_data = dataloader['u']\n", 348 | "t_data = dataloader['u']\n", 349 | "\n", 350 | "train_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,:ntrain]\n", 351 | " ).permute(3, 1, 2, 0)\n", 352 | "train_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,:ntrain]\n", 353 | " ).permute(3, 1, 2, 0)\n", 354 | "\n", 355 | "test_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,-ntest:]\n", 356 | " ).permute(3, 1, 2, 0)\n", 357 | "test_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,-ntest:]\n", 358 | " ).permute(3, 1, 2, 0)\n", 359 | "\n", 360 | "print(train_u.shape)\n", 361 | "print(test_u.shape)\n", 362 | "assert (S == train_u.shape[-2])\n", 363 | "assert (T == train_u.shape[-1])" 364 | ], 365 | "outputs": [ 366 | { 367 | "output_type": "stream", 368 | "name": "stdout", 369 | "text": [ 370 | "torch.Size([1000, 64, 64, 40])\n", 371 | "torch.Size([200, 64, 64, 40])\n" 372 | ] 373 | } 374 | ], 375 | "metadata": {} 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 7, 380 | "source": [ 381 | "a_normalizer = UnitGaussianNormalizer(train_a)\n", 382 | "x_train = a_normalizer.encode(train_a)\n", 383 | "x_test = a_normalizer.encode(test_a)\n", 384 | "\n", 385 | "y_normalizer = UnitGaussianNormalizer(train_u)\n", 386 | "y_train = y_normalizer.encode(train_u)\n", 387 | "\n", 388 | "x_train = x_train.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1])\n", 389 | "x_test = x_test.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1])" 390 | ], 391 | "outputs": [], 392 | "metadata": {} 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 8, 397 | "source": [ 398 | "# pad locations (x,y,t)\n", 399 | "gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)\n", 400 | "gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1])\n", 401 | "gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)\n", 402 | "gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1])\n", 403 | "gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float)\n", 404 | "gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1])\n", 405 | "\n", 406 | "x_train = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]),\n", 407 | " gridt.repeat([ntrain,1,1,1,1]), x_train), dim=-1)\n", 408 | "x_test = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]),\n", 409 | " gridt.repeat([ntest,1,1,1,1]), x_test), dim=-1)" 410 | ], 411 | "outputs": [], 412 | "metadata": {} 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 9, 417 | "source": [ 418 | "batch_size = 10\n", 419 | "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n", 420 | "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, test_u), batch_size=batch_size, shuffle=False)" 421 | ], 422 | "outputs": [], 423 | "metadata": {} 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 10, 428 | "source": [ 429 | "ich = 13\n", 430 | "initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform\n", 431 | "\n", 432 | "torch.manual_seed(0)\n", 433 | "np.random.seed(0)\n", 434 | "\n", 435 | "alpha = 12\n", 436 | "c = 4\n", 437 | "k = 3\n", 438 | "nCZ = 4\n", 439 | "L = 0\n", 440 | "model = MWT(ich, \n", 441 | " alpha = alpha,\n", 442 | " c = c,\n", 443 | " k = k, \n", 444 | " base = 'legendre', # chebyshev\n", 445 | " nCZ = nCZ,\n", 446 | " L = L,\n", 447 | " initializer = initializer,\n", 448 | " ).to(device)\n", 449 | "learning_rate = 0.001\n", 450 | "\n", 451 | "epochs = 500\n", 452 | "step_size = 100\n", 453 | "gamma = 0.5" 454 | ], 455 | "outputs": [], 456 | "metadata": {} 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "source": [ 462 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n", 463 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n", 464 | "\n", 465 | "myloss = LpLoss(size_average=False)\n", 466 | "y_normalizer.cuda()\n", 467 | "\n", 468 | "for epoch in range(1, epochs+1):\n", 469 | " train_l2 = train(model, train_loader, optimizer, epoch, device,\n", 470 | " lossFn = myloss, lr_schedule = scheduler,\n", 471 | " post_proc = y_normalizer.decode)\n", 472 | " \n", 473 | " test_l2 = test(model, test_loader, device, lossFn=myloss, post_proc=y_normalizer.decode)\n", 474 | " print(f'epoch: {epoch}, train l2 = {train_l2}, test l2 = {test_l2}')" 475 | ], 476 | "outputs": [], 477 | "metadata": {} 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "source": [], 483 | "outputs": [], 484 | "metadata": {} 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "source": [], 490 | "outputs": [], 491 | "metadata": {} 492 | } 493 | ], 494 | "metadata": { 495 | "kernelspec": { 496 | "display_name": "torch_RL", 497 | "language": "python", 498 | "name": "torch_rl" 499 | }, 500 | "language_info": { 501 | "codemirror_mode": { 502 | "name": "ipython", 503 | "version": 3 504 | }, 505 | "file_extension": ".py", 506 | "mimetype": "text/x-python", 507 | "name": "python", 508 | "nbconvert_exporter": "python", 509 | "pygments_lexer": "ipython3", 510 | "version": "3.6.10" 511 | } 512 | }, 513 | "nbformat": 4, 514 | "nbformat_minor": 5 515 | } -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from typing import List, Tuple 8 | import math 9 | 10 | from utils import get_filter 11 | 12 | 13 | class sparseKernel1d(nn.Module): 14 | def __init__(self, 15 | k, alpha, c=1, 16 | nl = 1, 17 | initializer = None, 18 | **kwargs): 19 | super(sparseKernel1d,self).__init__() 20 | 21 | self.k = k 22 | self.Li = nn.Linear(c*k, 128) 23 | self.conv = self.convBlock(c*k, 128) 24 | self.Lo = nn.Linear(128, c*k) 25 | 26 | def forward(self, x): 27 | B, N, c, ich = x.shape # (B, N, c, k) 28 | x = x.view(B, N, -1) 29 | x = x.permute(0, 2, 1) 30 | x = self.conv(x) 31 | x = x.permute(0, 2, 1) 32 | x = self.Lo(x) 33 | x = x.view(B, N, c, ich) 34 | return x 35 | 36 | 37 | def convBlock(self, ich, och): 38 | net = nn.Sequential( 39 | nn.Conv1d(ich, och, 3, 1, 1), 40 | nn.ReLU(inplace=True), 41 | ) 42 | return net 43 | 44 | def compl_mul1d(x, weights): 45 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 46 | return torch.einsum("bix,iox->box", x, weights) 47 | 48 | class sparseKernelFT1d(nn.Module): 49 | def __init__(self, 50 | k, alpha, c=1, 51 | nl = 1, 52 | initializer = None, 53 | **kwargs): 54 | super(sparseKernelFT1d, self).__init__() 55 | 56 | self.modes1 = alpha 57 | self.scale = (1 / (c*k*c*k)) 58 | self.weights1 = nn.Parameter(self.scale * torch.rand(c*k, c*k, self.modes1, dtype=torch.cfloat)) 59 | self.weights1.requires_grad = True 60 | self.k = k 61 | 62 | def forward(self, x): 63 | B, N, c, k = x.shape # (B, N, c, k) 64 | 65 | x = x.view(B, N, -1) 66 | x = x.permute(0, 2, 1) 67 | x_fft = torch.fft.rfft(x) 68 | # Multiply relevant Fourier modes 69 | l = min(self.modes1, N//2+1) 70 | out_ft = torch.zeros(B, c*k, N//2 + 1, device=x.device, dtype=torch.cfloat) 71 | out_ft[:, :, :l] = compl_mul1d(x_fft[:, :, :l], self.weights1[:, :, :l]) 72 | 73 | #Return to physical space 74 | x = torch.fft.irfft(out_ft, n=N) 75 | x = x.permute(0, 2, 1).view(B, N, c, k) 76 | return x 77 | 78 | 79 | class MWT_CZ1d(nn.Module): 80 | def __init__(self, 81 | k = 3, alpha = 5, 82 | L = 0, c = 1, 83 | base = 'legendre', 84 | initializer = None, 85 | **kwargs): 86 | super(MWT_CZ1d, self).__init__() 87 | 88 | self.k = k 89 | self.L = L 90 | H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) 91 | H0r = H0@PHI0 92 | G0r = G0@PHI0 93 | H1r = H1@PHI1 94 | G1r = G1@PHI1 95 | 96 | H0r[np.abs(H0r)<1e-8]=0 97 | H1r[np.abs(H1r)<1e-8]=0 98 | G0r[np.abs(G0r)<1e-8]=0 99 | G1r[np.abs(G1r)<1e-8]=0 100 | 101 | self.A = sparseKernelFT1d(k, alpha, c) 102 | self.B = sparseKernelFT1d(k, alpha, c) 103 | self.C = sparseKernelFT1d(k, alpha, c) 104 | 105 | self.T0 = nn.Linear(k, k) 106 | 107 | self.register_buffer('ec_s', torch.Tensor( 108 | np.concatenate((H0.T, H1.T), axis=0))) 109 | self.register_buffer('ec_d', torch.Tensor( 110 | np.concatenate((G0.T, G1.T), axis=0))) 111 | 112 | self.register_buffer('rc_e', torch.Tensor( 113 | np.concatenate((H0r, G0r), axis=0))) 114 | self.register_buffer('rc_o', torch.Tensor( 115 | np.concatenate((H1r, G1r), axis=0))) 116 | 117 | 118 | def forward(self, x): 119 | 120 | B, N, c, ich = x.shape # (B, N, k) 121 | ns = math.floor(np.log2(N)) 122 | 123 | Ud = torch.jit.annotate(List[Tensor], []) 124 | Us = torch.jit.annotate(List[Tensor], []) 125 | # decompose 126 | for i in range(ns-self.L): 127 | d, x = self.wavelet_transform(x) 128 | Ud += [self.A(d) + self.B(x)] 129 | Us += [self.C(d)] 130 | x = self.T0(x) # coarsest scale transform 131 | 132 | # reconstruct 133 | for i in range(ns-1-self.L,-1,-1): 134 | x = x + Us[i] 135 | x = torch.cat((x, Ud[i]), -1) 136 | x = self.evenOdd(x) 137 | return x 138 | 139 | 140 | def wavelet_transform(self, x): 141 | xa = torch.cat([x[:, ::2, :, :], 142 | x[:, 1::2, :, :], 143 | ], -1) 144 | d = torch.matmul(xa, self.ec_d) 145 | s = torch.matmul(xa, self.ec_s) 146 | return d, s 147 | 148 | 149 | def evenOdd(self, x): 150 | 151 | B, N, c, ich = x.shape # (B, N, c, k) 152 | assert ich == 2*self.k 153 | x_e = torch.matmul(x, self.rc_e) 154 | x_o = torch.matmul(x, self.rc_o) 155 | 156 | x = torch.zeros(B, N*2, c, self.k, 157 | device = x.device) 158 | x[..., ::2, :, :] = x_e 159 | x[..., 1::2, :, :] = x_o 160 | return x 161 | 162 | 163 | class MWT1d(nn.Module): 164 | def __init__(self, 165 | ich = 1, k = 3, alpha = 2, c = 1, 166 | nCZ = 3, 167 | L = 0, 168 | base = 'legendre', 169 | initializer = None, 170 | **kwargs): 171 | super(MWT1d,self).__init__() 172 | 173 | self.k = k 174 | self.c = c 175 | self.L = L 176 | self.nCZ = nCZ 177 | self.Lk = nn.Linear(ich, c*k) 178 | 179 | self.MWT_CZ = nn.ModuleList( 180 | [MWT_CZ1d(k, alpha, L, c, base, 181 | initializer) for _ in range(nCZ)] 182 | ) 183 | self.Lc0 = nn.Linear(c*k, 128) 184 | self.Lc1 = nn.Linear(128, 1) 185 | 186 | if initializer is not None: 187 | self.reset_parameters(initializer) 188 | 189 | def forward(self, x): 190 | 191 | B, N, ich = x.shape # (B, N, d) 192 | ns = math.floor(np.log2(N)) 193 | x = self.Lk(x) 194 | x = x.view(B, N, self.c, self.k) 195 | 196 | for i in range(self.nCZ): 197 | x = self.MWT_CZ[i](x) 198 | if i < self.nCZ-1: 199 | x = F.relu(x) 200 | 201 | x = x.view(B, N, -1) # collapse c and k 202 | x = self.Lc0(x) 203 | x = F.relu(x) 204 | x = self.Lc1(x) 205 | return x.squeeze() 206 | 207 | def reset_parameters(self, initializer): 208 | initializer(self.Lc0.weight) 209 | initializer(self.Lc1.weight) 210 | 211 | 212 | 213 | class sparseKernel2d(nn.Module): 214 | def __init__(self, 215 | k, alpha, c=1, 216 | nl = 1, 217 | initializer = None, 218 | **kwargs): 219 | super(sparseKernel2d,self).__init__() 220 | 221 | self.k = k 222 | self.conv = self.convBlock(k, c*k**2, alpha) 223 | self.Lo = nn.Linear(alpha*k**2, c*k**2) 224 | 225 | def forward(self, x): 226 | B, Nx, Ny, c, ich = x.shape # (B, Nx, Ny, c, k**2) 227 | x = x.view(B, Nx, Ny, -1) 228 | x = x.permute(0, 3, 1, 2) 229 | x = self.conv(x) 230 | x = x.permute(0, 2, 3, 1) 231 | x = self.Lo(x) 232 | x = x.view(B, Nx, Ny, c, ich) 233 | 234 | return x 235 | 236 | 237 | def convBlock(self, k, W, alpha): 238 | och = alpha * k**2 239 | net = nn.Sequential( 240 | nn.Conv2d(W, och, 3, 1, 1), 241 | nn.ReLU(inplace=True), 242 | ) 243 | return net 244 | 245 | 246 | def compl_mul2d(x, weights): 247 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 248 | return torch.einsum("bixy,ioxy->boxy", x, weights) 249 | 250 | 251 | class sparseKernelFT2d(nn.Module): 252 | def __init__(self, 253 | k, alpha, c=1, 254 | nl = 1, 255 | initializer = None, 256 | **kwargs): 257 | super(sparseKernelFT2d, self).__init__() 258 | 259 | self.modes = alpha 260 | 261 | self.weights1 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, dtype=torch.cfloat)) 262 | self.weights2 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, dtype=torch.cfloat)) 263 | nn.init.xavier_normal_(self.weights1) 264 | nn.init.xavier_normal_(self.weights2) 265 | 266 | self.Lo = nn.Linear(c*k**2, c*k**2) 267 | self.k = k 268 | 269 | def forward(self, x): 270 | B, Nx, Ny, c, ich = x.shape # (B, N, N, c, k^2) 271 | 272 | x = x.view(B, Nx, Ny, -1) 273 | x = x.permute(0, 3, 1, 2) 274 | x_fft = torch.fft.rfft2(x) 275 | 276 | # Multiply relevant Fourier modes 277 | l1 = min(self.modes, Nx//2+1) 278 | l1l = min(self.modes, Nx//2-1) 279 | l2 = min(self.modes, Ny//2+1) 280 | out_ft = torch.zeros(B, c*ich, Nx, Ny//2 + 1, device=x.device, dtype=torch.cfloat) 281 | 282 | out_ft[:, :, :l1, :l2] = compl_mul2d( 283 | x_fft[:, :, :l1, :l2], self.weights1[:, :, :l1, :l2]) 284 | out_ft[:, :, -l1:, :l2] = compl_mul2d( 285 | x_fft[:, :, -l1:, :l2], self.weights2[:, :, :l1, :l2]) 286 | 287 | #Return to physical space 288 | x = torch.fft.irfft2(out_ft, s = (Nx, Ny)) 289 | 290 | x = x.permute(0, 2, 3, 1) 291 | x = F.relu(x) 292 | x = self.Lo(x) 293 | x = x.view(B, Nx, Ny, c, ich) 294 | return x 295 | 296 | 297 | class MWT_CZ2d(nn.Module): 298 | def __init__(self, 299 | k = 3, alpha = 5, 300 | L = 0, c = 1, 301 | base = 'legendre', 302 | initializer = None, 303 | **kwargs): 304 | super(MWT_CZ2d, self).__init__() 305 | 306 | self.k = k 307 | self.L = L 308 | H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) 309 | H0r = H0@PHI0 310 | G0r = G0@PHI0 311 | H1r = H1@PHI1 312 | G1r = G1@PHI1 313 | H0r[np.abs(H0r)<1e-8]=0 314 | H1r[np.abs(H1r)<1e-8]=0 315 | G0r[np.abs(G0r)<1e-8]=0 316 | G1r[np.abs(G1r)<1e-8]=0 317 | 318 | self.A = sparseKernelFT2d(k, alpha, c) 319 | self.B = sparseKernel2d(k, c, c) 320 | self.C = sparseKernel2d(k, c, c) 321 | 322 | self.T0 = nn.Linear(c*k**2, c*k**2) 323 | 324 | if initializer is not None: 325 | self.reset_parameters(initializer) 326 | 327 | self.register_buffer('ec_s', torch.Tensor( 328 | np.concatenate((np.kron(H0, H0).T, 329 | np.kron(H0, H1).T, 330 | np.kron(H1, H0).T, 331 | np.kron(H1, H1).T, 332 | ), axis=0))) 333 | self.register_buffer('ec_d', torch.Tensor( 334 | np.concatenate((np.kron(G0, G0).T, 335 | np.kron(G0, G1).T, 336 | np.kron(G1, G0).T, 337 | np.kron(G1, G1).T, 338 | ), axis=0))) 339 | 340 | self.register_buffer('rc_ee', torch.Tensor( 341 | np.concatenate((np.kron(H0r, H0r), 342 | np.kron(G0r, G0r), 343 | ), axis=0))) 344 | self.register_buffer('rc_eo', torch.Tensor( 345 | np.concatenate((np.kron(H0r, H1r), 346 | np.kron(G0r, G1r), 347 | ), axis=0))) 348 | self.register_buffer('rc_oe', torch.Tensor( 349 | np.concatenate((np.kron(H1r, H0r), 350 | np.kron(G1r, G0r), 351 | ), axis=0))) 352 | self.register_buffer('rc_oo', torch.Tensor( 353 | np.concatenate((np.kron(H1r, H1r), 354 | np.kron(G1r, G1r), 355 | ), axis=0))) 356 | 357 | 358 | def forward(self, x): 359 | 360 | B, Nx, Ny, c, ich = x.shape # (B, Nx, Ny, c, k**2) 361 | ns = math.floor(np.log2(Nx)) 362 | 363 | Ud = torch.jit.annotate(List[Tensor], []) 364 | Us = torch.jit.annotate(List[Tensor], []) 365 | 366 | # decompose 367 | for i in range(ns-self.L): 368 | d, x = self.wavelet_transform(x) 369 | Ud += [self.A(d) + self.B(x)] 370 | Us += [self.C(d)] 371 | x = self.T0(x.view(B, 2**self.L, 2**self.L, -1)).view( 372 | B, 2**self.L, 2**self.L, c, ich) # coarsest scale transform 373 | 374 | # reconstruct 375 | for i in range(ns-1-self.L,-1,-1): 376 | x = x + Us[i] 377 | x = torch.cat((x, Ud[i]), -1) 378 | x = self.evenOdd(x) 379 | 380 | return x 381 | 382 | 383 | def wavelet_transform(self, x): 384 | xa = torch.cat([x[:, ::2 , ::2 , :, :], 385 | x[:, ::2 , 1::2, :, :], 386 | x[:, 1::2, ::2 , :, :], 387 | x[:, 1::2, 1::2, :, :] 388 | ], -1) 389 | d = torch.matmul(xa, self.ec_d) 390 | s = torch.matmul(xa, self.ec_s) 391 | return d, s 392 | 393 | 394 | def evenOdd(self, x): 395 | 396 | B, Nx, Ny, c, ich = x.shape # (B, Nx, Ny, c, k**2) 397 | assert ich == 2*self.k**2 398 | x_ee = torch.matmul(x, self.rc_ee) 399 | x_eo = torch.matmul(x, self.rc_eo) 400 | x_oe = torch.matmul(x, self.rc_oe) 401 | x_oo = torch.matmul(x, self.rc_oo) 402 | 403 | x = torch.zeros(B, Nx*2, Ny*2, c, self.k**2, 404 | device = x.device) 405 | x[:, ::2 , ::2 , :, :] = x_ee 406 | x[:, ::2 , 1::2, :, :] = x_eo 407 | x[:, 1::2, ::2 , :, :] = x_oe 408 | x[:, 1::2, 1::2, :, :] = x_oo 409 | return x 410 | 411 | def reset_parameters(self, initializer): 412 | initializer(self.T0.weight) 413 | 414 | 415 | class MWT2d(nn.Module): 416 | def __init__(self, 417 | ich = 1, k = 3, alpha = 2, c = 1, 418 | nCZ = 3, 419 | L = 0, 420 | base = 'legendre', 421 | initializer = None, 422 | **kwargs): 423 | super(MWT2d,self).__init__() 424 | 425 | self.k = k 426 | self.c = c 427 | self.L = L 428 | self.nCZ = nCZ 429 | self.Lk = nn.Linear(ich, c*k**2) 430 | 431 | self.MWT_CZ = nn.ModuleList( 432 | [MWT_CZ2d(k, alpha, L, c, base, 433 | initializer) for _ in range(nCZ)] 434 | ) 435 | self.Lc0 = nn.Linear(c*k**2, 128) 436 | self.Lc1 = nn.Linear(128, 1) 437 | 438 | if initializer is not None: 439 | self.reset_parameters(initializer) 440 | 441 | def forward(self, x): 442 | 443 | B, Nx, Ny, ich = x.shape # (B, Nx, Ny, d) 444 | ns = math.floor(np.log2(Nx)) 445 | x = self.Lk(x) 446 | x = x.view(B, Nx, Ny, self.c, self.k**2) 447 | 448 | for i in range(self.nCZ): 449 | x = self.MWT_CZ[i](x) 450 | if i < self.nCZ-1: 451 | x = F.relu(x) 452 | 453 | x = x.view(B, Nx, Ny, -1) # collapse c and k**2 454 | x = self.Lc0(x) 455 | x = F.relu(x) 456 | x = self.Lc1(x) 457 | return x.squeeze() 458 | 459 | def reset_parameters(self, initializer): 460 | initializer(self.Lc0.weight) 461 | initializer(self.Lc1.weight) 462 | 463 | 464 | class sparseKernel(nn.Module): 465 | def __init__(self, 466 | k, alpha, c=1, 467 | nl = 1, 468 | initializer = None, 469 | **kwargs): 470 | super(sparseKernel,self).__init__() 471 | 472 | self.k = k 473 | self.conv = self.convBlock(k, c*k**2, alpha) 474 | self.Lo = nn.Linear(alpha*k**2, c*k**2) 475 | 476 | def forward(self, x): 477 | B, Nx, Ny, c, ich = x.shape # (B, Nx, Ny, c, k**2) 478 | x = x.view(B, Nx, Ny, -1) 479 | x = x.permute(0, 3, 1, 2) 480 | x = self.conv(x) 481 | x = x.permute(0, 2, 3, 1) 482 | x = self.Lo(x) 483 | x = x.view(B, Nx, Ny, c, ich) 484 | 485 | return x 486 | 487 | 488 | def convBlock(self, k, W, alpha): 489 | och = alpha * k**2 490 | net = nn.Sequential( 491 | nn.Conv2d(W, och, 3, 1, 1), 492 | nn.ReLU(inplace=True), 493 | ) 494 | return net 495 | 496 | 497 | class sparseKernel3d(nn.Module): 498 | def __init__(self, 499 | k, alpha, c=1, 500 | nl = 1, 501 | initializer = None, 502 | **kwargs): 503 | super(sparseKernel3d,self).__init__() 504 | 505 | self.k = k 506 | self.conv = self.convBlock(alpha*k**2, alpha*k**2) 507 | self.Lo = nn.Linear(alpha*k**2, c*k**2) 508 | 509 | def forward(self, x): 510 | B, Nx, Ny, T, c, ich = x.shape # (B, Nx, Ny, T, c, k**2) 511 | x = x.view(B, Nx, Ny, T, -1) 512 | x = x.permute(0, 4, 1, 2, 3) 513 | x = self.conv(x) 514 | x = x.permute(0, 2, 3, 4, 1) 515 | x = self.Lo(x) 516 | x = x.view(B, Nx, Ny, T, c, ich) 517 | 518 | return x 519 | 520 | 521 | def convBlock(self, ich, och): 522 | net = nn.Sequential( 523 | nn.Conv3d(och, och, 3, 1, 1), 524 | nn.ReLU(inplace=True), 525 | ) 526 | return net 527 | 528 | 529 | def compl_mul3d(input, weights): 530 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 531 | return torch.einsum("bixyz,ioxyz->boxyz", input, weights) 532 | 533 | 534 | class sparseKernelFT3d(nn.Module): 535 | def __init__(self, 536 | k, alpha, c=1, 537 | nl = 1, 538 | initializer = None, 539 | **kwargs): 540 | super(sparseKernelFT3d, self).__init__() 541 | 542 | self.modes = alpha 543 | 544 | self.weights1 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat)) 545 | self.weights2 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat)) 546 | self.weights3 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat)) 547 | self.weights4 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat)) 548 | nn.init.xavier_normal_(self.weights1) 549 | nn.init.xavier_normal_(self.weights2) 550 | nn.init.xavier_normal_(self.weights3) 551 | nn.init.xavier_normal_(self.weights4) 552 | 553 | self.Lo = nn.Linear(c*k**2, c*k**2) 554 | self.k = k 555 | 556 | def forward(self, x): 557 | B, Nx, Ny, T, c, ich = x.shape # (B, N, N, T, c, k^2) 558 | 559 | x = x.view(B, Nx, Ny, T, -1) 560 | x = x.permute(0, 4, 1, 2, 3) 561 | x_fft = torch.fft.rfftn(x, dim = [-3, -2, -1]) 562 | 563 | # Multiply relevant Fourier modes 564 | l1 = min(self.modes, Nx//2+1) 565 | l2 = min(self.modes, Ny//2+1) 566 | out_ft = torch.zeros(B, c*ich, Nx, Ny, T//2 +1, device=x.device, dtype=torch.cfloat) 567 | 568 | out_ft[:, :, :l1, :l2, :self.modes] = compl_mul3d( 569 | x_fft[:, :, :l1, :l2, :self.modes], self.weights1[:, :, :l1, :l2, :]) 570 | out_ft[:, :, -l1:, :l2, :self.modes] = compl_mul3d( 571 | x_fft[:, :, -l1:, :l2, :self.modes], self.weights2[:, :, :l1, :l2, :]) 572 | out_ft[:, :, :l1, -l2:, :self.modes] = compl_mul3d( 573 | x_fft[:, :, :l1, -l2:, :self.modes], self.weights3[:, :, :l1, :l2, :]) 574 | out_ft[:, :, -l1:, -l2:, :self.modes] = compl_mul3d( 575 | x_fft[:, :, -l1:, -l2:, :self.modes], self.weights4[:, :, :l1, :l2, :]) 576 | 577 | #Return to physical space 578 | x = torch.fft.irfftn(out_ft, s = (Nx, Ny, T)) 579 | 580 | x = x.permute(0, 2, 3, 4, 1) 581 | x = F.relu(x) 582 | x = self.Lo(x) 583 | x = x.view(B, Nx, Ny, T, c, ich) 584 | return x 585 | 586 | 587 | class MWT_CZ3d(nn.Module): 588 | def __init__(self, 589 | k = 3, alpha = 5, 590 | L = 0, c = 1, 591 | base = 'legendre', 592 | initializer = None, 593 | **kwargs): 594 | super(MWT_CZ3d, self).__init__() 595 | 596 | self.k = k 597 | self.L = L 598 | H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) 599 | H0r = H0@PHI0 600 | G0r = G0@PHI0 601 | H1r = H1@PHI1 602 | G1r = G1@PHI1 603 | 604 | H0r[np.abs(H0r)<1e-8]=0 605 | H1r[np.abs(H1r)<1e-8]=0 606 | G0r[np.abs(G0r)<1e-8]=0 607 | G1r[np.abs(G1r)<1e-8]=0 608 | 609 | self.A = sparseKernelFT3d(k, alpha, c) 610 | self.B = sparseKernel3d(k, c, c) 611 | self.C = sparseKernel3d(k, c, c) 612 | 613 | self.T0 = nn.Linear(c*k**2, c*k**2) 614 | 615 | if initializer is not None: 616 | self.reset_parameters(initializer) 617 | 618 | self.register_buffer('ec_s', torch.Tensor( 619 | np.concatenate((np.kron(H0, H0).T, 620 | np.kron(H0, H1).T, 621 | np.kron(H1, H0).T, 622 | np.kron(H1, H1).T, 623 | ), axis=0))) 624 | self.register_buffer('ec_d', torch.Tensor( 625 | np.concatenate((np.kron(G0, G0).T, 626 | np.kron(G0, G1).T, 627 | np.kron(G1, G0).T, 628 | np.kron(G1, G1).T, 629 | ), axis=0))) 630 | 631 | self.register_buffer('rc_ee', torch.Tensor( 632 | np.concatenate((np.kron(H0r, H0r), 633 | np.kron(G0r, G0r), 634 | ), axis=0))) 635 | self.register_buffer('rc_eo', torch.Tensor( 636 | np.concatenate((np.kron(H0r, H1r), 637 | np.kron(G0r, G1r), 638 | ), axis=0))) 639 | self.register_buffer('rc_oe', torch.Tensor( 640 | np.concatenate((np.kron(H1r, H0r), 641 | np.kron(G1r, G0r), 642 | ), axis=0))) 643 | self.register_buffer('rc_oo', torch.Tensor( 644 | np.concatenate((np.kron(H1r, H1r), 645 | np.kron(G1r, G1r), 646 | ), axis=0))) 647 | 648 | 649 | def forward(self, x): 650 | 651 | B, Nx, Ny, T, c, ich = x.shape # (B, Nx, Ny, T, c, k**2) 652 | ns = math.floor(np.log2(Nx)) 653 | 654 | Ud = torch.jit.annotate(List[Tensor], []) 655 | Us = torch.jit.annotate(List[Tensor], []) 656 | 657 | # decompose 658 | for i in range(ns-self.L): 659 | d, x = self.wavelet_transform(x) 660 | Ud += [self.A(d) + self.B(x)] 661 | Us += [self.C(d)] 662 | x = self.T0(x.view(B, 2**self.L, 2**self.L, T, -1)).view( 663 | B, 2**self.L, 2**self.L, T, c, ich) # coarsest scale transform 664 | 665 | # reconstruct 666 | for i in range(ns-1-self.L,-1,-1): 667 | x = x + Us[i] 668 | x = torch.cat((x, Ud[i]), -1) 669 | x = self.evenOdd(x) 670 | 671 | return x 672 | 673 | 674 | def wavelet_transform(self, x): 675 | xa = torch.cat([x[:, ::2 , ::2 , :, :, :], 676 | x[:, ::2 , 1::2, :, :, :], 677 | x[:, 1::2, ::2 , :, :, :], 678 | x[:, 1::2, 1::2, :, :, :] 679 | ], -1) 680 | d = torch.matmul(xa, self.ec_d) 681 | s = torch.matmul(xa, self.ec_s) 682 | return d, s 683 | 684 | 685 | def evenOdd(self, x): 686 | 687 | B, Nx, Ny, T, c, ich = x.shape # (B, Nx, Ny, c, k**2) 688 | assert ich == 2*self.k**2 689 | x_ee = torch.matmul(x, self.rc_ee) 690 | x_eo = torch.matmul(x, self.rc_eo) 691 | x_oe = torch.matmul(x, self.rc_oe) 692 | x_oo = torch.matmul(x, self.rc_oo) 693 | 694 | x = torch.zeros(B, Nx*2, Ny*2, T, c, self.k**2, 695 | device = x.device) 696 | x[:, ::2 , ::2 , :, :, :] = x_ee 697 | x[:, ::2 , 1::2, :, :, :] = x_eo 698 | x[:, 1::2, ::2 , :, :, :] = x_oe 699 | x[:, 1::2, 1::2, :, :, :] = x_oo 700 | return x 701 | 702 | def reset_parameters(self, initializer): 703 | initializer(self.T0.weight) 704 | 705 | 706 | class MWT3d(nn.Module): 707 | def __init__(self, 708 | ich = 1, k = 3, alpha = 2, c = 1, 709 | nCZ = 3, 710 | L = 0, 711 | base = 'legendre', 712 | initializer = None, 713 | **kwargs): 714 | super(MWT3d,self).__init__() 715 | 716 | self.k = k 717 | self.c = c 718 | self.L = L 719 | self.nCZ = nCZ 720 | self.Lk = nn.Linear(ich, c*k**2) 721 | 722 | self.MWT_CZ = nn.ModuleList( 723 | [MWT_CZ3d(k, alpha, L, c, base, 724 | initializer) for _ in range(nCZ)] 725 | ) 726 | self.Lc0 = nn.Linear(c*k**2, 128) 727 | self.Lc1 = nn.Linear(128, 1) 728 | 729 | if initializer is not None: 730 | self.reset_parameters(initializer) 731 | 732 | def forward(self, x): 733 | 734 | B, Nx, Ny, T, ich = x.shape # (B, Nx, Ny, T, d) 735 | ns = math.floor(np.log2(Nx)) 736 | x = self.Lk(x) 737 | x = x.view(B, Nx, Ny, T, self.c, self.k**2) 738 | 739 | for i in range(self.nCZ): 740 | x = self.MWT_CZ[i](x) 741 | if i < self.nCZ-1: 742 | x = F.relu(x) 743 | 744 | x = x.view(B, Nx, Ny, T, -1) # collapse c and k**2 745 | x = self.Lc0(x) 746 | x = F.relu(x) 747 | x = self.Lc1(x) 748 | return x.squeeze() 749 | 750 | def reset_parameters(self, initializer): 751 | initializer(self.Lc0.weight) 752 | initializer(self.Lc1.weight) --------------------------------------------------------------------------------