├── .flake8 ├── .github └── workflows │ └── run_tests.yml ├── .gitignore ├── CONTRIBUTING.md ├── DOCUMENTATION.md ├── LICENSE ├── README.md ├── assets └── latent_sde.gif ├── benchmarks ├── __init__.py ├── brownian.py └── profile_btree.py ├── diagnostics ├── __init__.py ├── inspection.py ├── ito_additive.py ├── ito_diagonal.py ├── ito_general.py ├── ito_scalar.py ├── run_all.py ├── stratonovich_additive.py ├── stratonovich_diagonal.py ├── stratonovich_general.py ├── stratonovich_scalar.py └── utils.py ├── examples ├── __init__.py ├── cont_ddpm.py ├── demo.ipynb ├── latent_sde.py ├── latent_sde_lorenz.py ├── sde_gan.py └── unet.py ├── pyproject.toml ├── setup.py ├── tests ├── __init__.py ├── problems.py ├── test_adjoint.py ├── test_brownian_interval.py ├── test_brownian_path.py ├── test_brownian_tree.py ├── test_sdeint.py └── utils.py └── torchsde ├── __init__.py ├── _brownian ├── __init__.py ├── brownian_base.py ├── brownian_interval.py └── derived.py ├── _core ├── __init__.py ├── adaptive_stepping.py ├── adjoint.py ├── adjoint_sde.py ├── base_sde.py ├── base_solver.py ├── better_abc.py ├── interp.py ├── methods │ ├── __init__.py │ ├── euler.py │ ├── euler_heun.py │ ├── heun.py │ ├── log_ode.py │ ├── midpoint.py │ ├── milstein.py │ ├── reversible_heun.py │ ├── srk.py │ └── tableaus │ │ ├── __init__.py │ │ ├── sra1.py │ │ ├── sra2.py │ │ ├── sra3.py │ │ ├── srid1.py │ │ └── srid2.py ├── misc.py └── sdeint.py ├── settings.py └── types.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = W291,W503,W504,E123,E126,E203,E402,E701 4 | per-file-ignores = __init__.py: F401 5 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Run tests 5 | 6 | on: 7 | pull_request: 8 | schedule: 9 | - cron: "0 2 * * 6" 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build: 14 | strategy: 15 | matrix: 16 | python-version: [ 3.8 ] 17 | os: [ ubuntu-latest, macOS-latest, windows-latest ] 18 | fail-fast: false 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v4 23 | 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | cache: pip 29 | cache-dependency-path: setup.py 30 | 31 | - name: Install 32 | run: pip install pytest -e . --only-binary=numpy,scipy,matplotlib,torch 33 | env: 34 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 35 | 36 | - name: Test with pytest 37 | run: python -m pytest -v 38 | 39 | lint: 40 | runs-on: ubuntu-latest 41 | steps: 42 | - name: Checkout code 43 | uses: actions/checkout@v4 44 | 45 | - uses: actions/setup-python@v4 46 | with: 47 | python-version: "3.11" 48 | cache: pip 49 | cache-dependency-path: setup.py 50 | 51 | - name: Lint with flake8 52 | run: | 53 | python -m pip install flake8 54 | python -m flake8 . 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | .lib 4 | __pycache__ 5 | .ipynb_checkpoints 6 | diagnostics/plots/ 7 | build/ 8 | dist/ 9 | .vscode/ 10 | *.egg-info/ 11 | benchmarks/plots/ 12 | CMakeLists.txt 13 | restats 14 | *-darwin.so 15 | **.pyc 16 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | a few small guidelines to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of Differentiable SDE Solvers ![Python package](https://github.com/google-research/torchsde/actions/workflows/run_tests.yml/badge.svg) 2 | This library provides [stochastic differential equation (SDE)](https://en.wikipedia.org/wiki/Stochastic_differential_equation) solvers with GPU support and efficient backpropagation. 3 | 4 | --- 5 |

6 | 7 |

8 | 9 | ## Installation 10 | ```shell script 11 | pip install torchsde 12 | ``` 13 | 14 | **Requirements:** Python >=3.8 and PyTorch >=1.6.0. 15 | 16 | ## Documentation 17 | Available [here](./DOCUMENTATION.md). 18 | 19 | ## Examples 20 | ### Quick example 21 | ```python 22 | import torch 23 | import torchsde 24 | 25 | batch_size, state_size, brownian_size = 32, 3, 2 26 | t_size = 20 27 | 28 | class SDE(torch.nn.Module): 29 | noise_type = 'general' 30 | sde_type = 'ito' 31 | 32 | def __init__(self): 33 | super().__init__() 34 | self.mu = torch.nn.Linear(state_size, 35 | state_size) 36 | self.sigma = torch.nn.Linear(state_size, 37 | state_size * brownian_size) 38 | 39 | # Drift 40 | def f(self, t, y): 41 | return self.mu(y) # shape (batch_size, state_size) 42 | 43 | # Diffusion 44 | def g(self, t, y): 45 | return self.sigma(y).view(batch_size, 46 | state_size, 47 | brownian_size) 48 | 49 | sde = SDE() 50 | y0 = torch.full((batch_size, state_size), 0.1) 51 | ts = torch.linspace(0, 1, t_size) 52 | # Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]]. 53 | # ys will have shape (t_size, batch_size, state_size) 54 | ys = torchsde.sdeint(sde, y0, ts) 55 | ``` 56 | 57 | ### Notebook 58 | 59 | [`examples/demo.ipynb`](examples/demo.ipynb) gives a short guide on how to solve SDEs, including subtle points such as fixing the randomness in the solver and the choice of *noise types*. 60 | 61 | ### Latent SDE 62 | 63 | [`examples/latent_sde.py`](examples/latent_sde.py) learns a *latent stochastic differential equation*, as in Section 5 of [\[1\]](https://arxiv.org/pdf/2001.01328.pdf). 64 | The example fits an SDE to data, whilst regularizing it to be like an [Ornstein-Uhlenbeck](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process) prior process. 65 | The model can be loosely viewed as a [variational autoencoder](https://en.wikipedia.org/wiki/Autoencoder#Variational_autoencoder_(VAE)) with its prior and approximate posterior being SDEs. This example can be run via 66 | ```shell script 67 | python -m examples.latent_sde --train-dir 68 | ``` 69 | The program outputs figures to the path specified by ``. 70 | Training should stabilize after 500 iterations with the default hyperparameters. 71 | 72 | ### Neural SDEs as GANs 73 | [`examples/sde_gan.py`](examples/sde_gan.py) learns an SDE as a GAN, as in [\[2\]](https://arxiv.org/abs/2102.03657), [\[3\]](https://arxiv.org/abs/2105.13493). The example trains an SDE as the generator of a GAN, whilst using a [neural CDE](https://github.com/patrick-kidger/NeuralCDE) [\[4\]](https://arxiv.org/abs/2005.08926) as the discriminator. This example can be run via 74 | 75 | ```shell script 76 | python -m examples.sde_gan 77 | ``` 78 | 79 | ## Citation 80 | 81 | If you found this codebase useful in your research, please consider citing either or both of: 82 | 83 | ``` 84 | @article{li2020scalable, 85 | title={Scalable gradients for stochastic differential equations}, 86 | author={Li, Xuechen and Wong, Ting-Kam Leonard and Chen, Ricky T. Q. and Duvenaud, David}, 87 | journal={International Conference on Artificial Intelligence and Statistics}, 88 | year={2020} 89 | } 90 | ``` 91 | 92 | ``` 93 | @article{kidger2021neuralsde, 94 | title={Neural {SDE}s as {I}nfinite-{D}imensional {GAN}s}, 95 | author={Kidger, Patrick and Foster, James and Li, Xuechen and Oberhauser, Harald and Lyons, Terry}, 96 | journal={International Conference on Machine Learning}, 97 | year={2021} 98 | } 99 | ``` 100 | 101 | ## References 102 | 103 | \[1\] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations". *International Conference on Artificial Intelligence and Statistics.* 2020. [[arXiv]](https://arxiv.org/pdf/2001.01328.pdf) 104 | 105 | \[2\] Patrick Kidger, James Foster, Xuechen Li, Harald Oberhauser, Terry Lyons. "Neural SDEs as Infinite-Dimensional GANs". *International Conference on Machine Learning* 2021. [[arXiv]](https://arxiv.org/abs/2102.03657) 106 | 107 | \[3\] Patrick Kidger, James Foster, Xuechen Li, Terry Lyons. "Efficient and Accurate Gradients for Neural SDEs". 2021. [[arXiv]](https://arxiv.org/abs/2105.13493) 108 | 109 | \[4\] Patrick Kidger, James Morrill, James Foster, Terry Lyons, "Neural Controlled Differential Equations for Irregular Time Series". *Neural Information Processing Systems* 2020. [[arXiv]](https://arxiv.org/abs/2005.08926) 110 | 111 | --- 112 | This is a research project, not an official Google product. 113 | -------------------------------------------------------------------------------- /assets/latent_sde.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/torchsde/eb3a00e31cbd56176270066ed2f62c394cf6acb7/assets/latent_sde.gif -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/torchsde/eb3a00e31cbd56176270066ed2f62c394cf6acb7/benchmarks/__init__.py -------------------------------------------------------------------------------- /benchmarks/brownian.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Compare the speed of 5 Brownian motion variants on problems of different sizes.""" 16 | import argparse 17 | import logging 18 | import os 19 | import time 20 | 21 | import numpy.random as npr 22 | import torch 23 | 24 | import torchsde 25 | from diagnostics import utils 26 | 27 | t0, t1 = 0.0, 1.0 28 | reps, steps = 3, 100 29 | small_batch_size, small_d = 128, 5 30 | large_batch_size, large_d = 256, 128 31 | huge_batch_size, huge_d = 512, 256 32 | 33 | 34 | def _time_query(bm, ts): 35 | now = time.perf_counter() 36 | for _ in range(reps): 37 | for ta, tb in zip(ts[:-1], ts[1:]): 38 | if ta > tb: 39 | ta, tb = tb, ta 40 | bm(ta, tb) 41 | return time.perf_counter() - now 42 | 43 | 44 | def _compare(w0, ts, msg=''): 45 | bm = torchsde.BrownianPath(t0=t0, w0=w0) 46 | bp_py_time = _time_query(bm, ts) 47 | logging.warning(f'{msg} (torchsde.BrownianPath): {bp_py_time:.4f}') 48 | 49 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, tol=1e-5) 50 | bt_py_time = _time_query(bm, ts) 51 | logging.warning(f'{msg} (torchsde.BrownianTree): {bt_py_time:.4f}') 52 | 53 | bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=w0.shape, dtype=w0.dtype, device=w0.device) 54 | bi_py_time = _time_query(bm, ts) 55 | logging.warning(f'{msg} (torchsde.BrownianInterval): {bi_py_time:.4f}') 56 | 57 | return bp_py_time, bt_py_time, bi_py_time 58 | 59 | 60 | def sequential_access(): 61 | ts = torch.linspace(t0, t1, steps=steps) 62 | 63 | w0 = torch.zeros(small_batch_size, small_d).to(device) 64 | bp_py_time_s, bt_py_time_s, bi_py_time_s = _compare(w0, ts, msg='small sequential access') 65 | 66 | w0 = torch.zeros(large_batch_size, large_d).to(device) 67 | bp_py_time_l, bt_py_time_l, bi_py_time_l = _compare(w0, ts, msg='large sequential access') 68 | 69 | w0 = torch.zeros(huge_batch_size, huge_d).to(device) 70 | bp_py_time_h, bt_py_time_h, bi_py_time_h = _compare(w0, ts, msg="huge sequential access") 71 | 72 | img_path = os.path.join('.', 'benchmarks', f'plots-{device}', 'sequential_access.png') 73 | if not os.path.exists(os.path.dirname(img_path)): 74 | os.makedirs(os.path.dirname(img_path)) 75 | 76 | xaxis = [small_batch_size * small_d, large_batch_size * large_batch_size, huge_batch_size * huge_d] 77 | 78 | utils.swiss_knife_plotter( 79 | img_path, 80 | plots=[ 81 | {'x': xaxis, 'y': [bp_py_time_s, bp_py_time_l, bp_py_time_h], 'label': 'bp_py', 'marker': 'x'}, 82 | {'x': xaxis, 'y': [bt_py_time_s, bt_py_time_l, bt_py_time_h], 'label': 'bt_py', 'marker': 'x'}, 83 | {'x': xaxis, 'y': [bi_py_time_s, bi_py_time_l, bi_py_time_h], 'label': 'bi_py', 'marker': 'x'}, 84 | ], 85 | options={ 86 | 'xscale': 'log', 87 | 'yscale': 'log', 88 | 'xlabel': 'size of tensor', 89 | 'ylabel': f'wall time on {device}', 90 | 'title': 'sequential access' 91 | } 92 | ) 93 | 94 | 95 | def random_access(): 96 | generator = torch.Generator().manual_seed(456789) 97 | ts = torch.empty(steps).uniform_(t0, t1, generator=generator) 98 | 99 | w0 = torch.zeros(small_batch_size, small_d).to(device) 100 | bp_py_time_s, bt_py_time_s, bi_py_time_s = _compare(w0, ts, msg='small random access') 101 | 102 | w0 = torch.zeros(large_batch_size, large_d).to(device) 103 | bp_py_time_l, bt_py_time_l, bi_py_time_l = _compare(w0, ts, msg='large random access') 104 | 105 | w0 = torch.zeros(huge_batch_size, huge_d).to(device) 106 | bp_py_time_h, bt_py_time_h, bi_py_time_h = _compare(w0, ts, msg="huge random access") 107 | 108 | img_path = os.path.join('.', 'benchmarks', f'plots-{device}', 'random_access.png') 109 | if not os.path.exists(os.path.dirname(img_path)): 110 | os.makedirs(os.path.dirname(img_path)) 111 | 112 | xaxis = [small_batch_size * small_d, large_batch_size * large_batch_size, huge_batch_size * huge_d] 113 | 114 | utils.swiss_knife_plotter( 115 | img_path, 116 | plots=[ 117 | {'x': xaxis, 'y': [bp_py_time_s, bp_py_time_l, bp_py_time_h], 'label': 'bp_py', 'marker': 'x'}, 118 | {'x': xaxis, 'y': [bt_py_time_s, bt_py_time_l, bt_py_time_h], 'label': 'bt_py', 'marker': 'x'}, 119 | {'x': xaxis, 'y': [bi_py_time_s, bi_py_time_l, bi_py_time_h], 'label': 'bi_py', 'marker': 'x'}, 120 | ], 121 | options={ 122 | 'xscale': 'log', 123 | 'yscale': 'log', 124 | 'xlabel': 'size of tensor', 125 | 'ylabel': f'wall time on {device}', 126 | 'title': 'random access' 127 | } 128 | ) 129 | 130 | 131 | class SDE(torchsde.SDEIto): 132 | def __init__(self): 133 | super(SDE, self).__init__(noise_type="diagonal") 134 | 135 | def f(self, t, y): 136 | return y 137 | 138 | def g(self, t, y): 139 | return torch.exp(-y) 140 | 141 | 142 | def _time_sdeint(sde, y0, ts, bm): 143 | now = time.perf_counter() 144 | with torch.no_grad(): 145 | torchsde.sdeint(sde, y0, ts, bm, method='euler') 146 | return time.perf_counter() - now 147 | 148 | 149 | def _time_sdeint_bp(sde, y0, ts, bm): 150 | now = time.perf_counter() 151 | sde.zero_grad() 152 | y0 = y0.clone().requires_grad_(True) 153 | ys = torchsde.sdeint(sde, y0, ts, bm, method='euler') 154 | ys.sum().backward() 155 | return time.perf_counter() - now 156 | 157 | 158 | def _time_sdeint_adjoint(sde, y0, ts, bm): 159 | now = time.perf_counter() 160 | sde.zero_grad() 161 | y0 = y0.clone().requires_grad_(True) 162 | ys = torchsde.sdeint_adjoint(sde, y0, ts, bm, method='euler') 163 | ys.sum().backward() 164 | return time.perf_counter() - now 165 | 166 | 167 | def _compare_sdeint(w0, sde, y0, ts, func, msg=''): 168 | bm = torchsde.BrownianPath(t0=t0, w0=w0) 169 | bp_py_time = func(sde, y0, ts, bm) 170 | logging.warning(f'{msg} (torchsde.BrownianPath): {bp_py_time:.4f}') 171 | 172 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, tol=1e-5) 173 | bt_py_time = func(sde, y0, ts, bm) 174 | logging.warning(f'{msg} (torchsde.BrownianTree): {bt_py_time:.4f}') 175 | 176 | bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=w0.shape, dtype=w0.dtype, device=w0.device) 177 | bi_py_time = func(sde, y0, ts, bm) 178 | logging.warning(f'{msg} (torchsde.BrownianInterval): {bi_py_time:.4f}') 179 | 180 | return bp_py_time, bt_py_time, bi_py_time 181 | 182 | 183 | def solver_access(func=_time_sdeint): 184 | ts = torch.linspace(t0, t1, steps) 185 | sde = SDE().to(device) 186 | 187 | y0 = w0 = torch.zeros(small_batch_size, small_d).to(device) 188 | bp_py_time_s, bt_py_time_s, bi_py_time_s = _compare_sdeint(w0, sde, y0, ts, func, msg='small') 189 | 190 | y0 = w0 = torch.zeros(large_batch_size, large_d).to(device) 191 | bp_py_time_l, bt_py_time_l, bi_py_time_l = _compare_sdeint(w0, sde, y0, ts, func, msg='large') 192 | 193 | y0 = w0 = torch.zeros(huge_batch_size, huge_d).to(device) 194 | bp_py_time_h, bt_py_time_h, bi_py_time_h = _compare_sdeint(w0, sde, y0, ts, func, msg='huge') 195 | 196 | name = { 197 | _time_sdeint: 'sdeint', 198 | _time_sdeint_bp: 'sdeint-backprop-solver', 199 | _time_sdeint_adjoint: 'sdeint-backprop-adjoint' 200 | }[func] 201 | 202 | img_path = os.path.join('.', 'benchmarks', f'plots-{device}', f'{name}.png') 203 | if not os.path.exists(os.path.dirname(img_path)): 204 | os.makedirs(os.path.dirname(img_path)) 205 | 206 | xaxis = [small_batch_size * small_d, large_batch_size * large_batch_size, huge_batch_size * huge_d] 207 | 208 | utils.swiss_knife_plotter( 209 | img_path, 210 | plots=[ 211 | {'x': xaxis, 'y': [bp_py_time_s, bp_py_time_l, bp_py_time_h], 'label': 'bp_py', 'marker': 'x'}, 212 | {'x': xaxis, 'y': [bt_py_time_s, bt_py_time_l, bt_py_time_h], 'label': 'bt_py', 'marker': 'x'}, 213 | {'x': xaxis, 'y': [bi_py_time_s, bi_py_time_l, bi_py_time_h], 'label': 'bi_py', 'marker': 'x'}, 214 | ], 215 | options={ 216 | 'xscale': 'log', 217 | 'yscale': 'log', 218 | 'xlabel': 'size of tensor', 219 | 'ylabel': f'wall time on {device}', 220 | 'title': name 221 | } 222 | ) 223 | 224 | 225 | def main(): 226 | sequential_access() 227 | random_access() 228 | 229 | solver_access(func=_time_sdeint) 230 | solver_access(func=_time_sdeint_bp) 231 | solver_access(func=_time_sdeint_adjoint) 232 | 233 | 234 | if __name__ == "__main__": 235 | parser = argparse.ArgumentParser() 236 | parser.add_argument('--no-gpu', action='store_true') 237 | parser.add_argument('--debug', action='store_true') 238 | parser.add_argument('--seed', type=int, default=0) 239 | 240 | args = parser.parse_args() 241 | device = torch.device('cuda' if torch.cuda.is_available() and not args.no_gpu else 'cpu') 242 | 243 | npr.seed(args.seed) 244 | torch.manual_seed(args.seed) 245 | 246 | if args.debug: 247 | logging.getLogger().setLevel(logging.INFO) 248 | 249 | main() 250 | -------------------------------------------------------------------------------- /benchmarks/profile_btree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import time 18 | 19 | import matplotlib.pyplot as plt 20 | import torch 21 | import tqdm 22 | 23 | from torchsde import BrownianTree 24 | 25 | 26 | def run_torch(ks=(0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12)): 27 | w0 = torch.zeros(b, d) 28 | 29 | t_cons = [] 30 | t_queries = [] 31 | t_alls = [] 32 | for k in tqdm.tqdm(ks): 33 | now = time.time() 34 | bm_vanilla = BrownianTree(t0=t0, t1=t1, w0=w0, cache_depth=k) 35 | t_con = time.time() - now 36 | t_cons.append(t_con) 37 | 38 | now = time.time() 39 | for t in ts: 40 | bm_vanilla(t).to(device) 41 | t_query = time.time() - now 42 | t_queries.append(t_query) 43 | 44 | t_all = t_con + t_query 45 | t_alls.append(t_all) 46 | logging.warning(f'k={k}, t_con={t_con:.4f}, t_query={t_query:.4f}, t_all={t_all:.4f}') 47 | 48 | img_path = os.path.join('.', 'diagnostics', 'plots', 'profile_btree.png') 49 | plt.figure() 50 | plt.plot(ks, t_cons, label='cons') 51 | plt.plot(ks, t_queries, label='queries') 52 | plt.plot(ks, t_alls, label='all') 53 | plt.title(f'b={b}, d={d}, repetitions={reps}, device={w0.device}') 54 | plt.xlabel('Cache level') 55 | plt.ylabel('Time (secs)') 56 | plt.legend() 57 | plt.savefig(img_path) 58 | plt.close() 59 | 60 | 61 | def main(): 62 | run_torch() 63 | 64 | 65 | if __name__ == "__main__": 66 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 67 | torch.manual_seed(1147481649) 68 | 69 | reps = 500 70 | b, d = 512, 10 71 | 72 | t0, t1 = 0., 1. 73 | ts = torch.rand(size=(reps,)).numpy() 74 | 75 | main() 76 | -------------------------------------------------------------------------------- /diagnostics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /diagnostics/inspection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | 18 | import torch 19 | import tqdm 20 | 21 | from torchsde import BaseBrownian, BaseSDE, sdeint 22 | from torchsde.settings import SDE_TYPES 23 | from torchsde.types import Tensor, Vector, Scalar, Tuple, Optional, Callable 24 | from . import utils 25 | 26 | 27 | sys.setrecursionlimit(5000) 28 | 29 | 30 | @torch.no_grad() 31 | def inspect_samples(y0: Tensor, 32 | ts: Vector, 33 | dt: Scalar, 34 | sde: BaseSDE, 35 | bm: BaseBrownian, 36 | img_dir: str, 37 | methods: Tuple[str, ...], 38 | options: Optional[Tuple] = None, 39 | labels: Optional[Tuple[str, ...]] = None, 40 | vis_dim=0, 41 | dt_true: Optional[float] = 2 ** -14): 42 | if options is None: 43 | options = (None,) * len(methods) 44 | if labels is None: 45 | labels = methods 46 | 47 | solns = [ 48 | sdeint(sde, y0, ts, bm, method=method, dt=dt, options=options_) 49 | for method, options_ in zip(methods, options) 50 | ] 51 | 52 | method_for_true = 'euler' if sde.sde_type == SDE_TYPES.ito else 'midpoint' 53 | true = sdeint(sde, y0, ts, bm, method=method_for_true, dt=dt_true) 54 | 55 | labels += ('true',) 56 | solns += [true] 57 | 58 | # (T, batch_size, d) -> (T, batch_size) -> (batch_size, T). 59 | solns = [soln[..., vis_dim].t() for soln in solns] 60 | 61 | for i, samples in enumerate(zip(*solns)): 62 | utils.swiss_knife_plotter( 63 | img_path=os.path.join(img_dir, f'{i}'), 64 | plots=[ 65 | {'x': ts, 'y': sample, 'label': label, 'marker': 'x'} 66 | for sample, label in zip(samples, labels) 67 | ] 68 | ) 69 | 70 | 71 | @torch.no_grad() 72 | def inspect_orders(y0: Tensor, 73 | t0: Scalar, 74 | t1: Scalar, 75 | dts: Vector, 76 | sde: BaseSDE, 77 | bm: BaseBrownian, 78 | img_dir: str, 79 | methods: Tuple[str, ...], 80 | options: Optional[Tuple] = None, 81 | labels: Optional[Tuple[str, ...]] = None, 82 | dt_true: Optional[float] = 2 ** -14, 83 | test_func: Optional[Callable] = lambda x: (x ** 2).flatten(start_dim=1).sum(dim=1)): 84 | if options is None: 85 | options = (None,) * len(methods) 86 | if labels is None: 87 | labels = methods 88 | 89 | ts = torch.tensor([t0, t1], device=y0.device) 90 | 91 | solns = [ 92 | [ 93 | sdeint(sde, y0, ts, bm, method=method, dt=dt, options=options_)[-1] 94 | for method, options_ in zip(methods, options) 95 | ] 96 | for dt in tqdm.tqdm(dts) 97 | ] 98 | 99 | if hasattr(sde, 'analytical_sample'): 100 | true = sde.analytical_sample(y0, ts, bm)[-1] 101 | else: 102 | method_for_true = 'euler' if sde.sde_type == SDE_TYPES.ito else 'midpoint' 103 | true = sdeint(sde, y0, ts, bm, method=method_for_true, dt=dt_true)[-1] 104 | 105 | mses = [] 106 | maes = [] 107 | for dt, solns_ in zip(dts, solns): 108 | mses_for_dt = [utils.mse(soln, true) for soln in solns_] 109 | mses.append(mses_for_dt) 110 | 111 | maes_for_dt = [utils.mae(soln, true, test_func) for soln in solns_] 112 | maes.append(maes_for_dt) 113 | 114 | strong_order_slopes = [ 115 | utils.linregress_slope(utils.log(dts), .5 * utils.log(mses_for_method)) 116 | for mses_for_method in zip(*mses) 117 | ] 118 | 119 | weak_order_slopes = [ 120 | utils.linregress_slope(utils.log(dts), utils.log(maes_for_method)) 121 | for maes_for_method in zip(*maes) 122 | ] 123 | 124 | utils.swiss_knife_plotter( 125 | img_path=os.path.join(img_dir, 'strong_order'), 126 | plots=[ 127 | {'x': dts, 'y': mses_for_method, 'label': f'{label}(k={slope:.4f})', 'marker': 'x'} 128 | for mses_for_method, label, slope in zip(zip(*mses), labels, strong_order_slopes) 129 | ], 130 | options={'xscale': 'log', 'yscale': 'log', 'cycle_line_style': True} 131 | ) 132 | 133 | utils.swiss_knife_plotter( 134 | img_path=os.path.join(img_dir, 'weak_order'), 135 | plots=[ 136 | {'x': dts, 'y': mres_for_method, 'label': f'{label}(k={slope:.4f})', 'marker': 'x'} 137 | for mres_for_method, label, slope in zip(zip(*maes), labels, weak_order_slopes) 138 | ], 139 | options={'xscale': 'log', 'yscale': 'log', 'cycle_line_style': True} 140 | ) 141 | -------------------------------------------------------------------------------- /diagnostics/ito_additive.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralAdditive 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralAdditive(d=d, m=m).to(device) 36 | methods = ('euler', 'milstein', 'milstein', 'srk') 37 | options = (None, None, dict(grad_free=True), None) 38 | labels = ('euler', 'milstein', 'gradient-free milstein', 'srk') 39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_additive') 40 | 41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 42 | bm = BrownianInterval( 43 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device, 44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time 45 | ) 46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 47 | 48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 49 | bm = BrownianInterval( 50 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device, 51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time 52 | ) 53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /diagnostics/ito_diagonal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralDiagonal 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d = 16, 16384, 5 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralDiagonal(d=d).to(device) 36 | methods = ('euler', 'milstein', 'milstein', 'srk') 37 | options = (None, None, dict(grad_free=True), None) 38 | labels = ('euler', 'milstein', 'gradient-free milstein', 'srk') 39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_diagonal') 40 | 41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 42 | bm = BrownianInterval( 43 | t0=t0, t1=t1, size=(small_batch_size, d), dtype=y0.dtype, device=device, 44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time 45 | ) 46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 47 | 48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 49 | bm = BrownianInterval( 50 | t0=t0, t1=t1, size=(large_batch_size, d), dtype=y0.dtype, device=device, 51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time 52 | ) 53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /diagnostics/ito_general.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralGeneral 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralGeneral(d=d, m=m).to(device) 36 | methods = ('euler',) 37 | options = (None,) 38 | labels = ('euler',) 39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_general') 40 | 41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 42 | bm = BrownianInterval( 43 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device, 44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 45 | ) 46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 47 | 48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 49 | bm = BrownianInterval( 50 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device, 51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 52 | ) 53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /diagnostics/ito_scalar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralScalar 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d = 16, 16384, 5 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralScalar(d=d).to(device) 36 | methods = ('euler', 'milstein', 'milstein', 'srk') 37 | options = (None, None, dict(grad_free=True), None) 38 | labels = ('euler', 'milstein', 'gradient-free milstein', 'srk') 39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_scalar') 40 | 41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 42 | bm = BrownianInterval( 43 | t0=t0, t1=t1, size=(small_batch_size, 1), dtype=y0.dtype, device=device, 44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time 45 | ) 46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 47 | 48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 49 | bm = BrownianInterval( 50 | t0=t0, t1=t1, size=(large_batch_size, 1), dtype=y0.dtype, device=device, 51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time 52 | ) 53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /diagnostics/run_all.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import ito_additive, ito_diagonal, ito_general, ito_scalar 16 | from . import stratonovich_additive, stratonovich_diagonal, stratonovich_general, stratonovich_scalar 17 | 18 | if __name__ == '__main__': 19 | for module in (ito_additive, ito_diagonal, ito_general, ito_scalar, stratonovich_additive, stratonovich_diagonal, 20 | stratonovich_general, stratonovich_scalar): 21 | module.main() 22 | -------------------------------------------------------------------------------- /diagnostics/stratonovich_additive.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralAdditive 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralAdditive(d=d, m=m, sde_type=SDE_TYPES.stratonovich).to(device) 36 | # Milstein and log-ODE should in theory both superfluously compute zeros here. 37 | # We include them anyway to check that they do what they claim to do. 38 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'milstein', 'log_ode') 39 | options = (None, None, None, None, None, dict(grad_free=True), None) 40 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'grad-free milstein', 'log_ode') 41 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_additive') 42 | 43 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 44 | bm = BrownianInterval( 45 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device, 46 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 47 | ) 48 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 49 | 50 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 51 | bm = BrownianInterval( 52 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device, 53 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 54 | ) 55 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /diagnostics/stratonovich_diagonal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralDiagonal 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d = 16, 16384, 3 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralDiagonal(d=d, sde_type=SDE_TYPES.stratonovich).to(device) 36 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'milstein', 'log_ode') 37 | options = (None, None, None, None, None, dict(grad_free=True), None) 38 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'grad-free milstein', 'log_ode') 39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_diagonal') 40 | 41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 42 | bm = BrownianInterval( 43 | t0=t0, t1=t1, size=(small_batch_size, d), dtype=y0.dtype, device=device, 44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 45 | ) 46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 47 | 48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 49 | bm = BrownianInterval( 50 | t0=t0, t1=t1, size=(large_batch_size, d), dtype=y0.dtype, device=device, 51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 52 | ) 53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /diagnostics/stratonovich_general.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralGeneral 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralGeneral(d=d, m=m, sde_type=SDE_TYPES.stratonovich).to(device) 36 | # Don't include Milstein as it doesn't work for general noise. 37 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'log_ode') 38 | options = (None, None, None, None, None) 39 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'log_ode') 40 | 41 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_general') 42 | 43 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 44 | bm = BrownianInterval( 45 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device, 46 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 47 | ) 48 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 49 | 50 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 51 | bm = BrownianInterval( 52 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device, 53 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 54 | ) 55 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /diagnostics/stratonovich_scalar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | 19 | from tests.problems import NeuralScalar 20 | from torchsde import BrownianInterval 21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES 22 | from . import inspection 23 | from . import utils 24 | 25 | 26 | def main(): 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | torch.set_default_dtype(torch.float64) 29 | utils.manual_seed() 30 | 31 | small_batch_size, large_batch_size, d = 16, 16384, 3 32 | t0, t1, steps, dt = 0., 2., 10, 1e-1 33 | ts = torch.linspace(t0, t1, steps=steps, device=device) 34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order. 35 | sde = NeuralScalar(d=d, sde_type=SDE_TYPES.stratonovich).to(device) 36 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'milstein', 'log_ode') 37 | options = (None, None, None, None, None, dict(grad_free=True), None) 38 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'grad-free milstein', 'log_ode') 39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_scalar') 40 | 41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device) 42 | bm = BrownianInterval( 43 | t0=t0, t1=t1, size=(small_batch_size, 1), dtype=y0.dtype, device=device, 44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 45 | ) 46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels) 47 | 48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device) 49 | bm = BrownianInterval( 50 | t0=t0, t1=t1, size=(large_batch_size, 1), dtype=y0.dtype, device=device, 51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster 52 | ) 53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /diagnostics/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import itertools 16 | import os 17 | import random 18 | 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import torch 22 | from scipy import stats 23 | 24 | from torchsde.types import Optional, Tensor, Sequence, Union, Callable 25 | 26 | 27 | def to_numpy(*args): 28 | """Convert a sequence which might contain Tensors to numpy arrays.""" 29 | if len(args) == 1: 30 | arg = args[0] 31 | if isinstance(arg, torch.Tensor): 32 | arg = _to_numpy_single(arg) 33 | return arg 34 | else: 35 | return tuple(_to_numpy_single(arg) if isinstance(arg, torch.Tensor) else arg for arg in args) 36 | 37 | 38 | def _to_numpy_single(arg: torch.Tensor) -> np.ndarray: 39 | return arg.detach().cpu().numpy() 40 | 41 | 42 | def mse(x: Tensor, y: Tensor, norm_dim: Optional[int] = 1, mean_dim: Optional[int] = 0) -> np.ndarray: 43 | """Compute mean squared error.""" 44 | return _to_numpy_single((torch.norm(x - y, dim=norm_dim) ** 2).mean(dim=mean_dim)) 45 | 46 | 47 | def mae(x: Tensor, y: Tensor, test_func: Callable, mean_dim: Optional[int] = 0) -> np.ndarray: 48 | return _to_numpy_single( 49 | abs(test_func(x).mean(mean_dim) - test_func(y).mean(mean_dim)) 50 | ) 51 | 52 | 53 | def log(x: Union[Sequence[float], np.ndarray]) -> np.ndarray: 54 | """Compute element-wise log of a sequence of floats.""" 55 | return np.log(np.array(x)) 56 | 57 | 58 | def linregress_slope(x, y): 59 | """Return the slope of a least-squares regression for two sets of measurements.""" 60 | return stats.linregress(x, y)[0] 61 | 62 | 63 | def swiss_knife_plotter(img_path, plots=None, scatters=None, hists=None, options=None): 64 | """A multi-functional *standalone* wrapper; reduces boilerplate. 65 | 66 | Args: 67 | img_path (str): A path to the place where the image should be written. 68 | plots (list of dict, optional): A list of curves that needs `plt.plot`. 69 | scatters (list of dict, optional): A list of scatter plots that needs `plt.scatter`. 70 | hists (list of histograms, optional): A list of histograms that needs `plt.hist`. 71 | options (dict, optional): A dictionary of optional arguments. Possible entries include 72 | - xscale (str): Scale of xaxis. 73 | - yscale (str): Scale of yaxis. 74 | - xlabel (str): Label of xaxis. 75 | - ylabel (str): Label of yaxis. 76 | - title (str): Title of the plot. 77 | - cycle_linestyle (bool): Cycle through matplotlib's possible line styles if True. 78 | 79 | Returns: 80 | Nothing. 81 | """ 82 | img_dir = os.path.dirname(img_path) 83 | if not os.path.exists(img_dir): 84 | os.makedirs(img_dir) 85 | 86 | if plots is None: plots = () 87 | if scatters is None: scatters = () 88 | if hists is None: hists = () 89 | if options is None: options = {} 90 | 91 | plt.figure(dpi=300) 92 | if 'xscale' in options: plt.xscale(options['xscale']) 93 | if 'yscale' in options: plt.yscale(options['yscale']) 94 | if 'xlabel' in options: plt.xlabel(options['xlabel']) 95 | if 'ylabel' in options: plt.ylabel(options['ylabel']) 96 | if 'title' in options: plt.title(options['title']) 97 | 98 | cycle_linestyle = options.get('cycle_linestyle', False) 99 | cycler = itertools.cycle(["-", "--", "-.", ":"]) if cycle_linestyle else None 100 | for entry in plots: 101 | kwargs = {key: entry[key] for key in entry if key != 'x' and key != 'y'} 102 | entry['x'], entry['y'] = to_numpy(entry['x'], entry['y']) 103 | if cycle_linestyle: 104 | kwargs['linestyle'] = next(cycler) 105 | plt.plot(entry['x'], entry['y'], **kwargs) 106 | 107 | for entry in scatters: 108 | kwargs = {key: entry[key] for key in entry if key != 'x' and key != 'y'} 109 | entry['x'], entry['y'] = to_numpy(entry['x'], entry['y']) 110 | plt.scatter(entry['x'], entry['y'], **kwargs) 111 | 112 | for entry in hists: 113 | kwargs = {key: entry[key] for key in entry if key != 'x'} 114 | entry['x'] = to_numpy(entry['x']) 115 | plt.hist(entry['x'], **kwargs) 116 | 117 | if len(plots) > 0 or len(scatters) > 0: plt.legend() 118 | plt.tight_layout() 119 | plt.savefig(img_path) 120 | plt.close() 121 | 122 | 123 | def manual_seed(seed: Optional[int] = 1147481649): 124 | """Set seeds for default generators of 1) torch, 2) numpy, and 3) Python's random library.""" 125 | torch.manual_seed(seed) 126 | np.random.seed(seed) 127 | random.seed(seed) 128 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/torchsde/eb3a00e31cbd56176270066ed2f62c394cf6acb7/examples/__init__.py -------------------------------------------------------------------------------- /examples/cont_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A min example for continuous-time Denoising Diffusion Probabilistic Models. 16 | 17 | Trains the backward dynamics to be close to the reverse of a fixed forward 18 | dynamics via a score-matching-type objective. 19 | 20 | Trains a simple model on MNIST and samples from both the reverse ODE and 21 | SDE formulation. 22 | 23 | To run this file, first run the following to install extra requirements: 24 | pip install kornia 25 | pip install einops 26 | pip install torchdiffeq 27 | pip install fire 28 | 29 | To run, execute: 30 | python -m examples.cont_ddpm 31 | """ 32 | import abc 33 | import logging 34 | import math 35 | import os 36 | 37 | import fire 38 | import torch 39 | import torchdiffeq 40 | import torchvision as tv 41 | import tqdm 42 | from torch import nn, optim 43 | from torch.utils import data 44 | 45 | import torchsde 46 | from . import unet 47 | 48 | 49 | def fill_tail_dims(y: torch.Tensor, y_like: torch.Tensor): 50 | """Fill in missing trailing dimensions for y according to y_like.""" 51 | return y[(...,) + (None,) * (y_like.dim() - y.dim())] 52 | 53 | 54 | class Module(abc.ABC, nn.Module): 55 | """A wrapper module that's more convenient to use.""" 56 | 57 | def __init__(self): 58 | super(Module, self).__init__() 59 | self._checkpoint = False 60 | 61 | def zero_grad(self) -> None: 62 | for p in self.parameters(): p.grad = None 63 | 64 | @property 65 | def device(self): 66 | return next(self.parameters()).device 67 | 68 | 69 | class ScoreMatchingSDE(Module): 70 | """Wraps score network with analytical sampling and cond. score computation. 71 | 72 | The variance preserving formulation in 73 | Score-Based Generative Modeling through Stochastic Differential Equations 74 | https://arxiv.org/abs/2011.13456 75 | """ 76 | 77 | def __init__(self, denoiser, input_size=(1, 28, 28), t0=0., t1=1., beta_min=.1, beta_max=20.): 78 | super(ScoreMatchingSDE, self).__init__() 79 | if t0 > t1: 80 | raise ValueError(f"Expected t0 <= t1, but found t0={t0:.4f}, t1={t1:.4f}") 81 | 82 | self.input_size = input_size 83 | self.denoiser = denoiser 84 | 85 | self.t0 = t0 86 | self.t1 = t1 87 | 88 | self.beta_min = beta_min 89 | self.beta_max = beta_max 90 | 91 | def score(self, t, y): 92 | if isinstance(t, float): 93 | t = y.new_tensor(t) 94 | if t.dim() == 0: 95 | t = t.repeat(y.shape[0]) 96 | return self.denoiser(t, y) 97 | 98 | def _beta(self, t): 99 | return self.beta_min + t * (self.beta_max - self.beta_min) 100 | 101 | def _indefinite_int(self, t): 102 | """Indefinite integral of beta(t).""" 103 | return self.beta_min * t + .5 * t ** 2 * (self.beta_max - self.beta_min) 104 | 105 | def analytical_mean(self, t, x_t0): 106 | mean_coeff = (-.5 * (self._indefinite_int(t) - self._indefinite_int(self.t0))).exp() 107 | mean = x_t0 * fill_tail_dims(mean_coeff, x_t0) 108 | return mean 109 | 110 | def analytical_var(self, t, x_t0): 111 | analytical_var = 1 - (-self._indefinite_int(t) + self._indefinite_int(self.t0)).exp() 112 | return analytical_var 113 | 114 | @torch.no_grad() 115 | def analytical_sample(self, t, x_t0): 116 | mean = self.analytical_mean(t, x_t0) 117 | var = self.analytical_var(t, x_t0) 118 | return mean + torch.randn_like(mean) * fill_tail_dims(var.sqrt(), mean) 119 | 120 | @torch.no_grad() 121 | def analytical_score(self, x_t, t, x_t0): 122 | mean = self.analytical_mean(t, x_t0) 123 | var = self.analytical_var(t, x_t0) 124 | return - (x_t - mean) / fill_tail_dims(var, mean).clamp_min(1e-5) 125 | 126 | def f(self, t, y): 127 | return -0.5 * self._beta(t) * y 128 | 129 | def g(self, t, y): 130 | return fill_tail_dims(self._beta(t).sqrt(), y).expand_as(y) 131 | 132 | def sample_t1_marginal(self, batch_size, tau=1.): 133 | return torch.randn(size=(batch_size, *self.input_size), device=self.device) * math.sqrt(tau) 134 | 135 | def lambda_t(self, t): 136 | return self.analytical_var(t, None) 137 | 138 | def forward(self, x_t0, partitions=1): 139 | """Compute the score matching objective. 140 | Split [t0, t1] into partitions; sample uniformly on each partition to reduce gradient variance. 141 | """ 142 | u = torch.rand(size=(x_t0.shape[0], partitions), dtype=x_t0.dtype, device=x_t0.device) 143 | u.mul_((self.t1 - self.t0) / partitions) 144 | shifts = torch.arange(0, partitions, device=x_t0.device, dtype=x_t0.dtype)[None, :] 145 | shifts.mul_((self.t1 - self.t0) / partitions).add_(self.t0) 146 | t = (u + shifts).reshape(-1) 147 | lambda_t = self.lambda_t(t) 148 | 149 | x_t0 = x_t0.repeat_interleave(partitions, dim=0) 150 | x_t = self.analytical_sample(t, x_t0) 151 | 152 | fake_score = self.score(t, x_t) 153 | true_score = self.analytical_score(x_t, t, x_t0) 154 | loss = (lambda_t * ((fake_score - true_score) ** 2).flatten(start_dim=1).sum(dim=1)) 155 | return loss 156 | 157 | 158 | class ReverseDiffeqWrapper(Module): 159 | """Wrapper of the score network for odeint/sdeint. 160 | 161 | We split this module out, so that `forward` of the score network is solely 162 | used for computing the score, and the `forward` here is used for odeint. 163 | Helps with data parallel. 164 | """ 165 | noise_type = "diagonal" 166 | sde_type = "stratonovich" 167 | 168 | def __init__(self, module: ScoreMatchingSDE): 169 | super(ReverseDiffeqWrapper, self).__init__() 170 | self.module = module 171 | 172 | # --- odeint --- 173 | def forward(self, t, y): 174 | return -(self.module.f(-t, y) - .5 * self.module.g(-t, y) ** 2 * self.module.score(-t, y)) 175 | 176 | # --- sdeint --- 177 | def f(self, t, y): 178 | y = y.view(-1, *self.module.input_size) 179 | out = -(self.module.f(-t, y) - self.module.g(-t, y) ** 2 * self.module.score(-t, y)) 180 | return out.flatten(start_dim=1) 181 | 182 | def g(self, t, y): 183 | y = y.view(-1, *self.module.input_size) 184 | out = -self.module.g(-t, y) 185 | return out.flatten(start_dim=1) 186 | 187 | # --- sample --- 188 | def sample_t1_marginal(self, batch_size, tau=1.): 189 | return self.module.sample_t1_marginal(batch_size, tau) 190 | 191 | @torch.no_grad() 192 | def ode_sample(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2): 193 | self.module.eval() 194 | 195 | t = torch.tensor([-self.t1, -self.t0], device=self.device) if t is None else t 196 | y = self.sample_t1_marginal(batch_size, tau) if y is None else y 197 | return torchdiffeq.odeint(self, y, t, method="rk4", options={"step_size": dt}) 198 | 199 | @torch.no_grad() 200 | def ode_sample_final(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2): 201 | return self.ode_sample(batch_size, tau, t, y, dt)[-1] 202 | 203 | @torch.no_grad() 204 | def sde_sample(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2, tweedie_correction=True): 205 | self.module.eval() 206 | 207 | t = torch.tensor([-self.t1, -self.t0], device=self.device) if t is None else t 208 | y = self.sample_t1_marginal(batch_size, tau) if y is None else y 209 | 210 | ys = torchsde.sdeint(self, y.flatten(start_dim=1), t, dt=dt) 211 | ys = ys.view(len(t), *y.size()) 212 | if tweedie_correction: 213 | ys[-1] = self.tweedie_correction(self.t0, ys[-1], dt) 214 | return ys 215 | 216 | @torch.no_grad() 217 | def sde_sample_final(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2): 218 | return self.sde_sample(batch_size, tau, t, y, dt)[-1] 219 | 220 | def tweedie_correction(self, t, y, dt): 221 | return y + dt ** 2 * self.module.score(t, y) 222 | 223 | @property 224 | def t0(self): 225 | return self.module.t0 226 | 227 | @property 228 | def t1(self): 229 | return self.module.t1 230 | 231 | 232 | def preprocess(x, logit_transform, alpha=0.95): 233 | if logit_transform: 234 | x = alpha + (1 - 2 * alpha) * x 235 | x = (x / (1 - x)).log() 236 | else: 237 | x = (x - 0.5) * 2 238 | return x 239 | 240 | 241 | def postprocess(x, logit_transform, alpha=0.95, clamp=True): 242 | if logit_transform: 243 | x = (x.sigmoid() - alpha) / (1 - 2 * alpha) 244 | else: 245 | x = x * 0.5 + 0.5 246 | return x.clamp(min=0., max=1.) if clamp else x 247 | 248 | 249 | def make_loader( 250 | root="./data/mnist", 251 | train_batch_size=128, 252 | shuffle=True, 253 | pin_memory=True, 254 | num_workers=0, 255 | drop_last=True 256 | ): 257 | """Make a simple loader for training images in MNIST.""" 258 | 259 | def dequantize(x, nvals=256): 260 | """[0, 1] -> [0, nvals] -> add uniform noise -> [0, 1]""" 261 | noise = x.new().resize_as_(x).uniform_() 262 | x = x * (nvals - 1) + noise 263 | x = x / nvals 264 | return x 265 | 266 | train_transform = tv.transforms.Compose([tv.transforms.ToTensor(), dequantize]) 267 | train_data = tv.datasets.MNIST(root, train=True, transform=train_transform, download=True) 268 | train_loader = data.DataLoader( 269 | train_data, 270 | batch_size=train_batch_size, 271 | drop_last=drop_last, 272 | shuffle=shuffle, 273 | pin_memory=pin_memory, 274 | num_workers=num_workers 275 | ) 276 | return train_loader 277 | 278 | 279 | def main( 280 | train_dir="./dump/cont_ddpm/", 281 | epochs=100, 282 | lr=1e-4, 283 | batch_size=128, 284 | pause_every=1000, 285 | tau=1., 286 | logit_transform=True, 287 | ): 288 | """Train and sample once in a while. 289 | 290 | Args: 291 | train_dir: Path to a folder to dump things. 292 | epochs: Number of training epochs. 293 | lr: Learning rate for Adam. 294 | batch_size: Batch size for training. 295 | pause_every: Log and write figures once in this many iterations. 296 | tau: The temperature for sampling. 297 | logit_transform: Applies the typical logit transformation if True. 298 | """ 299 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 300 | 301 | # Data. 302 | train_loader = make_loader(root=os.path.join(train_dir, 'data'), train_batch_size=batch_size) 303 | 304 | # Model + optimizer. 305 | denoiser = unet.Unet( 306 | input_size=(1, 28, 28), 307 | dim_mults=(1, 2, 4,), 308 | attention_cls=unet.LinearTimeSelfAttention, 309 | ) 310 | forward = ScoreMatchingSDE(denoiser=denoiser).to(device) 311 | reverse = ReverseDiffeqWrapper(forward) 312 | optimizer = optim.Adam(params=forward.parameters(), lr=lr) 313 | 314 | def plot(imgs, path): 315 | assert not torch.any(torch.isnan(imgs)), "Found nans in images" 316 | os.makedirs(os.path.dirname(path), exist_ok=True) 317 | imgs = postprocess(imgs, logit_transform=logit_transform).detach().cpu() 318 | tv.utils.save_image(imgs, path) 319 | 320 | global_step = 0 321 | for epoch in range(epochs): 322 | for x, _ in tqdm.tqdm(train_loader): 323 | forward.train() 324 | forward.zero_grad() 325 | x = preprocess(x.to(device), logit_transform=logit_transform) 326 | loss = forward(x).mean(dim=0) 327 | loss.backward() 328 | optimizer.step() 329 | global_step += 1 330 | 331 | if global_step % pause_every == 0: 332 | logging.warning(f'global_step: {global_step:06d}, loss: {loss:.4f}') 333 | 334 | img_path = os.path.join(train_dir, 'ode_samples', f'global_step_{global_step:07d}.png') 335 | ode_samples = reverse.ode_sample_final(tau=tau) 336 | plot(ode_samples, img_path) 337 | 338 | img_path = os.path.join(train_dir, 'sde_samples', f'global_step_{global_step:07d}.png') 339 | sde_samples = reverse.sde_sample_final(tau=tau) 340 | plot(sde_samples, img_path) 341 | 342 | 343 | if __name__ == "__main__": 344 | fire.Fire(main) 345 | -------------------------------------------------------------------------------- /examples/latent_sde_lorenz.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Train a latent SDE on data from a stochastic Lorenz attractor. 16 | 17 | Reproduce the toy example in Section 7.2 of https://arxiv.org/pdf/2001.01328.pdf 18 | 19 | To run this file, first run the following to install extra requirements: 20 | pip install fire 21 | 22 | To run, execute: 23 | python -m examples.latent_sde_lorenz 24 | """ 25 | import logging 26 | import os 27 | from typing import Sequence 28 | 29 | import fire 30 | import matplotlib.gridspec as gridspec 31 | import matplotlib.pyplot as plt 32 | import numpy as np 33 | import torch 34 | import tqdm 35 | from torch import nn 36 | from torch import optim 37 | from torch.distributions import Normal 38 | 39 | import torchsde 40 | 41 | 42 | class LinearScheduler(object): 43 | def __init__(self, iters, maxval=1.0): 44 | self._iters = max(1, iters) 45 | self._val = maxval / self._iters 46 | self._maxval = maxval 47 | 48 | def step(self): 49 | self._val = min(self._maxval, self._val + self._maxval / self._iters) 50 | 51 | @property 52 | def val(self): 53 | return self._val 54 | 55 | 56 | class StochasticLorenz(object): 57 | """Stochastic Lorenz attractor. 58 | 59 | Used for simulating ground truth and obtaining noisy data. 60 | Details described in Section 7.2 https://arxiv.org/pdf/2001.01328.pdf 61 | Default a, b from https://openreview.net/pdf?id=HkzRQhR9YX 62 | """ 63 | noise_type = "diagonal" 64 | sde_type = "ito" 65 | 66 | def __init__(self, a: Sequence = (10., 28., 8 / 3), b: Sequence = (.1, .28, .3)): 67 | super(StochasticLorenz, self).__init__() 68 | self.a = a 69 | self.b = b 70 | 71 | def f(self, t, y): 72 | x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1) 73 | a1, a2, a3 = self.a 74 | 75 | f1 = a1 * (x2 - x1) 76 | f2 = a2 * x1 - x2 - x1 * x3 77 | f3 = x1 * x2 - a3 * x3 78 | return torch.cat([f1, f2, f3], dim=1) 79 | 80 | def g(self, t, y): 81 | x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1) 82 | b1, b2, b3 = self.b 83 | 84 | g1 = x1 * b1 85 | g2 = x2 * b2 86 | g3 = x3 * b3 87 | return torch.cat([g1, g2, g3], dim=1) 88 | 89 | @torch.no_grad() 90 | def sample(self, x0, ts, noise_std, normalize): 91 | """Sample data for training. Store data normalization constants if necessary.""" 92 | xs = torchsde.sdeint(self, x0, ts) 93 | if normalize: 94 | mean, std = torch.mean(xs, dim=(0, 1)), torch.std(xs, dim=(0, 1)) 95 | xs.sub_(mean).div_(std).add_(torch.randn_like(xs) * noise_std) 96 | return xs 97 | 98 | 99 | class Encoder(nn.Module): 100 | def __init__(self, input_size, hidden_size, output_size): 101 | super(Encoder, self).__init__() 102 | self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size) 103 | self.lin = nn.Linear(hidden_size, output_size) 104 | 105 | def forward(self, inp): 106 | out, _ = self.gru(inp) 107 | out = self.lin(out) 108 | return out 109 | 110 | 111 | class LatentSDE(nn.Module): 112 | sde_type = "ito" 113 | noise_type = "diagonal" 114 | 115 | def __init__(self, data_size, latent_size, context_size, hidden_size): 116 | super(LatentSDE, self).__init__() 117 | # Encoder. 118 | self.encoder = Encoder(input_size=data_size, hidden_size=hidden_size, output_size=context_size) 119 | self.qz0_net = nn.Linear(context_size, latent_size + latent_size) 120 | 121 | # Decoder. 122 | self.f_net = nn.Sequential( 123 | nn.Linear(latent_size + context_size, hidden_size), 124 | nn.Softplus(), 125 | nn.Linear(hidden_size, hidden_size), 126 | nn.Softplus(), 127 | nn.Linear(hidden_size, latent_size), 128 | ) 129 | self.h_net = nn.Sequential( 130 | nn.Linear(latent_size, hidden_size), 131 | nn.Softplus(), 132 | nn.Linear(hidden_size, hidden_size), 133 | nn.Softplus(), 134 | nn.Linear(hidden_size, latent_size), 135 | ) 136 | # This needs to be an element-wise function for the SDE to satisfy diagonal noise. 137 | self.g_nets = nn.ModuleList( 138 | [ 139 | nn.Sequential( 140 | nn.Linear(1, hidden_size), 141 | nn.Softplus(), 142 | nn.Linear(hidden_size, 1), 143 | nn.Sigmoid() 144 | ) 145 | for _ in range(latent_size) 146 | ] 147 | ) 148 | self.projector = nn.Linear(latent_size, data_size) 149 | 150 | self.pz0_mean = nn.Parameter(torch.zeros(1, latent_size)) 151 | self.pz0_logstd = nn.Parameter(torch.zeros(1, latent_size)) 152 | 153 | self._ctx = None 154 | 155 | def contextualize(self, ctx): 156 | self._ctx = ctx # A tuple of tensors of sizes (T,), (T, batch_size, d). 157 | 158 | def f(self, t, y): 159 | ts, ctx = self._ctx 160 | i = min(torch.searchsorted(ts, t, right=True), len(ts) - 1) 161 | return self.f_net(torch.cat((y, ctx[i]), dim=1)) 162 | 163 | def h(self, t, y): 164 | return self.h_net(y) 165 | 166 | def g(self, t, y): # Diagonal diffusion. 167 | y = torch.split(y, split_size_or_sections=1, dim=1) 168 | out = [g_net_i(y_i) for (g_net_i, y_i) in zip(self.g_nets, y)] 169 | return torch.cat(out, dim=1) 170 | 171 | def forward(self, xs, ts, noise_std, adjoint=False, method="euler"): 172 | # Contextualization is only needed for posterior inference. 173 | ctx = self.encoder(torch.flip(xs, dims=(0,))) 174 | ctx = torch.flip(ctx, dims=(0,)) 175 | self.contextualize((ts, ctx)) 176 | 177 | qz0_mean, qz0_logstd = self.qz0_net(ctx[0]).chunk(chunks=2, dim=1) 178 | z0 = qz0_mean + qz0_logstd.exp() * torch.randn_like(qz0_mean) 179 | 180 | if adjoint: 181 | # Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`. 182 | adjoint_params = ( 183 | (ctx,) + 184 | tuple(self.f_net.parameters()) + tuple(self.g_nets.parameters()) + tuple(self.h_net.parameters()) 185 | ) 186 | zs, log_ratio = torchsde.sdeint_adjoint( 187 | self, z0, ts, adjoint_params=adjoint_params, dt=1e-2, logqp=True, method=method) 188 | else: 189 | zs, log_ratio = torchsde.sdeint(self, z0, ts, dt=1e-2, logqp=True, method=method) 190 | 191 | _xs = self.projector(zs) 192 | xs_dist = Normal(loc=_xs, scale=noise_std) 193 | log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean(dim=0) 194 | 195 | qz0 = torch.distributions.Normal(loc=qz0_mean, scale=qz0_logstd.exp()) 196 | pz0 = torch.distributions.Normal(loc=self.pz0_mean, scale=self.pz0_logstd.exp()) 197 | logqp0 = torch.distributions.kl_divergence(qz0, pz0).sum(dim=1).mean(dim=0) 198 | logqp_path = log_ratio.sum(dim=0).mean(dim=0) 199 | return log_pxs, logqp0 + logqp_path 200 | 201 | @torch.no_grad() 202 | def sample(self, batch_size, ts, bm=None): 203 | eps = torch.randn(size=(batch_size, *self.pz0_mean.shape[1:]), device=self.pz0_mean.device) 204 | z0 = self.pz0_mean + self.pz0_logstd.exp() * eps 205 | zs = torchsde.sdeint(self, z0, ts, names={'drift': 'h'}, dt=1e-3, bm=bm) 206 | # Most of the times in ML, we don't sample the observation noise for visualization purposes. 207 | _xs = self.projector(zs) 208 | return _xs 209 | 210 | 211 | def make_dataset(t0, t1, batch_size, noise_std, train_dir, device): 212 | data_path = os.path.join(train_dir, 'lorenz_data.pth') 213 | if os.path.exists(data_path): 214 | data_dict = torch.load(data_path) 215 | xs, ts = data_dict['xs'], data_dict['ts'] 216 | logging.warning(f'Loaded toy data at: {data_path}') 217 | if xs.shape[1] != batch_size: 218 | raise ValueError("Batch size has changed; please delete and regenerate the data.") 219 | if ts[0] != t0 or ts[-1] != t1: 220 | raise ValueError("Times interval [t0, t1] has changed; please delete and regenerate the data.") 221 | else: 222 | _y0 = torch.randn(batch_size, 3, device=device) 223 | ts = torch.linspace(t0, t1, steps=100, device=device) 224 | xs = StochasticLorenz().sample(_y0, ts, noise_std, normalize=True) 225 | 226 | os.makedirs(os.path.dirname(data_path), exist_ok=True) 227 | torch.save({'xs': xs, 'ts': ts}, data_path) 228 | logging.warning(f'Stored toy data at: {data_path}') 229 | return xs, ts 230 | 231 | 232 | def vis(xs, ts, latent_sde, bm_vis, img_path, num_samples=10): 233 | fig = plt.figure(figsize=(20, 9)) 234 | gs = gridspec.GridSpec(1, 2) 235 | ax00 = fig.add_subplot(gs[0, 0], projection='3d') 236 | ax01 = fig.add_subplot(gs[0, 1], projection='3d') 237 | 238 | # Left plot: data. 239 | z1, z2, z3 = np.split(xs.cpu().numpy(), indices_or_sections=3, axis=-1) 240 | [ax00.plot(z1[:, i, 0], z2[:, i, 0], z3[:, i, 0]) for i in range(num_samples)] 241 | ax00.scatter(z1[0, :num_samples, 0], z2[0, :num_samples, 0], z3[0, :10, 0], marker='x') 242 | ax00.set_yticklabels([]) 243 | ax00.set_xticklabels([]) 244 | ax00.set_zticklabels([]) 245 | ax00.set_xlabel('$z_1$', labelpad=0., fontsize=16) 246 | ax00.set_ylabel('$z_2$', labelpad=.5, fontsize=16) 247 | ax00.set_zlabel('$z_3$', labelpad=0., horizontalalignment='center', fontsize=16) 248 | ax00.set_title('Data', fontsize=20) 249 | xlim = ax00.get_xlim() 250 | ylim = ax00.get_ylim() 251 | zlim = ax00.get_zlim() 252 | 253 | # Right plot: samples from learned model. 254 | xs = latent_sde.sample(batch_size=xs.size(1), ts=ts, bm=bm_vis).cpu().numpy() 255 | z1, z2, z3 = np.split(xs, indices_or_sections=3, axis=-1) 256 | 257 | [ax01.plot(z1[:, i, 0], z2[:, i, 0], z3[:, i, 0]) for i in range(num_samples)] 258 | ax01.scatter(z1[0, :num_samples, 0], z2[0, :num_samples, 0], z3[0, :10, 0], marker='x') 259 | ax01.set_yticklabels([]) 260 | ax01.set_xticklabels([]) 261 | ax01.set_zticklabels([]) 262 | ax01.set_xlabel('$z_1$', labelpad=0., fontsize=16) 263 | ax01.set_ylabel('$z_2$', labelpad=.5, fontsize=16) 264 | ax01.set_zlabel('$z_3$', labelpad=0., horizontalalignment='center', fontsize=16) 265 | ax01.set_title('Samples', fontsize=20) 266 | ax01.set_xlim(xlim) 267 | ax01.set_ylim(ylim) 268 | ax01.set_zlim(zlim) 269 | 270 | plt.savefig(img_path) 271 | plt.close() 272 | 273 | 274 | def main( 275 | batch_size=1024, 276 | latent_size=4, 277 | context_size=64, 278 | hidden_size=128, 279 | lr_init=1e-2, 280 | t0=0., 281 | t1=2., 282 | lr_gamma=0.997, 283 | num_iters=5000, 284 | kl_anneal_iters=1000, 285 | pause_every=50, 286 | noise_std=0.01, 287 | adjoint=False, 288 | train_dir='./dump/lorenz/', 289 | method="euler", 290 | ): 291 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 292 | 293 | xs, ts = make_dataset(t0=t0, t1=t1, batch_size=batch_size, noise_std=noise_std, train_dir=train_dir, device=device) 294 | latent_sde = LatentSDE( 295 | data_size=3, 296 | latent_size=latent_size, 297 | context_size=context_size, 298 | hidden_size=hidden_size, 299 | ).to(device) 300 | optimizer = optim.Adam(params=latent_sde.parameters(), lr=lr_init) 301 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_gamma) 302 | kl_scheduler = LinearScheduler(iters=kl_anneal_iters) 303 | 304 | # Fix the same Brownian motion for visualization. 305 | bm_vis = torchsde.BrownianInterval( 306 | t0=t0, t1=t1, size=(batch_size, latent_size,), device=device, levy_area_approximation="space-time") 307 | 308 | for global_step in tqdm.tqdm(range(1, num_iters + 1)): 309 | latent_sde.zero_grad() 310 | log_pxs, log_ratio = latent_sde(xs, ts, noise_std, adjoint, method) 311 | loss = -log_pxs + log_ratio * kl_scheduler.val 312 | loss.backward() 313 | optimizer.step() 314 | scheduler.step() 315 | kl_scheduler.step() 316 | 317 | if global_step % pause_every == 0: 318 | lr_now = optimizer.param_groups[0]['lr'] 319 | logging.warning( 320 | f'global_step: {global_step:06d}, lr: {lr_now:.5f}, ' 321 | f'log_pxs: {log_pxs:.4f}, log_ratio: {log_ratio:.4f} loss: {loss:.4f}, kl_coeff: {kl_scheduler.val:.4f}' 322 | ) 323 | img_path = os.path.join(train_dir, f'global_step_{global_step:06d}.pdf') 324 | vis(xs, ts, latent_sde, bm_vis, img_path) 325 | 326 | 327 | if __name__ == "__main__": 328 | fire.Fire(main) 329 | -------------------------------------------------------------------------------- /examples/unet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """U-Nets for continuous-time Denoising Diffusion Probabilistic Models. 16 | 17 | This file only serves as a helper for `examples/cont_ddpm.py`. 18 | 19 | To use this file, run the following to install extra requirements: 20 | 21 | pip install kornia 22 | pip install einops 23 | """ 24 | import math 25 | 26 | import kornia 27 | import torch 28 | import torch.nn.functional as F 29 | from einops import rearrange 30 | from torch import nn 31 | 32 | 33 | class Mish(nn.Module): 34 | def forward(self, x): 35 | return _mish(x) 36 | 37 | 38 | @torch.jit.script 39 | def _mish(x): 40 | return x * torch.tanh(F.softplus(x)) 41 | 42 | 43 | class SinusoidalPosEmb(nn.Module): 44 | def __init__(self, dim): 45 | super().__init__() 46 | half_dim = dim // 2 47 | emb = math.log(10000) / (half_dim - 1) 48 | self.register_buffer('emb', torch.exp(torch.arange(half_dim) * -emb)) 49 | 50 | def forward(self, x): 51 | emb = x[:, None] * self.emb[None, :] 52 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 53 | return emb 54 | 55 | 56 | class SelfAttention(nn.Module): 57 | def __init__(self, dim, groups=32, **kwargs): 58 | super().__init__() 59 | self.group_norm = nn.GroupNorm(groups, dim) 60 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1) 61 | self.out = nn.Conv2d(dim, dim, kernel_size=1) 62 | 63 | def forward(self, x): 64 | b, c, h, w = x.size() 65 | x = self.group_norm(x) 66 | q, k, v = tuple(t.view(b, c, h * w) for t in self.qkv(x).chunk(chunks=3, dim=1)) 67 | attn_matrix = (torch.bmm(k.permute(0, 2, 1), q) / math.sqrt(c)).softmax(dim=-2) 68 | out = torch.bmm(v, attn_matrix).view(b, c, h, w) 69 | return self.out(out) 70 | 71 | 72 | class LinearTimeSelfAttention(nn.Module): 73 | def __init__(self, dim, heads=4, dim_head=32, groups=32): 74 | super().__init__() 75 | self.group_norm = nn.GroupNorm(groups, dim) 76 | self.heads = heads 77 | hidden_dim = dim_head * heads 78 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1) 79 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 80 | 81 | def forward(self, x): 82 | b, c, h, w = x.shape 83 | x = self.group_norm(x) 84 | qkv = self.to_qkv(x) 85 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) 86 | k = k.softmax(dim=-1) 87 | context = torch.einsum('bhdn,bhen->bhde', k, v) 88 | out = torch.einsum('bhde,bhdn->bhen', context, q) 89 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 90 | return self.to_out(out) 91 | 92 | 93 | class ResnetBlock(nn.Module): 94 | def __init__(self, dim, dim_out, *, time_emb_dim, groups=8, dropout_rate=0.): 95 | super().__init__() 96 | # groups: Used in group norm. 97 | self.dim = dim 98 | self.dim_out = dim_out 99 | self.groups = groups 100 | self.dropout_rate = dropout_rate 101 | self.time_emb_dim = time_emb_dim 102 | 103 | self.mlp = nn.Sequential( 104 | Mish(), 105 | nn.Linear(time_emb_dim, dim_out) 106 | ) 107 | # Norm -> non-linearity -> conv format follows 108 | # https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/nn.py#L55 109 | self.block1 = nn.Sequential( 110 | nn.GroupNorm(groups, dim), 111 | Mish(), 112 | nn.Conv2d(dim, dim_out, 3, padding=1), 113 | ) 114 | self.block2 = nn.Sequential( 115 | nn.GroupNorm(groups, dim_out), 116 | Mish(), 117 | nn.Dropout(p=dropout_rate), 118 | nn.Conv2d(dim_out, dim_out, 3, padding=1), 119 | ) 120 | self.res_conv = nn.Conv2d(dim, dim_out, 1) 121 | 122 | def forward(self, x, t): 123 | h = self.block1(x) 124 | h += self.mlp(t)[..., None, None] 125 | h = self.block2(h) 126 | return h + self.res_conv(x) 127 | 128 | def __repr__(self): 129 | return (f"{self.__class__.__name__}(dim={self.dim}, dim_out={self.dim_out}, time_emb_dim=" 130 | f"{self.time_emb_dim}, groups={self.groups}, dropout_rate={self.dropout_rate})") 131 | 132 | 133 | class Residual(nn.Module): 134 | def __init__(self, fn): 135 | super().__init__() 136 | self.fn = fn 137 | 138 | def forward(self, x, *args, **kwargs): 139 | return self.fn(x, *args, **kwargs) + x 140 | 141 | 142 | class Blur(nn.Module): 143 | def __init__(self): 144 | super().__init__() 145 | f = torch.Tensor([1, 2, 1]) 146 | f = f[None, None, :] * f[None, :, None] 147 | self.register_buffer('f', f) 148 | 149 | def forward(self, x): 150 | return kornia.filter2D(x, self.f, normalized=True) 151 | 152 | 153 | class Downsample(nn.Module): 154 | def __init__(self, dim, blur=True): 155 | super().__init__() 156 | if blur: 157 | self.conv = nn.Sequential( 158 | Blur(), 159 | nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1), 160 | ) 161 | else: 162 | self.conv = nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1) 163 | 164 | def forward(self, x): 165 | return self.conv(x) 166 | 167 | 168 | class Upsample(nn.Module): 169 | def __init__(self, dim, blur=True): 170 | super().__init__() 171 | if blur: 172 | self.conv = nn.Sequential( 173 | nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1), 174 | Blur() 175 | ) 176 | else: 177 | self.conv = nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1) 178 | 179 | def forward(self, x): 180 | return self.conv(x) 181 | 182 | 183 | class Unet(nn.Module): 184 | def __init__(self, 185 | input_size=(3, 32, 32), 186 | hidden_channels=64, 187 | dim_mults=(1, 2, 4, 8), 188 | groups=32, 189 | heads=4, 190 | dim_head=32, 191 | dropout_rate=0., 192 | num_res_blocks=2, 193 | attn_resolutions=(16,), 194 | attention_cls=SelfAttention): 195 | super().__init__() 196 | in_channels, in_height, in_width = input_size 197 | dims = [hidden_channels, *map(lambda m: hidden_channels * m, dim_mults)] 198 | in_out = list(zip(dims[:-1], dims[1:])) 199 | 200 | self.time_pos_emb = SinusoidalPosEmb(hidden_channels) 201 | self.mlp = nn.Sequential( 202 | nn.Linear(hidden_channels, hidden_channels * 4), 203 | Mish(), 204 | nn.Linear(hidden_channels * 4, hidden_channels) 205 | ) 206 | 207 | self.first_conv = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) 208 | 209 | h, w = in_height, in_width 210 | self.down_res_blocks = nn.ModuleList([]) 211 | self.down_attn_blocks = nn.ModuleList([]) 212 | self.down_spatial_blocks = nn.ModuleList([]) 213 | for ind, (dim_in, dim_out) in enumerate(in_out): 214 | res_blocks = nn.ModuleList([ 215 | ResnetBlock( 216 | dim=dim_in, 217 | dim_out=dim_out, 218 | time_emb_dim=hidden_channels, 219 | groups=groups, 220 | dropout_rate=dropout_rate 221 | ) 222 | ]) 223 | res_blocks.extend([ 224 | ResnetBlock( 225 | dim=dim_out, 226 | dim_out=dim_out, 227 | time_emb_dim=hidden_channels, 228 | groups=groups, 229 | dropout_rate=dropout_rate 230 | ) for _ in range(num_res_blocks - 1) 231 | ]) 232 | self.down_res_blocks.append(res_blocks) 233 | 234 | attn_blocks = nn.ModuleList([]) 235 | if h in attn_resolutions and w in attn_resolutions: 236 | attn_blocks.extend( 237 | [Residual(attention_cls(dim_out, heads=heads, dim_head=dim_head, groups=groups)) 238 | for _ in range(num_res_blocks)] 239 | ) 240 | self.down_attn_blocks.append(attn_blocks) 241 | 242 | if ind < (len(in_out) - 1): 243 | spatial_blocks = nn.ModuleList([Downsample(dim_out)]) 244 | h, w = h // 2, w // 2 245 | else: 246 | spatial_blocks = nn.ModuleList() 247 | self.down_spatial_blocks.append(spatial_blocks) 248 | 249 | mid_dim = dims[-1] 250 | self.mid_block1 = ResnetBlock( 251 | dim=mid_dim, 252 | dim_out=mid_dim, 253 | time_emb_dim=hidden_channels, 254 | groups=groups, 255 | dropout_rate=dropout_rate 256 | ) 257 | self.mid_attn = Residual(attention_cls(mid_dim, heads=heads, dim_head=dim_head, groups=groups)) 258 | self.mid_block2 = ResnetBlock( 259 | dim=mid_dim, 260 | dim_out=mid_dim, 261 | time_emb_dim=hidden_channels, 262 | groups=groups, 263 | dropout_rate=dropout_rate 264 | ) 265 | 266 | self.ups_res_blocks = nn.ModuleList([]) 267 | self.ups_attn_blocks = nn.ModuleList([]) 268 | self.ups_spatial_blocks = nn.ModuleList([]) 269 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 270 | res_blocks = nn.ModuleList([ 271 | ResnetBlock( 272 | dim=dim_out * 2, 273 | dim_out=dim_out, 274 | time_emb_dim=hidden_channels, 275 | groups=groups, 276 | dropout_rate=dropout_rate 277 | ) for _ in range(num_res_blocks) 278 | ]) 279 | res_blocks.extend([ 280 | ResnetBlock( 281 | dim=dim_out + dim_in, 282 | dim_out=dim_in, 283 | time_emb_dim=hidden_channels, 284 | groups=groups, 285 | dropout_rate=dropout_rate 286 | ) 287 | ]) 288 | self.ups_res_blocks.append(res_blocks) 289 | 290 | attn_blocks = nn.ModuleList([]) 291 | if h in attn_resolutions and w in attn_resolutions: 292 | attn_blocks.extend( 293 | [Residual(attention_cls(dim_out, heads=heads, dim_head=dim_head, groups=groups)) 294 | for _ in range(num_res_blocks)] 295 | ) 296 | attn_blocks.append( 297 | Residual(attention_cls(dim_in, heads=heads, dim_head=dim_head, groups=groups)) 298 | ) 299 | self.ups_attn_blocks.append(attn_blocks) 300 | 301 | spatial_blocks = nn.ModuleList() 302 | if ind < (len(in_out) - 1): 303 | spatial_blocks.append(Upsample(dim_in)) 304 | h, w = h * 2, w * 2 305 | self.ups_spatial_blocks.append(spatial_blocks) 306 | 307 | self.final_conv = nn.Sequential( 308 | nn.GroupNorm(groups, hidden_channels), 309 | Mish(), 310 | nn.Conv2d(hidden_channels, in_channels, 1) 311 | ) 312 | 313 | def forward(self, t, x): 314 | t = self.mlp(self.time_pos_emb(t)) 315 | 316 | hs = [self.first_conv(x)] 317 | for i, (res_blocks, attn_blocks, spatial_blocks) in enumerate( 318 | zip(self.down_res_blocks, self.down_attn_blocks, self.down_spatial_blocks)): 319 | if len(attn_blocks) > 0: 320 | for res_block, attn_block in zip(res_blocks, attn_blocks): 321 | h = res_block(hs[-1], t) 322 | h = attn_block(h) 323 | hs.append(h) 324 | else: 325 | for res_block in res_blocks: 326 | h = res_block(hs[-1], t) 327 | hs.append(h) 328 | if len(spatial_blocks) > 0: 329 | spatial_block, = spatial_blocks 330 | hs.append(spatial_block(hs[-1])) 331 | 332 | h = hs[-1] 333 | h = self.mid_block1(h, t) 334 | h = self.mid_attn(h) 335 | h = self.mid_block2(h, t) 336 | 337 | for i, (res_blocks, attn_blocks, spatial_blocks) in enumerate( 338 | zip(self.ups_res_blocks, self.ups_attn_blocks, self.ups_spatial_blocks)): 339 | if len(attn_blocks) > 0: 340 | for res_block, attn_block in zip(res_blocks, attn_blocks): 341 | h = res_block(torch.cat((h, hs.pop()), dim=1), t) 342 | h = attn_block(h) 343 | else: 344 | for res_block in res_blocks: 345 | h = res_block(torch.cat((h, hs.pop()), dim=1), t) 346 | if len(spatial_blocks) > 0: 347 | spatial_block, = spatial_blocks 348 | h = spatial_block(h) 349 | return self.final_conv(h) 350 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # These are the assumed default build requirements from pip: 3 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 4 | requires = ["setuptools>=40.8.0", "wheel"] 5 | build-backend = "setuptools.build_meta" 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | 18 | import setuptools 19 | 20 | # for simplicity we actually store the version in the __version__ attribute in the source 21 | here = os.path.realpath(os.path.dirname(__file__)) 22 | with open(os.path.join(here, 'torchsde', '__init__.py')) as f: 23 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 24 | if meta_match: 25 | version = meta_match.group(1) 26 | else: 27 | raise RuntimeError("Unable to find __version__ string.") 28 | 29 | with open(os.path.join(here, 'README.md')) as f: 30 | readme = f.read() 31 | 32 | setuptools.setup( 33 | name="torchsde", 34 | version=version, 35 | author="Xuechen Li, Patrick Kidger", 36 | author_email="lxuechen@cs.stanford.edu, hello@kidger.site", 37 | description="SDE solvers and stochastic adjoint sensitivity analysis in PyTorch.", 38 | long_description=readme, 39 | long_description_content_type="text/markdown", 40 | url="https://github.com/google-research/torchsde", 41 | packages=setuptools.find_packages(exclude=['benchmarks', 'diagnostics', 'examples', 'tests']), 42 | install_requires=[ 43 | "numpy>=1.19", 44 | "scipy>=1.5", 45 | "torch>=1.6.0", 46 | "trampoline>=0.1.2", 47 | ], 48 | python_requires='>=3.8', 49 | classifiers=[ 50 | "Programming Language :: Python :: 3", 51 | "License :: OSI Approved :: Apache Software License", 52 | ], 53 | ) 54 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/test_adjoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Compare gradients computed with adjoint vs analytical solution.""" 16 | import sys 17 | 18 | sys.path = sys.path[1:] # A hack so that we always import the installed library. 19 | 20 | import pytest 21 | import torch 22 | import torchsde 23 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, METHODS, NOISE_TYPES, SDE_TYPES 24 | 25 | from . import utils 26 | from . import problems 27 | 28 | torch.manual_seed(1147481649) 29 | torch.set_default_dtype(torch.float64) 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | dtype = torch.get_default_dtype() 32 | 33 | 34 | def _methods(): 35 | yield SDE_TYPES.ito, METHODS.milstein, None 36 | yield SDE_TYPES.ito, METHODS.srk, None 37 | yield SDE_TYPES.stratonovich, METHODS.midpoint, None 38 | yield SDE_TYPES.stratonovich, METHODS.reversible_heun, None 39 | 40 | 41 | @pytest.mark.parametrize("sde_cls", [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive, 42 | problems.NeuralGeneral]) 43 | @pytest.mark.parametrize("sde_type, method, options", _methods()) 44 | @pytest.mark.parametrize('adaptive', (False,)) 45 | def test_against_numerical(sde_cls, sde_type, method, options, adaptive): 46 | # Skipping below, since method not supported for corresponding noise types. 47 | if sde_cls.noise_type == NOISE_TYPES.general and method in (METHODS.milstein, METHODS.srk): 48 | return 49 | 50 | d = 3 51 | m = { 52 | NOISE_TYPES.scalar: 1, 53 | NOISE_TYPES.diagonal: d, 54 | NOISE_TYPES.general: 2, 55 | NOISE_TYPES.additive: 2 56 | }[sde_cls.noise_type] 57 | batch_size = 4 58 | t0, t1 = ts = torch.tensor([0.0, 0.5], device=device) 59 | dt = 1e-3 60 | y0 = torch.full((batch_size, d), 0.1, device=device) 61 | sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device) 62 | 63 | if method == METHODS.srk: 64 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time 65 | else: 66 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none 67 | bm = torchsde.BrownianInterval( 68 | t0=t0, t1=t1, size=(batch_size, m), dtype=dtype, device=device, 69 | levy_area_approximation=levy_area_approximation 70 | ) 71 | 72 | if method == METHODS.reversible_heun: 73 | tol = 1e-6 74 | adjoint_method = METHODS.adjoint_reversible_heun 75 | adjoint_options = options 76 | else: 77 | tol = 1e-2 78 | adjoint_method = None 79 | adjoint_options = None 80 | 81 | def func(inputs, modules): 82 | y0, sde = inputs[0], modules[0] 83 | ys = torchsde.sdeint_adjoint(sde, y0, ts, dt=dt, method=method, adjoint_method=adjoint_method, 84 | adaptive=adaptive, bm=bm, options=options, adjoint_options=adjoint_options) 85 | return (ys[-1] ** 2).sum(dim=1).mean(dim=0) 86 | 87 | # `grad_inputs=True` also works, but we really only care about grad wrt params and want fast tests. 88 | utils.gradcheck(func, y0, sde, eps=1e-6, rtol=tol, atol=tol, grad_params=True) 89 | 90 | 91 | def _methods_dt_tol(): 92 | for sde_type, method, options in _methods(): 93 | if method == METHODS.reversible_heun: 94 | yield sde_type, method, options, 2**-3, 1e-3, 1e-4 95 | yield sde_type, method, options, 1e-3, 1e-3, 1e-4 96 | else: 97 | yield sde_type, method, options, 1e-3, 1e-2, 1e-2 98 | 99 | 100 | @pytest.mark.parametrize("sde_cls", [problems.NeuralDiagonal, problems.NeuralScalar, problems.NeuralAdditive, 101 | problems.NeuralGeneral]) 102 | @pytest.mark.parametrize("sde_type, method, options, dt, rtol, atol", _methods_dt_tol()) 103 | @pytest.mark.parametrize("len_ts", [2, 9]) 104 | def test_against_sdeint(sde_cls, sde_type, method, options, dt, rtol, atol, len_ts): 105 | # Skipping below, since method not supported for corresponding noise types. 106 | if sde_cls.noise_type == NOISE_TYPES.general and method in (METHODS.milstein, METHODS.srk): 107 | return 108 | 109 | d = 3 110 | m = { 111 | NOISE_TYPES.scalar: 1, 112 | NOISE_TYPES.diagonal: d, 113 | NOISE_TYPES.general: 2, 114 | NOISE_TYPES.additive: 2 115 | }[sde_cls.noise_type] 116 | batch_size = 4 117 | ts = torch.linspace(0.0, 1.0, len_ts, device=device, dtype=torch.float64) 118 | t0 = ts[0] 119 | t1 = ts[-1] 120 | y0 = torch.full((batch_size, d), 0.1, device=device, dtype=torch.float64, requires_grad=True) 121 | sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device, torch.float64) 122 | 123 | if method == METHODS.srk: 124 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time 125 | else: 126 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none 127 | bm = torchsde.BrownianInterval( 128 | t0=t0, t1=t1, size=(batch_size, m), dtype=torch.float64, device=device, 129 | levy_area_approximation=levy_area_approximation 130 | ) 131 | 132 | if method == METHODS.reversible_heun: 133 | adjoint_method = METHODS.adjoint_reversible_heun 134 | adjoint_options = options 135 | else: 136 | adjoint_method = None 137 | adjoint_options = None 138 | 139 | ys_true = torchsde.sdeint(sde, y0, ts, dt=dt, method=method, bm=bm, options=options) 140 | grad = torch.randn_like(ys_true) 141 | ys_true.backward(grad) 142 | 143 | true_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()]) 144 | y0.grad.zero_() 145 | for param in sde.parameters(): 146 | param.grad.zero_() 147 | 148 | ys_test = torchsde.sdeint_adjoint(sde, y0, ts, dt=dt, method=method, bm=bm, adjoint_method=adjoint_method, 149 | options=options, adjoint_options=adjoint_options) 150 | ys_test.backward(grad) 151 | test_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()]) 152 | 153 | torch.testing.assert_allclose(ys_true, ys_test) 154 | torch.testing.assert_allclose(true_grad, test_grad, rtol=rtol, atol=atol) 155 | 156 | 157 | @pytest.mark.parametrize("problem", [problems.BasicSDE1, problems.BasicSDE2, problems.BasicSDE3, problems.BasicSDE4]) 158 | @pytest.mark.parametrize("method", ["milstein", "srk"]) 159 | @pytest.mark.parametrize('adaptive', (False, True)) 160 | def test_basic(problem, method, adaptive): 161 | d = 10 162 | batch_size = 128 163 | ts = torch.tensor([0.0, 0.5], device=device) 164 | dt = 1e-3 165 | y0 = torch.zeros(batch_size, d).to(device).fill_(0.1) 166 | 167 | problem = problem(d).to(device) 168 | 169 | num_before = _count_differentiable_params(problem) 170 | 171 | problem.zero_grad() 172 | _, yt = torchsde.sdeint_adjoint(problem, y0, ts, method=method, dt=dt, adaptive=adaptive) 173 | loss = yt.sum(dim=1).mean(dim=0) 174 | loss.backward() 175 | 176 | num_after = _count_differentiable_params(problem) 177 | assert num_before == num_after 178 | 179 | 180 | def _count_differentiable_params(module): 181 | return len([p for p in module.parameters() if p.requires_grad]) 182 | -------------------------------------------------------------------------------- /tests/test_brownian_path.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Test `BrownianPath`. 16 | 17 | The suite tests both running on CPU and CUDA (if available). 18 | """ 19 | import sys 20 | 21 | sys.path = sys.path[1:] # A hack so that we always import the installed library. 22 | 23 | import math 24 | import numpy as np 25 | import numpy.random as npr 26 | import torch 27 | from scipy.stats import norm, kstest 28 | 29 | import torchsde 30 | import pytest 31 | 32 | torch.manual_seed(1147481649) 33 | torch.set_default_dtype(torch.float64) 34 | 35 | D = 3 36 | BATCH_SIZE = 131072 37 | REPS = 3 38 | ALPHA = 0.00001 39 | 40 | devices = [cpu, gpu] = [torch.device('cpu'), torch.device('cuda')] 41 | 42 | 43 | def _setup(device): 44 | t0, t1 = torch.tensor([0., 1.], device=device) 45 | w0, w1 = torch.randn([2, BATCH_SIZE, D], device=device) 46 | t = torch.rand([], device=device) 47 | bm = torchsde.BrownianPath(t0=t0, w0=w0) 48 | return t, bm 49 | 50 | 51 | @pytest.mark.parametrize("device", devices) 52 | def test_basic(device): 53 | if device == gpu and not torch.cuda.is_available(): 54 | pytest.skip(reason="CUDA not available.") 55 | 56 | t, bm = _setup(device) 57 | sample = bm(t) 58 | assert sample.size() == (BATCH_SIZE, D) 59 | 60 | 61 | @pytest.mark.parametrize("device", devices) 62 | def test_determinism(device): 63 | if device == gpu and not torch.cuda.is_available(): 64 | pytest.skip(reason="CUDA not available.") 65 | 66 | t, bm = _setup(device) 67 | vals = [bm(t) for _ in range(REPS)] 68 | for val in vals[1:]: 69 | assert torch.allclose(val, vals[0]) 70 | 71 | 72 | @pytest.mark.parametrize("device", devices) 73 | def test_normality(device): 74 | if device == gpu and not torch.cuda.is_available(): 75 | pytest.skip(reason="CUDA not available.") 76 | 77 | t0_, t1_ = 0.0, 1.0 78 | eps = 1e-2 79 | for _ in range(REPS): 80 | w0_ = npr.randn() * math.sqrt(t1_) 81 | w0 = torch.tensor(w0_, device=device).repeat(BATCH_SIZE) 82 | 83 | bm = torchsde.BrownianPath(t0=t0_, w0=w0) # noqa 84 | 85 | w1_ = bm(t1_).cpu().numpy() 86 | 87 | t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps) # Avoid sampling too close to the boundary. 88 | samples_ = bm(t_).cpu().numpy() 89 | 90 | # True expected mean from Brownian bridge. 91 | mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_) 92 | std_ = math.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_)) 93 | ref_dist = norm(loc=np.zeros_like(mean_), scale=np.ones_like(std_)) 94 | 95 | _, pval = kstest((samples_ - mean_) / std_, ref_dist.cdf) 96 | assert pval >= ALPHA 97 | -------------------------------------------------------------------------------- /tests/test_brownian_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Test `BrownianTree`. 16 | 17 | The suite tests both running on CPU and CUDA (if available). 18 | """ 19 | import sys 20 | 21 | sys.path = sys.path[1:] # A hack so that we always import the installed library. 22 | 23 | import numpy as np 24 | import numpy.random as npr 25 | import torch 26 | from scipy.stats import norm, kstest 27 | 28 | import pytest 29 | import torchsde 30 | 31 | torch.manual_seed(0) 32 | torch.set_default_dtype(torch.float64) 33 | 34 | D = 3 35 | SMALL_BATCH_SIZE = 16 36 | LARGE_BATCH_SIZE = 16384 37 | REPS = 3 38 | ALPHA = 0.00001 39 | 40 | devices = [cpu, gpu] = [torch.device('cpu'), torch.device('cuda')] 41 | 42 | 43 | def _setup(device, batch_size): 44 | t0, t1 = torch.tensor([0., 1.], device=device) 45 | w0 = torch.zeros(batch_size, D, device=device) 46 | t = torch.rand([]).to(device) 47 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, entropy=0) 48 | return t, bm 49 | 50 | 51 | def _dict_to_sorted_list(*dicts): 52 | lists = tuple([d[k] for k in sorted(d.keys())] for d in dicts) 53 | if len(lists) == 1: 54 | return lists[0] 55 | return lists 56 | 57 | 58 | @pytest.mark.parametrize("device", devices) 59 | def test_basic(device): 60 | if device == gpu and not torch.cuda.is_available(): 61 | pytest.skip(reason="CUDA not available.") 62 | 63 | t, bm = _setup(device, SMALL_BATCH_SIZE) 64 | sample = bm(t) 65 | assert sample.size() == (SMALL_BATCH_SIZE, D) 66 | 67 | 68 | @pytest.mark.parametrize("device", devices) 69 | def test_determinism(device): 70 | if device == gpu and not torch.cuda.is_available(): 71 | pytest.skip(reason="CUDA not available.") 72 | 73 | t, bm = _setup(device, SMALL_BATCH_SIZE) 74 | vals = [bm(t) for _ in range(REPS)] 75 | for val in vals[1:]: 76 | assert torch.allclose(val, vals[0]) 77 | 78 | 79 | @pytest.mark.parametrize("device", devices) 80 | def test_normality(device): 81 | if device == gpu and not torch.cuda.is_available(): 82 | pytest.skip(reason="CUDA not available.") 83 | 84 | t0_, t1_ = 0.0, 1.0 85 | t0, t1 = torch.tensor([t0_, t1_], device=device) 86 | eps = 1e-5 87 | for _ in range(REPS): 88 | w0_, w1_ = 0.0, npr.randn() 89 | w0 = torch.tensor(w0_, device=device).repeat(LARGE_BATCH_SIZE) 90 | w1 = torch.tensor(w1_, device=device).repeat(LARGE_BATCH_SIZE) 91 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, pool_size=100, tol=1e-14) # noqa 92 | 93 | for _ in range(REPS): 94 | t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps) 95 | samples = bm(t_) 96 | samples_ = samples.cpu().detach().numpy() 97 | 98 | mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_) 99 | std_ = np.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_)) 100 | ref_dist = norm(loc=mean_, scale=std_) 101 | 102 | _, pval = kstest(samples_, ref_dist.cdf) 103 | assert pval >= ALPHA 104 | -------------------------------------------------------------------------------- /tests/test_sdeint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | 17 | sys.path = sys.path[1:] # A hack so that we always import the installed library. 18 | 19 | import pytest 20 | import torch 21 | import torchsde 22 | from torchsde.settings import NOISE_TYPES 23 | 24 | from . import problems 25 | 26 | torch.manual_seed(1147481649) 27 | torch.set_default_dtype(torch.float64) 28 | devices = ['cpu'] 29 | if torch.cuda.is_available(): 30 | devices.append('cuda') 31 | 32 | batch_size = 4 33 | d = 3 34 | m = 2 35 | t0 = 0.0 36 | t1 = 0.3 37 | T = 5 38 | dt = 0.05 39 | dtype = torch.get_default_dtype() 40 | 41 | 42 | class _nullcontext: 43 | def __enter__(self): 44 | pass 45 | 46 | def __exit__(self, exc_type, exc_val, exc_tb): 47 | pass 48 | 49 | 50 | @pytest.mark.parametrize('device', devices) 51 | def test_rename_methods(device): 52 | """Test renaming works with a subset of names.""" 53 | sde = problems.CustomNamesSDE().to(device) 54 | y0 = torch.ones(batch_size, d, device=device) 55 | ts = torch.linspace(t0, t1, steps=T, device=device) 56 | ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'}) 57 | assert ans.shape == (T, batch_size, d) 58 | 59 | 60 | @pytest.mark.parametrize('device', devices) 61 | def test_rename_methods_logqp(device): 62 | """Test renaming works with a subset of names when `logqp=True`.""" 63 | sde = problems.CustomNamesSDELogqp().to(device) 64 | y0 = torch.ones(batch_size, d, device=device) 65 | ts = torch.linspace(t0, t1, steps=T, device=device) 66 | ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True) 67 | assert ans[0].shape == (T, batch_size, d) 68 | assert ans[1].shape == (T - 1, batch_size) 69 | 70 | 71 | def _use_bm__levy_area_approximation(): 72 | yield False, None 73 | yield True, 'none' 74 | yield True, 'space-time' 75 | yield True, 'davie' 76 | yield True, 'foster' 77 | 78 | 79 | @pytest.mark.parametrize('sde_type,method', [('ito', 'euler'), ('stratonovich', 'midpoint')]) 80 | def test_specialised_functions(sde_type, method): 81 | vector = torch.randn(m) 82 | fg = problems.FGSDE(sde_type, vector) 83 | f_and_g = problems.FAndGSDE(sde_type, vector) 84 | g_prod = problems.GProdSDE(sde_type, vector) 85 | f_and_g_prod = problems.FAndGProdSDE(sde_type, vector) 86 | f_and_g_with_g_prod1 = problems.FAndGGProdSDE1(sde_type, vector) 87 | f_and_g_with_g_prod2 = problems.FAndGGProdSDE2(sde_type, vector) 88 | 89 | y0 = torch.randn(batch_size, d) 90 | 91 | outs = [] 92 | for sde in (fg, f_and_g, g_prod, f_and_g_prod, f_and_g_with_g_prod1, f_and_g_with_g_prod2): 93 | bm = torchsde.BrownianInterval(t0, t1, (batch_size, m), entropy=45678) 94 | outs.append(torchsde.sdeint(sde, y0, [t0, t1], dt=dt, bm=bm)[1]) 95 | for o in outs[1:]: 96 | # Equality of floating points, because we expect them to do everything exactly the same. 97 | assert o.shape == outs[0].shape 98 | assert (o == outs[0]).all() 99 | 100 | 101 | @pytest.mark.parametrize('sde_cls', [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive, 102 | problems.NeuralGeneral]) 103 | @pytest.mark.parametrize('use_bm,levy_area_approximation', _use_bm__levy_area_approximation()) 104 | @pytest.mark.parametrize('sde_type', ['ito', 'stratonovich']) 105 | @pytest.mark.parametrize('method', 106 | ['blah', 'euler', 'milstein', 'milstein_grad_free', 'srk', 'euler_heun', 'heun', 'midpoint', 107 | 'log_ode']) 108 | @pytest.mark.parametrize('adaptive', [False, True]) 109 | @pytest.mark.parametrize('logqp', [True, False]) 110 | @pytest.mark.parametrize('device', devices) 111 | def test_sdeint_run_shape_method(sde_cls, use_bm, levy_area_approximation, sde_type, method, adaptive, logqp, device): 112 | """Tests that sdeint: 113 | (a) runs/raises an error as appropriate 114 | (b) produces tensors of the right shape 115 | (c) accepts every method 116 | """ 117 | 118 | if method == 'milstein_grad_free': 119 | method = 'milstein' 120 | options = dict(grad_free=True) 121 | else: 122 | options = dict() 123 | 124 | should_fail = False 125 | if sde_type == 'ito': 126 | if method not in ('euler', 'srk', 'milstein'): 127 | should_fail = True 128 | else: 129 | if method not in ('euler_heun', 'heun', 'midpoint', 'log_ode', 'milstein'): 130 | should_fail = True 131 | if method in ('milstein', 'srk') and sde_cls.noise_type == 'general': 132 | should_fail = True 133 | if method == 'srk' and levy_area_approximation == 'none': 134 | should_fail = True 135 | if method == 'log_ode' and levy_area_approximation in ('none', 'space-time'): 136 | should_fail = True 137 | 138 | if sde_cls.noise_type in (NOISE_TYPES.scalar, NOISE_TYPES.diagonal): 139 | kwargs = {'d': d} 140 | else: 141 | kwargs = {'d': d, 'm': m} 142 | sde = sde_cls(sde_type=sde_type, **kwargs).to(device) 143 | 144 | if use_bm: 145 | if sde_cls.noise_type == 'scalar': 146 | size = (batch_size, 1) 147 | elif sde_cls.noise_type == 'diagonal': 148 | size = (batch_size, d + 1) if logqp else (batch_size, d) 149 | else: 150 | assert sde_cls.noise_type in ('additive', 'general') 151 | size = (batch_size, m) 152 | bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, 153 | levy_area_approximation=levy_area_approximation) 154 | else: 155 | bm = None 156 | 157 | _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options) 158 | 159 | 160 | @pytest.mark.parametrize("sde_cls", [problems.BasicSDE1, problems.BasicSDE2, problems.BasicSDE3, problems.BasicSDE4]) 161 | @pytest.mark.parametrize('method', ['euler', 'milstein', 'milstein_grad_free', 'srk']) 162 | @pytest.mark.parametrize('adaptive', [False, True]) 163 | @pytest.mark.parametrize('device', devices) 164 | def test_sdeint_dependencies(sde_cls, method, adaptive, device): 165 | """This test uses diagonal noise. This checks if the solvers still work when some of the functions don't depend on 166 | the states/params and when some states/params don't require gradients. 167 | """ 168 | 169 | if method == 'milstein_grad_free': 170 | method = 'milstein' 171 | options = dict(grad_free=True) 172 | else: 173 | options = dict() 174 | 175 | sde = sde_cls(d=d).to(device) 176 | bm = None 177 | logqp = False 178 | should_fail = False 179 | _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options) 180 | 181 | 182 | def _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options): 183 | y0 = torch.ones(batch_size, d, device=device) 184 | ts = torch.linspace(t0, t1, steps=T, device=device) 185 | if adaptive and method == 'euler' and sde.noise_type != 'additive': 186 | ctx = pytest.warns(UserWarning) 187 | else: 188 | ctx = _nullcontext() 189 | 190 | # Using `f` as drift. 191 | with torch.no_grad(): 192 | try: 193 | with ctx: 194 | ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, logqp=logqp, 195 | options=options) 196 | except ValueError: 197 | if should_fail: 198 | return 199 | raise 200 | else: 201 | if should_fail: 202 | pytest.fail("Expected an error; did not get one.") 203 | if logqp: 204 | ans, log_ratio = ans 205 | assert log_ratio.shape == (T - 1, batch_size) 206 | assert ans.shape == (T, batch_size, d) 207 | 208 | # Using `h` as drift. 209 | with torch.no_grad(): 210 | with ctx: 211 | ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, names={'drift': 'h'}, 212 | logqp=logqp, options=options) 213 | if logqp: 214 | ans, log_ratio = ans 215 | assert log_ratio.shape == (T - 1, batch_size) 216 | assert ans.shape == (T, batch_size, d) 217 | 218 | 219 | @pytest.mark.parametrize("sde_cls", [problems.NeuralDiagonal, problems.NeuralScalar, problems.NeuralAdditive, 220 | problems.NeuralGeneral]) 221 | def test_reversibility(sde_cls): 222 | batch_size = 32 223 | state_size = 4 224 | t_size = 20 225 | dt = 0.1 226 | 227 | brownian_size = { 228 | NOISE_TYPES.scalar: 1, 229 | NOISE_TYPES.diagonal: state_size, 230 | NOISE_TYPES.general: 2, 231 | NOISE_TYPES.additive: 2 232 | }[sde_cls.noise_type] 233 | 234 | class MinusSDE(torch.nn.Module): 235 | def __init__(self, sde): 236 | self.noise_type = sde.noise_type 237 | self.sde_type = sde.sde_type 238 | self.f = lambda t, y: -sde.f(-t, y) 239 | self.g = lambda t, y: -sde.g(-t, y) 240 | 241 | sde = sde_cls(d=state_size, m=brownian_size, sde_type='stratonovich') 242 | minus_sde = MinusSDE(sde) 243 | y0 = torch.full((batch_size, state_size), 0.1) 244 | ts = torch.linspace(0, (t_size - 1) * dt, t_size) 245 | bm = torchsde.BrownianInterval(t0=ts[0], t1=ts[-1], size=(batch_size, brownian_size)) 246 | ys, (f, g, z) = torchsde.sdeint(sde, y0, ts, bm=bm, method='reversible_heun', dt=dt, extra=True) 247 | backward_ts = -ts.flip(0) 248 | backward_ys = torchsde.sdeint(minus_sde, ys[-1], backward_ts, bm=torchsde.ReverseBrownian(bm), 249 | method='reversible_heun', dt=dt, extra_solver_state=(-f, -g, z)) 250 | backward_ys = backward_ys.flip(0) 251 | 252 | torch.testing.assert_allclose(ys, backward_ys, rtol=1e-6, atol=1e-6) 253 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from torchsde.types import Callable, ModuleOrModules, Optional, TensorOrTensors 21 | 22 | 23 | # These tolerances don't need to be this large. For gradients to match up in the Ito case, we typically need large 24 | # values; not so much as in the Stratonovich case. 25 | def assert_allclose(actual, expected, rtol=1e-3, atol=1e-2): 26 | if actual is None: 27 | assert expected is None 28 | else: 29 | torch.testing.assert_allclose(actual, expected, rtol=rtol, atol=atol) 30 | 31 | 32 | def gradcheck(func: Callable, 33 | inputs: TensorOrTensors, 34 | modules: Optional[ModuleOrModules] = (), 35 | eps: float = 1e-6, 36 | atol: float = 1e-5, 37 | rtol: float = 1e-3, 38 | grad_inputs=False, 39 | gradgrad_inputs=False, 40 | grad_params=False, 41 | gradgrad_params=False): 42 | """Check grad and grad of grad wrt inputs and parameters of Modules. 43 | 44 | When `func` is vector-valued, the checks compare autodiff vjp against 45 | finite-difference vjp, where v is a sampled standard normal vector. 46 | 47 | This function is aimed to be as self-contained as possible so that it could 48 | be copied/pasted across different projects. 49 | 50 | Args: 51 | func (callable): A Python function that takes in a sequence of tensors 52 | (inputs) and a sequence of nn.Module (modules), and outputs a tensor 53 | or a sequence of tensors. 54 | inputs (sequence of Tensors): The input tensors. 55 | modules (sequence of nn.Module): The modules whose parameter gradient 56 | needs to be tested. 57 | eps (float, optional): Magnitude of two-sided finite difference 58 | perturbation. 59 | atol (float, optional): Absolute tolerance. 60 | rtol (float, optional): Relative tolerance. 61 | grad_inputs (bool, optional): Check gradients wrt inputs if True. 62 | gradgrad_inputs (bool, optional): Check gradients of gradients wrt 63 | inputs if True. 64 | grad_params (bool, optional): Check gradients wrt differentiable 65 | parameters of modules if True. 66 | gradgrad_params (bool, optional): Check gradients of gradients wrt 67 | differentiable parameters of modules if True. 68 | 69 | Returns: 70 | None. 71 | """ 72 | 73 | def convert_none_to_zeros(sequence, like_sequence): 74 | return [torch.zeros_like(q) if p is None else p for p, q in zip(sequence, like_sequence)] 75 | 76 | def flatten(sequence): 77 | return torch.cat([p.reshape(-1) for p in sequence]) if len(sequence) > 0 else torch.tensor([]) 78 | 79 | if isinstance(inputs, torch.Tensor): 80 | inputs = (inputs,) 81 | 82 | if isinstance(modules, nn.Module): 83 | modules = (modules,) 84 | 85 | # Don't modify original objects. 86 | modules = tuple(copy.deepcopy(m) for m in modules) 87 | inputs = tuple(i.clone().requires_grad_() for i in inputs) 88 | 89 | func = _make_scalar_valued_func(func, inputs, modules) 90 | func_only_inputs = lambda *args: func(args, modules) # noqa 91 | 92 | # Grad wrt inputs. 93 | if grad_inputs: 94 | torch.autograd.gradcheck(func_only_inputs, inputs, eps=eps, atol=atol, rtol=rtol) 95 | 96 | # Grad of grad wrt inputs. 97 | if gradgrad_inputs: 98 | torch.autograd.gradgradcheck(func_only_inputs, inputs, eps=eps, atol=atol, rtol=rtol) 99 | 100 | # Grad wrt params. 101 | if grad_params: 102 | params = [p for m in modules for p in m.parameters() if p.requires_grad] 103 | loss = func(inputs, modules) 104 | framework_grad = flatten(convert_none_to_zeros(torch.autograd.grad(loss, params, create_graph=True), params)) 105 | 106 | numerical_grad = [] 107 | for param in params: 108 | flat_param = param.reshape(-1) 109 | for i in range(len(flat_param)): 110 | flat_param[i].data.add_(eps) 111 | plus_eps = func(inputs, modules).detach() 112 | flat_param[i].data.sub_(eps) 113 | 114 | flat_param[i].data.sub_(eps) 115 | minus_eps = func(inputs, modules).detach() 116 | flat_param[i].data.add_(eps) 117 | 118 | numerical_grad.append((plus_eps - minus_eps) / (2 * eps)) 119 | del plus_eps, minus_eps 120 | numerical_grad = torch.stack(numerical_grad) 121 | torch.testing.assert_allclose(numerical_grad, framework_grad, rtol=rtol, atol=atol) 122 | 123 | # Grad of grad wrt params. 124 | if gradgrad_params: 125 | def func_high_order(inputs, modules): 126 | params = [p for m in modules for p in m.parameters() if p.requires_grad] 127 | grads = torch.autograd.grad(func(inputs, modules), params, create_graph=True, allow_unused=True) 128 | return tuple(grad for grad in grads if grad is not None) 129 | 130 | gradcheck(func_high_order, inputs, modules, rtol=rtol, atol=atol, eps=eps, grad_params=True) 131 | 132 | 133 | def _make_scalar_valued_func(func, inputs, modules): 134 | outputs = func(inputs, modules) 135 | output_size = outputs.numel() if torch.is_tensor(outputs) else sum(o.numel() for o in outputs) 136 | 137 | if output_size > 1: 138 | # Define this outside `func_scalar_valued` so that random tensors are generated only once. 139 | grad_outputs = tuple(torch.randn_like(o) for o in outputs) 140 | 141 | def func_scalar_valued(inputs, modules): 142 | outputs = func(inputs, modules) 143 | return sum((output * grad_output).sum() for output, grad_output, in zip(outputs, grad_outputs)) 144 | 145 | return func_scalar_valued 146 | 147 | return func 148 | -------------------------------------------------------------------------------- /torchsde/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ._brownian import (BaseBrownian, BrownianInterval, BrownianPath, BrownianTree, ReverseBrownian, 16 | brownian_interval_like) 17 | from ._core.adjoint import sdeint_adjoint 18 | from ._core.base_sde import BaseSDE, SDEIto, SDEStratonovich 19 | from ._core.sdeint import sdeint 20 | 21 | BrownianInterval.__init__.__annotations__ = {} 22 | BrownianPath.__init__.__annotations__ = {} 23 | BrownianTree.__init__.__annotations__ = {} 24 | sdeint.__annotations__ = {} 25 | sdeint_adjoint.__annotations__ = {} 26 | 27 | __version__ = '0.2.6' 28 | -------------------------------------------------------------------------------- /torchsde/_brownian/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .brownian_base import BaseBrownian 16 | from .brownian_interval import BrownianInterval 17 | from .derived import ReverseBrownian, BrownianPath, BrownianTree, brownian_interval_like 18 | -------------------------------------------------------------------------------- /torchsde/_brownian/brownian_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | 17 | 18 | class BaseBrownian(metaclass=abc.ABCMeta): 19 | __slots__ = () 20 | 21 | @abc.abstractmethod 22 | def __call__(self, ta, tb=None, return_U=False, return_A=False): 23 | raise NotImplementedError 24 | 25 | @abc.abstractmethod 26 | def __repr__(self): 27 | raise NotImplementedError 28 | 29 | @property 30 | @abc.abstractmethod 31 | def dtype(self): 32 | raise NotImplementedError 33 | 34 | @property 35 | @abc.abstractmethod 36 | def device(self): 37 | raise NotImplementedError 38 | 39 | @property 40 | @abc.abstractmethod 41 | def shape(self): 42 | raise NotImplementedError 43 | 44 | @property 45 | @abc.abstractmethod 46 | def levy_area_approximation(self): 47 | raise NotImplementedError 48 | 49 | def size(self): 50 | return self.shape 51 | -------------------------------------------------------------------------------- /torchsde/_brownian/derived.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from . import brownian_base 18 | from . import brownian_interval 19 | from ..types import Optional, Scalar, Tensor, Tuple, Union 20 | 21 | 22 | class ReverseBrownian(brownian_base.BaseBrownian): 23 | def __init__(self, base_brownian): 24 | super(ReverseBrownian, self).__init__() 25 | self.base_brownian = base_brownian 26 | 27 | def __call__(self, ta, tb=None, return_U=False, return_A=False): 28 | # Whether or not to negate the statistics depends on the return value of the adjoint SDE. Currently, the adjoint 29 | # returns negated drift and diffusion, so we don't negate here. 30 | return self.base_brownian(-tb, -ta, return_U=return_U, return_A=return_A) 31 | 32 | def __repr__(self): 33 | return f"{self.__class__.__name__}(base_brownian={self.base_brownian})" 34 | 35 | @property 36 | def dtype(self): 37 | return self.base_brownian.dtype 38 | 39 | @property 40 | def device(self): 41 | return self.base_brownian.device 42 | 43 | @property 44 | def shape(self): 45 | return self.base_brownian.shape 46 | 47 | @property 48 | def levy_area_approximation(self): 49 | return self.base_brownian.levy_area_approximation 50 | 51 | 52 | class BrownianPath(brownian_base.BaseBrownian): 53 | """Brownian path, storing every computed value. 54 | 55 | Useful for speed, when memory isn't a concern. 56 | 57 | To use: 58 | >>> bm = BrownianPath(t0=0.0, w0=torch.zeros(4, 1)) 59 | >>> bm(0., 0.5) 60 | tensor([[ 0.0733], 61 | [-0.5692], 62 | [ 0.1872], 63 | [-0.3889]]) 64 | """ 65 | 66 | def __init__(self, t0: Scalar, w0: Tensor, window_size: int = 8): 67 | """Initialize Brownian path. 68 | Arguments: 69 | t0: Initial time. 70 | w0: Initial state. 71 | window_size: Unused; deprecated. 72 | """ 73 | t1 = t0 + 1 74 | self._w0 = w0 75 | self._interval = brownian_interval.BrownianInterval(t0=t0, t1=t1, size=w0.shape, dtype=w0.dtype, 76 | device=w0.device, cache_size=None) 77 | super(BrownianPath, self).__init__() 78 | 79 | def __call__(self, t, tb=None, return_U=False, return_A=False): 80 | # Deliberately called t rather than ta, for backward compatibility 81 | out = self._interval(t, tb, return_U=return_U, return_A=return_A) 82 | if tb is None and not return_U and not return_A: 83 | out = out + self._w0 84 | return out 85 | 86 | def __repr__(self): 87 | return f"{self.__class__.__name__}(interval={self._interval})" 88 | 89 | @property 90 | def dtype(self): 91 | return self._interval.dtype 92 | 93 | @property 94 | def device(self): 95 | return self._interval.device 96 | 97 | @property 98 | def shape(self): 99 | return self._interval.shape 100 | 101 | @property 102 | def levy_area_approximation(self): 103 | return self._interval.levy_area_approximation 104 | 105 | 106 | class BrownianTree(brownian_base.BaseBrownian): 107 | """Brownian tree with fixed entropy. 108 | 109 | Useful when the map from entropy -> Brownian motion shouldn't depend on the 110 | locations and order of the query points. (As the usual BrownianInterval 111 | does - note that BrownianTree is slower as a result though.) 112 | 113 | To use: 114 | >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1)) 115 | >>> bm(0., 0.5) 116 | tensor([[ 0.0733], 117 | [-0.5692], 118 | [ 0.1872], 119 | [-0.3889]], device='cuda:0') 120 | """ 121 | 122 | def __init__(self, t0: Scalar, 123 | w0: Tensor, 124 | t1: Optional[Scalar] = None, 125 | w1: Optional[Tensor] = None, 126 | entropy: Optional[int] = None, 127 | tol: float = 1e-6, 128 | pool_size: int = 24, 129 | cache_depth: int = 9, 130 | safety: Optional[float] = None): 131 | """Initialize the Brownian tree. 132 | 133 | The random value generation process exploits the parallel random number paradigm and uses 134 | `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`). 135 | 136 | Arguments: 137 | t0: Initial time. 138 | w0: Initial state. 139 | t1: Terminal time. 140 | w1: Terminal state. 141 | entropy: Global seed, defaults to `None` for random entropy. 142 | tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol). 143 | pool_size: Size of the pooled entropy. This parameter affects the query speed significantly. 144 | cache_depth: Unused; deprecated. 145 | safety: Unused; deprecated. 146 | """ 147 | 148 | if t1 is None: 149 | t1 = t0 + 1 150 | if w1 is None: 151 | W = None 152 | else: 153 | W = w1 - w0 154 | self._w0 = w0 155 | self._interval = brownian_interval.BrownianInterval(t0=t0, 156 | t1=t1, 157 | size=w0.shape, 158 | dtype=w0.dtype, 159 | device=w0.device, 160 | entropy=entropy, 161 | tol=tol, 162 | pool_size=pool_size, 163 | halfway_tree=True, 164 | W=W) 165 | super(BrownianTree, self).__init__() 166 | 167 | def __call__(self, t, tb=None, return_U=False, return_A=False): 168 | # Deliberately called t rather than ta, for backward compatibility 169 | out = self._interval(t, tb, return_U=return_U, return_A=return_A) 170 | if tb is None and not return_U and not return_A: 171 | out = out + self._w0 172 | return out 173 | 174 | def __repr__(self): 175 | return f"{self.__class__.__name__}(interval={self._interval})" 176 | 177 | @property 178 | def dtype(self): 179 | return self._interval.dtype 180 | 181 | @property 182 | def device(self): 183 | return self._interval.device 184 | 185 | @property 186 | def shape(self): 187 | return self._interval.shape 188 | 189 | @property 190 | def levy_area_approximation(self): 191 | return self._interval.levy_area_approximation 192 | 193 | 194 | def brownian_interval_like(y: Tensor, 195 | t0: Optional[Scalar] = 0., 196 | t1: Optional[Scalar] = 1., 197 | size: Optional[Tuple[int, ...]] = None, 198 | dtype: Optional[torch.dtype] = None, 199 | device: Optional[Union[str, torch.device]] = None, 200 | **kwargs): 201 | """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor.""" 202 | size = y.shape if size is None else size 203 | dtype = y.dtype if dtype is None else dtype 204 | device = y.device if device is None else device 205 | return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs) 206 | -------------------------------------------------------------------------------- /torchsde/_core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /torchsde/_core/adaptive_stepping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from . import misc 18 | from ..types import TensorOrTensors 19 | 20 | 21 | def update_step_size(error_estimate, prev_step_size, safety=0.9, facmin=0.2, facmax=1.4, prev_error_ratio=None): 22 | """Adaptively propose the next step size based on estimated errors.""" 23 | if error_estimate > 1: 24 | pfactor = 0 25 | ifactor = 1 / 1.5 # 1 / 5 26 | else: 27 | pfactor = 0.13 28 | ifactor = 1 / 4.5 # 1 / 15 29 | 30 | error_ratio = safety / error_estimate 31 | if prev_error_ratio is None: 32 | prev_error_ratio = error_ratio 33 | factor = error_ratio ** ifactor * (error_ratio / prev_error_ratio) ** pfactor 34 | if error_estimate <= 1: 35 | prev_error_ratio = error_ratio 36 | facmin = 1.0 37 | factor = min(facmax, max(facmin, factor)) 38 | new_step_size = prev_step_size * factor 39 | return new_step_size, prev_error_ratio 40 | 41 | 42 | def compute_error(y11: TensorOrTensors, y12: TensorOrTensors, rtol, atol, eps=1e-7): 43 | """Computer error estimate. 44 | 45 | Args: 46 | y11: A tensor or a sequence of tensors obtained with a full update. 47 | y12: A tensor or a sequence of tensors obtained with two half updates. 48 | rtol: Relative tolerance. 49 | atol: Absolute tolerance. 50 | eps: A small constant to avoid division by zero. 51 | 52 | Returns: 53 | A float for the aggregated error estimate. 54 | """ 55 | if torch.is_tensor(y11): 56 | y11 = (y11,) 57 | if torch.is_tensor(y12): 58 | y12 = (y12,) 59 | tol = [ 60 | (rtol * torch.max(torch.abs(y11_), torch.abs(y12_)) + atol).clamp_min(eps) 61 | for y11_, y12_ in zip(y11, y12) 62 | ] 63 | error_estimate = _rms( 64 | [(y11_ - y12_) / tol_ for y11_, y12_, tol_ in zip(y11, y12, tol)], eps 65 | ) 66 | assert not misc.is_nan(error_estimate), ( 67 | 'Found nans in the error estimate. Try increasing the tolerance or regularizing the dynamics.' 68 | ) 69 | return error_estimate.detach().cpu().item() 70 | 71 | 72 | def _rms(x, eps=1e-7): 73 | if torch.is_tensor(x): 74 | return torch.sqrt((x ** 2.).sum() / x.numel()).clamp_min(eps) 75 | else: 76 | return torch.sqrt(sum((x_ ** 2.).sum() for x_ in x) / sum(x_.numel() for x_ in x)).clamp_min(eps) 77 | -------------------------------------------------------------------------------- /torchsde/_core/base_sde.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from . import misc 21 | from ..settings import NOISE_TYPES, SDE_TYPES 22 | from ..types import Tensor 23 | 24 | 25 | class BaseSDE(abc.ABC, nn.Module): 26 | """Base class for all SDEs. 27 | 28 | Inheriting from this class ensures `noise_type` and `sde_type` are valid attributes, which the solver depends on. 29 | """ 30 | 31 | def __init__(self, noise_type, sde_type): 32 | super(BaseSDE, self).__init__() 33 | if noise_type not in NOISE_TYPES: 34 | raise ValueError(f"Expected noise type in {NOISE_TYPES}, but found {noise_type}") 35 | if sde_type not in SDE_TYPES: 36 | raise ValueError(f"Expected sde type in {SDE_TYPES}, but found {sde_type}") 37 | # Making these Python properties breaks `torch.jit.script`. 38 | self.noise_type = noise_type 39 | self.sde_type = sde_type 40 | 41 | 42 | class ForwardSDE(BaseSDE): 43 | 44 | def __init__(self, sde, fast_dg_ga_jvp_column_sum=False): 45 | super(ForwardSDE, self).__init__(sde_type=sde.sde_type, noise_type=sde.noise_type) 46 | self._base_sde = sde 47 | 48 | # Register the core functions. This avoids polluting the codebase with if-statements and achieves speed-ups 49 | # by making sure it's a one-time cost. 50 | 51 | if hasattr(sde, 'f_and_g_prod'): 52 | self.f_and_g_prod = sde.f_and_g_prod 53 | elif hasattr(sde, 'f') and hasattr(sde, 'g_prod'): 54 | self.f_and_g_prod = self.f_and_g_prod_default1 55 | else: # (f_and_g,) or (f, g,). 56 | self.f_and_g_prod = self.f_and_g_prod_default2 57 | 58 | self.f = getattr(sde, 'f', self.f_default) 59 | self.g = getattr(sde, 'g', self.g_default) 60 | self.f_and_g = getattr(sde, 'f_and_g', self.f_and_g_default) 61 | self.g_prod = getattr(sde, 'g_prod', self.g_prod_default) 62 | self.prod = { 63 | NOISE_TYPES.diagonal: self.prod_diagonal 64 | }.get(sde.noise_type, self.prod_default) 65 | self.g_prod_and_gdg_prod = { 66 | NOISE_TYPES.diagonal: self.g_prod_and_gdg_prod_diagonal, 67 | NOISE_TYPES.additive: self.g_prod_and_gdg_prod_additive, 68 | }.get(sde.noise_type, self.g_prod_and_gdg_prod_default) 69 | self.dg_ga_jvp_column_sum = { 70 | NOISE_TYPES.general: ( 71 | self.dg_ga_jvp_column_sum_v2 if fast_dg_ga_jvp_column_sum else self.dg_ga_jvp_column_sum_v1 72 | ) 73 | }.get(sde.noise_type, self._return_zero) 74 | 75 | ######################################## 76 | # f # 77 | ######################################## 78 | def f_default(self, t, y): 79 | raise RuntimeError("Method `f` has not been provided, but is required for this method.") 80 | 81 | ######################################## 82 | # g # 83 | ######################################## 84 | def g_default(self, t, y): 85 | raise RuntimeError("Method `g` has not been provided, but is required for this method.") 86 | 87 | ######################################## 88 | # f_and_g # 89 | ######################################## 90 | 91 | def f_and_g_default(self, t, y): 92 | return self.f(t, y), self.g(t, y) 93 | 94 | ######################################## 95 | # prod # 96 | ######################################## 97 | 98 | def prod_diagonal(self, g, v): 99 | return g * v 100 | 101 | def prod_default(self, g, v): 102 | return misc.batch_mvp(g, v) 103 | 104 | ######################################## 105 | # g_prod # 106 | ######################################## 107 | 108 | def g_prod_default(self, t, y, v): 109 | return self.prod(self.g(t, y), v) 110 | 111 | ######################################## 112 | # f_and_g_prod # 113 | ######################################## 114 | 115 | def f_and_g_prod_default1(self, t, y, v): 116 | return self.f(t, y), self.g_prod(t, y, v) 117 | 118 | def f_and_g_prod_default2(self, t, y, v): 119 | f, g = self.f_and_g(t, y) 120 | return f, self.prod(g, v) 121 | 122 | ######################################## 123 | # g_prod_and_gdg_prod # 124 | ######################################## 125 | 126 | # Computes: g_prod and sum_{j, l} g_{j, l} d g_{j, l} d x_i v2_l. 127 | def g_prod_and_gdg_prod_default(self, t, y, v1, v2): 128 | requires_grad = torch.is_grad_enabled() 129 | with torch.enable_grad(): 130 | y = y if y.requires_grad else y.detach().requires_grad_(True) 131 | g = self.g(t, y) 132 | vg_dg_vjp, = misc.vjp( 133 | outputs=g, 134 | inputs=y, 135 | grad_outputs=g * v2.unsqueeze(-2), 136 | retain_graph=True, 137 | create_graph=requires_grad, 138 | allow_unused=True 139 | ) 140 | return self.prod(g, v1), vg_dg_vjp 141 | 142 | def g_prod_and_gdg_prod_diagonal(self, t, y, v1, v2): 143 | requires_grad = torch.is_grad_enabled() 144 | with torch.enable_grad(): 145 | y = y if y.requires_grad else y.detach().requires_grad_(True) 146 | g = self.g(t, y) 147 | vg_dg_vjp, = misc.vjp( 148 | outputs=g, 149 | inputs=y, 150 | grad_outputs=g * v2, 151 | retain_graph=True, 152 | create_graph=requires_grad, 153 | allow_unused=True 154 | ) 155 | return self.prod(g, v1), vg_dg_vjp 156 | 157 | def g_prod_and_gdg_prod_additive(self, t, y, v1, v2): 158 | return self.g_prod(t, y, v1), 0. 159 | 160 | ######################################## 161 | # dg_ga_jvp # 162 | ######################################## 163 | 164 | # Computes: sum_{j,k,l} d g_{i,l} / d x_j g_{j,k} A_{k,l}. 165 | def dg_ga_jvp_column_sum_v1(self, t, y, a): 166 | requires_grad = torch.is_grad_enabled() 167 | with torch.enable_grad(): 168 | y = y if y.requires_grad else y.detach().requires_grad_(True) 169 | g = self.g(t, y) 170 | ga = torch.bmm(g, a) 171 | dg_ga_jvp = [ 172 | misc.jvp( 173 | outputs=g[..., col_idx], 174 | inputs=y, 175 | grad_inputs=ga[..., col_idx], 176 | retain_graph=True, 177 | create_graph=requires_grad, 178 | allow_unused=True 179 | )[0] 180 | for col_idx in range(g.size(-1)) 181 | ] 182 | dg_ga_jvp = sum(dg_ga_jvp) 183 | return dg_ga_jvp 184 | 185 | def dg_ga_jvp_column_sum_v2(self, t, y, a): 186 | # Faster, but more memory intensive. 187 | requires_grad = torch.is_grad_enabled() 188 | with torch.enable_grad(): 189 | y = y if y.requires_grad else y.detach().requires_grad_(True) 190 | g = self.g(t, y) 191 | ga = torch.bmm(g, a) 192 | 193 | batch_size, d, m = g.size() 194 | y_dup = torch.repeat_interleave(y, repeats=m, dim=0) 195 | g_dup = self.g(t, y_dup) 196 | ga_flat = ga.transpose(1, 2).flatten(0, 1) 197 | dg_ga_jvp, = misc.jvp( 198 | outputs=g_dup, 199 | inputs=y_dup, 200 | grad_inputs=ga_flat, 201 | create_graph=requires_grad, 202 | allow_unused=True 203 | ) 204 | dg_ga_jvp = dg_ga_jvp.reshape(batch_size, m, d, m).permute(0, 2, 1, 3) 205 | dg_ga_jvp = dg_ga_jvp.diagonal(dim1=-2, dim2=-1).sum(-1) 206 | return dg_ga_jvp 207 | 208 | def _return_zero(self, t, y, v): # noqa 209 | return 0. 210 | 211 | 212 | class RenameMethodsSDE(BaseSDE): 213 | 214 | def __init__(self, sde, drift='f', diffusion='g', prior_drift='h', diffusion_prod='g_prod', 215 | drift_and_diffusion='f_and_g', drift_and_diffusion_prod='f_and_g_prod'): 216 | super(RenameMethodsSDE, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type) 217 | self._base_sde = sde 218 | for name, value in zip(('f', 'g', 'h', 'g_prod', 'f_and_g', 'f_and_g_prod'), 219 | (drift, diffusion, prior_drift, diffusion_prod, drift_and_diffusion, 220 | drift_and_diffusion_prod)): 221 | try: 222 | setattr(self, name, getattr(sde, value)) 223 | except AttributeError: 224 | pass 225 | 226 | 227 | class SDEIto(BaseSDE): 228 | 229 | def __init__(self, noise_type): 230 | super(SDEIto, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.ito) 231 | 232 | 233 | class SDEStratonovich(BaseSDE): 234 | 235 | def __init__(self, noise_type): 236 | super(SDEStratonovich, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.stratonovich) 237 | 238 | 239 | # --- Backwards compatibility: v0.1.1. --- 240 | class SDELogqp(BaseSDE): 241 | 242 | def __init__(self, sde): 243 | super(SDELogqp, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type) 244 | self._base_sde = sde 245 | 246 | # Make this redirection a one-time cost. 247 | try: 248 | self._base_f = sde.f 249 | self._base_g = sde.g 250 | self._base_h = sde.h 251 | except AttributeError as e: 252 | # TODO: relax this requirement, and use f_and_g, f_and_g_prod, f_and_g_and_h and f_and_g_prod_and_h if 253 | # they're available. 254 | raise AttributeError("If using logqp then drift, diffusion and prior drift must all be specified.") from e 255 | 256 | # Make this method selection a one-time cost. 257 | if sde.noise_type == NOISE_TYPES.diagonal: 258 | self.f = self.f_diagonal 259 | self.g = self.g_diagonal 260 | self.f_and_g = self.f_and_g_diagonal 261 | else: 262 | self.f = self.f_general 263 | self.g = self.g_general 264 | self.f_and_g = self.f_and_g_general 265 | 266 | def f_diagonal(self, t, y: Tensor): 267 | y = y[:, :-1] 268 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y) 269 | u = misc.stable_division(f - h, g) 270 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) 271 | return torch.cat([f, f_logqp], dim=1) 272 | 273 | def g_diagonal(self, t, y: Tensor): 274 | y = y[:, :-1] 275 | g = self._base_g(t, y) 276 | g_logqp = y.new_zeros(size=(y.size(0), 1)) 277 | return torch.cat([g, g_logqp], dim=1) 278 | 279 | def f_and_g_diagonal(self, t, y: Tensor): 280 | y = y[:, :-1] 281 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y) 282 | u = misc.stable_division(f - h, g) 283 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) 284 | g_logqp = y.new_zeros(size=(y.size(0), 1)) 285 | return torch.cat([f, f_logqp], dim=1), torch.cat([g, g_logqp], dim=1) 286 | 287 | def f_general(self, t, y: Tensor): 288 | y = y[:, :-1] 289 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y) 290 | u = misc.batch_mvp(g.pinverse(), f - h) # (batch_size, brownian_size). 291 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) 292 | return torch.cat([f, f_logqp], dim=1) 293 | 294 | def g_general(self, t, y: Tensor): 295 | y = y[:, :-1] 296 | g = self._base_sde.g(t, y) 297 | g_logqp = y.new_zeros(size=(g.size(0), 1, g.size(-1))) 298 | return torch.cat([g, g_logqp], dim=1) 299 | 300 | def f_and_g_general(self, t, y: Tensor): 301 | y = y[:, :-1] 302 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y) 303 | u = misc.batch_mvp(g.pinverse(), f - h) # (batch_size, brownian_size). 304 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) 305 | g_logqp = y.new_zeros(size=(g.size(0), 1, g.size(-1))) 306 | return torch.cat([f, f_logqp], dim=1), torch.cat([g, g_logqp], dim=1) 307 | # ---------------------------------------- 308 | -------------------------------------------------------------------------------- /torchsde/_core/base_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import warnings 17 | 18 | import torch 19 | 20 | from . import adaptive_stepping 21 | from . import better_abc 22 | from . import interp 23 | from .base_sde import BaseSDE 24 | from .._brownian import BaseBrownian 25 | from ..settings import NOISE_TYPES 26 | from ..types import Scalar, Tensor, Dict, Tensors, Tuple 27 | 28 | 29 | class BaseSDESolver(metaclass=better_abc.ABCMeta): 30 | """API for solvers with possibly adaptive time stepping.""" 31 | 32 | strong_order = better_abc.abstract_attribute() 33 | weak_order = better_abc.abstract_attribute() 34 | sde_type = better_abc.abstract_attribute() 35 | noise_types = better_abc.abstract_attribute() 36 | levy_area_approximations = better_abc.abstract_attribute() 37 | 38 | def __init__(self, 39 | sde: BaseSDE, 40 | bm: BaseBrownian, 41 | dt: Scalar, 42 | adaptive: bool, 43 | rtol: Scalar, 44 | atol: Scalar, 45 | dt_min: Scalar, 46 | options: Dict, 47 | **kwargs): 48 | super(BaseSDESolver, self).__init__(**kwargs) 49 | if sde.sde_type != self.sde_type: 50 | raise ValueError(f"SDE is of type {sde.sde_type} but solver is for type {self.sde_type}") 51 | if sde.noise_type not in self.noise_types: 52 | raise ValueError(f"SDE has noise type {sde.noise_type} but solver only supports noise types " 53 | f"{self.noise_types}") 54 | if bm.levy_area_approximation not in self.levy_area_approximations: 55 | raise ValueError(f"SDE solver requires one of {self.levy_area_approximations} set as the " 56 | f"`levy_area_approximation` on the Brownian motion.") 57 | if sde.noise_type == NOISE_TYPES.scalar and torch.Size(bm.shape[1:]).numel() != 1: # noqa 58 | raise ValueError("The Brownian motion for scalar SDEs must of dimension 1.") 59 | 60 | self.sde = sde 61 | self.bm = bm 62 | self.dt = dt 63 | self.adaptive = adaptive 64 | self.rtol = rtol 65 | self.atol = atol 66 | self.dt_min = dt_min 67 | self.options = options 68 | 69 | def __repr__(self): 70 | return f"{self.__class__.__name__} of strong order: {self.strong_order}, and weak order: {self.weak_order}" 71 | 72 | def init_extra_solver_state(self, t0, y0) -> Tensors: 73 | return () 74 | 75 | @abc.abstractmethod 76 | def step(self, t0: Scalar, t1: Scalar, y0: Tensor, extra0: Tensors) -> Tuple[Tensor, Tensors]: 77 | """Propose a step with step size from time t to time next_t, with 78 | current state y. 79 | 80 | Args: 81 | t0: float or Tensor of size (,). 82 | t1: float or Tensor of size (,). 83 | y0: Tensor of size (batch_size, d). 84 | extra0: Any extra state for the solver. 85 | 86 | Returns: 87 | y1, where y1 is a Tensor of size (batch_size, d). 88 | extra1: Modified extra state for the solver. 89 | """ 90 | raise NotImplementedError 91 | 92 | def integrate(self, y0: Tensor, ts: Tensor, extra0: Tensors) -> Tuple[Tensor, Tensors]: 93 | """Integrate along trajectory. 94 | 95 | Args: 96 | y0: Tensor of size (batch_size, d) 97 | ts: Tensor of size (T,). 98 | extra0: Any extra state for the solver. 99 | 100 | Returns: 101 | ys, where ys is a Tensor of size (T, batch_size, d). 102 | extra_solver_state, which is a tuple of Tensors of shape (T, ...), where ... is arbitrary and 103 | solver-dependent. 104 | """ 105 | step_size = self.dt 106 | 107 | prev_t = curr_t = ts[0] 108 | prev_y = curr_y = y0 109 | curr_extra = extra0 110 | 111 | ys = [y0] 112 | prev_error_ratio = None 113 | 114 | for out_t in ts[1:]: 115 | while curr_t < out_t: 116 | next_t = min(curr_t + step_size, ts[-1]) 117 | if self.adaptive: 118 | # Take 1 full step. 119 | next_y_full, _ = self.step(curr_t, next_t, curr_y, curr_extra) 120 | # Take 2 half steps. 121 | midpoint_t = 0.5 * (curr_t + next_t) 122 | midpoint_y, midpoint_extra = self.step(curr_t, midpoint_t, curr_y, curr_extra) 123 | next_y, next_extra = self.step(midpoint_t, next_t, midpoint_y, midpoint_extra) 124 | 125 | # Estimate error based on difference between 1 full step and 2 half steps. 126 | with torch.no_grad(): 127 | error_estimate = adaptive_stepping.compute_error(next_y_full, next_y, self.rtol, self.atol) 128 | step_size, prev_error_ratio = adaptive_stepping.update_step_size( 129 | error_estimate=error_estimate, 130 | prev_step_size=step_size, 131 | prev_error_ratio=prev_error_ratio 132 | ) 133 | 134 | if step_size < self.dt_min: 135 | warnings.warn("Hitting minimum allowed step size in adaptive time-stepping.") 136 | step_size = self.dt_min 137 | prev_error_ratio = None 138 | 139 | # Accept step. 140 | if error_estimate <= 1 or step_size <= self.dt_min: 141 | prev_t, prev_y = curr_t, curr_y 142 | curr_t, curr_y, curr_extra = next_t, next_y, next_extra 143 | else: 144 | prev_t, prev_y = curr_t, curr_y 145 | curr_y, curr_extra = self.step(curr_t, next_t, curr_y, curr_extra) 146 | curr_t = next_t 147 | ys.append(interp.linear_interp(t0=prev_t, y0=prev_y, t1=curr_t, y1=curr_y, t=out_t)) 148 | 149 | return torch.stack(ys, dim=0), curr_extra 150 | -------------------------------------------------------------------------------- /torchsde/_core/better_abc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Useful trick adapted from https://stackoverflow.com/a/50381071/12254339 16 | # Allows one to define abstract instance attributes 17 | 18 | import abc 19 | 20 | 21 | class DummyAttribute: 22 | pass 23 | 24 | 25 | def abstract_attribute(obj=None): 26 | if obj is None: 27 | obj = DummyAttribute() 28 | obj.__is_abstract_attribute__ = True 29 | return obj 30 | 31 | 32 | class ABCMeta(abc.ABCMeta): 33 | def __call__(cls, *args, **kwargs): 34 | instance = super(ABCMeta, cls).__call__(*args, **kwargs) 35 | abstract_attributes = { 36 | name 37 | for name in dir(instance) 38 | if getattr(getattr(instance, name), '__is_abstract_attribute__', False) 39 | } 40 | if abstract_attributes: 41 | raise NotImplementedError( 42 | "Can't instantiate abstract class {} with" 43 | " abstract attributes: {}".format( 44 | cls.__name__, 45 | ', '.join(abstract_attributes) 46 | ) 47 | ) 48 | return instance 49 | -------------------------------------------------------------------------------- /torchsde/_core/interp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | def linear_interp(t0, y0, t1, y1, t): 16 | assert t0 <= t <= t1, f"Incorrect time order for linear interpolation: t0={t0}, t={t}, t1={t1}." 17 | y = (t1 - t) / (t1 - t0) * y0 + (t - t0) / (t1 - t0) * y1 18 | return y 19 | -------------------------------------------------------------------------------- /torchsde/_core/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .euler import Euler 16 | from .euler_heun import EulerHeun 17 | from .heun import Heun 18 | from .log_ode import LogODEMidpoint 19 | from .midpoint import Midpoint 20 | from .milstein import MilsteinIto, MilsteinStratonovich 21 | from .reversible_heun import ReversibleHeun, AdjointReversibleHeun 22 | from .srk import SRK 23 | from ...settings import METHODS, SDE_TYPES 24 | 25 | 26 | def select(method, sde_type): 27 | if method == METHODS.euler: 28 | return Euler 29 | elif method == METHODS.milstein and sde_type == SDE_TYPES.ito: 30 | return MilsteinIto 31 | elif method == METHODS.srk: 32 | return SRK 33 | elif method == METHODS.midpoint: 34 | return Midpoint 35 | elif method == METHODS.reversible_heun: 36 | return ReversibleHeun 37 | elif method == METHODS.adjoint_reversible_heun: 38 | return AdjointReversibleHeun 39 | elif method == METHODS.heun: 40 | return Heun 41 | elif method == METHODS.milstein and sde_type == SDE_TYPES.stratonovich: 42 | return MilsteinStratonovich 43 | elif method == METHODS.log_ode_midpoint: 44 | return LogODEMidpoint 45 | elif method == METHODS.euler_heun: 46 | return EulerHeun 47 | else: 48 | raise ValueError(f"Method '{method}' does not match any known method.") 49 | -------------------------------------------------------------------------------- /torchsde/_core/methods/euler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .. import base_solver 16 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS 17 | 18 | 19 | class Euler(base_solver.BaseSDESolver): 20 | weak_order = 1.0 21 | sde_type = SDE_TYPES.ito 22 | noise_types = NOISE_TYPES.all() 23 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all() 24 | 25 | def __init__(self, sde, **kwargs): 26 | self.strong_order = 1.0 if sde.noise_type == NOISE_TYPES.additive else 0.5 27 | super(Euler, self).__init__(sde=sde, **kwargs) 28 | 29 | def step(self, t0, t1, y0, extra0): 30 | del extra0 31 | dt = t1 - t0 32 | I_k = self.bm(t0, t1) 33 | 34 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k) 35 | 36 | y1 = y0 + f * dt + g_prod 37 | return y1, () 38 | -------------------------------------------------------------------------------- /torchsde/_core/methods/euler_heun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .. import base_solver 16 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS 17 | 18 | 19 | class EulerHeun(base_solver.BaseSDESolver): 20 | weak_order = 1.0 21 | sde_type = SDE_TYPES.stratonovich 22 | noise_types = NOISE_TYPES.all() 23 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all() 24 | 25 | def __init__(self, sde, **kwargs): 26 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0 27 | super(EulerHeun, self).__init__(sde=sde, **kwargs) 28 | 29 | def step(self, t0, t1, y0, extra0): 30 | del extra0 31 | dt = t1 - t0 32 | I_k = self.bm(t0, t1) 33 | 34 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k) 35 | 36 | y_prime = y0 + g_prod 37 | 38 | g_prod_prime = self.sde.g_prod(t1, y_prime, I_k) 39 | 40 | y1 = y0 + dt * f + (g_prod + g_prod_prime) * 0.5 41 | 42 | return y1, () 43 | -------------------------------------------------------------------------------- /torchsde/_core/methods/heun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Stratonovich Heun method (strong order 1.0 scheme) from 16 | 17 | Burrage K., Burrage P. M. and Tian T. 2004 "Numerical methods for strong solutions 18 | of stochastic differential equations: an overview" Proc. R. Soc. Lond. A. 460: 373–402. 19 | """ 20 | 21 | from .. import base_solver 22 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS 23 | 24 | 25 | class Heun(base_solver.BaseSDESolver): 26 | weak_order = 1.0 27 | sde_type = SDE_TYPES.stratonovich 28 | noise_types = NOISE_TYPES.all() 29 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all() 30 | 31 | def __init__(self, sde, **kwargs): 32 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0 33 | super(Heun, self).__init__(sde=sde, **kwargs) 34 | 35 | def step(self, t0, t1, y0, extra0): 36 | del extra0 37 | dt = t1 - t0 38 | I_k = self.bm(t0, t1) 39 | 40 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k) 41 | 42 | y0_prime = y0 + dt * f + g_prod 43 | 44 | f_prime, g_prod_prime = self.sde.f_and_g_prod(t1, y0_prime, I_k) 45 | 46 | y1 = y0 + (dt * (f + f_prime) + g_prod + g_prod_prime) * 0.5 47 | 48 | return y1, () 49 | -------------------------------------------------------------------------------- /torchsde/_core/methods/log_ode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Log-ODE scheme constructed by combining Lie-Trotter splitting with the explicit midpoint method. 16 | 17 | The scheme uses Levy area approximations. 18 | """ 19 | 20 | from .. import adjoint_sde 21 | from .. import base_solver 22 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS 23 | 24 | 25 | class LogODEMidpoint(base_solver.BaseSDESolver): 26 | weak_order = 1.0 27 | sde_type = SDE_TYPES.stratonovich 28 | noise_types = NOISE_TYPES.all() 29 | levy_area_approximations = (LEVY_AREA_APPROXIMATIONS.davie, LEVY_AREA_APPROXIMATIONS.foster) 30 | 31 | def __init__(self, sde, **kwargs): 32 | if isinstance(sde, adjoint_sde.AdjointSDE): 33 | raise ValueError("Log-ODE schemes cannot be used for adjoint SDEs, because they require " 34 | "direct access to the diffusion, whilst adjoint SDEs rely on a more efficient " 35 | "diffusion-vector product. Use a different method instead.") 36 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0 37 | super(LogODEMidpoint, self).__init__(sde=sde, **kwargs) 38 | 39 | def step(self, t0, t1, y0, extra0): 40 | del extra0 41 | dt = t1 - t0 42 | I_k, A = self.bm(t0, t1, return_A=True) 43 | 44 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k) 45 | 46 | half_dt = 0.5 * dt 47 | 48 | t_prime = t0 + half_dt 49 | y_prime = y0 + half_dt * f + .5 * g_prod 50 | 51 | f_prime, g_prod_prime = self.sde.f_and_g_prod(t_prime, y_prime, I_k) 52 | dg_ga_prime = self.sde.dg_ga_jvp_column_sum(t_prime, y_prime, A) 53 | 54 | y1 = y0 + dt * f_prime + g_prod_prime + dg_ga_prime 55 | 56 | return y1, () 57 | -------------------------------------------------------------------------------- /torchsde/_core/methods/midpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .. import base_solver 16 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS 17 | 18 | 19 | class Midpoint(base_solver.BaseSDESolver): 20 | weak_order = 1.0 21 | sde_type = SDE_TYPES.stratonovich 22 | noise_types = NOISE_TYPES.all() 23 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all() 24 | 25 | def __init__(self, sde, **kwargs): 26 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0 27 | super(Midpoint, self).__init__(sde=sde, **kwargs) 28 | 29 | def step(self, t0, t1, y0, extra0): 30 | del extra0 31 | dt = t1 - t0 32 | I_k = self.bm(t0, t1) 33 | 34 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k) 35 | 36 | half_dt = 0.5 * dt 37 | 38 | t_prime = t0 + half_dt 39 | y_prime = y0 + half_dt * f + 0.5 * g_prod 40 | 41 | f_prime, g_prod_prime = self.sde.f_and_g_prod(t_prime, y_prime, I_k) 42 | 43 | y1 = y0 + dt * f_prime + g_prod_prime 44 | 45 | return y1, () 46 | -------------------------------------------------------------------------------- /torchsde/_core/methods/milstein.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | 17 | from .. import adjoint_sde 18 | from .. import base_solver 19 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHOD_OPTIONS 20 | 21 | 22 | class BaseMilstein(base_solver.BaseSDESolver, metaclass=abc.ABCMeta): 23 | strong_order = 1.0 24 | weak_order = 1.0 25 | noise_types = (NOISE_TYPES.additive, NOISE_TYPES.diagonal, NOISE_TYPES.scalar) 26 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all() 27 | 28 | def __init__(self, sde, options, **kwargs): 29 | if METHOD_OPTIONS.grad_free not in options: 30 | options[METHOD_OPTIONS.grad_free] = False 31 | if options[METHOD_OPTIONS.grad_free]: 32 | if sde.noise_type == NOISE_TYPES.additive: 33 | # dg=0 in this case, and gdg_prod is already setup to handle that, whilst the grad_free code path isn't. 34 | options[METHOD_OPTIONS.grad_free] = False 35 | if options[METHOD_OPTIONS.grad_free]: 36 | if isinstance(sde, adjoint_sde.AdjointSDE): 37 | # We need access to the diffusion to do things grad-free. 38 | raise ValueError(f"Derivative-free Milstein cannot be used for adjoint SDEs, because it requires " 39 | f"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient " 40 | f"diffusion-vector product. Use derivative-using Milstein instead: " 41 | f"`adjoint_options=dict({METHOD_OPTIONS.grad_free}=False)`") 42 | super(BaseMilstein, self).__init__(sde=sde, options=options, **kwargs) 43 | 44 | @abc.abstractmethod 45 | def v_term(self, I_k, dt): 46 | raise NotImplementedError 47 | 48 | @abc.abstractmethod 49 | def y_prime_f_factor(self, dt, f): 50 | raise NotImplementedError 51 | 52 | def step(self, t0, t1, y0, extra0): 53 | del extra0 54 | dt = t1 - t0 55 | I_k = self.bm(t0, t1) 56 | v = self.v_term(I_k, dt) 57 | 58 | if self.options[METHOD_OPTIONS.grad_free]: 59 | f, g = self.sde.f_and_g(t0, y0) 60 | g_ = g.squeeze(2) if g.dim() == 3 else g # scalar noise vs diagonal noise 61 | sqrt_dt = dt.sqrt() 62 | # TODO: This y_prime_f_factor looks unnecessary: whether it's there or not we get the correct Taylor 63 | # expansion. I've (Patrick) not been able to find a reference making clear why it's sometimes included. 64 | y0_prime = y0 + self.y_prime_f_factor(dt, f) + g_ * sqrt_dt 65 | g_prime = self.sde.g(t0, y0_prime) 66 | g_prod_I_k = self.sde.prod(g, I_k) 67 | gdg_prod = self.sde.prod(g_prime - g, v) / (2 * sqrt_dt) 68 | else: 69 | f = self.sde.f(t0, y0) 70 | g_prod_I_k, gdg_prod = self.sde.g_prod_and_gdg_prod(t0, y0, I_k, 0.5 * v) 71 | 72 | y1 = y0 + f * dt + g_prod_I_k + gdg_prod 73 | 74 | return y1, () 75 | 76 | 77 | class MilsteinIto(BaseMilstein): 78 | sde_type = SDE_TYPES.ito 79 | 80 | def v_term(self, I_k, dt): 81 | return I_k ** 2 - dt 82 | 83 | def y_prime_f_factor(self, dt, f): 84 | return dt * f 85 | 86 | 87 | class MilsteinStratonovich(BaseMilstein): 88 | sde_type = SDE_TYPES.stratonovich 89 | 90 | def v_term(self, I_k, dt): 91 | return I_k ** 2 92 | 93 | def y_prime_f_factor(self, dt, f): 94 | return 0. 95 | -------------------------------------------------------------------------------- /torchsde/_core/methods/reversible_heun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Reversible Heun method from 17 | 18 | https://arxiv.org/abs/2105.13493 19 | 20 | Known to be strong order 0.5 in general and strong order 1.0 for additive noise. 21 | Precise strong orders for diagonal/scalar noise, and weak order in general, are 22 | for the time being unknown. 23 | 24 | This solver uses some extra state such that it is _algebraically reversible_: 25 | it is possible to reconstruct its input (y0, f0, g0, z0) given its output 26 | (y1, f1, g1, z1). 27 | 28 | This means we can backpropagate by (a) inverting these operations, (b) doing a local 29 | forward operation to construct a computation graph, (c) differentiate the local 30 | forward. This is what the adjoint method here does. 31 | 32 | This is in contrast to standard backpropagation, which requires holding all of these 33 | values in memory. 34 | 35 | This is contrast to the standard continuous adjoint method (sdeint_adjoint), which 36 | can only perform this procedure approximately, and only produces approximate gradients 37 | as a result. 38 | """ 39 | 40 | import torch 41 | 42 | from .. import adjoint_sde 43 | from .. import base_solver 44 | from .. import misc 45 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHODS 46 | 47 | 48 | class ReversibleHeun(base_solver.BaseSDESolver): 49 | weak_order = 1.0 50 | sde_type = SDE_TYPES.stratonovich 51 | noise_types = NOISE_TYPES.all() 52 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all() 53 | 54 | def __init__(self, sde, **kwargs): 55 | self.strong_order = 1.0 if sde.noise_type == NOISE_TYPES.additive else 0.5 56 | super(ReversibleHeun, self).__init__(sde=sde, **kwargs) 57 | 58 | def init_extra_solver_state(self, t0, y0): 59 | return self.sde.f_and_g(t0, y0) + (y0,) 60 | 61 | def step(self, t0, t1, y0, extra0): 62 | f0, g0, z0 = extra0 63 | # f is a drift-like quantity 64 | # g is a diffusion-like quantity 65 | # z is a state-like quantity (like y) 66 | dt = t1 - t0 67 | dW = self.bm(t0, t1) 68 | 69 | z1 = 2 * y0 - z0 + f0 * dt + self.sde.prod(g0, dW) 70 | f1, g1 = self.sde.f_and_g(t1, z1) 71 | y1 = y0 + (f0 + f1) * (0.5 * dt) + self.sde.prod(g0 + g1, 0.5 * dW) 72 | 73 | return y1, (f1, g1, z1) 74 | 75 | 76 | class AdjointReversibleHeun(base_solver.BaseSDESolver): 77 | weak_order = 1.0 78 | sde_type = SDE_TYPES.stratonovich 79 | noise_types = NOISE_TYPES.all() 80 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all() 81 | 82 | def __init__(self, sde, **kwargs): 83 | if not isinstance(sde, adjoint_sde.AdjointSDE): 84 | raise ValueError(f"{METHODS.adjoint_reversible_heun} can only be used for adjoint_method.") 85 | self.strong_order = 1.0 if sde.noise_type == NOISE_TYPES.additive else 0.5 86 | super(AdjointReversibleHeun, self).__init__(sde=sde, **kwargs) 87 | self.forward_sde = sde.forward_sde 88 | 89 | if self.forward_sde.noise_type == NOISE_TYPES.diagonal: 90 | self._adjoint_of_prod = lambda tensor1, tensor2: tensor1 * tensor2 91 | else: 92 | self._adjoint_of_prod = lambda tensor1, tensor2: tensor1.unsqueeze(-1) * tensor2.unsqueeze(-2) 93 | 94 | def init_extra_solver_state(self, t0, y0): 95 | # We expect to always be given the extra state from the forward pass. 96 | raise RuntimeError("Please report a bug to torchsde.") 97 | 98 | def step(self, t0, t1, y0, extra0): 99 | forward_f0, forward_g0, forward_z0 = extra0 100 | dt = t1 - t0 101 | dW = self.bm(t0, t1) 102 | half_dt = 0.5 * dt 103 | half_dW = 0.5 * dW 104 | forward_y0, adj_y0, (adj_f0, adj_g0, adj_z0, *adj_params), requires_grad = self.sde.get_state(t0, y0, 105 | extra_states=True) 106 | adj_y0_half_dt = adj_y0 * half_dt 107 | adj_y0_half_dW = self._adjoint_of_prod(adj_y0, half_dW) 108 | 109 | forward_z1 = 2 * forward_y0 - forward_z0 - forward_f0 * dt - self.forward_sde.prod(forward_g0, dW) 110 | 111 | adj_y1 = adj_y0 112 | adj_f1 = adj_y0_half_dt 113 | adj_f0 = adj_f0 + adj_y0_half_dt 114 | adj_g1 = adj_y0_half_dW 115 | adj_g0 = adj_g0 + adj_y0_half_dW 116 | 117 | # TODO: efficiency. It should be possible to make one fewer forward call by re-using the forward computation 118 | # in the previous step. 119 | with torch.enable_grad(): 120 | if not forward_z0.requires_grad: 121 | forward_z0 = forward_z0.detach().requires_grad_() 122 | re_forward_f0, re_forward_g0 = self.forward_sde.f_and_g(-t0, forward_z0) 123 | 124 | vjp_z, *vjp_params = misc.vjp(outputs=(re_forward_f0, re_forward_g0), 125 | inputs=[forward_z0] + self.sde.params, 126 | grad_outputs=[adj_f0, adj_g0], 127 | allow_unused=True, 128 | retain_graph=True, 129 | create_graph=requires_grad) 130 | adj_z0 = adj_z0 + vjp_z 131 | adj_params = misc.seq_add(adj_params, vjp_params) 132 | 133 | forward_f1, forward_g1 = self.forward_sde.f_and_g(-t1, forward_z1) 134 | forward_y1 = forward_y0 - (forward_f0 + forward_f1) * half_dt - self.forward_sde.prod(forward_g0 + forward_g1, 135 | half_dW) 136 | 137 | adj_y1 = adj_y1 + 2 * adj_z0 138 | adj_z1 = -adj_z0 139 | adj_f1 = adj_f1 + adj_z0 * dt 140 | adj_g1 = adj_g1 + self._adjoint_of_prod(adj_z0, dW) 141 | 142 | y1 = misc.flatten([forward_y1, adj_y1, adj_f1, adj_g1, adj_z1] + adj_params).unsqueeze(0) 143 | 144 | return y1, (forward_f1, forward_g1, forward_z1) 145 | -------------------------------------------------------------------------------- /torchsde/_core/methods/srk.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Strong order 1.5 scheme from 16 | 17 | Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions 18 | of stochastic differential equations." SIAM Journal on Numerical Analysis 48, 19 | no. 3 (2010): 922-952. 20 | """ 21 | 22 | from .tableaus import sra1, srid2 23 | from .. import adjoint_sde 24 | from .. import base_solver 25 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS 26 | 27 | _r2 = 1 / 2 28 | _r6 = 1 / 6 29 | 30 | 31 | class SRK(base_solver.BaseSDESolver): 32 | strong_order = 1.5 33 | weak_order = 1.5 34 | sde_type = SDE_TYPES.ito 35 | noise_types = (NOISE_TYPES.additive, NOISE_TYPES.diagonal, NOISE_TYPES.scalar) 36 | levy_area_approximations = (LEVY_AREA_APPROXIMATIONS.space_time, 37 | LEVY_AREA_APPROXIMATIONS.davie, 38 | LEVY_AREA_APPROXIMATIONS.foster) 39 | 40 | def __init__(self, sde, **kwargs): 41 | if sde.noise_type == NOISE_TYPES.additive: 42 | self.step = self.additive_step 43 | else: 44 | self.step = self.diagonal_or_scalar_step 45 | 46 | if isinstance(sde, adjoint_sde.AdjointSDE): 47 | raise ValueError("Stochastic Runge–Kutta methods cannot be used for adjoint SDEs, because it requires " 48 | "direct access to the diffusion, whilst adjoint SDEs rely on a more efficient " 49 | "diffusion-vector product. Use a different method instead.") 50 | 51 | super(SRK, self).__init__(sde=sde, **kwargs) 52 | 53 | def step(self, t0, t1, y, extra0): 54 | # Just to make @abstractmethod happy, as we assign during __init__. 55 | raise RuntimeError 56 | 57 | def diagonal_or_scalar_step(self, t0, t1, y0, extra0): 58 | del extra0 59 | dt = t1 - t0 60 | rdt = 1 / dt 61 | sqrt_dt = dt.sqrt() 62 | I_k, I_k0 = self.bm(t0, t1, return_U=True) 63 | I_kk = (I_k ** 2 - dt) * _r2 64 | I_kkk = (I_k ** 3 - 3 * dt * I_k) * _r6 65 | 66 | y1 = y0 67 | H0, H1 = [], [] 68 | for s in range(srid2.STAGES): 69 | H0s, H1s = y0, y0 # Values at the current stage to be accumulated. 70 | for j in range(s): 71 | f = self.sde.f(t0 + srid2.C0[j] * dt, H0[j]) 72 | g = self.sde.g(t0 + srid2.C1[j] * dt, H1[j]) 73 | g = g.squeeze(2) if g.dim() == 3 else g 74 | H0s = H0s + srid2.A0[s][j] * f * dt + srid2.B0[s][j] * g * I_k0 * rdt 75 | H1s = H1s + srid2.A1[s][j] * f * dt + srid2.B1[s][j] * g * sqrt_dt 76 | H0.append(H0s) 77 | H1.append(H1s) 78 | 79 | f = self.sde.f(t0 + srid2.C0[s] * dt, H0s) 80 | g_weight = ( 81 | srid2.beta1[s] * I_k + 82 | srid2.beta2[s] * I_kk / sqrt_dt + 83 | srid2.beta3[s] * I_k0 * rdt + 84 | srid2.beta4[s] * I_kkk * rdt 85 | ) 86 | g_prod = self.sde.g_prod(t0 + srid2.C1[s] * dt, H1s, g_weight) 87 | y1 = y1 + srid2.alpha[s] * f * dt + g_prod 88 | return y1, () 89 | 90 | def additive_step(self, t0, t1, y0, extra0): 91 | del extra0 92 | dt = t1 - t0 93 | rdt = 1 / dt 94 | I_k, I_k0 = self.bm(t0, t1, return_U=True) 95 | 96 | y1 = y0 97 | H0 = [] 98 | for i in range(sra1.STAGES): 99 | H0i = y0 100 | for j in range(i): 101 | f = self.sde.f(t0 + sra1.C0[j] * dt, H0[j]) 102 | g_weight = sra1.B0[i][j] * I_k0 * rdt 103 | g_prod = self.sde.g_prod(t0 + sra1.C1[j] * dt, y0, g_weight) 104 | H0i = H0i + sra1.A0[i][j] * f * dt + g_prod 105 | H0.append(H0i) 106 | 107 | f = self.sde.f(t0 + sra1.C0[i] * dt, H0i) 108 | g_weight = sra1.beta1[i] * I_k + sra1.beta2[i] * I_k0 * rdt 109 | g_prod = self.sde.g_prod(t0 + sra1.C1[i] * dt, y0, g_weight) 110 | y1 = y1 + sra1.alpha[i] * f * dt + g_prod 111 | return y1, () 112 | -------------------------------------------------------------------------------- /torchsde/_core/methods/tableaus/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /torchsde/_core/methods/tableaus/sra1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS". 16 | # For additive noise structure. 17 | # (ODE order, SDE strong order) = (2.0, 1.5). 18 | 19 | STAGES = 2 20 | 21 | C0 = (0, 3 / 4) 22 | C1 = (1, 0) 23 | 24 | A0 = ( 25 | (), 26 | (3 / 4,), 27 | ) 28 | 29 | B0 = ( 30 | (), 31 | (3 / 2,), 32 | ) 33 | 34 | alpha = (1 / 3, 2 / 3) 35 | beta1 = (1, 0) 36 | beta2 = (-1, 1) 37 | -------------------------------------------------------------------------------- /torchsde/_core/methods/tableaus/sra2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS". 16 | # For additive noise structure. 17 | # (ODE order, SDE strong order) = (2.0, 1.5). 18 | 19 | STAGES = 2 20 | 21 | C0 = (0, 3 / 4) 22 | C1 = (1 / 3, 1) 23 | 24 | A0 = ( 25 | (), 26 | (3 / 4,), 27 | ) 28 | 29 | B0 = ( 30 | (), 31 | (3 / 2,), 32 | ) 33 | 34 | alpha = (1 / 3, 2 / 3) 35 | beta1 = (0, 1) 36 | beta2 = (-3 / 2, 3 / 2) 37 | -------------------------------------------------------------------------------- /torchsde/_core/methods/tableaus/sra3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS". 16 | # For additive noise structure. 17 | # (ODE order, SDE strong order) = (3.0, 1.5). 18 | 19 | STAGES = 3 20 | 21 | C0 = (0, 1, 1 / 2) 22 | C1 = (1, 0, 0) 23 | 24 | A0 = ( 25 | (), 26 | (1,), 27 | (1 / 4, 1 / 4), 28 | ) 29 | 30 | B0 = ( 31 | (), 32 | (0,), 33 | (1, 1 / 2), 34 | ) 35 | 36 | alpha = (1 / 6, 1 / 6, 2 / 3) 37 | beta1 = (1, 0, 0) 38 | beta2 = (1, -1, 0) 39 | -------------------------------------------------------------------------------- /torchsde/_core/methods/tableaus/srid1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS". 16 | # For diagonal noise structure. 17 | # (ODE order, SDE strong order) = (2.0, 1.5). 18 | 19 | STAGES = 4 20 | 21 | C0 = (0, 3 / 4, 0, 0) 22 | C1 = (0, 1 / 4, 1, 1 / 4) 23 | 24 | A0 = ( 25 | (), 26 | (3 / 4,), 27 | (0, 0), 28 | (0, 0, 0), 29 | ) 30 | A1 = ( 31 | (), 32 | (1 / 4,), 33 | (1, 0), 34 | (0, 0, 1 / 4) 35 | ) 36 | 37 | B0 = ( 38 | (), 39 | (3 / 2,), 40 | (0, 0), 41 | (0, 0, 0), 42 | ) 43 | B1 = ( 44 | (), 45 | (1 / 2,), 46 | (-1, 0), 47 | (-5, 3, 1 / 2) 48 | ) 49 | 50 | alpha = (1 / 3, 2 / 3, 0, 0) 51 | beta1 = (-1, 4 / 3, 2 / 3, 0) 52 | beta2 = (-1, 4 / 3, -1 / 3, 0) 53 | beta3 = (2, -4 / 3, -2 / 3, 0) 54 | beta4 = (-2, 5 / 3, -2 / 3, 1) 55 | -------------------------------------------------------------------------------- /torchsde/_core/methods/tableaus/srid2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS". 16 | # For diagonal noise structure. 17 | # (ODE order, SDE strong order) = (3.0, 1.5). 18 | 19 | STAGES = 4 20 | 21 | C0 = (0, 1, 1 / 2, 0) 22 | C1 = (0, 1 / 4, 1, 1 / 4) 23 | 24 | A0 = ( 25 | (), 26 | (1,), 27 | (1 / 4, 1 / 4), 28 | (0, 0, 0) 29 | ) 30 | A1 = ( 31 | (), 32 | (1 / 4,), 33 | (1, 0), 34 | (0, 0, 1 / 4) 35 | ) 36 | 37 | B0 = ( 38 | (), 39 | (0,), 40 | (1, 1 / 2), 41 | (0, 0, 0), 42 | ) 43 | B1 = ( 44 | (), 45 | (-1 / 2,), 46 | (1, 0), 47 | (2, -1, 1 / 2) 48 | ) 49 | 50 | alpha = (1 / 6, 1 / 6, 2 / 3, 0) 51 | beta1 = (-1, 4 / 3, 2 / 3, 0) 52 | beta2 = (1, -4 / 3, 1 / 3, 0) 53 | beta3 = (2, -4 / 3, -2 / 3, 0) 54 | beta4 = (-2, 5 / 3, -2 / 3, 1) 55 | -------------------------------------------------------------------------------- /torchsde/_core/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | import torch 18 | 19 | 20 | def assert_no_grad(names, maybe_tensors): 21 | for name, maybe_tensor in zip(names, maybe_tensors): 22 | if torch.is_tensor(maybe_tensor) and maybe_tensor.requires_grad: 23 | raise ValueError(f"Argument {name} must not require gradient.") 24 | 25 | 26 | def handle_unused_kwargs(unused_kwargs, msg=None): 27 | if len(unused_kwargs) > 0: 28 | if msg is not None: 29 | warnings.warn(f"{msg}: Unexpected arguments {unused_kwargs}") 30 | else: 31 | warnings.warn(f"Unexpected arguments {unused_kwargs}") 32 | 33 | 34 | def flatten(sequence): 35 | return torch.cat([p.reshape(-1) for p in sequence]) if len(sequence) > 0 else torch.tensor([]) 36 | 37 | 38 | def convert_none_to_zeros(sequence, like_sequence): 39 | return [torch.zeros_like(q) if p is None else p for p, q in zip(sequence, like_sequence)] 40 | 41 | 42 | def make_seq_requires_grad(sequence): 43 | return [p if p.requires_grad else p.detach().requires_grad_(True) for p in sequence] 44 | 45 | 46 | def is_strictly_increasing(ts): 47 | return all(x < y for x, y in zip(ts[:-1], ts[1:])) 48 | 49 | 50 | def is_nan(t): 51 | return torch.any(torch.isnan(t)) 52 | 53 | 54 | def seq_add(*seqs): 55 | return [sum(seq) for seq in zip(*seqs)] 56 | 57 | 58 | def seq_sub(xs, ys): 59 | return [x - y for x, y in zip(xs, ys)] 60 | 61 | 62 | def batch_mvp(m, v): 63 | return torch.bmm(m, v.unsqueeze(-1)).squeeze(dim=-1) 64 | 65 | 66 | def stable_division(a, b, epsilon=1e-7): 67 | b = torch.where(b.abs().detach() > epsilon, b, torch.full_like(b, fill_value=epsilon) * b.sign()) 68 | return a / b 69 | 70 | 71 | def vjp(outputs, inputs, **kwargs): 72 | if torch.is_tensor(inputs): 73 | inputs = [inputs] 74 | _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 74 75 | 76 | if torch.is_tensor(outputs): 77 | outputs = [outputs] 78 | outputs = make_seq_requires_grad(outputs) 79 | 80 | _vjp = torch.autograd.grad(outputs, inputs, **kwargs) 81 | return convert_none_to_zeros(_vjp, inputs) 82 | 83 | 84 | def jvp(outputs, inputs, grad_inputs=None, **kwargs): 85 | # Unlike `torch.autograd.functional.jvp`, this function avoids repeating forward computation. 86 | if torch.is_tensor(inputs): 87 | inputs = [inputs] 88 | _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 88 89 | 90 | if torch.is_tensor(outputs): 91 | outputs = [outputs] 92 | outputs = make_seq_requires_grad(outputs) 93 | 94 | dummy_outputs = [torch.zeros_like(o, requires_grad=True) for o in outputs] 95 | _vjp = torch.autograd.grad(outputs, inputs, grad_outputs=dummy_outputs, create_graph=True, allow_unused=True) 96 | _vjp = make_seq_requires_grad(convert_none_to_zeros(_vjp, inputs)) 97 | 98 | _jvp = torch.autograd.grad(_vjp, dummy_outputs, grad_outputs=grad_inputs, **kwargs) 99 | return convert_none_to_zeros(_jvp, dummy_outputs) 100 | 101 | 102 | def flat_to_shape(flat_tensor, shapes): 103 | """Convert a flat tensor to a list of tensors with specified shapes. 104 | 105 | `flat_tensor` must have exactly the number of elements as stated in `shapes`. 106 | """ 107 | numels = [shape.numel() for shape in shapes] 108 | return [flat.reshape(shape) for flat, shape in zip(flat_tensor.split(split_size=numels), shapes)] 109 | -------------------------------------------------------------------------------- /torchsde/settings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class ContainerMeta(type): 17 | def all(cls): 18 | return sorted(getattr(cls, x) for x in dir(cls) if not x.startswith('__')) 19 | 20 | def __str__(cls): 21 | return str(cls.all()) 22 | 23 | def __contains__(cls, item): 24 | return item in cls.all() 25 | 26 | 27 | # TODO: consider moving all these enums into some appropriate section of the code, rather than having them be global 28 | # like this. (e.g. instead set METHODS = {'euler': Euler, ...} in methods/__init__.py) 29 | class METHODS(metaclass=ContainerMeta): 30 | euler = 'euler' 31 | milstein = 'milstein' 32 | srk = 'srk' 33 | midpoint = 'midpoint' 34 | reversible_heun = 'reversible_heun' 35 | adjoint_reversible_heun = 'adjoint_reversible_heun' 36 | heun = 'heun' 37 | log_ode_midpoint = 'log_ode' 38 | euler_heun = 'euler_heun' 39 | 40 | 41 | class NOISE_TYPES(metaclass=ContainerMeta): # noqa 42 | general = 'general' 43 | diagonal = 'diagonal' 44 | scalar = 'scalar' 45 | additive = 'additive' 46 | 47 | 48 | class SDE_TYPES(metaclass=ContainerMeta): # noqa 49 | ito = 'ito' 50 | stratonovich = 'stratonovich' 51 | 52 | 53 | class LEVY_AREA_APPROXIMATIONS(metaclass=ContainerMeta): # noqa 54 | none = 'none' # Don't compute any Levy area approximation 55 | space_time = 'space-time' # Only compute an (exact) space-time Levy area 56 | davie = 'davie' # Compute Davie's approximation to Levy area 57 | foster = 'foster' # Compute Foster's correction to Davie's approximation to Levy area 58 | 59 | 60 | class METHOD_OPTIONS(metaclass=ContainerMeta): # noqa 61 | grad_free = 'grad_free' 62 | -------------------------------------------------------------------------------- /torchsde/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # We import from `typing` more than what's enough, so that other modules can import from this file and not `typing`. 16 | from typing import Sequence, Union, Optional, Any, Dict, Tuple, Callable # noqa: F401 17 | 18 | import torch 19 | 20 | Tensor = torch.Tensor 21 | Tensors = Sequence[Tensor] 22 | TensorOrTensors = Union[Tensor, Tensors] 23 | 24 | Scalar = Union[float, Tensor] 25 | Vector = Union[Sequence[float], Tensor] 26 | 27 | Module = torch.nn.Module 28 | Modules = Sequence[Module] 29 | ModuleOrModules = Union[Module, Modules] 30 | 31 | Size = torch.Size 32 | Sizes = Sequence[Size] 33 | --------------------------------------------------------------------------------