├── immrax ├── _vendor │ └── __init__.py ├── refinement │ ├── __init__.py │ └── factories.py ├── system │ ├── __init__.py │ ├── trajectory.py │ └── system.py ├── parametric │ ├── __init__.py │ ├── sets │ │ ├── annulus.py │ │ ├── polytope.py │ │ ├── normotope.py │ │ └── ellipsoid.py │ ├── parametope.py │ └── param_reach.py ├── __init__.py ├── inclusion │ ├── __init__.py │ ├── polynomial.py │ ├── cubic_spline.py │ ├── custom_if.py │ └── interval.py ├── control.py ├── utils.py └── embedding.py ├── docs ├── source │ ├── examples │ ├── installation │ │ ├── install-jax.rst │ │ └── install-cyipopt-and-hsl.rst │ ├── api.rst │ ├── modules │ │ ├── utils.rst │ │ ├── neural.rst │ │ ├── system.rst │ │ ├── control.rst │ │ ├── embedding.rst │ │ ├── inclusion.rst │ │ ├── parametric.rst │ │ └── refinement.rst │ ├── examples.rst │ ├── index.rst │ ├── quickstart.ipynb │ ├── conf.py │ └── gettingstarted.md ├── requirements.txt ├── jupyter_execute │ ├── 3e79a9d4b9395e56d83b24723bd7d4ac3f51567e99634fab72ca962ae25368fb.png │ ├── 40633ce9f207092a3a44cadba8641883305d7a8bff83ff44d68815d71eb34610.png │ └── quickstart.ipynb └── make.bat ├── examples ├── vehicle │ ├── 100r100r2 │ │ ├── arch.txt │ │ ├── model.pt │ │ ├── model.eqx │ │ └── model.npy │ ├── vehicle.pdf │ └── show.py ├── pendulum │ ├── figures │ │ ├── pendulum.mp4 │ │ ├── pendulum.pdf │ │ └── frames │ │ │ ├── pendulum_00000.pdf │ │ │ ├── pendulum_00001.pdf │ │ │ ├── pendulum_00002.pdf │ │ │ ├── pendulum_00003.pdf │ │ │ ├── pendulum_00004.pdf │ │ │ ├── pendulum_00005.pdf │ │ │ ├── pendulum_00006.pdf │ │ │ ├── pendulum_00007.pdf │ │ │ ├── pendulum_00008.pdf │ │ │ ├── pendulum_00009.pdf │ │ │ ├── pendulum_00010.pdf │ │ │ ├── pendulum_00011.pdf │ │ │ ├── pendulum_00012.pdf │ │ │ ├── pendulum_00013.pdf │ │ │ ├── pendulum_00014.pdf │ │ │ ├── pendulum_00015.pdf │ │ │ ├── pendulum_00016.pdf │ │ │ ├── pendulum_00017.pdf │ │ │ ├── pendulum_00018.pdf │ │ │ ├── pendulum_00019.pdf │ │ │ ├── pendulum_00020.pdf │ │ │ ├── pendulum_00021.pdf │ │ │ ├── pendulum_00022.pdf │ │ │ ├── pendulum_00023.pdf │ │ │ ├── pendulum_00024.pdf │ │ │ ├── pendulum_00025.pdf │ │ │ ├── pendulum_00026.pdf │ │ │ ├── pendulum_00027.pdf │ │ │ ├── pendulum_00028.pdf │ │ │ ├── pendulum_00029.pdf │ │ │ ├── pendulum_00030.pdf │ │ │ ├── pendulum_00031.pdf │ │ │ ├── pendulum_00032.pdf │ │ │ ├── pendulum_00033.pdf │ │ │ ├── pendulum_00034.pdf │ │ │ ├── pendulum_00035.pdf │ │ │ ├── pendulum_00036.pdf │ │ │ ├── pendulum_00037.pdf │ │ │ ├── pendulum_00038.pdf │ │ │ ├── pendulum_00039.pdf │ │ │ ├── pendulum_00040.pdf │ │ │ ├── pendulum_00041.pdf │ │ │ ├── pendulum_00042.pdf │ │ │ ├── pendulum_00043.pdf │ │ │ ├── pendulum_00044.pdf │ │ │ ├── pendulum_00045.pdf │ │ │ ├── pendulum_00046.pdf │ │ │ ├── pendulum_00047.pdf │ │ │ ├── pendulum_00048.pdf │ │ │ ├── pendulum_00049.pdf │ │ │ ├── pendulum_00050.pdf │ │ │ ├── pendulum_00051.pdf │ │ │ ├── pendulum_00052.pdf │ │ │ ├── pendulum_00053.pdf │ │ │ ├── pendulum_00054.pdf │ │ │ ├── pendulum_00055.pdf │ │ │ ├── pendulum_00056.pdf │ │ │ ├── pendulum_00057.pdf │ │ │ ├── pendulum_00058.pdf │ │ │ ├── pendulum_00059.pdf │ │ │ ├── pendulum_00060.pdf │ │ │ ├── pendulum_00061.pdf │ │ │ ├── pendulum_00062.pdf │ │ │ ├── pendulum_00063.pdf │ │ │ ├── pendulum_00064.pdf │ │ │ └── pendulum_00065.pdf │ └── README.md ├── compare.py └── auxillary_vars │ ├── CVDP23.py │ └── aux-var.py ├── .gitmodules ├── .readthedocs.yaml ├── tests ├── utils.py ├── test_inclusion.py ├── test_polynomial.py ├── test_system_continuous.py ├── test_system_discrete.py └── test_cubic_spline.py ├── pyproject.toml ├── README.md └── .gitignore /immrax/_vendor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/examples: -------------------------------------------------------------------------------- 1 | ../../examples -------------------------------------------------------------------------------- /examples/vehicle/100r100r2/arch.txt: -------------------------------------------------------------------------------- 1 | 4 100 ReLU 100 ReLU 2 -------------------------------------------------------------------------------- /docs/source/installation/install-jax.rst: -------------------------------------------------------------------------------- 1 | Installing JAX 2 | ============== -------------------------------------------------------------------------------- /examples/vehicle/vehicle.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/vehicle/vehicle.pdf -------------------------------------------------------------------------------- /examples/vehicle/100r100r2/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/vehicle/100r100r2/model.pt -------------------------------------------------------------------------------- /examples/vehicle/100r100r2/model.eqx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/vehicle/100r100r2/model.eqx -------------------------------------------------------------------------------- /examples/vehicle/100r100r2/model.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/vehicle/100r100r2/model.npy -------------------------------------------------------------------------------- /examples/pendulum/figures/pendulum.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/pendulum.mp4 -------------------------------------------------------------------------------- /examples/pendulum/figures/pendulum.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/pendulum.pdf -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | === 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: MODULES 7 | :glob: 8 | 9 | modules/* 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "immrax/_vendor/jax-verify"] 2 | path = immrax/_vendor/jax_verify 3 | url = git@github.com:Akash-Harapanahalli/jax_verify.git 4 | -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00000.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00001.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00001.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00002.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00002.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00003.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00003.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00004.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00004.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00005.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00005.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00006.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00006.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00007.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00007.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00008.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00008.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00009.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00009.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00010.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00010.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00011.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00011.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00012.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00012.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00013.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00013.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00014.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00014.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00015.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00015.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00016.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00016.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00017.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00017.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00018.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00018.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00019.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00019.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00020.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00021.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00021.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00022.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00022.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00023.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00023.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00024.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00024.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00025.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00025.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00026.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00026.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00027.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00027.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00028.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00028.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00029.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00029.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00030.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00030.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00031.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00031.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00032.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00032.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00033.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00033.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00034.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00034.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00035.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00035.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00036.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00036.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00037.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00037.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00038.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00038.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00039.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00039.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00040.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00040.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00041.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00041.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00042.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00042.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00043.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00043.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00044.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00044.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00045.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00045.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00046.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00046.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00047.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00047.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00048.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00048.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00049.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00049.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00050.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00050.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00051.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00051.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00052.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00052.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00053.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00053.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00054.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00054.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00055.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00055.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00056.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00056.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00057.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00057.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00058.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00058.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00059.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00059.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00060.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00060.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00061.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00061.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00062.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00062.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00063.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00063.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00064.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00064.pdf -------------------------------------------------------------------------------- /examples/pendulum/figures/frames/pendulum_00065.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/examples/pendulum/figures/frames/pendulum_00065.pdf -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | immrax.utils 2 | ============ 3 | 4 | .. automodule:: immrax.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/modules/neural.rst: -------------------------------------------------------------------------------- 1 | immrax.neural 2 | ============= 3 | 4 | .. automodule:: immrax.neural 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/modules/system.rst: -------------------------------------------------------------------------------- 1 | immrax.system 2 | ============= 3 | 4 | .. automodule:: immrax.system 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/modules/control.rst: -------------------------------------------------------------------------------- 1 | immrax.control 2 | ============== 3 | 4 | .. automodule:: immrax.control 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/modules/embedding.rst: -------------------------------------------------------------------------------- 1 | immrax.embedding 2 | ================ 3 | 4 | .. automodule:: immrax.embedding 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | # sphinx-rtd-theme 3 | sphinx-book-theme 4 | # nbsphinx 5 | myst-nb 6 | numpydoc 7 | pygments 8 | mock 9 | jax[cpu] 10 | jaxlib 11 | jaxtyping 12 | ./ -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | Examples 3 | ======== 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | examples/vehicle/vehicle 9 | examples/pendulum/pendulum 10 | -------------------------------------------------------------------------------- /docs/source/modules/inclusion.rst: -------------------------------------------------------------------------------- 1 | immrax.inclusion 2 | ================ 3 | 4 | .. automodule:: immrax.inclusion 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | -------------------------------------------------------------------------------- /immrax/refinement/__init__.py: -------------------------------------------------------------------------------- 1 | from .factories import SampleRefinement, LinProgRefinement, NullVecRefinement 2 | 3 | __all__ = ["SampleRefinement", "LinProgRefinement", "NullVecRefinement"] 4 | -------------------------------------------------------------------------------- /docs/source/modules/parametric.rst: -------------------------------------------------------------------------------- 1 | immrax.parametric 2 | ================= 3 | 4 | .. automodule:: immrax.parametric 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/source/modules/refinement.rst: -------------------------------------------------------------------------------- 1 | immrax.refinement 2 | ================= 3 | 4 | .. automodule:: immrax.refinement 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/source/installation/install-cyipopt-and-hsl.rst: -------------------------------------------------------------------------------- 1 | Installing `cyipopt` and `coinhsl` 2 | ================================== 3 | 4 | This is only necessary for running the pendulum example in :ref:`examples/pendulum/pendulum` -------------------------------------------------------------------------------- /docs/jupyter_execute/3e79a9d4b9395e56d83b24723bd7d4ac3f51567e99634fab72ca962ae25368fb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/docs/jupyter_execute/3e79a9d4b9395e56d83b24723bd7d4ac3f51567e99634fab72ca962ae25368fb.png -------------------------------------------------------------------------------- /docs/jupyter_execute/40633ce9f207092a3a44cadba8641883305d7a8bff83ff44d68815d71eb34610.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtfactslab/immrax/HEAD/docs/jupyter_execute/40633ce9f207092a3a44cadba8641883305d7a8bff83ff44d68815d71eb34610.png -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.12" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | 12 | sphinx: 13 | configuration: docs/source/conf.py 14 | -------------------------------------------------------------------------------- /examples/compare.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import immrax 4 | 5 | f = lambda x: jnp.array([(x[0] + x[1]) ** 2, jnp.sin(x[0] + x[1] + 2 * x[1] * x[2])]) 6 | Fnat = jax.jit(immrax.natif(f)) 7 | Fjac = jax.jit(immrax.jacif(f)) 8 | Fmix = jax.jit(immrax.mjacif(f)) 9 | x0 = immrax.icentpert(jnp.zeros(2), 0.1) 10 | for F in [Fnat, Fjac, Fmix]: 11 | F(x0) # JIT Compile 12 | ret, times = immrax.utils.run_times(10000, F, x0) 13 | print(ret) 14 | print(f"{times.mean():.3e} \u00b1 {times.std():.3e}") 15 | -------------------------------------------------------------------------------- /immrax/system/__init__.py: -------------------------------------------------------------------------------- 1 | from .system import ( 2 | System, 3 | ReversedSystem, 4 | LinearTransformedSystem, 5 | LiftedSystem, 6 | OpenLoopSystem, 7 | ) 8 | from .trajectory import ( 9 | RawTrajectory, 10 | RawContinuousTrajectory, 11 | RawDiscreteTrajectory, 12 | Trajectory, 13 | ContinuousTrajectory, 14 | DiscreteTrajectory, 15 | ) 16 | 17 | __all__ = [ 18 | "System", 19 | "ReversedSystem", 20 | "LinearTransformedSystem", 21 | "LiftedSystem", 22 | "OpenLoopSystem", 23 | "RawTrajectory", 24 | "RawContinuousTrajectory", 25 | "RawDiscreteTrajectory", 26 | "Trajectory", 27 | "ContinuousTrajectory", 28 | "DiscreteTrajectory", 29 | ] 30 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Interval Analysis and Mixed Monotone Reachability in JAX 2 | ======================================================== 3 | 4 | `immrax` is a tool for interval analysis and mixed monotone reachability analysis in JAX. 5 | 6 | Inclusion function transformations are composable with existing JAX transformations, allowing the use of Automatic Differentiation to learn relationships between inputs and outputs, as well as parallelization and GPU capabilities for quick, accurate reachable set estimation. 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | :caption: Contents: 11 | 12 | gettingstarted 13 | examples 14 | api 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /immrax/parametric/__init__.py: -------------------------------------------------------------------------------- 1 | from .parametope import ( 2 | Parametope, 3 | hParametope, 4 | ) 5 | from .param_reach import ( 6 | ParametopeEmbedding, 7 | AdjointEmbedding, 8 | FastlinAdjointEmbedding, 9 | ) 10 | 11 | from .sets.ellipsoid import ( 12 | Ellipsoid, 13 | ) 14 | from .sets.polytope import ( 15 | Polytope, 16 | ) 17 | 18 | # from .sets.annulus import ( 19 | # LpAnnulus, 20 | # ) 21 | from .sets.normotope import ( 22 | Normotope, 23 | LinfNormotope, 24 | L2Normotope, 25 | ) 26 | 27 | __all__ = [ 28 | "Parametope", 29 | "hParametope", 30 | "ParametopeEmbedding", 31 | "AdjointEmbedding", 32 | "FastlinAdjointEmbedding", 33 | "Ellipsoid", 34 | "Polytope", 35 | "Normotope", 36 | "LinfNormotope", 37 | "L2Normotope", 38 | ] 39 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /examples/vehicle/show.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import immrax as irx 3 | 4 | net = irx.NeuralNetwork("100r100r2") 5 | 6 | 7 | class Vehicle(irx.OpenLoopSystem): 8 | def __init__(self) -> None: 9 | self.evolution = "continuous" 10 | self.xlen = 4 11 | 12 | def f( 13 | self, t: jnp.ndarray, x: jnp.ndarray, u: jnp.ndarray, w: jnp.ndarray 14 | ) -> jnp.ndarray: 15 | px, py, psi, v = x.ravel() 16 | u1, u2 = u.ravel() 17 | beta = jnp.arctan(jnp.tan(u2) / 2) 18 | return jnp.array( 19 | [v * jnp.cos(psi + beta), v * jnp.sin(psi + beta), v * jnp.sin(beta), u1] 20 | ) 21 | 22 | 23 | olsys = Vehicle() 24 | net = irx.NeuralNetwork("100r100r2") 25 | clsys = irx.ControlledSystem(olsys, net) 26 | 27 | 28 | print(net(jnp.zeros(4))) 29 | 30 | crown_net = irx.crown(net) 31 | fastlin_net = irx.fastlin(net) 32 | ix = irx.icentpert(jnp.zeros(4), 1.0) 33 | print(f"{ix=}") 34 | print(crown_net(ix)) 35 | res = fastlin_net(ix) 36 | print(res.C) 37 | -------------------------------------------------------------------------------- /immrax/__init__.py: -------------------------------------------------------------------------------- 1 | from . import inclusion as inclusion 2 | from .inclusion import * 3 | 4 | from . import system as system 5 | from .system import * 6 | 7 | from . import control as control 8 | from .control import * 9 | 10 | import sys 11 | import os 12 | 13 | jax_verify_path = os.path.join( 14 | os.path.dirname(__file__), 15 | "_vendor", 16 | "jax_verify", 17 | ) 18 | try: 19 | sys.path.insert(0, jax_verify_path) 20 | import jax_verify 21 | 22 | from . import neural as neural 23 | from .neural import * 24 | 25 | from . import parametric as parametric 26 | from .parametric import * 27 | except ImportError: 28 | print( 29 | "WARN (immrax): Failed to import jax_verify. Some neural and parametric features may not be available." 30 | ) 31 | print("WARN (immrax): Did you remember to initialize all git submodules?") 32 | finally: 33 | sys.path.remove(jax_verify_path) 34 | 35 | from . import embedding as embedding 36 | from .embedding import * 37 | 38 | 39 | from . import refinement as refinement 40 | from . import utils as utils 41 | -------------------------------------------------------------------------------- /docs/source/quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Installation and Quickstart\n", 8 | "\n", 9 | "## JAX\n", 10 | "For a quick, CPU only install,\n", 11 | "```\n", 12 | "pip install --upgrade jax[cpu]\n", 13 | "```\n", 14 | "For a GPU install, follow instructions from [https://jax.readthedocs.io/en/latest/installation.html](https://jax.readthedocs.io/en/latest/installation.html). \n", 15 | "For a local CUDA installation, this link may be helpful [https://gist.github.com/denguir/b21aa66ae7fb1089655dd9de8351a202](https://gist.github.com/denguir/b21aa66ae7fb1089655dd9de8351a202).\n", 16 | "\n", 17 | "## Install `immrax`\n", 18 | "\n", 19 | "For now, `immrax` can be installed by directly cloning the repository. We plan to upload the latest stable version to PyPi.\n", 20 | "```\n", 21 | "git clone \n", 22 | "pip install .\n", 23 | "```\n" 24 | ] 25 | } 26 | ], 27 | "metadata": { 28 | "language_info": { 29 | "name": "python" 30 | } 31 | }, 32 | "nbformat": 4, 33 | "nbformat_minor": 2 34 | } 35 | -------------------------------------------------------------------------------- /docs/jupyter_execute/quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Installation and Quickstart\n", 8 | "\n", 9 | "## JAX\n", 10 | "For a quick, CPU only install,\n", 11 | "```\n", 12 | "pip install --upgrade jax[cpu]\n", 13 | "```\n", 14 | "For a GPU install, follow instructions from [https://jax.readthedocs.io/en/latest/installation.html](https://jax.readthedocs.io/en/latest/installation.html). \n", 15 | "For a local CUDA installation, this link may be helpful [https://gist.github.com/denguir/b21aa66ae7fb1089655dd9de8351a202](https://gist.github.com/denguir/b21aa66ae7fb1089655dd9de8351a202).\n", 16 | "\n", 17 | "## Install `immrax`\n", 18 | "\n", 19 | "For now, `immrax` can be installed by directly cloning the repository. We plan to upload the latest stable version to PyPi.\n", 20 | "```\n", 21 | "git clone \n", 22 | "pip install .\n", 23 | "```\n" 24 | ] 25 | } 26 | ], 27 | "metadata": { 28 | "language_info": { 29 | "name": "python" 30 | } 31 | }, 32 | "nbformat": 4, 33 | "nbformat_minor": 2 34 | } -------------------------------------------------------------------------------- /examples/pendulum/README.md: -------------------------------------------------------------------------------- 1 | # Pendulum Optimal Control 2 | 3 | ## `Immrax` Set-Up 4 | 5 | Please ensure you have read and completed the set-up instructions in the [main project README](README.md). 6 | 7 | ## Installing `cyipopt` and `coinhsl` 8 | 9 | If you would like to run this example, you need to install IPOPT and the MA57 linear solver from HSL. 10 | 11 | To use the MA57 solver, you'll first need to acquire a package from [HSL](https://www.hsl.rl.ac.uk/). While there are instructions [here](https://cyipopt.readthedocs.io/en/stable/install.html#conda-forge-binaries-with-hsl), we highly recommend to instead use [ThirdParty-HSL](https://github.com/coin-or-tools/ThirdParty-HSL) to install HSL globally. 12 | Then, use a symbolic link to help the `conda` environment locate it. 13 | 14 | ```shell 15 | ln -s /usr/local/lib/libcoinhsl.so $CONDA_PREFIX/lib/libcoinhsl.so 16 | ``` 17 | 18 | Finally, install `cyipopt` to your `immrax` conda environment (more instructions [here](https://cyipopt.readthedocs.io/en/stable/install.html)). 19 | 20 | ```shell 21 | conda install -c conda-forge cyipopt 22 | ``` 23 | 24 | This command can take a while to fully resolve. 25 | -------------------------------------------------------------------------------- /immrax/inclusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .interval import ( 2 | Interval, 3 | interval, 4 | icopy, 5 | icentpert, 6 | i2centpert, 7 | i2lu, 8 | lu2i, 9 | i2ut, 10 | ut2i, 11 | iconcatenate, 12 | izeros, 13 | interval_intersect, 14 | interval_union, 15 | ) 16 | 17 | from . import nif as nif 18 | from .nif import natif 19 | from .custom_if import custom_if 20 | 21 | from .jacobian import ( 22 | jacif, 23 | jacM, 24 | Permutation, 25 | standard_permutation, 26 | two_permutations, 27 | all_permutations, 28 | Corner, 29 | bot_corner, 30 | top_corner, 31 | two_corners, 32 | all_corners, 33 | get_corner, 34 | get_corners, 35 | get_sparse_corners, 36 | mjacif, 37 | mjacM, 38 | ) 39 | 40 | __all__ = [ 41 | "Interval", 42 | "interval", 43 | "icopy", 44 | "icentpert", 45 | "i2centpert", 46 | "i2lu", 47 | "lu2i", 48 | "i2ut", 49 | "ut2i", 50 | "iconcatenate", 51 | "izeros", 52 | "interval_intersect", 53 | "interval_union", 54 | "nif", 55 | "natif", 56 | "jacM", 57 | "jacif", 58 | "custom_if", 59 | "Permutation", 60 | "standard_permutation", 61 | "two_permutations", 62 | "all_permutations", 63 | "Corner", 64 | "bot_corner", 65 | "top_corner", 66 | "two_corners", 67 | "all_corners", 68 | "get_corner", 69 | "get_corners", 70 | "get_sparse_corners", 71 | "mjacif", 72 | "mjacM", 73 | ] 74 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def validate_overapproximation_nd(func, input_interval, output_interval): 6 | """ 7 | Validates that the output_interval overapproximates the true values of the 8 | function over the input_interval by sampling along the diagonal for N-D intervals. 9 | """ 10 | # Generate 100 sample points along the line from lower to upper bound 11 | sample_points = jnp.linspace(input_interval.lower, input_interval.upper, 100) 12 | 13 | # Apply the original function to these sample points 14 | true_values = func(sample_points) 15 | 16 | # Check that all resulting values are within the computed interval bounds 17 | # Note: This check is along a line, not the full interval hyper-rectangle for vector inputs 18 | assert jnp.all(true_values >= output_interval.lower) 19 | assert jnp.all(true_values <= output_interval.upper) 20 | 21 | 22 | def validate_overapproximation_1d_list(func, input_interval, output_interval): 23 | """ 24 | Validates that the output_interval overapproximates the true values of the 25 | function over a list of 1D input_intervals by sampling. 26 | """ 27 | for i in range(len(input_interval.lower)): 28 | sample_points = jnp.linspace( 29 | input_interval.lower[i], input_interval.upper[i], 100 30 | ) 31 | true_values = jax.vmap(func)(sample_points) 32 | assert jnp.all(true_values >= output_interval.lower[i].squeeze()) 33 | assert jnp.all(true_values <= output_interval.upper[i].squeeze()) 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.metadata] 6 | allow-direct-references = true 7 | 8 | [tool.ruff] 9 | 10 | exclude = [ 11 | "immrax/_vendor/**" 12 | ] 13 | 14 | [tool.ruff.lint] 15 | 16 | ignore = ["E731"] # We like lambda functions 17 | 18 | [project] 19 | name = "immrax" 20 | version = "0.3.4" 21 | readme = {file = "README.md", content-type = "text/markdown"} 22 | 23 | dependencies = [ 24 | "numpy ~= 2.3.2", 25 | "scipy ~= 1.16.1", 26 | "shapely ~= 2.0.6", 27 | "matplotlib ~= 3.9.2", 28 | "pypoman ~= 1.1.0", 29 | "tabulate ~= 0.9.0", 30 | "jax ~= 0.6.1", 31 | "jaxtyping ~= 0.2.33", 32 | "diffrax ~= 0.7.0", 33 | "equinox ~= 0.12.2", 34 | "immutabledict ~= 4.2.0", 35 | "linrax ~= 0.1.0", 36 | # jax_verify dependencies. 37 | "absl-py", 38 | "cvxpy", 39 | "dm-tree", 40 | "optax", 41 | "dm-haiku", 42 | "einshape", 43 | "ml_collections", 44 | ] 45 | 46 | [project.urls] 47 | "Homepage" = "https://github.com/gtfactslab/immrax" 48 | "Documentation" = "https://immrax.readthedocs.io" 49 | "Bug Tracker" = "https://github.com/gtfactslab/immrax/issues" 50 | "Source Code" = "https://github.com/gtfactslab/immrax" 51 | 52 | [project.optional-dependencies] 53 | cuda = [ 54 | "jax[cuda12] ~= 0.6.1", 55 | "linrax[cuda]", 56 | ] 57 | 58 | examples = [ 59 | "casadi ~= 3.6.7", 60 | "control ~= 0.10.1", 61 | # "cyipopt ~= 1.5.0", # This is a pain to install manually, better to do with conda 62 | "ipykernel ~= 6.29.5", 63 | "ipympl ~= 0.9.7" 64 | ] 65 | -------------------------------------------------------------------------------- /immrax/parametric/sets/annulus.py: -------------------------------------------------------------------------------- 1 | from ..parametope import hParametope 2 | import jax.numpy as jnp 3 | from jaxtyping import ArrayLike 4 | 5 | # import transforms from matplotlib 6 | from jax.tree_util import register_pytree_node_class 7 | from ...inclusion import Interval, interval, icentpert 8 | 9 | 10 | @register_pytree_node_class 11 | class LpAnnulus(hParametope): 12 | p: float 13 | 14 | def __init__(self, ox, H, ly, uy, p=2.0): 15 | # ly = jnp.zeros_like(uy) if uy is not None else None 16 | super().__init__(ox, [H], [ly], [uy]) 17 | self.p = p 18 | 19 | @classmethod 20 | def from_parametope(cls, pt: hParametope): 21 | return LpAnnulus(pt.ox, pt.alpha, pt.y) 22 | 23 | def g(self, i: int, a: ArrayLike): 24 | if i != 0: 25 | raise Exception(f"Ellipsoid has only one constraint, got {i=}") 26 | return jnp.sum(jnp.abs(a) ** self.p) ** (1 / self.p) 27 | 28 | def ginv(self, i: int, iy: Interval): 29 | # Returns a box containing the preimage of the constraint over iy 30 | 31 | if i != 0: 32 | raise Exception(f"Annulus has only one constraint, got {i=}") 33 | 34 | n = len(self.ox) 35 | 36 | # |x|_inf \leq |x|_p \leq n^{1/p} |x|_inf 37 | return icentpert(jnp.zeros(n), iy.upper * jnp.ones(n)) 38 | 39 | def iover(self): 40 | return self.ginv(self.H @ interval(self.ly, self.uy)) 41 | 42 | @property 43 | def P(self): 44 | return self.H[0].T @ self.H[0] 45 | 46 | def V(self, x: ArrayLike): 47 | return self.g(0, self.H[0] @ (x - self.ox)) 48 | 49 | # def plot_projection (self, ax, xi=0, yi=1, rescale=False, **kwargs) : 50 | 51 | def __repr__(self): 52 | return f"Ellipsoid(ox={self.ox}, H={self.H}, uy={self.uy})" 53 | 54 | def __str__(self): 55 | return f"Ellipsoid(ox={self.ox}, H={self.H}, uy={self.uy})" 56 | -------------------------------------------------------------------------------- /immrax/inclusion/polynomial.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import core 3 | import jax.numpy as jnp 4 | from immrax.inclusion.interval import interval, Interval 5 | from immrax.inclusion import nif, custom_if 6 | 7 | 8 | @custom_if 9 | def polynomial(a, x): 10 | return jnp.polyval(a, x) 11 | 12 | 13 | @polynomial.defif 14 | def polynomial_inclusion(a, x): 15 | # if not isinstance(a, Interval) or jnp.allclose(a.lower, a.upper) : 16 | # TODO: This only works for constant coefficients. Try to take inclusion for x, natif into inclusion for both 17 | if True: 18 | # TODO: Can we make this static wrt a somehow, to avoid recomputing the critical points? 19 | """Minimal inclusion function for constant coefficient a. 20 | 21 | ulf = min_{x \\in [ulx, olx]} f(a, x) 22 | olf = max_{x \\in [ulx, olx]} f(a, x) 23 | 24 | Since f is a polynomial, check critical points and endpoints. 25 | """ 26 | if isinstance(a, Interval): 27 | a = a.lower 28 | 29 | ad = jnp.polyder(a) 30 | ad_roots = jnp.roots(ad, strip_zeros=False) 31 | # Critical points and values 32 | # crit = jnp.real(ad_roots[jnp.isreal(ad_roots)]) 33 | # crit = ad_roots[jnp.where(jnp.isreal(ad_roots))] 34 | # crit = ad_roots 35 | 36 | crit_vals = jnp.real(jax.vmap(jnp.polyval, in_axes=(None, 0))(a, ad_roots)) 37 | crit_in_x = jax.vmap( 38 | lambda crit: jnp.logical_and( 39 | jnp.logical_and( 40 | crit > jnp.atleast_1d(x.lower), crit < jnp.atleast_1d(x.upper) 41 | ), 42 | jnp.isreal(crit), 43 | ) 44 | )(ad_roots) 45 | 46 | end_vals = jnp.array( 47 | [ 48 | jnp.polyval(a, jnp.atleast_1d(x.lower)), 49 | jnp.polyval(a, jnp.atleast_1d(x.upper)), 50 | ] 51 | ) 52 | 53 | # print(f"{ad_roots.shape=}") 54 | # print(f"{end_vals.shape=}") 55 | # print(f"{crit_vals[:, None].shape=},\n{crit_in_x.shape=}") 56 | l_vals = jnp.concatenate( 57 | (end_vals, jnp.where(crit_in_x, crit_vals[:, None], jnp.inf)) 58 | ) 59 | u_vals = jnp.concatenate( 60 | (end_vals, jnp.where(crit_in_x, crit_vals[:, None], -jnp.inf)) 61 | ) 62 | 63 | return interval(jnp.min(l_vals, axis=0), jnp.max(u_vals, axis=0)) 64 | 65 | else: 66 | """Otherwise, simply use natural inclusion function.""" 67 | print("Using natural inclusion function for polynomial primitive.") 68 | return nif.natif(polynomial_impl)(a, x) 69 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import os 10 | import sys 11 | 12 | sys.path.insert(0, os.path.abspath("../..")) 13 | 14 | from mock import Mock as MagicMock 15 | 16 | 17 | class Mock(MagicMock): 18 | @classmethod 19 | def __getattr__(cls, name): 20 | return MagicMock() 21 | 22 | 23 | # MOCK_MODULES = ['jax', 'jax.lax', 'jax.numpy', 'jax.core', 'jax.experimental.compilation_cache', 'jax._src', 'jax._src.util', 'jax._src.api', 'jax._src.traceback_util', 'jax.tree_util', 24 | # 'jax.typing', 'jaxtyping', 'sympy', 'jax_verify', 'jax_verify.src', 'jax_verify.src.linear', 'sympy2jax', 'diffrax', 'equinox', 'equinox.nn', 25 | # 'numpy', 'shapely', 'shapely.geometry', 'shapely.ops'] 26 | # MOCK_MODULES = ['sympy', 'jax_verify', 'jax_verify.src', 'jax_verify.src.linear', 'sympy2jax', 'diffrax', 'equinox', 'equinox.nn', 27 | # 'numpy', 'shapely', 'shapely.geometry', 'shapely.ops'] 28 | MOCK_MODULES = ["jax_verify", "jax_verify.src", "jax_verify.src.linear"] 29 | # MOCK_MODULES = [] 30 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 31 | 32 | project = "immrax" 33 | copyright = "2023, Akash Harapanahalli" 34 | author = "Akash Harapanahalli" 35 | 36 | # -- General configuration --------------------------------------------------- 37 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 38 | 39 | extensions = [ 40 | "sphinx.ext.autodoc", 41 | "sphinx.ext.mathjax", 42 | # 'nbsphinx', 43 | "myst_nb", 44 | "numpydoc", 45 | ] 46 | 47 | templates_path = ["_templates"] 48 | exclude_patterns = [] 49 | 50 | autoclass_content = "both" 51 | 52 | autodoc_member_order = "bysource" 53 | 54 | nb_execution_mode = "off" 55 | 56 | # myst_enable_extensions=["dollarmath","amsmath"] 57 | # myst_enable_extensions = ["dollarmath"] 58 | # nb_myst_enable_extensions = ["dollarmath"] 59 | myst_enable_extensions = [ 60 | "amsmath", 61 | "colon_fence", 62 | "deflist", 63 | "dollarmath", 64 | "html_image", 65 | ] 66 | myst_url_schemes = ("http", "https", "mailto") 67 | # myst_update_mathjax = False 68 | 69 | # -- Options for HTML output ------------------------------------------------- 70 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 71 | 72 | # html_theme = 'alabaster' 73 | # html_theme = 'sphinx_rtd_theme' 74 | html_theme = "sphinx_book_theme" 75 | html_static_path = ["_static"] 76 | -------------------------------------------------------------------------------- /examples/auxillary_vars/CVDP23.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as onp 3 | from pypoman import plot_polygon 4 | from scipy.spatial import HalfspaceIntersection 5 | import matplotlib.pyplot as plt 6 | import immrax as irx 7 | from immrax.utils import angular_sweep, run_times 8 | from immrax.embedding import AuxVarEmbedding 9 | 10 | 11 | class CVDP(irx.System): 12 | def __init__(self, mu: float = 1) -> None: 13 | self.evolution = "continuous" 14 | self.xlen = 5 15 | self.mu = mu 16 | 17 | def f(self, t, x: jnp.ndarray) -> jnp.ndarray: 18 | x1, y1, x2, y2, b = x.ravel() 19 | return jnp.array( 20 | [ 21 | y1, 22 | self.mu * (1 - x1**2) * y1 + b * (x2 - x1) - x1, 23 | y2, 24 | self.mu * (1 - x2**2) * y2 - b * (x2 - x1) - x2, 25 | 0, 26 | ] 27 | ) 28 | 29 | 30 | x0 = irx.interval( 31 | jnp.array([1.25, 2.35, 1.25, 2.35, 1]), jnp.array([1.55, 2.45, 1.55, 2.45, 3]) 32 | ) 33 | N = 2 # WARN: don't choose odd numbers here, makes redundant aux vars 34 | sweep = angular_sweep(N) 35 | couplings = [(0, 1), (2, 3), (1, 2), (0, 4)] 36 | H = jnp.eye(5) 37 | for coupling in couplings: 38 | permuted_sweep = jnp.zeros([N, len(x0)]) 39 | for var, idx in enumerate(coupling): 40 | permuted_sweep = permuted_sweep.at[:, idx].set(sweep[:, var]) 41 | H = jnp.vstack([H, permuted_sweep]) 42 | 43 | 44 | x0_lifted = irx.interval(H) @ x0 45 | t0 = 0.0 46 | tf = 2.0 47 | 48 | sys = CVDP() 49 | embsys = AuxVarEmbedding(sys, H) 50 | print("Compiling...") 51 | traj = embsys.compute_trajectory(t0, 0.01, irx.i2ut(x0_lifted)) 52 | print("Compiled.\nComputing trajectory...") 53 | traj, time = run_times(1, embsys.compute_trajectory, t0, tf, irx.i2ut(x0_lifted)) 54 | traj = traj.to_convenience() 55 | print( 56 | f"Computing trajectory took {time.item():.4g} s, {((tf - t0) / 0.01 / time).item():.4g} it/s" 57 | ) 58 | ys_int = [irx.ut2i(y) for y in traj.ys] 59 | print(f"Final bound:\n{ys_int[-1][:5]}") 60 | 61 | # Plot the trajectory 62 | axs = [] 63 | for var_pair in range(2): 64 | plt.figure() 65 | plt.axhline(y=2.75, color="red", linestyle="--") 66 | plt.xlabel(f"x{var_pair + 1}") 67 | plt.ylabel(f"y{var_pair + 1}") 68 | axs.append(plt.gca()) 69 | 70 | for bound in ys_int: 71 | cons = onp.hstack( 72 | ( 73 | onp.vstack((-H, H)), 74 | onp.concatenate((bound.lower, -bound.upper)).reshape(-1, 1), 75 | ) 76 | ) 77 | hs = HalfspaceIntersection(cons, bound.center[0:5]) 78 | vertices = hs.intersections 79 | 80 | plt.sca(axs[0]) 81 | plot_polygon(vertices[:, 0:2], fill=False, resize=True, color="tab:blue") 82 | plt.sca(axs[1]) 83 | plot_polygon(vertices[:, 2:4], fill=False, resize=True, color="tab:blue") 84 | 85 | plt.show() 86 | -------------------------------------------------------------------------------- /docs/source/gettingstarted.md: -------------------------------------------------------------------------------- 1 | (Getting Started)= 2 | 3 | # Getting Started 4 | 5 | ## Dependencies 6 | 7 | `immrax` depends on the library `pypoman`, which internally uses `pycddlib` as a wrapper around [the cdd library](https://people.inf.ethz.ch/fukudak/cdd_home/). For this wrapper to function properly, you must install `cdd` to your system. On Ubuntu, the relevant packages can be installed with 8 | 9 | ```bash 10 | apt-get install -y libcdd-dev libgmp-dev 11 | ``` 12 | 13 | On Arch linux, you can use 14 | 15 | ```bash 16 | pacman -S cddlib 17 | ``` 18 | 19 | 20 | ## Installation 21 | 22 | ### Setting up a `conda` environment 23 | 24 | We recommend installing JAX and `immrax` into a `conda` environment ([miniconda](https://docs.conda.io/projects/miniconda/en/latest/)). 25 | 26 | ```shell 27 | conda create -n immrax python=3.12 28 | conda activate immrax 29 | ``` 30 | 31 | ### Installing immrax 32 | 33 | For now, manually clone the Github repository and `pip install` it. We plan to release a stable version on PyPi soon. 34 | 35 | ```shell 36 | git clone https://github.com/gtfactslab/immrax.git 37 | cd immrax 38 | pip install . 39 | ``` 40 | 41 | 42 | If you have cuda-enabled hardware you wish to utilize, please install the `cuda` optional dependency group. 43 | 44 | ```shell 45 | ... 46 | pip install .[cuda] 47 | ``` 48 | 49 | To test if the installation process worked, run the `compare.py` example. 50 | 51 | ```shell 52 | cd examples 53 | python compare.py 54 | ``` 55 | 56 | This should return the outputs of different inclusion functions as well as their runtimes. 57 | 58 | 59 | ## Citation 60 | 61 | If you find this library useful, please cite our paper with the following bibtex entry. 62 | 63 | ``` 64 | @article{immrax, 65 | title = {immrax: A Parallelizable and Differentiable Toolbox for Interval Analysis and Mixed Monotone Reachability in {JAX}}, 66 | journal = {IFAC-PapersOnLine}, 67 | volume = {58}, 68 | number = {11}, 69 | pages = {75-80}, 70 | year = {2024}, 71 | note = {8th IFAC Conference on Analysis and Design of Hybrid Systems ADHS 2024}, 72 | issn = {2405-8963}, 73 | doi = {https://doi.org/10.1016/j.ifacol.2024.07.428}, 74 | url = {https://www.sciencedirect.com/science/article/pii/S2405896324005275}, 75 | author = {Akash Harapanahalli and Saber Jafarpour and Samuel Coogan}, 76 | keywords = {Interval analysis, Reachability analysis, Automatic differentiation, Parallel computation, Computational tools, Optimal control, Robust control}, 77 | abstract = {We present an implementation of interval analysis and mixed monotone interval reachability analysis as function transforms in Python, fully composable with the computational framework JAX. The resulting toolbox inherits several key features from JAX, including computational efficiency through Just-In-Time Compilation, GPU acceleration for quick parallelized computations, and Automatic Differentiability We demonstrate the toolbox’s performance on several case studies, including a reachability problem on a vehicle model controlled by a neural network, and a robust closed-loop optimal control problem for a swinging pendulum.} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /immrax/control.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import jax 3 | from jaxtyping import Integer, Float 4 | from typing import Union 5 | from immrax.system import System, OpenLoopSystem 6 | 7 | __all__ = [ 8 | "Control", 9 | "ControlledSystem", 10 | "LinearControl", 11 | ] 12 | 13 | 14 | class Control(abc.ABC): 15 | """Control 16 | A feedback controller of the form :math:`u:\\mathbb{R}\\times\\mathbb{R}^n\\to\\mathbb{R}^p`. 17 | """ 18 | 19 | @abc.abstractmethod 20 | def u(self, t: Union[Integer, Float], x: jax.Array) -> jax.Array: 21 | """Feedback Control Output 22 | 23 | Parameters 24 | ---------- 25 | t:Union[Integer, Float] : 26 | 27 | x:jax.Array : 28 | 29 | 30 | Returns 31 | ------- 32 | 33 | """ 34 | 35 | 36 | class LinearControl(Control): 37 | K: jax.Array 38 | 39 | def __init__(self, K: jax.Array) -> None: 40 | self.K = K 41 | 42 | def u(self, t: Union[Integer, Float], x: jax.Array) -> jax.Array: 43 | return self.K @ x 44 | 45 | 46 | class ControlledSystem(System): 47 | """ControlledSystem 48 | A closed-loop nonlinear dynamical system of the form 49 | 50 | .. math:: 51 | 52 | \\dot{x} = f^{\\textsf{c}}(x,w) = f(x,N(x),w), 53 | 54 | where :math:`N:\\mathbb{R}^n \\to \\mathbb{R}^p`. 55 | """ 56 | 57 | olsystem: OpenLoopSystem 58 | control: Control 59 | 60 | def __init__(self, olsystem: OpenLoopSystem, control: Control) -> None: 61 | self.olsystem = olsystem 62 | self.control = control 63 | self.evolution = olsystem.evolution 64 | self.xlen = olsystem.xlen 65 | 66 | def f(self, t: Union[Integer, Float], x: jax.Array, w: jax.Array) -> jax.Array: 67 | """Returns the value of the closed loop system 68 | 69 | Parameters 70 | ---------- 71 | t : Union[Integer, Float] 72 | time value 73 | x : jax.Array 74 | state value 75 | w : jax.Array 76 | disturbance value 77 | 78 | Returns 79 | ------- 80 | jax.Array 81 | :math:`f^{\\textsf{c}}(x,w) = f(x,N(x),w)` 82 | 83 | """ 84 | # x = jnp.asarray(x); w = jnp.asarray(w) 85 | return self.olsystem.f(t, x, self.control.u(t, x), w) 86 | 87 | 88 | # class FOHControlledSystem (ControlledSystem) : 89 | # """FOHControlledSystem 90 | # A system in closed-loop with a First Order Hold Controller of the form 91 | 92 | # .. math:: 93 | 94 | # \\dot{x} = f^{\\textsf{c}}(x,w) = f(x,N(x(\\tau)),w), 95 | 96 | # where :math:`N:\\mathbb{R}^n \\to \\mathbb{R}^p`. The functon `step` is used to set :math:`x(\\tau)`. 97 | # """ 98 | 99 | # ut: jax.Array 100 | # def __init__(self, system: System, control: Control) -> None: 101 | # super().__init__(system, control) 102 | 103 | # def step(self, x:jax.Array) -> None : 104 | # self.ut = self.control(x) 105 | 106 | # def fc(self, x: jax.Array, w: jax.Array) -> jax.Array: 107 | # return self.system.f(x, self.ut, w) 108 | -------------------------------------------------------------------------------- /immrax/parametric/parametope.py: -------------------------------------------------------------------------------- 1 | from jax.tree_util import register_pytree_node_class 2 | import jax.numpy as jnp 3 | from jaxtyping import ArrayLike 4 | from ..inclusion import Interval 5 | 6 | 7 | @register_pytree_node_class 8 | class Parametope: 9 | r"""Parametope. Defines the set 10 | 11 | .. math:: 12 | {x : g(\alpha, x - \mathring{x}) <= y} 13 | 14 | """ 15 | 16 | ox: ArrayLike # Center 17 | alpha: ArrayLike # Parameters 18 | y: ArrayLike # Offset 19 | 20 | def __init__(self, ox, alpha, y): 21 | self.ox = ox 22 | self.alpha = alpha 23 | self.y = y 24 | 25 | def g(self, x: ArrayLike): 26 | r"""Evaluates the nonlinearity :math:`g(\alpha, x - \mathring{x})` at x 27 | 28 | Parameters 29 | ---------- 30 | alpha : ArrayLike 31 | _description_ 32 | x : ArrayLike 33 | _description_ 34 | """ 35 | raise NotImplementedError("Subclasses must implement the g method.") 36 | 37 | # Always flatten parametope data into (ox, alpha, y) 38 | def tree_flatten(self): 39 | return ((self.ox, self.alpha, self.y), type(self).__name__) 40 | 41 | # Override in subclasses to unpack the flattened data 42 | @classmethod 43 | def from_parametope(cls, pt: "Parametope"): 44 | return pt 45 | 46 | @classmethod 47 | def tree_unflatten(cls, aux_data, children): 48 | return cls.from_parametope(Parametope(*children)) 49 | 50 | @property 51 | def dtype(self) -> jnp.dtype: 52 | return self.ox.dtype 53 | 54 | def __str__(self): 55 | return f"Parametope(ox={self.ox}, alpha={self.alpha}, y={self.y})" 56 | 57 | 58 | @register_pytree_node_class 59 | class hParametope(Parametope): 60 | r"""Defines a parametope with the particular structured nonlinearity 61 | 62 | .. math:: 63 | g(\alpha, x - \mathring{x}) = (-h(\alpha (x - \mathring{x})), h(\alpha (x - \mathring{x}))) 64 | 65 | and y split into lower and upper bounds y = (ly, uy). 66 | """ 67 | 68 | def h(self, z: ArrayLike): 69 | """Evaluates the nonlinearity h at z 70 | 71 | Parameters 72 | ---------- 73 | z : ArrayLike 74 | Input to the nonlinearity 75 | """ 76 | pass 77 | 78 | def g(self, x: ArrayLike): 79 | """Evaluates the nonlinearity g at alpha, x 80 | 81 | Parameters 82 | ---------- 83 | z : ArrayLike 84 | Input to the nonlinearity 85 | """ 86 | return ( 87 | -self.h(jnp.dot(self.alpha, x - self.ox)), 88 | self.h(jnp.dot(self.alpha, x - self.ox)), 89 | ) 90 | 91 | def hinv(self, iy: Interval): 92 | """Overapproximating inverse image of the nonlinearity h 93 | 94 | Parameters 95 | ---------- 96 | iy : ArrayLike 97 | _description_ 98 | """ 99 | pass 100 | 101 | def k_face(self, k: int) -> Interval: 102 | """Overapproximate the k-face of the hParametope""" 103 | pass 104 | 105 | # Override in subclasses to unpack the flattened data 106 | @classmethod 107 | def from_parametope(cls, pt: "hParametope"): 108 | return pt 109 | 110 | @classmethod 111 | def tree_unflatten(cls, aux_data, children): 112 | return cls.from_parametope(hParametope(*children)) 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # immrax 2 | 3 | `immrax` is a tool for interval analysis and mixed monotone reachability analysis in JAX. 4 | 5 | Inclusion function transformations are composable with existing JAX transformations, allowing the use of Automatic Differentiation to learn relationships between inputs and outputs, as well as parallelization and GPU capabilities for quick, accurate reachable set estimation. 6 | 7 | For more information, please see the full [documentation](https://immrax.readthedocs.io). 8 | 9 | ## Dependencies 10 | 11 | `immrax` depends on the library `pypoman`, which internally uses `pycddlib` as a wrapper around [the cdd library](https://people.inf.ethz.ch/fukudak/cdd_home/). For this wrapper to function properly, you must install `cdd` to your system. On Ubuntu, the relevant packages can be installed with 12 | 13 | ```bash 14 | apt-get install -y libcdd-dev libgmp-dev 15 | ``` 16 | 17 | On Arch linux, you can use 18 | 19 | ```bash 20 | pacman -S cddlib 21 | ``` 22 | 23 | ## Installation 24 | 25 | ### Setting up a `conda` environment 26 | 27 | We recommend installing JAX and `immrax` into a `conda` environment ([miniconda](https://docs.conda.io/projects/miniconda/en/latest/)). 28 | 29 | ```shell 30 | conda create -n immrax python=3.12 31 | conda activate immrax 32 | ``` 33 | 34 | ### Installing immrax 35 | 36 | `immrax` is available as a package on PyPI and can be installed with `pip`. 37 | 38 | ```shell 39 | pip install immrax 40 | ``` 41 | 42 | If you have cuda-enabled hardware you wish to utilize, please install the `cuda` optional dependency group. 43 | 44 | ```shell 45 | ... 46 | pip install immrax[cuda] 47 | ``` 48 | 49 | To test if the installation process worked, run the `compare.py` example. The additional `examples` optional dependency group contains some dependencies needed for the more complex examples; be sure to also install it if you want to run the others. 50 | 51 | ```shell 52 | cd examples 53 | python compare.py 54 | ``` 55 | 56 | This should return the outputs of different inclusion functions as well as their runtimes. 57 | 58 | ## Citation 59 | 60 | If you find this library useful, please cite our paper with the following bibtex entry. 61 | 62 | ``` 63 | @article{immrax, 64 | title = {immrax: A Parallelizable and Differentiable Toolbox for Interval Analysis and Mixed Monotone Reachability in {JAX}}, 65 | journal = {IFAC-PapersOnLine}, 66 | volume = {58}, 67 | number = {11}, 68 | pages = {75-80}, 69 | year = {2024}, 70 | note = {8th IFAC Conference on Analysis and Design of Hybrid Systems ADHS 2024}, 71 | issn = {2405-8963}, 72 | doi = {https://doi.org/10.1016/j.ifacol.2024.07.428}, 73 | url = {https://www.sciencedirect.com/science/article/pii/S2405896324005275}, 74 | author = {Akash Harapanahalli and Saber Jafarpour and Samuel Coogan}, 75 | keywords = {Interval analysis, Reachability analysis, Automatic differentiation, Parallel computation, Computational tools, Optimal control, Robust control}, 76 | abstract = {We present an implementation of interval analysis and mixed monotone interval reachability analysis as function transforms in Python, fully composable with the computational framework JAX. The resulting toolbox inherits several key features from JAX, including computational efficiency through Just-In-Time Compilation, GPU acceleration for quick parallelized computations, and Automatic Differentiability We demonstrate the toolbox’s performance on several case studies, including a reachability problem on a vehicle model controlled by a neural network, and a robust closed-loop optimal control problem for a swinging pendulum.} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data/ 3 | pngs/ 4 | cache/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /examples/auxillary_vars/aux-var.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Literal, Tuple 3 | import time 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | import matplotlib.patches as mpatches 9 | import matplotlib.lines as mlines 10 | 11 | import immrax as irx 12 | from immrax.embedding import AuxVarEmbedding 13 | from immrax.system import Trajectory 14 | from immrax.utils import ( 15 | angular_sweep, 16 | run_times, 17 | draw_trajectory_2d, 18 | draw_refined_trajectory_2d, 19 | gen_ics, 20 | ) 21 | 22 | 23 | class HarmOsc(irx.System): 24 | def __init__(self) -> None: 25 | self.evolution = "continuous" 26 | self.xlen = 2 27 | self.name = "Harmonic Oscillator" 28 | 29 | def f(self, t, x: jnp.ndarray) -> jnp.ndarray: 30 | x1, x2 = x.ravel() 31 | return jnp.array([-x2, x1]) 32 | 33 | 34 | class VanDerPolOsc(irx.System): 35 | def __init__(self, mu: float = 1) -> None: 36 | self.evolution = "continuous" 37 | self.xlen = 2 38 | self.name = "Van der Pol Oscillator" 39 | self.mu = mu 40 | 41 | def f(self, t, x: jnp.ndarray) -> jnp.ndarray: 42 | x1, x2 = x.ravel() 43 | return jnp.array([self.mu * (x1 - 1 / 3 * x1**3 - x2), x1 / self.mu]) 44 | 45 | 46 | def angular_refined_trajectory( 47 | num_aux_vars: int, mode: Literal["sample", "linprog"], save: bool = False 48 | ) -> Tuple[Trajectory, jax.Array, Trajectory]: 49 | # Generate angular sweep aux vars 50 | # Odd num_aux_var is not a good choice, as it will generate angle theta=pi/2, which is redundant with the actual state vars 51 | aux_vars = angular_sweep(num_aux_vars) 52 | H = jnp.vstack([jnp.eye(2), aux_vars]) 53 | lifted_x0_int = irx.interval(H) @ x0_int 54 | 55 | # Compute refined trajectory 56 | auxsys = AuxVarEmbedding(sys, H, mode=mode) 57 | print("Compiling...") 58 | start = time.time() 59 | get_traj = jax.jit( 60 | lambda t0, tf, x0: auxsys.compute_trajectory(t0, tf, x0, solver="euler"), 61 | backend="cpu", 62 | ) 63 | get_traj(0.0, 0.01, irx.i2ut(lifted_x0_int)) 64 | print(f"Compilation took: {time.time() - start:.4g}s") 65 | print("Compiled.\nComputing trajectory...") 66 | traj, comp_time = run_times( 67 | 10, 68 | get_traj, 69 | 0.0, 70 | sim_len, 71 | irx.i2ut(lifted_x0_int), 72 | ) 73 | traj = traj.to_convenience() 74 | print( 75 | f"Computing trajectory with {mode} refinement for {num_aux_vars} aux vars took: {comp_time.mean():.4g} ± {comp_time.std():.4g}s" 76 | ) 77 | 78 | ys_int = [irx.ut2i(y) for y in traj.ys] 79 | final_bound = ys_int[-1][2:] 80 | final_bound_size = (final_bound[0].upper - final_bound[0].lower) * ( 81 | final_bound[1].upper - final_bound[1].lower 82 | ) 83 | print(f"Final bound: \n{final_bound}, size: {final_bound_size}") 84 | 85 | if save: 86 | pickle.dump(ys_int, open(f"{mode}_traj_{num_aux_vars}.pkl", "wb")) 87 | 88 | mc_x0s = gen_ics(x0_int, 30) 89 | mc_traj = jax.vmap( 90 | lambda x0: sys.compute_trajectory(0, sim_len, x0, solver="euler"), 91 | )(mc_x0s) 92 | 93 | return traj, H, mc_traj 94 | 95 | 96 | def plot_angular_refined_trajectory(traj: Trajectory, H: jax.Array): 97 | fig = plt.figure() 98 | # fig, axs = plt.subplots(int(jnp.ceil(N / 3)), 3, figsize=(5, 5)) 99 | fig.suptitle(f"Reachable Sets of the {sys.name}") 100 | plt.gca().set_xlabel(r"$x_1$") 101 | plt.gca().set_ylabel(r"$x_2$") 102 | # axs = axs.reshape(-1) 103 | 104 | draw_refined_trajectory_2d(traj, H) 105 | 106 | 107 | x0_int = irx.icentpert(jnp.array([1.0, 0.0]), jnp.array([0.1, 0.1])) 108 | sim_len = 2 * jnp.pi 109 | 110 | plt.rcParams.update({"text.usetex": True, "font.family": "CMU Serif", "font.size": 14}) 111 | # plt.figure() 112 | 113 | 114 | # Trajectory of unrefined system 115 | sys = VanDerPolOsc() # Can use an arbitrary system here 116 | embsys = irx.mjacemb(sys) 117 | traj = embsys.compute_trajectory( 118 | 0.0, 119 | sim_len, 120 | irx.i2ut(x0_int), 121 | ) 122 | plt.gcf().suptitle(f"{sys.name} with Uncertainty (No Refinement)") 123 | draw_trajectory_2d(traj) 124 | 125 | 126 | traj_s, H, mc_traj = angular_refined_trajectory(6, "sample") 127 | traj_lp, H, mc_traj = angular_refined_trajectory(6, "linprog") 128 | plot_angular_refined_trajectory(traj_lp, H) 129 | fig = plt.gcf() 130 | x = mc_traj.ys[:, :, 0].T 131 | y = mc_traj.ys[:, :, 1].T 132 | plt.plot(x, y, alpha=0.5, color="gray", linewidth=0.5) 133 | 134 | blue_rectangle = mpatches.Patch( 135 | edgecolor="tab:blue", facecolor="none", alpha=0.4, label="Reachable Set Bounds" 136 | ) 137 | gray_line = mlines.Line2D( 138 | [], [], color="gray", alpha=0.5, label="Monte Carlo Trajectories" 139 | ) 140 | plt.legend(handles=[blue_rectangle, gray_line], loc="lower left") 141 | 142 | plt.show() 143 | -------------------------------------------------------------------------------- /immrax/parametric/sets/polytope.py: -------------------------------------------------------------------------------- 1 | from ..parametope import hParametope 2 | import jax.numpy as jnp 3 | import matplotlib.pyplot as plt 4 | from jax.tree_util import register_pytree_node_class 5 | from ...inclusion import interval, i2centpert 6 | import numpy as onp 7 | from pypoman import plot_polygon, compute_polytope_vertices, project_polytope 8 | 9 | 10 | def _lu2y(l, u): 11 | return jnp.concatenate((-l, u)) 12 | 13 | 14 | def _y2lu(y): 15 | print(y) 16 | return -y[: len(y) // 2], y[len(y) // 2 :] 17 | 18 | 19 | @register_pytree_node_class 20 | class Polytope(hParametope): 21 | # def __init__ (self, ox, H, ) : 22 | # super().__init__(ox, H, _lu2y(ly, uy)) 23 | # # m_ks = [H.shape[0], len(ly), len(uy)] 24 | # # if len(set(m_ks)) != 1 : 25 | # # raise Exception(f"Dimension mismatch: {m_ks}") 26 | # # self.mk = m_ks[0] 27 | # # if len(ox) != H.shape[1] : 28 | # # raise Exception(f"Dimension mismatch: {ox.shape=} {H.shape=}") 29 | 30 | # @classmethod 31 | # def from_parametope (cls, pt:hParametope) : 32 | # # print(ds.H) 33 | # return cls(pt.ox, pt.alpha, *_y2lu(pt.y)) 34 | 35 | def h(self, z): 36 | # Identity nonlinearity 37 | return jnp.concatenate((-z, z)) 38 | 39 | def hinv(self, y): 40 | # Inverse image is also identity 41 | return interval(-y[: len(y) // 2], y[len(y) // 2 :]) 42 | 43 | @property 44 | def H(self): 45 | return self.alpha 46 | 47 | @property 48 | def ly(self): 49 | return -self.y[: len(self.y) // 2] 50 | 51 | @property 52 | def uy(self): 53 | return self.y[len(self.y) // 2 :] 54 | 55 | @property 56 | def iy(self): 57 | return interval(self.ly, self.uy) 58 | 59 | def get_vertices(self): 60 | Hi = jnp.vstack((-self.H, self.H)) 61 | bi = jnp.hstack((-self.ly, self.uy)) 62 | return jnp.asarray(compute_polytope_vertices(Hi, bi)) + self.ox 63 | 64 | def plot_projection(self, ax, xi=0, yi=1, rescale=False, **kwargs): 65 | Hi = onp.vstack((-self.H, self.H)) 66 | bi = onp.hstack((-self.ly, self.uy)) 67 | if Hi.shape[1] == 2: 68 | V = compute_polytope_vertices(Hi, bi) 69 | elif Hi.shape[1] > 2: 70 | E = onp.zeros((2, self.H.shape[1])) 71 | E[0, xi] = 1 72 | E[1, yi] = 1 73 | Hi = onp.vstack((-self.H, self.H)) 74 | bi = onp.hstack((-self.ly, self.uy)) 75 | # print(Hi.shape, bi.shape) 76 | V = project_polytope((E, jnp.zeros(2)), (Hi, bi)) 77 | plt.sca(ax) 78 | kwargs.setdefault("alpha", 1.0) 79 | kwargs.setdefault("fill", False) 80 | plot_polygon([v + self.ox[(xi, yi),] for v in V], **kwargs) 81 | # plot_polygon(V, **kwargs) 82 | 83 | def one_d_proj(self, yi=0, rescale=False, **kwargs): 84 | # 1D projection onto xi, time. Plotted as a tube 85 | Hi = onp.vstack((-self.H, self.H)) 86 | bi = onp.hstack((-self.ly, self.uy)) 87 | # V = compute_polytope_vertices(Hi, bi) 88 | E = onp.zeros((1, len(self.H))) 89 | E[0, yi] = 1 90 | return project_polytope((E, jnp.zeros(1)), (Hi, bi)) 91 | 92 | # @classmethod 93 | # def from_Hpolytope (H, uy, ox=jnp.zeros(2)) : 94 | # return Polytope(ox, H, -jnp.inf*jnp.ones_like(uy), uy) 95 | 96 | @classmethod 97 | def from_interval(cls, *args): 98 | cent, pert = i2centpert(interval(*args)) 99 | return Polytope(cent, jnp.eye(len(cent)), jnp.concatenate((pert, pert))) 100 | 101 | def add_rows(self, Haug, Hp): 102 | yaug = interval(Haug @ Hp) @ self.hinv(self.y) 103 | return Polytope( 104 | self.ox, 105 | jnp.vstack((self.H, Haug)), 106 | _lu2y( 107 | jnp.concatenate((self.ly, yaug.lower)), 108 | jnp.concatenate((self.uy, yaug.upper)), 109 | ), 110 | ) 111 | 112 | # Override in subclasses to unpack the flattened data 113 | @classmethod 114 | def from_parametope(cls, pt: "hParametope"): 115 | return Polytope(pt.ox, pt.alpha, pt.y) 116 | 117 | # @classmethod 118 | # def tree_unflatten (cls, aux_data, children) : 119 | # return cls.from_parametope(hParametope(*children)) 120 | 121 | 122 | # @register_pytree_node_class 123 | # class IntervalDualStar (Polytope) : 124 | # def __init__ (self, I:Interval, ox=None) : 125 | # ox = (I.upper + I.lower) / 2 if ox is None else ox 126 | # super().__init__(ox, jnp.eye(len(ox)), I.lower - ox, I.upper - ox) 127 | 128 | # def ds_add_interval (ds:DualStar) : 129 | # """Add an interval to the DualStar 130 | 131 | # Parameters 132 | # ---------- 133 | # ds : DualStar 134 | # The DualStar to add the interval to 135 | # iy : Interval 136 | # The interval to add 137 | # """ 138 | # ids = IntervalDualStar(ds.iover(), ds.ox) 139 | # # List addition here 140 | # return DualStar(ds.ox, ds.H + ids.H, ds.ly + ids.ly, ds.uy + ids.uy) 141 | -------------------------------------------------------------------------------- /immrax/inclusion/cubic_spline.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | import immrax as irx 5 | from immrax.inclusion.polynomial import polynomial 6 | from immrax.inclusion import custom_if 7 | 8 | 9 | def create_cubic_spline_coeffs(points): 10 | """ 11 | Computes the coefficients for a natural cubic spline from a list of points. 12 | This implementation is designed to be JAX-traceable. 13 | 14 | Args: 15 | points: A JAX numpy array of (x, y) coordinates. 16 | 17 | Returns: 18 | A tuple containing: 19 | - x_knots: The x-coordinates of the knots. 20 | - coeffs: A tuple of (a, b, c, d) coefficients for the n-1 splines. 21 | """ 22 | # Sort points by x-values, as JAX requires static shapes for many operations 23 | # and sorting inside a jitted function can be tricky if not handled carefully. 24 | # Here we assume points are sorted or we sort them before passing to a jitted function. 25 | sorted_indices = jnp.argsort(points[:, 0]) 26 | points = points[sorted_indices] 27 | 28 | x = points[:, 0] 29 | y = points[:, 1] 30 | n = len(x) 31 | 32 | if n < 3: 33 | raise ValueError("At least 3 points are required to build a cubic spline.") 34 | 35 | h = jnp.diff(x) 36 | 37 | # Setup the tridiagonal system for the second derivatives (related to 'c' coeffs) 38 | # For a natural spline, the second derivatives at the endpoints are 0. 39 | # This leaves us with n-2 unknowns to solve for. 40 | 41 | # Main diagonal of the tridiagonal matrix 42 | diag = 2 * (h[:-1] + h[1:]) 43 | 44 | # Off-diagonals 45 | off_diag = h[1:-1] 46 | 47 | # Construct the (n-2)x(n-2) tridiagonal matrix A 48 | A = jnp.diag(diag) + jnp.diag(off_diag, k=1) + jnp.diag(off_diag, k=-1) 49 | 50 | # Right-hand side vector B 51 | B = 3 * (jnp.diff(y[1:]) / h[1:] - jnp.diff(y[:-1]) / h[:-1]) 52 | 53 | # Solve Ac = B for the inner c coefficients (c_1 to c_{n-2}) 54 | # Using jnp.linalg.solve which is traceable 55 | c_inner = jnp.linalg.solve(A, B) 56 | 57 | # Combine with boundary conditions (c_0 = c_{n-1} = 0 for natural spline) 58 | c = jnp.concatenate([jnp.array([0.0]), c_inner, jnp.array([0.0])]) 59 | 60 | # Calculate remaining coefficients a, b, d from c 61 | # a_i = y_i 62 | a = y[:-1] 63 | 64 | # b_i = (y_{i+1} - y_i)/h_i - h_i/3 * (2*c_i + c_{i+1}) 65 | b = (jnp.diff(y) / h) - (h / 3) * (2 * c[:-1] + c[1:]) 66 | 67 | # d_i = (c_{i+1} - c_i) / (3*h_i) 68 | d = jnp.diff(c) / (3 * h) 69 | 70 | # The coefficients are for n-1 polynomials. 71 | # c is of length n, so we take c[:-1] for the c_i coefficient of the i-th polynomial. 72 | coeffs = (a, b, c[:-1], d) 73 | x_knots = x 74 | 75 | return x_knots, coeffs 76 | 77 | 78 | def make_spline_eval_fn(x_knots, coeffs): 79 | """ 80 | Creates a JAX-traceable function that evaluates the cubic spline. 81 | 82 | Args: 83 | x_knots: The x-coordinates of the knots. 84 | coeffs: A tuple of (a, b, c, d) spline coefficients. 85 | 86 | Returns: 87 | A function that takes a single x-value or an array of x-values and 88 | returns the corresponding interpolated y-value(s). 89 | """ 90 | a, b, c, d = coeffs 91 | 92 | @custom_if 93 | def eval_fn(x_eval): 94 | """ 95 | Evaluates the spline at given x-values. 96 | """ 97 | # Vectorized search for the correct interval for each x_eval 98 | i = jnp.clip(jnp.searchsorted(x_knots, x_eval, side="right") - 1, 0, len(a) - 1) 99 | 100 | # Evaluate the polynomial for each x_eval using its interval index i 101 | dx = x_eval - x_knots[i] 102 | my_eval = polynomial(jnp.array([d[i], c[i], b[i], a[i]]).squeeze(), dx) 103 | 104 | return my_eval 105 | 106 | @eval_fn.defif 107 | def incl_fn(int_eval): 108 | # For extrapolation, we extend the first and last domains to infinity. 109 | lower_bins = jnp.concatenate((-jnp.array([jnp.inf]), x_knots[1:-1])) 110 | upper_bins = jnp.concatenate((x_knots[1:-1], jnp.array([jnp.inf]))) 111 | bins = irx.interval(lower_bins, upper_bins) 112 | 113 | # int_eval_partitioned = jax.vmap(lambda x, y: x and y, in_axes=(0, None))( 114 | int_eval_partitioned = jax.vmap( 115 | lambda x, y: x & y, 116 | in_axes=(0, None), 117 | )(bins, int_eval) 118 | 119 | # The polynomial is evaluated on dx = x - x_knots[i]. 120 | # We need to compute the interval for dx for each bin. 121 | dx_intervals = int_eval_partitioned - x_knots[:-1] 122 | 123 | out_partitioned = jax.vmap(irx.natif(polynomial), out_axes=(-1))( 124 | jnp.array([d, c, b, a]).T, dx_intervals 125 | ) 126 | 127 | # Identify empty intervals in the input 128 | is_empty = int_eval_partitioned.lower >= int_eval_partitioned.upper 129 | 130 | # Use jnp.where to replace empty intervals with neutral elements for the reduction. 131 | valid_lowers = jnp.where(is_empty, jnp.inf, out_partitioned.lower) 132 | valid_uppers = jnp.where(is_empty, -jnp.inf, out_partitioned.upper) 133 | 134 | # jax.debug.print( 135 | # "int_eval_partitioned: {}\n dx_intervals: {}\nis_empty: {},\nout_partitioned: {},\nvalid_lowers: {},\nvalid_uppers: {}", 136 | # int_eval_partitioned, 137 | # dx_intervals, 138 | # is_empty, 139 | # out_partitioned, 140 | # valid_lowers, 141 | # valid_uppers, 142 | # ) 143 | 144 | return irx.interval(jnp.min(valid_lowers), jnp.max(valid_uppers)) 145 | 146 | return eval_fn 147 | -------------------------------------------------------------------------------- /immrax/inclusion/custom_if.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax.interpreters import ad, batching, mlir 7 | 8 | from immrax.inclusion import nif 9 | 10 | 11 | class custom_if: 12 | """A decorator to define a custom inclusion function for a JAX-traceable function. 13 | 14 | This is analogous to `jax.custom_jvp`. 15 | 16 | This annotation will: 17 | 1. Create a custom primitive that is bound to the implementation of the given function. 18 | 2. Associate the primitive with default batching, lowering, and jvp rules. 19 | 3. Derive the correct shape for abstract evaluation, and associate this shape with the new primitive's abstract eval. 20 | 4. Associate the primitive with a custom inclusion function, which can be defined using the `@f.defif` decorator. 21 | 22 | For example, to create a primitive for `jnp.polyval` with a custom inclusion function: 23 | 24 | .. code-block:: python 25 | 26 | from immrax.inclusion.custom_if import custom_if 27 | import jax.numpy as jnp 28 | 29 | @custom_if 30 | def polyval(a, x): 31 | return jnp.polyval(a, x) 32 | 33 | @polyval.defif 34 | def polyval_inclusion(a, x): 35 | # custom inclusion logic 36 | ... 37 | 38 | Now `polyval` can be used in computations, and `nif.natif` will dispatch 39 | to `polyval_inclusion` when it encounters this primitive. 40 | """ 41 | 42 | def __init__(self, fun): 43 | self.fun = fun 44 | self._if = None 45 | self.primitive = self._create_primitive() 46 | functools.update_wrapper(self, fun) 47 | 48 | def _create_primitive(self): 49 | primitive_name = f"{self.fun.__name__}_p" 50 | primitive = jax.extend.core.Primitive(primitive_name) 51 | 52 | # 1. Implementation 53 | primitive.def_impl(self.fun) 54 | 55 | # 2. Abstract evaluation (shape inference) 56 | def default_abstract_eval(*args_aval): 57 | try: 58 | shape_dtype = jax.eval_shape(self.fun, *args_aval) 59 | # TODO: I am not entirely sure if this will respect the device / ref counting behavior of the wrapped function 60 | # Should look here first if those types of problems come up 61 | return jax.core.ShapedArray(shape_dtype.shape, shape_dtype.dtype) 62 | except Exception as e: 63 | raise TypeError( 64 | f"Automatic shape inference for '{self.fun.__name__}' failed. " 65 | ) from e 66 | 67 | primitive.def_abstract_eval(default_abstract_eval) 68 | 69 | # 3. JIT lowering 70 | # Assuming the wrapped function returns a single result, as in polynomial.py. 71 | lowering = mlir.lower_fun(self.fun, multiple_results=False) 72 | mlir.register_lowering(primitive, lowering) 73 | 74 | # 4. Batching rule 75 | def batching_rule(vector_arg_values, batch_axes): 76 | res = jax.vmap(self.fun, in_axes=batch_axes)(*vector_arg_values) 77 | 78 | if isinstance(res, (list, tuple)): 79 | return res, tuple([0] * len(res)) 80 | else: 81 | return res, 0 82 | 83 | batching.primitive_batchers[primitive] = batching_rule 84 | 85 | # 5. JVP rules for autodiff 86 | try: 87 | sig = inspect.signature(self.fun) 88 | num_args = sum( 89 | 1 90 | for param in sig.parameters.values() 91 | if param.kind 92 | in ( 93 | inspect.Parameter.POSITIONAL_ONLY, 94 | inspect.Parameter.POSITIONAL_OR_KEYWORD, 95 | ) 96 | ) 97 | except (ValueError, TypeError): 98 | # Fallback for C functions or other callables without a clear signature 99 | # This part might need to be adjusted if those functions are used. 100 | # For now, we won't define JVP rules if we can't inspect the signature. 101 | num_args = 0 102 | 103 | jvprules = [] 104 | for i in range(num_args): 105 | 106 | def make_jvp_rule(arg_num): 107 | def jvp_rule(tangent, *primals): 108 | # Create zero tangents for all other arguments 109 | tangents = [ 110 | tangent 111 | if i == arg_num 112 | else jax.tree_util.tree_map(jnp.zeros_like, primal) 113 | for i, primal in enumerate(primals) 114 | ] 115 | _primals_out, tangents_out = jax.jvp( 116 | self.fun, primals, tuple(tangents) 117 | ) 118 | return tangents_out 119 | 120 | return jvp_rule 121 | 122 | jvprules.append(make_jvp_rule(i)) 123 | 124 | if jvprules: 125 | ad.defjvp(primitive, *jvprules) 126 | 127 | # 6. Register inclusion function 128 | def inclusion_dispatcher(*args, **kwargs): 129 | if self._if is None: 130 | raise NotImplementedError( 131 | f"No inclusion function defined for '{self.fun.__name__}'. " 132 | f"Use '@{self.fun.__name__}.defif' to define it." 133 | ) 134 | return self._if(*args, **kwargs) 135 | 136 | nif.inclusion_registry[primitive] = inclusion_dispatcher 137 | 138 | return primitive 139 | 140 | def __call__(self, *args, **kwargs): 141 | if kwargs: 142 | raise TypeError( 143 | f"Primitive '{self.primitive.name}' does not support keyword arguments. " 144 | "Consider using functools.partial to bind keyword arguments." 145 | ) 146 | return self.primitive.bind(*args) 147 | 148 | def defif(self, if_fun): 149 | """Decorator to define the inclusion function for a custom_if function.""" 150 | self._if = if_fun 151 | return if_fun 152 | -------------------------------------------------------------------------------- /immrax/system/trajectory.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Union, List 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax.tree_util import register_pytree_node_class 7 | 8 | from diffrax import Solution 9 | 10 | import numpy as np 11 | 12 | __all__ = [ 13 | "RawTrajectory", 14 | "RawContinuousTrajectory", 15 | "RawDiscreteTrajectory", 16 | "Trajectory", 17 | "ContinuousTrajectory", 18 | "DiscreteTrajectory", 19 | ] 20 | 21 | 22 | class RawTrajectory(abc.ABC): 23 | """Abstract base class for raw trajectories. 24 | 25 | These trajectories directly store the padded arrays used in JAX computations. 26 | """ 27 | 28 | ts: jax.Array 29 | ys: jax.Array 30 | 31 | @abc.abstractmethod 32 | def to_convenience(self) -> "Trajectory": 33 | """Converts a raw trajectory to a convenience trajectory.""" 34 | ... 35 | 36 | 37 | @register_pytree_node_class 38 | class RawContinuousTrajectory(RawTrajectory): 39 | """Raw continuous trajectory, wrapping a diffrax.Solution.""" 40 | 41 | def __init__(self, solution: Solution): 42 | self.solution = solution 43 | 44 | def __getattr__(self, name: str) -> Any: 45 | return getattr(self.solution, name) 46 | 47 | def tree_flatten(self): 48 | children = (self.solution,) 49 | aux_data = None 50 | return children, aux_data 51 | 52 | @classmethod 53 | def tree_unflatten(cls, aux_data, children): 54 | return cls(children[0]) 55 | 56 | def to_convenience(self) -> "ContinuousTrajectory": 57 | return ContinuousTrajectory(self) 58 | 59 | 60 | @register_pytree_node_class 61 | class RawDiscreteTrajectory(RawTrajectory): 62 | """Raw discrete trajectory.""" 63 | 64 | def __init__(self, ts: jax.Array, ys: jax.Array): 65 | self.ts = ts 66 | self.ys = ys 67 | 68 | def tree_flatten(self): 69 | children = (self.ts, self.ys) 70 | aux_data = None 71 | return children, aux_data 72 | 73 | @classmethod 74 | def tree_unflatten(cls, aux_data, children): 75 | return cls(*children) 76 | 77 | def to_convenience(self) -> "DiscreteTrajectory": 78 | return DiscreteTrajectory(self) 79 | 80 | 81 | class Trajectory: 82 | """Convenience wrapper for trajectories. 83 | 84 | This class provides access to the valid, unpadded data from a raw trajectory. 85 | """ 86 | 87 | ts: Union[jax.Array, List[jax.Array]] 88 | ys: Union[jax.Array, List[jax.Array]] 89 | 90 | def __init__(self, raw_trajectory: RawTrajectory): 91 | """Initializes a convenience trajectory wrapper from a raw trajectory. 92 | 93 | This constructor handles both single and batched trajectories. For batched 94 | trajectories, it can handle both cases where trajectories have the same 95 | length (non-ragged) and different lengths (ragged). 96 | 97 | For single trajectories, `ts` and `ys` will be `jax.Array`. 98 | For non-ragged batched trajectories, `ts` and `ys` will be `jax.Array` 99 | with leading batch dimensions. 100 | For ragged batched trajectories, `ts` and `ys` will be lists of `jax.Array`, 101 | where each array in the list corresponds to a single trajectory in the batch. 102 | """ 103 | raw_ts = raw_trajectory.ts 104 | raw_ys = raw_trajectory.ys 105 | 106 | if raw_ts is None: # diffrax can return None ts 107 | self.ts = jnp.empty(0) 108 | self.ys = ( 109 | jnp.empty((0, raw_ys.shape[-1])) 110 | if raw_ys is not None 111 | else jnp.empty((0, 0)) 112 | ) 113 | return 114 | 115 | tfinite_mask = jnp.isfinite(raw_ts) 116 | 117 | if raw_ts.ndim <= 1: 118 | self.ts = raw_ts[tfinite_mask] 119 | self.ys = raw_ys[tfinite_mask] 120 | return 121 | 122 | time_axis = -1 123 | num_finite = jnp.sum(tfinite_mask, axis=time_axis) 124 | 125 | is_ragged = ( 126 | not jnp.all(num_finite == num_finite.flatten()[0]) 127 | if num_finite.size > 0 128 | else False 129 | ) 130 | 131 | if not is_ragged: 132 | k = num_finite.flatten()[0].item() if num_finite.size > 0 else 0 133 | self.ts = raw_ts[..., :k] 134 | self.ys = raw_ys[..., :k, :] 135 | else: 136 | self.ts = [] 137 | self.ys = [] 138 | batch_shape = raw_ts.shape[:-1] 139 | for batch_idx in np.ndindex(batch_shape): 140 | n = num_finite[batch_idx].item() 141 | self.ts.append(raw_ts[batch_idx, :n].squeeze()) 142 | self.ys.append(raw_ys[batch_idx, :n].squeeze()) 143 | 144 | def is_ragged(self) -> bool: 145 | """Returns True if the trajectory is ragged, meaning `ts` and `ys` are lists of arrays. 146 | 147 | When a `RawTrajectory` is created from a `vmap`'d computation over parameters 148 | that affect the trajectory length (e.g., final time), the resulting batch of 149 | trajectories may be "ragged" - that is, different trajectories in the batch 150 | may have different numbers of time steps. To handle this, the `ts` and `ys` 151 | attributes of the `Trajectory` object will be lists of arrays, where each 152 | array corresponds to a single trajectory. The `is_ragged` method can be used 153 | to check for this case. 154 | """ 155 | return isinstance(self.ts, list) 156 | 157 | 158 | class ContinuousTrajectory(Trajectory): 159 | """Convenience wrapper for continuous trajectories.""" 160 | 161 | def __init__(self, raw_trajectory: RawTrajectory): 162 | assert isinstance(raw_trajectory, RawContinuousTrajectory) 163 | super().__init__(raw_trajectory) 164 | 165 | 166 | class DiscreteTrajectory(Trajectory): 167 | """Convenience wrapper for discrete trajectories.""" 168 | 169 | def __init__(self, raw_trajectory: RawTrajectory): 170 | assert isinstance(raw_trajectory, RawDiscreteTrajectory) 171 | super().__init__(raw_trajectory) 172 | -------------------------------------------------------------------------------- /tests/test_inclusion.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | 5 | import immrax as irx 6 | from tests.utils import validate_overapproximation_nd 7 | 8 | # --- Test Case Flags --- 9 | TEST_NATIF = True 10 | TEST_JACIF = True 11 | TEST_MJACIF = True 12 | TEST_LIN_SYS = True 13 | 14 | 15 | # --- Helper Functions --- 16 | def square_sin(x): 17 | """A simple non-linear function for testing.""" 18 | return jnp.sin(x**2) 19 | 20 | 21 | def exp_add(x): 22 | """A simple non-linear function for testing.""" 23 | return jnp.exp(x) + x 24 | 25 | 26 | def lin_sys(x, u): 27 | A = jnp.array([[1.0, 2.0], [1.0, 1.0]]) 28 | B = jnp.array([[0.0], [1.0]]) 29 | return A @ x[:] + B @ u[:] 30 | 31 | 32 | # --- Fixtures for test data --- 33 | 34 | 35 | @pytest.fixture( 36 | params=[ 37 | pytest.param(square_sin, id="sin(x^2)"), 38 | pytest.param(exp_add, id="exp(x)+x"), 39 | ] 40 | ) 41 | def jax_fn(request): 42 | """Parametrized fixture for JAX-traceable functions.""" 43 | return request.param 44 | 45 | 46 | @pytest.fixture( 47 | params=[ 48 | # pytest.param(jnp.array(1.0), id="scalar"), # FIXME: jacobian-based IFs change this shape 49 | pytest.param(jnp.array([1.0]), id="(1,) array"), 50 | pytest.param(jnp.array([1.0, 2.0]), id="(2,) array"), 51 | ] 52 | ) 53 | def eval_point(request): 54 | """Parametrized fixture for evaluation points.""" 55 | return request.param 56 | 57 | 58 | @pytest.fixture 59 | def eval_interval(eval_point): 60 | """Fixture for evaluation intervals, derived from eval_point.""" 61 | return irx.icentpert(eval_point, 0.1) 62 | 63 | 64 | # --- Test Functions --- 65 | 66 | 67 | @pytest.mark.skipif(not TEST_NATIF, reason="natif tests are disabled") 68 | def test_natif(jax_fn, eval_point, eval_interval): 69 | """Tests the natural interval extension (natif).""" 70 | inclusion_fn = irx.natif(jax_fn) 71 | result = inclusion_fn(eval_interval) 72 | 73 | # 1. Check that the output is an Interval 74 | assert isinstance(result, irx.Interval) 75 | 76 | # 2. Check that the shape is correct 77 | expected_shape = jax.eval_shape(jax_fn, eval_point).shape 78 | assert result.shape == expected_shape 79 | 80 | # 3. Check that the interval bounds are valid 81 | assert jnp.all(result.lower <= result.upper) 82 | 83 | # 4. Validate that the interval overapproximates the true function range 84 | validate_overapproximation_nd(jax_fn, eval_interval, result) 85 | 86 | 87 | @pytest.mark.skipif(not TEST_JACIF, reason="jacif tests are disabled") 88 | def test_jacif(jax_fn, eval_point, eval_interval): 89 | """Tests the Jacobian-based inclusion function (jacif).""" 90 | inclusion_fn = irx.jacif(jax_fn) 91 | result = inclusion_fn(eval_interval) 92 | 93 | # 1. Check that the output is an Interval 94 | assert isinstance(result, irx.Interval) 95 | 96 | # 2. Check that the shape is correct 97 | expected_shape = jax.eval_shape(jax_fn, eval_point).shape 98 | assert result.shape == expected_shape 99 | 100 | # 3. Check that the interval bounds are valid 101 | assert jnp.all(result.lower <= result.upper) 102 | 103 | # 4. Validate that the interval overapproximates the true Jacobian range 104 | validate_overapproximation_nd(jax_fn, eval_interval, result) 105 | 106 | 107 | @pytest.mark.skipif(not TEST_MJACIF, reason="mjacif tests are disabled") 108 | def test_mjacif(jax_fn, eval_point, eval_interval): 109 | """Tests the mixed-order Jacobian-based inclusion function (mjacif).""" 110 | inclusion_fn = irx.mjacif(jax_fn) 111 | result = inclusion_fn(eval_interval) 112 | 113 | # 1. Check that the output is an Interval 114 | assert isinstance(result, irx.Interval) 115 | 116 | # 2. Check that the shape is correct 117 | expected_shape = jax.eval_shape(jax_fn, eval_point).shape 118 | assert result.shape == expected_shape 119 | 120 | # 3. Check that the interval bounds are valid 121 | assert jnp.all(result.lower <= result.upper) 122 | 123 | # 4. Validate that the interval overapproximates the true Jacobian range 124 | validate_overapproximation_nd(jax_fn, eval_interval, result) 125 | 126 | 127 | # --- Tests for lin_sys --- 128 | 129 | 130 | @pytest.fixture 131 | def x_vec(): 132 | """Fixture for the state vector 'x' for lin_sys.""" 133 | return jnp.array([1.0, 2.0]) 134 | 135 | 136 | @pytest.fixture 137 | def u_vec(): 138 | """Fixture for the control vector 'u' for lin_sys.""" 139 | return jnp.array([0.5]) 140 | 141 | 142 | @pytest.fixture 143 | def x_interval(x_vec): 144 | """Fixture for the state interval 'x' for lin_sys.""" 145 | return irx.icentpert(x_vec, 0.1) 146 | 147 | 148 | @pytest.fixture 149 | def u_interval(u_vec): 150 | """Fixture for the control interval 'u' for lin_sys.""" 151 | return irx.icentpert(u_vec, 0.05) 152 | 153 | 154 | @pytest.mark.skipif(not TEST_LIN_SYS, reason="lin_sys tests are disabled") 155 | def test_lin_sys_natif(x_vec, u_vec, x_interval, u_interval): 156 | """Tests the natural interval extension (natif) for a multi-argument function.""" 157 | inclusion_fn = irx.natif(lin_sys) 158 | result = inclusion_fn(x_interval, u_interval) 159 | 160 | # 1. Check that the output is an Interval 161 | assert isinstance(result, irx.Interval) 162 | 163 | # 2. Check that the shape is correct 164 | expected_shape = jax.eval_shape(lin_sys, x_vec, u_vec).shape 165 | assert result.shape == expected_shape 166 | 167 | # 3. Check that the interval bounds are valid 168 | assert jnp.all(result.lower <= result.upper) 169 | 170 | # 4. Validate that the interval overapproximates the true function range 171 | def wrapped_lin_sys(xu_vec): 172 | x = xu_vec[:, :2] 173 | u = xu_vec[:, 2:] 174 | # Vectorize lin_sys to handle the batch of 100 samples 175 | vmapped_lin_sys = jax.vmap(lin_sys, in_axes=(0, 0)) 176 | return vmapped_lin_sys(x, u) 177 | 178 | combined_lower = jnp.concatenate([x_interval.lower, u_interval.lower]) 179 | combined_upper = jnp.concatenate([x_interval.upper, u_interval.upper]) 180 | combined_interval = irx.interval(combined_lower, combined_upper) 181 | 182 | validate_overapproximation_nd(wrapped_lin_sys, combined_interval, result) 183 | -------------------------------------------------------------------------------- /immrax/parametric/sets/normotope.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from ..parametope import Parametope 3 | from .polytope import Polytope 4 | from ...inclusion import interval, icentpert, i2centpert 5 | from jax.tree_util import register_pytree_node_class 6 | from .ellipsoid import Ellipsoid 7 | from math import sqrt 8 | 9 | 10 | @register_pytree_node_class 11 | class Normotope(Parametope): 12 | r"""Defines the set 13 | 14 | .. math:: 15 | {x : \|H(x - \ox)\| \leq y} 16 | 17 | where :math:`\|\cdot\|` is a norm, :math:`\ox` is the center, :math:`H` is a shaping matrix, and :math:`y` is the offset. 18 | 19 | Define :math:`h` as the norm in subclasses, and :math:`\mu` as the logarithmic norm associated to :math:`h`. 20 | """ 21 | 22 | def g(self, x): 23 | return self.h(jnp.dot(self.alpha, x - self.ox)) 24 | 25 | def h(self, z): 26 | """The norm associated to the normotope.""" 27 | raise NotImplementedError("Subclasses must implement the h method.") 28 | 29 | def hinv(self, y): 30 | """An interval overapproximation of the inverse image of y under h.""" 31 | raise NotImplementedError("Subclasses must implement the hinv method.") 32 | 33 | @classmethod 34 | def induced_norm(cls, A): 35 | """Computes the induced norm of A.""" 36 | raise NotImplementedError("Subclasses must implement the induced_norm method.") 37 | 38 | @classmethod 39 | def logarithmic_norm(cls, A): 40 | """The logarithmic norm associated to h.""" 41 | raise NotImplementedError( 42 | "Subclasses must implement the logarithmic_norm method." 43 | ) 44 | 45 | @classmethod 46 | def mu(cls, A): 47 | """Alias for the logarithmic norm.""" 48 | return cls.logarithmic_norm(A) 49 | 50 | def plot_projection(self, ax, xi=0, yi=1, rescale=False, **kwargs): 51 | """Plot the projection of the normotope onto the xi-yi plane.""" 52 | raise NotImplementedError("Will implement a sampling based thing in the future") 53 | 54 | @property 55 | def H(self): 56 | return self.alpha 57 | 58 | @classmethod 59 | def from_parametope(cls, pt: Parametope): 60 | return Normotope(pt.ox, pt.alpha, pt.y) 61 | 62 | def __getitem__(self, item): 63 | """Allows indexing into the normotope's parameters.""" 64 | return self.__class__.from_parametope( 65 | Parametope(self.ox[item], self.alpha[item], self.y[item]) 66 | ) 67 | 68 | def vec(self): 69 | """Vectorizes the normotope into a vector.""" 70 | return jnp.concatenate( 71 | (self.ox, self.alpha.reshape(-1), jnp.atleast_1d(self.y)) 72 | ) 73 | 74 | @classmethod 75 | # @partial(jax.jit, static_argnames=('n',)) 76 | def unvec(cls, vec, n=None): 77 | """Unvectorizes a vector into a normotope.""" 78 | y = vec[-1] 79 | N = len(vec) - 1 80 | if n is None: 81 | # Assume alpha is nxn, so N = n*n + n = n*(n+1) 82 | # QF: n^2 + n - N = 0 ==> n = (-1 + sqrt(1 + 4*N)) / 2 83 | n = int((sqrt(1 + 4 * N) - 1) // 2) 84 | # if alpha is mxn, N = m*n + n = m*(n+1) 85 | 86 | ox = vec[:n] 87 | alpha = vec[n:N].reshape(-1, n) 88 | return cls(ox, alpha, y) 89 | 90 | 91 | @register_pytree_node_class 92 | class LinfNormotope(Normotope): 93 | r"""Defines the set 94 | 95 | .. math:: 96 | {x : \|H(x - \ox)\|_\infty \leq y} 97 | 98 | """ 99 | 100 | def h(self, z): 101 | """The infinity norm""" 102 | return jnp.max(jnp.abs(z)) 103 | 104 | def hinv(self, y): 105 | n = self.alpha.shape[0] 106 | return icentpert(jnp.zeros(n), y * jnp.ones(n)) 107 | 108 | @classmethod 109 | def induced_norm(cls, A): 110 | r"""Computes the induced :math:`\ell_\infty` norm of A""" 111 | # Maximum row sum of |A| 112 | return jnp.max(jnp.sum(jnp.abs(A), axis=1)) 113 | 114 | @classmethod 115 | def logarithmic_norm(cls, A): 116 | r"""Computes the logarithmic :math:`\ell_\infty` norm of A""" 117 | # Maximum row sum of A_M (Metzlerized) 118 | A_M = jnp.where(jnp.eye(A.shape[0], dtype=bool), A, jnp.abs(A)) 119 | return jnp.max(jnp.sum(A_M, axis=1)) 120 | 121 | def to_polytope(self) -> Polytope: 122 | n = self.alpha.shape[0] 123 | return Polytope(self.ox, self.alpha, jnp.ones(2 * n) * self.y) 124 | 125 | def plot_projection(self, ax, xi=0, yi=1, rescale=False, **kwargs): 126 | self.to_polytope().plot_projection(ax, xi, yi, rescale, **kwargs) 127 | 128 | @classmethod 129 | def from_interval(cls, *args): 130 | cent, pert = i2centpert(interval(*args)) 131 | return LinfNormotope(cent, jnp.diag(1 / pert), 1.0) 132 | 133 | @classmethod 134 | def from_normotope(cls, nt: Normotope): 135 | return LinfNormotope(nt.ox, nt.alpha, nt.y) 136 | 137 | 138 | @register_pytree_node_class 139 | class L2Normotope(Normotope): 140 | r"""Defines the set 141 | 142 | .. math:: 143 | {x : \|H(x - \ox)\|_2 \leq y} 144 | 145 | """ 146 | 147 | def h(self, z): 148 | """The L_2 norm""" 149 | return jnp.sum(z**2) ** 0.5 150 | 151 | def hinv(self, y): 152 | n = self.alpha.shape[0] 153 | return icentpert(jnp.zeros(n), y * jnp.ones(n)) 154 | 155 | @classmethod 156 | def induced_norm(cls, A): 157 | r"""Computes the induced :math:`\ell_2` norm of A""" 158 | return jnp.linalg.norm(A, ord=2) 159 | 160 | @classmethod 161 | def logarithmic_norm(cls, A): 162 | r"""Computes the :math:`\ell_2` logarithmic norm of A""" 163 | return jnp.max(jnp.linalg.eigvalsh((A + A.T) / 2)) 164 | 165 | # def to_polytope (self) -> Polytope : 166 | # n = self.alpha.shape[0] 167 | # return Polytope (self.ox, self.alpha, jnp.ones(2*n)*self.y) 168 | 169 | def plot_projection(self, ax, xi=0, yi=1, rescale=False, **kwargs): 170 | # self.to_polytope().plot_projection(ax, xi, yi, rescale, **kwargs) 171 | Ellipsoid(self.ox, self.alpha, jnp.array([0.0, self.y**2])).plot_projection( 172 | ax, xi, yi, rescale, **kwargs 173 | ) 174 | 175 | @classmethod 176 | def from_interval(cls, *args): 177 | cent, pert = i2centpert(interval(*args)) 178 | rn = jnp.sqrt(len(cent)) 179 | # rn = 1. 180 | return L2Normotope(cent, jnp.diag(1 / (rn * pert)), 1.0) 181 | 182 | @classmethod 183 | def from_normotope(cls, nt: Normotope): 184 | return L2Normotope(nt.ox, nt.alpha, nt.y) 185 | -------------------------------------------------------------------------------- /tests/test_polynomial.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | 5 | import immrax as irx 6 | from immrax.inclusion.polynomial import polynomial 7 | from tests.utils import validate_overapproximation_nd 8 | 9 | # --- Test Case Flags --- 10 | # Set these to True to run the corresponding tests. 11 | # You can then run tests using `pytest` in your terminal. 12 | TEST_VECTOR_INPUTS = True 13 | TEST_INCLUSION_FUNCTIONS = True 14 | TEST_JACFWD = True 15 | TEST_JACREV = True 16 | TEST_JIT_COMPILATION = True 17 | 18 | # --- Fixtures for test data --- 19 | 20 | coeff_params = [ 21 | pytest.param(jnp.array([1.0, 4]), id="2nd-order"), 22 | pytest.param(jnp.array([1.0, 4, -5]), id="3rd-order"), 23 | pytest.param(jnp.array([1.0, 4, -5, -3]), id="4th-order"), 24 | ] 25 | 26 | # --- Helper Functions and Dynamic Parameters --- 27 | 28 | 29 | # Build the list of parameters for the evaluation points 30 | eval_point_params = [ 31 | pytest.param(1.0, id="scalar_float"), 32 | pytest.param(jnp.array([1.0]), id="scalar_array"), 33 | ] 34 | if TEST_VECTOR_INPUTS: 35 | eval_point_params.append(pytest.param(jnp.array([1.0, 2.0, 3.0]), id="vector")) 36 | 37 | # Build the list of parameters for interval evaluation 38 | 39 | eval_interval_params = [] 40 | 41 | if TEST_INCLUSION_FUNCTIONS: 42 | eval_interval_params.extend( 43 | [ 44 | pytest.param(irx.icentpert(1.0, 0.1), id="scalar_interval"), 45 | pytest.param( 46 | irx.icentpert(jnp.array([1.0]), 0.1), id="scalar_array_interval" 47 | ), 48 | ] 49 | ) 50 | 51 | if TEST_VECTOR_INPUTS: 52 | eval_interval_params.append( 53 | pytest.param( 54 | irx.icentpert(jnp.array([1.0, 2.0, 3.0]), 0.1), id="vector_interval" 55 | ) 56 | ) 57 | 58 | 59 | # --- Parametrized Fixtures --- 60 | 61 | 62 | @pytest.fixture(params=coeff_params) 63 | def poly_coeff(request): 64 | """Parametrized fixture for polynomial coefficients.""" 65 | 66 | return request.param 67 | 68 | 69 | @pytest.fixture(params=eval_point_params) 70 | def eval_point(request): 71 | """Parametrized fixture for evaluation points (scalar/vector).""" 72 | 73 | return request.param 74 | 75 | 76 | @pytest.fixture(params=eval_interval_params) 77 | def eval_interval(request): 78 | """Parametrized fixture for evaluation intervals.""" 79 | 80 | return request.param 81 | 82 | 83 | # --- Generic Test Functions --- 84 | 85 | 86 | def test_polynomial_evaluation(poly_coeff, eval_point): 87 | """Tests polynomial evaluation for various dynamically-provided input types.""" 88 | 89 | result = polynomial(poly_coeff, eval_point) 90 | 91 | expected = jnp.polyval(poly_coeff, eval_point) 92 | 93 | assert jnp.all(result == expected) 94 | 95 | 96 | @pytest.mark.skipif( 97 | not TEST_INCLUSION_FUNCTIONS, reason="Inclusion function tests are disabled" 98 | ) 99 | def test_inclusion_function(poly_coeff, eval_interval): 100 | """Tests the natif of the inclusion function for various interval types.""" 101 | poly_natif = irx.natif(polynomial) 102 | result = poly_natif(poly_coeff, eval_interval) 103 | 104 | # Check the type first 105 | assert isinstance(result, irx.Interval) 106 | 107 | # Validate the overapproximation by sampling 108 | validate_overapproximation_nd( 109 | lambda x: polynomial(poly_coeff, x), eval_interval, result 110 | ) 111 | 112 | 113 | @pytest.mark.skipif(not TEST_JACFWD, reason="JACFWD tests are disabled") 114 | def test_jacfwd(poly_coeff, eval_point): 115 | """Tests forward-mode AD for various dynamically-provided input types.""" 116 | pd_fwd = jax.jacfwd(polynomial, argnums=1) 117 | der_ad = pd_fwd(poly_coeff, eval_point) 118 | 119 | der_sym = jnp.polyval(jnp.polyder(poly_coeff), eval_point) 120 | 121 | # print() 122 | # print(f"{eval_point=}") 123 | # print(f"{der_ad=}") 124 | # print(f"{der_sym=}") 125 | 126 | if jnp.ndim(eval_point) == 0: 127 | expected_jacobian = der_sym 128 | else: 129 | expected_jacobian = jnp.diag(der_sym) 130 | 131 | assert jnp.allclose(der_ad, expected_jacobian) 132 | 133 | 134 | @pytest.mark.skipif(not TEST_JACREV, reason="JACREV tests are disabled") 135 | def test_jacrev(poly_coeff, eval_point): 136 | """Tests reverse-mode AD for various dynamically-provided input types.""" 137 | pd_rev = jax.jacrev(polynomial, argnums=1) 138 | jacobian = pd_rev(poly_coeff, eval_point) 139 | 140 | deriv_vals = jnp.polyval(jnp.polyder(poly_coeff), eval_point) 141 | 142 | if jnp.ndim(eval_point) == 0: 143 | expected_jacobian = deriv_vals 144 | else: 145 | expected_jacobian = jnp.diag(deriv_vals) 146 | 147 | assert jnp.allclose(jacobian, expected_jacobian) 148 | 149 | 150 | @pytest.mark.skipif( 151 | not (TEST_JIT_COMPILATION and TEST_INCLUSION_FUNCTIONS), 152 | reason="JIT inclusion tests are disabled", 153 | ) 154 | def test_jit_inclusion(poly_coeff, eval_interval): 155 | """Tests the JIT-compiled inclusion function for various interval types.""" 156 | poly_natif = irx.natif(polynomial) 157 | poly_natif_jit = jax.jit(poly_natif) 158 | result = poly_natif_jit(poly_coeff, eval_interval) 159 | 160 | # Check the type first 161 | assert isinstance(result, irx.Interval) 162 | 163 | # Validate the overapproximation by sampling 164 | validate_overapproximation_nd( 165 | lambda x: polynomial(poly_coeff, x), eval_interval, result 166 | ) 167 | 168 | 169 | if __name__ == "__main__": 170 | a = jnp.array([1, 2, 3]) 171 | # The major axis determines the degree of the polynomial (i.e. num rows = degree + 1) 172 | # The minor axis determines the number of polynomials (i.e. num cols = num polynomials) 173 | a_multiple = jnp.array( 174 | [ 175 | [1.0, 4, 2], 176 | [-3, 2, 2], 177 | [1, 2, 2], 178 | ] 179 | ) 180 | x = 2 181 | x_multiple = 1 + jnp.arange( 182 | a_multiple.shape[1] 183 | ) # WARN: this only works if x_multiple.shape() = a_multiple.shape()[1] 184 | 185 | # Confirmed same behavior as jnp.polyval 186 | print(polynomial(a, x)) 187 | print(polynomial(a_multiple, x)) 188 | 189 | print(polynomial(a, x_multiple)) 190 | print(polynomial(a_multiple, x_multiple)) 191 | # This batches over both arguments simultaneously (NOT product wise) 192 | # That is, the first polynomial is ONLY evaluated at the first point, the second at the second, etc. 193 | 194 | # Checking polyder 195 | print(jnp.polyder(a)) 196 | # print(jnp.polyder(a_multiple)) # This is not supported by jnp or numpy 197 | 198 | # print(jnp.polyder(jnp.array([[1]]))) 199 | # print(jnp.polyder(jnp.array([[1], [2]]))) 200 | # print(jnp.polyder(jnp.array([[1], [2], [3]]))) 201 | # print(jnp.polyder(jnp.array([[1], [2], [3], [4]]).squeeze())) 202 | -------------------------------------------------------------------------------- /tests/test_system_continuous.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | from immrax import System, RawTrajectory, RawContinuousTrajectory, ContinuousTrajectory 5 | 6 | A = jnp.array([[0, 1], [0, 0]]) 7 | B = jnp.array([[0], [1]]) 8 | 9 | 10 | class LinearSys(System): 11 | def __init__(self) -> None: 12 | self.evolution = "continuous" 13 | self.xlen = 2 14 | self.name = "Double Integrator" 15 | 16 | def f(self, t, x: jax.Array, u: jax.Array) -> jax.Array: 17 | return A @ x + B @ u 18 | 19 | 20 | class HarmOsc(System): 21 | def __init__(self) -> None: 22 | self.evolution = "continuous" 23 | self.xlen = 2 24 | self.name = "Harmonic Oscillator" 25 | 26 | def f(self, t, x: jax.Array) -> jax.Array: 27 | x1, x2 = x.ravel() 28 | return jnp.array([-x2, x1]) 29 | 30 | 31 | class VanDerPolOsc(System): 32 | def __init__(self, mu: float = 1) -> None: 33 | self.evolution = "continuous" 34 | self.xlen = 2 35 | self.name = "Van der Pol Oscillator" 36 | self.mu = mu 37 | 38 | def f(self, t, x: jax.Array) -> jax.Array: 39 | x1, x2 = x.ravel() 40 | return jnp.array([self.mu * (x1 - 1 / 3 * x1**3 - x2), x1 / self.mu]) 41 | 42 | 43 | class Vehicle(System): 44 | def __init__(self) -> None: 45 | self.evolution = "continuous" 46 | self.xlen = 4 47 | 48 | def f(self, t: jax.Array, x: jax.Array, u: jax.Array, w: jax.Array) -> jax.Array: 49 | px, py, psi, v = x.ravel() 50 | u1, u2 = u.ravel() 51 | beta = jnp.arctan(jnp.tan(u2) / 2) 52 | return jnp.array( 53 | [v * jnp.cos(psi + beta), v * jnp.sin(psi + beta), v * jnp.sin(beta), u1] 54 | ) 55 | 56 | 57 | # --- Helper Functions --- 58 | 59 | 60 | def validate_trajectory(traj_raw: RawTrajectory): 61 | """Validates the properties of a computed trajectory.""" 62 | assert isinstance(traj_raw, RawContinuousTrajectory) 63 | assert traj_raw is not None 64 | 65 | t_finite = jnp.isfinite(traj_raw.ts) 66 | computed_ys = traj_raw.ys[jnp.where(t_finite)] 67 | padding_ys = traj_raw.ys[jnp.where(~t_finite)] 68 | 69 | assert jnp.isfinite(computed_ys).all() 70 | assert jnp.isinf(padding_ys).all() 71 | assert computed_ys.shape[1:] == traj_raw.ys.shape[1:] 72 | assert padding_ys.shape[1:] == traj_raw.ys.shape[1:] 73 | 74 | traj = traj_raw.to_convenience() 75 | assert isinstance(traj, ContinuousTrajectory) 76 | assert jnp.equal(traj.ys, computed_ys).all() 77 | 78 | 79 | # --- Fixtures for 2D Systems (no inputs) --- 80 | 81 | 82 | @pytest.fixture( 83 | params=[ 84 | # pytest.param(DoubleIntegrator(), id="HarmOsc"), 85 | pytest.param(HarmOsc(), id="HarmOsc"), 86 | pytest.param(VanDerPolOsc(), id="VanDerPolOsc"), 87 | ] 88 | ) 89 | def system_2d(request): 90 | return request.param 91 | 92 | 93 | @pytest.fixture( 94 | params=[ 95 | pytest.param(jnp.array([1.0, 0.0]), id="unit_x"), 96 | pytest.param(jnp.array([0.0, 1.0]), id="unit_y"), 97 | pytest.param(jnp.zeros((2,)), id="zeros"), 98 | pytest.param(jnp.ones((2,)), id="ones"), 99 | pytest.param(jnp.array([0.5, -0.5]), id="pos_neg"), 100 | ] 101 | ) 102 | def x0_2d(request): 103 | return request.param 104 | 105 | 106 | # --- Fixtures for 4D Systems (with inputs) --- 107 | 108 | 109 | @pytest.fixture 110 | def system_4d(): 111 | return Vehicle() 112 | 113 | 114 | @pytest.fixture 115 | def x0_4d(): 116 | return jnp.array([0.0, 0.0, 0.0, 1.0]) 117 | 118 | 119 | @pytest.fixture( 120 | params=[ 121 | pytest.param( 122 | (lambda t, x: jnp.array([0.1, 0.1]), lambda x, t: jnp.zeros(1)), 123 | id="const_input_1", 124 | ), 125 | pytest.param( 126 | (lambda t, x: jnp.array([-0.1, 0.2]), lambda x, t: jnp.zeros(1)), 127 | id="const_input_2", 128 | ), 129 | ] 130 | ) 131 | def vehicle_inputs(request): 132 | """Fixture for Vehicle system inputs.""" 133 | return request.param 134 | 135 | 136 | @pytest.fixture 137 | def system_linear(): 138 | """Fixture for the LinearSys system.""" 139 | return LinearSys() 140 | 141 | 142 | # --- Test Functions --- 143 | 144 | 145 | def test_compute_trajectory_2d(system_2d, x0_2d): 146 | """Tests trajectory computation for 2D systems without inputs.""" 147 | traj_diffrax = system_2d.compute_trajectory(t0=0, tf=1, x0=x0_2d) 148 | validate_trajectory(traj_diffrax) 149 | 150 | 151 | def test_compute_trajectory_4d(system_4d, x0_4d, vehicle_inputs): 152 | """Tests trajectory computation for the 4D Vehicle system with inputs.""" 153 | traj_diffrax = system_4d.compute_trajectory( 154 | t0=0, tf=1, x0=x0_4d, inputs=vehicle_inputs 155 | ) 156 | validate_trajectory(traj_diffrax) 157 | 158 | 159 | def test_linear_sys_stabilization(system_linear, x0_2d): 160 | """Tests stabilization of a linear system with feedback control.""" 161 | n = A.shape[0] 162 | 163 | # 1. Compute controllability matrix 164 | C = jnp.hstack([B] + [jnp.linalg.matrix_power(A, i) @ B for i in range(1, n)]) 165 | 166 | # 2. Desired characteristic polynomial (poles at -1, -2) 167 | # p(s) = s^2 + 3s + 2 168 | # p(A) = A^2 + 3A + 2I 169 | pA = jnp.linalg.matrix_power(A, 2) + 3 * A + 2 * jnp.eye(n) 170 | 171 | # 3. Compute gain K using Ackermann's formula 172 | e_n_T = jnp.zeros((1, n)).at[0, -1].set(1.0) 173 | K = e_n_T @ jnp.linalg.inv(C) @ pA 174 | 175 | def controller(t, x): 176 | return -K @ x 177 | 178 | inputs = (controller,) 179 | 180 | traj_diffrax = system_linear.compute_trajectory( 181 | t0=0, tf=10.0, x0=x0_2d, inputs=inputs 182 | ) 183 | 184 | validate_trajectory(traj_diffrax) 185 | 186 | # Assert that the final state is close to zero 187 | final_state = traj_diffrax.ys[jnp.where(jnp.isfinite(traj_diffrax.ts))][-1] 188 | assert jnp.allclose(final_state, jnp.zeros_like(final_state), atol=1e-2) 189 | 190 | # Sanity check: With a zero controller, the system should not stabilize. 191 | if not jnp.allclose(x0_2d, jnp.zeros_like(x0_2d)): 192 | zero_controller = (lambda t, x: jnp.zeros((1,)),) 193 | traj_uncontrolled = system_linear.compute_trajectory( 194 | t0=0, tf=10.0, x0=x0_2d, inputs=zero_controller 195 | ) 196 | validate_trajectory(traj_uncontrolled) 197 | final_state_uncontrolled = traj_uncontrolled.ys[ 198 | jnp.where(jnp.isfinite(traj_uncontrolled.ts)) 199 | ][-1] 200 | assert not jnp.allclose( 201 | final_state_uncontrolled, 202 | jnp.zeros_like(final_state_uncontrolled), 203 | atol=1e-1, 204 | ) 205 | 206 | 207 | def test_ragged_trajectory_continuous(system_2d, x0_2d): 208 | """Tests ragged trajectory creation for continuous systems.""" 209 | tfs = jnp.arange(1.0, 2.0, 0.2) 210 | 211 | # vmap compute_trajectory over tf 212 | # We expect this to create a ragged trajectory, as each `tf` is different. 213 | compute_traj_vmap = jax.vmap(system_2d.compute_trajectory, in_axes=(None, 0, None)) 214 | 215 | raw_traj = compute_traj_vmap(0.0, tfs, x0_2d) 216 | 217 | traj = raw_traj.to_convenience() 218 | 219 | # 1. is_ragged should return true 220 | assert traj.is_ragged() 221 | 222 | # 2. The list of ts should have the same length as the range of tfs 223 | assert len(traj.ts) == len(tfs) 224 | 225 | # 3. Each ts should correspond to a ys of the same length 226 | for i in range(len(tfs)): 227 | assert len(traj.ts[i]) == len(traj.ys[i]) 228 | 229 | # 4. Each element of ys should be finite 230 | for i in range(len(tfs)): 231 | assert jnp.all(jnp.isfinite(traj.ys[i])) 232 | -------------------------------------------------------------------------------- /immrax/refinement/factories.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from itertools import permutations 3 | from typing import Callable 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from linrax import linprog 9 | from immrax.inclusion import Interval, icopy, interval 10 | from immrax.utils import angular_sweep, null_space 11 | 12 | 13 | class Refinement(abc.ABC): 14 | def __init__(self) -> None: 15 | pass 16 | 17 | @abc.abstractmethod 18 | def get_refine_func(self) -> Callable[[Interval], Interval]: 19 | pass 20 | 21 | 22 | class LinProgRefinement(Refinement): 23 | H: jax.Array 24 | 25 | def __init__(self, H: jax.Array) -> None: 26 | self.H = H 27 | super().__init__() 28 | 29 | def get_refine_func(self) -> Callable[[Interval], Interval]: 30 | A_ub = jnp.vstack((self.H, -self.H)) 31 | 32 | def var_refine(idx: int, ret: Interval) -> Interval: 33 | # I update b_eq and b_ub here because ret is shrinking 34 | b_ub = jnp.concatenate( 35 | (ret.upper, -ret.lower) 36 | ) # TODO: try adding buffer region *inside* the bounds to collapsed face 37 | obj_vec_i = self.H[idx] 38 | 39 | sol_min, sol_type_min = linprog( 40 | c=obj_vec_i, 41 | A_ub=A_ub, 42 | b_ub=b_ub, 43 | unbounded=True, 44 | ) 45 | 46 | sol_max, sol_type_max = linprog( 47 | c=-obj_vec_i, 48 | A_ub=A_ub, 49 | b_ub=b_ub, 50 | unbounded=True, 51 | ) 52 | 53 | # If a vector that gives extra info on this var is found, refine bounds 54 | new_lower_i = jnp.where( 55 | sol_type_min.success, 56 | jnp.maximum(sol_min.fun, ret.lower[idx]), 57 | ret.lower[idx], 58 | )[0] 59 | retl = ret.lower.at[idx].set(new_lower_i) 60 | new_upper_i = jnp.where( 61 | sol_type_max.success, 62 | jnp.minimum(-sol_max.fun, ret.upper[idx]), 63 | ret.upper[idx], 64 | )[0] 65 | retu = ret.upper.at[idx].set(new_upper_i) 66 | 67 | return interval(retl, retu) 68 | 69 | def I_r(y: Interval) -> Interval: 70 | # for i in range(n): 71 | # ret = var_refine(ret, i) 72 | 73 | return jax.lax.fori_loop(0, y.shape[0], var_refine, icopy(y)) 74 | 75 | return I_r 76 | 77 | 78 | class SampleRefinement(Refinement): 79 | H: jax.Array 80 | Hp: jax.Array 81 | N: jax.Array 82 | A_lib: jax.Array 83 | num_samples: int 84 | 85 | def __init__(self, H: jax.Array, num_samples: int = 10) -> None: 86 | self.num_samples = num_samples 87 | self.H = H 88 | # self.Hp = jnp.linalg.pinv(H) 89 | self.Hp = jnp.hstack( 90 | (jnp.eye(H.shape[1]), jnp.zeros((H.shape[1], H.shape[0] - H.shape[1]))) 91 | ) 92 | 93 | self.N = jnp.array( 94 | [ 95 | jnp.squeeze( 96 | null_space(jnp.vstack([jnp.eye(H.shape[1]), aug_var]).T, dim_null=1) 97 | ) 98 | for aug_var in H[H.shape[1] :] 99 | ] 100 | ).T 101 | # assert not jnp.any(jnp.isnan(self.N)) 102 | self.N = jnp.vstack([self.N[: H.shape[1]], jnp.diag(self.N[-1])]) 103 | 104 | # Sample aux vars independently 105 | self.A_lib = self.N.T 106 | 107 | # Sample aux vars pairwise 108 | if self.N.shape[1] > 1: 109 | points = angular_sweep(num_samples) 110 | extended_points = jnp.hstack( 111 | [ 112 | points, 113 | jnp.zeros((points.shape[0], self.N.shape[1] - points.shape[1])), 114 | ] 115 | ) 116 | 117 | def permutation(mat: jax.Array, perm): 118 | permuted_matrix = jnp.zeros_like(mat) 119 | for i, p in enumerate(perm): 120 | permuted_matrix = permuted_matrix.at[:, p].set(mat[:, i]) 121 | return permuted_matrix 122 | 123 | points_permutations = jax.vmap(permutation, in_axes=(None, 0))( 124 | extended_points, 125 | jnp.array(list(permutations(range(self.N.shape[1]), 2))), 126 | ) 127 | points_permutations = points_permutations.reshape( 128 | -1, points_permutations.shape[-1] 129 | ) 130 | points_permutations = points_permutations @ self.N.T 131 | self.A_lib = jnp.vstack([self.A_lib, points_permutations]) 132 | # assert jnp.allclose(self.A_lib @ self.H, 0, atol=1e-6) 133 | 134 | super().__init__() 135 | 136 | def get_refine_func(self) -> Callable[[Interval], Interval]: 137 | def vec_refine(null_vector: jax.Array, var_index: jax.Array, y: Interval): 138 | ret = icopy(y) 139 | 140 | # Set up linear algebra computations for the refinement 141 | bounding_vars = interval(null_vector.at[var_index].set(0)) 142 | ref_var = interval(null_vector[var_index]) 143 | b1 = lambda: ((-bounding_vars @ ret) / ref_var) & ret[var_index] 144 | b2 = lambda: ret[var_index] 145 | 146 | # Compute refinement based on null vector, if possible 147 | ndb0 = jnp.abs(null_vector[var_index]) > 1e-10 148 | ret = jax.lax.cond(ndb0, b1, b2) 149 | 150 | # fix fpe problem with upper < lower 151 | retu = jnp.where(ret.upper >= ret.lower, ret.upper, ret.lower) 152 | return interval(ret.lower, retu) 153 | 154 | mat_refine = jax.vmap(vec_refine, in_axes=(0, None, None), out_axes=0) 155 | mat_refine_all = jax.vmap(mat_refine, in_axes=(None, 0, None), out_axes=1) 156 | 157 | def best_refinement(y: Interval): 158 | refinements = mat_refine_all(self.A_lib, jnp.arange(len(y)), y) 159 | lower = jnp.fmax( 160 | y.lower, refinements.lower 161 | ) # Some refinements don't work, we need to ignore nans 162 | upper = jnp.fmin(y.upper, refinements.upper) 163 | 164 | return interval( 165 | jnp.max(lower, axis=0), # TODO: I need to cap this at upper boun 166 | jnp.min(upper, axis=0), 167 | ) 168 | 169 | return best_refinement 170 | 171 | 172 | class NullVecRefinement(Refinement): 173 | null_vec: jax.Array 174 | 175 | def __init__(self, null_vec: jax.Array) -> None: 176 | self.null_vec = null_vec 177 | super().__init__() 178 | 179 | def get_refine_func(self) -> Callable[[Interval], Interval]: 180 | def vec_refine(null_vector: jax.Array, var_index: jax.Array, y: Interval): 181 | ret = icopy(y) 182 | 183 | # Set up linear algebra computations for the refinement 184 | bounding_vars = interval(null_vector.at[var_index].set(0)) 185 | ref_var = interval(null_vector[var_index]) 186 | b1 = lambda: ((-bounding_vars @ ret) / ref_var) & ret[var_index] 187 | b2 = lambda: ret[var_index] 188 | 189 | # Compute refinement based on null vector, if possible 190 | ndb0 = jnp.abs(null_vector[var_index]) > 1e-10 191 | ret = jax.lax.cond(ndb0, b1, b2) 192 | 193 | # fix fpe problem with upper < lower 194 | retu = jnp.where(ret.upper >= ret.lower, ret.upper, ret.lower) 195 | return interval(ret.lower, retu) 196 | 197 | vec_refine_all = jax.vmap(vec_refine, in_axes=(None, 0, None), out_axes=1) 198 | 199 | def refinement(y: Interval): 200 | return vec_refine_all(self.null_vec, jnp.arange(len(y)), y) 201 | 202 | return refinement 203 | -------------------------------------------------------------------------------- /tests/test_system_discrete.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | from immrax import System, RawTrajectory, RawDiscreteTrajectory, DiscreteTrajectory 5 | 6 | # --- Systems --- 7 | 8 | # Nilpotent system 9 | A_nilpotent = jnp.array([[0, 1, 0], [0, 0, 1], [0, 0, 0]]) 10 | n_nilpotent = A_nilpotent.shape[0] 11 | 12 | 13 | class NilpotentSys(System): 14 | def __init__(self) -> None: 15 | self.evolution = "discrete" 16 | self.xlen = n_nilpotent 17 | self.name = "Nilpotent System" 18 | 19 | def f(self, t, x: jax.Array) -> jax.Array: 20 | return A_nilpotent @ x 21 | 22 | 23 | # Linear system 24 | A_linear = jnp.array([[1.1, 0.2], [-0.2, 1.1]]) # Unstable system 25 | B_linear = jnp.array([[0.1], [1.0]]) 26 | 27 | 28 | class LinearSysDiscrete(System): 29 | def __init__(self) -> None: 30 | self.evolution = "discrete" 31 | self.xlen = 2 32 | self.name = "Discrete Linear System" 33 | 34 | def f(self, t, x: jax.Array, u: jax.Array) -> jax.Array: 35 | return A_linear @ x + B_linear @ u 36 | 37 | 38 | # Nonlinear system 39 | class LogisticMap(System): 40 | def __init__(self, r: float = 3.9) -> None: 41 | self.evolution = "discrete" 42 | self.xlen = 1 43 | self.name = "Logistic Map" 44 | self.r = r 45 | 46 | def f(self, t, x: jax.Array) -> jax.Array: 47 | return self.r * x * (1 - x) 48 | 49 | 50 | # --- Helper Functions --- 51 | 52 | 53 | def validate_trajectory(traj_raw: RawTrajectory): 54 | """Validates the properties of a computed trajectory.""" 55 | assert isinstance(traj_raw, RawDiscreteTrajectory) 56 | assert traj_raw is not None 57 | 58 | t_finite = jnp.isfinite(traj_raw.ts) 59 | computed_ys = traj_raw.ys[jnp.where(t_finite)] 60 | padding_ys = traj_raw.ys[jnp.where(~t_finite)] 61 | 62 | assert jnp.isfinite(computed_ys).all() 63 | # assert jnp.isinf(padding_ys).all() 64 | assert computed_ys.shape[1:] == traj_raw.ys.shape[1:] 65 | assert padding_ys.shape[1:] == traj_raw.ys.shape[1:] 66 | 67 | traj = traj_raw.to_convenience() 68 | assert isinstance(traj, DiscreteTrajectory) 69 | assert jnp.equal(traj.ys, computed_ys).all() 70 | 71 | 72 | # --- Fixtures --- 73 | 74 | 75 | @pytest.fixture(params=[pytest.param(LogisticMap(), id="LogisticMap")]) 76 | def system_1d(request): 77 | return request.param 78 | 79 | 80 | @pytest.fixture( 81 | params=[ 82 | pytest.param(jnp.array([0.1]), id="0.1"), 83 | pytest.param(jnp.array([0.5]), id="0.5"), 84 | ] 85 | ) 86 | def x0_1d(request): 87 | return request.param 88 | 89 | 90 | @pytest.fixture 91 | def system_linear_discrete(): 92 | return LinearSysDiscrete() 93 | 94 | 95 | @pytest.fixture( 96 | params=[ 97 | pytest.param(jnp.array([1.0, 0.0]), id="unit_x"), 98 | pytest.param(jnp.array([0.0, 1.0]), id="unit_y"), 99 | pytest.param(jnp.ones(2), id="ones"), 100 | ] 101 | ) 102 | def x0_2d(request): 103 | return request.param 104 | 105 | 106 | @pytest.fixture 107 | def system_nilpotent(): 108 | return NilpotentSys() 109 | 110 | 111 | @pytest.fixture 112 | def x0_3d(): 113 | return jnp.array([1.0, 2.0, 3.0]) 114 | 115 | 116 | # --- Test Functions --- 117 | 118 | 119 | def test_compute_trajectory_1d(system_1d, x0_1d): 120 | """Tests trajectory computation for 1D systems without inputs.""" 121 | traj = system_1d.compute_trajectory(t0=0, tf=10, x0=x0_1d) 122 | validate_trajectory(traj) 123 | traj = traj.to_convenience() 124 | assert jnp.equal(traj.ys[0], x0_1d).all() 125 | assert traj.ys.shape == (11, 1) 126 | 127 | 128 | def test_compute_trajectory_2d(system_linear_discrete, x0_2d): 129 | """Tests trajectory computation for 2D systems with constant input.""" 130 | 131 | def controller(t, x): 132 | return jnp.array([0.1]) 133 | 134 | inputs = (controller,) 135 | traj = system_linear_discrete.compute_trajectory( 136 | t0=0, tf=10, x0=x0_2d, inputs=inputs 137 | ) 138 | validate_trajectory(traj) 139 | traj = traj.to_convenience() 140 | assert jnp.equal(traj.ys[0], x0_2d).all() 141 | assert traj.ys.shape == (11, 2) 142 | 143 | 144 | def test_nilpotent_convergence(system_nilpotent, x0_3d): 145 | """Tests that a nilpotent system converges to zero in n steps.""" 146 | n = system_nilpotent.xlen 147 | traj = system_nilpotent.compute_trajectory(t0=0, tf=n, x0=x0_3d) 148 | validate_trajectory(traj) 149 | traj = traj.to_convenience() 150 | assert jnp.equal(traj.ys[0], x0_3d).all() 151 | assert traj.ys.shape[0] == n + 1 152 | 153 | # State at time n should be zero 154 | final_state = traj.ys[-1] 155 | assert jnp.allclose(final_state, jnp.zeros_like(final_state)) 156 | 157 | # For sanity, check that A^n is the zero matrix 158 | An = jnp.linalg.matrix_power(A_nilpotent, n) 159 | assert jnp.allclose(An, jnp.zeros_like(An)) 160 | 161 | 162 | def test_linear_sys_stabilization_discrete(system_linear_discrete, x0_2d): 163 | """Tests stabilization of a discrete linear system with feedback control.""" 164 | n = system_linear_discrete.xlen 165 | A, B = A_linear, B_linear 166 | 167 | # 1. Compute controllability matrix 168 | C = jnp.hstack([B] + [jnp.linalg.matrix_power(A, i) @ B for i in range(1, n)]) 169 | assert jnp.linalg.matrix_rank(C) == n, "System is not controllable" 170 | 171 | # 2. Desired characteristic polynomial (poles at 0.5, 0.6) 172 | # p(z) = (z - 0.5)(z - 0.6) = z^2 - 1.1z + 0.3 173 | # p(A) = A^2 - 1.1A + 0.3I 174 | pA = jnp.linalg.matrix_power(A, 2) - 1.1 * A + 0.3 * jnp.eye(n) 175 | 176 | # 3. Compute gain K using Ackermann's formula 177 | e_n_T = jnp.zeros((1, n)).at[0, -1].set(1.0) 178 | K = e_n_T @ jnp.linalg.inv(C) @ pA 179 | 180 | def controller(t, x): 181 | return -K @ x 182 | 183 | inputs = (controller,) 184 | 185 | traj = system_linear_discrete.compute_trajectory( 186 | t0=0, tf=50, x0=x0_2d, inputs=inputs 187 | ) 188 | 189 | validate_trajectory(traj) 190 | traj = traj.to_convenience() 191 | assert jnp.equal(traj.ys[0], x0_2d).all() 192 | 193 | # Assert that the final state is close to zero 194 | final_state = traj.ys[-1] 195 | assert jnp.allclose(final_state, jnp.zeros_like(final_state), atol=1e-2) 196 | 197 | # Sanity check: With a zero controller, the system should not stabilize. 198 | if not jnp.allclose(x0_2d, jnp.zeros_like(x0_2d)): 199 | zero_controller = (lambda t, x: jnp.zeros((1,)),) 200 | traj_uncontrolled = system_linear_discrete.compute_trajectory( 201 | t0=0, tf=50, x0=x0_2d, inputs=zero_controller 202 | ) 203 | validate_trajectory(traj_uncontrolled) 204 | traj_uncontrolled = traj_uncontrolled.to_convenience() 205 | final_state_uncontrolled = traj_uncontrolled.ys[-1] 206 | assert not jnp.allclose( 207 | final_state_uncontrolled, 208 | jnp.zeros_like(final_state_uncontrolled), 209 | atol=1e-1, 210 | ) 211 | 212 | 213 | def test_ragged_trajectory_discrete(system_1d, x0_1d): 214 | """Tests ragged trajectory creation for discrete systems.""" 215 | tfs = jnp.arange(5, 10) 216 | 217 | # vmap compute_trajectory over tf 218 | # We expect this to create a ragged trajectory, as each `tf` is different. 219 | compute_traj_vmap = jax.vmap(system_1d.compute_trajectory, in_axes=(None, 0, None)) 220 | 221 | raw_traj = compute_traj_vmap(0, tfs, x0_1d) 222 | 223 | traj = raw_traj.to_convenience() 224 | 225 | # 1. is_ragged should return true 226 | assert traj.is_ragged() 227 | 228 | # 2. The list of ts should have the same length as the range of tfs 229 | assert len(traj.ts) == len(tfs) 230 | 231 | # 3. Each ts should correspond to a ys of the same length 232 | for i in range(len(tfs)): 233 | assert len(traj.ts[i]) == len(traj.ys[i]) 234 | # For discrete systems, trajectory length is tf + 1 (for t0=0) 235 | assert len(traj.ts[i]) == tfs[i] + 1 236 | 237 | # 4. Each element of ys should be finite 238 | for i in range(len(tfs)): 239 | assert jnp.all(jnp.isfinite(traj.ys[i])) 240 | -------------------------------------------------------------------------------- /immrax/system/system.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from functools import partial 3 | from typing import Any, Callable, List, Literal, Union 4 | 5 | from diffrax import ( 6 | AbstractSolver, 7 | Dopri5, 8 | Euler, 9 | ODETerm, 10 | SaveAt, 11 | Tsit5, 12 | diffeqsolve, 13 | ) 14 | from immutabledict import immutabledict 15 | import jax 16 | import jax.numpy as jnp 17 | from jaxtyping import Float, Integer 18 | 19 | from .trajectory import ( 20 | RawTrajectory, 21 | RawContinuousTrajectory, 22 | RawDiscreteTrajectory, 23 | ) 24 | 25 | 26 | class EvolutionError(Exception): 27 | def __init__(self, t: Any, evolution: Literal["continuous", "discrete"]) -> None: 28 | super().__init__( 29 | f"Time {t} of type {type(t)} does not match evolution type {evolution}" 30 | ) 31 | 32 | 33 | class System(abc.ABC): 34 | r"""System 35 | 36 | A dynamical system of one of the following forms: 37 | 38 | .. math:: 39 | \dot{x} = f(t, x, \dots), \text{ or } x^+ = f(t, x, \dots). 40 | 41 | where :math:`t\in T\in\{\mathbb{Z},\mathbb{R}\}` is a discrete or continuous time variable, :math:`x\in\mathbb{R}^n` is the state of the system, and :math:`\dots` are some other inputs, perhaps control and disturbance. 42 | 43 | There are two main attributes that need to be defined in a subclass: 44 | 45 | - `evolution` : Literal['continuous', 'discrete'], which specifies whether the system is continuous or discrete. 46 | - `xlen` : int, which specifies the dimension of the state space. 47 | 48 | The main method that needs to be defined is `f(t, x, *args, **kwargs)`, which returns the time evolution of the state at time `t` and state `x`. 49 | """ 50 | 51 | evolution: Literal["continuous", "discrete"] 52 | xlen: int 53 | 54 | @abc.abstractmethod 55 | def f(self, t: Union[Integer, Float], x: jax.Array, *args, **kwargs) -> jax.Array: 56 | """The right hand side of the system 57 | 58 | Parameters 59 | ---------- 60 | t : Union[Integer, Float] 61 | The time of the system 62 | x : jax.Array 63 | The state of the system 64 | *args : 65 | Inputs (control, disturbance, etc.) as positional arguments depending on parent class. 66 | **kwargs : 67 | Other keyword arguments depending on parent class. 68 | 69 | Returns 70 | ------- 71 | jax.Array 72 | The time evolution of the state 73 | 74 | """ 75 | 76 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 77 | return self.f(*args, **kwargs) 78 | 79 | @partial( 80 | jax.jit, static_argnums=(0, 4), static_argnames=("solver", "f_kwargs", "inputs") 81 | ) 82 | def compute_trajectory( 83 | self, 84 | t0: Union[Integer, Float], 85 | tf: Union[Integer, Float], 86 | x0: jax.Array, 87 | inputs: List[Callable[[int, jax.Array], jax.Array]] = [], 88 | dt: float = 0.01, 89 | *, 90 | solver: Union[Literal["euler", "rk45", "tsit5"], AbstractSolver] = "tsit5", 91 | f_kwargs: immutabledict = immutabledict({}), 92 | **kwargs, 93 | ) -> RawTrajectory: 94 | """Computes the trajectory of the system from time t0 to tf with initial condition x0. 95 | 96 | Parameters 97 | ---------- 98 | t0 : Union[Integer,Float] 99 | Initial time 100 | tf : Union[Integer,Float] 101 | Final time 102 | x0 : jax.Array 103 | Initial condition 104 | inputs : Tuple[Callable[[int,jax.Array], jax.Array]], optional 105 | A tuple of Callables u(t,x) returning time/state varying inputs as positional arguments into f, by default () 106 | dt : float, optional 107 | Time step, by default 0.01 108 | solver : Union[Literal['euler', 'rk45', 'tsit5'], AbstractSolver], optional 109 | Solver to use for diffrax, by default 'tsit5' 110 | f_kwargs : immutabledict, optional 111 | An immutabledict to pass as keyword arguments to the dynamics f, by default {} 112 | **kwargs : 113 | Additional kwargs to pass to the solver from diffrax 114 | 115 | Returns 116 | ------- 117 | RawTrajectory 118 | Flow line / trajectory of the system from the initial condition x0 to the final time tf 119 | """ 120 | 121 | def func(t, x, args): 122 | # Unpack the inputs 123 | return self.f(t, x, *[u(t, x) for u in inputs], **f_kwargs) 124 | 125 | if self.evolution == "continuous": 126 | term = ODETerm(func) 127 | if solver == "euler": 128 | solver = Euler() 129 | elif solver == "rk45": 130 | solver = Dopri5() 131 | elif solver == "tsit5": 132 | solver = Tsit5() 133 | elif isinstance(solver, AbstractSolver): 134 | pass 135 | else: 136 | raise Exception(f"{solver=} is not a valid solver") 137 | 138 | saveat = SaveAt(t0=True, t1=True, steps=True) 139 | sol = diffeqsolve(term, solver, t0, tf, dt, x0, saveat=saveat, **kwargs) 140 | return RawContinuousTrajectory(sol) 141 | 142 | elif self.evolution == "discrete": 143 | if not ( 144 | jnp.issubdtype(jnp.array(t0).dtype, jnp.integer) 145 | and jnp.issubdtype(jnp.array(tf).dtype, jnp.integer) 146 | ): 147 | raise Exception( 148 | f"Times {t0=} and {tf=} must be integers for discrete evolution, got {type(t0)=} and {type(tf)=}" 149 | ) 150 | 151 | max_steps = 4096 152 | times = jnp.where( 153 | jnp.arange(max_steps) <= tf - t0, 154 | t0 + jnp.arange(max_steps), 155 | jnp.inf * jnp.ones(max_steps), 156 | ) 157 | 158 | # Use jax.lax.scan to compute the trajectory of the discrete system 159 | def step(x, t): 160 | xtp1 = jax.lax.cond( 161 | t < tf, lambda: func(t, x, None), lambda: jnp.inf * x0 162 | ) 163 | return xtp1, xtp1 164 | 165 | _, traj = jax.lax.scan(step, x0, times) 166 | return RawDiscreteTrajectory(times, jnp.vstack((x0, traj[:-1]))) 167 | else: 168 | raise Exception( 169 | f"Evolution needs to be 'continuous' or 'discrete', got {self.evolution=}" 170 | ) 171 | 172 | 173 | class ReversedSystem(System): 174 | """ReversedSystem 175 | A system with time reversed dynamics, i.e. :math:`\\dot{x} = -f(t,x,...)`. 176 | """ 177 | 178 | sys: System 179 | 180 | def __init__(self, sys: System) -> None: 181 | self.evolution = sys.evolution 182 | self.xlen = sys.xlen 183 | self.sys = sys 184 | 185 | def f(self, t: Union[Integer, Float], x: jax.Array, *args, **kwargs) -> jax.Array: 186 | return -self.sys.f(t, x, *args, **kwargs) 187 | 188 | 189 | class LinearTransformedSystem(System): 190 | """Linear Transformed System 191 | A system with dynamics :math:`\\dot{x} = Tf(t, T^{-1}x, ...)` where :math:`T` is an invertible linear transformation. 192 | """ 193 | 194 | sys: System 195 | 196 | def __init__(self, sys: System, T: jax.Array) -> None: 197 | self.evolution = sys.evolution 198 | self.xlen = sys.xlen 199 | self.sys = sys 200 | self.T = T 201 | self.Tinv = jnp.linalg.inv(T) 202 | 203 | def f(self, t: Union[Integer, Float], x: jax.Array, *args, **kwargs) -> jax.Array: 204 | return self.T @ self.sys.f(t, self.Tinv @ x, *args, **kwargs) 205 | 206 | 207 | class LiftedSystem(System): 208 | """Lifted System 209 | A system with dynamics :math:`\\dot{x} = Hf(t, H^+x, ...)` where H^+Hx = x. 210 | """ 211 | 212 | sys: System 213 | H: jax.Array 214 | Hp: jax.Array 215 | 216 | def __init__(self, sys: System, H: jax.Array, Hp: jax.Array) -> None: 217 | self.evolution = sys.evolution 218 | self.xlen = H.shape[0] 219 | self.sys = sys 220 | self.H = H 221 | self.Hp = Hp 222 | 223 | def f(self, t: Union[Integer, Float], x: jax.Array, *args, **kwargs) -> jax.Array: 224 | return self.H @ self.sys.f(t, self.Hp @ x, *args, **kwargs) 225 | 226 | 227 | class OpenLoopSystem(System, abc.ABC): 228 | """OpenLoopSystem 229 | An open-loop nonlinear dynamical system of the form 230 | 231 | .. math:: 232 | 233 | \\dot{x} = f(x,u,w), 234 | 235 | where :math:`x\\in\\mathbb{R}^n` is the state of the system, :math:`u\\in\\mathbb{R}^p` is a control input to the system, and :math:`w\\in\\mathbb{R}^q` is a disturbance input. 236 | """ 237 | 238 | @abc.abstractmethod 239 | def f( 240 | self, t: Union[Integer, Float], x: jax.Array, u: jax.Array, w: jax.Array 241 | ) -> jax.Array: 242 | """The right hand side of the open-loop system 243 | 244 | Parameters 245 | ---------- 246 | t : Union[Integer, Float] 247 | The time of the system 248 | x : jax.Array 249 | The state of the system 250 | u : jax.Array 251 | The control input to the system 252 | w : jax.Array 253 | The disturbance input to the system 254 | 255 | Returns 256 | ------- 257 | jax.Array 258 | The time evolution of the state 259 | 260 | """ 261 | -------------------------------------------------------------------------------- /tests/test_cubic_spline.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pytest 5 | import matplotlib.pyplot as plt 6 | 7 | import immrax as irx 8 | from immrax.inclusion.cubic_spline import ( 9 | create_cubic_spline_coeffs, 10 | make_spline_eval_fn, 11 | ) 12 | from immrax.inclusion.polynomial import polynomial 13 | from tests.utils import validate_overapproximation_1d_list 14 | 15 | # --- Test Case Flags --- 16 | TEST_INCLUSION_FUNCTIONS = True 17 | TEST_JIT_COMPILATION = True 18 | TEST_JACFWD = True 19 | TEST_JACREV = True 20 | 21 | # --- Fixtures for test data --- 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def spline_raw_data(): 26 | """Pytest fixture to generate raw data for spline tests.""" 27 | num_points = 20 28 | x_range = (0.0, 50.0) 29 | x_coords = np.linspace(x_range[0], x_range[1], num_points) 30 | y_range = (-10.0, 10.0) 31 | y_coords = np.random.uniform(y_range[0], y_range[1], size=num_points) 32 | points = jnp.array(np.vstack((x_coords, y_coords)).T) 33 | return points 34 | 35 | 36 | @pytest.fixture(scope="module") 37 | def spline_coeffs(spline_raw_data): 38 | """Pytest fixture to compute spline coefficients.""" 39 | return create_cubic_spline_coeffs(spline_raw_data) 40 | 41 | 42 | @pytest.fixture(scope="module") 43 | def spline_eval_fn(spline_coeffs): 44 | """Pytest fixture to create a spline evaluation function.""" 45 | x_knots, coeffs = spline_coeffs 46 | return make_spline_eval_fn(x_knots, coeffs) 47 | 48 | 49 | @pytest.fixture( 50 | params=[ 51 | pytest.param(10.5, id="scalar_float"), 52 | pytest.param(jnp.array([10.5]), id="scalar_array"), 53 | pytest.param(jnp.array([10.5, 20.5, 30.5]), id="vector"), 54 | ] 55 | ) 56 | def eval_point(request): 57 | """Parametrized fixture for evaluation points.""" 58 | return request.param 59 | 60 | 61 | # --- Helper Functions --- 62 | 63 | 64 | def make_spline_derivative_eval_fn(x_knots, coeffs): 65 | """Creates a function that evaluates the derivative of the cubic spline.""" 66 | a, b, c, d = coeffs 67 | 68 | def derivative_eval_fn(x_eval): 69 | """Evaluates the spline derivative at given x-values.""" 70 | i = jnp.sum( 71 | jax.vmap(lambda x: x_knots[1:] < x)(jnp.atleast_1d(x_eval)), 72 | axis=1, 73 | ) 74 | dx = x_eval - x_knots[i] 75 | 76 | # Derivative of the polynomial: 3*d*dx^2 + 2*c*dx + b 77 | der_coeffs = jnp.array([3 * d[i], 2 * c[i], b[i]]).squeeze() 78 | return polynomial(der_coeffs, dx) 79 | 80 | return derivative_eval_fn 81 | 82 | 83 | # --- Test Functions --- 84 | 85 | 86 | def test_spline_evaluation_at_knots(spline_raw_data, spline_eval_fn): 87 | """Tests if the spline evaluates to the original y-values at the knots.""" 88 | original_x = spline_raw_data[:, 0] 89 | original_y = spline_raw_data[:, 1] 90 | 91 | evaluated_y = spline_eval_fn(original_x) 92 | 93 | assert jnp.allclose(original_y, evaluated_y, atol=1e-6) 94 | 95 | 96 | @pytest.mark.skipif( 97 | not TEST_INCLUSION_FUNCTIONS, reason="Inclusion function tests are disabled" 98 | ) 99 | def test_inclusion_function(spline_coeffs, spline_eval_fn): 100 | """Tests the natif of the inclusion function for the spline.""" 101 | x_knots, _ = spline_coeffs 102 | 103 | eval_interval = irx.interval(x_knots[:-1] + 0.1, x_knots[1:] - 0.1) 104 | 105 | spline_natif = jax.vmap(irx.natif(spline_eval_fn)) 106 | result = spline_natif(eval_interval) 107 | 108 | assert isinstance(result, irx.Interval) 109 | 110 | validate_overapproximation_1d_list(spline_eval_fn, eval_interval, result) 111 | 112 | 113 | @pytest.mark.skipif( 114 | not (TEST_JIT_COMPILATION and TEST_INCLUSION_FUNCTIONS), 115 | reason="JIT inclusion tests are disabled", 116 | ) 117 | def test_jit_inclusion(spline_coeffs, spline_eval_fn): 118 | """Tests the JIT-compiled inclusion function for the spline.""" 119 | x_knots, _ = spline_coeffs 120 | 121 | eval_interval = irx.interval(x_knots[:-1] + 0.1, x_knots[1:] - 0.1) 122 | 123 | spline_natif_jit = jax.jit(jax.vmap(irx.natif(spline_eval_fn))) 124 | result = spline_natif_jit(eval_interval) 125 | 126 | assert isinstance(result, irx.Interval) 127 | 128 | validate_overapproximation_1d_list(spline_eval_fn, eval_interval, result) 129 | 130 | 131 | @pytest.mark.skipif(not TEST_JACFWD, reason="JACFWD tests are disabled") 132 | def test_jacfwd(spline_coeffs, spline_eval_fn, eval_point): 133 | """Tests forward-mode AD for the spline evaluation function.""" 134 | x_knots, coeffs = spline_coeffs 135 | 136 | spline_jacfwd = jax.jacfwd(spline_eval_fn) 137 | der_ad = spline_jacfwd(eval_point) 138 | 139 | spline_der_fn = make_spline_derivative_eval_fn(x_knots, coeffs) 140 | der_sym = spline_der_fn(eval_point) 141 | 142 | if jnp.ndim(eval_point) == 0: 143 | expected_jacobian = der_sym 144 | else: 145 | expected_jacobian = jnp.diag(der_sym) 146 | 147 | assert jnp.allclose(der_ad, expected_jacobian) 148 | 149 | 150 | @pytest.mark.skipif(not TEST_JACREV, reason="JACREV tests are disabled") 151 | def test_jacrev(spline_coeffs, spline_eval_fn, eval_point): 152 | """Tests reverse-mode AD for the spline evaluation function.""" 153 | x_knots, coeffs = spline_coeffs 154 | 155 | spline_jacrev = jax.jacrev(spline_eval_fn) 156 | der_ad = spline_jacrev(eval_point) 157 | 158 | spline_der_fn = make_spline_derivative_eval_fn(x_knots, coeffs) 159 | der_sym = spline_der_fn(eval_point) 160 | 161 | if jnp.ndim(eval_point) == 0: 162 | expected_jacobian = der_sym 163 | else: 164 | expected_jacobian = jnp.diag(der_sym) 165 | 166 | assert jnp.allclose(der_ad, expected_jacobian) 167 | 168 | 169 | def plot_interval_bounds(ax, interval_bounds, output_bounds, color, label): 170 | """ 171 | Adds shaded rectangles to a plot to visualize interval bounds. 172 | """ 173 | for i in range(len(interval_bounds.lower)): 174 | rect = plt.Rectangle( 175 | (interval_bounds.lower[i], output_bounds.lower[i]), 176 | interval_bounds.upper[i] - interval_bounds.lower[i], 177 | output_bounds.upper[i] - output_bounds.lower[i], 178 | facecolor=color, 179 | alpha=0.4, 180 | label=label if i == 0 else "", 181 | ) 182 | ax.add_patch(rect) 183 | 184 | 185 | if __name__ == "__main__": 186 | # Example of calling the function and comparing to input values. 187 | # This example uses JAX for computation and matplotlib for plotting. 188 | # You can install them with: pip install jax jaxlib numpy matplotlib 189 | 190 | # 1. Define the input data points (x, y) 191 | # Procedurally generate points to test scaling 192 | num_points = 20 # Change this value to test with different numbers of points 193 | 194 | # Generate sorted x-coordinates over a fixed range 195 | x_range = (0.0, 50.0) 196 | x_coords = np.linspace(x_range[0], x_range[1], num_points) 197 | 198 | # Generate y-coordinates randomly from a uniform distribution 199 | y_range = (-10.0, 10.0) 200 | y_coords = np.random.uniform(y_range[0], y_range[1], size=num_points) 201 | 202 | # Combine into a single array of points 203 | input_points_np = np.vstack((x_coords, y_coords)).T 204 | 205 | # Convert to JAX array for spline computation 206 | input_points_jax = jnp.array(input_points_np) 207 | 208 | # Sort for comparison and plotting 209 | sorted_indices = np.argsort(input_points_np[:, 0]) 210 | input_points_np_sorted = input_points_np[sorted_indices] 211 | 212 | # 2. Create the cubic spline coefficients 213 | x_knots, coeffs = create_cubic_spline_coeffs(input_points_jax) 214 | 215 | # 3. Create the evaluation function 216 | spline_eval_fn = make_spline_eval_fn(x_knots, coeffs) 217 | 218 | # 4. Compare to input values 219 | original_x = input_points_np_sorted[:, 0] 220 | original_y = input_points_np_sorted[:, 1] 221 | 222 | # 5. Demonstrate JAX traceability and interval bounds 223 | # Generate points for plotting 224 | x_smooth = np.linspace(x_knots[0], x_knots[-1], 200) 225 | y_smooth = spline_eval_fn(x_smooth) 226 | 227 | # Define interval bounds for testing 228 | # Case 3: N evenly spaced overlapping intervals 229 | N = 25 230 | w = 1.2 * (x_knots[-1] - x_knots[0]) / (N - 1) 231 | centers = np.linspace(x_knots[0], x_knots[-1], N) 232 | interval_bounds_overlapping = irx.interval( 233 | jnp.array(centers - w / 2), jnp.array(centers + w / 2) 234 | ) 235 | 236 | # Test interval bounds using immrax: natif 237 | spline_inclusion_fn_natif = jax.vmap(irx.natif(spline_eval_fn)) 238 | output_bounds_natif_overlapping = spline_inclusion_fn_natif( 239 | interval_bounds_overlapping 240 | ) 241 | 242 | # 6. Plot the results for visual comparison 243 | plt.figure(figsize=(12, 8)) 244 | ax = plt.gca() 245 | 246 | # Plot the main spline curve 247 | plt.plot( 248 | x_smooth, y_smooth, label="Cubic Spline", color="blue", linewidth=2.5, zorder=5 249 | ) 250 | 251 | # Plot the original data points 252 | plt.plot( 253 | original_x, 254 | original_y, 255 | "o", 256 | label="Original Data Points", 257 | color="red", 258 | markersize=8, 259 | zorder=10, 260 | ) 261 | 262 | plot_interval_bounds( 263 | ax, 264 | interval_bounds_overlapping, 265 | output_bounds_natif_overlapping, 266 | "purple", 267 | "Overlapping", 268 | ) 269 | 270 | plt.title("Cubic Spline Interpolation") 271 | plt.xlabel("x") 272 | plt.ylabel("y") 273 | plt.legend() 274 | plt.grid(True) 275 | plt.tight_layout() 276 | plt.show() 277 | -------------------------------------------------------------------------------- /immrax/utils.py: -------------------------------------------------------------------------------- 1 | from math import exp, floor, log 2 | import time 3 | from typing import Callable, Tuple 4 | 5 | import jax 6 | from jax._src.traceback_util import api_boundary 7 | from jax._src.util import wraps 8 | import jax.numpy as jnp 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 11 | import numpy as onp 12 | from pypoman import plot_polygon 13 | from scipy.spatial import HalfspaceIntersection 14 | import shapely.geometry as sg 15 | import shapely.ops as so 16 | 17 | import immrax as irx 18 | from immrax.inclusion import Corner, Interval, all_corners, i2lu, i2ut, ut2i 19 | from immrax.system import Trajectory 20 | 21 | # ================================================================================ 22 | # Function wrappers 23 | # ================================================================================ 24 | 25 | 26 | def timed(f: Callable): 27 | @wraps(f) 28 | @api_boundary 29 | def f_timed(*args, **kwargs): 30 | t0 = time.time() 31 | ret = jax.block_until_ready(f(*args, **kwargs)) 32 | tf = time.time() 33 | return ret, (tf - t0) 34 | 35 | return f_timed 36 | 37 | 38 | def run_times(N: int, f: Callable, *args, **kwargs): 39 | f_timed = timed(f) 40 | times = [] 41 | for i in range(N): 42 | ret, dt = f_timed(*args, **kwargs) 43 | times.append(dt) 44 | return ret, jnp.array(times) 45 | 46 | 47 | # ================================================================================ 48 | # Plotting 49 | # ================================================================================ 50 | 51 | sg_box = lambda x, xi=0, yi=1: sg.box( 52 | x[xi].lower, x[yi].lower, x[xi].upper, x[yi].upper 53 | ) 54 | sg_boxes = lambda xx, xi=0, yi=1: [sg_box(x, xi, yi) for x in xx] 55 | 56 | 57 | def draw_sg_union(ax, boxes, **kwargs): 58 | shape = so.unary_union(boxes) 59 | xs, ys = shape.exterior.xy 60 | kwargs.setdefault("ec", "tab:blue") 61 | kwargs.setdefault("fc", "none") 62 | kwargs.setdefault("lw", 2) 63 | kwargs.setdefault("alpha", 1) 64 | ax.fill(xs, ys, **kwargs) 65 | 66 | 67 | def draw_iarray(ax, x, xi=0, yi=1, **kwargs): 68 | return draw_sg_union(ax, [sg_box(x, xi, yi)], **kwargs) 69 | 70 | 71 | def draw_iarrays(ax, xx, xi=0, yi=1, **kwargs): 72 | return draw_sg_union(ax, sg_boxes(xx, xi, yi), **kwargs) 73 | 74 | 75 | def draw_iarray_3d(ax, x, xi=0, yi=1, zi=2, **kwargs): 76 | Xl, Yl, Zl = x.lower[(xi, yi, zi),] 77 | Xu, Yu, Zu = x.upper[(xi, yi, zi),] 78 | poly_alpha = kwargs.pop("poly_alpha", 0.0) 79 | kwargs.setdefault("color", "tab:blue") 80 | kwargs.setdefault("lw", 0.75) 81 | faces = [ 82 | onp.array( 83 | [[Xl, Yl, Zl], [Xu, Yl, Zl], [Xu, Yu, Zl], [Xl, Yu, Zl], [Xl, Yl, Zl]] 84 | ), 85 | onp.array( 86 | [[Xl, Yl, Zu], [Xu, Yl, Zu], [Xu, Yu, Zu], [Xl, Yu, Zu], [Xl, Yl, Zu]] 87 | ), 88 | onp.array( 89 | [[Xl, Yl, Zl], [Xu, Yl, Zl], [Xu, Yl, Zu], [Xl, Yl, Zu], [Xl, Yl, Zl]] 90 | ), 91 | onp.array( 92 | [[Xl, Yu, Zl], [Xu, Yu, Zl], [Xu, Yu, Zu], [Xl, Yu, Zu], [Xl, Yu, Zl]] 93 | ), 94 | onp.array( 95 | [[Xl, Yl, Zl], [Xl, Yu, Zl], [Xl, Yu, Zu], [Xl, Yl, Zu], [Xl, Yl, Zl]] 96 | ), 97 | onp.array( 98 | [[Xu, Yl, Zl], [Xu, Yu, Zl], [Xu, Yu, Zu], [Xu, Yl, Zu], [Xu, Yl, Zl]] 99 | ), 100 | ] 101 | for face in faces: 102 | ax.plot3D(face[:, 0], face[:, 1], face[:, 2], **kwargs) 103 | kwargs["alpha"] = poly_alpha 104 | ax.add_collection3d(Poly3DCollection([face], **kwargs)) 105 | 106 | 107 | def draw_iarrays_3d(ax, xx, xi=0, yi=1, zi=2, color="tab:blue"): 108 | for x in xx: 109 | draw_iarray_3d(ax, x, xi, yi, zi, color) 110 | 111 | 112 | def plot_interval_t(ax, tt, x, **kwargs): 113 | xl, xu = i2lu(x) 114 | alpha = kwargs.pop("alpha", 0.25) 115 | label = kwargs.pop("label", None) 116 | ax.fill_between(tt, xl, xu, alpha=alpha, label=label, **kwargs) 117 | ax.plot(tt, xl, **kwargs) 118 | ax.plot(tt, xu, **kwargs) 119 | 120 | 121 | def draw_trajectory_2d(traj: Trajectory, vars=(0, 1), **kwargs): 122 | n = traj.ys[0].shape[0] // 2 123 | y_int = [ 124 | irx.ut2i(jnp.array([y[vars[0]], y[vars[1]], y[vars[0] + n], y[vars[1] + n]])) 125 | for y in traj.ys 126 | ] # TODO: fix indexing 127 | alpha = kwargs.pop("alpha", 0.4) 128 | label = kwargs.pop("label", None) 129 | for bound in y_int: 130 | draw_iarray(plt.gca(), bound, alpha=alpha, label=label, **kwargs) 131 | label = "_nolegend_" # Only label the first plot 132 | 133 | 134 | def draw_refined_trajectory_2d(traj: Trajectory, H: jnp.ndarray, vars=(0, 1), **kwargs): 135 | ys_int = [irx.ut2i(y) for y in traj.ys] 136 | color = kwargs.pop("color", "tab:blue") 137 | for bound in ys_int: 138 | dx = 1e-3 * jnp.ones_like(bound.lower) 139 | cons = onp.hstack( 140 | ( 141 | onp.vstack((-H, H)), 142 | onp.concatenate((bound.lower - dx, -bound.upper - dx)).reshape(-1, 1), 143 | ) 144 | ) 145 | hs = HalfspaceIntersection(cons, bound.center[0 : H.shape[1]]) 146 | # try: 147 | # hs = HalfspaceIntersection(cons, bound.center[0 : H.shape[1]]) 148 | # except Exception: 149 | # x = bound.center[0 : H.shape[1]] 150 | # print(bound.lower[0 : H.shape[1]], H @ x, bound.upper[0 : H.shape[1]]) 151 | 152 | vertices = hs.intersections[:, 0:2] 153 | vertices = onp.vstack( 154 | (hs.intersections[:, vars[0]], hs.intersections[:, vars[1]]) 155 | ).T 156 | 157 | plot_polygon(vertices, fill=False, resize=True, color=color, **kwargs) 158 | 159 | 160 | def get_half_intervals(x: Interval, N=1, ut=False): 161 | _xx_0 = i2ut(x) if ut is False else x 162 | n = len(_xx_0) // 2 163 | ret = [_xx_0] 164 | for i in range(N): 165 | newret = [] 166 | for _xx_ in ret: 167 | cent = (_xx_[:n] + _xx_[n:]) / 2 168 | for part_i in range(2**n): 169 | part = jnp.copy(_xx_) 170 | for ind in range(n): 171 | part = part.at[ind + n * ((part_i >> ind) % 2)].set(cent[ind]) 172 | newret.append(part) 173 | ret = newret 174 | if ut: 175 | return ret 176 | else: 177 | return [ut2i(part) for part in ret] 178 | 179 | 180 | # ================================================================================ 181 | # Math 182 | # ================================================================================ 183 | 184 | 185 | # @partial(jax.jit,static_argnums=(1,)) 186 | def get_partitions_ut(x: jax.Array, N: int) -> jax.Array: 187 | n = len(x) // 2 188 | # c^n = N 189 | c = floor(exp(log(N) / n) + 1e-10) 190 | _x = x[:n] 191 | x_ = x[n:] 192 | xc = [] 193 | for i in range(c + 1): 194 | xc.append(_x + i * (x_ - _x) / c) 195 | l = onp.arange(c) 196 | A = onp.array(onp.meshgrid(*[l for i in range(n)])).reshape((n, -1)).T 197 | ret = [] 198 | for i in range(len(A)): 199 | _part = jnp.array([xc[A[i, j]][j] for j in range(n)]) 200 | part_ = jnp.array([xc[A[i, j] + 1][j] for j in range(n)]) 201 | ret.append(jnp.concatenate((_part, part_))) 202 | return jnp.array(ret) 203 | 204 | 205 | def gen_ics(x0, N, key=jax.random.key(0)): 206 | # X = np.empty((N, len(x0))) 207 | X = [] 208 | keys = jax.random.split(key, len(x0)) 209 | for i in range(len(x0)): 210 | # X[:,i] = uniform_disjoint(range, N) 211 | X.append( 212 | jax.random.uniform( 213 | key=keys[i], shape=(N,), minval=x0.lower[i], maxval=x0.upper[i] 214 | ) 215 | ) 216 | return jnp.array(X).T 217 | 218 | 219 | def set_columns_from_corner(corner: Corner, A: Interval): 220 | _Jx = jnp.where(jnp.asarray(corner) == 0, A.lower, A.upper) 221 | J_x = jnp.where(jnp.asarray(corner) == 0, A.upper, A.lower) 222 | return _Jx, J_x 223 | 224 | 225 | def get_corners(x: Interval, corners: Tuple[Corner] | None = None): 226 | corners = all_corners(len(x)) if corners is None else corners 227 | xut = i2ut(x) 228 | return jnp.array( 229 | [ 230 | jnp.array([x.lower[i] if c[i] == 0 else x.upper[i] for i in range(len(x))]) 231 | for c in corners 232 | ] 233 | ) 234 | 235 | 236 | def null_space(A, rcond=None, dim_null: int | None = None): 237 | """Taken from scipy, with some modifications to use jax.numpy""" 238 | u, s, vh = jnp.linalg.svd(A, full_matrices=True) 239 | M, N = u.shape[0], vh.shape[1] 240 | if rcond is None: 241 | rcond = jnp.finfo(s.dtype).eps * max(M, N) 242 | tol = jnp.amax(s) * rcond 243 | num = jnp.sum(s > tol, dtype=int) if dim_null is None else len(s) - dim_null + 1 244 | # num = jnp.sum(s > tol, dtype=int) 245 | # print(num) 246 | Q = vh[num:, :].T.conj() 247 | 248 | return Q 249 | 250 | 251 | def angular_sweep(N: int): 252 | """ 253 | Returns an array of points on the unit circle, evenly spaced in angle, which is in [0, pi] 254 | Both 0 and pi are excluded. 255 | 256 | Args: 257 | N: The number of points to generate 258 | 259 | Returns: 260 | jnp.array of points 261 | """ 262 | return jnp.array( 263 | [ 264 | [jnp.cos(n * jnp.pi / (N + 1)), jnp.sin(n * jnp.pi / (N + 1))] 265 | for n in range(1, N + 1) 266 | ] 267 | ) 268 | 269 | 270 | def check_containment(x, y): 271 | """Checks if the interval x is contained in the interval y. 272 | 273 | Returns 274 | ------- 275 | int 276 | 1 if x is fully contained in y 277 | -1 if x is fully outside of y 278 | 0 if x is partially contained in y 279 | """ 280 | fully_contained = jnp.logical_and( 281 | jnp.all(x.lower >= y.lower), jnp.all(x.upper <= y.upper) 282 | ).astype(int) 283 | fully_outside = jnp.logical_or( 284 | jnp.any(x.lower > y.upper), jnp.any(x.upper < y.lower) 285 | ).astype(int) 286 | return fully_contained - fully_outside 287 | 288 | 289 | def d_metzler(A): 290 | diag = jnp.diag_indices_from(A) 291 | Am = jnp.clip(A, 0, jnp.inf).at[diag].set(A[diag]) 292 | return Am, A - Am 293 | 294 | 295 | def d_positive(B): 296 | return jnp.clip(B, 0, jnp.inf), jnp.clip(B, -jnp.inf, 0) 297 | -------------------------------------------------------------------------------- /immrax/inclusion/interval.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import jax 3 | from jax.tree_util import register_pytree_node_class 4 | import jax.numpy as jnp 5 | from typing import Tuple, Iterable 6 | from jaxtyping import ArrayLike 7 | import numpy as onp 8 | 9 | 10 | @register_pytree_node_class 11 | class Interval: 12 | """Interval: A class to represent an interval in :math:`\\mathbb{R}^n`. 13 | 14 | Use the helper functions :func:`interval`, :func:`icentpert`, :func:`i2centpert`, :func:`i2lu`, :func:`i2ut`, and :func:`ut2i` to create and manipulate intervals. 15 | 16 | Use the transforms :func:`natif`, :func:`jacif`, :func:`mjacif`, :func:`mjacM`, to create inclusion functions. 17 | 18 | Composable with typical jax transforms, such as :func:`jax.jit`, :func:`jax.grad`, and :func:`jax.vmap`. 19 | """ 20 | 21 | lower: jax.Array 22 | upper: jax.Array 23 | 24 | def __init__(self, lower: jax.Array, upper: jax.Array) -> None: 25 | self.lower = lower 26 | self.upper = upper 27 | 28 | def tree_flatten(self): 29 | return ((self.lower, self.upper), "Interval") 30 | 31 | @classmethod 32 | def tree_unflatten(cls, _, children): 33 | return cls(*children) 34 | 35 | @property 36 | def dtype(self) -> jnp.dtype: 37 | return self.lower.dtype 38 | 39 | @property 40 | def shape(self) -> Tuple[int, ...]: 41 | return self.lower.shape 42 | 43 | @property 44 | def size(self) -> int: 45 | return self.lower.size 46 | 47 | @property 48 | def width(self) -> jax.Array: 49 | return self.upper - self.lower 50 | 51 | @property 52 | def center(self) -> jax.Array: 53 | return (self.lower + self.upper) / 2 54 | 55 | @property 56 | def pert(self) -> jax.Array: 57 | return (self.upper - self.lower) / 2 58 | 59 | def __matmul__(self, _: "Interval") -> "Interval": ... 60 | def __truediv__(self, _: "Interval") -> "Interval": ... 61 | def __neg__(self) -> "Interval": ... 62 | 63 | def __len__(self) -> int: 64 | return len(self.lower) 65 | 66 | def reshape(self, *args, **kwargs): 67 | return interval( 68 | self.lower.reshape(*args, **kwargs), self.upper.reshape(*args, **kwargs) 69 | ) 70 | 71 | def ravel(self) -> List["Interval"]: 72 | return [interval(l, u) for l, u in zip(self.lower.ravel(), self.upper.ravel())] 73 | 74 | def atleast_1d(self) -> "Interval": 75 | return interval(jnp.atleast_1d(self.lower), jnp.atleast_1d(self.upper)) 76 | 77 | def atleast_2d(self) -> "Interval": 78 | return interval(jnp.atleast_2d(self.lower), jnp.atleast_2d(self.upper)) 79 | 80 | def atleast_3d(self) -> "Interval": 81 | return interval(jnp.atleast_3d(self.lower), jnp.atleast_3d(self.upper)) 82 | 83 | @property 84 | def ndim(self) -> int: 85 | return self.lower.ndim 86 | 87 | def transpose(self, *args) -> "Interval": 88 | return Interval(self.lower.transpose(*args), self.upper.transpose(*args)) 89 | 90 | @property 91 | def T(self) -> "Interval": 92 | return self.transpose() 93 | 94 | def __and__(self, other: "Interval") -> "Interval": 95 | return interval( 96 | jnp.maximum(self.lower, other.lower), jnp.minimum(self.upper, other.upper) 97 | ) 98 | 99 | def __or__(self, other: "Interval") -> "Interval": 100 | return interval( 101 | jnp.minimum(self.lower, other.lower), jnp.maximum(self.upper, other.upper) 102 | ) 103 | 104 | def __str__(self) -> str: 105 | # return ( 106 | # onp.array( 107 | # [ 108 | # [(l, u)] 109 | # for (l, u) in zip(self.lower.reshape(-1), self.upper.reshape(-1)) 110 | # ], 111 | # dtype=onp.dtype([("f1", float), ("f2", float)]), 112 | # ) 113 | # .reshape(self.shape + (1,)) 114 | # .__str__() 115 | # ) 116 | return self.lower.__str__() + " <= x <= " + self.upper.__str__() 117 | 118 | def __repr__(self) -> str: 119 | # return onp.array([[(l,u)] for (l,u) in 120 | # zip(self.lower.reshape(-1),self.upper.reshape(-1))], 121 | # dtype=onp.dtype([('f1',float), ('f2', float)])).reshape(self.shape + (1,)).__str__() 122 | # dtype=np.dtype([('f1',float), ('f2', float)])).reshape(self.shape + (1,)).__repr__() 123 | return self.lower.__str__() + " <= x <= " + self.upper.__str__() 124 | 125 | def __getitem__(self, i: slice | ArrayLike) -> "Interval": 126 | return Interval(self.lower[i], self.upper[i]) 127 | 128 | def __iter__(self): 129 | """Return an iterator over the interval elements.""" 130 | # Use the actual length to create a proper iterator 131 | # This avoids the infinite loop issue by using explicit indexing 132 | length = int(len(self)) 133 | return (self[i] for i in range(length)) 134 | 135 | 136 | # HELPER FUNCTIONS 137 | 138 | 139 | def interval(lower: ArrayLike, upper: ArrayLike | None = None) -> Interval: 140 | """interval: Helper to create a Interval from a lower and upper bound. 141 | 142 | Parameters 143 | ---------- 144 | lower : ArrayLike 145 | Lower bound of the interval. 146 | upper : ArrayLike 147 | Upper bound of the interval. Set to lower bound if None. Defaults to None. 148 | lower:ArrayLike : 149 | 150 | Returns 151 | ------- 152 | Interval 153 | [lower, upper], or [lower, lower] if upper is None. 154 | 155 | """ 156 | if isinstance(lower, Interval) and upper is None: 157 | return lower 158 | if upper is None: 159 | return Interval(jnp.asarray(lower), jnp.asarray(lower)) 160 | lower = jnp.asarray(lower) 161 | upper = jnp.asarray(upper) 162 | if lower.dtype != upper.dtype: 163 | raise Exception( 164 | f"lower and upper dtype should match, {lower.dtype} != {upper.dtype}" 165 | ) 166 | if lower.shape != upper.shape: 167 | raise Exception( 168 | f"lower and upper shape should match, {lower.shape} != {upper.shape}" 169 | ) 170 | return Interval(jnp.asarray(lower), jnp.asarray(upper)) 171 | 172 | 173 | def icopy(i: Interval) -> Interval: 174 | """icopy: Helper to copy an interval. 175 | 176 | Parameters 177 | ---------- 178 | i : Interval 179 | interval to copy 180 | 181 | Returns 182 | ------- 183 | Interval 184 | copy of the interval 185 | 186 | """ 187 | return Interval(jnp.copy(i.lower), jnp.copy(i.upper)) 188 | 189 | 190 | def icentpert(cent: ArrayLike, pert: ArrayLike) -> Interval: 191 | """icentpert: Helper to create a Interval from a center of an interval and a perturbation. 192 | 193 | Parameters 194 | ---------- 195 | cent : ArrayLike 196 | Center of the interval, i.e., (l + u)/2 197 | pert : ArrayLike 198 | l-inf perturbation from the center, i.e., (u - l)/2 199 | 200 | Returns 201 | ------- 202 | Interval 203 | Interval [cent - pert, cent + pert] 204 | 205 | """ 206 | cent = jnp.asarray(cent) 207 | pert = jnp.asarray(pert) 208 | return interval(cent - pert, cent + pert) 209 | 210 | 211 | centpert2i = icentpert 212 | 213 | 214 | def i2centpert(i: Interval) -> Tuple[jax.Array, jax.Array]: 215 | """i2centpert: Helper to get the center and perturbation from the center of a Interval. 216 | 217 | Parameters 218 | ---------- 219 | i : Interval 220 | _description_ 221 | 222 | Returns 223 | ------- 224 | Tuple[jax.Array, jax.Array] 225 | ((l + u)/2, (u - l)/2) 226 | 227 | """ 228 | return (i.lower + i.upper) / 2, (i.upper - i.lower) / 2 229 | 230 | 231 | def interval_intersect(Is: Iterable[Interval]) -> Interval: 232 | """interval_intersect: Helper to get the intersection of a list of intervals. 233 | 234 | Parameters 235 | ---------- 236 | Is : Iterable[Interval] 237 | list of intervals 238 | 239 | Returns 240 | ------- 241 | Interval 242 | intersection of the intervals 243 | 244 | """ 245 | l = jnp.max(jnp.array([i.lower for i in Is]), axis=0) 246 | u = jnp.min(jnp.array([i.upper for i in Is]), axis=0) 247 | return interval(l, u) 248 | 249 | 250 | def interval_union(Is: Iterable[Interval]) -> Interval: 251 | """interval_union: Helper to get the union of a list of intervals. 252 | 253 | Parameters 254 | ---------- 255 | Is : Iterable[Interval] 256 | list of intervals 257 | 258 | Returns 259 | ------- 260 | Interval 261 | union of the intervals 262 | 263 | """ 264 | l = jnp.min(jnp.array([i.lower for i in Is]), axis=0) 265 | u = jnp.max(jnp.array([i.upper for i in Is]), axis=0) 266 | return interval(l, u) 267 | 268 | 269 | def i2lu(i: Interval) -> Tuple[jax.Array, jax.Array]: 270 | """i2lu: Helper to get the lower and upper bound of a Interval. 271 | 272 | Parameters 273 | ---------- 274 | interval : Interval 275 | _description_ 276 | 277 | Returns 278 | ------- 279 | Tuple[jax.Array, jax.Array] 280 | (l, u) 281 | 282 | """ 283 | return (i.lower, i.upper) 284 | 285 | 286 | def lu2i(l: jax.Array, u: jax.Array) -> Interval: 287 | """lu2i: Helper to create a Interval from a lower and upper bound. 288 | 289 | Parameters 290 | ---------- 291 | l : jax.Array 292 | Lower bound of the interval. 293 | u : jax.Array 294 | Upper bound of the interval. 295 | 296 | Returns 297 | ------- 298 | Interval 299 | [l, u] 300 | 301 | """ 302 | return interval(l, u) 303 | 304 | 305 | def i2ut(i: Interval) -> jax.Array: 306 | """i2ut: Helper to convert an interval to an upper triangular coordinate in :math:`\\mathbb{R}\\times\\mathbb{R}`. 307 | 308 | Parameters 309 | ---------- 310 | interval : Interval 311 | interval to convert 312 | 313 | Returns 314 | ------- 315 | jax.Array 316 | upper triangular coordinate in :math:`\\mathbb{R}\\times\\mathbb{R}` 317 | 318 | """ 319 | return jnp.concatenate((i.lower, i.upper)) 320 | 321 | 322 | def ut2i(coordinate: jax.Array, n: int | None = None) -> Interval: 323 | """ut2i: Helper to convert an upper triangular coordinate in :math:`\\mathbb{R}\\times\\mathbb{R}` to an interval. 324 | 325 | Parameters 326 | ---------- 327 | coordinate : jax.Array 328 | upper triangular coordinate to convert 329 | n : int 330 | length of interval, automatically determined if None. Defaults to None. 331 | 332 | Returns 333 | ------- 334 | Interval 335 | interval representation of the coordinate 336 | 337 | """ 338 | if n is None: 339 | n = len(coordinate) // 2 340 | return interval(coordinate[:n], coordinate[n:]) 341 | 342 | 343 | def izeros(shape: Tuple[int], dtype: onp.dtype = jnp.float32) -> Interval: 344 | """izeros: Helper to create a Interval of zeros. 345 | 346 | Parameters 347 | ---------- 348 | shape : Tuple[int] 349 | shape of the interval 350 | dtype : np.dtype 351 | dtype of the interval. Defaults to jnp.float32. 352 | 353 | Returns 354 | ------- 355 | Interval 356 | interval of zeros 357 | 358 | """ 359 | return interval(jnp.zeros(shape, dtype), jnp.zeros(shape, dtype)) 360 | 361 | 362 | def iconcatenate(intervals: Iterable[Interval], axis: int = 0) -> Interval: 363 | """iconcatenate: Helper to concatenate intervals (cartesian product). 364 | 365 | Parameters 366 | ---------- 367 | intervals : Iterable[Interval] 368 | intervals to concatenate 369 | axis : int 370 | axis to concatenate on. Defaults to 0. 371 | 372 | Returns 373 | ------- 374 | Interval 375 | concatenated interval 376 | 377 | """ 378 | return interval( 379 | jnp.concatenate([i.lower for i in intervals], axis=axis), 380 | jnp.concatenate([i.upper for i in intervals], axis=axis), 381 | ) 382 | -------------------------------------------------------------------------------- /immrax/embedding.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from functools import partial 3 | from typing import Any, Callable, Literal, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jaxtyping import Float, Integer, Bool, Array 8 | 9 | from .refinement import SampleRefinement, LinProgRefinement 10 | from .inclusion import Interval, i2ut, interval, jacif, mjacif, natif, ut2i 11 | from .system import LiftedSystem, System 12 | 13 | __all__ = [ 14 | "EmbeddingSystem", 15 | "InclusionEmbedding", 16 | "TransformEmbedding", 17 | "ifemb", 18 | "natemb", 19 | "jacemb", 20 | "mjacemb", 21 | "embed", 22 | "get_faces", 23 | ] 24 | 25 | 26 | class EmbeddingSystem(System, abc.ABC): 27 | """EmbeddingSystem 28 | 29 | Embeds a System 30 | 31 | ..math:: 32 | \\mathbb{R}^n \\times \\text{inputs} \\to\\mathbb{R}^n` 33 | 34 | into an Embedding System evolving on the upper triangle 35 | 36 | ..math:: 37 | \\mathcal{T}^{2n} \\times \\text{embedded inputs} \\to \\mathbb{T}^{2n}. 38 | """ 39 | 40 | sys: System 41 | 42 | @abc.abstractmethod 43 | def E(self, t: Union[Integer, Float], x: jax.Array, *args, **kwargs) -> jax.Array: 44 | """The right hand side of the embedding system. 45 | 46 | Parameters 47 | ---------- 48 | t : Union[Integer, Float] 49 | The time of the embedding system. 50 | x : jax.Array 51 | The state of the embedding system. 52 | *args : 53 | interval-valued control inputs, disturbance inputs, etc. Depends on parent class. 54 | **kwargs : 55 | 56 | 57 | Returns 58 | ------- 59 | jax.Array 60 | The time evolution of the state on the upper triangle 61 | 62 | """ 63 | 64 | def f(self, t: Union[Integer, Float], x: jax.Array, *args, **kwargs) -> jax.Array: 65 | return self.E(t, x, *args, **kwargs) 66 | 67 | 68 | class InclusionEmbedding(EmbeddingSystem): 69 | """EmbeddingSystem 70 | 71 | Embeds a System 72 | 73 | ..math:: 74 | \\mathbb{R}^n \\times \\text{inputs} \\to\\mathbb{R}^n`, 75 | 76 | into an Embedding System evolving on the upper triangle 77 | 78 | ..math:: 79 | \\mathcal{T}^{2n} \\times \\text{embedded inputs} \\to \\mathbb{T}^{2n}, 80 | 81 | using an Inclusion Function for the dynamics f. 82 | """ 83 | 84 | sys: System 85 | F: Callable[..., Interval] 86 | Fi: Callable[..., Interval] 87 | 88 | def __init__( 89 | self, 90 | sys: System, 91 | F: Callable[..., Interval], 92 | Fi: Callable[..., Interval] | None = None, 93 | ) -> None: 94 | """Initialize an EmbeddingSystem using a System and an inclusion function for f. 95 | 96 | Args: 97 | sys (System): The system to be embedded 98 | if_transform (InclusionFunction): An inclusion function for f. 99 | """ 100 | self.sys = sys 101 | self.F = F 102 | self.evolution = sys.evolution 103 | self.xlen = sys.xlen * 2 104 | 105 | def E( 106 | self, 107 | t: Any, 108 | x: jax.Array, 109 | *args, 110 | refine: Callable[[Interval], Interval] | None = None, 111 | **kwargs, 112 | ) -> jax.Array: 113 | t = interval(t) 114 | # jax.debug.print("isnan: {0}", jnp.isnan(x).any()) 115 | 116 | if refine is not None: 117 | convert = lambda x: refine(ut2i(x)) 118 | Fkwargs = lambda t, x, *args: self.F(t, refine(x), *args, **kwargs) 119 | else: 120 | convert = ut2i 121 | Fkwargs = partial(self.F, **kwargs) 122 | 123 | x_int = convert(x) 124 | # jax.debug.print( 125 | # "lower: {0}, upper: {1}", 126 | # jnp.isnan(x_int.lower).any(), 127 | # jnp.isnan(x_int.upper).any(), 128 | # ) 129 | 130 | if self.evolution == "continuous": 131 | n = self.sys.xlen 132 | _x = x_int.lower 133 | x_ = x_int.upper 134 | 135 | # Computing F on the faces of the hyperrectangle 136 | 137 | _X = interval( 138 | jnp.tile(_x, (n, 1)), jnp.where(jnp.eye(n), _x, jnp.tile(x_, (n, 1))) 139 | ) 140 | _E = jax.vmap(Fkwargs, (None, 0) + (None,) * len(args))(t, _X, *args) 141 | 142 | X_ = interval( 143 | jnp.where(jnp.eye(n), x_, jnp.tile(_x, (n, 1))), jnp.tile(x_, (n, 1)) 144 | ) 145 | E_ = jax.vmap(Fkwargs, (None, 0) + (None,) * len(args))(t, X_, *args) 146 | 147 | # return jnp.concatenate((_E, E_)) 148 | output = jnp.concatenate((jnp.diag(_E.lower), jnp.diag(E_.upper))) 149 | # jax.debug.print("output isnan: {0}", jnp.isnan(output).any()) 150 | return jnp.concatenate((jnp.diag(_E.lower), jnp.diag(E_.upper))) 151 | 152 | elif self.evolution == "discrete": 153 | # Convert x from ut to i, compute through F, convert back to ut. 154 | return i2ut(self.F(interval(t), x_int, *args, **kwargs)) 155 | else: 156 | raise Exception("evolution needs to be 'continuous' or 'discrete'") 157 | 158 | 159 | def ifemb(sys: System, F: Callable[..., Interval]): 160 | """Creates an EmbeddingSystem using an inclusion function for the dynamics of a System. 161 | 162 | Parameters 163 | ---------- 164 | sys : System 165 | System to embed 166 | F : Callable[..., Interval] 167 | Inclusion function for the dynamics of sys. 168 | 169 | Returns 170 | ------- 171 | EmbeddingSystem 172 | Embedding system from the inclusion function transform. 173 | 174 | """ 175 | return InclusionEmbedding(sys, F) 176 | 177 | 178 | def embed(F: Callable[..., Interval]): 179 | def E( 180 | t: Any, 181 | x: jax.Array, 182 | *args, 183 | refine: Callable[[Interval], Interval] | None = None, 184 | **kwargs, 185 | ): 186 | n = len(x) // 2 187 | _x = x[:n] 188 | x_ = x[n:] 189 | 190 | if refine is not None: 191 | Fkwargs = lambda t, x, *args: F(t, refine(x), *args, **kwargs) 192 | else: 193 | Fkwargs = partial(F, **kwargs) 194 | 195 | # Computing F on the faces of the hyperrectangle 196 | 197 | if n > 1: 198 | _X = interval( 199 | jnp.tile(_x, (n, 1)), jnp.where(jnp.eye(n), _x, jnp.tile(x_, (n, 1))) 200 | ) 201 | _E = interval( 202 | jax.vmap(Fkwargs, (None, 0) + (None,) * len(args))(t, _X, *args) 203 | ) 204 | 205 | X_ = interval( 206 | jnp.where(jnp.eye(n), x_, jnp.tile(_x, (n, 1))), jnp.tile(x_, (n, 1)) 207 | ) 208 | E_ = interval( 209 | jax.vmap(Fkwargs, (None, 0) + (None,) * len(args))(t, X_, *args) 210 | ) 211 | return jnp.concatenate((jnp.diag(_E.lower), jnp.diag(E_.upper))) 212 | else: 213 | _E = Fkwargs(t, interval(_x)).lower 214 | E_ = Fkwargs(t, interval(x_)).upper 215 | return jnp.array([_E, E_]) 216 | 217 | return E 218 | 219 | 220 | def get_faces(ix: Interval) -> tuple[Interval, Interval]: 221 | n = len(ix) 222 | 223 | _x = ix.lower 224 | x_ = ix.upper 225 | 226 | # _X = interval( 227 | # , jnp.where(jnp.eye(n), _x, jnp.tile(x_, (n, 1))) 228 | # ) 229 | 230 | # X_ = interval( 231 | # , jnp.tile(x_, (n, 1)) 232 | # ) 233 | 234 | X = interval( 235 | jnp.vstack( 236 | (jnp.tile(_x, (n, 1)), jnp.where(jnp.eye(n), x_, jnp.tile(_x, (n, 1)))) 237 | ), 238 | jnp.vstack( 239 | (jnp.where(jnp.eye(n), _x, jnp.tile(x_, (n, 1))), jnp.tile(x_, (n, 1))) 240 | ), 241 | ) 242 | 243 | return X 244 | 245 | 246 | class TransformEmbedding(InclusionEmbedding): 247 | def __init__(self, sys: System, if_transform=natif) -> None: 248 | """Initialize an EmbeddingSystem using a System and an inclusion function transform. 249 | 250 | Parameters 251 | ---------- 252 | sys : System 253 | _description_ 254 | if_transform : IFTransform 255 | _description_. Defaults to natif. 256 | 257 | Returns 258 | ------- 259 | 260 | """ 261 | F = if_transform(sys.f) 262 | # Fi = [if_transform(sys.fi[i]) for i in range(sys.xlen)] 263 | super().__init__(sys, F) 264 | 265 | 266 | def natemb(sys: System): 267 | """Creates an EmbeddingSystem using the natural inclusion function of the dynamics of a System. 268 | 269 | Parameters 270 | ---------- 271 | sys : System 272 | System to embed 273 | 274 | Returns 275 | ------- 276 | EmbeddingSystem 277 | Embedding system from the natural inclusion function transform. 278 | 279 | """ 280 | return TransformEmbedding(sys, if_transform=natif) 281 | 282 | 283 | def jacemb(sys: System): 284 | """Creates an EmbeddingSystem using the Jacobian-based inclusion function of the dynamics of a System. 285 | 286 | Parameters 287 | ---------- 288 | sys : System 289 | System to embed 290 | 291 | Returns 292 | ------- 293 | EmbeddingSystem 294 | Embedding system from the Jacobian-based inclusion function transform. 295 | 296 | """ 297 | return TransformEmbedding(sys, if_transform=jacif) 298 | 299 | 300 | def mjacemb(sys: System): 301 | """Creates an EmbeddingSystem using the Mixed Jacobian-based inclusion function of the dynamics of a System. 302 | 303 | Parameters 304 | ---------- 305 | sys : System 306 | System to embed 307 | 308 | Returns 309 | ------- 310 | EmbeddingSystem 311 | Embedding system from the Mixed Jacobian-based inclusion function transform. 312 | 313 | """ 314 | return TransformEmbedding(sys, if_transform=mjacif) 315 | 316 | 317 | class AuxVarEmbedding(InclusionEmbedding): 318 | """ 319 | Embedding system defined by auxiliary variables.n 320 | 321 | Attributes: 322 | H: Matrix of auxiliary variables to add 323 | Hp: psuedo-inverse of H 324 | """ 325 | 326 | def __init__( 327 | self, 328 | sys: System, 329 | H: Float[Array, "lstates bstates"], 330 | *, 331 | base_invariants: Bool[Array, "lstates"] | None = None, 332 | if_transform: Callable[[Callable[..., jnp.ndarray]], Callable[..., Interval]] 333 | | None = None, 334 | F: Callable[..., Interval] | None = None, 335 | mode: Literal["sample", "linprog"] = "sample", 336 | num_samples: int = 10, 337 | ) -> None: 338 | """ 339 | Embedding system defined by auxiliary variables. Given a base system with dimension n 340 | and matrix H m by n, the base system is first lifted to dimension m by adding m-n 341 | auxiliary variables. Each aux var is a linear combination of some of the real state 342 | variables, defined by the coefficients of the rows of H. Because of this, the subspace 343 | defined by y = Hx is invariant in the lifted state under the base system dynamics. 344 | 345 | The lifted system is then embedded onto the upper triangle with either the inclusion 346 | function F or if_transform given. 347 | 348 | The intervals of the embedded system can then be refined by the subspace invariance of 349 | the lifted system. There are two methods to do this, "sample" and "linprog", and the 350 | method is chosen by the mode argument. 351 | 352 | Args: 353 | sys: Base system to embed 354 | H: Matrix of auxiliary variables to add 355 | mode: Whether to refine by sampling or solving a LP. Defaults to sample 356 | if_transform: How to construct the inclusion function for the embedding system 357 | F: For greater control, allows you to pass an inclusion function directly. NOTE: 358 | is required to be an inclusion function for the *lifted* system, not the bases system 359 | num_samples (): How many samples to take for sampling refinement. Defaults to 10 360 | """ 361 | self.H = H 362 | # self.Hp = jnp.linalg.pinv(H) 363 | H_inv = jnp.linalg.inv(H[: H.shape[1]]) 364 | self.Hp = jnp.hstack([H_inv, jnp.zeros((H.shape[1], H.shape[0] - H.shape[1]))]) 365 | 366 | liftsys = LiftedSystem(sys, self.H, self.Hp) 367 | 368 | if mode == "sample": 369 | self.IH = SampleRefinement(H, num_samples).get_refine_func() 370 | elif mode == "linprog": 371 | self.IH = LinProgRefinement(H).get_refine_func() 372 | else: 373 | raise ValueError( 374 | "Invalid mode argument. Mode must be either 'sample' or 'linprog'." 375 | ) 376 | 377 | def liftf(t, x, *args, **kwargs): 378 | dx = liftsys.f(t, x, *args, **kwargs) 379 | if base_invariants is not None: 380 | dx = dx.at[base_invariants].set(0) 381 | return dx 382 | 383 | if F is None and if_transform is None: 384 | F = natif(liftf) # default to natif 385 | elif if_transform is not None and F is None: 386 | F = if_transform(liftf) 387 | elif F is not None and if_transform is None: 388 | pass # do nothing, take F as given 389 | else: 390 | raise ValueError( 391 | "Cannot specify both an inclusion function F and if_transform" 392 | ) 393 | 394 | super().__init__(liftsys, F) 395 | 396 | def E( 397 | self, 398 | t: Any, 399 | x: jax.Array, 400 | *args, 401 | refine: Callable[[Interval], Interval] | None = None, 402 | **kwargs, 403 | ) -> jax.Array: 404 | if refine is not None: 405 | raise ( 406 | Exception( 407 | "Class AuxVarEmbedding does not support passing refine as an argument, since the refinement is calculated from the auxillary variables." 408 | ) 409 | ) 410 | 411 | return super().E(t, x, *args, refine=self.IH, **kwargs) 412 | -------------------------------------------------------------------------------- /immrax/parametric/sets/ellipsoid.py: -------------------------------------------------------------------------------- 1 | from ..parametope import hParametope 2 | import jax.numpy as jnp 3 | from jaxtyping import ArrayLike 4 | from ...utils import null_space 5 | from matplotlib.patches import Ellipse, Patch 6 | from matplotlib.axes import Axes 7 | import numpy as onp 8 | from matplotlib.path import Path 9 | 10 | # import transforms from matplotlib 11 | from matplotlib import transforms 12 | from jax.tree_util import register_pytree_node_class 13 | from ...inclusion import icentpert 14 | 15 | 16 | @register_pytree_node_class 17 | class Ellipsoid(hParametope): 18 | def __init__(self, ox, alpha, y): 19 | # ly = jnp.zeros_like(uy) if uy is not None else None 20 | super().__init__(ox, alpha, y) 21 | 22 | @classmethod 23 | def from_parametope(cls, pt: hParametope): 24 | return Ellipsoid(pt.ox, pt.alpha, pt.y) 25 | 26 | def h(self, a: ArrayLike): 27 | return jnp.array([-a.T @ a, a.T @ a]) 28 | 29 | def hinv(self, y): 30 | # Returns a box containing the preimage of the constraint over iy 31 | n = len(self.ox) 32 | yu = y[1] 33 | 34 | # |x|_inf \leq |x|_2 \leq \sqrt{n} |x|_inf 35 | return icentpert(jnp.zeros(n), jnp.sqrt(yu) * jnp.ones(n)) 36 | # return icentpert(jnp.zeros(n), yu*jnp.ones(n)) 37 | 38 | # def iover (self) : 39 | # return self.ginv(self.H@interval(self.ly, self.uy)) 40 | 41 | @property 42 | def P(self): 43 | return self.alpha.T @ self.alpha 44 | 45 | def V(self, x: ArrayLike): 46 | ax = self.alpha @ (x - self.ox) 47 | return ax.T @ ax 48 | 49 | def plot_projection(self, ax, xi=0, yi=1, rescale=False, **kwargs): 50 | P = self.P / self.y[1] 51 | n = P.shape[0] 52 | if n == 2: 53 | _plot_ellipse(P, self.ox, ax, rescale, **kwargs) 54 | return 55 | ind = [k for k in range(n) if k not in [xi, yi]] 56 | Phat = P[ind, :] 57 | N = null_space(Phat) 58 | M = N[(xi, yi), :] # Since M is guaranteed 2x2, 59 | Minv = (1 / (M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0])) * jnp.array( 60 | [[M[1, 1], -M[0, 1]], [-M[1, 0], M[0, 0]]] 61 | ) 62 | Q = Minv.T @ N.T @ P @ N @ Minv 63 | _plot_ellipse(Q, self.ox[(xi, yi),], ax, rescale, **kwargs) 64 | 65 | # def __repr__(self) : 66 | # return f'Ellipsoid(ox={self.ox}, H={self.H}, uy={self.uy})' 67 | 68 | # def __str__(self) : 69 | # return f'Ellipsoid(ox={self.ox}, H={self.H}, uy={self.uy})' 70 | 71 | 72 | # def iover (e:Ellipsoid) -> irx.Interval : 73 | # """Interval over-approximation of an Ellipsoid""" 74 | # overpert = jnp.sqrt(jnp.diag(e.Pinv)) 75 | # return irx.icentpert(e.xc, overpert) 76 | 77 | # def eover (ix:irx.Interval, P:jax.Array) -> Ellipsoid : 78 | # """Ellipsoid over-approximation of an Interval""" 79 | # xc, xp = irx.i2centpert(ix) 80 | # corns = irx.get_corners(ix - xc) 81 | # m = jnp.max(jnp.array([norm_P(c, P) for c in corns])) 82 | # return Ellipsoid(P/m, xc) 83 | 84 | # @register_pytree_node_class 85 | # class EllipsoidAnnulus (Ellipsoid) : 86 | # def __init__ (self, ox, H, ly, uy) : 87 | # super().__init__(ox, H, ly, uy) 88 | 89 | # def plot_projection (self, ax, xi=0, yi=1, rescale=False, **kwargs) : 90 | # P = self.H[0].T @ self.H[0] / self.uy[0] 91 | # n = P.shape[0] 92 | # if n == 2 : 93 | # print(jnp.sqrt(self.ly[0])/jnp.sqrt(self.uy[0])) 94 | # _plot_annulus (P, self.ox, jnp.sqrt(self.ly[0])/jnp.sqrt(self.uy[0]), ax, rescale, **kwargs) 95 | # return 96 | # ind = [k for k in range(n) if k not in [xi,yi]] 97 | # Phat = P[ind,:] 98 | # N = null_space(Phat) 99 | # M = N[(xi,yi),:] # Since M is guaranteed 2x2, 100 | # Minv = (1/(M[0,0]*M[1,1] - M[0,1]*M[1,0]))*jnp.array([[M[1,1], -M[0,1]], [-M[1,0], M[0,0]]]) 101 | # Q = Minv.T@N.T@P@N@Minv 102 | # _plot_annulus(Q, self.ox[(xi,yi),], (self.ly[0]/self.uy[0]), ax, rescale, **kwargs) 103 | 104 | # @classmethod 105 | # def from_ds (cls, ds:DualStar) : 106 | # return EllipsoidAnnulus(ds.ox, ds.H[0], ds.ly[0], ds.uy[0]) 107 | 108 | # class EllipsoidAnnulus (DualStar) : 109 | # def __init__ (self, ox, H, ly, uy) : 110 | # super().__init__(ox, [H], [ly], [uy]) 111 | 112 | # def g (self, i:int, a:ArrayLike) : 113 | # if i > 0 : 114 | # raise Exception("Something has gone horribly wrong---Ellipsoid has only one constraint") 115 | # return a.T @ a 116 | 117 | # @classmethod 118 | # def from_ds (cls, ds:DualStar) : 119 | # return Ellipsoid(ds.ox, ds.H[0], ds.ly[0], ds.uy[0]) 120 | 121 | # def V (self, x:ArrayLike) : 122 | # P = self.H[0].T @ self.H[0] 123 | # return (x - self.ox).T @ P @ (x - self.ox) 124 | 125 | # def plot_projection (self, ax, xi=0, yi=1, rescale=False, **kwargs) : 126 | # P = self.H[0].T @ self.H[0] / self.uy[0] 127 | # n = P.shape[0] 128 | # if n == 2 : 129 | # _plot_annulus (P, self.ox, self.ly[0]/self.uy[0], ax, rescale, **kwargs) 130 | # return 131 | # ind = [k for k in range(n) if k not in [xi,yi]] 132 | # Phat = P[ind,:] 133 | # N = null_space(Phat) 134 | # M = N[(xi,yi),:] # Since M is guaranteed 2x2, 135 | # Minv = (1/(M[0,0]*M[1,1] - M[0,1]*M[1,0]))*jnp.array([[M[1,1], -M[0,1]], [-M[1,0], M[0,0]]]) 136 | # Q = Minv.T@N.T@P@N@Minv 137 | # _plot_annulus(Q, self.ox[(xi,yi),], self.ly[0]/self.uy[0], ax, rescale, **kwargs) 138 | 139 | # def __repr__(self) : 140 | # return f'EllipsoidAnnulus(ox={self.ox}, H={self.H}, ly={self.ly}, uy={self.uy})' 141 | 142 | # def __str__(self) : 143 | # return f'EllipsoidAnnulus(ox={self.ox}, H={self.H}, ly={self.ly}, uy={self.uy})' 144 | 145 | 146 | def _plot_ellipse( 147 | Q: ArrayLike, 148 | xc: ArrayLike = jnp.zeros(2), 149 | ax: Axes | None = None, 150 | rescale: bool = False, 151 | **kwargs, 152 | ): 153 | """ 154 | Parameters 155 | ---------- 156 | Q : ArrayLike 157 | PD matrix defining the ellipse 158 | xc : ArrayLike, optional 159 | Center of the ellipse, by default jnp.zeros(2) 160 | ax : Axes | None, optional 161 | Matplotlib Axes object to plot the ellipse on, plt.gca() if None, by default None 162 | rescale : bool, optional 163 | Rescales the axes to fit the ellipse, by default False 164 | 165 | Raises 166 | ------ 167 | ValueError 168 | Q must be a 2x2 matrix 169 | """ 170 | n = Q.shape[0] 171 | if n != 2: 172 | raise ValueError( 173 | "Use _plot_ellipse for 2D ellipses, see Ellipsoid.plot_projection" 174 | ) 175 | 176 | S, U = jnp.linalg.eigh(Q) 177 | Sinv = 1 / S 178 | 179 | kwargs.setdefault("color", "k") 180 | kwargs.setdefault("fill", False) 181 | width, height = 2 * jnp.sqrt(Sinv) 182 | angle = jnp.arctan2(U[1, 0], U[0, 0]) * 180 / jnp.pi 183 | ellipse = Ellipse(xy=xc, width=width, height=height, angle=angle, **kwargs) 184 | ax.add_patch(ellipse) 185 | 186 | if rescale: 187 | ax.set_xlim(xc[0] - 1.5 * width, xc[0] + 1.5 * width) 188 | ax.set_ylim(xc[1] - 1.5 * height, xc[1] + 1.5 * height) 189 | 190 | 191 | def _plot_annulus( 192 | Q: ArrayLike, 193 | xc: ArrayLike = jnp.zeros(2), 194 | inner: float = 0.5, 195 | ax: Axes | None = None, 196 | rescale: bool = False, 197 | **kwargs, 198 | ): 199 | """ 200 | Parameters 201 | ---------- 202 | Q : ArrayLike 203 | PD matrix defining the ellipse 204 | xc : ArrayLike, optional 205 | Center of the ellipse, by default jnp.zeros(2) 206 | inner: float, optional 207 | Inner radius of the annulus, by default 0.5 208 | ax : Axes | None, optional 209 | Matplotlib Axes object to plot the ellipse on, plt.gca() if None, by default None 210 | rescale : bool, optional 211 | Rescales the axes to fit the ellipse, by default False 212 | 213 | Raises 214 | ------ 215 | ValueError 216 | Q must be a 2x2 matrix 217 | """ 218 | n = Q.shape[0] 219 | if n != 2: 220 | raise ValueError( 221 | "Use _plot_ellipse for 2D ellipses, see Ellipsoid.plot_projection" 222 | ) 223 | 224 | S, U = jnp.linalg.eigh(Q) 225 | Sinv = 1 / S 226 | 227 | kwargs.setdefault("color", "k") 228 | kwargs.setdefault("fill", False) 229 | width, height = 2 * jnp.sqrt(Sinv) 230 | angle = jnp.arctan2(U[1, 0], U[0, 0]) * 180 / jnp.pi 231 | # outer_ellipse = Ellipse(xy=xc, width=width, height=height, angle=angle, **kwargs) 232 | # inner_ellipse = Ellipse(xy=xc, width=inner*width, height=inner*height, angle=angle, **kwargs) 233 | annulus = _AnnulusP(xy=xc, r=jnp.sqrt(Sinv), width=inner, angle=angle, **kwargs) 234 | ax.add_patch(annulus) 235 | 236 | if rescale: 237 | ax.set_xlim(xc[0] - 1.5 * width, xc[0] + 1.5 * width) 238 | ax.set_ylim(xc[1] - 1.5 * height, xc[1] + 1.5 * height) 239 | 240 | 241 | # def _plot_3d_ellipsoid () 242 | 243 | 244 | class _AnnulusP(Patch): 245 | """ 246 | An elliptical annulus. 247 | Most of the following code is from matplotlib.patches.Annulus. 248 | There are small modification to make the inner ellipse defined as 249 | a percentage of the outer ellipse---scaling major and minor axes 250 | as a multiplier of the outer ellipse's major and minor axes instead of additive. 251 | """ 252 | 253 | # @_docstring.interpd 254 | def __init__(self, xy, r, width, angle=0.0, **kwargs): 255 | """ 256 | Parameters 257 | ---------- 258 | xy : (float, float) 259 | xy coordinates of annulus centre. 260 | r : float or (float, float) 261 | The radius, or semi-axes: 262 | 263 | - If float: radius of the outer circle. 264 | - If two floats: semi-major and -minor axes of outer ellipse. 265 | width : float 266 | Width (thickness) of the annular ring. The width is measured inward 267 | from the outer ellipse so that for the inner ellipse the semi-axes 268 | are given by ``r - width``. *width* must be less than or equal to 269 | the semi-minor axis. 270 | angle : float, default: 0 271 | Rotation angle in degrees (anti-clockwise from the positive 272 | x-axis). Ignored for circular annuli (i.e., if *r* is a scalar). 273 | **kwargs 274 | Keyword arguments control the `Patch` properties: 275 | 276 | %(Patch:kwdoc)s 277 | """ 278 | super().__init__(**kwargs) 279 | 280 | self.set_radii(r) 281 | self.center = xy 282 | self.width = width 283 | self.angle = angle 284 | self._path = None 285 | 286 | def __str__(self): 287 | if self.a == self.b: 288 | r = self.a 289 | else: 290 | r = (self.a, self.b) 291 | 292 | return "Annulus(xy=(%s, %s), r=%s, width=%s, angle=%s)" % ( 293 | *self.center, 294 | r, 295 | self.width, 296 | self.angle, 297 | ) 298 | 299 | def set_center(self, xy): 300 | """ 301 | Set the center of the annulus. 302 | 303 | Parameters 304 | ---------- 305 | xy : (float, float) 306 | """ 307 | self._center = xy 308 | self._path = None 309 | self.stale = True 310 | 311 | def get_center(self): 312 | """Return the center of the annulus.""" 313 | return self._center 314 | 315 | center = property(get_center, set_center) 316 | 317 | def set_width(self, width): 318 | """ 319 | Set the width (thickness) of the annulus ring. 320 | 321 | The width is measured as a percent of both minor and major axes. 322 | 323 | Parameters 324 | ---------- 325 | width : float 326 | """ 327 | if width > 1 or width < 0: 328 | raise ValueError("Width of annulus must be a float between 0 and 1.") 329 | 330 | self._width = width 331 | self._path = None 332 | self.stale = True 333 | 334 | def get_width(self): 335 | """Return the width (thickness) of the annulus ring.""" 336 | return self._width 337 | 338 | width = property(get_width, set_width) 339 | 340 | def set_angle(self, angle): 341 | """ 342 | Set the tilt angle of the annulus. 343 | 344 | Parameters 345 | ---------- 346 | angle : float 347 | """ 348 | self._angle = angle 349 | self._path = None 350 | self.stale = True 351 | 352 | def get_angle(self): 353 | """Return the angle of the annulus.""" 354 | return self._angle 355 | 356 | angle = property(get_angle, set_angle) 357 | 358 | def set_semimajor(self, a): 359 | """ 360 | Set the semi-major axis *a* of the annulus. 361 | 362 | Parameters 363 | ---------- 364 | a : float 365 | """ 366 | self.a = float(a) 367 | self._path = None 368 | self.stale = True 369 | 370 | def set_semiminor(self, b): 371 | """ 372 | Set the semi-minor axis *b* of the annulus. 373 | 374 | Parameters 375 | ---------- 376 | b : float 377 | """ 378 | self.b = float(b) 379 | self._path = None 380 | self.stale = True 381 | 382 | def set_radii(self, r): 383 | """ 384 | Set the semi-major (*a*) and semi-minor radii (*b*) of the annulus. 385 | 386 | Parameters 387 | ---------- 388 | r : float or (float, float) 389 | The radius, or semi-axes: 390 | 391 | - If float: radius of the outer circle. 392 | - If two floats: semi-major and -minor axes of outer ellipse. 393 | """ 394 | if onp.shape(r) == (2,): 395 | self.a, self.b = r 396 | elif onp.shape(r) == (): 397 | self.a = self.b = float(r) 398 | else: 399 | raise ValueError("Parameter 'r' must be one or two floats.") 400 | 401 | self._path = None 402 | self.stale = True 403 | 404 | def get_radii(self): 405 | """Return the semi-major and semi-minor radii of the annulus.""" 406 | return self.a, self.b 407 | 408 | radii = property(get_radii, set_radii) 409 | 410 | def _transform_verts(self, verts, a, b): 411 | return ( 412 | transforms.Affine2D() 413 | .scale(*self._convert_xy_units((a, b))) 414 | .rotate_deg(self.angle) 415 | .translate(*self._convert_xy_units(self.center)) 416 | .transform(verts) 417 | ) 418 | 419 | def _recompute_path(self): 420 | # circular arc 421 | arc = Path.arc(0, 360) 422 | 423 | # annulus needs to draw an outer ring 424 | # followed by a reversed and scaled inner ring 425 | a, b, w = self.a, self.b, self.width 426 | v1 = self._transform_verts(arc.vertices, a, b) 427 | v2 = self._transform_verts(arc.vertices[::-1], a * w, b * w) 428 | v = onp.vstack([v1, v2, v1[0, :], (0, 0)]) 429 | c = onp.hstack( 430 | [arc.codes, Path.MOVETO, arc.codes[1:], Path.MOVETO, Path.CLOSEPOLY] 431 | ) 432 | self._path = Path(v, c) 433 | 434 | def get_path(self): 435 | if self._path is None: 436 | self._recompute_path() 437 | return self._path 438 | -------------------------------------------------------------------------------- /immrax/parametric/param_reach.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from ..system import System 4 | from .parametope import Parametope, hParametope 5 | from abc import ABC, abstractmethod 6 | from immutabledict import immutabledict 7 | from jaxtyping import Integer, Float, ArrayLike 8 | from typing import Tuple, Union, List, Callable, Literal 9 | from diffrax import AbstractSolver, ODETerm, Euler, Dopri5, Tsit5, SaveAt, diffeqsolve 10 | from ..inclusion import ( 11 | Interval, 12 | standard_permutation, 13 | mjacM, 14 | interval, 15 | i2ut, 16 | natif, 17 | jacM, 18 | icopy, 19 | ) 20 | from ..embedding import embed 21 | from ..neural import fastlin 22 | from jax.experimental.jet import jet 23 | 24 | 25 | class ParametopeEmbedding(ABC): 26 | sys: System 27 | 28 | def __init__(self, sys: System): 29 | self.sys = sys 30 | 31 | @abstractmethod 32 | def _initialize(self, pt0: Parametope) -> ArrayLike: 33 | """Initialize the Embedding System for a particular initial set pt0 34 | 35 | Parameters 36 | ---------- 37 | pt0 : hParametope 38 | _description_ 39 | 40 | Returns 41 | ------- 42 | ArrayLike 43 | aux0: Auxilliary states to evolve with the embedding system 44 | """ 45 | 46 | @abstractmethod 47 | def _dynamics(self, t, state, *args): 48 | """Embedding dynamics 49 | 50 | Parameters 51 | ---------- 52 | t : _type_ 53 | _description_ 54 | state : _type_ 55 | _description_ 56 | """ 57 | 58 | # @partial(jax.jit, static_argnums=(0, 4), static_argnames=("solver", "f_kwargs")) 59 | def compute_reachset( 60 | self, 61 | t0: Union[Integer, Float], 62 | tf: Union[Integer, Float], 63 | pt0: Parametope, 64 | inputs: List[Callable[[int, jax.Array], jax.Array]] = [], 65 | dt: float = 0.01, 66 | *, 67 | solver: Union[Literal["euler", "rk45", "tsit5"], AbstractSolver] = "tsit5", 68 | f_kwargs: immutabledict = immutabledict({}), 69 | **kwargs, 70 | ): 71 | def func(t, x, args): 72 | # Unpack the inputs 73 | return self._dynamics(t, x, *[u(t, x) for u in inputs], **f_kwargs) 74 | 75 | term = ODETerm(func) 76 | if solver == "euler": 77 | solver = Euler() 78 | 79 | elif solver == "rk45": 80 | solver = Dopri5() 81 | elif solver == "tsit5": 82 | solver = Tsit5() 83 | elif isinstance(solver, AbstractSolver): 84 | pass 85 | else: 86 | raise Exception(f"{solver=} is not a valid solver") 87 | 88 | aux0 = self._initialize(pt0) 89 | 90 | saveat = SaveAt(t0=True, t1=True, steps=True) 91 | return diffeqsolve( 92 | term, solver, t0, tf, dt, (pt0, aux0), saveat=saveat, **kwargs 93 | ) 94 | 95 | 96 | class AdjointEmbedding(ParametopeEmbedding): 97 | def __init__(self, sys, alpha_p0, N0, kap: float = 0.1, permutation=None): 98 | # refine_factory:Callable[[ArrayLike], Callable]=partial(SampleRefinement, num_samples=10)): 99 | super().__init__(sys) 100 | self.Jf_x = jax.jacfwd(sys.f, 1) 101 | self.Mf = mjacM(sys.f) 102 | self.Jf = jacM(sys.f) 103 | self.kap = kap 104 | self.alpha_p0 = alpha_p0 105 | self.N0 = N0 106 | self.permutation = permutation 107 | 108 | def _initialize(self, pt0: hParametope) -> ArrayLike: 109 | if not isinstance(pt0, hParametope): 110 | raise ValueError(f"{pt0=} is not a hParametope needed for AdjointEmbedding") 111 | 112 | # Setup refinement 113 | alpha = pt0.alpha 114 | # _id = lambda x : x 115 | # self.refine = self.refine_factory(alpha).get_refine_func() if alpha.shape[0] > alpha.shape[1] else _id 116 | # alpha_p = jnp.linalg.pinv(alpha) 117 | # N = null_space(alpha.T) 118 | # print(N@alpha) 119 | # return (alpha_p, N) 120 | return (self.alpha_p0, self.N0) 121 | 122 | def _dynamics(self, t, state: Tuple[hParametope, ArrayLike], *args, **kwargs): 123 | pt, aux = state 124 | ox = pt.ox 125 | 126 | K = len(pt.y) // 2 127 | # ly = -pt.y[:K] # negative for lower bound 128 | # uy = pt.y[K:] 129 | # iy = lu2i(ly, uy) 130 | y = pt.y 131 | alpha = pt.alpha 132 | alpha_p, N = aux 133 | 134 | ## Adjoint dynamics + LICQ CBF 135 | 136 | args_centers = (arg.center for arg in args) 137 | centers = (jnp.array([t]), ox) + tuple(args_centers) 138 | 139 | J = self.Jf_x(*centers) 140 | u0 = -alpha @ J 141 | u0flat = u0.reshape(-1) 142 | 143 | # CBF: Enforce pairwise independence on the rows of alpha 144 | 145 | PENALTY = 1 146 | 147 | if PENALTY == 0: 148 | ustar = jnp.zeros_like(u0) 149 | 150 | elif PENALTY == 1: 151 | 152 | def soft_overmax(x, eps=1e-5): 153 | return jnp.max(jnp.exp(x) / jnp.sum(jnp.exp(x))) 154 | 155 | def barrier_LICQ(alpha): 156 | # Normalize rows of alpha 157 | return jnp.linalg.det( 158 | alpha / jnp.linalg.norm(alpha, axis=1, keepdims=True) 159 | ) 160 | # return jnp.linalg.slogdet(alpha / jnp.linalg.norm(alpha, axis=1, keepdims=True))[1] 161 | 162 | balpha = barrier_LICQ(alpha) 163 | k = self.kap * balpha**3 164 | 165 | pLfh, Lfh = jax.jvp(barrier_LICQ, (alpha,), (u0,)) 166 | unroll = lambda v: jax.jvp( 167 | barrier_LICQ, (alpha,), (v.reshape(alpha.shape),) 168 | ) 169 | pLgh, Lgh = jax.vmap(unroll)(jnp.eye(alpha.size)) 170 | 171 | # Solution to QP 172 | ustar = jnp.where( 173 | Lfh + Lgh @ u0flat + k >= 0.0, 174 | jnp.zeros_like(u0flat), # constraint inactive 175 | -(Lfh + Lgh @ u0flat + k) * Lgh.T / (Lgh @ Lgh.T), 176 | ).reshape(alpha.shape) 177 | elif PENALTY == 2: 178 | ustar = jnp.zeros_like(u0) 179 | 180 | def soft(H): 181 | HHT = H @ H.T 182 | return jnp.sum((HHT - jnp.eye(H.shape[0])) ** 2) 183 | 184 | ustar = -self.kap * jax.grad(soft)(alpha) 185 | 186 | ## Offset Dynamics given ustar 187 | 188 | MJACM = True 189 | 190 | # For properly handling signs in lower offsets 191 | K_2 = len(y) // 2 192 | mul = jnp.concatenate((-jnp.ones(K_2), jnp.ones(K_2))) 193 | big_iz = pt.hinv(pt.y) 194 | Jh = natif(jax.jacfwd(lambda z: jnp.asarray(pt.h(z)))) 195 | 196 | def refine(y: Interval): 197 | if len(N) > 0: 198 | refinements = _mat_refine_all(N, jnp.arange(len(y)), y) 199 | return interval( 200 | jnp.max(refinements.lower, axis=0), 201 | jnp.min(refinements.upper, axis=0), 202 | ) 203 | else: 204 | return y 205 | 206 | if MJACM: 207 | if self.permutation is None: 208 | lenperm = sum([len(arg) for arg in centers]) 209 | self.permutation = standard_permutation(lenperm) 210 | 211 | MM = self.Mf( 212 | t, 213 | interval(alpha_p) @ big_iz + ox, 214 | *args, 215 | centers=(centers,), 216 | permutations=self.permutation, 217 | )[0] 218 | ls = [] 219 | us = [] 220 | for M, arg in zip(MM[2:], args): 221 | term = interval(M) @ arg 222 | ls.append(term.lower) 223 | us.append(term.upper) 224 | dist = interval(jnp.sum(jnp.asarray(ls)), jnp.sum(jnp.asarray(us))) 225 | 226 | def F(t, iy, *args): 227 | iy = refine(iy) 228 | iz = pt.hinv(i2ut(iy) * mul) 229 | 230 | # _, iJx = self.Jf(interval(t), interval(Hp)@iz + ox, *args) 231 | # MM = self.Mf(t, interval(alpha_p)@big_iz + ox, *args, \ 232 | # centers=(centers,), permutations=self.permutation)[0] 233 | Mx = MM[1] 234 | 235 | empty = jnp.any(iy.lower > iy.upper) 236 | 237 | def _zero(): 238 | return interval(jnp.zeros_like(iz.lower)) 239 | 240 | def _ret(): 241 | # Post first order cancellation 242 | PH = Jh(iz) 243 | return interval(PH[len(PH) // 2 :, :]) @ ( 244 | (interval(alpha) @ (Mx - J) + ustar) @ (interval(alpha_p) @ iz) 245 | + dist 246 | ) 247 | 248 | return jax.lax.cond(empty, _zero, _ret) 249 | 250 | E = embed(F) 251 | else: 252 | 253 | def F_second(t, iy, *args): 254 | iy = refine(iy) 255 | iz = pt.hinv(i2ut(iy) * mul) 256 | 257 | empty = jnp.any(iy.lower > iy.upper) 258 | 259 | def _zero(): 260 | return interval(jnp.zeros_like(iz.lower)) 261 | 262 | def _ret(): 263 | def _get_second(oz, z): 264 | primals = (t, alpha_p @ oz + ox) 265 | series = ( 266 | (0.0, 0.0), 267 | (alpha_p @ z, jnp.zeros_like(alpha_p @ z)), 268 | ) 269 | _, coeffs = jet(self.sys.f, primals, series) 270 | return coeffs[1] 271 | 272 | res = natif(_get_second)(big_iz, iz) 273 | 274 | # Post first order cancellation 275 | PH = Jh(iz) 276 | return interval(PH[len(PH) // 2 :, :]) @ ( 277 | interval(ustar) @ alpha_p @ iz + interval(alpha) @ res 278 | ) 279 | 280 | return jax.lax.cond(empty, _zero, _ret) 281 | 282 | E = embed(F_second) 283 | 284 | E_res = E(t, y * mul, *args) * mul 285 | # E_res = jnp.zeros_like(mul) 286 | 287 | # hParametope dynamics in same pytree structure as pt 288 | pt_dot = pt.from_parametope( 289 | hParametope(self.sys.f(*centers), u0 + ustar, E_res) 290 | ) 291 | # jnp.where(jnp.logical_and(y <= 1e-2, E_res <= 0), jnp.zeros_like(y), E_res))) 292 | # jnp.where(E_res <= 0, jnp.zeros_like(y), E_res))) 293 | 294 | # sets d/dt [alpha_p @ alpha] = 0, so alpha_p @ alpha = I 295 | alpha_p_dot = -alpha_p @ (u0 + ustar) @ alpha_p 296 | # alpha_p_dot = J@alpha_p 297 | 298 | # sets d/dt [N @ alpha] = 0, so N @ alpha = 0 299 | N_dot = -N @ (u0 + ustar) @ alpha_p 300 | # N_dot = jnp.zeros_like(N) 301 | 302 | return (pt_dot, (alpha_p_dot, N_dot)) 303 | 304 | 305 | class FastlinAdjointEmbedding(ParametopeEmbedding): 306 | def __init__( 307 | self, sys, alpha_p0, N0, permutation=None, ustars=None, tt=None, kap=None 308 | ): 309 | # refine_factory:Callable[[ArrayLike], Callable]=partial(SampleRefinement, num_samples=10)): 310 | super().__init__(sys) 311 | self.Jf_x = jax.jacfwd(sys.olsystem.f, 1) 312 | self.Jf_u = jax.jacfwd(sys.olsystem.f, 2) 313 | self.Mf = mjacM(sys.olsystem.f) 314 | self.Jf = jacM(sys.olsystem.f) 315 | self.alpha_p0 = alpha_p0 316 | self.N0 = N0 317 | self.permutation = permutation 318 | self.ustars = ustars 319 | self.tt = tt 320 | self.kap = kap 321 | 322 | def _initialize(self, pt0: hParametope) -> ArrayLike: 323 | if not isinstance(pt0, hParametope): 324 | raise ValueError(f"{pt0=} is not a hParametope needed for AdjointEmbedding") 325 | 326 | # Setup refinement 327 | # alpha = pt0.alpha 328 | # _id = lambda x : x 329 | # self.refine = self.refine_factory(alpha).get_refine_func() if alpha.shape[0] > alpha.shape[1] else _id 330 | # alpha_p = jnp.linalg.pinv(alpha) 331 | # N = null_space(alpha.T) 332 | # print(N@alpha) 333 | # return (alpha_p, N) 334 | return (self.alpha_p0, self.N0) 335 | 336 | def _dynamics(self, t, state: Tuple[hParametope, ArrayLike], *args, **kwargs): 337 | pt, aux = state 338 | ox = pt.ox 339 | 340 | K = len(pt.y) // 2 341 | # ly = -pt.y[:K] # negative for lower bound 342 | # uy = pt.y[K:] 343 | # iy = lu2i(ly, uy) 344 | y = pt.y 345 | alpha = pt.alpha 346 | alpha_p, N = aux 347 | 348 | # Global fastlin 349 | 350 | big_iz = pt.hinv(pt.y) 351 | Jh = natif(jax.jacfwd(lambda z: jnp.asarray(pt.h(z)))) 352 | 353 | def lifted_net(z): 354 | return self.sys.control(alpha_p @ z) 355 | 356 | lifted_net.out_len = self.sys.control.out_len 357 | lifted_net.u = lambda t, y: lifted_net(y) 358 | 359 | # fastlin_res = fastlin(self.sys.control)(interval(alpha_p)@(big_iz + alpha@ox)) 360 | fastlin_res = fastlin(lifted_net)(big_iz + alpha @ ox) 361 | C = fastlin_res.C 362 | # C = jax.jacfwd(lifted_net)(alpha@ox) 363 | 364 | big_iu = fastlin_res(big_iz + alpha @ ox) 365 | CHox = C @ alpha @ ox 366 | ou = self.sys.control(ox) 367 | 368 | ## Adjoint dynamics + LICQ CBF 369 | 370 | args_centers = (arg.center for arg in args) 371 | centers = (jnp.array([t]), ox, ou) + tuple(args_centers) 372 | 373 | J_x = self.Jf_x(*centers) 374 | J_u = self.Jf_u(*centers) 375 | # u0 = -alpha@(J_x@alpha_p + J_u@C) 376 | u0 = -alpha @ (J_x + J_u @ C @ alpha) 377 | u0flat = u0.reshape(-1) 378 | 379 | if self.kap is None: 380 | ustar = jnp.zeros_like(u0) 381 | 382 | else: 383 | 384 | def barrier_LICQ(alpha): 385 | # return jax.jit(jnp.linalg.det, backend='cpu')(alpha / jnp.linalg.norm(alpha, axis=1, keepdims=True)) 386 | return jnp.linalg.det( 387 | alpha / jnp.linalg.norm(alpha, axis=1, keepdims=True) 388 | ) 389 | # return jnp.linalg.slogdet(alpha / jnp.linalg.norm(alpha, axis=1, keepdims=True))[1] 390 | 391 | balpha = barrier_LICQ(alpha) 392 | k = self.kap * balpha**3 393 | 394 | pLfh, Lfh = jax.jvp(barrier_LICQ, (alpha,), (u0,)) 395 | unroll = lambda v: jax.jvp( 396 | barrier_LICQ, (alpha,), (v.reshape(alpha.shape),) 397 | ) 398 | pLgh, Lgh = jax.vmap(unroll)(jnp.eye(alpha.size)) 399 | 400 | # Solution to QP 401 | ustar = jnp.where( 402 | Lfh + Lgh @ u0flat + k >= 0.0, 403 | jnp.zeros_like(u0flat), # constraint inactive 404 | -(Lfh + Lgh @ u0flat + k) * Lgh.T / (Lgh @ Lgh.T), 405 | ).reshape(alpha.shape) 406 | 407 | ## Offset Dynamics given ustar 408 | 409 | # For properly handling signs in lower offsets 410 | K_2 = len(y) // 2 411 | mul = jnp.concatenate((-jnp.ones(K_2), jnp.ones(K_2))) 412 | 413 | def refine(y: Interval): 414 | if len(N) > 0: 415 | refinements = _mat_refine_all(N, jnp.arange(len(y)), y) 416 | return interval( 417 | jnp.max(refinements.lower, axis=0), 418 | jnp.min(refinements.upper, axis=0), 419 | ) 420 | else: 421 | return y 422 | 423 | if self.permutation is None: 424 | lenperm = sum([len(arg) for arg in centers]) 425 | self.permutation = standard_permutation(lenperm) 426 | 427 | MM = self.Mf( 428 | t, 429 | interval(alpha_p) @ big_iz + ox, 430 | big_iu, 431 | *args, 432 | centers=(centers,), 433 | permutations=self.permutation, 434 | )[0] 435 | 436 | ls = [] 437 | us = [] 438 | 439 | def F(t, iy, *args): 440 | # Bound outputs along h^{-1}([iy.lower, iy.upper]) 441 | 442 | iy = refine(iy) 443 | iz = pt.hinv(i2ut(iy) * mul) 444 | 445 | Mx = MM[1] 446 | Mu = MM[2] 447 | 448 | empty = jnp.any(iy.lower > iy.upper) 449 | 450 | def _zero(): 451 | return interval(jnp.zeros_like(iz.lower)) 452 | 453 | def _ret(): 454 | # Post first order cancellation 455 | PH = Jh(iz) 456 | return interval(PH[len(PH) // 2 :, :]) @ ( 457 | interval(alpha) 458 | @ ( 459 | ((Mx - J_x) + (Mu - J_u) @ C @ alpha) @ alpha_p @ iz 460 | + interval(Mu) @ (fastlin_res.lud + CHox - ou) 461 | ) 462 | + interval(ustar) @ alpha_p @ iz 463 | ) 464 | 465 | return jax.lax.cond(empty, _zero, _ret) 466 | 467 | E = embed(F) 468 | E_res = E(t, y * mul, *args) * mul 469 | 470 | # hParametope dynamics in same pytree structure as pt 471 | pt_dot = pt.from_parametope( 472 | hParametope(self.sys.olsystem.f(*centers), u0 + ustar, E_res) 473 | ) 474 | 475 | # sets d/dt [alpha_p @ alpha] = 0, so alpha_p @ alpha = I 476 | alpha_p_dot = -alpha_p @ (u0 + ustar) @ alpha_p 477 | 478 | # sets d/dt [N @ alpha] = 0, so N @ alpha = 0 479 | N_dot = -N @ (u0 + ustar) @ alpha_p 480 | 481 | return (pt_dot, (alpha_p_dot, N_dot)) 482 | 483 | 484 | def _vec_refine(null_vector: jax.Array, var_index: jax.Array, y: Interval): 485 | ret = icopy(y) 486 | 487 | # Set up linear algebra computations for the refinement 488 | bounding_vars = interval(null_vector.at[var_index].set(0)) 489 | ref_var = interval(null_vector[var_index]) 490 | b1 = lambda: ((-bounding_vars @ ret) / ref_var) & ret[var_index] 491 | b2 = lambda: ret[var_index] 492 | 493 | # Compute refinement based on null vector, if possible 494 | ndb0 = jnp.abs(null_vector[var_index]) > 1e-10 495 | ret = jax.lax.cond(ndb0, b1, b2) 496 | 497 | # fix fpe problem with upper < lower 498 | retu = jnp.where(ret.upper >= ret.lower, ret.upper, ret.lower) 499 | return interval(ret.lower, retu) 500 | 501 | 502 | _mat_refine = jax.vmap(_vec_refine, in_axes=(0, None, None), out_axes=0) 503 | _mat_refine_all = jax.vmap(_mat_refine, in_axes=(None, 0, None), out_axes=1) 504 | --------------------------------------------------------------------------------