├── .gitattributes ├── .github └── workflows │ ├── pytest.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── LICENSE-MIT ├── README.md ├── examples ├── Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb ├── Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb ├── McWilliams2d.svg ├── check_SFNO_shapes.py ├── ex2_FNO3d_train_normalized.ipynb ├── ex2_SFNO_5ep_spectra.ipynb ├── ex2_SFNO_finetune_McWilliams2d.ipynb ├── ex2_SFNO_finetune_fnodata.ipynb ├── ex2_SFNO_train.ipynb └── ex2_SFNO_train_fnodata.ipynb ├── fno ├── README.md ├── __init__.py ├── base.py ├── data_gen │ ├── __init__.py │ ├── data_gen_Kolmogorov2d.py │ ├── data_gen_McWilliams2d.py │ ├── data_gen_fno.py │ ├── data_gen_fno_legacy.py │ ├── data_utils.py │ ├── grf.py │ └── solvers.py ├── datasets.py ├── finetune.py ├── fno3d.py ├── losses.py ├── pipeline.py ├── sfno.py ├── sfno_pytest.py ├── train.py ├── utils.py └── visualizations.py ├── requirements.txt ├── setup.py └── torch_cfd ├── README.md ├── __init__.py ├── advection.py ├── boundaries.py ├── equations.py ├── finite_differences.py ├── forcings.py ├── fvm.py ├── grids.py ├── initial_conditions.py ├── pressure.py ├── spectral.py ├── tensor_utils.py ├── test_utils.py └── tests ├── __init__.py ├── test_advection.py ├── test_boundaries.py ├── test_finite_differences.py └── test_grids.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 🧪 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | test: 8 | name: Run pytest 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: "3.10" 18 | 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | 24 | - name: Extract tag name 25 | id: tag 26 | run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> $GITHUB_OUTPUT 27 | 28 | - name: Update version in setup.py 29 | run: >- 30 | sed -i "s/{{VERSION_PLACEHOLDER}}/${{ steps.tag.outputs.TAG_NAME }}/g" setup.py 31 | 32 | - name: Install build 33 | run: python -m pip install build 34 | 35 | - name: Run pytest 36 | run: | 37 | pytest --pyargs torch_cfd --verbose -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Publish Python 🐍 distributions 📦 to PyPI 10 | 11 | on: 12 | push: 13 | tags: 14 | - '*' 15 | 16 | jobs: 17 | build-n-publish: 18 | name: Build and publish Python 🐍 distributions 📦 to PyPI 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Set up Python 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: "3.10" 27 | 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install -r requirements.txt 32 | 33 | - name: Extract tag name 34 | id: tag 35 | run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> $GITHUB_OUTPUT 36 | 37 | - name: Update version in setup.py 38 | run: >- 39 | sed -i "s/{{VERSION_PLACEHOLDER}}/${{ steps.tag.outputs.TAG_NAME }}/g" setup.py 40 | 41 | - name: Install build 42 | run: python -m pip install build 43 | 44 | - name: Build a binary wheel and a source tarball 45 | run: python -m build --sdist --wheel --outdir dist/ 46 | 47 | - name: Publish distribution 📦 to PyPI 48 | uses: pypa/gh-action-pypi-publish@release/v1 49 | with: 50 | user: __token__ 51 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # data files 163 | *.pt 164 | *.pth 165 | *.pkl 166 | 167 | # mac os 168 | .DS_Store 169 | *.code-workspace 170 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2021] Google LLC 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Shuhao Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Operator-Assisted Computational Fluid Dynamics in PyTorch 2 | [![Python 3.10](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org/downloads/release/python-3100) ![CFD tests](https://github.com/scaomath/torch-cfd/actions/workflows/pytest.yml/badge.svg) 3 | 4 | ![A decaying turbulence (McWilliams 1984)](examples/McWilliams2d.svg) 5 | 6 | ## Summary 7 | 8 | This repository contains mainly two parts: 9 | 10 | ### Part I: a native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd) 11 | The main changes are documented in the `README.md` under the [`torch_cfd` directory](./torch_cfd/). The most significant changes in all routines include: 12 | - Routines that rely on the functional programming of Jax have been rewritten to be the PyTorch's tensor-in-tensor-out style, which is arguably more user-friendly to debugging as one can view intermediate tensors in VS Code debugger, set. 13 | - Functions and operators are in general implemented as `nn.Module` like a factory template. 14 | - Jax-cfd's `funcutils.trajectory` function supports tracking only one field variable (vorticity or velocity). For this port, extra fields computation and tracking are made more accessible, such as time derivatives $\partial_t\mathbf{u}_h$ and PDE residual $R(\mathbf{u}_h):=\mathbf{f}-\partial_t \mathbf{u}_h-(\mathbf{u}_h\cdot\nabla)\mathbf{u}_h + \nu \Delta \mathbf{u}_h$. 15 | - All ops take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, m)`, which is similar to PyTorch behavior. In Google Research's original Jax-CFD package, only a single trajectory is implemented. The stencil operations generally starts from the last dimension using negative indexing, following `torch.nn.functional.pad`'s behavior. 16 | 17 | ### Part II: Spectral-Refiner: Neural Operator-Assisted Navier-Stokes Equations simulator. 18 | - The **Spatiotemporal Fourier Neural Operator** (SFNO) is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`fno` directory](./fno). Different components of FNO have been re-implemented keeping the conciseness of the original implementation while allowing modern expansions. We draw inspiration from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator), [Transformers-based neural operators](https://github.com/thuml/Neural-Solver-Library), as well as Temam's book on functional analysis for the NSE. 19 | - Major architectural changes: learnable spatiotemporal positional encodings, layernorm to replace a hard-coded global Gaussian normalizer, and many others. For more details please see [the documentation of the `SFNO` class](./fno/sfno.py#L485). 20 | - Data generation for the meta-example of the isotropic turbulence in [McWilliams1984]. After the warmup phase, the energy spectra match the inverse cascade of Kolmogorov flow in a periodic box. 21 | - Pipelines for the *a posteriori* error estimation to fine-tune the SFNO to reach the scientific computing level of accuracy ($\le 10^{-6}$) in Bochner norm using FLOPs on par with a single evaluation, and only a fraction of FLOPs of a single `.backward()`. 22 | - [Examples](#examples) can be found below. 23 | 24 | [McWilliams1984]: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. *Journal of Fluid Mechanics*, 146, 21-43. 25 | 26 | ## Installation 27 | To install `torch-cfd`'s current release, simply do: 28 | ```bash 29 | pip install torch-cfd 30 | ``` 31 | 32 | If one wants to play with the neural operator part, it is recommended to clone this repo and play it locally by creating a venv using [`requirements.txt`](./requirements.txt). Note: even you do not want to install dependencies, using PyTorch version >=2.0.0 is recommended for the broadcasting semantics. 33 | ```bash 34 | python3.11 -m venv venv 35 | source venv/bin/activate 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Data 40 | The data are available at [https://huggingface.co/datasets/scaomath/navier-stokes-dataset](https://huggingface.co/datasets/scaomath/navier-stokes-dataset). 41 | Data generation instructions are available in the [SFNO folder](./fno). 42 | 43 | 44 | ## Examples 45 | - Demos of different simulation setups: 46 | - [2D simulation with a pseudo-spectral solver](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb) 47 | - [2D simulation with a finite volume solver](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb) 48 | - Demos of Spatiotemporal FNO's training and evaluation using the neural operator-assisted fluid simulation pipelines 49 | - [Training of SFNO for only 15 epochs for the isotropic turbulence example](./examples/ex2_SFNO_train.ipynb) 50 | - [Training of SFNO for only ***10*** epochs with 1k samples and reach `1e-2` level of relative error](./examples/ex2_SFNO_train_fnodata.ipynb) using the data in the FNO paper, which to our best knowledge no operator learner can do this in <100 epochs in the small data regime. 51 | - [Fine-tuning of SFNO on a `256x256` grid for only 50 ADAM iterations to reach `1e-6` residual in the functional norm using FNO data](./examples/ex2_SFNO_finetune_fnodata.ipynb) 52 | - [Fine-tuning of SFNO on the `256x256` grid for the McWilliams 2d isotropic turbulence](./examples/ex2_SFNO_finetune_McWilliams2d.ipynb) 53 | - [Training of SFNO for only 5 epoch to match the inverse cascade of Kolmogorov flow](./examples/ex2_SFNO_5ep_spectra.ipynb) 54 | - [Baseline of FNO3d for fixed step size that requires preloading a normalizer](./examples/ex2_FNO3d_train_normalized.ipynb) 55 | 56 | ## Licenses 57 | The Apache 2.0 License in the root folder applies to the `torch-cfd` folder of the repo that is inherited from Google's original license file for `Jax-cfd`. The `fno` folder has the MIT license inherited from [NVIDIA's Neural Operator repo](https://github.com/neuraloperator/neuraloperator). Note: the license(s) in the subfolder takes precedence. 58 | 59 | ## Contributions 60 | PR welcome. Currently, the port of `torch-cfd` currently includes: 61 | - The pseudospectral method for vorticity uses anti-aliasing filtering techniques for nonlinear terms to maintain stability. 62 | - The finite volume method on a MAC grid for velocity, and using the projection scheme to impose the divergence free condition. 63 | - Temporal discretization: Currently only RK4 temporal discretization uses explicit time-stepping for advection and either implicit or explicit time-stepping for diffusion. 64 | - Boundary conditions: only periodic boundary conditions. 65 | 66 | ## Reference 67 | 68 | If you like to use `torch-cfd` please use the following [paper](https://arxiv.org/abs/2405.17211) as citation. 69 | 70 | ```bibtex 71 | @inproceedings{2025SpectralRefiner, 72 | title = {Spectral-Refiner: Accurate Fine-Tuning of Spatiotemporal Fourier Neural Operator for Turbulent Flows}, 73 | author = {Shuhao Cao and Francesco Brarda and Ruipeng Li and Yuanzhe Xi}, 74 | booktitle = {The Thirteenth International Conference on Learning Representations}, 75 | year = {2025}, 76 | url = {https://openreview.net/forum?id=MKP1g8wU0P}, 77 | eprint = {arXiv:2405.17211}, 78 | } 79 | ``` 80 | 81 | ## Acknowledgments 82 | I am grateful for the support from [Long Chen (UC Irvine)](https://github.com/lyc102/ifem) and 83 | [Ludmil Zikatanov (Penn State)](https://github.com/HAZmathTeam/hazmath) over the years, and their efforts in open-sourcing scientific computing codes. I also appreciate the support from the National Science Foundation (NSF) to junior researchers. I want to thank the free A6000 credits at the SSE ML cluster from the University of Missouri. 84 | 85 | (Added after `0.2.0`) I also want to acknowledge that University of Missouri's OpenAI Enterprise API key. After version `0.1.0`, I began prompt in VSCode Copilot with existing codes (using the OpenAI Enterprise API), which arguably significantly improve the efficiency on "porting->debugging->refactoring" cycle, e.g., Copilot helps design unittests and `.vscode/launch.json` for debugging. For details of how Copilot's suggestions on code refactoring, please see [README.md](./torch_cfd/README.md) in `torch_cfd` folder. 86 | 87 | For individual paper's acknowledgment please see [here](./fno/README.md). 88 | -------------------------------------------------------------------------------- /examples/check_SFNO_shapes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fno.sfno import SFNO 3 | 4 | 5 | if __name__ == "__main__": 6 | """ 7 | testing the arbitrary sizes inference for both 8 | spatial and temporal dimensions of SFNO 9 | """ 10 | modes = 8 11 | modes_t = 2 12 | width = 10 13 | bsz = 5 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | sizes = [(n, n, n_t) for (n, n_t) in zip([64, 128, 256], [5, 10, 20])] 16 | model = SFNO(modes, modes, modes_t, width, 17 | latent_steps=3).to(device) 18 | x = torch.randn(bsz, *sizes[0]).to(device) 19 | _ = model(x) 20 | 21 | try: 22 | from torchinfo import summary 23 | 24 | """ 25 | torchinfo has not resolve the complex number problem 26 | """ 27 | summary(model, input_size=(bsz, *sizes[-1])) 28 | except: 29 | raise ImportError( 30 | "torchinfo is not installed, please install it to get the model summary" 31 | ) 32 | del model 33 | 34 | print("\n" * 3) 35 | for k, size in enumerate(sizes): 36 | torch.cuda.empty_cache() 37 | model = SFNO(modes, modes, modes_t, width, latent_steps=3).to(device) 38 | model.add_latent_hook("activations") 39 | x = torch.randn(bsz, *size).to(device) 40 | pred = model(x) 41 | print(f"\n\ninput shape: {list(x.size())}") 42 | print(f"output shape: {list(pred.size())}") 43 | for k, v in model.latent_tensors.items(): 44 | print(k, list(v.shape)) 45 | del model 46 | 47 | print("\n") 48 | # test evaluation speed 49 | from time import time 50 | 51 | torch.cuda.empty_cache() 52 | model = SFNO(modes, modes, modes_t, width, latent_steps=3).to(device) 53 | model.eval() 54 | x = torch.randn(bsz, *sizes[1]).to(device) 55 | start_time = time() 56 | for _ in range(100): 57 | pred = model(x) 58 | end_time = time() 59 | print(f"Average eval for time: {(end_time - start_time) / 100:.6f} seconds") 60 | del model -------------------------------------------------------------------------------- /fno/README.md: -------------------------------------------------------------------------------- 1 | # Spatiotemporal Fourier Neural Operator (SFNO) 2 | This is a new concise implementation of the Fourier Neural Operator see [`base.py`](./base.py#L172) for a template class. 3 | 4 | ## Learning maps between Bochner spaces 5 | SFNO now can learn a `trajectory-to-trajectory` map that inputs arbitrary-length trajectory, and outputs arbitrary-lengthed trajectory (if length is not specified, then the output length is the same with the input). The tests on its trajectory-to-trajectory shapes can be found in [`sfno_pytest.py`](sfno_pytest.py) and [`check_SFNO_shapes.py`](../examples/check_SFNO_shapes.py). 6 | 7 | ## Data generation 8 | 9 | ### FNO NSE datasdet 10 | Generate the original FNO data where the right hand side is a fixed forcing $0.1(\sin(2\pi(x+y))+\cos(2\pi(x+y)))$. 11 | 12 | - Training and validation data (training using first 1152 and valid using the last 128) for paper 13 | ```bash 14 | >>> python data_gen_fno.py --num-samples 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3 15 | ``` 16 | 17 | - Test data on `256x256` grid 18 | ```bash 19 | >>> python data_gen_fno.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --replicable-init --seed 42 20 | ``` 21 | 22 | ### McWilliams 2d dataset 23 | 24 | Generate the isotropic turbulence in [1] with the inverse cascade frequency signature Kolmogorov discovered. 25 | 26 | [1]: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. Journal of Fluid Mechanics, 146, 21-43. 27 | 28 | - Training dataset: 29 | ```bash 30 | >>> python data_gen_McWilliams2d.py --num-samples 1152 --batch-size 128 --grid-size 256 --subsample 4 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" 31 | ``` 32 | 33 | - Testing dataset for plotting the enstrohpy spectrum in the paper 34 | ```bash 35 | >>> python data_gen_McWilliams2d.py --num-samples 16 --grid-size 256 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" 36 | ``` 37 | 38 | 39 | ## Training and evaluation scripts 40 | 41 | ### VSCode workspace for development 42 | Please add the following setting to your VSCode workspace setting: 43 | ```json 44 | "settings": { 45 | "terminal.integrated.env.osx": {"PYTHONPATH": "${workspaceFolder}"}, 46 | "terminal.integrated.env.linux": {"PYTHONPATH": "${workspaceFolder}"}, 47 | "jupyter.notebookFileRoot": "${workspaceFolder}", 48 | } 49 | ``` 50 | 51 | 52 | ### Testing the arbitrary input and output discretization sizes (including time) 53 | Run the part below `__name__ == "__main__"` in [`sfno.py`](sfno.py) 54 | ```bash 55 | >>> python fno/sfno.py 56 | ``` 57 | 58 | ### FNO NSE dataset 59 | Train SFNO for the FNO dataset: 60 | ```bash 61 | >>> python train.py --example "fno" --num-samples 1152 --num-val-samples 128 --epochs 10 --width 20 --modes 12 --modes-t 5 --time-steps 10 --out-time-steps 40 --beta 0.02 62 | ``` 63 | 64 | Evaluating the model only and plotting the predictions. Note for evaluation, there is no need to specify the out_steps when initializing the model. One should get around `1e-2` relative accuracy with the ground truth in 10 epochs of training, if this level is not reached, something must be wrong with the setup. 65 | ```bash 66 | >>> python train.py --example "fno" --eval-only --epochs 10 --width 20 --modes 12 --modes-t 5 --beta 0.02 --out-time-steps 40 --demo-plots 10 67 | ``` 68 | 69 | ### The McWilliams 2d dataset 70 | The isotropic turbulence that has the inverse cascade of -5/3 frequency decay signature. 71 | 72 | Train SFNO for the McWilliams2d dataset. One should get around `6e-2` relative accruacy with the ground truth after 15 epochs of training. 73 | ```bash 74 | >>> python train.py --example "McWilliams2d" --epochs 15 --width 10 --modes 32 --modes-t 5 --beta -0.01 75 | ``` 76 | 77 | Evaluation for McWilliams2d dataset: note there will be aliasing error caused by the super-resolution when the solution is not smooth. 78 | ```bash 79 | >>> python train.py --example "McWilliams2d" --eval-only --width 10 --modes 32 --modes-t 5 --beta -0.01 --demo-plots 10 80 | ``` 81 | 82 | ## Licenses 83 | This folder has the MIT license. Note: the license(s) in the subfolder takes precedence. 84 | 85 | ## Acknowledgments 86 | The research of Brarda and Xi is supported by the National Science Foundation award DMS-2208412. 87 | The work of Li was performed under the auspices of 88 | the U.S. Department of Energy by Lawrence Livermore National Laboratory under Contract DEAC52-07NA27344 and was supported by the LLNL-LDRD program under Project No. 24ERD033. The research of Cao also is in part supported by the National Science Foundation DMS-2309778. -------------------------------------------------------------------------------- /fno/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaomath/torch-cfd/6abe65ad8c49b31bc09c661974c7c4b120ab6729/fno/__init__.py -------------------------------------------------------------------------------- /fno/base.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2025 Shuhao Cao 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | from __future__ import annotations 10 | 11 | from abc import abstractmethod 12 | 13 | from copy import deepcopy 14 | 15 | from functools import partial 16 | from typing import List, Tuple, Union 17 | 18 | import torch 19 | import torch.fft as fft 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.nn.init import constant_, xavier_uniform_ 23 | 24 | 25 | conv_dict = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} 26 | 27 | ACTIVATION_FUNCTIONS = [ 28 | "CELU", 29 | "ELU", 30 | "GELU", 31 | "GLU", 32 | "Hardtanh", 33 | "Hardshrink", 34 | "Hardsigmoid", 35 | "Hardswish", 36 | "LeakyReLU", 37 | "LogSigmoid", 38 | "MultiheadAttention", 39 | "PReLU", 40 | "ReLU", 41 | "ReLU6", 42 | "RReLU", 43 | "SELU", 44 | "SiLU", 45 | "Sigmoid", 46 | "SoftPlus", 47 | "Softmax", 48 | "Softmax2d", 49 | "Softshrink", 50 | "Softsign", 51 | "Tanh", 52 | "Tanhshrink", 53 | "Threshold", 54 | "Mish", 55 | ] 56 | 57 | # Type hint for activation functions 58 | ActivationType = Union[str] 59 | 60 | 61 | class LayerNormnd(nn.GroupNorm): 62 | """ 63 | a wrapper for GroupNorm used as LayerNorm 64 | https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html 65 | input and output shapes: (bsz, C, *) 66 | * can be (H, W) or (H, W, T) 67 | Note: by default the layernorm is applied to the last dimension 68 | """ 69 | 70 | def __init__( 71 | self, num_channels, eps=1e-07, elementwise_affine=True, device=None, dtype=None 72 | ): 73 | super().__init__( 74 | num_groups=1, 75 | num_channels=num_channels, 76 | eps=eps, 77 | affine=elementwise_affine, 78 | device=device, 79 | dtype=dtype, 80 | ) 81 | 82 | def forward(self, v: torch.Tensor): 83 | return super().forward(v) 84 | 85 | 86 | class PointwiseFFN(nn.Module): 87 | def __init__( 88 | self, 89 | in_channels: int, 90 | out_channels: int, 91 | mid_channels: int, 92 | activation: ActivationType = "ReLU", 93 | dim: int = 3, 94 | ): 95 | super().__init__() 96 | """ 97 | Pointwisely-applied 2-layer FFN with a channel expansion 98 | """ 99 | 100 | if dim not in conv_dict: 101 | raise ValueError(f"Unsupported dimension: {dim}, expected 1, 2, or 3") 102 | 103 | Conv = conv_dict[dim] 104 | self.linear1 = Conv(in_channels, mid_channels, 1) 105 | self.linear2 = Conv(mid_channels, out_channels, 1) 106 | self.activation = getattr(nn, activation)() 107 | 108 | def forward(self, v: torch.Tensor): 109 | for b in [self.linear1, self.activation, self.linear2]: 110 | v = b(v) 111 | return v 112 | 113 | 114 | class SpectralConv(nn.Module): 115 | def __init__( 116 | self, 117 | in_channels: int, 118 | out_channels: int, 119 | modes: List[int], 120 | dim: int, 121 | bias: bool = False, 122 | norm: str = "backward", 123 | ) -> None: 124 | super().__init__() 125 | 126 | """ 127 | Spacetime Fourier layer template 128 | FFT, linear transform, and Inverse FFT. 129 | focusing on space 130 | modes: the number of Fourier modes in each dimension 131 | modes's length needs to be same as the dimension 132 | """ 133 | 134 | self.in_channels = in_channels 135 | self.out_channels = out_channels 136 | self.dim = dim 137 | self.bias = bias 138 | assert len(modes) == dim, "modes should match the dimension" 139 | size = [in_channels, out_channels, *modes, 2] 140 | gain = 0.5 / (in_channels * out_channels) 141 | dims = tuple(range(-self.dim, 0)) 142 | self.fft = partial(fft.rfftn, dim=dims, norm=norm) 143 | self.ifft = partial(fft.irfftn, dim=dims, norm=norm) 144 | self._initialize_weights(size, gain) 145 | 146 | def _initialize_weights(self, size, gain=1e-4): 147 | """ 148 | # of weight groups = 4 = 2*(ndim - 1) 149 | """ 150 | self.weight = nn.ParameterList( 151 | [ 152 | nn.Parameter(gain * torch.rand(*size)) 153 | for _ in range(2 * (self.dim - 1)) 154 | ] # 2*(ndim - 1) 155 | ) 156 | if self.bias: 157 | self.bias = nn.ParameterList( 158 | [ 159 | nn.Parameter( 160 | gain 161 | * torch.zeros( 162 | *size[2:], 163 | ) 164 | ) 165 | for _ in range(2 * (self.dim - 1)) # 2*(ndim - 1) 166 | ] 167 | ) 168 | 169 | def _reset_parameters(self, gain=1e-6): 170 | for name, param in self.named_parameters(): 171 | if "bias" in name: 172 | constant_(param, 0.0) 173 | else: 174 | xavier_uniform_(param, gain) 175 | 176 | @staticmethod 177 | def complex_matmul(x, w, **kwargs): 178 | """ 179 | Implement this method in subclass to return complex matmul function 180 | this is a general implmentation of arbitrary dimension 181 | (b, c_i, *mesh_dims), (c_i, c_o, *mesh_dims) -> (b, c_o, *mesh_dims) 182 | for pure einsum benchmark, ellipsis version runs about 30% slower, 183 | however, when being implemented in FNO, the performance difference is negligible 184 | one can implement a more specific einsum for the dimension 185 | 1D: (b, c_i, x), (c_i, c_o, x) -> (b, c_o, x) 186 | 2D: (b, c_i, x, y), (c_i, c_o, x, y) -> (b, c_o, x, y) 187 | (2+1)D: (b, c_i, x, y, t), (c_i, c_o, x, y, t) -> (b, c_o, x, y, t) 188 | """ 189 | return torch.einsum("bi...,io...->bo...", x, w) 190 | 191 | def _set_complex_matmul_nd(self, dim: int = None): 192 | """ 193 | Generate einsum string based on dimension. 194 | 1D: "bix,iox->box" 195 | 2D: "bixy,ioxy->boxy" 196 | 3D: "bixyz,ioxyz->boxyz" 197 | 4D: "biwxyz, iowxyz->bowxyz" 198 | 199 | Args: 200 | dim: The dimension of the data 201 | 202 | Returns: 203 | str: The appropriate einsum string 204 | """ 205 | dim = self.dim if dim is None else dim 206 | assert dim >= 1 207 | 208 | # Start with the basic components 209 | inp = "bi" 210 | w = "io" 211 | out = "bo" 212 | 213 | # Add dimension-specific characters 214 | mesh_dims = "".join([chr(ord("z") + i) for i in range(1 - dim, 1)]) 215 | 216 | inp += mesh_dims 217 | w += mesh_dims 218 | out += mesh_dims 219 | 220 | equation = f"{inp},{w}->{out}" 221 | self.complex_matmul = partial(torch.einsum, equation) 222 | 223 | @abstractmethod 224 | def spectral_conv(self, vhat, *fft_mesh_size, **kwargs): 225 | raise NotImplementedError( 226 | "Subclasses must implement spectral_conv() to perform spectral convolution" 227 | ) 228 | 229 | def forward(self, v, out_mesh_size=None, **kwargs): 230 | bsz, _, *mesh_size = v.size() 231 | out_mesh_size = mesh_size if out_mesh_size is None else out_mesh_size 232 | fft_mesh_size = mesh_size.copy() 233 | fft_mesh_size[-1] = mesh_size[-1] // 2 + 1 234 | v_hat = self.fft(v) 235 | v_hat = self.spectral_conv(v_hat, *fft_mesh_size) 236 | v = self.ifft(v_hat, s=out_mesh_size) 237 | return v 238 | 239 | 240 | class FNOBase(nn.Module): 241 | def __init__( 242 | self, 243 | *, 244 | num_spectral_layers: int = 4, 245 | fft_norm="backward", 246 | activation: ActivationType = "ReLU", 247 | spatial_padding: int = 0, 248 | channel_expansion: int = 4, 249 | spatial_random_feats: bool = False, 250 | lift_activation: bool = False, 251 | debug=False, 252 | **kwargs, 253 | ): 254 | super().__init__() 255 | """New implementation for the base class for Fourier Neural Operator (FNO) models. 256 | The users need to implement 257 | - the lifting operator 258 | - the output operator 259 | - the forward method 260 | 261 | add_latent_hook() is used to register a hook to get the latent tensors 262 | Example: 263 | model.add_latent_hook("reduction") # reduction is the name of the layer 264 | The hook will save the output of the layer to self.latent_tensors["r"] 265 | which is the output of the layer self.r 266 | """ 267 | 268 | self.spatial_padding = spatial_padding 269 | self.fft_norm = fft_norm 270 | self.activation = activation 271 | self.spatial_random_feats = spatial_random_feats 272 | self.lift_activation = lift_activation 273 | self.channel_expansion = channel_expansion 274 | self.debug = debug 275 | self.num_spectral_layers = num_spectral_layers 276 | # These should be implemented by subclasses 277 | 278 | @staticmethod 279 | def _set_modulelist(module: nn.Module, num_layers, *args): 280 | return nn.ModuleList([deepcopy(module(*args)) for _ in range(num_layers)]) 281 | 282 | @property 283 | @abstractmethod 284 | def set_lifting_operator(self, *args, **kwargs): 285 | """Implement this method in subclass to return the lifting operator""" 286 | raise NotImplementedError("Subclasses must implement lifting_operator property") 287 | 288 | @property 289 | @abstractmethod 290 | def set_output_operator(self, *args, **kwargs): 291 | """Implement this method in subclass to return the output operator""" 292 | raise NotImplementedError("Subclasses must implement output_operator property") 293 | 294 | def _set_spectral_layers( 295 | self, 296 | num_layers: int, 297 | modes: List[int], 298 | width: int, 299 | activation: ActivationType, 300 | spectral_conv: SpectralConv, 301 | mlp: PointwiseFFN, 302 | linear: Union[nn.Conv1d, nn.Conv2d, nn.Conv3d], 303 | channel_expansion: int = 4, 304 | ) -> None: 305 | """ 306 | In SFNO 307 | spectral_conv: SpectralConvS 308 | mlp: MLP with dim=3 309 | linear: nn.Conv3d 310 | """ 311 | act_func = getattr(nn, activation) 312 | for attr, module, args in zip( 313 | ["spectral_conv", "mlp", "w", "activations"], 314 | [spectral_conv, mlp, linear, act_func], 315 | [ 316 | (width, width, *modes), 317 | (width, width, channel_expansion * width, activation), 318 | (width, width, 1), 319 | (), 320 | ], 321 | ): 322 | setattr( 323 | self, 324 | attr, 325 | self._set_modulelist(module, num_layers, *args), 326 | ) 327 | 328 | latent_tensors = {} 329 | 330 | def add_latent_hook(self, layer_name: str): 331 | def _get_latent_tensors(name): 332 | def hook(model, input, output): 333 | self.latent_tensors[name] = output.detach() 334 | 335 | return hook 336 | 337 | module = getattr(self, layer_name) 338 | 339 | if hasattr(module, "__iter__"): 340 | for k, b in enumerate(module): 341 | b.register_forward_hook(_get_latent_tensors(f"{layer_name}_{k}")) 342 | else: 343 | module.register_forward_hook(_get_latent_tensors(layer_name)) 344 | 345 | def double(self): 346 | for param in self.parameters(): 347 | if param.dtype == torch.float32: 348 | param.data = param.data.to(torch.float64) 349 | elif param.dtype == torch.complex64: 350 | param.data = param.data.to(torch.complex128) 351 | return self 352 | 353 | def forward(self, *args, **kwargs): 354 | raise NotImplementedError("Subclasses of FNO must implement the forward method") 355 | -------------------------------------------------------------------------------- /fno/data_gen/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import * 2 | from .solvers import * -------------------------------------------------------------------------------- /fno/data_gen/data_gen_Kolmogorov2d.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2024 Shuhao Cao 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | import os 11 | 12 | import torch 13 | import torch.fft as fft 14 | from torch_cfd.finite_differences import curl_2d 15 | from torch_cfd.forcings import KolmogorovForcing 16 | 17 | from torch_cfd.grids import Grid 18 | from torch_cfd.initial_conditions import filtered_velocity_field 19 | from torch_cfd.equations import * 20 | from data_gen.solvers import get_trajectory_imex 21 | 22 | from data_utils import * 23 | 24 | from fno.pipeline import DATA_PATH, LOG_PATH 25 | 26 | 27 | def main(args): 28 | """ 29 | Generate the Kolmogorov 2d flow data in [1] that are used an examples in [2]. 30 | 31 | [1]: Kolmogorov, A. N. (1941). The local structure of turbulence in incompressible viscous fluid for very large Reynolds. Numbers. In Dokl. Akad. Nauk SSSR, 30, 301. 32 | 33 | [2]: Kochkov, D., Smith, J. A., Alieva, A., Wang, Q., Brenner, M. P., & Hoyer, S. (2021). Machine learning-accelerated computational fluid dynamics. Proceedings of the National Academy of Sciences, 118(21), e2101784118. 34 | 35 | Training dataset: 36 | >>> python data_gen_Kolmogorov2d.py --num-samples 1152 --batch-size 128 --grid-size 256 --subsample 4 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" 37 | 38 | Testing dataset for plotting the enstrohpy spectrum: 39 | >>> python data_gen_Kolmogorov2d.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double 40 | 41 | Testing if the data generation works: 42 | >>> python data_gen_Kolmogorov2d.py --num-samples 4 --batch-size 2 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double --demo 43 | """ 44 | args = args.parse_args() 45 | 46 | current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm") 47 | log_name = "".join(os.path.basename(__file__).split(".")[:-1]) 48 | 49 | log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log") 50 | logger = get_logger(log_filename) 51 | 52 | total_samples = args.num_samples 53 | batch_size = args.batch_size # 128 54 | assert batch_size <= total_samples, "batch_size <= num_samples" 55 | assert total_samples % batch_size == 0, "total_samples divisible by batch_size" 56 | n = args.grid_size # 256 57 | viscosity = args.visc if args.Re is None else 1 / args.Re 58 | Re = 1 / viscosity 59 | dt = args.dt # 1e-3 60 | T = args.time # 10 61 | subsample = args.subsample # 4 62 | ns = n // subsample 63 | scale = args.scale # 1 64 | T_warmup = args.time_warmup # 4.5 65 | num_snapshots = args.num_steps # 100 66 | random_state = args.seed 67 | peak_wavenumber = args.peak_wavenumber # 4 68 | diam = ( 69 | eval(args.diam) if isinstance(args.diam, str) else args.diam 70 | ) # "2 * torch.pi" 71 | force_rerun = args.force_rerun 72 | 73 | logger = logging.getLogger() 74 | logger.info(f"Generating data for Kolmogorov2d flow with {total_samples} samples") 75 | 76 | max_velocity = args.max_velocity # 5 77 | dt = stable_time_step(diam / n, dt, max_velocity, viscosity=viscosity) 78 | logger.info(f"Using dt = {dt}") 79 | 80 | warmup_steps = int(T_warmup / dt) 81 | total_steps = int((T - T_warmup) / dt) 82 | record_every_iters = int(total_steps / num_snapshots) 83 | 84 | dtype = torch.float64 if args.double else torch.float32 85 | cdtype = torch.complex128 if args.double else torch.complex64 86 | dtype_str = "_fp64" if args.double else "" 87 | filename = args.filename 88 | if filename is None: 89 | filename = f"Kolmogorov2d{dtype_str}_{ns}x{ns}_N{total_samples}_Re{int(Re)}_T{num_snapshots}.pt" 90 | args.filename = filename 91 | data_filepath = os.path.join(DATA_PATH, filename) 92 | if os.path.exists(data_filepath) and not force_rerun: 93 | logger.info(f"Data already exists at {data_filepath}") 94 | return 95 | elif os.path.exists(data_filepath) and force_rerun: 96 | logger.info(f"Force rerun and save data to {data_filepath}") 97 | os.remove(data_filepath) 98 | else: 99 | logger.info(f"Save data to {data_filepath}") 100 | 101 | cuda = not args.no_cuda and torch.cuda.is_available() 102 | no_tqdm = args.no_tqdm 103 | device = torch.device("cuda:0" if cuda else "cpu") 104 | 105 | torch.set_default_dtype(torch.float64) 106 | logger.info( 107 | f"Using device: {device} | save dtype: {dtype} | compute dtype: {torch.get_default_dtype()}" 108 | ) 109 | 110 | grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device) 111 | 112 | forcing_fn = KolmogorovForcing( 113 | grid=grid, 114 | scale=scale, 115 | wave_number=peak_wavenumber, 116 | swap_xy=False, 117 | ) 118 | 119 | ns2d = NavierStokes2DSpectral( 120 | viscosity=viscosity, 121 | grid=grid, 122 | drag=0.1, 123 | smooth=True, 124 | forcing_fn=forcing_fn, 125 | step_fn=RK4CrankNicolsonStepper(), 126 | ).to(device) 127 | 128 | num_batches = total_samples // batch_size 129 | for i, idx in enumerate(range(0, total_samples, batch_size)): 130 | logger.info(f"Generate trajectory for batch [{i+1}/{num_batches}]") 131 | logger.info( 132 | f"random states: {random_state + idx} to {random_state + idx + batch_size-1}" 133 | ) 134 | 135 | vort_init = torch.stack( 136 | [ 137 | curl_2d( 138 | filtered_velocity_field( 139 | grid, 140 | max_velocity, 141 | peak_wavenumber, 142 | random_state=random_state + i + k, 143 | ) 144 | ).data 145 | for k in range(batch_size) 146 | ] 147 | ) 148 | vort_hat = fft.rfft2(vort_init).to(device) 149 | 150 | with tqdm(total=warmup_steps, disable=no_tqdm) as pbar: 151 | for j in range(warmup_steps): 152 | vort_hat, _ = ns2d.step(vort_hat, dt) 153 | if j % 100 == 0: 154 | vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item()/n 155 | desc = datetime.now().strftime("%d-%b-%Y %H:%M:%S") + f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}" 156 | pbar.set_description(desc) 157 | pbar.update(100) 158 | 159 | result = get_trajectory_imex( 160 | ns2d, 161 | vort_hat, 162 | dt, 163 | num_steps=total_steps, 164 | record_every_steps=record_every_iters, 165 | pbar=not no_tqdm, 166 | dtype=cdtype, 167 | ) 168 | 169 | for field, value in result.items(): 170 | logger.info( 171 | f"freq variable: {field:<12} | shape: {value.shape} | dtype: {value.dtype}" 172 | ) 173 | value = fft.irfft2(value).real.cpu().to(dtype) 174 | logger.info( 175 | f"saved variable: {field:<12} | shape: {value.shape} | dtype: {value.dtype}" 176 | ) 177 | if subsample > 1: 178 | result[field] = F.interpolate(value, size=(ns, ns), mode="bilinear") 179 | else: 180 | result[field] = value 181 | 182 | result["random_states"] = torch.tensor( 183 | [random_state + idx + k for k in range(batch_size)], dtype=torch.int32 184 | ) 185 | logger.info(f"Saving batch [{i+1}/{num_batches}] to {data_filepath}") 186 | if not args.demo: 187 | save_pickle(result, data_filepath, append=True) 188 | del result 189 | 190 | 191 | if not args.demo: 192 | pickle_to_pt(data_filepath) 193 | logger.info(f"Done saving.") 194 | else: 195 | try: 196 | verify_trajectories( 197 | data_filepath, 198 | dt=record_every_iters * dt, 199 | T_warmup=T_warmup, 200 | n_samples=1, 201 | ) 202 | except Exception as e: 203 | logger.error(f"Error in plotting sample trajectories: {e}") 204 | return 0 205 | 206 | 207 | if __name__ == "__main__": 208 | args = get_args_ns2d("Params Kolmogorov 2d flow data generation") 209 | main(args) 210 | -------------------------------------------------------------------------------- /fno/data_gen/data_gen_McWilliams2d.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2024 Shuhao Cao 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | import os 11 | 12 | import torch 13 | import torch.fft as fft 14 | 15 | from torch_cfd.grids import Grid 16 | from torch_cfd.initial_conditions import vorticity_field 17 | from torch_cfd.equations import * 18 | 19 | from solvers import get_trajectory_imex 20 | from data_utils import * 21 | 22 | import logging 23 | 24 | from fno.pipeline import DATA_PATH, LOG_PATH 25 | 26 | 27 | def main(args): 28 | """ 29 | Generate the isotropic turbulence in [1] 30 | 31 | [1]: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. Journal of Fluid Mechanics, 146, 21-43. 32 | 33 | Training dataset for the SFNO ICLR 2025 paper: 34 | >>> python data_gen_McWilliams2d.py --num-samples 1152 --batch-size 128 --grid-size 256 --subsample 4 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" 35 | 36 | Testing dataset for plotting the enstrohpy spectrum: 37 | >>> python data_gen_McWilliams2d.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double 38 | 39 | Training dataset with Re=5k: 40 | >>> python data_gen_McWilliams2d.py --num-samples 1152 --batch-size 128 --grid-size 512 --subsample 1 --Re 5e3 --dt 5e-4 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" 41 | 42 | Demo dataset to test if the data generation works: 43 | >>> python data_gen_McWilliams2d.py --num-samples 4 --batch-size 2 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double --demo 44 | 45 | """ 46 | args = args.parse_args() 47 | 48 | current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm") 49 | log_name = "".join(os.path.basename(__file__).split(".")[:-1]) 50 | 51 | log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log") 52 | logger = get_logger(log_filename) 53 | 54 | total_samples = args.num_samples 55 | batch_size = args.batch_size # 128 56 | assert batch_size <= total_samples, "batch_size <= num_samples" 57 | assert total_samples % batch_size == 0, "total_samples divisible by batch_size" 58 | n = args.grid_size # 256 59 | viscosity = args.visc if args.Re is None else 1 / args.Re 60 | Re = 1 / viscosity 61 | dt = args.dt # 1e-3 62 | T = args.time # 10 63 | subsample = args.subsample # 4 64 | ns = n // subsample 65 | T_warmup = args.time_warmup # 4.5 66 | num_snapshots = args.num_steps # 100 67 | random_state = args.seed 68 | peak_wavenumber = args.peak_wavenumber # 4 69 | diam = ( 70 | eval(args.diam) if isinstance(args.diam, str) else args.diam 71 | ) # "2 * torch.pi" 72 | force_rerun = args.force_rerun 73 | 74 | logger = logging.getLogger() 75 | logger.info(f"Generating data for McWilliams2d with {total_samples} samples") 76 | 77 | max_velocity = args.max_velocity # 5 78 | dt = stable_time_step(diam / n, dt, max_velocity, viscosity=viscosity) 79 | logger.info(f"Using dt = {dt}") 80 | 81 | warmup_steps = int(T_warmup / dt) 82 | total_steps = int((T - T_warmup) / dt) 83 | record_every_iters = int(total_steps / num_snapshots) 84 | 85 | dtype = torch.float64 if args.double else torch.float32 86 | cdtype = torch.complex128 if args.double else torch.complex64 87 | dtype_str = "_fp64" if args.double else "" 88 | filename = args.filename 89 | if filename is None: 90 | # filename = f"McWilliams2d{dtype_str}_{ns}x{ns}_N{total_samples}_v{viscosity:.0e}_T{num_snapshots}.pt".replace( 91 | # "e-0", "e-" 92 | # ) 93 | filename = f"McWilliams2d{dtype_str}_{ns}x{ns}_N{total_samples}_Re{int(Re)}_T{num_snapshots}.pt" 94 | args.filename = filename 95 | data_filepath = os.path.join(DATA_PATH, filename) 96 | if os.path.exists(data_filepath) and not force_rerun: 97 | logger.info(f"Data already exists at {data_filepath}") 98 | return 99 | elif os.path.exists(data_filepath) and force_rerun: 100 | logger.info(f"Force rerun and save data to {data_filepath}") 101 | os.remove(data_filepath) 102 | else: 103 | logger.info(f"Save data to {data_filepath}") 104 | 105 | cuda = not args.no_cuda and torch.cuda.is_available() 106 | no_tqdm = args.no_tqdm 107 | device = torch.device("cuda:0" if cuda else "cpu") 108 | 109 | torch.set_default_dtype(torch.float64) 110 | logger.info( 111 | f"Using device: {device} | save dtype: {dtype} | compute dtype: {torch.get_default_dtype()}" 112 | ) 113 | 114 | grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device) 115 | 116 | ns2d = NavierStokes2DSpectral( 117 | viscosity=viscosity, 118 | grid=grid, 119 | drag=0, 120 | smooth=True, 121 | forcing_fn=None, 122 | step_fn=RK4CrankNicolsonStepper(), 123 | ).to(device) 124 | 125 | num_batches = total_samples // batch_size 126 | for i, idx in enumerate(range(0, total_samples, batch_size)): 127 | logger.info(f"Generate trajectory for batch [{i+1}/{num_batches}]") 128 | logger.info( 129 | f"random states: {random_state + idx} to {random_state + idx + batch_size-1}" 130 | ) 131 | 132 | vort_init = torch.stack( 133 | [ 134 | vorticity_field( 135 | grid, peak_wavenumber, random_state=random_state + idx + k 136 | ).data 137 | for k in range(batch_size) 138 | ] 139 | ) 140 | vort_hat = fft.rfft2(vort_init).to(device) 141 | 142 | with tqdm(total=warmup_steps, disable=no_tqdm) as pbar: 143 | for j in range(warmup_steps): 144 | vort_hat, _ = ns2d.step(vort_hat, dt) 145 | if j % 100 == 0: 146 | vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item() / n 147 | desc = ( 148 | datetime.now().strftime("%d-%b-%Y %H:%M:%S") 149 | + f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}" 150 | ) 151 | pbar.set_description(desc) 152 | pbar.update(100) 153 | 154 | result = get_trajectory_imex( 155 | ns2d, 156 | vort_hat, 157 | dt, 158 | num_steps=total_steps, 159 | record_every_steps=record_every_iters, 160 | pbar=not no_tqdm, 161 | dtype=cdtype, 162 | ) 163 | 164 | for field, value in result.items(): 165 | logger.info( 166 | f"freq variable: {field:<12} | shape: {value.shape} | dtype: {value.dtype}" 167 | ) 168 | value = fft.irfft2(value).real.cpu().to(dtype) 169 | logger.info( 170 | f"saved variable: {field:<12} | shape: {value.shape} | dtype: {value.dtype}" 171 | ) 172 | if subsample > 1: 173 | result[field] = F.interpolate(value, size=(ns, ns), mode="bilinear") 174 | else: 175 | result[field] = value 176 | 177 | result["random_states"] = torch.tensor( 178 | [random_state + idx + k for k in range(batch_size)], dtype=torch.int32 179 | ) 180 | if not args.demo: 181 | save_pickle(result, data_filepath, append=True) 182 | del result 183 | 184 | if not args.demo: 185 | pickle_to_pt(data_filepath) 186 | logger.info(f"Done saving.") 187 | else: 188 | try: 189 | verify_trajectories( 190 | result, 191 | dt=record_every_iters * dt, 192 | T_warmup=T_warmup, 193 | n_samples=1, 194 | ) 195 | except Exception as e: 196 | logger.error(f"Error in plotting sample trajectories: {e}") 197 | return 0 198 | 199 | 200 | if __name__ == "__main__": 201 | args = get_args_ns2d( 202 | "Parameters for generating NSE 2d with McWilliams 2d example" 203 | ) 204 | main(args) 205 | -------------------------------------------------------------------------------- /fno/data_gen/data_gen_fno.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2025 Shuhao Cao 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | import os 11 | 12 | import torch 13 | import torch.fft as fft 14 | import torch.nn.functional as F 15 | 16 | from torch_cfd.grids import Grid 17 | from torch_cfd.equations import * 18 | from torch_cfd.forcings import SinCosForcing 19 | 20 | from grf import GRF2d 21 | from solvers import get_trajectory_imex 22 | from data_utils import * 23 | import logging 24 | 25 | from fno.pipeline import DATA_PATH, LOG_PATH 26 | 27 | 28 | def main(args): 29 | """ 30 | Generate the original FNO data 31 | the right hand side is a fixed forcing 32 | 0.1*(torch.sin(2*math.pi*(x+y))+torch.cos(2*math.pi*(x+y))) 33 | 34 | It stores data after each batch, and will resume using a fixed formula'd seed 35 | when starting again. 36 | The default values of the params for the Gaussian Random Field (GRF) are printed. 37 | 38 | Sample usage: 39 | 40 | - Training data for Spectral-Refiner ICLR 2025 paper 'fnodata_extra_64x64_N1280_v1e-3_T50_steps100_alpha2.5_tau7.pt' 41 | >>> python data_gen_fno.py --num-samples 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3 --scale 0.1 42 | 43 | - Test data 44 | >>> python data_gen_fno.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --scale 0.1 --replicable-init --seed 42 45 | 46 | - Test data fine 47 | >>> python data_gen_fno.py --num-samples 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --scale 0.1 --replicable-init --seed 42 48 | 49 | - Testing if the code works 50 | >>> python data_gen/data_gen_fno.py --num-samples 4 --batch-size 2 --grid-size 128 --subsample 1 --double --extra-vars --time 2 --time-warmup 1 --num-steps 10 --dt 1e-3 --scale 0.1 --replicable-init --seed 42 --demo 51 | 52 | """ 53 | 54 | args = args.parse_args() 55 | 56 | current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm") 57 | log_name = "".join(os.path.basename(__file__).split(".")[:-1]) 58 | logpath = args.logpath if args.logpath is not None else LOG_PATH 59 | log_filename = os.path.join(logpath, f"{current_time}_{log_name}.log") 60 | logger = get_logger(log_filename) 61 | 62 | logger.info(f"Using the following arguments: ") 63 | all_args = {k: v for k, v in vars(args).items() if not callable(v)} 64 | logger.info(" | ".join(f"{k}={v}" for k, v in all_args.items())) 65 | 66 | n_grid_max = 2048 67 | n = args.grid_size # 256 68 | subsample = args.subsample # 4 69 | ns = n // subsample 70 | diam = args.diam # 1.0 71 | diam = eval(diam) if isinstance(diam, str) else diam 72 | if n > n_grid_max: 73 | raise ValueError( 74 | f"Grid size {n} is larger than the maximum allowed {n_grid_max}" 75 | ) 76 | scale = args.scale 77 | visc = args.visc if args.Re is None else 1 / args.Re # 1e-3 78 | T = args.time # 50 79 | T_warmup = args.time_warmup # 30 80 | T_new = T - T_warmup 81 | record_steps = args.num_steps 82 | dt = args.dt # 1e-4 83 | logger.info(f"Using dt = {dt}") 84 | 85 | warmup_steps = int(T_warmup / dt) 86 | total_steps = int(T_new / dt) 87 | record_every_iters = int(total_steps / record_steps) 88 | 89 | alpha = args.alpha # 2.5 90 | tau = args.tau # 7 91 | peak_wavenumber = args.peak_wavenumber 92 | 93 | dtype = torch.float64 if args.double else torch.float32 94 | normalize = args.normalize 95 | filename = args.filename 96 | force_rerun = args.force_rerun 97 | replicate_init = args.replicable_init 98 | dealias = not args.no_dealias 99 | pbar = not args.no_tqdm 100 | 101 | # Number of solutions to generate 102 | total_samples = args.num_samples # 8 103 | 104 | # Batch size 105 | batch_size = args.batch_size # 8 106 | 107 | extra = "_extra" if args.extra_vars else "" 108 | dtype_str = "_fp64" if args.double else "" 109 | if filename is None: 110 | filename = ( 111 | f"fnodata{extra}{dtype_str}_{ns}x{ns}_N{total_samples}" 112 | + f"_v{visc:.0e}_T{int(T)}_steps{record_steps}_alpha{alpha:.1f}_tau{tau:.0f}.pt" 113 | ).replace("e-0", "e-") 114 | args.filename = filename 115 | 116 | filepath = args.filepath if args.filepath is not None else DATA_PATH 117 | for p in [filepath]: 118 | if not os.path.exists(p): 119 | os.makedirs(p) 120 | logging.info(f"Created directory {p}") 121 | data_filepath = os.path.join(DATA_PATH, filename) 122 | 123 | data_exist = os.path.exists(data_filepath) 124 | if data_exist and not force_rerun: 125 | logger.info(f"File {filename} exists with current data as follows:") 126 | data = torch.load(data_filepath) 127 | 128 | for key, v in data.items(): 129 | if isinstance(v, torch.Tensor): 130 | logger.info(f"{key:<12} | {v.shape} | {v.dtype}") 131 | else: 132 | logger.info(f"{key:<12} | {v.dtype}") 133 | if len(data[key]) == total_samples: 134 | return 135 | elif len(data[key]) < total_samples: 136 | total_samples -= len(data[key]) 137 | else: 138 | logger.info(f"Generating data and saving in {filename}") 139 | 140 | cuda = not args.no_cuda and torch.cuda.is_available() 141 | no_tqdm = args.no_tqdm 142 | device = torch.device("cuda:0" if cuda else "cpu") 143 | 144 | torch.set_default_dtype(torch.float64) 145 | logger.info( 146 | f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}" 147 | ) 148 | # Set up 2d GRF with covariance parameters 149 | # Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha) 150 | # Note that we need alpha > d/2 (here d= 2) 151 | 152 | grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device) 153 | 154 | forcing_fn = SinCosForcing( 155 | grid=grid, 156 | scale=scale, 157 | diam=diam, 158 | k=peak_wavenumber, 159 | vorticity=True, 160 | ) 161 | # Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y))) 162 | 163 | grf = GRF2d( 164 | n=n, 165 | alpha=alpha, 166 | tau=tau, 167 | normalize=normalize, 168 | device=device, 169 | dtype=torch.float64, 170 | ) 171 | 172 | step_fn = IMEXStepper(order=2, requires_grad=False) 173 | 174 | ns2d = NavierStokes2DSpectral( 175 | viscosity=visc, 176 | grid=grid, 177 | smooth=True, 178 | forcing_fn=forcing_fn, 179 | step_fn=step_fn, 180 | ).to(device) 181 | 182 | if os.path.exists(data_filepath) and not force_rerun: 183 | logger.info(f"Data already exists at {data_filepath}") 184 | return 185 | elif os.path.exists(data_filepath) and force_rerun: 186 | logger.info(f"Force rerun and save data to {data_filepath}") 187 | os.remove(data_filepath) 188 | else: 189 | logger.info(f"Save data to {data_filepath}") 190 | 191 | num_batches = total_samples // batch_size 192 | for i, idx in enumerate(range(0, total_samples, batch_size)): 193 | logger.info(f"Generate trajectory for batch [{i+1}/{num_batches}]") 194 | logger.info( 195 | f"random states: {args.seed + idx} to {args.seed + idx + batch_size-1}" 196 | ) 197 | 198 | # Sample random fields 199 | seeds = [args.seed + idx + k for k in range(batch_size)] 200 | n0 = n_grid_max if replicate_init else n 201 | vort_init = [ 202 | grf.sample(1, n0, random_state=s) for _, s in zip(range(batch_size), seeds) 203 | ] 204 | vort_init = torch.stack(vort_init) 205 | if n != n0: 206 | vort_init = F.interpolate(vort_init, size=(n, n), mode="nearest") 207 | vort_init = vort_init.squeeze(1) 208 | vort_hat = fft.rfft2(vort_init).to(device) 209 | 210 | logger.info(f"initial condition {vort_init.shape}") 211 | 212 | if T_warmup > 0: 213 | with tqdm(total=warmup_steps, disable=no_tqdm) as pbar: 214 | for j in range(warmup_steps): 215 | vort_hat, _ = ns2d.step(vort_hat, dt) 216 | if j % 100 == 0: 217 | vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item() / n 218 | desc = ( 219 | datetime.now().strftime("%d-%b-%Y %H:%M:%S") 220 | + f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}" 221 | ) 222 | pbar.set_description(desc) 223 | pbar.update(100) 224 | 225 | logger.info(f"generate data from {T_warmup} to {T}") 226 | result = get_trajectory_imex( 227 | ns2d, 228 | vort_hat, 229 | dt, 230 | num_steps=total_steps, 231 | record_every_steps=record_every_iters, 232 | pbar=not no_tqdm, 233 | ) 234 | 235 | for field, value in result.items(): 236 | value = fft.irfft2(value).real.cpu().to(dtype) 237 | logger.info( 238 | f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}" 239 | ) 240 | if subsample > 1: 241 | assert ( 242 | value.ndim == 4 243 | ), f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}" 244 | value = F.interpolate(value, size=(ns, ns), mode="bilinear") 245 | result[field] = value 246 | logger.info(f"{field:<15} | {value.shape} | {value.dtype}") 247 | 248 | if not extra: 249 | for key in ["vort_t", "stream", "residual"]: 250 | result[key] = torch.empty(0, device="cpu") 251 | result["random_states"] = torch.as_tensor(seeds, dtype=torch.int32) 252 | 253 | logger.info(f"Saving batch [{i+1}/{num_batches}] to {data_filepath}") 254 | if not args.demo: 255 | save_pickle(result, data_filepath, append=True) 256 | del result 257 | 258 | if not args.demo: 259 | pickle_to_pt(data_filepath) 260 | logger.info(f"Done saving.") 261 | else: 262 | try: 263 | verify_trajectories( 264 | result, 265 | dt=record_every_iters * dt, 266 | T_warmup=T_warmup, 267 | n_samples=1, 268 | ) 269 | except Exception as e: 270 | logger.error(f"Error in plotting sample trajectories: {e}") 271 | return 0 272 | 273 | 274 | if __name__ == "__main__": 275 | args = get_args_ns2d("Generate the original FNO data for NSE in 2D") 276 | main(args) 277 | -------------------------------------------------------------------------------- /fno/data_gen/data_gen_fno_legacy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from functools import partial 5 | 6 | import torch 7 | import torch.fft as fft 8 | import torch.nn.functional as F 9 | 10 | from grf import GRF2d 11 | from solvers import * 12 | from data_utils import * 13 | from fno.pipeline import DATA_PATH, LOG_PATH 14 | 15 | def main(args): 16 | """ 17 | Generate the original FNO data 18 | the right hand side is a fixed forcing 19 | 0.1*(torch.sin(2*math.pi*(x+y))+torch.cos(2*math.pi*(x+y))) 20 | This is modified from the original FNO data generation code. 21 | For the new code using torch_cfd, please refer to 22 | fno/data_gen/data_gen_fno.py 23 | 24 | It stores data after each batch, and will resume using a fixed formula'd seed 25 | when starting again. 26 | The default values of the params for the Gaussian Random Field (GRF) are printed. 27 | 28 | Sample usage: 29 | 30 | - Training data for Spectral-Refiner ICLR 2025 paper 'fnodata_extra_64x64_N1280_v1e-3_T50_steps100_alpha2.5_tau7.pt' 31 | >>> python data_gen_fno.py --num-samples 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3 32 | 33 | - Test data 34 | >>> python data_gen_fno.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --replicable-init --seed 42 35 | 36 | - Test data fine 37 | >>> python data_gen_fno.py --num-samples 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --replicable-init --seed 42 38 | 39 | """ 40 | 41 | args = args.parse_args() 42 | 43 | current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm") 44 | log_name = "".join(os.path.basename(__file__).split(".")[:-1]) 45 | logpath = args.logpath if args.logpath is not None else LOG_PATH 46 | log_filename = os.path.join(logpath, f"{current_time}_{log_name}.log") 47 | logger = get_logger(log_filename) 48 | 49 | cuda = not args.no_cuda and torch.cuda.is_available() 50 | device = torch.device("cuda" if cuda else "cpu") 51 | logger.info(f"Using device: {device}") 52 | logger.info(f"Using the following arguments: ") 53 | all_args = {k: v for k, v in vars(args).items() if not callable(v)} 54 | logger.info(" | ".join(f"{k}={v}" for k, v in all_args.items())) 55 | 56 | n_grid_max = 2048 57 | n = args.grid_size # 256 58 | subsample = args.subsample # 4 59 | ns = n // subsample 60 | diam = args.diam # 1.0 61 | diam = eval(diam) if isinstance(diam, str) else diam 62 | if n > n_grid_max: 63 | raise ValueError( 64 | f"Grid size {n} is larger than the maximum allowed {n_grid_max}" 65 | ) 66 | visc = args.visc if args.Re is None else 1/args.Re # 1e-3 67 | T = args.time # 50 68 | T_warmup = args.time_warmup # 30 69 | T_new = T - T_warmup 70 | delta_t = args.dt # 1e-4 71 | 72 | alpha = args.alpha # 2.5 73 | tau = args.tau # 7 74 | f = args.forcing # FNO's default sin+cos 75 | dtype = torch.float64 if args.double else torch.float32 76 | normalize = args.normalize 77 | filename = args.filename 78 | force_rerun = args.force_rerun 79 | replicate_init = args.replicable_init 80 | dealias = not args.no_dealias 81 | pbar = not args.no_tqdm 82 | torch.set_default_dtype(torch.float64) 83 | logger.info(f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}") 84 | 85 | # Number of solutions to generate 86 | total_samples = args.num_samples # 8 87 | 88 | # Number of snapshots from solution 89 | record_steps = args.num_steps 90 | 91 | # Batch size 92 | batch_size = args.batch_size # 8 93 | 94 | solver_kws = dict(visc=visc, 95 | delta_t=delta_t, 96 | diam=diam, 97 | dealias=dealias, 98 | dtype=torch.float64) 99 | 100 | extra = "_extra" if args.extra_vars else "" 101 | dtype_str = "_fp64" if args.double else "" 102 | if filename is None: 103 | filename = ( 104 | f"fnodata{extra}{dtype_str}_{ns}x{ns}_N{total_samples}" 105 | + f"_v{visc:.0e}_T{int(T)}_steps{record_steps}_alpha{alpha:.1f}_tau{tau:.0f}.pt" 106 | ).replace("e-0", "e-") 107 | args.filename = filename 108 | 109 | filepath = args.filepath if args.filepath is not None else DATA_PATH 110 | for p in [filepath]: 111 | if not os.path.exists(p): 112 | os.makedirs(p) 113 | logging.info(f"Created directory {p}") 114 | data_filepath = os.path.join(DATA_PATH, filename) 115 | 116 | data_exist = os.path.exists(data_filepath) 117 | if data_exist and not force_rerun: 118 | logger.info(f"File {filename} exists with current data as follows:") 119 | data = torch.load(data_filepath) 120 | 121 | for key, v in data.items(): 122 | if isinstance(v, torch.Tensor): 123 | logger.info(f"{key:<12} | {v.shape} | {v.dtype}") 124 | else: 125 | logger.info(f"{key:<12} | {v.dtype}") 126 | if len(data[key]) == total_samples: 127 | return 128 | elif len(data[key]) < total_samples: 129 | total_samples -= len(data[key]) 130 | else: 131 | logger.info(f"Generating data and saving in {filename}") 132 | 133 | # Set up 2d GRF with covariance parameters 134 | # Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha) 135 | # Note that we need alpha > d/2 (here d= 2) 136 | grf = GRF2d( 137 | n=n, 138 | alpha=alpha, 139 | tau=tau, 140 | normalize=normalize, 141 | device=device, 142 | dtype=torch.float64, 143 | ) 144 | 145 | # Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y))) 146 | grid = torch.linspace(0, 1, n + 1, device=device) 147 | grid = grid[0:-1] 148 | 149 | X, Y = torch.meshgrid(grid, grid, indexing="ij") 150 | # FNO's original implementation 151 | # fh = 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))) 152 | fh = f(X, Y) 153 | 154 | if os.path.exists(data_filepath) and not force_rerun: 155 | logger.info(f"Data already exists at {data_filepath}") 156 | return 157 | elif os.path.exists(data_filepath) and force_rerun: 158 | logger.info(f"Force rerun and save data to {data_filepath}") 159 | os.remove(data_filepath) 160 | else: 161 | logger.info(f"Save data to {data_filepath}") 162 | 163 | num_batches = total_samples // batch_size 164 | for i, idx in enumerate(range(0, total_samples, batch_size)): 165 | logger.info(f"Generate trajectory for batch [{i+1}/{num_batches}]") 166 | logger.info(f"random states: {args.seed + idx} to {args.seed + idx + batch_size-1}") 167 | 168 | # Sample random fields 169 | seeds = [args.seed + idx + k for k in range(batch_size)] 170 | n0 = n_grid_max if replicate_init else n 171 | w0 = [grf.sample(1, n0, random_state=s) for _, s in zip(range(batch_size), seeds)] 172 | w0 = torch.stack(w0) 173 | if n != n0: 174 | w0 = F.interpolate(w0, size=(n, n), mode="nearest") 175 | w0 = w0.squeeze(1) 176 | 177 | logger.info(f"initial condition {w0.shape}") 178 | 179 | if T_warmup > 0: 180 | logger.info(f"warm up till {T_warmup}") 181 | tmp = get_trajectory_imex_crank_nicolson( 182 | w0, 183 | fh, 184 | T=T_warmup, 185 | record_steps=record_steps, 186 | subsample=1, 187 | pbar=pbar, 188 | **solver_kws, 189 | ) 190 | w0 = tmp["vorticity"][:, -1].to(device) 191 | del tmp 192 | logger.info(f"warmup initial condition {w0.shape}") 193 | 194 | logger.info(f"generate data from {T_warmup} to {T}") 195 | result = get_trajectory_imex_crank_nicolson( 196 | w0, 197 | fh, 198 | T=T_new, 199 | record_steps=record_steps, 200 | subsample=subsample, 201 | pbar=pbar, 202 | **solver_kws, 203 | ) 204 | 205 | for field, value in result.items(): 206 | if subsample > 1 and value.ndim == 4: 207 | value = F.interpolate(value, size=(ns, ns), mode="bilinear") 208 | result[field] = value.cpu().to(dtype) 209 | logger.info(f"{field:<15} | {value.shape} | {value.dtype}") 210 | 211 | 212 | if not extra: 213 | for key in ["vort_t", "stream", "residual"]: 214 | result[key] = torch.empty(0, device="cpu") 215 | result["random_states"] = torch.as_tensor(seeds, dtype=torch.int32) 216 | 217 | logger.info(f"Saving batch [{i+1}/{num_batches}] to {data_filepath}") 218 | save_pickle(result, data_filepath) 219 | del result 220 | 221 | pickle_to_pt(data_filepath) 222 | logger.info(f"Done converting to pt.") 223 | if args.demo_plots: 224 | try: 225 | verify_trajectories( 226 | data_filepath, 227 | dt=T_new/record_steps, 228 | T_warmup=T_warmup, 229 | n_samples=1, 230 | ) 231 | except Exception as e: 232 | logger.error(f"Error in plotting: {e}") 233 | finally: 234 | pass 235 | return 236 | 237 | 238 | if __name__ == "__main__": 239 | args = get_args_ns2d("Generate the original FNO data for NSE in 2D") 240 | main(args) 241 | -------------------------------------------------------------------------------- /fno/data_gen/data_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import sys 6 | from collections import defaultdict 7 | from datetime import datetime 8 | 9 | import dill 10 | import h5py 11 | 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import seaborn as sns 15 | import torch 16 | import xarray 17 | from tqdm.auto import tqdm 18 | 19 | feval = lambda s: eval("lambda x, y:" + s, globals()) 20 | 21 | 22 | class TqdmLoggingHandler(logging.Handler): 23 | def __init__(self, level=logging.NOTSET): 24 | super().__init__(level) 25 | 26 | def emit(self, record): 27 | try: 28 | msg = self.format(record) 29 | tqdm.write(msg) 30 | self.flush() 31 | except Exception: 32 | self.handleError(record) 33 | 34 | 35 | def get_logger(filename, tqdm=True): 36 | stream_handler = TqdmLoggingHandler() if tqdm else logging.StreamHandler(sys.stdout) 37 | logging.basicConfig( 38 | level=logging.INFO, 39 | format="%(asctime)s - %(message)s", 40 | datefmt="%d-%b-%Y %H:%M:%S", 41 | handlers=[ 42 | logging.FileHandler(filename=filename), 43 | stream_handler, 44 | ], 45 | ) 46 | return logging.getLogger() 47 | 48 | 49 | def get_args_2d(desc="Data generation in 2D"): 50 | parser = argparse.ArgumentParser(description=desc) 51 | parser.add_argument( 52 | "--example", 53 | type=str, 54 | default=None, 55 | metavar="example name", 56 | help="data name (default: None)", 57 | ) 58 | parser.add_argument( 59 | "--grid-size", 60 | type=int, 61 | default=256, 62 | metavar="n", 63 | help="grid size (including boundary nodes) in a square domain (default: 256)", 64 | ) 65 | parser.add_argument( 66 | "--boundary", 67 | type=str, 68 | default="periodic", 69 | metavar="a", 70 | help="boundary type: periodic, dirichlet, neumann", 71 | ) 72 | parser.add_argument( 73 | "--subsample", 74 | type=int, 75 | default=1, 76 | metavar="s", 77 | help="subsample (default: 1)", 78 | ) 79 | parser.add_argument( 80 | "--diam", 81 | default=1.0, 82 | metavar="diam", 83 | help="domain is (0,d)x(0,d) (default: 1.0)", 84 | ) 85 | parser.add_argument( 86 | "--scale", 87 | default=1.0, 88 | type=float, 89 | metavar="scale", 90 | help="spatial scaling of the domain (default: 1.0)", 91 | ) 92 | parser.add_argument( 93 | "--batch-size", 94 | type=int, 95 | default=8, 96 | metavar="bsz", 97 | help="batch size for data generation (default: 8)", 98 | ) 99 | parser.add_argument( 100 | "--num-samples", 101 | type=int, 102 | default=1200, 103 | metavar="N", 104 | help="number of samples for data generation (default: 1200)", 105 | ) 106 | parser.add_argument( 107 | "--normalize", 108 | action="store_true", 109 | default=False, 110 | help="use normalized GRF in IV to have L2 norm = 1 (default: False)", 111 | ) 112 | parser.add_argument( 113 | "--double", 114 | action="store_true", 115 | default=False, 116 | help="use double precision torch to save data", 117 | ) 118 | parser.add_argument( 119 | "--alpha", 120 | type=float, 121 | default=2.5, 122 | metavar="alpha", 123 | help="smoothness of the GRF, spatial covariance (default: 2.5)", 124 | ) 125 | parser.add_argument( 126 | "--tau", 127 | type=float, 128 | default=7.0, 129 | metavar="tau", 130 | help="strength of diagonal regularizer in the covariance (default: 7.0)", 131 | ) 132 | parser.add_argument( 133 | "--epsilon", 134 | type=float, 135 | default=1e-2, 136 | metavar="eps", 137 | help="singular coefficient in -eps*\Delta u + gamma*u= f", 138 | ) 139 | parser.add_argument( 140 | "--filepath", 141 | type=str, 142 | default=None, 143 | metavar="file path", 144 | help="path to save the data (default: None)", 145 | ) 146 | parser.add_argument( 147 | "--logpath", 148 | type=str, 149 | default=None, 150 | metavar="log path", 151 | help="path to save the logs (default: None)", 152 | ) 153 | parser.add_argument( 154 | "--filename", 155 | type=str, 156 | default=None, 157 | metavar="file name", 158 | help="file name for Navier-Stokes data (default: None)", 159 | ) 160 | parser.add_argument( 161 | "--no-cuda", action="store_true", default=False, help="disables CUDA" 162 | ) 163 | parser.add_argument( 164 | "--extra-vars", 165 | action="store_true", 166 | default=False, 167 | help="store extra variables in the data file", 168 | ) 169 | parser.add_argument( 170 | "--force-rerun", 171 | action="store_true", 172 | default=False, 173 | help="Force regenerate data even if it exists", 174 | ) 175 | parser.add_argument( 176 | "--no-tqdm", 177 | action="store_true", 178 | default=False, 179 | help="Disable program bar for data generation", 180 | ) 181 | parser.add_argument( 182 | "--verify-data", 183 | action="store_true", 184 | default=False, 185 | help="verify the generated data shape, device", 186 | ) 187 | parser.add_argument( 188 | "--seed", 189 | type=int, 190 | default=1127825, 191 | metavar="Seed", 192 | help="random seed (default: 1127825)", 193 | ) 194 | 195 | return parser 196 | 197 | def get_args_ns2d(desc="Data generation of Navier-Stokes in 2D"): 198 | parser = get_args_2d(desc=desc) 199 | parser.add_argument( 200 | "--visc", 201 | type=float, 202 | default=1e-3, 203 | metavar="viscosity", 204 | help="viscosity in front of Laplacian, 1/Re (default: 0.001)", 205 | ) 206 | parser.add_argument( 207 | "--Re", 208 | type=float, 209 | default=None, 210 | metavar="Reynolds number", 211 | help="Re (default: None)", 212 | ) 213 | parser.add_argument( 214 | "--time", 215 | type=float, 216 | default=20.0, 217 | metavar="T", 218 | help="total time for simulation (default: 20.0)", 219 | ) 220 | parser.add_argument( 221 | "--time-warmup", 222 | type=float, 223 | default=4.5, 224 | metavar="T_warmup", 225 | help="warm up for simulation (default: 4.5)", 226 | ) 227 | parser.add_argument( 228 | "--dt", 229 | type=float, 230 | default=1e-4, 231 | metavar="delta_t", 232 | help="time step size for simulation (default: 1e-4)", 233 | ) 234 | parser.add_argument( 235 | "--num-steps", 236 | type=int, 237 | default=50, 238 | metavar="nt", 239 | help="number of recorded snapshots (default: 50)", 240 | ) 241 | parser.add_argument( 242 | "--gamma", 243 | type=float, 244 | default=0.0, 245 | metavar="gamma", 246 | help="L2 coefficient in elliptic problem or NSE (drag) (default: 0.0)", 247 | ) 248 | parser.add_argument( 249 | "--forcing", 250 | type=feval, 251 | nargs="?", 252 | default="0.1*(torch.sin(2*math.pi*(x+y))+torch.cos(2*math.pi*(x+y)))", 253 | metavar="f", 254 | help="rhs in vorticity equation in lambda x: f(x) (default: FNO's default)", 255 | ) 256 | parser.add_argument( 257 | "--peak-wavenumber", 258 | type=int, 259 | default=4, 260 | metavar="kappa", 261 | help="wavenumber of the highest energy density for the initial condition (default: 4)", 262 | ) 263 | parser.add_argument( 264 | "--max-velocity", 265 | type=float, 266 | default=5, 267 | metavar="v_max", 268 | help="the maximum speed in the init velocity field (default: 5)", 269 | ) 270 | parser.add_argument( 271 | "--replicable-init", 272 | action="store_true", 273 | default=False, 274 | help="Use the GRF on a reference max mesh size then downsample to get a replicable initial condition", 275 | ) 276 | parser.add_argument( 277 | "--no-dealias", 278 | action="store_true", 279 | default=False, 280 | help="Disable the dealias masking to the nonlinear convection term", 281 | ) 282 | parser.add_argument( 283 | "--demo", 284 | action="store_true", 285 | default=False, 286 | help="Only demo and plot several trajectories for the generated data (not save to disk)", 287 | ) 288 | 289 | return parser 290 | 291 | 292 | def save_pickle(data, save_path, append=True): 293 | mode = "ab" if append else "wb" 294 | with open(save_path, mode) as f: 295 | dill.dump(data, f) 296 | 297 | 298 | def load_pickle(load_path, mode="rb"): 299 | """ 300 | convert serialized data from pickle to pytorch pt file 301 | using dill instead of pickle 302 | https://stackoverflow.com/a/28745948/622119 303 | """ 304 | data = [] 305 | with open(load_path, mode=mode) as f: 306 | try: 307 | while True: 308 | data.append(dill.load(f)) 309 | except EOFError: 310 | pass 311 | return data 312 | 313 | 314 | def pickle_to_pt(data_path, save_path=None): 315 | """ 316 | Change: defaultdict or list is deemed not safe for serialization in PyTorch 2.6.0 317 | a workaround is to create a new dict after serialization 318 | """ 319 | save_path = data_path.replace(".pkl", ".pt") if save_path is None else save_path 320 | result = load_pickle(data_path) 321 | 322 | data = defaultdict(list) 323 | for _res in result: 324 | for field, value in _res.items(): 325 | data[field].append(value) 326 | 327 | for field, value in data.items(): 328 | v = torch.cat(value) 329 | if v.ndim == 1: # time steps or seed 330 | v = torch.unique(v) 331 | data[field] = v 332 | 333 | torch.save({k: v for k, v in data.items() if not callable(v)}, data_path) 334 | 335 | 336 | def matlab_to_pt(data_path, save_path=None): 337 | """ 338 | Convert MATLAB .mat files to PyTorch .pt files. 339 | """ 340 | save_path = data_path.replace(".mat", ".pt") if save_path is None else save_path 341 | with h5py.File(data_path, "r") as f: 342 | mat_data = {key: np.array(f[key]) for key in f.keys()} 343 | 344 | data = defaultdict(list) 345 | for key, value in mat_data.items(): 346 | value = np.transpose(value, axes=range(len(value.shape) - 1, -1, -1)) 347 | data[key] = torch.from_numpy(value) 348 | 349 | torch.save({k: v for k, v in data.items() if not callable(v)}, data_path) 350 | 351 | 352 | def verify_trajectories( 353 | data: dict, 354 | n_samples=5, 355 | dt=1e-3, 356 | T_warmup=4.5, 357 | diam=2 * torch.pi, 358 | ): 359 | import matplotlib 360 | matplotlib.use('TkAgg') 361 | for k, v in data.items(): 362 | if isinstance(v, torch.Tensor): 363 | print(k, v.shape, v.dtype) 364 | N, T, ns, _ = data["vorticity"].shape 365 | n_samples = min(n_samples, N // 2) 366 | idxes = torch.randint(0, N, (n_samples,)) 367 | gridx = gridy = torch.arange(ns) * diam / ns 368 | coords = { 369 | "time": dt * torch.arange(T) + T_warmup, 370 | "x": gridx, 371 | "y": gridy, 372 | } 373 | 374 | for idx in idxes: 375 | w_data = xarray.DataArray( 376 | data["vorticity"][idx, :T], 377 | dims=["time", "x", "y"], 378 | coords=coords, 379 | ).to_dataset(name="vorticity") 380 | 381 | g = ( 382 | w_data["vorticity"] 383 | .isel(time=slice(2, None)) 384 | .thin(time=T // 5) 385 | .plot.imshow( 386 | col="time", 387 | col_wrap=5, 388 | cmap=sns.cm.icefire, 389 | robust=True, 390 | xticks=None, 391 | yticks=None, 392 | cbar_kwargs={"label": f"Vorticity in Sample {idx}"}, 393 | ) 394 | ) 395 | 396 | g.set_xlabels("") 397 | g.set_ylabels("") 398 | plt.show() 399 | -------------------------------------------------------------------------------- /fno/data_gen/grf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | import torch.fft as fft 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | class GRF2d(nn.Module): 14 | """ 15 | Params: alpha, tau 16 | Output: Gaussian random field on [0,1]^2 with mean 0 17 | and covariance operator C = (-Delta + tau^2)^(-alpha) 18 | where Delta is the Laplacian with zero Neumann boundary conditions. 19 | 20 | Params: 21 | - alpha: controls the smoothness of the problem 22 | alpha = 1 is somewhat limiting in 2d 23 | the closer alpha is to 1, the less smooth the isocurve is 24 | 25 | - tau: controls the high frequency modes of the problem 26 | the bigger tau is, the more oscillatory the output is 27 | # TODO: write this as an n-d function 28 | """ 29 | 30 | def __init__( 31 | self, 32 | *, 33 | dim=2, 34 | n=128, 35 | alpha=2, 36 | tau=3, 37 | device:torch.device=device, 38 | dtype=torch.float, 39 | normalize=False, 40 | smoothing=False, 41 | **kwargs, 42 | ): 43 | self.dim = dim 44 | self.n = n 45 | self.device = device 46 | self.dtype = dtype 47 | self.normalize = normalize 48 | self.alpha = alpha 49 | self.tau = tau 50 | self.smoothing = smoothing 51 | self.max_mesh_size = 2048 52 | self._initialize() 53 | 54 | def _initialize(self, n=None, device=None, alpha=None, tau=None, sigma=None): 55 | n = self.n if n is None else n 56 | device = self.device if device is None else device 57 | 58 | alpha = self.alpha if alpha is None else alpha 59 | tau = self.tau if tau is None else tau 60 | sigma = tau ** (0.5 * (2 * alpha - self.dim)) if sigma is None else sigma 61 | 62 | k_max = n // 2 # Nyquist freq 63 | h = 1 / n 64 | 65 | # this is basically fft.fftfreq(n)*n 66 | 67 | kx = fft.fftfreq(n, d=h, device=device) 68 | ky = fft.fftfreq(n, d=h, device=device) 69 | kx, ky = torch.meshgrid(kx, ky, indexing="ij") 70 | 71 | sqrt_eig = ((n**self.dim) 72 | * math.sqrt(2.0) 73 | * sigma 74 | * ((4 * (math.pi**2) * (kx**2 + ky**2) + tau**2) ** (-alpha / 2.0)) 75 | ) 76 | sqrt_eig[0, 0] = 0.0 77 | self.sqrt_eig = sqrt_eig 78 | 79 | def sample(self, bsz, n=None, random_state=0, **kwargs): 80 | if n is None or n == self.n: 81 | n = self.n 82 | elif n != self.n: 83 | self._initialize(n=n, **kwargs) 84 | else: 85 | raise ValueError 86 | 87 | mesh_size = [n for _ in range(self.dim)] 88 | torch.cuda.manual_seed(random_state) 89 | torch.random.manual_seed(random_state) 90 | if self.smoothing: 91 | # this is smoothing in the frequency domain 92 | # by interpolating the neighboring frequencies 93 | max_mesh_size = [self.max_mesh_size for _ in range(self.dim)] 94 | coeff = torch.randn(bsz, 2, *max_mesh_size, dtype=self.dtype, device=self.device) 95 | # interpolate needs the channel dimension, and needs real input 96 | # which we use to represent the real and imaginary parts 97 | coeff = F.interpolate(coeff, size=mesh_size, mode='bilinear') 98 | # because coeff is interpolated, need to call contiguous to have stride 1 to use view_as_complex 99 | # or use the simplified implmentation as follows 100 | # coeff = coeff.permute(0, 2, 3, 1).contiguous() 101 | # coeff = torch.view_as_complex(coeff) 102 | else: 103 | coeff = torch.randn(bsz, 2, *mesh_size, dtype=self.dtype, device=self.device) 104 | coeff = coeff[:, 0] + 1j*coeff[:, 1] 105 | 106 | # coeff = fft.fftn( 107 | # torch.randn(bsz, *mesh_size, dtype=self.dtype, device=self.device), 108 | # dim=list(range(-1, -self.dim - 1, -1)), 109 | # ) 110 | 111 | coeff = self.sqrt_eig * coeff 112 | s = fft.ifftn(coeff, dim=list(range(-1, -self.dim - 1, -1))).real 113 | if self.normalize: 114 | s = s / torch.linalg.norm(s / n, dim=(-1, -2), keepdim=True) 115 | return s 116 | 117 | def forward(self, x, **kwargs): 118 | """ 119 | input: (bsz, C, n, n) 120 | """ 121 | device = x.device 122 | bsz, _, *mesh_size = x.size() 123 | n = max(mesh_size) 124 | return self.sample(bsz, n=n, device=device, **kwargs) 125 | 126 | 127 | def main(): 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument("--n", type=int, default=256) 130 | parser.add_argument("--bsz", type=int, default=32) 131 | parser.add_argument("--alpha", type=float, default=2) 132 | parser.add_argument("--tau", type=float, default=5) 133 | parser.add_argument("--normalize", action="store_true") 134 | parser.add_argument("--smoothing", action="store_true") 135 | args = parser.parse_args() 136 | 137 | grf = GRF2d( 138 | n=args.n, 139 | alpha=args.alpha, 140 | tau=args.tau, 141 | device=device, 142 | normalize=args.normalize, 143 | smoothing=args.smoothing, 144 | ) 145 | 146 | sample = grf.sample(args.bsz) 147 | print(sample.shape) 148 | 149 | import matplotlib.pyplot as plt 150 | import seaborn as sns 151 | from mpl_toolkits.axes_grid1 import make_axes_locatable 152 | 153 | idxes = torch.randint(0, args.bsz, (min(args.bsz//2, 4),)) 154 | fig, axs = plt.subplots(1, len(idxes), figsize=(5*len(idxes), 5)) 155 | for i, ax in enumerate(axs.flatten()): 156 | im = ax.imshow(sample[idxes[i]].cpu().numpy(), cmap=sns.cm.icefire) 157 | ax.set_title(f"GRF sample {idxes[i]}") 158 | ax.xaxis.set_visible(False) 159 | ax.yaxis.set_visible(False) 160 | divider = make_axes_locatable(ax) 161 | cax = divider.append_axes("right", size="7%", pad=0.07) 162 | fig.colorbar(im, cax=cax) 163 | plt.show() 164 | 165 | if __name__ == "__main__": 166 | main() 167 | 168 | -------------------------------------------------------------------------------- /fno/finetune.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2024 Shuhao Cao 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | from __future__ import annotations 11 | 12 | from typing import Tuple 13 | 14 | import torch 15 | import torch.fft as fft 16 | import torch.nn as nn 17 | 18 | from .sfno import OutConv, SpectralConvT 19 | from data_gen.solvers import * 20 | from einops import rearrange, repeat 21 | 22 | 23 | class OutConvFT(OutConv): 24 | def __init__( 25 | self, 26 | modes_x, 27 | modes_y, 28 | modes_t, 29 | batch_size: int = 1, 30 | diam=1.0, 31 | n_grid: int = 256, 32 | out_steps=None, 33 | spatial_padding: int = 0, 34 | temporal_padding: bool = True, 35 | norm="backward", 36 | finetune=True, 37 | dealias=True, 38 | delta=5e-2, 39 | visc=1e-3, 40 | dt=1e-6, # marching step for the solver 41 | bdf_weight=(0, 1), 42 | dtype=torch.float64, 43 | debug=False, 44 | ) -> None: 45 | super().__init__( 46 | modes_x=modes_x, 47 | modes_y=modes_y, 48 | modes_t=modes_t, 49 | delta=delta, 50 | n_grid=n_grid, 51 | norm=norm, 52 | out_steps=out_steps, 53 | spatial_padding=spatial_padding, 54 | temporal_padding=temporal_padding, 55 | ) 56 | """ 57 | from latent steps to output steps 58 | finetuning 59 | n_grid is only needed for building the spectral "mesh" 60 | """ 61 | self.finetune = finetune 62 | self.out_steps = out_steps 63 | self.batch_size = batch_size 64 | self.dealias = dealias 65 | self.diam = diam 66 | self.dtype = dtype 67 | self.visc = visc 68 | self.dt = dt 69 | self.bdf_weight = bdf_weight 70 | self._initialize_fftmesh() 71 | 72 | def _initialize_fftmesh(self): 73 | 74 | kx, ky = fft_mesh_2d(self.n_grid, self.diam) 75 | kmax = self.n_grid // 2 76 | kx, ky = [repeat(z, "x y -> b x y", b=self.batch_size) for z in [kx, ky]] 77 | kx = kx[..., : kmax + 1] 78 | ky = ky[..., : kmax + 1] 79 | 80 | lap = spectral_laplacian_2d(fft_mesh=(kx, ky)) 81 | 82 | dealias_filter = ( 83 | torch.logical_and( 84 | ky.abs() <= (2.0 / 3.0) * kmax, 85 | kx.abs() <= (2.0 / 3.0) * kmax, 86 | ).to(self.dtype) 87 | if self.dealias 88 | else torch.tensor(True) 89 | ) 90 | self.register_buffer("lap", lap) 91 | self.register_buffer("kx", kx) 92 | self.register_buffer("ky", ky) 93 | self.register_buffer("dealias_filter", dealias_filter) 94 | 95 | def _update_spectral_conv_weights( 96 | self, 97 | modes_x, 98 | modes_y, 99 | modes_t, 100 | device: torch.device = None, 101 | model: nn.Module = None, 102 | debug=False, 103 | ): 104 | """ 105 | update the last spectral conv layer for fine-tuning 106 | modes_t <= out_steps // 2 + 1 but not explicitly checked 107 | """ 108 | # self.train() 109 | model = self if model is None else model 110 | old_conv = model.conv 111 | size = [1, 1, modes_x, modes_y, modes_t] 112 | conv = SpectralConvT( 113 | *size, 114 | bias=True, 115 | delta=self.delta, 116 | temporal_padding=self.temporal_padding, 117 | out_steps=self.out_steps, 118 | ).to(device) 119 | conv._reset_parameters() 120 | 121 | if not debug: 122 | mx_ = old_conv.modes_x 123 | my_ = old_conv.modes_y 124 | mt_ = old_conv.modes_t 125 | 126 | slice_x = [slice(0, mx_), slice(-mx_, None)] 127 | slice_y = [slice(0, my_), slice(-my_, None)] 128 | st = slice(0, mt_) 129 | for ix, sx in enumerate(slice_x): 130 | for iy, sy in enumerate(slice_y): 131 | old_weights = old_conv.weight[ix + 2 * iy].data 132 | old_bias = old_conv.bias[ix + 2 * iy].data 133 | conv.weight[ix + 2 * iy].data[..., sx, sy, st, :] = old_weights 134 | conv.bias[ix + 2 * iy].data[..., sx, sy, st, :] = old_bias 135 | 136 | self.conv = conv 137 | self.mode_x = modes_x 138 | self.mode_y = modes_y 139 | self.mode_t = modes_t 140 | 141 | @staticmethod 142 | def get_temporal_derivative(w_h, f_h, dt, weight=(0, 1), **kwargs): 143 | """ 144 | v: (b, x, y, t) 145 | kwargs needed 146 | rfftmesh: (kx, ky) 147 | laplacian: -4 * (torch.pi**2) * (kx**2 + ky**2) 148 | dealias_filter 149 | dealias optional 150 | """ 151 | w_t = [] 152 | w = [] 153 | for dt in [-dt, dt]: 154 | w_, w_t_, *_ = imex_crank_nicolson_step( 155 | w_h, 156 | f_h, 157 | delta_t=dt, 158 | **kwargs, 159 | ) 160 | w_t.append(w_t_) 161 | w.append(w_) 162 | w_t = weight[0] * w_t[0] + weight[1] * w_t[1] 163 | w = weight[0] * w[0] + weight[1] * w[1] 164 | return w, w_t 165 | 166 | def _fine_tune(self, w, f, **solver_kws): 167 | bsz, *s, nt = w.shape # s = (x, y, t) 168 | ft_kws = {"s": s, "norm": self.norm} 169 | dt = self.dt 170 | w = rearrange(w, "b x y t -> b t x y") 171 | if f is None: # for testing 172 | f = torch.zeros_like(w).to(w.device) 173 | w_h, f_h = [fft.rfftn(v, **ft_kws) for v in [w, f]] # f: (b, x, y) 174 | 175 | w_h, w_h_t = self.get_temporal_derivative(w_h, f_h, dt, **solver_kws) 176 | 177 | res_h = update_residual( 178 | w_h, 179 | w_h_t, 180 | f_h, 181 | **solver_kws, 182 | ) 183 | w, w_t, res = [fft.irfftn(v, **ft_kws).real for v in [w_h, w_h_t, res_h]] 184 | w, w_t, res = [rearrange(v, "b t x y -> b x y t") for v in [w, w_t, res]] 185 | 186 | return dict(w=w, w_t=w_t, residual=res) 187 | 188 | def forward(self, v, v_res, f=None, out_steps: int = None, original=False): 189 | """ 190 | v_latent: (b, 1, x, y, t) 191 | w_inp: (b, x, y, t) 192 | f: (b, x, y) 193 | """ 194 | solver_kws = { 195 | "visc": self.visc, 196 | "laplacian": self.lap, 197 | "dealias_filter": self.dealias_filter, 198 | "dealias": self.dealias, 199 | "rfftmesh": (self.kx, self.ky), 200 | "diam": self.diam, 201 | "weight": self.bdf_weight, 202 | } 203 | 204 | v = super().forward(v, v_res, out_steps=out_steps) 205 | 206 | if not self.finetune or original: 207 | return v 208 | else: 209 | return self._fine_tune(v, f, **solver_kws) 210 | 211 | 212 | if __name__ == "__main__": 213 | modes = 128 214 | modes_t = 6 215 | qft = OutConvFT(modes, modes, modes_t, n_grid=256, delta=1) 216 | 217 | for n_t in [10, 50]: 218 | v_latent = torch.randn(1, 1, 256, 256, n_t) 219 | w_res = torch.randn(1, 256, 256, n_t) 220 | f = torch.randn(1, 256, 256) 221 | out = qft(v_latent, w_res, f, out_steps=n_t) 222 | 223 | for k in ["w", "w_t", "residual"]: 224 | print(f"{k:<10} | shape: {list(out[k].shape)}") 225 | -------------------------------------------------------------------------------- /fno/fno3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | minor modification from the original FNO3d code: 3 | Note: the original code is from the master branch of the neural operator repo 4 | However, as of Aug 2024, the master branch has been deleted by the maintainers. 5 | https://github.com/neuraloperator/neuraloperator/blob/master/fourier_3d.py 6 | For an unchanged fork please 7 | https://github.com/scaomath/fourier_neural_operator/blob/master/fourier_3d.py 8 | which is update-to-date till the commit de514f2 with shasum 9 | de514f2adc0de483f99253d9c6630e1fb6e653f1 10 | https://github.com/scaomath/fourier_neural_operator/commit/de514f2adc0de483f99253d9c6630e1fb6e653f1 11 | """ 12 | 13 | import torch 14 | import torch.fft as fft 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class SpectralConv3d(nn.Module): 20 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 21 | super(SpectralConv3d, self).__init__() 22 | 23 | """ 24 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. 25 | """ 26 | 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.modes1 = ( 30 | modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 31 | ) 32 | self.modes2 = modes2 33 | self.modes3 = modes3 34 | 35 | self.scale = 1 / (in_channels * out_channels) 36 | self.weights1 = nn.Parameter( 37 | self.scale 38 | * torch.rand( 39 | in_channels, 40 | out_channels, 41 | self.modes1, 42 | self.modes2, 43 | self.modes3, 44 | dtype=torch.cfloat, 45 | ) 46 | ) 47 | self.weights2 = nn.Parameter( 48 | self.scale 49 | * torch.rand( 50 | in_channels, 51 | out_channels, 52 | self.modes1, 53 | self.modes2, 54 | self.modes3, 55 | dtype=torch.cfloat, 56 | ) 57 | ) 58 | self.weights3 = nn.Parameter( 59 | self.scale 60 | * torch.rand( 61 | in_channels, 62 | out_channels, 63 | self.modes1, 64 | self.modes2, 65 | self.modes3, 66 | dtype=torch.cfloat, 67 | ) 68 | ) 69 | self.weights4 = nn.Parameter( 70 | self.scale 71 | * torch.rand( 72 | in_channels, 73 | out_channels, 74 | self.modes1, 75 | self.modes2, 76 | self.modes3, 77 | dtype=torch.cfloat, 78 | ) 79 | ) 80 | 81 | # Complex multiplication 82 | def compl_mul3d(self, inp, weights): 83 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 84 | return torch.einsum("bixyz,ioxyz->boxyz", inp, weights) 85 | 86 | def forward(self, x): 87 | batchsize = x.shape[0] 88 | # Compute Fourier coeffcients up to factor of e^(- something constant) 89 | x_ft = fft.rfftn(x, dim=[-3, -2, -1]) 90 | 91 | # Multiply relevant Fourier modes 92 | out_ft = torch.zeros( 93 | batchsize, 94 | self.out_channels, 95 | x.size(-3), 96 | x.size(-2), 97 | x.size(-1) // 2 + 1, 98 | dtype=torch.cfloat, 99 | device=x.device, 100 | ) 101 | out_ft[:, :, : self.modes1, : self.modes2, : self.modes3] = self.compl_mul3d( 102 | x_ft[:, :, : self.modes1, : self.modes2, : self.modes3], self.weights1 103 | ) 104 | out_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3] = self.compl_mul3d( 105 | x_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3], self.weights2 106 | ) 107 | out_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3] = self.compl_mul3d( 108 | x_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3], self.weights3 109 | ) 110 | out_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3] = self.compl_mul3d( 111 | x_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3], self.weights4 112 | ) 113 | 114 | # Return to physical space 115 | x = fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) 116 | return x 117 | 118 | 119 | class MLP(nn.Module): 120 | def __init__(self, in_channels, out_channels, mid_channels, activation=True): 121 | super(MLP, self).__init__() 122 | self.mlp1 = nn.Conv3d(in_channels, mid_channels, 1) 123 | self.mlp2 = nn.Conv3d(mid_channels, out_channels, 1) 124 | self.activation = nn.GELU() if activation else nn.Identity() 125 | 126 | def forward(self, x): 127 | for layer in [self.mlp1, self.activation, self.mlp2]: 128 | x = layer(x) 129 | return x 130 | 131 | 132 | class FNO3d(nn.Module): 133 | def __init__( 134 | self, 135 | modes1, 136 | modes2, 137 | modes3, 138 | width, 139 | dim=3, 140 | input_channel=10, 141 | num_spectral_layers=4, 142 | last_activation=False, 143 | padding=0, 144 | extra_mlp=True, 145 | channel_expansion=128, 146 | debug=False, 147 | ): 148 | super().__init__() 149 | 150 | """ 151 | The overall network reimplemented. 152 | 153 | It contains n (=4 by default) layers of the Fourier layer. 154 | 1. Lift the input to the desire channel dimension by self.fc0 . 155 | 2. n layers of the integral operators u' = (W + K)(u). 156 | W defined by self.w; K defined by self.conv . 157 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 158 | 159 | 160 | last_activation: if True, then the last spectral layer activation is gelu, otherwise, it's linear 161 | channel_expansion: the channel expansion of the MLP after the last spectral layer 162 | 163 | input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. 164 | input shape: (batchsize, x=64, y=64, t=40, c=13) 165 | output: the solution of the next 40 timesteps 166 | output shape: (batchsize, x=64, y=64, t=40, c=1) 167 | """ 168 | 169 | self.modes1 = modes1 170 | self.modes2 = modes2 171 | self.modes3 = modes3 172 | self.width = width 173 | self.input_channel = input_channel 174 | self.padding = padding # pad the domain if input is non-periodic 175 | self.extra_mlp = extra_mlp 176 | self.channel_expansion = channel_expansion 177 | 178 | self.p = nn.Conv3d(input_channel + dim, self.width, 1) 179 | # input channel is 13: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t) 180 | 181 | self.spectral_conv = nn.ModuleList( 182 | [ 183 | SpectralConv3d(width, width, modes1, modes2, modes3) 184 | for _ in range(num_spectral_layers) 185 | ] 186 | ) 187 | 188 | self.mlp = nn.ModuleList( 189 | [MLP(width, width, width) for _ in range(num_spectral_layers)] 190 | ) 191 | 192 | self.w = nn.ModuleList( 193 | [nn.Conv3d(width, width, 1) for _ in range(num_spectral_layers)] 194 | ) 195 | 196 | self.activation = nn.ModuleList( 197 | [nn.GELU() for _ in range(num_spectral_layers - 1)] 198 | ) 199 | self.activation.append(nn.GELU() if last_activation else nn.Identity()) 200 | 201 | self.q = MLP(self.width, 1, self.channel_expansion, activation=last_activation) 202 | # output channel is 1: u(x, y) 203 | self.debug = debug 204 | 205 | def forward(self, x): 206 | """ 207 | the treatment of grid is different from FNO official code 208 | which give my autograd trouble 209 | """ 210 | # bsz = x.size(0) 211 | # grid_size = self.grid.size() 212 | # grid = self.grid[None, ...].expand(bsz, *grid_size).to(x.fdevice) 213 | # x = torch.cat((x, grid), dim=-1) 214 | 215 | x = self.p(x) # (b,13,x,y,t) -> (b,c,x,y,t) 216 | 217 | x = F.pad( 218 | x, 219 | [0, 0, self.padding, self.padding, self.padding, self.padding], 220 | mode="circular", 221 | ) # pad the domain if input is non-periodic 222 | 223 | for conv, mlp, w, nonlinear in zip( 224 | self.spectral_conv, self.mlp, self.w, self.activation 225 | ): 226 | x1 = conv(x) # (b,C,x,y,t) 227 | x1 = mlp(x1) # conv3d (N, C_{in}, D, H, W) -> (N, C_{out}, D, H, W) 228 | x2 = w(x) 229 | x = x1 + x2 230 | x = nonlinear(x) 231 | 232 | if self.padding != 0: 233 | x = x[..., self.padding : -self.padding, self.padding : -self.padding, :] 234 | 235 | x = self.q(x) # (b,C,x,y,t) -> (b,1,x,y,t) 236 | return x.squeeze(1), None 237 | 238 | 239 | if __name__ == "__main__": 240 | modes = 8 241 | modes_t = 11 242 | width = 20 243 | model = FNO3d(modes, modes, modes_t, width, extra_mlp=True) 244 | """ 245 | torchinfo has not resolve the complex number problem 246 | """ 247 | for layer in model.children(): 248 | if hasattr(layer, "out_features"): 249 | print(layer.out_features) 250 | try: 251 | from torchinfo import summary 252 | 253 | summary(model, input_size=(5, 13, 128, 128, 40)) 254 | print("\n" * 3) 255 | model_orig = FNO3d(modes, modes, modes_t, width) 256 | summary( 257 | model_orig, input_size=(5, 13, 64, 64, 40) 258 | ) # number of parameters is 6563417 which is not correct 259 | except ImportError as e: 260 | print(e) 261 | -------------------------------------------------------------------------------- /fno/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.tensorboard import SummaryWriter 7 | from tqdm import tqdm 8 | 9 | 10 | def default(value, d): 11 | """ 12 | helper taken from https://github.com/lucidrains/linear-attention-transformer 13 | """ 14 | return d if value is None else value 15 | 16 | 17 | current_path = os.path.abspath(__file__) 18 | SRC_ROOT = os.path.dirname(current_path) 19 | ROOT = os.path.dirname(SRC_ROOT) 20 | MODEL_PATH = default(os.environ.get("MODEL_PATH"), os.path.join(SRC_ROOT, "models")) 21 | LOG_PATH = default(os.environ.get("LOG_PATH"), os.path.join(SRC_ROOT, "logs")) 22 | DATA_PATH = default(os.environ.get("DATA_PATH"), os.path.join(ROOT, "data")) 23 | FIG_PATH = default(os.environ.get("FIG_PATH"), os.path.join(ROOT, "figures")) 24 | for p in [MODEL_PATH, LOG_PATH, DATA_PATH, FIG_PATH]: 25 | if not os.path.exists(p): 26 | os.makedirs(p) 27 | 28 | EPOCH_SCHEDULERS = [ 29 | "ReduceLROnPlateau", 30 | "StepLR", 31 | "MultiplicativeLR", 32 | "MultiStepLR", 33 | "ExponentialLR", 34 | "LambdaLR", 35 | ] 36 | 37 | 38 | def train_batch_ns( 39 | model, 40 | loss_func, 41 | data, 42 | optimizer, 43 | device, 44 | grad_clip=0, 45 | fname="vorticity", 46 | normalizer=None, 47 | ): 48 | optimizer.zero_grad() 49 | a = data[0][fname].to(device) 50 | u = data[1][fname].to(device) 51 | out = model(a) 52 | if normalizer is not None: 53 | out = normalizer[fname].inverse_transform(out) 54 | u = normalizer[fname].inverse_transform(u) 55 | 56 | loss = loss_func(out, u) 57 | 58 | loss.backward() 59 | if grad_clip > 0: 60 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 61 | 62 | optimizer.step() 63 | return loss 64 | 65 | 66 | def eval_epoch_ns( 67 | model, 68 | metric_func, 69 | valid_loader, 70 | device, 71 | fname="vorticity", 72 | out_steps=None, 73 | normalizer=None, 74 | return_output=False, 75 | ): 76 | model.eval() 77 | metric_vals = [] 78 | preds = [] 79 | targets = [] 80 | 81 | with torch.no_grad(): 82 | for _, data in enumerate(valid_loader): 83 | a = data[0][fname].to(device) 84 | u = data[1][fname].to(device) 85 | out = model(a, out_steps=out_steps) 86 | 87 | if normalizer is not None: 88 | out = normalizer[fname].inverse_transform(out) 89 | u = normalizer[fname].inverse_transform(u) 90 | 91 | if return_output: 92 | preds.append(out.cpu()) 93 | targets.append(u.cpu()) 94 | 95 | metric_val = metric_func(out, u) 96 | metric_vals.append(metric_val.item()) 97 | 98 | metric = np.mean(np.asarray(metric_vals), axis=0) 99 | 100 | if return_output: 101 | return metric, torch.cat(preds), torch.cat(targets) 102 | else: 103 | return metric 104 | -------------------------------------------------------------------------------- /fno/sfno_pytest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.fft as fft 4 | 5 | from .sfno import ( 6 | HelmholtzProjection, 7 | LiftingOperator, 8 | OutConv, 9 | SFNO, 10 | SpaceTimePositionalEncoding, 11 | SpectralConvS, 12 | SpectralConvT, 13 | ) 14 | from torch_cfd.spectral import * 15 | from contextlib import contextmanager 16 | 17 | 18 | @contextmanager 19 | def set_default_dtype(dtype): 20 | old_dtype = torch.get_default_dtype() 21 | try: 22 | torch.set_default_dtype(dtype) 23 | yield 24 | finally: 25 | torch.set_default_dtype(old_dtype) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "input_shape, modes", 30 | [ 31 | ((32, 32, 5), (8, 8, 3)), 32 | ((64, 64, 10), (16, 16, 4)), 33 | ((128, 128, 20), (16, 16, 4)), 34 | ], 35 | ) 36 | def test_space_time_positional_encoding_shape(input_shape, modes): 37 | modes_x, modes_y, modes_t = modes 38 | width = 16 39 | pe = SpaceTimePositionalEncoding( 40 | modes_x=modes_x, 41 | modes_y=modes_y, 42 | modes_t=modes_t, 43 | num_channels=width, 44 | input_shape=input_shape, 45 | ) 46 | 47 | bsz = 2 48 | v = torch.randn(bsz, 1, *input_shape) 49 | output = pe(v) 50 | 51 | assert output.shape == (bsz, width, *input_shape) 52 | 53 | 54 | @pytest.mark.parametrize("n_grid", [64, 128, 256]) 55 | def test_helmholtz_fft_mesh_2d(n_grid): 56 | n_grid = 64 57 | diam = 2 * torch.pi 58 | kx, ky = fft_mesh_2d(n_grid, diam) 59 | hzproj = HelmholtzProjection(n_grid=n_grid, diam=diam) 60 | assert (kx == hzproj.kx).all() and (ky == hzproj.ky).all() 61 | 62 | 63 | @pytest.mark.parametrize("n_grid", [64, 128, 256]) 64 | def test_helmholtz_laplacian_2d(n_grid): 65 | diam = 2 * torch.pi 66 | kx, ky = fft_mesh_2d(n_grid, diam) 67 | lap = spectral_laplacian_2d((kx, ky)) 68 | hzproj = HelmholtzProjection(n_grid=n_grid, diam=diam) 69 | assert (lap == hzproj.lap).all() 70 | 71 | 72 | @pytest.mark.parametrize("n_grid", [64, 128, 256, 512]) 73 | def test_helmholtz_divergence_free_fp32(n_grid): 74 | bsz = 2 75 | T = 6 76 | 77 | hzproj = HelmholtzProjection(n_grid=n_grid) 78 | kx, ky = hzproj.kx, hzproj.ky 79 | # Create some random vector fields 80 | vhat = [] 81 | for t in range(T): 82 | lap = hzproj.lap 83 | vhat_ = [ 84 | fft.fft2(torch.randn(bsz, n_grid, n_grid)) / (5e-1 + lap) for _ in range(2) 85 | ] 86 | vhat_ = torch.stack(vhat_, dim=1) 87 | vhat.append(vhat_) 88 | vhat = torch.stack(vhat, dim=-1) 89 | 90 | # Apply Helmholtz projection 91 | w_hat = hzproj(vhat) 92 | 93 | # Check if result is divergence free 94 | div_w_hat = hzproj.div(w_hat, (kx, ky)) 95 | div_w = fft.irfft2(div_w_hat, s=(n_grid, n_grid), dim=(1, 2)).real 96 | 97 | assert torch.linalg.norm(div_w) < 1e-5 98 | 99 | 100 | @pytest.mark.parametrize("n_grid", [64, 128, 256, 512]) 101 | def test_helmholtz_divergence_free_fp64(n_grid): 102 | with set_default_dtype(torch.float64): 103 | bsz = 2 104 | T = 6 105 | diam = 2 * torch.pi 106 | hzproj = HelmholtzProjection(n_grid=n_grid, diam=diam, dtype=torch.float64) 107 | kx, ky = hzproj.kx, hzproj.ky 108 | lap = hzproj.lap 109 | 110 | # Create some random vector fields 111 | vhat = [] 112 | for t in range(T): 113 | vhat_ = [ 114 | fft.fft2(torch.randn(bsz, n_grid, n_grid, dtype=torch.float64)) 115 | / (5e-1 + lap) 116 | for _ in range(2) 117 | ] 118 | vhat_ = torch.stack(vhat_, dim=1) 119 | vhat.append(vhat_) 120 | vhat = torch.stack(vhat, dim=-1) 121 | 122 | # Apply Helmholtz projection 123 | w_hat = hzproj(vhat) 124 | 125 | # Check if result is divergence free 126 | div_w_hat = hzproj.div(w_hat, (kx, ky)) 127 | div_w = fft.irfft2(div_w_hat, s=(n_grid, n_grid), dim=(1, 2)).real 128 | 129 | assert torch.linalg.norm(div_w) < 1e-12 130 | 131 | 132 | @pytest.mark.parametrize( 133 | "input_shape, modes", 134 | [ 135 | ((32, 32, 5), (8, 8, 3)), 136 | ((64, 64, 10), (16, 16, 4)), 137 | ((128, 128, 20), (16, 16, 4)), 138 | ], 139 | ) 140 | def test_lifting_operator_shape(input_shape, modes): 141 | width = 16 142 | modes_x, modes_y, modes_t = modes 143 | latent_steps = 5 # latent_steps should be <= time steps 144 | 145 | lifting = LiftingOperator( 146 | width=width, 147 | modes_x=modes_x, 148 | modes_y=modes_y, 149 | modes_t=modes_t, 150 | latent_steps=latent_steps, 151 | ) 152 | 153 | bsz = 8 154 | nx, ny, nt = input_shape 155 | v = torch.randn(bsz, 1, nx, ny, nt) 156 | 157 | output = lifting(v) 158 | 159 | assert output.shape == (bsz, width, nx, ny, latent_steps) 160 | 161 | 162 | @pytest.mark.parametrize( 163 | "input_shape, modes", 164 | [ 165 | ((32, 32, 12), (8, 8, 3)), 166 | ((64, 64, 15), (16, 16, 4)), 167 | ((128, 128, 20), (32, 32, 10)), 168 | ], 169 | ) 170 | def test_out_conv_shape(input_shape, modes): 171 | modes_x, modes_y, modes_t = modes 172 | out_dim = 1 173 | 174 | out_conv = OutConv( 175 | modes_x=modes_x, 176 | modes_y=modes_y, 177 | modes_t=modes_t, 178 | out_dim=out_dim, 179 | ) 180 | 181 | bsz = 8 182 | nx, ny, nt = input_shape 183 | latent_steps = 10 184 | out_steps = 40 185 | 186 | v = torch.randn(bsz, out_dim, nx, ny, latent_steps) 187 | v_res = torch.randn(bsz, nx, ny, nt) # Input with 5 steps 188 | 189 | output = out_conv(v, v_res, out_steps=out_steps) 190 | 191 | assert output.shape == (bsz, nx, ny, out_steps) 192 | 193 | 194 | @pytest.mark.parametrize( 195 | "input_shape, modes", 196 | [ 197 | ((32, 32, 12), (8, 8, 3)), 198 | ((64, 64, 15), (16, 16, 4)), 199 | ((128, 128, 20), (32, 32, 10)), 200 | ], 201 | ) 202 | def test_spectral_conv_s(input_shape, modes): 203 | in_channels = 16 204 | out_channels = 16 205 | modes_x, modes_y, modes_t = modes 206 | 207 | conv = SpectralConvS( 208 | in_channels=in_channels, 209 | out_channels=out_channels, 210 | modes_x=modes_x, 211 | modes_y=modes_y, 212 | modes_t=modes_t, 213 | ) 214 | 215 | bsz = 2 216 | nx, ny, nt = input_shape 217 | v = torch.randn(bsz, in_channels, nx, ny, nt) 218 | 219 | output = conv(v) 220 | 221 | assert output.shape == (bsz, out_channels, nx, ny, nt) 222 | 223 | 224 | @pytest.mark.parametrize( 225 | "out_steps", 226 | [10, 20, 40], 227 | ) 228 | def test_spectral_conv_t_with_different_out_steps(out_steps): 229 | in_channels = 16 230 | out_channels = 16 231 | modes_x, modes_y, modes_t = 8, 8, 4 232 | 233 | conv = SpectralConvT( 234 | in_channels=in_channels, 235 | out_channels=out_channels, 236 | modes_x=modes_x, 237 | modes_y=modes_y, 238 | modes_t=modes_t, 239 | ) 240 | 241 | bsz = 2 242 | nx, ny, nt = 64, 64, 10 243 | 244 | v = torch.randn(bsz, in_channels, nx, ny, nt) 245 | 246 | output = conv(v, out_steps=out_steps) 247 | 248 | assert output.shape == (bsz, out_channels, nx, ny, out_steps) 249 | 250 | @pytest.mark.parametrize( 251 | "mesh_size", 252 | [ 253 | (64, 64, 10), 254 | (128, 128, 20), 255 | (256, 256, 40), 256 | ], 257 | ) 258 | def test_sfno_with_different_input_sizes(mesh_size): 259 | modes = 8 260 | modes_t = 4 261 | width = 16 262 | bsz = 2 263 | # note: input steps >= latent steps 264 | model = SFNO(modes, modes, modes_t, width) 265 | 266 | x = torch.randn(bsz, *mesh_size) 267 | pred = model(x) 268 | 269 | # Output should match input size if out_steps is not specified 270 | assert pred.shape == (bsz, *mesh_size) 271 | 272 | 273 | @pytest.mark.parametrize( 274 | "out_steps", 275 | [10, 20, 40], 276 | ) 277 | def test_sfno_forward_with_different_output_steps(out_steps): 278 | modes = 8 279 | modes_t = 4 280 | width = 16 281 | 282 | model = SFNO( 283 | modes_x=modes, 284 | modes_y=modes, 285 | modes_t=modes_t, 286 | width=width, 287 | output_steps=out_steps, # Default output steps 288 | ) 289 | 290 | bsz = 2 291 | nx, ny, nt = 64, 64, 10 292 | x = torch.randn(bsz, nx, ny, nt) 293 | 294 | # Should use default output steps 295 | pred = model(x) 296 | assert pred.shape == (bsz, nx, ny, out_steps) 297 | -------------------------------------------------------------------------------- /fno/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import subprocess 5 | import sys 6 | from contextlib import contextmanager 7 | from time import ctime, time 8 | from typing import Generator, Callable 9 | 10 | import numpy as np 11 | import psutil 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def get_seed(s, quiet=True, cudnn=True, logger=None): 17 | os.environ["PYTHONHASHSEED"] = str(s) 18 | np.random.seed(s) 19 | # Torch 20 | torch.manual_seed(s) 21 | torch.cuda.manual_seed(s) 22 | if cudnn: 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed_all(s) 27 | message = f"""""" 28 | message += f""" 29 | os.environ['PYTHONHASHSEED'] = str({s}) 30 | numpy.random.seed({s}) 31 | torch.manual_seed({s}) 32 | torch.cuda.manual_seed({s}) 33 | """ 34 | if cudnn: 35 | message += f""" 36 | torch.backends.cudnn.deterministic = True 37 | torch.backends.cudnn.benchmark = False""" 38 | 39 | if torch.cuda.is_available(): 40 | message += f""" 41 | torch.cuda.manual_seed_all({s})""" 42 | 43 | if not quiet and not logger: 44 | print("\n") 45 | print(f"The following code snippets have been run.") 46 | print("=" * 50) 47 | print(message) 48 | print("=" * 50) 49 | elif not quiet and logger: 50 | logger.info( 51 | "The following code snippets have been run:" 52 | + " | ".join(message.splitlines()) 53 | ) 54 | 55 | 56 | class Colors: 57 | """Defining Color Codes to color the text displayed on terminal.""" 58 | 59 | red = "\033[91m" 60 | green = "\033[92m" 61 | yellow = "\033[93m" 62 | blue = "\033[94m" 63 | magenta = "\033[95m" 64 | end = "\033[0m" 65 | 66 | 67 | def color(string: str, color: Colors = Colors.yellow) -> str: 68 | return f"{color}{string}{Colors.end}" 69 | 70 | 71 | @contextmanager 72 | def timer(label: str = "", compact=False, quiet=False) -> Generator[None, None, None]: 73 | """ 74 | https://www.kaggle.com/c/riiid-test-answer-prediction/discussion/203020#1111022 75 | print 76 | 1. the time the code block takes to run 77 | 2. the memory usage. 78 | """ 79 | p = psutil.Process(os.getpid()) 80 | m0 = p.memory_info()[0] / 2.0**30 81 | start = time() # Setup - __enter__ 82 | if not compact and not quiet: 83 | print(color(f"{label}:\nStart at {ctime(start)};", color=Colors.blue)) 84 | try: 85 | yield # yield to body of `with` statement 86 | finally: # Teardown - __exit__ 87 | m1 = p.memory_info()[0] / 2.0**30 88 | delta = m1 - m0 89 | sign = "+" if delta >= 0 else "-" 90 | delta = math.fabs(delta) 91 | end = time() 92 | print( 93 | color( 94 | f"Done at {ctime(end)} ({end - start:.6f} secs elapsed);", 95 | color=Colors.blue, 96 | ) 97 | ) 98 | print(color(f"\nLocal RAM usage at START: {m0:.2f} GB", color=Colors.green)) 99 | print( 100 | color( 101 | f"Local RAM usage at END: {m1:.2f}GB ({sign}{delta:.2f}GB)", 102 | color=Colors.green, 103 | ) 104 | ) 105 | print("\n") 106 | elif compact and not quiet: 107 | yield 108 | print( 109 | color( 110 | f"{label} - done in {time() - start:.6f} seconds. \n", color=Colors.blue 111 | ) 112 | ) 113 | else: 114 | try: 115 | yield 116 | finally: 117 | pass 118 | 119 | 120 | def pretty_tensor_size(size): 121 | """Pretty prints a torch.Size object""" 122 | assert isinstance(size, torch.Size) 123 | return " x ".join(map(str, size)) 124 | 125 | 126 | def get_size(bytes, suffix="B"): 127 | """ 128 | by Fred Cirera, https://stackoverflow.com/a/1094933/1870254, modified 129 | Scale bytes to its proper format 130 | e.g: 131 | 1253656 => '1.20MiB' 132 | 1253656678 => '1.17GiB' 133 | """ 134 | for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 135 | if abs(bytes) < 1024.0: 136 | return f"{bytes:3.2f} {unit}{suffix}" 137 | bytes /= 1024.0 138 | return f"{bytes:3.2f} 'Yi'{suffix}" 139 | 140 | 141 | def dump_tensors(gpu_only=True): 142 | """Prints a list of the Tensors being tracked by the garbage collector.""" 143 | import gc 144 | 145 | total_size = 0 146 | for obj in gc.get_objects(): 147 | try: 148 | if torch.is_tensor(obj): 149 | if not gpu_only or obj.is_cuda: 150 | print( 151 | "%s:%s%s %s" 152 | % ( 153 | type(obj).__name__, 154 | " GPU" if obj.is_cuda else "", 155 | " pinned" if obj.is_pinned else "", 156 | pretty_tensor_size(obj.size()), 157 | ) 158 | ) 159 | total_size += obj.numel() 160 | elif hasattr(obj, "data") and torch.is_tensor(obj.data): 161 | if not gpu_only or obj.is_cuda: 162 | info = ( 163 | f"{type(obj).__name__} → {type(obj.data).__name__}" + " GPU" 164 | if obj.is_cuda 165 | else ( 166 | "" + " pinned" 167 | if obj.data.is_pinned 168 | else ( 169 | "" + " grad" 170 | if obj.requires_grad 171 | else ( 172 | "" + " volatile" 173 | if obj.volatile 174 | else "" + f" {pretty_tensor_size(obj.data.size())}" 175 | ) 176 | ) 177 | ) 178 | ) 179 | print(info) 180 | total_size += obj.data.numel() 181 | except Exception as e: 182 | pass 183 | print("Total size:", get_size(total_size)) 184 | 185 | 186 | def get_num_params(model): 187 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 188 | num_params = 0 189 | for p in model_parameters: 190 | num_params += np.prod(p.size() + (2,) if p.is_complex() else p.size()) 191 | return num_params 192 | 193 | 194 | def get_config(module: nn.Module, quiet=True, logger=None): 195 | config = {} 196 | _config = filter(lambda x: not x.startswith("_"), dir(module)) 197 | for a in _config: 198 | if not isinstance(getattr(module, a), (Callable, nn.Parameter, torch.Tensor)): 199 | config[a] = getattr(module, a) 200 | if not quiet and not logger: 201 | for k, v in config.items(): 202 | print(f"{k:<25}: {v}") 203 | elif logger: 204 | logger.info(f"args of {module.__repr__()}: "+" | ".join(f"{k}={v}" for k, v in config.items())) 205 | return config 206 | 207 | 208 | def default(value, d): 209 | """ 210 | helper taken from https://github.com/lucidrains/linear-attention-transformer 211 | """ 212 | return d if value is None else value 213 | 214 | 215 | def clones(module, N): 216 | """ 217 | Input: 218 | - module: nn.Module obj 219 | Output: 220 | - zip identical N layers (not stacking) 221 | 222 | Refs: 223 | - https://nlp.seas.harvard.edu/2018/04/03/attention.html 224 | """ 225 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 226 | 227 | 228 | def print_config(model: nn.Module) -> None: 229 | try: 230 | for a in model.config.keys(): 231 | print(f"{a:<25}: ", getattr(model, a)) 232 | except: 233 | config = filter(lambda x: not x.startswith("__"), dir(model)) 234 | for a in config: 235 | print(f"{a:<25}: ", getattr(model, a)) 236 | 237 | def check_nan(tensor, tensor_name=""): 238 | if tensor.isnan().any(): 239 | tensor = tensor[~torch.isnan(tensor)] 240 | raise ValueError(f"{tensor_name} has nan with norm {torch.linalg.norm(tensor)}") 241 | 242 | def get_core_optimizer(name: str): 243 | """ 244 | ASGD Adadelta Adagrad Adam AdamW Adamax LBFGS NAdam Optimizer RAdam RMSprop Rprop SGD 245 | """ 246 | import torch.optim as optim 247 | return getattr(optim, name) 248 | 249 | if __name__ == "__main__": 250 | get_seed(42) 251 | else: 252 | with timer("", quiet=True): 253 | try: 254 | import plotly.express as px 255 | import plotly.figure_factory as ff 256 | import plotly.graph_objects as go 257 | import plotly.io as pio 258 | except ImportError as err: 259 | sys.stderr.write(f"Error: failed to import module ({err})") 260 | subprocess.check_call([sys.executable, "-m", "pip", "install", "plotly"]) 261 | -------------------------------------------------------------------------------- /fno/visualizations.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import plotly.express as px 4 | import plotly.figure_factory as ff 5 | import plotly.graph_objects as go 6 | import seaborn as sns 7 | import torch 8 | import torch.fft as fft 9 | 10 | import xarray 11 | from mpl_toolkits.axes_grid1 import make_axes_locatable 12 | 13 | 14 | def plot_contour(z, func=plt.imshow, **kwargs): 15 | if isinstance(z, torch.Tensor): 16 | z = z.cpu().numpy() 17 | _, ax = plt.subplots(figsize=(3, 3)) 18 | f = func(z, cmap=sns.cm.icefire) 19 | ax.xaxis.set_visible(False) 20 | ax.yaxis.set_visible(False) 21 | divider = make_axes_locatable(ax) 22 | cax = divider.append_axes("right", size="7%", pad=0.1) 23 | cbar = plt.colorbar(f, ax=ax, cax=cax) 24 | cbar.ax.tick_params(labelsize=10) 25 | cbar.ax.locator_params(nbins=9) 26 | cbar.update_ticks() 27 | 28 | 29 | def plot_contour_plotly( 30 | z, 31 | colorscale="RdYlBu", 32 | showscale=False, 33 | showlabels=False, 34 | continuous_coloring=False, 35 | reversescale=True, 36 | dimensions=(200, 200), 37 | line_smoothing=0.7, 38 | ncontours=20, 39 | **plot_kwargs, 40 | ): 41 | """ 42 | show 2D solution z of its contour 43 | colorscale: balance (MATLAB new) or Jet (MATLAB old) 44 | """ 45 | 46 | if not plot_kwargs: 47 | plot_kwargs = dict( 48 | contour_kwargs=dict( 49 | colorscale=colorscale, 50 | line_smoothing=line_smoothing, 51 | line_width=0.1, 52 | ncontours=ncontours, 53 | reversescale=reversescale, 54 | # ) 55 | ), 56 | figure_kwargs=dict( 57 | layout={ 58 | "xaxis": { 59 | "title": "x-label", 60 | "visible": False, 61 | "showticklabels": False, 62 | }, 63 | "yaxis": { 64 | "title": "y-label", 65 | "visible": False, 66 | "showticklabels": False, 67 | }, 68 | } 69 | ), 70 | layout_kwargs=dict( 71 | margin=dict(l=0, r=0, t=0, b=0), 72 | width=dimensions[0], 73 | height=dimensions[1], 74 | template="plotly_white", 75 | ), 76 | ) 77 | 78 | contour_kwargs = plot_kwargs["contour_kwargs"] 79 | figure_kwargs = plot_kwargs["figure_kwargs"] 80 | layout_kwargs = plot_kwargs["layout_kwargs"] 81 | if showscale: 82 | contour_kwargs["showscale"] = True 83 | contour_kwargs["colorbar"] = dict( 84 | thickness=0.15 * layout_kwargs["height"], 85 | tickwidth=0.3, 86 | exponentformat="e", 87 | ) 88 | layout_kwargs["width"] = 1.32 * layout_kwargs["height"] 89 | else: 90 | contour_kwargs["showscale"] = False 91 | 92 | if continuous_coloring: 93 | contour_kwargs["contours_coloring"] = "heatmap" 94 | 95 | if showlabels: 96 | contour_kwargs["contours"] = dict( 97 | coloring="heatmap", 98 | showlabels=True, # show labels on contours 99 | labelfont=dict( # label font properties 100 | size=12, 101 | color="gray", 102 | ), 103 | ) 104 | 105 | uplot = go.Contour(z=z, **contour_kwargs) 106 | fig = go.Figure(data=uplot, **figure_kwargs) 107 | if "template" not in layout_kwargs.keys(): 108 | fig.update_layout(template="plotly_dark", **layout_kwargs) 109 | else: 110 | fig.update_layout(**layout_kwargs) 111 | return fig 112 | 113 | 114 | def get_enstrophy_spectrum(vorticity, h): 115 | if isinstance(vorticity, np.ndarray): 116 | vorticity = torch.from_numpy(vorticity) 117 | n = vorticity.shape[0] 118 | kx = fft.fftfreq(n, d=h) 119 | ky = fft.fftfreq(n, d=h) 120 | kx, ky = torch.meshgrid([kx, ky], indexing="ij") 121 | kmax = n // 2 122 | kx = kx[..., : kmax + 1] 123 | ky = ky[..., : kmax + 1] 124 | k2 = (4 * torch.pi**2) * (kx**2 + ky**2) 125 | k2[0, 0] = 1.0 126 | 127 | wh = fft.rfft2(vorticity) 128 | 129 | tke = (0.5 * wh * wh.conj()).real 130 | kmod = torch.sqrt(k2) 131 | k = torch.arange(1, kmax, dtype=torch.float64) # Nyquist limit for this grid 132 | Ens = torch.zeros_like(k) 133 | dk = (torch.max(k) - torch.min(k)) / (2 * n) 134 | for i in range(len(k)): 135 | Ens[i] += (tke[(kmod < k[i] + dk) & (kmod >= k[i] - dk)]).sum() 136 | 137 | Ens = Ens / Ens.sum() 138 | return Ens 139 | 140 | 141 | def plot_enstrophy_spectrum( 142 | fields: list, 143 | h=None, 144 | slope=5, 145 | factor=None, 146 | cutoff=1e-15, 147 | plot_cutoff_factor=1 / 8, 148 | labels=None, 149 | title=None, 150 | legend_loc="upper right", 151 | fontsize=15, 152 | subplot_kw={"figsize": (5, 5), "dpi": 100, "facecolor": "w"}, 153 | **kwargs, 154 | ): 155 | for k, field in enumerate(fields): 156 | if isinstance(field, np.ndarray): 157 | fields[k] = torch.from_numpy(field) 158 | if labels is None: 159 | labels = [f"Field {i}" for i in range(len(fields))] 160 | n = fields[0].shape[0] 161 | if h is None: 162 | h = 1 / n 163 | kmax = n // 2 164 | k = torch.arange(1, kmax, dtype=torch.float64) # Nyquist limit for this grid 165 | Es = [get_enstrophy_spectrum(field, h) for field in fields] 166 | if factor is None: 167 | factor = Es[-1].quantile(0.8) / (k[-1] ** (-slope)) 168 | # print(factor) 169 | 170 | fig, ax = plt.subplots(**subplot_kw) 171 | plot_cutoff = int(n * plot_cutoff_factor) 172 | for i, E in enumerate(Es): 173 | if cutoff is not None: 174 | E[E < cutoff] = np.nan 175 | E[-plot_cutoff:] = np.nan 176 | plt.loglog(k, E, label=f"{labels[i]}") 177 | 178 | plt.loglog( 179 | k[:-plot_cutoff], 180 | (factor * k ** (-slope))[:-plot_cutoff], 181 | "b--", 182 | label=f"$O(k^{{{-slope:.3g}}})$", 183 | ) 184 | plt.grid(True, which="both", ls="--", linewidth=0.4) 185 | plt.autoscale(enable=True, axis="x", tight=True) 186 | plt.legend(fontsize=fontsize, loc=legend_loc) 187 | plt.title(title, fontsize=fontsize) 188 | plt.xlabel("Wavenumber", fontsize=fontsize) 189 | ax.xaxis.set_tick_params(labelsize=fontsize) 190 | ax.yaxis.set_tick_params(labelsize=fontsize) 191 | 192 | 193 | def plot_contour_trajectory( 194 | field, 195 | num_snapshots=5, 196 | col_wrap=5, 197 | contourf=False, 198 | T_start=4.5, 199 | dt=1e-1, 200 | title=None, 201 | cb_kws=dict(orientation="vertical", pad=0.01, aspect=10), 202 | subplot_kws=dict( 203 | xticks=[], 204 | yticks=[], 205 | ylabel="", 206 | xlabel="", 207 | ), 208 | **plot_kws, 209 | ): 210 | """ 211 | plot trajectory using xarray's imshow or contourf wrapper 212 | """ 213 | field = field.detach().cpu().numpy() 214 | *size, T = field.shape 215 | grid = np.linspace(0, 1, size[0] + 1)[:-1] 216 | time = np.arange(T) * dt + T_start 217 | coords = { 218 | "x": grid, 219 | "y": grid, 220 | "t": time, 221 | } 222 | ds = xarray.DataArray(field, dims=["x", "y", "t"], coords=coords) 223 | t_steps = T // num_snapshots 224 | T_rem = T % num_snapshots 225 | ds = ds.isel(t=slice(T_rem, None)).thin({"t": t_steps}) 226 | plot_func = ds.plot.contourf if contourf else ds.plot.imshow 227 | 228 | # fig = plt.figure() 229 | _plot_kws = dict( 230 | col_wrap=col_wrap, 231 | cmap=sns.cm.icefire, 232 | interpolation="hermite", 233 | robust=True, 234 | add_colorbar=True, 235 | xticks=None, 236 | yticks=None, 237 | size=3, 238 | aspect=1, 239 | ) 240 | _plot_kws.update(plot_kws) 241 | 242 | im = plot_func( 243 | col="t", 244 | subplot_kws=subplot_kws, 245 | cbar_kwargs=cb_kws, 246 | **_plot_kws, 247 | ) 248 | if title is not None: 249 | im.fig.suptitle(title, y=0.05) 250 | # plt.show() 251 | 252 | return im 253 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl_py==2.2.2 2 | dill==0.4.0 3 | einops==0.8.1 4 | h5py==3.13.0 5 | matplotlib==3.10.3 6 | numpy==2.2.6 7 | plotly==6.0.1 8 | psutil==7.0.0 9 | pytest==8.3.5 10 | scipy==1.15.3 11 | seaborn==0.13.2 12 | tensordict==0.7.2 13 | torch==2.6.0 14 | tqdm==4.67.1 15 | xarray==2025.3.1 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'torch-cfd', 5 | packages=find_packages(include=['torch_cfd', 'torch_cfd.*']), 6 | version='{{VERSION_PLACEHOLDER}}', 7 | license='Apache-2.0', 8 | description = 'PyTorch CFD', 9 | long_description='PyTorch Computational Fluid Dynamics Library', 10 | long_description_content_type="text/markdown", 11 | author = 'Shuhao Cao', 12 | author_email = 'scao.math@gmail.com', 13 | url = 'https://github.com/scaomath/torch-cfd', 14 | keywords = ['pytorch', 'cfd', 'pde', 'spectral', 'fluid dynamics', 'deep learning', 'neural operator'], 15 | python_requires='>=3.10', 16 | install_requires=[ 17 | 'numpy>=2.2.0', 18 | 'torch>=2.5.0', 19 | 'xarray>=2025.3.1', 20 | 'tqdm>=4.62.0', 21 | 'einops>=0.8.0', 22 | 'dill>=0.4.0', 23 | 'matplotlib>=3.5.0', 24 | 'seaborn>=0.13.0', 25 | ], 26 | classifiers=[ 27 | 'Development Status :: 4 - Beta', 28 | 'Intended Audience :: Science/Research', 29 | 'Topic :: Scientific/Engineering :: Mathematics', 30 | 'License :: OSI Approved :: Apache Software License', 31 | 'Programming Language :: Python :: 3.10', 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /torch_cfd/README.md: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | - [x] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray` and `GridVariable` (added 0.1.). 4 | - [x] add discrete Helmholtz decomposition (pressure projection) in both spatial and spectral domains (added 0.0.1). 5 | - [x] adjust the functions and routines to act on `(batch, time, *spatial)` tensor, currently only `(*spatial)` is supported (added for key routines in 0.0.1). 6 | - [x] add native FFT-based vorticity computation, instead of taking finite differences for pseudo-spectral (added in 0.0.4). 7 | - [ ] add no-slip boundary. 8 | 9 | # Changelog 10 | 11 | ### 0.2.0 12 | 13 | After version `0.1.0`, I began prompt with existing codes in VSCode Copilot (using the OpenAI Enterprise API kindedly provided by UM), which arguably significantly improve the "porting->debugging->refactoring" cycle. I recorded some several good refactoring suggestions by GPT o4-mini and some by ***Claude Sonnet 3.7*** here. There were definitely over-complicated "poor" refactoring suggestions, which have been stashed after benchmarking. I found that Sonnet 3.7 is exceptionally good at providing templates for me to filling the details, when it is properly prompted with details of the functionality of current codes. Another highlight is that, based on the error or exception raised in the unittests, Sonnet 3.7 directly added configurations in `.vscode/launch.json`, saving me quite some time of copy-paste boilerplates then change by hand. 14 | 15 | #### Major change: batch dimension for FVM 16 | The finite volume solver now accepts the batch dimension, some key updates include 17 | - Re-implemented flux computations of $(\boldsymbol{u}\cdot\nabla)\boldsymbol{u}$ as `nn.Module`. I originally implemented a tensor only version but did not quite work by pre-assigning `target_offsets`, which was buggy for the second component of the velocity. Sonnet 3.7 provided a very good refactoring template after being given both the original code and my implementation, after which I pretty much just fill in the blanks in [`advection.py`](./advection.py). Later I found out the bug was pretty stupid on my side, from 18 | ```python 19 | for i in range(2): 20 | u = v[i] 21 | for j in range(2): 22 | v[i] = flux_interpolation(u, v[j]) # offset is updated here 23 | ``` 24 | to 25 | ```python 26 | # this is gonna be buggy of course because the offset alignment will go wrong 27 | # the target_offsets are looped inside flux_interpolation of AdvectionVanLeer 28 | for offset in target_offsets: 29 | u = v[i]; u.offset = offset 30 | for j in range(2): 31 | v[i] = flux_interpolation(u, v[j]) 32 | v[i].offset = offset 33 | ``` 34 | The fixed version that loops in `__call__` of [`AdvectionVanLeer` class is here](./advection.py#L451). 35 | - Implemented a `_constant_pad_tensor` function to improve the behavior of `F.pad`, to help imposing non-homogeneous boundary conditions. It uses naturally ordered `pad` args (like Jax, unlike `F.pad`), while taking the batch dimension into consideration. 36 | - Changed the behavior of `u.shift` taking into consideration of batch dimension. In general these methods within the `bc` class or `GridVariable` starts from the last dimension instead of the first, e.g., `for dim in range(u.grid.ndim): ...` changes to `for dim in range(-u.grid.ndim, 0): ...`. 37 | 38 | 39 | #### Retaining only `GridVariable` class 40 | This refactoring is suggested by ***Claude Sonnet 3.7***. In [`grids.py`](./grids.py#442), following `numpy`'s practice (see updates notes in [0.0.1](#001)) in `np.lib.mixins.NDArrayOperatorsMixin`, I originally implemented two mixin classes, [`GridArrayOperatorsMixin`](https://github.com/scaomath/torch-cfd/blob/475c7385549225570b61d8c3dcf1d415d8977f19/torch_cfd/grids.py#L304) and [`GridVariableOperatorsMixin`](https://github.com/scaomath/torch-cfd/blob/475c7385549225570b61d8c3dcf1d415d8977f19/torch_cfd/grids.py#L616) using the same boilerplate to enable operations such as `v + u` or computing the upwinding flux for two `GridArray` instances: 41 | ```python 42 | def _binary_method(name, op): 43 | def method(self, other): 44 | ... 45 | method.__name__ = f"__{name}__" 46 | return method 47 | 48 | class GridArrayOperatorsMixin: 49 | __slots__ = () 50 | __lt__ = _binary_method("lt", operator.lt) 51 | ... 52 | 53 | @dataclasses.dataclass 54 | class GridArray(GridArrayOperatorsMixin): 55 | 56 | @classmethod 57 | def __torch_function__(self, ufunc, types, args=(), kwargs=None): 58 | ... 59 | ``` 60 | `GridVariable` is implemented largely the same recycling the codes. Note that `GridVariable` is only a container for `GridArray` that wraps boundary conditions of a field in it. Whereas`GridArray`, arguably being more vital in the whole scheme, determines an array `v`'s location by its `offset` (cell center or faces, or nodes) by Jax-CFD's original design. After a detailed prompt introducing each class's functions, after reading my workspace, **Sonnet 3.7** suggested introducing only a single `GridVariable`, while performing binary methods of two fields with the same offsets, the boundary conditions will be set to `None` if they don't share the same bc. This is already the case for some flux computations in the original `Jax-CFD` but implemented in a more hard-coded way. Now the implementation is much more concise and the boundary condition for flux computation is handled in automatically. 61 | 62 | #### Adding a GridVectorBase class 63 | Yet again, ***Claude Sonnet 3.7*** gave an awesome refactoring advice here. In `0.1.0`'s `grids.py`, the vector field's wrappers recycles lots of [boilerplate codes I learned from numpy back in 0.0.1](https://github.com/scaomath/torch-cfd/blob/475c7385549225570b61d8c3dcf1d415d8977f19/torch_cfd/grids.py#L801). There codes are largely the same for `GridArray` and `GridVariable` to define their behaviors when performing `__add__` and `__mul__` with a scalar, etc: 64 | ```python 65 | class GridArrayVector(tuple): 66 | def __new__(cls, arrays): 67 | ... 68 | 69 | def __add__(self, other): 70 | ... 71 | 72 | __radd__ = __add__ 73 | 74 | class GridVariableVector(tuple): 75 | def __new__(cls, variables): 76 | ... 77 | 78 | def __add__(self, other): 79 | # largely the same 80 | ... 81 | __radd__ = __add__ 82 | ``` 83 | The refactored code by Sonnet 3.7 is just amazing by cleverly exploiting the `classmethod` decorator and `super()`: 84 | ```python 85 | from typing import TypeVar 86 | 87 | class GridVectorBase(tuple, Generic[TypeVar("T")]): 88 | 89 | def __new__(cls, v: Sequence[T]): 90 | if not all(isinstance(x, cls._element_type()) for x in v): 91 | raise TypeError 92 | return super().__new__(cls, v) 93 | 94 | @classmethod 95 | def _element_type(cls): 96 | raise NotImplementedError 97 | 98 | def __add__(self, other): 99 | ... 100 | __radd__ = __add__ 101 | 102 | 103 | class GridVariableVector(GridVectorBase[GridVariable]): 104 | @classmethod 105 | def _element_type(cls): 106 | return GridVariable 107 | 108 | ``` 109 | 110 | #### Unittests 111 | Another great feat by ***Sonnet 3.7*** is coming up with unittests using `absl.testing`'s parametrized testing. Based on [`test_grids.py`](tests/test_grids.py) I ported and tweaked by-hand example-wise, Sonnet 3.7 generated [corresponding tests using finite differences](tests/test_finite_differences.py). Even though "reasoning" regarding numerical PDE is sometimes wrong, for example, coming up with what would be shape after trimming the boundary for MAC grids variables, most are correctly formulated and helped figure out several bugs regarding the batch implementation for finite volume method. 112 | 113 | 114 | ### 0.1.0 115 | - Implemented the FVM method on a staggered MAC grid (pressure on cell centers). 116 | - Added native PyTorch implementation for applying `torch.linalg` and `torch.fft` functions directly on `GridArray` and `GridVariable`. 117 | - Added native implementation of arithmetic manipulation directly on `GridVariableVector`. 118 | - Added several helper functions `consistent_grid` to replace `consistent_grid_arrays`. 119 | - Removed dependence of `from torch.utils._pytree import register_pytree_node` 120 | - Minor notes: 121 | - Added native PyTorch dense implementation of `scipy.linalg.circulant`: for a 1d array `column` 122 | ```python 123 | # scipy version 124 | mat = scipy.linalg.circulant(column) 125 | 126 | # torch version 127 | idx = (n - torch.arange(n)[None].T + torch.arange(n)[None]) % n 128 | mat = torch.gather(column[None, ...].expand(n, -1), 1, idx) 129 | ``` 130 | 131 | 132 | ### 0.0.8 133 | - Starting from PyTorch 2.6.0, if data are saved using serialization (for loop with `pickle` or `dill`), then `torch.load` will raise an error, if you want to load the data, you can either add this in the imports or re-generate the data using this version. 134 | ```python 135 | torch.serialization.add_safe_globals([defaultdict]) 136 | torch.serialization.add_safe_globals([list]) 137 | ``` 138 | 139 | ### 0.0.6 140 | - Minor changes in function names, added `sfno` directory and moved `get_trajectory_imex` and `get_trajectory_rk4` to the data generation folder. 141 | 142 | ### 0.0.5 143 | - added a batch dimension in solver matching. By default, the solver should work for input shapes `(batch, kx, ky)` or `(kx, ky)`. `get_trajectory()` output is either `(n_t, kx, ky)` or `(batch, n_t, kx, ky)`. 144 | 145 | 146 | ### 0.0.4 147 | - The forcing functions are now implemented as `nn.Module` and utilize a wrapper decorator for the potential function. 148 | - Added some common time stepping schemes, additional ones that Jax-CFD did not have includes the commonly used Crank-Nicholson IMEX. 149 | - Combined the implementation for step size satisfying the CFL condition. 150 | 151 | 152 | ### 0.0.1 153 | - `grids.GridArray` is implemented as a subclass of `torch.Tensor`, not the original jax implentation uses the inheritance from `np.lib.mixins.NDArrayOperatorsMixin`. `__array_ufunc__()` is replaced by `__torch_function__()`. 154 | - The padding of `torch.nn.functional.pad()` is different from `jax.numpy.pad()`, PyTorch's pad starts from the last dimension, while Jax's pad starts from the first dimension. For example, `F.pad(x, (0, 0, 1, 0, 1, 1))` is equivalent to `jax.numpy.pad(x, ((1, 1), (1, 0), (0, 0)))` for an array of size `(*, t, h, w)`. 155 | - A handy outer sum, which is usefully in getting the n-dimensional Laplacian in the frequency domain, is implemented as follows to replace `reduce(np.add.outer, eigenvalues)` 156 | ```python 157 | def outer_sum(x: Union[List[torch.Tensor], Tuple[torch.Tensor]]) -> torch.Tensor: 158 | """ 159 | Returns the outer sum of a list of one dimensional arrays 160 | Example: 161 | x = [a, b, c] 162 | out = a[..., None, None] + b[..., None] + c 163 | """ 164 | 165 | def _sum(a, b): 166 | return a[..., None] + b 167 | 168 | return reduce(_sum, x) 169 | ``` -------------------------------------------------------------------------------- /torch_cfd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaomath/torch-cfd/6abe65ad8c49b31bc09c661974c7c4b120ab6729/torch_cfd/__init__.py -------------------------------------------------------------------------------- /torch_cfd/forcings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Modifications copyright (C) 2024 S.Cao 16 | # ported Google's Jax-CFD functional template to PyTorch's tensor ops 17 | 18 | from typing import Optional, Tuple, Union 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from torch_cfd import grids 24 | 25 | 26 | Grid = grids.Grid 27 | GridVariable = grids.GridVariable 28 | 29 | 30 | def forcing_eval(eval_func): 31 | """ 32 | A decorator for forcing evaluators. 33 | This decorator simplifies the conversion of a standalone forcing evaluation function 34 | to a method that can be called on a class instance. It standardizes the interface 35 | for forcing functions by ensuring they accept grid and field parameters. 36 | Parameters 37 | ---------- 38 | eval_func : callable 39 | The forcing evaluation function to be decorated. Should accept grid and field parameters 40 | and return a torch.Tensor representing the forcing term. 41 | Returns 42 | ------- 43 | callable 44 | A wrapper function that can be used as a class method for evaluating forcing terms. 45 | The wrapper maintains the same signature as the decorated function but ignores the 46 | class instance (self) parameter. 47 | Examples 48 | -------- 49 | @forcing_eval 50 | def constant_forcing(field, grid): 51 | return torch.ones_like(field) 52 | """ 53 | 54 | def wrapper( 55 | cls, 56 | grid: Grid, 57 | field: Optional[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]], 58 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 59 | return eval_func(grid, field) 60 | 61 | return wrapper 62 | 63 | 64 | class ForcingFn(nn.Module): 65 | """ 66 | A meta class for forcing functions 67 | 68 | Args: 69 | vorticity: whether the forcing function is a vorticity forcing 70 | 71 | Notes: 72 | - the grid variable is the first argument in the __call__ so that the second variable can be velocity or vorticity 73 | - forcing term does not have boundary conditions, when being evaluated, it is simply added to the velocity or vorticity (with the same grid) 74 | 75 | TODO: 76 | - [ ] MAC grid the components of velocity does not live on the same grid. 77 | """ 78 | 79 | def __init__( 80 | self, 81 | grid: Grid, 82 | scale: float = 1, 83 | wave_number: int = 1, 84 | diam: float = 1.0, 85 | swap_xy: bool = False, 86 | vorticity: bool = False, 87 | offsets: Optional[Tuple[Tuple[float, ...], ...]] = None, 88 | device: Optional[torch.device] = None, 89 | **kwargs, 90 | ): 91 | super().__init__() 92 | self.grid = grid 93 | self.scale = scale 94 | self.wave_number = wave_number 95 | self.diam = diam 96 | self.swap_xy = swap_xy 97 | self.vorticity = vorticity 98 | self.offsets = grid.cell_faces if offsets is None else offsets 99 | self.device = grid.device if device is None else device 100 | 101 | @forcing_eval 102 | def velocity_eval( 103 | grid: Grid, velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] 104 | ) -> Tuple[torch.Tensor, torch.Tensor]: 105 | raise NotImplementedError 106 | 107 | @forcing_eval 108 | def vorticity_eval(grid: Grid, vorticity: Optional[torch.Tensor]) -> torch.Tensor: 109 | raise NotImplementedError 110 | 111 | def forward( 112 | self, 113 | grid: Optional[Union[Grid, Tuple[Grid, Grid]]] = None, 114 | velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 115 | vorticity: Optional[torch.Tensor] = None, 116 | ) -> Tuple[torch.Tensor, torch.Tensor]: 117 | if not self.vorticity: 118 | return self.velocity_eval(grid, velocity) 119 | else: 120 | return self.vorticity_eval(grid, vorticity) 121 | 122 | 123 | class KolmogorovForcing(ForcingFn): 124 | """ 125 | The Kolmogorov forcing function used in 126 | Sets up the flow that is used in Kochkov et al. [1]. 127 | which is based on Boffetta et al. [2]. 128 | 129 | Note in the port: this forcing belongs a larger class 130 | of isotropic turbulence. See [3]. 131 | 132 | References: 133 | [1] Machine learning-accelerated computational fluid dynamics. Dmitrii 134 | Kochkov, Jamie A. Smith, Ayya Alieva, Qing Wang, Michael P. Brenner, Stephan 135 | Hoyer Proceedings of the National Academy of Sciences May 2021, 118 (21) 136 | e2101784118; DOI: 10.1073/pnas.2101784118. 137 | https://doi.org/10.1073/pnas.2101784118 138 | 139 | [2] Boffetta, Guido, and Robert E. Ecke. "Two-dimensional turbulence." 140 | Annual review of fluid mechanics 44 (2012): 427-451. 141 | https://doi.org/10.1146/annurev-fluid-120710-101240 142 | 143 | [3] McWilliams, J. C. (1984). "The emergence of isolated coherent vortices 144 | in turbulent flow". Journal of Fluid Mechanics, 146, 21-43. 145 | """ 146 | 147 | def __init__( 148 | self, 149 | *args, 150 | diam=2 * torch.pi, 151 | offsets=((0, 0), (0, 0)), 152 | vorticity=False, 153 | wave_number=1, 154 | **kwargs, 155 | ): 156 | super().__init__( 157 | *args, 158 | diam=diam, 159 | offsets=offsets, 160 | vorticity=vorticity, 161 | wave_number=wave_number, 162 | **kwargs, 163 | ) 164 | 165 | def velocity_eval( 166 | self, 167 | grid: Optional[Grid], 168 | velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 169 | ) -> Tuple[torch.Tensor, torch.Tensor]: 170 | offsets = self.offsets 171 | grid = self.grid if grid is None else grid 172 | domain_factor = 2 * torch.pi / self.diam 173 | 174 | if self.swap_xy: 175 | x = grid.mesh(offsets[1])[0] 176 | v = GridVariable( 177 | self.scale * torch.sin(self.wave_number * domain_factor * x), 178 | offsets[1], 179 | grid, 180 | ) 181 | u = GridVariable(torch.zeros_like(v.data), (1, 1 / 2), grid) 182 | else: 183 | y = grid.mesh(offsets[0])[1] 184 | u = GridVariable( 185 | self.scale * torch.sin(self.wave_number * domain_factor * y), 186 | offsets[0], 187 | grid, 188 | ) 189 | v = GridVariable(torch.zeros_like(u.data), (1 / 2, 1), grid) 190 | return tuple((u, v)) 191 | 192 | def vorticity_eval( 193 | self, 194 | grid: Optional[Grid], 195 | vorticity: Optional[torch.Tensor] = None, 196 | ) -> torch.Tensor: 197 | offsets = self.offsets 198 | grid = self.grid if grid is None else grid 199 | domain_factor = 2 * torch.pi / self.diam 200 | 201 | if self.swap_xy: 202 | x = grid.mesh(offsets[1])[0] 203 | w = GridVariable( 204 | -self.scale 205 | * self.wave_number 206 | * domain_factor 207 | * torch.cos(self.wave_number * domain_factor * x), 208 | offsets[1], 209 | grid, 210 | ) 211 | else: 212 | y = grid.mesh(offsets[0])[1] 213 | w = GridVariable( 214 | -self.scale 215 | * self.wave_number 216 | * domain_factor 217 | * torch.cos(self.wave_number * domain_factor * y), 218 | offsets[0], 219 | grid, 220 | ) 221 | return w 222 | 223 | 224 | def scalar_potential(potential_func): 225 | def wrapper( 226 | cls, x: torch.Tensor, y: torch.Tensor, s: float, k: float 227 | ) -> torch.Tensor: 228 | return potential_func(x, y, s, k) 229 | 230 | return wrapper 231 | 232 | 233 | class SimpleSolenoidalForcing(ForcingFn): 234 | """ 235 | A simple solenoidal (rotating, divergence free) forcing function template. 236 | The template forcing is F = (-psi, psi) such that 237 | 238 | Args: 239 | grid: grid on which to simulate the flow 240 | scale: a in the equation above, amplitude of the forcing 241 | k: k in the equation above, wavenumber of the forcing 242 | """ 243 | 244 | def __init__( 245 | self, 246 | scale=1, 247 | diam=1.0, 248 | k=1.0, 249 | offsets=((0, 0), (0, 0)), 250 | vorticity=True, 251 | *args, 252 | **kwargs, 253 | ): 254 | super().__init__( 255 | *args, 256 | scale=scale, 257 | diam=diam, 258 | wave_number=k, 259 | offsets=offsets, 260 | vorticity=vorticity, 261 | **kwargs, 262 | ) 263 | 264 | @scalar_potential 265 | def potential(*args, **kwargs) -> torch.Tensor: 266 | raise NotImplementedError 267 | 268 | @scalar_potential 269 | def vort_potential(*args, **kwargs) -> torch.Tensor: 270 | raise NotImplementedError 271 | 272 | def velocity_eval( 273 | self, 274 | grid: Optional[Grid], 275 | velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 276 | ) -> Tuple[torch.Tensor, torch.Tensor]: 277 | offsets = self.offsets 278 | grid = self.grid if grid is None else grid 279 | domain_factor = 2 * torch.pi / self.diam 280 | k = self.wave_number * domain_factor 281 | scale = 0.5 * self.scale / (2 * torch.pi) / self.wave_number 282 | 283 | if self.swap_xy: 284 | x = grid.mesh(offsets[1])[0] 285 | y = grid.mesh(offsets[0])[1] 286 | rot = self.potential(x, y, scale, k) 287 | v = GridVariable(rot, offsets[1], grid) 288 | u = GridVariable(-rot, (1, 1 / 2), grid) 289 | else: 290 | x = grid.mesh(offsets[0])[0] 291 | y = grid.mesh(offsets[1])[1] 292 | rot = self.potential(x, y, scale, k) 293 | u = GridVariable(rot, offsets[0], grid) 294 | v = GridVariable(-rot, (1 / 2, 1), grid) 295 | return tuple((u, v)) 296 | 297 | def vorticity_eval( 298 | self, 299 | grid: Optional[Grid], 300 | vorticity: Optional[torch.Tensor] = None, 301 | ) -> torch.Tensor: 302 | offsets = self.offsets 303 | grid = self.grid if grid is None else grid 304 | domain_factor = 2 * torch.pi / self.diam 305 | k = self.wave_number * domain_factor 306 | scale = self.scale 307 | 308 | if self.swap_xy: 309 | x = grid.mesh(offsets[1])[0] 310 | y = grid.mesh(offsets[0])[1] 311 | return self.vort_potential(x, y, scale, k) 312 | else: 313 | x = grid.mesh(offsets[0])[0] 314 | y = grid.mesh(offsets[1])[1] 315 | return self.vort_potential(x, y, scale, k) 316 | 317 | 318 | class SinCosForcing(SimpleSolenoidalForcing): 319 | """ 320 | The solenoidal (divergence free) forcing function used in [4]. 321 | 322 | Note: in the vorticity-streamfunction formulation, the forcing 323 | is actually the curl of the velocity field, which 324 | is a*(sin(2*pi*k*(x+y)) + cos(2*pi*k*(x+y))) 325 | a=0.1, k=1 in [4] 326 | 327 | References: 328 | [4] Li, Zongyi, et al. "Fourier Neural Operator for 329 | Parametric Partial Differential Equations." 330 | ICLR. 2020. 331 | 332 | Args: 333 | grid: grid on which to simulate the flow 334 | scale: a in the equation above, amplitude of the forcing 335 | k: k in the equation above, wavenumber of the forcing 336 | """ 337 | 338 | def __init__( 339 | self, 340 | scale=0.1, 341 | diam=1.0, 342 | k=1.0, 343 | offsets=((0, 0), (0, 0)), 344 | *args, 345 | **kwargs, 346 | ): 347 | super().__init__( 348 | *args, 349 | scale=scale, 350 | diam=diam, 351 | k=k, 352 | offsets=offsets, 353 | **kwargs, 354 | ) 355 | 356 | @scalar_potential 357 | def potential(x: torch.Tensor, y: torch.Tensor, s: float, k: float) -> torch.Tensor: 358 | return s * (torch.sin(k * (x + y)) - torch.cos(k * (x + y))) 359 | 360 | @scalar_potential 361 | def vort_potential( 362 | x: torch.Tensor, y: torch.Tensor, s: float, k: float 363 | ) -> torch.Tensor: 364 | return s * (torch.cos(k * (x + y)) + torch.sin(k * (x + y))) 365 | -------------------------------------------------------------------------------- /torch_cfd/fvm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Modifications copyright (C) 2025 S.Cao 16 | # ported Google's Jax-CFD functional template to PyTorch's tensor ops 17 | from __future__ import annotations 18 | 19 | from typing import Callable, Dict, List, Optional, Sequence, Tuple 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | import torch_cfd.finite_differences as fdm 25 | from torch_cfd import advection, boundaries, forcings, grids, pressure 26 | 27 | 28 | Grid = grids.Grid 29 | GridVariable = grids.GridVariable 30 | GridVariableVector = grids.GridVariableVector 31 | ForcingFn = forcings.ForcingFn 32 | 33 | 34 | def wrap_field_same_bcs(v, field_ref): 35 | return GridVariableVector( 36 | tuple( 37 | GridVariable(a.data, a.offset, a.grid, w.bc) for a, w in zip(v, field_ref) 38 | ) 39 | ) 40 | 41 | 42 | class ProjectionExplicitODE(nn.Module): 43 | r"""Navier-Stokes equation in 2D with explicit stepping and a pressure projection (discrete Helmholtz decomposition by modding the gradient of a Laplacian inverse of the extra divergence). 44 | 45 | \partial u/ \partial t = explicit_terms(u) 46 | u <- pressure_projection(u) 47 | """ 48 | 49 | def explicit_terms(self, *args, **kwargs) -> GridVariableVector: 50 | """ 51 | Explicit forcing term as du/dt. 52 | """ 53 | raise NotImplementedError 54 | 55 | def pressure_projection(self, *args, **kwargs) -> Tuple[GridVariableVector, GridVariable]: 56 | """Pressure projection step.""" 57 | raise NotImplementedError 58 | 59 | def forward(self, u: GridVariableVector, dt: float) -> GridVariableVector: 60 | """Perform one time step. 61 | 62 | Args: 63 | u: Initial state (velocity field) 64 | dt: Time step size 65 | 66 | Returns: 67 | Updated velocity field after one time step 68 | """ 69 | raise NotImplementedError 70 | 71 | 72 | class RKStepper(nn.Module): 73 | """Base class for Explicit Runge-Kutta stepper. 74 | 75 | Input: 76 | tableau: Butcher tableau (a, b) for the Runge-Kutta method as a dictionary 77 | method: String name of built-in RK method if tableau not provided 78 | 79 | Examples: 80 | stepper = RKStepper.from_name("classic_rk4", equation, ...) 81 | """ 82 | 83 | _METHOD_MAP = { 84 | "forward_euler": {"a": [], "b": [1.0]}, 85 | "midpoint": {"a": [[1 / 2]], "b": [0, 1.0]}, 86 | "heun_rk2": {"a": [[1.0]], "b": [1 / 2, 1 / 2]}, 87 | "classic_rk4": { 88 | "a": [[1 / 2], [0.0, 1 / 2], [0.0, 0.0, 1.0]], 89 | "b": [1 / 6, 1 / 3, 1 / 3, 1 / 6], 90 | }, 91 | } 92 | 93 | def __init__( 94 | self, 95 | tableau: Optional[Dict[str, List]] = None, 96 | method: Optional[str] = "forward_euler", 97 | dtype: Optional[torch.dtype] = torch.float32, 98 | requires_grad=False, 99 | **kwargs, 100 | ): 101 | super().__init__() 102 | 103 | self._method = None 104 | self.dtype = dtype 105 | self.requires_grad = requires_grad 106 | 107 | # Set the tableau first directly, either directly or from method name 108 | if tableau is not None: 109 | self.tableau = tableau 110 | else: 111 | self.method = method 112 | self._set_params() 113 | 114 | @property 115 | def method(self): 116 | """Get the current Runge-Kutta method name.""" 117 | return self._method 118 | 119 | @method.setter 120 | def method(self, name: str): 121 | """Set the tableau based on the method name.""" 122 | if name not in self._METHOD_MAP: 123 | raise ValueError(f"Unknown RK method: {name}") 124 | self._method = name 125 | self._tableau = self._METHOD_MAP[name] 126 | 127 | @property 128 | def tableau(self): 129 | """Get the current tableau.""" 130 | return self._tableau 131 | 132 | @tableau.setter 133 | def tableau(self, tab: Dict[str, List]): 134 | """Set the tableau directly.""" 135 | self._tableau = tab 136 | self._method = None # Clear method name when setting tableau directly 137 | 138 | def _set_params(self): 139 | """Set the parameters of the Butcher tableau.""" 140 | try: 141 | a, b = self._tableau["a"], self._tableau["b"] 142 | if a.__len__() + 1 != b.__len__(): 143 | raise ValueError("Inconsistent Butcher tableau: len(a) + 1 != len(b)") 144 | self.params = nn.ParameterDict() 145 | self.params["a"] = nn.ParameterList() 146 | for a_ in a: 147 | self.params["a"].append( 148 | nn.Parameter( 149 | torch.tensor( 150 | a_, dtype=self.dtype, requires_grad=self.requires_grad 151 | ) 152 | ) 153 | ) 154 | self.params["b"] = nn.Parameter( 155 | torch.tensor(b, dtype=self.dtype, requires_grad=self.requires_grad) 156 | ) 157 | except KeyError as e: 158 | print(f"{e}: Either `tableau` or `method` must be given.") 159 | 160 | @classmethod 161 | def from_tableau( 162 | cls, 163 | tableau: Dict[str, List], 164 | requires_grad: bool = False, 165 | dtype: Optional[torch.dtype] = torch.float32, 166 | ): 167 | """Factory method to create an RKStepper from a Butcher tableau.""" 168 | return cls(tableau=tableau, requires_grad=requires_grad, dtype=dtype) 169 | 170 | @classmethod 171 | def from_method( 172 | cls, method: str = "forward_euler", requires_grad: bool = False, **kwargs 173 | ): 174 | """Factory method to create an RKStepper by name.""" 175 | return cls(method=method, requires_grad=requires_grad, **kwargs) 176 | 177 | def forward( 178 | self, u0: GridVariableVector, dt: float, equation: ProjectionExplicitODE 179 | ) -> Tuple[GridVariableVector, GridVariable]: 180 | """Perform one time step. 181 | 182 | Args: 183 | u0: Initial state (velocity field) 184 | dt: Time step size 185 | equation: The ODE to solve 186 | 187 | Returns: 188 | Updated velocity field after one time step 189 | """ 190 | alpha = self.params["a"] 191 | beta = self.params["b"] 192 | num_steps = len(beta) 193 | 194 | u = [None] * num_steps 195 | k = [None] * num_steps 196 | 197 | # First stage 198 | u[0] = u0 199 | k[0] = equation.explicit_terms(u0, dt) 200 | 201 | # Intermediate stages 202 | for i in range(1, num_steps): 203 | u_star = GridVariableVector(tuple(v.clone() for v in u0)) 204 | 205 | for j in range(i): 206 | if alpha[i - 1][j] != 0: 207 | u_star = u_star + dt * alpha[i - 1][j] * k[j] 208 | 209 | u[i], _ = equation.pressure_projection(u_star) 210 | k[i] = equation.explicit_terms(u[i], dt) 211 | 212 | u_star = GridVariableVector(tuple(v.clone() for v in u0)) 213 | for j in range(num_steps): 214 | if beta[j] != 0: 215 | u_star = u_star + dt * beta[j] * k[j] 216 | 217 | u_final, p = equation.pressure_projection(u_star) 218 | 219 | return u_final, p 220 | 221 | 222 | class NavierStokes2DFVMProjection(ProjectionExplicitODE): 223 | r"""incompressible Navier-Stokes velocity pressure formulation 224 | 225 | Runge-Kutta time stepper for the NSE discretized using a MAC grid FVM with a pressure projection Chorin's method. The x- and y-dofs of the velocity 226 | are on a staggered grid, which is reflected in the offset attr. 227 | 228 | Original implementation in Jax-CFD repository: 229 | 230 | - semi_implicit_navier_stokes in jax_cfd.base.fvm which returns a stepper function `time_stepper(ode, dt)` where `ode` specifies the explicit terms and the pressure projection. 231 | - The time_stepper is a wrapper function by jax.named_call( 232 | navier_stokes_rk()) that implements the various Runge-Kutta method according to the Butcher tableau. 233 | - navier_stokes_rk() implements Runge-Kutta time-stepping for the NSE using the explicit terms and pressure projection with equation as an input where user needs to specify the explicit terms and pressure projection. 234 | 235 | (Original reference listed in Jax-CFD) 236 | This class implements the reference method (equations 16-21) from: 237 | "Fast-Projection Methods for the Incompressible Navier-Stokes Equations" 238 | Fluids 2020, 5, 222; doi:10.3390/fluids5040222 239 | """ 240 | 241 | def __init__( 242 | self, 243 | viscosity: float, 244 | grid: Grid, 245 | bcs: Optional[Sequence[boundaries.BoundaryConditions]] = None, 246 | drag: float = 0.0, 247 | density: float = 1.0, 248 | convection: Callable = None, 249 | pressure_proj: Callable = None, 250 | forcing: Optional[ForcingFn] = None, 251 | step_fn: RKStepper = None, 252 | **kwargs, 253 | ): 254 | """ 255 | Args: 256 | tableau: Tuple (a, b) where a is the coefficient matrix (list of lists of floats) 257 | and b is the weight vector (list of floats) 258 | equation: Navier-Stokes equation to solve 259 | requires_grad: Whether parameters should be trainable 260 | """ 261 | super().__init__() 262 | self.viscosity = viscosity 263 | self.density = density 264 | self.grid = grid 265 | self.bcs = bcs 266 | self.drag = drag 267 | self.forcing = forcing 268 | self.convection = convection 269 | self.step_fn = step_fn 270 | self.pressure_proj = pressure_proj 271 | self._set_pressure_bc() 272 | self._set_convect() 273 | self._set_pressure_projection() 274 | 275 | def _set_convect(self): 276 | if self.convection is not None: 277 | self._convect = self.convection 278 | else: 279 | self._convect = advection.ConvectionVector(grid=self.grid, bcs=self.bcs) 280 | 281 | def _set_pressure_projection(self): 282 | if self.pressure_proj is not None: 283 | self._projection = self.pressure_proj 284 | return 285 | self._projection = pressure.PressureProjection( 286 | grid=self.grid, 287 | bc=self.pressure_bc, 288 | ) 289 | 290 | def _set_pressure_bc(self): 291 | if self.bcs is None: 292 | self.bcs = ( 293 | boundaries.periodic_boundary_conditions(ndim=self.grid.ndim), 294 | boundaries.periodic_boundary_conditions(ndim=self.grid.ndim), 295 | ) 296 | self.pressure_bc = boundaries.get_pressure_bc_from_velocity_bc(bcs=self.bcs) 297 | 298 | def _diffusion(self, v: GridVariableVector) -> GridVariableVector: 299 | """Returns the diffusion term for the velocity field.""" 300 | alpha = self.viscosity / self.density 301 | lapv = GridVariableVector(tuple(alpha * fdm.laplacian(u) for u in v)) 302 | return lapv 303 | 304 | def _explicit_terms(self, v, dt, **kwargs): 305 | dv_dt = self._convect(v, v, dt) 306 | grid = self.grid 307 | density = self.density 308 | forcing = self.forcing 309 | dv_dt += self._diffusion(v) 310 | if forcing is not None: 311 | dv_dt += GridVariableVector(forcing(grid, v)) / density 312 | dv_dt = wrap_field_same_bcs(dv_dt, v) 313 | if self.drag > 0.0: 314 | dv_dt += -self.drag * v 315 | return dv_dt 316 | 317 | def explicit_terms(self, *args, **kwargs): 318 | return self._explicit_terms(*args, **kwargs) 319 | 320 | def pressure_projection(self, *args, **kwargs): 321 | return self._projection(*args, **kwargs) 322 | 323 | def forward(self, u: GridVariableVector, dt: float) -> GridVariableVector: 324 | """Perform one time step. 325 | 326 | Args: 327 | u: Initial state (velocity field) 328 | dt: Time step size 329 | 330 | Returns: 331 | Updated velocity field after one time step 332 | """ 333 | 334 | return self.step_fn(u, dt, self) 335 | -------------------------------------------------------------------------------- /torch_cfd/initial_conditions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Modifications copyright (C) 2024 S.Cao 16 | # ported Google's Jax-CFD functional template to PyTorch's tensor ops 17 | 18 | """Prepare initial conditions for simulations.""" 19 | import math 20 | from typing import Callable, Optional, Sequence, Union 21 | 22 | import torch 23 | import torch.fft as fft 24 | 25 | from torch_cfd import grids, pressure, boundaries 26 | 27 | Grid = grids.Grid 28 | GridVariable = grids.GridVariable 29 | GridVariableVector = grids.GridVariableVector 30 | BoundaryConditions = grids.BoundaryConditions 31 | 32 | 33 | def wrap_velocities( 34 | v: Sequence[torch.Tensor], 35 | grid: Grid, 36 | bcs: Sequence[BoundaryConditions], 37 | device: Optional[torch.device] = None, 38 | ) -> GridVariableVector: 39 | """Wrap velocity arrays for input into simulations.""" 40 | device = grid.device if device is None else device 41 | return GridVariableVector(tuple( 42 | GridVariable(u, offset, grid, bc).to(device) 43 | for u, offset, bc in zip(v, grid.cell_faces, bcs) 44 | )) 45 | 46 | 47 | def wrap_vorticity( 48 | w: torch.Tensor, 49 | grid: Grid, 50 | bc: BoundaryConditions, 51 | device: Optional[torch.device] = None, 52 | ) -> GridVariable: 53 | """Wrap vorticity arrays for input into simulations.""" 54 | device = grid.device if device is None else device 55 | return GridVariable(w, grid.cell_faces, grid, bc).to(device) 56 | 57 | 58 | def _log_normal_density(k, mode: float, variance=0.25): 59 | """ 60 | Unscaled PDF for a log normal given `mode` and log variance 1. 61 | """ 62 | mean = math.log(mode) + variance 63 | logk = torch.log(k) 64 | return torch.exp(-((mean - logk) ** 2) / 2 / variance - logk) 65 | 66 | 67 | def McWilliams_density(k, mode: float, tau: float = 1.0): 68 | """Implements the McWilliams spectral density function. 69 | |\psi|^2 \sim k^{-1}(tau^2 + (k/k_0)^4)^{-1} 70 | k_0 is a prescribed wavenumber that the energy peaks. 71 | tau flattens the spectrum density at low wavenumbers to be bigger. 72 | 73 | Refs: 74 | McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. 75 | """ 76 | return (k * (tau**2 + (k / mode) ** 4)) ** (-1) 77 | 78 | 79 | def _angular_frequency_magnitude(grid: grids.Grid) -> torch.Tensor: 80 | frequencies = [ 81 | 2 * torch.pi * fft.fftfreq(size, step) 82 | for size, step in zip(grid.shape, grid.step) 83 | ] 84 | freq_vector = torch.stack(torch.meshgrid(*frequencies, indexing="ij"), axis=0) 85 | return torch.linalg.norm(freq_vector, axis=0) 86 | 87 | 88 | def spectral_filter( 89 | spectral_density: Callable[[torch.Tensor], torch.Tensor], 90 | v: Union[torch.Tensor, GridVariable], 91 | grid: Grid, 92 | ) -> torch.Tensor: 93 | """Filter a torch.Tensor with white noise to match a prescribed spectral density.""" 94 | k = _angular_frequency_magnitude(grid) 95 | filters = torch.where(k > 0, spectral_density(k), 0.0).to(v.device) 96 | # The output signal can safely be assumed to be real if our input signal was 97 | # real, because our spectral density only depends on norm(k). 98 | return fft.ifftn(fft.fftn(v) * filters).real 99 | 100 | 101 | def streamfunc_normalize(k, psi): 102 | nx, ny = psi.shape 103 | psih = fft.fft2(psi) 104 | uh_mag = k * psih 105 | kinetic_energy = (2 * uh_mag.abs() ** 2 / (nx * ny) ** 2).sum() 106 | return psi / kinetic_energy.sqrt() 107 | 108 | 109 | def project_and_normalize( 110 | v: GridVariableVector, maximum_velocity: float = 1 111 | ) -> GridVariableVector: 112 | grid = grids.consistent_grid_arrays(*v) 113 | pressure_bc = boundaries.get_pressure_bc_from_velocity(v) 114 | projection = pressure.PressureProjection(grid, pressure_bc).to(v.device) 115 | v, _ = projection(v) 116 | vmax = torch.linalg.norm(torch.stack([u.data for u in v]), dim=0).max() 117 | v = GridVariableVector(tuple(GridVariable(maximum_velocity * u.data / vmax, u.offset, u.grid, u.bc) for u in v)) 118 | return v 119 | 120 | 121 | def filtered_velocity_field( 122 | grid: Grid, 123 | maximum_velocity: float = 1, 124 | peak_wavenumber: float = 3, 125 | iterations: int = 3, 126 | random_state: int = 0, 127 | batch_size: int = 1, 128 | device: torch.device = torch.device("cpu"), 129 | ) -> GridVariableVector: 130 | """Create divergence-free velocity fields with appropriate spectral filtering. 131 | 132 | Args: 133 | rng_key: key for seeding the random initial velocity field. 134 | grid: the grid on which the velocity field is defined. 135 | maximum_velocity: the maximum speed in the velocity field. 136 | peak_wavenumber: the velocity field will be filtered so that the largest 137 | magnitudes are associated with this wavenumber. 138 | iterations: the number of repeated pressure projection and normalization 139 | iterations to apply. 140 | Returns: 141 | A divergence free velocity field with the given maximum velocity. Associates 142 | periodic boundary conditions with the velocity field components. 143 | """ 144 | 145 | # Log normal distribution peaked at `peak_wavenumber`. Note that we have to 146 | # divide by `k ** (ndim - 1)` to account for the volume of the 147 | # `ndim - 1`-sphere of values with wavenumber `k`. 148 | spectral_density = lambda k: _log_normal_density(k, peak_wavenumber) / k ** (grid.ndim - 1) 149 | result = [] 150 | 151 | for k in range(batch_size): 152 | random_states = [random_state + i + k*batch_size for i in range(grid.ndim)] 153 | rng = torch.Generator(device=device) 154 | velocity_components = [] 155 | boundary_conditions = [] 156 | for k in random_states: 157 | rng.manual_seed(k) 158 | noise = torch.randn(grid.shape, generator=rng, device=device) 159 | velocity_components.append(spectral_filter(spectral_density, noise, grid)) 160 | boundary_conditions.append(boundaries.periodic_boundary_conditions(grid.ndim)) 161 | velocity = wrap_velocities(velocity_components, grid, boundary_conditions, device=device) 162 | for _ in range(iterations): 163 | velocity = project_and_normalize(velocity, maximum_velocity) 164 | result.append(velocity) 165 | # Due to numerical precision issues, we repeatedly normalize and project the 166 | # velocity field. This ensures that it is divergence-free and achieves the 167 | # specified maximum velocity. 168 | # velocity is ((n, n), (n, n)) GridVariableVector 169 | 170 | return grids.stack_gridvariable_vectors(*result) 171 | 172 | 173 | def vorticity_field( 174 | grid: Grid, 175 | peak_wavenumber: float = 3, 176 | random_state: int = 0, 177 | batch_size: int = 1, 178 | ) -> GridVariable: 179 | """Create vorticity field with a spectral filtering 180 | using the McWilliams power spectrum density function. 181 | 182 | Args: 183 | rng_key: key for seeding the random initial vorticity field. 184 | grid: the grid on which the vorticity field is defined. 185 | peak_wavenumber: the velocity field will be filtered so that the largest 186 | magnitudes are associated with this wavenumber. 187 | 188 | Returns: 189 | A vorticity field with periodic boundary condition. 190 | """ 191 | spectral_density = lambda k: McWilliams_density(k, peak_wavenumber) 192 | 193 | rng = torch.Generator() 194 | result = [] 195 | 196 | for k in range(batch_size): 197 | random_state = random_state + k 198 | rng.manual_seed(random_state) 199 | noise = torch.randn(grid.shape, generator=rng) 200 | k = _angular_frequency_magnitude(grid) 201 | psi = spectral_filter(spectral_density, noise, grid) 202 | psi = streamfunc_normalize(k, psi) 203 | vorticity = fft.ifftn(fft.fftn(psi) * k**2).real 204 | boundary_condition = boundaries.periodic_boundary_conditions(grid.ndim) 205 | vorticity = wrap_vorticity(vorticity, grid, boundary_condition) 206 | result.append(vorticity) 207 | 208 | return grids.stack_gridvariables(*result) 209 | -------------------------------------------------------------------------------- /torch_cfd/spectral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Modifications copyright (C) 2025 S.Cao 16 | # ported Google's Jax-CFD functional template to PyTorch's tensor ops 17 | 18 | from typing import Optional, Tuple 19 | 20 | import torch 21 | import torch.fft as fft 22 | 23 | from torch_cfd import grids 24 | from einops import repeat 25 | 26 | Grid = grids.Grid 27 | 28 | 29 | def fft_mesh_2d(n, diam, device=None): 30 | kx, ky = [fft.fftfreq(n, d=diam / n) for _ in range(2)] 31 | kx, ky = torch.meshgrid([kx, ky], indexing="ij") 32 | return kx.to(device), ky.to(device) 33 | 34 | 35 | def fft_expand_dims(fft_mesh, batch_size): 36 | kx, ky = fft_mesh 37 | kx, ky = [repeat(z, "x y -> b x y 1", b=batch_size) for z in [kx, ky]] 38 | return kx, ky 39 | 40 | 41 | def spectral_laplacian_2d(fft_mesh, device=None): 42 | kx, ky = fft_mesh 43 | lap = -4 * (torch.pi**2) * (abs(kx) ** 2 + abs(ky) ** 2) 44 | # (2 * torch.pi * 1j)**2 45 | lap[..., 0, 0] = 1 46 | return lap.to(device) 47 | 48 | 49 | def spectral_curl_2d(vhat, rfft_mesh): 50 | r""" 51 | Computes the 2D curl in the Fourier basis. 52 | det [d_x d_y \\ u v] 53 | """ 54 | uhat, vhat = vhat 55 | kx, ky = rfft_mesh 56 | return 2j * torch.pi * (vhat * kx - uhat * ky) 57 | 58 | 59 | def spectral_div_2d(vhat, rfft_mesh): 60 | r""" 61 | Computes the 2D divergence in the Fourier basis. 62 | """ 63 | uhat, vhat = vhat 64 | kx, ky = rfft_mesh 65 | return 2j * torch.pi * (uhat * kx + vhat * ky) 66 | 67 | 68 | def spectral_grad_2d(vhat, rfft_mesh): 69 | kx, ky = rfft_mesh 70 | return 2j * torch.pi * kx * vhat, 2j * torch.pi * ky * vhat 71 | 72 | 73 | def spectral_rot_2d(vhat, rfft_mesh): 74 | vgradx, vgrady = spectral_grad_2d(vhat, rfft_mesh) 75 | return vgrady, -vgradx 76 | 77 | 78 | def brick_wall_filter_2d(grid: Grid): 79 | """Implements the 2/3 rule.""" 80 | n, _ = grid.shape 81 | filter_ = torch.zeros((n, n // 2 + 1)) 82 | filter_[: int(2 / 3 * n) // 2, : int(2 / 3 * (n // 2 + 1))] = 1 83 | filter_[-int(2 / 3 * n) // 2 :, : int(2 / 3 * (n // 2 + 1))] = 1 84 | return filter_ 85 | 86 | 87 | def vorticity_to_velocity( 88 | grid: Grid, w_hat: torch.Tensor, rfft_mesh: Optional[Tuple[torch.Tensor, torch.Tensor]] = None 89 | ): 90 | """Constructs a function for converting vorticity to velocity, both in Fourier domain. 91 | 92 | Solves for the stream function and then uses the stream function to compute 93 | the velocity. This is the standard approach. A quick sketch can be found in 94 | [1]. 95 | 96 | Args: 97 | grid: the grid underlying the vorticity field. 98 | 99 | Returns: 100 | A function that takes a vorticity (rfftn) and returns a velocity vector 101 | field. 102 | 103 | Reference: 104 | [1] Z. Yin, H.J.H. Clercx, D.C. Montgomery, An easily implemented task-based 105 | parallel scheme for the Fourier pseudospectral solver applied to 2D 106 | Navier-Stokes turbulence, Computers & Fluids, Volume 33, Issue 4, 2004, 107 | Pages 509-520, ISSN 0045-7930, 108 | https://doi.org/10.1016/j.compfluid.2003.06.003. 109 | """ 110 | kx, ky = rfft_mesh if rfft_mesh is not None else grid.rfft_mesh() 111 | assert kx.shape[-2:] == w_hat.shape[-2:] 112 | laplace = spectral_laplacian_2d((kx, ky)) 113 | psi_hat = -1 / laplace * w_hat 114 | u_hat, v_hat = spectral_rot_2d(psi_hat, (kx, ky)) 115 | return (u_hat, v_hat), psi_hat -------------------------------------------------------------------------------- /torch_cfd/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Modifications copyright (C) 2024 S.Cao 16 | # ported Google's Jax-CFD functional template to PyTorch's tensor ops 17 | 18 | from __future__ import annotations 19 | 20 | from typing import Any, Callable, List, Sequence, Tuple, Union 21 | 22 | import torch 23 | import torch.utils._pytree as pytree 24 | 25 | def _normalize_axis(axis: int, ndim: int) -> int: 26 | """Validates and returns positive `axis` value.""" 27 | if not -ndim <= axis < ndim: 28 | raise ValueError(f"invalid axis {axis} for ndim {ndim}") 29 | if axis < 0: 30 | axis += ndim 31 | return axis 32 | 33 | 34 | def slice_along_axis( 35 | inputs: Any, axis: int, idx: Union[slice, int], expect_same_dims: bool = True 36 | ) -> Any: 37 | """Returns slice of `inputs` defined by `idx` along axis `axis`. 38 | 39 | Args: 40 | inputs: tensor or a tuple of tensors to slice. 41 | axis: axis along which to slice the `inputs`. 42 | idx: index or slice along axis `axis` that is returned. 43 | expect_same_dims: whether all arrays should have same number of dimensions. 44 | 45 | Returns: 46 | Slice of `inputs` defined by `idx` along axis `axis`. 47 | """ 48 | arrays, tree_def = pytree.tree_flatten(inputs) 49 | ndims = set(a.ndim for a in arrays) 50 | if expect_same_dims and len(ndims) != 1: 51 | raise ValueError( 52 | "arrays in `inputs` expected to have same ndims, but have " 53 | f"{ndims}. To allow this, pass expect_same_dims=False" 54 | ) 55 | sliced = [] 56 | for array in arrays: 57 | ndim = array.ndim 58 | slc = tuple( 59 | idx if j == _normalize_axis(axis, ndim) else slice(None) 60 | for j in range(ndim) 61 | ) 62 | sliced.append(array[slc]) 63 | return pytree.tree_unflatten(tree_def, sliced) 64 | 65 | 66 | def split_along_axis( 67 | inputs: Any, split_idx: int, axis: int, expect_same_dims: bool = True 68 | ) -> Tuple[Any, Any]: 69 | """Returns a tuple of slices of `inputs` split along `axis` at `split_idx`. 70 | 71 | Args: 72 | inputs: pytree of arrays to split. 73 | split_idx: index along `axis` where the second split starts. 74 | axis: axis along which to split the `inputs`. 75 | expect_same_dims: whether all arrays should have same number of dimensions. 76 | 77 | Returns: 78 | Tuple of slices of `inputs` split along `axis` at `split_idx`. 79 | """ 80 | 81 | first_slice = slice_along_axis(inputs, axis, slice(0, split_idx), expect_same_dims) 82 | second_slice = slice_along_axis( 83 | inputs, axis, slice(split_idx, None), expect_same_dims 84 | ) 85 | return first_slice, second_slice 86 | 87 | 88 | def split_axis(inputs: Any, dim: int, keep_dims: bool = False) -> Tuple[Any, ...]: 89 | """Splits the arrays in `inputs` along `axis`. 90 | 91 | Args: 92 | inputs: pytree to be split. 93 | axis: axis along which to split the `inputs`. 94 | keep_dims: whether to keep `axis` dimension. 95 | 96 | Returns: 97 | Tuple of pytrees that correspond to slices of `inputs` along `axis`. The 98 | `axis` dimension is removed if `squeeze is set to True. 99 | 100 | Raises: 101 | ValueError: if arrays in `inputs` don't have unique size along `axis`. 102 | """ 103 | arrays, tree_def = pytree.tree_flatten(inputs) 104 | axis_shapes = set(a.shape[dim] for a in arrays) 105 | if len(axis_shapes) != 1: 106 | raise ValueError(f"Arrays must have equal sized axis but got {axis_shapes}") 107 | (axis_shape,) = axis_shapes 108 | splits = [torch.split(a, axis_shape, dim=dim) for a in arrays] 109 | if not keep_dims: 110 | splits = pytree.tree_map(lambda a: torch.squeeze(a, dim), splits) 111 | splits = zip(*splits) 112 | return tuple(pytree.tree_unflatten(tree_def, leaves) for leaves in splits) 113 | 114 | 115 | -------------------------------------------------------------------------------- /torch_cfd/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Modifications copyright (C) 2025 S.Cao 16 | # ported Google's Jax-CFD functional template to torch.Tensor operations 17 | 18 | """Shared test utilities.""" 19 | 20 | import torch 21 | from absl.testing import parameterized 22 | 23 | from torch_cfd import grids 24 | 25 | # Enable CUDA deterministic mode if needed 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | class TestCase(parameterized.TestCase): 30 | """TestCase with assertions for grids.GridVariable.""" 31 | 32 | def _check_and_remove_alignment_and_grid(self, *arrays): 33 | """Check that array-like data values and other attributes match. 34 | 35 | If args type is GridArray, verify their offsets and grids match. 36 | If args type is GridVariable, verify their offsets, grids, and bc match. 37 | 38 | Args: 39 | *arrays: one or more Array, GridArray or GridVariable, but they all be the 40 | same type. 41 | 42 | Returns: 43 | The data-only arrays, with other attributes removed. 44 | """ 45 | # GridVariable 46 | is_gridvariable = [isinstance(array, grids.GridVariable) for array in arrays] 47 | if any(is_gridvariable): 48 | self.assertTrue( 49 | all(is_gridvariable), msg=f"arrays have mixed types: {arrays}" 50 | ) 51 | try: 52 | grids.consistent_offset_arrays(*arrays) 53 | except ValueError as e: 54 | raise AssertionError(str(e)) from None 55 | try: 56 | grids.consistent_grid_arrays(*arrays) 57 | except ValueError as e: 58 | raise AssertionError(str(e)) from None 59 | arrays = tuple(array.data for array in arrays) 60 | return arrays 61 | 62 | # pylint: disable=unbalanced-tuple-unpacking 63 | def assertArrayEqual(self, actual, expected, **kwargs): 64 | actual, expected = self._check_and_remove_alignment_and_grid(actual, expected) 65 | atol = torch.finfo(expected.data.dtype).eps 66 | rtol = expected.abs().max() * atol 67 | torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol, **kwargs) 68 | 69 | def assertAllClose(self, actual, expected, **kwargs): 70 | actual, expected = self._check_and_remove_alignment_and_grid(actual, expected) 71 | torch.testing.assert_close(actual, expected, **kwargs) 72 | 73 | # pylint: enable=unbalanced-tuple-unpacking 74 | -------------------------------------------------------------------------------- /torch_cfd/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaomath/torch-cfd/6abe65ad8c49b31bc09c661974c7c4b120ab6729/torch_cfd/tests/__init__.py -------------------------------------------------------------------------------- /torch_cfd/tests/test_advection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Modifications copyright (C) 2025 S.Cao 16 | # ported Google's Jax-CFD functional template to PyTorch's tensor ops 17 | """Tests for torch_cfd.advection.""" 18 | 19 | import math 20 | from functools import partial 21 | 22 | import torch 23 | from absl.testing import absltest, parameterized 24 | 25 | from torch_cfd import advection, boundaries, grids, test_utils 26 | 27 | identity = lambda x: x 28 | 29 | Grid = grids.Grid 30 | GridVariable = grids.GridVariable 31 | GridVariableVector = grids.GridVariableVector 32 | 33 | 34 | def step_func(c, v, dt, method): 35 | c_new = c.data + dt * method(c, v, dt) 36 | return c.bc.impose_bc(c_new) 37 | 38 | 39 | def _gaussian_concentration(grid, bc): 40 | offset = tuple(-int(math.ceil(s / 2.0)) for s in grid.shape) 41 | 42 | mesh_coords = grid.mesh(offset=offset) 43 | squared_sum = torch.zeros_like(mesh_coords[0]) 44 | for m in mesh_coords: 45 | squared_sum += torch.square(m) * 30.0 46 | 47 | return GridVariable(torch.exp(-squared_sum), (0.5,) * len(grid.shape), grid, bc) 48 | 49 | 50 | def _square_concentration(grid, bc): 51 | select_square = lambda x: torch.where(torch.logical_and(x > 0.4, x < 0.6), 1.0, 0.0) 52 | mesh_coords = grid.mesh() 53 | concentration = torch.ones_like(mesh_coords[0]) 54 | for m in mesh_coords: 55 | concentration *= select_square(m) 56 | 57 | return GridVariable(concentration, (0.5,) * len(grid.shape), grid, bc) 58 | 59 | 60 | def _unit_velocity(grid, velocity_sign=1.0): 61 | ndim = grid.ndim 62 | offsets = (torch.eye(ndim) + torch.ones([ndim, ndim])) / 2.0 63 | offsets = [offsets[i].tolist() for i in range(ndim)] 64 | return GridVariableVector( 65 | tuple( 66 | GridVariable( 67 | ( 68 | velocity_sign * torch.ones(grid.shape) 69 | if ax == 0 70 | else torch.zeros(grid.shape) 71 | ), 72 | tuple(offset), 73 | grid, 74 | ) 75 | for ax, offset in enumerate(offsets) 76 | ) 77 | ) 78 | 79 | 80 | def _total_variation(c: GridVariable, dim: int = 0): 81 | next_values = c.shift(1, dim) 82 | variation = torch.abs(next_values.data - c.data).sum().item() 83 | return variation 84 | 85 | 86 | advect_linear = advection.AdvectionLinear 87 | advect_upwind = advection.AdvectionUpwind 88 | advect_van_leer = partial(advection.AdvectionVanLeer, limiter=identity) 89 | advect_van_leer_using_limiters = advection.AdvectionVanLeer 90 | 91 | 92 | class AdvectionTestAnalytical(test_utils.TestCase): 93 | 94 | @parameterized.named_parameters( 95 | dict( 96 | testcase_name="linear_2D", 97 | shape=(101, 101), 98 | dt=1 / 100, 99 | method=advect_linear, 100 | num_steps=100, 101 | cfl_number=0.1, 102 | atol=5e-2, 103 | rtol=1e-3, 104 | ), 105 | dict( 106 | testcase_name="upwind_2D", 107 | shape=(101, 101), 108 | dt=1 / 100, 109 | method=advect_upwind, 110 | num_steps=100, 111 | cfl_number=0.5, 112 | atol=7e-2, 113 | rtol=1e-3, 114 | ), 115 | dict( 116 | testcase_name="van_leer_2D", 117 | shape=(101, 101), 118 | dt=1 / 100, 119 | method=advect_van_leer, 120 | num_steps=100, 121 | cfl_number=0.5, 122 | atol=7e-2, 123 | rtol=1e-3, 124 | ), 125 | dict( 126 | testcase_name="van_leer_using_limiters_2D", 127 | shape=(101, 101), 128 | dt=1 / 100, 129 | method=advect_van_leer_using_limiters, 130 | num_steps=100, 131 | cfl_number=0.1, 132 | atol=2e-2, 133 | rtol=1e-3, 134 | ), 135 | ) 136 | def test_advection_analytical( 137 | self, shape, dt, method, num_steps, cfl_number, atol, rtol 138 | ): 139 | """Tests advection of a Gaussian concentration on a periodic grid.""" 140 | step = tuple(1.0 / s for s in shape) 141 | grid = Grid(shape, step) 142 | v_sign = 1.0 143 | bc = boundaries.periodic_boundary_conditions(len(shape)) 144 | v = GridVariableVector(tuple(u for u in _unit_velocity(grid, v_sign))) 145 | c = _gaussian_concentration(grid, bc) 146 | advect = method(grid, c.offset) 147 | dt = cfl_number * dt 148 | ct = c.clone() 149 | for _ in range(num_steps): 150 | ct = step_func(ct, v, dt, method=advect) 151 | 152 | expected_shift = int(round(-cfl_number * num_steps * v_sign)) 153 | expected = c.shift(expected_shift, dim=0) 154 | 155 | self.assertAllClose(expected.data, ct.data, atol=atol, rtol=rtol) 156 | 157 | @parameterized.named_parameters( 158 | dict( 159 | testcase_name="dirichlet_1d_200_cell_center", 160 | shape=(200,), 161 | atol=0.00025, 162 | rtol=1 / 200, 163 | offset=0.5, 164 | ), 165 | dict( 166 | testcase_name="dirichlet_1d_400_cell_center", 167 | shape=(400,), 168 | atol=0.00007, 169 | rtol=1 / 400, 170 | offset=0.5, 171 | ), 172 | dict( 173 | testcase_name="dirichlet_1d_200_cell_edge_0", 174 | shape=(200,), 175 | atol=0.0005, 176 | rtol=1 / 200, 177 | offset=0.0, 178 | ), 179 | dict( 180 | testcase_name="dirichlet_1d_400_cell_edge_0", 181 | shape=(400,), 182 | atol=0.000125, 183 | rtol=1 / 400, 184 | offset=0.0, 185 | ), 186 | dict( 187 | testcase_name="dirichlet_1d_200_cell_edge_1", 188 | shape=(200,), 189 | atol=0.0005, 190 | rtol=1 / 200, 191 | offset=1.0, 192 | ), 193 | dict( 194 | testcase_name="dirichlet_1d_400_cell_edge_1", 195 | shape=(400,), 196 | atol=0.000125, 197 | rtol=1 / 400, 198 | offset=1.0, 199 | ), 200 | ) 201 | def test_burgers_analytical_dirichlet_convergence( 202 | self, 203 | shape, 204 | atol, 205 | rtol, 206 | offset, 207 | ): 208 | def _step_func(v, dt, method): 209 | dv_dt = method(c=v[0], v=v, dt=dt) / 2 210 | return (bc.impose_bc(v[0].data + dt * dv_dt),) 211 | 212 | def _velocity_implicit(grid, offset, u, t): 213 | """Returns solution of a Burgers equation at time `t`.""" 214 | x = grid.mesh(offset)[0] 215 | return grids.GridVariable(torch.sin(x - u * t), offset, grid) 216 | 217 | num_steps = 1000 218 | cfl_number = 0.01 219 | step = 2 * math.pi / 1000 220 | offset = (offset,) 221 | grid = grids.Grid(shape, domain=((0.0, 2 * math.pi),)) 222 | bc = boundaries.dirichlet_boundary_conditions(grid.ndim) 223 | v = (bc.impose_bc(_velocity_implicit(grid, offset, 0, 0)),) 224 | dt = cfl_number * step 225 | advect = advect_van_leer(grid, offset) 226 | 227 | for _ in range(num_steps): 228 | """ 229 | dt/2 is used because for Burgers equation 230 | the flux is u_t + (0.5*u^2)_x = 0 231 | """ 232 | v = _step_func(v, dt, method=advect) 233 | 234 | expected = bc.impose_bc( 235 | _velocity_implicit(grid, offset, v[0].data, dt * num_steps) 236 | ).data 237 | self.assertAllClose(expected, v[0].data, atol=atol, rtol=rtol) 238 | 239 | class AdvectionTestProperties(test_utils.TestCase): 240 | 241 | @parameterized.named_parameters( 242 | dict( 243 | testcase_name="upwind_2D", 244 | shape=(101, 51), 245 | method=advect_upwind, 246 | ), 247 | dict( 248 | testcase_name="van_leer_2D", 249 | shape=(101, 51), 250 | method=advect_van_leer, 251 | ), 252 | dict( 253 | testcase_name="van_leer_using_limiters_2D", 254 | shape=(101, 101), 255 | method=advect_van_leer_using_limiters, 256 | ), 257 | ) 258 | def test_tvd_property(self, shape, method): 259 | atol = 1e-6 260 | step = tuple(1.0 / s for s in shape) 261 | grid = Grid(shape, step) 262 | bc = boundaries.periodic_boundary_conditions(grid.ndim) 263 | v = GridVariableVector(tuple(u for u in _unit_velocity(grid))) 264 | c = _square_concentration(grid, bc) 265 | dt = min(step) 266 | num_steps = 300 267 | ct = c.clone() 268 | 269 | advect = method(grid, c.offset) 270 | 271 | initial_total_variation = _total_variation(c, 0) + atol 272 | 273 | for _ in range(num_steps): 274 | ct = step_func(ct, v, dt, method=advect) 275 | current_total_variation = _total_variation(ct, 0) 276 | self.assertLessEqual(current_total_variation, initial_total_variation) 277 | 278 | @parameterized.named_parameters( 279 | dict( 280 | testcase_name="upwind_2D", 281 | shape=(201, 101), 282 | method=advect_upwind, 283 | ), 284 | dict( 285 | testcase_name="van_leer_2D", 286 | shape=(101, 201), 287 | method=advect_van_leer, 288 | ), 289 | dict( 290 | testcase_name="van_leer_using_limiters_2D", 291 | shape=(101, 101), 292 | method=advect_van_leer_using_limiters, 293 | ), 294 | ) 295 | def test_mass_conservation(self, shape, method): 296 | """ 297 | Note: when the mass integral is close to zero 298 | the relative error will be big~O(1e1) for fp32 299 | """ 300 | offset = (0.5, 0.5) 301 | offsets_v = ((1.0, 0.5), (0.5, 1.0)) 302 | cfl_number = 0.1 303 | dt = cfl_number / shape[0] 304 | num_steps = 1000 305 | 306 | grid = Grid(shape, domain=((-1.0, 1.0), (-1.0, 1.0))) 307 | bc = boundaries.dirichlet_boundary_conditions(grid.ndim) 308 | c_bc = boundaries.dirichlet_and_periodic_boundary_conditions(bc_vals=(0.0, 2.0)) 309 | 310 | phi = lambda t: torch.sin(math.pi * t) 311 | 312 | def _velocity(grid, offsets): 313 | x, y = grid.mesh(offsets[0]) 314 | u1 = GridVariable(-phi(x) * phi(y), offsets[0], grid) 315 | u2 = GridVariable(torch.zeros_like(u1.data), offsets[1], grid) 316 | return GridVariableVector((u1, u2)) 317 | 318 | def c0(grid, offset): 319 | x = grid.mesh(offset)[0] + 1 320 | return GridVariable(x, offset, grid) 321 | 322 | v = tuple(bc.impose_bc(u) for u in _velocity(grid, offsets_v)) 323 | c = c_bc.impose_bc(c0(grid, offset)) 324 | 325 | ct = c.clone() 326 | 327 | advect = method(grid, c.offset) 328 | 329 | initial_mass = c.data.sum().item() 330 | for _ in range(num_steps): 331 | ct = step_func(ct, v, dt, method=advect) 332 | current_total_mass = ct.data.sum().item() 333 | self.assertAllClose(current_total_mass, initial_mass, atol=1e-6, rtol=1e-2) 334 | 335 | 336 | if __name__ == "__main__": 337 | absltest.main() 338 | --------------------------------------------------------------------------------