├── .gitignore ├── LICENSE ├── README.md ├── data_provider ├── data_factory.py ├── data_loader.py └── shapenet_utils.py ├── exp ├── exp_basic.py ├── exp_dynamic_autoregressive.py ├── exp_dynamic_conditional.py ├── exp_steady.py └── exp_steady_design.py ├── layers ├── Basic.py ├── Embedding.py ├── FFNO_Layers.py ├── FNO_Layers.py ├── GeoFNO_Projection.py ├── MWT_Layers.py ├── Neural_Spectral_Block.py ├── Physics_Attention.py └── UNet_Blocks.py ├── models ├── FNO.py ├── F_FNO.py ├── Factformer.py ├── GNOT.py ├── Galerkin_Transformer.py ├── GraphSAGE.py ├── Graph_UNet.py ├── LSM.py ├── MWT.py ├── ONO.py ├── PointNet.py ├── Swin_Transformer.py ├── Transformer.py ├── Transolver.py ├── U_FNO.py ├── U_NO.py ├── U_Net.py └── model_factory.py ├── pic ├── logo.png └── task.png ├── requirements.txt ├── run.py ├── scripts ├── DesignBench │ └── car │ │ ├── GNOT.sh │ │ ├── GraphSAGE.sh │ │ ├── GraphUNet.sh │ │ ├── PointNet.sh │ │ └── Transolver.sh ├── PDEBench │ ├── 3DCFD │ │ ├── FNO.sh │ │ └── U_Net.sh │ ├── darcy │ │ ├── FNO.sh │ │ ├── MWT.sh │ │ ├── Transfomer.sh │ │ ├── Transolver.sh │ │ ├── U_FNO.sh │ │ ├── U_NO.sh │ │ └── U_Net.sh │ └── diff_sorp │ │ ├── FNO.sh │ │ ├── MWT.sh │ │ ├── Transolver.sh │ │ ├── U_FNO.sh │ │ ├── U_NO.sh │ │ └── U_Net.sh └── StandardBench │ ├── airfoil │ ├── FNO.sh │ ├── F_FNO.sh │ ├── Factformer.sh │ ├── GNOT.sh │ ├── Galerkin_Transformer.sh │ ├── LSM.sh │ ├── MWT.sh │ ├── Swin.sh │ ├── Transformer.sh │ ├── Transolver.sh │ ├── U_FNO.sh │ ├── U_NO.sh │ └── U_Net.sh │ ├── darcy │ ├── FNO.sh │ ├── F_FNO.sh │ ├── Factformer.sh │ ├── GNOT.sh │ ├── Galerkin_Transformer.sh │ ├── LSM.sh │ ├── MWT.sh │ ├── ONO.sh │ ├── Swin.sh │ ├── Transformer.sh │ ├── Transolver.sh │ ├── U_FNO.sh │ ├── U_NO.sh │ └── U_Net.sh │ ├── elasticity │ ├── FNO.sh │ ├── F_FNO.sh │ ├── GNOT.sh │ ├── Galerkin_Transformer.sh │ ├── LSM.sh │ ├── MWT.sh │ ├── Transformer.sh │ ├── Transolver.sh │ ├── U_FNO.sh │ ├── U_NO.sh │ └── U_Net.sh │ ├── ns │ ├── FNO.sh │ ├── F_FNO.sh │ ├── Factformer.sh │ ├── GNOT.sh │ ├── Galerkin_Transformer.sh │ ├── LSM.sh │ ├── MWT.sh │ ├── ONO.sh │ ├── Swin.sh │ ├── Transformer.sh │ ├── Transolver.sh │ ├── U_FNO.sh │ ├── U_NO.sh │ └── U_Net.sh │ ├── pipe │ ├── FNO.sh │ ├── F_FNO.sh │ ├── Factformer.sh │ ├── GNOT.sh │ ├── Galerkin_Transformer.sh │ ├── LSM.sh │ ├── MWT.sh │ ├── ONO.sh │ ├── Swin.sh │ ├── Transformer.sh │ ├── Transolver.sh │ ├── U_FNO.sh │ ├── U_NO.sh │ └── U_Net.sh │ └── plasticity │ ├── FNO.sh │ ├── F_FNO.sh │ ├── Factformer.sh │ ├── GNOT.sh │ ├── Galerkin_Transformer.sh │ ├── LSM.sh │ ├── MWT.sh │ ├── ONO.sh │ ├── Swin.sh │ ├── Transformer.sh │ ├── Transolver.sh │ ├── U_FNO.sh │ ├── U_NO.sh │ └── U_Net.sh └── utils ├── drag_coefficient.py ├── loss.py ├── normalizer.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # .DS_Store files 10 | .DS_Store 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | # results and checkpoints 177 | results/ 178 | checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LogoNeural-Solver-Library (NeuralSolver) 2 | 3 | NeuralSolver is an open-source library for deep learning researchers, especially for neural PDE solvers. 4 | 5 | :triangular_flag_on_post:**News** (2025.03) We release the NeuralSolver as a simple and neat code base for benchmarking neural PDE solvers, which is extended from our previous GitHub repository [Transolver](https://github.com/thuml/Transolver). 6 | 7 | ## Features 8 | 9 | This library currently supports the following benchmarks: 10 | 11 | - Six Standard Benchmarks from [[FNO]](https://arxiv.org/abs/2010.08895) and [[geo-FNO]](https://arxiv.org/abs/2207.05209) 12 | - PDEBench [[NeurIPS 2022 Track Datasets and Benchmarks]](https://arxiv.org/abs/2210.07182) for benchmarking autoregressive tasks 13 | - ShapeNet-Car from [[TOG 2018]](https://dl.acm.org/doi/abs/10.1145/3197517.3201325) for benchmarking industrial design tasks 14 | 15 |

16 | 17 |

18 | Figure 1. Examples of supported PDE-solving tasks. 19 |

20 | 21 | ## Supported Neural Solvers 22 | 23 | Here is the list of supported neural PDE solvers: 24 | 25 | - [x] **Transolver** - Transolver: A Fast Transformer Solver for PDEs on General Geometries [[ICML 2024]](https://arxiv.org/abs/2402.02366) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Transolver.py) 26 | - [x] **ONO** - Improved Operator Learning by Orthogonal Attention [[ICML 2024]](https://arxiv.org/abs/2310.12487v3) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/ONO.py) 27 | - [x] **Factformer** - Scalable Transformer for PDE Surrogate Modeling [[NeurIPS 2023]](https://arxiv.org/abs/2305.17560) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Factformer.py) 28 | - [x] **U-NO** - U-NO: U-shaped Neural Operators [[TMLR 2023]](https://openreview.net/pdf?id=j3oQF9coJd) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/U_NO.py) 29 | - [x] **LSM** - Solving High-Dimensional PDEs with Latent Spectral Models [[ICML 2023]](https://arxiv.org/pdf/2301.12664) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/LSM.py) 30 | - [x] **GNOT** - GNOT: A General Neural Operator Transformer for Operator Learning [[ICML 2023]](https://arxiv.org/abs/2302.14376) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/GNOT.py) 31 | - [x] **F-FNO** - Factorized Fourier Neural Operators [[ICLR 2023]](https://arxiv.org/abs/2111.13802) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/F_FNO.py) 32 | - [x] **U-FNO** - An enhanced Fourier neural operator-based deep-learning model for multiphase flow [[Advances in Water Resources 2022]](https://www.sciencedirect.com/science/article/pii/S0309170822000562) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/U_FNO.py) 33 | - [x] **Galerkin Transformer** - Choose a Transformer: Fourier or Galerkin [[NeurIPS 2021]](https://arxiv.org/abs/2105.14995) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Galerkin_Transformer.py) 34 | - [x] **MWT** - Multiwavelet-based Operator Learning for Differential Equations [[NeurIPS 2021]](https://openreview.net/forum?id=LZDiWaC9CGL) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/MWT.py) 35 | - [x] **FNO** - Fourier Neural Operator for Parametric Partial Differential Equations [[ICLR 2021]](https://arxiv.org/pdf/2010.08895) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/FNO.py) 36 | - [x] **Transformer** - Attention Is All You Need [[NeurIPS 2017]](https://arxiv.org/pdf/1706.03762) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Transformer.py) 37 | 38 | Some vision backbones can be good baselines for tasks in structured geometries: 39 | 40 | - [x] **Swin Transformer** - Swin Transformer: Hierarchical Vision Transformer using Shifted Windows [[ICCV 2021]](https://arxiv.org/abs/2103.14030) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Swin_Transformer.py) 41 | - [x] **U-Net** - U-Net: Convolutional Networks for Biomedical Image Segmentation [[MICCAI 2015]](https://arxiv.org/pdf/1505.04597) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/U_Net.py) 42 | 43 | Some classical geometric deep models are also included for design tasks: 44 | 45 | - [x] **Graph-UNet** - Graph U-Nets [[ICML 2019]](https://arxiv.org/pdf/1905.05178) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Graph_UNet.py) 46 | - [x] **GraphSAGE** - Inductive Representation Learning on Large Graphs [[NeurIPS 2017]](https://arxiv.org/pdf/1706.02216) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/GraphSAGE.py) 47 | - [x] **PointNet** - PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation [[CVPR 2017]](https://arxiv.org/pdf/1612.00593) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/PointNet.py) 48 | 49 | 🌟 We have made a great effort to ensure good reproducibility, and are glad to claim that the official results of all the above methods can be completely reproduced (sometimes even better) by this library. 50 | 51 | ## Usage 52 | 53 | 1. Install Python 3.8. For convenience, execute the following command. 54 | 55 | ```bash 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | 2. Prepare Data 60 | 3. Train and evaluate the model. We provide the experiment scripts for all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as follows: 61 | 62 | ```bash 63 | bash ./scripts/StandardBench/airfoil/Transolver.sh 64 | ``` 65 | 66 | 4. Develop your own model. 67 | 68 | - Add the model file to the folder `./models`. You can follow the `./models/Transolver.py`. 69 | - Include the newly added model in the `model_dict` of `./models/model_factory.py`. 70 | - Create the corresponding scripts under the folder `./scripts`, where you can set hyperparameters following the provided scripts of other models. 71 | 72 | ## Citation 73 | 74 | If you find this repo useful, please cite our paper. 75 | 76 | ``` 77 | @inproceedings{wu2024Transolver, 78 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries}, 79 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long}, 80 | booktitle={International Conference on Machine Learning}, 81 | year={2024} 82 | } 83 | ``` 84 | 85 | ## Contact 86 | 87 | If you have any questions or want to use the code, please contact our team or describe it in Issues. 88 | 89 | Current maintenance team: 90 | 91 | - Haixu Wu (Ph.D. student, [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn)) 92 | - Yuanxu Sun (Undergraduate, sunyuanx22@mails.tsinghua.edu.cn) 93 | - Hang Zhou (Master student, zhou-h23@mails.tsinghua.edu.cn) 94 | - Yuezhou Ma (Ph.D. student, mayz24@mails.tsinghua.edu.cn) 95 | 96 | ## Acknowledgement 97 | 98 | We appreciate the following GitHub repos a lot for their valuable code base or datasets: 99 | 100 | https://github.com/thuml/Transolver 101 | 102 | https://github.com/thuml/Latent-Spectral-Models 103 | 104 | https://github.com/neuraloperator/neuraloperator 105 | 106 | https://github.com/neuraloperator/Geo-FNO 107 | -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import airfoil, ns, darcy, pipe, elas, plas, pdebench_autoregressive, \ 2 | pdebench_steady_darcy, car_design, cfd3d 3 | 4 | 5 | def get_data(args): 6 | data_dict = { 7 | 'car_design': car_design, 8 | 'pdebench_autoregressive': pdebench_autoregressive, 9 | 'pdebench_steady_darcy': pdebench_steady_darcy, 10 | 'elas': elas, 11 | 'pipe': pipe, 12 | 'airfoil': airfoil, 13 | 'darcy': darcy, 14 | 'ns': ns, 15 | 'plas': plas, 16 | 'cfd3d': cfd3d, 17 | } 18 | dataset = data_dict[args.loader](args) 19 | train_loader, test_loader, shapelist = dataset.get_loader() 20 | return dataset, train_loader, test_loader, shapelist 21 | -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models.model_factory import get_model 4 | from data_provider.data_factory import get_data 5 | 6 | 7 | def count_parameters(model): 8 | total_params = 0 9 | for name, parameter in model.named_parameters(): 10 | if not parameter.requires_grad: continue 11 | params = parameter.numel() 12 | total_params += params 13 | print(f"Total Trainable Params: {total_params}") 14 | return total_params 15 | 16 | 17 | class Exp_Basic(object): 18 | def __init__(self, args): 19 | self.dataset, self.train_loader, self.test_loader, args.shapelist = get_data(args) 20 | self.model = get_model(args).cuda() 21 | self.args = args 22 | print(self.args) 23 | print(self.model) 24 | count_parameters(self.model) 25 | 26 | def vali(self): 27 | pass 28 | 29 | def train(self): 30 | pass 31 | 32 | def test(self): 33 | pass 34 | -------------------------------------------------------------------------------- /exp/exp_dynamic_autoregressive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from exp.exp_basic import Exp_Basic 4 | from models.model_factory import get_model 5 | from data_provider.data_factory import get_data 6 | from utils.loss import L2Loss 7 | import matplotlib.pyplot as plt 8 | from utils.visual import visual 9 | import numpy as np 10 | 11 | 12 | class Exp_Dynamic_Autoregressive(Exp_Basic): 13 | def __init__(self, args): 14 | super(Exp_Dynamic_Autoregressive, self).__init__(args) 15 | 16 | def vali(self): 17 | myloss = L2Loss(size_average=False) 18 | test_l2_full = 0 19 | self.model.eval() 20 | with torch.no_grad(): 21 | for x, fx, yy in self.test_loader: 22 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() 23 | for t in range(self.args.T_out): 24 | if self.args.fun_dim == 0: 25 | fx = None 26 | im = self.model(x, fx=fx) 27 | if t == 0: 28 | pred = im 29 | else: 30 | pred = torch.cat((pred, im), -1) 31 | fx = torch.cat((fx[..., self.args.out_dim:], im), dim=-1) 32 | if self.args.normalize: 33 | pred = self.dataset.y_normalizer.decode(pred) 34 | test_l2_full += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item() 35 | test_loss_full = test_l2_full / (self.args.ntest) 36 | return test_loss_full 37 | 38 | def train(self): 39 | if self.args.optimizer == 'AdamW': 40 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 41 | elif self.args.optimizer == 'Adam': 42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 43 | else: 44 | raise ValueError('Optimizer only AdamW or Adam') 45 | if self.args.scheduler == 'OneCycleLR': 46 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.args.lr, epochs=self.args.epochs, 47 | steps_per_epoch=len(self.train_loader), 48 | pct_start=self.args.pct_start) 49 | elif self.args.scheduler == 'CosineAnnealingLR': 50 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs) 51 | elif self.args.scheduler == 'StepLR': 52 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 53 | 54 | myloss = L2Loss(size_average=False) 55 | 56 | for ep in range(self.args.epochs): 57 | self.model.train() 58 | train_l2_step = 0 59 | train_l2_full = 0 60 | 61 | for pos, fx, yy in self.train_loader: 62 | loss = 0 63 | x, fx, yy = pos.cuda(), fx.cuda(), yy.cuda() 64 | for t in range(self.args.T_out): 65 | y = yy[..., self.args.out_dim * t:self.args.out_dim * (t + 1)] 66 | if self.args.fun_dim == 0: 67 | fx = None 68 | im = self.model(x, fx=fx) 69 | loss += myloss(im.reshape(x.shape[0], -1), y.reshape(x.shape[0], -1)) 70 | if t == 0: 71 | pred = im 72 | else: 73 | pred = torch.cat((pred, im), -1) 74 | 75 | if self.args.teacher_forcing: 76 | fx = torch.cat((fx[..., self.args.out_dim:], y), dim=-1) 77 | else: 78 | fx = torch.cat((fx[..., self.args.out_dim:], im), dim=-1) 79 | 80 | train_l2_step += loss.item() 81 | train_l2_full += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item() 82 | optimizer.zero_grad() 83 | loss.backward() 84 | 85 | if self.args.max_grad_norm is not None: 86 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 87 | optimizer.step() 88 | 89 | if self.args.scheduler == 'OneCycleLR': 90 | scheduler.step() 91 | if self.args.scheduler == 'CosineAnnealingLR' or self.args.scheduler == 'StepLR': 92 | scheduler.step() 93 | 94 | train_loss_step = train_l2_step / (self.args.ntrain * float(self.args.T_out)) 95 | train_loss_full = train_l2_full / (self.args.ntrain) 96 | print("Epoch {} Train loss step : {:.5f} Train loss full : {:.5f}".format(ep, train_loss_step, 97 | train_loss_full)) 98 | 99 | test_loss_full = self.vali() 100 | print("Epoch {} Test loss full : {:.5f}".format(ep, test_loss_full)) 101 | 102 | if ep % 100 == 0: 103 | if not os.path.exists('./checkpoints'): 104 | os.makedirs('./checkpoints') 105 | print('save models') 106 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt')) 107 | 108 | if not os.path.exists('./checkpoints'): 109 | os.makedirs('./checkpoints') 110 | print('final save models') 111 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt')) 112 | 113 | def test(self): 114 | self.model.load_state_dict(torch.load("./checkpoints/" + self.args.save_name + ".pt")) 115 | self.model.eval() 116 | if not os.path.exists('./results/' + self.args.save_name + '/'): 117 | os.makedirs('./results/' + self.args.save_name + '/') 118 | 119 | rel_err = 0.0 120 | id = 0 121 | myloss = L2Loss(size_average=False) 122 | with torch.no_grad(): 123 | for x, fx, yy in self.test_loader: 124 | id += 1 125 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() 126 | for t in range(self.args.T_out): 127 | if self.args.fun_dim == 0: 128 | fx = None 129 | im = self.model(x, fx=fx) 130 | fx = torch.cat((fx[..., self.args.out_dim:], im), dim=-1) 131 | if t == 0: 132 | pred = im 133 | else: 134 | pred = torch.cat((pred, im), -1) 135 | if self.args.normalize: 136 | pred = self.dataset.y_normalizer.decode(pred) 137 | rel_err += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item() 138 | if id < self.args.vis_num: 139 | print('visual: ', id) 140 | for t in range(self.args.T_out): 141 | visual(x, yy[:, :, self.args.out_dim * t:self.args.out_dim * (t + 1)], 142 | pred[:, :, self.args.out_dim * t:self.args.out_dim * (t + 1)], self.args, 143 | str(id) + '_' + str(t)) 144 | 145 | rel_err /= self.args.ntest 146 | print("rel_err:{}".format(rel_err)) 147 | -------------------------------------------------------------------------------- /exp/exp_dynamic_conditional.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from exp.exp_basic import Exp_Basic 4 | from models.model_factory import get_model 5 | from data_provider.data_factory import get_data 6 | from utils.loss import L2Loss 7 | import matplotlib.pyplot as plt 8 | from utils.visual import visual 9 | import numpy as np 10 | 11 | 12 | class Exp_Dynamic_Conditional(Exp_Basic): 13 | def __init__(self, args): 14 | super(Exp_Dynamic_Conditional, self).__init__(args) 15 | 16 | def vali(self): 17 | myloss = L2Loss(size_average=False) 18 | test_l2_full = 0 19 | self.model.eval() 20 | with torch.no_grad(): 21 | for x, time, fx, yy in self.test_loader: 22 | x, time, fx, yy = x.cuda(), time.cuda(), fx.cuda(), yy.cuda() 23 | for t in range(self.args.T_out): 24 | input_T = time[:, t:t + 1].reshape(x.shape[0], 1) 25 | if self.args.fun_dim == 0: 26 | fx = None 27 | im = self.model(x, fx=fx, T=input_T) 28 | if t == 0: 29 | pred = im 30 | else: 31 | pred = torch.cat((pred, im), -1) 32 | if self.args.normalize: 33 | pred = self.dataset.y_normalizer.decode(pred) 34 | test_l2_full += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item() 35 | test_loss_full = test_l2_full / (self.args.ntest) 36 | return test_loss_full 37 | 38 | def train(self): 39 | if self.args.optimizer == 'AdamW': 40 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 41 | elif self.args.optimizer == 'Adam': 42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 43 | else: 44 | raise ValueError('Optimizer only AdamW or Adam') 45 | if self.args.scheduler == 'OneCycleLR': 46 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.args.lr, epochs=self.args.epochs, 47 | steps_per_epoch=len(self.train_loader), 48 | pct_start=self.args.pct_start) 49 | elif self.args.scheduler == 'CosineAnnealingLR': 50 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs) 51 | elif self.args.scheduler == 'StepLR': 52 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 53 | myloss = L2Loss(size_average=False) 54 | 55 | for ep in range(self.args.epochs): 56 | self.model.train() 57 | train_l2_step = 0 58 | 59 | for pos, time, fx, yy in self.train_loader: 60 | x, time, fx, yy = pos.cuda(), time.cuda(), fx.cuda(), yy.cuda() 61 | for t in range(self.args.T_out): 62 | y = yy[..., self.args.out_dim * t:self.args.out_dim * (t + 1)] 63 | input_T = time[:, t:t + 1].reshape(x.shape[0], 1) 64 | if self.args.fun_dim == 0: 65 | fx = None 66 | im = self.model(x, fx=fx, T=input_T) 67 | loss = myloss(im.reshape(x.shape[0], -1), y.reshape(x.shape[0], -1)) 68 | train_l2_step += loss.item() 69 | optimizer.zero_grad() 70 | loss.backward() 71 | 72 | if self.args.max_grad_norm is not None: 73 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 74 | optimizer.step() 75 | 76 | if self.args.scheduler == 'OneCycleLR': 77 | scheduler.step() 78 | if self.args.scheduler == 'CosineAnnealingLR' or self.args.scheduler == 'StepLR': 79 | scheduler.step() 80 | 81 | train_loss_step = train_l2_step / (self.args.ntrain * float(self.args.T_out)) 82 | print("Epoch {} Train loss step : {:.5f} ".format(ep, train_loss_step)) 83 | 84 | test_loss_full = self.vali() 85 | print("Epoch {} Test loss full : {:.5f}".format(ep, test_loss_full)) 86 | 87 | if ep % 100 == 0: 88 | if not os.path.exists('./checkpoints'): 89 | os.makedirs('./checkpoints') 90 | print('save models') 91 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt')) 92 | 93 | if not os.path.exists('./checkpoints'): 94 | os.makedirs('./checkpoints') 95 | print('final save models') 96 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt')) 97 | 98 | def test(self): 99 | self.model.load_state_dict(torch.load("./checkpoints/" + self.args.save_name + ".pt")) 100 | self.model.eval() 101 | if not os.path.exists('./results/' + self.args.save_name + '/'): 102 | os.makedirs('./results/' + self.args.save_name + '/') 103 | 104 | rel_err = 0.0 105 | id = 0 106 | myloss = L2Loss(size_average=False) 107 | with torch.no_grad(): 108 | for x, time, fx, yy in self.test_loader: 109 | id += 1 110 | x, time, fx, yy = x.cuda(), time.cuda(), fx.cuda(), yy.cuda() # x : B, 4096, 2 fx : B, 4096 y : B, 4096, T 111 | for t in range(self.args.T_out): 112 | input_T = time[:, t:t + 1].reshape(x.shape[0], 1) # B,step 113 | if self.args.fun_dim == 0: 114 | fx = None 115 | im = self.model(x, fx=fx, T=input_T) 116 | if t == 0: 117 | pred = im 118 | else: 119 | pred = torch.cat((pred, im), -1) 120 | if self.args.normalize: 121 | pred = self.dataset.y_normalizer.decode(pred) 122 | rel_err += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item() 123 | 124 | if id < self.args.vis_num: 125 | print('visual: ', id) 126 | visual(yy[:, :, -4:-2], torch.sqrt(yy[:, :, -1:] ** 2 + yy[:, :, -2:-1] ** 2), 127 | torch.sqrt(pred[:, :, -1:] ** 2 + pred[:, :, -2:-1] ** 2), self.args, id) 128 | 129 | rel_err /= self.args.ntest 130 | print("rel_err:{}".format(rel_err)) 131 | -------------------------------------------------------------------------------- /exp/exp_steady.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from exp.exp_basic import Exp_Basic 4 | from models.model_factory import get_model 5 | from data_provider.data_factory import get_data 6 | from utils.loss import L2Loss, DerivLoss 7 | import matplotlib.pyplot as plt 8 | from utils.visual import visual 9 | import numpy as np 10 | 11 | 12 | class Exp_Steady(Exp_Basic): 13 | def __init__(self, args): 14 | super(Exp_Steady, self).__init__(args) 15 | 16 | def vali(self): 17 | myloss = L2Loss(size_average=False) 18 | self.model.eval() 19 | rel_err = 0.0 20 | with torch.no_grad(): 21 | for pos, fx, y in self.test_loader: 22 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 23 | if self.args.fun_dim == 0: 24 | fx = None 25 | out = self.model(x, fx) 26 | if self.args.normalize: 27 | out = self.dataset.y_normalizer.decode(out) 28 | 29 | tl = myloss(out, y).item() 30 | rel_err += tl 31 | 32 | rel_err /= self.args.ntest 33 | return rel_err 34 | 35 | def train(self): 36 | if self.args.optimizer == 'AdamW': 37 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 38 | elif self.args.optimizer == 'Adam': 39 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 40 | else: 41 | raise ValueError('Optimizer only AdamW or Adam') 42 | 43 | if self.args.scheduler == 'OneCycleLR': 44 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.args.lr, epochs=self.args.epochs, 45 | steps_per_epoch=len(self.train_loader), 46 | pct_start=self.args.pct_start) 47 | elif self.args.scheduler == 'CosineAnnealingLR': 48 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs) 49 | elif self.args.scheduler == 'StepLR': 50 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 51 | myloss = L2Loss(size_average=False) 52 | if self.args.derivloss: 53 | regloss = DerivLoss(size_average=False, shapelist=self.args.shapelist) 54 | 55 | for ep in range(self.args.epochs): 56 | 57 | self.model.train() 58 | train_loss = 0 59 | 60 | for pos, fx, y in self.train_loader: 61 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 62 | if self.args.fun_dim == 0: 63 | fx = None 64 | out = self.model(x, fx) 65 | if self.args.normalize: 66 | out = self.dataset.y_normalizer.decode(out) 67 | y = self.dataset.y_normalizer.decode(y) 68 | 69 | if self.args.derivloss: 70 | loss = myloss(out, y) + 0.1 * regloss(out, y) 71 | else: 72 | loss = myloss(out, y) 73 | 74 | train_loss += loss.item() 75 | optimizer.zero_grad() 76 | loss.backward() 77 | 78 | if self.args.max_grad_norm is not None: 79 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 80 | optimizer.step() 81 | 82 | if self.args.scheduler == 'OneCycleLR': 83 | scheduler.step() 84 | if self.args.scheduler == 'CosineAnnealingLR' or self.args.scheduler == 'StepLR': 85 | scheduler.step() 86 | 87 | train_loss = train_loss / self.args.ntrain 88 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss)) 89 | 90 | rel_err = self.vali() 91 | print("rel_err:{}".format(rel_err)) 92 | 93 | if ep % 100 == 0: 94 | if not os.path.exists('./checkpoints'): 95 | os.makedirs('./checkpoints') 96 | print('save models') 97 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt')) 98 | 99 | if not os.path.exists('./checkpoints'): 100 | os.makedirs('./checkpoints') 101 | print('final save models') 102 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt')) 103 | 104 | def test(self): 105 | self.model.load_state_dict(torch.load("./checkpoints/" + self.args.save_name + ".pt")) 106 | self.model.eval() 107 | if not os.path.exists('./results/' + self.args.save_name + '/'): 108 | os.makedirs('./results/' + self.args.save_name + '/') 109 | 110 | rel_err = 0.0 111 | id = 0 112 | myloss = L2Loss(size_average=False) 113 | with torch.no_grad(): 114 | for pos, fx, y in self.test_loader: 115 | id += 1 116 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 117 | if self.args.fun_dim == 0: 118 | fx = None 119 | out = self.model(x, fx) 120 | if self.args.normalize: 121 | out = self.dataset.y_normalizer.decode(out) 122 | tl = myloss(out, y).item() 123 | rel_err += tl 124 | if id < self.args.vis_num: 125 | print('visual: ', id) 126 | visual(x, y, out, self.args, id) 127 | 128 | rel_err /= self.args.ntest 129 | print("rel_err:{}".format(rel_err)) 130 | -------------------------------------------------------------------------------- /layers/Embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | import numpy as np 6 | 7 | 8 | def unified_pos_embedding(shapelist, ref, batchsize=1, device='cuda'): 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device 10 | if len(shapelist) == 1: 11 | size_x = shapelist[0] 12 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 13 | grid = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]).to(device) # B N 1 14 | gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) 15 | grid_ref = gridx.reshape(1, ref, 1).repeat([batchsize, 1, 1]).to(device) # B N 1 16 | pos = torch.sqrt(torch.sum((grid[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, dim=-1)). \ 17 | reshape(batchsize, size_x, ref).contiguous() 18 | if len(shapelist) == 2: 19 | size_x, size_y = shapelist[0], shapelist[1] 20 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 21 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 22 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 23 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 24 | grid = torch.cat((gridx, gridy), dim=-1).to(device) # B H W 2 25 | 26 | gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) 27 | gridx = gridx.reshape(1, ref, 1, 1).repeat([batchsize, 1, ref, 1]) 28 | gridy = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) 29 | gridy = gridy.reshape(1, 1, ref, 1).repeat([batchsize, ref, 1, 1]) 30 | grid_ref = torch.cat((gridx, gridy), dim=-1).to(device) # B H W 8 8 2 31 | 32 | pos = torch.sqrt(torch.sum((grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) ** 2, dim=-1)). \ 33 | reshape(batchsize, size_x * size_y, ref * ref).contiguous() 34 | if len(shapelist) == 3: 35 | size_x, size_y, size_z = shapelist[0], shapelist[1], shapelist[2] 36 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 37 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 38 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 39 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 40 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float) 41 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 42 | grid = torch.cat((gridx, gridy, gridz), dim=-1).to(device) # B H W D 3 43 | 44 | gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) 45 | gridx = gridx.reshape(1, ref, 1, 1, 1).repeat([batchsize, 1, ref, ref, 1]) 46 | gridy = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) 47 | gridy = gridy.reshape(1, 1, ref, 1, 1).repeat([batchsize, ref, 1, ref, 1]) 48 | gridz = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) 49 | gridz = gridz.reshape(1, 1, 1, ref, 1).repeat([batchsize, ref, ref, 1, 1]) 50 | grid_ref = torch.cat((gridx, gridy, gridz), dim=-1).to(device) # B 4 4 4 3 51 | 52 | pos = torch.sqrt( 53 | torch.sum((grid[:, :, :, :, None, None, None, :] - grid_ref[:, None, None, None, :, :, :, :]) ** 2, 54 | dim=-1)). \ 55 | reshape(batchsize, size_x * size_y * size_z, ref * ref * ref).contiguous() 56 | return pos 57 | 58 | 59 | class RotaryEmbedding(nn.Module): 60 | def __init__(self, dim, min_freq=1 / 2, scale=1.): 61 | super().__init__() 62 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 63 | self.min_freq = min_freq 64 | self.scale = scale 65 | self.register_buffer('inv_freq', inv_freq) 66 | 67 | def forward(self, coordinates, device='cuda'): 68 | # coordinates [b, n] 69 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device 70 | t = coordinates.to(device).type_as(self.inv_freq) 71 | t = t * (self.scale / self.min_freq) 72 | freqs = torch.einsum('... i , j -> ... i j', t, self.inv_freq) # [b, n, d//2] 73 | return torch.cat((freqs, freqs), dim=-1) # [b, n, d] 74 | 75 | 76 | def rotate_half(x): 77 | x = rearrange(x, '... (j d) -> ... j d', j=2) 78 | x1, x2 = x.unbind(dim=-2) 79 | return torch.cat((-x2, x1), dim=-1) 80 | 81 | 82 | def apply_rotary_pos_emb(t, freqs): 83 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) 84 | 85 | 86 | def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): 87 | # split t into first half and second half 88 | # t: [b, h, n, d] 89 | # freq_x/y: [b, n, d] 90 | d = t.shape[-1] 91 | t_x, t_y = t[..., :d // 2], t[..., d // 2:] 92 | 93 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x), 94 | apply_rotary_pos_emb(t_y, freqs_y)), dim=-1) 95 | 96 | 97 | class PositionalEncoding(nn.Module): 98 | "Implement the PE function." 99 | 100 | def __init__(self, d_model, dropout, max_len=421 * 421): 101 | super(PositionalEncoding, self).__init__() 102 | self.dropout = nn.Dropout(p=dropout) 103 | 104 | # Compute the positional encodings once in log space. 105 | pe = torch.zeros(max_len, d_model) 106 | position = torch.arange(0, max_len).unsqueeze(1) 107 | div_term = torch.exp( 108 | torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 109 | ) 110 | pe[:, 0::2] = torch.sin(position * div_term) 111 | pe[:, 1::2] = torch.cos(position * div_term) 112 | pe = pe.unsqueeze(0) 113 | self.register_buffer("pe", pe) 114 | 115 | def forward(self, x): 116 | x = x + self.pe[:, : x.size(1)].requires_grad_(False) 117 | return self.dropout(x) 118 | 119 | 120 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 121 | """ 122 | Create sinusoidal timestep embeddings. 123 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 124 | These may be fractional. 125 | :param dim: the dimension of the output. 126 | :param max_period: controls the minimum frequency of the embeddings. 127 | :return: an [N x dim] Tensor of positional embeddings. 128 | """ 129 | 130 | half = dim // 2 131 | freqs = torch.exp( 132 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 133 | ).to(device=timesteps.device) 134 | args = timesteps[:, None].float() * freqs[None] 135 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 136 | if dim % 2: 137 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:,:,:1])], dim=-1) 138 | return embedding 139 | -------------------------------------------------------------------------------- /layers/FFNO_Layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | import math 6 | 7 | ################################################################ 8 | # 1d fourier layer 9 | ################################################################ 10 | class SpectralConv1d(nn.Module): 11 | 12 | ## FFNO degenerate to FNO in 1D space 13 | def __init__(self, in_channels, out_channels, modes1): 14 | super(SpectralConv1d, self).__init__() 15 | 16 | """ 17 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 18 | """ 19 | 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 23 | 24 | self.scale = (1 / (in_channels * out_channels)) 25 | self.weights1 = nn.Parameter( 26 | self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat)) 27 | 28 | # Complex multiplication 29 | def compl_mul1d(self, input, weights): 30 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 31 | return torch.einsum("bix,iox->box", input, weights) 32 | 33 | def forward(self, x): 34 | batchsize = x.shape[0] 35 | # Compute Fourier coeffcients up to factor of e^(- something constant) 36 | x_ft = torch.fft.rfft(x) 37 | 38 | # Multiply relevant Fourier modes 39 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat) 40 | out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 41 | 42 | # Return to physical space 43 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 44 | return x 45 | 46 | ################################################################ 47 | # 2d fourier layer 48 | ################################################################ 49 | class SpectralConv2d(nn.Module): 50 | def __init__(self, in_dim, out_dim, modes_x, modes_y): 51 | super().__init__() 52 | self.in_dim = in_dim 53 | self.out_dim = out_dim 54 | self.modes_x = modes_x 55 | self.modes_y = modes_y 56 | 57 | self.fourier_weight = nn.ParameterList([]) 58 | for n_modes in [modes_x, modes_y]: 59 | weight = torch.FloatTensor(in_dim, out_dim, n_modes, 2) 60 | param = nn.Parameter(weight) 61 | nn.init.xavier_normal_(param) 62 | self.fourier_weight.append(param) 63 | 64 | def forward(self, x): 65 | B, I, M, N = x.shape 66 | 67 | # # # Dimesion Y # # # 68 | x_fty = torch.fft.rfft(x, dim=-1, norm='ortho') 69 | # x_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1] 70 | 71 | out_ft = x_fty.new_zeros(B, I, M, N // 2 + 1) 72 | # out_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1, 2] 73 | 74 | out_ft[:, :, :, :self.modes_y] = torch.einsum( 75 | "bixy,ioy->boxy", 76 | x_fty[:, :, :, :self.modes_y], 77 | torch.view_as_complex(self.fourier_weight[1])) 78 | 79 | xy = torch.fft.irfft(out_ft, n=N, dim=-1, norm='ortho') 80 | # x.shape == [batch_size, in_dim, grid_size, grid_size] 81 | 82 | # # # Dimesion X # # # 83 | x_ftx = torch.fft.rfft(x, dim=-2, norm='ortho') 84 | # x_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size] 85 | 86 | out_ft = x_ftx.new_zeros(B, I, M // 2 + 1, N) 87 | # out_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size, 2] 88 | 89 | out_ft[:, :, :self.modes_x, :] = torch.einsum( 90 | "bixy,iox->boxy", 91 | x_ftx[:, :, :self.modes_x, :], 92 | torch.view_as_complex(self.fourier_weight[0])) 93 | 94 | xx = torch.fft.irfft(out_ft, n=M, dim=-2, norm='ortho') 95 | # x.shape == [batch_size, in_dim, grid_size, grid_size] 96 | 97 | # # Combining Dimensions # # 98 | x = xx + xy 99 | 100 | return x 101 | 102 | 103 | ################################################################ 104 | # 3d fourier layers 105 | ################################################################ 106 | 107 | class SpectralConv3d(nn.Module): 108 | def __init__(self, in_dim, out_dim, modes_x, modes_y, modes_z): 109 | super().__init__() 110 | self.in_dim = in_dim 111 | self.out_dim = out_dim 112 | self.modes_x = modes_x 113 | self.modes_y = modes_y 114 | self.modes_z = modes_z 115 | 116 | self.fourier_weight = nn.ParameterList([]) 117 | for n_modes in [modes_x, modes_y, modes_z]: 118 | weight = torch.FloatTensor(in_dim, out_dim, n_modes, 2) 119 | param = nn.Parameter(weight) 120 | nn.init.xavier_normal_(param) 121 | self.fourier_weight.append(param) 122 | 123 | def forward(self, x): 124 | B, I, S1, S2, S3 = x.shape 125 | 126 | # # # Dimesion Z # # # 127 | x_ftz = torch.fft.rfft(x, dim=-1, norm='ortho') 128 | # x_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1] 129 | 130 | out_ft = x_ftz.new_zeros(B, I, S1, S2, S3 // 2 + 1) 131 | # out_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1, 2] 132 | 133 | out_ft[:, :, :, :, :self.modes_z] = torch.einsum( 134 | "bixyz,ioz->boxyz", 135 | x_ftz[:, :, :, :, :self.modes_z], 136 | torch.view_as_complex(self.fourier_weight[2])) 137 | 138 | xz = torch.fft.irfft(out_ft, n=S3, dim=-1, norm='ortho') 139 | # x.shape == [batch_size, in_dim, grid_size, grid_size] 140 | 141 | # # # Dimesion Y # # # 142 | x_fty = torch.fft.rfft(x, dim=-2, norm='ortho') 143 | # x_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size] 144 | 145 | out_ft = x_fty.new_zeros(B, I, S1, S2 // 2 + 1, S3) 146 | # out_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size, 2] 147 | 148 | out_ft[:, :, :, :self.modes_y, :] = torch.einsum( 149 | "bixyz,ioy->boxyz", 150 | x_fty[:, :, :, :self.modes_y, :], 151 | torch.view_as_complex(self.fourier_weight[1])) 152 | 153 | xy = torch.fft.irfft(out_ft, n=S2, dim=-2, norm='ortho') 154 | # x.shape == [batch_size, in_dim, grid_size, grid_size] 155 | 156 | # # # Dimesion X # # # 157 | x_ftx = torch.fft.rfft(x, dim=-3, norm='ortho') 158 | # x_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size] 159 | 160 | out_ft = x_ftx.new_zeros(B, I, S1 // 2 + 1, S2, S3) 161 | # out_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size, 2] 162 | 163 | out_ft[:, :, :self.modes_x, :, :] = torch.einsum( 164 | "bixyz,iox->boxyz", 165 | x_ftx[:, :, :self.modes_x, :, :], 166 | torch.view_as_complex(self.fourier_weight[0])) 167 | 168 | xx = torch.fft.irfft(out_ft, n=S1, dim=-3, norm='ortho') 169 | # x.shape == [batch_size, in_dim, grid_size, grid_size] 170 | 171 | # # Combining Dimensions # # 172 | x = xx + xy + xz 173 | 174 | return x 175 | -------------------------------------------------------------------------------- /layers/FNO_Layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | import math 6 | 7 | 8 | ################################################################ 9 | # 1d fourier layer 10 | ################################################################ 11 | class SpectralConv1d(nn.Module): 12 | def __init__(self, in_channels, out_channels, modes1): 13 | super(SpectralConv1d, self).__init__() 14 | 15 | """ 16 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 17 | """ 18 | 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 22 | 23 | self.scale = (1 / (in_channels * out_channels)) 24 | self.weights1 = nn.Parameter( 25 | self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat)) 26 | 27 | # Complex multiplication 28 | def compl_mul1d(self, input, weights): 29 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 30 | return torch.einsum("bix,iox->box", input, weights) 31 | 32 | def forward(self, x): 33 | batchsize = x.shape[0] 34 | # Compute Fourier coeffcients up to factor of e^(- something constant) 35 | x_ft = torch.fft.rfft(x) 36 | 37 | # Multiply relevant Fourier modes 38 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat) 39 | out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 40 | 41 | # Return to physical space 42 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 43 | return x 44 | 45 | 46 | ################################################################ 47 | # 2d fourier layer 48 | ################################################################ 49 | class SpectralConv2d(nn.Module): 50 | def __init__(self, in_channels, out_channels, modes1, modes2): 51 | super(SpectralConv2d, self).__init__() 52 | """ 53 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 54 | """ 55 | self.in_channels = in_channels 56 | self.out_channels = out_channels 57 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 58 | self.modes2 = modes2 59 | 60 | self.scale = (1 / (in_channels * out_channels)) 61 | self.weights1 = nn.Parameter( 62 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 63 | self.weights2 = nn.Parameter( 64 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 65 | 66 | # Complex multiplication 67 | def compl_mul2d(self, input, weights): 68 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 69 | return torch.einsum("bixy,ioxy->boxy", input, weights) 70 | 71 | def forward(self, x): 72 | batchsize = x.shape[0] 73 | # Compute Fourier coeffcients up to factor of e^(- something constant) 74 | x_ft = torch.fft.rfft2(x) 75 | 76 | # Multiply relevant Fourier modes 77 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, 78 | device=x.device) 79 | out_ft[:, :, :self.modes1, :self.modes2] = \ 80 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 81 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 82 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 83 | 84 | # Return to physical space 85 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 86 | return x 87 | 88 | 89 | ################################################################ 90 | # 3d fourier layers 91 | ################################################################ 92 | 93 | class SpectralConv3d(nn.Module): 94 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 95 | super(SpectralConv3d, self).__init__() 96 | 97 | """ 98 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. 99 | """ 100 | 101 | self.in_channels = in_channels 102 | self.out_channels = out_channels 103 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 104 | self.modes2 = modes2 105 | self.modes3 = modes3 106 | 107 | self.scale = (1 / (in_channels * out_channels)) 108 | self.weights1 = nn.Parameter( 109 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 110 | dtype=torch.cfloat)) 111 | self.weights2 = nn.Parameter( 112 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 113 | dtype=torch.cfloat)) 114 | self.weights3 = nn.Parameter( 115 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 116 | dtype=torch.cfloat)) 117 | self.weights4 = nn.Parameter( 118 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 119 | dtype=torch.cfloat)) 120 | 121 | # Complex multiplication 122 | def compl_mul3d(self, input, weights): 123 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 124 | return torch.einsum("bixyz,ioxyz->boxyz", input, weights) 125 | 126 | def forward(self, x): 127 | batchsize = x.shape[0] 128 | # Compute Fourier coeffcients up to factor of e^(- something constant) 129 | x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1]) 130 | 131 | # Multiply relevant Fourier modes 132 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1) // 2 + 1, 133 | dtype=torch.cfloat, device=x.device) 134 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 135 | self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 136 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 137 | self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 138 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 139 | self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 140 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 141 | self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 142 | 143 | # Return to physical space 144 | x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) 145 | return x 146 | -------------------------------------------------------------------------------- /layers/GeoFNO_Projection.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | import math 6 | 7 | 8 | ################################################################ 9 | # geo projection 10 | ################################################################ 11 | class SpectralConv2d_IrregularGeo(nn.Module): 12 | def __init__(self, in_channels, out_channels, modes1, modes2, s1=32, s2=32): 13 | super(SpectralConv2d_IrregularGeo, self).__init__() 14 | 15 | """ 16 | from geoFNO 17 | """ 18 | 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 22 | self.modes2 = modes2 23 | self.s1 = s1 24 | self.s2 = s2 25 | 26 | self.scale = (1 / (in_channels * out_channels)) 27 | self.weights1 = nn.Parameter( 28 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 29 | self.weights2 = nn.Parameter( 30 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 31 | 32 | # Complex multiplication 33 | def compl_mul2d(self, input, weights): 34 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 35 | return torch.einsum("bixy,ioxy->boxy", input, weights) 36 | 37 | def forward(self, u, x_in=None, x_out=None, iphi=None, code=None): 38 | batchsize = u.shape[0] 39 | 40 | # Compute Fourier coeffcients up to factor of e^(- something constant) 41 | if x_in == None: 42 | u_ft = torch.fft.rfft2(u) 43 | s1 = u.size(-2) 44 | s2 = u.size(-1) 45 | else: 46 | u_ft = self.fft2d(u, x_in, iphi, code) 47 | s1 = self.s1 48 | s2 = self.s2 49 | 50 | # Multiply relevant Fourier modes 51 | # print(u.shape, u_ft.shape) 52 | factor1 = self.compl_mul2d(u_ft[:, :, :self.modes1, :self.modes2], self.weights1) 53 | factor2 = self.compl_mul2d(u_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 54 | 55 | # Return to physical space 56 | if x_out == None: 57 | out_ft = torch.zeros(batchsize, self.out_channels, s1, s2 // 2 + 1, dtype=torch.cfloat, device=u.device) 58 | out_ft[:, :, :self.modes1, :self.modes2] = factor1 59 | out_ft[:, :, -self.modes1:, :self.modes2] = factor2 60 | u = torch.fft.irfft2(out_ft, s=(s1, s2)) 61 | else: 62 | out_ft = torch.cat([factor1, factor2], dim=-2) 63 | u = self.ifft2d(out_ft, x_out, iphi, code) 64 | 65 | return u 66 | 67 | def fft2d(self, u, x_in, iphi=None, code=None): 68 | # u (batch, channels, n) 69 | # x_in (batch, n, 2) locations in [0,1]*[0,1] 70 | # iphi: function: x_in -> x_c 71 | 72 | batchsize = x_in.shape[0] 73 | N = x_in.shape[1] 74 | device = x_in.device 75 | m1 = 2 * self.modes1 76 | m2 = 2 * self.modes2 - 1 77 | 78 | # wavenumber (m1, m2) 79 | k_x1 = torch.cat((torch.arange(start=0, end=self.modes1, step=1), \ 80 | torch.arange(start=-(self.modes1), end=0, step=1)), 0).reshape(m1, 1).repeat(1, m2).to(device) 81 | k_x2 = torch.cat((torch.arange(start=0, end=self.modes2, step=1), \ 82 | torch.arange(start=-(self.modes2 - 1), end=0, step=1)), 0).reshape(1, m2).repeat(m1, 1).to( 83 | device) 84 | 85 | if iphi == None: 86 | x = x_in 87 | else: 88 | x = iphi(x_in, code) 89 | 90 | # K = , (batch, N, m1, m2) 91 | K1 = torch.outer(x[..., 0].view(-1), k_x1.view(-1)).reshape(batchsize, N, m1, m2) 92 | K2 = torch.outer(x[..., 1].view(-1), k_x2.view(-1)).reshape(batchsize, N, m1, m2) 93 | K = K1 + K2 94 | 95 | # basis (batch, N, m1, m2) 96 | basis = torch.exp(-1j * 2 * np.pi * K).to(device) 97 | 98 | # Y (batch, channels, N) 99 | u = u + 0j 100 | Y = torch.einsum("bcn,bnxy->bcxy", u, basis) 101 | return Y 102 | 103 | def ifft2d(self, u_ft, x_out, iphi=None, code=None): 104 | # u_ft (batch, channels, kmax, kmax) 105 | # x_out (batch, N, 2) locations in [0,1]*[0,1] 106 | # iphi: function: x_out -> x_c 107 | 108 | batchsize = x_out.shape[0] 109 | N = x_out.shape[1] 110 | device = x_out.device 111 | m1 = 2 * self.modes1 112 | m2 = 2 * self.modes2 - 1 113 | 114 | # wavenumber (m1, m2) 115 | k_x1 = torch.cat((torch.arange(start=0, end=self.modes1, step=1), \ 116 | torch.arange(start=-(self.modes1), end=0, step=1)), 0).reshape(m1, 1).repeat(1, m2).to(device) 117 | k_x2 = torch.cat((torch.arange(start=0, end=self.modes2, step=1), \ 118 | torch.arange(start=-(self.modes2 - 1), end=0, step=1)), 0).reshape(1, m2).repeat(m1, 1).to( 119 | device) 120 | 121 | if iphi == None: 122 | x = x_out 123 | else: 124 | x = iphi(x_out, code) 125 | 126 | # K = , (batch, N, m1, m2) 127 | K1 = torch.outer(x[:, :, 0].view(-1), k_x1.view(-1)).reshape(batchsize, N, m1, m2) 128 | K2 = torch.outer(x[:, :, 1].view(-1), k_x2.view(-1)).reshape(batchsize, N, m1, m2) 129 | K = K1 + K2 130 | 131 | # basis (batch, N, m1, m2) 132 | basis = torch.exp(1j * 2 * np.pi * K).to(device) 133 | 134 | # coeff (batch, channels, m1, m2) 135 | u_ft2 = u_ft[..., 1:].flip(-1, -2).conj() 136 | u_ft = torch.cat([u_ft, u_ft2], dim=-1) 137 | 138 | # Y (batch, channels, N) 139 | Y = torch.einsum("bcxy,bnxy->bcn", u_ft, basis) 140 | Y = Y.real 141 | return Y 142 | 143 | 144 | class IPHI(nn.Module): 145 | def __init__(self, width=32): 146 | super(IPHI, self).__init__() 147 | 148 | """ 149 | inverse phi: x -> xi 150 | """ 151 | self.width = width 152 | self.fc0 = nn.Linear(4, self.width) 153 | self.fc_code = nn.Linear(42, self.width) 154 | self.fc_no_code = nn.Linear(3 * self.width, 4 * self.width) 155 | self.fc1 = nn.Linear(4 * self.width, 4 * self.width) 156 | self.fc2 = nn.Linear(4 * self.width, 4 * self.width) 157 | self.fc3 = nn.Linear(4 * self.width, 4 * self.width) 158 | self.fc4 = nn.Linear(4 * self.width, 2) 159 | self.activation = torch.tanh 160 | self.center = torch.tensor([0.0001, 0.0001], device="cuda").reshape(1, 1, 2) 161 | 162 | self.B = np.pi * torch.pow(2, torch.arange(0, self.width // 4, dtype=torch.float, device="cuda")).reshape(1, 1, 163 | 1, 164 | self.width // 4) 165 | 166 | def forward(self, x, code=None): 167 | # x (batch, N_grid, 2) 168 | # code (batch, N_features) 169 | 170 | # some feature engineering 171 | angle = torch.atan2(x[:, :, 1] - self.center[:, :, 1], x[:, :, 0] - self.center[:, :, 0]) 172 | radius = torch.norm(x - self.center, dim=-1, p=2) 173 | xd = torch.stack([x[:, :, 0], x[:, :, 1], angle, radius], dim=-1) 174 | 175 | # sin features from NeRF 176 | b, n, d = xd.shape[0], xd.shape[1], xd.shape[2] 177 | x_sin = torch.sin(self.B * xd.view(b, n, d, 1)).view(b, n, d * self.width // 4) 178 | x_cos = torch.cos(self.B * xd.view(b, n, d, 1)).view(b, n, d * self.width // 4) 179 | xd = self.fc0(xd) 180 | xd = torch.cat([xd, x_sin, x_cos], dim=-1).reshape(b, n, 3 * self.width) 181 | 182 | if code != None: 183 | cd = self.fc_code(code) 184 | cd = cd.unsqueeze(1).repeat(1, xd.shape[1], 1) 185 | xd = torch.cat([cd, xd], dim=-1) 186 | else: 187 | xd = self.fc_no_code(xd) 188 | 189 | xd = self.fc1(xd) 190 | xd = self.activation(xd) 191 | xd = self.fc2(xd) 192 | xd = self.activation(xd) 193 | xd = self.fc3(xd) 194 | xd = self.activation(xd) 195 | xd = self.fc4(xd) 196 | return x + x * xd 197 | -------------------------------------------------------------------------------- /models/FNO.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | from layers.Basic import MLP 8 | from layers.Embedding import timestep_embedding, unified_pos_embedding 9 | from layers.FNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d 10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI 11 | 12 | BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d] 13 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d] 14 | 15 | 16 | class Model(nn.Module): 17 | def __init__(self, args, s1=96, s2=96): 18 | super(Model, self).__init__() 19 | self.__name__ = 'FNO' 20 | self.args = args 21 | ## embedding 22 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 23 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 24 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 25 | args.n_hidden, n_layers=0, res=False, act=args.act) 26 | else: 27 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 28 | n_layers=0, res=False, act=args.act) 29 | if args.time_input: 30 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 31 | nn.Linear(args.n_hidden, args.n_hidden)) 32 | # geometry projection 33 | if self.args.geotype == 'unstructured': 34 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1, 35 | s2) 36 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1, 37 | s2) 38 | self.iphi = IPHI() 39 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]] 40 | else: 41 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist] 42 | self.conv0 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 43 | *[args.modes for _ in range(len(self.padding))]) 44 | self.conv1 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 45 | *[args.modes for _ in range(len(self.padding))]) 46 | self.conv2 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 47 | *[args.modes for _ in range(len(self.padding))]) 48 | self.conv3 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 49 | *[args.modes for _ in range(len(self.padding))]) 50 | self.w0 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 51 | self.w1 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 52 | self.w2 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 53 | self.w3 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 54 | # projectors 55 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden) 56 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim) 57 | 58 | def structured_geo(self, x, fx, T=None): 59 | B, N, _ = x.shape 60 | if self.args.unified_pos: 61 | x = self.pos.repeat(x.shape[0], 1, 1) 62 | if fx is not None: 63 | fx = torch.cat((x, fx), -1) 64 | fx = self.preprocess(fx) 65 | else: 66 | fx = self.preprocess(x) 67 | 68 | if T is not None: 69 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 70 | Time_emb = self.time_fc(Time_emb) 71 | fx = fx + Time_emb 72 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist) 73 | if not all(item == 0 for item in self.padding): 74 | if len(self.args.shapelist) == 2: 75 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]]) 76 | elif len(self.args.shapelist) == 3: 77 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]]) 78 | 79 | x1 = self.conv0(x) 80 | x2 = self.w0(x) 81 | x = x1 + x2 82 | x = F.gelu(x) 83 | 84 | x1 = self.conv1(x) 85 | x2 = self.w1(x) 86 | x = x1 + x2 87 | x = F.gelu(x) 88 | 89 | x1 = self.conv2(x) 90 | x2 = self.w2(x) 91 | x = x1 + x2 92 | x = F.gelu(x) 93 | 94 | x1 = self.conv3(x) 95 | x2 = self.w3(x) 96 | x = x1 + x2 97 | 98 | if not all(item == 0 for item in self.padding): 99 | if len(self.args.shapelist) == 2: 100 | x = x[..., :-self.padding[0], :-self.padding[1]] 101 | elif len(self.args.shapelist) == 3: 102 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]] 103 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1) 104 | x = self.fc1(x) 105 | x = F.gelu(x) 106 | x = self.fc2(x) 107 | return x 108 | 109 | def unstructured_geo(self, x, fx, T=None): 110 | original_pos = x 111 | if fx is not None: 112 | fx = torch.cat((x, fx), -1) 113 | fx = self.preprocess(fx) 114 | else: 115 | fx = self.preprocess(x) 116 | 117 | if T is not None: 118 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 119 | Time_emb = self.time_fc(Time_emb) 120 | fx = fx + Time_emb 121 | 122 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None) 123 | 124 | x1 = self.conv0(x) 125 | x2 = self.w0(x) 126 | x = x1 + x2 127 | x = F.gelu(x) 128 | 129 | x1 = self.conv1(x) 130 | x2 = self.w1(x) 131 | x = x1 + x2 132 | x = F.gelu(x) 133 | 134 | x1 = self.conv2(x) 135 | x2 = self.w2(x) 136 | x = x1 + x2 137 | x = F.gelu(x) 138 | 139 | x1 = self.conv3(x) 140 | x2 = self.w3(x) 141 | x = x1 + x2 142 | 143 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1) 144 | x = self.fc1(x) 145 | x = F.gelu(x) 146 | x = self.fc2(x) 147 | return x 148 | 149 | def forward(self, x, fx, T=None, geo=None): 150 | if self.args.geotype == 'unstructured': 151 | return self.unstructured_geo(x, fx, T) 152 | else: 153 | return self.structured_geo(x, fx, T) 154 | -------------------------------------------------------------------------------- /models/F_FNO.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | from layers.Basic import MLP 8 | from layers.Embedding import timestep_embedding, unified_pos_embedding 9 | from layers.FFNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d 10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI 11 | 12 | BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d] 13 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d] 14 | 15 | 16 | class Model(nn.Module): 17 | def __init__(self, args, s1=96, s2=96): 18 | super(Model, self).__init__() 19 | self.__name__ = 'F-FNO' 20 | self.args = args 21 | ## embedding 22 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 23 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 24 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 25 | args.n_hidden, n_layers=0, res=False, act=args.act) 26 | else: 27 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 28 | n_layers=0, res=False, act=args.act) 29 | if args.time_input: 30 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 31 | nn.Linear(args.n_hidden, args.n_hidden)) 32 | # geometry projection 33 | if self.args.geotype == 'unstructured': 34 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1, 35 | s2) 36 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1, 37 | s2) 38 | self.iphi = IPHI() 39 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]] 40 | else: 41 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist] 42 | 43 | self.spectral_layers = nn.ModuleList([]) 44 | for _ in range(args.n_layers): 45 | self.spectral_layers.append(BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 46 | *[args.modes for _ in range(len(self.padding))])) 47 | # projectors 48 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden) 49 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim) 50 | 51 | def structured_geo(self, x, fx, T=None): 52 | B, N, _ = x.shape 53 | if self.args.unified_pos: 54 | x = self.pos.repeat(x.shape[0], 1, 1) 55 | if fx is not None: 56 | fx = torch.cat((x, fx), -1) 57 | fx = self.preprocess(fx) 58 | else: 59 | fx = self.preprocess(x) 60 | 61 | if T is not None: 62 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 63 | Time_emb = self.time_fc(Time_emb) 64 | fx = fx + Time_emb 65 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist) 66 | if not all(item == 0 for item in self.padding): 67 | if len(self.args.shapelist) == 2: 68 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]]) 69 | elif len(self.args.shapelist) == 3: 70 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]]) 71 | 72 | for i in range(self.args.n_layers): 73 | x = x + self.spectral_layers[i](x) 74 | 75 | if not all(item == 0 for item in self.padding): 76 | if len(self.args.shapelist) == 2: 77 | x = x[..., :-self.padding[0], :-self.padding[1]] 78 | elif len(self.args.shapelist) == 3: 79 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]] 80 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1) 81 | x = self.fc1(x) 82 | x = F.gelu(x) 83 | x = self.fc2(x) 84 | return x 85 | 86 | def unstructured_geo(self, x, fx, T=None): 87 | original_pos = x 88 | if fx is not None: 89 | fx = torch.cat((x, fx), -1) 90 | fx = self.preprocess(fx) 91 | else: 92 | fx = self.preprocess(x) 93 | 94 | if T is not None: 95 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 96 | Time_emb = self.time_fc(Time_emb) 97 | fx = fx + Time_emb 98 | 99 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None) 100 | for i in range(self.args.n_layers): 101 | x = x + self.spectral_layers[i](x) 102 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1) 103 | x = self.fc1(x) 104 | x = F.gelu(x) 105 | x = self.fc2(x) 106 | return x 107 | 108 | def forward(self, x, fx, T=None, geo=None): 109 | if self.args.geotype == 'unstructured': 110 | return self.unstructured_geo(x, fx, T) 111 | else: 112 | return self.structured_geo(x, fx, T) 113 | -------------------------------------------------------------------------------- /models/GNOT.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.layers import trunc_normal_ 6 | from einops import repeat, rearrange 7 | from torch.nn import functional as F 8 | from layers.Basic import MLP, LinearAttention, ACTIVATION 9 | from layers.Embedding import timestep_embedding, unified_pos_embedding 10 | 11 | class GNOT_block(nn.Module): 12 | """Transformer encoder block in MOE style.""" 13 | 14 | def __init__(self, num_heads: int, 15 | hidden_dim: int, 16 | dropout: float, 17 | act='gelu', 18 | mlp_ratio=4, 19 | space_dim=2, 20 | n_experts=3): 21 | super(GNOT_block, self).__init__() 22 | self.ln1 = nn.LayerNorm(hidden_dim) 23 | self.ln2 = nn.LayerNorm(hidden_dim) 24 | self.ln3 = nn.LayerNorm(hidden_dim) 25 | self.ln4 = nn.LayerNorm(hidden_dim) 26 | self.ln5 = nn.LayerNorm(hidden_dim) 27 | 28 | self.selfattn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout) 29 | self.crossattn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout) 30 | self.resid_drop1 = nn.Dropout(dropout) 31 | self.resid_drop2 = nn.Dropout(dropout) 32 | 33 | ## MLP in MOE 34 | self.n_experts = n_experts 35 | if act in ACTIVATION.keys(): 36 | self.act = ACTIVATION[act] 37 | self.moe_mlp1 = nn.ModuleList([nn.Sequential( 38 | nn.Linear(hidden_dim, hidden_dim * mlp_ratio), 39 | self.act(), 40 | nn.Linear(hidden_dim * mlp_ratio, hidden_dim), 41 | ) for _ in range(self.n_experts)]) 42 | 43 | self.moe_mlp2 = nn.ModuleList([nn.Sequential( 44 | nn.Linear(hidden_dim, hidden_dim * mlp_ratio), 45 | self.act(), 46 | nn.Linear(hidden_dim * mlp_ratio, hidden_dim), 47 | ) for _ in range(self.n_experts)]) 48 | 49 | self.gatenet = nn.Sequential( 50 | nn.Linear(space_dim, hidden_dim * mlp_ratio), 51 | self.act(), 52 | nn.Linear(hidden_dim * mlp_ratio, hidden_dim * mlp_ratio), 53 | self.act(), 54 | nn.Linear(hidden_dim * mlp_ratio, self.n_experts) 55 | ) 56 | 57 | def forward(self, x, y, pos): 58 | ## point-wise gate for moe 59 | gate_score = F.softmax(self.gatenet(pos), dim=-1).unsqueeze(2) 60 | ## cross attention between geo and physics observation 61 | x = x + self.resid_drop1(self.crossattn(self.ln1(x), self.ln2(y))) 62 | ## moe mlp 63 | x_moe1 = torch.stack([self.moe_mlp1[i](x) for i in range(self.n_experts)], dim=-1) 64 | x_moe1 = (gate_score * x_moe1).sum(dim=-1, keepdim=False) 65 | x = x + self.ln3(x_moe1) 66 | ## self attention among geo 67 | x = x + self.resid_drop2(self.selfattn(self.ln4(x))) 68 | ## moe mlp 69 | x_moe2 = torch.stack([self.moe_mlp2[i](x) for i in range(self.n_experts)], dim=-1) 70 | x_moe2 = (gate_score * x_moe2).sum(dim=-1, keepdim=False) 71 | x = x + self.ln5(x_moe2) 72 | return x 73 | 74 | 75 | class Model(nn.Module): 76 | ## GNOT: Transformer in MOE style 77 | def __init__(self, args, n_experts=3): 78 | super(Model, self).__init__() 79 | self.__name__ = 'GNOT' 80 | self.args = args 81 | ## embedding 82 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 83 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 84 | self.preprocess_x = MLP(args.ref ** len(args.shapelist), args.n_hidden * 2, 85 | args.n_hidden, n_layers=0, res=False, act=args.act) 86 | self.preprocess_z = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 87 | args.n_hidden, n_layers=0, res=False, act=args.act) 88 | else: 89 | self.preprocess_x = MLP(args.space_dim, args.n_hidden * 2, args.n_hidden, 90 | n_layers=0, res=False, act=args.act) 91 | self.preprocess_z = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 92 | n_layers=0, res=False, act=args.act) 93 | if args.time_input: 94 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 95 | nn.Linear(args.n_hidden, args.n_hidden)) 96 | 97 | ## models 98 | self.blocks = nn.ModuleList([GNOT_block(num_heads=args.n_heads, 99 | hidden_dim=args.n_hidden, 100 | dropout=args.dropout, 101 | act=args.act, 102 | mlp_ratio=args.mlp_ratio, 103 | space_dim=args.space_dim, 104 | n_experts=n_experts) 105 | for _ in range(args.n_layers)]) 106 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float)) 107 | # projectors 108 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden * 2) 109 | self.fc2 = nn.Linear(args.n_hidden * 2, args.out_dim) 110 | self.initialize_weights() 111 | 112 | def initialize_weights(self): 113 | self.apply(self._init_weights) 114 | 115 | def _init_weights(self, m): 116 | if isinstance(m, nn.Linear): 117 | trunc_normal_(m.weight, std=0.02) 118 | if isinstance(m, nn.Linear) and m.bias is not None: 119 | nn.init.constant_(m.bias, 0) 120 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 121 | nn.init.constant_(m.bias, 0) 122 | nn.init.constant_(m.weight, 1.0) 123 | 124 | def forward(self, x, fx, T=None, geo=None): 125 | pos = x 126 | if self.args.unified_pos: 127 | x = self.pos.repeat(x.shape[0], 1, 1) 128 | if fx is not None: 129 | fx = torch.cat((x, fx), -1) 130 | fx = self.preprocess_z(fx) 131 | else: 132 | fx = self.preprocess_z(x) 133 | fx = fx + self.placeholder[None, None, :] 134 | x = self.preprocess_x(x) 135 | if T is not None: 136 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 137 | Time_emb = self.time_fc(Time_emb) 138 | fx = fx + Time_emb 139 | 140 | for block in self.blocks: 141 | fx = block(x, fx, pos) 142 | fx = self.fc1(fx) 143 | fx = F.gelu(fx) 144 | fx = self.fc2(fx) 145 | return fx 146 | -------------------------------------------------------------------------------- /models/Galerkin_Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from timm.models.layers import trunc_normal_ 6 | from layers.Basic import MLP, LinearAttention 7 | from layers.Embedding import timestep_embedding, unified_pos_embedding 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | class Galerkin_Transformer_block(nn.Module): 13 | """Transformer encoder block.""" 14 | 15 | def __init__( 16 | self, 17 | num_heads: int, 18 | hidden_dim: int, 19 | dropout: float, 20 | act='gelu', 21 | mlp_ratio=4, 22 | last_layer=False, 23 | out_dim=1, 24 | ): 25 | super().__init__() 26 | self.last_layer = last_layer 27 | self.ln_1 = nn.LayerNorm(hidden_dim) 28 | self.ln_1a = nn.LayerNorm(hidden_dim) 29 | self.Attn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, 30 | dropout=dropout, attn_type='galerkin') 31 | self.ln_2 = nn.LayerNorm(hidden_dim) 32 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 33 | if self.last_layer: 34 | self.ln_3 = nn.LayerNorm(hidden_dim) 35 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 36 | 37 | def forward(self, fx): 38 | fx = self.Attn(self.ln_1(fx), self.ln_1a(fx)) + fx 39 | fx = self.mlp(self.ln_2(fx)) + fx 40 | if self.last_layer: 41 | return self.mlp2(self.ln_3(fx)) 42 | else: 43 | return fx 44 | 45 | 46 | class Model(nn.Module): 47 | ## Factformer 48 | def __init__(self, args): 49 | super(Model, self).__init__() 50 | self.__name__ = 'Factformer' 51 | self.args = args 52 | ## embedding 53 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 54 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 55 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 56 | args.n_hidden, n_layers=0, res=False, act=args.act) 57 | else: 58 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 59 | n_layers=0, res=False, act=args.act) 60 | if args.time_input: 61 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 62 | nn.Linear(args.n_hidden, args.n_hidden)) 63 | 64 | ## models 65 | self.blocks = nn.ModuleList([Galerkin_Transformer_block(num_heads=args.n_heads, hidden_dim=args.n_hidden, 66 | dropout=args.dropout, 67 | act=args.act, 68 | mlp_ratio=args.mlp_ratio, 69 | out_dim=args.out_dim, 70 | last_layer=(_ == args.n_layers - 1)) 71 | for _ in range(args.n_layers)]) 72 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float)) 73 | self.initialize_weights() 74 | 75 | def initialize_weights(self): 76 | self.apply(self._init_weights) 77 | 78 | def _init_weights(self, m): 79 | if isinstance(m, nn.Linear): 80 | trunc_normal_(m.weight, std=0.02) 81 | if isinstance(m, nn.Linear) and m.bias is not None: 82 | nn.init.constant_(m.bias, 0) 83 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 84 | nn.init.constant_(m.bias, 0) 85 | nn.init.constant_(m.weight, 1.0) 86 | 87 | def forward(self, x, fx, T=None, geo=None): 88 | if self.args.unified_pos: 89 | x = self.pos.repeat(x.shape[0], 1, 1) 90 | if fx is not None: 91 | fx = torch.cat((x, fx), -1) 92 | fx = self.preprocess(fx) 93 | else: 94 | fx = self.preprocess(x) 95 | fx = fx + self.placeholder[None, None, :] 96 | 97 | if T is not None: 98 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 99 | Time_emb = self.time_fc(Time_emb) 100 | fx = fx + Time_emb 101 | 102 | for block in self.blocks: 103 | fx = block(fx) 104 | return fx 105 | -------------------------------------------------------------------------------- /models/GraphSAGE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.nn as nng 4 | from layers.Basic import MLP 5 | 6 | 7 | class Model(nn.Module): 8 | def __init__(self, args): 9 | super(Model, self).__init__() 10 | self.__name__ = 'GraphSAGE' 11 | 12 | self.nb_hidden_layers = args.n_layers 13 | self.size_hidden_layers = args.n_hidden 14 | self.bn_bool = True 15 | self.activation = nn.ReLU() 16 | 17 | self.encoder = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, n_layers=0, res=False, 18 | act=args.act) 19 | self.decoder = MLP(args.n_hidden, args.n_hidden * 2, args.out_dim, n_layers=0, res=False, act=args.act) 20 | 21 | self.in_layer = nng.SAGEConv( 22 | in_channels=args.n_hidden, 23 | out_channels=self.size_hidden_layers 24 | ) 25 | 26 | self.hidden_layers = nn.ModuleList() 27 | for n in range(self.nb_hidden_layers - 1): 28 | self.hidden_layers.append(nng.SAGEConv( 29 | in_channels=self.size_hidden_layers, 30 | out_channels=self.size_hidden_layers 31 | )) 32 | 33 | self.out_layer = nng.SAGEConv( 34 | in_channels=self.size_hidden_layers, 35 | out_channels=self.size_hidden_layers 36 | ) 37 | 38 | if self.bn_bool: 39 | self.bn = nn.ModuleList() 40 | for n in range(self.nb_hidden_layers): 41 | self.bn.append(nn.BatchNorm1d(self.size_hidden_layers, track_running_stats=False)) 42 | 43 | def forward(self, x, fx, T=None, geo=None): 44 | if geo is None: 45 | raise ValueError('Please provide edge index for Graph Neural Networks') 46 | z, edge_index = torch.cat((x, fx), dim=-1).squeeze(0), geo 47 | z = self.encoder(z) 48 | z = self.in_layer(z, edge_index) 49 | if self.bn_bool: 50 | z = self.bn[0](z) 51 | z = self.activation(z) 52 | 53 | for n in range(self.nb_hidden_layers - 1): 54 | z = self.hidden_layers[n](z, edge_index) 55 | if self.bn_bool: 56 | z = self.bn[n + 1](z) 57 | z = self.activation(z) 58 | z = self.out_layer(z, edge_index) 59 | z = self.decoder(z) 60 | return z.unsqueeze(0) 61 | -------------------------------------------------------------------------------- /models/Graph_UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.nn as nng 4 | import random 5 | from layers.Basic import MLP 6 | 7 | 8 | def DownSample(id, x, edge_index, pos_x, pool, pool_ratio, r, max_neighbors): 9 | y = x.clone() 10 | n = int(x.size(0)) 11 | 12 | if pool is not None: 13 | y, _, _, _, id_sampled, _ = pool(y, edge_index) 14 | else: 15 | k = int((pool_ratio * torch.tensor(n, dtype=torch.float)).ceil()) 16 | id_sampled = random.sample(range(n), k) 17 | id_sampled = torch.tensor(id_sampled, dtype=torch.long) 18 | y = y[id_sampled] 19 | 20 | pos_x = pos_x[id_sampled] 21 | id.append(id_sampled) 22 | 23 | edge_index_sampled = nng.radius_graph(x=pos_x.detach(), r=r, loop=True, max_num_neighbors=max_neighbors) 24 | 25 | return y, edge_index_sampled 26 | 27 | 28 | def UpSample(x, pos_x_up, pos_x_down): 29 | cluster = nng.nearest(pos_x_up, pos_x_down) 30 | x_up = x[cluster] 31 | 32 | return x_up 33 | 34 | 35 | class Model(nn.Module): 36 | def __init__(self, args, pool='random', scale=5, list_r=[0.05, 0.2, 0.5, 1, 10], 37 | pool_ratio=[0.5, 0.5, 0.5, 0.5, 0.5], max_neighbors=64, layer='SAGE', head=2): 38 | super(Model, self).__init__() 39 | self.__name__ = 'GUNet' 40 | 41 | self.L = scale 42 | self.layer = layer 43 | self.pool_type = pool 44 | self.pool_ratio = pool_ratio 45 | self.list_r = list_r 46 | self.size_hidden_layers = args.n_hidden 47 | self.size_hidden_layers_init = args.n_hidden 48 | self.max_neighbors = max_neighbors 49 | self.dim_enc = args.n_hidden 50 | self.bn_bool = True 51 | self.res = False 52 | self.head = head 53 | self.activation = nn.ReLU() 54 | 55 | self.encoder = MLP(args.fun_dim, args.n_hidden * 2, args.n_hidden, n_layers=0, res=False, 56 | act=args.act) 57 | self.decoder = MLP(args.n_hidden, args.n_hidden * 2, args.out_dim, n_layers=0, res=False, act=args.act) 58 | 59 | self.down_layers = nn.ModuleList() 60 | 61 | if self.pool_type != 'random': 62 | self.pool = nn.ModuleList() 63 | else: 64 | self.pool = None 65 | 66 | if self.layer == 'SAGE': 67 | self.down_layers.append(nng.SAGEConv( 68 | in_channels=self.dim_enc, 69 | out_channels=self.size_hidden_layers 70 | )) 71 | bn_in = self.size_hidden_layers 72 | 73 | elif self.layer == 'GAT': 74 | self.down_layers.append(nng.GATConv( 75 | in_channels=self.dim_enc, 76 | out_channels=self.size_hidden_layers, 77 | heads=self.head, 78 | add_self_loops=False, 79 | concat=True 80 | )) 81 | bn_in = self.head * self.size_hidden_layers 82 | 83 | if self.bn_bool == True: 84 | self.bn = nn.ModuleList() 85 | self.bn.append(nng.BatchNorm( 86 | in_channels=bn_in, 87 | track_running_stats=False 88 | )) 89 | else: 90 | self.bn = None 91 | 92 | for n in range(1, self.L): 93 | if self.pool_type != 'random': 94 | self.pool.append(nng.TopKPooling( 95 | in_channels=self.size_hidden_layers, 96 | ratio=self.pool_ratio[n - 1], 97 | nonlinearity=torch.sigmoid 98 | )) 99 | 100 | if self.layer == 'SAGE': 101 | self.down_layers.append(nng.SAGEConv( 102 | in_channels=self.size_hidden_layers, 103 | out_channels=2 * self.size_hidden_layers, 104 | )) 105 | self.size_hidden_layers = 2 * self.size_hidden_layers 106 | bn_in = self.size_hidden_layers 107 | 108 | elif self.layer == 'GAT': 109 | self.down_layers.append(nng.GATConv( 110 | in_channels=self.head * self.size_hidden_layers, 111 | out_channels=self.size_hidden_layers, 112 | heads=2, 113 | add_self_loops=False, 114 | concat=True 115 | )) 116 | 117 | if self.bn_bool == True: 118 | self.bn.append(nng.BatchNorm( 119 | in_channels=bn_in, 120 | track_running_stats=False 121 | )) 122 | 123 | self.up_layers = nn.ModuleList() 124 | 125 | if self.layer == 'SAGE': 126 | self.up_layers.append(nng.SAGEConv( 127 | in_channels=3 * self.size_hidden_layers_init, 128 | out_channels=self.dim_enc 129 | )) 130 | self.size_hidden_layers_init = 2 * self.size_hidden_layers_init 131 | 132 | elif self.layer == 'GAT': 133 | self.up_layers.append(nng.GATConv( 134 | in_channels=2 * self.head * self.size_hidden_layers, 135 | out_channels=self.dim_enc, 136 | heads=2, 137 | add_self_loops=False, 138 | concat=False 139 | )) 140 | 141 | if self.bn_bool == True: 142 | self.bn.append(nng.BatchNorm( 143 | in_channels=self.dim_enc, 144 | track_running_stats=False 145 | )) 146 | 147 | for n in range(1, self.L - 1): 148 | if self.layer == 'SAGE': 149 | self.up_layers.append(nng.SAGEConv( 150 | in_channels=3 * self.size_hidden_layers_init, 151 | out_channels=self.size_hidden_layers_init, 152 | )) 153 | bn_in = self.size_hidden_layers_init 154 | self.size_hidden_layers_init = 2 * self.size_hidden_layers_init 155 | 156 | elif self.layer == 'GAT': 157 | self.up_layers.append(nng.GATConv( 158 | in_channels=2 * self.head * self.size_hidden_layers, 159 | out_channels=self.size_hidden_layers, 160 | heads=2, 161 | add_self_loops=False, 162 | concat=True 163 | )) 164 | 165 | if self.bn_bool == True: 166 | self.bn.append(nng.BatchNorm( 167 | in_channels=bn_in, 168 | track_running_stats=False 169 | )) 170 | 171 | def forward(self, x, fx, T=None, geo=None): 172 | if geo is None: 173 | raise ValueError('Please provide edge index for Graph Neural Networks') 174 | x, edge_index = fx.squeeze(0), geo 175 | id = [] 176 | edge_index_list = [edge_index.clone()] 177 | pos_x_list = [] 178 | z = self.encoder(x) 179 | if self.res: 180 | z_res = z.clone() 181 | 182 | z = self.down_layers[0](z, edge_index) 183 | 184 | if self.bn_bool == True: 185 | z = self.bn[0](z) 186 | 187 | z = self.activation(z) 188 | z_list = [z.clone()] 189 | for n in range(self.L - 1): 190 | pos_x = x[:, :2] if n == 0 else pos_x[id[n - 1]] 191 | pos_x_list.append(pos_x.clone()) 192 | 193 | if self.pool_type != 'random': 194 | z, edge_index = DownSample(id, z, edge_index, pos_x, self.pool[n], self.pool_ratio[n], self.list_r[n], 195 | self.max_neighbors) 196 | else: 197 | z, edge_index = DownSample(id, z, edge_index, pos_x, None, self.pool_ratio[n], self.list_r[n], 198 | self.max_neighbors) 199 | edge_index_list.append(edge_index.clone()) 200 | 201 | z = self.down_layers[n + 1](z, edge_index) 202 | 203 | if self.bn_bool == True: 204 | z = self.bn[n + 1](z) 205 | 206 | z = self.activation(z) 207 | z_list.append(z.clone()) 208 | pos_x_list.append(pos_x[id[-1]].clone()) 209 | 210 | for n in range(self.L - 1, 0, -1): 211 | z = UpSample(z, pos_x_list[n - 1], pos_x_list[n]) 212 | z = torch.cat([z, z_list[n - 1]], dim=1) 213 | z = self.up_layers[n - 1](z, edge_index_list[n - 1]) 214 | 215 | if self.bn_bool == True: 216 | z = self.bn[self.L + n - 1](z) 217 | 218 | z = self.activation(z) if n != 1 else z 219 | 220 | del (z_list, pos_x_list, edge_index_list) 221 | 222 | if self.res: 223 | z = z + z_res 224 | 225 | z = self.decoder(z) 226 | 227 | return z.unsqueeze(0) 228 | -------------------------------------------------------------------------------- /models/LSM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | from layers.Basic import MLP 8 | from layers.Embedding import timestep_embedding, unified_pos_embedding 9 | from layers.Neural_Spectral_Block import NeuralSpectralBlock1D, NeuralSpectralBlock2D, NeuralSpectralBlock3D 10 | from layers.UNet_Blocks import DoubleConv1D, Down1D, Up1D, OutConv1D, DoubleConv2D, Down2D, Up2D, OutConv2D, \ 11 | DoubleConv3D, Down3D, Up3D, OutConv3D 12 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI 13 | 14 | ConvList = [None, DoubleConv1D, DoubleConv2D, DoubleConv3D] 15 | DownList = [None, Down1D, Down2D, Down3D] 16 | UpList = [None, Up1D, Up2D, Up3D] 17 | OutList = [None, OutConv1D, OutConv2D, OutConv3D] 18 | BlockList = [None, NeuralSpectralBlock1D, NeuralSpectralBlock2D, NeuralSpectralBlock3D] 19 | 20 | 21 | class Model(nn.Module): 22 | def __init__(self, args, bilinear=True, num_token=4, num_basis=12, s1=96, s2=96): 23 | super(Model, self).__init__() 24 | self.__name__ = 'LSM' 25 | self.args = args 26 | if args.task == 'steady': 27 | normtype = 'bn' 28 | else: 29 | normtype = 'in' # when conducting dynamic tasks, use instance norm for stability 30 | ## embedding 31 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 32 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 33 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 34 | args.n_hidden, n_layers=0, res=False, act=args.act) 35 | else: 36 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 37 | n_layers=0, res=False, act=args.act) 38 | if args.time_input: 39 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 40 | nn.Linear(args.n_hidden, args.n_hidden)) 41 | # geometry projection 42 | if self.args.geotype == 'unstructured': 43 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, 44 | s1, s2) 45 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, 46 | s1, s2) 47 | self.iphi = IPHI() 48 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in [s1, s2]] 49 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]] 50 | else: 51 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in args.shapelist] 52 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist] 53 | # multiscale modules 54 | self.inc = ConvList[len(patch_size)](args.n_hidden, args.n_hidden, normtype=normtype) 55 | self.down1 = DownList[len(patch_size)](args.n_hidden, args.n_hidden * 2, normtype=normtype) 56 | self.down2 = DownList[len(patch_size)](args.n_hidden * 2, args.n_hidden * 4, normtype=normtype) 57 | self.down3 = DownList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 8, normtype=normtype) 58 | factor = 2 if bilinear else 1 59 | self.down4 = DownList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 16 // factor, normtype=normtype) 60 | self.up1 = UpList[len(patch_size)](args.n_hidden * 16, args.n_hidden * 8 // factor, bilinear, normtype=normtype) 61 | self.up2 = UpList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 4 // factor, bilinear, normtype=normtype) 62 | self.up3 = UpList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 2 // factor, bilinear, normtype=normtype) 63 | self.up4 = UpList[len(patch_size)](args.n_hidden * 2, args.n_hidden, bilinear, normtype=normtype) 64 | self.outc = OutList[len(patch_size)](args.n_hidden, args.n_hidden) 65 | # Patchified Neural Spectral Blocks 66 | self.process1 = BlockList[len(patch_size)](args.n_hidden, num_basis, patch_size, num_token, args.n_heads) 67 | self.process2 = BlockList[len(patch_size)](args.n_hidden * 2, num_basis, patch_size, num_token, args.n_heads) 68 | self.process3 = BlockList[len(patch_size)](args.n_hidden * 4, num_basis, patch_size, num_token, args.n_heads) 69 | self.process4 = BlockList[len(patch_size)](args.n_hidden * 8, num_basis, patch_size, num_token, args.n_heads) 70 | self.process5 = BlockList[len(patch_size)](args.n_hidden * 16 // factor, num_basis, patch_size, num_token, 71 | args.n_heads) 72 | # projectors 73 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden * 2) 74 | self.fc2 = nn.Linear(args.n_hidden * 2, args.out_dim) 75 | 76 | def structured_geo(self, x, fx, T=None): 77 | B, N, _ = x.shape 78 | if self.args.unified_pos: 79 | x = self.pos.repeat(x.shape[0], 1, 1) 80 | if fx is not None: 81 | fx = torch.cat((x, fx), -1) 82 | fx = self.preprocess(fx) 83 | else: 84 | fx = self.preprocess(x) 85 | 86 | if T is not None: 87 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 88 | Time_emb = self.time_fc(Time_emb) 89 | fx = fx + Time_emb 90 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist) 91 | if not all(item == 0 for item in self.padding): 92 | if len(self.args.shapelist) == 2: 93 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]]) 94 | elif len(self.args.shapelist) == 3: 95 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]]) 96 | x1 = self.inc(x) 97 | x2 = self.down1(x1) 98 | x3 = self.down2(x2) 99 | x4 = self.down3(x3) 100 | x5 = self.down4(x4) 101 | x = self.up1(self.process5(x5), self.process4(x4)) 102 | x = self.up2(x, self.process3(x3)) 103 | x = self.up3(x, self.process2(x2)) 104 | x = self.up4(x, self.process1(x1)) 105 | x = self.outc(x) 106 | 107 | if not all(item == 0 for item in self.padding): 108 | if len(self.args.shapelist) == 2: 109 | x = x[..., :-self.padding[0], :-self.padding[1]] 110 | elif len(self.args.shapelist) == 3: 111 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]] 112 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1) 113 | x = self.fc1(x) 114 | x = F.gelu(x) 115 | x = self.fc2(x) 116 | return x 117 | 118 | def unstructured_geo(self, x, fx, T=None): 119 | original_pos = x 120 | if fx is not None: 121 | fx = torch.cat((x, fx), -1) 122 | fx = self.preprocess(fx) 123 | else: 124 | fx = self.preprocess(x) 125 | 126 | if T is not None: 127 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 128 | Time_emb = self.time_fc(Time_emb) 129 | fx = fx + Time_emb 130 | 131 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None) 132 | x1 = self.inc(x) 133 | x2 = self.down1(x1) 134 | x3 = self.down2(x2) 135 | x4 = self.down3(x3) 136 | x5 = self.down4(x4) 137 | x = self.up1(self.process5(x5), self.process4(x4)) 138 | x = self.up2(x, self.process3(x3)) 139 | x = self.up3(x, self.process2(x2)) 140 | x = self.up4(x, self.process1(x1)) 141 | x = self.outc(x) 142 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1) 143 | x = self.fc1(x) 144 | x = F.gelu(x) 145 | x = self.fc2(x) 146 | return x 147 | 148 | def forward(self, x, fx, T=None, geo=None): 149 | if self.args.geotype == 'unstructured': 150 | return self.unstructured_geo(x, fx, T) 151 | else: 152 | return self.structured_geo(x, fx, T) 153 | -------------------------------------------------------------------------------- /models/MWT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | from layers.Basic import MLP 8 | from layers.Embedding import timestep_embedding, unified_pos_embedding 9 | from layers.MWT_Layers import MWT_CZ1d, MWT_CZ2d, MWT_CZ3d 10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI 11 | 12 | BlockList = [None, MWT_CZ1d, MWT_CZ2d, MWT_CZ3d] 13 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d] 14 | 15 | 16 | class Model(nn.Module): 17 | # this model requires H = W = Z and H, W, Z is the power of two 18 | def __init__(self, args, alpha=2, L=0, c=1, base='legendre', s1=128, s2=128): 19 | super(Model, self).__init__() 20 | self.__name__ = 'MWT' 21 | self.args = args 22 | self.k = args.mwt_k 23 | self.WMT_dim = c * self.k ** 2 24 | if args.geotype == 'structured_1D': 25 | self.WMT_dim = c * self.k 26 | self.c = c 27 | self.s1 = s1 28 | self.s2 = s2 29 | ## embedding 30 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 31 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 32 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 33 | self.WMT_dim, n_layers=0, res=False, act=args.act) 34 | else: 35 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, self.WMT_dim, 36 | n_layers=0, res=False, act=args.act) 37 | if args.time_input: 38 | self.time_fc = nn.Sequential(nn.Linear(self.WMT_dim, args.n_hidden), nn.SiLU(), 39 | nn.Linear(args.n_hidden, self.WMT_dim)) 40 | # geometry projection 41 | if self.args.geotype == 'unstructured': 42 | self.fftproject_in = SpectralConv2d_IrregularGeo(self.WMT_dim, self.WMT_dim, args.modes, args.modes, s1, s2) 43 | self.fftproject_out = SpectralConv2d_IrregularGeo(self.WMT_dim, self.WMT_dim, args.modes, args.modes, s1, 44 | s2) 45 | self.iphi = IPHI() 46 | self.augmented_resolution = [s1, s2] 47 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]] 48 | else: 49 | target = 2 ** (math.ceil(np.log2(max(args.shapelist)))) 50 | self.padding = [(target - size) for size in args.shapelist] 51 | self.augmented_resolution = [target for _ in range(len(self.padding))] 52 | self.spectral_layers = nn.ModuleList( 53 | [BlockList[len(self.padding)](k=self.k, alpha=alpha, L=L, c=c, base=base) for _ in range(args.n_layers)]) 54 | # projectors 55 | self.fc1 = nn.Linear(self.WMT_dim, args.n_hidden) 56 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim) 57 | 58 | def structured_geo(self, x, fx, T=None): 59 | B, N, _ = x.shape 60 | if self.args.unified_pos: 61 | x = self.pos.repeat(x.shape[0], 1, 1) 62 | if fx is not None: 63 | fx = torch.cat((x, fx), -1) 64 | fx = self.preprocess(fx) 65 | else: 66 | fx = self.preprocess(x) 67 | 68 | if T is not None: 69 | Time_emb = timestep_embedding(T, self.WMT_dim).repeat(1, x.shape[1], 1) 70 | Time_emb = self.time_fc(Time_emb) 71 | fx = fx + Time_emb 72 | x = fx.permute(0, 2, 1).reshape(B, self.WMT_dim, *self.args.shapelist) 73 | if not all(item == 0 for item in self.padding): 74 | if len(self.args.shapelist) == 2: 75 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]]) 76 | elif len(self.args.shapelist) == 3: 77 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]]) 78 | x = x.reshape(B, self.WMT_dim, -1).permute(0, 2, 1).contiguous() \ 79 | .reshape(B, *self.augmented_resolution, self.c, self.k ** 2 if self.args.geotype != 'structured_1D' else self.k) 80 | for i in range(self.args.n_layers): 81 | x = self.spectral_layers[i](x) 82 | if i < self.args.n_layers - 1: 83 | x = F.gelu(x) 84 | x = x.reshape(B, -1, self.WMT_dim).permute(0, 2, 1).contiguous() \ 85 | .reshape(B, self.WMT_dim, *self.augmented_resolution) 86 | if not all(item == 0 for item in self.padding): 87 | if len(self.args.shapelist) == 2: 88 | x = x[..., :-self.padding[0], :-self.padding[1]] 89 | elif len(self.args.shapelist) == 3: 90 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]] 91 | x = x.reshape(B, self.WMT_dim, -1).permute(0, 2, 1) 92 | x = self.fc1(x) 93 | x = F.gelu(x) 94 | x = self.fc2(x) 95 | return x 96 | 97 | def unstructured_geo(self, x, fx, T=None): 98 | B, N, _ = x.shape 99 | original_pos = x 100 | if fx is not None: 101 | fx = torch.cat((x, fx), -1) 102 | fx = self.preprocess(fx) 103 | else: 104 | fx = self.preprocess(x) 105 | 106 | if T is not None: 107 | Time_emb = timestep_embedding(T, self.WMT_dim).repeat(1, x.shape[1], 1) 108 | Time_emb = self.time_fc(Time_emb) 109 | fx = fx + Time_emb 110 | 111 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None) 112 | x = x.reshape(B, self.WMT_dim, -1).permute(0, 2, 1).contiguous() \ 113 | .reshape(B, *self.augmented_resolution, self.c, self.k ** 2) 114 | for i in range(self.args.n_layers): 115 | x = self.spectral_layers[i](x) 116 | if i < self.args.n_layers - 1: 117 | x = F.gelu(x) 118 | x = x.reshape(B, -1, self.WMT_dim).permute(0, 2, 1).contiguous() \ 119 | .reshape(B, self.WMT_dim, *self.augmented_resolution) 120 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1) 121 | x = self.fc1(x) 122 | x = F.gelu(x) 123 | x = self.fc2(x) 124 | return x 125 | 126 | def forward(self, x, fx, T=None, geo=None): 127 | if self.args.geotype == 'unstructured': 128 | return self.unstructured_geo(x, fx, T) 129 | else: 130 | return self.structured_geo(x, fx, T) 131 | -------------------------------------------------------------------------------- /models/ONO.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from timm.models.layers import trunc_normal_ 6 | from layers.Basic import MLP, LinearAttention, FlashAttention, SelfAttention as LinearSelfAttention 7 | from layers.Embedding import timestep_embedding, unified_pos_embedding 8 | from einops import rearrange, repeat 9 | import warnings 10 | 11 | 12 | def psd_safe_cholesky(A, upper=False, out=None, jitter=None): 13 | """Compute the Cholesky decomposition of A. If A is only p.s.d, add a small jitter to the diagonal. 14 | Args: 15 | :attr:`A` (Tensor): 16 | The tensor to compute the Cholesky decomposition of 17 | :attr:`upper` (bool, optional): 18 | See torch.cholesky 19 | :attr:`out` (Tensor, optional): 20 | See torch.cholesky 21 | :attr:`jitter` (float, optional): 22 | The jitter to add to the diagonal of A in case A is only p.s.d. If omitted, chosen 23 | as 1e-6 (float) or 1e-8 (double) 24 | """ 25 | try: 26 | L = torch.linalg.cholesky(A, upper=upper, out=out) 27 | if torch.isnan(L).any(): 28 | raise RuntimeError 29 | return L 30 | except RuntimeError as e: 31 | isnan = torch.isnan(A) 32 | if isnan.any(): 33 | raise ValueError( 34 | f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN." 35 | ) 36 | 37 | if jitter is None: 38 | jitter = 1e-6 if A.dtype == torch.float32 else 1e-8 39 | Aprime = A.clone() 40 | jitter_prev = 0 41 | for i in range(10): 42 | jitter_new = jitter * (10 ** i) 43 | Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev) 44 | jitter_prev = jitter_new 45 | try: 46 | L = torch.linalg.cholesky(Aprime, upper=upper, out=out) 47 | warnings.warn( 48 | f"A not p.d., added jitter of {jitter_new} to the diagonal", 49 | RuntimeWarning, 50 | ) 51 | return L 52 | except RuntimeError: 53 | continue 54 | raise e 55 | 56 | 57 | class ONOBlock(nn.Module): 58 | """ONO encoder block.""" 59 | 60 | def __init__( 61 | self, 62 | num_heads: int, 63 | hidden_dim: int, 64 | dropout: float, 65 | act='gelu', 66 | attn_type='nystrom', 67 | mlp_ratio=4, 68 | last_layer=False, 69 | momentum=0.9, 70 | psi_dim=8, 71 | out_dim=1 72 | ): 73 | super().__init__() 74 | self.momentum = momentum 75 | self.psi_dim = psi_dim 76 | 77 | self.register_buffer("feature_cov", None) 78 | self.register_parameter("mu", nn.Parameter(torch.zeros(psi_dim))) 79 | self.ln_1 = nn.LayerNorm(hidden_dim) 80 | if attn_type == 'nystrom': 81 | from nystrom_attention import NystromAttention 82 | self.Attn = NystromAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout) 83 | elif attn_type == 'linear': 84 | self.Attn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout, 85 | attn_type='galerkin') 86 | elif attn_type == 'selfAttention': 87 | self.Attn = LinearSelfAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout) 88 | else: 89 | raise ValueError('Attn type only supports nystrom or linear') 90 | self.ln_2 = nn.LayerNorm(hidden_dim) 91 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 92 | self.proj = nn.Linear(hidden_dim, psi_dim) 93 | self.ln_3 = nn.LayerNorm(hidden_dim) 94 | self.mlp2 = nn.Linear(hidden_dim, out_dim) if last_layer else MLP(hidden_dim, hidden_dim * mlp_ratio, 95 | hidden_dim, n_layers=0, res=False, act=act) 96 | 97 | def forward(self, x, fx): 98 | x = self.Attn(self.ln_1(x)) + x 99 | x = self.mlp(self.ln_2(x)) + x 100 | x_ = self.proj(x) 101 | if self.training: 102 | batch_cov = torch.einsum("blc, bld->cd", x_, x_) / x_.shape[0] / x_.shape[1] 103 | with torch.no_grad(): 104 | if self.feature_cov is None: 105 | self.feature_cov = batch_cov 106 | else: 107 | self.feature_cov.mul_(self.momentum).add_(batch_cov, alpha=1 - self.momentum) 108 | else: 109 | batch_cov = self.feature_cov 110 | L = psd_safe_cholesky(batch_cov) 111 | L_inv_T = L.inverse().transpose(-2, -1) 112 | x_ = x_ @ L_inv_T 113 | 114 | fx = (x_ * torch.nn.functional.softplus(self.mu)) @ (x_.transpose(-2, -1) @ fx) + fx 115 | fx = self.mlp2(self.ln_3(fx)) 116 | 117 | return x, fx 118 | 119 | 120 | class Model(nn.Module): 121 | ## speed up with flash attention 122 | def __init__(self, args): 123 | super(Model, self).__init__() 124 | self.__name__ = 'ONO' 125 | self.args = args 126 | ## embedding 127 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 128 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 129 | self.preprocess_x = MLP(args.ref ** len(args.shapelist), args.n_hidden * 2, 130 | args.n_hidden, n_layers=0, res=False, act=args.act) 131 | self.preprocess_z = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 132 | args.n_hidden, n_layers=0, res=False, act=args.act) 133 | else: 134 | self.preprocess_x = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 135 | n_layers=0, res=False, act=args.act) 136 | self.preprocess_z = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 137 | n_layers=0, res=False, act=args.act) 138 | if args.time_input: 139 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 140 | nn.Linear(args.n_hidden, args.n_hidden)) 141 | 142 | ## models 143 | self.blocks = nn.ModuleList([ONOBlock(num_heads=args.n_heads, hidden_dim=args.n_hidden, 144 | dropout=args.dropout, 145 | act=args.act, 146 | mlp_ratio=args.mlp_ratio, 147 | out_dim=args.out_dim, 148 | psi_dim=args.psi_dim, 149 | attn_type=args.attn_type, 150 | last_layer=(_ == args.n_layers - 1)) 151 | for _ in range(args.n_layers)]) 152 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float)) 153 | self.initialize_weights() 154 | 155 | def initialize_weights(self): 156 | self.apply(self._init_weights) 157 | 158 | def _init_weights(self, m): 159 | if isinstance(m, nn.Linear): 160 | trunc_normal_(m.weight, std=0.02) 161 | if isinstance(m, nn.Linear) and m.bias is not None: 162 | nn.init.constant_(m.bias, 0) 163 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 164 | nn.init.constant_(m.bias, 0) 165 | nn.init.constant_(m.weight, 1.0) 166 | 167 | def forward(self, x, fx, T=None, geo=None): 168 | if self.args.unified_pos: 169 | x = self.pos.repeat(x.shape[0], 1, 1) 170 | if fx is not None: 171 | x = torch.cat((x, fx), -1) 172 | fx = self.preprocess_z(x) 173 | x = self.preprocess_x(x) 174 | else: 175 | fx = self.preprocess_z(x) 176 | x = self.preprocess_x(x) 177 | fx = fx + self.placeholder[None, None, :] 178 | 179 | if T is not None: 180 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 181 | Time_emb = self.time_fc(Time_emb) 182 | fx = fx + Time_emb 183 | 184 | for block in self.blocks: 185 | x, fx = block(x, fx) 186 | return fx 187 | -------------------------------------------------------------------------------- /models/PointNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.nn as nng 4 | from layers.Embedding import unified_pos_embedding 5 | from layers.Basic import MLP 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args): 10 | super(Model, self).__init__() 11 | self.__name__ = 'PointNet' 12 | 13 | self.in_block = MLP(args.n_hidden, args.n_hidden * 2, args.n_hidden * 2, n_layers=0, res=False, 14 | act=args.act) 15 | self.max_block = MLP(args.n_hidden * 2, args.n_hidden * 8, args.n_hidden * 32, n_layers=0, res=False, 16 | act=args.act) 17 | 18 | self.out_block = MLP(args.n_hidden * (2 + 32), args.n_hidden * 16, args.n_hidden * 4, n_layers=0, res=False, 19 | act=args.act) 20 | 21 | self.encoder = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, n_layers=0, res=False, 22 | act=args.act) 23 | self.decoder = MLP(args.n_hidden, args.n_hidden * 2, args.out_dim, n_layers=0, res=False, act=args.act) 24 | 25 | self.fcfinal = nn.Linear(args.n_hidden * 4, args.n_hidden) 26 | 27 | def forward(self, x, fx, T=None, geo=None): 28 | if geo is None: 29 | raise ValueError('Please provide edge index for Graph Neural Networks') 30 | z, batch = torch.cat((x, fx), dim=-1).float().squeeze(0), torch.zeros([x.shape[1]]).cuda().long() 31 | 32 | z = self.encoder(z) 33 | z = self.in_block(z) 34 | 35 | global_coef = self.max_block(z) 36 | global_coef = nng.global_max_pool(global_coef, batch=batch) 37 | nb_points = torch.zeros(global_coef.shape[0], device=z.device) 38 | 39 | for i in range(batch.max() + 1): 40 | nb_points[i] = (batch == i).sum() 41 | nb_points = nb_points.long() 42 | global_coef = torch.repeat_interleave(global_coef, nb_points, dim=0) 43 | 44 | z = torch.cat([z, global_coef], dim=1) 45 | z = self.out_block(z) 46 | z = self.fcfinal(z) 47 | z = self.decoder(z) 48 | 49 | return z.unsqueeze(0) 50 | -------------------------------------------------------------------------------- /models/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from timm.models.layers import trunc_normal_ 6 | from layers.Basic import MLP, FlashAttention 7 | from layers.Embedding import timestep_embedding, unified_pos_embedding 8 | from einops import rearrange, repeat 9 | 10 | 11 | class Transformer_block(nn.Module): 12 | """Transformer encoder block.""" 13 | 14 | def __init__( 15 | self, 16 | num_heads: int, 17 | hidden_dim: int, 18 | dropout: float, 19 | act='gelu', 20 | mlp_ratio=4, 21 | last_layer=False, 22 | out_dim=1, 23 | ): 24 | super().__init__() 25 | self.last_layer = last_layer 26 | self.ln_1 = nn.LayerNorm(hidden_dim) 27 | 28 | # flash_attention 29 | self.Attn = FlashAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout) 30 | self.ln_2 = nn.LayerNorm(hidden_dim) 31 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 32 | if self.last_layer: 33 | self.ln_3 = nn.LayerNorm(hidden_dim) 34 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 35 | 36 | def forward(self, fx): 37 | fx = self.Attn(self.ln_1(fx)) + fx 38 | fx = self.mlp(self.ln_2(fx)) + fx 39 | if self.last_layer: 40 | return self.mlp2(self.ln_3(fx)) 41 | else: 42 | return fx 43 | 44 | 45 | class Model(nn.Module): 46 | ## speed up with flash attention 47 | def __init__(self, args): 48 | super(Model, self).__init__() 49 | self.__name__ = 'Transformer' 50 | self.args = args 51 | ## embedding 52 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 53 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 54 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 55 | args.n_hidden, n_layers=0, res=False, act=args.act) 56 | else: 57 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 58 | n_layers=0, res=False, act=args.act) 59 | if args.time_input: 60 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 61 | nn.Linear(args.n_hidden, args.n_hidden)) 62 | 63 | ## models 64 | self.blocks = nn.ModuleList([Transformer_block(num_heads=args.n_heads, hidden_dim=args.n_hidden, 65 | dropout=args.dropout, 66 | act=args.act, 67 | mlp_ratio=args.mlp_ratio, 68 | out_dim=args.out_dim, 69 | last_layer=(_ == args.n_layers - 1)) 70 | for _ in range(args.n_layers)]) 71 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float)) 72 | self.initialize_weights() 73 | 74 | def initialize_weights(self): 75 | self.apply(self._init_weights) 76 | 77 | def _init_weights(self, m): 78 | if isinstance(m, nn.Linear): 79 | trunc_normal_(m.weight, std=0.02) 80 | if isinstance(m, nn.Linear) and m.bias is not None: 81 | nn.init.constant_(m.bias, 0) 82 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 83 | nn.init.constant_(m.bias, 0) 84 | nn.init.constant_(m.weight, 1.0) 85 | 86 | def forward(self, x, fx, T=None, geo=None): 87 | if self.args.unified_pos: 88 | x = self.pos.repeat(x.shape[0], 1, 1) 89 | if fx is not None: 90 | fx = torch.cat((x, fx), -1) 91 | fx = self.preprocess(fx) 92 | else: 93 | fx = self.preprocess(x) 94 | fx = fx + self.placeholder[None, None, :] 95 | 96 | if T is not None: 97 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 98 | Time_emb = self.time_fc(Time_emb) 99 | fx = fx + Time_emb 100 | 101 | for block in self.blocks: 102 | fx = block(fx) 103 | return fx 104 | -------------------------------------------------------------------------------- /models/Transolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from timm.models.layers import trunc_normal_ 5 | from layers.Basic import MLP 6 | from layers.Embedding import timestep_embedding, unified_pos_embedding 7 | from layers.Physics_Attention import Physics_Attention_Irregular_Mesh 8 | from layers.Physics_Attention import Physics_Attention_Structured_Mesh_1D 9 | from layers.Physics_Attention import Physics_Attention_Structured_Mesh_2D 10 | from layers.Physics_Attention import Physics_Attention_Structured_Mesh_3D 11 | 12 | PHYSICS_ATTENTION = { 13 | 'unstructured': Physics_Attention_Irregular_Mesh, 14 | 'structured_1D': Physics_Attention_Structured_Mesh_1D, 15 | 'structured_2D': Physics_Attention_Structured_Mesh_2D, 16 | 'structured_3D': Physics_Attention_Structured_Mesh_3D 17 | } 18 | 19 | 20 | class Transolver_block(nn.Module): 21 | """Transolver encoder block.""" 22 | 23 | def __init__( 24 | self, 25 | num_heads: int, 26 | hidden_dim: int, 27 | dropout: float, 28 | act='gelu', 29 | mlp_ratio=4, 30 | last_layer=False, 31 | out_dim=1, 32 | slice_num=32, 33 | geotype='unstructured', 34 | shapelist=None 35 | ): 36 | super().__init__() 37 | self.last_layer = last_layer 38 | self.ln_1 = nn.LayerNorm(hidden_dim) 39 | 40 | self.Attn = PHYSICS_ATTENTION[geotype](hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, 41 | dropout=dropout, slice_num=slice_num, shapelist=shapelist) 42 | self.ln_2 = nn.LayerNorm(hidden_dim) 43 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 44 | if self.last_layer: 45 | self.ln_3 = nn.LayerNorm(hidden_dim) 46 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 47 | 48 | def forward(self, fx): 49 | fx = self.Attn(self.ln_1(fx)) + fx 50 | fx = self.mlp(self.ln_2(fx)) + fx 51 | if self.last_layer: 52 | return self.mlp2(self.ln_3(fx)) 53 | else: 54 | return fx 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, args): 59 | super(Model, self).__init__() 60 | self.__name__ = 'Transolver' 61 | self.args = args 62 | ## embedding 63 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 64 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 65 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 66 | args.n_hidden, n_layers=0, res=False, act=args.act) 67 | else: 68 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 69 | n_layers=0, res=False, act=args.act) 70 | if args.time_input: 71 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 72 | nn.Linear(args.n_hidden, args.n_hidden)) 73 | 74 | ## models 75 | self.blocks = nn.ModuleList([Transolver_block(num_heads=args.n_heads, hidden_dim=args.n_hidden, 76 | dropout=args.dropout, 77 | act=args.act, 78 | mlp_ratio=args.mlp_ratio, 79 | out_dim=args.out_dim, 80 | slice_num=args.slice_num, 81 | last_layer=(_ == args.n_layers - 1), 82 | geotype=args.geotype, 83 | shapelist=args.shapelist) 84 | for _ in range(args.n_layers)]) 85 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float)) 86 | self.initialize_weights() 87 | 88 | def initialize_weights(self): 89 | self.apply(self._init_weights) 90 | 91 | def _init_weights(self, m): 92 | if isinstance(m, nn.Linear): 93 | trunc_normal_(m.weight, std=0.02) 94 | if isinstance(m, nn.Linear) and m.bias is not None: 95 | nn.init.constant_(m.bias, 0) 96 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 97 | nn.init.constant_(m.bias, 0) 98 | nn.init.constant_(m.weight, 1.0) 99 | 100 | def structured_geo(self, x, fx, T=None): 101 | if self.args.unified_pos: 102 | x = self.pos.repeat(x.shape[0], 1, 1) 103 | if fx is not None: 104 | fx = torch.cat((x, fx), -1) 105 | fx = self.preprocess(fx) 106 | else: 107 | fx = self.preprocess(x) 108 | fx = fx + self.placeholder[None, None, :] 109 | 110 | if T is not None: 111 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 112 | Time_emb = self.time_fc(Time_emb) 113 | fx = fx + Time_emb 114 | 115 | for block in self.blocks: 116 | fx = block(fx) 117 | return fx 118 | 119 | def unstructured_geo(self, x, fx, T=None): 120 | if fx is not None: 121 | fx = torch.cat((x, fx), -1) 122 | fx = self.preprocess(fx) 123 | else: 124 | fx = self.preprocess(x) 125 | fx = fx + self.placeholder[None, None, :] 126 | 127 | if T is not None: 128 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 129 | Time_emb = self.time_fc(Time_emb) 130 | fx = fx + Time_emb 131 | 132 | for block in self.blocks: 133 | fx = block(fx) 134 | return fx 135 | 136 | def forward(self, x, fx, T=None, geo=None): 137 | if self.args.geotype == 'unstructured': 138 | return self.unstructured_geo(x, fx, T) 139 | else: 140 | return self.structured_geo(x, fx, T) 141 | -------------------------------------------------------------------------------- /models/U_FNO.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | from layers.Basic import MLP 8 | from layers.Embedding import timestep_embedding, unified_pos_embedding 9 | from layers.FNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d 10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI 11 | from models.U_Net import Model as U_Net 12 | 13 | BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d] 14 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d] 15 | 16 | 17 | class Model(nn.Module): 18 | def __init__(self, args, s1=96, s2=96): 19 | super(Model, self).__init__() 20 | self.__name__ = 'U-FNO' 21 | self.args = args 22 | ## embedding 23 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 24 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 25 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 26 | args.n_hidden, n_layers=0, res=False, act=args.act) 27 | else: 28 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 29 | n_layers=0, res=False, act=args.act) 30 | if args.time_input: 31 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 32 | nn.Linear(args.n_hidden, args.n_hidden)) 33 | # geometry projection 34 | if self.args.geotype == 'unstructured': 35 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1, 36 | s2) 37 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1, 38 | s2) 39 | self.iphi = IPHI() 40 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]] 41 | else: 42 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist] 43 | self.conv0 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 44 | *[args.modes for _ in range(len(self.padding))]) 45 | self.conv1 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 46 | *[args.modes for _ in range(len(self.padding))]) 47 | self.conv2 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 48 | *[args.modes for _ in range(len(self.padding))]) 49 | self.conv3 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden, 50 | *[args.modes for _ in range(len(self.padding))]) 51 | self.w0 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 52 | self.w1 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 53 | self.w2 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 54 | self.w3 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1) 55 | self.u_net2 = U_Net(args) 56 | self.u_net3 = U_Net(args) 57 | # projectors 58 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden) 59 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim) 60 | 61 | def structured_geo(self, x, fx, T=None): 62 | B, N, _ = x.shape 63 | if self.args.unified_pos: 64 | x = self.pos.repeat(x.shape[0], 1, 1) 65 | if fx is not None: 66 | fx = torch.cat((x, fx), -1) 67 | fx = self.preprocess(fx) 68 | else: 69 | fx = self.preprocess(x) 70 | 71 | if T is not None: 72 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 73 | Time_emb = self.time_fc(Time_emb) 74 | fx = fx + Time_emb 75 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist) 76 | if not all(item == 0 for item in self.padding): 77 | if len(self.args.shapelist) == 2: 78 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]]) 79 | elif len(self.args.shapelist) == 3: 80 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]]) 81 | 82 | x1 = self.conv0(x) 83 | x2 = self.w0(x) 84 | x = x1 + x2 85 | x = F.gelu(x) 86 | 87 | x1 = self.conv1(x) 88 | x2 = self.w1(x) 89 | x = x1 + x2 90 | x = F.gelu(x) 91 | 92 | x1 = self.conv2(x) 93 | x2 = self.w2(x) 94 | x3 = self.u_net2.multiscale(x) 95 | x = x1 + x2 + x3 96 | x = F.gelu(x) 97 | 98 | x1 = self.conv3(x) 99 | x2 = self.w3(x) 100 | x3 = self.u_net3.multiscale(x) 101 | x = x1 + x2 + x3 102 | 103 | if not all(item == 0 for item in self.padding): 104 | if len(self.args.shapelist) == 2: 105 | x = x[..., :-self.padding[0], :-self.padding[1]] 106 | elif len(self.args.shapelist) == 3: 107 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]] 108 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1) 109 | x = self.fc1(x) 110 | x = F.gelu(x) 111 | x = self.fc2(x) 112 | return x 113 | 114 | def unstructured_geo(self, x, fx, T=None): 115 | original_pos = x 116 | if fx is not None: 117 | fx = torch.cat((x, fx), -1) 118 | fx = self.preprocess(fx) 119 | else: 120 | fx = self.preprocess(x) 121 | 122 | if T is not None: 123 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 124 | Time_emb = self.time_fc(Time_emb) 125 | fx = fx + Time_emb 126 | 127 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None) 128 | 129 | x1 = self.conv0(x) 130 | x2 = self.w0(x) 131 | x = x1 + x2 132 | x = F.gelu(x) 133 | 134 | x1 = self.conv1(x) 135 | x2 = self.w1(x) 136 | x = x1 + x2 137 | x = F.gelu(x) 138 | 139 | x1 = self.conv2(x) 140 | x2 = self.w2(x) 141 | x3 = self.u_net2.multiscale(x) 142 | x = x1 + x2 + x3 143 | x = F.gelu(x) 144 | 145 | x1 = self.conv3(x) 146 | x2 = self.w3(x) 147 | x3 = self.u_net3.multiscale(x) 148 | x = x1 + x2 + x3 149 | 150 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1) 151 | x = self.fc1(x) 152 | x = F.gelu(x) 153 | x = self.fc2(x) 154 | return x 155 | 156 | def forward(self, x, fx, T=None, geo=None): 157 | if self.args.geotype == 'unstructured': 158 | return self.unstructured_geo(x, fx, T) 159 | else: 160 | return self.structured_geo(x, fx, T) 161 | -------------------------------------------------------------------------------- /models/U_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | from layers.Basic import MLP 8 | from layers.Embedding import timestep_embedding, unified_pos_embedding 9 | from layers.UNet_Blocks import DoubleConv1D, Down1D, Up1D, OutConv1D, DoubleConv2D, Down2D, Up2D, OutConv2D, \ 10 | DoubleConv3D, Down3D, Up3D, OutConv3D 11 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI 12 | 13 | ConvList = [None, DoubleConv1D, DoubleConv2D, DoubleConv3D] 14 | DownList = [None, Down1D, Down2D, Down3D] 15 | UpList = [None, Up1D, Up2D, Up3D] 16 | OutList = [None, OutConv1D, OutConv2D, OutConv3D] 17 | 18 | 19 | class Model(nn.Module): 20 | def __init__(self, args, bilinear=True, s1=96, s2=96): 21 | super(Model, self).__init__() 22 | self.__name__ = 'U-Net' 23 | self.args = args 24 | if args.task == 'steady': 25 | normtype = 'bn' 26 | else: 27 | normtype = 'in' # when conducting dynamic tasks, use instance norm for stability 28 | ## embedding 29 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh 30 | self.pos = unified_pos_embedding(args.shapelist, args.ref) 31 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2, 32 | args.n_hidden, n_layers=0, res=False, act=args.act) 33 | else: 34 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, 35 | n_layers=0, res=False, act=args.act) 36 | if args.time_input: 37 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(), 38 | nn.Linear(args.n_hidden, args.n_hidden)) 39 | # geometry projection 40 | if self.args.geotype == 'unstructured': 41 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, 42 | s1, s2) 43 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, 44 | s1, s2) 45 | self.iphi = IPHI() 46 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in [s1, s2]] 47 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]] 48 | else: 49 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in args.shapelist] 50 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist] 51 | # multiscale modules 52 | self.inc = ConvList[len(patch_size)](args.n_hidden, args.n_hidden, normtype=normtype) 53 | self.down1 = DownList[len(patch_size)](args.n_hidden, args.n_hidden * 2, normtype=normtype) 54 | self.down2 = DownList[len(patch_size)](args.n_hidden * 2, args.n_hidden * 4, normtype=normtype) 55 | self.down3 = DownList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 8, normtype=normtype) 56 | factor = 2 if bilinear else 1 57 | self.down4 = DownList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 16 // factor, normtype=normtype) 58 | self.up1 = UpList[len(patch_size)](args.n_hidden * 16, args.n_hidden * 8 // factor, bilinear, normtype=normtype) 59 | self.up2 = UpList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 4 // factor, bilinear, normtype=normtype) 60 | self.up3 = UpList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 2 // factor, bilinear, normtype=normtype) 61 | self.up4 = UpList[len(patch_size)](args.n_hidden * 2, args.n_hidden, bilinear, normtype=normtype) 62 | self.outc = OutList[len(patch_size)](args.n_hidden, args.n_hidden) 63 | # projectors 64 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden) 65 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim) 66 | 67 | def multiscale(self, x): 68 | x1 = self.inc(x) 69 | x2 = self.down1(x1) 70 | x3 = self.down2(x2) 71 | x4 = self.down3(x3) 72 | x5 = self.down4(x4) 73 | x = self.up1(x5, x4) 74 | x = self.up2(x, x3) 75 | x = self.up3(x, x2) 76 | x = self.up4(x, x1) 77 | x = self.outc(x) 78 | return x 79 | 80 | def structured_geo(self, x, fx, T=None): 81 | B, N, _ = x.shape 82 | if self.args.unified_pos: 83 | x = self.pos.repeat(x.shape[0], 1, 1) 84 | if fx is not None: 85 | fx = torch.cat((x, fx), -1) 86 | fx = self.preprocess(fx) 87 | else: 88 | fx = self.preprocess(x) 89 | 90 | if T is not None: 91 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 92 | Time_emb = self.time_fc(Time_emb) 93 | fx = fx + Time_emb 94 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist) 95 | if not all(item == 0 for item in self.padding): 96 | if len(self.args.shapelist) == 2: 97 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]]) 98 | elif len(self.args.shapelist) == 3: 99 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]]) 100 | x = self.multiscale(x) ## U-Net 101 | if not all(item == 0 for item in self.padding): 102 | if len(self.args.shapelist) == 2: 103 | x = x[..., :-self.padding[0], :-self.padding[1]] 104 | elif len(self.args.shapelist) == 3: 105 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]] 106 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1) 107 | x = self.fc1(x) 108 | x = F.gelu(x) 109 | x = self.fc2(x) 110 | return x 111 | 112 | def unstructured_geo(self, x, fx, T=None): 113 | original_pos = x 114 | if fx is not None: 115 | fx = torch.cat((x, fx), -1) 116 | fx = self.preprocess(fx) 117 | else: 118 | fx = self.preprocess(x) 119 | 120 | if T is not None: 121 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1) 122 | Time_emb = self.time_fc(Time_emb) 123 | fx = fx + Time_emb 124 | 125 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None) 126 | x = self.multiscale(x) ## U-Net 127 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1) 128 | x = self.fc1(x) 129 | x = F.gelu(x) 130 | x = self.fc2(x) 131 | return x 132 | 133 | def forward(self, x, fx, T=None, geo=None): 134 | if self.args.geotype == 'unstructured': 135 | return self.unstructured_geo(x, fx, T) 136 | else: 137 | return self.structured_geo(x, fx, T) 138 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | from models import Transolver, LSM, FNO, U_Net, Transformer, Factformer, Swin_Transformer, Galerkin_Transformer, GNOT, \ 2 | U_NO, U_FNO, F_FNO, ONO, MWT, GraphSAGE, Graph_UNet, PointNet 3 | 4 | 5 | def get_model(args): 6 | model_dict = { 7 | 'PointNet': PointNet, 8 | 'Graph_UNet': Graph_UNet, 9 | 'GraphSAGE': GraphSAGE, 10 | 'MWT': MWT, 11 | 'ONO': ONO, 12 | 'F_FNO': F_FNO, 13 | 'U_FNO': U_FNO, 14 | 'U_NO': U_NO, 15 | 'GNOT': GNOT, 16 | 'Galerkin_Transformer': Galerkin_Transformer, 17 | 'Swin_Transformer': Swin_Transformer, 18 | 'Factformer': Factformer, 19 | 'Transformer': Transformer, 20 | 'U_Net': U_Net, 21 | 'FNO': FNO, 22 | 'Transolver': Transolver, 23 | 'LSM': LSM, 24 | } 25 | return model_dict[args.model].Model(args) 26 | -------------------------------------------------------------------------------- /pic/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Neural-Solver-Library/148c2eb890dc22081acc822acd760adcf4f06b30/pic/logo.png -------------------------------------------------------------------------------- /pic/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Neural-Solver-Library/148c2eb890dc22081acc822acd760adcf4f06b30/pic/task.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.8.0 2 | dgl==1.1.0 3 | scipy==1.7.3 4 | scikit-learn 5 | torch 6 | torch_geometric 7 | torch-cluster 8 | vtk 9 | timm 10 | einops 11 | seaborn 12 | pyvista 13 | sympy 14 | nystrom_attention -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from tqdm import * 6 | 7 | parser = argparse.ArgumentParser('Training Neural PDE Solvers') 8 | 9 | ## training 10 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 11 | parser.add_argument('--epochs', type=int, default=500, help='maximum epochs') 12 | parser.add_argument('--weight_decay', type=float, default=1e-5, help='optimizer weight decay') 13 | parser.add_argument('--pct_start', type=float, default=0.3, help='oncycle lr schedule') 14 | parser.add_argument('--batch-size', type=int, default=8, help='batch size') 15 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use") 16 | parser.add_argument('--max_grad_norm', type=float, default=None, help='make the training stable') 17 | parser.add_argument('--derivloss', type=bool, default=False, help='adopt the spatial derivate as regularization') 18 | parser.add_argument('--teacher_forcing', type=int, default=1, 19 | help='adopt teacher forcing in autoregressive to speed up convergence') 20 | parser.add_argument('--optimizer', type=str, default='AdamW', help='optimizer type, select from Adam, AdamW') 21 | parser.add_argument('--scheduler', type=str, default='OneCycleLR', 22 | help='learning rate scheduler, select from [OneCycleLR, CosineAnnealingLR, StepLR]') 23 | parser.add_argument('--step_size', type=int, default=100, help='step size for StepLR scheduler') 24 | parser.add_argument('--gamma', type=float, default=0.5, help='decay parameter for StepLR scheduler') 25 | 26 | ## data 27 | parser.add_argument('--data_path', type=str, default='/data/fno/', help='data folder') 28 | parser.add_argument('--loader', type=str, default='airfoil', help='type of data loader') 29 | parser.add_argument('--train_ratio', type=float, default=0.8, help='training data ratio') 30 | parser.add_argument('--ntrain', type=int, default=1000, help='training data numbers') 31 | parser.add_argument('--ntest', type=int, default=200, help='test data numbers') 32 | parser.add_argument('--normalize', type=bool, default=False, help='make normalization to output') 33 | parser.add_argument('--norm_type', type=str, default='UnitTransformer', 34 | help='dataset normalize type. select from [UnitTransformer, UnitGaussianNormalizer]') 35 | parser.add_argument('--geotype', type=str, default='unstructured', 36 | help='select from [unstructured, structured_1D, structured_2D, structured_3D]') 37 | parser.add_argument('--time_input', type=bool, default=False, help='for conditional dynamic task') 38 | parser.add_argument('--space_dim', type=int, default=2, help='position information dimension') 39 | parser.add_argument('--fun_dim', type=int, default=0, help='input observation dimension') 40 | parser.add_argument('--out_dim', type=int, default=1, help='output observation dimension') 41 | parser.add_argument('--shapelist', type=list, default=None, help='for structured geometry') 42 | parser.add_argument('--downsamplex', type=int, default=1, help='downsample rate in x-axis') 43 | parser.add_argument('--downsampley', type=int, default=1, help='downsample rate in y-axis') 44 | parser.add_argument('--downsamplez', type=int, default=1, help='downsample rate in z-axis') 45 | parser.add_argument('--radius', type=float, default=0.2, help='for construct geometry') 46 | 47 | ## task 48 | parser.add_argument('--task', type=str, default='steady', 49 | help='select from [steady, dynamic_autoregressive, dynamic_conditional]') 50 | parser.add_argument('--T_in', type=int, default=10, help='for input sequence') 51 | parser.add_argument('--T_out', type=int, default=10, help='for output sequence') 52 | 53 | ## models 54 | parser.add_argument('--model', type=str, default='Transolver') 55 | parser.add_argument('--n_hidden', type=int, default=64, help='hidden dim') 56 | parser.add_argument('--n_layers', type=int, default=3, help='layers') 57 | parser.add_argument('--n_heads', type=int, default=4, help='number of heads') 58 | parser.add_argument('--act', type=str, default='gelu') 59 | parser.add_argument('--mlp_ratio', type=int, default=1, help='mlp ratio for feedforward layers') 60 | parser.add_argument('--dropout', type=float, default=0.0, help='dropout') 61 | parser.add_argument('--unified_pos', type=int, default=0, help='for unified position embedding') 62 | parser.add_argument('--ref', type=int, default=8, help='number of reference points for unified pos embedding') 63 | 64 | ## model specific configuration 65 | parser.add_argument('--slice_num', type=int, default=32, help='number of physical states for Transolver') 66 | parser.add_argument('--modes', type=int, default=12, help='number of basis functions for LSM and FNO') 67 | parser.add_argument('--psi_dim', type=int, default=8, help='number of psi_dim for ONO') 68 | parser.add_argument('--attn_type', type=str, default='nystrom',help='attn_type for ONO, select from nystrom, linear, selfAttention') 69 | parser.add_argument('--mwt_k', type=int, default=3,help='number of wavelet basis functions for MWT') 70 | 71 | ## eval 72 | parser.add_argument('--eval', type=int, default=0, help='evaluation or not') 73 | parser.add_argument('--save_name', type=str, default='Transolver_check', help='name of folders') 74 | parser.add_argument('--vis_num', type=int, default=10, help='number of visualization cases') 75 | parser.add_argument('--vis_bound', type=int, nargs='+', default=None, help='size of region for visualization, in list') 76 | 77 | args = parser.parse_args() 78 | eval = args.eval 79 | save_name = args.save_name 80 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 81 | 82 | 83 | def main(): 84 | if args.task == 'steady': 85 | from exp.exp_steady import Exp_Steady 86 | exp = Exp_Steady(args) 87 | elif args.task == 'steady_design': 88 | from exp.exp_steady_design import Exp_Steady_Design 89 | exp = Exp_Steady_Design(args) 90 | elif args.task == 'dynamic_autoregressive': 91 | from exp.exp_dynamic_autoregressive import Exp_Dynamic_Autoregressive 92 | exp = Exp_Dynamic_Autoregressive(args) 93 | elif args.task == 'dynamic_conditional': 94 | from exp.exp_dynamic_conditional import Exp_Dynamic_Conditional 95 | exp = Exp_Dynamic_Conditional(args) 96 | else: 97 | raise NotImplementedError 98 | 99 | if eval: 100 | exp.test() 101 | else: 102 | exp.train() 103 | exp.test() 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /scripts/DesignBench/car/GNOT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDE_data/mlcfd_data/ \ 4 | --loader car_design \ 5 | --geotype unstructured \ 6 | --task steady_design \ 7 | --space_dim 3 \ 8 | --fun_dim 7 \ 9 | --out_dim 4 \ 10 | --model GNOT \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 0 \ 17 | --ref 8 \ 18 | --batch-size 1 \ 19 | --epochs 200 \ 20 | --eval 0 \ 21 | --save_name car_design_GNOT -------------------------------------------------------------------------------- /scripts/DesignBench/car/GraphSAGE.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/PDE_data/mlcfd_data/ \ 4 | --loader car_design \ 5 | --geotype unstructured \ 6 | --task steady_design \ 7 | --space_dim 3 \ 8 | --fun_dim 7 \ 9 | --out_dim 4 \ 10 | --model GraphSAGE \ 11 | --n_hidden 128 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 0 \ 17 | --ref 8 \ 18 | --batch-size 1 \ 19 | --epochs 200 \ 20 | --eval 0 \ 21 | --save_name car_design_GraphSAGE -------------------------------------------------------------------------------- /scripts/DesignBench/car/GraphUNet.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/PDE_data/mlcfd_data/ \ 4 | --loader car_design \ 5 | --geotype unstructured \ 6 | --task steady_design \ 7 | --space_dim 3 \ 8 | --fun_dim 7 \ 9 | --out_dim 4 \ 10 | --model Graph_UNet \ 11 | --n_hidden 16 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 0 \ 17 | --ref 8 \ 18 | --batch-size 1 \ 19 | --epochs 200 \ 20 | --eval 0 \ 21 | --save_name car_design_GraphUNet -------------------------------------------------------------------------------- /scripts/DesignBench/car/PointNet.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/PDE_data/mlcfd_data/ \ 4 | --loader car_design \ 5 | --geotype unstructured \ 6 | --task steady_design \ 7 | --space_dim 3 \ 8 | --fun_dim 7 \ 9 | --out_dim 4 \ 10 | --model PointNet \ 11 | --n_hidden 16 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 0 \ 17 | --ref 8 \ 18 | --batch-size 1 \ 19 | --epochs 200 \ 20 | --eval 0 \ 21 | --save_name car_design_PointNet -------------------------------------------------------------------------------- /scripts/DesignBench/car/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/PDE_data/mlcfd_data/ \ 4 | --loader car_design \ 5 | --geotype unstructured \ 6 | --task steady_design \ 7 | --space_dim 3 \ 8 | --fun_dim 7 \ 9 | --out_dim 4 \ 10 | --model Transolver \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 0 \ 17 | --ref 8 \ 18 | --batch-size 1 \ 19 | --epochs 200 \ 20 | --eval 0 \ 21 | --save_name car_design_Transolver -------------------------------------------------------------------------------- /scripts/PDEBench/3DCFD/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/3D/3D/Train/3D_CFD_Rand_M1.0_Eta1e-08_Zeta1e-08_periodic_Train.hdf5 \ 4 | --loader cfd3d \ 5 | --geotype structured_3D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 3 \ 12 | --fun_dim 50 \ 13 | --out_dim 5 \ 14 | --downsamplex 2 \ 15 | --downsampley 2 \ 16 | --downsamplez 2 \ 17 | --model FNO \ 18 | --n_hidden 20 \ 19 | --n_heads 8 \ 20 | --n_layers 8 \ 21 | --unified_pos 0 \ 22 | --ref 8 \ 23 | --batch-size 5 \ 24 | --epochs 500 \ 25 | --eval 0 \ 26 | --save_name pdebench_FNO_3DCFD -------------------------------------------------------------------------------- /scripts/PDEBench/3DCFD/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/3D/3D/Train/3D_CFD_Rand_M1.0_Eta1e-08_Zeta1e-08_periodic_Train.hdf5 \ 4 | --loader cfd3d \ 5 | --geotype structured_3D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.001 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 3 \ 12 | --fun_dim 50 \ 13 | --out_dim 5 \ 14 | --downsamplex 2 \ 15 | --downsampley 2 \ 16 | --downsamplez 2 \ 17 | --model U_Net \ 18 | --n_hidden 20 \ 19 | --n_heads 8 \ 20 | --n_layers 8 \ 21 | --unified_pos 0 \ 22 | --ref 8 \ 23 | --batch-size 5 \ 24 | --epochs 500 \ 25 | --eval 0 \ 26 | --save_name pdebench_U_Net_3DCFD -------------------------------------------------------------------------------- /scripts/PDEBench/darcy/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \ 4 | --loader pdebench_steady_darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --scheduler StepLR \ 8 | --downsamplex 1 \ 9 | --downsampley 1 \ 10 | --space_dim 2 \ 11 | --fun_dim 1 \ 12 | --out_dim 1 \ 13 | --model FNO \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 1 \ 18 | --ref 8 \ 19 | --batch-size 50 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --ntrain 8000 \ 23 | --save_name pdebench_darcy_FNO -------------------------------------------------------------------------------- /scripts/PDEBench/darcy/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \ 4 | --loader pdebench_steady_darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --scheduler StepLR \ 8 | --downsamplex 1 \ 9 | --downsampley 1 \ 10 | --space_dim 2 \ 11 | --fun_dim 1 \ 12 | --out_dim 1 \ 13 | --model MWT \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 50 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --ntrain 8000 \ 23 | --save_name pdebench_darcy_MWT -------------------------------------------------------------------------------- /scripts/PDEBench/darcy/Transfomer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \ 4 | --loader pdebench_steady_darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --scheduler StepLR \ 8 | --downsamplex 1 \ 9 | --downsampley 1 \ 10 | --space_dim 2 \ 11 | --fun_dim 1 \ 12 | --out_dim 1 \ 13 | --model Transformer \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 50 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --ntrain 8000 \ 23 | --save_name pdebench_darcy_Transformer -------------------------------------------------------------------------------- /scripts/PDEBench/darcy/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \ 4 | --loader pdebench_steady_darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --scheduler StepLR \ 8 | --downsamplex 1 \ 9 | --downsampley 1 \ 10 | --space_dim 2 \ 11 | --fun_dim 1 \ 12 | --out_dim 1 \ 13 | --model Transolver \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 20 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --ntrain 8000 \ 23 | --save_name pdebench_darcy_Transolver -------------------------------------------------------------------------------- /scripts/PDEBench/darcy/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \ 4 | --loader pdebench_steady_darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --scheduler StepLR \ 8 | --downsamplex 1 \ 9 | --downsampley 1 \ 10 | --space_dim 2 \ 11 | --fun_dim 1 \ 12 | --out_dim 1 \ 13 | --model U_FNO \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 50 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --ntrain 8000 \ 23 | --save_name pdebench_darcy_U_FNO -------------------------------------------------------------------------------- /scripts/PDEBench/darcy/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \ 4 | --loader pdebench_steady_darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --scheduler StepLR \ 8 | --downsamplex 1 \ 9 | --downsampley 1 \ 10 | --space_dim 2 \ 11 | --fun_dim 1 \ 12 | --out_dim 1 \ 13 | --model U_NO \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 50 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --ntrain 8000 \ 23 | --save_name pdebench_darcy_U_NO -------------------------------------------------------------------------------- /scripts/PDEBench/darcy/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \ 4 | --loader pdebench_steady_darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --scheduler StepLR \ 8 | --downsamplex 1 \ 9 | --downsampley 1 \ 10 | --space_dim 2 \ 11 | --fun_dim 1 \ 12 | --out_dim 1 \ 13 | --model U_Net \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 50 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --ntrain 8000 \ 23 | --save_name pdebench_darcy_U_Net -------------------------------------------------------------------------------- /scripts/PDEBench/diff_sorp/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \ 4 | --loader pdebench_autoregressive \ 5 | --geotype structured_1D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 1 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model FNO \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 1 \ 23 | --save_name diff_sorp_FNO -------------------------------------------------------------------------------- /scripts/PDEBench/diff_sorp/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \ 4 | --loader pdebench_autoregressive \ 5 | --geotype structured_1D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 1 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model MWT \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 50 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --mwt_k 9 \ 24 | --save_name diff_sorp_MWT -------------------------------------------------------------------------------- /scripts/PDEBench/diff_sorp/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \ 4 | --loader pdebench_autoregressive \ 5 | --geotype structured_1D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 1 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model Transolver \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name diff_sorp_Transolver -------------------------------------------------------------------------------- /scripts/PDEBench/diff_sorp/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \ 4 | --loader pdebench_autoregressive \ 5 | --geotype structured_1D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 1 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model U_FNO \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name diff_sorp_U_FNO -------------------------------------------------------------------------------- /scripts/PDEBench/diff_sorp/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \ 4 | --loader pdebench_autoregressive \ 5 | --geotype structured_1D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 1 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model U_NO \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name diff_sorp_U_NO -------------------------------------------------------------------------------- /scripts/PDEBench/diff_sorp/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \ 4 | --loader pdebench_autoregressive \ 5 | --geotype structured_1D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.001 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 1 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model U_Net \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name diff_sorp_U_Net -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model FNO \ 10 | --n_hidden 32 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/F_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --lr 0.001 \ 7 | --weight_decay 1e-4 \ 8 | --scheduler StepLR \ 9 | --space_dim 2 \ 10 | --fun_dim 0 \ 11 | --out_dim 1 \ 12 | --model F_FNO \ 13 | --n_hidden 32 \ 14 | --n_heads 8 \ 15 | --n_layers 8 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 20 \ 20 | --epochs 500 \ 21 | --vis_bound 40 180 0 35 \ 22 | --eval 0 \ 23 | --save_name airfoil_F_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/Factformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model Factformer \ 10 | --n_hidden 32 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_Factformer -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/GNOT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model GNOT \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_GNOT -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/Galerkin_Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model Galerkin_Transformer \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_Galerkin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/LSM.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model LSM \ 10 | --n_hidden 32 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_LSM -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model MWT \ 10 | --n_hidden 32 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_MWT -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/Swin.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model Swin \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --vis_bound 40 180 0 35 \ 20 | --eval 0 \ 21 | --save_name airfoil_Swin -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model Transformer \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --lr 0.001 \ 7 | --weight_decay 1e-4 \ 8 | --space_dim 2 \ 9 | --fun_dim 2 \ 10 | --out_dim 1 \ 11 | --model Transolver \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 8 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --max_grad_norm 0.1 \ 23 | --save_name airfoil_Transolver -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --lr 0.001 \ 7 | --weight_decay 1e-4 \ 8 | --scheduler StepLR \ 9 | --space_dim 2 \ 10 | --fun_dim 0 \ 11 | --out_dim 1 \ 12 | --model U_FNO \ 13 | --n_hidden 32 \ 14 | --n_heads 8 \ 15 | --n_layers 8 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 20 \ 20 | --epochs 500 \ 21 | --vis_bound 40 180 0 35 \ 22 | --eval 0 \ 23 | --save_name airfoil_U_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --lr 0.001 \ 7 | --weight_decay 1e-4 \ 8 | --scheduler StepLR \ 9 | --space_dim 2 \ 10 | --fun_dim 0 \ 11 | --out_dim 1 \ 12 | --model U_NO \ 13 | --n_hidden 32 \ 14 | --n_heads 8 \ 15 | --n_layers 8 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 20 \ 20 | --epochs 500 \ 21 | --vis_bound 40 180 0 35 \ 22 | --eval 0 \ 23 | --save_name airfoil_U_NO_AdamW -------------------------------------------------------------------------------- /scripts/StandardBench/airfoil/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/airfoil/naca \ 4 | --loader airfoil \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model U_Net \ 10 | --n_hidden 32 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --vis_bound 40 180 0 35 \ 19 | --eval 0 \ 20 | --save_name airfoil_U_Net -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/fno/ \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model FNO \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --unified_pos 1 \ 20 | --ref 8 \ 21 | --batch-size 4 \ 22 | --epochs 500 \ 23 | --eval 1 \ 24 | --save_name darcy_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/F_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model F_FNO \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 1 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --max_grad_norm 0.1 \ 26 | --save_name darcy_F_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/Factformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model Factformer \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --unified_pos 1 \ 20 | --ref 8 \ 21 | --batch-size 4 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name darcy_Factformer -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/GNOT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --weight_decay 0.00005 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model GNOT \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 3 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 0 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --save_name darcy_GNOT -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/Galerkin_Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model Galerkin_Transformer \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 1 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --save_name darcy_Galerkin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/LSM.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno/ \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --scheduler StepLR \ 7 | --task steady \ 8 | --normalize 1 \ 9 | --lr 0.0005 \ 10 | --optimizer Adam \ 11 | --weight_decay 1e-4 \ 12 | --norm_type UnitGaussianNormalizer \ 13 | --downsamplex 5 \ 14 | --downsampley 5 \ 15 | --space_dim 2 \ 16 | --fun_dim 1 \ 17 | --out_dim 1 \ 18 | --model LSM \ 19 | --n_hidden 64 \ 20 | --n_heads 8 \ 21 | --n_layers 8 \ 22 | --slice_num 64 \ 23 | --unified_pos 0 \ 24 | --ref 8 \ 25 | --batch-size 20 \ 26 | --epochs 500 \ 27 | --eval 0 \ 28 | --save_name darcy_LSM -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model MWT \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 1 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --save_name darcy_MWT -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/ONO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model ONO \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 10 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 0 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --psi_dim 32 \ 26 | --max_grad_norm 0.1 \ 27 | --save_name darcy_ONO 28 | 29 | -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/Swin.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model Swin_Transformer \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 1 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --save_name darcy_Swin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model Transformer \ 15 | --n_hidden 32 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 1 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --save_name darcy_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model Transolver \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 1 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --save_name darcy_Transolver -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/fno/ \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model U_FNO \ 15 | --n_hidden 32 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 1 \ 20 | --ref 8 \ 21 | --batch-size 4 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name darcy_U_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno/ \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model U_NO \ 15 | --n_hidden 32 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 1 \ 20 | --ref 8 \ 21 | --batch-size 4 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name darcy_U_NO -------------------------------------------------------------------------------- /scripts/StandardBench/darcy/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model U_Net \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --slice_num 64 \ 20 | --unified_pos 1 \ 21 | --ref 8 \ 22 | --batch-size 4 \ 23 | --epochs 500 \ 24 | --eval 0 \ 25 | --save_name darcy_U_Net -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --normalize 1 \ 10 | --model FNO \ 11 | --n_hidden 32 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --batch-size 4 \ 15 | --epochs 500 \ 16 | --eval 0 \ 17 | --save_name elas_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/F_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model F_FNO \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 1 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --max_grad_norm 0.1 \ 23 | --save_name elas_F_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/GNOT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model GNOT \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 1 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --save_name elas_GNOT -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/Galerkin_Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model Galerkin_Transformer \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 1 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --save_name elas_Galerkin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/LSM.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --normalize 1 \ 10 | --model LSM \ 11 | --n_hidden 32 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --batch-size 4 \ 15 | --epochs 500 \ 16 | --eval 0 \ 17 | --save_name elas_LSM -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model MWT \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 1 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --save_name elas_MWT -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model Transformer \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --slice_num 64 \ 16 | --unified_pos 0 \ 17 | --ref 8 \ 18 | --batch-size 1 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name elas_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model Transolver \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 1 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --save_name elas_Transolver -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model U_FNO \ 12 | --n_hidden 128 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 1 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --save_name elas_U_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --normalize 1 \ 10 | --model U_NO \ 11 | --n_hidden 32 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --batch-size 4 \ 15 | --epochs 500 \ 16 | --eval 0 \ 17 | --save_name elas_U_NO -------------------------------------------------------------------------------- /scripts/StandardBench/elasticity/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/ \ 4 | --loader elas \ 5 | --geotype unstructured \ 6 | --scheduler CosineAnnealingLR \ 7 | --space_dim 2 \ 8 | --fun_dim 0 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --model U_Net \ 12 | --n_hidden 32 \ 13 | --n_heads 8 \ 14 | --n_layers 8 \ 15 | --mlp_ratio 2 \ 16 | --slice_num 64 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 1 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --save_name elas_U_Net -------------------------------------------------------------------------------- /scripts/StandardBench/ns/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/fno/ \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 2 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model FNO \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name ns_FNO_wo_teacher_forcing_wo_unipos_real_steplr -------------------------------------------------------------------------------- /scripts/StandardBench/ns/F_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --scheduler StepLR \ 7 | --task dynamic_autoregressive \ 8 | --space_dim 2 \ 9 | --lr 0.0025 \ 10 | --weight_decay 0.0001 \ 11 | --fun_dim 10 \ 12 | --out_dim 1 \ 13 | --model F_FNO \ 14 | --n_hidden 20 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --mlp_ratio 2 \ 18 | --slice_num 32 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 20 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --max_grad_norm 0.1 \ 25 | --save_name ns_F_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/ns/Factformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --fun_dim 10 \ 9 | --out_dim 1 \ 10 | --model Factormer \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 1 \ 17 | --ref 8 \ 18 | --batch-size 2 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name ns_Factormer -------------------------------------------------------------------------------- /scripts/StandardBench/ns/GNOT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --weight_decay 0.00005 \ 9 | --fun_dim 10 \ 10 | --out_dim 1 \ 11 | --model GNOT \ 12 | --optimizer AdamW \ 13 | --n_hidden 128 \ 14 | --n_heads 8 \ 15 | --n_layers 3 \ 16 | --mlp_ratio 2 \ 17 | --slice_num 32 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 4 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name ns_GNOT -------------------------------------------------------------------------------- /scripts/StandardBench/ns/Galerkin_Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --fun_dim 10 \ 9 | --out_dim 1 \ 10 | --model Galerkin_Transformer \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 1 \ 17 | --ref 8 \ 18 | --batch-size 2 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name ns_Galerkin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/ns/LSM.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/ \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 2 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model LSM \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name ns_LSM_AdamW -------------------------------------------------------------------------------- /scripts/StandardBench/ns/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --fun_dim 10 \ 9 | --out_dim 1 \ 10 | --model MWT \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 1 \ 17 | --ref 8 \ 18 | --batch-size 2 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name ns_MWT -------------------------------------------------------------------------------- /scripts/StandardBench/ns/ONO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --scheduler OneCycleLR \ 8 | --space_dim 2 \ 9 | --fun_dim 10 \ 10 | --attn_type selfAttention \ 11 | --out_dim 1 \ 12 | --model ONO \ 13 | --n_hidden 128 \ 14 | --n_heads 8 \ 15 | --n_layers 8 \ 16 | --mlp_ratio 1 \ 17 | --slice_num 32 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 8 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --psi_dim 16 \ 24 | --teacher_forcing 0 \ 25 | --max_grad_norm 0.1 \ 26 | --save_name ns_ONO -------------------------------------------------------------------------------- /scripts/StandardBench/ns/Swin.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --fun_dim 10 \ 9 | --out_dim 1 \ 10 | --model Swin_Transformer \ 11 | --n_hidden 128 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 0 \ 17 | --ref 8 \ 18 | --batch-size 2 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name ns_Swin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/ns/Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --fun_dim 10 \ 9 | --out_dim 1 \ 10 | --model Transformer \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 1 \ 17 | --ref 8 \ 18 | --batch-size 2 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name ns_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/ns/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --fun_dim 10 \ 9 | --out_dim 1 \ 10 | --model Transolver \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 1 \ 17 | --ref 8 \ 18 | --batch-size 2 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name ns_Transolver -------------------------------------------------------------------------------- /scripts/StandardBench/ns/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno/ \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --lr 0.0005 \ 9 | --weight_decay 1e-4 \ 10 | --scheduler StepLR \ 11 | --space_dim 2 \ 12 | --fun_dim 10 \ 13 | --out_dim 1 \ 14 | --model U_FNO \ 15 | --n_hidden 64 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --unified_pos 0 \ 19 | --ref 8 \ 20 | --batch-size 20 \ 21 | --epochs 500 \ 22 | --eval 0 \ 23 | --save_name ns_U_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/ns/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/ \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --teacher_forcing 0 \ 8 | --scheduler StepLR \ 9 | --optimizer Adam \ 10 | --space_dim 2 \ 11 | --fun_dim 10 \ 12 | --out_dim 1 \ 13 | --model U_NO \ 14 | --n_hidden 64 \ 15 | --n_heads 8 \ 16 | --n_layers 8 \ 17 | --unified_pos 0 \ 18 | --ref 8 \ 19 | --batch-size 16 \ 20 | --epochs 500 \ 21 | --eval 0 \ 22 | --save_name ns_U_NO_Adam -------------------------------------------------------------------------------- /scripts/StandardBench/ns/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno \ 4 | --loader ns \ 5 | --geotype structured_2D \ 6 | --task dynamic_autoregressive \ 7 | --space_dim 2 \ 8 | --fun_dim 10 \ 9 | --out_dim 1 \ 10 | --model U_Net \ 11 | --n_hidden 256 \ 12 | --n_heads 8 \ 13 | --n_layers 8 \ 14 | --mlp_ratio 2 \ 15 | --slice_num 32 \ 16 | --unified_pos 1 \ 17 | --ref 8 \ 18 | --batch-size 2 \ 19 | --epochs 500 \ 20 | --eval 0 \ 21 | --save_name ns_U_Net -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model FNO \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/F_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model F_FNO \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --max_grad_norm 0.1 \ 21 | --normalize 1 \ 22 | --save_name pipe_F_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/Factformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader darcy \ 5 | --geotype structured_2D \ 6 | --task steady \ 7 | --normalize 1 \ 8 | --derivloss 1 \ 9 | --downsamplex 5 \ 10 | --downsampley 5 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 1 \ 14 | --model Factformer \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --mlp_ratio 2 \ 19 | --unified_pos 1 \ 20 | --ref 8 \ 21 | --batch-size 4 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --normalize 1 \ 25 | --save_name darcy_Factformer -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/GNOT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model GNOT \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_GNOT -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/Galerkin_Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model Galerkin_Transformer \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_Galerkin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/LSM.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 0 \ 8 | --out_dim 1 \ 9 | --model LSM \ 10 | --n_hidden 32 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --slice_num 64 \ 14 | --unified_pos 0 \ 15 | --ref 8 \ 16 | --batch-size 4 \ 17 | --epochs 500 \ 18 | --eval 0 \ 19 | --normalize 1 \ 20 | --save_name pipe_LSM -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model MWT \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_MWT -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/ONO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --scheduler OneCycleLR \ 7 | --space_dim 2 \ 8 | --fun_dim 2 \ 9 | --out_dim 1 \ 10 | --normalize 1 \ 11 | --max_grad_norm 0.1 \ 12 | --model ONO \ 13 | --n_hidden 128 \ 14 | --n_heads 8 \ 15 | --n_layers 8 \ 16 | --mlp_ratio 2 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_ONO -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/Swin.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model Swin_Transformer \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_Swin_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model Transformer \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_Transformer -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model Transolver \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_Transolver -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model U_FNO \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_U_FNO -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 6 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model U_NO \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_U_NO -------------------------------------------------------------------------------- /scripts/StandardBench/pipe/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model U_Net \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --normalize 1 \ 21 | --save_name pipe_U_Net -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 0 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model FNO \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_FNO 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/F_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model F_FNO \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --max_grad_norm 0.1 \ 25 | --save_name plas_F_FNO 26 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/Factformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 2 \ 3 | --data_path /data/fno/pipe \ 4 | --loader pipe \ 5 | --geotype structured_2D \ 6 | --space_dim 2 \ 7 | --fun_dim 2 \ 8 | --out_dim 1 \ 9 | --model Factformer \ 10 | --n_hidden 128 \ 11 | --n_heads 8 \ 12 | --n_layers 8 \ 13 | --mlp_ratio 2 \ 14 | --slice_num 64 \ 15 | --unified_pos 0 \ 16 | --ref 8 \ 17 | --batch-size 4 \ 18 | --epochs 500 \ 19 | --eval 0 \ 20 | --save_name pipe_Factformer -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/GNOT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model GNOT \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_GNOT 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/Galerkin_Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model Galerkin_Transformer \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_Galerkin_Transformer 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/LSM.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 3 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model LSM \ 15 | --n_hidden 32 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 4 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_LSM -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/MWT.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model MWT \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_MWT 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/ONO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 4 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model ONO \ 15 | --n_hidden 32 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 4 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_ONO -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/Swin.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model Swin_Transformer \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_Swin_Transformer 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/Transformer.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model Transformer \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_Transformer 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/Transolver.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 7 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model Transolver \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_Transolver 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/U_FNO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model U_FNO \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_U_FNO 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/U_NO.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 5 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model U_NO \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_U_NO_test 25 | -------------------------------------------------------------------------------- /scripts/StandardBench/plasticity/U_Net.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | --gpu 1 \ 3 | --data_path /data/fno/ \ 4 | --loader plas \ 5 | --geotype structured_2D \ 6 | --task dynamic_conditional \ 7 | --ntrain 900 \ 8 | --ntest 80 \ 9 | --T_out 20 \ 10 | --time_input 1 \ 11 | --space_dim 2 \ 12 | --fun_dim 1 \ 13 | --out_dim 4 \ 14 | --model U_Net \ 15 | --n_hidden 128 \ 16 | --n_heads 8 \ 17 | --n_layers 8 \ 18 | --slice_num 64 \ 19 | --unified_pos 0 \ 20 | --ref 8 \ 21 | --batch-size 8 \ 22 | --epochs 500 \ 23 | --eval 0 \ 24 | --save_name plas_U_Net 25 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | 6 | class L2Loss(object): 7 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 8 | super(L2Loss, self).__init__() 9 | 10 | assert d > 0 and p > 0 11 | 12 | self.d = d 13 | self.p = p 14 | self.reduction = reduction 15 | self.size_average = size_average 16 | 17 | def abs(self, x, y): 18 | num_examples = x.size()[0] 19 | 20 | h = 1.0 / (x.size()[1] - 1.0) 21 | 22 | all_norms = (h ** (self.d / self.p)) * torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 23 | 1) 24 | 25 | if self.reduction: 26 | if self.size_average: 27 | return torch.mean(all_norms) 28 | else: 29 | return torch.sum(all_norms) 30 | 31 | return all_norms 32 | 33 | def rel(self, x, y): 34 | num_examples = x.size()[0] 35 | 36 | diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) 37 | y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) 38 | if self.reduction: 39 | if self.size_average: 40 | return torch.mean(diff_norms / y_norms) 41 | else: 42 | return torch.sum(diff_norms / y_norms) 43 | 44 | return diff_norms / y_norms 45 | 46 | def __call__(self, x, y): 47 | return self.rel(x, y) 48 | 49 | 50 | class DerivLoss(object): 51 | def __init__(self, d=2, p=2, size_average=True, reduction=True, shapelist=None): 52 | super(DerivLoss, self).__init__() 53 | 54 | assert d > 0 and p > 0 55 | self.shapelist = shapelist 56 | self.de_x = L2Loss(d=d, p=p, size_average=size_average, reduction=reduction) 57 | self.de_y = L2Loss(d=d, p=p, size_average=size_average, reduction=reduction) 58 | 59 | def central_diff(self, x, h1, h2, s1, s2): 60 | # assuming PBC 61 | # x: (batch, n, feats), h is the step size, assuming n = h*w 62 | x = rearrange(x, 'b (h w) c -> b h w c', h=s1, w=s2) 63 | x = F.pad(x, 64 | (0, 0, 1, 1, 1, 1), mode='constant', value=0.) # [b c t h+2 w+2] 65 | grad_x = (x[:, 1:-1, 2:, :] - x[:, 1:-1, :-2, :]) / (2 * h1) # f(x+h) - f(x-h) / 2h 66 | grad_y = (x[:, 2:, 1:-1, :] - x[:, :-2, 1:-1, :]) / (2 * h2) # f(x+h) - f(x-h) / 2h 67 | 68 | return grad_x, grad_y 69 | 70 | def __call__(self, out, y): 71 | out = rearrange(out, 'b (h w) c -> b c h w', h=self.shapelist[0], w=self.shapelist[1]) 72 | out = out[..., 1:-1, 1:-1].contiguous() 73 | out = F.pad(out, (1, 1, 1, 1), "constant", 0) 74 | out = rearrange(out, 'b c h w -> b (h w) c') 75 | gt_grad_x, gt_grad_y = self.central_diff(y, 1.0 / float(self.shapelist[0]), 76 | 1.0 / float(self.shapelist[1]), self.shapelist[0], self.shapelist[1]) 77 | pred_grad_x, pred_grad_y = self.central_diff(out, 1.0 / float(self.shapelist[0]), 78 | 1.0 / float(self.shapelist[1]), self.shapelist[0], 79 | self.shapelist[1]) 80 | deriv_loss = self.de_x(pred_grad_x, gt_grad_x) + self.de_y(pred_grad_y, gt_grad_y) 81 | return deriv_loss 82 | -------------------------------------------------------------------------------- /utils/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import * 3 | 4 | 5 | class IdentityTransformer(): 6 | def __init__(self, X): 7 | self.mean = X.mean(dim=0, keepdim=True) 8 | self.std = X.std(dim=0, keepdim=True) + 1e-8 9 | 10 | def to(self, device): 11 | self.mean = self.mean.to(device) 12 | self.std = self.std.to(device) 13 | return self 14 | 15 | def cuda(self): 16 | self.mean = self.mean.cuda() 17 | self.std = self.std.cuda() 18 | 19 | def cpu(self): 20 | self.mean = self.mean.cpu() 21 | self.std = self.std.cpu() 22 | 23 | def encode(self, x): 24 | return x 25 | 26 | def decode(self, x): 27 | return x 28 | 29 | 30 | class UnitTransformer(): 31 | def __init__(self, X): 32 | self.mean = X.mean(dim=(0, 1), keepdim=True) 33 | self.std = X.std(dim=(0, 1), keepdim=True) + 1e-8 34 | 35 | def to(self, device): 36 | self.mean = self.mean.to(device) 37 | self.std = self.std.to(device) 38 | return self 39 | 40 | def cuda(self): 41 | self.mean = self.mean.cuda() 42 | self.std = self.std.cuda() 43 | 44 | def cpu(self): 45 | self.mean = self.mean.cpu() 46 | self.std = self.std.cpu() 47 | 48 | def encode(self, x): 49 | x = (x - self.mean) / (self.std) 50 | return x 51 | 52 | def decode(self, x): 53 | return x * self.std + self.mean 54 | 55 | def transform(self, X, inverse=True, component='all'): 56 | if component == 'all' or 'all-reduce': 57 | if inverse: 58 | orig_shape = X.shape 59 | return (X * (self.std - 1e-8) + self.mean).view(orig_shape) 60 | else: 61 | return (X - self.mean) / self.std 62 | else: 63 | if inverse: 64 | orig_shape = X.shape 65 | return (X * (self.std[:, component] - 1e-8) + self.mean[:, component]).view(orig_shape) 66 | else: 67 | return (X - self.mean[:, component]) / self.std[:, component] 68 | 69 | 70 | class UnitGaussianNormalizer(object): 71 | def __init__(self, x, eps=0.00001, time_last=True): 72 | super(UnitGaussianNormalizer, self).__init__() 73 | 74 | self.mean = torch.mean(x, 0) 75 | self.std = torch.std(x, 0) 76 | self.eps = eps 77 | self.time_last = time_last # if the time dimension is the last dim 78 | 79 | def encode(self, x): 80 | x = (x - self.mean) / (self.std + self.eps) 81 | return x 82 | 83 | def decode(self, x, sample_idx=None): 84 | # sample_idx is the spatial sampling mask 85 | if sample_idx is None: 86 | std = self.std + self.eps # n 87 | mean = self.mean 88 | else: 89 | if self.mean.ndim == sample_idx.ndim or self.time_last: 90 | std = self.std[sample_idx] + self.eps # batch*n 91 | mean = self.mean[sample_idx] 92 | if self.mean.ndim > sample_idx.ndim and not self.time_last: 93 | std = self.std[..., sample_idx] + self.eps # T*batch*n 94 | mean = self.mean[..., sample_idx] 95 | # x is in shape of batch*(spatial discretization size) or T*batch*(spatial discretization size) 96 | x = (x * std) + mean 97 | return x 98 | 99 | def to(self, device): 100 | if torch.is_tensor(self.mean): 101 | self.mean = self.mean.to(device) 102 | self.std = self.std.to(device) 103 | else: 104 | self.mean = torch.from_numpy(self.mean).to(device) 105 | self.std = torch.from_numpy(self.std).to(device) 106 | return self 107 | 108 | def cuda(self): 109 | self.mean = self.mean.cuda() 110 | self.std = self.std.cuda() 111 | 112 | def cpu(self): 113 | self.mean = self.mean.cpu() 114 | self.std = self.std.cpu() 115 | --------------------------------------------------------------------------------