├── .gitignore ├── LICENSE ├── README.md ├── data ├── case1_pack.txt ├── case2_airfoil.txt └── w1015.dat ├── model └── .gitignore ├── outs └── .gitignore ├── poster.png ├── problems.png ├── requirements.txt ├── results.png └── src ├── FBPINN └── fbpinn.py ├── HC ├── hard_constraint.py ├── hard_constraint_collector.py ├── l_functions.py └── normal_function.py ├── PFNN └── pfnn.py ├── __init__.py ├── case1.py ├── case2.py ├── case3.py ├── case4.py ├── configs ├── __init__.py ├── case1 │ ├── __init__.py │ ├── fbpinn.py │ ├── hc.py │ ├── params.py │ ├── pfnn.py │ ├── pinn.py │ └── xpinn.py ├── case2 │ ├── __init__.py │ ├── fbpinn.py │ ├── hc.py │ ├── params.py │ ├── pfnn.py │ └── pinn.py └── case3 │ ├── hc.py │ ├── params.py │ ├── pfnn.py │ └── pinn.py ├── utils ├── __init__.py ├── nn_wrapper.py ├── no_stdout_context.py ├── pinn_bc.py ├── pinn_callback.py ├── pinn_geometry.py ├── resnet.py ├── torch_interp.py └── utils.py └── xPINN ├── interface_conditions.py └── xPINN.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode 132 | .vscode/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Songming Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unified Hard-Constraint Framework 2 | 3 | This repository is the official implementation of the NeurIPS 2022 paper: [A Unified Hard-Constraint Framework for Solving Geometrically Complex PDEs](https://arxiv.org/abs/2210.03526). 4 | 5 | ## Introduction 6 | 7 | Physics-informed neural networks are suffering from the unbalanced competition in their loss function: $\mathcal{L}_{\mathrm{PDE}}$ vs. $\mathcal{L}_{\mathrm{BC}}$. Hard-constraint methods have been developed to alleviate such issues, which are, however, **limited to Dirichlet** boundary conditions. Our work provides a unified hard-constraint framework for the three most common types of boundary conditions: **Dirichlet, Neumann, and Robin**. The proposed method has achieved promising performance in **real-world geometrically complex problems**. 8 | 9 | ![](poster.png) 10 | 11 | ## Directory Tree 12 | 13 | This repository is organized as: 14 | 15 | ```bash 16 | HardConstraint/ 17 | │ README.md 18 | │ requirements.txt # required dependencies 19 | │ 20 | ├─data # data used in this paper 21 | │ case1_pack.txt # ground truth for "Simulation of a 2D battery pack (Heat Equation)" 22 | │ case2_airfoil.txt # ground truth for "Simulation of an Airfoil (Navier-Stokes Equations)" 23 | │ w1015.dat # achor points of the airfoil 24 | │ 25 | ├─model/ # saved model weights (empty) 26 | ├─outs/ # outputs (empty) 27 | └─src 28 | │ case1.py # scripts for each experiment 29 | │ ... 30 | │ 31 | ├─configs # configurations for each experiment 32 | │ │ 33 | │ ├─case1 34 | │ │ ... 35 | │ │ 36 | │ ├─case2 37 | │ │ ... 38 | │ │ 39 | │ └─case3 40 | │ ... 41 | │ 42 | ├─FBPINN/ # implementations of each model 43 | ├─HC/ 44 | ├─PFNN/ 45 | ├─xPINN/ 46 | │ 47 | └─utils/ # some utils 48 | ``` 49 | 50 | ## Getting Started 51 | 52 | 1. Install necessary dependencies: 53 | 54 | ```bash 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | 2. To train and evaluate the models in the paper, run this command: 59 | 60 | ```bash 61 | # In the root directory of this repository 62 | DDEBACKEND=pytorch python -m src.caseX 63 | ``` 64 | 65 | where X = 1, 2, 3, 4, corresponding to *Simulation of a 2D battery pack (Heat Equation)*, *Simulation of an Airfoil (Navier-Stokes Equations)*, *High-dimensional Heat Equation*, and *Ablation Study: Extra fields*. 66 | 67 | If you want to run different models (i.e., the proposed model and baselines), please modify the global variables in `src/caseX.py`. 68 | 69 | To run the experiment of *Ablation Study: Hyper-parameters of Hardness*, you can change the value of: 70 | 71 | - $\beta_s$ in `src/HC/l_functions.py` (line 31, default: $\beta_s=5$) 72 | - $\beta_t$ in `src/configs/case1/hc.py` (line 127, default: $\beta_t=10$, case 1) or `src/configs/case3/hc.py` (line 56, default: $\beta_t=10$, case 3) 73 | 74 | ## Experimental Results 75 | 76 | **Settings**: 77 | 78 | - Evaluation Metrics: Mean Absolute Error (**MAE**), Mean Absolute Percentage Error (**MAPE**), and Weighted Mean Absolute Percentage Error (**WMAPE**) 79 | - Baselines: 80 | - **PINN**: vanilla PINN 81 | - **PINN-LA** **& PINN-LA-2**: PINN with learning rate annealing 82 | - **xPINN** **& FBPINN**: PINN with domain decomposition for geometrically complex PDEs 83 | - **PFNN** **& PFNN-2**: hard-constraint methods based on the variational formulation of PDEs 84 | - **HC**: our proposed method 85 | 86 | **Problems**: a 2D battery pack, an airfoil, a high-dimensional heat equation 87 | 88 | ![](problems.png) 89 | 90 | **Results**: 91 | 92 | 93 | 94 | ## Problem & Solution 95 | 96 | 1. Scalar Type Error 97 | 98 | ```bash 99 | ... 100 | File "ENV_PATH/lib/python3.9/site-packages/deepxde/model.py", line 228, in outputs_losses 101 | outputs_ = self.net(self.net.inputs) 102 | ... 103 | File "ENV_PATH/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear 104 | return torch._C._nn.linear(input, weight, bias) 105 | RuntimeError: expected scalar type Float but found Double 106 | ``` 107 | 108 | Please modify `ENV_PATH/lib/python3.9/site-packages/deepxde/model.py (line 228)` from: 109 | 110 | ```python 111 | self.net.train(mode=training) 112 | self.net.inputs = torch.as_tensor(inputs) 113 | self.net.inputs.requires_grad_() 114 | outputs_ = self.net(self.net.inputs) 115 | ``` 116 | 117 | to: 118 | 119 | ```python 120 | self.net.train(mode=training) 121 | self.net.inputs = torch.as_tensor(inputs) 122 | self.net.inputs.requires_grad_() 123 | outputs_ = self.net(self.net.inputs.float()) # add this 124 | ``` 125 | 126 | ## Citation 127 | 128 | If you find this work is helpful for your research, please **cite us** with the following BibTeX entry: 129 | 130 | ``` 131 | @article{liu2022unified, 132 | title={A Unified Hard-Constraint Framework for Solving Geometrically Complex PDEs}, 133 | author={Liu, Songming and Hao, Zhongkai and Ying, Chengyang and Su, Hang and Zhu, Jun and Cheng, Ze}, 134 | journal={arXiv preprint arXiv:2210.03526}, 135 | year={2022} 136 | } 137 | ``` 138 | 139 | -------------------------------------------------------------------------------- /data/w1015.dat: -------------------------------------------------------------------------------- 1 | W1015 2 | 1.000000 0.000833 3 | 0.995381 0.001427 4 | 0.987977 0.002574 5 | 0.979865 0.003883 6 | 0.971058 0.005266 7 | 0.961675 0.006732 8 | 0.951852 0.008255 9 | 0.941694 0.009784 10 | 0.931242 0.011292 11 | 0.920520 0.012783 12 | 0.909558 0.014262 13 | 0.898393 0.015735 14 | 0.887068 0.017209 15 | 0.875629 0.018690 16 | 0.864117 0.020176 17 | 0.852562 0.021663 18 | 0.840979 0.023149 19 | 0.829375 0.024634 20 | 0.817752 0.026118 21 | 0.806117 0.027601 22 | 0.794473 0.029083 23 | 0.782828 0.030563 24 | 0.771183 0.032040 25 | 0.759543 0.033513 26 | 0.747907 0.034981 27 | 0.736277 0.036445 28 | 0.724654 0.037902 29 | 0.713040 0.039353 30 | 0.701436 0.040795 31 | 0.689844 0.042228 32 | 0.678266 0.043650 33 | 0.666701 0.045059 34 | 0.655152 0.046454 35 | 0.643618 0.047833 36 | 0.632100 0.049194 37 | 0.620598 0.050536 38 | 0.609111 0.051858 39 | 0.597637 0.053157 40 | 0.586177 0.054433 41 | 0.574730 0.055685 42 | 0.563294 0.056910 43 | 0.551870 0.058109 44 | 0.540458 0.059279 45 | 0.529057 0.060420 46 | 0.517668 0.061530 47 | 0.506292 0.062608 48 | 0.494928 0.063653 49 | 0.483579 0.064663 50 | 0.472243 0.065637 51 | 0.460923 0.066574 52 | 0.449619 0.067471 53 | 0.438331 0.068327 54 | 0.427060 0.069140 55 | 0.415807 0.069910 56 | 0.404573 0.070634 57 | 0.393359 0.071310 58 | 0.382166 0.071936 59 | 0.370997 0.072510 60 | 0.359853 0.073031 61 | 0.348736 0.073496 62 | 0.337647 0.073902 63 | 0.326589 0.074248 64 | 0.315563 0.074531 65 | 0.304574 0.074748 66 | 0.293625 0.074896 67 | 0.282717 0.074973 68 | 0.271857 0.074976 69 | 0.261048 0.074901 70 | 0.250295 0.074746 71 | 0.239604 0.074506 72 | 0.228982 0.074179 73 | 0.218436 0.073760 74 | 0.207976 0.073247 75 | 0.197613 0.072636 76 | 0.187358 0.071922 77 | 0.177226 0.071102 78 | 0.167232 0.070175 79 | 0.157394 0.069135 80 | 0.147735 0.067983 81 | 0.138276 0.066717 82 | 0.129044 0.065339 83 | 0.120071 0.063850 84 | 0.111388 0.062257 85 | 0.103029 0.060568 86 | 0.095027 0.058791 87 | 0.087414 0.056939 88 | 0.080213 0.055026 89 | 0.073440 0.053067 90 | 0.067105 0.051075 91 | 0.061209 0.049066 92 | 0.055741 0.047053 93 | 0.050685 0.045047 94 | 0.046025 0.043058 95 | 0.041736 0.041093 96 | 0.037793 0.039160 97 | 0.034168 0.037261 98 | 0.030837 0.035398 99 | 0.027777 0.033571 100 | 0.024965 0.031780 101 | 0.022377 0.030026 102 | 0.019994 0.028308 103 | 0.017800 0.026623 104 | 0.015782 0.024970 105 | 0.013927 0.023347 106 | 0.012223 0.021752 107 | 0.010657 0.020186 108 | 0.009220 0.018646 109 | 0.007904 0.017132 110 | 0.006701 0.015643 111 | 0.005606 0.014179 112 | 0.004611 0.012738 113 | 0.003712 0.011319 114 | 0.002906 0.009920 115 | 0.002193 0.008540 116 | 0.001576 0.007179 117 | 0.001055 0.005831 118 | 0.000635 0.004494 119 | 0.000321 0.003174 120 | 0.000114 0.001882 121 | 0.000013 0.000623 122 | 0.000013 -0.000623 123 | 0.000114 -0.001882 124 | 0.000321 -0.003174 125 | 0.000635 -0.004494 126 | 0.001055 -0.005831 127 | 0.001576 -0.007179 128 | 0.002193 -0.008540 129 | 0.002906 -0.009920 130 | 0.003712 -0.011319 131 | 0.004611 -0.012738 132 | 0.005606 -0.014179 133 | 0.006701 -0.015643 134 | 0.007904 -0.017132 135 | 0.009220 -0.018646 136 | 0.010657 -0.020186 137 | 0.012223 -0.021752 138 | 0.013927 -0.023347 139 | 0.015782 -0.024970 140 | 0.017800 -0.026623 141 | 0.019994 -0.028308 142 | 0.022377 -0.030026 143 | 0.024965 -0.031780 144 | 0.027777 -0.033571 145 | 0.030837 -0.035398 146 | 0.034168 -0.037261 147 | 0.037793 -0.039160 148 | 0.041736 -0.041093 149 | 0.046025 -0.043058 150 | 0.050685 -0.045047 151 | 0.055741 -0.047053 152 | 0.061209 -0.049065 153 | 0.067105 -0.051075 154 | 0.073440 -0.053067 155 | 0.080213 -0.055026 156 | 0.087414 -0.056939 157 | 0.095027 -0.058791 158 | 0.103029 -0.060568 159 | 0.111388 -0.062257 160 | 0.120071 -0.063850 161 | 0.129044 -0.065339 162 | 0.138276 -0.066717 163 | 0.147735 -0.067983 164 | 0.157394 -0.069135 165 | 0.167232 -0.070175 166 | 0.177226 -0.071102 167 | 0.187358 -0.071922 168 | 0.197613 -0.072636 169 | 0.207976 -0.073247 170 | 0.218436 -0.073760 171 | 0.228982 -0.074179 172 | 0.239604 -0.074506 173 | 0.250295 -0.074746 174 | 0.261048 -0.074901 175 | 0.271857 -0.074976 176 | 0.282717 -0.074973 177 | 0.293625 -0.074896 178 | 0.304574 -0.074748 179 | 0.315563 -0.074531 180 | 0.326588 -0.074248 181 | 0.337647 -0.073902 182 | 0.348736 -0.073496 183 | 0.359853 -0.073031 184 | 0.370998 -0.072510 185 | 0.382166 -0.071936 186 | 0.393359 -0.071309 187 | 0.404573 -0.070633 188 | 0.415807 -0.069910 189 | 0.427060 -0.069140 190 | 0.438331 -0.068327 191 | 0.449619 -0.067470 192 | 0.460924 -0.066573 193 | 0.472243 -0.065637 194 | 0.483579 -0.064663 195 | 0.494928 -0.063653 196 | 0.506292 -0.062608 197 | 0.517668 -0.061530 198 | 0.529057 -0.060420 199 | 0.540458 -0.059279 200 | 0.551871 -0.058109 201 | 0.563294 -0.056910 202 | 0.574730 -0.055685 203 | 0.586177 -0.054433 204 | 0.597637 -0.053157 205 | 0.609111 -0.051858 206 | 0.620598 -0.050536 207 | 0.632100 -0.049194 208 | 0.643618 -0.047832 209 | 0.655152 -0.046453 210 | 0.666701 -0.045058 211 | 0.678266 -0.043649 212 | 0.689844 -0.042228 213 | 0.701436 -0.040795 214 | 0.713040 -0.039353 215 | 0.724655 -0.037902 216 | 0.736277 -0.036445 217 | 0.747907 -0.034981 218 | 0.759543 -0.033513 219 | 0.771184 -0.032039 220 | 0.782828 -0.030563 221 | 0.794474 -0.029083 222 | 0.806117 -0.027601 223 | 0.817752 -0.026118 224 | 0.829375 -0.024634 225 | 0.840979 -0.023149 226 | 0.852562 -0.021662 227 | 0.864117 -0.020175 228 | 0.875629 -0.018690 229 | 0.887068 -0.017209 230 | 0.898393 -0.015735 231 | 0.909558 -0.014261 232 | 0.920520 -0.012782 233 | 0.931242 -0.011291 234 | 0.941694 -0.009783 235 | 0.951852 -0.008255 236 | 0.961675 -0.006732 237 | 0.971058 -0.005266 238 | 0.979866 -0.003883 239 | 0.987977 -0.002573 240 | 0.995381 -0.001426 241 | 1.000000 -0.000833 242 | -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /outs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/poster.png -------------------------------------------------------------------------------- /problems.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/problems.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2021.10.8 2 | cycler==0.11.0 3 | DeepXDE==1.1.1 4 | fonttools==4.30.0 5 | joblib==1.1.0 6 | kiwisolver==1.3.2 7 | matplotlib==3.5.1 8 | numpy==1.22.3 9 | packaging==21.3 10 | Pillow==9.0.1 11 | pyaml==21.10.1 12 | pyparsing==3.0.7 13 | python-dateutil==2.8.2 14 | PyYAML==6.0 15 | scikit-learn==1.0.2 16 | scikit-optimize==0.9.0 17 | scipy==1.8.0 18 | six==1.16.0 19 | threadpoolctl==3.1.0 20 | torch==1.10.2 21 | typing_extensions==4.1.1 22 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/results.png -------------------------------------------------------------------------------- /src/FBPINN/fbpinn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from src.utils.resnet import ResNet 4 | import deepxde as dde 5 | 6 | 7 | class FBPINN(nn.Module): 8 | """ 9 | Doamin decomposition PINN with window functions. 10 | ## Parameter 11 | sigma - scale factor\n 12 | L - length\n 13 | H - height\n 14 | n_col - the number of columns of the sub-domains\n 15 | n_row - the number of rows of the sub-domains\n 16 | num_layers, input_dim, hidden_dim, output_dim - the parameters of each sub-network\n 17 | xmin - the offset of the left-bottom point 18 | """ 19 | def __init__( 20 | self, sigma, L, H, n_col, n_row, 21 | num_layers, input_dim, hidden_dim, output_dim, 22 | is_res_net=False, xmin=None 23 | ) -> None: 24 | super(FBPINN, self).__init__() 25 | self.sigma = sigma 26 | self.L = L 27 | self.H = H 28 | if xmin is None: 29 | self.xmin = [- L / 2, - H / 2] 30 | else: 31 | self.xmin = xmin 32 | self.n_col = n_col 33 | self.n_row = n_row 34 | if is_res_net: 35 | nets = [ResNet(num_layers, input_dim, hidden_dim, output_dim) for _ in range(n_col * n_row)] 36 | else: 37 | nets = [dde.nn.FNN([input_dim] + num_layers * [hidden_dim] + [output_dim], "tanh", "Glorot normal") for _ in range(n_col * n_row)] 38 | self.nets = nn.ModuleList(nets) 39 | self.lower_bs = None 40 | 41 | def forward(self, x) -> torch.Tensor: 42 | if self.lower_bs is None: 43 | self.lower_bs = [ 44 | [ 45 | torch.tensor([self.L * j / self.n_col + self.xmin[0], 46 | self.H * i / self.n_row + self.xmin[1]]) 47 | for j in range(self.n_col) 48 | ] for i in range(self.n_row) 49 | ] 50 | self.upper_bs = [ 51 | [ 52 | torch.tensor([self.L * (j+1) / self.n_col + self.xmin[0], 53 | self.H * (i+1) / self.n_row + self.xmin[1]]) 54 | for j in range(self.n_col) 55 | ] for i in range(self.n_row) 56 | ] 57 | self.centers = [ 58 | [ 59 | torch.tensor([(self.L * j / self.n_col + self.L * (j+1) / self.n_col) / 2 + self.xmin[0], 60 | (self.H * i / self.n_row + self.H * (i+1) / self.n_row) / 2 + self.xmin[1]]) 61 | for j in range(self.n_col) 62 | ] for i in range(self.n_row) 63 | ] 64 | self.subdomain_size = torch.tensor([self.L / self.n_col, self.H / self.n_row]) 65 | res = None 66 | spatial_x = x[:, :2] 67 | for i in range(self.n_row): 68 | for j in range(self.n_col): 69 | window_res = torch.sigmoid( 70 | (spatial_x[:, 0:1] - self.lower_bs[i][j][0]) / self.sigma 71 | ) * torch.sigmoid( 72 | (self.upper_bs[i][j][0:1] - spatial_x[:, 0:1]) / self.sigma 73 | ) * torch.sigmoid( 74 | (spatial_x[:, 1:] - self.lower_bs[i][j][1]) / self.sigma 75 | ) * torch.sigmoid( 76 | (self.upper_bs[i][j][1] - spatial_x[:, 1:]) / self.sigma 77 | ) 78 | # normalization 79 | spatial_x_normalized = (spatial_x - self.centers[i][j]) / \ 80 | self.subdomain_size 81 | # recover temporal dimension 82 | if x.shape[1] > 2: 83 | x_normalized = torch.cat([spatial_x_normalized, x[:, 2:3]], dim=1) 84 | else: 85 | x_normalized = spatial_x_normalized 86 | nn_res = self.nets[i * self.n_col + j](x_normalized) 87 | cal_res = window_res.expand(-1, nn_res.shape[1]) * nn_res 88 | if res is None: 89 | res = cal_res 90 | else: 91 | res += cal_res 92 | return res 93 | -------------------------------------------------------------------------------- /src/HC/hard_constraint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from src.utils.nn_wrapper import NNWrapper 4 | 5 | 6 | class HardConstraintRobin2D(nn.Module): 7 | """ 8 | Build hard constraint formula for Robin BCs in 2D. 9 | """ 10 | def __init__(self, f_normal, NN: NNWrapper, f_a, f_b, f_g) -> None: 11 | ''' 12 | f_normal - function for calculating normals.\n 13 | NN - a neural network.\n 14 | ... - parameters for the Robin BCs 15 | ''' 16 | super(HardConstraintRobin2D, self).__init__() 17 | self.f_normal = f_normal 18 | self.NN = NN 19 | self.f_a = f_a 20 | self.f_b = f_b 21 | self.f_g = f_g 22 | self.normal_res = None 23 | 24 | def get_u(self, x): 25 | if self.normal_res is None: 26 | self.normal_res = self.f_normal(x) 27 | ns = self.normal_res 28 | hs = self.NN(x) 29 | a = self.f_a(x) 30 | b = self.f_b(x) 31 | res = b * (-ns[:, 1:2] * hs[:, 0:1] + ns[:, 0:1] * hs[:, 1:2]) + a * self.f_g(x) 32 | return res / (a ** 2 + b ** 2) 33 | 34 | def get_p_1(self, x): 35 | if self.normal_res is None: 36 | self.normal_res = self.f_normal(x) 37 | ns = self.normal_res 38 | hs = self.NN(x) 39 | a = self.f_a(x) 40 | b = self.f_b(x) 41 | res = b * (ns[:, 1:2] * hs[:, 2:3] + ns[:, 0:1] * self.f_g(x)) - a * hs[:, 1:2] 42 | return res / (a ** 2 + b ** 2) 43 | 44 | def get_p_2(self, x): 45 | if self.normal_res is None: 46 | self.normal_res = self.f_normal(x) 47 | ns = self.normal_res 48 | hs = self.NN(x) 49 | a = self.f_a(x) 50 | b = self.f_b(x) 51 | res = b * (-ns[:, 0:1] * hs[:, 2:3] + ns[:, 1:2] * self.f_g(x)) + a * hs[:, 0:1] 52 | return res / (a ** 2 + b ** 2) 53 | 54 | def clear_res(self): 55 | self.NN.clear_res() 56 | self.normal_res = None 57 | 58 | def save(self, path_prefix: str, name: str): 59 | torch.save(self.NN, path_prefix + name) 60 | 61 | def load(self, path_prefix: str, name: str): 62 | self.NN = torch.load(path_prefix + name) 63 | 64 | 65 | class HardConstraintNeumann2D(nn.Module): 66 | """ 67 | Build hard constraint formula for Neumann BCs in 2D. 68 | """ 69 | def __init__(self, f_normal, NN: NNWrapper, f_g) -> None: 70 | ''' 71 | f_normal - function for calculating normals.\n 72 | NN - a neural network.\n 73 | ... - parameters for the Neumann BCs 74 | ''' 75 | super(HardConstraintNeumann2D, self).__init__() 76 | self.f_normal = f_normal 77 | self.NN = NN 78 | self.f_g = f_g 79 | self.normal_res = None 80 | 81 | def get_p_1(self, x): 82 | if self.normal_res is None: 83 | self.normal_res = self.f_normal(x) 84 | ns = self.normal_res 85 | h = self.NN(x) 86 | return ns[:, 1:2] * h + ns[:, 0:1] * self.f_g(x) 87 | 88 | def get_p_2(self, x): 89 | if self.normal_res is None: 90 | self.normal_res = self.f_normal(x) 91 | ns = self.normal_res 92 | h = self.NN(x) 93 | return -ns[:, 0:1] * h + ns[:, 1:2] * self.f_g(x) 94 | 95 | def clear_res(self): 96 | self.NN.clear_res() 97 | self.normal_res = None 98 | 99 | def save(self, path_prefix: str, name: str): 100 | torch.save(self.NN, path_prefix + name) 101 | 102 | def load(self, path_prefix: str, name: str): 103 | self.NN = torch.load(path_prefix + name) 104 | 105 | 106 | class HardConstraintNeumannND(nn.Module): 107 | """ 108 | Build hard constraint formula for Neumann BCs in ND. 109 | """ 110 | def __init__(self, f_normal, NN: NNWrapper, f_g) -> None: 111 | ''' 112 | f_normal - function for calculating normals.\n 113 | NN - a neural network.\n 114 | ... - parameters for the Neumann BCs 115 | ''' 116 | super(HardConstraintNeumannND, self).__init__() 117 | self.f_normal = f_normal 118 | self.NN = NN 119 | self.f_g = f_g 120 | self.res = None 121 | 122 | def get_p_i(self, x, i): 123 | if self.res is None: 124 | ns = self.f_normal(x) 125 | hs = self.NN(x) 126 | self.res = hs + ns * (self.f_g(x) - torch.sum(ns * hs, dim=1, keepdim=True)) 127 | return self.res[:, i:i+1] 128 | 129 | def clear_res(self): 130 | self.NN.clear_res() 131 | self.res = None 132 | 133 | def save(self, path_prefix: str, name: str): 134 | torch.save(self.NN, path_prefix + name) 135 | 136 | def load(self, path_prefix: str, name: str): 137 | self.NN = torch.load(path_prefix + name) -------------------------------------------------------------------------------- /src/HC/hard_constraint_collector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class HardConstraintCollector(nn.Module): 5 | """ 6 | Parametrize a single state variable u_i. 7 | """ 8 | def __init__(self, i, M, ls, us, N) -> None: 9 | ''' 10 | i - index of the state variable.\n 11 | M - M function, taking B x m tensors as inputs, where B is the batch size, 12 | and m is the number of boundaries, outputing B x 1 tensors.\n 13 | ls - a list of m callable objects (lambda functions), each taking B x d tensors as inputs, 14 | where d is the dimensionality, outputing B x 1 tensors.\n 15 | us - a list of m callable objects (the general solutions of u_i at each boundary, B x 1), 16 | taking B x d tensors as inputs.\n 17 | N - a callable object to generate the raw output (B x d'), 18 | taking B x d tensors and the index as inputs. 19 | ''' 20 | super(HardConstraintCollector, self).__init__() 21 | self.i = i 22 | self.M = M 23 | self.ls = ls 24 | self.us = us 25 | self.N = N 26 | 27 | def forward(self, x) -> torch.Tensor: 28 | """ 29 | Map the coordinates to the state variable u_i.\n 30 | x - coordinates, a B x d tensor, where B is the batch size, and d is the dimensionality.\n 31 | Return a B x 1 tensor corresponding to u_i. 32 | """ 33 | dists = torch.cat([l.get_dist(x) for l in self.ls], dim=1) # output: B x m 34 | u_res = self.M(dists) * self.N(x, self.i) # output: B x 1 35 | for j in range(len(self.ls)): 36 | u_res += self.ls[j](x) * self.us[j](x) 37 | return u_res 38 | -------------------------------------------------------------------------------- /src/HC/l_functions.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import numpy as np 3 | import torch 4 | from torch import optim, nn 5 | import deepxde as dde 6 | from src.utils.utils import cart2pol_pt, lineseg_dists 7 | 8 | class LFunctionBase: 9 | """ 10 | Lambda function. 11 | """ 12 | def __init__(self, X, geom) -> None: 13 | self.X = X 14 | self.geom = geom 15 | self.alpha = None 16 | 17 | def get_dist(self, x): 18 | ''' 19 | Calculate the extended distance to x (minibatch). 20 | ''' 21 | del x 22 | 23 | def get_alpha(self): 24 | r''' 25 | Calculate the parameter $\alpha$. 26 | ''' 27 | dists = self.get_dist(torch.tensor(self.X).float()).detach().cpu().numpy() 28 | dists_ = dists[~self.geom.on_boundary(self.X)] 29 | # make sure when it comes to the other nearest boundary 30 | # the coefficient goes down to exp(-5) 31 | self.alpha = 5 / np.min(dists_) 32 | 33 | def __call__(self, x): 34 | return torch.exp(-self.alpha * self.get_dist(x)) 35 | 36 | 37 | class LFunctionDisk(LFunctionBase): 38 | """ 39 | Lambda function for a 2D disk. 40 | """ 41 | def __init__(self, X, disk, inner=True) -> None: 42 | super().__init__(X, disk) 43 | self.center = torch.tensor(disk.center) 44 | self.radius = disk.radius 45 | self.inner = inner 46 | self.get_alpha() 47 | 48 | def get_dist(self, x): 49 | x = x[:, :2] 50 | if self.inner: 51 | return torch.linalg.norm(x - self.center, dim=1, keepdim=True) - self.radius 52 | else: 53 | return self.radius - torch.linalg.norm(x - self.center, dim=1, keepdim=True) 54 | 55 | 56 | class LFunctionRectangle(LFunctionBase): 57 | """ 58 | Lambda function for a 2D rectangle. 59 | """ 60 | def __init__(self, X, rec, m_function: Callable) -> None: 61 | super().__init__(X, rec) 62 | self.xmin = rec.xmin 63 | self.xmax = rec.xmax 64 | self.m_function = m_function 65 | self.get_alpha() 66 | 67 | def get_dist(self, x): 68 | dist = torch.stack([ 69 | x[:, 0] - self.xmin[0], - x[:, 0] + self.xmax[0], 70 | x[:, 1] - self.xmin[1], - x[:, 1] + self.xmax[1], 71 | ], dim=1) 72 | return self.m_function(dist) 73 | 74 | 75 | class LFunctionOpenRectangle(LFunctionBase): 76 | """ 77 | Lambda function for a 2D (right) open rectangle.\n 78 | |------------------\n 79 | |\n 80 | |------------------ 81 | """ 82 | def __init__(self, X, rec, m_function: Callable) -> None: 83 | super().__init__(X, rec) 84 | self.xmin = rec.xmin 85 | self.xmax = rec.xmax 86 | self.m_function = m_function 87 | self.get_alpha() 88 | 89 | def get_dist(self, x): 90 | dist = torch.stack([ 91 | x[:, 0] - self.xmin[0], 92 | x[:, 1] - self.xmin[1], - x[:, 1] + self.xmax[1], 93 | ], dim=1) 94 | return self.m_function(dist) 95 | 96 | 97 | class LFunctionAxisLine(LFunctionBase): 98 | """ 99 | Lambda function for a line perpendicular to the axis. 100 | """ 101 | def __init__(self, X, geom, x_0, j, is_left=True) -> None: 102 | ''' 103 | x_0 - intersection point 104 | j - axis number (start from zero) 105 | is_left - left or right boundary 106 | ''' 107 | super().__init__(X, geom) 108 | self.x_0 = x_0 109 | self.j = j 110 | self.is_left = is_left 111 | self.get_alpha() 112 | 113 | def get_dist(self, x): 114 | if self.is_left: 115 | return x[:, self.j:self.j+1] - self.x_0 116 | else: 117 | return -x[:, self.j:self.j+1] + self.x_0 118 | 119 | 120 | class DistNet(nn.Module): 121 | """ 122 | Network to produce a prediction of distance. 123 | """ 124 | def __init__(self, reference_points) -> None: 125 | super(DistNet, self).__init__() 126 | self.reference_points = [ 127 | torch.tensor(reference_point) 128 | for reference_point in reference_points 129 | ] 130 | self.net = dde.nn.FNN([2 * len(reference_points)] + 3 * [30] + [1], 131 | "tanh", "Glorot normal") 132 | 133 | def forward(self, x) -> torch.Tensor: 134 | x_polars = [] 135 | for reference_point in self.reference_points: 136 | delta_x = x - reference_point 137 | x_polars.extend(cart2pol_pt(delta_x[:, 0:1], delta_x[:, 1:])) 138 | x_polars = torch.cat(x_polars, dim=1) 139 | dist_pred = self.net(x_polars) 140 | return dist_pred 141 | 142 | 143 | class LFunctionPolygon(LFunctionBase): 144 | """ 145 | Lambda function for a 2D polygon. 146 | """ 147 | def __init__(self, X, polygon, spatial_domain) -> None: 148 | super().__init__(X, polygon) 149 | self.polygon = polygon 150 | self.vertices_left = polygon.vertices 151 | self.vertices_right = np.roll(polygon.vertices, 1, axis=0) 152 | center_1 = np.mean(polygon.vertices[polygon.vertices[:,0]<0.5,:], axis=0) 153 | center_2 = np.mean(polygon.vertices[polygon.vertices[:,0]>=0.5,:], axis=0) 154 | # sample points 155 | eps = 0.01 156 | self.bbox = dde.geometry.CSGDifference( 157 | dde.geometry.Rectangle( 158 | xmin=[np.min(polygon.vertices[:, 0]) - eps, np.min(polygon.vertices[:, 1]) - eps], 159 | xmax=[np.max(polygon.vertices[:, 0]) + eps, np.max(polygon.vertices[:, 1]) + eps] 160 | ), polygon) 161 | self.spatial_domain = spatial_domain 162 | X, dists = self.sample_points(1024 * 6) 163 | self.model = DistNet([center_1, center_2]) 164 | X = torch.tensor(X).float() 165 | dists = torch.tensor(dists).float() 166 | self.train(X, dists) 167 | self.get_alpha() 168 | 169 | def sample_points(self, n): 170 | points = np.concatenate(( 171 | self.bbox.random_points(n * 5 // 6), 172 | self.spatial_domain.random_points(n * 1 // 6) 173 | )) 174 | dists = [] 175 | for point in points: 176 | dists.append( 177 | [np.min(lineseg_dists(point, self.vertices_left, self.vertices_right))] 178 | ) 179 | return points, np.array(dists) 180 | 181 | def loss_fn(self, dists, Y): 182 | loss = Y - dists 183 | return torch.mean(torch.abs(loss)) 184 | 185 | def train(self, X, dists): 186 | print("Training extended dist for 2D polygon...") 187 | n_epochs = 10000 188 | optimizer = optim.Adam(self.model.parameters(), lr=1e-3) 189 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 190 | optimizer, patience=100, factor=0.75, min_lr=1e-5 191 | ) 192 | for i in range(n_epochs): 193 | Y = self.model(X) 194 | loss = self.loss_fn(dists, Y) 195 | # Backpropagation 196 | optimizer.zero_grad() 197 | loss.backward() 198 | optimizer.step() 199 | scheduler.step(loss.item()) 200 | if (i+1) % 1000 == 0 or i == 0: 201 | print(f"[Epoch {i+1}/{n_epochs}] loss: {loss.item():>7f}") 202 | # test 203 | X, dists = self.sample_points(1024) 204 | Y = self.get_dist(torch.tensor(X).float()).detach().cpu().numpy() 205 | print(f"Finish training!\nTesting loss: {np.mean(np.abs(Y - dists)):>7f}") 206 | 207 | def get_dist(self, x): 208 | x = x[:, :2] 209 | return self.model(x) 210 | -------------------------------------------------------------------------------- /src/HC/normal_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from src.utils.torch_interp import Interp1d 4 | from src.utils.utils import cart2pol_np, cart2pol_pt 5 | 6 | 7 | class NormalFunctionDisk: 8 | """ 9 | Normal function for a 2D disk. 10 | """ 11 | def __init__(self, center, inner=True) -> None: 12 | self.center = torch.tensor(center) 13 | self.inner = inner 14 | 15 | def __call__(self, x): 16 | ''' 17 | Calcute the extended (outer) normal at x (minibatch) 18 | ''' 19 | x = x[:, :2] 20 | if self.inner: 21 | d_x = self.center - x 22 | else: 23 | d_x = x - self.center 24 | return d_x / torch.linalg.norm(d_x, dim=1, keepdim=True) 25 | 26 | 27 | class NormalFunctionSphere: 28 | """ 29 | Normal function for a ND sphere. 30 | """ 31 | def __init__(self, center, inner=True) -> None: 32 | self.n_dim = len(center) 33 | self.center = torch.tensor(center) 34 | self.inner = inner 35 | 36 | def __call__(self, x): 37 | ''' 38 | Calcute the extended (outer) normal at x (minibatch) 39 | ''' 40 | x = x[:, :self.n_dim] 41 | if self.inner: 42 | d_x = self.center - x 43 | else: 44 | d_x = x - self.center 45 | return d_x / torch.linalg.norm(d_x, dim=1, keepdim=True) 46 | 47 | 48 | class NormalFunctionRectangle: 49 | """ 50 | Normal function for a 2D rectangle (centered at (0,0), outer boundary). 51 | """ 52 | def __init__(self, H, L) -> None: 53 | self.H = H 54 | self.L = L 55 | 56 | def __call__(self, x): 57 | ''' 58 | Calcute the extended (outer) normal at x (minibatch) 59 | ''' 60 | x = x[:, :2] 61 | k = self.H / self.L 62 | n = torch.zeros_like(x) 63 | n[torch.where(torch.isclose(x[:,0], torch.tensor(0.)))] = torch.tensor([0., 1.]) 64 | n[torch.where(torch.logical_and(torch.logical_or(x[:,1] >= x[:,0] * k, x[:,1] <= -x[:,0] * k), x[:,1] >= 0))] = torch.tensor([0., 1.]) 65 | n[torch.where(torch.logical_and(torch.logical_or(x[:,1] >= x[:,0] * k, x[:,1] <= -x[:,0] * k), x[:,1] <= 0))] = torch.tensor([0., -1.]) 66 | n[torch.where(torch.logical_and(torch.logical_and(x[:,1] <= x[:,0] * k, x[:,1] >= -x[:,0] * k), x[:,0] >= 0))] = torch.tensor([1., 0.]) 67 | n[torch.where(torch.logical_and(torch.logical_and(x[:,1] <= x[:,0] * k, x[:,1] >= -x[:,0] * k), x[:,0] <= 0))] = torch.tensor([-1., 0.]) 68 | return n 69 | 70 | class NormalFunctionPolygon: 71 | """ 72 | Normal function for a 2D polygon. 73 | """ 74 | def __init__(self, polygon, inner=True) -> None: 75 | self.polygon = polygon 76 | self.inner = inner 77 | # reference center 78 | self.center_np = np.mean( 79 | polygon.vertices, axis=0 80 | ) 81 | self.center = torch.tensor(self.center_np) 82 | # middle points in the polygon 83 | X = (polygon.vertices + np.roll(polygon.vertices, 1, axis=0)) / 2 84 | if self.inner: 85 | normals = -self.polygon.boundary_normal(X) 86 | else: 87 | normals = self.polygon.boundary_normal(X) 88 | thetas = self.get_thetas_np(X) 89 | normal_thetas = cart2pol_np(normals[:, 0], normals[:, 1])[1] 90 | # correct the thetas 91 | eps = 1e-66 92 | normal_thetas[0] = -np.pi 93 | normal_thetas = np.concatenate((normal_thetas, [np.pi])) 94 | thetas[0, :] = [0] 95 | thetas = np.concatenate((thetas, [[-eps]])) 96 | sorted_indices = np.argsort(thetas[:, 0]) 97 | thetas = thetas[sorted_indices, :] 98 | normal_thetas = normal_thetas[sorted_indices] 99 | thetas[0, :] = [-np.pi] 100 | thetas = np.concatenate((thetas, [[np.pi]])) 101 | normal_thetas = np.concatenate((normal_thetas, [0])) 102 | # interpolation 103 | self.thetas = torch.tensor(thetas.T) 104 | self.normal_thetas = torch.tensor(normal_thetas) 105 | self.interp = Interp1d() 106 | 107 | def get_thetas_pt(self, x): 108 | delta_x = x - self.center 109 | theta_x = cart2pol_pt(delta_x[:, 0:1], delta_x[:, 1:])[1] 110 | return theta_x 111 | 112 | def get_thetas_np(self, x): 113 | delta_x = x - self.center_np 114 | theta_x = cart2pol_np(delta_x[:, 0:1], delta_x[:, 1:])[1] 115 | return theta_x 116 | 117 | def __call__(self, x): 118 | ''' 119 | Calcute the extended (outer) normal at x (minibatch) 120 | ''' 121 | x = x[:, :2] 122 | x_thetas = self.get_thetas_pt(x) 123 | normal_thetas = self.interp( 124 | self.thetas, self.normal_thetas, 125 | torch.transpose(x_thetas, 0, 1) 126 | ) 127 | normal_thetas = torch.transpose(normal_thetas, 0, 1) 128 | return torch.concat([torch.cos(normal_thetas), torch.sin(normal_thetas)], dim=1) 129 | -------------------------------------------------------------------------------- /src/PFNN/pfnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import deepxde as dde 4 | 5 | class PFNN(nn.Module): 6 | """ 7 | Penalty-Free Neural Network for time-dependent PDEs. 8 | """ 9 | def __init__(self, ic_fn, num_input, num_output, ddpinn=None) -> None: 10 | ''' 11 | ddpinn - a domain-decomposition based PINN. Specify this nn to select the PFNN-2. 12 | ''' 13 | super(PFNN, self).__init__() 14 | self.ic_fn = ic_fn 15 | if ddpinn is None: 16 | self.net = dde.nn.FNN([num_input] + 4 * [50] + [num_output], 17 | "tanh", "Glorot normal") 18 | else: 19 | self.net = ddpinn 20 | 21 | def forward(self, x) -> torch.Tensor: 22 | # -1 is the time dimension 23 | return self.net(x) * x[:, -1:] + self.ic_fn(x) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/src/__init__.py -------------------------------------------------------------------------------- /src/case1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import numpy as np 4 | import deepxde as dde 5 | from deepxde import gradients as dde_grad 6 | from src.FBPINN.fbpinn import FBPINN 7 | from src.PFNN.pfnn import PFNN 8 | from src.configs.case1.xpinn import pde_xpinn 9 | from src.utils.no_stdout_context import nostdout 10 | from src.utils.utils import plot_lines, test_time, Tester 11 | from src.utils.pinn_callback import PINNLRAdaptor, PINNLRScheduler, PINNModelSaver, PINNTester 12 | from src.configs.case1.params import model_path_prefix, data_path, H, L 13 | from src.configs.case1.pinn import spatial_time_domain, time_domain, pde_pinn, ic_bcs, num_bcs, num_pdes, lr_alpha 14 | from src.configs.case1.hc import pde_hc, HCNN 15 | from src.configs.case1.fbpinn import sigma 16 | from src.configs.case1.pfnn import tot_points, loss_pfnn, initial_condition 17 | from src.xPINN.interface_conditions import Subdomains 18 | from src.xPINN.xPINN import xPINN 19 | 20 | 21 | # Model Selection (default: HC) 22 | TEST_PINN = False 23 | TEST_FBPINN = False # this choice is valid if TEST_PINN == True 24 | PINN_LR_ANNEALING = False # this choice is valid if TEST_PINN == True 25 | PINN_LR_ANNEALING_2 = False # this choice is valid if TEST_PINN == True 26 | TEST_XPINN = False 27 | TEST_PFNN = False 28 | TEST_PFNN_2 = False # this choice is valid if TEST_PFNN == True 29 | 30 | # Other Configurations 31 | LOAD_MODEL = False 32 | SAVE_MODEL = False 33 | TEST_WHILE_TRAIN = False 34 | 35 | 36 | def train_pinn(): 37 | n_epochs = 5000 38 | lr = 0.01 39 | data = dde.data.TimePDE( 40 | spatial_time_domain, 41 | pde_pinn, 42 | ic_bcs, 43 | num_domain=8192, 44 | num_boundary=512, 45 | num_initial=512 46 | ) 47 | if TEST_FBPINN: 48 | net = torch.nn.DataParallel(FBPINN(sigma, L, H, 4, 6, 3, 3, 30, 1)) 49 | else: 50 | net = dde.nn.FNN([3] + 4 * [50] + [1], "tanh", "Glorot normal") 51 | if LOAD_MODEL: 52 | net = torch.load(model_path_prefix + "pinn.pth") 53 | # train 54 | optimizer = optim.Adam( 55 | net.parameters(), 56 | lr=lr 57 | ) 58 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 59 | optimizer, patience=100, factor=0.75, min_lr=1e-5 60 | ) 61 | test_callback = PINNTester(data_path, spatial_time_domain.dim, Tester.test_while_train) 62 | loss_weights = [1.] * (num_pdes + num_bcs) 63 | lr_adaptor_callback = PINNLRAdaptor( 64 | loss_weights, num_pdes, lr_alpha, 65 | mode="max" if PINN_LR_ANNEALING else "mean" 66 | ) 67 | lr_scheduler_callback = PINNLRScheduler(scheduler) 68 | resampler = dde.callbacks.PDEResidualResampler(period=10) 69 | callbacks = [lr_scheduler_callback, resampler] 70 | if TEST_WHILE_TRAIN: 71 | callbacks.append(test_callback) 72 | if PINN_LR_ANNEALING or PINN_LR_ANNEALING_2: 73 | callbacks.append(lr_adaptor_callback) 74 | model = dde.Model(data, net) 75 | model.compile("adam", lr=lr, loss_weights=loss_weights) 76 | with nostdout(): 77 | model.train(epochs=n_epochs, callbacks=callbacks, display_every=1) 78 | # lbfgs 79 | model_saver = PINNModelSaver() 80 | resampler = dde.callbacks.PDEResidualResampler(period=1) 81 | model.compile("L-BFGS", loss_weights=loss_weights) 82 | model.train(callbacks=[resampler, model_saver]) 83 | if model_saver.got_nan: 84 | model.net.load_state_dict(model_saver.weights) 85 | m_loss_res = test_callback.m_loss_res 86 | m_abs_e_res = test_callback.m_abs_e_res 87 | m_r_abs_e_res = test_callback.m_r_abs_e_res 88 | net = model.net 89 | if TEST_FBPINN: 90 | net = net.module 91 | # save the model 92 | if SAVE_MODEL: 93 | torch.save(net, model_path_prefix + "pinn.pth") 94 | return net, m_loss_res, m_abs_e_res, m_r_abs_e_res 95 | 96 | 97 | def train_hc(): 98 | n_epochs = 5000 99 | lr = 0.01 100 | net = HCNN(model_path_prefix if LOAD_MODEL else "") 101 | optimizer = optim.Adam( 102 | net.parameters(), 103 | lr=lr) 104 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 105 | optimizer, patience=100, factor=0.75, min_lr=1e-5 106 | ) 107 | # prepare data set 108 | X = torch.tensor(spatial_time_domain.random_points(8192)).float() 109 | X.requires_grad = True 110 | # train 111 | m_abs_e_res, m_r_abs_e_res, m_loss_res = [], [], [] 112 | print("Training...") 113 | for i in range(n_epochs): 114 | pred = net(X) 115 | loss = torch.cat(pde_hc(X, pred), dim=1) 116 | loss = torch.sum(loss ** 2, dim=1) 117 | loss = torch.mean(loss) 118 | # Backpropagation 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | # update the lr 123 | loss_val = loss.item() 124 | scheduler.step(loss_val) 125 | m_loss_res.append(loss_val) 126 | # test while train 127 | if TEST_WHILE_TRAIN: 128 | # need not net.eval() right now 129 | test_res = Tester.test_while_train(data_path, spatial_time_domain.dim, net) 130 | m_abs_e_res.append(test_res[0][0]) 131 | m_r_abs_e_res.append(test_res[1][0]) 132 | if (i+1) % 100 == 0 or i == 0: 133 | print(f"[Epoch {i+1}/{n_epochs}] loss: {loss_val:>7f}") 134 | # clear the cache of the grads 135 | dde_grad.clear() 136 | if i % 10 == 0: 137 | X = torch.tensor(spatial_time_domain.random_points(8192)).float() 138 | X.requires_grad = True 139 | # l-bfgs 140 | resampler = dde.callbacks.PDEResidualResampler(period=1) 141 | data = dde.data.TimePDE( 142 | spatial_time_domain, 143 | pde_hc, 144 | [], 145 | num_domain=8192 146 | ) 147 | model = dde.Model(data, net) 148 | model.compile("L-BFGS") 149 | model.train(callbacks=[resampler]) 150 | print("Finish training!") 151 | # save the model 152 | if SAVE_MODEL: 153 | net.save(model_path_prefix) 154 | return net, m_loss_res, m_abs_e_res, m_r_abs_e_res 155 | 156 | 157 | def train_xpinn(): 158 | # prepare for xPINN 159 | subdomains = Subdomains(L, H, 4, 6, 1, spatial_time_domain, temporal_domain=time_domain) 160 | interface_points = subdomains.generate_interface_points(512) 161 | interface_conditions = subdomains.generate_interface_conditions() 162 | # set data 163 | n_epochs = 5000 164 | lr = 0.01 165 | data = dde.data.TimePDE( 166 | spatial_time_domain, 167 | pde_xpinn, 168 | ic_bcs + interface_conditions, 169 | num_domain=8192, 170 | num_boundary=512, 171 | num_initial=512, 172 | anchors=interface_points 173 | ) 174 | net = torch.nn.DataParallel(xPINN(L, H, 4, 6, 3, 3, 30, 1, pde_pinn)) 175 | if LOAD_MODEL: 176 | net = torch.load(model_path_prefix + "xpinn.pth") 177 | # train 178 | optimizer = optim.Adam( 179 | net.parameters(), 180 | lr=lr 181 | ) 182 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 183 | optimizer, patience=100, factor=0.75, min_lr=1e-5 184 | ) 185 | test_callback = PINNTester(data_path, spatial_time_domain.dim, Tester.test_while_train) 186 | lr_scheduler_callback = PINNLRScheduler(scheduler) 187 | resampler = dde.callbacks.PDEResidualResampler(period=10) 188 | callbacks = [lr_scheduler_callback, resampler] 189 | if TEST_WHILE_TRAIN: 190 | callbacks.append(test_callback) 191 | model = dde.Model(data, net) 192 | model.compile("adam", lr=lr) 193 | with nostdout(): 194 | model.train(epochs=n_epochs, callbacks=callbacks, display_every=1) 195 | # lbfgs 196 | model_saver = PINNModelSaver() 197 | resampler = dde.callbacks.PDEResidualResampler(period=1) 198 | model.compile("L-BFGS") 199 | model.train(callbacks=[resampler, model_saver]) 200 | if model_saver.got_nan: 201 | model.net.load_state_dict(model_saver.weights) 202 | m_loss_res = test_callback.m_loss_res 203 | m_abs_e_res = test_callback.m_abs_e_res 204 | m_r_abs_e_res = test_callback.m_r_abs_e_res 205 | net = model.net.module 206 | # save the model 207 | if SAVE_MODEL: 208 | torch.save(net, model_path_prefix + "xpinn.pth") 209 | # set model to evaluation mode 210 | net.set_eval() 211 | return net, m_loss_res, m_abs_e_res, m_r_abs_e_res 212 | 213 | 214 | def train_pfnn(): 215 | n_epochs = 5000 216 | lr = 0.01 217 | data = dde.data.TimePDE( 218 | spatial_time_domain, 219 | loss_pfnn, 220 | [], 221 | anchors=tot_points 222 | ) 223 | net = PFNN( 224 | initial_condition, 3, 1, 225 | ddpinn=torch.nn.DataParallel(FBPINN(sigma, L, H, 4, 6, 3, 3, 30, 1)) if TEST_PFNN_2 else None 226 | ) 227 | if LOAD_MODEL: 228 | net = torch.load(model_path_prefix + "pfnn.pth") 229 | # train 230 | optimizer = optim.Adam( 231 | net.parameters(), 232 | lr=lr 233 | ) 234 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 235 | optimizer, patience=100, factor=0.75, min_lr=1e-5 236 | ) 237 | test_callback = PINNTester(data_path, spatial_time_domain.dim, Tester.test_while_train) 238 | lr_scheduler_callback = PINNLRScheduler(scheduler) 239 | resampler = dde.callbacks.PDEResidualResampler(period=10) 240 | callbacks = [lr_scheduler_callback, resampler] 241 | if TEST_WHILE_TRAIN: 242 | callbacks.append(test_callback) 243 | model = dde.Model(data, net) 244 | model.compile("adam", lr=lr) 245 | with nostdout(): 246 | model.train(epochs=n_epochs, callbacks=callbacks, display_every=1) 247 | # lbfgs 248 | model_saver = PINNModelSaver() 249 | model.compile("L-BFGS") 250 | model.train(callbacks=[model_saver]) 251 | if model_saver.got_nan: 252 | model.net.load_state_dict(model_saver.weights) 253 | m_loss_res = test_callback.m_loss_res 254 | m_abs_e_res = test_callback.m_abs_e_res 255 | m_r_abs_e_res = test_callback.m_r_abs_e_res 256 | net = model.net 257 | if TEST_PFNN_2: 258 | net.net = net.net.module 259 | # save the model 260 | if SAVE_MODEL: 261 | torch.save(net, model_path_prefix + "pfnn.pth") 262 | return net, m_loss_res, m_abs_e_res, m_r_abs_e_res 263 | 264 | 265 | if __name__ == "__main__": 266 | if TEST_PINN: 267 | net, m_loss_res, m_abs_e_res, m_r_abs_e_res = train_pinn() 268 | elif TEST_XPINN: 269 | net, m_loss_res, m_abs_e_res, m_r_abs_e_res = train_xpinn() 270 | elif TEST_PFNN: 271 | net, m_loss_res, m_abs_e_res, m_r_abs_e_res = train_pfnn() 272 | else: 273 | net, m_loss_res, m_abs_e_res, m_r_abs_e_res = train_hc() 274 | 275 | # test the model 276 | # net.eval() 277 | test_time(data_path, spatial_time_domain.dim - 1, net) 278 | # plot lines of testing res while training 279 | if TEST_WHILE_TRAIN: 280 | plot_lines( 281 | [list(range(1, len(m_abs_e_res) + 1)), 282 | np.array(m_loss_res) / np.max(m_loss_res), 283 | np.array(m_abs_e_res) / np.max(m_abs_e_res), 284 | np.array(m_r_abs_e_res) / np.max(m_r_abs_e_res), 285 | ], 286 | "Epochs", 287 | "Normalized result", 288 | ["loss", "m_abs_e", "m_r_abs_e"], "outs/e_while_training.png", 289 | is_log=True 290 | ) 291 | -------------------------------------------------------------------------------- /src/case2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import numpy as np 4 | import deepxde as dde 5 | from deepxde import gradients as dde_grad 6 | from src.FBPINN.fbpinn import FBPINN 7 | from src.utils.no_stdout_context import nostdout 8 | from src.utils.utils import plot_lines, test, Tester 9 | from src.utils.pinn_callback import PINNLRAdaptor, PINNLRScheduler, PINNModelSaver, PINNTester 10 | from src.configs.case2.params import model_path_prefix, data_path, xmin, xmax 11 | from src.configs.case2.pinn import spatial_domain, pde_pinn, ic_bcs, num_bcs, num_pdes, lr_alpha 12 | from src.configs.case2.hc import pde_hc, HCNN 13 | from src.configs.case2.fbpinn import sigma 14 | from src.xPINN.interface_conditions import Subdomains 15 | from src.xPINN.xPINN import xPINN 16 | 17 | 18 | # Model Selection (default: HC) 19 | TEST_PINN = False 20 | TEST_FBPINN = False # this choice is valid if TEST_PINN == True 21 | PINN_LR_ANNEALING = False # this choice is valid if TEST_PINN == True 22 | PINN_LR_ANNEALING_2 = False # this choice is valid if TEST_PINN == True 23 | TEST_XPINN = False 24 | 25 | # Other Configurations 26 | LOAD_MODEL = False 27 | SAVE_MODEL = False 28 | TEST_WHILE_TRAIN = False 29 | 30 | 31 | def train_pinn(): 32 | n_epochs = 5000 33 | lr = 1e-3 34 | data = dde.data.PDE( 35 | spatial_domain, 36 | pde_pinn, 37 | ic_bcs, 38 | num_domain=10000, 39 | num_boundary=2048 40 | ) 41 | if TEST_FBPINN: 42 | net = torch.nn.DataParallel(FBPINN(sigma, xmax[0] - xmin[0], xmax[1] - xmin[1], 6, 3, 4, 2, 30, 3, xmin=xmin)) 43 | else: 44 | net = dde.nn.FNN([2] + 6 * [50] + [3], "tanh", "Glorot normal") 45 | if LOAD_MODEL: 46 | net = torch.load(model_path_prefix + "pinn.pth") 47 | # train 48 | optimizer = optim.Adam( 49 | net.parameters(), 50 | lr=lr 51 | ) 52 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 53 | optimizer, patience=100, factor=0.75, min_lr=1e-5 54 | ) 55 | test_callback = PINNTester(data_path, spatial_domain.dim, Tester.test_while_train) 56 | loss_weights = [1.] * (num_pdes + num_bcs) 57 | lr_adaptor_callback = PINNLRAdaptor( 58 | loss_weights, num_pdes, lr_alpha, 59 | mode="max" if PINN_LR_ANNEALING else "mean" 60 | ) 61 | lr_scheduler_callback = PINNLRScheduler(scheduler) 62 | resampler = dde.callbacks.PDEResidualResampler(period=10) 63 | callbacks = [lr_scheduler_callback, resampler] 64 | if TEST_WHILE_TRAIN: 65 | callbacks.append(test_callback) 66 | if PINN_LR_ANNEALING or PINN_LR_ANNEALING_2: 67 | callbacks.append(lr_adaptor_callback) 68 | model = dde.Model(data, net) 69 | model.compile("adam", lr=lr, loss_weights=loss_weights) 70 | with nostdout(): 71 | model.train(epochs=n_epochs, callbacks=callbacks, display_every=1) 72 | # lbfgs 73 | model_saver = PINNModelSaver() 74 | resampler = dde.callbacks.PDEResidualResampler(period=1) 75 | model.compile("L-BFGS", loss_weights=loss_weights) 76 | model.train(callbacks=[resampler, model_saver]) 77 | if model_saver.got_nan: 78 | model.net.load_state_dict(model_saver.weights) 79 | m_loss_res = test_callback.m_loss_res 80 | m_abs_e_res = test_callback.m_abs_e_res 81 | m_r_abs_e_res = test_callback.m_r_abs_e_res 82 | net = model.net 83 | if TEST_FBPINN: 84 | net = net.module 85 | # save the model 86 | if SAVE_MODEL: 87 | torch.save(net, model_path_prefix + "pinn.pth") 88 | return net, m_loss_res, m_abs_e_res, m_r_abs_e_res 89 | 90 | 91 | def train_hc(): 92 | n_epochs = 5000 93 | lr = 1e-3 94 | net = HCNN(model_path_prefix if LOAD_MODEL else "") 95 | optimizer = optim.Adam( 96 | net.parameters(), 97 | lr=lr) 98 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 99 | optimizer, patience=100, factor=0.75, min_lr=1e-5 100 | ) 101 | # prepare data set 102 | X = torch.tensor(spatial_domain.random_points(10000)).float() 103 | X.requires_grad = True 104 | # train 105 | m_abs_e_res, m_r_abs_e_res, m_loss_res = [], [], [] 106 | print("Training...") 107 | for i in range(n_epochs): 108 | pred = net(X) 109 | loss = torch.cat(pde_hc(X, pred), dim=1) 110 | loss = torch.sum(loss ** 2, dim=1) 111 | loss = torch.mean(loss) 112 | # Backpropagation 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | # update the lr 117 | loss_val = loss.item() 118 | scheduler.step(loss_val) 119 | m_loss_res.append(loss_val) 120 | # test while train 121 | if TEST_WHILE_TRAIN: 122 | # need not net.eval() right now 123 | test_res = Tester.test_while_train(data_path, spatial_domain.dim, net) 124 | m_abs_e_res.append(test_res[0][0]) 125 | m_r_abs_e_res.append(test_res[1][0]) 126 | if (i+1) % 100 == 0 or i == 0: 127 | print(f"[Epoch {i+1}/{n_epochs}] loss: {loss_val:>7f}") 128 | # clear the cache of the grads 129 | dde_grad.clear() 130 | if i % 10 == 0: 131 | X = torch.tensor(spatial_domain.random_points(10000)).float() 132 | X.requires_grad = True 133 | # l-bfgs 134 | resampler = dde.callbacks.PDEResidualResampler(period=1) 135 | model_saver = PINNModelSaver() 136 | data = dde.data.TimePDE( 137 | spatial_domain, 138 | pde_hc, 139 | [], 140 | num_domain=10000 141 | ) 142 | model = dde.Model(data, net) 143 | model.compile("L-BFGS") 144 | model.train(callbacks=[resampler, model_saver]) 145 | if model_saver.got_nan: 146 | net.load_state_dict(model_saver.weights) 147 | print("Finish training!") 148 | # save the model 149 | if SAVE_MODEL: 150 | net.save(model_path_prefix) 151 | return net, m_loss_res, m_abs_e_res, m_r_abs_e_res 152 | 153 | 154 | def train_xpinn(): 155 | # prepare for xPINN 156 | subdomains = Subdomains(xmax[0] - xmin[0], xmax[1] - xmin[1], 6, 3, 3, spatial_domain) 157 | interface_points = subdomains.generate_interface_points(2048) 158 | interface_conditions = subdomains.generate_interface_conditions() 159 | # set data 160 | n_epochs = 5000 161 | lr = 1e-3 162 | data = dde.data.PDE( 163 | spatial_domain, 164 | pde_pinn, 165 | ic_bcs + interface_conditions, 166 | num_domain=10000, 167 | num_boundary=2048, 168 | anchors=interface_points 169 | ) 170 | net = torch.nn.DataParallel(xPINN(xmax[0] - xmin[0], xmax[1] - xmin[1], 6, 3, 4, 2, 30, 3, pde_pinn)) 171 | if LOAD_MODEL: 172 | net = torch.load(model_path_prefix + "xpinn.pth") 173 | # train 174 | optimizer = optim.Adam( 175 | net.parameters(), 176 | lr=lr 177 | ) 178 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 179 | optimizer, patience=100, factor=0.75, min_lr=1e-5 180 | ) 181 | test_callback = PINNTester(data_path, spatial_domain.dim, Tester.test_while_train) 182 | lr_scheduler_callback = PINNLRScheduler(scheduler) 183 | resampler = dde.callbacks.PDEResidualResampler(period=10) 184 | callbacks = [lr_scheduler_callback, resampler] 185 | if TEST_WHILE_TRAIN: 186 | callbacks.append(test_callback) 187 | model = dde.Model(data, net) 188 | model.compile("adam", lr=lr) 189 | with nostdout(): 190 | model.train(epochs=n_epochs, callbacks=callbacks, display_every=1) 191 | # lbfgs 192 | model_saver = PINNModelSaver() 193 | resampler = dde.callbacks.PDEResidualResampler(period=1) 194 | model.compile("L-BFGS") 195 | model.train(callbacks=[resampler, model_saver]) 196 | if model_saver.got_nan: 197 | model.net.load_state_dict(model_saver.weights) 198 | m_loss_res = test_callback.m_loss_res 199 | m_abs_e_res = test_callback.m_abs_e_res 200 | m_r_abs_e_res = test_callback.m_r_abs_e_res 201 | net = model.net.module 202 | # save the model 203 | if SAVE_MODEL: 204 | torch.save(net, model_path_prefix + "xpinn.pth") 205 | # set model to evaluation mode 206 | net.set_eval() 207 | return net, m_loss_res, m_abs_e_res, m_r_abs_e_res 208 | 209 | 210 | if __name__ == "__main__": 211 | if TEST_PINN: 212 | net, m_loss_res, m_abs_e_res, m_r_abs_e_res = train_pinn() 213 | elif TEST_XPINN: 214 | net, m_loss_res, m_abs_e_res, m_r_abs_e_res = train_xpinn() 215 | else: 216 | net, m_loss_res, m_abs_e_res, m_r_abs_e_res = train_hc() 217 | 218 | # test the model 219 | # net.eval() 220 | test(data_path, spatial_domain.dim, net) 221 | # plot lines of testing res while training 222 | if TEST_WHILE_TRAIN: 223 | plot_lines( 224 | [list(range(1, len(m_abs_e_res) + 1)), 225 | np.array(m_loss_res) / np.max(m_loss_res), 226 | np.array(m_abs_e_res) / np.max(m_abs_e_res), 227 | np.array(m_r_abs_e_res) / np.max(m_r_abs_e_res), 228 | ], 229 | "Epochs", 230 | "Normalized result", 231 | ["loss", "m_abs_e", "m_r_abs_e"], "outs/e_while_training.png", 232 | is_log=True 233 | ) 234 | -------------------------------------------------------------------------------- /src/case3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import numpy as np 4 | import deepxde as dde 5 | from deepxde import gradients as dde_grad 6 | from src.PFNN.pfnn import PFNN 7 | from src.configs.case3.hc import HCNN, pde_hc 8 | from src.utils.no_stdout_context import nostdout 9 | from src.utils.utils import test_time_with_reference_solution 10 | from src.utils.pinn_callback import PINNLRAdaptor, PINNLRScheduler, PINNModelSaver 11 | from src.configs.case3.params import model_path_prefix, d 12 | from src.configs.case3.pinn import reference_solution, spatial_time_domain, \ 13 | pde_pinn, ic_bcs, num_bcs, num_pdes, lr_alpha, spatial_domain 14 | from src.configs.case3.pfnn import loss_pfnn, reference_solution_pt, tot_points 15 | 16 | 17 | # Model Selection (default: HC) 18 | TEST_PINN = False 19 | TEST_PFNN = False 20 | PINN_LR_ANNEALING = False # this choice is valid if TEST_PINN == True 21 | PINN_LR_ANNEALING_2 = False # this choice is valid if TEST_PINN == True 22 | 23 | # Other Configurations 24 | LOAD_MODEL = False 25 | SAVE_MODEL = False 26 | 27 | 28 | def train_pinn(): 29 | n_epochs = 5000 30 | lr = 0.01 31 | data = dde.data.TimePDE( 32 | spatial_time_domain, 33 | pde_pinn, 34 | ic_bcs, 35 | num_domain=1000, 36 | num_boundary=100, 37 | num_initial=100 38 | ) 39 | net = dde.nn.FNN([d + 1] + 4 * [50] + [1], "tanh", "Glorot normal") 40 | if LOAD_MODEL: 41 | net = torch.load(model_path_prefix + "pinn.pth") 42 | # train 43 | optimizer = optim.Adam( 44 | net.parameters(), 45 | lr=lr 46 | ) 47 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 48 | optimizer, patience=100, factor=0.75, min_lr=1e-5 49 | ) 50 | loss_weights = [1.] * (num_pdes + num_bcs) 51 | lr_adaptor_callback = PINNLRAdaptor( 52 | loss_weights, num_pdes, lr_alpha, 53 | mode="max" if PINN_LR_ANNEALING else "mean" 54 | ) 55 | lr_scheduler_callback = PINNLRScheduler(scheduler) 56 | resampler = dde.callbacks.PDEResidualResampler(period=10) 57 | callbacks = [lr_scheduler_callback, resampler] 58 | if PINN_LR_ANNEALING or PINN_LR_ANNEALING_2: 59 | callbacks.append(lr_adaptor_callback) 60 | model = dde.Model(data, net) 61 | model.compile("adam", lr=lr, loss_weights=loss_weights) 62 | with nostdout(): 63 | model.train(epochs=n_epochs, callbacks=callbacks, display_every=1) 64 | # lbfgs 65 | model_saver = PINNModelSaver() 66 | resampler = dde.callbacks.PDEResidualResampler(period=1) 67 | model.compile("L-BFGS", loss_weights=loss_weights) 68 | model.train(callbacks=[resampler, model_saver]) 69 | if model_saver.got_nan: 70 | model.net.load_state_dict(model_saver.weights) 71 | net = model.net 72 | # save the model 73 | if SAVE_MODEL: 74 | torch.save(net, model_path_prefix + "pinn.pth") 75 | return net 76 | 77 | 78 | def train_hc(): 79 | n_epochs = 5000 80 | lr = 0.01 81 | net = HCNN(model_path_prefix if LOAD_MODEL else "") 82 | optimizer = optim.Adam( 83 | net.parameters(), 84 | lr=lr) 85 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 86 | optimizer, patience=100, factor=0.75, min_lr=1e-5 87 | ) 88 | # prepare data set 89 | X = torch.tensor(spatial_time_domain.random_points(1000)).float() 90 | X.requires_grad = True 91 | # train 92 | print("Training...") 93 | for i in range(n_epochs): 94 | pred = net(X) 95 | loss = torch.cat(pde_hc(X, pred), dim=1) 96 | loss = torch.sum(loss ** 2, dim=1) 97 | loss = torch.mean(loss) 98 | # Backpropagation 99 | optimizer.zero_grad() 100 | loss.backward() 101 | optimizer.step() 102 | # update the lr 103 | loss_val = loss.item() 104 | scheduler.step(loss_val) 105 | if (i+1) % 100 == 0 or i == 0: 106 | print(f"[Epoch {i+1}/{n_epochs}] loss: {loss_val:>7f}") 107 | # clear the cache of the grads 108 | dde_grad.clear() 109 | if i % 10 == 0: 110 | X = torch.tensor(spatial_time_domain.random_points(1000)).float() 111 | X.requires_grad = True 112 | # l-bfgs 113 | resampler = dde.callbacks.PDEResidualResampler(period=1) 114 | model_saver = PINNModelSaver(net.state_dict()) 115 | data = dde.data.TimePDE( 116 | spatial_time_domain, 117 | pde_hc, 118 | [], 119 | num_domain=1000 120 | ) 121 | model = dde.Model(data, net) 122 | model.compile("L-BFGS") 123 | model.train(callbacks=[resampler, model_saver]) 124 | if model_saver.got_nan: 125 | net.load_state_dict(model_saver.weights) 126 | print("Finish training!") 127 | # save the model 128 | if SAVE_MODEL: 129 | net.save(model_path_prefix) 130 | return net 131 | 132 | 133 | def train_pfnn(): 134 | n_epochs = 5000 135 | lr = 0.01 136 | data = dde.data.TimePDE( 137 | spatial_time_domain, 138 | loss_pfnn, 139 | [], 140 | anchors=tot_points 141 | ) 142 | net = PFNN(reference_solution_pt, d + 1, 1) 143 | if LOAD_MODEL: 144 | net = torch.load(model_path_prefix + "pfnn.pth") 145 | # train 146 | optimizer = optim.Adam( 147 | net.parameters(), 148 | lr=lr 149 | ) 150 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 151 | optimizer, patience=100, factor=0.75, min_lr=1e-5 152 | ) 153 | lr_scheduler_callback = PINNLRScheduler(scheduler) 154 | resampler = dde.callbacks.PDEResidualResampler(period=10) 155 | callbacks = [lr_scheduler_callback, resampler] 156 | model = dde.Model(data, net) 157 | model.compile("adam", lr=lr) 158 | with nostdout(): 159 | model.train(epochs=n_epochs, callbacks=callbacks, display_every=1) 160 | # lbfgs 161 | model_saver = PINNModelSaver() 162 | model.compile("L-BFGS") 163 | model.train(callbacks=[model_saver]) 164 | if model_saver.got_nan: 165 | model.net.load_state_dict(model_saver.weights) 166 | net = model.net 167 | # save the model 168 | if SAVE_MODEL: 169 | torch.save(net, model_path_prefix + "pfnn.pth") 170 | return net 171 | 172 | 173 | if __name__ == "__main__": 174 | if TEST_PINN: 175 | net = train_pinn() 176 | elif TEST_PFNN: 177 | net = train_pfnn() 178 | else: 179 | net = train_hc() 180 | 181 | # test the model 182 | # net.eval() 183 | test_X = spatial_domain.random_points(1000) 184 | test_X = np.concatenate([ 185 | np.concatenate((test_X, np.ones((test_X.shape[0], 1)) * t / 10), axis=1) 186 | for t in range(11) 187 | ], axis=0) 188 | test_time_with_reference_solution(reference_solution, test_X, net) 189 | -------------------------------------------------------------------------------- /src/case4.py: -------------------------------------------------------------------------------- 1 | import deepxde as dde 2 | import numpy as np 3 | import torch 4 | from src.utils.no_stdout_context import nostdout 5 | from src.utils.pinn_callback import PINNGradientTracker 6 | from src.utils.utils import plot_lines 7 | 8 | 9 | # Problem Selection (default: Poisson equation) 10 | TEST_SH = False # test on the nonlinear Schrödinger equation 11 | 12 | # Other Configuration 13 | a = 2 14 | 15 | # DO NOT MODIFY THIS VARIABLE 16 | EXTRA_FIELDS = False 17 | 18 | 19 | def train_poisson(): 20 | def pde(x, y): 21 | if EXTRA_FIELDS: 22 | p = y[:, 1:] 23 | du_x = dde.grad.jacobian(y, x, i=0, j=0) 24 | dp_x = dde.grad.jacobian(y, x, i=1, j=0) 25 | return [ 26 | dp_x + a ** 2 * torch.sin(a * x), 27 | p - du_x 28 | ] 29 | dy_xx = dde.grad.hessian(y, x) 30 | return dy_xx + a ** 2 * torch.sin(a * x) 31 | 32 | def boundary(_, on_boundary): 33 | return on_boundary 34 | 35 | def y_exact(x): 36 | return np.sin(a * x) 37 | # training 38 | geom = dde.geometry.Interval(0, 2 * np.pi) 39 | bc = dde.icbc.DirichletBC(geom, y_exact, boundary) 40 | data = dde.data.PDE(geom, pde, bc, 128, 2) 41 | 42 | if EXTRA_FIELDS: 43 | layer_size = [1] + [50] * 3 + [2] 44 | else: 45 | layer_size = [1] + [50] * 3 + [1] 46 | activation = "tanh" 47 | initializer = "Glorot uniform" 48 | net = dde.nn.FNN(layer_size, activation, initializer) 49 | 50 | model = dde.Model(data, net) 51 | model.compile("adam", lr=0.001) 52 | callback = PINNGradientTracker(num_pdes=2 if EXTRA_FIELDS else 1) 53 | model.train(epochs=10000, display_every=1, callbacks=[callback]) 54 | m_gradients = np.array(callback.m_gradients) 55 | steps = np.array(callback.steps) 56 | conds = np.array(callback.conds) 57 | return steps, m_gradients, conds 58 | 59 | 60 | def train_schrodinger(): 61 | ''' 62 | Source: anonymous 63 | ''' 64 | x_lower = -5 65 | x_upper = 5 66 | t_lower = 0 67 | t_upper = np.pi / 2 68 | def pde(x, y): 69 | """ 70 | INPUTS: 71 | x: x[:,0] is x-coordinate 72 | x[:,1] is t-coordinate 73 | y: Network output, in this case: 74 | y[:,0] is u(x,t) the real part 75 | y[:,1] is v(x,t) the imaginary part 76 | OUTPUT: 77 | The pde in standard form i.e. something that must be zero 78 | """ 79 | if EXTRA_FIELDS: 80 | u = y[:, 0:1] 81 | v = y[:, 1:2] 82 | _u_x = y[:, 2:3] 83 | _v_x = y[:, 3:4] 84 | 85 | # In 'jacobian', i is the output component and j is the input component 86 | u_t = dde.grad.jacobian(y, x, i=0, j=1) 87 | v_t = dde.grad.jacobian(y, x, i=1, j=1) 88 | 89 | u_x = dde.grad.jacobian(y, x, i=0, j=0) 90 | v_x = dde.grad.jacobian(y, x, i=1, j=0) 91 | 92 | # In 'hessian', i and j are both input components. (The Hessian could be in principle something like d^2y/dxdt, d^2y/d^2x etc) 93 | # The output component is selected by "component" 94 | u_xx = dde.grad.jacobian(y, x, i=2, j=0) 95 | v_xx = dde.grad.jacobian(y, x, i=3, j=0) 96 | 97 | f_u = u_t + 0.5 * v_xx + (u ** 2 + v ** 2) * v 98 | f_v = v_t - 0.5 * u_xx - (u ** 2 + v ** 2) * u 99 | 100 | return [f_u, f_v, u_x - _u_x, v_x - _v_x] 101 | u = y[:, 0:1] 102 | v = y[:, 1:2] 103 | 104 | # In 'jacobian', i is the output component and j is the input component 105 | u_t = dde.grad.jacobian(y, x, i=0, j=1) 106 | v_t = dde.grad.jacobian(y, x, i=1, j=1) 107 | 108 | u_x = dde.grad.jacobian(y, x, i=0, j=0) 109 | v_x = dde.grad.jacobian(y, x, i=1, j=0) 110 | 111 | # In 'hessian', i and j are both input components. (The Hessian could be in principle something like d^2y/dxdt, d^2y/d^2x etc) 112 | # The output component is selected by "component" 113 | u_xx = dde.grad.hessian(y, x, component=0, i=0, j=0) 114 | v_xx = dde.grad.hessian(y, x, component=1, i=0, j=0) 115 | 116 | f_u = u_t + 0.5 * v_xx + (u ** 2 + v ** 2) * v 117 | f_v = v_t - 0.5 * u_xx - (u ** 2 + v ** 2) * u 118 | 119 | return [f_u, f_v] 120 | 121 | # Space and time domains/geometry (for the deepxde model) 122 | space_domain = dde.geometry.Interval(x_lower, x_upper) 123 | time_domain = dde.geometry.TimeDomain(t_lower, t_upper) 124 | geomtime = dde.geometry.GeometryXTime(space_domain, time_domain) 125 | 126 | # Boundary and Initial conditions 127 | # Periodic Boundary conditions 128 | bc_u_0 = dde.icbc.PeriodicBC( 129 | geomtime, 0, lambda _, on_boundary: on_boundary, derivative_order=0, component=0 130 | ) 131 | bc_u_1 = dde.icbc.PeriodicBC( 132 | geomtime, 0, lambda _, on_boundary: on_boundary, derivative_order=1, component=0 133 | ) 134 | bc_v_0 = dde.icbc.PeriodicBC( 135 | geomtime, 0, lambda _, on_boundary: on_boundary, derivative_order=0, component=1 136 | ) 137 | bc_v_1 = dde.icbc.PeriodicBC( 138 | geomtime, 0, lambda _, on_boundary: on_boundary, derivative_order=1, component=1 139 | ) 140 | 141 | # Initial conditions 142 | def init_cond_u(x): 143 | "2 sech(x)" 144 | return 2 / np.cosh(x[:, 0:1]) 145 | 146 | 147 | def init_cond_v(x): 148 | return 0 149 | 150 | 151 | ic_u = dde.icbc.IC(geomtime, init_cond_u, lambda _, on_initial: on_initial, component=0) 152 | ic_v = dde.icbc.IC(geomtime, init_cond_v, lambda _, on_initial: on_initial, component=1) 153 | 154 | # training 155 | data = dde.data.TimePDE( 156 | geomtime, 157 | pde, 158 | [bc_u_0, bc_u_1, bc_v_0, bc_v_1, ic_u, ic_v], 159 | num_domain=1000, 160 | num_boundary=20, 161 | num_initial=200, 162 | train_distribution="pseudo", 163 | ) 164 | 165 | if EXTRA_FIELDS: 166 | layer_size = [2] + [100] * 5 + [4] 167 | else: 168 | layer_size = [2] + [100] * 5 + [2] 169 | activation = "tanh" 170 | initializer = "Glorot uniform" 171 | net = dde.nn.FNN(layer_size, activation, initializer) 172 | 173 | model = dde.Model(data, net) 174 | model.compile("adam", lr=0.001) 175 | callback = PINNGradientTracker(num_pdes=4 if EXTRA_FIELDS else 2) 176 | model.train(epochs=10000, display_every=1, callbacks=[callback]) 177 | m_gradients = np.array(callback.m_gradients) 178 | steps = np.array(callback.steps) 179 | conds = np.array(callback.conds) 180 | return steps, m_gradients, conds 181 | 182 | 183 | if __name__ == "__main__": 184 | with nostdout(): 185 | steps, m_gradients_raw, conds_raw = train_schrodinger() if TEST_SH else train_poisson() 186 | EXTRA_FIELDS = True 187 | steps, m_gradients_ef, conds_ef = train_schrodinger() if TEST_SH else train_poisson() 188 | # plot the gradients history 189 | plot_lines( 190 | [steps[1::100], np.abs(m_gradients_raw[1::100]), np.abs(m_gradients_ef[1::100])], 191 | "Steps", 192 | "Mean absolute gradients (abs)", 193 | ["Origin", "Extra Fields"], "outs/mean_gradients_while_training.png", 194 | is_log=True 195 | ) 196 | # plot the cond history 197 | plot_lines( 198 | [steps[1::100], np.abs(conds_raw[1::100]), np.abs(conds_ef[1::100])], 199 | "Steps", 200 | "Condition numbers (abs)", 201 | ["Origin", "Extra Fields"], "outs/conds_while_training.png", 202 | is_log=True 203 | ) 204 | -------------------------------------------------------------------------------- /src/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/src/configs/__init__.py -------------------------------------------------------------------------------- /src/configs/case1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/src/configs/case1/__init__.py -------------------------------------------------------------------------------- /src/configs/case1/fbpinn.py: -------------------------------------------------------------------------------- 1 | # scale factor of the window function 2 | sigma = 0.4 -------------------------------------------------------------------------------- /src/configs/case1/hc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import deepxde as dde 5 | from src.HC.hard_constraint import HardConstraintRobin2D 6 | from src.HC.hard_constraint_collector import HardConstraintCollector 7 | from src.HC.l_functions import LFunctionDisk, LFunctionRectangle 8 | from src.HC.normal_function import NormalFunctionDisk, NormalFunctionRectangle 9 | from src.configs.case1.pinn import rec, disks_c, disks_w 10 | from src.utils.nn_wrapper import NNWrapper 11 | from src.configs.case1.params import * 12 | 13 | def pde_hc(x, u): 14 | p_1, p_2 = u[:, 1:2], u[:, 2:] 15 | T_x = dde.grad.jacobian(u, x, i=0, j=0) 16 | T_y = dde.grad.jacobian(u, x, i=0, j=1) 17 | T_t = dde.grad.jacobian(u, x, i=0, j=2) 18 | p_1_x = dde.grad.jacobian(u, x, i=1, j=0) 19 | p_2_y = dde.grad.jacobian(u, x, i=2, j=1) 20 | 21 | res = T_t - k * (p_1_x + p_2_y) 22 | p_1_res = p_1 - T_x 23 | p_2_res = p_2 - T_y 24 | return [res, p_1_res, p_2_res] 25 | 26 | 27 | # Helper functions (input: x, a B x d tensor) 28 | disks = disks_c + disks_w 29 | X = rec.random_boundary_points(512) 30 | for disk in disks: 31 | X = np.concatenate((X, disk.random_boundary_points(256)), axis=0) 32 | # M function 33 | beta = 4.0 34 | M = lambda x: torch.logsumexp(-beta * x, dim=1, keepdim=True) / (-beta) 35 | # Lambda functions 36 | l_gamma_rec = LFunctionRectangle(X, rec, m_function=M) 37 | l_gamma_disks_c = [LFunctionDisk(X, disk) for disk in disks_c] 38 | l_gamma_disks_w = [LFunctionDisk(X, disk) for disk in disks_w] 39 | # Normal functions 40 | n_gamma_outer = NormalFunctionRectangle(H, L) 41 | n_gamma_disks_c = [NormalFunctionDisk(center=disk_centers_c[i]) for i in range(len(disk_centers_c))] 42 | n_gamma_disks_w = [NormalFunctionDisk(center=disk_centers_w[i]) for i in range(len(disk_centers_w))] 43 | 44 | # hard constraints 45 | hc_gamma_outer = HardConstraintRobin2D( 46 | n_gamma_outer, 47 | NNWrapper(dde.nn.FNN([3] + 3 * [20] + [3], "tanh", "Glorot normal")), 48 | lambda _: 1., 49 | lambda _: 1., 50 | lambda _: 1e-1 51 | ) 52 | hc_gamma_disks_c = [ 53 | HardConstraintRobin2D( 54 | n_gamma_disk, 55 | NNWrapper(dde.nn.FNN([3] + 3 * [20] + [3], "tanh", "Glorot normal")), 56 | lambda _: 1., 57 | lambda _: 1., 58 | lambda _: 5. 59 | ) for n_gamma_disk in n_gamma_disks_c 60 | ] 61 | hc_gamma_disks_w = [ 62 | HardConstraintRobin2D( 63 | n_gamma_disk, 64 | NNWrapper(dde.nn.FNN([3] + 3 * [20] + [3], "tanh", "Glorot normal")), 65 | lambda _: 1., 66 | lambda _: 1., 67 | lambda _: 1. 68 | ) for n_gamma_disk in n_gamma_disks_w 69 | ] 70 | 71 | # model 72 | class HCNN(nn.Module): 73 | """ 74 | Hard constraint model. 75 | """ 76 | def __init__(self, path_prefix="") -> None: 77 | super(HCNN, self).__init__() 78 | # NNs 79 | self.hc_gamma_outer = hc_gamma_outer 80 | self.hc_gamma_disks = nn.ModuleList(hc_gamma_disks_c + hc_gamma_disks_w) 81 | self.N_main = NNWrapper(dde.nn.FNN([3] + 4 * [50] + [3], "tanh", "Glorot normal")) 82 | if path_prefix != "": 83 | self.load(path_prefix) 84 | # hard constraint for each components 85 | self.HCC_T = HardConstraintCollector( 86 | 0, M, [l_gamma_rec] + l_gamma_disks_c + l_gamma_disks_w, 87 | [ 88 | hc_gamma_outer.get_u 89 | ] + [ 90 | hc_gamma_disk.get_u for hc_gamma_disk in self.hc_gamma_disks 91 | ], self.N_main 92 | ) 93 | self.HCC_p_1 = HardConstraintCollector( 94 | 1, M, [l_gamma_rec] + l_gamma_disks_c + l_gamma_disks_w, 95 | [ 96 | hc_gamma_outer.get_p_1 97 | ] + [ 98 | hc_gamma_disk.get_p_1 for hc_gamma_disk in self.hc_gamma_disks 99 | ], self.N_main 100 | ) 101 | self.HCC_p_2 = HardConstraintCollector( 102 | 2, M, [l_gamma_rec] + l_gamma_disks_c + l_gamma_disks_w, 103 | [ 104 | hc_gamma_outer.get_p_2 105 | ] + [ 106 | hc_gamma_disk.get_p_2 for hc_gamma_disk in self.hc_gamma_disks 107 | ], self.N_main 108 | ) 109 | 110 | def save(self, path_prefix: str): 111 | torch.save(self.N_main, path_prefix + "hc_main.pth") 112 | self.hc_gamma_outer.save(path_prefix, "hc_g_o.pth") 113 | for i in range(len(self.hc_gamma_disks)): 114 | self.hc_gamma_disks[i].save(path_prefix, "hc_g_d%d.pth"%i) 115 | 116 | def load(self, path_prefix: str): 117 | self.N_main = torch.load(path_prefix + "hc_main.pth") 118 | self.hc_gamma_outer.load(path_prefix, "hc_g_o.pth") 119 | for i in range(len(self.hc_gamma_disks)): 120 | self.hc_gamma_disks[i].load(path_prefix, "hc_g_d%d.pth"%i) 121 | 122 | def forward(self, x) -> torch.Tensor: 123 | self.N_main.clear_res() 124 | self.hc_gamma_outer.clear_res() 125 | for hc_gamma_disk in self.hc_gamma_disks: 126 | hc_gamma_disk.clear_res() 127 | time_factor = torch.exp(-10 * x[:, 2:3]) 128 | return torch.cat(( 129 | self.HCC_T(x) * (1 - time_factor) + 1e-1 * time_factor, 130 | self.HCC_p_1(x), 131 | self.HCC_p_2(x) 132 | ), dim=1) 133 | -------------------------------------------------------------------------------- /src/configs/case1/params.py: -------------------------------------------------------------------------------- 1 | from src.utils.utils import load_params_disks 2 | 3 | # path 4 | data_path = "data/case1_pack.txt" 5 | model_path_prefix = "model/case1_" 6 | 7 | # parameters for PDEs 8 | H = 24 9 | L = 16 10 | k = 1. # thermal conductivity 11 | 12 | # parameters of the geometry 13 | # read from the input file 14 | disk_centers, disk_rs = load_params_disks(data_path) 15 | disk_centers_c = disk_centers[:-6] 16 | disk_rs_c = disk_rs[:-6] 17 | disk_centers_w = disk_centers[-6:] 18 | disk_rs_w = disk_rs[-6:] 19 | -------------------------------------------------------------------------------- /src/configs/case1/pfnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import deepxde as dde 4 | from src.configs.case1.params import * 5 | from src.configs.case1.pinn import rec, disks_c, disks_w, time_domain, spatial_time_domain 6 | 7 | # Sample Points 8 | omega_points = spatial_time_domain.random_points(8192) 9 | rec_points = dde.geometry.GeometryXTime(rec, time_domain).random_boundary_points(512 // 3) 10 | disk_c_points = None 11 | for disk_c in disks_c: 12 | new_points = dde.geometry.GeometryXTime(disk_c, time_domain) \ 13 | .random_boundary_points(512 // 3 // len(disks_c)) 14 | if disk_c_points is None: 15 | disk_c_points = new_points 16 | else: 17 | disk_c_points = np.concatenate(( 18 | disk_c_points, new_points 19 | )) 20 | disk_w_points = None 21 | for disk_w in disks_w: 22 | new_points = dde.geometry.GeometryXTime(disk_w, time_domain) \ 23 | .random_boundary_points(512 // 3 // len(disks_w)) 24 | if disk_w_points is None: 25 | disk_w_points = new_points 26 | else: 27 | disk_w_points = np.concatenate(( 28 | disk_w_points, new_points 29 | )) 30 | tot_points = np.concatenate((omega_points, rec_points, disk_c_points, disk_w_points)) 31 | 32 | # Area and numbers 33 | area_omega = L * H - np.pi * (disk_rs_c[0] ** 2) * len(disk_centers_c) \ 34 | - np.pi * (disk_rs_w[0] ** 2) * len(disk_centers_w) 35 | num_omega = omega_points.shape[0] 36 | len_rec = 2 * (L + H) 37 | num_rec = rec_points.shape[0] 38 | len_disk_c = 2 * np.pi * disk_rs_c[0] * len(disk_centers_c) 39 | num_disk_c = disk_c_points.shape[0] 40 | len_disk_w = 2 * np.pi * disk_rs_w[0] * len(disk_centers_w) 41 | num_disk_w = disk_w_points.shape[0] 42 | 43 | def loss_pfnn(x, T): 44 | T_x = dde.grad.jacobian(T, x, i=0, j=0) 45 | T_y = dde.grad.jacobian(T, x, i=0, j=1) 46 | T_t = dde.grad.jacobian(T, x, i=0, j=2) 47 | # variational form 48 | # a(u,u) 49 | start = 0 50 | a = area_omega * torch.mean(k * (T_x[start:start+num_omega, :] ** 2 51 | + T_y[start:start+num_omega, :] ** 2)) 52 | start += num_omega 53 | a += len_rec * torch.mean(T[start:start+num_rec, :] ** 2) 54 | start += num_rec 55 | a += len_disk_c * torch.mean(T[start:start+num_disk_c, :] ** 2) 56 | start += num_disk_c 57 | a += len_disk_w * torch.mean(T[start:start+num_disk_w, :] ** 2) 58 | start += num_disk_w 59 | # L(u) 60 | start = 0 61 | L = area_omega * torch.mean(-T_t[start:start+num_omega, :] * T[start:start+num_omega, :]) 62 | start += num_omega 63 | L += len_rec * torch.mean(T[start:start+num_rec, :] * 1e-1) 64 | start += num_rec 65 | L += len_disk_c * torch.mean(T[start:start+num_disk_c, :] * 5.) 66 | start += num_disk_c 67 | L += len_disk_w * torch.mean(T[start:start+num_disk_w, :] * 1.) 68 | start += num_disk_w 69 | 70 | return (0.5 * a - L).reshape(1, 1) 71 | 72 | initial_condition = lambda _: 1e-1 -------------------------------------------------------------------------------- /src/configs/case1/pinn.py: -------------------------------------------------------------------------------- 1 | import deepxde as dde 2 | from src.configs.case1.params import * 3 | from src.utils.pinn_geometry import RecDiskDomain 4 | 5 | rec = dde.geometry.Rectangle(xmin=[-L/2, -H/2], xmax=[L/2, H/2]) 6 | disks_c = [dde.geometry.Disk(center=disk_centers_c[i], radius=disk_rs_c[i]) for i in range(len(disk_centers_c))] 7 | disks_w = [dde.geometry.Disk(center=disk_centers_w[i], radius=disk_rs_w[i]) for i in range(len(disk_centers_w))] 8 | spatial_domain = RecDiskDomain(rec, disks_c + disks_w) 9 | time_domain = dde.geometry.TimeDomain(0, 1) 10 | spatial_time_domain = dde.geometry.GeometryXTime(spatial_domain, time_domain) 11 | 12 | 13 | def pde_pinn(x, T): 14 | T_t = dde.grad.jacobian(T, x, i=0, j=2) 15 | T_xx = dde.grad.hessian(T, x, i=0, j=0) 16 | T_yy = dde.grad.hessian(T, x, i=1, j=1) 17 | return T_t - k * (T_xx + T_yy) 18 | 19 | 20 | rec_bc = dde.icbc.RobinBC( 21 | spatial_time_domain, lambda _, T: 1e-1 - T, 22 | lambda x, on_bc: on_bc and rec.on_boundary(x[:2]) 23 | ) 24 | 25 | bc_disks = [] 26 | for i in range(len(disks_c)): 27 | bc_disks.append( 28 | dde.icbc.RobinBC( 29 | spatial_time_domain, lambda _, T: 5. - T, 30 | lambda x, on_bc, j=i: on_bc and disks_c[j].on_boundary(x[:2]) 31 | ) 32 | ) 33 | 34 | for i in range(len(disks_w)): 35 | bc_disks.append( 36 | dde.icbc.RobinBC( 37 | spatial_time_domain, lambda _, T: 1. - T, 38 | lambda x, on_bc, j=i: on_bc and disks_w[j].on_boundary(x[0:2]) 39 | ) 40 | ) 41 | 42 | ic = dde.icbc.IC( 43 | spatial_time_domain, lambda _: 1e-1, 44 | lambda _, on_initial: on_initial 45 | ) 46 | 47 | ic_bcs = [rec_bc] + bc_disks + [ic] 48 | 49 | # parameters for PINN 50 | # moving average in learning rate annealing 51 | lr_alpha = 0.1 52 | # num pdes 53 | num_pdes = 1 54 | num_bcs = len(ic_bcs) -------------------------------------------------------------------------------- /src/configs/case1/xpinn.py: -------------------------------------------------------------------------------- 1 | import deepxde as dde 2 | from src.configs.case1.params import * 3 | 4 | def pde_xpinn(x, T): 5 | T_t = dde.grad.jacobian(T, x, i=0, j=2) 6 | T_xx = dde.grad.hessian(T, x, i=0, j=0, component=0) 7 | T_yy = dde.grad.hessian(T, x, i=1, j=1, component=0) 8 | return T_t - k * (T_xx + T_yy) -------------------------------------------------------------------------------- /src/configs/case2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/src/configs/case2/__init__.py -------------------------------------------------------------------------------- /src/configs/case2/fbpinn.py: -------------------------------------------------------------------------------- 1 | # scale factor of the window function 2 | sigma = 0.2 -------------------------------------------------------------------------------- /src/configs/case2/hc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import deepxde as dde 4 | from src.HC.hard_constraint import HardConstraintNeumann2D 5 | from src.HC.hard_constraint_collector import HardConstraintCollector 6 | from src.HC.l_functions import LFunctionAxisLine, LFunctionOpenRectangle, LFunctionPolygon 7 | from src.HC.normal_function import NormalFunctionPolygon 8 | from src.configs.case2.pinn import spatial_domain, rec, airfoil 9 | from src.configs.case2.params import * 10 | from src.utils.nn_wrapper import NNWrapper 11 | 12 | 13 | def pde_hc(x, u): 14 | u_vel, v_vel, p, _u_x, _u_y, _v_x, _v_y = \ 15 | u[:, 0:1], u[:, 1:2], u[:, 2:3], u[:, 3:4], u[:, 4:5], u[:, 5:6], u[:, 6:7] 16 | u_vel_x = dde.grad.jacobian(u, x, i=0, j=0) 17 | u_vel_y = dde.grad.jacobian(u, x, i=0, j=1) 18 | v_vel_x = dde.grad.jacobian(u, x, i=1, j=0) 19 | v_vel_y = dde.grad.jacobian(u, x, i=1, j=1) 20 | p_x = dde.grad.jacobian(u, x, i=2, j=0) 21 | p_y = dde.grad.jacobian(u, x, i=2, j=1) 22 | 23 | u_vel_xx = dde.grad.jacobian(u, x, i=3, j=0) 24 | u_vel_yy = dde.grad.jacobian(u, x, i=4, j=1) 25 | v_vel_xx = dde.grad.jacobian(u, x, i=5, j=0) 26 | v_vel_yy = dde.grad.jacobian(u, x, i=6, j=1) 27 | 28 | res_u_x = (_u_x - u_vel_x) 29 | res_u_y = (_u_y - u_vel_y) 30 | res_v_x = (_v_x - v_vel_x) 31 | res_v_y = (_v_y - v_vel_y) 32 | momentum_x = ( 33 | u_vel * _u_x + v_vel * _u_y + p_x - nu * (u_vel_xx + u_vel_yy) 34 | ) 35 | momentum_y = ( 36 | u_vel * _v_x + v_vel * _v_y + p_y - nu * (v_vel_xx + v_vel_yy) 37 | ) 38 | continuity = _u_x + _v_y 39 | 40 | return [res_u_x, res_u_y, res_v_x, res_v_y, momentum_x, momentum_y, continuity] 41 | 42 | # Helper functions (input: x, a B x d tensor) 43 | X_rec_bc = rec.random_boundary_points(512) 44 | X_airfoil_bc = airfoil.random_boundary_points(512) 45 | # M function 46 | beta = 4.0 47 | M = lambda x: torch.logsumexp(-beta * x, dim=1, keepdim=True) / (-beta) 48 | # Lambda functions 49 | l_gamma_openrec = LFunctionOpenRectangle(X_airfoil_bc, rec, m_function=M) 50 | l_gamma_right = LFunctionAxisLine(X_airfoil_bc, rec, xmax[0], 0, is_left=False) 51 | l_gamma_airfoil = LFunctionPolygon(X_rec_bc, airfoil, spatial_domain) 52 | # Normal functions 53 | n_gamma_airfoil = NormalFunctionPolygon(airfoil) 54 | 55 | # hard constraints 56 | hc_gamma_airfoil = HardConstraintNeumann2D( 57 | n_gamma_airfoil, 58 | NNWrapper(dde.nn.FNN([2] + 4 * [40] + [1], "tanh", "Glorot normal")), 59 | lambda _: 0. 60 | ) 61 | 62 | # model 63 | class HCNN(nn.Module): 64 | """ 65 | Hard constraint model. 66 | """ 67 | def __init__(self, path_prefix="") -> None: 68 | super(HCNN, self).__init__() 69 | # NNs 70 | self.hc_gamma_airfoil = hc_gamma_airfoil 71 | self.N_main = NNWrapper(dde.nn.FNN([2] + 6 * [50] + [7], "tanh", "Glorot normal")) 72 | if path_prefix != "": 73 | self.load(path_prefix) 74 | # hard constraint for each components 75 | self.HCC_u = HardConstraintCollector( 76 | 0, M, [l_gamma_openrec, l_gamma_airfoil], 77 | [lambda _: 1., hc_gamma_airfoil.get_p_1], self.N_main 78 | ) 79 | self.HCC_v = HardConstraintCollector( 80 | 1, M, [l_gamma_openrec, l_gamma_airfoil], 81 | [lambda _: 0., hc_gamma_airfoil.get_p_2], self.N_main 82 | ) 83 | self.HCC_p = HardConstraintCollector( 84 | 2, M, [l_gamma_right], 85 | [lambda _: 1.], self.N_main 86 | ) 87 | 88 | def save(self, path_prefix: str): 89 | torch.save(self.N_main, path_prefix + "hc_main.pth") 90 | self.hc_gamma_airfoil.save(path_prefix, "hc_g_a.pth") 91 | 92 | def load(self, path_prefix: str): 93 | self.N_main = torch.load(path_prefix + "hc_main.pth") 94 | self.hc_gamma_airfoil.load(path_prefix, "hc_g_a.pth") 95 | 96 | def forward(self, x) -> torch.Tensor: 97 | self.N_main.clear_res() 98 | self.hc_gamma_airfoil.clear_res() 99 | return torch.cat( 100 | [ 101 | self.HCC_u(x), 102 | self.HCC_v(x), 103 | self.HCC_p(x) 104 | ] + [self.N_main(x, i) for i in range(3, 7)] 105 | , dim=1) 106 | -------------------------------------------------------------------------------- /src/configs/case2/params.py: -------------------------------------------------------------------------------- 1 | from src.utils.utils import load_airfoil_points 2 | 3 | # path 4 | data_path = "data/case2_airfoil.txt" 5 | airfoil_path = "data/w1015.dat" 6 | model_path_prefix = "model/case2_" 7 | 8 | # parameters for PDEs 9 | xmin = [-2, -2] 10 | xmax = [6, 2] 11 | Re = 50. 12 | nu = 1 / Re # viscosity 13 | 14 | # the anchor points of the airfoil 15 | anchor_points = load_airfoil_points(airfoil_path) 16 | -------------------------------------------------------------------------------- /src/configs/case2/pfnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.PFNN.pfnn import PFNN 3 | from src.configs.case2.hc import l_gamma_openrec, l_gamma_right 4 | 5 | class PFNN_NS(PFNN): 6 | """ 7 | Penalty-Free Neural Network for NS equation. 8 | """ 9 | 10 | def forward(self, x) -> torch.Tensor: 11 | raw_output = self.net(x) 12 | u_hat = raw_output[:, 0] * l_gamma_openrec.get_dist(x) + 1. 13 | v_hat = raw_output[:, 1] * l_gamma_openrec.get_dist(x) + 0. 14 | p_hat = raw_output[:, 2] * l_gamma_right.get_dist(x) + 1. 15 | return torch.cat([u_hat, v_hat, p_hat], dim=1) -------------------------------------------------------------------------------- /src/configs/case2/pinn.py: -------------------------------------------------------------------------------- 1 | import deepxde as dde 2 | import numpy as np 3 | from src.configs.case2.params import * 4 | from src.utils.pinn_bc import NormalBC 5 | 6 | 7 | rec = dde.geometry.Rectangle(xmin=xmin, xmax=xmax) 8 | airfoil = dde.geometry.Polygon(anchor_points) 9 | spatial_domain = dde.geometry.CSGDifference(rec, airfoil) 10 | 11 | 12 | def pde_pinn(x, u): 13 | u_vel, v_vel, p = u[:, 0:1], u[:, 1:2], u[:, 2:3] 14 | u_vel_x = dde.grad.jacobian(u, x, i=0, j=0) 15 | u_vel_y = dde.grad.jacobian(u, x, i=0, j=1) 16 | u_vel_xx = dde.grad.hessian(u, x, component=0, i=0, j=0) 17 | u_vel_yy = dde.grad.hessian(u, x, component=0, i=1, j=1) 18 | v_vel_x = dde.grad.jacobian(u, x, i=1, j=0) 19 | v_vel_y = dde.grad.jacobian(u, x, i=1, j=1) 20 | v_vel_xx = dde.grad.hessian(u, x, component=1, i=0, j=0) 21 | v_vel_yy = dde.grad.hessian(u, x, component=1, i=1, j=1) 22 | p_x = dde.grad.jacobian(u, x, i=2, j=0) 23 | p_y = dde.grad.jacobian(u, x, i=2, j=1) 24 | 25 | momentum_x = ( 26 | u_vel * u_vel_x + v_vel * u_vel_y + p_x - nu * (u_vel_xx + u_vel_yy) 27 | ) 28 | momentum_y = ( 29 | u_vel * v_vel_x + v_vel * v_vel_y + p_y - nu * (v_vel_xx + v_vel_yy) 30 | ) 31 | continuity = u_vel_x + v_vel_y 32 | 33 | return [momentum_x, momentum_y, continuity] 34 | 35 | # u0, v0 36 | in_u_bc = dde.icbc.DirichletBC( 37 | spatial_domain, lambda _: 1., 38 | lambda x, on_bc: on_bc and (np.isclose(x[0], xmin[0]) or 39 | np.isclose(x[1], xmin[1]) or np.isclose(x[1], xmax[1])), 40 | component=0 41 | ) 42 | 43 | in_v_bc = dde.icbc.DirichletBC( 44 | spatial_domain, lambda _: 0., 45 | lambda x, on_bc: on_bc and (np.isclose(x[0], xmin[0]) or 46 | np.isclose(x[1], xmin[1]) or np.isclose(x[1], xmax[1])), 47 | component=1 48 | ) 49 | 50 | out_p_bc = dde.icbc.DirichletBC( 51 | spatial_domain, lambda _: 1., 52 | lambda x, on_bc: on_bc and np.isclose(x[0], xmax[0]), 53 | component=2 54 | ) 55 | 56 | airfoil_bc = NormalBC( 57 | spatial_domain, lambda _: 0., 58 | lambda x, on_bc: on_bc and (not rec.on_boundary(x)), 59 | component=0 60 | ) 61 | 62 | ic_bcs = [in_u_bc, in_v_bc, out_p_bc, airfoil_bc] 63 | 64 | # parameters for PINN 65 | # moving average in learning rate annealing 66 | lr_alpha = 0.1 67 | # num pdes 68 | num_pdes = 3 69 | num_bcs = len(ic_bcs) -------------------------------------------------------------------------------- /src/configs/case3/hc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import deepxde as dde 4 | from src.HC.hard_constraint import HardConstraintNeumannND 5 | from src.HC.normal_function import NormalFunctionSphere 6 | from src.utils.nn_wrapper import NNWrapper 7 | from src.configs.case3.params import * 8 | 9 | def pde_hc(x, u): 10 | ps = [u[:, i+1:i+2] for i in range(d)] 11 | T_xs = [dde.grad.jacobian(u, x, i=0, j=i) for i in range(d)] 12 | ps_xs = [dde.grad.jacobian(u, x, i=i+1, j=i) for i in range(d)] 13 | T_t = dde.grad.jacobian(u, x, i=0, j=d) 14 | squared_norm = torch.sum(x[:, :d] ** 2, dim=1, keepdim=True) 15 | f = -alpha * squared_norm * torch.exp(0.5 * squared_norm + x[:, d:d+1]) 16 | 17 | delta_T = 0. 18 | for i in range(d): 19 | delta_T += ps_xs[i] 20 | 21 | res = T_t - alpha * delta_T - f 22 | ps_res = [ps[i] - T_xs[i] for i in range(d)] 23 | return [res] + ps_res 24 | 25 | # Distance function 26 | dist_sphere = lambda x: 1. - torch.sum(x[:, :d] ** 2, dim=1, keepdim=True) 27 | # Normal functions 28 | n_gamma_sphere = NormalFunctionSphere(center=[0] * d, inner=False) 29 | 30 | def reference_solution_pt(x): 31 | squared_norm = torch.sum(x[:, :d] ** 2, dim=1, keepdim=True) 32 | return torch.exp(0.5 * squared_norm + x[:, d:d+1]) 33 | 34 | # hard constraints 35 | hc_gamma_sphere = HardConstraintNeumannND( 36 | n_gamma_sphere, 37 | NNWrapper(dde.nn.FNN([d+1] + 3 * [20] + [d], "tanh", "Glorot normal")), 38 | reference_solution_pt 39 | ) 40 | 41 | # model 42 | class HCNN(nn.Module): 43 | """ 44 | Hard constraint model. 45 | """ 46 | def __init__(self, path_prefix="") -> None: 47 | super(HCNN, self).__init__() 48 | # NNs 49 | self.hc_gamma = hc_gamma_sphere 50 | self.N_main = NNWrapper(dde.nn.FNN([d+1] + 4 * [50] + [d+1], "tanh", "Glorot normal")) 51 | if path_prefix != "": 52 | self.load(path_prefix) 53 | 54 | # hard constraint for each components 55 | def HCC_T(self, x): 56 | time_factor = torch.exp(-10 * x[:, d:d+1]) 57 | return self.N_main(x, 0) * (1 - time_factor) + \ 58 | reference_solution_pt(x) * time_factor 59 | 60 | def HCC_p_i(self, x, i): 61 | return self.N_main(x, i+1) * dist_sphere(x) + \ 62 | self.hc_gamma.get_p_i(x, i) 63 | 64 | def save(self, path_prefix: str): 65 | torch.save(self.N_main, path_prefix + "hc_main.pth") 66 | self.hc_gamma.save(path_prefix, "hc_g_o.pth") 67 | for i in range(len(self.hc_gamma_disks)): 68 | self.hc_gamma_disks[i].save(path_prefix, "hc_g_d%d.pth"%i) 69 | 70 | def load(self, path_prefix: str): 71 | self.N_main = torch.load(path_prefix + "hc_main.pth") 72 | self.hc_gamma.load(path_prefix, "hc_g_o.pth") 73 | for i in range(len(self.hc_gamma_disks)): 74 | self.hc_gamma_disks[i].load(path_prefix, "hc_g_d%d.pth"%i) 75 | 76 | def forward(self, x) -> torch.Tensor: 77 | self.N_main.clear_res() 78 | self.hc_gamma.clear_res() 79 | return torch.cat( 80 | [self.HCC_T(x)] + [ 81 | self.HCC_p_i(x, i) for i in range(d) 82 | ], dim=1) 83 | -------------------------------------------------------------------------------- /src/configs/case3/params.py: -------------------------------------------------------------------------------- 1 | # path 2 | model_path_prefix = "model/case3_" 3 | 4 | # parameters for PDEs 5 | d = 10 6 | alpha = 1 / d 7 | -------------------------------------------------------------------------------- /src/configs/case3/pfnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import deepxde as dde 5 | from src.configs.case3.params import * 6 | from src.configs.case3.pinn import spatial_time_domain, reference_solution 7 | 8 | # Sample Points 9 | omega_points = spatial_time_domain.random_points(1000) 10 | boundary_points = spatial_time_domain.random_boundary_points(100) 11 | tot_points = np.concatenate((omega_points, boundary_points)) 12 | 13 | # Area and numbers 14 | volume_omega = math.pi ** (d/2) / math.gamma(d/2 + 1) 15 | num_omega = omega_points.shape[0] 16 | area_boundary = 2 * math.pi ** (d/2) / math.gamma(d/2) 17 | num_boundary = boundary_points.shape[0] 18 | 19 | def loss_pfnn(x, T): 20 | T_t = dde.grad.jacobian(T, x, i=0, j=d) 21 | nabla_dot_nabla_T = 0. 22 | for i in range(d): 23 | nabla_dot_nabla_T += dde.grad.jacobian(T, x, i=0, j=i) ** 2 24 | squared_norm = torch.sum(x[:, :d] ** 2, dim=1, keepdim=True) 25 | f = -alpha * squared_norm * torch.exp(0.5 * squared_norm + x[:, d:d+1]) 26 | 27 | # variational form 28 | # a(u,u) 29 | start = 0 30 | a = volume_omega * torch.mean(alpha * nabla_dot_nabla_T) 31 | start += num_omega 32 | # L(u) 33 | start = 0 34 | L = volume_omega * torch.mean((f[start:start+num_omega, :]-T_t[start:start+num_omega, :]) 35 | * T[start:start+num_omega, :]) 36 | start += num_omega 37 | L += area_boundary * torch.mean(alpha * T[start:start+num_boundary, :] 38 | * reference_solution_pt(x[start:start+num_boundary, :])) 39 | start += num_boundary 40 | 41 | return (0.5 * a - L).reshape(1, 1) 42 | 43 | 44 | def reference_solution_pt(x): 45 | squared_norm = torch.sum(x[:, :d] ** 2, dim=1, keepdim=True) 46 | return torch.exp(0.5 * squared_norm + x[:, d:d+1]) -------------------------------------------------------------------------------- /src/configs/case3/pinn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import deepxde as dde 4 | from src.configs.case3.params import * 5 | 6 | spatial_domain = dde.geometry.Hypersphere([0] * d, 1) 7 | time_domain = dde.geometry.TimeDomain(0, 1) 8 | spatial_time_domain = dde.geometry.GeometryXTime(spatial_domain, time_domain) 9 | 10 | 11 | def pde_pinn(x, T): 12 | T_t = dde.grad.jacobian(T, x, i=0, j=d) 13 | delta_T = 0. 14 | for i in range(d): 15 | delta_T += dde.grad.hessian(T, x, i=i, j=i) 16 | squared_norm = torch.sum(x[:, :d] ** 2, dim=1, keepdim=True) 17 | f = -alpha * squared_norm * torch.exp(0.5 * squared_norm + x[:, d:d+1]) 18 | return T_t - alpha * delta_T - f 19 | 20 | 21 | def reference_solution(x): 22 | squared_norm = np.sum(x[:, :d] ** 2, axis=1, keepdims=True) 23 | return np.exp(0.5 * squared_norm + x[:, d:d+1]) 24 | 25 | 26 | bc = dde.icbc.NeumannBC( 27 | spatial_time_domain, reference_solution, 28 | lambda _, on_bc: on_bc 29 | ) 30 | 31 | 32 | ic = dde.icbc.IC( 33 | spatial_time_domain, reference_solution, 34 | lambda _, on_initial: on_initial 35 | ) 36 | 37 | ic_bcs = [bc, ic] 38 | 39 | # parameters for PINN 40 | # moving average in learning rate annealing 41 | lr_alpha = 0.1 42 | # num pdes 43 | num_pdes = 1 44 | num_bcs = len(ic_bcs) -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuastt/HardConstraint/c9363e901cb6220801d9f90ae0a19cc2b7e9677d/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/nn_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class NNWrapper(nn.Module): 5 | """ 6 | Wrapper a NN to avoid repeating calculations. 7 | """ 8 | def __init__(self, NN: nn.Module) -> None: 9 | super(NNWrapper, self).__init__() 10 | self.NN = NN 11 | self.res = None # calculation result 12 | 13 | def clear_res(self) -> None: 14 | ''' 15 | Clear the calculation result of the last batch. 16 | Note: call it at the beginning of a batch. 17 | ''' 18 | self.res = None 19 | 20 | def forward(self, x, i=None) -> torch.Tensor: 21 | ''' 22 | x - coordinate 23 | i - the ith column of the result will be returned (keepdim) 24 | ''' 25 | if self.res is None: 26 | self.res = self.NN(x) 27 | if self.res.dim() == 1: 28 | assert i == 0 or i is None 29 | return self.res[:] 30 | if i is None: 31 | return self.res 32 | else: 33 | return self.res[:, i:i+1] 34 | -------------------------------------------------------------------------------- /src/utils/no_stdout_context.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrap a function to suppress its stdoutput. 3 | Source: anonymous 4 | """ 5 | import contextlib 6 | import sys 7 | 8 | class DummyFile(object): 9 | def write(self, _): pass 10 | 11 | def flush(self): pass 12 | 13 | @contextlib.contextmanager 14 | def nostdout(): 15 | save_stdout = sys.stdout 16 | sys.stdout = DummyFile() 17 | yield 18 | sys.stdout = save_stdout 19 | -------------------------------------------------------------------------------- /src/utils/pinn_bc.py: -------------------------------------------------------------------------------- 1 | import deepxde.utils as utils 2 | import deepxde.backend as bkd 3 | from deepxde.icbc.boundary_conditions import BC, npfunc_range_autocache 4 | 5 | class NormalBC(BC): 6 | """Normal boundary conditions: y_{component} * n_1(x) + y_{component+1} * n_2(x) = func(x).""" 7 | 8 | def __init__(self, geom, func, on_boundary, component=0): 9 | super().__init__(geom, on_boundary, component) 10 | self.func = npfunc_range_autocache(utils.return_tensor(func)) 11 | 12 | def error(self, X, inputs, outputs, beg, end): 13 | values = self.func(X, beg, end) 14 | n = self.boundary_normal(X, beg, end) 15 | y = outputs[beg:end, self.component : self.component + 2] 16 | return bkd.sum(y * n, 1, keepdims=True) - values -------------------------------------------------------------------------------- /src/utils/pinn_callback.py: -------------------------------------------------------------------------------- 1 | import deepxde.losses as losses_module 2 | from deepxde.callbacks import Callback 3 | import copy 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class PINNTester(Callback): 9 | 10 | def __init__(self, data_path, dim, test_fn): 11 | super().__init__() 12 | self.data_path = data_path 13 | self.dim = dim 14 | self.test_fn = test_fn 15 | self.m_abs_e_res = [] 16 | self.m_r_abs_e_res = [] 17 | self.m_loss_res = [] 18 | 19 | def on_epoch_end(self): 20 | net = self.model.net 21 | test_res = self.test_fn(self.data_path, self.dim, net) 22 | self.m_abs_e_res.append(test_res[0][0]) 23 | self.m_r_abs_e_res.append(test_res[1][0]) 24 | self.m_loss_res.append(np.sum(self.model.train_state.loss_train)) 25 | 26 | 27 | class PINNGradientTracker(Callback): 28 | 29 | def __init__(self, num_pdes): 30 | super().__init__() 31 | self.num_pdes = num_pdes 32 | self.m_gradients = [] 33 | self.conds = [] 34 | self.last_loss = None 35 | self.last_params = None 36 | self.steps = [] 37 | self.loss_fn = losses_module.get("MSE") 38 | 39 | def on_epoch_end(self): 40 | model = self.model.net 41 | params = [] 42 | for param in model.parameters(): 43 | params.append(param.reshape(-1)) 44 | params = torch.cat(params) 45 | # get the loss and parameters 46 | outputs = self.model.net(self.model.net.inputs.float()) 47 | losses = self.model.data.losses(None, outputs, self.loss_fn, self.model) 48 | # find mean|\nabla_{\theta}L_r| 49 | losses_r = torch.sum(torch.stack(losses[:self.num_pdes])) 50 | m_grad_r = [] 51 | for param in self.model.net.parameters(): 52 | grads = torch.autograd.grad(losses_r, param, retain_graph=True, allow_unused=True) 53 | if grads[0] is not None: 54 | m_grad_r.append(torch.abs(grads[0]).reshape(-1)) 55 | else: 56 | m_grad_r.append(torch.zeros_like(param)) 57 | self.m_gradients.append(torch.mean(torch.cat(m_grad_r)).item()) 58 | self.steps.append(self.model.train_state.epoch) 59 | loss = np.sum(self.model.train_state.loss_train[:self.num_pdes]) 60 | loss = np.sum(loss) 61 | # calculate cond 62 | if self.last_params is None: 63 | # the first epoch 64 | self.conds.append(None) 65 | else: 66 | self.conds.append( 67 | np.abs(loss - self.last_loss) / torch.norm(params-self.last_params).item() 68 | ) 69 | self.last_params = params 70 | self.last_loss = loss 71 | 72 | 73 | class PINNLRAdaptor(Callback): 74 | """ 75 | PINN callback for learning rate annealing algorithm of physics-informed neural networks. 76 | """ 77 | 78 | def __init__(self, loss_weight, num_pdes, alpha, mode="max"): 79 | ''' 80 | loss_weight - initial loss weights\n 81 | num_pdes - the number of the PDEs (boundary conditions excluded)\n 82 | alpha - parameter of moving average\n 83 | mode - "max" (PINN-LA), "mean" (PINN-LA-2) 84 | ''' 85 | super().__init__() 86 | self.loss_weight = loss_weight 87 | self.num_pdes = num_pdes 88 | self.alpha = alpha 89 | self.loss_fn = losses_module.get("MSE") 90 | self.mode = mode 91 | 92 | def on_epoch_end(self): 93 | # get the loss and parameters 94 | outputs = self.model.net(self.model.net.inputs.float()) 95 | losses = self.model.data.losses(None, outputs, self.loss_fn, self.model) 96 | # find max|\nabla_{\theta}L_r| 97 | losses_r = torch.sum(torch.stack(losses[:self.num_pdes])) 98 | m_grad_r = [] 99 | for param in self.model.net.parameters(): 100 | grads = torch.autograd.grad(losses_r, param, retain_graph=True, allow_unused=True) 101 | if grads[0] is not None: 102 | m_grad_r.append(torch.abs(grads[0]).reshape(-1)) 103 | else: 104 | m_grad_r.append(torch.zeros_like(param)) 105 | if self.mode == "mean": 106 | m_grad_r = torch.mean(torch.cat(m_grad_r)).item() 107 | else: 108 | m_grad_r = torch.max(torch.cat(m_grad_r)).item() 109 | # adapt the weights for each bc term 110 | for i in range(self.num_pdes, len(self.loss_weight)): 111 | grads_bc = [] 112 | for param in self.model.net.parameters(): 113 | grads = torch.autograd.grad(losses[i], param, retain_graph=True, allow_unused=True) 114 | if grads[0] is not None: 115 | grads_bc.append(torch.abs(grads[0]).reshape(-1)) 116 | else: 117 | grads_bc.append(torch.zeros_like(param)) 118 | lambda_hat = m_grad_r / (torch.mean(torch.cat(grads_bc)).item() * self.loss_weight[i]) 119 | self.loss_weight[i] = (1 - self.alpha) * self.loss_weight[i] + self.alpha * lambda_hat 120 | 121 | 122 | class PINNLRScheduler(Callback): 123 | """ 124 | PINN callback for learning rate scheduler. 125 | """ 126 | def __init__(self, scheduler): 127 | super().__init__() 128 | self.scheduler = scheduler 129 | 130 | def on_epoch_end(self): 131 | # get the loss and parameters 132 | losses = np.array(self.model.train_state.loss_train) 133 | m_loss = np.mean(losses) 134 | self.scheduler.step(m_loss) 135 | 136 | 137 | class PINNModelSaver(Callback): 138 | """ 139 | PINN callback for saving the weights of physics-informed neural networks. 140 | """ 141 | 142 | def __init__(self, init_weights=None): 143 | super().__init__() 144 | # if the model got nan outputs 145 | self.got_nan = False 146 | # the weights of the last epoch 147 | if init_weights is not None: 148 | self.weights = copy.deepcopy(init_weights) 149 | 150 | def on_epoch_end(self): 151 | # get the loss and parameters 152 | losses = np.array(self.model.train_state.loss_train) 153 | if np.isnan(losses).any(): 154 | # do not save when it has nan outputs 155 | self.got_nan = True 156 | else: 157 | self.weights = copy.deepcopy(self.model.net.state_dict()) -------------------------------------------------------------------------------- /src/utils/pinn_geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from deepxde.geometry.geometry import Geometry 3 | 4 | class RecDiskDomain(Geometry): 5 | """ 6 | Rectangular outer boundary & Circular inner boundary 7 | """ 8 | 9 | def __init__(self, rec, disks): 10 | super().__init__(rec.dim, rec.bbox, rec.diam) 11 | self.rec = rec 12 | self.disks = disks 13 | 14 | def inside(self, x): 15 | inside_all = self.rec.inside(x) 16 | for disk in self.disks: 17 | inside_all = np.logical_and(inside_all, ~disk.inside(x)) 18 | return inside_all 19 | 20 | def on_boundary(self, x): 21 | on_boundary_all = self.rec.on_boundary(x) 22 | for disk in self.disks: 23 | on_boundary_all = np.logical_or(on_boundary_all, disk.on_boundary(x)) 24 | return on_boundary_all 25 | 26 | def random_points(self, n, random="pseudo"): 27 | x = np.empty(shape=(n, self.dim)) 28 | i = 0 29 | while i < n: 30 | tmp = self.rec.random_points(n, random=random) 31 | for disk in self.disks: 32 | tmp = tmp[~disk.inside(tmp)] 33 | 34 | if len(tmp) > n - i: 35 | tmp = tmp[: n - i] 36 | x[i : i + len(tmp)] = tmp 37 | i += len(tmp) 38 | return x 39 | 40 | def random_boundary_points(self, n, random="pseudo"): 41 | x = np.empty(shape=(n, self.dim)) 42 | i = 0 43 | while i < n: 44 | 45 | tmp = self.rec.random_boundary_points(n, random=random) 46 | for disk in self.disks: 47 | disk_boundary_potins = disk.random_boundary_points(n, random=random) 48 | tmp = np.concatenate((tmp, disk_boundary_potins)) 49 | 50 | tmp = np.random.permutation(tmp) 51 | 52 | if len(tmp) > n - i: 53 | tmp = tmp[: n - i] 54 | x[i : i + len(tmp)] = tmp 55 | i += len(tmp) 56 | return x 57 | 58 | def boundary_normal(self, x): 59 | n = self.rec.on_boundary(x).reshape(-1, 1) * self.rec.boundary_normal(x) 60 | for disk in self.disks: 61 | n += disk.on_boundary(x).reshape(-1, 1) * -disk.boundary_normal(x) 62 | return n -------------------------------------------------------------------------------- /src/utils/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | # MLP with residual connection 6 | class ResNet(nn.Module): 7 | def __init__(self, num_res_layers, input_dim, hidden_dim, output_dim): 8 | ''' 9 | num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model. 10 | input_dim: dimensionality of input features 11 | hidden_dim: dimensionality of hidden units at ALL layers 12 | output_dim: number of classes for prediction 13 | device: which device to use 14 | ''' 15 | 16 | super(ResNet, self).__init__() 17 | self.num_layers = num_res_layers 18 | # Multi-layer model 19 | self.linears = torch.nn.ModuleList() 20 | self.fc1 = nn.Linear(input_dim, hidden_dim) 21 | for _ in range(num_res_layers): 22 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 23 | self.fc2 = nn.Linear(hidden_dim, output_dim) 24 | 25 | def forward(self, x_loc, x_bou=None): 26 | if x_bou is None: 27 | x = x_loc 28 | else: 29 | x = torch.cat((x_bou, x_loc), 1) 30 | x = torch.tanh(self.fc1(x)) 31 | last_t = x 32 | res_connect = False 33 | for layer in range(self.num_layers): 34 | if res_connect: 35 | x = torch.tanh(self.linears[layer](x)) + last_t 36 | last_t = x 37 | else: 38 | x = torch.tanh(self.linears[layer](x)) 39 | res_connect = not res_connect 40 | x = self.fc2(x) 41 | return x 42 | -------------------------------------------------------------------------------- /src/utils/torch_interp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: anonymous 3 | """ 4 | import torch 5 | import contextlib 6 | 7 | class Interp1d(torch.autograd.Function): 8 | def __call__(self, x, y, xnew, out=None): 9 | return self.forward(x, y, xnew, out) 10 | 11 | def forward(ctx, x, y, xnew, out=None): 12 | """ 13 | Linear 1D interpolation on the GPU for Pytorch. 14 | This function returns interpolated values of a set of 1-D functions at 15 | the desired query points `xnew`. 16 | This function is working similarly to Matlab™ or scipy functions with 17 | the `linear` interpolation mode on, except that it parallelises over 18 | any number of desired interpolation problems. 19 | The code will run on GPU if all the tensors provided are on a cuda 20 | device. 21 | 22 | Parameters 23 | ---------- 24 | x : (N, ) or (D, N) Pytorch Tensor 25 | A 1-D or 2-D tensor of real values. 26 | y : (N,) or (D, N) Pytorch Tensor 27 | A 1-D or 2-D tensor of real values. The length of `y` along its 28 | last dimension must be the same as that of `x` 29 | xnew : (P,) or (D, P) Pytorch Tensor 30 | A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if 31 | _both_ `x` and `y` are 1-D. Otherwise, its length along the first 32 | dimension must be the same as that of whichever `x` and `y` is 2-D. 33 | out : Pytorch Tensor, same shape as `xnew` 34 | Tensor for the output. If None: allocated automatically. 35 | 36 | """ 37 | # making the vectors at least 2D 38 | is_flat = {} 39 | require_grad = {} 40 | v = {} 41 | device = [] 42 | eps = torch.finfo(y.dtype).eps 43 | for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items(): 44 | assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\ 45 | 'at most 2-D.' 46 | if len(vec.shape) == 1: 47 | v[name] = vec[None, :] 48 | else: 49 | v[name] = vec 50 | is_flat[name] = v[name].shape[0] == 1 51 | require_grad[name] = vec.requires_grad 52 | device = list(set(device + [str(vec.device)])) 53 | assert len(device) == 1, 'All parameters must be on the same device.' 54 | device = device[0] 55 | 56 | # Checking for the dimensions 57 | assert (v['x'].shape[1] == v['y'].shape[1] 58 | and ( 59 | v['x'].shape[0] == v['y'].shape[0] 60 | or v['x'].shape[0] == 1 61 | or v['y'].shape[0] == 1 62 | ) 63 | ), ("x and y must have the same number of columns, and either " 64 | "the same number of row or one of them having only one " 65 | "row.") 66 | 67 | reshaped_xnew = False 68 | if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1) 69 | and (v['xnew'].shape[0] > 1)): 70 | # if there is only one row for both x and y, there is no need to 71 | # loop over the rows of xnew because they will all have to face the 72 | # same interpolation problem. We should just stack them together to 73 | # call interp1d and put them back in place afterwards. 74 | original_xnew_shape = v['xnew'].shape 75 | v['xnew'] = v['xnew'].contiguous().view(1, -1) 76 | reshaped_xnew = True 77 | 78 | # identify the dimensions of output and check if the one provided is ok 79 | D = max(v['x'].shape[0], v['xnew'].shape[0]) 80 | shape_ynew = (D, v['xnew'].shape[-1]) 81 | if out is not None: 82 | if out.numel() != shape_ynew[0]*shape_ynew[1]: 83 | # The output provided is of incorrect shape. 84 | # Going for a new one 85 | out = None 86 | else: 87 | ynew = out.reshape(shape_ynew) 88 | if out is None: 89 | ynew = torch.zeros(*shape_ynew, device=device) 90 | 91 | # moving everything to the desired device in case it was not there 92 | # already (not handling the case things do not fit entirely, user will 93 | # do it if required.) 94 | for name in v: 95 | v[name] = v[name].to(device) 96 | 97 | # calling searchsorted on the x values. 98 | ind = ynew.long() 99 | 100 | # expanding xnew to match the number of rows of x in case only one xnew is 101 | # provided 102 | if v['xnew'].shape[0] == 1: 103 | v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1) 104 | 105 | torch.searchsorted(v['x'].contiguous(), 106 | v['xnew'].contiguous(), out=ind) 107 | 108 | # the `-1` is because searchsorted looks for the index where the values 109 | # must be inserted to preserve order. And we want the index of the 110 | # preceeding value. 111 | ind -= 1 112 | # we clamp the index, because the number of intervals is x.shape-1, 113 | # and the left neighbour should hence be at most number of intervals 114 | # -1, i.e. number of columns in x -2 115 | ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1) 116 | 117 | # helper function to select stuff according to the found indices. 118 | def sel(name): 119 | if is_flat[name]: 120 | return v[name].contiguous().view(-1)[ind] 121 | return torch.gather(v[name], 1, ind) 122 | 123 | # activating gradient storing for everything now 124 | enable_grad = False 125 | saved_inputs = [] 126 | for name in ['x', 'y', 'xnew']: 127 | if require_grad[name]: 128 | enable_grad = True 129 | saved_inputs += [v[name]] 130 | else: 131 | saved_inputs += [None, ] 132 | # assuming x are sorted in the dimension 1, computing the slopes for 133 | # the segments 134 | is_flat['slopes'] = is_flat['x'] 135 | # now we have found the indices of the neighbors, we start building the 136 | # output. Hence, we start also activating gradient tracking 137 | with torch.enable_grad() if enable_grad else contextlib.suppress(): 138 | v['slopes'] = ( 139 | (v['y'][:, 1:]-v['y'][:, :-1]) 140 | / 141 | (eps + (v['x'][:, 1:]-v['x'][:, :-1])) 142 | ) 143 | 144 | # now build the linear interpolation 145 | ynew = sel('y') + sel('slopes')*( 146 | v['xnew'] - sel('x')) 147 | 148 | if reshaped_xnew: 149 | ynew = ynew.view(original_xnew_shape) 150 | 151 | ctx.save_for_backward(ynew, *saved_inputs) 152 | return ynew 153 | 154 | @staticmethod 155 | def backward(ctx, grad_out): 156 | inputs = ctx.saved_tensors[1:] 157 | gradients = torch.autograd.grad( 158 | ctx.saved_tensors[0], 159 | [i for i in inputs if i is not None], 160 | grad_out, retain_graph=True) 161 | result = [None, ] * 5 162 | pos = 0 163 | for index in range(len(inputs)): 164 | if inputs[index] is not None: 165 | result[index] = gradients[pos] 166 | pos += 1 167 | return (*result,) -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from array import array 3 | from typing import Callable 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | from scipy import interpolate 8 | 9 | 10 | epsilon = 1e-6 11 | 12 | 13 | def load_params_disks(data_path: str) -> tuple[list[list[float]], list[float]]: 14 | ''' 15 | Load the parameters of the disks from the first line of the input file. 16 | ''' 17 | disk_centers = [] 18 | disk_rs = [] 19 | with open(data_path, "r") as f: 20 | line = f.readline() 21 | ls = line.split() 22 | for i in range(1, len(ls), 3): 23 | disk_centers.append([float(ls[i]), float(ls[i+1])]) 24 | disk_rs.append(float(ls[i+2])) 25 | return disk_centers, disk_rs 26 | 27 | 28 | def load_airfoil_points(airfoil_path: str) -> list[list[float]]: 29 | points = [] 30 | with open(airfoil_path, "r") as f: 31 | for line in f.readlines(): 32 | line = line.split() 33 | if len(line) == 2: 34 | points.append( 35 | [float(line[0]), 36 | float(line[1])] 37 | ) 38 | return np.array(points) 39 | 40 | 41 | def load(data_path: str, dim_X: int) -> tuple[array, array]: 42 | X = [] 43 | Y = [] 44 | with open(data_path, "r") as f: 45 | for line in f.readlines(): 46 | ls = line.split() 47 | if ls[0] == '%': 48 | continue 49 | ls = [float(s) for s in ls] 50 | X.append(ls[:dim_X]) 51 | Y.append(ls[dim_X:]) 52 | return np.array(X), np.array(Y) 53 | 54 | 55 | def load_time(data_path: str, dim_X: int) -> tuple[array, array]: 56 | ''' 57 | Load time-dependent data 58 | ''' 59 | X = [] 60 | Y = [] 61 | with open(data_path, "r") as f: 62 | for line in f.readlines(): 63 | ls = line.split() 64 | if ls[0] == '%': 65 | continue 66 | ls = [float(s) for s in ls] 67 | X.extend([ls[:dim_X] + [i/10] for i in range(11)]) 68 | Y.extend([[ls[dim_X + i]] for i in range(11)]) 69 | return np.array(X), np.array(Y) 70 | 71 | 72 | def test(data_path: str, dim_X: int, model: nn.Module) -> None: 73 | test_X, test_Y = load(data_path, dim_X) 74 | pred_Y = model(torch.tensor(test_X).float()).detach().cpu().numpy() 75 | for j in range(test_Y.shape[1]): 76 | abs_e = np.absolute(pred_Y[:, j] - test_Y[:, j]) 77 | m_abs_e = np.mean(abs_e) 78 | print("Mean abs error of u_%d: %.4f"%(j, m_abs_e)) 79 | wmape = np.sum(abs_e) / np.sum(np.absolute(test_Y[:, j])) 80 | print("Weighted mean abs percentage error of u_%d: %.4f"%(j, wmape)) 81 | # plot the heat map 82 | vmin = max(np.min( 83 | test_Y[:, j] 84 | ), 0) 85 | vmax = np.max( 86 | test_Y[:, j] 87 | ) 88 | plot_heatmap( 89 | test_X[:,0].reshape(-1), test_X[:,1].reshape(-1), 90 | test_Y[:, j], "outs/heatmap_exact_u_%d.png"%j, 91 | title="Heatmap of exact u_%d"%j, vmin=vmin, vmax=vmax 92 | ) 93 | plot_heatmap( 94 | test_X[:,0].reshape(-1), test_X[:,1].reshape(-1), 95 | pred_Y[:, j], "outs/heatmap_pred_u_%d.png"%j, 96 | title="Heatmap of pred u_%d"%j, vmin=vmin, vmax=vmax 97 | ) 98 | vmin = max(np.min( 99 | abs_e 100 | ), 0) 101 | plot_heatmap( 102 | test_X[:,0].reshape(-1), test_X[:,1].reshape(-1), 103 | abs_e, "outs/heatmap_r_abs_e_u_%d.png"%j, 104 | title="Heatmap of absolute error of u_%d"%j, vmin=vmin 105 | ) 106 | # save the results 107 | with open("outs/result.txt", "w") as f: 108 | for i in range(test_Y.shape[0]): 109 | for j in range(test_X.shape[1]): 110 | f.write("%f "%test_X[i, j]) 111 | for j in range(test_Y.shape[1]): 112 | f.write("%f "%pred_Y[i, j]) 113 | f.write("\n") 114 | 115 | 116 | def test_time(data_path: str, dim_X: int, model: nn.Module) -> None: 117 | test_X, test_Y = load_time(data_path, dim_X) 118 | pred_Y = model(torch.tensor(test_X).float()).detach().cpu().numpy() 119 | for j in range(test_Y.shape[1]): 120 | m_abs_es = [] 121 | m_r_abs_es = [] 122 | for t in range(11): 123 | t = t / 10 124 | index = np.where(np.isclose(test_X[:,2], t)) 125 | pred_Y_t = pred_Y[index] 126 | test_Y_t = test_Y[index] 127 | test_X_t = test_X[index][:, :2] 128 | abs_e = np.absolute(pred_Y_t[:, j] - test_Y_t[:, j]) 129 | m_abs_e = np.mean(abs_e) 130 | print("Mean abs error of u_%d (t=%.1f): %.4f"%(j, t, m_abs_e)) 131 | m_abs_es.append(m_abs_e) 132 | r_abs_e = abs_e / np.absolute(test_Y_t[:, j]) 133 | m_r_abs_e = np.mean(r_abs_e) 134 | print("Mean abs percentage error of u_%d (t=%.1f): %.4f"%(j, t, m_r_abs_e)) 135 | m_r_abs_es.append(m_r_abs_e) 136 | # plot the heat map 137 | vmin = max(np.min( 138 | test_Y_t[:, j] 139 | ), 0) 140 | vmax = np.max( 141 | test_Y_t[:, j] 142 | ) 143 | plot_heatmap( 144 | test_X_t[:,0].reshape(-1), test_X_t[:,1].reshape(-1), 145 | test_Y_t[:, j], "outs/heatmap_exact_u_%d (t=%.1f).png"%(j, t), 146 | title="Heatmap of exact u_%d (t=%.1f)"%(j, t), vmin=vmin, vmax=vmax 147 | ) 148 | plot_heatmap( 149 | test_X_t[:,0].reshape(-1), test_X_t[:,1].reshape(-1), 150 | pred_Y_t[:, j], "outs/heatmap_pred_u_%d (t=%.1f).png"%(j, t), 151 | title="Heatmap of pred u_%d (t=%.1f)"%(j, t), vmin=vmin, vmax=vmax 152 | ) 153 | vmin = max(np.min( 154 | r_abs_e 155 | ), 0) 156 | plot_heatmap( 157 | test_X_t[:,0].reshape(-1), test_X_t[:,1].reshape(-1), 158 | r_abs_e, "outs/heatmap_r_abs_e_u_%d (t=%.1f).png"%(j, t), 159 | title="Heatmap of absolute percentage error of u_%d (t=%.1f)"%(j, t), vmin=vmin 160 | ) 161 | print("Overall mean abs error of u_%d: %.4f"%(j, np.mean(m_abs_es))) 162 | print("Overall mean abs percentage error of u_%d: %.4f"%(j, np.mean(m_r_abs_es))) 163 | # save the results 164 | with open("outs/result.txt", "w") as f: 165 | for i in range(test_X.shape[0] // 11): 166 | f.write("%f %f "%(test_X[i * 11, 0], test_X[i * 11, 1])) 167 | for j in range(11): 168 | for k in range(test_Y.shape[1]): 169 | f.write("%f "%pred_Y[i * 11 + j, k]) 170 | f.write("\n") 171 | 172 | 173 | def test_time_with_reference_solution( 174 | reference_solution: Callable, test_X, 175 | model: nn.Module 176 | ): 177 | test_Y = reference_solution(test_X) 178 | pred_Y = model(torch.tensor(test_X).float()).detach().cpu().numpy() 179 | for j in range(test_Y.shape[1]): 180 | m_abs_es = [] 181 | m_r_abs_es = [] 182 | for t in range(11): 183 | t = t / 10 184 | index = np.where(np.isclose(test_X[:,-1], t)) 185 | pred_Y_t = pred_Y[index] 186 | test_Y_t = test_Y[index] 187 | abs_e = np.absolute(pred_Y_t[:, j] - test_Y_t[:, j]) 188 | m_abs_e = np.mean(abs_e) 189 | print("Mean abs error of u_%d (t=%.1f): %.4f"%(j, t, m_abs_e)) 190 | m_abs_es.append(m_abs_e) 191 | r_abs_e = abs_e / np.absolute(test_Y_t[:, j]) 192 | m_r_abs_e = np.mean(r_abs_e) 193 | print("Mean abs percentage error of u_%d (t=%.1f): %.4f"%(j, t, m_r_abs_e)) 194 | m_r_abs_es.append(m_r_abs_e) 195 | print("Overall mean abs error of u_%d: %.4f"%(j, np.mean(m_abs_es))) 196 | print("Overall mean abs percentage error of u_%d: %.4f"%(j, np.mean(m_r_abs_es))) 197 | 198 | 199 | class Tester: 200 | test_X_while_train = None 201 | test_Y_while_train = None 202 | def test_while_train(data_path: str, dim_X: int, model: nn.Module) -> list: 203 | if Tester.test_X_while_train is None: 204 | Tester.test_X_while_train, Tester.test_Y_while_train = load(data_path, dim_X) 205 | Tester.test_X_while_train = torch.tensor(Tester.test_X_while_train).float() 206 | pred_Y = model(Tester.test_X_while_train).detach().cpu().numpy() 207 | m_abs_e_res, m_r_abs_e_res = [], [] 208 | for j in range(Tester.test_Y_while_train.shape[1]): 209 | abs_e = np.absolute(pred_Y[:, j] - Tester.test_Y_while_train[:, j]) 210 | m_abs_e = np.mean(abs_e) 211 | m_abs_e_res.append(m_abs_e) 212 | r_abs_e = abs_e / np.absolute(Tester.test_Y_while_train[:, j]) 213 | m_r_abs_e = np.mean(r_abs_e) 214 | m_r_abs_e_res.append(m_r_abs_e) 215 | return m_abs_e_res, m_r_abs_e_res 216 | 217 | 218 | def plot_distribution(data, xlabel, ylabel, path, title="", log_scal=True): 219 | ''' 220 | plot the distribution of data with log-scale. 221 | ''' 222 | plt.cla() 223 | plt.figure() 224 | _, bins, _ = plt.hist(data, bins=30) 225 | plt.close() 226 | plt.cla() 227 | plt.figure() 228 | logbins = np.logspace(np.log10(bins[0] + epsilon),np.log10(bins[-1] + epsilon),len(bins)) 229 | plt.hist(data, bins=logbins) 230 | if log_scal: 231 | plt.xscale('log') 232 | plt.xlabel(xlabel) 233 | plt.ylabel(ylabel) 234 | plt.axvline(x=np.mean(data), c="r", ls="--", lw=2) 235 | plt.title(title) 236 | plt.savefig(path) 237 | plt.close() 238 | 239 | 240 | def plot_lines( 241 | data, xlabel, ylabel, labels, 242 | path, is_log=False, title="", 243 | sort_=False 244 | ): 245 | ''' 246 | Lines 247 | ''' 248 | plt.cla() 249 | plt.figure() 250 | for i in range(1, len(data)): 251 | if sort_: 252 | x = np.array(data[0]) 253 | y = np.array(data[i]) 254 | sorted_indices = np.argsort(x) 255 | sorted_x = x[sorted_indices] 256 | sorted_y = y[sorted_indices] 257 | plt.plot(sorted_x, sorted_y, label=labels[i-1]) 258 | else: 259 | plt.plot(data[0], data[i], label=labels[i-1]) 260 | plt.legend() 261 | if is_log: 262 | plt.yscale('log') 263 | plt.xlabel(xlabel) 264 | plt.ylabel(ylabel) 265 | plt.title(title) 266 | plt.savefig(path) 267 | plt.close() 268 | 269 | 270 | def plot_heatmap( 271 | x, y, z, path, vmin=None, vmax=None, 272 | title="", xlabel="x", ylabel="y" 273 | ): 274 | ''' 275 | Plot heat map for a 3-dimension data 276 | ''' 277 | plt.cla() 278 | plt.figure() 279 | xx = np.linspace(np.min(x), np.max(x)) 280 | yy = np.linspace(np.min(y), np.max(y)) 281 | xx, yy = np.meshgrid(xx, yy) 282 | yy = yy[::-1,:] 283 | 284 | vals = interpolate.griddata(np.array([x, y]).T, np.array(z), 285 | (xx, yy), method='cubic') 286 | vals_0 = interpolate.griddata(np.array([x, y]).T, np.array(z), 287 | (xx, yy), method='nearest') 288 | vals[np.isnan(vals)] = vals_0[np.isnan(vals)] 289 | 290 | if vmin is not None and vmax is not None: 291 | fig = plt.imshow(vals, 292 | extent=[np.min(x), np.max(x),np.min(y), np.max(y)], 293 | aspect="equal", interpolation="bicubic", 294 | vmin=vmin, vmax=vmax) 295 | elif vmin is not None: 296 | fig = plt.imshow(vals, 297 | extent=[np.min(x), np.max(x),np.min(y), np.max(y)], 298 | aspect="equal", interpolation="bicubic", 299 | vmin=vmin) 300 | else: 301 | fig = plt.imshow(vals, 302 | extent=[np.min(x), np.max(x),np.min(y), np.max(y)], 303 | aspect="equal", interpolation="bicubic") 304 | fig.axes.set_autoscale_on(False) 305 | plt.xlabel(xlabel) 306 | plt.ylabel(ylabel) 307 | plt.title(title) 308 | plt.colorbar() 309 | plt.savefig(path) 310 | plt.close() 311 | 312 | def cart2pol_np(x, y): 313 | ''' 314 | From Cartesian coordinates to polar coordinates (implemented by numpy). 315 | Return: (r, theta) 316 | ''' 317 | r = np.sqrt(x**2 + y**2) 318 | theta = np.arctan2(y, x) 319 | return (r, theta) 320 | 321 | def cart2pol_pt(x, y): 322 | ''' 323 | From Cartesian coordinates to polar coordinates (implemented by pytorch). 324 | Return: (r, theta) 325 | ''' 326 | r = torch.sqrt(x**2 + y**2) 327 | theta = torch.atan2(y, x) 328 | return (r, theta) 329 | 330 | def lineseg_dists(p, a, b): 331 | """Cartesian distance from point to line segment 332 | 333 | Edited to support arguments as series, from: anonymous 334 | 335 | Args: 336 | - p: np.array of single point, shape (2,) or 2D array, shape (x, 2) 337 | - a: np.array of shape (x, 2) 338 | - b: np.array of shape (x, 2) 339 | """ 340 | # normalized tangent vectors 341 | d_ba = b - a 342 | d = np.divide(d_ba, (np.hypot(d_ba[:, 0], d_ba[:, 1]) 343 | .reshape(-1, 1))) 344 | 345 | # signed parallel distance components 346 | # rowwise dot products of 2D vectors 347 | s = np.multiply(a - p, d).sum(axis=1) 348 | t = np.multiply(p - b, d).sum(axis=1) 349 | 350 | # clamped parallel distance 351 | h = np.maximum.reduce([s, t, np.zeros(len(s))]) 352 | 353 | # perpendicular distance component 354 | # rowwise cross products of 2D vectors 355 | d_pa = p - a 356 | c = d_pa[:, 0] * d[:, 1] - d_pa[:, 1] * d[:, 0] 357 | 358 | return np.hypot(h, c) -------------------------------------------------------------------------------- /src/xPINN/interface_conditions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import deepxde as dde 3 | 4 | 5 | class Subdomains: 6 | """ 7 | Generate subdomains and their interface conditions for the xPINN. 8 | ## Parameter 9 | L - length\n 10 | H - height\n 11 | n_col - the number of columns of the sub-domains\n 12 | n_row - the number of rows of the sub-domains\n 13 | xmin - the offset of the left-bottom point\n 14 | domain - problem domain\n 15 | num_output - the number outputs 16 | """ 17 | def __init__( 18 | self, L, H, n_col, n_row, num_output, spatial_temporal_domain, xmin=None, temporal_domain=None 19 | ) -> None: 20 | self.L = L 21 | self.H = H 22 | self.n_col = n_col 23 | self.n_row = n_row 24 | self.num_output = num_output 25 | self.domain = spatial_temporal_domain 26 | if xmin is None: 27 | xmin = [- L / 2, - H / 2] 28 | self.lower_bs = [ 29 | [ 30 | [L * j / n_col + xmin[0], 31 | H * i / n_row + xmin[1]] 32 | for j in range(n_col) 33 | ] for i in range(n_row) 34 | ] 35 | self.upper_bs = [ 36 | [ 37 | [L * (j+1) / n_col + xmin[0], 38 | H * (i+1) / n_row + xmin[1]] 39 | for j in range(n_col) 40 | ] for i in range(n_row) 41 | ] 42 | self.sub_blocks = [ 43 | [ 44 | dde.geometry.Rectangle( 45 | xmin=self.lower_bs[i][j], xmax=self.upper_bs[i][j] 46 | ) if temporal_domain is None else 47 | dde.geometry.GeometryXTime( 48 | dde.geometry.Rectangle( 49 | xmin=self.lower_bs[i][j], xmax=self.upper_bs[i][j] 50 | ), temporal_domain 51 | ) 52 | for j in range(n_col) 53 | ] for i in range(n_row) 54 | ] 55 | 56 | def generate_interface_points(self, n): 57 | points = [] 58 | for i in range(self.n_row): 59 | for j in range(self.n_col): 60 | points.append( 61 | self.sub_blocks[i][j].random_boundary_points(n) 62 | ) 63 | points = np.concatenate(points, axis=0) 64 | np.random.shuffle(points) 65 | return points[:n, :] 66 | 67 | # check if index is legal 68 | def check_index(self, i, j): 69 | return (0 <= i < self.n_row) and \ 70 | (0 <= j < self.n_col) 71 | 72 | def generate_interface_conditions(self): 73 | interface_conditions = [] 74 | cnt = self.num_output 75 | for i in range(self.n_row): 76 | for j in range(self.n_col): 77 | if self.check_index(i, j-1): 78 | interface_conditions.append( 79 | dde.icbc.DirichletBC( 80 | self.domain, lambda _: 0., 81 | lambda x, on_bc: np.isclose(x[0], self.lower_bs[i][j][0]) and \ 82 | (self.lower_bs[i][j][1] <= x[1] <= self.upper_bs[i][j][1]), 83 | component=cnt 84 | ) 85 | ) 86 | cnt += 1 87 | if self.check_index(i, j+1): 88 | interface_conditions.append( 89 | dde.icbc.DirichletBC( 90 | self.domain, lambda _: 0., 91 | lambda x, on_bc: np.isclose(x[0], self.upper_bs[i][j][0]) and \ 92 | (self.lower_bs[i][j][1] <= x[1] <= self.upper_bs[i][j][1]), 93 | component=cnt 94 | ) 95 | ) 96 | cnt += 1 97 | if self.check_index(i-1, j): 98 | interface_conditions.append( 99 | dde.icbc.DirichletBC( 100 | self.domain, lambda _: 0., 101 | lambda x, on_bc: np.isclose(x[1], self.lower_bs[i][j][1]) and \ 102 | (self.lower_bs[i][j][0] <= x[0] <= self.upper_bs[i][j][0]), 103 | component=cnt 104 | ) 105 | ) 106 | cnt += 1 107 | if self.check_index(i+1, j): 108 | interface_conditions.append( 109 | dde.icbc.DirichletBC( 110 | self.domain, lambda _: 0., 111 | lambda x, on_bc: np.isclose(x[1], self.upper_bs[i][j][1]) and \ 112 | (self.lower_bs[i][j][0] <= x[0] <= self.upper_bs[i][j][0]), 113 | component=cnt 114 | ) 115 | ) 116 | cnt += 1 117 | return interface_conditions 118 | -------------------------------------------------------------------------------- /src/xPINN/xPINN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import deepxde as dde 4 | 5 | 6 | class xPINN(nn.Module): 7 | """ 8 | Doamin decomposition PINN in a hard way. 9 | ## Parameter 10 | L - length\n 11 | H - height\n 12 | n_col - the number of columns of the sub-domains\n 13 | n_row - the number of rows of the sub-domains\n 14 | num_layers, input_dim, hidden_dim, output_dim - the parameters of each sub-network\n 15 | pde_fn - function to calculate pde residuals\n 16 | xmin - the offset of the left-bottom point 17 | """ 18 | def __init__( 19 | self, L, H, n_col, n_row, 20 | num_layers, input_dim, hidden_dim, output_dim, pde_fn, 21 | xmin=None 22 | ) -> None: 23 | super(xPINN, self).__init__() 24 | self.L = L 25 | self.H = H 26 | self.n_col = n_col 27 | self.n_row = n_row 28 | self.pde_fn = pde_fn 29 | if xmin is None: 30 | self.xmin = [- L / 2, - H / 2] 31 | else: 32 | self.xmin = xmin 33 | nets = [dde.nn.FNN([input_dim] + num_layers * [hidden_dim] + [output_dim], "tanh", "Glorot normal") for _ in range(n_col * n_row)] 34 | self.nets = nn.ModuleList(nets) 35 | self.lower_bs = None 36 | self.eval_mode = False 37 | 38 | # check if index is legal 39 | def check_index(self, i, j): 40 | return (0 <= i < self.n_row) and \ 41 | (0 <= j < self.n_col) 42 | 43 | # loss at interface 44 | # x (this) and y (other) are 2 neighbors 45 | def loss_at_inter(self, x, y, pde_x, pde_y): 46 | loss_discont = torch.sum((x - (x + y) / 2) ** 2, dim=1, keepdim=True) 47 | loss_pde_discont = torch.sum((pde_x - pde_y) ** 2, dim=1, keepdim=True) 48 | return torch.sqrt(loss_discont + loss_pde_discont) 49 | 50 | def append_loss(self, i, j, delta_i=0, delta_j=0): 51 | self.residuals_inter.append(self.loss_at_inter( 52 | self.subdomain_res[i][j], 53 | self.subdomain_res[i+delta_i][j+delta_j], 54 | self.subdomain_pde_res[i][j], 55 | self.subdomain_pde_res[i+delta_i][j+delta_j] 56 | )) 57 | 58 | def set_eval(self): 59 | ''' 60 | Set the mode to evaluation, which does not calculate interface losses. 61 | ''' 62 | self.eval_mode = True 63 | 64 | def set_training(self): 65 | ''' 66 | Set the mode to training. 67 | ''' 68 | self.eval_mode = False 69 | 70 | def forward(self, x) -> torch.Tensor: 71 | # init 72 | if self.lower_bs is None: 73 | self.lower_bs = [ 74 | [ 75 | torch.tensor([self.L * j / self.n_col + self.xmin[0], 76 | self.H * i / self.n_row + self.xmin[1]]) 77 | for j in range(self.n_col) 78 | ] for i in range(self.n_row) 79 | ] 80 | self.upper_bs = [ 81 | [ 82 | torch.tensor([self.L * (j+1) / self.n_col + self.xmin[0], 83 | self.H * (i+1) / self.n_row + self.xmin[1]]) 84 | for j in range(self.n_col) 85 | ] for i in range(self.n_row) 86 | ] 87 | self.subdomain_res = [[None for _ in range(self.n_col)] for _ in range(self.n_row)] 88 | self.subdomain_pde_res = [[None for _ in range(self.n_col)] for _ in range(self.n_row)] 89 | self.residuals_inter = [] 90 | # calculation results 91 | tot_res = 0. 92 | spatial_x = x[:, :2] 93 | for i in range(self.n_row): 94 | for j in range(self.n_col): 95 | indicator_res = torch.prod( 96 | torch.logical_and( 97 | self.lower_bs[i][j] <= spatial_x, 98 | spatial_x <= self.upper_bs[i][j] 99 | ), dim=1, keepdim=True 100 | ) 101 | nn_res = self.nets[i * self.n_col + j](x) 102 | tot_res += indicator_res * nn_res 103 | self.subdomain_res[i][j] = nn_res 104 | if not self.eval_mode: 105 | pde_res = self.pde_fn(x, nn_res) 106 | if isinstance(pde_res, list): 107 | pde_res = torch.cat(pde_res, dim=1) 108 | self.subdomain_pde_res[i][j] = pde_res 109 | if self.eval_mode: 110 | return tot_res 111 | # residuals at interface 112 | self.residuals_inter.clear() 113 | for i in range(self.n_row): 114 | for j in range(self.n_col): 115 | if self.check_index(i, j-1): 116 | self.append_loss(i, j, delta_j=-1) 117 | if self.check_index(i, j+1): 118 | self.append_loss(i, j, delta_j=1) 119 | if self.check_index(i-1, j): 120 | self.append_loss(i, j, delta_i=-1) 121 | if self.check_index(i+1, j): 122 | self.append_loss(i, j, delta_i=1) 123 | return torch.cat((tot_res, torch.cat(self.residuals_inter, dim=1)), dim=1) 124 | --------------------------------------------------------------------------------