├── .gitignore ├── Airfoil-Design-AirfRANS ├── LICENSE ├── README.md ├── dataset │ ├── dataset.py │ └── dataset_stats.ipynb ├── fig │ ├── results.png │ └── task.png ├── main.py ├── main_evaluation.py ├── models │ ├── GUNet.py │ ├── GraphSAGE.py │ ├── MLP.py │ ├── NN.py │ ├── PointNet.py │ └── Transolver.py ├── params.yaml ├── requirements.txt ├── scripts │ ├── Evaluation.sh │ ├── GraphSAGE.sh │ └── Transolver.sh ├── train.py └── utils │ ├── metrics.py │ ├── metrics_NACA.py │ ├── naca_generator.py │ └── reorganize.py ├── Car-Design-ShapeNetCar ├── README.md ├── dataset │ ├── dataset.py │ └── load_dataset.py ├── fig │ ├── car_slice_surf.png │ ├── case_study.png │ ├── results.png │ └── task.png ├── main.py ├── main_evaluation.py ├── models │ └── Transolver.py ├── requirements.txt ├── scripts │ ├── Evaluation.sh │ └── Transolver.sh ├── train.py └── utils │ └── drag_coefficient.py ├── LICENSE ├── PDE-Solving-StandardBenchmark ├── README.md ├── exp_airfoil.py ├── exp_darcy.py ├── exp_elas.py ├── exp_ns.py ├── exp_pipe.py ├── exp_plas.py ├── fig │ ├── scalibility.png │ ├── showcase.png │ └── standard_benchmark.png ├── model │ ├── Embedding.py │ ├── Physics_Attention.py │ ├── Transolver_Irregular_Mesh.py │ ├── Transolver_Structured_Mesh_2D.py │ └── Transolver_Structured_Mesh_3D.py ├── model_dict.py ├── requirements.txt ├── scripts │ ├── Transolver_Airfoil.sh │ ├── Transolver_Darcy.sh │ ├── Transolver_Elas.sh │ ├── Transolver_NS.sh │ ├── Transolver_Pipe.sh │ └── Transolver_Plas.sh └── utils │ ├── normalizer.py │ └── testloss.py ├── Physics_Attention.py ├── README.md └── pic ├── Transolver.png ├── physical_states.png └── showcases.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/README.md: -------------------------------------------------------------------------------- 1 | # Transolver for Airfoil Design 2 | 3 | **Paper Correction:** There is a typo in our descriptions about evaluation metrics for the physics fields of volume and surface. The paper's reported results are MSE (not Relative L2), which is completely the same as the AirfRANS paper. We are sincerely sorry for this mistake. 4 | 5 | We test [Transolver](https://arxiv.org/abs/2402.02366) on practical design tasks. The airfoil design task requires the model to estimate the surrounding and surface physical quantities of a 2D airfoil under different Reynolds and angles of attacks. 6 | 7 |

8 | 9 |

10 | Figure 1. Airfoil design task. Left: surrounding pressure; Right: x-direction wind speed. 11 |

12 | 13 | ## Get Started 14 | 15 | This part of code is developed based on the [[AirfRANS]](https://github.com/Extrality/AirfRANS). 16 | 17 | 1. Install Python 3.8. For convenience, execute the following command. 18 | 19 | ```bash 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | Note: You need to install [pytorch_geometric](https://github.com/pyg-team/pytorch_geometric). 24 | 25 | 2. Prepare Data. 26 | 27 | The experiment data is provided by [[AirfRANS]](https://github.com/Extrality/AirfRANS). You can directly download it with this [link](https://data.isir.upmc.fr/extrality/NeurIPS_2022/Dataset.zip) (9.3GB). 28 | 29 | 3. Train and evaluate model. We provide the experiment scripts under the folder `./scripts/`. You can reproduce the experiment results as the following examples: 30 | 31 | ```bash 32 | bash scripts/Transolver.sh # for Training Transolver (will take 20-24 hours on one single A100) 33 | bash scripts/Evaluation.sh # for Evaluation 34 | bash scripts/GraphSAGE.sh # for Training GraphSAGE (will take 30-36 hours on one single A100) 35 | ``` 36 | 37 | Note: You need to change the argument `--my_path` to your dataset path. 38 | 39 | 4. Test model with different settings. This benchmark supports four types of settings. 40 | 41 | | Settings | Argument | 42 | | -------------------------------------------- | ------------- | 43 | | Use full data | `-t full` | 44 | | Use scarce data | `-t scarce` | 45 | | Test on out-of-distribution Reynolds | `-t reynolds` | 46 | | Test on out-of-distribution Angle of Attacks | `-t aoa` | 47 | 48 | 5. Develop your own model. Here are the instructions: 49 | 50 | - Add the model file under folder `./models/`. 51 | 52 | - Add the training details in `./params.yaml`. If you donot want to change setting, just copy other models' configuration. 53 | 54 | - Add the model configuration into `./main.py`. 55 | 56 | - Add a script file under folder `./scripts/` and change the argument `--model`. 57 | 58 | ## Main Results 59 | 60 | Transolver achieves the consistent best performance in practical design tasks. 61 | 62 |

63 | 64 |

65 | Table 1. Model comparisons on the practical design tasks. 66 |

67 | 68 | ## Citation 69 | 70 | If you find this repo useful, please cite our paper. 71 | 72 | ``` 73 | @inproceedings{wu2024Transolver, 74 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries}, 75 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long}, 76 | booktitle={International Conference on Machine Learning}, 77 | year={2024} 78 | } 79 | ``` 80 | 81 | ## Contact 82 | 83 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn). 84 | 85 | ## Acknowledgement 86 | 87 | We appreciate the following github repos a lot for their valuable code base or datasets: 88 | 89 | https://github.com/Extrality/AirfRANS 90 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/dataset/dataset_stats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import json\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import seaborn as sns\n", 13 | "import os.path as osp\n", 14 | "\n", 15 | "NU = 1.56e-5\n", 16 | "\n", 17 | "sets = ['full_train', 'scarce_train', 'reynolds_train', 'aoa_train', 'full_test', 'reynolds_test', 'aoa_test']\n", 18 | "colors = ['cornflowerblue']*4 + ['burlywood']*3\n", 19 | "data_dir = 'MY_ROOT_DIRECTORY'\n", 20 | "\n", 21 | "for c, s in zip(colors, sets):\n", 22 | " with open(osp.join(data_dir, 'manifest.json'), 'r') as f:\n", 23 | " manifest = json.load(f)[s]\n", 24 | "\n", 25 | " us = []\n", 26 | " angles = []\n", 27 | " digits4 = []\n", 28 | " digits5 = []\n", 29 | " for sim in manifest:\n", 30 | " params = sim.split('_')\n", 31 | " us.append(float(params[2])/NU)\n", 32 | " angles.append(float(params[3]))\n", 33 | "\n", 34 | " if len(params) == 7:\n", 35 | " digits4.append(list(map(float, params[-3:])))\n", 36 | " else:\n", 37 | " digits5.append(list(map(float, params[-4:])))\n", 38 | "\n", 39 | " digits4 = np.array(digits4)\n", 40 | " digits5 = np.array(digits5)\n", 41 | "\n", 42 | " sns.set()\n", 43 | "\n", 44 | " fig, ax = plt.subplots(3, 3, figsize = (12, 12))\n", 45 | " ax[2, 1].hist(us, bins = 20, color = c)\n", 46 | " ax[2, 1].set_title('Reynolds number')\n", 47 | "\n", 48 | " ax[2, 2].hist(angles, bins = 20, color = c)\n", 49 | " ax[2, 2].set_xlabel('Degrees')\n", 50 | " ax[2, 2].set_title('Angle of attack')\n", 51 | "\n", 52 | " ax[0, 0].hist(digits4[:, 0], bins = 20, color = c)\n", 53 | " ax[0, 0].set_title(r'$1^{st}$ digit')\n", 54 | "\n", 55 | " ax[0, 1].hist(digits4[:, 1], bins = 20, color = c)\n", 56 | " ax[0, 1].set_title(r'$2^{nd}$ digit')\n", 57 | "\n", 58 | " ax[0, 2].hist(digits4[:, 2], bins = 20, color = c)\n", 59 | " ax[0, 2].set_title(r'$3^{rd}$ and $4^{th}$ digits')\n", 60 | "\n", 61 | " ax[1, 0].hist(digits5[:, 0], bins = 20, color = c)\n", 62 | " ax[1, 0].set_title(r'$1^{st}$ digit')\n", 63 | "\n", 64 | " ax[1, 1].hist(digits5[:, 1], bins = 20, color = c)\n", 65 | " ax[1, 1].set_title(r'$2^{nd}$ digit')\n", 66 | "\n", 67 | " ax[2, 0].hist(digits5[:, 2], bins = 2, color = c)\n", 68 | " ax[2, 0].set_title(r'$3^{rd}$ digit')\n", 69 | "\n", 70 | " ax[1, 2].hist(digits5[:, 3], bins = 20, color = c)\n", 71 | " ax[1, 2].set_title(r'$4^{th}$ and $5^{th}$ digits');\n", 72 | " fig.savefig(s, bbox_inches = 'tight', dpi = 150)" 73 | ] 74 | } 75 | ], 76 | "metadata": { 77 | "kernelspec": { 78 | "display_name": "Python 3.9.12 ('isir')", 79 | "language": "python", 80 | "name": "python3" 81 | }, 82 | "language_info": { 83 | "codemirror_mode": { 84 | "name": "ipython", 85 | "version": 3 86 | }, 87 | "file_extension": ".py", 88 | "mimetype": "text/x-python", 89 | "name": "python", 90 | "nbconvert_exporter": "python", 91 | "pygments_lexer": "ipython3", 92 | "version": "3.9.12" 93 | }, 94 | "orig_nbformat": 4, 95 | "vscode": { 96 | "interpreter": { 97 | "hash": "d00e44851a3a4d5201bc229183e4c0de3fea7314717b82800f8d82d2168b4a23" 98 | } 99 | } 100 | }, 101 | "nbformat": 4, 102 | "nbformat_minor": 2 103 | } 104 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/fig/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Airfoil-Design-AirfRANS/fig/results.png -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/fig/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Airfoil-Design-AirfRANS/fig/task.png -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/main.py: -------------------------------------------------------------------------------- 1 | import argparse, yaml, json 2 | import torch 3 | import train 4 | import utils.metrics as metrics 5 | from dataset.dataset import Dataset 6 | import os.path as osp 7 | import numpy as np 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--model', help='The model you want to train, choose between MLP, GraphSAGE, PointNet, GUNet', 11 | type=str) 12 | parser.add_argument('-n', '--nmodel', help='Number of trained models for standard deviation estimation (default: 1)', 13 | default=1, type=int) 14 | parser.add_argument('-w', '--weight', help='Weight in front of the surface loss (default: 1)', default=1, type=float) 15 | parser.add_argument('-t', '--task', 16 | help='Task to train on. Choose between "full", "scarce", "reynolds" and "aoa" (default: full)', 17 | default='full', type=str) 18 | parser.add_argument('-s', '--score', 19 | help='If you want to compute the score of the models on the associated test set. (default: 0)', 20 | default=0, type=int) 21 | parser.add_argument('--my_path', 22 | default='/data/path', type=str) 23 | parser.add_argument('--save_path', 24 | default='metrics', type=str) 25 | args = parser.parse_args() 26 | 27 | with open(args.my_path + '/manifest.json', 'r') as f: 28 | manifest = json.load(f) 29 | 30 | manifest_train = manifest[args.task + '_train'] 31 | test_dataset = manifest[args.task + '_test'] if args.task != 'scarce' else manifest['full_test'] 32 | n = int(.1 * len(manifest_train)) 33 | train_dataset = manifest_train[:-n] 34 | val_dataset = manifest_train[-n:] 35 | print("start load data") 36 | train_dataset, coef_norm = Dataset(train_dataset, norm=True, sample=None, my_path=args.my_path) 37 | val_dataset = Dataset(val_dataset, sample=None, coef_norm=coef_norm, my_path=args.my_path) 38 | print("load data finish") 39 | # Cuda 40 | use_cuda = torch.cuda.is_available() 41 | device = 'cuda:0' if use_cuda else 'cpu' 42 | if use_cuda: 43 | print('Using GPU') 44 | else: 45 | print('Using CPU') 46 | 47 | with open('params.yaml', 'r') as f: # hyperparameters of the model 48 | hparams = yaml.safe_load(f)[args.model] 49 | 50 | from models.MLP import MLP 51 | 52 | models = [] 53 | for i in range(args.nmodel): 54 | 55 | if args.model == 'Transolver': 56 | from models.Transolver import Transolver 57 | 58 | model = Transolver(n_hidden=256, 59 | n_layers=8, 60 | space_dim=7, 61 | fun_dim=0, 62 | n_head=8, 63 | mlp_ratio=2, 64 | out_dim=4, 65 | slice_num=32, 66 | unified_pos=1).cuda() 67 | else: 68 | encoder = MLP(hparams['encoder'], batch_norm=False) 69 | decoder = MLP(hparams['decoder'], batch_norm=False) 70 | if args.model == 'GraphSAGE': 71 | from models.GraphSAGE import GraphSAGE 72 | 73 | model = GraphSAGE(hparams, encoder, decoder) 74 | 75 | elif args.model == 'PointNet': 76 | from models.PointNet import PointNet 77 | 78 | model = PointNet(hparams, encoder, decoder) 79 | 80 | elif args.model == 'MLP': 81 | from models.NN import NN 82 | 83 | model = NN(hparams, encoder, decoder) 84 | 85 | elif args.model == 'GUNet': 86 | from models.GUNet import GUNet 87 | 88 | model = GUNet(hparams, encoder, decoder) 89 | 90 | log_path = osp.join(args.save_path, args.task, args.model) # path where you want to save log and figures 91 | print('start training') 92 | model = train.main(device, train_dataset, val_dataset, model, hparams, log_path, 93 | criterion='MSE_weighted', val_iter=10, reg=args.weight, name_mod=args.model, val_sample=True) 94 | print('end training') 95 | models.append(model) 96 | torch.save(models, osp.join(args.save_path, args.task, args.model, args.model)) 97 | 98 | if bool(args.score): 99 | print('start score') 100 | s = args.task + '_test' if args.task != 'scarce' else 'full_test' 101 | coefs = metrics.Results_test(device, [models], [hparams], coef_norm, args.my_path, path_out='scores', n_test=3, 102 | criterion='MSE', s=s) 103 | # models can be a stack of the same model (for example MLP) on the task s, if you have another stack of another model (for example GraphSAGE) 104 | # you can put in model argument [models_MLP, models_GraphSAGE] and it will output the results for both models (mean and std) in an ordered array. 105 | np.save(osp.join('scores', args.task, 'true_coefs'), coefs[0]) 106 | np.save(osp.join('scores', args.task, 'pred_coefs_mean'), coefs[1]) 107 | np.save(osp.join('scores', args.task, 'pred_coefs_std'), coefs[2]) 108 | for n, file in enumerate(coefs[3]): 109 | np.save(osp.join('scores', args.task, 'true_surf_coefs_' + str(n)), file) 110 | for n, file in enumerate(coefs[4]): 111 | np.save(osp.join('scores', args.task, 'surf_coefs_' + str(n)), file) 112 | np.save(osp.join('scores', args.task, 'true_bls'), coefs[5]) 113 | np.save(osp.join('scores', args.task, 'bls'), coefs[6]) 114 | print('end score') 115 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/main_evaluation.py: -------------------------------------------------------------------------------- 1 | import yaml, json 2 | import torch 3 | import utils.metrics as metrics 4 | from dataset.dataset import Dataset 5 | import os.path as osp 6 | import argparse 7 | import numpy as np 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--my_path', default='/data/path', type=str) # data save path 11 | parser.add_argument('--save_path', default='./', type=str) # model save path 12 | args = parser.parse_args() 13 | 14 | # Compute the normalization used for the training 15 | 16 | use_cuda = torch.cuda.is_available() 17 | device = 'cuda:0' if use_cuda else 'cpu' 18 | if use_cuda: 19 | print('Using GPU') 20 | else: 21 | print('Using CPU') 22 | 23 | data_root_dir = args.my_path 24 | ckpt_root_dir = args.save_path 25 | 26 | tasks = ['full'] 27 | 28 | for task in tasks: 29 | print('Generating results for task ' + task + '...') 30 | # task = 'full' # Choose between 'full', 'scarce', 'reynolds', and 'aoa' 31 | s = task + '_test' if task != 'scarce' else 'full_test' 32 | s_train = task + '_train' 33 | 34 | data_dir = osp.join(data_root_dir, 'Dataset') 35 | with open(osp.join(data_dir, 'manifest.json'), 'r') as f: 36 | manifest = json.load(f) 37 | 38 | manifest_train = manifest[s_train] 39 | n = int(.1 * len(manifest_train)) 40 | train_dataset = manifest_train[:-n] 41 | 42 | _, coef_norm = Dataset(train_dataset, norm=True, sample=None, my_path=data_dir) 43 | 44 | # Compute the scores on the test set 45 | 46 | model_names = ['Transolver'] 47 | models = [] 48 | hparams = [] 49 | 50 | for model in model_names: 51 | model_path = osp.join(ckpt_root_dir, 'metrics', task, model, model) 52 | mod = torch.load(model_path) 53 | print(mod) 54 | mod = [m.to(device) for m in mod] 55 | models.append(mod) 56 | 57 | with open('params.yaml', 'r') as f: 58 | hparam = yaml.safe_load(f)[model] 59 | hparams.append(hparam) 60 | 61 | results_dir = osp.join(ckpt_root_dir, 'scores', task) 62 | coefs = metrics.Results_test(device, models, hparams, coef_norm, data_dir, results_dir, n_test=3, criterion='MSE', 63 | s=s) 64 | # models can be a stack of the same model (for example MLP) on the task s, if you have another stack of another model (for example GraphSAGE) 65 | # you can put in model argument [models_MLP, models_GraphSAGE] and it will output the results for both models (mean and std) in an ordered array. 66 | 67 | np.save(osp.join(results_dir, 'true_coefs'), coefs[0]) 68 | np.save(osp.join(results_dir, 'pred_coefs_mean'), coefs[1]) 69 | np.save(osp.join(results_dir, 'pred_coefs_std'), coefs[2]) 70 | for n, file in enumerate(coefs[3]): 71 | np.save(osp.join(results_dir, 'true_surf_coefs_' + str(n)), file) 72 | for n, file in enumerate(coefs[4]): 73 | np.save(osp.join(results_dir, 'surf_coefs_' + str(n)), file) 74 | np.save(osp.join(results_dir, 'true_bls'), coefs[5]) 75 | np.save(osp.join(results_dir, 'bls'), coefs[6]) 76 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/models/GUNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.nn as nng 4 | import random 5 | 6 | def DownSample(id, x, edge_index, pos_x, pool, pool_ratio, r, max_neighbors): 7 | y = x.clone() 8 | n = int(x.size(0)) 9 | 10 | if pool is not None: 11 | y, _, _, _, id_sampled, _ = pool(y, edge_index) 12 | else: 13 | k = int((pool_ratio*torch.tensor(n, dtype = torch.float)).ceil()) 14 | id_sampled = random.sample(range(n), k) 15 | id_sampled = torch.tensor(id_sampled, dtype = torch.long) 16 | y = y[id_sampled] 17 | 18 | pos_x = pos_x[id_sampled] 19 | id.append(id_sampled) 20 | 21 | # if training: 22 | # edge_index_sampled = nng.radius_graph(x = pos_x.detach(), r = r, loop = True, max_num_neighbors = 64) 23 | # else: 24 | # edge_index_sampled = nng.radius_graph(x = pos_x.detach(), r = r, loop = True, max_num_neighbors = 512) 25 | edge_index_sampled = nng.radius_graph(x = pos_x.detach(), r = r, loop = True, max_num_neighbors = max_neighbors) 26 | 27 | return y, edge_index_sampled 28 | 29 | def UpSample(x, pos_x_up, pos_x_down): 30 | cluster = nng.nearest(pos_x_up, pos_x_down) 31 | x_up = x[cluster] 32 | 33 | return x_up 34 | 35 | class GUNet(nn.Module): 36 | def __init__(self, hparams, encoder, decoder): 37 | super(GUNet, self).__init__() 38 | 39 | self.L = hparams['nb_scale'] 40 | self.layer = hparams['layer'] 41 | self.pool_type = hparams['pool'] 42 | self.pool_ratio = hparams['pool_ratio'] 43 | self.list_r = hparams['list_r'] 44 | self.size_hidden_layers = hparams['size_hidden_layers'] 45 | self.size_hidden_layers_init = hparams['size_hidden_layers'] 46 | self.max_neighbors = hparams['max_neighbors'] 47 | self.dim_enc = hparams['encoder'][-1] 48 | self.bn_bool = hparams['batchnorm'] 49 | self.res = hparams['res'] 50 | self.head = 2 51 | self.activation = nn.ReLU() 52 | 53 | self.encoder = encoder 54 | self.decoder = decoder 55 | 56 | self.down_layers = nn.ModuleList() 57 | 58 | if self.pool_type != 'random': 59 | self.pool = nn.ModuleList() 60 | else: 61 | self.pool = None 62 | 63 | if self.layer == 'SAGE': 64 | self.down_layers.append(nng.SAGEConv( 65 | in_channels = self.dim_enc, 66 | out_channels = self.size_hidden_layers 67 | )) 68 | bn_in = self.size_hidden_layers 69 | 70 | elif self.layer == 'GAT': 71 | self.down_layers.append(nng.GATConv( 72 | in_channels = self.dim_enc, 73 | out_channels = self.size_hidden_layers, 74 | heads = self.head, 75 | add_self_loops = False, 76 | concat = True 77 | )) 78 | bn_in = self.head*self.size_hidden_layers 79 | 80 | if self.bn_bool == True: 81 | self.bn = nn.ModuleList() 82 | self.bn.append(nng.BatchNorm( 83 | in_channels = bn_in, 84 | track_running_stats = False 85 | )) 86 | else: 87 | self.bn = None 88 | 89 | 90 | for n in range(1, self.L): 91 | if self.pool_type != 'random': 92 | self.pool.append(nng.TopKPooling( 93 | in_channels = self.size_hidden_layers, 94 | ratio = self.pool_ratio[n - 1], 95 | nonlinearity = torch.sigmoid 96 | )) 97 | 98 | if self.layer == 'SAGE': 99 | self.down_layers.append(nng.SAGEConv( 100 | in_channels = self.size_hidden_layers, 101 | out_channels = 2*self.size_hidden_layers, 102 | )) 103 | self.size_hidden_layers = 2*self.size_hidden_layers 104 | bn_in = self.size_hidden_layers 105 | 106 | elif self.layer == 'GAT': 107 | self.down_layers.append(nng.GATConv( 108 | in_channels = self.head*self.size_hidden_layers, 109 | out_channels = self.size_hidden_layers, 110 | heads = 2, 111 | add_self_loops = False, 112 | concat = True 113 | )) 114 | 115 | if self.bn_bool == True: 116 | self.bn.append(nng.BatchNorm( 117 | in_channels = bn_in, 118 | track_running_stats = False 119 | )) 120 | 121 | self.up_layers = nn.ModuleList() 122 | 123 | if self.layer == 'SAGE': 124 | self.up_layers.append(nng.SAGEConv( 125 | in_channels = 3*self.size_hidden_layers_init, 126 | out_channels = self.dim_enc 127 | )) 128 | self.size_hidden_layers_init = 2*self.size_hidden_layers_init 129 | 130 | elif self.layer == 'GAT': 131 | self.up_layers.append(nng.GATConv( 132 | in_channels = 2*self.head*self.size_hidden_layers, 133 | out_channels = self.dim_enc, 134 | heads = 2, 135 | add_self_loops = False, 136 | concat = False 137 | )) 138 | 139 | if self.bn_bool == True: 140 | self.bn.append(nng.BatchNorm( 141 | in_channels = self.dim_enc, 142 | track_running_stats = False 143 | )) 144 | 145 | for n in range(1, self.L - 1): 146 | if self.layer == 'SAGE': 147 | self.up_layers.append(nng.SAGEConv( 148 | in_channels = 3*self.size_hidden_layers_init, 149 | out_channels = self.size_hidden_layers_init, 150 | )) 151 | bn_in = self.size_hidden_layers_init 152 | self.size_hidden_layers_init = 2*self.size_hidden_layers_init 153 | 154 | elif self.layer == 'GAT': 155 | self.up_layers.append(nng.GATConv( 156 | in_channels = 2*self.head*self.size_hidden_layers, 157 | out_channels = self.size_hidden_layers, 158 | heads = 2, 159 | add_self_loops = False, 160 | concat = True 161 | )) 162 | 163 | if self.bn_bool == True: 164 | self.bn.append(nng.BatchNorm( 165 | in_channels = bn_in, 166 | track_running_stats = False 167 | )) 168 | 169 | def forward(self, data): 170 | x, edge_index = data.x, data.edge_index 171 | id = [] 172 | edge_index_list = [edge_index.clone()] 173 | pos_x_list = [] 174 | z = self.encoder(x) 175 | if self.res: 176 | z_res = z.clone() 177 | 178 | z = self.down_layers[0](z, edge_index) 179 | 180 | if self.bn_bool == True: 181 | z = self.bn[0](z) 182 | 183 | z = self.activation(z) 184 | z_list = [z.clone()] 185 | for n in range(self.L - 1): 186 | pos_x = x[:, :2] if n == 0 else pos_x[id[n - 1]] 187 | pos_x_list.append(pos_x.clone()) 188 | 189 | if self.pool_type != 'random': 190 | z, edge_index = DownSample(id, z, edge_index, pos_x, self.pool[n], self.pool_ratio[n], self.list_r[n], self.max_neighbors) 191 | else: 192 | z, edge_index = DownSample(id, z, edge_index, pos_x, None, self.pool_ratio[n], self.list_r[n], self.max_neighbors) 193 | edge_index_list.append(edge_index.clone()) 194 | 195 | z = self.down_layers[n + 1](z, edge_index) 196 | 197 | if self.bn_bool == True: 198 | z = self.bn[n + 1](z) 199 | 200 | z = self.activation(z) 201 | z_list.append(z.clone()) 202 | pos_x_list.append(pos_x[id[-1]].clone()) 203 | 204 | for n in range(self.L - 1, 0, -1): 205 | z = UpSample(z, pos_x_list[n - 1], pos_x_list[n]) 206 | z = torch.cat([z, z_list[n - 1]], dim = 1) 207 | z = self.up_layers[n - 1](z, edge_index_list[n - 1]) 208 | 209 | if self.bn_bool == True: 210 | z = self.bn[self.L + n - 1](z) 211 | 212 | z = self.activation(z) if n != 1 else z 213 | 214 | del(z_list, pos_x_list, edge_index_list) 215 | 216 | if self.res: 217 | z = z + z_res 218 | 219 | z = self.decoder(z) 220 | 221 | return z 222 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/models/GraphSAGE.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch_geometric.nn as nng 3 | 4 | class GraphSAGE(nn.Module): 5 | def __init__(self, hparams, encoder, decoder): 6 | super(GraphSAGE, self).__init__() 7 | 8 | self.nb_hidden_layers = hparams['nb_hidden_layers'] 9 | self.size_hidden_layers = hparams['size_hidden_layers'] 10 | self.bn_bool = hparams['bn_bool'] 11 | self.activation = nn.ReLU() 12 | 13 | self.encoder = encoder 14 | self.decoder = decoder 15 | 16 | self.in_layer = nng.SAGEConv( 17 | in_channels = hparams['encoder'][-1], 18 | out_channels = self.size_hidden_layers 19 | ) 20 | 21 | self.hidden_layers = nn.ModuleList() 22 | for n in range(self.nb_hidden_layers - 1): 23 | self.hidden_layers.append(nng.SAGEConv( 24 | in_channels = self.size_hidden_layers, 25 | out_channels = self.size_hidden_layers 26 | )) 27 | 28 | 29 | self.out_layer = nng.SAGEConv( 30 | in_channels = self.size_hidden_layers, 31 | out_channels = hparams['decoder'][0] 32 | ) 33 | 34 | if self.bn_bool: 35 | self.bn = nn.ModuleList() 36 | for n in range(self.nb_hidden_layers): 37 | self.bn.append(nn.BatchNorm1d(self.size_hidden_layers, track_running_stats = False)) 38 | 39 | def forward(self, data): 40 | z, edge_index = data.x, data.edge_index 41 | z = self.encoder(z) 42 | 43 | z = self.in_layer(z, edge_index) 44 | if self.bn_bool: 45 | z = self.bn[0](z) 46 | z = self.activation(z) 47 | 48 | for n in range(self.nb_hidden_layers - 1): 49 | z = self.hidden_layers[n](z, edge_index) 50 | if self.bn_bool: 51 | z = self.bn[n + 1](z) 52 | z = self.activation(z) 53 | 54 | z = self.out_layer(z, edge_index) 55 | 56 | z = self.decoder(z) 57 | 58 | return z -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/models/MLP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | from torch.nn import BatchNorm1d, Identity 5 | from torch_geometric.nn import Linear 6 | 7 | class MLP(torch.nn.Module): 8 | r"""A multi-layer perception (MLP) model. 9 | 10 | Args: 11 | channel_list (List[int]): List of input, intermediate and output 12 | channels. :obj:`len(channel_list) - 1` denotes the number of layers 13 | of the MLP. 14 | dropout (float, optional): Dropout probability of each hidden 15 | embedding. (default: :obj:`0.`) 16 | batch_norm (bool, optional): If set to :obj:`False`, will not make use 17 | of batch normalization. (default: :obj:`True`) 18 | relu_first (bool, optional): If set to :obj:`True`, ReLU activation is 19 | applied before batch normalization. (default: :obj:`False`) 20 | """ 21 | def __init__(self, channel_list, dropout = 0., 22 | batch_norm = True, relu_first = False): 23 | super().__init__() 24 | assert len(channel_list) >= 2 25 | self.channel_list = channel_list 26 | self.dropout = dropout 27 | self.relu_first = relu_first 28 | 29 | self.lins = torch.nn.ModuleList() 30 | for dims in zip(self.channel_list[:-1], self.channel_list[1:]): 31 | self.lins.append(Linear(*dims)) 32 | 33 | self.norms = torch.nn.ModuleList() 34 | for dim in zip(self.channel_list[1:-1]): 35 | self.norms.append(BatchNorm1d(dim, track_running_stats = False) if batch_norm else Identity()) 36 | 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | for lin in self.lins: 41 | lin.reset_parameters() 42 | for norm in self.norms: 43 | if hasattr(norm, 'reset_parameters'): 44 | norm.reset_parameters() 45 | 46 | 47 | def forward(self, x: Tensor) -> Tensor: 48 | """""" 49 | x = self.lins[0](x) 50 | for lin, norm in zip(self.lins[1:], self.norms): 51 | if self.relu_first: 52 | x = x.relu_() 53 | x = norm(x) 54 | if not self.relu_first: 55 | x = x.relu_() 56 | x = F.dropout(x, p = self.dropout, training = self.training) 57 | x = lin.forward(x) 58 | return x 59 | 60 | 61 | def __repr__(self) -> str: 62 | return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})' -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/models/NN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.MLP import MLP 3 | 4 | class NN(nn.Module): 5 | def __init__(self, hparams, encoder, decoder): 6 | super(NN, self).__init__() 7 | 8 | self.nb_hidden_layers = hparams['nb_hidden_layers'] 9 | self.size_hidden_layers = hparams['size_hidden_layers'] 10 | self.bn_bool = hparams['bn_bool'] 11 | self.activation = nn.ReLU() 12 | 13 | self.encoder = encoder 14 | self.decoder = decoder 15 | 16 | self.dim_enc = hparams['encoder'][-1] 17 | 18 | self.nn = MLP([self.dim_enc] + [self.size_hidden_layers]*self.nb_hidden_layers + [self.dim_enc], batch_norm = self.bn_bool) 19 | 20 | def forward(self, data): 21 | z = self.encoder(data.x) 22 | z = self.nn(z) 23 | z = self.decoder(z) 24 | 25 | return z -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/models/PointNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.nn as nng 4 | from models.MLP import MLP 5 | 6 | 7 | class PointNet(nn.Module): 8 | def __init__(self, hparams, encoder, decoder): 9 | super(PointNet, self).__init__() 10 | 11 | self.base_nb = hparams['base_nb'] 12 | 13 | self.in_block = MLP([hparams['encoder'][-1], self.base_nb, self.base_nb * 2], batch_norm=False) 14 | self.max_block = MLP([self.base_nb * 2, self.base_nb * 4, self.base_nb * 8, self.base_nb * 32], 15 | batch_norm=False) 16 | 17 | self.out_block = MLP([self.base_nb * (32 + 2), self.base_nb * 16, self.base_nb * 8, self.base_nb * 4], 18 | batch_norm=False) 19 | 20 | self.encoder = encoder 21 | self.decoder = decoder 22 | 23 | self.fcfinal = nn.Linear(self.base_nb * 4, hparams['encoder'][-1]) 24 | 25 | def forward(self, data): 26 | z, batch = data.x.float(), data.batch.long() 27 | 28 | z = self.encoder(z) 29 | z = self.in_block(z) 30 | 31 | global_coef = self.max_block(z) 32 | global_coef = nng.global_max_pool(global_coef, batch=batch) 33 | nb_points = torch.zeros(global_coef.shape[0], device=z.device) 34 | 35 | for i in range(batch.max() + 1): 36 | nb_points[i] = (batch == i).sum() 37 | nb_points = nb_points.long() 38 | global_coef = torch.repeat_interleave(global_coef, nb_points, dim=0) 39 | 40 | z = torch.cat([z, global_coef], dim=1) 41 | z = self.out_block(z) 42 | z = self.fcfinal(z) 43 | z = self.decoder(z) 44 | 45 | return z -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/models/Transolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from timm.models.layers import trunc_normal_ 5 | from einops import rearrange, repeat 6 | 7 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1), 8 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU} 9 | 10 | 11 | class Physics_Attention_Irregular_Mesh(nn.Module): 12 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64): 13 | super().__init__() 14 | inner_dim = dim_head * heads 15 | self.dim_head = dim_head 16 | self.heads = heads 17 | self.scale = dim_head ** -0.5 18 | self.softmax = nn.Softmax(dim=-1) 19 | self.dropout = nn.Dropout(dropout) 20 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 21 | 22 | self.in_project_x = nn.Linear(dim, inner_dim) 23 | self.in_project_fx = nn.Linear(dim, inner_dim) 24 | self.in_project_slice = nn.Linear(dim_head, slice_num) 25 | for l in [self.in_project_slice]: 26 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 27 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 28 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 29 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 30 | self.to_out = nn.Sequential( 31 | nn.Linear(inner_dim, dim), 32 | nn.Dropout(dropout) 33 | ) 34 | 35 | def forward(self, x): 36 | # B N C 37 | B, N, C = x.shape 38 | 39 | ### (1) Slice 40 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \ 41 | .permute(0, 2, 1, 3).contiguous() # B H N C 42 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \ 43 | .permute(0, 2, 1, 3).contiguous() # B H N C 44 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G 45 | slice_norm = slice_weights.sum(2) # B H G 46 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 47 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 48 | 49 | ### (2) Attention among slice tokens 50 | q_slice_token = self.to_q(slice_token) 51 | k_slice_token = self.to_k(slice_token) 52 | v_slice_token = self.to_v(slice_token) 53 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 54 | attn = self.softmax(dots) 55 | attn = self.dropout(attn) 56 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 57 | 58 | ### (3) Deslice 59 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 60 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 61 | return self.to_out(out_x) 62 | 63 | 64 | class MLP(nn.Module): 65 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True): 66 | super(MLP, self).__init__() 67 | 68 | if act in ACTIVATION.keys(): 69 | act = ACTIVATION[act] 70 | else: 71 | raise NotImplementedError 72 | self.n_input = n_input 73 | self.n_hidden = n_hidden 74 | self.n_output = n_output 75 | self.n_layers = n_layers 76 | self.res = res 77 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act()) 78 | self.linear_post = nn.Linear(n_hidden, n_output) 79 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)]) 80 | 81 | def forward(self, x): 82 | x = self.linear_pre(x) 83 | for i in range(self.n_layers): 84 | if self.res: 85 | x = self.linears[i](x) + x 86 | else: 87 | x = self.linears[i](x) 88 | x = self.linear_post(x) 89 | return x 90 | 91 | 92 | class Transolver_block(nn.Module): 93 | """Transformer encoder block.""" 94 | 95 | def __init__( 96 | self, 97 | num_heads: int, 98 | hidden_dim: int, 99 | dropout: float, 100 | act='gelu', 101 | mlp_ratio=4, 102 | last_layer=False, 103 | out_dim=1, 104 | slice_num=32, 105 | ): 106 | super().__init__() 107 | self.last_layer = last_layer 108 | self.ln_1 = nn.LayerNorm(hidden_dim) 109 | self.Attn = Physics_Attention_Irregular_Mesh(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, 110 | dropout=dropout, slice_num=slice_num) 111 | self.ln_2 = nn.LayerNorm(hidden_dim) 112 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 113 | self.mlp_new = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 114 | if self.last_layer: 115 | self.ln_3 = nn.LayerNorm(hidden_dim) 116 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 117 | 118 | def forward(self, fx): 119 | fx = self.Attn(self.ln_1(fx)) + fx 120 | fx = self.mlp(self.ln_2(fx)) + fx 121 | if self.last_layer: 122 | return self.mlp2(self.ln_3(fx)) 123 | else: 124 | return fx 125 | 126 | 127 | class Transolver(nn.Module): 128 | def __init__(self, 129 | space_dim=1, 130 | n_layers=5, 131 | n_hidden=256, 132 | dropout=0, 133 | n_head=8, 134 | act='gelu', 135 | mlp_ratio=1, 136 | fun_dim=1, 137 | out_dim=1, 138 | slice_num=32, 139 | ref=8, 140 | unified_pos=False 141 | ): 142 | super(Transolver, self).__init__() 143 | self.__name__ = 'Transolver' 144 | self.ref = ref 145 | self.unified_pos = unified_pos 146 | if self.unified_pos: 147 | self.preprocess = MLP(fun_dim + space_dim + self.ref * self.ref, n_hidden * 2, n_hidden, 148 | n_layers=0, 149 | res=False, act=act) 150 | else: 151 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) 152 | 153 | self.n_hidden = n_hidden 154 | self.space_dim = space_dim 155 | 156 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden, 157 | dropout=dropout, 158 | act=act, 159 | mlp_ratio=mlp_ratio, out_dim=out_dim, 160 | slice_num=slice_num, 161 | last_layer=(_ == n_layers - 1)) 162 | for _ in range(n_layers)]) 163 | self.initialize_weights() 164 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float)) 165 | 166 | def initialize_weights(self): 167 | self.apply(self._init_weights) 168 | 169 | def _init_weights(self, m): 170 | if isinstance(m, nn.Linear): 171 | trunc_normal_(m.weight, std=0.02) 172 | if isinstance(m, nn.Linear) and m.bias is not None: 173 | nn.init.constant_(m.bias, 0) 174 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 175 | nn.init.constant_(m.bias, 0) 176 | nn.init.constant_(m.weight, 1.0) 177 | 178 | def get_grid(self, my_pos): 179 | # my_pos 1 N 3 180 | batchsize = my_pos.shape[0] 181 | 182 | gridx = torch.tensor(np.linspace(-2, 4, self.ref), dtype=torch.float) 183 | gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1]) 184 | gridy = torch.tensor(np.linspace(-1.5, 1.5, self.ref), dtype=torch.float) 185 | gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1]) 186 | grid_ref = torch.cat((gridx, gridy), dim=-1).cuda().reshape(batchsize, self.ref ** 2, 2) # B 4 4 4 3 187 | 188 | pos = torch.sqrt( 189 | torch.sum((my_pos[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, 190 | dim=-1)). \ 191 | reshape(batchsize, my_pos.shape[1], self.ref * self.ref).contiguous() 192 | return pos 193 | 194 | def forward(self, data): 195 | x, fx, T = data.x, None, None 196 | x = x[None, :, :] 197 | if self.unified_pos: 198 | new_pos = self.get_grid(data.pos[None, :, :]) 199 | x = torch.cat((x, new_pos), dim=-1) 200 | if fx is not None: 201 | fx = torch.cat((x, fx), -1) 202 | fx = self.preprocess(fx) 203 | else: 204 | fx = self.preprocess(x) 205 | fx = fx + self.placeholder[None, None, :] 206 | 207 | for block in self.blocks: 208 | fx = block(fx) 209 | 210 | return fx[0] 211 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/params.yaml: -------------------------------------------------------------------------------- 1 | GraphSAGE: 2 | encoder: [ 7, 64, 64, 8 ] 3 | decoder: [ 8, 64, 64, 4 ] 4 | 5 | nb_hidden_layers: 3 6 | size_hidden_layers: 64 7 | batch_size: 1 8 | nb_epochs: 398 9 | lr: 0.001 10 | max_neighbors: 64 11 | bn_bool: True 12 | subsampling: 32000 13 | r: 0.05 14 | 15 | Transolver: 16 | batch_size: 1 17 | nb_epochs: 398 18 | lr: 0.001 19 | max_neighbors: 64 20 | subsampling: 32000 21 | r: 0.05 22 | 23 | PointNet: 24 | encoder: [ 7, 64, 64, 8 ] 25 | decoder: [ 8, 64, 64, 4 ] 26 | 27 | base_nb: 8 28 | batch_size: 1 29 | nb_epochs: 398 30 | lr: 0.001 31 | subsampling: 32000 32 | 33 | MLP: 34 | encoder: [ 7, 64, 64, 8 ] 35 | decoder: [ 8, 64, 64, 4 ] 36 | 37 | nb_hidden_layers: 3 38 | size_hidden_layers: 64 39 | batch_size: 1 40 | nb_epochs: 398 41 | lr: 0.001 42 | bn_bool: True 43 | subsampling: 32000 44 | 45 | GUNet: 46 | encoder: [ 7, 64, 64, 8 ] 47 | decoder: [ 8, 64, 64, 4 ] 48 | 49 | layer: 'SAGE' 50 | pool: 'random' 51 | nb_scale: 5 52 | pool_ratio: [ .5, .5, .5, .5 ] 53 | list_r: [ .05, .2, .5, 1, 10 ] 54 | size_hidden_layers: 8 55 | batchnorm: True 56 | res: False 57 | 58 | batch_size: 1 59 | nb_epochs: 398 60 | lr: 0.001 61 | max_neighbors: 64 62 | subsampling: 32000 63 | r: 0.05 -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torch_geometric 3 | torch-cluster 4 | vtk 5 | timm 6 | einops 7 | seaborn 8 | pyvista 9 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/scripts/Evaluation.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python main_evaluation.py --my_path /data/naca/ -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/scripts/GraphSAGE.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python main.py --model GraphSAGE -t full --my_path /data/naca/Dataset --score 1 -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/scripts/Transolver.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=6 2 | 3 | python main.py --model Transolver -t full --my_path /data/naca/Dataset --score 1 -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/utils/metrics_NACA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from utils.naca_generator import camber_line 5 | 6 | sns.set() 7 | 8 | # Properties of air at sea level and 293.15K 9 | RHO = 1.184 10 | NU = 1.56e-5 11 | C = 346.1 12 | P_ref = 1.013e5 13 | 14 | 15 | def surface_coefficients(airfoil, aero_name, compressible=False, extrado=False): 16 | u_inf = float(aero_name.split('_')[2]) 17 | digits = list(map(float, aero_name.split('_')[4:-1])) 18 | if compressible: 19 | qInf = 0.5 * u_inf ** 2 * RHO 20 | else: 21 | qInf = 0.5 * u_inf ** 2 22 | 23 | if extrado: 24 | camber = camber_line(digits, airfoil.points[:, 0])[0] 25 | idx_extrado = (airfoil.points[:, 1] > camber) 26 | points = airfoil.points[:, 0] 27 | pressure = airfoil.point_data['p'] 28 | wss = np.linalg.norm(airfoil.point_data['wallShearStress'][:, :2], axis=1) 29 | 30 | c_p = np.concatenate([points[:, None], pressure[:, None] / qInf], axis=1) 31 | c_l = np.concatenate([points[:, None], wss[:, None] / qInf], axis=1) 32 | 33 | if extrado: 34 | return c_p, c_l, idx_extrado 35 | else: 36 | return c_p, c_l 37 | 38 | 39 | def compare_surface_coefs(coefs1, coefs2, extrado=True, path=None): 40 | ycp1, ycp2, c_p1, c_p2 = coefs1[0][:, 0], coefs2[0][:, 0], coefs1[0][:, 1], coefs2[0][:, 1] 41 | ycl1, ycl2, c_f1, c_f2 = coefs1[1][:, 0], coefs2[1][:, 0], coefs1[1][:, 1], coefs2[1][:, 1] 42 | 43 | fig, ax = plt.subplots(2, figsize=(20, 10)) 44 | if extrado: 45 | n_extrado1, n_extrado2 = coefs1[2], coefs2[2] 46 | ax[0].scatter(ycp1[:n_extrado1], c_p1[:n_extrado1], label='Extrado 1') 47 | ax[0].scatter(ycp1[n_extrado1:], c_p1[n_extrado1:], color='r', marker='x', label='Intrado 1') 48 | ax[0].scatter(ycp2[:n_extrado2], c_p2[:n_extrado2], color='y', label='Extrado Target') 49 | ax[0].scatter(ycp2[n_extrado2:], c_p2[n_extrado2:], color='g', marker='x', label='Intrado Target') 50 | 51 | ax[1].scatter(ycl1[:n_extrado1], c_f1[:n_extrado1], label='Extrado 1') 52 | ax[1].scatter(ycl1[n_extrado1:], c_f1[n_extrado1:], color='r', marker='x', label='Intrado 1') 53 | ax[1].scatter(ycl2[:n_extrado2], c_f2[:n_extrado2], color='y', label='Extrado Target') 54 | ax[1].scatter(ycl2[n_extrado2:], c_f2[n_extrado2:], color='g', marker='x', label='Intrado Target') 55 | 56 | else: 57 | ax[0].scatter(ycp1, c_p1, label='Experiment 1') 58 | ax[0].scatter(ycp2, c_p2, color='y', label='Experiment Target') 59 | 60 | ax[1].scatter(ycl1, c_f1, label='Experiment 1') 61 | ax[1].scatter(ycl2, c_f2, color='y', label='Experiment Targer') 62 | 63 | ax[0].invert_yaxis() 64 | ax[0].set_xlabel('x/c') 65 | ax[1].set_xlabel('x/c') 66 | ax[0].set_ylabel(r'$C_p$') 67 | ax[1].set_ylabel(r'$C_f$') 68 | ax[0].set_title('Pressure coefficient') 69 | ax[1].set_title('Skin friction coefficient') 70 | ax[0].legend(loc='best') 71 | ax[1].legend(loc='best') 72 | 73 | if path != None: 74 | fig.savefig(path + 'surface_coefs.png', bbox_inches='tight', dpi=150) 75 | 76 | 77 | def boundary_layer(airfoil, internal, aero_name, x, y=1e-3, resolution=int(1e3), direction='normals', rotation=False, 78 | extrado=True): 79 | u_inf = float(aero_name.split('_')[2]) 80 | digits = list(map(float, aero_name.split('_')[4:-1])) 81 | camber = camber_line(digits, airfoil.points[:, 0])[0] 82 | idx_extrado = (airfoil.points[:, 1] > camber) 83 | 84 | if extrado: 85 | arg = np.argmin(np.abs(airfoil.points[idx_extrado, 0] - x)) + 1 86 | arg = np.argwhere(idx_extrado.cumsum() == arg).min() 87 | else: 88 | arg = np.argmin(np.abs(airfoil.points[~idx_extrado, 0] - x)) + 1 89 | arg = np.argwhere((~idx_extrado).cumsum() == arg).min() 90 | 91 | if direction == 'normals': 92 | normals = -airfoil.point_data['Normals'][arg] 93 | 94 | elif direction == 'y': 95 | normals = np.array([0, 2 * int(extrado) - 1, 0]) 96 | 97 | a, b = airfoil.points[arg], airfoil.points[arg] + y * normals 98 | bl = internal.sample_over_line(a, b, resolution=resolution) 99 | 100 | if rotation: 101 | rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) 102 | u = (bl.point_data['U'] * (rot @ normals)).sum(axis=1) 103 | v = (bl.point_data['U'] * normals).sum(axis=1) 104 | else: 105 | u = bl.point_data['U'][:, 0] 106 | v = bl.point_data['U'][:, 1] 107 | 108 | nut = bl.point_data['nut'] 109 | yc = bl.points[:, 1] - a[1] 110 | 111 | return yc, u / u_inf, v / u_inf, nut / NU 112 | 113 | 114 | def compare_boundary_layer(coefs1, coefs2, ylim=.1, path=None, ylog=False): 115 | yc1, u1, v1, nut1 = coefs1 116 | yc2, u2, v2, nut2 = coefs2 117 | 118 | fig, ax = plt.subplots(1, 3, figsize=(30, 10)) 119 | ax[0].scatter(u1, yc1, label='Experiment 1') 120 | ax[0].scatter(u2, yc2, label='Experiment 2', color='r', marker='x') 121 | ax[0].set_xlabel(r'$u/U_\infty$') 122 | ax[0].set_ylabel(r'$(y-y_0)/c$') 123 | # ax[0].set_xlim([-0.2, 1.4]) 124 | # ax[0].set_ylim([0, ylim]) 125 | ax[0].legend(loc='best') 126 | 127 | ax[1].scatter(v1, yc1, label='Experiment 1') 128 | ax[1].scatter(v2, yc2, label='Experiment 2', color='r', marker='x') 129 | ax[1].set_xlabel(r'$v/U_\infty$') 130 | ax[1].set_ylabel(r'$(y-y_0)/c$') 131 | # ax[1].set_xlim([-0.2, 0.2]) 132 | # ax[1].set_ylim([0, ylim]) 133 | ax[1].legend(loc='best') 134 | 135 | ax[2].scatter(nut1, yc1, label='Experience 1') 136 | ax[2].scatter(nut2, yc2, label='Experience 2', color='r', marker='x') 137 | # ax[2].set_ylim([0, ylim]) 138 | ax[2].set_xlabel(r'$\nu_t/\nu$') 139 | ax[2].set_ylabel(r'$(y-y_0)/c$') 140 | ax[2].legend(loc='best') 141 | 142 | if ylog: 143 | ax[0].set_yscale('log') 144 | ax[1].set_yscale('log') 145 | ax[2].set_yscale('log') 146 | 147 | if path != None: 148 | fig.savefig(path + 'boundary_layer.png', bbox_inches='tight', dpi=150) 149 | 150 | 151 | def plot_residuals(path, params): 152 | datas = dict() 153 | if params['turbulence'] == 'SA': 154 | fields = ['Ux', 'Uy', 'p', 'nuTilda'] 155 | elif params['turbulence'] == 'SST': 156 | fields = ['Ux', 'Uy', 'p', 'k', 'omega'] 157 | for field in fields: 158 | data = np.loadtxt(path + 'logs/' + field + '_0')[:, 1] 159 | datas[field] = data 160 | 161 | if params['turbulence'] == 'SA': 162 | fig, ax = plt.subplots(2, 2, figsize=(20, 20)) 163 | ax[1, 1].plot(datas['nuTilda']) 164 | ax[1, 1].set_yscale('log') 165 | ax[1, 1].set_title('nuTilda residual') 166 | ax[1, 1].set_xlabel('Number of iterations') 167 | 168 | elif params['turbulence'] == 'SST': 169 | fig, ax = plt.subplots(3, 2, figsize=(30, 20)) 170 | ax[1, 1].plot(datas['k']) 171 | ax[1, 1].set_yscale('log') 172 | ax[1, 1].set_title('k residual') 173 | ax[1, 1].set_xlabel('Number of iterations') 174 | 175 | ax[2, 0].plot(datas['omega']) 176 | ax[2, 0].set_yscale('log') 177 | ax[2, 0].set_title('omega residual') 178 | ax[2, 0].set_xlabel('Number of iterations'); 179 | 180 | ax[0, 0].plot(datas['Ux']) 181 | ax[0, 0].set_yscale('log') 182 | ax[0, 0].set_title('Ux residual') 183 | 184 | ax[0, 1].plot(datas['Uy']) 185 | ax[0, 1].set_yscale('log') 186 | ax[0, 1].set_title('Uy residual') 187 | 188 | ax[1, 0].plot(datas['p']) 189 | ax[1, 0].set_yscale('log') 190 | ax[1, 0].set_title('p residual') 191 | ax[1, 0].set_xlabel('Number of iterations'); 192 | 193 | fig.savefig(path + 'residuals.png', bbox_inches='tight', dpi=150) 194 | 195 | return datas 196 | 197 | 198 | def plot_coef_convergence(path, params): 199 | datas = dict() 200 | datas['c_d'] = np.loadtxt(path + 'postProcessing/forceCoeffs1/0/coefficient.dat')[:, 1] 201 | datas['c_l'] = np.loadtxt(path + 'postProcessing/forceCoeffs1/0/coefficient.dat')[:, 3] 202 | c_d, c_l = datas['c_d'][-1], datas['c_l'][-1] 203 | 204 | fig, ax = plt.subplots(2, figsize=(30, 15)) 205 | ax[0].plot(datas['c_d']) 206 | ax[0].set_ylim([.5 * c_d, 1.5 * c_d]) 207 | ax[0].set_title('Drag coefficient') 208 | ax[0].set_xlabel('Number of iterations') 209 | ax[0].set_ylabel(r'$C_D$') 210 | 211 | ax[1].plot(datas['c_l']) 212 | ax[1].set_title('Lift coefficient') 213 | ax[1].set_ylim([.5 * c_l, 1.5 * c_l]) 214 | ax[1].set_ylabel(r'$C_L$') 215 | ax[1].set_xlabel('Number of iterations'); 216 | 217 | print('Drag coefficient: {0:.5}, lift coefficient: {1:.5}'.format(c_d, c_l)) 218 | 219 | fig.savefig(path + 'coef_convergence.png', bbox_inches='tight', dpi=150) 220 | 221 | return datas, c_d, c_l 222 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/utils/naca_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def thickness_dist(t, x, CTE=True): 5 | # CTE for close trailing edge 6 | if CTE: 7 | a = -0.1036 8 | else: 9 | a = -0.1015 10 | return 5 * t * (0.2969 * np.sqrt(x) - 0.1260 * x - 0.3516 * x ** 2 + 0.2843 * x ** 3 + a * x ** 4) 11 | 12 | 13 | def camber_line(params, x): 14 | # assert np.all(np.logical_and(x >= -1e-1, x <= 1)), 'Found x > 1 or x < 0' 15 | y_c = np.zeros_like(x) 16 | dy_c = np.zeros_like(x) 17 | 18 | if len(params) == 2: 19 | m = params[0] / 100 20 | p = params[1] / 10 21 | 22 | if p == 0: 23 | dy_c = -2 * m * x 24 | return y_c, dy_c 25 | elif p == 1: 26 | dy_c = 2 * m * (1 - x) 27 | return y_c, dy_c 28 | 29 | mask1 = (x < p) 30 | mask2 = (x >= p) 31 | y_c[mask1] = (m / p ** 2) * (2 * p * x[mask1] - x[mask1] ** 2) 32 | dy_c[mask1] = (2 * m / p ** 2) * (p - x[mask1]) 33 | y_c[mask2] = (m / (1 - p) ** 2) * ((1 - 2 * p) + 2 * p * x[mask2] - x[mask2] ** 2) 34 | dy_c[mask2] = (2 * m / (1 - p) ** 2) * (p - x[mask2]) 35 | 36 | elif len(params) == 3: 37 | l, p, q = params 38 | c_l, x_f = 3 / 20 * l, p / 20 39 | 40 | f = lambda x: x * (1 - np.sqrt(x / 3)) - x_f 41 | df = lambda x: 1 - 3 * np.sqrt(x / 3) / 2 42 | old_m = 0.5 43 | cond = True 44 | while cond: 45 | new_m = np.max([old_m - f(old_m) / df(old_m), 0]) 46 | cond = (np.abs(old_m - new_m) > 1e-15) 47 | old_m = new_m 48 | m = old_m 49 | r = (3 * m - 7 * m ** 2 + 8 * m ** 3 - 4 * m ** 4) / np.sqrt(m * (1 - m)) - 3 / 2 * (1 - 2 * m) * ( 50 | np.pi / 2 - np.arcsin(1 - 2 * m)) 51 | k_1 = c_l / r 52 | 53 | mask1 = (x <= m) 54 | mask2 = (x > m) 55 | if q == 0: 56 | y_c[mask1] = k_1 * ((x[mask1] ** 3 - 3 * m * x[mask1] ** 2 + m ** 2 * (3 - m) * x[mask1])) 57 | dy_c[mask1] = k_1 * (3 * x[mask1] ** 2 - 6 * m * x[mask1] + m ** 2 * (3 - m)) 58 | y_c[mask2] = k_1 * m ** 3 * (1 - x[mask2]) 59 | dy_c[mask2] = -k_1 * m ** 3 * np.ones_like(dy_c[mask2]) 60 | 61 | elif q == 1: 62 | k = (3 * (m - x_f) ** 2 - m ** 3) / (1 - m) ** 3 63 | y_c[mask1] = k_1 * ((x[mask1] - m) ** 3 - k * (1 - m) ** 3 * x[mask1] - m ** 3 * x[mask1] + m ** 3) 64 | dy_c[mask1] = k_1 * (3 * (x[mask1] - m) ** 2 - k * (1 - m) ** 3 - m ** 3) 65 | y_c[mask2] = k_1 * (k * (x[mask2] - m) ** 3 - k * (1 - m) ** 3 * x[mask2] - m ** 3 * x[mask2] + m ** 3) 66 | dy_c[mask2] = k_1 * (3 * k * (x[mask2] - m) ** 2 - k * (1 - m) ** 3 - m ** 3) 67 | 68 | else: 69 | raise ValueError('Q must be 0 for normal camber or 1 for reflex camber.') 70 | 71 | else: 72 | raise ValueError('The first input must be a tuple of the 2 or 3 digits that represent the camber line.') 73 | 74 | return y_c, dy_c 75 | 76 | 77 | def naca_generator(params, nb_samples=400, scale=1, origin=(0, 0), cosine_spacing=True, verbose=True, CTE=True): 78 | if len(params) == 3: 79 | params_c = params[:2] 80 | t = params[2] / 100 81 | if verbose: 82 | print(f'Generating naca M = {params_c[0]}, P = {params_c[1]}, XX = {t * 100}') 83 | elif len(params) == 4: 84 | params_c = params[:3] 85 | t = params[3] / 100 86 | if verbose: 87 | print(f'Generating naca L = {params_c[0]}, P = {params_c[1]}, Q = {params_c[2]}, XX = {t * 100}') 88 | else: 89 | raise ValueError('The first argument must be a tuple of the 4 or 5 digits of the airfoil.') 90 | 91 | if cosine_spacing: 92 | beta = np.pi * np.linspace(1, 0, nb_samples + 1, endpoint=True) 93 | x = (1 - np.cos(beta)) / 2 94 | else: 95 | x = np.linspace(1, 0, nb_samples + 1, endpoint=True) 96 | 97 | y_c, dy_c = camber_line(params_c, x) 98 | y_t = thickness_dist(t, x, CTE) 99 | theta = np.arctan(dy_c) 100 | x_u = x - y_t * np.sin(theta) 101 | x_l = x + y_t * np.sin(theta) 102 | y_u = y_c + y_t * np.cos(theta) 103 | y_l = y_c - y_t * np.cos(theta) 104 | x = np.concatenate([x_u, x_l[:-1][::-1]], axis=0) 105 | y = np.concatenate([y_u, y_l[:-1][::-1]], axis=0) 106 | pos = np.stack([ 107 | x * scale + origin[0], 108 | y * scale + origin[1] 109 | ], axis=-1 110 | ) 111 | pos[0], pos[-1] = np.array([1, 0]), np.array([1, 0]) 112 | return pos 113 | -------------------------------------------------------------------------------- /Airfoil-Design-AirfRANS/utils/reorganize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def reorganize(in_order_points, out_order_points, quantity_to_reordered): 4 | n = out_order_points.shape[0] 5 | idx = np.zeros(n) 6 | for i in range(n): 7 | cond = (out_order_points[i] == in_order_points) 8 | cond = cond[:, 0]*cond[:, 1] 9 | idx[i] = np.argwhere(cond)[0][0] 10 | idx = idx.astype('int') 11 | 12 | assert (in_order_points[idx] == out_order_points).all() 13 | 14 | return quantity_to_reordered[idx] -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/README.md: -------------------------------------------------------------------------------- 1 | # Transolver for Car Design 2 | 3 | We test [Transolver](https://arxiv.org/abs/2402.02366) on practical design tasks. The car design task requires the model to estimate the surrounding wind speed and surface pressure for a driving car. 4 | 5 |

6 | 7 |

8 | Figure 1. Car design task. 9 |

10 | 11 | Relative error of surrounding wind, surface pressure and [drag coefficient](https://en.wikipedia.org/wiki/Drag_coefficient) are recorded, as well as [Spearman's rank correlations](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient), which can be used to quantify the model's capability in ranking different designs. 12 | 13 |

14 | 15 |

16 | Table 1. Model comparisons of the car design task. 17 |

18 | 19 | 20 | ## Get Started 21 | 22 | 1. Install Python 3.8. For convenience, execute the following command. 23 | 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | Note: You need to install [pytorch_geometric](https://github.com/pyg-team/pytorch_geometric). 29 | 30 | 2. Prepare Data. 31 | 32 | The raw data can be found [[here]](http://www.nobuyuki-umetani.com/publication/mlcfd_data.zip), which is provided by [Nobuyuki Umetani](https://dl.acm.org/doi/abs/10.1145/3197517.3201325). 33 | 34 | 3. Train and evaluate model. We provide the experiment scripts under the folder `./scripts/`. You can reproduce the experiment results as the following examples: 35 | 36 | ```bash 37 | bash scripts/Transolver.sh # for Training (will take 8-10 hours on one single A100) 38 | bash scripts/Evaluation.sh # for Evaluation 39 | ``` 40 | 41 | Note: You need to change the argument `--data_dir` and `--save_dir` to your dataset path. Here `data_dir` is for the raw data and `save_dir` is to save the preprocessed data. 42 | 43 | If you have already downloaded or generated the preprocecessed data, you can change `--preprocessed` as True for speed up. 44 | 45 | 4. Develop your own model. Here are the instructions: 46 | 47 | - Add the model file under folder `./models/`. 48 | - Add the model configuration into `./main.py`. 49 | - Add a script file under folder `./scripts/` and change the argument `--model`. 50 | 51 | ## Slice Visualization 52 | 53 | Transolver proposes to **learn physical states** hidden under the unwieldy meshes. 54 | 55 | The following visualization demonstrates that Transolver can successfully learn to ascribe the points under similar physical state to the same slice, such as windshield, license plate and headlight. 56 | 57 |

58 | 59 |

60 | Figure 2. Visualization for Transolver learned physical states. 61 |

62 | 63 | 64 | ## Showcases 65 | 66 | Transolver achieves the best performance in complex geometries and hybrid physics. 67 | 68 |

69 | 70 |

71 | Figure 3. Case study of Transolver and other models. 72 |

73 | 74 | 75 | ## Citation 76 | 77 | If you find this repo useful, please cite our paper. 78 | 79 | ``` 80 | @inproceedings{wu2024Transolver, 81 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries}, 82 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long}, 83 | booktitle={International Conference on Machine Learning}, 84 | year={2024} 85 | } 86 | ``` 87 | 88 | ## Contact 89 | 90 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn). 91 | 92 | ## Acknowledgement 93 | 94 | We appreciate the following papers a lot for their valuable code base or datasets: 95 | 96 | https://dl.acm.org/doi/abs/10.1145/3197517.3201325 97 | 98 | https://openreview.net/forum?id=EyQO9RPhwN 99 | -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/dataset/load_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataset.dataset import get_datalist 3 | 4 | 5 | def get_samples(root): 6 | folds = [f'param{i}' for i in range(9)] 7 | samples = [] 8 | for fold in folds: 9 | fold_samples = [] 10 | files = os.listdir(os.path.join(root, fold)) 11 | for file in files: 12 | path = os.path.join(root, os.path.join(fold, file)) 13 | if os.path.isdir(path): 14 | fold_samples.append(os.path.join(fold, file)) 15 | samples.append(fold_samples) 16 | return samples # 100 + 99 + 97 + 100 + 100 + 96 + 100 + 98 + 99 = 889 samples 17 | 18 | 19 | def load_train_val_fold(args, preprocessed): 20 | samples = get_samples(args.data_dir) 21 | trainlst = [] 22 | for i in range(len(samples)): 23 | if i == args.fold_id: 24 | continue 25 | trainlst += samples[i] 26 | vallst = samples[args.fold_id] if 0 <= args.fold_id < len(samples) else None 27 | 28 | if preprocessed: 29 | print("use preprocessed data") 30 | print("loading data") 31 | train_dataset, coef_norm = get_datalist(args.data_dir, trainlst, norm=True, savedir=args.save_dir, 32 | preprocessed=preprocessed) 33 | val_dataset = get_datalist(args.data_dir, vallst, coef_norm=coef_norm, savedir=args.save_dir, 34 | preprocessed=preprocessed) 35 | print("load data finish") 36 | return train_dataset, val_dataset, coef_norm 37 | 38 | 39 | def load_train_val_fold_file(args, preprocessed): 40 | samples = get_samples(args.data_dir) 41 | trainlst = [] 42 | for i in range(len(samples)): 43 | if i == args.fold_id: 44 | continue 45 | trainlst += samples[i] 46 | vallst = samples[args.fold_id] if 0 <= args.fold_id < len(samples) else None 47 | 48 | if preprocessed: 49 | print("use preprocessed data") 50 | print("loading data") 51 | train_dataset, coef_norm = get_datalist(args.data_dir, trainlst, norm=True, savedir=args.save_dir, 52 | preprocessed=preprocessed) 53 | val_dataset = get_datalist(args.data_dir, vallst, coef_norm=coef_norm, savedir=args.save_dir, 54 | preprocessed=preprocessed) 55 | print("load data finish") 56 | return train_dataset, val_dataset, coef_norm, vallst 57 | -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/fig/car_slice_surf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/car_slice_surf.png -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/fig/case_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/case_study.png -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/fig/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/results.png -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/fig/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/Car-Design-ShapeNetCar/fig/task.png -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/main.py: -------------------------------------------------------------------------------- 1 | import train 2 | import os 3 | import torch 4 | import argparse 5 | 6 | from dataset.load_dataset import load_train_val_fold 7 | from dataset.dataset import GraphDataset 8 | from models.Transolver import Model 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_dir', default='/data/PDE_data/mlcfd_data/training_data') 12 | parser.add_argument('--save_dir', default='/data/PDE_data/mlcfd_data/preprocessed_data') 13 | parser.add_argument('--fold_id', default=0, type=int) 14 | parser.add_argument('--gpu', default=0, type=int) 15 | parser.add_argument('--val_iter', default=10, type=int) 16 | parser.add_argument('--cfd_config_dir', default='cfd/cfd_params.yaml') 17 | parser.add_argument('--cfd_model') 18 | parser.add_argument('--cfd_mesh', action='store_true') 19 | parser.add_argument('--r', default=0.2, type=float) 20 | parser.add_argument('--weight', default=0.5, type=float) 21 | parser.add_argument('--lr', default=0.001, type=float) 22 | parser.add_argument('--batch_size', default=1, type=float) 23 | parser.add_argument('--nb_epochs', default=200, type=float) 24 | parser.add_argument('--preprocessed', default=1, type=int) 25 | args = parser.parse_args() 26 | print(args) 27 | 28 | hparams = {'lr': args.lr, 'batch_size': args.batch_size, 'nb_epochs': args.nb_epochs} 29 | 30 | n_gpu = torch.cuda.device_count() 31 | use_cuda = 0 <= args.gpu < n_gpu and torch.cuda.is_available() 32 | device = torch.device(f'cuda:{args.gpu}' if use_cuda else 'cpu') 33 | 34 | train_data, val_data, coef_norm = load_train_val_fold(args, preprocessed=args.preprocessed) 35 | train_ds = GraphDataset(train_data, use_cfd_mesh=args.cfd_mesh, r=args.r) 36 | val_ds = GraphDataset(val_data, use_cfd_mesh=args.cfd_mesh, r=args.r) 37 | 38 | if args.cfd_model == 'Transolver': 39 | model = Model(n_hidden=256, n_layers=8, space_dim=7, 40 | fun_dim=0, 41 | n_head=8, 42 | mlp_ratio=2, out_dim=4, 43 | slice_num=32, 44 | unified_pos=0).cuda() 45 | 46 | path = f'metrics/{args.cfd_model}/{args.fold_id}/{args.nb_epochs}_{args.weight}' 47 | if not os.path.exists(path): 48 | os.makedirs(path) 49 | 50 | model = train.main(device, train_ds, val_ds, model, hparams, path, val_iter=args.val_iter, reg=args.weight, 51 | coef_norm=coef_norm) 52 | -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/main_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import yaml 5 | import numpy as np 6 | import time 7 | from torch import nn 8 | from torch_geometric.loader import DataLoader 9 | from utils.drag_coefficient import cal_coefficient 10 | from dataset.load_dataset import load_train_val_fold_file 11 | from dataset.dataset import GraphDataset 12 | import scipy as sc 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data_dir', default='/data/PDE_data/mlcfd_data/training_data') 16 | parser.add_argument('--save_dir', default='/data/PDE_data/mlcfd_data/preprocessed_data') 17 | parser.add_argument('--fold_id', default=0, type=int) 18 | parser.add_argument('--gpu', default=0, type=int) 19 | parser.add_argument('--cfd_model') 20 | parser.add_argument('--cfd_mesh', action='store_true') 21 | parser.add_argument('--r', default=0.2, type=float) 22 | parser.add_argument('--weight', default=0.5, type=float) 23 | parser.add_argument('--nb_epochs', default=200, type=float) 24 | args = parser.parse_args() 25 | print(args) 26 | 27 | 28 | n_gpu = torch.cuda.device_count() 29 | use_cuda = 0 <= args.gpu < n_gpu and torch.cuda.is_available() 30 | device = torch.device(f'cuda:{args.gpu}' if use_cuda else 'cpu') 31 | 32 | train_data, val_data, coef_norm, vallst = load_train_val_fold_file(args, preprocessed=True) 33 | val_ds = GraphDataset(val_data, use_cfd_mesh=args.cfd_mesh, r=args.r) 34 | 35 | path = f'metrics/{args.cfd_model}/{args.fold_id}/{args.nb_epochs}_{args.weight}' 36 | model = torch.load(os.path.join(path, f'model_{args.nb_epochs}.pth')).to(device) 37 | 38 | test_loader = DataLoader(val_ds, batch_size=1) 39 | 40 | if not os.path.exists('./results/' + args.cfd_model + '/'): 41 | os.makedirs('./results/' + args.cfd_model + '/') 42 | 43 | with torch.no_grad(): 44 | model.eval() 45 | criterion_func = nn.MSELoss(reduction='none') 46 | l2errs_press = [] 47 | l2errs_velo = [] 48 | mses_press = [] 49 | mses_velo_var = [] 50 | times = [] 51 | gt_coef_list = [] 52 | pred_coef_list = [] 53 | coef_error = 0 54 | index = 0 55 | for cfd_data, geom in test_loader: 56 | print(vallst[index]) 57 | cfd_data = cfd_data.to(device) 58 | geom = geom.to(device) 59 | tic = time.time() 60 | out = model((cfd_data, geom)) 61 | toc = time.time() 62 | targets = cfd_data.y 63 | 64 | if coef_norm is not None: 65 | mean = torch.tensor(coef_norm[2]).to(device) 66 | std = torch.tensor(coef_norm[3]).to(device) 67 | pred_press = out[cfd_data.surf, -1] * std[-1] + mean[-1] 68 | gt_press = targets[cfd_data.surf, -1] * std[-1] + mean[-1] 69 | pred_surf_velo = out[cfd_data.surf, :-1] * std[:-1] + mean[:-1] 70 | gt_surf_velo = targets[cfd_data.surf, :-1] * std[:-1] + mean[:-1] 71 | pred_velo = out[~cfd_data.surf, :-1] * std[:-1] + mean[:-1] 72 | gt_velo = targets[~cfd_data.surf, :-1] * std[:-1] + mean[:-1] 73 | out_denorm = out * std + mean 74 | y_denorm = targets * std + mean 75 | 76 | np.save('./results/' + args.cfd_model + '/' + str(index) + '_pred.npy', out_denorm.detach().cpu().numpy()) 77 | np.save('./results/' + args.cfd_model + '/' + str(index) + '_gt.npy', y_denorm.detach().cpu().numpy()) 78 | 79 | pred_coef = cal_coefficient(vallst[index].split('/')[1], pred_press[:, None].detach().cpu().numpy(), 80 | pred_surf_velo.detach().cpu().numpy()) 81 | gt_coef = cal_coefficient(vallst[index].split('/')[1], gt_press[:, None].detach().cpu().numpy(), 82 | gt_surf_velo.detach().cpu().numpy()) 83 | 84 | gt_coef_list.append(gt_coef) 85 | pred_coef_list.append(pred_coef) 86 | coef_error += (abs(pred_coef - gt_coef) / gt_coef) 87 | print(coef_error / (index + 1)) 88 | 89 | l2err_press = torch.norm(pred_press - gt_press) / torch.norm(gt_press) 90 | l2err_velo = torch.norm(pred_velo - gt_velo) / torch.norm(gt_velo) 91 | 92 | mse_press = criterion_func(out[cfd_data.surf, -1], targets[cfd_data.surf, -1]).mean(dim=0) 93 | mse_velo_var = criterion_func(out[~cfd_data.surf, :-1], targets[~cfd_data.surf, :-1]).mean(dim=0) 94 | 95 | l2errs_press.append(l2err_press.cpu().numpy()) 96 | l2errs_velo.append(l2err_velo.cpu().numpy()) 97 | mses_press.append(mse_press.cpu().numpy()) 98 | mses_velo_var.append(mse_velo_var.cpu().numpy()) 99 | times.append(toc - tic) 100 | index += 1 101 | 102 | gt_coef_list = np.array(gt_coef_list) 103 | pred_coef_list = np.array(pred_coef_list) 104 | spear = sc.stats.spearmanr(gt_coef_list, pred_coef_list)[0] 105 | print("rho_d: ", spear) 106 | print("c_d: ", coef_error / index) 107 | l2err_press = np.mean(l2errs_press) 108 | l2err_velo = np.mean(l2errs_velo) 109 | rmse_press = np.sqrt(np.mean(mses_press)) 110 | rmse_velo_var = np.sqrt(np.mean(mses_velo_var, axis=0)) 111 | if coef_norm is not None: 112 | rmse_press *= coef_norm[3][-1] 113 | rmse_velo_var *= coef_norm[3][:-1] 114 | print('relative l2 error press:', l2err_press) 115 | print('relative l2 error velo:', l2err_velo) 116 | print('press:', rmse_press) 117 | print('velo:', rmse_velo_var, np.sqrt(np.mean(np.square(rmse_velo_var)))) 118 | print('time:', np.mean(times)) 119 | -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/models/Transolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from timm.models.layers import trunc_normal_ 5 | from einops import rearrange, repeat 6 | 7 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1), 8 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU} 9 | 10 | 11 | class Physics_Attention_Irregular_Mesh(nn.Module): 12 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64): 13 | super().__init__() 14 | inner_dim = dim_head * heads 15 | self.dim_head = dim_head 16 | self.heads = heads 17 | self.scale = dim_head ** -0.5 18 | self.softmax = nn.Softmax(dim=-1) 19 | self.dropout = nn.Dropout(dropout) 20 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 21 | 22 | self.in_project_x = nn.Linear(dim, inner_dim) 23 | self.in_project_fx = nn.Linear(dim, inner_dim) 24 | self.in_project_slice = nn.Linear(dim_head, slice_num) 25 | for l in [self.in_project_slice]: 26 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 27 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 28 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 29 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 30 | self.to_out = nn.Sequential( 31 | nn.Linear(inner_dim, dim), 32 | nn.Dropout(dropout) 33 | ) 34 | 35 | def forward(self, x): 36 | # B N C 37 | B, N, C = x.shape 38 | 39 | ### (1) Slice 40 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \ 41 | .permute(0, 2, 1, 3).contiguous() # B H N C 42 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \ 43 | .permute(0, 2, 1, 3).contiguous() # B H N C 44 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G 45 | slice_norm = slice_weights.sum(2) # B H G 46 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 47 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 48 | 49 | ### (2) Attention among slice tokens 50 | q_slice_token = self.to_q(slice_token) 51 | k_slice_token = self.to_k(slice_token) 52 | v_slice_token = self.to_v(slice_token) 53 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 54 | attn = self.softmax(dots) 55 | attn = self.dropout(attn) 56 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 57 | 58 | ### (3) Deslice 59 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 60 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 61 | return self.to_out(out_x) 62 | 63 | 64 | class MLP(nn.Module): 65 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True): 66 | super(MLP, self).__init__() 67 | 68 | if act in ACTIVATION.keys(): 69 | act = ACTIVATION[act] 70 | else: 71 | raise NotImplementedError 72 | self.n_input = n_input 73 | self.n_hidden = n_hidden 74 | self.n_output = n_output 75 | self.n_layers = n_layers 76 | self.res = res 77 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act()) 78 | self.linear_post = nn.Linear(n_hidden, n_output) 79 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)]) 80 | 81 | def forward(self, x): 82 | x = self.linear_pre(x) 83 | for i in range(self.n_layers): 84 | if self.res: 85 | x = self.linears[i](x) + x 86 | else: 87 | x = self.linears[i](x) 88 | x = self.linear_post(x) 89 | return x 90 | 91 | 92 | class Transolver_block(nn.Module): 93 | """Transformer encoder block.""" 94 | 95 | def __init__( 96 | self, 97 | num_heads: int, 98 | hidden_dim: int, 99 | dropout: float, 100 | act='gelu', 101 | mlp_ratio=4, 102 | last_layer=False, 103 | out_dim=1, 104 | slice_num=32, 105 | ): 106 | super().__init__() 107 | self.last_layer = last_layer 108 | self.ln_1 = nn.LayerNorm(hidden_dim) 109 | self.Attn = Physics_Attention_Irregular_Mesh(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, 110 | dropout=dropout, slice_num=slice_num) 111 | self.ln_2 = nn.LayerNorm(hidden_dim) 112 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 113 | if self.last_layer: 114 | self.ln_3 = nn.LayerNorm(hidden_dim) 115 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 116 | 117 | def forward(self, fx): 118 | fx = self.Attn(self.ln_1(fx)) + fx 119 | fx = self.mlp(self.ln_2(fx)) + fx 120 | if self.last_layer: 121 | return self.mlp2(self.ln_3(fx)) 122 | else: 123 | return fx 124 | 125 | 126 | class Model(nn.Module): 127 | def __init__(self, 128 | space_dim=1, 129 | n_layers=5, 130 | n_hidden=256, 131 | dropout=0, 132 | n_head=8, 133 | act='gelu', 134 | mlp_ratio=1, 135 | fun_dim=1, 136 | out_dim=1, 137 | slice_num=32, 138 | ref=8, 139 | unified_pos=False 140 | ): 141 | super(Model, self).__init__() 142 | self.__name__ = 'UniPDE_3D' 143 | self.ref = ref 144 | self.unified_pos = unified_pos 145 | if self.unified_pos: 146 | self.preprocess = MLP(fun_dim + self.ref * self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, 147 | res=False, act=act) 148 | else: 149 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) 150 | 151 | self.n_hidden = n_hidden 152 | self.space_dim = space_dim 153 | 154 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden, 155 | dropout=dropout, 156 | act=act, 157 | mlp_ratio=mlp_ratio, 158 | out_dim=out_dim, 159 | slice_num=slice_num, 160 | last_layer=(_ == n_layers - 1)) 161 | for _ in range(n_layers)]) 162 | self.initialize_weights() 163 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float)) 164 | 165 | def initialize_weights(self): 166 | self.apply(self._init_weights) 167 | 168 | def _init_weights(self, m): 169 | if isinstance(m, nn.Linear): 170 | trunc_normal_(m.weight, std=0.02) 171 | if isinstance(m, nn.Linear) and m.bias is not None: 172 | nn.init.constant_(m.bias, 0) 173 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 174 | nn.init.constant_(m.bias, 0) 175 | nn.init.constant_(m.weight, 1.0) 176 | 177 | def get_grid(self, my_pos): 178 | # my_pos 1 N 3 179 | batchsize = my_pos.shape[0] 180 | 181 | gridx = torch.tensor(np.linspace(-1.5, 1.5, self.ref), dtype=torch.float) 182 | gridx = gridx.reshape(1, self.ref, 1, 1, 1).repeat([batchsize, 1, self.ref, self.ref, 1]) 183 | gridy = torch.tensor(np.linspace(0, 2, self.ref), dtype=torch.float) 184 | gridy = gridy.reshape(1, 1, self.ref, 1, 1).repeat([batchsize, self.ref, 1, self.ref, 1]) 185 | gridz = torch.tensor(np.linspace(-4, 4, self.ref), dtype=torch.float) 186 | gridz = gridz.reshape(1, 1, 1, self.ref, 1).repeat([batchsize, self.ref, self.ref, 1, 1]) 187 | grid_ref = torch.cat((gridx, gridy, gridz), dim=-1).cuda().reshape(batchsize, self.ref ** 3, 3) # B 4 4 4 3 188 | 189 | pos = torch.sqrt( 190 | torch.sum((my_pos[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, 191 | dim=-1)). \ 192 | reshape(batchsize, my_pos.shape[1], self.ref * self.ref * self.ref).contiguous() 193 | return pos 194 | 195 | def forward(self, data): 196 | cfd_data, geom_data = data 197 | x, fx, T = cfd_data.x, None, None 198 | x = x[None, :, :] 199 | if self.unified_pos: 200 | new_pos = self.get_grid(cfd_data.pos[None, :, :]) 201 | x = torch.cat((x, new_pos), dim=-1) 202 | 203 | if fx is not None: 204 | fx = torch.cat((x, fx), -1) 205 | fx = self.preprocess(fx) 206 | else: 207 | fx = self.preprocess(x) 208 | fx = fx + self.placeholder[None, None, :] 209 | 210 | for block in self.blocks: 211 | fx = block(fx) 212 | 213 | return fx[0] 214 | -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torch_geometric 3 | torch-cluster 4 | vtk 5 | timm 6 | einops -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/scripts/Evaluation.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | python main_evaluation.py \ 4 | --cfd_model=Transolver \ 5 | --data_dir /data/PDE_data/mlcfd_data/training_data \ 6 | --save_dir /data/PDE_data/mlcfd_data/preprocessed_data \ -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/scripts/Transolver.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=7 2 | 3 | python main.py \ 4 | --cfd_model=Transolver \ 5 | --data_dir /data/PDE_data/mlcfd_data/training_data \ 6 | --save_dir /data/PDE_data/mlcfd_data/preprocessed_data \ 7 | -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time, json, os 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch_geometric.loader import DataLoader 7 | from tqdm import tqdm 8 | 9 | 10 | def get_nb_trainable_params(model): 11 | ''' 12 | Return the number of trainable parameters 13 | ''' 14 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 15 | return sum([np.prod(p.size()) for p in model_parameters]) 16 | 17 | 18 | def train(device, model, train_loader, optimizer, scheduler, reg=1): 19 | model.train() 20 | 21 | criterion_func = nn.MSELoss(reduction='none') 22 | losses_press = [] 23 | losses_velo = [] 24 | for cfd_data, geom in train_loader: 25 | cfd_data = cfd_data.to(device) 26 | geom = geom.to(device) 27 | optimizer.zero_grad() 28 | out = model((cfd_data, geom)) 29 | targets = cfd_data.y 30 | 31 | loss_press = criterion_func(out[cfd_data.surf, -1], targets[cfd_data.surf, -1]).mean(dim=0) 32 | loss_velo_var = criterion_func(out[:, :-1], targets[:, :-1]).mean(dim=0) 33 | loss_velo = loss_velo_var.mean() 34 | total_loss = loss_velo + reg * loss_press 35 | 36 | total_loss.backward() 37 | 38 | optimizer.step() 39 | scheduler.step() 40 | 41 | losses_press.append(loss_press.item()) 42 | losses_velo.append(loss_velo.item()) 43 | 44 | return np.mean(losses_press), np.mean(losses_velo) 45 | 46 | 47 | @torch.no_grad() 48 | def test(device, model, test_loader): 49 | model.eval() 50 | 51 | criterion_func = nn.MSELoss(reduction='none') 52 | losses_press = [] 53 | losses_velo = [] 54 | for cfd_data, geom in test_loader: 55 | cfd_data = cfd_data.to(device) 56 | geom = geom.to(device) 57 | out = model((cfd_data, geom)) 58 | targets = cfd_data.y 59 | 60 | loss_press = criterion_func(out[cfd_data.surf, -1], targets[cfd_data.surf, -1]).mean(dim=0) 61 | loss_velo_var = criterion_func(out[:, :-1], targets[:, :-1]).mean(dim=0) 62 | loss_velo = loss_velo_var.mean() 63 | 64 | losses_press.append(loss_press.item()) 65 | losses_velo.append(loss_velo.item()) 66 | 67 | return np.mean(losses_press), np.mean(losses_velo) 68 | 69 | 70 | class NumpyEncoder(json.JSONEncoder): 71 | def default(self, obj): 72 | if isinstance(obj, np.ndarray): 73 | return obj.tolist() 74 | return json.JSONEncoder.default(self, obj) 75 | 76 | 77 | def main(device, train_dataset, val_dataset, Net, hparams, path, reg=1, val_iter=1, coef_norm=[]): 78 | model = Net.to(device) 79 | optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr']) 80 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( 81 | optimizer, 82 | max_lr=hparams['lr'], 83 | total_steps=(len(train_dataset) // hparams['batch_size'] + 1) * hparams['nb_epochs'], 84 | final_div_factor=1000., 85 | ) 86 | start = time.time() 87 | 88 | train_loss, val_loss = 1e5, 1e5 89 | pbar_train = tqdm(range(hparams['nb_epochs']), position=0) 90 | for epoch in pbar_train: 91 | train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, drop_last=True) 92 | loss_velo, loss_press = train(device, model, train_loader, optimizer, lr_scheduler, reg=reg) 93 | train_loss = loss_velo + reg * loss_press 94 | del (train_loader) 95 | 96 | if val_iter is not None and (epoch == hparams['nb_epochs'] - 1 or epoch % val_iter == 0): 97 | val_loader = DataLoader(val_dataset, batch_size=1) 98 | 99 | loss_velo, loss_press = test(device, model, val_loader) 100 | val_loss = loss_velo + reg * loss_press 101 | del (val_loader) 102 | 103 | pbar_train.set_postfix(train_loss=train_loss, val_loss=val_loss) 104 | else: 105 | pbar_train.set_postfix(train_loss=train_loss) 106 | 107 | end = time.time() 108 | time_elapsed = end - start 109 | params_model = get_nb_trainable_params(model).astype('float') 110 | print('Number of parameters:', params_model) 111 | print('Time elapsed: {0:.2f} seconds'.format(time_elapsed)) 112 | torch.save(model, path + os.sep + f'model_{hparams["nb_epochs"]}.pth') 113 | 114 | if val_iter is not None: 115 | with open(path + os.sep + f'log_{hparams["nb_epochs"]}.json', 'a') as f: 116 | json.dump( 117 | { 118 | 'nb_parameters': params_model, 119 | 'time_elapsed': time_elapsed, 120 | 'hparams': hparams, 121 | 'train_loss': train_loss, 122 | 'val_loss': val_loss, 123 | 'coef_norm': list(coef_norm), 124 | }, f, indent=12, cls=NumpyEncoder 125 | ) 126 | 127 | return model 128 | -------------------------------------------------------------------------------- /Car-Design-ShapeNetCar/utils/drag_coefficient.py: -------------------------------------------------------------------------------- 1 | import vtk 2 | import os 3 | import numpy as np 4 | from vtk.util.numpy_support import vtk_to_numpy 5 | from scipy.spatial import ConvexHull 6 | 7 | 8 | def unstructured_grid_data_to_poly_data(unstructured_grid_data): 9 | filter = vtk.vtkDataSetSurfaceFilter() 10 | filter.SetInputData(unstructured_grid_data) 11 | filter.Update() 12 | poly_data = filter.GetOutput() 13 | return poly_data, filter 14 | 15 | 16 | def load_unstructured_grid_data(file_name): 17 | reader = vtk.vtkUnstructuredGridReader() 18 | reader.SetFileName(file_name) 19 | reader.Update() 20 | output = reader.GetOutput() 21 | return output 22 | 23 | 24 | ############## calculate rectangle ############## 25 | def calculate_pos(pos): 26 | hull = ConvexHull(pos[:, :2]) 27 | A = hull.volume 28 | return A 29 | 30 | 31 | ############## surf area ############## 32 | def calculate_mesh_cell_area(unstructured_grid_data): 33 | # Read VTK file 34 | poly_data, _ = unstructured_grid_data_to_poly_data(unstructured_grid_data) 35 | 36 | # Get the points and cells 37 | points = poly_data.GetPoints() 38 | cells = poly_data.GetPolys() 39 | 40 | # Initialize an array to store point areas 41 | cell_areas = np.zeros(cells.GetNumberOfCells()) 42 | 43 | # Iterate through cells to calculate areas 44 | cells.InitTraversal() 45 | cell = vtk.vtkIdList() 46 | id = 0 47 | while cells.GetNextCell(cell): 48 | # Check if the cell is a quadrilateral 49 | if cell.GetNumberOfIds() == 4: 50 | # Get the four vertices of the quadrilateral 51 | p1 = np.array(points.GetPoint(cell.GetId(0))) 52 | p2 = np.array(points.GetPoint(cell.GetId(1))) 53 | p3 = np.array(points.GetPoint(cell.GetId(2))) 54 | p4 = np.array(points.GetPoint(cell.GetId(3))) 55 | # Calculate the area of the quadrilateral 56 | area = 0.5 * ( 57 | np.linalg.norm(np.cross(p2 - p1, p3 - p1)) + 58 | np.linalg.norm(np.cross(p3 - p1, p4 - p1)) 59 | ) 60 | 61 | # Add the area to each vertex of the quadrilateral 62 | cell_areas[id] += area 63 | id += 1 64 | 65 | return cell_areas 66 | 67 | 68 | ############## velocity gradient ############## 69 | def calculate_cell_velocity_gradient(unstructured_grid_data, velocity): 70 | # Create a vtkDoubleArray for velocity 71 | velocity_data = vtk.vtkDoubleArray() 72 | velocity_data.SetNumberOfComponents(3) # Assuming 3D velocity field 73 | velocity_data.SetNumberOfTuples(unstructured_grid_data.GetNumberOfPoints()) 74 | velocity_data.SetName("Velocity") # Replace "Velocity" with the desired array name 75 | 76 | # Set the velocity array values 77 | for i in range(unstructured_grid_data.GetNumberOfPoints()): 78 | velocity_data.SetTuple(i, velocity[i]) 79 | 80 | # Add the velocity array to the point data 81 | unstructured_grid_data.GetPointData().AddArray(velocity_data) 82 | 83 | # Get the points and cell data (assuming velocity is stored as point data) 84 | poly_data, _ = unstructured_grid_data_to_poly_data(unstructured_grid_data) 85 | points = poly_data.GetPoints() 86 | 87 | # Initialize arrays to store velocity gradients 88 | grad_u = np.zeros((poly_data.GetNumberOfCells(), 3)) # Assuming 3D velocity field 89 | # Iterate through cells to calculate gradients 90 | cells = poly_data.GetPolys() 91 | cells.InitTraversal() 92 | cell = vtk.vtkIdList() 93 | id = 0 94 | while cells.GetNextCell(cell): 95 | # Check if the cell is a quadrilateral 96 | if cell.GetNumberOfIds() == 4: 97 | # Get the four vertices of the quadrilateral 98 | p1 = np.array(points.GetPoint(cell.GetId(0))) 99 | p2 = np.array(points.GetPoint(cell.GetId(1))) 100 | p3 = np.array(points.GetPoint(cell.GetId(2))) 101 | p4 = np.array(points.GetPoint(cell.GetId(3))) 102 | # Calculate the velocity at each vertex 103 | u1 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(0))) 104 | u2 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(1))) 105 | u3 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(2))) 106 | u4 = np.array(poly_data.GetPointData().GetArray("Velocity").GetTuple(cell.GetId(3))) 107 | 108 | # Calculate the gradients using finite differences 109 | du_dx = (u2 - u1 + u3 - u4) / (np.linalg.norm(p2 - p1 + p3 - p4) + 1e-8) 110 | du_dy = (u3 - u1 + u4 - u2) / (np.linalg.norm(p3 - p1 + p4 - p2) + 1e-8) 111 | du_dz = (u4 - u1 + u2 - u3) / (np.linalg.norm(p4 - p1 + p2 - p3) + 1e-8) 112 | 113 | # Add the gradients to each vertex of the quadrilateral 114 | grad_u[id] += (du_dx + du_dy + du_dz) 115 | id += 1 116 | 117 | return grad_u 118 | 119 | 120 | ############## calculate drag ############## 121 | def calculate_drag_force(cell_areas, surface_normals, pressure_array, velocity_gradients, dynamic_viscosity): 122 | # Calculate the pressure force component along the flow direction 123 | pressure_force_component = -np.dot(pressure_array.flatten() * cell_areas.flatten(), surface_normals.flatten()) 124 | 125 | # Calculate the wall shear stress component along the flow direction 126 | wall_shear_stress_component = -np.dot(velocity_gradients.flatten() * cell_areas.flatten(), 127 | surface_normals.flatten()) * dynamic_viscosity 128 | # Sum the pressure force and wall shear stress components to get the total drag force 129 | drag_force = np.sum(pressure_force_component + wall_shear_stress_component) 130 | 131 | return drag_force 132 | 133 | 134 | ############## calculate norm ############## 135 | def get_normal(unstructured_grid_data): 136 | poly_data, surface_filter = unstructured_grid_data_to_poly_data(unstructured_grid_data) 137 | normal_filter = vtk.vtkPolyDataNormals() 138 | normal_filter.SetInputData(poly_data) 139 | normal_filter.SetAutoOrientNormals(1) 140 | normal_filter.SetConsistency(1) 141 | normal_filter.SetComputeCellNormals(1) 142 | normal_filter.SetComputePointNormals(0) 143 | normal_filter.Update() 144 | return vtk_to_numpy(normal_filter.GetOutput().GetCellData().GetNormals()) 145 | 146 | 147 | ############## calculate coefficient ############## 148 | def cal_coefficient(file_name, press_surf=None, velo_surf=None): 149 | root = '/data/PDE_data/mlcfd_data/training_data' 150 | save_path = '/data/PDE_data/mlcfd_data/preprocessed_data/param0/' + file_name 151 | file_name_press = 'param0/' + file_name + '/quadpress_smpl.vtk' 152 | file_name_velo = 'param0/' + file_name + '/hexvelo_smpl.vtk' 153 | file_name_press = os.path.join(root, file_name_press) 154 | file_name_velo = os.path.join(root, file_name_velo) 155 | unstructured_grid_data_press = load_unstructured_grid_data(file_name_press) 156 | unstructured_grid_data_velo = load_unstructured_grid_data(file_name_velo) 157 | 158 | # normal 159 | normal_surf = get_normal(unstructured_grid_data_press) 160 | # front area 161 | points_surf = vtk_to_numpy(unstructured_grid_data_press.GetPoints().GetData()) 162 | A = calculate_pos(points_surf) 163 | # mesh area 164 | cell_areas = calculate_mesh_cell_area(unstructured_grid_data_press) 165 | # mesh velo 166 | if velo_surf is None: 167 | velo = vtk_to_numpy(unstructured_grid_data_velo.GetPointData().GetVectors()) 168 | points_velo = vtk_to_numpy(unstructured_grid_data_velo.GetPoints().GetData()) 169 | velo_dict = {tuple(p): velo[i] for i, p in enumerate(points_velo)} 170 | velo_surf = np.array([velo_dict[tuple(p)] if tuple(p) in velo_dict else np.zeros(3) for p in points_surf]) 171 | # gradient u 172 | grad_u = calculate_cell_velocity_gradient(unstructured_grid_data_press, velo_surf) 173 | # press 174 | if press_surf is None: 175 | c2p = vtk.vtkPointDataToCellData() 176 | c2p.SetInputData(unstructured_grid_data_press) 177 | c2p.Update() 178 | unstructured_grid_data_press = c2p.GetOutput() 179 | press_surf = vtk_to_numpy(unstructured_grid_data_press.GetCellData().GetScalars()) 180 | else: 181 | # Create a vtkDoubleArray for press 182 | press_data = vtk.vtkDoubleArray() 183 | press_data.SetNumberOfComponents(1) # Assuming 3D velocity field 184 | press_data.SetNumberOfTuples(unstructured_grid_data_press.GetNumberOfPoints()) 185 | press_data.SetName("my_press") # Replace "my_press" with the desired array name 186 | 187 | # Set the velocity array values 188 | for i in range(unstructured_grid_data_press.GetNumberOfPoints()): 189 | press_data.SetTuple(i, press_surf[i]) 190 | 191 | # Add the velocity array to the point data 192 | unstructured_grid_data_press.GetPointData().AddArray(press_data) 193 | c2p = vtk.vtkPointDataToCellData() 194 | c2p.SetInputData(unstructured_grid_data_press) 195 | c2p.Update() 196 | unstructured_grid_data_press = c2p.GetOutput() 197 | press_surf = vtk_to_numpy(unstructured_grid_data_press.GetCellData().GetArray("my_press")) 198 | 199 | drag_force = calculate_drag_force(cell_areas, normal_surf[:, -1], press_surf, grad_u[:, -1], np.array(1.8e-5)) 200 | nu = 72 / 3.6 201 | air_density = 0.3 202 | cd = (2 / ((nu ** 2) * A * air_density)) * drag_force 203 | return cd 204 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/README.md: -------------------------------------------------------------------------------- 1 | # Transolver for PDE Solving 2 | 3 | We evaluate [Transolver](https://arxiv.org/abs/2402.02366) with six widely used PDE-solving benchmarks, which is provided by [FNO and GeoFNO](https://github.com/neuraloperator/neuraloperator). 4 | 5 | **Transolver achieves 22% averaged relative promotion over the previous second-best model, presenting favorable efficiency and scalibility.** 6 | 7 |

8 | 9 |

10 | Table 1. Comparison in six standard benchmarks. Relative L2 is recorded. 11 |

12 | 13 | 14 | ## Get Started 15 | 16 | 1. Install Python 3.8. For convenience, execute the following command. 17 | 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | 2. Prepare Data. You can obtain experimental datasets from the following links. 23 | 24 | 25 | | Dataset | Task | Geometry | Link | 26 | | ------------- | --------------------------------------- | --------------- | ------------------------------------------------------------ | 27 | | Elasticity | Estimate material inner stress | Point Cloud | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) | 28 | | Plasticity | Estimate material deformation over time | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) | 29 | | Navier-Stokes | Predict future fluid velocity | Regular Grid | [[Google Cloud]](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) | 30 | | Darcy | Estimate fluid pressure through medium | Regular Grid | [[Google Cloud]](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-) | 31 | | AirFoil | Estimate airflow velocity around airfoil | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) | 32 | | Pipe | Estimate fluid velocity in a pipe | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) | 33 | 34 | 3. Train and evaluate model. We provide the experiment scripts of all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples: 35 | 36 | ```bash 37 | bash scripts/Transolver_Elas.sh # for Elasticity 38 | bash scripts/Transolver_Plas.sh # for Plasticity 39 | bash scripts/Transolver_NS.sh # for Navier-Stokes 40 | bash scripts/Transolver_Darcy.sh # for Darcy 41 | bash scripts/Transolver_Airfoil.sh # for Airfoil 42 | bash scripts/Transolver_Pipe.sh # for Pipe 43 | ``` 44 | 45 | Note: You need to change the argument `--data_path` to your dataset path. 46 | 47 | 4. Develop your own model. Here are the instructions: 48 | 49 | - Add the model file under folder `./models/`. 50 | - Add the model name into `./model_dict.py`. 51 | - Add a script file under folder `./scripts/` and change the argument `--model`. 52 | 53 | ## Visualization 54 | 55 | Transolver can handle PDEs under various geometrics well, such as predicting the future fluid and estimating the [[shock wave]](https://en.wikipedia.org/wiki/Shock_wave) around airfoil. 56 | 57 |

58 | 59 |

60 | Figure 1. Case study of different models. 61 |

62 | 63 | ## PDE Solving at Scale 64 | 65 | To align with previous model, we only experiment with 8-layer Transolver in the main text. Actually, you can easily obtain a better performance by **scaling up Transolver**. The relative L2 generally decreases when we adding more layers. 66 | 67 |

68 | 69 |

70 | Figure 2. Scaling up Transolver: relative L2 curve w.r.t. model layers. 71 |

72 | 73 | ## Citation 74 | 75 | If you find this repo useful, please cite our paper. 76 | 77 | ``` 78 | @inproceedings{wu2024Transolver, 79 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries}, 80 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long}, 81 | booktitle={International Conference on Machine Learning}, 82 | year={2024} 83 | } 84 | ``` 85 | 86 | ## Contact 87 | 88 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn). 89 | 90 | ## Acknowledgement 91 | 92 | We appreciate the following github repos a lot for their valuable code base or datasets: 93 | 94 | https://github.com/neuraloperator/neuraloperator 95 | 96 | https://github.com/neuraloperator/Geo-FNO 97 | 98 | https://github.com/thuml/Latent-Spectral-Models 99 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/exp_airfoil.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | from tqdm import * 7 | from utils.testloss import TestLoss 8 | from model_dict import get_model 9 | 10 | parser = argparse.ArgumentParser('Training Transformer') 11 | 12 | parser.add_argument('--lr', type=float, default=1e-3) 13 | parser.add_argument('--epochs', type=int, default=500) 14 | parser.add_argument('--weight_decay', type=float, default=1e-5) 15 | parser.add_argument('--model', type=str, default='Transolver_2D') 16 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim') 17 | parser.add_argument('--n-layers', type=int, default=3, help='layers') 18 | parser.add_argument('--n-heads', type=int, default=4) 19 | parser.add_argument('--batch-size', type=int, default=8) 20 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use") 21 | parser.add_argument('--max_grad_norm', type=float, default=None) 22 | parser.add_argument('--downsamplex', type=int, default=1) 23 | parser.add_argument('--downsampley', type=int, default=1) 24 | parser.add_argument('--mlp_ratio', type=int, default=1) 25 | parser.add_argument('--dropout', type=float, default=0.0) 26 | parser.add_argument('--unified_pos', type=int, default=0) 27 | parser.add_argument('--ref', type=int, default=8) 28 | parser.add_argument('--slice_num', type=int, default=32) 29 | parser.add_argument('--eval', type=int, default=0) 30 | parser.add_argument('--save_name', type=str, default='airfoil_Transolver') 31 | parser.add_argument('--data_path', type=str, default='/data/fno/airfoil/naca') 32 | args = parser.parse_args() 33 | eval = args.eval 34 | save_name = args.save_name 35 | 36 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 37 | 38 | 39 | def count_parameters(model): 40 | total_params = 0 41 | for name, parameter in model.named_parameters(): 42 | if not parameter.requires_grad: continue 43 | params = parameter.numel() 44 | total_params += params 45 | print(f"Total Trainable Params: {total_params}") 46 | return total_params 47 | 48 | 49 | def main(): 50 | INPUT_X = args.data_path + '/NACA_Cylinder_X.npy' 51 | INPUT_Y = args.data_path + '/NACA_Cylinder_Y.npy' 52 | OUTPUT_Sigma = args.data_path + '/NACA_Cylinder_Q.npy' 53 | 54 | ntrain = 1000 55 | ntest = 200 56 | 57 | r1 = args.downsamplex 58 | r2 = args.downsampley 59 | s1 = int(((221 - 1) / r1) + 1) 60 | s2 = int(((51 - 1) / r2) + 1) 61 | 62 | inputX = np.load(INPUT_X) 63 | inputX = torch.tensor(inputX, dtype=torch.float) 64 | inputY = np.load(INPUT_Y) 65 | inputY = torch.tensor(inputY, dtype=torch.float) 66 | input = torch.stack([inputX, inputY], dim=-1) 67 | 68 | output = np.load(OUTPUT_Sigma)[:, 4] 69 | output = torch.tensor(output, dtype=torch.float) 70 | print(input.shape, output.shape) 71 | 72 | x_train = input[:ntrain, ::r1, ::r2][:, :s1, :s2] 73 | y_train = output[:ntrain, ::r1, ::r2][:, :s1, :s2] 74 | x_test = input[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2] 75 | y_test = output[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2] 76 | x_train = x_train.reshape(ntrain, -1, 2) 77 | x_test = x_test.reshape(ntest, -1, 2) 78 | y_train = y_train.reshape(ntrain, -1) 79 | y_test = y_test.reshape(ntest, -1) 80 | 81 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, x_train, y_train), 82 | batch_size=args.batch_size, 83 | shuffle=True) 84 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, x_test, y_test), 85 | batch_size=args.batch_size, 86 | shuffle=False) 87 | 88 | print("Dataloading is over.") 89 | 90 | model = get_model(args).Model(space_dim=2, 91 | n_layers=args.n_layers, 92 | n_hidden=args.n_hidden, 93 | dropout=args.dropout, 94 | n_head=args.n_heads, 95 | Time_Input=False, 96 | mlp_ratio=args.mlp_ratio, 97 | fun_dim=0, 98 | out_dim=1, 99 | slice_num=args.slice_num, 100 | ref=args.ref, 101 | unified_pos=args.unified_pos, 102 | H=s1, W=s2).cuda() 103 | 104 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 105 | print(args) 106 | print(model) 107 | count_parameters(model) 108 | 109 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs, 110 | steps_per_epoch=len(train_loader)) 111 | myloss = TestLoss(size_average=False) 112 | 113 | if eval: 114 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt")) 115 | model.eval() 116 | if not os.path.exists('./results/' + save_name + '/'): 117 | os.makedirs('./results/' + save_name + '/') 118 | 119 | rel_err = 0.0 120 | showcase = 10 121 | id = 0 122 | 123 | with torch.no_grad(): 124 | for pos, fx, y in test_loader: 125 | id += 1 126 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 127 | out = model(x, None).squeeze(-1) 128 | 129 | tl = myloss(out, y).item() 130 | rel_err += tl 131 | if id < showcase: 132 | print(id) 133 | plt.axis('off') 134 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 135 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 136 | np.zeros([140, 35]), 137 | shading='auto', 138 | edgecolors='black', linewidths=0.1) 139 | plt.colorbar() 140 | plt.savefig( 141 | os.path.join('./results/' + save_name + '/', 142 | "input_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 143 | plt.close() 144 | plt.axis('off') 145 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 146 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 147 | out[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 148 | shading='auto', cmap='coolwarm') 149 | plt.colorbar() 150 | plt.clim(0, 1.2) 151 | plt.savefig( 152 | os.path.join('./results/' + save_name + '/', 153 | "pred_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 154 | plt.close() 155 | plt.axis('off') 156 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 157 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 158 | y[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 159 | shading='auto', cmap='coolwarm') 160 | plt.colorbar() 161 | plt.clim(0, 1.2) 162 | plt.savefig( 163 | os.path.join('./results/' + save_name + '/', 164 | "gt_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 165 | plt.close() 166 | plt.axis('off') 167 | plt.pcolormesh(x[0, :, 0].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 168 | x[0, :, 1].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 169 | out[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy() - \ 170 | y[0, :].reshape(221, 51)[40:180, :35].detach().cpu().numpy(), 171 | shading='auto', cmap='coolwarm') 172 | plt.colorbar() 173 | plt.clim(-0.2, 0.2) 174 | plt.savefig( 175 | os.path.join('./results/' + save_name + '/', 176 | "error_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 177 | plt.close() 178 | 179 | rel_err /= ntest 180 | print("rel_err:{}".format(rel_err)) 181 | else: 182 | for ep in range(args.epochs): 183 | 184 | model.train() 185 | train_loss = 0 186 | 187 | for pos, fx, y in train_loader: 188 | 189 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() # x:B,N,2 fx:B,N,2 y:B,N 190 | optimizer.zero_grad() 191 | out = model(x, None).squeeze(-1) 192 | loss = myloss(out, y) 193 | loss.backward() 194 | 195 | if args.max_grad_norm is not None: 196 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 197 | optimizer.step() 198 | train_loss += loss.item() 199 | scheduler.step() 200 | 201 | train_loss = train_loss / ntrain 202 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss)) 203 | 204 | model.eval() 205 | rel_err = 0.0 206 | with torch.no_grad(): 207 | for pos, fx, y in test_loader: 208 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 209 | out = model(x, None).squeeze(-1) 210 | 211 | tl = myloss(out, y).item() 212 | rel_err += tl 213 | 214 | rel_err /= ntest 215 | print("rel_err:{}".format(rel_err)) 216 | 217 | if ep % 100 == 0: 218 | if not os.path.exists('./checkpoints'): 219 | os.makedirs('./checkpoints') 220 | print('save model') 221 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 222 | 223 | if not os.path.exists('./checkpoints'): 224 | os.makedirs('./checkpoints') 225 | print('save model') 226 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/exp_darcy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import scipy.io as scio 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import * 8 | from utils.testloss import TestLoss 9 | from einops import rearrange 10 | from model_dict import get_model 11 | from utils.normalizer import UnitTransformer 12 | import matplotlib.pyplot as plt 13 | 14 | parser = argparse.ArgumentParser('Training Transolver') 15 | 16 | parser.add_argument('--lr', type=float, default=1e-3) 17 | parser.add_argument('--epochs', type=int, default=500) 18 | parser.add_argument('--weight_decay', type=float, default=1e-5) 19 | parser.add_argument('--model', type=str, default='Transolver_2D') 20 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim') 21 | parser.add_argument('--n-layers', type=int, default=3, help='layers') 22 | parser.add_argument('--n-heads', type=int, default=4) 23 | parser.add_argument('--batch-size', type=int, default=8) 24 | parser.add_argument("--gpu", type=str, default='1', help="GPU index to use") 25 | parser.add_argument('--max_grad_norm', type=float, default=None) 26 | parser.add_argument('--downsample', type=int, default=5) 27 | parser.add_argument('--mlp_ratio', type=int, default=1) 28 | parser.add_argument('--dropout', type=float, default=0.0) 29 | parser.add_argument('--ntrain', type=int, default=1000) 30 | parser.add_argument('--unified_pos', type=int, default=0) 31 | parser.add_argument('--ref', type=int, default=8) 32 | parser.add_argument('--slice_num', type=int, default=32) 33 | parser.add_argument('--eval', type=int, default=0) 34 | parser.add_argument('--save_name', type=str, default='darcy_Transolver') 35 | parser.add_argument('--data_path', type=str, default='/data/fno') 36 | args = parser.parse_args() 37 | 38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 39 | 40 | train_path = args.data_path + '/piececonst_r421_N1024_smooth1.mat' 41 | test_path = args.data_path + '/piececonst_r421_N1024_smooth2.mat' 42 | ntrain = args.ntrain 43 | ntest = 200 44 | epochs = args.epochs 45 | eval = args.eval 46 | save_name = args.save_name 47 | 48 | 49 | def count_parameters(model): 50 | total_params = 0 51 | for name, parameter in model.named_parameters(): 52 | if not parameter.requires_grad: continue 53 | params = parameter.numel() 54 | total_params += params 55 | print(f"Total Trainable Params: {total_params}") 56 | return total_params 57 | 58 | 59 | def central_diff(x: torch.Tensor, h, resolution): 60 | # assuming PBC 61 | # x: (batch, n, feats), h is the step size, assuming n = h*w 62 | x = rearrange(x, 'b (h w) c -> b h w c', h=resolution, w=resolution) 63 | x = F.pad(x, 64 | (0, 0, 1, 1, 1, 1), mode='constant', value=0.) # [b c t h+2 w+2] 65 | grad_x = (x[:, 1:-1, 2:, :] - x[:, 1:-1, :-2, :]) / (2 * h) # f(x+h) - f(x-h) / 2h 66 | grad_y = (x[:, 2:, 1:-1, :] - x[:, :-2, 1:-1, :]) / (2 * h) # f(x+h) - f(x-h) / 2h 67 | 68 | return grad_x, grad_y 69 | 70 | 71 | def main(): 72 | r = args.downsample 73 | h = int(((421 - 1) / r) + 1) 74 | s = h 75 | dx = 1.0 / s 76 | 77 | train_data = scio.loadmat(train_path) 78 | x_train = train_data['coeff'][:ntrain, ::r, ::r][:, :s, :s] 79 | x_train = x_train.reshape(ntrain, -1) 80 | x_train = torch.from_numpy(x_train).float() 81 | y_train = train_data['sol'][:ntrain, ::r, ::r][:, :s, :s] 82 | y_train = y_train.reshape(ntrain, -1) 83 | y_train = torch.from_numpy(y_train) 84 | 85 | test_data = scio.loadmat(test_path) 86 | x_test = test_data['coeff'][:ntest, ::r, ::r][:, :s, :s] 87 | x_test = x_test.reshape(ntest, -1) 88 | x_test = torch.from_numpy(x_test).float() 89 | y_test = test_data['sol'][:ntest, ::r, ::r][:, :s, :s] 90 | y_test = y_test.reshape(ntest, -1) 91 | y_test = torch.from_numpy(y_test) 92 | 93 | x_normalizer = UnitTransformer(x_train) 94 | y_normalizer = UnitTransformer(y_train) 95 | 96 | x_train = x_normalizer.encode(x_train) 97 | x_test = x_normalizer.encode(x_test) 98 | y_train = y_normalizer.encode(y_train) 99 | 100 | x_normalizer.cuda() 101 | y_normalizer.cuda() 102 | 103 | x = np.linspace(0, 1, s) 104 | y = np.linspace(0, 1, s) 105 | x, y = np.meshgrid(x, y) 106 | pos = np.c_[x.ravel(), y.ravel()] 107 | pos = torch.tensor(pos, dtype=torch.float).unsqueeze(0) 108 | 109 | pos_train = pos.repeat(ntrain, 1, 1) 110 | pos_test = pos.repeat(ntest, 1, 1) 111 | print("Dataloading is over.") 112 | 113 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_train, x_train, y_train), 114 | batch_size=args.batch_size, shuffle=True) 115 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, x_test, y_test), 116 | batch_size=args.batch_size, shuffle=False) 117 | 118 | model = get_model(args).Model(space_dim=2, 119 | n_layers=args.n_layers, 120 | n_hidden=args.n_hidden, 121 | dropout=args.dropout, 122 | n_head=args.n_heads, 123 | Time_Input=False, 124 | mlp_ratio=args.mlp_ratio, 125 | fun_dim=1, 126 | out_dim=1, 127 | slice_num=args.slice_num, 128 | ref=args.ref, 129 | unified_pos=args.unified_pos, 130 | H=s, W=s).cuda() 131 | 132 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 133 | 134 | print(args) 135 | print(model) 136 | count_parameters(model) 137 | 138 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=epochs, 139 | steps_per_epoch=len(train_loader)) 140 | myloss = TestLoss(size_average=False) 141 | de_x = TestLoss(size_average=False) 142 | de_y = TestLoss(size_average=False) 143 | 144 | if eval: 145 | print("model evaluation") 146 | print(s, s) 147 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"), strict=False) 148 | model.eval() 149 | showcase = 10 150 | id = 0 151 | if not os.path.exists('./results/' + save_name + '/'): 152 | os.makedirs('./results/' + save_name + '/') 153 | 154 | with torch.no_grad(): 155 | rel_err = 0.0 156 | with torch.no_grad(): 157 | for x, fx, y in test_loader: 158 | id += 1 159 | x, fx, y = x.cuda(), fx.cuda(), y.cuda() 160 | out = model(x, fx=fx.unsqueeze(-1)).squeeze(-1) 161 | out = y_normalizer.decode(out) 162 | tl = myloss(out, y).item() 163 | 164 | rel_err += tl 165 | 166 | if id < showcase: 167 | print(id) 168 | plt.figure() 169 | plt.axis('off') 170 | plt.imshow(out[0, :].reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm') 171 | plt.colorbar() 172 | plt.savefig( 173 | os.path.join('./results/' + save_name + '/', 174 | "case_" + str(id) + "_pred.pdf")) 175 | plt.close() 176 | # ============ # 177 | plt.figure() 178 | plt.axis('off') 179 | plt.imshow(y[0, :].reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm') 180 | plt.colorbar() 181 | plt.savefig( 182 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_gt.pdf")) 183 | plt.close() 184 | # ============ # 185 | plt.figure() 186 | plt.axis('off') 187 | plt.imshow((y[0, :] - out[0, :]).reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm') 188 | plt.colorbar() 189 | plt.clim(-0.0005, 0.0005) 190 | plt.savefig( 191 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_error.pdf")) 192 | plt.close() 193 | # ============ # 194 | plt.figure() 195 | plt.axis('off') 196 | plt.imshow((fx[0, :].unsqueeze(-1)).reshape(85, 85).detach().cpu().numpy(), cmap='coolwarm') 197 | plt.colorbar() 198 | plt.savefig( 199 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_input.pdf")) 200 | plt.close() 201 | 202 | rel_err /= ntest 203 | print("rel_err:{}".format(rel_err)) 204 | else: 205 | for ep in range(args.epochs): 206 | model.train() 207 | train_loss = 0 208 | reg = 0 209 | for x, fx, y in train_loader: 210 | x, fx, y = x.cuda(), fx.cuda(), y.cuda() 211 | optimizer.zero_grad() 212 | 213 | out = model(x, fx=fx.unsqueeze(-1)).squeeze(-1) # B, N , 2, fx: B, N, y: B, N 214 | out = y_normalizer.decode(out) 215 | y = y_normalizer.decode(y) 216 | 217 | l2loss = myloss(out, y) 218 | 219 | out = rearrange(out.unsqueeze(-1), 'b (h w) c -> b c h w', h=s) 220 | out = out[..., 1:-1, 1:-1].contiguous() 221 | out = F.pad(out, (1, 1, 1, 1), "constant", 0) 222 | out = rearrange(out, 'b c h w -> b (h w) c') 223 | gt_grad_x, gt_grad_y = central_diff(y.unsqueeze(-1), dx, s) 224 | pred_grad_x, pred_grad_y = central_diff(out, dx, s) 225 | deriv_loss = de_x(pred_grad_x, gt_grad_x) + de_y(pred_grad_y, gt_grad_y) 226 | loss = 0.1 * deriv_loss + l2loss 227 | loss.backward() 228 | 229 | if args.max_grad_norm is not None: 230 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 231 | optimizer.step() 232 | train_loss += l2loss.item() 233 | reg += deriv_loss.item() 234 | scheduler.step() 235 | 236 | train_loss /= ntrain 237 | reg /= ntrain 238 | print("Epoch {} Reg : {:.5f} Train loss : {:.5f}".format(ep, reg, train_loss)) 239 | 240 | model.eval() 241 | rel_err = 0.0 242 | id = 0 243 | with torch.no_grad(): 244 | for x, fx, y in test_loader: 245 | id += 1 246 | if id == 2: 247 | vis = True 248 | else: 249 | vis = False 250 | x, fx, y = x.cuda(), fx.cuda(), y.cuda() 251 | out = model(x, fx=fx.unsqueeze(-1)).squeeze(-1) 252 | out = y_normalizer.decode(out) 253 | tl = myloss(out, y).item() 254 | rel_err += tl 255 | 256 | rel_err /= ntest 257 | print("rel_err:{}".format(rel_err)) 258 | 259 | if ep % 100 == 0: 260 | if not os.path.exists('./checkpoints'): 261 | os.makedirs('./checkpoints') 262 | print('save model') 263 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 264 | 265 | if not os.path.exists('./checkpoints'): 266 | os.makedirs('./checkpoints') 267 | print('save model') 268 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 269 | 270 | 271 | if __name__ == "__main__": 272 | main() 273 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/exp_elas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | from tqdm import * 7 | from utils.testloss import TestLoss 8 | from model_dict import get_model 9 | from utils.normalizer import UnitTransformer 10 | 11 | parser = argparse.ArgumentParser('Training Transformer') 12 | 13 | parser.add_argument('--lr', type=float, default=1e-3) 14 | parser.add_argument('--epochs', type=int, default=500) 15 | parser.add_argument('--weight_decay', type=float, default=1e-5) 16 | parser.add_argument('--model', type=str, default='Transolver_1D') 17 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim') 18 | parser.add_argument('--n-layers', type=int, default=3, help='layers') 19 | parser.add_argument('--n-heads', type=int, default=4) 20 | parser.add_argument('--batch-size', type=int, default=8) 21 | parser.add_argument("--gpu", type=str, default='1', help="GPU index to use") 22 | parser.add_argument('--max_grad_norm', type=float, default=None) 23 | parser.add_argument('--downsample', type=int, default=5) 24 | parser.add_argument('--mlp_ratio', type=int, default=1) 25 | parser.add_argument('--dropout', type=float, default=0.0) 26 | parser.add_argument('--ntrain', type=int, default=1000) 27 | parser.add_argument('--unified_pos', type=int, default=0) 28 | parser.add_argument('--ref', type=int, default=8) 29 | parser.add_argument('--slice_num', type=int, default=32) 30 | parser.add_argument('--eval', type=int, default=0) 31 | parser.add_argument('--save_name', type=str, default='elas_Transolver') 32 | parser.add_argument('--data_path', type=str, default='/data/fno') 33 | args = parser.parse_args() 34 | eval = args.eval 35 | save_name = args.save_name 36 | 37 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 38 | 39 | 40 | def count_parameters(model): 41 | total_params = 0 42 | for name, parameter in model.named_parameters(): 43 | if not parameter.requires_grad: continue 44 | params = parameter.numel() 45 | total_params += params 46 | print(f"Total Trainable Params: {total_params}") 47 | return total_params 48 | 49 | 50 | def main(): 51 | ntrain = args.ntrain 52 | ntest = 200 53 | 54 | PATH_Sigma = args.data_path + '/elasticity/Meshes/Random_UnitCell_sigma_10.npy' 55 | PATH_XY = args.data_path + '/elasticity/Meshes/Random_UnitCell_XY_10.npy' 56 | 57 | input_s = np.load(PATH_Sigma) 58 | input_s = torch.tensor(input_s, dtype=torch.float).permute(1, 0) 59 | input_xy = np.load(PATH_XY) 60 | input_xy = torch.tensor(input_xy, dtype=torch.float).permute(2, 0, 1) 61 | 62 | train_s = input_s[:ntrain] 63 | test_s = input_s[-ntest:] 64 | train_xy = input_xy[:ntrain] 65 | test_xy = input_xy[-ntest:] 66 | 67 | print(input_s.shape, input_xy.shape) 68 | 69 | y_normalizer = UnitTransformer(train_s) 70 | 71 | train_s = y_normalizer.encode(train_s) 72 | y_normalizer.cuda() 73 | 74 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_xy, train_xy, train_s), 75 | batch_size=args.batch_size, 76 | shuffle=True) 77 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_xy, test_xy, test_s), 78 | batch_size=args.batch_size, 79 | shuffle=False) 80 | 81 | print("Dataloading is over.") 82 | 83 | model = get_model(args).Model(space_dim=2, 84 | n_layers=args.n_layers, 85 | n_hidden=args.n_hidden, 86 | dropout=args.dropout, 87 | n_head=args.n_heads, 88 | Time_Input=False, 89 | mlp_ratio=args.mlp_ratio, 90 | fun_dim=0, 91 | out_dim=1, 92 | slice_num=args.slice_num, 93 | ref=args.ref, 94 | unified_pos=args.unified_pos).cuda() 95 | 96 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 97 | 98 | print(args) 99 | print(model) 100 | count_parameters(model) 101 | 102 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) 103 | 104 | myloss = TestLoss(size_average=False) 105 | 106 | if eval: 107 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt")) 108 | model.eval() 109 | if not os.path.exists('./results/' + save_name + '/'): 110 | os.makedirs('./results/' + save_name + '/') 111 | rel_err = 0.0 112 | showcase = 10 113 | id = 0 114 | 115 | with torch.no_grad(): 116 | for pos, fx, y in test_loader: 117 | id += 1 118 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 119 | out = model(x, None).squeeze(-1) 120 | out = y_normalizer.decode(out) 121 | tl = myloss(out, y).item() 122 | rel_err += tl 123 | if id < showcase: 124 | print(id) 125 | plt.axis('off') 126 | plt.scatter(x=fx[0, :, 0].detach().cpu().numpy(), y=fx[0, :, 1].detach().cpu().numpy(), 127 | c=y[0, :].detach().cpu().numpy(), cmap='coolwarm') 128 | plt.colorbar() 129 | plt.clim(0, 1000) 130 | plt.savefig( 131 | os.path.join('./results/' + save_name + '/', 132 | "gt_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 133 | plt.close() 134 | 135 | plt.axis('off') 136 | plt.scatter(x=fx[0, :, 0].detach().cpu().numpy(), y=fx[0, :, 1].detach().cpu().numpy(), 137 | c=out[0, :].detach().cpu().numpy(), cmap='coolwarm') 138 | plt.colorbar() 139 | plt.clim(0, 1000) 140 | plt.savefig( 141 | os.path.join('./results/' + save_name + '/', 142 | "pred_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 143 | plt.close() 144 | 145 | plt.axis('off') 146 | plt.scatter(x=fx[0, :, 0].detach().cpu().numpy(), y=fx[0, :, 1].detach().cpu().numpy(), 147 | c=((y[0, :] - out[0, :])).detach().cpu().numpy(), cmap='coolwarm') 148 | plt.clim(-8, 8) 149 | plt.colorbar() 150 | plt.savefig( 151 | os.path.join('./results/' + save_name + '/', 152 | "error_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 153 | plt.close() 154 | 155 | rel_err /= ntest 156 | print("rel_err : {}".format(rel_err)) 157 | else: 158 | for ep in range(args.epochs): 159 | 160 | model.train() 161 | train_loss = 0 162 | 163 | for pos, fx, y in train_loader: 164 | 165 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() # x:B,N,2 fx:B,N,2 y:B,N, 166 | optimizer.zero_grad() 167 | out = model(x, None).squeeze(-1) 168 | out = y_normalizer.decode(out) 169 | y = y_normalizer.decode(y) 170 | loss = myloss(out, y) 171 | loss.backward() 172 | 173 | if args.max_grad_norm is not None: 174 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 175 | optimizer.step() 176 | train_loss += loss.item() 177 | scheduler.step() 178 | 179 | train_loss = train_loss / ntrain 180 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss)) 181 | 182 | model.eval() 183 | rel_err = 0.0 184 | with torch.no_grad(): 185 | for pos, fx, y in test_loader: 186 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 187 | out = model(x, None).squeeze(-1) 188 | out = y_normalizer.decode(out) 189 | tl = myloss(out, y).item() 190 | rel_err += tl 191 | 192 | rel_err /= ntest 193 | print("rel_err : {}".format(rel_err)) 194 | 195 | if ep % 100 == 0: 196 | if not os.path.exists('./checkpoints'): 197 | os.makedirs('./checkpoints') 198 | print('save model') 199 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 200 | 201 | if not os.path.exists('./checkpoints'): 202 | os.makedirs('./checkpoints') 203 | print('save model') 204 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/exp_ns.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | import scipy.io as scio 5 | import numpy as np 6 | import torch 7 | from tqdm import * 8 | from utils.testloss import TestLoss 9 | from model_dict import get_model 10 | 11 | parser = argparse.ArgumentParser('Training Transformer') 12 | 13 | parser.add_argument('--lr', type=float, default=1e-3) 14 | parser.add_argument('--epochs', type=int, default=500) 15 | parser.add_argument('--weight_decay', type=float, default=1e-5) 16 | parser.add_argument('--model', type=str, default='Transolver_2D') 17 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim') 18 | parser.add_argument('--n-layers', type=int, default=3, help='layers') 19 | parser.add_argument('--n-heads', type=int, default=4) 20 | parser.add_argument('--batch-size', type=int, default=8) 21 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use") 22 | parser.add_argument('--max_grad_norm', type=float, default=None) 23 | parser.add_argument('--downsample', type=int, default=1) 24 | parser.add_argument('--mlp_ratio', type=int, default=1) 25 | parser.add_argument('--dropout', type=float, default=0.0) 26 | parser.add_argument('--unified_pos', type=int, default=0) 27 | parser.add_argument('--ref', type=int, default=8) 28 | parser.add_argument('--slice_num', type=int, default=32) 29 | parser.add_argument('--eval', type=int, default=0) 30 | parser.add_argument('--save_name', type=str, default='ns_2d_UniPDE') 31 | parser.add_argument('--data_path', type=str, default='/data/fno') 32 | args = parser.parse_args() 33 | 34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 35 | 36 | data_path = args.data_path + '/NavierStokes_V1e-5_N1200_T20/NavierStokes_V1e-5_N1200_T20.mat' 37 | # data_path = args.data_path + '/NavierStokes_V1e-5_N1200_T20.mat' 38 | ntrain = 1000 39 | ntest = 200 40 | T_in = 10 41 | T = 10 42 | step = 1 43 | eval = args.eval 44 | save_name = args.save_name 45 | 46 | 47 | def count_parameters(model): 48 | total_params = 0 49 | for name, parameter in model.named_parameters(): 50 | if not parameter.requires_grad: continue 51 | params = parameter.numel() 52 | total_params += params 53 | print(f"Total Trainable Params: {total_params}") 54 | return total_params 55 | 56 | 57 | def main(): 58 | r = args.downsample 59 | h = int(((64 - 1) / r) + 1) 60 | 61 | data = scio.loadmat(data_path) 62 | print(data['u'].shape) 63 | train_a = data['u'][:ntrain, ::r, ::r, :T_in][:, :h, :h, :] 64 | train_a = train_a.reshape(train_a.shape[0], -1, train_a.shape[-1]) 65 | train_a = torch.from_numpy(train_a) 66 | train_u = data['u'][:ntrain, ::r, ::r, T_in:T + T_in][:, :h, :h, :] 67 | train_u = train_u.reshape(train_u.shape[0], -1, train_u.shape[-1]) 68 | train_u = torch.from_numpy(train_u) 69 | 70 | test_a = data['u'][-ntest:, ::r, ::r, :T_in][:, :h, :h, :] 71 | test_a = test_a.reshape(test_a.shape[0], -1, test_a.shape[-1]) 72 | test_a = torch.from_numpy(test_a) 73 | test_u = data['u'][-ntest:, ::r, ::r, T_in:T + T_in][:, :h, :h, :] 74 | test_u = test_u.reshape(test_u.shape[0], -1, test_u.shape[-1]) 75 | test_u = torch.from_numpy(test_u) 76 | 77 | x = np.linspace(0, 1, h) 78 | y = np.linspace(0, 1, h) 79 | x, y = np.meshgrid(x, y) 80 | pos = np.c_[x.ravel(), y.ravel()] 81 | pos = torch.tensor(pos, dtype=torch.float).unsqueeze(0) 82 | pos_train = pos.repeat(ntrain, 1, 1) 83 | pos_test = pos.repeat(ntest, 1, 1) 84 | 85 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_train, train_a, train_u), 86 | batch_size=args.batch_size, shuffle=True) 87 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, test_a, test_u), 88 | batch_size=args.batch_size, shuffle=False) 89 | 90 | print("Dataloading is over.") 91 | 92 | model = get_model(args).Model(space_dim=2, 93 | n_layers=args.n_layers, 94 | n_hidden=args.n_hidden, 95 | dropout=args.dropout, 96 | n_head=args.n_heads, 97 | Time_Input=False, 98 | mlp_ratio=args.mlp_ratio, 99 | fun_dim=T_in, 100 | out_dim=1, 101 | slice_num=args.slice_num, 102 | ref=args.ref, 103 | unified_pos=args.unified_pos, 104 | H=h, W=h).cuda() 105 | 106 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 107 | 108 | print(args) 109 | print(model) 110 | count_parameters(model) 111 | 112 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs, 113 | steps_per_epoch=len(train_loader)) 114 | myloss = TestLoss(size_average=False) 115 | 116 | if eval: 117 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"), strict=False) 118 | model.eval() 119 | showcase = 10 120 | id = 0 121 | 122 | if not os.path.exists('./results/' + save_name + '/'): 123 | os.makedirs('./results/' + save_name + '/') 124 | 125 | test_l2_full = 0 126 | with torch.no_grad(): 127 | for x, fx, yy in test_loader: 128 | id += 1 129 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() # x : B, 4096, 2 fx : B, 4096 y : B, 4096, T 130 | bsz = x.shape[0] 131 | for t in range(0, T, step): 132 | im = model(x, fx=fx) 133 | 134 | fx = torch.cat((fx[..., step:], im), dim=-1) 135 | if t == 0: 136 | pred = im 137 | else: 138 | pred = torch.cat((pred, im), -1) 139 | 140 | if id < showcase: 141 | print(id) 142 | plt.figure() 143 | plt.axis('off') 144 | plt.imshow(im[0, :, 0].reshape(64, 64).detach().cpu().numpy(), cmap='coolwarm') 145 | plt.colorbar() 146 | plt.clim(-3, 3) 147 | plt.savefig( 148 | os.path.join('./results/' + save_name + '/', 149 | "case_" + str(id) + "_pred_" + str(20) + ".pdf")) 150 | plt.close() 151 | # ============ # 152 | plt.figure() 153 | plt.axis('off') 154 | plt.imshow(yy[0, :, t].reshape(64, 64).detach().cpu().numpy(), cmap='coolwarm') 155 | plt.colorbar() 156 | plt.clim(-3, 3) 157 | plt.savefig( 158 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_gt_" + str(20) + ".pdf")) 159 | plt.close() 160 | # ============ # 161 | plt.figure() 162 | plt.axis('off') 163 | plt.imshow((im[0, :, 0].reshape(64, 64) - yy[0, :, t].reshape(64, 64)).detach().cpu().numpy(), 164 | cmap='coolwarm') 165 | plt.colorbar() 166 | plt.clim(-2, 2) 167 | plt.savefig( 168 | os.path.join('./results/' + save_name + '/', "case_" + str(id) + "_error_" + str(20) + ".pdf")) 169 | plt.close() 170 | test_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item() 171 | print(test_l2_full / ntest) 172 | else: 173 | for ep in range(args.epochs): 174 | 175 | model.train() 176 | train_l2_step = 0 177 | train_l2_full = 0 178 | 179 | for x, fx, yy in train_loader: 180 | loss = 0 181 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() # x: B,4096,2 fx: B,4096,T y: B,4096,T 182 | bsz = x.shape[0] 183 | 184 | for t in range(0, T, step): 185 | y = yy[..., t:t + step] 186 | im = model(x, fx=fx) # B , 4096 , 1 187 | loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1)) 188 | if t == 0: 189 | pred = im 190 | else: 191 | pred = torch.cat((pred, im), -1) 192 | fx = torch.cat((fx[..., step:], y), dim=-1) # detach() & groundtruth 193 | 194 | train_l2_step += loss.item() 195 | train_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item() 196 | optimizer.zero_grad() 197 | loss.backward() 198 | if args.max_grad_norm is not None: 199 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 200 | optimizer.step() 201 | scheduler.step() 202 | 203 | test_l2_step = 0 204 | test_l2_full = 0 205 | 206 | model.eval() 207 | 208 | with torch.no_grad(): 209 | for x, fx, yy in test_loader: 210 | loss = 0 211 | x, fx, yy = x.cuda(), fx.cuda(), yy.cuda() # x : B, 4096, 2 fx : B, 4096 y : B, 4096, T 212 | bsz = x.shape[0] 213 | for t in range(0, T, step): 214 | y = yy[..., t:t + step] 215 | im = model(x, fx=fx) 216 | loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1)) 217 | if t == 0: 218 | pred = im 219 | else: 220 | pred = torch.cat((pred, im), -1) 221 | fx = torch.cat((fx[..., step:], im), dim=-1) 222 | 223 | test_l2_step += loss.item() 224 | test_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item() 225 | 226 | print( 227 | "Epoch {} , train_step_loss:{:.5f} , train_full_loss:{:.5f} , test_step_loss:{:.5f} , test_full_loss:{:.5f}".format( 228 | ep, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step), 229 | test_l2_full / ntest)) 230 | 231 | if ep % 100 == 0: 232 | if not os.path.exists('./checkpoints'): 233 | os.makedirs('./checkpoints') 234 | print('save model') 235 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 236 | 237 | if not os.path.exists('./checkpoints'): 238 | os.makedirs('./checkpoints') 239 | print('save model') 240 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 241 | 242 | 243 | if __name__ == "__main__": 244 | main() 245 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/exp_pipe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | 5 | parser = argparse.ArgumentParser('Training Transformer') 6 | 7 | parser.add_argument('--lr', type=float, default=1e-3) 8 | parser.add_argument('--epochs', type=int, default=500) 9 | parser.add_argument('--weight_decay', type=float, default=1e-5) 10 | parser.add_argument('--model', type=str, default='Transolver_2D') 11 | parser.add_argument('--n-hidden', type=int, default=64, help='hidden dim') 12 | parser.add_argument('--n-layers', type=int, default=3, help='layers') 13 | parser.add_argument('--n-heads', type=int, default=4) 14 | parser.add_argument('--batch-size', type=int, default=8) 15 | parser.add_argument("--gpu", type=str, default='0', help="GPU index to use") 16 | parser.add_argument('--max_grad_norm', type=float, default=None) 17 | parser.add_argument('--downsamplex', type=int, default=1) 18 | parser.add_argument('--downsampley', type=int, default=1) 19 | parser.add_argument('--mlp_ratio', type=int, default=1) 20 | parser.add_argument('--dropout', type=float, default=0.0) 21 | parser.add_argument('--unified_pos', type=int, default=0) 22 | parser.add_argument('--ref', type=int, default=8) 23 | parser.add_argument('--slice_num', type=int, default=32) 24 | parser.add_argument('--eval', type=int, default=0) 25 | parser.add_argument('--save_name', type=str, default='pipe_UniPDE') 26 | parser.add_argument('--data_path', type=str, default='/data/fno/pipe') 27 | args = parser.parse_args() 28 | eval = args.eval 29 | save_name = args.save_name 30 | 31 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 32 | 33 | import numpy as np 34 | import torch 35 | from tqdm import * 36 | from utils.testloss import TestLoss 37 | from model_dict import get_model 38 | from utils.normalizer import UnitTransformer 39 | 40 | 41 | def count_parameters(model): 42 | total_params = 0 43 | for name, parameter in model.named_parameters(): 44 | if not parameter.requires_grad: continue 45 | params = parameter.numel() 46 | total_params += params 47 | print(f"Total Trainable Params: {total_params}") 48 | return total_params 49 | 50 | 51 | def main(): 52 | INPUT_X = args.data_path + '/Pipe_X.npy' 53 | INPUT_Y = args.data_path + '/Pipe_Y.npy' 54 | OUTPUT_Sigma = args.data_path + '/Pipe_Q.npy' 55 | 56 | ntrain = 1000 57 | ntest = 200 58 | N = 1200 59 | 60 | r1 = args.downsamplex 61 | r2 = args.downsampley 62 | s1 = int(((129 - 1) / r1) + 1) 63 | s2 = int(((129 - 1) / r2) + 1) 64 | 65 | inputX = np.load(INPUT_X) 66 | inputX = torch.tensor(inputX, dtype=torch.float) 67 | inputY = np.load(INPUT_Y) 68 | inputY = torch.tensor(inputY, dtype=torch.float) 69 | input = torch.stack([inputX, inputY], dim=-1) 70 | 71 | output = np.load(OUTPUT_Sigma)[:, 0] 72 | output = torch.tensor(output, dtype=torch.float) 73 | print(input.shape, output.shape) 74 | x_train = input[:N][:ntrain, ::r1, ::r2][:, :s1, :s2] 75 | y_train = output[:N][:ntrain, ::r1, ::r2][:, :s1, :s2] 76 | x_test = input[:N][-ntest:, ::r1, ::r2][:, :s1, :s2] 77 | y_test = output[:N][-ntest:, ::r1, ::r2][:, :s1, :s2] 78 | x_train = x_train.reshape(ntrain, -1, 2) 79 | x_test = x_test.reshape(ntest, -1, 2) 80 | y_train = y_train.reshape(ntrain, -1) 81 | y_test = y_test.reshape(ntest, -1) 82 | 83 | x_normalizer = UnitTransformer(x_train) 84 | y_normalizer = UnitTransformer(y_train) 85 | 86 | x_train = x_normalizer.encode(x_train) 87 | x_test = x_normalizer.encode(x_test) 88 | y_train = y_normalizer.encode(y_train) 89 | 90 | x_normalizer.cuda() 91 | y_normalizer.cuda() 92 | 93 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, x_train, y_train), 94 | batch_size=args.batch_size, 95 | shuffle=True) 96 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, x_test, y_test), 97 | batch_size=args.batch_size, 98 | shuffle=False) 99 | 100 | print("Dataloading is over.") 101 | 102 | model = get_model(args).Model(space_dim=2, 103 | n_layers=args.n_layers, 104 | n_hidden=args.n_hidden, 105 | dropout=args.dropout, 106 | n_head=args.n_heads, 107 | Time_Input=False, 108 | mlp_ratio=args.mlp_ratio, 109 | fun_dim=0, 110 | out_dim=1, 111 | slice_num=args.slice_num, 112 | ref=args.ref, 113 | unified_pos=args.unified_pos, 114 | H=s1, W=s2).cuda() 115 | 116 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 117 | 118 | print(args) 119 | print(model) 120 | count_parameters(model) 121 | 122 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) 123 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs, 124 | steps_per_epoch=len(train_loader)) 125 | myloss = TestLoss(size_average=False) 126 | 127 | if eval: 128 | model.load_state_dict(torch.load("./checkpoints/" + save_name + ".pt"), strict=False) 129 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '_resave' + '.pt')) 130 | model.eval() 131 | if not os.path.exists('./results/' + save_name + '/'): 132 | os.makedirs('./results/' + save_name + '/') 133 | 134 | rel_err = 0.0 135 | showcase = 10 136 | id = 0 137 | 138 | with torch.no_grad(): 139 | for pos, fx, y in test_loader: 140 | id += 1 141 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 142 | out = model(x, None).squeeze(-1) 143 | out = y_normalizer.decode(out) 144 | 145 | tl = myloss(out, y).item() 146 | rel_err += tl 147 | 148 | if id < showcase: 149 | print(id) 150 | plt.axis('off') 151 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(), 152 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(), 153 | np.zeros([129, 129]), 154 | shading='auto', 155 | edgecolors='black', linewidths=0.1) 156 | plt.colorbar() 157 | plt.savefig( 158 | os.path.join('./results/' + save_name + '/', 159 | "input_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 160 | plt.close() 161 | plt.axis('off') 162 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(), 163 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(), 164 | out[0, :].reshape(129, 129).detach().cpu().numpy(), 165 | shading='auto', cmap='coolwarm') 166 | plt.colorbar() 167 | plt.clim(0, 0.3) 168 | plt.savefig( 169 | os.path.join('./results/' + save_name + '/', 170 | "pred_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 171 | plt.close() 172 | plt.axis('off') 173 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(), 174 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(), 175 | y[0, :].reshape(129, 129).detach().cpu().numpy(), 176 | shading='auto', cmap='coolwarm') 177 | plt.colorbar() 178 | plt.clim(0, 0.3) 179 | plt.savefig( 180 | os.path.join('./results/' + save_name + '/', 181 | "gt_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 182 | plt.close() 183 | plt.axis('off') 184 | plt.pcolormesh(x[0, :, 0].reshape(129, 129).detach().cpu().numpy(), 185 | x[0, :, 1].reshape(129, 129).detach().cpu().numpy(), 186 | out[0, :].reshape(129, 129).detach().cpu().numpy() - \ 187 | y[0, :].reshape(129, 129).detach().cpu().numpy(), 188 | shading='auto', cmap='coolwarm') 189 | plt.colorbar() 190 | plt.clim(-0.02, 0.02) 191 | plt.savefig( 192 | os.path.join('./results/' + save_name + '/', 193 | "error_" + str(id) + ".pdf"), bbox_inches='tight', pad_inches=0) 194 | plt.close() 195 | 196 | rel_err /= ntest 197 | print("rel_err:{}".format(rel_err)) 198 | else: 199 | for ep in range(args.epochs): 200 | 201 | model.train() 202 | train_loss = 0 203 | 204 | for pos, fx, y in train_loader: 205 | 206 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() # x:B,N,2 fx:B,N,2 y:B,N 207 | optimizer.zero_grad() 208 | out = model(x, None).squeeze(-1) 209 | 210 | out = y_normalizer.decode(out) 211 | y = y_normalizer.decode(y) 212 | 213 | loss = myloss(out, y) 214 | loss.backward() 215 | 216 | # print("loss:{}".format(loss.item()/batch_size)) 217 | if args.max_grad_norm is not None: 218 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 219 | optimizer.step() 220 | train_loss += loss.item() 221 | scheduler.step() 222 | 223 | train_loss = train_loss / ntrain 224 | print("Epoch {} Train loss : {:.5f}".format(ep, train_loss)) 225 | 226 | model.eval() 227 | rel_err = 0.0 228 | with torch.no_grad(): 229 | for pos, fx, y in test_loader: 230 | x, fx, y = pos.cuda(), fx.cuda(), y.cuda() 231 | out = model(x, None).squeeze(-1) 232 | out = y_normalizer.decode(out) 233 | 234 | tl = myloss(out, y).item() 235 | rel_err += tl 236 | 237 | rel_err /= ntest 238 | print("rel_err:{}".format(rel_err)) 239 | 240 | if ep % 100 == 0: 241 | if not os.path.exists('./checkpoints'): 242 | os.makedirs('./checkpoints') 243 | print('save model') 244 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 245 | 246 | if not os.path.exists('./checkpoints'): 247 | os.makedirs('./checkpoints') 248 | print('save model') 249 | torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt')) 250 | 251 | 252 | if __name__ == "__main__": 253 | main() 254 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/fig/scalibility.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/PDE-Solving-StandardBenchmark/fig/scalibility.png -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/fig/showcase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/PDE-Solving-StandardBenchmark/fig/showcase.png -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/fig/standard_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/PDE-Solving-StandardBenchmark/fig/standard_benchmark.png -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/model/Embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | 7 | class RotaryEmbedding(nn.Module): 8 | def __init__(self, dim, min_freq=1 / 2, scale=1.): 9 | super().__init__() 10 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 11 | self.min_freq = min_freq 12 | self.scale = scale 13 | self.register_buffer('inv_freq', inv_freq) 14 | 15 | def forward(self, coordinates, device): 16 | # coordinates [b, n] 17 | t = coordinates.to(device).type_as(self.inv_freq) 18 | t = t * (self.scale / self.min_freq) 19 | freqs = torch.einsum('... i , j -> ... i j', t, self.inv_freq) # [b, n, d//2] 20 | return torch.cat((freqs, freqs), dim=-1) # [b, n, d] 21 | 22 | 23 | def rotate_half(x): 24 | x = rearrange(x, '... (j d) -> ... j d', j=2) 25 | x1, x2 = x.unbind(dim=-2) 26 | return torch.cat((-x2, x1), dim=-1) 27 | 28 | 29 | def apply_rotary_pos_emb(t, freqs): 30 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) 31 | 32 | 33 | def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): 34 | # split t into first half and second half 35 | # t: [b, h, n, d] 36 | # freq_x/y: [b, n, d] 37 | d = t.shape[-1] 38 | t_x, t_y = t[..., :d // 2], t[..., d // 2:] 39 | 40 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x), 41 | apply_rotary_pos_emb(t_y, freqs_y)), dim=-1) 42 | 43 | 44 | class PositionalEncoding(nn.Module): 45 | "Implement the PE function." 46 | 47 | def __init__(self, d_model, dropout, max_len=421 * 421): 48 | super(PositionalEncoding, self).__init__() 49 | self.dropout = nn.Dropout(p=dropout) 50 | 51 | # Compute the positional encodings once in log space. 52 | pe = torch.zeros(max_len, d_model) 53 | position = torch.arange(0, max_len).unsqueeze(1) 54 | div_term = torch.exp( 55 | torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 56 | ) 57 | pe[:, 0::2] = torch.sin(position * div_term) 58 | pe[:, 1::2] = torch.cos(position * div_term) 59 | pe = pe.unsqueeze(0) 60 | self.register_buffer("pe", pe) 61 | 62 | def forward(self, x): 63 | x = x + self.pe[:, : x.size(1)].requires_grad_(False) 64 | return self.dropout(x) 65 | 66 | 67 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 68 | """ 69 | Create sinusoidal timestep embeddings. 70 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 71 | These may be fractional. 72 | :param dim: the dimension of the output. 73 | :param max_period: controls the minimum frequency of the embeddings. 74 | :return: an [N x dim] Tensor of positional embeddings. 75 | """ 76 | 77 | half = dim // 2 78 | freqs = torch.exp( 79 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 80 | ).to(device=timesteps.device) 81 | args = timesteps[:, None].float() * freqs[None] 82 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 83 | if dim % 2: 84 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 85 | return embedding 86 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/model/Physics_Attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from einops import rearrange, repeat 4 | 5 | 6 | class Physics_Attention_Irregular_Mesh(nn.Module): 7 | ## for irregular meshes in 1D, 2D or 3D space 8 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64): 9 | super().__init__() 10 | inner_dim = dim_head * heads 11 | self.dim_head = dim_head 12 | self.heads = heads 13 | self.scale = dim_head ** -0.5 14 | self.softmax = nn.Softmax(dim=-1) 15 | self.dropout = nn.Dropout(dropout) 16 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 17 | 18 | self.in_project_x = nn.Linear(dim, inner_dim) 19 | self.in_project_fx = nn.Linear(dim, inner_dim) 20 | self.in_project_slice = nn.Linear(dim_head, slice_num) 21 | for l in [self.in_project_slice]: 22 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 23 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 24 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 25 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 26 | self.to_out = nn.Sequential( 27 | nn.Linear(inner_dim, dim), 28 | nn.Dropout(dropout) 29 | ) 30 | 31 | def forward(self, x): 32 | # B N C 33 | B, N, C = x.shape 34 | 35 | ### (1) Slice 36 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \ 37 | .permute(0, 2, 1, 3).contiguous() # B H N C 38 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \ 39 | .permute(0, 2, 1, 3).contiguous() # B H N C 40 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G 41 | slice_norm = slice_weights.sum(2) # B H G 42 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 43 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 44 | 45 | ### (2) Attention among slice tokens 46 | q_slice_token = self.to_q(slice_token) 47 | k_slice_token = self.to_k(slice_token) 48 | v_slice_token = self.to_v(slice_token) 49 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 50 | attn = self.softmax(dots) 51 | attn = self.dropout(attn) 52 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 53 | 54 | ### (3) Deslice 55 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 56 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 57 | return self.to_out(out_x) 58 | 59 | 60 | class Physics_Attention_Structured_Mesh_2D(nn.Module): 61 | ## for structured mesh in 2D space 62 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, H=101, W=31, kernel=3): # kernel=3): 63 | super().__init__() 64 | inner_dim = dim_head * heads 65 | self.dim_head = dim_head 66 | self.heads = heads 67 | self.scale = dim_head ** -0.5 68 | self.softmax = nn.Softmax(dim=-1) 69 | self.dropout = nn.Dropout(dropout) 70 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 71 | self.H = H 72 | self.W = W 73 | 74 | self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) 75 | self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) 76 | self.in_project_slice = nn.Linear(dim_head, slice_num) 77 | for l in [self.in_project_slice]: 78 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 79 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 80 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 81 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 82 | 83 | self.to_out = nn.Sequential( 84 | nn.Linear(inner_dim, dim), 85 | nn.Dropout(dropout) 86 | ) 87 | 88 | def forward(self, x): 89 | # B N C 90 | B, N, C = x.shape 91 | x = x.reshape(B, self.H, self.W, C).contiguous().permute(0, 3, 1, 2).contiguous() # B C H W 92 | 93 | ### (1) Slice 94 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 95 | .permute(0, 2, 1, 3).contiguous() # B H N C 96 | x_mid = self.in_project_x(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 97 | .permute(0, 2, 1, 3).contiguous() # B H N G 98 | slice_weights = self.softmax( 99 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G 100 | slice_norm = slice_weights.sum(2) # B H G 101 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 102 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 103 | 104 | ### (2) Attention among slice tokens 105 | q_slice_token = self.to_q(slice_token) 106 | k_slice_token = self.to_k(slice_token) 107 | v_slice_token = self.to_v(slice_token) 108 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 109 | attn = self.softmax(dots) 110 | attn = self.dropout(attn) 111 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 112 | 113 | ### (3) Deslice 114 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 115 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 116 | return self.to_out(out_x) 117 | 118 | 119 | class Physics_Attention_Structured_Mesh_3D(nn.Module): 120 | ## for structured mesh in 3D space 121 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=32, H=32, W=32, D=32, kernel=3): 122 | super().__init__() 123 | inner_dim = dim_head * heads 124 | self.dim_head = dim_head 125 | self.heads = heads 126 | self.scale = dim_head ** -0.5 127 | self.softmax = nn.Softmax(dim=-1) 128 | self.dropout = nn.Dropout(dropout) 129 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 130 | self.H = H 131 | self.W = W 132 | self.D = D 133 | 134 | self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) 135 | self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) 136 | self.in_project_slice = nn.Linear(dim_head, slice_num) 137 | for l in [self.in_project_slice]: 138 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 139 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 140 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 141 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 142 | self.to_out = nn.Sequential( 143 | nn.Linear(inner_dim, dim), 144 | nn.Dropout(dropout) 145 | ) 146 | 147 | def forward(self, x): 148 | # B N C 149 | B, N, C = x.shape 150 | x = x.reshape(B, self.H, self.W, self.D, C).contiguous().permute(0, 4, 1, 2, 3).contiguous() # B C H W 151 | 152 | ### (1) Slice 153 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 154 | .permute(0, 2, 1, 3).contiguous() # B H N C 155 | x_mid = self.in_project_x(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 156 | .permute(0, 2, 1, 3).contiguous() # B H N G 157 | slice_weights = self.softmax( 158 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G 159 | slice_norm = slice_weights.sum(2) # B H G 160 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 161 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 162 | 163 | ### (2) Attention among slice tokens 164 | q_slice_token = self.to_q(slice_token) 165 | k_slice_token = self.to_k(slice_token) 166 | v_slice_token = self.to_v(slice_token) 167 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 168 | attn = self.softmax(dots) 169 | attn = self.dropout(attn) 170 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 171 | 172 | ### (3) Deslice 173 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 174 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 175 | return self.to_out(out_x) 176 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/model/Transolver_Irregular_Mesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | from model.Embedding import timestep_embedding 5 | import numpy as np 6 | from model.Physics_Attention import Physics_Attention_Irregular_Mesh 7 | 8 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1), 9 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU} 10 | 11 | 12 | class MLP(nn.Module): 13 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True): 14 | super(MLP, self).__init__() 15 | 16 | if act in ACTIVATION.keys(): 17 | act = ACTIVATION[act] 18 | else: 19 | raise NotImplementedError 20 | self.n_input = n_input 21 | self.n_hidden = n_hidden 22 | self.n_output = n_output 23 | self.n_layers = n_layers 24 | self.res = res 25 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act()) 26 | self.linear_post = nn.Linear(n_hidden, n_output) 27 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)]) 28 | 29 | def forward(self, x): 30 | x = self.linear_pre(x) 31 | for i in range(self.n_layers): 32 | if self.res: 33 | x = self.linears[i](x) + x 34 | else: 35 | x = self.linears[i](x) 36 | x = self.linear_post(x) 37 | return x 38 | 39 | 40 | class Transolver_block(nn.Module): 41 | """Transformer encoder block.""" 42 | 43 | def __init__( 44 | self, 45 | num_heads: int, 46 | hidden_dim: int, 47 | dropout: float, 48 | act='gelu', 49 | mlp_ratio=4, 50 | last_layer=False, 51 | out_dim=1, 52 | slice_num=32, 53 | ): 54 | super().__init__() 55 | self.last_layer = last_layer 56 | self.ln_1 = nn.LayerNorm(hidden_dim) 57 | self.Attn = Physics_Attention_Irregular_Mesh(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, 58 | dropout=dropout, slice_num=slice_num) 59 | self.ln_2 = nn.LayerNorm(hidden_dim) 60 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 61 | if self.last_layer: 62 | self.ln_3 = nn.LayerNorm(hidden_dim) 63 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 64 | 65 | def forward(self, fx): 66 | fx = self.Attn(self.ln_1(fx)) + fx 67 | fx = self.mlp(self.ln_2(fx)) + fx 68 | if self.last_layer: 69 | return self.mlp2(self.ln_3(fx)) 70 | else: 71 | return fx 72 | 73 | 74 | class Model(nn.Module): 75 | def __init__(self, 76 | space_dim=1, 77 | n_layers=5, 78 | n_hidden=256, 79 | dropout=0.0, 80 | n_head=8, 81 | Time_Input=False, 82 | act='gelu', 83 | mlp_ratio=1, 84 | fun_dim=1, 85 | out_dim=1, 86 | slice_num=32, 87 | ref=8, 88 | unified_pos=False 89 | ): 90 | super(Model, self).__init__() 91 | self.__name__ = 'Transolver_1D' 92 | self.ref = ref 93 | self.unified_pos = unified_pos 94 | self.Time_Input = Time_Input 95 | self.n_hidden = n_hidden 96 | self.space_dim = space_dim 97 | if self.unified_pos: 98 | self.preprocess = MLP(fun_dim + self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) 99 | else: 100 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) 101 | if Time_Input: 102 | self.time_fc = nn.Sequential(nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden)) 103 | 104 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden, 105 | dropout=dropout, 106 | act=act, 107 | mlp_ratio=mlp_ratio, 108 | out_dim=out_dim, 109 | slice_num=slice_num, 110 | last_layer=(_ == n_layers - 1)) 111 | for _ in range(n_layers)]) 112 | self.initialize_weights() 113 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float)) 114 | 115 | def initialize_weights(self): 116 | self.apply(self._init_weights) 117 | 118 | def _init_weights(self, m): 119 | if isinstance(m, nn.Linear): 120 | trunc_normal_(m.weight, std=0.02) 121 | if isinstance(m, nn.Linear) and m.bias is not None: 122 | nn.init.constant_(m.bias, 0) 123 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 124 | nn.init.constant_(m.bias, 0) 125 | nn.init.constant_(m.weight, 1.0) 126 | 127 | def get_grid(self, x, batchsize=1): 128 | # x: B N 2 129 | # grid_ref 130 | gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) 131 | gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1]) 132 | gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) 133 | gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1]) 134 | grid_ref = torch.cat((gridx, gridy), dim=-1).cuda().reshape(batchsize, self.ref * self.ref, 2) # B H W 8 8 2 135 | 136 | pos = torch.sqrt(torch.sum((x[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, dim=-1)). \ 137 | reshape(batchsize, x.shape[1], self.ref * self.ref).contiguous() 138 | return pos 139 | 140 | def forward(self, x, fx, T=None): 141 | if self.unified_pos: 142 | x = self.get_grid(x, x.shape[0]) 143 | if fx is not None: 144 | fx = torch.cat((x, fx), -1) 145 | fx = self.preprocess(fx) 146 | else: 147 | fx = self.preprocess(x) 148 | fx = fx + self.placeholder[None, None, :] 149 | 150 | if T is not None: 151 | Time_emb = timestep_embedding(T, self.n_hidden).repeat(1, x.shape[1], 1) 152 | Time_emb = self.time_fc(Time_emb) 153 | fx = fx + Time_emb 154 | 155 | for block in self.blocks: 156 | fx = block(fx) 157 | 158 | return fx 159 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/model/Transolver_Structured_Mesh_2D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from timm.models.layers import trunc_normal_ 5 | from model.Embedding import timestep_embedding 6 | from model.Physics_Attention import Physics_Attention_Structured_Mesh_2D 7 | 8 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1), 9 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU} 10 | 11 | 12 | class MLP(nn.Module): 13 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True): 14 | super(MLP, self).__init__() 15 | 16 | if act in ACTIVATION.keys(): 17 | act = ACTIVATION[act] 18 | else: 19 | raise NotImplementedError 20 | self.n_input = n_input 21 | self.n_hidden = n_hidden 22 | self.n_output = n_output 23 | self.n_layers = n_layers 24 | self.res = res 25 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act()) 26 | self.linear_post = nn.Linear(n_hidden, n_output) 27 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)]) 28 | 29 | def forward(self, x): 30 | x = self.linear_pre(x) 31 | for i in range(self.n_layers): 32 | if self.res: 33 | x = self.linears[i](x) + x 34 | else: 35 | x = self.linears[i](x) 36 | x = self.linear_post(x) 37 | return x 38 | 39 | 40 | class Transolver_block(nn.Module): 41 | """Transformer encoder block.""" 42 | 43 | def __init__( 44 | self, 45 | num_heads: int, 46 | hidden_dim: int, 47 | dropout: float, 48 | act='gelu', 49 | mlp_ratio=4, 50 | last_layer=False, 51 | out_dim=1, 52 | slice_num=32, 53 | H=85, 54 | W=85 55 | ): 56 | super().__init__() 57 | self.last_layer = last_layer 58 | self.ln_1 = nn.LayerNorm(hidden_dim) 59 | self.Attn = Physics_Attention_Structured_Mesh_2D(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, 60 | dropout=dropout, slice_num=slice_num, H=H, W=W) 61 | 62 | self.ln_2 = nn.LayerNorm(hidden_dim) 63 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 64 | if self.last_layer: 65 | self.ln_3 = nn.LayerNorm(hidden_dim) 66 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 67 | 68 | def forward(self, fx): 69 | fx = self.Attn(self.ln_1(fx)) + fx 70 | fx = self.mlp(self.ln_2(fx)) + fx 71 | if self.last_layer: 72 | return self.mlp2(self.ln_3(fx)) 73 | else: 74 | return fx 75 | 76 | 77 | class Model(nn.Module): 78 | def __init__(self, 79 | space_dim=1, 80 | n_layers=5, 81 | n_hidden=256, 82 | dropout=0.0, 83 | n_head=8, 84 | Time_Input=False, 85 | act='gelu', 86 | mlp_ratio=1, 87 | fun_dim=1, 88 | out_dim=1, 89 | slice_num=32, 90 | ref=8, 91 | unified_pos=False, 92 | H=85, 93 | W=85, 94 | ): 95 | super(Model, self).__init__() 96 | self.__name__ = 'Transolver_2D' 97 | self.H = H 98 | self.W = W 99 | self.ref = ref 100 | self.unified_pos = unified_pos 101 | if self.unified_pos: 102 | self.pos = self.get_grid() 103 | self.preprocess = MLP(fun_dim + self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) 104 | else: 105 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) 106 | 107 | self.Time_Input = Time_Input 108 | self.n_hidden = n_hidden 109 | self.space_dim = space_dim 110 | if Time_Input: 111 | self.time_fc = nn.Sequential(nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden)) 112 | 113 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden, 114 | dropout=dropout, 115 | act=act, 116 | mlp_ratio=mlp_ratio, 117 | out_dim=out_dim, 118 | slice_num=slice_num, 119 | H=H, 120 | W=W, 121 | last_layer=(_ == n_layers - 1)) 122 | for _ in range(n_layers)]) 123 | self.initialize_weights() 124 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float)) 125 | 126 | def initialize_weights(self): 127 | self.apply(self._init_weights) 128 | 129 | def _init_weights(self, m): 130 | if isinstance(m, nn.Linear): 131 | trunc_normal_(m.weight, std=0.02) 132 | if isinstance(m, nn.Linear) and m.bias is not None: 133 | nn.init.constant_(m.bias, 0) 134 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 135 | nn.init.constant_(m.bias, 0) 136 | nn.init.constant_(m.weight, 1.0) 137 | 138 | def get_grid(self, batchsize=1): 139 | size_x, size_y = self.H, self.W 140 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 141 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 142 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 143 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 144 | grid = torch.cat((gridx, gridy), dim=-1).cuda() # B H W 2 145 | 146 | gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) 147 | gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1]) 148 | gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) 149 | gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1]) 150 | grid_ref = torch.cat((gridx, gridy), dim=-1).cuda() # B H W 8 8 2 151 | 152 | pos = torch.sqrt(torch.sum((grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) ** 2, dim=-1)). \ 153 | reshape(batchsize, size_x, size_y, self.ref * self.ref).contiguous() 154 | return pos 155 | 156 | def forward(self, x, fx, T=None): 157 | if self.unified_pos: 158 | x = self.pos.repeat(x.shape[0], 1, 1, 1).reshape(x.shape[0], self.H * self.W, self.ref * self.ref) 159 | if fx is not None: 160 | fx = torch.cat((x, fx), -1) 161 | fx = self.preprocess(fx) 162 | else: 163 | fx = self.preprocess(x) 164 | fx = fx + self.placeholder[None, None, :] 165 | 166 | if T is not None: 167 | Time_emb = timestep_embedding(T, self.n_hidden).repeat(1, x.shape[1], 1) 168 | Time_emb = self.time_fc(Time_emb) 169 | fx = fx + Time_emb 170 | 171 | for block in self.blocks: 172 | fx = block(fx) 173 | 174 | return fx 175 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/model/Transolver_Structured_Mesh_3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from timm.models.layers import trunc_normal_ 5 | from model.Embedding import timestep_embedding 6 | import torch.utils.checkpoint as checkpoint 7 | from model.Physics_Attention import Physics_Attention_Structured_Mesh_3D 8 | 9 | ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1), 10 | 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU} 11 | 12 | 13 | class MLP(nn.Module): 14 | def __init__(self, n_input, n_hidden, n_output, n_layers=1, act='gelu', res=True): 15 | super(MLP, self).__init__() 16 | 17 | if act in ACTIVATION.keys(): 18 | act = ACTIVATION[act] 19 | else: 20 | raise NotImplementedError 21 | self.n_input = n_input 22 | self.n_hidden = n_hidden 23 | self.n_output = n_output 24 | self.n_layers = n_layers 25 | self.res = res 26 | self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act()) 27 | self.linear_post = nn.Linear(n_hidden, n_output) 28 | self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)]) 29 | 30 | def forward(self, x): 31 | x = self.linear_pre(x) 32 | for i in range(self.n_layers): 33 | if self.res: 34 | x = self.linears[i](x) + x 35 | else: 36 | x = self.linears[i](x) 37 | x = self.linear_post(x) 38 | return x 39 | 40 | 41 | class Transolver_block(nn.Module): 42 | """Transformer encoder block.""" 43 | 44 | def __init__( 45 | self, 46 | num_heads: int, 47 | hidden_dim: int, 48 | dropout: float, 49 | act='gelu', 50 | mlp_ratio=4, 51 | last_layer=False, 52 | out_dim=1, 53 | slice_num=32, 54 | H=32, 55 | W=32, 56 | D=32 57 | ): 58 | super().__init__() 59 | self.last_layer = last_layer 60 | self.ln_1 = nn.LayerNorm(hidden_dim) 61 | self.Attn = Physics_Attention_Structured_Mesh_3D(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, 62 | dropout=dropout, slice_num=slice_num, H=H, W=W, D=D) 63 | 64 | self.ln_2 = nn.LayerNorm(hidden_dim) 65 | self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act) 66 | if self.last_layer: 67 | self.ln_3 = nn.LayerNorm(hidden_dim) 68 | self.mlp2 = nn.Linear(hidden_dim, out_dim) 69 | 70 | def forward(self, fx): 71 | fx = self.Attn(self.ln_1(fx)) + fx 72 | fx = self.mlp(self.ln_2(fx)) + fx 73 | if self.last_layer: 74 | return self.mlp2(self.ln_3(fx)) 75 | else: 76 | return fx 77 | 78 | 79 | class Model(nn.Module): 80 | def __init__(self, 81 | space_dim=1, 82 | n_layers=5, 83 | n_hidden=256, 84 | dropout=0.0, 85 | n_head=8, 86 | Time_Input=False, 87 | act='gelu', 88 | mlp_ratio=1, 89 | fun_dim=1, 90 | out_dim=1, 91 | slice_num=32, 92 | ref=8, 93 | unified_pos=False, 94 | H=32, 95 | W=32, 96 | D=32, 97 | ): 98 | super(Model, self).__init__() 99 | self.__name__ = 'Transolver_3D' 100 | self.use_checkpoint = False 101 | self.H = H 102 | self.W = W 103 | self.D = D 104 | self.ref = ref 105 | self.unified_pos = unified_pos 106 | if self.unified_pos: 107 | self.pos = self.get_grid() 108 | self.preprocess = MLP(fun_dim + self.ref * self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, 109 | res=False, act=act) 110 | else: 111 | self.preprocess = MLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act) 112 | 113 | self.Time_Input = Time_Input 114 | self.n_hidden = n_hidden 115 | self.space_dim = space_dim 116 | if Time_Input: 117 | self.time_fc = nn.Sequential(nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden)) 118 | 119 | self.blocks = nn.ModuleList([Transolver_block(num_heads=n_head, hidden_dim=n_hidden, 120 | dropout=dropout, 121 | act=act, 122 | mlp_ratio=mlp_ratio, 123 | out_dim=out_dim, 124 | slice_num=slice_num, 125 | H=H, 126 | W=W, 127 | D=D, 128 | last_layer=(_ == n_layers - 1)) 129 | for _ in range(n_layers)]) 130 | self.initialize_weights() 131 | self.placeholder = nn.Parameter((1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float)) 132 | 133 | def initialize_weights(self): 134 | self.apply(self._init_weights) 135 | 136 | def _init_weights(self, m): 137 | if isinstance(m, nn.Linear): 138 | trunc_normal_(m.weight, std=0.02) 139 | if isinstance(m, nn.Linear) and m.bias is not None: 140 | nn.init.constant_(m.bias, 0) 141 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): 142 | nn.init.constant_(m.bias, 0) 143 | nn.init.constant_(m.weight, 1.0) 144 | 145 | def get_grid(self, batchsize=1): 146 | size_x, size_y, size_z = self.H, self.W, self.D 147 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 148 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 149 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 150 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 151 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float) 152 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 153 | grid = torch.cat((gridx, gridy, gridz), dim=-1).cuda() # B H W D 3 154 | 155 | gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) 156 | gridx = gridx.reshape(1, self.ref, 1, 1, 1).repeat([batchsize, 1, self.ref, self.ref, 1]) 157 | gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) 158 | gridy = gridy.reshape(1, 1, self.ref, 1, 1).repeat([batchsize, self.ref, 1, self.ref, 1]) 159 | gridz = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) 160 | gridz = gridz.reshape(1, 1, 1, self.ref, 1).repeat([batchsize, self.ref, self.ref, 1, 1]) 161 | grid_ref = torch.cat((gridx, gridy, gridz), dim=-1).cuda() # B 4 4 4 3 162 | 163 | pos = torch.sqrt( 164 | torch.sum((grid[:, :, :, :, None, None, None, :] - grid_ref[:, None, None, None, :, :, :, :]) ** 2, 165 | dim=-1)). \ 166 | reshape(batchsize, size_x, size_y, size_z, self.ref * self.ref * self.ref).contiguous() 167 | return pos 168 | 169 | def forward(self, x, fx, T=None): 170 | if self.unified_pos: 171 | x = self.pos.repeat(x.shape[0], 1, 1, 1, 1).reshape(x.shape[0], self.H * self.W * self.D, 172 | self.ref * self.ref * self.ref) 173 | if fx is not None: 174 | fx = torch.cat((x, fx), -1) 175 | fx = self.preprocess(fx) 176 | else: 177 | fx = self.preprocess(x) 178 | fx = fx + self.placeholder[None, None, :] 179 | 180 | if T is not None: 181 | Time_emb = timestep_embedding(T, self.n_hidden).repeat(1, x.shape[1], 1) 182 | Time_emb = self.time_fc(Time_emb) 183 | fx = fx + Time_emb 184 | 185 | for block in self.blocks: 186 | if self.use_checkpoint: 187 | fx = checkpoint.checkpoint(block, fx) 188 | else: 189 | fx = block(fx) 190 | 191 | return fx 192 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/model_dict.py: -------------------------------------------------------------------------------- 1 | from model import Transolver_Irregular_Mesh, Transolver_Structured_Mesh_2D, Transolver_Structured_Mesh_3D 2 | 3 | 4 | def get_model(args): 5 | model_dict = { 6 | 'Transolver_Irregular_Mesh': Transolver_Irregular_Mesh, # for PDEs in 1D space or in unstructured meshes 7 | 'Transolver_Structured_Mesh_2D': Transolver_Structured_Mesh_2D, 8 | 'Transolver_Structured_Mesh_3D': Transolver_Structured_Mesh_3D, 9 | } 10 | return model_dict[args.model] 11 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.1 2 | h5py==3.8.0 3 | dgl==1.1.0 4 | einops==0.6.1 5 | scipy==1.7.3 6 | timm==0.9.2 -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/scripts/Transolver_Airfoil.sh: -------------------------------------------------------------------------------- 1 | python exp_airfoil.py \ 2 | --gpu 5 \ 3 | --model Transolver_Structured_Mesh_2D \ 4 | --n-hidden 128 \ 5 | --n-heads 8 \ 6 | --n-layers 8 \ 7 | --lr 0.001 \ 8 | --max_grad_norm 0.1 \ 9 | --batch-size 4 \ 10 | --slice_num 64 \ 11 | --unified_pos 0 \ 12 | --ref 8 \ 13 | --eval 0 \ 14 | --save_name airfoil_Transolver -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/scripts/Transolver_Darcy.sh: -------------------------------------------------------------------------------- 1 | python exp_darcy.py \ 2 | --gpu 4 \ 3 | --model Transolver_Structured_Mesh_2D \ 4 | --n-hidden 128 \ 5 | --n-heads 8 \ 6 | --n-layers 8 \ 7 | --lr 0.001 \ 8 | --max_grad_norm 0.1 \ 9 | --batch-size 4 \ 10 | --slice_num 64 \ 11 | --unified_pos 1 \ 12 | --ref 8 \ 13 | --eval 0 \ 14 | --downsample 5 \ 15 | --save_name darcy_UniPDE 16 | 17 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/scripts/Transolver_Elas.sh: -------------------------------------------------------------------------------- 1 | python exp_elas.py \ 2 | --gpu 6 \ 3 | --model Transolver_Irregular_Mesh \ 4 | --n-hidden 128 \ 5 | --n-heads 8 \ 6 | --n-layers 8 \ 7 | --lr 0.001 \ 8 | --max_grad_norm 0.1 \ 9 | --batch-size 1 \ 10 | --slice_num 64 \ 11 | --unified_pos 0 \ 12 | --ref 8 \ 13 | --eval 0 \ 14 | --save_name elas_Transolver -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/scripts/Transolver_NS.sh: -------------------------------------------------------------------------------- 1 | python exp_ns.py \ 2 | --gpu 2 \ 3 | --model Transolver_Structured_Mesh_2D \ 4 | --n-hidden 256 \ 5 | --n-heads 8 \ 6 | --n-layers 8 \ 7 | --lr 0.001 \ 8 | --batch-size 2 \ 9 | --slice_num 32 \ 10 | --unified_pos 1 \ 11 | --ref 8 \ 12 | --eval 0 \ 13 | --save_name ns_Transolver 14 | 15 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/scripts/Transolver_Pipe.sh: -------------------------------------------------------------------------------- 1 | python exp_pipe.py \ 2 | --gpu 7 \ 3 | --model Transolver_Structured_Mesh_2D \ 4 | --n-hidden 128 \ 5 | --n-heads 8 \ 6 | --n-layers 8 \ 7 | --mlp_ratio 2 \ 8 | --lr 0.001 \ 9 | --max_grad_norm 0.1 \ 10 | --batch-size 8 \ 11 | --slice_num 64 \ 12 | --unified_pos 0 \ 13 | --ref 8 \ 14 | --eval 0 \ 15 | --save_name pipe_Transolver 16 | 17 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/scripts/Transolver_Plas.sh: -------------------------------------------------------------------------------- 1 | python exp_plas.py \ 2 | --gpu 3 \ 3 | --model Transolver_Structured_Mesh_2D \ 4 | --n-hidden 128 \ 5 | --n-heads 8 \ 6 | --n-layers 8 \ 7 | --lr 0.001 \ 8 | --max_grad_norm 0.1 \ 9 | --batch-size 8 \ 10 | --slice_num 64 \ 11 | --unified_pos 0 \ 12 | --ref 8 \ 13 | --eval 0 \ 14 | --save_name plas_Transolver 15 | 16 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/utils/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import * 3 | 4 | 5 | class IdentityTransformer(): 6 | def __init__(self, X): 7 | self.mean = X.mean(dim=0, keepdim=True) 8 | self.std = X.std(dim=0, keepdim=True) + 1e-8 9 | 10 | def to(self, device): 11 | self.mean = self.mean.to(device) 12 | self.std = self.std.to(device) 13 | return self 14 | 15 | def cuda(self): 16 | self.mean = self.mean.cuda() 17 | self.std = self.std.cuda() 18 | 19 | def cpu(self): 20 | self.mean = self.mean.cpu() 21 | self.std = self.std.cpu() 22 | 23 | def encode(self, x): 24 | return x 25 | 26 | def decode(self, x): 27 | return x 28 | 29 | 30 | class UnitTransformer(): 31 | def __init__(self, X): 32 | self.mean = X.mean(dim=(0, 1), keepdim=True) 33 | self.std = X.std(dim=(0, 1), keepdim=True) + 1e-8 34 | 35 | def to(self, device): 36 | self.mean = self.mean.to(device) 37 | self.std = self.std.to(device) 38 | return self 39 | 40 | def cuda(self): 41 | self.mean = self.mean.cuda() 42 | self.std = self.std.cuda() 43 | 44 | def cpu(self): 45 | self.mean = self.mean.cpu() 46 | self.std = self.std.cpu() 47 | 48 | def encode(self, x): 49 | x = (x - self.mean) / (self.std) 50 | return x 51 | 52 | def decode(self, x): 53 | return x * self.std + self.mean 54 | 55 | def transform(self, X, inverse=True, component='all'): 56 | if component == 'all' or 'all-reduce': 57 | if inverse: 58 | orig_shape = X.shape 59 | return (X * (self.std - 1e-8) + self.mean).view(orig_shape) 60 | else: 61 | return (X - self.mean) / self.std 62 | else: 63 | if inverse: 64 | orig_shape = X.shape 65 | return (X * (self.std[:, component] - 1e-8) + self.mean[:, component]).view(orig_shape) 66 | else: 67 | return (X - self.mean[:, component]) / self.std[:, component] 68 | 69 | 70 | class UnitGaussianNormalizer(object): 71 | def __init__(self, x, eps=0.00001, time_last=True): 72 | super(UnitGaussianNormalizer, self).__init__() 73 | 74 | self.mean = torch.mean(x, 0) 75 | self.std = torch.std(x, 0) 76 | self.eps = eps 77 | self.time_last = time_last # if the time dimension is the last dim 78 | 79 | def encode(self, x): 80 | x = (x - self.mean) / (self.std + self.eps) 81 | return x 82 | 83 | def decode(self, x, sample_idx=None): 84 | # sample_idx is the spatial sampling mask 85 | if sample_idx is None: 86 | std = self.std + self.eps # n 87 | mean = self.mean 88 | else: 89 | if self.mean.ndim == sample_idx.ndim or self.time_last: 90 | std = self.std[sample_idx] + self.eps # batch*n 91 | mean = self.mean[sample_idx] 92 | if self.mean.ndim > sample_idx.ndim and not self.time_last: 93 | std = self.std[..., sample_idx] + self.eps # T*batch*n 94 | mean = self.mean[..., sample_idx] 95 | # x is in shape of batch*(spatial discretization size) or T*batch*(spatial discretization size) 96 | x = (x * std) + mean 97 | return x 98 | 99 | def to(self, device): 100 | if torch.is_tensor(self.mean): 101 | self.mean = self.mean.to(device) 102 | self.std = self.std.to(device) 103 | else: 104 | self.mean = torch.from_numpy(self.mean).to(device) 105 | self.std = torch.from_numpy(self.std).to(device) 106 | return self 107 | 108 | def cuda(self): 109 | self.mean = self.mean.cuda() 110 | self.std = self.std.cuda() 111 | 112 | def cpu(self): 113 | self.mean = self.mean.cpu() 114 | self.std = self.std.cpu() 115 | -------------------------------------------------------------------------------- /PDE-Solving-StandardBenchmark/utils/testloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TestLoss(object): 5 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 6 | super(TestLoss, self).__init__() 7 | 8 | assert d > 0 and p > 0 9 | 10 | self.d = d 11 | self.p = p 12 | self.reduction = reduction 13 | self.size_average = size_average 14 | 15 | def abs(self, x, y): 16 | num_examples = x.size()[0] 17 | 18 | h = 1.0 / (x.size()[1] - 1.0) 19 | 20 | all_norms = (h ** (self.d / self.p)) * torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 21 | 1) 22 | 23 | if self.reduction: 24 | if self.size_average: 25 | return torch.mean(all_norms) 26 | else: 27 | return torch.sum(all_norms) 28 | 29 | return all_norms 30 | 31 | def rel(self, x, y): 32 | num_examples = x.size()[0] 33 | 34 | diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) 35 | y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) 36 | if self.reduction: 37 | if self.size_average: 38 | return torch.mean(diff_norms / y_norms) 39 | else: 40 | return torch.sum(diff_norms / y_norms) 41 | 42 | return diff_norms / y_norms 43 | 44 | def __call__(self, x, y): 45 | return self.rel(x, y) 46 | -------------------------------------------------------------------------------- /Physics_Attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from einops import rearrange, repeat 4 | 5 | 6 | class Physics_Attention_Irregular_Mesh(nn.Module): 7 | ## for irregular meshes in 1D, 2D or 3D space 8 | 9 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64): 10 | super().__init__() 11 | inner_dim = dim_head * heads 12 | self.dim_head = dim_head 13 | self.heads = heads 14 | self.scale = dim_head ** -0.5 15 | self.softmax = nn.Softmax(dim=-1) 16 | self.dropout = nn.Dropout(dropout) 17 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 18 | 19 | self.in_project_x = nn.Linear(dim, inner_dim) 20 | self.in_project_fx = nn.Linear(dim, inner_dim) 21 | self.in_project_slice = nn.Linear(dim_head, slice_num) 22 | for l in [self.in_project_slice]: 23 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 24 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 25 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 26 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 27 | self.to_out = nn.Sequential( 28 | nn.Linear(inner_dim, dim), 29 | nn.Dropout(dropout) 30 | ) 31 | 32 | def forward(self, x): 33 | # B N C 34 | B, N, C = x.shape 35 | 36 | ### (1) Slice 37 | fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \ 38 | .permute(0, 2, 1, 3).contiguous() # B H N C 39 | x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \ 40 | .permute(0, 2, 1, 3).contiguous() # B H N C 41 | slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G 42 | slice_norm = slice_weights.sum(2) # B H G 43 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 44 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 45 | 46 | ### (2) Attention among slice tokens 47 | q_slice_token = self.to_q(slice_token) 48 | k_slice_token = self.to_k(slice_token) 49 | v_slice_token = self.to_v(slice_token) 50 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 51 | attn = self.softmax(dots) 52 | attn = self.dropout(attn) 53 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 54 | 55 | ### (3) Deslice 56 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 57 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 58 | return self.to_out(out_x) 59 | 60 | 61 | class Physics_Attention_Structured_Mesh_2D(nn.Module): 62 | ## for structured mesh in 2D space 63 | 64 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, H=101, W=31, kernel=3): # kernel=3): 65 | super().__init__() 66 | inner_dim = dim_head * heads 67 | self.dim_head = dim_head 68 | self.heads = heads 69 | self.scale = dim_head ** -0.5 70 | self.softmax = nn.Softmax(dim=-1) 71 | self.dropout = nn.Dropout(dropout) 72 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 73 | self.H = H 74 | self.W = W 75 | 76 | self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) 77 | self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) 78 | self.in_project_slice = nn.Linear(dim_head, slice_num) 79 | for l in [self.in_project_slice]: 80 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 81 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 82 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 83 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 84 | 85 | self.to_out = nn.Sequential( 86 | nn.Linear(inner_dim, dim), 87 | nn.Dropout(dropout) 88 | ) 89 | 90 | def forward(self, x): 91 | # B N C 92 | B, N, C = x.shape 93 | x = x.reshape(B, self.H, self.W, C).contiguous().permute(0, 3, 1, 2).contiguous() # B C H W 94 | 95 | ### (1) Slice 96 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 97 | .permute(0, 2, 1, 3).contiguous() # B H N C 98 | x_mid = self.in_project_x(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 99 | .permute(0, 2, 1, 3).contiguous() # B H N G 100 | slice_weights = self.softmax( 101 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G 102 | slice_norm = slice_weights.sum(2) # B H G 103 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 104 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 105 | 106 | ### (2) Attention among slice tokens 107 | q_slice_token = self.to_q(slice_token) 108 | k_slice_token = self.to_k(slice_token) 109 | v_slice_token = self.to_v(slice_token) 110 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 111 | attn = self.softmax(dots) 112 | attn = self.dropout(attn) 113 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 114 | 115 | ### (3) Deslice 116 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 117 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 118 | return self.to_out(out_x) 119 | 120 | 121 | class Physics_Attention_Structured_Mesh_3D(nn.Module): 122 | ## for structured mesh in 3D space 123 | 124 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=32, H=32, W=32, D=32, kernel=3): 125 | super().__init__() 126 | inner_dim = dim_head * heads 127 | self.dim_head = dim_head 128 | self.heads = heads 129 | self.scale = dim_head ** -0.5 130 | self.softmax = nn.Softmax(dim=-1) 131 | self.dropout = nn.Dropout(dropout) 132 | self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) 133 | self.H = H 134 | self.W = W 135 | self.D = D 136 | 137 | self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) 138 | self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) 139 | self.in_project_slice = nn.Linear(dim_head, slice_num) 140 | for l in [self.in_project_slice]: 141 | torch.nn.init.orthogonal_(l.weight) # use a principled initialization 142 | self.to_q = nn.Linear(dim_head, dim_head, bias=False) 143 | self.to_k = nn.Linear(dim_head, dim_head, bias=False) 144 | self.to_v = nn.Linear(dim_head, dim_head, bias=False) 145 | self.to_out = nn.Sequential( 146 | nn.Linear(inner_dim, dim), 147 | nn.Dropout(dropout) 148 | ) 149 | 150 | def forward(self, x): 151 | # B N C 152 | B, N, C = x.shape 153 | x = x.reshape(B, self.H, self.W, self.D, C).contiguous().permute(0, 4, 1, 2, 3).contiguous() # B C H W 154 | 155 | ### (1) Slice 156 | fx_mid = self.in_project_fx(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 157 | .permute(0, 2, 1, 3).contiguous() # B H N C 158 | x_mid = self.in_project_x(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \ 159 | .permute(0, 2, 1, 3).contiguous() # B H N G 160 | slice_weights = self.softmax( 161 | self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G 162 | slice_norm = slice_weights.sum(2) # B H G 163 | slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) 164 | slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head)) 165 | 166 | ### (2) Attention among slice tokens 167 | q_slice_token = self.to_q(slice_token) 168 | k_slice_token = self.to_k(slice_token) 169 | v_slice_token = self.to_v(slice_token) 170 | dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale 171 | attn = self.softmax(dots) 172 | attn = self.dropout(attn) 173 | out_slice_token = torch.matmul(attn, v_slice_token) # B H G D 174 | 175 | ### (3) Deslice 176 | out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) 177 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 178 | return self.to_out(out_x) 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transolver (ICML 2024 Spotlight) 2 | 3 | :triangular_flag_on_post:**News** (2025.04) We have released [Neural-Solver-Library](https://github.com/thuml/Neural-Solver-Library) as a simple and neat code base for PDE solving. It contains 17 well-reproduced neural solvers. Welcome to try this library and join the research in solving PDEs. 4 | 5 | :triangular_flag_on_post:**News** (2025.02) We present an upgraded version of Transolver, named [Transolver++](https://arxiv.org/abs/2502.02414v1), which can handle million-scale geometries in one GPU with more accurate results. 6 | 7 | :triangular_flag_on_post:**News** (2024.10) Transolver has been integrated into [NVIDIA modulus](https://github.com/NVIDIA/modulus/tree/main/examples/cfd/darcy_transolver). 8 | 9 | Transolver: A Fast Transformer Solver for PDEs on General Geometries [[Paper]](https://arxiv.org/abs/2402.02366) [[Slides]](https://wuhaixu2016.github.io/pdf/ICML2024_Transolver.pdf) [[Poster]](https://wuhaixu2016.github.io/pdf/poster_ICML2024_Transolver.pdf) 10 | 11 | In real-world applications, PDEs are typically discretized into large-scale meshes with complex geometries. To capture intricate physical correlations hidden under multifarious meshes, we propose the Transolver with the following features: 12 | 13 | - Going beyond previous work, Transolver **calculates attention among learned physical states** instead of mesh points, which empowers the model with **endogenetic geometry-general capability**. 14 | - Transolver achieves **22% error reduction over previous SOTA in six standard benchmarks** and excels in **large-scale industrial simulations**, including car and airfoil designs. 15 | - Transolver presents favorable **efficiency, scalability and out-of-distrbution generalizability**. 16 | 17 |

18 | 19 |

20 | Figure 1. Overview of Transolver. 21 |

22 | 23 | 24 | ## Transolver v.s. Previous Transformer Operators 25 | 26 | **All of the previous Transformer-based neural operators directly apply attention to mesh points.** However, the massive mesh points in practical applications will cause challenges in both computation cost and capturing physical correlations. 27 | 28 | Transolver is based on a more foundational idea, that is **learning intrinsic physical states under complex geometrics**. This design frees our model from superficial and unwieldy meshes and focuses more on physics modeling. 29 | 30 | As shown below, **Transolver can precisely capture miscellaneous physical states of PDEs**, such as (a) various fluid-structure interactions in a Darcy flow, (b) different extrusion regions of elastic materials, (c) shock wave and wake flow around the airfoil, (d) front-back surfaces and up-bottom spaces of driving cars. 31 | 32 |

33 | 34 |

35 | Figure 2. Visualization of learned physical states. 36 |

37 | 38 | ## Get Started 39 | 40 | 1. Please refer to different folders for detailed experiment instructions. 41 | 42 | 2. List of experiments: 43 | 44 | - Core code: see [./Physics_Attention.py](https://github.com/thuml/Transolver/blob/main/Physics_Attention.py) 45 | - Standard benchmarks: see [./PDE-Solving-StandardBenchmark](https://github.com/thuml/Transolver/tree/main/PDE-Solving-StandardBenchmark) 46 | - Car design task: see [./Car-Design-ShapeNetCar](https://github.com/thuml/Transolver/tree/main/Car-Design-ShapeNetCar) 47 | - Airfoil design task: see [./Airfoil-Design-AirfRANS](https://github.com/thuml/Transolver/tree/main/Airfoil-Design-AirfRANS) 48 | 49 | ## Results 50 | 51 | Transolver achieves consistent state-of-the-art in **six standard benchmarks and two practical design tasks**. **More than 20 baselines are compared.** 52 | 53 |

54 | 55 |

56 | Table 1. Results on six standard benchmarks. 57 |

58 | 59 |

60 | 61 |

62 | Table 2. Results on two design tasks: Car and Airfoild design. 63 |

64 | 65 | ## Showcases 66 | 67 |

68 | 69 |

70 | Figure 3. Comparison of Transolver and other models. 71 |

72 | 73 | ## Citation 74 | 75 | If you find this repo useful, please cite our paper. 76 | 77 | ``` 78 | @inproceedings{wu2024Transolver, 79 | title={Transolver: A Fast Transformer Solver for PDEs on General Geometries}, 80 | author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long}, 81 | booktitle={International Conference on Machine Learning}, 82 | year={2024} 83 | } 84 | ``` 85 | 86 | ## Contact 87 | 88 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn). 89 | 90 | ## Acknowledgement 91 | 92 | We appreciate the following github repos a lot for their valuable code base or datasets: 93 | 94 | https://github.com/neuraloperator/neuraloperator 95 | 96 | https://github.com/neuraloperator/Geo-FNO 97 | 98 | https://github.com/thuml/Latent-Spectral-Models 99 | 100 | https://github.com/Extrality/AirfRANS 101 | -------------------------------------------------------------------------------- /pic/Transolver.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/pic/Transolver.png -------------------------------------------------------------------------------- /pic/physical_states.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/pic/physical_states.png -------------------------------------------------------------------------------- /pic/showcases.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transolver/fcc58dfc7a761418903452fb6c59a04be874f983/pic/showcases.png --------------------------------------------------------------------------------