├── .flake8
├── .github
└── workflows
│ └── run_tests.yml
├── .gitignore
├── CONTRIBUTING.md
├── DOCUMENTATION.md
├── LICENSE
├── README.md
├── assets
└── latent_sde.gif
├── benchmarks
├── __init__.py
├── brownian.py
└── profile_btree.py
├── diagnostics
├── __init__.py
├── inspection.py
├── ito_additive.py
├── ito_diagonal.py
├── ito_general.py
├── ito_scalar.py
├── run_all.py
├── stratonovich_additive.py
├── stratonovich_diagonal.py
├── stratonovich_general.py
├── stratonovich_scalar.py
└── utils.py
├── examples
├── __init__.py
├── cont_ddpm.py
├── demo.ipynb
├── latent_sde.py
├── latent_sde_lorenz.py
├── sde_gan.py
└── unet.py
├── pyproject.toml
├── setup.py
├── tests
├── __init__.py
├── problems.py
├── test_adjoint.py
├── test_brownian_interval.py
├── test_brownian_path.py
├── test_brownian_tree.py
├── test_sdeint.py
└── utils.py
└── torchsde
├── __init__.py
├── _brownian
├── __init__.py
├── brownian_base.py
├── brownian_interval.py
└── derived.py
├── _core
├── __init__.py
├── adaptive_stepping.py
├── adjoint.py
├── adjoint_sde.py
├── base_sde.py
├── base_solver.py
├── better_abc.py
├── interp.py
├── methods
│ ├── __init__.py
│ ├── euler.py
│ ├── euler_heun.py
│ ├── heun.py
│ ├── log_ode.py
│ ├── midpoint.py
│ ├── milstein.py
│ ├── reversible_heun.py
│ ├── srk.py
│ └── tableaus
│ │ ├── __init__.py
│ │ ├── sra1.py
│ │ ├── sra2.py
│ │ ├── sra3.py
│ │ ├── srid1.py
│ │ └── srid2.py
├── misc.py
└── sdeint.py
├── settings.py
└── types.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = W291,W503,W504,E123,E126,E203,E402,E701
4 | per-file-ignores = __init__.py: F401
5 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Run tests
5 |
6 | on:
7 | pull_request:
8 | schedule:
9 | - cron: "0 2 * * 6"
10 | workflow_dispatch:
11 |
12 | jobs:
13 | build:
14 | strategy:
15 | matrix:
16 | python-version: [ 3.8 ]
17 | os: [ ubuntu-latest, macOS-latest, windows-latest ]
18 | fail-fast: false
19 | runs-on: ${{ matrix.os }}
20 | steps:
21 | - name: Checkout code
22 | uses: actions/checkout@v4
23 |
24 | - name: Set up Python ${{ matrix.python-version }}
25 | uses: actions/setup-python@v4
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | cache: pip
29 | cache-dependency-path: setup.py
30 |
31 | - name: Install
32 | run: pip install pytest -e . --only-binary=numpy,scipy,matplotlib,torch
33 | env:
34 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu
35 |
36 | - name: Test with pytest
37 | run: python -m pytest -v
38 |
39 | lint:
40 | runs-on: ubuntu-latest
41 | steps:
42 | - name: Checkout code
43 | uses: actions/checkout@v4
44 |
45 | - uses: actions/setup-python@v4
46 | with:
47 | python-version: "3.11"
48 | cache: pip
49 | cache-dependency-path: setup.py
50 |
51 | - name: Lint with flake8
52 | run: |
53 | python -m pip install flake8
54 | python -m flake8 .
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .DS_Store
3 | .lib
4 | __pycache__
5 | .ipynb_checkpoints
6 | diagnostics/plots/
7 | build/
8 | dist/
9 | .vscode/
10 | *.egg-info/
11 | benchmarks/plots/
12 | CMakeLists.txt
13 | restats
14 | *-darwin.so
15 | **.pyc
16 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | a few small guidelines to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch Implementation of Differentiable SDE Solvers 
2 | This library provides [stochastic differential equation (SDE)](https://en.wikipedia.org/wiki/Stochastic_differential_equation) solvers with GPU support and efficient backpropagation.
3 |
4 | ---
5 |
6 |
7 |
8 |
9 | ## Installation
10 | ```shell script
11 | pip install torchsde
12 | ```
13 |
14 | **Requirements:** Python >=3.8 and PyTorch >=1.6.0.
15 |
16 | ## Documentation
17 | Available [here](./DOCUMENTATION.md).
18 |
19 | ## Examples
20 | ### Quick example
21 | ```python
22 | import torch
23 | import torchsde
24 |
25 | batch_size, state_size, brownian_size = 32, 3, 2
26 | t_size = 20
27 |
28 | class SDE(torch.nn.Module):
29 | noise_type = 'general'
30 | sde_type = 'ito'
31 |
32 | def __init__(self):
33 | super().__init__()
34 | self.mu = torch.nn.Linear(state_size,
35 | state_size)
36 | self.sigma = torch.nn.Linear(state_size,
37 | state_size * brownian_size)
38 |
39 | # Drift
40 | def f(self, t, y):
41 | return self.mu(y) # shape (batch_size, state_size)
42 |
43 | # Diffusion
44 | def g(self, t, y):
45 | return self.sigma(y).view(batch_size,
46 | state_size,
47 | brownian_size)
48 |
49 | sde = SDE()
50 | y0 = torch.full((batch_size, state_size), 0.1)
51 | ts = torch.linspace(0, 1, t_size)
52 | # Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]].
53 | # ys will have shape (t_size, batch_size, state_size)
54 | ys = torchsde.sdeint(sde, y0, ts)
55 | ```
56 |
57 | ### Notebook
58 |
59 | [`examples/demo.ipynb`](examples/demo.ipynb) gives a short guide on how to solve SDEs, including subtle points such as fixing the randomness in the solver and the choice of *noise types*.
60 |
61 | ### Latent SDE
62 |
63 | [`examples/latent_sde.py`](examples/latent_sde.py) learns a *latent stochastic differential equation*, as in Section 5 of [\[1\]](https://arxiv.org/pdf/2001.01328.pdf).
64 | The example fits an SDE to data, whilst regularizing it to be like an [Ornstein-Uhlenbeck](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process) prior process.
65 | The model can be loosely viewed as a [variational autoencoder](https://en.wikipedia.org/wiki/Autoencoder#Variational_autoencoder_(VAE)) with its prior and approximate posterior being SDEs. This example can be run via
66 | ```shell script
67 | python -m examples.latent_sde --train-dir
68 | ```
69 | The program outputs figures to the path specified by ``.
70 | Training should stabilize after 500 iterations with the default hyperparameters.
71 |
72 | ### Neural SDEs as GANs
73 | [`examples/sde_gan.py`](examples/sde_gan.py) learns an SDE as a GAN, as in [\[2\]](https://arxiv.org/abs/2102.03657), [\[3\]](https://arxiv.org/abs/2105.13493). The example trains an SDE as the generator of a GAN, whilst using a [neural CDE](https://github.com/patrick-kidger/NeuralCDE) [\[4\]](https://arxiv.org/abs/2005.08926) as the discriminator. This example can be run via
74 |
75 | ```shell script
76 | python -m examples.sde_gan
77 | ```
78 |
79 | ## Citation
80 |
81 | If you found this codebase useful in your research, please consider citing either or both of:
82 |
83 | ```
84 | @article{li2020scalable,
85 | title={Scalable gradients for stochastic differential equations},
86 | author={Li, Xuechen and Wong, Ting-Kam Leonard and Chen, Ricky T. Q. and Duvenaud, David},
87 | journal={International Conference on Artificial Intelligence and Statistics},
88 | year={2020}
89 | }
90 | ```
91 |
92 | ```
93 | @article{kidger2021neuralsde,
94 | title={Neural {SDE}s as {I}nfinite-{D}imensional {GAN}s},
95 | author={Kidger, Patrick and Foster, James and Li, Xuechen and Oberhauser, Harald and Lyons, Terry},
96 | journal={International Conference on Machine Learning},
97 | year={2021}
98 | }
99 | ```
100 |
101 | ## References
102 |
103 | \[1\] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations". *International Conference on Artificial Intelligence and Statistics.* 2020. [[arXiv]](https://arxiv.org/pdf/2001.01328.pdf)
104 |
105 | \[2\] Patrick Kidger, James Foster, Xuechen Li, Harald Oberhauser, Terry Lyons. "Neural SDEs as Infinite-Dimensional GANs". *International Conference on Machine Learning* 2021. [[arXiv]](https://arxiv.org/abs/2102.03657)
106 |
107 | \[3\] Patrick Kidger, James Foster, Xuechen Li, Terry Lyons. "Efficient and Accurate Gradients for Neural SDEs". 2021. [[arXiv]](https://arxiv.org/abs/2105.13493)
108 |
109 | \[4\] Patrick Kidger, James Morrill, James Foster, Terry Lyons, "Neural Controlled Differential Equations for Irregular Time Series". *Neural Information Processing Systems* 2020. [[arXiv]](https://arxiv.org/abs/2005.08926)
110 |
111 | ---
112 | This is a research project, not an official Google product.
113 |
--------------------------------------------------------------------------------
/assets/latent_sde.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/torchsde/eb3a00e31cbd56176270066ed2f62c394cf6acb7/assets/latent_sde.gif
--------------------------------------------------------------------------------
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/torchsde/eb3a00e31cbd56176270066ed2f62c394cf6acb7/benchmarks/__init__.py
--------------------------------------------------------------------------------
/benchmarks/brownian.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Compare the speed of 5 Brownian motion variants on problems of different sizes."""
16 | import argparse
17 | import logging
18 | import os
19 | import time
20 |
21 | import numpy.random as npr
22 | import torch
23 |
24 | import torchsde
25 | from diagnostics import utils
26 |
27 | t0, t1 = 0.0, 1.0
28 | reps, steps = 3, 100
29 | small_batch_size, small_d = 128, 5
30 | large_batch_size, large_d = 256, 128
31 | huge_batch_size, huge_d = 512, 256
32 |
33 |
34 | def _time_query(bm, ts):
35 | now = time.perf_counter()
36 | for _ in range(reps):
37 | for ta, tb in zip(ts[:-1], ts[1:]):
38 | if ta > tb:
39 | ta, tb = tb, ta
40 | bm(ta, tb)
41 | return time.perf_counter() - now
42 |
43 |
44 | def _compare(w0, ts, msg=''):
45 | bm = torchsde.BrownianPath(t0=t0, w0=w0)
46 | bp_py_time = _time_query(bm, ts)
47 | logging.warning(f'{msg} (torchsde.BrownianPath): {bp_py_time:.4f}')
48 |
49 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, tol=1e-5)
50 | bt_py_time = _time_query(bm, ts)
51 | logging.warning(f'{msg} (torchsde.BrownianTree): {bt_py_time:.4f}')
52 |
53 | bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=w0.shape, dtype=w0.dtype, device=w0.device)
54 | bi_py_time = _time_query(bm, ts)
55 | logging.warning(f'{msg} (torchsde.BrownianInterval): {bi_py_time:.4f}')
56 |
57 | return bp_py_time, bt_py_time, bi_py_time
58 |
59 |
60 | def sequential_access():
61 | ts = torch.linspace(t0, t1, steps=steps)
62 |
63 | w0 = torch.zeros(small_batch_size, small_d).to(device)
64 | bp_py_time_s, bt_py_time_s, bi_py_time_s = _compare(w0, ts, msg='small sequential access')
65 |
66 | w0 = torch.zeros(large_batch_size, large_d).to(device)
67 | bp_py_time_l, bt_py_time_l, bi_py_time_l = _compare(w0, ts, msg='large sequential access')
68 |
69 | w0 = torch.zeros(huge_batch_size, huge_d).to(device)
70 | bp_py_time_h, bt_py_time_h, bi_py_time_h = _compare(w0, ts, msg="huge sequential access")
71 |
72 | img_path = os.path.join('.', 'benchmarks', f'plots-{device}', 'sequential_access.png')
73 | if not os.path.exists(os.path.dirname(img_path)):
74 | os.makedirs(os.path.dirname(img_path))
75 |
76 | xaxis = [small_batch_size * small_d, large_batch_size * large_batch_size, huge_batch_size * huge_d]
77 |
78 | utils.swiss_knife_plotter(
79 | img_path,
80 | plots=[
81 | {'x': xaxis, 'y': [bp_py_time_s, bp_py_time_l, bp_py_time_h], 'label': 'bp_py', 'marker': 'x'},
82 | {'x': xaxis, 'y': [bt_py_time_s, bt_py_time_l, bt_py_time_h], 'label': 'bt_py', 'marker': 'x'},
83 | {'x': xaxis, 'y': [bi_py_time_s, bi_py_time_l, bi_py_time_h], 'label': 'bi_py', 'marker': 'x'},
84 | ],
85 | options={
86 | 'xscale': 'log',
87 | 'yscale': 'log',
88 | 'xlabel': 'size of tensor',
89 | 'ylabel': f'wall time on {device}',
90 | 'title': 'sequential access'
91 | }
92 | )
93 |
94 |
95 | def random_access():
96 | generator = torch.Generator().manual_seed(456789)
97 | ts = torch.empty(steps).uniform_(t0, t1, generator=generator)
98 |
99 | w0 = torch.zeros(small_batch_size, small_d).to(device)
100 | bp_py_time_s, bt_py_time_s, bi_py_time_s = _compare(w0, ts, msg='small random access')
101 |
102 | w0 = torch.zeros(large_batch_size, large_d).to(device)
103 | bp_py_time_l, bt_py_time_l, bi_py_time_l = _compare(w0, ts, msg='large random access')
104 |
105 | w0 = torch.zeros(huge_batch_size, huge_d).to(device)
106 | bp_py_time_h, bt_py_time_h, bi_py_time_h = _compare(w0, ts, msg="huge random access")
107 |
108 | img_path = os.path.join('.', 'benchmarks', f'plots-{device}', 'random_access.png')
109 | if not os.path.exists(os.path.dirname(img_path)):
110 | os.makedirs(os.path.dirname(img_path))
111 |
112 | xaxis = [small_batch_size * small_d, large_batch_size * large_batch_size, huge_batch_size * huge_d]
113 |
114 | utils.swiss_knife_plotter(
115 | img_path,
116 | plots=[
117 | {'x': xaxis, 'y': [bp_py_time_s, bp_py_time_l, bp_py_time_h], 'label': 'bp_py', 'marker': 'x'},
118 | {'x': xaxis, 'y': [bt_py_time_s, bt_py_time_l, bt_py_time_h], 'label': 'bt_py', 'marker': 'x'},
119 | {'x': xaxis, 'y': [bi_py_time_s, bi_py_time_l, bi_py_time_h], 'label': 'bi_py', 'marker': 'x'},
120 | ],
121 | options={
122 | 'xscale': 'log',
123 | 'yscale': 'log',
124 | 'xlabel': 'size of tensor',
125 | 'ylabel': f'wall time on {device}',
126 | 'title': 'random access'
127 | }
128 | )
129 |
130 |
131 | class SDE(torchsde.SDEIto):
132 | def __init__(self):
133 | super(SDE, self).__init__(noise_type="diagonal")
134 |
135 | def f(self, t, y):
136 | return y
137 |
138 | def g(self, t, y):
139 | return torch.exp(-y)
140 |
141 |
142 | def _time_sdeint(sde, y0, ts, bm):
143 | now = time.perf_counter()
144 | with torch.no_grad():
145 | torchsde.sdeint(sde, y0, ts, bm, method='euler')
146 | return time.perf_counter() - now
147 |
148 |
149 | def _time_sdeint_bp(sde, y0, ts, bm):
150 | now = time.perf_counter()
151 | sde.zero_grad()
152 | y0 = y0.clone().requires_grad_(True)
153 | ys = torchsde.sdeint(sde, y0, ts, bm, method='euler')
154 | ys.sum().backward()
155 | return time.perf_counter() - now
156 |
157 |
158 | def _time_sdeint_adjoint(sde, y0, ts, bm):
159 | now = time.perf_counter()
160 | sde.zero_grad()
161 | y0 = y0.clone().requires_grad_(True)
162 | ys = torchsde.sdeint_adjoint(sde, y0, ts, bm, method='euler')
163 | ys.sum().backward()
164 | return time.perf_counter() - now
165 |
166 |
167 | def _compare_sdeint(w0, sde, y0, ts, func, msg=''):
168 | bm = torchsde.BrownianPath(t0=t0, w0=w0)
169 | bp_py_time = func(sde, y0, ts, bm)
170 | logging.warning(f'{msg} (torchsde.BrownianPath): {bp_py_time:.4f}')
171 |
172 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, tol=1e-5)
173 | bt_py_time = func(sde, y0, ts, bm)
174 | logging.warning(f'{msg} (torchsde.BrownianTree): {bt_py_time:.4f}')
175 |
176 | bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=w0.shape, dtype=w0.dtype, device=w0.device)
177 | bi_py_time = func(sde, y0, ts, bm)
178 | logging.warning(f'{msg} (torchsde.BrownianInterval): {bi_py_time:.4f}')
179 |
180 | return bp_py_time, bt_py_time, bi_py_time
181 |
182 |
183 | def solver_access(func=_time_sdeint):
184 | ts = torch.linspace(t0, t1, steps)
185 | sde = SDE().to(device)
186 |
187 | y0 = w0 = torch.zeros(small_batch_size, small_d).to(device)
188 | bp_py_time_s, bt_py_time_s, bi_py_time_s = _compare_sdeint(w0, sde, y0, ts, func, msg='small')
189 |
190 | y0 = w0 = torch.zeros(large_batch_size, large_d).to(device)
191 | bp_py_time_l, bt_py_time_l, bi_py_time_l = _compare_sdeint(w0, sde, y0, ts, func, msg='large')
192 |
193 | y0 = w0 = torch.zeros(huge_batch_size, huge_d).to(device)
194 | bp_py_time_h, bt_py_time_h, bi_py_time_h = _compare_sdeint(w0, sde, y0, ts, func, msg='huge')
195 |
196 | name = {
197 | _time_sdeint: 'sdeint',
198 | _time_sdeint_bp: 'sdeint-backprop-solver',
199 | _time_sdeint_adjoint: 'sdeint-backprop-adjoint'
200 | }[func]
201 |
202 | img_path = os.path.join('.', 'benchmarks', f'plots-{device}', f'{name}.png')
203 | if not os.path.exists(os.path.dirname(img_path)):
204 | os.makedirs(os.path.dirname(img_path))
205 |
206 | xaxis = [small_batch_size * small_d, large_batch_size * large_batch_size, huge_batch_size * huge_d]
207 |
208 | utils.swiss_knife_plotter(
209 | img_path,
210 | plots=[
211 | {'x': xaxis, 'y': [bp_py_time_s, bp_py_time_l, bp_py_time_h], 'label': 'bp_py', 'marker': 'x'},
212 | {'x': xaxis, 'y': [bt_py_time_s, bt_py_time_l, bt_py_time_h], 'label': 'bt_py', 'marker': 'x'},
213 | {'x': xaxis, 'y': [bi_py_time_s, bi_py_time_l, bi_py_time_h], 'label': 'bi_py', 'marker': 'x'},
214 | ],
215 | options={
216 | 'xscale': 'log',
217 | 'yscale': 'log',
218 | 'xlabel': 'size of tensor',
219 | 'ylabel': f'wall time on {device}',
220 | 'title': name
221 | }
222 | )
223 |
224 |
225 | def main():
226 | sequential_access()
227 | random_access()
228 |
229 | solver_access(func=_time_sdeint)
230 | solver_access(func=_time_sdeint_bp)
231 | solver_access(func=_time_sdeint_adjoint)
232 |
233 |
234 | if __name__ == "__main__":
235 | parser = argparse.ArgumentParser()
236 | parser.add_argument('--no-gpu', action='store_true')
237 | parser.add_argument('--debug', action='store_true')
238 | parser.add_argument('--seed', type=int, default=0)
239 |
240 | args = parser.parse_args()
241 | device = torch.device('cuda' if torch.cuda.is_available() and not args.no_gpu else 'cpu')
242 |
243 | npr.seed(args.seed)
244 | torch.manual_seed(args.seed)
245 |
246 | if args.debug:
247 | logging.getLogger().setLevel(logging.INFO)
248 |
249 | main()
250 |
--------------------------------------------------------------------------------
/benchmarks/profile_btree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import os
17 | import time
18 |
19 | import matplotlib.pyplot as plt
20 | import torch
21 | import tqdm
22 |
23 | from torchsde import BrownianTree
24 |
25 |
26 | def run_torch(ks=(0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12)):
27 | w0 = torch.zeros(b, d)
28 |
29 | t_cons = []
30 | t_queries = []
31 | t_alls = []
32 | for k in tqdm.tqdm(ks):
33 | now = time.time()
34 | bm_vanilla = BrownianTree(t0=t0, t1=t1, w0=w0, cache_depth=k)
35 | t_con = time.time() - now
36 | t_cons.append(t_con)
37 |
38 | now = time.time()
39 | for t in ts:
40 | bm_vanilla(t).to(device)
41 | t_query = time.time() - now
42 | t_queries.append(t_query)
43 |
44 | t_all = t_con + t_query
45 | t_alls.append(t_all)
46 | logging.warning(f'k={k}, t_con={t_con:.4f}, t_query={t_query:.4f}, t_all={t_all:.4f}')
47 |
48 | img_path = os.path.join('.', 'diagnostics', 'plots', 'profile_btree.png')
49 | plt.figure()
50 | plt.plot(ks, t_cons, label='cons')
51 | plt.plot(ks, t_queries, label='queries')
52 | plt.plot(ks, t_alls, label='all')
53 | plt.title(f'b={b}, d={d}, repetitions={reps}, device={w0.device}')
54 | plt.xlabel('Cache level')
55 | plt.ylabel('Time (secs)')
56 | plt.legend()
57 | plt.savefig(img_path)
58 | plt.close()
59 |
60 |
61 | def main():
62 | run_torch()
63 |
64 |
65 | if __name__ == "__main__":
66 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
67 | torch.manual_seed(1147481649)
68 |
69 | reps = 500
70 | b, d = 512, 10
71 |
72 | t0, t1 = 0., 1.
73 | ts = torch.rand(size=(reps,)).numpy()
74 |
75 | main()
76 |
--------------------------------------------------------------------------------
/diagnostics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/diagnostics/inspection.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import sys
17 |
18 | import torch
19 | import tqdm
20 |
21 | from torchsde import BaseBrownian, BaseSDE, sdeint
22 | from torchsde.settings import SDE_TYPES
23 | from torchsde.types import Tensor, Vector, Scalar, Tuple, Optional, Callable
24 | from . import utils
25 |
26 |
27 | sys.setrecursionlimit(5000)
28 |
29 |
30 | @torch.no_grad()
31 | def inspect_samples(y0: Tensor,
32 | ts: Vector,
33 | dt: Scalar,
34 | sde: BaseSDE,
35 | bm: BaseBrownian,
36 | img_dir: str,
37 | methods: Tuple[str, ...],
38 | options: Optional[Tuple] = None,
39 | labels: Optional[Tuple[str, ...]] = None,
40 | vis_dim=0,
41 | dt_true: Optional[float] = 2 ** -14):
42 | if options is None:
43 | options = (None,) * len(methods)
44 | if labels is None:
45 | labels = methods
46 |
47 | solns = [
48 | sdeint(sde, y0, ts, bm, method=method, dt=dt, options=options_)
49 | for method, options_ in zip(methods, options)
50 | ]
51 |
52 | method_for_true = 'euler' if sde.sde_type == SDE_TYPES.ito else 'midpoint'
53 | true = sdeint(sde, y0, ts, bm, method=method_for_true, dt=dt_true)
54 |
55 | labels += ('true',)
56 | solns += [true]
57 |
58 | # (T, batch_size, d) -> (T, batch_size) -> (batch_size, T).
59 | solns = [soln[..., vis_dim].t() for soln in solns]
60 |
61 | for i, samples in enumerate(zip(*solns)):
62 | utils.swiss_knife_plotter(
63 | img_path=os.path.join(img_dir, f'{i}'),
64 | plots=[
65 | {'x': ts, 'y': sample, 'label': label, 'marker': 'x'}
66 | for sample, label in zip(samples, labels)
67 | ]
68 | )
69 |
70 |
71 | @torch.no_grad()
72 | def inspect_orders(y0: Tensor,
73 | t0: Scalar,
74 | t1: Scalar,
75 | dts: Vector,
76 | sde: BaseSDE,
77 | bm: BaseBrownian,
78 | img_dir: str,
79 | methods: Tuple[str, ...],
80 | options: Optional[Tuple] = None,
81 | labels: Optional[Tuple[str, ...]] = None,
82 | dt_true: Optional[float] = 2 ** -14,
83 | test_func: Optional[Callable] = lambda x: (x ** 2).flatten(start_dim=1).sum(dim=1)):
84 | if options is None:
85 | options = (None,) * len(methods)
86 | if labels is None:
87 | labels = methods
88 |
89 | ts = torch.tensor([t0, t1], device=y0.device)
90 |
91 | solns = [
92 | [
93 | sdeint(sde, y0, ts, bm, method=method, dt=dt, options=options_)[-1]
94 | for method, options_ in zip(methods, options)
95 | ]
96 | for dt in tqdm.tqdm(dts)
97 | ]
98 |
99 | if hasattr(sde, 'analytical_sample'):
100 | true = sde.analytical_sample(y0, ts, bm)[-1]
101 | else:
102 | method_for_true = 'euler' if sde.sde_type == SDE_TYPES.ito else 'midpoint'
103 | true = sdeint(sde, y0, ts, bm, method=method_for_true, dt=dt_true)[-1]
104 |
105 | mses = []
106 | maes = []
107 | for dt, solns_ in zip(dts, solns):
108 | mses_for_dt = [utils.mse(soln, true) for soln in solns_]
109 | mses.append(mses_for_dt)
110 |
111 | maes_for_dt = [utils.mae(soln, true, test_func) for soln in solns_]
112 | maes.append(maes_for_dt)
113 |
114 | strong_order_slopes = [
115 | utils.linregress_slope(utils.log(dts), .5 * utils.log(mses_for_method))
116 | for mses_for_method in zip(*mses)
117 | ]
118 |
119 | weak_order_slopes = [
120 | utils.linregress_slope(utils.log(dts), utils.log(maes_for_method))
121 | for maes_for_method in zip(*maes)
122 | ]
123 |
124 | utils.swiss_knife_plotter(
125 | img_path=os.path.join(img_dir, 'strong_order'),
126 | plots=[
127 | {'x': dts, 'y': mses_for_method, 'label': f'{label}(k={slope:.4f})', 'marker': 'x'}
128 | for mses_for_method, label, slope in zip(zip(*mses), labels, strong_order_slopes)
129 | ],
130 | options={'xscale': 'log', 'yscale': 'log', 'cycle_line_style': True}
131 | )
132 |
133 | utils.swiss_knife_plotter(
134 | img_path=os.path.join(img_dir, 'weak_order'),
135 | plots=[
136 | {'x': dts, 'y': mres_for_method, 'label': f'{label}(k={slope:.4f})', 'marker': 'x'}
137 | for mres_for_method, label, slope in zip(zip(*maes), labels, weak_order_slopes)
138 | ],
139 | options={'xscale': 'log', 'yscale': 'log', 'cycle_line_style': True}
140 | )
141 |
--------------------------------------------------------------------------------
/diagnostics/ito_additive.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralAdditive
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralAdditive(d=d, m=m).to(device)
36 | methods = ('euler', 'milstein', 'milstein', 'srk')
37 | options = (None, None, dict(grad_free=True), None)
38 | labels = ('euler', 'milstein', 'gradient-free milstein', 'srk')
39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_additive')
40 |
41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
42 | bm = BrownianInterval(
43 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device,
44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time
45 | )
46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
47 |
48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
49 | bm = BrownianInterval(
50 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device,
51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time
52 | )
53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/diagnostics/ito_diagonal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralDiagonal
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d = 16, 16384, 5
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralDiagonal(d=d).to(device)
36 | methods = ('euler', 'milstein', 'milstein', 'srk')
37 | options = (None, None, dict(grad_free=True), None)
38 | labels = ('euler', 'milstein', 'gradient-free milstein', 'srk')
39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_diagonal')
40 |
41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
42 | bm = BrownianInterval(
43 | t0=t0, t1=t1, size=(small_batch_size, d), dtype=y0.dtype, device=device,
44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time
45 | )
46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
47 |
48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
49 | bm = BrownianInterval(
50 | t0=t0, t1=t1, size=(large_batch_size, d), dtype=y0.dtype, device=device,
51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time
52 | )
53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/diagnostics/ito_general.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralGeneral
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralGeneral(d=d, m=m).to(device)
36 | methods = ('euler',)
37 | options = (None,)
38 | labels = ('euler',)
39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_general')
40 |
41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
42 | bm = BrownianInterval(
43 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device,
44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
45 | )
46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
47 |
48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
49 | bm = BrownianInterval(
50 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device,
51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
52 | )
53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/diagnostics/ito_scalar.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralScalar
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d = 16, 16384, 5
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralScalar(d=d).to(device)
36 | methods = ('euler', 'milstein', 'milstein', 'srk')
37 | options = (None, None, dict(grad_free=True), None)
38 | labels = ('euler', 'milstein', 'gradient-free milstein', 'srk')
39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'ito_scalar')
40 |
41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
42 | bm = BrownianInterval(
43 | t0=t0, t1=t1, size=(small_batch_size, 1), dtype=y0.dtype, device=device,
44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time
45 | )
46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
47 |
48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
49 | bm = BrownianInterval(
50 | t0=t0, t1=t1, size=(large_batch_size, 1), dtype=y0.dtype, device=device,
51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.space_time
52 | )
53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/diagnostics/run_all.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import ito_additive, ito_diagonal, ito_general, ito_scalar
16 | from . import stratonovich_additive, stratonovich_diagonal, stratonovich_general, stratonovich_scalar
17 |
18 | if __name__ == '__main__':
19 | for module in (ito_additive, ito_diagonal, ito_general, ito_scalar, stratonovich_additive, stratonovich_diagonal,
20 | stratonovich_general, stratonovich_scalar):
21 | module.main()
22 |
--------------------------------------------------------------------------------
/diagnostics/stratonovich_additive.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralAdditive
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralAdditive(d=d, m=m, sde_type=SDE_TYPES.stratonovich).to(device)
36 | # Milstein and log-ODE should in theory both superfluously compute zeros here.
37 | # We include them anyway to check that they do what they claim to do.
38 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'milstein', 'log_ode')
39 | options = (None, None, None, None, None, dict(grad_free=True), None)
40 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'grad-free milstein', 'log_ode')
41 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_additive')
42 |
43 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
44 | bm = BrownianInterval(
45 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device,
46 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
47 | )
48 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
49 |
50 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
51 | bm = BrownianInterval(
52 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device,
53 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
54 | )
55 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
--------------------------------------------------------------------------------
/diagnostics/stratonovich_diagonal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralDiagonal
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d = 16, 16384, 3
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralDiagonal(d=d, sde_type=SDE_TYPES.stratonovich).to(device)
36 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'milstein', 'log_ode')
37 | options = (None, None, None, None, None, dict(grad_free=True), None)
38 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'grad-free milstein', 'log_ode')
39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_diagonal')
40 |
41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
42 | bm = BrownianInterval(
43 | t0=t0, t1=t1, size=(small_batch_size, d), dtype=y0.dtype, device=device,
44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
45 | )
46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
47 |
48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
49 | bm = BrownianInterval(
50 | t0=t0, t1=t1, size=(large_batch_size, d), dtype=y0.dtype, device=device,
51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
52 | )
53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/diagnostics/stratonovich_general.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralGeneral
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralGeneral(d=d, m=m, sde_type=SDE_TYPES.stratonovich).to(device)
36 | # Don't include Milstein as it doesn't work for general noise.
37 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'log_ode')
38 | options = (None, None, None, None, None)
39 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'log_ode')
40 |
41 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_general')
42 |
43 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
44 | bm = BrownianInterval(
45 | t0=t0, t1=t1, size=(small_batch_size, m), dtype=y0.dtype, device=device,
46 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
47 | )
48 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
49 |
50 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
51 | bm = BrownianInterval(
52 | t0=t0, t1=t1, size=(large_batch_size, m), dtype=y0.dtype, device=device,
53 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
54 | )
55 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
--------------------------------------------------------------------------------
/diagnostics/stratonovich_scalar.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 |
19 | from tests.problems import NeuralScalar
20 | from torchsde import BrownianInterval
21 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES
22 | from . import inspection
23 | from . import utils
24 |
25 |
26 | def main():
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | torch.set_default_dtype(torch.float64)
29 | utils.manual_seed()
30 |
31 | small_batch_size, large_batch_size, d = 16, 16384, 3
32 | t0, t1, steps, dt = 0., 2., 10, 1e-1
33 | ts = torch.linspace(t0, t1, steps=steps, device=device)
34 | dts = tuple(2 ** -i for i in range(1, 7)) # For checking strong order.
35 | sde = NeuralScalar(d=d, sde_type=SDE_TYPES.stratonovich).to(device)
36 | methods = ('euler_heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'milstein', 'log_ode')
37 | options = (None, None, None, None, None, dict(grad_free=True), None)
38 | labels = ('euler-heun', 'heun', 'midpoint', 'reversible_heun', 'milstein', 'grad-free milstein', 'log_ode')
39 | img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_scalar')
40 |
41 | y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
42 | bm = BrownianInterval(
43 | t0=t0, t1=t1, size=(small_batch_size, 1), dtype=y0.dtype, device=device,
44 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
45 | )
46 | inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods, options, labels)
47 |
48 | y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
49 | bm = BrownianInterval(
50 | t0=t0, t1=t1, size=(large_batch_size, 1), dtype=y0.dtype, device=device,
51 | levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
52 | )
53 | inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods, options, labels)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/diagnostics/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import itertools
16 | import os
17 | import random
18 |
19 | import matplotlib.pyplot as plt
20 | import numpy as np
21 | import torch
22 | from scipy import stats
23 |
24 | from torchsde.types import Optional, Tensor, Sequence, Union, Callable
25 |
26 |
27 | def to_numpy(*args):
28 | """Convert a sequence which might contain Tensors to numpy arrays."""
29 | if len(args) == 1:
30 | arg = args[0]
31 | if isinstance(arg, torch.Tensor):
32 | arg = _to_numpy_single(arg)
33 | return arg
34 | else:
35 | return tuple(_to_numpy_single(arg) if isinstance(arg, torch.Tensor) else arg for arg in args)
36 |
37 |
38 | def _to_numpy_single(arg: torch.Tensor) -> np.ndarray:
39 | return arg.detach().cpu().numpy()
40 |
41 |
42 | def mse(x: Tensor, y: Tensor, norm_dim: Optional[int] = 1, mean_dim: Optional[int] = 0) -> np.ndarray:
43 | """Compute mean squared error."""
44 | return _to_numpy_single((torch.norm(x - y, dim=norm_dim) ** 2).mean(dim=mean_dim))
45 |
46 |
47 | def mae(x: Tensor, y: Tensor, test_func: Callable, mean_dim: Optional[int] = 0) -> np.ndarray:
48 | return _to_numpy_single(
49 | abs(test_func(x).mean(mean_dim) - test_func(y).mean(mean_dim))
50 | )
51 |
52 |
53 | def log(x: Union[Sequence[float], np.ndarray]) -> np.ndarray:
54 | """Compute element-wise log of a sequence of floats."""
55 | return np.log(np.array(x))
56 |
57 |
58 | def linregress_slope(x, y):
59 | """Return the slope of a least-squares regression for two sets of measurements."""
60 | return stats.linregress(x, y)[0]
61 |
62 |
63 | def swiss_knife_plotter(img_path, plots=None, scatters=None, hists=None, options=None):
64 | """A multi-functional *standalone* wrapper; reduces boilerplate.
65 |
66 | Args:
67 | img_path (str): A path to the place where the image should be written.
68 | plots (list of dict, optional): A list of curves that needs `plt.plot`.
69 | scatters (list of dict, optional): A list of scatter plots that needs `plt.scatter`.
70 | hists (list of histograms, optional): A list of histograms that needs `plt.hist`.
71 | options (dict, optional): A dictionary of optional arguments. Possible entries include
72 | - xscale (str): Scale of xaxis.
73 | - yscale (str): Scale of yaxis.
74 | - xlabel (str): Label of xaxis.
75 | - ylabel (str): Label of yaxis.
76 | - title (str): Title of the plot.
77 | - cycle_linestyle (bool): Cycle through matplotlib's possible line styles if True.
78 |
79 | Returns:
80 | Nothing.
81 | """
82 | img_dir = os.path.dirname(img_path)
83 | if not os.path.exists(img_dir):
84 | os.makedirs(img_dir)
85 |
86 | if plots is None: plots = ()
87 | if scatters is None: scatters = ()
88 | if hists is None: hists = ()
89 | if options is None: options = {}
90 |
91 | plt.figure(dpi=300)
92 | if 'xscale' in options: plt.xscale(options['xscale'])
93 | if 'yscale' in options: plt.yscale(options['yscale'])
94 | if 'xlabel' in options: plt.xlabel(options['xlabel'])
95 | if 'ylabel' in options: plt.ylabel(options['ylabel'])
96 | if 'title' in options: plt.title(options['title'])
97 |
98 | cycle_linestyle = options.get('cycle_linestyle', False)
99 | cycler = itertools.cycle(["-", "--", "-.", ":"]) if cycle_linestyle else None
100 | for entry in plots:
101 | kwargs = {key: entry[key] for key in entry if key != 'x' and key != 'y'}
102 | entry['x'], entry['y'] = to_numpy(entry['x'], entry['y'])
103 | if cycle_linestyle:
104 | kwargs['linestyle'] = next(cycler)
105 | plt.plot(entry['x'], entry['y'], **kwargs)
106 |
107 | for entry in scatters:
108 | kwargs = {key: entry[key] for key in entry if key != 'x' and key != 'y'}
109 | entry['x'], entry['y'] = to_numpy(entry['x'], entry['y'])
110 | plt.scatter(entry['x'], entry['y'], **kwargs)
111 |
112 | for entry in hists:
113 | kwargs = {key: entry[key] for key in entry if key != 'x'}
114 | entry['x'] = to_numpy(entry['x'])
115 | plt.hist(entry['x'], **kwargs)
116 |
117 | if len(plots) > 0 or len(scatters) > 0: plt.legend()
118 | plt.tight_layout()
119 | plt.savefig(img_path)
120 | plt.close()
121 |
122 |
123 | def manual_seed(seed: Optional[int] = 1147481649):
124 | """Set seeds for default generators of 1) torch, 2) numpy, and 3) Python's random library."""
125 | torch.manual_seed(seed)
126 | np.random.seed(seed)
127 | random.seed(seed)
128 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/torchsde/eb3a00e31cbd56176270066ed2f62c394cf6acb7/examples/__init__.py
--------------------------------------------------------------------------------
/examples/cont_ddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A min example for continuous-time Denoising Diffusion Probabilistic Models.
16 |
17 | Trains the backward dynamics to be close to the reverse of a fixed forward
18 | dynamics via a score-matching-type objective.
19 |
20 | Trains a simple model on MNIST and samples from both the reverse ODE and
21 | SDE formulation.
22 |
23 | To run this file, first run the following to install extra requirements:
24 | pip install kornia
25 | pip install einops
26 | pip install torchdiffeq
27 | pip install fire
28 |
29 | To run, execute:
30 | python -m examples.cont_ddpm
31 | """
32 | import abc
33 | import logging
34 | import math
35 | import os
36 |
37 | import fire
38 | import torch
39 | import torchdiffeq
40 | import torchvision as tv
41 | import tqdm
42 | from torch import nn, optim
43 | from torch.utils import data
44 |
45 | import torchsde
46 | from . import unet
47 |
48 |
49 | def fill_tail_dims(y: torch.Tensor, y_like: torch.Tensor):
50 | """Fill in missing trailing dimensions for y according to y_like."""
51 | return y[(...,) + (None,) * (y_like.dim() - y.dim())]
52 |
53 |
54 | class Module(abc.ABC, nn.Module):
55 | """A wrapper module that's more convenient to use."""
56 |
57 | def __init__(self):
58 | super(Module, self).__init__()
59 | self._checkpoint = False
60 |
61 | def zero_grad(self) -> None:
62 | for p in self.parameters(): p.grad = None
63 |
64 | @property
65 | def device(self):
66 | return next(self.parameters()).device
67 |
68 |
69 | class ScoreMatchingSDE(Module):
70 | """Wraps score network with analytical sampling and cond. score computation.
71 |
72 | The variance preserving formulation in
73 | Score-Based Generative Modeling through Stochastic Differential Equations
74 | https://arxiv.org/abs/2011.13456
75 | """
76 |
77 | def __init__(self, denoiser, input_size=(1, 28, 28), t0=0., t1=1., beta_min=.1, beta_max=20.):
78 | super(ScoreMatchingSDE, self).__init__()
79 | if t0 > t1:
80 | raise ValueError(f"Expected t0 <= t1, but found t0={t0:.4f}, t1={t1:.4f}")
81 |
82 | self.input_size = input_size
83 | self.denoiser = denoiser
84 |
85 | self.t0 = t0
86 | self.t1 = t1
87 |
88 | self.beta_min = beta_min
89 | self.beta_max = beta_max
90 |
91 | def score(self, t, y):
92 | if isinstance(t, float):
93 | t = y.new_tensor(t)
94 | if t.dim() == 0:
95 | t = t.repeat(y.shape[0])
96 | return self.denoiser(t, y)
97 |
98 | def _beta(self, t):
99 | return self.beta_min + t * (self.beta_max - self.beta_min)
100 |
101 | def _indefinite_int(self, t):
102 | """Indefinite integral of beta(t)."""
103 | return self.beta_min * t + .5 * t ** 2 * (self.beta_max - self.beta_min)
104 |
105 | def analytical_mean(self, t, x_t0):
106 | mean_coeff = (-.5 * (self._indefinite_int(t) - self._indefinite_int(self.t0))).exp()
107 | mean = x_t0 * fill_tail_dims(mean_coeff, x_t0)
108 | return mean
109 |
110 | def analytical_var(self, t, x_t0):
111 | analytical_var = 1 - (-self._indefinite_int(t) + self._indefinite_int(self.t0)).exp()
112 | return analytical_var
113 |
114 | @torch.no_grad()
115 | def analytical_sample(self, t, x_t0):
116 | mean = self.analytical_mean(t, x_t0)
117 | var = self.analytical_var(t, x_t0)
118 | return mean + torch.randn_like(mean) * fill_tail_dims(var.sqrt(), mean)
119 |
120 | @torch.no_grad()
121 | def analytical_score(self, x_t, t, x_t0):
122 | mean = self.analytical_mean(t, x_t0)
123 | var = self.analytical_var(t, x_t0)
124 | return - (x_t - mean) / fill_tail_dims(var, mean).clamp_min(1e-5)
125 |
126 | def f(self, t, y):
127 | return -0.5 * self._beta(t) * y
128 |
129 | def g(self, t, y):
130 | return fill_tail_dims(self._beta(t).sqrt(), y).expand_as(y)
131 |
132 | def sample_t1_marginal(self, batch_size, tau=1.):
133 | return torch.randn(size=(batch_size, *self.input_size), device=self.device) * math.sqrt(tau)
134 |
135 | def lambda_t(self, t):
136 | return self.analytical_var(t, None)
137 |
138 | def forward(self, x_t0, partitions=1):
139 | """Compute the score matching objective.
140 | Split [t0, t1] into partitions; sample uniformly on each partition to reduce gradient variance.
141 | """
142 | u = torch.rand(size=(x_t0.shape[0], partitions), dtype=x_t0.dtype, device=x_t0.device)
143 | u.mul_((self.t1 - self.t0) / partitions)
144 | shifts = torch.arange(0, partitions, device=x_t0.device, dtype=x_t0.dtype)[None, :]
145 | shifts.mul_((self.t1 - self.t0) / partitions).add_(self.t0)
146 | t = (u + shifts).reshape(-1)
147 | lambda_t = self.lambda_t(t)
148 |
149 | x_t0 = x_t0.repeat_interleave(partitions, dim=0)
150 | x_t = self.analytical_sample(t, x_t0)
151 |
152 | fake_score = self.score(t, x_t)
153 | true_score = self.analytical_score(x_t, t, x_t0)
154 | loss = (lambda_t * ((fake_score - true_score) ** 2).flatten(start_dim=1).sum(dim=1))
155 | return loss
156 |
157 |
158 | class ReverseDiffeqWrapper(Module):
159 | """Wrapper of the score network for odeint/sdeint.
160 |
161 | We split this module out, so that `forward` of the score network is solely
162 | used for computing the score, and the `forward` here is used for odeint.
163 | Helps with data parallel.
164 | """
165 | noise_type = "diagonal"
166 | sde_type = "stratonovich"
167 |
168 | def __init__(self, module: ScoreMatchingSDE):
169 | super(ReverseDiffeqWrapper, self).__init__()
170 | self.module = module
171 |
172 | # --- odeint ---
173 | def forward(self, t, y):
174 | return -(self.module.f(-t, y) - .5 * self.module.g(-t, y) ** 2 * self.module.score(-t, y))
175 |
176 | # --- sdeint ---
177 | def f(self, t, y):
178 | y = y.view(-1, *self.module.input_size)
179 | out = -(self.module.f(-t, y) - self.module.g(-t, y) ** 2 * self.module.score(-t, y))
180 | return out.flatten(start_dim=1)
181 |
182 | def g(self, t, y):
183 | y = y.view(-1, *self.module.input_size)
184 | out = -self.module.g(-t, y)
185 | return out.flatten(start_dim=1)
186 |
187 | # --- sample ---
188 | def sample_t1_marginal(self, batch_size, tau=1.):
189 | return self.module.sample_t1_marginal(batch_size, tau)
190 |
191 | @torch.no_grad()
192 | def ode_sample(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2):
193 | self.module.eval()
194 |
195 | t = torch.tensor([-self.t1, -self.t0], device=self.device) if t is None else t
196 | y = self.sample_t1_marginal(batch_size, tau) if y is None else y
197 | return torchdiffeq.odeint(self, y, t, method="rk4", options={"step_size": dt})
198 |
199 | @torch.no_grad()
200 | def ode_sample_final(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2):
201 | return self.ode_sample(batch_size, tau, t, y, dt)[-1]
202 |
203 | @torch.no_grad()
204 | def sde_sample(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2, tweedie_correction=True):
205 | self.module.eval()
206 |
207 | t = torch.tensor([-self.t1, -self.t0], device=self.device) if t is None else t
208 | y = self.sample_t1_marginal(batch_size, tau) if y is None else y
209 |
210 | ys = torchsde.sdeint(self, y.flatten(start_dim=1), t, dt=dt)
211 | ys = ys.view(len(t), *y.size())
212 | if tweedie_correction:
213 | ys[-1] = self.tweedie_correction(self.t0, ys[-1], dt)
214 | return ys
215 |
216 | @torch.no_grad()
217 | def sde_sample_final(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2):
218 | return self.sde_sample(batch_size, tau, t, y, dt)[-1]
219 |
220 | def tweedie_correction(self, t, y, dt):
221 | return y + dt ** 2 * self.module.score(t, y)
222 |
223 | @property
224 | def t0(self):
225 | return self.module.t0
226 |
227 | @property
228 | def t1(self):
229 | return self.module.t1
230 |
231 |
232 | def preprocess(x, logit_transform, alpha=0.95):
233 | if logit_transform:
234 | x = alpha + (1 - 2 * alpha) * x
235 | x = (x / (1 - x)).log()
236 | else:
237 | x = (x - 0.5) * 2
238 | return x
239 |
240 |
241 | def postprocess(x, logit_transform, alpha=0.95, clamp=True):
242 | if logit_transform:
243 | x = (x.sigmoid() - alpha) / (1 - 2 * alpha)
244 | else:
245 | x = x * 0.5 + 0.5
246 | return x.clamp(min=0., max=1.) if clamp else x
247 |
248 |
249 | def make_loader(
250 | root="./data/mnist",
251 | train_batch_size=128,
252 | shuffle=True,
253 | pin_memory=True,
254 | num_workers=0,
255 | drop_last=True
256 | ):
257 | """Make a simple loader for training images in MNIST."""
258 |
259 | def dequantize(x, nvals=256):
260 | """[0, 1] -> [0, nvals] -> add uniform noise -> [0, 1]"""
261 | noise = x.new().resize_as_(x).uniform_()
262 | x = x * (nvals - 1) + noise
263 | x = x / nvals
264 | return x
265 |
266 | train_transform = tv.transforms.Compose([tv.transforms.ToTensor(), dequantize])
267 | train_data = tv.datasets.MNIST(root, train=True, transform=train_transform, download=True)
268 | train_loader = data.DataLoader(
269 | train_data,
270 | batch_size=train_batch_size,
271 | drop_last=drop_last,
272 | shuffle=shuffle,
273 | pin_memory=pin_memory,
274 | num_workers=num_workers
275 | )
276 | return train_loader
277 |
278 |
279 | def main(
280 | train_dir="./dump/cont_ddpm/",
281 | epochs=100,
282 | lr=1e-4,
283 | batch_size=128,
284 | pause_every=1000,
285 | tau=1.,
286 | logit_transform=True,
287 | ):
288 | """Train and sample once in a while.
289 |
290 | Args:
291 | train_dir: Path to a folder to dump things.
292 | epochs: Number of training epochs.
293 | lr: Learning rate for Adam.
294 | batch_size: Batch size for training.
295 | pause_every: Log and write figures once in this many iterations.
296 | tau: The temperature for sampling.
297 | logit_transform: Applies the typical logit transformation if True.
298 | """
299 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
300 |
301 | # Data.
302 | train_loader = make_loader(root=os.path.join(train_dir, 'data'), train_batch_size=batch_size)
303 |
304 | # Model + optimizer.
305 | denoiser = unet.Unet(
306 | input_size=(1, 28, 28),
307 | dim_mults=(1, 2, 4,),
308 | attention_cls=unet.LinearTimeSelfAttention,
309 | )
310 | forward = ScoreMatchingSDE(denoiser=denoiser).to(device)
311 | reverse = ReverseDiffeqWrapper(forward)
312 | optimizer = optim.Adam(params=forward.parameters(), lr=lr)
313 |
314 | def plot(imgs, path):
315 | assert not torch.any(torch.isnan(imgs)), "Found nans in images"
316 | os.makedirs(os.path.dirname(path), exist_ok=True)
317 | imgs = postprocess(imgs, logit_transform=logit_transform).detach().cpu()
318 | tv.utils.save_image(imgs, path)
319 |
320 | global_step = 0
321 | for epoch in range(epochs):
322 | for x, _ in tqdm.tqdm(train_loader):
323 | forward.train()
324 | forward.zero_grad()
325 | x = preprocess(x.to(device), logit_transform=logit_transform)
326 | loss = forward(x).mean(dim=0)
327 | loss.backward()
328 | optimizer.step()
329 | global_step += 1
330 |
331 | if global_step % pause_every == 0:
332 | logging.warning(f'global_step: {global_step:06d}, loss: {loss:.4f}')
333 |
334 | img_path = os.path.join(train_dir, 'ode_samples', f'global_step_{global_step:07d}.png')
335 | ode_samples = reverse.ode_sample_final(tau=tau)
336 | plot(ode_samples, img_path)
337 |
338 | img_path = os.path.join(train_dir, 'sde_samples', f'global_step_{global_step:07d}.png')
339 | sde_samples = reverse.sde_sample_final(tau=tau)
340 | plot(sde_samples, img_path)
341 |
342 |
343 | if __name__ == "__main__":
344 | fire.Fire(main)
345 |
--------------------------------------------------------------------------------
/examples/latent_sde_lorenz.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Train a latent SDE on data from a stochastic Lorenz attractor.
16 |
17 | Reproduce the toy example in Section 7.2 of https://arxiv.org/pdf/2001.01328.pdf
18 |
19 | To run this file, first run the following to install extra requirements:
20 | pip install fire
21 |
22 | To run, execute:
23 | python -m examples.latent_sde_lorenz
24 | """
25 | import logging
26 | import os
27 | from typing import Sequence
28 |
29 | import fire
30 | import matplotlib.gridspec as gridspec
31 | import matplotlib.pyplot as plt
32 | import numpy as np
33 | import torch
34 | import tqdm
35 | from torch import nn
36 | from torch import optim
37 | from torch.distributions import Normal
38 |
39 | import torchsde
40 |
41 |
42 | class LinearScheduler(object):
43 | def __init__(self, iters, maxval=1.0):
44 | self._iters = max(1, iters)
45 | self._val = maxval / self._iters
46 | self._maxval = maxval
47 |
48 | def step(self):
49 | self._val = min(self._maxval, self._val + self._maxval / self._iters)
50 |
51 | @property
52 | def val(self):
53 | return self._val
54 |
55 |
56 | class StochasticLorenz(object):
57 | """Stochastic Lorenz attractor.
58 |
59 | Used for simulating ground truth and obtaining noisy data.
60 | Details described in Section 7.2 https://arxiv.org/pdf/2001.01328.pdf
61 | Default a, b from https://openreview.net/pdf?id=HkzRQhR9YX
62 | """
63 | noise_type = "diagonal"
64 | sde_type = "ito"
65 |
66 | def __init__(self, a: Sequence = (10., 28., 8 / 3), b: Sequence = (.1, .28, .3)):
67 | super(StochasticLorenz, self).__init__()
68 | self.a = a
69 | self.b = b
70 |
71 | def f(self, t, y):
72 | x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1)
73 | a1, a2, a3 = self.a
74 |
75 | f1 = a1 * (x2 - x1)
76 | f2 = a2 * x1 - x2 - x1 * x3
77 | f3 = x1 * x2 - a3 * x3
78 | return torch.cat([f1, f2, f3], dim=1)
79 |
80 | def g(self, t, y):
81 | x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1)
82 | b1, b2, b3 = self.b
83 |
84 | g1 = x1 * b1
85 | g2 = x2 * b2
86 | g3 = x3 * b3
87 | return torch.cat([g1, g2, g3], dim=1)
88 |
89 | @torch.no_grad()
90 | def sample(self, x0, ts, noise_std, normalize):
91 | """Sample data for training. Store data normalization constants if necessary."""
92 | xs = torchsde.sdeint(self, x0, ts)
93 | if normalize:
94 | mean, std = torch.mean(xs, dim=(0, 1)), torch.std(xs, dim=(0, 1))
95 | xs.sub_(mean).div_(std).add_(torch.randn_like(xs) * noise_std)
96 | return xs
97 |
98 |
99 | class Encoder(nn.Module):
100 | def __init__(self, input_size, hidden_size, output_size):
101 | super(Encoder, self).__init__()
102 | self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size)
103 | self.lin = nn.Linear(hidden_size, output_size)
104 |
105 | def forward(self, inp):
106 | out, _ = self.gru(inp)
107 | out = self.lin(out)
108 | return out
109 |
110 |
111 | class LatentSDE(nn.Module):
112 | sde_type = "ito"
113 | noise_type = "diagonal"
114 |
115 | def __init__(self, data_size, latent_size, context_size, hidden_size):
116 | super(LatentSDE, self).__init__()
117 | # Encoder.
118 | self.encoder = Encoder(input_size=data_size, hidden_size=hidden_size, output_size=context_size)
119 | self.qz0_net = nn.Linear(context_size, latent_size + latent_size)
120 |
121 | # Decoder.
122 | self.f_net = nn.Sequential(
123 | nn.Linear(latent_size + context_size, hidden_size),
124 | nn.Softplus(),
125 | nn.Linear(hidden_size, hidden_size),
126 | nn.Softplus(),
127 | nn.Linear(hidden_size, latent_size),
128 | )
129 | self.h_net = nn.Sequential(
130 | nn.Linear(latent_size, hidden_size),
131 | nn.Softplus(),
132 | nn.Linear(hidden_size, hidden_size),
133 | nn.Softplus(),
134 | nn.Linear(hidden_size, latent_size),
135 | )
136 | # This needs to be an element-wise function for the SDE to satisfy diagonal noise.
137 | self.g_nets = nn.ModuleList(
138 | [
139 | nn.Sequential(
140 | nn.Linear(1, hidden_size),
141 | nn.Softplus(),
142 | nn.Linear(hidden_size, 1),
143 | nn.Sigmoid()
144 | )
145 | for _ in range(latent_size)
146 | ]
147 | )
148 | self.projector = nn.Linear(latent_size, data_size)
149 |
150 | self.pz0_mean = nn.Parameter(torch.zeros(1, latent_size))
151 | self.pz0_logstd = nn.Parameter(torch.zeros(1, latent_size))
152 |
153 | self._ctx = None
154 |
155 | def contextualize(self, ctx):
156 | self._ctx = ctx # A tuple of tensors of sizes (T,), (T, batch_size, d).
157 |
158 | def f(self, t, y):
159 | ts, ctx = self._ctx
160 | i = min(torch.searchsorted(ts, t, right=True), len(ts) - 1)
161 | return self.f_net(torch.cat((y, ctx[i]), dim=1))
162 |
163 | def h(self, t, y):
164 | return self.h_net(y)
165 |
166 | def g(self, t, y): # Diagonal diffusion.
167 | y = torch.split(y, split_size_or_sections=1, dim=1)
168 | out = [g_net_i(y_i) for (g_net_i, y_i) in zip(self.g_nets, y)]
169 | return torch.cat(out, dim=1)
170 |
171 | def forward(self, xs, ts, noise_std, adjoint=False, method="euler"):
172 | # Contextualization is only needed for posterior inference.
173 | ctx = self.encoder(torch.flip(xs, dims=(0,)))
174 | ctx = torch.flip(ctx, dims=(0,))
175 | self.contextualize((ts, ctx))
176 |
177 | qz0_mean, qz0_logstd = self.qz0_net(ctx[0]).chunk(chunks=2, dim=1)
178 | z0 = qz0_mean + qz0_logstd.exp() * torch.randn_like(qz0_mean)
179 |
180 | if adjoint:
181 | # Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`.
182 | adjoint_params = (
183 | (ctx,) +
184 | tuple(self.f_net.parameters()) + tuple(self.g_nets.parameters()) + tuple(self.h_net.parameters())
185 | )
186 | zs, log_ratio = torchsde.sdeint_adjoint(
187 | self, z0, ts, adjoint_params=adjoint_params, dt=1e-2, logqp=True, method=method)
188 | else:
189 | zs, log_ratio = torchsde.sdeint(self, z0, ts, dt=1e-2, logqp=True, method=method)
190 |
191 | _xs = self.projector(zs)
192 | xs_dist = Normal(loc=_xs, scale=noise_std)
193 | log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean(dim=0)
194 |
195 | qz0 = torch.distributions.Normal(loc=qz0_mean, scale=qz0_logstd.exp())
196 | pz0 = torch.distributions.Normal(loc=self.pz0_mean, scale=self.pz0_logstd.exp())
197 | logqp0 = torch.distributions.kl_divergence(qz0, pz0).sum(dim=1).mean(dim=0)
198 | logqp_path = log_ratio.sum(dim=0).mean(dim=0)
199 | return log_pxs, logqp0 + logqp_path
200 |
201 | @torch.no_grad()
202 | def sample(self, batch_size, ts, bm=None):
203 | eps = torch.randn(size=(batch_size, *self.pz0_mean.shape[1:]), device=self.pz0_mean.device)
204 | z0 = self.pz0_mean + self.pz0_logstd.exp() * eps
205 | zs = torchsde.sdeint(self, z0, ts, names={'drift': 'h'}, dt=1e-3, bm=bm)
206 | # Most of the times in ML, we don't sample the observation noise for visualization purposes.
207 | _xs = self.projector(zs)
208 | return _xs
209 |
210 |
211 | def make_dataset(t0, t1, batch_size, noise_std, train_dir, device):
212 | data_path = os.path.join(train_dir, 'lorenz_data.pth')
213 | if os.path.exists(data_path):
214 | data_dict = torch.load(data_path)
215 | xs, ts = data_dict['xs'], data_dict['ts']
216 | logging.warning(f'Loaded toy data at: {data_path}')
217 | if xs.shape[1] != batch_size:
218 | raise ValueError("Batch size has changed; please delete and regenerate the data.")
219 | if ts[0] != t0 or ts[-1] != t1:
220 | raise ValueError("Times interval [t0, t1] has changed; please delete and regenerate the data.")
221 | else:
222 | _y0 = torch.randn(batch_size, 3, device=device)
223 | ts = torch.linspace(t0, t1, steps=100, device=device)
224 | xs = StochasticLorenz().sample(_y0, ts, noise_std, normalize=True)
225 |
226 | os.makedirs(os.path.dirname(data_path), exist_ok=True)
227 | torch.save({'xs': xs, 'ts': ts}, data_path)
228 | logging.warning(f'Stored toy data at: {data_path}')
229 | return xs, ts
230 |
231 |
232 | def vis(xs, ts, latent_sde, bm_vis, img_path, num_samples=10):
233 | fig = plt.figure(figsize=(20, 9))
234 | gs = gridspec.GridSpec(1, 2)
235 | ax00 = fig.add_subplot(gs[0, 0], projection='3d')
236 | ax01 = fig.add_subplot(gs[0, 1], projection='3d')
237 |
238 | # Left plot: data.
239 | z1, z2, z3 = np.split(xs.cpu().numpy(), indices_or_sections=3, axis=-1)
240 | [ax00.plot(z1[:, i, 0], z2[:, i, 0], z3[:, i, 0]) for i in range(num_samples)]
241 | ax00.scatter(z1[0, :num_samples, 0], z2[0, :num_samples, 0], z3[0, :10, 0], marker='x')
242 | ax00.set_yticklabels([])
243 | ax00.set_xticklabels([])
244 | ax00.set_zticklabels([])
245 | ax00.set_xlabel('$z_1$', labelpad=0., fontsize=16)
246 | ax00.set_ylabel('$z_2$', labelpad=.5, fontsize=16)
247 | ax00.set_zlabel('$z_3$', labelpad=0., horizontalalignment='center', fontsize=16)
248 | ax00.set_title('Data', fontsize=20)
249 | xlim = ax00.get_xlim()
250 | ylim = ax00.get_ylim()
251 | zlim = ax00.get_zlim()
252 |
253 | # Right plot: samples from learned model.
254 | xs = latent_sde.sample(batch_size=xs.size(1), ts=ts, bm=bm_vis).cpu().numpy()
255 | z1, z2, z3 = np.split(xs, indices_or_sections=3, axis=-1)
256 |
257 | [ax01.plot(z1[:, i, 0], z2[:, i, 0], z3[:, i, 0]) for i in range(num_samples)]
258 | ax01.scatter(z1[0, :num_samples, 0], z2[0, :num_samples, 0], z3[0, :10, 0], marker='x')
259 | ax01.set_yticklabels([])
260 | ax01.set_xticklabels([])
261 | ax01.set_zticklabels([])
262 | ax01.set_xlabel('$z_1$', labelpad=0., fontsize=16)
263 | ax01.set_ylabel('$z_2$', labelpad=.5, fontsize=16)
264 | ax01.set_zlabel('$z_3$', labelpad=0., horizontalalignment='center', fontsize=16)
265 | ax01.set_title('Samples', fontsize=20)
266 | ax01.set_xlim(xlim)
267 | ax01.set_ylim(ylim)
268 | ax01.set_zlim(zlim)
269 |
270 | plt.savefig(img_path)
271 | plt.close()
272 |
273 |
274 | def main(
275 | batch_size=1024,
276 | latent_size=4,
277 | context_size=64,
278 | hidden_size=128,
279 | lr_init=1e-2,
280 | t0=0.,
281 | t1=2.,
282 | lr_gamma=0.997,
283 | num_iters=5000,
284 | kl_anneal_iters=1000,
285 | pause_every=50,
286 | noise_std=0.01,
287 | adjoint=False,
288 | train_dir='./dump/lorenz/',
289 | method="euler",
290 | ):
291 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
292 |
293 | xs, ts = make_dataset(t0=t0, t1=t1, batch_size=batch_size, noise_std=noise_std, train_dir=train_dir, device=device)
294 | latent_sde = LatentSDE(
295 | data_size=3,
296 | latent_size=latent_size,
297 | context_size=context_size,
298 | hidden_size=hidden_size,
299 | ).to(device)
300 | optimizer = optim.Adam(params=latent_sde.parameters(), lr=lr_init)
301 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_gamma)
302 | kl_scheduler = LinearScheduler(iters=kl_anneal_iters)
303 |
304 | # Fix the same Brownian motion for visualization.
305 | bm_vis = torchsde.BrownianInterval(
306 | t0=t0, t1=t1, size=(batch_size, latent_size,), device=device, levy_area_approximation="space-time")
307 |
308 | for global_step in tqdm.tqdm(range(1, num_iters + 1)):
309 | latent_sde.zero_grad()
310 | log_pxs, log_ratio = latent_sde(xs, ts, noise_std, adjoint, method)
311 | loss = -log_pxs + log_ratio * kl_scheduler.val
312 | loss.backward()
313 | optimizer.step()
314 | scheduler.step()
315 | kl_scheduler.step()
316 |
317 | if global_step % pause_every == 0:
318 | lr_now = optimizer.param_groups[0]['lr']
319 | logging.warning(
320 | f'global_step: {global_step:06d}, lr: {lr_now:.5f}, '
321 | f'log_pxs: {log_pxs:.4f}, log_ratio: {log_ratio:.4f} loss: {loss:.4f}, kl_coeff: {kl_scheduler.val:.4f}'
322 | )
323 | img_path = os.path.join(train_dir, f'global_step_{global_step:06d}.pdf')
324 | vis(xs, ts, latent_sde, bm_vis, img_path)
325 |
326 |
327 | if __name__ == "__main__":
328 | fire.Fire(main)
329 |
--------------------------------------------------------------------------------
/examples/unet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """U-Nets for continuous-time Denoising Diffusion Probabilistic Models.
16 |
17 | This file only serves as a helper for `examples/cont_ddpm.py`.
18 |
19 | To use this file, run the following to install extra requirements:
20 |
21 | pip install kornia
22 | pip install einops
23 | """
24 | import math
25 |
26 | import kornia
27 | import torch
28 | import torch.nn.functional as F
29 | from einops import rearrange
30 | from torch import nn
31 |
32 |
33 | class Mish(nn.Module):
34 | def forward(self, x):
35 | return _mish(x)
36 |
37 |
38 | @torch.jit.script
39 | def _mish(x):
40 | return x * torch.tanh(F.softplus(x))
41 |
42 |
43 | class SinusoidalPosEmb(nn.Module):
44 | def __init__(self, dim):
45 | super().__init__()
46 | half_dim = dim // 2
47 | emb = math.log(10000) / (half_dim - 1)
48 | self.register_buffer('emb', torch.exp(torch.arange(half_dim) * -emb))
49 |
50 | def forward(self, x):
51 | emb = x[:, None] * self.emb[None, :]
52 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
53 | return emb
54 |
55 |
56 | class SelfAttention(nn.Module):
57 | def __init__(self, dim, groups=32, **kwargs):
58 | super().__init__()
59 | self.group_norm = nn.GroupNorm(groups, dim)
60 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)
61 | self.out = nn.Conv2d(dim, dim, kernel_size=1)
62 |
63 | def forward(self, x):
64 | b, c, h, w = x.size()
65 | x = self.group_norm(x)
66 | q, k, v = tuple(t.view(b, c, h * w) for t in self.qkv(x).chunk(chunks=3, dim=1))
67 | attn_matrix = (torch.bmm(k.permute(0, 2, 1), q) / math.sqrt(c)).softmax(dim=-2)
68 | out = torch.bmm(v, attn_matrix).view(b, c, h, w)
69 | return self.out(out)
70 |
71 |
72 | class LinearTimeSelfAttention(nn.Module):
73 | def __init__(self, dim, heads=4, dim_head=32, groups=32):
74 | super().__init__()
75 | self.group_norm = nn.GroupNorm(groups, dim)
76 | self.heads = heads
77 | hidden_dim = dim_head * heads
78 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1)
79 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
80 |
81 | def forward(self, x):
82 | b, c, h, w = x.shape
83 | x = self.group_norm(x)
84 | qkv = self.to_qkv(x)
85 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
86 | k = k.softmax(dim=-1)
87 | context = torch.einsum('bhdn,bhen->bhde', k, v)
88 | out = torch.einsum('bhde,bhdn->bhen', context, q)
89 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
90 | return self.to_out(out)
91 |
92 |
93 | class ResnetBlock(nn.Module):
94 | def __init__(self, dim, dim_out, *, time_emb_dim, groups=8, dropout_rate=0.):
95 | super().__init__()
96 | # groups: Used in group norm.
97 | self.dim = dim
98 | self.dim_out = dim_out
99 | self.groups = groups
100 | self.dropout_rate = dropout_rate
101 | self.time_emb_dim = time_emb_dim
102 |
103 | self.mlp = nn.Sequential(
104 | Mish(),
105 | nn.Linear(time_emb_dim, dim_out)
106 | )
107 | # Norm -> non-linearity -> conv format follows
108 | # https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/nn.py#L55
109 | self.block1 = nn.Sequential(
110 | nn.GroupNorm(groups, dim),
111 | Mish(),
112 | nn.Conv2d(dim, dim_out, 3, padding=1),
113 | )
114 | self.block2 = nn.Sequential(
115 | nn.GroupNorm(groups, dim_out),
116 | Mish(),
117 | nn.Dropout(p=dropout_rate),
118 | nn.Conv2d(dim_out, dim_out, 3, padding=1),
119 | )
120 | self.res_conv = nn.Conv2d(dim, dim_out, 1)
121 |
122 | def forward(self, x, t):
123 | h = self.block1(x)
124 | h += self.mlp(t)[..., None, None]
125 | h = self.block2(h)
126 | return h + self.res_conv(x)
127 |
128 | def __repr__(self):
129 | return (f"{self.__class__.__name__}(dim={self.dim}, dim_out={self.dim_out}, time_emb_dim="
130 | f"{self.time_emb_dim}, groups={self.groups}, dropout_rate={self.dropout_rate})")
131 |
132 |
133 | class Residual(nn.Module):
134 | def __init__(self, fn):
135 | super().__init__()
136 | self.fn = fn
137 |
138 | def forward(self, x, *args, **kwargs):
139 | return self.fn(x, *args, **kwargs) + x
140 |
141 |
142 | class Blur(nn.Module):
143 | def __init__(self):
144 | super().__init__()
145 | f = torch.Tensor([1, 2, 1])
146 | f = f[None, None, :] * f[None, :, None]
147 | self.register_buffer('f', f)
148 |
149 | def forward(self, x):
150 | return kornia.filter2D(x, self.f, normalized=True)
151 |
152 |
153 | class Downsample(nn.Module):
154 | def __init__(self, dim, blur=True):
155 | super().__init__()
156 | if blur:
157 | self.conv = nn.Sequential(
158 | Blur(),
159 | nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1),
160 | )
161 | else:
162 | self.conv = nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1)
163 |
164 | def forward(self, x):
165 | return self.conv(x)
166 |
167 |
168 | class Upsample(nn.Module):
169 | def __init__(self, dim, blur=True):
170 | super().__init__()
171 | if blur:
172 | self.conv = nn.Sequential(
173 | nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1),
174 | Blur()
175 | )
176 | else:
177 | self.conv = nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1)
178 |
179 | def forward(self, x):
180 | return self.conv(x)
181 |
182 |
183 | class Unet(nn.Module):
184 | def __init__(self,
185 | input_size=(3, 32, 32),
186 | hidden_channels=64,
187 | dim_mults=(1, 2, 4, 8),
188 | groups=32,
189 | heads=4,
190 | dim_head=32,
191 | dropout_rate=0.,
192 | num_res_blocks=2,
193 | attn_resolutions=(16,),
194 | attention_cls=SelfAttention):
195 | super().__init__()
196 | in_channels, in_height, in_width = input_size
197 | dims = [hidden_channels, *map(lambda m: hidden_channels * m, dim_mults)]
198 | in_out = list(zip(dims[:-1], dims[1:]))
199 |
200 | self.time_pos_emb = SinusoidalPosEmb(hidden_channels)
201 | self.mlp = nn.Sequential(
202 | nn.Linear(hidden_channels, hidden_channels * 4),
203 | Mish(),
204 | nn.Linear(hidden_channels * 4, hidden_channels)
205 | )
206 |
207 | self.first_conv = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1)
208 |
209 | h, w = in_height, in_width
210 | self.down_res_blocks = nn.ModuleList([])
211 | self.down_attn_blocks = nn.ModuleList([])
212 | self.down_spatial_blocks = nn.ModuleList([])
213 | for ind, (dim_in, dim_out) in enumerate(in_out):
214 | res_blocks = nn.ModuleList([
215 | ResnetBlock(
216 | dim=dim_in,
217 | dim_out=dim_out,
218 | time_emb_dim=hidden_channels,
219 | groups=groups,
220 | dropout_rate=dropout_rate
221 | )
222 | ])
223 | res_blocks.extend([
224 | ResnetBlock(
225 | dim=dim_out,
226 | dim_out=dim_out,
227 | time_emb_dim=hidden_channels,
228 | groups=groups,
229 | dropout_rate=dropout_rate
230 | ) for _ in range(num_res_blocks - 1)
231 | ])
232 | self.down_res_blocks.append(res_blocks)
233 |
234 | attn_blocks = nn.ModuleList([])
235 | if h in attn_resolutions and w in attn_resolutions:
236 | attn_blocks.extend(
237 | [Residual(attention_cls(dim_out, heads=heads, dim_head=dim_head, groups=groups))
238 | for _ in range(num_res_blocks)]
239 | )
240 | self.down_attn_blocks.append(attn_blocks)
241 |
242 | if ind < (len(in_out) - 1):
243 | spatial_blocks = nn.ModuleList([Downsample(dim_out)])
244 | h, w = h // 2, w // 2
245 | else:
246 | spatial_blocks = nn.ModuleList()
247 | self.down_spatial_blocks.append(spatial_blocks)
248 |
249 | mid_dim = dims[-1]
250 | self.mid_block1 = ResnetBlock(
251 | dim=mid_dim,
252 | dim_out=mid_dim,
253 | time_emb_dim=hidden_channels,
254 | groups=groups,
255 | dropout_rate=dropout_rate
256 | )
257 | self.mid_attn = Residual(attention_cls(mid_dim, heads=heads, dim_head=dim_head, groups=groups))
258 | self.mid_block2 = ResnetBlock(
259 | dim=mid_dim,
260 | dim_out=mid_dim,
261 | time_emb_dim=hidden_channels,
262 | groups=groups,
263 | dropout_rate=dropout_rate
264 | )
265 |
266 | self.ups_res_blocks = nn.ModuleList([])
267 | self.ups_attn_blocks = nn.ModuleList([])
268 | self.ups_spatial_blocks = nn.ModuleList([])
269 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
270 | res_blocks = nn.ModuleList([
271 | ResnetBlock(
272 | dim=dim_out * 2,
273 | dim_out=dim_out,
274 | time_emb_dim=hidden_channels,
275 | groups=groups,
276 | dropout_rate=dropout_rate
277 | ) for _ in range(num_res_blocks)
278 | ])
279 | res_blocks.extend([
280 | ResnetBlock(
281 | dim=dim_out + dim_in,
282 | dim_out=dim_in,
283 | time_emb_dim=hidden_channels,
284 | groups=groups,
285 | dropout_rate=dropout_rate
286 | )
287 | ])
288 | self.ups_res_blocks.append(res_blocks)
289 |
290 | attn_blocks = nn.ModuleList([])
291 | if h in attn_resolutions and w in attn_resolutions:
292 | attn_blocks.extend(
293 | [Residual(attention_cls(dim_out, heads=heads, dim_head=dim_head, groups=groups))
294 | for _ in range(num_res_blocks)]
295 | )
296 | attn_blocks.append(
297 | Residual(attention_cls(dim_in, heads=heads, dim_head=dim_head, groups=groups))
298 | )
299 | self.ups_attn_blocks.append(attn_blocks)
300 |
301 | spatial_blocks = nn.ModuleList()
302 | if ind < (len(in_out) - 1):
303 | spatial_blocks.append(Upsample(dim_in))
304 | h, w = h * 2, w * 2
305 | self.ups_spatial_blocks.append(spatial_blocks)
306 |
307 | self.final_conv = nn.Sequential(
308 | nn.GroupNorm(groups, hidden_channels),
309 | Mish(),
310 | nn.Conv2d(hidden_channels, in_channels, 1)
311 | )
312 |
313 | def forward(self, t, x):
314 | t = self.mlp(self.time_pos_emb(t))
315 |
316 | hs = [self.first_conv(x)]
317 | for i, (res_blocks, attn_blocks, spatial_blocks) in enumerate(
318 | zip(self.down_res_blocks, self.down_attn_blocks, self.down_spatial_blocks)):
319 | if len(attn_blocks) > 0:
320 | for res_block, attn_block in zip(res_blocks, attn_blocks):
321 | h = res_block(hs[-1], t)
322 | h = attn_block(h)
323 | hs.append(h)
324 | else:
325 | for res_block in res_blocks:
326 | h = res_block(hs[-1], t)
327 | hs.append(h)
328 | if len(spatial_blocks) > 0:
329 | spatial_block, = spatial_blocks
330 | hs.append(spatial_block(hs[-1]))
331 |
332 | h = hs[-1]
333 | h = self.mid_block1(h, t)
334 | h = self.mid_attn(h)
335 | h = self.mid_block2(h, t)
336 |
337 | for i, (res_blocks, attn_blocks, spatial_blocks) in enumerate(
338 | zip(self.ups_res_blocks, self.ups_attn_blocks, self.ups_spatial_blocks)):
339 | if len(attn_blocks) > 0:
340 | for res_block, attn_block in zip(res_blocks, attn_blocks):
341 | h = res_block(torch.cat((h, hs.pop()), dim=1), t)
342 | h = attn_block(h)
343 | else:
344 | for res_block in res_blocks:
345 | h = res_block(torch.cat((h, hs.pop()), dim=1), t)
346 | if len(spatial_blocks) > 0:
347 | spatial_block, = spatial_blocks
348 | h = spatial_block(h)
349 | return self.final_conv(h)
350 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | # These are the assumed default build requirements from pip:
3 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
4 | requires = ["setuptools>=40.8.0", "wheel"]
5 | build-backend = "setuptools.build_meta"
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import re
17 |
18 | import setuptools
19 |
20 | # for simplicity we actually store the version in the __version__ attribute in the source
21 | here = os.path.realpath(os.path.dirname(__file__))
22 | with open(os.path.join(here, 'torchsde', '__init__.py')) as f:
23 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
24 | if meta_match:
25 | version = meta_match.group(1)
26 | else:
27 | raise RuntimeError("Unable to find __version__ string.")
28 |
29 | with open(os.path.join(here, 'README.md')) as f:
30 | readme = f.read()
31 |
32 | setuptools.setup(
33 | name="torchsde",
34 | version=version,
35 | author="Xuechen Li, Patrick Kidger",
36 | author_email="lxuechen@cs.stanford.edu, hello@kidger.site",
37 | description="SDE solvers and stochastic adjoint sensitivity analysis in PyTorch.",
38 | long_description=readme,
39 | long_description_content_type="text/markdown",
40 | url="https://github.com/google-research/torchsde",
41 | packages=setuptools.find_packages(exclude=['benchmarks', 'diagnostics', 'examples', 'tests']),
42 | install_requires=[
43 | "numpy>=1.19",
44 | "scipy>=1.5",
45 | "torch>=1.6.0",
46 | "trampoline>=0.1.2",
47 | ],
48 | python_requires='>=3.8',
49 | classifiers=[
50 | "Programming Language :: Python :: 3",
51 | "License :: OSI Approved :: Apache Software License",
52 | ],
53 | )
54 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/test_adjoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Compare gradients computed with adjoint vs analytical solution."""
16 | import sys
17 |
18 | sys.path = sys.path[1:] # A hack so that we always import the installed library.
19 |
20 | import pytest
21 | import torch
22 | import torchsde
23 | from torchsde.settings import LEVY_AREA_APPROXIMATIONS, METHODS, NOISE_TYPES, SDE_TYPES
24 |
25 | from . import utils
26 | from . import problems
27 |
28 | torch.manual_seed(1147481649)
29 | torch.set_default_dtype(torch.float64)
30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31 | dtype = torch.get_default_dtype()
32 |
33 |
34 | def _methods():
35 | yield SDE_TYPES.ito, METHODS.milstein, None
36 | yield SDE_TYPES.ito, METHODS.srk, None
37 | yield SDE_TYPES.stratonovich, METHODS.midpoint, None
38 | yield SDE_TYPES.stratonovich, METHODS.reversible_heun, None
39 |
40 |
41 | @pytest.mark.parametrize("sde_cls", [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive,
42 | problems.NeuralGeneral])
43 | @pytest.mark.parametrize("sde_type, method, options", _methods())
44 | @pytest.mark.parametrize('adaptive', (False,))
45 | def test_against_numerical(sde_cls, sde_type, method, options, adaptive):
46 | # Skipping below, since method not supported for corresponding noise types.
47 | if sde_cls.noise_type == NOISE_TYPES.general and method in (METHODS.milstein, METHODS.srk):
48 | return
49 |
50 | d = 3
51 | m = {
52 | NOISE_TYPES.scalar: 1,
53 | NOISE_TYPES.diagonal: d,
54 | NOISE_TYPES.general: 2,
55 | NOISE_TYPES.additive: 2
56 | }[sde_cls.noise_type]
57 | batch_size = 4
58 | t0, t1 = ts = torch.tensor([0.0, 0.5], device=device)
59 | dt = 1e-3
60 | y0 = torch.full((batch_size, d), 0.1, device=device)
61 | sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device)
62 |
63 | if method == METHODS.srk:
64 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time
65 | else:
66 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none
67 | bm = torchsde.BrownianInterval(
68 | t0=t0, t1=t1, size=(batch_size, m), dtype=dtype, device=device,
69 | levy_area_approximation=levy_area_approximation
70 | )
71 |
72 | if method == METHODS.reversible_heun:
73 | tol = 1e-6
74 | adjoint_method = METHODS.adjoint_reversible_heun
75 | adjoint_options = options
76 | else:
77 | tol = 1e-2
78 | adjoint_method = None
79 | adjoint_options = None
80 |
81 | def func(inputs, modules):
82 | y0, sde = inputs[0], modules[0]
83 | ys = torchsde.sdeint_adjoint(sde, y0, ts, dt=dt, method=method, adjoint_method=adjoint_method,
84 | adaptive=adaptive, bm=bm, options=options, adjoint_options=adjoint_options)
85 | return (ys[-1] ** 2).sum(dim=1).mean(dim=0)
86 |
87 | # `grad_inputs=True` also works, but we really only care about grad wrt params and want fast tests.
88 | utils.gradcheck(func, y0, sde, eps=1e-6, rtol=tol, atol=tol, grad_params=True)
89 |
90 |
91 | def _methods_dt_tol():
92 | for sde_type, method, options in _methods():
93 | if method == METHODS.reversible_heun:
94 | yield sde_type, method, options, 2**-3, 1e-3, 1e-4
95 | yield sde_type, method, options, 1e-3, 1e-3, 1e-4
96 | else:
97 | yield sde_type, method, options, 1e-3, 1e-2, 1e-2
98 |
99 |
100 | @pytest.mark.parametrize("sde_cls", [problems.NeuralDiagonal, problems.NeuralScalar, problems.NeuralAdditive,
101 | problems.NeuralGeneral])
102 | @pytest.mark.parametrize("sde_type, method, options, dt, rtol, atol", _methods_dt_tol())
103 | @pytest.mark.parametrize("len_ts", [2, 9])
104 | def test_against_sdeint(sde_cls, sde_type, method, options, dt, rtol, atol, len_ts):
105 | # Skipping below, since method not supported for corresponding noise types.
106 | if sde_cls.noise_type == NOISE_TYPES.general and method in (METHODS.milstein, METHODS.srk):
107 | return
108 |
109 | d = 3
110 | m = {
111 | NOISE_TYPES.scalar: 1,
112 | NOISE_TYPES.diagonal: d,
113 | NOISE_TYPES.general: 2,
114 | NOISE_TYPES.additive: 2
115 | }[sde_cls.noise_type]
116 | batch_size = 4
117 | ts = torch.linspace(0.0, 1.0, len_ts, device=device, dtype=torch.float64)
118 | t0 = ts[0]
119 | t1 = ts[-1]
120 | y0 = torch.full((batch_size, d), 0.1, device=device, dtype=torch.float64, requires_grad=True)
121 | sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device, torch.float64)
122 |
123 | if method == METHODS.srk:
124 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time
125 | else:
126 | levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none
127 | bm = torchsde.BrownianInterval(
128 | t0=t0, t1=t1, size=(batch_size, m), dtype=torch.float64, device=device,
129 | levy_area_approximation=levy_area_approximation
130 | )
131 |
132 | if method == METHODS.reversible_heun:
133 | adjoint_method = METHODS.adjoint_reversible_heun
134 | adjoint_options = options
135 | else:
136 | adjoint_method = None
137 | adjoint_options = None
138 |
139 | ys_true = torchsde.sdeint(sde, y0, ts, dt=dt, method=method, bm=bm, options=options)
140 | grad = torch.randn_like(ys_true)
141 | ys_true.backward(grad)
142 |
143 | true_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()])
144 | y0.grad.zero_()
145 | for param in sde.parameters():
146 | param.grad.zero_()
147 |
148 | ys_test = torchsde.sdeint_adjoint(sde, y0, ts, dt=dt, method=method, bm=bm, adjoint_method=adjoint_method,
149 | options=options, adjoint_options=adjoint_options)
150 | ys_test.backward(grad)
151 | test_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()])
152 |
153 | torch.testing.assert_allclose(ys_true, ys_test)
154 | torch.testing.assert_allclose(true_grad, test_grad, rtol=rtol, atol=atol)
155 |
156 |
157 | @pytest.mark.parametrize("problem", [problems.BasicSDE1, problems.BasicSDE2, problems.BasicSDE3, problems.BasicSDE4])
158 | @pytest.mark.parametrize("method", ["milstein", "srk"])
159 | @pytest.mark.parametrize('adaptive', (False, True))
160 | def test_basic(problem, method, adaptive):
161 | d = 10
162 | batch_size = 128
163 | ts = torch.tensor([0.0, 0.5], device=device)
164 | dt = 1e-3
165 | y0 = torch.zeros(batch_size, d).to(device).fill_(0.1)
166 |
167 | problem = problem(d).to(device)
168 |
169 | num_before = _count_differentiable_params(problem)
170 |
171 | problem.zero_grad()
172 | _, yt = torchsde.sdeint_adjoint(problem, y0, ts, method=method, dt=dt, adaptive=adaptive)
173 | loss = yt.sum(dim=1).mean(dim=0)
174 | loss.backward()
175 |
176 | num_after = _count_differentiable_params(problem)
177 | assert num_before == num_after
178 |
179 |
180 | def _count_differentiable_params(module):
181 | return len([p for p in module.parameters() if p.requires_grad])
182 |
--------------------------------------------------------------------------------
/tests/test_brownian_path.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Test `BrownianPath`.
16 |
17 | The suite tests both running on CPU and CUDA (if available).
18 | """
19 | import sys
20 |
21 | sys.path = sys.path[1:] # A hack so that we always import the installed library.
22 |
23 | import math
24 | import numpy as np
25 | import numpy.random as npr
26 | import torch
27 | from scipy.stats import norm, kstest
28 |
29 | import torchsde
30 | import pytest
31 |
32 | torch.manual_seed(1147481649)
33 | torch.set_default_dtype(torch.float64)
34 |
35 | D = 3
36 | BATCH_SIZE = 131072
37 | REPS = 3
38 | ALPHA = 0.00001
39 |
40 | devices = [cpu, gpu] = [torch.device('cpu'), torch.device('cuda')]
41 |
42 |
43 | def _setup(device):
44 | t0, t1 = torch.tensor([0., 1.], device=device)
45 | w0, w1 = torch.randn([2, BATCH_SIZE, D], device=device)
46 | t = torch.rand([], device=device)
47 | bm = torchsde.BrownianPath(t0=t0, w0=w0)
48 | return t, bm
49 |
50 |
51 | @pytest.mark.parametrize("device", devices)
52 | def test_basic(device):
53 | if device == gpu and not torch.cuda.is_available():
54 | pytest.skip(reason="CUDA not available.")
55 |
56 | t, bm = _setup(device)
57 | sample = bm(t)
58 | assert sample.size() == (BATCH_SIZE, D)
59 |
60 |
61 | @pytest.mark.parametrize("device", devices)
62 | def test_determinism(device):
63 | if device == gpu and not torch.cuda.is_available():
64 | pytest.skip(reason="CUDA not available.")
65 |
66 | t, bm = _setup(device)
67 | vals = [bm(t) for _ in range(REPS)]
68 | for val in vals[1:]:
69 | assert torch.allclose(val, vals[0])
70 |
71 |
72 | @pytest.mark.parametrize("device", devices)
73 | def test_normality(device):
74 | if device == gpu and not torch.cuda.is_available():
75 | pytest.skip(reason="CUDA not available.")
76 |
77 | t0_, t1_ = 0.0, 1.0
78 | eps = 1e-2
79 | for _ in range(REPS):
80 | w0_ = npr.randn() * math.sqrt(t1_)
81 | w0 = torch.tensor(w0_, device=device).repeat(BATCH_SIZE)
82 |
83 | bm = torchsde.BrownianPath(t0=t0_, w0=w0) # noqa
84 |
85 | w1_ = bm(t1_).cpu().numpy()
86 |
87 | t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps) # Avoid sampling too close to the boundary.
88 | samples_ = bm(t_).cpu().numpy()
89 |
90 | # True expected mean from Brownian bridge.
91 | mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_)
92 | std_ = math.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_))
93 | ref_dist = norm(loc=np.zeros_like(mean_), scale=np.ones_like(std_))
94 |
95 | _, pval = kstest((samples_ - mean_) / std_, ref_dist.cdf)
96 | assert pval >= ALPHA
97 |
--------------------------------------------------------------------------------
/tests/test_brownian_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Test `BrownianTree`.
16 |
17 | The suite tests both running on CPU and CUDA (if available).
18 | """
19 | import sys
20 |
21 | sys.path = sys.path[1:] # A hack so that we always import the installed library.
22 |
23 | import numpy as np
24 | import numpy.random as npr
25 | import torch
26 | from scipy.stats import norm, kstest
27 |
28 | import pytest
29 | import torchsde
30 |
31 | torch.manual_seed(0)
32 | torch.set_default_dtype(torch.float64)
33 |
34 | D = 3
35 | SMALL_BATCH_SIZE = 16
36 | LARGE_BATCH_SIZE = 16384
37 | REPS = 3
38 | ALPHA = 0.00001
39 |
40 | devices = [cpu, gpu] = [torch.device('cpu'), torch.device('cuda')]
41 |
42 |
43 | def _setup(device, batch_size):
44 | t0, t1 = torch.tensor([0., 1.], device=device)
45 | w0 = torch.zeros(batch_size, D, device=device)
46 | t = torch.rand([]).to(device)
47 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, entropy=0)
48 | return t, bm
49 |
50 |
51 | def _dict_to_sorted_list(*dicts):
52 | lists = tuple([d[k] for k in sorted(d.keys())] for d in dicts)
53 | if len(lists) == 1:
54 | return lists[0]
55 | return lists
56 |
57 |
58 | @pytest.mark.parametrize("device", devices)
59 | def test_basic(device):
60 | if device == gpu and not torch.cuda.is_available():
61 | pytest.skip(reason="CUDA not available.")
62 |
63 | t, bm = _setup(device, SMALL_BATCH_SIZE)
64 | sample = bm(t)
65 | assert sample.size() == (SMALL_BATCH_SIZE, D)
66 |
67 |
68 | @pytest.mark.parametrize("device", devices)
69 | def test_determinism(device):
70 | if device == gpu and not torch.cuda.is_available():
71 | pytest.skip(reason="CUDA not available.")
72 |
73 | t, bm = _setup(device, SMALL_BATCH_SIZE)
74 | vals = [bm(t) for _ in range(REPS)]
75 | for val in vals[1:]:
76 | assert torch.allclose(val, vals[0])
77 |
78 |
79 | @pytest.mark.parametrize("device", devices)
80 | def test_normality(device):
81 | if device == gpu and not torch.cuda.is_available():
82 | pytest.skip(reason="CUDA not available.")
83 |
84 | t0_, t1_ = 0.0, 1.0
85 | t0, t1 = torch.tensor([t0_, t1_], device=device)
86 | eps = 1e-5
87 | for _ in range(REPS):
88 | w0_, w1_ = 0.0, npr.randn()
89 | w0 = torch.tensor(w0_, device=device).repeat(LARGE_BATCH_SIZE)
90 | w1 = torch.tensor(w1_, device=device).repeat(LARGE_BATCH_SIZE)
91 | bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, pool_size=100, tol=1e-14) # noqa
92 |
93 | for _ in range(REPS):
94 | t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps)
95 | samples = bm(t_)
96 | samples_ = samples.cpu().detach().numpy()
97 |
98 | mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_)
99 | std_ = np.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_))
100 | ref_dist = norm(loc=mean_, scale=std_)
101 |
102 | _, pval = kstest(samples_, ref_dist.cdf)
103 | assert pval >= ALPHA
104 |
--------------------------------------------------------------------------------
/tests/test_sdeint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 |
17 | sys.path = sys.path[1:] # A hack so that we always import the installed library.
18 |
19 | import pytest
20 | import torch
21 | import torchsde
22 | from torchsde.settings import NOISE_TYPES
23 |
24 | from . import problems
25 |
26 | torch.manual_seed(1147481649)
27 | torch.set_default_dtype(torch.float64)
28 | devices = ['cpu']
29 | if torch.cuda.is_available():
30 | devices.append('cuda')
31 |
32 | batch_size = 4
33 | d = 3
34 | m = 2
35 | t0 = 0.0
36 | t1 = 0.3
37 | T = 5
38 | dt = 0.05
39 | dtype = torch.get_default_dtype()
40 |
41 |
42 | class _nullcontext:
43 | def __enter__(self):
44 | pass
45 |
46 | def __exit__(self, exc_type, exc_val, exc_tb):
47 | pass
48 |
49 |
50 | @pytest.mark.parametrize('device', devices)
51 | def test_rename_methods(device):
52 | """Test renaming works with a subset of names."""
53 | sde = problems.CustomNamesSDE().to(device)
54 | y0 = torch.ones(batch_size, d, device=device)
55 | ts = torch.linspace(t0, t1, steps=T, device=device)
56 | ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'})
57 | assert ans.shape == (T, batch_size, d)
58 |
59 |
60 | @pytest.mark.parametrize('device', devices)
61 | def test_rename_methods_logqp(device):
62 | """Test renaming works with a subset of names when `logqp=True`."""
63 | sde = problems.CustomNamesSDELogqp().to(device)
64 | y0 = torch.ones(batch_size, d, device=device)
65 | ts = torch.linspace(t0, t1, steps=T, device=device)
66 | ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True)
67 | assert ans[0].shape == (T, batch_size, d)
68 | assert ans[1].shape == (T - 1, batch_size)
69 |
70 |
71 | def _use_bm__levy_area_approximation():
72 | yield False, None
73 | yield True, 'none'
74 | yield True, 'space-time'
75 | yield True, 'davie'
76 | yield True, 'foster'
77 |
78 |
79 | @pytest.mark.parametrize('sde_type,method', [('ito', 'euler'), ('stratonovich', 'midpoint')])
80 | def test_specialised_functions(sde_type, method):
81 | vector = torch.randn(m)
82 | fg = problems.FGSDE(sde_type, vector)
83 | f_and_g = problems.FAndGSDE(sde_type, vector)
84 | g_prod = problems.GProdSDE(sde_type, vector)
85 | f_and_g_prod = problems.FAndGProdSDE(sde_type, vector)
86 | f_and_g_with_g_prod1 = problems.FAndGGProdSDE1(sde_type, vector)
87 | f_and_g_with_g_prod2 = problems.FAndGGProdSDE2(sde_type, vector)
88 |
89 | y0 = torch.randn(batch_size, d)
90 |
91 | outs = []
92 | for sde in (fg, f_and_g, g_prod, f_and_g_prod, f_and_g_with_g_prod1, f_and_g_with_g_prod2):
93 | bm = torchsde.BrownianInterval(t0, t1, (batch_size, m), entropy=45678)
94 | outs.append(torchsde.sdeint(sde, y0, [t0, t1], dt=dt, bm=bm)[1])
95 | for o in outs[1:]:
96 | # Equality of floating points, because we expect them to do everything exactly the same.
97 | assert o.shape == outs[0].shape
98 | assert (o == outs[0]).all()
99 |
100 |
101 | @pytest.mark.parametrize('sde_cls', [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive,
102 | problems.NeuralGeneral])
103 | @pytest.mark.parametrize('use_bm,levy_area_approximation', _use_bm__levy_area_approximation())
104 | @pytest.mark.parametrize('sde_type', ['ito', 'stratonovich'])
105 | @pytest.mark.parametrize('method',
106 | ['blah', 'euler', 'milstein', 'milstein_grad_free', 'srk', 'euler_heun', 'heun', 'midpoint',
107 | 'log_ode'])
108 | @pytest.mark.parametrize('adaptive', [False, True])
109 | @pytest.mark.parametrize('logqp', [True, False])
110 | @pytest.mark.parametrize('device', devices)
111 | def test_sdeint_run_shape_method(sde_cls, use_bm, levy_area_approximation, sde_type, method, adaptive, logqp, device):
112 | """Tests that sdeint:
113 | (a) runs/raises an error as appropriate
114 | (b) produces tensors of the right shape
115 | (c) accepts every method
116 | """
117 |
118 | if method == 'milstein_grad_free':
119 | method = 'milstein'
120 | options = dict(grad_free=True)
121 | else:
122 | options = dict()
123 |
124 | should_fail = False
125 | if sde_type == 'ito':
126 | if method not in ('euler', 'srk', 'milstein'):
127 | should_fail = True
128 | else:
129 | if method not in ('euler_heun', 'heun', 'midpoint', 'log_ode', 'milstein'):
130 | should_fail = True
131 | if method in ('milstein', 'srk') and sde_cls.noise_type == 'general':
132 | should_fail = True
133 | if method == 'srk' and levy_area_approximation == 'none':
134 | should_fail = True
135 | if method == 'log_ode' and levy_area_approximation in ('none', 'space-time'):
136 | should_fail = True
137 |
138 | if sde_cls.noise_type in (NOISE_TYPES.scalar, NOISE_TYPES.diagonal):
139 | kwargs = {'d': d}
140 | else:
141 | kwargs = {'d': d, 'm': m}
142 | sde = sde_cls(sde_type=sde_type, **kwargs).to(device)
143 |
144 | if use_bm:
145 | if sde_cls.noise_type == 'scalar':
146 | size = (batch_size, 1)
147 | elif sde_cls.noise_type == 'diagonal':
148 | size = (batch_size, d + 1) if logqp else (batch_size, d)
149 | else:
150 | assert sde_cls.noise_type in ('additive', 'general')
151 | size = (batch_size, m)
152 | bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device,
153 | levy_area_approximation=levy_area_approximation)
154 | else:
155 | bm = None
156 |
157 | _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options)
158 |
159 |
160 | @pytest.mark.parametrize("sde_cls", [problems.BasicSDE1, problems.BasicSDE2, problems.BasicSDE3, problems.BasicSDE4])
161 | @pytest.mark.parametrize('method', ['euler', 'milstein', 'milstein_grad_free', 'srk'])
162 | @pytest.mark.parametrize('adaptive', [False, True])
163 | @pytest.mark.parametrize('device', devices)
164 | def test_sdeint_dependencies(sde_cls, method, adaptive, device):
165 | """This test uses diagonal noise. This checks if the solvers still work when some of the functions don't depend on
166 | the states/params and when some states/params don't require gradients.
167 | """
168 |
169 | if method == 'milstein_grad_free':
170 | method = 'milstein'
171 | options = dict(grad_free=True)
172 | else:
173 | options = dict()
174 |
175 | sde = sde_cls(d=d).to(device)
176 | bm = None
177 | logqp = False
178 | should_fail = False
179 | _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options)
180 |
181 |
182 | def _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options):
183 | y0 = torch.ones(batch_size, d, device=device)
184 | ts = torch.linspace(t0, t1, steps=T, device=device)
185 | if adaptive and method == 'euler' and sde.noise_type != 'additive':
186 | ctx = pytest.warns(UserWarning)
187 | else:
188 | ctx = _nullcontext()
189 |
190 | # Using `f` as drift.
191 | with torch.no_grad():
192 | try:
193 | with ctx:
194 | ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, logqp=logqp,
195 | options=options)
196 | except ValueError:
197 | if should_fail:
198 | return
199 | raise
200 | else:
201 | if should_fail:
202 | pytest.fail("Expected an error; did not get one.")
203 | if logqp:
204 | ans, log_ratio = ans
205 | assert log_ratio.shape == (T - 1, batch_size)
206 | assert ans.shape == (T, batch_size, d)
207 |
208 | # Using `h` as drift.
209 | with torch.no_grad():
210 | with ctx:
211 | ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, names={'drift': 'h'},
212 | logqp=logqp, options=options)
213 | if logqp:
214 | ans, log_ratio = ans
215 | assert log_ratio.shape == (T - 1, batch_size)
216 | assert ans.shape == (T, batch_size, d)
217 |
218 |
219 | @pytest.mark.parametrize("sde_cls", [problems.NeuralDiagonal, problems.NeuralScalar, problems.NeuralAdditive,
220 | problems.NeuralGeneral])
221 | def test_reversibility(sde_cls):
222 | batch_size = 32
223 | state_size = 4
224 | t_size = 20
225 | dt = 0.1
226 |
227 | brownian_size = {
228 | NOISE_TYPES.scalar: 1,
229 | NOISE_TYPES.diagonal: state_size,
230 | NOISE_TYPES.general: 2,
231 | NOISE_TYPES.additive: 2
232 | }[sde_cls.noise_type]
233 |
234 | class MinusSDE(torch.nn.Module):
235 | def __init__(self, sde):
236 | self.noise_type = sde.noise_type
237 | self.sde_type = sde.sde_type
238 | self.f = lambda t, y: -sde.f(-t, y)
239 | self.g = lambda t, y: -sde.g(-t, y)
240 |
241 | sde = sde_cls(d=state_size, m=brownian_size, sde_type='stratonovich')
242 | minus_sde = MinusSDE(sde)
243 | y0 = torch.full((batch_size, state_size), 0.1)
244 | ts = torch.linspace(0, (t_size - 1) * dt, t_size)
245 | bm = torchsde.BrownianInterval(t0=ts[0], t1=ts[-1], size=(batch_size, brownian_size))
246 | ys, (f, g, z) = torchsde.sdeint(sde, y0, ts, bm=bm, method='reversible_heun', dt=dt, extra=True)
247 | backward_ts = -ts.flip(0)
248 | backward_ys = torchsde.sdeint(minus_sde, ys[-1], backward_ts, bm=torchsde.ReverseBrownian(bm),
249 | method='reversible_heun', dt=dt, extra_solver_state=(-f, -g, z))
250 | backward_ys = backward_ys.flip(0)
251 |
252 | torch.testing.assert_allclose(ys, backward_ys, rtol=1e-6, atol=1e-6)
253 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 |
17 | import torch
18 | from torch import nn
19 |
20 | from torchsde.types import Callable, ModuleOrModules, Optional, TensorOrTensors
21 |
22 |
23 | # These tolerances don't need to be this large. For gradients to match up in the Ito case, we typically need large
24 | # values; not so much as in the Stratonovich case.
25 | def assert_allclose(actual, expected, rtol=1e-3, atol=1e-2):
26 | if actual is None:
27 | assert expected is None
28 | else:
29 | torch.testing.assert_allclose(actual, expected, rtol=rtol, atol=atol)
30 |
31 |
32 | def gradcheck(func: Callable,
33 | inputs: TensorOrTensors,
34 | modules: Optional[ModuleOrModules] = (),
35 | eps: float = 1e-6,
36 | atol: float = 1e-5,
37 | rtol: float = 1e-3,
38 | grad_inputs=False,
39 | gradgrad_inputs=False,
40 | grad_params=False,
41 | gradgrad_params=False):
42 | """Check grad and grad of grad wrt inputs and parameters of Modules.
43 |
44 | When `func` is vector-valued, the checks compare autodiff vjp against
45 | finite-difference vjp, where v is a sampled standard normal vector.
46 |
47 | This function is aimed to be as self-contained as possible so that it could
48 | be copied/pasted across different projects.
49 |
50 | Args:
51 | func (callable): A Python function that takes in a sequence of tensors
52 | (inputs) and a sequence of nn.Module (modules), and outputs a tensor
53 | or a sequence of tensors.
54 | inputs (sequence of Tensors): The input tensors.
55 | modules (sequence of nn.Module): The modules whose parameter gradient
56 | needs to be tested.
57 | eps (float, optional): Magnitude of two-sided finite difference
58 | perturbation.
59 | atol (float, optional): Absolute tolerance.
60 | rtol (float, optional): Relative tolerance.
61 | grad_inputs (bool, optional): Check gradients wrt inputs if True.
62 | gradgrad_inputs (bool, optional): Check gradients of gradients wrt
63 | inputs if True.
64 | grad_params (bool, optional): Check gradients wrt differentiable
65 | parameters of modules if True.
66 | gradgrad_params (bool, optional): Check gradients of gradients wrt
67 | differentiable parameters of modules if True.
68 |
69 | Returns:
70 | None.
71 | """
72 |
73 | def convert_none_to_zeros(sequence, like_sequence):
74 | return [torch.zeros_like(q) if p is None else p for p, q in zip(sequence, like_sequence)]
75 |
76 | def flatten(sequence):
77 | return torch.cat([p.reshape(-1) for p in sequence]) if len(sequence) > 0 else torch.tensor([])
78 |
79 | if isinstance(inputs, torch.Tensor):
80 | inputs = (inputs,)
81 |
82 | if isinstance(modules, nn.Module):
83 | modules = (modules,)
84 |
85 | # Don't modify original objects.
86 | modules = tuple(copy.deepcopy(m) for m in modules)
87 | inputs = tuple(i.clone().requires_grad_() for i in inputs)
88 |
89 | func = _make_scalar_valued_func(func, inputs, modules)
90 | func_only_inputs = lambda *args: func(args, modules) # noqa
91 |
92 | # Grad wrt inputs.
93 | if grad_inputs:
94 | torch.autograd.gradcheck(func_only_inputs, inputs, eps=eps, atol=atol, rtol=rtol)
95 |
96 | # Grad of grad wrt inputs.
97 | if gradgrad_inputs:
98 | torch.autograd.gradgradcheck(func_only_inputs, inputs, eps=eps, atol=atol, rtol=rtol)
99 |
100 | # Grad wrt params.
101 | if grad_params:
102 | params = [p for m in modules for p in m.parameters() if p.requires_grad]
103 | loss = func(inputs, modules)
104 | framework_grad = flatten(convert_none_to_zeros(torch.autograd.grad(loss, params, create_graph=True), params))
105 |
106 | numerical_grad = []
107 | for param in params:
108 | flat_param = param.reshape(-1)
109 | for i in range(len(flat_param)):
110 | flat_param[i].data.add_(eps)
111 | plus_eps = func(inputs, modules).detach()
112 | flat_param[i].data.sub_(eps)
113 |
114 | flat_param[i].data.sub_(eps)
115 | minus_eps = func(inputs, modules).detach()
116 | flat_param[i].data.add_(eps)
117 |
118 | numerical_grad.append((plus_eps - minus_eps) / (2 * eps))
119 | del plus_eps, minus_eps
120 | numerical_grad = torch.stack(numerical_grad)
121 | torch.testing.assert_allclose(numerical_grad, framework_grad, rtol=rtol, atol=atol)
122 |
123 | # Grad of grad wrt params.
124 | if gradgrad_params:
125 | def func_high_order(inputs, modules):
126 | params = [p for m in modules for p in m.parameters() if p.requires_grad]
127 | grads = torch.autograd.grad(func(inputs, modules), params, create_graph=True, allow_unused=True)
128 | return tuple(grad for grad in grads if grad is not None)
129 |
130 | gradcheck(func_high_order, inputs, modules, rtol=rtol, atol=atol, eps=eps, grad_params=True)
131 |
132 |
133 | def _make_scalar_valued_func(func, inputs, modules):
134 | outputs = func(inputs, modules)
135 | output_size = outputs.numel() if torch.is_tensor(outputs) else sum(o.numel() for o in outputs)
136 |
137 | if output_size > 1:
138 | # Define this outside `func_scalar_valued` so that random tensors are generated only once.
139 | grad_outputs = tuple(torch.randn_like(o) for o in outputs)
140 |
141 | def func_scalar_valued(inputs, modules):
142 | outputs = func(inputs, modules)
143 | return sum((output * grad_output).sum() for output, grad_output, in zip(outputs, grad_outputs))
144 |
145 | return func_scalar_valued
146 |
147 | return func
148 |
--------------------------------------------------------------------------------
/torchsde/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from ._brownian import (BaseBrownian, BrownianInterval, BrownianPath, BrownianTree, ReverseBrownian,
16 | brownian_interval_like)
17 | from ._core.adjoint import sdeint_adjoint
18 | from ._core.base_sde import BaseSDE, SDEIto, SDEStratonovich
19 | from ._core.sdeint import sdeint
20 |
21 | BrownianInterval.__init__.__annotations__ = {}
22 | BrownianPath.__init__.__annotations__ = {}
23 | BrownianTree.__init__.__annotations__ = {}
24 | sdeint.__annotations__ = {}
25 | sdeint_adjoint.__annotations__ = {}
26 |
27 | __version__ = '0.2.6'
28 |
--------------------------------------------------------------------------------
/torchsde/_brownian/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .brownian_base import BaseBrownian
16 | from .brownian_interval import BrownianInterval
17 | from .derived import ReverseBrownian, BrownianPath, BrownianTree, brownian_interval_like
18 |
--------------------------------------------------------------------------------
/torchsde/_brownian/brownian_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import abc
16 |
17 |
18 | class BaseBrownian(metaclass=abc.ABCMeta):
19 | __slots__ = ()
20 |
21 | @abc.abstractmethod
22 | def __call__(self, ta, tb=None, return_U=False, return_A=False):
23 | raise NotImplementedError
24 |
25 | @abc.abstractmethod
26 | def __repr__(self):
27 | raise NotImplementedError
28 |
29 | @property
30 | @abc.abstractmethod
31 | def dtype(self):
32 | raise NotImplementedError
33 |
34 | @property
35 | @abc.abstractmethod
36 | def device(self):
37 | raise NotImplementedError
38 |
39 | @property
40 | @abc.abstractmethod
41 | def shape(self):
42 | raise NotImplementedError
43 |
44 | @property
45 | @abc.abstractmethod
46 | def levy_area_approximation(self):
47 | raise NotImplementedError
48 |
49 | def size(self):
50 | return self.shape
51 |
--------------------------------------------------------------------------------
/torchsde/_brownian/derived.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | from . import brownian_base
18 | from . import brownian_interval
19 | from ..types import Optional, Scalar, Tensor, Tuple, Union
20 |
21 |
22 | class ReverseBrownian(brownian_base.BaseBrownian):
23 | def __init__(self, base_brownian):
24 | super(ReverseBrownian, self).__init__()
25 | self.base_brownian = base_brownian
26 |
27 | def __call__(self, ta, tb=None, return_U=False, return_A=False):
28 | # Whether or not to negate the statistics depends on the return value of the adjoint SDE. Currently, the adjoint
29 | # returns negated drift and diffusion, so we don't negate here.
30 | return self.base_brownian(-tb, -ta, return_U=return_U, return_A=return_A)
31 |
32 | def __repr__(self):
33 | return f"{self.__class__.__name__}(base_brownian={self.base_brownian})"
34 |
35 | @property
36 | def dtype(self):
37 | return self.base_brownian.dtype
38 |
39 | @property
40 | def device(self):
41 | return self.base_brownian.device
42 |
43 | @property
44 | def shape(self):
45 | return self.base_brownian.shape
46 |
47 | @property
48 | def levy_area_approximation(self):
49 | return self.base_brownian.levy_area_approximation
50 |
51 |
52 | class BrownianPath(brownian_base.BaseBrownian):
53 | """Brownian path, storing every computed value.
54 |
55 | Useful for speed, when memory isn't a concern.
56 |
57 | To use:
58 | >>> bm = BrownianPath(t0=0.0, w0=torch.zeros(4, 1))
59 | >>> bm(0., 0.5)
60 | tensor([[ 0.0733],
61 | [-0.5692],
62 | [ 0.1872],
63 | [-0.3889]])
64 | """
65 |
66 | def __init__(self, t0: Scalar, w0: Tensor, window_size: int = 8):
67 | """Initialize Brownian path.
68 | Arguments:
69 | t0: Initial time.
70 | w0: Initial state.
71 | window_size: Unused; deprecated.
72 | """
73 | t1 = t0 + 1
74 | self._w0 = w0
75 | self._interval = brownian_interval.BrownianInterval(t0=t0, t1=t1, size=w0.shape, dtype=w0.dtype,
76 | device=w0.device, cache_size=None)
77 | super(BrownianPath, self).__init__()
78 |
79 | def __call__(self, t, tb=None, return_U=False, return_A=False):
80 | # Deliberately called t rather than ta, for backward compatibility
81 | out = self._interval(t, tb, return_U=return_U, return_A=return_A)
82 | if tb is None and not return_U and not return_A:
83 | out = out + self._w0
84 | return out
85 |
86 | def __repr__(self):
87 | return f"{self.__class__.__name__}(interval={self._interval})"
88 |
89 | @property
90 | def dtype(self):
91 | return self._interval.dtype
92 |
93 | @property
94 | def device(self):
95 | return self._interval.device
96 |
97 | @property
98 | def shape(self):
99 | return self._interval.shape
100 |
101 | @property
102 | def levy_area_approximation(self):
103 | return self._interval.levy_area_approximation
104 |
105 |
106 | class BrownianTree(brownian_base.BaseBrownian):
107 | """Brownian tree with fixed entropy.
108 |
109 | Useful when the map from entropy -> Brownian motion shouldn't depend on the
110 | locations and order of the query points. (As the usual BrownianInterval
111 | does - note that BrownianTree is slower as a result though.)
112 |
113 | To use:
114 | >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1))
115 | >>> bm(0., 0.5)
116 | tensor([[ 0.0733],
117 | [-0.5692],
118 | [ 0.1872],
119 | [-0.3889]], device='cuda:0')
120 | """
121 |
122 | def __init__(self, t0: Scalar,
123 | w0: Tensor,
124 | t1: Optional[Scalar] = None,
125 | w1: Optional[Tensor] = None,
126 | entropy: Optional[int] = None,
127 | tol: float = 1e-6,
128 | pool_size: int = 24,
129 | cache_depth: int = 9,
130 | safety: Optional[float] = None):
131 | """Initialize the Brownian tree.
132 |
133 | The random value generation process exploits the parallel random number paradigm and uses
134 | `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`).
135 |
136 | Arguments:
137 | t0: Initial time.
138 | w0: Initial state.
139 | t1: Terminal time.
140 | w1: Terminal state.
141 | entropy: Global seed, defaults to `None` for random entropy.
142 | tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol).
143 | pool_size: Size of the pooled entropy. This parameter affects the query speed significantly.
144 | cache_depth: Unused; deprecated.
145 | safety: Unused; deprecated.
146 | """
147 |
148 | if t1 is None:
149 | t1 = t0 + 1
150 | if w1 is None:
151 | W = None
152 | else:
153 | W = w1 - w0
154 | self._w0 = w0
155 | self._interval = brownian_interval.BrownianInterval(t0=t0,
156 | t1=t1,
157 | size=w0.shape,
158 | dtype=w0.dtype,
159 | device=w0.device,
160 | entropy=entropy,
161 | tol=tol,
162 | pool_size=pool_size,
163 | halfway_tree=True,
164 | W=W)
165 | super(BrownianTree, self).__init__()
166 |
167 | def __call__(self, t, tb=None, return_U=False, return_A=False):
168 | # Deliberately called t rather than ta, for backward compatibility
169 | out = self._interval(t, tb, return_U=return_U, return_A=return_A)
170 | if tb is None and not return_U and not return_A:
171 | out = out + self._w0
172 | return out
173 |
174 | def __repr__(self):
175 | return f"{self.__class__.__name__}(interval={self._interval})"
176 |
177 | @property
178 | def dtype(self):
179 | return self._interval.dtype
180 |
181 | @property
182 | def device(self):
183 | return self._interval.device
184 |
185 | @property
186 | def shape(self):
187 | return self._interval.shape
188 |
189 | @property
190 | def levy_area_approximation(self):
191 | return self._interval.levy_area_approximation
192 |
193 |
194 | def brownian_interval_like(y: Tensor,
195 | t0: Optional[Scalar] = 0.,
196 | t1: Optional[Scalar] = 1.,
197 | size: Optional[Tuple[int, ...]] = None,
198 | dtype: Optional[torch.dtype] = None,
199 | device: Optional[Union[str, torch.device]] = None,
200 | **kwargs):
201 | """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor."""
202 | size = y.shape if size is None else size
203 | dtype = y.dtype if dtype is None else dtype
204 | device = y.device if device is None else device
205 | return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs)
206 |
--------------------------------------------------------------------------------
/torchsde/_core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/torchsde/_core/adaptive_stepping.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | from . import misc
18 | from ..types import TensorOrTensors
19 |
20 |
21 | def update_step_size(error_estimate, prev_step_size, safety=0.9, facmin=0.2, facmax=1.4, prev_error_ratio=None):
22 | """Adaptively propose the next step size based on estimated errors."""
23 | if error_estimate > 1:
24 | pfactor = 0
25 | ifactor = 1 / 1.5 # 1 / 5
26 | else:
27 | pfactor = 0.13
28 | ifactor = 1 / 4.5 # 1 / 15
29 |
30 | error_ratio = safety / error_estimate
31 | if prev_error_ratio is None:
32 | prev_error_ratio = error_ratio
33 | factor = error_ratio ** ifactor * (error_ratio / prev_error_ratio) ** pfactor
34 | if error_estimate <= 1:
35 | prev_error_ratio = error_ratio
36 | facmin = 1.0
37 | factor = min(facmax, max(facmin, factor))
38 | new_step_size = prev_step_size * factor
39 | return new_step_size, prev_error_ratio
40 |
41 |
42 | def compute_error(y11: TensorOrTensors, y12: TensorOrTensors, rtol, atol, eps=1e-7):
43 | """Computer error estimate.
44 |
45 | Args:
46 | y11: A tensor or a sequence of tensors obtained with a full update.
47 | y12: A tensor or a sequence of tensors obtained with two half updates.
48 | rtol: Relative tolerance.
49 | atol: Absolute tolerance.
50 | eps: A small constant to avoid division by zero.
51 |
52 | Returns:
53 | A float for the aggregated error estimate.
54 | """
55 | if torch.is_tensor(y11):
56 | y11 = (y11,)
57 | if torch.is_tensor(y12):
58 | y12 = (y12,)
59 | tol = [
60 | (rtol * torch.max(torch.abs(y11_), torch.abs(y12_)) + atol).clamp_min(eps)
61 | for y11_, y12_ in zip(y11, y12)
62 | ]
63 | error_estimate = _rms(
64 | [(y11_ - y12_) / tol_ for y11_, y12_, tol_ in zip(y11, y12, tol)], eps
65 | )
66 | assert not misc.is_nan(error_estimate), (
67 | 'Found nans in the error estimate. Try increasing the tolerance or regularizing the dynamics.'
68 | )
69 | return error_estimate.detach().cpu().item()
70 |
71 |
72 | def _rms(x, eps=1e-7):
73 | if torch.is_tensor(x):
74 | return torch.sqrt((x ** 2.).sum() / x.numel()).clamp_min(eps)
75 | else:
76 | return torch.sqrt(sum((x_ ** 2.).sum() for x_ in x) / sum(x_.numel() for x_ in x)).clamp_min(eps)
77 |
--------------------------------------------------------------------------------
/torchsde/_core/base_sde.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import abc
16 |
17 | import torch
18 | from torch import nn
19 |
20 | from . import misc
21 | from ..settings import NOISE_TYPES, SDE_TYPES
22 | from ..types import Tensor
23 |
24 |
25 | class BaseSDE(abc.ABC, nn.Module):
26 | """Base class for all SDEs.
27 |
28 | Inheriting from this class ensures `noise_type` and `sde_type` are valid attributes, which the solver depends on.
29 | """
30 |
31 | def __init__(self, noise_type, sde_type):
32 | super(BaseSDE, self).__init__()
33 | if noise_type not in NOISE_TYPES:
34 | raise ValueError(f"Expected noise type in {NOISE_TYPES}, but found {noise_type}")
35 | if sde_type not in SDE_TYPES:
36 | raise ValueError(f"Expected sde type in {SDE_TYPES}, but found {sde_type}")
37 | # Making these Python properties breaks `torch.jit.script`.
38 | self.noise_type = noise_type
39 | self.sde_type = sde_type
40 |
41 |
42 | class ForwardSDE(BaseSDE):
43 |
44 | def __init__(self, sde, fast_dg_ga_jvp_column_sum=False):
45 | super(ForwardSDE, self).__init__(sde_type=sde.sde_type, noise_type=sde.noise_type)
46 | self._base_sde = sde
47 |
48 | # Register the core functions. This avoids polluting the codebase with if-statements and achieves speed-ups
49 | # by making sure it's a one-time cost.
50 |
51 | if hasattr(sde, 'f_and_g_prod'):
52 | self.f_and_g_prod = sde.f_and_g_prod
53 | elif hasattr(sde, 'f') and hasattr(sde, 'g_prod'):
54 | self.f_and_g_prod = self.f_and_g_prod_default1
55 | else: # (f_and_g,) or (f, g,).
56 | self.f_and_g_prod = self.f_and_g_prod_default2
57 |
58 | self.f = getattr(sde, 'f', self.f_default)
59 | self.g = getattr(sde, 'g', self.g_default)
60 | self.f_and_g = getattr(sde, 'f_and_g', self.f_and_g_default)
61 | self.g_prod = getattr(sde, 'g_prod', self.g_prod_default)
62 | self.prod = {
63 | NOISE_TYPES.diagonal: self.prod_diagonal
64 | }.get(sde.noise_type, self.prod_default)
65 | self.g_prod_and_gdg_prod = {
66 | NOISE_TYPES.diagonal: self.g_prod_and_gdg_prod_diagonal,
67 | NOISE_TYPES.additive: self.g_prod_and_gdg_prod_additive,
68 | }.get(sde.noise_type, self.g_prod_and_gdg_prod_default)
69 | self.dg_ga_jvp_column_sum = {
70 | NOISE_TYPES.general: (
71 | self.dg_ga_jvp_column_sum_v2 if fast_dg_ga_jvp_column_sum else self.dg_ga_jvp_column_sum_v1
72 | )
73 | }.get(sde.noise_type, self._return_zero)
74 |
75 | ########################################
76 | # f #
77 | ########################################
78 | def f_default(self, t, y):
79 | raise RuntimeError("Method `f` has not been provided, but is required for this method.")
80 |
81 | ########################################
82 | # g #
83 | ########################################
84 | def g_default(self, t, y):
85 | raise RuntimeError("Method `g` has not been provided, but is required for this method.")
86 |
87 | ########################################
88 | # f_and_g #
89 | ########################################
90 |
91 | def f_and_g_default(self, t, y):
92 | return self.f(t, y), self.g(t, y)
93 |
94 | ########################################
95 | # prod #
96 | ########################################
97 |
98 | def prod_diagonal(self, g, v):
99 | return g * v
100 |
101 | def prod_default(self, g, v):
102 | return misc.batch_mvp(g, v)
103 |
104 | ########################################
105 | # g_prod #
106 | ########################################
107 |
108 | def g_prod_default(self, t, y, v):
109 | return self.prod(self.g(t, y), v)
110 |
111 | ########################################
112 | # f_and_g_prod #
113 | ########################################
114 |
115 | def f_and_g_prod_default1(self, t, y, v):
116 | return self.f(t, y), self.g_prod(t, y, v)
117 |
118 | def f_and_g_prod_default2(self, t, y, v):
119 | f, g = self.f_and_g(t, y)
120 | return f, self.prod(g, v)
121 |
122 | ########################################
123 | # g_prod_and_gdg_prod #
124 | ########################################
125 |
126 | # Computes: g_prod and sum_{j, l} g_{j, l} d g_{j, l} d x_i v2_l.
127 | def g_prod_and_gdg_prod_default(self, t, y, v1, v2):
128 | requires_grad = torch.is_grad_enabled()
129 | with torch.enable_grad():
130 | y = y if y.requires_grad else y.detach().requires_grad_(True)
131 | g = self.g(t, y)
132 | vg_dg_vjp, = misc.vjp(
133 | outputs=g,
134 | inputs=y,
135 | grad_outputs=g * v2.unsqueeze(-2),
136 | retain_graph=True,
137 | create_graph=requires_grad,
138 | allow_unused=True
139 | )
140 | return self.prod(g, v1), vg_dg_vjp
141 |
142 | def g_prod_and_gdg_prod_diagonal(self, t, y, v1, v2):
143 | requires_grad = torch.is_grad_enabled()
144 | with torch.enable_grad():
145 | y = y if y.requires_grad else y.detach().requires_grad_(True)
146 | g = self.g(t, y)
147 | vg_dg_vjp, = misc.vjp(
148 | outputs=g,
149 | inputs=y,
150 | grad_outputs=g * v2,
151 | retain_graph=True,
152 | create_graph=requires_grad,
153 | allow_unused=True
154 | )
155 | return self.prod(g, v1), vg_dg_vjp
156 |
157 | def g_prod_and_gdg_prod_additive(self, t, y, v1, v2):
158 | return self.g_prod(t, y, v1), 0.
159 |
160 | ########################################
161 | # dg_ga_jvp #
162 | ########################################
163 |
164 | # Computes: sum_{j,k,l} d g_{i,l} / d x_j g_{j,k} A_{k,l}.
165 | def dg_ga_jvp_column_sum_v1(self, t, y, a):
166 | requires_grad = torch.is_grad_enabled()
167 | with torch.enable_grad():
168 | y = y if y.requires_grad else y.detach().requires_grad_(True)
169 | g = self.g(t, y)
170 | ga = torch.bmm(g, a)
171 | dg_ga_jvp = [
172 | misc.jvp(
173 | outputs=g[..., col_idx],
174 | inputs=y,
175 | grad_inputs=ga[..., col_idx],
176 | retain_graph=True,
177 | create_graph=requires_grad,
178 | allow_unused=True
179 | )[0]
180 | for col_idx in range(g.size(-1))
181 | ]
182 | dg_ga_jvp = sum(dg_ga_jvp)
183 | return dg_ga_jvp
184 |
185 | def dg_ga_jvp_column_sum_v2(self, t, y, a):
186 | # Faster, but more memory intensive.
187 | requires_grad = torch.is_grad_enabled()
188 | with torch.enable_grad():
189 | y = y if y.requires_grad else y.detach().requires_grad_(True)
190 | g = self.g(t, y)
191 | ga = torch.bmm(g, a)
192 |
193 | batch_size, d, m = g.size()
194 | y_dup = torch.repeat_interleave(y, repeats=m, dim=0)
195 | g_dup = self.g(t, y_dup)
196 | ga_flat = ga.transpose(1, 2).flatten(0, 1)
197 | dg_ga_jvp, = misc.jvp(
198 | outputs=g_dup,
199 | inputs=y_dup,
200 | grad_inputs=ga_flat,
201 | create_graph=requires_grad,
202 | allow_unused=True
203 | )
204 | dg_ga_jvp = dg_ga_jvp.reshape(batch_size, m, d, m).permute(0, 2, 1, 3)
205 | dg_ga_jvp = dg_ga_jvp.diagonal(dim1=-2, dim2=-1).sum(-1)
206 | return dg_ga_jvp
207 |
208 | def _return_zero(self, t, y, v): # noqa
209 | return 0.
210 |
211 |
212 | class RenameMethodsSDE(BaseSDE):
213 |
214 | def __init__(self, sde, drift='f', diffusion='g', prior_drift='h', diffusion_prod='g_prod',
215 | drift_and_diffusion='f_and_g', drift_and_diffusion_prod='f_and_g_prod'):
216 | super(RenameMethodsSDE, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type)
217 | self._base_sde = sde
218 | for name, value in zip(('f', 'g', 'h', 'g_prod', 'f_and_g', 'f_and_g_prod'),
219 | (drift, diffusion, prior_drift, diffusion_prod, drift_and_diffusion,
220 | drift_and_diffusion_prod)):
221 | try:
222 | setattr(self, name, getattr(sde, value))
223 | except AttributeError:
224 | pass
225 |
226 |
227 | class SDEIto(BaseSDE):
228 |
229 | def __init__(self, noise_type):
230 | super(SDEIto, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.ito)
231 |
232 |
233 | class SDEStratonovich(BaseSDE):
234 |
235 | def __init__(self, noise_type):
236 | super(SDEStratonovich, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.stratonovich)
237 |
238 |
239 | # --- Backwards compatibility: v0.1.1. ---
240 | class SDELogqp(BaseSDE):
241 |
242 | def __init__(self, sde):
243 | super(SDELogqp, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type)
244 | self._base_sde = sde
245 |
246 | # Make this redirection a one-time cost.
247 | try:
248 | self._base_f = sde.f
249 | self._base_g = sde.g
250 | self._base_h = sde.h
251 | except AttributeError as e:
252 | # TODO: relax this requirement, and use f_and_g, f_and_g_prod, f_and_g_and_h and f_and_g_prod_and_h if
253 | # they're available.
254 | raise AttributeError("If using logqp then drift, diffusion and prior drift must all be specified.") from e
255 |
256 | # Make this method selection a one-time cost.
257 | if sde.noise_type == NOISE_TYPES.diagonal:
258 | self.f = self.f_diagonal
259 | self.g = self.g_diagonal
260 | self.f_and_g = self.f_and_g_diagonal
261 | else:
262 | self.f = self.f_general
263 | self.g = self.g_general
264 | self.f_and_g = self.f_and_g_general
265 |
266 | def f_diagonal(self, t, y: Tensor):
267 | y = y[:, :-1]
268 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y)
269 | u = misc.stable_division(f - h, g)
270 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True)
271 | return torch.cat([f, f_logqp], dim=1)
272 |
273 | def g_diagonal(self, t, y: Tensor):
274 | y = y[:, :-1]
275 | g = self._base_g(t, y)
276 | g_logqp = y.new_zeros(size=(y.size(0), 1))
277 | return torch.cat([g, g_logqp], dim=1)
278 |
279 | def f_and_g_diagonal(self, t, y: Tensor):
280 | y = y[:, :-1]
281 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y)
282 | u = misc.stable_division(f - h, g)
283 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True)
284 | g_logqp = y.new_zeros(size=(y.size(0), 1))
285 | return torch.cat([f, f_logqp], dim=1), torch.cat([g, g_logqp], dim=1)
286 |
287 | def f_general(self, t, y: Tensor):
288 | y = y[:, :-1]
289 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y)
290 | u = misc.batch_mvp(g.pinverse(), f - h) # (batch_size, brownian_size).
291 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True)
292 | return torch.cat([f, f_logqp], dim=1)
293 |
294 | def g_general(self, t, y: Tensor):
295 | y = y[:, :-1]
296 | g = self._base_sde.g(t, y)
297 | g_logqp = y.new_zeros(size=(g.size(0), 1, g.size(-1)))
298 | return torch.cat([g, g_logqp], dim=1)
299 |
300 | def f_and_g_general(self, t, y: Tensor):
301 | y = y[:, :-1]
302 | f, g, h = self._base_f(t, y), self._base_g(t, y), self._base_h(t, y)
303 | u = misc.batch_mvp(g.pinverse(), f - h) # (batch_size, brownian_size).
304 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True)
305 | g_logqp = y.new_zeros(size=(g.size(0), 1, g.size(-1)))
306 | return torch.cat([f, f_logqp], dim=1), torch.cat([g, g_logqp], dim=1)
307 | # ----------------------------------------
308 |
--------------------------------------------------------------------------------
/torchsde/_core/base_solver.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import abc
16 | import warnings
17 |
18 | import torch
19 |
20 | from . import adaptive_stepping
21 | from . import better_abc
22 | from . import interp
23 | from .base_sde import BaseSDE
24 | from .._brownian import BaseBrownian
25 | from ..settings import NOISE_TYPES
26 | from ..types import Scalar, Tensor, Dict, Tensors, Tuple
27 |
28 |
29 | class BaseSDESolver(metaclass=better_abc.ABCMeta):
30 | """API for solvers with possibly adaptive time stepping."""
31 |
32 | strong_order = better_abc.abstract_attribute()
33 | weak_order = better_abc.abstract_attribute()
34 | sde_type = better_abc.abstract_attribute()
35 | noise_types = better_abc.abstract_attribute()
36 | levy_area_approximations = better_abc.abstract_attribute()
37 |
38 | def __init__(self,
39 | sde: BaseSDE,
40 | bm: BaseBrownian,
41 | dt: Scalar,
42 | adaptive: bool,
43 | rtol: Scalar,
44 | atol: Scalar,
45 | dt_min: Scalar,
46 | options: Dict,
47 | **kwargs):
48 | super(BaseSDESolver, self).__init__(**kwargs)
49 | if sde.sde_type != self.sde_type:
50 | raise ValueError(f"SDE is of type {sde.sde_type} but solver is for type {self.sde_type}")
51 | if sde.noise_type not in self.noise_types:
52 | raise ValueError(f"SDE has noise type {sde.noise_type} but solver only supports noise types "
53 | f"{self.noise_types}")
54 | if bm.levy_area_approximation not in self.levy_area_approximations:
55 | raise ValueError(f"SDE solver requires one of {self.levy_area_approximations} set as the "
56 | f"`levy_area_approximation` on the Brownian motion.")
57 | if sde.noise_type == NOISE_TYPES.scalar and torch.Size(bm.shape[1:]).numel() != 1: # noqa
58 | raise ValueError("The Brownian motion for scalar SDEs must of dimension 1.")
59 |
60 | self.sde = sde
61 | self.bm = bm
62 | self.dt = dt
63 | self.adaptive = adaptive
64 | self.rtol = rtol
65 | self.atol = atol
66 | self.dt_min = dt_min
67 | self.options = options
68 |
69 | def __repr__(self):
70 | return f"{self.__class__.__name__} of strong order: {self.strong_order}, and weak order: {self.weak_order}"
71 |
72 | def init_extra_solver_state(self, t0, y0) -> Tensors:
73 | return ()
74 |
75 | @abc.abstractmethod
76 | def step(self, t0: Scalar, t1: Scalar, y0: Tensor, extra0: Tensors) -> Tuple[Tensor, Tensors]:
77 | """Propose a step with step size from time t to time next_t, with
78 | current state y.
79 |
80 | Args:
81 | t0: float or Tensor of size (,).
82 | t1: float or Tensor of size (,).
83 | y0: Tensor of size (batch_size, d).
84 | extra0: Any extra state for the solver.
85 |
86 | Returns:
87 | y1, where y1 is a Tensor of size (batch_size, d).
88 | extra1: Modified extra state for the solver.
89 | """
90 | raise NotImplementedError
91 |
92 | def integrate(self, y0: Tensor, ts: Tensor, extra0: Tensors) -> Tuple[Tensor, Tensors]:
93 | """Integrate along trajectory.
94 |
95 | Args:
96 | y0: Tensor of size (batch_size, d)
97 | ts: Tensor of size (T,).
98 | extra0: Any extra state for the solver.
99 |
100 | Returns:
101 | ys, where ys is a Tensor of size (T, batch_size, d).
102 | extra_solver_state, which is a tuple of Tensors of shape (T, ...), where ... is arbitrary and
103 | solver-dependent.
104 | """
105 | step_size = self.dt
106 |
107 | prev_t = curr_t = ts[0]
108 | prev_y = curr_y = y0
109 | curr_extra = extra0
110 |
111 | ys = [y0]
112 | prev_error_ratio = None
113 |
114 | for out_t in ts[1:]:
115 | while curr_t < out_t:
116 | next_t = min(curr_t + step_size, ts[-1])
117 | if self.adaptive:
118 | # Take 1 full step.
119 | next_y_full, _ = self.step(curr_t, next_t, curr_y, curr_extra)
120 | # Take 2 half steps.
121 | midpoint_t = 0.5 * (curr_t + next_t)
122 | midpoint_y, midpoint_extra = self.step(curr_t, midpoint_t, curr_y, curr_extra)
123 | next_y, next_extra = self.step(midpoint_t, next_t, midpoint_y, midpoint_extra)
124 |
125 | # Estimate error based on difference between 1 full step and 2 half steps.
126 | with torch.no_grad():
127 | error_estimate = adaptive_stepping.compute_error(next_y_full, next_y, self.rtol, self.atol)
128 | step_size, prev_error_ratio = adaptive_stepping.update_step_size(
129 | error_estimate=error_estimate,
130 | prev_step_size=step_size,
131 | prev_error_ratio=prev_error_ratio
132 | )
133 |
134 | if step_size < self.dt_min:
135 | warnings.warn("Hitting minimum allowed step size in adaptive time-stepping.")
136 | step_size = self.dt_min
137 | prev_error_ratio = None
138 |
139 | # Accept step.
140 | if error_estimate <= 1 or step_size <= self.dt_min:
141 | prev_t, prev_y = curr_t, curr_y
142 | curr_t, curr_y, curr_extra = next_t, next_y, next_extra
143 | else:
144 | prev_t, prev_y = curr_t, curr_y
145 | curr_y, curr_extra = self.step(curr_t, next_t, curr_y, curr_extra)
146 | curr_t = next_t
147 | ys.append(interp.linear_interp(t0=prev_t, y0=prev_y, t1=curr_t, y1=curr_y, t=out_t))
148 |
149 | return torch.stack(ys, dim=0), curr_extra
150 |
--------------------------------------------------------------------------------
/torchsde/_core/better_abc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Useful trick adapted from https://stackoverflow.com/a/50381071/12254339
16 | # Allows one to define abstract instance attributes
17 |
18 | import abc
19 |
20 |
21 | class DummyAttribute:
22 | pass
23 |
24 |
25 | def abstract_attribute(obj=None):
26 | if obj is None:
27 | obj = DummyAttribute()
28 | obj.__is_abstract_attribute__ = True
29 | return obj
30 |
31 |
32 | class ABCMeta(abc.ABCMeta):
33 | def __call__(cls, *args, **kwargs):
34 | instance = super(ABCMeta, cls).__call__(*args, **kwargs)
35 | abstract_attributes = {
36 | name
37 | for name in dir(instance)
38 | if getattr(getattr(instance, name), '__is_abstract_attribute__', False)
39 | }
40 | if abstract_attributes:
41 | raise NotImplementedError(
42 | "Can't instantiate abstract class {} with"
43 | " abstract attributes: {}".format(
44 | cls.__name__,
45 | ', '.join(abstract_attributes)
46 | )
47 | )
48 | return instance
49 |
--------------------------------------------------------------------------------
/torchsde/_core/interp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | def linear_interp(t0, y0, t1, y1, t):
16 | assert t0 <= t <= t1, f"Incorrect time order for linear interpolation: t0={t0}, t={t}, t1={t1}."
17 | y = (t1 - t) / (t1 - t0) * y0 + (t - t0) / (t1 - t0) * y1
18 | return y
19 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .euler import Euler
16 | from .euler_heun import EulerHeun
17 | from .heun import Heun
18 | from .log_ode import LogODEMidpoint
19 | from .midpoint import Midpoint
20 | from .milstein import MilsteinIto, MilsteinStratonovich
21 | from .reversible_heun import ReversibleHeun, AdjointReversibleHeun
22 | from .srk import SRK
23 | from ...settings import METHODS, SDE_TYPES
24 |
25 |
26 | def select(method, sde_type):
27 | if method == METHODS.euler:
28 | return Euler
29 | elif method == METHODS.milstein and sde_type == SDE_TYPES.ito:
30 | return MilsteinIto
31 | elif method == METHODS.srk:
32 | return SRK
33 | elif method == METHODS.midpoint:
34 | return Midpoint
35 | elif method == METHODS.reversible_heun:
36 | return ReversibleHeun
37 | elif method == METHODS.adjoint_reversible_heun:
38 | return AdjointReversibleHeun
39 | elif method == METHODS.heun:
40 | return Heun
41 | elif method == METHODS.milstein and sde_type == SDE_TYPES.stratonovich:
42 | return MilsteinStratonovich
43 | elif method == METHODS.log_ode_midpoint:
44 | return LogODEMidpoint
45 | elif method == METHODS.euler_heun:
46 | return EulerHeun
47 | else:
48 | raise ValueError(f"Method '{method}' does not match any known method.")
49 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/euler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .. import base_solver
16 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS
17 |
18 |
19 | class Euler(base_solver.BaseSDESolver):
20 | weak_order = 1.0
21 | sde_type = SDE_TYPES.ito
22 | noise_types = NOISE_TYPES.all()
23 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()
24 |
25 | def __init__(self, sde, **kwargs):
26 | self.strong_order = 1.0 if sde.noise_type == NOISE_TYPES.additive else 0.5
27 | super(Euler, self).__init__(sde=sde, **kwargs)
28 |
29 | def step(self, t0, t1, y0, extra0):
30 | del extra0
31 | dt = t1 - t0
32 | I_k = self.bm(t0, t1)
33 |
34 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)
35 |
36 | y1 = y0 + f * dt + g_prod
37 | return y1, ()
38 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/euler_heun.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .. import base_solver
16 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS
17 |
18 |
19 | class EulerHeun(base_solver.BaseSDESolver):
20 | weak_order = 1.0
21 | sde_type = SDE_TYPES.stratonovich
22 | noise_types = NOISE_TYPES.all()
23 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()
24 |
25 | def __init__(self, sde, **kwargs):
26 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0
27 | super(EulerHeun, self).__init__(sde=sde, **kwargs)
28 |
29 | def step(self, t0, t1, y0, extra0):
30 | del extra0
31 | dt = t1 - t0
32 | I_k = self.bm(t0, t1)
33 |
34 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)
35 |
36 | y_prime = y0 + g_prod
37 |
38 | g_prod_prime = self.sde.g_prod(t1, y_prime, I_k)
39 |
40 | y1 = y0 + dt * f + (g_prod + g_prod_prime) * 0.5
41 |
42 | return y1, ()
43 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/heun.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Stratonovich Heun method (strong order 1.0 scheme) from
16 |
17 | Burrage K., Burrage P. M. and Tian T. 2004 "Numerical methods for strong solutions
18 | of stochastic differential equations: an overview" Proc. R. Soc. Lond. A. 460: 373–402.
19 | """
20 |
21 | from .. import base_solver
22 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS
23 |
24 |
25 | class Heun(base_solver.BaseSDESolver):
26 | weak_order = 1.0
27 | sde_type = SDE_TYPES.stratonovich
28 | noise_types = NOISE_TYPES.all()
29 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()
30 |
31 | def __init__(self, sde, **kwargs):
32 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0
33 | super(Heun, self).__init__(sde=sde, **kwargs)
34 |
35 | def step(self, t0, t1, y0, extra0):
36 | del extra0
37 | dt = t1 - t0
38 | I_k = self.bm(t0, t1)
39 |
40 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)
41 |
42 | y0_prime = y0 + dt * f + g_prod
43 |
44 | f_prime, g_prod_prime = self.sde.f_and_g_prod(t1, y0_prime, I_k)
45 |
46 | y1 = y0 + (dt * (f + f_prime) + g_prod + g_prod_prime) * 0.5
47 |
48 | return y1, ()
49 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/log_ode.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Log-ODE scheme constructed by combining Lie-Trotter splitting with the explicit midpoint method.
16 |
17 | The scheme uses Levy area approximations.
18 | """
19 |
20 | from .. import adjoint_sde
21 | from .. import base_solver
22 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS
23 |
24 |
25 | class LogODEMidpoint(base_solver.BaseSDESolver):
26 | weak_order = 1.0
27 | sde_type = SDE_TYPES.stratonovich
28 | noise_types = NOISE_TYPES.all()
29 | levy_area_approximations = (LEVY_AREA_APPROXIMATIONS.davie, LEVY_AREA_APPROXIMATIONS.foster)
30 |
31 | def __init__(self, sde, **kwargs):
32 | if isinstance(sde, adjoint_sde.AdjointSDE):
33 | raise ValueError("Log-ODE schemes cannot be used for adjoint SDEs, because they require "
34 | "direct access to the diffusion, whilst adjoint SDEs rely on a more efficient "
35 | "diffusion-vector product. Use a different method instead.")
36 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0
37 | super(LogODEMidpoint, self).__init__(sde=sde, **kwargs)
38 |
39 | def step(self, t0, t1, y0, extra0):
40 | del extra0
41 | dt = t1 - t0
42 | I_k, A = self.bm(t0, t1, return_A=True)
43 |
44 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)
45 |
46 | half_dt = 0.5 * dt
47 |
48 | t_prime = t0 + half_dt
49 | y_prime = y0 + half_dt * f + .5 * g_prod
50 |
51 | f_prime, g_prod_prime = self.sde.f_and_g_prod(t_prime, y_prime, I_k)
52 | dg_ga_prime = self.sde.dg_ga_jvp_column_sum(t_prime, y_prime, A)
53 |
54 | y1 = y0 + dt * f_prime + g_prod_prime + dg_ga_prime
55 |
56 | return y1, ()
57 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/midpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .. import base_solver
16 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS
17 |
18 |
19 | class Midpoint(base_solver.BaseSDESolver):
20 | weak_order = 1.0
21 | sde_type = SDE_TYPES.stratonovich
22 | noise_types = NOISE_TYPES.all()
23 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()
24 |
25 | def __init__(self, sde, **kwargs):
26 | self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0
27 | super(Midpoint, self).__init__(sde=sde, **kwargs)
28 |
29 | def step(self, t0, t1, y0, extra0):
30 | del extra0
31 | dt = t1 - t0
32 | I_k = self.bm(t0, t1)
33 |
34 | f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)
35 |
36 | half_dt = 0.5 * dt
37 |
38 | t_prime = t0 + half_dt
39 | y_prime = y0 + half_dt * f + 0.5 * g_prod
40 |
41 | f_prime, g_prod_prime = self.sde.f_and_g_prod(t_prime, y_prime, I_k)
42 |
43 | y1 = y0 + dt * f_prime + g_prod_prime
44 |
45 | return y1, ()
46 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/milstein.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import abc
16 |
17 | from .. import adjoint_sde
18 | from .. import base_solver
19 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHOD_OPTIONS
20 |
21 |
22 | class BaseMilstein(base_solver.BaseSDESolver, metaclass=abc.ABCMeta):
23 | strong_order = 1.0
24 | weak_order = 1.0
25 | noise_types = (NOISE_TYPES.additive, NOISE_TYPES.diagonal, NOISE_TYPES.scalar)
26 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()
27 |
28 | def __init__(self, sde, options, **kwargs):
29 | if METHOD_OPTIONS.grad_free not in options:
30 | options[METHOD_OPTIONS.grad_free] = False
31 | if options[METHOD_OPTIONS.grad_free]:
32 | if sde.noise_type == NOISE_TYPES.additive:
33 | # dg=0 in this case, and gdg_prod is already setup to handle that, whilst the grad_free code path isn't.
34 | options[METHOD_OPTIONS.grad_free] = False
35 | if options[METHOD_OPTIONS.grad_free]:
36 | if isinstance(sde, adjoint_sde.AdjointSDE):
37 | # We need access to the diffusion to do things grad-free.
38 | raise ValueError(f"Derivative-free Milstein cannot be used for adjoint SDEs, because it requires "
39 | f"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient "
40 | f"diffusion-vector product. Use derivative-using Milstein instead: "
41 | f"`adjoint_options=dict({METHOD_OPTIONS.grad_free}=False)`")
42 | super(BaseMilstein, self).__init__(sde=sde, options=options, **kwargs)
43 |
44 | @abc.abstractmethod
45 | def v_term(self, I_k, dt):
46 | raise NotImplementedError
47 |
48 | @abc.abstractmethod
49 | def y_prime_f_factor(self, dt, f):
50 | raise NotImplementedError
51 |
52 | def step(self, t0, t1, y0, extra0):
53 | del extra0
54 | dt = t1 - t0
55 | I_k = self.bm(t0, t1)
56 | v = self.v_term(I_k, dt)
57 |
58 | if self.options[METHOD_OPTIONS.grad_free]:
59 | f, g = self.sde.f_and_g(t0, y0)
60 | g_ = g.squeeze(2) if g.dim() == 3 else g # scalar noise vs diagonal noise
61 | sqrt_dt = dt.sqrt()
62 | # TODO: This y_prime_f_factor looks unnecessary: whether it's there or not we get the correct Taylor
63 | # expansion. I've (Patrick) not been able to find a reference making clear why it's sometimes included.
64 | y0_prime = y0 + self.y_prime_f_factor(dt, f) + g_ * sqrt_dt
65 | g_prime = self.sde.g(t0, y0_prime)
66 | g_prod_I_k = self.sde.prod(g, I_k)
67 | gdg_prod = self.sde.prod(g_prime - g, v) / (2 * sqrt_dt)
68 | else:
69 | f = self.sde.f(t0, y0)
70 | g_prod_I_k, gdg_prod = self.sde.g_prod_and_gdg_prod(t0, y0, I_k, 0.5 * v)
71 |
72 | y1 = y0 + f * dt + g_prod_I_k + gdg_prod
73 |
74 | return y1, ()
75 |
76 |
77 | class MilsteinIto(BaseMilstein):
78 | sde_type = SDE_TYPES.ito
79 |
80 | def v_term(self, I_k, dt):
81 | return I_k ** 2 - dt
82 |
83 | def y_prime_f_factor(self, dt, f):
84 | return dt * f
85 |
86 |
87 | class MilsteinStratonovich(BaseMilstein):
88 | sde_type = SDE_TYPES.stratonovich
89 |
90 | def v_term(self, I_k, dt):
91 | return I_k ** 2
92 |
93 | def y_prime_f_factor(self, dt, f):
94 | return 0.
95 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/reversible_heun.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | """Reversible Heun method from
17 |
18 | https://arxiv.org/abs/2105.13493
19 |
20 | Known to be strong order 0.5 in general and strong order 1.0 for additive noise.
21 | Precise strong orders for diagonal/scalar noise, and weak order in general, are
22 | for the time being unknown.
23 |
24 | This solver uses some extra state such that it is _algebraically reversible_:
25 | it is possible to reconstruct its input (y0, f0, g0, z0) given its output
26 | (y1, f1, g1, z1).
27 |
28 | This means we can backpropagate by (a) inverting these operations, (b) doing a local
29 | forward operation to construct a computation graph, (c) differentiate the local
30 | forward. This is what the adjoint method here does.
31 |
32 | This is in contrast to standard backpropagation, which requires holding all of these
33 | values in memory.
34 |
35 | This is contrast to the standard continuous adjoint method (sdeint_adjoint), which
36 | can only perform this procedure approximately, and only produces approximate gradients
37 | as a result.
38 | """
39 |
40 | import torch
41 |
42 | from .. import adjoint_sde
43 | from .. import base_solver
44 | from .. import misc
45 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHODS
46 |
47 |
48 | class ReversibleHeun(base_solver.BaseSDESolver):
49 | weak_order = 1.0
50 | sde_type = SDE_TYPES.stratonovich
51 | noise_types = NOISE_TYPES.all()
52 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()
53 |
54 | def __init__(self, sde, **kwargs):
55 | self.strong_order = 1.0 if sde.noise_type == NOISE_TYPES.additive else 0.5
56 | super(ReversibleHeun, self).__init__(sde=sde, **kwargs)
57 |
58 | def init_extra_solver_state(self, t0, y0):
59 | return self.sde.f_and_g(t0, y0) + (y0,)
60 |
61 | def step(self, t0, t1, y0, extra0):
62 | f0, g0, z0 = extra0
63 | # f is a drift-like quantity
64 | # g is a diffusion-like quantity
65 | # z is a state-like quantity (like y)
66 | dt = t1 - t0
67 | dW = self.bm(t0, t1)
68 |
69 | z1 = 2 * y0 - z0 + f0 * dt + self.sde.prod(g0, dW)
70 | f1, g1 = self.sde.f_and_g(t1, z1)
71 | y1 = y0 + (f0 + f1) * (0.5 * dt) + self.sde.prod(g0 + g1, 0.5 * dW)
72 |
73 | return y1, (f1, g1, z1)
74 |
75 |
76 | class AdjointReversibleHeun(base_solver.BaseSDESolver):
77 | weak_order = 1.0
78 | sde_type = SDE_TYPES.stratonovich
79 | noise_types = NOISE_TYPES.all()
80 | levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()
81 |
82 | def __init__(self, sde, **kwargs):
83 | if not isinstance(sde, adjoint_sde.AdjointSDE):
84 | raise ValueError(f"{METHODS.adjoint_reversible_heun} can only be used for adjoint_method.")
85 | self.strong_order = 1.0 if sde.noise_type == NOISE_TYPES.additive else 0.5
86 | super(AdjointReversibleHeun, self).__init__(sde=sde, **kwargs)
87 | self.forward_sde = sde.forward_sde
88 |
89 | if self.forward_sde.noise_type == NOISE_TYPES.diagonal:
90 | self._adjoint_of_prod = lambda tensor1, tensor2: tensor1 * tensor2
91 | else:
92 | self._adjoint_of_prod = lambda tensor1, tensor2: tensor1.unsqueeze(-1) * tensor2.unsqueeze(-2)
93 |
94 | def init_extra_solver_state(self, t0, y0):
95 | # We expect to always be given the extra state from the forward pass.
96 | raise RuntimeError("Please report a bug to torchsde.")
97 |
98 | def step(self, t0, t1, y0, extra0):
99 | forward_f0, forward_g0, forward_z0 = extra0
100 | dt = t1 - t0
101 | dW = self.bm(t0, t1)
102 | half_dt = 0.5 * dt
103 | half_dW = 0.5 * dW
104 | forward_y0, adj_y0, (adj_f0, adj_g0, adj_z0, *adj_params), requires_grad = self.sde.get_state(t0, y0,
105 | extra_states=True)
106 | adj_y0_half_dt = adj_y0 * half_dt
107 | adj_y0_half_dW = self._adjoint_of_prod(adj_y0, half_dW)
108 |
109 | forward_z1 = 2 * forward_y0 - forward_z0 - forward_f0 * dt - self.forward_sde.prod(forward_g0, dW)
110 |
111 | adj_y1 = adj_y0
112 | adj_f1 = adj_y0_half_dt
113 | adj_f0 = adj_f0 + adj_y0_half_dt
114 | adj_g1 = adj_y0_half_dW
115 | adj_g0 = adj_g0 + adj_y0_half_dW
116 |
117 | # TODO: efficiency. It should be possible to make one fewer forward call by re-using the forward computation
118 | # in the previous step.
119 | with torch.enable_grad():
120 | if not forward_z0.requires_grad:
121 | forward_z0 = forward_z0.detach().requires_grad_()
122 | re_forward_f0, re_forward_g0 = self.forward_sde.f_and_g(-t0, forward_z0)
123 |
124 | vjp_z, *vjp_params = misc.vjp(outputs=(re_forward_f0, re_forward_g0),
125 | inputs=[forward_z0] + self.sde.params,
126 | grad_outputs=[adj_f0, adj_g0],
127 | allow_unused=True,
128 | retain_graph=True,
129 | create_graph=requires_grad)
130 | adj_z0 = adj_z0 + vjp_z
131 | adj_params = misc.seq_add(adj_params, vjp_params)
132 |
133 | forward_f1, forward_g1 = self.forward_sde.f_and_g(-t1, forward_z1)
134 | forward_y1 = forward_y0 - (forward_f0 + forward_f1) * half_dt - self.forward_sde.prod(forward_g0 + forward_g1,
135 | half_dW)
136 |
137 | adj_y1 = adj_y1 + 2 * adj_z0
138 | adj_z1 = -adj_z0
139 | adj_f1 = adj_f1 + adj_z0 * dt
140 | adj_g1 = adj_g1 + self._adjoint_of_prod(adj_z0, dW)
141 |
142 | y1 = misc.flatten([forward_y1, adj_y1, adj_f1, adj_g1, adj_z1] + adj_params).unsqueeze(0)
143 |
144 | return y1, (forward_f1, forward_g1, forward_z1)
145 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/srk.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Strong order 1.5 scheme from
16 |
17 | Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions
18 | of stochastic differential equations." SIAM Journal on Numerical Analysis 48,
19 | no. 3 (2010): 922-952.
20 | """
21 |
22 | from .tableaus import sra1, srid2
23 | from .. import adjoint_sde
24 | from .. import base_solver
25 | from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS
26 |
27 | _r2 = 1 / 2
28 | _r6 = 1 / 6
29 |
30 |
31 | class SRK(base_solver.BaseSDESolver):
32 | strong_order = 1.5
33 | weak_order = 1.5
34 | sde_type = SDE_TYPES.ito
35 | noise_types = (NOISE_TYPES.additive, NOISE_TYPES.diagonal, NOISE_TYPES.scalar)
36 | levy_area_approximations = (LEVY_AREA_APPROXIMATIONS.space_time,
37 | LEVY_AREA_APPROXIMATIONS.davie,
38 | LEVY_AREA_APPROXIMATIONS.foster)
39 |
40 | def __init__(self, sde, **kwargs):
41 | if sde.noise_type == NOISE_TYPES.additive:
42 | self.step = self.additive_step
43 | else:
44 | self.step = self.diagonal_or_scalar_step
45 |
46 | if isinstance(sde, adjoint_sde.AdjointSDE):
47 | raise ValueError("Stochastic Runge–Kutta methods cannot be used for adjoint SDEs, because it requires "
48 | "direct access to the diffusion, whilst adjoint SDEs rely on a more efficient "
49 | "diffusion-vector product. Use a different method instead.")
50 |
51 | super(SRK, self).__init__(sde=sde, **kwargs)
52 |
53 | def step(self, t0, t1, y, extra0):
54 | # Just to make @abstractmethod happy, as we assign during __init__.
55 | raise RuntimeError
56 |
57 | def diagonal_or_scalar_step(self, t0, t1, y0, extra0):
58 | del extra0
59 | dt = t1 - t0
60 | rdt = 1 / dt
61 | sqrt_dt = dt.sqrt()
62 | I_k, I_k0 = self.bm(t0, t1, return_U=True)
63 | I_kk = (I_k ** 2 - dt) * _r2
64 | I_kkk = (I_k ** 3 - 3 * dt * I_k) * _r6
65 |
66 | y1 = y0
67 | H0, H1 = [], []
68 | for s in range(srid2.STAGES):
69 | H0s, H1s = y0, y0 # Values at the current stage to be accumulated.
70 | for j in range(s):
71 | f = self.sde.f(t0 + srid2.C0[j] * dt, H0[j])
72 | g = self.sde.g(t0 + srid2.C1[j] * dt, H1[j])
73 | g = g.squeeze(2) if g.dim() == 3 else g
74 | H0s = H0s + srid2.A0[s][j] * f * dt + srid2.B0[s][j] * g * I_k0 * rdt
75 | H1s = H1s + srid2.A1[s][j] * f * dt + srid2.B1[s][j] * g * sqrt_dt
76 | H0.append(H0s)
77 | H1.append(H1s)
78 |
79 | f = self.sde.f(t0 + srid2.C0[s] * dt, H0s)
80 | g_weight = (
81 | srid2.beta1[s] * I_k +
82 | srid2.beta2[s] * I_kk / sqrt_dt +
83 | srid2.beta3[s] * I_k0 * rdt +
84 | srid2.beta4[s] * I_kkk * rdt
85 | )
86 | g_prod = self.sde.g_prod(t0 + srid2.C1[s] * dt, H1s, g_weight)
87 | y1 = y1 + srid2.alpha[s] * f * dt + g_prod
88 | return y1, ()
89 |
90 | def additive_step(self, t0, t1, y0, extra0):
91 | del extra0
92 | dt = t1 - t0
93 | rdt = 1 / dt
94 | I_k, I_k0 = self.bm(t0, t1, return_U=True)
95 |
96 | y1 = y0
97 | H0 = []
98 | for i in range(sra1.STAGES):
99 | H0i = y0
100 | for j in range(i):
101 | f = self.sde.f(t0 + sra1.C0[j] * dt, H0[j])
102 | g_weight = sra1.B0[i][j] * I_k0 * rdt
103 | g_prod = self.sde.g_prod(t0 + sra1.C1[j] * dt, y0, g_weight)
104 | H0i = H0i + sra1.A0[i][j] * f * dt + g_prod
105 | H0.append(H0i)
106 |
107 | f = self.sde.f(t0 + sra1.C0[i] * dt, H0i)
108 | g_weight = sra1.beta1[i] * I_k + sra1.beta2[i] * I_k0 * rdt
109 | g_prod = self.sde.g_prod(t0 + sra1.C1[i] * dt, y0, g_weight)
110 | y1 = y1 + sra1.alpha[i] * f * dt + g_prod
111 | return y1, ()
112 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/tableaus/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/tableaus/sra1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS".
16 | # For additive noise structure.
17 | # (ODE order, SDE strong order) = (2.0, 1.5).
18 |
19 | STAGES = 2
20 |
21 | C0 = (0, 3 / 4)
22 | C1 = (1, 0)
23 |
24 | A0 = (
25 | (),
26 | (3 / 4,),
27 | )
28 |
29 | B0 = (
30 | (),
31 | (3 / 2,),
32 | )
33 |
34 | alpha = (1 / 3, 2 / 3)
35 | beta1 = (1, 0)
36 | beta2 = (-1, 1)
37 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/tableaus/sra2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS".
16 | # For additive noise structure.
17 | # (ODE order, SDE strong order) = (2.0, 1.5).
18 |
19 | STAGES = 2
20 |
21 | C0 = (0, 3 / 4)
22 | C1 = (1 / 3, 1)
23 |
24 | A0 = (
25 | (),
26 | (3 / 4,),
27 | )
28 |
29 | B0 = (
30 | (),
31 | (3 / 2,),
32 | )
33 |
34 | alpha = (1 / 3, 2 / 3)
35 | beta1 = (0, 1)
36 | beta2 = (-3 / 2, 3 / 2)
37 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/tableaus/sra3.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS".
16 | # For additive noise structure.
17 | # (ODE order, SDE strong order) = (3.0, 1.5).
18 |
19 | STAGES = 3
20 |
21 | C0 = (0, 1, 1 / 2)
22 | C1 = (1, 0, 0)
23 |
24 | A0 = (
25 | (),
26 | (1,),
27 | (1 / 4, 1 / 4),
28 | )
29 |
30 | B0 = (
31 | (),
32 | (0,),
33 | (1, 1 / 2),
34 | )
35 |
36 | alpha = (1 / 6, 1 / 6, 2 / 3)
37 | beta1 = (1, 0, 0)
38 | beta2 = (1, -1, 0)
39 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/tableaus/srid1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS".
16 | # For diagonal noise structure.
17 | # (ODE order, SDE strong order) = (2.0, 1.5).
18 |
19 | STAGES = 4
20 |
21 | C0 = (0, 3 / 4, 0, 0)
22 | C1 = (0, 1 / 4, 1, 1 / 4)
23 |
24 | A0 = (
25 | (),
26 | (3 / 4,),
27 | (0, 0),
28 | (0, 0, 0),
29 | )
30 | A1 = (
31 | (),
32 | (1 / 4,),
33 | (1, 0),
34 | (0, 0, 1 / 4)
35 | )
36 |
37 | B0 = (
38 | (),
39 | (3 / 2,),
40 | (0, 0),
41 | (0, 0, 0),
42 | )
43 | B1 = (
44 | (),
45 | (1 / 2,),
46 | (-1, 0),
47 | (-5, 3, 1 / 2)
48 | )
49 |
50 | alpha = (1 / 3, 2 / 3, 0, 0)
51 | beta1 = (-1, 4 / 3, 2 / 3, 0)
52 | beta2 = (-1, 4 / 3, -1 / 3, 0)
53 | beta3 = (2, -4 / 3, -2 / 3, 0)
54 | beta4 = (-2, 5 / 3, -2 / 3, 1)
55 |
--------------------------------------------------------------------------------
/torchsde/_core/methods/tableaus/srid2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # From "RUNGE-KUTTA METHODS FOR THE STRONG APPROXIMATION OF SOLUTIONS OF STOCHASTIC DIFFERENTIAL EQUATIONS".
16 | # For diagonal noise structure.
17 | # (ODE order, SDE strong order) = (3.0, 1.5).
18 |
19 | STAGES = 4
20 |
21 | C0 = (0, 1, 1 / 2, 0)
22 | C1 = (0, 1 / 4, 1, 1 / 4)
23 |
24 | A0 = (
25 | (),
26 | (1,),
27 | (1 / 4, 1 / 4),
28 | (0, 0, 0)
29 | )
30 | A1 = (
31 | (),
32 | (1 / 4,),
33 | (1, 0),
34 | (0, 0, 1 / 4)
35 | )
36 |
37 | B0 = (
38 | (),
39 | (0,),
40 | (1, 1 / 2),
41 | (0, 0, 0),
42 | )
43 | B1 = (
44 | (),
45 | (-1 / 2,),
46 | (1, 0),
47 | (2, -1, 1 / 2)
48 | )
49 |
50 | alpha = (1 / 6, 1 / 6, 2 / 3, 0)
51 | beta1 = (-1, 4 / 3, 2 / 3, 0)
52 | beta2 = (1, -4 / 3, 1 / 3, 0)
53 | beta3 = (2, -4 / 3, -2 / 3, 0)
54 | beta4 = (-2, 5 / 3, -2 / 3, 1)
55 |
--------------------------------------------------------------------------------
/torchsde/_core/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | import torch
18 |
19 |
20 | def assert_no_grad(names, maybe_tensors):
21 | for name, maybe_tensor in zip(names, maybe_tensors):
22 | if torch.is_tensor(maybe_tensor) and maybe_tensor.requires_grad:
23 | raise ValueError(f"Argument {name} must not require gradient.")
24 |
25 |
26 | def handle_unused_kwargs(unused_kwargs, msg=None):
27 | if len(unused_kwargs) > 0:
28 | if msg is not None:
29 | warnings.warn(f"{msg}: Unexpected arguments {unused_kwargs}")
30 | else:
31 | warnings.warn(f"Unexpected arguments {unused_kwargs}")
32 |
33 |
34 | def flatten(sequence):
35 | return torch.cat([p.reshape(-1) for p in sequence]) if len(sequence) > 0 else torch.tensor([])
36 |
37 |
38 | def convert_none_to_zeros(sequence, like_sequence):
39 | return [torch.zeros_like(q) if p is None else p for p, q in zip(sequence, like_sequence)]
40 |
41 |
42 | def make_seq_requires_grad(sequence):
43 | return [p if p.requires_grad else p.detach().requires_grad_(True) for p in sequence]
44 |
45 |
46 | def is_strictly_increasing(ts):
47 | return all(x < y for x, y in zip(ts[:-1], ts[1:]))
48 |
49 |
50 | def is_nan(t):
51 | return torch.any(torch.isnan(t))
52 |
53 |
54 | def seq_add(*seqs):
55 | return [sum(seq) for seq in zip(*seqs)]
56 |
57 |
58 | def seq_sub(xs, ys):
59 | return [x - y for x, y in zip(xs, ys)]
60 |
61 |
62 | def batch_mvp(m, v):
63 | return torch.bmm(m, v.unsqueeze(-1)).squeeze(dim=-1)
64 |
65 |
66 | def stable_division(a, b, epsilon=1e-7):
67 | b = torch.where(b.abs().detach() > epsilon, b, torch.full_like(b, fill_value=epsilon) * b.sign())
68 | return a / b
69 |
70 |
71 | def vjp(outputs, inputs, **kwargs):
72 | if torch.is_tensor(inputs):
73 | inputs = [inputs]
74 | _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 74
75 |
76 | if torch.is_tensor(outputs):
77 | outputs = [outputs]
78 | outputs = make_seq_requires_grad(outputs)
79 |
80 | _vjp = torch.autograd.grad(outputs, inputs, **kwargs)
81 | return convert_none_to_zeros(_vjp, inputs)
82 |
83 |
84 | def jvp(outputs, inputs, grad_inputs=None, **kwargs):
85 | # Unlike `torch.autograd.functional.jvp`, this function avoids repeating forward computation.
86 | if torch.is_tensor(inputs):
87 | inputs = [inputs]
88 | _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 88
89 |
90 | if torch.is_tensor(outputs):
91 | outputs = [outputs]
92 | outputs = make_seq_requires_grad(outputs)
93 |
94 | dummy_outputs = [torch.zeros_like(o, requires_grad=True) for o in outputs]
95 | _vjp = torch.autograd.grad(outputs, inputs, grad_outputs=dummy_outputs, create_graph=True, allow_unused=True)
96 | _vjp = make_seq_requires_grad(convert_none_to_zeros(_vjp, inputs))
97 |
98 | _jvp = torch.autograd.grad(_vjp, dummy_outputs, grad_outputs=grad_inputs, **kwargs)
99 | return convert_none_to_zeros(_jvp, dummy_outputs)
100 |
101 |
102 | def flat_to_shape(flat_tensor, shapes):
103 | """Convert a flat tensor to a list of tensors with specified shapes.
104 |
105 | `flat_tensor` must have exactly the number of elements as stated in `shapes`.
106 | """
107 | numels = [shape.numel() for shape in shapes]
108 | return [flat.reshape(shape) for flat, shape in zip(flat_tensor.split(split_size=numels), shapes)]
109 |
--------------------------------------------------------------------------------
/torchsde/settings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | class ContainerMeta(type):
17 | def all(cls):
18 | return sorted(getattr(cls, x) for x in dir(cls) if not x.startswith('__'))
19 |
20 | def __str__(cls):
21 | return str(cls.all())
22 |
23 | def __contains__(cls, item):
24 | return item in cls.all()
25 |
26 |
27 | # TODO: consider moving all these enums into some appropriate section of the code, rather than having them be global
28 | # like this. (e.g. instead set METHODS = {'euler': Euler, ...} in methods/__init__.py)
29 | class METHODS(metaclass=ContainerMeta):
30 | euler = 'euler'
31 | milstein = 'milstein'
32 | srk = 'srk'
33 | midpoint = 'midpoint'
34 | reversible_heun = 'reversible_heun'
35 | adjoint_reversible_heun = 'adjoint_reversible_heun'
36 | heun = 'heun'
37 | log_ode_midpoint = 'log_ode'
38 | euler_heun = 'euler_heun'
39 |
40 |
41 | class NOISE_TYPES(metaclass=ContainerMeta): # noqa
42 | general = 'general'
43 | diagonal = 'diagonal'
44 | scalar = 'scalar'
45 | additive = 'additive'
46 |
47 |
48 | class SDE_TYPES(metaclass=ContainerMeta): # noqa
49 | ito = 'ito'
50 | stratonovich = 'stratonovich'
51 |
52 |
53 | class LEVY_AREA_APPROXIMATIONS(metaclass=ContainerMeta): # noqa
54 | none = 'none' # Don't compute any Levy area approximation
55 | space_time = 'space-time' # Only compute an (exact) space-time Levy area
56 | davie = 'davie' # Compute Davie's approximation to Levy area
57 | foster = 'foster' # Compute Foster's correction to Davie's approximation to Levy area
58 |
59 |
60 | class METHOD_OPTIONS(metaclass=ContainerMeta): # noqa
61 | grad_free = 'grad_free'
62 |
--------------------------------------------------------------------------------
/torchsde/types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # We import from `typing` more than what's enough, so that other modules can import from this file and not `typing`.
16 | from typing import Sequence, Union, Optional, Any, Dict, Tuple, Callable # noqa: F401
17 |
18 | import torch
19 |
20 | Tensor = torch.Tensor
21 | Tensors = Sequence[Tensor]
22 | TensorOrTensors = Union[Tensor, Tensors]
23 |
24 | Scalar = Union[float, Tensor]
25 | Vector = Union[Sequence[float], Tensor]
26 |
27 | Module = torch.nn.Module
28 | Modules = Sequence[Module]
29 | ModuleOrModules = Union[Module, Modules]
30 |
31 | Size = torch.Size
32 | Sizes = Sequence[Size]
33 |
--------------------------------------------------------------------------------