├── .flake8
├── .github
├── FUNDING.yml
└── workflows
│ ├── publish.yaml
│ └── run_tests.yaml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── example
├── irregular_data.py
├── logsignature_example.py
└── time_series_classification.py
├── imgs
└── main.png
├── setup.py
├── test
├── markers.py
├── pytest.ini
├── test_cdeint.py
├── test_example.py
├── test_hermite_cubic.py
├── test_linear_interpolation.py
├── test_log_ode.py
├── test_misc.py
├── test_natural_cubic_spline.py
└── test_tricks.py
└── torchcde
├── __init__.py
├── interpolation_base.py
├── interpolation_cubic.py
├── interpolation_hermite_cubic_bdiff.py
├── interpolation_linear.py
├── log_ode.py
├── misc.py
└── solver.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = W291,W503,E203
4 | per-file-ignores = __init__.py: F401
5 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [patrick-kidger]
2 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | name: Publish
2 | on:
3 | release:
4 | types: [published]
5 | branches: [master]
6 |
7 | jobs:
8 | test_and_build_and_publish:
9 | strategy:
10 | matrix:
11 | python-version: [ 3.6, 3.8 ]
12 | os: [ macos-latest, ubuntu-latest, windows-latest ]
13 | fail-fast: false
14 | runs-on: ${{ matrix.os }}
15 | steps:
16 | - name: Checkout code
17 | uses: actions/checkout@v2
18 |
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 |
24 | - name: Check version
25 | shell: bash
26 | run: |
27 | python -m pip install --upgrade pip
28 | python -m pip install torchcde
29 | pypi_info=$(pip list | grep torchcde)
30 | pypi_version=$(echo ${pypi_info} | cut -d " " -f2)
31 | python -m pip uninstall -y torchcde
32 | python setup.py install
33 | master_info=$(pip list | grep torchcde)
34 | master_version=$(echo ${master_info} | cut -d " " -f2)
35 | python -m pip uninstall -y torchcde
36 | python -c "import itertools as it;
37 | import sys;
38 | _, pypi_version, master_version = sys.argv;
39 | pypi_version_ = [int(i) for i in pypi_version.split('.')];
40 | master_version_ = [int(i) for i in master_version.split('.')];
41 | pypi_version__ = tuple(p for m, p in it.zip_longest(master_version_, pypi_version_, fillvalue=0));
42 | master_version__ = tuple(m for m, p in it.zip_longest(master_version_, pypi_version_, fillvalue=0));
43 | sys.exit(master_version__ <= pypi_version__)" ${pypi_version} ${master_version}
44 |
45 | - name: Install dependencies
46 | run: |
47 | python -m pip install flake8 pytest wheel
48 |
49 | - name: Lint with flake8
50 | run: |
51 | python -m flake8 .
52 |
53 | # For some reason egg files see to be getting uploaded to PyPI;
54 | # not sure why they're being created.
55 | - name: Build and install sdist
56 | shell: bash
57 | run: |
58 | python -m pip install torch==1.9.0
59 | python setup.py sdist bdist_wheel
60 | rm -f dist/*.egg
61 | python -m pip install dist/*.tar.gz
62 |
63 | # Happens after install the sdist, so that PyTorch is already installed.
64 | # We then detect the version of PyTorch installed, and install the
65 | # appropriate version of Signatory.
66 | - name: Install Signatory
67 | if: matrix.os != 'macos-latest'
68 | shell: bash
69 | run: |
70 | signatory_version=$(python -c "import re
71 | import subprocess
72 | version_msg = subprocess.run('pip install --use-deprecated=legacy-resolver signatory==', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
73 | version_re = re.compile(rb'from versions: ([0-9\. ,]*)\)')
74 | last_version = version_re.search(version_msg.stderr).group(1).split(b', ')[-1].decode('utf-8').split('.')
75 | assert len(last_version) == 6
76 | last_version = '.'.join(last_version[:3])
77 | print(last_version)")
78 | torch_info=$(pip list | grep '^torch ')
79 | torch_version=$(echo ${torch_info} | cut -d " " -f2)
80 | python -m pip install signatory==${signatory_version}.${torch_version}
81 |
82 | - name: Run sdist tests
83 | run: |
84 | python -m pytest
85 | python -m pip uninstall -y torchcde
86 |
87 | - name: Run bdist_wheel tests
88 | shell: bash
89 | run: |
90 | python -m pip install dist/*.whl
91 | python -m pytest
92 | python -m pip uninstall -y torchcde
93 |
94 | - name: Publish to PyPI
95 | if: matrix.python-version == '3.8' && matrix.os == 'ubuntu-latest'
96 | uses: pypa/gh-action-pypi-publish@v1.4.2
97 | with:
98 | user: ${{ secrets.pypi_username }}
99 | password: ${{ secrets.pypi_password }}
100 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yaml:
--------------------------------------------------------------------------------
1 | name: Run test suite
2 | on: [pull_request]
3 |
4 | jobs:
5 | check_version:
6 | strategy:
7 | matrix:
8 | python-version: [ 3.8 ]
9 | os: [ ubuntu-latest ]
10 | runs-on: ${{ matrix.os }}
11 | steps:
12 | - name: Checkout code
13 | uses: actions/checkout@v2
14 |
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 |
20 | - name: Check version
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install git+https://github.com/patrick-kidger/torchcde.git
24 | master_info=$(pip list | grep torchcde)
25 | master_version=$(echo ${master_info} | cut -d " " -f2)
26 | pip uninstall -y torchcde
27 | python setup.py install
28 | pr_info=$(pip list | grep torchcde)
29 | pr_version=$(echo ${pr_info} | cut -d " " -f2)
30 | python -c "import itertools as it;
31 | import sys;
32 | master_version = sys.argv[1];
33 | pr_version = sys.argv[2];
34 | master_version_ = [int(i) for i in master_version.split('.')];
35 | pr_version_ = [int(i) for i in pr_version.split('.')];
36 | master_version__ = tuple(m for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0));
37 | pr_version__ = tuple(p for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0));
38 | sys.exit(pr_version__ <= master_version__)" ${master_version} ${pr_version}
39 |
40 | test:
41 | needs: [ check_version ]
42 | strategy:
43 | matrix:
44 | python-version: [ 3.6, 3.8 ]
45 | os: [ ubuntu-latest, macOS-latest, windows-latest ]
46 | fail-fast: false
47 | runs-on: ${{ matrix.os }}
48 | steps:
49 | - name: Checkout code
50 | uses: actions/checkout@v2
51 |
52 | - name: Set up Python ${{ matrix.python-version }}
53 | uses: actions/setup-python@v2
54 | with:
55 | python-version: ${{ matrix.python-version }}
56 |
57 | - name: Install dependencies
58 | run: |
59 | python -m pip install --upgrade pip
60 | python -m pip install flake8 pytest
61 |
62 | - name: Install torchcde
63 | run: |
64 | python -m pip install torch==1.9.0
65 | python setup.py install
66 |
67 | - name: Install Signatory
68 | if: matrix.os != 'macos-latest'
69 | shell: bash
70 | run: |
71 | signatory_version=$(python -c "import re
72 | import subprocess
73 | version_msg = subprocess.run('pip install --use-deprecated=legacy-resolver signatory==', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
74 | version_re = re.compile(rb'from versions: ([0-9\. ,]*)\)')
75 | last_version = version_re.search(version_msg.stderr).group(1).split(b', ')[-1].decode('utf-8').split('.')
76 | assert len(last_version) == 6
77 | last_version = '.'.join(last_version[:3])
78 | print(last_version)")
79 | torch_info=$(pip list | grep '^torch ')
80 | torch_version=$(echo ${torch_info} | cut -d " " -f2)
81 | python -m pip install signatory==${signatory_version}.${torch_version}
82 |
83 | - name: Lint with flake8
84 | run: |
85 | python -m flake8 .
86 |
87 | - name: Test with pytest
88 | run: |
89 | python -m pytest
90 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 | **/.ipynb_checkpoints/
3 | *.py[cod]
4 | .idea/
5 | .vs/
6 | build/
7 | dist/
8 | *.egg_info/
9 | *.egg
10 | *.so
11 | *.egg-info/
12 | **/.mypy_cache/
13 | env/
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | prune test
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
torchcde
2 | Differentiable GPU-capable solvers for CDEs
3 |
4 | **Update: for any new projects, I would now recommend using [Diffrax](https://github.com/patrick-kidger/diffrax) instead. This is much faster, and producion-quality. torchcde was its prototype as a research project!**
5 |
6 | This library provides differentiable GPU-capable solvers for controlled differential equations (CDEs). Backpropagation through the solver or via the adjoint method is supported; the latter allows for improved memory efficiency.
7 |
8 | In particular this allows for building [Neural Controlled Differential Equation](https://github.com/patrick-kidger/NeuralCDE) models, which are state-of-the-art models for (arbitrarily irregular!) time series. Neural CDEs can be thought of as a "continuous time RNN".
9 |
10 | ---
11 |
12 |
13 |
14 |
15 |
16 | ## Installation
17 |
18 | ```bash
19 | pip install torchcde
20 | ```
21 |
22 | Requires PyTorch >=1.7.
23 |
24 | ## Example
25 | ```python
26 | import torch
27 | import torchcde
28 |
29 | # Create some data
30 | batch, length, input_channels = 1, 10, 2
31 | hidden_channels = 3
32 | t = torch.linspace(0, 1, length)
33 | t_ = t.unsqueeze(0).unsqueeze(-1).expand(batch, length, 1)
34 | x_ = torch.rand(batch, length, input_channels - 1)
35 | x = torch.cat([t_, x_], dim=2) # include time as a channel
36 |
37 | # Interpolate it
38 | coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(x)
39 | X = torchcde.CubicSpline(coeffs)
40 |
41 | # Create the Neural CDE system
42 | class F(torch.nn.Module):
43 | def __init__(self):
44 | super(F, self).__init__()
45 | self.linear = torch.nn.Linear(hidden_channels,
46 | hidden_channels * input_channels)
47 |
48 | def forward(self, t, z):
49 | return self.linear(z).view(batch, hidden_channels, input_channels)
50 |
51 | func = F()
52 | z0 = torch.rand(batch, hidden_channels)
53 |
54 | # Integrate it
55 | torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval)
56 | ```
57 |
58 | See [time_series_classification.py](./example/time_series_classification.py), which demonstrates how to use the library to train a Neural CDE model to predict the chirality of a spiral.
59 |
60 | Also see [irregular_data.py](./example/irregular_data.py), for demonstrations on how to handle variable-length inputs, irregular sampling, or missing data, all of which can be handled easily, without changing the model.
61 |
62 | ## Citation
63 | If you found use this library useful, please consider citing
64 |
65 | ```bibtex
66 | @article{kidger2020neuralcde,
67 | title={{N}eural {C}ontrolled {D}ifferential {E}quations for {I}rregular {T}ime {S}eries},
68 | author={Kidger, Patrick and Morrill, James and Foster, James and Lyons, Terry},
69 | journal={Advances in Neural Information Processing Systems},
70 | year={2020}
71 | }
72 | ```
73 |
74 | ## Documentation
75 |
76 | The library consists of two main components: (1) integrators for solving controlled differential equations, and (2) ways of constructing controls from data.
77 |
78 | ### Integrators
79 |
80 | The library provides the `cdeint` function, which solves the system of controlled differential equations:
81 | ```
82 | dz(t) = f(t, z(t))dX(t) z(t_0) = z0
83 | ```
84 |
85 | The goal is to find the response `z` driven by the control `X`. This can be re-written as the following differential equation:
86 | ```
87 | dz/dt(t) = f(t, z)dX/dt(t) z(t_0) = z0
88 | ```
89 | where the right hand side describes a matrix-vector product between `f(t, z)` and `dX/dt(t)`.
90 |
91 | This is solved by
92 | ```python
93 | cdeint(X, func, z0, t, adjoint, backend, **kwargs)
94 | ```
95 | where letting `...` denote an arbitrary number of batch dimensions:
96 | * `X` is a `torch.nn.Module` with method `derivative`, such that `X.derivative(t)` is a Tensor of shape `(..., input_channels)`,
97 | * `func` is a `torch.nn.Module`, such that `func(t, z)` returns a Tensor of shape `(..., hidden_channels, input_channels)`,
98 | * `z0` is a Tensor of shape `(..., hidden_channels)`,
99 | * `t` is a one-dimensional Tensor of times to output `z` at.
100 | * `adjoint` is a boolean (defaulting to `True`).
101 | * `backend` is a string (defaulting to `"torchdiffeq"`).
102 |
103 | Adjoint backpropagation (which is slower but more memory efficient) can be toggled with `adjoint=True/False`.
104 |
105 | The `backend` should be either `"torchdiffeq"` or `"torchsde"`, corresponding to which underlying library to use for the solvers. If using torchsde then the stochastic term is zero -- so the CDE is still reduced to an ODE. This is useful if one library supports a feature that the other doesn't. (For example torchsde supports a reversible solver, the [reversible Heun method](https://arxiv.org/abs/2105.13493); at time of writing torchdiffeq does not support any reversible solvers.)
106 |
107 | Any additional `**kwargs` are passed on to `torchdiffeq.odeint[_adjoint]` or `torchsde.sdeint[_adjoint]`, for example to specify the solver.
108 |
109 | ### Constructing controls
110 |
111 | A very common scenario is to construct the continuous control`X` from discrete data (which may be irregularly sampled with missing values). To support this, we provide three main interpolation schemes:
112 |
113 | * Hermite cubic splines with backwards differences
114 | * Linear interpolation
115 | * Rectilinear interpolation
116 |
117 | _Note that if for some reason you already have a continuous control `X` then you won't need an interpolation scheme at all!_
118 |
119 | Hermite cubic splines are usually the best choice, if possible. Linear and rectilinear interpolations are particularly useful in causal settings -- when at inference time the data is arriving over time. We go into further details in the [Further Documentation](#further-documentation) below.
120 |
121 | Just demonstrating Hermite cubic splines for now:
122 | ```python
123 | coeffs = hermite_cubic_coefficients_with_backward_differences(x)
124 |
125 | # coeffs is a torch.Tensor you can save, load,
126 | # pass through Datasets and DataLoaders etc.
127 |
128 | X = CubicSpline(coeffs)
129 | ```
130 | where:
131 | * `x` is a Tensor of shape `(..., length, input_channels)`, where `...` is some number of batch dimensions. Missing data should be represented as a `NaN`.
132 |
133 | The interface provided by `CubicSpline` is:
134 |
135 | * `.interval`, which gives the time interval the spline is defined over. (Often used as the `t` argument in `cdeint`.) This is determined implicitly from the length of the data, and so does _not_ in general correspond to the time your data was actually observed at. (See the [Further Documentation](#further-documentation) note on reparameterisation invariance.)
136 | * `.grid_points` is all of the knots in the spline, so that for example `X.evaluate(X.grid_points)` will recover the original data.
137 | * `.evaluate(t)`, where `t` is an any-dimensional Tensor, to evaluate the spline at any (collection of) time(s).
138 | * `.derivative(t)`, where `t` is an any-dimensional Tensor, to evaluate the derivative of the spline at any (collection of) time(s).
139 |
140 | Usually `hermite_cubic_coefficients_with_backward_differences` should be computed as a preprocessing step, whilst `CubicSpline` should be called inside the forward pass of your model. See [time_series_classification.py](./example/time_series_classification.py) for a worked example.
141 |
142 | Then call:
143 | ```python
144 | cdeint(X=X, func=... z0=..., t=X.interval)
145 | ```
146 |
147 | ## Further documentation
148 | The earlier documentation section should give everything you need to get up and running.
149 |
150 | Here we discuss a few more advanced bits of functionality:
151 | * The reparameterisation invariance property of CDEs.
152 | * Other interpolation methods, and the differences between them.
153 | * The use of fixed solvers. (They just work.)
154 | * Stacking CDEs (i.e. controlling one by the output of another).
155 | * Computing logsignatures for the log-ODE method.
156 |
157 | #### Reparameterisation invariance
158 | This is a classical fact about CDEs.
159 |
160 | Let
be differentiable and increasing, with
and
. Let
, let
, let
, and let
. Then substituting
into a CDE (and just using the standard change of variables formula):
161 |
162 |
163 |
164 | We see that
**also** satisfies the neural CDE equation, just with
as input instead of
. In other words, using
changes the speed at which we traverse the input
, and correspondingly changes the speed at which we traverse the output
-- and that's it! In particular the CDE itself doesn't need any adjusting.
165 |
166 | This ends up being a really useful fact for writing neater software. We can handle things like messy data (e.g. variable length time series) just during data preprocessing, without it complicating the model code. In [time_series_classification.py](/example/time_series_classification.py), the region we integrate over is given by `X.interval` as a standardised region to integrate over. In the example [irregular_data.py](/example/irregular_data.py), we use this to handle variable-length data.
167 |
168 | #### Different interpolation methods
169 | For a full breakdown into the interpolation schemes, see [Neural Controlled Differential Equations for Online Prediction Tasks](https://arxiv.org/pdf/2106.11028.pdf) where each interpolation scheme is scrutinised, and best practices are presented.
170 |
171 | In brief:
172 | * Will your data: (a) be arriving in an online fashion at inference time; and (b) be multivariate; and (c) potentially have missing values?
173 | * Yes: rectilinear interpolation.
174 | * No: Are you using an adaptive step size solver (e.g. the default `dopri5`)?
175 | * Yes: Hermite cubic splines with backwards differences.
176 | * No: linear interpolation.
177 | * Not sure / both: Hermite cubic splines with backwards differences.
178 |
179 | In more detail:
180 |
181 | * Linear interpolation: these are "kind-of" causal.
182 |
183 | During inference we can simply wait at each time point for the next data point to arrive, and then interpolate towards the next data point when it arrives, and solve the CDE over that interval.
184 |
185 | If there is missing data, however, then this isn't possible. (As some of the channels might not have observations you can interpolate to.) In this case use rectilinear interpolation, below.
186 |
187 | Example:
188 | ```python
189 | coeffs = linear_interpolation_coeffs(x)
190 | X = LinearInterpolation(coeffs)
191 | cdeint(X=X, ...)
192 | ```
193 |
194 | Linear interpolation has kinks. If using adaptive step size solvers then it should be told about the kinks. (Rather than expensively finding them for itself -- slowing down to resolve the kink, and then speeding up again afterwards.) This is done with the `jump_t` option when using the `torchdiffeq` backend:
195 | ```python
196 | cdeint(...,
197 | backend='torchdiffeq',
198 | method='dopri5',
199 | options=dict(jump_t=X.grid_points))
200 | ```
201 | Although adaptive step size solvers will probably find it easier to resolve Hermite cubic splines with backward differences, below.
202 |
203 | * Hermite cubic splines with backwards differences: these are "kind-of" causal in the same way as linear interpolation, but dont have kinks, which makes them faster with adaptive step size solvers. (But is simply an unnecessary overhead when working with fixed step size solvers, which is why we recommend linear interpolation is you know you're only going to be using fixed step size solvers.)
204 |
205 | Example:
206 | ```python
207 | coeffs = hermite_cubic_coefficients_with_backward_differences(x)
208 | X = CubicSpline(coeffs)
209 | cdeint(X=X, ...)
210 | ```
211 |
212 | * Rectilinear interpolation: This is appropriate if there is multivariate missing data, and you need causality.
213 |
214 | What is done is to linearly interpolate forward in time (keeping the observations constant), and then linearly interpolate the values (keeping the time constant). This is possible because time is a channel (and doesn't need to line up with the "time" used in the differential equation solver, as per the reparameterisation invariance of the previous section).
215 |
216 | This can be a bit unintuitive at first. We suggest firing up matplotlib and plotting things to get a feel for what's going on. As a fun sidenote, using rectilinear interpolation makes neural CDEs generalise [ODE-RNNs](https://arxiv.org/abs/1907.03907).
217 |
218 | Example:
219 | ```python
220 | # standard setup for a neural CDE: include time as a channel
221 | t = torch.linspace(0, 1, 10)
222 | x = torch.rand(2, 10, 3)
223 | t_ = t.unsqueeze(0).unsqueeze(-1).expand(2, 10, 1)
224 | x = torch.cat([t_, x], dim=-1)
225 | del t, t_ # won't need these again!
226 | # The `rectilinear` argument is the channel index corresponding to time
227 | coeffs = linear_interpolation_coeffs(x, rectilinear=0)
228 | X = LinearInterpolation(coeffs)
229 | cdeint(X=X, ...)
230 | ```
231 |
232 | As before, if using an adaptive step size solver, it should be informed about the kinks.
233 | ```python
234 | cdeint(...,
235 | backend='torchdiffeq',
236 | method='dopri5',
237 | options=dict(jump_t=X.grid_points))
238 | ```
239 |
240 | #### Fixed solvers
241 | Solving CDEs (regardless of the choice of interpolation scheme in a Neural CDE) with fixed solvers like `euler`, `midpoint`, `rk4` etc. is pretty much exactly the same as solving an ODE with a fixed solver. Just make sure to set the `step_size` option to something sensible; for example the smallest gap between times:
242 | ```python
243 | X = LinearInterpolation(coeffs)
244 | step_size = (X.grid_points[1:] - X.grid_points[:-1]).min()
245 | cdeint(
246 | X=X, t=X.interval, func=..., method='rk4',
247 | options=dict(step_size=step_size)
248 | )
249 | ```
250 |
251 | #### Stacking CDEs
252 | You may wish to use the output of one CDE to control another. That is, to solve the coupled CDEs:
253 | ```
254 | du(t) = g(t, u(t))dz(t) u(t_0) = u0
255 | dz(t) = f(t, z(t))dX(t) z(t_0) = z0
256 | ```
257 |
258 | There are two ways to do this. The first way is to put everything inside a single `cdeint` call, by solving the system
259 | ```
260 | v = [u, z]
261 | v0 = [u0, z0]
262 | h(t, v) = [g(t, u)f(t, z), f(t, z)]
263 |
264 | dv(t) = h(t, v(t))dX(t) v(t_0) = v0
265 | ```
266 | and using `cdeint` as normal. This is usually the best way to do it! It's simpler and usually faster. (But forces you to use the same solver for the whole system, for example.)
267 |
268 | The second way is to have `cdeint` output `z(t)` at multiple times `t`, interpolate the discrete output into a continuous path, and then call `cdeint` again. This is probably less memory efficient, but allows for different choices of solver for each call to `cdeint`.
269 |
270 | _For example, this could be used to create multi-layer Neural CDEs, just like multi-layer RNNs. Although as of writing this, no-one seems to have tried this yet!_
271 |
272 | #### The log-ODE method
273 | This is a way of reducing the length of data by using extra channels. (For example, this may help train Neural CDE models faster, as the extra channels can be parallelised, but extra length cannot.)
274 |
275 | This is done by splitting the control `X` up into windows, and computing the _logsignature_ of the control over each window. The logsignature is a transform known to extract the information that is most important to describing how `X` controls a CDE.
276 |
277 | This is supported by the `logsig_windows` function, which takes in data, and produces a transformed path, that now exists in logsignature space:
278 | ```python
279 | batch, length, channels = 1, 100, 2
280 | x = torch.rand(batch, length, channels)
281 | depth, window = 3, 10.0
282 | x = torchcde.logsig_windows(x, depth, window)
283 | # use x as you would normally: interpolate, etc.
284 | ```
285 |
286 | See the paper [Neural Rough Differential Equations for Long Time Series](https://arxiv.org/abs/2009.08295) for more information. See [logsignature_example.py](./example/logsignature_example.py) for a worked example.
287 |
288 | _Note that this requires installing the [Signatory](https://github.com/patrick-kidger/signatory) package._
289 |
--------------------------------------------------------------------------------
/example/irregular_data.py:
--------------------------------------------------------------------------------
1 | ######################
2 | # Processing irregular data is sometimes a little finickity.
3 | # With neural CDEs, it is instead relatively straightforward.
4 | #
5 | # Here we'll look at how you can handle:
6 | # - irregular sampling
7 | # - missing data
8 | # - variable-length sequences
9 | #
10 | # In every case, the only thing that needs changing is the data preprocessing. You won't need to change your model at
11 | # all.
12 | #
13 | # Note that there's little magical going on here -- the way in which we're going to prepare the data is actually
14 | # pretty similar to how we would do so for an RNN etc.
15 | ######################
16 |
17 | import torch
18 | import torchcde
19 |
20 |
21 | ######################
22 | # We begin with a helper for solving a CDE over some data.
23 | ######################
24 |
25 | def _solve_cde(x):
26 | # x should be a tensor of shape (..., length, channels), and may have missing data represented by NaNs.
27 |
28 | # Create dataset
29 | coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(x)
30 |
31 | # Create model
32 | input_channels = x.size(-1)
33 | hidden_channels = 4 # hyperparameter, we can pick whatever we want for this
34 | output_channels = 10 # e.g. to perform 10-way multiclass classification
35 |
36 | class F(torch.nn.Module):
37 | def __init__(self):
38 | super(F, self).__init__()
39 | # For illustrative purposes only. You should usually use an MLP or something. A single linear layer won't be
40 | # that great.
41 | self.linear = torch.nn.Linear(hidden_channels,
42 | hidden_channels * input_channels)
43 |
44 | def forward(self, t, z):
45 | batch_dims = z.shape[:-1]
46 | return self.linear(z).tanh().view(*batch_dims, hidden_channels, input_channels)
47 |
48 | class Model(torch.nn.Module):
49 | def __init__(self):
50 | super(Model, self).__init__()
51 | self.initial = torch.nn.Linear(input_channels, hidden_channels)
52 | self.func = F()
53 | self.readout = torch.nn.Linear(hidden_channels, output_channels)
54 |
55 | def forward(self, coeffs):
56 | X = torchcde.CubicSpline(coeffs)
57 | X0 = X.evaluate(X.interval[0])
58 | z0 = self.initial(X0)
59 | zt = torchcde.cdeint(X=X, func=self.func, z0=z0, t=X.interval)
60 | zT = zt[..., -1, :] # get the terminal value of the CDE
61 | return self.readout(zT)
62 |
63 | model = Model()
64 |
65 | # Run model
66 | return model(coeffs)
67 |
68 |
69 | ######################
70 | # Okay, now for the meat of it: handling irregular data.
71 | ######################
72 |
73 | def irregular_data():
74 | ######################
75 | # Begin by generating some example data.
76 | ######################
77 |
78 | # Batch of three elements, each of two channels. Each element and channel are sampled at different times, and at a
79 | # different number of times.
80 | t1a, t1b = torch.rand(7).sort().values, torch.rand(5).sort().values
81 | t2a, t2b = torch.rand(9).sort().values, torch.rand(7).sort().values
82 | t3a, t3b = torch.rand(8).sort().values, torch.rand(3).sort().values
83 | x1a, x1b = torch.rand_like(t1a), torch.rand_like(t1b)
84 | x2a, x2b = torch.rand_like(t2a), torch.rand_like(t2b)
85 | x3a, x3b = torch.rand_like(t3a), torch.rand_like(t3b)
86 | # Overall this has irregular sampling, missing data, and variable lengths.
87 |
88 | ######################
89 | # We begin by putting handling each batch element individually. Here we handle the problems of irregular sampling
90 | # and missing data.
91 | ######################
92 |
93 | def process_batch_element(ta, tb, xa, xb):
94 | # First get all the times that the batch element was sampled at, across all channels.
95 | t, sort_indices = torch.cat([ta, tb]).sort()
96 | # Now add NaNs to each channel where the other channel was sampled.
97 | xa_ = torch.cat([xa, torch.full_like(xb, float('nan'))])[sort_indices]
98 | xb_ = torch.cat([torch.full_like(xa, float('nan')), xb])[sort_indices]
99 | # Add observational masks
100 | maska = (~torch.isnan(xa_)).cumsum(dim=0)
101 | maskb = (~torch.isnan(xb_)).cumsum(dim=0)
102 | # Stack (time, observation, mask) together into a tensor of shape (length, channels).
103 | return torch.stack([t, xa_, xb_, maska, maskb], dim=1)
104 |
105 | x1 = process_batch_element(t1a, t1b, x1a, x1b)
106 | x2 = process_batch_element(t2a, t2b, x2a, x2b)
107 | x3 = process_batch_element(t3a, t3b, x3a, x3b)
108 |
109 | # Note that observational masks can of course be omitted if the data is regularly sampled and has no missing data.
110 | # Similarly the observational mask may be only a single channel (rather than on a per-channel basis) if there is
111 | # irregular sampling but no missing data.
112 |
113 | ######################
114 | # Now pad out every shorter sequence by filling the last value forward. The choice of fill-forward here is crucial.
115 | ######################
116 |
117 | max_length = max(x1.size(0), x2.size(0), x3.size(0))
118 |
119 | def fill_forward(x):
120 | return torch.cat([x, x[-1].unsqueeze(0).expand(max_length - x.size(0), x.size(1))])
121 |
122 | x1 = fill_forward(x1)
123 | x2 = fill_forward(x2)
124 | x3 = fill_forward(x3)
125 |
126 | ######################
127 | # Batch everything together
128 | ######################
129 | x = torch.stack([x1, x2, x3])
130 |
131 | ######################
132 | # Solve a Neural CDE: this bit is standard, and just included for completeness.
133 | ######################
134 |
135 | zT = _solve_cde(x)
136 | return zT
137 |
138 | ######################
139 | # Let's recap what's happened here.
140 | ######################
141 |
142 | ######################
143 | # Irregular sampling is easy to solve. We don't have to care that things were sampled at different time points, as
144 | # time is just another channel of the data.
145 | ######################
146 |
147 | ######################
148 | # Missing data is next. We indicated missing values by putting in some NaNs in `x`.
149 | # Then when `hermite_cubic_coefficients_with_backward_differences` is called inside `_solve_cde`, it just did the
150 | # interpolation over the missing values.
151 | ######################
152 |
153 | ######################
154 | # We made sure not to lose any information (due to the interpolation) by adding extra channels corresponding to
155 | # (cumulative) masks for whether a channel has been updated. This means that the the NCDE knows how out-of-date
156 | # (or perhaps "how reliable") its input information is.
157 | #
158 | # This is sometimes called "informative missingness": e.g. the notion that doctors may take more frequest
159 | # measurments of patients they believe to be at risk, so the mere presence of an observation tells you something.
160 | #
161 | # That's not 100% accurate, though. These extra channels should always be included when you have missing data, even
162 | # if the missingness probably isn't important. That's simply so the network knows how out-of-date its input is, and
163 | # thus how much it can trust it.
164 | ######################
165 |
166 | ######################
167 | # We handled variable length data by filling everything forward. That might look a little odd: we solved for the
168 | # _final_ value of the CDE, despite having applied padding to our sequences. Shouldn't we have had to get some of
169 | # the intermediate values as well, to get the final value for each individual batch element?
170 | #
171 | # Not so!
172 | # This is a neat trick: Remember that (in differential equation form), a CDE is given by:
173 | # dz/dt(t) = f(t, z)dX/dt(t)
174 | # So when we chose to use fill-forward to pad in our data, then the data is _constant_ over the padding. That means
175 | # that its derivative, dX/dt, is zero. Once the data stops changing, then the hidden state will stop changing as
176 | # well.
177 | #
178 | # Importantly: we applied padding _after_ doing everything else like appending time. If we did it the other way
179 | # around then e.g. the time channel would still keep changing, and this wouldn't work.
180 | #
181 | # Note that technically speaking, a cubic spline interpolation, being smooth, will still have small perturbations in
182 | # dX/dt: it won't _quite_ be zero. Practically speaking this is unlikely to be an issue, but if you prefer then use
183 | # linear interpolation instead, which will set dX/dt to exactly zero.
184 | ######################
185 |
186 | ######################
187 | # Finally, it's worth remarking that all of this is very similar to handling irregular data with RNNs. There's a
188 | # few differences:
189 | # - Time and observational masks are presented cumulatively, rather than as e.g. delta-time increments.
190 | # - It's fine for there to be NaN values in the data (rather than filling them in with zeros or something), because
191 | # the interpolation routines for torchcde handle that for you.
192 | # - Variable length data can be extracted at the end of the CDE, rather than evaluating it at lots of different
193 | # times. (Incidentally doing so is also more efficient when using the adjoint method, as you only have a single
194 | # backward solve to make, rather than lots of small ones between all the final times.)
195 | ######################
196 |
--------------------------------------------------------------------------------
/example/logsignature_example.py:
--------------------------------------------------------------------------------
1 | ######################
2 | # In this script we code up a Neural CDE using the log-ode method for a long time series thus becoming a Neural RDE.
3 | # This paper describing this methodology can be found at https://arxiv.org/pdf/2009.08295.pdf
4 | # This method assumes familiarity with the standard Neural CDE example at `time_series_classification.py`. We will only
5 | # describe the differences from that example.
6 | ######################
7 | import time
8 | import torch
9 | import torchcde
10 | from time_series_classification import NeuralCDE, get_data
11 |
12 |
13 | def _train(train_X, train_y, test_X, test_y, depth, num_epochs, window_length):
14 | # Time the training process
15 | start_time = time.time()
16 |
17 | # Logsignature computation step
18 | train_logsig = torchcde.logsig_windows(train_X, depth, window_length=window_length)
19 | print("Logsignature shape: {}".format(train_logsig.size()))
20 |
21 | model = NeuralCDE(
22 | input_channels=train_logsig.size(-1), hidden_channels=8, output_channels=1, interpolation="linear"
23 | )
24 | optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
25 |
26 | train_coeffs = torchcde.linear_interpolation_coeffs(train_logsig)
27 |
28 | train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
29 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
30 | for epoch in range(num_epochs):
31 | for batch in train_dataloader:
32 | batch_coeffs, batch_y = batch
33 | pred_y = model(batch_coeffs).squeeze(-1)
34 | loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
35 | loss.backward()
36 | optimizer.step()
37 | optimizer.zero_grad()
38 | print("Epoch: {} Training loss: {}".format(epoch, loss.item()))
39 |
40 | # Remember to compute the logsignatures of the test data too!
41 | test_logsig = torchcde.logsig_windows(test_X, depth, window_length=window_length)
42 | test_coeffs = torchcde.linear_interpolation_coeffs(test_logsig)
43 | pred_y = model(test_coeffs).squeeze(-1)
44 | binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)
45 | prediction_matches = (binary_prediction == test_y).to(test_y.dtype)
46 | proportion_correct = prediction_matches.sum() / test_y.size(0)
47 | print("Test Accuracy: {}".format(proportion_correct))
48 |
49 | # Total time
50 | elapsed = time.time() - start_time
51 |
52 | return proportion_correct, elapsed
53 |
54 |
55 | def print_heading(message):
56 | # Print a message inbetween rows of #'s
57 | string_sep = "#" * 50
58 | print("\n" + string_sep + "\n{}\n".format(message) + string_sep)
59 |
60 |
61 | def main(num_epochs=15):
62 | ######################
63 | # Here we load a high frequency version of the spiral data using in `torchcde.example`. Each sample contains 5000
64 | # time points. This is too long to sensibly expect a Neural CDE to handle, training time will be very long and it
65 | # will struggle to remember information from early on in the sequence.
66 | ######################
67 | num_timepoints = 5000
68 | train_X, train_y = get_data(num_timepoints=num_timepoints)
69 | test_X, test_y = get_data(num_timepoints=num_timepoints)
70 |
71 | ######################
72 | # We test the model over logsignature depths [1, 2, 3] with a window length of 50. This reduces the effective
73 | # length of the path to just 100. The only change is an application of `torchcde.logsig_windows`
74 |
75 | # The raw signal has 3 input channels. Taking logsignatures of depths [1, 2, 3] results in a path of logsignatures
76 | # of dimension [3, 6, 14] respectively. We see that higher logsignature depths contain more information about the
77 | # path over the intervals, at a cost of increased numbers of channels.
78 | ######################
79 | depths = [1, 2, 3]
80 | window_length = 50
81 | accuracies = []
82 | training_times = []
83 | for depth in depths:
84 | print_heading('Running for logsignature depth: {}'.format(depth))
85 | acc, elapsed = _train(
86 | train_X, train_y, test_X, test_y, depth, num_epochs, window_length
87 | )
88 | training_times.append(elapsed)
89 | accuracies.append(acc)
90 |
91 | # Finally log the results to the console for a comparison
92 | print_heading("Final results")
93 | for acc, elapsed, depth in zip(accuracies, training_times, depths):
94 | print(
95 | "Depth: {}\n\tAccuracy on test set: {:.1f}%\n\tTime per epoch: {:.1f}s".format(
96 | depth, acc * 100, elapsed / num_epochs
97 | )
98 | )
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/example/time_series_classification.py:
--------------------------------------------------------------------------------
1 | ######################
2 | # So you want to train a Neural CDE model?
3 | # Let's get started!
4 | ######################
5 |
6 | import math
7 | import torch
8 | import torchcde
9 |
10 |
11 | ######################
12 | # A CDE model looks like
13 | #
14 | # z_t = z_0 + \int_0^t f_\theta(z_s) dX_s
15 | #
16 | # Where X is your data and f_\theta is a neural network. So the first thing we need to do is define such an f_\theta.
17 | # That's what this CDEFunc class does.
18 | # Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128.
19 | ######################
20 | class CDEFunc(torch.nn.Module):
21 | def __init__(self, input_channels, hidden_channels):
22 | ######################
23 | # input_channels is the number of input channels in the data X. (Determined by the data.)
24 | # hidden_channels is the number of channels for z_t. (Determined by you!)
25 | ######################
26 | super(CDEFunc, self).__init__()
27 | self.input_channels = input_channels
28 | self.hidden_channels = hidden_channels
29 |
30 | self.linear1 = torch.nn.Linear(hidden_channels, 128)
31 | self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)
32 |
33 | ######################
34 | # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
35 | # different times, which would be unusual. But it's there if you need it!
36 | ######################
37 | def forward(self, t, z):
38 | # z has shape (batch, hidden_channels)
39 | z = self.linear1(z)
40 | z = z.relu()
41 | z = self.linear2(z)
42 | ######################
43 | # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
44 | ######################
45 | z = z.tanh()
46 | ######################
47 | # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
48 | # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
49 | ######################
50 | z = z.view(z.size(0), self.hidden_channels, self.input_channels)
51 | return z
52 |
53 |
54 | ######################
55 | # Next, we need to package CDEFunc up into a model that computes the integral.
56 | ######################
57 | class NeuralCDE(torch.nn.Module):
58 | def __init__(self, input_channels, hidden_channels, output_channels, interpolation="cubic"):
59 | super(NeuralCDE, self).__init__()
60 |
61 | self.func = CDEFunc(input_channels, hidden_channels)
62 | self.initial = torch.nn.Linear(input_channels, hidden_channels)
63 | self.readout = torch.nn.Linear(hidden_channels, output_channels)
64 | self.interpolation = interpolation
65 |
66 | def forward(self, coeffs):
67 | if self.interpolation == 'cubic':
68 | X = torchcde.CubicSpline(coeffs)
69 | elif self.interpolation == 'linear':
70 | X = torchcde.LinearInterpolation(coeffs)
71 | else:
72 | raise ValueError("Only 'linear' and 'cubic' interpolation methods are implemented.")
73 |
74 | ######################
75 | # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
76 | ######################
77 | X0 = X.evaluate(X.interval[0])
78 | z0 = self.initial(X0)
79 |
80 | ######################
81 | # Actually solve the CDE.
82 | ######################
83 | z_T = torchcde.cdeint(X=X,
84 | z0=z0,
85 | func=self.func,
86 | t=X.interval)
87 |
88 | ######################
89 | # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
90 | # and then apply a linear map.
91 | ######################
92 | z_T = z_T[:, 1]
93 | pred_y = self.readout(z_T)
94 | return pred_y
95 |
96 |
97 | ######################
98 | # Now we need some data.
99 | # Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise.
100 | ######################
101 | def get_data(num_timepoints=100):
102 | t = torch.linspace(0., 4 * math.pi, num_timepoints)
103 |
104 | start = torch.rand(128) * 2 * math.pi
105 | x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
106 | x_pos[:64] *= -1
107 | y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
108 | x_pos += 0.01 * torch.randn_like(x_pos)
109 | y_pos += 0.01 * torch.randn_like(y_pos)
110 | ######################
111 | # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
112 | # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
113 | ######################
114 | X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)
115 | y = torch.zeros(128)
116 | y[:64] = 1
117 |
118 | perm = torch.randperm(128)
119 | X = X[perm]
120 | y = y[perm]
121 |
122 | ######################
123 | # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
124 | # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise
125 | # respectively.
126 | ######################
127 | return X, y
128 |
129 |
130 | def main(num_epochs=30):
131 | train_X, train_y = get_data()
132 |
133 | ######################
134 | # input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
135 | # hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
136 | # output_channels=1 because we're doing binary classification.
137 | ######################
138 | model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=1)
139 | optimizer = torch.optim.Adam(model.parameters())
140 |
141 | ######################
142 | # Now we turn our dataset into a continuous path. We do this here via Hermite cubic spline interpolation.
143 | # The resulting `train_coeffs` is a tensor describing the path.
144 | # For most problems, it's probably easiest to save this tensor and treat it as the dataset.
145 | ######################
146 | train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)
147 |
148 | train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
149 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
150 | for epoch in range(num_epochs):
151 | for batch in train_dataloader:
152 | batch_coeffs, batch_y = batch
153 | pred_y = model(batch_coeffs).squeeze(-1)
154 | loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
155 | loss.backward()
156 | optimizer.step()
157 | optimizer.zero_grad()
158 | print('Epoch: {} Training loss: {}'.format(epoch, loss.item()))
159 |
160 | test_X, test_y = get_data()
161 | test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
162 | pred_y = model(test_coeffs).squeeze(-1)
163 | binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)
164 | prediction_matches = (binary_prediction == test_y).to(test_y.dtype)
165 | proportion_correct = prediction_matches.sum() / test_y.size(0)
166 | print('Test Accuracy: {}'.format(proportion_correct))
167 |
168 |
169 | if __name__ == '__main__':
170 | main()
171 |
--------------------------------------------------------------------------------
/imgs/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrick-kidger/torchcde/9ff6aba4738989dc5fe3aee86d45812c318f6231/imgs/main.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import re
4 | import setuptools
5 |
6 | here = os.path.realpath(os.path.dirname(__file__))
7 |
8 |
9 | name = 'torchcde'
10 |
11 | # for simplicity we actually store the version in the __version__ attribute in the source
12 | with io.open(os.path.join(here, name, '__init__.py')) as f:
13 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
14 | if meta_match:
15 | version = meta_match.group(1)
16 | else:
17 | raise RuntimeError("Unable to find __version__ string.")
18 |
19 | author = 'Patrick Kidger'
20 |
21 | author_email = 'contact@kidger.site'
22 |
23 | description = "Differentiable controlled differential equation solvers for PyTorch with GPU support and " \
24 | "memory-efficient adjoint backpropagation."
25 |
26 | with io.open(os.path.join(here, 'README.md'), 'r', encoding='utf-8') as f:
27 | readme = f.read()
28 |
29 | url = "https://github.com/patrick-kidger/torchcde"
30 |
31 | license = "Apache-2.0"
32 |
33 | classifiers = ["Development Status :: 4 - Beta",
34 | "Intended Audience :: Developers",
35 | "Intended Audience :: Financial and Insurance Industry",
36 | "Intended Audience :: Information Technology",
37 | "Intended Audience :: Science/Research",
38 | "License :: OSI Approved :: Apache Software License",
39 | "Natural Language :: English",
40 | "Operating System :: MacOS :: MacOS X",
41 | "Operating System :: Microsoft :: Windows",
42 | "Operating System :: Unix",
43 | "Programming Language :: Python :: 3",
44 | "Programming Language :: Python :: Implementation :: CPython",
45 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
46 | "Topic :: Scientific/Engineering :: Information Analysis",
47 | "Topic :: Scientific/Engineering :: Mathematics"]
48 |
49 | python_requires = "~=3.6"
50 |
51 | install_requires = ['torch>=1.7.0', 'torchdiffeq>=0.2.0', 'torchsde>=0.2.5']
52 |
53 | setuptools.setup(name=name,
54 | version=version,
55 | author=author,
56 | author_email=author_email,
57 | maintainer=author,
58 | maintainer_email=author_email,
59 | description=description,
60 | long_description=readme,
61 | long_description_content_type="text/markdown",
62 | url=url,
63 | license=license,
64 | classifiers=classifiers,
65 | zip_safe=False,
66 | python_requires=python_requires,
67 | install_requires=install_requires,
68 | packages=[name])
69 |
--------------------------------------------------------------------------------
/test/markers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sys
3 |
4 |
5 | uses_signatory = pytest.mark.skipif(sys.platform == "darwin", reason="Signatory does not support MacOS")
6 |
--------------------------------------------------------------------------------
/test/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | addopts = --tb=long -ra --durations=0
--------------------------------------------------------------------------------
/test/test_cdeint.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torchcde
4 |
5 |
6 | @pytest.mark.parametrize("backend, method, kwargs", (('torchdiffeq', 'rk4', {"options": {"step_size": 1.0}}),
7 | ('torchdiffeq', 'dopri5', {}),
8 | ('torchsde', 'midpoint', {"dt": 1.0})))
9 | def test_shape(backend, method, kwargs):
10 | for _ in range(5):
11 | num_points = torch.randint(low=5, high=100, size=(1,)).item()
12 | num_channels = torch.randint(low=1, high=3, size=(1,)).item()
13 | num_hidden_channels = torch.randint(low=1, high=5, size=(1,)).item()
14 | if backend == "torchdiffeq":
15 | num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item()
16 | batch_dims = []
17 | for _ in range(num_batch_dims):
18 | batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
19 | elif backend == "torchsde":
20 | num_batch_dims = 1
21 | batch_dims = [torch.randint(low=1, high=3, size=(1,)).item()]
22 | else:
23 | raise ValueError
24 |
25 | values = torch.rand(*batch_dims, num_points, num_channels)
26 |
27 | coeffs = torchcde.natural_cubic_coeffs(values)
28 | spline = torchcde.CubicSpline(coeffs)
29 |
30 | class _Func(torch.nn.Module):
31 | def __init__(self):
32 | super(_Func, self).__init__()
33 | self.variable = torch.nn.Parameter(torch.rand(*[1 for _ in range(num_batch_dims)], 1, num_channels))
34 |
35 | def forward(self, t, z):
36 | return z.sigmoid().unsqueeze(-1) + self.variable
37 |
38 | f = _Func()
39 | z0 = torch.rand(*batch_dims, num_hidden_channels)
40 |
41 | num_out_times = torch.randint(low=2, high=10, size=(1,)).item()
42 | start, end = spline.interval
43 | out_times = torch.rand(num_out_times, dtype=torch.float64).sort().values * (end - start) + start
44 |
45 | out = torchcde.cdeint(spline, f, z0, out_times, backend=backend, method=method, rtol=1e-1, atol=1e-1, **kwargs)
46 | assert out.shape == (*batch_dims, num_out_times, num_hidden_channels)
47 |
48 |
49 | def test_backend():
50 | x = torch.randn(1, 10, 2)
51 | coeffs = torchcde.natural_cubic_coeffs(x)
52 | X = torchcde.CubicSpline(coeffs)
53 |
54 | def func(t, z):
55 | return -z.unsqueeze(-1).expand(1, 3, 2)
56 |
57 | z0 = torch.randn(1, 3)
58 |
59 | torchdiffeq_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchdiffeq", method="midpoint",
60 | options=dict(step_size=1.0))
61 | torchsde_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchsde", method="midpoint", dt=1.0)
62 |
63 | torch.testing.assert_allclose(torchdiffeq_out, torchsde_out)
64 |
65 |
66 | def test_tuple_input():
67 | xa = torch.rand(2, 10, 2)
68 | xb = torch.rand(10, 1)
69 |
70 | coeffs_a = torchcde.natural_cubic_coeffs(xa)
71 | coeffs_b = torchcde.natural_cubic_coeffs(xb)
72 | spline_a = torchcde.CubicSpline(coeffs_a)
73 | spline_b = torchcde.CubicSpline(coeffs_b)
74 | X = torchcde.TupleControl(spline_a, spline_b)
75 |
76 | def func(t, z):
77 | za, zb = z
78 | return za.sigmoid().unsqueeze(-1).repeat_interleave(2, dim=-1), zb.tanh().unsqueeze(-1)
79 |
80 | z0 = torch.rand(2, 3), torch.rand(5, requires_grad=True)
81 | out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, adjoint_params=())
82 | out[0].sum().backward()
83 | assert (z0[1].grad == 0).all()
84 |
85 |
86 | def test_prod():
87 | x = torch.rand(2, 5, 1)
88 | X = torchcde.CubicSpline(torchcde.natural_cubic_coeffs(x))
89 |
90 | class F:
91 | def prod(self, t, z, dXdt):
92 | assert t.shape == ()
93 | assert z.shape == (2, 3)
94 | assert dXdt.shape == (2, 1)
95 | return -z * dXdt
96 |
97 | z0 = torch.rand(2, 3, requires_grad=True)
98 | out = torchcde.cdeint(X=X, func=F(), z0=z0, t=X.interval, adjoint_params=())
99 | out.sum().backward()
100 |
--------------------------------------------------------------------------------
/test/test_example.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import sys
3 |
4 | _here = pathlib.Path(__file__).resolve().parent
5 | sys.path.append(str(_here / '../example'))
6 |
7 | import irregular_data # noqa: E402
8 | import logsignature_example # noqa: E402
9 | import time_series_classification # noqa: E402
10 |
11 | import markers # noqa: E402
12 |
13 |
14 | def test_irregular_data():
15 | irregular_data.irregular_data()
16 |
17 |
18 | def test_time_series_classification():
19 | time_series_classification.main(num_epochs=3)
20 |
21 |
22 | @markers.uses_signatory
23 | def test_logsignature_example():
24 | logsignature_example.main(num_epochs=1)
25 |
--------------------------------------------------------------------------------
/test/test_hermite_cubic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchcde import hermite_cubic_coefficients_with_backward_differences, CubicSpline
3 |
4 |
5 | # Represents a random Hermite cubic spline with unit time jumps
6 | class _HermiteUnitTime:
7 | def __init__(self, data):
8 | x_next = data[..., 1:, :]
9 | x_prev = data[..., :-1, :]
10 | derivs_next = x_next - x_prev
11 | derivs_prev = torch.cat([derivs_next[..., [0], :], derivs_next[..., :-1, :]], axis=-2)
12 | self._a = x_prev
13 | self._b = derivs_prev
14 | self._two_c = 2 * 2 * (derivs_next - derivs_prev)
15 | self._three_d = -3 * (derivs_next - derivs_prev)
16 |
17 | def evaluate(self, fractional_part, index):
18 | fractional_part = fractional_part.unsqueeze(-1)
19 | inner = 0.5 * self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part / 3
20 | inner = self._b[..., index, :] + inner * fractional_part
21 | return self._a[..., index, :] + inner * fractional_part
22 |
23 |
24 | def test_hermite_cubic_unit_time():
25 | for num_channels in (1, 3, 6):
26 | for batch_dims in ((1,), (2, 3)):
27 | for length in (2, 5, 10):
28 | data = torch.randn(*batch_dims, length, num_channels, dtype=torch.float64)
29 | # Hermite
30 | hermite_coeffs = hermite_cubic_coefficients_with_backward_differences(data)
31 | spline = CubicSpline(hermite_coeffs)
32 | # Hermite with unit time
33 | hermite_cubic_unit = _HermiteUnitTime(data)
34 | # Test close
35 | times = torch.linspace(0, length, 10)
36 | for time in times:
37 | fractional_part, index = spline._interpret_t(time)
38 | assert torch.allclose(spline.evaluate(time), hermite_cubic_unit.evaluate(fractional_part, index))
39 |
--------------------------------------------------------------------------------
/test/test_linear_interpolation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchcde
3 | import pytest
4 |
5 |
6 | def test_random():
7 | def _points():
8 | yield 2
9 | yield 3
10 | yield 100
11 | for _ in range(10):
12 | yield torch.randint(low=2, high=100, size=(1,)).item()
13 |
14 | for drop in (False, True):
15 | for use_t in (False, True):
16 | for num_points in _points():
17 | if use_t:
18 | start = torch.rand(1).item() * 10 - 5
19 | end = torch.rand(1).item() * 10 - 5
20 | start, end = min(start, end), max(start, end)
21 | t = torch.linspace(start, end, num_points, dtype=torch.float64)
22 | t_ = t
23 | else:
24 | t = torch.linspace(0, num_points - 1, num_points, dtype=torch.float64)
25 | t_ = None
26 | num_channels = torch.randint(low=1, high=5, size=(1,)).item()
27 | m = torch.rand(num_channels, dtype=torch.float64) * 10 - 5
28 | c = torch.rand(num_channels, dtype=torch.float64) * 10 - 5
29 | values = m * t.unsqueeze(-1) + c
30 |
31 | values_clone = values.clone()
32 | if drop:
33 | for values_slice in values_clone.unbind(dim=-1):
34 | num_drop = int(num_points * torch.randint(low=1, high=4, size=(1,)).item() / 10)
35 | num_drop = min(num_drop, num_points - 4)
36 | to_drop = torch.randperm(num_points - 2)[:num_drop] + 1 # don't drop first or last
37 | values_slice[to_drop] = float('nan')
38 |
39 | coeffs = torchcde.linear_interpolation_coeffs(values_clone, t=t_)
40 | linear = torchcde.LinearInterpolation(coeffs, t=t_)
41 |
42 | for time, value in zip(t, values):
43 | linear_evaluate = linear.evaluate(time)
44 | assert value.shape == linear_evaluate.shape
45 | assert value.allclose(linear_evaluate, rtol=1e-4, atol=1e-6)
46 | linear_derivative = linear.derivative(time)
47 | assert m.shape == linear_derivative.shape
48 | assert m.allclose(linear_derivative, rtol=1e-4, atol=1e-6)
49 |
50 |
51 | def test_small():
52 | for use_t in (False, True):
53 | if use_t:
54 | start = torch.rand(1).item() * 10 - 5
55 | end = torch.rand(1).item() * 10 - 5
56 | start, end = min(start, end), max(start, end)
57 | t = torch.tensor([start, end], dtype=torch.float64)
58 | t_ = t
59 | else:
60 | start = 0
61 | end = 1
62 | t = torch.tensor([0., 1.], dtype=torch.float64)
63 | t_ = None
64 | x = torch.rand(2, 1, dtype=torch.float64)
65 | true_deriv = (x[1] - x[0]) / (end - start)
66 | coeffs = torchcde.linear_interpolation_coeffs(x, t=t_)
67 | linear = torchcde.LinearInterpolation(coeffs, t=t_)
68 | for time in torch.linspace(-1, 2, 100):
69 | true = x[0] + true_deriv * (time - t[0])
70 | pred = linear.evaluate(time)
71 | deriv = linear.derivative(time)
72 | assert true_deriv.shape == deriv.shape
73 | assert true_deriv.allclose(deriv)
74 | assert true.shape == pred.shape
75 | assert true.allclose(pred)
76 |
77 |
78 | def test_specification_and_derivative():
79 | for use_t in (False, True):
80 | for _ in range(10):
81 | for num_batch_dims in (0, 1, 2, 3):
82 | batch_dims = []
83 | for _ in range(num_batch_dims):
84 | batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
85 | length = torch.randint(low=5, high=10, size=(1,)).item()
86 | channels = torch.randint(low=1, high=5, size=(1,)).item()
87 | if use_t:
88 | t = torch.linspace(0, 1, length, dtype=torch.float64)
89 | t_ = t
90 | else:
91 | t = torch.linspace(0, length - 1, length, dtype=torch.float64)
92 | t_ = None
93 | x = torch.rand(*batch_dims, length, channels, dtype=torch.float64)
94 | coeffs = torchcde.linear_interpolation_coeffs(x, t=t_)
95 | spline = torchcde.LinearInterpolation(coeffs, t=t_)
96 | # Test specification
97 | for i, point in enumerate(t):
98 | evaluate = spline.evaluate(point)
99 | xi = x[..., i, :]
100 | assert evaluate.allclose(xi, atol=1e-5, rtol=1e-5)
101 | # Test derivative
102 | for point in torch.rand(100, dtype=torch.float64):
103 | point_with_grad = point.detach().requires_grad_(True)
104 | evaluate = spline.evaluate(point_with_grad)
105 | derivative = spline.derivative(point)
106 | autoderivative = []
107 | for elem in evaluate.view(-1):
108 | elem.backward(retain_graph=True)
109 | with torch.no_grad():
110 | autoderivative.append(point_with_grad.grad.clone())
111 | point_with_grad.grad.zero_()
112 | autoderivative = torch.stack(autoderivative).view(*evaluate.shape)
113 | assert derivative.shape == autoderivative.shape
114 | assert derivative.allclose(autoderivative, atol=1e-5, rtol=1e-5)
115 |
116 |
117 | def test_rectilinear_preparation():
118 | devices = ['cpu']
119 | if torch.cuda.is_available():
120 | devices.append('cuda')
121 |
122 | for device in devices:
123 | # Simple test
124 | nan = float('nan')
125 | t1 = torch.tensor([0.1, 0.2, 0.9]).view(-1, 1).to(device)
126 | t2 = torch.tensor([0.2, 0.3]).view(-1, 1).to(device)
127 | x1 = torch.tensor([0.4, nan, 1.1]).view(-1, 1).to(device)
128 | x2 = torch.tensor([nan, 2.]).view(-1, 1).to(device)
129 | x = torch.nn.utils.rnn.pad_sequence(
130 | [torch.cat((t1, x1), -1), torch.cat((t2, x2), -1)], batch_first=True, padding_value=nan
131 | )
132 | # We have to fill the time index forward because we currently dont allow nan times for rectilinear
133 | x[:, :, 0] = torchcde.misc.forward_fill(x[:, :, 0], fill_index=-1)
134 | # Build true solution
135 | x1_true = torch.tensor([[0.1, 0.2, 0.2, 0.9, 0.9], [0.4, 0.4, 0.4, 0.4, 1.1]]).T.view(-1, 2).to(device)
136 | x2_true = torch.tensor([[0.2, 0.3, 0.3, 0.3, 0.3], [2., 2., 2., 2., 2.]]).T.view(-1, 2).to(device)
137 | rect_true = torch.stack((x1_true, x2_true))
138 | # Apply rectilinear and compare
139 | rectilinear = torchcde.linear_interpolation_coeffs(x, rectilinear=0)
140 | assert torch.equal(rect_true[~torch.isnan(rect_true)], rectilinear[~torch.isnan(rectilinear)])
141 | # Test also if we swap time time dimension
142 | x_swap = x[:, :, [1, 0]]
143 | rectilinear_swap = torchcde.linear_interpolation_coeffs(x_swap, rectilinear=1)
144 | rect_swp = rect_true[:, :, [1, 0]]
145 | assert torch.equal(rect_swp, rectilinear_swap)
146 |
147 | # Additionally try a 2d case
148 | assert torch.equal(rect_true[0], torchcde.linear_interpolation_coeffs(x[0], rectilinear=0))
149 | # And a 4d case
150 | x_4d = torch.stack([x, x])
151 | rect_true_4d = torch.stack([rect_true, rect_true])
152 | assert torch.equal(rect_true_4d, torchcde.linear_interpolation_coeffs(x_4d, rectilinear=0))
153 |
154 | # Ensure error is thrown if time has a nan value anywhere
155 | x_time_nan = x.clone()
156 | x_time_nan[0, 1, 0] = float('nan')
157 | pytest.raises(AssertionError, torchcde.linear_interpolation_coeffs, x_time_nan, rectilinear=0)
158 |
159 | # Some randoms tests
160 | for _ in range(5):
161 | # Build some data with time
162 | t_starts = torch.randn(5).to(device) ** 2
163 | ts = [torch.linspace(s, s + 10, torch.randint(2, 50, (1,)).item()).to(device) for s in t_starts]
164 | xs = [torch.randn(len(t), 10 - 1).to(device) for t in ts]
165 | x = torch.nn.utils.rnn.pad_sequence(
166 | [torch.cat([t_.view(-1, 1), x_], dim=1) for t_, x_ in zip(ts, xs)], batch_first=True, padding_value=nan
167 | )
168 | # Add some random nans about the place
169 | mask = torch.randint(0, 5, (x.size(0), x.size(1), x.size(2) - 1), dtype=torch.float).to(device)
170 | mask[mask == 0] = float('nan')
171 | x[:, :, 1:] = x[:, :, 1:] * mask
172 | # We have to fill the time index forward because we currently dont allow nan times for rectilinear
173 | x[:, :, 0] = torchcde.misc.forward_fill(x[:, :, 0], fill_index=-1)
174 | # Fill
175 | x_ffilled = torchcde.misc.forward_fill(x)
176 | # Compute the true solution
177 | N, L, C = x_ffilled.shape
178 | rect_true = torch.zeros(N, 2 * L - 1, C).to(device)
179 | lag = torch.cat([x_ffilled[:, 1:, [0]], x_ffilled[:, :-1, 1:]], dim=-1)
180 | rect_true[:, ::2, ] = x_ffilled
181 | rect_true[:, 1::2] = lag
182 | # Need to backfill rect true
183 | # Rectilinear solution
184 | rectilinear = torchcde.linear_interpolation_coeffs(x, rectilinear=0)
185 | assert torch.equal(rect_true[~torch.isnan(rect_true)], rectilinear[~torch.isnan(rect_true)])
186 |
--------------------------------------------------------------------------------
/test/test_log_ode.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchcde
3 |
4 | import markers
5 |
6 |
7 | @markers.uses_signatory
8 | def test_with_linear_interpolation():
9 | import signatory
10 | window_length = 4
11 | for depth in (1, 2, 3, 4):
12 | compute_logsignature = signatory.Logsignature(depth)
13 | for pieces in (1, 2, 3, 5, 10):
14 | num_channels = torch.randint(low=1, high=4, size=(1,)).item()
15 | x_ = [torch.randn(1, num_channels, dtype=torch.float64)]
16 | logsignatures = []
17 | for _ in range(pieces):
18 | x = torch.randn(window_length, num_channels, dtype=torch.float64)
19 | logsignature = compute_logsignature(torch.cat([x_[-1][-1:], x]).unsqueeze(0))
20 | x_.append(x)
21 | logsignatures.append(logsignature)
22 |
23 | x = torch.cat(x_)
24 |
25 | logsig_x = torchcde.logsig_windows(x, depth, window_length)
26 | coeffs = torchcde.linear_interpolation_coeffs(logsig_x)
27 | X = torchcde.LinearInterpolation(coeffs)
28 |
29 | point = 0.5
30 | for logsignature in logsignatures:
31 | interp_logsignature = X.derivative(point)
32 | assert interp_logsignature.allclose(logsignature)
33 | point += 1
34 |
--------------------------------------------------------------------------------
/test/test_misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchcde.misc # testing an implementation detail
3 |
4 |
5 | def test_cheap_stack():
6 | for num in range(1, 4):
7 | for dim in (-2, -1, 0, 1):
8 | xs = [torch.rand(1, 1) for _ in range(num)]
9 | s = torchcde.misc.cheap_stack(xs, dim)
10 | s2 = torch.stack(xs, dim)
11 | assert s.shape == s2.shape
12 | assert (s == s2).all()
13 |
14 |
15 | def test_tridiagonal_solve():
16 | for _ in range(5):
17 | size = torch.randint(low=2, high=10, size=(1,)).item()
18 | diagonal = torch.randn(size, dtype=torch.float64)
19 | upper = torch.randn(size - 1, dtype=torch.float64)
20 | lower = torch.randn(size - 1, dtype=torch.float64)
21 | A = torch.zeros(size, size, dtype=torch.float64)
22 | A[range(size), range(size)] = diagonal
23 | A[range(1, size), range(size - 1)] = lower
24 | A[range(size - 1), range(1, size)] = upper
25 | b = torch.randn(size, dtype=torch.float64)
26 | x = torchcde.misc.tridiagonal_solve(b, upper, diagonal, lower)
27 | mul = A @ x
28 | assert mul.allclose(b)
29 |
30 |
31 | def test_forward_fill():
32 | devices = ['cpu']
33 | if torch.cuda.is_available():
34 | devices.append('cuda')
35 |
36 | for device in devices:
37 | # Check ffill
38 | for N, L, C in [(1, 5, 3), (2, 2, 2), (3, 2, 1)]:
39 | x = torch.randn(N, L, C).to(device)
40 | # Drop mask
41 | tensor_num = x.numel()
42 | mask = torch.randperm(tensor_num)[:int(0.3 * tensor_num)].to(device)
43 | x.view(-1)[mask] = float('nan')
44 | x_ffilled = x.clone().float()
45 | for i in range(0, x.size(0)):
46 | for j in range(x.size(1)):
47 | for k in range(x.size(2)):
48 | non_nan = x_ffilled[i, :j + 1, k][~torch.isnan(x[i, :j + 1, k])]
49 | input_val = non_nan[-1].item() if len(non_nan) > 0 else float('nan')
50 | x_ffilled[i, j, k] = input_val
51 | x_ffilled_actual = torchcde.misc.forward_fill(x)
52 | assert x_ffilled.allclose(x_ffilled_actual, equal_nan=True)
53 |
--------------------------------------------------------------------------------
/test/test_natural_cubic_spline.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torchcde
4 |
5 |
6 | # Represents a random natural cubic spline with a single knot in the middle
7 | class _Cubic:
8 | def __init__(self, batch_dims, num_channels, start, end):
9 | self.a = torch.randn(*batch_dims, num_channels, dtype=torch.float64) * 10
10 | self.b = torch.randn(*batch_dims, num_channels, dtype=torch.float64) * 10
11 | self.c = torch.randn(*batch_dims, num_channels, dtype=torch.float64) * 10
12 | self.d1 = -self.c / (3 * start)
13 | self.d2 = -self.c / (3 * end)
14 |
15 | def _normalise_dims(self, t):
16 | a = self.a
17 | b = self.b
18 | c = self.c
19 | d1 = self.d1
20 | d2 = self.d2
21 | for _ in t.shape:
22 | a = a.unsqueeze(-2)
23 | b = b.unsqueeze(-2)
24 | c = c.unsqueeze(-2)
25 | d1 = d1.unsqueeze(-2)
26 | d2 = d2.unsqueeze(-2)
27 | t = t.unsqueeze(-1)
28 | d = torch.where(t >= 0, d2, d1)
29 | return a, b, c, d, t
30 |
31 | def evaluate(self, t):
32 | a, b, c, d, t = self._normalise_dims(t)
33 | t_sq = t ** 2
34 | t_cu = t_sq * t
35 | return a + b * t + c * t_sq + d * t_cu
36 |
37 | def derivative(self, t):
38 | a, b, c, d, t = self._normalise_dims(t)
39 | t_sq = t ** 2
40 | return b + 2 * c * t + 3 * d * t_sq
41 |
42 |
43 | class _Offset:
44 | def __init__(self, batch_dims, num_channels, start, end, offset):
45 | self.cubic = _Cubic(batch_dims, num_channels, start - offset, end - offset)
46 | self.offset = offset
47 |
48 | def evaluate(self, t):
49 | t = t - self.offset
50 | return self.cubic.evaluate(t)
51 |
52 | def derivative(self, t):
53 | t = t - self.offset
54 | return self.cubic.derivative(t)
55 |
56 |
57 | @pytest.mark.skip(reason="Test is flaky. Not obvious whether the problem is in the test or in the natural cubic "
58 | "spline code. As natural cubic splines are being phased out in favour of alternatives "
59 | "anyway, I'm just marking this test as a skip.")
60 | def test_interp():
61 | for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
62 | for _ in range(3):
63 | for use_t in (True, False):
64 | for drop in (False, True):
65 | num_points = torch.randint(low=5, high=100, size=(1,)).item()
66 | half_num_points = num_points // 2
67 | num_points = 2 * half_num_points + 1
68 | if use_t:
69 | times1 = torch.rand(half_num_points, dtype=torch.float64) - 1
70 | times2 = torch.rand(half_num_points, dtype=torch.float64)
71 | t = torch.cat([times1, times2, torch.tensor([0.], dtype=torch.float64)]).sort().values
72 | t_ = t
73 | start, end = -1.5, 1.5
74 | del times1, times2
75 | else:
76 | t = torch.linspace(0, num_points - 1, num_points, dtype=torch.float64)
77 | t_ = None
78 | start = 0
79 | end = num_points - 0.5
80 | num_channels = torch.randint(low=1, high=3, size=(1,)).item()
81 | num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item()
82 | batch_dims = []
83 | for _ in range(num_batch_dims):
84 | batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
85 | if use_t:
86 | cubic = _Cubic(batch_dims, num_channels, start=t[0], end=t[-1])
87 | knot = 0
88 | else:
89 | cubic = _Offset(batch_dims, num_channels, start=t[0], end=t[-1], offset=t[1] - t[0])
90 | knot = t[1] - t[0]
91 | values = cubic.evaluate(t)
92 | if drop:
93 | for values_slice in values.unbind(dim=-1):
94 | num_drop = int(num_points * torch.randint(low=1, high=4, size=(1,)).item() / 10)
95 | num_drop = min(num_drop, num_points - 4)
96 | # don't drop first or last
97 | to_drop = torch.randperm(num_points - 2)[:num_drop] + 1
98 | to_drop = [x for x in to_drop if x != knot]
99 | values_slice[..., to_drop] = float('nan')
100 | del num_drop, to_drop, values_slice
101 | coeffs = interp_fn(values, t_)
102 | spline = torchcde.CubicSpline(coeffs, t_)
103 | _test_equal(batch_dims, num_channels, cubic, spline, torch.float64, start, end, 1e-3)
104 |
105 |
106 | def test_linear():
107 | for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
108 | for use_t in (False, True):
109 | start = torch.rand(1).item() * 5 - 2.5
110 | end = torch.rand(1).item() * 5 - 2.5
111 | start, end = min(start, end), max(start, end)
112 | num_points = torch.randint(low=2, high=10, size=(1,)).item()
113 | num_channels = torch.randint(low=1, high=4, size=(1,)).item()
114 | m = torch.rand(num_channels, dtype=torch.float64) * 5 - 2.5
115 | c = torch.rand(num_channels, dtype=torch.float64) * 5 - 2.5
116 | if use_t:
117 | t = torch.linspace(start, end, num_points, dtype=torch.float64)
118 | t_ = t
119 | else:
120 | t = torch.linspace(0, num_points - 1, num_points, dtype=torch.float64)
121 | t_ = None
122 | values = m * t.unsqueeze(-1) + c
123 | coeffs = interp_fn(values, t_)
124 | spline = torchcde.CubicSpline(coeffs, t_)
125 | coeffs2 = torchcde.linear_interpolation_coeffs(values, t_)
126 | linear = torchcde.LinearInterpolation(coeffs2, t_)
127 | batch_dims = []
128 | _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
129 |
130 |
131 | def test_short():
132 | for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
133 | for use_t in (False, True):
134 | if use_t:
135 | t = torch.tensor([0., 1.])
136 | else:
137 | t = None
138 | values = torch.rand(2, 1)
139 | coeffs = interp_fn(values, t)
140 | spline = torchcde.CubicSpline(coeffs, t)
141 | coeffs2 = torchcde.linear_interpolation_coeffs(values, t)
142 | linear = torchcde.LinearInterpolation(coeffs2, t)
143 | batch_dims = []
144 | num_channels = 1
145 | _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
146 |
147 |
148 | # TODO: test other edge cases
149 |
150 |
151 | def _test_equal(batch_dims, num_channels, obj1, obj2, dtype, start, end, tol):
152 | for dimension in (0, 1, 2):
153 | sizes = []
154 | for _ in range(dimension):
155 | sizes.append(torch.randint(low=1, high=4, size=(1,)).item())
156 | expected_size = tuple(batch_dims) + tuple(sizes) + (num_channels,)
157 | eval_times = torch.rand(sizes, dtype=dtype) * (end - start) + start
158 | obj1_evaluate = obj1.evaluate(eval_times)
159 | obj2_evaluate = obj2.evaluate(eval_times)
160 | obj1_derivative = obj1.derivative(eval_times)
161 | obj2_derivative = obj2.derivative(eval_times)
162 | assert obj1_evaluate.shape == expected_size
163 | assert obj2_evaluate.shape == expected_size
164 | assert obj1_derivative.shape == expected_size
165 | assert obj2_derivative.shape == expected_size
166 | torch.testing.assert_allclose(obj1_evaluate, obj2_evaluate, rtol=tol, atol=tol)
167 | torch.testing.assert_allclose(obj1_derivative, obj2_derivative, rtol=tol, atol=tol)
168 |
169 |
170 | def test_specification_and_derivative():
171 | for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
172 | for _ in range(10):
173 | for use_t in (False, True):
174 | for num_batch_dims in (0, 1, 2, 3):
175 | batch_dims = []
176 | for _ in range(num_batch_dims):
177 | batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
178 | length = torch.randint(low=5, high=10, size=(1,)).item()
179 | channels = torch.randint(low=1, high=5, size=(1,)).item()
180 | if use_t:
181 | t = torch.linspace(0, 1, length, dtype=torch.float64)
182 | else:
183 | t = torch.linspace(0, length - 1, length, dtype=torch.float64)
184 | x = torch.rand(*batch_dims, length, channels, dtype=torch.float64)
185 | coeffs = interp_fn(x, t)
186 | spline = torchcde.CubicSpline(coeffs, t)
187 | # Test specification
188 | for i, point in enumerate(t):
189 | evaluate = spline.evaluate(point)
190 | xi = x[..., i, :]
191 | assert evaluate.allclose(xi, atol=1e-5, rtol=1e-5)
192 | # Test derivative
193 | for point in torch.rand(100, dtype=torch.float64):
194 | point_with_grad = point.detach().requires_grad_(True)
195 | evaluate = spline.evaluate(point_with_grad)
196 | derivative = spline.derivative(point)
197 | autoderivative = []
198 | for elem in evaluate.view(-1):
199 | elem.backward(retain_graph=True)
200 | with torch.no_grad():
201 | autoderivative.append(point_with_grad.grad.clone())
202 | point_with_grad.grad.zero_()
203 | autoderivative = torch.stack(autoderivative).view(*evaluate.shape)
204 | assert derivative.shape == autoderivative.shape
205 | assert derivative.allclose(autoderivative, atol=1e-5, rtol=1e-5)
206 |
--------------------------------------------------------------------------------
/test/test_tricks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torchcde
4 |
5 |
6 | class _Func(torch.nn.Module):
7 | def __init__(self, input_size, hidden_size):
8 | super(_Func, self).__init__()
9 | self.input_size = input_size
10 | self.hidden_size = hidden_size
11 | self.variable = torch.nn.Parameter(torch.rand(1, 1, input_size))
12 |
13 | def forward(self, t, z):
14 | assert z.shape == (1, self.hidden_size)
15 | out = z.sigmoid().unsqueeze(-1) + self.variable
16 | assert out.shape == (1, self.hidden_size, self.input_size)
17 | return out
18 |
19 |
20 | # Test that gradients can propagate through the controlling path at all
21 | def test_grad_paths():
22 | for method in ('rk4', 'dopri5'):
23 | for adjoint in (True, False):
24 | t = torch.linspace(0, 9, 10, requires_grad=True)
25 | path = torch.rand(1, 10, 3, requires_grad=True)
26 | coeffs = torchcde.natural_cubic_coeffs(path, t)
27 | cubic_spline = torchcde.CubicSpline(coeffs, t)
28 | z0 = torch.rand(1, 3, requires_grad=True)
29 | func = _Func(input_size=3, hidden_size=3)
30 | t_ = torch.tensor([0., 9.], requires_grad=True)
31 |
32 | if adjoint:
33 | kwargs = dict(adjoint_params=tuple(func.parameters()) + (coeffs, t))
34 | else:
35 | kwargs = {}
36 | z = torchcde.cdeint(X=cubic_spline, func=func, z0=z0, t=t_, adjoint=adjoint, method=method, rtol=1e-4,
37 | atol=1e-6, **kwargs)
38 | assert z.shape == (1, 2, 3)
39 | assert t.grad is None
40 | assert path.grad is None
41 | assert z0.grad is None
42 | assert func.variable.grad is None
43 | assert t_.grad is None
44 | z[:, 1].sum().backward()
45 | assert isinstance(t.grad, torch.Tensor)
46 | assert isinstance(path.grad, torch.Tensor)
47 | assert isinstance(z0.grad, torch.Tensor)
48 | assert isinstance(func.variable.grad, torch.Tensor)
49 | assert isinstance(t_.grad, torch.Tensor)
50 |
51 |
52 | # Test that gradients flow back through multiple CDEs stacked on top of one another, and that they do so correctly
53 | # without going through earlier parts of the graph multiple times.
54 | def test_stacked_paths():
55 | class Record(torch.autograd.Function):
56 | @staticmethod
57 | def forward(ctx, name, x):
58 | ctx.name = name
59 | return x
60 |
61 | @staticmethod
62 | def backward(ctx, x):
63 | if hasattr(ctx, 'been_here_before'):
64 | pytest.fail(ctx.name)
65 | ctx.been_here_before = True
66 | return None, x
67 |
68 | coeff_paths = [(torchcde.linear_interpolation_coeffs, torchcde.LinearInterpolation),
69 | (torchcde.natural_cubic_coeffs, torchcde.CubicSpline)]
70 | for adjoint in (False, True):
71 | for first_coeffs, First in coeff_paths:
72 | for second_coeffs, Second in coeff_paths:
73 | first_path = torch.rand(1, 1000, 2, requires_grad=True)
74 | first_coeff = first_coeffs(first_path)
75 | first_X = First(first_coeff)
76 | first_func = _Func(input_size=2, hidden_size=2)
77 |
78 | second_t = torch.linspace(0, 999, 100)
79 | if adjoint:
80 | kwargs = dict(adjoint_params=tuple(first_func.parameters()) + (first_coeff,))
81 | else:
82 | kwargs = {}
83 | second_path = torchcde.cdeint(X=first_X, func=first_func, z0=torch.rand(1, 2),
84 | t=second_t, adjoint=adjoint, method='rk4', options=dict(step_size=10),
85 | **kwargs)
86 | second_path = Record.apply('second', second_path)
87 | second_coeff = second_coeffs(second_path, second_t)
88 | second_X = Second(second_coeff, second_t)
89 | second_func = _Func(input_size=2, hidden_size=2)
90 |
91 | third_t = torch.linspace(0, 999, 10)
92 | if adjoint:
93 | kwargs = dict(adjoint_params=tuple(second_func.parameters()) + (second_coeff, second_t))
94 | else:
95 | kwargs = {}
96 | third_path = torchcde.cdeint(X=second_X, func=second_func, z0=torch.rand(1, 2),
97 | t=third_t, adjoint=adjoint, method='rk4', options=dict(step_size=10),
98 | **kwargs)
99 | third_path = Record.apply('third', third_path)
100 | assert first_func.variable.grad is None
101 | assert second_func.variable.grad is None
102 | assert first_path.grad is None
103 | third_path[:, -1].sum().backward()
104 | assert isinstance(second_func.variable.grad, torch.Tensor)
105 | assert isinstance(first_func.variable.grad, torch.Tensor)
106 | assert isinstance(first_path.grad, torch.Tensor)
107 |
108 |
109 | # Tests that the trick in which we use detaches in the backward pass if possible, does in fact work.
110 | # It's a bit superfluous to test it here now that we've upstreamed it into torchdiffeq, but oh well.
111 | def test_detach_trick():
112 | path = torch.rand(1, 10, 3)
113 | interp = torchcde.CubicSpline(torchcde.natural_cubic_coeffs(path))
114 |
115 | func = _Func(input_size=3, hidden_size=3)
116 |
117 | for adjoint in (True, False):
118 | variable_grads = []
119 | z0 = torch.rand(1, 3)
120 | for t_grad in (True, False):
121 | t_ = torch.tensor([0., 9.], requires_grad=t_grad)
122 | # Don't test dopri5. We will get different results then, because the t variable will force smaller step
123 | # sizes and thus slightly different results.
124 | z = torchcde.cdeint(X=interp, z0=z0, func=func, t=t_, adjoint=adjoint, method='rk4',
125 | options=dict(step_size=0.5))
126 | z[:, -1].sum().backward()
127 | variable_grads.append(func.variable.grad.clone())
128 | func.variable.grad.zero_()
129 |
130 | for elem in variable_grads[1:]:
131 | assert (elem == variable_grads[0]).all()
132 |
--------------------------------------------------------------------------------
/torchcde/__init__.py:
--------------------------------------------------------------------------------
1 | from .interpolation_base import InterpolationBase
2 | from .interpolation_cubic import natural_cubic_spline_coeffs, natural_cubic_coeffs, CubicSpline
3 | from .interpolation_linear import linear_interpolation_coeffs, LinearInterpolation
4 | from .interpolation_hermite_cubic_bdiff import hermite_cubic_coefficients_with_backward_differences
5 | from .log_ode import logsignature_windows, logsig_windows
6 | from .misc import TupleControl
7 | from .solver import cdeint
8 |
9 | __version__ = "0.2.5"
10 |
--------------------------------------------------------------------------------
/torchcde/interpolation_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 |
4 |
5 | class InterpolationBase(torch.nn.Module, metaclass=abc.ABCMeta):
6 | @property
7 | @abc.abstractmethod
8 | def grid_points(self):
9 | raise NotImplementedError
10 |
11 | @property
12 | @abc.abstractmethod
13 | def interval(self):
14 | raise NotImplementedError
15 |
16 | @abc.abstractmethod
17 | def evaluate(self, t):
18 | raise NotImplementedError
19 |
20 | @abc.abstractmethod
21 | def derivative(self, t):
22 | raise NotImplementedError
23 |
--------------------------------------------------------------------------------
/torchcde/interpolation_cubic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchcde import interpolation_base
3 |
4 | from . import misc
5 |
6 |
7 | def _natural_cubic_spline_coeffs_without_missing_values(t, x):
8 | # x should be a tensor of shape (..., length)
9 | # Will return the b, two_c, three_d coefficients of the derivative of the cubic spline interpolating the path.
10 |
11 | length = x.size(-1)
12 |
13 | if length < 2:
14 | # In practice this should always already be caught in __init__.
15 | raise ValueError("Must have a time dimension of size at least 2.")
16 | elif length == 2:
17 | a = x[..., :1]
18 | b = (x[..., 1:] - x[..., :1]) / (t[..., 1:] - t[..., :1])
19 | two_c = torch.zeros(*x.shape[:-1], 1, dtype=x.dtype, device=x.device)
20 | three_d = torch.zeros(*x.shape[:-1], 1, dtype=x.dtype, device=x.device)
21 | else:
22 | # Set up some intermediate values
23 | time_diffs = t[1:] - t[:-1]
24 | time_diffs_reciprocal = time_diffs.reciprocal()
25 | time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2
26 | three_path_diffs = 3 * (x[..., 1:] - x[..., :-1])
27 | six_path_diffs = 2 * three_path_diffs
28 | path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared
29 |
30 | # Solve a tridiagonal linear system to find the derivatives at the knots
31 | system_diagonal = torch.empty(length, dtype=x.dtype, device=x.device)
32 | system_diagonal[:-1] = time_diffs_reciprocal
33 | system_diagonal[-1] = 0
34 | system_diagonal[1:] += time_diffs_reciprocal
35 | system_diagonal *= 2
36 | system_rhs = torch.empty_like(x)
37 | system_rhs[..., :-1] = path_diffs_scaled
38 | system_rhs[..., -1] = 0
39 | system_rhs[..., 1:] += path_diffs_scaled
40 | knot_derivatives = misc.tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal,
41 | time_diffs_reciprocal)
42 |
43 | # Do some algebra to find the coefficients of the spline
44 | a = x[..., :-1]
45 | b = knot_derivatives[..., :-1]
46 | two_c = (six_path_diffs * time_diffs_reciprocal
47 | - 4 * knot_derivatives[..., :-1]
48 | - 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal
49 | three_d = (-six_path_diffs * time_diffs_reciprocal
50 | + 3 * (knot_derivatives[..., :-1]
51 | + knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared
52 |
53 | return a, b, two_c, three_d
54 |
55 |
56 | def _natural_cubic_spline_coeffs_with_missing_values(t, x, _version):
57 | if x.ndimension() == 1:
58 | # We have to break everything down to individual scalar paths because of the possibility of missing values
59 | # being different in different channels
60 | return _natural_cubic_spline_coeffs_with_missing_values_scalar(t, x, _version)
61 | else:
62 | a_pieces = []
63 | b_pieces = []
64 | two_c_pieces = []
65 | three_d_pieces = []
66 | for p in x.unbind(dim=0): # TODO: parallelise over this
67 | a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, p, _version)
68 | a_pieces.append(a)
69 | b_pieces.append(b)
70 | two_c_pieces.append(two_c)
71 | three_d_pieces.append(three_d)
72 | return (misc.cheap_stack(a_pieces, dim=0),
73 | misc.cheap_stack(b_pieces, dim=0),
74 | misc.cheap_stack(two_c_pieces, dim=0),
75 | misc.cheap_stack(three_d_pieces, dim=0))
76 |
77 |
78 | def _natural_cubic_spline_coeffs_with_missing_values_scalar(t, x, _version):
79 | # t and x both have shape (length,)
80 |
81 | nan = torch.isnan(x)
82 | not_nan = ~nan
83 | path_no_nan = x.masked_select(not_nan)
84 |
85 | if path_no_nan.size(0) == 0:
86 | # Every entry is a NaN, so we take a constant path with derivative zero, so return zero coefficients.
87 | # Note that we may assume that X.size(0) >= 2 by the checks in __init__ so "X.size(0) - 1" is a valid
88 | # thing to do.
89 | return (torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device),
90 | torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device),
91 | torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device),
92 | torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device))
93 | # else we have at least one non-NaN entry, in which case we're going to impute at least one more entry (as
94 | # the path is of length at least 2 so the start and the end aren't the same), so we will then have at least two
95 | # non-Nan entries. In particular we can call _compute_coeffs safely later.
96 |
97 | # How to deal with missing values at the start or end of the time series? We're creating some splines, so one
98 | # option is just to extend the first piece backwards, and the final piece forwards. But polynomials tend to
99 | # behave badly when extended beyond the interval they were constructed on, so the results can easily end up
100 | # being awful.
101 | if _version == 0:
102 | # Instead we impute an observation at the very start equal to the first actual observation made, and impute an
103 | # observation at the very end equal to the last actual observation made, and then proceed with splines as
104 | # normal.
105 | need_new_not_nan = False
106 | if torch.isnan(x[0]):
107 | if not need_new_not_nan:
108 | x = x.clone()
109 | need_new_not_nan = True
110 | x[0] = path_no_nan[0]
111 | if torch.isnan(x[-1]):
112 | if not need_new_not_nan:
113 | x = x.clone()
114 | need_new_not_nan = True
115 | x[-1] = path_no_nan[-1]
116 | if need_new_not_nan:
117 | not_nan = ~torch.isnan(x)
118 | path_no_nan = x.masked_select(not_nan)
119 | else:
120 | # Instead we fill forward and backward from the first/last observation made. This is better than the previous
121 | # approach as the splines instead rapidly stabilise to the first/last value.
122 | cumsum_mask = not_nan.cumsum(dim=0)
123 | cumsum_mask[nan] = -1
124 | last_non_nan_index = cumsum_mask.argmax(dim=0)
125 | cumsum_mask[nan] = 1 + last_non_nan_index
126 | first_non_nan_index = cumsum_mask.argmin(dim=0)
127 | x = x.clone()
128 | x[:first_non_nan_index] = x[first_non_nan_index]
129 | x[last_non_nan_index + 1:] = x[last_non_nan_index]
130 | not_nan = ~torch.isnan(x)
131 | path_no_nan = x.masked_select(not_nan)
132 | times_no_nan = t.masked_select(not_nan)
133 |
134 | # Find the coefficients on the pieces we do understand
135 | # These all have shape (len - 1,)
136 | (a_pieces_no_nan,
137 | b_pieces_no_nan,
138 | two_c_pieces_no_nan,
139 | three_d_pieces_no_nan) = _natural_cubic_spline_coeffs_without_missing_values(times_no_nan, path_no_nan)
140 |
141 | # Now we're going to normalise them to give coefficients on every interval
142 | a_pieces = []
143 | b_pieces = []
144 | two_c_pieces = []
145 | three_d_pieces = []
146 |
147 | iter_times_no_nan = iter(times_no_nan)
148 | iter_coeffs_no_nan = iter(zip(a_pieces_no_nan, b_pieces_no_nan, two_c_pieces_no_nan, three_d_pieces_no_nan))
149 | next_time_no_nan = next(iter_times_no_nan)
150 | for time in t[:-1]:
151 | # will always trigger on the first iteration because of how we've imputed missing values at the start and
152 | # end of the time series.
153 | if time >= next_time_no_nan:
154 | prev_time_no_nan = next_time_no_nan
155 | next_time_no_nan = next(iter_times_no_nan)
156 | next_a_no_nan, next_b_no_nan, next_two_c_no_nan, next_three_d_no_nan = next(iter_coeffs_no_nan)
157 | offset = prev_time_no_nan - time
158 | a_inner = (0.5 * next_two_c_no_nan - next_three_d_no_nan * offset / 3) * offset
159 | a_pieces.append(next_a_no_nan + (a_inner - next_b_no_nan) * offset)
160 | b_pieces.append(next_b_no_nan + (next_three_d_no_nan * offset - next_two_c_no_nan) * offset)
161 | two_c_pieces.append(next_two_c_no_nan - 2 * next_three_d_no_nan * offset)
162 | three_d_pieces.append(next_three_d_no_nan)
163 |
164 | return (misc.cheap_stack(a_pieces, dim=0),
165 | misc.cheap_stack(b_pieces, dim=0),
166 | misc.cheap_stack(two_c_pieces, dim=0),
167 | misc.cheap_stack(three_d_pieces, dim=0))
168 |
169 |
170 | # The mathematics of this are adapted from http://mathworld.wolfram.com/CubicSpline.html, although they only treat the
171 | # case of each piece being parameterised by [0, 1]. (We instead take the length of each piece to be the difference in
172 | # time stamps.)
173 | def _natural_cubic_spline_coeffs(x, t, _version):
174 | t = misc.validate_input_path(x, t)
175 |
176 | if torch.isnan(x).any():
177 | # Transpose because channels are a batch dimension for the purpose of finding interpolating polynomials.
178 | # b, two_c, three_d have shape (..., channels, length - 1)
179 | a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, x.transpose(-1, -2), _version)
180 | else:
181 | # Can do things more quickly in this case.
182 | a, b, two_c, three_d = _natural_cubic_spline_coeffs_without_missing_values(t, x.transpose(-1, -2))
183 |
184 | # These all have shape (..., length - 1, channels)
185 | a = a.transpose(-1, -2)
186 | b = b.transpose(-1, -2)
187 | two_c = two_c.transpose(-1, -2)
188 | three_d = three_d.transpose(-1, -2)
189 | coeffs = torch.cat([a, b, two_c, three_d], dim=-1) # for simplicity put them all together
190 | return coeffs
191 |
192 |
193 | def natural_cubic_spline_coeffs(x, t=None):
194 | """Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
195 |
196 | ********************
197 | DEPRECATED: this now exists for backward compatibility. For new projects please use `natural_cubic_coeffs` instead,
198 | which handles missing data at the start/end of a time series better.
199 | ********************
200 |
201 | Arguments:
202 | x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
203 | is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
204 | length-many observations. Missing values are supported, and should be represented as NaNs.
205 | t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
206 | tensor([0., 1., ..., length - 1]). If you are using neural CDEs then you **do not need to use this
207 | argument**. See the Further Documentation in README.md.
208 |
209 | Warning:
210 | If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
211 | don't reinstantiate it on every forward pass, if at all possible.
212 |
213 | Returns:
214 | A tensor, which should in turn be passed to `torchcde.CubicSpline`.
215 |
216 | Why do we do it like this? Because typically you want to use PyTorch tensors at various interfaces, for example
217 | when loading a batch from a DataLoader. If we wrapped all of this up into just the
218 | `torchcde.CubicSpline` class then that sort of thing wouldn't be possible.
219 |
220 | As such the suggested use is to:
221 | (a) Load your data.
222 | (b) Preprocess it with this function.
223 | (c) Save the result.
224 | (d) Treat the result as your dataset as far as PyTorch's `torch.utils.data.Dataset` and
225 | `torch.utils.data.DataLoader` classes are concerned.
226 | (e) Call CubicSpline as the first part of your model.
227 |
228 | See also the accompanying example.py.
229 | """
230 | return _natural_cubic_spline_coeffs(x, t, _version=0)
231 |
232 |
233 | def natural_cubic_coeffs(x, t=None):
234 | """Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
235 |
236 | Arguments:
237 | x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
238 | is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
239 | length-many observations. Missing values are supported, and should be represented as NaNs.
240 | t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
241 | tensor([0., 1., ..., length - 1]). If you are using neural CDEs then you **do not need to use this
242 | argument**. See the Further Documentation in README.md.
243 |
244 | Warning:
245 | If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
246 | don't reinstantiate it on every forward pass, if at all possible.
247 |
248 | Returns:
249 | A tensor, which should in turn be passed to `torchcde.CubicSpline`.
250 |
251 | Why do we do it like this? Because typically you want to use PyTorch tensors at various interfaces, for example
252 | when loading a batch from a DataLoader. If we wrapped all of this up into just the
253 | `torchcde.CubicSpline` class then that sort of thing wouldn't be possible.
254 |
255 | As such the suggested use is to:
256 | (a) Load your data.
257 | (b) Preprocess it with this function.
258 | (c) Save the result.
259 | (d) Treat the result as your dataset as far as PyTorch's `torch.utils.data.Dataset` and
260 | `torch.utils.data.DataLoader` classes are concerned.
261 | (e) Call CubicSpline as the first part of your model.
262 |
263 | See also the accompanying example.py.
264 | """
265 | return _natural_cubic_spline_coeffs(x, t, _version=1)
266 |
267 |
268 | class CubicSpline(interpolation_base.InterpolationBase):
269 | """Calculates the cubic spline approximation to the batch of controls given. Also calculates its derivative.
270 |
271 | Example:
272 | # (2, 1) are batch dimensions. 7 is the time dimension (of the same length as t). 3 is the channel dimension.
273 | x = torch.rand(2, 1, 7, 3)
274 | coeffs = natural_cubic_coeffs(x)
275 | # ...at this point you can save coeffs, put it through PyTorch's Datasets and DataLoaders, etc...
276 | spline = CubicSpline(coeffs)
277 | point = torch.tensor(0.4)
278 | # will be a tensor of shape (2, 1, 3), corresponding to batch and channel dimensions
279 | out = spline.derivative(point)
280 | """
281 |
282 | def __init__(self, coeffs, t=None, **kwargs):
283 | """
284 | Arguments:
285 | coeffs: As returned by `torchcde.natural_cubic_coeffs`.
286 | t: As passed to linear_interpolation_coeffs. (If it was passed. If you are using neural CDEs then you **do
287 | not need to use this argument**. See the Further Documentation in README.md.)
288 | """
289 | super(CubicSpline, self).__init__(**kwargs)
290 |
291 | if t is None:
292 | t = torch.linspace(0, coeffs.size(-2), coeffs.size(-2) + 1, dtype=coeffs.dtype, device=coeffs.device)
293 |
294 | channels = coeffs.size(-1) // 4
295 | if channels * 4 != coeffs.size(-1): # check that it's a multiple of 4
296 | raise ValueError("Passed invalid coeffs.")
297 | a, b, two_c, three_d = (coeffs[..., :channels], coeffs[..., channels:2 * channels],
298 | coeffs[..., 2 * channels:3 * channels], coeffs[..., 3 * channels:])
299 |
300 | self.register_buffer('_t', t)
301 | self.register_buffer('_a', a)
302 | self.register_buffer('_b', b)
303 | # as we're typically computing derivatives, we store the multiples of these coefficients that are more useful
304 | self.register_buffer('_two_c', two_c)
305 | self.register_buffer('_three_d', three_d)
306 |
307 | @property
308 | def grid_points(self):
309 | return self._t
310 |
311 | @property
312 | def interval(self):
313 | return torch.stack([self._t[0], self._t[-1]])
314 |
315 | def _interpret_t(self, t):
316 | t = torch.as_tensor(t, dtype=self._b.dtype, device=self._b.device)
317 | maxlen = self._b.size(-2) - 1
318 | # clamp because t may go outside of [t[0], t[-1]]; this is fine
319 | index = torch.bucketize(t.detach(), self._t.detach()).sub(1).clamp(0, maxlen)
320 | # will never access the last element of self._t; this is correct behaviour
321 | fractional_part = t - self._t[index]
322 | return fractional_part, index
323 |
324 | def evaluate(self, t):
325 | fractional_part, index = self._interpret_t(t)
326 | fractional_part = fractional_part.unsqueeze(-1)
327 | inner = 0.5 * self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part / 3
328 | inner = self._b[..., index, :] + inner * fractional_part
329 | return self._a[..., index, :] + inner * fractional_part
330 |
331 | def derivative(self, t):
332 | fractional_part, index = self._interpret_t(t)
333 | fractional_part = fractional_part.unsqueeze(-1)
334 | inner = self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part
335 | deriv = self._b[..., index, :] + inner * fractional_part
336 | return deriv
337 |
338 |
339 | class NaturalCubicSpline(CubicSpline):
340 | """Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
341 |
342 | ********************
343 | DEPRECATED: this now exists for backward compatibility. For new projects please use `CubicSpline` instead. This
344 | class is general for any cubic coeffs (currently natural cubic or Hermite with backwards differences).
345 | ********************
346 | """
347 |
--------------------------------------------------------------------------------
/torchcde/interpolation_hermite_cubic_bdiff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchcde.interpolation_linear import linear_interpolation_coeffs
3 |
4 |
5 | def _setup_hermite_cubic_coeffs_w_backward_differences(times, coeffs, derivs, device=None):
6 | """Compute backward hermite from linear coeffs."""
7 | x_prev = coeffs[..., :-1, :]
8 | x_next = coeffs[..., 1:, :]
9 | # Let x_0 - x_{-1} = x_1 - x_0
10 | derivs_prev = torch.cat((derivs[..., [0], :], derivs[..., :-1, :]), axis=-2)
11 | derivs_next = derivs
12 | x_diff = x_next - x_prev
13 | t_diff = (times[1:] - times[:-1]).unsqueeze(-1)
14 | # Coeffs
15 | a = x_prev
16 | b = derivs_prev
17 | two_c = 2 * (3 * (x_diff / t_diff - b) - derivs_next + derivs_prev) / t_diff
18 | three_d = (1 / t_diff ** 2) * (derivs_next - b) - (two_c) / t_diff
19 | coeffs = torch.cat([a, b, two_c, three_d], dim=-1).to(device)
20 | return coeffs
21 |
22 |
23 | def hermite_cubic_coefficients_with_backward_differences(x, t=None):
24 | """Computes the coefficients for hermite cubic splines with backward differences.
25 |
26 | Arguments:
27 | As `torchcde.linear_interpolation_coeffs`.
28 |
29 | Returns:
30 | A tensor, which should in turn be passed to `torchcde.CubicSpline`.
31 | """
32 | # Linear coeffs
33 | coeffs = linear_interpolation_coeffs(x, t=t, rectilinear=None)
34 |
35 | if t is None:
36 | t = torch.linspace(0, coeffs.size(-2) - 1, coeffs.size(-2), dtype=coeffs.dtype, device=coeffs.device)
37 |
38 | # Linear derivs
39 | derivs = (coeffs[..., 1:, :] - coeffs[..., :-1, :]) / (t[1:] - t[:-1]).unsqueeze(-1)
40 |
41 | # Use the above to compute hermite coeffs
42 | hermite_coeffs = _setup_hermite_cubic_coeffs_w_backward_differences(t, coeffs, derivs, device=coeffs.device)
43 |
44 | return hermite_coeffs
45 |
--------------------------------------------------------------------------------
/torchcde/interpolation_linear.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import warnings
4 |
5 | from . import interpolation_base
6 | from . import misc
7 |
8 |
9 | _two_pi = 2 * math.pi
10 | _inv_two_pi = 1 / _two_pi
11 |
12 |
13 | def _linear_interpolation_coeffs_with_missing_values_scalar(t, x):
14 | # t and X both have shape (length,)
15 |
16 | not_nan = ~torch.isnan(x)
17 | path_no_nan = x.masked_select(not_nan)
18 |
19 | if path_no_nan.size(0) == 0:
20 | # Every entry is a NaN, so we take a constant path with derivative zero, so return zero coefficients.
21 | return torch.zeros(x.size(0), dtype=x.dtype, device=x.device)
22 |
23 | if path_no_nan.size(0) == x.size(0):
24 | # Every entry is not-NaN, so just return.
25 | return x
26 |
27 | x = x.clone()
28 | # How to deal with missing values at the start or end of the time series? We impute an observation at the very start
29 | # equal to the first actual observation made, and impute an observation at the very end equal to the last actual
30 | # observation made, and then proceed as normal.
31 | if torch.isnan(x[0]):
32 | x[0] = path_no_nan[0]
33 | if torch.isnan(x[-1]):
34 | x[-1] = path_no_nan[-1]
35 |
36 | nan_indices = torch.arange(x.size(0), device=x.device).masked_select(torch.isnan(x))
37 |
38 | if nan_indices.size(0) == 0:
39 | # We only had missing values at the start or end
40 | return x
41 |
42 | prev_nan_index = nan_indices[0]
43 | prev_not_nan_index = prev_nan_index - 1
44 | prev_not_nan_indices = [prev_not_nan_index]
45 | for nan_index in nan_indices[1:]:
46 | if prev_nan_index != nan_index - 1:
47 | prev_not_nan_index = nan_index - 1
48 | prev_nan_index = nan_index
49 | prev_not_nan_indices.append(prev_not_nan_index)
50 |
51 | next_nan_index = nan_indices[-1]
52 | next_not_nan_index = next_nan_index + 1
53 | next_not_nan_indices = [next_not_nan_index]
54 | for nan_index in reversed(nan_indices[:-1]):
55 | if next_nan_index != nan_index + 1:
56 | next_not_nan_index = nan_index + 1
57 | next_nan_index = nan_index
58 | next_not_nan_indices.append(next_not_nan_index)
59 | next_not_nan_indices = reversed(next_not_nan_indices)
60 | for prev_not_nan_index, nan_index, next_not_nan_index in zip(prev_not_nan_indices,
61 | nan_indices,
62 | next_not_nan_indices):
63 | prev_stream = x[prev_not_nan_index]
64 | next_stream = x[next_not_nan_index]
65 | prev_time = t[prev_not_nan_index]
66 | next_time = t[next_not_nan_index]
67 | time = t[nan_index]
68 | ratio = (time - prev_time) / (next_time - prev_time)
69 | x[nan_index] = prev_stream + ratio * (next_stream - prev_stream)
70 |
71 | return x
72 |
73 |
74 | def _linear_interpolation_coeffs_with_missing_values(t, x):
75 | if x.ndimension() == 1:
76 | # We have to break everything down to individual scalar paths because of the possibility of missing values
77 | # being different in different channels
78 | return _linear_interpolation_coeffs_with_missing_values_scalar(t, x)
79 | else:
80 | out_pieces = []
81 | for p in x.unbind(dim=0): # TODO: parallelise over this
82 | out = _linear_interpolation_coeffs_with_missing_values(t, p)
83 | out_pieces.append(out)
84 | return misc.cheap_stack(out_pieces, dim=0)
85 |
86 |
87 | def _prepare_rectilinear_interpolation(data, time_index):
88 | """Prepares data for rectilinear interpolation.
89 |
90 | This function performs the relevant filling and lagging of the data needed to convert raw data into a format such
91 | standard linear interpolation will give the rectilinear interpolation.
92 |
93 | Arguments:
94 | x: tensor of values with first channel index being time, of shape (..., length, input_channels), where ... is
95 | some number of batch dimensions.
96 | time_index: integer giving the index of the time channel.
97 |
98 | Example:
99 | Suppose we have data:
100 | data = [(t1, x1), (t2, NaN), (t3, x3), ...]
101 | that we wish to interpolate using a rectilinear scheme. The key point is that this is equivalent to a linear
102 | interpolation on
103 | data_rect = [(t1, x1), (t2, x1), (t2, x1), (t3, x1), (t3, x3) ...]
104 | This function simply performs the conversion from `data` to `data_rect` so that we can apply the inbuilt
105 | torchcde linear interpolation scheme to achieve rectilinear interpolation.
106 |
107 | Returns:
108 | A tensor, now of shape (..., 2 * length - 1, input_channels] that can be fed to linear interpolation coeffs to
109 | give rectilinear coeffs.
110 | """
111 | # Check time_index is of the correct format
112 | n_channels = data.size(-1)
113 | assert isinstance(time_index, int), "Index of the time channel must be an integer in [0, {}]".format(n_channels - 1)
114 | assert 0 <= time_index < n_channels, "Time index must be in [0, {}], was given {}." \
115 | "".format(n_channels - 1, time_index)
116 |
117 | times = data[..., time_index]
118 | assert not torch.isnan(times).any(), "There exist nan values in the time column which is not allowed. If the " \
119 | "times are padded with nans after final time, a simple solution is to " \
120 | "forward fill the final time."
121 |
122 | # Forward fill and perform lag interleaving for rectilinear
123 | data_filled = misc.forward_fill(data)
124 | data_repeat = data_filled.repeat_interleave(2, dim=-2)
125 | data_repeat[..., :-1, time_index] = data_repeat[..., 1:, time_index]
126 | data_rect = data_repeat[..., :-1, :]
127 |
128 | return data_rect
129 |
130 |
131 | def linear_interpolation_coeffs(x, t=None, rectilinear=None):
132 | """Calculates the knots of the linear interpolation of the batch of controls given.
133 |
134 | Arguments:
135 | x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
136 | is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
137 | length-many observations. Missing values are supported, and should be represented as NaNs.
138 | t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
139 | tensor([0., 1., ..., length - 1]). If you are using neural CDEs then you **do not need to use this
140 | argument**. See the Further Documentation in README.md.
141 | rectilinear: Optional integer. Used for performing rectilinear interpolation. This means that interpolation
142 | between each two adjoint points is done by first interpolating in the time direction, and then interpolating
143 | in the feature direction. (This is useful for causal missing data, see the Further Documentation in
144 | README.md.) Defaults to None, i.e. not performing rectilinear interpolation. For rectilinear interpolation
145 | time *must* be a channel in x and the `rectilinear` parameter must be an integer specifying the channel
146 | index location of the time index in x.
147 |
148 | Warning:
149 | If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
150 | don't call it on every forward pass, if at all possible.
151 |
152 | Returns:
153 | A tensor, which should in turn be passed to `torchcde.LinearInterpolation`.
154 |
155 | See the docstring for `torchcde.natural_cubic_coeffs` for more information on why we do it this way.
156 | """
157 | if rectilinear is not None:
158 | if torch.isnan(x[..., 0, :]).any():
159 | warnings.warn("The data `x` begins with missing values in some channels. The path will be constructed by "
160 | "backward-filling the first observed value, which is not causal. Raising a warning as the "
161 | "`rectilinear` argument has also been passed, which is nearly always only used when "
162 | "causality is desired. If you need causality then fill in the missing value at the start of "
163 | "each channel with whatever you'd like it to be. (The mean over that channel is a common "
164 | "choice.)")
165 | x = _prepare_rectilinear_interpolation(x, rectilinear)
166 |
167 | t = misc.validate_input_path(x, t)
168 |
169 | if torch.isnan(x).any():
170 | x = _linear_interpolation_coeffs_with_missing_values(t, x.transpose(-1, -2)).transpose(-1, -2)
171 | return x
172 |
173 |
174 | class LinearInterpolation(interpolation_base.InterpolationBase):
175 | """Calculates the linear interpolation to the batch of controls given. Also calculates its derivative."""
176 |
177 | def __init__(self, coeffs, t=None, **kwargs):
178 | """
179 | Arguments:
180 | coeffs: As returned by linear_interpolation_coeffs.
181 | t: As passed to linear_interpolation_coeffs. (If it was passed. If you are using neural CDEs then you **do
182 | not need to use this argument**. See the Further Documentation in README.md.)
183 | """
184 | super(LinearInterpolation, self).__init__(**kwargs)
185 |
186 | if t is None:
187 | t = torch.linspace(0, coeffs.size(-2) - 1, coeffs.size(-2), dtype=coeffs.dtype, device=coeffs.device)
188 |
189 | derivs = (coeffs[..., 1:, :] - coeffs[..., :-1, :]) / (t[1:] - t[:-1]).unsqueeze(-1)
190 |
191 | self.register_buffer('_t', t)
192 | self.register_buffer('_coeffs', coeffs)
193 | self.register_buffer('_derivs', derivs)
194 |
195 | @property
196 | def grid_points(self):
197 | return self._t
198 |
199 | @property
200 | def interval(self):
201 | return torch.stack([self._t[0], self._t[-1]])
202 |
203 | def _interpret_t(self, t):
204 | t = torch.as_tensor(t, dtype=self._derivs.dtype, device=self._derivs.device)
205 | maxlen = self._derivs.size(-2) - 1
206 | # clamp because t may go outside of [t[0], t[-1]]; this is fine
207 | index = torch.bucketize(t.detach(), self._t.detach()).sub(1).clamp(0, maxlen)
208 | # will never access the last element of self._t; this is correct behaviour
209 | fractional_part = t - self._t[index]
210 | return fractional_part, index
211 |
212 | def evaluate(self, t):
213 | fractional_part, index = self._interpret_t(t)
214 | fractional_part = fractional_part.unsqueeze(-1)
215 | prev_coeff = self._coeffs[..., index, :]
216 | next_coeff = self._coeffs[..., index + 1, :]
217 | prev_t = self._t[index]
218 | next_t = self._t[index + 1]
219 | diff_t = next_t - prev_t
220 | return prev_coeff + fractional_part * (next_coeff - prev_coeff) / diff_t.unsqueeze(-1)
221 |
222 | def derivative(self, t):
223 | fractional_part, index = self._interpret_t(t)
224 | deriv = self._derivs[..., index, :]
225 | return deriv
226 |
--------------------------------------------------------------------------------
/torchcde/log_ode.py:
--------------------------------------------------------------------------------
1 | try:
2 | import signatory
3 | except ImportError:
4 | class DummyModule:
5 | def __getattr__(self, item):
6 | raise ImportError("signatory has not been installed. Please install it from "
7 | "https://github.com/patrick-kidger/signatory to use the log-ODE method.")
8 | signatory = DummyModule()
9 | import torch
10 |
11 | from . import interpolation_linear
12 | from . import misc
13 |
14 |
15 | def _logsignature_windows(x, depth, window_length, t, _version):
16 | t = misc.validate_input_path(x, t)
17 |
18 | # slightly roundabout way of doing things (rather than using arange) so that it's constructed differentiably
19 | timespan = t[-1] - t[0]
20 | num_pieces = (timespan / window_length).ceil().to(int).item()
21 | end_t = t[0] + num_pieces * window_length
22 | new_t = torch.linspace(t[0], end_t, num_pieces + 1, dtype=t.dtype, device=t.device)
23 | new_t = torch.min(new_t, t.max())
24 |
25 | t_index = 0
26 | new_t_unique = []
27 | new_t_indices = []
28 | for new_t_elem in new_t:
29 | while True:
30 | lequal = (new_t_elem <= t[t_index])
31 | close = new_t_elem.allclose(t[t_index])
32 | if lequal or close:
33 | break
34 | t_index += 1
35 | new_t_indices.append(t_index + len(new_t_unique))
36 | if close:
37 | continue
38 | new_t_unique.append(new_t_elem.unsqueeze(0))
39 |
40 | batch_dimensions = x.shape[:-2]
41 |
42 | missing_X = torch.full((1,), float('nan'), dtype=x.dtype, device=x.device).expand(*batch_dimensions, 1, x.size(-1))
43 | if len(new_t_unique) > 0: # no-op if len == 0, so skip for efficiency
44 | t, indices = torch.cat([t, *new_t_unique]).sort()
45 | x = torch.cat([x, missing_X], dim=-2)[..., indices.clamp(0, x.size(-2)), :]
46 |
47 | # Fill in any missing data linearly (linearly because that's what signatures do in between observations anyway)
48 | # and conveniently that's what this already does. Here 'missing data' includes the NaNs we've just added.
49 | x = interpolation_linear.linear_interpolation_coeffs(x, t)
50 |
51 | # Flatten batch dimensions for compatibility with Signatory
52 | flatten_X = x.reshape(-1, x.size(-2), x.size(-1))
53 | first_increment = torch.zeros(*batch_dimensions, signatory.logsignature_channels(x.size(-1), depth), dtype=x.dtype,
54 | device=x.device)
55 | first_increment[..., :x.size(-1)] = x[..., 0, :]
56 | logsignatures = [first_increment]
57 | compute_logsignature = signatory.Logsignature(depth=depth)
58 | for index, next_index, time, next_time in zip(new_t_indices[:-1], new_t_indices[1:], new_t[:-1], new_t[1:]):
59 | logsignature = compute_logsignature(flatten_X[..., index:next_index + 1, :])
60 | logsignature = logsignature.view(*batch_dimensions, -1)
61 | if _version == 0:
62 | logsignature = logsignature * (next_time - time)
63 | elif _version == 1:
64 | pass
65 | else:
66 | raise RuntimeError
67 | logsignatures.append(logsignature)
68 |
69 | logsignatures = torch.stack(logsignatures, dim=-2)
70 | logsignatures = logsignatures.cumsum(dim=-2)
71 |
72 | if _version == 0:
73 | return logsignatures, new_t
74 | elif _version == 1:
75 | return logsignatures
76 | else:
77 | raise RuntimeError
78 |
79 |
80 | def logsignature_windows(x, depth, window_length, t=None):
81 | """Calculates logsignatures over multiple windows, for the batch of controls given, as in the log-ODE method.
82 |
83 | ********************
84 | DEPRECATED: this now exists for backward compatibility. For new projects please use `logsig_windows` instead,
85 | which has a corrected rescaling coefficient.
86 | ********************
87 |
88 | This corresponds to a transform of the time series, and should be used prior to applying one of the interpolation
89 | schemes.
90 |
91 | Arguments:
92 | x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
93 | is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
94 | length-many observations. Missing values are supported, and should be represented as NaNs.
95 | depth: What depth to compute the logsignatures to.
96 | window_length: How long a time interval to compute logsignatures over.
97 | t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
98 | tensor([0., 1., ..., length - 1]).
99 |
100 | Warning:
101 | If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
102 | don't reinstantiate it on every forward pass, if at all possible.
103 |
104 | Returns:
105 | A tuple of two tensors, which are the values and times of the transformed path.
106 | """
107 | return _logsignature_windows(x, depth, window_length, t, _version=0)
108 |
109 |
110 | def logsig_windows(x, depth, window_length, t=None):
111 | """Calculates logsignatures over multiple windows, for the batch of controls given, as in the log-ODE method.
112 |
113 | This corresponds to a transform of the time series, and should be used prior to applying one of the interpolation
114 | schemes.
115 |
116 | Arguments:
117 | x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
118 | is interpreted as a (batch of) paths taking values in an `input_channels`-dimensional real vector space,
119 | with `length`-many observations. Missing values are supported, and should be represented as NaNs.
120 | depth: What depth to compute the logsignatures to.
121 | window_length: How long a time interval to compute logsignatures over.
122 | t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
123 | `tensor([0., 1., ..., length - 1])`.
124 |
125 | Warning:
126 | If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
127 | don't reinstantiate it on every forward pass, if at all possible.
128 |
129 | Returns:
130 | A tensor, which are the values of the transformed path. Times are _not_ returned: the return value is
131 | always scaled such that the corresponding times are just `tensor([0., 1., ..., length - 1])`.
132 | """
133 | return _logsignature_windows(x, depth, window_length, t, _version=1)
134 |
--------------------------------------------------------------------------------
/torchcde/misc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def cheap_stack(tensors, dim):
7 | if len(tensors) == 1:
8 | return tensors[0].unsqueeze(dim)
9 | else:
10 | return torch.stack(tensors, dim=dim)
11 |
12 |
13 | def tridiagonal_solve(b, A_upper, A_diagonal, A_lower):
14 | """Solves a tridiagonal system Ax = b.
15 |
16 | The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
17 | and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
18 | of size (k, k), with entries:
19 |
20 | D[0] U[0]
21 | L[0] D[1] U[1]
22 | L[1] D[2] U[2] 0
23 | L[2] D[3] U[3]
24 | . . .
25 | . . .
26 | . . .
27 | L[k - 3] D[k - 2] U[k - 2]
28 | 0 L[k - 2] D[k - 1] U[k - 1]
29 | L[k - 1] D[k]
30 |
31 | Arguments:
32 | b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
33 | A_upper: A tensor of shape (..., k - 1).
34 | A_diagonal: A tensor of shape (..., k).
35 | A_lower: A tensor of shape (..., k - 1).
36 |
37 | Returns:
38 | A tensor of shape (..., k), corresponding to the x solving Ax = b
39 |
40 | Warning:
41 | This implementation isn't super fast. You probably want to cache the result, if possible.
42 | """
43 |
44 | # This implementation is very much written for clarity rather than speed.
45 |
46 | A_upper, _ = torch.broadcast_tensors(A_upper, b[..., :-1])
47 | A_lower, _ = torch.broadcast_tensors(A_lower, b[..., :-1])
48 | A_diagonal, b = torch.broadcast_tensors(A_diagonal, b)
49 |
50 | channels = b.size(-1)
51 |
52 | new_b = np.empty(channels, dtype=object)
53 | new_A_diagonal = np.empty(channels, dtype=object)
54 | outs = np.empty(channels, dtype=object)
55 |
56 | new_b[0] = b[..., 0]
57 | new_A_diagonal[0] = A_diagonal[..., 0]
58 | for i in range(1, channels):
59 | w = A_lower[..., i - 1] / new_A_diagonal[i - 1]
60 | new_A_diagonal[i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
61 | new_b[i] = b[..., i] - w * new_b[i - 1]
62 |
63 | outs[channels - 1] = new_b[channels - 1] / new_A_diagonal[channels - 1]
64 | for i in range(channels - 2, -1, -1):
65 | outs[i] = (new_b[i] - A_upper[..., i] * outs[i + 1]) / new_A_diagonal[i]
66 |
67 | return torch.stack(outs.tolist(), dim=-1)
68 |
69 |
70 | def validate_input_path(x, t):
71 | if not x.is_floating_point():
72 | raise ValueError("X must both be floating point.")
73 |
74 | if x.ndimension() < 2:
75 | raise ValueError("X must have at least two dimensions, corresponding to time and channels. It instead has "
76 | "shape {}.".format(tuple(x.shape)))
77 |
78 | if t is None:
79 | t = torch.linspace(0, x.size(-2) - 1, x.size(-2), dtype=x.dtype, device=x.device)
80 |
81 | if not t.is_floating_point():
82 | raise ValueError("t must both be floating point.")
83 | if len(t.shape) != 1:
84 | raise ValueError("t must be one dimensional. It instead has shape {}.".format(tuple(t.shape)))
85 | prev_t_i = -math.inf
86 | for t_i in t:
87 | if t_i <= prev_t_i:
88 | raise ValueError("t must be monotonically increasing.")
89 | prev_t_i = t_i
90 |
91 | if x.size(-2) != t.size(0):
92 | raise ValueError("The time dimension of X must equal the length of t. X has shape {} and t has shape {}, "
93 | "corresponding to time dimensions of {} and {} respectively."
94 | .format(tuple(x.shape), tuple(t.shape), x.size(-2), t.size(0)))
95 |
96 | if t.size(0) < 2:
97 | raise ValueError("Must have a time dimension of size at least 2. It instead has shape {}, corresponding to a "
98 | "time dimension of size {}.".format(tuple(t.shape), t.size(0)))
99 |
100 | return t
101 |
102 |
103 | def forward_fill(x, fill_index=-2):
104 | """Forward fills data in a torch tensor of shape (..., length, input_channels) along the length dim.
105 |
106 | Arguments:
107 | x: tensor of values with first channel index being time, of shape (..., length, input_channels), where ... is
108 | some number of batch dimensions.
109 | fill_index: int that denotes the index to fill down. Default is -2 as we tend to use the convention (...,
110 | length, input_channels) filling down the length dimension.
111 |
112 | Returns:
113 | A tensor with forward filled data.
114 | """
115 | # Checks
116 | assert isinstance(x, torch.Tensor)
117 | assert x.dim() >= 2
118 |
119 | mask = torch.isnan(x)
120 | if mask.any():
121 | cumsum_mask = (~mask).cumsum(dim=fill_index)
122 | cumsum_mask[mask] = 0
123 | _, index = cumsum_mask.cummax(dim=fill_index)
124 | x = x.gather(dim=fill_index, index=index)
125 |
126 | return x
127 |
128 |
129 | class TupleControl(torch.nn.Module):
130 | def __init__(self, *controls):
131 | super(TupleControl, self).__init__()
132 |
133 | if len(controls) == 0:
134 | raise ValueError("Expected one or more controls to batch together.")
135 |
136 | self._interval = controls[0].interval
137 | grid_points = controls[0].grid_points
138 | same_grid_points = True
139 | for control in controls[1:]:
140 | if (control.interval != self._interval).any():
141 | raise ValueError("Can only batch togehter controls over the same interval.")
142 | if same_grid_points and (control.grid_points != grid_points).any():
143 | same_grid_points = False
144 |
145 | if same_grid_points:
146 | self._grid_points = grid_points
147 | else:
148 | self._grid_points = None
149 |
150 | self.controls = torch.nn.ModuleList(controls)
151 |
152 | @property
153 | def interval(self):
154 | return self._interval
155 |
156 | @property
157 | def grid_points(self):
158 | if self._grid_points is None:
159 | raise RuntimeError("Batch of controls have different grid points.")
160 | return self._grid_points
161 |
162 | def evaluate(self, t):
163 | return tuple(control.evaluate(t) for control in self.controls)
164 |
165 | def derivative(self, t):
166 | return tuple(control.derivative(t) for control in self.controls)
167 |
--------------------------------------------------------------------------------
/torchcde/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchdiffeq
3 | import torchsde
4 | import warnings
5 |
6 |
7 | def _check_compatability_per_tensor_base(control_gradient, z0):
8 | if control_gradient.shape[:-1] != z0.shape[:-1]:
9 | raise ValueError("X.derivative did not return a tensor with the same number of batch dimensions as z0. "
10 | "X.derivative returned shape {} (meaning {} batch dimensions), whilst z0 has shape {} "
11 | "(meaning {} batch dimensions)."
12 | "".format(tuple(control_gradient.shape), tuple(control_gradient.shape[:-1]), tuple(z0.shape),
13 | tuple(z0.shape[:-1])))
14 |
15 |
16 | def _check_compatability_per_tensor_forward(control_gradient, system, z0):
17 | _check_compatability_per_tensor_base(control_gradient, z0)
18 | if system.shape[:-2] != z0.shape[:-1]:
19 | raise ValueError("func did not return a tensor with the same number of batch dimensions as z0. func returned "
20 | "shape {} (meaning {} batch dimensions), whilst z0 has shape {} (meaning {} batch"
21 | " dimensions)."
22 | "".format(tuple(system.shape), tuple(system.shape[:-2]), tuple(z0.shape),
23 | tuple(z0.shape[:-1])))
24 | if system.size(-2) != z0.size(-1):
25 | raise ValueError("func did not return a tensor with the same number of hidden channels as z0. func returned "
26 | "shape {} (meaning {} channels), whilst z0 has shape {} (meaning {} channels)."
27 | "".format(tuple(system.shape), system.size(-2), tuple(z0.shape), z0.size(-1)))
28 | if system.size(-1) != control_gradient.size(-1):
29 | raise ValueError("func did not return a tensor with the same number of input channels as X.derivative "
30 | "returned. func returned shape {} (meaning {} channels), whilst X.derivative returned shape "
31 | "{} (meaning {} channels)."
32 | "".format(tuple(system.shape), system.size(-1), tuple(control_gradient.shape),
33 | control_gradient.size(-1)))
34 |
35 |
36 | def _check_compatability_per_tensor_prod(control_gradient, vector_field, z0):
37 | _check_compatability_per_tensor_base(control_gradient, z0)
38 | if vector_field.shape != z0.shape:
39 | raise ValueError("func.prod did not return a tensor with the same shape as z0. func.prod returned shape {} "
40 | "whilst z0 has shape {}."
41 | "".format(tuple(vector_field.shape), tuple(z0.shape)))
42 |
43 |
44 | def _check_compatability(X, func, z0, t):
45 | if not hasattr(X, 'derivative'):
46 | raise ValueError("X must have a 'derivative' method.")
47 | control_gradient = X.derivative(t[0].detach())
48 | if hasattr(func, 'prod'):
49 | is_prod = True
50 | vector_field = func.prod(t[0], z0, control_gradient)
51 | else:
52 | is_prod = False
53 | system = func(t[0], z0)
54 |
55 | if isinstance(z0, torch.Tensor):
56 | is_tensor = True
57 | if not isinstance(control_gradient, torch.Tensor):
58 | raise ValueError("z0 is a tensor and so X.derivative must return a tensor as well.")
59 | if is_prod:
60 | if not isinstance(vector_field, torch.Tensor):
61 | raise ValueError("z0 is a tensor and so func.prod must return a tensor as well.")
62 | _check_compatability_per_tensor_prod(control_gradient, vector_field, z0)
63 | else:
64 | if not isinstance(system, torch.Tensor):
65 | raise ValueError("z0 is a tensor and so func must return a tensor as well.")
66 | _check_compatability_per_tensor_forward(control_gradient, system, z0)
67 |
68 | elif isinstance(z0, (tuple, list)):
69 | is_tensor = False
70 | if not isinstance(control_gradient, (tuple, list)):
71 | raise ValueError("z0 is a tuple/list and so X.derivative must return a tuple/list as well.")
72 | if len(z0) != len(control_gradient):
73 | raise ValueError("z0 and X.derivative(t) must be tuples of the same length.")
74 | if is_prod:
75 | if not isinstance(vector_field, (tuple, list)):
76 | raise ValueError("z0 is a tuple/list and so func.prod must return a tuple/list as well.")
77 | if len(z0) != len(vector_field):
78 | raise ValueError("z0 and func.prod(t, z, dXdt) must be tuples of the same length.")
79 | for control_gradient_, vector_Field_, z0_ in zip(control_gradient, vector_field, z0):
80 | if not isinstance(control_gradient_, torch.Tensor):
81 | raise ValueError("X.derivative must return a tensor or tuple of tensors.")
82 | if not isinstance(vector_Field_, torch.Tensor):
83 | raise ValueError("func.prod must return a tensor or tuple/list of tensors.")
84 | _check_compatability_per_tensor_prod(control_gradient_, vector_Field_, z0_)
85 | else:
86 | if not isinstance(system, (tuple, list)):
87 | raise ValueError("z0 is a tuple/list and so func must return a tuple/list as well.")
88 | if len(z0) != len(system):
89 | raise ValueError("z0 and func(t, z) must be tuples of the same length.")
90 | for control_gradient_, system_, z0_ in zip(control_gradient, system, z0):
91 | if not isinstance(control_gradient_, torch.Tensor):
92 | raise ValueError("X.derivative must return a tensor or tuple of tensors.")
93 | if not isinstance(system_, torch.Tensor):
94 | raise ValueError("func must return a tensor or tuple/list of tensors.")
95 | _check_compatability_per_tensor_forward(control_gradient_, system_, z0_)
96 |
97 | else:
98 | raise ValueError("z0 must either a tensor or a tuple/list of tensors.")
99 |
100 | return is_tensor, is_prod
101 |
102 |
103 | class _VectorField(torch.nn.Module):
104 | def __init__(self, X, func, is_tensor, is_prod):
105 | super(_VectorField, self).__init__()
106 |
107 | self.X = X
108 | self.func = func
109 | self.is_tensor = is_tensor
110 | self.is_prod = is_prod
111 |
112 | # torchsde backend
113 | self.sde_type = getattr(func, "sde_type", "stratonovich")
114 | self.noise_type = getattr(func, "noise_type", "additive")
115 |
116 | # torchdiffeq backend
117 | def forward(self, t, z):
118 | # control_gradient is of shape (..., input_channels)
119 | control_gradient = self.X.derivative(t)
120 |
121 | if self.is_prod:
122 | # out is of shape (..., hidden_channels)
123 | out = self.func.prod(t, z, control_gradient)
124 | else:
125 | # vector_field is of shape (..., hidden_channels, input_channels)
126 | vector_field = self.func(t, z)
127 | if self.is_tensor:
128 | # out is of shape (..., hidden_channels)
129 | # (The squeezing is necessary to make the matrix-multiply properly batch in all cases)
130 | out = (vector_field @ control_gradient.unsqueeze(-1)).squeeze(-1)
131 | else:
132 | out = tuple((vector_field_ @ control_gradient_.unsqueeze(-1)).squeeze(-1)
133 | for vector_field_, control_gradient_ in zip(vector_field, control_gradient))
134 |
135 | return out
136 |
137 | # torchsde backend
138 | f = forward
139 |
140 | def g(self, t, z):
141 | return torch.zeros_like(z).unsqueeze(-1)
142 |
143 |
144 | def cdeint(X, func, z0, t, adjoint=True, backend="torchdiffeq", **kwargs):
145 | r"""Solves a system of controlled differential equations.
146 |
147 | Solves the controlled problem:
148 | ```
149 | z_t = z_{t_0} + \int_{t_0}^t f(s, z_s) dX_s
150 | ```
151 | where z is a tensor of any shape, and X is some controlling signal.
152 |
153 | Arguments:
154 | X: The control. This should be a instance of `torch.nn.Module`, with a `derivative` method. For example
155 | `torchcde.CubicSpline`. This represents a continuous path derived from the data. The
156 | derivative at a point will be computed via `X.derivative(t)`, where t is a scalar tensor. The returned
157 | tensor should have shape (..., input_channels), where '...' is some number of batch dimensions and
158 | input_channels is the number of channels in the input path.
159 | func: Should be a callable describing the vector field f(t, z). If using `adjoint=True` (the default), then
160 | should be an instance of `torch.nn.Module`, to collect the parameters for the adjoint pass. Will be called
161 | with a scalar tensor t and a tensor z of shape (..., hidden_channels), and should return a tensor of shape
162 | (..., hidden_channels, input_channels), where hidden_channels and input_channels are integers defined by the
163 | `hidden_shape` and `X` arguments as above. The '...' corresponds to some number of batch dimensions. If it
164 | has a method `prod` then that will be called to calculate the matrix-vector product f(t, z) dX_t/dt, via
165 | `func.prod(t, z, dXdt)`.
166 | z0: The initial state of the solution. It should have shape (..., hidden_channels), where '...' is some number
167 | of batch dimensions.
168 | t: a one dimensional tensor describing the times to range of times to integrate over and output the results at.
169 | The initial time will be t[0] and the final time will be t[-1].
170 | adjoint: A boolean; whether to use the adjoint method to backpropagate. Defaults to True.
171 | backend: Either "torchdiffeq" or "torchsde". Which library to use for the solvers. Note that if using torchsde
172 | that the Brownian motion component is completely ignored -- so it's still reducing the CDE to an ODE --
173 | but it makes it possible to e.g. use an SDE solver there as the ODE/CDE solver here, if for some reason
174 | that's desired.
175 | **kwargs: Any additional kwargs to pass to the odeint solver of torchdiffeq (the most common are `rtol`, `atol`,
176 | `method`, `options`) or the sdeint solver of torchsde.
177 |
178 | Returns:
179 | The value of each z_{t_i} of the solution to the CDE z_t = z_{t_0} + \int_0^t f(s, z_s)dX_s, where t_i = t[i].
180 | This will be a tensor of shape (..., len(t), hidden_channels).
181 |
182 | Raises:
183 | ValueError for malformed inputs.
184 |
185 | Note:
186 | Supports tupled input, i.e. z0 can be a tuple of tensors, and X.derivative and func can return tuples of tensors
187 | of the same length.
188 |
189 | Warnings:
190 | Note that the returned tensor puts the sequence dimension second-to-last, rather than first like in
191 | `torchdiffeq.odeint` or `torchsde.sdeint`.
192 | """
193 |
194 | # Reduce the default values for the tolerances because CDEs are difficult to solve with the default high tolerances.
195 | if 'atol' not in kwargs:
196 | kwargs['atol'] = 1e-6
197 | if 'rtol' not in kwargs:
198 | kwargs['rtol'] = 1e-4
199 | if adjoint:
200 | if "adjoint_atol" not in kwargs:
201 | kwargs["adjoint_atol"] = kwargs["atol"]
202 | if "adjoint_rtol" not in kwargs:
203 | kwargs["adjoint_rtol"] = kwargs["rtol"]
204 |
205 | is_tensor, is_prod = _check_compatability(X, func, z0, t)
206 |
207 | if adjoint and 'adjoint_params' not in kwargs:
208 | for buffer in X.buffers():
209 | # Compare based on id to avoid PyTorch not playing well with using `in` on tensors.
210 | if buffer.requires_grad:
211 | warnings.warn("One of the inputs to the control path X requires gradients but "
212 | "`kwargs['adjoint_params']` has not been passed. This is probably a mistake: these "
213 | "inputs will not receive a gradient when using the adjoint method. Either have the input "
214 | "not require gradients (if that was unintended), or include it (and every other "
215 | "parameter needing gradients) in `adjoint_params`. For example:\n"
216 | "```\n"
217 | "coeffs = ...\n"
218 | "func = ...\n"
219 | "X = CubicSpline(coeffs)\n"
220 | "adjoint_params = tuple(func.parameters()) + (coeffs,)\n"
221 | "cdeint(X=X, func=func, ..., adjoint_params=adjoint_params)\n"
222 | "```")
223 |
224 | vector_field = _VectorField(X=X, func=func, is_tensor=is_tensor, is_prod=is_prod)
225 | if backend == "torchdiffeq":
226 | odeint = torchdiffeq.odeint_adjoint if adjoint else torchdiffeq.odeint
227 | out = odeint(func=vector_field, y0=z0, t=t, **kwargs)
228 | elif backend == "torchsde":
229 | sdeint = torchsde.sdeint_adjoint if adjoint else torchsde.sdeint
230 | out = sdeint(sde=vector_field, y0=z0, ts=t, **kwargs)
231 | else:
232 | raise ValueError(f"Unrecognised backend={backend}")
233 |
234 | if is_tensor:
235 | batch_dims = range(1, len(out.shape) - 1)
236 | out = out.permute(*batch_dims, 0, -1)
237 | else:
238 | out_ = []
239 | for outi in out:
240 | batch_dims = range(1, len(outi.shape) - 1)
241 | outi = outi.permute(*batch_dims, 0, -1)
242 | out_.append(outi)
243 | out = tuple(out_)
244 |
245 | return out
246 |
--------------------------------------------------------------------------------