├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── notebooks ├── examples_1d_multi_step.ipynb ├── examples_1d_single_step.ipynb └── examples_3d_multi_step.ipynb ├── resources ├── heat_neumann.png ├── ns_initial_condition.png ├── ns_lid_cavity.mp4 ├── ns_lid_cavity_rel_err.mp4 ├── operator_bdy.png └── stokes.png └── src ├── models ├── base │ ├── FNO1d.py │ ├── FNO2d.py │ ├── FNO3d.py │ └── __init__.py ├── corrections.py ├── multi_step │ ├── BOON_2d.py │ ├── BOON_3d.py │ └── __init__.py ├── operator.py └── single_step │ ├── BOON_1d.py │ └── __init__.py └── utils └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BOON: Boundary correction for neural operators 2 | 3 | ![Image](resources/operator_bdy.png) 4 | 5 | [Nadim Saad*](https://profiles.stanford.edu/nadim-saad), [Gaurav Gupta*](http://guptagaurav.me/index.html), [Shima Alizadeh](https://scholar.google.com/citations?user=r3qS03kAAAAJ&hl=en), [Danielle C. Maddix](https://dcmaddix.github.io/)\ 6 | **Guiding continuous operator learning through Physics-based boundary constraints,**\ 7 | [International Conference on Learning Representations](https://openreview.net/forum?id=gfWNItGOES6), 2023\ 8 | (*equal contribution authors) 9 | 10 | 11 | ## Setup 12 | 13 | ### Requirements 14 | The code package is developed using Python 3.8 and Pytorch 1.11 with cuda 11.6. The code could be executed on CPU/GPU but GPU is preferred. All experiments were conducted on Tesla V100 16GB. 15 | 16 | ## Experiments 17 | ### Data 18 | Generate the data using the scripts provided in the 'Data' directory. The scripts use Matlab 2018+. A sample generated dataset for all the experiments is available below. 19 | 20 | [BOON PDE datasets](https://drive.google.com/drive/folders/1tj3dBlM6NQk6qo9cwyLaJmvLnXTho0yD?usp=sharing) 21 | 22 | ### Scripts 23 | Detailed notebooks for reproducing all the experiments in the paper are provided. The cases of 1D, 1D time-varying, 2D time-varying are shown in the respective notebooks for all the three boundary conditions of Dirichlet, Neumann, and Periodic. 24 | 25 | ### 1D Heat equation motivating example 26 | As an example, a complete pipeline is shown for the 1D single-step PDE with Neumann boundary condition in the attached `examples_1d_single_step.ipynb` notebook. 27 | 28 | ![Image](resources/heat_neumann.png) \ 29 | **Non-physical solution**: Nonzero flux suggests heat flow through an insulator. 30 | 31 | ### 1D Stokes' second problem 32 | As an example, a complete pipeline is shown for the 1D time-varying PDE with Dirichlet boundary condition in the attached `examples_1d_multi_step.ipynb` notebook. 33 | 34 | ![Image](resources/stokes.png) 35 | 36 | ### 2D Navier-Stokes lid-driven cavity flow 37 | A complete pipeline is shown for the 2D time-varying PDE with Dirichlet boundary condition in the attached `examples_3d_multi_step.ipynb` notebook. 38 | 39 | https://user-images.githubusercontent.com/19197210/217733438-211a4499-d2b3-4830-8bba-2d3d7ad5dfb3.mp4 40 | 41 | 42 | 43 | 44 | https://user-images.githubusercontent.com/19197210/217733831-9d9336a3-6709-40f3-b326-716ab98a6d30.mp4 45 | 46 | 47 | 48 | ## Citation 49 | If you use this code, or our work, please cite: 50 | ``` 51 | @inproceedings{saad2022BOON, 52 | author = {Saad, Nadim and Gupta, Gaurav and Alizadeh, Shima and Maddix, Danielle C.}, 53 | title = {Guiding continuous operator learning through Physics-based boundary constraints}, 54 | booktitle={International Conference on Learning Representations}, 55 | year={2023}, 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /notebooks/examples_1d_single_step.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "5256e5aa", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "from pathlib import Path\n", 12 | "project_root = Path.cwd().parent.absolute() # get project root path for loading modules in notebook\n", 13 | "sys.path.insert(0, str(project_root))\n", 14 | "\n", 15 | "\n", 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "\n", 20 | "import os\n", 21 | "import numpy as np\n", 22 | "from scipy.io import loadmat\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from functools import reduce, partial\n", 25 | "from timeit import default_timer\n", 26 | "\n", 27 | "from src.utils.utils import *\n", 28 | "from src.models.base import FNO1d\n", 29 | "from src.models.single_step import BOON_FNO1d" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "d5194a27", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "torch.manual_seed(0)\n", 40 | "np.random.seed(0)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "id": "bbf56dca", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "c13799cb", 56 | "metadata": {}, 57 | "source": [ 58 | "# Dirichlet" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "id": "84bb3be5", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "ntrain = 500\n", 69 | "ntest = 100\n", 70 | "\n", 71 | "sub = 1 #subsampling rate\n", 72 | "h = 500 // sub #total grid size divided by the subsampling rate\n", 73 | "\n", 74 | "s = h\n", 75 | "\n", 76 | "batch_size = 20\n", 77 | "learning_rate = 0.001\n", 78 | "\n", 79 | "epochs = 500\n", 80 | "step_size = 50\n", 81 | "gamma = 0.5\n", 82 | "\n", 83 | "modes = 16\n", 84 | "width = 64" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 9, 90 | "id": "69d0bf53-c6f8-426b-9eb9-1022cfabd197", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "################################################################\n", 95 | "# read data\n", 96 | "################################################################\n", 97 | "\n", 98 | "# LOAD YOUR DATA HERE consists of pairs (a,u) with randomly generated initial conditions or PDE parameters a\n", 99 | "# and solution u\n", 100 | "\n", 101 | "downloader = DataDownloader()\n", 102 | "# Use downloader.download(id = \"Burgers_Dir_1D\", tag = \"nu_0_point_1\") if only downloading a single file \n", 103 | "downloader.download(id = \"Burgers_Dir_1D\")\n", 104 | "rw_ = downloader.locate(id = \"Burgers_Dir_1D\", tag = \"nu_0_point_1\")\n", 105 | "\n", 106 | "x_data = rw_['a'].astype(np.float32) # shape (num_random_simulations, number_grid_points)\n", 107 | "y_data = rw_['u'].astype(np.float32) # shape (num_random_simulations, number_grid_points)\n", 108 | "\n", 109 | "x_data = torch.from_numpy(x_data)\n", 110 | "y_data = torch.from_numpy(y_data)\n", 111 | "\n", 112 | "x_train = x_data[:ntrain,:]\n", 113 | "y_train = y_data[:ntrain,:]\n", 114 | "x_test = x_data[-ntest:,:]\n", 115 | "y_test = y_data[-ntest:,:]\n", 116 | "\n", 117 | "x_train = x_train.unsqueeze(-1)\n", 118 | "x_test = x_test.unsqueeze(-1)\n", 119 | "\n", 120 | "left_bdry_train = y_train[:,0].unsqueeze(-1)\n", 121 | "right_bdry_train = y_train[:,-1].unsqueeze(-1)\n", 122 | "\n", 123 | "left_bdry_test = y_test[:,0].unsqueeze(-1)\n", 124 | "right_bdry_test = y_test[:,-1].unsqueeze(-1)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 11, 130 | "id": "473fbef6", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train, left_bdry_train, right_bdry_train), \n", 135 | " batch_size=batch_size, shuffle=True)\n", 136 | "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test, left_bdry_test, right_bdry_test), \n", 137 | " batch_size=batch_size, shuffle=False)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 13, 143 | "id": "64e61bf1", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "base_no = FNO1d(modes, width)\n", 148 | "model = BOON_FNO1d(width, \n", 149 | " base_no,\n", 150 | " bdy_type='dirichlet').to(device)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "b04284f9", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n", 161 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n", 162 | "\n", 163 | "myloss = LpLoss(size_average=False)\n", 164 | "\n", 165 | "for ep in range(epochs):\n", 166 | " model.train()\n", 167 | " t1 = default_timer()\n", 168 | " train_l2 = 0\n", 169 | " for batch in train_loader:\n", 170 | " x, y, left, right = batch\n", 171 | " x, y, left, right = x.to(device), y.to(device), left.to(device), right.to(device)\n", 172 | " \n", 173 | " optimizer.zero_grad()\n", 174 | " out = model(x, \n", 175 | " bdy_left={'val':left}, \n", 176 | " bdy_right={'val':right}\n", 177 | " )\n", 178 | "\n", 179 | " l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))\n", 180 | " l2.backward() # use the l2 relative loss\n", 181 | "\n", 182 | " optimizer.step()\n", 183 | " train_l2 += l2.item()\n", 184 | " \n", 185 | "\n", 186 | " scheduler.step()\n", 187 | " model.eval()\n", 188 | " test_l2 = 0.0\n", 189 | " with torch.no_grad():\n", 190 | " for batch in test_loader:\n", 191 | " x, y, left, right = batch\n", 192 | " x, y, left, right = x.to(device), y.to(device), left.to(device), right.to(device)\n", 193 | "\n", 194 | " out = model(x, \n", 195 | " bdy_left={'val':left}, \n", 196 | " bdy_right={'val':right}\n", 197 | " )\n", 198 | "\n", 199 | " test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()\n", 200 | "\n", 201 | " train_l2 /= ntrain\n", 202 | " test_l2 /= ntest\n", 203 | "\n", 204 | " t2 = default_timer()\n", 205 | " print(ep, t2-t1, train_l2, test_l2)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "id": "956401f0", 211 | "metadata": {}, 212 | "source": [ 213 | "# Neumann" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 8, 219 | "id": "25945916", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "################################################################\n", 224 | "# configurations\n", 225 | "################################################################\n", 226 | "ntrain = 500\n", 227 | "ntest = 100\n", 228 | "\n", 229 | "sub = 1 #subsampling rate\n", 230 | "h =500 // sub #total grid size divided by the subsampling rate\n", 231 | "\n", 232 | "s = h\n", 233 | "N = h\n", 234 | "\n", 235 | "batch_size = 20\n", 236 | "learning_rate = 0.001\n", 237 | "\n", 238 | "epochs = 500\n", 239 | "step_size = 50\n", 240 | "gamma = 0.5\n", 241 | "\n", 242 | "modes = 16\n", 243 | "width = 64" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "id": "a3c543f2", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "################################################################\n", 254 | "# read data\n", 255 | "################################################################\n", 256 | "# LOAD YOUR DATA HERE consists of pairs (a,u) with randomly generated initial conditions or PDE parameters a\n", 257 | "# and solution u\n", 258 | "\n", 259 | "downloader = DataDownloader()\n", 260 | "downloader.download(id = \"Heat_Neu_1D\")\n", 261 | "rw_ = downloader.locate(id = \"Heat_Neu_1D\", tag = \"1D\")\n", 262 | "\n", 263 | "x_data = rw_['a'].astype(np.float32) # shape (num_random_simulations, number_grid_points)\n", 264 | "y_data = rw_['u'].astype(np.float32) # shape (num_random_simulations, number_grid_points)\n", 265 | "\n", 266 | "x_data = torch.from_numpy(x_data)\n", 267 | "y_data = torch.from_numpy(y_data)\n", 268 | "\n", 269 | "x_train = x_data[:ntrain,::sub]\n", 270 | "y_train = y_data[:ntrain,::sub]\n", 271 | "x_test = x_data[-ntest:,::sub]\n", 272 | "y_test = y_data[-ntest:,::sub]\n", 273 | "\n", 274 | "x_train = x_train.unsqueeze(-1)\n", 275 | "x_test = x_test.unsqueeze(-1)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 11, 281 | "id": "17bb3909", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), \n", 286 | " batch_size=batch_size, shuffle=True)\n", 287 | "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), \n", 288 | " batch_size=batch_size, shuffle=False)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 13, 294 | "id": "37202626", 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "h = 1/(N-1)\n", 299 | "\n", 300 | "coeffs_finite_difference_right = np.array([-1/(3*h), 3/(2*h), -3/h , 11/(6*h)])\n", 301 | "coeffs_finite_difference_right = coeffs_finite_difference_right.astype(np.float32)\n", 302 | "\n", 303 | "normalized_coeff_right = np.array([-2/11, 9/11, -18/11, 6*h/11]).astype(np.float32)\n", 304 | "normalized_coeff_right = torch.from_numpy(normalized_coeff_right).to(device)\n", 305 | "\n", 306 | "diff_fn_right = partial(compute_finite_diff, normalized_coeff_right, loc=-1)\n", 307 | "\n", 308 | "coeffs_finite_difference_left = np.array([1, -1])\n", 309 | "coeffs_finite_difference_left = coeffs_finite_difference_left.astype(np.float32)\n", 310 | "\n", 311 | "normalized_coeff_left = np.array([1, -1]).astype(np.float32)\n", 312 | "normalized_coeff_left = torch.from_numpy(normalized_coeff_left).to(device)\n", 313 | "diff_fn_left = partial(compute_finite_diff, normalized_coeff_left, loc=0)\n", 314 | "\n", 315 | "neumann_bdy_left = 0.0\n", 316 | "\n", 317 | "U = np.array([5])\n", 318 | "U = torch.from_numpy(U).to(device)\n", 319 | "\n", 320 | "t = np.array([0.5])\n", 321 | "t = torch.from_numpy(t).to(device)\n", 322 | "neumann_bdy_right = U*torch.sin(torch.pi*t)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 14, 328 | "id": "7c5c5ddb", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "base_no = FNO1d(modes, width)\n", 333 | "model = BOON_FNO1d(width, \n", 334 | " base_no,\n", 335 | " bdy_type='neumann').to(device)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "id": "28ef2f8a", 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n", 346 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n", 347 | "\n", 348 | "myloss = LpLoss(size_average=False)\n", 349 | "\n", 350 | "for ep in range(epochs):\n", 351 | " model.train()\n", 352 | " t1 = default_timer()\n", 353 | " train_l2 = 0\n", 354 | " for batch in train_loader:\n", 355 | " x, y = batch\n", 356 | " x, y = x.to(device), y.to(device)\n", 357 | " \n", 358 | " optimizer.zero_grad()\n", 359 | " out = model(x, \n", 360 | " bdy_left={'val':neumann_bdy_left, 'diff_fn':diff_fn_left},\n", 361 | " bdy_right={'val':neumann_bdy_right, 'diff_fn':diff_fn_right},\n", 362 | " )\n", 363 | "\n", 364 | " l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))\n", 365 | " l2.backward() # use the l2 relative loss\n", 366 | "\n", 367 | " optimizer.step()\n", 368 | " train_l2 += l2.item()\n", 369 | " \n", 370 | "\n", 371 | " scheduler.step()\n", 372 | " model.eval()\n", 373 | " test_l2 = 0.0\n", 374 | " with torch.no_grad():\n", 375 | " for batch in test_loader:\n", 376 | " x, y = batch\n", 377 | " x, y = x.to(device), y.to(device)\n", 378 | "\n", 379 | " out = model(x, \n", 380 | " bdy_left={'val':neumann_bdy_left, 'diff_fn':diff_fn_left},\n", 381 | " bdy_right={'val':neumann_bdy_right, 'diff_fn':diff_fn_right},\n", 382 | " )\n", 383 | "\n", 384 | " test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()\n", 385 | "\n", 386 | " train_l2 /= ntrain\n", 387 | " test_l2 /= ntest\n", 388 | "\n", 389 | " t2 = default_timer()\n", 390 | " print(ep, t2-t1, train_l2, test_l2)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "id": "1c68d539", 396 | "metadata": {}, 397 | "source": [ 398 | "# Periodic" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 9, 404 | "id": "cf20a7ce", 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "################################################################\n", 409 | "# configurations\n", 410 | "################################################################\n", 411 | "ntrain = 500\n", 412 | "ntest = 100\n", 413 | "\n", 414 | "sub = 2**6 #subsampling rate\n", 415 | "h = 2**13 // sub #total grid size divided by the subsampling rate\n", 416 | "\n", 417 | "s = h\n", 418 | "\n", 419 | "batch_size = 20\n", 420 | "learning_rate = 0.001\n", 421 | "\n", 422 | "epochs = 500\n", 423 | "step_size = 50\n", 424 | "gamma = 0.5\n", 425 | "\n", 426 | "modes = 16\n", 427 | "width = 64" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": null, 433 | "id": "353968f0", 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "################################################################\n", 438 | "# read data\n", 439 | "################################################################\n", 440 | "# LOAD YOUR DATA HERE consists of pairs (a,u) with randomly generated initial conditions or PDE parameters a\n", 441 | "# and solution u\n", 442 | "\n", 443 | "downloader = DataDownloader()\n", 444 | "downloader.download(id = \"Burgers_Per_1D\")\n", 445 | "rw_ = downloader.locate(id = \"Burgers_Per_1D\", tag = \"R10\")\n", 446 | "\n", 447 | "x_data = rw_['a'].astype(np.float32) # shape (num_random_simulations, number_grid_points)\n", 448 | "y_data = rw_['u'].astype(np.float32) # shape (num_random_simulations, number_grid_points)\n", 449 | "\n", 450 | "x_data = torch.from_numpy(x_data)\n", 451 | "y_data = torch.from_numpy(y_data)\n", 452 | "\n", 453 | "x_train = x_data[:ntrain,::sub]\n", 454 | "y_train = y_data[:ntrain,::sub]\n", 455 | "x_test = x_data[-ntest:,::sub]\n", 456 | "y_test = y_data[-ntest:,::sub]\n", 457 | "\n", 458 | "# strictly put periodic bdy condition\n", 459 | "x_train[:, -1] = x_train[:, 0]\n", 460 | "y_train[:, -1] = y_train[:, 0]\n", 461 | "x_test[:, -1] = x_test[:, 0]\n", 462 | "y_test[:, -1] = y_test[:, 0]\n", 463 | "\n", 464 | "x_train = x_train.unsqueeze(-1)\n", 465 | "x_test = x_test.unsqueeze(-1)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 12, 471 | "id": "eee50758", 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), \n", 476 | " batch_size=batch_size, shuffle=True)\n", 477 | "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), \n", 478 | " batch_size=batch_size, shuffle=False)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 14, 484 | "id": "b3982b28", 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "base_no = FNO1d(modes, width)\n", 489 | "model = BOON_FNO1d(width, \n", 490 | " base_no,\n", 491 | " bdy_type='periodic').to(device)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "id": "f2b30b8a", 498 | "metadata": {}, 499 | "outputs": [], 500 | "source": [ 501 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n", 502 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n", 503 | "\n", 504 | "myloss = LpLoss(size_average=False)\n", 505 | "\n", 506 | "for ep in range(epochs):\n", 507 | " model.train()\n", 508 | " t1 = default_timer()\n", 509 | " train_l2 = 0\n", 510 | " for batch in train_loader:\n", 511 | " x, y = batch\n", 512 | " x, y = x.to(device), y.to(device)\n", 513 | " \n", 514 | " optimizer.zero_grad()\n", 515 | " out = model(x)\n", 516 | "\n", 517 | " l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))\n", 518 | " l2.backward() # use the l2 relative loss\n", 519 | "\n", 520 | " optimizer.step()\n", 521 | " train_l2 += l2.item()\n", 522 | " \n", 523 | "\n", 524 | " scheduler.step()\n", 525 | " model.eval()\n", 526 | " test_l2 = 0.0\n", 527 | " test_l2_near_bdy = 0.0\n", 528 | " with torch.no_grad():\n", 529 | " for batch in test_loader:\n", 530 | " x, y = batch\n", 531 | " x, y = x.to(device), y.to(device)\n", 532 | " out = model(x)\n", 533 | " test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()\n", 534 | "\n", 535 | " train_l2 /= ntrain\n", 536 | " test_l2 /= ntest\n", 537 | "\n", 538 | " t2 = default_timer()\n", 539 | " print(ep, t2-t1, train_l2, test_l2)" 540 | ] 541 | } 542 | ], 543 | "metadata": { 544 | "kernelspec": { 545 | "display_name": "Python 3 (ipykernel)", 546 | "language": "python", 547 | "name": "python3" 548 | }, 549 | "language_info": { 550 | "codemirror_mode": { 551 | "name": "ipython", 552 | "version": 3 553 | }, 554 | "file_extension": ".py", 555 | "mimetype": "text/x-python", 556 | "name": "python", 557 | "nbconvert_exporter": "python", 558 | "pygments_lexer": "ipython3", 559 | "version": "3.8.15" 560 | } 561 | }, 562 | "nbformat": 4, 563 | "nbformat_minor": 5 564 | } 565 | -------------------------------------------------------------------------------- /notebooks/examples_3d_multi_step.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "385b82a3", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "from pathlib import Path\n", 12 | "project_root = Path.cwd().parent.absolute() # get project root path for loading modules in notebook\n", 13 | "sys.path.insert(0, str(project_root))\n", 14 | "\n", 15 | "\n", 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "\n", 20 | "import os\n", 21 | "import numpy as np\n", 22 | "from scipy.io import loadmat\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from functools import reduce, partial\n", 25 | "from timeit import default_timer\n", 26 | "\n", 27 | "from src.utils.utils import *\n", 28 | "from src.models.base import FNO3d\n", 29 | "from src.models.multi_step import BOON_FNO3d" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "cc6ad850", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "torch.manual_seed(0)\n", 40 | "np.random.seed(0)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "id": "2a30370c", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "a173177d", 56 | "metadata": {}, 57 | "source": [ 58 | "# Dirichlet" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "id": "bcc2cdc4", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "ntrain = 1000\n", 69 | "ntest = 200\n", 70 | "\n", 71 | "modes = 8\n", 72 | "width = 20\n", 73 | "\n", 74 | "batch_size = 10\n", 75 | "batch_size2 = batch_size\n", 76 | "\n", 77 | "epochs = 500\n", 78 | "learning_rate = 0.001\n", 79 | "scheduler_step = 100\n", 80 | "scheduler_gamma = 0.5\n", 81 | "\n", 82 | "sub = 2\n", 83 | "N = 100 // sub #total grid size divided by the subsampling rate\n", 84 | "S = N\n", 85 | "\n", 86 | "T_in = 1\n", 87 | "T = 25" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 17, 93 | "id": "fb45335c", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "downloader = DataDownloader()\n", 98 | "# Use downloader.download(id = \"NV_Dir_3D\", tag = \"Re_100\") if only downloading a single file \n", 99 | "downloader.download(id = \"NV_Dir_3D\")\n", 100 | "rw = downloader.locate(id = \"NV_Dir_3D\", tag = \"Re_100\")\n", 101 | "\n", 102 | "train_a = rw['a'][:ntrain,::sub, ::sub,:T_in]\n", 103 | "train_a = train_a.astype(np.float32)\n", 104 | "train_a = torch.from_numpy(train_a)\n", 105 | "\n", 106 | "train_u = rw['u'][:ntrain,::sub, ::sub,-T:]\n", 107 | "train_u = train_u.astype(np.float32)\n", 108 | "train_u = torch.from_numpy(train_u)\n", 109 | "\n", 110 | "test_a = rw['a'][-ntest:,::sub, ::sub,:T_in]\n", 111 | "test_a = test_a.astype(np.float32)\n", 112 | "test_a = torch.from_numpy(test_a)\n", 113 | "\n", 114 | "test_u = rw['u'][-ntest:,::sub, ::sub,-T:]\n", 115 | "test_u = test_u.astype(np.float32)\n", 116 | "test_u = torch.from_numpy(test_u)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 20, 122 | "id": "3ddc17dd", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1])\n", 127 | "test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1])\n", 128 | "\n", 129 | "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)\n", 130 | "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 21, 136 | "id": "4474187e", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "base_no = FNO3d(modes, modes, modes, width)\n", 141 | "model = BOON_FNO3d(width,\n", 142 | " base_no,\n", 143 | " bdy_type = 'dirichlet').to(device)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "17d9014b", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n", 154 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)\n", 155 | "\n", 156 | "myloss = LpLoss(size_average=False)\n", 157 | "for ep in range(epochs):\n", 158 | " model.train()\n", 159 | " t1 = default_timer()\n", 160 | " train_l2 = 0\n", 161 | " for x, y in train_loader:\n", 162 | " bs, nx, ny, T, _ = x.shape\n", 163 | " x, y = x.to(device), y.to(device)\n", 164 | "\n", 165 | " optimizer.zero_grad()\n", 166 | " \n", 167 | " bdy_left = y[:, 0, :, :].reshape(bs, 1, ny, T) # add extra dimension to take care of \n", 168 | "# model channel structure\n", 169 | " bdy_right = y[:,-1, :, :].reshape(bs, 1, ny, T)\n", 170 | " bdy_top = y[:, :, 0, :].reshape(bs, 1, nx, T)\n", 171 | " bdy_down = y[:, :,-1, :].reshape(bs, 1, nx, T)\n", 172 | " \n", 173 | " out = model(x, \n", 174 | " bdy_left = {'val':bdy_left}, \n", 175 | " bdy_right = {'val':bdy_right}, \n", 176 | " bdy_top = {'val':bdy_top}, \n", 177 | " bdy_down = {'val':bdy_down}\n", 178 | " ).view(bs, S, S, T)\n", 179 | "\n", 180 | " l2 = myloss(out.view(bs, -1), y.view(bs, -1))\n", 181 | " l2.backward()\n", 182 | "\n", 183 | " optimizer.step()\n", 184 | " train_l2 += l2.item()\n", 185 | "\n", 186 | " scheduler.step()\n", 187 | "\n", 188 | " model.eval()\n", 189 | " test_l2 = 0.0\n", 190 | " with torch.no_grad():\n", 191 | " for x, y in test_loader:\n", 192 | " bs, nx, ny, T, _ = x.shape\n", 193 | " x, y = x.to(device), y.to(device)\n", 194 | " \n", 195 | " bdy_left = y[:, 0, :, :].reshape(bs, 1, ny, T) # add extra dimension to take care of \n", 196 | "# model channel structure\n", 197 | " bdy_right = y[:,-1, :, :].reshape(bs, 1, ny, T)\n", 198 | " bdy_top = y[:, :, 0, :].reshape(bs, 1, nx, T)\n", 199 | " bdy_down = y[:, :,-1, :].reshape(bs, 1, nx, T)\n", 200 | "\n", 201 | " out = model(x,\n", 202 | " bdy_left = {'val':bdy_left}, \n", 203 | " bdy_right = {'val':bdy_right}, \n", 204 | " bdy_top = {'val':bdy_top}, \n", 205 | " bdy_down = {'val':bdy_down}\n", 206 | " ).view(bs, S, S, T)\n", 207 | " test_l2 += myloss(out.view(bs, -1), y.view(bs, -1)).item()\n", 208 | "\n", 209 | " train_l2 /= ntrain\n", 210 | " test_l2 /= ntest\n", 211 | "\n", 212 | " t2 = default_timer()\n", 213 | " print(ep, t2-t1, train_l2, test_l2)\n", 214 | "# torch.save(model, path_model)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "ef7af4c7", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "bert", 229 | "language": "python", 230 | "name": "bert" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.8.15" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 5 247 | } 248 | -------------------------------------------------------------------------------- /resources/heat_neumann.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/boon/19882ef6b6f970e7862bb432ab5b3076f6b70a44/resources/heat_neumann.png -------------------------------------------------------------------------------- /resources/ns_initial_condition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/boon/19882ef6b6f970e7862bb432ab5b3076f6b70a44/resources/ns_initial_condition.png -------------------------------------------------------------------------------- /resources/ns_lid_cavity.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/boon/19882ef6b6f970e7862bb432ab5b3076f6b70a44/resources/ns_lid_cavity.mp4 -------------------------------------------------------------------------------- /resources/ns_lid_cavity_rel_err.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/boon/19882ef6b6f970e7862bb432ab5b3076f6b70a44/resources/ns_lid_cavity_rel_err.mp4 -------------------------------------------------------------------------------- /resources/operator_bdy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/boon/19882ef6b6f970e7862bb432ab5b3076f6b70a44/resources/operator_bdy.png -------------------------------------------------------------------------------- /resources/stokes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/boon/19882ef6b6f970e7862bb432ab5b3076f6b70a44/resources/stokes.png -------------------------------------------------------------------------------- /src/models/base/FNO1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | # This code is borrowed from FNO git repository: https://github.com/zongyi-li/fourier_neural_operator 8 | 9 | 10 | ################################################################ 11 | # 1d fourier layer 12 | ################################################################ 13 | class SpectralConv1d(nn.Module): 14 | def __init__(self, in_channels, out_channels, modes1): 15 | super(SpectralConv1d, self).__init__() 16 | 17 | """ 18 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 19 | """ 20 | 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 24 | 25 | self.scale = (1 / (in_channels*out_channels)) 26 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat)) 27 | 28 | # Complex multiplication 29 | def compl_mul1d(self, input, weights): 30 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 31 | return torch.einsum("bix,iox->box", input, weights) 32 | 33 | def forward(self, x): 34 | batchsize = x.shape[0] 35 | #Compute Fourier coeffcients up to factor of e^(- something constant) 36 | x_ft = torch.fft.rfft(x) 37 | 38 | # Multiply relevant Fourier modes 39 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat) 40 | out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 41 | 42 | #Return to physical space 43 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 44 | return x 45 | 46 | 47 | 48 | class FNO1d(nn.Module): 49 | def __init__(self, 50 | modes, 51 | width, 52 | lb=0, 53 | ub=1): 54 | super().__init__() 55 | 56 | """ 57 | The overall network. It contains 4 layers of the Fourier layer. 58 | 1. Lift the input to the desire channel dimension by self.fc0 . 59 | 2. 4 layers of the integral operators u' = (W + K)(u). 60 | W defined by self.w; K defined by self.conv . 61 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 62 | 63 | input: the solution of the initial condition and location (a(x), x) 64 | input shape: (batchsize, x=s, c=2) 65 | output: the solution of a later timestep 66 | output shape: (batchsize, x=s, c=1) 67 | """ 68 | 69 | self.modes1 = modes 70 | self.width = width 71 | self.lb = lb # lower value of the domain 72 | self.ub = ub # upper value of the domain 73 | self.padding = 2 # pad the domain if input is non-periodic 74 | self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x) 75 | 76 | self.conv0 = SpectralConv1d(self.width, self.width, self.modes1) 77 | self.conv1 = SpectralConv1d(self.width, self.width, self.modes1) 78 | self.conv2 = SpectralConv1d(self.width, self.width, self.modes1) 79 | self.conv3 = SpectralConv1d(self.width, self.width, self.modes1) 80 | 81 | self.w0 = nn.Conv1d(self.width, self.width, 1) 82 | self.w1 = nn.Conv1d(self.width, self.width, 1) 83 | self.w2 = nn.Conv1d(self.width, self.width, 1) 84 | self.w3 = nn.Conv1d(self.width, self.width, 1) 85 | 86 | self.fc1 = nn.Linear(self.width, 128) 87 | self.fc2 = nn.Linear(128, 1) 88 | 89 | 90 | 91 | def forward(self, x): 92 | """ 93 | Forward function of the Neural Operator. 94 | 95 | Args: 96 | x: array representing the input of the Neural Operator, 97 | given by the initial condition of the PDE 98 | """ 99 | 100 | grid = self.get_grid(x.shape, x.device) 101 | 102 | x = torch.cat((x, grid), dim=-1) 103 | x = self.fc0(x) 104 | 105 | x = x.permute(0, 2, 1) 106 | 107 | x1 = self.conv0(x) 108 | x2 = self.w0(x) 109 | x = F.gelu(x1 + x2) 110 | 111 | x1 = self.conv1(x) 112 | x2 = self.w1(x) 113 | x = F.gelu(x1 + x2) 114 | 115 | x1 = self.conv2(x) 116 | x2 = self.w2(x) 117 | x = F.gelu(x1 + x2) 118 | 119 | x1 = self.conv3(x) 120 | x2 = self.w3(x) 121 | 122 | x = x.permute(0, 2, 1) 123 | 124 | x = self.fc1(x) 125 | x = F.gelu(x) 126 | 127 | x = self.fc2(x) 128 | 129 | return x 130 | 131 | def get_grid(self, shape, device): 132 | batchsize, size_x = shape[0], shape[1] 133 | gridx = torch.tensor(np.linspace(self.lb, self.ub, size_x), dtype=torch.float) 134 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 135 | return gridx.to(device) -------------------------------------------------------------------------------- /src/models/base/FNO2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | # This code is borrowed from FNO git repository: https://github.com/zongyi-li/fourier_neural_operator 8 | 9 | 10 | ################################################################ 11 | # 2D fourier layer 12 | ################################################################ 13 | class SpectralConv2d(nn.Module): 14 | def __init__(self, in_channels, out_channels, modes1, modes2): 15 | super(SpectralConv2d, self).__init__() 16 | 17 | """ 18 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 19 | """ 20 | 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 24 | self.modes2 = modes2 25 | 26 | self.scale = (1 / (in_channels * out_channels)) 27 | self.weights1 = nn.Parameter( 28 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 29 | self.weights2 = nn.Parameter( 30 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 31 | 32 | # Complex multiplication 33 | def compl_mul2d(self, input, weights): 34 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 35 | return torch.einsum("bixy,ioxy->boxy", input, weights) 36 | 37 | def forward(self, x): 38 | batchsize = x.shape[0] 39 | # Compute Fourier coeffcients up to factor of e^(- something constant) 40 | x_ft = torch.fft.rfft2(x) 41 | 42 | # Multiply relevant Fourier modes 43 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, 44 | device=x.device) 45 | out_ft[:, :, :self.modes1, :self.modes2] = \ 46 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 47 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 48 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 49 | 50 | # Return to physical space 51 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 52 | return x 53 | 54 | 55 | class FNO2d(nn.Module): 56 | def __init__(self, 57 | modes1, 58 | modes2, 59 | width, 60 | lb=0, 61 | ub=1): 62 | super(FNO2d, self).__init__() 63 | 64 | """ 65 | The overall network. It contains 4 layers of the Fourier layer. 66 | 1. Lift the input to the desire channel dimension by self.fc0 . 67 | 2. 4 layers of the integral operators u' = (W + K)(u). 68 | W defined by self.w; K defined by self.conv . 69 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 70 | 71 | input: the solution of the coefficient function and locations (a(x, y), x, y) 72 | input shape: (batchsize, x=s, y=s, c=3) 73 | output: the solution 74 | output shape: (batchsize, x=s, y=s, c=1) 75 | """ 76 | 77 | self.modes1 = modes1 78 | self.modes2 = modes2 79 | self.width = width 80 | self.lb = lb 81 | self.ub = ub 82 | self.padding = 9 # pad the domain if input is non-periodic 83 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 84 | 85 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 86 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 87 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 88 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 89 | 90 | self.w0 = nn.Conv2d(self.width, self.width, 1) 91 | self.w1 = nn.Conv2d(self.width, self.width, 1) 92 | self.w2 = nn.Conv2d(self.width, self.width, 1) 93 | self.w3 = nn.Conv2d(self.width, self.width, 1) 94 | 95 | self.fc1 = nn.Linear(self.width, 128) 96 | self.fc2 = nn.Linear(128, 1) 97 | 98 | def forward(self, x): 99 | grid = self.get_grid(x.shape, x.device) 100 | x = torch.cat((x, grid), dim=-1) 101 | x = self.fc0(x) 102 | x = x.permute(0, 3, 1, 2) 103 | x = F.pad(x, [0, self.padding, 0, self.padding]) 104 | 105 | x1 = self.conv0(x) 106 | x2 = self.w0(x) 107 | x = x1 + x2 108 | x = F.gelu(x) 109 | 110 | x1 = self.conv1(x) 111 | x2 = self.w1(x) 112 | x = x1 + x2 113 | x = F.gelu(x) 114 | 115 | x1 = self.conv2(x) 116 | x2 = self.w2(x) 117 | x = x1 + x2 118 | x = F.gelu(x) 119 | 120 | x1 = self.conv3(x) 121 | x2 = self.w3(x) 122 | x = x1 + x2 123 | 124 | x = x[..., :-self.padding, :-self.padding] 125 | x = x.permute(0, 2, 3, 1) 126 | x = self.fc1(x) 127 | x = F.gelu(x) 128 | 129 | x = self.fc2(x) 130 | return x 131 | 132 | def get_grid(self, shape, device): 133 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 134 | gridx = torch.tensor(np.linspace(self.lb, self.ub, size_x), dtype=torch.float) 135 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 136 | gridy = torch.tensor(np.linspace(self.lb, self.ub, size_y), dtype=torch.float) 137 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 138 | return torch.cat((gridx, gridy), dim=-1).to(device) 139 | -------------------------------------------------------------------------------- /src/models/base/FNO3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | # This code is borrowed from FNO git repository: https://github.com/zongyi-li/fourier_neural_operator 8 | 9 | 10 | ################################################################ 11 | # 3d fourier layers 12 | ################################################################ 13 | 14 | class SpectralConv3d(nn.Module): 15 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 16 | super(SpectralConv3d, self).__init__() 17 | 18 | """ 19 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. 20 | """ 21 | 22 | self.in_channels = in_channels 23 | self.out_channels = out_channels 24 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 25 | self.modes2 = modes2 26 | self.modes3 = modes3 27 | 28 | self.scale = (1 / (in_channels * out_channels)) 29 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 30 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 31 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 32 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 33 | 34 | # Complex multiplication 35 | def compl_mul3d(self, input, weights): 36 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 37 | return torch.einsum("bixyz,ioxyz->boxyz", input, weights) 38 | 39 | def forward(self, x): 40 | batchsize = x.shape[0] 41 | #Compute Fourier coeffcients up to factor of e^(- something constant) 42 | x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1]) 43 | 44 | # Multiply relevant Fourier modes 45 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 46 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 47 | self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 48 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 49 | self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 50 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 51 | self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 52 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 53 | self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 54 | 55 | #Return to physical space 56 | x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) 57 | return x 58 | 59 | class FNO3d(nn.Module): 60 | def __init__(self, 61 | modes1, 62 | modes2, 63 | modes3, 64 | width, 65 | lb=0, 66 | ub=1): 67 | super(FNO3d, self).__init__() 68 | 69 | """ 70 | The overall network. It contains 4 layers of the Fourier layer. 71 | 1. Lift the input to the desire channel dimension by self.fc0 . 72 | 2. 4 layers of the integral operators u' = (W + K)(u). 73 | W defined by self.w; K defined by self.conv . 74 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 75 | 76 | input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. 77 | input shape: (batchsize, x=64, y=64, t=40, c=13) 78 | output: the solution of the next 40 timesteps 79 | output shape: (batchsize, x=64, y=64, t=40, c=1) 80 | """ 81 | 82 | self.modes1 = modes1 83 | self.modes2 = modes2 84 | self.modes3 = modes3 85 | self.width = width 86 | self.lb = lb # lower value of the domain 87 | self.ub = ub # upper value of the domain 88 | self.padding = 6 # pad the domain if input is non-periodic 89 | self.fc0 = nn.Linear(4, self.width) 90 | # input channel is 12: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t) 91 | 92 | self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 93 | self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 94 | self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 95 | self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 96 | self.w0 = nn.Conv3d(self.width, self.width, 1) 97 | self.w1 = nn.Conv3d(self.width, self.width, 1) 98 | self.w2 = nn.Conv3d(self.width, self.width, 1) 99 | self.w3 = nn.Conv3d(self.width, self.width, 1) 100 | self.bn0 = torch.nn.BatchNorm3d(self.width) 101 | self.bn1 = torch.nn.BatchNorm3d(self.width) 102 | self.bn2 = torch.nn.BatchNorm3d(self.width) 103 | self.bn3 = torch.nn.BatchNorm3d(self.width) 104 | 105 | self.fc1 = nn.Linear(self.width, 128) 106 | self.fc2 = nn.Linear(128, 1) 107 | 108 | def forward(self, x): 109 | grid = self.get_grid(x.shape, x.device) 110 | x = torch.cat((x, grid), dim=-1) 111 | x = self.fc0(x) 112 | x = x.permute(0, 4, 1, 2, 3) 113 | x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic 114 | 115 | x1 = self.conv0(x) 116 | x2 = self.w0(x) 117 | x = x1 + x2 118 | x = F.gelu(x) 119 | 120 | x1 = self.conv1(x) 121 | x2 = self.w1(x) 122 | x = x1 + x2 123 | x = F.gelu(x) 124 | 125 | x1 = self.conv2(x) 126 | x2 = self.w2(x) 127 | x = x1 + x2 128 | x = F.gelu(x) 129 | 130 | x1 = self.conv3(x) 131 | x2 = self.w3(x) 132 | x = x1 + x2 133 | 134 | x = x[..., :-self.padding] 135 | x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic 136 | x = self.fc1(x) 137 | x = F.gelu(x) 138 | x = self.fc2(x) 139 | return x 140 | 141 | def get_grid(self, shape, device): 142 | batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] 143 | gridx = torch.tensor(np.linspace(self.lb, self.ub, size_x), dtype=torch.float) 144 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 145 | gridy = torch.tensor(np.linspace(self.lb, self.ub, size_y), dtype=torch.float) 146 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 147 | gridz = torch.tensor(np.linspace(self.lb, self.ub, size_z), dtype=torch.float) 148 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 149 | return torch.cat((gridx, gridy, gridz), dim=-1).to(device) -------------------------------------------------------------------------------- /src/models/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .FNO1d import FNO1d 2 | from .FNO2d import FNO2d 3 | from .FNO3d import FNO3d 4 | -------------------------------------------------------------------------------- /src/models/corrections.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | ################################################################ 7 | # 1d kernel correction 8 | ################################################################ 9 | DEFAULT_BDY = {'val':None, 'diff_fn':None} 10 | 11 | 12 | class gaussian_elim_layer1d(nn.Module): 13 | def __init__( 14 | self, 15 | layer, 16 | tol = 1e-7): 17 | super().__init__() 18 | self.layer = layer 19 | self.tol = tol 20 | 21 | 22 | def forward(self, x): 23 | raise NotImplementedError 24 | 25 | 26 | def gauss_elimination(self, x, left, right): 27 | 28 | e1 = torch.zeros_like(x) 29 | e1[:, :, 0] = 1 30 | eN = torch.zeros_like(x) 31 | eN[:, :, -1] = 1 32 | 33 | K00 = self.layer(e1)[0,0,0] 34 | KN0 = self.layer(e1)[0,0,-1] 35 | if torch.abs(K00) < self.tol:K00 = self.tol 36 | 37 | KNN = self.layer(eN)[0,0,-1] 38 | K0N = self.layer(eN)[0,0,0] 39 | if torch.abs(KNN) < self.tol:KNN = self.tol 40 | 41 | if left is not None: # first apply left and then right bdy 42 | # opposite could also be done (may result in slight different results) 43 | Kx = self.layer(x) 44 | T2_left_x = x.clone() 45 | tilde_K00 = K00 - KN0*K0N/KNN 46 | if torch.abs(tilde_K00) < self.tol:tilde_K00 = self.tol 47 | T2_left_x[:, :, 0] = 2*x[:, :, 0] - (1/tilde_K00)*( 48 | Kx[:, :, 0] - Kx[:, :, -1]*K0N/KNN 49 | + K0N*x[:, :, -1]) 50 | else: 51 | T2_left_x = x 52 | 53 | Ky = self.layer(T2_left_x) 54 | 55 | if right is not None: # left bdy is already corrected, if exists 56 | T2_right_y = T2_left_x.clone() 57 | T2_right_y[:, :, -1] = 2*T2_right_y[:, :, -1] - Ky[:, :, -1]/KNN 58 | else: 59 | T2_right_y = T2_left_x 60 | 61 | T = self.layer(T2_right_y) 62 | return T 63 | 64 | 65 | class gaussian_elim_layer2d(nn.Module): 66 | def __init__( 67 | self, 68 | layer, 69 | tol = 1e-4): 70 | super().__init__() 71 | self.layer = layer 72 | self.tol = tol 73 | 74 | def forward(self): 75 | raise NotImplementedError 76 | 77 | def gauss_elimination(self, x, left, right): 78 | 79 | e1 = torch.zeros_like(x) 80 | eN = torch.zeros_like(x) 81 | 82 | e1[:, :, 0, :] = 1 83 | eN[:, :, -1, :] = 1 84 | 85 | K00 = self.layer(e1)[0, 0, 0, :] 86 | KN0 = self.layer(e1)[0, 0, -1, :] 87 | K0N = self.layer(eN)[0, 0, 0, :] 88 | KNN = self.layer(eN)[0, 0, -1, :] 89 | 90 | indices_K00_less_tol = torch.nonzero(K00 < self.tol) 91 | indices_KNN_less_tol = torch.nonzero(KNN < self.tol) 92 | 93 | if len(indices_K00_less_tol) != 0: 94 | K00[indices_K00_less_tol] = self.tol 95 | 96 | if len(indices_KNN_less_tol) != 0: 97 | KNN[indices_KNN_less_tol] = self.tol 98 | 99 | if left is not None: # first apply left and then right bdy 100 | # opposite could also be done (may result in slight different results) 101 | Kx = self.layer(x) 102 | T2_left_x = x.clone() 103 | 104 | tilde_K00 = K00 - KN0 * K0N / KNN 105 | 106 | indices_tilde_K00_less_tol = torch.nonzero(tilde_K00 < self.tol) 107 | 108 | if len(indices_tilde_K00_less_tol) != 0: 109 | tilde_K00[indices_tilde_K00_less_tol] = self.tol 110 | 111 | T2_left_x[:, :, 0, :] = 2 * x[:, :, 0, :] - (1 / tilde_K00) * ( 112 | Kx[:, :, 0, :] - Kx[:, :, -1, :] * K0N / KNN + K0N * x[:, :, -1, :]) 113 | 114 | else: 115 | T2_left_x = x 116 | 117 | Ky = self.layer(T2_left_x) 118 | 119 | if right is not None: # left bdy is already corrected, if exists 120 | T2_right_y = T2_left_x.clone() 121 | T2_right_y[:, :, -1, :] = 2 * T2_right_y[:, :, 0, :] - Ky[:, :, -1, :] / KNN 122 | else: 123 | T2_right_y = T2_left_x 124 | 125 | T = self.layer(T2_right_y) 126 | return T 127 | 128 | 129 | class gaussian_elim_layer3d(nn.Module): 130 | def __init__( 131 | self, 132 | layer, 133 | tol = 1e-4): 134 | super().__init__() 135 | self.layer = layer 136 | self.tol = tol 137 | 138 | 139 | def forward(self, x): 140 | raise NotImplementedError 141 | 142 | 143 | def gauss_elimination(self, x): 144 | 145 | e1_x = torch.zeros_like(x) 146 | e1_x[:, :, 0, :, :] = 1 147 | eN_x = torch.zeros_like(x) 148 | eN_x[:, :, -1, :, :] = 1 149 | 150 | K00_x = self.layer(e1_x)[0,0,0,0,:] 151 | KN0_x = self.layer(e1_x)[0,0,-1,0,:] 152 | idx_K00_below_tol = torch.nonzero(torch.abs(K00_x) 0: 154 | K00_x[idx_K00_below_tol] = self.tol 155 | 156 | KNN_x = self.layer(eN_x)[0,0,-1,0,:] 157 | K0N_x = self.layer(eN_x)[0,0,0,0,:] 158 | idx_KNN_below_tol = torch.nonzero(torch.abs(KNN_x)0: 160 | KNN_x[idx_KNN_below_tol] = self.tol 161 | 162 | 163 | tilde_K00_x = K00_x - KN0_x*K0N_x/KNN_x 164 | 165 | idx_tilde_K00x_below_tol = torch.nonzero(torch.abs(tilde_K00_x)0: 167 | tilde_K00_x[idx_tilde_K00x_below_tol] = self.tol 168 | 169 | Kx = self.layer(x) 170 | T2_left_x = x.clone() 171 | 172 | T2_left_x[:,:,0,:,:] = 2*x[:,:,0,:,:] - (1/tilde_K00_x)*( 173 | Kx[:,:,0,:,:] - Kx[:,:,-1,:,:]*K0N_x/KNN_x 174 | + K0N_x*x[:,:,-1,:,:]) 175 | 176 | Ky = self.layer(T2_left_x) 177 | T2_right_y = T2_left_x.clone() 178 | 179 | T2_right_y[:,:,-1,:,:] = 2*T2_left_x[:,:,0,:,:] - Ky[:,:,-1,:,:]/KNN_x 180 | 181 | T = self.layer(T2_right_y) 182 | return T 183 | 184 | 185 | class dirkernelcorrection1d(gaussian_elim_layer1d): 186 | def __init__(self, layer): 187 | super().__init__(layer=layer) 188 | """ 189 | 1D Corrected Layer. It modifies the kernel to enforce a Dirichlet boundary condition. 190 | """ 191 | 192 | def forward(self, x, left=DEFAULT_BDY, right=DEFAULT_BDY): 193 | if left['val'] is None and right['val'] is None: # no bdy correction 194 | return self.layer(x) 195 | 196 | T = self.gauss_elimination(x, left['val'], right['val']) 197 | 198 | T[..., 0] = left['val'] 199 | T[..., -1] = right['val'] 200 | 201 | return T 202 | 203 | 204 | class dirkernelcorrection2d(gaussian_elim_layer2d): 205 | def __init__(self, layer): 206 | super().__init__(layer=layer) 207 | """ 208 | 2D Corrected Layer. It modifies the kernel to enforce a Dirichlet boundary condition. 209 | """ 210 | 211 | def forward(self, x, left=DEFAULT_BDY, right=DEFAULT_BDY, 212 | num_smooth: int = 5): # TODO: update default to no smoothing 213 | 214 | if left['val'] is None and right['val'] is None: # no bdy correction 215 | return self.layer(x) 216 | 217 | T = self.gauss_elimination(x, left['val'], right['val']) 218 | 219 | T[:, :, 0, :] = left['val'] 220 | T[:, :, -1, :] = right['val'] 221 | 222 | # Apply mollifier for smoothing at the boundary 223 | T[:, :, 1:num_smooth, :] = (T[:, :, 0:num_smooth - 1, :] + T[:, :, 1:num_smooth, :] + T[:, :, 224 | 2:num_smooth + 1, 225 | :]) / 3 226 | T[:, :, -num_smooth:-1, :] = (T[:, :, -num_smooth - 1:-2, :] + T[:, :, -num_smooth:-1, :] + T[:, :, -num_smooth+1:, :]) / 3 227 | return T 228 | 229 | 230 | class dirkernelcorrection3d(gaussian_elim_layer3d): 231 | def __init__(self, layer): 232 | super().__init__(layer=layer) 233 | """ 234 | 1D Corrected Layer. It modifies the kernel to enforce a Dirichlet boundary condition. 235 | """ 236 | 237 | def forward(self, x, left=DEFAULT_BDY, right=DEFAULT_BDY, 238 | top=DEFAULT_BDY, down=DEFAULT_BDY): 239 | if (left['val'] is None and right['val'] is None and 240 | top['val'] is None and down['val'] is None): # no bdy correction 241 | return self.layer(x) 242 | 243 | T = self.gauss_elimination(x) 244 | 245 | T[:, :, 0, :, :] = left['val'] 246 | T[:, :, -1, :, :] = right['val'] 247 | T[:, :, :, 0, :] = top['val'] 248 | T[:, :, :, -1, :] = down['val'] 249 | 250 | return T 251 | 252 | 253 | class neukernelcorrection1d(gaussian_elim_layer1d): 254 | def __init__(self, layer): 255 | super().__init__(layer=layer) 256 | """ 257 | 1D Corrected Layer. It modifies the kernel to enforce a Neumann boundary condition. 258 | """ 259 | 260 | def forward(self, x , left=DEFAULT_BDY, right=DEFAULT_BDY): 261 | 262 | if left['val'] is None and right['val'] is None: # no bdy correction 263 | return self.layer(x) 264 | 265 | T = self.gauss_elimination(x, None, right['val']) 266 | 267 | inv_c0_l, _diff_l = left['diff_fn'](T) 268 | inv_c0_r, _diff_r = right['diff_fn'](T) 269 | 270 | T[..., 0] = left['val'] * inv_c0_l - _diff_l 271 | T[...,-1] = right['val'] * inv_c0_r - _diff_r 272 | 273 | return T/2 # 0.5 factor has shown to be better learning the model 274 | 275 | 276 | class neukernelcorrection2d(gaussian_elim_layer2d): 277 | def __init__(self, layer): 278 | super().__init__(layer=layer) 279 | """ 280 | 1D Corrected Layer. It modifies the kernel to enforce a Neumann boundary condition. 281 | """ 282 | 283 | def forward(self, x, left=DEFAULT_BDY, right=DEFAULT_BDY): 284 | 285 | if left['val'] is None and right['val'] is None: # no bdy correction 286 | return self.layer(x) 287 | 288 | # TODO: Fix getting OOM error when pass in left['val'] 289 | T = self.gauss_elimination(x, left['val'], right['val']) 290 | 291 | inv_c0_l, _diff_l = left['diff_fn'](T.permute(0,1,3,2)) 292 | inv_c0_r, _diff_r = right['diff_fn'](T.permute(0,1,3,2)) 293 | 294 | T[:, :, 0, :] = left['val'] * inv_c0_l - _diff_l 295 | T[:, :, -1, :] = right['val'] * inv_c0_r - _diff_r 296 | 297 | return T/2 # 0.5 factor has shown to be better learning the model 298 | 299 | 300 | class perkernelcorrection1d(nn.Module): 301 | def __init__(self, layer): 302 | super().__init__() 303 | 304 | """ 305 | 1D Corrected Layer. It modifies the kernel to enforce a periodic boundary condition. 306 | """ 307 | 308 | self.layer = layer 309 | self.alpha = 0.5 # may not need other possible value, alpha+beta=1, alpha, beta >=0 310 | self.beta = 0.5 311 | 312 | def forward(self, x, *args): 313 | x = self.layer(x) 314 | bdy_val = self.alpha*x[...,0] + self.beta*x[...,-1] 315 | x[..., 0] = bdy_val 316 | x[...,-1] = bdy_val 317 | 318 | return x 319 | 320 | 321 | class perkernelcorrection2d(nn.Module): 322 | def __init__(self, layer): 323 | super().__init__() 324 | 325 | """ 326 | 2D Corrected Layer. It modifies the kernel to enforce a periodic boundary condition. 327 | """ 328 | 329 | self.layer = layer 330 | self.alpha = 0.5 # may not need other possible value, alpha+beta=1, alpha, beta >=0 331 | self.beta = 0.5 332 | 333 | def forward(self, x, *args): 334 | x = self.layer(x) 335 | bdy_val = self.alpha * x[:, :, 0, :] + self.beta * x[:, :, -1, :] # shape (bs, channel_dim, N, T) 336 | x[:, :, 0, :] = bdy_val 337 | x[:, :, -1, :] = bdy_val 338 | 339 | return x 340 | 341 | -------------------------------------------------------------------------------- /src/models/multi_step/BOON_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base import FNO2d 6 | from ..corrections import ( 7 | dirkernelcorrection2d, 8 | neukernelcorrection2d, 9 | perkernelcorrection2d, 10 | ) 11 | from ..operator import BOON_2d 12 | 13 | DEFAULT_BDY = {'val':None, 'diff_fn':None} 14 | 15 | bdy_correction2d = { 16 | 'dirichlet':dirkernelcorrection2d, 17 | 'periodic': perkernelcorrection2d, 18 | 'neumann':neukernelcorrection2d, 19 | } 20 | 21 | 22 | # 1D space with time 23 | class BOON_FNO2d(BOON_2d): 24 | def __init__( 25 | self, 26 | width, 27 | base_no, 28 | lb = 0, 29 | ub = 1, 30 | bdy_type = 'dirichlet'): # dirichlet, periodic, neumann 31 | super().__init__(width=width, base_no=base_no, lb=lb, ub=ub, 32 | bdy_type=bdy_type) 33 | 34 | assert isinstance(base_no, FNO2d), ( 35 | 'BOON-FNO2d only accepts FNO 2D as base_no') 36 | 37 | self.padding = 9 38 | self.bdy_type = bdy_type 39 | 40 | self.conv_correction0 = bdy_correction2d[bdy_type](self.base_no.conv0) 41 | self.conv_correction1 = bdy_correction2d[bdy_type](self.base_no.conv1) 42 | self.conv_correction2 = bdy_correction2d[bdy_type](self.base_no.conv2) 43 | self.conv_correction3 = bdy_correction2d[bdy_type](self.base_no.conv3) 44 | 45 | self.w_correction0 = bdy_correction2d[bdy_type](self.base_no.w0) 46 | self.w_correction1 = bdy_correction2d[bdy_type](self.base_no.w1) 47 | self.w_correction2 = bdy_correction2d[bdy_type](self.base_no.w2) 48 | self.w_correction3 = bdy_correction2d[bdy_type](self.base_no.w3) 49 | 50 | 51 | def forward(self, x, bdy_left=DEFAULT_BDY, 52 | bdy_right=DEFAULT_BDY): 53 | 54 | non_bdy = self.get_non_bdy(x, bdy_left, bdy_right) 55 | grid = self.get_grid(x.shape, x.device) 56 | 57 | bdy_left_padded = bdy_left.copy() 58 | bdy_right_padded = bdy_right.copy() 59 | 60 | if self.bdy_type != 'periodic': 61 | bdy_left_padded['val'] = F.pad(bdy_left['val'], [0, self.padding]) 62 | bdy_right_padded['val'] = F.pad(bdy_right['val'], [0, self.padding]) 63 | 64 | x = torch.cat((x, grid), dim=-1) 65 | x = self.fc0(x) 66 | 67 | x = x.permute(0, 3, 1, 2) 68 | 69 | if self.bdy_type != 'periodic': 70 | x = F.pad(x, [0,self.padding]) 71 | 72 | x1 = self.conv_correction0(x,bdy_left_padded, bdy_right_padded) 73 | x2 = self.w_correction0(x,bdy_left_padded, bdy_right_padded) 74 | x = x1 + x2 75 | x[:, :, non_bdy, :] = F.gelu( 76 | x[:, :, non_bdy, :].clone()) 77 | 78 | 79 | x1 = self.conv_correction1(x,bdy_left_padded, bdy_right_padded) 80 | x2 = self.w_correction1(x,bdy_left_padded, bdy_right_padded) 81 | x = x1 + x2 82 | x[:, :, non_bdy, :] = F.gelu( 83 | x[:, :, non_bdy, :].clone()) 84 | 85 | 86 | x1 = self.conv_correction2(x,bdy_left_padded, bdy_right_padded) 87 | x2 = self.w_correction2(x,bdy_left_padded, bdy_right_padded) 88 | x = x1 + x2 89 | x[:, :, non_bdy, :] = F.gelu( 90 | x[:, :, non_bdy, :].clone()) 91 | 92 | 93 | x1 = self.conv_correction3(x,bdy_left_padded, bdy_right_padded) 94 | x2 = self.w_correction3(x,bdy_left_padded, bdy_right_padded) 95 | x = x1 + x2 96 | 97 | if self.bdy_type != 'periodic': 98 | x = x[..., :-self.padding] 99 | x = x.permute(0, 2, 3, 1) 100 | 101 | x = self.fc1(x) 102 | x[:, non_bdy, :, :] = F.gelu( 103 | x[:, non_bdy, :, :].clone()) 104 | 105 | x = self.fc2(x) 106 | x = self.strict_enforce_bdy(x, bdy_left, bdy_right) 107 | 108 | return x 109 | -------------------------------------------------------------------------------- /src/models/multi_step/BOON_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base import FNO3d 6 | from ..corrections import ( 7 | dirkernelcorrection3d, 8 | # neukernelcorrection3d, 9 | # perkernelcorrection3d, 10 | ) 11 | from ..operator import BOON_3d 12 | 13 | DEFAULT_BDY = {'val':None, 'diff_fn':None} 14 | 15 | # TODO: Add 3d support fot periodic and neumann 16 | bdy_correction3d = { 17 | 'dirichlet':dirkernelcorrection3d, 18 | # 'periodic': perkernelcorrection3d, 19 | # 'neumann':neukernelcorrection3d, 20 | } 21 | 22 | 23 | class BOON_FNO3d(BOON_3d): 24 | def __init__( 25 | self, 26 | width, 27 | base_no, 28 | lb = 0, 29 | ub = 1, 30 | bdy_type = 'dirichlet'): # dirichlet, periodic, neumann 31 | super().__init__(width=width, base_no=base_no, lb=lb, ub=ub, 32 | bdy_type=bdy_type) 33 | 34 | assert isinstance(base_no, FNO3d), ( 35 | 'BOON-FNO3d only accepts FNO 3D as base_no') 36 | 37 | self.padding = 6 38 | self.bdy_type = bdy_type 39 | 40 | self.conv_correction0 = bdy_correction3d[bdy_type](self.base_no.conv0) 41 | self.conv_correction1 = bdy_correction3d[bdy_type](self.base_no.conv1) 42 | self.conv_correction2 = bdy_correction3d[bdy_type](self.base_no.conv2) 43 | self.conv_correction3 = bdy_correction3d[bdy_type](self.base_no.conv3) 44 | 45 | self.w_correction0 = bdy_correction3d[bdy_type](self.base_no.w0) 46 | self.w_correction1 = bdy_correction3d[bdy_type](self.base_no.w1) 47 | self.w_correction2 = bdy_correction3d[bdy_type](self.base_no.w2) 48 | self.w_correction3 = bdy_correction3d[bdy_type](self.base_no.w3) 49 | 50 | 51 | def forward(self, x, bdy_left=DEFAULT_BDY, 52 | bdy_right=DEFAULT_BDY, 53 | bdy_top=DEFAULT_BDY, 54 | bdy_down=DEFAULT_BDY,): 55 | 56 | non_bdy_x, non_bdy_y = self.get_non_bdy(x, bdy_left, bdy_right, bdy_top, bdy_down) 57 | grid = self.get_grid(x.shape, x.device) 58 | 59 | bdy_left_padded = bdy_left.copy() 60 | bdy_right_padded = bdy_right.copy() 61 | bdy_top_padded = bdy_top.copy() 62 | bdy_down_padded = bdy_down.copy() 63 | 64 | if self.bdy_type != 'periodic': 65 | bdy_left_padded['val'] = F.pad(bdy_left['val'], [0, self.padding]) 66 | bdy_right_padded['val'] = F.pad(bdy_right['val'], [0, self.padding]) 67 | bdy_top_padded['val'] = F.pad(bdy_top['val'], [0, self.padding]) 68 | bdy_down_padded['val'] = F.pad(bdy_down['val'], [0, self.padding]) 69 | 70 | 71 | x = torch.cat((x, grid), dim=-1) 72 | x = self.fc0(x) 73 | 74 | x = x.permute(0, 4, 1, 2, 3) 75 | 76 | if self.bdy_type != 'periodic': 77 | x = F.pad(x, [0,self.padding]) 78 | 79 | x1 = self.conv_correction0(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 80 | x2 = self.w_correction0(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 81 | x = x1 + x2 82 | x[:, :, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :] = F.gelu( 83 | x[:, :, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :].clone()) 84 | 85 | 86 | x1 = self.conv_correction1(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 87 | x2 = self.w_correction1(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 88 | x = x1 + x2 89 | x[:, :, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :] = F.gelu( 90 | x[:, :, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :].clone()) 91 | 92 | 93 | x1 = self.conv_correction2(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 94 | x2 = self.w_correction2(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 95 | x = x1 + x2 96 | x[:, :, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :] = F.gelu( 97 | x[:, :, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :].clone()) 98 | 99 | 100 | x1 = self.conv_correction3(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 101 | x2 = self.w_correction3(x,bdy_left_padded, bdy_right_padded, bdy_top_padded, bdy_down_padded) 102 | x = x1 + x2 103 | 104 | if self.bdy_type != 'periodic': 105 | x = x[...,:-self.padding] 106 | x = x.permute(0, 2, 3, 4, 1) 107 | 108 | x = self.fc1(x) 109 | x[:, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :, :] = F.gelu( 110 | x[:, non_bdy_x[0]:non_bdy_x[1], non_bdy_y[0]:non_bdy_y[1], :, :].clone()) 111 | 112 | x = self.fc2(x) 113 | x = self.strict_enforce_bdy(x, bdy_left, bdy_right, bdy_top, bdy_down) 114 | 115 | return x 116 | -------------------------------------------------------------------------------- /src/models/multi_step/__init__.py: -------------------------------------------------------------------------------- 1 | from .BOON_3d import BOON_FNO3d 2 | from .BOON_2d import BOON_FNO2d 3 | -------------------------------------------------------------------------------- /src/models/operator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | DEFAULT_BDY = {'val':None, 'diff_fn':None} 7 | 8 | 9 | class add_layers(nn.Module): 10 | def __init__(self, layers): 11 | super(add_layers, self).__init__() 12 | self.layers = nn.ModuleList(layers) 13 | 14 | def forward(self, x): 15 | return torch.stack( 16 | [layer(x) for layer in self.layers], dim=-1).sum(dim=-1) 17 | 18 | 19 | class BOON_1d(nn.Module): 20 | def __init__( 21 | self, 22 | base_no, 23 | width, 24 | lb, 25 | ub, 26 | bdy_type): 27 | super().__init__() 28 | 29 | self.base_no = base_no 30 | self.width = width 31 | self.lb = lb # lower value of the domain 32 | self.ub = ub # upper value of the domain 33 | self.bdy_type = bdy_type 34 | 35 | self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x) 36 | self.fc1 = nn.Linear(self.width, 128) 37 | self.fc2 = nn.Linear(128, 1) 38 | 39 | 40 | def get_non_bdy(self, x, left=None, right=None): 41 | N = x.shape[1] 42 | non_bdy = torch.arange(N) 43 | if self.bdy_type == 'periodic': 44 | # return np.s_[:] # torch autograd giving issue 45 | return non_bdy 46 | else: # TODO: Do we need left['val'] and right['val'] here? 47 | left_interior = 1 if isinstance(left, torch.Tensor) else None 48 | right_interior = -1 if isinstance(right, torch.Tensor) else None 49 | 50 | return non_bdy[left_interior:right_interior] 51 | 52 | 53 | def strict_enforce_bdy(self, x, left=None, right=None): 54 | if self.bdy_type == 'dirichlet': 55 | if left is not None: 56 | x[:, 0]=left['val'] 57 | if right is not None: 58 | x[:,-1]=right['val'] 59 | elif self.bdy_type == 'periodic': 60 | bdy_val = 0.5*x[:,0,:] + 0.5*x[:,-1,:] 61 | x[:, 0,:] = bdy_val 62 | x[:,-1,:] = bdy_val 63 | elif self.bdy_type == 'neumann': 64 | if left is not None: 65 | inv_c0_l, _diff_l = left['diff_fn'](x.permute(0, 2, 1)) 66 | x[:, 0,:] = left['val']*inv_c0_l - _diff_l 67 | if right is not None: 68 | inv_c0_r, _diff_r = right['diff_fn'](x.permute(0, 2, 1)) 69 | x[:,-1,:] = right['val']*inv_c0_r - _diff_r 70 | else: 71 | raise NotImplementedError 72 | return x 73 | 74 | 75 | def get_grid(self, shape, device): 76 | batchsize, size_x = shape[0], shape[1] 77 | gridx = torch.tensor(np.linspace(self.lb, self.ub, size_x), dtype=torch.float) 78 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 79 | return gridx.to(device) 80 | 81 | 82 | class BOON_2d(nn.Module): 83 | def __init__( 84 | self, 85 | base_no, 86 | width, 87 | lb, 88 | ub, 89 | bdy_type): 90 | super().__init__() 91 | 92 | self.base_no = base_no 93 | self.width = width 94 | self.lb = lb 95 | self.ub = ub 96 | self.bdy_type = bdy_type 97 | 98 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 99 | self.fc1 = nn.Linear(self.width, 128) 100 | self.fc2 = nn.Linear(128, 1) 101 | 102 | def get_non_bdy(self, x, left=None, right=None): 103 | N = x.shape[1] 104 | non_bdy = torch.arange(N) 105 | if self.bdy_type == 'periodic': 106 | # return np.s_[:] # torch autograd giving issue 107 | return non_bdy 108 | else: 109 | left_interior = 1 if isinstance(left['val'], torch.Tensor) else None 110 | right_interior = -1 if isinstance(right['val'], torch.Tensor) else None 111 | 112 | return non_bdy[left_interior:right_interior] 113 | 114 | def strict_enforce_bdy(self, x, left=None, right=None, num_smooth=5): 115 | if self.bdy_type == 'dirichlet': 116 | if self.bdy_type == 'dirichlet': 117 | if left is not None: 118 | x[:, 0, :, 0] = left['val'][:, 0, :] 119 | if right is not None: 120 | x[:, -1, :, 0] = right['val'][:, 0, :] 121 | # Apply mollifier 122 | # TODO: Add custom smoothing stencil input 123 | x[:, 1:num_smooth, :] = (x[:, 0:num_smooth - 1, :] + x[:, 1:num_smooth, :] + x[:, 2:num_smooth + 1, :]) / 3 124 | x[:, -num_smooth:-1, :] = (x[:, -num_smooth - 1:-2, :] + x[:, -num_smooth:-1, :] + x[:, -num_smooth + 1:, :]) / 3 125 | elif self.bdy_type == 'periodic': 126 | bdy_val = 0.5 * x[:, 0, :, :] + 0.5 * x[:, -1, :, :] 127 | x[:, 0, :, :] = bdy_val 128 | x[:, -1, :, :] = bdy_val 129 | elif self.bdy_type == 'neumann': 130 | if left is not None: 131 | inv_c0_l, _diff_l = left['diff_fn'](x.permute(0, 2, 3, 1)) 132 | x[:, 0, :, 0] = left['val'][:, 0, :] * inv_c0_l - _diff_l.squeeze() 133 | if right is not None: 134 | inv_c0_r, _diff_r = right['diff_fn'](x.permute(0, 2, 3, 1)) 135 | x[:, -1, :, 0] = right['val'][:, 0, :] * inv_c0_r - _diff_r.squeeze() 136 | else: 137 | raise NotImplementedError 138 | return x 139 | 140 | def get_grid(self, shape, device): 141 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 142 | gridx = torch.tensor(np.linspace(self.lb, self.ub, size_x), dtype=torch.float) 143 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 144 | # TODO: Time dim may want to pass a separate lb_y, ub_y 145 | gridy = torch.tensor(np.linspace(self.lb, self.ub, size_y), dtype=torch.float) 146 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 147 | return torch.cat((gridx, gridy), dim=-1).to(device) 148 | 149 | 150 | class BOON_3d(nn.Module): 151 | def __init__( 152 | self, 153 | base_no, 154 | width, 155 | lb, 156 | ub, 157 | bdy_type): 158 | super().__init__() 159 | 160 | self.base_no = base_no 161 | self.width = width 162 | self.lb = lb 163 | self.ub = ub 164 | self.bdy_type = bdy_type 165 | 166 | self.fc0 = nn.Linear(4, self.width) # input channel is 4: (u(x), (x,y,z)) 167 | self.fc1 = nn.Linear(self.width, 128) 168 | self.fc2 = nn.Linear(128, 1) 169 | 170 | 171 | def get_non_bdy(self, x, left=None, right=None, top=None, down=None): 172 | Nx, Ny = x.shape[1], x.shape[2] 173 | if self.bdy_type == 'periodic': 174 | # return np.s_[:] # torch autograd giving issue 175 | return (0, Nx), (0, Ny) 176 | else: 177 | left_interior = 1 if isinstance(left['val'], torch.Tensor) else None 178 | right_interior = -1 if isinstance(right['val'], torch.Tensor) else None 179 | top_interior = 1 if isinstance(top['val'], torch.Tensor) else None 180 | down_interior = -1 if isinstance(down['val'], torch.Tensor) else None 181 | 182 | return (left_interior, right_interior), (top_interior, down_interior) 183 | 184 | 185 | def strict_enforce_bdy(self, x, left=None, right=None, top=None, down=None): 186 | if self.bdy_type == 'dirichlet': 187 | if left['val'] is not None: 188 | x[:, 0, :, :, 0]=left['val'][:,0,:,:] # squeeze expanded dimension at second axes 189 | if right['val'] is not None: 190 | x[:,-1, :, :, 0]=right['val'][:,0,:,:] 191 | if top['val'] is not None: 192 | x[:, :, 0, :, 0]=top['val'][:,0,:,:] 193 | if down['val'] is not None: 194 | x[:, :,-1, :, 0]=down['val'][:,0,:,:] 195 | elif self.bdy_type == 'periodic': 196 | bdy_val_x = 0.5*x[:,0] + 0.5*x[:,-1] 197 | x[:, 0] = bdy_val_x 198 | x[:,-1] = bdy_val_x 199 | bdy_val_y = 0.5*x[:,:,0] + 0.5*x[:,:,-1] 200 | x[:,:, 0] = bdy_val_y 201 | x[:,:,-1] = bdy_val_y 202 | 203 | elif self.bdy_type == 'neumann': 204 | raise NotImplementedError 205 | else: 206 | raise NotImplementedError 207 | return x 208 | 209 | 210 | def get_grid(self, shape, device): 211 | batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] 212 | gridx = torch.tensor(np.linspace(self.lb, self.ub, size_x), dtype=torch.float) 213 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 214 | # TODO: y dim may want to pass a separate lb_y, ub_y and lb_z, ub_z (time) 215 | gridy = torch.tensor(np.linspace(self.lb, self.ub, size_y), dtype=torch.float) 216 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 217 | gridz = torch.tensor(np.linspace(self.lb, self.ub, size_z), dtype=torch.float) 218 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 219 | return torch.cat((gridx, gridy, gridz), dim=-1).to(device) 220 | -------------------------------------------------------------------------------- /src/models/single_step/BOON_1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base import FNO1d 6 | from ..corrections import ( 7 | dirkernelcorrection1d, 8 | neukernelcorrection1d, 9 | perkernelcorrection1d, 10 | ) 11 | from ..operator import BOON_1d 12 | 13 | DEFAULT_BDY = {'val':None, 'diff_fn':None} 14 | 15 | bdy_correction1d = { 16 | 'dirichlet':dirkernelcorrection1d, 17 | 'periodic': perkernelcorrection1d, 18 | 'neumann':neukernelcorrection1d, 19 | } 20 | class BOON_FNO1d(BOON_1d): 21 | def __init__( 22 | self, 23 | width, 24 | base_no, 25 | lb = 0, 26 | ub = 1, 27 | bdy_type = 'dirichlet'): # dirichlet, neumann, periodic 28 | super().__init__(width=width, base_no=base_no, lb=lb, ub=ub, 29 | bdy_type=bdy_type) 30 | 31 | assert isinstance(base_no, FNO1d), ( 32 | 'BOON-FNO1d only accepts FNO 1D as base_no') 33 | 34 | self.conv_correction0 = bdy_correction1d[bdy_type](self.base_no.conv0) 35 | self.conv_correction1 = bdy_correction1d[bdy_type](self.base_no.conv1) 36 | self.conv_correction2 = bdy_correction1d[bdy_type](self.base_no.conv2) 37 | self.conv_correction3 = bdy_correction1d[bdy_type](self.base_no.conv3) 38 | 39 | self.w_correction0 = bdy_correction1d[bdy_type](self.base_no.w0) 40 | self.w_correction1 = bdy_correction1d[bdy_type](self.base_no.w1) 41 | self.w_correction2 = bdy_correction1d[bdy_type](self.base_no.w2) 42 | self.w_correction3 = bdy_correction1d[bdy_type](self.base_no.w3) 43 | 44 | 45 | def forward(self, x, bdy_left=DEFAULT_BDY, 46 | bdy_right=DEFAULT_BDY): 47 | 48 | non_bdy = self.get_non_bdy(x, bdy_left, bdy_right) 49 | grid = self.get_grid(x.shape, x.device) 50 | 51 | x = torch.cat((x, grid), dim=-1) 52 | x = self.fc0(x) 53 | 54 | x = x.permute(0, 2, 1) 55 | 56 | x1 = self.conv_correction0(x, bdy_left, bdy_right) 57 | x2 = self.w_correction0(x, bdy_left, bdy_right) 58 | x = x1 + x2 59 | x[:, :, non_bdy] = F.gelu(x[:, :, non_bdy].clone()) 60 | 61 | 62 | x1 = self.conv_correction1(x, bdy_left, bdy_right) 63 | x2 = self.w_correction1(x, bdy_left, bdy_right) 64 | x = x1 + x2 65 | x[:, :, non_bdy] = F.gelu(x[:, :, non_bdy].clone()) 66 | 67 | 68 | x1 = self.conv_correction2(x, bdy_left, bdy_right) 69 | x2 = self.w_correction2(x, bdy_left, bdy_right) 70 | x = x1 + x2 71 | x[:, :, non_bdy] = F.gelu(x[:, :, non_bdy].clone()) 72 | 73 | 74 | x1 = self.conv_correction3(x, bdy_left, bdy_right) 75 | x2 = self.w_correction3(x, bdy_left, bdy_right) 76 | x = x1 + x2 77 | x = x.permute(0, 2, 1) 78 | 79 | x = self.fc1(x) 80 | x[:, non_bdy, :] = F.gelu(x[:, non_bdy, :].clone()) 81 | 82 | x = self.fc2(x) 83 | x = self.strict_enforce_bdy(x, bdy_left, bdy_right) 84 | 85 | return x 86 | -------------------------------------------------------------------------------- /src/models/single_step/__init__.py: -------------------------------------------------------------------------------- 1 | from .BOON_1d import BOON_FNO1d -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import gdown 6 | import glob 7 | import os 8 | from scipy.io import loadmat 9 | import torch.nn as nn 10 | 11 | import operator 12 | from functools import reduce 13 | from functools import partial 14 | 15 | ################################################# 16 | # 17 | # Utilities 18 | # 19 | ################################################# 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | def compute_finite_diff(c_norm, x, loc=0): 24 | """ 25 | loc = 0, -1 for left, right, respectively 26 | returns: 1/c0 and c0-normalized differencing terms, 27 | first/last term of c_norm is 1/c0, 1/cN for loc=0,-1. respectively. 28 | """ 29 | len_c = len(c_norm) 30 | 31 | if loc==0: 32 | return c_norm[0], torch.sum(c_norm[1:] * x[...,1:len_c], dim=-1) 33 | elif loc==-1: 34 | return c_norm[-1], torch.sum(c_norm[:-1] * x[...,-len_c:-1], dim=-1) 35 | else: 36 | raise NotImplementedError 37 | 38 | 39 | # reading data 40 | class MatReader(object): 41 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 42 | super(MatReader, self).__init__() 43 | 44 | self.to_torch = to_torch 45 | self.to_cuda = to_cuda 46 | self.to_float = to_float 47 | 48 | self.file_path = file_path 49 | 50 | self.data = None 51 | self.old_mat = None 52 | self._load_file() 53 | 54 | def _load_file(self): 55 | try: 56 | self.data = scipy.io.loadmat(self.file_path) 57 | self.old_mat = True 58 | except: 59 | self.data = h5py.File(self.file_path) 60 | self.old_mat = False 61 | 62 | def load_file(self, file_path): 63 | self.file_path = file_path 64 | self._load_file() 65 | 66 | def read_field(self, field): 67 | x = self.data[field] 68 | 69 | if not self.old_mat: 70 | x = x[()] 71 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 72 | 73 | if self.to_float: 74 | x = x.astype(np.float32) 75 | 76 | if self.to_torch: 77 | x = torch.from_numpy(x) 78 | 79 | if self.to_cuda: 80 | x = x.cuda() 81 | 82 | return x 83 | 84 | def set_cuda(self, to_cuda): 85 | self.to_cuda = to_cuda 86 | 87 | def set_torch(self, to_torch): 88 | self.to_torch = to_torch 89 | 90 | def set_float(self, to_float): 91 | self.to_float = to_float 92 | 93 | # normalization, pointwise gaussian 94 | class UnitGaussianNormalizer(object): 95 | def __init__(self, x, eps=0.00001): 96 | super(UnitGaussianNormalizer, self).__init__() 97 | 98 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 99 | self.mean = torch.mean(x, 0) 100 | self.std = torch.std(x, 0) 101 | self.eps = eps 102 | 103 | def encode(self, x): 104 | x = (x - self.mean) / (self.std + self.eps) 105 | return x 106 | 107 | def decode(self, x, sample_idx=None): 108 | if sample_idx is None: 109 | std = self.std + self.eps # n 110 | mean = self.mean 111 | else: 112 | if len(self.mean.shape) == len(sample_idx[0].shape): 113 | std = self.std[sample_idx] + self.eps # batch*n 114 | mean = self.mean[sample_idx] 115 | if len(self.mean.shape) > len(sample_idx[0].shape): 116 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 117 | mean = self.mean[:,sample_idx] 118 | 119 | # x is in shape of batch*n or T*batch*n 120 | x = (x * std) + mean 121 | return x 122 | 123 | def cuda(self): 124 | self.mean = self.mean.cuda() 125 | self.std = self.std.cuda() 126 | 127 | def cpu(self): 128 | self.mean = self.mean.cpu() 129 | self.std = self.std.cpu() 130 | 131 | # normalization, Gaussian 132 | class GaussianNormalizer(object): 133 | def __init__(self, x, eps=0.00001): 134 | super(GaussianNormalizer, self).__init__() 135 | 136 | self.mean = torch.mean(x) 137 | self.std = torch.std(x) 138 | self.eps = eps 139 | 140 | def encode(self, x): 141 | x = (x - self.mean) / (self.std + self.eps) 142 | return x 143 | 144 | def decode(self, x, sample_idx=None): 145 | x = (x * (self.std + self.eps)) + self.mean 146 | return x 147 | 148 | def cuda(self): 149 | self.mean = self.mean.cuda() 150 | self.std = self.std.cuda() 151 | 152 | def cpu(self): 153 | self.mean = self.mean.cpu() 154 | self.std = self.std.cpu() 155 | 156 | 157 | # normalization, scaling by range 158 | class RangeNormalizer(object): 159 | def __init__(self, x, low=0.0, high=1.0): 160 | super(RangeNormalizer, self).__init__() 161 | mymin = torch.min(x, 0)[0].view(-1) 162 | mymax = torch.max(x, 0)[0].view(-1) 163 | 164 | self.a = (high - low)/(mymax - mymin) 165 | self.b = -self.a*mymax + high 166 | 167 | def encode(self, x): 168 | s = x.size() 169 | x = x.view(s[0], -1) 170 | x = self.a*x + self.b 171 | x = x.view(s) 172 | return x 173 | 174 | def decode(self, x): 175 | s = x.size() 176 | x = x.view(s[0], -1) 177 | x = (x - self.b)/self.a 178 | x = x.view(s) 179 | return x 180 | 181 | #loss function with rel/abs Lp loss 182 | class LpLoss(object): 183 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 184 | super().__init__() 185 | 186 | #Dimension and Lp-norm type are postive 187 | assert d > 0 and p > 0 188 | 189 | self.d = d 190 | self.p = p 191 | self.reduction = reduction 192 | self.size_average = size_average 193 | 194 | def abs(self, x, y): 195 | num_examples = x.size()[0] 196 | 197 | #Assume uniform mesh 198 | h = 1.0 / (x.size()[1] - 1.0) 199 | 200 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 201 | 202 | if self.reduction: 203 | if self.size_average: 204 | return torch.mean(all_norms) 205 | else: 206 | return torch.sum(all_norms) 207 | 208 | return all_norms 209 | 210 | def rel(self, x, y): 211 | num_examples = x.size()[0] 212 | 213 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 214 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 215 | 216 | if self.reduction: 217 | if self.size_average: 218 | return torch.mean(diff_norms/y_norms) 219 | else: 220 | return torch.sum(diff_norms/y_norms) 221 | 222 | return diff_norms/y_norms 223 | 224 | def __call__(self, x, y): 225 | return self.rel(x, y) 226 | 227 | # Sobolev norm (HS norm) 228 | # where we also compare the numerical derivatives between the output and target 229 | class HsLoss(object): 230 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 231 | super(HsLoss, self).__init__() 232 | 233 | #Dimension and Lp-norm type are postive 234 | assert d > 0 and p > 0 235 | 236 | self.d = d 237 | self.p = p 238 | self.k = k 239 | self.balanced = group 240 | self.reduction = reduction 241 | self.size_average = size_average 242 | 243 | if a == None: 244 | a = [1,] * k 245 | self.a = a 246 | 247 | def rel(self, x, y): 248 | num_examples = x.size()[0] 249 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 250 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 251 | if self.reduction: 252 | if self.size_average: 253 | return torch.mean(diff_norms/y_norms) 254 | else: 255 | return torch.sum(diff_norms/y_norms) 256 | return diff_norms/y_norms 257 | 258 | def __call__(self, x, y, a=None): 259 | nx = x.size()[1] 260 | ny = x.size()[2] 261 | k = self.k 262 | balanced = self.balanced 263 | a = self.a 264 | x = x.view(x.shape[0], nx, ny, -1) 265 | y = y.view(y.shape[0], nx, ny, -1) 266 | 267 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 268 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 269 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 270 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 271 | 272 | x = torch.fft.fftn(x, dim=[1, 2]) 273 | y = torch.fft.fftn(y, dim=[1, 2]) 274 | 275 | if balanced==False: 276 | weight = 1 277 | if k >= 1: 278 | weight += a[0]**2 * (k_x**2 + k_y**2) 279 | if k >= 2: 280 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 281 | weight = torch.sqrt(weight) 282 | loss = self.rel(x*weight, y*weight) 283 | else: 284 | loss = self.rel(x, y) 285 | if k >= 1: 286 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 287 | loss += self.rel(x*weight, y*weight) 288 | if k >= 2: 289 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 290 | loss += self.rel(x*weight, y*weight) 291 | loss = loss / (k+1) 292 | 293 | return loss 294 | 295 | # A simple feedforward neural network 296 | class DenseNet(torch.nn.Module): 297 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 298 | super(DenseNet, self).__init__() 299 | 300 | self.n_layers = len(layers) - 1 301 | 302 | assert self.n_layers >= 1 303 | 304 | self.layers = nn.ModuleList() 305 | 306 | for j in range(self.n_layers): 307 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 308 | 309 | if j != self.n_layers - 1: 310 | if normalize: 311 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 312 | 313 | self.layers.append(nonlinearity()) 314 | 315 | if out_nonlinearity is not None: 316 | self.layers.append(out_nonlinearity()) 317 | 318 | def forward(self, x): 319 | for _, l in enumerate(self.layers): 320 | x = l(x) 321 | 322 | return x 323 | 324 | 325 | # print the number of parameters 326 | def count_params(model): 327 | c = 0 328 | for p in list(model.parameters()): 329 | c += reduce(operator.mul, 330 | list(p.size()+(2,) if p.is_complex() else p.size())) 331 | return c 332 | 333 | 334 | class DataDownloader(): 335 | dir_id = { 336 | "Heat_Neu_1D":"1xIe-lPFk7z91CeEZtSuD47H7ajxaJC2j", 337 | "Burgers_Per_1D":"1wr1rUpT4jSYJNEk81mK4elm5BzihIQ3b", 338 | "Burgers_Dir_1D":"1Ehxxj6751AzQChF6gh3TggCtWSPRLPOT", 339 | "Stokes_Dir_1D":"1ILpBKD__iddtm-CEp2j_UqjXkcM3vOXT", 340 | "Heat_Neu_2D":"1UwNzd40DiStP0GNn9VO0yytMDxQEc1y3", 341 | "Burgers_Per_2D":"1Wvnx-8_MJG9bUhrNnMQfOZsdkiro2x1v", 342 | "Burgers_Dir_2D":"1v4J5T2OAFgPOjEawqIyUqZVsISBZwfSg", 343 | "Stokes_Dir_2D":"1XdKe_4_TeEpoF3kYMDRC1-osDALg7D-N", 344 | "NV_Dir_3D":"125T9UvHIgmvabtxDy2a1hqXA1AwFREdv", 345 | "Wave_Neu_3D":"1JFlvlVpFAvRSx-RMbzTm6F2NvitNQtY5", 346 | } 347 | 348 | def __init__( 349 | self, 350 | output = 'Data', 351 | quiet = False, 352 | ): 353 | self.output = output 354 | self.quiet = quiet 355 | 356 | def download(self, id, tag=None): 357 | if id not in self.dir_id: 358 | assert 0, "Data ID not present in the repo. Check again! Supported ones are " \ 359 | f"[{','.join(self.dir_id.keys())}]" 360 | 361 | if tag is None: 362 | gdown.download_folder(id=self.dir_id[id], output=os.path.join(self.output, id), quiet=self.quiet) 363 | else: 364 | id_and_names = self.dir_list(self.dir_id[id]) 365 | located = False 366 | for child_id, child_name, child_type in id_and_names: 367 | if tag in child_name: 368 | located=True 369 | break 370 | if located: 371 | parent = os.path.join(self.output, id) 372 | if not os.path.exists(parent): 373 | os.makedirs(parent) 374 | gdown.download(id=child_id, output=os.path.join(parent, child_name), quiet=self.quiet) 375 | else: 376 | assert 0, f"Provided tag:{tag} not present in the {id}" 377 | 378 | 379 | def locate(self, id, tag): 380 | files = glob.glob(os.path.join(self.output, id, '*.mat')) 381 | located = False 382 | for file in files: 383 | if tag in file: 384 | located = True 385 | break 386 | if located: 387 | return loadmat(file) 388 | else: 389 | assert 0, f"No file found with tag: {tag}." 390 | 391 | 392 | def dir_list(self, id): 393 | from gdown.download_folder import _parse_google_drive_file, _get_session 394 | 395 | url = "https://drive.google.com/drive/folders/{id}".format(id=id) 396 | sess = _get_session(proxy=None, use_cookies=True) 397 | # canonicalize the language into English 398 | 399 | if "?" in url: 400 | url += "&hl=en" 401 | else: 402 | url += "?hl=en" 403 | 404 | try: 405 | res = sess.get(url, verify=True) 406 | except requests.exceptions.ProxyError as e: 407 | print( 408 | "An error has occurred using proxy:", sess.proxies, file=sys.stderr 409 | ) 410 | print(e, file=sys.stderr) 411 | return None 412 | 413 | if res.status_code != 200: 414 | return None 415 | 416 | gdrive_file, id_name_type_iter = _parse_google_drive_file( 417 | url=url, 418 | content=res.text, 419 | ) 420 | return id_name_type_iter --------------------------------------------------------------------------------