├── .github └── workflows │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.rst ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── .gitignore │ ├── api.rst │ ├── bibliography.bib │ ├── conf.py │ ├── examples.rst │ ├── index.rst │ ├── quick-start.rst │ ├── references.rst │ └── user-guide.rst ├── dynax ├── __init__.py ├── custom_types.py ├── derivative.py ├── estimation.py ├── evolution.py ├── example_models.py ├── interpolation.py ├── linearize.py ├── structident.py ├── system.py └── util.py ├── examples ├── fit_initial_state.py ├── fit_long_input.py ├── fit_multiple_shooting.py ├── fit_nonlinear_ode.ipynb ├── linearize_ode.py └── linearize_recurrent_network.py ├── pyproject.toml └── tests ├── conftest.py ├── test_ad.py ├── test_estimation.py ├── test_evolution.py ├── test_examples.py ├── test_linearize.py └── test_systems.py /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: master 6 | pull_request: 7 | branches: master 8 | 9 | jobs: 10 | run-tests: 11 | strategy: 12 | matrix: 13 | python-version: [ "3.10", "3.11", "3.12" ] 14 | os: [ ubuntu-latest ] 15 | fail-fast: false 16 | runs-on: ${{ matrix.os }} 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Checks with pre-commit 27 | uses: pre-commit/action@v3.0.1 28 | 29 | - name: Test with pytest 30 | run: | 31 | python -m pip install .[dev] 32 | python -m pytest --runslow --durations=0 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info 3 | experiments 4 | .vscode 5 | .coverage 6 | htmlcov 7 | build 8 | _build 9 | docs/generated 10 | *.pytest_cache 11 | .pytype 12 | .ruff_cache 13 | .mypy_cache 14 | .ipynb_checkpoints 15 | .virtual_documents 16 | .helix 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: 'v0.5.3' 4 | hooks: 5 | - id: ruff 6 | args: [--fix] 7 | types_or: [ python, pyi, jupyter ] 8 | - id: ruff-format 9 | types_or: [ python, pyi, jupyter ] 10 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.12" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions (pull requests) are very welcome! Here's how to get started. 4 | 5 | --- 6 | 7 | **Getting started** 8 | 9 | First fork the library on GitHub. 10 | 11 | Then clone and install the library in development mode: 12 | 13 | ```bash 14 | git clone https://github.com/your-username-here/dynax.git 15 | cd dynax 16 | pip install -e .[dev] 17 | ``` 18 | 19 | Then install the pre-commit hook: 20 | 21 | ```bash 22 | pip install pre-commit 23 | pre-commit install 24 | ``` 25 | 26 | These hooks use Black and isort to format the code, and flake8 to lint it. 27 | 28 | --- 29 | 30 | **If you're making changes to the code:** 31 | 32 | Now make your changes. Make sure to include additional tests if necessary. 33 | 34 | Next verify the tests all pass: 35 | 36 | ```bash 37 | pip install pytest 38 | pytest 39 | ``` 40 | 41 | Then push your changes back to your fork of the repository: 42 | 43 | ```bash 44 | git push 45 | ``` 46 | 47 | Finally, open a pull request on GitHub! 48 | 49 | --- 50 | 51 | **If you're making changes to the documentation:** 52 | 53 | Make your changes. You can then build the documentation by doing 54 | 55 | ```bash 56 | cd docs 57 | make livehtml 58 | ``` 59 | 60 | You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser. 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright © 2022 Technical University of Denmark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Dynax 2 | ===== 3 | 4 | *"Dynamical systems in JAX"* 5 | 6 | |workflow_badge| |doc_badge| 7 | 8 | .. |workflow_badge| image:: https://github.com/fhchl/dynax/actions/workflows/run_tests.yml/badge.svg 9 | :target: https://github.com/fhchl/dynax/actions/workflows/run_tests.yml 10 | .. |doc_badge| image:: https://readthedocs.org/projects/dynax/badge/?version=latest 11 | :target: https://dynax.readthedocs.io/en/latest/?badge=latest 12 | 13 | **This is WIP. Expect things to break!** 14 | 15 | This package allows for straight-forward simulation, fitting and linearization of dynamical systems 16 | by combing `JAX`_, `Diffrax`_, `Equinox`_, and `scipy.optimize`_. Its main features 17 | include: 18 | 19 | - estimation of ODE parameters and their covariance via the prediction-error method (`example `_) 20 | - estimation of the initial state (`example `_) 21 | - estimation of linear ODE parameters via matching of frequency-response functions (`example `_) 22 | - estimation from multiple experiments 23 | - estimation with a poor man's multiple shooting (`example `_) 24 | - input-output linearization of continuous-time input affine systems (`example `_) 25 | - input-output linearization of discrete-time systems (`example `_) 26 | - estimation of a system's relative-degree (`example `_) 27 | 28 | Documentation is on its way. Until then, have a look at the `example `_ and `test `_ folders. 29 | 30 | 31 | Installing 32 | ---------- 33 | 34 | Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11+ and Diffrax 0.5+. With a 35 | suitable version of jaxlib installed: 36 | 37 | :: 38 | 39 | pip install . 40 | 41 | 42 | Testing 43 | ------- 44 | 45 | Install with 46 | 47 | :: 48 | 49 | pip install .[dev] 50 | 51 | and run 52 | 53 | :: 54 | 55 | pytest 56 | 57 | To also test the examples, do 58 | 59 | :: 60 | 61 | pytest --runslow 62 | 63 | 64 | Related software 65 | ---------------- 66 | 67 | - `nlgreyfast`_: Matlab library for fitting ODE's with mutliple shooting 68 | - `dynamax`_: inference and learning for probablistic state-space models 69 | 70 | .. _scipy.optimize: https://docs.scipy.org/doc/scipy/reference/optimize.html 71 | .. _dynamax: https://github.com/probml/dynamax 72 | .. _nlgreyfast: https://github.com/meco-group/nlgreyfast 73 | .. _jax: https://github.com/google/jax 74 | .. _diffrax: https://github.com/patrick-kidger/diffrax 75 | .. _equinox: https://github.com/patrick-kidger/equinox 76 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | WATCHDIR = "../dynax" 11 | 12 | # Put it first so that "make" without argument is like "make help". 13 | help: 14 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 15 | 16 | .PHONY: help Makefile 17 | 18 | # Catch-all target: route all unknown targets to Sphinx using the new 19 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 20 | %: Makefile 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | livehtml: 24 | sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) --watch $(WATCHDIR) $(O) 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | nbsphinx 2 | sphinx-autobuild 3 | sphinx-autodoc-typehints 4 | sphinx-rtd-theme 5 | sphinx 6 | sphinxcontrib-bibtex 7 | sphinxcontrib-aafig 8 | furo 9 | ./ 10 | jaxlib 11 | 12 | -------------------------------------------------------------------------------- /docs/source/.gitignore: -------------------------------------------------------------------------------- 1 | generated/ 2 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | ================= 3 | 4 | .. autosummary:: 5 | :toctree: generated 6 | :recursive: 7 | 8 | dynax.system 9 | dynax.estimation 10 | dynax.evolution 11 | dynax.linearize 12 | dynax.derivative 13 | dynax.interpolation 14 | dynax.util 15 | -------------------------------------------------------------------------------- /docs/source/bibliography.bib: -------------------------------------------------------------------------------- 1 | @article{robenackComputationLieDerivatives2005, 2 | title = {Computation of {{Lie Derivatives}} of {{Tensor Fields Required}} for {{Nonlinear Controller}} and {{Observer Design Employing Automatic Differentiation}}}, 3 | author = {R\"obenack, Klaus}, 4 | year = {2005}, 5 | journal = {PAMM}, 6 | volume = {5}, 7 | number = {1}, 8 | pages = {181--184}, 9 | doi = {10.1002/pamm.200510069} 10 | } 11 | 12 | @book{leeLinearizationNonlinearControl2022, 13 | title = {Linearization of {{Nonlinear Control Systems}}}, 14 | author = {Lee, Hong-Gi}, 15 | year = {2022}, 16 | publisher = {{Springer Nature Singapore}}, 17 | doi = {10.1007/978-981-19-3643-2}, 18 | } 19 | 20 | @book{sastry2013nonlinear, 21 | title={Nonlinear systems: analysis, stability, and control}, 22 | author={Sastry, Shankar}, 23 | volume={10}, 24 | year={2013}, 25 | publisher={Springer Science \& Business Media} 26 | } 27 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | 4 | html_theme = "sphinx_rtd_theme" 5 | # html_static_path = ["_static"] 6 | 7 | project = "Dynax" 8 | copyright = "2023, Franz M. Heuchel" 9 | author = "Franz M. Heuchel" 10 | 11 | # -- General configuration --------------------------------------------------- 12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 13 | 14 | extensions = [ 15 | "sphinx.ext.autodoc", 16 | "sphinx.ext.autosummary", 17 | "sphinx.ext.intersphinx", 18 | "sphinx.ext.todo", 19 | "sphinx.ext.viewcode", 20 | "sphinxcontrib.bibtex", 21 | "sphinxcontrib.aafig", 22 | "nbsphinx", 23 | "sphinx.ext.napoleon", 24 | # FIXME: sphinx_autodoc_typehints is not working together with autodoc_type_aliases, 25 | # see https://github.com/tox-dev/sphinx-autodoc-typehints/issues/284 26 | # For now, I will just use jaxtyping.ArrayLike in the docs. Sadly, that one does 27 | # not intersphinx-link to the docs. 28 | "sphinx_autodoc_typehints", 29 | ] 30 | 31 | bibtex_bibfiles = ["bibliography.bib"] 32 | templates_path = ["_templates"] 33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 34 | 35 | autosummary_generate = True 36 | autodoc_member_order = "bysource" 37 | autodoc_default_options = { 38 | "members": True, 39 | "inherited-members": False, 40 | "show-inheritance": True, 41 | "special-members": "__call__", 42 | } 43 | autoclass_content = "both" 44 | autodoc_typehints = "signature" 45 | autodoc_preserve_defaults = True 46 | autodoc_type_aliases = { 47 | a: a 48 | for a in [ 49 | # "VectorFunc", 50 | # "ScalarFunc", 51 | "VectorField", 52 | "OutputFunc", 53 | # "ArrayLike" 54 | ] 55 | } 56 | # } | { 57 | # "jax.typing.ArrayLike": "jax.typing.ArrayLike", 58 | # "ArrayLike": "dynax.custom_types.ArrayLike", 59 | # "jax._src_basearray.ArrayLike": "ArrayLike", 60 | # } 61 | 62 | # TODO: I want to stop ArrayLike from exploiding, but above doesn't seem to work :/ 63 | 64 | # For sphinx_autodoc_typehints. 65 | typehints_use_rtype = False 66 | always_use_bars_union = True 67 | 68 | napoleon_numpy_docstring = False 69 | napoleon_google_docstring = True 70 | napoleon_include_init_with_doc = False 71 | napoleon_include_special_with_doc = False 72 | napoleon_preprocess_types = True 73 | napoleon_attr_annotations = True 74 | napoleon_use_rtype = False 75 | 76 | # TODO: __init__ should not pop up in the docs 77 | # TODO: remove dynax.evolution... path from docs 78 | # TODO: intersphinx should pick up Module, Array and ArrayLike, but doesn't :( 79 | 80 | 81 | intersphinx_mapping = { 82 | "python": ("https://docs.python.org/3/", None), 83 | "numpy": ("https://numpy.org/doc/stable/", None), 84 | "scipy": ("https://docs.scipy.org/doc/scipy/", None), 85 | "jaxtyping": ("https://docs.kidger.site/jaxtyping/", None), 86 | "diffrax": ("https://docs.kidger.site/diffrax/", None), 87 | "equinox": ("https://docs.kidger.site/equinox/", None), 88 | "optimistix": ("https://docs.kidger.site/optimistix/", None), 89 | "lineax": ("https://docs.kidger.site/lineax/", None), 90 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 91 | } 92 | 93 | 94 | # def autodoc_process_docstring(app, what, name, obj, options, lines): 95 | # for i in range(len(lines)): 96 | # if lines[i] 97 | # # # lines[i] = lines[i].replace("np.", "~numpy.") # For shorter links 98 | # # lines[i] = lines[i].replace("F.", "torch.nn.functional.") 99 | # # lines[i] = lines[i].replace("List[", "~typing.List[") 100 | 101 | 102 | # def setup(app): 103 | # app.connect("autodoc-process-docstring", autodoc_process_docstring) 104 | 105 | # Short type docs for jaxtyping's types 106 | # https://github.com/patrick-kidger/pytkdocs_tweaks/blob/2a7ce453e315f526d792f689e61d56ecaa4ab000/pytkdocs_tweaks/__init__.py#L283 107 | typing.GENERATING_DOCUMENTATION = True # pyright: ignore 108 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | Have a look at the notebooks on the left or the following scripts. 8 | 9 | 10 | .. _example_models: 11 | 12 | Declaring ODE systems in Dynax 13 | ------------------------------ 14 | 15 | .. literalinclude:: ../../dynax/example_models.py 16 | 17 | 18 | .. _example-fit-ode: 19 | 20 | Fit a system of ordinary differential equations 21 | ----------------------------------------------- 22 | 23 | .. literalinclude:: ../../examples/fit_ode.ipynb 24 | 25 | 26 | .. _example-fit-multiple-shooting: 27 | 28 | Fit a system with multiple shooting 29 | ----------------------------------- 30 | 31 | .. literalinclude:: ../../examples/fit_multiple_shooting.py 32 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Dynax documentation master file, created by 2 | sphinx-quickstart on Fri Jan 27 10:54:38 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Dynax's documentation! 7 | ================================= 8 | 9 | **Dynax** is a Python package for straight-forward simulation, fitting and 10 | linearization of dynamical systems. It combines `JAX`_, 11 | `Diffrax`_, `Equinox`_ and `SciPy`_ optimizers. 12 | 13 | .. _JAX: https://github.com/google/jax 14 | .. _Diffrax: https://github.com/patrick-kidger/diffrax 15 | .. _Equinox: https://github.com/patrick-kidger/equinox 16 | .. _Scipy: https://docs.scipy.org/doc/scipy/index.html 17 | 18 | .. note:: 19 | 20 | This project is under active development. The documentation is not there yet. 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | :caption: Contents: 25 | 26 | quick-start 27 | user-guide 28 | examples 29 | api 30 | references 31 | -------------------------------------------------------------------------------- /docs/source/quick-start.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | 3 | 4 | Quick start 5 | =========== 6 | 7 | Defining a dynamical system 8 | --------------------------- 9 | 10 | Parameter fitting 11 | ----------------- 12 | 13 | Input-output linearization 14 | -------------------------- 15 | 16 | 17 | -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | 3 | Bibliography 4 | ============ 5 | 6 | .. bibliography:: 7 | -------------------------------------------------------------------------------- /docs/source/user-guide.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | 3 | User guide 4 | ========== 5 | 6 | Defining dynamical systems 7 | -------------------------- 8 | 9 | Constraining parameters 10 | ----------------------- 11 | 12 | 13 | -------------------------------------------------------------------------------- /dynax/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | import jax as _jax 4 | 5 | from .derivative import lie_derivative as lie_derivative 6 | from .estimation import ( 7 | fit_csd_matching as fit_csd_matching, 8 | fit_least_squares as fit_least_squares, 9 | fit_multiple_shooting as fit_multiple_shooting, 10 | transfer_function as transfer_function, 11 | ) 12 | from .evolution import AbstractEvolution as AbstractEvolution, Flow as Flow, Map as Map 13 | from .interpolation import spline_it as spline_it 14 | from .linearize import ( 15 | discrete_input_output_linearize as discrete_input_output_linearize, 16 | discrete_relative_degree as discrete_relative_degree, 17 | DiscreteLinearizingSystem as DiscreteLinearizingSystem, 18 | input_output_linearize as input_output_linearize, 19 | LinearizingSystem as LinearizingSystem, 20 | relative_degree as relative_degree, 21 | ) 22 | from .system import ( 23 | AbstractControlAffine as AbstractControlAffine, 24 | AbstractSystem as AbstractSystem, 25 | boxed_field as boxed_field, 26 | DynamicStateFeedbackSystem as DynamicStateFeedbackSystem, 27 | FeedbackSystem as FeedbackSystem, 28 | field as field, 29 | LinearSystem as LinearSystem, 30 | non_negative_field as non_negative_field, 31 | SeriesSystem as SeriesSystem, 32 | static_field as static_field, 33 | StaticStateFeedbackSystem as StaticStateFeedbackSystem, 34 | ) 35 | from .util import _monkeypatch_pretty_print, pretty as pretty 36 | 37 | 38 | # TODO: leave out or make clear somewhere 39 | print("Setting jax_enable_x64 to True.") 40 | _jax.config.update("jax_enable_x64", True) 41 | 42 | _monkeypatch_pretty_print() 43 | 44 | __version__ = importlib.metadata.version("dynax") 45 | -------------------------------------------------------------------------------- /dynax/custom_types.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Callable, TypeAlias, Union 3 | 4 | import jaxtyping 5 | import numpy as np 6 | 7 | 8 | generating_docs = getattr(typing, "GENERATING_DOCUMENTATION", False) 9 | 10 | if typing.TYPE_CHECKING: 11 | # In the editor. 12 | from jax import Array as Array 13 | from jax.typing import ArrayLike as ArrayLike 14 | 15 | Scalar: TypeAlias = Array 16 | ScalarLike: TypeAlias = ArrayLike 17 | FloatScalarLike = Union[float, Array, np.ndarray] 18 | elif generating_docs: 19 | # In the docs. 20 | class Scalar: 21 | pass 22 | 23 | class ScalarLike: 24 | pass 25 | 26 | class Array: 27 | pass 28 | 29 | class ArrayLike: 30 | pass 31 | 32 | FloatScalarLike = float 33 | 34 | for cls in (Scalar, ScalarLike, Array, ArrayLike): 35 | cls.__module__ = "builtins" 36 | cls.__qualname__ = cls.__name__ 37 | else: 38 | # At runtime. 39 | from jax import Array 40 | from jax.typing import ArrayLike as ArrayLike 41 | 42 | Scalar = jaxtyping.Shaped[Array, ""] 43 | ScalarLike = jaxtyping.Shaped[ArrayLike, ""] 44 | FloatScalarLike = jaxtyping.Float[ArrayLike, ""] 45 | 46 | 47 | VectorFunc: TypeAlias = Callable[[Array], Array] 48 | ScalarFunc: TypeAlias = Callable[[Array], Scalar] 49 | # VectorField: TypeAlias = Callable[[Array, Scalar], Array] 50 | # OutputFunc: TypeAlias = Callable[[Array, Scalar], Array] 51 | -------------------------------------------------------------------------------- /dynax/derivative.py: -------------------------------------------------------------------------------- 1 | """Various functions for computing Lie derivatives.""" 2 | 3 | from __future__ import annotations # delayed evaluation of annotations 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import Array 9 | from jax.experimental.jet import jet 10 | 11 | from .custom_types import Scalar, ScalarFunc, VectorFunc 12 | 13 | 14 | def lie_derivative(f: VectorFunc, h: ScalarFunc, n: int = 1) -> ScalarFunc: 15 | r"""Return the Lie (or directional) derivative of `h` along `f`. 16 | 17 | The Lie derivative of order `n` is recursively defined as 18 | 19 | .. math:: 20 | 21 | L_f^0 h(x) &= h(x) \\ 22 | L_f^n h(x) &= (\nabla_x L_f^{n-1} h)(x)^T f(x) 23 | 24 | Args: 25 | f: Function from :math:`\mathbb{R}^n` to :math:`\mathbb{R}^n`. 26 | h: Function from :math:`\mathbb{R}^n` to :math:`\mathbb{R}`. 27 | n: Order of the Lie derivative. 28 | 29 | Returns: 30 | The `n`-th order Lie derivative (a function from :math:`\mathbb{R}^n` to 31 | :math:`\mathbb{R}`). 32 | 33 | """ 34 | if n < 0: 35 | raise ValueError(f"n must be non-negative but is {n}") 36 | if n == 0: 37 | return h 38 | else: 39 | lie_der = lie_derivative(f, h, n=n - 1) 40 | return lambda x: jax.jvp( 41 | lie_der, 42 | (x,), 43 | (f(x),), 44 | )[1] 45 | 46 | 47 | def lie_derivative_jet(f: VectorFunc, h: ScalarFunc, n: int = 1) -> ScalarFunc: 48 | """Compute the Lie derivative of `h` along `f` using Taylor-mode differentiation. 49 | 50 | Same parameters as :py:func:`lie_derivative`. Uses :py:func:`lie_derivatives_jet`. 51 | 52 | """ 53 | 54 | def liefun(x: Array) -> Scalar: 55 | return lie_derivatives_jet(f, h, n)(x)[-1] 56 | 57 | return liefun 58 | 59 | 60 | def lie_derivatives_jet(f: VectorFunc, h: ScalarFunc, n: int = 1) -> VectorFunc: 61 | """Return all Lie derivatives up to order `n` using Taylor-mode differentiation. 62 | 63 | Uses :py:func:`jax.experimental.jet.jet`, which currently does not compose 64 | with :py:func:`jax.grad`. 65 | 66 | See :cite:p:`robenackComputationLieDerivatives2005`. 67 | 68 | """ 69 | fac = jax.scipy.special.factorial(np.arange(n + 1)) 70 | 71 | def liefun(x: Array) -> Array: 72 | # Taylor coefficients of x(t) = ϕₜ(x_0) 73 | x_primals = [x] 74 | x_series = [jnp.zeros_like(x) for _ in range(n)] 75 | for k in range(n): 76 | # Taylor coefficients of z(t) = f(x(t)) 77 | z_primals, z_series = jet(f, x_primals, (x_series,)) 78 | z = [z_primals] + z_series 79 | # Build xₖ from zₖ: ẋ(t) = z(t) = f(x(t)) 80 | x_series[k] = z[k] / (k + 1) 81 | # Taylor coefficients of y(t) = h(x(t)) = h(ϕₜ(x_0)) 82 | y_primals, y_series = jet(h, x_primals, (x_series,)) 83 | Lfh = fac * jnp.array((y_primals, *y_series)) 84 | return Lfh 85 | 86 | return liefun 87 | -------------------------------------------------------------------------------- /dynax/estimation.py: -------------------------------------------------------------------------------- 1 | """Functions for estimating parameters of dynamical systems. 2 | 3 | Parameters of `model.system` can be constrained via the `*_field` functions. 4 | """ 5 | 6 | from __future__ import annotations # delayed evaluation of annotations 7 | 8 | import warnings 9 | from dataclasses import fields 10 | from typing import Any, Callable, cast, Literal, Optional 11 | 12 | import diffrax as dfx 13 | import equinox as eqx 14 | import jax 15 | import jax.numpy as jnp 16 | import jax.tree_util as jtu 17 | import numpy as np 18 | import scipy.signal as sig 19 | from jax import Array 20 | from jax.flatten_util import ravel_pytree 21 | from scipy.linalg import pinvh 22 | from scipy.optimize import least_squares, OptimizeResult 23 | from scipy.optimize._optimize import MemoizeJac 24 | 25 | from .custom_types import ArrayLike 26 | from .evolution import AbstractEvolution 27 | from .system import AbstractSystem 28 | from .util import broadcast_right, mse, nmse, nrmse, value_and_jacfwd 29 | 30 | 31 | def _get_bounds(module: eqx.Module) -> tuple[list[float], list[float]]: 32 | """Build flattened arrays of lower and upper parameter bounds.""" 33 | lower_bounds = [] 34 | upper_bounds = [] 35 | for field_ in fields(module): 36 | name = field_.name 37 | value = module.__dict__.get(name, None) 38 | if value is None: 39 | continue 40 | # elif field_.metadata.get("static", False): 41 | # continue 42 | elif isinstance(value, eqx.Module): 43 | lbs, ubs = _get_bounds(value) 44 | lower_bounds.extend(lbs) 45 | upper_bounds.extend(ubs) 46 | elif constraint := field_.metadata.get("constrained", False): 47 | assert isinstance(value, jax.Array) 48 | _, (lb, ub) = constraint # type: ignore 49 | size = np.asarray(value).size 50 | lower_bounds.extend([lb] * size) 51 | upper_bounds.extend([ub] * size) 52 | elif isinstance(value, jax.Array): 53 | size = np.asarray(value).size 54 | lower_bounds.extend([-np.inf] * size) 55 | upper_bounds.extend([np.inf] * size) 56 | else: 57 | continue 58 | return list(lower_bounds), list(upper_bounds) 59 | 60 | 61 | def _key_paths(tree: Any, root: str = "tree") -> list[str]: 62 | """List key_paths to trainable fields of pytree including elements of JAX arrays.""" 63 | arr_to_list = lambda x: x.tolist() if isinstance(x, jax.Array) else x 64 | params, _ = eqx.partition(tree, lambda x: isinstance(x, jax.Array)) 65 | flattened, _ = jtu.tree_flatten_with_path(jtu.tree_map(arr_to_list, params)) 66 | return [f"{root}{jtu.keystr(kp)}" for kp, _ in flattened] 67 | 68 | 69 | def _compute_covariance( 70 | jac, cost, absolute_sigma: bool, cov_prior: Optional[np.ndarray] = None 71 | ) -> np.ndarray: 72 | """Compute covariance matrix from least-squares result.""" 73 | rsize, xsize = jac.shape 74 | rtol = np.finfo(float).eps * max(rsize, xsize) 75 | hess = jac.T @ jac 76 | if cov_prior is not None: 77 | # pcov = inv(JJ^T + Σₚ⁻¹) 78 | hess += np.linalg.inv(cov_prior) 79 | pcov = cast(np.ndarray, pinvh(hess, rtol=rtol)) 80 | 81 | warn_cov = False 82 | if not absolute_sigma: 83 | if rsize > xsize: 84 | s_sq = cost / (rsize - xsize) 85 | pcov = pcov * s_sq 86 | else: 87 | warn_cov = True 88 | 89 | if np.isnan(pcov).any(): 90 | warn_cov = True 91 | 92 | if warn_cov: 93 | pcov.fill(np.inf) 94 | warnings.warn( 95 | "Covariance of the parameters could not be estimated", stacklevel=2 96 | ) 97 | 98 | return pcov 99 | 100 | 101 | def _least_squares( 102 | f: Callable[[Array], Array], 103 | init_params: Array, 104 | bounds: tuple[ArrayLike, ArrayLike], 105 | reg_term: Optional[Callable[[Array], Array]] = None, 106 | x_scale: bool = True, 107 | verbose_mse: bool = True, 108 | **kwargs: Any, 109 | ) -> OptimizeResult: 110 | """Least-squares with jit, autodiff, parameter scaling and regularization.""" 111 | 112 | if reg_term is not None: 113 | # Add regularization term 114 | _f = f 115 | _reg_term = reg_term # https://github.com/python/mypy/issues/7268 116 | f = lambda params: jnp.concatenate((_f(params), _reg_term(params))) 117 | 118 | if verbose_mse: 119 | # Scale cost to mean-squared error 120 | __f = f 121 | 122 | def f(params): 123 | res = __f(params) 124 | return res * np.sqrt(2 / res.size) 125 | 126 | if x_scale: 127 | # Scale parameters and bounds by initial values 128 | norm = np.where(np.asarray(init_params) != 0, np.abs(init_params), 1) 129 | init_params = init_params / norm 130 | ___f = f 131 | f = lambda params: ___f(params * norm) 132 | bounds = (np.array(bounds[0]) / norm, np.array(bounds[1]) / norm) 133 | 134 | fun = MemoizeJac(eqx.filter_jit(lambda x: value_and_jacfwd(f, x))) 135 | jac = fun.derivative 136 | res = least_squares( 137 | fun, 138 | init_params, 139 | bounds=bounds, 140 | jac=jac, # type: ignore 141 | x_scale="jac", # type: ignore 142 | **kwargs, 143 | ) 144 | 145 | if x_scale: 146 | # Unscale parameters 147 | res.x = res.x * norm 148 | 149 | if verbose_mse: 150 | # Rescale to Least Squares cost 151 | mse_scaling = np.sqrt(2 / res.fun.size) 152 | res.fun = res.fun / mse_scaling 153 | res.jac = res.jac / mse_scaling 154 | 155 | if reg_term is not None: 156 | # Remove regularization from residuals and Jacobian and cost 157 | res.fun = res.fun[: -len(init_params)] 158 | res.jac = res.jac[: -len(init_params)] 159 | res.cost = np.sum(res.fun**2) / 2 160 | 161 | return res 162 | 163 | 164 | def ravel_and_bounds(pytree): 165 | params, static = eqx.partition(pytree, lambda x: isinstance(x, jax.Array)) 166 | params_flat, _unravel = ravel_pytree(params) 167 | bounds = _get_bounds(params) 168 | 169 | def unravel(params_flat: np.ndarray): 170 | params = _unravel(params_flat) 171 | pytree = eqx.combine(params, static) 172 | return pytree 173 | 174 | return params_flat, bounds, unravel 175 | 176 | 177 | def fit_least_squares( 178 | model: AbstractEvolution, 179 | t: ArrayLike, 180 | y: ArrayLike, 181 | u: Optional[ArrayLike] = None, 182 | batched: bool = False, 183 | sigma: Optional[ArrayLike] = None, 184 | absolute_sigma: bool = False, 185 | reg_val: float = 0, 186 | reg_bias: Optional[Literal["initial"]] = None, 187 | verbose_mse: bool = True, 188 | **kwargs, 189 | ) -> OptimizeResult: 190 | """Fit evolution model with regularized, box-constrained nonlinear least-squares. 191 | 192 | For an example, see :ref:`example-fit-ode`. 193 | 194 | Args: 195 | model: A concrete evolution object. 196 | t: Times signal. 197 | y: Output signal with time dimension along the first axis. 198 | u: Optional input signal with time along the first axis. 199 | batched: Whether `t`, `y`, and `u` have an additional first axis of equal 200 | length holding several trajectories. The loss is then computed over all 201 | trajectories. 202 | sigma: The measurement standard deviation which is broadcasted 203 | against `y`. If `None`, it is assumed that the outputs have equal 204 | signal-to-noise ratios. 205 | absolute_sigma: Whether `sigma` is used in an absolute sense and the estimated 206 | parameter covariance reflects these absolute values. Otherwise, only 207 | the relative magnitudes of the sigma values matter. See also 208 | :func:`scipy.optimize.curve_fit`. 209 | reg_val: Weight of the :math:`L_2` regularization. 210 | reg_bias: Substractive bias term in the :math:`L_2` regularization. If 211 | `initial`, uses the initial parameters. 212 | verbose_mse: Whether the cost is scaled to the mean-squared error during logging 213 | with `verbose=2`. 214 | kwargs: Optional parameters for :py:func:`scipy.optimize.least_squares`. 215 | 216 | Returns: 217 | A Result object with the following additional attributes. 218 | 219 | - `result`: Fitted model. 220 | - `pcov`: Covariance matrix of the predicted parameters. 221 | - `y_pred`: Predicted outputs. 222 | - `key_paths`: Paths to free parameters of the model, see 223 | :py:func:`jax.tree_util.tree_flatten_with_path`. 224 | - `mse`: Mean-squared error. 225 | - `nmse`: Normalized mean-squared error. 226 | - `nrmse`: Normalized root mean-squared error. 227 | 228 | """ 229 | t = jnp.asarray(t) 230 | y = jnp.asarray(y) 231 | 232 | if batched: 233 | # First axis holds experiments, second axis holds time. 234 | std_y = np.std(y, axis=1, keepdims=True) 235 | calc_coeffs = jax.vmap(dfx.backward_hermite_coefficients) 236 | else: 237 | # First axis holds time. 238 | std_y = np.std(y, axis=0, keepdims=True) 239 | calc_coeffs = dfx.backward_hermite_coefficients 240 | 241 | if sigma is None: 242 | weight = 1 / std_y 243 | else: 244 | sigma = np.asarray(sigma) 245 | weight = 1 / sigma 246 | 247 | if u is not None: 248 | u = jnp.asarray(u) 249 | ucoeffs = calc_coeffs(t, u) 250 | else: 251 | ucoeffs = None 252 | 253 | init_params, bounds, unravel = ravel_and_bounds(model) 254 | 255 | param_bias = 0 256 | if reg_bias == "initial": 257 | param_bias = init_params 258 | 259 | is_regularized = np.any(reg_val != 0) 260 | if is_regularized: 261 | cov_prior = np.diag(1 / reg_val * np.ones(len(init_params))) 262 | reg_term = lambda params: reg_val * (params - param_bias) 263 | else: 264 | cov_prior = None 265 | reg_term = None 266 | 267 | def residual_term(params_flat): 268 | model = unravel(params_flat) 269 | if batched: 270 | # this can use pmap, if batch size is smaller than CPU cores 271 | model = jax.vmap(model) 272 | # FIXME: ucoeffs not supported for Map 273 | _, pred_y = model(t=t, ucoeffs=ucoeffs) 274 | res = (y - pred_y) * weight 275 | return res.reshape(-1) 276 | 277 | res = _least_squares( 278 | residual_term, 279 | init_params, 280 | bounds, 281 | reg_term=reg_term, 282 | verbose_mse=verbose_mse, 283 | **kwargs, 284 | ) 285 | 286 | res.result = unravel(res.x) 287 | res.pcov = _compute_covariance(res.jac, res.cost, absolute_sigma, cov_prior) 288 | res.y_pred = y - res.fun.reshape(y.shape) / weight 289 | res.key_paths = _key_paths(model, root=model.__class__.__name__) 290 | res.mse = np.atleast_1d(mse(y, res.y_pred)) 291 | res.nmse = np.atleast_1d(nmse(y, res.y_pred)) 292 | res.nrmse = np.atleast_1d(nrmse(y, res.y_pred)) 293 | 294 | return res 295 | 296 | 297 | def _moving_window(a: Array, size: int, stride: int): 298 | start_idx = jnp.arange(0, len(a) - size + 1, stride)[:, None] 299 | inner_idx = jnp.arange(size)[None, :] 300 | return a[start_idx + inner_idx] 301 | 302 | 303 | def fit_multiple_shooting( 304 | model: AbstractEvolution, 305 | t: ArrayLike, 306 | y: ArrayLike, 307 | u: Optional[ArrayLike] = None, 308 | num_shots: int = 1, 309 | continuity_penalty: float = 0.1, 310 | **kwargs, 311 | ) -> OptimizeResult: 312 | """Fit evolution model with multiple shooting. 313 | 314 | Multiple shooting subdivides the training problem into shooting segments and fits 315 | the initial states of the segments and the model parameters by minimizing the 316 | output error and a continuity loss of the states along the segment boundaries. 317 | 318 | For an example, see :ref:`example-fit-multiple-shooting`. 319 | 320 | Args: 321 | model: Concrete evolution object. 322 | t: Time signal. 323 | y: Outputs with time dimension along the first axis. 324 | u: Optional inputs with time along the first axis. 325 | num_shots: Number of shooting segments the training problem is divided into. 326 | If the length of the signals is not divisible by `num_shots`, the last few 327 | samples are ignored. 328 | continuity_penalty: Weights the penalty for discontinuities of the solution 329 | along shooting segment boundaries. 330 | kwargs: Optional parameters for :py:func:`scipy.optimize.least_squares`. 331 | 332 | Returns: 333 | Result object with the following additional attributes 334 | 335 | - `result`: The fitted model. 336 | - `x0s`: The initial states of the shooting segments. 337 | - `ts`: The times of the segments. 338 | - `ts0`: The times of the segments relative to the start of each segment. 339 | - `us`: The inputs of the segments. Only returned if `u` is not `None`. 340 | 341 | """ 342 | t = jnp.asarray(t) 343 | y = jnp.asarray(y) 344 | 345 | if u is None: 346 | msg = ( 347 | f"t, y must have same number of samples, but have shapes " 348 | f"{t.shape}, {y.shape}" 349 | ) 350 | assert t.shape[0] == y.shape[0], msg 351 | else: 352 | u = jnp.asarray(u) 353 | msg = ( 354 | f"t, y, u must have same number of samples, but have shapes " 355 | f"{t.shape}, {y.shape} and {u.shape}" 356 | ) 357 | assert t.shape[0] == y.shape[0] == u.shape[0], msg 358 | 359 | # Compute number of samples per segment. Remove samples at end if total 360 | # number is not divisible by num_shots. 361 | num_samples = len(t) 362 | num_samples_per_segment = int(np.floor((num_samples + (num_shots - 1)) / num_shots)) 363 | leftover_samples = num_samples - (num_samples_per_segment * num_shots) 364 | if leftover_samples: 365 | print("Warning: removing last ", leftover_samples, "samples.") 366 | num_samples -= leftover_samples 367 | t = t[:num_samples] 368 | y = y[:num_samples] 369 | 370 | n_states = len(model.system.initial_state) 371 | 372 | # TODO: use numpy for everything that is not jitted 373 | # Divide signals into segments. 374 | ts = _moving_window(t, num_samples_per_segment, num_samples_per_segment - 1) 375 | ys = _moving_window(y, num_samples_per_segment, num_samples_per_segment - 1) 376 | x0s = np.zeros((num_shots - 1, n_states)) 377 | 378 | ucoeffs = None 379 | if u is not None: 380 | us = u[:num_samples] 381 | us = _moving_window(us, num_samples_per_segment, num_samples_per_segment - 1) 382 | compute_coeffs = lambda t, u: jnp.stack(dfx.backward_hermite_coefficients(t, u)) 383 | ucoeffs = jax.vmap(compute_coeffs)(ts, us) 384 | 385 | # Each segment's time starts at 0. 386 | ts0 = ts - ts[:, :1] 387 | 388 | # Prepare optimization. 389 | model_params, param_bounds, unravel_model = ravel_and_bounds(model) 390 | init_params, unravel_params = ravel_pytree((x0s, model_params)) 391 | 392 | state_bounds = ( 393 | (num_shots - 1) * n_states * [-np.inf], 394 | (num_shots - 1) * n_states * [np.inf], 395 | ) 396 | bounds = ( 397 | state_bounds[0] + param_bounds[0], 398 | state_bounds[1] + param_bounds[1], 399 | ) 400 | std_y = np.std(y, axis=0) 401 | 402 | def residuals(params): 403 | x0s, model_params = unravel_params(params) 404 | model = unravel_model(model_params) 405 | x0s = jnp.concatenate((model.system.initial_state[None], x0s), axis=0) 406 | xs_pred, ys_pred = jax.vmap(model)(t=ts0, ucoeffs=ucoeffs, initial_state=x0s) 407 | # output residual 408 | res_y = ((ys - ys_pred) / std_y).reshape(-1) 409 | res_y = res_y / np.sqrt(len(res_y)) 410 | # continuity residual 411 | std_x = jnp.std(xs_pred, axis=(0, 1)) 412 | res_x0 = ((x0s[1:] - xs_pred[:-1, -1]) / std_x).reshape(-1) 413 | res_x0 = res_x0 / np.sqrt(len(res_x0)) 414 | return jnp.concatenate((res_y, continuity_penalty * res_x0)) 415 | 416 | res = _least_squares(residuals, init_params, bounds, x_scale=False, **kwargs) 417 | 418 | x0s, model_params = unravel_params(res.x) 419 | res.result = unravel_model(model_params) 420 | res.x0s = jnp.concatenate((res.result.system.initial_state[None], x0s), axis=0) 421 | res.ts = np.asarray(ts) 422 | res.ts0 = np.asarray(ts0) 423 | 424 | if u is not None: 425 | res.us = np.asarray(us) 426 | 427 | return res 428 | 429 | 430 | def transfer_function( 431 | sys: AbstractSystem, to_states: bool = False, **kwargs 432 | ) -> Callable[[complex], Array]: 433 | """Compute transfer-function :math:`H(s)` of the linearized system. 434 | 435 | Args: 436 | sys: Concrete dynamical system. 437 | to_states: Whether to return the transfer-function between input and states. 438 | Otherwise compute it between input and output. 439 | kwargs: Optional arguments for :any:`AbstractSystem.linearize`. 440 | 441 | Returns: 442 | A function that computes the transfer-function at a given complex frequency. 443 | 444 | """ 445 | linsys = sys.linearize(**kwargs) 446 | A, B, C, D = linsys.A, linsys.B, linsys.C, linsys.D 447 | 448 | def H(s: complex): 449 | """Transfer-function at s.""" 450 | identity = np.eye(linsys.initial_state.size) 451 | phi_B = jnp.linalg.solve(s * identity - A, B) 452 | if to_states: 453 | return phi_B 454 | return C.dot(phi_B) + D 455 | 456 | return H 457 | 458 | 459 | def estimate_spectra( 460 | u: np.ndarray, y: np.ndarray, sr: float, nperseg: int 461 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: 462 | """Estimate cross- and autospectral densities using Welch's method. 463 | 464 | Args: 465 | u: Input signal. 466 | y: Output signal. 467 | sr: Sampling rate. 468 | nperseg: Number of samples per segment. 469 | 470 | Returns: 471 | Tuple `(f, S_yu, S_uu)` of frequencies, cross- and autospectral densities. 472 | 473 | """ 474 | u_ = np.asarray(u) 475 | y_ = np.asarray(y) 476 | # Prep for correct broadcasting in sig.csd 477 | if u_.ndim == 1: 478 | u_ = u_[:, None] 479 | if y_.ndim == 1: 480 | y_ = y_[:, None] 481 | f, S_yu = sig.csd(u_[:, None, :], y_[:, :, None], fs=sr, nperseg=nperseg, axis=0) 482 | _, S_uu = sig.welch(u, fs=sr, nperseg=nperseg, axis=0) 483 | # Reshape back with dimensions of arguments 484 | S_yu = S_yu.reshape((nperseg // 2 + 1,) + y.shape[1:] + u.shape[1:]) 485 | S_uu = S_uu.reshape((nperseg // 2 + 1,) + u.shape[1:]) 486 | return f, S_yu, S_uu 487 | 488 | 489 | def fit_csd_matching( 490 | sys: AbstractSystem, 491 | u: ArrayLike, 492 | y: ArrayLike, 493 | sr: float = 1.0, 494 | nperseg: int = 1024, 495 | reg: float = 0, 496 | x_scale: bool = True, 497 | verbose_mse: bool = True, 498 | absolute_sigma: bool = False, 499 | fit_dc: bool = False, 500 | linearize_kwargs: dict | None = None, 501 | **kwargs, 502 | ) -> OptimizeResult: 503 | """Estimate parameters of linearized system by matching cross-spectral densities. 504 | 505 | Args: 506 | sys: Concrete dynamical system. 507 | u: Input signal. 508 | y: Output signal. 509 | sr: Sampling rate. 510 | nperseg: Number of samples per segment. 511 | reg: Weight of the :math:`L_2` regularization. 512 | x_scale: Whether parameters are scaled by the initial values. 513 | verbose_mse: Whether the cost is scaled to the mean-squared error during logging 514 | with `verbose=2`. 515 | absolute_sigma: Whether `sigma` is used in an absolute sense and the estimated 516 | parameter covariance reflects these absolute values. Otherwise, only 517 | the relative magnitudes of the sigma values matter. See also 518 | :func:`scipy.optimize.curve_fit`. 519 | fit_dc: Whether to fit the DC term. 520 | linearize_kwargs: Arguments passed to 521 | :py:meth:`~dynax.system.AbstractSystem.linearize`. 522 | kwargs: Optional parameters for `scipy.optimize.least_squares`. 523 | 524 | Returns: 525 | Result object with these additional attributes. 526 | 527 | - `result`: Fitted model. 528 | - `pcov`: Estimated covariance of the parameters. 529 | - `key_paths`: Paths to free parameters of the model, see 530 | :py:func:`jax.tree_util.tree_flatten_with_path`. 531 | - `mse`: Mean-squared error. 532 | - `nmse`: Normalized mean-squared error. 533 | - `nrmse`: Normalized root mean-squared error. 534 | 535 | """ 536 | if linearize_kwargs is None: 537 | linearize_kwargs = {} 538 | 539 | f, Syu, Suu = estimate_spectra(u, y, sr=sr, nperseg=nperseg) 540 | 541 | if not fit_dc: 542 | # remove dc term 543 | f = f[1:] 544 | Syu = Syu[1:] 545 | Suu = Suu[1:] 546 | 547 | s = 2 * np.pi * f * 1j 548 | weight = 1 / np.std(Syu, axis=0) 549 | init_params, bounds, unravel = ravel_and_bounds(sys) 550 | 551 | is_regularized = np.any(reg != 0) 552 | if is_regularized: 553 | cov_prior = np.diag(1 / reg * np.ones(len(init_params))) 554 | reg_term = lambda params: params * reg 555 | else: 556 | cov_prior = None 557 | reg_term = None 558 | 559 | def residuals(params): 560 | sys = unravel(params) 561 | H = transfer_function(sys, **linearize_kwargs) 562 | Gyu_pred = jax.vmap(H)(s) 563 | Syu_pred = Gyu_pred * broadcast_right(Suu, Gyu_pred) 564 | r = (Syu - Syu_pred) * weight 565 | r = jnp.concatenate((jnp.real(r), jnp.imag(r))) 566 | return r.reshape(-1) 567 | 568 | res = _least_squares( 569 | residuals, 570 | init_params, 571 | bounds, 572 | reg_term=reg_term, 573 | x_scale=x_scale, 574 | verbose_mse=verbose_mse, 575 | **kwargs, 576 | ) 577 | 578 | Syu_pred_real, Syu_pred_imag = res.fun[: Syu.size], res.fun[Syu.size :] 579 | Syu_pred = Syu - (Syu_pred_real + 1j * Syu_pred_imag).reshape(Syu.shape) / weight 580 | 581 | res.result = unravel(res.x) 582 | res.pcov = _compute_covariance( 583 | res.jac, res.cost, absolute_sigma, cov_prior=cov_prior 584 | ) 585 | res.key_paths = _key_paths(sys, root=sys.__class__.__name__) 586 | res.mse = np.atleast_1d(mse(Syu, Syu_pred)) 587 | res.nmse = np.atleast_1d(nmse(Syu, Syu_pred)) 588 | res.nrmse = np.atleast_1d(nrmse(Syu, Syu_pred)) 589 | 590 | return res 591 | -------------------------------------------------------------------------------- /dynax/evolution.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Callable, cast, Optional 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from diffrax import ( 7 | AbstractAdaptiveSolver, 8 | AbstractStepSizeController, 9 | ConstantStepSize, 10 | CubicInterpolation, 11 | diffeqsolve, 12 | DirectAdjoint, 13 | Dopri5, 14 | ODETerm, 15 | SaveAt, 16 | ) 17 | from equinox import filter_eval_shape, Module, static_field 18 | from jax import Array 19 | from jaxtyping import PyTree 20 | 21 | from .interpolation import spline_it 22 | from .system import AbstractSystem 23 | from .util import broadcast_right, dim2shape 24 | 25 | 26 | class AbstractEvolution(Module): 27 | """Abstract base-class for evolutions. 28 | 29 | Evolutions combine dynamical systems with a solver. They simulate the evolution of 30 | the system state and output over time given an initial and, possibly, an input 31 | sequence. 32 | 33 | """ 34 | 35 | system: AbstractSystem 36 | 37 | @abstractmethod 38 | def __call__( 39 | self, t: Array, u: Optional[Array], initial_state: Optional[Array] 40 | ) -> tuple[Array, Array]: 41 | """Evolve an initial state along the vector field and compute output. 42 | 43 | Args: 44 | t: Times at which to evaluate the evolution. 45 | u: Optional input sequence of same length. 46 | initial_state: Optional, fixed initial state used instead of 47 | :py:attr:`AbstractSystem.initial_state`. 48 | 49 | Returns: 50 | Tuple `(x, y)` of state and output sequences. 51 | 52 | """ 53 | raise NotImplementedError 54 | 55 | 56 | class Flow(AbstractEvolution): 57 | """Evolution for continous-time dynamical systems. 58 | 59 | Args: 60 | system: Dynamical system. 61 | solver: Differential equation solver. Defaults to :py:class:`diffrax.Dopri5`. 62 | stepsize_controller: Stepsize controller. Defaults to 63 | :py:class:`diffrax.ConstantStepSize`. 64 | 65 | """ 66 | 67 | solver: AbstractAdaptiveSolver = static_field(default_factory=Dopri5) 68 | stepsize_controller: AbstractStepSizeController = static_field( 69 | default_factory=ConstantStepSize 70 | ) 71 | 72 | def __call__( 73 | self, 74 | t: Array, 75 | u: Optional[Array] = None, 76 | initial_state: Optional[Array] = None, 77 | *, 78 | ufun: Optional[Callable[[float], Array]] = None, 79 | ucoeffs: Optional[tuple[PyTree, PyTree, PyTree, PyTree]] = None, 80 | **diffeqsolve_kwargs, 81 | ) -> tuple[Array, Array]: 82 | r"""Evolve an initial state along the vector field and compute output. 83 | 84 | Args: 85 | t: Times at which to evaluate the evolution. 86 | u: Optional input sequence of same length. 87 | initial_state: Optional, fixed initial state used instead of 88 | :py:attr:`AbstractSystem.initial_state`. 89 | ufun: A function :math:`t \mapsto u`. Can be used instead of `u` or 90 | `ucoeffs`. 91 | ucoeffs: Precomputed spline coefficients of the input passed to 92 | :py:class:`diffrax.CubicInterpolation`. Can be used instead of `u` or 93 | `ufun`. 94 | **diffeqsolve_kwargs: Additional arguments passed to 95 | :py:meth:`diffrax.diffeqsolve`. 96 | 97 | Returns: 98 | Tuple `(x, y)` of state and output sequences. 99 | 100 | """ 101 | # Parse inputs. 102 | t = jnp.asarray(t) 103 | 104 | if initial_state is not None: 105 | x = jnp.asarray(initial_state) 106 | if initial_state.shape != self.system.initial_state.shape: 107 | raise ValueError("Initial state dimenions do not match.") 108 | else: 109 | initial_state = self.system.initial_state 110 | 111 | # Prepare input function. 112 | if ucoeffs is not None: 113 | path = CubicInterpolation(t, ucoeffs) 114 | _ufun = path.evaluate 115 | elif callable(ufun): 116 | _ufun = ufun 117 | elif u is not None: 118 | u = jnp.asarray(u) 119 | if len(t) != u.shape[0]: 120 | raise ValueError("t and u must have matching first dimension.") 121 | _ufun = spline_it(t, u) 122 | elif self.system.n_inputs == 0: 123 | _ufun = lambda t: jnp.empty((0,)) 124 | else: 125 | raise ValueError("Must specify one of u, ufun, or ucoeffs.") 126 | 127 | # Check shape of ufun return values. 128 | _u = filter_eval_shape(_ufun, 0.0) 129 | if not isinstance(_u, jax.ShapeDtypeStruct): 130 | raise ValueError(f"ufun must return Arrays, not {type(_u)}.") 131 | else: 132 | if not _u.shape == dim2shape(self.system.n_inputs): 133 | raise ValueError( 134 | f"Input dimensions do not match: inputs have shape {_u.shape}, but" 135 | f"system.n_inputs is {self.system.n_inputs}" 136 | ) 137 | 138 | # Solve ODE. 139 | diffeqsolve_default_options = dict( 140 | solver=self.solver, 141 | stepsize_controller=self.stepsize_controller, 142 | saveat=SaveAt(ts=t), 143 | max_steps=50 * len(t), # completely arbitrary number of steps 144 | adjoint=DirectAdjoint(), 145 | dt0=( 146 | t[1] if isinstance(self.stepsize_controller, ConstantStepSize) else None 147 | ), 148 | ) 149 | diffeqsolve_default_options |= diffeqsolve_kwargs 150 | vector_field = lambda t, x, self: self.system.vector_field(x, _ufun(t), t) 151 | term = ODETerm(vector_field) 152 | x = diffeqsolve( 153 | term, 154 | t0=t[0], 155 | t1=t[-1], 156 | y0=initial_state, 157 | args=self, # https://github.com/patrick-kidger/diffrax/issues/135 158 | **diffeqsolve_default_options, # type: ignore 159 | ).ys 160 | # Could be in general a Pytree, but we only allow Array states. 161 | x = cast(Array, x) 162 | 163 | # Compute output. 164 | y = jax.vmap(self.system.output)(x, u, t) 165 | 166 | return x, y 167 | 168 | 169 | class Map(AbstractEvolution): 170 | """Evolution for discrete-time dynamical systems. 171 | 172 | Args: 173 | system: Dynamical system. 174 | 175 | """ 176 | 177 | def __call__( 178 | self, 179 | t: Optional[Array] = None, 180 | u: Optional[Array] = None, 181 | initial_state: Optional[Array] = None, 182 | *, 183 | num_steps: Optional[int] = None, 184 | ) -> tuple[Array, Array]: 185 | """Evolve an initial state along the vector field and compute output. 186 | 187 | Args: 188 | t: Times at which to evaluate the evolution. 189 | u: Optional input sequence of same length. 190 | initial_state: Optional, fixed initial state used instead of 191 | :py:attr:`AbstractSystem.initial_state`. 192 | num_steps: Number of steps to compute if `t` and `u` are not specified. 193 | 194 | Returns: 195 | Tuple `(x, y)` of state and output sequences. 196 | 197 | """ 198 | 199 | # Parse inputs. 200 | if initial_state is not None: 201 | x = jnp.asarray(initial_state) 202 | if initial_state.shape != self.system.initial_state.shape: 203 | raise ValueError("Initial state dimenions do not match.") 204 | else: 205 | initial_state = self.system.initial_state 206 | 207 | if t is not None: 208 | t = jnp.asarray(t) 209 | elif u is not None: 210 | u = jnp.asarray(u) 211 | elif num_steps is not None: 212 | t = jnp.arange(num_steps) 213 | else: 214 | raise ValueError("must specify one of num_steps, t, or u.") 215 | 216 | if t is not None and u is not None: 217 | if t.shape[0] != u.shape[0]: 218 | raise ValueError("t and u must have the same first dimension.") 219 | inputs = jnp.stack((broadcast_right(t, u), u), axis=1) 220 | unpack = lambda input: (input[0], input[1]) 221 | elif t is not None: 222 | inputs = t 223 | unpack = lambda input: (input, None) 224 | else: 225 | inputs = u 226 | unpack = lambda input: (None, input) 227 | 228 | # Evolve. 229 | def scan_fun(state, input): 230 | t, u = unpack(input) 231 | next_state = self.system.vector_field(state, u, t) 232 | return next_state, state 233 | 234 | _, x = jax.lax.scan(scan_fun, initial_state, inputs, length=num_steps) 235 | 236 | # Compute output. 237 | y = jax.vmap(self.system.output)(x, u, t) 238 | 239 | return x, y 240 | -------------------------------------------------------------------------------- /dynax/example_models.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | 4 | from .custom_types import Array, Scalar 5 | from .system import ( 6 | AbstractControlAffine, 7 | AbstractSystem, 8 | boxed_field, 9 | field, 10 | non_negative_field, 11 | static_field, 12 | ) 13 | 14 | 15 | # Define a general dynamical system by subclassing `AbstractSystem`. 16 | class SpringMassDamper(AbstractSystem): 17 | """Forced linear spring-mass-damper system. 18 | 19 | .. math:: m x'' + r x' + k x = u. 20 | 21 | """ 22 | 23 | # Define the system parameters as data fields. 24 | m: float = field() 25 | """Mass.""" 26 | r: float = field() 27 | """Linear drag.""" 28 | k: float = field() 29 | """Stiffness.""" 30 | 31 | # The following two fields are aleady defined in `AbstractSystem`. Thus, their 32 | # type declarations can be left out. 33 | initial_state = np.zeros(2) 34 | n_inputs = "scalar" 35 | 36 | # Define the vector field of the system by implementing the `vector_field` method. 37 | def vector_field(self, x: Array, u: Scalar, t=None) -> Array: 38 | """The vector field. 39 | 40 | .. math:: ẋ = [x_2, (u - r x_2 - k x_1) / m]^T. 41 | 42 | Args: 43 | x: State vector. 44 | u: Optional input vector. 45 | 46 | Returns: 47 | State derivative. 48 | 49 | """ 50 | x1, x2 = x 51 | return jnp.array([x2, (u - self.r * x2 - self.k * x1) / self.m]) 52 | 53 | # This class does not override the `AbstractSystem.output` method. The output is 54 | # then the full state vector by default. 55 | 56 | 57 | # Systems that have a control affine structure can subclass `AbstractControlAffine` and 58 | # implement the `f`, `g`, and `h` methods. Such systems can often be input-output 59 | # linearized with the functions in `dynax.linearizate`. 60 | class NonlinearDrag(AbstractControlAffine): 61 | """Forced spring-mass-damper system with nonlin drag. 62 | 63 | .. math:: m x'' + r x' + r_2 x'|x'| + k x = u. 64 | 65 | """ 66 | 67 | r: Array = field() 68 | """Linear drag.""" 69 | r2: Array = field() 70 | """Nonlinear drag.""" 71 | k: Array = field() 72 | """Stiffness.""" 73 | m: Array = field() 74 | """Mass.""" 75 | 76 | # We can define additional dataclass fields that do not represent trainable 77 | # model parameters using the `static_field` function. This function tells JAX that 78 | # the field is a constant and should not be differentiated by. 79 | outputs: tuple[int, ...] = static_field(default=(0,)) 80 | """Indeces of state vectors that are outputs. Defaults to `[0]`.""" 81 | 82 | initial_state = np.zeros(2) 83 | n_inputs = "scalar" 84 | 85 | def f(self, x: Array) -> Array: 86 | """Constant-input part of the vector field. 87 | 88 | .. math: f(x) = [x_2, (-r x_2 - r_2 |x_2| x_2 - k x_1) / m]^T. 89 | 90 | """ 91 | x1, x2 = x 92 | return jnp.array( 93 | [x2, (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1) / self.m] 94 | ) 95 | 96 | def g(self, x: Array) -> Array: 97 | """Input-proportional part of the vector field. 98 | 99 | .. math: g(x) = [0, 1 / m]^T. 100 | 101 | """ 102 | return jnp.array([0.0, 1.0 / self.m]) 103 | 104 | def h(self, x: Array) -> Array: 105 | """Output function. 106 | 107 | .. math: y = h(x) = {x_j | j ∈ outputs}. 108 | 109 | """ 110 | return x[np.array(self.outputs)] 111 | 112 | 113 | class Sastry9_9(AbstractControlAffine): 114 | r"""Example 9.9 in :cite:t:`sastry2013nonlinear`. 115 | 116 | .. math:: 117 | 118 | x_1' &= e^{x_1} u \\ 119 | x_2' &= x_1 + x_2^2 + e^{x_1} u \\ 120 | x_3' &= x_1 - x_2 \\ 121 | y &= x_3 \\ 122 | 123 | """ 124 | 125 | initial_state = np.zeros(3) 126 | n_inputs = "scalar" 127 | 128 | def f(self, x: Array) -> Array: 129 | return jnp.array([0.0, x[0] + x[1] ** 2, x[0] - x[1]]) 130 | 131 | def g(self, x: Array) -> Array: 132 | return jnp.array([jnp.exp(x[1]), jnp.exp(x[1]), 0.0]) 133 | 134 | def h(self, x: Array) -> Scalar: 135 | return x[2] 136 | 137 | 138 | class LotkaVolterra(AbstractSystem): 139 | r"""The notorious predator-prey model. 140 | 141 | .. math:: 142 | 143 | x_1' &= α x_1 - β x_1 x_2 \\ 144 | x_2' &= δ x_1 x_2 - γ x_2 \\ 145 | y &= [x_1, x_2]^T 146 | 147 | """ 148 | 149 | # The values of parameters can be constrained by initializing them with the 150 | # `non_negative_field` and `boxed_field` functions 151 | alpha: float = boxed_field(0.0, jnp.inf, default=0.0) 152 | beta: float = boxed_field(0.0, jnp.inf, default=0.0) 153 | gamma: float = boxed_field(0.0, jnp.inf, default=0.0) 154 | delta: float = non_negative_field(default=0.0) # same as boxed_field(0, jnp.inf) 155 | 156 | initial_state = np.ones(2) * 0.5 157 | 158 | # Systems without inputs should set n_inputs to zero. 159 | n_inputs = 0 160 | 161 | def vector_field(self, x, u=None, t=None): 162 | x, y = x 163 | return jnp.array( 164 | [self.alpha * x - self.beta * x * y, self.delta * x * y - self.gamma * y] 165 | ) 166 | 167 | 168 | # We can also subclass already defined systems to further change their behaviour. 169 | class LotkaVolterraWithTrainableInitialState(LotkaVolterra): 170 | # We can release parameter constraints with `field`. This will remove 171 | # the metadata on the corresponding field, indcating that this parameter is 172 | # unconstrained. 173 | alpha: float = field(default=1.0) 174 | 175 | # In constrast, the following line will not change the constraint on the parameter, 176 | # only its default value. The metadata of the field is unchanged. 177 | beta = 1.0 178 | 179 | # Here we redeclare the initial_state field to be trainable. When default values 180 | # with the field functions are set to mutable values (which includes 181 | # jax.Array), one must use the `default_factory` argument. 182 | initial_state: Array = field(default_factory=lambda: jnp.ones(2) * 0.5) 183 | -------------------------------------------------------------------------------- /dynax/interpolation.py: -------------------------------------------------------------------------------- 1 | import diffrax as dfx 2 | import equinox 3 | import jax.numpy as jnp 4 | from jax import Array 5 | 6 | 7 | class InterpolationFunction(equinox.Module): 8 | """Interpolating cubic-spline function.""" 9 | 10 | path: dfx.CubicInterpolation 11 | 12 | def __init__(self, ts: Array, xs: Array): 13 | ts = jnp.asarray(ts) 14 | xs = jnp.asarray(xs) 15 | if len(ts) != xs.shape[0]: 16 | raise ValueError("time and data must have same number of samples") 17 | coeffs = dfx.backward_hermite_coefficients(ts, xs) 18 | self.path = dfx.CubicInterpolation(ts, coeffs) 19 | 20 | def __call__(self, t: float) -> Array: 21 | """Evaluate the interpolating function at time `t`.""" 22 | return self.path.evaluate(t) 23 | 24 | 25 | def spline_it(ts: Array, xs: Array) -> InterpolationFunction: 26 | """Create an interpolating cubic-spline function. 27 | 28 | Args: 29 | ts: Time sequence. 30 | xs: Data points with first axis having the same length as `t`. 31 | 32 | Returns: 33 | A function `f(t)` that computes the interpolated value at time `t`. 34 | 35 | """ 36 | return InterpolationFunction(ts, xs) 37 | -------------------------------------------------------------------------------- /dynax/linearize.py: -------------------------------------------------------------------------------- 1 | """Functions related to input-output linearization of nonlinear systems.""" 2 | 3 | from collections.abc import Callable 4 | from functools import partial 5 | from typing import Optional, Sequence 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import optimistix as optx 11 | from jax import Array 12 | 13 | from .custom_types import Scalar 14 | from .derivative import lie_derivative 15 | from .system import ( 16 | _CoupledSystemMixin, 17 | AbstractControlAffine, 18 | AbstractSystem, 19 | DynamicStateFeedbackSystem, 20 | LinearSystem, 21 | ) 22 | 23 | 24 | def relative_degree( 25 | sys: AbstractControlAffine, xs: Array, output: Optional[int] = None 26 | ) -> int: 27 | """Estimate the relative degree of a SISO control-affine system. 28 | 29 | Tests that the Lie derivatives of the output are zero exactly up but not including 30 | to the relative-degree'th order for each state in `xs`. 31 | 32 | Args: 33 | sys: Continous time control-affine system with well defined relative degree and 34 | single input and output. 35 | xs: Samples of the state space stacked along the first axis. 36 | output: Optional index of the output if `sys` has multiple outputs. 37 | 38 | Returns: 39 | Estimated relative degree of the system. 40 | 41 | """ 42 | if sys.n_inputs not in ["scalar", 1]: 43 | raise ValueError("System must be single input.") 44 | if output is None: 45 | # Make sure system has single output 46 | if sys.n_outputs not in ["scalar", 1]: 47 | raise ValueError(f"Output is None, but system has {sys.n_outputs} outputs.") 48 | h = sys.h 49 | else: 50 | h = lambda *args, **kwargs: sys.h(*args, **kwargs)[output] 51 | 52 | max_reldeg = jnp.size(sys.initial_state) 53 | for n in range(0, max_reldeg + 1): 54 | if n == 0: 55 | res = jax.vmap(sys.i)(xs) 56 | else: 57 | LgLfn1h = lie_derivative(sys.g, lie_derivative(sys.f, h, n - 1)) 58 | res = jax.vmap(LgLfn1h)(xs) 59 | 60 | if np.all(res == 0.0): 61 | continue 62 | elif np.all(res != 0.0): 63 | return n 64 | 65 | raise RuntimeError("sys has ill-defined relative degree.") 66 | 67 | 68 | # TODO: remove? 69 | def is_controllable(A, B) -> bool: 70 | """Test controllability of linear system.""" 71 | n = A.shape[0] 72 | contrmat = np.hstack([np.linalg.matrix_power(A, ni).dot(B) for ni in range(n)]) 73 | return np.linalg.matrix_rank(contrmat) == n 74 | 75 | 76 | # TODO: Adapt to general nonlinear reference system. 77 | def input_output_linearize( 78 | sys: AbstractControlAffine, 79 | reldeg: int, 80 | ref: LinearSystem, 81 | output: Optional[int] = None, 82 | asymptotic: Optional[Sequence] = None, 83 | reg: Optional[float] = None, 84 | ) -> Callable[[Array, Array, float], Scalar]: 85 | """Construct an input-output linearizing feedback law. 86 | 87 | Args: 88 | sys: Continous time control-affine system with well defined relative degree and 89 | single input and output. 90 | reldeg: Relative degree of `sys` and lower bound of relative degree of `ref`. 91 | ref: Linear target system with single input and output. 92 | output: Optional index of the output if `sys` has multiple outputs. 93 | asymptotic: If `None`, compute the exactly linearizing law. Otherwise, compute 94 | an asymptotically linearizing law. Then `asymptotic` is interpreted as the 95 | sequence of length `reldeg` of coefficients of the characteristic polynomial 96 | of the tracking error system. 97 | reg: Regularization parameter that controls the linearization effort. Only 98 | effective if asymptotic is not `None`. 99 | 100 | Returns: 101 | Feedback law `u = u(x, z, v)` that input-output linearizes the system. 102 | 103 | """ 104 | assert sys.n_inputs == ref.n_inputs, "systems have same input dimension" 105 | assert sys.n_inputs in [1, "scalar"] 106 | 107 | if output is None: 108 | assert sys.n_outputs == ref.n_outputs, "systems must have same output dimension" 109 | assert sys.n_outputs in [1, "scalar"] 110 | h = sys.h 111 | A, b, c = ref.A, ref.B, ref.C 112 | else: 113 | h = lambda x, t=None: sys.h(x)[output] 114 | A, b, c = ref.A, ref.B, ref.C[output] 115 | 116 | Lfnh = lie_derivative(sys.f, h, reldeg) 117 | LgLfnm1h = lie_derivative(sys.g, lie_derivative(sys.f, h, reldeg - 1)) 118 | cAn = c.dot(np.linalg.matrix_power(A, reldeg)) 119 | cAnm1b = c.dot(np.linalg.matrix_power(A, reldeg - 1)).dot(b) 120 | 121 | if asymptotic is None: 122 | 123 | def feedbacklaw(x: Array, z: Array, v: float) -> Scalar: 124 | y_reldeg_ref = cAn.dot(z) + cAnm1b * v 125 | y_reldeg = Lfnh(x) 126 | out = (y_reldeg_ref - y_reldeg) / LgLfnm1h(x) 127 | return out if sys.n_inputs != "scalar" else out.squeeze() 128 | 129 | else: 130 | if len(asymptotic) != reldeg: 131 | raise ValueError( 132 | f"asymptotic must be of length {reldeg=} but, {len(asymptotic)=}" 133 | ) 134 | 135 | coeffs = np.concatenate(([1], asymptotic)) 136 | if not np.all(np.real(np.roots(coeffs)) <= 0): 137 | raise ValueError("Polynomial must be Hurwitz") 138 | 139 | alphas = asymptotic 140 | 141 | cAis = [c.dot(np.linalg.matrix_power(A, i)) for i in range(reldeg)] 142 | Lfihs = [lie_derivative(sys.f, h, i) for i in range(reldeg)] 143 | 144 | def feedbacklaw(x: Array, z: Array, v: float) -> Scalar: 145 | y_reldeg_ref = cAn.dot(z) + cAnm1b * v 146 | y_reldeg = Lfnh(x) 147 | ae0s = jnp.array( 148 | [ 149 | ai * (cAi.dot(z) - Lfih(x)) 150 | for ai, Lfih, cAi in zip(alphas, Lfihs, cAis, strict=True) 151 | ] 152 | ) 153 | error = y_reldeg_ref - y_reldeg + jnp.sum(ae0s) 154 | if reg is None: 155 | out = error / LgLfnm1h(x) 156 | else: 157 | l = LgLfnm1h(x) 158 | out = error * l / (l + reg) 159 | return out if sys.n_inputs != "scalar" else out.squeeze() 160 | 161 | return feedbacklaw 162 | 163 | 164 | def _propagate(f: Callable[[Array, float], Array], n: int, x: Array, u: float) -> Array: 165 | # Propagates system for n <= discrete_relative_degree(sys) steps.""" 166 | def fun(x, _): 167 | return f(x, u), None 168 | 169 | xn, _ = jax.lax.scan(fun, x, jnp.arange(n)) 170 | return xn 171 | 172 | 173 | def discrete_relative_degree( 174 | sys: AbstractSystem, 175 | xs: Array, 176 | us: Array, 177 | output: Optional[int] = None, 178 | ): 179 | """Estimate the relative degree of a SISO discrete-time system. 180 | 181 | Tests that exactly the first relative-degree - 1 output samples are independent of 182 | the input for each `(x, u)` for the initial state and input samples `(xs, us)`. In 183 | this way, the discrete relative-degree can be interpreted as a system delay. 184 | 185 | Args: 186 | sys: Discrete-time dynamical system with well defined relative degree and 187 | single input and output. 188 | xs: Initial state samples stacked along the first axis. 189 | us: Initial input samples stacked along the first axis. 190 | output: Optional index of the output if the system has multiple outputs. 191 | 192 | Returns: 193 | The discrete-time relative degree of the system. 194 | 195 | See :cite:p:`leeLinearizationNonlinearControl2022{def 7.7.}`. 196 | 197 | """ 198 | if sys.n_inputs not in ["scalar", 1]: 199 | raise ValueError("System must be single input.") 200 | if output is None: 201 | # Make sure system has single output 202 | if sys.n_outputs not in ["scalar", 1]: 203 | raise ValueError(f"Output is None, but system has {sys.n_outputs} outputs.") 204 | h = sys.output 205 | else: 206 | h = lambda *args, **kwargs: sys.output(*args, **kwargs)[output] 207 | 208 | f = sys.vector_field 209 | y = lambda n, x, u: h(_propagate(f, n, x, u), u) 210 | y_depends_u = jax.grad(y, 2) 211 | 212 | max_reldeg = jnp.size(sys.initial_state) 213 | for n in range(0, max_reldeg + 1): 214 | res = jax.vmap(partial(y_depends_u, n))(xs, us) 215 | if np.all(res == 0): 216 | continue 217 | elif np.all(res != 0): 218 | return n 219 | raise RuntimeError("sys has ill defined relative degree.") 220 | 221 | 222 | def discrete_input_output_linearize( 223 | sys: AbstractSystem, 224 | reldeg: int, 225 | ref: AbstractSystem, 226 | output: Optional[int] = None, 227 | solver: Optional[optx.AbstractRootFinder] = None, 228 | ) -> Callable[[Array, Array, float, float], float]: 229 | """Construct the input-output linearizing feedback for a discrete-time system. 230 | 231 | This is similar to model-predictive control with a horizon of a single time 232 | step and without constraints. The reference system can be nonlinear, in 233 | which case the feedback law implements an exact tracking controller. 234 | 235 | Args: 236 | sys: Discrete-time dynamical system with well defined relative degree and 237 | single input and output. 238 | reldeg: Relative degree of `sys` and lower bound of relative degree of `ref`. 239 | ref: Discrete-time reference system. 240 | output: Optional index of the output if the `sys` has multiple outputs. 241 | solver: Root finding algorithm to solve the feedback law. Defaults to 242 | :py:class:`optimistix.Newton` with absolute and relative tolerance `1e-6`. 243 | 244 | Returns: 245 | Feedback law :math:`u_n = u(x_n, z_n, v_n, u_{n-1})` that input-output 246 | linearizes the system. 247 | 248 | See :cite:p:`leeLinearizationNonlinearControl2022{def 7.4.}`. 249 | 250 | """ 251 | f = lambda x, u: sys.vector_field(x, u) 252 | h = sys.output 253 | if sys.n_inputs != ref.n_inputs != 1: 254 | raise ValueError("Systems must have single input.") 255 | if output is None: 256 | if not (sys.n_outputs == ref.n_outputs and sys.n_outputs in ["scalar", 1]): 257 | raise ValueError("Systems must be single output and `output` is None.") 258 | _output = lambda x: x 259 | else: 260 | _output = lambda x: x[output] 261 | 262 | if solver is None: 263 | solver = optx.Newton(rtol=1e-6, atol=1e-6) 264 | 265 | def y_reldeg_ref(z, v): 266 | if isinstance(ref, LinearSystem): 267 | # A little faster for the linear case (if this is not optimized by jit) 268 | A, b, c = ref.A, ref.B, ref.C 269 | A_reldeg = c.dot(np.linalg.matrix_power(A, reldeg)) 270 | B_reldeg = c.dot(np.linalg.matrix_power(A, reldeg - 1)).dot(b) 271 | return _output(A_reldeg.dot(z) + B_reldeg.dot(v)) 272 | else: 273 | _output(ref.output(_propagate(ref.vector_field, reldeg, z, v))) 274 | 275 | def feedbacklaw(x: Array, z: Array, v: float, u_prev: float) -> float: 276 | def fn(u, _): 277 | return ( 278 | _output(h(_propagate(f, reldeg, x, u))) - y_reldeg_ref(z, v) 279 | ).squeeze() 280 | 281 | u = optx.root_find(fn, solver, u_prev).value 282 | return u 283 | 284 | return feedbacklaw 285 | 286 | 287 | class DiscreteLinearizingSystem(AbstractSystem, _CoupledSystemMixin): 288 | r"""Coupled discrete-time system of dynamics, reference and linearizing feedback. 289 | 290 | .. math:: 291 | 292 | x_{n+1} &= f^{sys}(x_n, v_n) \\ 293 | z_{n+1} &= f^{ref}(z_n, u_n) \\ 294 | y_n &= v_n = v(x_n, z_n, u_n) 295 | 296 | where :math:`v` is such that :math:`y_n^{sys} = h^{sys}(x_n, u_n)` equals 297 | :math:`y^{ref}_n = h^{ref}(z_n, u_n)`. 298 | 299 | Args: 300 | sys: Discrete-time dynamical system with well defined relative degree and 301 | single input and output. 302 | ref: Discrete-time reference system. 303 | reldeg: Discrete relative degree of `sys` and lower bound of discrete relative 304 | degree of `ref`. 305 | fb_kwargs: Additional keyword arguments passed to 306 | :py:func:`discrete_input_output_linearize`. 307 | 308 | """ 309 | 310 | _v: Callable 311 | 312 | n_inputs = "scalar" 313 | 314 | def __init__( 315 | self, 316 | sys: AbstractSystem, 317 | ref: AbstractSystem, 318 | reldeg: int, 319 | **fb_kwargs, 320 | ): 321 | if sys.n_inputs != "scalar": 322 | raise ValueError("Only single input systems supported.") 323 | self._sys1 = sys 324 | self._sys2 = ref 325 | self.initial_state = jnp.append( 326 | self._pack_states(self._sys1.initial_state, self._sys2.initial_state), 0.0 327 | ) 328 | self._v = discrete_input_output_linearize(sys, reldeg, ref, **fb_kwargs) 329 | 330 | def vector_field(self, x, u=None, t=None): 331 | (x, z), v_last = self._unpack_states(x[:-1]), x[-1] 332 | v = self._v(x, z, u, v_last) 333 | xn = self._sys1.vector_field(x, v) 334 | zn = self._sys2.vector_field(z, u) 335 | return jnp.append(self._pack_states(xn, zn), v) 336 | 337 | def output(self, x, u=None, t=None): 338 | (x, z), v_last = self._unpack_states(x[:-1]), x[-1] 339 | v = self._v(x, z, u, v_last) # NOTE: feedback law is computed twice 340 | return v 341 | 342 | 343 | class LinearizingSystem(DynamicStateFeedbackSystem): 344 | r"""Coupled ODE of nonlinear dynamics, linear reference and linearizing feedback. 345 | 346 | .. math:: 347 | 348 | ẋ &= f(x) + g(x)v \\ 349 | ż &= Az + Bu \\ 350 | y &= v = v(x, z, u) 351 | 352 | where :math:`v` is such that :math:`y^{sys} = h(x) + i(x)v` equals 353 | :math:`y^{ref} = Cz + Du`. 354 | 355 | Args: 356 | sys: Continous time control-affine system with well defined relative degree and 357 | single input and output. 358 | ref: Linear target system with single input and output. 359 | reldeg: Relative degree of `sys` and lower bound of relative degree of `ref`. 360 | fb_kwargs: Additional keyword arguments passed to 361 | :py:func:`input_output_linearize`. 362 | 363 | """ 364 | 365 | n_inputs = "scalar" 366 | 367 | def __init__( 368 | self, 369 | sys: AbstractControlAffine, 370 | ref: LinearSystem, 371 | reldeg: int, 372 | **fb_kwargs, 373 | ): 374 | v = input_output_linearize(sys, reldeg, ref, **fb_kwargs) 375 | super().__init__(sys, ref, v) 376 | 377 | def output(self, x, u, t=None): 378 | x, z = self._unpack_states(x) 379 | v = self._v(x, z, u) 380 | return v 381 | -------------------------------------------------------------------------------- /dynax/structident.py: -------------------------------------------------------------------------------- 1 | # TODO: This is an old draft, update and test this file. 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from .derivative import extended_lie_derivative, lie_derivative 7 | 8 | 9 | def obs_ident_mat(sys, x0, u=None, t=None): 10 | """Generalized observability-identifiability matrix for constant input. 11 | 12 | Villaverde, 2017. 13 | """ 14 | params, treedef = jax.tree_util.tree_flatten(sys) 15 | 16 | def f(x, p): 17 | """Vector-field for argumented state vector xp = [x, p].""" 18 | model = treedef.unflatten(p) 19 | return model.vector_field(x, u, t) 20 | 21 | def g(x, p): 22 | """Output function for argumented state vector xp = [x, p].""" 23 | model = treedef.unflatten(p) 24 | return model.output(x, t) 25 | 26 | params = jnp.array(params) 27 | O_i = jnp.vstack( 28 | [ 29 | jnp.hstack(jax.jacfwd(lie_derivative(f, g, n), (0, 1))(x0, params)) 30 | for n in range(sys.n_states + sys.n_params) 31 | ] 32 | ) 33 | 34 | return O_i 35 | 36 | 37 | def extended_obs_ident_mat(sys, x0, u, t=None): 38 | """Generalized observability-identifiability matrix for constant input. 39 | 40 | Villaverde, 2017. 41 | """ 42 | params, treedef = jax.tree_util.tree_flatten(sys) 43 | 44 | def f(x, u, p): 45 | """Vector-field for argumented state vector xp = [x, p].""" 46 | model = treedef.unflatten(p) 47 | return model.vector_field(x, u, t) 48 | 49 | def g(x, p): 50 | """Output function for argumented state vector xp = [x, p].""" 51 | model = treedef.unflatten(p) 52 | return model.output(x, t) 53 | 54 | params = jnp.array(params) 55 | u = jnp.array(u) 56 | lies = [ 57 | extended_lie_derivative(f, g, n) for n in range(sys.n_states + sys.n_params) 58 | ] 59 | grad_of_outputs = [jnp.hstack(jax.jacfwd(l, (0, 2))(x0, u, params)) for l in lies] 60 | O_i = jnp.vstack(grad_of_outputs) 61 | return O_i 62 | -------------------------------------------------------------------------------- /dynax/system.py: -------------------------------------------------------------------------------- 1 | """Classes representing dynamical systems.""" 2 | 3 | from abc import abstractmethod 4 | from collections.abc import Callable 5 | from dataclasses import Field 6 | from typing import Any, Literal, TypeVar 7 | 8 | import equinox as eqx 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | from jax import Array 13 | 14 | from .custom_types import FloatScalarLike 15 | from .util import dim2shape, pretty 16 | 17 | 18 | def _linearize(f, h, x0, u0, t0): 19 | """Linearize dx=f(x,u,t), y=h(x,u,t) around x0, u0, t0.""" 20 | A = jax.jacfwd(f, argnums=0)(x0, u0, t0) 21 | B = jax.jacfwd(f, argnums=1)(x0, u0, t0) 22 | C = jax.jacfwd(h, argnums=0)(x0, u0, t0) 23 | D = jax.jacfwd(h, argnums=1)(x0, u0, t0) 24 | return A, B, C, D 25 | 26 | 27 | T = TypeVar("T") 28 | 29 | 30 | def _to_static_array(x: T) -> np.ndarray | T: 31 | if isinstance(x, jax.Array): 32 | return np.asarray(x) 33 | else: 34 | return x 35 | 36 | 37 | def field(**kwargs: Any) -> Field: 38 | """Mark an attribute value as trainable and unconstrained. 39 | 40 | Args: 41 | **kwargs: Keyword arguments passed to :py:func:`dataclasses.field`. 42 | 43 | """ 44 | try: 45 | metadata = dict(kwargs["metadata"]) 46 | except KeyError: 47 | metadata = kwargs["metadata"] = {} 48 | metadata["constrained"] = False 49 | return eqx.field(converter=jnp.asarray, **kwargs) 50 | 51 | 52 | def static_field(**kwargs: Any) -> Field: 53 | """Mark an attribute value as non-trainable. 54 | 55 | Like :py:func:`equinox.field`, but removes constraints if they exist and converts 56 | JAX arrays to Numpy arrays. 57 | 58 | Args: 59 | **kwargs: Keyword arguments passed to :py:func:`eqx.field`. 60 | 61 | """ 62 | try: 63 | metadata = dict(kwargs["metadata"]) 64 | except KeyError: 65 | metadata = kwargs["metadata"] = {} 66 | metadata["constrained"] = False 67 | return eqx.field(converter=_to_static_array, **kwargs) 68 | 69 | 70 | def boxed_field(lower: float, upper: float, **kwargs: Any) -> Field: 71 | """Mark an attribute value as trainable and box-constrained on `[lower, upper]`. 72 | 73 | Args: 74 | lower: Lower bound. 75 | upper: Upper bound. 76 | **kwargs: Keyword arguments passed to :py:func:`dataclasses.field`. 77 | 78 | """ 79 | try: 80 | metadata = dict(kwargs["metadata"]) 81 | except KeyError: 82 | metadata = kwargs["metadata"] = {} 83 | metadata["constrained"] = ("boxed", (lower, upper)) 84 | return field(**kwargs) 85 | 86 | 87 | def non_negative_field(min_val: float = 0.0, **kwargs: Any) -> Field: 88 | """Mark an attribute value as trainable and non-negative. 89 | 90 | Args: 91 | min_val: Minimum value. 92 | **kwargs: Keyword arguments passed to :py:func:`dataclasses.field`. 93 | 94 | """ 95 | return boxed_field(lower=min_val, upper=np.inf, **kwargs) 96 | 97 | 98 | class AbstractSystem(eqx.Module): 99 | r"""Base class for dynamical systems. 100 | 101 | Any dynamical system in Dynax must inherit from this class. Subclasses can define 102 | continous-time 103 | 104 | .. math:: 105 | 106 | ẋ &= f(x, u, t) \\ 107 | y &= h(x, u, t) 108 | 109 | or discrete-time 110 | 111 | .. math:: 112 | 113 | x_{k+1} &= f(x_k, u_k, t) \\ 114 | y_k &= h(x_k, u_k, t) 115 | 116 | system. The distinction between the two is only made when instances of subclasses 117 | are passed to objects such as :py:class:`dynax.evolution.Flow`, 118 | :py:class:`dynax.evolution.Map`, :py:class:`dynax.linearize.input_output_linearize`, 119 | or :py:class:`dynax.linearize.discrete_input_output_linearize`. 120 | 121 | Subclasses must set values for the `n_inputs`, and `initial_state` attributes 122 | and implement the `vector_field` method. The `output` method describes the measurent 123 | equations. By default, the full state vector is returned as output. 124 | 125 | Example:: 126 | 127 | class IntegratorAndGain(AbstractSystem): 128 | n_states = 1 129 | n_inputs = "scalar" 130 | gain: float 131 | 132 | def vector_field(self, x, u, t): 133 | dx = u 134 | return dx 135 | 136 | def output(self, x, u, t): 137 | return self.gain*x 138 | 139 | 140 | `AbstractSystem` is a dataclass and as such defines a default constructor which can 141 | make it necessary to implement a custom `__init__` method. 142 | 143 | """ 144 | 145 | # TODO: make these abstract vars? 146 | initial_state: np.ndarray = static_field(init=False) 147 | """Initial state vector.""" 148 | n_inputs: int | Literal["scalar"] = static_field(init=False) 149 | """Number of inputs.""" 150 | 151 | def __check_init__(self): 152 | # Check that required attributes are initialized 153 | required_attrs = ["initial_state", "n_inputs"] 154 | for attr in required_attrs: 155 | if not hasattr(self, attr): 156 | raise AttributeError(f"Attribute '{attr}' not initialized.") 157 | 158 | with jax.ensure_compile_time_eval(): 159 | # Check that vector_field and output returns Arrays or scalars - not PyTrees 160 | x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64) 161 | u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) 162 | try: 163 | dx = eqx.filter_eval_shape(self.vector_field, x, u, t=1.0) 164 | y = eqx.filter_eval_shape(self.output, x, u, t=1.0) 165 | except Exception as e: 166 | raise ValueError( 167 | "Can not evaluate output shapes. Check your definitions!" 168 | ) from e 169 | for val, func in zip((dx, y), ("vector_field, output")): # noqa: B905 170 | if not isinstance(val, jax.ShapeDtypeStruct): 171 | raise ValueError( 172 | f"{func} must return arrays or scalars, not {type(val)}" 173 | ) 174 | 175 | @abstractmethod 176 | def vector_field( 177 | self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None 178 | ) -> Array: 179 | """Compute state derivative. 180 | 181 | Args: 182 | x: State vector. 183 | u: Optional input vector. 184 | t: Optional time. 185 | 186 | Returns: 187 | State derivative. 188 | 189 | """ 190 | raise NotImplementedError 191 | 192 | def output( 193 | self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None 194 | ) -> Array: 195 | """Compute output. 196 | 197 | Args: 198 | x: State vector. 199 | u: Optional input vector. 200 | t: Optional time. 201 | 202 | Returns: 203 | System output. 204 | 205 | """ 206 | return x 207 | 208 | @property 209 | def n_outputs(self) -> int | Literal["scalar"]: 210 | """The size of the output vector.""" 211 | with jax.ensure_compile_time_eval(): 212 | x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64) 213 | u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) 214 | y = eqx.filter_eval_shape(self.output, x, u, t=1.0) 215 | n_out = "scalar" if y.ndim == 0 else y.shape[0] 216 | return n_out 217 | 218 | def linearize( 219 | self, 220 | x0: Array | None = None, 221 | u0: Array | None = None, 222 | t: FloatScalarLike | None = None, 223 | ) -> "LinearSystem": 224 | """Compute the Jacobian linearizationaround a point. 225 | 226 | Args: 227 | x0: State at which to linearize. Defaults to `initial_state`. 228 | u0: Input at which to linearize. Defaults to zero input. 229 | t: Time at which to linearize. 230 | 231 | Returns: 232 | Linearized system. 233 | 234 | """ 235 | if x0 is None: 236 | x0 = self.initial_state 237 | if u0 is None: 238 | u0 = jnp.zeros(dim2shape(self.n_inputs)) 239 | A, B, C, D = _linearize(self.vector_field, self.output, x0, u0, t) 240 | return LinearSystem(A, B, C, D) 241 | 242 | def pretty(self) -> str: 243 | """Return a pretty formatted string representation. 244 | 245 | The string includes the constrains of all trainable parameters and the values of 246 | all parameters. 247 | """ 248 | return pretty(self) 249 | 250 | 251 | class AbstractControlAffine(AbstractSystem): 252 | r"""Base class for control-affine dynamical systems. 253 | 254 | Both in continuous-time 255 | 256 | .. math:: 257 | 258 | ẋ &= f(x) + g(x)u \\ 259 | y &= h(x) + i(x)u 260 | 261 | or the discrete-time equivalent. 262 | 263 | Subclasses must implement the `f` and `g` methods that characterize the vector 264 | field. Optionally, the `h` and `i` methods can be implemented to describe the 265 | measurement equations. By default, the full state vector is returned as output. 266 | 267 | """ 268 | 269 | @abstractmethod 270 | def f(self, x: Array) -> Array: 271 | """The constant-input part of the vector field.""" 272 | pass 273 | 274 | @abstractmethod 275 | def g(self, x: Array) -> Array: 276 | """The input-proportional part of the vector field.""" 277 | pass 278 | 279 | def h(self, x: Array) -> Array: 280 | """The constant-input part of the output equation.""" 281 | return x 282 | 283 | def i(self, x: Array) -> Array: 284 | """The input-proportional part of the output equation.""" 285 | return jnp.array(0.0) 286 | 287 | def vector_field(self, x, u=None, t=None): 288 | out = self.f(x) 289 | if u is not None: 290 | out += self.g(x).dot(u) 291 | return out 292 | 293 | def output(self, x, u=None, t=None): 294 | out = self.h(x) 295 | if u is not None: 296 | out += self.i(x).dot(u) 297 | return out 298 | 299 | 300 | class LinearSystem(AbstractControlAffine): 301 | r"""A linear, time-invariant dynamical system. 302 | 303 | .. math:: 304 | 305 | ẋ &= Ax + Bu \\ 306 | y &= Cx + Du 307 | 308 | Args: 309 | A, B, C, D: System matrices of appropriate shape. 310 | 311 | """ 312 | 313 | A: Array 314 | """State matrix.""" 315 | B: Array 316 | """Input matrix.""" 317 | C: Array 318 | """Output matrix.""" 319 | D: Array 320 | """Feedthrough matrix.""" 321 | 322 | def __post_init__(self): 323 | # Without this context manager, `initial_state` will leak later 324 | with jax.ensure_compile_time_eval(): 325 | self.initial_state = ( 326 | jnp.array(0) if self.A.ndim == 0 else jnp.zeros(self.A.shape[0]) 327 | ) 328 | if self.initial_state.ndim == 0: 329 | if self.B.ndim == 0: 330 | self.n_inputs = "scalar" 331 | elif self.B.ndim == 1: 332 | self.n_inputs = self.B.size 333 | else: 334 | raise ValueError("Dimension mismatch.") 335 | else: 336 | if self.B.ndim == 1: 337 | self.n_inputs = "scalar" 338 | elif self.B.ndim == 2: 339 | self.n_inputs = self.B.shape[1] 340 | else: 341 | raise ValueError("Dimension mismatch.") 342 | 343 | def f(self, x: Array) -> Array: 344 | return self.A.dot(x) 345 | 346 | def g(self, x: Array) -> Array: 347 | return self.B 348 | 349 | def h(self, x: Array) -> Array: 350 | return self.C.dot(x) 351 | 352 | def i(self, x: Array) -> Array: 353 | return self.D 354 | 355 | 356 | class _CoupledSystemMixin(eqx.Module): 357 | _sys1: AbstractSystem 358 | _sys2: AbstractSystem 359 | 360 | def _pack_states(self, x1: Array, x2: Array) -> Array: 361 | return jnp.concatenate( 362 | ( 363 | jnp.atleast_1d(x1), 364 | jnp.atleast_1d(x2), 365 | ) 366 | ) 367 | 368 | def _unpack_states(self, x: Array) -> tuple[Array, Array]: 369 | sys1_size = ( 370 | 1 371 | if jnp.ndim(self._sys1.initial_state) == 0 372 | else self._sys1.initial_state.size 373 | ) 374 | return ( 375 | x[:sys1_size].reshape(self._sys1.initial_state.shape), 376 | x[sys1_size:].reshape(self._sys2.initial_state.shape), 377 | ) 378 | 379 | 380 | class SeriesSystem(AbstractSystem, _CoupledSystemMixin): 381 | r"""Two systems in series. 382 | 383 | .. math:: 384 | 385 | ẋ_1 &= f_1(x_1, u, t) \\ 386 | y_1 &= h_1(x_1, u, t) \\ 387 | ẋ_2 &= f_2(x_2, y1, t) \\ 388 | y_2 &= h_2(x_2, y1, t) 389 | 390 | .. aafig:: 391 | 392 | +------+ +------+ 393 | u --+->+ sys1 +--y1->+ sys2 +--> y2 394 | +------+ +------+ 395 | 396 | Args: 397 | sys1: System with :math:`n` outputs. 398 | sys2: System with :math:`n` inputs. 399 | 400 | """ 401 | 402 | def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem): 403 | self._sys1 = sys1 404 | self._sys2 = sys2 405 | self.initial_state = self._pack_states(sys1.initial_state, sys2.initial_state) 406 | self.n_inputs = sys1.n_inputs 407 | 408 | def vector_field( 409 | self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None 410 | ) -> Array: 411 | x1, x2 = self._unpack_states(x) 412 | y1 = self._sys1.output(x1, u, t) 413 | dx1 = self._sys1.vector_field(x1, u, t) 414 | dx2 = self._sys2.vector_field(x2, y1, t) 415 | return self._pack_states(dx1, dx2) 416 | 417 | def output( 418 | self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None 419 | ) -> Array: 420 | x1, x2 = self._unpack_states(x) 421 | y1 = self._sys1.output(x1, u, t) 422 | y2 = self._sys2.output(x2, y1, t) 423 | return y2 424 | 425 | 426 | class FeedbackSystem(AbstractSystem, _CoupledSystemMixin): 427 | r"""Two systems connected via feedback. 428 | 429 | .. math:: 430 | 431 | ẋ_1 &= f_1(x_1, u + y_2, t) \\ 432 | y_1 &= h_1(x_1, t) \\ 433 | ẋ_2 &= f_2(x_2, y_1, t) \\ 434 | y_2 &= h_2(x_2, y_1, t) 435 | 436 | .. aafig:: 437 | 438 | +------+ 439 | u --+->+ sys1 +--+-> y1 440 | ^ +------+ | 441 | | | 442 | y2| +------+ | 443 | +--+ sys2 |<-+ 444 | +------+ 445 | 446 | Args: 447 | sys1: System in forward path with :math:`n` inputs. 448 | sys2: System in feedback path with :math:`n` outputs. 449 | 450 | """ 451 | 452 | def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem): 453 | self._sys1 = sys1 454 | self._sys2 = sys2 455 | self.initial_state = self._pack_states(sys1.initial_state, sys2.initial_state) 456 | self.n_inputs = sys1.n_inputs 457 | 458 | def vector_field( 459 | self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None 460 | ) -> Array: 461 | if u is None: 462 | u = jnp.zeros(dim2shape(self._sys1.n_inputs)) 463 | x1, x2 = self._unpack_states(x) 464 | y1 = self._sys1.output(x1, None, t) 465 | y2 = self._sys2.output(x2, y1, t) 466 | dx1 = self._sys1.vector_field(x1, u + y2, t) 467 | dx2 = self._sys2.vector_field(x2, y1, t) 468 | dx = self._pack_states(dx1, dx2) 469 | return dx 470 | 471 | def output( 472 | self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None 473 | ) -> Array: 474 | x1, _ = self._unpack_states(x) 475 | y = self._sys1.output(x1, None, t) 476 | return y 477 | 478 | 479 | class StaticStateFeedbackSystem(AbstractSystem): 480 | r"""System with static state-feedback. 481 | 482 | .. math:: 483 | 484 | ẋ &= f(x, v(x), t) \\ 485 | y &= h(x, u, t) 486 | 487 | .. aafig:: 488 | 489 | +-----+ 490 | u --+------------->+ sys +----> y 491 | ^ +--+--+ 492 | | | 493 | | | x 494 | | +--------+ | 495 | +--+ "v(x)" +<----+ 496 | +--------+ 497 | 498 | Args: 499 | sys: System with vector field :math:`f` and output :math:`h`. 500 | v: Static feedback law :math:`v`. 501 | 502 | """ 503 | 504 | _sys: AbstractSystem 505 | _v: Callable[[Array], Array] 506 | 507 | def __init__(self, sys: AbstractSystem, v: Callable[[Array], Array]): 508 | self._sys = sys 509 | self._v = staticmethod(v) 510 | self.initial_state = sys.initial_state 511 | self.n_inputs = sys.n_inputs 512 | 513 | def vector_field(self, x, u=None, t=None): 514 | v = self._v(x) 515 | dx = self._sys.vector_field(x, v, t) 516 | return dx 517 | 518 | def output(self, x, u=None, t=None): 519 | y = self._sys.output(x, u, t) 520 | return y 521 | 522 | 523 | class DynamicStateFeedbackSystem(AbstractSystem, _CoupledSystemMixin): 524 | r"""System with dynamic state-feedback. 525 | 526 | .. math:: 527 | 528 | ẋ_1 &= f_1(x_1, v(x_1, x_2, u), t) \\ 529 | ẋ_2 &= f_2(x_2, u, t) \\ 530 | y &= h_1(x_1, u, t) 531 | 532 | .. aafig:: 533 | 534 | +--------------+ +-----+ 535 | u -+->+ v(x1, x2, u) +--v->+ sys +-> y 536 | | +-+-------+----+ +--+--+ 537 | | ^ ^ | 538 | | | x2 | x1 | 539 | | | +-------------+ 540 | | +------+ 541 | +->+ sys2 | 542 | +------+ 543 | 544 | Args: 545 | sys1: System with vector field :math:`f_1` and output :math:`h`. 546 | sys2: System with vector field :math:`f_2`. 547 | v: dynamic feedback law :math:`v`. 548 | 549 | """ 550 | 551 | _v: Callable[[Array, Array, float], float] 552 | 553 | def __init__( 554 | self, 555 | sys1: AbstractSystem, 556 | sys2: AbstractSystem, 557 | v: Callable[[Array, Array, Array | float], float], 558 | ): 559 | self._sys1 = sys1 560 | self._sys2 = sys2 561 | self._v = staticmethod(v) 562 | self.initial_state = self._pack_states(sys1.initial_state, sys2.initial_state) 563 | self.n_inputs = sys1.n_inputs 564 | 565 | def vector_field(self, x, u=None, t=None): 566 | if u is None: 567 | u = np.zeros(dim2shape(self._sys1.n_inputs)) 568 | x1, x2 = self._unpack_states(x) 569 | v = self._v(x1, x2, u) 570 | dx = self._sys1.vector_field(x1, v, t) 571 | dz = self._sys2.vector_field(x2, u, t) 572 | return jnp.concatenate((dx, dz)) 573 | 574 | def output(self, x, u=None, t=None): 575 | x1, _ = self._unpack_states(x) 576 | y = self._sys1.output(x1, u, t) 577 | return y 578 | -------------------------------------------------------------------------------- /dynax/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Callable, Literal 3 | 4 | import equinox 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from .custom_types import Array, Scalar 9 | 10 | 11 | def value_and_jacfwd(fun: Callable, x: Array) -> tuple[Array, Callable]: 12 | """Create a function that evaluates both fun and its foward-mode jacobian. 13 | 14 | Args: 15 | fun: Function whose Jacobian is to be computed. 16 | x: Point at which function and Jacobian is evaluated. 17 | 18 | From this `issue `_. 19 | """ 20 | pushfwd = functools.partial(jax.jvp, fun, (x,)) 21 | basis = jnp.eye(x.size, dtype=x.dtype) 22 | y, jac = jax.vmap(pushfwd, out_axes=(None, 1))((basis,)) 23 | return y, jac 24 | 25 | 26 | def value_and_jacrev(fun: Callable, x: Array) -> tuple[Array, Callable]: 27 | """Create a function that evaluates both fun and its reverse-mode jacobian. 28 | 29 | Args: 30 | fun: Function whose Jacobian is to be computed. 31 | x: Point at which function and Jacobian is evaluated. 32 | 33 | From this `issue `_. 34 | """ 35 | y, pullback = jax.vjp(fun, x) 36 | basis = jnp.eye(y.size, dtype=y.dtype) 37 | jac = jax.vmap(pullback)(basis) 38 | return y, jac 39 | 40 | 41 | def mse(target: Array, prediction: Array, axis: int = 0) -> Scalar: 42 | """Compute mean-squared error.""" 43 | return jnp.mean(jnp.abs(target - prediction) ** 2, axis=axis) 44 | 45 | 46 | def nmse(target: Array, prediction: Array, axis: int = 0) -> Scalar: 47 | """Compute normalized mean-squared error.""" 48 | return mse(target, prediction, axis) / jnp.mean(jnp.abs(target) ** 2, axis=axis) 49 | 50 | 51 | def nrmse(target: Array, prediction: Array, axis: int = 0) -> Scalar: 52 | """Compute normalized root mean-squared error.""" 53 | return jnp.sqrt(nmse(target, prediction, axis)) 54 | 55 | 56 | def _monkeypatch_pretty_print(): 57 | from equinox._pretty_print import named_objs, bracketed, pp, dataclasses # noqa 58 | 59 | def _pformat_dataclass(obj, **kwargs): 60 | def field_kind(field): 61 | if field.metadata.get("static", False): 62 | return "(static)" 63 | elif constr := field.metadata.get("constrained", False): 64 | return f"({constr[0]}: {constr[1]})" 65 | return "" 66 | 67 | objs = named_objs( 68 | [ 69 | ( 70 | field.name + field_kind(field), 71 | getattr(obj, field.name, ""), 72 | ) 73 | for field in dataclasses.fields(obj) 74 | if field.repr 75 | ], 76 | **kwargs, 77 | ) 78 | return bracketed( 79 | name=pp.text(obj.__class__.__name__), 80 | indent=kwargs["indent"], 81 | objs=objs, 82 | lbracket="(", 83 | rbracket=")", 84 | ) 85 | 86 | equinox._pretty_print._pformat_dataclass = _pformat_dataclass 87 | 88 | 89 | def pretty(tree): 90 | return equinox.tree_pformat(tree, short_arrays=False) 91 | 92 | 93 | def broadcast_right(arr, target): 94 | return arr.reshape(arr.shape + (1,) * (target.ndim - arr.ndim)) 95 | 96 | 97 | def dim2shape(x: int | Literal["scalar"]) -> tuple: 98 | return () if x == "scalar" else (x,) 99 | -------------------------------------------------------------------------------- /examples/fit_initial_state.py: -------------------------------------------------------------------------------- 1 | """Fit a parameters and initial values of an ODE.""" 2 | 3 | import jax 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | from dynax import field, fit_least_squares, Flow, pretty 8 | from dynax.example_models import NonlinearDrag 9 | 10 | 11 | # Declare the initial values as trainable using the `field` function. 12 | class NonlinearDragWithInitialValues(NonlinearDrag): 13 | initial_state: jax.Array = field(init=True) 14 | 15 | 16 | # Initiate a dynamical system representing the some "true" parameters. 17 | true_system = NonlinearDragWithInitialValues( 18 | m=1.0, r=2.0, r2=0.1, k=4.0, initial_state=(0.5, 1.0) 19 | ) 20 | 21 | # Combine ODE system and ODE solver (Dopri5 and constant stepsize by default). 22 | true_model = Flow(true_system) 23 | print("true system:", pretty(true_system)) 24 | 25 | # Create some training data using the true model. This could be your measurement data. 26 | t_train = np.linspace(0, 2, 100) 27 | samplerate = 1 / t_train[1] 28 | np.random.seed(42) 29 | u_train = np.random.normal(size=len(t_train)) 30 | x_train, y_train = true_model(t_train, u_train) 31 | 32 | # Create our model system with some initial parameters. 33 | initial_sys = NonlinearDragWithInitialValues( 34 | m=1.0, r=1.0, r2=1.0, k=1.0, initial_state=(0.0, 0.0) 35 | ) 36 | print("initial system:", pretty(initial_sys)) 37 | 38 | # Combine the ODE with an ODE solver. 39 | init_model = Flow(initial_sys) 40 | 41 | # Fit the parameters of the nonlinear system including the initial state. 42 | pred_model = fit_least_squares( 43 | model=init_model, t=t_train, y=y_train, u=u_train, verbose=0 44 | ).result 45 | print("fitted system:", pretty(pred_model.system)) 46 | 47 | # Check the results. 48 | x_pred, y_pred = pred_model(t_train, u_train) 49 | assert np.allclose(x_train, x_pred) 50 | assert np.allclose(true_system.initial_state, pred_model.system.initial_state) 51 | 52 | plt.plot(t_train, x_train, "--", label="target") 53 | plt.plot(t_train, x_pred, label="prediction") 54 | plt.legend() 55 | plt.show() 56 | -------------------------------------------------------------------------------- /examples/fit_long_input.py: -------------------------------------------------------------------------------- 1 | """Fit a second-order nonlinear system to data for which we have long measurements.""" 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from dynax import fit_csd_matching, fit_least_squares, Flow, pretty 7 | from dynax.example_models import NonlinearDrag 8 | 9 | 10 | # Initiate a dynamical system representing the some "true" parameters. 11 | true_system = NonlinearDrag(m=1.0, r=2.0, r2=0.1, k=4.0) 12 | # Combine ODE system and ODE solver (Dopri5 and constant stepsize by default). 13 | true_model = Flow(true_system) 14 | print("true system:", pretty(true_system)) 15 | 16 | # Create some training data using the true model. This could be your measurement data. 17 | t_train = np.linspace(0, 50, 5000) 18 | samplerate = 1 / t_train[1] 19 | np.random.seed(42) 20 | u_train = np.random.normal(size=len(t_train)) 21 | x_train, y_train = true_model(t_train, u_train) 22 | 23 | # Create our model system with some initial parameters. 24 | initial_sys = NonlinearDrag(m=1.0, r=1.0, r2=1.0, k=1.0) 25 | print("initial system:", pretty(initial_sys)) 26 | 27 | # If we have long-duration, wide-band input data we can fit the linear 28 | # parameters first by matching the transfer-functions. In this example the result is 29 | # not very good. 30 | initial_sys = fit_csd_matching( 31 | initial_sys, u_train, y_train, samplerate, nperseg=500 32 | ).result 33 | print("linear params fitted:", pretty(initial_sys)) 34 | 35 | # Combine the fitted ODE with an ODE solver 36 | init_model = Flow(initial_sys) 37 | # Fit the parameters of the nonlinear system with previously estimated parameters as a 38 | # starting guess. 39 | pred_model = fit_least_squares( 40 | model=init_model, t=t_train, y=y_train, u=u_train, verbose=0 41 | ).result 42 | print("fitted system:", pretty(pred_model.system)) 43 | 44 | # Check the results. 45 | x_pred, y_pred = pred_model(t_train, u_train) 46 | assert np.allclose(x_train, x_pred) 47 | 48 | plt.plot(t_train, x_train, "--", label="target") 49 | plt.plot(t_train, x_pred, label="prediction") 50 | plt.legend() 51 | plt.show() 52 | -------------------------------------------------------------------------------- /examples/fit_multiple_shooting.py: -------------------------------------------------------------------------------- 1 | """Example: fit a second-order nonlinear system to data.""" 2 | 3 | import equinox as eqx 4 | import jax 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from dynax import fit_multiple_shooting, Flow, pretty 9 | from dynax.example_models import LotkaVolterra 10 | 11 | 12 | # Initiate a dynamical system representing the some "true" parameters. 13 | true_system = LotkaVolterra(alpha=0.1, beta=0.2, gamma=0.3, delta=0.4) 14 | # Combine ODE system with ODE solver (Dopri5 and constant stepsize by default) 15 | true_model = Flow(true_system) 16 | print("true system:", true_system) 17 | 18 | # Generate training data using the true model. This could be your measurement data. 19 | t_train = np.linspace(0, 100, 1000) 20 | _, y_train = true_model(t_train) 21 | 22 | # Initiate ODE with some initial parameters. 23 | initial_sys = LotkaVolterra(alpha=0.5, beta=0.5, gamma=0.5, delta=0.5) 24 | print("initial system:", pretty(initial_sys)) 25 | 26 | # Combine the ODE with an ODE solver. 27 | init_model = Flow(initial_sys) 28 | 29 | # Fiting with single shooting fails: the optimizer gets stuck in local minima. 30 | num_shots = 1 31 | res = fit_multiple_shooting( 32 | model=init_model, 33 | t=t_train, 34 | y=y_train, 35 | verbose=2, 36 | num_shots=num_shots, 37 | ) 38 | model = res.result 39 | x0s = res.x0s 40 | ts = res.ts 41 | ts0 = res.ts0 42 | print("single shooting:", pretty(model.system)) 43 | 44 | plt.figure() 45 | plt.title("single shooting") 46 | _, ys_pred = jax.vmap(model)(ts0, initial_state=x0s) 47 | plt.plot(t_train, y_train, "k--", label="target") 48 | for i in range(num_shots): 49 | plt.plot(ts[i], ys_pred[i], label="fitted", color=f"C{i}") 50 | for j in range(x0s.shape[1]): 51 | plt.scatter(ts[i, 0], x0s[i, j], c=f"C{i}") 52 | plt.plot() 53 | plt.legend() 54 | 55 | # Multiple shooting to the rescue. 56 | num_shots = 3 57 | res = fit_multiple_shooting( 58 | model=init_model, 59 | t=t_train, 60 | y=y_train, 61 | verbose=2, 62 | num_shots=num_shots, 63 | ) 64 | model = res.result 65 | x0s = res.x0s 66 | ts = res.ts 67 | ts0 = res.ts0 68 | print("multiple shooting:", pretty(model.system)) 69 | 70 | plt.figure() 71 | plt.title("multiple shooting") 72 | _, ys_pred = jax.vmap(model)(ts0, initial_state=x0s) 73 | plt.plot(t_train, y_train, "k--", label="target") 74 | for i in range(num_shots): 75 | plt.plot(ts[i], ys_pred[i], label="fitted", color=f"C{i}") 76 | for j in range(x0s.shape[1]): 77 | plt.scatter(ts[i, 0], x0s[i, j], c=f"C{i}") 78 | plt.plot() 79 | plt.legend() 80 | 81 | plt.show() 82 | 83 | # Check the results 84 | _, y_pred = model(t_train) 85 | assert eqx.tree_equal(model.system, true_system, rtol=1e-3, atol=1e-3) 86 | assert np.allclose(y_train, y_pred, atol=1e-5, rtol=1e-5) 87 | -------------------------------------------------------------------------------- /examples/linearize_ode.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from dynax import ( 5 | Flow, 6 | LinearizingSystem, 7 | relative_degree, 8 | ) 9 | from dynax.example_models import NonlinearDrag 10 | 11 | 12 | # The system to control: a simple spring-mass-damper system with strong nonlinear drag. 13 | system = NonlinearDrag(r=1.0, r2=5.0, k=1.0, m=1.0) 14 | 15 | # The linear reference system is the system linearized around the origin. 16 | reference_system = system.linearize() 17 | 18 | # We want the nonlinear systems output to be equal to the reference system's output 19 | # when driven with the following input. 20 | t = np.linspace(0, 10, 1000) 21 | u = 10 * np.sin(2 * np.pi * t) 22 | 23 | # Compute the relative degree of the system over a set of test states. 24 | reldeg = relative_degree( 25 | sys=system, xs=np.random.normal(size=(100, len(system.initial_state))) 26 | ) 27 | 28 | # The input signal that forces the outputs of the nonlinear and reference 29 | # systems to be equal is computed by solving a coupled ODE system constructed by 30 | # `dynax.LinearizingSystem`. 31 | linearizing_system = LinearizingSystem(system, reference_system, reldeg) 32 | 33 | # The output of this system when driven with the reference input is the linearizing 34 | # input. 35 | _, linearizing_inputs = Flow(linearizing_system)(t=t, u=u) 36 | 37 | # Lets simulate the original system, 38 | states_orig, output_orig = Flow(system)(t=t, u=u) 39 | # the linear reference system, 40 | _, output_ref = Flow(reference_system)(t=t, u=u) 41 | # and the nonlinear system driven with the linearizing signal. 42 | _, output_linearized = Flow(system)(t=t, u=linearizing_inputs) 43 | 44 | 45 | plt.plot(t, output_orig, label="nonlinear drag") 46 | plt.plot(t, output_ref, label="linear reference") 47 | plt.plot(t, output_linearized, "--", label="input-output linearized") 48 | plt.legend() 49 | plt.figure() 50 | plt.plot(t, u, label="input to reference system") 51 | plt.plot(t, linearizing_inputs, label="linearizing input") 52 | plt.legend() 53 | plt.show() 54 | 55 | # The output of the linearized system is equal to the output of the reference system! 56 | assert np.allclose(output_ref, output_linearized, atol=1e-4) 57 | -------------------------------------------------------------------------------- /examples/linearize_recurrent_network.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from equinox.nn import GRUCell 5 | from jax.random import PRNGKey 6 | 7 | from dynax import ( 8 | AbstractSystem, 9 | discrete_relative_degree, 10 | DiscreteLinearizingSystem, 11 | LinearSystem, 12 | Map, 13 | ) 14 | 15 | 16 | # The nonlinear system to control: a simple RNN with a GRU cell 17 | class Recurrent(AbstractSystem): 18 | cell: GRUCell 19 | 20 | n_inputs = "scalar" 21 | 22 | def __init__(self, hidden_size, *, key): 23 | self.cell = GRUCell( 24 | input_size=1, hidden_size=hidden_size, use_bias=False, key=key 25 | ) 26 | self.initial_state = np.zeros(hidden_size) 27 | 28 | def vector_field(self, x, u, t=None): 29 | return self.cell(jnp.array([u]), x) 30 | 31 | def output(self, x, u=None, t=None): 32 | return x[0] 33 | 34 | 35 | hidden_size = 3 36 | system = Recurrent(hidden_size=hidden_size, key=PRNGKey(0)) 37 | 38 | # A linear reference system. 39 | reference_system = LinearSystem( 40 | A=jnp.array([[-0.3, 0.1], [0, -0.3]]), 41 | B=jnp.array([0.0, 1.0]), 42 | C=jnp.array([1, 0]), 43 | D=jnp.array(0), 44 | ) 45 | 46 | # We want the nonlinear systems output to be equal to the reference system's output 47 | # when driven with the following input. 48 | u = 0.1 * jnp.concatenate((jnp.array([0.1, 0.2, 0.3]), jnp.zeros(10))) 49 | 50 | # The relative degree of the reference system can be larger or equal to the relative 51 | # degree of the nonlinear system. Here we test for the relative degree with a set of 52 | # points and inputs. 53 | reldeg = discrete_relative_degree( 54 | system, np.random.normal(size=(len(u),) + system.initial_state.shape), u 55 | ) 56 | print("Relative degree of nonlinear system:", reldeg) 57 | print( 58 | "Relative degree of reference system:", 59 | discrete_relative_degree( 60 | reference_system, 61 | np.random.normal(size=(len(u),) + reference_system.initial_state.shape), 62 | u, 63 | ), 64 | ) 65 | 66 | # We compute the input signal that forces the outputs of the nonlinear and reference 67 | # systems to be equal by solving a coupled system that is constructed by 68 | # `dynax.DiscreteLinearizingSystem` 69 | linearizing_system = DiscreteLinearizingSystem(system, reference_system, reldeg) 70 | 71 | # The output of this system when driven with the reference input is the linearizing 72 | # input. 73 | _, linearizing_inputs = Map(linearizing_system)(u=u) 74 | 75 | # Lets simulate the original system, 76 | states_orig, output_orig = Map(system)(u=u) 77 | # the linear reference system, 78 | _, output_ref = Map(reference_system)(u=u) 79 | # and the nonlinear system driven with the linearizing signal. 80 | _, output_linearized = Map(system)(u=linearizing_inputs) 81 | 82 | # The output of the linearized system is equal to the output of the reference system! 83 | assert np.allclose(output_ref, output_linearized) 84 | 85 | plt.plot(output_orig, label="GRUCell") 86 | plt.plot(output_ref, label="linear reference") 87 | plt.plot(output_linearized, "--", label="input-output linearized GRU") 88 | plt.legend() 89 | plt.figure() 90 | plt.plot(u, label="input to reference system") 91 | plt.plot(linearizing_inputs, label="linearizing input") 92 | plt.legend() 93 | plt.show() 94 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "dynax" 3 | version = "0.0.3" 4 | description = "Dynamical systems with JAX!" 5 | readme = "README.rst" 6 | requires-python = ">=3.10" 7 | license = { file = "LICENSE" } 8 | authors = [{ name = "Franz M. Heuchel", email = "franz.heuchel@pm.me" }] 9 | keywords = [ 10 | "jax", 11 | "dynamical-systems", 12 | "system-identification", 13 | "linearization", 14 | ] 15 | urls = { repository = "https://github.com/fhchl/dynax" } 16 | dependencies = ["jax<=0.4.33", "diffrax<=0.6"] 17 | 18 | [project.optional-dependencies] 19 | dev = ["pytest", "jupyter", "matplotlib", "pre-commit", "ruff"] 20 | docs = [ 21 | "nbsphinx", 22 | "sphinx-autobuild", 23 | "sphinx-autodoc-typehints", 24 | "sphinx-rtd-theme", 25 | "sphinx", 26 | "sphinxcontrib-bibtex", 27 | "sphinxcontrib-aafig", 28 | "furo", 29 | ] 30 | 31 | [tool.pytest.ini_options] 32 | addopts = [ 33 | "--pdbcls=IPython.terminal.debugger:Pdb", 34 | # "--jaxtyping-packages=dynax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" 35 | ] 36 | 37 | [build-system] 38 | requires = ["hatchling"] 39 | build-backend = "hatchling.build" 40 | 41 | [tool.ruff] 42 | line-length = 88 43 | src = ["dynax", "tests", "examples"] 44 | force-exclude = true 45 | extend-include = ["*.ipynb"] 46 | 47 | [tool.ruff.lint] 48 | select = ["E", "F", "I001", "B"] 49 | ignore = [ 50 | "E402", # Module level import not at top of file 51 | "E721", # Do not compare types, use 'isinstance()' 52 | "E731", # Do not assign a lambda expression, use a def (E731) 53 | "E741", # Do not use variables named 'I', 'O', or 'l' 54 | ] 55 | fixable = ["I001", "F401"] 56 | 57 | [tool.ruff.lint.isort] 58 | combine-as-imports = true 59 | lines-after-imports = 2 60 | order-by-type = false 61 | known-first-party = ["dynax"] 62 | 63 | [tool.uv] 64 | dev-dependencies = [ 65 | "dynax[dev, docs]", 66 | ] 67 | 68 | [tool.uv.sources] 69 | dynax = { workspace = true } 70 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def pytest_configure(config): 5 | config.addinivalue_line("markers", "slow: run slow tests") 6 | 7 | 8 | def pytest_addoption(parser): 9 | parser.addoption( 10 | "--runslow", action="store_true", default=False, help="run slow tests" 11 | ) 12 | 13 | 14 | def pytest_collection_modifyitems(config, items): 15 | if config.getoption("--runslow"): 16 | # --runslow given in cli: do not skip slow tests 17 | return 18 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 19 | for item in items: 20 | if "slow" in item.keywords: 21 | item.add_marker(skip_slow) 22 | -------------------------------------------------------------------------------- /tests/test_ad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.testing as npt 3 | 4 | from dynax.derivative import lie_derivative, lie_derivative_jet, lie_derivatives_jet 5 | from dynax.example_models import Sastry9_9 6 | 7 | 8 | def test_lie_derivative(): 9 | sys = Sastry9_9() 10 | f = sys.f 11 | g = sys.g 12 | h = sys.h 13 | 14 | np.random.seed(0) 15 | xs = np.random.normal(size=(10, 3)) 16 | for x in xs: 17 | x1, x2, x3 = x 18 | npt.assert_allclose(lie_derivative(f, h, n=1)(x), x1 - x2) 19 | npt.assert_allclose(lie_derivative(f, h, n=2)(x), -x1 - x2**2) 20 | npt.assert_allclose(lie_derivative(f, h, n=3)(x), -2 * x2 * (x1 + x2**2)) 21 | npt.assert_allclose(lie_derivative(g, h, n=1)(x), 0) 22 | npt.assert_allclose(lie_derivative(g, lie_derivative(f, h, n=1))(x), 0) 23 | npt.assert_allclose( 24 | lie_derivative(g, lie_derivative(f, h, n=2))(x), 25 | -(1 + 2 * x2) * np.exp(x2), 26 | rtol=1e-6, 27 | ) 28 | 29 | 30 | def test_lie_derivative2(): 31 | sys = Sastry9_9() 32 | f = sys.f 33 | g = sys.g 34 | h = sys.h 35 | 36 | np.random.seed(0) 37 | xs = np.random.normal(size=(10, 3)) 38 | tol = dict(atol=1e-8, rtol=1e-6) 39 | 40 | for x in xs: 41 | x1, x2, _ = x 42 | npt.assert_allclose( 43 | lie_derivatives_jet(f, h, n=3)(x), 44 | [h(x), x1 - x2, -x1 - x2**2, -2 * x2 * (x1 + x2**2)], 45 | **tol, 46 | ) 47 | npt.assert_allclose(lie_derivative_jet(g, h, n=1)(x), 0, **tol) 48 | npt.assert_allclose(lie_derivative_jet(g, lie_derivative_jet(f, h, n=1))(x), 0) 49 | npt.assert_allclose( 50 | lie_derivative_jet(g, lie_derivative_jet(f, h, n=2))(x), 51 | -(1 + 2 * x2) * np.exp(x2), 52 | rtol=1e-5, 53 | ) 54 | -------------------------------------------------------------------------------- /tests/test_estimation.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import numpy.testing as npt 6 | import pytest 7 | from diffrax import Kvaerno5, PIDController 8 | from jax import Array 9 | 10 | from dynax import ( 11 | AbstractSystem, 12 | field, 13 | fit_csd_matching, 14 | fit_least_squares, 15 | fit_multiple_shooting, 16 | Flow, 17 | non_negative_field, 18 | transfer_function, 19 | ) 20 | from dynax.example_models import LotkaVolterra, NonlinearDrag, SpringMassDamper 21 | 22 | 23 | tols = {"rtol": 1e-02, "atol": 1e-04} 24 | 25 | 26 | @pytest.mark.parametrize("outputs", [(0,), (0, 1)]) 27 | def test_fit_least_squares(outputs): 28 | # data 29 | t = np.linspace(0, 1, 100) 30 | u = ( 31 | 0.1 * np.sin(1 * 2 * np.pi * t) 32 | + np.sin(0.1 * 2 * np.pi * t) 33 | + np.sin(10 * 2 * np.pi * t) 34 | ) 35 | true_model = Flow( 36 | NonlinearDrag(1.0, 2.0, 3.0, 4.0, outputs), 37 | ) 38 | _, y_true = true_model(t, u) 39 | # fit 40 | init_model = Flow(NonlinearDrag(1.0, 2.0, 3.0, 4.0, outputs)) 41 | pred_model = fit_least_squares(init_model, t, y_true, u, verbose=2).result 42 | # check result 43 | _, y_pred = pred_model(t, u) 44 | npt.assert_allclose(y_pred, y_true, **tols) 45 | assert eqx.tree_equal(pred_model, true_model, **tols) 46 | 47 | 48 | def test_fit_least_squares_on_batch(): 49 | # data 50 | t = np.linspace(0, 1, 100) 51 | us = np.stack( 52 | ( 53 | np.sin(1 * 2 * np.pi * t), 54 | np.sin(0.1 * 2 * np.pi * t), 55 | np.sin(10 * 2 * np.pi * t), 56 | ), 57 | axis=0, 58 | ) 59 | ts = np.repeat(t[None], us.shape[0], axis=0) 60 | true_model = Flow( 61 | NonlinearDrag(1.0, 2.0, 3.0, 4.0), 62 | ) 63 | _, ys = jax.vmap(true_model)(ts, us) 64 | # fit 65 | init_model = Flow( 66 | NonlinearDrag(1.0, 2.0, 3.0, 4.0), 67 | ) 68 | pred_model = fit_least_squares(init_model, ts, ys, us, batched=True).result 69 | # check result 70 | _, ys_pred = jax.vmap(pred_model)(ts, us) 71 | npt.assert_allclose(ys_pred, ys, **tols) 72 | assert eqx.tree_equal(pred_model, true_model, **tols) 73 | 74 | 75 | def test_can_compute_jacfwd_with_implicit_methods(): 76 | # don't get caught by https://github.com/patrick-kidger/diffrax/issues/135 77 | t = jnp.linspace(0, 1, 10) 78 | x0 = jnp.array([1.0, 0.0]) 79 | solver_opt = dict( 80 | solver=Kvaerno5(), stepsize_controller=PIDController(atol=1e-6, rtol=1e-3) 81 | ) 82 | 83 | def fun(m, r, k, x0=x0, solver_opt=solver_opt, t=t): 84 | model = Flow(SpringMassDamper(m, r, k), **solver_opt) 85 | x_true, _ = model(t, u=jnp.zeros_like(t), initial_state=x0) 86 | return x_true 87 | 88 | jac = jax.jacfwd(fun, argnums=(0, 1, 2)) 89 | jac(1.0, 2.0, 3.0) 90 | 91 | 92 | def test_fit_with_bounded_parameters(): 93 | # data 94 | t = jnp.linspace(0, 1, 100) 95 | solver_opt = dict(stepsize_controller=PIDController(rtol=1e-5, atol=1e-7)) 96 | true_model = Flow( 97 | LotkaVolterra(alpha=2 / 3, beta=4 / 3, gamma=1.0, delta=1.0), **solver_opt 98 | ) 99 | x_true, _ = true_model(t) 100 | # fit 101 | init_model = Flow( 102 | LotkaVolterra(alpha=1.0, beta=1.0, gamma=1.5, delta=2.0), **solver_opt 103 | ) 104 | pred_model = fit_least_squares(init_model, t, x_true).result 105 | # check result 106 | x_pred, _ = pred_model(t) 107 | npt.assert_allclose(x_pred, x_true, **tols) 108 | assert eqx.tree_equal(pred_model, true_model, **tols) 109 | 110 | 111 | def test_fit_with_bounded_parameters_and_ndarrays(): 112 | # model 113 | class LotkaVolterraBounded(AbstractSystem): 114 | alpha: float = field() 115 | beta: float = field() 116 | delta_gamma: Array = non_negative_field() 117 | 118 | initial_state = np.array((0.5, 0.5)) 119 | n_inputs = 0 120 | 121 | def vector_field(self, x, u=None, t=None): 122 | x, y = x 123 | gamma, delta = self.delta_gamma 124 | return jnp.array( 125 | [self.alpha * x - self.beta * x * y, delta * x * y - gamma * y] 126 | ) 127 | 128 | # data 129 | t = jnp.linspace(0, 1, 100) 130 | solver_opt = dict(stepsize_controller=PIDController(rtol=1e-5, atol=1e-7)) 131 | true_model = Flow( 132 | LotkaVolterraBounded( 133 | alpha=2 / 3, beta=4 / 3, delta_gamma=jnp.array([1.0, 1.0]) 134 | ), 135 | **solver_opt, 136 | ) 137 | x_true, _ = true_model(t) 138 | # fit 139 | init_model = Flow( 140 | LotkaVolterraBounded(alpha=1.0, beta=1.0, delta_gamma=jnp.array([1.5, 2])), 141 | **solver_opt, 142 | ) 143 | pred_model = fit_least_squares(init_model, t, x_true).result 144 | # check result 145 | x_pred, _ = pred_model(t) 146 | assert eqx.tree_equal(pred_model, true_model, **tols) 147 | npt.assert_allclose(x_pred, x_true, **tols) 148 | 149 | 150 | @pytest.mark.parametrize("num_shots", [1, 2, 3]) 151 | def test_fit_multiple_shooting_with_input(num_shots): 152 | # data 153 | t = jnp.linspace(0, 1, 200) 154 | u = jnp.sin(1 * 2 * np.pi * t) 155 | true_model = Flow(SpringMassDamper(1.0, 2.0, 3.0)) 156 | x_true, _ = true_model(t, u) 157 | # fit 158 | init_model = Flow(SpringMassDamper(1.0, 1.0, 1.0)) 159 | pred_model = fit_multiple_shooting( 160 | init_model, 161 | t, 162 | x_true, 163 | u, 164 | continuity_penalty=1, 165 | num_shots=num_shots, 166 | verbose=2, 167 | ).result 168 | # check result 169 | x_pred, _ = pred_model(t, u) 170 | npt.assert_allclose(x_pred, x_true, **tols) 171 | assert eqx.tree_equal(pred_model, true_model, **tols) 172 | 173 | 174 | @pytest.mark.parametrize("num_shots", [1, 2, 3]) 175 | def test_fit_multiple_shooting_without_input(num_shots): 176 | # data 177 | t = jnp.linspace(0, 1, 200) 178 | solver_opt = dict(stepsize_controller=PIDController(rtol=1e-3, atol=1e-6)) 179 | true_model = Flow( 180 | LotkaVolterra(alpha=2 / 3, beta=4 / 3, gamma=1.0, delta=1.0), **solver_opt 181 | ) 182 | x_true, _ = true_model(t) 183 | # fit 184 | init_model = Flow( 185 | LotkaVolterra(alpha=1.0, beta=1.0, gamma=1.5, delta=2.0), **solver_opt 186 | ) 187 | pred_model = fit_multiple_shooting( 188 | init_model, t, x_true, num_shots=num_shots, continuity_penalty=1 189 | ).result 190 | # check result 191 | x_pred, _ = pred_model(t) 192 | npt.assert_allclose(x_pred, x_true, atol=1e-3, rtol=1e-3) 193 | assert eqx.tree_equal( 194 | pred_model, 195 | true_model, 196 | atol=1e-2, 197 | rtol=1e-2, 198 | ) 199 | 200 | 201 | def test_transfer_function(): 202 | sys = SpringMassDamper(1.0, 1.0, 1.0) 203 | sr = 100 204 | f = jnp.linspace(0, sr / 2, 100) 205 | s = 2 * np.pi * f * 1j 206 | H = jax.vmap(transfer_function(sys))(s)[:, 0] 207 | H_true = 1 / (sys.m * s**2 + sys.r * s + sys.k) 208 | npt.assert_array_almost_equal(H, H_true) 209 | 210 | 211 | # FIXME: this test fails in jax>0.4.23 when run with others, but succeeds alone ... 212 | def test_csd_matching(): 213 | np.random.seed(123) 214 | # model 215 | sys = SpringMassDamper(1.0, 1.0, 1.0) 216 | model = Flow(sys, stepsize_controller=PIDController(rtol=1e-4, atol=1e-6)) 217 | # input 218 | duration = 1000 219 | sr = 50 220 | t = np.arange(int(duration * sr)) / sr 221 | u = np.random.normal(size=len(t)) 222 | # output 223 | _, y = model(t, u) 224 | # fit 225 | init_sys = SpringMassDamper(1.0, 1.0, 1.0) 226 | fitted_sys = fit_csd_matching(init_sys, u, y, sr, nperseg=1024, verbose=1).result 227 | 228 | assert eqx.tree_equal( 229 | fitted_sys, 230 | sys, 231 | rtol=1e-1, 232 | atol=1e-1, 233 | ) 234 | 235 | 236 | def test_estimate_initial_state(): 237 | class NonlinearDragFreeInitialState(NonlinearDrag): 238 | initial_state: Array = field(init=False) 239 | 240 | def __post_init__(self): 241 | self.initial_state = jnp.zeros(2) 242 | 243 | # data 244 | t = np.linspace(0, 2, 200) 245 | u = ( 246 | np.sin(1 * 2 * np.pi * t) 247 | + np.sin(0.1 * 2 * np.pi * t) 248 | + np.sin(10 * 2 * np.pi * t) 249 | ) 250 | 251 | # True model has nonzero initial state 252 | true_initial_state = jnp.array([1.0, 0.5]) 253 | true_model = Flow(NonlinearDragFreeInitialState(1.0, 2.0, 3.0, 4.0, outputs=(0, 1))) 254 | true_model = eqx.tree_at( 255 | lambda t: t.system.initial_state, true_model, true_initial_state 256 | ) 257 | _, y_true = true_model(t, u, true_initial_state) 258 | 259 | # fit 260 | init_model = Flow(NonlinearDragFreeInitialState(1.0, 1.0, 1.0, 1.0, outputs=(0, 1))) 261 | pred_model = fit_least_squares(init_model, t, y_true, u=u).result 262 | 263 | # check result 264 | _, y_pred = pred_model(t, u) 265 | npt.assert_allclose(y_pred, y_true, **tols) 266 | npt.assert_allclose( 267 | pred_model.system.initial_state, 268 | true_initial_state, 269 | **tols, 270 | ) 271 | -------------------------------------------------------------------------------- /tests/test_evolution.py: -------------------------------------------------------------------------------- 1 | import diffrax as dfx 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import numpy.testing as npt 5 | from scipy.signal import dlsim, dlti 6 | 7 | from dynax import AbstractSystem, Flow, LinearSystem, Map 8 | 9 | 10 | tols = dict(rtol=1e-04, atol=1e-06) 11 | 12 | 13 | class SecondOrder(AbstractSystem): 14 | """Second-order, linear system with constant coefficients.""" 15 | 16 | b: float 17 | c: float 18 | 19 | n_inputs = 0 20 | initial_state = np.array([0.0, 0.0]) 21 | 22 | def vector_field(self, x, u=None, t=None): 23 | """ddx + b dx + c x = u as first order with x1=x and x2=dx.""" 24 | x1, x2 = x 25 | dx1 = x2 26 | dx2 = -self.b * x2 - self.c * x1 27 | return jnp.array([dx1, dx2]) 28 | 29 | def output(self, x, u=None, t=None): 30 | x1, _ = x 31 | return x1 32 | 33 | 34 | def test_forward_model_crit_damp(): 35 | b = 2 36 | c = 1 # critical damping as b**2 == 4*c 37 | sys = SecondOrder(b, c) 38 | 39 | def x(t, x0, dx0): 40 | """Solution to critically damped linear second-order system.""" 41 | C2 = x0 42 | C1 = b / 2 * C2 43 | return np.exp(-b * t / 2) * (C1 * t + C2) 44 | 45 | x0 = jnp.array([1, 0]) # x(t=0)=1, dx(t=0)=0 46 | t = jnp.linspace(0, 1) 47 | model = Flow(sys, stepsize_controller=dfx.PIDController(rtol=1e-7, atol=1e-9)) 48 | x_pred = model(t, initial_state=x0)[1] 49 | x_true = x(t, *x0) 50 | assert np.allclose(x_true, x_pred) 51 | 52 | 53 | def test_forward_model_lin_sys(): 54 | b = 2 55 | c = 1 # critical damping as b**2 == 4*c 56 | uconst = 1 57 | 58 | A = jnp.array([[0, 1], [-c, -b]]) 59 | B = jnp.array([[0], [1]]) 60 | C = jnp.array([[1, 0]]) 61 | D = jnp.zeros((1, 1)) 62 | sys = LinearSystem(A, B, C, D) 63 | 64 | def x(t, x0, dx0, uconst): 65 | """Solution to critically damped linear second-order system.""" 66 | C2 = x0 - uconst / c 67 | C1 = b / 2 * C2 68 | return np.exp(-b * t / 2) * (C1 * t + C2) + uconst / c 69 | 70 | x0 = jnp.array([1, 0]) # x(t=0)=1, dx(t=0)=0 71 | t = jnp.linspace(0, 1) 72 | u = jnp.ones(t.shape + (1,)) * uconst 73 | model = Flow(sys, stepsize_controller=dfx.PIDController(rtol=1e-7, atol=1e-9)) 74 | x_pred = model(t, u, initial_state=x0)[1] 75 | x_true = x(t, x0[0], x0[1], uconst) 76 | assert np.allclose(x_true, x_pred) 77 | 78 | 79 | def test_discrete_forward_model(): 80 | b = 2 81 | c = 1 # critical damping as b**2 == 4*c 82 | t = jnp.arange(50) 83 | u = jnp.sin(1 / len(t) * 2 * np.pi * t)[:, None] # single input 84 | x0 = jnp.array([1.0, 0.0]) 85 | A = jnp.array([[0, 1], [-c, -b]]) 86 | B = jnp.array([[0], [1]]) 87 | C = jnp.array([[1, 0]]) 88 | D = jnp.zeros((1, 1)) 89 | # test just input 90 | sys = LinearSystem(A, B, C, D) 91 | model = Map(sys) 92 | x, y = model(u=u, initial_state=x0) # ours 93 | scipy_sys = dlti(A, B, C, D) 94 | _, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0) 95 | npt.assert_allclose(scipy_y, y, **tols) 96 | npt.assert_allclose(scipy_x, x, **tols) 97 | # test input and time (results should be same) 98 | x, y = model(u=u, t=t, initial_state=x0) 99 | scipy_t, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0, t=t) 100 | npt.assert_allclose(scipy_y, y, **tols) 101 | npt.assert_allclose(scipy_x, x, **tols) 102 | 103 | 104 | def test_initial_state(): 105 | class Sys(AbstractSystem): 106 | n_inputs = "scalar" 107 | initial_state = jnp.array(1.0) 108 | 109 | def vector_field(self, x, u, t=None): 110 | return x * 0.1 + u 111 | 112 | t = jnp.arange(5) 113 | u = jnp.zeros(5) 114 | x, y = Flow(Sys())(t, u) 115 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import runpy 2 | from pathlib import Path 3 | 4 | import nbformat 5 | import pytest 6 | from nbconvert.preprocessors import ExecutePreprocessor 7 | 8 | 9 | example_dir = Path(__file__, "..", "..", "examples").resolve() 10 | examples = [str(p) for p in example_dir.glob("*.py")] 11 | notebooks = [str(p) for p in example_dir.resolve().glob("*.ipynb")] 12 | 13 | 14 | @pytest.mark.slow 15 | @pytest.mark.parametrize("example", examples, ids=lambda x: Path(x).name) 16 | def test_examples_run_without_error(example): 17 | runpy.run_path(example) 18 | 19 | 20 | @pytest.mark.slow 21 | @pytest.mark.parametrize("notebook", notebooks, ids=lambda x: Path(x).name) 22 | def test_notebooks_dont_change(notebook): 23 | with open(notebook) as f: 24 | nb = nbformat.read(f, as_version=4) 25 | try: 26 | ExecutePreprocessor(timeout=60).preprocess(nb) 27 | except Exception as e: 28 | raise Exception(f"Running the notebook {notebook} failed") from e 29 | -------------------------------------------------------------------------------- /tests/test_linearize.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import numpy.testing as npt 4 | 5 | from dynax import ( 6 | AbstractControlAffine, 7 | AbstractSystem, 8 | discrete_relative_degree, 9 | DiscreteLinearizingSystem, 10 | DynamicStateFeedbackSystem, 11 | Flow, 12 | input_output_linearize, 13 | LinearSystem, 14 | Map, 15 | relative_degree, 16 | ) 17 | from dynax.example_models import NonlinearDrag, Sastry9_9 18 | from dynax.linearize import ( 19 | is_controllable, 20 | ) 21 | 22 | 23 | tols = dict(rtol=1e-04, atol=1e-06) 24 | 25 | 26 | class Allpass(AbstractControlAffine): 27 | initial_state = jnp.zeros(0) 28 | n_inputs = "scalar" 29 | 30 | def f(self, x): 31 | return jnp.array(0.0) 32 | 33 | def g(self, x): 34 | return jnp.array(0.0) 35 | 36 | def h(self, x): 37 | return jnp.array(0.0) 38 | 39 | def i(self, x): 40 | return jnp.array(1.0) 41 | 42 | 43 | class SpringMassDamperWithOutput(AbstractControlAffine): 44 | m: float = 0.1 45 | r: float = 0.1 46 | k: float = 0.1 47 | out: int = 0 48 | 49 | initial_state = jnp.zeros(2) 50 | n_inputs = "scalar" 51 | 52 | def f(self, x): 53 | x1, x2 = x 54 | return jnp.array([x2, (-self.r * x2 - self.k * x1) / self.m]) 55 | 56 | def g(self, x): 57 | return jnp.array([0, 1 / self.m]) 58 | 59 | def h(self, x): 60 | return x[np.array(self.out)] 61 | 62 | 63 | def test_relative_degree(): 64 | xs = np.random.normal(size=(100, 2)) 65 | # output is position 66 | sys = SpringMassDamperWithOutput(out=0) 67 | assert relative_degree(sys, xs) == 2 68 | # output is velocity 69 | sys = SpringMassDamperWithOutput(out=1) 70 | assert relative_degree(sys, xs) == 1 71 | 72 | xs = np.random.normal(size=100) 73 | assert relative_degree(Allpass(), xs) == 0 74 | 75 | 76 | def test_discrete_relative_degree(): 77 | xs = np.random.normal(size=(100, 2)) 78 | us = np.random.normal(size=(100)) 79 | 80 | sys = SpringMassDamperWithOutput(out=0) 81 | assert discrete_relative_degree(sys, xs, us) == 2 82 | 83 | sys = SpringMassDamperWithOutput(out=1) 84 | assert discrete_relative_degree(sys, xs, us) == 1 85 | 86 | xs = np.random.normal(size=100) 87 | assert discrete_relative_degree(Allpass(), xs, us) == 0 88 | 89 | 90 | def test_is_controllable(): 91 | n = 3 92 | A = np.diag(np.arange(n)) 93 | B = np.ones((n, 1)) 94 | assert is_controllable(A, B) 95 | 96 | A[1, :] = A[0, :] 97 | assert not is_controllable(A, B) 98 | 99 | 100 | def test_linearize_lin2lin(): 101 | n, m, p = 3, 2, 1 102 | A = jnp.array(np.random.normal(size=(n, n))) 103 | B = jnp.array(np.random.normal(size=(n, m))) 104 | C = jnp.array(np.random.normal(size=(p, n))) 105 | D = jnp.array(np.random.normal(size=(p, m))) 106 | sys = LinearSystem(A, B, C, D) 107 | linsys = sys.linearize() 108 | assert np.allclose(A, linsys.A) 109 | assert np.allclose(B, linsys.B) 110 | assert np.allclose(C, linsys.C) 111 | assert np.allclose(D, linsys.D) 112 | 113 | 114 | def test_linearize_dyn2lin(): 115 | class ScalarScalar(AbstractSystem): 116 | initial_state = jnp.array(0.0) 117 | n_inputs = "scalar" 118 | 119 | def vector_field(self, x, u, t): 120 | return -1 * x + 2 * u 121 | 122 | def output(self, x, u, t): 123 | return 3 * x + 4 * u 124 | 125 | sys = ScalarScalar() 126 | linsys = sys.linearize() 127 | assert np.array_equal(linsys.A, -1.0) 128 | assert np.array_equal(linsys.B, 2.0) 129 | assert np.array_equal(linsys.C, 3.0) 130 | assert np.array_equal(linsys.D, 4.0) 131 | 132 | 133 | def test_linearize_sastry9_9(): 134 | """Linearize should return 2d-arrays. Refererence computed by hand.""" 135 | sys = Sastry9_9() 136 | linsys = sys.linearize() 137 | assert np.array_equal(linsys.A, [[0, 0, 0], [1, 0, 0], [1, -1, 0]]) 138 | assert np.array_equal(linsys.B, [1, 1, 0]) 139 | assert np.array_equal(linsys.C, [0, 0, 1]) 140 | assert np.array_equal(linsys.D, 0.0) 141 | 142 | 143 | def test_input_output_linearize_single_output(): 144 | """Feedback linearized system equals system linearized around x0.""" 145 | sys = NonlinearDrag(0.1, 0.1, 0.1, 0.1) 146 | ref = sys.linearize() 147 | xs = np.random.normal(size=(100,) + sys.initial_state.shape) 148 | reldeg = relative_degree(sys, xs) 149 | feedbacklaw = input_output_linearize(sys, reldeg, ref) 150 | feedback_sys = DynamicStateFeedbackSystem(sys, ref, feedbacklaw) 151 | t = jnp.linspace(0, 0.1) 152 | u = jnp.sin(t) 153 | npt.assert_allclose( 154 | Flow(ref)(t, u)[1], 155 | Flow(feedback_sys)(t, u)[1], 156 | **tols, 157 | ) 158 | 159 | 160 | def test_input_output_linearize_multiple_outputs(): 161 | """Can select an output for linearization.""" 162 | sys = SpringMassDamperWithOutput(out=[0, 1]) 163 | ref = sys.linearize() 164 | for out_idx in range(2): 165 | out_idx = 1 166 | xs = np.random.normal(size=(100,) + sys.initial_state.shape) 167 | reldeg = relative_degree(sys, xs, output=out_idx) 168 | feedbacklaw = input_output_linearize(sys, reldeg, ref, output=out_idx) 169 | feedback_sys = DynamicStateFeedbackSystem(sys, ref, feedbacklaw) 170 | t = jnp.linspace(0, 1) 171 | u = jnp.sin(t) * 0.1 172 | y_ref = Flow(ref)(t, u)[1] 173 | y = Flow(feedback_sys)(t, u)[1] 174 | npt.assert_allclose(y_ref[:, out_idx], y[:, out_idx], **tols) 175 | 176 | 177 | class Lee7_4_5(AbstractSystem): 178 | initial_state = jnp.zeros(2) 179 | n_inputs = "scalar" 180 | 181 | def vector_field(self, x, u, t=None): 182 | x1, x2 = x 183 | return 0.1 * jnp.array([x1 + x1**3 + x2, x2 + x2**3 + u]) 184 | 185 | def output(self, x, u=None, t=None): 186 | return x[0] 187 | 188 | 189 | def test_discrete_input_output_linearize(): 190 | sys = Lee7_4_5() 191 | refsys = sys.linearize() 192 | xs = np.random.normal(size=(100, 2)) 193 | us = np.random.normal(size=100) 194 | reldeg = discrete_relative_degree(sys, xs, us) 195 | assert reldeg == 2 196 | 197 | feedback_sys = DiscreteLinearizingSystem(sys, refsys, reldeg) 198 | t = jnp.linspace(0, 0.001, 10) 199 | u = jnp.cos(t) * 0.1 200 | _, v = Map(feedback_sys)(t, u) 201 | _, y = Map(sys)(t, u) 202 | _, y_ref = Map(refsys)(t, u) 203 | 204 | npt.assert_allclose(y_ref, y, **tols) 205 | -------------------------------------------------------------------------------- /tests/test_systems.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import numpy.testing as npt 4 | 5 | from dynax import FeedbackSystem, LinearSystem, SeriesSystem 6 | 7 | 8 | def test_series(): 9 | n1, m1, p1 = 4, 3, 2 10 | A1 = jnp.array(np.random.randint(-5, 5, size=(n1, n1))) 11 | B1 = jnp.array(np.random.randint(-5, 5, size=(n1, m1))) 12 | C1 = jnp.array(np.random.randint(-5, 5, size=(p1, n1))) 13 | D1 = jnp.array(np.random.randint(-5, 5, size=(p1, m1))) 14 | sys1 = LinearSystem(A1, B1, C1, D1) 15 | n2, m2, p2 = 5, p1, 3 16 | A2 = jnp.array(np.random.randint(-5, 5, size=(n2, n2))) 17 | B2 = jnp.array(np.random.randint(-5, 5, size=(n2, m2))) 18 | C2 = jnp.array(np.random.randint(-5, 5, size=(p2, n2))) 19 | D2 = jnp.array(np.random.randint(-5, 5, size=(p2, m2))) 20 | sys2 = LinearSystem(A2, B2, C2, D2) 21 | sys = SeriesSystem(sys1, sys2) 22 | linsys = sys.linearize() 23 | npt.assert_array_equal( 24 | linsys.A, np.block([[A1, np.zeros((n1, n2))], [B2.dot(C1), A2]]) 25 | ) 26 | npt.assert_array_equal(linsys.B, np.block([[B1], [B2.dot(D1)]])) 27 | npt.assert_array_equal(linsys.C, np.block([[D2.dot(C1), C2]])) 28 | npt.assert_array_equal(linsys.D, D2.dot(D1)) 29 | 30 | 31 | def test_feedback(): 32 | n1, m1, p1 = 4, 3, 2 33 | A1 = jnp.array(np.random.randint(-5, 5, size=(n1, n1))) 34 | B1 = jnp.array(np.random.randint(-5, 5, size=(n1, m1))) 35 | C1 = jnp.array(np.random.randint(-5, 5, size=(p1, n1))) 36 | D1 = jnp.array(np.zeros((p1, m1))) 37 | sys1 = LinearSystem(A1, B1, C1, D1) 38 | n2, m2, p2 = 5, p1, 3 39 | A2 = jnp.array(np.random.randint(-5, 5, size=(n2, n2))) 40 | B2 = jnp.array(np.random.randint(-5, 5, size=(n2, m2))) 41 | C2 = jnp.array(np.random.randint(-5, 5, size=(p2, n2))) 42 | D2 = jnp.array(np.random.randint(-5, 5, size=(p2, m2))) 43 | sys2 = LinearSystem(A2, B2, C2, D2) 44 | sys = FeedbackSystem(sys1, sys2) 45 | linsys = sys.linearize() 46 | npt.assert_array_equal( 47 | linsys.A, np.block([[A1 + B1 @ D2 @ C1, B1 @ C2], [B2 @ C1, A2]]) 48 | ) 49 | npt.assert_array_equal(linsys.B, np.block([[B1], [np.zeros((n2, m1))]])) 50 | npt.assert_array_equal(linsys.C, np.block([[C1, np.zeros((p1, n2))]])) 51 | npt.assert_array_equal(linsys.D, np.zeros((p1, m1))) 52 | --------------------------------------------------------------------------------