├── .gitignore ├── CITATION.cff ├── FAQ.md ├── FURTHER_DOCUMENTATION.md ├── LICENSE ├── README.md ├── assets ├── bouncing_ball.png ├── cnf_demo.gif ├── ode_demo.gif ├── odenet_0_viz.png └── resnet_0_viz.png ├── examples ├── README.md ├── bouncing_ball.py ├── cnf.py ├── latent_ode.py ├── learn_physics.py ├── ode_demo.py └── odenet_mnist.py ├── setup.py ├── tests ├── DETEST │ ├── detest.py │ └── run.py ├── api_tests.py ├── event_tests.py ├── gradient_tests.py ├── norm_tests.py ├── odeint_tests.py ├── problems.py └── run_all.py └── torchdiffeq ├── __init__.py └── _impl ├── __init__.py ├── adaptive_heun.py ├── adjoint.py ├── bosh3.py ├── dopri5.py ├── dopri8.py ├── event_handling.py ├── fehlberg2.py ├── fixed_adams.py ├── fixed_grid.py ├── fixed_grid_implicit.py ├── interp.py ├── misc.py ├── odeint.py ├── rk_common.py ├── scipy_wrapper.py ├── solvers.py └── tsit5.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | .installed.cfg 3 | *.egg 4 | *__pycache__* 5 | *.pyc 6 | .vscode 7 | build 8 | dist 9 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # YAML 1.2 2 | --- 3 | abstract: | 4 | "This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. We also allow terminating an ODE solution based on an event function, with exact gradient computed. 5 | 6 | As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU." 7 | authors: 8 | - 9 | family-names: Chen 10 | given-names: "Ricky T. Q." 11 | cff-version: "1.1.0" 12 | date-released: 2021-06-02 13 | license: MIT 14 | message: "PyTorch Implementation of Differentiable ODE Solvers" 15 | repository-code: "https://github.com/rtqichen/torchdiffeq" 16 | title: torchdiffeq 17 | version: "0.2.2" 18 | ... 19 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions (FAQ) 2 | 3 | **What are good resources to understand how ODEs can be solved?**
4 | *Solving Ordinary Differential Equations I Nonstiff Problems* by Hairer et al.
5 | [ODE solver selection in MatLab](https://blogs.mathworks.com/loren/2015/09/23/ode-solver-selection-in-matlab/)
6 | 7 | **What are the ODE solvers available in this repo?**
8 | 9 | - Adaptive-step: 10 | - `dopri8` Runge-Kutta 7(8) of Dormand-Prince-Shampine 11 | - `dopri5` Runge-Kutta 4(5) of Dormand-Prince **[default]**. 12 | - `bosh3` Runge-Kutta 2(3) of Bogacki-Shampine 13 | - `adaptive_heun` Runge-Kutta 1(2) 14 | 15 | - Fixed-step: 16 | - `euler` Euler method. 17 | - `midpoint` Midpoint method. 18 | - `rk4` Fourth-order Runge-Kutta with 3/8 rule. 19 | - `explicit_adams` Explicit Adams. 20 | - `implicit_adams` Implicit Adams. 21 | 22 | - `scipy_solver`: Wraps a SciPy solver. 23 | 24 | 25 | **What are `NFE-F` and `NFE-B`?**
26 | Number of function evaluations for forward and backward pass. 27 | 28 | **What are `rtol` and `atol`?**
29 | They refer to relative `rtol` and absolute `atol` error tolerance. 30 | 31 | **What is the role of error tolerance in adaptive solvers?**
32 | The basic idea is each adaptive solver can produce an error estimate of the current step, and if the error is greater than some tolerance, then the step is redone with a smaller step size, and this repeats until the error is smaller than the provided tolerance.
33 | [Error Tolerances for Variable-Step Solvers](https://www.mathworks.com/help/simulink/ug/types-of-solvers.html#f11-44943) 34 | 35 | **How is the error tolerance calculated?**
36 | The error tolerance is [calculated]((https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/misc.py#L74)) as `atol + rtol * norm of current state`, where the norm being used is a mixed L-infinity/RMS norm. 37 | 38 | **Where is the code that computes the error tolerance?**
39 | It is computed [here.](https://github.com/rtqichen/torchdiffeq/blob/c4c9c61c939c630b9b88267aa56ddaaec319cb16/torchdiffeq/_impl/misc.py#L94) 40 | 41 | **How many states must a Neural ODE solver store during a forward pass with the adjoint method?**
42 | The number of states required to be stored in memory during a forward pass is solver dependent. For example, `dopri5` requires 6 intermediate states to be stored. 43 | 44 | **How many function evaluations are there per ODE step on adaptive solvers?**
45 | 46 | - `dopri5`
47 | The `dopri5` ODE solver stores at least 6 evaluations of the ODE, then takes a step using a linear combination of them. The diagram below illustrates it: the evaluations marked with `o` are on the estimated path, the others with `x` are not. The first two are for selecting the initial step size. 48 | 49 | ``` 50 | 0 1 | 2 3 4 5 6 7 | 8 9 10 12 13 14 51 | o x | x x x x x o | x x x x x o 52 | ``` 53 | 54 | 55 | **How do I obtain evaluations on the estimated path when using an adaptive solver?**
56 | The argument `t` of `odeint` specifies what times should the ODE solver output.
57 | ```odeint(func, x0, t=torch.linspace(0, 1, 50))``` 58 | 59 | Note that the ODE solver will always integrate from `min t(0)` to `max t(1)`, and the intermediate values of `t` have no effect on how the ODE the solved. Intermediate values are computed using polynomial interpolation and have very small cost. 60 | 61 | **What non-linearities should I use in my Neural ODE?**
62 | Avoid non-smooth non-linearities such as ReLU and LeakyReLU.
63 | Prefer non-linearities with a theoretically unique adjoint/gradient such as Softplus. 64 | 65 | **Where is backpropagation for the Neural ODE defined?**
66 | It's defined [here](https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py) if you use the adjoint method `odeint_adjoint`. 67 | 68 | **What are Tableaus?**
69 | Tableaus are ways to describe coefficients for [RK methods](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods). The particular set of coefficients used on this repo was taken from [here](https://www.ams.org/journals/mcom/1986-46-173/S0025-5718-1986-0815836-3/). 70 | 71 | **How do I install the repo on Windows?**
72 | Try downloading the code directly and just running python setup.py install. 73 | https://stackoverflow.com/questions/52528955/installing-a-python-module-from-github-in-windows-10 74 | 75 | **What is the most memory-expensive operation during training?**
76 | The most memory-expensive operation is the single [backward call](https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L75) made to the network. 77 | 78 | **My Neural ODE's numerical solution is farther away from the target than the initial value**
79 | Most tricks for initializing residual nets (like zeroing the weights of the last layer) should help for ODEs as well. This will initialize the ODE as an identity. 80 | 81 | 82 | **My Neural ODE takes too long to train**
83 | This might be because you're running on CPU. Being extremely slow on CPU is expected, as training requires evaluating a neural net multiple times. 84 | 85 | 86 | **My Neural ODE produces underflow in dt when using adaptive solvers like `dopri5`**
87 | This is a problem of the ODE becoming stiff, essentially acting too erratic in a region and the step size becomes so close to zero that no progress can be made in the solver. We were able to avoid this with regularization such as weight decay and using "nice" activation functions, but YMMV. Other potential options are just to accept a larger error by increasing `atol`, `rtol`, or by switching to a fixed solver. -------------------------------------------------------------------------------- /FURTHER_DOCUMENTATION.md: -------------------------------------------------------------------------------- 1 | # Further documentation 2 | 3 | ## Solver options 4 | 5 | Adaptive and fixed solvers all support several options. Also shown are their default values. 6 | 7 | **Adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun):**
8 | For these solvers, `rtol` and `atol` correspond to the tolerances for accepting/rejecting an adaptive step. 9 | 10 | - `first_step=None`: What size the first step of the solver should be; by default this is selected empirically. 11 | 12 | - `safety=0.9, ifactor=10.0, dfactor=0.2`: How the next optimal step size is calculated, see E. Hairer, S. P. Norsett G. Wanner, *Solving Ordinary Differential Equations I: Nonstiff Problems*, Sec. II.4. Roughly speaking, `safety` will try to shrink the step size slightly by this amount, `ifactor` is the most that the step size can grow by, and `dfactor` is the most that it can shrink by. 13 | 14 | - `max_num_steps=2 ** 31 - 1`: The maximum number of steps the solver is allowed to take. 15 | 16 | - `dtype=torch.float64`: what dtype to use for timelike quantities. Setting this to `torch.float32` will improve speed but may produce underflow errors more easily. 17 | 18 | - `step_t=None`: Times that a step must me made to. In particular this is useful when `func` has kinks (derivative discontinuities) at these times, as the solver then does not need to (slowly) discover these for itself. If passed this should be a `torch.Tensor`. 19 | 20 | - `jump_t=None`: Times that a step must be made to, and `func` re-evaluated at. In particular this is useful when `func` has discontinuites at these times, as then the solver knows that the final function evaluation of the previous step is not equal to the first function evaluation of this step. (i.e. the FSAL property does not hold at this point.) If passed this should be a `torch.Tensor`. Note that this may not be efficient when using PyTorch 1.6.0 or earlier. 21 | 22 | - `norm`: What norm to compute the accept/reject criterion with respect to. Given tensor input, this defaults to an RMS norm. Given tupled input, this defaults to computing an RMS norm over each tensor, and then taking a max over the tuple, producing a mixed L-infinity/RMS norm. If passed this should be a function consuming a tensor/tuple with the same shape as `y0`, and return a scalar corresponding to its norm. When passed as part of `adjoint_options`, then the special value `"seminorm"` may be used to zero out the contribution from the parameters, as per the ["Hey, that's not an ODE"](https://arxiv.org/abs/2009.09457) paper. 23 | 24 | **Fixed solvers (euler, midpoint, rk4, explicit_adams, implicit_adams):**
25 | 26 | - `step_size=None`: How large each discrete step should be. If not passed then this defaults to stepping between the values of `t`. Note that if using `t` just to specify the start and end of the regions of integration, then it is very important to specify this argument! It is mutually exclusive with the `grid_constructor` argument, below. 27 | 28 | - `grid_constructor=None`: A more fine-grained way of setting the steps, by setting these particular locations as the locations of the steps. Should be a callable `func, y0, t -> grid`, transforming the arguments `func, y0, t` of `odeint` into the desired grid (which should be a one dimensional tensor). 29 | 30 | - `perturb`: Defaults to False. If True, then automatically add small perturbations to the start and end of each step, so that stepping to discontinuities works. Note that this this may not be efficient when using PyTorch 1.6.0 or earlier. 31 | 32 | Individual solvers also offer certain options. 33 | 34 | **explicit_adams:**
35 | For this solver, `rtol` and `atol` are ignored. This solver also supports: 36 | 37 | - `max_order`: The maximum order of the Adams-Bashforth predictor. 38 | 39 | **implicit_adams:**
40 | For this solver, `rtol` and `atol` correspond to the tolerance for convergence of the Adams-Moulton corrector. This solver also supports: 41 | 42 | - `max_order`: The maximum order of the Adams-Bashforth-Moulton predictor-corrector. 43 | 44 | - `max_iters`: The maximum number of iterations to run the Adams-Moulton corrector for. 45 | 46 | **scipy_solver:**
47 | - `solver`: which SciPy solver to use; corresponds to the `'method'` argument of `scipy.integrate.solve_ivp`. 48 | 49 | ## Adjoint options 50 | 51 | The function `odeint_adjoint` offers some adjoint-specific options. 52 | 53 | - `adjoint_rtol`,
`adjoint_atol`,
`adjoint_method`,
`adjoint_options`:
The `rtol, atol, method, options` to use for the backward pass. Defaults to the values used for the forward pass. 54 | 55 | - `adjoint_options` has the special key-value pair `{"norm": "seminorm"}` that provides a potentially more efficient adjoint solve when using adaptive step solvers, as described in the ["Hey, that's not an ODE"](https://arxiv.org/abs/2009.09457) paper. 56 | 57 | - `adjoint_params`: The parameters to compute gradients with respect to in the backward pass. Should be a tuple of tensors. Defaults to `tuple(func.parameters())`. 58 | - If passed then `func` does not have to be a `torch.nn.Module`. 59 | - If `func` has no parameters, `adjoint_params=()` must be specified. 60 | 61 | 62 | ## Callbacks 63 | 64 | Callbacks can be triggered during the solve. Callbacks should be specified as methods of the `func` argument to `odeint` and `odeint_adjoint`. 65 | 66 | At the moment support for this is minimal: let us know if you'd find additional callbacks useful. 67 | 68 | **callback_step(self, t0, y0, dt):**
69 | This is called immediately before taking a step of size `dt`, at time `t0`, with current solution value `y0`. This is supported by every solver except `scipy_solver`. 70 | 71 | **callback_accept_step(self, t0, y0, dt):**
72 | This is called when accepting a step of size `dt` at time `t0`, with current solution value `y0`. This is supported by the adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun). 73 | 74 | **callback_reject_step(self, t0, y0, dt):**
75 | As `callback_accept_step`, except called when rejecting steps. 76 | 77 | In addition, callbacks can be triggered during the adjoint pass by adding `_adjoint` to the name of any one of the supported callbacks, e.g. `callback_step_adjoint`. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ricky Tian Qi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of Differentiable ODE Solvers 2 | 3 | This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. For usage of ODE solvers in deep learning applications, see reference [1]. 4 | 5 | As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU. 6 | 7 | ## Installation 8 | 9 | To install latest stable version: 10 | ``` 11 | pip install torchdiffeq 12 | ``` 13 | 14 | To install latest on GitHub: 15 | ``` 16 | pip install git+https://github.com/rtqichen/torchdiffeq 17 | ``` 18 | 19 | ## Examples 20 | Examples are placed in the [`examples`](./examples) directory. 21 | 22 | We encourage those who are interested in using this library to take a look at [`examples/ode_demo.py`](./examples/ode_demo.py) for understanding how to use `torchdiffeq` to fit a simple spiral ODE. 23 | 24 |

25 | ODE Demo 26 |

27 | 28 | ## Basic usage 29 | This library provides one main interface `odeint` which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value, 30 | ``` 31 | dy/dt = f(t, y) y(t_0) = y_0. 32 | ``` 33 | The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition. 34 | 35 | To solve an IVP using the default solver: 36 | ``` 37 | from torchdiffeq import odeint 38 | 39 | odeint(func, y0, t) 40 | ``` 41 | where `func` is any callable implementing the ordinary differential equation `f(t, x)`, `y0` is an _any_-D Tensor representing the initial values, and `t` is a 1-D Tensor containing the evaluation points. The initial time is taken to be `t[0]`. 42 | 43 | Backpropagation through `odeint` goes through the internals of the solver. Note that this is not numerically stable for all solvers (but should probably be fine with the default `dopri5` method). Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage. 44 | 45 | To use the adjoint method: 46 | ``` 47 | from torchdiffeq import odeint_adjoint as odeint 48 | 49 | odeint(func, y0, t) 50 | ``` 51 | `odeint_adjoint` simply wraps around `odeint`, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call. 52 | 53 | The biggest **gotcha** is that `func` must be a `nn.Module` when using the adjoint method. This is used to collect parameters of the differential equation. 54 | 55 | ## Differentiable event handling 56 | 57 | We allow terminating an ODE solution based on an event function. Backpropagation through most solvers is supported. For usage of event handling in deep learning applications, see reference [2]. 58 | 59 | This can be invoked with `odeint_event`: 60 | ``` 61 | from torchdiffeq import odeint_event 62 | odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs) 63 | ``` 64 | - `func` and `y0` are the same as `odeint`. 65 | - `t0` is a scalar representing the initial time value. 66 | - `event_fn(t, y)` returns a tensor, and is a required keyword argument. 67 | - `reverse_time` is a boolean specifying whether we should solve in reverse time. Default is `False`. 68 | - `odeint_interface` is one of `odeint` or `odeint_adjoint`, specifying whether adjoint mode should be used for differentiating through the ODE solution. Default is `odeint`. 69 | - `**kwargs`: any remaining keyword arguments are passed to `odeint_interface`. 70 | 71 | The solve is terminated at an event time `t` and state `y` when an element of `event_fn(t, y)` is equal to zero. Multiple outputs from `event_fn` can be used to specify multiple event functions, of which the first to trigger will terminate the solve. 72 | 73 | Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function. **NOTE**: parameters for the event function must be in the state itself to obtain gradients. 74 | 75 | The numerical precision for the event time is determined by the `atol` argument. 76 | 77 | See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py). See example code for learning a simple event function in [`examples/learn_physics.py`](./examples/learn_physics.py). 78 | 79 |

80 | Bouncing Ball 81 |

82 | 83 | ## Keyword arguments for odeint(_adjoint) 84 | 85 | #### Keyword arguments: 86 | - `rtol` Relative tolerance. 87 | - `atol` Absolute tolerance. 88 | - `method` One of the solvers listed below. 89 | - `options` A dictionary of solver-specific options, see the [further documentation](FURTHER_DOCUMENTATION.md). 90 | 91 | #### List of ODE Solvers: 92 | 93 | Adaptive-step: 94 | - `dopri8` Runge-Kutta of order 8 of Dormand-Prince-Shampine. 95 | - `dopri5` Runge-Kutta of order 5 of Dormand-Prince-Shampine **[default]**. 96 | - `bosh3` Runge-Kutta of order 3 of Bogacki-Shampine. 97 | - `fehlberg2` Runge-Kutta-Fehlberg of order 2. 98 | - `adaptive_heun` Runge-Kutta of order 2. 99 | 100 | Fixed-step: 101 | - `euler` Euler method. 102 | - `midpoint` Midpoint method. 103 | - `rk4` Fourth-order Runge-Kutta with 3/8 rule. 104 | - `explicit_adams` Explicit Adams-Bashforth. 105 | - `implicit_adams` Implicit Adams-Bashforth-Moulton. 106 | 107 | Additionally, all solvers available through SciPy are wrapped for use with `scipy_solver`. 108 | 109 | For most problems, good choices are the default `dopri5`, or to use `rk4` with `options=dict(step_size=...)` set appropriately small. Adjusting the tolerances (adaptive solvers) or step size (fixed solvers), will allow for trade-offs between speed and accuracy. 110 | 111 | ## Frequently Asked Questions 112 | Take a look at our [FAQ](FAQ.md) for frequently asked questions. 113 | 114 | ## Further documentation 115 | For details of the adjoint-specific and solver-specific options, check out the [further documentation](FURTHER_DOCUMENTATION.md). 116 | 117 | ## References 118 | 119 | Applications of differentiable ODE solvers and event handling are discussed in these two papers: 120 | 121 | Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Information Processing Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366) 122 | 123 | ``` 124 | @article{chen2018neuralode, 125 | title={Neural Ordinary Differential Equations}, 126 | author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David}, 127 | journal={Advances in Neural Information Processing Systems}, 128 | year={2018} 129 | } 130 | ``` 131 | 132 | Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." *International Conference on Learning Representations.* 2021. [[arxiv]](https://arxiv.org/abs/2011.03902) 133 | 134 | ``` 135 | @article{chen2021eventfn, 136 | title={Learning Neural Event Functions for Ordinary Differential Equations}, 137 | author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian}, 138 | journal={International Conference on Learning Representations}, 139 | year={2021} 140 | } 141 | ``` 142 | 143 | The seminorm option for computing adjoints is discussed in 144 | 145 | Patrick Kidger, Ricky T. Q. Chen, Terry Lyons. "'Hey, that’s not an ODE': Faster ODE Adjoints via Seminorms." *International Conference on Machine 146 | Learning.* 2021. [[arxiv]](https://arxiv.org/abs/2009.09457) 147 | ``` 148 | @article{kidger2021hey, 149 | title={"Hey, that's not an ODE": Faster ODE Adjoints via Seminorms.}, 150 | author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry J.}, 151 | journal={International Conference on Machine Learning}, 152 | year={2021} 153 | } 154 | ``` 155 | 156 | --- 157 | 158 | If you found this library useful in your research, please consider citing. 159 | ``` 160 | @misc{torchdiffeq, 161 | author={Chen, Ricky T. Q.}, 162 | title={torchdiffeq}, 163 | year={2018}, 164 | url={https://github.com/rtqichen/torchdiffeq}, 165 | } 166 | ``` 167 | -------------------------------------------------------------------------------- /assets/bouncing_ball.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/bouncing_ball.png -------------------------------------------------------------------------------- /assets/cnf_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/cnf_demo.gif -------------------------------------------------------------------------------- /assets/ode_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/ode_demo.gif -------------------------------------------------------------------------------- /assets/odenet_0_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/odenet_0_viz.png -------------------------------------------------------------------------------- /assets/resnet_0_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtqichen/torchdiffeq/657943acefa826ef04c025ebeb1ff5e9d60dc268/assets/resnet_0_viz.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Overview of Examples 2 | 3 | This `examples` directory contains cleaned up code regarding the usage of adaptive ODE solvers in machine learning. The scripts in this directory assume that `torchdiffeq` is installed following instructions from the main directory. 4 | 5 | ## Demo 6 | The `ode_demo.py` file contains a short implementation of learning a dynamics model to mimic a spiral ODE. 7 | 8 | To visualize the training progress, run 9 | ``` 10 | python ode_demo.py --viz 11 | ``` 12 | The training should look similar to this: 13 | 14 |

15 | ODE Demo 16 |

17 | 18 | ## ODEnet for MNIST 19 | The `odenet_mnist.py` file contains a reproduction of the MNIST experiments in our Neural ODE paper. Notably not just the architecture but the ODE solver library and integration method are different from our original experiments, though the results are similar to those we report in the paper. 20 | 21 | We can use an adaptive ODE solver to approximate our continuous-depth network while still backpropagating through the network. 22 | ``` 23 | python odenet_mnist.py --network odenet 24 | ``` 25 | However, the memory requirements for this will blow up very fast, especially for more complex problems where the number of function evaluations can reach nearly a thousand. 26 | 27 | For applications that require solving complex trajectories, we recommend using the adjoint method. 28 | ``` 29 | python odenet_mnist.py --network odenet --adjoint True 30 | ``` 31 | The adjoint method can be slower when using an adaptive ODE solver as it involves another solve in the backward pass with a much larger system, so experimenting on small systems with direct backpropagation first is recommended. 32 | 33 | Thankfully, it is extremely easy to write code for both adjoint and non-adjoint backpropagation, as they use the same interface. 34 | ``` 35 | if adjoint: 36 | from torchdiffeq import odeint_adjoint as odeint 37 | else: 38 | from torchdiffeq import odeint 39 | ``` 40 | The main gotcha is that `odeint_adjoint` requires implementing the dynamics network as a `nn.Module` while `odeint` can work with any callable in Python. 41 | 42 | ## Continuous Normalizing Flows 43 | 44 | The `cnf.py` file contains a simple CNF implementation for learning the density of a coencentric circles dataset. 45 | 46 | To train a CNF and visualise the resulting dynamics, run 47 | ``` 48 | python cnf.py --viz 49 | ``` 50 | The result should look similar to this: 51 | 52 |

53 | CNF Demo 54 |

55 | 56 | More comprehensive code for continuous normalizing flows (CNFs) has its own public repository. Tools for training, evaluating, and visualizing CNFs for reversible generative modeling are provided along with FFJORD, a linear cost stochastic approximation of CNFs. 57 | 58 | Find the code in https://github.com/rtqichen/ffjord. This code contains some advanced tricks for `torchdiffeq`. 59 | -------------------------------------------------------------------------------- /examples/bouncing_ball.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from torchdiffeq import odeint, odeint_adjoint 9 | from torchdiffeq import odeint_event 10 | 11 | torch.set_default_dtype(torch.float64) 12 | 13 | 14 | class BouncingBallExample(nn.Module): 15 | def __init__(self, radius=0.2, gravity=9.8, adjoint=False): 16 | super().__init__() 17 | self.gravity = nn.Parameter(torch.as_tensor([gravity])) 18 | self.log_radius = nn.Parameter(torch.log(torch.as_tensor([radius]))) 19 | self.t0 = nn.Parameter(torch.tensor([0.0])) 20 | self.init_pos = nn.Parameter(torch.tensor([10.0])) 21 | self.init_vel = nn.Parameter(torch.tensor([0.0])) 22 | self.absorption = nn.Parameter(torch.tensor([0.2])) 23 | self.odeint = odeint_adjoint if adjoint else odeint 24 | 25 | def forward(self, t, state): 26 | pos, vel, log_radius = state 27 | dpos = vel 28 | dvel = -self.gravity 29 | return dpos, dvel, torch.zeros_like(log_radius) 30 | 31 | def event_fn(self, t, state): 32 | # positive if ball in mid-air, negative if ball within ground. 33 | pos, _, log_radius = state 34 | return pos - torch.exp(log_radius) 35 | 36 | def get_initial_state(self): 37 | state = (self.init_pos, self.init_vel, self.log_radius) 38 | return self.t0, state 39 | 40 | def state_update(self, state): 41 | """Updates state based on an event (collision).""" 42 | pos, vel, log_radius = state 43 | pos = ( 44 | pos + 1e-7 45 | ) # need to add a small eps so as not to trigger the event function immediately. 46 | vel = -vel * (1 - self.absorption) 47 | return (pos, vel, log_radius) 48 | 49 | def get_collision_times(self, nbounces=1): 50 | 51 | event_times = [] 52 | 53 | t0, state = self.get_initial_state() 54 | 55 | for i in range(nbounces): 56 | event_t, solution = odeint_event( 57 | self, 58 | state, 59 | t0, 60 | event_fn=self.event_fn, 61 | reverse_time=False, 62 | atol=1e-8, 63 | rtol=1e-8, 64 | odeint_interface=self.odeint, 65 | ) 66 | event_times.append(event_t) 67 | 68 | state = self.state_update(tuple(s[-1] for s in solution)) 69 | t0 = event_t 70 | 71 | return event_times 72 | 73 | def simulate(self, nbounces=1): 74 | event_times = self.get_collision_times(nbounces) 75 | 76 | # get dense path 77 | t0, state = self.get_initial_state() 78 | trajectory = [state[0][None]] 79 | velocity = [state[1][None]] 80 | times = [t0.reshape(-1)] 81 | for event_t in event_times: 82 | tt = torch.linspace( 83 | float(t0), float(event_t), int((float(event_t) - float(t0)) * 50) 84 | )[1:-1] 85 | tt = torch.cat([t0.reshape(-1), tt, event_t.reshape(-1)]) 86 | solution = odeint(self, state, tt, atol=1e-8, rtol=1e-8) 87 | 88 | trajectory.append(solution[0][1:]) 89 | velocity.append(solution[1][1:]) 90 | times.append(tt[1:]) 91 | 92 | state = self.state_update(tuple(s[-1] for s in solution)) 93 | t0 = event_t 94 | 95 | return ( 96 | torch.cat(times), 97 | torch.cat(trajectory, dim=0).reshape(-1), 98 | torch.cat(velocity, dim=0).reshape(-1), 99 | event_times, 100 | ) 101 | 102 | 103 | def gradcheck(nbounces): 104 | 105 | system = BouncingBallExample() 106 | 107 | variables = { 108 | "init_pos": system.init_pos, 109 | "init_vel": system.init_vel, 110 | "t0": system.t0, 111 | "gravity": system.gravity, 112 | "log_radius": system.log_radius, 113 | } 114 | 115 | event_t = system.get_collision_times(nbounces)[-1] 116 | event_t.backward() 117 | 118 | analytical_grads = {} 119 | for name, p in system.named_parameters(): 120 | for var in variables.keys(): 121 | if var in name: 122 | analytical_grads[var] = p.grad 123 | 124 | eps = 1e-3 125 | 126 | fd_grads = {} 127 | 128 | for var, param in variables.items(): 129 | orig = param.data 130 | param.data = orig - eps 131 | f_meps = system.get_collision_times(nbounces)[-1] 132 | param.data = orig + eps 133 | f_peps = system.get_collision_times(nbounces)[-1] 134 | param.data = orig 135 | fd = (f_peps - f_meps) / (2 * eps) 136 | fd_grads[var] = fd 137 | 138 | success = True 139 | for var in variables.keys(): 140 | analytical = analytical_grads[var] 141 | fd = fd_grads[var] 142 | if torch.norm(analytical - fd) > 1e-4: 143 | success = False 144 | print( 145 | f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}" 146 | ) 147 | 148 | if not success: 149 | raise Exception("Gradient check failed.") 150 | 151 | print("Gradient check passed.") 152 | 153 | 154 | if __name__ == "__main__": 155 | 156 | parser = argparse.ArgumentParser(description="Process some integers.") 157 | parser.add_argument("nbounces", type=int, nargs="?", default=10) 158 | parser.add_argument("--adjoint", action="store_true") 159 | args = parser.parse_args() 160 | 161 | gradcheck(args.nbounces) 162 | 163 | system = BouncingBallExample() 164 | times, trajectory, velocity, event_times = system.simulate(nbounces=args.nbounces) 165 | times = times.detach().cpu().numpy() 166 | trajectory = trajectory.detach().cpu().numpy() 167 | velocity = velocity.detach().cpu().numpy() 168 | event_times = torch.stack(event_times).detach().cpu().numpy() 169 | 170 | plt.figure(figsize=(7, 3.5)) 171 | 172 | # Event locations. 173 | for event_t in event_times: 174 | plt.plot( 175 | event_t, 176 | 0.0, 177 | color="C0", 178 | marker="o", 179 | markersize=7, 180 | fillstyle="none", 181 | linestyle="", 182 | ) 183 | 184 | (vel,) = plt.plot( 185 | times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0 186 | ) 187 | (pos,) = plt.plot(times, trajectory, color="C0", linewidth=2.0) 188 | 189 | plt.hlines(0, 0, 100) 190 | plt.xlim([times[0], times[-1]]) 191 | plt.ylim([velocity.min() - 0.02, velocity.max() + 0.02]) 192 | plt.ylabel("Markov State", fontsize=16) 193 | plt.xlabel("Time", fontsize=13) 194 | plt.legend([pos, vel], ["Position", "Velocity"], fontsize=16) 195 | 196 | plt.gca().xaxis.set_tick_params( 197 | direction="in", which="both" 198 | ) # The bottom will maintain the default of 'out' 199 | plt.gca().yaxis.set_tick_params( 200 | direction="in", which="both" 201 | ) # The bottom will maintain the default of 'out' 202 | 203 | # Hide the right and top spines 204 | plt.gca().spines["right"].set_visible(False) 205 | plt.gca().spines["top"].set_visible(False) 206 | 207 | # Only show ticks on the left and bottom spines 208 | plt.gca().yaxis.set_ticks_position("left") 209 | plt.gca().xaxis.set_ticks_position("bottom") 210 | 211 | plt.tight_layout() 212 | plt.savefig("bouncing_ball.png") 213 | -------------------------------------------------------------------------------- /examples/cnf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import glob 4 | from PIL import Image 5 | import numpy as np 6 | import matplotlib 7 | matplotlib.use('agg') 8 | import matplotlib.pyplot as plt 9 | from sklearn.datasets import make_circles 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--adjoint', action='store_true') 17 | parser.add_argument('--viz', action='store_true') 18 | parser.add_argument('--niters', type=int, default=1000) 19 | parser.add_argument('--lr', type=float, default=1e-3) 20 | parser.add_argument('--num_samples', type=int, default=512) 21 | parser.add_argument('--width', type=int, default=64) 22 | parser.add_argument('--hidden_dim', type=int, default=32) 23 | parser.add_argument('--gpu', type=int, default=0) 24 | parser.add_argument('--train_dir', type=str, default=None) 25 | parser.add_argument('--results_dir', type=str, default="./results") 26 | args = parser.parse_args() 27 | 28 | if args.adjoint: 29 | from torchdiffeq import odeint_adjoint as odeint 30 | else: 31 | from torchdiffeq import odeint 32 | 33 | 34 | class CNF(nn.Module): 35 | """Adapted from the NumPy implementation at: 36 | https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52 37 | """ 38 | def __init__(self, in_out_dim, hidden_dim, width): 39 | super().__init__() 40 | self.in_out_dim = in_out_dim 41 | self.hidden_dim = hidden_dim 42 | self.width = width 43 | self.hyper_net = HyperNetwork(in_out_dim, hidden_dim, width) 44 | 45 | def forward(self, t, states): 46 | z = states[0] 47 | logp_z = states[1] 48 | 49 | batchsize = z.shape[0] 50 | 51 | with torch.set_grad_enabled(True): 52 | z.requires_grad_(True) 53 | 54 | W, B, U = self.hyper_net(t) 55 | 56 | Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1) 57 | 58 | h = torch.tanh(torch.matmul(Z, W) + B) 59 | dz_dt = torch.matmul(h, U).mean(0) 60 | 61 | dlogp_z_dt = -trace_df_dz(dz_dt, z).view(batchsize, 1) 62 | 63 | return (dz_dt, dlogp_z_dt) 64 | 65 | 66 | def trace_df_dz(f, z): 67 | """Calculates the trace of the Jacobian df/dz. 68 | Stolen from: https://github.com/rtqichen/ffjord/blob/master/lib/layers/odefunc.py#L13 69 | """ 70 | sum_diag = 0. 71 | for i in range(z.shape[1]): 72 | sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous() 73 | 74 | return sum_diag.contiguous() 75 | 76 | 77 | class HyperNetwork(nn.Module): 78 | """Hyper-network allowing f(z(t), t) to change with time. 79 | 80 | Adapted from the NumPy implementation at: 81 | https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52 82 | """ 83 | def __init__(self, in_out_dim, hidden_dim, width): 84 | super().__init__() 85 | 86 | blocksize = width * in_out_dim 87 | 88 | self.fc1 = nn.Linear(1, hidden_dim) 89 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 90 | self.fc3 = nn.Linear(hidden_dim, 3 * blocksize + width) 91 | 92 | self.in_out_dim = in_out_dim 93 | self.hidden_dim = hidden_dim 94 | self.width = width 95 | self.blocksize = blocksize 96 | 97 | def forward(self, t): 98 | # predict params 99 | params = t.reshape(1, 1) 100 | params = torch.tanh(self.fc1(params)) 101 | params = torch.tanh(self.fc2(params)) 102 | params = self.fc3(params) 103 | 104 | # restructure 105 | params = params.reshape(-1) 106 | W = params[:self.blocksize].reshape(self.width, self.in_out_dim, 1) 107 | 108 | U = params[self.blocksize:2 * self.blocksize].reshape(self.width, 1, self.in_out_dim) 109 | 110 | G = params[2 * self.blocksize:3 * self.blocksize].reshape(self.width, 1, self.in_out_dim) 111 | U = U * torch.sigmoid(G) 112 | 113 | B = params[3 * self.blocksize:].reshape(self.width, 1, 1) 114 | return [W, B, U] 115 | 116 | 117 | class RunningAverageMeter(object): 118 | """Computes and stores the average and current value""" 119 | 120 | def __init__(self, momentum=0.99): 121 | self.momentum = momentum 122 | self.reset() 123 | 124 | def reset(self): 125 | self.val = None 126 | self.avg = 0 127 | 128 | def update(self, val): 129 | if self.val is None: 130 | self.avg = val 131 | else: 132 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 133 | self.val = val 134 | 135 | 136 | def get_batch(num_samples): 137 | points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5) 138 | x = torch.tensor(points).type(torch.float32).to(device) 139 | logp_diff_t1 = torch.zeros(num_samples, 1).type(torch.float32).to(device) 140 | 141 | return(x, logp_diff_t1) 142 | 143 | 144 | if __name__ == '__main__': 145 | t0 = 0 146 | t1 = 10 147 | device = torch.device('cuda:' + str(args.gpu) 148 | if torch.cuda.is_available() else 'cpu') 149 | 150 | # model 151 | func = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width).to(device) 152 | optimizer = optim.Adam(func.parameters(), lr=args.lr) 153 | p_z0 = torch.distributions.MultivariateNormal( 154 | loc=torch.tensor([0.0, 0.0]).to(device), 155 | covariance_matrix=torch.tensor([[0.1, 0.0], [0.0, 0.1]]).to(device) 156 | ) 157 | loss_meter = RunningAverageMeter() 158 | 159 | if args.train_dir is not None: 160 | if not os.path.exists(args.train_dir): 161 | os.makedirs(args.train_dir) 162 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth') 163 | if os.path.exists(ckpt_path): 164 | checkpoint = torch.load(ckpt_path) 165 | func.load_state_dict(checkpoint['func_state_dict']) 166 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 167 | print('Loaded ckpt from {}'.format(ckpt_path)) 168 | 169 | try: 170 | for itr in range(1, args.niters + 1): 171 | optimizer.zero_grad() 172 | 173 | x, logp_diff_t1 = get_batch(args.num_samples) 174 | 175 | z_t, logp_diff_t = odeint( 176 | func, 177 | (x, logp_diff_t1), 178 | torch.tensor([t1, t0]).type(torch.float32).to(device), 179 | atol=1e-5, 180 | rtol=1e-5, 181 | method='dopri5', 182 | ) 183 | 184 | z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1] 185 | 186 | logp_x = p_z0.log_prob(z_t0).to(device) - logp_diff_t0.view(-1) 187 | loss = -logp_x.mean(0) 188 | 189 | loss.backward() 190 | optimizer.step() 191 | 192 | loss_meter.update(loss.item()) 193 | 194 | print('Iter: {}, running avg loss: {:.4f}'.format(itr, loss_meter.avg)) 195 | 196 | except KeyboardInterrupt: 197 | if args.train_dir is not None: 198 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth') 199 | torch.save({ 200 | 'func_state_dict': func.state_dict(), 201 | 'optimizer_state_dict': optimizer.state_dict(), 202 | }, ckpt_path) 203 | print('Stored ckpt at {}'.format(ckpt_path)) 204 | print('Training complete after {} iters.'.format(itr)) 205 | 206 | if args.viz: 207 | viz_samples = 30000 208 | viz_timesteps = 41 209 | target_sample, _ = get_batch(viz_samples) 210 | 211 | if not os.path.exists(args.results_dir): 212 | os.makedirs(args.results_dir) 213 | with torch.no_grad(): 214 | # Generate evolution of samples 215 | z_t0 = p_z0.sample([viz_samples]).to(device) 216 | logp_diff_t0 = torch.zeros(viz_samples, 1).type(torch.float32).to(device) 217 | 218 | z_t_samples, _ = odeint( 219 | func, 220 | (z_t0, logp_diff_t0), 221 | torch.tensor(np.linspace(t0, t1, viz_timesteps)).to(device), 222 | atol=1e-5, 223 | rtol=1e-5, 224 | method='dopri5', 225 | ) 226 | 227 | # Generate evolution of density 228 | x = np.linspace(-1.5, 1.5, 100) 229 | y = np.linspace(-1.5, 1.5, 100) 230 | points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T 231 | 232 | z_t1 = torch.tensor(points).type(torch.float32).to(device) 233 | logp_diff_t1 = torch.zeros(z_t1.shape[0], 1).type(torch.float32).to(device) 234 | 235 | z_t_density, logp_diff_t = odeint( 236 | func, 237 | (z_t1, logp_diff_t1), 238 | torch.tensor(np.linspace(t1, t0, viz_timesteps)).to(device), 239 | atol=1e-5, 240 | rtol=1e-5, 241 | method='dopri5', 242 | ) 243 | 244 | # Create plots for each timestep 245 | for (t, z_sample, z_density, logp_diff) in zip( 246 | np.linspace(t0, t1, viz_timesteps), 247 | z_t_samples, z_t_density, logp_diff_t 248 | ): 249 | fig = plt.figure(figsize=(12, 4), dpi=200) 250 | plt.tight_layout() 251 | plt.axis('off') 252 | plt.margins(0, 0) 253 | fig.suptitle(f'{t:.2f}s') 254 | 255 | ax1 = fig.add_subplot(1, 3, 1) 256 | ax1.set_title('Target') 257 | ax1.get_xaxis().set_ticks([]) 258 | ax1.get_yaxis().set_ticks([]) 259 | ax2 = fig.add_subplot(1, 3, 2) 260 | ax2.set_title('Samples') 261 | ax2.get_xaxis().set_ticks([]) 262 | ax2.get_yaxis().set_ticks([]) 263 | ax3 = fig.add_subplot(1, 3, 3) 264 | ax3.set_title('Log Probability') 265 | ax3.get_xaxis().set_ticks([]) 266 | ax3.get_yaxis().set_ticks([]) 267 | 268 | ax1.hist2d(*target_sample.detach().cpu().numpy().T, bins=300, density=True, 269 | range=[[-1.5, 1.5], [-1.5, 1.5]]) 270 | 271 | ax2.hist2d(*z_sample.detach().cpu().numpy().T, bins=300, density=True, 272 | range=[[-1.5, 1.5], [-1.5, 1.5]]) 273 | 274 | logp = p_z0.log_prob(z_density) - logp_diff.view(-1) 275 | ax3.tricontourf(*z_t1.detach().cpu().numpy().T, 276 | np.exp(logp.detach().cpu().numpy()), 200) 277 | 278 | plt.savefig(os.path.join(args.results_dir, f"cnf-viz-{int(t*1000):05d}.jpg"), 279 | pad_inches=0.2, bbox_inches='tight') 280 | plt.close() 281 | 282 | img, *imgs = [Image.open(f) for f in sorted(glob.glob(os.path.join(args.results_dir, f"cnf-viz-*.jpg")))] 283 | img.save(fp=os.path.join(args.results_dir, "cnf-viz.gif"), format='GIF', append_images=imgs, 284 | save_all=True, duration=250, loop=0) 285 | 286 | print('Saved visualization animation at {}'.format(os.path.join(args.results_dir, "cnf-viz.gif"))) 287 | -------------------------------------------------------------------------------- /examples/latent_ode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import time 5 | import numpy as np 6 | import numpy.random as npr 7 | import matplotlib 8 | matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--adjoint', type=eval, default=False) 18 | parser.add_argument('--visualize', type=eval, default=False) 19 | parser.add_argument('--niters', type=int, default=2000) 20 | parser.add_argument('--lr', type=float, default=0.01) 21 | parser.add_argument('--gpu', type=int, default=0) 22 | parser.add_argument('--train_dir', type=str, default=None) 23 | args = parser.parse_args() 24 | 25 | if args.adjoint: 26 | from torchdiffeq import odeint_adjoint as odeint 27 | else: 28 | from torchdiffeq import odeint 29 | 30 | 31 | def generate_spiral2d(nspiral=1000, 32 | ntotal=500, 33 | nsample=100, 34 | start=0., 35 | stop=1, # approximately equal to 6pi 36 | noise_std=.1, 37 | a=0., 38 | b=1., 39 | savefig=True): 40 | """Parametric formula for 2d spiral is `r = a + b * theta`. 41 | 42 | Args: 43 | nspiral: number of spirals, i.e. batch dimension 44 | ntotal: total number of datapoints per spiral 45 | nsample: number of sampled datapoints for model fitting per spiral 46 | start: spiral starting theta value 47 | stop: spiral ending theta value 48 | noise_std: observation noise standard deviation 49 | a, b: parameters of the Archimedean spiral 50 | savefig: plot the ground truth for sanity check 51 | 52 | Returns: 53 | Tuple where first element is true trajectory of size (nspiral, ntotal, 2), 54 | second element is noisy observations of size (nspiral, nsample, 2), 55 | third element is timestamps of size (ntotal,), 56 | and fourth element is timestamps of size (nsample,) 57 | """ 58 | 59 | # add 1 all timestamps to avoid division by 0 60 | orig_ts = np.linspace(start, stop, num=ntotal) 61 | samp_ts = orig_ts[:nsample] 62 | 63 | # generate clock-wise and counter clock-wise spirals in observation space 64 | # with two sets of time-invariant latent dynamics 65 | zs_cw = stop + 1. - orig_ts 66 | rs_cw = a + b * 50. / zs_cw 67 | xs, ys = rs_cw * np.cos(zs_cw) - 5., rs_cw * np.sin(zs_cw) 68 | orig_traj_cw = np.stack((xs, ys), axis=1) 69 | 70 | zs_cc = orig_ts 71 | rw_cc = a + b * zs_cc 72 | xs, ys = rw_cc * np.cos(zs_cc) + 5., rw_cc * np.sin(zs_cc) 73 | orig_traj_cc = np.stack((xs, ys), axis=1) 74 | 75 | if savefig: 76 | plt.figure() 77 | plt.plot(orig_traj_cw[:, 0], orig_traj_cw[:, 1], label='clock') 78 | plt.plot(orig_traj_cc[:, 0], orig_traj_cc[:, 1], label='counter clock') 79 | plt.legend() 80 | plt.savefig('./ground_truth.png', dpi=500) 81 | print('Saved ground truth spiral at {}'.format('./ground_truth.png')) 82 | 83 | # sample starting timestamps 84 | orig_trajs = [] 85 | samp_trajs = [] 86 | for _ in range(nspiral): 87 | # don't sample t0 very near the start or the end 88 | t0_idx = npr.multinomial( 89 | 1, [1. / (ntotal - 2. * nsample)] * (ntotal - int(2 * nsample))) 90 | t0_idx = np.argmax(t0_idx) + nsample 91 | 92 | cc = bool(npr.rand() > .5) # uniformly select rotation 93 | orig_traj = orig_traj_cc if cc else orig_traj_cw 94 | orig_trajs.append(orig_traj) 95 | 96 | samp_traj = orig_traj[t0_idx:t0_idx + nsample, :].copy() 97 | samp_traj += npr.randn(*samp_traj.shape) * noise_std 98 | samp_trajs.append(samp_traj) 99 | 100 | # batching for sample trajectories is good for RNN; batching for original 101 | # trajectories only for ease of indexing 102 | orig_trajs = np.stack(orig_trajs, axis=0) 103 | samp_trajs = np.stack(samp_trajs, axis=0) 104 | 105 | return orig_trajs, samp_trajs, orig_ts, samp_ts 106 | 107 | 108 | class LatentODEfunc(nn.Module): 109 | 110 | def __init__(self, latent_dim=4, nhidden=20): 111 | super(LatentODEfunc, self).__init__() 112 | self.elu = nn.ELU(inplace=True) 113 | self.fc1 = nn.Linear(latent_dim, nhidden) 114 | self.fc2 = nn.Linear(nhidden, nhidden) 115 | self.fc3 = nn.Linear(nhidden, latent_dim) 116 | self.nfe = 0 117 | 118 | def forward(self, t, x): 119 | self.nfe += 1 120 | out = self.fc1(x) 121 | out = self.elu(out) 122 | out = self.fc2(out) 123 | out = self.elu(out) 124 | out = self.fc3(out) 125 | return out 126 | 127 | 128 | class RecognitionRNN(nn.Module): 129 | 130 | def __init__(self, latent_dim=4, obs_dim=2, nhidden=25, nbatch=1): 131 | super(RecognitionRNN, self).__init__() 132 | self.nhidden = nhidden 133 | self.nbatch = nbatch 134 | self.i2h = nn.Linear(obs_dim + nhidden, nhidden) 135 | self.h2o = nn.Linear(nhidden, latent_dim * 2) 136 | 137 | def forward(self, x, h): 138 | combined = torch.cat((x, h), dim=1) 139 | h = torch.tanh(self.i2h(combined)) 140 | out = self.h2o(h) 141 | return out, h 142 | 143 | def initHidden(self): 144 | return torch.zeros(self.nbatch, self.nhidden) 145 | 146 | 147 | class Decoder(nn.Module): 148 | 149 | def __init__(self, latent_dim=4, obs_dim=2, nhidden=20): 150 | super(Decoder, self).__init__() 151 | self.relu = nn.ReLU(inplace=True) 152 | self.fc1 = nn.Linear(latent_dim, nhidden) 153 | self.fc2 = nn.Linear(nhidden, obs_dim) 154 | 155 | def forward(self, z): 156 | out = self.fc1(z) 157 | out = self.relu(out) 158 | out = self.fc2(out) 159 | return out 160 | 161 | 162 | class RunningAverageMeter(object): 163 | """Computes and stores the average and current value""" 164 | 165 | def __init__(self, momentum=0.99): 166 | self.momentum = momentum 167 | self.reset() 168 | 169 | def reset(self): 170 | self.val = None 171 | self.avg = 0 172 | 173 | def update(self, val): 174 | if self.val is None: 175 | self.avg = val 176 | else: 177 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 178 | self.val = val 179 | 180 | 181 | def log_normal_pdf(x, mean, logvar): 182 | const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device) 183 | const = torch.log(const) 184 | return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar)) 185 | 186 | 187 | def normal_kl(mu1, lv1, mu2, lv2): 188 | v1 = torch.exp(lv1) 189 | v2 = torch.exp(lv2) 190 | lstd1 = lv1 / 2. 191 | lstd2 = lv2 / 2. 192 | 193 | kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5 194 | return kl 195 | 196 | 197 | if __name__ == '__main__': 198 | latent_dim = 4 199 | nhidden = 20 200 | rnn_nhidden = 25 201 | obs_dim = 2 202 | nspiral = 1000 203 | start = 0. 204 | stop = 6 * np.pi 205 | noise_std = .3 206 | a = 0. 207 | b = .3 208 | ntotal = 1000 209 | nsample = 100 210 | device = torch.device('cuda:' + str(args.gpu) 211 | if torch.cuda.is_available() else 'cpu') 212 | 213 | # generate toy spiral data 214 | orig_trajs, samp_trajs, orig_ts, samp_ts = generate_spiral2d( 215 | nspiral=nspiral, 216 | start=start, 217 | stop=stop, 218 | noise_std=noise_std, 219 | a=a, b=b 220 | ) 221 | orig_trajs = torch.from_numpy(orig_trajs).float().to(device) 222 | samp_trajs = torch.from_numpy(samp_trajs).float().to(device) 223 | samp_ts = torch.from_numpy(samp_ts).float().to(device) 224 | 225 | # model 226 | func = LatentODEfunc(latent_dim, nhidden).to(device) 227 | rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, nspiral).to(device) 228 | dec = Decoder(latent_dim, obs_dim, nhidden).to(device) 229 | params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters())) 230 | optimizer = optim.Adam(params, lr=args.lr) 231 | loss_meter = RunningAverageMeter() 232 | 233 | if args.train_dir is not None: 234 | if not os.path.exists(args.train_dir): 235 | os.makedirs(args.train_dir) 236 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth') 237 | if os.path.exists(ckpt_path): 238 | checkpoint = torch.load(ckpt_path) 239 | func.load_state_dict(checkpoint['func_state_dict']) 240 | rec.load_state_dict(checkpoint['rec_state_dict']) 241 | dec.load_state_dict(checkpoint['dec_state_dict']) 242 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 243 | orig_trajs = checkpoint['orig_trajs'] 244 | samp_trajs = checkpoint['samp_trajs'] 245 | orig_ts = checkpoint['orig_ts'] 246 | samp_ts = checkpoint['samp_ts'] 247 | print('Loaded ckpt from {}'.format(ckpt_path)) 248 | 249 | try: 250 | for itr in range(1, args.niters + 1): 251 | optimizer.zero_grad() 252 | # backward in time to infer q(z_0) 253 | h = rec.initHidden().to(device) 254 | for t in reversed(range(samp_trajs.size(1))): 255 | obs = samp_trajs[:, t, :] 256 | out, h = rec.forward(obs, h) 257 | qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:] 258 | epsilon = torch.randn(qz0_mean.size()).to(device) 259 | z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean 260 | 261 | # forward in time and solve ode for reconstructions 262 | pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2) 263 | pred_x = dec(pred_z) 264 | 265 | # compute loss 266 | noise_std_ = torch.zeros(pred_x.size()).to(device) + noise_std 267 | noise_logvar = 2. * torch.log(noise_std_).to(device) 268 | logpx = log_normal_pdf( 269 | samp_trajs, pred_x, noise_logvar).sum(-1).sum(-1) 270 | pz0_mean = pz0_logvar = torch.zeros(z0.size()).to(device) 271 | analytic_kl = normal_kl(qz0_mean, qz0_logvar, 272 | pz0_mean, pz0_logvar).sum(-1) 273 | loss = torch.mean(-logpx + analytic_kl, dim=0) 274 | loss.backward() 275 | optimizer.step() 276 | loss_meter.update(loss.item()) 277 | 278 | print('Iter: {}, running avg elbo: {:.4f}'.format(itr, -loss_meter.avg)) 279 | 280 | except KeyboardInterrupt: 281 | if args.train_dir is not None: 282 | ckpt_path = os.path.join(args.train_dir, 'ckpt.pth') 283 | torch.save({ 284 | 'func_state_dict': func.state_dict(), 285 | 'rec_state_dict': rec.state_dict(), 286 | 'dec_state_dict': dec.state_dict(), 287 | 'optimizer_state_dict': optimizer.state_dict(), 288 | 'orig_trajs': orig_trajs, 289 | 'samp_trajs': samp_trajs, 290 | 'orig_ts': orig_ts, 291 | 'samp_ts': samp_ts, 292 | }, ckpt_path) 293 | print('Stored ckpt at {}'.format(ckpt_path)) 294 | print('Training complete after {} iters.'.format(itr)) 295 | 296 | if args.visualize: 297 | with torch.no_grad(): 298 | # sample from trajectorys' approx. posterior 299 | h = rec.initHidden().to(device) 300 | for t in reversed(range(samp_trajs.size(1))): 301 | obs = samp_trajs[:, t, :] 302 | out, h = rec.forward(obs, h) 303 | qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:] 304 | epsilon = torch.randn(qz0_mean.size()).to(device) 305 | z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean 306 | orig_ts = torch.from_numpy(orig_ts).float().to(device) 307 | 308 | # take first trajectory for visualization 309 | z0 = z0[0] 310 | 311 | ts_pos = np.linspace(0., 2. * np.pi, num=2000) 312 | ts_neg = np.linspace(-np.pi, 0., num=2000)[::-1].copy() 313 | ts_pos = torch.from_numpy(ts_pos).float().to(device) 314 | ts_neg = torch.from_numpy(ts_neg).float().to(device) 315 | 316 | zs_pos = odeint(func, z0, ts_pos) 317 | zs_neg = odeint(func, z0, ts_neg) 318 | 319 | xs_pos = dec(zs_pos) 320 | xs_neg = torch.flip(dec(zs_neg), dims=[0]) 321 | 322 | xs_pos = xs_pos.cpu().numpy() 323 | xs_neg = xs_neg.cpu().numpy() 324 | orig_traj = orig_trajs[0].cpu().numpy() 325 | samp_traj = samp_trajs[0].cpu().numpy() 326 | 327 | plt.figure() 328 | plt.plot(orig_traj[:, 0], orig_traj[:, 1], 329 | 'g', label='true trajectory') 330 | plt.plot(xs_pos[:, 0], xs_pos[:, 1], 'r', 331 | label='learned trajectory (t>0)') 332 | plt.plot(xs_neg[:, 0], xs_neg[:, 1], 'c', 333 | label='learned trajectory (t<0)') 334 | plt.scatter(samp_traj[:, 0], samp_traj[ 335 | :, 1], label='sampled data', s=3) 336 | plt.legend() 337 | plt.savefig('./vis.png', dpi=500) 338 | print('Saved visualization figure at {}'.format('./vis.png')) 339 | -------------------------------------------------------------------------------- /examples/learn_physics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import math 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | from torchdiffeq import odeint, odeint_event 10 | 11 | from bouncing_ball import BouncingBallExample 12 | 13 | 14 | class HamiltonianDynamics(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | self.dvel = nn.Linear(1, 1) 18 | self.scale = nn.Parameter(torch.tensor(10.0)) 19 | 20 | def forward(self, t, state): 21 | pos, vel, *rest = state 22 | dpos = vel 23 | dvel = torch.tanh(self.dvel(torch.zeros_like(vel))) * self.scale 24 | return (dpos, dvel, *[torch.zeros_like(r) for r in rest]) 25 | 26 | 27 | class EventFn(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.radius = nn.Parameter(torch.rand(1)) 31 | 32 | def parameters(self): 33 | return [self.radius] 34 | 35 | def forward(self, t, state): 36 | # IMPORTANT: event computation must use variables from the state. 37 | pos, _, radius = state 38 | return pos - radius.reshape_as(pos) ** 2 39 | 40 | 41 | class InstantaneousStateChange(nn.Module): 42 | def __init__(self): 43 | super().__init__() 44 | self.net = nn.Linear(1, 1) 45 | 46 | def forward(self, t, state): 47 | pos, vel, *rest = state 48 | vel = -torch.sigmoid(self.net(torch.ones_like(vel))) * vel 49 | return (pos, vel, *rest) 50 | 51 | 52 | class NeuralPhysics(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | self.initial_pos = nn.Parameter(torch.tensor([10.0])) 56 | self.initial_vel = nn.Parameter(torch.tensor([0.0])) 57 | self.dynamics_fn = HamiltonianDynamics() 58 | self.event_fn = EventFn() 59 | self.inst_update = InstantaneousStateChange() 60 | 61 | def simulate(self, times): 62 | 63 | t0 = torch.tensor([0.0]).to(times) 64 | 65 | # Add a terminal time to the event function. 66 | def event_fn(t, state): 67 | if t > times[-1] + 1e-7: 68 | return torch.zeros_like(t) 69 | event_fval = self.event_fn(t, state) 70 | return event_fval 71 | 72 | # IMPORTANT: for gradients of odeint_event to be computed, parameters of the event function 73 | # must appear in the state in the current implementation. 74 | state = (self.initial_pos, self.initial_vel, *self.event_fn.parameters()) 75 | 76 | event_times = [] 77 | 78 | trajectory = [state[0][None]] 79 | 80 | n_events = 0 81 | max_events = 20 82 | 83 | while t0 < times[-1] and n_events < max_events: 84 | last = n_events == max_events - 1 85 | 86 | if not last: 87 | event_t, solution = odeint_event( 88 | self.dynamics_fn, 89 | state, 90 | t0, 91 | event_fn=event_fn, 92 | atol=1e-8, 93 | rtol=1e-8, 94 | method="dopri5", 95 | ) 96 | else: 97 | event_t = times[-1] 98 | 99 | interval_ts = times[times > t0] 100 | interval_ts = interval_ts[interval_ts <= event_t] 101 | interval_ts = torch.cat([t0.reshape(-1), interval_ts.reshape(-1)]) 102 | 103 | solution_ = odeint( 104 | self.dynamics_fn, state, interval_ts, atol=1e-8, rtol=1e-8 105 | ) 106 | traj_ = solution_[0][1:] # [0] for position; [1:] to remove intial state. 107 | trajectory.append(traj_) 108 | 109 | if event_t < times[-1]: 110 | state = tuple(s[-1] for s in solution) 111 | 112 | # update velocity instantaneously. 113 | state = self.inst_update(event_t, state) 114 | 115 | # advance the position a little bit to avoid re-triggering the event fn. 116 | pos, *rest = state 117 | pos = pos + 1e-7 * self.dynamics_fn(event_t, state)[0] 118 | state = pos, *rest 119 | 120 | event_times.append(event_t) 121 | t0 = event_t 122 | 123 | n_events += 1 124 | 125 | # print(event_t.item(), state[0].item(), state[1].item(), self.event_fn.mod(pos).item()) 126 | 127 | trajectory = torch.cat(trajectory, dim=0).reshape(-1) 128 | return trajectory, event_times 129 | 130 | 131 | class Sine(nn.Module): 132 | def forward(self, x): 133 | return torch.sin(x) 134 | 135 | 136 | class NeuralODE(nn.Module): 137 | def __init__(self, aug_dim=2): 138 | super().__init__() 139 | self.initial_pos = nn.Parameter(torch.tensor([10.0])) 140 | self.initial_aug = nn.Parameter(torch.zeros(aug_dim)) 141 | self.odefunc = mlp( 142 | input_dim=1 + aug_dim, 143 | hidden_dim=64, 144 | output_dim=1 + aug_dim, 145 | hidden_depth=2, 146 | act=Sine, 147 | ) 148 | 149 | def init(m): 150 | if isinstance(m, nn.Linear): 151 | std = 1.0 / math.sqrt(m.weight.size(1)) 152 | m.weight.data.uniform_(-2.0 * std, 2.0 * std) 153 | m.bias.data.zero_() 154 | 155 | self.odefunc.apply(init) 156 | 157 | def forward(self, t, state): 158 | return self.odefunc(state) 159 | 160 | def simulate(self, times): 161 | x0 = torch.cat([self.initial_pos, self.initial_aug]).reshape(-1) 162 | solution = odeint(self, x0, times, atol=1e-8, rtol=1e-8, method="dopri5") 163 | trajectory = solution[:, 0] 164 | return trajectory, [] 165 | 166 | 167 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None, act=nn.ReLU): 168 | if hidden_depth == 0: 169 | mods = [nn.Linear(input_dim, output_dim)] 170 | else: 171 | mods = [nn.Linear(input_dim, hidden_dim), act()] 172 | for i in range(hidden_depth - 1): 173 | mods += [nn.Linear(hidden_dim, hidden_dim), act()] 174 | mods.append(nn.Linear(hidden_dim, output_dim)) 175 | if output_mod is not None: 176 | mods.append(output_mod) 177 | trunk = nn.Sequential(*mods) 178 | return trunk 179 | 180 | 181 | def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0): 182 | global_step = min(global_step, decay_steps) 183 | cosine_decay = 0.5 * (1 + math.cos(math.pi * global_step / decay_steps)) 184 | decayed = (1 - alpha) * cosine_decay + alpha 185 | return learning_rate * decayed 186 | 187 | 188 | def learning_rate_schedule( 189 | global_step, warmup_steps, base_learning_rate, lr_scaling, train_steps 190 | ): 191 | warmup_steps = int(round(warmup_steps)) 192 | scaled_lr = base_learning_rate * lr_scaling 193 | if warmup_steps: 194 | learning_rate = global_step / warmup_steps * scaled_lr 195 | else: 196 | learning_rate = scaled_lr 197 | 198 | if global_step < warmup_steps: 199 | learning_rate = learning_rate 200 | else: 201 | learning_rate = cosine_decay( 202 | scaled_lr, global_step - warmup_steps, train_steps - warmup_steps 203 | ) 204 | return learning_rate 205 | 206 | 207 | def set_learning_rate(optimizer, lr): 208 | for group in optimizer.param_groups: 209 | group["lr"] = lr 210 | 211 | 212 | if __name__ == "__main__": 213 | 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument("--base_lr", type=float, default=0.1) 216 | parser.add_argument("--num_iterations", type=int, default=1000) 217 | parser.add_argument("--no_events", action="store_true") 218 | parser.add_argument("--save", type=str, default="figs") 219 | args = parser.parse_args() 220 | 221 | torch.manual_seed(0) 222 | 223 | torch.set_default_dtype(torch.float64) 224 | 225 | with torch.no_grad(): 226 | system = BouncingBallExample() 227 | obs_times, gt_trajectory, _, _ = system.simulate(nbounces=4) 228 | 229 | obs_times = obs_times[:300] 230 | gt_trajectory = gt_trajectory[:300] 231 | 232 | if args.no_events: 233 | model = NeuralODE() 234 | else: 235 | model = NeuralPhysics() 236 | optimizer = torch.optim.Adam(model.parameters(), lr=args.base_lr) 237 | 238 | decay = 1.0 239 | 240 | model.train() 241 | for itr in range(args.num_iterations): 242 | optimizer.zero_grad() 243 | trajectory, event_times = model.simulate(obs_times) 244 | weights = decay**obs_times 245 | loss = ( 246 | ((trajectory - gt_trajectory) / (gt_trajectory + 1e-3)) 247 | .abs() 248 | .mul(weights) 249 | .mean() 250 | ) 251 | loss.backward() 252 | 253 | lr = learning_rate_schedule(itr, 0, args.base_lr, 1.0, args.num_iterations) 254 | set_learning_rate(optimizer, lr) 255 | optimizer.step() 256 | 257 | if itr % 10 == 0: 258 | print(itr, loss.item(), len(event_times)) 259 | 260 | if itr % 10 == 0: 261 | plt.figure() 262 | plt.plot( 263 | obs_times.detach().cpu().numpy(), 264 | gt_trajectory.detach().cpu().numpy(), 265 | label="Target", 266 | ) 267 | plt.plot( 268 | obs_times.detach().cpu().numpy(), 269 | trajectory.detach().cpu().numpy(), 270 | label="Learned", 271 | ) 272 | plt.tight_layout() 273 | os.makedirs(args.save, exist_ok=True) 274 | plt.savefig(f"{args.save}/{itr:05d}.png") 275 | plt.close() 276 | 277 | if (itr + 1) % 100 == 0: 278 | torch.save( 279 | { 280 | "state_dict": model.state_dict(), 281 | }, 282 | f"{args.save}/model.pt", 283 | ) 284 | 285 | del trajectory, loss 286 | -------------------------------------------------------------------------------- /examples/ode_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | parser = argparse.ArgumentParser('ODE demo') 11 | parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5') 12 | parser.add_argument('--data_size', type=int, default=1000) 13 | parser.add_argument('--batch_time', type=int, default=10) 14 | parser.add_argument('--batch_size', type=int, default=20) 15 | parser.add_argument('--niters', type=int, default=2000) 16 | parser.add_argument('--test_freq', type=int, default=20) 17 | parser.add_argument('--viz', action='store_true') 18 | parser.add_argument('--gpu', type=int, default=0) 19 | parser.add_argument('--adjoint', action='store_true') 20 | args = parser.parse_args() 21 | 22 | if args.adjoint: 23 | from torchdiffeq import odeint_adjoint as odeint 24 | else: 25 | from torchdiffeq import odeint 26 | 27 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 28 | 29 | true_y0 = torch.tensor([[2., 0.]]).to(device) 30 | t = torch.linspace(0., 25., args.data_size).to(device) 31 | true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device) 32 | 33 | 34 | class Lambda(nn.Module): 35 | 36 | def forward(self, t, y): 37 | return torch.mm(y**3, true_A) 38 | 39 | 40 | with torch.no_grad(): 41 | true_y = odeint(Lambda(), true_y0, t, method='dopri5') 42 | 43 | 44 | def get_batch(): 45 | s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False)) 46 | batch_y0 = true_y[s] # (M, D) 47 | batch_t = t[:args.batch_time] # (T) 48 | batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D) 49 | return batch_y0.to(device), batch_t.to(device), batch_y.to(device) 50 | 51 | 52 | def makedirs(dirname): 53 | if not os.path.exists(dirname): 54 | os.makedirs(dirname) 55 | 56 | 57 | if args.viz: 58 | makedirs('png') 59 | import matplotlib.pyplot as plt 60 | fig = plt.figure(figsize=(12, 4), facecolor='white') 61 | ax_traj = fig.add_subplot(131, frameon=False) 62 | ax_phase = fig.add_subplot(132, frameon=False) 63 | ax_vecfield = fig.add_subplot(133, frameon=False) 64 | plt.show(block=False) 65 | 66 | 67 | def visualize(true_y, pred_y, odefunc, itr): 68 | 69 | if args.viz: 70 | 71 | ax_traj.cla() 72 | ax_traj.set_title('Trajectories') 73 | ax_traj.set_xlabel('t') 74 | ax_traj.set_ylabel('x,y') 75 | ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-') 76 | ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--') 77 | ax_traj.set_xlim(t.cpu().min(), t.cpu().max()) 78 | ax_traj.set_ylim(-2, 2) 79 | ax_traj.legend() 80 | 81 | ax_phase.cla() 82 | ax_phase.set_title('Phase Portrait') 83 | ax_phase.set_xlabel('x') 84 | ax_phase.set_ylabel('y') 85 | ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-') 86 | ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--') 87 | ax_phase.set_xlim(-2, 2) 88 | ax_phase.set_ylim(-2, 2) 89 | 90 | ax_vecfield.cla() 91 | ax_vecfield.set_title('Learned Vector Field') 92 | ax_vecfield.set_xlabel('x') 93 | ax_vecfield.set_ylabel('y') 94 | 95 | y, x = np.mgrid[-2:2:21j, -2:2:21j] 96 | dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy() 97 | mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1) 98 | dydt = (dydt / mag) 99 | dydt = dydt.reshape(21, 21, 2) 100 | 101 | ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black") 102 | ax_vecfield.set_xlim(-2, 2) 103 | ax_vecfield.set_ylim(-2, 2) 104 | 105 | fig.tight_layout() 106 | plt.savefig('png/{:03d}'.format(itr)) 107 | plt.draw() 108 | plt.pause(0.001) 109 | 110 | 111 | class ODEFunc(nn.Module): 112 | 113 | def __init__(self): 114 | super(ODEFunc, self).__init__() 115 | 116 | self.net = nn.Sequential( 117 | nn.Linear(2, 50), 118 | nn.Tanh(), 119 | nn.Linear(50, 2), 120 | ) 121 | 122 | for m in self.net.modules(): 123 | if isinstance(m, nn.Linear): 124 | nn.init.normal_(m.weight, mean=0, std=0.1) 125 | nn.init.constant_(m.bias, val=0) 126 | 127 | def forward(self, t, y): 128 | return self.net(y**3) 129 | 130 | 131 | class RunningAverageMeter(object): 132 | """Computes and stores the average and current value""" 133 | 134 | def __init__(self, momentum=0.99): 135 | self.momentum = momentum 136 | self.reset() 137 | 138 | def reset(self): 139 | self.val = None 140 | self.avg = 0 141 | 142 | def update(self, val): 143 | if self.val is None: 144 | self.avg = val 145 | else: 146 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 147 | self.val = val 148 | 149 | 150 | if __name__ == '__main__': 151 | 152 | ii = 0 153 | 154 | func = ODEFunc().to(device) 155 | 156 | optimizer = optim.RMSprop(func.parameters(), lr=1e-3) 157 | end = time.time() 158 | 159 | time_meter = RunningAverageMeter(0.97) 160 | 161 | loss_meter = RunningAverageMeter(0.97) 162 | 163 | for itr in range(1, args.niters + 1): 164 | optimizer.zero_grad() 165 | batch_y0, batch_t, batch_y = get_batch() 166 | pred_y = odeint(func, batch_y0, batch_t).to(device) 167 | loss = torch.mean(torch.abs(pred_y - batch_y)) 168 | loss.backward() 169 | optimizer.step() 170 | 171 | time_meter.update(time.time() - end) 172 | loss_meter.update(loss.item()) 173 | 174 | if itr % args.test_freq == 0: 175 | with torch.no_grad(): 176 | pred_y = odeint(func, true_y0, t) 177 | loss = torch.mean(torch.abs(pred_y - true_y)) 178 | print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item())) 179 | visualize(true_y, pred_y, func, ii) 180 | ii += 1 181 | 182 | end = time.time() 183 | -------------------------------------------------------------------------------- /examples/odenet_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | import torchvision.datasets as datasets 10 | import torchvision.transforms as transforms 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet') 14 | parser.add_argument('--tol', type=float, default=1e-3) 15 | parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False]) 16 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res']) 17 | parser.add_argument('--nepochs', type=int, default=160) 18 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False]) 19 | parser.add_argument('--lr', type=float, default=0.1) 20 | parser.add_argument('--batch_size', type=int, default=128) 21 | parser.add_argument('--test_batch_size', type=int, default=1000) 22 | 23 | parser.add_argument('--save', type=str, default='./experiment1') 24 | parser.add_argument('--debug', action='store_true') 25 | parser.add_argument('--gpu', type=int, default=0) 26 | args = parser.parse_args() 27 | 28 | if args.adjoint: 29 | from torchdiffeq import odeint_adjoint as odeint 30 | else: 31 | from torchdiffeq import odeint 32 | 33 | 34 | def conv3x3(in_planes, out_planes, stride=1): 35 | """3x3 convolution with padding""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1): 40 | """1x1 convolution""" 41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 42 | 43 | 44 | def norm(dim): 45 | return nn.GroupNorm(min(32, dim), dim) 46 | 47 | 48 | class ResBlock(nn.Module): 49 | expansion = 1 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(ResBlock, self).__init__() 53 | self.norm1 = norm(inplanes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.downsample = downsample 56 | self.conv1 = conv3x3(inplanes, planes, stride) 57 | self.norm2 = norm(planes) 58 | self.conv2 = conv3x3(planes, planes) 59 | 60 | def forward(self, x): 61 | shortcut = x 62 | 63 | out = self.relu(self.norm1(x)) 64 | 65 | if self.downsample is not None: 66 | shortcut = self.downsample(out) 67 | 68 | out = self.conv1(out) 69 | out = self.norm2(out) 70 | out = self.relu(out) 71 | out = self.conv2(out) 72 | 73 | return out + shortcut 74 | 75 | 76 | class ConcatConv2d(nn.Module): 77 | 78 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 79 | super(ConcatConv2d, self).__init__() 80 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 81 | self._layer = module( 82 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 83 | bias=bias 84 | ) 85 | 86 | def forward(self, t, x): 87 | tt = torch.ones_like(x[:, :1, :, :]) * t 88 | ttx = torch.cat([tt, x], 1) 89 | return self._layer(ttx) 90 | 91 | 92 | class ODEfunc(nn.Module): 93 | 94 | def __init__(self, dim): 95 | super(ODEfunc, self).__init__() 96 | self.norm1 = norm(dim) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1) 99 | self.norm2 = norm(dim) 100 | self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1) 101 | self.norm3 = norm(dim) 102 | self.nfe = 0 103 | 104 | def forward(self, t, x): 105 | self.nfe += 1 106 | out = self.norm1(x) 107 | out = self.relu(out) 108 | out = self.conv1(t, out) 109 | out = self.norm2(out) 110 | out = self.relu(out) 111 | out = self.conv2(t, out) 112 | out = self.norm3(out) 113 | return out 114 | 115 | 116 | class ODEBlock(nn.Module): 117 | 118 | def __init__(self, odefunc): 119 | super(ODEBlock, self).__init__() 120 | self.odefunc = odefunc 121 | self.integration_time = torch.tensor([0, 1]).float() 122 | 123 | def forward(self, x): 124 | self.integration_time = self.integration_time.type_as(x) 125 | out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol) 126 | return out[1] 127 | 128 | @property 129 | def nfe(self): 130 | return self.odefunc.nfe 131 | 132 | @nfe.setter 133 | def nfe(self, value): 134 | self.odefunc.nfe = value 135 | 136 | 137 | class Flatten(nn.Module): 138 | 139 | def __init__(self): 140 | super(Flatten, self).__init__() 141 | 142 | def forward(self, x): 143 | shape = torch.prod(torch.tensor(x.shape[1:])).item() 144 | return x.view(-1, shape) 145 | 146 | 147 | class RunningAverageMeter(object): 148 | """Computes and stores the average and current value""" 149 | 150 | def __init__(self, momentum=0.99): 151 | self.momentum = momentum 152 | self.reset() 153 | 154 | def reset(self): 155 | self.val = None 156 | self.avg = 0 157 | 158 | def update(self, val): 159 | if self.val is None: 160 | self.avg = val 161 | else: 162 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 163 | self.val = val 164 | 165 | 166 | def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0): 167 | if data_aug: 168 | transform_train = transforms.Compose([ 169 | transforms.RandomCrop(28, padding=4), 170 | transforms.ToTensor(), 171 | ]) 172 | else: 173 | transform_train = transforms.Compose([ 174 | transforms.ToTensor(), 175 | ]) 176 | 177 | transform_test = transforms.Compose([ 178 | transforms.ToTensor(), 179 | ]) 180 | 181 | train_loader = DataLoader( 182 | datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size, 183 | shuffle=True, num_workers=2, drop_last=True 184 | ) 185 | 186 | train_eval_loader = DataLoader( 187 | datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test), 188 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True 189 | ) 190 | 191 | test_loader = DataLoader( 192 | datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test), 193 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True 194 | ) 195 | 196 | return train_loader, test_loader, train_eval_loader 197 | 198 | 199 | def inf_generator(iterable): 200 | """Allows training with DataLoaders in a single infinite loop: 201 | for i, (x, y) in enumerate(inf_generator(train_loader)): 202 | """ 203 | iterator = iterable.__iter__() 204 | while True: 205 | try: 206 | yield iterator.__next__() 207 | except StopIteration: 208 | iterator = iterable.__iter__() 209 | 210 | 211 | def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates): 212 | initial_learning_rate = args.lr * batch_size / batch_denom 213 | 214 | boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs] 215 | vals = [initial_learning_rate * decay for decay in decay_rates] 216 | 217 | def learning_rate_fn(itr): 218 | lt = [itr < b for b in boundaries] + [True] 219 | i = np.argmax(lt) 220 | return vals[i] 221 | 222 | return learning_rate_fn 223 | 224 | 225 | def one_hot(x, K): 226 | return np.array(x[:, None] == np.arange(K)[None, :], dtype=int) 227 | 228 | 229 | def accuracy(model, dataset_loader): 230 | total_correct = 0 231 | for x, y in dataset_loader: 232 | x = x.to(device) 233 | y = one_hot(np.array(y.numpy()), 10) 234 | 235 | target_class = np.argmax(y, axis=1) 236 | predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1) 237 | total_correct += np.sum(predicted_class == target_class) 238 | return total_correct / len(dataset_loader.dataset) 239 | 240 | 241 | def count_parameters(model): 242 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 243 | 244 | 245 | def makedirs(dirname): 246 | if not os.path.exists(dirname): 247 | os.makedirs(dirname) 248 | 249 | 250 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 251 | logger = logging.getLogger() 252 | if debug: 253 | level = logging.DEBUG 254 | else: 255 | level = logging.INFO 256 | logger.setLevel(level) 257 | if saving: 258 | info_file_handler = logging.FileHandler(logpath, mode="a") 259 | info_file_handler.setLevel(level) 260 | logger.addHandler(info_file_handler) 261 | if displaying: 262 | console_handler = logging.StreamHandler() 263 | console_handler.setLevel(level) 264 | logger.addHandler(console_handler) 265 | logger.info(filepath) 266 | with open(filepath, "r") as f: 267 | logger.info(f.read()) 268 | 269 | for f in package_files: 270 | logger.info(f) 271 | with open(f, "r") as package_f: 272 | logger.info(package_f.read()) 273 | 274 | return logger 275 | 276 | 277 | if __name__ == '__main__': 278 | 279 | makedirs(args.save) 280 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 281 | logger.info(args) 282 | 283 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 284 | 285 | is_odenet = args.network == 'odenet' 286 | 287 | if args.downsampling_method == 'conv': 288 | downsampling_layers = [ 289 | nn.Conv2d(1, 64, 3, 1), 290 | norm(64), 291 | nn.ReLU(inplace=True), 292 | nn.Conv2d(64, 64, 4, 2, 1), 293 | norm(64), 294 | nn.ReLU(inplace=True), 295 | nn.Conv2d(64, 64, 4, 2, 1), 296 | ] 297 | elif args.downsampling_method == 'res': 298 | downsampling_layers = [ 299 | nn.Conv2d(1, 64, 3, 1), 300 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)), 301 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)), 302 | ] 303 | 304 | feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)] 305 | fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)] 306 | 307 | model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device) 308 | 309 | logger.info(model) 310 | logger.info('Number of parameters: {}'.format(count_parameters(model))) 311 | 312 | criterion = nn.CrossEntropyLoss().to(device) 313 | 314 | train_loader, test_loader, train_eval_loader = get_mnist_loaders( 315 | args.data_aug, args.batch_size, args.test_batch_size 316 | ) 317 | 318 | data_gen = inf_generator(train_loader) 319 | batches_per_epoch = len(train_loader) 320 | 321 | lr_fn = learning_rate_with_decay( 322 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140], 323 | decay_rates=[1, 0.1, 0.01, 0.001] 324 | ) 325 | 326 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 327 | 328 | best_acc = 0 329 | batch_time_meter = RunningAverageMeter() 330 | f_nfe_meter = RunningAverageMeter() 331 | b_nfe_meter = RunningAverageMeter() 332 | end = time.time() 333 | 334 | for itr in range(args.nepochs * batches_per_epoch): 335 | 336 | for param_group in optimizer.param_groups: 337 | param_group['lr'] = lr_fn(itr) 338 | 339 | optimizer.zero_grad() 340 | x, y = data_gen.__next__() 341 | x = x.to(device) 342 | y = y.to(device) 343 | logits = model(x) 344 | loss = criterion(logits, y) 345 | 346 | if is_odenet: 347 | nfe_forward = feature_layers[0].nfe 348 | feature_layers[0].nfe = 0 349 | 350 | loss.backward() 351 | optimizer.step() 352 | 353 | if is_odenet: 354 | nfe_backward = feature_layers[0].nfe 355 | feature_layers[0].nfe = 0 356 | 357 | batch_time_meter.update(time.time() - end) 358 | if is_odenet: 359 | f_nfe_meter.update(nfe_forward) 360 | b_nfe_meter.update(nfe_backward) 361 | end = time.time() 362 | 363 | if itr % batches_per_epoch == 0: 364 | with torch.no_grad(): 365 | train_acc = accuracy(model, train_eval_loader) 366 | val_acc = accuracy(model, test_loader) 367 | if val_acc > best_acc: 368 | torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth')) 369 | best_acc = val_acc 370 | logger.info( 371 | "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | " 372 | "Train Acc {:.4f} | Test Acc {:.4f}".format( 373 | itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg, 374 | b_nfe_meter.avg, train_acc, val_acc 375 | ) 376 | ) 377 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import setuptools 4 | 5 | 6 | # for simplicity we actually store the version in the __version__ attribute in the source 7 | here = os.path.realpath(os.path.dirname(__file__)) 8 | with open(os.path.join(here, 'torchdiffeq', '__init__.py')) as f: 9 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 10 | if meta_match: 11 | version = meta_match.group(1) 12 | else: 13 | raise RuntimeError("Unable to find __version__ string.") 14 | 15 | 16 | setuptools.setup( 17 | name="torchdiffeq", 18 | version=version, 19 | author="Ricky Tian Qi Chen", 20 | author_email="rtqichen@cs.toronto.edu", 21 | description="ODE solvers and adjoint sensitivity analysis in PyTorch.", 22 | url="https://github.com/rtqichen/torchdiffeq", 23 | packages=setuptools.find_packages(), 24 | install_requires=['torch>=1.5.0', 'scipy>=1.4.0'], 25 | python_requires='~=3.6', 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: MIT License", 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /tests/DETEST/detest.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | #################################### 6 | # Problem Class A. Single equations. 7 | #################################### 8 | def A1(): 9 | diffeq = lambda t, y: -y 10 | init = lambda: (torch.tensor(0.), torch.tensor(1.)) 11 | solution = lambda t: torch.exp(-t) 12 | return diffeq, init, solution 13 | 14 | 15 | def A2(): 16 | diffeq = lambda t, y: -y**3 / 2 17 | init = lambda: (torch.tensor(0.), torch.tensor(1.)) 18 | solution = lambda t: 1 / torch.sqrt(t + 1) 19 | return diffeq, init, solution 20 | 21 | 22 | def A3(): 23 | diffeq = lambda t, y: y * torch.cos(t) 24 | init = lambda: (torch.tensor(0.), torch.tensor(1.)) 25 | solution = lambda t: torch.exp(torch.sin(t)) 26 | return diffeq, init, solution 27 | 28 | 29 | def A4(): 30 | diffeq = lambda t, y: y / 4 * (1 - y / 20) 31 | init = lambda: (torch.tensor(0.), torch.tensor(1.)) 32 | solution = lambda t: 20 / (1 + 19 * torch.exp(-t / 4)) 33 | return diffeq, init, solution 34 | 35 | 36 | def A5(): 37 | diffeq = lambda t, y: (y - t) / (y + t) 38 | init = lambda: (torch.tensor(0.), torch.tensor(4.)) 39 | return diffeq, init, None 40 | 41 | 42 | ################################# 43 | # Problem Class B. Small systems. 44 | ################################# 45 | def B1(): 46 | 47 | def diffeq(t, y): 48 | dy0 = 2 * (y[0] - y[0] * y[1]) 49 | dy1 = -(y[1] - y[0] * y[1]) 50 | return torch.stack([dy0, dy1]) 51 | 52 | def init(): 53 | return torch.tensor(0.), torch.tensor([1., 3.]) 54 | 55 | return diffeq, init, None 56 | 57 | 58 | def B2(): 59 | 60 | A = torch.tensor([[-1., 1., 0.], [1., -2., 1.], [0., 1., -1.]]) 61 | 62 | def diffeq(t, y): 63 | dy = torch.mv(A, y) 64 | return dy 65 | 66 | def init(): 67 | return torch.tensor(0.), torch.tensor([2., 0., 1.]) 68 | 69 | return diffeq, init, None 70 | 71 | 72 | def B3(): 73 | 74 | def diffeq(t, y): 75 | dy0 = -y[0] 76 | dy1 = y[0] - y[1] * y[1] 77 | dy2 = y[1] * y[1] 78 | return torch.stack([dy0, dy1, dy2]) 79 | 80 | def init(): 81 | return torch.tensor(0.), torch.tensor([1., 0., 0.]) 82 | 83 | return diffeq, init, None 84 | 85 | 86 | def B4(): 87 | 88 | def diffeq(t, y): 89 | a = torch.sqrt(y[0] * y[0] + y[1] * y[1]) 90 | dy0 = -y[1] - y[0] * y[2] / a 91 | dy1 = y[0] - y[1] * y[2] / a 92 | dy2 = y[0] / a 93 | return torch.stack([dy0, dy1, dy2]) 94 | 95 | def init(): 96 | return torch.tensor(0.), torch.tensor([3., 0., 0.]) 97 | 98 | return diffeq, init, None 99 | 100 | 101 | def B5(): 102 | 103 | def diffeq(t, y): 104 | dy0 = y[1] * y[2] 105 | dy1 = -y[0] * y[2] 106 | dy2 = -0.51 * y[0] * y[1] 107 | return torch.stack([dy0, dy1, dy2]) 108 | 109 | def init(): 110 | return torch.tensor(0.), torch.tensor([0., 1., 1.]) 111 | 112 | return diffeq, init, None 113 | 114 | 115 | #################################### 116 | # Problem Class C. Moderate systems. 117 | #################################### 118 | def C1(): 119 | 120 | A = torch.zeros(10, 10) 121 | A.view(-1)[:-1:11] = -1 122 | A.view(-1)[10::11] = 1 123 | 124 | def diffeq(t, y): 125 | return torch.mv(A, y) 126 | 127 | def init(): 128 | y0 = torch.zeros(10) 129 | y0[0] = 1 130 | return torch.tensor(0.), y0 131 | 132 | return diffeq, init, None 133 | 134 | 135 | def C2(): 136 | 137 | A = torch.zeros(10, 10) 138 | A.view(-1)[:-1:11] = torch.linspace(-1, -9, 9) 139 | A.view(-1)[10::11] = torch.linspace(1, 9, 9) 140 | 141 | def diffeq(t, y): 142 | return torch.mv(A, y) 143 | 144 | def init(): 145 | y0 = torch.zeros(10) 146 | y0[0] = 1 147 | return torch.tensor(0.), y0 148 | 149 | return diffeq, init, None 150 | 151 | 152 | def C3(): 153 | n = 10 154 | A = torch.zeros(n, n) 155 | A.view(-1)[::n + 1] = -2 156 | A.view(-1)[n::n + 1] = 1 157 | A.view(-1)[1::n + 1] = 1 158 | 159 | def diffeq(t, y): 160 | return torch.mv(A, y) 161 | 162 | def init(): 163 | y0 = torch.zeros(n) 164 | y0[0] = 1 165 | return torch.tensor(0.), y0 166 | 167 | return diffeq, init, None 168 | 169 | 170 | def C4(): 171 | n = 51 172 | A = torch.zeros(n, n) 173 | A.view(-1)[::n + 1] = -2 174 | A.view(-1)[n::n + 1] = 1 175 | A.view(-1)[1::n + 1] = 1 176 | 177 | def diffeq(t, y): 178 | return torch.mv(A, y) 179 | 180 | def init(): 181 | y0 = torch.zeros(n) 182 | y0[0] = 1 183 | return torch.tensor(0.), y0 184 | 185 | return diffeq, init, None 186 | 187 | 188 | def C5(): 189 | 190 | k2 = torch.tensor(2.95912208286) 191 | m0 = torch.tensor(1.00000597682) 192 | m = torch.tensor([ 193 | 0.000954786104043, 194 | 0.000285583733151, 195 | 0.0000437273164546, 196 | 0.0000517759138449, 197 | 0.00000277777777778, 198 | ]).view(1, 5) 199 | 200 | def diffeq(t, y): 201 | # y is 2 x 3 x 5 202 | # y[0] contains y, y[0] contains y' 203 | # second axis indexes space (x,y,z). 204 | # third axis indexes 5 bodies. 205 | 206 | dy = y[1, :, :] 207 | y = y[0] 208 | r = torch.sqrt(torch.sum(y**2, 0)).view(1, 5) 209 | d = torch.sqrt(torch.sum((y[:, :, None] - y[:, None, :])**2, 0)) 210 | F = m.view(1, 1, 5) * ((y[:, None, :] - y[:, :, None]) / (d * d * d).view(1, 5, 5) + y.view(3, 1, 5) / 211 | (r * r * r).view(1, 1, 5)) 212 | F.view(3, 5 * 5)[:, ::6] = 0 213 | ddy = k2 * (-(m0 + m) * y / (r * r * r)) + F.sum(2) 214 | return torch.stack([dy, ddy], 0) 215 | 216 | def init(): 217 | y0 = torch.tensor([ 218 | 3.42947415189, 3.35386959711, 1.35494901715, 6.64145542550, 5.97156957878, 2.18231499728, 11.2630437207, 219 | 14.6952576794, 6.27960525067, -30.1552268759, 165699966404, 1.43785752721, -21.1238353380, 28.4465098142, 220 | 15.388265967 221 | ]).view(5, 3).transpose(0, 1) 222 | 223 | dy0 = torch.tensor([ 224 | -.557160570446, .505696783289, .230578543901, -.415570776342, .365682722812, .169143213293, -.325325669158, 225 | .189706021964, .0877265322780, -.0240476254170, -.287659532608, -.117219543175, -.176860753121, 226 | -.216393453025, -.0148647893090 227 | ]).view(5, 3).transpose(0, 1) 228 | 229 | return torch.tensor(0.), torch.stack([y0, dy0], 0) 230 | 231 | return diffeq, init, None 232 | 233 | 234 | ################################### 235 | # Problem Class D. Orbit equations. 236 | ################################### 237 | def _DTemplate(eps): 238 | 239 | def diffeq(t, y): 240 | r = (y[0]**2 + y[1]**2)**(3 / 2) 241 | dy0 = y[2] 242 | dy1 = y[3] 243 | dy2 = -y[0] / r 244 | dy3 = -y[1] / r 245 | return torch.stack([dy0, dy1, dy2, dy3]) 246 | 247 | def init(): 248 | return torch.tensor(0.), torch.tensor([1 - eps, 0, 0, math.sqrt((1 + eps) / (1 - eps))]) 249 | 250 | return diffeq, init, None 251 | 252 | 253 | D1 = lambda: _DTemplate(0.1) 254 | D2 = lambda: _DTemplate(0.3) 255 | D3 = lambda: _DTemplate(0.5) 256 | D4 = lambda: _DTemplate(0.7) 257 | D5 = lambda: _DTemplate(0.9) 258 | 259 | 260 | ########################################## 261 | # Problem Class E. Higher order equations. 262 | ########################################## 263 | def E1(): 264 | 265 | def diffeq(t, y): 266 | dy0 = y[1] 267 | dy1 = -(y[1] / (t + 1) + (1 - 0.25 / (t + 1)**2) * y[0]) 268 | return torch.stack([dy0, dy1]) 269 | 270 | def init(): 271 | return torch.tensor(0.), torch.tensor([.671396707141803, .0954005144474744]) 272 | 273 | return diffeq, init, None 274 | 275 | 276 | def E2(): 277 | 278 | def diffeq(t, y): 279 | dy0 = y[1] 280 | dy1 = (1 - y[0]**2) * y[1] - y[0] 281 | return torch.stack([dy0, dy1]) 282 | 283 | def init(): 284 | return torch.tensor(0.), torch.tensor([2., 0.]) 285 | 286 | return diffeq, init, None 287 | 288 | 289 | def E3(): 290 | 291 | def diffeq(t, y): 292 | dy0 = y[1] 293 | dy1 = y[0]**3 / 6 - y[0] + 2 * torch.sin(2.78535 * t) 294 | return torch.stack([dy0, dy1]) 295 | 296 | def init(): 297 | return torch.tensor(0.), torch.tensor([0., 0.]) 298 | 299 | return diffeq, init, None 300 | 301 | 302 | def E4(): 303 | 304 | def diffeq(t, y): 305 | dy0 = y[1] 306 | dy1 = .32 - .4 * y[1]**2 307 | return torch.stack([dy0, dy1]) 308 | 309 | def init(): 310 | return torch.tensor(0.), torch.tensor([30., 0.]) 311 | 312 | return diffeq, init, None 313 | 314 | 315 | def E5(): 316 | 317 | def diffeq(t, y): 318 | dy0 = y[1] 319 | dy1 = torch.sqrt(1 + y[1]**2) / (25 - t) 320 | return torch.stack([dy0, dy1]) 321 | 322 | def init(): 323 | return torch.tensor(0.), torch.tensor([0., 0.]) 324 | 325 | return diffeq, init, None 326 | 327 | 328 | ################### 329 | # Helper functions. 330 | ################### 331 | def _to_tensor(x): 332 | if not torch.is_tensor(x): 333 | x = torch.tensor(x) 334 | return x 335 | -------------------------------------------------------------------------------- /tests/DETEST/run.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from scipy.stats.mstats import gmean 4 | import torch 5 | from torchdiffeq import odeint 6 | import detest 7 | 8 | torch.set_default_tensor_type(torch.DoubleTensor) 9 | 10 | 11 | class NFEDiffEq: 12 | 13 | def __init__(self, diffeq): 14 | self.diffeq = diffeq 15 | self.nfe = 0 16 | 17 | def __call__(self, t, y): 18 | self.nfe += 1 19 | return self.diffeq(t, y) 20 | 21 | 22 | def main(): 23 | 24 | sol = dict() 25 | for method in ['dopri5', 'adams']: 26 | for tol in [1e-3, 1e-6, 1e-9]: 27 | print('======= {} | tol={:e} ======='.format(method, tol)) 28 | nfes = [] 29 | times = [] 30 | errs = [] 31 | for c in ['A', 'B', 'C', 'D', 'E']: 32 | for i in ['1', '2', '3', '4', '5']: 33 | diffeq, init, _ = getattr(detest, c + i)() 34 | t0, y0 = init() 35 | diffeq = NFEDiffEq(diffeq) 36 | 37 | if not c + i in sol: 38 | sol[c + i] = odeint( 39 | diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=1e-12, rtol=1e-12, method='dopri5' 40 | )[1] 41 | diffeq.nfe = 0 42 | 43 | start_time = time.time() 44 | est = odeint(diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=tol, rtol=tol, method=method) 45 | time_spent = time.time() - start_time 46 | 47 | error = torch.sqrt(torch.mean((sol[c + i] - est[1])**2)) 48 | 49 | errs.append(error.item()) 50 | nfes.append(diffeq.nfe) 51 | times.append(time_spent) 52 | 53 | print('{}: NFE {} | Time {} | Err {:e}'.format(c + i, diffeq.nfe, time_spent, error.item())) 54 | 55 | print('Total NFE {} | Total Time {} | GeomAvg Error {:e}'.format(np.sum(nfes), np.sum(times), gmean(errs))) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /tests/api_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchdiffeq 4 | 5 | from problems import construct_problem, DTYPES, DEVICES, ADAPTIVE_METHODS 6 | 7 | 8 | EPS = {torch.float32: 1e-4, torch.float64: 1e-12, torch.complex64: 1e-4} 9 | 10 | 11 | class TestCollectionState(unittest.TestCase): 12 | def test_forward(self): 13 | for dtype in DTYPES: 14 | eps = EPS[dtype] 15 | for device in DEVICES: 16 | f, y0, t_points, sol = construct_problem(dtype=dtype, device=device) 17 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) 18 | tuple_y0 = (y0, y0) 19 | for method in ADAPTIVE_METHODS: 20 | 21 | with self.subTest(dtype=dtype, device=device, method=method): 22 | tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method=method) 23 | max_error0 = (sol - tuple_y[0]).abs().max() 24 | max_error1 = (sol - tuple_y[1]).abs().max() 25 | self.assertLess(max_error0, eps) 26 | self.assertLess(max_error1, eps) 27 | 28 | def test_gradient(self): 29 | for device in DEVICES: 30 | f, y0, t_points, sol = construct_problem(device=device) 31 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) 32 | for method in ADAPTIVE_METHODS: 33 | if method == "scipy_solver": 34 | continue 35 | 36 | with self.subTest(device=device, method=method): 37 | for i in range(2): 38 | func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method=method)[i] 39 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /tests/event_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchdiffeq 4 | 5 | from problems import construct_problem, DTYPES, DEVICES, METHODS, FIXED_METHODS 6 | 7 | 8 | def rel_error(true, estimate): 9 | return ((true - estimate) / true).abs().max() 10 | 11 | 12 | class TestEventHandling(unittest.TestCase): 13 | 14 | def test_odeint(self): 15 | for reverse in (False, True): 16 | for dtype in DTYPES: 17 | for device in DEVICES: 18 | for method in METHODS: 19 | 20 | # TODO: remove after event handling gets enabled. 21 | if method == 'scipy_solver': 22 | continue 23 | 24 | for ode in ("constant", "sine"): 25 | with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method): 26 | if method == "explicit_adams": 27 | tol = 7e-2 28 | elif method == "euler" or method == "implicit_euler": 29 | tol = 5e-3 30 | elif method == "gl6": 31 | tol = 2e-3 32 | else: 33 | tol = 1e-4 34 | 35 | f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode, 36 | reverse=reverse) 37 | 38 | def event_fn(t, y): 39 | return torch.sum(y - sol[2]).real 40 | 41 | if method in FIXED_METHODS: 42 | options = {"step_size": 0.01, "interp": "cubic"} 43 | else: 44 | options = {} 45 | 46 | t, y = torchdiffeq.odeint(f, y0, t_points[0:2], event_fn=event_fn, method=method, options=options) 47 | y = y[-1] 48 | self.assertLess(rel_error(sol[2], y), tol) 49 | self.assertLess(rel_error(t_points[2], t), tol) 50 | 51 | def test_adjoint(self): 52 | f, y0, t_points, sol = construct_problem(device="cpu", ode="constant") 53 | 54 | def event_fn(t, y): 55 | return torch.sum(y - sol[-1]) 56 | 57 | t, y = torchdiffeq.odeint_adjoint(f, y0, t_points[0:2], event_fn=event_fn, method="dopri5") 58 | y = y[-1] 59 | self.assertLess(rel_error(sol[-1], y), 1e-4) 60 | self.assertLess(rel_error(t_points[-1], t), 1e-4) 61 | 62 | # Make sure adjoint mode backward code can still be run. 63 | t.backward(retain_graph=True) 64 | y.sum().backward() 65 | 66 | 67 | if __name__ == '__main__': 68 | unittest.main() 69 | -------------------------------------------------------------------------------- /tests/gradient_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchdiffeq 4 | 5 | from problems import construct_problem, PROBLEMS, DEVICES, METHODS 6 | 7 | 8 | def max_abs(tensor): 9 | return torch.max(torch.abs(tensor)) 10 | 11 | 12 | class TestGradient(unittest.TestCase): 13 | def test_odeint(self): 14 | for device in DEVICES: 15 | for method in METHODS: 16 | 17 | if method == 'scipy_solver': 18 | continue 19 | 20 | with self.subTest(device=device, method=method): 21 | f, y0, t_points, _ = construct_problem(device=device) 22 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method=method) 23 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 24 | 25 | def test_adjoint(self): 26 | for device in DEVICES: 27 | for method in METHODS: 28 | 29 | with self.subTest(device=device, method=method): 30 | f, y0, t_points, _ = construct_problem(device=device) 31 | func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method=method) 32 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 33 | 34 | def test_adjoint_against_odeint(self): 35 | """ 36 | Test against dopri5 37 | """ 38 | for device in DEVICES: 39 | for ode in PROBLEMS: 40 | for t_grad in (True, False): 41 | if ode == 'constant': 42 | eps = 1e-12 43 | elif ode == 'linear': 44 | eps = 1e-5 45 | elif ode == 'sine': 46 | eps = 5e-3 47 | elif ode == 'exp': 48 | eps = 1e-2 49 | else: 50 | raise RuntimeError 51 | 52 | with self.subTest(device=device, ode=ode, t_grad=t_grad): 53 | f, y0, t_points, _ = construct_problem(device=device, ode=ode) 54 | t_points.requires_grad_(t_grad) 55 | 56 | ys = torchdiffeq.odeint(f, y0, t_points, rtol=1e-9, atol=1e-12) 57 | torch.manual_seed(0) 58 | gradys = torch.rand_like(ys) 59 | ys.backward(gradys) 60 | 61 | reg_y0_grad = y0.grad.clone() 62 | reg_t_grad = t_points.grad.clone() if t_grad else None 63 | reg_params_grads = [] 64 | for param in f.parameters(): 65 | reg_params_grads.append(param.grad.clone()) 66 | 67 | y0.grad.zero_() 68 | if t_grad: 69 | t_points.grad.zero_() 70 | for param in f.parameters(): 71 | param.grad.zero_() 72 | 73 | ys = torchdiffeq.odeint_adjoint(f, y0, t_points, rtol=1e-9, atol=1e-12) 74 | ys.backward(gradys) 75 | 76 | adj_y0_grad = y0.grad 77 | adj_t_grad = t_points.grad if t_grad else None 78 | adj_params_grads = [] 79 | for param in f.parameters(): 80 | adj_params_grads.append(param.grad) 81 | 82 | self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps) 83 | if t_grad: 84 | self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps) 85 | for reg_grad, adj_grad in zip(reg_params_grads, adj_params_grads): 86 | self.assertLess(max_abs(reg_grad - adj_grad), eps) 87 | 88 | 89 | class TestCompareAdjointGradient(unittest.TestCase): 90 | 91 | def problem(self, device): 92 | class Odefunc(torch.nn.Module): 93 | def __init__(self): 94 | super(Odefunc, self).__init__() 95 | self.A = torch.nn.Parameter(torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])) 96 | self.unused_module = torch.nn.Linear(2, 5) 97 | 98 | def forward(self, t, y): 99 | return torch.mm(y**3, self.A) 100 | 101 | y0 = torch.tensor([[2., 0.]], device=device, requires_grad=True) 102 | t_points = torch.linspace(0., 25., 10, device=device, requires_grad=True) 103 | func = Odefunc().to(device) 104 | return func, y0, t_points 105 | 106 | def test_against_dopri5(self): 107 | method_eps = { 108 | 'dopri5': (3e-4, 1e-4, 2e-3), 109 | 'scipy_solver': (3e-4, 1e-4, 2e-3), 110 | } 111 | for device in DEVICES: 112 | for method, eps in method_eps.items(): 113 | for t_grad in (True, False): 114 | with self.subTest(device=device, method=method): 115 | func, y0, t_points = self.problem(device=device) 116 | t_points.requires_grad_(t_grad) 117 | 118 | ys = torchdiffeq.odeint_adjoint(func, y0, t_points, method=method) 119 | gradys = torch.rand_like(ys) * 0.1 120 | ys.backward(gradys) 121 | 122 | adj_y0_grad = y0.grad 123 | adj_t_grad = t_points.grad if t_grad else None 124 | adj_A_grad = func.A.grad 125 | self.assertEqual(max_abs(func.unused_module.weight.grad), 0) 126 | self.assertEqual(max_abs(func.unused_module.bias.grad), 0) 127 | 128 | func, y0, t_points = self.problem(device=device) 129 | ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5') 130 | ys.backward(gradys) 131 | 132 | self.assertLess(max_abs(y0.grad - adj_y0_grad), eps[0]) 133 | if t_grad: 134 | self.assertLess(max_abs(t_points.grad - adj_t_grad), eps[1]) 135 | self.assertLess(max_abs(func.A.grad - adj_A_grad), eps[2]) 136 | 137 | 138 | if __name__ == '__main__': 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /tests/norm_tests.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import unittest 3 | 4 | import torch 5 | import torchdiffeq 6 | 7 | from problems import (DTYPES, DEVICES, ADAPTIVE_METHODS) 8 | 9 | 10 | @contextlib.contextmanager 11 | def random_seed_torch(seed): 12 | cpu_rng_state = torch.get_rng_state() 13 | torch.manual_seed(seed) 14 | 15 | try: 16 | yield 17 | finally: 18 | torch.set_rng_state(cpu_rng_state) 19 | 20 | 21 | class _NeuralF(torch.nn.Module): 22 | def __init__(self, width, oscillate): 23 | super(_NeuralF, self).__init__() 24 | 25 | # Use the same set of random weights for every instance. 26 | with random_seed_torch(0): 27 | self.linears = torch.nn.Sequential(torch.nn.Linear(2, width), 28 | torch.nn.Tanh(), 29 | torch.nn.Linear(width, 2), 30 | torch.nn.Tanh()) 31 | self.nfe = 0 32 | self.oscillate = oscillate 33 | 34 | def forward(self, t, x): 35 | self.nfe += 1 36 | out = self.linears(x) 37 | if self.oscillate: 38 | out = out * t.mul(2).sin() 39 | return out 40 | 41 | 42 | class TestNorms(unittest.TestCase): 43 | def test_norm(self): 44 | def f(t, x): 45 | return x 46 | t = torch.tensor([0., 1.]) 47 | 48 | # First test that tensor input appears in the norm. 49 | is_called = False 50 | 51 | def norm(state): 52 | nonlocal is_called 53 | is_called = True 54 | self.assertIsInstance(state, torch.Tensor) 55 | self.assertEqual(state.shape, ()) 56 | return state.pow(2).mean().sqrt() 57 | x0 = torch.tensor(1.) 58 | torchdiffeq.odeint(f, x0, t, options=dict(norm=norm)) 59 | self.assertTrue(is_called) 60 | 61 | # Now test that tupled input appears in the norm 62 | is_called = False 63 | 64 | def norm(state): 65 | nonlocal is_called 66 | is_called = True 67 | self.assertIsInstance(state, tuple) 68 | self.assertEqual(len(state), 1) 69 | state, = state 70 | self.assertEqual(state.shape, ()) 71 | return state.pow(2).mean().sqrt() 72 | x0 = (torch.tensor(1.),) 73 | torchdiffeq.odeint(f, x0, t, options=dict(norm=norm)) 74 | self.assertTrue(is_called) 75 | 76 | is_called = False 77 | 78 | def norm(state): 79 | nonlocal is_called 80 | is_called = True 81 | self.assertIsInstance(state, tuple) 82 | self.assertEqual(len(state), 2) 83 | state1, state2 = state 84 | self.assertEqual(state1.shape, ()) 85 | self.assertEqual(state2.shape, (2, 2)) 86 | return state1.pow(2).mean().sqrt() 87 | x0 = (torch.tensor(1.), torch.tensor([[0.5, 0.5], [0.1, 0.1]])) 88 | torchdiffeq.odeint(f, x0, t, options=dict(norm=norm)) 89 | self.assertTrue(is_called) 90 | 91 | def test_adjoint_norm(self): 92 | def f(t, x): 93 | return x 94 | t = torch.tensor([0., 1.]) 95 | adjoint_params = (torch.rand(7, requires_grad=True), torch.rand((), requires_grad=True)) 96 | 97 | def make_spy_on_adjoint_norm(adjoint_norm, actual_norm): 98 | is_spy_called = [False] 99 | 100 | def spy_on_adjoint_norm(tensor): 101 | nonlocal is_spy_called 102 | is_spy_called[0] = True 103 | norm_result = adjoint_norm(tensor) 104 | true_norm_result = actual_norm(tensor) 105 | self.assertIsInstance(norm_result, torch.Tensor) 106 | self.assertEqual(norm_result.shape, true_norm_result.shape) 107 | self.assertLess((norm_result - true_norm_result).abs().max(), 1e-6) 108 | return norm_result 109 | 110 | return spy_on_adjoint_norm, is_spy_called 111 | 112 | # Test the various auto-constructed adjoint norms with tensor (not tuple) state 113 | for shape in ((), (1,), (2, 2)): 114 | for use_adjoint_options, seminorm in ((False, False), (True, False), (True, True)): 115 | with self.subTest(shape=shape, use_adjoint_options=use_adjoint_options, seminorm=seminorm): 116 | x0 = torch.full(shape, 1.) 117 | if use_adjoint_options: 118 | if seminorm: 119 | # Test passing adjoint_options and wanting the seminorm 120 | kwargs = dict(adjoint_options=dict(norm='seminorm')) 121 | else: 122 | # Test passing adjoint_options but not specify the adjoint norm 123 | kwargs = dict(adjoint_options={}) 124 | else: 125 | # Test not passing adjoint_options at all. 126 | kwargs = {} 127 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params, **kwargs) 128 | _adjoint_norm = xs.grad_fn.adjoint_options['norm'] 129 | 130 | is_called = False 131 | 132 | def actual_norm(tensor_tuple): 133 | nonlocal is_called 134 | is_called = True 135 | self.assertIsInstance(tensor_tuple, tuple) 136 | t, y, adj_y, adj_param1, adj_param2 = tensor_tuple 137 | self.assertEqual(t.shape, ()) 138 | self.assertEqual(y.shape, shape) 139 | self.assertEqual(adj_y.shape, shape) 140 | self.assertEqual(adj_param1.shape, (7,)) 141 | self.assertEqual(adj_param2.shape, (())) 142 | out = max(t.abs(), y.pow(2).mean().sqrt(), adj_y.pow(2).mean().sqrt()) 143 | if not seminorm: 144 | out = max(out, adj_param1.pow(2).mean().sqrt(), adj_param2.abs()) 145 | return out 146 | 147 | xs.grad_fn.adjoint_options['norm'], is_spy_called = make_spy_on_adjoint_norm(_adjoint_norm, 148 | actual_norm) 149 | xs.sum().backward() 150 | self.assertTrue(is_called) 151 | self.assertTrue(is_spy_called[0]) 152 | 153 | # Test the various auto-constructed adjoint norms with tuple (not tensor) state 154 | for use_adjoint_options, seminorm in ((False, False), (True, False), (True, True)): 155 | with self.subTest(shape=shape, use_adjoint_options=use_adjoint_options, seminorm=seminorm): 156 | x0 = torch.tensor(1.), torch.tensor([[0.5, 0.5], [0.1, 0.1]]) 157 | if use_adjoint_options: 158 | if seminorm: 159 | # Test passing adjoint_options and wanting the seminorm 160 | kwargs = dict(adjoint_options=dict(norm='seminorm')) 161 | else: 162 | # Test passing adjoint_options but not specify the adjoint norm 163 | kwargs = dict(adjoint_options={}) 164 | else: 165 | # Test not passing adjoint_options at all. 166 | kwargs = {} 167 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params, **kwargs) 168 | adjoint_options_dict = xs[0].grad_fn.next_functions[0][0].next_functions[0][0].adjoint_options 169 | _adjoint_norm = adjoint_options_dict['norm'] 170 | 171 | is_called = False 172 | 173 | def actual_norm(tensor_tuple): 174 | nonlocal is_called 175 | is_called = True 176 | self.assertIsInstance(tensor_tuple, tuple) 177 | t, y, adj_y, adj_param1, adj_param2 = tensor_tuple 178 | self.assertEqual(t.shape, ()) 179 | self.assertEqual(y.shape, (5,)) 180 | self.assertEqual(adj_y.shape, (5,)) 181 | self.assertEqual(adj_param1.shape, (7,)) 182 | self.assertEqual(adj_param2.shape, ()) 183 | ya = y[0] 184 | yb = y[1:] 185 | adj_ya = adj_y[0] 186 | adj_yb = adj_y[1:4] 187 | out = max(t.abs(), ya.abs(), yb.pow(2).mean().sqrt(), adj_ya.abs(), adj_yb.pow(2).mean().sqrt()) 188 | if not seminorm: 189 | out = max(out, adj_param1.pow(2).mean().sqrt(), adj_param2.abs()) 190 | return out 191 | 192 | spy_on_adjoint_norm, is_spy_called = make_spy_on_adjoint_norm(_adjoint_norm, actual_norm) 193 | adjoint_options_dict['norm'] = spy_on_adjoint_norm 194 | xs[0].sum().backward() 195 | self.assertTrue(is_called) 196 | self.assertTrue(is_spy_called[0]) 197 | 198 | # Test user-passed adjoint norms with tensor (not tuple) state 199 | is_called = False 200 | 201 | def adjoint_norm(tensor_tuple): 202 | nonlocal is_called 203 | is_called = True 204 | self.assertIsInstance(tensor_tuple, tuple) 205 | t, y, adj_y, adj_param1, adj_param2 = tensor_tuple 206 | self.assertEqual(t.shape, ()) 207 | self.assertEqual(y.shape, ()) 208 | self.assertEqual(adj_y.shape, ()) 209 | self.assertEqual(adj_param1.shape, (7,)) 210 | self.assertEqual(adj_param2.shape, ()) 211 | return max(t.abs(), y.pow(2).mean().sqrt(), adj_y.pow(2).mean().sqrt(), adj_param1.pow(2).mean().sqrt(), 212 | adj_param2.abs()) 213 | 214 | x0 = torch.tensor(1.) 215 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params, 216 | adjoint_options=dict(norm=adjoint_norm)) 217 | xs.sum().backward() 218 | self.assertTrue(is_called) 219 | 220 | # Test user-passed adjoint norms with tuple (not tensor) state 221 | is_called = False 222 | 223 | def adjoint_norm(tensor_tuple): 224 | nonlocal is_called 225 | is_called = True 226 | self.assertIsInstance(tensor_tuple, tuple) 227 | t, ya, yb, adj_ya, adj_yb, adj_param1, adj_param2 = tensor_tuple 228 | self.assertEqual(t.shape, ()) 229 | self.assertEqual(ya.shape, ()) 230 | self.assertEqual(yb.shape, (2, 2)) 231 | self.assertEqual(adj_ya.shape, ()) 232 | self.assertEqual(adj_yb.shape, (2, 2)) 233 | self.assertEqual(adj_param1.shape, (7,)) 234 | self.assertEqual(adj_param2.shape, ()) 235 | return max(t.abs(), ya.abs(), yb.pow(2).mean().sqrt(), adj_ya.abs(), adj_yb.pow(2).mean().sqrt(), 236 | adj_param1.pow(2).mean().sqrt(), adj_param2.abs()) 237 | 238 | x0 = torch.tensor(1.), torch.tensor([[0.5, 0.5], [0.1, 0.1]]) 239 | xs = torchdiffeq.odeint_adjoint(f, x0, t, adjoint_params=adjoint_params, 240 | adjoint_options=dict(norm=adjoint_norm)) 241 | xs[0].sum().backward() 242 | self.assertTrue(is_called) 243 | 244 | def test_large_norm(self): 245 | 246 | def norm(tensor): 247 | return tensor.abs().max() 248 | 249 | def large_norm(tensor): 250 | return 10 * tensor.abs().max() 251 | 252 | for dtype in DTYPES: 253 | for device in DEVICES: 254 | for method in ADAPTIVE_METHODS: 255 | if dtype == torch.float32 and method == 'dopri8': 256 | continue 257 | 258 | with self.subTest(dtype=dtype, device=device, method=method): 259 | x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype) 260 | t = torch.tensor([0., 1.0], device=device, dtype=torch.float64) 261 | 262 | norm_f = _NeuralF(width=10, oscillate=True).to(device, dtype) 263 | torchdiffeq.odeint(norm_f, x0, t, method=method, options=dict(norm=norm)) 264 | large_norm_f = _NeuralF(width=10, oscillate=True).to(device, dtype) 265 | with torch.no_grad(): 266 | for norm_param, large_norm_param in zip(norm_f.parameters(), large_norm_f.parameters()): 267 | large_norm_param.copy_(norm_param) 268 | torchdiffeq.odeint(large_norm_f, x0, t, method=method, options=dict(norm=large_norm)) 269 | 270 | self.assertLessEqual(norm_f.nfe, large_norm_f.nfe) 271 | 272 | def test_seminorm(self): 273 | for dtype in DTYPES: 274 | for device in DEVICES: 275 | for method in ADAPTIVE_METHODS: 276 | # Tests with known failures 277 | if ( 278 | dtype in [torch.float32] and 279 | method in ['tsit5'] 280 | ): 281 | continue 282 | 283 | with self.subTest(dtype=dtype, device=device, method=method): 284 | 285 | if dtype == torch.float64: 286 | tol = 1e-8 287 | else: 288 | tol = 1e-6 289 | 290 | x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype) 291 | t = torch.tensor([0., 1.0], device=device, dtype=torch.float64) 292 | 293 | ode_f = _NeuralF(width=1024, oscillate=True).to(device, dtype) 294 | 295 | out = torchdiffeq.odeint_adjoint(ode_f, x0, t, atol=tol, rtol=tol, method=method) 296 | ode_f.nfe = 0 297 | out.sum().backward() 298 | default_nfe = ode_f.nfe 299 | 300 | out = torchdiffeq.odeint_adjoint(ode_f, x0, t, atol=tol, rtol=tol, method=method, 301 | adjoint_options=dict(norm='seminorm')) 302 | ode_f.nfe = 0 303 | out.sum().backward() 304 | seminorm_nfe = ode_f.nfe 305 | 306 | self.assertLessEqual(seminorm_nfe, default_nfe) 307 | 308 | 309 | if __name__ == '__main__': 310 | unittest.main() 311 | -------------------------------------------------------------------------------- /tests/problems.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import scipy.linalg 4 | import torch 5 | 6 | 7 | class ConstantODE(torch.nn.Module): 8 | 9 | def __init__(self): 10 | super(ConstantODE, self).__init__() 11 | self.a = torch.nn.Parameter(torch.tensor(0.2)) 12 | self.b = torch.nn.Parameter(torch.tensor(3.0)) 13 | 14 | def forward(self, t, y): 15 | return self.a + (y - (self.a * t + self.b))**5 16 | 17 | def y_exact(self, t): 18 | return self.a * t + self.b 19 | 20 | 21 | class SineODE(torch.nn.Module): 22 | def forward(self, t, y): 23 | return 2 * y / t + t**4 * torch.sin(2 * t) - t**2 + 4 * t**3 24 | 25 | def y_exact(self, t): 26 | return -0.5 * t**4 * torch.cos(2 * t) + 0.5 * t**3 * torch.sin(2 * t) + 0.25 * t**2 * torch.cos( 27 | 2 * t 28 | ) - t**3 + 2 * t**4 + (math.pi - 0.25) * t**2 29 | 30 | 31 | class LinearODE(torch.nn.Module): 32 | 33 | def __init__(self, dim=10): 34 | super(LinearODE, self).__init__() 35 | torch.manual_seed(0) 36 | self.dim = dim 37 | U = torch.randn(dim, dim) * 0.1 38 | A = 2 * U - (U + U.transpose(0, 1)) 39 | self.A = torch.nn.Parameter(A) 40 | self.initial_val = np.ones((dim, 1)) 41 | self.nfe = 0 42 | 43 | def forward(self, t, y): 44 | self.nfe += 1 45 | return torch.mm(self.A, y.reshape(self.dim, 1)).reshape(-1) 46 | 47 | def y_exact(self, t): 48 | t_numpy = t.detach().cpu().numpy() 49 | A_np = self.A.detach().cpu().numpy() 50 | ans = [] 51 | for t_i in t_numpy: 52 | ans.append(np.matmul(scipy.linalg.expm(A_np * t_i), self.initial_val)) 53 | return torch.stack([torch.tensor(ans_) for ans_ in ans]).reshape(len(t_numpy), self.dim).to(t) 54 | 55 | 56 | class ExpODE(torch.nn.Module): 57 | def forward(self, t, y): 58 | return -0.1 * self.y_exact(t) 59 | 60 | def y_exact(self, t): 61 | return torch.exp(-0.1 * t) 62 | 63 | 64 | PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE, 'exp': ExpODE} 65 | DTYPES = (torch.float32, torch.float64) 66 | DEVICES = ['cpu'] 67 | if torch.cuda.is_available(): 68 | DEVICES.append('cuda') 69 | FIXED_EXPLICIT_METHODS = ('euler', 'midpoint', 'heun2', 'heun3', 'rk4', 'explicit_adams', 'implicit_adams') 70 | FIXED_IMPLICIT_METHODS = ('implicit_euler', 'implicit_midpoint', 'trapezoid', 'radauIIA3', 'gl4', 'radauIIA5', 'gl6', 'sdirk2', 'trbdf2') 71 | FIXED_METHODS = FIXED_EXPLICIT_METHODS + FIXED_IMPLICIT_METHODS 72 | ADAMS_METHODS = ('explicit_adams', 'implicit_adams') 73 | ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'tsit5', 'dopri5', 'dopri8') 74 | SCIPY_METHODS = ('scipy_solver',) 75 | IMPLICIT_METHODS = FIXED_IMPLICIT_METHODS 76 | METHODS = FIXED_METHODS + ADAPTIVE_METHODS + SCIPY_METHODS 77 | 78 | 79 | def construct_problem(device, npts=10, ode='constant', reverse=False, dtype=torch.float64): 80 | 81 | f = PROBLEMS[ode]().to(dtype=dtype, device=device) 82 | 83 | t_points = torch.linspace(1, 8, npts, dtype=torch.float64, device=device, requires_grad=True) 84 | sol = f.y_exact(t_points).to(dtype) 85 | 86 | def _flip(x, dim): 87 | indices = [slice(None)] * x.dim() 88 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=device) 89 | return x[tuple(indices)] 90 | 91 | if reverse: 92 | t_points = _flip(t_points, 0).clone().detach() 93 | sol = _flip(sol, 0).clone().detach() 94 | 95 | return f, sol[0].detach().requires_grad_(True), t_points, sol 96 | 97 | 98 | if __name__ == '__main__': 99 | f = SineODE().cpu() 100 | t_points = torch.linspace(1, 8, 100, device='cpu') 101 | sol = f.y_exact(t_points) 102 | 103 | import matplotlib.pyplot as plt 104 | plt.plot(t_points, sol) 105 | plt.show() 106 | -------------------------------------------------------------------------------- /tests/run_all.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from api_tests import * 3 | from event_tests import * 4 | from gradient_tests import * 5 | from norm_tests import * 6 | from odeint_tests import * 7 | 8 | if __name__ == '__main__': 9 | unittest.main() 10 | -------------------------------------------------------------------------------- /torchdiffeq/__init__.py: -------------------------------------------------------------------------------- 1 | from ._impl import odeint 2 | from ._impl import odeint_adjoint 3 | from ._impl import odeint_event 4 | from ._impl import odeint_dense 5 | __version__ = "0.2.5" 6 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/__init__.py: -------------------------------------------------------------------------------- 1 | from .odeint import odeint, odeint_dense, odeint_event 2 | from .adjoint import odeint_adjoint 3 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/adaptive_heun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver 3 | 4 | 5 | _ADAPTIVE_HEUN_TABLEAU = _ButcherTableau( 6 | alpha=torch.tensor([1.], dtype=torch.float64), 7 | beta=[ 8 | torch.tensor([1.], dtype=torch.float64), 9 | ], 10 | c_sol=torch.tensor([0.5, 0.5], dtype=torch.float64), 11 | c_error=torch.tensor([ 12 | 0.5, 13 | -0.5, 14 | ], dtype=torch.float64), 15 | ) 16 | 17 | _AH_C_MID = torch.tensor([ 18 | 0.5, 0. 19 | ], dtype=torch.float64) 20 | 21 | 22 | class AdaptiveHeunSolver(RKAdaptiveStepsizeODESolver): 23 | order = 2 24 | tableau = _ADAPTIVE_HEUN_TABLEAU 25 | mid = _AH_C_MID 26 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/adjoint.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | from .odeint import SOLVERS, odeint 5 | from .misc import _check_inputs, _flat_to_shape, _mixed_norm, _all_callback_names, _all_adjoint_callback_names 6 | 7 | 8 | class OdeintAdjointMethod(torch.autograd.Function): 9 | 10 | @staticmethod 11 | def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, 12 | adjoint_options, t_requires_grad, *adjoint_params): 13 | 14 | ctx.shapes = shapes 15 | ctx.func = func 16 | ctx.adjoint_rtol = adjoint_rtol 17 | ctx.adjoint_atol = adjoint_atol 18 | ctx.adjoint_method = adjoint_method 19 | ctx.adjoint_options = adjoint_options 20 | ctx.t_requires_grad = t_requires_grad 21 | ctx.event_mode = event_fn is not None 22 | 23 | with torch.no_grad(): 24 | ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn) 25 | 26 | if event_fn is None: 27 | y = ans 28 | ctx.save_for_backward(t, y, *adjoint_params) 29 | else: 30 | event_t, y = ans 31 | ctx.save_for_backward(t, y, event_t, *adjoint_params) 32 | 33 | return ans 34 | 35 | @staticmethod 36 | def backward(ctx, *grad_y): 37 | with torch.no_grad(): 38 | func = ctx.func 39 | adjoint_rtol = ctx.adjoint_rtol 40 | adjoint_atol = ctx.adjoint_atol 41 | adjoint_method = ctx.adjoint_method 42 | adjoint_options = ctx.adjoint_options 43 | t_requires_grad = ctx.t_requires_grad 44 | 45 | # Backprop as if integrating up to event time. 46 | # Does NOT backpropagate through the event time. 47 | event_mode = ctx.event_mode 48 | if event_mode: 49 | t, y, event_t, *adjoint_params = ctx.saved_tensors 50 | _t = t 51 | t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)]) 52 | grad_y = grad_y[1] 53 | else: 54 | t, y, *adjoint_params = ctx.saved_tensors 55 | grad_y = grad_y[0] 56 | 57 | adjoint_params = tuple(adjoint_params) 58 | 59 | ################################## 60 | # Set up initial state # 61 | ################################## 62 | 63 | # [-1] because y and grad_y are both of shape (len(t), *y0.shape) 64 | aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y 65 | aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params 66 | 67 | ################################## 68 | # Set up backward ODE func # 69 | ################################## 70 | 71 | # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives. 72 | def augmented_dynamics(t, y_aug): 73 | # Dynamics of the original system augmented with 74 | # the adjoint wrt y, and an integrator wrt t and args. 75 | y = y_aug[1] 76 | adj_y = y_aug[2] 77 | # ignore gradients wrt time and parameters 78 | 79 | with torch.enable_grad(): 80 | t_ = t.detach() 81 | t = t_.requires_grad_(True) 82 | y = y.detach().requires_grad_(True) 83 | 84 | # If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which 85 | # doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients 86 | # wrt t here means we won't compute that if we don't need it. 87 | func_eval = func(t if t_requires_grad else t_, y) 88 | 89 | # Workaround for PyTorch bug #39784 90 | _t = torch.as_strided(t, (), ()) # noqa 91 | _y = torch.as_strided(y, (), ()) # noqa 92 | _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa 93 | 94 | vjp_t, vjp_y, *vjp_params = torch.autograd.grad( 95 | func_eval, (t, y) + adjoint_params, -adj_y, 96 | allow_unused=True, retain_graph=True 97 | ) 98 | 99 | # autograd.grad returns None if no gradient, set to zero. 100 | vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t 101 | vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y 102 | vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param 103 | for param, vjp_param in zip(adjoint_params, vjp_params)] 104 | 105 | return (vjp_t, func_eval, vjp_y, *vjp_params) 106 | 107 | # Add adjoint callbacks 108 | for callback_name, adjoint_callback_name in zip(_all_callback_names, _all_adjoint_callback_names): 109 | try: 110 | callback = getattr(func, adjoint_callback_name) 111 | except AttributeError: 112 | pass 113 | else: 114 | setattr(augmented_dynamics, callback_name, callback) 115 | 116 | ################################## 117 | # Solve adjoint ODE # 118 | ################################## 119 | 120 | if t_requires_grad: 121 | time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device) 122 | else: 123 | time_vjps = None 124 | for i in range(len(t) - 1, 0, -1): 125 | if t_requires_grad: 126 | # Compute the effect of moving the current time measurement point. 127 | # We don't compute this unless we need to, to save some computation. 128 | func_eval = func(t[i], y[i]) 129 | dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1)) 130 | aug_state[0] -= dLd_cur_t 131 | time_vjps[i] = dLd_cur_t 132 | 133 | # Run the augmented system backwards in time. 134 | aug_state = odeint( 135 | augmented_dynamics, tuple(aug_state), 136 | t[i - 1:i + 1].flip(0), 137 | rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options 138 | ) 139 | aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value 140 | aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state 141 | aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point 142 | 143 | if t_requires_grad: 144 | time_vjps[0] = aug_state[0] 145 | 146 | # Only compute gradient wrt initial time when in event handling mode. 147 | if event_mode and t_requires_grad: 148 | time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])]) 149 | 150 | adj_y = aug_state[2] 151 | adj_params = aug_state[3:] 152 | 153 | return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params) 154 | 155 | 156 | def odeint_adjoint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, 157 | adjoint_rtol=None, adjoint_atol=None, adjoint_method=None, adjoint_options=None, adjoint_params=None): 158 | 159 | # We need this in order to access the variables inside this module, 160 | # since we have no other way of getting variables along the execution path. 161 | if adjoint_params is None and not isinstance(func, nn.Module): 162 | raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they ' 163 | 'can be specified explicitly via the `adjoint_params` argument. If there are no parameters ' 164 | 'then it is allowable to set `adjoint_params=()`.') 165 | 166 | # Must come before _check_inputs as we don't want to use normalised input (in particular any changes to options) 167 | if adjoint_rtol is None: 168 | adjoint_rtol = rtol 169 | if adjoint_atol is None: 170 | adjoint_atol = atol 171 | if adjoint_method is None: 172 | adjoint_method = method 173 | 174 | if adjoint_method != method and options is not None and adjoint_options is None: 175 | raise ValueError("If `adjoint_method != method` then we cannot infer `adjoint_options` from `options`. So as " 176 | "`options` has been passed then `adjoint_options` must be passed as well.") 177 | 178 | if adjoint_options is None: 179 | adjoint_options = {k: v for k, v in options.items() if k != "norm"} if options is not None else {} 180 | else: 181 | # Avoid in-place modifying a user-specified dict. 182 | adjoint_options = adjoint_options.copy() 183 | 184 | if adjoint_params is None: 185 | adjoint_params = tuple(find_parameters(func)) 186 | else: 187 | adjoint_params = tuple(adjoint_params) # in case adjoint_params is a generator. 188 | 189 | # Filter params that don't require gradients. 190 | oldlen_ = len(adjoint_params) 191 | adjoint_params = tuple(p for p in adjoint_params if p.requires_grad) 192 | if len(adjoint_params) != oldlen_: 193 | # Some params were excluded. 194 | # Issue a warning if a user-specified norm is specified. 195 | if 'norm' in adjoint_options and callable(adjoint_options['norm']): 196 | warnings.warn("An adjoint parameter was passed without requiring gradient. For efficiency this will be " 197 | "excluded from the adjoint pass, and will not appear as a tensor in the adjoint norm.") 198 | 199 | # Convert to flattened state. 200 | shapes, func, y0, t, rtol, atol, method, options, event_fn, decreasing_time = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS) 201 | 202 | # Handle the adjoint norm function. 203 | state_norm = options["norm"] 204 | handle_adjoint_norm_(adjoint_options, shapes, state_norm) 205 | 206 | ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, 207 | adjoint_method, adjoint_options, t.requires_grad, *adjoint_params) 208 | 209 | if event_fn is None: 210 | solution = ans 211 | else: 212 | event_t, solution = ans 213 | event_t = event_t.to(t) 214 | if decreasing_time: 215 | event_t = -event_t 216 | 217 | if shapes is not None: 218 | solution = _flat_to_shape(solution, (len(t),), shapes) 219 | 220 | if event_fn is None: 221 | return solution 222 | else: 223 | return event_t, solution 224 | 225 | 226 | def find_parameters(module): 227 | 228 | assert isinstance(module, nn.Module) 229 | 230 | # If called within DataParallel, parameters won't appear in module.parameters(). 231 | if getattr(module, '_is_replica', False): 232 | 233 | def find_tensor_attributes(module): 234 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) and v.requires_grad] 235 | return tuples 236 | 237 | gen = module._named_members(get_members_fn=find_tensor_attributes) 238 | return [param for _, param in gen] 239 | else: 240 | return list(module.parameters()) 241 | 242 | 243 | def handle_adjoint_norm_(adjoint_options, shapes, state_norm): 244 | """In-place modifies the adjoint options to choose or wrap the norm function.""" 245 | 246 | # This is the default adjoint norm on the backward pass: a mixed norm over the tuple of inputs. 247 | def default_adjoint_norm(tensor_tuple): 248 | t, y, adj_y, *adj_params = tensor_tuple 249 | # (If the state is actually a flattened tuple then this will be unpacked again in state_norm.) 250 | return max(t.abs(), state_norm(y), state_norm(adj_y), _mixed_norm(adj_params)) 251 | 252 | if "norm" not in adjoint_options: 253 | # `adjoint_options` was not explicitly specified by the user. Use the default norm. 254 | adjoint_options["norm"] = default_adjoint_norm 255 | else: 256 | # `adjoint_options` was explicitly specified by the user... 257 | try: 258 | adjoint_norm = adjoint_options['norm'] 259 | except KeyError: 260 | # ...but they did not specify the norm argument. Back to plan A: use the default norm. 261 | adjoint_options['norm'] = default_adjoint_norm 262 | else: 263 | # ...and they did specify the norm argument. 264 | if adjoint_norm == 'seminorm': 265 | # They told us they want to use seminorms. Slight modification to plan A: use the default norm, 266 | # but ignore the parameter state 267 | def adjoint_seminorm(tensor_tuple): 268 | t, y, adj_y, *adj_params = tensor_tuple 269 | # (If the state is actually a flattened tuple then this will be unpacked again in state_norm.) 270 | return max(t.abs(), state_norm(y), state_norm(adj_y)) 271 | adjoint_options['norm'] = adjoint_seminorm 272 | else: 273 | # And they're using their own custom norm. 274 | if shapes is None: 275 | # The state on the forward pass was a tensor, not a tuple. We don't need to do anything, they're 276 | # already going to get given the full adjoint state as (t, y, adj_y, adj_params) 277 | pass # this branch included for clarity 278 | else: 279 | # This is the bit that is tuple/tensor abstraction-breaking, because the odeint machinery 280 | # doesn't know about the tupled nature of the forward state. We need to tell the user's adjoint 281 | # norm about that ourselves. 282 | 283 | def _adjoint_norm(tensor_tuple): 284 | t, y, adj_y, *adj_params = tensor_tuple 285 | y = _flat_to_shape(y, (), shapes) 286 | adj_y = _flat_to_shape(adj_y, (), shapes) 287 | return adjoint_norm((t, *y, *adj_y, *adj_params)) 288 | adjoint_options['norm'] = _adjoint_norm 289 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/bosh3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver 3 | 4 | 5 | _BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau( 6 | alpha=torch.tensor([1 / 2, 3 / 4, 1.], dtype=torch.float64), 7 | beta=[ 8 | torch.tensor([1 / 2], dtype=torch.float64), 9 | torch.tensor([0., 3 / 4], dtype=torch.float64), 10 | torch.tensor([2 / 9, 1 / 3, 4 / 9], dtype=torch.float64) 11 | ], 12 | c_sol=torch.tensor([2 / 9, 1 / 3, 4 / 9, 0.], dtype=torch.float64), 13 | c_error=torch.tensor([2 / 9 - 7 / 24, 1 / 3 - 1 / 4, 4 / 9 - 1 / 3, -1 / 8], dtype=torch.float64), 14 | ) 15 | 16 | _BS_C_MID = torch.tensor([0., 0.5, 0., 0.], dtype=torch.float64) 17 | 18 | 19 | class Bosh3Solver(RKAdaptiveStepsizeODESolver): 20 | order = 3 21 | tableau = _BOGACKI_SHAMPINE_TABLEAU 22 | mid = _BS_C_MID 23 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/dopri5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver 3 | 4 | 5 | _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( 6 | alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64), 7 | beta=[ 8 | torch.tensor([1 / 5], dtype=torch.float64), 9 | torch.tensor([3 / 40, 9 / 40], dtype=torch.float64), 10 | torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64), 11 | torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64), 12 | torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64), 13 | torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64), 14 | ], 15 | c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64), 16 | c_error=torch.tensor([ 17 | 35 / 384 - 1951 / 21600, 18 | 0, 19 | 500 / 1113 - 22642 / 50085, 20 | 125 / 192 - 451 / 720, 21 | -2187 / 6784 - -12231 / 42400, 22 | 11 / 84 - 649 / 6300, 23 | -1. / 60., 24 | ], dtype=torch.float64), 25 | ) 26 | 27 | DPS_C_MID = torch.tensor([ 28 | 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, 29 | 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 30 | ], dtype=torch.float64) 31 | 32 | 33 | class Dopri5Solver(RKAdaptiveStepsizeODESolver): 34 | order = 5 35 | tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU 36 | mid = DPS_C_MID 37 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/dopri8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver 3 | 4 | 5 | A = [1 / 18, 1 / 12, 1 / 8, 5 / 16, 3 / 8, 59 / 400, 93 / 200, 5490023248 / 9719169821, 13 / 20, 1201146811 / 1299019798, 1, 1, 1] 6 | 7 | B = [ 8 | [1 / 18], 9 | 10 | [1 / 48, 1 / 16], 11 | 12 | [1 / 32, 0, 3 / 32], 13 | 14 | [5 / 16, 0, -75 / 64, 75 / 64], 15 | 16 | [3 / 80, 0, 0, 3 / 16, 3 / 20], 17 | 18 | [29443841 / 614563906, 0, 0, 77736538 / 692538347, -28693883 / 1125000000, 23124283 / 1800000000], 19 | 20 | [16016141 / 946692911, 0, 0, 61564180 / 158732637, 22789713 / 633445777, 545815736 / 2771057229, -180193667 / 1043307555], 21 | 22 | [39632708 / 573591083, 0, 0, -433636366 / 683701615, -421739975 / 2616292301, 100302831 / 723423059, 790204164 / 839813087, 800635310 / 3783071287], 23 | 24 | [246121993 / 1340847787, 0, 0, -37695042795 / 15268766246, -309121744 / 1061227803, -12992083 / 490766935, 6005943493 / 2108947869, 393006217 / 1396673457, 123872331 / 1001029789], 25 | 26 | [-1028468189 / 846180014, 0, 0, 8478235783 / 508512852, 1311729495 / 1432422823, -10304129995 / 1701304382, -48777925059 / 3047939560, 15336726248 / 1032824649, -45442868181 / 3398467696, 3065993473 / 597172653], 27 | 28 | [185892177 / 718116043, 0, 0, -3185094517 / 667107341, -477755414 / 1098053517, -703635378 / 230739211, 5731566787 / 1027545527, 5232866602 / 850066563, -4093664535 / 808688257, 3962137247 / 1805957418, 65686358 / 487910083], 29 | 30 | [403863854 / 491063109, 0, 0, -5068492393 / 434740067, -411421997 / 543043805, 652783627 / 914296604, 11173962825 / 925320556, -13158990841 / 6184727034, 3936647629 / 1978049680, -160528059 / 685178525, 248638103 / 1413531060, 0], 31 | 32 | [14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731, 561292985 / 797845732, -1041891430 / 1371343529, 760417239 / 1151165299, 118820643 / 751138087, -528747749 / 2220607170, 1 / 4] 33 | ] 34 | 35 | C_sol = [14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731, 561292985 / 797845732, -1041891430 / 1371343529, 760417239 / 1151165299, 118820643 / 751138087, -528747749 / 2220607170, 1 / 4, 0] 36 | 37 | C_err = [14005451 / 335480064 - 13451932 / 455176623, 0, 0, 0, 0, -59238493 / 1068277825 - -808719846 / 976000145, 181606767 / 758867731 - 1757004468 / 5645159321, 561292985 / 797845732 - 656045339 / 265891186, -1041891430 / 1371343529 - -3867574721 / 1518517206, 760417239 / 1151165299 - 465885868 / 322736535, 118820643 / 751138087 - 53011238 / 667516719, -528747749 / 2220607170 - 2 / 45, 1 / 4, 0] 38 | 39 | h = 1 / 2 40 | 41 | C_mid = [0.] * 14 42 | 43 | C_mid[0] = (- 6.3448349392860401388 * (h**5) + 22.1396504998094068976 * (h**4) - 30.0610568289666450593 * (h**3) + 19.9990069333683970610 * (h**2) - 6.6910181737837595697 * h + 1.0) / (1 / h) 44 | 45 | C_mid[5] = (- 39.6107919852202505218 * (h**5) + 116.4422149550342161651 * (h**4) - 121.4999627731334642623 * (h**3) + 52.2273532792945524050 * (h**2) - 7.6142658045872677172 * h) / (1 / h) 46 | 47 | C_mid[6] = (20.3761213808791436958 * (h**5) - 67.1451318825957197185 * (h**4) + 83.1721004639847717481 * (h**3) - 46.8919164181093621583 * (h**2) + 10.7281392630428866124 * h) / (1 / h) 48 | 49 | C_mid[7] = (7.3347098826795362023 * (h**5) - 16.5672243527496524646 * (h**4) + 9.5724507555993664382 * (h**3) - 0.1890893225010595467 * (h**2) + 0.5526637063753648783 * h) / (1 / h) 50 | 51 | C_mid[8] = (32.8801774352459155182 * (h**5) - 89.9916014847245016028 * (h**4) + 87.8406057677205645007 * (h**3) - 35.7075975946222072821 * (h**2) + 4.2186562625665153803 * h) / (1 / h) 52 | 53 | C_mid[9] = (- 10.1588990526426760954 * (h**5) + 22.6237489648532849093 * (h**4) - 17.4152107770762969005 * (h**3) + 6.2736448083240352160 * (h**2) - 0.6627209125361597559 * h) / (1 / h) 54 | 55 | C_mid[10] = (- 12.5401268098782561200 * (h**5) + 32.2362340167355370113 * (h**4) - 28.5903289514790976966 * (h**3) + 10.3160881272450748458 * (h**2) - 1.2636789001135462218 * h) / (1 / h) 56 | 57 | C_mid[11] = (29.5553001484516038033 * (h**5) - 82.1020315488359848644 * (h**4) + 81.6630950584341412934 * (h**3) - 34.7650769866611817349 * (h**2) + 5.4106037898590422230 * h) / (1 / h) 58 | 59 | C_mid[12] = (- 41.7923486424390588923 * (h**5) + 116.2662185791119533462 * (h**4) - 114.9375291377009418170 * (h**3) + 47.7457971078225540396 * (h**2) - 7.0321379067945741781 * h) / (1 / h) 60 | 61 | C_mid[13] = (20.3006925822100825485 * (h**5) - 53.9020777466385396792 * (h**4) + 50.2558364226176017553 * (h**3) - 19.0082099341608028453 * (h**2) + 2.3537586759714983486 * h) / (1 / h) 62 | 63 | 64 | A = torch.tensor(A, dtype=torch.float64) 65 | B = [torch.tensor(B_, dtype=torch.float64) for B_ in B] 66 | C_sol = torch.tensor(C_sol, dtype=torch.float64) 67 | C_err = torch.tensor(C_err, dtype=torch.float64) 68 | _C_mid = torch.tensor(C_mid, dtype=torch.float64) 69 | 70 | _DOPRI8_TABLEAU = _ButcherTableau(alpha=A, beta=B, c_sol=C_sol, c_error=C_err) 71 | 72 | 73 | class Dopri8Solver(RKAdaptiveStepsizeODESolver): 74 | order = 8 75 | tableau = _DOPRI8_TABLEAU 76 | mid = _C_mid 77 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/event_handling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def find_event(interp_fn, sign0, t0, t1, event_fn, tol): 6 | with torch.no_grad(): 7 | 8 | # Num iterations for the secant method until tolerance is within target. 9 | nitrs = torch.ceil(torch.log((t1 - t0) / tol) / math.log(2.0)) 10 | 11 | for _ in range(nitrs.long()): 12 | t_mid = (t1 + t0) / 2.0 13 | y_mid = interp_fn(t_mid) 14 | sign_mid = torch.sign(event_fn(t_mid, y_mid)) 15 | same_as_sign0 = (sign0 == sign_mid) 16 | t0 = torch.where(same_as_sign0, t_mid, t0) 17 | t1 = torch.where(same_as_sign0, t1, t_mid) 18 | event_t = (t0 + t1) / 2.0 19 | 20 | return event_t, interp_fn(event_t) 21 | 22 | 23 | def combine_event_functions(event_fn, t0, y0): 24 | """ 25 | We ensure all event functions are initially positive, 26 | so then we can combine them by taking a min. 27 | """ 28 | with torch.no_grad(): 29 | initial_signs = torch.sign(event_fn(t0, y0)) 30 | 31 | def combined_event_fn(t, y): 32 | c = event_fn(t, y) 33 | return torch.min(c * initial_signs) 34 | 35 | return combined_event_fn 36 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/fehlberg2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver 3 | 4 | _FEHLBERG2_TABLEAU = _ButcherTableau( 5 | alpha=torch.tensor([1 / 2, 1.0], dtype=torch.float64), 6 | beta=[ 7 | torch.tensor([1 / 2], dtype=torch.float64), 8 | torch.tensor([1 / 256, 255 / 256], dtype=torch.float64), 9 | ], 10 | c_sol=torch.tensor([1 / 512, 255 / 256, 1 / 512], dtype=torch.float64), 11 | c_error=torch.tensor( 12 | [-1 / 512, 0, 1 / 512], dtype=torch.float64 13 | ), 14 | ) 15 | 16 | _FE_C_MID = torch.tensor([0.0, 0.5, 0.0], dtype=torch.float64) 17 | 18 | 19 | class Fehlberg2(RKAdaptiveStepsizeODESolver): 20 | order = 2 21 | tableau = _FEHLBERG2_TABLEAU 22 | mid = _FE_C_MID 23 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/fixed_adams.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import sys 3 | import torch 4 | import warnings 5 | from .solvers import FixedGridODESolver 6 | from .misc import _compute_error_ratio, _linf_norm 7 | from .misc import Perturb 8 | from .rk_common import rk4_alt_step_func 9 | 10 | _BASHFORTH_COEFFICIENTS = [ 11 | [], # order 0 12 | [11], 13 | [3, -1], 14 | [23, -16, 5], 15 | [55, -59, 37, -9], 16 | [1901, -2774, 2616, -1274, 251], 17 | [4277, -7923, 9982, -7298, 2877, -475], 18 | [198721, -447288, 705549, -688256, 407139, -134472, 19087], 19 | [434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799], 20 | [14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017], 21 | [30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753], 22 | [ 23 | 2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920, 24 | 7417904451, -1479574348, 134211265 25 | ], 26 | [ 27 | 4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290, 28 | 58189107627, -17410248271, 3158642445, -262747265 29 | ], 30 | [ 31 | 13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764, 32 | 1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734, 33 | 703604254357 34 | ], 35 | [ 36 | 27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755, 37 | 4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906, 38 | 19382853593787, -1382741929621 39 | ], 40 | [ 41 | 173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728, 42 | 53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179, 43 | -3728807256577472, 859236476684231, -122594813904112, 8164168737599 44 | ], 45 | [ 46 | 362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733, 47 | -131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331, 48 | 70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375 49 | ], 50 | [ 51 | 192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010, 52 | -102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560, 53 | -158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550, 54 | -5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249 55 | ], 56 | [ 57 | 401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760, 58 | -304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750, 59 | -706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224, 60 | -49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873 61 | ], 62 | [ 63 | 333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344, 64 | 151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408, 65 | 1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432, 66 | 343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552, 67 | 2157574942881818312049, -239560589366324764716, 12600467236042756559 68 | ], 69 | [ 70 | 691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295, 71 | 399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380, 72 | 4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210, 73 | 1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532, 74 | 28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303 75 | ], 76 | ] 77 | 78 | _MOULTON_COEFFICIENTS = [ 79 | [], # order 0 80 | [1], 81 | [1, 1], 82 | [5, 8, -1], 83 | [9, 19, -5, 1], 84 | [251, 646, -264, 106, -19], 85 | [475, 1427, -798, 482, -173, 27], 86 | [19087, 65112, -46461, 37504, -20211, 6312, -863], 87 | [36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375], 88 | [1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953], 89 | [2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281], 90 | [ 91 | 134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195, 92 | 36284876, -3250433 93 | ], 94 | [ 95 | 262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115, 96 | 384709327, -68928781, 5675265 97 | ], 98 | [ 99 | 703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152, 100 | 19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093 101 | ], 102 | [ 103 | 1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892, 104 | 80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093 105 | ], 106 | [ 107 | 8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632, 108 | -963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544, 109 | -14110480969927, 1998759236336, -132282840127 110 | ], 111 | [ 112 | 16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733, 113 | -3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483, 114 | -137515713789319, 29219384284087, -3867689367599, 240208245823 115 | ], 116 | [ 117 | 8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010, 118 | 1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370, 119 | -1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610, 120 | 1913813460537746, -111956703448001 121 | ], 122 | [ 123 | 15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520, 124 | 4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490, 125 | -6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220, 126 | 31816981024600492, -3722582669836627, 205804074290625 127 | ], 128 | [ 129 | 12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800, 130 | -2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800, 131 | -15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288, 132 | -4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136, 133 | -26182538841925312881, 2895045518506940460, -151711881512390095 134 | ], 135 | [ 136 | 24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335, 137 | -5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820, 138 | -51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906, 139 | -22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212, 140 | -325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815 141 | ], 142 | ] 143 | 144 | _DIVISOR = [ 145 | None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000, 146 | 31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000 147 | ] 148 | 149 | _BASHFORTH_DIVISOR = [torch.tensor([b / divisor for b in bashforth], dtype=torch.float64) 150 | for bashforth, divisor in zip(_BASHFORTH_COEFFICIENTS, _DIVISOR)] 151 | _MOULTON_DIVISOR = [torch.tensor([m / divisor for m in moulton], dtype=torch.float64) 152 | for moulton, divisor in zip(_MOULTON_COEFFICIENTS, _DIVISOR)] 153 | 154 | _MIN_ORDER = 4 155 | _MAX_ORDER = 12 156 | _MAX_ITERS = 4 157 | 158 | 159 | # TODO: replace this with PyTorch operations (a little hard because y is a deque being used as a circular buffer) 160 | def _dot_product(x, y): 161 | return sum(xi * yi for xi, yi in zip(x, y)) 162 | 163 | 164 | class AdamsBashforthMoulton(FixedGridODESolver): 165 | order = 4 166 | 167 | def __init__(self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER, 168 | **kwargs): 169 | super(AdamsBashforthMoulton, self).__init__(func, y0, rtol=rtol, atol=rtol, **kwargs) 170 | assert max_order <= _MAX_ORDER, "max_order must be at most {}".format(_MAX_ORDER) 171 | if max_order < _MIN_ORDER: 172 | warnings.warn("max_order is below {}, so the solver reduces to `rk4`.".format(_MIN_ORDER)) 173 | 174 | self.rtol = torch.as_tensor(rtol, dtype=y0.dtype, device=y0.device) 175 | self.atol = torch.as_tensor(atol, dtype=y0.dtype, device=y0.device) 176 | self.implicit = implicit 177 | self.max_iters = max_iters 178 | self.max_order = int(max_order) 179 | self.prev_f = collections.deque(maxlen=self.max_order - 1) 180 | self.prev_t = None 181 | 182 | self.bashforth = [x.to(y0.device) for x in _BASHFORTH_DIVISOR] 183 | self.moulton = [x.to(y0.device) for x in _MOULTON_DIVISOR] 184 | 185 | def _update_history(self, t, f): 186 | if self.prev_t is None or self.prev_t != t: 187 | self.prev_f.appendleft(f) 188 | self.prev_t = t 189 | 190 | def _has_converged(self, y0, y1): 191 | """Checks that each element is within the error tolerance.""" 192 | error_ratio = _compute_error_ratio(torch.abs(y0 - y1), self.rtol, self.atol, y0, y1, _linf_norm) 193 | return error_ratio < 1 194 | 195 | def _step_func(self, func, t0, dt, t1, y0): 196 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) 197 | self._update_history(t0, f0) 198 | order = min(len(self.prev_f), self.max_order - 1) 199 | if order < _MIN_ORDER - 1: 200 | # Compute using RK4. 201 | return rk4_alt_step_func(func, t0, dt, t1, y0, f0=self.prev_f[0], perturb=self.perturb), f0 202 | else: 203 | # Adams-Bashforth predictor. 204 | bashforth_coeffs = self.bashforth[order] 205 | dy = _dot_product(dt * bashforth_coeffs, self.prev_f).type_as(y0) # bashforth is float64 so cast back 206 | 207 | # Adams-Moulton corrector. 208 | if self.implicit: 209 | moulton_coeffs = self.moulton[order + 1] 210 | delta = dt * _dot_product(moulton_coeffs[1:], self.prev_f).type_as(y0) # moulton is float64 so cast back 211 | converged = False 212 | for _ in range(self.max_iters): 213 | dy_old = dy 214 | f = func(t1, y0 + dy, perturb=Perturb.PREV if self.perturb else Perturb.NONE) 215 | dy = (dt * (moulton_coeffs[0]) * f).type_as(y0) + delta # moulton is float64 so cast back 216 | converged = self._has_converged(dy_old, dy) 217 | if converged: 218 | break 219 | if not converged: 220 | warnings.warn('Functional iteration did not converge. Solution may be incorrect.') 221 | self.prev_f.pop() 222 | self._update_history(t0, f) 223 | return dy, f0 224 | 225 | 226 | class AdamsBashforth(AdamsBashforthMoulton): 227 | def __init__(self, func, y0, **kwargs): 228 | super(AdamsBashforth, self).__init__(func, y0, implicit=False, **kwargs) 229 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/fixed_grid.py: -------------------------------------------------------------------------------- 1 | from .solvers import FixedGridODESolver 2 | from .rk_common import rk4_alt_step_func, rk3_step_func, rk2_step_func 3 | from .misc import Perturb 4 | 5 | 6 | class Euler(FixedGridODESolver): 7 | order = 1 8 | 9 | def _step_func(self, func, t0, dt, t1, y0): 10 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) 11 | return dt * f0, f0 12 | 13 | 14 | class Midpoint(FixedGridODESolver): 15 | order = 2 16 | 17 | def _step_func(self, func, t0, dt, t1, y0): 18 | half_dt = 0.5 * dt 19 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) 20 | y_mid = y0 + f0 * half_dt 21 | return dt * func(t0 + half_dt, y_mid), f0 22 | 23 | 24 | class RK4(FixedGridODESolver): 25 | order = 4 26 | 27 | def _step_func(self, func, t0, dt, t1, y0): 28 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) 29 | return rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0 30 | 31 | 32 | class Heun3(FixedGridODESolver): 33 | order = 3 34 | 35 | def _step_func(self, func, t0, dt, t1, y0): 36 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) 37 | 38 | butcher_tableu = [ 39 | [0.0, 0.0, 0.0, 0.0], 40 | [1/3, 1/3, 0.0, 0.0], 41 | [2/3, 0.0, 2/3, 0.0], 42 | [0.0, 1/4, 0.0, 3/4], 43 | ] 44 | 45 | return rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0 46 | 47 | 48 | class Heun2(FixedGridODESolver): 49 | order = 2 50 | 51 | def _step_func(self, func, t0, dt, t1, y0): 52 | f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) 53 | 54 | butcher_tableu = [ 55 | [0.0, 0.0, 0.0], 56 | [1.0, 1.0, 0.0], 57 | [0.0, 1/2, 1/2], 58 | ] 59 | 60 | return rk2_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0 61 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/fixed_grid_implicit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rk_common import FixedGridFIRKODESolver, FixedGridDIRKODESolver 3 | from .rk_common import _ButcherTableau 4 | 5 | _sqrt_2 = torch.sqrt(torch.tensor(2, dtype=torch.float64)).item() 6 | _sqrt_3 = torch.sqrt(torch.tensor(3, dtype=torch.float64)).item() 7 | _sqrt_6 = torch.sqrt(torch.tensor(6, dtype=torch.float64)).item() 8 | _sqrt_15 = torch.sqrt(torch.tensor(15, dtype=torch.float64)).item() 9 | 10 | _IMPLICIT_EULER_TABLEAU = _ButcherTableau( 11 | alpha=torch.tensor([1], dtype=torch.float64), 12 | beta=[ 13 | torch.tensor([1], dtype=torch.float64), 14 | ], 15 | c_sol=torch.tensor([1], dtype=torch.float64), 16 | c_error=torch.tensor([], dtype=torch.float64), 17 | ) 18 | 19 | class ImplicitEuler(FixedGridFIRKODESolver): 20 | order = 1 21 | tableau = _IMPLICIT_EULER_TABLEAU 22 | 23 | _IMPLICIT_MIDPOINT_TABLEAU = _ButcherTableau( 24 | alpha=torch.tensor([1 / 2], dtype=torch.float64), 25 | beta=[ 26 | torch.tensor([1 / 2], dtype=torch.float64), 27 | 28 | ], 29 | c_sol=torch.tensor([1], dtype=torch.float64), 30 | c_error=torch.tensor([], dtype=torch.float64), 31 | ) 32 | 33 | class ImplicitMidpoint(FixedGridFIRKODESolver): 34 | order = 2 35 | tableau = _IMPLICIT_MIDPOINT_TABLEAU 36 | 37 | _GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau( 38 | alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 - _sqrt_3 / 6], dtype=torch.float64), 39 | beta=[ 40 | torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64), 41 | torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64), 42 | ], 43 | c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64), 44 | c_error=torch.tensor([], dtype=torch.float64), 45 | ) 46 | 47 | _TRAPEZOID_TABLEAU = _ButcherTableau( 48 | alpha=torch.tensor([0, 1], dtype=torch.float64), 49 | beta=[ 50 | torch.tensor([0, 0], dtype=torch.float64), 51 | torch.tensor([1 /2, 1 / 2], dtype=torch.float64), 52 | ], 53 | c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64), 54 | c_error=torch.tensor([], dtype=torch.float64), 55 | ) 56 | 57 | class Trapezoid(FixedGridFIRKODESolver): 58 | order = 2 59 | tableau = _TRAPEZOID_TABLEAU 60 | 61 | 62 | class GaussLegendre4(FixedGridFIRKODESolver): 63 | order = 4 64 | tableau = _GAUSS_LEGENDRE_4_TABLEAU 65 | 66 | _GAUSS_LEGENDRE_6_TABLEAU = _ButcherTableau( 67 | alpha=torch.tensor([1 / 2 - _sqrt_15 / 10, 1 / 2, 1 / 2 + _sqrt_15 / 10], dtype=torch.float64), 68 | beta=[ 69 | torch.tensor([5 / 36 , 2 / 9 - _sqrt_15 / 15, 5 / 36 - _sqrt_15 / 30], dtype=torch.float64), 70 | torch.tensor([5 / 36 + _sqrt_15 / 24, 2 / 9 , 5 / 36 - _sqrt_15 / 24], dtype=torch.float64), 71 | torch.tensor([5 / 36 + _sqrt_15 / 30, 2 / 9 + _sqrt_15 / 15, 5 / 36 ], dtype=torch.float64), 72 | ], 73 | c_sol=torch.tensor([5 / 18, 4 / 9, 5 / 18], dtype=torch.float64), 74 | c_error=torch.tensor([], dtype=torch.float64), 75 | ) 76 | 77 | class GaussLegendre6(FixedGridFIRKODESolver): 78 | order = 6 79 | tableau = _GAUSS_LEGENDRE_6_TABLEAU 80 | 81 | _RADAU_IIA_3_TABLEAU = _ButcherTableau( 82 | alpha=torch.tensor([1 / 3, 1], dtype=torch.float64), 83 | beta=[ 84 | torch.tensor([5 / 12, -1 / 12], dtype=torch.float64), 85 | torch.tensor([3 / 4, 1 / 4], dtype=torch.float64) 86 | ], 87 | c_sol=torch.tensor([3 / 4, 1 / 4], dtype=torch.float64), 88 | c_error=torch.tensor([], dtype=torch.float64) 89 | ) 90 | 91 | class RadauIIA3(FixedGridFIRKODESolver): 92 | order = 3 93 | tableau = _RADAU_IIA_3_TABLEAU 94 | 95 | _RADAU_IIA_5_TABLEAU = _ButcherTableau( 96 | alpha=torch.tensor([2 / 5 - _sqrt_6 / 10, 2 / 5 + _sqrt_6 / 10, 1], dtype=torch.float64), 97 | beta=[ 98 | torch.tensor([11 / 45 - 7 * _sqrt_6 / 360 , 37 / 225 - 169 * _sqrt_6 / 1800, -2 / 225 + _sqrt_6 / 75], dtype=torch.float64), 99 | torch.tensor([37 / 225 + 169 * _sqrt_6 / 1800, 11 / 45 + 7 * _sqrt_6 / 360 , -2 / 225 - _sqrt_6 / 75], dtype=torch.float64), 100 | torch.tensor([4 / 9 - _sqrt_6 / 36 , 4 / 9 + _sqrt_6 / 36 , 1 / 9], dtype=torch.float64) 101 | ], 102 | c_sol=torch.tensor([4 / 9 - _sqrt_6 / 36, 4 / 9 + _sqrt_6 / 36, 1 / 9], dtype=torch.float64), 103 | c_error=torch.tensor([], dtype=torch.float64) 104 | ) 105 | 106 | class RadauIIA5(FixedGridFIRKODESolver): 107 | order = 5 108 | tableau = _RADAU_IIA_5_TABLEAU 109 | 110 | gamma = (2. - _sqrt_2) / 2. 111 | _SDIRK_2_TABLEAU = _ButcherTableau( 112 | alpha = torch.tensor([gamma, 1], dtype=torch.float64), 113 | beta=[ 114 | torch.tensor([gamma], dtype=torch.float64), 115 | torch.tensor([1 - gamma, gamma], dtype=torch.float64), 116 | ], 117 | c_sol=torch.tensor([1 - gamma, gamma], dtype=torch.float64), 118 | c_error=torch.tensor([], dtype=torch.float64) 119 | ) 120 | 121 | class SDIRK2(FixedGridDIRKODESolver): 122 | order = 2 123 | tableau = _SDIRK_2_TABLEAU 124 | 125 | gamma = 1. - _sqrt_2 / 2. 126 | beta = _sqrt_2 / 4. 127 | _TRBDF_2_TABLEAU = _ButcherTableau( 128 | alpha = torch.tensor([0, 2 * gamma, 1], dtype=torch.float64), 129 | beta=[ 130 | torch.tensor([0], dtype=torch.float64), 131 | torch.tensor([gamma, gamma], dtype=torch.float64), 132 | torch.tensor([beta, beta, gamma], dtype=torch.float64), 133 | ], 134 | c_sol=torch.tensor([beta, beta, gamma], dtype=torch.float64), 135 | c_error=torch.tensor([], dtype=torch.float64) 136 | ) 137 | 138 | class TRBDF2(FixedGridDIRKODESolver): 139 | order = 2 140 | tableau = _TRBDF_2_TABLEAU 141 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/interp.py: -------------------------------------------------------------------------------- 1 | def _interp_fit(y0, y1, y_mid, f0, f1, dt): 2 | """Fit coefficients for 4th order polynomial interpolation. 3 | 4 | Args: 5 | y0: function value at the start of the interval. 6 | y1: function value at the end of the interval. 7 | y_mid: function value at the mid-point of the interval. 8 | f0: derivative value at the start of the interval. 9 | f1: derivative value at the end of the interval. 10 | dt: width of the interval. 11 | 12 | Returns: 13 | List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial 14 | `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` 15 | between 0 (start of interval) and 1 (end of interval). 16 | """ 17 | a = 2 * dt * (f1 - f0) - 8 * (y1 + y0) + 16 * y_mid 18 | b = dt * (5 * f0 - 3 * f1) + 18 * y0 + 14 * y1 - 32 * y_mid 19 | c = dt * (f1 - 4 * f0) - 11 * y0 - 5 * y1 + 16 * y_mid 20 | d = dt * f0 21 | e = y0 22 | return [e, d, c, b, a] 23 | 24 | 25 | def _interp_evaluate(coefficients, t0, t1, t): 26 | """Evaluate polynomial interpolation at the given time point. 27 | 28 | Args: 29 | coefficients: list of Tensor coefficients as created by `interp_fit`. 30 | t0: scalar float64 Tensor giving the start of the interval. 31 | t1: scalar float64 Tensor giving the end of the interval. 32 | t: scalar float64 Tensor giving the desired interpolation point. 33 | 34 | Returns: 35 | Polynomial interpolation of the coefficients at time `t`. 36 | """ 37 | 38 | assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) 39 | x = (t - t0) / (t1 - t0) 40 | x = x.to(coefficients[0].dtype) 41 | 42 | total = coefficients[0] + x * coefficients[1] 43 | x_power = x 44 | for coefficient in coefficients[2:]: 45 | x_power = x_power * x 46 | total = total + x_power * coefficient 47 | 48 | return total 49 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/misc.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | import numpy as np 4 | import torch 5 | import warnings 6 | from .event_handling import combine_event_functions 7 | 8 | 9 | _all_callback_names = ['callback_step', 'callback_accept_step', 'callback_reject_step'] 10 | _all_adjoint_callback_names = [name + '_adjoint' for name in _all_callback_names] 11 | _null_callback = lambda *args, **kwargs: None 12 | 13 | def _handle_unused_kwargs(solver, unused_kwargs): 14 | if len(unused_kwargs) > 0: 15 | warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs)) 16 | 17 | 18 | def _linf_norm(tensor): 19 | return tensor.abs().max() 20 | 21 | 22 | def _rms_norm(tensor): 23 | return tensor.abs().pow(2).mean().sqrt() 24 | 25 | 26 | def _zero_norm(tensor): 27 | return 0. 28 | 29 | 30 | def _mixed_norm(tensor_tuple): 31 | if len(tensor_tuple) == 0: 32 | return 0. 33 | return max([_rms_norm(tensor) for tensor in tensor_tuple]) 34 | 35 | 36 | def _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0=None): 37 | """Empirically select a good initial step. 38 | 39 | The algorithm is described in [1]_. 40 | 41 | References 42 | ---------- 43 | .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential 44 | Equations I: Nonstiff Problems", Sec. II.4, 2nd edition. 45 | """ 46 | 47 | dtype = y0.dtype 48 | device = y0.device 49 | t_dtype = t0.dtype 50 | t0 = t0.to(t_dtype) 51 | 52 | if f0 is None: 53 | f0 = func(t0, y0) 54 | 55 | scale = atol + torch.abs(y0) * rtol 56 | 57 | d0 = norm(y0 / scale).abs() 58 | d1 = norm(f0 / scale).abs() 59 | 60 | if d0 < 1e-5 or d1 < 1e-5: 61 | h0 = torch.tensor(1e-6, dtype=dtype, device=device) 62 | else: 63 | h0 = 0.01 * d0 / d1 64 | h0 = h0.abs() 65 | 66 | y1 = y0 + h0 * f0 67 | f1 = func(t0 + h0, y1) 68 | 69 | d2 = torch.abs(norm((f1 - f0) / scale) / h0) 70 | 71 | if d1 <= 1e-15 and d2 <= 1e-15: 72 | h1 = torch.max(torch.tensor(1e-6, dtype=dtype, device=device), h0 * 1e-3) 73 | else: 74 | h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1)) 75 | h1 = h1.abs() 76 | 77 | return torch.min(100 * h0, h1).to(t_dtype) 78 | 79 | 80 | def _compute_error_ratio(error_estimate, rtol, atol, y0, y1, norm): 81 | error_tol = atol + rtol * torch.max(y0.abs(), y1.abs()) 82 | return norm(error_estimate / error_tol).abs() 83 | 84 | 85 | @torch.no_grad() 86 | def _optimal_step_size(last_step, error_ratio, safety, ifactor, dfactor, order): 87 | """Calculate the optimal size for the next step.""" 88 | if error_ratio == 0: 89 | return last_step * ifactor 90 | if error_ratio < 1: 91 | dfactor = torch.ones((), dtype=last_step.dtype, device=last_step.device) 92 | error_ratio = error_ratio.type_as(last_step) 93 | exponent = torch.tensor(order, dtype=last_step.dtype, device=last_step.device).reciprocal() 94 | factor = torch.min(ifactor, torch.max(safety / error_ratio ** exponent, dfactor)) 95 | return last_step * factor 96 | 97 | 98 | def _decreasing(t): 99 | return (t[1:] < t[:-1]).all() 100 | 101 | 102 | def _assert_one_dimensional(name, t): 103 | assert t.ndimension() == 1, "{} must be one dimensional".format(name) 104 | 105 | 106 | def _assert_increasing(name, t): 107 | assert (t[1:] > t[:-1]).all(), '{} must be strictly increasing or decreasing'.format(name) 108 | 109 | 110 | def _assert_floating(name, t): 111 | if not torch.is_floating_point(t): 112 | raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type())) 113 | 114 | 115 | def _tuple_tol(name, tol, shapes): 116 | try: 117 | iter(tol) 118 | except TypeError: 119 | return tol 120 | tol = tuple(tol) 121 | assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name) 122 | tol = [torch.as_tensor(tol_).expand(shape.numel()) for tol_, shape in zip(tol, shapes)] 123 | return torch.cat(tol) 124 | 125 | 126 | def _flat_to_shape(tensor, length, shapes): 127 | tensor_list = [] 128 | total = 0 129 | for shape in shapes: 130 | next_total = total + shape.numel() 131 | # It's important that this be view((...)), not view(...). Else when length=(), shape=() it fails. 132 | tensor_list.append(tensor[..., total:next_total].view((*length, *shape))) 133 | total = next_total 134 | return tuple(tensor_list) 135 | 136 | 137 | class _TupleFunc(torch.nn.Module): 138 | def __init__(self, base_func, shapes): 139 | super(_TupleFunc, self).__init__() 140 | self.base_func = base_func 141 | self.shapes = shapes 142 | 143 | def forward(self, t, y): 144 | f = self.base_func(t, _flat_to_shape(y, (), self.shapes)) 145 | return torch.cat([f_.reshape(-1) for f_ in f]) 146 | 147 | 148 | class _TupleInputOnlyFunc(torch.nn.Module): 149 | def __init__(self, base_func, shapes): 150 | super(_TupleInputOnlyFunc, self).__init__() 151 | self.base_func = base_func 152 | self.shapes = shapes 153 | 154 | def forward(self, t, y): 155 | return self.base_func(t, _flat_to_shape(y, (), self.shapes)) 156 | 157 | 158 | class _ReverseFunc(torch.nn.Module): 159 | def __init__(self, base_func, mul=1.0): 160 | super(_ReverseFunc, self).__init__() 161 | self.base_func = base_func 162 | self.mul = mul 163 | 164 | def forward(self, t, y): 165 | return self.mul * self.base_func(-t, y) 166 | 167 | 168 | class Perturb(Enum): 169 | NONE = 0 170 | PREV = 1 171 | NEXT = 2 172 | 173 | 174 | class _PerturbFunc(torch.nn.Module): 175 | 176 | def __init__(self, base_func): 177 | super(_PerturbFunc, self).__init__() 178 | self.base_func = base_func 179 | 180 | def forward(self, t, y, *, perturb=Perturb.NONE): 181 | assert isinstance(perturb, Perturb), "perturb argument must be of type Perturb enum" 182 | # This dtype change here might be buggy. 183 | # The exact time value should be determined inside the solver, 184 | # but this can slightly change it due to numerical differences during casting. 185 | if torch.is_complex(t): 186 | t = t.real 187 | t = t.to(y.abs().dtype) 188 | if perturb is Perturb.NEXT: 189 | # Replace with next smallest representable value. 190 | t = _nextafter(t, t + 1) 191 | elif perturb is Perturb.PREV: 192 | # Replace with prev largest representable value. 193 | t = _nextafter(t, t - 1) 194 | else: 195 | # Do nothing. 196 | pass 197 | return self.base_func(t, y) 198 | 199 | 200 | def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS): 201 | 202 | if event_fn is not None: 203 | if len(t) != 2: 204 | raise ValueError(f"We require len(t) == 2 when in event handling mode, but got len(t)={len(t)}.") 205 | 206 | # Combine event functions if the output is multivariate. 207 | event_fn = combine_event_functions(event_fn, t[0], y0) 208 | 209 | # Keep reference to original func as passed in 210 | original_func = func 211 | 212 | # Normalise to tensor (non-tupled) input 213 | shapes = None 214 | is_tuple = not isinstance(y0, torch.Tensor) 215 | if is_tuple: 216 | assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' 217 | shapes = [y0_.shape for y0_ in y0] 218 | rtol = _tuple_tol('rtol', rtol, shapes) 219 | atol = _tuple_tol('atol', atol, shapes) 220 | y0 = torch.cat([y0_.reshape(-1) for y0_ in y0]) 221 | func = _TupleFunc(func, shapes) 222 | if event_fn is not None: 223 | event_fn = _TupleInputOnlyFunc(event_fn, shapes) 224 | 225 | # Normalise method and options 226 | if options is None: 227 | options = {} 228 | else: 229 | options = options.copy() 230 | if method is None: 231 | method = 'dopri5' 232 | if method not in SOLVERS: 233 | raise ValueError('Invalid method "{}". Must be one of {}'.format(method, 234 | '{"' + '", "'.join(SOLVERS.keys()) + '"}.')) 235 | 236 | if is_tuple: 237 | # We accept tupled input. This is an abstraction that is hidden from the rest of odeint (exception when 238 | # returning values), so here we need to maintain the abstraction by wrapping norm functions. 239 | 240 | if 'norm' in options: 241 | # If the user passed a norm then get that... 242 | norm = options['norm'] 243 | else: 244 | # ...otherwise we default to a mixed Linf/L2 norm over tupled input. 245 | norm = _mixed_norm 246 | 247 | # In either case, norm(...) is assumed to take a tuple of tensors as input. (As that's what the state looks 248 | # like from the point of view of the user.) 249 | # So here we take the tensor that the machinery of odeint has given us, and turn it in the tuple that the 250 | # norm function is expecting. 251 | def _norm(tensor): 252 | y = _flat_to_shape(tensor, (), shapes) 253 | return norm(y) 254 | options['norm'] = _norm 255 | 256 | else: 257 | if 'norm' in options: 258 | # No need to change the norm function. 259 | pass 260 | else: 261 | # Else just use the default norm. 262 | # Technically we don't need to set that here (RKAdaptiveStepsizeODESolver has it as a default), but it 263 | # makes it easier to reason about, in the adjoint norm logic, if we know that options['norm'] is 264 | # definitely set to something. 265 | options['norm'] = _rms_norm 266 | 267 | # Normalise time 268 | _check_timelike('t', t, True) 269 | t_is_reversed = False 270 | if len(t) > 1 and t[0] > t[1]: 271 | t_is_reversed = True 272 | 273 | if t_is_reversed: 274 | # Change the integration times to ascending order. 275 | # We do this by negating the time values and all associated arguments. 276 | t = -t 277 | 278 | # Ensure time values are un-negated when calling functions. 279 | func = _ReverseFunc(func, mul=-1.0) 280 | if event_fn is not None: 281 | event_fn = _ReverseFunc(event_fn) 282 | 283 | # For fixed step solvers. 284 | try: 285 | _grid_constructor = options['grid_constructor'] 286 | except KeyError: 287 | pass 288 | else: 289 | options['grid_constructor'] = lambda func, y0, t: -_grid_constructor(func, y0, -t) 290 | 291 | # For RK solvers. 292 | _flip_option(options, 'step_t') 293 | _flip_option(options, 'jump_t') 294 | 295 | # Can only do after having normalised time 296 | _assert_increasing('t', t) 297 | 298 | # Tol checking 299 | if torch.is_tensor(rtol): 300 | assert not rtol.requires_grad, "rtol cannot require gradient" 301 | if torch.is_tensor(atol): 302 | assert not atol.requires_grad, "atol cannot require gradient" 303 | 304 | # Backward compatibility: Allow t and y0 to be on different devices 305 | if t.device != y0.device: 306 | warnings.warn("t is not on the same device as y0. Coercing to y0.device.") 307 | t = t.to(y0.device) 308 | # ~Backward compatibility 309 | 310 | # Add perturb argument to func. 311 | func = _PerturbFunc(func) 312 | 313 | # Add callbacks to wrapped_func 314 | callback_names = set() 315 | for callback_name in _all_callback_names: 316 | try: 317 | callback = getattr(original_func, callback_name) 318 | except AttributeError: 319 | setattr(func, callback_name, _null_callback) 320 | else: 321 | if callback is not _null_callback: 322 | callback_names.add(callback_name) 323 | # At the moment all callbacks have the arguments (t0, y0, dt). 324 | # These will need adjusting on a per-callback basis if that changes in the future. 325 | if is_tuple: 326 | def callback(t0, y0, dt, _callback=callback): 327 | y0 = _flat_to_shape(y0, (), shapes) 328 | return _callback(t0, y0, dt) 329 | if t_is_reversed: 330 | def callback(t0, y0, dt, _callback=callback): 331 | return _callback(-t0, y0, dt) 332 | setattr(func, callback_name, callback) 333 | for callback_name in _all_adjoint_callback_names: 334 | try: 335 | callback = getattr(original_func, callback_name) 336 | except AttributeError: 337 | pass 338 | else: 339 | setattr(func, callback_name, callback) 340 | 341 | invalid_callbacks = callback_names - SOLVERS[method].valid_callbacks() 342 | if len(invalid_callbacks) > 0: 343 | warnings.warn("Solver '{}' does not support callbacks {}".format(method, invalid_callbacks)) 344 | 345 | return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed 346 | 347 | 348 | class _StitchGradient(torch.autograd.Function): 349 | @staticmethod 350 | def forward(ctx, x1, out): 351 | return out 352 | 353 | @staticmethod 354 | def backward(ctx, grad_out): 355 | return grad_out, None 356 | 357 | 358 | def _nextafter(x1, x2): 359 | with torch.no_grad(): 360 | if hasattr(torch, "nextafter"): 361 | out = torch.nextafter(x1, x2) 362 | else: 363 | out = np_nextafter(x1, x2) 364 | return _StitchGradient.apply(x1, out) 365 | 366 | 367 | def np_nextafter(x1, x2): 368 | warnings.warn("torch.nextafter is only available in PyTorch 1.7 or newer." 369 | "Falling back to numpy.nextafter. Upgrade PyTorch to remove this warning.") 370 | x1_np = x1.detach().cpu().numpy() 371 | x2_np = x2.detach().cpu().numpy() 372 | out = torch.tensor(np.nextafter(x1_np, x2_np)).to(x1) 373 | return out 374 | 375 | 376 | def _check_timelike(name, timelike, can_grad): 377 | assert isinstance(timelike, torch.Tensor), '{} must be a torch.Tensor'.format(name) 378 | _assert_floating(name, timelike) 379 | assert timelike.ndimension() == 1, "{} must be one dimensional".format(name) 380 | if not can_grad: 381 | assert not timelike.requires_grad, "{} cannot require gradient".format(name) 382 | diff = timelike[1:] > timelike[:-1] 383 | assert diff.all() or (~diff).all(), '{} must be strictly increasing or decreasing'.format(name) 384 | 385 | 386 | def _flip_option(options, option_name): 387 | try: 388 | option_value = options[option_name] 389 | except KeyError: 390 | pass 391 | else: 392 | if isinstance(option_value, torch.Tensor): 393 | options[option_name] = -option_value 394 | # else: an error will be raised when the option is attempted to be used in Solver.__init__, but we defer raising 395 | # the error until then to keep things tidy. 396 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/odeint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.functional import vjp 3 | from .dopri5 import Dopri5Solver 4 | from .bosh3 import Bosh3Solver 5 | from .adaptive_heun import AdaptiveHeunSolver 6 | from .fehlberg2 import Fehlberg2 7 | from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4 8 | from .fixed_grid_implicit import ImplicitEuler, ImplicitMidpoint, Trapezoid 9 | from .fixed_grid_implicit import GaussLegendre4, GaussLegendre6 10 | from .fixed_grid_implicit import RadauIIA3, RadauIIA5 11 | from .fixed_grid_implicit import SDIRK2, TRBDF2 12 | from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton 13 | from .dopri8 import Dopri8Solver 14 | from .tsit5 import Tsit5Solver 15 | from .scipy_wrapper import ScipyWrapperODESolver 16 | from .misc import _check_inputs, _flat_to_shape 17 | from .interp import _interp_evaluate 18 | 19 | SOLVERS = { 20 | 'dopri8': Dopri8Solver, 21 | 'dopri5': Dopri5Solver, 22 | 'tsit5': Tsit5Solver, 23 | 'bosh3': Bosh3Solver, 24 | 'fehlberg2': Fehlberg2, 25 | 'adaptive_heun': AdaptiveHeunSolver, 26 | 'euler': Euler, 27 | 'midpoint': Midpoint, 28 | 'heun2': Heun2, 29 | 'heun3': Heun3, 30 | 'rk4': RK4, 31 | 'explicit_adams': AdamsBashforth, 32 | 'implicit_adams': AdamsBashforthMoulton, 33 | 'implicit_euler': ImplicitEuler, 34 | 'implicit_midpoint': ImplicitMidpoint, 35 | 'trapezoid': Trapezoid, 36 | 'radauIIA3': RadauIIA3, 37 | 'gl4': GaussLegendre4, 38 | 'radauIIA5': RadauIIA5, 39 | 'gl6': GaussLegendre6, 40 | 'sdirk2': SDIRK2, 41 | 'trbdf2': TRBDF2, 42 | # Backward compatibility: use the same name as before 43 | 'fixed_adams': AdamsBashforthMoulton, 44 | # ~Backwards compatibility 45 | 'scipy_solver': ScipyWrapperODESolver, 46 | } 47 | 48 | 49 | def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None): 50 | """Integrate a system of ordinary differential equations. 51 | 52 | Solves the initial value problem for a non-stiff system of first order ODEs: 53 | ``` 54 | dy/dt = func(t, y), y(t[0]) = y0 55 | ``` 56 | where y is a Tensor or tuple of Tensors of any shape. 57 | 58 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. 59 | 60 | Args: 61 | func: Function that maps a scalar Tensor `t` and a Tensor holding the state `y` 62 | into a Tensor of state derivatives with respect to time. Optionally, `y` 63 | can also be a tuple of Tensors. 64 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. Optionally, `y0` 65 | can also be a tuple of Tensors. 66 | t: 1-D Tensor holding a sequence of time points for which to solve for 67 | `y`, in either increasing or decreasing order. The first element of 68 | this sequence is taken to be the initial time point. 69 | rtol: optional float64 Tensor specifying an upper bound on relative error, 70 | per element of `y`. 71 | atol: optional float64 Tensor specifying an upper bound on absolute error, 72 | per element of `y`. 73 | method: optional string indicating the integration method to use. 74 | options: optional dict of configuring options for the indicated integration 75 | method. Can only be provided if a `method` is explicitly set. 76 | event_fn: Function that maps the state `y` to a Tensor. The solve terminates when 77 | event_fn evaluates to zero. If this is not None, all but the first elements of 78 | `t` are ignored. 79 | 80 | Returns: 81 | y: Tensor, where the first dimension corresponds to different 82 | time points. Contains the solved value of y for each desired time point in 83 | `t`, with the initial value `y0` being the first element along the first 84 | dimension. 85 | 86 | Raises: 87 | ValueError: if an invalid `method` is provided. 88 | """ 89 | 90 | shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS) 91 | 92 | solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) 93 | 94 | if event_fn is None: 95 | solution = solver.integrate(t) 96 | else: 97 | event_t, solution = solver.integrate_until_event(t[0], event_fn) 98 | event_t = event_t.to(t) 99 | if t_is_reversed: 100 | event_t = -event_t 101 | 102 | if shapes is not None: 103 | solution = _flat_to_shape(solution, (len(t),), shapes) 104 | 105 | if event_fn is None: 106 | return solution 107 | else: 108 | return event_t, solution 109 | 110 | 111 | def odeint_dense(func, y0, t0, t1, *, rtol=1e-7, atol=1e-9, method=None, options=None): 112 | 113 | assert torch.is_tensor(y0) # TODO: handle tuple of tensors 114 | 115 | t = torch.tensor([t0, t1]).to(t0) 116 | 117 | shapes, func, y0, t, rtol, atol, method, options, _, _ = _check_inputs(func, y0, t, rtol, atol, method, options, None, SOLVERS) 118 | 119 | assert method == "dopri5" 120 | 121 | solver = Dopri5Solver(func=func, y0=y0, rtol=rtol, atol=atol, **options) 122 | 123 | # The integration loop 124 | solution = torch.empty(len(t), *solver.y0.shape, dtype=solver.y0.dtype, device=solver.y0.device) 125 | solution[0] = solver.y0 126 | t = t.to(solver.dtype) 127 | solver._before_integrate(t) 128 | t0 = solver.rk_state.t0 129 | 130 | times = [t0] 131 | interp_coeffs = [] 132 | 133 | for i in range(1, len(t)): 134 | next_t = t[i] 135 | while next_t > solver.rk_state.t1: 136 | solver.rk_state = solver._adaptive_step(solver.rk_state) 137 | t1 = solver.rk_state.t1 138 | 139 | if t1 != t0: 140 | # Step accepted. 141 | t0 = t1 142 | times.append(t1) 143 | interp_coeffs.append(torch.stack(solver.rk_state.interp_coeff)) 144 | 145 | solution[i] = _interp_evaluate(solver.rk_state.interp_coeff, solver.rk_state.t0, solver.rk_state.t1, next_t) 146 | 147 | times = torch.stack(times).reshape(-1).cpu() 148 | interp_coeffs = torch.stack(interp_coeffs) 149 | 150 | def dense_output_fn(t_eval): 151 | idx = torch.searchsorted(times, t_eval, side="right") 152 | t0 = times[idx - 1] 153 | t1 = times[idx] 154 | coef = [interp_coeffs[idx - 1][i] for i in range(interp_coeffs.shape[1])] 155 | return _interp_evaluate(coef, t0, t1, t_eval) 156 | 157 | return dense_output_fn 158 | 159 | 160 | def odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs): 161 | """Automatically links up the gradient from the event time.""" 162 | 163 | if reverse_time: 164 | t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() - 1.0]) 165 | else: 166 | t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() + 1.0]) 167 | 168 | event_t, solution = odeint_interface(func, y0, t, event_fn=event_fn, **kwargs) 169 | 170 | # Dummy values for rtol, atol, method, and options. 171 | shapes, _func, _, t, _, _, _, _, event_fn, _ = _check_inputs(func, y0, t, 0.0, 0.0, None, None, event_fn, SOLVERS) 172 | 173 | if shapes is not None: 174 | state_t = torch.cat([s[-1].reshape(-1) for s in solution]) 175 | else: 176 | state_t = solution[-1] 177 | 178 | # Event_fn takes in negated time value if reverse_time is True. 179 | if reverse_time: 180 | event_t = -event_t 181 | 182 | event_t, state_t = ImplicitFnGradientRerouting.apply(_func, event_fn, event_t, state_t) 183 | 184 | # Return the user expected time value. 185 | if reverse_time: 186 | event_t = -event_t 187 | 188 | if shapes is not None: 189 | state_t = _flat_to_shape(state_t, (), shapes) 190 | solution = tuple(torch.cat([s[:-1], s_t[None]], dim=0) for s, s_t in zip(solution, state_t)) 191 | else: 192 | solution = torch.cat([solution[:-1], state_t[None]], dim=0) 193 | 194 | return event_t, solution 195 | 196 | 197 | class ImplicitFnGradientRerouting(torch.autograd.Function): 198 | 199 | @staticmethod 200 | def forward(ctx, func, event_fn, event_t, state_t): 201 | """ event_t is the solution to event_fn """ 202 | ctx.func = func 203 | ctx.event_fn = event_fn 204 | ctx.save_for_backward(event_t, state_t) 205 | return event_t.detach(), state_t.detach() 206 | 207 | @staticmethod 208 | def backward(ctx, grad_t, grad_state): 209 | func = ctx.func 210 | event_fn = ctx.event_fn 211 | event_t, state_t = ctx.saved_tensors 212 | 213 | event_t = event_t.detach().clone().requires_grad_(True) 214 | state_t = state_t.detach().clone().requires_grad_(True) 215 | 216 | f_val = func(event_t, state_t) 217 | 218 | with torch.enable_grad(): 219 | c, (par_dt, dstate) = vjp(event_fn, (event_t, state_t)) 220 | 221 | # Total derivative of event_fn wrt t evaluated at event_t. 222 | dcdt = par_dt + torch.sum(dstate * f_val) 223 | 224 | # Add the gradient from final state to final time value as if a regular odeint was called. 225 | grad_t = grad_t + torch.sum(grad_state * f_val) 226 | 227 | dstate = dstate * (-grad_t / (dcdt + 1e-12)).reshape_as(c) 228 | 229 | grad_state = grad_state + dstate 230 | 231 | return None, None, None, grad_state 232 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/scipy_wrapper.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from scipy.integrate import solve_ivp 4 | from .misc import _handle_unused_kwargs 5 | 6 | 7 | class ScipyWrapperODESolver(metaclass=abc.ABCMeta): 8 | 9 | def __init__(self, func, y0, rtol, atol, min_step=0, max_step=float('inf'), solver="LSODA", **unused_kwargs): 10 | unused_kwargs.pop('norm', None) 11 | unused_kwargs.pop('grid_points', None) 12 | unused_kwargs.pop('eps', None) 13 | _handle_unused_kwargs(self, unused_kwargs) 14 | del unused_kwargs 15 | 16 | self.dtype = y0.dtype 17 | self.device = y0.device 18 | self.shape = y0.shape 19 | self.y0 = y0.detach().cpu().numpy().reshape(-1) 20 | self.rtol = rtol 21 | self.atol = atol 22 | self.min_step = min_step 23 | self.max_step = max_step 24 | self.solver = solver 25 | self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype) 26 | 27 | def integrate(self, t): 28 | if t.numel() == 1: 29 | return torch.tensor(self.y0)[None].to(self.device, self.dtype) 30 | t = t.detach().cpu().numpy() 31 | sol = solve_ivp( 32 | self.func, 33 | t_span=[t.min(), t.max()], 34 | y0=self.y0, 35 | t_eval=t, 36 | method=self.solver, 37 | rtol=self.rtol, 38 | atol=self.atol, 39 | min_step=self.min_step, 40 | max_step=self.max_step 41 | ) 42 | sol = torch.tensor(sol.y).T.to(self.device, self.dtype) 43 | sol = sol.reshape(-1, *self.shape) 44 | return sol 45 | 46 | @classmethod 47 | def valid_callbacks(cls): 48 | return set() 49 | 50 | 51 | def convert_func_to_numpy(func, shape, device, dtype): 52 | 53 | def np_func(t, y): 54 | t = torch.tensor(t).to(device, dtype) 55 | y = torch.reshape(torch.tensor(y).to(device, dtype), shape) 56 | with torch.no_grad(): 57 | f = func(t, y) 58 | return f.detach().cpu().numpy().reshape(-1) 59 | 60 | return np_func 61 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/solvers.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from .event_handling import find_event 4 | from .misc import _handle_unused_kwargs 5 | 6 | 7 | class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta): 8 | def __init__(self, dtype, y0, norm, **unused_kwargs): 9 | _handle_unused_kwargs(self, unused_kwargs) 10 | del unused_kwargs 11 | 12 | self.y0 = y0 13 | self.dtype = dtype 14 | 15 | self.norm = norm 16 | 17 | def _before_integrate(self, t): 18 | pass 19 | 20 | @abc.abstractmethod 21 | def _advance(self, next_t): 22 | raise NotImplementedError 23 | 24 | @classmethod 25 | def valid_callbacks(cls): 26 | return set() 27 | 28 | def integrate(self, t): 29 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 30 | solution[0] = self.y0 31 | t = t.to(self.dtype) 32 | self._before_integrate(t) 33 | for i in range(1, len(t)): 34 | solution[i] = self._advance(t[i]) 35 | return solution 36 | 37 | 38 | class AdaptiveStepsizeEventODESolver(AdaptiveStepsizeODESolver, metaclass=abc.ABCMeta): 39 | 40 | @abc.abstractmethod 41 | def _advance_until_event(self, event_fn): 42 | raise NotImplementedError 43 | 44 | def integrate_until_event(self, t0, event_fn): 45 | t0 = t0.to(self.y0.device, self.dtype) 46 | self._before_integrate(t0.reshape(-1)) 47 | event_time, y1 = self._advance_until_event(event_fn) 48 | solution = torch.stack([self.y0, y1], dim=0) 49 | return event_time, solution 50 | 51 | 52 | class FixedGridODESolver(metaclass=abc.ABCMeta): 53 | order: int 54 | 55 | def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs): 56 | self.atol = unused_kwargs.pop('atol') 57 | unused_kwargs.pop('rtol', None) 58 | unused_kwargs.pop('norm', None) 59 | _handle_unused_kwargs(self, unused_kwargs) 60 | del unused_kwargs 61 | 62 | self.func = func 63 | self.y0 = y0 64 | self.dtype = y0.dtype 65 | self.device = y0.device 66 | self.step_size = step_size 67 | self.interp = interp 68 | self.perturb = perturb 69 | 70 | if step_size is None: 71 | if grid_constructor is None: 72 | self.grid_constructor = lambda f, y0, t: t 73 | else: 74 | self.grid_constructor = grid_constructor 75 | else: 76 | if grid_constructor is None: 77 | self.grid_constructor = self._grid_constructor_from_step_size(step_size) 78 | else: 79 | raise ValueError("step_size and grid_constructor are mutually exclusive arguments.") 80 | 81 | @classmethod 82 | def valid_callbacks(cls): 83 | return {'callback_step'} 84 | 85 | @staticmethod 86 | def _grid_constructor_from_step_size(step_size): 87 | def _grid_constructor(func, y0, t): 88 | start_time = t[0] 89 | end_time = t[-1] 90 | 91 | niters = torch.ceil((end_time - start_time) / step_size + 1).item() 92 | t_infer = torch.arange(0, niters, dtype=t.dtype, device=t.device) * step_size + start_time 93 | t_infer[-1] = t[-1] 94 | 95 | return t_infer 96 | return _grid_constructor 97 | 98 | @abc.abstractmethod 99 | def _step_func(self, func, t0, dt, t1, y0): 100 | pass 101 | 102 | def integrate(self, t): 103 | time_grid = self.grid_constructor(self.func, self.y0, t) 104 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 105 | 106 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 107 | solution[0] = self.y0 108 | 109 | j = 1 110 | y0 = self.y0 111 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 112 | dt = t1 - t0 113 | self.func.callback_step(t0, y0, dt) 114 | dy, f0 = self._step_func(self.func, t0, dt, t1, y0) 115 | y1 = y0 + dy 116 | 117 | while j < len(t) and t1 >= t[j]: 118 | if self.interp == "linear": 119 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) 120 | elif self.interp == "cubic": 121 | f1 = self.func(t1, y1) 122 | solution[j] = self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t[j]) 123 | else: 124 | raise ValueError(f"Unknown interpolation method {self.interp}") 125 | j += 1 126 | y0 = y1 127 | 128 | return solution 129 | 130 | def integrate_until_event(self, t0, event_fn): 131 | assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options." 132 | 133 | t0 = t0.type_as(self.y0.abs()) 134 | y0 = self.y0 135 | dt = self.step_size 136 | 137 | sign0 = torch.sign(event_fn(t0, y0)) 138 | max_itrs = 20000 139 | itr = 0 140 | while True: 141 | itr += 1 142 | t1 = t0 + dt 143 | dy, f0 = self._step_func(self.func, t0, dt, t1, y0) 144 | y1 = y0 + dy 145 | 146 | sign1 = torch.sign(event_fn(t1, y1)) 147 | 148 | if sign0 != sign1: 149 | if self.interp == "linear": 150 | interp_fn = lambda t: self._linear_interp(t0, t1, y0, y1, t) 151 | elif self.interp == "cubic": 152 | f1 = self.func(t1, y1) 153 | interp_fn = lambda t: self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t) 154 | else: 155 | raise ValueError(f"Unknown interpolation method {self.interp}") 156 | event_time, y1 = find_event(interp_fn, sign0, t0, t1, event_fn, float(self.atol)) 157 | break 158 | else: 159 | t0, y0 = t1, y1 160 | 161 | if itr >= max_itrs: 162 | raise RuntimeError(f"Reached maximum number of iterations {max_itrs}.") 163 | solution = torch.stack([self.y0, y1], dim=0) 164 | return event_time, solution 165 | 166 | def _cubic_hermite_interp(self, t0, y0, f0, t1, y1, f1, t): 167 | h = (t - t0) / (t1 - t0) 168 | h00 = (1 + 2 * h) * (1 - h) * (1 - h) 169 | h10 = h * (1 - h) * (1 - h) 170 | h01 = h * h * (3 - 2 * h) 171 | h11 = h * h * (h - 1) 172 | dt = (t1 - t0) 173 | return h00 * y0 + h10 * dt * f0 + h01 * y1 + h11 * dt * f1 174 | 175 | def _linear_interp(self, t0, t1, y0, y1, t): 176 | if t == t0: 177 | return y0 178 | if t == t1: 179 | return y1 180 | slope = (t - t0) / (t1 - t0) 181 | return y0 + slope * (y1 - y0) 182 | -------------------------------------------------------------------------------- /torchdiffeq/_impl/tsit5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver 3 | # https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqTsit5/src/tsit_tableaus.jl 4 | # https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_solver/tsit5.py#L158 5 | 6 | _TSITOURAS_TABLEAU = _ButcherTableau( 7 | alpha=torch.tensor([ 8 | 161 / 1000, 9 | 327 / 1000, 10 | 9 / 10, 11 | .9800255409045096857298102862870245954942137979563024768854764293221195950761080302604, 12 | 1, 13 | 1 14 | ], dtype=torch.float64), 15 | beta=[ 16 | torch.tensor([161 / 1000], dtype=torch.float64), 17 | torch.tensor([ 18 | -.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2, 19 | .3354806554923569885444268742502307746751211773934303915373692342452941929761641411569 20 | ], dtype=torch.float64), 21 | torch.tensor([ 22 | 2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289, 23 | -6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366, 24 | 4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077, 25 | ], dtype=torch.float64), 26 | torch.tensor([ 27 | 5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791, 28 | -11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589, 29 | 7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753, 30 | -.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1 31 | ], dtype=torch.float64), 32 | torch.tensor([ 33 | 5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748, 34 | -12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452, 35 | 8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986, 36 | -.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1, 37 | -.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1 38 | ], dtype=torch.float64), 39 | torch.tensor([ 40 | .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1, 41 | 1 / 100, 42 | .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509, 43 | 1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331, 44 | -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677, 45 | 2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841 46 | ], dtype=torch.float64), 47 | ], 48 | c_sol=torch.tensor([ 49 | .9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1, 50 | .9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2, 51 | .4877705284247615707855642599631228241516691959761363774365216240304071651579571959813, 52 | 1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761, 53 | -2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702, 54 | 1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089, 55 | 1 / 66 56 | ], dtype=torch.float64), 57 | c_error=torch.tensor([ 58 | -1.780011052225771443378550607539534775944678804333659557637450799792588061629796e-03, 59 | -8.164344596567469032236360633546862401862537590159047610940604670770447527463931e-04, 60 | 7.880878010261996010314727672526304238628733777103128603258129604952959142646516e-03, 61 | -1.44711007173262907537165147972635116720922712343167677619514233896760819649515e-01, 62 | 5.823571654525552250199376106520421794260781239567387797673045438803694038950012e-01, 63 | -4.580821059291869466616365188325542974428047279788398179474684434732070620889539e-01, 64 | 1 / 66 65 | ], dtype=torch.float64), 66 | ) 67 | 68 | x = 1 / 2 69 | TSIT_C_MID = torch.tensor([ 70 | -1.0530884977290216*x*(x-1.329989018975412)*(x*x-1.4364028541716351*x+0.7139816917074209), 71 | 0.1017*x*x*(x*x-2.1966568338249754*x+1.2949852507374631), 72 | 2.490627285651252793*x*x*(x*x-2.38535645472061657*x+1.57803468208092486), 73 | -16.54810288924490272*(x-1.21712927295533244)*(x-0.61620406037800089)*x*x, 74 | 47.37952196281928122*(x-1.203071208372362603)*(x-0.658047292653547382)*x*x, 75 | -34.87065786149660974*(x-1.2)*(x-2/3)*x*x, 76 | 2.5*(x-1)*(x-0.6)*x*x 77 | ], dtype=torch.float64) 78 | 79 | class Tsit5Solver(RKAdaptiveStepsizeODESolver): 80 | order = 5 81 | tableau = _TSITOURAS_TABLEAU 82 | mid = TSIT_C_MID 83 | --------------------------------------------------------------------------------