├── .gitignore
├── Airfoil-Design-AirfRANS
├── LICENSE
├── README.md
├── dataset
│ ├── dataset.py
│ └── dataset_stats.ipynb
├── fig
│ ├── results.png
│ └── task.png
├── main.py
├── main_evaluation.py
├── models
│ ├── GUNet.py
│ ├── GraphSAGE.py
│ ├── MLP.py
│ ├── NN.py
│ ├── PointNet.py
│ └── Transolver.py
├── params.yaml
├── requirements.txt
├── scripts
│ ├── Evaluation.sh
│ ├── GraphSAGE.sh
│ └── Transolver.sh
├── train.py
└── utils
│ ├── metrics.py
│ ├── metrics_NACA.py
│ ├── naca_generator.py
│ └── reorganize.py
├── Car-Design-ShapeNetCar
├── README.md
├── dataset
│ ├── dataset.py
│ └── load_dataset.py
├── fig
│ ├── car_slice_surf.png
│ ├── case_study.png
│ ├── results.png
│ └── task.png
├── main.py
├── main_evaluation.py
├── models
│ └── Transolver.py
├── requirements.txt
├── scripts
│ ├── Evaluation.sh
│ └── Transolver.sh
├── train.py
└── utils
│ └── drag_coefficient.py
├── LICENSE
├── PDE-Solving-StandardBenchmark
├── README.md
├── exp_airfoil.py
├── exp_darcy.py
├── exp_elas.py
├── exp_ns.py
├── exp_pipe.py
├── exp_plas.py
├── fig
│ ├── scalibility.png
│ ├── showcase.png
│ └── standard_benchmark.png
├── model
│ ├── Embedding.py
│ ├── Physics_Attention.py
│ ├── Transolver_Irregular_Mesh.py
│ ├── Transolver_Structured_Mesh_2D.py
│ └── Transolver_Structured_Mesh_3D.py
├── model_dict.py
├── requirements.txt
├── scripts
│ ├── Transolver_Airfoil.sh
│ ├── Transolver_Darcy.sh
│ ├── Transolver_Elas.sh
│ ├── Transolver_NS.sh
│ ├── Transolver_Pipe.sh
│ └── Transolver_Plas.sh
└── utils
│ ├── normalizer.py
│ └── testloss.py
├── Physics_Attention.py
├── README.md
└── pic
├── Transolver.png
├── physical_states.png
└── showcases.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/README.md:
--------------------------------------------------------------------------------
1 | # Transolver for Airfoil Design
2 |
3 | **Paper Correction:** There is a typo in our descriptions about evaluation metrics for the physics fields of volume and surface. The paper's reported results are MSE (not Relative L2), which is completely the same as the AirfRANS paper. We are sincerely sorry for this mistake.
4 |
5 | We test [Transolver](https://arxiv.org/abs/2402.02366) on practical design tasks. The airfoil design task requires the model to estimate the surrounding and surface physical quantities of a 2D airfoil under different Reynolds and angles of attacks.
6 |
7 |
8 |
9 |
10 | Figure 1. Airfoil design task. Left: surrounding pressure; Right: x-direction wind speed.
11 |
12 |
13 | ## Get Started
14 |
15 | This part of code is developed based on the [[AirfRANS]](https://github.com/Extrality/AirfRANS).
16 |
17 | 1. Install Python 3.8. For convenience, execute the following command.
18 |
19 | ```bash
20 | pip install -r requirements.txt
21 | ```
22 |
23 | Note: You need to install [pytorch_geometric](https://github.com/pyg-team/pytorch_geometric).
24 |
25 | 2. Prepare Data.
26 |
27 | The experiment data is provided by [[AirfRANS]](https://github.com/Extrality/AirfRANS). You can directly download it with this [link](https://data.isir.upmc.fr/extrality/NeurIPS_2022/Dataset.zip) (9.3GB).
28 |
29 | 3. Train and evaluate model. We provide the experiment scripts under the folder `./scripts/`. You can reproduce the experiment results as the following examples:
30 |
31 | ```bash
32 | bash scripts/Transolver.sh # for Training Transolver (will take 20-24 hours on one single A100)
33 | bash scripts/Evaluation.sh # for Evaluation
34 | bash scripts/GraphSAGE.sh # for Training GraphSAGE (will take 30-36 hours on one single A100)
35 | ```
36 |
37 | Note: You need to change the argument `--my_path` to your dataset path.
38 |
39 | 4. Test model with different settings. This benchmark supports four types of settings.
40 |
41 | | Settings | Argument |
42 | | -------------------------------------------- | ------------- |
43 | | Use full data | `-t full` |
44 | | Use scarce data | `-t scarce` |
45 | | Test on out-of-distribution Reynolds | `-t reynolds` |
46 | | Test on out-of-distribution Angle of Attacks | `-t aoa` |
47 |
48 | 5. Develop your own model. Here are the instructions:
49 |
50 | - Add the model file under folder `./models/`.
51 |
52 | - Add the training details in `./params.yaml`. If you donot want to change setting, just copy other models' configuration.
53 |
54 | - Add the model configuration into `./main.py`.
55 |
56 | - Add a script file under folder `./scripts/` and change the argument `--model`.
57 |
58 | ## Main Results
59 |
60 | Transolver achieves the consistent best performance in practical design tasks.
61 |
62 |
63 |
64 |
65 | Table 1. Model comparisons on the practical design tasks.
66 |
67 |
68 | ## Citation
69 |
70 | If you find this repo useful, please cite our paper.
71 |
72 | ```
73 | @inproceedings{wu2024Transolver,
74 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries},
75 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long},
76 | booktitle={International Conference on Machine Learning},
77 | year={2024}
78 | }
79 | ```
80 |
81 | ## Contact
82 |
83 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn).
84 |
85 | ## Acknowledgement
86 |
87 | We appreciate the following github repos a lot for their valuable code base or datasets:
88 |
89 | https://github.com/Extrality/AirfRANS
90 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/dataset/dataset_stats.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import json\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "import seaborn as sns\n",
13 | "import os.path as osp\n",
14 | "\n",
15 | "NU = 1.56e-5\n",
16 | "\n",
17 | "sets = ['full_train', 'scarce_train', 'reynolds_train', 'aoa_train', 'full_test', 'reynolds_test', 'aoa_test']\n",
18 | "colors = ['cornflowerblue']*4 + ['burlywood']*3\n",
19 | "data_dir = 'MY_ROOT_DIRECTORY'\n",
20 | "\n",
21 | "for c, s in zip(colors, sets):\n",
22 | " with open(osp.join(data_dir, 'manifest.json'), 'r') as f:\n",
23 | " manifest = json.load(f)[s]\n",
24 | "\n",
25 | " us = []\n",
26 | " angles = []\n",
27 | " digits4 = []\n",
28 | " digits5 = []\n",
29 | " for sim in manifest:\n",
30 | " params = sim.split('_')\n",
31 | " us.append(float(params[2])/NU)\n",
32 | " angles.append(float(params[3]))\n",
33 | "\n",
34 | " if len(params) == 7:\n",
35 | " digits4.append(list(map(float, params[-3:])))\n",
36 | " else:\n",
37 | " digits5.append(list(map(float, params[-4:])))\n",
38 | "\n",
39 | " digits4 = np.array(digits4)\n",
40 | " digits5 = np.array(digits5)\n",
41 | "\n",
42 | " sns.set()\n",
43 | "\n",
44 | " fig, ax = plt.subplots(3, 3, figsize = (12, 12))\n",
45 | " ax[2, 1].hist(us, bins = 20, color = c)\n",
46 | " ax[2, 1].set_title('Reynolds number')\n",
47 | "\n",
48 | " ax[2, 2].hist(angles, bins = 20, color = c)\n",
49 | " ax[2, 2].set_xlabel('Degrees')\n",
50 | " ax[2, 2].set_title('Angle of attack')\n",
51 | "\n",
52 | " ax[0, 0].hist(digits4[:, 0], bins = 20, color = c)\n",
53 | " ax[0, 0].set_title(r'$1^{st}$ digit')\n",
54 | "\n",
55 | " ax[0, 1].hist(digits4[:, 1], bins = 20, color = c)\n",
56 | " ax[0, 1].set_title(r'$2^{nd}$ digit')\n",
57 | "\n",
58 | " ax[0, 2].hist(digits4[:, 2], bins = 20, color = c)\n",
59 | " ax[0, 2].set_title(r'$3^{rd}$ and $4^{th}$ digits')\n",
60 | "\n",
61 | " ax[1, 0].hist(digits5[:, 0], bins = 20, color = c)\n",
62 | " ax[1, 0].set_title(r'$1^{st}$ digit')\n",
63 | "\n",
64 | " ax[1, 1].hist(digits5[:, 1], bins = 20, color = c)\n",
65 | " ax[1, 1].set_title(r'$2^{nd}$ digit')\n",
66 | "\n",
67 | " ax[2, 0].hist(digits5[:, 2], bins = 2, color = c)\n",
68 | " ax[2, 0].set_title(r'$3^{rd}$ digit')\n",
69 | "\n",
70 | " ax[1, 2].hist(digits5[:, 3], bins = 20, color = c)\n",
71 | " ax[1, 2].set_title(r'$4^{th}$ and $5^{th}$ digits');\n",
72 | " fig.savefig(s, bbox_inches = 'tight', dpi = 150)"
73 | ]
74 | }
75 | ],
76 | "metadata": {
77 | "kernelspec": {
78 | "display_name": "Python 3.9.12 ('isir')",
79 | "language": "python",
80 | "name": "python3"
81 | },
82 | "language_info": {
83 | "codemirror_mode": {
84 | "name": "ipython",
85 | "version": 3
86 | },
87 | "file_extension": ".py",
88 | "mimetype": "text/x-python",
89 | "name": "python",
90 | "nbconvert_exporter": "python",
91 | "pygments_lexer": "ipython3",
92 | "version": "3.9.12"
93 | },
94 | "orig_nbformat": 4,
95 | "vscode": {
96 | "interpreter": {
97 | "hash": "d00e44851a3a4d5201bc229183e4c0de3fea7314717b82800f8d82d2168b4a23"
98 | }
99 | }
100 | },
101 | "nbformat": 4,
102 | "nbformat_minor": 2
103 | }
104 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/fig/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Airfoil-Design-AirfRANS/fig/results.png
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/fig/task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Airfoil-Design-AirfRANS/fig/task.png
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/main.py:
--------------------------------------------------------------------------------
1 | import argparse, yaml, json
2 | import torch
3 | import train
4 | import utils.metrics as metrics
5 | from dataset.dataset import Dataset
6 | import os.path as osp
7 | import numpy as np
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--model', help='The model you want to train, choose between MLP, GraphSAGE, PointNet, GUNet',
11 | type=str)
12 | parser.add_argument('-n', '--nmodel', help='Number of trained models for standard deviation estimation (default: 1)',
13 | default=1, type=int)
14 | parser.add_argument('-w', '--weight', help='Weight in front of the surface loss (default: 1)', default=1, type=float)
15 | parser.add_argument('-t', '--task',
16 | help='Task to train on. Choose between "full", "scarce", "reynolds" and "aoa" (default: full)',
17 | default='full', type=str)
18 | parser.add_argument('-s', '--score',
19 | help='If you want to compute the score of the models on the associated test set. (default: 0)',
20 | default=0, type=int)
21 | parser.add_argument('--my_path',
22 | default='/data/path', type=str)
23 | parser.add_argument('--save_path',
24 | default='metrics', type=str)
25 | args = parser.parse_args()
26 |
27 | with open(args.my_path + '/manifest.json', 'r') as f:
28 | manifest = json.load(f)
29 |
30 | manifest_train = manifest[args.task + '_train']
31 | test_dataset = manifest[args.task + '_test'] if args.task != 'scarce' else manifest['full_test']
32 | n = int(.1 * len(manifest_train))
33 | train_dataset = manifest_train[:-n]
34 | val_dataset = manifest_train[-n:]
35 | print("start load data")
36 | train_dataset, coef_norm = Dataset(train_dataset, norm=True, sample=None, my_path=args.my_path)
37 | val_dataset = Dataset(val_dataset, sample=None, coef_norm=coef_norm, my_path=args.my_path)
38 | print("load data finish")
39 | # Cuda
40 | use_cuda = torch.cuda.is_available()
41 | device = 'cuda:0' if use_cuda else 'cpu'
42 | if use_cuda:
43 | print('Using GPU')
44 | else:
45 | print('Using CPU')
46 |
47 | with open('params.yaml', 'r') as f: # hyperparameters of the model
48 | hparams = yaml.safe_load(f)[args.model]
49 |
50 | from models.MLP import MLP
51 |
52 | models = []
53 | for i in range(args.nmodel):
54 |
55 | if args.model == 'Transolver':
56 | from models.Transolver import Transolver
57 |
58 | model = Transolver(n_hidden=256,
59 | n_layers=8,
60 | space_dim=7,
61 | fun_dim=0,
62 | n_head=8,
63 | mlp_ratio=2,
64 | out_dim=4,
65 | slice_num=32,
66 | unified_pos=1).cuda()
67 | else:
68 | encoder = MLP(hparams['encoder'], batch_norm=False)
69 | decoder = MLP(hparams['decoder'], batch_norm=False)
70 | if args.model == 'GraphSAGE':
71 | from models.GraphSAGE import GraphSAGE
72 |
73 | model = GraphSAGE(hparams, encoder, decoder)
74 |
75 | elif args.model == 'PointNet':
76 | from models.PointNet import PointNet
77 |
78 | model = PointNet(hparams, encoder, decoder)
79 |
80 | elif args.model == 'MLP':
81 | from models.NN import NN
82 |
83 | model = NN(hparams, encoder, decoder)
84 |
85 | elif args.model == 'GUNet':
86 | from models.GUNet import GUNet
87 |
88 | model = GUNet(hparams, encoder, decoder)
89 |
90 | log_path = osp.join(args.save_path, args.task, args.model) # path where you want to save log and figures
91 | print('start training')
92 | model = train.main(device, train_dataset, val_dataset, model, hparams, log_path,
93 | criterion='MSE_weighted', val_iter=10, reg=args.weight, name_mod=args.model, val_sample=True)
94 | print('end training')
95 | models.append(model)
96 | torch.save(models, osp.join(args.save_path, args.task, args.model, args.model))
97 |
98 | if bool(args.score):
99 | print('start score')
100 | s = args.task + '_test' if args.task != 'scarce' else 'full_test'
101 | coefs = metrics.Results_test(device, [models], [hparams], coef_norm, args.my_path, path_out='scores', n_test=3,
102 | criterion='MSE', s=s)
103 | # models can be a stack of the same model (for example MLP) on the task s, if you have another stack of another model (for example GraphSAGE)
104 | # you can put in model argument [models_MLP, models_GraphSAGE] and it will output the results for both models (mean and std) in an ordered array.
105 | np.save(osp.join('scores', args.task, 'true_coefs'), coefs[0])
106 | np.save(osp.join('scores', args.task, 'pred_coefs_mean'), coefs[1])
107 | np.save(osp.join('scores', args.task, 'pred_coefs_std'), coefs[2])
108 | for n, file in enumerate(coefs[3]):
109 | np.save(osp.join('scores', args.task, 'true_surf_coefs_' + str(n)), file)
110 | for n, file in enumerate(coefs[4]):
111 | np.save(osp.join('scores', args.task, 'surf_coefs_' + str(n)), file)
112 | np.save(osp.join('scores', args.task, 'true_bls'), coefs[5])
113 | np.save(osp.join('scores', args.task, 'bls'), coefs[6])
114 | print('end score')
115 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/main_evaluation.py:
--------------------------------------------------------------------------------
1 | import yaml, json
2 | import torch
3 | import utils.metrics as metrics
4 | from dataset.dataset import Dataset
5 | import os.path as osp
6 | import argparse
7 | import numpy as np
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--my_path', default='/data/path', type=str) # data save path
11 | parser.add_argument('--save_path', default='./', type=str) # model save path
12 | args = parser.parse_args()
13 |
14 | # Compute the normalization used for the training
15 |
16 | use_cuda = torch.cuda.is_available()
17 | device = 'cuda:0' if use_cuda else 'cpu'
18 | if use_cuda:
19 | print('Using GPU')
20 | else:
21 | print('Using CPU')
22 |
23 | data_root_dir = args.my_path
24 | ckpt_root_dir = args.save_path
25 |
26 | tasks = ['full']
27 |
28 | for task in tasks:
29 | print('Generating results for task ' + task + '...')
30 | # task = 'full' # Choose between 'full', 'scarce', 'reynolds', and 'aoa'
31 | s = task + '_test' if task != 'scarce' else 'full_test'
32 | s_train = task + '_train'
33 |
34 | data_dir = osp.join(data_root_dir, 'Dataset')
35 | with open(osp.join(data_dir, 'manifest.json'), 'r') as f:
36 | manifest = json.load(f)
37 |
38 | manifest_train = manifest[s_train]
39 | n = int(.1 * len(manifest_train))
40 | train_dataset = manifest_train[:-n]
41 |
42 | _, coef_norm = Dataset(train_dataset, norm=True, sample=None, my_path=data_dir)
43 |
44 | # Compute the scores on the test set
45 |
46 | model_names = ['Transolver']
47 | models = []
48 | hparams = []
49 |
50 | for model in model_names:
51 | model_path = osp.join(ckpt_root_dir, 'metrics', task, model, model)
52 | mod = torch.load(model_path)
53 | print(mod)
54 | mod = [m.to(device) for m in mod]
55 | models.append(mod)
56 |
57 | with open('params.yaml', 'r') as f:
58 | hparam = yaml.safe_load(f)[model]
59 | hparams.append(hparam)
60 |
61 | results_dir = osp.join(ckpt_root_dir, 'scores', task)
62 | coefs = metrics.Results_test(device, models, hparams, coef_norm, data_dir, results_dir, n_test=3, criterion='MSE',
63 | s=s)
64 | # models can be a stack of the same model (for example MLP) on the task s, if you have another stack of another model (for example GraphSAGE)
65 | # you can put in model argument [models_MLP, models_GraphSAGE] and it will output the results for both models (mean and std) in an ordered array.
66 |
67 | np.save(osp.join(results_dir, 'true_coefs'), coefs[0])
68 | np.save(osp.join(results_dir, 'pred_coefs_mean'), coefs[1])
69 | np.save(osp.join(results_dir, 'pred_coefs_std'), coefs[2])
70 | for n, file in enumerate(coefs[3]):
71 | np.save(osp.join(results_dir, 'true_surf_coefs_' + str(n)), file)
72 | for n, file in enumerate(coefs[4]):
73 | np.save(osp.join(results_dir, 'surf_coefs_' + str(n)), file)
74 | np.save(osp.join(results_dir, 'true_bls'), coefs[5])
75 | np.save(osp.join(results_dir, 'bls'), coefs[6])
76 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/models/GUNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric.nn as nng
4 | import random
5 |
6 | def DownSample(id, x, edge_index, pos_x, pool, pool_ratio, r, max_neighbors):
7 | y = x.clone()
8 | n = int(x.size(0))
9 |
10 | if pool is not None:
11 | y, _, _, _, id_sampled, _ = pool(y, edge_index)
12 | else:
13 | k = int((pool_ratio*torch.tensor(n, dtype = torch.float)).ceil())
14 | id_sampled = random.sample(range(n), k)
15 | id_sampled = torch.tensor(id_sampled, dtype = torch.long)
16 | y = y[id_sampled]
17 |
18 | pos_x = pos_x[id_sampled]
19 | id.append(id_sampled)
20 |
21 | # if training:
22 | # edge_index_sampled = nng.radius_graph(x = pos_x.detach(), r = r, loop = True, max_num_neighbors = 64)
23 | # else:
24 | # edge_index_sampled = nng.radius_graph(x = pos_x.detach(), r = r, loop = True, max_num_neighbors = 512)
25 | edge_index_sampled = nng.radius_graph(x = pos_x.detach(), r = r, loop = True, max_num_neighbors = max_neighbors)
26 |
27 | return y, edge_index_sampled
28 |
29 | def UpSample(x, pos_x_up, pos_x_down):
30 | cluster = nng.nearest(pos_x_up, pos_x_down)
31 | x_up = x[cluster]
32 |
33 | return x_up
34 |
35 | class GUNet(nn.Module):
36 | def __init__(self, hparams, encoder, decoder):
37 | super(GUNet, self).__init__()
38 |
39 | self.L = hparams['nb_scale']
40 | self.layer = hparams['layer']
41 | self.pool_type = hparams['pool']
42 | self.pool_ratio = hparams['pool_ratio']
43 | self.list_r = hparams['list_r']
44 | self.size_hidden_layers = hparams['size_hidden_layers']
45 | self.size_hidden_layers_init = hparams['size_hidden_layers']
46 | self.max_neighbors = hparams['max_neighbors']
47 | self.dim_enc = hparams['encoder'][-1]
48 | self.bn_bool = hparams['batchnorm']
49 | self.res = hparams['res']
50 | self.head = 2
51 | self.activation = nn.ReLU()
52 |
53 | self.encoder = encoder
54 | self.decoder = decoder
55 |
56 | self.down_layers = nn.ModuleList()
57 |
58 | if self.pool_type != 'random':
59 | self.pool = nn.ModuleList()
60 | else:
61 | self.pool = None
62 |
63 | if self.layer == 'SAGE':
64 | self.down_layers.append(nng.SAGEConv(
65 | in_channels = self.dim_enc,
66 | out_channels = self.size_hidden_layers
67 | ))
68 | bn_in = self.size_hidden_layers
69 |
70 | elif self.layer == 'GAT':
71 | self.down_layers.append(nng.GATConv(
72 | in_channels = self.dim_enc,
73 | out_channels = self.size_hidden_layers,
74 | heads = self.head,
75 | add_self_loops = False,
76 | concat = True
77 | ))
78 | bn_in = self.head*self.size_hidden_layers
79 |
80 | if self.bn_bool == True:
81 | self.bn = nn.ModuleList()
82 | self.bn.append(nng.BatchNorm(
83 | in_channels = bn_in,
84 | track_running_stats = False
85 | ))
86 | else:
87 | self.bn = None
88 |
89 |
90 | for n in range(1, self.L):
91 | if self.pool_type != 'random':
92 | self.pool.append(nng.TopKPooling(
93 | in_channels = self.size_hidden_layers,
94 | ratio = self.pool_ratio[n - 1],
95 | nonlinearity = torch.sigmoid
96 | ))
97 |
98 | if self.layer == 'SAGE':
99 | self.down_layers.append(nng.SAGEConv(
100 | in_channels = self.size_hidden_layers,
101 | out_channels = 2*self.size_hidden_layers,
102 | ))
103 | self.size_hidden_layers = 2*self.size_hidden_layers
104 | bn_in = self.size_hidden_layers
105 |
106 | elif self.layer == 'GAT':
107 | self.down_layers.append(nng.GATConv(
108 | in_channels = self.head*self.size_hidden_layers,
109 | out_channels = self.size_hidden_layers,
110 | heads = 2,
111 | add_self_loops = False,
112 | concat = True
113 | ))
114 |
115 | if self.bn_bool == True:
116 | self.bn.append(nng.BatchNorm(
117 | in_channels = bn_in,
118 | track_running_stats = False
119 | ))
120 |
121 | self.up_layers = nn.ModuleList()
122 |
123 | if self.layer == 'SAGE':
124 | self.up_layers.append(nng.SAGEConv(
125 | in_channels = 3*self.size_hidden_layers_init,
126 | out_channels = self.dim_enc
127 | ))
128 | self.size_hidden_layers_init = 2*self.size_hidden_layers_init
129 |
130 | elif self.layer == 'GAT':
131 | self.up_layers.append(nng.GATConv(
132 | in_channels = 2*self.head*self.size_hidden_layers,
133 | out_channels = self.dim_enc,
134 | heads = 2,
135 | add_self_loops = False,
136 | concat = False
137 | ))
138 |
139 | if self.bn_bool == True:
140 | self.bn.append(nng.BatchNorm(
141 | in_channels = self.dim_enc,
142 | track_running_stats = False
143 | ))
144 |
145 | for n in range(1, self.L - 1):
146 | if self.layer == 'SAGE':
147 | self.up_layers.append(nng.SAGEConv(
148 | in_channels = 3*self.size_hidden_layers_init,
149 | out_channels = self.size_hidden_layers_init,
150 | ))
151 | bn_in = self.size_hidden_layers_init
152 | self.size_hidden_layers_init = 2*self.size_hidden_layers_init
153 |
154 | elif self.layer == 'GAT':
155 | self.up_layers.append(nng.GATConv(
156 | in_channels = 2*self.head*self.size_hidden_layers,
157 | out_channels = self.size_hidden_layers,
158 | heads = 2,
159 | add_self_loops = False,
160 | concat = True
161 | ))
162 |
163 | if self.bn_bool == True:
164 | self.bn.append(nng.BatchNorm(
165 | in_channels = bn_in,
166 | track_running_stats = False
167 | ))
168 |
169 | def forward(self, data):
170 | x, edge_index = data.x, data.edge_index
171 | id = []
172 | edge_index_list = [edge_index.clone()]
173 | pos_x_list = []
174 | z = self.encoder(x)
175 | if self.res:
176 | z_res = z.clone()
177 |
178 | z = self.down_layers[0](z, edge_index)
179 |
180 | if self.bn_bool == True:
181 | z = self.bn[0](z)
182 |
183 | z = self.activation(z)
184 | z_list = [z.clone()]
185 | for n in range(self.L - 1):
186 | pos_x = x[:, :2] if n == 0 else pos_x[id[n - 1]]
187 | pos_x_list.append(pos_x.clone())
188 |
189 | if self.pool_type != 'random':
190 | z, edge_index = DownSample(id, z, edge_index, pos_x, self.pool[n], self.pool_ratio[n], self.list_r[n], self.max_neighbors)
191 | else:
192 | z, edge_index = DownSample(id, z, edge_index, pos_x, None, self.pool_ratio[n], self.list_r[n], self.max_neighbors)
193 | edge_index_list.append(edge_index.clone())
194 |
195 | z = self.down_layers[n + 1](z, edge_index)
196 |
197 | if self.bn_bool == True:
198 | z = self.bn[n + 1](z)
199 |
200 | z = self.activation(z)
201 | z_list.append(z.clone())
202 | pos_x_list.append(pos_x[id[-1]].clone())
203 |
204 | for n in range(self.L - 1, 0, -1):
205 | z = UpSample(z, pos_x_list[n - 1], pos_x_list[n])
206 | z = torch.cat([z, z_list[n - 1]], dim = 1)
207 | z = self.up_layers[n - 1](z, edge_index_list[n - 1])
208 |
209 | if self.bn_bool == True:
210 | z = self.bn[self.L + n - 1](z)
211 |
212 | z = self.activation(z) if n != 1 else z
213 |
214 | del(z_list, pos_x_list, edge_index_list)
215 |
216 | if self.res:
217 | z = z + z_res
218 |
219 | z = self.decoder(z)
220 |
221 | return z
222 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/models/GraphSAGE.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch_geometric.nn as nng
3 |
4 | class GraphSAGE(nn.Module):
5 | def __init__(self, hparams, encoder, decoder):
6 | super(GraphSAGE, self).__init__()
7 |
8 | self.nb_hidden_layers = hparams['nb_hidden_layers']
9 | self.size_hidden_layers = hparams['size_hidden_layers']
10 | self.bn_bool = hparams['bn_bool']
11 | self.activation = nn.ReLU()
12 |
13 | self.encoder = encoder
14 | self.decoder = decoder
15 |
16 | self.in_layer = nng.SAGEConv(
17 | in_channels = hparams['encoder'][-1],
18 | out_channels = self.size_hidden_layers
19 | )
20 |
21 | self.hidden_layers = nn.ModuleList()
22 | for n in range(self.nb_hidden_layers - 1):
23 | self.hidden_layers.append(nng.SAGEConv(
24 | in_channels = self.size_hidden_layers,
25 | out_channels = self.size_hidden_layers
26 | ))
27 |
28 |
29 | self.out_layer = nng.SAGEConv(
30 | in_channels = self.size_hidden_layers,
31 | out_channels = hparams['decoder'][0]
32 | )
33 |
34 | if self.bn_bool:
35 | self.bn = nn.ModuleList()
36 | for n in range(self.nb_hidden_layers):
37 | self.bn.append(nn.BatchNorm1d(self.size_hidden_layers, track_running_stats = False))
38 |
39 | def forward(self, data):
40 | z, edge_index = data.x, data.edge_index
41 | z = self.encoder(z)
42 |
43 | z = self.in_layer(z, edge_index)
44 | if self.bn_bool:
45 | z = self.bn[0](z)
46 | z = self.activation(z)
47 |
48 | for n in range(self.nb_hidden_layers - 1):
49 | z = self.hidden_layers[n](z, edge_index)
50 | if self.bn_bool:
51 | z = self.bn[n + 1](z)
52 | z = self.activation(z)
53 |
54 | z = self.out_layer(z, edge_index)
55 |
56 | z = self.decoder(z)
57 |
58 | return z
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/models/MLP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import Tensor
4 | from torch.nn import BatchNorm1d, Identity
5 | from torch_geometric.nn import Linear
6 |
7 | class MLP(torch.nn.Module):
8 | r"""A multi-layer perception (MLP) model.
9 |
10 | Args:
11 | channel_list (List[int]): List of input, intermediate and output
12 | channels. :obj:`len(channel_list) - 1` denotes the number of layers
13 | of the MLP.
14 | dropout (float, optional): Dropout probability of each hidden
15 | embedding. (default: :obj:`0.`)
16 | batch_norm (bool, optional): If set to :obj:`False`, will not make use
17 | of batch normalization. (default: :obj:`True`)
18 | relu_first (bool, optional): If set to :obj:`True`, ReLU activation is
19 | applied before batch normalization. (default: :obj:`False`)
20 | """
21 | def __init__(self, channel_list, dropout = 0.,
22 | batch_norm = True, relu_first = False):
23 | super().__init__()
24 | assert len(channel_list) >= 2
25 | self.channel_list = channel_list
26 | self.dropout = dropout
27 | self.relu_first = relu_first
28 |
29 | self.lins = torch.nn.ModuleList()
30 | for dims in zip(self.channel_list[:-1], self.channel_list[1:]):
31 | self.lins.append(Linear(*dims))
32 |
33 | self.norms = torch.nn.ModuleList()
34 | for dim in zip(self.channel_list[1:-1]):
35 | self.norms.append(BatchNorm1d(dim, track_running_stats = False) if batch_norm else Identity())
36 |
37 | self.reset_parameters()
38 |
39 | def reset_parameters(self):
40 | for lin in self.lins:
41 | lin.reset_parameters()
42 | for norm in self.norms:
43 | if hasattr(norm, 'reset_parameters'):
44 | norm.reset_parameters()
45 |
46 |
47 | def forward(self, x: Tensor) -> Tensor:
48 | """"""
49 | x = self.lins[0](x)
50 | for lin, norm in zip(self.lins[1:], self.norms):
51 | if self.relu_first:
52 | x = x.relu_()
53 | x = norm(x)
54 | if not self.relu_first:
55 | x = x.relu_()
56 | x = F.dropout(x, p = self.dropout, training = self.training)
57 | x = lin.forward(x)
58 | return x
59 |
60 |
61 | def __repr__(self) -> str:
62 | return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/models/NN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from models.MLP import MLP
3 |
4 | class NN(nn.Module):
5 | def __init__(self, hparams, encoder, decoder):
6 | super(NN, self).__init__()
7 |
8 | self.nb_hidden_layers = hparams['nb_hidden_layers']
9 | self.size_hidden_layers = hparams['size_hidden_layers']
10 | self.bn_bool = hparams['bn_bool']
11 | self.activation = nn.ReLU()
12 |
13 | self.encoder = encoder
14 | self.decoder = decoder
15 |
16 | self.dim_enc = hparams['encoder'][-1]
17 |
18 | self.nn = MLP([self.dim_enc] + [self.size_hidden_layers]*self.nb_hidden_layers + [self.dim_enc], batch_norm = self.bn_bool)
19 |
20 | def forward(self, data):
21 | z = self.encoder(data.x)
22 | z = self.nn(z)
23 | z = self.decoder(z)
24 |
25 | return z
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/models/PointNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric.nn as nng
4 | from models.MLP import MLP
5 |
6 |
7 | class PointNet(nn.Module):
8 | def __init__(self, hparams, encoder, decoder):
9 | super(PointNet, self).__init__()
10 |
11 | self.base_nb = hparams['base_nb']
12 |
13 | self.in_block = MLP([hparams['encoder'][-1], self.base_nb, self.base_nb * 2], batch_norm=False)
14 | self.max_block = MLP([self.base_nb * 2, self.base_nb * 4, self.base_nb * 8, self.base_nb * 32],
15 | batch_norm=False)
16 |
17 | self.out_block = MLP([self.base_nb * (32 + 2), self.base_nb * 16, self.base_nb * 8, self.base_nb * 4],
18 | batch_norm=False)
19 |
20 | self.encoder = encoder
21 | self.decoder = decoder
22 |
23 | self.fcfinal = nn.Linear(self.base_nb * 4, hparams['encoder'][-1])
24 |
25 | def forward(self, data):
26 | z, batch = data.x.float(), data.batch.long()
27 |
28 | z = self.encoder(z)
29 | z = self.in_block(z)
30 |
31 | global_coef = self.max_block(z)
32 | global_coef = nng.global_max_pool(global_coef, batch=batch)
33 | nb_points = torch.zeros(global_coef.shape[0], device=z.device)
34 |
35 | for i in range(batch.max() + 1):
36 | nb_points[i] = (batch == i).sum()
37 | nb_points = nb_points.long()
38 | global_coef = torch.repeat_interleave(global_coef, nb_points, dim=0)
39 |
40 | z = torch.cat([z, global_coef], dim=1)
41 | z = self.out_block(z)
42 | z = self.fcfinal(z)
43 | z = self.decoder(z)
44 |
45 | return z
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/models/Transolver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from timm.models.layers import trunc_normal_
5 | from einops import rearrange, repeat
6 |
7 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1),
8 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU}
9 |
10 |
11 | class Physics_Attention_Irregular_Mesh(nn.Module):
12 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64):
13 | super().__init__()
14 | inner_dim = dim_head * heads
15 | self.dim_head = dim_head
16 | self.heads = heads
17 | self.scale = dim_head ** -0.5
18 | self.softmax = nn.Softmax(dim=-1)
19 | self.dropout = nn.Dropout(dropout)
20 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
21 |
22 | self.in_project_x = nn.Linear(dim, inner_dim)
23 | self.in_project_fx = nn.Linear(dim, inner_dim)
24 | self.in_project_slice = nn.Linear(dim_head, slice_num)
25 | for l in [self.in_project_slice]:
26 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
27 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
28 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
29 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
30 | self.to_out = nn.Sequential(
31 | nn.Linear(inner_dim, dim),
32 | nn.Dropout(dropout)
33 | )
34 |
35 | def forward(self, x):
36 | # B N C
37 | B, N, C = x.shape
38 |
39 | ### (1) Slice
40 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \
41 | .permute(0, 2, 1, 3).contiguous() # B H N C
42 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \
43 | .permute(0, 2, 1, 3).contiguous() # B H N C
44 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G
45 | slice_norm = slice_weights.sum(2) # B H G
46 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
47 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
48 |
49 | ### (2) Attention among slice tokens
50 | q_slice_token = self.to_q(slice_token)
51 | k_slice_token = self.to_k(slice_token)
52 | v_slice_token = self.to_v(slice_token)
53 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
54 | attn = self.softmax(dots)
55 | attn = self.dropout(attn)
56 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
57 |
58 | ### (3) Deslice
59 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
60 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
61 | return self.to_out(out_x)
62 |
63 |
64 | class MLP(nn.Module):
65 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True):
66 | super(MLP, self).__init__()
67 |
68 | if act in ACTIVATION.keys():
69 | act = ACTIVATION[act]
70 | else:
71 | raise NotImplementedError
72 | self.n_input = n_input
73 | self.n_hidden = n_hidden
74 | self.n_output = n_output
75 | self.n_layers = n_layers
76 | self.res = res
77 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act())
78 | self.linear_post = nn.Linear(n_hidden, n_output)
79 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)])
80 |
81 | def forward(self, x):
82 | x = self.linear_pre(x)
83 | for i in range(self.n_layers):
84 | if self.res:
85 | x = self.linears[i](x) + x
86 | else:
87 | x = self.linears[i](x)
88 | x = self.linear_post(x)
89 | return x
90 |
91 |
92 | class Transolver_block(nn.Module):
93 | """Transformer encoder block."""
94 |
95 | def __init__(
96 | self,
97 | num_heads: int,
98 | hidden_dim: int,
99 | dropout: float,
100 | act='gelu',
101 | mlp_ratio=4,
102 | last_layer=False,
103 | out_dim=1,
104 | slice_num=32,
105 | ):
106 | super().__init__()
107 | self.last_layer = last_layer
108 | self.ln_1 = nn.LayerNorm(hidden_dim)
109 | self.Attn = Physics_Attention_Irregular_Mesh(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
110 | dropout=dropout, slice_num=slice_num)
111 | self.ln_2 = nn.LayerNorm(hidden_dim)
112 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
113 | self.mlp_new = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
114 | if self.last_layer:
115 | self.ln_3 = nn.LayerNorm(hidden_dim)
116 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
117 |
118 | def forward(self, fx):
119 | fx = self.Attn(self.ln_1(fx)) + fx
120 | fx = self.mlp(self.ln_2(fx)) + fx
121 | if self.last_layer:
122 | return self.mlp2(self.ln_3(fx))
123 | else:
124 | return fx
125 |
126 |
127 | class Transolver(nn.Module):
128 | def __init__(self,
129 | space_dim=1,
130 | n_layers=5,
131 | n_hidden=256,
132 | dropout=0,
133 | n_head=8,
134 | act='gelu',
135 | mlp_ratio=1,
136 | fun_dim=1,
137 | out_dim=1,
138 | slice_num=32,
139 | ref=8,
140 | unified_pos=False
141 | ):
142 | super(Transolver, self).__init__()
143 | self.__name__ = 'Transolver'
144 | self.ref = ref
145 | self.unified_pos = unified_pos
146 | if self.unified_pos:
147 | self.preprocess = MLP(fun_dim + space_dim + self.ref * self.ref, n_hidden * 2, n_hidden,
148 | n_layers=0,
149 | res=False, act=act)
150 | else:
151 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
152 |
153 | self.n_hidden = n_hidden
154 | self.space_dim = space_dim
155 |
156 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden,
157 | dropout=dropout,
158 | act=act,
159 | mlp_ratio=mlp_ratio, out_dim=out_dim,
160 | slice_num=slice_num,
161 | last_layer=(_ == n_layers - 1))
162 | for _ in range(n_layers)])
163 | self.initialize_weights()
164 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float))
165 |
166 | def initialize_weights(self):
167 | self.apply(self._init_weights)
168 |
169 | def _init_weights(self, m):
170 | if isinstance(m, nn.Linear):
171 | trunc_normal_(m.weight, std=0.02)
172 | if isinstance(m, nn.Linear) and m.bias is not None:
173 | nn.init.constant_(m.bias, 0)
174 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
175 | nn.init.constant_(m.bias, 0)
176 | nn.init.constant_(m.weight, 1.0)
177 |
178 | def get_grid(self, my_pos):
179 | # my_pos 1 N 3
180 | batchsize = my_pos.shape[0]
181 |
182 | gridx = torch.tensor(np.linspace(-2, 4, self.ref), dtype=torch.float)
183 | gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1])
184 | gridy = torch.tensor(np.linspace(-1.5, 1.5, self.ref), dtype=torch.float)
185 | gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1])
186 | grid_ref = torch.cat((gridx, gridy), dim=-1).cuda().reshape(batchsize, self.ref ** 2, 2) # B 4 4 4 3
187 |
188 | pos = torch.sqrt(
189 | torch.sum((my_pos[:, :, None, :] - grid_ref[:, None, :, :]) ** 2,
190 | dim=-1)). \
191 | reshape(batchsize, my_pos.shape[1], self.ref * self.ref).contiguous()
192 | return pos
193 |
194 | def forward(self, data):
195 | x, fx, T = data.x, None, None
196 | x = x[None, :, :]
197 | if self.unified_pos:
198 | new_pos = self.get_grid(data.pos[None, :, :])
199 | x = torch.cat((x, new_pos), dim=-1)
200 | if fx is not None:
201 | fx = torch.cat((x, fx), -1)
202 | fx = self.preprocess(fx)
203 | else:
204 | fx = self.preprocess(x)
205 | fx = fx + self.placeholder[None, None, :]
206 |
207 | for block in self.blocks:
208 | fx = block(fx)
209 |
210 | return fx[0]
211 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/params.yaml:
--------------------------------------------------------------------------------
1 | GraphSAGE:
2 | encoder: [ 7, 64, 64, 8 ]
3 | decoder: [ 8, 64, 64, 4 ]
4 |
5 | nb_hidden_layers: 3
6 | size_hidden_layers: 64
7 | batch_size: 1
8 | nb_epochs: 398
9 | lr: 0.001
10 | max_neighbors: 64
11 | bn_bool: True
12 | subsampling: 32000
13 | r: 0.05
14 |
15 | Transolver:
16 | batch_size: 1
17 | nb_epochs: 398
18 | lr: 0.001
19 | max_neighbors: 64
20 | subsampling: 32000
21 | r: 0.05
22 |
23 | PointNet:
24 | encoder: [ 7, 64, 64, 8 ]
25 | decoder: [ 8, 64, 64, 4 ]
26 |
27 | base_nb: 8
28 | batch_size: 1
29 | nb_epochs: 398
30 | lr: 0.001
31 | subsampling: 32000
32 |
33 | MLP:
34 | encoder: [ 7, 64, 64, 8 ]
35 | decoder: [ 8, 64, 64, 4 ]
36 |
37 | nb_hidden_layers: 3
38 | size_hidden_layers: 64
39 | batch_size: 1
40 | nb_epochs: 398
41 | lr: 0.001
42 | bn_bool: True
43 | subsampling: 32000
44 |
45 | GUNet:
46 | encoder: [ 7, 64, 64, 8 ]
47 | decoder: [ 8, 64, 64, 4 ]
48 |
49 | layer: 'SAGE'
50 | pool: 'random'
51 | nb_scale: 5
52 | pool_ratio: [ .5, .5, .5, .5 ]
53 | list_r: [ .05, .2, .5, 1, 10 ]
54 | size_hidden_layers: 8
55 | batchnorm: True
56 | res: False
57 |
58 | batch_size: 1
59 | nb_epochs: 398
60 | lr: 0.001
61 | max_neighbors: 64
62 | subsampling: 32000
63 | r: 0.05
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torch_geometric
3 | torch-cluster
4 | vtk
5 | timm
6 | einops
7 | seaborn
8 | pyvista
9 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/scripts/Evaluation.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=4
2 |
3 | python main_evaluation.py --my_path /data/naca/
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/scripts/GraphSAGE.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=4
2 |
3 | python main.py --model GraphSAGE -t full --my_path /data/naca/Dataset --score 1
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/scripts/Transolver.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=6
2 |
3 | python main.py --model Transolver -t full --my_path /data/naca/Dataset --score 1
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/utils/metrics_NACA.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | from utils.naca_generator import camber_line
5 |
6 | sns.set()
7 |
8 | # Properties of air at sea level and 293.15K
9 | RHO = 1.184
10 | NU = 1.56e-5
11 | C = 346.1
12 | P_ref = 1.013e5
13 |
14 |
15 | def surface_coefficients(airfoil, aero_name, compressible=False, extrado=False):
16 | u_inf = float(aero_name.split('_')[2])
17 | digits = list(map(float, aero_name.split('_')[4:-1]))
18 | if compressible:
19 | qInf = 0.5 * u_inf ** 2 * RHO
20 | else:
21 | qInf = 0.5 * u_inf ** 2
22 |
23 | if extrado:
24 | camber = camber_line(digits, airfoil.points[:, 0])[0]
25 | idx_extrado = (airfoil.points[:, 1] > camber)
26 | points = airfoil.points[:, 0]
27 | pressure = airfoil.point_data['p']
28 | wss = np.linalg.norm(airfoil.point_data['wallShearStress'][:, :2], axis=1)
29 |
30 | c_p = np.concatenate([points[:, None], pressure[:, None] / qInf], axis=1)
31 | c_l = np.concatenate([points[:, None], wss[:, None] / qInf], axis=1)
32 |
33 | if extrado:
34 | return c_p, c_l, idx_extrado
35 | else:
36 | return c_p, c_l
37 |
38 |
39 | def compare_surface_coefs(coefs1, coefs2, extrado=True, path=None):
40 | ycp1, ycp2, c_p1, c_p2 = coefs1[0][:, 0], coefs2[0][:, 0], coefs1[0][:, 1], coefs2[0][:, 1]
41 | ycl1, ycl2, c_f1, c_f2 = coefs1[1][:, 0], coefs2[1][:, 0], coefs1[1][:, 1], coefs2[1][:, 1]
42 |
43 | fig, ax = plt.subplots(2, figsize=(20, 10))
44 | if extrado:
45 | n_extrado1, n_extrado2 = coefs1[2], coefs2[2]
46 | ax[0].scatter(ycp1[:n_extrado1], c_p1[:n_extrado1], label='Extrado 1')
47 | ax[0].scatter(ycp1[n_extrado1:], c_p1[n_extrado1:], color='r', marker='x', label='Intrado 1')
48 | ax[0].scatter(ycp2[:n_extrado2], c_p2[:n_extrado2], color='y', label='Extrado Target')
49 | ax[0].scatter(ycp2[n_extrado2:], c_p2[n_extrado2:], color='g', marker='x', label='Intrado Target')
50 |
51 | ax[1].scatter(ycl1[:n_extrado1], c_f1[:n_extrado1], label='Extrado 1')
52 | ax[1].scatter(ycl1[n_extrado1:], c_f1[n_extrado1:], color='r', marker='x', label='Intrado 1')
53 | ax[1].scatter(ycl2[:n_extrado2], c_f2[:n_extrado2], color='y', label='Extrado Target')
54 | ax[1].scatter(ycl2[n_extrado2:], c_f2[n_extrado2:], color='g', marker='x', label='Intrado Target')
55 |
56 | else:
57 | ax[0].scatter(ycp1, c_p1, label='Experiment 1')
58 | ax[0].scatter(ycp2, c_p2, color='y', label='Experiment Target')
59 |
60 | ax[1].scatter(ycl1, c_f1, label='Experiment 1')
61 | ax[1].scatter(ycl2, c_f2, color='y', label='Experiment Targer')
62 |
63 | ax[0].invert_yaxis()
64 | ax[0].set_xlabel('x/c')
65 | ax[1].set_xlabel('x/c')
66 | ax[0].set_ylabel(r'$C_p$')
67 | ax[1].set_ylabel(r'$C_f$')
68 | ax[0].set_title('Pressure coefficient')
69 | ax[1].set_title('Skin friction coefficient')
70 | ax[0].legend(loc='best')
71 | ax[1].legend(loc='best')
72 |
73 | if path != None:
74 | fig.savefig(path + 'surface_coefs.png', bbox_inches='tight', dpi=150)
75 |
76 |
77 | def boundary_layer(airfoil, internal, aero_name, x, y=1e-3, resolution=int(1e3), direction='normals', rotation=False,
78 | extrado=True):
79 | u_inf = float(aero_name.split('_')[2])
80 | digits = list(map(float, aero_name.split('_')[4:-1]))
81 | camber = camber_line(digits, airfoil.points[:, 0])[0]
82 | idx_extrado = (airfoil.points[:, 1] > camber)
83 |
84 | if extrado:
85 | arg = np.argmin(np.abs(airfoil.points[idx_extrado, 0] - x)) + 1
86 | arg = np.argwhere(idx_extrado.cumsum() == arg).min()
87 | else:
88 | arg = np.argmin(np.abs(airfoil.points[~idx_extrado, 0] - x)) + 1
89 | arg = np.argwhere((~idx_extrado).cumsum() == arg).min()
90 |
91 | if direction == 'normals':
92 | normals = -airfoil.point_data['Normals'][arg]
93 |
94 | elif direction == 'y':
95 | normals = np.array([0, 2 * int(extrado) - 1, 0])
96 |
97 | a, b = airfoil.points[arg], airfoil.points[arg] + y * normals
98 | bl = internal.sample_over_line(a, b, resolution=resolution)
99 |
100 | if rotation:
101 | rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
102 | u = (bl.point_data['U'] * (rot @ normals)).sum(axis=1)
103 | v = (bl.point_data['U'] * normals).sum(axis=1)
104 | else:
105 | u = bl.point_data['U'][:, 0]
106 | v = bl.point_data['U'][:, 1]
107 |
108 | nut = bl.point_data['nut']
109 | yc = bl.points[:, 1] - a[1]
110 |
111 | return yc, u / u_inf, v / u_inf, nut / NU
112 |
113 |
114 | def compare_boundary_layer(coefs1, coefs2, ylim=.1, path=None, ylog=False):
115 | yc1, u1, v1, nut1 = coefs1
116 | yc2, u2, v2, nut2 = coefs2
117 |
118 | fig, ax = plt.subplots(1, 3, figsize=(30, 10))
119 | ax[0].scatter(u1, yc1, label='Experiment 1')
120 | ax[0].scatter(u2, yc2, label='Experiment 2', color='r', marker='x')
121 | ax[0].set_xlabel(r'$u/U_\infty$')
122 | ax[0].set_ylabel(r'$(y-y_0)/c$')
123 | # ax[0].set_xlim([-0.2, 1.4])
124 | # ax[0].set_ylim([0, ylim])
125 | ax[0].legend(loc='best')
126 |
127 | ax[1].scatter(v1, yc1, label='Experiment 1')
128 | ax[1].scatter(v2, yc2, label='Experiment 2', color='r', marker='x')
129 | ax[1].set_xlabel(r'$v/U_\infty$')
130 | ax[1].set_ylabel(r'$(y-y_0)/c$')
131 | # ax[1].set_xlim([-0.2, 0.2])
132 | # ax[1].set_ylim([0, ylim])
133 | ax[1].legend(loc='best')
134 |
135 | ax[2].scatter(nut1, yc1, label='Experience 1')
136 | ax[2].scatter(nut2, yc2, label='Experience 2', color='r', marker='x')
137 | # ax[2].set_ylim([0, ylim])
138 | ax[2].set_xlabel(r'$\nu_t/\nu$')
139 | ax[2].set_ylabel(r'$(y-y_0)/c$')
140 | ax[2].legend(loc='best')
141 |
142 | if ylog:
143 | ax[0].set_yscale('log')
144 | ax[1].set_yscale('log')
145 | ax[2].set_yscale('log')
146 |
147 | if path != None:
148 | fig.savefig(path + 'boundary_layer.png', bbox_inches='tight', dpi=150)
149 |
150 |
151 | def plot_residuals(path, params):
152 | datas = dict()
153 | if params['turbulence'] == 'SA':
154 | fields = ['Ux', 'Uy', 'p', 'nuTilda']
155 | elif params['turbulence'] == 'SST':
156 | fields = ['Ux', 'Uy', 'p', 'k', 'omega']
157 | for field in fields:
158 | data = np.loadtxt(path + 'logs/' + field + '_0')[:, 1]
159 | datas[field] = data
160 |
161 | if params['turbulence'] == 'SA':
162 | fig, ax = plt.subplots(2, 2, figsize=(20, 20))
163 | ax[1, 1].plot(datas['nuTilda'])
164 | ax[1, 1].set_yscale('log')
165 | ax[1, 1].set_title('nuTilda residual')
166 | ax[1, 1].set_xlabel('Number of iterations')
167 |
168 | elif params['turbulence'] == 'SST':
169 | fig, ax = plt.subplots(3, 2, figsize=(30, 20))
170 | ax[1, 1].plot(datas['k'])
171 | ax[1, 1].set_yscale('log')
172 | ax[1, 1].set_title('k residual')
173 | ax[1, 1].set_xlabel('Number of iterations')
174 |
175 | ax[2, 0].plot(datas['omega'])
176 | ax[2, 0].set_yscale('log')
177 | ax[2, 0].set_title('omega residual')
178 | ax[2, 0].set_xlabel('Number of iterations');
179 |
180 | ax[0, 0].plot(datas['Ux'])
181 | ax[0, 0].set_yscale('log')
182 | ax[0, 0].set_title('Ux residual')
183 |
184 | ax[0, 1].plot(datas['Uy'])
185 | ax[0, 1].set_yscale('log')
186 | ax[0, 1].set_title('Uy residual')
187 |
188 | ax[1, 0].plot(datas['p'])
189 | ax[1, 0].set_yscale('log')
190 | ax[1, 0].set_title('p residual')
191 | ax[1, 0].set_xlabel('Number of iterations');
192 |
193 | fig.savefig(path + 'residuals.png', bbox_inches='tight', dpi=150)
194 |
195 | return datas
196 |
197 |
198 | def plot_coef_convergence(path, params):
199 | datas = dict()
200 | datas['c_d'] = np.loadtxt(path + 'postProcessing/forceCoeffs1/0/coefficient.dat')[:, 1]
201 | datas['c_l'] = np.loadtxt(path + 'postProcessing/forceCoeffs1/0/coefficient.dat')[:, 3]
202 | c_d, c_l = datas['c_d'][-1], datas['c_l'][-1]
203 |
204 | fig, ax = plt.subplots(2, figsize=(30, 15))
205 | ax[0].plot(datas['c_d'])
206 | ax[0].set_ylim([.5 * c_d, 1.5 * c_d])
207 | ax[0].set_title('Drag coefficient')
208 | ax[0].set_xlabel('Number of iterations')
209 | ax[0].set_ylabel(r'$C_D$')
210 |
211 | ax[1].plot(datas['c_l'])
212 | ax[1].set_title('Lift coefficient')
213 | ax[1].set_ylim([.5 * c_l, 1.5 * c_l])
214 | ax[1].set_ylabel(r'$C_L$')
215 | ax[1].set_xlabel('Number of iterations');
216 |
217 | print('Drag coefficient: {0:.5}, lift coefficient: {1:.5}'.format(c_d, c_l))
218 |
219 | fig.savefig(path + 'coef_convergence.png', bbox_inches='tight', dpi=150)
220 |
221 | return datas, c_d, c_l
222 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/utils/naca_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def thickness_dist(t, x, CTE=True):
5 | # CTE for close trailing edge
6 | if CTE:
7 | a = -0.1036
8 | else:
9 | a = -0.1015
10 | return 5 * t * (0.2969 * np.sqrt(x) - 0.1260 * x - 0.3516 * x ** 2 + 0.2843 * x ** 3 + a * x ** 4)
11 |
12 |
13 | def camber_line(params, x):
14 | # assert np.all(np.logical_and(x >= -1e-1, x <= 1)), 'Found x > 1 or x < 0'
15 | y_c = np.zeros_like(x)
16 | dy_c = np.zeros_like(x)
17 |
18 | if len(params) == 2:
19 | m = params[0] / 100
20 | p = params[1] / 10
21 |
22 | if p == 0:
23 | dy_c = -2 * m * x
24 | return y_c, dy_c
25 | elif p == 1:
26 | dy_c = 2 * m * (1 - x)
27 | return y_c, dy_c
28 |
29 | mask1 = (x < p)
30 | mask2 = (x >= p)
31 | y_c[mask1] = (m / p ** 2) * (2 * p * x[mask1] - x[mask1] ** 2)
32 | dy_c[mask1] = (2 * m / p ** 2) * (p - x[mask1])
33 | y_c[mask2] = (m / (1 - p) ** 2) * ((1 - 2 * p) + 2 * p * x[mask2] - x[mask2] ** 2)
34 | dy_c[mask2] = (2 * m / (1 - p) ** 2) * (p - x[mask2])
35 |
36 | elif len(params) == 3:
37 | l, p, q = params
38 | c_l, x_f = 3 / 20 * l, p / 20
39 |
40 | f = lambda x: x * (1 - np.sqrt(x / 3)) - x_f
41 | df = lambda x: 1 - 3 * np.sqrt(x / 3) / 2
42 | old_m = 0.5
43 | cond = True
44 | while cond:
45 | new_m = np.max([old_m - f(old_m) / df(old_m), 0])
46 | cond = (np.abs(old_m - new_m) > 1e-15)
47 | old_m = new_m
48 | m = old_m
49 | r = (3 * m - 7 * m ** 2 + 8 * m ** 3 - 4 * m ** 4) / np.sqrt(m * (1 - m)) - 3 / 2 * (1 - 2 * m) * (
50 | np.pi / 2 - np.arcsin(1 - 2 * m))
51 | k_1 = c_l / r
52 |
53 | mask1 = (x <= m)
54 | mask2 = (x > m)
55 | if q == 0:
56 | y_c[mask1] = k_1 * ((x[mask1] ** 3 - 3 * m * x[mask1] ** 2 + m ** 2 * (3 - m) * x[mask1]))
57 | dy_c[mask1] = k_1 * (3 * x[mask1] ** 2 - 6 * m * x[mask1] + m ** 2 * (3 - m))
58 | y_c[mask2] = k_1 * m ** 3 * (1 - x[mask2])
59 | dy_c[mask2] = -k_1 * m ** 3 * np.ones_like(dy_c[mask2])
60 |
61 | elif q == 1:
62 | k = (3 * (m - x_f) ** 2 - m ** 3) / (1 - m) ** 3
63 | y_c[mask1] = k_1 * ((x[mask1] - m) ** 3 - k * (1 - m) ** 3 * x[mask1] - m ** 3 * x[mask1] + m ** 3)
64 | dy_c[mask1] = k_1 * (3 * (x[mask1] - m) ** 2 - k * (1 - m) ** 3 - m ** 3)
65 | y_c[mask2] = k_1 * (k * (x[mask2] - m) ** 3 - k * (1 - m) ** 3 * x[mask2] - m ** 3 * x[mask2] + m ** 3)
66 | dy_c[mask2] = k_1 * (3 * k * (x[mask2] - m) ** 2 - k * (1 - m) ** 3 - m ** 3)
67 |
68 | else:
69 | raise ValueError('Q must be 0 for normal camber or 1 for reflex camber.')
70 |
71 | else:
72 | raise ValueError('The first input must be a tuple of the 2 or 3 digits that represent the camber line.')
73 |
74 | return y_c, dy_c
75 |
76 |
77 | def naca_generator(params, nb_samples=400, scale=1, origin=(0, 0), cosine_spacing=True, verbose=True, CTE=True):
78 | if len(params) == 3:
79 | params_c = params[:2]
80 | t = params[2] / 100
81 | if verbose:
82 | print(f'Generating naca M = {params_c[0]}, P = {params_c[1]}, XX = {t * 100}')
83 | elif len(params) == 4:
84 | params_c = params[:3]
85 | t = params[3] / 100
86 | if verbose:
87 | print(f'Generating naca L = {params_c[0]}, P = {params_c[1]}, Q = {params_c[2]}, XX = {t * 100}')
88 | else:
89 | raise ValueError('The first argument must be a tuple of the 4 or 5 digits of the airfoil.')
90 |
91 | if cosine_spacing:
92 | beta = np.pi * np.linspace(1, 0, nb_samples + 1, endpoint=True)
93 | x = (1 - np.cos(beta)) / 2
94 | else:
95 | x = np.linspace(1, 0, nb_samples + 1, endpoint=True)
96 |
97 | y_c, dy_c = camber_line(params_c, x)
98 | y_t = thickness_dist(t, x, CTE)
99 | theta = np.arctan(dy_c)
100 | x_u = x - y_t * np.sin(theta)
101 | x_l = x + y_t * np.sin(theta)
102 | y_u = y_c + y_t * np.cos(theta)
103 | y_l = y_c - y_t * np.cos(theta)
104 | x = np.concatenate([x_u, x_l[:-1][::-1]], axis=0)
105 | y = np.concatenate([y_u, y_l[:-1][::-1]], axis=0)
106 | pos = np.stack([
107 | x * scale + origin[0],
108 | y * scale + origin[1]
109 | ], axis=-1
110 | )
111 | pos[0], pos[-1] = np.array([1, 0]), np.array([1, 0])
112 | return pos
113 |
--------------------------------------------------------------------------------
/Airfoil-Design-AirfRANS/utils/reorganize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def reorganize(in_order_points, out_order_points, quantity_to_reordered):
4 | n = out_order_points.shape[0]
5 | idx = np.zeros(n)
6 | for i in range(n):
7 | cond = (out_order_points[i] == in_order_points)
8 | cond = cond[:, 0]*cond[:, 1]
9 | idx[i] = np.argwhere(cond)[0][0]
10 | idx = idx.astype('int')
11 |
12 | assert (in_order_points[idx] == out_order_points).all()
13 |
14 | return quantity_to_reordered[idx]
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/README.md:
--------------------------------------------------------------------------------
1 | # Transolver for Car Design
2 |
3 | We test [Transolver](https://arxiv.org/abs/2402.02366) on practical design tasks. The car design task requires the model to estimate the surrounding wind speed and surface pressure for a driving car.
4 |
5 |
6 |
7 |
8 | Figure 1. Car design task.
9 |
10 |
11 | Relative error of surrounding wind, surface pressure and [drag coefficient](https://en.wikipedia.org/wiki/Drag_coefficient) are recorded, as well as [Spearman's rank correlations](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient), which can be used to quantify the model's capability in ranking different designs.
12 |
13 |
14 |
15 |
16 | Table 1. Model comparisons of the car design task.
17 |
18 |
19 |
20 | ## Get Started
21 |
22 | 1. Install Python 3.8. For convenience, execute the following command.
23 |
24 | ```bash
25 | pip install -r requirements.txt
26 | ```
27 |
28 | Note: You need to install [pytorch_geometric](https://github.com/pyg-team/pytorch_geometric).
29 |
30 | 2. Prepare Data.
31 |
32 | The raw data can be found [[here]](http://www.nobuyuki-umetani.com/publication/mlcfd_data.zip), which is provided by [Nobuyuki Umetani](https://dl.acm.org/doi/abs/10.1145/3197517.3201325).
33 |
34 | 3. Train and evaluate model. We provide the experiment scripts under the folder `./scripts/`. You can reproduce the experiment results as the following examples:
35 |
36 | ```bash
37 | bash scripts/Transolver.sh # for Training (will take 8-10 hours on one single A100)
38 | bash scripts/Evaluation.sh # for Evaluation
39 | ```
40 |
41 | Note: You need to change the argument `--data_dir` and `--save_dir` to your dataset path. Here `data_dir` is for the raw data and `save_dir` is to save the preprocessed data.
42 |
43 | If you have already downloaded or generated the preprocecessed data, you can change `--preprocessed` as True for speed up.
44 |
45 | 4. Develop your own model. Here are the instructions:
46 |
47 | - Add the model file under folder `./models/`.
48 | - Add the model configuration into `./main.py`.
49 | - Add a script file under folder `./scripts/` and change the argument `--model`.
50 |
51 | ## Slice Visualization
52 |
53 | Transolver proposes to **learn physical states** hidden under the unwieldy meshes.
54 |
55 | The following visualization demonstrates that Transolver can successfully learn to ascribe the points under similar physical state to the same slice, such as windshield, license plate and headlight.
56 |
57 |
58 |
59 |
60 | Figure 2. Visualization for Transolver learned physical states.
61 |
62 |
63 |
64 | ## Showcases
65 |
66 | Transolver achieves the best performance in complex geometries and hybrid physics.
67 |
68 |
69 |
70 |
71 | Figure 3. Case study of Transolver and other models.
72 |
73 |
74 |
75 | ## Citation
76 |
77 | If you find this repo useful, please cite our paper.
78 |
79 | ```
80 | @inproceedings{wu2024Transolver,
81 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries},
82 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long},
83 | booktitle={International Conference on Machine Learning},
84 | year={2024}
85 | }
86 | ```
87 |
88 | ## Contact
89 |
90 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn).
91 |
92 | ## Acknowledgement
93 |
94 | We appreciate the following papers a lot for their valuable code base or datasets:
95 |
96 | https://dl.acm.org/doi/abs/10.1145/3197517.3201325
97 |
98 | https://openreview.net/forum?id=EyQO9RPhwN
99 |
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/dataset/load_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataset.dataset import get_datalist
3 |
4 |
5 | def get_samples(root):
6 | folds = [f'param{i}' for i in range(9)]
7 | samples = []
8 | for fold in folds:
9 | fold_samples = []
10 | files = os.listdir(os.path.join(root, fold))
11 | for file in files:
12 | path = os.path.join(root, os.path.join(fold, file))
13 | if os.path.isdir(path):
14 | fold_samples.append(os.path.join(fold, file))
15 | samples.append(fold_samples)
16 | return samples # 100 + 99 + 97 + 100 + 100 + 96 + 100 + 98 + 99 = 889 samples
17 |
18 |
19 | def load_train_val_fold(args, preprocessed):
20 | samples = get_samples(args.data_dir)
21 | trainlst = []
22 | for i in range(len(samples)):
23 | if i == args.fold_id:
24 | continue
25 | trainlst += samples[i]
26 | vallst = samples[args.fold_id] if 0 <= args.fold_id < len(samples) else None
27 |
28 | if preprocessed:
29 | print("use preprocessed data")
30 | print("loading data")
31 | train_dataset, coef_norm = get_datalist(args.data_dir, trainlst, norm=True, savedir=args.save_dir,
32 | preprocessed=preprocessed)
33 | val_dataset = get_datalist(args.data_dir, vallst, coef_norm=coef_norm, savedir=args.save_dir,
34 | preprocessed=preprocessed)
35 | print("load data finish")
36 | return train_dataset, val_dataset, coef_norm
37 |
38 |
39 | def load_train_val_fold_file(args, preprocessed):
40 | samples = get_samples(args.data_dir)
41 | trainlst = []
42 | for i in range(len(samples)):
43 | if i == args.fold_id:
44 | continue
45 | trainlst += samples[i]
46 | vallst = samples[args.fold_id] if 0 <= args.fold_id < len(samples) else None
47 |
48 | if preprocessed:
49 | print("use preprocessed data")
50 | print("loading data")
51 | train_dataset, coef_norm = get_datalist(args.data_dir, trainlst, norm=True, savedir=args.save_dir,
52 | preprocessed=preprocessed)
53 | val_dataset = get_datalist(args.data_dir, vallst, coef_norm=coef_norm, savedir=args.save_dir,
54 | preprocessed=preprocessed)
55 | print("load data finish")
56 | return train_dataset, val_dataset, coef_norm, vallst
57 |
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/fig/car_slice_surf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/car_slice_surf.png
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/fig/case_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/case_study.png
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/fig/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/results.png
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/fig/task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/task.png
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/main.py:
--------------------------------------------------------------------------------
1 | import train
2 | import os
3 | import torch
4 | import argparse
5 |
6 | from dataset.load_dataset import load_train_val_fold
7 | from dataset.dataset import GraphDataset
8 | from models.Transolver import Model
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--data_dir', default='/data/PDE_data/mlcfd_data/training_data')
12 | parser.add_argument('--save_dir', default='/data/PDE_data/mlcfd_data/preprocessed_data')
13 | parser.add_argument('--fold_id', default=0, type=int)
14 | parser.add_argument('--gpu', default=0, type=int)
15 | parser.add_argument('--val_iter', default=10, type=int)
16 | parser.add_argument('--cfd_config_dir', default='cfd/cfd_params.yaml')
17 | parser.add_argument('--cfd_model')
18 | parser.add_argument('--cfd_mesh', action='store_true')
19 | parser.add_argument('--r', default=0.2, type=float)
20 | parser.add_argument('--weight', default=0.5, type=float)
21 | parser.add_argument('--lr', default=0.001, type=float)
22 | parser.add_argument('--batch_size', default=1, type=float)
23 | parser.add_argument('--nb_epochs', default=200, type=float)
24 | parser.add_argument('--preprocessed', default=1, type=int)
25 | args = parser.parse_args()
26 | print(args)
27 |
28 | hparams = {'lr': args.lr, 'batch_size': args.batch_size, 'nb_epochs': args.nb_epochs}
29 |
30 | n_gpu = torch.cuda.device_count()
31 | use_cuda = 0 <= args.gpu < n_gpu and torch.cuda.is_available()
32 | device = torch.device(f'cuda:{args.gpu}' if use_cuda else 'cpu')
33 |
34 | train_data, val_data, coef_norm = load_train_val_fold(args, preprocessed=args.preprocessed)
35 | train_ds = GraphDataset(train_data, use_cfd_mesh=args.cfd_mesh, r=args.r)
36 | val_ds = GraphDataset(val_data, use_cfd_mesh=args.cfd_mesh, r=args.r)
37 |
38 | if args.cfd_model == 'Transolver':
39 | model = Model(n_hidden=256, n_layers=8, space_dim=7,
40 | fun_dim=0,
41 | n_head=8,
42 | mlp_ratio=2, out_dim=4,
43 | slice_num=32,
44 | unified_pos=0).cuda()
45 |
46 | path = f'metrics/{args.cfd_model}/{args.fold_id}/{args.nb_epochs}_{args.weight}'
47 | if not os.path.exists(path):
48 | os.makedirs(path)
49 |
50 | model = train.main(device, train_ds, val_ds, model, hparams, path, val_iter=args.val_iter, reg=args.weight,
51 | coef_norm=coef_norm)
52 |
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/main_evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import yaml
5 | import numpy as np
6 | import time
7 | from torch import nn
8 | from torch_geometric.loader import DataLoader
9 | from utils.drag_coefficient import cal_coefficient
10 | from dataset.load_dataset import load_train_val_fold_file
11 | from dataset.dataset import GraphDataset
12 | import scipy as sc
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--data_dir', default='/data/PDE_data/mlcfd_data/training_data')
16 | parser.add_argument('--save_dir', default='/data/PDE_data/mlcfd_data/preprocessed_data')
17 | parser.add_argument('--fold_id', default=0, type=int)
18 | parser.add_argument('--gpu', default=0, type=int)
19 | parser.add_argument('--cfd_model')
20 | parser.add_argument('--cfd_mesh', action='store_true')
21 | parser.add_argument('--r', default=0.2, type=float)
22 | parser.add_argument('--weight', default=0.5, type=float)
23 | parser.add_argument('--nb_epochs', default=200, type=float)
24 | args = parser.parse_args()
25 | print(args)
26 |
27 |
28 | n_gpu = torch.cuda.device_count()
29 | use_cuda = 0 <= args.gpu < n_gpu and torch.cuda.is_available()
30 | device = torch.device(f'cuda:{args.gpu}' if use_cuda else 'cpu')
31 |
32 | train_data, val_data, coef_norm, vallst = load_train_val_fold_file(args, preprocessed=True)
33 | val_ds = GraphDataset(val_data, use_cfd_mesh=args.cfd_mesh, r=args.r)
34 |
35 | path = f'metrics/{args.cfd_model}/{args.fold_id}/{args.nb_epochs}_{args.weight}'
36 | model = torch.load(os.path.join(path, f'model_{args.nb_epochs}.pth')).to(device)
37 |
38 | test_loader = DataLoader(val_ds, batch_size=1)
39 |
40 | if not os.path.exists('./results/' + args.cfd_model + '/'):
41 | os.makedirs('./results/' + args.cfd_model + '/')
42 |
43 | with torch.no_grad():
44 | model.eval()
45 | criterion_func = nn.MSELoss(reduction='none')
46 | l2errs_press = []
47 | l2errs_velo = []
48 | mses_press = []
49 | mses_velo_var = []
50 | times = []
51 | gt_coef_list = []
52 | pred_coef_list = []
53 | coef_error = 0
54 | index = 0
55 | for cfd_data, geom in test_loader:
56 | print(vallst[index])
57 | cfd_data = cfd_data.to(device)
58 | geom = geom.to(device)
59 | tic = time.time()
60 | out = model((cfd_data, geom))
61 | toc = time.time()
62 | targets = cfd_data.y
63 |
64 | if coef_norm is not None:
65 | mean = torch.tensor(coef_norm[2]).to(device)
66 | std = torch.tensor(coef_norm[3]).to(device)
67 | pred_press = out[cfd_data.surf, -1] * std[-1] + mean[-1]
68 | gt_press = targets[cfd_data.surf, -1] * std[-1] + mean[-1]
69 | pred_surf_velo = out[cfd_data.surf, :-1] * std[:-1] + mean[:-1]
70 | gt_surf_velo = targets[cfd_data.surf, :-1] * std[:-1] + mean[:-1]
71 | pred_velo = out[~cfd_data.surf, :-1] * std[:-1] + mean[:-1]
72 | gt_velo = targets[~cfd_data.surf, :-1] * std[:-1] + mean[:-1]
73 | out_denorm = out * std + mean
74 | y_denorm = targets * std + mean
75 |
76 | np.save('./results/' + args.cfd_model + '/' + str(index) + '_pred.npy', out_denorm.detach().cpu().numpy())
77 | np.save('./results/' + args.cfd_model + '/' + str(index) + '_gt.npy', y_denorm.detach().cpu().numpy())
78 |
79 | pred_coef = cal_coefficient(vallst[index].split('/')[1], pred_press[:, None].detach().cpu().numpy(),
80 | pred_surf_velo.detach().cpu().numpy())
81 | gt_coef = cal_coefficient(vallst[index].split('/')[1], gt_press[:, None].detach().cpu().numpy(),
82 | gt_surf_velo.detach().cpu().numpy())
83 |
84 | gt_coef_list.append(gt_coef)
85 | pred_coef_list.append(pred_coef)
86 | coef_error += (abs(pred_coef - gt_coef) / gt_coef)
87 | print(coef_error / (index + 1))
88 |
89 | l2err_press = torch.norm(pred_press - gt_press) / torch.norm(gt_press)
90 | l2err_velo = torch.norm(pred_velo - gt_velo) / torch.norm(gt_velo)
91 |
92 | mse_press = criterion_func(out[cfd_data.surf, -1], targets[cfd_data.surf, -1]).mean(dim=0)
93 | mse_velo_var = criterion_func(out[~cfd_data.surf, :-1], targets[~cfd_data.surf, :-1]).mean(dim=0)
94 |
95 | l2errs_press.append(l2err_press.cpu().numpy())
96 | l2errs_velo.append(l2err_velo.cpu().numpy())
97 | mses_press.append(mse_press.cpu().numpy())
98 | mses_velo_var.append(mse_velo_var.cpu().numpy())
99 | times.append(toc - tic)
100 | index += 1
101 |
102 | gt_coef_list = np.array(gt_coef_list)
103 | pred_coef_list = np.array(pred_coef_list)
104 | spear = sc.stats.spearmanr(gt_coef_list, pred_coef_list)[0]
105 | print("rho_d: ", spear)
106 | print("c_d: ", coef_error / index)
107 | l2err_press = np.mean(l2errs_press)
108 | l2err_velo = np.mean(l2errs_velo)
109 | rmse_press = np.sqrt(np.mean(mses_press))
110 | rmse_velo_var = np.sqrt(np.mean(mses_velo_var, axis=0))
111 | if coef_norm is not None:
112 | rmse_press *= coef_norm[3][-1]
113 | rmse_velo_var *= coef_norm[3][:-1]
114 | print('relative l2 error press:', l2err_press)
115 | print('relative l2 error velo:', l2err_velo)
116 | print('press:', rmse_press)
117 | print('velo:', rmse_velo_var, np.sqrt(np.mean(np.square(rmse_velo_var))))
118 | print('time:', np.mean(times))
119 |
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/models/Transolver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from timm.models.layers import trunc_normal_
5 | from einops import rearrange, repeat
6 |
7 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1),
8 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU}
9 |
10 |
11 | class Physics_Attention_Irregular_Mesh(nn.Module):
12 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64):
13 | super().__init__()
14 | inner_dim = dim_head * heads
15 | self.dim_head = dim_head
16 | self.heads = heads
17 | self.scale = dim_head ** -0.5
18 | self.softmax = nn.Softmax(dim=-1)
19 | self.dropout = nn.Dropout(dropout)
20 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
21 |
22 | self.in_project_x = nn.Linear(dim, inner_dim)
23 | self.in_project_fx = nn.Linear(dim, inner_dim)
24 | self.in_project_slice = nn.Linear(dim_head, slice_num)
25 | for l in [self.in_project_slice]:
26 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
27 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
28 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
29 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
30 | self.to_out = nn.Sequential(
31 | nn.Linear(inner_dim, dim),
32 | nn.Dropout(dropout)
33 | )
34 |
35 | def forward(self, x):
36 | # B N C
37 | B, N, C = x.shape
38 |
39 | ### (1) Slice
40 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \
41 | .permute(0, 2, 1, 3).contiguous() # B H N C
42 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \
43 | .permute(0, 2, 1, 3).contiguous() # B H N C
44 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G
45 | slice_norm = slice_weights.sum(2) # B H G
46 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
47 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
48 |
49 | ### (2) Attention among slice tokens
50 | q_slice_token = self.to_q(slice_token)
51 | k_slice_token = self.to_k(slice_token)
52 | v_slice_token = self.to_v(slice_token)
53 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
54 | attn = self.softmax(dots)
55 | attn = self.dropout(attn)
56 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
57 |
58 | ### (3) Deslice
59 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
60 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
61 | return self.to_out(out_x)
62 |
63 |
64 | class MLP(nn.Module):
65 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True):
66 | super(MLP, self).__init__()
67 |
68 | if act in ACTIVATION.keys():
69 | act = ACTIVATION[act]
70 | else:
71 | raise NotImplementedError
72 | self.n_input = n_input
73 | self.n_hidden = n_hidden
74 | self.n_output = n_output
75 | self.n_layers = n_layers
76 | self.res = res
77 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act())
78 | self.linear_post = nn.Linear(n_hidden, n_output)
79 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)])
80 |
81 | def forward(self, x):
82 | x = self.linear_pre(x)
83 | for i in range(self.n_layers):
84 | if self.res:
85 | x = self.linears[i](x) + x
86 | else:
87 | x = self.linears[i](x)
88 | x = self.linear_post(x)
89 | return x
90 |
91 |
92 | class Transolver_block(nn.Module):
93 | """Transformer encoder block."""
94 |
95 | def __init__(
96 | self,
97 | num_heads: int,
98 | hidden_dim: int,
99 | dropout: float,
100 | act='gelu',
101 | mlp_ratio=4,
102 | last_layer=False,
103 | out_dim=1,
104 | slice_num=32,
105 | ):
106 | super().__init__()
107 | self.last_layer = last_layer
108 | self.ln_1 = nn.LayerNorm(hidden_dim)
109 | self.Attn = Physics_Attention_Irregular_Mesh(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
110 | dropout=dropout, slice_num=slice_num)
111 | self.ln_2 = nn.LayerNorm(hidden_dim)
112 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
113 | if self.last_layer:
114 | self.ln_3 = nn.LayerNorm(hidden_dim)
115 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
116 |
117 | def forward(self, fx):
118 | fx = self.Attn(self.ln_1(fx)) + fx
119 | fx = self.mlp(self.ln_2(fx)) + fx
120 | if self.last_layer:
121 | return self.mlp2(self.ln_3(fx))
122 | else:
123 | return fx
124 |
125 |
126 | class Model(nn.Module):
127 | def __init__(self,
128 | space_dim=1,
129 | n_layers=5,
130 | n_hidden=256,
131 | dropout=0,
132 | n_head=8,
133 | act='gelu',
134 | mlp_ratio=1,
135 | fun_dim=1,
136 | out_dim=1,
137 | slice_num=32,
138 | ref=8,
139 | unified_pos=False
140 | ):
141 | super(Model, self).__init__()
142 | self.__name__ = 'UniPDE_3D'
143 | self.ref = ref
144 | self.unified_pos = unified_pos
145 | if self.unified_pos:
146 | self.preprocess = MLP(fun_dim + self.ref * self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0,
147 | res=False, act=act)
148 | else:
149 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
150 |
151 | self.n_hidden = n_hidden
152 | self.space_dim = space_dim
153 |
154 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden,
155 | dropout=dropout,
156 | act=act,
157 | mlp_ratio=mlp_ratio,
158 | out_dim=out_dim,
159 | slice_num=slice_num,
160 | last_layer=(_ == n_layers - 1))
161 | for _ in range(n_layers)])
162 | self.initialize_weights()
163 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float))
164 |
165 | def initialize_weights(self):
166 | self.apply(self._init_weights)
167 |
168 | def _init_weights(self, m):
169 | if isinstance(m, nn.Linear):
170 | trunc_normal_(m.weight, std=0.02)
171 | if isinstance(m, nn.Linear) and m.bias is not None:
172 | nn.init.constant_(m.bias, 0)
173 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
174 | nn.init.constant_(m.bias, 0)
175 | nn.init.constant_(m.weight, 1.0)
176 |
177 | def get_grid(self, my_pos):
178 | # my_pos 1 N 3
179 | batchsize = my_pos.shape[0]
180 |
181 | gridx = torch.tensor(np.linspace(-1.5, 1.5, self.ref), dtype=torch.float)
182 | gridx = gridx.reshape(1, self.ref, 1, 1, 1).repeat([batchsize, 1, self.ref, self.ref, 1])
183 | gridy = torch.tensor(np.linspace(0, 2, self.ref), dtype=torch.float)
184 | gridy = gridy.reshape(1, 1, self.ref, 1, 1).repeat([batchsize, self.ref, 1, self.ref, 1])
185 | gridz = torch.tensor(np.linspace(-4, 4, self.ref), dtype=torch.float)
186 | gridz = gridz.reshape(1, 1, 1, self.ref, 1).repeat([batchsize, self.ref, self.ref, 1, 1])
187 | grid_ref = torch.cat((gridx, gridy, gridz), dim=-1).cuda().reshape(batchsize, self.ref ** 3, 3) # B 4 4 4 3
188 |
189 | pos = torch.sqrt(
190 | torch.sum((my_pos[:, :, None, :] - grid_ref[:, None, :, :]) ** 2,
191 | dim=-1)). \
192 | reshape(batchsize, my_pos.shape[1], self.ref * self.ref * self.ref).contiguous()
193 | return pos
194 |
195 | def forward(self, data):
196 | cfd_data, geom_data = data
197 | x, fx, T = cfd_data.x, None, None
198 | x = x[None, :, :]
199 | if self.unified_pos:
200 | new_pos = self.get_grid(cfd_data.pos[None, :, :])
201 | x = torch.cat((x, new_pos), dim=-1)
202 |
203 | if fx is not None:
204 | fx = torch.cat((x, fx), -1)
205 | fx = self.preprocess(fx)
206 | else:
207 | fx = self.preprocess(x)
208 | fx = fx + self.placeholder[None, None, :]
209 |
210 | for block in self.blocks:
211 | fx = block(fx)
212 |
213 | return fx[0]
214 |
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torch_geometric
3 | torch-cluster
4 | vtk
5 | timm
6 | einops
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/scripts/Evaluation.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=3
2 |
3 | python main_evaluation.py \
4 | --cfd_model=Transolver \
5 | --data_dir /data/PDE_data/mlcfd_data/training_data \
6 | --save_dir /data/PDE_data/mlcfd_data/preprocessed_data \
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/scripts/Transolver.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=7
2 |
3 | python main.py \
4 | --cfd_model=Transolver \
5 | --data_dir /data/PDE_data/mlcfd_data/training_data \
6 | --save_dir /data/PDE_data/mlcfd_data/preprocessed_data \
7 |
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time, json, os
3 | import torch
4 | import torch.nn as nn
5 |
6 | from torch_geometric.loader import DataLoader
7 | from tqdm import tqdm
8 |
9 |
10 | def get_nb_trainable_params(model):
11 | '''
12 | Return the number of trainable parameters
13 | '''
14 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
15 | return sum([np.prod(p.size()) for p in model_parameters])
16 |
17 |
18 | def train(device, model, train_loader, optimizer, scheduler, reg=1):
19 | model.train()
20 |
21 | criterion_func = nn.MSELoss(reduction='none')
22 | losses_press = []
23 | losses_velo = []
24 | for cfd_data, geom in train_loader:
25 | cfd_data = cfd_data.to(device)
26 | geom = geom.to(device)
27 | optimizer.zero_grad()
28 | out = model((cfd_data, geom))
29 | targets = cfd_data.y
30 |
31 | loss_press = criterion_func(out[cfd_data.surf, -1], targets[cfd_data.surf, -1]).mean(dim=0)
32 | loss_velo_var = criterion_func(out[:, :-1], targets[:, :-1]).mean(dim=0)
33 | loss_velo = loss_velo_var.mean()
34 | total_loss = loss_velo + reg * loss_press
35 |
36 | total_loss.backward()
37 |
38 | optimizer.step()
39 | scheduler.step()
40 |
41 | losses_press.append(loss_press.item())
42 | losses_velo.append(loss_velo.item())
43 |
44 | return np.mean(losses_press), np.mean(losses_velo)
45 |
46 |
47 | @torch.no_grad()
48 | def test(device, model, test_loader):
49 | model.eval()
50 |
51 | criterion_func = nn.MSELoss(reduction='none')
52 | losses_press = []
53 | losses_velo = []
54 | for cfd_data, geom in test_loader:
55 | cfd_data = cfd_data.to(device)
56 | geom = geom.to(device)
57 | out = model((cfd_data, geom))
58 | targets = cfd_data.y
59 |
60 | loss_press = criterion_func(out[cfd_data.surf, -1], targets[cfd_data.surf, -1]).mean(dim=0)
61 | loss_velo_var = criterion_func(out[:, :-1], targets[:, :-1]).mean(dim=0)
62 | loss_velo = loss_velo_var.mean()
63 |
64 | losses_press.append(loss_press.item())
65 | losses_velo.append(loss_velo.item())
66 |
67 | return np.mean(losses_press), np.mean(losses_velo)
68 |
69 |
70 | class NumpyEncoder(json.JSONEncoder):
71 | def default(self, obj):
72 | if isinstance(obj, np.ndarray):
73 | return obj.tolist()
74 | return json.JSONEncoder.default(self, obj)
75 |
76 |
77 | def main(device, train_dataset, val_dataset, Net, hparams, path, reg=1, val_iter=1, coef_norm=[]):
78 | model = Net.to(device)
79 | optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'])
80 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
81 | optimizer,
82 | max_lr=hparams['lr'],
83 | total_steps=(len(train_dataset) // hparams['batch_size'] + 1) * hparams['nb_epochs'],
84 | final_div_factor=1000.,
85 | )
86 | start = time.time()
87 |
88 | train_loss, val_loss = 1e5, 1e5
89 | pbar_train = tqdm(range(hparams['nb_epochs']), position=0)
90 | for epoch in pbar_train:
91 | train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, drop_last=True)
92 | loss_velo, loss_press = train(device, model, train_loader, optimizer, lr_scheduler, reg=reg)
93 | train_loss = loss_velo + reg * loss_press
94 | del (train_loader)
95 |
96 | if val_iter is not None and (epoch == hparams['nb_epochs'] - 1 or epoch % val_iter == 0):
97 | val_loader = DataLoader(val_dataset, batch_size=1)
98 |
99 | loss_velo, loss_press = test(device, model, val_loader)
100 | val_loss = loss_velo + reg * loss_press
101 | del (val_loader)
102 |
103 | pbar_train.set_postfix(train_loss=train_loss, val_loss=val_loss)
104 | else:
105 | pbar_train.set_postfix(train_loss=train_loss)
106 |
107 | end = time.time()
108 | time_elapsed = end - start
109 | params_model = get_nb_trainable_params(model).astype('float')
110 | print('Number of parameters:', params_model)
111 | print('Time elapsed: {0:.2f} seconds'.format(time_elapsed))
112 | torch.save(model, path + os.sep + f'model_{hparams["nb_epochs"]}.pth')
113 |
114 | if val_iter is not None:
115 | with open(path + os.sep + f'log_{hparams["nb_epochs"]}.json', 'a') as f:
116 | json.dump(
117 | {
118 | 'nb_parameters': params_model,
119 | 'time_elapsed': time_elapsed,
120 | 'hparams': hparams,
121 | 'train_loss': train_loss,
122 | 'val_loss': val_loss,
123 | 'coef_norm': list(coef_norm),
124 | }, f, indent=12, cls=NumpyEncoder
125 | )
126 |
127 | return model
128 |
--------------------------------------------------------------------------------
/Car-Design-ShapeNetCar/utils/drag_coefficient.py:
--------------------------------------------------------------------------------
1 | import vtk
2 | import os
3 | import numpy as np
4 | from vtk.util.numpy_support import vtk_to_numpy
5 | from scipy.spatial import ConvexHull
6 |
7 |
8 | def unstructured_grid_data_to_poly_data(unstructured_grid_data):
9 | filter = vtk.vtkDataSetSurfaceFilter()
10 | filter.SetInputData(unstructured_grid_data)
11 | filter.Update()
12 | poly_data = filter.GetOutput()
13 | return poly_data, filter
14 |
15 |
16 | def load_unstructured_grid_data(file_name):
17 | reader = vtk.vtkUnstructuredGridReader()
18 | reader.SetFileName(file_name)
19 | reader.Update()
20 | output = reader.GetOutput()
21 | return output
22 |
23 |
24 | ############## calculate rectangle ##############
25 | def calculate_pos(pos):
26 | hull = ConvexHull(pos[:, :2])
27 | A = hull.volume
28 | return A
29 |
30 |
31 | ############## surf area ##############
32 | def calculate_mesh_cell_area(unstructured_grid_data):
33 | # Read VTK file
34 | poly_data, _ = unstructured_grid_data_to_poly_data(unstructured_grid_data)
35 |
36 | # Get the points and cells
37 | points = poly_data.GetPoints()
38 | cells = poly_data.GetPolys()
39 |
40 | # Initialize an array to store point areas
41 | cell_areas = np.zeros(cells.GetNumberOfCells())
42 |
43 | # Iterate through cells to calculate areas
44 | cells.InitTraversal()
45 | cell = vtk.vtkIdList()
46 | id = 0
47 | while cells.GetNextCell(cell):
48 | # Check if the cell is a quadrilateral
49 | if cell.GetNumberOfIds() == 4:
50 | # Get the four vertices of the quadrilateral
51 | p1 = np.array(points.GetPoint(cell.GetId(0)))
52 | p2 = np.array(points.GetPoint(cell.GetId(1)))
53 | p3 = np.array(points.GetPoint(cell.GetId(2)))
54 | p4 = np.array(points.GetPoint(cell.GetId(3)))
55 | # Calculate the area of the quadrilateral
56 | area = 0.5 * (
57 | np.linalg.norm(np.cross(p2 - p1, p3 - p1)) +
58 | np.linalg.norm(np.cross(p3 - p1, p4 - p1))
59 | )
60 |
61 | # Add the area to each vertex of the quadrilateral
62 | cell_areas[id] += area
63 | id += 1
64 |
65 | return cell_areas
66 |
67 |
68 | ############## velocity gradient ##############
69 | def calculate_cell_velocity_gradient(unstructured_grid_data, velocity):
70 | # Create a vtkDoubleArray for velocity
71 | velocity_data = vtk.vtkDoubleArray()
72 | velocity_data.SetNumberOfComponents(3) # Assuming 3D velocity field
73 | velocity_data.SetNumberOfTuples(unstructured_grid_data.GetNumberOfPoints())
74 | velocity_data.SetName("Velocity") # Replace "Velocity" with the desired array name
75 |
76 | # Set the velocity array values
77 | for i in range(unstructured_grid_data.GetNumberOfPoints()):
78 | velocity_data.SetTuple(i, velocity[i])
79 |
80 | # Add the velocity array to the point data
81 | unstructured_grid_data.GetPointData().AddArray(velocity_data)
82 |
83 | # Get the points and cell data (assuming velocity is stored as point data)
84 | poly_data, _ = unstructured_grid_data_to_poly_data(unstructured_grid_data)
85 | points = poly_data.GetPoints()
86 |
87 | # Initialize arrays to store velocity gradients
88 | grad_u = np.zeros((poly_data.GetNumberOfCells(), 3)) # Assuming 3D velocity field
89 | # Iterate through cells to calculate gradients
90 | cells = poly_data.GetPolys()
91 | cells.InitTraversal()
92 | cell = vtk.vtkIdList()
93 | id = 0
94 | while cells.GetNextCell(cell):
95 | # Check if the cell is a quadrilateral
96 | if cell.GetNumberOfIds() == 4:
97 | # Get the four vertices of the quadrilateral
98 | p1 = np.array(points.GetPoint(cell.GetId(0)))
99 | p2 = np.array(points.GetPoint(cell.GetId(1)))
100 | p3 = np.array(points.GetPoint(cell.GetId(2)))
101 | p4 = np.array(points.GetPoint(cell.GetId(3)))
102 | # Calculate the velocity at each vertex
103 | u1 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(0)))
104 | u2 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(1)))
105 | u3 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(2)))
106 | u4 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(3)))
107 |
108 | # Calculate the gradients using finite differences
109 | du_dx = (u2 - u1 + u3 - u4) / (np.linalg.norm(p2 - p1 + p3 - p4) + 1e-8)
110 | du_dy = (u3 - u1 + u4 - u2) / (np.linalg.norm(p3 - p1 + p4 - p2) + 1e-8)
111 | du_dz = (u4 - u1 + u2 - u3) / (np.linalg.norm(p4 - p1 + p2 - p3) + 1e-8)
112 |
113 | # Add the gradients to each vertex of the quadrilateral
114 | grad_u[id] += (du_dx + du_dy + du_dz)
115 | id += 1
116 |
117 | return grad_u
118 |
119 |
120 | ############## calculate drag ##############
121 | def calculate_drag_force(cell_areas, surface_normals, pressure_array, velocity_gradients, dynamic_viscosity):
122 | # Calculate the pressure force component along the flow direction
123 | pressure_force_component = -np.dot(pressure_array.flatten() * cell_areas.flatten(), surface_normals.flatten())
124 |
125 | # Calculate the wall shear stress component along the flow direction
126 | wall_shear_stress_component = -np.dot(velocity_gradients.flatten() * cell_areas.flatten(),
127 | surface_normals.flatten()) * dynamic_viscosity
128 | # Sum the pressure force and wall shear stress components to get the total drag force
129 | drag_force = np.sum(pressure_force_component + wall_shear_stress_component)
130 |
131 | return drag_force
132 |
133 |
134 | ############## calculate norm ##############
135 | def get_normal(unstructured_grid_data):
136 | poly_data, surface_filter = unstructured_grid_data_to_poly_data(unstructured_grid_data)
137 | normal_filter = vtk.vtkPolyDataNormals()
138 | normal_filter.SetInputData(poly_data)
139 | normal_filter.SetAutoOrientNormals(1)
140 | normal_filter.SetConsistency(1)
141 | normal_filter.SetComputeCellNormals(1)
142 | normal_filter.SetComputePointNormals(0)
143 | normal_filter.Update()
144 | return vtk_to_numpy(normal_filter.GetOutput().GetCellData().GetNormals())
145 |
146 |
147 | ############## calculate coefficient ##############
148 | def cal_coefficient(file_name, press_surf=None, velo_surf=None):
149 | root = '/data/PDE_data/mlcfd_data/training_data'
150 | save_path = '/data/PDE_data/mlcfd_data/preprocessed_data/param0/' + file_name
151 | file_name_press = 'param0/' + file_name + '/quadpress_smpl.vtk'
152 | file_name_velo = 'param0/' + file_name + '/hexvelo_smpl.vtk'
153 | file_name_press = os.path.join(root, file_name_press)
154 | file_name_velo = os.path.join(root, file_name_velo)
155 | unstructured_grid_data_press = load_unstructured_grid_data(file_name_press)
156 | unstructured_grid_data_velo = load_unstructured_grid_data(file_name_velo)
157 |
158 | # normal
159 | normal_surf = get_normal(unstructured_grid_data_press)
160 | # front area
161 | points_surf = vtk_to_numpy(unstructured_grid_data_press.GetPoints().GetData())
162 | A = calculate_pos(points_surf)
163 | # mesh area
164 | cell_areas = calculate_mesh_cell_area(unstructured_grid_data_press)
165 | # mesh velo
166 | if velo_surf is None:
167 | velo = vtk_to_numpy(unstructured_grid_data_velo.GetPointData().GetVectors())
168 | points_velo = vtk_to_numpy(unstructured_grid_data_velo.GetPoints().GetData())
169 | velo_dict = {tuple(p): velo[i] for i, p in enumerate(points_velo)}
170 | velo_surf = np.array([velo_dict[tuple(p)] if tuple(p) in velo_dict else np.zeros(3) for p in points_surf])
171 | # gradient u
172 | grad_u = calculate_cell_velocity_gradient(unstructured_grid_data_press, velo_surf)
173 | # press
174 | if press_surf is None:
175 | c2p = vtk.vtkPointDataToCellData()
176 | c2p.SetInputData(unstructured_grid_data_press)
177 | c2p.Update()
178 | unstructured_grid_data_press = c2p.GetOutput()
179 | press_surf = vtk_to_numpy(unstructured_grid_data_press.GetCellData().GetScalars())
180 | else:
181 | # Create a vtkDoubleArray for press
182 | press_data = vtk.vtkDoubleArray()
183 | press_data.SetNumberOfComponents(1) # Assuming 3D velocity field
184 | press_data.SetNumberOfTuples(unstructured_grid_data_press.GetNumberOfPoints())
185 | press_data.SetName("my_press") # Replace "my_press" with the desired array name
186 |
187 | # Set the velocity array values
188 | for i in range(unstructured_grid_data_press.GetNumberOfPoints()):
189 | press_data.SetTuple(i, press_surf[i])
190 |
191 | # Add the velocity array to the point data
192 | unstructured_grid_data_press.GetPointData().AddArray(press_data)
193 | c2p = vtk.vtkPointDataToCellData()
194 | c2p.SetInputData(unstructured_grid_data_press)
195 | c2p.Update()
196 | unstructured_grid_data_press = c2p.GetOutput()
197 | press_surf = vtk_to_numpy(unstructured_grid_data_press.GetCellData().GetArray("my_press"))
198 |
199 | drag_force = calculate_drag_force(cell_areas, normal_surf[:, -1], press_surf, grad_u[:, -1], np.array(1.8e-5))
200 | nu = 72 / 3.6
201 | air_density = 0.3
202 | cd = (2 / ((nu ** 2) * A * air_density)) * drag_force
203 | return cd
204 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 THUML @ Tsinghua University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/README.md:
--------------------------------------------------------------------------------
1 | # Transolver for PDE Solving
2 |
3 | We evaluate [Transolver](https://arxiv.org/abs/2402.02366) with six widely used PDE-solving benchmarks, which is provided by [FNO and GeoFNO](https://github.com/neuraloperator/neuraloperator).
4 |
5 | **Transolver achieves 22% averaged relative promotion over the previous second-best model, presenting favorable efficiency and scalibility.**
6 |
7 |
8 |
9 |
10 | Table 1. Comparison in six standard benchmarks. Relative L2 is recorded.
11 |
12 |
13 |
14 | ## Get Started
15 |
16 | 1. Install Python 3.8. For convenience, execute the following command.
17 |
18 | ```bash
19 | pip install -r requirements.txt
20 | ```
21 |
22 | 2. Prepare Data. You can obtain experimental datasets from the following links.
23 |
24 |
25 | | Dataset | Task | Geometry | Link |
26 | | ------------- | --------------------------------------- | --------------- | ------------------------------------------------------------ |
27 | | Elasticity | Estimate material inner stress | Point Cloud | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) |
28 | | Plasticity | Estimate material deformation over time | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) |
29 | | Navier-Stokes | Predict future fluid velocity | Regular Grid | [[Google Cloud]](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) |
30 | | Darcy | Estimate fluid pressure through medium | Regular Grid | [[Google Cloud]](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) |
31 | | AirFoil | Estimate airflow velocity around airfoil | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) |
32 | | Pipe | Estimate fluid velocity in a pipe | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) |
33 |
34 | 3. Train and evaluate model. We provide the experiment scripts of all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples:
35 |
36 | ```bash
37 | bash scripts/Transolver_Elas.sh # for Elasticity
38 | bash scripts/Transolver_Plas.sh # for Plasticity
39 | bash scripts/Transolver_NS.sh # for Navier-Stokes
40 | bash scripts/Transolver_Darcy.sh # for Darcy
41 | bash scripts/Transolver_Airfoil.sh # for Airfoil
42 | bash scripts/Transolver_Pipe.sh # for Pipe
43 | ```
44 |
45 | Note: You need to change the argument `--data_path` to your dataset path.
46 |
47 | 4. Develop your own model. Here are the instructions:
48 |
49 | - Add the model file under folder `./models/`.
50 | - Add the model name into `./model_dict.py`.
51 | - Add a script file under folder `./scripts/` and change the argument `--model`.
52 |
53 | ## Visualization
54 |
55 | Transolver can handle PDEs under various geometrics well, such as predicting the future fluid and estimating the [[shock wave]](https://en.wikipedia.org/wiki/Shock_wave) around airfoil.
56 |
57 |
58 |
59 |
60 | Figure 1. Case study of different models.
61 |
62 |
63 | ## PDE Solving at Scale
64 |
65 | To align with previous model, we only experiment with 8-layer Transolver in the main text. Actually, you can easily obtain a better performance by **scaling up Transolver**. The relative L2 generally decreases when we adding more layers.
66 |
67 |
68 |
69 |
70 | Figure 2. Scaling up Transolver: relative L2 curve w.r.t. model layers.
71 |
72 |
73 | ## Citation
74 |
75 | If you find this repo useful, please cite our paper.
76 |
77 | ```
78 | @inproceedings{wu2024Transolver,
79 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries},
80 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long},
81 | booktitle={International Conference on Machine Learning},
82 | year={2024}
83 | }
84 | ```
85 |
86 | ## Contact
87 |
88 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn).
89 |
90 | ## Acknowledgement
91 |
92 | We appreciate the following github repos a lot for their valuable code base or datasets:
93 |
94 | https://github.com/neuraloperator/neuraloperator
95 |
96 | https://github.com/neuraloperator/Geo-FNO
97 |
98 | https://github.com/thuml/Latent-Spectral-Models
99 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/exp_airfoil.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 | from tqdm import *
7 | from utils.testloss import TestLoss
8 | from model_dict import get_model
9 |
10 | parser = argparse.ArgumentParser('Training Transformer')
11 |
12 | parser.add_argument('--lr', type=float, default=1e-3)
13 | parser.add_argument('--epochs', type=int, default=500)
14 | parser.add_argument('--weight_decay', type=float, default=1e-5)
15 | parser.add_argument('--model', type=str, default='Transolver_2D')
16 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim')
17 | parser.add_argument('--n-layers', type=int, default=3, help='layers')
18 | parser.add_argument('--n-heads', type=int, default=4)
19 | parser.add_argument('--batch-size', type=int, default=8)
20 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use")
21 | parser.add_argument('--max_grad_norm', type=float, default=None)
22 | parser.add_argument('--downsamplex', type=int, default=1)
23 | parser.add_argument('--downsampley', type=int, default=1)
24 | parser.add_argument('--mlp_ratio', type=int, default=1)
25 | parser.add_argument('--dropout', type=float, default=0.0)
26 | parser.add_argument('--unified_pos', type=int, default=0)
27 | parser.add_argument('--ref', type=int, default=8)
28 | parser.add_argument('--slice_num', type=int, default=32)
29 | parser.add_argument('--eval', type=int, default=0)
30 | parser.add_argument('--save_name', type=str, default='airfoil_Transolver')
31 | parser.add_argument('--data_path', type=str, default='/data/fno/airfoil/naca')
32 | args = parser.parse_args()
33 | eval = args.eval
34 | save_name = args.save_name
35 |
36 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
37 |
38 |
39 | def count_parameters(model):
40 | total_params = 0
41 | for name, parameter in model.named_parameters():
42 | if not parameter.requires_grad: continue
43 | params = parameter.numel()
44 | total_params += params
45 | print(f"Total Trainable Params: {total_params}")
46 | return total_params
47 |
48 |
49 | def main():
50 | INPUT_X = args.data_path + '/NACA_Cylinder_X.npy'
51 | INPUT_Y = args.data_path + '/NACA_Cylinder_Y.npy'
52 | OUTPUT_Sigma = args.data_path + '/NACA_Cylinder_Q.npy'
53 |
54 | ntrain = 1000
55 | ntest = 200
56 |
57 | r1 = args.downsamplex
58 | r2 = args.downsampley
59 | s1 = int(((221 - 1) / r1) + 1)
60 | s2 = int(((51 - 1) / r2) + 1)
61 |
62 | inputX = np.load(INPUT_X)
63 | inputX = torch.tensor(inputX, dtype=torch.float)
64 | inputY = np.load(INPUT_Y)
65 | inputY = torch.tensor(inputY, dtype=torch.float)
66 | input = torch.stack([inputX, inputY], dim=-1)
67 |
68 | output = np.load(OUTPUT_Sigma)[:, 4]
69 | output = torch.tensor(output, dtype=torch.float)
70 | print(input.shape, output.shape)
71 |
72 | x_train = input[:ntrain, ::r1, ::r2][:, :s1, :s2]
73 | y_train = output[:ntrain, ::r1, ::r2][:, :s1, :s2]
74 | x_test = input[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]
75 | y_test = output[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]
76 | x_train = x_train.reshape(ntrain, -1, 2)
77 | x_test = x_test.reshape(ntest, -1, 2)
78 | y_train = y_train.reshape(ntrain, -1)
79 | y_test = y_test.reshape(ntest, -1)
80 |
81 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, x_train, y_train),
82 | batch_size=args.batch_size,
83 | shuffle=True)
84 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, x_test, y_test),
85 | batch_size=args.batch_size,
86 | shuffle=False)
87 |
88 | print("Dataloading is over.")
89 |
90 | model = get_model(args).Model(space_dim=2,
91 | n_layers=args.n_layers,
92 | n_hidden=args.n_hidden,
93 | dropout=args.dropout,
94 | n_head=args.n_heads,
95 | Time_Input=False,
96 | mlp_ratio=args.mlp_ratio,
97 | fun_dim=0,
98 | out_dim=1,
99 | slice_num=args.slice_num,
100 | ref=args.ref,
101 | unified_pos=args.unified_pos,
102 | H=s1, W=s2).cuda()
103 |
104 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
105 | print(args)
106 | print(model)
107 | count_parameters(model)
108 |
109 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,
110 | steps_per_epoch=len(train_loader))
111 | myloss = TestLoss(size_average=False)
112 |
113 | if eval:
114 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"))
115 | model.eval()
116 | if not os.path.exists('./results/' + save_name + '/'):
117 | os.makedirs('./results/' + save_name + '/')
118 |
119 | rel_err = 0.0
120 | showcase = 10
121 | id = 0
122 |
123 | with torch.no_grad():
124 | for pos, fx, y in test_loader:
125 | id += 1
126 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
127 | out = model(x, None).squeeze(-1)
128 |
129 | tl = myloss(out, y).item()
130 | rel_err += tl
131 | if id < showcase:
132 | print(id)
133 | plt.axis('off')
134 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
135 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
136 | np.zeros([140, 35]),
137 | shading='auto',
138 | edgecolors='black', linewidths=0.1)
139 | plt.colorbar()
140 | plt.savefig(
141 | os.path.join('./results/' + save_name + '/',
142 | "input_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
143 | plt.close()
144 | plt.axis('off')
145 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
146 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
147 | out[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
148 | shading='auto', cmap='coolwarm')
149 | plt.colorbar()
150 | plt.clim(0, 1.2)
151 | plt.savefig(
152 | os.path.join('./results/' + save_name + '/',
153 | "pred_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
154 | plt.close()
155 | plt.axis('off')
156 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
157 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
158 | y[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
159 | shading='auto', cmap='coolwarm')
160 | plt.colorbar()
161 | plt.clim(0, 1.2)
162 | plt.savefig(
163 | os.path.join('./results/' + save_name + '/',
164 | "gt_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
165 | plt.close()
166 | plt.axis('off')
167 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
168 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
169 | out[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy() - \
170 | y[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy(),
171 | shading='auto', cmap='coolwarm')
172 | plt.colorbar()
173 | plt.clim(-0.2, 0.2)
174 | plt.savefig(
175 | os.path.join('./results/' + save_name + '/',
176 | "error_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
177 | plt.close()
178 |
179 | rel_err /= ntest
180 | print("rel_err:{}".format(rel_err))
181 | else:
182 | for ep in range(args.epochs):
183 |
184 | model.train()
185 | train_loss = 0
186 |
187 | for pos, fx, y in train_loader:
188 |
189 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() # x:B,N,2 fx:B,N,2 y:B,N
190 | optimizer.zero_grad()
191 | out = model(x, None).squeeze(-1)
192 | loss = myloss(out, y)
193 | loss.backward()
194 |
195 | if args.max_grad_norm is not None:
196 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
197 | optimizer.step()
198 | train_loss += loss.item()
199 | scheduler.step()
200 |
201 | train_loss = train_loss / ntrain
202 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss))
203 |
204 | model.eval()
205 | rel_err = 0.0
206 | with torch.no_grad():
207 | for pos, fx, y in test_loader:
208 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
209 | out = model(x, None).squeeze(-1)
210 |
211 | tl = myloss(out, y).item()
212 | rel_err += tl
213 |
214 | rel_err /= ntest
215 | print("rel_err:{}".format(rel_err))
216 |
217 | if ep % 100 == 0:
218 | if not os.path.exists('./checkpoints'):
219 | os.makedirs('./checkpoints')
220 | print('save model')
221 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
222 |
223 | if not os.path.exists('./checkpoints'):
224 | os.makedirs('./checkpoints')
225 | print('save model')
226 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
227 |
228 |
229 | if __name__ == "__main__":
230 | main()
231 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/exp_darcy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import scipy.io as scio
5 | import torch
6 | import torch.nn.functional as F
7 | from tqdm import *
8 | from utils.testloss import TestLoss
9 | from einops import rearrange
10 | from model_dict import get_model
11 | from utils.normalizer import UnitTransformer
12 | import matplotlib.pyplot as plt
13 |
14 | parser = argparse.ArgumentParser('Training Transolver')
15 |
16 | parser.add_argument('--lr', type=float, default=1e-3)
17 | parser.add_argument('--epochs', type=int, default=500)
18 | parser.add_argument('--weight_decay', type=float, default=1e-5)
19 | parser.add_argument('--model', type=str, default='Transolver_2D')
20 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim')
21 | parser.add_argument('--n-layers', type=int, default=3, help='layers')
22 | parser.add_argument('--n-heads', type=int, default=4)
23 | parser.add_argument('--batch-size', type=int, default=8)
24 | parser.add_argument("--gpu", type=str, default='1', help="GPU index to use")
25 | parser.add_argument('--max_grad_norm', type=float, default=None)
26 | parser.add_argument('--downsample', type=int, default=5)
27 | parser.add_argument('--mlp_ratio', type=int, default=1)
28 | parser.add_argument('--dropout', type=float, default=0.0)
29 | parser.add_argument('--ntrain', type=int, default=1000)
30 | parser.add_argument('--unified_pos', type=int, default=0)
31 | parser.add_argument('--ref', type=int, default=8)
32 | parser.add_argument('--slice_num', type=int, default=32)
33 | parser.add_argument('--eval', type=int, default=0)
34 | parser.add_argument('--save_name', type=str, default='darcy_Transolver')
35 | parser.add_argument('--data_path', type=str, default='/data/fno')
36 | args = parser.parse_args()
37 |
38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
39 |
40 | train_path = args.data_path + '/piececonst_r421_N1024_smooth1.mat'
41 | test_path = args.data_path + '/piececonst_r421_N1024_smooth2.mat'
42 | ntrain = args.ntrain
43 | ntest = 200
44 | epochs = args.epochs
45 | eval = args.eval
46 | save_name = args.save_name
47 |
48 |
49 | def count_parameters(model):
50 | total_params = 0
51 | for name, parameter in model.named_parameters():
52 | if not parameter.requires_grad: continue
53 | params = parameter.numel()
54 | total_params += params
55 | print(f"Total Trainable Params: {total_params}")
56 | return total_params
57 |
58 |
59 | def central_diff(x: torch.Tensor, h, resolution):
60 | # assuming PBC
61 | # x: (batch, n, feats), h is the step size, assuming n = h*w
62 | x = rearrange(x, 'b (h w) c -> b h w c', h=resolution, w=resolution)
63 | x = F.pad(x,
64 | (0, 0, 1, 1, 1, 1), mode='constant', value=0.) # [b c t h+2 w+2]
65 | grad_x = (x[:, 1:-1, 2:, :] - x[:, 1:-1, :-2, :]) / (2 * h) # f(x+h) - f(x-h) / 2h
66 | grad_y = (x[:, 2:, 1:-1, :] - x[:, :-2, 1:-1, :]) / (2 * h) # f(x+h) - f(x-h) / 2h
67 |
68 | return grad_x, grad_y
69 |
70 |
71 | def main():
72 | r = args.downsample
73 | h = int(((421 - 1) / r) + 1)
74 | s = h
75 | dx = 1.0 / s
76 |
77 | train_data = scio.loadmat(train_path)
78 | x_train = train_data['coeff'][:ntrain, ::r, ::r][:, :s, :s]
79 | x_train = x_train.reshape(ntrain, -1)
80 | x_train = torch.from_numpy(x_train).float()
81 | y_train = train_data['sol'][:ntrain, ::r, ::r][:, :s, :s]
82 | y_train = y_train.reshape(ntrain, -1)
83 | y_train = torch.from_numpy(y_train)
84 |
85 | test_data = scio.loadmat(test_path)
86 | x_test = test_data['coeff'][:ntest, ::r, ::r][:, :s, :s]
87 | x_test = x_test.reshape(ntest, -1)
88 | x_test = torch.from_numpy(x_test).float()
89 | y_test = test_data['sol'][:ntest, ::r, ::r][:, :s, :s]
90 | y_test = y_test.reshape(ntest, -1)
91 | y_test = torch.from_numpy(y_test)
92 |
93 | x_normalizer = UnitTransformer(x_train)
94 | y_normalizer = UnitTransformer(y_train)
95 |
96 | x_train = x_normalizer.encode(x_train)
97 | x_test = x_normalizer.encode(x_test)
98 | y_train = y_normalizer.encode(y_train)
99 |
100 | x_normalizer.cuda()
101 | y_normalizer.cuda()
102 |
103 | x = np.linspace(0, 1, s)
104 | y = np.linspace(0, 1, s)
105 | x, y = np.meshgrid(x, y)
106 | pos = np.c_[x.ravel(), y.ravel()]
107 | pos = torch.tensor(pos, dtype=torch.float).unsqueeze(0)
108 |
109 | pos_train = pos.repeat(ntrain, 1, 1)
110 | pos_test = pos.repeat(ntest, 1, 1)
111 | print("Dataloading is over.")
112 |
113 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_train, x_train, y_train),
114 | batch_size=args.batch_size, shuffle=True)
115 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, x_test, y_test),
116 | batch_size=args.batch_size, shuffle=False)
117 |
118 | model = get_model(args).Model(space_dim=2,
119 | n_layers=args.n_layers,
120 | n_hidden=args.n_hidden,
121 | dropout=args.dropout,
122 | n_head=args.n_heads,
123 | Time_Input=False,
124 | mlp_ratio=args.mlp_ratio,
125 | fun_dim=1,
126 | out_dim=1,
127 | slice_num=args.slice_num,
128 | ref=args.ref,
129 | unified_pos=args.unified_pos,
130 | H=s, W=s).cuda()
131 |
132 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
133 |
134 | print(args)
135 | print(model)
136 | count_parameters(model)
137 |
138 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=epochs,
139 | steps_per_epoch=len(train_loader))
140 | myloss = TestLoss(size_average=False)
141 | de_x = TestLoss(size_average=False)
142 | de_y = TestLoss(size_average=False)
143 |
144 | if eval:
145 | print("model evaluation")
146 | print(s, s)
147 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"), strict=False)
148 | model.eval()
149 | showcase = 10
150 | id = 0
151 | if not os.path.exists('./results/' + save_name + '/'):
152 | os.makedirs('./results/' + save_name + '/')
153 |
154 | with torch.no_grad():
155 | rel_err = 0.0
156 | with torch.no_grad():
157 | for x, fx, y in test_loader:
158 | id += 1
159 | x, fx, y = x.cuda(), fx.cuda(), y.cuda()
160 | out = model(x, fx=fx.unsqueeze(-1)).squeeze(-1)
161 | out = y_normalizer.decode(out)
162 | tl = myloss(out, y).item()
163 |
164 | rel_err += tl
165 |
166 | if id < showcase:
167 | print(id)
168 | plt.figure()
169 | plt.axis('off')
170 | plt.imshow(out[0, :].reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm')
171 | plt.colorbar()
172 | plt.savefig(
173 | os.path.join('./results/' + save_name + '/',
174 | "case_" + str(id) + "_pred.pdf"))
175 | plt.close()
176 | # ============ #
177 | plt.figure()
178 | plt.axis('off')
179 | plt.imshow(y[0, :].reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm')
180 | plt.colorbar()
181 | plt.savefig(
182 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_gt.pdf"))
183 | plt.close()
184 | # ============ #
185 | plt.figure()
186 | plt.axis('off')
187 | plt.imshow((y[0, :] - out[0, :]).reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm')
188 | plt.colorbar()
189 | plt.clim(-0.0005, 0.0005)
190 | plt.savefig(
191 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_error.pdf"))
192 | plt.close()
193 | # ============ #
194 | plt.figure()
195 | plt.axis('off')
196 | plt.imshow((fx[0, :].unsqueeze(-1)).reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm')
197 | plt.colorbar()
198 | plt.savefig(
199 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_input.pdf"))
200 | plt.close()
201 |
202 | rel_err /= ntest
203 | print("rel_err:{}".format(rel_err))
204 | else:
205 | for ep in range(args.epochs):
206 | model.train()
207 | train_loss = 0
208 | reg = 0
209 | for x, fx, y in train_loader:
210 | x, fx, y = x.cuda(), fx.cuda(), y.cuda()
211 | optimizer.zero_grad()
212 |
213 | out = model(x, fx=fx.unsqueeze(-1)).squeeze(-1) # B, N , 2, fx: B, N, y: B, N
214 | out = y_normalizer.decode(out)
215 | y = y_normalizer.decode(y)
216 |
217 | l2loss = myloss(out, y)
218 |
219 | out = rearrange(out.unsqueeze(-1), 'b (h w) c -> b c h w', h=s)
220 | out = out[..., 1:-1, 1:-1].contiguous()
221 | out = F.pad(out, (1, 1, 1, 1), "constant", 0)
222 | out = rearrange(out, 'b c h w -> b (h w) c')
223 | gt_grad_x, gt_grad_y = central_diff(y.unsqueeze(-1), dx, s)
224 | pred_grad_x, pred_grad_y = central_diff(out, dx, s)
225 | deriv_loss = de_x(pred_grad_x, gt_grad_x) + de_y(pred_grad_y, gt_grad_y)
226 | loss = 0.1 * deriv_loss + l2loss
227 | loss.backward()
228 |
229 | if args.max_grad_norm is not None:
230 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
231 | optimizer.step()
232 | train_loss += l2loss.item()
233 | reg += deriv_loss.item()
234 | scheduler.step()
235 |
236 | train_loss /= ntrain
237 | reg /= ntrain
238 | print("Epoch {} Reg : {:.5f} Train loss : {:.5f}".format(ep, reg, train_loss))
239 |
240 | model.eval()
241 | rel_err = 0.0
242 | id = 0
243 | with torch.no_grad():
244 | for x, fx, y in test_loader:
245 | id += 1
246 | if id == 2:
247 | vis = True
248 | else:
249 | vis = False
250 | x, fx, y = x.cuda(), fx.cuda(), y.cuda()
251 | out = model(x, fx=fx.unsqueeze(-1)).squeeze(-1)
252 | out = y_normalizer.decode(out)
253 | tl = myloss(out, y).item()
254 | rel_err += tl
255 |
256 | rel_err /= ntest
257 | print("rel_err:{}".format(rel_err))
258 |
259 | if ep % 100 == 0:
260 | if not os.path.exists('./checkpoints'):
261 | os.makedirs('./checkpoints')
262 | print('save model')
263 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
264 |
265 | if not os.path.exists('./checkpoints'):
266 | os.makedirs('./checkpoints')
267 | print('save model')
268 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
269 |
270 |
271 | if __name__ == "__main__":
272 | main()
273 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/exp_elas.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 | from tqdm import *
7 | from utils.testloss import TestLoss
8 | from model_dict import get_model
9 | from utils.normalizer import UnitTransformer
10 |
11 | parser = argparse.ArgumentParser('Training Transformer')
12 |
13 | parser.add_argument('--lr', type=float, default=1e-3)
14 | parser.add_argument('--epochs', type=int, default=500)
15 | parser.add_argument('--weight_decay', type=float, default=1e-5)
16 | parser.add_argument('--model', type=str, default='Transolver_1D')
17 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim')
18 | parser.add_argument('--n-layers', type=int, default=3, help='layers')
19 | parser.add_argument('--n-heads', type=int, default=4)
20 | parser.add_argument('--batch-size', type=int, default=8)
21 | parser.add_argument("--gpu", type=str, default='1', help="GPU index to use")
22 | parser.add_argument('--max_grad_norm', type=float, default=None)
23 | parser.add_argument('--downsample', type=int, default=5)
24 | parser.add_argument('--mlp_ratio', type=int, default=1)
25 | parser.add_argument('--dropout', type=float, default=0.0)
26 | parser.add_argument('--ntrain', type=int, default=1000)
27 | parser.add_argument('--unified_pos', type=int, default=0)
28 | parser.add_argument('--ref', type=int, default=8)
29 | parser.add_argument('--slice_num', type=int, default=32)
30 | parser.add_argument('--eval', type=int, default=0)
31 | parser.add_argument('--save_name', type=str, default='elas_Transolver')
32 | parser.add_argument('--data_path', type=str, default='/data/fno')
33 | args = parser.parse_args()
34 | eval = args.eval
35 | save_name = args.save_name
36 |
37 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
38 |
39 |
40 | def count_parameters(model):
41 | total_params = 0
42 | for name, parameter in model.named_parameters():
43 | if not parameter.requires_grad: continue
44 | params = parameter.numel()
45 | total_params += params
46 | print(f"Total Trainable Params: {total_params}")
47 | return total_params
48 |
49 |
50 | def main():
51 | ntrain = args.ntrain
52 | ntest = 200
53 |
54 | PATH_Sigma = args.data_path + '/elasticity/Meshes/Random_UnitCell_sigma_10.npy'
55 | PATH_XY = args.data_path + '/elasticity/Meshes/Random_UnitCell_XY_10.npy'
56 |
57 | input_s = np.load(PATH_Sigma)
58 | input_s = torch.tensor(input_s, dtype=torch.float).permute(1, 0)
59 | input_xy = np.load(PATH_XY)
60 | input_xy = torch.tensor(input_xy, dtype=torch.float).permute(2, 0, 1)
61 |
62 | train_s = input_s[:ntrain]
63 | test_s = input_s[-ntest:]
64 | train_xy = input_xy[:ntrain]
65 | test_xy = input_xy[-ntest:]
66 |
67 | print(input_s.shape, input_xy.shape)
68 |
69 | y_normalizer = UnitTransformer(train_s)
70 |
71 | train_s = y_normalizer.encode(train_s)
72 | y_normalizer.cuda()
73 |
74 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_xy, train_xy, train_s),
75 | batch_size=args.batch_size,
76 | shuffle=True)
77 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_xy, test_xy, test_s),
78 | batch_size=args.batch_size,
79 | shuffle=False)
80 |
81 | print("Dataloading is over.")
82 |
83 | model = get_model(args).Model(space_dim=2,
84 | n_layers=args.n_layers,
85 | n_hidden=args.n_hidden,
86 | dropout=args.dropout,
87 | n_head=args.n_heads,
88 | Time_Input=False,
89 | mlp_ratio=args.mlp_ratio,
90 | fun_dim=0,
91 | out_dim=1,
92 | slice_num=args.slice_num,
93 | ref=args.ref,
94 | unified_pos=args.unified_pos).cuda()
95 |
96 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
97 |
98 | print(args)
99 | print(model)
100 | count_parameters(model)
101 |
102 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
103 |
104 | myloss = TestLoss(size_average=False)
105 |
106 | if eval:
107 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"))
108 | model.eval()
109 | if not os.path.exists('./results/' + save_name + '/'):
110 | os.makedirs('./results/' + save_name + '/')
111 | rel_err = 0.0
112 | showcase = 10
113 | id = 0
114 |
115 | with torch.no_grad():
116 | for pos, fx, y in test_loader:
117 | id += 1
118 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
119 | out = model(x, None).squeeze(-1)
120 | out = y_normalizer.decode(out)
121 | tl = myloss(out, y).item()
122 | rel_err += tl
123 | if id < showcase:
124 | print(id)
125 | plt.axis('off')
126 | plt.scatter(x=fx[0, :, 0].detach().cpu().numpy(), y=fx[0, :, 1].detach().cpu().numpy(),
127 | c=y[0, :].detach().cpu().numpy(), cmap='coolwarm')
128 | plt.colorbar()
129 | plt.clim(0, 1000)
130 | plt.savefig(
131 | os.path.join('./results/' + save_name + '/',
132 | "gt_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
133 | plt.close()
134 |
135 | plt.axis('off')
136 | plt.scatter(x=fx[0, :, 0].detach().cpu().numpy(), y=fx[0, :, 1].detach().cpu().numpy(),
137 | c=out[0, :].detach().cpu().numpy(), cmap='coolwarm')
138 | plt.colorbar()
139 | plt.clim(0, 1000)
140 | plt.savefig(
141 | os.path.join('./results/' + save_name + '/',
142 | "pred_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
143 | plt.close()
144 |
145 | plt.axis('off')
146 | plt.scatter(x=fx[0, :, 0].detach().cpu().numpy(), y=fx[0, :, 1].detach().cpu().numpy(),
147 | c=((y[0, :] - out[0, :])).detach().cpu().numpy(), cmap='coolwarm')
148 | plt.clim(-8, 8)
149 | plt.colorbar()
150 | plt.savefig(
151 | os.path.join('./results/' + save_name + '/',
152 | "error_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
153 | plt.close()
154 |
155 | rel_err /= ntest
156 | print("rel_err : {}".format(rel_err))
157 | else:
158 | for ep in range(args.epochs):
159 |
160 | model.train()
161 | train_loss = 0
162 |
163 | for pos, fx, y in train_loader:
164 |
165 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() # x:B,N,2 fx:B,N,2 y:B,N,
166 | optimizer.zero_grad()
167 | out = model(x, None).squeeze(-1)
168 | out = y_normalizer.decode(out)
169 | y = y_normalizer.decode(y)
170 | loss = myloss(out, y)
171 | loss.backward()
172 |
173 | if args.max_grad_norm is not None:
174 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
175 | optimizer.step()
176 | train_loss += loss.item()
177 | scheduler.step()
178 |
179 | train_loss = train_loss / ntrain
180 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss))
181 |
182 | model.eval()
183 | rel_err = 0.0
184 | with torch.no_grad():
185 | for pos, fx, y in test_loader:
186 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
187 | out = model(x, None).squeeze(-1)
188 | out = y_normalizer.decode(out)
189 | tl = myloss(out, y).item()
190 | rel_err += tl
191 |
192 | rel_err /= ntest
193 | print("rel_err : {}".format(rel_err))
194 |
195 | if ep % 100 == 0:
196 | if not os.path.exists('./checkpoints'):
197 | os.makedirs('./checkpoints')
198 | print('save model')
199 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
200 |
201 | if not os.path.exists('./checkpoints'):
202 | os.makedirs('./checkpoints')
203 | print('save model')
204 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
205 |
206 |
207 | if __name__ == "__main__":
208 | main()
209 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/exp_ns.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | import argparse
4 | import scipy.io as scio
5 | import numpy as np
6 | import torch
7 | from tqdm import *
8 | from utils.testloss import TestLoss
9 | from model_dict import get_model
10 |
11 | parser = argparse.ArgumentParser('Training Transformer')
12 |
13 | parser.add_argument('--lr', type=float, default=1e-3)
14 | parser.add_argument('--epochs', type=int, default=500)
15 | parser.add_argument('--weight_decay', type=float, default=1e-5)
16 | parser.add_argument('--model', type=str, default='Transolver_2D')
17 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim')
18 | parser.add_argument('--n-layers', type=int, default=3, help='layers')
19 | parser.add_argument('--n-heads', type=int, default=4)
20 | parser.add_argument('--batch-size', type=int, default=8)
21 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use")
22 | parser.add_argument('--max_grad_norm', type=float, default=None)
23 | parser.add_argument('--downsample', type=int, default=1)
24 | parser.add_argument('--mlp_ratio', type=int, default=1)
25 | parser.add_argument('--dropout', type=float, default=0.0)
26 | parser.add_argument('--unified_pos', type=int, default=0)
27 | parser.add_argument('--ref', type=int, default=8)
28 | parser.add_argument('--slice_num', type=int, default=32)
29 | parser.add_argument('--eval', type=int, default=0)
30 | parser.add_argument('--save_name', type=str, default='ns_2d_UniPDE')
31 | parser.add_argument('--data_path', type=str, default='/data/fno')
32 | args = parser.parse_args()
33 |
34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
35 |
36 | data_path = args.data_path + '/NavierStokes_V1e-5_N1200_T20/NavierStokes_V1e-5_N1200_T20.mat'
37 | # data_path = args.data_path + '/NavierStokes_V1e-5_N1200_T20.mat'
38 | ntrain = 1000
39 | ntest = 200
40 | T_in = 10
41 | T = 10
42 | step = 1
43 | eval = args.eval
44 | save_name = args.save_name
45 |
46 |
47 | def count_parameters(model):
48 | total_params = 0
49 | for name, parameter in model.named_parameters():
50 | if not parameter.requires_grad: continue
51 | params = parameter.numel()
52 | total_params += params
53 | print(f"Total Trainable Params: {total_params}")
54 | return total_params
55 |
56 |
57 | def main():
58 | r = args.downsample
59 | h = int(((64 - 1) / r) + 1)
60 |
61 | data = scio.loadmat(data_path)
62 | print(data['u'].shape)
63 | train_a = data['u'][:ntrain, ::r, ::r, :T_in][:, :h, :h, :]
64 | train_a = train_a.reshape(train_a.shape[0], -1, train_a.shape[-1])
65 | train_a = torch.from_numpy(train_a)
66 | train_u = data['u'][:ntrain, ::r, ::r, T_in:T + T_in][:, :h, :h, :]
67 | train_u = train_u.reshape(train_u.shape[0], -1, train_u.shape[-1])
68 | train_u = torch.from_numpy(train_u)
69 |
70 | test_a = data['u'][-ntest:, ::r, ::r, :T_in][:, :h, :h, :]
71 | test_a = test_a.reshape(test_a.shape[0], -1, test_a.shape[-1])
72 | test_a = torch.from_numpy(test_a)
73 | test_u = data['u'][-ntest:, ::r, ::r, T_in:T + T_in][:, :h, :h, :]
74 | test_u = test_u.reshape(test_u.shape[0], -1, test_u.shape[-1])
75 | test_u = torch.from_numpy(test_u)
76 |
77 | x = np.linspace(0, 1, h)
78 | y = np.linspace(0, 1, h)
79 | x, y = np.meshgrid(x, y)
80 | pos = np.c_[x.ravel(), y.ravel()]
81 | pos = torch.tensor(pos, dtype=torch.float).unsqueeze(0)
82 | pos_train = pos.repeat(ntrain, 1, 1)
83 | pos_test = pos.repeat(ntest, 1, 1)
84 |
85 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_train, train_a, train_u),
86 | batch_size=args.batch_size, shuffle=True)
87 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, test_a, test_u),
88 | batch_size=args.batch_size, shuffle=False)
89 |
90 | print("Dataloading is over.")
91 |
92 | model = get_model(args).Model(space_dim=2,
93 | n_layers=args.n_layers,
94 | n_hidden=args.n_hidden,
95 | dropout=args.dropout,
96 | n_head=args.n_heads,
97 | Time_Input=False,
98 | mlp_ratio=args.mlp_ratio,
99 | fun_dim=T_in,
100 | out_dim=1,
101 | slice_num=args.slice_num,
102 | ref=args.ref,
103 | unified_pos=args.unified_pos,
104 | H=h, W=h).cuda()
105 |
106 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
107 |
108 | print(args)
109 | print(model)
110 | count_parameters(model)
111 |
112 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,
113 | steps_per_epoch=len(train_loader))
114 | myloss = TestLoss(size_average=False)
115 |
116 | if eval:
117 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"), strict=False)
118 | model.eval()
119 | showcase = 10
120 | id = 0
121 |
122 | if not os.path.exists('./results/' + save_name + '/'):
123 | os.makedirs('./results/' + save_name + '/')
124 |
125 | test_l2_full = 0
126 | with torch.no_grad():
127 | for x, fx, yy in test_loader:
128 | id += 1
129 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() # x : B, 4096, 2 fx : B, 4096 y : B, 4096, T
130 | bsz = x.shape[0]
131 | for t in range(0, T, step):
132 | im = model(x, fx=fx)
133 |
134 | fx = torch.cat((fx[..., step:], im), dim=-1)
135 | if t == 0:
136 | pred = im
137 | else:
138 | pred = torch.cat((pred, im), -1)
139 |
140 | if id < showcase:
141 | print(id)
142 | plt.figure()
143 | plt.axis('off')
144 | plt.imshow(im[0, :, 0].reshape(64, 64).detach().cpu().numpy(), cmap='coolwarm')
145 | plt.colorbar()
146 | plt.clim(-3, 3)
147 | plt.savefig(
148 | os.path.join('./results/' + save_name + '/',
149 | "case_" + str(id) + "_pred_" + str(20) + ".pdf"))
150 | plt.close()
151 | # ============ #
152 | plt.figure()
153 | plt.axis('off')
154 | plt.imshow(yy[0, :, t].reshape(64, 64).detach().cpu().numpy(), cmap='coolwarm')
155 | plt.colorbar()
156 | plt.clim(-3, 3)
157 | plt.savefig(
158 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_gt_" + str(20) + ".pdf"))
159 | plt.close()
160 | # ============ #
161 | plt.figure()
162 | plt.axis('off')
163 | plt.imshow((im[0, :, 0].reshape(64, 64) - yy[0, :, t].reshape(64, 64)).detach().cpu().numpy(),
164 | cmap='coolwarm')
165 | plt.colorbar()
166 | plt.clim(-2, 2)
167 | plt.savefig(
168 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_error_" + str(20) + ".pdf"))
169 | plt.close()
170 | test_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()
171 | print(test_l2_full / ntest)
172 | else:
173 | for ep in range(args.epochs):
174 |
175 | model.train()
176 | train_l2_step = 0
177 | train_l2_full = 0
178 |
179 | for x, fx, yy in train_loader:
180 | loss = 0
181 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() # x: B,4096,2 fx: B,4096,T y: B,4096,T
182 | bsz = x.shape[0]
183 |
184 | for t in range(0, T, step):
185 | y = yy[..., t:t + step]
186 | im = model(x, fx=fx) # B , 4096 , 1
187 | loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))
188 | if t == 0:
189 | pred = im
190 | else:
191 | pred = torch.cat((pred, im), -1)
192 | fx = torch.cat((fx[..., step:], y), dim=-1) # detach() & groundtruth
193 |
194 | train_l2_step += loss.item()
195 | train_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()
196 | optimizer.zero_grad()
197 | loss.backward()
198 | if args.max_grad_norm is not None:
199 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
200 | optimizer.step()
201 | scheduler.step()
202 |
203 | test_l2_step = 0
204 | test_l2_full = 0
205 |
206 | model.eval()
207 |
208 | with torch.no_grad():
209 | for x, fx, yy in test_loader:
210 | loss = 0
211 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() # x : B, 4096, 2 fx : B, 4096 y : B, 4096, T
212 | bsz = x.shape[0]
213 | for t in range(0, T, step):
214 | y = yy[..., t:t + step]
215 | im = model(x, fx=fx)
216 | loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))
217 | if t == 0:
218 | pred = im
219 | else:
220 | pred = torch.cat((pred, im), -1)
221 | fx = torch.cat((fx[..., step:], im), dim=-1)
222 |
223 | test_l2_step += loss.item()
224 | test_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()
225 |
226 | print(
227 | "Epoch {} , train_step_loss:{:.5f} , train_full_loss:{:.5f} , test_step_loss:{:.5f} , test_full_loss:{:.5f}".format(
228 | ep, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step),
229 | test_l2_full / ntest))
230 |
231 | if ep % 100 == 0:
232 | if not os.path.exists('./checkpoints'):
233 | os.makedirs('./checkpoints')
234 | print('save model')
235 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
236 |
237 | if not os.path.exists('./checkpoints'):
238 | os.makedirs('./checkpoints')
239 | print('save model')
240 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
241 |
242 |
243 | if __name__ == "__main__":
244 | main()
245 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/exp_pipe.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import matplotlib.pyplot as plt
4 |
5 | parser = argparse.ArgumentParser('Training Transformer')
6 |
7 | parser.add_argument('--lr', type=float, default=1e-3)
8 | parser.add_argument('--epochs', type=int, default=500)
9 | parser.add_argument('--weight_decay', type=float, default=1e-5)
10 | parser.add_argument('--model', type=str, default='Transolver_2D')
11 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim')
12 | parser.add_argument('--n-layers', type=int, default=3, help='layers')
13 | parser.add_argument('--n-heads', type=int, default=4)
14 | parser.add_argument('--batch-size', type=int, default=8)
15 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use")
16 | parser.add_argument('--max_grad_norm', type=float, default=None)
17 | parser.add_argument('--downsamplex', type=int, default=1)
18 | parser.add_argument('--downsampley', type=int, default=1)
19 | parser.add_argument('--mlp_ratio', type=int, default=1)
20 | parser.add_argument('--dropout', type=float, default=0.0)
21 | parser.add_argument('--unified_pos', type=int, default=0)
22 | parser.add_argument('--ref', type=int, default=8)
23 | parser.add_argument('--slice_num', type=int, default=32)
24 | parser.add_argument('--eval', type=int, default=0)
25 | parser.add_argument('--save_name', type=str, default='pipe_UniPDE')
26 | parser.add_argument('--data_path', type=str, default='/data/fno/pipe')
27 | args = parser.parse_args()
28 | eval = args.eval
29 | save_name = args.save_name
30 |
31 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
32 |
33 | import numpy as np
34 | import torch
35 | from tqdm import *
36 | from utils.testloss import TestLoss
37 | from model_dict import get_model
38 | from utils.normalizer import UnitTransformer
39 |
40 |
41 | def count_parameters(model):
42 | total_params = 0
43 | for name, parameter in model.named_parameters():
44 | if not parameter.requires_grad: continue
45 | params = parameter.numel()
46 | total_params += params
47 | print(f"Total Trainable Params: {total_params}")
48 | return total_params
49 |
50 |
51 | def main():
52 | INPUT_X = args.data_path + '/Pipe_X.npy'
53 | INPUT_Y = args.data_path + '/Pipe_Y.npy'
54 | OUTPUT_Sigma = args.data_path + '/Pipe_Q.npy'
55 |
56 | ntrain = 1000
57 | ntest = 200
58 | N = 1200
59 |
60 | r1 = args.downsamplex
61 | r2 = args.downsampley
62 | s1 = int(((129 - 1) / r1) + 1)
63 | s2 = int(((129 - 1) / r2) + 1)
64 |
65 | inputX = np.load(INPUT_X)
66 | inputX = torch.tensor(inputX, dtype=torch.float)
67 | inputY = np.load(INPUT_Y)
68 | inputY = torch.tensor(inputY, dtype=torch.float)
69 | input = torch.stack([inputX, inputY], dim=-1)
70 |
71 | output = np.load(OUTPUT_Sigma)[:, 0]
72 | output = torch.tensor(output, dtype=torch.float)
73 | print(input.shape, output.shape)
74 | x_train = input[:N][:ntrain, ::r1, ::r2][:, :s1, :s2]
75 | y_train = output[:N][:ntrain, ::r1, ::r2][:, :s1, :s2]
76 | x_test = input[:N][-ntest:, ::r1, ::r2][:, :s1, :s2]
77 | y_test = output[:N][-ntest:, ::r1, ::r2][:, :s1, :s2]
78 | x_train = x_train.reshape(ntrain, -1, 2)
79 | x_test = x_test.reshape(ntest, -1, 2)
80 | y_train = y_train.reshape(ntrain, -1)
81 | y_test = y_test.reshape(ntest, -1)
82 |
83 | x_normalizer = UnitTransformer(x_train)
84 | y_normalizer = UnitTransformer(y_train)
85 |
86 | x_train = x_normalizer.encode(x_train)
87 | x_test = x_normalizer.encode(x_test)
88 | y_train = y_normalizer.encode(y_train)
89 |
90 | x_normalizer.cuda()
91 | y_normalizer.cuda()
92 |
93 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, x_train, y_train),
94 | batch_size=args.batch_size,
95 | shuffle=True)
96 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, x_test, y_test),
97 | batch_size=args.batch_size,
98 | shuffle=False)
99 |
100 | print("Dataloading is over.")
101 |
102 | model = get_model(args).Model(space_dim=2,
103 | n_layers=args.n_layers,
104 | n_hidden=args.n_hidden,
105 | dropout=args.dropout,
106 | n_head=args.n_heads,
107 | Time_Input=False,
108 | mlp_ratio=args.mlp_ratio,
109 | fun_dim=0,
110 | out_dim=1,
111 | slice_num=args.slice_num,
112 | ref=args.ref,
113 | unified_pos=args.unified_pos,
114 | H=s1, W=s2).cuda()
115 |
116 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
117 |
118 | print(args)
119 | print(model)
120 | count_parameters(model)
121 |
122 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
123 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,
124 | steps_per_epoch=len(train_loader))
125 | myloss = TestLoss(size_average=False)
126 |
127 | if eval:
128 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"), strict=False)
129 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '_resave' + '.pt'))
130 | model.eval()
131 | if not os.path.exists('./results/' + save_name + '/'):
132 | os.makedirs('./results/' + save_name + '/')
133 |
134 | rel_err = 0.0
135 | showcase = 10
136 | id = 0
137 |
138 | with torch.no_grad():
139 | for pos, fx, y in test_loader:
140 | id += 1
141 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
142 | out = model(x, None).squeeze(-1)
143 | out = y_normalizer.decode(out)
144 |
145 | tl = myloss(out, y).item()
146 | rel_err += tl
147 |
148 | if id < showcase:
149 | print(id)
150 | plt.axis('off')
151 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(),
152 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(),
153 | np.zeros([129, 129]),
154 | shading='auto',
155 | edgecolors='black', linewidths=0.1)
156 | plt.colorbar()
157 | plt.savefig(
158 | os.path.join('./results/' + save_name + '/',
159 | "input_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
160 | plt.close()
161 | plt.axis('off')
162 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(),
163 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(),
164 | out[0, :].reshape(129, 129).detach().cpu().numpy(),
165 | shading='auto', cmap='coolwarm')
166 | plt.colorbar()
167 | plt.clim(0, 0.3)
168 | plt.savefig(
169 | os.path.join('./results/' + save_name + '/',
170 | "pred_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
171 | plt.close()
172 | plt.axis('off')
173 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(),
174 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(),
175 | y[0, :].reshape(129, 129).detach().cpu().numpy(),
176 | shading='auto', cmap='coolwarm')
177 | plt.colorbar()
178 | plt.clim(0, 0.3)
179 | plt.savefig(
180 | os.path.join('./results/' + save_name + '/',
181 | "gt_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
182 | plt.close()
183 | plt.axis('off')
184 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(),
185 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(),
186 | out[0, :].reshape(129, 129).detach().cpu().numpy() - \
187 | y[0, :].reshape(129, 129).detach().cpu().numpy(),
188 | shading='auto', cmap='coolwarm')
189 | plt.colorbar()
190 | plt.clim(-0.02, 0.02)
191 | plt.savefig(
192 | os.path.join('./results/' + save_name + '/',
193 | "error_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0)
194 | plt.close()
195 |
196 | rel_err /= ntest
197 | print("rel_err:{}".format(rel_err))
198 | else:
199 | for ep in range(args.epochs):
200 |
201 | model.train()
202 | train_loss = 0
203 |
204 | for pos, fx, y in train_loader:
205 |
206 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() # x:B,N,2 fx:B,N,2 y:B,N
207 | optimizer.zero_grad()
208 | out = model(x, None).squeeze(-1)
209 |
210 | out = y_normalizer.decode(out)
211 | y = y_normalizer.decode(y)
212 |
213 | loss = myloss(out, y)
214 | loss.backward()
215 |
216 | # print("loss:{}".format(loss.item()/batch_size))
217 | if args.max_grad_norm is not None:
218 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
219 | optimizer.step()
220 | train_loss += loss.item()
221 | scheduler.step()
222 |
223 | train_loss = train_loss / ntrain
224 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss))
225 |
226 | model.eval()
227 | rel_err = 0.0
228 | with torch.no_grad():
229 | for pos, fx, y in test_loader:
230 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda()
231 | out = model(x, None).squeeze(-1)
232 | out = y_normalizer.decode(out)
233 |
234 | tl = myloss(out, y).item()
235 | rel_err += tl
236 |
237 | rel_err /= ntest
238 | print("rel_err:{}".format(rel_err))
239 |
240 | if ep % 100 == 0:
241 | if not os.path.exists('./checkpoints'):
242 | os.makedirs('./checkpoints')
243 | print('save model')
244 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
245 |
246 | if not os.path.exists('./checkpoints'):
247 | os.makedirs('./checkpoints')
248 | print('save model')
249 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))
250 |
251 |
252 | if __name__ == "__main__":
253 | main()
254 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/fig/scalibility.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/PDE-Solving-StandardBenchmark/fig/scalibility.png
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/fig/showcase.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/PDE-Solving-StandardBenchmark/fig/showcase.png
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/fig/standard_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/PDE-Solving-StandardBenchmark/fig/standard_benchmark.png
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/model/Embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange
5 |
6 |
7 | class RotaryEmbedding(nn.Module):
8 | def __init__(self, dim, min_freq=1 / 2, scale=1.):
9 | super().__init__()
10 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
11 | self.min_freq = min_freq
12 | self.scale = scale
13 | self.register_buffer('inv_freq', inv_freq)
14 |
15 | def forward(self, coordinates, device):
16 | # coordinates [b, n]
17 | t = coordinates.to(device).type_as(self.inv_freq)
18 | t = t * (self.scale / self.min_freq)
19 | freqs = torch.einsum('... i , j -> ... i j', t, self.inv_freq) # [b, n, d//2]
20 | return torch.cat((freqs, freqs), dim=-1) # [b, n, d]
21 |
22 |
23 | def rotate_half(x):
24 | x = rearrange(x, '... (j d) -> ... j d', j=2)
25 | x1, x2 = x.unbind(dim=-2)
26 | return torch.cat((-x2, x1), dim=-1)
27 |
28 |
29 | def apply_rotary_pos_emb(t, freqs):
30 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
31 |
32 |
33 | def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y):
34 | # split t into first half and second half
35 | # t: [b, h, n, d]
36 | # freq_x/y: [b, n, d]
37 | d = t.shape[-1]
38 | t_x, t_y = t[..., :d // 2], t[..., d // 2:]
39 |
40 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x),
41 | apply_rotary_pos_emb(t_y, freqs_y)), dim=-1)
42 |
43 |
44 | class PositionalEncoding(nn.Module):
45 | "Implement the PE function."
46 |
47 | def __init__(self, d_model, dropout, max_len=421 * 421):
48 | super(PositionalEncoding, self).__init__()
49 | self.dropout = nn.Dropout(p=dropout)
50 |
51 | # Compute the positional encodings once in log space.
52 | pe = torch.zeros(max_len, d_model)
53 | position = torch.arange(0, max_len).unsqueeze(1)
54 | div_term = torch.exp(
55 | torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
56 | )
57 | pe[:, 0::2] = torch.sin(position * div_term)
58 | pe[:, 1::2] = torch.cos(position * div_term)
59 | pe = pe.unsqueeze(0)
60 | self.register_buffer("pe", pe)
61 |
62 | def forward(self, x):
63 | x = x + self.pe[:, : x.size(1)].requires_grad_(False)
64 | return self.dropout(x)
65 |
66 |
67 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
68 | """
69 | Create sinusoidal timestep embeddings.
70 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
71 | These may be fractional.
72 | :param dim: the dimension of the output.
73 | :param max_period: controls the minimum frequency of the embeddings.
74 | :return: an [N x dim] Tensor of positional embeddings.
75 | """
76 |
77 | half = dim // 2
78 | freqs = torch.exp(
79 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
80 | ).to(device=timesteps.device)
81 | args = timesteps[:, None].float() * freqs[None]
82 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
83 | if dim % 2:
84 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
85 | return embedding
86 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/model/Physics_Attention.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from einops import rearrange, repeat
4 |
5 |
6 | class Physics_Attention_Irregular_Mesh(nn.Module):
7 | ## for irregular meshes in 1D, 2D or 3D space
8 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64):
9 | super().__init__()
10 | inner_dim = dim_head * heads
11 | self.dim_head = dim_head
12 | self.heads = heads
13 | self.scale = dim_head ** -0.5
14 | self.softmax = nn.Softmax(dim=-1)
15 | self.dropout = nn.Dropout(dropout)
16 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
17 |
18 | self.in_project_x = nn.Linear(dim, inner_dim)
19 | self.in_project_fx = nn.Linear(dim, inner_dim)
20 | self.in_project_slice = nn.Linear(dim_head, slice_num)
21 | for l in [self.in_project_slice]:
22 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
23 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
24 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
25 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
26 | self.to_out = nn.Sequential(
27 | nn.Linear(inner_dim, dim),
28 | nn.Dropout(dropout)
29 | )
30 |
31 | def forward(self, x):
32 | # B N C
33 | B, N, C = x.shape
34 |
35 | ### (1) Slice
36 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \
37 | .permute(0, 2, 1, 3).contiguous() # B H N C
38 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \
39 | .permute(0, 2, 1, 3).contiguous() # B H N C
40 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G
41 | slice_norm = slice_weights.sum(2) # B H G
42 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
43 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
44 |
45 | ### (2) Attention among slice tokens
46 | q_slice_token = self.to_q(slice_token)
47 | k_slice_token = self.to_k(slice_token)
48 | v_slice_token = self.to_v(slice_token)
49 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
50 | attn = self.softmax(dots)
51 | attn = self.dropout(attn)
52 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
53 |
54 | ### (3) Deslice
55 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
56 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
57 | return self.to_out(out_x)
58 |
59 |
60 | class Physics_Attention_Structured_Mesh_2D(nn.Module):
61 | ## for structured mesh in 2D space
62 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, H=101, W=31, kernel=3): # kernel=3):
63 | super().__init__()
64 | inner_dim = dim_head * heads
65 | self.dim_head = dim_head
66 | self.heads = heads
67 | self.scale = dim_head ** -0.5
68 | self.softmax = nn.Softmax(dim=-1)
69 | self.dropout = nn.Dropout(dropout)
70 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
71 | self.H = H
72 | self.W = W
73 |
74 | self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2)
75 | self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2)
76 | self.in_project_slice = nn.Linear(dim_head, slice_num)
77 | for l in [self.in_project_slice]:
78 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
79 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
80 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
81 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
82 |
83 | self.to_out = nn.Sequential(
84 | nn.Linear(inner_dim, dim),
85 | nn.Dropout(dropout)
86 | )
87 |
88 | def forward(self, x):
89 | # B N C
90 | B, N, C = x.shape
91 | x = x.reshape(B, self.H, self.W, C).contiguous().permute(0, 3, 1, 2).contiguous() # B C H W
92 |
93 | ### (1) Slice
94 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
95 | .permute(0, 2, 1, 3).contiguous() # B H N C
96 | x_mid = self.in_project_x(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
97 | .permute(0, 2, 1, 3).contiguous() # B H N G
98 | slice_weights = self.softmax(
99 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G
100 | slice_norm = slice_weights.sum(2) # B H G
101 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
102 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
103 |
104 | ### (2) Attention among slice tokens
105 | q_slice_token = self.to_q(slice_token)
106 | k_slice_token = self.to_k(slice_token)
107 | v_slice_token = self.to_v(slice_token)
108 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
109 | attn = self.softmax(dots)
110 | attn = self.dropout(attn)
111 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
112 |
113 | ### (3) Deslice
114 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
115 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
116 | return self.to_out(out_x)
117 |
118 |
119 | class Physics_Attention_Structured_Mesh_3D(nn.Module):
120 | ## for structured mesh in 3D space
121 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=32, H=32, W=32, D=32, kernel=3):
122 | super().__init__()
123 | inner_dim = dim_head * heads
124 | self.dim_head = dim_head
125 | self.heads = heads
126 | self.scale = dim_head ** -0.5
127 | self.softmax = nn.Softmax(dim=-1)
128 | self.dropout = nn.Dropout(dropout)
129 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
130 | self.H = H
131 | self.W = W
132 | self.D = D
133 |
134 | self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2)
135 | self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2)
136 | self.in_project_slice = nn.Linear(dim_head, slice_num)
137 | for l in [self.in_project_slice]:
138 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
139 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
140 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
141 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
142 | self.to_out = nn.Sequential(
143 | nn.Linear(inner_dim, dim),
144 | nn.Dropout(dropout)
145 | )
146 |
147 | def forward(self, x):
148 | # B N C
149 | B, N, C = x.shape
150 | x = x.reshape(B, self.H, self.W, self.D, C).contiguous().permute(0, 4, 1, 2, 3).contiguous() # B C H W
151 |
152 | ### (1) Slice
153 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
154 | .permute(0, 2, 1, 3).contiguous() # B H N C
155 | x_mid = self.in_project_x(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
156 | .permute(0, 2, 1, 3).contiguous() # B H N G
157 | slice_weights = self.softmax(
158 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G
159 | slice_norm = slice_weights.sum(2) # B H G
160 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
161 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
162 |
163 | ### (2) Attention among slice tokens
164 | q_slice_token = self.to_q(slice_token)
165 | k_slice_token = self.to_k(slice_token)
166 | v_slice_token = self.to_v(slice_token)
167 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
168 | attn = self.softmax(dots)
169 | attn = self.dropout(attn)
170 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
171 |
172 | ### (3) Deslice
173 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
174 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
175 | return self.to_out(out_x)
176 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/model/Transolver_Irregular_Mesh.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import trunc_normal_
4 | from model.Embedding import timestep_embedding
5 | import numpy as np
6 | from model.Physics_Attention import Physics_Attention_Irregular_Mesh
7 |
8 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1),
9 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU}
10 |
11 |
12 | class MLP(nn.Module):
13 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True):
14 | super(MLP, self).__init__()
15 |
16 | if act in ACTIVATION.keys():
17 | act = ACTIVATION[act]
18 | else:
19 | raise NotImplementedError
20 | self.n_input = n_input
21 | self.n_hidden = n_hidden
22 | self.n_output = n_output
23 | self.n_layers = n_layers
24 | self.res = res
25 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act())
26 | self.linear_post = nn.Linear(n_hidden, n_output)
27 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)])
28 |
29 | def forward(self, x):
30 | x = self.linear_pre(x)
31 | for i in range(self.n_layers):
32 | if self.res:
33 | x = self.linears[i](x) + x
34 | else:
35 | x = self.linears[i](x)
36 | x = self.linear_post(x)
37 | return x
38 |
39 |
40 | class Transolver_block(nn.Module):
41 | """Transformer encoder block."""
42 |
43 | def __init__(
44 | self,
45 | num_heads: int,
46 | hidden_dim: int,
47 | dropout: float,
48 | act='gelu',
49 | mlp_ratio=4,
50 | last_layer=False,
51 | out_dim=1,
52 | slice_num=32,
53 | ):
54 | super().__init__()
55 | self.last_layer = last_layer
56 | self.ln_1 = nn.LayerNorm(hidden_dim)
57 | self.Attn = Physics_Attention_Irregular_Mesh(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
58 | dropout=dropout, slice_num=slice_num)
59 | self.ln_2 = nn.LayerNorm(hidden_dim)
60 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
61 | if self.last_layer:
62 | self.ln_3 = nn.LayerNorm(hidden_dim)
63 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
64 |
65 | def forward(self, fx):
66 | fx = self.Attn(self.ln_1(fx)) + fx
67 | fx = self.mlp(self.ln_2(fx)) + fx
68 | if self.last_layer:
69 | return self.mlp2(self.ln_3(fx))
70 | else:
71 | return fx
72 |
73 |
74 | class Model(nn.Module):
75 | def __init__(self,
76 | space_dim=1,
77 | n_layers=5,
78 | n_hidden=256,
79 | dropout=0.0,
80 | n_head=8,
81 | Time_Input=False,
82 | act='gelu',
83 | mlp_ratio=1,
84 | fun_dim=1,
85 | out_dim=1,
86 | slice_num=32,
87 | ref=8,
88 | unified_pos=False
89 | ):
90 | super(Model, self).__init__()
91 | self.__name__ = 'Transolver_1D'
92 | self.ref = ref
93 | self.unified_pos = unified_pos
94 | self.Time_Input = Time_Input
95 | self.n_hidden = n_hidden
96 | self.space_dim = space_dim
97 | if self.unified_pos:
98 | self.preprocess = MLP(fun_dim + self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
99 | else:
100 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
101 | if Time_Input:
102 | self.time_fc = nn.Sequential(nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden))
103 |
104 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden,
105 | dropout=dropout,
106 | act=act,
107 | mlp_ratio=mlp_ratio,
108 | out_dim=out_dim,
109 | slice_num=slice_num,
110 | last_layer=(_ == n_layers - 1))
111 | for _ in range(n_layers)])
112 | self.initialize_weights()
113 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float))
114 |
115 | def initialize_weights(self):
116 | self.apply(self._init_weights)
117 |
118 | def _init_weights(self, m):
119 | if isinstance(m, nn.Linear):
120 | trunc_normal_(m.weight, std=0.02)
121 | if isinstance(m, nn.Linear) and m.bias is not None:
122 | nn.init.constant_(m.bias, 0)
123 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
124 | nn.init.constant_(m.bias, 0)
125 | nn.init.constant_(m.weight, 1.0)
126 |
127 | def get_grid(self, x, batchsize=1):
128 | # x: B N 2
129 | # grid_ref
130 | gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
131 | gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1])
132 | gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
133 | gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1])
134 | grid_ref = torch.cat((gridx, gridy), dim=-1).cuda().reshape(batchsize, self.ref * self.ref, 2) # B H W 8 8 2
135 |
136 | pos = torch.sqrt(torch.sum((x[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, dim=-1)). \
137 | reshape(batchsize, x.shape[1], self.ref * self.ref).contiguous()
138 | return pos
139 |
140 | def forward(self, x, fx, T=None):
141 | if self.unified_pos:
142 | x = self.get_grid(x, x.shape[0])
143 | if fx is not None:
144 | fx = torch.cat((x, fx), -1)
145 | fx = self.preprocess(fx)
146 | else:
147 | fx = self.preprocess(x)
148 | fx = fx + self.placeholder[None, None, :]
149 |
150 | if T is not None:
151 | Time_emb = timestep_embedding(T, self.n_hidden).repeat(1, x.shape[1], 1)
152 | Time_emb = self.time_fc(Time_emb)
153 | fx = fx + Time_emb
154 |
155 | for block in self.blocks:
156 | fx = block(fx)
157 |
158 | return fx
159 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/model/Transolver_Structured_Mesh_2D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from timm.models.layers import trunc_normal_
5 | from model.Embedding import timestep_embedding
6 | from model.Physics_Attention import Physics_Attention_Structured_Mesh_2D
7 |
8 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1),
9 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU}
10 |
11 |
12 | class MLP(nn.Module):
13 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True):
14 | super(MLP, self).__init__()
15 |
16 | if act in ACTIVATION.keys():
17 | act = ACTIVATION[act]
18 | else:
19 | raise NotImplementedError
20 | self.n_input = n_input
21 | self.n_hidden = n_hidden
22 | self.n_output = n_output
23 | self.n_layers = n_layers
24 | self.res = res
25 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act())
26 | self.linear_post = nn.Linear(n_hidden, n_output)
27 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)])
28 |
29 | def forward(self, x):
30 | x = self.linear_pre(x)
31 | for i in range(self.n_layers):
32 | if self.res:
33 | x = self.linears[i](x) + x
34 | else:
35 | x = self.linears[i](x)
36 | x = self.linear_post(x)
37 | return x
38 |
39 |
40 | class Transolver_block(nn.Module):
41 | """Transformer encoder block."""
42 |
43 | def __init__(
44 | self,
45 | num_heads: int,
46 | hidden_dim: int,
47 | dropout: float,
48 | act='gelu',
49 | mlp_ratio=4,
50 | last_layer=False,
51 | out_dim=1,
52 | slice_num=32,
53 | H=85,
54 | W=85
55 | ):
56 | super().__init__()
57 | self.last_layer = last_layer
58 | self.ln_1 = nn.LayerNorm(hidden_dim)
59 | self.Attn = Physics_Attention_Structured_Mesh_2D(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
60 | dropout=dropout, slice_num=slice_num, H=H, W=W)
61 |
62 | self.ln_2 = nn.LayerNorm(hidden_dim)
63 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
64 | if self.last_layer:
65 | self.ln_3 = nn.LayerNorm(hidden_dim)
66 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
67 |
68 | def forward(self, fx):
69 | fx = self.Attn(self.ln_1(fx)) + fx
70 | fx = self.mlp(self.ln_2(fx)) + fx
71 | if self.last_layer:
72 | return self.mlp2(self.ln_3(fx))
73 | else:
74 | return fx
75 |
76 |
77 | class Model(nn.Module):
78 | def __init__(self,
79 | space_dim=1,
80 | n_layers=5,
81 | n_hidden=256,
82 | dropout=0.0,
83 | n_head=8,
84 | Time_Input=False,
85 | act='gelu',
86 | mlp_ratio=1,
87 | fun_dim=1,
88 | out_dim=1,
89 | slice_num=32,
90 | ref=8,
91 | unified_pos=False,
92 | H=85,
93 | W=85,
94 | ):
95 | super(Model, self).__init__()
96 | self.__name__ = 'Transolver_2D'
97 | self.H = H
98 | self.W = W
99 | self.ref = ref
100 | self.unified_pos = unified_pos
101 | if self.unified_pos:
102 | self.pos = self.get_grid()
103 | self.preprocess = MLP(fun_dim + self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
104 | else:
105 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
106 |
107 | self.Time_Input = Time_Input
108 | self.n_hidden = n_hidden
109 | self.space_dim = space_dim
110 | if Time_Input:
111 | self.time_fc = nn.Sequential(nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden))
112 |
113 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden,
114 | dropout=dropout,
115 | act=act,
116 | mlp_ratio=mlp_ratio,
117 | out_dim=out_dim,
118 | slice_num=slice_num,
119 | H=H,
120 | W=W,
121 | last_layer=(_ == n_layers - 1))
122 | for _ in range(n_layers)])
123 | self.initialize_weights()
124 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float))
125 |
126 | def initialize_weights(self):
127 | self.apply(self._init_weights)
128 |
129 | def _init_weights(self, m):
130 | if isinstance(m, nn.Linear):
131 | trunc_normal_(m.weight, std=0.02)
132 | if isinstance(m, nn.Linear) and m.bias is not None:
133 | nn.init.constant_(m.bias, 0)
134 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
135 | nn.init.constant_(m.bias, 0)
136 | nn.init.constant_(m.weight, 1.0)
137 |
138 | def get_grid(self, batchsize=1):
139 | size_x, size_y = self.H, self.W
140 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
141 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
142 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
143 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
144 | grid = torch.cat((gridx, gridy), dim=-1).cuda() # B H W 2
145 |
146 | gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
147 | gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1])
148 | gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
149 | gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1])
150 | grid_ref = torch.cat((gridx, gridy), dim=-1).cuda() # B H W 8 8 2
151 |
152 | pos = torch.sqrt(torch.sum((grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) ** 2, dim=-1)). \
153 | reshape(batchsize, size_x, size_y, self.ref * self.ref).contiguous()
154 | return pos
155 |
156 | def forward(self, x, fx, T=None):
157 | if self.unified_pos:
158 | x = self.pos.repeat(x.shape[0], 1, 1, 1).reshape(x.shape[0], self.H * self.W, self.ref * self.ref)
159 | if fx is not None:
160 | fx = torch.cat((x, fx), -1)
161 | fx = self.preprocess(fx)
162 | else:
163 | fx = self.preprocess(x)
164 | fx = fx + self.placeholder[None, None, :]
165 |
166 | if T is not None:
167 | Time_emb = timestep_embedding(T, self.n_hidden).repeat(1, x.shape[1], 1)
168 | Time_emb = self.time_fc(Time_emb)
169 | fx = fx + Time_emb
170 |
171 | for block in self.blocks:
172 | fx = block(fx)
173 |
174 | return fx
175 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/model/Transolver_Structured_Mesh_3D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from timm.models.layers import trunc_normal_
5 | from model.Embedding import timestep_embedding
6 | import torch.utils.checkpoint as checkpoint
7 | from model.Physics_Attention import Physics_Attention_Structured_Mesh_3D
8 |
9 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1),
10 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU}
11 |
12 |
13 | class MLP(nn.Module):
14 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True):
15 | super(MLP, self).__init__()
16 |
17 | if act in ACTIVATION.keys():
18 | act = ACTIVATION[act]
19 | else:
20 | raise NotImplementedError
21 | self.n_input = n_input
22 | self.n_hidden = n_hidden
23 | self.n_output = n_output
24 | self.n_layers = n_layers
25 | self.res = res
26 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act())
27 | self.linear_post = nn.Linear(n_hidden, n_output)
28 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)])
29 |
30 | def forward(self, x):
31 | x = self.linear_pre(x)
32 | for i in range(self.n_layers):
33 | if self.res:
34 | x = self.linears[i](x) + x
35 | else:
36 | x = self.linears[i](x)
37 | x = self.linear_post(x)
38 | return x
39 |
40 |
41 | class Transolver_block(nn.Module):
42 | """Transformer encoder block."""
43 |
44 | def __init__(
45 | self,
46 | num_heads: int,
47 | hidden_dim: int,
48 | dropout: float,
49 | act='gelu',
50 | mlp_ratio=4,
51 | last_layer=False,
52 | out_dim=1,
53 | slice_num=32,
54 | H=32,
55 | W=32,
56 | D=32
57 | ):
58 | super().__init__()
59 | self.last_layer = last_layer
60 | self.ln_1 = nn.LayerNorm(hidden_dim)
61 | self.Attn = Physics_Attention_Structured_Mesh_3D(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
62 | dropout=dropout, slice_num=slice_num, H=H, W=W, D=D)
63 |
64 | self.ln_2 = nn.LayerNorm(hidden_dim)
65 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
66 | if self.last_layer:
67 | self.ln_3 = nn.LayerNorm(hidden_dim)
68 | self.mlp2 = nn.Linear(hidden_dim, out_dim)
69 |
70 | def forward(self, fx):
71 | fx = self.Attn(self.ln_1(fx)) + fx
72 | fx = self.mlp(self.ln_2(fx)) + fx
73 | if self.last_layer:
74 | return self.mlp2(self.ln_3(fx))
75 | else:
76 | return fx
77 |
78 |
79 | class Model(nn.Module):
80 | def __init__(self,
81 | space_dim=1,
82 | n_layers=5,
83 | n_hidden=256,
84 | dropout=0.0,
85 | n_head=8,
86 | Time_Input=False,
87 | act='gelu',
88 | mlp_ratio=1,
89 | fun_dim=1,
90 | out_dim=1,
91 | slice_num=32,
92 | ref=8,
93 | unified_pos=False,
94 | H=32,
95 | W=32,
96 | D=32,
97 | ):
98 | super(Model, self).__init__()
99 | self.__name__ = 'Transolver_3D'
100 | self.use_checkpoint = False
101 | self.H = H
102 | self.W = W
103 | self.D = D
104 | self.ref = ref
105 | self.unified_pos = unified_pos
106 | if self.unified_pos:
107 | self.pos = self.get_grid()
108 | self.preprocess = MLP(fun_dim + self.ref * self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0,
109 | res=False, act=act)
110 | else:
111 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
112 |
113 | self.Time_Input = Time_Input
114 | self.n_hidden = n_hidden
115 | self.space_dim = space_dim
116 | if Time_Input:
117 | self.time_fc = nn.Sequential(nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden))
118 |
119 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden,
120 | dropout=dropout,
121 | act=act,
122 | mlp_ratio=mlp_ratio,
123 | out_dim=out_dim,
124 | slice_num=slice_num,
125 | H=H,
126 | W=W,
127 | D=D,
128 | last_layer=(_ == n_layers - 1))
129 | for _ in range(n_layers)])
130 | self.initialize_weights()
131 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float))
132 |
133 | def initialize_weights(self):
134 | self.apply(self._init_weights)
135 |
136 | def _init_weights(self, m):
137 | if isinstance(m, nn.Linear):
138 | trunc_normal_(m.weight, std=0.02)
139 | if isinstance(m, nn.Linear) and m.bias is not None:
140 | nn.init.constant_(m.bias, 0)
141 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
142 | nn.init.constant_(m.bias, 0)
143 | nn.init.constant_(m.weight, 1.0)
144 |
145 | def get_grid(self, batchsize=1):
146 | size_x, size_y, size_z = self.H, self.W, self.D
147 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
148 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
149 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
150 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
151 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
152 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
153 | grid = torch.cat((gridx, gridy, gridz), dim=-1).cuda() # B H W D 3
154 |
155 | gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
156 | gridx = gridx.reshape(1, self.ref, 1, 1, 1).repeat([batchsize, 1, self.ref, self.ref, 1])
157 | gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
158 | gridy = gridy.reshape(1, 1, self.ref, 1, 1).repeat([batchsize, self.ref, 1, self.ref, 1])
159 | gridz = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
160 | gridz = gridz.reshape(1, 1, 1, self.ref, 1).repeat([batchsize, self.ref, self.ref, 1, 1])
161 | grid_ref = torch.cat((gridx, gridy, gridz), dim=-1).cuda() # B 4 4 4 3
162 |
163 | pos = torch.sqrt(
164 | torch.sum((grid[:, :, :, :, None, None, None, :] - grid_ref[:, None, None, None, :, :, :, :]) ** 2,
165 | dim=-1)). \
166 | reshape(batchsize, size_x, size_y, size_z, self.ref * self.ref * self.ref).contiguous()
167 | return pos
168 |
169 | def forward(self, x, fx, T=None):
170 | if self.unified_pos:
171 | x = self.pos.repeat(x.shape[0], 1, 1, 1, 1).reshape(x.shape[0], self.H * self.W * self.D,
172 | self.ref * self.ref * self.ref)
173 | if fx is not None:
174 | fx = torch.cat((x, fx), -1)
175 | fx = self.preprocess(fx)
176 | else:
177 | fx = self.preprocess(x)
178 | fx = fx + self.placeholder[None, None, :]
179 |
180 | if T is not None:
181 | Time_emb = timestep_embedding(T, self.n_hidden).repeat(1, x.shape[1], 1)
182 | Time_emb = self.time_fc(Time_emb)
183 | fx = fx + Time_emb
184 |
185 | for block in self.blocks:
186 | if self.use_checkpoint:
187 | fx = checkpoint.checkpoint(block, fx)
188 | else:
189 | fx = block(fx)
190 |
191 | return fx
192 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/model_dict.py:
--------------------------------------------------------------------------------
1 | from model import Transolver_Irregular_Mesh, Transolver_Structured_Mesh_2D, Transolver_Structured_Mesh_3D
2 |
3 |
4 | def get_model(args):
5 | model_dict = {
6 | 'Transolver_Irregular_Mesh': Transolver_Irregular_Mesh, # for PDEs in 1D space or in unstructured meshes
7 | 'Transolver_Structured_Mesh_2D': Transolver_Structured_Mesh_2D,
8 | 'Transolver_Structured_Mesh_3D': Transolver_Structured_Mesh_3D,
9 | }
10 | return model_dict[args.model]
11 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.10.1
2 | h5py==3.8.0
3 | dgl==1.1.0
4 | einops==0.6.1
5 | scipy==1.7.3
6 | timm==0.9.2
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/scripts/Transolver_Airfoil.sh:
--------------------------------------------------------------------------------
1 | python exp_airfoil.py \
2 | --gpu 5 \
3 | --model Transolver_Structured_Mesh_2D \
4 | --n-hidden 128 \
5 | --n-heads 8 \
6 | --n-layers 8 \
7 | --lr 0.001 \
8 | --max_grad_norm 0.1 \
9 | --batch-size 4 \
10 | --slice_num 64 \
11 | --unified_pos 0 \
12 | --ref 8 \
13 | --eval 0 \
14 | --save_name airfoil_Transolver
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/scripts/Transolver_Darcy.sh:
--------------------------------------------------------------------------------
1 | python exp_darcy.py \
2 | --gpu 4 \
3 | --model Transolver_Structured_Mesh_2D \
4 | --n-hidden 128 \
5 | --n-heads 8 \
6 | --n-layers 8 \
7 | --lr 0.001 \
8 | --max_grad_norm 0.1 \
9 | --batch-size 4 \
10 | --slice_num 64 \
11 | --unified_pos 1 \
12 | --ref 8 \
13 | --eval 0 \
14 | --downsample 5 \
15 | --save_name darcy_UniPDE
16 |
17 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/scripts/Transolver_Elas.sh:
--------------------------------------------------------------------------------
1 | python exp_elas.py \
2 | --gpu 6 \
3 | --model Transolver_Irregular_Mesh \
4 | --n-hidden 128 \
5 | --n-heads 8 \
6 | --n-layers 8 \
7 | --lr 0.001 \
8 | --max_grad_norm 0.1 \
9 | --batch-size 1 \
10 | --slice_num 64 \
11 | --unified_pos 0 \
12 | --ref 8 \
13 | --eval 0 \
14 | --save_name elas_Transolver
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/scripts/Transolver_NS.sh:
--------------------------------------------------------------------------------
1 | python exp_ns.py \
2 | --gpu 2 \
3 | --model Transolver_Structured_Mesh_2D \
4 | --n-hidden 256 \
5 | --n-heads 8 \
6 | --n-layers 8 \
7 | --lr 0.001 \
8 | --batch-size 2 \
9 | --slice_num 32 \
10 | --unified_pos 1 \
11 | --ref 8 \
12 | --eval 0 \
13 | --save_name ns_Transolver
14 |
15 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/scripts/Transolver_Pipe.sh:
--------------------------------------------------------------------------------
1 | python exp_pipe.py \
2 | --gpu 7 \
3 | --model Transolver_Structured_Mesh_2D \
4 | --n-hidden 128 \
5 | --n-heads 8 \
6 | --n-layers 8 \
7 | --mlp_ratio 2 \
8 | --lr 0.001 \
9 | --max_grad_norm 0.1 \
10 | --batch-size 8 \
11 | --slice_num 64 \
12 | --unified_pos 0 \
13 | --ref 8 \
14 | --eval 0 \
15 | --save_name pipe_Transolver
16 |
17 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/scripts/Transolver_Plas.sh:
--------------------------------------------------------------------------------
1 | python exp_plas.py \
2 | --gpu 3 \
3 | --model Transolver_Structured_Mesh_2D \
4 | --n-hidden 128 \
5 | --n-heads 8 \
6 | --n-layers 8 \
7 | --lr 0.001 \
8 | --max_grad_norm 0.1 \
9 | --batch-size 8 \
10 | --slice_num 64 \
11 | --unified_pos 0 \
12 | --ref 8 \
13 | --eval 0 \
14 | --save_name plas_Transolver
15 |
16 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/utils/normalizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import *
3 |
4 |
5 | class IdentityTransformer():
6 | def __init__(self, X):
7 | self.mean = X.mean(dim=0, keepdim=True)
8 | self.std = X.std(dim=0, keepdim=True) + 1e-8
9 |
10 | def to(self, device):
11 | self.mean = self.mean.to(device)
12 | self.std = self.std.to(device)
13 | return self
14 |
15 | def cuda(self):
16 | self.mean = self.mean.cuda()
17 | self.std = self.std.cuda()
18 |
19 | def cpu(self):
20 | self.mean = self.mean.cpu()
21 | self.std = self.std.cpu()
22 |
23 | def encode(self, x):
24 | return x
25 |
26 | def decode(self, x):
27 | return x
28 |
29 |
30 | class UnitTransformer():
31 | def __init__(self, X):
32 | self.mean = X.mean(dim=(0, 1), keepdim=True)
33 | self.std = X.std(dim=(0, 1), keepdim=True) + 1e-8
34 |
35 | def to(self, device):
36 | self.mean = self.mean.to(device)
37 | self.std = self.std.to(device)
38 | return self
39 |
40 | def cuda(self):
41 | self.mean = self.mean.cuda()
42 | self.std = self.std.cuda()
43 |
44 | def cpu(self):
45 | self.mean = self.mean.cpu()
46 | self.std = self.std.cpu()
47 |
48 | def encode(self, x):
49 | x = (x - self.mean) / (self.std)
50 | return x
51 |
52 | def decode(self, x):
53 | return x * self.std + self.mean
54 |
55 | def transform(self, X, inverse=True, component='all'):
56 | if component == 'all' or 'all-reduce':
57 | if inverse:
58 | orig_shape = X.shape
59 | return (X * (self.std - 1e-8) + self.mean).view(orig_shape)
60 | else:
61 | return (X - self.mean) / self.std
62 | else:
63 | if inverse:
64 | orig_shape = X.shape
65 | return (X * (self.std[:, component] - 1e-8) + self.mean[:, component]).view(orig_shape)
66 | else:
67 | return (X - self.mean[:, component]) / self.std[:, component]
68 |
69 |
70 | class UnitGaussianNormalizer(object):
71 | def __init__(self, x, eps=0.00001, time_last=True):
72 | super(UnitGaussianNormalizer, self).__init__()
73 |
74 | self.mean = torch.mean(x, 0)
75 | self.std = torch.std(x, 0)
76 | self.eps = eps
77 | self.time_last = time_last # if the time dimension is the last dim
78 |
79 | def encode(self, x):
80 | x = (x - self.mean) / (self.std + self.eps)
81 | return x
82 |
83 | def decode(self, x, sample_idx=None):
84 | # sample_idx is the spatial sampling mask
85 | if sample_idx is None:
86 | std = self.std + self.eps # n
87 | mean = self.mean
88 | else:
89 | if self.mean.ndim == sample_idx.ndim or self.time_last:
90 | std = self.std[sample_idx] + self.eps # batch*n
91 | mean = self.mean[sample_idx]
92 | if self.mean.ndim > sample_idx.ndim and not self.time_last:
93 | std = self.std[..., sample_idx] + self.eps # T*batch*n
94 | mean = self.mean[..., sample_idx]
95 | # x is in shape of batch*(spatial discretization size) or T*batch*(spatial discretization size)
96 | x = (x * std) + mean
97 | return x
98 |
99 | def to(self, device):
100 | if torch.is_tensor(self.mean):
101 | self.mean = self.mean.to(device)
102 | self.std = self.std.to(device)
103 | else:
104 | self.mean = torch.from_numpy(self.mean).to(device)
105 | self.std = torch.from_numpy(self.std).to(device)
106 | return self
107 |
108 | def cuda(self):
109 | self.mean = self.mean.cuda()
110 | self.std = self.std.cuda()
111 |
112 | def cpu(self):
113 | self.mean = self.mean.cpu()
114 | self.std = self.std.cpu()
115 |
--------------------------------------------------------------------------------
/PDE-Solving-StandardBenchmark/utils/testloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TestLoss(object):
5 | def __init__(self, d=2, p=2, size_average=True, reduction=True):
6 | super(TestLoss, self).__init__()
7 |
8 | assert d > 0 and p > 0
9 |
10 | self.d = d
11 | self.p = p
12 | self.reduction = reduction
13 | self.size_average = size_average
14 |
15 | def abs(self, x, y):
16 | num_examples = x.size()[0]
17 |
18 | h = 1.0 / (x.size()[1] - 1.0)
19 |
20 | all_norms = (h ** (self.d / self.p)) * torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p,
21 | 1)
22 |
23 | if self.reduction:
24 | if self.size_average:
25 | return torch.mean(all_norms)
26 | else:
27 | return torch.sum(all_norms)
28 |
29 | return all_norms
30 |
31 | def rel(self, x, y):
32 | num_examples = x.size()[0]
33 |
34 | diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1)
35 | y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1)
36 | if self.reduction:
37 | if self.size_average:
38 | return torch.mean(diff_norms / y_norms)
39 | else:
40 | return torch.sum(diff_norms / y_norms)
41 |
42 | return diff_norms / y_norms
43 |
44 | def __call__(self, x, y):
45 | return self.rel(x, y)
46 |
--------------------------------------------------------------------------------
/Physics_Attention.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from einops import rearrange, repeat
4 |
5 |
6 | class Physics_Attention_Irregular_Mesh(nn.Module):
7 | ## for irregular meshes in 1D, 2D or 3D space
8 |
9 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64):
10 | super().__init__()
11 | inner_dim = dim_head * heads
12 | self.dim_head = dim_head
13 | self.heads = heads
14 | self.scale = dim_head ** -0.5
15 | self.softmax = nn.Softmax(dim=-1)
16 | self.dropout = nn.Dropout(dropout)
17 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
18 |
19 | self.in_project_x = nn.Linear(dim, inner_dim)
20 | self.in_project_fx = nn.Linear(dim, inner_dim)
21 | self.in_project_slice = nn.Linear(dim_head, slice_num)
22 | for l in [self.in_project_slice]:
23 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
24 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
25 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
26 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
27 | self.to_out = nn.Sequential(
28 | nn.Linear(inner_dim, dim),
29 | nn.Dropout(dropout)
30 | )
31 |
32 | def forward(self, x):
33 | # B N C
34 | B, N, C = x.shape
35 |
36 | ### (1) Slice
37 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \
38 | .permute(0, 2, 1, 3).contiguous() # B H N C
39 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \
40 | .permute(0, 2, 1, 3).contiguous() # B H N C
41 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G
42 | slice_norm = slice_weights.sum(2) # B H G
43 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
44 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
45 |
46 | ### (2) Attention among slice tokens
47 | q_slice_token = self.to_q(slice_token)
48 | k_slice_token = self.to_k(slice_token)
49 | v_slice_token = self.to_v(slice_token)
50 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
51 | attn = self.softmax(dots)
52 | attn = self.dropout(attn)
53 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
54 |
55 | ### (3) Deslice
56 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
57 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
58 | return self.to_out(out_x)
59 |
60 |
61 | class Physics_Attention_Structured_Mesh_2D(nn.Module):
62 | ## for structured mesh in 2D space
63 |
64 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, H=101, W=31, kernel=3): # kernel=3):
65 | super().__init__()
66 | inner_dim = dim_head * heads
67 | self.dim_head = dim_head
68 | self.heads = heads
69 | self.scale = dim_head ** -0.5
70 | self.softmax = nn.Softmax(dim=-1)
71 | self.dropout = nn.Dropout(dropout)
72 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
73 | self.H = H
74 | self.W = W
75 |
76 | self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2)
77 | self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2)
78 | self.in_project_slice = nn.Linear(dim_head, slice_num)
79 | for l in [self.in_project_slice]:
80 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
81 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
82 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
83 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
84 |
85 | self.to_out = nn.Sequential(
86 | nn.Linear(inner_dim, dim),
87 | nn.Dropout(dropout)
88 | )
89 |
90 | def forward(self, x):
91 | # B N C
92 | B, N, C = x.shape
93 | x = x.reshape(B, self.H, self.W, C).contiguous().permute(0, 3, 1, 2).contiguous() # B C H W
94 |
95 | ### (1) Slice
96 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
97 | .permute(0, 2, 1, 3).contiguous() # B H N C
98 | x_mid = self.in_project_x(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
99 | .permute(0, 2, 1, 3).contiguous() # B H N G
100 | slice_weights = self.softmax(
101 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G
102 | slice_norm = slice_weights.sum(2) # B H G
103 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
104 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
105 |
106 | ### (2) Attention among slice tokens
107 | q_slice_token = self.to_q(slice_token)
108 | k_slice_token = self.to_k(slice_token)
109 | v_slice_token = self.to_v(slice_token)
110 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
111 | attn = self.softmax(dots)
112 | attn = self.dropout(attn)
113 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
114 |
115 | ### (3) Deslice
116 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
117 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
118 | return self.to_out(out_x)
119 |
120 |
121 | class Physics_Attention_Structured_Mesh_3D(nn.Module):
122 | ## for structured mesh in 3D space
123 |
124 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=32, H=32, W=32, D=32, kernel=3):
125 | super().__init__()
126 | inner_dim = dim_head * heads
127 | self.dim_head = dim_head
128 | self.heads = heads
129 | self.scale = dim_head ** -0.5
130 | self.softmax = nn.Softmax(dim=-1)
131 | self.dropout = nn.Dropout(dropout)
132 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
133 | self.H = H
134 | self.W = W
135 | self.D = D
136 |
137 | self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2)
138 | self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2)
139 | self.in_project_slice = nn.Linear(dim_head, slice_num)
140 | for l in [self.in_project_slice]:
141 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization
142 | self.to_q = nn.Linear(dim_head, dim_head, bias=False)
143 | self.to_k = nn.Linear(dim_head, dim_head, bias=False)
144 | self.to_v = nn.Linear(dim_head, dim_head, bias=False)
145 | self.to_out = nn.Sequential(
146 | nn.Linear(inner_dim, dim),
147 | nn.Dropout(dropout)
148 | )
149 |
150 | def forward(self, x):
151 | # B N C
152 | B, N, C = x.shape
153 | x = x.reshape(B, self.H, self.W, self.D, C).contiguous().permute(0, 4, 1, 2, 3).contiguous() # B C H W
154 |
155 | ### (1) Slice
156 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
157 | .permute(0, 2, 1, 3).contiguous() # B H N C
158 | x_mid = self.in_project_x(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
159 | .permute(0, 2, 1, 3).contiguous() # B H N G
160 | slice_weights = self.softmax(
161 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G
162 | slice_norm = slice_weights.sum(2) # B H G
163 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
164 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
165 |
166 | ### (2) Attention among slice tokens
167 | q_slice_token = self.to_q(slice_token)
168 | k_slice_token = self.to_k(slice_token)
169 | v_slice_token = self.to_v(slice_token)
170 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
171 | attn = self.softmax(dots)
172 | attn = self.dropout(attn)
173 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
174 |
175 | ### (3) Deslice
176 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
177 | out_x = rearrange(out_x, 'b h n d -> b n (h d)')
178 | return self.to_out(out_x)
179 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transolver (ICML 2024 Spotlight)
2 |
3 | :triangular_flag_on_post:**News** (2025.04) We have released [Neural-Solver-Library](https://github.com/thuml/Neural-Solver-Library) as a simple and neat code base for PDE solving. It contains 17 well-reproduced neural solvers. Welcome to try this library and join the research in solving PDEs.
4 |
5 | :triangular_flag_on_post:**News** (2025.02) We present an upgraded version of Transolver, named [Transolver++](https://arxiv.org/abs/2502.02414v1), which can handle million-scale geometries in one GPU with more accurate results.
6 |
7 | :triangular_flag_on_post:**News** (2024.10) Transolver has been integrated into [NVIDIA modulus](https://github.com/NVIDIA/modulus/tree/main/examples/cfd/darcy_transolver).
8 |
9 | Transolver: A Fast Transformer Solver for PDEs on General Geometries [[Paper]](https://arxiv.org/abs/2402.02366) [[Slides]](https://wuhaixu2016.github.io/pdf/ICML2024_Transolver.pdf) [[Poster]](https://wuhaixu2016.github.io/pdf/poster_ICML2024_Transolver.pdf)
10 |
11 | In real-world applications, PDEs are typically discretized into large-scale meshes with complex geometries. To capture intricate physical correlations hidden under multifarious meshes, we propose the Transolver with the following features:
12 |
13 | - Going beyond previous work, Transolver **calculates attention among learned physical states** instead of mesh points, which empowers the model with **endogenetic geometry-general capability**.
14 | - Transolver achieves **22% error reduction over previous SOTA in six standard benchmarks** and excels in **large-scale industrial simulations**, including car and airfoil designs.
15 | - Transolver presents favorable **efficiency, scalability and out-of-distrbution generalizability**.
16 |
17 |
18 |
19 |
20 | Figure 1. Overview of Transolver.
21 |
22 |
23 |
24 | ## Transolver v.s. Previous Transformer Operators
25 |
26 | **All of the previous Transformer-based neural operators directly apply attention to mesh points.** However, the massive mesh points in practical applications will cause challenges in both computation cost and capturing physical correlations.
27 |
28 | Transolver is based on a more foundational idea, that is **learning intrinsic physical states under complex geometrics**. This design frees our model from superficial and unwieldy meshes and focuses more on physics modeling.
29 |
30 | As shown below, **Transolver can precisely capture miscellaneous physical states of PDEs**, such as (a) various fluid-structure interactions in a Darcy flow, (b) different extrusion regions of elastic materials, (c) shock wave and wake flow around the airfoil, (d) front-back surfaces and up-bottom spaces of driving cars.
31 |
32 |
33 |
34 |
35 | Figure 2. Visualization of learned physical states.
36 |
37 |
38 | ## Get Started
39 |
40 | 1. Please refer to different folders for detailed experiment instructions.
41 |
42 | 2. List of experiments:
43 |
44 | - Core code: see [./Physics_Attention.py](https://github.com/thuml/Transolver/blob/main/Physics_Attention.py)
45 | - Standard benchmarks: see [./PDE-Solving-StandardBenchmark](https://github.com/thuml/Transolver/tree/main/PDE-Solving-StandardBenchmark)
46 | - Car design task: see [./Car-Design-ShapeNetCar](https://github.com/thuml/Transolver/tree/main/Car-Design-ShapeNetCar)
47 | - Airfoil design task: see [./Airfoil-Design-AirfRANS](https://github.com/thuml/Transolver/tree/main/Airfoil-Design-AirfRANS)
48 |
49 | ## Results
50 |
51 | Transolver achieves consistent state-of-the-art in **six standard benchmarks and two practical design tasks**. **More than 20 baselines are compared.**
52 |
53 |
54 |
55 |
56 | Table 1. Results on six standard benchmarks.
57 |
58 |
59 |
60 |
61 |
62 | Table 2. Results on two design tasks: Car and Airfoild design.
63 |
64 |
65 | ## Showcases
66 |
67 |
68 |
69 |
70 | Figure 3. Comparison of Transolver and other models.
71 |
72 |
73 | ## Citation
74 |
75 | If you find this repo useful, please cite our paper.
76 |
77 | ```
78 | @inproceedings{wu2024Transolver,
79 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries},
80 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long},
81 | booktitle={International Conference on Machine Learning},
82 | year={2024}
83 | }
84 | ```
85 |
86 | ## Contact
87 |
88 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn).
89 |
90 | ## Acknowledgement
91 |
92 | We appreciate the following github repos a lot for their valuable code base or datasets:
93 |
94 | https://github.com/neuraloperator/neuraloperator
95 |
96 | https://github.com/neuraloperator/Geo-FNO
97 |
98 | https://github.com/thuml/Latent-Spectral-Models
99 |
100 | https://github.com/Extrality/AirfRANS
101 |
--------------------------------------------------------------------------------
/pic/Transolver.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/pic/Transolver.png
--------------------------------------------------------------------------------
/pic/physical_states.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/pic/physical_states.png
--------------------------------------------------------------------------------
/pic/showcases.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/pic/showcases.png
--------------------------------------------------------------------------------