├── .gitignore ├── LICENSE ├── README.md ├── conf ├── ACEquation.yaml ├── Burgers.yaml ├── Diffusion.yaml ├── KDV.yaml ├── Schrodinger.yaml ├── data_conf │ └── default.yaml ├── evaluation_conf.yaml ├── global_conf │ └── default.yaml ├── model_conf │ └── default.yaml └── train_conf │ └── default.yaml ├── evaluate.py ├── ground_true ├── ACEquation.npz ├── Burgers.npz ├── Diffusion.npz ├── KDV.npz └── Schrodinger.npz ├── images └── DMIS-diagram.png ├── pretrain ├── ACEquation │ ├── PINN-DMIS │ │ └── best.pth │ ├── PINN-N │ │ └── best.pth │ └── PINN-O │ │ └── best.pth ├── Burgers │ ├── PINN-DMIS │ │ └── best.pth │ ├── PINN-N │ │ └── best.pth │ └── PINN-O │ │ └── best.pth ├── Diffusion │ ├── PINN-DMIS │ │ └── best.pth │ ├── PINN-N │ │ └── best.pth │ └── PINN-O │ │ └── best.pth ├── KDV │ ├── PINN-DMIS │ │ └── best.pth │ ├── PINN-N │ │ └── best.pth │ └── PINN-O │ │ └── best.pth └── Schrodinger │ ├── PINN-DMIS │ └── best.pth │ ├── PINN-N │ └── best.pth │ └── PINN-O │ └── best.pth ├── scripts ├── plot_tools.py └── tensorboard_data_export_tools.py ├── test.py ├── train.py └── utils ├── data_utils.py ├── equations ├── ACEquation.py ├── Burgers.py ├── Diffusion.py ├── KDV.py ├── Schrodinger.py ├── __init__.py └── basic_define.py ├── models.py ├── pde_utils.py ├── plot_utils.py ├── reweightings.py └── samplers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | 163 | # output files 164 | outputs/ 165 | multirun/ 166 | data/ 167 | Simulation/simulation_data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MatrixBrain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DMIS (AAAI2023) 2 | Official code for "[DMIS: Dynamic Mesh-based Importance Sampling for Training Physics-Informed Neural Networks](https://ojs.aaai.org/index.php/AAAI/article/view/25669)" (AAAI 2023) 3 | 4 | ![](./images/DMIS-diagram.png) 5 | 6 | Modeling dynamics in the form of partial differential equations (PDEs) is an effectual way to understand real-world physics processes. For complex physics systems, analytical solutions are not available and numerical solutions are widely-used. However, traditional numerical algorithms are computationally expensive and challenging in handling multiphysics systems. Recently, using neural networks to solve PDEs has made significant progress, called physics-informed neural networks (PINNs). PINNs encode physical laws into neural networks and learn the continuous solutions of PDEs. For the training of PINNs, existing methods suffer from the problems of inefficiency and unstable convergence, since the PDE residuals require calculating automatic differentiation. In this paper, we propose **D**ynamic **M**esh-based **I**mportance **S**ampling (DMIS) to tackle these problems. DMIS is a novel sampling scheme based on importance sampling, which constructs a dynamic triangular mesh to estimate sample weights efficiently. DMIS has broad applicability and can be easily integrated into existing methods. The evaluation of DMIS on three widely-used benchmarks shows that DMIS improves the convergence speed and accuracy in the meantime. Especially in solving the highly nonlinear Schrödinger Equation, compared with state-of-the-art methods, DMIS shows up to 46\% smaller root mean square error and five times faster convergence speed. 7 | 8 | ## Quick Start 9 | 10 | ### Installation 11 | 12 | #### Setup environment 13 | 14 | Dependencies: 15 | 16 | * PyTorch == 1.11.0 17 | * hydra == 1.2.0 18 | * tensorboard == 2.9.0 19 | * sympy == 1.10.1 20 | * scipy == 1.8.1 21 | * pandas == 1.4.3 22 | * numpy == 1.22.4 23 | * matplotlib == 3.5.2 24 | 25 | ```bash 26 | conda create --name DMIS python=3.7 27 | conda activate DMIS 28 | conda install --file requirements.txt 29 | ``` 30 | 31 | All the code has been tested on Ubuntu 16.04, Python 3.7.12, PyTorch 1.11.0, and CUDA 11.3 32 | 33 | #### Clone this repository 34 | 35 | ```bash 36 | git clone git@github.com:MatrixBrain/DMIS.git 37 | cd DMIS 38 | ``` 39 | 40 | ### Training 41 | 42 | To train PINNs with DMIS for solving Schrödinger Equation: 43 | 44 | ```bash 45 | python train.py --config-name=Schrodinger train_conf.pde_sampler=SamplerWithDMIS train_conf.pde_reweighting=BiasedReWeighting hydra.job.chdir=True 46 | ``` 47 | 48 | To train PINNs with uniform sampling for solving Schrödinger Equation: 49 | 50 | ```bash 51 | python train.py --config-name=Schrodinger train_conf.pde_sampler=UniformSampler train_conf.pde_reweighting=NoReWeighting hydra.job.chdir=True 52 | ``` 53 | 54 | For other equations, you can replace ```Schrodinger``` with ```KDV```(KdV Equation), ```Burgers```(Burgers' Equation), ```Diffusion```(Diffusion Equation) and ```ACEquation ```(Allen-Cahn Equation) 55 | 56 | ### Evaluation 57 | 58 | To evaluate PINN-O, PINN-N, DMIS used in our paper, please run: 59 | 60 | ```bash 61 | python evaluate.py hydra.job.chdir=True 62 | ``` 63 | 64 | ## Results 65 | 66 | * Schrödinger Equation 67 | 68 | |Method|ME|MAE|RMSE| 69 | |:-:|:-:|:-:|:-:| 70 | |PINN-O|1.360|0.186|0.4092| 71 | |PINN-N|0.948|0.149|0.2906| 72 | |xPINN|0.546|0.045|0.0089| 73 | |cPINN|0.591|0.069|0.0169| 74 | |**PINN-DMIS(ours)**|0.647|0.127|0.2196| 75 | |**xPINN-DMIS(ours)**|0.867|0.036|0.0129| 76 | |**cPINN-DMIS(ours)**|**0.358**|**0.025**|**0.0033**| 77 | 78 | * Burgers' Equation 79 | 80 | |Method|ME|MAE|RMSE| 81 | |:-:|:-:|:-:|:-:| 82 | |PINN-O|0.451|0.0738|0.1100| 83 | |PINN-N|0.358|0.0579|0.0859| 84 | |xPINN|0.261|0.0099|0.0010| 85 | |cPINN|0.324|**0.0084**|**0.0007**| 86 | |**PINN-DMIS(ours)**|**0.225**|0.0294|0.0495| 87 | |**xPINN-DMIS(ours)**|0.420|0.0115|0.0017| 88 | |**cPINN-DMIS(ours)**|0.397|0.0111|0.0016| 89 | 90 | * KdV Equation 91 | 92 | |Method|ME|MAE|RMSE| 93 | |:-:|:-:|:-:|:-:| 94 | |PINN-O|2.140|0.363|0.520| 95 | |PINN-N|1.860|0.292|0.441| 96 | |xPINN|2.462|0.272|0.230| 97 | |cPINN|2.925|0.258|0.248| 98 | |**PINN-DMIS(ours)**|**1.170**|0.391|0.492| 99 | |**xPINN-DMIS(ours)**|2.380|0.233|**0.196**| 100 | |**cPINN-DMIS(ours)**|2.680|**0.230**|0.200| 101 | 102 | Note: The results of PINN-O are different from the provided results in the original PINN paper because we use **extrapolation precision** and the original PINN paper uses **interpolation precision**. 103 | 104 | ## Citation 105 | If you find the code and pre-trained models useful for your research, please consider citing our paper. :blush: 106 | ``` 107 | @InProceedings{yang2022dmis, 108 | author = {Yang, Zijiang and Qiu, Zhongwei and Fu, Dongmei}, 109 | title = {DMIS: Dynamic Mesh-based Importance Sampling for Training Physics-Informed Neural Networks}, 110 | booktitle = {AAAI}, 111 | year = {2023}, 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /conf/ACEquation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - global_conf: default 3 | - data_conf: default 4 | - model_conf: default 5 | - train_conf: default 6 | - _self_ 7 | 8 | name: ACEquation 9 | 10 | data_conf: 11 | initial_data_n: 200 12 | boundary_data_n: 200 13 | pde_data_n: 60000 14 | 15 | global_conf: 16 | seed: 8734 17 | 18 | model_conf: 19 | layer: 20 | layer_n: 5 21 | layer_size: [64, 64, 64, 64, 64] 22 | 23 | problem_conf: 24 | dims: 2 25 | x_range: [-1, 1] 26 | t_range: [0., 1.] 27 | initial_cond: cos(pi * x) * (x**2) 28 | boundary_cond: periodic 29 | 30 | train_conf: 31 | train_t_range: [ 0, 0.5 ] 32 | eval_t_range: [ 0.5, 0.75 ] 33 | test_t_range: [ 0.75, 1.0 ] 34 | pde_sampler: SamplerWithDMIS 35 | pde_reweighting: BiasedReWeighting 36 | reweighting_params: 37 | k_init: 1.5 38 | k_final: 1.5 39 | iter_n: ${train_conf.main_conf.max_steps} 40 | optim_conf: 41 | lr: 1e-3 42 | main_conf: 43 | max_steps: 100000 44 | pde_batch_size: 20000 45 | initial_batch_size: 50 46 | boundary_batch_size: 50 47 | print_frequency: 1 48 | model_basic_save_name: ${name} 49 | sampler_conf: 50 | forward_batch_size: ${train_conf.main_conf.pde_batch_size} 51 | mesh_update_thres: 0.4 52 | addon_points: [ [ 0., -1. ], [ 0., 1. ], [ 0.5, 1. ], [ 0.5, -1. ] ] 53 | seed_n: 1000 54 | -------------------------------------------------------------------------------- /conf/Burgers.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - global_conf: default 3 | - data_conf: default 4 | - model_conf: default 5 | - train_conf: default 6 | - _self_ 7 | 8 | name: Burgers 9 | 10 | global_conf: 11 | seed: 2000 12 | 13 | model_conf: 14 | layer: 15 | layer_n: 3 16 | layer_size: [32, 32, 32] 17 | 18 | problem_conf: 19 | dims: 2 20 | x_range: [-1, 1] 21 | t_range: [0., 1.] 22 | initial_cond: "-sin(pi * x)" 23 | boundary_cond: "0 * x" 24 | 25 | train_conf: 26 | train_t_range: [ 0, 0.5 ] 27 | eval_t_range: [ 0.5, 0.75 ] 28 | test_t_range: [ 0.75, 1.0 ] 29 | pde_sampler: SamplerWithDMIS 30 | pde_reweighting: BiasedReWeighting 31 | reweighting_params: 32 | k_init: 1.5 33 | k_final: 1.5 34 | iter_n: ${train_conf.main_conf.max_steps} 35 | optim_conf: 36 | lr: 5e-3 37 | main_conf: 38 | max_steps: 10000 39 | print_frequency: 1 40 | pde_batch_size: 10000 41 | initial_batch_size: 500 42 | boundary_batch_size: 500 43 | model_basic_save_name: ${name} 44 | sampler_conf: 45 | forward_batch_size: ${train_conf.main_conf.pde_batch_size} 46 | mesh_update_thres: 0.4 47 | addon_points: [ [ 0., 1. ], [ 0., -1. ], [ 0.5, 1. ], [ 0.5, -1. ] ] 48 | seed_n: 1000 49 | -------------------------------------------------------------------------------- /conf/Diffusion.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - global_conf: default 3 | - data_conf: default 4 | - model_conf: default 5 | - train_conf: default 6 | - _self_ 7 | 8 | name: Diffusion 9 | 10 | global_conf: 11 | seed: 1024 12 | 13 | model_conf: 14 | layer_n: 2 15 | layer_size: [16, 16] 16 | 17 | problem_conf: 18 | dims: 2 19 | x_range: [0, 1.] 20 | t_range: [0., 1.] 21 | initial_cond: "2 * sin(pi * x) + 2 * (x - x**3)" 22 | boundary_cond: "0 * x" 23 | 24 | train_conf: 25 | train_t_range: [ 0, 0.5 ] 26 | eval_t_range: [ 0.5, 0.75 ] 27 | test_t_range: [ 0.75, 1.0 ] 28 | pde_sampler: SamplerWithDMIS 29 | pde_reweighting: BiasedReWeighting 30 | reweighting_params: 31 | k_init: 2 32 | k_final: 2 33 | iter_n: ${train_conf.main_conf.max_steps} 34 | optim_conf: 35 | lr: 2e-3 36 | main_conf: 37 | max_steps: 4000 38 | pde_batch_size: 20000 39 | initial_batch_size: 50 40 | boundary_batch_size: 50 41 | print_frequency: 1 42 | model_basic_save_name: ${name} 43 | sampler_conf: 44 | forward_batch_size: ${train_conf.main_conf.pde_batch_size} 45 | mesh_update_thres: 0.4 46 | addon_points: [[0., 1.], [0., 0.], [0.5, 0.], [0.5, 1.]] 47 | seed_n: 1000 48 | 49 | plot_conf: 50 | device: cuda 51 | cut_indxs: [10, 350, 500] 52 | pinn_path: outputs\compares\Diffusion\Uniform\9\Diffusion_3000.pth 53 | pinn_s_path: outputs\compares\Diffusion\Interpolation\9\Diffusion_1000.pth 54 | -------------------------------------------------------------------------------- /conf/KDV.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - global_conf: default 3 | - data_conf: default 4 | - model_conf: default 5 | - train_conf: default 6 | - _self_ 7 | 8 | name: KDV 9 | 10 | data_conf: 11 | initial_data_n: 200 12 | boundary_data_n: 200 13 | pde_data_n: 60000 14 | 15 | global_conf: 16 | seed: 1024 17 | 18 | model_conf: 19 | layer: 20 | layer_n: 4 21 | layer_size: [64, 64, 64, 64] 22 | 23 | problem_conf: 24 | dims: 2 25 | x_range: [-1, 1] 26 | t_range: [0., 1] 27 | initial_cond: cos(pi * x) 28 | boundary_cond: periodic 29 | 30 | train_conf: 31 | train_t_range: [ 0, 0.5 ] 32 | eval_t_range: [ 0.5, 0.75 ] 33 | test_t_range: [ 0.75, 1.0 ] 34 | pde_sampler: SamplerWithDMIS 35 | pde_reweighting: BiasedReWeighting 36 | reweighting_params: 37 | k_init: 2 38 | k_final: 2 39 | iter_n: ${train_conf.main_conf.max_steps} 40 | optim_conf: 41 | lr: 1e-3 42 | main_conf: 43 | print_frequency: 1 44 | max_steps: 50000 45 | pde_batch_size: 20000 46 | initial_batch_size: 50 47 | boundary_batch_size: 50 48 | model_basic_save_name: ${name} 49 | sampler_conf: 50 | forward_batch_size: ${train_conf.main_conf.pde_batch_size} 51 | mesh_update_thres: 0.4 52 | addon_points: [ [ 0., -1. ], [ 0., 1. ], [ 0.5, 1. ], [ 0.5, -1. ] ] 53 | seed_n: 1000 54 | -------------------------------------------------------------------------------- /conf/Schrodinger.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - global_conf: default 3 | - data_conf: default 4 | - model_conf: default 5 | - train_conf: default 6 | - _self_ 7 | 8 | name: Schrodinger 9 | 10 | data_conf: 11 | initial_data_n: 200 12 | boundary_data_n: 200 13 | pde_data_n: 60000 14 | 15 | global_conf: 16 | seed: 10 17 | 18 | model_conf: 19 | dim: 20 | output_dim: 2 21 | layer: 22 | layer_n: 4 23 | layer_size: [64, 64, 64, 64] 24 | 25 | problem_conf: 26 | dims: 2 27 | x_range: [-5, 5] 28 | t_range: [0., 1.57] 29 | initial_cond: 4 / (E**x + E**(-x)) 30 | boundary_cond: periodic 31 | 32 | train_conf: 33 | train_t_range: [ 0, 0.9 ] 34 | eval_t_range: [ 0.9, 1.1 ] 35 | test_t_range: [ 1.1, 1.57 ] 36 | pde_sampler: SamplerWithDMIS 37 | pde_reweighting: BiasedReWeighting 38 | reweighting_params: 39 | k_init: 2 40 | k_final: 2 41 | iter_n: ${train_conf.main_conf.max_steps} 42 | optim_conf: 43 | lr: 1e-3 44 | main_conf: 45 | max_steps: 100000 46 | print_frequency: 1 47 | pde_batch_size: 20000 48 | initial_batch_size: 50 49 | boundary_batch_size: 50 50 | model_basic_save_name: ${name} 51 | sampler_conf: 52 | forward_batch_size: ${train_conf.main_conf.pde_batch_size} 53 | mesh_update_thres: 0.4 54 | addon_points: [ [ 0., -5. ], [ 0., 5. ], [ 0.9, -5.], [ 0.9, 5. ] ] 55 | seed_n: 1000 56 | -------------------------------------------------------------------------------- /conf/data_conf/default.yaml: -------------------------------------------------------------------------------- 1 | initial_data_n: 2000 2 | boundary_data_n: 2000 3 | pde_data_n: 100000 -------------------------------------------------------------------------------- /conf/evaluation_conf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - global_conf: default 3 | - model_conf: default 4 | - _self_ 5 | 6 | equation_conf: 7 | Burgers: 8 | layer_n: 3 9 | layer_size: [ 32, 32, 32 ] 10 | output_dim: 1 11 | weight_dict: 12 | PINN-O: "PINN-O/best.pth" 13 | PINN-N: "PINN-N/best.pth" 14 | PINN-DMIS: "PINN-DMIS/best.pth" 15 | KDV: 16 | layer_n: 4 17 | layer_size: [ 64, 64, 64, 64 ] 18 | output_dim: 1 19 | weight_dict: 20 | PINN-O: "PINN-O/best.pth" 21 | PINN-N: "PINN-N/best.pth" 22 | PINN-DMIS: "PINN-DMIS/best.pth" 23 | Schrodinger: 24 | layer_n: 4 25 | layer_size: [ 64, 64, 64, 64 ] 26 | output_dim: 2 27 | weight_dict: 28 | PINN-O: "PINN-O/best.pth" 29 | PINN-N: "PINN-N/best.pth" 30 | PINN-DMIS: "PINN-DMIS/best.pth" 31 | Diffusion: 32 | layer_n: 4 33 | layer_size: [ 32, 32, 32, 32 ] 34 | output_dim: 1 35 | weight_dict: 36 | PINN-O: "PINN-O/best.pth" 37 | PINN-N: "PINN-N/best.pth" 38 | PINN-DMIS: "PINN-DMIS/best.pth" 39 | ACEquation: 40 | layer_n: 5 41 | layer_size: [ 64, 64, 64, 64, 64 ] 42 | output_dim: 1 43 | weight_dict: 44 | PINN-O: "PINN-O/best.pth" 45 | PINN-N: "PINN-N/best.pth" 46 | PINN-DMIS: "PINN-DMIS/best.pth" 47 | 48 | evaluation_metrics: [ 49 | "max error", 50 | "mean absolute error", 51 | "RMSE" 52 | ] 53 | -------------------------------------------------------------------------------- /conf/global_conf/default.yaml: -------------------------------------------------------------------------------- 1 | seed: 1024 2 | device: cuda 3 | tensorboard_path: ./tensorboard_log -------------------------------------------------------------------------------- /conf/model_conf/default.yaml: -------------------------------------------------------------------------------- 1 | load_model: False 2 | model_path: "./" 3 | dim: 4 | input_dim: 2 5 | output_dim: 1 6 | layer: 7 | layer_n: 4 8 | activate: "tanh" 9 | final_activate: "Identify" 10 | norm: False 11 | layer_size: [32, 32, 32, 32] -------------------------------------------------------------------------------- /conf/train_conf/default.yaml: -------------------------------------------------------------------------------- 1 | train_t_range: None 2 | eval_t_range: None 3 | test_t_range: None 4 | pde_sampler: UniformSampler 5 | pde_reweighting: NoReWeighting 6 | optim_conf: 7 | lr: 1e-3 8 | betas: [ 0.9, 0.999 ] 9 | eps: 1e-8 10 | weight_decay: 0.0 11 | amsgrad: False 12 | main_conf: 13 | max_steps: 5000 14 | pde_batch_size: 10000 15 | initial_batch_size: 500 16 | boundary_batch_size: 500 17 | print_frequency: 100 18 | eval_frequency: 500 19 | model_save_folder: ./ 20 | model_basic_save_name: None -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin python3 2 | # encoding : utf-8 -*- 3 | # @author : Zijiang Yang 4 | # @file : evaluation_metrics_result.py 5 | # @Time : 2022/7/28 0:56 6 | import logging 7 | import os.path 8 | 9 | import hydra 10 | from hydra.utils import get_original_cwd 11 | import numpy as np 12 | import torch 13 | 14 | from utils.models import FullyConnectedNetwork 15 | 16 | 17 | @hydra.main(version_base=None, config_path="./conf", config_name="evaluation_conf") 18 | def evaluation_setup(cfg): 19 | log = logging.getLogger("evaluation") 20 | project_root = get_original_cwd() 21 | 22 | model_conf = cfg["model_conf"] 23 | equation_conf = cfg["equation_conf"] 24 | evaluation_metrics = cfg["evaluation_metrics"] 25 | 26 | device = torch.device("cuda") 27 | 28 | for equation_key, conf in equation_conf.items(): 29 | log.info(equation_key) 30 | 31 | log.info("create model and load weight") 32 | model_dict = dict() 33 | for model_key, weight_path in conf["weight_dict"].items(): 34 | if weight_path != "": 35 | weight_path = os.path.join(project_root, 36 | "pretrain/{}".format(equation_key), 37 | weight_path) 38 | model_conf["layer"]["layer_n"] = conf["layer_n"] 39 | model_conf["layer"]["layer_size"] = conf["layer_size"] 40 | model_conf["dim"]["output_dim"] = conf["output_dim"] 41 | model_dict[model_key] = FullyConnectedNetwork(model_conf).to(device) 42 | model_dict[model_key].load_state_dict(torch.load(weight_path)) 43 | 44 | log.info("load ground true") 45 | ground_true_path = os.path.join(project_root, "ground_true/{}.npz".format(equation_key)) 46 | ground_true_numpy = np.load(ground_true_path) 47 | 48 | x_input = ground_true_numpy["input_x"] 49 | t_input = ground_true_numpy["input_t"] 50 | max_t = np.max(t_input) 51 | output = ground_true_numpy["output"] 52 | test_data_indices = np.argwhere(t_input > max_t * 0.75).reshape(-1) 53 | 54 | test_data_x = x_input[test_data_indices].reshape(-1, 1) 55 | test_data_t = t_input[test_data_indices].reshape(-1, 1) 56 | test_data_ground_true = output[test_data_indices].reshape(-1, 1) 57 | 58 | test_data = np.concatenate([ 59 | test_data_t, 60 | test_data_x, 61 | test_data_ground_true 62 | ], axis=1) 63 | 64 | log.info("pred") 65 | input_tensor = torch.from_numpy(test_data[:, :2]).to(device=device, dtype=torch.float) 66 | pred_dict = dict() 67 | 68 | for model_key, model in model_dict.items(): 69 | _pred = model(input_tensor).detach().cpu().numpy() 70 | if _pred.shape[1] == 1: 71 | _pred = _pred.reshape(-1) 72 | elif _pred.shape[1] == 2: 73 | _pred = np.sqrt(_pred[:, 0] ** 2 + _pred[:, 1] ** 2) 74 | pred_dict[model_key] = _pred 75 | 76 | log.info("evaluation") 77 | evaluation_dict = dict() 78 | for model_key, pred_result in pred_dict.items(): 79 | for metric in evaluation_metrics: 80 | if metric == "max error": 81 | evaluation_result = np.max(np.abs(pred_result - test_data[:, -1])) 82 | elif metric == "l2 norm": 83 | evaluation_result = np.linalg.norm(pred_result - test_data[:, -1]) 84 | elif metric == "RMSE": 85 | evaluation_result = np.sqrt(np.mean((pred_result - test_data[:, -1]) ** 2)) 86 | elif metric == "mean absolute error": 87 | evaluation_result = np.mean(np.abs(pred_result - test_data[:, -1])) 88 | else: 89 | raise KeyError 90 | evaluation_dict["{}_{}".format(model_key, metric)] = evaluation_result 91 | 92 | log.info("print evaluation result") 93 | for key, value in evaluation_dict.items(): 94 | log.info("{}: {:.5e}".format(key, value)) 95 | 96 | 97 | if __name__ == "__main__": 98 | evaluation_setup() 99 | -------------------------------------------------------------------------------- /ground_true/ACEquation.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/ground_true/ACEquation.npz -------------------------------------------------------------------------------- /ground_true/Burgers.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/ground_true/Burgers.npz -------------------------------------------------------------------------------- /ground_true/Diffusion.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/ground_true/Diffusion.npz -------------------------------------------------------------------------------- /ground_true/KDV.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/ground_true/KDV.npz -------------------------------------------------------------------------------- /ground_true/Schrodinger.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/ground_true/Schrodinger.npz -------------------------------------------------------------------------------- /images/DMIS-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/images/DMIS-diagram.png -------------------------------------------------------------------------------- /pretrain/ACEquation/PINN-DMIS/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/ACEquation/PINN-DMIS/best.pth -------------------------------------------------------------------------------- /pretrain/ACEquation/PINN-N/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/ACEquation/PINN-N/best.pth -------------------------------------------------------------------------------- /pretrain/ACEquation/PINN-O/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/ACEquation/PINN-O/best.pth -------------------------------------------------------------------------------- /pretrain/Burgers/PINN-DMIS/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Burgers/PINN-DMIS/best.pth -------------------------------------------------------------------------------- /pretrain/Burgers/PINN-N/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Burgers/PINN-N/best.pth -------------------------------------------------------------------------------- /pretrain/Burgers/PINN-O/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Burgers/PINN-O/best.pth -------------------------------------------------------------------------------- /pretrain/Diffusion/PINN-DMIS/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Diffusion/PINN-DMIS/best.pth -------------------------------------------------------------------------------- /pretrain/Diffusion/PINN-N/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Diffusion/PINN-N/best.pth -------------------------------------------------------------------------------- /pretrain/Diffusion/PINN-O/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Diffusion/PINN-O/best.pth -------------------------------------------------------------------------------- /pretrain/KDV/PINN-DMIS/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/KDV/PINN-DMIS/best.pth -------------------------------------------------------------------------------- /pretrain/KDV/PINN-N/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/KDV/PINN-N/best.pth -------------------------------------------------------------------------------- /pretrain/KDV/PINN-O/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/KDV/PINN-O/best.pth -------------------------------------------------------------------------------- /pretrain/Schrodinger/PINN-DMIS/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Schrodinger/PINN-DMIS/best.pth -------------------------------------------------------------------------------- /pretrain/Schrodinger/PINN-N/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Schrodinger/PINN-N/best.pth -------------------------------------------------------------------------------- /pretrain/Schrodinger/PINN-O/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatrixBrain/DMIS/cf2a4d8da58f4bf0dbb971826e47ff0011945690/pretrain/Schrodinger/PINN-O/best.pth -------------------------------------------------------------------------------- /scripts/plot_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def plot_train_process(data_dict: dict, if_log=True): 8 | """ 9 | plot convergence curves 10 | """ 11 | keys = list(data_dict.keys()) 12 | model_name_list = set([key.split("_")[0] for key in keys]) 13 | plot_data_dict = dict() 14 | for model_name in model_name_list: 15 | plot_data_dict["{}_step".format(model_name)] = None 16 | plot_data_dict["{}_time".format(model_name)] = None 17 | plot_data_dict["{}_data".format(model_name)] = list() 18 | 19 | for key, data in data_dict.items(): 20 | model_name = key.split("_")[0] 21 | if plot_data_dict["{}_step".format(model_name)] is None: 22 | plot_data_dict["{}_step".format(model_name)] = data["step"] 23 | plot_data_dict["{}_time".format(model_name)] = data["walltime"] - np.min(data["walltime"]) 24 | plot_data_dict["{}_data".format(model_name)].append(data["value"]) 25 | 26 | for model_name in model_name_list: 27 | key = "{}_data".format(model_name) 28 | model_name = key.split("_")[0] 29 | plot_data_dict[key] = np.array(plot_data_dict[key]) 30 | upper_list = np.max(plot_data_dict[key], axis=0) 31 | lower_list = np.min(plot_data_dict[key], axis=0) 32 | if if_log: 33 | mean_list = np.power(10, (np.log10(upper_list) + np.log10(lower_list)) / 2) 34 | else: 35 | mean_list = (upper_list + lower_list) / 2 36 | plot_data_dict["{}_mean".format(model_name)] = mean_list 37 | plot_data_dict["{}_upper".format(model_name)] = upper_list 38 | plot_data_dict["{}_lower".format(model_name)] = lower_list 39 | 40 | plt.figure(figsize=(16, 8)) 41 | iter_subplot = plt.subplot(1, 2, 1) 42 | if if_log: 43 | iter_subplot.set_yscale("log") 44 | for model_name in model_name_list: 45 | step_list = plot_data_dict["{}_step".format(model_name)] 46 | mean_list = plot_data_dict["{}_mean".format(model_name)] 47 | upper_list = plot_data_dict["{}_upper".format(model_name)] 48 | lower_list = plot_data_dict["{}_lower".format(model_name)] 49 | 50 | iter_subplot.plot(step_list, mean_list, label=model_name, linewidth=2) 51 | iter_subplot.fill_between(step_list, upper_list, lower_list, alpha=0.3) 52 | 53 | iter_subplot.set_xlabel("iterations") 54 | iter_subplot.set_ylabel("loss") 55 | iter_subplot.legend() 56 | 57 | time_subplot = plt.subplot(1, 2, 2) 58 | if if_log: 59 | time_subplot.set_yscale("log") 60 | for model_name in model_name_list: 61 | time_list = plot_data_dict["{}_time".format(model_name)] 62 | mean_list = plot_data_dict["{}_mean".format(model_name)] 63 | upper_list = plot_data_dict["{}_upper".format(model_name)] 64 | lower_list = plot_data_dict["{}_lower".format(model_name)] 65 | 66 | time_subplot.plot(time_list, mean_list, label=model_name, linewidth=2) 67 | time_subplot.fill_between(time_list, upper_list, lower_list, alpha=0.3) 68 | time_subplot.set_xlabel("time") 69 | time_subplot.set_ylabel("loss") 70 | time_subplot.legend() 71 | plt.tight_layout() 72 | plt.show() 73 | 74 | 75 | if __name__ == "__main__": 76 | 77 | # testing 78 | # curve 1: y = e^(-x) - 1 79 | # curve 2: y = e^(-2x) - 1 80 | 81 | test_data_dict = dict() 82 | test_basic_input = np.arange(200) * 0.1 83 | for i in range(10): 84 | test_model1_data = np.exp(-0.05*test_basic_input) - 1 + np.random.random(200) * np.exp(-0.07*test_basic_input) 85 | test_model1_data = (test_model1_data - np.min(test_model1_data) + 1e-1) * 1000 86 | test_model1_step = np.arange(200) * 50 87 | test_model1_time = np.arange(200) * 1.1 + np.random.random(200) 88 | 89 | test_data = { 90 | "value": test_model1_data, 91 | "step": test_model1_step, 92 | "walltime": test_model1_time 93 | } 94 | 95 | test_data_dict["model1_{}".format(i)] = test_data 96 | 97 | test_model2_data = np.exp(-0.2*test_basic_input) - 1 + np.random.random(200) * np.exp(-0.1*test_basic_input) 98 | test_model2_data = (test_model2_data - np.min(test_model2_data) + 1e-1) * 1000 99 | test_model2_step = np.arange(200) * 50 100 | test_model2_time = np.arange(200) * 1.3 + np.random.random(200) 101 | 102 | test_data = { 103 | "value": test_model2_data, 104 | "step": test_model2_step, 105 | "walltime": test_model2_time 106 | } 107 | test_data_dict["model2_{}".format(i)] = test_data 108 | 109 | plot_train_process(test_data_dict, if_log=True) 110 | 111 | 112 | -------------------------------------------------------------------------------- /scripts/tensorboard_data_export_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from tensorboard.backend.event_processing import event_accumulator 6 | import pandas as pd 7 | import argparse 8 | 9 | 10 | def _to_csv(export_path, save_data, columns, index=False): 11 | df = pd.DataFrame(data=save_data, columns=columns) 12 | df.to_csv(export_path, index=index) 13 | 14 | 15 | def single_event_scalars_export_tool(in_path, ex_path): 16 | event_data = event_accumulator.EventAccumulator(in_path) 17 | event_data.Reload() 18 | 19 | keys = event_data.scalars.Keys() 20 | print("keys list: {}".format(keys)) 21 | 22 | export_items = ["step", "wall_time", "value"] 23 | for key in keys: 24 | print("process key: {}".format(key)) 25 | export_names = ["{}_{}".format(key, item) for item in export_items] 26 | save_data = list() 27 | for e in event_data.Scalars(key): 28 | temp_data = [e.step, e.wall_time, e.value] 29 | save_data.append(temp_data) 30 | 31 | save_name = "{}.csv".format(key.replace("/", "_")) 32 | save_path = os.path.join(ex_path, save_name) 33 | _to_csv(save_path, save_data, export_names) 34 | print("export data of {} done...".format(key)) 35 | 36 | 37 | def multi_event_scalars_export_tool(in_path, ex_path): 38 | root, dirs, _ = next(os.walk(in_path)) 39 | for _dir in dirs: 40 | print("process summary: {}".format(_dir)) 41 | event_in_path = os.path.join(root, _dir, "tensorboard_log") 42 | event_ex_path = os.path.join(ex_path, _dir) 43 | if not os.path.exists(event_ex_path): 44 | os.makedirs(event_ex_path) 45 | single_event_scalars_export_tool(event_in_path, event_ex_path) 46 | 47 | 48 | def single_event_images_export_tool(): 49 | pass 50 | 51 | 52 | def multi_event_images_export_tool(): 53 | pass 54 | 55 | 56 | if __name__ == "__main__": 57 | 58 | parser = argparse.ArgumentParser(description="Export tensorboard data") 59 | parser.add_argument("--in_path", type=str, help="tensorboard file location") 60 | parser.add_argument("--ex_path", type=str, default="./", help="export path") 61 | parser.add_argument("--state", type=str, default="single", help="single summary or multi summaries") 62 | parser.add_argument("--ex_type", type=str, default="scales", help="export data type") 63 | args = parser.parse_args() 64 | 65 | if args.state == "single": 66 | if args.ex_type == "scales": 67 | single_event_scalars_export_tool(args.in_path, args.ex_path) 68 | else: 69 | raise KeyError 70 | elif args.state == "multi": 71 | if args.ex_type == "scales": 72 | multi_event_scalars_export_tool(args.in_path, args.ex_path) 73 | else: 74 | raise KeyError 75 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from mpl_toolkits.axes_grid1 import make_axes_locatable 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | from utils.plot_utils import interpolate_2d 9 | 10 | 11 | def test2d(model, t_range, x_range, *, delta_t=0.01, delta_x=0.01, ground_true=None): 12 | 13 | model.eval() 14 | x_input = np.arange(x_range[0], x_range[1], delta_x) 15 | t_input = np.arange(t_range[0], t_range[1], delta_t) 16 | 17 | xx, tt = np.meshgrid(x_input, t_input) 18 | input_array = np.concatenate([tt.reshape(-1, 1), xx.reshape(-1, 1)], axis=1) 19 | input_tensor = torch.from_numpy(input_array).to(device=torch.device("cuda"), dtype=torch.float) 20 | 21 | pred = model(input_tensor).detach().cpu().numpy() 22 | if pred.shape[1] == 1: 23 | pred = pred.reshape(-1) 24 | elif pred.shape[1] == 2: 25 | pred = np.sqrt(pred[:, 0] ** 2 + pred[:, 1] ** 2) 26 | else: 27 | raise ValueError("don't support {} dims".format(pred.shape[1])) 28 | 29 | pred_extent, pred_image, pred_mesh = interpolate_2d(input_array, pred) 30 | 31 | plt.figure(figsize=(8, 8)) 32 | plt.pcolor(pred_mesh[0], pred_mesh[1], pred_image, cmap="rainbow") 33 | plt.xlabel("t") 34 | plt.ylabel("x") 35 | divider = make_axes_locatable(plt.gca()) 36 | cax = divider.append_axes("right", size="5%", pad="3%") 37 | plt.colorbar(cax=cax) 38 | plt.tight_layout() 39 | plt.savefig("./pred.png") 40 | plt.close() 41 | 42 | if ground_true is not None: 43 | gt_extent, gt_image, gt_mesh = interpolate_2d(ground_true[:, 0:2], ground_true[:, 2]) 44 | 45 | plt.figure(figsize=(8, 8)) 46 | plt.pcolor(gt_mesh[0], gt_mesh[1], gt_image, cmap="rainbow") 47 | # plt.imshow(ground_true_mesh.T, origin="lower", extent=ground_true_extent) 48 | plt.xlabel("t") 49 | plt.ylabel("x") 50 | divider = make_axes_locatable(plt.gca()) 51 | cax = divider.append_axes("right", size="5%", pad="3%") 52 | plt.colorbar(cax=cax) 53 | plt.tight_layout() 54 | plt.savefig("./ground_true.png") 55 | plt.close() 56 | 57 | difference_image = gt_image - pred_image 58 | plt.figure(figsize=(8, 8)) 59 | plt.pcolor(gt_mesh[0], gt_mesh[1], difference_image, cmap="rainbow") 60 | # plt.imshow(difference_mesh.T, origin="lower", extent=ground_true_extent) 61 | plt.xlabel("t") 62 | plt.ylabel("x") 63 | divider = make_axes_locatable(plt.gca()) 64 | cax = divider.append_axes("right", size="5%", pad="3%") 65 | plt.colorbar(cax=cax) 66 | plt.tight_layout() 67 | plt.savefig("./difference.png") 68 | plt.close() 69 | 70 | 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import queue 5 | import hydra 6 | import random 7 | import logging 8 | import numpy as np 9 | from omegaconf import OmegaConf 10 | from hydra.utils import get_original_cwd 11 | 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from utils.equations import equation_dict 16 | from utils.data_utils import split_data 17 | from utils.samplers import sampler_dict 18 | from utils.reweightings import reweighting_dict 19 | from utils.models import FullyConnectedNetwork, model_saver 20 | from test import test2d 21 | 22 | 23 | @hydra.main(version_base=None, config_path="./conf", config_name="ACEquation") 24 | def train_setup(cfg): 25 | log = logging.getLogger("Train") 26 | problem_conf = cfg["problem_conf"] 27 | global_conf = cfg["global_conf"] 28 | model_conf = cfg["model_conf"] 29 | train_conf = cfg["train_conf"] 30 | data_conf = cfg["data_conf"] 31 | tensorboard_writer = SummaryWriter(cfg["global_conf"]["tensorboard_path"]) 32 | log.info(OmegaConf.to_yaml(cfg)) 33 | 34 | # --------------- 35 | # global 36 | # --------------- 37 | if global_conf["seed"]: 38 | np.random.seed(global_conf["seed"]) 39 | random.seed(global_conf["seed"]) 40 | torch.manual_seed(global_conf["seed"]) 41 | torch.cuda.manual_seed(global_conf["seed"]) 42 | 43 | device = torch.device(global_conf["device"]) 44 | log.info(f"device: {device}") 45 | 46 | # ------------- 47 | # model 48 | # ------------- 49 | log.info("create model...") 50 | model = FullyConnectedNetwork(model_conf) 51 | model.to(device) 52 | log.info(model) 53 | if model_conf.load_model: 54 | log.info("load weights") 55 | model.load_state_dict(torch.load(model_conf.model_path)) 56 | log.info("load done...") 57 | 58 | # ------------ 59 | # create data 60 | # ------------ 61 | problem_define = equation_dict[cfg["name"]](problem_conf, data_conf) # create data_manager 62 | problem_define.data_generator(global_conf["seed"]) # create dataset 63 | log.info("create problem data successful...") 64 | 65 | # --------------------- 66 | # split training, validating, testing 67 | # --------------------- 68 | split_t_dict = { 69 | "train": train_conf["train_t_range"], 70 | "eval": train_conf["eval_t_range"], 71 | "test": train_conf["test_t_range"] 72 | } 73 | boundary_data_split_result = split_data(problem_define.boundary_data, split_t_dict, 0) 74 | pde_data_split_result = split_data(problem_define.pde_data, split_t_dict, 0) 75 | 76 | log.info("split dataset successful...") 77 | 78 | # --------- 79 | # create sampler 80 | # --------- 81 | # train data sampler 82 | train_initial_tensor = torch.from_numpy(problem_define.initial_data).to(device=device, dtype=torch.float) 83 | train_boundary_tensor = torch.from_numpy(boundary_data_split_result["train"]).to(device=device, dtype=torch.float) 84 | train_pde_tensor = torch.from_numpy(pde_data_split_result["train"]).to(device=device, dtype=torch.float) 85 | train_pde_tensor.requires_grad = True 86 | if problem_conf["boundary_cond"] == "periodic": 87 | train_boundary_tensor.requires_grad = True 88 | 89 | train_pde_sampler = sampler_dict[train_conf["pde_sampler"]]( 90 | train_pde_tensor, reweighting_dict[train_conf["pde_reweighting"]](train_conf["reweighting_params"]), 91 | model=model, 92 | loss_func=problem_define.compute_loss_basic_weights, 93 | **train_conf["sampler_conf"] 94 | ) 95 | train_initial_sampler = sampler_dict["UniformSampler"](train_initial_tensor, reweighting_dict["NoReWeighting"]()) 96 | train_boundary_sampler = sampler_dict["UniformSampler"](train_boundary_tensor, reweighting_dict["NoReWeighting"]()) 97 | 98 | # validate data 99 | project_root = get_original_cwd() 100 | ground_true_numpy = np.load("{}/ground_true/{}.npz".format(project_root, cfg["name"])) 101 | 102 | x_input = ground_true_numpy["input_x"].reshape(-1, 1) 103 | t_input = ground_true_numpy["input_t"].reshape(-1, 1) 104 | output = ground_true_numpy["output"].reshape(-1, 1) 105 | ground_true = np.concatenate([t_input, x_input, output], axis=1) 106 | 107 | # ground true 108 | ground_true_split_data = split_data(ground_true, split_t_dict, 0) 109 | for key, data in ground_true_split_data.items(): 110 | ground_true_split_data[key] = torch.from_numpy(data).to(device=torch.device("cuda"), dtype=torch.float) 111 | 112 | # ------------- 113 | # optimizer 114 | # ------------- 115 | optim = torch.optim.Adam(model.parameters(), **train_conf["optim_conf"]) 116 | 117 | # ------------- 118 | # main loop 119 | # ------------- 120 | best_eval_loss = 1e6 121 | best_model_save_path = None 122 | train_main_conf = train_conf["main_conf"] 123 | model_save_queue = queue.Queue(maxsize=5) 124 | for step in range(train_main_conf["max_steps"]): 125 | 126 | train_pde_data = train_pde_sampler.sampler(train_main_conf["pde_batch_size"]) 127 | train_initial_data = train_initial_sampler.sampler(train_main_conf["initial_batch_size"]) 128 | train_boundary_data = train_boundary_sampler.sampler(train_main_conf["boundary_batch_size"]) 129 | 130 | optim.zero_grad() 131 | loss_dict = problem_define.compute_loss(model, train_pde_data, train_initial_data, train_boundary_data, "train") 132 | optim.step() 133 | 134 | if step % train_main_conf["print_frequency"] == 0: 135 | log.info(f"step: {step}") 136 | for key, value in loss_dict.items(): 137 | log.info("{} loss: {:.5e}".format(key, value)) 138 | tensorboard_writer.add_scalar(f"TrainLoss/{key}", value, step) 139 | 140 | if step % train_main_conf["eval_frequency"] == 0: 141 | log.info("evaluation") 142 | model.eval() 143 | 144 | # evaluation 145 | loss_dict = dict() 146 | for key, data in ground_true_split_data.items(): 147 | _pred = model(data[:, 0:2]) 148 | if _pred.shape[1] == 2: 149 | _pred = torch.sqrt(_pred[:, 0:1] ** 2 + _pred[:, 1:2] ** 2) 150 | _error = torch.abs(_pred - data[:, 2:3]) 151 | _absolute_error = torch.mean(_error).item() 152 | _l2_error = torch.mean(_error**2).item() 153 | _peak_error = torch.max(_error).item() 154 | log.info("{} area: peak error:{:.4e}, " 155 | "absolute error:{:.4e}, " 156 | "l2 error:{:.4e}".format(key, _peak_error, _absolute_error, _l2_error)) 157 | 158 | tensorboard_writer.add_scalar(f"Error/{key} peak", _peak_error, step) 159 | tensorboard_writer.add_scalar(f"Error/{key} l2", _l2_error, step) 160 | tensorboard_writer.add_scalar(f"Error/{key} absolute", _absolute_error, step) 161 | 162 | loss_dict[key] = _l2_error 163 | 164 | if best_eval_loss > loss_dict["eval"]: 165 | best_eval_loss = loss_dict["eval"] 166 | best_model_save_path = model_saver( 167 | save_folder=train_main_conf["model_save_folder"], 168 | model=model, 169 | save_name=train_main_conf["model_basic_save_name"], 170 | step=step 171 | ) 172 | 173 | if model_save_queue.full(): 174 | del_step = model_save_queue.get() 175 | del_path = os.path.join(train_main_conf["model_save_folder"], 176 | "{}_{}.pth".format(train_main_conf["model_basic_save_name"], del_step)) 177 | os.remove(del_path) 178 | 179 | model_save_queue.put(step) 180 | 181 | model.train() 182 | 183 | log.info("train done...") 184 | 185 | # --------- 186 | # testing 187 | # --------- 188 | log.info("begin test...") 189 | model.load_state_dict(torch.load(best_model_save_path)) 190 | 191 | if problem_conf["dims"] == 2: 192 | 193 | test2d(model, problem_conf["t_range"], problem_conf["x_range"], ground_true=ground_true) 194 | log.info("test done...") 195 | 196 | 197 | if __name__ == "__main__": 198 | train_setup() 199 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numbers 4 | import numpy as np 5 | from functools import singledispatch 6 | 7 | 8 | @singledispatch 9 | def uniform_sampling(minmax_range, data_n): 10 | _temp_data = np.random.random_sample((data_n, 1)) 11 | _temp_data = (minmax_range[1] - minmax_range[0]) * _temp_data + minmax_range[0] 12 | return _temp_data 13 | 14 | 15 | @uniform_sampling.register(numbers.Real) 16 | def _(_value, data_n): 17 | _temp_data = np.ones((data_n, 1)) * _value 18 | return _temp_data 19 | 20 | 21 | @uniform_sampling.register(tuple) 22 | def _(_value, data_n): 23 | random_mask = np.random.randn(data_n, 1) > 0 24 | _temp_data = random_mask * _value[0] + (1-random_mask) * _value[1] 25 | return _temp_data 26 | 27 | 28 | def split_data(data, t_split_dict, t_dim_index): 29 | 30 | split_result = dict() 31 | for split_key, t_list in t_split_dict.items(): 32 | _temp_data_indices = np.argwhere((data[:, t_dim_index] > t_list[0]) & 33 | (data[:, t_dim_index] < t_list[1])).reshape(-1) 34 | 35 | split_result[split_key] = data[_temp_data_indices, :] 36 | 37 | return split_result 38 | -------------------------------------------------------------------------------- /utils/equations/ACEquation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from ..pde_utils import fwd_gradients 9 | from .basic_define import problem_decorator, ProblemDefine2d 10 | 11 | 12 | @problem_decorator 13 | class ACEquation(ProblemDefine2d): 14 | 15 | def __init__(self, problem_conf, data_conf): 16 | super(ACEquation, self).__init__(problem_conf, data_conf) 17 | 18 | def data_generator(self, random_state=None): 19 | """ 20 | create data from domain 21 | """ 22 | if random_state is not None: 23 | random.seed(random_state) 24 | np.random.seed(random_state) 25 | 26 | self.create_boundary_data() 27 | self.create_initial_data() 28 | self.create_pde_data() 29 | 30 | def create_boundary_data(self): 31 | self._create_boundary_data_periodic() 32 | 33 | def pde_loss(self, pred, input_tensor): 34 | df_dt_dx = fwd_gradients(pred, input_tensor) 35 | df_dt = df_dt_dx[:, 0:1] 36 | df_dx = df_dt_dx[:, 1:2] 37 | 38 | df_dxx = fwd_gradients(df_dx, input_tensor)[:, 1:2] 39 | pde_output = df_dt - 0.0001 * df_dxx + 5 * pred ** 3 - 5 * pred 40 | return pde_output 41 | 42 | def boundary_loss(self, input_lower, input_upper, pred_lower, pred_upper): 43 | df_dx_lower = fwd_gradients(pred_lower, input_lower)[:, 1:2] 44 | df_dx_upper = fwd_gradients(pred_upper, input_upper)[:, 1:2] 45 | 46 | boundary_value_loss = torch.mean((pred_lower - pred_upper) ** 2) 47 | boundary_gradient_loss = torch.mean((df_dx_upper - df_dx_lower) ** 2) 48 | 49 | return boundary_gradient_loss + boundary_value_loss 50 | 51 | def compute_loss(self, model, pde_data, initial_data, boundary_data, state="train"): 52 | loss_dict = dict() 53 | if state == "train": 54 | 55 | # ------------- 56 | # initial conditions 57 | # ------------- 58 | _initial_data = initial_data["data"] 59 | initial_input = _initial_data[:, 0:2] 60 | initial_ground_true = _initial_data[:, 2:3] 61 | initial_pred = model(initial_input) 62 | initial_loss = torch.mean((initial_pred - initial_ground_true)**2) 63 | 64 | # ------------- 65 | # boundary conditions 66 | # ------------- 67 | _boundary_data = boundary_data["data"] 68 | boundary_input_lower = _boundary_data[:, [0, 1]] 69 | boundary_input_upper = _boundary_data[:, [0, 2]] 70 | boundary_pred_lower = model(boundary_input_lower) 71 | boundary_pred_upper = model(boundary_input_upper) 72 | boundary_loss = self.boundary_loss( 73 | boundary_input_lower, 74 | boundary_input_upper, 75 | boundary_pred_lower, 76 | boundary_pred_upper 77 | ) 78 | # ------------- 79 | # pde 80 | # ------------- 81 | _pde_data = pde_data["data"] 82 | _pde_weight = pde_data["weights"] / torch.sum(pde_data["weights"]) 83 | pde_pred = model(_pde_data) 84 | _pde_weight = torch.reshape(_pde_weight, pde_pred.shape) 85 | pde_loss = torch.sum((self.pde_loss(pde_pred, _pde_data)**2).mul(_pde_weight)) 86 | 87 | total_loss = initial_loss + boundary_loss + pde_loss 88 | total_loss.backward() 89 | 90 | loss_dict["pde"] = pde_loss.item() 91 | loss_dict["initial"] = initial_loss.item() 92 | loss_dict["boundary"] = boundary_loss.item() 93 | loss_dict["total"] = total_loss.item() 94 | 95 | elif state == "eval": 96 | boundary_input = boundary_data[:, 0:2] 97 | boundary_ground_true = boundary_data[:, 2:3] 98 | boundary_pred = model(boundary_input) 99 | boundary_loss = torch.mean((boundary_pred - boundary_ground_true) ** 2) 100 | 101 | pde_pred = model(pde_data) 102 | pde_loss = torch.mean(self.pde_loss(pde_pred, pde_data) ** 2) 103 | 104 | total_loss = boundary_loss + pde_loss 105 | 106 | loss_dict["pde"] = pde_loss.item() 107 | loss_dict["boundary"] = boundary_loss.item() 108 | loss_dict["total"] = total_loss.item() 109 | 110 | return loss_dict 111 | 112 | def compute_loss_basic_weights(self, model, data): 113 | pde_pred = model(data) 114 | pde_loss = torch.abs(self.pde_loss(pde_pred, data)) 115 | return pde_loss -------------------------------------------------------------------------------- /utils/equations/Burgers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from ..pde_utils import fwd_gradients 9 | from .basic_define import problem_decorator, ProblemDefine2d 10 | 11 | 12 | @problem_decorator 13 | class Burgers(ProblemDefine2d): 14 | 15 | def __init__(self, problem_conf, data_conf): 16 | super(Burgers, self).__init__(problem_conf, data_conf) 17 | 18 | def data_generator(self, random_state=None): 19 | """ 20 | create data from domain 21 | """ 22 | if random_state is not None: 23 | random.seed(random_state) 24 | np.random.seed(random_state) 25 | 26 | self.create_boundary_data() 27 | self.create_initial_data() 28 | self.create_pde_data() 29 | 30 | def pde_loss(self, pred, input_tensor): 31 | df_dt_dx = fwd_gradients(pred, input_tensor) 32 | df_dt = df_dt_dx[:, 0:1] 33 | df_dx = df_dt_dx[:, 1:2] 34 | 35 | df_dxx = fwd_gradients(df_dx, input_tensor)[:, 1:2] 36 | 37 | pde_output = df_dt + pred * df_dx - (0.04/torch.pi) * df_dxx 38 | return pde_output 39 | 40 | def compute_loss(self, model, pde_data, initial_data, boundary_data, state="train"): 41 | loss_dict = dict() 42 | if state == "train": 43 | 44 | # ------------- 45 | # initial conditions 46 | # ------------- 47 | _initial_data = initial_data["data"] 48 | initial_input = _initial_data[:, 0:2] 49 | initial_ground_true = _initial_data[:, 2:3] 50 | initial_pred = model(initial_input) 51 | initial_loss = torch.mean((initial_pred - initial_ground_true)**2) 52 | 53 | # ------------- 54 | # boundary conditions 55 | # ------------- 56 | _boundary_data = boundary_data["data"] 57 | boundary_input = _boundary_data[:, 0:2] 58 | boundary_ground_true = _boundary_data[:, 2:3] 59 | boundary_pred = model(boundary_input) 60 | boundary_loss = torch.mean((boundary_pred - boundary_ground_true)**2) 61 | 62 | # ------------- 63 | # pde 64 | # ------------- 65 | _pde_data = pde_data["data"] 66 | _pde_weight = pde_data["weights"] / torch.sum(pde_data["weights"]) 67 | pde_pred = model(_pde_data) 68 | _pde_weight = torch.reshape(_pde_weight, pde_pred.shape) 69 | pde_loss = torch.sum((self.pde_loss(pde_pred, _pde_data)**2).mul(_pde_weight)) 70 | 71 | total_loss = initial_loss + boundary_loss + pde_loss 72 | total_loss.backward() 73 | 74 | loss_dict["pde"] = pde_loss.item() 75 | loss_dict["initial"] = initial_loss.item() 76 | loss_dict["boundary"] = boundary_loss.item() 77 | loss_dict["total"] = total_loss.item() 78 | 79 | return loss_dict 80 | 81 | def compute_loss_basic_weights(self, model, data): 82 | pde_pred = model(data) 83 | pde_loss = torch.abs(self.pde_loss(pde_pred, data)) 84 | return pde_loss 85 | 86 | -------------------------------------------------------------------------------- /utils/equations/Diffusion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from ..pde_utils import fwd_gradients 9 | from .basic_define import problem_decorator, ProblemDefine2d 10 | 11 | 12 | @problem_decorator 13 | class Diffusion(ProblemDefine2d): 14 | 15 | def __init__(self, problem_conf, data_conf): 16 | super(Diffusion, self).__init__(problem_conf, data_conf) 17 | 18 | def data_generator(self, random_state=None): 19 | """ 20 | create data from domain 21 | """ 22 | if random_state is not None: 23 | random.seed(random_state) 24 | np.random.seed(random_state) 25 | 26 | self.create_boundary_data() 27 | self.create_initial_data() 28 | self.create_pde_data() 29 | 30 | def pde_loss(self, pred, input_tensor): 31 | df_dt_dx = fwd_gradients(pred, input_tensor) 32 | df_dt = df_dt_dx[:, 0:1] 33 | df_dx = df_dt_dx[:, 1:2] 34 | 35 | df_dxx = fwd_gradients(df_dx, input_tensor)[:, 1:2] 36 | 37 | pde_output = df_dt - 1.2 * df_dxx - 5 * input_tensor[:, 1:2] * torch.exp(-input_tensor[:, 0:1]) 38 | return pde_output 39 | 40 | def compute_loss(self, model, pde_data, initial_data, boundary_data, state="train"): 41 | loss_dict = dict() 42 | if state == "train": 43 | 44 | # ------------- 45 | # initial conditions 46 | # ------------- 47 | _initial_data = initial_data["data"] 48 | initial_input = _initial_data[:, 0:2] 49 | initial_ground_true = _initial_data[:, 2:3] 50 | initial_pred = model(initial_input) 51 | initial_loss = torch.mean((initial_pred - initial_ground_true)**2) 52 | 53 | # ------------- 54 | # boundary conditions 55 | # ------------- 56 | _boundary_data = boundary_data["data"] 57 | boundary_input = _boundary_data[:, 0:2] 58 | boundary_ground_true = _boundary_data[:, 2:3] 59 | boundary_pred = model(boundary_input) 60 | boundary_loss = torch.mean((boundary_pred - boundary_ground_true)**2) 61 | 62 | # ------------- 63 | # pde 64 | # ------------- 65 | _pde_data = pde_data["data"] 66 | _pde_weight = pde_data["weights"] / torch.sum(pde_data["weights"]) 67 | pde_pred = model(_pde_data) 68 | _pde_weight = torch.reshape(_pde_weight, pde_pred.shape) 69 | pde_loss = torch.sum((self.pde_loss(pde_pred, _pde_data)**2).mul(_pde_weight)) 70 | 71 | total_loss = initial_loss + boundary_loss + pde_loss 72 | total_loss.backward() 73 | 74 | loss_dict["pde"] = pde_loss.item() 75 | loss_dict["initial"] = initial_loss.item() 76 | loss_dict["boundary"] = boundary_loss.item() 77 | loss_dict["total"] = total_loss.item() 78 | 79 | return loss_dict 80 | 81 | def compute_loss_basic_weights(self, model, data): 82 | pde_pred = model(data) 83 | pde_loss = torch.abs(self.pde_loss(pde_pred, data)) 84 | return pde_loss -------------------------------------------------------------------------------- /utils/equations/KDV.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from ..pde_utils import fwd_gradients 9 | from .basic_define import problem_decorator, ProblemDefine2d 10 | 11 | 12 | @problem_decorator 13 | class KDV(ProblemDefine2d): 14 | 15 | def __init__(self, problem_conf, data_conf): 16 | super(KDV, self).__init__(problem_conf, data_conf) 17 | 18 | def data_generator(self, random_state=None): 19 | """ 20 | create data from domain 21 | """ 22 | if random_state is not None: 23 | random.seed(random_state) 24 | np.random.seed(random_state) 25 | 26 | self.create_boundary_data() 27 | self.create_initial_data() 28 | self.create_pde_data() 29 | 30 | def create_boundary_data(self): 31 | self._create_boundary_data_periodic() 32 | 33 | def pde_loss(self, pred, input_tensor): 34 | df_dt_dx = fwd_gradients(pred, input_tensor) 35 | df_dt = df_dt_dx[:, 0:1] 36 | df_dx = df_dt_dx[:, 1:2] 37 | 38 | df_dxx = fwd_gradients(df_dx, input_tensor)[:, 1:2] 39 | df_dxxx = fwd_gradients(df_dxx, input_tensor)[:, 1:2] 40 | 41 | pde_output = df_dt + pred * df_dx + 0.0025 * df_dxxx 42 | return pde_output 43 | 44 | def boundary_loss(self, input_lower, input_upper, pred_lower, pred_upper): 45 | df_dx_lower = fwd_gradients(pred_lower, input_lower)[:, 1:2] 46 | df_dx_upper = fwd_gradients(pred_upper, input_upper)[:, 1:2] 47 | 48 | boundary_value_loss = torch.mean((pred_lower - pred_upper) ** 2) 49 | boundary_gradient_loss = torch.mean((df_dx_upper - df_dx_lower) ** 2) 50 | 51 | return boundary_gradient_loss + boundary_value_loss 52 | 53 | def compute_loss(self, model, pde_data, initial_data, boundary_data, state="train"): 54 | loss_dict = dict() 55 | if state == "train": 56 | 57 | # ------------- 58 | # initial conditions 59 | # ------------- 60 | _initial_data = initial_data["data"] 61 | initial_input = _initial_data[:, 0:2] 62 | initial_ground_true = _initial_data[:, 2:3] 63 | initial_pred = model(initial_input) 64 | initial_loss = torch.mean((initial_pred - initial_ground_true)**2) 65 | 66 | # ------------- 67 | # boundary conditions 68 | # ------------- 69 | _boundary_data = boundary_data["data"] 70 | boundary_input_lower = _boundary_data[:, [0, 1]] 71 | boundary_input_upper = _boundary_data[:, [0, 2]] 72 | boundary_pred_lower = model(boundary_input_lower) 73 | boundary_pred_upper = model(boundary_input_upper) 74 | boundary_loss = self.boundary_loss( 75 | boundary_input_lower, 76 | boundary_input_upper, 77 | boundary_pred_lower, 78 | boundary_pred_upper 79 | ) 80 | # ------------- 81 | # pde 82 | # ------------- 83 | _pde_data = pde_data["data"] 84 | _pde_weight = pde_data["weights"] / torch.sum(pde_data["weights"]) 85 | pde_pred = model(_pde_data) 86 | _pde_weight = torch.reshape(_pde_weight, pde_pred.shape) 87 | pde_loss = torch.sum((self.pde_loss(pde_pred, _pde_data)**2).mul(_pde_weight)) 88 | 89 | total_loss = initial_loss + boundary_loss + pde_loss 90 | total_loss.backward() 91 | 92 | loss_dict["pde"] = pde_loss.item() 93 | loss_dict["initial"] = initial_loss.item() 94 | loss_dict["boundary"] = boundary_loss.item() 95 | loss_dict["total"] = total_loss.item() 96 | 97 | return loss_dict 98 | 99 | def compute_loss_basic_weights(self, model, data): 100 | pde_pred = model(data) 101 | pde_loss = torch.abs(self.pde_loss(pde_pred, data)) 102 | return pde_loss 103 | -------------------------------------------------------------------------------- /utils/equations/Schrodinger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import random 4 | import numpy as np 5 | import sympy as sym 6 | from sympy.parsing.sympy_parser import parse_expr 7 | from ..data_utils import uniform_sampling 8 | 9 | import torch 10 | 11 | from ..pde_utils import fwd_gradients 12 | from .basic_define import problem_decorator, ProblemDefine2d 13 | 14 | 15 | @problem_decorator 16 | class Schrodinger(ProblemDefine2d): 17 | 18 | def __init__(self, problem_conf, data_conf): 19 | super(Schrodinger, self).__init__(problem_conf, data_conf) 20 | 21 | def data_generator(self, random_state=None): 22 | """ 23 | create data from domain 24 | """ 25 | if random_state is not None: 26 | random.seed(random_state) 27 | np.random.seed(random_state) 28 | 29 | self.create_boundary_data() 30 | self.create_initial_data() 31 | self.create_pde_data() 32 | 33 | def create_boundary_data(self): 34 | self._create_boundary_data_periodic() 35 | 36 | def create_initial_data(self): 37 | x = sym.symbols("x") 38 | t = sym.symbols("t") 39 | _temp_x_data = uniform_sampling(list(self.x_range), self.initial_data_n) 40 | _temp_t_data = uniform_sampling(self.t_range[0], self.initial_data_n) 41 | 42 | _expr_real = parse_expr(self.initial_condition, evaluate=False) 43 | _expr_real = sym.lambdify((x, t), _expr_real, "numpy") 44 | _ground_true_real = _expr_real(_temp_x_data, _temp_t_data) 45 | _ground_true_imaginary = np.zeros_like(_ground_true_real) 46 | self.initial_data = np.concatenate( 47 | ( 48 | _temp_t_data, 49 | _temp_x_data, 50 | _ground_true_real, 51 | _ground_true_imaginary 52 | ), 53 | axis=1 54 | ) 55 | 56 | def pde_loss(self, pred, input_tensor): 57 | 58 | pred_real = pred[:, 0:1] 59 | pred_imag = pred[:, 1:2] 60 | 61 | df_dt_dx_real = fwd_gradients(pred_real, input_tensor) 62 | df_dt_real = df_dt_dx_real[:, 0:1] 63 | df_dx_real = df_dt_dx_real[:, 1:2] 64 | 65 | df_dt_dx_imag = fwd_gradients(pred_imag, input_tensor) 66 | df_dt_imag = df_dt_dx_imag[:, 0:1] 67 | df_dx_imag = df_dt_dx_imag[:, 1:2] 68 | 69 | df_dxx_real = fwd_gradients(df_dx_real, input_tensor)[:, 1:2] 70 | df_dxx_imag = fwd_gradients(df_dx_imag, input_tensor)[:, 1:2] 71 | 72 | pde_output_real = -df_dt_imag + 0.5 * df_dxx_real + (pred_real ** 2 + pred_imag ** 2) * pred_real 73 | pde_output_imag = df_dt_real + 0.5 * df_dxx_imag + (pred_real ** 2 + pred_imag ** 2) * pred_imag 74 | return pde_output_real, pde_output_imag 75 | 76 | def boundary_loss(self, input_lower, input_upper, pred_lower, pred_upper): 77 | 78 | df_dx_lower_real = fwd_gradients(pred_lower[:, 0:1], input_lower)[:, 1:2] 79 | df_dx_lower_imag = fwd_gradients(pred_lower[:, 1:2], input_lower)[:, 1:2] 80 | 81 | df_dx_upper_real = fwd_gradients(pred_upper[:, 0:1], input_upper)[:, 1:2] 82 | df_dx_upper_imag = fwd_gradients(pred_upper[:, 1:2], input_upper)[:, 1:2] 83 | 84 | boundary_value_loss_real = torch.mean((pred_lower[:, 0:1] - pred_upper[:, 0:1]) ** 2) 85 | boundary_value_loss_imag = torch.mean((pred_lower[:, 1:2] - pred_upper[:, 1:2]) ** 2) 86 | boundary_gradient_loss_real = torch.mean((df_dx_lower_real - df_dx_upper_real) ** 2) 87 | boundary_gradient_loss_imag = torch.mean((df_dx_lower_imag - df_dx_upper_imag) ** 2) 88 | 89 | total_boundary_loss = boundary_value_loss_real + boundary_value_loss_imag +\ 90 | boundary_gradient_loss_real + boundary_gradient_loss_imag 91 | 92 | return total_boundary_loss 93 | 94 | def compute_loss(self, model, pde_data, initial_data, boundary_data, state="train"): 95 | loss_dict = dict() 96 | if state == "train": 97 | 98 | # ------------- 99 | # initial conditions 100 | # ------------- 101 | _initial_data = initial_data["data"] 102 | initial_input = _initial_data[:, 0:2] 103 | initial_ground_true = _initial_data[:, 2:4] 104 | initial_pred = model(initial_input) 105 | initial_loss = torch.mean((initial_pred - initial_ground_true) ** 2) 106 | 107 | # ------------- 108 | # boundary conditions 109 | # ------------- 110 | _boundary_data = boundary_data["data"] 111 | boundary_input_lower = _boundary_data[:, [0, 1]] 112 | boundary_input_upper = _boundary_data[:, [0, 2]] 113 | boundary_pred_lower = model(boundary_input_lower) 114 | boundary_pred_upper = model(boundary_input_upper) 115 | boundary_loss = self.boundary_loss( 116 | boundary_input_lower, 117 | boundary_input_upper, 118 | boundary_pred_lower, 119 | boundary_pred_upper 120 | ) 121 | # ------------- 122 | # pde 123 | # ------------- 124 | _pde_data = pde_data["data"] 125 | _pde_weight = pde_data["weights"] / torch.sum(pde_data["weights"]) 126 | pde_pred = model(_pde_data) 127 | _pde_weight = torch.reshape(_pde_weight, (pde_pred.shape[0], 1)) 128 | _pde_real, _pde_imag = self.pde_loss(pde_pred, _pde_data) 129 | pde_loss = torch.sum((_pde_real ** 2 + _pde_imag ** 2).mul(_pde_weight)) 130 | 131 | total_loss = initial_loss + boundary_loss + pde_loss 132 | total_loss.backward() 133 | 134 | loss_dict["pde"] = pde_loss.item() 135 | loss_dict["initial"] = initial_loss.item() 136 | loss_dict["boundary"] = boundary_loss.item() 137 | loss_dict["total"] = total_loss.item() 138 | 139 | return loss_dict 140 | 141 | def compute_loss_basic_weights(self, model, data): 142 | pde_pred = model(data) 143 | _pde_real, _pde_imag = self.pde_loss(pde_pred, data) 144 | pde_loss = torch.sqrt(_pde_real ** 2 + _pde_imag ** 2) 145 | return pde_loss 146 | 147 | -------------------------------------------------------------------------------- /utils/equations/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .Diffusion import * 4 | from .Burgers import * 5 | from .ACEquation import * 6 | from .Schrodinger import * 7 | from .KDV import * 8 | from .basic_define import equation_dict 9 | -------------------------------------------------------------------------------- /utils/equations/basic_define.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import sympy as sym 6 | from sympy.parsing.sympy_parser import parse_expr 7 | 8 | from ..data_utils import uniform_sampling 9 | 10 | 11 | equation_dict = dict() 12 | 13 | 14 | def problem_decorator(problem): 15 | equation_dict[problem.__name__] = problem 16 | return problem 17 | 18 | 19 | class ProblemDefine2d: 20 | 21 | def __init__(self, problem_conf, data_conf): 22 | self.x_range = problem_conf["x_range"] 23 | self.t_range = problem_conf["t_range"] 24 | self.initial_condition = problem_conf["initial_cond"] 25 | self.boundary_condition = problem_conf["boundary_cond"] 26 | 27 | self.initial_data_n = data_conf["initial_data_n"] 28 | self.boundary_data_n = data_conf["boundary_data_n"] 29 | self.pde_data_n = data_conf["pde_data_n"] 30 | 31 | self.initial_data = None 32 | self.boundary_data = None 33 | self.pde_data = None 34 | self.data_dict = dict() 35 | 36 | def create_initial_data(self): 37 | self._create_initial_data_basic() 38 | 39 | def create_boundary_data(self): 40 | self._create_boundary_data_basic() 41 | 42 | def create_pde_data(self): 43 | self._create_pde_data_basic() 44 | 45 | def _create_initial_data_basic(self): 46 | x = sym.symbols("x") 47 | t = sym.symbols("t") 48 | _temp_x_data = uniform_sampling(list(self.x_range), self.initial_data_n) 49 | _temp_t_data = uniform_sampling(self.t_range[0], self.initial_data_n) 50 | _expr = parse_expr(self.initial_condition, evaluate=False) 51 | _expr = sym.lambdify((x, t), _expr, "numpy") 52 | _ground_true = _expr(_temp_x_data, _temp_t_data) 53 | self.initial_data = np.concatenate((_temp_t_data, _temp_x_data, _ground_true), axis=1) 54 | self.data_dict["initial"] = self.initial_data 55 | 56 | def _create_boundary_data_basic(self): 57 | x = sym.symbols("x") 58 | t = sym.symbols("t") 59 | _temp_x_data = uniform_sampling(tuple(self.x_range), self.boundary_data_n) 60 | _temp_t_data = uniform_sampling(list(self.t_range), self.boundary_data_n) 61 | _expr = parse_expr(self.boundary_condition, evaluate=False) 62 | _expr = sym.lambdify((x, t), _expr, "numpy") 63 | _ground_true = _expr(_temp_x_data, _temp_t_data) 64 | self.boundary_data = np.concatenate((_temp_t_data, _temp_x_data, _ground_true), axis=1) 65 | self.data_dict["boundary"] = self.boundary_data 66 | 67 | def _create_boundary_data_periodic(self): 68 | _temp_t_data = uniform_sampling(list(self.t_range), self.boundary_data_n) 69 | _lower_x_data = np.ones_like(_temp_t_data) * self.x_range[0] 70 | _upper_x_data = np.ones_like(_temp_t_data) * self.x_range[1] 71 | self.boundary_data = np.concatenate((_temp_t_data, _lower_x_data, _upper_x_data), axis=1) 72 | self.data_dict["boundary"] = self.boundary_data 73 | 74 | def _create_pde_data_basic(self): 75 | _temp_x_data = uniform_sampling(list(self.x_range), self.pde_data_n) 76 | _temp_t_data = uniform_sampling(list(self.t_range), self.pde_data_n) 77 | self.pde_data = np.concatenate((_temp_t_data, _temp_x_data), axis=1) 78 | self.data_dict["pde"] = self.pde_data 79 | 80 | def pde_loss(self, pred, input_tensor): 81 | raise NotImplementedError 82 | 83 | def compute_loss(self, model, pde_data, initial_data, boundary_data, state="train"): 84 | raise NotImplementedError 85 | 86 | def plot_samples(self): 87 | plt.figure() 88 | plt.scatter(self.boundary_data[:, 0], self.boundary_data[:, 1], label="boundary", s=2) 89 | plt.scatter(self.initial_data[:, 0], self.initial_data[:, 1], label="initial", s=2) 90 | plt.scatter(self.pde_data[:, 0], self.pde_data[:, 1], label="pde", s=2) 91 | plt.xlabel("t") 92 | plt.ylabel("x") 93 | plt.legend() 94 | plt.show() 95 | -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from numbers import Number 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class FullyConnectedNetwork(nn.Module): 10 | 11 | def __init__(self, conf): 12 | super(FullyConnectedNetwork, self).__init__() 13 | 14 | self.layer_conf = conf["layer"] 15 | self.dim_conf = conf["dim"] 16 | 17 | if isinstance(self.layer_conf["layer_size"], Number): 18 | layer_size = [self.layer_conf["layer_size"]] * self.layer_conf["layer_n"] 19 | else: 20 | layer_size = self.layer_conf["layer_size"] 21 | assert len(layer_size) == self.layer_conf["layer_n"] 22 | 23 | self._network = nn.Sequential() 24 | curr_dim = self.dim_conf["input_dim"] 25 | activate_func = self.layer_conf["activate"] 26 | norm_flag = self.layer_conf["norm"] 27 | for layer_id, layer_dim in enumerate(layer_size): 28 | self._network.add_module( 29 | "layer_{}".format(layer_id + 1), 30 | self._make_layer(curr_dim, layer_dim, norm_flag, activate_func) 31 | ) 32 | curr_dim = layer_dim 33 | 34 | self._network.add_module( 35 | "layer_{}".format(len(layer_size) + 1), 36 | self._make_layer(curr_dim, 37 | self.dim_conf["output_dim"], 38 | activate_func=self.layer_conf["final_activate"]) 39 | ) 40 | 41 | def _forward_impl(self, x): 42 | return self._network(x) 43 | 44 | def forward(self, x): 45 | return self._forward_impl(x) 46 | 47 | @staticmethod 48 | def _make_layer(input_dim, output_dim, norm=False, activate_func="tanh"): 49 | layers = list() 50 | 51 | layers.append( 52 | nn.Linear(input_dim, output_dim) 53 | ) 54 | 55 | if norm: 56 | layers.append( 57 | nn.BatchNorm1d(output_dim) 58 | ) 59 | 60 | if activate_func == "tanh": 61 | layers.append(nn.Tanh()) 62 | elif activate_func == "Identify": 63 | pass 64 | else: 65 | raise ValueError 66 | 67 | return nn.Sequential(*layers) 68 | 69 | 70 | def model_saver(save_folder, model, save_name, step=None): 71 | if not os.path.exists(save_folder): 72 | os.makedirs(save_folder) 73 | if step is not None: 74 | save_path = os.path.join(save_folder, "{}_{}.pth".format(save_name, step)) 75 | else: 76 | save_path = os.path.join(save_folder, "{}.pth".format(save_name)) 77 | torch.save(model.state_dict(), save_path) 78 | return save_path 79 | -------------------------------------------------------------------------------- /utils/pde_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def fwd_gradients(obj, x): 7 | dummy = torch.ones_like(obj) 8 | derivative = torch.autograd.grad(obj, x, dummy, create_graph= True)[0] 9 | return derivative 10 | -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import numpy as np 5 | from scipy.interpolate import griddata 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def interpolate_2d(invar, outvar, plotsize_x=100, plotsize_y=100): 10 | 11 | # create grid 12 | extent = (invar[:, 0].min(), invar[:, 0].max(), invar[:, 1].min(), invar[:, 1].max()) 13 | _plot_mesh = np.meshgrid( 14 | np.linspace(extent[0], extent[1], plotsize_x), 15 | np.linspace(extent[2], extent[3], plotsize_y), 16 | indexing="ij" 17 | ) 18 | 19 | outvar_interp = griddata( 20 | invar, outvar, tuple(_plot_mesh) 21 | ) 22 | return extent, outvar_interp, _plot_mesh 23 | 24 | 25 | def mesh_plotter_2d(coords, simplices, step=None, ex_path="./mesh_data", name="mesh"): 26 | """ 27 | function to plot triangular meshes 28 | """ 29 | assert coords.shape[1] == 2 30 | plt.figure(figsize=(20, 20), dpi=100) 31 | 32 | plt.triplot(coords[:, 0], coords[:, 1], simplices) 33 | 34 | if step is not None: 35 | plt.savefig(os.path.join(ex_path, "{}_{}.png".format(name, step))) 36 | else: 37 | plt.savefig(os.path.join(ex_path, "{}.png".format(name))) 38 | plt.close() 39 | 40 | -------------------------------------------------------------------------------- /utils/reweightings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | 6 | reweighting_dict = dict() 7 | 8 | 9 | def reweighting_decorator(reweighting): 10 | reweighting_dict[reweighting.__name__] = reweighting 11 | return reweighting 12 | 13 | 14 | class BasicReWeighting: 15 | 16 | def __init__(self, reweighting_params=None): 17 | self.reweighting_params = reweighting_params 18 | 19 | def sample_weights(self, indices, scores): 20 | raise NotImplementedError 21 | 22 | 23 | @reweighting_decorator 24 | class NoReWeighting(BasicReWeighting): 25 | 26 | def __init__(self, reweighting_parmas=None): 27 | super(NoReWeighting, self).__init__(reweighting_parmas) 28 | 29 | def sample_weights(self, indxs, scores): 30 | """ 31 | sample weight = 1 32 | """ 33 | return torch.ones(len(indxs)) 34 | 35 | 36 | @reweighting_decorator 37 | class BiasedReWeighting(BasicReWeighting): 38 | 39 | def __init__(self, reweighting_params=None): 40 | """ 41 | ref: 42 | A. Katharopoulos, F. Fleuret. 43 | Biased importance sampling for deep neural network training[J]. 44 | arXiv preprint arXiv:1706.00043, 2017. 45 | """ 46 | super(BiasedReWeighting, self).__init__(reweighting_params) 47 | self.k_zero = self.reweighting_params["k_init"] 48 | self.k_end = self.reweighting_params["k_final"] 49 | self.max_step = self.reweighting_params["iter_n"] 50 | self.decay_step = int(self.max_step * 0.25) 51 | self.decrease_step = self.max_step - self.decay_step 52 | self.step_count = 0 53 | 54 | def sample_weights(self, indxs, scores): 55 | 56 | # rate 57 | if self.step_count <= self.decay_step: 58 | k = self.k_zero 59 | else: 60 | k = self.k_zero + (self.k_end - self.k_zero) * (self.step_count / self.decrease_step) 61 | 62 | # reweighting 63 | samples_len = len(scores) 64 | samples_scores = scores[indxs] 65 | samplers_weight = np.sum(scores) / (samples_len * samples_scores) 66 | samplers_weight = samplers_weight ** k 67 | 68 | self.step_count += 1 69 | return torch.from_numpy(samplers_weight) 70 | 71 | -------------------------------------------------------------------------------- /utils/samplers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import torch 6 | import logging 7 | import numpy as np 8 | 9 | from scipy import spatial, interpolate 10 | from .plot_utils import mesh_plotter_2d 11 | from .models import model_saver 12 | 13 | 14 | sampler_dict = dict() 15 | 16 | 17 | def sampler_decorator(sampler): 18 | 19 | sampler_dict[sampler.__name__] = sampler 20 | return sampler 21 | 22 | 23 | @sampler_decorator 24 | class BaseSampler: 25 | """ 26 | Abstract class of sampler 27 | """ 28 | 29 | def __init__(self, data, reweighting): 30 | self.data = data 31 | self.reweighting = reweighting 32 | 33 | self.data_n = len(self.data) 34 | self.indices = torch.arange(self.data_n) 35 | self.log = logging.getLogger("sampler") 36 | 37 | def __len__(self): 38 | return self.data_n 39 | 40 | def compute_scores(self): 41 | raise NotImplementedError 42 | 43 | def sampler(self, batch_size, replace=True): 44 | # sample weights 45 | indices, scores = self.compute_scores() 46 | 47 | # sampling 48 | sample_p = scores / np.sum(scores) 49 | sample_indices = np.random.choice(len(indices), batch_size, p=sample_p, replace=replace) 50 | 51 | sample_weights = self.reweighting.sample_weights(sample_indices, scores) 52 | 53 | batch_dict = dict() 54 | batch_dict["data"] = self.data[sample_indices] 55 | batch_dict["weights"] = sample_weights.to(torch.device("cuda")) 56 | 57 | return batch_dict 58 | 59 | 60 | @sampler_decorator 61 | class UniformSampler(BaseSampler): 62 | """ 63 | Implement of PINN-O 64 | """ 65 | 66 | def __init__(self, data, reweighting, *args, **kwargs): 67 | super(UniformSampler, self).__init__(data, reweighting) 68 | 69 | self.scores = np.ones(len(self)) 70 | 71 | def compute_scores(self): 72 | return self.indices, self.scores 73 | 74 | 75 | class InterpolationSampler(BaseSampler): 76 | 77 | def __init__(self, data, reweighting, *args, **kwargs): 78 | super(InterpolationSampler, self).__init__(data, reweighting) 79 | 80 | # mesh simplex 81 | self.interp_simplex_result = None 82 | # interpolation weights 83 | self.interp_bary_weights = None 84 | # mesh update flag 85 | self.mesh_update_flag = False 86 | # set of mesh points 87 | self.seed_indxs = None 88 | 89 | # interpolation result 90 | self.interp_data = None 91 | 92 | def compute_scores(self): 93 | raise NotImplementedError 94 | 95 | def mesh_update(self): 96 | """ 97 | update mesh 98 | 1. Delaunay 99 | 2. Compute the triangular of each point 100 | """ 101 | 102 | seed_points = self.interp_data[self.seed_indxs] 103 | n_dim = self.interp_data.shape[1] 104 | 105 | # Delaunay 106 | tri = spatial.Delaunay(seed_points) 107 | 108 | # Compute the triangular of each point 109 | simplex = tri.find_simplex(self.data_numpy) 110 | 111 | self.interp_simplex_result = np.take(tri.simplices, simplex, axis=0) 112 | temp = np.take(tri.transform, simplex, axis=0) 113 | delta = self.data_numpy - temp[:, n_dim] 114 | bary = np.einsum('njk,nk->nj', temp[:, :n_dim, :], delta) 115 | self.interp_bary_weights = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True))) 116 | 117 | def seed_update(self): 118 | raise NotImplementedError 119 | 120 | 121 | @sampler_decorator 122 | class SamplerWithDMIS(InterpolationSampler): 123 | """ 124 | Implement of PINN-DMIS 125 | """ 126 | 127 | def __init__( 128 | self, 129 | data, 130 | reweighting, 131 | *args, 132 | **kwargs 133 | ): 134 | super(SamplerWithDMIS, self).__init__(data, reweighting) 135 | 136 | # model 137 | self.model = kwargs["model"] 138 | # loss function 139 | self.loss_func = kwargs["loss_func"] 140 | # mesh update threshold 141 | self.mesh_update_thres = kwargs["mesh_update_thres"] 142 | # batch size of computing sample weights 143 | self.forward_batch_size = kwargs["forward_batch_size"] 144 | 145 | # create data 146 | addon_points = torch.tensor(list(kwargs["addon_points"])).to(device=torch.device("cuda"), dtype=torch.float) # 获取额外边界点 147 | self.data = torch.concat([self.data, addon_points], dim=0) 148 | self.data_np = self.data.detach().cpu().numpy() 149 | 150 | # total number 151 | self.addon_n = len(addon_points) 152 | self.seed_n = kwargs["seed_n"] - self.addon_n 153 | 154 | # init recoder 155 | self.seed_scores_t0 = np.ones(kwargs["seed_n"]) 156 | self.seed_scores_t = np.ones(kwargs["seed_n"]) 157 | self.step_count = 0 158 | if not os.path.exists("./mesh_data"): 159 | os.mkdir("./mesh_data") 160 | 161 | # init set of mesh points 162 | self._build_seed_indices() 163 | # init the triangular mesh 164 | self.mesh_update() 165 | 166 | def compute_scores(self): 167 | 168 | # -------------- 169 | # mesh update 170 | # -------------- 171 | if self.mesh_update_flag: 172 | self.seed_update() 173 | self.mesh_update() 174 | 175 | # ------------------ 176 | # compute weight of mesh points 177 | # ------------------ 178 | seed_data = self.data[self.seed_indxs] 179 | seed_len = len(seed_data) 180 | self.model.eval() 181 | _index = 0 182 | scores_tensor = torch.tensor([]).to(torch.device("cuda")) 183 | while True: 184 | if _index == seed_len: 185 | break 186 | 187 | last_index = min(_index + self.forward_batch_size, seed_len) 188 | 189 | input_tensor = seed_data[_index:last_index, :] 190 | _loss = self.loss_func(self.model, input_tensor) 191 | scores_tensor = torch.cat([scores_tensor, _loss], dim=0) 192 | _index = last_index 193 | 194 | self.model.train() 195 | seed_scores = scores_tensor.detach().cpu().numpy().reshape(-1) 196 | 197 | # ------------------ 198 | # update weight recorders 199 | # ------------------ 200 | if self.mesh_update_flag: 201 | self.seed_scores_t = seed_scores 202 | self.seed_scores_t0 = seed_scores.copy() 203 | self.mesh_update_flag = False 204 | else: 205 | self.seed_scores_t = seed_scores 206 | 207 | # --------------- 208 | # interpolation 209 | # --------------- 210 | interp_scores = np.einsum( 211 | 'nj,nj->n', 212 | np.take(self.seed_scores_t, self.interp_simplex_result), 213 | self.interp_bary_weights 214 | ) 215 | if np.min(interp_scores) < 5e-3: 216 | # check min value and make sure it is positive(>5e-3) 217 | # interp_scores -= np.min(interp_scores) 218 | interp_scores += 5e-3 219 | 220 | # check update 221 | self.mesh_update_check() 222 | self.step_count += 1 223 | 224 | # return the list of datapoints and sample weights 225 | return self.indices, interp_scores 226 | 227 | def mesh_update_check(self): 228 | # cosine similarity 229 | norm_scores = np.linalg.norm(self.seed_scores_t) 230 | norm_history_scores = np.linalg.norm(self.seed_scores_t0) 231 | cos_sim = self.seed_scores_t0.dot(self.seed_scores_t) / (norm_scores * norm_history_scores) 232 | 233 | if cos_sim < self.mesh_update_thres: 234 | self.mesh_update_flag = True 235 | self.log.info("change mesh") 236 | 237 | def seed_update(self): 238 | """update the set of mesh points""" 239 | scores_differance = np.abs(self.seed_scores_t - self.seed_scores_t0) 240 | 241 | interp_differance = np.einsum( 242 | 'nj,nj->n', 243 | np.take(scores_differance, self.interp_simplex_result), 244 | self.interp_bary_weights 245 | ) 246 | 247 | p = interp_differance / np.sum(interp_differance) 248 | 249 | # re-select mesh points 250 | self._build_seed_indices(p) 251 | 252 | def _build_seed_indices(self, p=None): 253 | """create the set of mesh points""" 254 | self.seed_indxs = np.random.choice( 255 | self.data_n, 256 | self.seed_n, 257 | p=p, 258 | replace=False 259 | ) 260 | self.seed_indxs = np.append(self.seed_indxs, np.arange(self.data_n, self.data_n + self.addon_n)) 261 | 262 | def mesh_update(self): 263 | """create triangular according to the set of mesh points""" 264 | 265 | seed_points = self.data_np[self.seed_indxs] 266 | n_dim = self.data_np.shape[1] 267 | 268 | # Delaunay 269 | tri = spatial.Delaunay(seed_points) 270 | simplex = tri.find_simplex(self.data_np[:self.data_n, :]) 271 | self.interp_simplex_result = np.take(tri.simplices, simplex, axis=0) 272 | temp = np.take(tri.transform, simplex, axis=0) 273 | delta = self.data_np[:self.data_n, :] - temp[:, n_dim] 274 | bary = np.einsum('njk,nk->nj', temp[:, :n_dim, :], delta) 275 | self.interp_bary_weights = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True))) 276 | 277 | # save data 278 | save_dict = { 279 | "seeds": seed_points, 280 | "data": self.data_np[:self.data_n, :] 281 | } 282 | 283 | model_saver( 284 | save_folder="./mesh_data", 285 | model=self.model, 286 | save_name="mesh", 287 | step=self.step_count 288 | ) 289 | 290 | output_full_path = os.path.join("./mesh_data", "{}_{}.npz".format("mesh", self.step_count)) 291 | np.savez(output_full_path, **save_dict) 292 | 293 | # quick plot 294 | mesh_plotter_2d(seed_points, tri.simplices, self.step_count) 295 | 296 | 297 | @sampler_decorator 298 | class SamplerWithBasicIS(InterpolationSampler): 299 | """ 300 | Implement of PINN-BasicIS 301 | """ 302 | 303 | def __init__( 304 | self, 305 | data, 306 | reweighting, 307 | *args, 308 | **kwargs 309 | ): 310 | super(SamplerWithBasicIS, self).__init__(data, reweighting) 311 | 312 | # model 313 | self.model = kwargs["model"] 314 | # loss function 315 | self.loss_func = kwargs["loss_func"] 316 | # batch size of computing sample weights 317 | self.forward_batch_size = kwargs["forward_batch_size"] 318 | 319 | # create data 320 | addon_points = torch.tensor(list(kwargs["addon_points"])).to(device=torch.device("cuda"), dtype=torch.float) # 获取额外边界点 321 | self.data = torch.concat([self.data, addon_points], dim=0) 322 | self.data_np = self.data.detach().cpu().numpy() 323 | 324 | # total number 325 | self.addon_n = len(addon_points) 326 | self.seed_n = kwargs["seed_n"] - self.addon_n 327 | 328 | # init recoder 329 | self.seed_scores_t0 = np.ones(kwargs["seed_n"]) 330 | self.seed_scores_t = np.ones(kwargs["seed_n"]) 331 | self.step_count = 0 332 | if not os.path.exists("./mesh_data"): 333 | os.mkdir("./mesh_data") 334 | 335 | # init the set of mesh points 336 | self._build_seed_indices() 337 | # init the triangular mesh 338 | self.mesh_update() 339 | 340 | def compute_scores(self): 341 | 342 | # ------------------ 343 | # compute weight of mesh points 344 | # ------------------ 345 | seed_data = self.data[self.seed_indxs] 346 | seed_len = len(seed_data) 347 | self.model.eval() 348 | _index = 0 349 | scores_tensor = torch.tensor([]).to(torch.device("cuda")) 350 | while True: 351 | if _index == seed_len: 352 | break 353 | 354 | last_index = min(_index + self.forward_batch_size, seed_len) 355 | 356 | input_tensor = seed_data[_index:last_index, :] 357 | _loss = self.loss_func(self.model, input_tensor) 358 | scores_tensor = torch.cat([scores_tensor, _loss], dim=0) 359 | _index = last_index 360 | 361 | self.model.train() 362 | seed_scores = scores_tensor.detach().cpu().numpy().reshape(-1) 363 | 364 | # ------------------ 365 | # update weight recorders 366 | # ------------------ 367 | if self.mesh_update_flag: 368 | self.seed_scores_t = seed_scores 369 | self.seed_scores_t0 = seed_scores.copy() 370 | self.mesh_update_flag = False 371 | else: 372 | self.seed_scores_t = seed_scores 373 | 374 | # --------------- 375 | # interpolation 376 | # --------------- 377 | interp_scores = np.einsum( 378 | 'nj,nj->n', 379 | np.take(self.seed_scores_t, self.interp_simplex_result), 380 | self.interp_bary_weights 381 | ) 382 | if np.min(interp_scores) < 5e-3: 383 | # check min value and make sure it is positive(>5e-3) 384 | # interp_scores -= np.min(interp_scores) 385 | interp_scores += 5e-3 386 | 387 | # check update 388 | self.mesh_update_check() 389 | self.step_count += 1 390 | 391 | # return the list of datapoints and sample weights 392 | return self.indices, interp_scores 393 | 394 | def _build_seed_indices(self, p=None): 395 | self.seed_indxs = np.random.choice( 396 | self.data_n, 397 | self.seed_n, 398 | p=p, 399 | replace=False 400 | ) 401 | self.seed_indxs = np.append(self.seed_indxs, np.arange(self.data_n, self.data_n + self.addon_n)) 402 | 403 | def seed_update(self): 404 | pass 405 | 406 | def mesh_update(self): 407 | """create triangular according to the set of mesh points""" 408 | 409 | seed_points = self.data_np[self.seed_indxs] 410 | n_dim = self.data_np.shape[1] 411 | 412 | # Delaunay 413 | tri = spatial.Delaunay(seed_points) 414 | simplex = tri.find_simplex(self.data_np[:self.data_n, :]) 415 | self.interp_simplex_result = np.take(tri.simplices, simplex, axis=0) 416 | temp = np.take(tri.transform, simplex, axis=0) 417 | delta = self.data_np[:self.data_n, :] - temp[:, n_dim] 418 | bary = np.einsum('njk,nk->nj', temp[:, :n_dim, :], delta) 419 | self.interp_bary_weights = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True))) 420 | 421 | # save data 422 | save_dict = { 423 | "seeds": seed_points, 424 | "data": self.data_np[:self.data_n, :] 425 | } 426 | 427 | model_saver( 428 | save_folder="./mesh_data", 429 | model=self.model, 430 | save_name="mesh", 431 | step=self.step_count 432 | ) 433 | 434 | output_full_path = os.path.join("./mesh_data", "{}_{}.npz".format("mesh", self.step_count)) 435 | np.savez(output_full_path, **save_dict) 436 | 437 | # quick plot 438 | mesh_plotter_2d(seed_points, tri.simplices, self.step_count) 439 | 440 | 441 | @sampler_decorator 442 | class SamplerWithNabianMethod(InterpolationSampler): 443 | """ 444 | Implement of PINN-N 445 | Nabian, M. A.; Gladstone, R. J.; and Meidani, H. 2021. 446 | Efficient training of physics-informed neural networks via importance sampling. 447 | Computer-Aided Civil and Infrastructure Engineering, 36(8): 962–977. 448 | """ 449 | 450 | def __init__( 451 | self, 452 | data, 453 | reweighting, 454 | *args, 455 | **kwargs 456 | ): 457 | super(SamplerWithNabianMethod, self).__init__(data, reweighting) 458 | 459 | self.model = kwargs["model"] 460 | self.loss_func = kwargs["loss_func"] 461 | 462 | self.forward_batch_size = kwargs["forward_batch_size"] 463 | 464 | addon_points = torch.tensor(list(kwargs["addon_points"])).to(device=torch.device("cuda"), dtype=torch.float) # 获取额外边界点 465 | self.data = torch.concat([self.data, addon_points], dim=0) 466 | self.data_np = self.data.detach().cpu().numpy() 467 | self.addon_n = len(addon_points) 468 | self.seed_n = kwargs["seed_n"] - self.addon_n 469 | 470 | self.seed_scores_t = np.ones(kwargs["seed_n"]) 471 | self.step_count = 0 472 | if not os.path.exists("./mesh_data"): 473 | os.mkdir("./mesh_data") 474 | 475 | self._build_seed_indices() 476 | 477 | def compute_scores(self): 478 | 479 | # ------------------ 480 | # compute sample weights 481 | # ------------------ 482 | seed_data = self.data[self.seed_indxs] 483 | seed_len = len(seed_data) 484 | self.model.eval() 485 | _index = 0 486 | scores_tensor = torch.tensor([]).to(torch.device("cuda")) 487 | while True: 488 | if _index == seed_len: 489 | break 490 | 491 | last_index = min(_index + self.forward_batch_size, seed_len) 492 | 493 | input_tensor = seed_data[_index:last_index, :] 494 | _loss = self.loss_func(self.model, input_tensor) 495 | scores_tensor = torch.cat([scores_tensor, _loss], dim=0) 496 | _index = last_index 497 | 498 | self.model.train() 499 | seed_scores = scores_tensor.detach().cpu().numpy().reshape(-1) 500 | 501 | self.seed_scores_t = seed_scores 502 | 503 | # --------------- 504 | # interpolation 505 | # --------------- 506 | interp_scores = interpolate.griddata(self.data_np[self.seed_indxs], 507 | self.seed_scores_t, 508 | self.data_np[:self.data_n], 509 | method="nearest") 510 | interp_scores -= np.min(interp_scores) 511 | interp_scores += 1e-15 512 | 513 | self.step_count += 1 514 | 515 | return self.indices, interp_scores 516 | 517 | def _build_seed_indices(self, p=None): 518 | self.seed_indxs = np.random.choice( 519 | self.data_n, 520 | self.seed_n, 521 | p=p, 522 | replace=False 523 | ) 524 | self.seed_indxs = np.append(self.seed_indxs, np.arange(self.data_n, self.data_n + self.addon_n)) 525 | 526 | def seed_update(self): 527 | """no need to update the set of mesh points""" 528 | pass 529 | 530 | def mesh_update(self): 531 | """no need to update the triangular mesh""" 532 | pass 533 | --------------------------------------------------------------------------------