├── .gitignore
├── CITATION.cff
├── FAQ.md
├── FURTHER_DOCUMENTATION.md
├── LICENSE
├── README.md
├── assets
├── bouncing_ball.png
├── cnf_demo.gif
├── ode_demo.gif
├── odenet_0_viz.png
└── resnet_0_viz.png
├── examples
├── README.md
├── bouncing_ball.py
├── cnf.py
├── latent_ode.py
├── learn_physics.py
├── ode_demo.py
└── odenet_mnist.py
├── setup.py
├── tests
├── DETEST
│ ├── detest.py
│ └── run.py
├── api_tests.py
├── event_tests.py
├── gradient_tests.py
├── norm_tests.py
├── odeint_tests.py
├── problems.py
└── run_all.py
└── torchdiffeq
├── __init__.py
└── _impl
├── __init__.py
├── adaptive_heun.py
├── adjoint.py
├── bosh3.py
├── dopri5.py
├── dopri8.py
├── event_handling.py
├── fehlberg2.py
├── fixed_adams.py
├── fixed_grid.py
├── fixed_grid_implicit.py
├── interp.py
├── misc.py
├── odeint.py
├── rk_common.py
├── scipy_wrapper.py
├── solvers.py
└── tsit5.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info/
2 | .installed.cfg
3 | *.egg
4 | *__pycache__*
5 | *.pyc
6 | .vscode
7 | build
8 | dist
9 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # YAML 1.2
2 | ---
3 | abstract: |
4 | "This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. We also allow terminating an ODE solution based on an event function, with exact gradient computed.
5 |
6 | As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU."
7 | authors:
8 | -
9 | family-names: Chen
10 | given-names: "Ricky T. Q."
11 | cff-version: "1.1.0"
12 | date-released: 2021-06-02
13 | license: MIT
14 | message: "PyTorch Implementation of Differentiable ODE Solvers"
15 | repository-code: "https://github.com/rtqichen/torchdiffeq"
16 | title: torchdiffeq
17 | version: "0.2.2"
18 | ...
19 |
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | # Frequently Asked Questions (FAQ)
2 |
3 | **What are good resources to understand how ODEs can be solved?**
4 | *Solving Ordinary Differential Equations I Nonstiff Problems* by Hairer et al.
5 | [ODE solver selection in MatLab](https://blogs.mathworks.com/loren/2015/09/23/ode-solver-selection-in-matlab/)
6 |
7 | **What are the ODE solvers available in this repo?**
8 |
9 | - Adaptive-step:
10 | - `dopri8` Runge-Kutta 7(8) of Dormand-Prince-Shampine
11 | - `dopri5` Runge-Kutta 4(5) of Dormand-Prince **[default]**.
12 | - `bosh3` Runge-Kutta 2(3) of Bogacki-Shampine
13 | - `adaptive_heun` Runge-Kutta 1(2)
14 |
15 | - Fixed-step:
16 | - `euler` Euler method.
17 | - `midpoint` Midpoint method.
18 | - `rk4` Fourth-order Runge-Kutta with 3/8 rule.
19 | - `explicit_adams` Explicit Adams.
20 | - `implicit_adams` Implicit Adams.
21 |
22 | - `scipy_solver`: Wraps a SciPy solver.
23 |
24 |
25 | **What are `NFE-F` and `NFE-B`?**
26 | Number of function evaluations for forward and backward pass.
27 |
28 | **What are `rtol` and `atol`?**
29 | They refer to relative `rtol` and absolute `atol` error tolerance.
30 |
31 | **What is the role of error tolerance in adaptive solvers?**
32 | The basic idea is each adaptive solver can produce an error estimate of the current step, and if the error is greater than some tolerance, then the step is redone with a smaller step size, and this repeats until the error is smaller than the provided tolerance.
33 | [Error Tolerances for Variable-Step Solvers](https://www.mathworks.com/help/simulink/ug/types-of-solvers.html#f11-44943)
34 |
35 | **How is the error tolerance calculated?**
36 | The error tolerance is [calculated]((https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/misc.py#L74)) as `atol + rtol * norm of current state`, where the norm being used is a mixed L-infinity/RMS norm.
37 |
38 | **Where is the code that computes the error tolerance?**
39 | It is computed [here.](https://github.com/rtqichen/torchdiffeq/blob/c4c9c61c939c630b9b88267aa56ddaaec319cb16/torchdiffeq/_impl/misc.py#L94)
40 |
41 | **How many states must a Neural ODE solver store during a forward pass with the adjoint method?**
42 | The number of states required to be stored in memory during a forward pass is solver dependent. For example, `dopri5` requires 6 intermediate states to be stored.
43 |
44 | **How many function evaluations are there per ODE step on adaptive solvers?**
45 |
46 | - `dopri5`
47 | The `dopri5` ODE solver stores at least 6 evaluations of the ODE, then takes a step using a linear combination of them. The diagram below illustrates it: the evaluations marked with `o` are on the estimated path, the others with `x` are not. The first two are for selecting the initial step size.
48 |
49 | ```
50 | 0 1 | 2 3 4 5 6 7 | 8 9 10 12 13 14
51 | o x | x x x x x o | x x x x x o
52 | ```
53 |
54 |
55 | **How do I obtain evaluations on the estimated path when using an adaptive solver?**
56 | The argument `t` of `odeint` specifies what times should the ODE solver output.
57 | ```odeint(func, x0, t=torch.linspace(0, 1, 50))```
58 |
59 | Note that the ODE solver will always integrate from `min t(0)` to `max t(1)`, and the intermediate values of `t` have no effect on how the ODE the solved. Intermediate values are computed using polynomial interpolation and have very small cost.
60 |
61 | **What non-linearities should I use in my Neural ODE?**
62 | Avoid non-smooth non-linearities such as ReLU and LeakyReLU.
63 | Prefer non-linearities with a theoretically unique adjoint/gradient such as Softplus.
64 |
65 | **Where is backpropagation for the Neural ODE defined?**
66 | It's defined [here](https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py) if you use the adjoint method `odeint_adjoint`.
67 |
68 | **What are Tableaus?**
69 | Tableaus are ways to describe coefficients for [RK methods](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods). The particular set of coefficients used on this repo was taken from [here](https://www.ams.org/journals/mcom/1986-46-173/S0025-5718-1986-0815836-3/).
70 |
71 | **How do I install the repo on Windows?**
72 | Try downloading the code directly and just running python setup.py install.
73 | https://stackoverflow.com/questions/52528955/installing-a-python-module-from-github-in-windows-10
74 |
75 | **What is the most memory-expensive operation during training?**
76 | The most memory-expensive operation is the single [backward call](https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L75) made to the network.
77 |
78 | **My Neural ODE's numerical solution is farther away from the target than the initial value**
79 | Most tricks for initializing residual nets (like zeroing the weights of the last layer) should help for ODEs as well. This will initialize the ODE as an identity.
80 |
81 |
82 | **My Neural ODE takes too long to train**
83 | This might be because you're running on CPU. Being extremely slow on CPU is expected, as training requires evaluating a neural net multiple times.
84 |
85 |
86 | **My Neural ODE produces underflow in dt when using adaptive solvers like `dopri5`**
87 | This is a problem of the ODE becoming stiff, essentially acting too erratic in a region and the step size becomes so close to zero that no progress can be made in the solver. We were able to avoid this with regularization such as weight decay and using "nice" activation functions, but YMMV. Other potential options are just to accept a larger error by increasing `atol`, `rtol`, or by switching to a fixed solver.
--------------------------------------------------------------------------------
/FURTHER_DOCUMENTATION.md:
--------------------------------------------------------------------------------
1 | # Further documentation
2 |
3 | ## Solver options
4 |
5 | Adaptive and fixed solvers all support several options. Also shown are their default values.
6 |
7 | **Adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun):**
8 | For these solvers, `rtol` and `atol` correspond to the tolerances for accepting/rejecting an adaptive step.
9 |
10 | - `first_step=None`: What size the first step of the solver should be; by default this is selected empirically.
11 |
12 | - `safety=0.9, ifactor=10.0, dfactor=0.2`: How the next optimal step size is calculated, see E. Hairer, S. P. Norsett G. Wanner, *Solving Ordinary Differential Equations I: Nonstiff Problems*, Sec. II.4. Roughly speaking, `safety` will try to shrink the step size slightly by this amount, `ifactor` is the most that the step size can grow by, and `dfactor` is the most that it can shrink by.
13 |
14 | - `max_num_steps=2 ** 31 - 1`: The maximum number of steps the solver is allowed to take.
15 |
16 | - `dtype=torch.float64`: what dtype to use for timelike quantities. Setting this to `torch.float32` will improve speed but may produce underflow errors more easily.
17 |
18 | - `step_t=None`: Times that a step must me made to. In particular this is useful when `func` has kinks (derivative discontinuities) at these times, as the solver then does not need to (slowly) discover these for itself. If passed this should be a `torch.Tensor`.
19 |
20 | - `jump_t=None`: Times that a step must be made to, and `func` re-evaluated at. In particular this is useful when `func` has discontinuites at these times, as then the solver knows that the final function evaluation of the previous step is not equal to the first function evaluation of this step. (i.e. the FSAL property does not hold at this point.) If passed this should be a `torch.Tensor`. Note that this may not be efficient when using PyTorch 1.6.0 or earlier.
21 |
22 | - `norm`: What norm to compute the accept/reject criterion with respect to. Given tensor input, this defaults to an RMS norm. Given tupled input, this defaults to computing an RMS norm over each tensor, and then taking a max over the tuple, producing a mixed L-infinity/RMS norm. If passed this should be a function consuming a tensor/tuple with the same shape as `y0`, and return a scalar corresponding to its norm. When passed as part of `adjoint_options`, then the special value `"seminorm"` may be used to zero out the contribution from the parameters, as per the ["Hey, that's not an ODE"](https://arxiv.org/abs/2009.09457) paper.
23 |
24 | **Fixed solvers (euler, midpoint, rk4, explicit_adams, implicit_adams):**
25 |
26 | - `step_size=None`: How large each discrete step should be. If not passed then this defaults to stepping between the values of `t`. Note that if using `t` just to specify the start and end of the regions of integration, then it is very important to specify this argument! It is mutually exclusive with the `grid_constructor` argument, below.
27 |
28 | - `grid_constructor=None`: A more fine-grained way of setting the steps, by setting these particular locations as the locations of the steps. Should be a callable `func, y0, t -> grid`, transforming the arguments `func, y0, t` of `odeint` into the desired grid (which should be a one dimensional tensor).
29 |
30 | - `perturb`: Defaults to False. If True, then automatically add small perturbations to the start and end of each step, so that stepping to discontinuities works. Note that this this may not be efficient when using PyTorch 1.6.0 or earlier.
31 |
32 | Individual solvers also offer certain options.
33 |
34 | **explicit_adams:**
35 | For this solver, `rtol` and `atol` are ignored. This solver also supports:
36 |
37 | - `max_order`: The maximum order of the Adams-Bashforth predictor.
38 |
39 | **implicit_adams:**
40 | For this solver, `rtol` and `atol` correspond to the tolerance for convergence of the Adams-Moulton corrector. This solver also supports:
41 |
42 | - `max_order`: The maximum order of the Adams-Bashforth-Moulton predictor-corrector.
43 |
44 | - `max_iters`: The maximum number of iterations to run the Adams-Moulton corrector for.
45 |
46 | **scipy_solver:**
47 | - `solver`: which SciPy solver to use; corresponds to the `'method'` argument of `scipy.integrate.solve_ivp`.
48 |
49 | ## Adjoint options
50 |
51 | The function `odeint_adjoint` offers some adjoint-specific options.
52 |
53 | - `adjoint_rtol`,
`adjoint_atol`,
`adjoint_method`,
`adjoint_options`:
The `rtol, atol, method, options` to use for the backward pass. Defaults to the values used for the forward pass.
54 |
55 | - `adjoint_options` has the special key-value pair `{"norm": "seminorm"}` that provides a potentially more efficient adjoint solve when using adaptive step solvers, as described in the ["Hey, that's not an ODE"](https://arxiv.org/abs/2009.09457) paper.
56 |
57 | - `adjoint_params`: The parameters to compute gradients with respect to in the backward pass. Should be a tuple of tensors. Defaults to `tuple(func.parameters())`.
58 | - If passed then `func` does not have to be a `torch.nn.Module`.
59 | - If `func` has no parameters, `adjoint_params=()` must be specified.
60 |
61 |
62 | ## Callbacks
63 |
64 | Callbacks can be triggered during the solve. Callbacks should be specified as methods of the `func` argument to `odeint` and `odeint_adjoint`.
65 |
66 | At the moment support for this is minimal: let us know if you'd find additional callbacks useful.
67 |
68 | **callback_step(self, t0, y0, dt):**
69 | This is called immediately before taking a step of size `dt`, at time `t0`, with current solution value `y0`. This is supported by every solver except `scipy_solver`.
70 |
71 | **callback_accept_step(self, t0, y0, dt):**
72 | This is called when accepting a step of size `dt` at time `t0`, with current solution value `y0`. This is supported by the adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun).
73 |
74 | **callback_reject_step(self, t0, y0, dt):**
75 | As `callback_accept_step`, except called when rejecting steps.
76 |
77 | In addition, callbacks can be triggered during the adjoint pass by adding `_adjoint` to the name of any one of the supported callbacks, e.g. `callback_step_adjoint`.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Ricky Tian Qi Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch Implementation of Differentiable ODE Solvers
2 |
3 | This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. For usage of ODE solvers in deep learning applications, see reference [1].
4 |
5 | As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.
6 |
7 | ## Installation
8 |
9 | To install latest stable version:
10 | ```
11 | pip install torchdiffeq
12 | ```
13 |
14 | To install latest on GitHub:
15 | ```
16 | pip install git+https://github.com/rtqichen/torchdiffeq
17 | ```
18 |
19 | ## Examples
20 | Examples are placed in the [`examples`](./examples) directory.
21 |
22 | We encourage those who are interested in using this library to take a look at [`examples/ode_demo.py`](./examples/ode_demo.py) for understanding how to use `torchdiffeq` to fit a simple spiral ODE.
23 |
24 |
25 |
26 |
27 |
28 | ## Basic usage
29 | This library provides one main interface `odeint` which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value,
30 | ```
31 | dy/dt = f(t, y) y(t_0) = y_0.
32 | ```
33 | The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.
34 |
35 | To solve an IVP using the default solver:
36 | ```
37 | from torchdiffeq import odeint
38 |
39 | odeint(func, y0, t)
40 | ```
41 | where `func` is any callable implementing the ordinary differential equation `f(t, x)`, `y0` is an _any_-D Tensor representing the initial values, and `t` is a 1-D Tensor containing the evaluation points. The initial time is taken to be `t[0]`.
42 |
43 | Backpropagation through `odeint` goes through the internals of the solver. Note that this is not numerically stable for all solvers (but should probably be fine with the default `dopri5` method). Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage.
44 |
45 | To use the adjoint method:
46 | ```
47 | from torchdiffeq import odeint_adjoint as odeint
48 |
49 | odeint(func, y0, t)
50 | ```
51 | `odeint_adjoint` simply wraps around `odeint`, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call.
52 |
53 | The biggest **gotcha** is that `func` must be a `nn.Module` when using the adjoint method. This is used to collect parameters of the differential equation.
54 |
55 | ## Differentiable event handling
56 |
57 | We allow terminating an ODE solution based on an event function. Backpropagation through most solvers is supported. For usage of event handling in deep learning applications, see reference [2].
58 |
59 | This can be invoked with `odeint_event`:
60 | ```
61 | from torchdiffeq import odeint_event
62 | odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs)
63 | ```
64 | - `func` and `y0` are the same as `odeint`.
65 | - `t0` is a scalar representing the initial time value.
66 | - `event_fn(t, y)` returns a tensor, and is a required keyword argument.
67 | - `reverse_time` is a boolean specifying whether we should solve in reverse time. Default is `False`.
68 | - `odeint_interface` is one of `odeint` or `odeint_adjoint`, specifying whether adjoint mode should be used for differentiating through the ODE solution. Default is `odeint`.
69 | - `**kwargs`: any remaining keyword arguments are passed to `odeint_interface`.
70 |
71 | The solve is terminated at an event time `t` and state `y` when an element of `event_fn(t, y)` is equal to zero. Multiple outputs from `event_fn` can be used to specify multiple event functions, of which the first to trigger will terminate the solve.
72 |
73 | Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function. **NOTE**: parameters for the event function must be in the state itself to obtain gradients.
74 |
75 | The numerical precision for the event time is determined by the `atol` argument.
76 |
77 | See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py). See example code for learning a simple event function in [`examples/learn_physics.py`](./examples/learn_physics.py).
78 |
79 |
80 |
81 |
82 |
83 | ## Keyword arguments for odeint(_adjoint)
84 |
85 | #### Keyword arguments:
86 | - `rtol` Relative tolerance.
87 | - `atol` Absolute tolerance.
88 | - `method` One of the solvers listed below.
89 | - `options` A dictionary of solver-specific options, see the [further documentation](FURTHER_DOCUMENTATION.md).
90 |
91 | #### List of ODE Solvers:
92 |
93 | Adaptive-step:
94 | - `dopri8` Runge-Kutta of order 8 of Dormand-Prince-Shampine.
95 | - `dopri5` Runge-Kutta of order 5 of Dormand-Prince-Shampine **[default]**.
96 | - `bosh3` Runge-Kutta of order 3 of Bogacki-Shampine.
97 | - `fehlberg2` Runge-Kutta-Fehlberg of order 2.
98 | - `adaptive_heun` Runge-Kutta of order 2.
99 |
100 | Fixed-step:
101 | - `euler` Euler method.
102 | - `midpoint` Midpoint method.
103 | - `rk4` Fourth-order Runge-Kutta with 3/8 rule.
104 | - `explicit_adams` Explicit Adams-Bashforth.
105 | - `implicit_adams` Implicit Adams-Bashforth-Moulton.
106 |
107 | Additionally, all solvers available through SciPy are wrapped for use with `scipy_solver`.
108 |
109 | For most problems, good choices are the default `dopri5`, or to use `rk4` with `options=dict(step_size=...)` set appropriately small. Adjusting the tolerances (adaptive solvers) or step size (fixed solvers), will allow for trade-offs between speed and accuracy.
110 |
111 | ## Frequently Asked Questions
112 | Take a look at our [FAQ](FAQ.md) for frequently asked questions.
113 |
114 | ## Further documentation
115 | For details of the adjoint-specific and solver-specific options, check out the [further documentation](FURTHER_DOCUMENTATION.md).
116 |
117 | ## References
118 |
119 | Applications of differentiable ODE solvers and event handling are discussed in these two papers:
120 |
121 | Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Information Processing Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366)
122 |
123 | ```
124 | @article{chen2018neuralode,
125 | title={Neural Ordinary Differential Equations},
126 | author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David},
127 | journal={Advances in Neural Information Processing Systems},
128 | year={2018}
129 | }
130 | ```
131 |
132 | Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." *International Conference on Learning Representations.* 2021. [[arxiv]](https://arxiv.org/abs/2011.03902)
133 |
134 | ```
135 | @article{chen2021eventfn,
136 | title={Learning Neural Event Functions for Ordinary Differential Equations},
137 | author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian},
138 | journal={International Conference on Learning Representations},
139 | year={2021}
140 | }
141 | ```
142 |
143 | The seminorm option for computing adjoints is discussed in
144 |
145 | Patrick Kidger, Ricky T. Q. Chen, Terry Lyons. "'Hey, that’s not an ODE': Faster ODE Adjoints via Seminorms." *International Conference on Machine
146 | Learning.* 2021. [[arxiv]](https://arxiv.org/abs/2009.09457)
147 | ```
148 | @article{kidger2021hey,
149 | title={"Hey, that's not an ODE": Faster ODE Adjoints via Seminorms.},
150 | author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry J.},
151 | journal={International Conference on Machine Learning},
152 | year={2021}
153 | }
154 | ```
155 |
156 | ---
157 |
158 | If you found this library useful in your research, please consider citing.
159 | ```
160 | @misc{torchdiffeq,
161 | author={Chen, Ricky T. Q.},
162 | title={torchdiffeq},
163 | year={2018},
164 | url={https://github.com/rtqichen/torchdiffeq},
165 | }
166 | ```
167 |
--------------------------------------------------------------------------------
/assets/bouncing_ball.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/bouncing_ball.png
--------------------------------------------------------------------------------
/assets/cnf_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/cnf_demo.gif
--------------------------------------------------------------------------------
/assets/ode_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/ode_demo.gif
--------------------------------------------------------------------------------
/assets/odenet_0_viz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/odenet_0_viz.png
--------------------------------------------------------------------------------
/assets/resnet_0_viz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/resnet_0_viz.png
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Overview of Examples
2 |
3 | This `examples` directory contains cleaned up code regarding the usage of adaptive ODE solvers in machine learning. The scripts in this directory assume that `torchdiffeq` is installed following instructions from the main directory.
4 |
5 | ## Demo
6 | The `ode_demo.py` file contains a short implementation of learning a dynamics model to mimic a spiral ODE.
7 |
8 | To visualize the training progress, run
9 | ```
10 | python ode_demo.py --viz
11 | ```
12 | The training should look similar to this:
13 |
14 |
15 |
16 |
17 |
18 | ## ODEnet for MNIST
19 | The `odenet_mnist.py` file contains a reproduction of the MNIST experiments in our Neural ODE paper. Notably not just the architecture but the ODE solver library and integration method are different from our original experiments, though the results are similar to those we report in the paper.
20 |
21 | We can use an adaptive ODE solver to approximate our continuous-depth network while still backpropagating through the network.
22 | ```
23 | python odenet_mnist.py --network odenet
24 | ```
25 | However, the memory requirements for this will blow up very fast, especially for more complex problems where the number of function evaluations can reach nearly a thousand.
26 |
27 | For applications that require solving complex trajectories, we recommend using the adjoint method.
28 | ```
29 | python odenet_mnist.py --network odenet --adjoint True
30 | ```
31 | The adjoint method can be slower when using an adaptive ODE solver as it involves another solve in the backward pass with a much larger system, so experimenting on small systems with direct backpropagation first is recommended.
32 |
33 | Thankfully, it is extremely easy to write code for both adjoint and non-adjoint backpropagation, as they use the same interface.
34 | ```
35 | if adjoint:
36 | from torchdiffeq import odeint_adjoint as odeint
37 | else:
38 | from torchdiffeq import odeint
39 | ```
40 | The main gotcha is that `odeint_adjoint` requires implementing the dynamics network as a `nn.Module` while `odeint` can work with any callable in Python.
41 |
42 | ## Continuous Normalizing Flows
43 |
44 | The `cnf.py` file contains a simple CNF implementation for learning the density of a coencentric circles dataset.
45 |
46 | To train a CNF and visualise the resulting dynamics, run
47 | ```
48 | python cnf.py --viz
49 | ```
50 | The result should look similar to this:
51 |
52 |
53 |
54 |
55 |
56 | More comprehensive code for continuous normalizing flows (CNFs) has its own public repository. Tools for training, evaluating, and visualizing CNFs for reversible generative modeling are provided along with FFJORD, a linear cost stochastic approximation of CNFs.
57 |
58 | Find the code in https://github.com/rtqichen/ffjord. This code contains some advanced tricks for `torchdiffeq`.
59 |
--------------------------------------------------------------------------------
/examples/bouncing_ball.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import matplotlib.pyplot as plt
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from torchdiffeq import odeint, odeint_adjoint
9 | from torchdiffeq import odeint_event
10 |
11 | torch.set_default_dtype(torch.float64)
12 |
13 |
14 | class BouncingBallExample(nn.Module):
15 | def __init__(self, radius=0.2, gravity=9.8, adjoint=False):
16 | super().__init__()
17 | self.gravity = nn.Parameter(torch.as_tensor([gravity]))
18 | self.log_radius = nn.Parameter(torch.log(torch.as_tensor([radius])))
19 | self.t0 = nn.Parameter(torch.tensor([0.0]))
20 | self.init_pos = nn.Parameter(torch.tensor([10.0]))
21 | self.init_vel = nn.Parameter(torch.tensor([0.0]))
22 | self.absorption = nn.Parameter(torch.tensor([0.2]))
23 | self.odeint = odeint_adjoint if adjoint else odeint
24 |
25 | def forward(self, t, state):
26 | pos, vel, log_radius = state
27 | dpos = vel
28 | dvel = -self.gravity
29 | return dpos, dvel, torch.zeros_like(log_radius)
30 |
31 | def event_fn(self, t, state):
32 | # positive if ball in mid-air, negative if ball within ground.
33 | pos, _, log_radius = state
34 | return pos - torch.exp(log_radius)
35 |
36 | def get_initial_state(self):
37 | state = (self.init_pos, self.init_vel, self.log_radius)
38 | return self.t0, state
39 |
40 | def state_update(self, state):
41 | """Updates state based on an event (collision)."""
42 | pos, vel, log_radius = state
43 | pos = (
44 | pos + 1e-7
45 | ) # need to add a small eps so as not to trigger the event function immediately.
46 | vel = -vel * (1 - self.absorption)
47 | return (pos, vel, log_radius)
48 |
49 | def get_collision_times(self, nbounces=1):
50 |
51 | event_times = []
52 |
53 | t0, state = self.get_initial_state()
54 |
55 | for i in range(nbounces):
56 | event_t, solution = odeint_event(
57 | self,
58 | state,
59 | t0,
60 | event_fn=self.event_fn,
61 | reverse_time=False,
62 | atol=1e-8,
63 | rtol=1e-8,
64 | odeint_interface=self.odeint,
65 | )
66 | event_times.append(event_t)
67 |
68 | state = self.state_update(tuple(s[-1] for s in solution))
69 | t0 = event_t
70 |
71 | return event_times
72 |
73 | def simulate(self, nbounces=1):
74 | event_times = self.get_collision_times(nbounces)
75 |
76 | # get dense path
77 | t0, state = self.get_initial_state()
78 | trajectory = [state[0][None]]
79 | velocity = [state[1][None]]
80 | times = [t0.reshape(-1)]
81 | for event_t in event_times:
82 | tt = torch.linspace(
83 | float(t0), float(event_t), int((float(event_t) - float(t0)) * 50)
84 | )[1:-1]
85 | tt = torch.cat([t0.reshape(-1), tt, event_t.reshape(-1)])
86 | solution = odeint(self, state, tt, atol=1e-8, rtol=1e-8)
87 |
88 | trajectory.append(solution[0][1:])
89 | velocity.append(solution[1][1:])
90 | times.append(tt[1:])
91 |
92 | state = self.state_update(tuple(s[-1] for s in solution))
93 | t0 = event_t
94 |
95 | return (
96 | torch.cat(times),
97 | torch.cat(trajectory, dim=0).reshape(-1),
98 | torch.cat(velocity, dim=0).reshape(-1),
99 | event_times,
100 | )
101 |
102 |
103 | def gradcheck(nbounces):
104 |
105 | system = BouncingBallExample()
106 |
107 | variables = {
108 | "init_pos": system.init_pos,
109 | "init_vel": system.init_vel,
110 | "t0": system.t0,
111 | "gravity": system.gravity,
112 | "log_radius": system.log_radius,
113 | }
114 |
115 | event_t = system.get_collision_times(nbounces)[-1]
116 | event_t.backward()
117 |
118 | analytical_grads = {}
119 | for name, p in system.named_parameters():
120 | for var in variables.keys():
121 | if var in name:
122 | analytical_grads[var] = p.grad
123 |
124 | eps = 1e-3
125 |
126 | fd_grads = {}
127 |
128 | for var, param in variables.items():
129 | orig = param.data
130 | param.data = orig - eps
131 | f_meps = system.get_collision_times(nbounces)[-1]
132 | param.data = orig + eps
133 | f_peps = system.get_collision_times(nbounces)[-1]
134 | param.data = orig
135 | fd = (f_peps - f_meps) / (2 * eps)
136 | fd_grads[var] = fd
137 |
138 | success = True
139 | for var in variables.keys():
140 | analytical = analytical_grads[var]
141 | fd = fd_grads[var]
142 | if torch.norm(analytical - fd) > 1e-4:
143 | success = False
144 | print(
145 | f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}"
146 | )
147 |
148 | if not success:
149 | raise Exception("Gradient check failed.")
150 |
151 | print("Gradient check passed.")
152 |
153 |
154 | if __name__ == "__main__":
155 |
156 | parser = argparse.ArgumentParser(description="Process some integers.")
157 | parser.add_argument("nbounces", type=int, nargs="?", default=10)
158 | parser.add_argument("--adjoint", action="store_true")
159 | args = parser.parse_args()
160 |
161 | gradcheck(args.nbounces)
162 |
163 | system = BouncingBallExample()
164 | times, trajectory, velocity, event_times = system.simulate(nbounces=args.nbounces)
165 | times = times.detach().cpu().numpy()
166 | trajectory = trajectory.detach().cpu().numpy()
167 | velocity = velocity.detach().cpu().numpy()
168 | event_times = torch.stack(event_times).detach().cpu().numpy()
169 |
170 | plt.figure(figsize=(7, 3.5))
171 |
172 | # Event locations.
173 | for event_t in event_times:
174 | plt.plot(
175 | event_t,
176 | 0.0,
177 | color="C0",
178 | marker="o",
179 | markersize=7,
180 | fillstyle="none",
181 | linestyle="",
182 | )
183 |
184 | (vel,) = plt.plot(
185 | times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0
186 | )
187 | (pos,) = plt.plot(times, trajectory, color="C0", linewidth=2.0)
188 |
189 | plt.hlines(0, 0, 100)
190 | plt.xlim([times[0], times[-1]])
191 | plt.ylim([velocity.min() - 0.02, velocity.max() + 0.02])
192 | plt.ylabel("Markov State", fontsize=16)
193 | plt.xlabel("Time", fontsize=13)
194 | plt.legend([pos, vel], ["Position", "Velocity"], fontsize=16)
195 |
196 | plt.gca().xaxis.set_tick_params(
197 | direction="in", which="both"
198 | ) # The bottom will maintain the default of 'out'
199 | plt.gca().yaxis.set_tick_params(
200 | direction="in", which="both"
201 | ) # The bottom will maintain the default of 'out'
202 |
203 | # Hide the right and top spines
204 | plt.gca().spines["right"].set_visible(False)
205 | plt.gca().spines["top"].set_visible(False)
206 |
207 | # Only show ticks on the left and bottom spines
208 | plt.gca().yaxis.set_ticks_position("left")
209 | plt.gca().xaxis.set_ticks_position("bottom")
210 |
211 | plt.tight_layout()
212 | plt.savefig("bouncing_ball.png")
213 |
--------------------------------------------------------------------------------
/examples/cnf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import glob
4 | from PIL import Image
5 | import numpy as np
6 | import matplotlib
7 | matplotlib.use('agg')
8 | import matplotlib.pyplot as plt
9 | from sklearn.datasets import make_circles
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--adjoint', action='store_true')
17 | parser.add_argument('--viz', action='store_true')
18 | parser.add_argument('--niters', type=int, default=1000)
19 | parser.add_argument('--lr', type=float, default=1e-3)
20 | parser.add_argument('--num_samples', type=int, default=512)
21 | parser.add_argument('--width', type=int, default=64)
22 | parser.add_argument('--hidden_dim', type=int, default=32)
23 | parser.add_argument('--gpu', type=int, default=0)
24 | parser.add_argument('--train_dir', type=str, default=None)
25 | parser.add_argument('--results_dir', type=str, default="./results")
26 | args = parser.parse_args()
27 |
28 | if args.adjoint:
29 | from torchdiffeq import odeint_adjoint as odeint
30 | else:
31 | from torchdiffeq import odeint
32 |
33 |
34 | class CNF(nn.Module):
35 | """Adapted from the NumPy implementation at:
36 | https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52
37 | """
38 | def __init__(self, in_out_dim, hidden_dim, width):
39 | super().__init__()
40 | self.in_out_dim = in_out_dim
41 | self.hidden_dim = hidden_dim
42 | self.width = width
43 | self.hyper_net = HyperNetwork(in_out_dim, hidden_dim, width)
44 |
45 | def forward(self, t, states):
46 | z = states[0]
47 | logp_z = states[1]
48 |
49 | batchsize = z.shape[0]
50 |
51 | with torch.set_grad_enabled(True):
52 | z.requires_grad_(True)
53 |
54 | W, B, U = self.hyper_net(t)
55 |
56 | Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1)
57 |
58 | h = torch.tanh(torch.matmul(Z, W) + B)
59 | dz_dt = torch.matmul(h, U).mean(0)
60 |
61 | dlogp_z_dt = -trace_df_dz(dz_dt, z).view(batchsize, 1)
62 |
63 | return (dz_dt, dlogp_z_dt)
64 |
65 |
66 | def trace_df_dz(f, z):
67 | """Calculates the trace of the Jacobian df/dz.
68 | Stolen from: https://github.com/rtqichen/ffjord/blob/master/lib/layers/odefunc.py#L13
69 | """
70 | sum_diag = 0.
71 | for i in range(z.shape[1]):
72 | sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()
73 |
74 | return sum_diag.contiguous()
75 |
76 |
77 | class HyperNetwork(nn.Module):
78 | """Hyper-network allowing f(z(t), t) to change with time.
79 |
80 | Adapted from the NumPy implementation at:
81 | https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52
82 | """
83 | def __init__(self, in_out_dim, hidden_dim, width):
84 | super().__init__()
85 |
86 | blocksize = width * in_out_dim
87 |
88 | self.fc1 = nn.Linear(1, hidden_dim)
89 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
90 | self.fc3 = nn.Linear(hidden_dim, 3 * blocksize + width)
91 |
92 | self.in_out_dim = in_out_dim
93 | self.hidden_dim = hidden_dim
94 | self.width = width
95 | self.blocksize = blocksize
96 |
97 | def forward(self, t):
98 | # predict params
99 | params = t.reshape(1, 1)
100 | params = torch.tanh(self.fc1(params))
101 | params = torch.tanh(self.fc2(params))
102 | params = self.fc3(params)
103 |
104 | # restructure
105 | params = params.reshape(-1)
106 | W = params[:self.blocksize].reshape(self.width, self.in_out_dim, 1)
107 |
108 | U = params[self.blocksize:2 * self.blocksize].reshape(self.width, 1, self.in_out_dim)
109 |
110 | G = params[2 * self.blocksize:3 * self.blocksize].reshape(self.width, 1, self.in_out_dim)
111 | U = U * torch.sigmoid(G)
112 |
113 | B = params[3 * self.blocksize:].reshape(self.width, 1, 1)
114 | return [W, B, U]
115 |
116 |
117 | class RunningAverageMeter(object):
118 | """Computes and stores the average and current value"""
119 |
120 | def __init__(self, momentum=0.99):
121 | self.momentum = momentum
122 | self.reset()
123 |
124 | def reset(self):
125 | self.val = None
126 | self.avg = 0
127 |
128 | def update(self, val):
129 | if self.val is None:
130 | self.avg = val
131 | else:
132 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
133 | self.val = val
134 |
135 |
136 | def get_batch(num_samples):
137 | points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
138 | x = torch.tensor(points).type(torch.float32).to(device)
139 | logp_diff_t1 = torch.zeros(num_samples, 1).type(torch.float32).to(device)
140 |
141 | return(x, logp_diff_t1)
142 |
143 |
144 | if __name__ == '__main__':
145 | t0 = 0
146 | t1 = 10
147 | device = torch.device('cuda:' + str(args.gpu)
148 | if torch.cuda.is_available() else 'cpu')
149 |
150 | # model
151 | func = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width).to(device)
152 | optimizer = optim.Adam(func.parameters(), lr=args.lr)
153 | p_z0 = torch.distributions.MultivariateNormal(
154 | loc=torch.tensor([0.0, 0.0]).to(device),
155 | covariance_matrix=torch.tensor([[0.1, 0.0], [0.0, 0.1]]).to(device)
156 | )
157 | loss_meter = RunningAverageMeter()
158 |
159 | if args.train_dir is not None:
160 | if not os.path.exists(args.train_dir):
161 | os.makedirs(args.train_dir)
162 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
163 | if os.path.exists(ckpt_path):
164 | checkpoint = torch.load(ckpt_path)
165 | func.load_state_dict(checkpoint['func_state_dict'])
166 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
167 | print('Loaded ckpt from {}'.format(ckpt_path))
168 |
169 | try:
170 | for itr in range(1, args.niters + 1):
171 | optimizer.zero_grad()
172 |
173 | x, logp_diff_t1 = get_batch(args.num_samples)
174 |
175 | z_t, logp_diff_t = odeint(
176 | func,
177 | (x, logp_diff_t1),
178 | torch.tensor([t1, t0]).type(torch.float32).to(device),
179 | atol=1e-5,
180 | rtol=1e-5,
181 | method='dopri5',
182 | )
183 |
184 | z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]
185 |
186 | logp_x = p_z0.log_prob(z_t0).to(device) - logp_diff_t0.view(-1)
187 | loss = -logp_x.mean(0)
188 |
189 | loss.backward()
190 | optimizer.step()
191 |
192 | loss_meter.update(loss.item())
193 |
194 | print('Iter: {}, running avg loss: {:.4f}'.format(itr, loss_meter.avg))
195 |
196 | except KeyboardInterrupt:
197 | if args.train_dir is not None:
198 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
199 | torch.save({
200 | 'func_state_dict': func.state_dict(),
201 | 'optimizer_state_dict': optimizer.state_dict(),
202 | }, ckpt_path)
203 | print('Stored ckpt at {}'.format(ckpt_path))
204 | print('Training complete after {} iters.'.format(itr))
205 |
206 | if args.viz:
207 | viz_samples = 30000
208 | viz_timesteps = 41
209 | target_sample, _ = get_batch(viz_samples)
210 |
211 | if not os.path.exists(args.results_dir):
212 | os.makedirs(args.results_dir)
213 | with torch.no_grad():
214 | # Generate evolution of samples
215 | z_t0 = p_z0.sample([viz_samples]).to(device)
216 | logp_diff_t0 = torch.zeros(viz_samples, 1).type(torch.float32).to(device)
217 |
218 | z_t_samples, _ = odeint(
219 | func,
220 | (z_t0, logp_diff_t0),
221 | torch.tensor(np.linspace(t0, t1, viz_timesteps)).to(device),
222 | atol=1e-5,
223 | rtol=1e-5,
224 | method='dopri5',
225 | )
226 |
227 | # Generate evolution of density
228 | x = np.linspace(-1.5, 1.5, 100)
229 | y = np.linspace(-1.5, 1.5, 100)
230 | points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T
231 |
232 | z_t1 = torch.tensor(points).type(torch.float32).to(device)
233 | logp_diff_t1 = torch.zeros(z_t1.shape[0], 1).type(torch.float32).to(device)
234 |
235 | z_t_density, logp_diff_t = odeint(
236 | func,
237 | (z_t1, logp_diff_t1),
238 | torch.tensor(np.linspace(t1, t0, viz_timesteps)).to(device),
239 | atol=1e-5,
240 | rtol=1e-5,
241 | method='dopri5',
242 | )
243 |
244 | # Create plots for each timestep
245 | for (t, z_sample, z_density, logp_diff) in zip(
246 | np.linspace(t0, t1, viz_timesteps),
247 | z_t_samples, z_t_density, logp_diff_t
248 | ):
249 | fig = plt.figure(figsize=(12, 4), dpi=200)
250 | plt.tight_layout()
251 | plt.axis('off')
252 | plt.margins(0, 0)
253 | fig.suptitle(f'{t:.2f}s')
254 |
255 | ax1 = fig.add_subplot(1, 3, 1)
256 | ax1.set_title('Target')
257 | ax1.get_xaxis().set_ticks([])
258 | ax1.get_yaxis().set_ticks([])
259 | ax2 = fig.add_subplot(1, 3, 2)
260 | ax2.set_title('Samples')
261 | ax2.get_xaxis().set_ticks([])
262 | ax2.get_yaxis().set_ticks([])
263 | ax3 = fig.add_subplot(1, 3, 3)
264 | ax3.set_title('Log Probability')
265 | ax3.get_xaxis().set_ticks([])
266 | ax3.get_yaxis().set_ticks([])
267 |
268 | ax1.hist2d(*target_sample.detach().cpu().numpy().T, bins=300, density=True,
269 | range=[[-1.5, 1.5], [-1.5, 1.5]])
270 |
271 | ax2.hist2d(*z_sample.detach().cpu().numpy().T, bins=300, density=True,
272 | range=[[-1.5, 1.5], [-1.5, 1.5]])
273 |
274 | logp = p_z0.log_prob(z_density) - logp_diff.view(-1)
275 | ax3.tricontourf(*z_t1.detach().cpu().numpy().T,
276 | np.exp(logp.detach().cpu().numpy()), 200)
277 |
278 | plt.savefig(os.path.join(args.results_dir, f"cnf-viz-{int(t*1000):05d}.jpg"),
279 | pad_inches=0.2, bbox_inches='tight')
280 | plt.close()
281 |
282 | img, *imgs = [Image.open(f) for f in sorted(glob.glob(os.path.join(args.results_dir, f"cnf-viz-*.jpg")))]
283 | img.save(fp=os.path.join(args.results_dir, "cnf-viz.gif"), format='GIF', append_images=imgs,
284 | save_all=True, duration=250, loop=0)
285 |
286 | print('Saved visualization animation at {}'.format(os.path.join(args.results_dir, "cnf-viz.gif")))
287 |
--------------------------------------------------------------------------------
/examples/latent_ode.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import time
5 | import numpy as np
6 | import numpy.random as npr
7 | import matplotlib
8 | matplotlib.use('agg')
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | import torch.nn.functional as F
14 |
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--adjoint', type=eval, default=False)
18 | parser.add_argument('--visualize', type=eval, default=False)
19 | parser.add_argument('--niters', type=int, default=2000)
20 | parser.add_argument('--lr', type=float, default=0.01)
21 | parser.add_argument('--gpu', type=int, default=0)
22 | parser.add_argument('--train_dir', type=str, default=None)
23 | args = parser.parse_args()
24 |
25 | if args.adjoint:
26 | from torchdiffeq import odeint_adjoint as odeint
27 | else:
28 | from torchdiffeq import odeint
29 |
30 |
31 | def generate_spiral2d(nspiral=1000,
32 | ntotal=500,
33 | nsample=100,
34 | start=0.,
35 | stop=1, # approximately equal to 6pi
36 | noise_std=.1,
37 | a=0.,
38 | b=1.,
39 | savefig=True):
40 | """Parametric formula for 2d spiral is `r = a + b * theta`.
41 |
42 | Args:
43 | nspiral: number of spirals, i.e. batch dimension
44 | ntotal: total number of datapoints per spiral
45 | nsample: number of sampled datapoints for model fitting per spiral
46 | start: spiral starting theta value
47 | stop: spiral ending theta value
48 | noise_std: observation noise standard deviation
49 | a, b: parameters of the Archimedean spiral
50 | savefig: plot the ground truth for sanity check
51 |
52 | Returns:
53 | Tuple where first element is true trajectory of size (nspiral, ntotal, 2),
54 | second element is noisy observations of size (nspiral, nsample, 2),
55 | third element is timestamps of size (ntotal,),
56 | and fourth element is timestamps of size (nsample,)
57 | """
58 |
59 | # add 1 all timestamps to avoid division by 0
60 | orig_ts = np.linspace(start, stop, num=ntotal)
61 | samp_ts = orig_ts[:nsample]
62 |
63 | # generate clock-wise and counter clock-wise spirals in observation space
64 | # with two sets of time-invariant latent dynamics
65 | zs_cw = stop + 1. - orig_ts
66 | rs_cw = a + b * 50. / zs_cw
67 | xs, ys = rs_cw * np.cos(zs_cw) - 5., rs_cw * np.sin(zs_cw)
68 | orig_traj_cw = np.stack((xs, ys), axis=1)
69 |
70 | zs_cc = orig_ts
71 | rw_cc = a + b * zs_cc
72 | xs, ys = rw_cc * np.cos(zs_cc) + 5., rw_cc * np.sin(zs_cc)
73 | orig_traj_cc = np.stack((xs, ys), axis=1)
74 |
75 | if savefig:
76 | plt.figure()
77 | plt.plot(orig_traj_cw[:, 0], orig_traj_cw[:, 1], label='clock')
78 | plt.plot(orig_traj_cc[:, 0], orig_traj_cc[:, 1], label='counter clock')
79 | plt.legend()
80 | plt.savefig('./ground_truth.png', dpi=500)
81 | print('Saved ground truth spiral at {}'.format('./ground_truth.png'))
82 |
83 | # sample starting timestamps
84 | orig_trajs = []
85 | samp_trajs = []
86 | for _ in range(nspiral):
87 | # don't sample t0 very near the start or the end
88 | t0_idx = npr.multinomial(
89 | 1, [1. / (ntotal - 2. * nsample)] * (ntotal - int(2 * nsample)))
90 | t0_idx = np.argmax(t0_idx) + nsample
91 |
92 | cc = bool(npr.rand() > .5) # uniformly select rotation
93 | orig_traj = orig_traj_cc if cc else orig_traj_cw
94 | orig_trajs.append(orig_traj)
95 |
96 | samp_traj = orig_traj[t0_idx:t0_idx + nsample, :].copy()
97 | samp_traj += npr.randn(*samp_traj.shape) * noise_std
98 | samp_trajs.append(samp_traj)
99 |
100 | # batching for sample trajectories is good for RNN; batching for original
101 | # trajectories only for ease of indexing
102 | orig_trajs = np.stack(orig_trajs, axis=0)
103 | samp_trajs = np.stack(samp_trajs, axis=0)
104 |
105 | return orig_trajs, samp_trajs, orig_ts, samp_ts
106 |
107 |
108 | class LatentODEfunc(nn.Module):
109 |
110 | def __init__(self, latent_dim=4, nhidden=20):
111 | super(LatentODEfunc, self).__init__()
112 | self.elu = nn.ELU(inplace=True)
113 | self.fc1 = nn.Linear(latent_dim, nhidden)
114 | self.fc2 = nn.Linear(nhidden, nhidden)
115 | self.fc3 = nn.Linear(nhidden, latent_dim)
116 | self.nfe = 0
117 |
118 | def forward(self, t, x):
119 | self.nfe += 1
120 | out = self.fc1(x)
121 | out = self.elu(out)
122 | out = self.fc2(out)
123 | out = self.elu(out)
124 | out = self.fc3(out)
125 | return out
126 |
127 |
128 | class RecognitionRNN(nn.Module):
129 |
130 | def __init__(self, latent_dim=4, obs_dim=2, nhidden=25, nbatch=1):
131 | super(RecognitionRNN, self).__init__()
132 | self.nhidden = nhidden
133 | self.nbatch = nbatch
134 | self.i2h = nn.Linear(obs_dim + nhidden, nhidden)
135 | self.h2o = nn.Linear(nhidden, latent_dim * 2)
136 |
137 | def forward(self, x, h):
138 | combined = torch.cat((x, h), dim=1)
139 | h = torch.tanh(self.i2h(combined))
140 | out = self.h2o(h)
141 | return out, h
142 |
143 | def initHidden(self):
144 | return torch.zeros(self.nbatch, self.nhidden)
145 |
146 |
147 | class Decoder(nn.Module):
148 |
149 | def __init__(self, latent_dim=4, obs_dim=2, nhidden=20):
150 | super(Decoder, self).__init__()
151 | self.relu = nn.ReLU(inplace=True)
152 | self.fc1 = nn.Linear(latent_dim, nhidden)
153 | self.fc2 = nn.Linear(nhidden, obs_dim)
154 |
155 | def forward(self, z):
156 | out = self.fc1(z)
157 | out = self.relu(out)
158 | out = self.fc2(out)
159 | return out
160 |
161 |
162 | class RunningAverageMeter(object):
163 | """Computes and stores the average and current value"""
164 |
165 | def __init__(self, momentum=0.99):
166 | self.momentum = momentum
167 | self.reset()
168 |
169 | def reset(self):
170 | self.val = None
171 | self.avg = 0
172 |
173 | def update(self, val):
174 | if self.val is None:
175 | self.avg = val
176 | else:
177 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
178 | self.val = val
179 |
180 |
181 | def log_normal_pdf(x, mean, logvar):
182 | const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
183 | const = torch.log(const)
184 | return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))
185 |
186 |
187 | def normal_kl(mu1, lv1, mu2, lv2):
188 | v1 = torch.exp(lv1)
189 | v2 = torch.exp(lv2)
190 | lstd1 = lv1 / 2.
191 | lstd2 = lv2 / 2.
192 |
193 | kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
194 | return kl
195 |
196 |
197 | if __name__ == '__main__':
198 | latent_dim = 4
199 | nhidden = 20
200 | rnn_nhidden = 25
201 | obs_dim = 2
202 | nspiral = 1000
203 | start = 0.
204 | stop = 6 * np.pi
205 | noise_std = .3
206 | a = 0.
207 | b = .3
208 | ntotal = 1000
209 | nsample = 100
210 | device = torch.device('cuda:' + str(args.gpu)
211 | if torch.cuda.is_available() else 'cpu')
212 |
213 | # generate toy spiral data
214 | orig_trajs, samp_trajs, orig_ts, samp_ts = generate_spiral2d(
215 | nspiral=nspiral,
216 | start=start,
217 | stop=stop,
218 | noise_std=noise_std,
219 | a=a, b=b
220 | )
221 | orig_trajs = torch.from_numpy(orig_trajs).float().to(device)
222 | samp_trajs = torch.from_numpy(samp_trajs).float().to(device)
223 | samp_ts = torch.from_numpy(samp_ts).float().to(device)
224 |
225 | # model
226 | func = LatentODEfunc(latent_dim, nhidden).to(device)
227 | rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, nspiral).to(device)
228 | dec = Decoder(latent_dim, obs_dim, nhidden).to(device)
229 | params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters()))
230 | optimizer = optim.Adam(params, lr=args.lr)
231 | loss_meter = RunningAverageMeter()
232 |
233 | if args.train_dir is not None:
234 | if not os.path.exists(args.train_dir):
235 | os.makedirs(args.train_dir)
236 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
237 | if os.path.exists(ckpt_path):
238 | checkpoint = torch.load(ckpt_path)
239 | func.load_state_dict(checkpoint['func_state_dict'])
240 | rec.load_state_dict(checkpoint['rec_state_dict'])
241 | dec.load_state_dict(checkpoint['dec_state_dict'])
242 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
243 | orig_trajs = checkpoint['orig_trajs']
244 | samp_trajs = checkpoint['samp_trajs']
245 | orig_ts = checkpoint['orig_ts']
246 | samp_ts = checkpoint['samp_ts']
247 | print('Loaded ckpt from {}'.format(ckpt_path))
248 |
249 | try:
250 | for itr in range(1, args.niters + 1):
251 | optimizer.zero_grad()
252 | # backward in time to infer q(z_0)
253 | h = rec.initHidden().to(device)
254 | for t in reversed(range(samp_trajs.size(1))):
255 | obs = samp_trajs[:, t, :]
256 | out, h = rec.forward(obs, h)
257 | qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
258 | epsilon = torch.randn(qz0_mean.size()).to(device)
259 | z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
260 |
261 | # forward in time and solve ode for reconstructions
262 | pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
263 | pred_x = dec(pred_z)
264 |
265 | # compute loss
266 | noise_std_ = torch.zeros(pred_x.size()).to(device) + noise_std
267 | noise_logvar = 2. * torch.log(noise_std_).to(device)
268 | logpx = log_normal_pdf(
269 | samp_trajs, pred_x, noise_logvar).sum(-1).sum(-1)
270 | pz0_mean = pz0_logvar = torch.zeros(z0.size()).to(device)
271 | analytic_kl = normal_kl(qz0_mean, qz0_logvar,
272 | pz0_mean, pz0_logvar).sum(-1)
273 | loss = torch.mean(-logpx + analytic_kl, dim=0)
274 | loss.backward()
275 | optimizer.step()
276 | loss_meter.update(loss.item())
277 |
278 | print('Iter: {}, running avg elbo: {:.4f}'.format(itr, -loss_meter.avg))
279 |
280 | except KeyboardInterrupt:
281 | if args.train_dir is not None:
282 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
283 | torch.save({
284 | 'func_state_dict': func.state_dict(),
285 | 'rec_state_dict': rec.state_dict(),
286 | 'dec_state_dict': dec.state_dict(),
287 | 'optimizer_state_dict': optimizer.state_dict(),
288 | 'orig_trajs': orig_trajs,
289 | 'samp_trajs': samp_trajs,
290 | 'orig_ts': orig_ts,
291 | 'samp_ts': samp_ts,
292 | }, ckpt_path)
293 | print('Stored ckpt at {}'.format(ckpt_path))
294 | print('Training complete after {} iters.'.format(itr))
295 |
296 | if args.visualize:
297 | with torch.no_grad():
298 | # sample from trajectorys' approx. posterior
299 | h = rec.initHidden().to(device)
300 | for t in reversed(range(samp_trajs.size(1))):
301 | obs = samp_trajs[:, t, :]
302 | out, h = rec.forward(obs, h)
303 | qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
304 | epsilon = torch.randn(qz0_mean.size()).to(device)
305 | z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
306 | orig_ts = torch.from_numpy(orig_ts).float().to(device)
307 |
308 | # take first trajectory for visualization
309 | z0 = z0[0]
310 |
311 | ts_pos = np.linspace(0., 2. * np.pi, num=2000)
312 | ts_neg = np.linspace(-np.pi, 0., num=2000)[::-1].copy()
313 | ts_pos = torch.from_numpy(ts_pos).float().to(device)
314 | ts_neg = torch.from_numpy(ts_neg).float().to(device)
315 |
316 | zs_pos = odeint(func, z0, ts_pos)
317 | zs_neg = odeint(func, z0, ts_neg)
318 |
319 | xs_pos = dec(zs_pos)
320 | xs_neg = torch.flip(dec(zs_neg), dims=[0])
321 |
322 | xs_pos = xs_pos.cpu().numpy()
323 | xs_neg = xs_neg.cpu().numpy()
324 | orig_traj = orig_trajs[0].cpu().numpy()
325 | samp_traj = samp_trajs[0].cpu().numpy()
326 |
327 | plt.figure()
328 | plt.plot(orig_traj[:, 0], orig_traj[:, 1],
329 | 'g', label='true trajectory')
330 | plt.plot(xs_pos[:, 0], xs_pos[:, 1], 'r',
331 | label='learned trajectory (t>0)')
332 | plt.plot(xs_neg[:, 0], xs_neg[:, 1], 'c',
333 | label='learned trajectory (t<0)')
334 | plt.scatter(samp_traj[:, 0], samp_traj[
335 | :, 1], label='sampled data', s=3)
336 | plt.legend()
337 | plt.savefig('./vis.png', dpi=500)
338 | print('Saved visualization figure at {}'.format('./vis.png'))
339 |
--------------------------------------------------------------------------------
/examples/learn_physics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import os
5 | import math
6 | import matplotlib.pyplot as plt
7 | import torch
8 | import torch.nn as nn
9 | from torchdiffeq import odeint, odeint_event
10 |
11 | from bouncing_ball import BouncingBallExample
12 |
13 |
14 | class HamiltonianDynamics(nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 | self.dvel = nn.Linear(1, 1)
18 | self.scale = nn.Parameter(torch.tensor(10.0))
19 |
20 | def forward(self, t, state):
21 | pos, vel, *rest = state
22 | dpos = vel
23 | dvel = torch.tanh(self.dvel(torch.zeros_like(vel))) * self.scale
24 | return (dpos, dvel, *[torch.zeros_like(r) for r in rest])
25 |
26 |
27 | class EventFn(nn.Module):
28 | def __init__(self):
29 | super().__init__()
30 | self.radius = nn.Parameter(torch.rand(1))
31 |
32 | def parameters(self):
33 | return [self.radius]
34 |
35 | def forward(self, t, state):
36 | # IMPORTANT: event computation must use variables from the state.
37 | pos, _, radius = state
38 | return pos - radius.reshape_as(pos) ** 2
39 |
40 |
41 | class InstantaneousStateChange(nn.Module):
42 | def __init__(self):
43 | super().__init__()
44 | self.net = nn.Linear(1, 1)
45 |
46 | def forward(self, t, state):
47 | pos, vel, *rest = state
48 | vel = -torch.sigmoid(self.net(torch.ones_like(vel))) * vel
49 | return (pos, vel, *rest)
50 |
51 |
52 | class NeuralPhysics(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 | self.initial_pos = nn.Parameter(torch.tensor([10.0]))
56 | self.initial_vel = nn.Parameter(torch.tensor([0.0]))
57 | self.dynamics_fn = HamiltonianDynamics()
58 | self.event_fn = EventFn()
59 | self.inst_update = InstantaneousStateChange()
60 |
61 | def simulate(self, times):
62 |
63 | t0 = torch.tensor([0.0]).to(times)
64 |
65 | # Add a terminal time to the event function.
66 | def event_fn(t, state):
67 | if t > times[-1] + 1e-7:
68 | return torch.zeros_like(t)
69 | event_fval = self.event_fn(t, state)
70 | return event_fval
71 |
72 | # IMPORTANT: for gradients of odeint_event to be computed, parameters of the event function
73 | # must appear in the state in the current implementation.
74 | state = (self.initial_pos, self.initial_vel, *self.event_fn.parameters())
75 |
76 | event_times = []
77 |
78 | trajectory = [state[0][None]]
79 |
80 | n_events = 0
81 | max_events = 20
82 |
83 | while t0 < times[-1] and n_events < max_events:
84 | last = n_events == max_events - 1
85 |
86 | if not last:
87 | event_t, solution = odeint_event(
88 | self.dynamics_fn,
89 | state,
90 | t0,
91 | event_fn=event_fn,
92 | atol=1e-8,
93 | rtol=1e-8,
94 | method="dopri5",
95 | )
96 | else:
97 | event_t = times[-1]
98 |
99 | interval_ts = times[times > t0]
100 | interval_ts = interval_ts[interval_ts <= event_t]
101 | interval_ts = torch.cat([t0.reshape(-1), interval_ts.reshape(-1)])
102 |
103 | solution_ = odeint(
104 | self.dynamics_fn, state, interval_ts, atol=1e-8, rtol=1e-8
105 | )
106 | traj_ = solution_[0][1:] # [0] for position; [1:] to remove intial state.
107 | trajectory.append(traj_)
108 |
109 | if event_t < times[-1]:
110 | state = tuple(s[-1] for s in solution)
111 |
112 | # update velocity instantaneously.
113 | state = self.inst_update(event_t, state)
114 |
115 | # advance the position a little bit to avoid re-triggering the event fn.
116 | pos, *rest = state
117 | pos = pos + 1e-7 * self.dynamics_fn(event_t, state)[0]
118 | state = pos, *rest
119 |
120 | event_times.append(event_t)
121 | t0 = event_t
122 |
123 | n_events += 1
124 |
125 | # print(event_t.item(), state[0].item(), state[1].item(), self.event_fn.mod(pos).item())
126 |
127 | trajectory = torch.cat(trajectory, dim=0).reshape(-1)
128 | return trajectory, event_times
129 |
130 |
131 | class Sine(nn.Module):
132 | def forward(self, x):
133 | return torch.sin(x)
134 |
135 |
136 | class NeuralODE(nn.Module):
137 | def __init__(self, aug_dim=2):
138 | super().__init__()
139 | self.initial_pos = nn.Parameter(torch.tensor([10.0]))
140 | self.initial_aug = nn.Parameter(torch.zeros(aug_dim))
141 | self.odefunc = mlp(
142 | input_dim=1 + aug_dim,
143 | hidden_dim=64,
144 | output_dim=1 + aug_dim,
145 | hidden_depth=2,
146 | act=Sine,
147 | )
148 |
149 | def init(m):
150 | if isinstance(m, nn.Linear):
151 | std = 1.0 / math.sqrt(m.weight.size(1))
152 | m.weight.data.uniform_(-2.0 * std, 2.0 * std)
153 | m.bias.data.zero_()
154 |
155 | self.odefunc.apply(init)
156 |
157 | def forward(self, t, state):
158 | return self.odefunc(state)
159 |
160 | def simulate(self, times):
161 | x0 = torch.cat([self.initial_pos, self.initial_aug]).reshape(-1)
162 | solution = odeint(self, x0, times, atol=1e-8, rtol=1e-8, method="dopri5")
163 | trajectory = solution[:, 0]
164 | return trajectory, []
165 |
166 |
167 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None, act=nn.ReLU):
168 | if hidden_depth == 0:
169 | mods = [nn.Linear(input_dim, output_dim)]
170 | else:
171 | mods = [nn.Linear(input_dim, hidden_dim), act()]
172 | for i in range(hidden_depth - 1):
173 | mods += [nn.Linear(hidden_dim, hidden_dim), act()]
174 | mods.append(nn.Linear(hidden_dim, output_dim))
175 | if output_mod is not None:
176 | mods.append(output_mod)
177 | trunk = nn.Sequential(*mods)
178 | return trunk
179 |
180 |
181 | def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0):
182 | global_step = min(global_step, decay_steps)
183 | cosine_decay = 0.5 * (1 + math.cos(math.pi * global_step / decay_steps))
184 | decayed = (1 - alpha) * cosine_decay + alpha
185 | return learning_rate * decayed
186 |
187 |
188 | def learning_rate_schedule(
189 | global_step, warmup_steps, base_learning_rate, lr_scaling, train_steps
190 | ):
191 | warmup_steps = int(round(warmup_steps))
192 | scaled_lr = base_learning_rate * lr_scaling
193 | if warmup_steps:
194 | learning_rate = global_step / warmup_steps * scaled_lr
195 | else:
196 | learning_rate = scaled_lr
197 |
198 | if global_step < warmup_steps:
199 | learning_rate = learning_rate
200 | else:
201 | learning_rate = cosine_decay(
202 | scaled_lr, global_step - warmup_steps, train_steps - warmup_steps
203 | )
204 | return learning_rate
205 |
206 |
207 | def set_learning_rate(optimizer, lr):
208 | for group in optimizer.param_groups:
209 | group["lr"] = lr
210 |
211 |
212 | if __name__ == "__main__":
213 |
214 | parser = argparse.ArgumentParser()
215 | parser.add_argument("--base_lr", type=float, default=0.1)
216 | parser.add_argument("--num_iterations", type=int, default=1000)
217 | parser.add_argument("--no_events", action="store_true")
218 | parser.add_argument("--save", type=str, default="figs")
219 | args = parser.parse_args()
220 |
221 | torch.manual_seed(0)
222 |
223 | torch.set_default_dtype(torch.float64)
224 |
225 | with torch.no_grad():
226 | system = BouncingBallExample()
227 | obs_times, gt_trajectory, _, _ = system.simulate(nbounces=4)
228 |
229 | obs_times = obs_times[:300]
230 | gt_trajectory = gt_trajectory[:300]
231 |
232 | if args.no_events:
233 | model = NeuralODE()
234 | else:
235 | model = NeuralPhysics()
236 | optimizer = torch.optim.Adam(model.parameters(), lr=args.base_lr)
237 |
238 | decay = 1.0
239 |
240 | model.train()
241 | for itr in range(args.num_iterations):
242 | optimizer.zero_grad()
243 | trajectory, event_times = model.simulate(obs_times)
244 | weights = decay**obs_times
245 | loss = (
246 | ((trajectory - gt_trajectory) / (gt_trajectory + 1e-3))
247 | .abs()
248 | .mul(weights)
249 | .mean()
250 | )
251 | loss.backward()
252 |
253 | lr = learning_rate_schedule(itr, 0, args.base_lr, 1.0, args.num_iterations)
254 | set_learning_rate(optimizer, lr)
255 | optimizer.step()
256 |
257 | if itr % 10 == 0:
258 | print(itr, loss.item(), len(event_times))
259 |
260 | if itr % 10 == 0:
261 | plt.figure()
262 | plt.plot(
263 | obs_times.detach().cpu().numpy(),
264 | gt_trajectory.detach().cpu().numpy(),
265 | label="Target",
266 | )
267 | plt.plot(
268 | obs_times.detach().cpu().numpy(),
269 | trajectory.detach().cpu().numpy(),
270 | label="Learned",
271 | )
272 | plt.tight_layout()
273 | os.makedirs(args.save, exist_ok=True)
274 | plt.savefig(f"{args.save}/{itr:05d}.png")
275 | plt.close()
276 |
277 | if (itr + 1) % 100 == 0:
278 | torch.save(
279 | {
280 | "state_dict": model.state_dict(),
281 | },
282 | f"{args.save}/model.pt",
283 | )
284 |
285 | del trajectory, loss
286 |
--------------------------------------------------------------------------------
/examples/ode_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import time
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 |
10 | parser = argparse.ArgumentParser('ODE demo')
11 | parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5')
12 | parser.add_argument('--data_size', type=int, default=1000)
13 | parser.add_argument('--batch_time', type=int, default=10)
14 | parser.add_argument('--batch_size', type=int, default=20)
15 | parser.add_argument('--niters', type=int, default=2000)
16 | parser.add_argument('--test_freq', type=int, default=20)
17 | parser.add_argument('--viz', action='store_true')
18 | parser.add_argument('--gpu', type=int, default=0)
19 | parser.add_argument('--adjoint', action='store_true')
20 | args = parser.parse_args()
21 |
22 | if args.adjoint:
23 | from torchdiffeq import odeint_adjoint as odeint
24 | else:
25 | from torchdiffeq import odeint
26 |
27 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
28 |
29 | true_y0 = torch.tensor([[2., 0.]]).to(device)
30 | t = torch.linspace(0., 25., args.data_size).to(device)
31 | true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)
32 |
33 |
34 | class Lambda(nn.Module):
35 |
36 | def forward(self, t, y):
37 | return torch.mm(y**3, true_A)
38 |
39 |
40 | with torch.no_grad():
41 | true_y = odeint(Lambda(), true_y0, t, method='dopri5')
42 |
43 |
44 | def get_batch():
45 | s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
46 | batch_y0 = true_y[s] # (M, D)
47 | batch_t = t[:args.batch_time] # (T)
48 | batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
49 | return batch_y0.to(device), batch_t.to(device), batch_y.to(device)
50 |
51 |
52 | def makedirs(dirname):
53 | if not os.path.exists(dirname):
54 | os.makedirs(dirname)
55 |
56 |
57 | if args.viz:
58 | makedirs('png')
59 | import matplotlib.pyplot as plt
60 | fig = plt.figure(figsize=(12, 4), facecolor='white')
61 | ax_traj = fig.add_subplot(131, frameon=False)
62 | ax_phase = fig.add_subplot(132, frameon=False)
63 | ax_vecfield = fig.add_subplot(133, frameon=False)
64 | plt.show(block=False)
65 |
66 |
67 | def visualize(true_y, pred_y, odefunc, itr):
68 |
69 | if args.viz:
70 |
71 | ax_traj.cla()
72 | ax_traj.set_title('Trajectories')
73 | ax_traj.set_xlabel('t')
74 | ax_traj.set_ylabel('x,y')
75 | ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')
76 | ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')
77 | ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
78 | ax_traj.set_ylim(-2, 2)
79 | ax_traj.legend()
80 |
81 | ax_phase.cla()
82 | ax_phase.set_title('Phase Portrait')
83 | ax_phase.set_xlabel('x')
84 | ax_phase.set_ylabel('y')
85 | ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
86 | ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
87 | ax_phase.set_xlim(-2, 2)
88 | ax_phase.set_ylim(-2, 2)
89 |
90 | ax_vecfield.cla()
91 | ax_vecfield.set_title('Learned Vector Field')
92 | ax_vecfield.set_xlabel('x')
93 | ax_vecfield.set_ylabel('y')
94 |
95 | y, x = np.mgrid[-2:2:21j, -2:2:21j]
96 | dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
97 | mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
98 | dydt = (dydt / mag)
99 | dydt = dydt.reshape(21, 21, 2)
100 |
101 | ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
102 | ax_vecfield.set_xlim(-2, 2)
103 | ax_vecfield.set_ylim(-2, 2)
104 |
105 | fig.tight_layout()
106 | plt.savefig('png/{:03d}'.format(itr))
107 | plt.draw()
108 | plt.pause(0.001)
109 |
110 |
111 | class ODEFunc(nn.Module):
112 |
113 | def __init__(self):
114 | super(ODEFunc, self).__init__()
115 |
116 | self.net = nn.Sequential(
117 | nn.Linear(2, 50),
118 | nn.Tanh(),
119 | nn.Linear(50, 2),
120 | )
121 |
122 | for m in self.net.modules():
123 | if isinstance(m, nn.Linear):
124 | nn.init.normal_(m.weight, mean=0, std=0.1)
125 | nn.init.constant_(m.bias, val=0)
126 |
127 | def forward(self, t, y):
128 | return self.net(y**3)
129 |
130 |
131 | class RunningAverageMeter(object):
132 | """Computes and stores the average and current value"""
133 |
134 | def __init__(self, momentum=0.99):
135 | self.momentum = momentum
136 | self.reset()
137 |
138 | def reset(self):
139 | self.val = None
140 | self.avg = 0
141 |
142 | def update(self, val):
143 | if self.val is None:
144 | self.avg = val
145 | else:
146 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
147 | self.val = val
148 |
149 |
150 | if __name__ == '__main__':
151 |
152 | ii = 0
153 |
154 | func = ODEFunc().to(device)
155 |
156 | optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
157 | end = time.time()
158 |
159 | time_meter = RunningAverageMeter(0.97)
160 |
161 | loss_meter = RunningAverageMeter(0.97)
162 |
163 | for itr in range(1, args.niters + 1):
164 | optimizer.zero_grad()
165 | batch_y0, batch_t, batch_y = get_batch()
166 | pred_y = odeint(func, batch_y0, batch_t).to(device)
167 | loss = torch.mean(torch.abs(pred_y - batch_y))
168 | loss.backward()
169 | optimizer.step()
170 |
171 | time_meter.update(time.time() - end)
172 | loss_meter.update(loss.item())
173 |
174 | if itr % args.test_freq == 0:
175 | with torch.no_grad():
176 | pred_y = odeint(func, true_y0, t)
177 | loss = torch.mean(torch.abs(pred_y - true_y))
178 | print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
179 | visualize(true_y, pred_y, func, ii)
180 | ii += 1
181 |
182 | end = time.time()
183 |
--------------------------------------------------------------------------------
/examples/odenet_mnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import time
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from torch.utils.data import DataLoader
9 | import torchvision.datasets as datasets
10 | import torchvision.transforms as transforms
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
14 | parser.add_argument('--tol', type=float, default=1e-3)
15 | parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
16 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
17 | parser.add_argument('--nepochs', type=int, default=160)
18 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
19 | parser.add_argument('--lr', type=float, default=0.1)
20 | parser.add_argument('--batch_size', type=int, default=128)
21 | parser.add_argument('--test_batch_size', type=int, default=1000)
22 |
23 | parser.add_argument('--save', type=str, default='./experiment1')
24 | parser.add_argument('--debug', action='store_true')
25 | parser.add_argument('--gpu', type=int, default=0)
26 | args = parser.parse_args()
27 |
28 | if args.adjoint:
29 | from torchdiffeq import odeint_adjoint as odeint
30 | else:
31 | from torchdiffeq import odeint
32 |
33 |
34 | def conv3x3(in_planes, out_planes, stride=1):
35 | """3x3 convolution with padding"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
37 |
38 |
39 | def conv1x1(in_planes, out_planes, stride=1):
40 | """1x1 convolution"""
41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
42 |
43 |
44 | def norm(dim):
45 | return nn.GroupNorm(min(32, dim), dim)
46 |
47 |
48 | class ResBlock(nn.Module):
49 | expansion = 1
50 |
51 | def __init__(self, inplanes, planes, stride=1, downsample=None):
52 | super(ResBlock, self).__init__()
53 | self.norm1 = norm(inplanes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.downsample = downsample
56 | self.conv1 = conv3x3(inplanes, planes, stride)
57 | self.norm2 = norm(planes)
58 | self.conv2 = conv3x3(planes, planes)
59 |
60 | def forward(self, x):
61 | shortcut = x
62 |
63 | out = self.relu(self.norm1(x))
64 |
65 | if self.downsample is not None:
66 | shortcut = self.downsample(out)
67 |
68 | out = self.conv1(out)
69 | out = self.norm2(out)
70 | out = self.relu(out)
71 | out = self.conv2(out)
72 |
73 | return out + shortcut
74 |
75 |
76 | class ConcatConv2d(nn.Module):
77 |
78 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
79 | super(ConcatConv2d, self).__init__()
80 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
81 | self._layer = module(
82 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
83 | bias=bias
84 | )
85 |
86 | def forward(self, t, x):
87 | tt = torch.ones_like(x[:, :1, :, :]) * t
88 | ttx = torch.cat([tt, x], 1)
89 | return self._layer(ttx)
90 |
91 |
92 | class ODEfunc(nn.Module):
93 |
94 | def __init__(self, dim):
95 | super(ODEfunc, self).__init__()
96 | self.norm1 = norm(dim)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
99 | self.norm2 = norm(dim)
100 | self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
101 | self.norm3 = norm(dim)
102 | self.nfe = 0
103 |
104 | def forward(self, t, x):
105 | self.nfe += 1
106 | out = self.norm1(x)
107 | out = self.relu(out)
108 | out = self.conv1(t, out)
109 | out = self.norm2(out)
110 | out = self.relu(out)
111 | out = self.conv2(t, out)
112 | out = self.norm3(out)
113 | return out
114 |
115 |
116 | class ODEBlock(nn.Module):
117 |
118 | def __init__(self, odefunc):
119 | super(ODEBlock, self).__init__()
120 | self.odefunc = odefunc
121 | self.integration_time = torch.tensor([0, 1]).float()
122 |
123 | def forward(self, x):
124 | self.integration_time = self.integration_time.type_as(x)
125 | out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
126 | return out[1]
127 |
128 | @property
129 | def nfe(self):
130 | return self.odefunc.nfe
131 |
132 | @nfe.setter
133 | def nfe(self, value):
134 | self.odefunc.nfe = value
135 |
136 |
137 | class Flatten(nn.Module):
138 |
139 | def __init__(self):
140 | super(Flatten, self).__init__()
141 |
142 | def forward(self, x):
143 | shape = torch.prod(torch.tensor(x.shape[1:])).item()
144 | return x.view(-1, shape)
145 |
146 |
147 | class RunningAverageMeter(object):
148 | """Computes and stores the average and current value"""
149 |
150 | def __init__(self, momentum=0.99):
151 | self.momentum = momentum
152 | self.reset()
153 |
154 | def reset(self):
155 | self.val = None
156 | self.avg = 0
157 |
158 | def update(self, val):
159 | if self.val is None:
160 | self.avg = val
161 | else:
162 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
163 | self.val = val
164 |
165 |
166 | def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
167 | if data_aug:
168 | transform_train = transforms.Compose([
169 | transforms.RandomCrop(28, padding=4),
170 | transforms.ToTensor(),
171 | ])
172 | else:
173 | transform_train = transforms.Compose([
174 | transforms.ToTensor(),
175 | ])
176 |
177 | transform_test = transforms.Compose([
178 | transforms.ToTensor(),
179 | ])
180 |
181 | train_loader = DataLoader(
182 | datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
183 | shuffle=True, num_workers=2, drop_last=True
184 | )
185 |
186 | train_eval_loader = DataLoader(
187 | datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
188 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
189 | )
190 |
191 | test_loader = DataLoader(
192 | datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
193 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
194 | )
195 |
196 | return train_loader, test_loader, train_eval_loader
197 |
198 |
199 | def inf_generator(iterable):
200 | """Allows training with DataLoaders in a single infinite loop:
201 | for i, (x, y) in enumerate(inf_generator(train_loader)):
202 | """
203 | iterator = iterable.__iter__()
204 | while True:
205 | try:
206 | yield iterator.__next__()
207 | except StopIteration:
208 | iterator = iterable.__iter__()
209 |
210 |
211 | def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
212 | initial_learning_rate = args.lr * batch_size / batch_denom
213 |
214 | boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
215 | vals = [initial_learning_rate * decay for decay in decay_rates]
216 |
217 | def learning_rate_fn(itr):
218 | lt = [itr < b for b in boundaries] + [True]
219 | i = np.argmax(lt)
220 | return vals[i]
221 |
222 | return learning_rate_fn
223 |
224 |
225 | def one_hot(x, K):
226 | return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
227 |
228 |
229 | def accuracy(model, dataset_loader):
230 | total_correct = 0
231 | for x, y in dataset_loader:
232 | x = x.to(device)
233 | y = one_hot(np.array(y.numpy()), 10)
234 |
235 | target_class = np.argmax(y, axis=1)
236 | predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
237 | total_correct += np.sum(predicted_class == target_class)
238 | return total_correct / len(dataset_loader.dataset)
239 |
240 |
241 | def count_parameters(model):
242 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
243 |
244 |
245 | def makedirs(dirname):
246 | if not os.path.exists(dirname):
247 | os.makedirs(dirname)
248 |
249 |
250 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
251 | logger = logging.getLogger()
252 | if debug:
253 | level = logging.DEBUG
254 | else:
255 | level = logging.INFO
256 | logger.setLevel(level)
257 | if saving:
258 | info_file_handler = logging.FileHandler(logpath, mode="a")
259 | info_file_handler.setLevel(level)
260 | logger.addHandler(info_file_handler)
261 | if displaying:
262 | console_handler = logging.StreamHandler()
263 | console_handler.setLevel(level)
264 | logger.addHandler(console_handler)
265 | logger.info(filepath)
266 | with open(filepath, "r") as f:
267 | logger.info(f.read())
268 |
269 | for f in package_files:
270 | logger.info(f)
271 | with open(f, "r") as package_f:
272 | logger.info(package_f.read())
273 |
274 | return logger
275 |
276 |
277 | if __name__ == '__main__':
278 |
279 | makedirs(args.save)
280 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
281 | logger.info(args)
282 |
283 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
284 |
285 | is_odenet = args.network == 'odenet'
286 |
287 | if args.downsampling_method == 'conv':
288 | downsampling_layers = [
289 | nn.Conv2d(1, 64, 3, 1),
290 | norm(64),
291 | nn.ReLU(inplace=True),
292 | nn.Conv2d(64, 64, 4, 2, 1),
293 | norm(64),
294 | nn.ReLU(inplace=True),
295 | nn.Conv2d(64, 64, 4, 2, 1),
296 | ]
297 | elif args.downsampling_method == 'res':
298 | downsampling_layers = [
299 | nn.Conv2d(1, 64, 3, 1),
300 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
301 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
302 | ]
303 |
304 | feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
305 | fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
306 |
307 | model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
308 |
309 | logger.info(model)
310 | logger.info('Number of parameters: {}'.format(count_parameters(model)))
311 |
312 | criterion = nn.CrossEntropyLoss().to(device)
313 |
314 | train_loader, test_loader, train_eval_loader = get_mnist_loaders(
315 | args.data_aug, args.batch_size, args.test_batch_size
316 | )
317 |
318 | data_gen = inf_generator(train_loader)
319 | batches_per_epoch = len(train_loader)
320 |
321 | lr_fn = learning_rate_with_decay(
322 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
323 | decay_rates=[1, 0.1, 0.01, 0.001]
324 | )
325 |
326 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
327 |
328 | best_acc = 0
329 | batch_time_meter = RunningAverageMeter()
330 | f_nfe_meter = RunningAverageMeter()
331 | b_nfe_meter = RunningAverageMeter()
332 | end = time.time()
333 |
334 | for itr in range(args.nepochs * batches_per_epoch):
335 |
336 | for param_group in optimizer.param_groups:
337 | param_group['lr'] = lr_fn(itr)
338 |
339 | optimizer.zero_grad()
340 | x, y = data_gen.__next__()
341 | x = x.to(device)
342 | y = y.to(device)
343 | logits = model(x)
344 | loss = criterion(logits, y)
345 |
346 | if is_odenet:
347 | nfe_forward = feature_layers[0].nfe
348 | feature_layers[0].nfe = 0
349 |
350 | loss.backward()
351 | optimizer.step()
352 |
353 | if is_odenet:
354 | nfe_backward = feature_layers[0].nfe
355 | feature_layers[0].nfe = 0
356 |
357 | batch_time_meter.update(time.time() - end)
358 | if is_odenet:
359 | f_nfe_meter.update(nfe_forward)
360 | b_nfe_meter.update(nfe_backward)
361 | end = time.time()
362 |
363 | if itr % batches_per_epoch == 0:
364 | with torch.no_grad():
365 | train_acc = accuracy(model, train_eval_loader)
366 | val_acc = accuracy(model, test_loader)
367 | if val_acc > best_acc:
368 | torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
369 | best_acc = val_acc
370 | logger.info(
371 | "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
372 | "Train Acc {:.4f} | Test Acc {:.4f}".format(
373 | itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
374 | b_nfe_meter.avg, train_acc, val_acc
375 | )
376 | )
377 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import setuptools
4 |
5 |
6 | # for simplicity we actually store the version in the __version__ attribute in the source
7 | here = os.path.realpath(os.path.dirname(__file__))
8 | with open(os.path.join(here, 'torchdiffeq', '__init__.py')) as f:
9 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
10 | if meta_match:
11 | version = meta_match.group(1)
12 | else:
13 | raise RuntimeError("Unable to find __version__ string.")
14 |
15 |
16 | setuptools.setup(
17 | name="torchdiffeq",
18 | version=version,
19 | author="Ricky Tian Qi Chen",
20 | author_email="rtqichen@cs.toronto.edu",
21 | description="ODE solvers and adjoint sensitivity analysis in PyTorch.",
22 | url="https://github.com/rtqichen/torchdiffeq",
23 | packages=setuptools.find_packages(),
24 | install_requires=['torch>=1.5.0', 'scipy>=1.4.0'],
25 | python_requires='~=3.6',
26 | classifiers=[
27 | "Programming Language :: Python :: 3",
28 | "License :: OSI Approved :: MIT License",
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/tests/DETEST/detest.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | ####################################
6 | # Problem Class A. Single equations.
7 | ####################################
8 | def A1():
9 | diffeq = lambda t, y: -y
10 | init = lambda: (torch.tensor(0.), torch.tensor(1.))
11 | solution = lambda t: torch.exp(-t)
12 | return diffeq, init, solution
13 |
14 |
15 | def A2():
16 | diffeq = lambda t, y: -y**3 / 2
17 | init = lambda: (torch.tensor(0.), torch.tensor(1.))
18 | solution = lambda t: 1 / torch.sqrt(t + 1)
19 | return diffeq, init, solution
20 |
21 |
22 | def A3():
23 | diffeq = lambda t, y: y * torch.cos(t)
24 | init = lambda: (torch.tensor(0.), torch.tensor(1.))
25 | solution = lambda t: torch.exp(torch.sin(t))
26 | return diffeq, init, solution
27 |
28 |
29 | def A4():
30 | diffeq = lambda t, y: y / 4 * (1 - y / 20)
31 | init = lambda: (torch.tensor(0.), torch.tensor(1.))
32 | solution = lambda t: 20 / (1 + 19 * torch.exp(-t / 4))
33 | return diffeq, init, solution
34 |
35 |
36 | def A5():
37 | diffeq = lambda t, y: (y - t) / (y + t)
38 | init = lambda: (torch.tensor(0.), torch.tensor(4.))
39 | return diffeq, init, None
40 |
41 |
42 | #################################
43 | # Problem Class B. Small systems.
44 | #################################
45 | def B1():
46 |
47 | def diffeq(t, y):
48 | dy0 = 2 * (y[0] - y[0] * y[1])
49 | dy1 = -(y[1] - y[0] * y[1])
50 | return torch.stack([dy0, dy1])
51 |
52 | def init():
53 | return torch.tensor(0.), torch.tensor([1., 3.])
54 |
55 | return diffeq, init, None
56 |
57 |
58 | def B2():
59 |
60 | A = torch.tensor([[-1., 1., 0.], [1., -2., 1.], [0., 1., -1.]])
61 |
62 | def diffeq(t, y):
63 | dy = torch.mv(A, y)
64 | return dy
65 |
66 | def init():
67 | return torch.tensor(0.), torch.tensor([2., 0., 1.])
68 |
69 | return diffeq, init, None
70 |
71 |
72 | def B3():
73 |
74 | def diffeq(t, y):
75 | dy0 = -y[0]
76 | dy1 = y[0] - y[1] * y[1]
77 | dy2 = y[1] * y[1]
78 | return torch.stack([dy0, dy1, dy2])
79 |
80 | def init():
81 | return torch.tensor(0.), torch.tensor([1., 0., 0.])
82 |
83 | return diffeq, init, None
84 |
85 |
86 | def B4():
87 |
88 | def diffeq(t, y):
89 | a = torch.sqrt(y[0] * y[0] + y[1] * y[1])
90 | dy0 = -y[1] - y[0] * y[2] / a
91 | dy1 = y[0] - y[1] * y[2] / a
92 | dy2 = y[0] / a
93 | return torch.stack([dy0, dy1, dy2])
94 |
95 | def init():
96 | return torch.tensor(0.), torch.tensor([3., 0., 0.])
97 |
98 | return diffeq, init, None
99 |
100 |
101 | def B5():
102 |
103 | def diffeq(t, y):
104 | dy0 = y[1] * y[2]
105 | dy1 = -y[0] * y[2]
106 | dy2 = -0.51 * y[0] * y[1]
107 | return torch.stack([dy0, dy1, dy2])
108 |
109 | def init():
110 | return torch.tensor(0.), torch.tensor([0., 1., 1.])
111 |
112 | return diffeq, init, None
113 |
114 |
115 | ####################################
116 | # Problem Class C. Moderate systems.
117 | ####################################
118 | def C1():
119 |
120 | A = torch.zeros(10, 10)
121 | A.view(-1)[:-1:11] = -1
122 | A.view(-1)[10::11] = 1
123 |
124 | def diffeq(t, y):
125 | return torch.mv(A, y)
126 |
127 | def init():
128 | y0 = torch.zeros(10)
129 | y0[0] = 1
130 | return torch.tensor(0.), y0
131 |
132 | return diffeq, init, None
133 |
134 |
135 | def C2():
136 |
137 | A = torch.zeros(10, 10)
138 | A.view(-1)[:-1:11] = torch.linspace(-1, -9, 9)
139 | A.view(-1)[10::11] = torch.linspace(1, 9, 9)
140 |
141 | def diffeq(t, y):
142 | return torch.mv(A, y)
143 |
144 | def init():
145 | y0 = torch.zeros(10)
146 | y0[0] = 1
147 | return torch.tensor(0.), y0
148 |
149 | return diffeq, init, None
150 |
151 |
152 | def C3():
153 | n = 10
154 | A = torch.zeros(n, n)
155 | A.view(-1)[::n + 1] = -2
156 | A.view(-1)[n::n + 1] = 1
157 | A.view(-1)[1::n + 1] = 1
158 |
159 | def diffeq(t, y):
160 | return torch.mv(A, y)
161 |
162 | def init():
163 | y0 = torch.zeros(n)
164 | y0[0] = 1
165 | return torch.tensor(0.), y0
166 |
167 | return diffeq, init, None
168 |
169 |
170 | def C4():
171 | n = 51
172 | A = torch.zeros(n, n)
173 | A.view(-1)[::n + 1] = -2
174 | A.view(-1)[n::n + 1] = 1
175 | A.view(-1)[1::n + 1] = 1
176 |
177 | def diffeq(t, y):
178 | return torch.mv(A, y)
179 |
180 | def init():
181 | y0 = torch.zeros(n)
182 | y0[0] = 1
183 | return torch.tensor(0.), y0
184 |
185 | return diffeq, init, None
186 |
187 |
188 | def C5():
189 |
190 | k2 = torch.tensor(2.95912208286)
191 | m0 = torch.tensor(1.00000597682)
192 | m = torch.tensor([
193 | 0.000954786104043,
194 | 0.000285583733151,
195 | 0.0000437273164546,
196 | 0.0000517759138449,
197 | 0.00000277777777778,
198 | ]).view(1, 5)
199 |
200 | def diffeq(t, y):
201 | # y is 2 x 3 x 5
202 | # y[0] contains y, y[0] contains y'
203 | # second axis indexes space (x,y,z).
204 | # third axis indexes 5 bodies.
205 |
206 | dy = y[1, :, :]
207 | y = y[0]
208 | r = torch.sqrt(torch.sum(y**2, 0)).view(1, 5)
209 | d = torch.sqrt(torch.sum((y[:, :, None] - y[:, None, :])**2, 0))
210 | F = m.view(1, 1, 5) * ((y[:, None, :] - y[:, :, None]) / (d * d * d).view(1, 5, 5) + y.view(3, 1, 5) /
211 | (r * r * r).view(1, 1, 5))
212 | F.view(3, 5 * 5)[:, ::6] = 0
213 | ddy = k2 * (-(m0 + m) * y / (r * r * r)) + F.sum(2)
214 | return torch.stack([dy, ddy], 0)
215 |
216 | def init():
217 | y0 = torch.tensor([
218 | 3.42947415189, 3.35386959711, 1.35494901715, 6.64145542550, 5.97156957878, 2.18231499728, 11.2630437207,
219 | 14.6952576794, 6.27960525067, -30.1552268759, 165699966404, 1.43785752721, -21.1238353380, 28.4465098142,
220 | 15.388265967
221 | ]).view(5, 3).transpose(0, 1)
222 |
223 | dy0 = torch.tensor([
224 | -.557160570446, .505696783289, .230578543901, -.415570776342, .365682722812, .169143213293, -.325325669158,
225 | .189706021964, .0877265322780, -.0240476254170, -.287659532608, -.117219543175, -.176860753121,
226 | -.216393453025, -.0148647893090
227 | ]).view(5, 3).transpose(0, 1)
228 |
229 | return torch.tensor(0.), torch.stack([y0, dy0], 0)
230 |
231 | return diffeq, init, None
232 |
233 |
234 | ###################################
235 | # Problem Class D. Orbit equations.
236 | ###################################
237 | def _DTemplate(eps):
238 |
239 | def diffeq(t, y):
240 | r = (y[0]**2 + y[1]**2)**(3 / 2)
241 | dy0 = y[2]
242 | dy1 = y[3]
243 | dy2 = -y[0] / r
244 | dy3 = -y[1] / r
245 | return torch.stack([dy0, dy1, dy2, dy3])
246 |
247 | def init():
248 | return torch.tensor(0.), torch.tensor([1 - eps, 0, 0, math.sqrt((1 + eps) / (1 - eps))])
249 |
250 | return diffeq, init, None
251 |
252 |
253 | D1 = lambda: _DTemplate(0.1)
254 | D2 = lambda: _DTemplate(0.3)
255 | D3 = lambda: _DTemplate(0.5)
256 | D4 = lambda: _DTemplate(0.7)
257 | D5 = lambda: _DTemplate(0.9)
258 |
259 |
260 | ##########################################
261 | # Problem Class E. Higher order equations.
262 | ##########################################
263 | def E1():
264 |
265 | def diffeq(t, y):
266 | dy0 = y[1]
267 | dy1 = -(y[1] / (t + 1) + (1 - 0.25 / (t + 1)**2) * y[0])
268 | return torch.stack([dy0, dy1])
269 |
270 | def init():
271 | return torch.tensor(0.), torch.tensor([.671396707141803, .0954005144474744])
272 |
273 | return diffeq, init, None
274 |
275 |
276 | def E2():
277 |
278 | def diffeq(t, y):
279 | dy0 = y[1]
280 | dy1 = (1 - y[0]**2) * y[1] - y[0]
281 | return torch.stack([dy0, dy1])
282 |
283 | def init():
284 | return torch.tensor(0.), torch.tensor([2., 0.])
285 |
286 | return diffeq, init, None
287 |
288 |
289 | def E3():
290 |
291 | def diffeq(t, y):
292 | dy0 = y[1]
293 | dy1 = y[0]**3 / 6 - y[0] + 2 * torch.sin(2.78535 * t)
294 | return torch.stack([dy0, dy1])
295 |
296 | def init():
297 | return torch.tensor(0.), torch.tensor([0., 0.])
298 |
299 | return diffeq, init, None
300 |
301 |
302 | def E4():
303 |
304 | def diffeq(t, y):
305 | dy0 = y[1]
306 | dy1 = .32 - .4 * y[1]**2
307 | return torch.stack([dy0, dy1])
308 |
309 | def init():
310 | return torch.tensor(0.), torch.tensor([30., 0.])
311 |
312 | return diffeq, init, None
313 |
314 |
315 | def E5():
316 |
317 | def diffeq(t, y):
318 | dy0 = y[1]
319 | dy1 = torch.sqrt(1 + y[1]**2) / (25 - t)
320 | return torch.stack([dy0, dy1])
321 |
322 | def init():
323 | return torch.tensor(0.), torch.tensor([0., 0.])
324 |
325 | return diffeq, init, None
326 |
327 |
328 | ###################
329 | # Helper functions.
330 | ###################
331 | def _to_tensor(x):
332 | if not torch.is_tensor(x):
333 | x = torch.tensor(x)
334 | return x
335 |
--------------------------------------------------------------------------------
/tests/DETEST/run.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from scipy.stats.mstats import gmean
4 | import torch
5 | from torchdiffeq import odeint
6 | import detest
7 |
8 | torch.set_default_tensor_type(torch.DoubleTensor)
9 |
10 |
11 | class NFEDiffEq:
12 |
13 | def __init__(self, diffeq):
14 | self.diffeq = diffeq
15 | self.nfe = 0
16 |
17 | def __call__(self, t, y):
18 | self.nfe += 1
19 | return self.diffeq(t, y)
20 |
21 |
22 | def main():
23 |
24 | sol = dict()
25 | for method in ['dopri5', 'adams']:
26 | for tol in [1e-3, 1e-6, 1e-9]:
27 | print('======= {} | tol={:e} ======='.format(method, tol))
28 | nfes = []
29 | times = []
30 | errs = []
31 | for c in ['A', 'B', 'C', 'D', 'E']:
32 | for i in ['1', '2', '3', '4', '5']:
33 | diffeq, init, _ = getattr(detest, c + i)()
34 | t0, y0 = init()
35 | diffeq = NFEDiffEq(diffeq)
36 |
37 | if not c + i in sol:
38 | sol[c + i] = odeint(
39 | diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=1e-12, rtol=1e-12, method='dopri5'
40 | )[1]
41 | diffeq.nfe = 0
42 |
43 | start_time = time.time()
44 | est = odeint(diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=tol, rtol=tol, method=method)
45 | time_spent = time.time() - start_time
46 |
47 | error = torch.sqrt(torch.mean((sol[c + i] - est[1])**2))
48 |
49 | errs.append(error.item())
50 | nfes.append(diffeq.nfe)
51 | times.append(time_spent)
52 |
53 | print('{}: NFE {} | Time {} | Err {:e}'.format(c + i, diffeq.nfe, time_spent, error.item()))
54 |
55 | print('Total NFE {} | Total Time {} | GeomAvg Error {:e}'.format(np.sum(nfes), np.sum(times), gmean(errs)))
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
--------------------------------------------------------------------------------
/tests/api_tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchdiffeq
4 |
5 | from problems import construct_problem, DTYPES, DEVICES, ADAPTIVE_METHODS
6 |
7 |
8 | EPS = {torch.float32: 1e-4, torch.float64: 1e-12, torch.complex64: 1e-4}
9 |
10 |
11 | class TestCollectionState(unittest.TestCase):
12 | def test_forward(self):
13 | for dtype in DTYPES:
14 | eps = EPS[dtype]
15 | for device in DEVICES:
16 | f, y0, t_points, sol = construct_problem(dtype=dtype, device=device)
17 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
18 | tuple_y0 = (y0, y0)
19 | for method in ADAPTIVE_METHODS:
20 |
21 | with self.subTest(dtype=dtype, device=device, method=method):
22 | tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method=method)
23 | max_error0 = (sol - tuple_y[0]).abs().max()
24 | max_error1 = (sol - tuple_y[1]).abs().max()
25 | self.assertLess(max_error0, eps)
26 | self.assertLess(max_error1, eps)
27 |
28 | def test_gradient(self):
29 | for device in DEVICES:
30 | f, y0, t_points, sol = construct_problem(device=device)
31 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
32 | for method in ADAPTIVE_METHODS:
33 | if method == "scipy_solver":
34 | continue
35 |
36 | with self.subTest(device=device, method=method):
37 | for i in range(2):
38 | func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method=method)[i]
39 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
40 |
41 |
42 | if __name__ == '__main__':
43 | unittest.main()
44 |
--------------------------------------------------------------------------------
/tests/event_tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchdiffeq
4 |
5 | from problems import construct_problem, DTYPES, DEVICES, METHODS, FIXED_METHODS
6 |
7 |
8 | def rel_error(true, estimate):
9 | return ((true - estimate) / true).abs().max()
10 |
11 |
12 | class TestEventHandling(unittest.TestCase):
13 |
14 | def test_odeint(self):
15 | for reverse in (False, True):
16 | for dtype in DTYPES:
17 | for device in DEVICES:
18 | for method in METHODS:
19 |
20 | # TODO: remove after event handling gets enabled.
21 | if method == 'scipy_solver':
22 | continue
23 |
24 | for ode in ("constant", "sine"):
25 | with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method):
26 | if method == "explicit_adams":
27 | tol = 7e-2
28 | elif method == "euler" or method == "implicit_euler":
29 | tol = 5e-3
30 | elif method == "gl6":
31 | tol = 2e-3
32 | else:
33 | tol = 1e-4
34 |
35 | f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
36 | reverse=reverse)
37 |
38 | def event_fn(t, y):
39 | return torch.sum(y - sol[2]).real
40 |
41 | if method in FIXED_METHODS:
42 | options = {"step_size": 0.01, "interp": "cubic"}
43 | else:
44 | options = {}
45 |
46 | t, y = torchdiffeq.odeint(f, y0, t_points[0:2], event_fn=event_fn, method=method, options=options)
47 | y = y[-1]
48 | self.assertLess(rel_error(sol[2], y), tol)
49 | self.assertLess(rel_error(t_points[2], t), tol)
50 |
51 | def test_adjoint(self):
52 | f, y0, t_points, sol = construct_problem(device="cpu", ode="constant")
53 |
54 | def event_fn(t, y):
55 | return torch.sum(y - sol[-1])
56 |
57 | t, y = torchdiffeq.odeint_adjoint(f, y0, t_points[0:2], event_fn=event_fn, method="dopri5")
58 | y = y[-1]
59 | self.assertLess(rel_error(sol[-1], y), 1e-4)
60 | self.assertLess(rel_error(t_points[-1], t), 1e-4)
61 |
62 | # Make sure adjoint mode backward code can still be run.
63 | t.backward(retain_graph=True)
64 | y.sum().backward()
65 |
66 |
67 | if __name__ == '__main__':
68 | unittest.main()
69 |
--------------------------------------------------------------------------------
/tests/gradient_tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchdiffeq
4 |
5 | from problems import construct_problem, PROBLEMS, DEVICES, METHODS
6 |
7 |
8 | def max_abs(tensor):
9 | return torch.max(torch.abs(tensor))
10 |
11 |
12 | class TestGradient(unittest.TestCase):
13 | def test_odeint(self):
14 | for device in DEVICES:
15 | for method in METHODS:
16 |
17 | if method == 'scipy_solver':
18 | continue
19 |
20 | with self.subTest(device=device, method=method):
21 | f, y0, t_points, _ = construct_problem(device=device)
22 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method=method)
23 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
24 |
25 | def test_adjoint(self):
26 | for device in DEVICES:
27 | for method in METHODS:
28 |
29 | with self.subTest(device=device, method=method):
30 | f, y0, t_points, _ = construct_problem(device=device)
31 | func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method=method)
32 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
33 |
34 | def test_adjoint_against_odeint(self):
35 | """
36 | Test against dopri5
37 | """
38 | for device in DEVICES:
39 | for ode in PROBLEMS:
40 | for t_grad in (True, False):
41 | if ode == 'constant':
42 | eps = 1e-12
43 | elif ode == 'linear':
44 | eps = 1e-5
45 | elif ode == 'sine':
46 | eps = 5e-3
47 | elif ode == 'exp':
48 | eps = 1e-2
49 | else:
50 | raise RuntimeError
51 |
52 | with self.subTest(device=device, ode=ode, t_grad=t_grad):
53 | f, y0, t_points, _ = construct_problem(device=device, ode=ode)
54 | t_points.requires_grad_(t_grad)
55 |
56 | ys = torchdiffeq.odeint(f, y0, t_points, rtol=1e-9, atol=1e-12)
57 | torch.manual_seed(0)
58 | gradys = torch.rand_like(ys)
59 | ys.backward(gradys)
60 |
61 | reg_y0_grad = y0.grad.clone()
62 | reg_t_grad = t_points.grad.clone() if t_grad else None
63 | reg_params_grads = []
64 | for param in f.parameters():
65 | reg_params_grads.append(param.grad.clone())
66 |
67 | y0.grad.zero_()
68 | if t_grad:
69 | t_points.grad.zero_()
70 | for param in f.parameters():
71 | param.grad.zero_()
72 |
73 | ys = torchdiffeq.odeint_adjoint(f, y0, t_points, rtol=1e-9, atol=1e-12)
74 | ys.backward(gradys)
75 |
76 | adj_y0_grad = y0.grad
77 | adj_t_grad = t_points.grad if t_grad else None
78 | adj_params_grads = []
79 | for param in f.parameters():
80 | adj_params_grads.append(param.grad)
81 |
82 | self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps)
83 | if t_grad:
84 | self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps)
85 | for reg_grad, adj_grad in zip(reg_params_grads, adj_params_grads):
86 | self.assertLess(max_abs(reg_grad - adj_grad), eps)
87 |
88 |
89 | class TestCompareAdjointGradient(unittest.TestCase):
90 |
91 | def problem(self, device):
92 | class Odefunc(torch.nn.Module):
93 | def __init__(self):
94 | super(Odefunc, self).__init__()
95 | self.A = torch.nn.Parameter(torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]))
96 | self.unused_module = torch.nn.Linear(2, 5)
97 |
98 | def forward(self, t, y):
99 | return torch.mm(y**3, self.A)
100 |
101 | y0 = torch.tensor([[2., 0.]], device=device, requires_grad=True)
102 | t_points = torch.linspace(0., 25., 10, device=device, requires_grad=True)
103 | func = Odefunc().to(device)
104 | return func, y0, t_points
105 |
106 | def test_against_dopri5(self):
107 | method_eps = {
108 | 'dopri5': (3e-4, 1e-4, 2e-3),
109 | 'scipy_solver': (3e-4, 1e-4, 2e-3),
110 | }
111 | for device in DEVICES:
112 | for method, eps in method_eps.items():
113 | for t_grad in (True, False):
114 | with self.subTest(device=device, method=method):
115 | func, y0, t_points = self.problem(device=device)
116 | t_points.requires_grad_(t_grad)
117 |
118 | ys = torchdiffeq.odeint_adjoint(func, y0, t_points, method=method)
119 | gradys = torch.rand_like(ys) * 0.1
120 | ys.backward(gradys)
121 |
122 | adj_y0_grad = y0.grad
123 | adj_t_grad = t_points.grad if t_grad else None
124 | adj_A_grad = func.A.grad
125 | self.assertEqual(max_abs(func.unused_module.weight.grad), 0)
126 | self.assertEqual(max_abs(func.unused_module.bias.grad), 0)
127 |
128 | func, y0, t_points = self.problem(device=device)
129 | ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5')
130 | ys.backward(gradys)
131 |
132 | self.assertLess(max_abs(y0.grad - adj_y0_grad), eps[0])
133 | if t_grad:
134 | self.assertLess(max_abs(t_points.grad - adj_t_grad), eps[1])
135 | self.assertLess(max_abs(func.A.grad - adj_A_grad), eps[2])
136 |
137 |
138 | if __name__ == '__main__':
139 | unittest.main()
140 |
--------------------------------------------------------------------------------
/tests/norm_tests.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import unittest
3 |
4 | import torch
5 | import torchdiffeq
6 |
7 | from problems import (DTYPES, DEVICES, ADAPTIVE_METHODS)
8 |
9 |
10 | @contextlib.contextmanager
11 | def random_seed_torch(seed):
12 | cpu_rng_state = torch.get_rng_state()
13 | torch.manual_seed(seed)
14 |
15 | try:
16 | yield
17 | finally:
18 | torch.set_rng_state(cpu_rng_state)
19 |
20 |
21 | class _NeuralF(torch.nn.Module):
22 | def __init__(self, width, oscillate):
23 | super(_NeuralF, self).__init__()
24 |
25 | # Use the same set of random weights for every instance.
26 | with random_seed_torch(0):
27 | self.linears = torch.nn.Sequential(torch.nn.Linear(2, width),
28 | torch.nn.Tanh(),
29 | torch.nn.Linear(width, 2),
30 | torch.nn.Tanh())
31 | self.nfe = 0
32 | self.oscillate = oscillate
33 |
34 | def forward(self, t, x):
35 | self.nfe += 1
36 | out = self.linears(x)
37 | if self.oscillate:
38 | out = out * t.mul(2).sin()
39 | return out
40 |
41 |
42 | class TestNorms(unittest.TestCase):
43 | def test_norm(self):
44 | def f(t, x):
45 | return x
46 | t = torch.tensor([0., 1.])
47 |
48 | # First test that tensor input appears in the norm.
49 | is_called = False
50 |
51 | def norm(state):
52 | nonlocal is_called
53 | is_called = True
54 | self.assertIsInstance(state, torch.Tensor)
55 | self.assertEqual(state.shape, ())
56 | return state.pow(2).mean().sqrt()
57 | x0 = torch.tensor(1.)
58 | torchdiffeq.odeint(f, x0, t, options=dict(norm=norm))
59 | self.assertTrue(is_called)
60 |
61 | # Now test that tupled input appears in the norm
62 | is_called = False
63 |
64 | def norm(state):
65 | nonlocal is_called
66 | is_called = True
67 | self.assertIsInstance(state, tuple)
68 | self.assertEqual(len(state), 1)
69 | state, = state
70 | self.assertEqual(state.shape, ())
71 | return state.pow(2).mean().sqrt()
72 | x0 = (torch.tensor(1.),)
73 | torchdiffeq.odeint(f, x0, t, options=dict(norm=norm))
74 | self.assertTrue(is_called)
75 |
76 | is_called = False
77 |
78 | def norm(state):
79 | nonlocal is_called
80 | is_called = True
81 | self.assertIsInstance(state, tuple)
82 | self.assertEqual(len(state), 2)
83 | state1, state2 = state
84 | self.assertEqual(state1.shape, ())
85 | self.assertEqual(state2.shape, (2, 2))
86 | return state1.pow(2).mean().sqrt()
87 | x0 = (torch.tensor(1.), torch.tensor([[0.5, 0.5], [0.1, 0.1]]))
88 | torchdiffeq.odeint(f, x0, t, options=dict(norm=norm))
89 | self.assertTrue(is_called)
90 |
91 | def test_adjoint_norm(self):
92 | def f(t, x):
93 | return x
94 | t = torch.tensor([0., 1.])
95 | adjoint_params = (torch.rand(7, requires_grad=True), torch.rand((), requires_grad=True))
96 |
97 | def make_spy_on_adjoint_norm(adjoint_norm, actual_norm):
98 | is_spy_called = [False]
99 |
100 | def spy_on_adjoint_norm(tensor):
101 | nonlocal is_spy_called
102 | is_spy_called[0] = True
103 | norm_result = adjoint_norm(tensor)
104 | true_norm_result = actual_norm(tensor)
105 | self.assertIsInstance(norm_result, torch.Tensor)
106 | self.assertEqual(norm_result.shape, true_norm_result.shape)
107 | self.assertLess((norm_result - true_norm_result).abs().max(), 1e-6)
108 | return norm_result
109 |
110 | return spy_on_adjoint_norm, is_spy_called
111 |
112 | # Test the various auto-constructed adjoint norms with tensor (not tuple) state
113 | for shape in ((), (1,), (2, 2)):
114 | for use_adjoint_options, seminorm in ((False, False), (True, False), (True, True)):
115 | with self.subTest(shape=shape, use_adjoint_options=use_adjoint_options, seminorm=seminorm):
116 | x0 = torch.full(shape, 1.)
117 | if use_adjoint_options:
118 | if seminorm:
119 | # Test passing adjoint_options and wanting the seminorm
120 | kwargs = dict(adjoint_options=dict(norm='seminorm'))
121 | else:
122 | # Test passing adjoint_options but not specify the adjoint norm
123 | kwargs = dict(adjoint_options={})
124 | else:
125 | # Test not passing adjoint_options at all.
126 | kwargs = {}
127 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params, **kwargs)
128 | _adjoint_norm = xs.grad_fn.adjoint_options['norm']
129 |
130 | is_called = False
131 |
132 | def actual_norm(tensor_tuple):
133 | nonlocal is_called
134 | is_called = True
135 | self.assertIsInstance(tensor_tuple, tuple)
136 | t, y, adj_y, adj_param1, adj_param2 = tensor_tuple
137 | self.assertEqual(t.shape, ())
138 | self.assertEqual(y.shape, shape)
139 | self.assertEqual(adj_y.shape, shape)
140 | self.assertEqual(adj_param1.shape, (7,))
141 | self.assertEqual(adj_param2.shape, (()))
142 | out = max(t.abs(), y.pow(2).mean().sqrt(), adj_y.pow(2).mean().sqrt())
143 | if not seminorm:
144 | out = max(out, adj_param1.pow(2).mean().sqrt(), adj_param2.abs())
145 | return out
146 |
147 | xs.grad_fn.adjoint_options['norm'], is_spy_called = make_spy_on_adjoint_norm(_adjoint_norm,
148 | actual_norm)
149 | xs.sum().backward()
150 | self.assertTrue(is_called)
151 | self.assertTrue(is_spy_called[0])
152 |
153 | # Test the various auto-constructed adjoint norms with tuple (not tensor) state
154 | for use_adjoint_options, seminorm in ((False, False), (True, False), (True, True)):
155 | with self.subTest(shape=shape, use_adjoint_options=use_adjoint_options, seminorm=seminorm):
156 | x0 = torch.tensor(1.), torch.tensor([[0.5, 0.5], [0.1, 0.1]])
157 | if use_adjoint_options:
158 | if seminorm:
159 | # Test passing adjoint_options and wanting the seminorm
160 | kwargs = dict(adjoint_options=dict(norm='seminorm'))
161 | else:
162 | # Test passing adjoint_options but not specify the adjoint norm
163 | kwargs = dict(adjoint_options={})
164 | else:
165 | # Test not passing adjoint_options at all.
166 | kwargs = {}
167 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params, **kwargs)
168 | adjoint_options_dict = xs[0].grad_fn.next_functions[0][0].next_functions[0][0].adjoint_options
169 | _adjoint_norm = adjoint_options_dict['norm']
170 |
171 | is_called = False
172 |
173 | def actual_norm(tensor_tuple):
174 | nonlocal is_called
175 | is_called = True
176 | self.assertIsInstance(tensor_tuple, tuple)
177 | t, y, adj_y, adj_param1, adj_param2 = tensor_tuple
178 | self.assertEqual(t.shape, ())
179 | self.assertEqual(y.shape, (5,))
180 | self.assertEqual(adj_y.shape, (5,))
181 | self.assertEqual(adj_param1.shape, (7,))
182 | self.assertEqual(adj_param2.shape, ())
183 | ya = y[0]
184 | yb = y[1:]
185 | adj_ya = adj_y[0]
186 | adj_yb = adj_y[1:4]
187 | out = max(t.abs(), ya.abs(), yb.pow(2).mean().sqrt(), adj_ya.abs(), adj_yb.pow(2).mean().sqrt())
188 | if not seminorm:
189 | out = max(out, adj_param1.pow(2).mean().sqrt(), adj_param2.abs())
190 | return out
191 |
192 | spy_on_adjoint_norm, is_spy_called = make_spy_on_adjoint_norm(_adjoint_norm, actual_norm)
193 | adjoint_options_dict['norm'] = spy_on_adjoint_norm
194 | xs[0].sum().backward()
195 | self.assertTrue(is_called)
196 | self.assertTrue(is_spy_called[0])
197 |
198 | # Test user-passed adjoint norms with tensor (not tuple) state
199 | is_called = False
200 |
201 | def adjoint_norm(tensor_tuple):
202 | nonlocal is_called
203 | is_called = True
204 | self.assertIsInstance(tensor_tuple, tuple)
205 | t, y, adj_y, adj_param1, adj_param2 = tensor_tuple
206 | self.assertEqual(t.shape, ())
207 | self.assertEqual(y.shape, ())
208 | self.assertEqual(adj_y.shape, ())
209 | self.assertEqual(adj_param1.shape, (7,))
210 | self.assertEqual(adj_param2.shape, ())
211 | return max(t.abs(), y.pow(2).mean().sqrt(), adj_y.pow(2).mean().sqrt(), adj_param1.pow(2).mean().sqrt(),
212 | adj_param2.abs())
213 |
214 | x0 = torch.tensor(1.)
215 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params,
216 | adjoint_options=dict(norm=adjoint_norm))
217 | xs.sum().backward()
218 | self.assertTrue(is_called)
219 |
220 | # Test user-passed adjoint norms with tuple (not tensor) state
221 | is_called = False
222 |
223 | def adjoint_norm(tensor_tuple):
224 | nonlocal is_called
225 | is_called = True
226 | self.assertIsInstance(tensor_tuple, tuple)
227 | t, ya, yb, adj_ya, adj_yb, adj_param1, adj_param2 = tensor_tuple
228 | self.assertEqual(t.shape, ())
229 | self.assertEqual(ya.shape, ())
230 | self.assertEqual(yb.shape, (2, 2))
231 | self.assertEqual(adj_ya.shape, ())
232 | self.assertEqual(adj_yb.shape, (2, 2))
233 | self.assertEqual(adj_param1.shape, (7,))
234 | self.assertEqual(adj_param2.shape, ())
235 | return max(t.abs(), ya.abs(), yb.pow(2).mean().sqrt(), adj_ya.abs(), adj_yb.pow(2).mean().sqrt(),
236 | adj_param1.pow(2).mean().sqrt(), adj_param2.abs())
237 |
238 | x0 = torch.tensor(1.), torch.tensor([[0.5, 0.5], [0.1, 0.1]])
239 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params,
240 | adjoint_options=dict(norm=adjoint_norm))
241 | xs[0].sum().backward()
242 | self.assertTrue(is_called)
243 |
244 | def test_large_norm(self):
245 |
246 | def norm(tensor):
247 | return tensor.abs().max()
248 |
249 | def large_norm(tensor):
250 | return 10 * tensor.abs().max()
251 |
252 | for dtype in DTYPES:
253 | for device in DEVICES:
254 | for method in ADAPTIVE_METHODS:
255 | if dtype == torch.float32 and method == 'dopri8':
256 | continue
257 |
258 | with self.subTest(dtype=dtype, device=device, method=method):
259 | x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype)
260 | t = torch.tensor([0., 1.0], device=device, dtype=torch.float64)
261 |
262 | norm_f = _NeuralF(width=10, oscillate=True).to(device, dtype)
263 | torchdiffeq.odeint(norm_f, x0, t, method=method, options=dict(norm=norm))
264 | large_norm_f = _NeuralF(width=10, oscillate=True).to(device, dtype)
265 | with torch.no_grad():
266 | for norm_param, large_norm_param in zip(norm_f.parameters(), large_norm_f.parameters()):
267 | large_norm_param.copy_(norm_param)
268 | torchdiffeq.odeint(large_norm_f, x0, t, method=method, options=dict(norm=large_norm))
269 |
270 | self.assertLessEqual(norm_f.nfe, large_norm_f.nfe)
271 |
272 | def test_seminorm(self):
273 | for dtype in DTYPES:
274 | for device in DEVICES:
275 | for method in ADAPTIVE_METHODS:
276 | # Tests with known failures
277 | if (
278 | dtype in [torch.float32] and
279 | method in ['tsit5']
280 | ):
281 | continue
282 |
283 | with self.subTest(dtype=dtype, device=device, method=method):
284 |
285 | if dtype == torch.float64:
286 | tol = 1e-8
287 | else:
288 | tol = 1e-6
289 |
290 | x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype)
291 | t = torch.tensor([0., 1.0], device=device, dtype=torch.float64)
292 |
293 | ode_f = _NeuralF(width=1024, oscillate=True).to(device, dtype)
294 |
295 | out = torchdiffeq.odeint_adjoint(ode_f, x0, t, atol=tol, rtol=tol, method=method)
296 | ode_f.nfe = 0
297 | out.sum().backward()
298 | default_nfe = ode_f.nfe
299 |
300 | out = torchdiffeq.odeint_adjoint(ode_f, x0, t, atol=tol, rtol=tol, method=method,
301 | adjoint_options=dict(norm='seminorm'))
302 | ode_f.nfe = 0
303 | out.sum().backward()
304 | seminorm_nfe = ode_f.nfe
305 |
306 | self.assertLessEqual(seminorm_nfe, default_nfe)
307 |
308 |
309 | if __name__ == '__main__':
310 | unittest.main()
311 |
--------------------------------------------------------------------------------
/tests/problems.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import scipy.linalg
4 | import torch
5 |
6 |
7 | class ConstantODE(torch.nn.Module):
8 |
9 | def __init__(self):
10 | super(ConstantODE, self).__init__()
11 | self.a = torch.nn.Parameter(torch.tensor(0.2))
12 | self.b = torch.nn.Parameter(torch.tensor(3.0))
13 |
14 | def forward(self, t, y):
15 | return self.a + (y - (self.a * t + self.b))**5
16 |
17 | def y_exact(self, t):
18 | return self.a * t + self.b
19 |
20 |
21 | class SineODE(torch.nn.Module):
22 | def forward(self, t, y):
23 | return 2 * y / t + t**4 * torch.sin(2 * t) - t**2 + 4 * t**3
24 |
25 | def y_exact(self, t):
26 | return -0.5 * t**4 * torch.cos(2 * t) + 0.5 * t**3 * torch.sin(2 * t) + 0.25 * t**2 * torch.cos(
27 | 2 * t
28 | ) - t**3 + 2 * t**4 + (math.pi - 0.25) * t**2
29 |
30 |
31 | class LinearODE(torch.nn.Module):
32 |
33 | def __init__(self, dim=10):
34 | super(LinearODE, self).__init__()
35 | torch.manual_seed(0)
36 | self.dim = dim
37 | U = torch.randn(dim, dim) * 0.1
38 | A = 2 * U - (U + U.transpose(0, 1))
39 | self.A = torch.nn.Parameter(A)
40 | self.initial_val = np.ones((dim, 1))
41 | self.nfe = 0
42 |
43 | def forward(self, t, y):
44 | self.nfe += 1
45 | return torch.mm(self.A, y.reshape(self.dim, 1)).reshape(-1)
46 |
47 | def y_exact(self, t):
48 | t_numpy = t.detach().cpu().numpy()
49 | A_np = self.A.detach().cpu().numpy()
50 | ans = []
51 | for t_i in t_numpy:
52 | ans.append(np.matmul(scipy.linalg.expm(A_np * t_i), self.initial_val))
53 | return torch.stack([torch.tensor(ans_) for ans_ in ans]).reshape(len(t_numpy), self.dim).to(t)
54 |
55 |
56 | class ExpODE(torch.nn.Module):
57 | def forward(self, t, y):
58 | return -0.1 * self.y_exact(t)
59 |
60 | def y_exact(self, t):
61 | return torch.exp(-0.1 * t)
62 |
63 |
64 | PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE, 'exp': ExpODE}
65 | DTYPES = (torch.float32, torch.float64)
66 | DEVICES = ['cpu']
67 | if torch.cuda.is_available():
68 | DEVICES.append('cuda')
69 | FIXED_EXPLICIT_METHODS = ('euler', 'midpoint', 'heun2', 'heun3', 'rk4', 'explicit_adams', 'implicit_adams')
70 | FIXED_IMPLICIT_METHODS = ('implicit_euler', 'implicit_midpoint', 'trapezoid', 'radauIIA3', 'gl4', 'radauIIA5', 'gl6', 'sdirk2', 'trbdf2')
71 | FIXED_METHODS = FIXED_EXPLICIT_METHODS + FIXED_IMPLICIT_METHODS
72 | ADAMS_METHODS = ('explicit_adams', 'implicit_adams')
73 | ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'tsit5', 'dopri5', 'dopri8')
74 | SCIPY_METHODS = ('scipy_solver',)
75 | IMPLICIT_METHODS = FIXED_IMPLICIT_METHODS
76 | METHODS = FIXED_METHODS + ADAPTIVE_METHODS + SCIPY_METHODS
77 |
78 |
79 | def construct_problem(device, npts=10, ode='constant', reverse=False, dtype=torch.float64):
80 |
81 | f = PROBLEMS[ode]().to(dtype=dtype, device=device)
82 |
83 | t_points = torch.linspace(1, 8, npts, dtype=torch.float64, device=device, requires_grad=True)
84 | sol = f.y_exact(t_points).to(dtype)
85 |
86 | def _flip(x, dim):
87 | indices = [slice(None)] * x.dim()
88 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=device)
89 | return x[tuple(indices)]
90 |
91 | if reverse:
92 | t_points = _flip(t_points, 0).clone().detach()
93 | sol = _flip(sol, 0).clone().detach()
94 |
95 | return f, sol[0].detach().requires_grad_(True), t_points, sol
96 |
97 |
98 | if __name__ == '__main__':
99 | f = SineODE().cpu()
100 | t_points = torch.linspace(1, 8, 100, device='cpu')
101 | sol = f.y_exact(t_points)
102 |
103 | import matplotlib.pyplot as plt
104 | plt.plot(t_points, sol)
105 | plt.show()
106 |
--------------------------------------------------------------------------------
/tests/run_all.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from api_tests import *
3 | from event_tests import *
4 | from gradient_tests import *
5 | from norm_tests import *
6 | from odeint_tests import *
7 |
8 | if __name__ == '__main__':
9 | unittest.main()
10 |
--------------------------------------------------------------------------------
/torchdiffeq/__init__.py:
--------------------------------------------------------------------------------
1 | from ._impl import odeint
2 | from ._impl import odeint_adjoint
3 | from ._impl import odeint_event
4 | from ._impl import odeint_dense
5 | __version__ = "0.2.5"
6 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/__init__.py:
--------------------------------------------------------------------------------
1 | from .odeint import odeint, odeint_dense, odeint_event
2 | from .adjoint import odeint_adjoint
3 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/adaptive_heun.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
3 |
4 |
5 | _ADAPTIVE_HEUN_TABLEAU = _ButcherTableau(
6 | alpha=torch.tensor([1.], dtype=torch.float64),
7 | beta=[
8 | torch.tensor([1.], dtype=torch.float64),
9 | ],
10 | c_sol=torch.tensor([0.5, 0.5], dtype=torch.float64),
11 | c_error=torch.tensor([
12 | 0.5,
13 | -0.5,
14 | ], dtype=torch.float64),
15 | )
16 |
17 | _AH_C_MID = torch.tensor([
18 | 0.5, 0.
19 | ], dtype=torch.float64)
20 |
21 |
22 | class AdaptiveHeunSolver(RKAdaptiveStepsizeODESolver):
23 | order = 2
24 | tableau = _ADAPTIVE_HEUN_TABLEAU
25 | mid = _AH_C_MID
26 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/adjoint.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import torch
3 | import torch.nn as nn
4 | from .odeint import SOLVERS, odeint
5 | from .misc import _check_inputs, _flat_to_shape, _mixed_norm, _all_callback_names, _all_adjoint_callback_names
6 |
7 |
8 | class OdeintAdjointMethod(torch.autograd.Function):
9 |
10 | @staticmethod
11 | def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
12 | adjoint_options, t_requires_grad, *adjoint_params):
13 |
14 | ctx.shapes = shapes
15 | ctx.func = func
16 | ctx.adjoint_rtol = adjoint_rtol
17 | ctx.adjoint_atol = adjoint_atol
18 | ctx.adjoint_method = adjoint_method
19 | ctx.adjoint_options = adjoint_options
20 | ctx.t_requires_grad = t_requires_grad
21 | ctx.event_mode = event_fn is not None
22 |
23 | with torch.no_grad():
24 | ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)
25 |
26 | if event_fn is None:
27 | y = ans
28 | ctx.save_for_backward(t, y, *adjoint_params)
29 | else:
30 | event_t, y = ans
31 | ctx.save_for_backward(t, y, event_t, *adjoint_params)
32 |
33 | return ans
34 |
35 | @staticmethod
36 | def backward(ctx, *grad_y):
37 | with torch.no_grad():
38 | func = ctx.func
39 | adjoint_rtol = ctx.adjoint_rtol
40 | adjoint_atol = ctx.adjoint_atol
41 | adjoint_method = ctx.adjoint_method
42 | adjoint_options = ctx.adjoint_options
43 | t_requires_grad = ctx.t_requires_grad
44 |
45 | # Backprop as if integrating up to event time.
46 | # Does NOT backpropagate through the event time.
47 | event_mode = ctx.event_mode
48 | if event_mode:
49 | t, y, event_t, *adjoint_params = ctx.saved_tensors
50 | _t = t
51 | t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
52 | grad_y = grad_y[1]
53 | else:
54 | t, y, *adjoint_params = ctx.saved_tensors
55 | grad_y = grad_y[0]
56 |
57 | adjoint_params = tuple(adjoint_params)
58 |
59 | ##################################
60 | # Set up initial state #
61 | ##################################
62 |
63 | # [-1] because y and grad_y are both of shape (len(t), *y0.shape)
64 | aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y
65 | aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params
66 |
67 | ##################################
68 | # Set up backward ODE func #
69 | ##################################
70 |
71 | # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
72 | def augmented_dynamics(t, y_aug):
73 | # Dynamics of the original system augmented with
74 | # the adjoint wrt y, and an integrator wrt t and args.
75 | y = y_aug[1]
76 | adj_y = y_aug[2]
77 | # ignore gradients wrt time and parameters
78 |
79 | with torch.enable_grad():
80 | t_ = t.detach()
81 | t = t_.requires_grad_(True)
82 | y = y.detach().requires_grad_(True)
83 |
84 | # If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
85 | # doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
86 | # wrt t here means we won't compute that if we don't need it.
87 | func_eval = func(t if t_requires_grad else t_, y)
88 |
89 | # Workaround for PyTorch bug #39784
90 | _t = torch.as_strided(t, (), ()) # noqa
91 | _y = torch.as_strided(y, (), ()) # noqa
92 | _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa
93 |
94 | vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
95 | func_eval, (t, y) + adjoint_params, -adj_y,
96 | allow_unused=True, retain_graph=True
97 | )
98 |
99 | # autograd.grad returns None if no gradient, set to zero.
100 | vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
101 | vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
102 | vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
103 | for param, vjp_param in zip(adjoint_params, vjp_params)]
104 |
105 | return (vjp_t, func_eval, vjp_y, *vjp_params)
106 |
107 | # Add adjoint callbacks
108 | for callback_name, adjoint_callback_name in zip(_all_callback_names, _all_adjoint_callback_names):
109 | try:
110 | callback = getattr(func, adjoint_callback_name)
111 | except AttributeError:
112 | pass
113 | else:
114 | setattr(augmented_dynamics, callback_name, callback)
115 |
116 | ##################################
117 | # Solve adjoint ODE #
118 | ##################################
119 |
120 | if t_requires_grad:
121 | time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
122 | else:
123 | time_vjps = None
124 | for i in range(len(t) - 1, 0, -1):
125 | if t_requires_grad:
126 | # Compute the effect of moving the current time measurement point.
127 | # We don't compute this unless we need to, to save some computation.
128 | func_eval = func(t[i], y[i])
129 | dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
130 | aug_state[0] -= dLd_cur_t
131 | time_vjps[i] = dLd_cur_t
132 |
133 | # Run the augmented system backwards in time.
134 | aug_state = odeint(
135 | augmented_dynamics, tuple(aug_state),
136 | t[i - 1:i + 1].flip(0),
137 | rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
138 | )
139 | aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
140 | aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state
141 | aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point
142 |
143 | if t_requires_grad:
144 | time_vjps[0] = aug_state[0]
145 |
146 | # Only compute gradient wrt initial time when in event handling mode.
147 | if event_mode and t_requires_grad:
148 | time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])
149 |
150 | adj_y = aug_state[2]
151 | adj_params = aug_state[3:]
152 |
153 | return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)
154 |
155 |
156 | def odeint_adjoint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None,
157 | adjoint_rtol=None, adjoint_atol=None, adjoint_method=None, adjoint_options=None, adjoint_params=None):
158 |
159 | # We need this in order to access the variables inside this module,
160 | # since we have no other way of getting variables along the execution path.
161 | if adjoint_params is None and not isinstance(func, nn.Module):
162 | raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they '
163 | 'can be specified explicitly via the `adjoint_params` argument. If there are no parameters '
164 | 'then it is allowable to set `adjoint_params=()`.')
165 |
166 | # Must come before _check_inputs as we don't want to use normalised input (in particular any changes to options)
167 | if adjoint_rtol is None:
168 | adjoint_rtol = rtol
169 | if adjoint_atol is None:
170 | adjoint_atol = atol
171 | if adjoint_method is None:
172 | adjoint_method = method
173 |
174 | if adjoint_method != method and options is not None and adjoint_options is None:
175 | raise ValueError("If `adjoint_method != method` then we cannot infer `adjoint_options` from `options`. So as "
176 | "`options` has been passed then `adjoint_options` must be passed as well.")
177 |
178 | if adjoint_options is None:
179 | adjoint_options = {k: v for k, v in options.items() if k != "norm"} if options is not None else {}
180 | else:
181 | # Avoid in-place modifying a user-specified dict.
182 | adjoint_options = adjoint_options.copy()
183 |
184 | if adjoint_params is None:
185 | adjoint_params = tuple(find_parameters(func))
186 | else:
187 | adjoint_params = tuple(adjoint_params) # in case adjoint_params is a generator.
188 |
189 | # Filter params that don't require gradients.
190 | oldlen_ = len(adjoint_params)
191 | adjoint_params = tuple(p for p in adjoint_params if p.requires_grad)
192 | if len(adjoint_params) != oldlen_:
193 | # Some params were excluded.
194 | # Issue a warning if a user-specified norm is specified.
195 | if 'norm' in adjoint_options and callable(adjoint_options['norm']):
196 | warnings.warn("An adjoint parameter was passed without requiring gradient. For efficiency this will be "
197 | "excluded from the adjoint pass, and will not appear as a tensor in the adjoint norm.")
198 |
199 | # Convert to flattened state.
200 | shapes, func, y0, t, rtol, atol, method, options, event_fn, decreasing_time = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
201 |
202 | # Handle the adjoint norm function.
203 | state_norm = options["norm"]
204 | handle_adjoint_norm_(adjoint_options, shapes, state_norm)
205 |
206 | ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol,
207 | adjoint_method, adjoint_options, t.requires_grad, *adjoint_params)
208 |
209 | if event_fn is None:
210 | solution = ans
211 | else:
212 | event_t, solution = ans
213 | event_t = event_t.to(t)
214 | if decreasing_time:
215 | event_t = -event_t
216 |
217 | if shapes is not None:
218 | solution = _flat_to_shape(solution, (len(t),), shapes)
219 |
220 | if event_fn is None:
221 | return solution
222 | else:
223 | return event_t, solution
224 |
225 |
226 | def find_parameters(module):
227 |
228 | assert isinstance(module, nn.Module)
229 |
230 | # If called within DataParallel, parameters won't appear in module.parameters().
231 | if getattr(module, '_is_replica', False):
232 |
233 | def find_tensor_attributes(module):
234 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) and v.requires_grad]
235 | return tuples
236 |
237 | gen = module._named_members(get_members_fn=find_tensor_attributes)
238 | return [param for _, param in gen]
239 | else:
240 | return list(module.parameters())
241 |
242 |
243 | def handle_adjoint_norm_(adjoint_options, shapes, state_norm):
244 | """In-place modifies the adjoint options to choose or wrap the norm function."""
245 |
246 | # This is the default adjoint norm on the backward pass: a mixed norm over the tuple of inputs.
247 | def default_adjoint_norm(tensor_tuple):
248 | t, y, adj_y, *adj_params = tensor_tuple
249 | # (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
250 | return max(t.abs(), state_norm(y), state_norm(adj_y), _mixed_norm(adj_params))
251 |
252 | if "norm" not in adjoint_options:
253 | # `adjoint_options` was not explicitly specified by the user. Use the default norm.
254 | adjoint_options["norm"] = default_adjoint_norm
255 | else:
256 | # `adjoint_options` was explicitly specified by the user...
257 | try:
258 | adjoint_norm = adjoint_options['norm']
259 | except KeyError:
260 | # ...but they did not specify the norm argument. Back to plan A: use the default norm.
261 | adjoint_options['norm'] = default_adjoint_norm
262 | else:
263 | # ...and they did specify the norm argument.
264 | if adjoint_norm == 'seminorm':
265 | # They told us they want to use seminorms. Slight modification to plan A: use the default norm,
266 | # but ignore the parameter state
267 | def adjoint_seminorm(tensor_tuple):
268 | t, y, adj_y, *adj_params = tensor_tuple
269 | # (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
270 | return max(t.abs(), state_norm(y), state_norm(adj_y))
271 | adjoint_options['norm'] = adjoint_seminorm
272 | else:
273 | # And they're using their own custom norm.
274 | if shapes is None:
275 | # The state on the forward pass was a tensor, not a tuple. We don't need to do anything, they're
276 | # already going to get given the full adjoint state as (t, y, adj_y, adj_params)
277 | pass # this branch included for clarity
278 | else:
279 | # This is the bit that is tuple/tensor abstraction-breaking, because the odeint machinery
280 | # doesn't know about the tupled nature of the forward state. We need to tell the user's adjoint
281 | # norm about that ourselves.
282 |
283 | def _adjoint_norm(tensor_tuple):
284 | t, y, adj_y, *adj_params = tensor_tuple
285 | y = _flat_to_shape(y, (), shapes)
286 | adj_y = _flat_to_shape(adj_y, (), shapes)
287 | return adjoint_norm((t, *y, *adj_y, *adj_params))
288 | adjoint_options['norm'] = _adjoint_norm
289 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/bosh3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
3 |
4 |
5 | _BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau(
6 | alpha=torch.tensor([1 / 2, 3 / 4, 1.], dtype=torch.float64),
7 | beta=[
8 | torch.tensor([1 / 2], dtype=torch.float64),
9 | torch.tensor([0., 3 / 4], dtype=torch.float64),
10 | torch.tensor([2 / 9, 1 / 3, 4 / 9], dtype=torch.float64)
11 | ],
12 | c_sol=torch.tensor([2 / 9, 1 / 3, 4 / 9, 0.], dtype=torch.float64),
13 | c_error=torch.tensor([2 / 9 - 7 / 24, 1 / 3 - 1 / 4, 4 / 9 - 1 / 3, -1 / 8], dtype=torch.float64),
14 | )
15 |
16 | _BS_C_MID = torch.tensor([0., 0.5, 0., 0.], dtype=torch.float64)
17 |
18 |
19 | class Bosh3Solver(RKAdaptiveStepsizeODESolver):
20 | order = 3
21 | tableau = _BOGACKI_SHAMPINE_TABLEAU
22 | mid = _BS_C_MID
23 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/dopri5.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
3 |
4 |
5 | _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau(
6 | alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64),
7 | beta=[
8 | torch.tensor([1 / 5], dtype=torch.float64),
9 | torch.tensor([3 / 40, 9 / 40], dtype=torch.float64),
10 | torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64),
11 | torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64),
12 | torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64),
13 | torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64),
14 | ],
15 | c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64),
16 | c_error=torch.tensor([
17 | 35 / 384 - 1951 / 21600,
18 | 0,
19 | 500 / 1113 - 22642 / 50085,
20 | 125 / 192 - 451 / 720,
21 | -2187 / 6784 - -12231 / 42400,
22 | 11 / 84 - 649 / 6300,
23 | -1. / 60.,
24 | ], dtype=torch.float64),
25 | )
26 |
27 | DPS_C_MID = torch.tensor([
28 | 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2,
29 | 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2
30 | ], dtype=torch.float64)
31 |
32 |
33 | class Dopri5Solver(RKAdaptiveStepsizeODESolver):
34 | order = 5
35 | tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU
36 | mid = DPS_C_MID
37 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/dopri8.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
3 |
4 |
5 | A = [1 / 18, 1 / 12, 1 / 8, 5 / 16, 3 / 8, 59 / 400, 93 / 200, 5490023248 / 9719169821, 13 / 20, 1201146811 / 1299019798, 1, 1, 1]
6 |
7 | B = [
8 | [1 / 18],
9 |
10 | [1 / 48, 1 / 16],
11 |
12 | [1 / 32, 0, 3 / 32],
13 |
14 | [5 / 16, 0, -75 / 64, 75 / 64],
15 |
16 | [3 / 80, 0, 0, 3 / 16, 3 / 20],
17 |
18 | [29443841 / 614563906, 0, 0, 77736538 / 692538347, -28693883 / 1125000000, 23124283 / 1800000000],
19 |
20 | [16016141 / 946692911, 0, 0, 61564180 / 158732637, 22789713 / 633445777, 545815736 / 2771057229, -180193667 / 1043307555],
21 |
22 | [39632708 / 573591083, 0, 0, -433636366 / 683701615, -421739975 / 2616292301, 100302831 / 723423059, 790204164 / 839813087, 800635310 / 3783071287],
23 |
24 | [246121993 / 1340847787, 0, 0, -37695042795 / 15268766246, -309121744 / 1061227803, -12992083 / 490766935, 6005943493 / 2108947869, 393006217 / 1396673457, 123872331 / 1001029789],
25 |
26 | [-1028468189 / 846180014, 0, 0, 8478235783 / 508512852, 1311729495 / 1432422823, -10304129995 / 1701304382, -48777925059 / 3047939560, 15336726248 / 1032824649, -45442868181 / 3398467696, 3065993473 / 597172653],
27 |
28 | [185892177 / 718116043, 0, 0, -3185094517 / 667107341, -477755414 / 1098053517, -703635378 / 230739211, 5731566787 / 1027545527, 5232866602 / 850066563, -4093664535 / 808688257, 3962137247 / 1805957418, 65686358 / 487910083],
29 |
30 | [403863854 / 491063109, 0, 0, -5068492393 / 434740067, -411421997 / 543043805, 652783627 / 914296604, 11173962825 / 925320556, -13158990841 / 6184727034, 3936647629 / 1978049680, -160528059 / 685178525, 248638103 / 1413531060, 0],
31 |
32 | [14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731, 561292985 / 797845732, -1041891430 / 1371343529, 760417239 / 1151165299, 118820643 / 751138087, -528747749 / 2220607170, 1 / 4]
33 | ]
34 |
35 | C_sol = [14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731, 561292985 / 797845732, -1041891430 / 1371343529, 760417239 / 1151165299, 118820643 / 751138087, -528747749 / 2220607170, 1 / 4, 0]
36 |
37 | C_err = [14005451 / 335480064 - 13451932 / 455176623, 0, 0, 0, 0, -59238493 / 1068277825 - -808719846 / 976000145, 181606767 / 758867731 - 1757004468 / 5645159321, 561292985 / 797845732 - 656045339 / 265891186, -1041891430 / 1371343529 - -3867574721 / 1518517206, 760417239 / 1151165299 - 465885868 / 322736535, 118820643 / 751138087 - 53011238 / 667516719, -528747749 / 2220607170 - 2 / 45, 1 / 4, 0]
38 |
39 | h = 1 / 2
40 |
41 | C_mid = [0.] * 14
42 |
43 | C_mid[0] = (- 6.3448349392860401388 * (h**5) + 22.1396504998094068976 * (h**4) - 30.0610568289666450593 * (h**3) + 19.9990069333683970610 * (h**2) - 6.6910181737837595697 * h + 1.0) / (1 / h)
44 |
45 | C_mid[5] = (- 39.6107919852202505218 * (h**5) + 116.4422149550342161651 * (h**4) - 121.4999627731334642623 * (h**3) + 52.2273532792945524050 * (h**2) - 7.6142658045872677172 * h) / (1 / h)
46 |
47 | C_mid[6] = (20.3761213808791436958 * (h**5) - 67.1451318825957197185 * (h**4) + 83.1721004639847717481 * (h**3) - 46.8919164181093621583 * (h**2) + 10.7281392630428866124 * h) / (1 / h)
48 |
49 | C_mid[7] = (7.3347098826795362023 * (h**5) - 16.5672243527496524646 * (h**4) + 9.5724507555993664382 * (h**3) - 0.1890893225010595467 * (h**2) + 0.5526637063753648783 * h) / (1 / h)
50 |
51 | C_mid[8] = (32.8801774352459155182 * (h**5) - 89.9916014847245016028 * (h**4) + 87.8406057677205645007 * (h**3) - 35.7075975946222072821 * (h**2) + 4.2186562625665153803 * h) / (1 / h)
52 |
53 | C_mid[9] = (- 10.1588990526426760954 * (h**5) + 22.6237489648532849093 * (h**4) - 17.4152107770762969005 * (h**3) + 6.2736448083240352160 * (h**2) - 0.6627209125361597559 * h) / (1 / h)
54 |
55 | C_mid[10] = (- 12.5401268098782561200 * (h**5) + 32.2362340167355370113 * (h**4) - 28.5903289514790976966 * (h**3) + 10.3160881272450748458 * (h**2) - 1.2636789001135462218 * h) / (1 / h)
56 |
57 | C_mid[11] = (29.5553001484516038033 * (h**5) - 82.1020315488359848644 * (h**4) + 81.6630950584341412934 * (h**3) - 34.7650769866611817349 * (h**2) + 5.4106037898590422230 * h) / (1 / h)
58 |
59 | C_mid[12] = (- 41.7923486424390588923 * (h**5) + 116.2662185791119533462 * (h**4) - 114.9375291377009418170 * (h**3) + 47.7457971078225540396 * (h**2) - 7.0321379067945741781 * h) / (1 / h)
60 |
61 | C_mid[13] = (20.3006925822100825485 * (h**5) - 53.9020777466385396792 * (h**4) + 50.2558364226176017553 * (h**3) - 19.0082099341608028453 * (h**2) + 2.3537586759714983486 * h) / (1 / h)
62 |
63 |
64 | A = torch.tensor(A, dtype=torch.float64)
65 | B = [torch.tensor(B_, dtype=torch.float64) for B_ in B]
66 | C_sol = torch.tensor(C_sol, dtype=torch.float64)
67 | C_err = torch.tensor(C_err, dtype=torch.float64)
68 | _C_mid = torch.tensor(C_mid, dtype=torch.float64)
69 |
70 | _DOPRI8_TABLEAU = _ButcherTableau(alpha=A, beta=B, c_sol=C_sol, c_error=C_err)
71 |
72 |
73 | class Dopri8Solver(RKAdaptiveStepsizeODESolver):
74 | order = 8
75 | tableau = _DOPRI8_TABLEAU
76 | mid = _C_mid
77 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/event_handling.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | def find_event(interp_fn, sign0, t0, t1, event_fn, tol):
6 | with torch.no_grad():
7 |
8 | # Num iterations for the secant method until tolerance is within target.
9 | nitrs = torch.ceil(torch.log((t1 - t0) / tol) / math.log(2.0))
10 |
11 | for _ in range(nitrs.long()):
12 | t_mid = (t1 + t0) / 2.0
13 | y_mid = interp_fn(t_mid)
14 | sign_mid = torch.sign(event_fn(t_mid, y_mid))
15 | same_as_sign0 = (sign0 == sign_mid)
16 | t0 = torch.where(same_as_sign0, t_mid, t0)
17 | t1 = torch.where(same_as_sign0, t1, t_mid)
18 | event_t = (t0 + t1) / 2.0
19 |
20 | return event_t, interp_fn(event_t)
21 |
22 |
23 | def combine_event_functions(event_fn, t0, y0):
24 | """
25 | We ensure all event functions are initially positive,
26 | so then we can combine them by taking a min.
27 | """
28 | with torch.no_grad():
29 | initial_signs = torch.sign(event_fn(t0, y0))
30 |
31 | def combined_event_fn(t, y):
32 | c = event_fn(t, y)
33 | return torch.min(c * initial_signs)
34 |
35 | return combined_event_fn
36 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/fehlberg2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
3 |
4 | _FEHLBERG2_TABLEAU = _ButcherTableau(
5 | alpha=torch.tensor([1 / 2, 1.0], dtype=torch.float64),
6 | beta=[
7 | torch.tensor([1 / 2], dtype=torch.float64),
8 | torch.tensor([1 / 256, 255 / 256], dtype=torch.float64),
9 | ],
10 | c_sol=torch.tensor([1 / 512, 255 / 256, 1 / 512], dtype=torch.float64),
11 | c_error=torch.tensor(
12 | [-1 / 512, 0, 1 / 512], dtype=torch.float64
13 | ),
14 | )
15 |
16 | _FE_C_MID = torch.tensor([0.0, 0.5, 0.0], dtype=torch.float64)
17 |
18 |
19 | class Fehlberg2(RKAdaptiveStepsizeODESolver):
20 | order = 2
21 | tableau = _FEHLBERG2_TABLEAU
22 | mid = _FE_C_MID
23 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/fixed_adams.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import sys
3 | import torch
4 | import warnings
5 | from .solvers import FixedGridODESolver
6 | from .misc import _compute_error_ratio, _linf_norm
7 | from .misc import Perturb
8 | from .rk_common import rk4_alt_step_func
9 |
10 | _BASHFORTH_COEFFICIENTS = [
11 | [], # order 0
12 | [11],
13 | [3, -1],
14 | [23, -16, 5],
15 | [55, -59, 37, -9],
16 | [1901, -2774, 2616, -1274, 251],
17 | [4277, -7923, 9982, -7298, 2877, -475],
18 | [198721, -447288, 705549, -688256, 407139, -134472, 19087],
19 | [434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799],
20 | [14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017],
21 | [30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753],
22 | [
23 | 2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920,
24 | 7417904451, -1479574348, 134211265
25 | ],
26 | [
27 | 4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290,
28 | 58189107627, -17410248271, 3158642445, -262747265
29 | ],
30 | [
31 | 13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764,
32 | 1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734,
33 | 703604254357
34 | ],
35 | [
36 | 27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755,
37 | 4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906,
38 | 19382853593787, -1382741929621
39 | ],
40 | [
41 | 173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728,
42 | 53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179,
43 | -3728807256577472, 859236476684231, -122594813904112, 8164168737599
44 | ],
45 | [
46 | 362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733,
47 | -131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331,
48 | 70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375
49 | ],
50 | [
51 | 192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010,
52 | -102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560,
53 | -158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550,
54 | -5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249
55 | ],
56 | [
57 | 401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760,
58 | -304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750,
59 | -706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224,
60 | -49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873
61 | ],
62 | [
63 | 333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344,
64 | 151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408,
65 | 1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432,
66 | 343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552,
67 | 2157574942881818312049, -239560589366324764716, 12600467236042756559
68 | ],
69 | [
70 | 691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295,
71 | 399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380,
72 | 4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210,
73 | 1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532,
74 | 28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303
75 | ],
76 | ]
77 |
78 | _MOULTON_COEFFICIENTS = [
79 | [], # order 0
80 | [1],
81 | [1, 1],
82 | [5, 8, -1],
83 | [9, 19, -5, 1],
84 | [251, 646, -264, 106, -19],
85 | [475, 1427, -798, 482, -173, 27],
86 | [19087, 65112, -46461, 37504, -20211, 6312, -863],
87 | [36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375],
88 | [1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953],
89 | [2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281],
90 | [
91 | 134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195,
92 | 36284876, -3250433
93 | ],
94 | [
95 | 262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115,
96 | 384709327, -68928781, 5675265
97 | ],
98 | [
99 | 703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152,
100 | 19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093
101 | ],
102 | [
103 | 1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892,
104 | 80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093
105 | ],
106 | [
107 | 8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632,
108 | -963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544,
109 | -14110480969927, 1998759236336, -132282840127
110 | ],
111 | [
112 | 16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733,
113 | -3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483,
114 | -137515713789319, 29219384284087, -3867689367599, 240208245823
115 | ],
116 | [
117 | 8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010,
118 | 1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370,
119 | -1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610,
120 | 1913813460537746, -111956703448001
121 | ],
122 | [
123 | 15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520,
124 | 4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490,
125 | -6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220,
126 | 31816981024600492, -3722582669836627, 205804074290625
127 | ],
128 | [
129 | 12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800,
130 | -2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800,
131 | -15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288,
132 | -4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136,
133 | -26182538841925312881, 2895045518506940460, -151711881512390095
134 | ],
135 | [
136 | 24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335,
137 | -5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820,
138 | -51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906,
139 | -22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212,
140 | -325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815
141 | ],
142 | ]
143 |
144 | _DIVISOR = [
145 | None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000,
146 | 31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000
147 | ]
148 |
149 | _BASHFORTH_DIVISOR = [torch.tensor([b / divisor for b in bashforth], dtype=torch.float64)
150 | for bashforth, divisor in zip(_BASHFORTH_COEFFICIENTS, _DIVISOR)]
151 | _MOULTON_DIVISOR = [torch.tensor([m / divisor for m in moulton], dtype=torch.float64)
152 | for moulton, divisor in zip(_MOULTON_COEFFICIENTS, _DIVISOR)]
153 |
154 | _MIN_ORDER = 4
155 | _MAX_ORDER = 12
156 | _MAX_ITERS = 4
157 |
158 |
159 | # TODO: replace this with PyTorch operations (a little hard because y is a deque being used as a circular buffer)
160 | def _dot_product(x, y):
161 | return sum(xi * yi for xi, yi in zip(x, y))
162 |
163 |
164 | class AdamsBashforthMoulton(FixedGridODESolver):
165 | order = 4
166 |
167 | def __init__(self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER,
168 | **kwargs):
169 | super(AdamsBashforthMoulton, self).__init__(func, y0, rtol=rtol, atol=rtol, **kwargs)
170 | assert max_order <= _MAX_ORDER, "max_order must be at most {}".format(_MAX_ORDER)
171 | if max_order < _MIN_ORDER:
172 | warnings.warn("max_order is below {}, so the solver reduces to `rk4`.".format(_MIN_ORDER))
173 |
174 | self.rtol = torch.as_tensor(rtol, dtype=y0.dtype, device=y0.device)
175 | self.atol = torch.as_tensor(atol, dtype=y0.dtype, device=y0.device)
176 | self.implicit = implicit
177 | self.max_iters = max_iters
178 | self.max_order = int(max_order)
179 | self.prev_f = collections.deque(maxlen=self.max_order - 1)
180 | self.prev_t = None
181 |
182 | self.bashforth = [x.to(y0.device) for x in _BASHFORTH_DIVISOR]
183 | self.moulton = [x.to(y0.device) for x in _MOULTON_DIVISOR]
184 |
185 | def _update_history(self, t, f):
186 | if self.prev_t is None or self.prev_t != t:
187 | self.prev_f.appendleft(f)
188 | self.prev_t = t
189 |
190 | def _has_converged(self, y0, y1):
191 | """Checks that each element is within the error tolerance."""
192 | error_ratio = _compute_error_ratio(torch.abs(y0 - y1), self.rtol, self.atol, y0, y1, _linf_norm)
193 | return error_ratio < 1
194 |
195 | def _step_func(self, func, t0, dt, t1, y0):
196 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
197 | self._update_history(t0, f0)
198 | order = min(len(self.prev_f), self.max_order - 1)
199 | if order < _MIN_ORDER - 1:
200 | # Compute using RK4.
201 | return rk4_alt_step_func(func, t0, dt, t1, y0, f0=self.prev_f[0], perturb=self.perturb), f0
202 | else:
203 | # Adams-Bashforth predictor.
204 | bashforth_coeffs = self.bashforth[order]
205 | dy = _dot_product(dt * bashforth_coeffs, self.prev_f).type_as(y0) # bashforth is float64 so cast back
206 |
207 | # Adams-Moulton corrector.
208 | if self.implicit:
209 | moulton_coeffs = self.moulton[order + 1]
210 | delta = dt * _dot_product(moulton_coeffs[1:], self.prev_f).type_as(y0) # moulton is float64 so cast back
211 | converged = False
212 | for _ in range(self.max_iters):
213 | dy_old = dy
214 | f = func(t1, y0 + dy, perturb=Perturb.PREV if self.perturb else Perturb.NONE)
215 | dy = (dt * (moulton_coeffs[0]) * f).type_as(y0) + delta # moulton is float64 so cast back
216 | converged = self._has_converged(dy_old, dy)
217 | if converged:
218 | break
219 | if not converged:
220 | warnings.warn('Functional iteration did not converge. Solution may be incorrect.')
221 | self.prev_f.pop()
222 | self._update_history(t0, f)
223 | return dy, f0
224 |
225 |
226 | class AdamsBashforth(AdamsBashforthMoulton):
227 | def __init__(self, func, y0, **kwargs):
228 | super(AdamsBashforth, self).__init__(func, y0, implicit=False, **kwargs)
229 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/fixed_grid.py:
--------------------------------------------------------------------------------
1 | from .solvers import FixedGridODESolver
2 | from .rk_common import rk4_alt_step_func, rk3_step_func, rk2_step_func
3 | from .misc import Perturb
4 |
5 |
6 | class Euler(FixedGridODESolver):
7 | order = 1
8 |
9 | def _step_func(self, func, t0, dt, t1, y0):
10 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
11 | return dt * f0, f0
12 |
13 |
14 | class Midpoint(FixedGridODESolver):
15 | order = 2
16 |
17 | def _step_func(self, func, t0, dt, t1, y0):
18 | half_dt = 0.5 * dt
19 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
20 | y_mid = y0 + f0 * half_dt
21 | return dt * func(t0 + half_dt, y_mid), f0
22 |
23 |
24 | class RK4(FixedGridODESolver):
25 | order = 4
26 |
27 | def _step_func(self, func, t0, dt, t1, y0):
28 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
29 | return rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0
30 |
31 |
32 | class Heun3(FixedGridODESolver):
33 | order = 3
34 |
35 | def _step_func(self, func, t0, dt, t1, y0):
36 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
37 |
38 | butcher_tableu = [
39 | [0.0, 0.0, 0.0, 0.0],
40 | [1/3, 1/3, 0.0, 0.0],
41 | [2/3, 0.0, 2/3, 0.0],
42 | [0.0, 1/4, 0.0, 3/4],
43 | ]
44 |
45 | return rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0
46 |
47 |
48 | class Heun2(FixedGridODESolver):
49 | order = 2
50 |
51 | def _step_func(self, func, t0, dt, t1, y0):
52 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
53 |
54 | butcher_tableu = [
55 | [0.0, 0.0, 0.0],
56 | [1.0, 1.0, 0.0],
57 | [0.0, 1/2, 1/2],
58 | ]
59 |
60 | return rk2_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0
61 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/fixed_grid_implicit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rk_common import FixedGridFIRKODESolver, FixedGridDIRKODESolver
3 | from .rk_common import _ButcherTableau
4 |
5 | _sqrt_2 = torch.sqrt(torch.tensor(2, dtype=torch.float64)).item()
6 | _sqrt_3 = torch.sqrt(torch.tensor(3, dtype=torch.float64)).item()
7 | _sqrt_6 = torch.sqrt(torch.tensor(6, dtype=torch.float64)).item()
8 | _sqrt_15 = torch.sqrt(torch.tensor(15, dtype=torch.float64)).item()
9 |
10 | _IMPLICIT_EULER_TABLEAU = _ButcherTableau(
11 | alpha=torch.tensor([1], dtype=torch.float64),
12 | beta=[
13 | torch.tensor([1], dtype=torch.float64),
14 | ],
15 | c_sol=torch.tensor([1], dtype=torch.float64),
16 | c_error=torch.tensor([], dtype=torch.float64),
17 | )
18 |
19 | class ImplicitEuler(FixedGridFIRKODESolver):
20 | order = 1
21 | tableau = _IMPLICIT_EULER_TABLEAU
22 |
23 | _IMPLICIT_MIDPOINT_TABLEAU = _ButcherTableau(
24 | alpha=torch.tensor([1 / 2], dtype=torch.float64),
25 | beta=[
26 | torch.tensor([1 / 2], dtype=torch.float64),
27 |
28 | ],
29 | c_sol=torch.tensor([1], dtype=torch.float64),
30 | c_error=torch.tensor([], dtype=torch.float64),
31 | )
32 |
33 | class ImplicitMidpoint(FixedGridFIRKODESolver):
34 | order = 2
35 | tableau = _IMPLICIT_MIDPOINT_TABLEAU
36 |
37 | _GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau(
38 | alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 - _sqrt_3 / 6], dtype=torch.float64),
39 | beta=[
40 | torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64),
41 | torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64),
42 | ],
43 | c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64),
44 | c_error=torch.tensor([], dtype=torch.float64),
45 | )
46 |
47 | _TRAPEZOID_TABLEAU = _ButcherTableau(
48 | alpha=torch.tensor([0, 1], dtype=torch.float64),
49 | beta=[
50 | torch.tensor([0, 0], dtype=torch.float64),
51 | torch.tensor([1 /2, 1 / 2], dtype=torch.float64),
52 | ],
53 | c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64),
54 | c_error=torch.tensor([], dtype=torch.float64),
55 | )
56 |
57 | class Trapezoid(FixedGridFIRKODESolver):
58 | order = 2
59 | tableau = _TRAPEZOID_TABLEAU
60 |
61 |
62 | class GaussLegendre4(FixedGridFIRKODESolver):
63 | order = 4
64 | tableau = _GAUSS_LEGENDRE_4_TABLEAU
65 |
66 | _GAUSS_LEGENDRE_6_TABLEAU = _ButcherTableau(
67 | alpha=torch.tensor([1 / 2 - _sqrt_15 / 10, 1 / 2, 1 / 2 + _sqrt_15 / 10], dtype=torch.float64),
68 | beta=[
69 | torch.tensor([5 / 36 , 2 / 9 - _sqrt_15 / 15, 5 / 36 - _sqrt_15 / 30], dtype=torch.float64),
70 | torch.tensor([5 / 36 + _sqrt_15 / 24, 2 / 9 , 5 / 36 - _sqrt_15 / 24], dtype=torch.float64),
71 | torch.tensor([5 / 36 + _sqrt_15 / 30, 2 / 9 + _sqrt_15 / 15, 5 / 36 ], dtype=torch.float64),
72 | ],
73 | c_sol=torch.tensor([5 / 18, 4 / 9, 5 / 18], dtype=torch.float64),
74 | c_error=torch.tensor([], dtype=torch.float64),
75 | )
76 |
77 | class GaussLegendre6(FixedGridFIRKODESolver):
78 | order = 6
79 | tableau = _GAUSS_LEGENDRE_6_TABLEAU
80 |
81 | _RADAU_IIA_3_TABLEAU = _ButcherTableau(
82 | alpha=torch.tensor([1 / 3, 1], dtype=torch.float64),
83 | beta=[
84 | torch.tensor([5 / 12, -1 / 12], dtype=torch.float64),
85 | torch.tensor([3 / 4, 1 / 4], dtype=torch.float64)
86 | ],
87 | c_sol=torch.tensor([3 / 4, 1 / 4], dtype=torch.float64),
88 | c_error=torch.tensor([], dtype=torch.float64)
89 | )
90 |
91 | class RadauIIA3(FixedGridFIRKODESolver):
92 | order = 3
93 | tableau = _RADAU_IIA_3_TABLEAU
94 |
95 | _RADAU_IIA_5_TABLEAU = _ButcherTableau(
96 | alpha=torch.tensor([2 / 5 - _sqrt_6 / 10, 2 / 5 + _sqrt_6 / 10, 1], dtype=torch.float64),
97 | beta=[
98 | torch.tensor([11 / 45 - 7 * _sqrt_6 / 360 , 37 / 225 - 169 * _sqrt_6 / 1800, -2 / 225 + _sqrt_6 / 75], dtype=torch.float64),
99 | torch.tensor([37 / 225 + 169 * _sqrt_6 / 1800, 11 / 45 + 7 * _sqrt_6 / 360 , -2 / 225 - _sqrt_6 / 75], dtype=torch.float64),
100 | torch.tensor([4 / 9 - _sqrt_6 / 36 , 4 / 9 + _sqrt_6 / 36 , 1 / 9], dtype=torch.float64)
101 | ],
102 | c_sol=torch.tensor([4 / 9 - _sqrt_6 / 36, 4 / 9 + _sqrt_6 / 36, 1 / 9], dtype=torch.float64),
103 | c_error=torch.tensor([], dtype=torch.float64)
104 | )
105 |
106 | class RadauIIA5(FixedGridFIRKODESolver):
107 | order = 5
108 | tableau = _RADAU_IIA_5_TABLEAU
109 |
110 | gamma = (2. - _sqrt_2) / 2.
111 | _SDIRK_2_TABLEAU = _ButcherTableau(
112 | alpha = torch.tensor([gamma, 1], dtype=torch.float64),
113 | beta=[
114 | torch.tensor([gamma], dtype=torch.float64),
115 | torch.tensor([1 - gamma, gamma], dtype=torch.float64),
116 | ],
117 | c_sol=torch.tensor([1 - gamma, gamma], dtype=torch.float64),
118 | c_error=torch.tensor([], dtype=torch.float64)
119 | )
120 |
121 | class SDIRK2(FixedGridDIRKODESolver):
122 | order = 2
123 | tableau = _SDIRK_2_TABLEAU
124 |
125 | gamma = 1. - _sqrt_2 / 2.
126 | beta = _sqrt_2 / 4.
127 | _TRBDF_2_TABLEAU = _ButcherTableau(
128 | alpha = torch.tensor([0, 2 * gamma, 1], dtype=torch.float64),
129 | beta=[
130 | torch.tensor([0], dtype=torch.float64),
131 | torch.tensor([gamma, gamma], dtype=torch.float64),
132 | torch.tensor([beta, beta, gamma], dtype=torch.float64),
133 | ],
134 | c_sol=torch.tensor([beta, beta, gamma], dtype=torch.float64),
135 | c_error=torch.tensor([], dtype=torch.float64)
136 | )
137 |
138 | class TRBDF2(FixedGridDIRKODESolver):
139 | order = 2
140 | tableau = _TRBDF_2_TABLEAU
141 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/interp.py:
--------------------------------------------------------------------------------
1 | def _interp_fit(y0, y1, y_mid, f0, f1, dt):
2 | """Fit coefficients for 4th order polynomial interpolation.
3 |
4 | Args:
5 | y0: function value at the start of the interval.
6 | y1: function value at the end of the interval.
7 | y_mid: function value at the mid-point of the interval.
8 | f0: derivative value at the start of the interval.
9 | f1: derivative value at the end of the interval.
10 | dt: width of the interval.
11 |
12 | Returns:
13 | List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
14 | `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
15 | between 0 (start of interval) and 1 (end of interval).
16 | """
17 | a = 2 * dt * (f1 - f0) - 8 * (y1 + y0) + 16 * y_mid
18 | b = dt * (5 * f0 - 3 * f1) + 18 * y0 + 14 * y1 - 32 * y_mid
19 | c = dt * (f1 - 4 * f0) - 11 * y0 - 5 * y1 + 16 * y_mid
20 | d = dt * f0
21 | e = y0
22 | return [e, d, c, b, a]
23 |
24 |
25 | def _interp_evaluate(coefficients, t0, t1, t):
26 | """Evaluate polynomial interpolation at the given time point.
27 |
28 | Args:
29 | coefficients: list of Tensor coefficients as created by `interp_fit`.
30 | t0: scalar float64 Tensor giving the start of the interval.
31 | t1: scalar float64 Tensor giving the end of the interval.
32 | t: scalar float64 Tensor giving the desired interpolation point.
33 |
34 | Returns:
35 | Polynomial interpolation of the coefficients at time `t`.
36 | """
37 |
38 | assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1)
39 | x = (t - t0) / (t1 - t0)
40 | x = x.to(coefficients[0].dtype)
41 |
42 | total = coefficients[0] + x * coefficients[1]
43 | x_power = x
44 | for coefficient in coefficients[2:]:
45 | x_power = x_power * x
46 | total = total + x_power * coefficient
47 |
48 | return total
49 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/misc.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | import math
3 | import numpy as np
4 | import torch
5 | import warnings
6 | from .event_handling import combine_event_functions
7 |
8 |
9 | _all_callback_names = ['callback_step', 'callback_accept_step', 'callback_reject_step']
10 | _all_adjoint_callback_names = [name + '_adjoint' for name in _all_callback_names]
11 | _null_callback = lambda *args, **kwargs: None
12 |
13 | def _handle_unused_kwargs(solver, unused_kwargs):
14 | if len(unused_kwargs) > 0:
15 | warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs))
16 |
17 |
18 | def _linf_norm(tensor):
19 | return tensor.abs().max()
20 |
21 |
22 | def _rms_norm(tensor):
23 | return tensor.abs().pow(2).mean().sqrt()
24 |
25 |
26 | def _zero_norm(tensor):
27 | return 0.
28 |
29 |
30 | def _mixed_norm(tensor_tuple):
31 | if len(tensor_tuple) == 0:
32 | return 0.
33 | return max([_rms_norm(tensor) for tensor in tensor_tuple])
34 |
35 |
36 | def _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0=None):
37 | """Empirically select a good initial step.
38 |
39 | The algorithm is described in [1]_.
40 |
41 | References
42 | ----------
43 | .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
44 | Equations I: Nonstiff Problems", Sec. II.4, 2nd edition.
45 | """
46 |
47 | dtype = y0.dtype
48 | device = y0.device
49 | t_dtype = t0.dtype
50 | t0 = t0.to(t_dtype)
51 |
52 | if f0 is None:
53 | f0 = func(t0, y0)
54 |
55 | scale = atol + torch.abs(y0) * rtol
56 |
57 | d0 = norm(y0 / scale).abs()
58 | d1 = norm(f0 / scale).abs()
59 |
60 | if d0 < 1e-5 or d1 < 1e-5:
61 | h0 = torch.tensor(1e-6, dtype=dtype, device=device)
62 | else:
63 | h0 = 0.01 * d0 / d1
64 | h0 = h0.abs()
65 |
66 | y1 = y0 + h0 * f0
67 | f1 = func(t0 + h0, y1)
68 |
69 | d2 = torch.abs(norm((f1 - f0) / scale) / h0)
70 |
71 | if d1 <= 1e-15 and d2 <= 1e-15:
72 | h1 = torch.max(torch.tensor(1e-6, dtype=dtype, device=device), h0 * 1e-3)
73 | else:
74 | h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1))
75 | h1 = h1.abs()
76 |
77 | return torch.min(100 * h0, h1).to(t_dtype)
78 |
79 |
80 | def _compute_error_ratio(error_estimate, rtol, atol, y0, y1, norm):
81 | error_tol = atol + rtol * torch.max(y0.abs(), y1.abs())
82 | return norm(error_estimate / error_tol).abs()
83 |
84 |
85 | @torch.no_grad()
86 | def _optimal_step_size(last_step, error_ratio, safety, ifactor, dfactor, order):
87 | """Calculate the optimal size for the next step."""
88 | if error_ratio == 0:
89 | return last_step * ifactor
90 | if error_ratio < 1:
91 | dfactor = torch.ones((), dtype=last_step.dtype, device=last_step.device)
92 | error_ratio = error_ratio.type_as(last_step)
93 | exponent = torch.tensor(order, dtype=last_step.dtype, device=last_step.device).reciprocal()
94 | factor = torch.min(ifactor, torch.max(safety / error_ratio ** exponent, dfactor))
95 | return last_step * factor
96 |
97 |
98 | def _decreasing(t):
99 | return (t[1:] < t[:-1]).all()
100 |
101 |
102 | def _assert_one_dimensional(name, t):
103 | assert t.ndimension() == 1, "{} must be one dimensional".format(name)
104 |
105 |
106 | def _assert_increasing(name, t):
107 | assert (t[1:] > t[:-1]).all(), '{} must be strictly increasing or decreasing'.format(name)
108 |
109 |
110 | def _assert_floating(name, t):
111 | if not torch.is_floating_point(t):
112 | raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type()))
113 |
114 |
115 | def _tuple_tol(name, tol, shapes):
116 | try:
117 | iter(tol)
118 | except TypeError:
119 | return tol
120 | tol = tuple(tol)
121 | assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name)
122 | tol = [torch.as_tensor(tol_).expand(shape.numel()) for tol_, shape in zip(tol, shapes)]
123 | return torch.cat(tol)
124 |
125 |
126 | def _flat_to_shape(tensor, length, shapes):
127 | tensor_list = []
128 | total = 0
129 | for shape in shapes:
130 | next_total = total + shape.numel()
131 | # It's important that this be view((...)), not view(...). Else when length=(), shape=() it fails.
132 | tensor_list.append(tensor[..., total:next_total].view((*length, *shape)))
133 | total = next_total
134 | return tuple(tensor_list)
135 |
136 |
137 | class _TupleFunc(torch.nn.Module):
138 | def __init__(self, base_func, shapes):
139 | super(_TupleFunc, self).__init__()
140 | self.base_func = base_func
141 | self.shapes = shapes
142 |
143 | def forward(self, t, y):
144 | f = self.base_func(t, _flat_to_shape(y, (), self.shapes))
145 | return torch.cat([f_.reshape(-1) for f_ in f])
146 |
147 |
148 | class _TupleInputOnlyFunc(torch.nn.Module):
149 | def __init__(self, base_func, shapes):
150 | super(_TupleInputOnlyFunc, self).__init__()
151 | self.base_func = base_func
152 | self.shapes = shapes
153 |
154 | def forward(self, t, y):
155 | return self.base_func(t, _flat_to_shape(y, (), self.shapes))
156 |
157 |
158 | class _ReverseFunc(torch.nn.Module):
159 | def __init__(self, base_func, mul=1.0):
160 | super(_ReverseFunc, self).__init__()
161 | self.base_func = base_func
162 | self.mul = mul
163 |
164 | def forward(self, t, y):
165 | return self.mul * self.base_func(-t, y)
166 |
167 |
168 | class Perturb(Enum):
169 | NONE = 0
170 | PREV = 1
171 | NEXT = 2
172 |
173 |
174 | class _PerturbFunc(torch.nn.Module):
175 |
176 | def __init__(self, base_func):
177 | super(_PerturbFunc, self).__init__()
178 | self.base_func = base_func
179 |
180 | def forward(self, t, y, *, perturb=Perturb.NONE):
181 | assert isinstance(perturb, Perturb), "perturb argument must be of type Perturb enum"
182 | # This dtype change here might be buggy.
183 | # The exact time value should be determined inside the solver,
184 | # but this can slightly change it due to numerical differences during casting.
185 | if torch.is_complex(t):
186 | t = t.real
187 | t = t.to(y.abs().dtype)
188 | if perturb is Perturb.NEXT:
189 | # Replace with next smallest representable value.
190 | t = _nextafter(t, t + 1)
191 | elif perturb is Perturb.PREV:
192 | # Replace with prev largest representable value.
193 | t = _nextafter(t, t - 1)
194 | else:
195 | # Do nothing.
196 | pass
197 | return self.base_func(t, y)
198 |
199 |
200 | def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS):
201 |
202 | if event_fn is not None:
203 | if len(t) != 2:
204 | raise ValueError(f"We require len(t) == 2 when in event handling mode, but got len(t)={len(t)}.")
205 |
206 | # Combine event functions if the output is multivariate.
207 | event_fn = combine_event_functions(event_fn, t[0], y0)
208 |
209 | # Keep reference to original func as passed in
210 | original_func = func
211 |
212 | # Normalise to tensor (non-tupled) input
213 | shapes = None
214 | is_tuple = not isinstance(y0, torch.Tensor)
215 | if is_tuple:
216 | assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
217 | shapes = [y0_.shape for y0_ in y0]
218 | rtol = _tuple_tol('rtol', rtol, shapes)
219 | atol = _tuple_tol('atol', atol, shapes)
220 | y0 = torch.cat([y0_.reshape(-1) for y0_ in y0])
221 | func = _TupleFunc(func, shapes)
222 | if event_fn is not None:
223 | event_fn = _TupleInputOnlyFunc(event_fn, shapes)
224 |
225 | # Normalise method and options
226 | if options is None:
227 | options = {}
228 | else:
229 | options = options.copy()
230 | if method is None:
231 | method = 'dopri5'
232 | if method not in SOLVERS:
233 | raise ValueError('Invalid method "{}". Must be one of {}'.format(method,
234 | '{"' + '", "'.join(SOLVERS.keys()) + '"}.'))
235 |
236 | if is_tuple:
237 | # We accept tupled input. This is an abstraction that is hidden from the rest of odeint (exception when
238 | # returning values), so here we need to maintain the abstraction by wrapping norm functions.
239 |
240 | if 'norm' in options:
241 | # If the user passed a norm then get that...
242 | norm = options['norm']
243 | else:
244 | # ...otherwise we default to a mixed Linf/L2 norm over tupled input.
245 | norm = _mixed_norm
246 |
247 | # In either case, norm(...) is assumed to take a tuple of tensors as input. (As that's what the state looks
248 | # like from the point of view of the user.)
249 | # So here we take the tensor that the machinery of odeint has given us, and turn it in the tuple that the
250 | # norm function is expecting.
251 | def _norm(tensor):
252 | y = _flat_to_shape(tensor, (), shapes)
253 | return norm(y)
254 | options['norm'] = _norm
255 |
256 | else:
257 | if 'norm' in options:
258 | # No need to change the norm function.
259 | pass
260 | else:
261 | # Else just use the default norm.
262 | # Technically we don't need to set that here (RKAdaptiveStepsizeODESolver has it as a default), but it
263 | # makes it easier to reason about, in the adjoint norm logic, if we know that options['norm'] is
264 | # definitely set to something.
265 | options['norm'] = _rms_norm
266 |
267 | # Normalise time
268 | _check_timelike('t', t, True)
269 | t_is_reversed = False
270 | if len(t) > 1 and t[0] > t[1]:
271 | t_is_reversed = True
272 |
273 | if t_is_reversed:
274 | # Change the integration times to ascending order.
275 | # We do this by negating the time values and all associated arguments.
276 | t = -t
277 |
278 | # Ensure time values are un-negated when calling functions.
279 | func = _ReverseFunc(func, mul=-1.0)
280 | if event_fn is not None:
281 | event_fn = _ReverseFunc(event_fn)
282 |
283 | # For fixed step solvers.
284 | try:
285 | _grid_constructor = options['grid_constructor']
286 | except KeyError:
287 | pass
288 | else:
289 | options['grid_constructor'] = lambda func, y0, t: -_grid_constructor(func, y0, -t)
290 |
291 | # For RK solvers.
292 | _flip_option(options, 'step_t')
293 | _flip_option(options, 'jump_t')
294 |
295 | # Can only do after having normalised time
296 | _assert_increasing('t', t)
297 |
298 | # Tol checking
299 | if torch.is_tensor(rtol):
300 | assert not rtol.requires_grad, "rtol cannot require gradient"
301 | if torch.is_tensor(atol):
302 | assert not atol.requires_grad, "atol cannot require gradient"
303 |
304 | # Backward compatibility: Allow t and y0 to be on different devices
305 | if t.device != y0.device:
306 | warnings.warn("t is not on the same device as y0. Coercing to y0.device.")
307 | t = t.to(y0.device)
308 | # ~Backward compatibility
309 |
310 | # Add perturb argument to func.
311 | func = _PerturbFunc(func)
312 |
313 | # Add callbacks to wrapped_func
314 | callback_names = set()
315 | for callback_name in _all_callback_names:
316 | try:
317 | callback = getattr(original_func, callback_name)
318 | except AttributeError:
319 | setattr(func, callback_name, _null_callback)
320 | else:
321 | if callback is not _null_callback:
322 | callback_names.add(callback_name)
323 | # At the moment all callbacks have the arguments (t0, y0, dt).
324 | # These will need adjusting on a per-callback basis if that changes in the future.
325 | if is_tuple:
326 | def callback(t0, y0, dt, _callback=callback):
327 | y0 = _flat_to_shape(y0, (), shapes)
328 | return _callback(t0, y0, dt)
329 | if t_is_reversed:
330 | def callback(t0, y0, dt, _callback=callback):
331 | return _callback(-t0, y0, dt)
332 | setattr(func, callback_name, callback)
333 | for callback_name in _all_adjoint_callback_names:
334 | try:
335 | callback = getattr(original_func, callback_name)
336 | except AttributeError:
337 | pass
338 | else:
339 | setattr(func, callback_name, callback)
340 |
341 | invalid_callbacks = callback_names - SOLVERS[method].valid_callbacks()
342 | if len(invalid_callbacks) > 0:
343 | warnings.warn("Solver '{}' does not support callbacks {}".format(method, invalid_callbacks))
344 |
345 | return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed
346 |
347 |
348 | class _StitchGradient(torch.autograd.Function):
349 | @staticmethod
350 | def forward(ctx, x1, out):
351 | return out
352 |
353 | @staticmethod
354 | def backward(ctx, grad_out):
355 | return grad_out, None
356 |
357 |
358 | def _nextafter(x1, x2):
359 | with torch.no_grad():
360 | if hasattr(torch, "nextafter"):
361 | out = torch.nextafter(x1, x2)
362 | else:
363 | out = np_nextafter(x1, x2)
364 | return _StitchGradient.apply(x1, out)
365 |
366 |
367 | def np_nextafter(x1, x2):
368 | warnings.warn("torch.nextafter is only available in PyTorch 1.7 or newer."
369 | "Falling back to numpy.nextafter. Upgrade PyTorch to remove this warning.")
370 | x1_np = x1.detach().cpu().numpy()
371 | x2_np = x2.detach().cpu().numpy()
372 | out = torch.tensor(np.nextafter(x1_np, x2_np)).to(x1)
373 | return out
374 |
375 |
376 | def _check_timelike(name, timelike, can_grad):
377 | assert isinstance(timelike, torch.Tensor), '{} must be a torch.Tensor'.format(name)
378 | _assert_floating(name, timelike)
379 | assert timelike.ndimension() == 1, "{} must be one dimensional".format(name)
380 | if not can_grad:
381 | assert not timelike.requires_grad, "{} cannot require gradient".format(name)
382 | diff = timelike[1:] > timelike[:-1]
383 | assert diff.all() or (~diff).all(), '{} must be strictly increasing or decreasing'.format(name)
384 |
385 |
386 | def _flip_option(options, option_name):
387 | try:
388 | option_value = options[option_name]
389 | except KeyError:
390 | pass
391 | else:
392 | if isinstance(option_value, torch.Tensor):
393 | options[option_name] = -option_value
394 | # else: an error will be raised when the option is attempted to be used in Solver.__init__, but we defer raising
395 | # the error until then to keep things tidy.
396 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/odeint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd.functional import vjp
3 | from .dopri5 import Dopri5Solver
4 | from .bosh3 import Bosh3Solver
5 | from .adaptive_heun import AdaptiveHeunSolver
6 | from .fehlberg2 import Fehlberg2
7 | from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4
8 | from .fixed_grid_implicit import ImplicitEuler, ImplicitMidpoint, Trapezoid
9 | from .fixed_grid_implicit import GaussLegendre4, GaussLegendre6
10 | from .fixed_grid_implicit import RadauIIA3, RadauIIA5
11 | from .fixed_grid_implicit import SDIRK2, TRBDF2
12 | from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
13 | from .dopri8 import Dopri8Solver
14 | from .tsit5 import Tsit5Solver
15 | from .scipy_wrapper import ScipyWrapperODESolver
16 | from .misc import _check_inputs, _flat_to_shape
17 | from .interp import _interp_evaluate
18 |
19 | SOLVERS = {
20 | 'dopri8': Dopri8Solver,
21 | 'dopri5': Dopri5Solver,
22 | 'tsit5': Tsit5Solver,
23 | 'bosh3': Bosh3Solver,
24 | 'fehlberg2': Fehlberg2,
25 | 'adaptive_heun': AdaptiveHeunSolver,
26 | 'euler': Euler,
27 | 'midpoint': Midpoint,
28 | 'heun2': Heun2,
29 | 'heun3': Heun3,
30 | 'rk4': RK4,
31 | 'explicit_adams': AdamsBashforth,
32 | 'implicit_adams': AdamsBashforthMoulton,
33 | 'implicit_euler': ImplicitEuler,
34 | 'implicit_midpoint': ImplicitMidpoint,
35 | 'trapezoid': Trapezoid,
36 | 'radauIIA3': RadauIIA3,
37 | 'gl4': GaussLegendre4,
38 | 'radauIIA5': RadauIIA5,
39 | 'gl6': GaussLegendre6,
40 | 'sdirk2': SDIRK2,
41 | 'trbdf2': TRBDF2,
42 | # Backward compatibility: use the same name as before
43 | 'fixed_adams': AdamsBashforthMoulton,
44 | # ~Backwards compatibility
45 | 'scipy_solver': ScipyWrapperODESolver,
46 | }
47 |
48 |
49 | def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None):
50 | """Integrate a system of ordinary differential equations.
51 |
52 | Solves the initial value problem for a non-stiff system of first order ODEs:
53 | ```
54 | dy/dt = func(t, y), y(t[0]) = y0
55 | ```
56 | where y is a Tensor or tuple of Tensors of any shape.
57 |
58 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
59 |
60 | Args:
61 | func: Function that maps a scalar Tensor `t` and a Tensor holding the state `y`
62 | into a Tensor of state derivatives with respect to time. Optionally, `y`
63 | can also be a tuple of Tensors.
64 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. Optionally, `y0`
65 | can also be a tuple of Tensors.
66 | t: 1-D Tensor holding a sequence of time points for which to solve for
67 | `y`, in either increasing or decreasing order. The first element of
68 | this sequence is taken to be the initial time point.
69 | rtol: optional float64 Tensor specifying an upper bound on relative error,
70 | per element of `y`.
71 | atol: optional float64 Tensor specifying an upper bound on absolute error,
72 | per element of `y`.
73 | method: optional string indicating the integration method to use.
74 | options: optional dict of configuring options for the indicated integration
75 | method. Can only be provided if a `method` is explicitly set.
76 | event_fn: Function that maps the state `y` to a Tensor. The solve terminates when
77 | event_fn evaluates to zero. If this is not None, all but the first elements of
78 | `t` are ignored.
79 |
80 | Returns:
81 | y: Tensor, where the first dimension corresponds to different
82 | time points. Contains the solved value of y for each desired time point in
83 | `t`, with the initial value `y0` being the first element along the first
84 | dimension.
85 |
86 | Raises:
87 | ValueError: if an invalid `method` is provided.
88 | """
89 |
90 | shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
91 |
92 | solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options)
93 |
94 | if event_fn is None:
95 | solution = solver.integrate(t)
96 | else:
97 | event_t, solution = solver.integrate_until_event(t[0], event_fn)
98 | event_t = event_t.to(t)
99 | if t_is_reversed:
100 | event_t = -event_t
101 |
102 | if shapes is not None:
103 | solution = _flat_to_shape(solution, (len(t),), shapes)
104 |
105 | if event_fn is None:
106 | return solution
107 | else:
108 | return event_t, solution
109 |
110 |
111 | def odeint_dense(func, y0, t0, t1, *, rtol=1e-7, atol=1e-9, method=None, options=None):
112 |
113 | assert torch.is_tensor(y0) # TODO: handle tuple of tensors
114 |
115 | t = torch.tensor([t0, t1]).to(t0)
116 |
117 | shapes, func, y0, t, rtol, atol, method, options, _, _ = _check_inputs(func, y0, t, rtol, atol, method, options, None, SOLVERS)
118 |
119 | assert method == "dopri5"
120 |
121 | solver = Dopri5Solver(func=func, y0=y0, rtol=rtol, atol=atol, **options)
122 |
123 | # The integration loop
124 | solution = torch.empty(len(t), *solver.y0.shape, dtype=solver.y0.dtype, device=solver.y0.device)
125 | solution[0] = solver.y0
126 | t = t.to(solver.dtype)
127 | solver._before_integrate(t)
128 | t0 = solver.rk_state.t0
129 |
130 | times = [t0]
131 | interp_coeffs = []
132 |
133 | for i in range(1, len(t)):
134 | next_t = t[i]
135 | while next_t > solver.rk_state.t1:
136 | solver.rk_state = solver._adaptive_step(solver.rk_state)
137 | t1 = solver.rk_state.t1
138 |
139 | if t1 != t0:
140 | # Step accepted.
141 | t0 = t1
142 | times.append(t1)
143 | interp_coeffs.append(torch.stack(solver.rk_state.interp_coeff))
144 |
145 | solution[i] = _interp_evaluate(solver.rk_state.interp_coeff, solver.rk_state.t0, solver.rk_state.t1, next_t)
146 |
147 | times = torch.stack(times).reshape(-1).cpu()
148 | interp_coeffs = torch.stack(interp_coeffs)
149 |
150 | def dense_output_fn(t_eval):
151 | idx = torch.searchsorted(times, t_eval, side="right")
152 | t0 = times[idx - 1]
153 | t1 = times[idx]
154 | coef = [interp_coeffs[idx - 1][i] for i in range(interp_coeffs.shape[1])]
155 | return _interp_evaluate(coef, t0, t1, t_eval)
156 |
157 | return dense_output_fn
158 |
159 |
160 | def odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs):
161 | """Automatically links up the gradient from the event time."""
162 |
163 | if reverse_time:
164 | t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() - 1.0])
165 | else:
166 | t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() + 1.0])
167 |
168 | event_t, solution = odeint_interface(func, y0, t, event_fn=event_fn, **kwargs)
169 |
170 | # Dummy values for rtol, atol, method, and options.
171 | shapes, _func, _, t, _, _, _, _, event_fn, _ = _check_inputs(func, y0, t, 0.0, 0.0, None, None, event_fn, SOLVERS)
172 |
173 | if shapes is not None:
174 | state_t = torch.cat([s[-1].reshape(-1) for s in solution])
175 | else:
176 | state_t = solution[-1]
177 |
178 | # Event_fn takes in negated time value if reverse_time is True.
179 | if reverse_time:
180 | event_t = -event_t
181 |
182 | event_t, state_t = ImplicitFnGradientRerouting.apply(_func, event_fn, event_t, state_t)
183 |
184 | # Return the user expected time value.
185 | if reverse_time:
186 | event_t = -event_t
187 |
188 | if shapes is not None:
189 | state_t = _flat_to_shape(state_t, (), shapes)
190 | solution = tuple(torch.cat([s[:-1], s_t[None]], dim=0) for s, s_t in zip(solution, state_t))
191 | else:
192 | solution = torch.cat([solution[:-1], state_t[None]], dim=0)
193 |
194 | return event_t, solution
195 |
196 |
197 | class ImplicitFnGradientRerouting(torch.autograd.Function):
198 |
199 | @staticmethod
200 | def forward(ctx, func, event_fn, event_t, state_t):
201 | """ event_t is the solution to event_fn """
202 | ctx.func = func
203 | ctx.event_fn = event_fn
204 | ctx.save_for_backward(event_t, state_t)
205 | return event_t.detach(), state_t.detach()
206 |
207 | @staticmethod
208 | def backward(ctx, grad_t, grad_state):
209 | func = ctx.func
210 | event_fn = ctx.event_fn
211 | event_t, state_t = ctx.saved_tensors
212 |
213 | event_t = event_t.detach().clone().requires_grad_(True)
214 | state_t = state_t.detach().clone().requires_grad_(True)
215 |
216 | f_val = func(event_t, state_t)
217 |
218 | with torch.enable_grad():
219 | c, (par_dt, dstate) = vjp(event_fn, (event_t, state_t))
220 |
221 | # Total derivative of event_fn wrt t evaluated at event_t.
222 | dcdt = par_dt + torch.sum(dstate * f_val)
223 |
224 | # Add the gradient from final state to final time value as if a regular odeint was called.
225 | grad_t = grad_t + torch.sum(grad_state * f_val)
226 |
227 | dstate = dstate * (-grad_t / (dcdt + 1e-12)).reshape_as(c)
228 |
229 | grad_state = grad_state + dstate
230 |
231 | return None, None, None, grad_state
232 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/scipy_wrapper.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | from scipy.integrate import solve_ivp
4 | from .misc import _handle_unused_kwargs
5 |
6 |
7 | class ScipyWrapperODESolver(metaclass=abc.ABCMeta):
8 |
9 | def __init__(self, func, y0, rtol, atol, min_step=0, max_step=float('inf'), solver="LSODA", **unused_kwargs):
10 | unused_kwargs.pop('norm', None)
11 | unused_kwargs.pop('grid_points', None)
12 | unused_kwargs.pop('eps', None)
13 | _handle_unused_kwargs(self, unused_kwargs)
14 | del unused_kwargs
15 |
16 | self.dtype = y0.dtype
17 | self.device = y0.device
18 | self.shape = y0.shape
19 | self.y0 = y0.detach().cpu().numpy().reshape(-1)
20 | self.rtol = rtol
21 | self.atol = atol
22 | self.min_step = min_step
23 | self.max_step = max_step
24 | self.solver = solver
25 | self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype)
26 |
27 | def integrate(self, t):
28 | if t.numel() == 1:
29 | return torch.tensor(self.y0)[None].to(self.device, self.dtype)
30 | t = t.detach().cpu().numpy()
31 | sol = solve_ivp(
32 | self.func,
33 | t_span=[t.min(), t.max()],
34 | y0=self.y0,
35 | t_eval=t,
36 | method=self.solver,
37 | rtol=self.rtol,
38 | atol=self.atol,
39 | min_step=self.min_step,
40 | max_step=self.max_step
41 | )
42 | sol = torch.tensor(sol.y).T.to(self.device, self.dtype)
43 | sol = sol.reshape(-1, *self.shape)
44 | return sol
45 |
46 | @classmethod
47 | def valid_callbacks(cls):
48 | return set()
49 |
50 |
51 | def convert_func_to_numpy(func, shape, device, dtype):
52 |
53 | def np_func(t, y):
54 | t = torch.tensor(t).to(device, dtype)
55 | y = torch.reshape(torch.tensor(y).to(device, dtype), shape)
56 | with torch.no_grad():
57 | f = func(t, y)
58 | return f.detach().cpu().numpy().reshape(-1)
59 |
60 | return np_func
61 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/solvers.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | from .event_handling import find_event
4 | from .misc import _handle_unused_kwargs
5 |
6 |
7 | class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta):
8 | def __init__(self, dtype, y0, norm, **unused_kwargs):
9 | _handle_unused_kwargs(self, unused_kwargs)
10 | del unused_kwargs
11 |
12 | self.y0 = y0
13 | self.dtype = dtype
14 |
15 | self.norm = norm
16 |
17 | def _before_integrate(self, t):
18 | pass
19 |
20 | @abc.abstractmethod
21 | def _advance(self, next_t):
22 | raise NotImplementedError
23 |
24 | @classmethod
25 | def valid_callbacks(cls):
26 | return set()
27 |
28 | def integrate(self, t):
29 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
30 | solution[0] = self.y0
31 | t = t.to(self.dtype)
32 | self._before_integrate(t)
33 | for i in range(1, len(t)):
34 | solution[i] = self._advance(t[i])
35 | return solution
36 |
37 |
38 | class AdaptiveStepsizeEventODESolver(AdaptiveStepsizeODESolver, metaclass=abc.ABCMeta):
39 |
40 | @abc.abstractmethod
41 | def _advance_until_event(self, event_fn):
42 | raise NotImplementedError
43 |
44 | def integrate_until_event(self, t0, event_fn):
45 | t0 = t0.to(self.y0.device, self.dtype)
46 | self._before_integrate(t0.reshape(-1))
47 | event_time, y1 = self._advance_until_event(event_fn)
48 | solution = torch.stack([self.y0, y1], dim=0)
49 | return event_time, solution
50 |
51 |
52 | class FixedGridODESolver(metaclass=abc.ABCMeta):
53 | order: int
54 |
55 | def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs):
56 | self.atol = unused_kwargs.pop('atol')
57 | unused_kwargs.pop('rtol', None)
58 | unused_kwargs.pop('norm', None)
59 | _handle_unused_kwargs(self, unused_kwargs)
60 | del unused_kwargs
61 |
62 | self.func = func
63 | self.y0 = y0
64 | self.dtype = y0.dtype
65 | self.device = y0.device
66 | self.step_size = step_size
67 | self.interp = interp
68 | self.perturb = perturb
69 |
70 | if step_size is None:
71 | if grid_constructor is None:
72 | self.grid_constructor = lambda f, y0, t: t
73 | else:
74 | self.grid_constructor = grid_constructor
75 | else:
76 | if grid_constructor is None:
77 | self.grid_constructor = self._grid_constructor_from_step_size(step_size)
78 | else:
79 | raise ValueError("step_size and grid_constructor are mutually exclusive arguments.")
80 |
81 | @classmethod
82 | def valid_callbacks(cls):
83 | return {'callback_step'}
84 |
85 | @staticmethod
86 | def _grid_constructor_from_step_size(step_size):
87 | def _grid_constructor(func, y0, t):
88 | start_time = t[0]
89 | end_time = t[-1]
90 |
91 | niters = torch.ceil((end_time - start_time) / step_size + 1).item()
92 | t_infer = torch.arange(0, niters, dtype=t.dtype, device=t.device) * step_size + start_time
93 | t_infer[-1] = t[-1]
94 |
95 | return t_infer
96 | return _grid_constructor
97 |
98 | @abc.abstractmethod
99 | def _step_func(self, func, t0, dt, t1, y0):
100 | pass
101 |
102 | def integrate(self, t):
103 | time_grid = self.grid_constructor(self.func, self.y0, t)
104 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
105 |
106 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
107 | solution[0] = self.y0
108 |
109 | j = 1
110 | y0 = self.y0
111 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
112 | dt = t1 - t0
113 | self.func.callback_step(t0, y0, dt)
114 | dy, f0 = self._step_func(self.func, t0, dt, t1, y0)
115 | y1 = y0 + dy
116 |
117 | while j < len(t) and t1 >= t[j]:
118 | if self.interp == "linear":
119 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
120 | elif self.interp == "cubic":
121 | f1 = self.func(t1, y1)
122 | solution[j] = self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t[j])
123 | else:
124 | raise ValueError(f"Unknown interpolation method {self.interp}")
125 | j += 1
126 | y0 = y1
127 |
128 | return solution
129 |
130 | def integrate_until_event(self, t0, event_fn):
131 | assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options."
132 |
133 | t0 = t0.type_as(self.y0.abs())
134 | y0 = self.y0
135 | dt = self.step_size
136 |
137 | sign0 = torch.sign(event_fn(t0, y0))
138 | max_itrs = 20000
139 | itr = 0
140 | while True:
141 | itr += 1
142 | t1 = t0 + dt
143 | dy, f0 = self._step_func(self.func, t0, dt, t1, y0)
144 | y1 = y0 + dy
145 |
146 | sign1 = torch.sign(event_fn(t1, y1))
147 |
148 | if sign0 != sign1:
149 | if self.interp == "linear":
150 | interp_fn = lambda t: self._linear_interp(t0, t1, y0, y1, t)
151 | elif self.interp == "cubic":
152 | f1 = self.func(t1, y1)
153 | interp_fn = lambda t: self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t)
154 | else:
155 | raise ValueError(f"Unknown interpolation method {self.interp}")
156 | event_time, y1 = find_event(interp_fn, sign0, t0, t1, event_fn, float(self.atol))
157 | break
158 | else:
159 | t0, y0 = t1, y1
160 |
161 | if itr >= max_itrs:
162 | raise RuntimeError(f"Reached maximum number of iterations {max_itrs}.")
163 | solution = torch.stack([self.y0, y1], dim=0)
164 | return event_time, solution
165 |
166 | def _cubic_hermite_interp(self, t0, y0, f0, t1, y1, f1, t):
167 | h = (t - t0) / (t1 - t0)
168 | h00 = (1 + 2 * h) * (1 - h) * (1 - h)
169 | h10 = h * (1 - h) * (1 - h)
170 | h01 = h * h * (3 - 2 * h)
171 | h11 = h * h * (h - 1)
172 | dt = (t1 - t0)
173 | return h00 * y0 + h10 * dt * f0 + h01 * y1 + h11 * dt * f1
174 |
175 | def _linear_interp(self, t0, t1, y0, y1, t):
176 | if t == t0:
177 | return y0
178 | if t == t1:
179 | return y1
180 | slope = (t - t0) / (t1 - t0)
181 | return y0 + slope * (y1 - y0)
182 |
--------------------------------------------------------------------------------
/torchdiffeq/_impl/tsit5.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
3 | # https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqTsit5/src/tsit_tableaus.jl
4 | # https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_solver/tsit5.py#L158
5 |
6 | _TSITOURAS_TABLEAU = _ButcherTableau(
7 | alpha=torch.tensor([
8 | 161 / 1000,
9 | 327 / 1000,
10 | 9 / 10,
11 | .9800255409045096857298102862870245954942137979563024768854764293221195950761080302604,
12 | 1,
13 | 1
14 | ], dtype=torch.float64),
15 | beta=[
16 | torch.tensor([161 / 1000], dtype=torch.float64),
17 | torch.tensor([
18 | -.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2,
19 | .3354806554923569885444268742502307746751211773934303915373692342452941929761641411569
20 | ], dtype=torch.float64),
21 | torch.tensor([
22 | 2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289,
23 | -6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366,
24 | 4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077,
25 | ], dtype=torch.float64),
26 | torch.tensor([
27 | 5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791,
28 | -11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589,
29 | 7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753,
30 | -.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1
31 | ], dtype=torch.float64),
32 | torch.tensor([
33 | 5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748,
34 | -12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452,
35 | 8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986,
36 | -.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1,
37 | -.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1
38 | ], dtype=torch.float64),
39 | torch.tensor([
40 | .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1,
41 | 1 / 100,
42 | .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509,
43 | 1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331,
44 | -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677,
45 | 2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841
46 | ], dtype=torch.float64),
47 | ],
48 | c_sol=torch.tensor([
49 | .9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1,
50 | .9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2,
51 | .4877705284247615707855642599631228241516691959761363774365216240304071651579571959813,
52 | 1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761,
53 | -2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702,
54 | 1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089,
55 | 1 / 66
56 | ], dtype=torch.float64),
57 | c_error=torch.tensor([
58 | -1.780011052225771443378550607539534775944678804333659557637450799792588061629796e-03,
59 | -8.164344596567469032236360633546862401862537590159047610940604670770447527463931e-04,
60 | 7.880878010261996010314727672526304238628733777103128603258129604952959142646516e-03,
61 | -1.44711007173262907537165147972635116720922712343167677619514233896760819649515e-01,
62 | 5.823571654525552250199376106520421794260781239567387797673045438803694038950012e-01,
63 | -4.580821059291869466616365188325542974428047279788398179474684434732070620889539e-01,
64 | 1 / 66
65 | ], dtype=torch.float64),
66 | )
67 |
68 | x = 1 / 2
69 | TSIT_C_MID = torch.tensor([
70 | -1.0530884977290216*x*(x-1.329989018975412)*(x*x-1.4364028541716351*x+0.7139816917074209),
71 | 0.1017*x*x*(x*x-2.1966568338249754*x+1.2949852507374631),
72 | 2.490627285651252793*x*x*(x*x-2.38535645472061657*x+1.57803468208092486),
73 | -16.54810288924490272*(x-1.21712927295533244)*(x-0.61620406037800089)*x*x,
74 | 47.37952196281928122*(x-1.203071208372362603)*(x-0.658047292653547382)*x*x,
75 | -34.87065786149660974*(x-1.2)*(x-2/3)*x*x,
76 | 2.5*(x-1)*(x-0.6)*x*x
77 | ], dtype=torch.float64)
78 |
79 | class Tsit5Solver(RKAdaptiveStepsizeODESolver):
80 | order = 5
81 | tableau = _TSITOURAS_TABLEAU
82 | mid = TSIT_C_MID
83 |
--------------------------------------------------------------------------------