├── .gitignore
├── LICENSE
├── README.md
├── data_provider
├── data_factory.py
├── data_loader.py
└── shapenet_utils.py
├── exp
├── exp_basic.py
├── exp_dynamic_autoregressive.py
├── exp_dynamic_conditional.py
├── exp_steady.py
└── exp_steady_design.py
├── layers
├── Basic.py
├── Embedding.py
├── FFNO_Layers.py
├── FNO_Layers.py
├── GeoFNO_Projection.py
├── MWT_Layers.py
├── Neural_Spectral_Block.py
├── Physics_Attention.py
└── UNet_Blocks.py
├── models
├── FNO.py
├── F_FNO.py
├── Factformer.py
├── GNOT.py
├── Galerkin_Transformer.py
├── GraphSAGE.py
├── Graph_UNet.py
├── LSM.py
├── MWT.py
├── ONO.py
├── PointNet.py
├── Swin_Transformer.py
├── Transformer.py
├── Transolver.py
├── U_FNO.py
├── U_NO.py
├── U_Net.py
└── model_factory.py
├── pic
├── logo.png
└── task.png
├── requirements.txt
├── run.py
├── scripts
├── DesignBench
│ └── car
│ │ ├── GNOT.sh
│ │ ├── GraphSAGE.sh
│ │ ├── GraphUNet.sh
│ │ ├── PointNet.sh
│ │ └── Transolver.sh
├── PDEBench
│ ├── 3DCFD
│ │ ├── FNO.sh
│ │ └── U_Net.sh
│ ├── darcy
│ │ ├── FNO.sh
│ │ ├── MWT.sh
│ │ ├── Transfomer.sh
│ │ ├── Transolver.sh
│ │ ├── U_FNO.sh
│ │ ├── U_NO.sh
│ │ └── U_Net.sh
│ └── diff_sorp
│ │ ├── FNO.sh
│ │ ├── MWT.sh
│ │ ├── Transolver.sh
│ │ ├── U_FNO.sh
│ │ ├── U_NO.sh
│ │ └── U_Net.sh
└── StandardBench
│ ├── airfoil
│ ├── FNO.sh
│ ├── F_FNO.sh
│ ├── Factformer.sh
│ ├── GNOT.sh
│ ├── Galerkin_Transformer.sh
│ ├── LSM.sh
│ ├── MWT.sh
│ ├── Swin.sh
│ ├── Transformer.sh
│ ├── Transolver.sh
│ ├── U_FNO.sh
│ ├── U_NO.sh
│ └── U_Net.sh
│ ├── darcy
│ ├── FNO.sh
│ ├── F_FNO.sh
│ ├── Factformer.sh
│ ├── GNOT.sh
│ ├── Galerkin_Transformer.sh
│ ├── LSM.sh
│ ├── MWT.sh
│ ├── ONO.sh
│ ├── Swin.sh
│ ├── Transformer.sh
│ ├── Transolver.sh
│ ├── U_FNO.sh
│ ├── U_NO.sh
│ └── U_Net.sh
│ ├── elasticity
│ ├── FNO.sh
│ ├── F_FNO.sh
│ ├── GNOT.sh
│ ├── Galerkin_Transformer.sh
│ ├── LSM.sh
│ ├── MWT.sh
│ ├── Transformer.sh
│ ├── Transolver.sh
│ ├── U_FNO.sh
│ ├── U_NO.sh
│ └── U_Net.sh
│ ├── ns
│ ├── FNO.sh
│ ├── F_FNO.sh
│ ├── Factformer.sh
│ ├── GNOT.sh
│ ├── Galerkin_Transformer.sh
│ ├── LSM.sh
│ ├── MWT.sh
│ ├── ONO.sh
│ ├── Swin.sh
│ ├── Transformer.sh
│ ├── Transolver.sh
│ ├── U_FNO.sh
│ ├── U_NO.sh
│ └── U_Net.sh
│ ├── pipe
│ ├── FNO.sh
│ ├── F_FNO.sh
│ ├── Factformer.sh
│ ├── GNOT.sh
│ ├── Galerkin_Transformer.sh
│ ├── LSM.sh
│ ├── MWT.sh
│ ├── ONO.sh
│ ├── Swin.sh
│ ├── Transformer.sh
│ ├── Transolver.sh
│ ├── U_FNO.sh
│ ├── U_NO.sh
│ └── U_Net.sh
│ └── plasticity
│ ├── FNO.sh
│ ├── F_FNO.sh
│ ├── Factformer.sh
│ ├── GNOT.sh
│ ├── Galerkin_Transformer.sh
│ ├── LSM.sh
│ ├── MWT.sh
│ ├── ONO.sh
│ ├── Swin.sh
│ ├── Transformer.sh
│ ├── Transolver.sh
│ ├── U_FNO.sh
│ ├── U_NO.sh
│ └── U_Net.sh
└── utils
├── drag_coefficient.py
├── loss.py
├── normalizer.py
└── visual.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # .DS_Store files
10 | .DS_Store
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # UV
101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | #uv.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | #pdm.lock
116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117 | # in version control.
118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
119 | .pdm.toml
120 | .pdm-python
121 | .pdm-build/
122 |
123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124 | __pypackages__/
125 |
126 | # Celery stuff
127 | celerybeat-schedule
128 | celerybeat.pid
129 |
130 | # SageMath parsed files
131 | *.sage.py
132 |
133 | # Environments
134 | .env
135 | .venv
136 | env/
137 | venv/
138 | ENV/
139 | env.bak/
140 | venv.bak/
141 |
142 | # Spyder project settings
143 | .spyderproject
144 | .spyproject
145 |
146 | # Rope project settings
147 | .ropeproject
148 |
149 | # mkdocs documentation
150 | /site
151 |
152 | # mypy
153 | .mypy_cache/
154 | .dmypy.json
155 | dmypy.json
156 |
157 | # Pyre type checker
158 | .pyre/
159 |
160 | # pytype static type analyzer
161 | .pytype/
162 |
163 | # Cython debug symbols
164 | cython_debug/
165 |
166 | # PyCharm
167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169 | # and can be added to the global gitignore or merged into this file. For a more nuclear
170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171 | #.idea/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
176 | # results and checkpoints
177 | results/
178 | checkpoints/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 THUML @ Tsinghua University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
Neural-Solver-Library (NeuralSolver)
2 |
3 | NeuralSolver is an open-source library for deep learning researchers, especially for neural PDE solvers.
4 |
5 | :triangular_flag_on_post:**News** (2025.03) We release the NeuralSolver as a simple and neat code base for benchmarking neural PDE solvers, which is extended from our previous GitHub repository [Transolver](https://github.com/thuml/Transolver).
6 |
7 | ## Features
8 |
9 | This library currently supports the following benchmarks:
10 |
11 | - Six Standard Benchmarks from [[FNO]](https://arxiv.org/abs/2010.08895) and [[geo-FNO]](https://arxiv.org/abs/2207.05209)
12 | - PDEBench [[NeurIPS 2022 Track Datasets and Benchmarks]](https://arxiv.org/abs/2210.07182) for benchmarking autoregressive tasks
13 | - ShapeNet-Car from [[TOG 2018]](https://dl.acm.org/doi/abs/10.1145/3197517.3201325) for benchmarking industrial design tasks
14 |
15 |
16 |
17 |
18 | Figure 1. Examples of supported PDE-solving tasks.
19 |
20 |
21 | ## Supported Neural Solvers
22 |
23 | Here is the list of supported neural PDE solvers:
24 |
25 | - [x] **Transolver** - Transolver: A Fast Transformer Solver for PDEs on General Geometries [[ICML 2024]](https://arxiv.org/abs/2402.02366) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Transolver.py)
26 | - [x] **ONO** - Improved Operator Learning by Orthogonal Attention [[ICML 2024]](https://arxiv.org/abs/2310.12487v3) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/ONO.py)
27 | - [x] **Factformer** - Scalable Transformer for PDE Surrogate Modeling [[NeurIPS 2023]](https://arxiv.org/abs/2305.17560) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Factformer.py)
28 | - [x] **U-NO** - U-NO: U-shaped Neural Operators [[TMLR 2023]](https://openreview.net/pdf?id=j3oQF9coJd) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/U_NO.py)
29 | - [x] **LSM** - Solving High-Dimensional PDEs with Latent Spectral Models [[ICML 2023]](https://arxiv.org/pdf/2301.12664) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/LSM.py)
30 | - [x] **GNOT** - GNOT: A General Neural Operator Transformer for Operator Learning [[ICML 2023]](https://arxiv.org/abs/2302.14376) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/GNOT.py)
31 | - [x] **F-FNO** - Factorized Fourier Neural Operators [[ICLR 2023]](https://arxiv.org/abs/2111.13802) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/F_FNO.py)
32 | - [x] **U-FNO** - An enhanced Fourier neural operator-based deep-learning model for multiphase flow [[Advances in Water Resources 2022]](https://www.sciencedirect.com/science/article/pii/S0309170822000562) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/U_FNO.py)
33 | - [x] **Galerkin Transformer** - Choose a Transformer: Fourier or Galerkin [[NeurIPS 2021]](https://arxiv.org/abs/2105.14995) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Galerkin_Transformer.py)
34 | - [x] **MWT** - Multiwavelet-based Operator Learning for Differential Equations [[NeurIPS 2021]](https://openreview.net/forum?id=LZDiWaC9CGL) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/MWT.py)
35 | - [x] **FNO** - Fourier Neural Operator for Parametric Partial Differential Equations [[ICLR 2021]](https://arxiv.org/pdf/2010.08895) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/FNO.py)
36 | - [x] **Transformer** - Attention Is All You Need [[NeurIPS 2017]](https://arxiv.org/pdf/1706.03762) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Transformer.py)
37 |
38 | Some vision backbones can be good baselines for tasks in structured geometries:
39 |
40 | - [x] **Swin Transformer** - Swin Transformer: Hierarchical Vision Transformer using Shifted Windows [[ICCV 2021]](https://arxiv.org/abs/2103.14030) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Swin_Transformer.py)
41 | - [x] **U-Net** - U-Net: Convolutional Networks for Biomedical Image Segmentation [[MICCAI 2015]](https://arxiv.org/pdf/1505.04597) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/U_Net.py)
42 |
43 | Some classical geometric deep models are also included for design tasks:
44 |
45 | - [x] **Graph-UNet** - Graph U-Nets [[ICML 2019]](https://arxiv.org/pdf/1905.05178) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/Graph_UNet.py)
46 | - [x] **GraphSAGE** - Inductive Representation Learning on Large Graphs [[NeurIPS 2017]](https://arxiv.org/pdf/1706.02216) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/GraphSAGE.py)
47 | - [x] **PointNet** - PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation [[CVPR 2017]](https://arxiv.org/pdf/1612.00593) [[Code]](https://github.com/thuml/Neural-Solver-Library/blob/main/models/PointNet.py)
48 |
49 | 🌟 We have made a great effort to ensure good reproducibility, and are glad to claim that the official results of all the above methods can be completely reproduced (sometimes even better) by this library.
50 |
51 | ## Usage
52 |
53 | 1. Install Python 3.8. For convenience, execute the following command.
54 |
55 | ```bash
56 | pip install -r requirements.txt
57 | ```
58 |
59 | 2. Prepare Data
60 | 3. Train and evaluate the model. We provide the experiment scripts for all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as follows:
61 |
62 | ```bash
63 | bash ./scripts/StandardBench/airfoil/Transolver.sh
64 | ```
65 |
66 | 4. Develop your own model.
67 |
68 | - Add the model file to the folder `./models`. You can follow the `./models/Transolver.py`.
69 | - Include the newly added model in the `model_dict` of `./models/model_factory.py`.
70 | - Create the corresponding scripts under the folder `./scripts`, where you can set hyperparameters following the provided scripts of other models.
71 |
72 | ## Citation
73 |
74 | If you find this repo useful, please cite our paper.
75 |
76 | ```
77 | @inproceedings{wu2024Transolver,
78 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries},
79 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long},
80 | booktitle={International Conference on Machine Learning},
81 | year={2024}
82 | }
83 | ```
84 |
85 | ## Contact
86 |
87 | If you have any questions or want to use the code, please contact our team or describe it in Issues.
88 |
89 | Current maintenance team:
90 |
91 | - Haixu Wu (Ph.D. student, [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn))
92 | - Yuanxu Sun (Undergraduate, sunyuanx22@mails.tsinghua.edu.cn)
93 | - Hang Zhou (Master student, zhou-h23@mails.tsinghua.edu.cn)
94 | - Yuezhou Ma (Ph.D. student, mayz24@mails.tsinghua.edu.cn)
95 |
96 | ## Acknowledgement
97 |
98 | We appreciate the following GitHub repos a lot for their valuable code base or datasets:
99 |
100 | https://github.com/thuml/Transolver
101 |
102 | https://github.com/thuml/Latent-Spectral-Models
103 |
104 | https://github.com/neuraloperator/neuraloperator
105 |
106 | https://github.com/neuraloperator/Geo-FNO
107 |
--------------------------------------------------------------------------------
/data_provider/data_factory.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_loader import airfoil, ns, darcy, pipe, elas, plas, pdebench_autoregressive, \
2 | pdebench_steady_darcy, car_design, cfd3d
3 |
4 |
5 | def get_data(args):
6 | data_dict = {
7 | 'car_design': car_design,
8 | 'pdebench_autoregressive': pdebench_autoregressive,
9 | 'pdebench_steady_darcy': pdebench_steady_darcy,
10 | 'elas': elas,
11 | 'pipe': pipe,
12 | 'airfoil': airfoil,
13 | 'darcy': darcy,
14 | 'ns': ns,
15 | 'plas': plas,
16 | 'cfd3d': cfd3d,
17 | }
18 | dataset = data_dict[args.loader](args)
19 | train_loader, test_loader, shapelist = dataset.get_loader()
20 | return dataset, train_loader, test_loader, shapelist
21 |
--------------------------------------------------------------------------------
/exp/exp_basic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from models.model_factory import get_model
4 | from data_provider.data_factory import get_data
5 |
6 |
7 | def count_parameters(model):
8 | total_params = 0
9 | for name, parameter in model.named_parameters():
10 | if not parameter.requires_grad: continue
11 | params = parameter.numel()
12 | total_params += params
13 | print(f"Total Trainable Params: {total_params}")
14 | return total_params
15 |
16 |
17 | class Exp_Basic(object):
18 | def __init__(self, args):
19 | self.dataset, self.train_loader, self.test_loader, args.shapelist = get_data(args)
20 | self.model = get_model(args).cuda()
21 | self.args = args
22 | print(self.args)
23 | print(self.model)
24 | count_parameters(self.model)
25 |
26 | def vali(self):
27 | pass
28 |
29 | def train(self):
30 | pass
31 |
32 | def test(self):
33 | pass
34 |
--------------------------------------------------------------------------------
/exp/exp_dynamic_autoregressive.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from exp.exp_basic import Exp_Basic
4 | from models.model_factory import get_model
5 | from data_provider.data_factory import get_data
6 | from utils.loss import L2Loss
7 | import matplotlib.pyplot as plt
8 | from utils.visual import visual
9 | import numpy as np
10 |
11 |
12 | class Exp_Dynamic_Autoregressive(Exp_Basic):
13 | def __init__(self, args):
14 | super(Exp_Dynamic_Autoregressive, self).__init__(args)
15 |
16 | def vali(self):
17 | myloss = L2Loss(size_average=False)
18 | test_l2_full = 0
19 | self.model.eval()
20 | with torch.no_grad():
21 | for x, fx, yy in self.test_loader:
22 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda()
23 | for t in range(self.args.T_out):
24 | if self.args.fun_dim == 0:
25 | fx = None
26 | im = self.model(x, fx=fx)
27 | if t == 0:
28 | pred = im
29 | else:
30 | pred = torch.cat((pred, im), -1)
31 | fx = torch.cat((fx[..., self.args.out_dim:], im), dim=-1)
32 | if self.args.normalize:
33 | pred = self.dataset.y_normalizer.decode(pred)
34 | test_l2_full += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item()
35 | test_loss_full = test_l2_full / (self.args.ntest)
36 | return test_loss_full
37 |
38 | def train(self):
39 | if self.args.optimizer == 'AdamW':
40 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
41 | elif self.args.optimizer == 'Adam':
42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
43 | else:
44 | raise ValueError('Optimizer only AdamW or Adam')
45 | if self.args.scheduler == 'OneCycleLR':
46 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.args.lr, epochs=self.args.epochs,
47 | steps_per_epoch=len(self.train_loader),
48 | pct_start=self.args.pct_start)
49 | elif self.args.scheduler == 'CosineAnnealingLR':
50 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs)
51 | elif self.args.scheduler == 'StepLR':
52 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
53 |
54 | myloss = L2Loss(size_average=False)
55 |
56 | for ep in range(self.args.epochs):
57 | self.model.train()
58 | train_l2_step = 0
59 | train_l2_full = 0
60 |
61 | for pos, fx, yy in self.train_loader:
62 | loss = 0
63 | x, fx, yy = pos.cuda(), fx.cuda(), yy.cuda()
64 | for t in range(self.args.T_out):
65 | y = yy[..., self.args.out_dim * t:self.args.out_dim * (t + 1)]
66 | if self.args.fun_dim == 0:
67 | fx = None
68 | im = self.model(x, fx=fx)
69 | loss += myloss(im.reshape(x.shape[0], -1), y.reshape(x.shape[0], -1))
70 | if t == 0:
71 | pred = im
72 | else:
73 | pred = torch.cat((pred, im), -1)
74 |
75 | if self.args.teacher_forcing:
76 | fx = torch.cat((fx[..., self.args.out_dim:], y), dim=-1)
77 | else:
78 | fx = torch.cat((fx[..., self.args.out_dim:], im), dim=-1)
79 |
80 | train_l2_step += loss.item()
81 | train_l2_full += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item()
82 | optimizer.zero_grad()
83 | loss.backward()
84 |
85 | if self.args.max_grad_norm is not None:
86 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
87 | optimizer.step()
88 |
89 | if self.args.scheduler == 'OneCycleLR':
90 | scheduler.step()
91 | if self.args.scheduler == 'CosineAnnealingLR' or self.args.scheduler == 'StepLR':
92 | scheduler.step()
93 |
94 | train_loss_step = train_l2_step / (self.args.ntrain * float(self.args.T_out))
95 | train_loss_full = train_l2_full / (self.args.ntrain)
96 | print("Epoch {} Train loss step : {:.5f} Train loss full : {:.5f}".format(ep, train_loss_step,
97 | train_loss_full))
98 |
99 | test_loss_full = self.vali()
100 | print("Epoch {} Test loss full : {:.5f}".format(ep, test_loss_full))
101 |
102 | if ep % 100 == 0:
103 | if not os.path.exists('./checkpoints'):
104 | os.makedirs('./checkpoints')
105 | print('save models')
106 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt'))
107 |
108 | if not os.path.exists('./checkpoints'):
109 | os.makedirs('./checkpoints')
110 | print('final save models')
111 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt'))
112 |
113 | def test(self):
114 | self.model.load_state_dict(torch.load("./checkpoints/" + self.args.save_name + ".pt"))
115 | self.model.eval()
116 | if not os.path.exists('./results/' + self.args.save_name + '/'):
117 | os.makedirs('./results/' + self.args.save_name + '/')
118 |
119 | rel_err = 0.0
120 | id = 0
121 | myloss = L2Loss(size_average=False)
122 | with torch.no_grad():
123 | for x, fx, yy in self.test_loader:
124 | id += 1
125 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda()
126 | for t in range(self.args.T_out):
127 | if self.args.fun_dim == 0:
128 | fx = None
129 | im = self.model(x, fx=fx)
130 | fx = torch.cat((fx[..., self.args.out_dim:], im), dim=-1)
131 | if t == 0:
132 | pred = im
133 | else:
134 | pred = torch.cat((pred, im), -1)
135 | if self.args.normalize:
136 | pred = self.dataset.y_normalizer.decode(pred)
137 | rel_err += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item()
138 | if id < self.args.vis_num:
139 | print('visual: ', id)
140 | for t in range(self.args.T_out):
141 | visual(x, yy[:, :, self.args.out_dim * t:self.args.out_dim * (t + 1)],
142 | pred[:, :, self.args.out_dim * t:self.args.out_dim * (t + 1)], self.args,
143 | str(id) + '_' + str(t))
144 |
145 | rel_err /= self.args.ntest
146 | print("rel_err:{}".format(rel_err))
147 |
--------------------------------------------------------------------------------
/exp/exp_dynamic_conditional.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from exp.exp_basic import Exp_Basic
4 | from models.model_factory import get_model
5 | from data_provider.data_factory import get_data
6 | from utils.loss import L2Loss
7 | import matplotlib.pyplot as plt
8 | from utils.visual import visual
9 | import numpy as np
10 |
11 |
12 | class Exp_Dynamic_Conditional(Exp_Basic):
13 | def __init__(self, args):
14 | super(Exp_Dynamic_Conditional, self).__init__(args)
15 |
16 | def vali(self):
17 | myloss = L2Loss(size_average=False)
18 | test_l2_full = 0
19 | self.model.eval()
20 | with torch.no_grad():
21 | for x, time, fx, yy in self.test_loader:
22 | x, time, fx, yy = x.cuda(), time.cuda(), fx.cuda(), yy.cuda()
23 | for t in range(self.args.T_out):
24 | input_T = time[:, t:t + 1].reshape(x.shape[0], 1)
25 | if self.args.fun_dim == 0:
26 | fx = None
27 | im = self.model(x, fx=fx, T=input_T)
28 | if t == 0:
29 | pred = im
30 | else:
31 | pred = torch.cat((pred, im), -1)
32 | if self.args.normalize:
33 | pred = self.dataset.y_normalizer.decode(pred)
34 | test_l2_full += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item()
35 | test_loss_full = test_l2_full / (self.args.ntest)
36 | return test_loss_full
37 |
38 | def train(self):
39 | if self.args.optimizer == 'AdamW':
40 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
41 | elif self.args.optimizer == 'Adam':
42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
43 | else:
44 | raise ValueError('Optimizer only AdamW or Adam')
45 | if self.args.scheduler == 'OneCycleLR':
46 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.args.lr, epochs=self.args.epochs,
47 | steps_per_epoch=len(self.train_loader),
48 | pct_start=self.args.pct_start)
49 | elif self.args.scheduler == 'CosineAnnealingLR':
50 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs)
51 | elif self.args.scheduler == 'StepLR':
52 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
53 | myloss = L2Loss(size_average=False)
54 |
55 | for ep in range(self.args.epochs):
56 | self.model.train()
57 | train_l2_step = 0
58 |
59 | for pos, time, fx, yy in self.train_loader:
60 | x, time, fx, yy = pos.cuda(), time.cuda(), fx.cuda(), yy.cuda()
61 | for t in range(self.args.T_out):
62 | y = yy[..., self.args.out_dim * t:self.args.out_dim * (t + 1)]
63 | input_T = time[:, t:t + 1].reshape(x.shape[0], 1)
64 | if self.args.fun_dim == 0:
65 | fx = None
66 | im = self.model(x, fx=fx, T=input_T)
67 | loss = myloss(im.reshape(x.shape[0], -1), y.reshape(x.shape[0], -1))
68 | train_l2_step += loss.item()
69 | optimizer.zero_grad()
70 | loss.backward()
71 |
72 | if self.args.max_grad_norm is not None:
73 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
74 | optimizer.step()
75 |
76 | if self.args.scheduler == 'OneCycleLR':
77 | scheduler.step()
78 | if self.args.scheduler == 'CosineAnnealingLR' or self.args.scheduler == 'StepLR':
79 | scheduler.step()
80 |
81 | train_loss_step = train_l2_step / (self.args.ntrain * float(self.args.T_out))
82 | print("Epoch {} Train loss step : {:.5f} ".format(ep, train_loss_step))
83 |
84 | test_loss_full = self.vali()
85 | print("Epoch {} Test loss full : {:.5f}".format(ep, test_loss_full))
86 |
87 | if ep % 100 == 0:
88 | if not os.path.exists('./checkpoints'):
89 | os.makedirs('./checkpoints')
90 | print('save models')
91 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt'))
92 |
93 | if not os.path.exists('./checkpoints'):
94 | os.makedirs('./checkpoints')
95 | print('final save models')
96 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt'))
97 |
98 | def test(self):
99 | self.model.load_state_dict(torch.load("./checkpoints/" + self.args.save_name + ".pt"))
100 | self.model.eval()
101 | if not os.path.exists('./results/' + self.args.save_name + '/'):
102 | os.makedirs('./results/' + self.args.save_name + '/')
103 |
104 | rel_err = 0.0
105 | id = 0
106 | myloss = L2Loss(size_average=False)
107 | with torch.no_grad():
108 | for x, time, fx, yy in self.test_loader:
109 | id += 1
110 | x, time, fx, yy = x.cuda(), time.cuda(), fx.cuda(), yy.cuda() # x : B, 4096, 2 fx : B, 4096 y : B, 4096, T
111 | for t in range(self.args.T_out):
112 | input_T = time[:, t:t + 1].reshape(x.shape[0], 1) # B,step
113 | if self.args.fun_dim == 0:
114 | fx = None
115 | im = self.model(x, fx=fx, T=input_T)
116 | if t == 0:
117 | pred = im
118 | else:
119 | pred = torch.cat((pred, im), -1)
120 | if self.args.normalize:
121 | pred = self.dataset.y_normalizer.decode(pred)
122 | rel_err += myloss(pred.reshape(x.shape[0], -1), yy.reshape(x.shape[0], -1)).item()
123 |
124 | if id < self.args.vis_num:
125 | print('visual: ', id)
126 | visual(yy[:, :, -4:-2], torch.sqrt(yy[:, :, -1:] ** 2 + yy[:, :, -2:-1] ** 2),
127 | torch.sqrt(pred[:, :, -1:] ** 2 + pred[:, :, -2:-1] ** 2), self.args, id)
128 |
129 | rel_err /= self.args.ntest
130 | print("rel_err:{}".format(rel_err))
131 |
--------------------------------------------------------------------------------
/exp/exp_steady.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from exp.exp_basic import Exp_Basic
4 | from models.model_factory import get_model
5 | from data_provider.data_factory import get_data
6 | from utils.loss import L2Loss, DerivLoss
7 | import matplotlib.pyplot as plt
8 | from utils.visual import visual
9 | import numpy as np
10 |
11 |
12 | class Exp_Steady(Exp_Basic):
13 | def __init__(self, args):
14 | super(Exp_Steady, self).__init__(args)
15 |
16 | def vali(self):
17 | myloss = L2Loss(size_average=False)
18 | self.model.eval()
19 | rel_err = 0.0
20 | with torch.no_grad():
21 | for pos, fx, y in self.test_loader:
22 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
23 | if self.args.fun_dim == 0:
24 | fx = None
25 | out = self.model(x, fx)
26 | if self.args.normalize:
27 | out = self.dataset.y_normalizer.decode(out)
28 |
29 | tl = myloss(out, y).item()
30 | rel_err += tl
31 |
32 | rel_err /= self.args.ntest
33 | return rel_err
34 |
35 | def train(self):
36 | if self.args.optimizer == 'AdamW':
37 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
38 | elif self.args.optimizer == 'Adam':
39 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
40 | else:
41 | raise ValueError('Optimizer only AdamW or Adam')
42 |
43 | if self.args.scheduler == 'OneCycleLR':
44 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.args.lr, epochs=self.args.epochs,
45 | steps_per_epoch=len(self.train_loader),
46 | pct_start=self.args.pct_start)
47 | elif self.args.scheduler == 'CosineAnnealingLR':
48 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs)
49 | elif self.args.scheduler == 'StepLR':
50 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
51 | myloss = L2Loss(size_average=False)
52 | if self.args.derivloss:
53 | regloss = DerivLoss(size_average=False, shapelist=self.args.shapelist)
54 |
55 | for ep in range(self.args.epochs):
56 |
57 | self.model.train()
58 | train_loss = 0
59 |
60 | for pos, fx, y in self.train_loader:
61 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
62 | if self.args.fun_dim == 0:
63 | fx = None
64 | out = self.model(x, fx)
65 | if self.args.normalize:
66 | out = self.dataset.y_normalizer.decode(out)
67 | y = self.dataset.y_normalizer.decode(y)
68 |
69 | if self.args.derivloss:
70 | loss = myloss(out, y) + 0.1 * regloss(out, y)
71 | else:
72 | loss = myloss(out, y)
73 |
74 | train_loss += loss.item()
75 | optimizer.zero_grad()
76 | loss.backward()
77 |
78 | if self.args.max_grad_norm is not None:
79 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
80 | optimizer.step()
81 |
82 | if self.args.scheduler == 'OneCycleLR':
83 | scheduler.step()
84 | if self.args.scheduler == 'CosineAnnealingLR' or self.args.scheduler == 'StepLR':
85 | scheduler.step()
86 |
87 | train_loss = train_loss / self.args.ntrain
88 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss))
89 |
90 | rel_err = self.vali()
91 | print("rel_err:{}".format(rel_err))
92 |
93 | if ep % 100 == 0:
94 | if not os.path.exists('./checkpoints'):
95 | os.makedirs('./checkpoints')
96 | print('save models')
97 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt'))
98 |
99 | if not os.path.exists('./checkpoints'):
100 | os.makedirs('./checkpoints')
101 | print('final save models')
102 | torch.save(self.model.state_dict(), os.path.join('./checkpoints', self.args.save_name + '.pt'))
103 |
104 | def test(self):
105 | self.model.load_state_dict(torch.load("./checkpoints/" + self.args.save_name + ".pt"))
106 | self.model.eval()
107 | if not os.path.exists('./results/' + self.args.save_name + '/'):
108 | os.makedirs('./results/' + self.args.save_name + '/')
109 |
110 | rel_err = 0.0
111 | id = 0
112 | myloss = L2Loss(size_average=False)
113 | with torch.no_grad():
114 | for pos, fx, y in self.test_loader:
115 | id += 1
116 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
117 | if self.args.fun_dim == 0:
118 | fx = None
119 | out = self.model(x, fx)
120 | if self.args.normalize:
121 | out = self.dataset.y_normalizer.decode(out)
122 | tl = myloss(out, y).item()
123 | rel_err += tl
124 | if id < self.args.vis_num:
125 | print('visual: ', id)
126 | visual(x, y, out, self.args, id)
127 |
128 | rel_err /= self.args.ntest
129 | print("rel_err:{}".format(rel_err))
130 |
--------------------------------------------------------------------------------
/layers/Embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange
5 | import numpy as np
6 |
7 |
8 | def unified_pos_embedding(shapelist, ref, batchsize=1, device='cuda'):
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
10 | if len(shapelist) == 1:
11 | size_x = shapelist[0]
12 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
13 | grid = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]).to(device) # B N 1
14 | gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float)
15 | grid_ref = gridx.reshape(1, ref, 1).repeat([batchsize, 1, 1]).to(device) # B N 1
16 | pos = torch.sqrt(torch.sum((grid[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, dim=-1)). \
17 | reshape(batchsize, size_x, ref).contiguous()
18 | if len(shapelist) == 2:
19 | size_x, size_y = shapelist[0], shapelist[1]
20 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
21 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
22 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
23 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
24 | grid = torch.cat((gridx, gridy), dim=-1).to(device) # B H W 2
25 |
26 | gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float)
27 | gridx = gridx.reshape(1, ref, 1, 1).repeat([batchsize, 1, ref, 1])
28 | gridy = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float)
29 | gridy = gridy.reshape(1, 1, ref, 1).repeat([batchsize, ref, 1, 1])
30 | grid_ref = torch.cat((gridx, gridy), dim=-1).to(device) # B H W 8 8 2
31 |
32 | pos = torch.sqrt(torch.sum((grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) ** 2, dim=-1)). \
33 | reshape(batchsize, size_x * size_y, ref * ref).contiguous()
34 | if len(shapelist) == 3:
35 | size_x, size_y, size_z = shapelist[0], shapelist[1], shapelist[2]
36 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
37 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
38 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
39 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
40 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
41 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
42 | grid = torch.cat((gridx, gridy, gridz), dim=-1).to(device) # B H W D 3
43 |
44 | gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float)
45 | gridx = gridx.reshape(1, ref, 1, 1, 1).repeat([batchsize, 1, ref, ref, 1])
46 | gridy = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float)
47 | gridy = gridy.reshape(1, 1, ref, 1, 1).repeat([batchsize, ref, 1, ref, 1])
48 | gridz = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float)
49 | gridz = gridz.reshape(1, 1, 1, ref, 1).repeat([batchsize, ref, ref, 1, 1])
50 | grid_ref = torch.cat((gridx, gridy, gridz), dim=-1).to(device) # B 4 4 4 3
51 |
52 | pos = torch.sqrt(
53 | torch.sum((grid[:, :, :, :, None, None, None, :] - grid_ref[:, None, None, None, :, :, :, :]) ** 2,
54 | dim=-1)). \
55 | reshape(batchsize, size_x * size_y * size_z, ref * ref * ref).contiguous()
56 | return pos
57 |
58 |
59 | class RotaryEmbedding(nn.Module):
60 | def __init__(self, dim, min_freq=1 / 2, scale=1.):
61 | super().__init__()
62 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
63 | self.min_freq = min_freq
64 | self.scale = scale
65 | self.register_buffer('inv_freq', inv_freq)
66 |
67 | def forward(self, coordinates, device='cuda'):
68 | # coordinates [b, n]
69 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
70 | t = coordinates.to(device).type_as(self.inv_freq)
71 | t = t * (self.scale / self.min_freq)
72 | freqs = torch.einsum('... i , j -> ... i j', t, self.inv_freq) # [b, n, d//2]
73 | return torch.cat((freqs, freqs), dim=-1) # [b, n, d]
74 |
75 |
76 | def rotate_half(x):
77 | x = rearrange(x, '... (j d) -> ... j d', j=2)
78 | x1, x2 = x.unbind(dim=-2)
79 | return torch.cat((-x2, x1), dim=-1)
80 |
81 |
82 | def apply_rotary_pos_emb(t, freqs):
83 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
84 |
85 |
86 | def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y):
87 | # split t into first half and second half
88 | # t: [b, h, n, d]
89 | # freq_x/y: [b, n, d]
90 | d = t.shape[-1]
91 | t_x, t_y = t[..., :d // 2], t[..., d // 2:]
92 |
93 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x),
94 | apply_rotary_pos_emb(t_y, freqs_y)), dim=-1)
95 |
96 |
97 | class PositionalEncoding(nn.Module):
98 | "Implement the PE function."
99 |
100 | def __init__(self, d_model, dropout, max_len=421 * 421):
101 | super(PositionalEncoding, self).__init__()
102 | self.dropout = nn.Dropout(p=dropout)
103 |
104 | # Compute the positional encodings once in log space.
105 | pe = torch.zeros(max_len, d_model)
106 | position = torch.arange(0, max_len).unsqueeze(1)
107 | div_term = torch.exp(
108 | torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
109 | )
110 | pe[:, 0::2] = torch.sin(position * div_term)
111 | pe[:, 1::2] = torch.cos(position * div_term)
112 | pe = pe.unsqueeze(0)
113 | self.register_buffer("pe", pe)
114 |
115 | def forward(self, x):
116 | x = x + self.pe[:, : x.size(1)].requires_grad_(False)
117 | return self.dropout(x)
118 |
119 |
120 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
121 | """
122 | Create sinusoidal timestep embeddings.
123 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
124 | These may be fractional.
125 | :param dim: the dimension of the output.
126 | :param max_period: controls the minimum frequency of the embeddings.
127 | :return: an [N x dim] Tensor of positional embeddings.
128 | """
129 |
130 | half = dim // 2
131 | freqs = torch.exp(
132 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
133 | ).to(device=timesteps.device)
134 | args = timesteps[:, None].float() * freqs[None]
135 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
136 | if dim % 2:
137 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:,:,:1])], dim=-1)
138 | return embedding
139 |
--------------------------------------------------------------------------------
/layers/FFNO_Layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import torch
4 | import numpy as np
5 | import math
6 |
7 | ################################################################
8 | # 1d fourier layer
9 | ################################################################
10 | class SpectralConv1d(nn.Module):
11 |
12 | ## FFNO degenerate to FNO in 1D space
13 | def __init__(self, in_channels, out_channels, modes1):
14 | super(SpectralConv1d, self).__init__()
15 |
16 | """
17 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
18 | """
19 |
20 | self.in_channels = in_channels
21 | self.out_channels = out_channels
22 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
23 |
24 | self.scale = (1 / (in_channels * out_channels))
25 | self.weights1 = nn.Parameter(
26 | self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
27 |
28 | # Complex multiplication
29 | def compl_mul1d(self, input, weights):
30 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
31 | return torch.einsum("bix,iox->box", input, weights)
32 |
33 | def forward(self, x):
34 | batchsize = x.shape[0]
35 | # Compute Fourier coeffcients up to factor of e^(- something constant)
36 | x_ft = torch.fft.rfft(x)
37 |
38 | # Multiply relevant Fourier modes
39 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat)
40 | out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)
41 |
42 | # Return to physical space
43 | x = torch.fft.irfft(out_ft, n=x.size(-1))
44 | return x
45 |
46 | ################################################################
47 | # 2d fourier layer
48 | ################################################################
49 | class SpectralConv2d(nn.Module):
50 | def __init__(self, in_dim, out_dim, modes_x, modes_y):
51 | super().__init__()
52 | self.in_dim = in_dim
53 | self.out_dim = out_dim
54 | self.modes_x = modes_x
55 | self.modes_y = modes_y
56 |
57 | self.fourier_weight = nn.ParameterList([])
58 | for n_modes in [modes_x, modes_y]:
59 | weight = torch.FloatTensor(in_dim, out_dim, n_modes, 2)
60 | param = nn.Parameter(weight)
61 | nn.init.xavier_normal_(param)
62 | self.fourier_weight.append(param)
63 |
64 | def forward(self, x):
65 | B, I, M, N = x.shape
66 |
67 | # # # Dimesion Y # # #
68 | x_fty = torch.fft.rfft(x, dim=-1, norm='ortho')
69 | # x_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1]
70 |
71 | out_ft = x_fty.new_zeros(B, I, M, N // 2 + 1)
72 | # out_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1, 2]
73 |
74 | out_ft[:, :, :, :self.modes_y] = torch.einsum(
75 | "bixy,ioy->boxy",
76 | x_fty[:, :, :, :self.modes_y],
77 | torch.view_as_complex(self.fourier_weight[1]))
78 |
79 | xy = torch.fft.irfft(out_ft, n=N, dim=-1, norm='ortho')
80 | # x.shape == [batch_size, in_dim, grid_size, grid_size]
81 |
82 | # # # Dimesion X # # #
83 | x_ftx = torch.fft.rfft(x, dim=-2, norm='ortho')
84 | # x_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size]
85 |
86 | out_ft = x_ftx.new_zeros(B, I, M // 2 + 1, N)
87 | # out_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size, 2]
88 |
89 | out_ft[:, :, :self.modes_x, :] = torch.einsum(
90 | "bixy,iox->boxy",
91 | x_ftx[:, :, :self.modes_x, :],
92 | torch.view_as_complex(self.fourier_weight[0]))
93 |
94 | xx = torch.fft.irfft(out_ft, n=M, dim=-2, norm='ortho')
95 | # x.shape == [batch_size, in_dim, grid_size, grid_size]
96 |
97 | # # Combining Dimensions # #
98 | x = xx + xy
99 |
100 | return x
101 |
102 |
103 | ################################################################
104 | # 3d fourier layers
105 | ################################################################
106 |
107 | class SpectralConv3d(nn.Module):
108 | def __init__(self, in_dim, out_dim, modes_x, modes_y, modes_z):
109 | super().__init__()
110 | self.in_dim = in_dim
111 | self.out_dim = out_dim
112 | self.modes_x = modes_x
113 | self.modes_y = modes_y
114 | self.modes_z = modes_z
115 |
116 | self.fourier_weight = nn.ParameterList([])
117 | for n_modes in [modes_x, modes_y, modes_z]:
118 | weight = torch.FloatTensor(in_dim, out_dim, n_modes, 2)
119 | param = nn.Parameter(weight)
120 | nn.init.xavier_normal_(param)
121 | self.fourier_weight.append(param)
122 |
123 | def forward(self, x):
124 | B, I, S1, S2, S3 = x.shape
125 |
126 | # # # Dimesion Z # # #
127 | x_ftz = torch.fft.rfft(x, dim=-1, norm='ortho')
128 | # x_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1]
129 |
130 | out_ft = x_ftz.new_zeros(B, I, S1, S2, S3 // 2 + 1)
131 | # out_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1, 2]
132 |
133 | out_ft[:, :, :, :, :self.modes_z] = torch.einsum(
134 | "bixyz,ioz->boxyz",
135 | x_ftz[:, :, :, :, :self.modes_z],
136 | torch.view_as_complex(self.fourier_weight[2]))
137 |
138 | xz = torch.fft.irfft(out_ft, n=S3, dim=-1, norm='ortho')
139 | # x.shape == [batch_size, in_dim, grid_size, grid_size]
140 |
141 | # # # Dimesion Y # # #
142 | x_fty = torch.fft.rfft(x, dim=-2, norm='ortho')
143 | # x_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size]
144 |
145 | out_ft = x_fty.new_zeros(B, I, S1, S2 // 2 + 1, S3)
146 | # out_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size, 2]
147 |
148 | out_ft[:, :, :, :self.modes_y, :] = torch.einsum(
149 | "bixyz,ioy->boxyz",
150 | x_fty[:, :, :, :self.modes_y, :],
151 | torch.view_as_complex(self.fourier_weight[1]))
152 |
153 | xy = torch.fft.irfft(out_ft, n=S2, dim=-2, norm='ortho')
154 | # x.shape == [batch_size, in_dim, grid_size, grid_size]
155 |
156 | # # # Dimesion X # # #
157 | x_ftx = torch.fft.rfft(x, dim=-3, norm='ortho')
158 | # x_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size]
159 |
160 | out_ft = x_ftx.new_zeros(B, I, S1 // 2 + 1, S2, S3)
161 | # out_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size, 2]
162 |
163 | out_ft[:, :, :self.modes_x, :, :] = torch.einsum(
164 | "bixyz,iox->boxyz",
165 | x_ftx[:, :, :self.modes_x, :, :],
166 | torch.view_as_complex(self.fourier_weight[0]))
167 |
168 | xx = torch.fft.irfft(out_ft, n=S1, dim=-3, norm='ortho')
169 | # x.shape == [batch_size, in_dim, grid_size, grid_size]
170 |
171 | # # Combining Dimensions # #
172 | x = xx + xy + xz
173 |
174 | return x
175 |
--------------------------------------------------------------------------------
/layers/FNO_Layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import torch
4 | import numpy as np
5 | import math
6 |
7 |
8 | ################################################################
9 | # 1d fourier layer
10 | ################################################################
11 | class SpectralConv1d(nn.Module):
12 | def __init__(self, in_channels, out_channels, modes1):
13 | super(SpectralConv1d, self).__init__()
14 |
15 | """
16 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
17 | """
18 |
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
22 |
23 | self.scale = (1 / (in_channels * out_channels))
24 | self.weights1 = nn.Parameter(
25 | self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
26 |
27 | # Complex multiplication
28 | def compl_mul1d(self, input, weights):
29 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
30 | return torch.einsum("bix,iox->box", input, weights)
31 |
32 | def forward(self, x):
33 | batchsize = x.shape[0]
34 | # Compute Fourier coeffcients up to factor of e^(- something constant)
35 | x_ft = torch.fft.rfft(x)
36 |
37 | # Multiply relevant Fourier modes
38 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat)
39 | out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)
40 |
41 | # Return to physical space
42 | x = torch.fft.irfft(out_ft, n=x.size(-1))
43 | return x
44 |
45 |
46 | ################################################################
47 | # 2d fourier layer
48 | ################################################################
49 | class SpectralConv2d(nn.Module):
50 | def __init__(self, in_channels, out_channels, modes1, modes2):
51 | super(SpectralConv2d, self).__init__()
52 | """
53 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT.
54 | """
55 | self.in_channels = in_channels
56 | self.out_channels = out_channels
57 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
58 | self.modes2 = modes2
59 |
60 | self.scale = (1 / (in_channels * out_channels))
61 | self.weights1 = nn.Parameter(
62 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
63 | self.weights2 = nn.Parameter(
64 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
65 |
66 | # Complex multiplication
67 | def compl_mul2d(self, input, weights):
68 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
69 | return torch.einsum("bixy,ioxy->boxy", input, weights)
70 |
71 | def forward(self, x):
72 | batchsize = x.shape[0]
73 | # Compute Fourier coeffcients up to factor of e^(- something constant)
74 | x_ft = torch.fft.rfft2(x)
75 |
76 | # Multiply relevant Fourier modes
77 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat,
78 | device=x.device)
79 | out_ft[:, :, :self.modes1, :self.modes2] = \
80 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
81 | out_ft[:, :, -self.modes1:, :self.modes2] = \
82 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)
83 |
84 | # Return to physical space
85 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
86 | return x
87 |
88 |
89 | ################################################################
90 | # 3d fourier layers
91 | ################################################################
92 |
93 | class SpectralConv3d(nn.Module):
94 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
95 | super(SpectralConv3d, self).__init__()
96 |
97 | """
98 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT.
99 | """
100 |
101 | self.in_channels = in_channels
102 | self.out_channels = out_channels
103 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
104 | self.modes2 = modes2
105 | self.modes3 = modes3
106 |
107 | self.scale = (1 / (in_channels * out_channels))
108 | self.weights1 = nn.Parameter(
109 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3,
110 | dtype=torch.cfloat))
111 | self.weights2 = nn.Parameter(
112 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3,
113 | dtype=torch.cfloat))
114 | self.weights3 = nn.Parameter(
115 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3,
116 | dtype=torch.cfloat))
117 | self.weights4 = nn.Parameter(
118 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3,
119 | dtype=torch.cfloat))
120 |
121 | # Complex multiplication
122 | def compl_mul3d(self, input, weights):
123 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
124 | return torch.einsum("bixyz,ioxyz->boxyz", input, weights)
125 |
126 | def forward(self, x):
127 | batchsize = x.shape[0]
128 | # Compute Fourier coeffcients up to factor of e^(- something constant)
129 | x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1])
130 |
131 | # Multiply relevant Fourier modes
132 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1) // 2 + 1,
133 | dtype=torch.cfloat, device=x.device)
134 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
135 | self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
136 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
137 | self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
138 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
139 | self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
140 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
141 | self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)
142 |
143 | # Return to physical space
144 | x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
145 | return x
146 |
--------------------------------------------------------------------------------
/layers/GeoFNO_Projection.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import torch
4 | import numpy as np
5 | import math
6 |
7 |
8 | ################################################################
9 | # geo projection
10 | ################################################################
11 | class SpectralConv2d_IrregularGeo(nn.Module):
12 | def __init__(self, in_channels, out_channels, modes1, modes2, s1=32, s2=32):
13 | super(SpectralConv2d_IrregularGeo, self).__init__()
14 |
15 | """
16 | from geoFNO
17 | """
18 |
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
22 | self.modes2 = modes2
23 | self.s1 = s1
24 | self.s2 = s2
25 |
26 | self.scale = (1 / (in_channels * out_channels))
27 | self.weights1 = nn.Parameter(
28 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
29 | self.weights2 = nn.Parameter(
30 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
31 |
32 | # Complex multiplication
33 | def compl_mul2d(self, input, weights):
34 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
35 | return torch.einsum("bixy,ioxy->boxy", input, weights)
36 |
37 | def forward(self, u, x_in=None, x_out=None, iphi=None, code=None):
38 | batchsize = u.shape[0]
39 |
40 | # Compute Fourier coeffcients up to factor of e^(- something constant)
41 | if x_in == None:
42 | u_ft = torch.fft.rfft2(u)
43 | s1 = u.size(-2)
44 | s2 = u.size(-1)
45 | else:
46 | u_ft = self.fft2d(u, x_in, iphi, code)
47 | s1 = self.s1
48 | s2 = self.s2
49 |
50 | # Multiply relevant Fourier modes
51 | # print(u.shape, u_ft.shape)
52 | factor1 = self.compl_mul2d(u_ft[:, :, :self.modes1, :self.modes2], self.weights1)
53 | factor2 = self.compl_mul2d(u_ft[:, :, -self.modes1:, :self.modes2], self.weights2)
54 |
55 | # Return to physical space
56 | if x_out == None:
57 | out_ft = torch.zeros(batchsize, self.out_channels, s1, s2 // 2 + 1, dtype=torch.cfloat, device=u.device)
58 | out_ft[:, :, :self.modes1, :self.modes2] = factor1
59 | out_ft[:, :, -self.modes1:, :self.modes2] = factor2
60 | u = torch.fft.irfft2(out_ft, s=(s1, s2))
61 | else:
62 | out_ft = torch.cat([factor1, factor2], dim=-2)
63 | u = self.ifft2d(out_ft, x_out, iphi, code)
64 |
65 | return u
66 |
67 | def fft2d(self, u, x_in, iphi=None, code=None):
68 | # u (batch, channels, n)
69 | # x_in (batch, n, 2) locations in [0,1]*[0,1]
70 | # iphi: function: x_in -> x_c
71 |
72 | batchsize = x_in.shape[0]
73 | N = x_in.shape[1]
74 | device = x_in.device
75 | m1 = 2 * self.modes1
76 | m2 = 2 * self.modes2 - 1
77 |
78 | # wavenumber (m1, m2)
79 | k_x1 = torch.cat((torch.arange(start=0, end=self.modes1, step=1), \
80 | torch.arange(start=-(self.modes1), end=0, step=1)), 0).reshape(m1, 1).repeat(1, m2).to(device)
81 | k_x2 = torch.cat((torch.arange(start=0, end=self.modes2, step=1), \
82 | torch.arange(start=-(self.modes2 - 1), end=0, step=1)), 0).reshape(1, m2).repeat(m1, 1).to(
83 | device)
84 |
85 | if iphi == None:
86 | x = x_in
87 | else:
88 | x = iphi(x_in, code)
89 |
90 | # K = , (batch, N, m1, m2)
91 | K1 = torch.outer(x[..., 0].view(-1), k_x1.view(-1)).reshape(batchsize, N, m1, m2)
92 | K2 = torch.outer(x[..., 1].view(-1), k_x2.view(-1)).reshape(batchsize, N, m1, m2)
93 | K = K1 + K2
94 |
95 | # basis (batch, N, m1, m2)
96 | basis = torch.exp(-1j * 2 * np.pi * K).to(device)
97 |
98 | # Y (batch, channels, N)
99 | u = u + 0j
100 | Y = torch.einsum("bcn,bnxy->bcxy", u, basis)
101 | return Y
102 |
103 | def ifft2d(self, u_ft, x_out, iphi=None, code=None):
104 | # u_ft (batch, channels, kmax, kmax)
105 | # x_out (batch, N, 2) locations in [0,1]*[0,1]
106 | # iphi: function: x_out -> x_c
107 |
108 | batchsize = x_out.shape[0]
109 | N = x_out.shape[1]
110 | device = x_out.device
111 | m1 = 2 * self.modes1
112 | m2 = 2 * self.modes2 - 1
113 |
114 | # wavenumber (m1, m2)
115 | k_x1 = torch.cat((torch.arange(start=0, end=self.modes1, step=1), \
116 | torch.arange(start=-(self.modes1), end=0, step=1)), 0).reshape(m1, 1).repeat(1, m2).to(device)
117 | k_x2 = torch.cat((torch.arange(start=0, end=self.modes2, step=1), \
118 | torch.arange(start=-(self.modes2 - 1), end=0, step=1)), 0).reshape(1, m2).repeat(m1, 1).to(
119 | device)
120 |
121 | if iphi == None:
122 | x = x_out
123 | else:
124 | x = iphi(x_out, code)
125 |
126 | # K = , (batch, N, m1, m2)
127 | K1 = torch.outer(x[:, :, 0].view(-1), k_x1.view(-1)).reshape(batchsize, N, m1, m2)
128 | K2 = torch.outer(x[:, :, 1].view(-1), k_x2.view(-1)).reshape(batchsize, N, m1, m2)
129 | K = K1 + K2
130 |
131 | # basis (batch, N, m1, m2)
132 | basis = torch.exp(1j * 2 * np.pi * K).to(device)
133 |
134 | # coeff (batch, channels, m1, m2)
135 | u_ft2 = u_ft[..., 1:].flip(-1, -2).conj()
136 | u_ft = torch.cat([u_ft, u_ft2], dim=-1)
137 |
138 | # Y (batch, channels, N)
139 | Y = torch.einsum("bcxy,bnxy->bcn", u_ft, basis)
140 | Y = Y.real
141 | return Y
142 |
143 |
144 | class IPHI(nn.Module):
145 | def __init__(self, width=32):
146 | super(IPHI, self).__init__()
147 |
148 | """
149 | inverse phi: x -> xi
150 | """
151 | self.width = width
152 | self.fc0 = nn.Linear(4, self.width)
153 | self.fc_code = nn.Linear(42, self.width)
154 | self.fc_no_code = nn.Linear(3 * self.width, 4 * self.width)
155 | self.fc1 = nn.Linear(4 * self.width, 4 * self.width)
156 | self.fc2 = nn.Linear(4 * self.width, 4 * self.width)
157 | self.fc3 = nn.Linear(4 * self.width, 4 * self.width)
158 | self.fc4 = nn.Linear(4 * self.width, 2)
159 | self.activation = torch.tanh
160 | self.center = torch.tensor([0.0001, 0.0001], device="cuda").reshape(1, 1, 2)
161 |
162 | self.B = np.pi * torch.pow(2, torch.arange(0, self.width // 4, dtype=torch.float, device="cuda")).reshape(1, 1,
163 | 1,
164 | self.width // 4)
165 |
166 | def forward(self, x, code=None):
167 | # x (batch, N_grid, 2)
168 | # code (batch, N_features)
169 |
170 | # some feature engineering
171 | angle = torch.atan2(x[:, :, 1] - self.center[:, :, 1], x[:, :, 0] - self.center[:, :, 0])
172 | radius = torch.norm(x - self.center, dim=-1, p=2)
173 | xd = torch.stack([x[:, :, 0], x[:, :, 1], angle, radius], dim=-1)
174 |
175 | # sin features from NeRF
176 | b, n, d = xd.shape[0], xd.shape[1], xd.shape[2]
177 | x_sin = torch.sin(self.B * xd.view(b, n, d, 1)).view(b, n, d * self.width // 4)
178 | x_cos = torch.cos(self.B * xd.view(b, n, d, 1)).view(b, n, d * self.width // 4)
179 | xd = self.fc0(xd)
180 | xd = torch.cat([xd, x_sin, x_cos], dim=-1).reshape(b, n, 3 * self.width)
181 |
182 | if code != None:
183 | cd = self.fc_code(code)
184 | cd = cd.unsqueeze(1).repeat(1, xd.shape[1], 1)
185 | xd = torch.cat([cd, xd], dim=-1)
186 | else:
187 | xd = self.fc_no_code(xd)
188 |
189 | xd = self.fc1(xd)
190 | xd = self.activation(xd)
191 | xd = self.fc2(xd)
192 | xd = self.activation(xd)
193 | xd = self.fc3(xd)
194 | xd = self.activation(xd)
195 | xd = self.fc4(xd)
196 | return x + x * xd
197 |
--------------------------------------------------------------------------------
/models/FNO.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from timm.models.layers import trunc_normal_
7 | from layers.Basic import MLP
8 | from layers.Embedding import timestep_embedding, unified_pos_embedding
9 | from layers.FNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d
10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI
11 |
12 | BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d]
13 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d]
14 |
15 |
16 | class Model(nn.Module):
17 | def __init__(self, args, s1=96, s2=96):
18 | super(Model, self).__init__()
19 | self.__name__ = 'FNO'
20 | self.args = args
21 | ## embedding
22 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
23 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
24 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
25 | args.n_hidden, n_layers=0, res=False, act=args.act)
26 | else:
27 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
28 | n_layers=0, res=False, act=args.act)
29 | if args.time_input:
30 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
31 | nn.Linear(args.n_hidden, args.n_hidden))
32 | # geometry projection
33 | if self.args.geotype == 'unstructured':
34 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1,
35 | s2)
36 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1,
37 | s2)
38 | self.iphi = IPHI()
39 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]]
40 | else:
41 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist]
42 | self.conv0 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
43 | *[args.modes for _ in range(len(self.padding))])
44 | self.conv1 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
45 | *[args.modes for _ in range(len(self.padding))])
46 | self.conv2 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
47 | *[args.modes for _ in range(len(self.padding))])
48 | self.conv3 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
49 | *[args.modes for _ in range(len(self.padding))])
50 | self.w0 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
51 | self.w1 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
52 | self.w2 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
53 | self.w3 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
54 | # projectors
55 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden)
56 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim)
57 |
58 | def structured_geo(self, x, fx, T=None):
59 | B, N, _ = x.shape
60 | if self.args.unified_pos:
61 | x = self.pos.repeat(x.shape[0], 1, 1)
62 | if fx is not None:
63 | fx = torch.cat((x, fx), -1)
64 | fx = self.preprocess(fx)
65 | else:
66 | fx = self.preprocess(x)
67 |
68 | if T is not None:
69 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
70 | Time_emb = self.time_fc(Time_emb)
71 | fx = fx + Time_emb
72 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist)
73 | if not all(item == 0 for item in self.padding):
74 | if len(self.args.shapelist) == 2:
75 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
76 | elif len(self.args.shapelist) == 3:
77 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
78 |
79 | x1 = self.conv0(x)
80 | x2 = self.w0(x)
81 | x = x1 + x2
82 | x = F.gelu(x)
83 |
84 | x1 = self.conv1(x)
85 | x2 = self.w1(x)
86 | x = x1 + x2
87 | x = F.gelu(x)
88 |
89 | x1 = self.conv2(x)
90 | x2 = self.w2(x)
91 | x = x1 + x2
92 | x = F.gelu(x)
93 |
94 | x1 = self.conv3(x)
95 | x2 = self.w3(x)
96 | x = x1 + x2
97 |
98 | if not all(item == 0 for item in self.padding):
99 | if len(self.args.shapelist) == 2:
100 | x = x[..., :-self.padding[0], :-self.padding[1]]
101 | elif len(self.args.shapelist) == 3:
102 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
103 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1)
104 | x = self.fc1(x)
105 | x = F.gelu(x)
106 | x = self.fc2(x)
107 | return x
108 |
109 | def unstructured_geo(self, x, fx, T=None):
110 | original_pos = x
111 | if fx is not None:
112 | fx = torch.cat((x, fx), -1)
113 | fx = self.preprocess(fx)
114 | else:
115 | fx = self.preprocess(x)
116 |
117 | if T is not None:
118 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
119 | Time_emb = self.time_fc(Time_emb)
120 | fx = fx + Time_emb
121 |
122 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None)
123 |
124 | x1 = self.conv0(x)
125 | x2 = self.w0(x)
126 | x = x1 + x2
127 | x = F.gelu(x)
128 |
129 | x1 = self.conv1(x)
130 | x2 = self.w1(x)
131 | x = x1 + x2
132 | x = F.gelu(x)
133 |
134 | x1 = self.conv2(x)
135 | x2 = self.w2(x)
136 | x = x1 + x2
137 | x = F.gelu(x)
138 |
139 | x1 = self.conv3(x)
140 | x2 = self.w3(x)
141 | x = x1 + x2
142 |
143 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1)
144 | x = self.fc1(x)
145 | x = F.gelu(x)
146 | x = self.fc2(x)
147 | return x
148 |
149 | def forward(self, x, fx, T=None, geo=None):
150 | if self.args.geotype == 'unstructured':
151 | return self.unstructured_geo(x, fx, T)
152 | else:
153 | return self.structured_geo(x, fx, T)
154 |
--------------------------------------------------------------------------------
/models/F_FNO.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from timm.models.layers import trunc_normal_
7 | from layers.Basic import MLP
8 | from layers.Embedding import timestep_embedding, unified_pos_embedding
9 | from layers.FFNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d
10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI
11 |
12 | BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d]
13 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d]
14 |
15 |
16 | class Model(nn.Module):
17 | def __init__(self, args, s1=96, s2=96):
18 | super(Model, self).__init__()
19 | self.__name__ = 'F-FNO'
20 | self.args = args
21 | ## embedding
22 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
23 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
24 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
25 | args.n_hidden, n_layers=0, res=False, act=args.act)
26 | else:
27 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
28 | n_layers=0, res=False, act=args.act)
29 | if args.time_input:
30 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
31 | nn.Linear(args.n_hidden, args.n_hidden))
32 | # geometry projection
33 | if self.args.geotype == 'unstructured':
34 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1,
35 | s2)
36 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1,
37 | s2)
38 | self.iphi = IPHI()
39 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]]
40 | else:
41 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist]
42 |
43 | self.spectral_layers = nn.ModuleList([])
44 | for _ in range(args.n_layers):
45 | self.spectral_layers.append(BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
46 | *[args.modes for _ in range(len(self.padding))]))
47 | # projectors
48 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden)
49 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim)
50 |
51 | def structured_geo(self, x, fx, T=None):
52 | B, N, _ = x.shape
53 | if self.args.unified_pos:
54 | x = self.pos.repeat(x.shape[0], 1, 1)
55 | if fx is not None:
56 | fx = torch.cat((x, fx), -1)
57 | fx = self.preprocess(fx)
58 | else:
59 | fx = self.preprocess(x)
60 |
61 | if T is not None:
62 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
63 | Time_emb = self.time_fc(Time_emb)
64 | fx = fx + Time_emb
65 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist)
66 | if not all(item == 0 for item in self.padding):
67 | if len(self.args.shapelist) == 2:
68 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
69 | elif len(self.args.shapelist) == 3:
70 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
71 |
72 | for i in range(self.args.n_layers):
73 | x = x + self.spectral_layers[i](x)
74 |
75 | if not all(item == 0 for item in self.padding):
76 | if len(self.args.shapelist) == 2:
77 | x = x[..., :-self.padding[0], :-self.padding[1]]
78 | elif len(self.args.shapelist) == 3:
79 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
80 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1)
81 | x = self.fc1(x)
82 | x = F.gelu(x)
83 | x = self.fc2(x)
84 | return x
85 |
86 | def unstructured_geo(self, x, fx, T=None):
87 | original_pos = x
88 | if fx is not None:
89 | fx = torch.cat((x, fx), -1)
90 | fx = self.preprocess(fx)
91 | else:
92 | fx = self.preprocess(x)
93 |
94 | if T is not None:
95 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
96 | Time_emb = self.time_fc(Time_emb)
97 | fx = fx + Time_emb
98 |
99 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None)
100 | for i in range(self.args.n_layers):
101 | x = x + self.spectral_layers[i](x)
102 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1)
103 | x = self.fc1(x)
104 | x = F.gelu(x)
105 | x = self.fc2(x)
106 | return x
107 |
108 | def forward(self, x, fx, T=None, geo=None):
109 | if self.args.geotype == 'unstructured':
110 | return self.unstructured_geo(x, fx, T)
111 | else:
112 | return self.structured_geo(x, fx, T)
113 |
--------------------------------------------------------------------------------
/models/GNOT.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from timm.models.layers import trunc_normal_
6 | from einops import repeat, rearrange
7 | from torch.nn import functional as F
8 | from layers.Basic import MLP, LinearAttention, ACTIVATION
9 | from layers.Embedding import timestep_embedding, unified_pos_embedding
10 |
11 | class GNOT_block(nn.Module):
12 | """Transformer encoder block in MOE style."""
13 |
14 | def __init__(self, num_heads: int,
15 | hidden_dim: int,
16 | dropout: float,
17 | act='gelu',
18 | mlp_ratio=4,
19 | space_dim=2,
20 | n_experts=3):
21 | super(GNOT_block, self).__init__()
22 | self.ln1 = nn.LayerNorm(hidden_dim)
23 | self.ln2 = nn.LayerNorm(hidden_dim)
24 | self.ln3 = nn.LayerNorm(hidden_dim)
25 | self.ln4 = nn.LayerNorm(hidden_dim)
26 | self.ln5 = nn.LayerNorm(hidden_dim)
27 |
28 | self.selfattn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout)
29 | self.crossattn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout)
30 | self.resid_drop1 = nn.Dropout(dropout)
31 | self.resid_drop2 = nn.Dropout(dropout)
32 |
33 | ## MLP in MOE
34 | self.n_experts = n_experts
35 | if act in ACTIVATION.keys():
36 | self.act = ACTIVATION[act]
37 | self.moe_mlp1 = nn.ModuleList([nn.Sequential(
38 | nn.Linear(hidden_dim, hidden_dim * mlp_ratio),
39 | self.act(),
40 | nn.Linear(hidden_dim * mlp_ratio, hidden_dim),
41 | ) for _ in range(self.n_experts)])
42 |
43 | self.moe_mlp2 = nn.ModuleList([nn.Sequential(
44 | nn.Linear(hidden_dim, hidden_dim * mlp_ratio),
45 | self.act(),
46 | nn.Linear(hidden_dim * mlp_ratio, hidden_dim),
47 | ) for _ in range(self.n_experts)])
48 |
49 | self.gatenet = nn.Sequential(
50 | nn.Linear(space_dim, hidden_dim * mlp_ratio),
51 | self.act(),
52 | nn.Linear(hidden_dim * mlp_ratio, hidden_dim * mlp_ratio),
53 | self.act(),
54 | nn.Linear(hidden_dim * mlp_ratio, self.n_experts)
55 | )
56 |
57 | def forward(self, x, y, pos):
58 | ## point-wise gate for moe
59 | gate_score = F.softmax(self.gatenet(pos), dim=-1).unsqueeze(2)
60 | ## cross attention between geo and physics observation
61 | x = x + self.resid_drop1(self.crossattn(self.ln1(x), self.ln2(y)))
62 | ## moe mlp
63 | x_moe1 = torch.stack([self.moe_mlp1[i](x) for i in range(self.n_experts)], dim=-1)
64 | x_moe1 = (gate_score * x_moe1).sum(dim=-1, keepdim=False)
65 | x = x + self.ln3(x_moe1)
66 | ## self attention among geo
67 | x = x + self.resid_drop2(self.selfattn(self.ln4(x)))
68 | ## moe mlp
69 | x_moe2 = torch.stack([self.moe_mlp2[i](x) for i in range(self.n_experts)], dim=-1)
70 | x_moe2 = (gate_score * x_moe2).sum(dim=-1, keepdim=False)
71 | x = x + self.ln5(x_moe2)
72 | return x
73 |
74 |
75 | class Model(nn.Module):
76 | ## GNOT: Transformer in MOE style
77 | def __init__(self, args, n_experts=3):
78 | super(Model, self).__init__()
79 | self.__name__ = 'GNOT'
80 | self.args = args
81 | ## embedding
82 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
83 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
84 | self.preprocess_x = MLP(args.ref ** len(args.shapelist), args.n_hidden * 2,
85 | args.n_hidden, n_layers=0, res=False, act=args.act)
86 | self.preprocess_z = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
87 | args.n_hidden, n_layers=0, res=False, act=args.act)
88 | else:
89 | self.preprocess_x = MLP(args.space_dim, args.n_hidden * 2, args.n_hidden,
90 | n_layers=0, res=False, act=args.act)
91 | self.preprocess_z = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
92 | n_layers=0, res=False, act=args.act)
93 | if args.time_input:
94 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
95 | nn.Linear(args.n_hidden, args.n_hidden))
96 |
97 | ## models
98 | self.blocks = nn.ModuleList([GNOT_block(num_heads=args.n_heads,
99 | hidden_dim=args.n_hidden,
100 | dropout=args.dropout,
101 | act=args.act,
102 | mlp_ratio=args.mlp_ratio,
103 | space_dim=args.space_dim,
104 | n_experts=n_experts)
105 | for _ in range(args.n_layers)])
106 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float))
107 | # projectors
108 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden * 2)
109 | self.fc2 = nn.Linear(args.n_hidden * 2, args.out_dim)
110 | self.initialize_weights()
111 |
112 | def initialize_weights(self):
113 | self.apply(self._init_weights)
114 |
115 | def _init_weights(self, m):
116 | if isinstance(m, nn.Linear):
117 | trunc_normal_(m.weight, std=0.02)
118 | if isinstance(m, nn.Linear) and m.bias is not None:
119 | nn.init.constant_(m.bias, 0)
120 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
121 | nn.init.constant_(m.bias, 0)
122 | nn.init.constant_(m.weight, 1.0)
123 |
124 | def forward(self, x, fx, T=None, geo=None):
125 | pos = x
126 | if self.args.unified_pos:
127 | x = self.pos.repeat(x.shape[0], 1, 1)
128 | if fx is not None:
129 | fx = torch.cat((x, fx), -1)
130 | fx = self.preprocess_z(fx)
131 | else:
132 | fx = self.preprocess_z(x)
133 | fx = fx + self.placeholder[None, None, :]
134 | x = self.preprocess_x(x)
135 | if T is not None:
136 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
137 | Time_emb = self.time_fc(Time_emb)
138 | fx = fx + Time_emb
139 |
140 | for block in self.blocks:
141 | fx = block(x, fx, pos)
142 | fx = self.fc1(fx)
143 | fx = F.gelu(fx)
144 | fx = self.fc2(fx)
145 | return fx
146 |
--------------------------------------------------------------------------------
/models/Galerkin_Transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from timm.models.layers import trunc_normal_
6 | from layers.Basic import MLP, LinearAttention
7 | from layers.Embedding import timestep_embedding, unified_pos_embedding
8 | from einops import rearrange, repeat
9 | from einops.layers.torch import Rearrange
10 |
11 |
12 | class Galerkin_Transformer_block(nn.Module):
13 | """Transformer encoder block."""
14 |
15 | def __init__(
16 | self,
17 | num_heads: int,
18 | hidden_dim: int,
19 | dropout: float,
20 | act='gelu',
21 | mlp_ratio=4,
22 | last_layer=False,
23 | out_dim=1,
24 | ):
25 | super().__init__()
26 | self.last_layer = last_layer
27 | self.ln_1 = nn.LayerNorm(hidden_dim)
28 | self.ln_1a = nn.LayerNorm(hidden_dim)
29 | self.Attn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
30 | dropout=dropout, attn_type='galerkin')
31 | self.ln_2 = nn.LayerNorm(hidden_dim)
32 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
33 | if self.last_layer:
34 | self.ln_3 = nn.LayerNorm(hidden_dim)
35 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
36 |
37 | def forward(self, fx):
38 | fx = self.Attn(self.ln_1(fx), self.ln_1a(fx)) + fx
39 | fx = self.mlp(self.ln_2(fx)) + fx
40 | if self.last_layer:
41 | return self.mlp2(self.ln_3(fx))
42 | else:
43 | return fx
44 |
45 |
46 | class Model(nn.Module):
47 | ## Factformer
48 | def __init__(self, args):
49 | super(Model, self).__init__()
50 | self.__name__ = 'Factformer'
51 | self.args = args
52 | ## embedding
53 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
54 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
55 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
56 | args.n_hidden, n_layers=0, res=False, act=args.act)
57 | else:
58 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
59 | n_layers=0, res=False, act=args.act)
60 | if args.time_input:
61 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
62 | nn.Linear(args.n_hidden, args.n_hidden))
63 |
64 | ## models
65 | self.blocks = nn.ModuleList([Galerkin_Transformer_block(num_heads=args.n_heads, hidden_dim=args.n_hidden,
66 | dropout=args.dropout,
67 | act=args.act,
68 | mlp_ratio=args.mlp_ratio,
69 | out_dim=args.out_dim,
70 | last_layer=(_ == args.n_layers - 1))
71 | for _ in range(args.n_layers)])
72 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float))
73 | self.initialize_weights()
74 |
75 | def initialize_weights(self):
76 | self.apply(self._init_weights)
77 |
78 | def _init_weights(self, m):
79 | if isinstance(m, nn.Linear):
80 | trunc_normal_(m.weight, std=0.02)
81 | if isinstance(m, nn.Linear) and m.bias is not None:
82 | nn.init.constant_(m.bias, 0)
83 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
84 | nn.init.constant_(m.bias, 0)
85 | nn.init.constant_(m.weight, 1.0)
86 |
87 | def forward(self, x, fx, T=None, geo=None):
88 | if self.args.unified_pos:
89 | x = self.pos.repeat(x.shape[0], 1, 1)
90 | if fx is not None:
91 | fx = torch.cat((x, fx), -1)
92 | fx = self.preprocess(fx)
93 | else:
94 | fx = self.preprocess(x)
95 | fx = fx + self.placeholder[None, None, :]
96 |
97 | if T is not None:
98 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
99 | Time_emb = self.time_fc(Time_emb)
100 | fx = fx + Time_emb
101 |
102 | for block in self.blocks:
103 | fx = block(fx)
104 | return fx
105 |
--------------------------------------------------------------------------------
/models/GraphSAGE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric.nn as nng
4 | from layers.Basic import MLP
5 |
6 |
7 | class Model(nn.Module):
8 | def __init__(self, args):
9 | super(Model, self).__init__()
10 | self.__name__ = 'GraphSAGE'
11 |
12 | self.nb_hidden_layers = args.n_layers
13 | self.size_hidden_layers = args.n_hidden
14 | self.bn_bool = True
15 | self.activation = nn.ReLU()
16 |
17 | self.encoder = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, n_layers=0, res=False,
18 | act=args.act)
19 | self.decoder = MLP(args.n_hidden, args.n_hidden * 2, args.out_dim, n_layers=0, res=False, act=args.act)
20 |
21 | self.in_layer = nng.SAGEConv(
22 | in_channels=args.n_hidden,
23 | out_channels=self.size_hidden_layers
24 | )
25 |
26 | self.hidden_layers = nn.ModuleList()
27 | for n in range(self.nb_hidden_layers - 1):
28 | self.hidden_layers.append(nng.SAGEConv(
29 | in_channels=self.size_hidden_layers,
30 | out_channels=self.size_hidden_layers
31 | ))
32 |
33 | self.out_layer = nng.SAGEConv(
34 | in_channels=self.size_hidden_layers,
35 | out_channels=self.size_hidden_layers
36 | )
37 |
38 | if self.bn_bool:
39 | self.bn = nn.ModuleList()
40 | for n in range(self.nb_hidden_layers):
41 | self.bn.append(nn.BatchNorm1d(self.size_hidden_layers, track_running_stats=False))
42 |
43 | def forward(self, x, fx, T=None, geo=None):
44 | if geo is None:
45 | raise ValueError('Please provide edge index for Graph Neural Networks')
46 | z, edge_index = torch.cat((x, fx), dim=-1).squeeze(0), geo
47 | z = self.encoder(z)
48 | z = self.in_layer(z, edge_index)
49 | if self.bn_bool:
50 | z = self.bn[0](z)
51 | z = self.activation(z)
52 |
53 | for n in range(self.nb_hidden_layers - 1):
54 | z = self.hidden_layers[n](z, edge_index)
55 | if self.bn_bool:
56 | z = self.bn[n + 1](z)
57 | z = self.activation(z)
58 | z = self.out_layer(z, edge_index)
59 | z = self.decoder(z)
60 | return z.unsqueeze(0)
61 |
--------------------------------------------------------------------------------
/models/Graph_UNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric.nn as nng
4 | import random
5 | from layers.Basic import MLP
6 |
7 |
8 | def DownSample(id, x, edge_index, pos_x, pool, pool_ratio, r, max_neighbors):
9 | y = x.clone()
10 | n = int(x.size(0))
11 |
12 | if pool is not None:
13 | y, _, _, _, id_sampled, _ = pool(y, edge_index)
14 | else:
15 | k = int((pool_ratio * torch.tensor(n, dtype=torch.float)).ceil())
16 | id_sampled = random.sample(range(n), k)
17 | id_sampled = torch.tensor(id_sampled, dtype=torch.long)
18 | y = y[id_sampled]
19 |
20 | pos_x = pos_x[id_sampled]
21 | id.append(id_sampled)
22 |
23 | edge_index_sampled = nng.radius_graph(x=pos_x.detach(), r=r, loop=True, max_num_neighbors=max_neighbors)
24 |
25 | return y, edge_index_sampled
26 |
27 |
28 | def UpSample(x, pos_x_up, pos_x_down):
29 | cluster = nng.nearest(pos_x_up, pos_x_down)
30 | x_up = x[cluster]
31 |
32 | return x_up
33 |
34 |
35 | class Model(nn.Module):
36 | def __init__(self, args, pool='random', scale=5, list_r=[0.05, 0.2, 0.5, 1, 10],
37 | pool_ratio=[0.5, 0.5, 0.5, 0.5, 0.5], max_neighbors=64, layer='SAGE', head=2):
38 | super(Model, self).__init__()
39 | self.__name__ = 'GUNet'
40 |
41 | self.L = scale
42 | self.layer = layer
43 | self.pool_type = pool
44 | self.pool_ratio = pool_ratio
45 | self.list_r = list_r
46 | self.size_hidden_layers = args.n_hidden
47 | self.size_hidden_layers_init = args.n_hidden
48 | self.max_neighbors = max_neighbors
49 | self.dim_enc = args.n_hidden
50 | self.bn_bool = True
51 | self.res = False
52 | self.head = head
53 | self.activation = nn.ReLU()
54 |
55 | self.encoder = MLP(args.fun_dim, args.n_hidden * 2, args.n_hidden, n_layers=0, res=False,
56 | act=args.act)
57 | self.decoder = MLP(args.n_hidden, args.n_hidden * 2, args.out_dim, n_layers=0, res=False, act=args.act)
58 |
59 | self.down_layers = nn.ModuleList()
60 |
61 | if self.pool_type != 'random':
62 | self.pool = nn.ModuleList()
63 | else:
64 | self.pool = None
65 |
66 | if self.layer == 'SAGE':
67 | self.down_layers.append(nng.SAGEConv(
68 | in_channels=self.dim_enc,
69 | out_channels=self.size_hidden_layers
70 | ))
71 | bn_in = self.size_hidden_layers
72 |
73 | elif self.layer == 'GAT':
74 | self.down_layers.append(nng.GATConv(
75 | in_channels=self.dim_enc,
76 | out_channels=self.size_hidden_layers,
77 | heads=self.head,
78 | add_self_loops=False,
79 | concat=True
80 | ))
81 | bn_in = self.head * self.size_hidden_layers
82 |
83 | if self.bn_bool == True:
84 | self.bn = nn.ModuleList()
85 | self.bn.append(nng.BatchNorm(
86 | in_channels=bn_in,
87 | track_running_stats=False
88 | ))
89 | else:
90 | self.bn = None
91 |
92 | for n in range(1, self.L):
93 | if self.pool_type != 'random':
94 | self.pool.append(nng.TopKPooling(
95 | in_channels=self.size_hidden_layers,
96 | ratio=self.pool_ratio[n - 1],
97 | nonlinearity=torch.sigmoid
98 | ))
99 |
100 | if self.layer == 'SAGE':
101 | self.down_layers.append(nng.SAGEConv(
102 | in_channels=self.size_hidden_layers,
103 | out_channels=2 * self.size_hidden_layers,
104 | ))
105 | self.size_hidden_layers = 2 * self.size_hidden_layers
106 | bn_in = self.size_hidden_layers
107 |
108 | elif self.layer == 'GAT':
109 | self.down_layers.append(nng.GATConv(
110 | in_channels=self.head * self.size_hidden_layers,
111 | out_channels=self.size_hidden_layers,
112 | heads=2,
113 | add_self_loops=False,
114 | concat=True
115 | ))
116 |
117 | if self.bn_bool == True:
118 | self.bn.append(nng.BatchNorm(
119 | in_channels=bn_in,
120 | track_running_stats=False
121 | ))
122 |
123 | self.up_layers = nn.ModuleList()
124 |
125 | if self.layer == 'SAGE':
126 | self.up_layers.append(nng.SAGEConv(
127 | in_channels=3 * self.size_hidden_layers_init,
128 | out_channels=self.dim_enc
129 | ))
130 | self.size_hidden_layers_init = 2 * self.size_hidden_layers_init
131 |
132 | elif self.layer == 'GAT':
133 | self.up_layers.append(nng.GATConv(
134 | in_channels=2 * self.head * self.size_hidden_layers,
135 | out_channels=self.dim_enc,
136 | heads=2,
137 | add_self_loops=False,
138 | concat=False
139 | ))
140 |
141 | if self.bn_bool == True:
142 | self.bn.append(nng.BatchNorm(
143 | in_channels=self.dim_enc,
144 | track_running_stats=False
145 | ))
146 |
147 | for n in range(1, self.L - 1):
148 | if self.layer == 'SAGE':
149 | self.up_layers.append(nng.SAGEConv(
150 | in_channels=3 * self.size_hidden_layers_init,
151 | out_channels=self.size_hidden_layers_init,
152 | ))
153 | bn_in = self.size_hidden_layers_init
154 | self.size_hidden_layers_init = 2 * self.size_hidden_layers_init
155 |
156 | elif self.layer == 'GAT':
157 | self.up_layers.append(nng.GATConv(
158 | in_channels=2 * self.head * self.size_hidden_layers,
159 | out_channels=self.size_hidden_layers,
160 | heads=2,
161 | add_self_loops=False,
162 | concat=True
163 | ))
164 |
165 | if self.bn_bool == True:
166 | self.bn.append(nng.BatchNorm(
167 | in_channels=bn_in,
168 | track_running_stats=False
169 | ))
170 |
171 | def forward(self, x, fx, T=None, geo=None):
172 | if geo is None:
173 | raise ValueError('Please provide edge index for Graph Neural Networks')
174 | x, edge_index = fx.squeeze(0), geo
175 | id = []
176 | edge_index_list = [edge_index.clone()]
177 | pos_x_list = []
178 | z = self.encoder(x)
179 | if self.res:
180 | z_res = z.clone()
181 |
182 | z = self.down_layers[0](z, edge_index)
183 |
184 | if self.bn_bool == True:
185 | z = self.bn[0](z)
186 |
187 | z = self.activation(z)
188 | z_list = [z.clone()]
189 | for n in range(self.L - 1):
190 | pos_x = x[:, :2] if n == 0 else pos_x[id[n - 1]]
191 | pos_x_list.append(pos_x.clone())
192 |
193 | if self.pool_type != 'random':
194 | z, edge_index = DownSample(id, z, edge_index, pos_x, self.pool[n], self.pool_ratio[n], self.list_r[n],
195 | self.max_neighbors)
196 | else:
197 | z, edge_index = DownSample(id, z, edge_index, pos_x, None, self.pool_ratio[n], self.list_r[n],
198 | self.max_neighbors)
199 | edge_index_list.append(edge_index.clone())
200 |
201 | z = self.down_layers[n + 1](z, edge_index)
202 |
203 | if self.bn_bool == True:
204 | z = self.bn[n + 1](z)
205 |
206 | z = self.activation(z)
207 | z_list.append(z.clone())
208 | pos_x_list.append(pos_x[id[-1]].clone())
209 |
210 | for n in range(self.L - 1, 0, -1):
211 | z = UpSample(z, pos_x_list[n - 1], pos_x_list[n])
212 | z = torch.cat([z, z_list[n - 1]], dim=1)
213 | z = self.up_layers[n - 1](z, edge_index_list[n - 1])
214 |
215 | if self.bn_bool == True:
216 | z = self.bn[self.L + n - 1](z)
217 |
218 | z = self.activation(z) if n != 1 else z
219 |
220 | del (z_list, pos_x_list, edge_index_list)
221 |
222 | if self.res:
223 | z = z + z_res
224 |
225 | z = self.decoder(z)
226 |
227 | return z.unsqueeze(0)
228 |
--------------------------------------------------------------------------------
/models/LSM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from timm.models.layers import trunc_normal_
7 | from layers.Basic import MLP
8 | from layers.Embedding import timestep_embedding, unified_pos_embedding
9 | from layers.Neural_Spectral_Block import NeuralSpectralBlock1D, NeuralSpectralBlock2D, NeuralSpectralBlock3D
10 | from layers.UNet_Blocks import DoubleConv1D, Down1D, Up1D, OutConv1D, DoubleConv2D, Down2D, Up2D, OutConv2D, \
11 | DoubleConv3D, Down3D, Up3D, OutConv3D
12 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI
13 |
14 | ConvList = [None, DoubleConv1D, DoubleConv2D, DoubleConv3D]
15 | DownList = [None, Down1D, Down2D, Down3D]
16 | UpList = [None, Up1D, Up2D, Up3D]
17 | OutList = [None, OutConv1D, OutConv2D, OutConv3D]
18 | BlockList = [None, NeuralSpectralBlock1D, NeuralSpectralBlock2D, NeuralSpectralBlock3D]
19 |
20 |
21 | class Model(nn.Module):
22 | def __init__(self, args, bilinear=True, num_token=4, num_basis=12, s1=96, s2=96):
23 | super(Model, self).__init__()
24 | self.__name__ = 'LSM'
25 | self.args = args
26 | if args.task == 'steady':
27 | normtype = 'bn'
28 | else:
29 | normtype = 'in' # when conducting dynamic tasks, use instance norm for stability
30 | ## embedding
31 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
32 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
33 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
34 | args.n_hidden, n_layers=0, res=False, act=args.act)
35 | else:
36 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
37 | n_layers=0, res=False, act=args.act)
38 | if args.time_input:
39 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
40 | nn.Linear(args.n_hidden, args.n_hidden))
41 | # geometry projection
42 | if self.args.geotype == 'unstructured':
43 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes,
44 | s1, s2)
45 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes,
46 | s1, s2)
47 | self.iphi = IPHI()
48 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in [s1, s2]]
49 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]]
50 | else:
51 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in args.shapelist]
52 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist]
53 | # multiscale modules
54 | self.inc = ConvList[len(patch_size)](args.n_hidden, args.n_hidden, normtype=normtype)
55 | self.down1 = DownList[len(patch_size)](args.n_hidden, args.n_hidden * 2, normtype=normtype)
56 | self.down2 = DownList[len(patch_size)](args.n_hidden * 2, args.n_hidden * 4, normtype=normtype)
57 | self.down3 = DownList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 8, normtype=normtype)
58 | factor = 2 if bilinear else 1
59 | self.down4 = DownList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 16 // factor, normtype=normtype)
60 | self.up1 = UpList[len(patch_size)](args.n_hidden * 16, args.n_hidden * 8 // factor, bilinear, normtype=normtype)
61 | self.up2 = UpList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 4 // factor, bilinear, normtype=normtype)
62 | self.up3 = UpList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 2 // factor, bilinear, normtype=normtype)
63 | self.up4 = UpList[len(patch_size)](args.n_hidden * 2, args.n_hidden, bilinear, normtype=normtype)
64 | self.outc = OutList[len(patch_size)](args.n_hidden, args.n_hidden)
65 | # Patchified Neural Spectral Blocks
66 | self.process1 = BlockList[len(patch_size)](args.n_hidden, num_basis, patch_size, num_token, args.n_heads)
67 | self.process2 = BlockList[len(patch_size)](args.n_hidden * 2, num_basis, patch_size, num_token, args.n_heads)
68 | self.process3 = BlockList[len(patch_size)](args.n_hidden * 4, num_basis, patch_size, num_token, args.n_heads)
69 | self.process4 = BlockList[len(patch_size)](args.n_hidden * 8, num_basis, patch_size, num_token, args.n_heads)
70 | self.process5 = BlockList[len(patch_size)](args.n_hidden * 16 // factor, num_basis, patch_size, num_token,
71 | args.n_heads)
72 | # projectors
73 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden * 2)
74 | self.fc2 = nn.Linear(args.n_hidden * 2, args.out_dim)
75 |
76 | def structured_geo(self, x, fx, T=None):
77 | B, N, _ = x.shape
78 | if self.args.unified_pos:
79 | x = self.pos.repeat(x.shape[0], 1, 1)
80 | if fx is not None:
81 | fx = torch.cat((x, fx), -1)
82 | fx = self.preprocess(fx)
83 | else:
84 | fx = self.preprocess(x)
85 |
86 | if T is not None:
87 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
88 | Time_emb = self.time_fc(Time_emb)
89 | fx = fx + Time_emb
90 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist)
91 | if not all(item == 0 for item in self.padding):
92 | if len(self.args.shapelist) == 2:
93 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
94 | elif len(self.args.shapelist) == 3:
95 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
96 | x1 = self.inc(x)
97 | x2 = self.down1(x1)
98 | x3 = self.down2(x2)
99 | x4 = self.down3(x3)
100 | x5 = self.down4(x4)
101 | x = self.up1(self.process5(x5), self.process4(x4))
102 | x = self.up2(x, self.process3(x3))
103 | x = self.up3(x, self.process2(x2))
104 | x = self.up4(x, self.process1(x1))
105 | x = self.outc(x)
106 |
107 | if not all(item == 0 for item in self.padding):
108 | if len(self.args.shapelist) == 2:
109 | x = x[..., :-self.padding[0], :-self.padding[1]]
110 | elif len(self.args.shapelist) == 3:
111 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
112 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1)
113 | x = self.fc1(x)
114 | x = F.gelu(x)
115 | x = self.fc2(x)
116 | return x
117 |
118 | def unstructured_geo(self, x, fx, T=None):
119 | original_pos = x
120 | if fx is not None:
121 | fx = torch.cat((x, fx), -1)
122 | fx = self.preprocess(fx)
123 | else:
124 | fx = self.preprocess(x)
125 |
126 | if T is not None:
127 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
128 | Time_emb = self.time_fc(Time_emb)
129 | fx = fx + Time_emb
130 |
131 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None)
132 | x1 = self.inc(x)
133 | x2 = self.down1(x1)
134 | x3 = self.down2(x2)
135 | x4 = self.down3(x3)
136 | x5 = self.down4(x4)
137 | x = self.up1(self.process5(x5), self.process4(x4))
138 | x = self.up2(x, self.process3(x3))
139 | x = self.up3(x, self.process2(x2))
140 | x = self.up4(x, self.process1(x1))
141 | x = self.outc(x)
142 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1)
143 | x = self.fc1(x)
144 | x = F.gelu(x)
145 | x = self.fc2(x)
146 | return x
147 |
148 | def forward(self, x, fx, T=None, geo=None):
149 | if self.args.geotype == 'unstructured':
150 | return self.unstructured_geo(x, fx, T)
151 | else:
152 | return self.structured_geo(x, fx, T)
153 |
--------------------------------------------------------------------------------
/models/MWT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from timm.models.layers import trunc_normal_
7 | from layers.Basic import MLP
8 | from layers.Embedding import timestep_embedding, unified_pos_embedding
9 | from layers.MWT_Layers import MWT_CZ1d, MWT_CZ2d, MWT_CZ3d
10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI
11 |
12 | BlockList = [None, MWT_CZ1d, MWT_CZ2d, MWT_CZ3d]
13 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d]
14 |
15 |
16 | class Model(nn.Module):
17 | # this model requires H = W = Z and H, W, Z is the power of two
18 | def __init__(self, args, alpha=2, L=0, c=1, base='legendre', s1=128, s2=128):
19 | super(Model, self).__init__()
20 | self.__name__ = 'MWT'
21 | self.args = args
22 | self.k = args.mwt_k
23 | self.WMT_dim = c * self.k ** 2
24 | if args.geotype == 'structured_1D':
25 | self.WMT_dim = c * self.k
26 | self.c = c
27 | self.s1 = s1
28 | self.s2 = s2
29 | ## embedding
30 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
31 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
32 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
33 | self.WMT_dim, n_layers=0, res=False, act=args.act)
34 | else:
35 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, self.WMT_dim,
36 | n_layers=0, res=False, act=args.act)
37 | if args.time_input:
38 | self.time_fc = nn.Sequential(nn.Linear(self.WMT_dim, args.n_hidden), nn.SiLU(),
39 | nn.Linear(args.n_hidden, self.WMT_dim))
40 | # geometry projection
41 | if self.args.geotype == 'unstructured':
42 | self.fftproject_in = SpectralConv2d_IrregularGeo(self.WMT_dim, self.WMT_dim, args.modes, args.modes, s1, s2)
43 | self.fftproject_out = SpectralConv2d_IrregularGeo(self.WMT_dim, self.WMT_dim, args.modes, args.modes, s1,
44 | s2)
45 | self.iphi = IPHI()
46 | self.augmented_resolution = [s1, s2]
47 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]]
48 | else:
49 | target = 2 ** (math.ceil(np.log2(max(args.shapelist))))
50 | self.padding = [(target - size) for size in args.shapelist]
51 | self.augmented_resolution = [target for _ in range(len(self.padding))]
52 | self.spectral_layers = nn.ModuleList(
53 | [BlockList[len(self.padding)](k=self.k, alpha=alpha, L=L, c=c, base=base) for _ in range(args.n_layers)])
54 | # projectors
55 | self.fc1 = nn.Linear(self.WMT_dim, args.n_hidden)
56 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim)
57 |
58 | def structured_geo(self, x, fx, T=None):
59 | B, N, _ = x.shape
60 | if self.args.unified_pos:
61 | x = self.pos.repeat(x.shape[0], 1, 1)
62 | if fx is not None:
63 | fx = torch.cat((x, fx), -1)
64 | fx = self.preprocess(fx)
65 | else:
66 | fx = self.preprocess(x)
67 |
68 | if T is not None:
69 | Time_emb = timestep_embedding(T, self.WMT_dim).repeat(1, x.shape[1], 1)
70 | Time_emb = self.time_fc(Time_emb)
71 | fx = fx + Time_emb
72 | x = fx.permute(0, 2, 1).reshape(B, self.WMT_dim, *self.args.shapelist)
73 | if not all(item == 0 for item in self.padding):
74 | if len(self.args.shapelist) == 2:
75 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
76 | elif len(self.args.shapelist) == 3:
77 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
78 | x = x.reshape(B, self.WMT_dim, -1).permute(0, 2, 1).contiguous() \
79 | .reshape(B, *self.augmented_resolution, self.c, self.k ** 2 if self.args.geotype != 'structured_1D' else self.k)
80 | for i in range(self.args.n_layers):
81 | x = self.spectral_layers[i](x)
82 | if i < self.args.n_layers - 1:
83 | x = F.gelu(x)
84 | x = x.reshape(B, -1, self.WMT_dim).permute(0, 2, 1).contiguous() \
85 | .reshape(B, self.WMT_dim, *self.augmented_resolution)
86 | if not all(item == 0 for item in self.padding):
87 | if len(self.args.shapelist) == 2:
88 | x = x[..., :-self.padding[0], :-self.padding[1]]
89 | elif len(self.args.shapelist) == 3:
90 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
91 | x = x.reshape(B, self.WMT_dim, -1).permute(0, 2, 1)
92 | x = self.fc1(x)
93 | x = F.gelu(x)
94 | x = self.fc2(x)
95 | return x
96 |
97 | def unstructured_geo(self, x, fx, T=None):
98 | B, N, _ = x.shape
99 | original_pos = x
100 | if fx is not None:
101 | fx = torch.cat((x, fx), -1)
102 | fx = self.preprocess(fx)
103 | else:
104 | fx = self.preprocess(x)
105 |
106 | if T is not None:
107 | Time_emb = timestep_embedding(T, self.WMT_dim).repeat(1, x.shape[1], 1)
108 | Time_emb = self.time_fc(Time_emb)
109 | fx = fx + Time_emb
110 |
111 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None)
112 | x = x.reshape(B, self.WMT_dim, -1).permute(0, 2, 1).contiguous() \
113 | .reshape(B, *self.augmented_resolution, self.c, self.k ** 2)
114 | for i in range(self.args.n_layers):
115 | x = self.spectral_layers[i](x)
116 | if i < self.args.n_layers - 1:
117 | x = F.gelu(x)
118 | x = x.reshape(B, -1, self.WMT_dim).permute(0, 2, 1).contiguous() \
119 | .reshape(B, self.WMT_dim, *self.augmented_resolution)
120 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1)
121 | x = self.fc1(x)
122 | x = F.gelu(x)
123 | x = self.fc2(x)
124 | return x
125 |
126 | def forward(self, x, fx, T=None, geo=None):
127 | if self.args.geotype == 'unstructured':
128 | return self.unstructured_geo(x, fx, T)
129 | else:
130 | return self.structured_geo(x, fx, T)
131 |
--------------------------------------------------------------------------------
/models/ONO.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from timm.models.layers import trunc_normal_
6 | from layers.Basic import MLP, LinearAttention, FlashAttention, SelfAttention as LinearSelfAttention
7 | from layers.Embedding import timestep_embedding, unified_pos_embedding
8 | from einops import rearrange, repeat
9 | import warnings
10 |
11 |
12 | def psd_safe_cholesky(A, upper=False, out=None, jitter=None):
13 | """Compute the Cholesky decomposition of A. If A is only p.s.d, add a small jitter to the diagonal.
14 | Args:
15 | :attr:`A` (Tensor):
16 | The tensor to compute the Cholesky decomposition of
17 | :attr:`upper` (bool, optional):
18 | See torch.cholesky
19 | :attr:`out` (Tensor, optional):
20 | See torch.cholesky
21 | :attr:`jitter` (float, optional):
22 | The jitter to add to the diagonal of A in case A is only p.s.d. If omitted, chosen
23 | as 1e-6 (float) or 1e-8 (double)
24 | """
25 | try:
26 | L = torch.linalg.cholesky(A, upper=upper, out=out)
27 | if torch.isnan(L).any():
28 | raise RuntimeError
29 | return L
30 | except RuntimeError as e:
31 | isnan = torch.isnan(A)
32 | if isnan.any():
33 | raise ValueError(
34 | f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN."
35 | )
36 |
37 | if jitter is None:
38 | jitter = 1e-6 if A.dtype == torch.float32 else 1e-8
39 | Aprime = A.clone()
40 | jitter_prev = 0
41 | for i in range(10):
42 | jitter_new = jitter * (10 ** i)
43 | Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev)
44 | jitter_prev = jitter_new
45 | try:
46 | L = torch.linalg.cholesky(Aprime, upper=upper, out=out)
47 | warnings.warn(
48 | f"A not p.d., added jitter of {jitter_new} to the diagonal",
49 | RuntimeWarning,
50 | )
51 | return L
52 | except RuntimeError:
53 | continue
54 | raise e
55 |
56 |
57 | class ONOBlock(nn.Module):
58 | """ONO encoder block."""
59 |
60 | def __init__(
61 | self,
62 | num_heads: int,
63 | hidden_dim: int,
64 | dropout: float,
65 | act='gelu',
66 | attn_type='nystrom',
67 | mlp_ratio=4,
68 | last_layer=False,
69 | momentum=0.9,
70 | psi_dim=8,
71 | out_dim=1
72 | ):
73 | super().__init__()
74 | self.momentum = momentum
75 | self.psi_dim = psi_dim
76 |
77 | self.register_buffer("feature_cov", None)
78 | self.register_parameter("mu", nn.Parameter(torch.zeros(psi_dim)))
79 | self.ln_1 = nn.LayerNorm(hidden_dim)
80 | if attn_type == 'nystrom':
81 | from nystrom_attention import NystromAttention
82 | self.Attn = NystromAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout)
83 | elif attn_type == 'linear':
84 | self.Attn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout,
85 | attn_type='galerkin')
86 | elif attn_type == 'selfAttention':
87 | self.Attn = LinearSelfAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout)
88 | else:
89 | raise ValueError('Attn type only supports nystrom or linear')
90 | self.ln_2 = nn.LayerNorm(hidden_dim)
91 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
92 | self.proj = nn.Linear(hidden_dim, psi_dim)
93 | self.ln_3 = nn.LayerNorm(hidden_dim)
94 | self.mlp2 = nn.Linear(hidden_dim, out_dim) if last_layer else MLP(hidden_dim, hidden_dim * mlp_ratio,
95 | hidden_dim, n_layers=0, res=False, act=act)
96 |
97 | def forward(self, x, fx):
98 | x = self.Attn(self.ln_1(x)) + x
99 | x = self.mlp(self.ln_2(x)) + x
100 | x_ = self.proj(x)
101 | if self.training:
102 | batch_cov = torch.einsum("blc, bld->cd", x_, x_) / x_.shape[0] / x_.shape[1]
103 | with torch.no_grad():
104 | if self.feature_cov is None:
105 | self.feature_cov = batch_cov
106 | else:
107 | self.feature_cov.mul_(self.momentum).add_(batch_cov, alpha=1 - self.momentum)
108 | else:
109 | batch_cov = self.feature_cov
110 | L = psd_safe_cholesky(batch_cov)
111 | L_inv_T = L.inverse().transpose(-2, -1)
112 | x_ = x_ @ L_inv_T
113 |
114 | fx = (x_ * torch.nn.functional.softplus(self.mu)) @ (x_.transpose(-2, -1) @ fx) + fx
115 | fx = self.mlp2(self.ln_3(fx))
116 |
117 | return x, fx
118 |
119 |
120 | class Model(nn.Module):
121 | ## speed up with flash attention
122 | def __init__(self, args):
123 | super(Model, self).__init__()
124 | self.__name__ = 'ONO'
125 | self.args = args
126 | ## embedding
127 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
128 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
129 | self.preprocess_x = MLP(args.ref ** len(args.shapelist), args.n_hidden * 2,
130 | args.n_hidden, n_layers=0, res=False, act=args.act)
131 | self.preprocess_z = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
132 | args.n_hidden, n_layers=0, res=False, act=args.act)
133 | else:
134 | self.preprocess_x = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
135 | n_layers=0, res=False, act=args.act)
136 | self.preprocess_z = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
137 | n_layers=0, res=False, act=args.act)
138 | if args.time_input:
139 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
140 | nn.Linear(args.n_hidden, args.n_hidden))
141 |
142 | ## models
143 | self.blocks = nn.ModuleList([ONOBlock(num_heads=args.n_heads, hidden_dim=args.n_hidden,
144 | dropout=args.dropout,
145 | act=args.act,
146 | mlp_ratio=args.mlp_ratio,
147 | out_dim=args.out_dim,
148 | psi_dim=args.psi_dim,
149 | attn_type=args.attn_type,
150 | last_layer=(_ == args.n_layers - 1))
151 | for _ in range(args.n_layers)])
152 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float))
153 | self.initialize_weights()
154 |
155 | def initialize_weights(self):
156 | self.apply(self._init_weights)
157 |
158 | def _init_weights(self, m):
159 | if isinstance(m, nn.Linear):
160 | trunc_normal_(m.weight, std=0.02)
161 | if isinstance(m, nn.Linear) and m.bias is not None:
162 | nn.init.constant_(m.bias, 0)
163 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
164 | nn.init.constant_(m.bias, 0)
165 | nn.init.constant_(m.weight, 1.0)
166 |
167 | def forward(self, x, fx, T=None, geo=None):
168 | if self.args.unified_pos:
169 | x = self.pos.repeat(x.shape[0], 1, 1)
170 | if fx is not None:
171 | x = torch.cat((x, fx), -1)
172 | fx = self.preprocess_z(x)
173 | x = self.preprocess_x(x)
174 | else:
175 | fx = self.preprocess_z(x)
176 | x = self.preprocess_x(x)
177 | fx = fx + self.placeholder[None, None, :]
178 |
179 | if T is not None:
180 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
181 | Time_emb = self.time_fc(Time_emb)
182 | fx = fx + Time_emb
183 |
184 | for block in self.blocks:
185 | x, fx = block(x, fx)
186 | return fx
187 |
--------------------------------------------------------------------------------
/models/PointNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric.nn as nng
4 | from layers.Embedding import unified_pos_embedding
5 | from layers.Basic import MLP
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args):
10 | super(Model, self).__init__()
11 | self.__name__ = 'PointNet'
12 |
13 | self.in_block = MLP(args.n_hidden, args.n_hidden * 2, args.n_hidden * 2, n_layers=0, res=False,
14 | act=args.act)
15 | self.max_block = MLP(args.n_hidden * 2, args.n_hidden * 8, args.n_hidden * 32, n_layers=0, res=False,
16 | act=args.act)
17 |
18 | self.out_block = MLP(args.n_hidden * (2 + 32), args.n_hidden * 16, args.n_hidden * 4, n_layers=0, res=False,
19 | act=args.act)
20 |
21 | self.encoder = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden, n_layers=0, res=False,
22 | act=args.act)
23 | self.decoder = MLP(args.n_hidden, args.n_hidden * 2, args.out_dim, n_layers=0, res=False, act=args.act)
24 |
25 | self.fcfinal = nn.Linear(args.n_hidden * 4, args.n_hidden)
26 |
27 | def forward(self, x, fx, T=None, geo=None):
28 | if geo is None:
29 | raise ValueError('Please provide edge index for Graph Neural Networks')
30 | z, batch = torch.cat((x, fx), dim=-1).float().squeeze(0), torch.zeros([x.shape[1]]).cuda().long()
31 |
32 | z = self.encoder(z)
33 | z = self.in_block(z)
34 |
35 | global_coef = self.max_block(z)
36 | global_coef = nng.global_max_pool(global_coef, batch=batch)
37 | nb_points = torch.zeros(global_coef.shape[0], device=z.device)
38 |
39 | for i in range(batch.max() + 1):
40 | nb_points[i] = (batch == i).sum()
41 | nb_points = nb_points.long()
42 | global_coef = torch.repeat_interleave(global_coef, nb_points, dim=0)
43 |
44 | z = torch.cat([z, global_coef], dim=1)
45 | z = self.out_block(z)
46 | z = self.fcfinal(z)
47 | z = self.decoder(z)
48 |
49 | return z.unsqueeze(0)
50 |
--------------------------------------------------------------------------------
/models/Transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from timm.models.layers import trunc_normal_
6 | from layers.Basic import MLP, FlashAttention
7 | from layers.Embedding import timestep_embedding, unified_pos_embedding
8 | from einops import rearrange, repeat
9 |
10 |
11 | class Transformer_block(nn.Module):
12 | """Transformer encoder block."""
13 |
14 | def __init__(
15 | self,
16 | num_heads: int,
17 | hidden_dim: int,
18 | dropout: float,
19 | act='gelu',
20 | mlp_ratio=4,
21 | last_layer=False,
22 | out_dim=1,
23 | ):
24 | super().__init__()
25 | self.last_layer = last_layer
26 | self.ln_1 = nn.LayerNorm(hidden_dim)
27 |
28 | # flash_attention
29 | self.Attn = FlashAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout)
30 | self.ln_2 = nn.LayerNorm(hidden_dim)
31 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
32 | if self.last_layer:
33 | self.ln_3 = nn.LayerNorm(hidden_dim)
34 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
35 |
36 | def forward(self, fx):
37 | fx = self.Attn(self.ln_1(fx)) + fx
38 | fx = self.mlp(self.ln_2(fx)) + fx
39 | if self.last_layer:
40 | return self.mlp2(self.ln_3(fx))
41 | else:
42 | return fx
43 |
44 |
45 | class Model(nn.Module):
46 | ## speed up with flash attention
47 | def __init__(self, args):
48 | super(Model, self).__init__()
49 | self.__name__ = 'Transformer'
50 | self.args = args
51 | ## embedding
52 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
53 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
54 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
55 | args.n_hidden, n_layers=0, res=False, act=args.act)
56 | else:
57 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
58 | n_layers=0, res=False, act=args.act)
59 | if args.time_input:
60 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
61 | nn.Linear(args.n_hidden, args.n_hidden))
62 |
63 | ## models
64 | self.blocks = nn.ModuleList([Transformer_block(num_heads=args.n_heads, hidden_dim=args.n_hidden,
65 | dropout=args.dropout,
66 | act=args.act,
67 | mlp_ratio=args.mlp_ratio,
68 | out_dim=args.out_dim,
69 | last_layer=(_ == args.n_layers - 1))
70 | for _ in range(args.n_layers)])
71 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float))
72 | self.initialize_weights()
73 |
74 | def initialize_weights(self):
75 | self.apply(self._init_weights)
76 |
77 | def _init_weights(self, m):
78 | if isinstance(m, nn.Linear):
79 | trunc_normal_(m.weight, std=0.02)
80 | if isinstance(m, nn.Linear) and m.bias is not None:
81 | nn.init.constant_(m.bias, 0)
82 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
83 | nn.init.constant_(m.bias, 0)
84 | nn.init.constant_(m.weight, 1.0)
85 |
86 | def forward(self, x, fx, T=None, geo=None):
87 | if self.args.unified_pos:
88 | x = self.pos.repeat(x.shape[0], 1, 1)
89 | if fx is not None:
90 | fx = torch.cat((x, fx), -1)
91 | fx = self.preprocess(fx)
92 | else:
93 | fx = self.preprocess(x)
94 | fx = fx + self.placeholder[None, None, :]
95 |
96 | if T is not None:
97 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
98 | Time_emb = self.time_fc(Time_emb)
99 | fx = fx + Time_emb
100 |
101 | for block in self.blocks:
102 | fx = block(fx)
103 | return fx
104 |
--------------------------------------------------------------------------------
/models/Transolver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from timm.models.layers import trunc_normal_
5 | from layers.Basic import MLP
6 | from layers.Embedding import timestep_embedding, unified_pos_embedding
7 | from layers.Physics_Attention import Physics_Attention_Irregular_Mesh
8 | from layers.Physics_Attention import Physics_Attention_Structured_Mesh_1D
9 | from layers.Physics_Attention import Physics_Attention_Structured_Mesh_2D
10 | from layers.Physics_Attention import Physics_Attention_Structured_Mesh_3D
11 |
12 | PHYSICS_ATTENTION = {
13 | 'unstructured': Physics_Attention_Irregular_Mesh,
14 | 'structured_1D': Physics_Attention_Structured_Mesh_1D,
15 | 'structured_2D': Physics_Attention_Structured_Mesh_2D,
16 | 'structured_3D': Physics_Attention_Structured_Mesh_3D
17 | }
18 |
19 |
20 | class Transolver_block(nn.Module):
21 | """Transolver encoder block."""
22 |
23 | def __init__(
24 | self,
25 | num_heads: int,
26 | hidden_dim: int,
27 | dropout: float,
28 | act='gelu',
29 | mlp_ratio=4,
30 | last_layer=False,
31 | out_dim=1,
32 | slice_num=32,
33 | geotype='unstructured',
34 | shapelist=None
35 | ):
36 | super().__init__()
37 | self.last_layer = last_layer
38 | self.ln_1 = nn.LayerNorm(hidden_dim)
39 |
40 | self.Attn = PHYSICS_ATTENTION[geotype](hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
41 | dropout=dropout, slice_num=slice_num, shapelist=shapelist)
42 | self.ln_2 = nn.LayerNorm(hidden_dim)
43 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
44 | if self.last_layer:
45 | self.ln_3 = nn.LayerNorm(hidden_dim)
46 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
47 |
48 | def forward(self, fx):
49 | fx = self.Attn(self.ln_1(fx)) + fx
50 | fx = self.mlp(self.ln_2(fx)) + fx
51 | if self.last_layer:
52 | return self.mlp2(self.ln_3(fx))
53 | else:
54 | return fx
55 |
56 |
57 | class Model(nn.Module):
58 | def __init__(self, args):
59 | super(Model, self).__init__()
60 | self.__name__ = 'Transolver'
61 | self.args = args
62 | ## embedding
63 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
64 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
65 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
66 | args.n_hidden, n_layers=0, res=False, act=args.act)
67 | else:
68 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
69 | n_layers=0, res=False, act=args.act)
70 | if args.time_input:
71 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
72 | nn.Linear(args.n_hidden, args.n_hidden))
73 |
74 | ## models
75 | self.blocks = nn.ModuleList([Transolver_block(num_heads=args.n_heads, hidden_dim=args.n_hidden,
76 | dropout=args.dropout,
77 | act=args.act,
78 | mlp_ratio=args.mlp_ratio,
79 | out_dim=args.out_dim,
80 | slice_num=args.slice_num,
81 | last_layer=(_ == args.n_layers - 1),
82 | geotype=args.geotype,
83 | shapelist=args.shapelist)
84 | for _ in range(args.n_layers)])
85 | self.placeholder = nn.Parameter((1 / (args.n_hidden)) * torch.rand(args.n_hidden, dtype=torch.float))
86 | self.initialize_weights()
87 |
88 | def initialize_weights(self):
89 | self.apply(self._init_weights)
90 |
91 | def _init_weights(self, m):
92 | if isinstance(m, nn.Linear):
93 | trunc_normal_(m.weight, std=0.02)
94 | if isinstance(m, nn.Linear) and m.bias is not None:
95 | nn.init.constant_(m.bias, 0)
96 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
97 | nn.init.constant_(m.bias, 0)
98 | nn.init.constant_(m.weight, 1.0)
99 |
100 | def structured_geo(self, x, fx, T=None):
101 | if self.args.unified_pos:
102 | x = self.pos.repeat(x.shape[0], 1, 1)
103 | if fx is not None:
104 | fx = torch.cat((x, fx), -1)
105 | fx = self.preprocess(fx)
106 | else:
107 | fx = self.preprocess(x)
108 | fx = fx + self.placeholder[None, None, :]
109 |
110 | if T is not None:
111 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
112 | Time_emb = self.time_fc(Time_emb)
113 | fx = fx + Time_emb
114 |
115 | for block in self.blocks:
116 | fx = block(fx)
117 | return fx
118 |
119 | def unstructured_geo(self, x, fx, T=None):
120 | if fx is not None:
121 | fx = torch.cat((x, fx), -1)
122 | fx = self.preprocess(fx)
123 | else:
124 | fx = self.preprocess(x)
125 | fx = fx + self.placeholder[None, None, :]
126 |
127 | if T is not None:
128 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
129 | Time_emb = self.time_fc(Time_emb)
130 | fx = fx + Time_emb
131 |
132 | for block in self.blocks:
133 | fx = block(fx)
134 | return fx
135 |
136 | def forward(self, x, fx, T=None, geo=None):
137 | if self.args.geotype == 'unstructured':
138 | return self.unstructured_geo(x, fx, T)
139 | else:
140 | return self.structured_geo(x, fx, T)
141 |
--------------------------------------------------------------------------------
/models/U_FNO.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from timm.models.layers import trunc_normal_
7 | from layers.Basic import MLP
8 | from layers.Embedding import timestep_embedding, unified_pos_embedding
9 | from layers.FNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d
10 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI
11 | from models.U_Net import Model as U_Net
12 |
13 | BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d]
14 | ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d]
15 |
16 |
17 | class Model(nn.Module):
18 | def __init__(self, args, s1=96, s2=96):
19 | super(Model, self).__init__()
20 | self.__name__ = 'U-FNO'
21 | self.args = args
22 | ## embedding
23 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
24 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
25 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
26 | args.n_hidden, n_layers=0, res=False, act=args.act)
27 | else:
28 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
29 | n_layers=0, res=False, act=args.act)
30 | if args.time_input:
31 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
32 | nn.Linear(args.n_hidden, args.n_hidden))
33 | # geometry projection
34 | if self.args.geotype == 'unstructured':
35 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1,
36 | s2)
37 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes, s1,
38 | s2)
39 | self.iphi = IPHI()
40 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]]
41 | else:
42 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist]
43 | self.conv0 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
44 | *[args.modes for _ in range(len(self.padding))])
45 | self.conv1 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
46 | *[args.modes for _ in range(len(self.padding))])
47 | self.conv2 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
48 | *[args.modes for _ in range(len(self.padding))])
49 | self.conv3 = BlockList[len(self.padding)](args.n_hidden, args.n_hidden,
50 | *[args.modes for _ in range(len(self.padding))])
51 | self.w0 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
52 | self.w1 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
53 | self.w2 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
54 | self.w3 = ConvList[len(self.padding)](args.n_hidden, args.n_hidden, 1)
55 | self.u_net2 = U_Net(args)
56 | self.u_net3 = U_Net(args)
57 | # projectors
58 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden)
59 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim)
60 |
61 | def structured_geo(self, x, fx, T=None):
62 | B, N, _ = x.shape
63 | if self.args.unified_pos:
64 | x = self.pos.repeat(x.shape[0], 1, 1)
65 | if fx is not None:
66 | fx = torch.cat((x, fx), -1)
67 | fx = self.preprocess(fx)
68 | else:
69 | fx = self.preprocess(x)
70 |
71 | if T is not None:
72 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
73 | Time_emb = self.time_fc(Time_emb)
74 | fx = fx + Time_emb
75 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist)
76 | if not all(item == 0 for item in self.padding):
77 | if len(self.args.shapelist) == 2:
78 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
79 | elif len(self.args.shapelist) == 3:
80 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
81 |
82 | x1 = self.conv0(x)
83 | x2 = self.w0(x)
84 | x = x1 + x2
85 | x = F.gelu(x)
86 |
87 | x1 = self.conv1(x)
88 | x2 = self.w1(x)
89 | x = x1 + x2
90 | x = F.gelu(x)
91 |
92 | x1 = self.conv2(x)
93 | x2 = self.w2(x)
94 | x3 = self.u_net2.multiscale(x)
95 | x = x1 + x2 + x3
96 | x = F.gelu(x)
97 |
98 | x1 = self.conv3(x)
99 | x2 = self.w3(x)
100 | x3 = self.u_net3.multiscale(x)
101 | x = x1 + x2 + x3
102 |
103 | if not all(item == 0 for item in self.padding):
104 | if len(self.args.shapelist) == 2:
105 | x = x[..., :-self.padding[0], :-self.padding[1]]
106 | elif len(self.args.shapelist) == 3:
107 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
108 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1)
109 | x = self.fc1(x)
110 | x = F.gelu(x)
111 | x = self.fc2(x)
112 | return x
113 |
114 | def unstructured_geo(self, x, fx, T=None):
115 | original_pos = x
116 | if fx is not None:
117 | fx = torch.cat((x, fx), -1)
118 | fx = self.preprocess(fx)
119 | else:
120 | fx = self.preprocess(x)
121 |
122 | if T is not None:
123 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
124 | Time_emb = self.time_fc(Time_emb)
125 | fx = fx + Time_emb
126 |
127 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None)
128 |
129 | x1 = self.conv0(x)
130 | x2 = self.w0(x)
131 | x = x1 + x2
132 | x = F.gelu(x)
133 |
134 | x1 = self.conv1(x)
135 | x2 = self.w1(x)
136 | x = x1 + x2
137 | x = F.gelu(x)
138 |
139 | x1 = self.conv2(x)
140 | x2 = self.w2(x)
141 | x3 = self.u_net2.multiscale(x)
142 | x = x1 + x2 + x3
143 | x = F.gelu(x)
144 |
145 | x1 = self.conv3(x)
146 | x2 = self.w3(x)
147 | x3 = self.u_net3.multiscale(x)
148 | x = x1 + x2 + x3
149 |
150 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1)
151 | x = self.fc1(x)
152 | x = F.gelu(x)
153 | x = self.fc2(x)
154 | return x
155 |
156 | def forward(self, x, fx, T=None, geo=None):
157 | if self.args.geotype == 'unstructured':
158 | return self.unstructured_geo(x, fx, T)
159 | else:
160 | return self.structured_geo(x, fx, T)
161 |
--------------------------------------------------------------------------------
/models/U_Net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from timm.models.layers import trunc_normal_
7 | from layers.Basic import MLP
8 | from layers.Embedding import timestep_embedding, unified_pos_embedding
9 | from layers.UNet_Blocks import DoubleConv1D, Down1D, Up1D, OutConv1D, DoubleConv2D, Down2D, Up2D, OutConv2D, \
10 | DoubleConv3D, Down3D, Up3D, OutConv3D
11 | from layers.GeoFNO_Projection import SpectralConv2d_IrregularGeo, IPHI
12 |
13 | ConvList = [None, DoubleConv1D, DoubleConv2D, DoubleConv3D]
14 | DownList = [None, Down1D, Down2D, Down3D]
15 | UpList = [None, Up1D, Up2D, Up3D]
16 | OutList = [None, OutConv1D, OutConv2D, OutConv3D]
17 |
18 |
19 | class Model(nn.Module):
20 | def __init__(self, args, bilinear=True, s1=96, s2=96):
21 | super(Model, self).__init__()
22 | self.__name__ = 'U-Net'
23 | self.args = args
24 | if args.task == 'steady':
25 | normtype = 'bn'
26 | else:
27 | normtype = 'in' # when conducting dynamic tasks, use instance norm for stability
28 | ## embedding
29 | if args.unified_pos and args.geotype != 'unstructured': # only for structured mesh
30 | self.pos = unified_pos_embedding(args.shapelist, args.ref)
31 | self.preprocess = MLP(args.fun_dim + args.ref ** len(args.shapelist), args.n_hidden * 2,
32 | args.n_hidden, n_layers=0, res=False, act=args.act)
33 | else:
34 | self.preprocess = MLP(args.fun_dim + args.space_dim, args.n_hidden * 2, args.n_hidden,
35 | n_layers=0, res=False, act=args.act)
36 | if args.time_input:
37 | self.time_fc = nn.Sequential(nn.Linear(args.n_hidden, args.n_hidden), nn.SiLU(),
38 | nn.Linear(args.n_hidden, args.n_hidden))
39 | # geometry projection
40 | if self.args.geotype == 'unstructured':
41 | self.fftproject_in = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes,
42 | s1, s2)
43 | self.fftproject_out = SpectralConv2d_IrregularGeo(args.n_hidden, args.n_hidden, args.modes, args.modes,
44 | s1, s2)
45 | self.iphi = IPHI()
46 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in [s1, s2]]
47 | self.padding = [(16 - size % 16) % 16 for size in [s1, s2]]
48 | else:
49 | patch_size = [(size + (16 - size % 16) % 16) // 16 for size in args.shapelist]
50 | self.padding = [(16 - size % 16) % 16 for size in args.shapelist]
51 | # multiscale modules
52 | self.inc = ConvList[len(patch_size)](args.n_hidden, args.n_hidden, normtype=normtype)
53 | self.down1 = DownList[len(patch_size)](args.n_hidden, args.n_hidden * 2, normtype=normtype)
54 | self.down2 = DownList[len(patch_size)](args.n_hidden * 2, args.n_hidden * 4, normtype=normtype)
55 | self.down3 = DownList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 8, normtype=normtype)
56 | factor = 2 if bilinear else 1
57 | self.down4 = DownList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 16 // factor, normtype=normtype)
58 | self.up1 = UpList[len(patch_size)](args.n_hidden * 16, args.n_hidden * 8 // factor, bilinear, normtype=normtype)
59 | self.up2 = UpList[len(patch_size)](args.n_hidden * 8, args.n_hidden * 4 // factor, bilinear, normtype=normtype)
60 | self.up3 = UpList[len(patch_size)](args.n_hidden * 4, args.n_hidden * 2 // factor, bilinear, normtype=normtype)
61 | self.up4 = UpList[len(patch_size)](args.n_hidden * 2, args.n_hidden, bilinear, normtype=normtype)
62 | self.outc = OutList[len(patch_size)](args.n_hidden, args.n_hidden)
63 | # projectors
64 | self.fc1 = nn.Linear(args.n_hidden, args.n_hidden)
65 | self.fc2 = nn.Linear(args.n_hidden, args.out_dim)
66 |
67 | def multiscale(self, x):
68 | x1 = self.inc(x)
69 | x2 = self.down1(x1)
70 | x3 = self.down2(x2)
71 | x4 = self.down3(x3)
72 | x5 = self.down4(x4)
73 | x = self.up1(x5, x4)
74 | x = self.up2(x, x3)
75 | x = self.up3(x, x2)
76 | x = self.up4(x, x1)
77 | x = self.outc(x)
78 | return x
79 |
80 | def structured_geo(self, x, fx, T=None):
81 | B, N, _ = x.shape
82 | if self.args.unified_pos:
83 | x = self.pos.repeat(x.shape[0], 1, 1)
84 | if fx is not None:
85 | fx = torch.cat((x, fx), -1)
86 | fx = self.preprocess(fx)
87 | else:
88 | fx = self.preprocess(x)
89 |
90 | if T is not None:
91 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
92 | Time_emb = self.time_fc(Time_emb)
93 | fx = fx + Time_emb
94 | x = fx.permute(0, 2, 1).reshape(B, self.args.n_hidden, *self.args.shapelist)
95 | if not all(item == 0 for item in self.padding):
96 | if len(self.args.shapelist) == 2:
97 | x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
98 | elif len(self.args.shapelist) == 3:
99 | x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
100 | x = self.multiscale(x) ## U-Net
101 | if not all(item == 0 for item in self.padding):
102 | if len(self.args.shapelist) == 2:
103 | x = x[..., :-self.padding[0], :-self.padding[1]]
104 | elif len(self.args.shapelist) == 3:
105 | x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
106 | x = x.reshape(B, self.args.n_hidden, -1).permute(0, 2, 1)
107 | x = self.fc1(x)
108 | x = F.gelu(x)
109 | x = self.fc2(x)
110 | return x
111 |
112 | def unstructured_geo(self, x, fx, T=None):
113 | original_pos = x
114 | if fx is not None:
115 | fx = torch.cat((x, fx), -1)
116 | fx = self.preprocess(fx)
117 | else:
118 | fx = self.preprocess(x)
119 |
120 | if T is not None:
121 | Time_emb = timestep_embedding(T, self.args.n_hidden).repeat(1, x.shape[1], 1)
122 | Time_emb = self.time_fc(Time_emb)
123 | fx = fx + Time_emb
124 |
125 | x = self.fftproject_in(fx.permute(0, 2, 1), x_in=original_pos, iphi=self.iphi, code=None)
126 | x = self.multiscale(x) ## U-Net
127 | x = self.fftproject_out(x, x_out=original_pos, iphi=self.iphi, code=None).permute(0, 2, 1)
128 | x = self.fc1(x)
129 | x = F.gelu(x)
130 | x = self.fc2(x)
131 | return x
132 |
133 | def forward(self, x, fx, T=None, geo=None):
134 | if self.args.geotype == 'unstructured':
135 | return self.unstructured_geo(x, fx, T)
136 | else:
137 | return self.structured_geo(x, fx, T)
138 |
--------------------------------------------------------------------------------
/models/model_factory.py:
--------------------------------------------------------------------------------
1 | from models import Transolver, LSM, FNO, U_Net, Transformer, Factformer, Swin_Transformer, Galerkin_Transformer, GNOT, \
2 | U_NO, U_FNO, F_FNO, ONO, MWT, GraphSAGE, Graph_UNet, PointNet
3 |
4 |
5 | def get_model(args):
6 | model_dict = {
7 | 'PointNet': PointNet,
8 | 'Graph_UNet': Graph_UNet,
9 | 'GraphSAGE': GraphSAGE,
10 | 'MWT': MWT,
11 | 'ONO': ONO,
12 | 'F_FNO': F_FNO,
13 | 'U_FNO': U_FNO,
14 | 'U_NO': U_NO,
15 | 'GNOT': GNOT,
16 | 'Galerkin_Transformer': Galerkin_Transformer,
17 | 'Swin_Transformer': Swin_Transformer,
18 | 'Factformer': Factformer,
19 | 'Transformer': Transformer,
20 | 'U_Net': U_Net,
21 | 'FNO': FNO,
22 | 'Transolver': Transolver,
23 | 'LSM': LSM,
24 | }
25 | return model_dict[args.model].Model(args)
26 |
--------------------------------------------------------------------------------
/pic/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Neural-Solver-Library/148c2eb890dc22081acc822acd760adcf4f06b30/pic/logo.png
--------------------------------------------------------------------------------
/pic/task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Neural-Solver-Library/148c2eb890dc22081acc822acd760adcf4f06b30/pic/task.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py==3.8.0
2 | dgl==1.1.0
3 | scipy==1.7.3
4 | scikit-learn
5 | torch
6 | torch_geometric
7 | torch-cluster
8 | vtk
9 | timm
10 | einops
11 | seaborn
12 | pyvista
13 | sympy
14 | nystrom_attention
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from tqdm import *
6 |
7 | parser = argparse.ArgumentParser('Training Neural PDE Solvers')
8 |
9 | ## training
10 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
11 | parser.add_argument('--epochs', type=int, default=500, help='maximum epochs')
12 | parser.add_argument('--weight_decay', type=float, default=1e-5, help='optimizer weight decay')
13 | parser.add_argument('--pct_start', type=float, default=0.3, help='oncycle lr schedule')
14 | parser.add_argument('--batch-size', type=int, default=8, help='batch size')
15 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use")
16 | parser.add_argument('--max_grad_norm', type=float, default=None, help='make the training stable')
17 | parser.add_argument('--derivloss', type=bool, default=False, help='adopt the spatial derivate as regularization')
18 | parser.add_argument('--teacher_forcing', type=int, default=1,
19 | help='adopt teacher forcing in autoregressive to speed up convergence')
20 | parser.add_argument('--optimizer', type=str, default='AdamW', help='optimizer type, select from Adam, AdamW')
21 | parser.add_argument('--scheduler', type=str, default='OneCycleLR',
22 | help='learning rate scheduler, select from [OneCycleLR, CosineAnnealingLR, StepLR]')
23 | parser.add_argument('--step_size', type=int, default=100, help='step size for StepLR scheduler')
24 | parser.add_argument('--gamma', type=float, default=0.5, help='decay parameter for StepLR scheduler')
25 |
26 | ## data
27 | parser.add_argument('--data_path', type=str, default='/data/fno/', help='data folder')
28 | parser.add_argument('--loader', type=str, default='airfoil', help='type of data loader')
29 | parser.add_argument('--train_ratio', type=float, default=0.8, help='training data ratio')
30 | parser.add_argument('--ntrain', type=int, default=1000, help='training data numbers')
31 | parser.add_argument('--ntest', type=int, default=200, help='test data numbers')
32 | parser.add_argument('--normalize', type=bool, default=False, help='make normalization to output')
33 | parser.add_argument('--norm_type', type=str, default='UnitTransformer',
34 | help='dataset normalize type. select from [UnitTransformer, UnitGaussianNormalizer]')
35 | parser.add_argument('--geotype', type=str, default='unstructured',
36 | help='select from [unstructured, structured_1D, structured_2D, structured_3D]')
37 | parser.add_argument('--time_input', type=bool, default=False, help='for conditional dynamic task')
38 | parser.add_argument('--space_dim', type=int, default=2, help='position information dimension')
39 | parser.add_argument('--fun_dim', type=int, default=0, help='input observation dimension')
40 | parser.add_argument('--out_dim', type=int, default=1, help='output observation dimension')
41 | parser.add_argument('--shapelist', type=list, default=None, help='for structured geometry')
42 | parser.add_argument('--downsamplex', type=int, default=1, help='downsample rate in x-axis')
43 | parser.add_argument('--downsampley', type=int, default=1, help='downsample rate in y-axis')
44 | parser.add_argument('--downsamplez', type=int, default=1, help='downsample rate in z-axis')
45 | parser.add_argument('--radius', type=float, default=0.2, help='for construct geometry')
46 |
47 | ## task
48 | parser.add_argument('--task', type=str, default='steady',
49 | help='select from [steady, dynamic_autoregressive, dynamic_conditional]')
50 | parser.add_argument('--T_in', type=int, default=10, help='for input sequence')
51 | parser.add_argument('--T_out', type=int, default=10, help='for output sequence')
52 |
53 | ## models
54 | parser.add_argument('--model', type=str, default='Transolver')
55 | parser.add_argument('--n_hidden', type=int, default=64, help='hidden dim')
56 | parser.add_argument('--n_layers', type=int, default=3, help='layers')
57 | parser.add_argument('--n_heads', type=int, default=4, help='number of heads')
58 | parser.add_argument('--act', type=str, default='gelu')
59 | parser.add_argument('--mlp_ratio', type=int, default=1, help='mlp ratio for feedforward layers')
60 | parser.add_argument('--dropout', type=float, default=0.0, help='dropout')
61 | parser.add_argument('--unified_pos', type=int, default=0, help='for unified position embedding')
62 | parser.add_argument('--ref', type=int, default=8, help='number of reference points for unified pos embedding')
63 |
64 | ## model specific configuration
65 | parser.add_argument('--slice_num', type=int, default=32, help='number of physical states for Transolver')
66 | parser.add_argument('--modes', type=int, default=12, help='number of basis functions for LSM and FNO')
67 | parser.add_argument('--psi_dim', type=int, default=8, help='number of psi_dim for ONO')
68 | parser.add_argument('--attn_type', type=str, default='nystrom',help='attn_type for ONO, select from nystrom, linear, selfAttention')
69 | parser.add_argument('--mwt_k', type=int, default=3,help='number of wavelet basis functions for MWT')
70 |
71 | ## eval
72 | parser.add_argument('--eval', type=int, default=0, help='evaluation or not')
73 | parser.add_argument('--save_name', type=str, default='Transolver_check', help='name of folders')
74 | parser.add_argument('--vis_num', type=int, default=10, help='number of visualization cases')
75 | parser.add_argument('--vis_bound', type=int, nargs='+', default=None, help='size of region for visualization, in list')
76 |
77 | args = parser.parse_args()
78 | eval = args.eval
79 | save_name = args.save_name
80 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
81 |
82 |
83 | def main():
84 | if args.task == 'steady':
85 | from exp.exp_steady import Exp_Steady
86 | exp = Exp_Steady(args)
87 | elif args.task == 'steady_design':
88 | from exp.exp_steady_design import Exp_Steady_Design
89 | exp = Exp_Steady_Design(args)
90 | elif args.task == 'dynamic_autoregressive':
91 | from exp.exp_dynamic_autoregressive import Exp_Dynamic_Autoregressive
92 | exp = Exp_Dynamic_Autoregressive(args)
93 | elif args.task == 'dynamic_conditional':
94 | from exp.exp_dynamic_conditional import Exp_Dynamic_Conditional
95 | exp = Exp_Dynamic_Conditional(args)
96 | else:
97 | raise NotImplementedError
98 |
99 | if eval:
100 | exp.test()
101 | else:
102 | exp.train()
103 | exp.test()
104 |
105 |
106 | if __name__ == "__main__":
107 | main()
108 |
--------------------------------------------------------------------------------
/scripts/DesignBench/car/GNOT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDE_data/mlcfd_data/ \
4 | --loader car_design \
5 | --geotype unstructured \
6 | --task steady_design \
7 | --space_dim 3 \
8 | --fun_dim 7 \
9 | --out_dim 4 \
10 | --model GNOT \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 0 \
17 | --ref 8 \
18 | --batch-size 1 \
19 | --epochs 200 \
20 | --eval 0 \
21 | --save_name car_design_GNOT
--------------------------------------------------------------------------------
/scripts/DesignBench/car/GraphSAGE.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/PDE_data/mlcfd_data/ \
4 | --loader car_design \
5 | --geotype unstructured \
6 | --task steady_design \
7 | --space_dim 3 \
8 | --fun_dim 7 \
9 | --out_dim 4 \
10 | --model GraphSAGE \
11 | --n_hidden 128 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 0 \
17 | --ref 8 \
18 | --batch-size 1 \
19 | --epochs 200 \
20 | --eval 0 \
21 | --save_name car_design_GraphSAGE
--------------------------------------------------------------------------------
/scripts/DesignBench/car/GraphUNet.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/PDE_data/mlcfd_data/ \
4 | --loader car_design \
5 | --geotype unstructured \
6 | --task steady_design \
7 | --space_dim 3 \
8 | --fun_dim 7 \
9 | --out_dim 4 \
10 | --model Graph_UNet \
11 | --n_hidden 16 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 0 \
17 | --ref 8 \
18 | --batch-size 1 \
19 | --epochs 200 \
20 | --eval 0 \
21 | --save_name car_design_GraphUNet
--------------------------------------------------------------------------------
/scripts/DesignBench/car/PointNet.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/PDE_data/mlcfd_data/ \
4 | --loader car_design \
5 | --geotype unstructured \
6 | --task steady_design \
7 | --space_dim 3 \
8 | --fun_dim 7 \
9 | --out_dim 4 \
10 | --model PointNet \
11 | --n_hidden 16 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 0 \
17 | --ref 8 \
18 | --batch-size 1 \
19 | --epochs 200 \
20 | --eval 0 \
21 | --save_name car_design_PointNet
--------------------------------------------------------------------------------
/scripts/DesignBench/car/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/PDE_data/mlcfd_data/ \
4 | --loader car_design \
5 | --geotype unstructured \
6 | --task steady_design \
7 | --space_dim 3 \
8 | --fun_dim 7 \
9 | --out_dim 4 \
10 | --model Transolver \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 0 \
17 | --ref 8 \
18 | --batch-size 1 \
19 | --epochs 200 \
20 | --eval 0 \
21 | --save_name car_design_Transolver
--------------------------------------------------------------------------------
/scripts/PDEBench/3DCFD/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/3D/3D/Train/3D_CFD_Rand_M1.0_Eta1e-08_Zeta1e-08_periodic_Train.hdf5 \
4 | --loader cfd3d \
5 | --geotype structured_3D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 3 \
12 | --fun_dim 50 \
13 | --out_dim 5 \
14 | --downsamplex 2 \
15 | --downsampley 2 \
16 | --downsamplez 2 \
17 | --model FNO \
18 | --n_hidden 20 \
19 | --n_heads 8 \
20 | --n_layers 8 \
21 | --unified_pos 0 \
22 | --ref 8 \
23 | --batch-size 5 \
24 | --epochs 500 \
25 | --eval 0 \
26 | --save_name pdebench_FNO_3DCFD
--------------------------------------------------------------------------------
/scripts/PDEBench/3DCFD/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/3D/3D/Train/3D_CFD_Rand_M1.0_Eta1e-08_Zeta1e-08_periodic_Train.hdf5 \
4 | --loader cfd3d \
5 | --geotype structured_3D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.001 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 3 \
12 | --fun_dim 50 \
13 | --out_dim 5 \
14 | --downsamplex 2 \
15 | --downsampley 2 \
16 | --downsamplez 2 \
17 | --model U_Net \
18 | --n_hidden 20 \
19 | --n_heads 8 \
20 | --n_layers 8 \
21 | --unified_pos 0 \
22 | --ref 8 \
23 | --batch-size 5 \
24 | --epochs 500 \
25 | --eval 0 \
26 | --save_name pdebench_U_Net_3DCFD
--------------------------------------------------------------------------------
/scripts/PDEBench/darcy/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \
4 | --loader pdebench_steady_darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --scheduler StepLR \
8 | --downsamplex 1 \
9 | --downsampley 1 \
10 | --space_dim 2 \
11 | --fun_dim 1 \
12 | --out_dim 1 \
13 | --model FNO \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 1 \
18 | --ref 8 \
19 | --batch-size 50 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --ntrain 8000 \
23 | --save_name pdebench_darcy_FNO
--------------------------------------------------------------------------------
/scripts/PDEBench/darcy/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \
4 | --loader pdebench_steady_darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --scheduler StepLR \
8 | --downsamplex 1 \
9 | --downsampley 1 \
10 | --space_dim 2 \
11 | --fun_dim 1 \
12 | --out_dim 1 \
13 | --model MWT \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 50 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --ntrain 8000 \
23 | --save_name pdebench_darcy_MWT
--------------------------------------------------------------------------------
/scripts/PDEBench/darcy/Transfomer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \
4 | --loader pdebench_steady_darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --scheduler StepLR \
8 | --downsamplex 1 \
9 | --downsampley 1 \
10 | --space_dim 2 \
11 | --fun_dim 1 \
12 | --out_dim 1 \
13 | --model Transformer \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 50 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --ntrain 8000 \
23 | --save_name pdebench_darcy_Transformer
--------------------------------------------------------------------------------
/scripts/PDEBench/darcy/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \
4 | --loader pdebench_steady_darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --scheduler StepLR \
8 | --downsamplex 1 \
9 | --downsampley 1 \
10 | --space_dim 2 \
11 | --fun_dim 1 \
12 | --out_dim 1 \
13 | --model Transolver \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 20 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --ntrain 8000 \
23 | --save_name pdebench_darcy_Transolver
--------------------------------------------------------------------------------
/scripts/PDEBench/darcy/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \
4 | --loader pdebench_steady_darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --scheduler StepLR \
8 | --downsamplex 1 \
9 | --downsampley 1 \
10 | --space_dim 2 \
11 | --fun_dim 1 \
12 | --out_dim 1 \
13 | --model U_FNO \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 50 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --ntrain 8000 \
23 | --save_name pdebench_darcy_U_FNO
--------------------------------------------------------------------------------
/scripts/PDEBench/darcy/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \
4 | --loader pdebench_steady_darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --scheduler StepLR \
8 | --downsamplex 1 \
9 | --downsampley 1 \
10 | --space_dim 2 \
11 | --fun_dim 1 \
12 | --out_dim 1 \
13 | --model U_NO \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 50 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --ntrain 8000 \
23 | --save_name pdebench_darcy_U_NO
--------------------------------------------------------------------------------
/scripts/PDEBench/darcy/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/PDEBench/2D/DarcyFlow/2D_DarcyFlow_beta1.0_Train.hdf5 \
4 | --loader pdebench_steady_darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --scheduler StepLR \
8 | --downsamplex 1 \
9 | --downsampley 1 \
10 | --space_dim 2 \
11 | --fun_dim 1 \
12 | --out_dim 1 \
13 | --model U_Net \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 50 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --ntrain 8000 \
23 | --save_name pdebench_darcy_U_Net
--------------------------------------------------------------------------------
/scripts/PDEBench/diff_sorp/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \
4 | --loader pdebench_autoregressive \
5 | --geotype structured_1D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 1 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model FNO \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 1 \
23 | --save_name diff_sorp_FNO
--------------------------------------------------------------------------------
/scripts/PDEBench/diff_sorp/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \
4 | --loader pdebench_autoregressive \
5 | --geotype structured_1D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 1 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model MWT \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 50 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --mwt_k 9 \
24 | --save_name diff_sorp_MWT
--------------------------------------------------------------------------------
/scripts/PDEBench/diff_sorp/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \
4 | --loader pdebench_autoregressive \
5 | --geotype structured_1D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 1 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model Transolver \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name diff_sorp_Transolver
--------------------------------------------------------------------------------
/scripts/PDEBench/diff_sorp/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \
4 | --loader pdebench_autoregressive \
5 | --geotype structured_1D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 1 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model U_FNO \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name diff_sorp_U_FNO
--------------------------------------------------------------------------------
/scripts/PDEBench/diff_sorp/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \
4 | --loader pdebench_autoregressive \
5 | --geotype structured_1D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 1 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model U_NO \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name diff_sorp_U_NO
--------------------------------------------------------------------------------
/scripts/PDEBench/diff_sorp/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/PDEBench/1D/diffusion-sorption/1D_diff-sorp_NA_NA.h5 \
4 | --loader pdebench_autoregressive \
5 | --geotype structured_1D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.001 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 1 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model U_Net \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name diff_sorp_U_Net
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model FNO \
10 | --n_hidden 32 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/F_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --lr 0.001 \
7 | --weight_decay 1e-4 \
8 | --scheduler StepLR \
9 | --space_dim 2 \
10 | --fun_dim 0 \
11 | --out_dim 1 \
12 | --model F_FNO \
13 | --n_hidden 32 \
14 | --n_heads 8 \
15 | --n_layers 8 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 20 \
20 | --epochs 500 \
21 | --vis_bound 40 180 0 35 \
22 | --eval 0 \
23 | --save_name airfoil_F_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/Factformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model Factformer \
10 | --n_hidden 32 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_Factformer
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/GNOT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model GNOT \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_GNOT
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/Galerkin_Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model Galerkin_Transformer \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_Galerkin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/LSM.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model LSM \
10 | --n_hidden 32 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_LSM
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model MWT \
10 | --n_hidden 32 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_MWT
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/Swin.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model Swin \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --vis_bound 40 180 0 35 \
20 | --eval 0 \
21 | --save_name airfoil_Swin
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model Transformer \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --lr 0.001 \
7 | --weight_decay 1e-4 \
8 | --space_dim 2 \
9 | --fun_dim 2 \
10 | --out_dim 1 \
11 | --model Transolver \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 8 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --max_grad_norm 0.1 \
23 | --save_name airfoil_Transolver
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --lr 0.001 \
7 | --weight_decay 1e-4 \
8 | --scheduler StepLR \
9 | --space_dim 2 \
10 | --fun_dim 0 \
11 | --out_dim 1 \
12 | --model U_FNO \
13 | --n_hidden 32 \
14 | --n_heads 8 \
15 | --n_layers 8 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 20 \
20 | --epochs 500 \
21 | --vis_bound 40 180 0 35 \
22 | --eval 0 \
23 | --save_name airfoil_U_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --lr 0.001 \
7 | --weight_decay 1e-4 \
8 | --scheduler StepLR \
9 | --space_dim 2 \
10 | --fun_dim 0 \
11 | --out_dim 1 \
12 | --model U_NO \
13 | --n_hidden 32 \
14 | --n_heads 8 \
15 | --n_layers 8 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 20 \
20 | --epochs 500 \
21 | --vis_bound 40 180 0 35 \
22 | --eval 0 \
23 | --save_name airfoil_U_NO_AdamW
--------------------------------------------------------------------------------
/scripts/StandardBench/airfoil/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/airfoil/naca \
4 | --loader airfoil \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model U_Net \
10 | --n_hidden 32 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --vis_bound 40 180 0 35 \
19 | --eval 0 \
20 | --save_name airfoil_U_Net
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/fno/ \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model FNO \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --unified_pos 1 \
20 | --ref 8 \
21 | --batch-size 4 \
22 | --epochs 500 \
23 | --eval 1 \
24 | --save_name darcy_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/F_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model F_FNO \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 1 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --max_grad_norm 0.1 \
26 | --save_name darcy_F_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/Factformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model Factformer \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --unified_pos 1 \
20 | --ref 8 \
21 | --batch-size 4 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name darcy_Factformer
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/GNOT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --weight_decay 0.00005 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model GNOT \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 3 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 0 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --save_name darcy_GNOT
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/Galerkin_Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model Galerkin_Transformer \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 1 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --save_name darcy_Galerkin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/LSM.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno/ \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --scheduler StepLR \
7 | --task steady \
8 | --normalize 1 \
9 | --lr 0.0005 \
10 | --optimizer Adam \
11 | --weight_decay 1e-4 \
12 | --norm_type UnitGaussianNormalizer \
13 | --downsamplex 5 \
14 | --downsampley 5 \
15 | --space_dim 2 \
16 | --fun_dim 1 \
17 | --out_dim 1 \
18 | --model LSM \
19 | --n_hidden 64 \
20 | --n_heads 8 \
21 | --n_layers 8 \
22 | --slice_num 64 \
23 | --unified_pos 0 \
24 | --ref 8 \
25 | --batch-size 20 \
26 | --epochs 500 \
27 | --eval 0 \
28 | --save_name darcy_LSM
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model MWT \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 1 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --save_name darcy_MWT
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/ONO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model ONO \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 10 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 0 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --psi_dim 32 \
26 | --max_grad_norm 0.1 \
27 | --save_name darcy_ONO
28 |
29 |
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/Swin.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model Swin_Transformer \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 1 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --save_name darcy_Swin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model Transformer \
15 | --n_hidden 32 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 1 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --save_name darcy_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model Transolver \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 1 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --save_name darcy_Transolver
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/fno/ \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model U_FNO \
15 | --n_hidden 32 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 1 \
20 | --ref 8 \
21 | --batch-size 4 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name darcy_U_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno/ \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model U_NO \
15 | --n_hidden 32 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 1 \
20 | --ref 8 \
21 | --batch-size 4 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name darcy_U_NO
--------------------------------------------------------------------------------
/scripts/StandardBench/darcy/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model U_Net \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --slice_num 64 \
20 | --unified_pos 1 \
21 | --ref 8 \
22 | --batch-size 4 \
23 | --epochs 500 \
24 | --eval 0 \
25 | --save_name darcy_U_Net
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --normalize 1 \
10 | --model FNO \
11 | --n_hidden 32 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --batch-size 4 \
15 | --epochs 500 \
16 | --eval 0 \
17 | --save_name elas_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/F_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model F_FNO \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 1 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --max_grad_norm 0.1 \
23 | --save_name elas_F_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/GNOT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model GNOT \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 1 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --save_name elas_GNOT
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/Galerkin_Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model Galerkin_Transformer \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 1 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --save_name elas_Galerkin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/LSM.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --normalize 1 \
10 | --model LSM \
11 | --n_hidden 32 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --batch-size 4 \
15 | --epochs 500 \
16 | --eval 0 \
17 | --save_name elas_LSM
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model MWT \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 1 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --save_name elas_MWT
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model Transformer \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --slice_num 64 \
16 | --unified_pos 0 \
17 | --ref 8 \
18 | --batch-size 1 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name elas_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model Transolver \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 1 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --save_name elas_Transolver
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model U_FNO \
12 | --n_hidden 128 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 1 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --save_name elas_U_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --normalize 1 \
10 | --model U_NO \
11 | --n_hidden 32 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --batch-size 4 \
15 | --epochs 500 \
16 | --eval 0 \
17 | --save_name elas_U_NO
--------------------------------------------------------------------------------
/scripts/StandardBench/elasticity/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/ \
4 | --loader elas \
5 | --geotype unstructured \
6 | --scheduler CosineAnnealingLR \
7 | --space_dim 2 \
8 | --fun_dim 0 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --model U_Net \
12 | --n_hidden 32 \
13 | --n_heads 8 \
14 | --n_layers 8 \
15 | --mlp_ratio 2 \
16 | --slice_num 64 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 1 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --save_name elas_U_Net
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/fno/ \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 2 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model FNO \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name ns_FNO_wo_teacher_forcing_wo_unipos_real_steplr
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/F_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --scheduler StepLR \
7 | --task dynamic_autoregressive \
8 | --space_dim 2 \
9 | --lr 0.0025 \
10 | --weight_decay 0.0001 \
11 | --fun_dim 10 \
12 | --out_dim 1 \
13 | --model F_FNO \
14 | --n_hidden 20 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --mlp_ratio 2 \
18 | --slice_num 32 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 20 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --max_grad_norm 0.1 \
25 | --save_name ns_F_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/Factformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --fun_dim 10 \
9 | --out_dim 1 \
10 | --model Factormer \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 1 \
17 | --ref 8 \
18 | --batch-size 2 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name ns_Factormer
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/GNOT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --weight_decay 0.00005 \
9 | --fun_dim 10 \
10 | --out_dim 1 \
11 | --model GNOT \
12 | --optimizer AdamW \
13 | --n_hidden 128 \
14 | --n_heads 8 \
15 | --n_layers 3 \
16 | --mlp_ratio 2 \
17 | --slice_num 32 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 4 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name ns_GNOT
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/Galerkin_Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --fun_dim 10 \
9 | --out_dim 1 \
10 | --model Galerkin_Transformer \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 1 \
17 | --ref 8 \
18 | --batch-size 2 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name ns_Galerkin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/LSM.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/ \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 2 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model LSM \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name ns_LSM_AdamW
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --fun_dim 10 \
9 | --out_dim 1 \
10 | --model MWT \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 1 \
17 | --ref 8 \
18 | --batch-size 2 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name ns_MWT
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/ONO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --scheduler OneCycleLR \
8 | --space_dim 2 \
9 | --fun_dim 10 \
10 | --attn_type selfAttention \
11 | --out_dim 1 \
12 | --model ONO \
13 | --n_hidden 128 \
14 | --n_heads 8 \
15 | --n_layers 8 \
16 | --mlp_ratio 1 \
17 | --slice_num 32 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 8 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --psi_dim 16 \
24 | --teacher_forcing 0 \
25 | --max_grad_norm 0.1 \
26 | --save_name ns_ONO
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/Swin.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --fun_dim 10 \
9 | --out_dim 1 \
10 | --model Swin_Transformer \
11 | --n_hidden 128 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 0 \
17 | --ref 8 \
18 | --batch-size 2 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name ns_Swin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --fun_dim 10 \
9 | --out_dim 1 \
10 | --model Transformer \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 1 \
17 | --ref 8 \
18 | --batch-size 2 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name ns_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --fun_dim 10 \
9 | --out_dim 1 \
10 | --model Transolver \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 1 \
17 | --ref 8 \
18 | --batch-size 2 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name ns_Transolver
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno/ \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --lr 0.0005 \
9 | --weight_decay 1e-4 \
10 | --scheduler StepLR \
11 | --space_dim 2 \
12 | --fun_dim 10 \
13 | --out_dim 1 \
14 | --model U_FNO \
15 | --n_hidden 64 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --unified_pos 0 \
19 | --ref 8 \
20 | --batch-size 20 \
21 | --epochs 500 \
22 | --eval 0 \
23 | --save_name ns_U_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/ \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --teacher_forcing 0 \
8 | --scheduler StepLR \
9 | --optimizer Adam \
10 | --space_dim 2 \
11 | --fun_dim 10 \
12 | --out_dim 1 \
13 | --model U_NO \
14 | --n_hidden 64 \
15 | --n_heads 8 \
16 | --n_layers 8 \
17 | --unified_pos 0 \
18 | --ref 8 \
19 | --batch-size 16 \
20 | --epochs 500 \
21 | --eval 0 \
22 | --save_name ns_U_NO_Adam
--------------------------------------------------------------------------------
/scripts/StandardBench/ns/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno \
4 | --loader ns \
5 | --geotype structured_2D \
6 | --task dynamic_autoregressive \
7 | --space_dim 2 \
8 | --fun_dim 10 \
9 | --out_dim 1 \
10 | --model U_Net \
11 | --n_hidden 256 \
12 | --n_heads 8 \
13 | --n_layers 8 \
14 | --mlp_ratio 2 \
15 | --slice_num 32 \
16 | --unified_pos 1 \
17 | --ref 8 \
18 | --batch-size 2 \
19 | --epochs 500 \
20 | --eval 0 \
21 | --save_name ns_U_Net
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model FNO \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/F_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model F_FNO \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --max_grad_norm 0.1 \
21 | --normalize 1 \
22 | --save_name pipe_F_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/Factformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader darcy \
5 | --geotype structured_2D \
6 | --task steady \
7 | --normalize 1 \
8 | --derivloss 1 \
9 | --downsamplex 5 \
10 | --downsampley 5 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 1 \
14 | --model Factformer \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --mlp_ratio 2 \
19 | --unified_pos 1 \
20 | --ref 8 \
21 | --batch-size 4 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --normalize 1 \
25 | --save_name darcy_Factformer
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/GNOT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model GNOT \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_GNOT
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/Galerkin_Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model Galerkin_Transformer \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_Galerkin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/LSM.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 0 \
8 | --out_dim 1 \
9 | --model LSM \
10 | --n_hidden 32 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --slice_num 64 \
14 | --unified_pos 0 \
15 | --ref 8 \
16 | --batch-size 4 \
17 | --epochs 500 \
18 | --eval 0 \
19 | --normalize 1 \
20 | --save_name pipe_LSM
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model MWT \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_MWT
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/ONO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --scheduler OneCycleLR \
7 | --space_dim 2 \
8 | --fun_dim 2 \
9 | --out_dim 1 \
10 | --normalize 1 \
11 | --max_grad_norm 0.1 \
12 | --model ONO \
13 | --n_hidden 128 \
14 | --n_heads 8 \
15 | --n_layers 8 \
16 | --mlp_ratio 2 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_ONO
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/Swin.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model Swin_Transformer \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_Swin_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model Transformer \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_Transformer
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model Transolver \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_Transolver
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model U_FNO \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_U_FNO
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 6 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model U_NO \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_U_NO
--------------------------------------------------------------------------------
/scripts/StandardBench/pipe/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model U_Net \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --normalize 1 \
21 | --save_name pipe_U_Net
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 0 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model FNO \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_FNO
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/F_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model F_FNO \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --max_grad_norm 0.1 \
25 | --save_name plas_F_FNO
26 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/Factformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 2 \
3 | --data_path /data/fno/pipe \
4 | --loader pipe \
5 | --geotype structured_2D \
6 | --space_dim 2 \
7 | --fun_dim 2 \
8 | --out_dim 1 \
9 | --model Factformer \
10 | --n_hidden 128 \
11 | --n_heads 8 \
12 | --n_layers 8 \
13 | --mlp_ratio 2 \
14 | --slice_num 64 \
15 | --unified_pos 0 \
16 | --ref 8 \
17 | --batch-size 4 \
18 | --epochs 500 \
19 | --eval 0 \
20 | --save_name pipe_Factformer
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/GNOT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model GNOT \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_GNOT
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/Galerkin_Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model Galerkin_Transformer \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_Galerkin_Transformer
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/LSM.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 3 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model LSM \
15 | --n_hidden 32 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 4 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_LSM
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/MWT.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model MWT \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_MWT
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/ONO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 4 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model ONO \
15 | --n_hidden 32 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 4 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_ONO
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/Swin.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model Swin_Transformer \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_Swin_Transformer
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/Transformer.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model Transformer \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_Transformer
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/Transolver.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 7 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model Transolver \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_Transolver
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/U_FNO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model U_FNO \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_U_FNO
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/U_NO.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 5 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model U_NO \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_U_NO_test
25 |
--------------------------------------------------------------------------------
/scripts/StandardBench/plasticity/U_Net.sh:
--------------------------------------------------------------------------------
1 | python run.py \
2 | --gpu 1 \
3 | --data_path /data/fno/ \
4 | --loader plas \
5 | --geotype structured_2D \
6 | --task dynamic_conditional \
7 | --ntrain 900 \
8 | --ntest 80 \
9 | --T_out 20 \
10 | --time_input 1 \
11 | --space_dim 2 \
12 | --fun_dim 1 \
13 | --out_dim 4 \
14 | --model U_Net \
15 | --n_hidden 128 \
16 | --n_heads 8 \
17 | --n_layers 8 \
18 | --slice_num 64 \
19 | --unified_pos 0 \
20 | --ref 8 \
21 | --batch-size 8 \
22 | --epochs 500 \
23 | --eval 0 \
24 | --save_name plas_U_Net
25 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 |
5 |
6 | class L2Loss(object):
7 | def __init__(self, d=2, p=2, size_average=True, reduction=True):
8 | super(L2Loss, self).__init__()
9 |
10 | assert d > 0 and p > 0
11 |
12 | self.d = d
13 | self.p = p
14 | self.reduction = reduction
15 | self.size_average = size_average
16 |
17 | def abs(self, x, y):
18 | num_examples = x.size()[0]
19 |
20 | h = 1.0 / (x.size()[1] - 1.0)
21 |
22 | all_norms = (h ** (self.d / self.p)) * torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p,
23 | 1)
24 |
25 | if self.reduction:
26 | if self.size_average:
27 | return torch.mean(all_norms)
28 | else:
29 | return torch.sum(all_norms)
30 |
31 | return all_norms
32 |
33 | def rel(self, x, y):
34 | num_examples = x.size()[0]
35 |
36 | diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1)
37 | y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1)
38 | if self.reduction:
39 | if self.size_average:
40 | return torch.mean(diff_norms / y_norms)
41 | else:
42 | return torch.sum(diff_norms / y_norms)
43 |
44 | return diff_norms / y_norms
45 |
46 | def __call__(self, x, y):
47 | return self.rel(x, y)
48 |
49 |
50 | class DerivLoss(object):
51 | def __init__(self, d=2, p=2, size_average=True, reduction=True, shapelist=None):
52 | super(DerivLoss, self).__init__()
53 |
54 | assert d > 0 and p > 0
55 | self.shapelist = shapelist
56 | self.de_x = L2Loss(d=d, p=p, size_average=size_average, reduction=reduction)
57 | self.de_y = L2Loss(d=d, p=p, size_average=size_average, reduction=reduction)
58 |
59 | def central_diff(self, x, h1, h2, s1, s2):
60 | # assuming PBC
61 | # x: (batch, n, feats), h is the step size, assuming n = h*w
62 | x = rearrange(x, 'b (h w) c -> b h w c', h=s1, w=s2)
63 | x = F.pad(x,
64 | (0, 0, 1, 1, 1, 1), mode='constant', value=0.) # [b c t h+2 w+2]
65 | grad_x = (x[:, 1:-1, 2:, :] - x[:, 1:-1, :-2, :]) / (2 * h1) # f(x+h) - f(x-h) / 2h
66 | grad_y = (x[:, 2:, 1:-1, :] - x[:, :-2, 1:-1, :]) / (2 * h2) # f(x+h) - f(x-h) / 2h
67 |
68 | return grad_x, grad_y
69 |
70 | def __call__(self, out, y):
71 | out = rearrange(out, 'b (h w) c -> b c h w', h=self.shapelist[0], w=self.shapelist[1])
72 | out = out[..., 1:-1, 1:-1].contiguous()
73 | out = F.pad(out, (1, 1, 1, 1), "constant", 0)
74 | out = rearrange(out, 'b c h w -> b (h w) c')
75 | gt_grad_x, gt_grad_y = self.central_diff(y, 1.0 / float(self.shapelist[0]),
76 | 1.0 / float(self.shapelist[1]), self.shapelist[0], self.shapelist[1])
77 | pred_grad_x, pred_grad_y = self.central_diff(out, 1.0 / float(self.shapelist[0]),
78 | 1.0 / float(self.shapelist[1]), self.shapelist[0],
79 | self.shapelist[1])
80 | deriv_loss = self.de_x(pred_grad_x, gt_grad_x) + self.de_y(pred_grad_y, gt_grad_y)
81 | return deriv_loss
82 |
--------------------------------------------------------------------------------
/utils/normalizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import *
3 |
4 |
5 | class IdentityTransformer():
6 | def __init__(self, X):
7 | self.mean = X.mean(dim=0, keepdim=True)
8 | self.std = X.std(dim=0, keepdim=True) + 1e-8
9 |
10 | def to(self, device):
11 | self.mean = self.mean.to(device)
12 | self.std = self.std.to(device)
13 | return self
14 |
15 | def cuda(self):
16 | self.mean = self.mean.cuda()
17 | self.std = self.std.cuda()
18 |
19 | def cpu(self):
20 | self.mean = self.mean.cpu()
21 | self.std = self.std.cpu()
22 |
23 | def encode(self, x):
24 | return x
25 |
26 | def decode(self, x):
27 | return x
28 |
29 |
30 | class UnitTransformer():
31 | def __init__(self, X):
32 | self.mean = X.mean(dim=(0, 1), keepdim=True)
33 | self.std = X.std(dim=(0, 1), keepdim=True) + 1e-8
34 |
35 | def to(self, device):
36 | self.mean = self.mean.to(device)
37 | self.std = self.std.to(device)
38 | return self
39 |
40 | def cuda(self):
41 | self.mean = self.mean.cuda()
42 | self.std = self.std.cuda()
43 |
44 | def cpu(self):
45 | self.mean = self.mean.cpu()
46 | self.std = self.std.cpu()
47 |
48 | def encode(self, x):
49 | x = (x - self.mean) / (self.std)
50 | return x
51 |
52 | def decode(self, x):
53 | return x * self.std + self.mean
54 |
55 | def transform(self, X, inverse=True, component='all'):
56 | if component == 'all' or 'all-reduce':
57 | if inverse:
58 | orig_shape = X.shape
59 | return (X * (self.std - 1e-8) + self.mean).view(orig_shape)
60 | else:
61 | return (X - self.mean) / self.std
62 | else:
63 | if inverse:
64 | orig_shape = X.shape
65 | return (X * (self.std[:, component] - 1e-8) + self.mean[:, component]).view(orig_shape)
66 | else:
67 | return (X - self.mean[:, component]) / self.std[:, component]
68 |
69 |
70 | class UnitGaussianNormalizer(object):
71 | def __init__(self, x, eps=0.00001, time_last=True):
72 | super(UnitGaussianNormalizer, self).__init__()
73 |
74 | self.mean = torch.mean(x, 0)
75 | self.std = torch.std(x, 0)
76 | self.eps = eps
77 | self.time_last = time_last # if the time dimension is the last dim
78 |
79 | def encode(self, x):
80 | x = (x - self.mean) / (self.std + self.eps)
81 | return x
82 |
83 | def decode(self, x, sample_idx=None):
84 | # sample_idx is the spatial sampling mask
85 | if sample_idx is None:
86 | std = self.std + self.eps # n
87 | mean = self.mean
88 | else:
89 | if self.mean.ndim == sample_idx.ndim or self.time_last:
90 | std = self.std[sample_idx] + self.eps # batch*n
91 | mean = self.mean[sample_idx]
92 | if self.mean.ndim > sample_idx.ndim and not self.time_last:
93 | std = self.std[..., sample_idx] + self.eps # T*batch*n
94 | mean = self.mean[..., sample_idx]
95 | # x is in shape of batch*(spatial discretization size) or T*batch*(spatial discretization size)
96 | x = (x * std) + mean
97 | return x
98 |
99 | def to(self, device):
100 | if torch.is_tensor(self.mean):
101 | self.mean = self.mean.to(device)
102 | self.std = self.std.to(device)
103 | else:
104 | self.mean = torch.from_numpy(self.mean).to(device)
105 | self.std = torch.from_numpy(self.std).to(device)
106 | return self
107 |
108 | def cuda(self):
109 | self.mean = self.mean.cuda()
110 | self.std = self.std.cuda()
111 |
112 | def cpu(self):
113 | self.mean = self.mean.cpu()
114 | self.std = self.std.cpu()
115 |
--------------------------------------------------------------------------------