├── License.txt ├── README.md ├── SECURITY.md ├── exampleBurgers1d.m ├── fno └── spectralConvolution1dLayer.m └── images ├── figure_0.png └── figure_1.png /License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, The MathWorks, Inc. 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fourier Neural Operator 2 | 3 | [![View on File Exchange](https://www.mathworks.com/matlabcentral/images/matlab-file-exchange.svg)](https://www.mathworks.com/matlabcentral/fileexchange/####-file-exchange-title) 4 | 5 | 6 | The Fourier Neural Operator (FNO) [1] is a neural operator with an integral kernel parameterized in Fourier space. This allows for an expressive and efficient architecture. Applications of the FNO include weather forecasting and, more generically, finding efficient solutions to the Navier-Stokes equations which govern fluid flow. 7 | 8 | ## Setup 9 | Add `fno` directory to the path. 10 | 11 | ```matlab 12 | addpath(genpath('fno')); 13 | ``` 14 | 15 | ## Requirements 16 | 17 | Requires: 18 | - [MATLAB](https://www.mathworks.com/products/matlab.html) (R2021b or newer) 19 | - [Deep Learning Toolbox™](https://www.mathworks.com/products/deep-learning.html) 20 | 21 | ## References 22 | [1] Zongyi Li, Nikola Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhattacharya, Andrew Stuart, and Anima 23 | Anandkumar. Fourier Neural Operator for Parametric Partial Differential Equations. In International Conference on 24 | Learning Representations (ICLR), 2021a. (https://openreview.net/forum?id=c8P9NQVtmnO) 25 | 26 | # Example: Fourier Neural Operator for 1d Burgers' Equation 27 | 28 | In this example we apply the Fourier Neural Operator to learn the one-dimensional Burgers' equation with the following definition: 29 | 30 | > , 31 | 32 | > , 33 | 34 | where and is the Reynolds number. Periodic boundary conditions are imposed across the spatial domain. We learn the operator mapping the initial condition to the solution at time : . 35 | 36 | ## Data preparation 37 | 38 | We use the `burgers_data_R10.mat`, which contains initial velocities and solutions of the Burgers' equation. We then use these as training inputs and targets respectively. The network inputs include the spatial domain at the desired discretization. In this example we choose a grid size of . 39 | 40 | ```matlab:Code 41 | % Setup. 42 | addpath(genpath('fno')); 43 | 44 | % Download training data. 45 | dataDir = fullfile('data'); 46 | if ~isfolder(dataDir) 47 | mkdir(dataDir); 48 | end 49 | dataFile = fullfile(dataDir,'burgers_data_R10.mat'); 50 | if ~exist(dataFile, 'file') 51 | location = 'https://ssd.mathworks.com/supportfiles/nnet/data/burgers1d/burgers_data_R10.mat'; 52 | websave(dataFile, location); 53 | end 54 | data = load(dataFile, 'a', 'u'); 55 | x = data.a; 56 | t = data.u; 57 | 58 | % Specify the number of observations in training and test data, respectively. 59 | numTrain = 1e3; 60 | numTest = 1e2; 61 | 62 | % Specify grid size and downsampling factor. 63 | h = 2^10; 64 | n = size(x,2); 65 | ns = floor(n./h); 66 | 67 | % Downsample the data for training. 68 | xTrain = x(1:numTrain, 1:ns:n); 69 | tTrain = t(1:numTrain, 1:ns:n); 70 | xTest = x(end-numTest+1:end, 1:ns:n); 71 | tTest = t(end-numTest+1:end, 1:ns:n); 72 | 73 | % Define the grid over the spatial domain x. 74 | xmax = 1; 75 | xgrid = linspace(0, xmax, h); 76 | 77 | % Combine initial velocities and spatial grid to create network 78 | % predictors. 79 | xTrain = cat(3, xTrain, repmat(xgrid, [numTrain 1])); 80 | xTest = cat(3, xTest, repmat(xgrid, [numTest 1])); 81 | ``` 82 | 83 | ## Define network architecture 84 | 85 | Here we create a `dlnetwork` for the Burgers' equation problem. The network accepts inputs of dimension `[h 2 miniBatchSize]`, and returns outputs of dimension `[h 1 miniBatchSize]`. The network consists os multiple blocks which combine spectral convolution with regular, linear convolution. The convolution in Fourier space filters out higher order oscillations in the solution, while the linear convolution learns local correlations. 86 | 87 | ```matlab:Code 88 | numModes = 16; 89 | width = 64; 90 | 91 | lg = layerGraph([ ... 92 | convolution1dLayer(1, width, Name='fc0') 93 | 94 | spectralConvolution1dLayer(width, numModes, Name='specConv1') 95 | additionLayer(2, Name='add1') 96 | reluLayer(Name='relu1') 97 | 98 | spectralConvolution1dLayer(width, numModes, Name='specConv2') 99 | additionLayer(2, Name='add2') 100 | reluLayer(Name='relu2') 101 | 102 | spectralConvolution1dLayer(width, numModes, Name='specConv3') 103 | additionLayer(2, Name='add3') 104 | reluLayer(Name='relu3') 105 | 106 | spectralConvolution1dLayer(width, numModes, Name='specConv4') 107 | additionLayer(2, Name='add4') 108 | 109 | convolution1dLayer(1, 128, Name='fc5') 110 | reluLayer(Name='relu5') 111 | convolution1dLayer(1, 1, Name='fc6') 112 | ]); 113 | 114 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc1')); 115 | lg = connectLayers(lg, 'fc0', 'fc1'); 116 | lg = connectLayers(lg, 'fc1', 'add1/in2'); 117 | 118 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc2')); 119 | lg = connectLayers(lg, 'relu1', 'fc2'); 120 | lg = connectLayers(lg, 'fc2', 'add2/in2'); 121 | 122 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc3')); 123 | lg = connectLayers(lg, 'relu2', 'fc3'); 124 | lg = connectLayers(lg, 'fc3', 'add3/in2'); 125 | 126 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc4')); 127 | lg = connectLayers(lg, 'relu3', 'fc4'); 128 | lg = connectLayers(lg, 'fc4', 'add4/in2'); 129 | 130 | numInputChannels = 2; 131 | XInit = dlarray(ones([h numInputChannels 1]), 'SCB'); 132 | net = dlnetwork(lg, XInit); 133 | 134 | analyzeNetwork(net) 135 | ``` 136 | 137 | ## Training options 138 | 139 | The network is trained using the standard SGDM algorithm, where the learn rate is decreased every `stepSize` iterations. 140 | 141 | ```matlab:Code 142 | executionEnvironment = "gpu"; 143 | 144 | batchSize = 20; 145 | learnRate = 1e-3; 146 | momentum = 0.9; 147 | 148 | numEpochs = 20; 149 | stepSize = 100; 150 | gamma = 0.5; 151 | expNum = 1; 152 | checkpoint = false; 153 | expDir = sprintf( 'checkpoints/run%g', expNum ); 154 | if ~isfolder( expDir ) && checkpoint 155 | mkdir(expDir) 156 | end 157 | 158 | vel = []; 159 | totalIter = 0; 160 | 161 | numTrain = size(xTrain,1); 162 | numIterPerEpoch = floor(numTrain./batchSize); 163 | ``` 164 | 165 | ## Training loop 166 | 167 | Train the network. 168 | 169 | ```matlab:Code 170 | if executionEnvironment == "gpu" && canUseGPU 171 | xTrain = gpuArray(xTrain); 172 | xTest = gpuArray(xTest); 173 | end 174 | 175 | start = tic; 176 | figure; 177 | clf 178 | lineLossTrain = animatedline('Color', [0 0.4470 0.7410]); 179 | lineLossTest = animatedline('Color', 'k', 'LineStyle', '--'); 180 | ylim([0 inf]) 181 | xlabel("Iteration") 182 | ylabel("Loss") 183 | grid on 184 | 185 | % Compute initial validation loss. 186 | y = net.predict( dlarray(xTest, 'BSC') ); 187 | yTest = extractdata(permute(stripdims(y), [3 1 2])); 188 | relLossTest = relativeL2Loss(yTest , tTest); 189 | addpoints(lineLossTest, 0, double(relLossTest/size(xTest,1))) 190 | 191 | % Main loop. 192 | lossfun = dlaccelerate(@modelLoss); 193 | for epoch = 1:numEpochs 194 | % Shuffle the data. 195 | dataIdx = randperm(numTrain); 196 | 197 | for iter = 1:numIterPerEpoch 198 | % Get mini-batch data. 199 | batchIdx = (1:batchSize) + (iter-1)*batchSize; 200 | idx = dataIdx(batchIdx); 201 | X = dlarray( xTrain(batchIdx, :, :), 'BSC' ); 202 | T = tTrain(batchIdx, :); 203 | 204 | % Compute loss and gradients. 205 | [loss, dnet] = dlfeval(lossfun, X, T, net); 206 | 207 | % Update model parameters using SGDM update rule. 208 | [net, vel] = sgdmupdate(net, dnet, vel, learnRate, momentum); 209 | 210 | % Plot training progress. 211 | totalIter = totalIter + 1; 212 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 213 | addpoints(lineLossTrain,totalIter,double(extractdata(loss/batchSize))) 214 | title("Epoch: " + epoch + ", Elapsed: " + string(D)) 215 | drawnow 216 | 217 | % Learn rate scheduling. 218 | if mod(totalIter, stepSize) == 0 219 | learnRate = gamma.*learnRate; 220 | end 221 | end 222 | % Compute validation loss and MSE. 223 | y = net.predict( dlarray(xTest, 'BSC') ); 224 | yTest = extractdata(permute(stripdims(y), [3 1 2])); 225 | relLossTest = relativeL2Loss( yTest , tTest ); 226 | mseTest = mean( (yTest(:) - tTest(:)).^2 ); 227 | 228 | % Display progress. 229 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 230 | numTest = size(xTest, 1); 231 | fprintf('Epoch = %g, train loss = %g, val loss = %g, val mse = %g, total time = %s. \n', ... 232 | epoch, extractdata(loss)/batchSize, relLossTest/numTest, mseTest/numTest, string(D)); 233 | addpoints(lineLossTest, totalIter, double(relLossTest/numTest)) 234 | 235 | % Checkpoints. 236 | if checkpoint 237 | filename = sprintf('checkpoints/run%g/epoch%g.mat', expNum, epoch); 238 | save(filename, 'net', 'epoch', 'vel', 'totalIter', 'relLossTest', 'mseTest', 'learnRate'); 239 | end 240 | end 241 | ``` 242 | 243 | ```text:Output 244 | Epoch = 1, train loss = 0.226405, val loss = 0.175389, val mse = 7.73286e-05, total time = 00:00:13. 245 | Epoch = 2, train loss = 0.153691, val loss = 0.145805, val mse = 5.99213e-05, total time = 00:00:22. 246 | Epoch = 3, train loss = 0.0923258, val loss = 0.0904608, val mse = 2.49174e-05, total time = 00:00:27. 247 | Epoch = 4, train loss = 0.102122, val loss = 0.0639219, val mse = 1.43723e-05, total time = 00:00:32. 248 | Epoch = 5, train loss = 0.0346076, val loss = 0.0393621, val mse = 9.33419e-06, total time = 00:00:36. 249 | Epoch = 6, train loss = 0.0361029, val loss = 0.032303, val mse = 7.10724e-06, total time = 00:00:45. 250 | Epoch = 7, train loss = 0.0270364, val loss = 0.0296161, val mse = 6.43696e-06, total time = 00:00:50. 251 | Epoch = 8, train loss = 0.0263171, val loss = 0.0283881, val mse = 5.92292e-06, total time = 00:00:54. 252 | Epoch = 9, train loss = 0.0248211, val loss = 0.0261364, val mse = 5.54218e-06, total time = 00:00:58. 253 | Epoch = 10, train loss = 0.0243392, val loss = 0.0253596, val mse = 5.32946e-06, total time = 00:01:03. 254 | Epoch = 11, train loss = 0.0236119, val loss = 0.0250861, val mse = 5.22886e-06, total time = 00:01:07. 255 | Epoch = 12, train loss = 0.023318, val loss = 0.024752, val mse = 5.12552e-06, total time = 00:01:11. 256 | Epoch = 13, train loss = 0.0230901, val loss = 0.0243369, val mse = 5.04185e-06, total time = 00:01:16. 257 | Epoch = 14, train loss = 0.0229644, val loss = 0.0241713, val mse = 4.99882e-06, total time = 00:01:20. 258 | Epoch = 15, train loss = 0.0228391, val loss = 0.0240904, val mse = 4.99516e-06, total time = 00:01:25. 259 | Epoch = 16, train loss = 0.022768, val loss = 0.0240143, val mse = 4.97173e-06, total time = 00:01:29. 260 | Epoch = 17, train loss = 0.0228152, val loss = 0.023916, val mse = 4.95474e-06, total time = 00:01:33. 261 | Epoch = 18, train loss = 0.022787, val loss = 0.0238792, val mse = 4.94643e-06, total time = 00:01:38. 262 | Epoch = 19, train loss = 0.0227602, val loss = 0.023865, val mse = 4.93665e-06, total time = 00:01:42. 263 | Epoch = 20, train loss = 0.0227464, val loss = 0.0238451, val mse = 4.93358e-06, total time = 00:01:46. 264 | ``` 265 | 266 | ![figure_0.png](images/figure_0.png) 267 | 268 | ## Test on unseen, higher resolution data 269 | 270 | Here we take the trained network and test on unseen data with a higher spatial resolution than the training data. This is an example of zero-shot super-resolution. 271 | 272 | ```matlab:Code 273 | gridHighRes = linspace(0, xmax, n); 274 | 275 | idxToPlot = numTrain+(1:4); 276 | figure; 277 | for p = 1:4 278 | xn = dlarray(cat(1, x(idxToPlot(p),:), gridHighRes),'CSB'); 279 | yn = predict(net, xn); 280 | 281 | subplot(2, 2, p) 282 | plot(gridHighRes, t(idxToPlot(p),:)), hold on, plot(gridHighRes, extractdata(yn)) 283 | axis tight 284 | xlabel('x') 285 | ylabel('U') 286 | end 287 | ``` 288 | 289 | ![figure_1.png](images/figure_1.png) 290 | 291 | ### Helper functions 292 | 293 | ```matlab:Code 294 | function [loss, grad] = modelLoss(x, t, net) 295 | y = net.forward(x); 296 | y = permute(stripdims(y), [3 1 2]); 297 | y = stripdims(y); 298 | 299 | loss = relativeL2Loss(y, t); 300 | 301 | grad = dlgradient(loss, net.Learnables); 302 | end 303 | 304 | function loss = relativeL2Loss(y, t) 305 | diffNorms = normFcn( (y - t) ); 306 | tNorms = normFcn( t ); 307 | 308 | loss = sum(diffNorms./tNorms, 1); 309 | end 310 | 311 | function n = normFcn(x) 312 | n = sqrt( sum(x.^2, 2) ); 313 | end 314 | ``` 315 | 316 | Copyright 2022-2023 The MathWorks, Inc. 317 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Vulnerabilities 2 | 3 | If you believe you have discovered a security vulnerability, please report it to 4 | [security@mathworks.com](mailto:security@mathworks.com). Please see 5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html) 6 | for additional information. 7 | -------------------------------------------------------------------------------- /exampleBurgers1d.m: -------------------------------------------------------------------------------- 1 | %% Fourier Neural Operator for 1d Burgers' Equation 2 | % In this example we apply the to learn the one-dimensional Burgers' equation with the following 4 | % definition: 5 | % 6 | % $\frac{\partial u}{\partial t} + u\frac{\partial u}{\partial x} = \frac{1}{Re}\frac{\partial^2 7 | % u}{\partial x^2}$, $x \in (0,1),\space t \in (0,1]$ 8 | % 9 | % $u(x,0) = u_0(x) $, $x \in (0,1)$ 10 | % 11 | % where $u=u\left(x,t\right)$ and $Re$ is the Reynolds number. Periodic boundary 12 | % conditions are imposed across the spatial domain. We learn the operator mapping 13 | % the initial condition $u_0$ to the solution at time $t=1$: $u_0 \longmapsto 14 | % u\left(x,1\right)$. 15 | %% Data preparation 16 | % We use the burgers_data_R10.mat, which contains initial velocities $u_0$ and 17 | % solutions $u\left(x,1\right)$ of the Burgers' equation which we use as training 18 | % inputs and targets respectively. The network inputs also consist of the spatial 19 | % domain $x=\left(0,1\right)$ at the desired discretization. In this example we 20 | % choose a grid size of $h=2^{10}$. 21 | 22 | % Download training data. 23 | dataDir = fullfile('data'); 24 | if ~isfolder(dataDir) 25 | mkdir(dataDir); 26 | end 27 | dataFile = fullfile(dataDir,'burgers_data_R10.mat'); 28 | if ~exist(dataFile, 'file') 29 | location = 'https://ssd.mathworks.com/supportfiles/nnet/data/burgers1d/burgers_data_R10.mat'; 30 | websave(dataFile, location); 31 | end 32 | data = load(dataFile, 'a', 'u'); 33 | x = data.a; 34 | t = data.u; 35 | 36 | % Setup. 37 | addpath(genpath('fno')); 38 | 39 | % Specify the number of observations in training and test data, respectively. 40 | numTrain = 1e3; 41 | numTest = 1e2; 42 | 43 | % Specify grid size and downsampling factor. 44 | h = 2^10; 45 | n = size(x,2); 46 | ns = floor(n./h); 47 | 48 | % Downsample the data for training. 49 | xTrain = x(1:numTrain, 1:ns:n); 50 | tTrain = t(1:numTrain, 1:ns:n); 51 | xTest = x(end-numTest+1:end, 1:ns:n); 52 | tTest = t(end-numTest+1:end, 1:ns:n); 53 | 54 | % Define the grid over the spatial domain x. 55 | xmax = 1; 56 | xgrid = linspace(0, xmax, h); 57 | 58 | % Combine initial velocities and spatial grid to create network 59 | % predictors. 60 | xTrain = cat(3, xTrain, repmat(xgrid, [numTrain 1])); 61 | xTest = cat(3, xTest, repmat(xgrid, [numTest 1])); 62 | %% Define network architecture 63 | % Here we create a |dlnetwork| for the Burgers' equation problem. The network 64 | % accepts inputs of dimension |[h 2 miniBatchSize]|, and returns outputs of dimension 65 | % |[h 1 miniBatchSize]|. The network consists os multiple blocks which combine 66 | % spectral convolution with regular, linear convolution. The convolution in Fourier 67 | % space filters out higher order oscillations in the solution, while the linear 68 | % convolution learns local correlations. 69 | 70 | numModes = 16; 71 | width = 64; 72 | 73 | lg = layerGraph([ ... 74 | convolution1dLayer(1, width, Name='fc0') 75 | 76 | spectralConvolution1dLayer(width, numModes, Name='specConv1') 77 | additionLayer(2, Name='add1') 78 | reluLayer(Name='relu1') 79 | 80 | spectralConvolution1dLayer(width, numModes, Name='specConv2') 81 | additionLayer(2, Name='add2') 82 | reluLayer(Name='relu2') 83 | 84 | spectralConvolution1dLayer(width, numModes, Name='specConv3') 85 | additionLayer(2, Name='add3') 86 | reluLayer(Name='relu3') 87 | 88 | spectralConvolution1dLayer(width, numModes, Name='specConv4') 89 | additionLayer(2, Name='add4') 90 | 91 | convolution1dLayer(1, 128, Name='fc5') 92 | reluLayer(Name='relu5') 93 | convolution1dLayer(1, 1, Name='fc6') 94 | ]); 95 | 96 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc1')); 97 | lg = connectLayers(lg, 'fc0', 'fc1'); 98 | lg = connectLayers(lg, 'fc1', 'add1/in2'); 99 | 100 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc2')); 101 | lg = connectLayers(lg, 'relu1', 'fc2'); 102 | lg = connectLayers(lg, 'fc2', 'add2/in2'); 103 | 104 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc3')); 105 | lg = connectLayers(lg, 'relu2', 'fc3'); 106 | lg = connectLayers(lg, 'fc3', 'add3/in2'); 107 | 108 | lg = addLayers(lg, convolution1dLayer(1, width, Name='fc4')); 109 | lg = connectLayers(lg, 'relu3', 'fc4'); 110 | lg = connectLayers(lg, 'fc4', 'add4/in2'); 111 | 112 | numInputChannels = 2; 113 | XInit = dlarray(ones([h numInputChannels 1]), 'SCB'); 114 | net = dlnetwork(lg, XInit); 115 | 116 | analyzeNetwork(net) 117 | %% Training options 118 | % The network is trained using the standard SGDM algorithm, which 119 | 120 | executionEnvironment = "gpu"; 121 | 122 | batchSize = 20; 123 | learnRate = 1e-3; 124 | momentum = 0.9; 125 | 126 | numEpochs = 20; 127 | stepSize = 100; 128 | gamma = 0.5; 129 | expNum = 1; 130 | checkpoint = false; 131 | expDir = sprintf( 'checkpoints/run%g', expNum ); 132 | if ~isfolder( expDir ) && checkpoint 133 | mkdir(expDir) 134 | end 135 | 136 | vel = []; 137 | totalIter = 0; 138 | 139 | numTrain = size(xTrain,1); 140 | numIterPerEpoch = floor(numTrain./batchSize); 141 | %% Training loop 142 | % Train the network. 143 | 144 | if executionEnvironment == "gpu" && canUseGPU 145 | xTrain = gpuArray(xTrain); 146 | xTest = gpuArray(xTest); 147 | end 148 | 149 | start = tic; 150 | figure; 151 | clf 152 | lineLossTrain = animatedline('Color', [0 0.4470 0.7410]); 153 | lineLossTest = animatedline('Color', 'k', 'LineStyle', '--'); 154 | ylim([0 inf]) 155 | xlabel("Iteration") 156 | ylabel("Loss") 157 | grid on 158 | 159 | % Compute initial validation loss. 160 | y = net.predict( dlarray(xTest, 'BSC') ); 161 | yTest = extractdata(permute(stripdims(y), [3 1 2])); 162 | relLossTest = relativeL2Loss(yTest , tTest); 163 | addpoints(lineLossTest, 0, double(relLossTest/size(xTest,1))) 164 | 165 | % Main loop. 166 | lossfun = dlaccelerate(@modelLoss); 167 | for epoch = 1:numEpochs 168 | % Shuffle the data. 169 | dataIdx = randperm(numTrain); 170 | 171 | for iter = 1:numIterPerEpoch 172 | % Get mini-batch data. 173 | batchIdx = (1:batchSize) + (iter-1)*batchSize; 174 | idx = dataIdx(batchIdx); 175 | X = dlarray( xTrain(batchIdx, :, :), 'BSC' ); 176 | T = tTrain(batchIdx, :); 177 | 178 | % Compute loss and gradients. 179 | [loss, dnet] = dlfeval(lossfun, X, T, net); 180 | 181 | % Update model parameters using SGDM update rule. 182 | [net, vel] = sgdmupdate(net, dnet, vel, learnRate, momentum); 183 | 184 | % Plot training progress. 185 | totalIter = totalIter + 1; 186 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 187 | addpoints(lineLossTrain,totalIter,double(extractdata(loss/batchSize))) 188 | title("Epoch: " + epoch + ", Elapsed: " + string(D)) 189 | drawnow 190 | 191 | % Learn rate scheduling. 192 | if mod(totalIter, stepSize) == 0 193 | learnRate = gamma.*learnRate; 194 | end 195 | end 196 | % Compute validation loss and MSE. 197 | y = net.predict( dlarray(xTest, 'BSC') ); 198 | yTest = extractdata(permute(stripdims(y), [3 1 2])); 199 | relLossTest = relativeL2Loss( yTest , tTest ); 200 | mseTest = mean( (yTest(:) - tTest(:)).^2 ); 201 | 202 | % Display progress. 203 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 204 | numTest = size(xTest, 1); 205 | fprintf('Epoch = %g, train loss = %g, val loss = %g, val mse = %g, total time = %s. \n', ... 206 | epoch, extractdata(loss)/batchSize, relLossTest/numTest, mseTest/numTest, string(D)); 207 | addpoints(lineLossTest, totalIter, double(relLossTest/numTest)) 208 | 209 | % Checkpoints. 210 | if checkpoint 211 | filename = sprintf('checkpoints/run%g/epoch%g.mat', expNum, epoch); 212 | save(filename, 'net', 'epoch', 'vel', 'totalIter', 'relLossTest', 'mseTest', 'learnRate'); 213 | end 214 | end 215 | %% Test on unseen, higher resolution data 216 | % Here we take the trained network and test on unseen data with a higher spatial 217 | % resolution than the training data. This is an example of zero-shot super-resolution. 218 | 219 | gridHighRes = linspace(0, xmax, n); 220 | 221 | idxToPlot = numTrain+(1:4); 222 | figure; 223 | for p = 1:4 224 | xn = dlarray(cat(1, x(idxToPlot(p),:), gridHighRes),'CSB'); 225 | yn = predict(net, xn); 226 | 227 | subplot(2, 2, p) 228 | plot(gridHighRes, t(idxToPlot(p),:)), hold on, plot(gridHighRes, extractdata(yn)) 229 | axis tight 230 | xlabel('x') 231 | ylabel('U') 232 | end 233 | %% Helper functions 234 | 235 | function [loss, grad] = modelLoss(x, t, net) 236 | y = net.forward(x); 237 | y = permute(stripdims(y), [3 1 2]); 238 | y = stripdims(y); 239 | 240 | loss = relativeL2Loss(y, t); 241 | 242 | grad = dlgradient(loss, net.Learnables); 243 | end 244 | 245 | function loss = relativeL2Loss(y, t) 246 | diffNorms = normFcn( (y - t) ); 247 | tNorms = normFcn( t ); 248 | 249 | loss = sum(diffNorms./tNorms, 1); 250 | end 251 | 252 | function n = normFcn(x) 253 | n = sqrt( sum(x.^2, 2) ); 254 | end -------------------------------------------------------------------------------- /fno/spectralConvolution1dLayer.m: -------------------------------------------------------------------------------- 1 | classdef spectralConvolution1dLayer < nnet.layer.Layer ... 2 | & nnet.layer.Formattable ... 3 | & nnet.layer.Acceleratable 4 | % spectralConvolution1dLayer Spectral convolution 1d 5 | 6 | % Copyright 2022 The MathWorks, Inc. 7 | 8 | properties 9 | Cin 10 | Cout 11 | NumModes 12 | end 13 | 14 | properties (Learnable) 15 | Weights 16 | end 17 | 18 | methods 19 | function this = spectralConvolution1dLayer(outChannels, numModes, nvargs) 20 | % spectralConvolution1dLayer Spectral convolution 1d 21 | % 22 | % layer = spectralConvolution1dLayer(outChannels, numModes) 23 | % creates a spectral convolution 1d layer. outChannels 24 | % specifies the number of channels in the layer output. 25 | % numModes specifies the number of modes which are combined 26 | % in Fourier space. 27 | % 28 | % layer = spectralConvolution1dLayer(outChannels, numModes, 29 | % Name=Value) specifies additional options using one or more 30 | % name-value arguments: 31 | % 32 | % Name - Name for the layer. The default value is "". 33 | % 34 | % Weights - Complex learnable array of size 35 | % (inChannels)x(outChannels)x(numModes). The 36 | % default value is []. 37 | arguments 38 | outChannels (1,1) double 39 | numModes (1,1) double 40 | nvargs.Name {mustBeTextScalar} = "spectralConv1d" 41 | nvargs.Weights = [] 42 | end 43 | 44 | this.Cout = outChannels; 45 | this.NumModes = numModes; 46 | this.Name = nvargs.Name; 47 | this.Weights = nvargs.Weights; 48 | end 49 | 50 | function this = initialize(this, ndl) 51 | inChannels = ndl.Size( finddim(ndl,'C') ); 52 | outChannels = this.Cout; 53 | numModes = this.NumModes; 54 | 55 | if isempty(this.Weights) 56 | this.Cin = inChannels; 57 | this.Weights = 1./(inChannels*outChannels).*( ... 58 | rand([inChannels outChannels numModes]) + ... 59 | 1i.*rand([inChannels outChannels numModes]) ); 60 | else 61 | assert( inChannels == this.Cin, 'The input channel size must match the layer' ); 62 | end 63 | end 64 | 65 | function y = predict(this, x) 66 | % First compute the rfft, normalized and one-sided 67 | x = real(x); 68 | x = stripdims(x); 69 | N = size(x, 1); 70 | xft = iRFFT(x, 1, N); 71 | 72 | % Multiply selected Fourier modes 73 | xft = permute(xft(1:this.NumModes, :, :), [3 2 1]); 74 | yft = pagemtimes( xft, this.Weights ); 75 | yft = permute(yft, [3 2 1]); 76 | 77 | S = floor(N/2)+1 - this.NumModes; 78 | yft = cat(1, yft, zeros([S size(yft, 2:3)], 'like', yft)); 79 | 80 | % Return to physical space via irfft, normalized and one-sided 81 | y = iIRFFT(yft, 1, N); 82 | 83 | % Re-apply labels 84 | y = dlarray(y, 'SCB'); 85 | y = real(y); 86 | end 87 | end 88 | end 89 | 90 | function y = iRFFT(x, dim, N) 91 | y = fft(x, [], dim); 92 | y = y(1:floor(N/2)+1, :, :)./N; 93 | end 94 | 95 | function y = iIRFFT(x, dim, N) 96 | x(end+1:N, :, :, :) = conj( x(ceil(N/2):-1:2, :, :, :) ); 97 | y = ifft(N.*x, [], dim, 'symmetric'); 98 | end 99 | -------------------------------------------------------------------------------- /images/figure_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/fourier-neural-operator/8ac0e99cce24ca4d517da50f958d0b29d904be04/images/figure_0.png -------------------------------------------------------------------------------- /images/figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/fourier-neural-operator/8ac0e99cce24ca4d517da50f958d0b29d904be04/images/figure_1.png --------------------------------------------------------------------------------