├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── deep_pdes ├── __init__.py ├── attentive_neural_process │ ├── .gitignore │ ├── LICENSE │ ├── anp.py │ ├── module.py │ ├── probconserv.py │ └── softc.py ├── datasets │ ├── base.py │ ├── pbc.py │ └── pme.py └── experiments │ ├── 1b_pme_train_pinp_multi.sh │ ├── 1b_pme_var_m.sh │ ├── 2b_stefan_var_p.sh │ ├── 3b_heat_var_c.sh │ ├── 4b_advection_var_a.sh │ ├── 5b_burgers_var_a.sh │ ├── analyze.py │ ├── conf │ ├── config.yaml │ ├── experiments │ │ ├── 1b_pme_var_m.yaml │ │ ├── 2b_stefan_var_p.yaml │ │ ├── 3b_heat_var_c.yaml │ │ ├── 4b_advection_var_a.yaml │ │ └── 5b_burgers_var_a.yaml │ └── train │ │ ├── 1b_pme_var_m_anp.yaml │ │ ├── 1b_pme_var_m_physnp.yaml │ │ ├── 1b_pme_var_m_physnp_fixedvar.yaml │ │ ├── 1b_pme_var_m_physnp_re.yaml │ │ ├── 1b_pme_var_m_physnp_rhs.yaml │ │ ├── 1b_pme_var_m_pinp.yaml │ │ ├── 1b_pme_var_m_pinp_1e1.yaml │ │ ├── 1b_pme_var_m_pinp_1e2.yaml │ │ ├── 1b_pme_var_m_pinp_1e6.yaml │ │ ├── 1b_pme_var_m_pinp_1en1.yaml │ │ ├── 1b_pme_var_m_pinp_1en2.yaml │ │ ├── 2b_stefan_var_p_anp.yaml │ │ ├── 2b_stefan_var_p_physnp.yaml │ │ ├── 2b_stefan_var_p_physnp_fixedvar.yaml │ │ ├── 2b_stefan_var_p_physnp_re.yaml │ │ ├── 2b_stefan_var_p_pinp.yaml │ │ ├── 3b_heat_var_c_anp.yaml │ │ ├── 3b_heat_var_c_pinp.yaml │ │ ├── 4b_advection_var_a_anp.yaml │ │ ├── 4b_advection_var_a_pinp.yaml │ │ ├── 5b_burgers_var_a_anp.yaml │ │ └── 5b_burgers_var_a_pinp.yaml │ ├── generate.py │ ├── output │ ├── .DS_Store │ └── paper │ │ ├── .DS_Store │ │ ├── 1b_pme_var_m │ │ ├── datasets │ │ │ ├── test.pt │ │ │ ├── train.pt │ │ │ └── valid.pt │ │ └── train │ │ │ ├── anp.pt │ │ │ ├── physnp.pt │ │ │ └── pinp.pt │ │ └── 2b_stefan_var_p │ │ ├── .DS_Store │ │ ├── datasets │ │ ├── test.pt │ │ ├── train.pt │ │ └── valid.pt │ │ └── train │ │ ├── anp.pt │ │ ├── physnp.pt │ │ └── physnp_second_deriv.pt │ ├── plots.py │ └── train.py ├── mypy.ini ├── poetry.lock ├── py.typed ├── pyproject.toml ├── resources ├── diffusion_eqtn_conserv_mass.png ├── schematic.png ├── stefan_shock_position_downstream_task └── stefan_solution_profile_UQ └── tests └── test_pme.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Derek Hansen 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProbConserv: Probabilistic Framework to Enforce Conservation Laws 2 | [![wemake-python-styleguide](https://img.shields.io/badge/style-wemake-000000.svg)](https://github.com/wemake-services/wemake-python-styleguide) 3 | 4 | ![Image](resources/schematic.png) 5 | 6 | [Derek Hansen](http://www-personal.umich.edu/~dereklh/), [Danielle C. Maddix](https://dcmaddix.github.io/), [Shima Alizadeh](https://scholar.google.com/citations?user=r3qS03kAAAAJ&hl=en), [Gaurav Gupta](http://guptagaurav.me/index.html), [Michael W. Mahoney](https://www.stat.berkeley.edu/~mmahoney/) \ 7 | **Learning Physical Models that Can Respect Conservation Laws** \ 8 | [Proceedings of the 40th International Conference on Machine Learning (ICML)](https://proceedings.mlr.press/v202/hansen23b/hansen23b.pdf), PMLR. 202:12469-12510, 2023. 9 | 10 | ## Installation 11 | This project uses [poetry](https://python-poetry.org/) to manage dependencies. 12 | 13 | From the root directory: 14 | ``` 15 | poetry install 16 | ``` 17 | 18 | Some of the plots require certain LaTeX packages are present. On Ubuntu, these are 19 | ``` 20 | sudo apt install cm-super dvipng texlive-latex-extra texlive-fonts-recommended 21 | ``` 22 | 23 | You can then use `poetry run` followed by a command, or `poetry shell` to open a shell with the correct virtual environment. 24 | 25 | To run the tests: 26 | ``` 27 | poetry run pytest 28 | ``` 29 | The code for this project is located in the `deep_pdes` folder. It consists of two libraries, `attentive_neural_process` and `datasets`, that comprise the models and datasets respectively. 30 | These libraries are imported by the scripts in `experiments` that configure and run the specific case studies explored in the ProbConserv paper. 31 | ## Running experiments 32 | The experiment code in `deep_pdes/experiment` uses [Hydra](https://hydra.cc/) to manage configuration and run experiments. The different stages of the experiments are broken into distinct commands for easier reproduceability 33 | - `generate.py`: Generate synthetic datasets for training 34 | - `train.py`: Train ProbConserv-ANP, ANP, and other baseline methods such as Physics-Informed Neural Networks (PINNs) 35 | - `analyze.py`: Evaluate the trained models on test datasets and create tables/plots from the results. 36 | - `plots.py`: Generate all plots used in the submission. Does not use the Hydra CLI but uses the compose API internally. 37 | 38 | Each script is run by passing an `+experiments=*` flag. The available experiments can be found in `deep_pdes/experiments/conf/experiments`. For example, to recreate the results on the Stefan GPME setting: 39 | ``` 40 | EXPERIMENT=2b_stefan_var_p 41 | 42 | python generate.py +experiments=$EXPERIMENT 43 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_anp 44 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_pinp 45 | python analyze.py +experiments=$EXPERIMENT 46 | ``` 47 | These commands are also available in convenience scripts; for example, the above is in `deep_pdes/experiments/2b_stefan_var_p.sh`. 48 | 49 | ![Image](resources/stefan_solution_profile_UQ) \ 50 | **Solution Profiles and UQ for the Stefan Equation** \ 51 | ![Image](resources/stefan_shock_position_downstream_task) \ 52 | **Downstream Task: Shock location detection** 53 | 54 | For the diffusion equation with constant diffusivity, see `deep_pdes/experiments/3b_heat_var_c.sh`. 55 | ![Image](resources//diffusion_eqtn_conserv_mass.png) \ 56 | **Conservation of mass** can be violated by the black-box deep learning models, even with applying the PDE as a soft-constraint to the loss function a la Physics informed Neural Networks (PINNs). The true mass for this diffusion equation is zero over time since there is zero net flux from the domain boundaries and mass cannot be created or destroyed on the interior. 57 | 58 | ## Sources 59 | This repo contains modified versions of the code found in the following repos: 60 | - `https://github.com/a1k12/characterizing-pinns-failure-modes`: For diffusion/heat equation analytical solution (MIT license) 61 | - `https://github.com/soobinseo/Attentive-Neural-Process`: For implementation of the Attentive Neural Process (Apache 2.0 license) 62 | 63 | ## Citation 64 | If you use this code, or our work, please cite: 65 | ``` 66 | @inproceedings{hansen2023learning, 67 | title={Learning Physical Models that Can Respect Conservation Laws}, 68 | author={Hansen, Derek and Maddix, Danielle C. and Alizadeh, Shima and Gupta, Gaurav and Mahoney, Michael W}, 69 | booktitle={International Conference on Machine Learning}, 70 | year={2023}, 71 | volume = {202}, 72 | pages={12469-12510}, 73 | organization={PMLR} 74 | } 75 | -------------------------------------------------------------------------------- /deep_pdes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/__init__.py -------------------------------------------------------------------------------- /deep_pdes/attentive_neural_process/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint/* 2 | data/* 3 | runs/* 4 | prototype.ipynb 5 | test.ipynb 6 | log.txt 7 | *.pyc 8 | .ipynb_checkpoints/ 9 | -------------------------------------------------------------------------------- /deep_pdes/attentive_neural_process/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /deep_pdes/attentive_neural_process/anp.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import Tensor 4 | from torch.distributions import Normal 5 | from torch.optim import Adam 6 | from tqdm import tqdm 7 | 8 | from deep_pdes.attentive_neural_process.module import Decoder, DeterministicEncoder, LatentEncoder 9 | 10 | 11 | class ANP(pl.LightningModule): # noqa: WPS214 12 | def __init__(self, num_hidden, dim_x=1, dim_y=1, lr=1e-4, free_bits=None, checkpoint=None): 13 | super().__init__() 14 | self.latent_encoder = LatentEncoder(num_hidden, num_hidden, dim_x=dim_x, dim_y=dim_y) 15 | self.deterministic_encoder = DeterministicEncoder( 16 | num_hidden, num_hidden, dim_x=dim_x, dim_y=dim_y 17 | ) 18 | self.decoder = Decoder(num_hidden, dim_x=dim_x, dim_y=dim_y) 19 | self.lr = lr 20 | self.free_bits = free_bits 21 | 22 | if checkpoint is not None: 23 | ckpt = torch.load(checkpoint, map_location=self.device) 24 | self.load_state_dict(ckpt["state_dict"]) 25 | 26 | def forward(self, context_x, context_y, target_x, target_y=None): # noqa: WPS210 27 | num_targets = target_x.size(1) 28 | 29 | context_mu, context_var, context_z = self.latent_encoder(context_x, context_y) 30 | 31 | training = target_y is not None 32 | if training: 33 | target_mu, target_var, target_z = self.latent_encoder(target_x, target_y) 34 | z = target_z 35 | 36 | # For Generation 37 | else: 38 | z = context_z 39 | 40 | z = z.unsqueeze(1).repeat(1, num_targets, 1) 41 | # sizes are [B, T_target, H] 42 | r = self.deterministic_encoder(context_x, context_y, target_x) 43 | # mu should be the prediction of target y 44 | target_y_dist: Normal = self.decoder(r, z, target_x) 45 | 46 | if training: 47 | # get log probability 48 | recon_prob = self.log_prob(target_y_dist, target_y) 49 | recon_loss = -recon_prob.sum() 50 | # get KL divergence between prior and posterior 51 | kl = self.kl_div(context_mu, context_var, target_mu, target_var) 52 | if self.free_bits is not None: 53 | kl = torch.clamp_min(kl, self.free_bits) 54 | 55 | # maximize prob and minimize KL divergence 56 | loss = recon_loss + kl 57 | 58 | # For Generation 59 | else: 60 | kl = None 61 | loss = None 62 | 63 | return target_y_dist, kl, loss, z[:, 0, :] 64 | 65 | def log_prob(self, target_y_dist: Normal, target_y: Tensor) -> Tensor: 66 | return target_y_dist.log_prob(target_y) 67 | 68 | def sample(self, context_x, context_y, target_x, n_samples=1): 69 | loc_list = [] 70 | scale_list = [] 71 | for _ in range(n_samples): 72 | dist, _, _, _ = self.forward(context_x, context_y, target_x) 73 | loc_list.append(dist.loc) 74 | scale_list.append(dist.scale) 75 | loc = torch.stack(loc_list) 76 | scale = torch.stack(scale_list) 77 | return Normal(loc, scale) 78 | 79 | def get_loc_and_scale_batched( 80 | self, input_contexts, output_contexts, input_targets, n_samples=100, batch_size=10000 81 | ): 82 | m_list = [] 83 | s_list = [] 84 | n_targets = input_targets.shape[1] 85 | n_batches = n_targets // batch_size + 1 86 | for i in tqdm(range(n_batches)): 87 | b_strt = i * batch_size 88 | b_end = (i + 1) * batch_size 89 | dist = self.get_loc_and_scale( 90 | input_contexts, output_contexts, input_targets[:, b_strt:b_end], n_samples 91 | ) 92 | m_list.append(dist.loc.cpu()) 93 | s_list.append(dist.scale.cpu()) 94 | loc = torch.cat(m_list, 1) 95 | scale = torch.cat(s_list, 1) 96 | return Normal(loc, scale) 97 | 98 | def get_loc_and_scale(self, input_contexts, output_contexts, input_targets, n_samples=100): 99 | dist = self.sample(input_contexts, output_contexts, input_targets, n_samples=n_samples) 100 | loc = dist.loc.mean(0) 101 | if loc.shape[0] == 1: 102 | var_of_locs = 0 103 | else: 104 | var_of_locs = dist.loc.var(0) 105 | var_total = dist.scale.pow(2).mean(0) + var_of_locs 106 | scale = var_total.sqrt() 107 | return Normal(loc, scale) 108 | 109 | def training_step(self, batch, batch_idx, train=True): 110 | context_x, context_y = batch["input_contexts"], batch["output_contexts"] 111 | target_x, target_y = batch["input_targets"], batch["output_targets"] 112 | _, kl, loss, _ = self.forward(context_x, context_y, target_x, target_y) 113 | if train: 114 | self.log("train_loss", loss) 115 | return loss 116 | 117 | def validation_step(self, batch, batch_idx): 118 | loss = self.training_step(batch, batch_idx, train=False) 119 | self.log("val_loss", loss) 120 | return loss 121 | 122 | def configure_optimizers(self): 123 | return Adam(self.parameters(), lr=self.lr) 124 | 125 | def kl_div(self, prior_mu, prior_var, posterior_mu, posterior_var): 126 | kl_div = torch.exp(posterior_var) + (posterior_mu - prior_mu) ** 2 127 | kl_div /= torch.exp(prior_var) 128 | kl_div -= 1.0 129 | kl_div += prior_var - posterior_var 130 | return 0.5 * kl_div.sum() 131 | 132 | def grad_of_mean_wrt_target(self, context_x, context_y, target_x): 133 | target_y_dist, _, _ = self.forward(context_x, context_y, target_x) 134 | return self._grad_of_mean_wrt_target(target_y_dist, target_x).loc 135 | 136 | def _grad_of_mean_wrt_target(self, target_y_dist: Normal, target_x): 137 | mean_target_y = target_y_dist.loc 138 | scale_target_y = target_y_dist.scale 139 | grad_mean = torch.autograd.grad( 140 | mean_target_y, 141 | target_x, 142 | torch.ones_like(mean_target_y), 143 | create_graph=True, 144 | retain_graph=True, 145 | )[0] 146 | grad_scale = torch.autograd.grad( 147 | scale_target_y, 148 | target_x, 149 | torch.ones_like(mean_target_y), 150 | create_graph=True, 151 | retain_graph=True, 152 | )[0].abs() 153 | return Normal(grad_mean, grad_scale) 154 | -------------------------------------------------------------------------------- /deep_pdes/attentive_neural_process/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from einops import rearrange 5 | from torch import nn # noqa: WPS458 6 | from torch.distributions import Normal 7 | from torch.nn import functional # noqa: WPS458 8 | 9 | 10 | class Linear(nn.Module): 11 | def __init__(self, in_dim, out_dim, bias=True, w_init="linear"): 12 | super().__init__() 13 | self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) 14 | 15 | nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init)) 16 | 17 | def forward(self, x): # noqa: WPS111 18 | return self.linear_layer(x) 19 | 20 | 21 | class MLP(nn.Module): 22 | def __init__(self, in_dim, hidden_dims, out_dim, residual=False): 23 | super().__init__() 24 | layers = [] 25 | in_sizes = [in_dim] + hidden_dims 26 | out_sizes = hidden_dims + [out_dim] 27 | for in_size, out_size in zip(in_sizes, out_sizes): 28 | layers.append(nn.Linear(in_size, out_size)) 29 | layers.append(nn.ReLU()) 30 | self.net = nn.Sequential(*layers[:-1]) 31 | self.residual = residual 32 | 33 | def forward(self, x): 34 | y = self.net(x) 35 | if self.residual: 36 | y += x 37 | return y 38 | 39 | 40 | class LatentEncoder(nn.Module): 41 | def __init__(self, num_hidden, num_latent, dim_x, dim_y): 42 | super().__init__() 43 | input_dim = dim_x + dim_y 44 | self.input_projection = Linear(input_dim, num_hidden) 45 | self.self_attentions = nn.ModuleList([Attention(num_hidden) for _ in range(2)]) 46 | self.penultimate_layer = Linear(num_hidden, num_hidden, w_init="relu") 47 | self.mu = Linear(num_hidden, num_latent) 48 | self.log_sigma = Linear(num_hidden, num_latent) 49 | 50 | def forward(self, x, y): # noqa: WPS111 51 | # concat location (x) and value (y) 52 | encoder_input = torch.cat([x, y], dim=-1) 53 | 54 | # project vector with dimension 3 --> num_hidden 55 | encoder_input = self.input_projection(encoder_input) 56 | 57 | # self attention layer 58 | for attention in self.self_attentions: 59 | encoder_input, _ = attention(encoder_input, encoder_input, encoder_input) 60 | 61 | # mean 62 | hidden = encoder_input.mean(dim=1) 63 | hidden = torch.relu(self.penultimate_layer(hidden)) 64 | 65 | # get mu and sigma 66 | mu = self.mu(hidden) 67 | log_sigma = self.log_sigma(hidden) 68 | 69 | # reparameterization trick 70 | std = torch.exp(0.5 * log_sigma) 71 | eps = torch.randn_like(std) 72 | z = eps.mul(std).add_(mu) # noqa: WPS111 73 | 74 | return mu, log_sigma, z 75 | 76 | 77 | class DeterministicEncoder(nn.Module): 78 | def __init__(self, num_hidden, num_latent, dim_x, dim_y): 79 | super().__init__() 80 | self.self_attentions = nn.ModuleList([Attention(num_hidden) for _ in range(2)]) 81 | self.cross_attentions = nn.ModuleList([Attention(num_hidden) for _ in range(2)]) 82 | self.input_projection = Linear(dim_x + dim_y, num_hidden) 83 | self.context_projection = Linear(dim_x, num_hidden) 84 | self.target_projection = Linear(dim_x, num_hidden) 85 | 86 | def forward(self, context_x, context_y, target_x): 87 | # concat context location (x), context value (y) 88 | encoder_input = torch.cat([context_x, context_y], dim=-1) 89 | 90 | # project vector with dimension 3 --> num_hidden 91 | encoder_input = self.input_projection(encoder_input) 92 | 93 | # self attention layer 94 | for self_att in self.self_attentions: 95 | encoder_input, _ = self_att(encoder_input, encoder_input, encoder_input) 96 | 97 | # query: target_x, key: context_x, value: representation 98 | query = self.target_projection(target_x) 99 | keys = self.context_projection(context_x) 100 | 101 | # cross attention layer 102 | for cross_att in self.cross_attentions: 103 | query, _ = cross_att(keys, encoder_input, query) 104 | 105 | return query 106 | 107 | 108 | class Decoder(nn.Module): 109 | def __init__(self, num_hidden, dim_x, dim_y): 110 | super().__init__() 111 | self.target_projection = Linear(dim_x, num_hidden) 112 | self.linears = nn.ModuleList( 113 | [ 114 | Linear(num_hidden * 3, num_hidden * 3, w_init="relu") for _ in range(3) 115 | ] # noqa: WPS221 116 | ) 117 | self.final_projection = Linear(num_hidden * 3, dim_y * 2) 118 | self.dim_y = dim_y 119 | 120 | def forward(self, r, z, target_x, min_sigma=1e-3): # noqa: WPS111 121 | batch_size, num_targets, _ = target_x.size() 122 | # project vector with dimension 2 --> num_hidden 123 | target_x = self.target_projection(target_x) 124 | 125 | # concat all vectors (r,z,target_x) 126 | hidden = torch.cat((r, z, target_x), dim=-1) 127 | 128 | # mlp layers 129 | for linear in self.linears: 130 | hidden = torch.relu(linear(hidden)) 131 | 132 | # get mu and sigma 133 | y_pred = self.final_projection(hidden) 134 | y_mu, y_sigma = torch.split(y_pred, (self.dim_y, self.dim_y), -1) # noqa: WPS221 135 | y_sigma = functional.softplus(y_sigma) + min_sigma 136 | 137 | return Normal(y_mu, y_sigma) 138 | 139 | 140 | class MultiheadAttention(nn.Module): 141 | def __init__(self, num_hidden_k): 142 | super().__init__() 143 | 144 | self.num_hidden_k = num_hidden_k 145 | self.attn_dropout = nn.Dropout(p=0.1) 146 | 147 | def forward(self, key, value, query): # noqa: WPS110 148 | # Get attention score 149 | attn = torch.bmm(query, key.transpose(1, 2)) 150 | attn = attn / math.sqrt(self.num_hidden_k) 151 | 152 | attn = torch.softmax(attn, dim=-1) 153 | 154 | # Dropout 155 | attn = self.attn_dropout(attn) 156 | 157 | # Get Context Vector 158 | output = torch.bmm(attn, value) 159 | 160 | return output, attn 161 | 162 | 163 | class Attention(nn.Module): 164 | def __init__(self, num_hidden, n_heads=4): 165 | super().__init__() 166 | 167 | self.num_hidden = num_hidden 168 | self.num_hidden_per_attn = num_hidden // n_heads 169 | self.n_heads = n_heads 170 | 171 | self.encoder_key = Linear(num_hidden, num_hidden, bias=False) 172 | self.encoder_value = Linear(num_hidden, num_hidden, bias=False) 173 | self.encoder_query = Linear(num_hidden, num_hidden, bias=False) 174 | 175 | self.multihead = MultiheadAttention(self.num_hidden_per_attn) 176 | 177 | self.residual_dropout = nn.Dropout(p=0.1) 178 | 179 | self.final_linear = Linear(num_hidden * 2, num_hidden) 180 | 181 | self.layer_norm = nn.LayerNorm(num_hidden) 182 | 183 | def forward(self, key, value, query): # noqa: WPS110 184 | pattern = "b n (nh nhpa) -> (nh b) n nhpa" 185 | key_enc = rearrange(self.encoder_key(key), pattern, nh=self.n_heads) 186 | value_enc = rearrange(self.encoder_value(value), pattern, nh=self.n_heads) 187 | query_enc = rearrange(self.encoder_query(query), pattern, nh=self.n_heads) 188 | 189 | # Get context vector 190 | output, attns = self.multihead(key_enc, value_enc, query_enc) 191 | 192 | # Concatenate all multihead context vector 193 | output = rearrange(output, "(nh b) sq nhpa -> b sq (nh nhpa)", nh=self.n_heads) 194 | # Concatenate context vector with input (most important) 195 | output = torch.cat([query, output], dim=-1) 196 | 197 | # Final linear 198 | output = self.final_linear(output) 199 | 200 | # Residual dropout & connection 201 | output = self.residual_dropout(output) 202 | output += query 203 | 204 | # Layer normalization 205 | output = self.layer_norm(output) 206 | 207 | return output, attns 208 | -------------------------------------------------------------------------------- /deep_pdes/attentive_neural_process/probconserv.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Literal, Optional 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from einops import rearrange, reduce 7 | from scipy.optimize import nnls 8 | from torch import Tensor, nn 9 | from torch.distributions import Normal 10 | from torch.optim import Adam 11 | from tqdm import tqdm 12 | 13 | from deep_pdes.attentive_neural_process.anp import ANP 14 | 15 | LimitingMode = Literal["physnp", "hcnp"] 16 | 17 | 18 | class PhysNP(pl.LightningModule): 19 | def __init__( # noqa: WPS211 20 | self, 21 | anp: ANP, 22 | lr: float = 1e-4, 23 | constraint_precision_train=1e2, 24 | min_var_sample=1e-8, 25 | train_precision=True, 26 | riemann_type="trapezoid", 27 | non_linear_ineq_constraint=False, 28 | second_deriv_alpha: Optional[float] = None, 29 | limiting_mode: Optional[LimitingMode] = None, 30 | return_full_cov: bool = True, 31 | use_double_on_constraint: bool = True, 32 | ) -> None: 33 | super().__init__() 34 | self.anp = anp 35 | self.lr = lr 36 | self.train_precision = train_precision 37 | if self.train_precision: 38 | self.log_constraint_precision_train = nn.Parameter( 39 | torch.tensor(constraint_precision_train).log() 40 | ) 41 | else: 42 | self.register_buffer( 43 | "log_constraint_precision_train", torch.tensor(constraint_precision_train).log() 44 | ) 45 | self.min_var_sample = min_var_sample 46 | self.riemann_type = riemann_type 47 | self.non_linear_ineq_constraint = non_linear_ineq_constraint 48 | self.second_deriv_alpha = second_deriv_alpha 49 | self.limiting_mode = limiting_mode 50 | self.return_full_cov = return_full_cov 51 | self.use_double_on_constraint = use_double_on_constraint 52 | 53 | @property 54 | def constraint_precision_train(self): 55 | return self.log_constraint_precision_train.exp() 56 | 57 | def forward(self, context_x, context_y, target_x, mass_rhs, target_y=None): 58 | context_x_flat = rearrange(context_x, "b nt nx d -> b (nt nx) d") 59 | context_y_flat = rearrange(context_y, "b nt nx d -> b (nt nx) d") 60 | target_x_flat = rearrange(target_x, "b nt nx d -> b (nt nx) d") 61 | if target_y is not None: 62 | target_y_flat = rearrange(target_y, "b nt nx d -> b (nt nx) d") 63 | else: 64 | target_y_flat = None 65 | 66 | target_y_dist, kl, _, z = self.anp.forward( 67 | context_x_flat, context_y_flat, target_x_flat, target_y_flat 68 | ) 69 | if target_y is not None: 70 | log_prob = self._constrained_log_prob(target_y_dist, target_x, target_y, mass_rhs, z) 71 | recon_loss = -log_prob.sum() 72 | loss = recon_loss + kl 73 | else: 74 | kl = None 75 | loss = None 76 | return target_y_dist, kl, loss, z 77 | 78 | def training_step(self, batch, batch_idx, train=True): 79 | out = {} 80 | nt = batch["n_targets_t"][0].item() 81 | for k in ("input_contexts", "output_contexts", "input_targets", "output_targets"): 82 | out[k] = rearrange(batch[k], "nf (nt nx) d -> nf nt nx d", nt=nt) 83 | context_x, context_y = out["input_contexts"], out["output_contexts"] 84 | target_x, target_y = out["input_targets"], out["output_targets"] 85 | mass_rhs = batch["mass_rhs"] 86 | _, _, loss, _ = self.forward(context_x, context_y, target_x, mass_rhs, target_y) 87 | if train: 88 | self.log("train_loss", loss) 89 | return loss 90 | 91 | def validation_step(self, batch, batch_idx): 92 | loss = self.training_step(batch, batch_idx, train=False) 93 | self.log("val_loss", loss) 94 | return loss 95 | 96 | def configure_optimizers(self): 97 | return Adam(self.parameters(), lr=self.lr) 98 | 99 | def get_loc_and_scale_batched( # noqa: WPS210 100 | self, 101 | input_contexts, 102 | output_contexts, 103 | input_targets, 104 | mass_rhs, 105 | n_samples=100, 106 | batch_size=500, 107 | ): 108 | m_list = [] 109 | s_list = [] 110 | cov_list = [] 111 | nf, nt, nx, _ = input_targets.shape 112 | n_batches = nt // batch_size + 1 113 | for i in tqdm(range(n_batches)): 114 | t_strt = i * batch_size 115 | t_end = (i + 1) * batch_size 116 | dist, cov = self.get_loc_and_scale( 117 | input_contexts, 118 | output_contexts, 119 | input_targets[:, t_strt:t_end], 120 | mass_rhs[:, t_strt:t_end], 121 | n_samples, 122 | ) 123 | m_list.append(dist.loc.cpu()) 124 | s_list.append(dist.scale.cpu()) 125 | if cov is not None: 126 | cov_list.append(cov.cpu()) 127 | loc = torch.cat(m_list, 1) 128 | scale = torch.cat(s_list, 1) 129 | # Check if we are returning covariance. 130 | if len(cov_list) > 0: 131 | cov = torch.cat(cov_list, 1) 132 | else: 133 | cov = None 134 | return Normal(loc, scale), cov 135 | 136 | def get_loc_and_scale( 137 | self, input_contexts, output_contexts, input_targets, mass_rhs, n_samples 138 | ): 139 | dist, cov = self.sample( 140 | input_contexts, output_contexts, input_targets, mass_rhs, n_samples=n_samples 141 | ) 142 | loc = dist.loc.mean(0) 143 | if dist.loc.shape[0] == 1: 144 | var_of_locs = 0 145 | else: 146 | var_of_locs = dist.loc.var(0) 147 | var_total = dist.scale.pow(2).mean(0) + var_of_locs 148 | scale = var_total.sqrt() 149 | 150 | if cov is not None: 151 | cov_total = cov.mean(0) 152 | else: 153 | cov_total = None 154 | 155 | return Normal(loc, scale), cov_total 156 | 157 | def sample( # noqa: WPS210 158 | self, input_contexts, output_contexts, input_targets, mass_rhs, n_samples 159 | ): 160 | loc_list = [] 161 | cov_list = [] if self.return_full_cov else None 162 | scale_list = [] 163 | for _ in range(n_samples): 164 | dist, _, _, z = self.forward(input_contexts, output_contexts, input_targets, mass_rhs) 165 | loc_i, cov_i = self._apply_constraint(dist, input_targets, mass_rhs, z) 166 | var_i = torch.diagonal(cov_i, dim1=2, dim2=3) 167 | var_i = var_i.unsqueeze(-1) 168 | var_i = var_i.clamp_min(self.min_var_sample) 169 | scale_i = var_i.sqrt() 170 | loc_list.append(loc_i) 171 | if cov_list is not None: 172 | cov_list.append(cov_i.cpu()) 173 | scale_list.append(scale_i) 174 | loc = torch.stack(loc_list) 175 | scale = torch.stack(scale_list) 176 | if cov_list is not None: 177 | cov = torch.stack(cov_list) 178 | else: 179 | cov = None 180 | return Normal(loc, scale), cov 181 | 182 | def _apply_constraint( # noqa: WPS210 183 | self, target_y_dist, target_inputs: Tensor, mass_rhs, z: Tensor 184 | ): 185 | # target_inputs: nf nt nx 2 186 | # target_outputs: nf nt nx 1 187 | # mass_rhs: nf nt 188 | nf, nt, nx, _ = target_inputs.shape 189 | 190 | mu = rearrange(target_y_dist.loc, "nf (nt nx) 1 -> nf nt nx 1", nt=nt, nx=nx) 191 | masses_at_t = rearrange(mass_rhs, "nf nt -> nf nt 1 1") 192 | 193 | input_grid = rearrange(target_inputs, "nf nt nx d -> nf nt nx d", nt=nt, nx=nx) 194 | x = input_grid[:, :, :, 1] 195 | 196 | x_delta = self._get_riemman_delta(x) 197 | 198 | g = rearrange(x_delta, "nf nt nx -> nf nt 1 nx") 199 | precis_g = self._get_constraint_precision(z) 200 | precis_g = rearrange(precis_g, "nf nt -> nf nt 1 1") 201 | 202 | eye = torch.eye(nx, device=g.device) 203 | eye = rearrange(eye, "nx1 nx2 -> 1 1 nx1 nx2") 204 | cov = target_y_dist.scale.pow(2) 205 | cov = rearrange(cov, "nf (nt nx) 1 -> nf nt nx 1", nt=nt) 206 | 207 | if self.second_deriv_alpha is not None: 208 | g2 = _get_second_deriv_mat(nx).to(g.device) 209 | g2 = rearrange(g2, "nxm2 nx -> 1 1 nxm2 nx") 210 | var_g2 = _get_second_derivative_var(cov, alpha=self.second_deriv_alpha).to(g.device) 211 | b = torch.zeros(1, 1, device=g2.device) 212 | mu, cov_mat = _apply_g(g2, var_g2, cov, mu, b) 213 | else: 214 | cov_mat = cov * eye 215 | 216 | if self.limiting_mode == "physnp": 217 | var_g = torch.zeros_like(precis_g) 218 | elif self.limiting_mode == "hcnp": 219 | var_g = torch.zeros_like(precis_g) 220 | cov_mat = eye 221 | else: 222 | var_g = 1 / precis_g 223 | 224 | if self.use_double_on_constraint: 225 | g = g.double() 226 | var_g = var_g.double() 227 | cov_mat = cov_mat.double() 228 | mu = mu.double() 229 | masses_at_t = masses_at_t.double() 230 | 231 | n_g = g.size(2) 232 | device = g.device 233 | dtype = g.dtype 234 | eye_g = torch.ones(1, 1, n_g, n_g, device=device, dtype=dtype) 235 | g_times_cov = g.matmul(cov_mat) 236 | gtr = g.transpose(3, 2) 237 | small_a = eye_g * var_g + (g_times_cov.matmul(gtr)) 238 | rinv1 = torch.linalg.solve(small_a, g_times_cov) 239 | if self.limiting_mode == "hcnp": 240 | new_cov = cov * eye 241 | else: 242 | gtr_rinv1 = gtr.matmul(rinv1) 243 | new_cov = cov_mat.matmul(eye - gtr_rinv1) 244 | rinv2 = torch.linalg.solve(small_a, g.matmul(mu) - masses_at_t) 245 | new_mu = mu - cov_mat.matmul(gtr.matmul(rinv2)) 246 | 247 | if self.non_linear_ineq_constraint: 248 | raise NotImplementedError() 249 | return new_mu.float(), new_cov.float() 250 | 251 | def _constrained_log_prob( # noqa: WPS210 252 | self, 253 | target_y_dist: Normal, 254 | target_inputs: Tensor, 255 | target_outputs: Tensor, 256 | mass_rhs: Tensor, 257 | z: Tensor, 258 | ) -> Tensor: 259 | # target_inputs: b nt nx 2 260 | # target_outputs: b nt nx 1 261 | # mass_rhs: b nt 262 | b, nt, nx, _ = target_inputs.shape 263 | target_outputs_flat = rearrange(target_outputs, "nf nt nx 1 -> nf (nt nx) 1") 264 | prior_log_prob_flat = self.anp.log_prob(target_y_dist, target_outputs_flat) 265 | prior_log_prob = rearrange(prior_log_prob_flat, "nf (nt nx) 1 -> nf nt nx 1", nt=nt, nx=nx) 266 | 267 | input_grid = rearrange(target_inputs, "b nt nx d -> b nt nx d", nt=nt, nx=nx) 268 | output_grid = rearrange(target_outputs, "b nt nx 1 -> b nt nx", nt=nt, nx=nx) 269 | 270 | x = input_grid[:, :, :, 1] 271 | x_delta = self._get_riemman_delta(x) 272 | 273 | mean_constraint = (x_delta * output_grid).sum(-1) # b nt 274 | precis_constraint = self._get_constraint_precision(z) 275 | sd_constraint = precis_constraint.pow(-0.5) 276 | constraint_dist = Normal(mean_constraint, sd_constraint) 277 | constraint_log_prob = constraint_dist.log_prob(mean_constraint) 278 | 279 | mu: Tensor = rearrange(target_y_dist.loc, "n (nt nx) 1 -> n nt nx", nt=nt, nx=nx) 280 | sd: Tensor = rearrange(target_y_dist.scale, "n (nt nx) 1 -> n nt nx", nt=nt, nx=nx) 281 | variance = sd.pow(2) 282 | mean_normalizing_constant_dist = (x_delta * mu).sum(-1) 283 | x_delta_squared = x_delta.pow(2) 284 | var_normalizing_constant_dist = (x_delta_squared * variance).sum(-1) + sd_constraint.pow(2) 285 | sd_normalizing_constant_dist = var_normalizing_constant_dist.sqrt() 286 | 287 | normalizing_constant_dist = Normal( 288 | mean_normalizing_constant_dist, sd_normalizing_constant_dist 289 | ) 290 | normalizing_constant = normalizing_constant_dist.log_prob(mean_constraint) 291 | 292 | return ( 293 | reduce(prior_log_prob, "nf nt nx 1 -> nf", "sum") 294 | + reduce(constraint_log_prob, "nf nt -> nf", "sum") 295 | - reduce(normalizing_constant, "nf nt -> nf", "sum") 296 | ) 297 | 298 | def _get_riemman_delta(self, x): 299 | x_diff = torch.diff(x, dim=2) 300 | assert torch.all(x_diff >= 0) 301 | zero_pad_shape = (*x.shape[:2], 1) 302 | zero_pad = torch.zeros(*zero_pad_shape, device=x.device) 303 | x_delta_l: Tensor = torch.cat((x_diff, zero_pad), dim=2) 304 | x_delta_r: Tensor = torch.cat((zero_pad, x_diff), dim=2) 305 | if self.riemann_type == "trapezoid": 306 | x_delta = 0.5 * (x_delta_l + x_delta_r) 307 | elif self.riemann_type == "rhs": 308 | x_delta = x_delta_r 309 | else: 310 | return NotImplementedError() 311 | return x_delta 312 | 313 | def _get_constraint_precision(self, z): 314 | # z: nf d_z 315 | # max_delta: nf nt 316 | # precis_f: nf nt 317 | return self.constraint_precision_train.reshape(1, 1) 318 | 319 | 320 | def _apply_g(g, var_g, cov, mu, mass_rhs): # noqa: WPS210 321 | _, _, nx, _ = mu.shape 322 | _, _, ng, _ = g.shape 323 | eye = torch.eye(nx, device=g.device) 324 | eye = rearrange(eye, "nx1 nx2 -> 1 1 nx1 nx2") 325 | eye_g = torch.eye(ng, device=g.device) 326 | eye_g = rearrange(eye_g, "ng1 ng2 -> 1 1 ng1 ng2") 327 | gtr = g.transpose(3, 2) 328 | small_a = eye_g * var_g + (g.matmul(cov * gtr)) 329 | rinv1 = torch.linalg.solve(small_a, g.matmul(cov * eye)) 330 | new_cov = cov * (eye - gtr.matmul(rinv1)) 331 | 332 | b = mass_rhs.unsqueeze(-1).unsqueeze(-1) 333 | rinv2 = torch.linalg.solve(small_a, g.matmul(mu) - b) 334 | new_mu = mu - cov * gtr.matmul(rinv2) 335 | return new_mu, new_cov 336 | 337 | 338 | def _get_second_deriv_mat(nx): 339 | eye = torch.eye(nx) 340 | eye1 = eye[:-2] 341 | eye2 = eye[1:-1] * -2 342 | eye3 = eye[2:] 343 | return eye1 + eye2 + eye3 344 | 345 | 346 | def _get_second_deriv_mat_autocor(nx, alpha=0.5): 347 | eye = torch.eye(nx) 348 | eye1 = eye[:-2] + ((alpha - 2) * alpha) 349 | eye2 = eye[1:-1] * -2 + alpha 350 | eye3 = eye[2:] 351 | return eye1 + eye2 + eye3 352 | 353 | 354 | def _get_second_derivative_var(cov: Tensor, alpha=0.5): 355 | nf, nt, nx, _ = cov.shape 356 | cov0 = cov[:, :, :-2] 357 | cov1 = cov[:, :, 1:-1] 358 | cov2 = cov[:, :, 2:] 359 | 360 | return ( 361 | cov0 362 | + 4 * cov1 363 | + cov2 364 | - 4 * alpha * cov0.sqrt() * cov1.sqrt() 365 | + 2 * (alpha**2) * cov0.sqrt() * cov2.sqrt() 366 | - 4 * alpha * cov1.sqrt() * cov2.sqrt() 367 | ) 368 | 369 | 370 | def get_mu_tilde_as_projection(mu, variance): 371 | mu = rearrange(mu, "n nt nx 1 -> n nt nx") 372 | variance = rearrange(variance, "n nt nx 1 -> n nt nx") 373 | mu_hat = reduce(mu, "n nt nx -> n nt 1", "mean") 374 | var_mean = reduce(variance, "n nt nx -> n nt 1", "mean") 375 | vec = variance / var_mean 376 | return mu - mu_hat * vec 377 | 378 | 379 | def get_cov_tilde_as_projection(variance): 380 | variance = rearrange(variance, "n nt nx 1 -> n nt nx") 381 | n, nt, nx = variance.shape 382 | eye = torch.eye(nx, device=variance.device) 383 | eye = rearrange(eye, "nx1 nx2 -> 1 1 nx1 nx2") 384 | cov_hat = reduce(variance.unsqueeze(-1) * eye, "n nt nx1 nx2 -> n nt 1 nx2", "mean") 385 | var_mean = reduce(variance, "n nt nx -> n nt 1 1", "mean") 386 | vec = rearrange(variance, "n nt nx -> n nt nx 1") / var_mean 387 | return variance.unsqueeze(-1) * eye - cov_hat * vec 388 | 389 | 390 | InequalityConstraint = Literal["monotone", "nonneg"] 391 | 392 | 393 | def apply_non_linear_ineq_constraint( 394 | new_mu: Tensor, new_cov: Tensor, tol=1e-8, max_iter=1, mode: InequalityConstraint = "monotone" 395 | ): 396 | # Return mean truncated to be decreasing and non-zero. 397 | loc = rearrange(new_mu, "nf nt nx 1 -> nf nt nx") 398 | nf, nt, nx = loc.shape 399 | new_loc_list = [] 400 | for n in tqdm(range(nf), desc="Applying constraint"): 401 | loc_n = loc[n] 402 | new_cov_n = new_cov[n] 403 | new_locs_n = _apply_non_linear_constraint_one_f(loc_n, new_cov_n, max_iter, tol, mode) 404 | new_loc_list.append(new_locs_n) 405 | new_locs = torch.stack(new_loc_list, dim=0) 406 | return rearrange(new_locs, "nf nt nx 1 -> nf nt nx 1") 407 | 408 | 409 | def _apply_non_linear_constraint_one_f( 410 | loc_n: Tensor, new_cov_n: Tensor, max_iter: int, tol: float, mode: InequalityConstraint 411 | ): 412 | new_locs_i = [] 413 | nt, nx = loc_n.shape 414 | for t in range(nt): 415 | loc_t = loc_n[t].unsqueeze(-1) 416 | cov_t = new_cov_n[t] 417 | new_loc_t = _apply_non_linear_constraint_at_t(loc_t, cov_t, max_iter, tol, mode) 418 | new_locs_i.append(new_loc_t) 419 | return torch.stack(new_locs_i, dim=0) 420 | 421 | 422 | def _apply_non_linear_constraint_at_t( # noqa: WPS210, WPS231 423 | loc_t: Tensor, cov_t: Tensor, max_iter: int, tol: float, mode: InequalityConstraint 424 | ): 425 | nx = loc_t.shape[0] 426 | eye = torch.eye(nx).to(loc_t.device) 427 | chol_t: Tensor = torch.linalg.cholesky(cov_t) 428 | chinv: Tensor = torch.linalg.solve_triangular(chol_t, eye, upper=False) 429 | chinv_loc_t = chinv.matmul(loc_t) 430 | if mode == "nonneg": 431 | a_matrix = chinv.numpy() 432 | b = chinv_loc_t.squeeze(-1).numpy() 433 | loc_t_np, _ = nnls(a_matrix, b) 434 | return torch.from_numpy(loc_t_np).unsqueeze(-1) 435 | if mode == "monotone": 436 | diff_matrix = _construct_diff_matrix(nx) 437 | diff_matrix_inv = np.linalg.inv(diff_matrix) 438 | a_matrix = np.matmul(chinv.numpy(), diff_matrix_inv) 439 | b = chinv_loc_t.squeeze(-1).numpy() 440 | loc_t_np_diff, _ = nnls(a_matrix, b) 441 | loc_t_np = np.matmul(diff_matrix_inv, loc_t_np_diff) 442 | return torch.from_numpy(loc_t_np).unsqueeze(-1) 443 | 444 | 445 | _diff_matrices: Dict[int, np.ndarray] = {} 446 | 447 | 448 | def _construct_diff_matrix(nx: int): 449 | try: 450 | diff_matrix = _diff_matrices[nx] 451 | except Exception: 452 | eye = np.eye(nx) 453 | eye2 = np.eye(nx, nx, k=1) 454 | diff_matrix = eye - eye2 455 | _diff_matrices[nx] = diff_matrix 456 | return diff_matrix 457 | -------------------------------------------------------------------------------- /deep_pdes/attentive_neural_process/softc.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from einops import rearrange 6 | from torch import Tensor 7 | from torch.optim import Adam 8 | 9 | from deep_pdes.attentive_neural_process.anp import ANP 10 | 11 | 12 | class StefanPressureFn: 13 | def __init__(self, k_max: float) -> None: 14 | self.k_max = k_max 15 | 16 | def __call__(self, p: Tensor, params: Tensor) -> Tensor: 17 | p_stars = params[:, 0] 18 | return (p >= p_stars) * self.k_max 19 | 20 | 21 | class PMEPressureFn: 22 | def __call__(self, p: Tensor, params: Tensor) -> Tensor: 23 | degrees = params[:, 0] 24 | return torch.relu(p).pow(degrees) 25 | 26 | 27 | class HeatPressureFn: 28 | def __call__(self, p: Tensor, params: Tensor) -> Tensor: 29 | conductivities = params[:, 0] 30 | return conductivities.expand(*p.shape) 31 | 32 | 33 | PressureFn = Union[StefanPressureFn, PMEPressureFn, HeatPressureFn] 34 | 35 | 36 | class PINP(pl.LightningModule): 37 | def __init__(self, anp: ANP, pressure_fn: PressureFn, pinns_lambda: float = 1.0, lr=1e-3): 38 | super().__init__() 39 | self.anp = anp 40 | self.pressure_fn = pressure_fn 41 | self.pinns_lambda = pinns_lambda 42 | self.lr = lr 43 | 44 | def get_loc_and_scale_batched( 45 | self, input_contexts, output_contexts, input_targets, n_samples=100, batch_size=10000 46 | ): 47 | return self.anp.get_loc_and_scale_batched( 48 | input_contexts, 49 | output_contexts, 50 | input_targets, 51 | n_samples=n_samples, 52 | batch_size=batch_size, 53 | ) 54 | 55 | def training_step(self, batch, batch_idx, train=True): 56 | context_x, context_y = batch["input_contexts"], batch["output_contexts"] 57 | target_x, target_y = batch["input_targets"], batch["output_targets"] 58 | target_x.requires_grad_(True) 59 | target_y_dist, _, anp_loss, _ = self.anp.forward(context_x, context_y, target_x, target_y) 60 | 61 | if train: 62 | params = rearrange(batch["params"], "nf d -> nf d 1 1") 63 | p = target_y_dist.loc 64 | pinns_loss = self._get_pinns_loss(params, p, target_x) 65 | else: 66 | pinns_loss = 0 67 | 68 | if train: 69 | self.log("train_anp_loss", anp_loss) 70 | self.log("train_pinn_loss", pinns_loss) 71 | return anp_loss + self.pinns_lambda * pinns_loss 72 | 73 | def validation_step(self, batch, batch_idx): 74 | loss = self.training_step(batch, batch_idx, train=False) 75 | self.log("val_loss", loss) 76 | return loss 77 | 78 | def configure_optimizers(self): 79 | return Adam(self.parameters(), lr=self.lr) 80 | 81 | def _get_pinns_loss(self, params: Tensor, p: Tensor, target_x: Tensor): 82 | p_d = partial_deriv(p, target_x) 83 | p_t, p_x = p_d.split((1, 1), -1) 84 | k_times_p_x = self.pressure_fn(p, params) * p_x 85 | k_times_p_x_d = partial_deriv(k_times_p_x, target_x) 86 | k_times_p_x_x = k_times_p_x_d[:, :, 1:2] 87 | f_pred = p_t - k_times_p_x_x 88 | return torch.mean(f_pred**2) 89 | 90 | 91 | def partial_deriv(out_tensor: Tensor, in_tensor: Tensor): 92 | return torch.autograd.grad( 93 | out_tensor, 94 | in_tensor, 95 | grad_outputs=torch.ones_like(out_tensor), 96 | retain_graph=True, 97 | create_graph=True, 98 | )[0] 99 | 100 | 101 | class GPMEDifferentialPenalty: 102 | def __init__(self, pressure_fn: PressureFn) -> None: 103 | self.pressure_fn = pressure_fn 104 | 105 | def get_pinns_loss(self, params: Tensor, p: Tensor, target_x: Tensor): 106 | p_d = partial_deriv(p, target_x) 107 | p_t, p_x = p_d.split((1, 1), -1) 108 | k_times_p_x = self.pressure_fn(p, params) * p_x 109 | k_times_p_x_d = partial_deriv(k_times_p_x, target_x) 110 | k_times_p_x_x = k_times_p_x_d[:, :, 1:2] 111 | f_pred = p_t - k_times_p_x_x 112 | return torch.mean(f_pred**2) 113 | 114 | 115 | class LinearAdvectionDifferentialPenalty: 116 | def get_pinns_loss(self, params: Tensor, p: Tensor, target_x: Tensor): 117 | p_d = partial_deriv(p, target_x) 118 | p_t, p_x = p_d.split((1, 1), -1) 119 | beta = params 120 | f_pred = p_t + beta * p_x 121 | return torch.mean(f_pred**2) 122 | 123 | 124 | class BurgersDifferentialPenalty: 125 | def get_pinns_loss(self, params: Tensor, p: Tensor, target_x: Tensor): 126 | p_d = partial_deriv(p, target_x) 127 | p_t, p_x = p_d.split((1, 1), -1) 128 | 129 | # Derivative of 0.5 * p**2 wrt x 130 | p2_x = p * p_x 131 | f_pred = p_t + p2_x 132 | return torch.mean(f_pred**2) 133 | 134 | 135 | DifferentialPenalty = Union[GPMEDifferentialPenalty, LinearAdvectionDifferentialPenalty] 136 | 137 | 138 | class SoftcANP(pl.LightningModule): 139 | def __init__( 140 | self, 141 | anp: ANP, 142 | differential_penalty: DifferentialPenalty, 143 | pinns_lambda: float = 1.0, 144 | lr=1e-3, 145 | ): 146 | super().__init__() 147 | self.anp = anp 148 | self.differential_penalty = differential_penalty 149 | self.pinns_lambda = pinns_lambda 150 | self.lr = lr 151 | 152 | def get_loc_and_scale_batched( 153 | self, input_contexts, output_contexts, input_targets, n_samples=100, batch_size=10000 154 | ): 155 | return self.anp.get_loc_and_scale_batched( 156 | input_contexts, 157 | output_contexts, 158 | input_targets, 159 | n_samples=n_samples, 160 | batch_size=batch_size, 161 | ) 162 | 163 | def training_step(self, batch, batch_idx, train=True): 164 | context_x, context_y = batch["input_contexts"], batch["output_contexts"] 165 | target_x, target_y = batch["input_targets"], batch["output_targets"] 166 | target_x.requires_grad_(True) 167 | target_y_dist, _, anp_loss, _ = self.anp.forward(context_x, context_y, target_x, target_y) 168 | 169 | if train: 170 | params = rearrange(batch["params"], "nf d -> nf d 1 1") 171 | p = target_y_dist.loc 172 | pinns_loss = self.differential_penalty.get_pinns_loss(params, p, target_x) 173 | else: 174 | pinns_loss = 0 175 | 176 | if train: 177 | self.log("train_anp_loss", anp_loss) 178 | self.log("train_pinn_loss", pinns_loss) 179 | return anp_loss + self.pinns_lambda * pinns_loss 180 | 181 | def validation_step(self, batch, batch_idx): 182 | loss = self.training_step(batch, batch_idx, train=False) 183 | self.log("val_loss", loss) 184 | return loss 185 | 186 | def configure_optimizers(self): 187 | return Adam(self.parameters(), lr=self.lr) 188 | -------------------------------------------------------------------------------- /deep_pdes/datasets/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, List, Optional, Tuple 3 | 4 | import icontract 5 | import torch 6 | from einops import repeat 7 | from torch import Tensor # noqa: WPS458 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | @icontract.invariant(lambda self: self.validate()) 12 | class ANPDataset(Dataset, icontract.DBC): # noqa: WPS214 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.mass_rhs: Optional[Tensor] = None 16 | self.n_targets_t: Optional[int] = None 17 | self.n_targets_x: Optional[int] = None 18 | 19 | @property 20 | @abstractmethod 21 | def tensors(self) -> Dict[str, Tensor]: 22 | pass # noqa: WPS420 23 | 24 | @property # type: ignore 25 | @icontract.ensure(lambda result: (result is None) or (len(result.shape) == 2)) 26 | def params(self) -> Optional[Tensor]: 27 | return None # noqa: WPS324 28 | 29 | @abstractmethod 30 | @icontract.require(lambda inputs: len(inputs.shape) == 3) 31 | @icontract.ensure(lambda result: len(result.shape) == 3) 32 | def solution(self, inputs: Tensor) -> Tensor: 33 | pass # noqa: WPS420 34 | 35 | @abstractmethod 36 | def lims(self, dimname: str) -> Tuple[float, float]: 37 | pass # noqa: WPS420 38 | 39 | @property 40 | @abstractmethod 41 | def dimnames(self) -> Tuple[str]: 42 | pass # noqa: WPS420 43 | 44 | @property 45 | @abstractmethod 46 | def batch_size(self) -> int: 47 | pass # noqa: WPS420 48 | 49 | @property 50 | def dimensions(self) -> Dict[str, int]: 51 | return { 52 | "n_functions": self.tensors["input_contexts"].shape[0], 53 | "n_contexts": self.tensors["input_contexts"].shape[1], 54 | "n_targets": self.tensors["input_targets"].shape[1], 55 | "input_dim": self.tensors["input_targets"].shape[2], 56 | "output_dim": self.tensors["output_contexts"].shape[2], 57 | } 58 | 59 | def validate(self): 60 | dim = self.dimensions 61 | shapes = { 62 | "input_contexts": (dim["n_functions"], dim["n_contexts"], dim["input_dim"]), 63 | "output_contexts": (dim["n_functions"], dim["n_contexts"], dim["output_dim"]), 64 | "input_targets": (dim["n_functions"], dim["n_targets"], dim["input_dim"]), 65 | "output_targets": (dim["n_functions"], dim["n_targets"], dim["output_dim"]), 66 | } 67 | for nm, shape in shapes.items(): 68 | assert self.tensors[nm].shape == shape, f"{nm} has incorrect shape" 69 | return True 70 | 71 | def __getitem__(self, idx) -> Dict[str, Tensor]: 72 | tensors = {k: v[idx] for k, v in self.tensors.items()} 73 | if self.params is not None: 74 | tensors["params"] = self.params[idx] 75 | if self.mass_rhs is not None: 76 | tensors["mass_rhs"] = self.mass_rhs[idx] 77 | tensors["n_targets_t"] = torch.tensor(self.n_targets_t) 78 | tensors["n_targets_x"] = torch.tensor(self.n_targets_x) 79 | return tensors 80 | 81 | def __len__(self): 82 | return self.dimensions["n_functions"] 83 | 84 | def dataloader(self): 85 | return DataLoader( 86 | self, 87 | batch_size=self.batch_size, 88 | shuffle=True, 89 | ) 90 | 91 | 92 | def sample_points_from_each_interval(n_functions: int, n_pts: int) -> Tensor: 93 | interval_starts = torch.tensor(range(n_pts)) / n_pts 94 | ts = torch.rand((n_functions, n_pts)) * (1 / n_pts) 95 | ts += interval_starts 96 | return ts 97 | 98 | 99 | def sample_points_uniformly(n_functions: int, n_pts: int, min_val=0, max_val=1) -> Tensor: 100 | return sample_uniform((n_functions, n_pts), min_val, max_val) 101 | 102 | 103 | def sample_uniform( 104 | size: Tuple[int, ...], 105 | min_val: float = 0, 106 | max_val: float = 1, 107 | sort: bool = False, 108 | ) -> Tensor: 109 | us = torch.rand(size) 110 | if sort: 111 | us, _ = torch.sort(us) 112 | return min_val + (max_val - min_val) * us 113 | 114 | 115 | def meshgrid(ts: Tensor, xs: Tensor): 116 | nf, nt = ts.shape 117 | _, nx = xs.shape 118 | ts = repeat(ts, "nf nt -> nf nt nx", nx=nx) 119 | xs = repeat(xs, "nf nx -> nf nt nx", nt=nt) 120 | return torch.stack((ts, xs), dim=-1) 121 | -------------------------------------------------------------------------------- /deep_pdes/datasets/pbc.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable 6 | 7 | 8 | class OneDimConvection: 9 | def __init__(self, beta: float, source: float = 0): 10 | self.beta = beta 11 | self.source = source 12 | 13 | def calc_over_grid(self, n_x: int, n_t: int): 14 | x = np.linspace(0, 2 * np.pi, n_x, endpoint=False) 15 | t = np.linspace(0, 1, n_t) 16 | x = x.reshape(-1, 1) 17 | t = t.reshape(-1, 1) 18 | u_vals = convection_diffusion(0, self.beta, self.source, n_x, n_t) 19 | u_star = u_vals.reshape(-1, 1) # Exact solution reshaped into (n, 1) 20 | u = u_star.reshape(len(t), len(x)) # Exact on the (x,t) grid 21 | return u, x, t 22 | 23 | 24 | def main(plot_path_str: str = None, beta=5.0, xgrid=256, nt=100): 25 | one_dim_convection = OneDimConvection(beta=beta) 26 | u, x, t = one_dim_convection.calc_over_grid(xgrid, nt) 27 | 28 | if plot_path_str is not None: 29 | plot_path = Path(plot_path_str) 30 | else: 31 | plot_path = Path("./plots") 32 | if not plot_path.exists(): 33 | plot_path.mkdir(parents=True) 34 | exact_u(u, x, t, path=plot_path / "exact.pdf") 35 | 36 | 37 | def convection_diffusion(nu, beta, source=0, nx=256, nt=100): 38 | h = 2 * np.pi / nx 39 | x = np.arange(0, 2 * np.pi, h) # not inclusive of the last point 40 | t = np.linspace(0, 1, nt).reshape(-1, 1) 41 | _, t_grid = np.meshgrid(x, t) 42 | u0 = np.sin(x) 43 | 44 | return convection_diffusion_solution(u0, t_grid, nu, beta).flatten() 45 | 46 | 47 | def convection_diffusion_solution( # noqa: WPS210 48 | x_start: np.ndarray, 49 | t_values: np.ndarray, 50 | nu: float, 51 | beta: float, 52 | source: float = 0, 53 | ): 54 | nx = x_start.shape[0] 55 | forcing_term = np.zeros_like(x_start) + source # G is the same size as u0 56 | 57 | ikx_pos = 1j * np.arange(0, nx / 2 + 1, 1) 58 | ikx_neg = 1j * np.arange(-nx / 2 + 1, 0, 1) # noqa: WPS221 59 | ikx = np.concatenate((ikx_pos, ikx_neg)) 60 | ikx2 = ikx * ikx 61 | 62 | uhat0 = np.fft.fft(x_start) 63 | nu_term = nu * ikx2 * t_values 64 | beta_term = beta * ikx * t_values 65 | nu_factor = np.exp(nu_term - beta_term) 66 | uhat = ( 67 | uhat0 * nu_factor + np.fft.fft(forcing_term) * t_values 68 | ) # for constant, fft(p) dt = fft(p)*T 69 | return np.real(np.fft.ifft(uhat)) 70 | 71 | 72 | def exact_u(u, x, t, path): 73 | fig = plt.figure(figsize=(9, 5)) 74 | sp = 111 75 | ax = fig.add_subplot(sp) 76 | 77 | h = ax.imshow( 78 | u.T, 79 | interpolation="nearest", 80 | cmap="rainbow", 81 | extent=[t.min(), t.max(), x.min(), x.max()], 82 | origin="lower", 83 | aspect="auto", 84 | ) 85 | divider = make_axes_locatable(ax) 86 | cax = divider.append_axes("right", size="5%", pad=0.1) 87 | cbar = fig.colorbar(h, cax=cax) 88 | labelsize = 15 89 | cbar.ax.tick_params(labelsize=labelsize) 90 | 91 | fontsize = 30 92 | ax.set_xlabel("t", fontweight="bold", size=fontsize) 93 | ax.set_ylabel("x", fontweight="bold", size=fontsize) 94 | ax.legend( 95 | loc="upper center", 96 | bbox_to_anchor=(0.9, -0.05), 97 | ncol=5, 98 | frameon=False, 99 | prop={"size": 15}, 100 | ) 101 | ax.tick_params(labelsize=labelsize) 102 | 103 | plt.savefig(path) 104 | plt.close() 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /deep_pdes/datasets/pme.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, List, Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange, repeat 7 | from scipy.optimize import root_scalar 8 | from scipy.special import erf 9 | from torch import Tensor # noqa: WPS458 10 | from torch.nn import functional 11 | 12 | from deep_pdes.datasets.base import ( 13 | ANPDataset, 14 | meshgrid, 15 | sample_points_from_each_interval, 16 | sample_points_uniformly, 17 | ) 18 | from deep_pdes.datasets.pbc import convection_diffusion_solution 19 | 20 | 21 | class GeneralizedPorousMediumEquation(ANPDataset): # noqa: WPS214 22 | def __init__( # noqa: WPS210, WPS211 23 | self, 24 | n_functions: int, 25 | n_contexts_t: int, 26 | n_contexts_x: int, 27 | n_targets_t: int, 28 | n_targets_x: int, 29 | batch_size: int, 30 | t_range: Tuple[float, float] = (0, 1), 31 | x_range: Tuple[float, float] = (0, 1), 32 | load_path: Optional[str] = None, 33 | ): 34 | self._batch_size = batch_size 35 | self.n_contexts_t = n_contexts_t 36 | self.n_contexts_x = n_contexts_x 37 | self.n_targets_t = n_targets_t 38 | self.n_targets_x = n_targets_x 39 | self.t_range = t_range 40 | self.x_range = x_range 41 | if load_path is not None: 42 | tensors, parameters = torch.load(load_path) # noqa: WPS110 43 | else: 44 | tensors = None 45 | parameters = self._sample_parameters(n_functions) # noqa: WPS110 46 | self._parameters = parameters # noqa: WPS110 47 | if tensors is None: 48 | tensors = self._make_solution( 49 | n_functions, 50 | n_contexts_t, 51 | n_contexts_t, 52 | n_targets_t, 53 | n_targets_x, 54 | ) 55 | self._tensors: Dict[str, Tensor] = tensors 56 | self.mass_rhs = self._mass_rhs() 57 | 58 | @property 59 | def dimnames(self): 60 | return ("t", "x") 61 | 62 | def lims(self, dimname: str): 63 | if dimname == "x": 64 | return self.x_range 65 | elif dimname == "t": 66 | return self.t_range 67 | raise ValueError() 68 | 69 | @property 70 | def tensors(self): 71 | return self._tensors 72 | 73 | @property 74 | def parameters(self): # noqa: WPS110 75 | return self._parameters 76 | 77 | @property 78 | def batch_size(self): 79 | return self._batch_size 80 | 81 | @property 82 | def has_mass_rhs(self) -> bool: 83 | return True 84 | 85 | @abstractmethod 86 | def solution(self, inputs: Tensor): 87 | # inputs: nf (nt nx) 2 88 | # outputs: nf (nt nx) 1 89 | pass # noqa: WPS420 90 | 91 | def _make_solution( 92 | self, 93 | n_functions: int, 94 | n_contexts_t: int, 95 | n_contexts_x: int, 96 | n_targets_t: int, 97 | n_targets_x: int, 98 | ) -> Dict[str, Tensor]: 99 | input_sample_settings = ( 100 | ("contexts", n_contexts_t, n_contexts_x), 101 | ("targets", n_targets_t, n_targets_x), 102 | ) 103 | tensors: Dict[str, Tensor] = {} 104 | for mode, n_t, n_x in input_sample_settings: 105 | ts = sample_ts(n_functions, n_t, mode=mode, t_range=self.t_range) 106 | xs = sample_xs(n_functions, n_x, x_range=self.x_range) 107 | inputs = meshgrid(ts, xs) 108 | inputs = rearrange(inputs, "nf nt nx d -> nf (nt nx) d") 109 | outputs = self.solution(inputs) 110 | tensors.update( 111 | { 112 | f"input_{mode}": inputs, 113 | f"output_{mode}": outputs, 114 | } 115 | ) 116 | return tensors 117 | 118 | @abstractmethod 119 | def _mass_rhs(self): 120 | pass # noqa: WPS420 121 | 122 | 123 | class PorousMediumEquation(GeneralizedPorousMediumEquation): 124 | def __init__( 125 | self, 126 | n_functions: int, 127 | *args, 128 | scale_lims=(0.2, 5), 129 | degree_min=2, 130 | degree_max=6, 131 | degrees=None, 132 | **kwargs, 133 | ): 134 | self.scale_lims = scale_lims 135 | if degrees is None: 136 | self._degrees = None 137 | self.degree_min = degree_min 138 | self.degree_max = degree_max 139 | else: 140 | self._degrees = torch.tensor(degrees) 141 | self.degree_min = self._degrees.min().item() 142 | self.degree_max = self._degrees.max().item() 143 | 144 | super().__init__(n_functions, *args, **kwargs) 145 | 146 | @property 147 | def degrees(self): 148 | return self._parameters[:, 0] 149 | 150 | @property 151 | def scales(self): 152 | return self._parameters[:, 1] 153 | 154 | @property 155 | def params(self) -> Tensor: 156 | return self._parameters 157 | 158 | def solution(self, inputs: Tensor): 159 | return self.true_solution(inputs, self.degrees, self.scales) 160 | 161 | def true_solution(self, inputs: Tensor, degrees: Tensor, scales: Tensor): 162 | # inputs: nf (nx nt) 2 163 | degrees = rearrange(degrees, "nf -> nf 1") 164 | scales = rearrange(scales, "nf -> nf 1") 165 | ts, xs = torch.split(inputs, (1, 1), -1) 166 | ts = ts.squeeze(-1) 167 | xs = xs.squeeze(-1) * scales 168 | 169 | us = degrees * functional.relu(ts - xs) 170 | ys = us.pow(1 / degrees) 171 | return rearrange(ys, "nf nt_nx -> nf nt_nx 1") 172 | 173 | def shock_points(self, i: int, x_of_interest): 174 | scale = self.scales[i] 175 | return torch.tensor(x_of_interest) * scale 176 | 177 | def _mass_rhs(self): 178 | degree = self.degrees 179 | degree = rearrange(degree, "nf -> nf 1") 180 | input_targets = rearrange( 181 | self.tensors["input_targets"], "nf (nt nx) d -> nf nt nx d", nt=self.n_targets_t 182 | ) 183 | ts = input_targets[:, :, 0, 0] 184 | return mass_pme(degree, ts) 185 | 186 | def _sample_parameters(self, n_functions: int): 187 | if self._degrees is None: 188 | degrees = self.degree_min + torch.rand(n_functions) * ( 189 | self.degree_max - self.degree_min 190 | ) 191 | else: 192 | n_functions_per_degree = n_functions // self._degrees.shape[0] 193 | degrees = repeat(self._degrees, "nd -> (nd nfpd)", nfpd=n_functions_per_degree) 194 | scales = self._sample_scales(n_functions, self.scale_lims) 195 | return torch.stack((degrees, scales), -1) 196 | 197 | def _sample_scales(self, n_functions, scale_lims): 198 | min_scale_log = torch.tensor(scale_lims[0]).log() 199 | max_scale_log = torch.tensor(scale_lims[1]).log() 200 | scales_log = min_scale_log + torch.rand(n_functions) * (max_scale_log - min_scale_log) 201 | return torch.exp(scales_log) 202 | 203 | 204 | def mass_pme(degree: Tensor, ts: Tensor): 205 | a1 = 1 + (1 / degree) 206 | return (degree.pow(a1)) / (degree + 1) * ts.pow(a1) 207 | 208 | 209 | class Stefan: 210 | def __init__(self, p_star=0.5): 211 | self.p_star = p_star 212 | self.k_min = 0 213 | self.k_max = 1 214 | 215 | self._z1: Optional[float] = None 216 | self._alpha: Optional[float] = None 217 | 218 | def true_solution(self, inputs: np.ndarray): 219 | ts, xs = np.split(inputs, 2, -1) 220 | p1 = self.p1(xs, ts) 221 | p2 = self.p2(xs, ts) 222 | x_star = self.alpha * np.sqrt(ts) 223 | p = p1 * (xs <= x_star) + p2 * (xs > x_star) 224 | p[np.isclose(xs, 0)] = 1.0 225 | return p.squeeze(-1) 226 | 227 | def shock_points(self, x_of_interest): 228 | return np.power(x_of_interest / self.alpha, 2) 229 | 230 | def mass_stefan(self, ts: Tensor) -> Tensor: 231 | k_max = self.k_max 232 | c1 = self.c1 233 | a1: float = 2 * np.sqrt(k_max / np.pi) 234 | 235 | return a1 * c1 * torch.sqrt(ts) 236 | 237 | def p1(self, x: np.ndarray, t: np.ndarray) -> np.ndarray: 238 | a = x / (2 * np.sqrt(self.k_max * t)) 239 | return 1 - self.c1 * erf(a) 240 | 241 | def p2(self, x: np.ndarray, t: np.ndarray) -> np.ndarray: 242 | if self.k_min == 0: 243 | return np.zeros_like(x) 244 | a = x / (2 * np.sqrt(self.k_min * t)) 245 | return self.c2 * (1 - erf(a)) 246 | 247 | @property 248 | def c1(self) -> float: 249 | num = 1 - self.p_star 250 | dem = erf(self.alpha / (2 * (np.sqrt(self.k_max)))) 251 | return num / dem 252 | 253 | @property 254 | def c2(self) -> float: 255 | num = self.p_star 256 | a = self.alpha / (2 * np.sqrt(self.k_min)) 257 | dem = 1 - erf(a) 258 | return num / dem 259 | 260 | @property 261 | def alpha(self) -> float: 262 | if self._alpha is None: 263 | self._alpha = 2 * np.sqrt(self.k_max) * self.z1 264 | return self._alpha 265 | 266 | @property 267 | def z1(self) -> float: 268 | if self._z1 is None: 269 | self._z1 = root_scalar(self._z1_objective, bracket=(0, 10)).root 270 | return self._z1 271 | 272 | def _z1_objective(self, z1): 273 | a1 = self.p_star * erf(z1) 274 | a2 = z1 * np.exp(np.power(z1, 2)) 275 | b = (1 - self.p_star) / np.sqrt(np.pi) 276 | return (a1 * a2) - b 277 | 278 | 279 | class StefanPME(GeneralizedPorousMediumEquation): 280 | def __init__( 281 | self, 282 | n_functions, 283 | *args, 284 | p_star_lim: Tuple[float, float] = (0.1, 0.9), 285 | p_stars: Optional[Tuple[float, ...]] = None, 286 | **kwargs, 287 | ): 288 | if p_stars is None: 289 | self._p_stars = None 290 | self.p_star_lim = p_star_lim 291 | else: 292 | self._p_stars = torch.tensor(p_stars) 293 | self.p_star_lim = (min(p_stars), max(p_stars)) 294 | super().__init__(n_functions, *args, **kwargs) 295 | 296 | @property 297 | def stefans(self) -> List[Stefan]: 298 | return self._parameters 299 | 300 | @property 301 | def params(self) -> Tensor: 302 | params = torch.tensor([s.p_star for s in self.stefans]) 303 | return params.reshape(-1, 1) 304 | 305 | def solution(self, inputs: Tensor): 306 | # inputs: nf (nt nx) 2 307 | # outputs: nf (nt nx) 1 308 | assert len(inputs.shape) == 3 309 | nf = len(self.stefans) 310 | if inputs.shape[0] == 1: 311 | inputs = inputs.expand(nf, -1, -1) 312 | soln_list = [] 313 | for in_tensor, stefan in zip(inputs, self.stefans): 314 | soln_i = torch.from_numpy(stefan.true_solution(in_tensor.numpy())) 315 | soln_i = soln_i.float() 316 | soln_list.append(soln_i) 317 | return torch.stack(soln_list, dim=0).unsqueeze(-1) 318 | 319 | def shock_points(self, i: int, x_of_interest): 320 | stefan = self.stefans[i] 321 | shock_points = stefan.shock_points(np.array(x_of_interest)) 322 | return torch.from_numpy(shock_points) 323 | 324 | def _mass_rhs(self): 325 | input_targets = rearrange( 326 | self.tensors["input_targets"], "nf (nt nx) d -> nf nt nx d", nt=self.n_targets_t 327 | ) 328 | ts = input_targets[:, :, 0, 0] 329 | stefans = self.stefans 330 | masses: List[Tensor] = [] 331 | for ts_i, stefan in zip(ts, stefans): 332 | mass = stefan.mass_stefan(ts_i) 333 | masses.append(mass) 334 | return torch.stack(masses) 335 | 336 | def _sample_parameters(self, n_functions: int): 337 | if self._p_stars is None: 338 | a = self.p_star_lim[0] 339 | b_minus_a = self.p_star_lim[1] - self.p_star_lim[0] 340 | p_stars = a + torch.rand(n_functions) * b_minus_a 341 | else: 342 | n_functions_per_pstar = n_functions // self._p_stars.shape[0] 343 | p_stars = repeat(self._p_stars, "npstar -> (npstar nfpps)", nfpps=n_functions_per_pstar) 344 | return [Stefan(p_star=p.item()) for p in p_stars] 345 | 346 | 347 | def sample_ts(n_functions: int, n_t: int, mode: str, t_range: Tuple[float, float]) -> Tensor: 348 | if mode == "contexts": 349 | us = sample_points_from_each_interval(n_functions, n_t) 350 | elif mode == "targets": 351 | us = sample_points_uniformly(n_functions, n_t) 352 | us, _ = torch.sort(us) 353 | t_min, t_max = t_range 354 | return t_min + us * (t_max - t_min) 355 | 356 | 357 | def make_dense_grid(nt, nx): 358 | dt = torch.linspace(0, 1, nt).unsqueeze(0) 359 | dx = torch.linspace(0, 1, nx).unsqueeze(0) 360 | return meshgrid(dt, dx) 361 | 362 | 363 | def sample_xs(n_functions: int, n_x: int, x_range: Tuple[float, float]) -> Tensor: 364 | us = sample_points_uniformly(n_functions, n_x) 365 | us, _ = torch.sort(us) 366 | x_min, x_max = x_range 367 | return x_min + us * (x_max - x_min) 368 | 369 | 370 | class HeatEquation(GeneralizedPorousMediumEquation): 371 | def __init__( 372 | self, 373 | n_functions: int, 374 | *args, 375 | conductivity_min=1, 376 | conductivity_max=5, 377 | conductivities=None, 378 | **kwargs, 379 | ): 380 | if conductivities is None: 381 | self._conductivities = None 382 | self.conductivity_min = conductivity_min 383 | self.conductivity_max = conductivity_max 384 | else: 385 | self._conductivities = torch.tensor(conductivities) 386 | self.conductivity_min = self._conductivities.min().item() 387 | self.conductivity_max = self._conductivities.max().item() 388 | 389 | self.nx_soln = 512 390 | x_range = (0, 2 * np.pi) 391 | kwargs["x_range"] = x_range 392 | 393 | super().__init__(n_functions, *args, **kwargs) 394 | 395 | @property 396 | def conductivities(self) -> Tensor: 397 | return self._parameters[:, 0] 398 | 399 | @property 400 | def thetas(self) -> Tensor: 401 | return self._parameters[:, 1] 402 | 403 | @property 404 | def params(self) -> Tensor: 405 | return self._parameters 406 | 407 | def solution(self, inputs: Tensor): 408 | nf = len(self.thetas) 409 | if inputs.shape[0] == 1: 410 | inputs = inputs.expand(nf, -1, -1) 411 | assert inputs.shape[0] == nf 412 | return self.true_solution(inputs, self.thetas, self.conductivities) 413 | 414 | def true_solution(self, inputs: Tensor, thetas: Tensor, nus: Tensor) -> Tensor: 415 | ts = inputs[:, :, 0] 416 | xs = inputs[:, :, 1] 417 | nf = len(thetas) 418 | xs = xs.unique(dim=1).reshape(nf, -1) 419 | 420 | ts = ts.unique(dim=1).reshape(nf, -1) 421 | tr_all = self.convection_onedim(ts, thetas, nus) 422 | 423 | nt_nx = inputs.shape[1] 424 | nt = int(np.sqrt(nt_nx)) 425 | 426 | grid = rearrange(inputs, "nf (nt nx) d -> nf nt nx d", nt=nt).clone() 427 | grid[:, :, :, 1] /= np.pi * 2 428 | grid = (grid - 0.5) * 2 429 | # (h w) to (x y) 430 | grid_x = grid[:, :, :, 1] 431 | grid_y = grid[:, :, :, 0] 432 | grid = torch.stack((grid_x, grid_y), dim=-1) 433 | 434 | tr = functional.grid_sample( 435 | tr_all.unsqueeze(1).float(), grid, align_corners=True, mode="bilinear" 436 | ).squeeze(1) 437 | tr = rearrange(tr, "nf nt nx -> nf (nt nx) 1") 438 | return tr.float() 439 | 440 | def convection_onedim(self, t_values: Tensor, thetas: Tensor, nus: Tensor): 441 | n_function_draws = thetas.shape[0] 442 | u_list = [] 443 | for i in range(n_function_draws): 444 | u_i = self._convection_onedim_for_one_parameter( 445 | t_values[i, :], 446 | thetas[i].item(), 447 | nus[i].item(), 448 | ) 449 | u_list.append(u_i) 450 | u = np.stack(u_list, axis=0) 451 | return torch.from_numpy(u) 452 | 453 | def _convection_onedim_for_one_parameter(self, t_values, theta: float, nu: float): 454 | two_pi = 2 * np.pi 455 | dx = two_pi / self.nx_soln 456 | x_grid = np.arange(0, two_pi, dx) 457 | x_start = np.sin(x_grid + theta) 458 | t_grid = repeat(t_values, "nt -> nt nx", nx=self.nx_soln) 459 | t_grid = t_grid.numpy() 460 | return convection_diffusion_solution(x_start, t_grid, nu, beta=0) 461 | 462 | def _mass_rhs(self): 463 | cs = self.conductivities 464 | cs = rearrange(cs, "nf -> nf 1") 465 | return torch.zeros_like(cs) 466 | 467 | def _sample_parameters(self, n_functions: int): 468 | if self._conductivities is None: 469 | degrees = self.conductivity_min + torch.rand(n_functions) * ( 470 | self.conductivity_max - self.conductivity_min 471 | ) 472 | else: 473 | n_functions_per_degree = n_functions // self._conductivities.shape[0] 474 | degrees = repeat(self._conductivities, "nd -> (nd nfpd)", nfpd=n_functions_per_degree) 475 | thetas = torch.zeros(n_functions) 476 | return torch.stack((degrees, thetas), -1) 477 | 478 | def _weighted_average(self, xs, full_output: Tensor): 479 | _, nt, _ = full_output.shape 480 | indx_floating = xs / (2 * np.pi) * (self.nx_soln - 1) 481 | 482 | indx_lower = torch.floor(indx_floating).to(torch.int64) 483 | indx_lower = repeat(indx_lower, "nf nx -> nf nt nx", nt=nt) 484 | 485 | indx_higher = indx_lower + 1 486 | w_lower = indx_higher - indx_floating.unsqueeze(1) 487 | w_higher = indx_floating.unsqueeze(1) - indx_lower 488 | padding = torch.zeros_like(full_output)[:, :, 0:1] 489 | full_output = torch.cat((full_output, padding), -1) 490 | full_output_lower = torch.gather(full_output, dim=-1, index=indx_lower) 491 | full_output_higher = torch.gather(full_output, dim=-1, index=indx_higher) 492 | 493 | return full_output_lower * w_lower + full_output_higher * w_higher 494 | 495 | 496 | class LinearAdvection(GeneralizedPorousMediumEquation): 497 | def __init__( 498 | self, 499 | n_functions, 500 | *args, 501 | a_lim: Tuple[float, float] = (1, 10), 502 | a_vals: Optional[Tuple[float, ...]] = None, 503 | **kwargs, 504 | ): 505 | if a_vals is None: 506 | self._a_vals = None 507 | self.a_lim = a_lim 508 | else: 509 | self._a_vals = torch.tensor(a_vals) 510 | self.a_lim = (min(a_vals), max(a_vals)) 511 | super().__init__(n_functions, *args, **kwargs) 512 | 513 | @property 514 | def params(self) -> Tensor: 515 | return self._parameters.unsqueeze(-1) 516 | 517 | def solution(self, inputs: Tensor): 518 | # inputs: nf (nt nx) 2 519 | # outputs: nf (nt nx) 1 520 | assert len(inputs.shape) == 3 521 | nf = self.parameters.shape[0] 522 | if inputs.shape[0] == 1: 523 | inputs = inputs.expand(nf, -1, -1) 524 | t = inputs[:, :, 0] 525 | x = inputs[:, :, 1] 526 | a = rearrange(self.parameters, "nf -> nf 1") 527 | u = self.h(x - t * a) 528 | return rearrange(u, "nf nt_nx -> nf nt_nx 1") 529 | 530 | def h(self, x: Tensor): 531 | return (x <= 0.5).float() # noqa: WPS459 532 | 533 | def mass_advection(self, ts: Tensor, a_vals: Tensor): 534 | max_density_tnsr = torch.tensor(0.5, device=ts.device) 535 | return 0.5 + torch.minimum(ts * a_vals, max_density_tnsr) 536 | 537 | def _mass_rhs(self): 538 | input_targets = rearrange( 539 | self.tensors["input_targets"], "nf (nt nx) d -> nf nt nx d", nt=self.n_targets_t 540 | ) 541 | ts = input_targets[:, :, 0, 0] 542 | a_vals = self.parameters.reshape(-1, 1) 543 | return self.mass_advection(ts, a_vals) 544 | 545 | def _sample_parameters(self, n_functions: int): 546 | if self._a_vals is None: 547 | a = self.a_lim[0] 548 | b_minus_a = self.a_lim[1] - self.a_lim[0] 549 | a_vals = a + torch.rand(n_functions) * b_minus_a 550 | else: 551 | n_functions_per_pstar = n_functions // self._a_vals.shape[0] 552 | a_vals = repeat(self._a_vals, "npstar -> (npstar nfpps)", nfpps=n_functions_per_pstar) 553 | return a_vals 554 | 555 | 556 | class Burgers(GeneralizedPorousMediumEquation): 557 | def __init__( 558 | self, 559 | n_functions, 560 | *args, 561 | a_lim: Tuple[float, float] = (1, 5), 562 | a_vals: Optional[Tuple[float, ...]] = None, 563 | **kwargs, 564 | ): 565 | if a_vals is None: 566 | self._a_vals = None 567 | self.a_lim = a_lim 568 | else: 569 | self._a_vals = torch.tensor(a_vals) 570 | self.a_lim = (min(a_vals), max(a_vals)) 571 | 572 | x_range = (-1, 1) 573 | kwargs["x_range"] = x_range 574 | super().__init__(n_functions, *args, **kwargs) 575 | 576 | @property 577 | def params(self) -> Tensor: 578 | return self._parameters.unsqueeze(-1) 579 | 580 | def solution(self, inputs: Tensor): 581 | # inputs: nf (nt nx) 2 582 | # outputs: nf (nt nx) 1 583 | nf = self.parameters.shape[0] 584 | if inputs.shape[0] == 1: 585 | inputs = inputs.expand(nf, -1, -1) 586 | t = inputs[:, :, 0] 587 | x = inputs[:, :, 1] 588 | a = rearrange(self.parameters, "nf -> nf 1") 589 | u = solution_burgers(t, x, a) 590 | return rearrange(u, "nf nt_nx -> nf nt_nx 1") 591 | 592 | def _mass_rhs(self): 593 | input_targets = rearrange( 594 | self.tensors["input_targets"], "nf (nt nx) d -> nf nt nx d", nt=self.n_targets_t 595 | ) 596 | ts = input_targets[:, :, 0, 0] 597 | a_vals = self.parameters.reshape(-1, 1) 598 | return mass_burgers(ts, a_vals) 599 | 600 | def _sample_parameters(self, n_functions: int): 601 | if self._a_vals is None: 602 | a = self.a_lim[0] 603 | b_minus_a = self.a_lim[1] - self.a_lim[0] 604 | a_vals = a + torch.rand(n_functions) * b_minus_a 605 | else: 606 | n_functions_per_pstar = n_functions // self._a_vals.shape[0] 607 | a_vals = repeat(self._a_vals, "npstar -> (npstar nfpps)", nfpps=n_functions_per_pstar) 608 | return a_vals 609 | 610 | 611 | def solution_burgers(t: Tensor, x: Tensor, a: Tensor): 612 | break_time = a.pow(-1) 613 | u_prebreak = _solution_burgers_prebreak(t, x, a) 614 | u_postbreak = _solution_burgers_postbreak(t, x, a) 615 | return u_prebreak * (t <= break_time) + u_postbreak * (t > break_time) 616 | 617 | 618 | def _solution_burgers_prebreak(t: Tensor, x: Tensor, a: Tensor) -> Tensor: 619 | c1 = x <= ((a * t) - 1) 620 | u1 = a * c1 621 | c2 = torch.logical_and(~c1, x <= 0) 622 | u2 = (a * x) / (a * t - 1) * c2 623 | # zero if above u3 624 | return u1 + u2 625 | 626 | 627 | def _solution_burgers_postbreak(t: Tensor, x: Tensor, a: Tensor) -> Tensor: 628 | c = x <= (0.5 * (a * t - 1)) 629 | return c * a 630 | 631 | 632 | def mass_burgers(ts: Tensor, a_vals: Tensor): 633 | return (a_vals / 2) * (1 + (a_vals * ts)) 634 | 635 | 636 | if __name__ == "__main__": 637 | pme = PorousMediumEquation(10, 10, 10, 10, 10, 10) 638 | -------------------------------------------------------------------------------- /deep_pdes/experiments/1b_pme_train_pinp_multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | EXPERIMENT=1b_pme_var_m 3 | # for l in 1en2 1en1 1e1 1e2 1e6 4 | for l in 1e6 5 | do 6 | echo "Training ANP+SoftC with lambda="${l} 7 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_pinp_${l} 8 | done 9 | -------------------------------------------------------------------------------- /deep_pdes/experiments/1b_pme_var_m.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT=1b_pme_var_m 2 | 3 | python generate.py +experiments=$EXPERIMENT 4 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_anp 5 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_pinp 6 | python analyze.py +experiments=$EXPERIMENT 7 | -------------------------------------------------------------------------------- /deep_pdes/experiments/2b_stefan_var_p.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT=2b_stefan_var_p 2 | 3 | python generate.py +experiments=$EXPERIMENT 4 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_anp 5 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_pinp 6 | python analyze.py +experiments=$EXPERIMENT 7 | -------------------------------------------------------------------------------- /deep_pdes/experiments/3b_heat_var_c.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #set -o errexit 3 | 4 | EXPERIMENT=3b_heat_var_c 5 | 6 | python generate.py +experiments=$EXPERIMENT 7 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_anp 8 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_pinp 9 | python analyze.py +experiments=$EXPERIMENT 10 | -------------------------------------------------------------------------------- /deep_pdes/experiments/4b_advection_var_a.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #set -o errexit 3 | 4 | EXPERIMENT=4b_advection_var_a 5 | 6 | python generate.py +experiments=$EXPERIMENT 7 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_anp 8 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_pinp 9 | python analyze.py +experiments=$EXPERIMENT 10 | -------------------------------------------------------------------------------- /deep_pdes/experiments/5b_burgers_var_a.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #set -o errexit 3 | 4 | EXPERIMENT=5b_burgers_var_a 5 | 6 | python generate.py +experiments=$EXPERIMENT 7 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_pinp 8 | python train.py +experiments=$EXPERIMENT +train=${EXPERIMENT}_anp 9 | python analyze.py +experiments=$EXPERIMENT 10 | -------------------------------------------------------------------------------- /deep_pdes/experiments/analyze.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from typing import Dict, List, Optional, Union 4 | 5 | import hydra 6 | import numpy as np 7 | import pandas as pd 8 | import plotnine as p9 9 | import torch 10 | from einops import rearrange, reduce, repeat 11 | from hydra.utils import instantiate 12 | from torch import Tensor 13 | from torch.nn import functional as F # noqa: WPS347 14 | from tqdm import tqdm 15 | 16 | from deep_pdes.attentive_neural_process.anp import ANP 17 | from deep_pdes.attentive_neural_process.probconserv import ( 18 | InequalityConstraint, 19 | PhysNP, 20 | apply_non_linear_ineq_constraint, 21 | ) 22 | from deep_pdes.attentive_neural_process.softc import PINP, SoftcANP 23 | from deep_pdes.datasets.base import ANPDataset 24 | from deep_pdes.datasets.pme import ( 25 | Burgers, 26 | HeatEquation, 27 | LinearAdvection, 28 | PorousMediumEquation, 29 | Stefan, 30 | StefanPME, 31 | mass_burgers, 32 | mass_pme, 33 | meshgrid, 34 | ) 35 | 36 | 37 | @hydra.main(version_base=None, config_path="conf", config_name="config") 38 | def main(cfg): 39 | outdir = Path(cfg.analysis.outdir) 40 | if not outdir.exists(): 41 | outdir.mkdir(parents=True) 42 | 43 | preds = infer(cfg) 44 | 45 | dataset_path = Path(cfg.datasets.save_path) 46 | test_load_path = dataset_path / "test.pt" 47 | test_dataset = instantiate(cfg.datasets.test, load_path=test_load_path) 48 | _, output_targets, _ = get_test_solution(cfg, test_dataset) 49 | mse_at_t_df = analyze_mean_squared_error(cfg, preds, output_targets, test_dataset) 50 | mse_at_t_df.to_pickle(cfg.analysis.mse_at_t_df_path) 51 | 52 | plot_df, true_df = make_plotting_dfs(cfg, preds, output_targets, test_dataset) 53 | plot_df.to_pickle(cfg.analysis.plot_df_path) 54 | true_df.to_pickle(cfg.analysis.true_df_path) 55 | 56 | cons_df, true_cons_df = analyze_conservation(cfg, test_dataset, plot_df) 57 | cons_df.to_pickle(cfg.analysis.cons_df_path) 58 | true_cons_df.to_pickle(cfg.analysis.true_cons_df_path) 59 | 60 | if ( 61 | isinstance(test_dataset, StefanPME) 62 | or isinstance(test_dataset, PorousMediumEquation) 63 | or isinstance(test_dataset, LinearAdvection) 64 | ): 65 | shocks_all = analyze_shocks(cfg, preds, test_dataset) 66 | else: 67 | shocks_all = None 68 | plot_context_points(cfg, test_dataset, plot_df, true_df) 69 | analyze_solution_profiles(cfg, test_dataset, plot_df, true_df) 70 | 71 | inference_results = { 72 | "shocks_all": shocks_all, 73 | } 74 | torch.save(inference_results, cfg.analysis.inference_results) 75 | 76 | 77 | def analyze_mean_squared_error( 78 | cfg, preds: Dict[str, Tensor], output_targets: Tensor, test_dataset: ANPDataset 79 | ) -> pd.DataFrame: 80 | outdir = Path(cfg.analysis.outdir) 81 | if not outdir.exists(): 82 | outdir.mkdir() 83 | t_range = cfg.analysis.t_range 84 | params = get_params(cfg, test_dataset) 85 | n_p_stars = len(params) 86 | params_ordered = cfg.analysis.get("params_ordered", params) 87 | nice_names = cfg.analysis.nice_names 88 | models_of_interest: List[str] = list(nice_names.keys()) 89 | 90 | all_preds = torch.stack([p["pred"] for p in preds.values()]) 91 | error = all_preds - rearrange(output_targets, "nf nx_nt -> 1 nf nx_nt") 92 | squared_error = error.pow(2) 93 | 94 | mse_by_t_and_fid = reduce( 95 | squared_error, 96 | "nm (nd nfpd) (nx nt) -> nm nd nfpd nt", 97 | "mean", 98 | nd=n_p_stars, 99 | nx=cfg.analysis.nx, 100 | ) 101 | mse_by_t = reduce(mse_by_t_and_fid, "nm nd nfpd nt -> nm nd nt", "mean", nd=n_p_stars) 102 | mse_sd_by_t = reduce( 103 | mse_by_t_and_fid, 104 | "nm nd nfpd nt -> nm nd nt", 105 | torch.std, 106 | nd=n_p_stars, 107 | ) 108 | nfpd = squared_error.shape[1] / n_p_stars 109 | mse_se_by_t = mse_sd_by_t / torch.tensor(nfpd).sqrt() 110 | ts = np.linspace(*t_range, cfg.analysis.nt) 111 | dim_values = [preds.keys(), params, ts] 112 | dim_names = ["model", "param", "t"] 113 | midx = pd.MultiIndex.from_product(dim_values, names=dim_names) 114 | mse_all_by_t = torch.cat( 115 | [mse_by_t.reshape(-1, 1), mse_sd_by_t.reshape(-1, 1), mse_se_by_t.reshape(-1, 1)], dim=-1 116 | ) 117 | mse_at_t_long_df = pd.DataFrame( 118 | mse_all_by_t, index=midx, columns=["MSE", "MSE_sd", "MSE_se"] 119 | ).reset_index() 120 | 121 | # scales = torch.stack([p["scale"].reshape(p["scale"].shape[0], -1) if p["scale"] is not None else torch.tensor(torch.nan).expand(p["pred"].shape) for p in preds.values()]) 122 | def _get_scale(model, preds): 123 | if "physnp" in model: 124 | # loc = preds[model]["loc"] 125 | # cov = preds["physnp_notrain"]["cov"] 126 | scale = rearrange(preds["scale"], "nf nt nx 1 -> nf (nx nt)", nx=cfg.analysis.nx) 127 | elif ("anp" in model) or ("pinp" in model): 128 | # scale = rearrange(preds[model]["scale"], "nf (nx nt) 1 -> nf nt nx 1", nx=cfg.analysis.nx) 129 | scale = rearrange(preds["scale"], "nf (nx nt) 1 -> nf (nx nt)", nx=cfg.analysis.nx) 130 | elif "hcnp" in model: 131 | scale = rearrange(preds["scale"], "nf nt nx 1 -> nf (nx nt)", nx=cfg.analysis.nx) 132 | else: 133 | scale = torch.tensor(torch.nan).expand(preds["pred"].shape) 134 | return scale 135 | 136 | scales = torch.stack([_get_scale(k, v) for (k, v) in preds.items()]) 137 | loglik = -0.5 * squared_error / scales.pow(2) - scales.log() - 0.5 * np.log(2 * torch.pi) 138 | loglik_by_t_and_fid = reduce( 139 | loglik, "nm (nd nfpd) (nx nt) -> nm nd nfpd nt", "mean", nd=n_p_stars, nx=cfg.analysis.nx 140 | ) 141 | loglik_by_t = reduce(loglik_by_t_and_fid, "nm nd nfpd nt -> nm nd nt", "mean") 142 | loglik_sd_by_t = reduce(loglik_by_t_and_fid, "nm nd nfpd nt -> nm nd nt", torch.std) 143 | loglik_se_by_t = loglik_sd_by_t / torch.tensor(nfpd).sqrt() 144 | 145 | loglik_at_fid = rearrange( 146 | loglik, "nm (nd nfpd) (nx nt) -> nm nd nfpd nx nt", nd=n_p_stars, nx=cfg.analysis.nx 147 | )[:, :, 1] 148 | loglik_by_t_at_fid = reduce(loglik_at_fid, "nm nd nx nt -> nm nd nt", "mean") 149 | 150 | mse_at_t_long_df["loglik"] = loglik_by_t.reshape(-1, 1) 151 | mse_at_t_long_df["loglik_sd"] = loglik_sd_by_t.reshape(-1, 1) 152 | mse_at_t_long_df["loglik_se"] = loglik_se_by_t.reshape(-1, 1) 153 | mse_at_t_long_df["loglik_fid"] = loglik_by_t_at_fid.reshape(-1, 1) 154 | mse_at_t_long_df = ( 155 | mse_at_t_long_df.loc[np.isin(mse_at_t_long_df.model, models_of_interest)] 156 | .assign(model=lambda df: pd.Categorical(df.model, models_of_interest, ordered=True)) 157 | .assign(param=lambda df: pd.Categorical(df.param, params_ordered, ordered=True)) 158 | ) 159 | 160 | t_of_interest = cfg.analysis.t_of_interest[0] 161 | m: pd.DataFrame = mse_at_t_long_df.loc[np.isclose(mse_at_t_long_df.t, t_of_interest)] 162 | m.to_csv(outdir / "mse_at_t.csv") 163 | 164 | labeller = make_labeller(cfg) 165 | mse_at_t_plot = ( 166 | p9.ggplot(mse_at_t_long_df, p9.aes(x="t", y="MSE", color="model")) # noqa: WPS221 167 | + p9.geom_smooth(span=0.05, se=False) 168 | + p9.facet_grid("~param", labeller=labeller) 169 | + p9.theme_bw(base_size=cfg.analysis.base_font_size) 170 | + p9.scale_color_hue(labels=nice_names.values()) 171 | # + p9.guides(color=None) 172 | + p9.theme( 173 | strip_text=p9.element_text(usetex=True), legend_position="bottom", legend_title=None 174 | ) 175 | ) 176 | 177 | plot_dir = outdir / "plots" 178 | if not plot_dir.exists(): 179 | plot_dir.mkdir() 180 | for ext in ("pdf", "png"): 181 | mse_at_t_plot.save( 182 | plot_dir / f"mse_over_time.{ext}", 183 | dpi=cfg.analysis.dpi, 184 | width=cfg.analysis.mse_plot_width, 185 | height=cfg.analysis.mse_plot_height, 186 | ) 187 | 188 | return mse_at_t_long_df 189 | 190 | 191 | def _get_loc_and_cov(model: str, preds: Dict[str, Dict[str, Tensor]], nx: int): 192 | if ("physnp_notrain" in model) or ("physnp_limit" in model) or ("physnp_second_deriv" in model): 193 | loc = preds[model]["loc"] 194 | cov = preds[model]["cov"] 195 | elif ("anp" in model) or ("pinp" in model): 196 | loc = rearrange(preds[model]["loc"], "nf (nx nt) 1 -> nf nt nx 1", nx=nx) 197 | scale = rearrange(preds[model]["scale"], "nf (nx nt) 1 -> nf nt nx 1", nx=nx) 198 | eye = rearrange(torch.eye(nx), "nx1 nx2 -> 1 1 nx1 nx2") 199 | cov = scale * eye 200 | elif "hcnp" in model: 201 | loc = rearrange(preds[model]["loc"], "nf nt nx 1 -> nf nt nx 1", nx=nx) 202 | scale = rearrange(preds[model]["scale"], "nf nt nx 1 -> nf nt nx 1", nx=nx) 203 | eye = rearrange(torch.eye(nx), "nx1 nx2 -> 1 1 nx1 nx2") 204 | cov = scale * eye 205 | else: 206 | loc = None 207 | cov = None 208 | return loc, cov 209 | 210 | 211 | def analyze_shocks(cfg, preds: Dict[str, Tensor], test_dataset: ANPDataset) -> Tensor: 212 | outdir = Path(cfg.analysis.outdir) 213 | if not outdir.exists(): 214 | outdir.mkdir() 215 | t_range = cfg.analysis.t_range 216 | if isinstance(test_dataset, PorousMediumEquation): 217 | params = cfg.datasets.test.degrees 218 | models = params # shock is constant 219 | elif isinstance(test_dataset, StefanPME): 220 | params = cfg.datasets.test.p_stars 221 | models = [Stefan(p) for p in params] 222 | elif isinstance(test_dataset, LinearAdvection): 223 | params = cfg.datasets.test.a_vals 224 | models = params 225 | else: 226 | raise NotImplementedError() 227 | params_ordered = cfg.analysis.get("params_ordered", params) 228 | nice_names = cfg.analysis.nice_names 229 | x_range = cfg.analysis.x_range 230 | t_of_interest = cfg.analysis.t_of_interest 231 | nx = cfg.analysis.nx 232 | fids_of_interest = cfg.analysis.fids_of_interest 233 | shocks = {} 234 | n_shock_samples = cfg.analysis.n_shock_samples 235 | n_shock_samples_per_batch: Optional[int] = cfg.analysis.get("n_shock_samples_per_batch") 236 | for model in nice_names.keys(): 237 | shock_path_str = cfg.analysis.methods[model].get("shock_path", None) 238 | shock_overwrite = cfg.analysis.methods[model].get("shock_overwrite", False) 239 | if shock_path_str is not None: 240 | shock_path = Path(shock_path_str) 241 | else: 242 | shock_path = None 243 | if (shock_path is None) or (not shock_path.exists()) or shock_overwrite: 244 | loc, cov = _get_loc_and_cov(model, preds, nx) 245 | if cov is None: 246 | continue 247 | s = estimate_shock_interval( 248 | loc.cuda(), 249 | cov.cuda(), 250 | n_samples=n_shock_samples, 251 | n_samples_per_batch=n_shock_samples_per_batch, 252 | ) 253 | s = rearrange(s, "ns (nd nfpd) nt -> ns nd nfpd nt", nd=len(models)) 254 | s = (s / nx) * (x_range[1]) 255 | shocks[nice_names[model]] = s.cpu() 256 | if shock_path is not None: 257 | torch.save(s.cpu(), shock_path) 258 | else: 259 | shocks[nice_names[model]] = torch.load(shock_path) 260 | shocks_all: Tensor = torch.stack(tuple(shocks.values())) 261 | for fid in fids_of_interest: 262 | t = t_of_interest[0] 263 | t_idx = int((t / t_range[1]) * (nx - 1)) 264 | dfs = [] 265 | true_shock_dfs = [] 266 | shocks_at_t_and_fid = shocks_all[:, :, :, int(fid), t_idx] 267 | shocks_at_t_and_fid = rearrange(shocks_at_t_and_fid, "nm nsamples nd -> (nm nsamples nd) 1") 268 | midx = pd.MultiIndex.from_product( 269 | [shocks.keys(), range(n_shock_samples), params], names=["model", "sample", "param"] 270 | ) 271 | shock_df = ( 272 | pd.DataFrame(shocks_at_t_and_fid, index=midx, columns=["shock_position"]) 273 | .reset_index() 274 | .assign(param=lambda df: pd.Categorical(df.param, params_ordered, ordered=True)) 275 | ) 276 | if isinstance(test_dataset, StefanPME): 277 | for i, model in enumerate(models): 278 | true_shock = model.alpha * np.sqrt(t_idx / nx * 0.1) 279 | min_stefan = Stefan(p_star=cfg.datasets.p_star_max) 280 | min_shock = min_stefan.alpha * np.sqrt(t_idx / nx * 0.1) 281 | max_stefan = Stefan(p_star=cfg.datasets.p_star_min) 282 | max_shock = max_stefan.alpha * np.sqrt(t_idx / nx * 0.1) 283 | true_shock_dfs.append( 284 | pd.DataFrame( 285 | { 286 | "param": np.array([model.p_star]), 287 | "true_shock": np.array([true_shock]), 288 | "max_shock": np.array([max_shock]), 289 | } 290 | ) 291 | ) 292 | elif isinstance(test_dataset, PorousMediumEquation): 293 | for i, model in enumerate(models): 294 | true_shock = t 295 | true_shock_dfs.append( 296 | pd.DataFrame({"param": np.array([model]), "true_shock": np.array([true_shock])}) 297 | ) 298 | elif isinstance(test_dataset, LinearAdvection): 299 | for i, param in enumerate(params): 300 | true_shock = 0.5 + t * param 301 | true_shock_dfs.append( 302 | pd.DataFrame({"param": np.array([param]), "true_shock": np.array([true_shock])}) 303 | ) 304 | else: 305 | raise NotImplementedError() 306 | 307 | true_shock_df = pd.concat(true_shock_dfs, axis=0).assign( 308 | param=lambda df: pd.Categorical(df.param, params_ordered, ordered=True) 309 | ) 310 | labeller = make_labeller(cfg) 311 | shock_plot_i = ( 312 | p9.ggplot( 313 | shock_df, p9.aes(x="shock_position", color="model", fill="model") 314 | ) # noqa: WPS221 315 | + p9.geom_histogram(bins=50) 316 | + p9.geom_vline(data=true_shock_df, mapping=p9.aes(xintercept="true_shock")) 317 | # + p9.annotate("rect", xmin=min_shock, xmax=max_shock, ymin=0, ymax=50, alpha=0.2) 318 | + p9.facet_grid("model~param", labeller=labeller) 319 | + p9.scale_color_hue(labels=nice_names.values()) 320 | + p9.scale_fill_hue(labels=nice_names.values()) 321 | + p9.theme_bw(base_size=cfg.analysis.base_font_size) 322 | + p9.theme( 323 | strip_text=p9.element_text(usetex=True), 324 | legend_position="top", 325 | legend_title=p9.element_blank(), 326 | ) 327 | + p9.xlab("x") 328 | ) 329 | plot_dir = outdir / "plots" 330 | if not plot_dir.exists(): 331 | plot_dir.mkdir() 332 | for ext in ("pdf", "png"): 333 | shock_plot_i.save( 334 | plot_dir / f"shock_plot_fid={fid}_t={t:.3f}.{ext}", 335 | width=cfg.analysis.shock_plot_width, 336 | height=cfg.analysis.shock_plot_height, 337 | dpi=cfg.analysis.dpi, 338 | ) 339 | return shocks_all 340 | # return {"shock_plot_i": shock_plot_i} 341 | 342 | 343 | def get_params(cfg, test_dataset: ANPDataset): 344 | if isinstance(test_dataset, PorousMediumEquation): 345 | params = cfg.datasets.test.degrees 346 | elif isinstance(test_dataset, StefanPME): 347 | params = cfg.datasets.test.p_stars 348 | elif isinstance(test_dataset, HeatEquation): 349 | params = cfg.datasets.test.conductivities 350 | elif isinstance(test_dataset, LinearAdvection): 351 | params = cfg.datasets.test.a_vals 352 | elif isinstance(test_dataset, Burgers): 353 | params = cfg.datasets.test.a_vals 354 | else: 355 | raise NotImplementedError() 356 | return params 357 | 358 | 359 | def make_plotting_dfs( 360 | cfg, preds: Dict[str, Tensor], output_targets: Tensor, test_dataset: ANPDataset 361 | ): 362 | plot_dfs = defaultdict(dict) 363 | params = get_params(cfg, test_dataset) 364 | n_p_stars = len(params) 365 | x_range = cfg.analysis.x_range 366 | t_range = cfg.analysis.t_range 367 | models_of_interest = list(cfg.analysis.nice_names.keys()) 368 | params_ordered = cfg.analysis.get("params_ordered", params) 369 | 370 | # nf = next(iter(preds.values()))["pred"] 371 | nx = cfg.analysis.nx 372 | nt = cfg.analysis.nt 373 | pred_loc_all = torch.stack([p["pred"] for p in preds.values()]) 374 | nf = pred_loc_all.shape[1] 375 | pred_sd_list = [] 376 | for model, pred_dict in preds.items(): 377 | scale = pred_dict.get("scale") 378 | if scale is None: 379 | pred_sd = torch.zeros((nf, nx * nt)) 380 | elif len(scale.shape) == 3: 381 | pred_sd = rearrange(scale, "nf nx_nt 1 -> nf nx_nt") 382 | else: 383 | pred_sd = rearrange(scale, "nf nt nx 1 -> nf (nx nt)") 384 | pred_sd_list.append(pred_sd) 385 | pred_sd_all = torch.stack(pred_sd_list) 386 | pred_all = torch.stack((pred_loc_all, pred_sd_all), dim=-1) 387 | nfpd = nf // n_p_stars 388 | pred_all = rearrange(pred_all, "nm nf nx_nt d -> (nm nf nx_nt) d") 389 | xs = np.linspace(*x_range, nx) 390 | ts = np.linspace(*t_range, nt) 391 | midx = pd.MultiIndex.from_product( 392 | (preds.keys(), params, np.array(range(nfpd)).astype(str), xs, ts), 393 | names=("model", "param", "f_id", "x", "t"), 394 | ) 395 | plot_df = ( 396 | pd.DataFrame(pred_all, index=midx, columns=["u", "u_sd"]) 397 | .reset_index() 398 | .assign(param=lambda df: pd.Categorical(df.param, params_ordered, ordered=True)) 399 | ) 400 | plot_df = plot_df.loc[np.isin(plot_df.model, models_of_interest)].assign( 401 | model=lambda df: pd.Categorical(df.model, models_of_interest, ordered=True) 402 | ) 403 | 404 | true_df = pd.DataFrame(output_targets) 405 | # true_df = pd.DataFrame(output_targets.unsqueeze(0)) 406 | true_df["param"] = pd.Categorical(np.repeat(params, nfpd), params_ordered, ordered=True) 407 | true_df["f_id"] = np.tile(range(nfpd), n_p_stars) 408 | true_df["f_id"] = true_df["f_id"].astype(str) 409 | true_df = true_df.melt(id_vars=("param", "f_id"), value_name="u") 410 | true_df["x"] = np.repeat( 411 | np.repeat(np.linspace(x_range[0], x_range[1], cfg.analysis.nx), cfg.analysis.nt), nf 412 | ) 413 | true_df["t"] = np.repeat( 414 | np.tile(np.linspace(t_range[0], t_range[1], cfg.analysis.nt), cfg.analysis.nx), nf 415 | ) 416 | 417 | return plot_df, true_df 418 | 419 | 420 | def analyze_solution_profiles( 421 | cfg, test_dataset: ANPDataset, plot_df: pd.DataFrame, true_df: pd.DataFrame 422 | ): 423 | outdir = Path(cfg.analysis.outdir) 424 | if not outdir.exists(): 425 | outdir.mkdir() 426 | plot_dir = outdir / "plots" 427 | if not plot_dir.exists(): 428 | plot_dir.mkdir() 429 | 430 | t_of_interest = cfg.analysis.t_of_interest 431 | params = get_params(cfg, test_dataset) 432 | n_p_stars = len(params) 433 | nf = len(test_dataset) 434 | nfpd = nf // n_p_stars 435 | n_models = len(cfg.analysis.nice_names) 436 | 437 | fids_of_interest = cfg.analysis.fids_of_interest 438 | params_ordered = cfg.analysis.get("params_ordered", params) 439 | 440 | if t_of_interest is not None: 441 | assert len(t_of_interest) == 1, "only look at one time point" 442 | for df in (plot_df, true_df): 443 | df["t_of_interest"] = False 444 | df["t_label"] = "" 445 | for t in t_of_interest: 446 | df.loc[np.isclose(df.t, t), "t_of_interest"] = True 447 | df.loc[np.isclose(df.t, t), "t_label"] = f"{t:.3f}" 448 | plot_df_at_ts = plot_df.loc[plot_df.t_of_interest] 449 | true_df_at_ts = true_df.loc[true_df.t_of_interest] 450 | 451 | labeller = make_labeller(cfg) 452 | plot_at_ts = ( 453 | p9.ggplot(plot_df_at_ts, mapping=p9.aes(x="x", y="u", group="f_id")) 454 | + p9.geom_line(p9.aes(color="f_id"), alpha=0.6) 455 | + p9.geom_line(data=true_df_at_ts, color="blue", linetype="dashed") 456 | + p9.facet_grid("param~model", labeller=labeller) 457 | + p9.theme_bw(base_size=cfg.analysis.base_font_size) 458 | ) 459 | plot_at_ts.save(plot_dir / f"time_plot.pdf", width=8, height=8, dpi=cfg.analysis.dpi) 460 | if fids_of_interest is not None: 461 | for f_id in fids_of_interest: 462 | input_idxs = range(int(f_id), nf, nfpd) 463 | input_contexts = test_dataset.tensors["input_contexts"] 464 | ic = input_contexts[input_idxs][:, :, 1] 465 | ic = rearrange(ic, "nd (nt nx) -> nd nt nx", nx=test_dataset.n_contexts_x)[:, 0] 466 | ic_df = ( 467 | pd.DataFrame(ic.transpose(1, 0), columns=params) 468 | .melt(var_name="param", value_name="x") 469 | .assign(param=lambda df: pd.Categorical(df.param, params_ordered, ordered=True)) 470 | ) 471 | plot_df_f_id = plot_df_at_ts.loc[plot_df_at_ts.f_id == f_id] 472 | true_df_f_id = true_df_at_ts.loc[true_df_at_ts.f_id == f_id] 473 | plot_at_ts_uncertainty = ( 474 | p9.ggplot(plot_df_f_id, mapping=p9.aes(x="x", y="u")) 475 | + p9.geom_line(p9.aes(color="model")) 476 | + p9.geom_ribbon( 477 | p9.aes( 478 | ymin="u - 3 * u_sd", ymax="u + 3 * u_sd", color="model", fill="model" 479 | ), 480 | alpha=0.2, 481 | ) 482 | + p9.geom_line(data=true_df_f_id, color="black", linetype="dashed") 483 | + p9.geom_point( 484 | data=ic_df, 485 | shape="x", 486 | inherit_aes=False, 487 | mapping=p9.aes(x="x", y=0), 488 | alpha=0.7, 489 | ) 490 | + p9.scale_color_hue(labels=["ANP", "PhysNP"], guide=None) 491 | + p9.scale_fill_hue(labels=["ANP", "PhysNP"], guide=None) 492 | + p9.facet_grid("model~param", labeller=labeller) 493 | + p9.theme_bw(base_size=cfg.analysis.base_font_size) 494 | + p9.theme(strip_text=p9.element_text(usetex=True)) 495 | ) 496 | plot_at_ts_uncertainty.save( 497 | plot_dir / f"time_plot_fid={f_id}_t={t:.2f}.pdf", 498 | width=cfg.analysis.time_plot_width, 499 | height=cfg.analysis.time_plot_height, 500 | dpi=cfg.analysis.dpi, 501 | ) 502 | 503 | 504 | def plot_context_points( 505 | cfg, test_dataset: ANPDataset, plot_df: pd.DataFrame, true_df: pd.DataFrame 506 | ): 507 | outdir = Path(cfg.analysis.outdir) 508 | if not outdir.exists(): 509 | outdir.mkdir() 510 | plot_dir = outdir / "plots" 511 | if not plot_dir.exists(): 512 | plot_dir.mkdir() 513 | ## Plot context points 514 | input_contexts = test_dataset.tensors["input_contexts"] 515 | 516 | params = get_params(cfg, test_dataset) 517 | n_p_stars = len(params) 518 | nf = len(test_dataset) 519 | nfpd = nf // n_p_stars 520 | f_ids_of_interest = cfg.analysis.get("fids_of_interest") 521 | params_ordered = cfg.analysis.get("params_ordered", params) 522 | 523 | x_range = cfg.analysis.x_range 524 | t_range = cfg.analysis.t_range 525 | if f_ids_of_interest is not None: 526 | for j, f_id in enumerate(f_ids_of_interest): 527 | input_idxs = range(int(f_id), nf, nfpd) 528 | ic = input_contexts[input_idxs] 529 | ic = rearrange(ic, "nd (nt nx) d -> (nd nt nx) d", nx=test_dataset.n_contexts_x) 530 | midx = pd.MultiIndex.from_product([params, range(100)], names=["param", "i"]) 531 | ic_df = ( 532 | pd.DataFrame(ic, index=midx, columns=["t", "x"]) 533 | .reset_index() 534 | .assign(param=lambda df: pd.Categorical(df.param, params_ordered, ordered=True)) 535 | ) 536 | labeller = make_labeller(cfg) 537 | input_context_param_fid_plot = ( 538 | p9.ggplot(ic_df, mapping=p9.aes(x="x", y="t")) 539 | + p9.geom_point(shape="x") 540 | + p9.facet_grid("~param", labeller=labeller) 541 | + p9.theme_bw(base_size=cfg.analysis.base_font_size) 542 | + p9.theme(strip_text=p9.element_text(usetex=True)) 543 | + p9.scale_x_continuous(limits=x_range) 544 | + p9.scale_y_continuous(limits=t_range) 545 | ) 546 | input_context_param_fid_plot.save( 547 | plot_dir / f"input_context_fid={f_id}.pdf", 548 | width=10, 549 | height=6, 550 | dpi=cfg.analysis.dpi, 551 | ) 552 | 553 | 554 | def analyze_conservation(cfg, test_dataset: ANPDataset, plot_df: pd.DataFrame): 555 | t_range = cfg.analysis.t_range 556 | x_range = cfg.analysis.x_range 557 | domain_length = x_range[1] - x_range[0] 558 | params = get_params(cfg, test_dataset) 559 | n_p_stars = len(params) 560 | params_ordered = cfg.analysis.get("params_ordered", params) 561 | nice_names = cfg.analysis.nice_names 562 | 563 | true_mass = get_analytical_mass_rhs( 564 | test_dataset, 0, len(test_dataset), torch.linspace(*t_range, cfg.analysis.nt) 565 | ) 566 | true_mass = rearrange(true_mass, "(nd nfpd) nt -> nd nfpd nt", nd=n_p_stars)[:, 0] 567 | true_cons_df = pd.DataFrame(true_mass.transpose(1, 0), columns=params) 568 | true_cons_df["t"] = torch.linspace(*t_range, cfg.analysis.nt) 569 | true_cons_df["t_idx"] = torch.arange(0, cfg.analysis.nt) 570 | true_cons_df = true_cons_df.melt( 571 | id_vars=["t", "t_idx"], var_name="param", value_name="true" 572 | ).assign(param=lambda df: pd.Categorical(df.param, params_ordered, ordered=True)) 573 | 574 | cons_df = ( 575 | plot_df.assign( 576 | t_idx=lambda df: (np.round((cfg.analysis.nt - 1) * (df.t / df.t.max()))).astype(int) 577 | ) 578 | .groupby(by=["param", "model", "t_idx", "f_id"]) 579 | .agg( 580 | lhs=pd.NamedAgg("u", lambda x: x[1:].mean() * domain_length), 581 | rhs=pd.NamedAgg("u", lambda x: x[:-1].mean() * domain_length), 582 | ) 583 | .reset_index() 584 | .assign(trap=lambda df: (df["rhs"] + df["lhs"]) / 2) 585 | .set_index(["param", "t_idx"]) 586 | .join(true_cons_df.set_index(["param", "t_idx"])) 587 | .reset_index() 588 | .assign(error=lambda df: df["trap"] - df["true"]) 589 | ) 590 | 591 | cons_df_for_plot = cons_df.loc[lambda df: df.f_id == "1"] 592 | 593 | labeller = make_labeller(cfg) 594 | cons_plot = ( 595 | p9.ggplot(cons_df_for_plot, p9.aes(x="t")) 596 | + p9.geom_line(p9.aes(y="lhs", color="model")) 597 | + p9.geom_line(p9.aes(y="rhs", color="model")) 598 | + p9.geom_line(p9.aes(y="true"), true_cons_df, linetype="dashed") 599 | + p9.facet_grid("~param", labeller=labeller) 600 | + p9.theme_bw(base_size=cfg.analysis.base_font_size) 601 | + p9.theme(strip_text=p9.element_text(usetex=True)) 602 | + p9.scale_color_hue(labels=nice_names.values()) # , guide=None) 603 | + p9.ylab("Mass") 604 | ) 605 | 606 | outdir = Path(cfg.analysis.outdir) 607 | if not outdir.exists(): 608 | outdir.mkdir() 609 | plot_dir = outdir / "plots" 610 | if not plot_dir.exists(): 611 | plot_dir.mkdir() 612 | for ext in ("pdf", "png"): 613 | cons_plot.save( 614 | plot_dir / f"cons_plot.{ext}", 615 | dpi=cfg.analysis.dpi, 616 | width=cfg.analysis.cons_plot_width, 617 | height=cfg.analysis.cons_plot_height, 618 | ) 619 | 620 | # cons_df = cons_df.set_index(["param", "t"]).join(true_cons_df.set_index(["param", "t"])).reset_index().assign(error = lambda df: df["trap"] - df["true"]) 621 | # cons_df_at_t = cons_df.loc[lambda df: np.isin(df["t"], cfg.analysis.t_of_interest[0])] 622 | cons_df_at_t = ( 623 | cons_df.loc[lambda df: np.isclose(df.t, cfg.analysis.t_of_interest[0])] 624 | .groupby(by=["param", "model"]) 625 | .agg(error=pd.NamedAgg("error", np.mean)) 626 | .reset_index() 627 | ) 628 | cons_df_at_t.to_csv(plot_dir / "cons_df.csv") 629 | 630 | return cons_df, true_cons_df 631 | 632 | 633 | def make_labeller(cfg): 634 | nice_names = cfg.analysis.nice_names 635 | 636 | def labeller(value): 637 | if value in nice_names: 638 | return nice_names[value] 639 | elif value == "0.5": 640 | return "Outside training range~($u^\\star=0.5$)" 641 | elif value == "0.6": 642 | return "Inside training range~($u^\\star=0.6$)" 643 | else: 644 | return value 645 | 646 | return labeller 647 | 648 | 649 | def infer(cfg) -> Dict[str, Dict[str, Tensor]]: 650 | dataset_path = Path(cfg.datasets.save_path) 651 | test_load_path = dataset_path / "test.pt" 652 | test_dataset = instantiate(cfg.datasets.test, load_path=test_load_path) 653 | 654 | input_targets, _, input_ts = get_test_solution(cfg, test_dataset) 655 | 656 | gpu = torch.device(cfg.analysis.gpu) 657 | 658 | out = {} 659 | 660 | for method_name, method_cfg in cfg.analysis.methods.items(): 661 | if method_cfg.get("use_empirical_mass", False): 662 | mass_rhs_in = get_empirical_mass_rhs(cfg) 663 | else: 664 | mass_rhs_in = None 665 | out[method_name] = run_inference_for_method( 666 | method_name, 667 | method_cfg, 668 | test_dataset, 669 | input_targets, 670 | gpu, 671 | input_ts, 672 | mass_rhs_in=mass_rhs_in, 673 | ) 674 | if method_cfg.get("truncated_version", False): 675 | out[f"{method_name}_trunc"] = truncate_results(out[method_name]) 676 | if method_cfg.get("constrained_version", False): 677 | constrained_results_path = Path(method_cfg.infer_path_constrained) 678 | if (not constrained_results_path.exists()) or method_cfg.overwrite_constrained: 679 | cnstrd = constrained_results(out[method_name], "monotone") 680 | torch.save(cnstrd, constrained_results_path) 681 | else: 682 | cnstrd = torch.load(constrained_results_path) 683 | out[f"{method_name}_cnstrd"] = cnstrd 684 | out[f"{method_name}_cnstrd_trunc"] = truncate_results(cnstrd) 685 | 686 | if method_cfg.get("nonneg_path", False): 687 | nonneg_path = Path(method_cfg.nonneg_path) 688 | if (not nonneg_path.exists()) or method_cfg.overwrite_nonneg: 689 | nonneg = constrained_results(out[method_name], "nonneg") 690 | torch.save(nonneg, nonneg_path) 691 | else: 692 | nonneg = torch.load(nonneg_path) 693 | out[f"{method_name}_nonneg"] = nonneg 694 | out[f"{method_name}_nonneg_trunc"] = truncate_results(nonneg) 695 | 696 | return out 697 | 698 | 699 | Model = Union[ANP, PhysNP] 700 | 701 | 702 | def run_inference_for_method( 703 | name: str, cfg, test_dataset: ANPDataset, input_targets, gpu, input_ts, mass_rhs_in=None 704 | ): 705 | 706 | seed: Optional[int] = cfg.get("seed", None) 707 | if seed is not None: 708 | torch.manual_seed(seed) 709 | nt = input_ts.shape[0] 710 | save_path = Path(cfg.infer_path) 711 | if (not save_path.exists()) or cfg.overwrite: 712 | model: Model = instantiate(cfg.model) 713 | if cfg.state_dict is not None: 714 | state_dict = torch.load(cfg.state_dict) 715 | model.load_state_dict(state_dict) 716 | else: 717 | state_dict = model.state_dict() 718 | results = {} 719 | if isinstance(model, PhysNP): 720 | constraint_precision = cfg.constraint_precision 721 | if constraint_precision is not None: 722 | state_dict["log_constraint_precision_train"] = torch.tensor( 723 | constraint_precision 724 | ).log() 725 | model.load_state_dict(state_dict) 726 | anp_state_dict_path = cfg.anp_state_dict 727 | if anp_state_dict_path is not None: 728 | anp_state_dict = torch.load(anp_state_dict_path) 729 | model.anp.load_state_dict(anp_state_dict) 730 | loc, scale, cov = physnp_batched( 731 | test_dataset, 732 | input_targets, 733 | model, 734 | gpu, 735 | nt, 736 | input_ts, 737 | n_samples=cfg.n_samples, 738 | mass_rhs_in=mass_rhs_in, 739 | ) 740 | pred = rearrange(loc, "nf nt nx 1 -> nf (nx nt)") 741 | results["cov"] = cov 742 | elif isinstance(model, ANP) or isinstance(model, PINP) or isinstance(model, SoftcANP): 743 | loc, scale = anp_batched( 744 | test_dataset, input_targets, model, gpu, n_samples=cfg.n_samples 745 | ) 746 | pred = rearrange(loc, "nf nx_nt 1 -> nf nx_nt") 747 | else: 748 | raise NotImplementedError() 749 | results.update({"loc": loc, "scale": scale, "pred": pred}) 750 | torch.save(results, save_path) 751 | else: 752 | results = torch.load(save_path) 753 | return results 754 | 755 | 756 | def truncate_results(results: Dict[str, Tensor]): 757 | # Enforce monotonicity 758 | loc = rearrange(results["loc"], "nf nt nx 1 -> nf nt nx") 759 | nf, nt, nx = loc.shape 760 | new_loc = loc.clone() 761 | for i in range(1, nx): 762 | new_loc[:, :, i] = new_loc[:, :, i] - F.relu(new_loc[:, :, i] - new_loc[:, :, i - 1]) 763 | 764 | new_loc = F.relu(new_loc) 765 | 766 | pred = rearrange(new_loc, "nf nt nx -> nf (nx nt)") 767 | return {"loc": new_loc, "scale": results["scale"], "pred": pred} 768 | 769 | 770 | def constrained_results(results: Dict[str, Tensor], mode: InequalityConstraint): 771 | # Enforce monotonicity 772 | new_loc = apply_non_linear_ineq_constraint( 773 | results["loc"], results["cov"], max_iter=10, mode=mode 774 | ) 775 | pred = rearrange(new_loc, "nf nt nx 1 -> nf (nx nt)") 776 | return {"loc": new_loc, "scale": results["scale"], "pred": pred} 777 | 778 | 779 | def anp_batched(test_dataset: ANPDataset, input_targets: Tensor, anp: ANP, gpu, n_samples: int): 780 | nf = test_dataset.tensors["input_contexts"].shape[0] 781 | n_functions_per_batch = 5 782 | anp = anp.to(gpu) 783 | it = input_targets.to(gpu).unsqueeze(0).expand(n_functions_per_batch, -1, -1) 784 | outputs = [] 785 | with torch.no_grad(): 786 | for i in range(0, nf, n_functions_per_batch): 787 | ic = test_dataset.tensors["input_contexts"][i : (i + n_functions_per_batch)] 788 | oc = test_dataset.tensors["output_contexts"][i : (i + n_functions_per_batch)] 789 | anp_dist = anp.get_loc_and_scale_batched( 790 | ic.to(gpu), 791 | oc.to(gpu), 792 | it, 793 | n_samples=n_samples, 794 | batch_size=50_000, 795 | ) 796 | loc = anp_dist.loc.cpu() 797 | scale = anp_dist.scale.cpu() 798 | outputs.append((loc, scale)) 799 | loc = torch.cat([x[0] for x in outputs]) 800 | scale = torch.cat([x[1] for x in outputs]) 801 | return loc, scale 802 | 803 | 804 | def physnp_batched( 805 | test_dataset: ANPDataset, 806 | input_targets: Tensor, 807 | anp: PhysNP, 808 | gpu, 809 | nt: int, 810 | ts: int, 811 | n_samples: int, 812 | mass_rhs_in: Optional[Tensor] = None, 813 | ): 814 | nf = test_dataset.tensors["input_contexts"].shape[0] 815 | # nf = 5 816 | n_functions_per_batch = 5 817 | anp = anp.to(gpu) 818 | it = input_targets.to(gpu).unsqueeze(0).expand(n_functions_per_batch, -1, -1) 819 | it = rearrange(it, "nf (nx nt) d -> nf nt nx d", nt=nt) 820 | outputs = [] 821 | ts = rearrange(ts, "nt -> 1 nt") 822 | with torch.no_grad(): 823 | for i in range(0, nf, n_functions_per_batch): 824 | ic = test_dataset.tensors["input_contexts"][i : (i + n_functions_per_batch)] 825 | oc = test_dataset.tensors["output_contexts"][i : (i + n_functions_per_batch)] 826 | ic = rearrange(ic, "nf (nt nx) d -> nf nt nx d", nt=test_dataset.n_contexts_t) 827 | oc = rearrange(oc, "nf (nt nx) d -> nf nt nx d", nt=test_dataset.n_contexts_t) 828 | if mass_rhs_in is not None: 829 | mass_rhs = mass_rhs_in[i : (i + n_functions_per_batch)] 830 | else: 831 | mass_rhs = get_analytical_mass_rhs(test_dataset, i, i + n_functions_per_batch, ts) 832 | mass_rhs_i = mass_rhs.to(it.device) 833 | anp_dist, cov = anp.get_loc_and_scale_batched( 834 | ic.to(gpu), 835 | oc.to(gpu), 836 | it, 837 | n_samples=n_samples, 838 | batch_size=50_000, 839 | mass_rhs=mass_rhs_i.to(gpu), 840 | ) 841 | loc = anp_dist.loc.cpu() 842 | scale = anp_dist.scale.cpu() 843 | outputs.append((loc, scale, cov)) 844 | loc = torch.cat([x[0] for x in outputs]) 845 | scale = torch.cat([x[1] for x in outputs]) 846 | if outputs[0][2] is not None: 847 | cov = torch.cat([x[2] for x in outputs]) 848 | else: 849 | cov = None 850 | return loc, scale, cov 851 | 852 | 853 | def get_analytical_mass_rhs(test_dataset, i, i_end, ts): 854 | if isinstance(test_dataset, PorousMediumEquation): 855 | degree = test_dataset.degrees[i:i_end] 856 | degree = rearrange(degree, "nf -> nf 1") 857 | mass_rhs = mass_pme(degree, ts) 858 | elif isinstance(test_dataset, StefanPME): 859 | stefans = test_dataset.stefans[i:i_end] 860 | mass_rhs_list = [] 861 | for stefan in stefans: 862 | mass_rhs_list.append(stefan.mass_stefan(ts)) 863 | mass_rhs = torch.stack(mass_rhs_list, dim=0) 864 | elif isinstance(test_dataset, HeatEquation): 865 | nf = test_dataset.params.shape[0] 866 | nt = ts.shape[0] 867 | mass_rhs = torch.zeros((nf, nt), device=ts.device) 868 | elif isinstance(test_dataset, LinearAdvection): 869 | a_values = rearrange(test_dataset.parameters[i:i_end], "nf -> nf 1") 870 | mass_rhs = test_dataset.mass_advection(ts, a_values) 871 | elif isinstance(test_dataset, Burgers): 872 | a_values = rearrange(test_dataset.parameters[i:i_end], "nf -> nf 1") 873 | mass_rhs = mass_burgers(ts, a_values) 874 | else: 875 | raise NotImplementedError() 876 | return mass_rhs 877 | 878 | 879 | def get_empirical_mass_rhs(cfg): 880 | dataset_path = Path(cfg.datasets.save_path) 881 | test_load_path = dataset_path / "test.pt" 882 | test_dataset = instantiate(cfg.datasets.test, load_path=test_load_path) 883 | _, output_targets, _ = get_test_solution(cfg, test_dataset) 884 | ot = rearrange(output_targets, "nf (nx nt) -> nf nx nt", nt=cfg.analysis.nt) 885 | return 0.5 * ( 886 | reduce(ot[:, 1:], "nf nx nt -> nf nt", "mean") 887 | + reduce(ot[:, :-1], "nf nx nt -> nf nt", "mean") 888 | ) 889 | 890 | 891 | def get_test_solution(cfg, dataset: ANPDataset): 892 | nt: int = cfg.analysis.nt 893 | nx: int = cfg.analysis.nx 894 | 895 | tlims = dataset.lims("t") 896 | xlims = dataset.lims("x") 897 | 898 | ts = torch.linspace(*tlims, nt) 899 | xs = torch.linspace(*xlims, nx) 900 | inputs = meshgrid(ts.unsqueeze(0), xs.unsqueeze(0)) 901 | # input_targets = rearrange(inputs, "nf nt nx d -> nf (nx nt) d") 902 | input_targets = rearrange(inputs, "nf nt nx d -> nf (nt nx) d") 903 | 904 | true_soln = dataset.solution(input_targets) 905 | output_targets = rearrange(true_soln, "nf (nt nx) 1 -> nf (nx nt)", nt=nt) 906 | 907 | input_targets = rearrange(input_targets, "1 (nt nx) d -> (nx nt) d", nt=nt) 908 | 909 | return input_targets, output_targets, ts 910 | 911 | 912 | def estimate_shock_interval( 913 | mean: Tensor, cov: Tensor, n_samples=10, n_samples_per_batch: Optional[int] = None 914 | ): 915 | nf, nt, nx, _ = mean.shape 916 | outlist = [] 917 | if n_samples_per_batch is None: 918 | n_samples_per_batch = n_samples 919 | for fid in tqdm(range(nf), desc="shock positions"): 920 | mean_i = mean[fid] 921 | cov_i = cov[fid] 922 | outlist_fid = [] 923 | for _ in range(n_samples // n_samples_per_batch): 924 | first_less_than_zero = _estimate_shock_interval_for_one_f( 925 | mean_i, cov_i, n_samples_per_batch 926 | ) 927 | outlist_fid.append(first_less_than_zero) 928 | outlist.append(torch.concat(outlist_fid, dim=0)) 929 | return torch.stack(outlist, dim=1) 930 | 931 | 932 | def _estimate_shock_interval_for_one_f(mean_i: Tensor, cov_i: Tensor, n_samples: int): 933 | nt, nx, _ = mean_i.shape 934 | device = mean_i.device 935 | idx = rearrange(torch.arange(0, nx, device=device), "nx -> 1 1 nx") 936 | try: 937 | chol: Tensor = torch.linalg.cholesky(cov_i).unsqueeze(0) 938 | except: 939 | chol: Tensor = torch.linalg.cholesky( 940 | cov_i + torch.eye(201).unsqueeze(0).cuda() * 1e-8 941 | ).unsqueeze(0) 942 | z = torch.randn(n_samples, *mean_i.shape, device=device) 943 | y = mean_i.unsqueeze(0) + chol.matmul(z) 944 | y = rearrange(y, "ns nt nx 1 -> ns nt nx") 945 | less_than_zero = y <= 0 946 | objective = less_than_zero * idx + (~less_than_zero) * nx 947 | return torch.argmin(objective, dim=2) 948 | 949 | 950 | if __name__ == "__main__": 951 | main() 952 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/config.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | save_best_model: True 3 | trainer: 4 | _target_: pytorch_lightning.Trainer 5 | accelerator: "gpu" 6 | devices: 1 7 | logger: 8 | _target_: pytorch_lightning.loggers.TensorBoardLogger 9 | default_hp_metric: False 10 | checkpoint_callback: 11 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 12 | filename: "epoch={epoch}-val_loss={val_loss:.3f}" 13 | save_top_k: 5 14 | verbose: True 15 | monitor: "val_loss" 16 | mode: "min" 17 | save_on_train_epoch_end: False 18 | auto_insert_metric_name: False 19 | analyze: 20 | plot_shock: False 21 | datasets: 22 | dataset_overwrite: False -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/experiments/1b_pme_var_m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | base_dir: ./output/paper/1b_pme_var_m 3 | datasets: 4 | degree_min: 1 5 | degree_max: 6 6 | save_path: ${base_dir}/datasets 7 | train: 8 | _target_: deep_pdes.datasets.pme.PorousMediumEquation 9 | n_functions: 10000 10 | n_contexts_t: 10 11 | n_contexts_x: 10 12 | n_targets_t: 10 13 | n_targets_x: 10 14 | batch_size: 250 15 | scale_lims: 16 | - 1.0 17 | - 1.0 18 | degree_min: ${datasets.degree_min} 19 | degree_max: ${datasets.degree_max} 20 | valid: 21 | _target_: deep_pdes.datasets.pme.PorousMediumEquation 22 | n_functions: 100 23 | n_contexts_t: ${datasets.train.n_contexts_t} 24 | n_contexts_x: ${datasets.train.n_contexts_x} 25 | n_targets_t: ${datasets.train.n_targets_t} 26 | n_targets_x: ${datasets.train.n_targets_x} 27 | batch_size: 250 28 | scale_lims: 29 | - 1.0 30 | - 1.0 31 | degree_min: ${datasets.degree_min} 32 | degree_max: ${datasets.degree_max} 33 | test: 34 | _target_: deep_pdes.datasets.pme.PorousMediumEquation 35 | n_functions: 150 36 | n_contexts_t: ${datasets.train.n_contexts_t} 37 | n_contexts_x: ${datasets.train.n_contexts_x} 38 | n_targets_t: ${datasets.train.n_targets_t} 39 | n_targets_x: ${datasets.train.n_targets_x} 40 | batch_size: 250 41 | scale_lims: 42 | - 1.0 43 | - 1.0 44 | # degree_min: ${datasets.degree_min} 45 | # degree_max: ${datasets.degree_max} 46 | degrees: 47 | - 1 48 | - 3 49 | - 6 50 | methods: 51 | anp: 52 | model: 53 | _target_: deep_pdes.attentive_neural_process.anp.ANP 54 | num_hidden: 128 55 | dim_x: 2 56 | dim_y: 1 57 | lr: 1e-4 58 | state_dict: ${base_dir}/train/anp.pt 59 | physnp: 60 | model: 61 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 62 | anp: 63 | _target_: deep_pdes.attentive_neural_process.anp.ANP 64 | num_hidden: 128 65 | dim_x: 2 66 | dim_y: 1 67 | lr: 1e-4 68 | state_dict: ${base_dir}/train/physnp.pt 69 | physnp_noretcov: 70 | model: 71 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 72 | anp: 73 | _target_: deep_pdes.attentive_neural_process.anp.ANP 74 | num_hidden: 128 75 | dim_x: 2 76 | dim_y: 1 77 | lr: 1e-4 78 | return_full_cov: False 79 | state_dict: ${base_dir}/train/physnp.pt 80 | physnp_limit: 81 | model: 82 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 83 | anp: 84 | _target_: deep_pdes.attentive_neural_process.anp.ANP 85 | num_hidden: 128 86 | dim_x: 2 87 | dim_y: 1 88 | lr: 1e-4 89 | limiting_mode: physnp 90 | state_dict: ${base_dir}/train/physnp.pt 91 | physnp_limit_noretcov: 92 | model: 93 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 94 | anp: 95 | _target_: deep_pdes.attentive_neural_process.anp.ANP 96 | num_hidden: 128 97 | dim_x: 2 98 | dim_y: 1 99 | lr: 1e-4 100 | limiting_mode: physnp 101 | return_full_cov: False 102 | state_dict: ${base_dir}/train/physnp.pt 103 | hcnp: 104 | model: 105 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 106 | anp: 107 | _target_: deep_pdes.attentive_neural_process.anp.ANP 108 | num_hidden: 128 109 | dim_x: 2 110 | dim_y: 1 111 | lr: 1e-4 112 | limiting_mode: hcnp 113 | return_full_cov: False 114 | state_dict: ${base_dir}/train/physnp.pt 115 | physnp_second_deriv: 116 | model: 117 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 118 | anp: 119 | _target_: deep_pdes.attentive_neural_process.anp.ANP 120 | num_hidden: 128 121 | dim_x: 2 122 | dim_y: 1 123 | lr: 1e-4 124 | limiting_mode: physnp 125 | constraint_precision_train: 1e4 126 | train_precision: False 127 | second_deriv_alpha: 0.9 128 | state_dict: ${base_dir}/train/physnp.pt 129 | physnp_second_deriv_noretcov: 130 | model: 131 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 132 | anp: 133 | _target_: deep_pdes.attentive_neural_process.anp.ANP 134 | num_hidden: 128 135 | dim_x: 2 136 | dim_y: 1 137 | lr: 1e-4 138 | return_full_cov: False 139 | limiting_mode: physnp 140 | constraint_precision_train: 1e4 141 | train_precision: False 142 | second_deriv_alpha: 0.9 143 | state_dict: ${base_dir}/train/physnp.pt 144 | pinp: 145 | model: 146 | _target_: deep_pdes.attentive_neural_process.softc.PINP 147 | anp: 148 | _target_: deep_pdes.attentive_neural_process.anp.ANP 149 | num_hidden: 128 150 | dim_x: 2 151 | dim_y: 1 152 | pressure_fn: 153 | _target_: deep_pdes.attentive_neural_process.softc.PMEPressureFn 154 | pinns_lambda: 1.0 155 | lr: 1e-4 156 | state_dict: ${base_dir}/train/pinp.pt 157 | pinp_1en2: 158 | model: 159 | _target_: deep_pdes.attentive_neural_process.softc.PINP 160 | anp: 161 | _target_: deep_pdes.attentive_neural_process.anp.ANP 162 | num_hidden: 128 163 | dim_x: 2 164 | dim_y: 1 165 | pressure_fn: 166 | _target_: deep_pdes.attentive_neural_process.softc.PMEPressureFn 167 | pinns_lambda: 1e-2 168 | lr: 1e-4 169 | state_dict: ${base_dir}/train/pinp_1en2.pt 170 | pinp_1en1: 171 | model: 172 | _target_: deep_pdes.attentive_neural_process.softc.PINP 173 | anp: 174 | _target_: deep_pdes.attentive_neural_process.anp.ANP 175 | num_hidden: 128 176 | dim_x: 2 177 | dim_y: 1 178 | pressure_fn: 179 | _target_: deep_pdes.attentive_neural_process.softc.PMEPressureFn 180 | pinns_lambda: 1e-1 181 | lr: 1e-4 182 | state_dict: ${base_dir}/train/pinp_1en1.pt 183 | pinp_1e1: 184 | model: 185 | _target_: deep_pdes.attentive_neural_process.softc.PINP 186 | anp: 187 | _target_: deep_pdes.attentive_neural_process.anp.ANP 188 | num_hidden: 128 189 | dim_x: 2 190 | dim_y: 1 191 | pressure_fn: 192 | _target_: deep_pdes.attentive_neural_process.softc.PMEPressureFn 193 | pinns_lambda: 1e1 194 | lr: 1e-4 195 | state_dict: ${base_dir}/train/pinp_1e1.pt 196 | pinp_1e2: 197 | model: 198 | _target_: deep_pdes.attentive_neural_process.softc.PINP 199 | anp: 200 | _target_: deep_pdes.attentive_neural_process.anp.ANP 201 | num_hidden: 128 202 | dim_x: 2 203 | dim_y: 1 204 | pressure_fn: 205 | _target_: deep_pdes.attentive_neural_process.softc.PMEPressureFn 206 | pinns_lambda: 1e2 207 | lr: 1e-4 208 | state_dict: ${base_dir}/train/pinp_1e2.pt 209 | pinp_1e6: 210 | model: 211 | _target_: deep_pdes.attentive_neural_process.softc.PINP 212 | anp: 213 | _target_: deep_pdes.attentive_neural_process.anp.ANP 214 | num_hidden: 128 215 | dim_x: 2 216 | dim_y: 1 217 | pressure_fn: 218 | _target_: deep_pdes.attentive_neural_process.softc.PMEPressureFn 219 | pinns_lambda: 1e6 220 | lr: 1e-4 221 | state_dict: ${base_dir}/train/pinp_1e6.pt 222 | analysis: 223 | methods: 224 | anp: 225 | model: ${methods.anp.model} 226 | state_dict: ${methods.anp.state_dict} 227 | infer_path: ${base_dir}/analysis/anp_infer.pt 228 | overwrite: False 229 | truncated_version: False 230 | n_samples: 100 231 | seed: 42 232 | shock_path: ${base_dir}/analysis/anp_shocks.pt 233 | pinp: 234 | model: ${methods.pinp.model} 235 | state_dict: ${methods.pinp.state_dict} 236 | infer_path: ${base_dir}/analysis/pinp_infer.pt 237 | overwrite: False 238 | truncated_version: False 239 | n_samples: 100 240 | shock_path: ${base_dir}/analysis/pinp_shocks.pt 241 | # pinp_1en2: 242 | # model: ${methods.pinp.model} 243 | # state_dict: ${methods.pinp.state_dict} 244 | # infer_path: ${base_dir}/analysis/pinp_infer_1en2.pt 245 | # overwrite: False 246 | # truncated_version: False 247 | # n_samples: 100 248 | # shock_path: ${base_dir}/analysis/pinp_shocks_1en2.pt 249 | # pinp_1en1: 250 | # model: ${methods.pinp.model} 251 | # state_dict: ${methods.pinp.state_dict} 252 | # infer_path: ${base_dir}/analysis/pinp_infer_1en1.pt 253 | # overwrite: False 254 | # truncated_version: False 255 | # n_samples: 100 256 | # shock_path: ${base_dir}/analysis/pinp_shocks_1en1.pt 257 | # pinp_1e1: 258 | # model: ${methods.pinp.model} 259 | # state_dict: ${methods.pinp.state_dict} 260 | # infer_path: ${base_dir}/analysis/pinp_infer_1e1.pt 261 | # overwrite: False 262 | # truncated_version: False 263 | # n_samples: 100 264 | # shock_path: ${base_dir}/analysis/pinp_shocks_1e1.pt 265 | # pinp_1e2: 266 | # model: ${methods.pinp.model} 267 | # state_dict: ${methods.pinp.state_dict} 268 | # infer_path: ${base_dir}/analysis/pinp_infer_1e2.pt 269 | # overwrite: False 270 | # truncated_version: False 271 | # n_samples: 100 272 | # shock_path: ${base_dir}/analysis/pinp_shocks_1e2pt 273 | # pinp_1e6: 274 | # model: ${methods.pinp.model} 275 | # state_dict: ${methods.pinp.state_dict} 276 | # infer_path: ${base_dir}/analysis/pinp_infer_1e6.pt 277 | # overwrite: False 278 | # truncated_version: False 279 | # n_samples: 100 280 | # shock_path: ${base_dir}/analysis/pinp_shocks_1e6pt 281 | # physnp_notrain: 282 | # model: ${methods.physnp.model} 283 | # state_dict: ${methods.physnp.state_dict} 284 | # infer_path: ${base_dir}/analysis/physnp_notrain.pt 285 | # overwrite: False 286 | # constraint_precision: 1e6 287 | # anp_state_dict: ${methods.anp.state_dict} 288 | # truncated_version: False 289 | # constrained_version: False 290 | # infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 291 | # overwrite_constrained: False 292 | # nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 293 | # overwrite_nonneg: False 294 | # n_samples: 100 295 | # seed: 42 296 | physnp_notrain_1en9: 297 | model: ${methods.physnp_noretcov.model} 298 | state_dict: ${methods.physnp_noretcov.state_dict} 299 | infer_path: ${base_dir}/analysis/physnp_notrain_1en6.pt 300 | overwrite: False 301 | constraint_precision: 1e-9 302 | anp_state_dict: ${methods.anp.state_dict} 303 | truncated_version: False 304 | constrained_version: False 305 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 306 | overwrite_constrained: False 307 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 308 | overwrite_nonneg: False 309 | n_samples: 100 310 | seed: 42 311 | physnp_notrain_1en6: 312 | model: ${methods.physnp_noretcov.model} 313 | state_dict: ${methods.physnp_noretcov.state_dict} 314 | infer_path: ${base_dir}/analysis/physnp_notrain_1en6.pt 315 | overwrite: False 316 | constraint_precision: 1e-6 317 | anp_state_dict: ${methods.anp.state_dict} 318 | truncated_version: False 319 | constrained_version: False 320 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 321 | overwrite_constrained: False 322 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 323 | overwrite_nonneg: False 324 | n_samples: 100 325 | seed: 42 326 | physnp_notrain_1en3: 327 | model: ${methods.physnp_noretcov.model} 328 | state_dict: ${methods.physnp_noretcov.state_dict} 329 | infer_path: ${base_dir}/analysis/physnp_notrain_1en3.pt 330 | overwrite: False 331 | constraint_precision: 1e-3 332 | anp_state_dict: ${methods.anp.state_dict} 333 | truncated_version: False 334 | constrained_version: False 335 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 336 | overwrite_constrained: False 337 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 338 | overwrite_nonneg: False 339 | n_samples: 100 340 | seed: 42 341 | physnp_notrain_1en2: 342 | model: ${methods.physnp_noretcov.model} 343 | state_dict: ${methods.physnp_noretcov.state_dict} 344 | infer_path: ${base_dir}/analysis/physnp_notrain_1en2.pt 345 | overwrite: False 346 | constraint_precision: 1e-2 347 | anp_state_dict: ${methods.anp.state_dict} 348 | truncated_version: False 349 | constrained_version: False 350 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 351 | overwrite_constrained: False 352 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 353 | overwrite_nonneg: False 354 | n_samples: 100 355 | seed: 42 356 | physnp_notrain_1en1: 357 | model: ${methods.physnp_noretcov.model} 358 | state_dict: ${methods.physnp_noretcov.state_dict} 359 | infer_path: ${base_dir}/analysis/physnp_notrain_1en1.pt 360 | overwrite: False 361 | constraint_precision: 1e-1 362 | anp_state_dict: ${methods.anp.state_dict} 363 | truncated_version: False 364 | constrained_version: False 365 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 366 | overwrite_constrained: False 367 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 368 | overwrite_nonneg: False 369 | n_samples: 100 370 | seed: 42 371 | physnp_notrain_1e0: 372 | model: ${methods.physnp_noretcov.model} 373 | state_dict: ${methods.physnp_noretcov.state_dict} 374 | infer_path: ${base_dir}/analysis/physnp_notrain_1e0.pt 375 | overwrite: False 376 | constraint_precision: 1.0 377 | anp_state_dict: ${methods.anp.state_dict} 378 | truncated_version: False 379 | constrained_version: False 380 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 381 | overwrite_constrained: False 382 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 383 | overwrite_nonneg: False 384 | n_samples: 100 385 | seed: 42 386 | physnp_notrain_1e1: 387 | model: ${methods.physnp_noretcov.model} 388 | state_dict: ${methods.physnp_noretcov.state_dict} 389 | infer_path: ${base_dir}/analysis/physnp_notrain_1e1.pt 390 | overwrite: False 391 | constraint_precision: 10.0 392 | anp_state_dict: ${methods.anp.state_dict} 393 | truncated_version: False 394 | constrained_version: False 395 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 396 | overwrite_constrained: False 397 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 398 | overwrite_nonneg: False 399 | n_samples: 100 400 | seed: 42 401 | physnp_notrain_1e2: 402 | model: ${methods.physnp_noretcov.model} 403 | state_dict: ${methods.physnp_noretcov.state_dict} 404 | infer_path: ${base_dir}/analysis/physnp_notrain_1e2.pt 405 | overwrite: False 406 | constraint_precision: 100.0 407 | anp_state_dict: ${methods.anp.state_dict} 408 | truncated_version: False 409 | constrained_version: False 410 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 411 | overwrite_constrained: False 412 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 413 | overwrite_nonneg: False 414 | n_samples: 100 415 | seed: 42 416 | physnp_notrain_1e3: 417 | model: ${methods.physnp_noretcov.model} 418 | state_dict: ${methods.physnp_noretcov.state_dict} 419 | infer_path: ${base_dir}/analysis/physnp_notrain_1e3.pt 420 | overwrite: False 421 | constraint_precision: 1e3 422 | anp_state_dict: ${methods.anp.state_dict} 423 | truncated_version: False 424 | constrained_version: False 425 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 426 | overwrite_constrained: False 427 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 428 | overwrite_nonneg: False 429 | n_samples: 100 430 | seed: 42 431 | physnp_notrain_1e4: 432 | model: ${methods.physnp_noretcov.model} 433 | state_dict: ${methods.physnp_noretcov.state_dict} 434 | infer_path: ${base_dir}/analysis/physnp_notrain_1e4.pt 435 | overwrite: False 436 | constraint_precision: 1e4 437 | anp_state_dict: ${methods.anp.state_dict} 438 | truncated_version: False 439 | constrained_version: False 440 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 441 | overwrite_constrained: False 442 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 443 | overwrite_nonneg: False 444 | n_samples: 100 445 | seed: 42 446 | physnp_notrain_1e5: 447 | model: ${methods.physnp_noretcov.model} 448 | state_dict: ${methods.physnp_noretcov.state_dict} 449 | infer_path: ${base_dir}/analysis/physnp_notrain_1e5.pt 450 | overwrite: False 451 | constraint_precision: 1e5 452 | anp_state_dict: ${methods.anp.state_dict} 453 | truncated_version: False 454 | constrained_version: False 455 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 456 | overwrite_constrained: False 457 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 458 | overwrite_nonneg: False 459 | n_samples: 100 460 | seed: 42 461 | physnp_notrain_1e6: 462 | model: ${methods.physnp_noretcov.model} 463 | state_dict: ${methods.physnp_noretcov.state_dict} 464 | infer_path: ${base_dir}/analysis/physnp_notrain_1e6.pt 465 | overwrite: False 466 | constraint_precision: 1e6 467 | anp_state_dict: ${methods.anp.state_dict} 468 | truncated_version: False 469 | constrained_version: False 470 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 471 | overwrite_constrained: False 472 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 473 | overwrite_nonneg: False 474 | n_samples: 100 475 | seed: 42 476 | physnp_notrain_1e7: 477 | model: ${methods.physnp_noretcov.model} 478 | state_dict: ${methods.physnp_noretcov.state_dict} 479 | infer_path: ${base_dir}/analysis/physnp_notrain_1e7.pt 480 | overwrite: False 481 | constraint_precision: 1e7 482 | anp_state_dict: ${methods.anp.state_dict} 483 | truncated_version: False 484 | constrained_version: False 485 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 486 | overwrite_constrained: False 487 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 488 | overwrite_nonneg: False 489 | n_samples: 100 490 | seed: 42 491 | physnp_notrain_1e8: 492 | model: ${methods.physnp_noretcov.model} 493 | state_dict: ${methods.physnp_noretcov.state_dict} 494 | infer_path: ${base_dir}/analysis/physnp_notrain_1e8.pt 495 | overwrite: False 496 | constraint_precision: 1e8 497 | anp_state_dict: ${methods.anp.state_dict} 498 | truncated_version: False 499 | constrained_version: False 500 | infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 501 | overwrite_constrained: False 502 | nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 503 | overwrite_nonneg: False 504 | n_samples: 100 505 | seed: 42 506 | physnp_limit: 507 | model: ${methods.physnp_limit.model} 508 | state_dict: ${methods.physnp_limit.state_dict} 509 | infer_path: ${base_dir}/analysis/physnp_limit.pt 510 | overwrite: False 511 | constraint_precision: 1e6 512 | anp_state_dict: ${methods.anp.state_dict} 513 | truncated_version: False 514 | constrained_version: False 515 | # infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 516 | # overwrite_constrained: false 517 | # nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 518 | # overwrite_nonneg: false 519 | n_samples: 100 520 | seed: 42 521 | shock_path: ${base_dir}/analysis/physnp_limit_shocks.pt 522 | # physnp_limit_empr: 523 | # model: ${methods.physnp_limit_noretcov.model} 524 | # state_dict: ${methods.physnp_limit_noretcov.state_dict} 525 | # infer_path: ${base_dir}/analysis/physnp_limit_empr.pt 526 | # overwrite: False 527 | # constraint_precision: 1e6 528 | # anp_state_dict: ${methods.anp.state_dict} 529 | # truncated_version: False 530 | # constrained_version: False 531 | # n_samples: 100 532 | # seed: 42 533 | # use_empirical_mass: True 534 | hcnp_notrain: 535 | model: ${methods.hcnp.model} 536 | state_dict: ${methods.hcnp.state_dict} 537 | infer_path: ${base_dir}/analysis/hcnp_notrain.pt 538 | overwrite: False 539 | constraint_precision: 1e6 540 | anp_state_dict: ${methods.anp.state_dict} 541 | truncated_version: False 542 | constrained_version: False 543 | n_samples: 100 544 | seed: 42 545 | shock_path: ${base_dir}/analysis/hcnp_notrain_shocks.pt 546 | # hcnp_notrain_empr: 547 | # model: ${methods.hcnp.model} 548 | # state_dict: ${methods.hcnp.state_dict} 549 | # infer_path: ${base_dir}/analysis/hcnp_notrain_empr.pt 550 | # overwrite: False 551 | # constraint_precision: 1e6 552 | # anp_state_dict: ${methods.anp.state_dict} 553 | # truncated_version: False 554 | # constrained_version: False 555 | # n_samples: 100 556 | # seed: 42 557 | # use_empirical_mass: True 558 | physnp_second_deriv: 559 | model: ${methods.physnp_second_deriv.model} 560 | state_dict: ${methods.physnp_second_deriv.state_dict} 561 | infer_path: ${base_dir}/analysis/physnp_second_deriv.pt 562 | overwrite: False 563 | constraint_precision: 1e6 564 | anp_state_dict: ${methods.anp.state_dict} 565 | truncated_version: True 566 | constrained_version: True 567 | infer_path_constrained: ${base_dir}/analysis/physnp_second_deriv_cnstrd.pt 568 | overwrite_constrained: False 569 | nonneg_path: ${base_dir}/analysis/physnp_second_deriv_nonneg.pt 570 | overwrite_nonneg: False 571 | n_samples: 100 572 | shock_path: ${base_dir}/analysis/physnp_second_deriv_shocks.pt 573 | # physnp_second_deriv_empr: 574 | # model: ${methods.physnp_second_deriv_noretcov.model} 575 | # state_dict: ${methods.physnp_second_deriv_noretcov.state_dict} 576 | # infer_path: ${base_dir}/analysis/physnp_second_deriv_empr.pt 577 | # overwrite: False 578 | # constraint_precision: 1e6 579 | # anp_state_dict: ${methods.anp.state_dict} 580 | # truncated_version: True 581 | # constrained_version: True 582 | # infer_path_constrained: ${base_dir}/analysis/physnp_second_deriv_cnstrd.pt 583 | # overwrite_constrained: False 584 | # nonneg_path: ${base_dir}/analysis/physnp_second_deriv_nonneg.pt 585 | # overwrite_nonneg: False 586 | # n_samples: 100 587 | # use_empirical_mass: True 588 | outdir: ${base_dir}/analysis/ 589 | inference_results : ${analysis.outdir}/inference_results.pt 590 | plot_df_path: ${analysis.outdir}/plot_df.pkl 591 | true_df_path: ${analysis.outdir}/true_df.pkl 592 | mse_at_t_df_path: ${analysis.outdir}/mse_at_t_df.pkl 593 | cons_df_path: ${analysis.outdir}/cons_df.pkl 594 | true_cons_df_path: ${analysis.outdir}/true_cons_df.pkl 595 | nt: 201 596 | nx: 201 597 | dpi: 500 598 | n_shock_samples: 500 599 | n_shock_samples_per_batch: 50 600 | base_font_size: 16 601 | x_range: 602 | - 0.0 603 | - 1.0 604 | t_range: 605 | - 0.0 606 | - 1.0 607 | gpu: "cuda:0" 608 | plot_shock: True 609 | colors: 610 | - "#F8766D" 611 | - "#7CAE00" 612 | - "#00BFC4" 613 | - "#C77CFF" 614 | - "#ff8000" 615 | nice_names: 616 | # physnp_notrain: "PhysNP" 617 | # physnp_notrain_cnstrd: "PhysNP" 618 | anp: "ANP" 619 | pinp: "SoftC-ANP" 620 | # pinp_1en2: "PINP(1e-2)" 621 | # pinp_1en1: "PINP(1e-1)" 622 | # pinp_1e1: "PINP(1e1)" 623 | # pinp_1e2: "PINP(1e2)" 624 | # pinp_1e6: "PINP(1e6)" 625 | hcnp_notrain: "HardC-ANP" 626 | # hcnp_notrain_empr: "ANP+HardC (empirical)" 627 | # physnp_notrain: "PhysNP" 628 | physnp_limit: "ProbConserv-ANP" 629 | # physnp_limit_empr: "PhysNP (empirical)" 630 | # physnp_notrain_cnstrd: "PhysNP (monotone)" 631 | # physnp_notrain_nonneg: "PhysNP (non-negative)" 632 | physnp_second_deriv: "ProbConserv-ANP (w/diff)" 633 | # physnp_second_deriv_empr: "PhysNP (w/diffusion,empirical)" 634 | physnp_notrain_1en9: "PhysNP(1e-9)" 635 | physnp_notrain_1en6: "PhysNP(1e-6)" 636 | physnp_notrain_1en3: "PhysNP(1e-3)" 637 | physnp_notrain_1en2: "PhysNP(1e-2)" 638 | physnp_notrain_1en1: "PhysNP(1e-1)" 639 | physnp_notrain_1e0: "PhysNP(0)" 640 | physnp_notrain_1e1: "PhysNP(1e1)" 641 | physnp_notrain_1e2: "PhysNP(1e2)" 642 | physnp_notrain_1e3: "PhysNP(1e3)" 643 | physnp_notrain_1e4: "PhysNP(1e4)" 644 | physnp_notrain_1e5: "PhysNP(1e5)" 645 | physnp_notrain_1e6: "PhysNP(1e6)" 646 | physnp_notrain_1e7: "PhysNP(1e7)" 647 | physnp_notrain_1e8: "PhysNP(1e8)" 648 | # physnp_second_deriv_cnstrd: "PhysNP (diffusion+m)" 649 | # physnp_second_deriv_nonneg: "PhysNP (diffusion+nn)" 650 | t_of_interest: 651 | # - 0.3 652 | - 0.5 653 | # - 0.7 654 | x_of_interest: 655 | - 0.3 656 | - 0.5 657 | - 0.7 658 | fids_of_interest: 659 | - "1" 660 | 661 | mse_plot_width: 6 662 | mse_plot_height: 3 663 | cons_plot_width: 6 664 | cons_plot_height: 3 665 | shock_plot_width: 6 666 | shock_plot_height: 6 667 | time_plot_width: 8 668 | time_plot_height: 16 669 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/experiments/2b_stefan_var_p.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | base_dir: ./output/paper/2b_stefan_var_p 3 | datasets: 4 | p_star_min: 0.55 5 | p_star_max: 0.7 6 | save_path: ${base_dir}/datasets 7 | train: 8 | _target_: deep_pdes.datasets.pme.StefanPME 9 | n_functions: 10000 10 | n_contexts_t: 10 11 | n_contexts_x: 10 12 | n_targets_t: 10 13 | n_targets_x: 10 14 | t_range: 15 | - 0.0 16 | - 0.1 17 | batch_size: 250 18 | p_star_lim: 19 | - ${datasets.p_star_min} 20 | - ${datasets.p_star_max} 21 | valid: 22 | _target_: deep_pdes.datasets.pme.StefanPME 23 | n_functions: 100 24 | n_contexts_t: ${datasets.train.n_contexts_t} 25 | n_contexts_x: ${datasets.train.n_contexts_x} 26 | n_targets_t: ${datasets.train.n_targets_t} 27 | n_targets_x: ${datasets.train.n_targets_x} 28 | batch_size: 250 29 | p_star_lim: 30 | - ${datasets.p_star_min} 31 | - ${datasets.p_star_max} 32 | t_range: ${datasets.train.t_range} 33 | test: 34 | _target_: deep_pdes.datasets.pme.StefanPME 35 | n_functions: 100 36 | n_contexts_t: ${datasets.train.n_contexts_t} 37 | n_contexts_x: ${datasets.train.n_contexts_x} 38 | n_targets_t: ${datasets.train.n_targets_t} 39 | n_targets_x: ${datasets.train.n_targets_x} 40 | batch_size: 250 41 | t_range: ${datasets.train.t_range} 42 | p_stars: 43 | - 0.5 44 | - 0.6 45 | methods: 46 | anp: 47 | model: 48 | _target_: deep_pdes.attentive_neural_process.anp.ANP 49 | num_hidden: 128 50 | dim_x: 2 51 | dim_y: 1 52 | lr: 1e-4 53 | state_dict: ${base_dir}/train/anp.pt 54 | physnp: 55 | model: 56 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 57 | anp: 58 | _target_: deep_pdes.attentive_neural_process.anp.ANP 59 | num_hidden: 128 60 | dim_x: 2 61 | dim_y: 1 62 | lr: 1e-4 63 | constraint_precision_train: 1e5 64 | train_precision: False 65 | state_dict: ${base_dir}/train/physnp.pt 66 | hcnp: 67 | model: 68 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 69 | anp: 70 | _target_: deep_pdes.attentive_neural_process.anp.ANP 71 | num_hidden: 128 72 | dim_x: 2 73 | dim_y: 1 74 | lr: 1e-4 75 | constraint_precision_train: 1e5 76 | train_precision: False 77 | limiting_mode: hcnp 78 | state_dict: ${base_dir}/train/physnp.pt 79 | pinp: 80 | model: 81 | _target_: deep_pdes.attentive_neural_process.softc.PINP 82 | anp: 83 | _target_: deep_pdes.attentive_neural_process.anp.ANP 84 | num_hidden: 128 85 | dim_x: 2 86 | dim_y: 1 87 | pressure_fn: 88 | _target_: deep_pdes.attentive_neural_process.softc.StefanPressureFn 89 | k_max: 1.0 90 | pinns_lambda: 1.0 91 | lr: 1e-4 92 | state_dict: ${base_dir}/train/pinp.pt 93 | physnp_second_deriv: 94 | model: 95 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 96 | anp: 97 | _target_: deep_pdes.attentive_neural_process.anp.ANP 98 | num_hidden: 128 99 | dim_x: 2 100 | dim_y: 1 101 | lr: 1e-4 102 | constraint_precision_train: 1e8 103 | train_precision: False 104 | second_deriv_alpha: 0.9 105 | state_dict: ${base_dir}/train/physnp.pt 106 | physnp_limit: 107 | model: 108 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 109 | anp: 110 | _target_: deep_pdes.attentive_neural_process.anp.ANP 111 | num_hidden: 128 112 | dim_x: 2 113 | dim_y: 1 114 | lr: 1e-4 115 | limiting_mode: physnp 116 | state_dict: ${base_dir}/train/physnp.pt 117 | physnp_constr: 118 | model: 119 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 120 | anp: 121 | _target_: deep_pdes.attentive_neural_process.anp.ANP 122 | num_hidden: 128 123 | dim_x: 2 124 | dim_y: 1 125 | lr: 1e-4 126 | non_linear_ineq_constraint: True 127 | state_dict: ${base_dir}/train/physnp.pt 128 | analysis: 129 | outdir: ${base_dir}/analysis/ 130 | inference_results : ${analysis.outdir}/inference_results.pt 131 | plot_df_path: ${analysis.outdir}/plot_df.pkl 132 | true_df_path: ${analysis.outdir}/true_df.pkl 133 | mse_at_t_df_path: ${analysis.outdir}/mse_at_t_df.pkl 134 | cons_df_path: ${analysis.outdir}/cons_df.pkl 135 | true_cons_df_path: ${analysis.outdir}/true_cons_df.pkl 136 | methods: 137 | anp: 138 | model: ${methods.anp.model} 139 | state_dict: ${methods.anp.state_dict} 140 | infer_path: ${base_dir}/analysis/anp_infer.pt 141 | overwrite: False 142 | truncated_version: False 143 | n_samples: 100 144 | shock_path: ${base_dir}/analysis/anp_shocks.pt 145 | pinp: 146 | model: ${methods.pinp.model} 147 | state_dict: ${methods.pinp.state_dict} 148 | infer_path: ${base_dir}/analysis/pinp_infer.pt 149 | overwrite: False 150 | truncated_version: False 151 | n_samples: 100 152 | shock_path: ${base_dir}/analysis/pinp_shocks.pt 153 | # physnp_notrain: 154 | # model: ${methods.physnp.model} 155 | # state_dict: ${methods.physnp.state_dict} 156 | # infer_path: ${base_dir}/analysis/physnp_notrain.pt 157 | # overwrite: False 158 | # constraint_precision: 1e8 159 | # anp_state_dict: ${methods.anp.state_dict} 160 | # truncated_version: True 161 | # constrained_version: True 162 | # infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 163 | # overwrite_constrained: False 164 | # nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 165 | # overwrite_nonneg: False 166 | # n_samples: 100 167 | physnp_limit: 168 | model: ${methods.physnp_limit.model} 169 | state_dict: ${methods.physnp_limit.state_dict} 170 | infer_path: ${base_dir}/analysis/physnp_limit.pt 171 | overwrite: False 172 | constraint_precision: 1e6 173 | anp_state_dict: ${methods.anp.state_dict} 174 | truncated_version: False 175 | constrained_version: False 176 | # infer_path_constrained: ${base_dir}/analysis/physnp_notrain_cnstrd.pt 177 | # overwrite_constrained: false 178 | # nonneg_path: ${base_dir}/analysis/physnp_notrain_nonneg.pt 179 | # overwrite_nonneg: false 180 | n_samples: 100 181 | shock_path: ${base_dir}/analysis/physnp_limit_shocks.pt 182 | hcnp_notrain: 183 | model: ${methods.hcnp.model} 184 | state_dict: ${methods.hcnp.state_dict} 185 | infer_path: ${base_dir}/analysis/hcnp_notrain.pt 186 | overwrite: False 187 | constraint_precision: 1e8 188 | anp_state_dict: ${methods.anp.state_dict} 189 | truncated_version: False 190 | constrained_version: False 191 | n_samples: 100 192 | shock_path: ${base_dir}/analysis/hcnp_notrain_shocks.pt 193 | physnp_second_deriv: 194 | model: ${methods.physnp_second_deriv.model} 195 | state_dict: ${methods.physnp_second_deriv.state_dict} 196 | infer_path: ${base_dir}/analysis/physnp_second_deriv.pt 197 | overwrite: False 198 | constraint_precision: 1e8 199 | anp_state_dict: ${methods.anp.state_dict} 200 | truncated_version: True 201 | constrained_version: True 202 | infer_path_constrained: ${base_dir}/analysis/physnp_second_deriv_cnstrd.pt 203 | overwrite_constrained: False 204 | nonneg_path: ${base_dir}/analysis/physnp_second_deriv_nonneg.pt 205 | overwrite_nonneg: False 206 | n_samples: 100 207 | shock_path: ${base_dir}/analysis/hcnp_notrain_shocks.pt 208 | nt: 201 209 | nx: 201 210 | dpi: 500 211 | base_font_size: 15 212 | n_shock_samples: 500 213 | n_shock_samples_per_batch: 50 214 | t_range: ${datasets.train.t_range} 215 | x_range: 216 | - 0.0 217 | - 1.0 218 | gpu: "cuda:0" 219 | plot_shock: True 220 | t_of_interest: 221 | - 0.05 222 | x_of_interest: 223 | - 0.32 224 | fids_of_interest: 225 | - "1" 226 | nice_names: 227 | anp: "ANP" 228 | pinp: "SoftC-ANP" 229 | hcnp_notrain: "HardC-ANP" 230 | # physnp_notrain: "PhysNP" 231 | physnp_limit: "ProbConserv-ANP" 232 | # physnp_notrain_cnstrd: "PhysNP (constrained)" 233 | # physnp_second_deriv: "PhysNP (diffusion)" 234 | colors: 235 | - "#F8766D" 236 | - "#7CAE00" 237 | - "#00BFC4" 238 | - "#C77CFF" 239 | params_ordered: 240 | - 0.6 241 | - 0.5 242 | mse_plot_width: 6 243 | mse_plot_height: 3 244 | cons_plot_width: 6 245 | cons_plot_height: 3 246 | shock_plot_width: 6 247 | shock_plot_height: 9 248 | time_plot_width: 6.8 249 | time_plot_height: 1.7 250 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/experiments/3b_heat_var_c.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | base_dir: ./output/paper/3b_heat_var_c 3 | datasets: 4 | c_min: 1 5 | c_max: 5 6 | save_path: ${base_dir}/datasets 7 | train: 8 | _target_: deep_pdes.datasets.pme.HeatEquation 9 | n_functions: 10000 10 | n_contexts_t: 10 11 | n_contexts_x: 10 12 | n_targets_t: 10 13 | n_targets_x: 10 14 | batch_size: 250 15 | conductivity_min: ${datasets.c_min} 16 | conductivity_max: ${datasets.c_max} 17 | valid: 18 | _target_: deep_pdes.datasets.pme.HeatEquation 19 | n_functions: 100 20 | n_contexts_t: ${datasets.train.n_contexts_t} 21 | n_contexts_x: ${datasets.train.n_contexts_x} 22 | n_targets_t: ${datasets.train.n_targets_t} 23 | n_targets_x: ${datasets.train.n_targets_x} 24 | batch_size: 250 25 | conductivity_min: ${datasets.c_min} 26 | conductivity_max: ${datasets.c_max} 27 | test: 28 | _target_: deep_pdes.datasets.pme.HeatEquation 29 | n_functions: 100 30 | n_contexts_t: ${datasets.train.n_contexts_t} 31 | n_contexts_x: ${datasets.train.n_contexts_x} 32 | n_targets_t: ${datasets.train.n_targets_t} 33 | n_targets_x: ${datasets.train.n_targets_x} 34 | batch_size: 250 35 | conductivities: 36 | - 1 37 | - 5 38 | methods: 39 | anp: 40 | model: 41 | _target_: deep_pdes.attentive_neural_process.anp.ANP 42 | num_hidden: 128 43 | dim_x: 2 44 | dim_y: 1 45 | lr: 1e-4 46 | state_dict: ${base_dir}/train/anp.pt 47 | physnp: 48 | model: 49 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 50 | anp: 51 | _target_: deep_pdes.attentive_neural_process.anp.ANP 52 | num_hidden: 128 53 | dim_x: 2 54 | dim_y: 1 55 | lr: 1e-4 56 | constraint_precision_train: 1e12 57 | train_precision: False 58 | limiting_mode: physnp 59 | state_dict: 60 | # state_dict: ${base_dir}/train/physnp.pt 61 | physnp_second_deriv: 62 | model: 63 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 64 | anp: 65 | _target_: deep_pdes.attentive_neural_process.anp.ANP 66 | num_hidden: 128 67 | dim_x: 2 68 | dim_y: 1 69 | lr: 1e-4 70 | constraint_precision_train: 1e12 71 | train_precision: False 72 | second_deriv_alpha: 0.9 73 | state_dict: 74 | hcnp: 75 | model: 76 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 77 | anp: 78 | _target_: deep_pdes.attentive_neural_process.anp.ANP 79 | num_hidden: 128 80 | dim_x: 2 81 | dim_y: 1 82 | lr: 1e-4 83 | constraint_precision_train: 1e12 84 | train_precision: False 85 | limiting_mode: hcnp 86 | state_dict: 87 | # state_dict: ${base_dir}/train/physnp.pt 88 | pinp: 89 | model: 90 | _target_: deep_pdes.attentive_neural_process.softc.PINP 91 | anp: 92 | _target_: deep_pdes.attentive_neural_process.anp.ANP 93 | num_hidden: 128 94 | dim_x: 2 95 | dim_y: 1 96 | pressure_fn: 97 | _target_: deep_pdes.attentive_neural_process.softc.HeatPressureFn 98 | pinns_lambda: 1.0 99 | lr: 1e-4 100 | state_dict: ${base_dir}/train/pinp.pt 101 | analysis: 102 | outdir: ${base_dir}/analysis/ 103 | inference_results : ${analysis.outdir}/inference_results.pt 104 | plot_df_path: ${analysis.outdir}/plot_df.pkl 105 | true_df_path: ${analysis.outdir}/true_df.pkl 106 | mse_at_t_df_path: ${analysis.outdir}/mse_at_t_df.pkl 107 | cons_df_path: ${analysis.outdir}/cons_df.pkl 108 | true_cons_df_path: ${analysis.outdir}/true_cons_df.pkl 109 | methods: 110 | anp: 111 | model: ${methods.anp.model} 112 | state_dict: ${methods.anp.state_dict} 113 | infer_path: ${base_dir}/analysis/anp_infer.pt 114 | overwrite: False 115 | truncated_version: False 116 | n_samples: 100 117 | pinp: 118 | model: ${methods.pinp.model} 119 | state_dict: ${methods.pinp.state_dict} 120 | infer_path: ${base_dir}/analysis/pinp_infer.pt 121 | overwrite: False 122 | truncated_version: False 123 | n_samples: 100 124 | physnp_notrain: 125 | model: ${methods.physnp.model} 126 | state_dict: ${methods.physnp.state_dict} 127 | infer_path: ${base_dir}/analysis/physnp_notrain.pt 128 | overwrite: False 129 | constraint_precision: 1e12 130 | anp_state_dict: ${methods.anp.state_dict} 131 | truncated_version: False 132 | constrained_version: False 133 | n_samples: 100 134 | hcnp_notrain: 135 | model: ${methods.hcnp.model} 136 | state_dict: ${methods.hcnp.state_dict} 137 | infer_path: ${base_dir}/analysis/hcnp_notrain.pt 138 | overwrite: False 139 | constraint_precision: 1e12 140 | anp_state_dict: ${methods.anp.state_dict} 141 | truncated_version: False 142 | constrained_version: False 143 | n_samples: 100 144 | physnp_second_deriv: 145 | model: ${methods.physnp_second_deriv.model} 146 | state_dict: ${methods.physnp_second_deriv.state_dict} 147 | infer_path: ${base_dir}/analysis/physnp_second_deriv.pt 148 | overwrite: False 149 | constraint_precision: 1e12 150 | anp_state_dict: ${methods.anp.state_dict} 151 | truncated_version: False 152 | constrained_version: False 153 | n_samples: 100 154 | nt: 201 155 | nx: 201 156 | dpi: 500 157 | base_font_size: 15 158 | n_shock_samples: 250 159 | n_shock_samples_per_batch: 50 160 | t_range: 161 | - 0.0 162 | - 1.0 163 | x_range: 164 | - 0.0 165 | - 6.283185307179586 166 | gpu: "cuda:0" 167 | plot_shock: True 168 | t_of_interest: 169 | - 0.5 170 | x_of_interest: 171 | - 0.32 172 | fids_of_interest: 173 | - "1" 174 | nice_names: 175 | anp: "ANP" 176 | pinp: "SoftC-ANP" 177 | hcnp_notrain: "HardC-ANP" 178 | physnp_notrain: "ProbConserv-ANP" 179 | # physnp_second_deriv: "ProbConserv-ANP (w/diffusion)" 180 | colors: 181 | - "#F8766D" 182 | - "#7CAE00" 183 | - "#00BFC4" 184 | - "#C77CFF" 185 | - "#000000" 186 | # colors: 187 | # ANP: "#F8766D" 188 | # "ANP+SoftC": "#7CAE00" 189 | # "ANP+HardC": "#00BFC4" 190 | # PhysNP: "#C77CFF" 191 | params_ordered: 192 | - 1 193 | - 5 194 | mse_plot_width: 6 195 | mse_plot_height: 3 196 | cons_plot_width: 4.5 197 | cons_plot_height: 3 198 | shock_plot_width: 6 199 | shock_plot_height: 6 200 | time_plot_width: 8 201 | time_plot_height: 8 202 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/experiments/4b_advection_var_a.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | base_dir: ./output/paper/4b_advection_var_a 3 | datasets: 4 | a_min: 1.0 5 | a_max: 5.0 6 | save_path: ${base_dir}/datasets 7 | dataset_overwrite: True 8 | train: 9 | _target_: deep_pdes.datasets.pme.LinearAdvection 10 | n_functions: 10000 11 | n_contexts_t: 10 12 | n_contexts_x: 10 13 | n_targets_t: 10 14 | n_targets_x: 10 15 | t_range: 16 | - 0.0 17 | - 1.0 18 | batch_size: 250 19 | a_lim: 20 | - ${datasets.a_min} 21 | - ${datasets.a_max} 22 | valid: 23 | _target_: deep_pdes.datasets.pme.LinearAdvection 24 | n_functions: 100 25 | n_contexts_t: ${datasets.train.n_contexts_t} 26 | n_contexts_x: ${datasets.train.n_contexts_x} 27 | n_targets_t: ${datasets.train.n_targets_t} 28 | n_targets_x: ${datasets.train.n_targets_x} 29 | batch_size: 250 30 | a_lim: 31 | - ${datasets.a_min} 32 | - ${datasets.a_max} 33 | t_range: ${datasets.train.t_range} 34 | test: 35 | _target_: deep_pdes.datasets.pme.LinearAdvection 36 | n_functions: 100 37 | n_contexts_t: ${datasets.train.n_contexts_t} 38 | n_contexts_x: ${datasets.train.n_contexts_x} 39 | n_targets_t: ${datasets.train.n_targets_t} 40 | n_targets_x: ${datasets.train.n_targets_x} 41 | batch_size: 250 42 | t_range: ${datasets.train.t_range} 43 | a_vals: 44 | - 1.0 45 | - 3.0 46 | methods: 47 | anp: 48 | model: 49 | _target_: deep_pdes.attentive_neural_process.anp.ANP 50 | num_hidden: 128 51 | dim_x: 2 52 | dim_y: 1 53 | lr: 1e-4 54 | state_dict: ${base_dir}/train/anp.pt 55 | hcnp: 56 | model: 57 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 58 | anp: 59 | _target_: deep_pdes.attentive_neural_process.anp.ANP 60 | num_hidden: 128 61 | dim_x: 2 62 | dim_y: 1 63 | lr: 1e-4 64 | constraint_precision_train: 1e5 65 | train_precision: False 66 | limiting_mode: hcnp 67 | state_dict: 68 | pinp: 69 | model: 70 | _target_: deep_pdes.attentive_neural_process.softc.SoftcANP 71 | anp: 72 | _target_: deep_pdes.attentive_neural_process.anp.ANP 73 | num_hidden: 128 74 | dim_x: 2 75 | dim_y: 1 76 | differential_penalty: 77 | _target_: deep_pdes.attentive_neural_process.softc.LinearAdvectionDifferentialPenalty 78 | pinns_lambda: 1.0 79 | lr: 1e-4 80 | state_dict: ${base_dir}/train/pinp.pt 81 | physnp_limit: 82 | model: 83 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 84 | anp: 85 | _target_: deep_pdes.attentive_neural_process.anp.ANP 86 | num_hidden: 128 87 | dim_x: 2 88 | dim_y: 1 89 | lr: 1e-4 90 | limiting_mode: physnp 91 | state_dict: 92 | analysis: 93 | outdir: ${base_dir}/analysis/ 94 | inference_results : ${analysis.outdir}/inference_results.pt 95 | plot_df_path: ${analysis.outdir}/plot_df.pkl 96 | true_df_path: ${analysis.outdir}/true_df.pkl 97 | mse_at_t_df_path: ${analysis.outdir}/mse_at_t_df.pkl 98 | cons_df_path: ${analysis.outdir}/cons_df.pkl 99 | true_cons_df_path: ${analysis.outdir}/true_cons_df.pkl 100 | methods: 101 | anp: 102 | model: ${methods.anp.model} 103 | state_dict: ${methods.anp.state_dict} 104 | infer_path: ${base_dir}/analysis/anp_infer.pt 105 | overwrite: False 106 | truncated_version: False 107 | n_samples: 100 108 | shock_path: ${base_dir}/analysis/anp_shocks.pt 109 | pinp: 110 | model: ${methods.pinp.model} 111 | state_dict: ${methods.pinp.state_dict} 112 | infer_path: ${base_dir}/analysis/pinp_infer.pt 113 | overwrite: False 114 | truncated_version: False 115 | n_samples: 100 116 | shock_path: ${base_dir}/analysis/pinp_shocks.pt 117 | physnp_limit: 118 | model: ${methods.physnp_limit.model} 119 | state_dict: ${methods.physnp_limit.state_dict} 120 | infer_path: ${base_dir}/analysis/physnp_limit.pt 121 | overwrite: False 122 | constraint_precision: 1e6 123 | anp_state_dict: ${methods.anp.state_dict} 124 | truncated_version: False 125 | constrained_version: False 126 | n_samples: 100 127 | shock_path: ${base_dir}/analysis/physnp_limit_shocks.pt 128 | hcnp_notrain: 129 | model: ${methods.hcnp.model} 130 | state_dict: ${methods.hcnp.state_dict} 131 | infer_path: ${base_dir}/analysis/hcnp_notrain.pt 132 | overwrite: False 133 | constraint_precision: 1e8 134 | anp_state_dict: ${methods.anp.state_dict} 135 | truncated_version: False 136 | constrained_version: False 137 | n_samples: 100 138 | shock_path: ${base_dir}/analysis/hcnp_notrain_shocks.pt 139 | nt: 201 140 | nx: 201 141 | dpi: 500 142 | base_font_size: 15 143 | n_shock_samples: 500 144 | n_shock_samples_per_batch: 50 145 | t_range: ${datasets.train.t_range} 146 | x_range: 147 | - 0.0 148 | - 1.0 149 | gpu: "cuda:0" 150 | plot_shock: True 151 | t_of_interest: 152 | - 0.10 153 | x_of_interest: 154 | - 0.32 155 | fids_of_interest: 156 | - "1" 157 | nice_names: 158 | anp: "ANP" 159 | pinp: "SoftC-ANP" 160 | hcnp_notrain: "HardC-ANP" 161 | physnp_limit: "ProbConserv-ANP" 162 | colors: 163 | - "#F8766D" 164 | - "#7CAE00" 165 | - "#00BFC4" 166 | - "#C77CFF" 167 | params_ordered: 168 | - 1.0 169 | - 3.0 170 | mse_plot_width: 6 171 | mse_plot_height: 3 172 | cons_plot_width: 6 173 | cons_plot_height: 3 174 | shock_plot_width: 6 175 | shock_plot_height: 9 176 | time_plot_width: 12 177 | time_plot_height: 6 178 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/experiments/5b_burgers_var_a.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | base_dir: ./output/paper/5b_burgers_var_a 3 | datasets: 4 | a_min: 1.0 5 | a_max: 4.0 6 | save_path: ${base_dir}/datasets 7 | dataset_overwrite: True 8 | train: 9 | _target_: deep_pdes.datasets.pme.Burgers 10 | n_functions: 10000 11 | n_contexts_t: 10 12 | n_contexts_x: 10 13 | n_targets_t: 10 14 | n_targets_x: 10 15 | t_range: 16 | - 0.0 17 | - 1.0 18 | batch_size: 250 19 | a_lim: 20 | - ${datasets.a_min} 21 | - ${datasets.a_max} 22 | valid: 23 | _target_: deep_pdes.datasets.pme.Burgers 24 | n_functions: 100 25 | n_contexts_t: ${datasets.train.n_contexts_t} 26 | n_contexts_x: ${datasets.train.n_contexts_x} 27 | n_targets_t: ${datasets.train.n_targets_t} 28 | n_targets_x: ${datasets.train.n_targets_x} 29 | batch_size: 250 30 | a_lim: 31 | - ${datasets.a_min} 32 | - ${datasets.a_max} 33 | t_range: ${datasets.train.t_range} 34 | test: 35 | _target_: deep_pdes.datasets.pme.Burgers 36 | n_functions: 100 37 | n_contexts_t: ${datasets.train.n_contexts_t} 38 | n_contexts_x: ${datasets.train.n_contexts_x} 39 | n_targets_t: ${datasets.train.n_targets_t} 40 | n_targets_x: ${datasets.train.n_targets_x} 41 | batch_size: 250 42 | t_range: ${datasets.train.t_range} 43 | a_vals: 44 | - 1.0 45 | - 3.0 46 | methods: 47 | anp: 48 | model: 49 | _target_: deep_pdes.attentive_neural_process.anp.ANP 50 | num_hidden: 128 51 | dim_x: 2 52 | dim_y: 1 53 | lr: 1e-4 54 | state_dict: ${base_dir}/train/anp.pt 55 | hcnp: 56 | model: 57 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 58 | anp: 59 | _target_: deep_pdes.attentive_neural_process.anp.ANP 60 | num_hidden: 128 61 | dim_x: 2 62 | dim_y: 1 63 | lr: 1e-4 64 | constraint_precision_train: 1e5 65 | train_precision: False 66 | limiting_mode: hcnp 67 | state_dict: 68 | pinp: 69 | model: 70 | _target_: deep_pdes.attentive_neural_process.softc.SoftcANP 71 | anp: 72 | _target_: deep_pdes.attentive_neural_process.anp.ANP 73 | num_hidden: 128 74 | dim_x: 2 75 | dim_y: 1 76 | differential_penalty: 77 | _target_: deep_pdes.attentive_neural_process.softc.BurgersDifferentialPenalty 78 | pinns_lambda: 1.0 79 | lr: 1e-4 80 | state_dict: ${base_dir}/train/pinp.pt 81 | physnp_limit: 82 | model: 83 | _target_: deep_pdes.attentive_neural_process.probconserv.PhysNP 84 | anp: 85 | _target_: deep_pdes.attentive_neural_process.anp.ANP 86 | num_hidden: 128 87 | dim_x: 2 88 | dim_y: 1 89 | lr: 1e-4 90 | limiting_mode: physnp 91 | state_dict: 92 | analysis: 93 | outdir: ${base_dir}/analysis/ 94 | inference_results : ${analysis.outdir}/inference_results.pt 95 | plot_df_path: ${analysis.outdir}/plot_df.pkl 96 | true_df_path: ${analysis.outdir}/true_df.pkl 97 | mse_at_t_df_path: ${analysis.outdir}/mse_at_t_df.pkl 98 | cons_df_path: ${analysis.outdir}/cons_df.pkl 99 | true_cons_df_path: ${analysis.outdir}/true_cons_df.pkl 100 | methods: 101 | anp: 102 | model: ${methods.anp.model} 103 | state_dict: ${methods.anp.state_dict} 104 | infer_path: ${base_dir}/analysis/anp_infer.pt 105 | overwrite: False 106 | truncated_version: False 107 | n_samples: 100 108 | shock_path: ${base_dir}/analysis/anp_shocks.pt 109 | pinp: 110 | model: ${methods.pinp.model} 111 | state_dict: ${methods.pinp.state_dict} 112 | infer_path: ${base_dir}/analysis/pinp_infer.pt 113 | overwrite: False 114 | truncated_version: False 115 | n_samples: 100 116 | shock_path: ${base_dir}/analysis/pinp_shocks.pt 117 | physnp_limit: 118 | model: ${methods.physnp_limit.model} 119 | state_dict: ${methods.physnp_limit.state_dict} 120 | infer_path: ${base_dir}/analysis/physnp_limit.pt 121 | overwrite: False 122 | constraint_precision: 1e6 123 | anp_state_dict: ${methods.anp.state_dict} 124 | truncated_version: False 125 | constrained_version: False 126 | n_samples: 100 127 | shock_path: ${base_dir}/analysis/physnp_limit_shocks.pt 128 | hcnp_notrain: 129 | model: ${methods.hcnp.model} 130 | state_dict: ${methods.hcnp.state_dict} 131 | infer_path: ${base_dir}/analysis/hcnp_notrain.pt 132 | overwrite: False 133 | constraint_precision: 1e8 134 | anp_state_dict: ${methods.anp.state_dict} 135 | truncated_version: False 136 | constrained_version: False 137 | n_samples: 100 138 | shock_path: ${base_dir}/analysis/hcnp_notrain_shocks.pt 139 | nt: 201 140 | nx: 201 141 | dpi: 500 142 | base_font_size: 15 143 | n_shock_samples: 500 144 | n_shock_samples_per_batch: 50 145 | t_range: ${datasets.train.t_range} 146 | x_range: 147 | - -1.0 148 | - 1.0 149 | gpu: "cuda:0" 150 | plot_shock: True 151 | t_of_interest: 152 | - 0.5 153 | x_of_interest: 154 | - 0.32 155 | fids_of_interest: 156 | - "1" 157 | nice_names: 158 | anp: "ANP" 159 | pinp: "SoftC-ANP" 160 | hcnp_notrain: "HardC-ANP" 161 | physnp_limit: "ProbConserv-ANP" 162 | colors: 163 | - "#F8766D" 164 | - "#7CAE00" 165 | - "#00BFC4" 166 | - "#C77CFF" 167 | params_ordered: 168 | - 1.0 169 | - 3.0 170 | mse_plot_width: 6 171 | mse_plot_height: 3 172 | cons_plot_width: 6 173 | cons_plot_height: 3 174 | shock_plot_width: 6 175 | shock_plot_height: 9 176 | time_plot_width: 12 177 | time_plot_height: 6 178 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_anp.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_anp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.anp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.anp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_physnp.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_physnp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.physnp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.physnp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_physnp_fixedvar.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_physnp_fixedvar 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.physnp_fixedvar.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.physnp_fixedvar.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_physnp_re.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_physnp_re 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.physnp_re.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.physnp_re.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_physnp_rhs.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_physnp_rhs 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.physnp_rhs.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.physnp_rhs.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_pinp.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_pinp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_pinp_1e1.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_pinp_1e1 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp_1e1.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp_1e1.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_pinp_1e2.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_pinp_1e2 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp_1e2.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp_1e2.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_pinp_1e6.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_pinp_1e6 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp_1e6.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp_1e6.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_pinp_1en1.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_pinp_1en1 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp_1en1.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp_1en1.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/1b_pme_var_m_pinp_1en2.yaml: -------------------------------------------------------------------------------- 1 | name: 1b_pme_var_m_pinp_1en2 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp_1en2.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp_1en2.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/2b_stefan_var_p_anp.yaml: -------------------------------------------------------------------------------- 1 | name: 2b_stefan_var_p_anp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.anp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.anp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/2b_stefan_var_p_physnp.yaml: -------------------------------------------------------------------------------- 1 | name: 2b_stefan_var_p_physnp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.physnp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.physnp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/2b_stefan_var_p_physnp_fixedvar.yaml: -------------------------------------------------------------------------------- 1 | name: 2b_stefan_var_p_physnp_fixedvar 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.physnp_fixedvar.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.physnp_fixedvar.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/2b_stefan_var_p_physnp_re.yaml: -------------------------------------------------------------------------------- 1 | name: 2b_stefan_var_p_physnp_re 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.physnp_re.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.physnp_re.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/2b_stefan_var_p_pinp.yaml: -------------------------------------------------------------------------------- 1 | name: 2b_stefan_var_p_pinp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/3b_heat_var_c_anp.yaml: -------------------------------------------------------------------------------- 1 | name: 3b_heat_var_c_anp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.anp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.anp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/3b_heat_var_c_pinp.yaml: -------------------------------------------------------------------------------- 1 | name: 3b_heat_var_c_pinp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/4b_advection_var_a_anp.yaml: -------------------------------------------------------------------------------- 1 | name: 4b_advection_var_a_anp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.anp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.anp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/4b_advection_var_a_pinp.yaml: -------------------------------------------------------------------------------- 1 | name: 4b_advection_var_a_pinp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/5b_burgers_var_a_anp.yaml: -------------------------------------------------------------------------------- 1 | name: 5b_burgers_var_a_anp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.anp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.anp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/conf/train/5b_burgers_var_a_pinp.yaml: -------------------------------------------------------------------------------- 1 | name: 5b_burgers_var_a_pinp 2 | version: "1" 3 | save_dir: ${base_dir}/${train.name} 4 | dataset_path: ${datasets.save_path} 5 | state_dict_path: ${methods.pinp.state_dict} 6 | trainer: 7 | max_epochs: 500 8 | check_val_every_n_epoch: 10 9 | logger: 10 | save_dir: ${train.save_dir} 11 | name: ${train.name} 12 | version: ${train.version} 13 | datasets: ${datasets} 14 | model: ${methods.pinp.model} 15 | 16 | -------------------------------------------------------------------------------- /deep_pdes/experiments/generate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | import torch 5 | from hydra.utils import instantiate 6 | 7 | from deep_pdes.datasets.base import ANPDataset 8 | 9 | 10 | def generate(cfg): 11 | dataset_path = Path(cfg.datasets.save_path) 12 | overwrite: bool = cfg.datasets.dataset_overwrite 13 | dataset_path.mkdir(parents=True, exist_ok=True) 14 | for dataset_type in ("train", "valid", "test", "pinn_grid_train", "pinn_grid_valid"): 15 | path = dataset_path / f"{dataset_type}.pt" 16 | if not overwrite: 17 | assert not path.exists() 18 | dataset_cfg = cfg.datasets.get(dataset_type) 19 | if dataset_cfg is not None: 20 | dataset: ANPDataset = instantiate(dataset_cfg) 21 | data_to_save = (dataset.tensors, dataset.parameters) 22 | torch.save(data_to_save, path) 23 | 24 | 25 | @hydra.main(version_base=None, config_path="conf", config_name="config") 26 | def main(cfg): 27 | generate(cfg) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /deep_pdes/experiments/output/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/.DS_Store -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/.DS_Store -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/1b_pme_var_m/datasets/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/1b_pme_var_m/datasets/test.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/1b_pme_var_m/datasets/train.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/1b_pme_var_m/datasets/train.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/1b_pme_var_m/datasets/valid.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/1b_pme_var_m/datasets/valid.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/1b_pme_var_m/train/anp.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/1b_pme_var_m/train/anp.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/1b_pme_var_m/train/physnp.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/1b_pme_var_m/train/physnp.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/1b_pme_var_m/train/pinp.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/1b_pme_var_m/train/pinp.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/2b_stefan_var_p/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/2b_stefan_var_p/.DS_Store -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/2b_stefan_var_p/datasets/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/2b_stefan_var_p/datasets/test.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/2b_stefan_var_p/datasets/train.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/2b_stefan_var_p/datasets/train.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/2b_stefan_var_p/datasets/valid.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/2b_stefan_var_p/datasets/valid.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/2b_stefan_var_p/train/anp.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/2b_stefan_var_p/train/anp.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/2b_stefan_var_p/train/physnp.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/2b_stefan_var_p/train/physnp.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/output/paper/2b_stefan_var_p/train/physnp_second_deriv.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/deep_pdes/experiments/output/paper/2b_stefan_var_p/train/physnp_second_deriv.pt -------------------------------------------------------------------------------- /deep_pdes/experiments/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | import torch 5 | from hydra.utils import instantiate 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | 9 | 10 | def train(cfg): # noqa: WPS210 11 | 12 | dataset_path = Path(cfg.train.dataset_path) 13 | 14 | model = instantiate(cfg.train.model) 15 | train_load_path = dataset_path / "train.pt" 16 | valid_load_path = dataset_path / "valid.pt" 17 | train_dataset = instantiate(cfg.train.datasets.train, load_path=train_load_path) 18 | valid_dataset = instantiate(cfg.train.datasets.valid, load_path=valid_load_path) 19 | 20 | train_loader = train_dataset.dataloader() 21 | val_loader = valid_dataset.dataloader() 22 | 23 | checkpoint_callback: ModelCheckpoint = instantiate(cfg.train.checkpoint_callback) 24 | trainer: Trainer = instantiate(cfg.train.trainer, callbacks=[checkpoint_callback]) 25 | trainer.fit(model, train_loader, val_loader) 26 | 27 | if cfg.train.save_best_model: 28 | model_checkpoint = torch.load(checkpoint_callback.best_model_path, map_location="cpu") 29 | 30 | model_state_dict = model_checkpoint["state_dict"] 31 | else: 32 | model_state_dict = model.state_dict() 33 | state_dict_path = Path(cfg.train.state_dict_path) 34 | if not state_dict_path.parent.exists(): 35 | state_dict_path.parent.mkdir(parents=True) 36 | torch.save(model_state_dict, state_dict_path) 37 | 38 | 39 | @hydra.main(version_base=None, config_path="conf", config_name="config") 40 | def main(cfg): 41 | train(cfg) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | exclude = deep_pdes/experiments 4 | 5 | -------------------------------------------------------------------------------- /py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/py.typed -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "deep_pdes" 3 | version = "0.1.0" 4 | description = "Code to accompany paper 'Learning Physical Models that Can Respect Conservation Laws'" 5 | packages = [{include = "deep_pdes"}] 6 | authors = ["Derek Hansen"] 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.8,<4.0" 10 | numpy = "^1.22.4" 11 | scipy = "^1.8.1" 12 | matplotlib = "^3.5.2" 13 | torch = "^1.11.0" 14 | pytorch-lightning = "^1.6.4" 15 | notebook = "^6.4.12" 16 | einops = "^0.4.1" 17 | pykalman = "^0.9.5" 18 | hydra-core = "^1.2.0" 19 | plotnine = "^0.9.0" 20 | qpsolvers = {extras = ["starter_solvers"], version = "^2.2.0"} 21 | joblib = "^1.1.0" 22 | patchworklib = "^0.4.7" 23 | icontract = "^2.6.2" 24 | 25 | [tool.poetry.dev-dependencies] 26 | black = "^22.3.0" 27 | flake8 = "^4.0.1" 28 | pytest = "^7.1.2" 29 | pylint = "^2.14.0" 30 | pytest-cov = "^3.0.0" 31 | wemake-python-styleguide = "^0.16.1" 32 | mypy = "^0.961" 33 | 34 | [tool.black] 35 | line-length = 100 36 | target-version = ['py38'] 37 | 38 | [build-system] 39 | requires = ["poetry-core>=1.0.0"] 40 | build-backend = "poetry.core.masonry.api" 41 | -------------------------------------------------------------------------------- /resources/diffusion_eqtn_conserv_mass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/resources/diffusion_eqtn_conserv_mass.png -------------------------------------------------------------------------------- /resources/schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/resources/schematic.png -------------------------------------------------------------------------------- /resources/stefan_shock_position_downstream_task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/resources/stefan_shock_position_downstream_task -------------------------------------------------------------------------------- /resources/stefan_solution_profile_UQ: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/probconserv/bfd388c8d3b2926dcdcbe12de6658616e2fa92d4/resources/stefan_solution_profile_UQ -------------------------------------------------------------------------------- /tests/test_pme.py: -------------------------------------------------------------------------------- 1 | from hydra import compose, initialize 2 | 3 | from deep_pdes.experiments.generate import generate 4 | from deep_pdes.experiments.train import train 5 | 6 | 7 | def run_pme(experiment: str, model: str, extra_overrides=None): 8 | with initialize( 9 | version_base=None, config_path="../deep_pdes/experiments/conf", job_name="test_app" 10 | ): 11 | 12 | overrides = [ 13 | f"+experiments={experiment}", 14 | f"+train={experiment}_{model}", 15 | "base_dir=./output/test", 16 | "train.trainer.accelerator=", 17 | "train.trainer.max_epochs=2", 18 | "train.trainer.check_val_every_n_epoch=2", 19 | "datasets.dataset_overwrite=True", 20 | "datasets.train.n_functions=2", 21 | "datasets.valid.n_functions=1", 22 | "analysis.gpu=cpu", 23 | "analysis.nx=11", 24 | "analysis.nt=11", 25 | ] 26 | if extra_overrides is not None: 27 | overrides += extra_overrides 28 | cfg = compose(config_name="config", overrides=overrides) 29 | generate(cfg) 30 | train(cfg) 31 | 32 | 33 | def test_anp_pme(): 34 | run_pme("1b_pme_var_m", "anp") 35 | 36 | 37 | def test_anp_stefan(): 38 | run_pme("2b_stefan_var_p", "anp") 39 | 40 | 41 | def test_physnp_pme(): 42 | run_pme("1b_pme_var_m", "physnp") 43 | 44 | 45 | def test_physnp_stefan(): 46 | run_pme("2b_stefan_var_p", "physnp") 47 | --------------------------------------------------------------------------------