├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── examples └── STP │ ├── asset │ ├── aircraft.png │ └── quad.png │ ├── config │ ├── default.yaml │ ├── ilqgame.yaml │ └── target_reach.yaml │ └── example.ipynb ├── misc ├── cover.jpg └── examples.png └── wpf ├── STP ├── __init__.py ├── cost.py ├── dynamics.py ├── helpers.py ├── ilq_game.py ├── ilqr.py ├── multiplayer_dynamical_system.py ├── stp.py └── utils.py ├── __init__.py ├── bnp.py └── io.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | *.asv 4 | *.DS_Store 5 | 6 | experiments/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Safe Robotics Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Who Plays First? Optimizing the Order of Play in Stackelberg Games with Many Robots 2 | 3 | [![License][license-shield]][license-url] 4 | [![Python 3.9](https://img.shields.io/badge/python-3.9-blue)](https://www.python.org/downloads/) 5 | [![Website][homepage-shield]][homepage-url] 6 | [![Paper][paper-shield]][paper-url] 7 | 8 | [Haimin Hu](https://haiminhu.org/)1, 9 | [Gabriele Dragotto](https://dragotto.net/)1, 10 | [Zixu Zhang](https://zzx9636.github.io/), 11 | [Kaiqu Liang](https://kaiquliang.github.io/), 12 | [Bartolomeo Stellato](https://stellato.io/), 13 | [Jaime F. Fisac](https://saferobotics.princeton.edu/jaime) 14 | 15 | 1equal contribution 16 | 17 | Published as a conference paper at RSS'2024. 18 | 19 | 20 | 21 |
22 |

23 | 24 | Logo 25 | 26 |

27 |

28 |

29 | 30 | 31 | 32 |
33 |

Table of Contents

34 |
    35 |
  1. About The Project
  2. 36 |
  3. Example
  4. 37 |
  5. License
  6. 38 |
  7. Contact
  8. 39 |
  9. Citation
  10. 40 |
41 |
42 | 43 | 44 | 45 | ## About The Project 46 | 47 | This repository implements **Branch and Play (B&P)**, an *efficient and exact* game-theoretic algorithm that provably converges to a *socially optimal* order of play and its Stackelberg (leader-follower) equilibrium. 48 | As a subroutine for B&P, we also implement sequential trajectory planning (STP) as a game solver to scalably compute a valid local Stackelberg equilibrium for any given order of play. 49 | The repository is primarily developed and maintained by [Haimin Hu](https://haiminhu.org/) and [Gabriele Dragotto](https://dragotto.net/). 50 | 51 | Click to watch our spotlight video: 52 | [![Watch the video](misc/cover.jpg)](https://haiminhu.org/wp-content/uploads/2024/06/rss_wpf.mp4) 53 | 54 | 55 | ## Example 56 | We provide an air traffic control (ATC) example in the [Notebook](https://github.com/SafeRoboticsLab/Who_Plays_First/blob/main/examples/STP/example.ipynb). 57 | This Notebook comprises three sections, each dedicated to a closed-loop simulation using a different method: Branch and Play, first-come-first-served baseline, and Nash ILQ Game baseline. 58 | 59 | 60 | 61 | ## License 62 | 63 | Distributed under the MIT License. See `LICENSE` for more information. 64 | 65 | 66 | 67 | ## Contact 68 | 69 | - Haimin Hu - [@HaiminHu](https://x.com/HaiminHu) - haiminh@princeton.edu 70 | - Gabriele Dragotto - [@GabrieleDrag8](https://x.com/GabrieleDrag8) - hello@dragotto.net 71 | 72 | 73 | 74 | ## Citation 75 | 76 | If you found this repository helpful, please consider citing our paper. 77 | 78 | ```tex 79 | @inproceedings{hu2024plays, 80 | title={Who Plays First? Optimizing the Order of Play in Stackelberg Games with Many Robots}, 81 | author={Hu, Haimin and Dragotto, Gabriele and Zhang, Zixu and Liang, Kaiqu and Stellato, Bartolomeo and Fisac, Jaime F}, 82 | booktitle={Proceedings of Robotics: Science and Systems}, 83 | year={2024} 84 | } 85 | ``` 86 | 87 | 88 | 89 | 90 | [license-shield]: https://img.shields.io/badge/License-MIT-blue.svg 91 | [license-url]: https://opensource.org/licenses/MIT 92 | [homepage-shield]: https://img.shields.io/badge/-Website-orange 93 | [homepage-url]: https://saferobotics.princeton.edu/research/who-plays-first 94 | [paper-shield]: https://img.shields.io/badge/-Paper-green 95 | [paper-url]: https://arxiv.org/abs/2402.09246 96 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: wpf 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - asttokens=2.4.1=pyhd8ed1ab_0 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - blas=1.0=mkl 11 | - brotli=1.1.0=hd590300_1 12 | - brotli-bin=1.1.0=hd590300_1 13 | - bzip2=1.0.8=hd590300_5 14 | - ca-certificates=2024.6.2=hbcca054_0 15 | - comm=0.2.2=pyhd8ed1ab_0 16 | - contourpy=1.2.1=py39h7633fee_0 17 | - cycler=0.12.1=pyhd8ed1ab_0 18 | - dbus=1.13.18=hb2f20db_0 19 | - debugpy=1.8.1=py39h3d6467e_0 20 | - decorator=5.1.1=pyhd8ed1ab_0 21 | - executing=2.0.1=pyhd8ed1ab_0 22 | - expat=2.6.2=h59595ed_0 23 | - fontconfig=2.14.1=hef1e5e3_0 24 | - fonttools=4.53.0=py39hd3abc70_0 25 | - freetype=2.10.4=h0708190_1 26 | - glib=2.78.4=h6a678d5_0 27 | - glib-tools=2.78.4=h6a678d5_0 28 | - gst-plugins-base=1.14.1=h6a678d5_1 29 | - gstreamer=1.14.1=h5eee18b_1 30 | - icu=58.2=hf484d3e_1000 31 | - imageio=2.31.1=py39h06a4308_0 32 | - importlib_metadata=7.2.1=hd8ed1ab_0 33 | - importlib_resources=6.4.0=pyhd8ed1ab_0 34 | - intel-openmp=2023.1.0=hdb19cb5_46306 35 | - ipykernel=6.29.4=pyh3099207_0 36 | - ipython=8.12.0=pyh41d4057_0 37 | - jedi=0.19.1=pyhd8ed1ab_0 38 | - jpeg=9e=h0b41bf4_3 39 | - jupyter_client=8.6.2=pyhd8ed1ab_0 40 | - jupyter_core=5.7.2=py39hf3d152e_0 41 | - keyutils=1.6.1=h166bdaf_0 42 | - kiwisolver=1.4.5=py39h7633fee_1 43 | - krb5=1.20.1=h81ceb04_0 44 | - lcms2=2.12=h3be6417_0 45 | - ld_impl_linux-64=2.40=hf3520f5_7 46 | - lerc=3.0=h295c915_0 47 | - libbrotlicommon=1.1.0=hd590300_1 48 | - libbrotlidec=1.1.0=hd590300_1 49 | - libbrotlienc=1.1.0=hd590300_1 50 | - libclang=10.0.1=default_hb85057a_2 51 | - libdeflate=1.17=h5eee18b_1 52 | - libedit=3.1.20191231=he28a2e2_2 53 | - libevent=2.1.12=hdbd6064_1 54 | - libexpat=2.6.2=h59595ed_0 55 | - libffi=3.4.2=h7f98852_5 56 | - libgcc-ng=13.2.0=h77fa898_13 57 | - libgfortran-ng=13.2.0=h69a702a_13 58 | - libgfortran5=13.2.0=h3d2ce59_13 59 | - libglib=2.78.4=hdc74915_0 60 | - libgomp=13.2.0=h77fa898_13 61 | - libiconv=1.17=hd590300_2 62 | - libllvm10=10.0.1=he513fc3_3 63 | - libpng=1.6.39=h5eee18b_0 64 | - libpq=12.17=hdbd6064_0 65 | - libsodium=1.0.18=h36c2ea0_1 66 | - libstdcxx-ng=13.2.0=hc0a3c3a_13 67 | - libtiff=4.5.1=h6a678d5_0 68 | - libwebp-base=1.4.0=hd590300_0 69 | - libxcb=1.16=hd590300_0 70 | - libxkbcommon=1.0.1=hfa300c1_0 71 | - libxml2=2.9.14=h74e7548_0 72 | - lz4-c=1.9.4=hcb278e6_0 73 | - matplotlib=3.7.2=py39h06a4308_0 74 | - matplotlib-base=3.7.2=py39h1128e8f_0 75 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 76 | - mkl=2023.1.0=h213fc3f_46344 77 | - mkl-service=2.4.0=py39h5eee18b_1 78 | - mkl_fft=1.3.8=py39h5eee18b_0 79 | - mkl_random=1.2.4=py39hdb19cb5_0 80 | - munkres=1.1.4=pyh9f0ad1d_0 81 | - ncurses=6.5=h59595ed_0 82 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 83 | - nspr=4.35=h6a678d5_0 84 | - nss=3.89.1=h6a678d5_0 85 | - numpy=1.25.2=py39h5f9d8c6_0 86 | - numpy-base=1.25.2=py39hb5e798b_0 87 | - openjpeg=2.4.0=h3ad879b_0 88 | - openssl=3.3.1=h4ab18f5_0 89 | - packaging=24.1=pyhd8ed1ab_0 90 | - parso=0.8.4=pyhd8ed1ab_0 91 | - pcre2=10.42=hebb0a14_1 92 | - pexpect=4.9.0=pyhd8ed1ab_0 93 | - pickleshare=0.7.5=py_1003 94 | - pillow=10.3.0=py39h5eee18b_0 95 | - pip=24.0=pyhd8ed1ab_0 96 | - platformdirs=4.2.2=pyhd8ed1ab_0 97 | - ply=3.11=pyhd8ed1ab_2 98 | - prompt-toolkit=3.0.47=pyha770c72_0 99 | - prompt_toolkit=3.0.47=hd8ed1ab_0 100 | - psutil=5.9.8=py39hd1e30aa_0 101 | - pthread-stubs=0.4=h36c2ea0_1001 102 | - ptyprocess=0.7.0=pyhd3deb0d_0 103 | - pure_eval=0.2.2=pyhd8ed1ab_0 104 | - pygments=2.18.0=pyhd8ed1ab_0 105 | - pyparsing=3.0.9=pyhd8ed1ab_0 106 | - pyqt=5.15.10=py39h6a678d5_0 107 | - pyqt5-sip=12.13.0=py39h5eee18b_0 108 | - python=3.9.19=h955ad1f_1 109 | - python-dateutil=2.9.0=pyhd8ed1ab_0 110 | - python_abi=3.9=2_cp39 111 | - pyyaml=6.0=py39h5eee18b_1 112 | - pyzmq=26.0.3=py39ha1047a2_0 113 | - qt-main=5.15.2=h327a75a_7 114 | - readline=8.2=h8228510_1 115 | - scipy=1.11.1=py39h5f9d8c6_0 116 | - setuptools=70.1.0=pyhd8ed1ab_0 117 | - sip=6.7.12=py39h3d6467e_0 118 | - six=1.16.0=pyh6c4a22f_0 119 | - sqlite=3.45.3=h5eee18b_0 120 | - stack_data=0.6.2=pyhd8ed1ab_0 121 | - tbb=2021.8.0=hdb19cb5_0 122 | - tk=8.6.14=h39e8969_0 123 | - tomli=2.0.1=pyhd8ed1ab_0 124 | - tornado=6.4.1=py39hd3abc70_0 125 | - traitlets=5.14.3=pyhd8ed1ab_0 126 | - typing_extensions=4.12.2=pyha770c72_0 127 | - unicodedata2=15.1.0=py39hd1e30aa_0 128 | - wcwidth=0.2.13=pyhd8ed1ab_0 129 | - wheel=0.43.0=pyhd8ed1ab_1 130 | - xorg-libxau=1.0.11=hd590300_0 131 | - xorg-libxdmcp=1.1.3=h7f98852_0 132 | - xz=5.4.6=h5eee18b_1 133 | - yaml=0.2.5=h7b6447c_0 134 | - zeromq=4.3.5=h59595ed_1 135 | - zipp=3.19.2=pyhd8ed1ab_0 136 | - zlib=1.2.13=h5eee18b_1 137 | - zstd=1.5.5=hc292b87_2 138 | - pip: 139 | - gurobipy==10.0.3 140 | - importlib-metadata==6.8.0 141 | - jax==0.4.16 142 | - jaxlib==0.4.16 143 | - ml-dtypes==0.3.0 144 | - opt-einsum==3.3.0 145 | - pandas==2.1.4 146 | - pytz==2023.3.post1 147 | - seaborn==0.13.1 148 | - tzdata==2023.4 149 | - yapf==0.40.2 150 | -------------------------------------------------------------------------------- /examples/STP/asset/aircraft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SafeRoboticsLab/Who_Plays_First/0f6b0a77ac6add52d4d49adcdbc1cc68bc5581db/examples/STP/asset/aircraft.png -------------------------------------------------------------------------------- /examples/STP/asset/quad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SafeRoboticsLab/Who_Plays_First/0f6b0a77ac6add52d4d49adcdbc1cc68bc5581db/examples/STP/asset/quad.png -------------------------------------------------------------------------------- /examples/STP/config/default.yaml: -------------------------------------------------------------------------------- 1 | # region: SYSTEM 2 | N_AGENT: 4 3 | DIM_X: 4 4 | DIM_U: 2 5 | WIDTH: 0.1 6 | PX_DIM: 0 7 | PY_DIM: 1 8 | DYNAMICS: 'DubinsCar' 9 | COST: 'CostDubinsCar' 10 | # endregion 11 | 12 | # region: COST PARAM 13 | W_X: 0.01 14 | W_Y: 0.01 15 | W_V: 1. 16 | W_PSI: 0. 17 | W_X_T: 1. 18 | W_Y_T: 1. 19 | W_V_T: 0. 20 | W_PSI_T: 0. 21 | W_ACCEL: 0.01 22 | W_DELTA: 0.1 23 | 24 | # CONSTRAINT PARAM 25 | PROX_SEP: 0.65 # safe separation 26 | PROX_SEP_CHECK: 0.2 27 | V_REF: 0.3 28 | V_MIN: 0.1 29 | V_MAX: 0.6 30 | A_MIN: -1. 31 | A_MAX: 0.5 32 | DELTA_MIN: -1.5 33 | DELTA_MAX: 1.5 34 | Q1_V: 0.4 35 | Q2_V: 15. 36 | Q1_PROX: 10. 37 | Q2_PROX: 80. 38 | BARRIER_THR: 15. 39 | # endregion 40 | 41 | # region: STP SOLVER 42 | N: 41 43 | T: 4. 44 | RHC: False # If use RHC for subgames 45 | MAX_ITER: 100 # ILQR max iteration 46 | RHC_STEPS: 41 # Horizon in RHC mode 47 | OPEN_LOOP_STEP: 5 # int (>= 1), RHC mode only 48 | INIT_WITH_PARENT: True 49 | COL_CHECK_N: 31 # Collision check horizon 50 | 51 | # BNP PARAMS 52 | BRANCHING: 'depthfirst' # branching strategies: 'depthfirst', 'bestfirst' 53 | MAX_NODES: 10000 # max number of explored nodes 54 | FEAS_TOL: 0.0001 55 | MIN_GAP: 0.0 # optimality threshold 56 | MAX_ITER_BNP: 100 57 | VERBOSE_BNP: False 58 | OPTION_B: True 59 | 60 | # SIMULATION 61 | SIM_STEPS: 350 # max simulation time steps 62 | OUT_FOLDER: experiments/default 63 | RANDOM_SEED: 35 64 | PLOT_RES: True 65 | 66 | # ATC 67 | ATC_X: 0. 68 | ATC_Y: 0. 69 | ATC_RADIUS: 2.5 70 | TARGET_RADIUS: 0.6 71 | REACH_RADIUS: 0.1 72 | # endregion 73 | -------------------------------------------------------------------------------- /examples/STP/config/ilqgame.yaml: -------------------------------------------------------------------------------- 1 | # region: SYSTEM 2 | N_AGENT: 4 3 | DIM_X: 4 4 | DIM_U: 2 5 | WIDTH: 0.1 6 | PX_DIM: 0 7 | PY_DIM: 1 8 | DYNAMICS: 'DubinsCar' 9 | COST: 'CostDubinsCarGame' 10 | # endregion 11 | 12 | # region: COST PARAM 13 | W_X: 0.1 14 | W_Y: 0.1 15 | W_V: 1. 16 | W_PSI: 0. 17 | W_X_T: 10. 18 | W_Y_T: 10. 19 | W_V_T: 0. 20 | W_PSI_T: 0. 21 | W_ACCEL: 0.01 22 | W_DELTA: 0.1 23 | 24 | # CONSTRAINT PARAM 25 | PROX_SEP: 0.7 # safe separation 26 | PROX_SEP_CHECK: 0.2 27 | V_REF: 0.3 28 | V_MIN: 0.1 29 | V_MAX: 0.6 30 | A_MIN: -1. 31 | A_MAX: 1. 32 | DELTA_MIN: -2.5 33 | DELTA_MAX: 2.5 34 | Q1_V: 10. 35 | Q2_V: 30. 36 | Q1_PROX: 10. 37 | Q2_PROX: 80. 38 | BARRIER_THR: 15. 39 | # endregion 40 | 41 | # region: STP SOLVER 42 | N: 41 43 | T: 4. 44 | RHC: False # If use RHC for subgames 45 | MAX_ITER: 100 # ILQR max iteration 46 | RHC_STEPS: 41 # Horizon in RHC mode 47 | OPEN_LOOP_STEP: 5 # int (>= 1), RHC mode only 48 | INIT_WITH_PARENT: True 49 | COL_CHECK_N: 21 # Collision check horizon 50 | 51 | # BNP PARAMS 52 | BRANCHING: 'depthfirst' # branching strategies: 'depthfirst', 'bestfirst' 53 | MAX_NODES: 10000 # max number of explored nodes 54 | FEAS_TOL: 0.0001 55 | MIN_GAP: 0.0 # optimality threshold 56 | MAX_ITER_BNP: 100 57 | VERBOSE_BNP: False 58 | 59 | # SIMULATION 60 | SIM_STEPS: 350 # max simulation time steps 61 | OUT_FOLDER: experiments/default 62 | RANDOM_SEED: 35 63 | PLOT_RES: True 64 | 65 | # ATC 66 | ATC_X: 0. 67 | ATC_Y: 0. 68 | ATC_RADIUS: 2.5 69 | TARGET_RADIUS: 0.6 70 | REACH_RADIUS: 0.1 71 | # endregion 72 | -------------------------------------------------------------------------------- /examples/STP/config/target_reach.yaml: -------------------------------------------------------------------------------- 1 | # SYSTEM 2 | N_AGENT: 4 3 | DIM_X: 4 4 | DIM_U: 2 5 | WIDTH: 0.1 6 | PX_DIM: 0 7 | PY_DIM: 1 8 | DYNAMICS: 'DubinsCar' 9 | COST: 'CostDubinsCar' 10 | 11 | # COST PARAM 12 | W_X: 1. 13 | W_Y: 1. 14 | W_V: 0. 15 | W_PSI: 0. 16 | W_X_T: 50. 17 | W_Y_T: 50. 18 | W_V_T: 0. 19 | W_PSI_T: 0. 20 | W_ACCEL: 0.01 21 | W_DELTA: 0.1 22 | 23 | # CONSTRAINT PARAM 24 | PROX_SEP: 0.6 # safe separation 25 | PROX_SEP_CHECK: 0.2 26 | V_REF: 0.3 27 | V_MIN: -0.05 28 | V_MAX: 0.6 29 | A_MIN: -2. 30 | A_MAX: 0.3 31 | DELTA_MIN: -2.5 32 | DELTA_MAX: 2.5 33 | Q1_V: 0.4 34 | Q2_V: 15. 35 | Q1_PROX: 10. 36 | Q2_PROX: 80. 37 | BARRIER_THR: 15. 38 | 39 | # STP SOLVER 40 | N: 41 41 | T: 4. 42 | RHC: False # If use RHC for subgames 43 | MAX_ITER: 100 # ILQR max iteration 44 | RHC_STEPS: 41 # Horizon in RHC mode 45 | OPEN_LOOP_STEP: 5 # int (>= 1), RHC mode only 46 | INIT_WITH_PARENT: True 47 | COL_CHECK_N: 31 # Collision check horizon 48 | -------------------------------------------------------------------------------- /examples/STP/example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a5882712", 6 | "metadata": {}, 7 | "source": [ 8 | "---\n", 9 | "\n", 10 | "This Notebook was developed by [Haimin Hu](https://haiminhu.org/) for the RSS'24 paper [_Who Plays First? Optimizing the Order of Play in Stackelberg Games with Many Robots_](https://saferobotics.princeton.edu/research/who-plays-first).\n", 11 | "\n", 12 | "Instructions:\n", 13 | "* Run the cells to initiate closed-loop simulation for each method.\n", 14 | "* The simulation results are automatically displayed as an animation within the Notebook.\n", 15 | "* The demo was created for $N=4$ agents, but can be straightforwardly extended to more agents by changing the initial conditions and config files.\n", 16 | "\n", 17 | "---" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "id": "379cd4f0", 23 | "metadata": {}, 24 | "source": [ 25 | "##### Who Plays First" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "id": "e628f67d", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# region: Imports and initialization\n", 36 | "import os, sys, jax\n", 37 | "import numpy as np\n", 38 | "from IPython.display import display, HTML\n", 39 | "from copy import deepcopy\n", 40 | "if len(os.getcwd()) > 0:\n", 41 | " sys.path.insert(1, \"../../\")\n", 42 | " import wpf\n", 43 | " from wpf.STP.utils import *\n", 44 | " from wpf.STP import *\n", 45 | "\n", 46 | "jax.config.update('jax_platform_name', 'cpu')\n", 47 | "jax.config.update('jax_enable_x64', True)\n", 48 | "np.set_printoptions(suppress=True)\n", 49 | "\n", 50 | "# Loads the config and specifies the folder to save figures.\n", 51 | "config = load_config(\"config/default.yaml\")\n", 52 | "config_target = load_config(\"config/target_reach.yaml\")\n", 53 | "dsolver = iLQR(load_config(\"config/default.yaml\")) # default solver\n", 54 | "fig_prog_folder = os.path.join(config.OUT_FOLDER, \"progress_wpf\")\n", 55 | "os.makedirs(fig_prog_folder, exist_ok=True)\n", 56 | "quad_img = plt.imread('asset/aircraft.png', format=\"png\")\n", 57 | "\n", 58 | "# Sets the initial and target states.\n", 59 | "# x y v psi\n", 60 | "x_init_R1 = np.array([-2.0, 2.5, 0., -0.79]) # Initial state of R1.\n", 61 | "x_init_R2 = np.array([2.5, 2.0, 0., 3.93]) # Initial state of R2.\n", 62 | "x_init_R3 = np.array([2.8, -2.5, 0., 2.36]) # Initial state of R3.\n", 63 | "x_init_R4 = np.array([-2.0, -2.8, 0., 0.79]) # Initial state of R4.\n", 64 | "x_init = [x_init_R1, x_init_R2, x_init_R3, x_init_R4]\n", 65 | "\n", 66 | "# x y v psi\n", 67 | "target_R1 = np.array([1.5, -1.5, config.V_REF, 2.36]) # Target state of R1.\n", 68 | "target_R2 = np.array([-1.5, -1.5, config.V_REF, 0.79]) # Target state of R2.\n", 69 | "target_R3 = np.array([-1.5, 1.5, config.V_REF, -0.79]) # Target state of R3.\n", 70 | "target_R4 = np.array([1.5, 1.5, config.V_REF, 3.93]) # Target state of R4.\n", 71 | "targets = [target_R1, target_R2, target_R3, target_R4]\n", 72 | "\n", 73 | "# Initializes stats.\n", 74 | "num_agent = config.N_AGENT\n", 75 | "zero_control = [np.zeros((2, config.N)) for _ in range(num_agent)]\n", 76 | "init_control = deepcopy(zero_control)\n", 77 | "sim_steps = config.SIM_STEPS\n", 78 | "state_hist = [np.zeros((4, sim_steps)) for _ in range(num_agent)]\n", 79 | "ctrl_hist = [np.zeros((2, sim_steps)) for _ in range(num_agent)]\n", 80 | "order_hist = []\n", 81 | "zonef_hist = []\n", 82 | "\n", 83 | "# Initializes Branch-and-play.\n", 84 | "x_cur = deepcopy(x_init)\n", 85 | "STP_instance = (x_cur, init_control, targets, config)\n", 86 | "\n", 87 | "if config.BRANCHING == 'depthfirst':\n", 88 | " branching_strategy = wpf.depthfirst_cb\n", 89 | "elif config.BRANCHING == 'bestfirst':\n", 90 | " branching_strategy = wpf.bestfirst_cb\n", 91 | "\n", 92 | "bnb = wpf.BranchAndPlay(\n", 93 | " num_agent, STP_instance, branching_strategy, stp_solve_cb, lambda *args: None, stp_branching,\n", 94 | " wpf.Settings(\n", 95 | " config.MAX_NODES, config.FEAS_TOL, config.MIN_GAP, config.MAX_ITER_BNP,\n", 96 | " verbose=config.VERBOSE_BNP\n", 97 | " ), stp_initializer\n", 98 | ")\n", 99 | "\n", 100 | "# Initializes the ATC Zone.\n", 101 | "zone = ATCZone(config, targets, STP(config), iLQR(config_target))\n", 102 | "# endregion" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "633fd71b", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# region: Receding horizon planning\n", 113 | "dh = display(HTML('
WPF starts.
'), display_id=True)\n", 114 | "k = 0\n", 115 | "while not all(zone.is_reach_target(x_cur)) and k < sim_steps:\n", 116 | "\n", 117 | " # region: Solves BnP.\n", 118 | " if k > 0:\n", 119 | " Js_prev = _results.custom_statistics[\"Js_prev\"]\n", 120 | " bnb.update_problem(instance_data=(x_cur, init_control, targets, Js_prev, config), mode=\"update\")\n", 121 | "\n", 122 | " bnb.solve()\n", 123 | " _results = bnb.results\n", 124 | " _stime = sum(_results.custom_statistics[\"jax_compile_time\"])\n", 125 | " # endregion\n", 126 | "\n", 127 | " # region: Locks the ordering when collision check fails.\n", 128 | " _bnb_state = jnp.stack(_results.incumbent[0], axis=2)\n", 129 | " if (zone.oz_planner.pairwise_collision_check(_bnb_state[:, :config.COL_CHECK_N, :])).any():\n", 130 | " _states, _controls = zone.plan_stp(x_cur, init_control, targets, order_hist[-1])\n", 131 | " x_new = [_states[ii][:, 1] for ii in range(num_agent)]\n", 132 | " order_hist.append(order_hist[-1])\n", 133 | " us_cur = _controls\n", 134 | " else:\n", 135 | " x_new = [_results.incumbent[0][ii][:, 1] for ii in range(num_agent)]\n", 136 | " order_hist.append(list(bnb.results.incumbent_permutation))\n", 137 | " us_cur = _results.incumbent[1]\n", 138 | " for ii in range(num_agent):\n", 139 | " init_control[ii][:, :-1] = us_cur[ii][:, 1:]\n", 140 | " # endregion\n", 141 | "\n", 142 | " # region: Checks ATC zone.\n", 143 | " _, us_new, zone_flags = zone.check_zone(x_cur, x_new, us_cur, zero_control)\n", 144 | " zonef_hist.append(zone_flags)\n", 145 | " col_flag = zone.is_collision(x_cur)\n", 146 | " # endregion\n", 147 | "\n", 148 | " # region: Updates and reports stats.\n", 149 | " for ii in range(num_agent):\n", 150 | " _xii, _ = dsolver.dynamics.integrate_forward(x_cur[ii], us_new[ii][:, 0])\n", 151 | " x_cur[ii] = _xii\n", 152 | "\n", 153 | " for ii in range(num_agent):\n", 154 | " state_hist[ii][:, k], ctrl_hist[ii][:, k] = x_cur[ii], us_new[ii][:, 0]\n", 155 | " k += 1\n", 156 | "\n", 157 | " _info = [\n", 158 | " 'step: ', k, ' | stime: ', '{:04.2f}'.format(_stime), ' | Objective: ',\n", 159 | " '{:03.1f}'.format(_results.global_ub), ' | Permutation: ', _results.incumbent_permutation,\n", 160 | " ' | Col: ', col_flag\n", 161 | " ]\n", 162 | " _info = [str(_item) for _item in _info]\n", 163 | " dh.update(''.join(_info))\n", 164 | " # endregion\n", 165 | "# endregion\n", 166 | "\n", 167 | "# region: Wraps up.\n", 168 | "state_hist = [state_hist[ii][:, :k] for ii in range(num_agent)]\n", 169 | "ctrl_hist = [ctrl_hist[ii][:, :k] for ii in range(num_agent)]\n", 170 | "\n", 171 | "# Plots the optimal trajectory.\n", 172 | "if config.PLOT_RES:\n", 173 | " plot_trajectory(\n", 174 | " state_hist, config, fig_prog_folder, orders=order_hist, targets=targets,\n", 175 | " colors=['r', 'g', 'b', 'm'], xlim=(-3, 3), ylim=(-3, 3), figsize=(20, 20), fontsize=35,\n", 176 | " image=quad_img, plot_arrow=False, zone=zone, zone_flags=zonef_hist\n", 177 | " )\n", 178 | " img = make_animation(state_hist[0].shape[1], config, fig_prog_folder)\n", 179 | " display(img)\n", 180 | "# endregion" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "id": "2b4db988", 186 | "metadata": {}, 187 | "source": [ 188 | "##### First-come-first-served" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "82fda483", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "# region: Initialization\n", 199 | "# Loads the config and specifies the folder to save figures.\n", 200 | "fig_prog_folder = os.path.join(config.OUT_FOLDER, \"progress_fcfs\")\n", 201 | "os.makedirs(fig_prog_folder, exist_ok=True)\n", 202 | "\n", 203 | "# Initializes stats.\n", 204 | "x_cur = deepcopy(x_init)\n", 205 | "num_agent = config.N_AGENT\n", 206 | "zero_control = [np.zeros((2, config.N)) for _ in range(num_agent)]\n", 207 | "init_control = deepcopy(zero_control)\n", 208 | "sim_steps = config.SIM_STEPS\n", 209 | "state_hist = [np.zeros((4, sim_steps)) for _ in range(num_agent)]\n", 210 | "ctrl_hist = [np.zeros((2, sim_steps)) for _ in range(num_agent)]\n", 211 | "order_hist = []\n", 212 | "zonef_hist = []\n", 213 | "dh = display(HTML('
FCFS starts.
'), display_id=True)\n", 214 | "\n", 215 | "perm = [None] * num_agent\n", 216 | "assigned_flags = [False] * num_agent\n", 217 | "order = 0\n", 218 | "\n", 219 | "# Initializes the ATC Zone.\n", 220 | "zone = ATCZone(config, targets, STP(config), iLQR(config_target))\n", 221 | "# endregion\n", 222 | "\n", 223 | "# region: Receding horizon planning\n", 224 | "k = 0\n", 225 | "while not all(zone.is_reach_target(x_cur)) and k < sim_steps:\n", 226 | "\n", 227 | " # region: Determines the order on a first-come-first-serve basis.\n", 228 | " if not all(assigned_flags):\n", 229 | " zone_flags = zone.is_in_zone(x_cur)\n", 230 | " for ii in range(num_agent):\n", 231 | " if not assigned_flags[ii] and zone_flags[ii]:\n", 232 | " perm[order] = ii\n", 233 | " assigned_flags[ii] = True\n", 234 | " order += 1\n", 235 | " # endregion\n", 236 | "\n", 237 | " # region: Plans STP.\n", 238 | " _states, _controls = zone.plan_stp(x_cur, init_control, targets, perm)\n", 239 | " x_new = [_states[ii][:, 1] for ii in range(num_agent)]\n", 240 | " order_hist.append(perm)\n", 241 | " us_cur = _controls\n", 242 | "\n", 243 | " for ii in range(num_agent):\n", 244 | " init_control[ii][:, :-1] = us_cur[ii][:, 1:]\n", 245 | " # endregion\n", 246 | "\n", 247 | " # region: Checks ATC zone.\n", 248 | " _, us_new, zone_flags = zone.check_zone(x_cur, x_new, us_cur, zero_control)\n", 249 | " zonef_hist.append(zone_flags)\n", 250 | " col_flag = zone.is_collision(x_cur)\n", 251 | " # endregion\n", 252 | "\n", 253 | " # region: Updates and reports stats.\n", 254 | " for ii in range(num_agent):\n", 255 | " _xii, _ = dsolver.dynamics.integrate_forward(x_cur[ii], us_new[ii][:, 0])\n", 256 | " x_cur[ii] = _xii\n", 257 | "\n", 258 | " for ii in range(num_agent):\n", 259 | " state_hist[ii][:, k], ctrl_hist[ii][:, k] = x_cur[ii], us_new[ii][:, 0]\n", 260 | " k += 1\n", 261 | "\n", 262 | " _info = ['step: ', k, ' | Permutation: ', perm, ' | Col: ', col_flag]\n", 263 | " _info = [str(_item) for _item in _info]\n", 264 | " dh.update(''.join(_info))\n", 265 | " # endregion\n", 266 | "# endregion\n", 267 | "\n", 268 | "# region: Wraps up.\n", 269 | "state_hist = [state_hist[ii][:, :k] for ii in range(num_agent)]\n", 270 | "ctrl_hist = [ctrl_hist[ii][:, :k] for ii in range(num_agent)]\n", 271 | "\n", 272 | "# Plots the optimal trajectory.\n", 273 | "if config.PLOT_RES:\n", 274 | " plot_trajectory(\n", 275 | " state_hist, config, fig_prog_folder, orders=order_hist, targets=targets,\n", 276 | " colors=['r', 'g', 'b', 'm'], xlim=(-3, 3), ylim=(-3, 3), figsize=(20, 20), fontsize=35,\n", 277 | " image=quad_img, plot_arrow=False, zone=zone, zone_flags=zonef_hist\n", 278 | " )\n", 279 | " img = make_animation(state_hist[0].shape[1], config, fig_prog_folder, name=\"rollout_fcfs.gif\")\n", 280 | " display(img)\n", 281 | "# endregion" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "id": "f0867a8e", 287 | "metadata": {}, 288 | "source": [ 289 | "##### Nash ILQ Game" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "ab10ccf2", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "# region: Initialization\n", 300 | "# Loads the config and specifies the folder to save figures.\n", 301 | "config = load_config(\"config/ilqgame.yaml\")\n", 302 | "config_ilqr = load_config(\"config/default.yaml\")\n", 303 | "config_ilqr.N, config_ilqr.T = config.N, config.T\n", 304 | "config_target = load_config(\"config/target_reach.yaml\")\n", 305 | "config_target.N, config_target.T = config.N, config.T\n", 306 | "fig_prog_folder = os.path.join(config.OUT_FOLDER, \"progress_wpf\")\n", 307 | "os.makedirs(fig_prog_folder, exist_ok=True)\n", 308 | "quad_img = plt.imread('asset/aircraft.png', format=\"png\")\n", 309 | "\n", 310 | "# Loads the config and specifies the folder to save figures.\n", 311 | "fig_prog_folder = os.path.join(config.OUT_FOLDER, \"progress_ilq\")\n", 312 | "os.makedirs(fig_prog_folder, exist_ok=True)\n", 313 | "\n", 314 | "# Initializes stats.\n", 315 | "x_cur = deepcopy(x_init)\n", 316 | "num_agent = config.N_AGENT\n", 317 | "zero_control = [np.zeros((2, config.N)) for _ in range(num_agent)]\n", 318 | "init_control = deepcopy(zero_control)\n", 319 | "sim_steps = config.SIM_STEPS\n", 320 | "state_hist = [np.zeros((4, sim_steps)) for _ in range(num_agent)]\n", 321 | "ctrl_hist = [np.zeros((2, sim_steps)) for _ in range(num_agent)]\n", 322 | "order_hist = []\n", 323 | "zonef_hist = []\n", 324 | "dh = display(HTML('
ILQGame starts.
'), display_id=True)\n", 325 | "\n", 326 | "perm = [None] * num_agent\n", 327 | "assigned_flags = [False] * num_agent\n", 328 | "order = 0\n", 329 | "\n", 330 | "# Initializes the ATC Zone.\n", 331 | "zone = ATCZone(config, targets, STP(config_ilqr), iLQR(config_target))\n", 332 | "\n", 333 | "# Sets up the iLQGame solver.\n", 334 | "dummy_ilqr = iLQR(config_ilqr)\n", 335 | "jnt_sys = ProductMultiPlayerDynamicalSystem([dummy_ilqr.dynamics] * num_agent)\n", 336 | "solver = ILQGame(config, jnt_sys, verbose=False)\n", 337 | "# endregion\n", 338 | "\n", 339 | "# region: Receding horizon planning\n", 340 | "k = 0\n", 341 | "while not all(zone.is_reach_target(x_cur)) and k < sim_steps:\n", 342 | "\n", 343 | " # region: Plans ILQGame.\n", 344 | " _states, _controls, _, _ = solver.solve(x_cur, init_control, targets)\n", 345 | " x_new = [_states[ii][:, 1] for ii in range(num_agent)]\n", 346 | " order_hist.append(perm)\n", 347 | " us_cur = _controls\n", 348 | "\n", 349 | " for ii in range(num_agent):\n", 350 | " init_control[ii][:, :-1] = us_cur[ii][:, 1:]\n", 351 | " # endregion\n", 352 | "\n", 353 | " # region: Shielding\n", 354 | " _states_cc = jnp.stack(_states, axis=2)\n", 355 | " if (zone.oz_planner.pairwise_collision_check(_states_cc[:, :config.COL_CHECK_N, :])).any():\n", 356 | " us_cur = deepcopy(us_cur)\n", 357 | " _xs_bk, _us_bk = zone.plan_stp(x_cur, init_control, targets, list(range(num_agent)))\n", 358 | " _xs_bk_cc = jnp.stack(_xs_bk, axis=2)\n", 359 | " if (zone.oz_planner.pairwise_collision_check(_xs_bk_cc[:, :config.COL_CHECK_N, :])).any():\n", 360 | " for ii in range(num_agent):\n", 361 | " if ii > 0:\n", 362 | " us_cur[ii][:, 0] = -x_cur[ii][2] / dsolver.dynamics.dt\n", 363 | " else:\n", 364 | " for ii in range(num_agent):\n", 365 | " us_cur[ii][:, 0] = _us_bk[ii][:, 0]\n", 366 | " # endregion\n", 367 | "\n", 368 | " # region: Checks ATC zone.\n", 369 | " _, us_new, zone_flags = zone.check_zone(x_cur, x_new, us_cur, zero_control)\n", 370 | " zonef_hist.append(zone_flags)\n", 371 | " col_flag = zone.is_collision(x_cur)\n", 372 | " # endregion\n", 373 | "\n", 374 | " # region: Updates and reports stats.\n", 375 | " for ii in range(num_agent):\n", 376 | " _xii, _ = solver.dynamics._subsystem.integrate_forward_norev(x_cur[ii], us_new[ii][:, 0])\n", 377 | " x_cur[ii] = _xii\n", 378 | "\n", 379 | " for ii in range(num_agent):\n", 380 | " state_hist[ii][:, k], ctrl_hist[ii][:, k] = x_cur[ii], us_new[ii][:, 0]\n", 381 | " k += 1\n", 382 | "\n", 383 | " print('step: ', k, ' | Permutation: ', perm, ' | Col: ', col_flag)\n", 384 | " _info = ['step: ', k, ' | Permutation: ', perm, ' | Col: ', col_flag]\n", 385 | " _info = [str(_item) for _item in _info]\n", 386 | " dh.update(''.join(_info))\n", 387 | " # endregion\n", 388 | "# endregion\n", 389 | "\n", 390 | "# region: Wraps up.\n", 391 | "state_hist = [state_hist[ii][:, :k] for ii in range(num_agent)]\n", 392 | "ctrl_hist = [ctrl_hist[ii][:, :k] for ii in range(num_agent)]\n", 393 | "\n", 394 | "# Plots the optimal trajectory.\n", 395 | "if config.PLOT_RES:\n", 396 | " plot_trajectory(\n", 397 | " state_hist, config, fig_prog_folder, orders=order_hist, targets=targets,\n", 398 | " colors=['r', 'g', 'b', 'm'], xlim=(-3, 3), ylim=(-3, 3), figsize=(20, 20), fontsize=35,\n", 399 | " image=quad_img, plot_arrow=False, zone=zone, zone_flags=zonef_hist\n", 400 | " )\n", 401 | " img = make_animation(state_hist[0].shape[1], config, fig_prog_folder, name=\"rollout_ilq.gif\")\n", 402 | " display(img)\n", 403 | "# endregion" 404 | ] 405 | } 406 | ], 407 | "metadata": { 408 | "colab": { 409 | "provenance": [] 410 | }, 411 | "kernelspec": { 412 | "display_name": "Python 3", 413 | "name": "python3" 414 | }, 415 | "language_info": { 416 | "codemirror_mode": { 417 | "name": "ipython", 418 | "version": 3 419 | }, 420 | "file_extension": ".py", 421 | "mimetype": "text/x-python", 422 | "name": "python", 423 | "nbconvert_exporter": "python", 424 | "pygments_lexer": "ipython3", 425 | "version": "3.9.19" 426 | } 427 | }, 428 | "nbformat": 4, 429 | "nbformat_minor": 5 430 | } 431 | -------------------------------------------------------------------------------- /misc/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SafeRoboticsLab/Who_Plays_First/0f6b0a77ac6add52d4d49adcdbc1cc68bc5581db/misc/cover.jpg -------------------------------------------------------------------------------- /misc/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SafeRoboticsLab/Who_Plays_First/0f6b0a77ac6add52d4d49adcdbc1cc68bc5581db/misc/examples.png -------------------------------------------------------------------------------- /wpf/STP/__init__.py: -------------------------------------------------------------------------------- 1 | from .cost import * 2 | from .dynamics import * 3 | from .ilqr import * 4 | from .stp import * 5 | from .utils import * 6 | from .helpers import * 7 | from .multiplayer_dynamical_system import * 8 | from .ilq_game import * 9 | -------------------------------------------------------------------------------- /wpf/STP/cost.py: -------------------------------------------------------------------------------- 1 | """ 2 | Costs, gradients, and Hessians. 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Author: Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | 8 | import numpy as np 9 | from typing import Tuple 10 | from abc import ABC, abstractmethod 11 | from scipy.linalg import solve_discrete_are as dare 12 | 13 | from functools import partial 14 | from jax import jit, lax, vmap, jacfwd, hessian 15 | from jaxlib.xla_extension import ArrayImpl 16 | import jax.numpy as jnp 17 | 18 | from .utils import Struct 19 | from .dynamics import * 20 | 21 | 22 | class Cost(ABC): 23 | 24 | def __init__(self, config: Struct, num_leaders: float): 25 | self.config = config 26 | 27 | # Planning parameters. 28 | self.N = config.N # number of planning steps 29 | self.num_leaders = num_leaders 30 | 31 | # System parameters. 32 | self.dim_x = config.DIM_X 33 | self.dim_u = config.DIM_U 34 | self.px_dim = config.PX_DIM 35 | self.py_dim = config.PY_DIM 36 | 37 | @abstractmethod 38 | def get_cost( 39 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 40 | ) -> ArrayImpl: 41 | """ 42 | Calculates the cost given planned states and controls. 43 | 44 | Args: 45 | states (ArrayImpl): (dim_x, N) planned trajectory. 46 | controls (ArrayImpl): (dim_u, N) planned control sequence. 47 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 48 | target (ArrayImpl): (dim_x,) target state. 49 | 50 | Returns: 51 | float: total cost. 52 | """ 53 | raise NotImplementedError 54 | 55 | @abstractmethod 56 | def get_derivatives( 57 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 58 | ) -> Tuple[ArrayImpl, ArrayImpl, ArrayImpl, ArrayImpl]: 59 | """ 60 | Calculates gradients and Hessian of the overall cost using Jax. 61 | 62 | Args: 63 | states (ArrayImpl): (dim_x, N) planned trajectory. 64 | controls (ArrayImpl): (dim_u, N) planned control sequence. 65 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 66 | target (ArrayImpl): (dim_x,) target state. 67 | 68 | Returns: 69 | ArrayImpl: lxs of the shape (dim_x, N). 70 | ArrayImpl: Hxxs of the shape (dim_x, dim_x, N). 71 | ArrayImpl: lus of the shape (dim_u, N). 72 | ArrayImpl: Huus of the shape (dim_u, dim_u, N). 73 | """ 74 | raise NotImplementedError 75 | 76 | 77 | class CostDubinsCar(Cost): 78 | 79 | def __init__(self, config: Struct, num_leaders: float): 80 | Cost.__init__(self, config, num_leaders) 81 | 82 | # Standard LQ weighting matrices. 83 | self.W_state = np.diag((config.W_X, config.W_Y, config.W_V, config.W_PSI)) 84 | self.W_control = np.diag((config.W_ACCEL, config.W_DELTA)) 85 | self.W_terminal = np.diag((config.W_X_T, config.W_Y_T, config.W_V_T, config.W_PSI_T)) 86 | 87 | # Soft constraint parameters. 88 | self.q1_v = config.Q1_V 89 | self.q2_v = config.Q2_V 90 | self.q1_prox = config.Q1_PROX 91 | self.q2_prox = config.Q2_PROX 92 | self.barrier_thr = config.BARRIER_THR 93 | self.prox_sep = config.PROX_SEP 94 | self.v_min = config.V_MIN 95 | self.v_max = config.V_MAX 96 | 97 | @partial(jit, static_argnames="self") 98 | def get_cost( 99 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 100 | ) -> ArrayImpl: 101 | """ 102 | Calculates the cost given planned states and controls. 103 | 104 | Args: 105 | states (ArrayImpl): (dim_x, N) planned trajectory. 106 | controls (ArrayImpl): (dim_u, N) planned control sequence. 107 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 108 | target (ArrayImpl): (dim_x,) target state. 109 | 110 | Returns: 111 | float: total cost. 112 | """ 113 | # vmap all costs. 114 | c_state_vmap = vmap(self.state_cost_stage, in_axes=(1, None), out_axes=(0)) 115 | c_cntrl_vmap = vmap(self.control_cost_stage, in_axes=(1), out_axes=(0)) 116 | c_velbd_vmap = vmap(self.vel_bound_cost_stage, in_axes=(1), out_axes=(0)) 117 | c_proxi_vmap = vmap(self.proximity_cost_stage, in_axes=(1, 1), out_axes=(0)) 118 | 119 | # Evaluates all cost terms. 120 | c_state = c_state_vmap(states, target) 121 | c_cntrl = c_cntrl_vmap(controls) 122 | c_velbd = c_velbd_vmap(states) 123 | c_proxi = c_proxi_vmap(states, leader_trajs) 124 | c_termi = self.state_cost_terminal(states[:, -1], target) 125 | 126 | # Sums up all cost terms. 127 | J = jnp.sum(c_state + c_cntrl + c_velbd + c_proxi) + c_termi 128 | 129 | return J 130 | 131 | @partial(jit, static_argnames="self") 132 | def get_derivatives( 133 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 134 | ) -> Tuple[ArrayImpl, ArrayImpl, ArrayImpl, ArrayImpl]: 135 | """ 136 | Calculates gradients and Hessian of the overall cost using Jax. 137 | 138 | Args: 139 | states (ArrayImpl): (dim_x, N) planned trajectory. 140 | controls (ArrayImpl): (dim_u, N) planned control sequence. 141 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 142 | target (ArrayImpl): (dim_x,) target state. 143 | 144 | Returns: 145 | ArrayImpl: lxs of the shape (dim_x, N). 146 | ArrayImpl: Hxxs of the shape (dim_x, dim_x, N). 147 | ArrayImpl: lus of the shape (dim_u, N). 148 | ArrayImpl: Huus of the shape (dim_u, dim_u, N). 149 | """ 150 | # Creates cost gradient functions. 151 | lx_state_fn = jacfwd(self.state_cost_stage, argnums=0) 152 | lx_velbd_fn = jacfwd(self.vel_bound_cost_stage, argnums=0) 153 | lx_proxi_fn = jacfwd(self.proximity_cost_stage, argnums=0) 154 | lu_cntrl_fn = jacfwd(self.control_cost_stage, argnums=0) 155 | lx_termi_fn = jacfwd(self.state_cost_terminal, argnums=0) 156 | 157 | # Creates cost Hessian functions. 158 | Hxx_state_fn = hessian(self.state_cost_stage, argnums=0) 159 | Hxx_velbd_fn = hessian(self.vel_bound_cost_stage, argnums=0) 160 | Hxx_proxi_fn = hessian(self.proximity_cost_stage, argnums=0) 161 | Huu_cntrl_fn = hessian(self.control_cost_stage, argnums=0) 162 | Hxx_termi_fn = hessian(self.state_cost_terminal, argnums=0) 163 | 164 | # vmap all gradients and Hessians. 165 | lx_state_vmap = vmap(lx_state_fn, in_axes=(1, None), out_axes=(1)) 166 | lx_velbd_vmap = vmap(lx_velbd_fn, in_axes=(1), out_axes=(1)) 167 | lx_proxi_vmap = vmap(lx_proxi_fn, in_axes=(1, 1), out_axes=(1)) 168 | lu_cntrl_vmap = vmap(lu_cntrl_fn, in_axes=(1), out_axes=(1)) 169 | 170 | Hxx_state_vmap = vmap(Hxx_state_fn, in_axes=(1, None), out_axes=(2)) 171 | Hxx_velbd_vmap = vmap(Hxx_velbd_fn, in_axes=(1), out_axes=(2)) 172 | Hxx_proxi_vmap = vmap(Hxx_proxi_fn, in_axes=(1, 1), out_axes=(2)) 173 | Huu_cntrl_vmap = vmap(Huu_cntrl_fn, in_axes=(1), out_axes=(2)) 174 | 175 | # Evaluates all cost gradients and Hessians. 176 | lx_state = lx_state_vmap(states, target) 177 | lx_velbd = lx_velbd_vmap(states) 178 | lx_proxi = lx_proxi_vmap(states, leader_trajs) 179 | lu_cntrl = lu_cntrl_vmap(controls) 180 | lx_termi = lx_termi_fn(states[:, -1], target) 181 | 182 | Hxx_state = Hxx_state_vmap(states, target) 183 | Hxx_velbd = Hxx_velbd_vmap(states) 184 | Hxx_proxi = Hxx_proxi_vmap(states, leader_trajs) 185 | Huu_cntrl = Huu_cntrl_vmap(controls) 186 | Hxx_termi = Hxx_termi_fn(states[:, -1], target) 187 | 188 | lxs = lx_state + lx_velbd + lx_proxi 189 | lus = lu_cntrl 190 | Hxxs = Hxx_state + Hxx_velbd + Hxx_proxi 191 | Huus = Huu_cntrl 192 | 193 | lxs = lxs.at[:, -1].set(lxs[:, -1] + lx_termi) 194 | Hxxs = Hxxs.at[:, :, -1].set(Hxxs[:, :, -1] + Hxx_termi) 195 | 196 | return lxs, Hxxs, lus, Huus 197 | 198 | # --------------------------- Running performance cost terms --------------------------- 199 | @partial(jit, static_argnames="self") 200 | def state_cost_stage(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 201 | """ 202 | Computes the stage state cost. 203 | 204 | Args: 205 | state (ArrayImpl): (4,) 206 | target (ArrayImpl): (2,) 207 | 208 | Returns: 209 | ArrayImpl: cost (scalar) 210 | """ 211 | return (state - target).T @ self.W_state @ (state-target) 212 | 213 | @partial(jit, static_argnames="self") 214 | def control_cost_stage(self, control: ArrayImpl) -> ArrayImpl: 215 | """ 216 | Computes the stage control cost. 217 | 218 | Args: 219 | control (ArrayImpl): (2,) 220 | 221 | Returns: 222 | ArrayImpl: cost (scalar) 223 | """ 224 | return control.T @ self.W_control @ control 225 | 226 | @partial(jit, static_argnames="self") 227 | def state_cost_terminal(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 228 | """ 229 | Computes the terminal state cost. 230 | HACK: terminal velocity is set to 0. 231 | 232 | Args: 233 | state (ArrayImpl): (4,) [x, y, v, psi] 234 | target (ArrayImpl): (4,) [x, y, v, psi] 235 | 236 | Returns: 237 | ArrayImpl: cost (scalar) 238 | """ 239 | t_state = jnp.array([target[0], target[1], 0., target[3]]) 240 | return (state - t_state).T @ self.W_terminal @ (state-t_state) 241 | 242 | # ----------------------- Running soft constraint cost terms ----------------------- 243 | @partial(jit, static_argnames="self") 244 | def vel_bound_cost_stage(self, state: ArrayImpl) -> ArrayImpl: 245 | """ 246 | Calculates the velocity bound soft constraint cost. 247 | 248 | Args: 249 | state (ArrayImpl): (4,) 250 | 251 | Returns: 252 | ArrayImpl: cost (scalar) 253 | """ 254 | cons_v_min = self.v_min - state[2] 255 | cons_v_max = state[2] - self.v_max 256 | barrier_v_min = self.q1_v * jnp.exp(jnp.clip(self.q2_v * cons_v_min, None, self.barrier_thr)) 257 | barrier_v_max = self.q1_v * jnp.exp(jnp.clip(self.q2_v * cons_v_max, None, self.barrier_thr)) 258 | return barrier_v_min + barrier_v_max 259 | 260 | @partial(jit, static_argnames="self") 261 | def proximity_cost_stage(self, state: ArrayImpl, leader_state: ArrayImpl) -> ArrayImpl: 262 | """ 263 | Calculates the proximity soft constraint cost. 264 | 265 | Args: 266 | state (ArrayImpl): (4,) ego state 267 | leader_state (ArrayImpl): (4, N_leader) leaders' optimized state 268 | 269 | Returns: 270 | ArrayImpl: cost (scalar) 271 | """ 272 | 273 | def _looper(ii, cost_vec): 274 | dx = state[self.px_dim] - leader_state[self.px_dim, ii] 275 | dy = state[self.py_dim] - leader_state[self.py_dim, ii] 276 | penalty_prox = -jnp.minimum(jnp.sqrt(dx**2 + dy**2) - self.prox_sep, 0.0) 277 | cost_vec = cost_vec.at[ii].set( 278 | self.q1_prox * jnp.exp(jnp.clip(self.q2_prox * penalty_prox, None, self.barrier_thr)) 279 | ) 280 | return cost_vec 281 | 282 | if self.num_leaders == 0: 283 | return 0.0 284 | else: 285 | cost_vec = jnp.zeros((self.num_leaders,)) 286 | cost_vec = lax.fori_loop(0, self.num_leaders, _looper, leader_state) 287 | # return jnp.sum(cost_vec) 288 | return jnp.max(cost_vec) 289 | 290 | 291 | class CostDoubleIntegrator(Cost): 292 | 293 | def __init__(self, config: Struct, num_leaders: float): 294 | Cost.__init__(self, config, num_leaders) 295 | 296 | # Standard LQ weighting matrices. 297 | self.W_state = np.diag((config.W_X, config.W_Y, config.W_V)) 298 | self.W_control = np.diag((config.W_AX, config.W_AY)) 299 | 300 | # Soft constraint parameters. 301 | self.q1_v = config.Q1_V 302 | self.q2_v = config.Q2_V 303 | self.q1_prox = config.Q1_PROX 304 | self.q2_prox = config.Q2_PROX 305 | self.barrier_thr = config.BARRIER_THR 306 | self.prox_sep = config.PROX_SEP 307 | self.v_min = config.V_MIN 308 | self.v_max = config.V_MAX 309 | _DI_sys = DoubleIntegrator(config) 310 | self.compute_terminal_cost(np.asarray(_DI_sys.Ad), np.asarray(_DI_sys.Bd)) 311 | 312 | def compute_terminal_cost(self, Ad: np.array, Bd: np.array): 313 | """ 314 | Computes the terminal cost-to-go via solving DARE. 315 | 316 | Args: 317 | Ad (np.array): system matrix 318 | Bd (np.array): system matrix 319 | """ 320 | _Q = np.diag((self.config.W_X_T, self.config.W_Y_T, self.config.W_V_T, self.config.W_V_T)) 321 | _Q_terminal = dare(a=Ad, b=Bd, q=_Q, r=self.W_control) 322 | self.Q_terminal = jnp.asarray(_Q_terminal) 323 | _K_lqr = np.linalg.inv(Bd.T @ _Q_terminal @ Bd + self.W_control) @ (Bd.T @ _Q_terminal @ Ad) 324 | self.K_lqr = jnp.asarray(_K_lqr) 325 | 326 | @partial(jit, static_argnames="self") 327 | def get_cost( 328 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 329 | ) -> ArrayImpl: 330 | """ 331 | Calculates the cost given planned states and controls. 332 | 333 | Args: 334 | states (ArrayImpl): (dim_x, N) planned trajectory. 335 | controls (ArrayImpl): (dim_u, N) planned control sequence. 336 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 337 | target (ArrayImpl): (dim_x,) target state. 338 | 339 | Returns: 340 | float: total cost. 341 | """ 342 | # vmap all costs. 343 | c_state_vmap = vmap(self.state_cost_stage, in_axes=(1, None), out_axes=(0)) 344 | c_cntrl_vmap = vmap(self.control_cost_stage, in_axes=(1), out_axes=(0)) 345 | c_velbd_vmap = vmap(self.vel_bound_cost_stage, in_axes=(1), out_axes=(0)) 346 | c_proxi_vmap = vmap(self.proximity_cost_stage, in_axes=(1, 1), out_axes=(0)) 347 | 348 | # Evaluates all cost terms. 349 | c_state = c_state_vmap(states, target) 350 | c_cntrl = c_cntrl_vmap(controls) 351 | c_velbd = c_velbd_vmap(states) 352 | c_proxi = c_proxi_vmap(states, leader_trajs) 353 | c_termi = self.state_cost_terminal(states[:, -1], target) 354 | 355 | # Sums up all cost terms. 356 | J = jnp.sum(c_state + c_cntrl + c_velbd + c_proxi) + c_termi 357 | 358 | return J 359 | 360 | @partial(jit, static_argnames="self") 361 | def get_derivatives( 362 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 363 | ) -> Tuple[ArrayImpl, ArrayImpl, ArrayImpl, ArrayImpl]: 364 | """ 365 | Calculates gradients and Hessian of the overall cost using Jax. 366 | 367 | Args: 368 | states (ArrayImpl): (dim_x, N) planned trajectory. 369 | controls (ArrayImpl): (dim_u, N) planned control sequence. 370 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 371 | target (ArrayImpl): (dim_x,) target state. 372 | 373 | Returns: 374 | ArrayImpl: lxs of the shape (dim_x, N). 375 | ArrayImpl: Hxxs of the shape (dim_x, dim_x, N). 376 | ArrayImpl: lus of the shape (dim_u, N). 377 | ArrayImpl: Huus of the shape (dim_u, dim_u, N). 378 | """ 379 | # Creates cost gradient functions. 380 | lx_state_fn = jacfwd(self.state_cost_stage, argnums=0) 381 | lx_velbd_fn = jacfwd(self.vel_bound_cost_stage, argnums=0) 382 | lx_proxi_fn = jacfwd(self.proximity_cost_stage, argnums=0) 383 | lu_cntrl_fn = jacfwd(self.control_cost_stage, argnums=0) 384 | lx_termi_fn = jacfwd(self.state_cost_terminal, argnums=0) 385 | 386 | # Creates cost Hessian functions. 387 | Hxx_state_fn = hessian(self.state_cost_stage, argnums=0) 388 | Hxx_velbd_fn = hessian(self.vel_bound_cost_stage, argnums=0) 389 | Hxx_proxi_fn = hessian(self.proximity_cost_stage, argnums=0) 390 | Huu_cntrl_fn = hessian(self.control_cost_stage, argnums=0) 391 | Hxx_termi_fn = hessian(self.state_cost_terminal, argnums=0) 392 | 393 | # vmap all gradients and Hessians. 394 | lx_state_vmap = vmap(lx_state_fn, in_axes=(1, None), out_axes=(1)) 395 | lx_velbd_vmap = vmap(lx_velbd_fn, in_axes=(1), out_axes=(1)) 396 | lx_proxi_vmap = vmap(lx_proxi_fn, in_axes=(1, 1), out_axes=(1)) 397 | lu_cntrl_vmap = vmap(lu_cntrl_fn, in_axes=(1), out_axes=(1)) 398 | 399 | Hxx_state_vmap = vmap(Hxx_state_fn, in_axes=(1, None), out_axes=(2)) 400 | Hxx_velbd_vmap = vmap(Hxx_velbd_fn, in_axes=(1), out_axes=(2)) 401 | Hxx_proxi_vmap = vmap(Hxx_proxi_fn, in_axes=(1, 1), out_axes=(2)) 402 | Huu_cntrl_vmap = vmap(Huu_cntrl_fn, in_axes=(1), out_axes=(2)) 403 | 404 | # Evaluates all cost gradients and Hessians. 405 | lx_state = lx_state_vmap(states, target) 406 | lx_velbd = lx_velbd_vmap(states) 407 | lx_proxi = lx_proxi_vmap(states, leader_trajs) 408 | lu_cntrl = lu_cntrl_vmap(controls) 409 | lx_termi = lx_termi_fn(states[:, -1], target) 410 | 411 | Hxx_state = Hxx_state_vmap(states, target) 412 | Hxx_velbd = Hxx_velbd_vmap(states) 413 | Hxx_proxi = Hxx_proxi_vmap(states, leader_trajs) 414 | Huu_cntrl = Huu_cntrl_vmap(controls) 415 | Hxx_termi = Hxx_termi_fn(states[:, -1], target) 416 | 417 | lxs = lx_state + lx_velbd + lx_proxi 418 | lus = lu_cntrl 419 | Hxxs = Hxx_state + Hxx_velbd + Hxx_proxi 420 | Huus = Huu_cntrl 421 | 422 | lxs = lxs.at[:, -1].set(lxs[:, -1] + lx_termi) 423 | Hxxs = Hxxs.at[:, :, -1].set(Hxxs[:, :, -1] + Hxx_termi) 424 | 425 | return lxs, Hxxs, lus, Huus 426 | 427 | # --------------------------- Running performance cost terms --------------------------- 428 | @partial(jit, static_argnames="self") 429 | def state_cost_stage(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 430 | """ 431 | Computes the stage state cost. 432 | 433 | Args: 434 | state (ArrayImpl): (4,) [x, y, vx, vy] 435 | target (ArrayImpl): (3,) [x, y, v] 436 | 437 | Returns: 438 | ArrayImpl: cost (scalar) 439 | """ 440 | _xyv = jnp.array((state[self.px_dim], state[self.py_dim], jnp.linalg.norm(state[2:]))) 441 | return (_xyv - target).T @ self.W_state @ (_xyv-target) 442 | 443 | @partial(jit, static_argnames="self") 444 | def control_cost_stage(self, control: ArrayImpl) -> ArrayImpl: 445 | """ 446 | Computes the stage control cost. 447 | 448 | Args: 449 | control (ArrayImpl): (dim_u,) 450 | 451 | Returns: 452 | ArrayImpl: cost (scalar) 453 | """ 454 | return control.T @ self.W_control @ control 455 | 456 | @partial(jit, static_argnames="self") 457 | def state_cost_terminal(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 458 | """ 459 | Computes the terminal state cost. 460 | HACK: terminal velocity is set to 0. 461 | 462 | Args: 463 | state (ArrayImpl): (4,) [x, y, vx, vy] 464 | target (ArrayImpl): (3,) [x, y, v] 465 | 466 | Returns: 467 | ArrayImpl: cost (scalar) 468 | """ 469 | t_state = jnp.array([target[0], target[1], 0., 0.]) 470 | return (state - t_state).T @ self.Q_terminal @ (state-t_state) 471 | 472 | # ----------------------- Running soft constraint cost terms ----------------------- 473 | @partial(jit, static_argnames="self") 474 | def vel_bound_cost_stage(self, state: ArrayImpl) -> ArrayImpl: 475 | """ 476 | Calculates the velocity bound soft constraint cost. 477 | 478 | Args: 479 | state (ArrayImpl): (4,) 480 | 481 | Returns: 482 | ArrayImpl: cost (scalar) 483 | """ 484 | v = jnp.linalg.norm(state[2:]) 485 | # cons_v_min = self.v_min - v 486 | cons_v_max = v - self.v_max 487 | # barrier_v_min = self.q1_v * jnp.exp(jnp.clip(self.q2_v * cons_v_min, None, self.barrier_thr)) 488 | barrier_v_max = self.q1_v * jnp.exp(jnp.clip(self.q2_v * cons_v_max, None, self.barrier_thr)) 489 | return barrier_v_max 490 | 491 | @partial(jit, static_argnames="self") 492 | def proximity_cost_stage(self, state: ArrayImpl, leader_state: ArrayImpl) -> ArrayImpl: 493 | """ 494 | Calculates the proximity soft constraint cost. 495 | 496 | Args: 497 | state (ArrayImpl): (dim_x,) ego state 498 | leader_state (ArrayImpl): (dim_x, N_leader) leaders' optimized state 499 | 500 | Returns: 501 | ArrayImpl: cost (scalar) 502 | """ 503 | 504 | def _looper(ii, cost_vec): 505 | dx = state[self.px_dim] - leader_state[self.px_dim, ii] 506 | dy = state[self.py_dim] - leader_state[self.py_dim, ii] 507 | penalty_prox = -jnp.minimum(jnp.sqrt(dx**2 + dy**2) - self.prox_sep, 0.0) 508 | cost_vec = cost_vec.at[ii].set( 509 | self.q1_prox * jnp.exp(jnp.clip(self.q2_prox * penalty_prox, None, self.barrier_thr)) 510 | ) 511 | return cost_vec 512 | 513 | if self.num_leaders == 0: 514 | return 0.0 515 | else: 516 | cost_vec = jnp.zeros((self.num_leaders,)) 517 | cost_vec = lax.fori_loop(0, self.num_leaders, _looper, leader_state) 518 | # return jnp.sum(cost_vec) 519 | return jnp.max(cost_vec) 520 | 521 | 522 | class CostDoubleIntegratorMovingTargets(CostDoubleIntegrator): 523 | 524 | def __init__(self, config: Struct, num_leaders: float): 525 | CostDoubleIntegrator.__init__(self, config, num_leaders) 526 | 527 | self.W_state = np.diag((config.W_X, config.W_Y, config.W_V, config.W_V)) 528 | 529 | @partial(jit, static_argnames="self") 530 | def get_cost( 531 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 532 | ) -> ArrayImpl: 533 | """ 534 | Calculates the cost given planned states and controls. 535 | 536 | Args: 537 | states (ArrayImpl): (dim_x, N) planned trajectory. 538 | controls (ArrayImpl): (dim_u, N) planned control sequence. 539 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 540 | target (ArrayImpl): (dim_x, N) target trajectory. 541 | 542 | Returns: 543 | float: total cost. 544 | """ 545 | # vmap all costs. 546 | c_state_vmap = vmap(self.state_cost_stage, in_axes=(1, 1), out_axes=(0)) 547 | c_cntrl_vmap = vmap(self.control_cost_stage, in_axes=(1), out_axes=(0)) 548 | c_velbd_vmap = vmap(self.vel_bound_cost_stage, in_axes=(1), out_axes=(0)) 549 | c_proxi_vmap = vmap(self.proximity_cost_stage, in_axes=(1, 1), out_axes=(0)) 550 | 551 | # Evaluates all cost terms. 552 | c_state = c_state_vmap(states, target) 553 | c_cntrl = c_cntrl_vmap(controls) 554 | c_velbd = c_velbd_vmap(states) 555 | c_proxi = c_proxi_vmap(states, leader_trajs) 556 | c_termi = self.state_cost_terminal(states[:, -1], target[:, -1]) 557 | 558 | # Sums up all cost terms. 559 | J = jnp.sum(c_state + c_cntrl + c_velbd + c_proxi) + c_termi 560 | 561 | return J 562 | 563 | @partial(jit, static_argnames="self") 564 | def get_derivatives( 565 | self, states: ArrayImpl, controls: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 566 | ) -> Tuple[ArrayImpl, ArrayImpl, ArrayImpl, ArrayImpl]: 567 | """ 568 | Calculates gradients and Hessian of the overall cost using Jax. 569 | 570 | Args: 571 | states (ArrayImpl): (dim_x, N) planned trajectory. 572 | controls (ArrayImpl): (dim_u, N) planned control sequence. 573 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs. 574 | target (ArrayImpl): (dim_x,) target state. 575 | 576 | Returns: 577 | ArrayImpl: lxs of the shape (dim_x, N). 578 | ArrayImpl: Hxxs of the shape (dim_x, dim_x, N). 579 | ArrayImpl: lus of the shape (dim_u, N). 580 | ArrayImpl: Huus of the shape (dim_u, dim_u, N). 581 | """ 582 | # Creates cost gradient functions. 583 | lx_state_fn = jacfwd(self.state_cost_stage, argnums=0) 584 | lx_velbd_fn = jacfwd(self.vel_bound_cost_stage, argnums=0) 585 | lx_proxi_fn = jacfwd(self.proximity_cost_stage, argnums=0) 586 | lu_cntrl_fn = jacfwd(self.control_cost_stage, argnums=0) 587 | lx_termi_fn = jacfwd(self.state_cost_terminal, argnums=0) 588 | 589 | # Creates cost Hessian functions. 590 | Hxx_state_fn = hessian(self.state_cost_stage, argnums=0) 591 | Hxx_velbd_fn = hessian(self.vel_bound_cost_stage, argnums=0) 592 | Hxx_proxi_fn = hessian(self.proximity_cost_stage, argnums=0) 593 | Huu_cntrl_fn = hessian(self.control_cost_stage, argnums=0) 594 | Hxx_termi_fn = hessian(self.state_cost_terminal, argnums=0) 595 | 596 | # vmap all gradients and Hessians. 597 | lx_state_vmap = vmap(lx_state_fn, in_axes=(1, 1), out_axes=(1)) 598 | lx_velbd_vmap = vmap(lx_velbd_fn, in_axes=(1), out_axes=(1)) 599 | lx_proxi_vmap = vmap(lx_proxi_fn, in_axes=(1, 1), out_axes=(1)) 600 | lu_cntrl_vmap = vmap(lu_cntrl_fn, in_axes=(1), out_axes=(1)) 601 | 602 | Hxx_state_vmap = vmap(Hxx_state_fn, in_axes=(1, 1), out_axes=(2)) 603 | Hxx_velbd_vmap = vmap(Hxx_velbd_fn, in_axes=(1), out_axes=(2)) 604 | Hxx_proxi_vmap = vmap(Hxx_proxi_fn, in_axes=(1, 1), out_axes=(2)) 605 | Huu_cntrl_vmap = vmap(Huu_cntrl_fn, in_axes=(1), out_axes=(2)) 606 | 607 | # Evaluates all cost gradients and Hessians. 608 | lx_state = lx_state_vmap(states, target) 609 | lx_velbd = lx_velbd_vmap(states) 610 | lx_proxi = lx_proxi_vmap(states, leader_trajs) 611 | lu_cntrl = lu_cntrl_vmap(controls) 612 | lx_termi = lx_termi_fn(states[:, -1], target[:, -1]) 613 | 614 | Hxx_state = Hxx_state_vmap(states, target) 615 | Hxx_velbd = Hxx_velbd_vmap(states) 616 | Hxx_proxi = Hxx_proxi_vmap(states, leader_trajs) 617 | Huu_cntrl = Huu_cntrl_vmap(controls) 618 | Hxx_termi = Hxx_termi_fn(states[:, -1], target[:, -1]) 619 | 620 | lxs = lx_state + lx_velbd + lx_proxi 621 | lus = lu_cntrl 622 | Hxxs = Hxx_state + Hxx_velbd + Hxx_proxi 623 | Huus = Huu_cntrl 624 | 625 | lxs = lxs.at[:, -1].set(lxs[:, -1] + lx_termi) 626 | Hxxs = Hxxs.at[:, :, -1].set(Hxxs[:, :, -1] + Hxx_termi) 627 | 628 | return lxs, Hxxs, lus, Huus 629 | 630 | @partial(jit, static_argnames="self") 631 | def state_cost_stage(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 632 | """ 633 | Computes the stage state cost. 634 | 635 | Args: 636 | state (ArrayImpl): (4,) [x, y, vx, vy] 637 | target (ArrayImpl): (4,) [x, y, vx, vy] 638 | 639 | Returns: 640 | ArrayImpl: cost (scalar) 641 | """ 642 | return (state - target).T @ self.W_state @ (state-target) 643 | 644 | @partial(jit, static_argnames="self") 645 | def state_cost_terminal(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 646 | """ 647 | Computes the terminal state cost. 648 | 649 | Args: 650 | state (ArrayImpl): (4,) [x, y, vx, vy] 651 | target (ArrayImpl): (4,) [x, y, vx, vy] 652 | 653 | Returns: 654 | ArrayImpl: cost (scalar) 655 | """ 656 | return (state - target).T @ self.Q_terminal @ (state-target) 657 | 658 | 659 | class CostDubinsCarILQGame(Cost): 660 | 661 | def __init__(self, config: Struct, LMx: np.ndarray, num_players: int): 662 | Cost.__init__(self, config, num_leaders=0) 663 | 664 | # Standard LQ weighting matrices. 665 | self.W_state = np.diag((config.W_X, config.W_Y, config.W_V, config.W_PSI)) 666 | self.W_control = np.diag((config.W_ACCEL, config.W_DELTA)) 667 | self.W_terminal = np.diag((config.W_X_T, config.W_Y_T, config.W_V_T, config.W_PSI_T)) 668 | 669 | # Soft constraint parameters. 670 | self.q1_v = config.Q1_V 671 | self.q2_v = config.Q2_V 672 | self.q1_prox = config.Q1_PROX 673 | self.q2_prox = config.Q2_PROX 674 | self.barrier_thr = config.BARRIER_THR 675 | self.prox_sep = config.PROX_SEP 676 | self.v_min = config.V_MIN 677 | self.v_max = config.V_MAX 678 | 679 | # Lifting matrix 680 | self.LMx = LMx 681 | 682 | # Problem dimensions. 683 | self.dim_xi = LMx.shape[0] 684 | self.dim_x = LMx.shape[1] 685 | self.num_players = num_players 686 | 687 | @partial(jit, static_argnames="self") 688 | def get_cost(self, states: ArrayImpl, controls: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 689 | """ 690 | Calculates the cost given planned states and controls. 691 | 692 | Args: 693 | states (ArrayImpl): (dim_x, N) planned trajectory. 694 | controls (ArrayImpl): (dim_u, N) planned control sequence. 695 | target (ArrayImpl): (dim_xi,) target state. 696 | 697 | Returns: 698 | float: total cost. 699 | """ 700 | # vmap all costs. 701 | c_state_vmap = vmap(self.state_cost_stage, in_axes=(1, None), out_axes=(0)) 702 | c_cntrl_vmap = vmap(self.control_cost_stage, in_axes=(1), out_axes=(0)) 703 | c_velbd_vmap = vmap(self.vel_bound_cost_stage, in_axes=(1), out_axes=(0)) 704 | c_proxi_vmap = vmap(self.proximity_cost_stage, in_axes=(1), out_axes=(0)) 705 | 706 | # Evaluates all cost terms. 707 | c_state = c_state_vmap(states, target) 708 | c_cntrl = c_cntrl_vmap(controls) 709 | c_velbd = c_velbd_vmap(states) 710 | c_proxi = c_proxi_vmap(states) 711 | c_termi = self.state_cost_terminal(states[:, -1], target) 712 | 713 | # Sums up all cost terms. 714 | J = jnp.sum(c_state + c_cntrl + c_velbd + c_proxi) + c_termi 715 | 716 | return J 717 | 718 | @partial(jit, static_argnames="self") 719 | def get_derivatives(self, states: ArrayImpl, controls: ArrayImpl, 720 | target: ArrayImpl) -> Tuple[ArrayImpl, ArrayImpl, ArrayImpl, ArrayImpl]: 721 | """ 722 | Calculates gradients and Hessian of the overall cost using Jax. 723 | 724 | Args: 725 | states (ArrayImpl): (dim_x, N) planned trajectory. 726 | controls (ArrayImpl): (dim_u, N) planned control sequence. 727 | target (ArrayImpl): (dim_x,) target state. 728 | 729 | Returns: 730 | ArrayImpl: lxs of the shape (dim_x, N). 731 | ArrayImpl: Hxxs of the shape (dim_x, dim_x, N). 732 | ArrayImpl: lus of the shape (dim_u, N). 733 | ArrayImpl: Huus of the shape (dim_u, dim_u, N). 734 | """ 735 | # Creates cost gradient functions. 736 | lx_state_fn = jacfwd(self.state_cost_stage, argnums=0) 737 | lx_velbd_fn = jacfwd(self.vel_bound_cost_stage, argnums=0) 738 | lx_proxi_fn = jacfwd(self.proximity_cost_stage, argnums=0) 739 | lu_cntrl_fn = jacfwd(self.control_cost_stage, argnums=0) 740 | lx_termi_fn = jacfwd(self.state_cost_terminal, argnums=0) 741 | 742 | # Creates cost Hessian functions. 743 | Hxx_state_fn = hessian(self.state_cost_stage, argnums=0) 744 | Hxx_velbd_fn = hessian(self.vel_bound_cost_stage, argnums=0) 745 | Hxx_proxi_fn = hessian(self.proximity_cost_stage, argnums=0) 746 | Huu_cntrl_fn = hessian(self.control_cost_stage, argnums=0) 747 | Hxx_termi_fn = hessian(self.state_cost_terminal, argnums=0) 748 | 749 | # vmap all gradients and Hessians. 750 | lx_state_vmap = vmap(lx_state_fn, in_axes=(1, None), out_axes=(1)) 751 | lx_velbd_vmap = vmap(lx_velbd_fn, in_axes=(1), out_axes=(1)) 752 | lx_proxi_vmap = vmap(lx_proxi_fn, in_axes=(1), out_axes=(1)) 753 | lu_cntrl_vmap = vmap(lu_cntrl_fn, in_axes=(1), out_axes=(1)) 754 | 755 | Hxx_state_vmap = vmap(Hxx_state_fn, in_axes=(1, None), out_axes=(2)) 756 | Hxx_velbd_vmap = vmap(Hxx_velbd_fn, in_axes=(1), out_axes=(2)) 757 | Hxx_proxi_vmap = vmap(Hxx_proxi_fn, in_axes=(1), out_axes=(2)) 758 | Huu_cntrl_vmap = vmap(Huu_cntrl_fn, in_axes=(1), out_axes=(2)) 759 | 760 | # Evaluates all cost gradients and Hessians. 761 | lx_state = lx_state_vmap(states, target) 762 | lx_velbd = lx_velbd_vmap(states) 763 | lx_proxi = lx_proxi_vmap(states) 764 | lu_cntrl = lu_cntrl_vmap(controls) 765 | lx_termi = lx_termi_fn(states[:, -1], target) 766 | 767 | Hxx_state = Hxx_state_vmap(states, target) 768 | Hxx_velbd = Hxx_velbd_vmap(states) 769 | Hxx_proxi = Hxx_proxi_vmap(states) 770 | Huu_cntrl = Huu_cntrl_vmap(controls) 771 | Hxx_termi = Hxx_termi_fn(states[:, -1], target) 772 | 773 | lxs = lx_state + lx_velbd + lx_proxi 774 | lus = lu_cntrl 775 | Hxxs = Hxx_state + Hxx_velbd + Hxx_proxi 776 | Huus = Huu_cntrl 777 | 778 | lxs = lxs.at[:, -1].set(lxs[:, -1] + lx_termi) 779 | Hxxs = Hxxs.at[:, :, -1].set(Hxxs[:, :, -1] + Hxx_termi) 780 | 781 | return lxs, lus, Hxxs, Huus 782 | 783 | # --------------------------- Running performance cost terms --------------------------- 784 | @partial(jit, static_argnames="self") 785 | def state_cost_stage(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 786 | """ 787 | Computes the stage state cost. 788 | 789 | Args: 790 | state (ArrayImpl): (4 * num_players,) [x, y, v, psi] 791 | target (ArrayImpl): (4,) 792 | 793 | Returns: 794 | ArrayImpl: cost (scalar) 795 | """ 796 | return (self.LMx @ state - target).T @ self.W_state @ (self.LMx @ state - target) 797 | 798 | @partial(jit, static_argnames="self") 799 | def control_cost_stage(self, control: ArrayImpl) -> ArrayImpl: 800 | """ 801 | Computes the stage control cost. 802 | 803 | Args: 804 | control (ArrayImpl): (2,) 805 | 806 | Returns: 807 | ArrayImpl: cost (scalar) 808 | """ 809 | return control.T @ self.W_control @ control 810 | 811 | @partial(jit, static_argnames="self") 812 | def state_cost_terminal(self, state: ArrayImpl, target: ArrayImpl) -> ArrayImpl: 813 | """ 814 | Computes the terminal state cost. 815 | HACK: terminal velocity is set to 0. 816 | 817 | Args: 818 | state (ArrayImpl): (4 * num_players,) 819 | target (ArrayImpl): (4,) 820 | 821 | Returns: 822 | ArrayImpl: cost (scalar) 823 | """ 824 | t_state = jnp.array([target[0], target[1], 0., target[3]]) 825 | return (self.LMx @ state - t_state).T @ self.W_terminal @ (self.LMx @ state - t_state) 826 | 827 | # ----------------------- Running soft constraint cost terms ----------------------- 828 | @partial(jit, static_argnames="self") 829 | def vel_bound_cost_stage(self, state: ArrayImpl) -> ArrayImpl: 830 | """ 831 | Calculates the velocity bound soft constraint cost. 832 | 833 | Args: 834 | state (ArrayImpl): (4 * num_players,) 835 | 836 | Returns: 837 | ArrayImpl: cost (scalar) 838 | """ 839 | state = self.LMx @ state 840 | cons_v_min = self.v_min - state[2] 841 | cons_v_max = state[2] - self.v_max 842 | barrier_v_min = self.q1_v * jnp.exp(jnp.clip(self.q2_v * cons_v_min, None, self.barrier_thr)) 843 | barrier_v_max = self.q1_v * jnp.exp(jnp.clip(self.q2_v * cons_v_max, None, self.barrier_thr)) 844 | return barrier_v_min + barrier_v_max 845 | 846 | @partial(jit, static_argnames="self") 847 | def proximity_cost_stage(self, state: ArrayImpl) -> ArrayImpl: 848 | """ 849 | Calculates the proximity soft constraint cost. 850 | 851 | Args: 852 | state (ArrayImpl): (4 * num_players,) ego state 853 | 854 | Returns: 855 | ArrayImpl: cost (scalar) 856 | """ 857 | 858 | def _looper(ii, cost_vec): 859 | 860 | def true_fn(penalty_prox): 861 | res = self.q1_prox * jnp.exp(jnp.clip(self.q2_prox * penalty_prox, None, self.barrier_thr)) 862 | return res 863 | 864 | def false_fn(penalty_prox): 865 | return 0. 866 | 867 | dx = state_ego[self.px_dim] - state[self.px_dim, ii] 868 | dy = state_ego[self.py_dim] - state[self.py_dim, ii] 869 | sep_sq = dx**2 + dy**2 870 | penalty_prox = -jnp.minimum(jnp.sqrt(sep_sq) - self.prox_sep, 0.0) 871 | pred = sep_sq > 1e-3 872 | cost_vec = cost_vec.at[ii].set(lax.cond(pred, true_fn, false_fn, penalty_prox)) 873 | return cost_vec 874 | 875 | state_ego = self.LMx @ state 876 | state = state.reshape(self.dim_xi, self.num_players) 877 | cost_vec = jnp.zeros((self.num_players,)) 878 | cost_vec = lax.fori_loop(0, self.num_players, _looper, state) 879 | # return jnp.sum(cost_vec) 880 | return jnp.max(cost_vec) 881 | -------------------------------------------------------------------------------- /wpf/STP/dynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamics. 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Author: Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | 8 | import numpy as np 9 | from typing import Tuple 10 | from abc import ABC, abstractmethod 11 | 12 | from functools import partial 13 | from jax import jit, jacfwd 14 | from jaxlib.xla_extension import ArrayImpl 15 | import jax.numpy as jnp 16 | import jax 17 | from scipy.signal import cont2discrete as c2d 18 | 19 | from .utils import Struct 20 | 21 | 22 | class Dynamics(ABC): 23 | 24 | def __init__(self, config: Struct): 25 | self.dim_x = config.DIM_X 26 | self.dim_u = config.DIM_U 27 | 28 | self.T = config.T # planning time horizon 29 | self.N = config.N # number of planning steps 30 | self.dt = self.T / (self.N - 1) # time step for each planning step 31 | 32 | # Useful constants. 33 | self.zeros = np.zeros((self.N)) 34 | self.ones = np.ones((self.N)) 35 | 36 | # Computes Jacobian matrices using Jax. 37 | self.jac_f = jit(jacfwd(self.dct_time_dyn, argnums=[0, 1])) 38 | 39 | # Vectorizes Jacobians using Jax. 40 | self.jac_f = jit(jax.vmap(self.jac_f, in_axes=(1, 1), out_axes=(2, 2))) 41 | 42 | @abstractmethod 43 | def dct_time_dyn(self, state: ArrayImpl, control: ArrayImpl) -> ArrayImpl: 44 | """ 45 | Computes the one-step time evolution of the system with the forward Euler method. 46 | Args: 47 | state (ArrayImpl): (nx,) 48 | control (ArrayImpl): (nu,) 49 | 50 | Returns: 51 | ArrayImpl: (nx,) next state. 52 | """ 53 | raise NotImplementedError 54 | 55 | @abstractmethod 56 | def integrate_forward(self, state: ArrayImpl, control: ArrayImpl) -> Tuple[ArrayImpl, ArrayImpl]: 57 | """ 58 | Computes the next state. 59 | 60 | Args: 61 | state (ArrayImpl): (nx,) 62 | control (ArrayImpl): (nu,) 63 | 64 | Returns: 65 | state_next: (nx,) ArrayImpl 66 | control_next: (nu,) ArrayImpl 67 | """ 68 | raise NotImplementedError 69 | 70 | @partial(jit, static_argnames="self") 71 | def get_AB_matrix(self, nominal_states: ArrayImpl, 72 | nominal_controls: ArrayImpl) -> Tuple[ArrayImpl, ArrayImpl]: 73 | """ 74 | Returns the linearized 'A' and 'B' matrix of the ego vehicle around 75 | nominal states and controls. 76 | 77 | Args: 78 | nominal_states (ArrayImpl): (nx, N) states along the nominal traj. 79 | nominal_controls (ArrayImpl): (nu, N) controls along the traj. 80 | 81 | Returns: 82 | ArrayImpl: the Jacobian of next state w.r.t. the current state. 83 | ArrayImpl: the Jacobian of next state w.r.t. the current control. 84 | """ 85 | A, B = self.jac_f(nominal_states, nominal_controls) 86 | return A, B 87 | 88 | 89 | class DubinsCar(Dynamics): 90 | 91 | def __init__(self, config: Struct): 92 | Dynamics.__init__(self, config) 93 | self.delta_min = config.DELTA_MIN # min turn rate (rad/s) 94 | self.delta_max = config.DELTA_MAX # max turn rate (rad/s) 95 | self.a_min = config.A_MIN # min longitudial accel 96 | self.a_max = config.A_MAX # max longitudial accel 97 | self.v_min = config.V_MIN # min velocity 98 | self.v_max = config.V_MAX # max velocity 99 | 100 | @partial(jit, static_argnames="self") 101 | def dct_time_dyn(self, state: ArrayImpl, control: ArrayImpl) -> ArrayImpl: 102 | """ 103 | Computes the one-step time evolution of the system with the forward Euler method. 104 | Dynamics: 105 | \dot{x} = v cos(psi) 106 | \dot{y} = v sin(psi) 107 | \dot{v} = a 108 | \dot{psi} = delta 109 | 110 | Args: 111 | 0 1 2 3 112 | state (ArrayImpl): [x, y, v, psi] 113 | control (ArrayImpl): [a, delta] 114 | 115 | Returns: 116 | ArrayImpl: (4,) next state. 117 | """ 118 | d_x = state[2] * jnp.cos(state[3]) 119 | d_y = state[2] * jnp.sin(state[3]) 120 | d_v = control[0] 121 | d_psi = control[1] 122 | state_next = state + jnp.hstack((d_x, d_y, d_v, d_psi)) * self.dt 123 | return state_next 124 | 125 | @partial(jit, static_argnames="self") 126 | def integrate_forward(self, state: ArrayImpl, control: ArrayImpl) -> Tuple[ArrayImpl, ArrayImpl]: 127 | """ 128 | Computes the next state. 129 | 130 | Args: 131 | state (ArrayImpl): (4,) jnp array [x, y, v, psi]. 132 | control (ArrayImpl): (2,) jnp array [a, delta]. 133 | 134 | Returns: 135 | state_next: ArrayImpl 136 | control_next: ArrayImpl 137 | """ 138 | # Clips the control values with their limits. 139 | accel = jnp.clip(control[0], self.a_min, self.a_max) 140 | delta = jnp.clip(control[1], self.delta_min, self.delta_max) 141 | 142 | # Integrates the system one-step forward in time using the Euler method. 143 | control_clip = jnp.hstack((accel, delta)) 144 | state_next = self.dct_time_dyn(state, control_clip) 145 | 146 | return state_next, control_clip 147 | 148 | @partial(jit, static_argnames="self") 149 | def integrate_forward_norev(self, state: ArrayImpl, 150 | control: ArrayImpl) -> Tuple[ArrayImpl, ArrayImpl]: 151 | """ 152 | Computes the next state. The velocity is clipped so that the car cannot back up. 153 | 154 | Args: 155 | state (ArrayImpl): (4,) jnp array [x, y, v, psi]. 156 | control (ArrayImpl): (2,) jnp array [a, delta]. 157 | 158 | Returns: 159 | state_next: ArrayImpl 160 | control_next: ArrayImpl 161 | """ 162 | # Clips the control values with their limits. 163 | accel = jnp.clip(control[0], self.a_min, self.a_max) 164 | delta = jnp.clip(control[1], self.delta_min, self.delta_max) 165 | 166 | # Integrates the system one-step forward in time using the Euler method. 167 | control_clip = jnp.hstack((accel, delta)) 168 | state_next = self.dct_time_dyn(state, control_clip) 169 | state_next = state_next.at[2].set(jnp.maximum(state_next[2], 0.0)) # car cannot back up. 170 | 171 | return state_next, control_clip 172 | 173 | 174 | class DoubleIntegrator(Dynamics): 175 | 176 | def __init__(self, config: Struct): 177 | Dynamics.__init__(self, config) 178 | self.ax_min = config.AX_MIN # min x accel 179 | self.ax_max = config.AX_MAX # max x accel 180 | self.ay_min = config.AY_MIN # min y accel 181 | self.ay_max = config.AY_MAX # max y accel 182 | self.Ac = np.array(([0., 0., 1., 0.], [0., 0., 0., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.])) 183 | self.Bc = np.array(([0., 0.], [0., 0.], [1., 0.], [0., 1.])) 184 | _Ad, _Bd, _, _, _ = c2d(system=(self.Ac, self.Bc, None, None), dt=self.dt, method='zoh') 185 | self.Ad, self.Bd = jnp.asarray(_Ad), jnp.asarray(_Bd) 186 | 187 | @partial(jit, static_argnames="self") 188 | def dct_time_dyn(self, state: ArrayImpl, control: ArrayImpl) -> ArrayImpl: 189 | """ 190 | Computes the one-step time evolution of the system with the forward Euler method. 191 | Dynamics: 192 | \dot{x} = vx 193 | \dot{y} = vy 194 | \dot{vx} = ax 195 | \dot{vy} = ay 196 | 197 | Args: 198 | 0 1 2 3 199 | state (ArrayImpl): [x, y, vx, vy] 200 | control (ArrayImpl): [ax, ay] 201 | 202 | Returns: 203 | ArrayImpl: (4,) next state. 204 | """ 205 | state_next = self.Ad @ state + self.Bd @ control 206 | return state_next 207 | 208 | @partial(jit, static_argnames="self") 209 | def integrate_forward(self, state: ArrayImpl, control: ArrayImpl) -> Tuple[ArrayImpl, ArrayImpl]: 210 | """ 211 | Computes the next state. 212 | 213 | Args: 214 | state (ArrayImpl): (4,) jnp array [x, y, vx, vy] 215 | control (ArrayImpl): (2,) jnp array [ax, ay] 216 | 217 | Returns: 218 | state_next: ArrayImpl 219 | control_next: ArrayImpl 220 | """ 221 | # Clips the control values with their limits. 222 | ax = jnp.clip(control[0], self.ax_min, self.ax_max) 223 | ay = jnp.clip(control[1], self.ay_min, self.ay_max) 224 | 225 | # Integrates the system one-step forward in time using the Euler method. 226 | control_clip = jnp.hstack((ax, ay)) 227 | state_next = self.dct_time_dyn(state, control_clip) 228 | 229 | return state_next, control_clip 230 | -------------------------------------------------------------------------------- /wpf/STP/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..bnp import Node 3 | from .stp import * 4 | import itertools 5 | 6 | 7 | def stp_initializer(instance_data, nodes, stats, settings, mode="init"): 8 | """ 9 | Initializes the BnP for the STP algorithm. 10 | In particular, it sets some custom settings, statistics, and node data 11 | 12 | mode = "init" or "update" 13 | """ 14 | if mode == "init": 15 | x_cur, init_control, targets, config = instance_data 16 | settings.custom_settings = { 17 | "solver": STP(config), 18 | "RHC": config.RHC, 19 | "init_parent": config.INIT_WITH_PARENT, 20 | "jax_compile_threshold": 0.5, 21 | "x_root": x_cur, 22 | "control_root": init_control, 23 | "targets_root": targets, 24 | "enable_custom_heuristic": False, 25 | } 26 | elif mode == "update": 27 | x_cur, init_control, targets, Js_prev, config = instance_data 28 | settings.custom_settings["x_root"] = x_cur 29 | settings.custom_settings["control_root"] = init_control 30 | settings.custom_settings["targets_root"] = targets 31 | settings.custom_settings["Js_prev"] = Js_prev 32 | settings.custom_settings["enable_custom_heuristic"] = True 33 | stats.custom_statistics = {"jax_compile_time": [], "Js_prev": []} 34 | nodes[0].custom_data = {"collision_matrix": None} 35 | 36 | 37 | def stp_solve_cb(data, node: Node, settings, stats): 38 | """ 39 | This method solves the local node of the STP 40 | """ 41 | # First, try to warmstart with the parent node's solution 42 | if (node.solution is None) or (not settings.custom_settings["init_parent"]): 43 | ctrls_ws = settings.custom_settings["control_root"] 44 | else: 45 | ctrls_ws = node.solution[1] 46 | 47 | # Solves STP. 48 | if settings.custom_settings["RHC"]: 49 | states, controls, Js, ts = settings.custom_settings["solver"].solve_rhc( 50 | settings.custom_settings["x_root"], 51 | ctrls_ws, 52 | settings.custom_settings["targets_root"], 53 | node.permutation, 54 | ) 55 | else: 56 | states, controls, Js, ts = settings.custom_settings["solver"].solve( 57 | settings.custom_settings["x_root"], 58 | ctrls_ws, 59 | settings.custom_settings["targets_root"], 60 | node.permutation, 61 | ) 62 | 63 | _total_time = sum(ts) 64 | if _total_time < settings.custom_settings["jax_compile_threshold"]: 65 | stats.custom_statistics["jax_compile_time"].append(_total_time) 66 | 67 | node.custom_data = { 68 | "collision_matrix": 69 | np.array( 70 | settings.custom_settings["solver"].pairwise_collision_check( 71 | np.array(np.stack(states, axis=2)) 72 | ) 73 | ) 74 | } 75 | 76 | node.solution = (states, controls, ts) 77 | node.lb = max(np.sum(Js), node.lb) 78 | node.Js = Js 79 | if None not in node.permutation: 80 | node.ub = node.lb 81 | 82 | 83 | def stp_branching(nodes, node, n, stats, settings): 84 | """ 85 | Defines the custom branching strategy for STP. 86 | Note that this method performs an in-place operation on nodes 87 | :param nodes: List of nodes 88 | :param node: Incumbent node 89 | :param n: Number of players 90 | """ 91 | undefined_locations = np.where(node.permutation == None)[0] 92 | if len(undefined_locations) > 0: 93 | location = undefined_locations[0] 94 | undefined = [] 95 | 96 | for player in range(n): 97 | if player not in node.permutation: 98 | undefined.append(player) 99 | 100 | no_collisions = True 101 | for player1 in undefined: 102 | for player2 in undefined: 103 | if player1 != player2: 104 | if node.custom_data["collision_matrix"][player1, player2]: 105 | no_collisions = False 106 | break 107 | 108 | # If no collisions, arbitrary order is fine 109 | if no_collisions: 110 | # This is a feasible solution 111 | undefined_completions = itertools.permutations(undefined) 112 | 113 | for index, completion in enumerate(undefined_completions): 114 | child_permutation = np.copy(node.permutation) 115 | # Fill remaining spots with undefined players 116 | undefined_locations = np.where(node.permutation == None)[0] 117 | for i in range(len(undefined)): 118 | child_permutation[undefined_locations[i]] = completion[i] 119 | if index == 0: 120 | if abs(node.lb) != np.inf: 121 | # Feasible solution 122 | stats.num_feasible_solutions += 1 123 | node.ub = node.lb 124 | if stats.global_ub > node.ub: 125 | stats.incumbent_permutation = child_permutation 126 | stats.incumbent = node.solution 127 | stats.custom_statistics["Js_prev"] = node.Js 128 | # Update upper bound 129 | stats.global_ub = node.ub 130 | if settings.verbose: 131 | print( 132 | "\tNo collisions. Found feasible solution. Pruned symmetrics:", 133 | len(list(undefined_completions)) - 1, 134 | ) 135 | else: 136 | nodes.append( 137 | Node( 138 | child_permutation, 139 | node.depth + 1, 140 | node.solution, 141 | node.lb, 142 | -np.inf, 143 | node.custom_data, 144 | ) 145 | ) 146 | if settings.verbose: 147 | print( 148 | "\tNo collisions. Pruned symmetrics:", 149 | len(list(undefined_completions)) - 1, 150 | ) 151 | else: 152 | stats.explored_permutations.append(child_permutation.tolist()) 153 | # We are done 154 | return 155 | 156 | else: 157 | # We have to branch on all possible permutations 158 | skipped = 0 159 | for index, player in enumerate(undefined): 160 | child_permutation = np.copy(node.permutation) 161 | child_permutation[location] = player 162 | 163 | undefined_locations = np.where(child_permutation == None)[0] 164 | 165 | # If only one is missing 166 | if len(undefined_locations) == 1: 167 | child_permutation[undefined_locations[0]] = undefined[(index+1) % len(undefined)] 168 | child_permutation_list = child_permutation.tolist() 169 | if child_permutation.tolist() not in stats.explored_permutations: 170 | nodes.append( 171 | Node( 172 | child_permutation, 173 | node.depth + 1, 174 | node.solution, 175 | node.lb, 176 | -np.inf, 177 | node.custom_data, 178 | ) 179 | ) 180 | # Handle symmetry 181 | # For players that do not collide 182 | # print("\tstarting parent", node.permutation) 183 | # print("\tstarting child", child_permutation_list) 184 | for player1, player2 in zip(*np.nonzero(node.custom_data["collision_matrix"] - 1)): 185 | if player1 != player2: 186 | # Check for congtiguos players 187 | l_1 = np.where(child_permutation == player1)[0] 188 | l_2 = np.where(child_permutation == player2)[0] 189 | # If the two players are in the child permutation 190 | if len(l_1) == 1 and len(l_2) == 1: 191 | l_1 = l_1[0] 192 | l_2 = l_2[0] 193 | # If their location is adjacent and were undefined in parent 194 | if ( 195 | abs(l_1 - l_2) == 1 and l_1 in undefined_locations and l_2 in undefined_locations 196 | ): 197 | # print("removed something") 198 | child_copy = child_permutation_list.copy() 199 | child_copy[l_1] = player2 200 | child_copy[l_2] = player1 201 | skipped += 1 202 | # print("\t\tplayers", player1, player2) 203 | # print("\t\tskipping", child_copy) 204 | stats.explored_permutations.append(child_copy) 205 | if skipped > 0: 206 | print("\tPruned symmetric nodes:", skipped) 207 | -------------------------------------------------------------------------------- /wpf/STP/ilq_game.py: -------------------------------------------------------------------------------- 1 | """ 2 | Jaxified differentiable iterative LQ Game solver that computes a locally approximate Nash 3 | equilibrium solution to a general-sum trajectory game. 4 | 5 | Please contact the author(s) of this library if you have any questions. 6 | Author: Haimin Hu (haiminh@princeton.edu) 7 | """ 8 | 9 | import time 10 | import numpy as np 11 | from typing import List, Tuple 12 | from functools import partial 13 | 14 | from jax import jit, lax, vmap 15 | from jaxlib.xla_extension import ArrayImpl 16 | import jax.numpy as jnp 17 | 18 | from .utils import Struct 19 | from .cost import CostDubinsCarILQGame 20 | from .multiplayer_dynamical_system import MultiPlayerDynamicalSystem 21 | 22 | 23 | class ILQGame(object): 24 | 25 | def __init__(self, config: Struct, dynamics: MultiPlayerDynamicalSystem, verbose: str = False): 26 | """ 27 | Initializer. 28 | """ 29 | 30 | self.config = config 31 | self.horizon = config.N 32 | self.max_iter = config.MAX_ITER 33 | self.dynamics = dynamics 34 | self.num_players = dynamics._num_players 35 | # self.line_search_scaling = np.linspace(config.LS_MAX, config.LS_MIN, config.LS_NUM) 36 | self.line_search_scaling = 1.1**(-np.arange(10)**2) 37 | 38 | self.dim_x = self.dynamics._x_dim 39 | self.dim_x_ss = self.dynamics._subsystem.dim_x 40 | self.dim_u_ss = self.dynamics._subsystem.dim_u 41 | 42 | # Create costs for each player. 43 | self.costs = [ 44 | CostDubinsCarILQGame(config, self.dynamics._LMx[:, :, ii], self.num_players) 45 | for ii in range(self.num_players) 46 | ] 47 | 48 | self.verbose = verbose 49 | self.reset() 50 | 51 | def solve( 52 | self, cur_state: List[np.ndarray], us_warmstart: List[np.ndarray], targets: List[np.ndarray] 53 | ): 54 | """ 55 | Runs the iLQGame algorithm. 56 | 57 | Args: 58 | cur_state (List[np.ndarray]): [(nx,)] Current state. 59 | us_warmstart (List[np.ndarray]): [(nui, N)] Warmstart controls. 60 | 61 | Returns: 62 | states: List[np.ndarray] 63 | controls: List[np.ndarray] 64 | t_process: float 65 | status: int 66 | """ 67 | status = 0 68 | time0 = time.time() 69 | cur_state = np.concatenate(cur_state) 70 | us_warmstart = np.stack(us_warmstart, axis=2) 71 | targets = np.stack(targets, axis=1) 72 | 73 | # Initial forward pass. 74 | x0, us = jnp.asarray(cur_state), jnp.asarray(us_warmstart) 75 | xs, cost_init = self.initial_forward_pass(x0, us, targets) 76 | cost_best = cost_init 77 | 78 | self.reset(_current_x=xs, _current_u=us, _current_J=cost_init) 79 | 80 | # Main loop. 81 | for iter in range(self.max_iter): 82 | t_start = time.time() 83 | 84 | # region: Forward & backward passes. 85 | As, Bs = self.linearize_dynamics(xs, us) 86 | lxs, lus, Hxxs, Huus = self.quadraticize_costs(xs, us, targets) 87 | Ps, alphas_bpass, _, _ = self.backward_pass(As, Bs, lxs, lus, Hxxs, Huus) 88 | for line_search_scaling in self.line_search_scaling: 89 | alphas_ls = alphas_bpass * line_search_scaling 90 | xs_ls, us_ls, cost_ls = self.compute_operating_point( 91 | xs, us, Ps, alphas_ls, cur_state, targets 92 | ) 93 | if cost_ls < cost_best: 94 | xs = xs_ls 95 | us = us_ls 96 | cost_best = cost_ls 97 | break 98 | # print("iter", iter, " | cost_best: ", cost_best) 99 | # endregion 100 | 101 | # region: Updates operating points. 102 | self.last_operating_point = self.current_operating_point 103 | self.current_operating_point = (xs, us) 104 | 105 | self.last_social_cost = self.current_social_cost 106 | self.current_social_cost = cost_best 107 | 108 | if self.current_social_cost < self.best_social_cost: 109 | self.best_operating_point = self.current_operating_point 110 | self.best_social_cost = self.current_social_cost 111 | # endregion 112 | 113 | # region: Checks convergence. 114 | if self.is_converged_cost(): 115 | status = 1 116 | if self.verbose: 117 | print( 118 | "[iLQGame] Social cost (", round(self.current_social_cost, 2), ") has converged! \n" 119 | ) 120 | break 121 | # endregion 122 | 123 | t_iter = time.time() - t_start 124 | if self.verbose: 125 | print( 126 | "[iLQGame] Iteration", iter, "| Social cost: ", round(self.current_social_cost, 2), 127 | " | Iter. time: ", t_iter 128 | ) 129 | 130 | t_process = time.time() - time0 131 | 132 | xs = np.asarray(xs).reshape(self.dim_x_ss, self.horizon, self.num_players) 133 | xs = [xs[:, :, ii] for ii in range(self.num_players)] 134 | us = np.asarray(us) 135 | us = [us[:, :, ii] for ii in range(self.num_players)] 136 | 137 | return xs, us, t_process, status 138 | 139 | def is_converged_cost(self): 140 | """ 141 | Checks convergence based on social cost difference. 142 | """ 143 | TOLERANCE_RATE = 0.005 144 | # COST_LB = 1e6 145 | 146 | if self.last_social_cost is None: 147 | return False 148 | 149 | cost_diff_rate = np.abs( 150 | (self.current_social_cost - self.last_social_cost) / self.last_social_cost 151 | ) 152 | 153 | if cost_diff_rate > TOLERANCE_RATE: #or self.current_social_cost > COST_LB: 154 | return False 155 | else: 156 | return True 157 | 158 | def reset(self, _current_x=None, _current_u=None, _current_J=None): 159 | """ 160 | Resets the solver and warmstarts it if possible. 161 | """ 162 | 163 | if _current_x is None: 164 | _current_x = jnp.zeros((self.dim_x, self.horizon)) 165 | 166 | if _current_u is None: 167 | _current_u = jnp.zeros((self.dim_u_ss, self.horizon, self.num_players)) 168 | 169 | self.last_operating_point = None 170 | self.current_operating_point = (_current_x, _current_u) 171 | self.best_operating_point = (_current_x, _current_u) 172 | 173 | self.last_social_cost = np.Inf 174 | self.current_social_cost = _current_J 175 | self.best_social_cost = _current_J 176 | 177 | def initial_forward_pass(self, cur_state: ArrayImpl, controls: ArrayImpl, targets: ArrayImpl): 178 | """ 179 | Performs the initial forward pass given warmstart controls. 180 | 181 | Args: 182 | cur_state (ArrayImpl) (nx,) 183 | controls (ArrayImpl) (nui, N, num_players) 184 | targets: (ArrayImpl) (nxi, num_players) 185 | 186 | Returns: 187 | states (ArrayImpl): states (nx, N) 188 | social_cost (float): sum of all players costs 189 | """ 190 | # Forward simulation. 191 | states = jnp.zeros((self.dim_x, self.horizon)) 192 | states = states.at[:, 0].set(cur_state) 193 | for k in range(1, self.horizon): 194 | states_next, _ = self.dynamics.integrate_forward(states[:, k - 1], controls[:, k - 1, :]) 195 | states = states.at[:, k].set(states_next) 196 | 197 | # Evaluates costs. 198 | cost_sum = self._evaluate_costs(states, controls, targets) 199 | 200 | return states, cost_sum 201 | 202 | def compute_operating_point( 203 | self, current_xs: ArrayImpl, current_us: ArrayImpl, Ps: ArrayImpl, alphas: ArrayImpl, 204 | cur_state: ArrayImpl, targets: ArrayImpl 205 | ) -> Tuple[ArrayImpl, ArrayImpl, float, ArrayImpl, ArrayImpl, ArrayImpl]: 206 | """ 207 | Computes current operating point by propagating through dynamics. 208 | This function is a wrapper of _compute_operating_point_jax() 209 | 210 | Args: 211 | current_xs (ArrayImpl): (nx, N) current state traj, used as nominal 212 | current_us (ArrayImpl): (nui, N, num_players) current player controls, used as nominal 213 | Ps (ArrayImpl): (nui, nx, N, num_players) 214 | alphas (ArrayImpl): (nui, N, num_players) 215 | cur_state (ArrayImpl): (nx,) current (initial) state 216 | targets (ArrayImpl): (nxi, num_players) 217 | 218 | Returns: 219 | xs (ArrayImpl): updated states (nx, N) 220 | us (ArrayImpl): updated player controls (nui, N, num_players) 221 | social_cost (float): sum of all players costs 222 | """ 223 | # Computes track info. 224 | xs, us = self._compute_operating_point_jax(current_xs, current_us, Ps, alphas, cur_state) 225 | 226 | # Evaluates costs. 227 | cost_sum = self._evaluate_costs(xs, us, targets) 228 | 229 | return xs, us, cost_sum 230 | 231 | @partial(jit, static_argnames='self') 232 | def _evaluate_costs(self, xs, us, targets): 233 | costs = jnp.zeros((self.num_players)) 234 | for ii in range(self.num_players): 235 | costs = costs.at[ii].set(self.costs[ii].get_cost(xs, us[:, :, ii], targets[:, ii])) 236 | return jnp.sum(costs) 237 | 238 | @partial(jit, static_argnames='self') 239 | def _compute_operating_point_jax( 240 | self, nominal_states: ArrayImpl, nominal_controls: ArrayImpl, Ps: ArrayImpl, 241 | alphas: ArrayImpl, cur_state: ArrayImpl 242 | ) -> Tuple[ArrayImpl, ArrayImpl]: 243 | """ 244 | Computes current operating point by propagating through dynamics. 245 | 246 | Args: 247 | nominal_states (ArrayImpl): (nx, N) 248 | nominal_controls (ArrayImpl): (nui, N, num_players) 249 | Ps (ArrayImpl): (nui, nx, N, num_players) 250 | alphas (ArrayImpl): (nui, N, num_players) 251 | cur_state (ArrayImpl): (nx,) current init. state 252 | 253 | Returns: 254 | xs (ArrayImpl): updated states (nx, N) 255 | us (ArrayImpl): updated player controls (nui, N, num_players) 256 | """ 257 | 258 | def forward_pass_looper(k, _carry): 259 | 260 | def compute_agent_control(x, x_ref, uii_ref, Pii, alphaii): 261 | return uii_ref - Pii @ (x-x_ref) - alphaii 262 | 263 | compute_all_agents_controls = vmap( 264 | compute_agent_control, in_axes=(None, None, 1, 2, 1), out_axes=(1) 265 | ) 266 | 267 | xs, us = _carry 268 | us_tmp = compute_all_agents_controls( 269 | xs[:, k], nominal_states[:, k], nominal_controls[:, k, :], Ps[:, :, k, :], alphas[:, k, :] 270 | ) 271 | X_next, U_next = self.dynamics.integrate_forward(xs[:, k], us_tmp) 272 | xs = xs.at[:, k + 1].set(X_next) 273 | us = us.at[:, k, :].set(U_next) 274 | return xs, us 275 | 276 | xs = jnp.zeros_like(nominal_states) 277 | us = jnp.zeros_like(nominal_controls) 278 | xs = xs.at[:, 0].set(cur_state) 279 | xs, us = lax.fori_loop(0, self.horizon - 1, forward_pass_looper, (xs, us)) 280 | return xs, us 281 | 282 | @partial(jit, static_argnames='self') 283 | def linearize_dynamics(self, xs: ArrayImpl, us: ArrayImpl) -> Tuple[ArrayImpl, ArrayImpl]: 284 | """ 285 | Linearizes dynamics at the current operating point. 286 | 287 | Args: 288 | xs (ArrayImpl): (nx, N) nominal state traj 289 | us (ArrayImpl): (nui, N, num_players) nominal player controls 290 | 291 | Returns: 292 | As (ArrayImpl): (nx, nx, N) A matrices 293 | Bs (ArrayImpl): (nx, nui, N, num_players) B matrices 294 | """ 295 | 296 | def linearize_single_time(x, u): 297 | A, B = self.dynamics.linearize_discrete_jitted(x, u) 298 | return A, B 299 | 300 | linearize_along_horizon = vmap(linearize_single_time, in_axes=(1, 1), out_axes=(2, 2)) 301 | As, Bs = linearize_along_horizon(xs, us) 302 | 303 | return As, Bs 304 | 305 | @partial(jit, static_argnames='self') 306 | def quadraticize_costs( 307 | self, 308 | xs: ArrayImpl, 309 | us: ArrayImpl, 310 | targets: ArrayImpl, 311 | ) -> Tuple[ArrayImpl, ArrayImpl, ArrayImpl, ArrayImpl]: 312 | """ 313 | Quadraticizes costs of all players at the current operating point. 314 | 315 | Args: 316 | xs (ArrayImpl): (nx, N) nominal state trajectory 317 | us (ArrayImpl): (nui, N, num_players) nominal player controls 318 | targets (ArrayImpl): (nxi, num_players) 319 | 320 | Returns: 321 | lxs (ArrayImpl): (nx, N, num_players) gradients lx = dc/dx of all playes 322 | lus (ArrayImpl): (nui, N, num_players) gradients lu = dc/du of all playes 323 | Hxxs (ArrayImpl): (nx, nx, N, num_players) Hessians Hxx of all playes 324 | Huus (ArrayImpl): (nui, nui, N, num_players) Hessians Huu of all playes 325 | """ 326 | lxs = jnp.zeros((self.dim_x, self.horizon, self.num_players)) 327 | lus = jnp.zeros((self.dim_u_ss, self.horizon, self.num_players)) 328 | Hxxs = jnp.zeros((self.dim_x, self.dim_x, self.horizon, self.num_players)) 329 | Huus = jnp.zeros((self.dim_u_ss, self.dim_u_ss, self.horizon, self.num_players)) 330 | 331 | for ii in range(self.num_players): 332 | lxs_ii, lus_ii, Hxxs_ii, Huus_ii = self.costs[ii].get_derivatives( 333 | xs, us[:, :, ii], targets[:, ii] 334 | ) 335 | lxs = lxs.at[:, :, ii].set(lxs_ii) 336 | lus = lus.at[:, :, ii].set(lus_ii) 337 | Hxxs = Hxxs.at[:, :, :, ii].set(Hxxs_ii) 338 | Huus = Huus.at[:, :, :, ii].set(Huus_ii) 339 | return lxs, lus, Hxxs, Huus 340 | 341 | @partial(jit, static_argnames='self') 342 | def backward_pass( 343 | self, As: ArrayImpl, Bs: ArrayImpl, lxs: ArrayImpl, lus: ArrayImpl, Hxxs: ArrayImpl, 344 | Huus: ArrayImpl 345 | ) -> Tuple[ArrayImpl, ArrayImpl, ArrayImpl, ArrayImpl]: 346 | """ 347 | Solves a time-varying, finite horizon LQ game (finds closed-loop Nash 348 | feedback strategies for both players). 349 | Assumes that dynamics are given by 350 | ``` dx_{k+1} = A_k dx_k + \sum_i Bs[i]_k du[i]_k ``` 351 | 352 | Derivation can be found in: 353 | https://github.com/HJReachability/ilqgames/blob/master/derivations/feedback_lq_nash.pdf 354 | 355 | Args: 356 | As (ArrayImpl): (nx, nx, N) A matrices 357 | Bs (ArrayImpl): (nui, nui, N, num_players) B matrices 358 | lxs (ArrayImpl): (nx, N, num_players) gradients lx = dc/dx of all playes 359 | lus (ArrayImpl): (nui, N, num_players) gradients lu = dc/du of all playes 360 | Hxxs (ArrayImpl): (nx, nx, N, num_players) Hessians Hxx of all playes 361 | Huus (ArrayImpl): (nui, nui, N, num_players) Hessians Huu of all playes 362 | 363 | Returns: 364 | ArrayImpl: Ps (dim_u_ss, dim_x, N-1, num_players) 365 | ArrayImpl: alphas (dim_u_ss, N-1, num_players) 366 | """ 367 | 368 | @jit 369 | def backward_pass_looper(k, _carry): 370 | Ps, alphas, Z, zeta = _carry 371 | n = horizon - 1 - k 372 | 373 | # Computes Ps given previously computed Z. 374 | S = jnp.array(()).reshape(0, sum(self.dynamics._u_dims)) 375 | Y1 = jnp.array(()).reshape(0, dim_x) 376 | for ii in range(num_players): 377 | Sii = jnp.array(()).reshape(dim_u_ss, 0) 378 | for jj in range(num_players): 379 | if jj == ii: 380 | Sii = jnp.hstack( 381 | (Sii, Bs[:, :, n, ii].T @ Z[:, :, ii] @ Bs[:, :, n, jj] + Huus[:, :, n, ii]) 382 | ) 383 | else: 384 | Sii = jnp.hstack((Sii, Bs[:, :, n, ii].T @ Z[:, :, ii] @ Bs[:, :, n, jj])) 385 | S = jnp.vstack((S, Sii)) 386 | 387 | Y1ii = Bs[:, :, n, ii].T @ Z[:, :, ii] @ As[:, :, n] 388 | Y1 = jnp.vstack((Y1, Y1ii)) 389 | 390 | P, _, _, _ = jnp.linalg.lstsq(a=S, b=Y1, rcond=None) 391 | # Sinv = jnp.linalg.pinv(S) 392 | # P = Sinv @ Y1 393 | 394 | for ii in range(num_players): 395 | Pii = self.dynamics._LMu[:, :, ii] @ P 396 | Ps = Ps.at[:, :, n, ii].set(Pii) 397 | 398 | # Computes F_k = A_k - B1_k P1_k - B2_k P2_k -... 399 | F = As[:, :, n] 400 | for ii in range(num_players): 401 | F -= Bs[:, :, n, ii] @ Ps[:, :, n, ii] 402 | 403 | # Computes alphas using previously computed zetas. 404 | Y2 = jnp.array(()).reshape(0, 1) 405 | for ii in range(num_players): 406 | # Y2ii = (Bs[:, :, n, ii].T @ zeta[:, ii]).reshape((dim_u_ss, 1)) 407 | Y2ii = (Bs[:, :, n, ii].T @ zeta[:, ii] + lus[:, n, ii]).reshape((dim_u_ss, 1)) 408 | Y2 = jnp.vstack((Y2, Y2ii)) 409 | 410 | alpha, _, _, _ = jnp.linalg.lstsq(a=S, b=Y2, rcond=None) 411 | # alpha = Sinv @ Y2 412 | 413 | for ii in range(num_players): 414 | alphaii = self.dynamics._LMu[:, :, ii] @ alpha 415 | alphas = alphas.at[:, n, ii].set(alphaii[:, 0]) 416 | 417 | # Computes beta_k = -B1_k alpha1 - B2_k alpha2_k -... 418 | beta = 0. 419 | for ii in range(num_players): 420 | beta -= Bs[:, :, n, ii] @ alphas[:, n, ii] 421 | 422 | # Updates zeta. 423 | for ii in range(num_players): 424 | _FZb = F.T @ (zeta[:, ii] + Z[:, :, ii] @ beta) 425 | _PRa = Ps[:, :, n, ii].T @ Huus[:, :, n, ii] @ alphas[:, n, ii] 426 | zeta = zeta.at[:, ii].set(_FZb + _PRa + lxs[:, n, ii]) 427 | 428 | # Updates Z. 429 | for ii in range(num_players): 430 | _FZF = F.T @ Z[:, :, ii] @ F 431 | _PRP = Ps[:, :, n, ii].T @ Huus[:, :, n, ii] @ Ps[:, :, n, ii] 432 | Z = Z.at[:, :, ii].set(_FZF + _PRP + Hxxs[:, :, n, ii]) 433 | 434 | return Ps, alphas, Z, zeta 435 | 436 | # Unpacks horizon and number of players. 437 | horizon = self.horizon 438 | num_players = self.num_players 439 | 440 | # Caches dimensions of state and controls for each player. 441 | dim_x = self.dim_x 442 | dim_u_ss = self.dim_u_ss 443 | 444 | # Recursively computes all intermediate and final variables. 445 | Z = Hxxs[:, :, -1, :] 446 | zeta = lxs[:, -1, :] 447 | 448 | # Initializes strategy matrices. 449 | Ps = jnp.zeros((dim_u_ss, dim_x, horizon, num_players)) 450 | alphas = jnp.zeros((dim_u_ss, horizon, num_players)) 451 | 452 | # Backward pass. 453 | Ps, alphas, Z, zeta = lax.fori_loop( 454 | 0, self.horizon, backward_pass_looper, (Ps, alphas, Z, zeta) 455 | ) 456 | 457 | return Ps, alphas, Z, zeta 458 | -------------------------------------------------------------------------------- /wpf/STP/ilqr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Jaxified iterative Linear Quadrative Regulator (iLQR). 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Author: Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | 8 | import time 9 | import numpy as np 10 | from typing import Tuple 11 | 12 | from .dynamics import * 13 | from .cost import * 14 | 15 | from functools import partial 16 | from jax import jit, lax 17 | from jaxlib.xla_extension import ArrayImpl 18 | import jax.numpy as jnp 19 | 20 | 21 | class iLQR: 22 | 23 | def __init__(self, config, num_leaders: float = 0): 24 | self.horizon = config.N 25 | self.max_iter = config.MAX_ITER 26 | self.dim_x = config.DIM_X 27 | self.dim_u = config.DIM_U 28 | self.tol = 1e-2 29 | self.lambad_init = 10.0 30 | self.lambad_min = 1e-3 31 | self.alphas = 1.1**(-np.arange(10)**2) 32 | 33 | self.dynamics = globals()[config.DYNAMICS](config) 34 | self.cost = globals()[config.COST](config, num_leaders) 35 | 36 | def solve( 37 | self, 38 | cur_state: ArrayImpl, 39 | controls: ArrayImpl, 40 | leader_trajs: ArrayImpl, 41 | target: ArrayImpl, 42 | ) -> Tuple[np.ndarray, np.ndarray, float, float, int]: 43 | """ 44 | Solves the iLQR-STP problem. 45 | 46 | Args: 47 | cur_state (ArrayImpl): (dim_x,) 48 | controls (ArrayImpl): (self.dim_u, N - 1) 49 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs 50 | target (ArrayImpl): (dim_x,) target state 51 | 52 | Returns: 53 | states: np.ndarray 54 | controls: np.ndarray 55 | J: float 56 | t_process: float 57 | status: int 58 | """ 59 | status = 0 60 | time0 = time.time() 61 | 62 | self.lambad = self.lambad_init 63 | 64 | # Initial forward pass. 65 | if controls is None: 66 | controls = jnp.zeros((self.dim_u, self.horizon)) 67 | states = jnp.zeros((self.dim_x, self.horizon)) 68 | states = states.at[:, 0].set(cur_state) 69 | 70 | for i in range(1, self.horizon): 71 | _states_i, _ = self.dynamics.integrate_forward(states[:, i - 1], controls[:, i - 1]) 72 | states = states.at[:, i].set(_states_i) 73 | J = self.cost.get_cost(states, controls, leader_trajs, target) 74 | 75 | converged = False 76 | 77 | # Main loop. 78 | for i in range(self.max_iter): 79 | # Backward pass. 80 | Ks, ks = self.backward_pass(states, controls, leader_trajs, target) 81 | 82 | # Linesearch 83 | updated = False 84 | for alpha in self.alphas: 85 | X_new, U_new, J_new = self.forward_pass( 86 | states, controls, leader_trajs, Ks, ks, alpha, target 87 | ) 88 | if J_new <= J: 89 | if jnp.abs((J-J_new) / J) < self.tol: 90 | converged = True 91 | J = J_new 92 | states = X_new 93 | controls = U_new 94 | updated = True 95 | break 96 | if updated: 97 | self.lambad *= 0.7 98 | else: 99 | status = 2 100 | break 101 | self.lambad = max(self.lambad_min, self.lambad) 102 | 103 | if converged: 104 | status = 1 105 | break 106 | t_process = time.time() - time0 107 | 108 | # print(t_process) 109 | 110 | return np.asarray(states), np.asarray(controls), J, t_process, status 111 | 112 | # ----------------------------- Jitted functions ----------------------------- 113 | @partial(jit, static_argnames="self") 114 | def forward_pass( 115 | self, 116 | nominal_states: ArrayImpl, 117 | nominal_controls: ArrayImpl, 118 | leader_trajs: ArrayImpl, 119 | Ks: ArrayImpl, 120 | ks: ArrayImpl, 121 | alpha: float, 122 | target: ArrayImpl, 123 | ) -> Tuple[ArrayImpl, ArrayImpl, float]: 124 | """ 125 | Jitted forward pass looped computation. 126 | 127 | Args: 128 | nominal_states (ArrayImpl): (dim_x, N) 129 | nominal_controls (ArrayImpl): (dim_u, N) 130 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs 131 | Ks (ArrayImpl): gain matrices (dim_u, dim_x, N - 1) 132 | ks (ArrayImpl): gain vectors (dim_u, N - 1) 133 | alpha (float): scalar parameter 134 | target (ArrayImpl): (dim_x,) target state 135 | 136 | Returns: 137 | Xs (ArrayImpl): (dim_x, N) 138 | Us (ArrayImpl): (dim_u, N) 139 | J (float): total cost 140 | """ 141 | 142 | @jit 143 | def forward_pass_looper(i, _carry): 144 | Xs, Us = _carry 145 | u = ( 146 | nominal_controls[:, i] + alpha * ks[:, i] 147 | + Ks[:, :, i] @ (Xs[:, i] - nominal_states[:, i]) 148 | ) 149 | X_next, U_next = self.dynamics.integrate_forward(Xs[:, i], u) 150 | Xs = Xs.at[:, i + 1].set(X_next) 151 | Us = Us.at[:, i].set(U_next) 152 | return Xs, Us 153 | 154 | # Computes trajectories. 155 | Xs = jnp.zeros((self.dim_x, self.horizon)) 156 | Us = jnp.zeros((self.dim_u, self.horizon)) 157 | Xs = Xs.at[:, 0].set(nominal_states[:, 0]) 158 | Xs, Us = lax.fori_loop(0, self.horizon - 1, forward_pass_looper, (Xs, Us)) 159 | 160 | # Computes the total cost. 161 | J = self.cost.get_cost(Xs, Us, leader_trajs, target) 162 | 163 | return Xs, Us, J 164 | 165 | @partial(jit, static_argnames="self") 166 | def backward_pass( 167 | self, 168 | nominal_states: ArrayImpl, 169 | nominal_controls: ArrayImpl, 170 | leader_trajs: ArrayImpl, 171 | target: ArrayImpl, 172 | ) -> Tuple[ArrayImpl, ArrayImpl]: 173 | """ 174 | Jitted backward pass looped computation. 175 | 176 | Args: 177 | nominal_states (ArrayImpl): (dim_x, N) 178 | nominal_controls (ArrayImpl): (dim_u, N) 179 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs 180 | target (ArrayImpl): (dim_x,) target state 181 | 182 | Returns: 183 | Ks (ArrayImpl): gain matrices (dim_u, dim_x, N - 1) 184 | ks (ArrayImpl): gain vectors (dim_u, N - 1) 185 | """ 186 | 187 | @jit 188 | def backward_pass_looper(i, _carry): 189 | V_x, V_xx, ks, Ks = _carry 190 | n = self.horizon - 2 - i 191 | 192 | Q_x = L_x[:, n] + fx[:, :, n].T @ V_x 193 | Q_u = L_u[:, n] + fu[:, :, n].T @ V_x 194 | Q_xx = L_xx[:, :, n] + fx[:, :, n].T @ V_xx @ fx[:, :, n] 195 | Q_ux = fu[:, :, n].T @ V_xx @ fx[:, :, n] 196 | Q_uu = L_uu[:, :, n] + fu[:, :, n].T @ V_xx @ fu[:, :, n] 197 | 198 | Q_uu_inv = jnp.linalg.inv(Q_uu + reg_mat) 199 | 200 | Ks = Ks.at[:, :, n].set(-Q_uu_inv @ Q_ux) 201 | ks = ks.at[:, n].set(-Q_uu_inv @ Q_u) 202 | 203 | V_x = Q_x - Ks[:, :, n].T @ Q_uu @ ks[:, n] 204 | V_xx = Q_xx - Ks[:, :, n].T @ Q_uu @ Ks[:, :, n] 205 | 206 | return V_x, V_xx, ks, Ks 207 | 208 | # Computes cost derivatives. 209 | L_x, L_xx, L_u, L_uu = self.cost.get_derivatives( 210 | nominal_states, nominal_controls, leader_trajs, target 211 | ) 212 | 213 | # Computes dynamics Jacobians. 214 | fx, fu = self.dynamics.get_AB_matrix(nominal_states, nominal_controls) 215 | 216 | # Computes the control policy. 217 | Ks = jnp.zeros((self.dim_u, self.dim_x, self.horizon - 1)) 218 | ks = jnp.zeros((self.dim_u, self.horizon - 1)) 219 | V_x = L_x[:, -1] 220 | V_xx = L_xx[:, :, -1] 221 | reg_mat = self.lambad * jnp.eye(self.dim_u) 222 | 223 | V_x, V_xx, ks, Ks = lax.fori_loop( 224 | 0, self.horizon - 1, backward_pass_looper, (V_x, V_xx, ks, Ks) 225 | ) 226 | return Ks, ks 227 | 228 | @partial(jit, static_argnames="self") 229 | def compute_cost( 230 | self, 231 | nominal_states: ArrayImpl, 232 | nominal_controls: ArrayImpl, 233 | leader_trajs: ArrayImpl, 234 | target: ArrayImpl, 235 | ) -> float: 236 | """ 237 | Computes accumulated cost along a trajectory. 238 | 239 | Args: 240 | nominal_states (ArrayImpl): (dim_x, N) 241 | nominal_controls (ArrayImpl): (dim_u, N) 242 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs 243 | target (ArrayImpl): (dim_x,) target state 244 | 245 | Returns: 246 | J (float): total cost 247 | """ 248 | return self.cost.get_cost(nominal_states, nominal_controls, leader_trajs, target) 249 | 250 | 251 | class LQR_shielded: 252 | 253 | def __init__(self, config, num_leaders: float = 0): 254 | self.horizon = config.N 255 | self.dim_x = config.DIM_X 256 | self.dim_u = config.DIM_U 257 | self.config = config 258 | self.dynamics = globals()[config.DYNAMICS](config) 259 | self.cost = globals()[config.COST](config, num_leaders) 260 | self.dt = self.dynamics.dt 261 | 262 | def solve( 263 | self, cur_state: ArrayImpl, ctrl_ws: ArrayImpl, leader_trajs: ArrayImpl, target: ArrayImpl 264 | ) -> Tuple[np.ndarray, np.ndarray, float, float, int]: 265 | time0 = time.time() 266 | states, controls, J, status = self.solve_jax(cur_state, leader_trajs, target) 267 | t_process = time.time() - time0 268 | return np.asarray(states), np.asarray(controls), float(J), t_process, status 269 | 270 | @partial(jit, static_argnames="self") 271 | def solve_jax(self, cur_state: ArrayImpl, leader_trajs: ArrayImpl, 272 | target: ArrayImpl) -> Tuple[np.ndarray, np.ndarray, float, float, int]: 273 | """ 274 | Solves the shielded LQR problem. 275 | 276 | Args: 277 | cur_state (ArrayImpl): (dim_x) 278 | leader_trajs (ArrayImpl): (dim_x, N, N_leader) all leaders' optimized state trajs 279 | target (ArrayImpl): (dim_x, N) target states 280 | 281 | Returns: 282 | states: np.ndarray 283 | controls: np.ndarray 284 | J: float 285 | t_process: float 286 | status: int 287 | """ 288 | 289 | @jit 290 | def _looper(i, _carry): 291 | 292 | def true_fn(state, control): 293 | return jnp.array((-state[0] / self.dt, -state[1] / self.dt)) # Shielding action 294 | 295 | def false_fn(state, control): 296 | return control 297 | 298 | def _check_two_agents(state1, state2): 299 | _dx = self.config.PX_DIM 300 | _dy = self.config.PY_DIM 301 | pxpy_diff = jnp.stack((state1[_dx] - state2[_dx], state1[_dy] - state2[_dy])) 302 | sep = jnp.linalg.norm(pxpy_diff, axis=0) 303 | return sep < self.config.PROX_SEP_CHECK 304 | 305 | states, controls = _carry 306 | 307 | lqr_ctrl = -K_lqr @ (states[:, i] - target[:, i]) 308 | _state_lqr, _ = self.dynamics.integrate_forward(states[:, i], lqr_ctrl) 309 | _check_two_agents_vmap = vmap(_check_two_agents, in_axes=(None, 1), out_axes=(0)) 310 | pred = jnp.any(_check_two_agents_vmap(_state_lqr, leader_trajs[:, i, :])) 311 | shielded_ctrl = lax.cond(pred, true_fn, false_fn, states[:, i], lqr_ctrl) 312 | _state_nxt, _ctrl_nxt = self.dynamics.integrate_forward(states[:, i], shielded_ctrl) 313 | states = states.at[:, i + 1].set(_state_nxt) 314 | controls = controls.at[:, i].set(_ctrl_nxt) 315 | return states, controls 316 | 317 | status = 0 318 | K_lqr = self.cost.K_lqr 319 | states = jnp.zeros((self.dim_x, self.horizon)) 320 | states = states.at[:, 0].set(cur_state) 321 | controls = jnp.zeros((self.dim_u, self.horizon)) 322 | 323 | states, controls = lax.fori_loop(0, self.horizon, _looper, (states, controls)) 324 | 325 | J = self.cost.get_cost(states, controls, leader_trajs, target) 326 | 327 | return states, controls, J, status 328 | -------------------------------------------------------------------------------- /wpf/STP/multiplayer_dynamical_system.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiplayer dynamical systems. 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Author: Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | from typing import List, Tuple 8 | from abc import abstractmethod 9 | 10 | from functools import partial 11 | from jax import jit, jacfwd, vmap 12 | from jaxlib.xla_extension import ArrayImpl 13 | import jax.numpy as jnp 14 | 15 | from .ilqr import Dynamics 16 | 17 | 18 | class MultiPlayerDynamicalSystem(object): 19 | """ 20 | Base class for all multiplayer continuous-time dynamical systems. Supports 21 | numrical integration and linearization. 22 | """ 23 | 24 | def __init__(self, x_dim, u_dims, T=None): 25 | """ 26 | Initialize with number of state/control dimensions. 27 | 28 | Args: 29 | x_dim (int): number of state dimensions 30 | u_dims ([int]): liset of number of control dimensions for each player 31 | T (float): time interval 32 | """ 33 | self._x_dim = x_dim 34 | self._u_dims = u_dims 35 | self._T = T 36 | self._num_players = len(u_dims) 37 | 38 | @abstractmethod 39 | def cont_time_dyn(self, x: ArrayImpl, u_list: list, k: int = 0) -> list: 40 | """ 41 | Computes the time derivative of state for a particular state/control. 42 | 43 | Args: 44 | x (ArrayImpl): joint state (nx,) 45 | u_list (list of ArrayImpl): list of controls [(nu_0,), (nu_1,), ...] 46 | 47 | Returns: 48 | list of ArrayImpl: list of next states [(nx_0,), (nx_1,), ...] 49 | """ 50 | raise NotImplementedError 51 | 52 | @partial(jit, static_argnames='self') 53 | def disc_time_dyn(self, x: ArrayImpl, us: ArrayImpl, k: int = 0) -> ArrayImpl: 54 | """ 55 | Computes the one-step evolution of the system in discrete time with Euler integration. 56 | 57 | Args: 58 | x (ArrayImpl): joint state (nx,) 59 | us (ArrayImpl): agent controls (nui, N_sys) 60 | 61 | Returns: 62 | ArrayImpl: next state (nx,) 63 | """ 64 | x_dot = self.cont_time_dyn(x, us, k) 65 | return x + self._T * x_dot 66 | 67 | @abstractmethod 68 | def linearize_discrete_jitted(self, x: ArrayImpl, us: ArrayImpl, 69 | k: int = 0) -> Tuple[ArrayImpl, ArrayImpl]: 70 | """ 71 | Compute the Jacobian linearization of the dynamics for a particular state x and control us. 72 | Outputs A and B matrices of a discrete-time linear system. 73 | 74 | Args: 75 | x (ArrayImpl): joint state (nx,) 76 | us (ArrayImpl): agent controls (nui, N_sys) 77 | 78 | Returns: 79 | ArrayImpl: the Jacobian of next state w.r.t. x (nx, nx) 80 | ArrayImpl: the Jacobians of next state w.r.t. us (nx, nui, N_sys) 81 | """ 82 | raise NotImplementedError 83 | 84 | 85 | class ProductMultiPlayerDynamicalSystem(MultiPlayerDynamicalSystem): 86 | 87 | def __init__(self, subsystem_list: List[Dynamics], T: float = 0.1): 88 | """ 89 | Implements a multiplayer dynamical system who's dynamics decompose into a Cartesian product of 90 | single-player dynamical systems. 91 | 92 | Initialize with a list of dynamical systems. 93 | 94 | NOTE: 95 | - Assumes that all subsystems have the same state and control dimension. 96 | 97 | Args: 98 | subsystem_list ([DynamicalSystem]): single-player dynamical system 99 | T (float): discretization time interval 100 | """ 101 | self._N_sys = len(subsystem_list) 102 | self._subsystem = subsystem_list[0] 103 | self._subsystems = subsystem_list 104 | 105 | self._x_dims = [subsys.dim_x for subsys in subsystem_list] 106 | self._x_dim = sum(self._x_dims) 107 | self._u_dims = [subsys.dim_u for subsys in subsystem_list] 108 | self._u_dim = sum(self._u_dims) 109 | 110 | super(ProductMultiPlayerDynamicalSystem, self).__init__(self._x_dim, self._u_dims, T) 111 | 112 | self.update_lifting_matrices() 113 | 114 | # Pre-computes Jacobian matrices. 115 | self.jac_f = jit(jacfwd(self.disc_time_dyn, argnums=[0, 1])) 116 | 117 | def update_lifting_matrices(self): 118 | """ 119 | Updates the lifting matrices. 120 | """ 121 | # Creates lifting matrices LMx_i for subsystem i such that LMx_i @ x = xi. 122 | _split_index = jnp.hstack((0, jnp.cumsum(jnp.asarray(self._x_dims)))) 123 | self._LMx = jnp.zeros((self._subsystem.dim_x, self._x_dim, self._N_sys)) 124 | _id_mat = jnp.eye(self._subsystem.dim_x) 125 | for i in range(self._N_sys): 126 | self._LMx = self._LMx.at[:, _split_index[i]:_split_index[i + 1], i].set(_id_mat) 127 | 128 | # Creates lifting matrices LMu_i for subsystem i such that LMu_i @ u = ui. 129 | _split_index = jnp.hstack((0, jnp.cumsum(jnp.asarray(self._u_dims)))) 130 | self._LMu = jnp.zeros((self._subsystem.dim_u, self._u_dim, self._N_sys)) 131 | _id_mat = jnp.eye(self._subsystem.dim_u) 132 | for i in range(self._N_sys): 133 | self._LMu = self._LMu.at[:, _split_index[i]:_split_index[i + 1], i].set(_id_mat) 134 | 135 | @partial(jit, static_argnames='self') 136 | def split_joint_state(self, x: ArrayImpl) -> ArrayImpl: 137 | """ 138 | Splits the joint state. 139 | 140 | Args: 141 | x (ArrayImpl): joint state (nx,) 142 | 143 | Returns: 144 | ArrayImpl: states (nxi, N_sys) 145 | """ 146 | _split = lambda LMx, x: LMx @ x 147 | 148 | _split_vmap = vmap(_split, in_axes=(2, None), out_axes=(1)) 149 | return _split_vmap(self._LMx, x) 150 | 151 | @partial(jit, static_argnames='self') 152 | def disc_time_dyn(self, x: ArrayImpl, us: ArrayImpl, k: int = 0) -> ArrayImpl: 153 | """ 154 | Computes the one-step time evolution of the system. 155 | 156 | Args: 157 | x (ArrayImpl): joint state (nx,) 158 | us (ArrayImpl): agent controls (nui, N_sys) 159 | 160 | Returns: 161 | ArrayImpl: next joint state (nx,) 162 | """ 163 | xs = self.split_joint_state(x) 164 | 165 | # vmap over subsystems. 166 | _disc_time_dyn_vmap = vmap(self._subsystem.dct_time_dyn, in_axes=(1, 1), out_axes=(1)) 167 | xs_next = _disc_time_dyn_vmap(xs, us) 168 | 169 | return xs_next.flatten('F') 170 | 171 | @partial(jit, static_argnames='self') 172 | def integrate_forward(self, x: ArrayImpl, us: ArrayImpl, k: int = 0) -> ArrayImpl: 173 | """ 174 | Computes the one-step time evolution of the system. 175 | 176 | Args: 177 | x (ArrayImpl): joint state (nx,) 178 | us (ArrayImpl): agent controls (nui, N_sys) 179 | 180 | Returns: 181 | ArrayImpl: next joint state (nx,) 182 | ArrayImpl: next clipped controls (nui, num_players) 183 | """ 184 | xs = self.split_joint_state(x) 185 | 186 | # vmap over subsystems. 187 | _integrate_forward_vmap = vmap( 188 | self._subsystem.integrate_forward, in_axes=(1, 1), out_axes=(1, 1) 189 | ) 190 | xs_next, us_next_clipped = _integrate_forward_vmap(xs, us) 191 | 192 | return xs_next.flatten('F'), us_next_clipped 193 | 194 | @partial(jit, static_argnames='self') 195 | def linearize_discrete_jitted(self, x: ArrayImpl, us: ArrayImpl, 196 | k: int = 0) -> Tuple[ArrayImpl, ArrayImpl]: 197 | """ 198 | Compute the Jacobian linearization of the dynamics for a particular state x and control us. 199 | Outputs A and B matrices of a discrete-time linear system. 200 | 201 | Args: 202 | x (ArrayImpl): joint state (nx,) 203 | us (ArrayImpl): agent controls (nui, N_sys) 204 | 205 | Returns: 206 | ArrayImpl: the Jacobian of next state w.r.t. x (nx, nx) 207 | ArrayImpl: the Jacobians of next state w.r.t. us (nx, nui, N_sys) 208 | """ 209 | A_disc, Bs_disc = self.jac_f(x, us, k) 210 | return A_disc, Bs_disc 211 | -------------------------------------------------------------------------------- /wpf/STP/stp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sequential ILQR-based trajectory planning. 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Author: Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | 8 | import numpy as np 9 | import jax.numpy as jnp 10 | from functools import partial 11 | from jax import jit, vmap, lax 12 | from jaxlib.xla_extension import ArrayImpl 13 | from typing import Tuple, List, Dict 14 | from copy import deepcopy 15 | 16 | from .ilqr import iLQR, LQR_shielded 17 | from .utils import Struct 18 | 19 | 20 | class STP: 21 | 22 | def __init__(self, config: Struct): 23 | """ 24 | Initializer. 25 | 26 | Args: 27 | config (Struct): config file 28 | order (List): list of int or None 29 | """ 30 | 31 | self.config = config 32 | self.horizon = config.N 33 | self.max_iter = config.MAX_ITER 34 | self.num_agents = config.N_AGENT 35 | 36 | # Pre-compile all sub-iLQR problems 37 | self.solvers = [] 38 | for iagent in range(self.num_agents): 39 | self.solvers.append(iLQR(config, num_leaders=iagent)) 40 | 41 | def get_init_ctrl_from_LQR( 42 | self, N: float, cur_state: List[np.ndarray], targets: List[np.ndarray] 43 | ) -> Tuple[List[np.ndarray], List[np.ndarray]]: 44 | 45 | @jit 46 | def get_LQR_ctrl(x_init, target): 47 | 48 | def _looper(i, carry): 49 | states, controls = carry 50 | control_LQR = -K_lqr @ (states[:, i] - _target) 51 | state_next, control = self.solvers[0].dynamics.integrate_forward(states[:, i], control_LQR) 52 | return states.at[:, i + 1].set(state_next), controls.at[:, i].set(control) 53 | 54 | states = jnp.zeros((self.config.DIM_X, N)) 55 | states = states.at[:, 0].set(x_init) 56 | controls = jnp.zeros((self.config.DIM_U, N)) 57 | _target = jnp.array([target[0], target[1], 0., 0.]) 58 | states, controls = lax.fori_loop(0, N, _looper, (states, controls)) 59 | return states, controls 60 | 61 | K_lqr = self.solvers[0].cost.K_lqr 62 | state_list, ctrl_list = [], [] 63 | for ii in range(len(cur_state)): 64 | states_ii, controls_ii = get_LQR_ctrl(cur_state[ii], targets[ii]) 65 | state_list.append(states_ii) 66 | ctrl_list.append(controls_ii) 67 | return state_list, ctrl_list 68 | 69 | def solve( 70 | self, cur_state: List[np.ndarray], ctrls_ws: List[np.ndarray], targets: List[np.ndarray], 71 | order: np.ndarray 72 | ) -> Tuple[List[np.ndarray], List[np.ndarray], List[float], List[float]]: 73 | """ 74 | Solves the iLQR-STP problem. 75 | 76 | Args: 77 | cur_state (List[np.ndarray]): [(dim_x,)] 78 | ctrls_ws (List[np.ndarray]): [(self.dim_u, N - 1)] warmstart control sequences 79 | targets (List[np.ndarray]): List of agents' target states 80 | order (np.ndarray): [id of player who plays 1st, id of player who plays 2nd, ...] 81 | 82 | Returns: 83 | states: List[np.ndarray] 84 | controls: List[np.ndarray] 85 | Js: List[float] 86 | ts: List[float] 87 | """ 88 | _Na = self.num_agents 89 | states, controls, Js, ts = ( 90 | [None] * _Na, 91 | [None] * _Na, 92 | [None] * _Na, 93 | [None] * _Na, 94 | ) 95 | leader_trajs = [] # arranged according to orders but agnostic to agent id 96 | 97 | cur_state = [jnp.asarray(_cur_state) for _cur_state in cur_state] 98 | ctrls_ws = [jnp.asarray(_ctrls_ws) for _ctrls_ws in ctrls_ws] 99 | targets = [jnp.asarray(_target) for _target in targets] 100 | 101 | # Computes trajectories for agents who are assigned order of play. 102 | assert len(order) == self.num_agents 103 | _assigned_agents = [] 104 | for i_order in range(len(order)): 105 | agent_id = order[i_order] 106 | if (agent_id is None): # assuming order = [agent assigned order | agent not assigned order] 107 | break 108 | else: 109 | _assigned_agents.append(agent_id) 110 | if len(leader_trajs) > 0: 111 | _leader_trajs = np.stack(leader_trajs, axis=2) 112 | else: 113 | _leader_trajs = None 114 | states_tmp, controls_tmp, J_tmp, t_tmp, _ = self.solvers[i_order].solve( 115 | cur_state[agent_id], 116 | ctrls_ws[agent_id], 117 | _leader_trajs, 118 | targets[agent_id], 119 | ) 120 | states[agent_id], controls[agent_id], Js[agent_id], ts[agent_id] = ( 121 | states_tmp, 122 | controls_tmp, 123 | J_tmp, 124 | t_tmp, 125 | ) 126 | leader_trajs.append(states[agent_id]) 127 | 128 | # Computes trajectories for unassigned agents. 129 | if len(leader_trajs) > 0: 130 | _leader_trajs = np.stack(leader_trajs, axis=2) 131 | else: 132 | _leader_trajs = None 133 | 134 | if len(_assigned_agents) < self.num_agents: 135 | _order_unassigned = len(_assigned_agents) 136 | assert _order_unassigned == len(leader_trajs) 137 | for agent_id in range(self.num_agents): 138 | if agent_id not in _assigned_agents: 139 | # print('Unassigned agent detected, id =', agent_id) 140 | 141 | # -> Option A. Unassigned players play last. 142 | if not (hasattr(self.config, "OPTION_B") and self.config.OPTION_B): 143 | states_tmp, controls_tmp, J_tmp, t_tmp, _ = self.solvers[_order_unassigned].solve( 144 | cur_state[agent_id], ctrls_ws[agent_id], _leader_trajs, targets[agent_id] 145 | ) 146 | 147 | # -> Option B. Unassigned players play blindly. 148 | else: 149 | states_tmp, controls_tmp, J_tmp, t_tmp, _ = self.solvers[0].solve( 150 | cur_state[agent_id], ctrls_ws[agent_id], None, targets[agent_id] 151 | ) 152 | 153 | states[agent_id], controls[agent_id], Js[agent_id], ts[agent_id] = ( 154 | states_tmp, 155 | controls_tmp, 156 | J_tmp, 157 | t_tmp, 158 | ) 159 | 160 | states = [np.asarray(_states) for _states in states] 161 | controls = [np.asarray(_controls) for _controls in controls] 162 | Js = [float(_J) for _J in Js] 163 | 164 | return states, controls, Js, ts 165 | 166 | def solve_rhc( 167 | self, 168 | cur_state: List[np.ndarray], 169 | ctrls_ws: List[np.ndarray], 170 | targets: List[np.ndarray], 171 | order: np.ndarray, 172 | ) -> Tuple[List[np.ndarray], List[np.ndarray], List[float], List[float]]: 173 | """ 174 | Solves the iLQR-STP problem with receding horizon control. 175 | 176 | Args: 177 | cur_state (List[np.ndarray]): [(dim_x,)] 178 | ctrls_ws (List[np.ndarray]): [(self.dim_u, N - 1)] warmstart control sequences 179 | targets (List[np.ndarray]): List of agents' target states 180 | order (np.ndarray): [id of player who plays 1st, id of player who plays 2nd, ...] 181 | 182 | Returns: 183 | states: List[np.ndarray] 184 | controls: List[np.ndarray] 185 | Js: List[float] 186 | ts: List[float] 187 | """ 188 | 189 | def _update_rhc(states_tmp, controls_tmp, t_tmp, leader_trajs, agent_id, i_order, horizon): 190 | if leader_trajs is not None: 191 | leader_trajs = leader_trajs[:, :_ols] 192 | 193 | if horizon == 0: 194 | states[agent_id], controls[agent_id], ts[agent_id] = ( 195 | states_tmp[:, :_ols], 196 | controls_tmp[:, :_ols], 197 | t_tmp, 198 | ) 199 | Js[agent_id] = self.solvers[i_order].compute_cost( 200 | states_tmp[:, :_ols], 201 | controls_tmp[:, :_ols], 202 | leader_trajs, 203 | targets[agent_id], 204 | ) 205 | 206 | else: 207 | states[agent_id] = jnp.concatenate((states[agent_id], states_tmp[:, 1:_ols + 1]), axis=1) 208 | controls[agent_id] = jnp.concatenate((controls[agent_id], controls_tmp[:, :_ols]), axis=1) 209 | ts[agent_id] += t_tmp 210 | Js[agent_id] += self.solvers[i_order].compute_cost( 211 | states_tmp[:, :_ols], 212 | controls_tmp[:, :_ols], 213 | leader_trajs, 214 | targets[agent_id], 215 | ) 216 | 217 | ctrls_ws[agent_id] = ctrls_ws[agent_id].at[:, :-1].set(controls_tmp[:, 1:]) 218 | 219 | _N = self.config.N 220 | _Na = self.num_agents 221 | _ols = (self.config.OPEN_LOOP_STEP) # Open-loop simulation steps between two RHC cycles. 222 | 223 | states, controls, Js, ts = ( 224 | [None] * _Na, 225 | [None] * _Na, 226 | [None] * _Na, 227 | [None] * _Na, 228 | ) 229 | cur_state = [jnp.asarray(_cur_state) for _cur_state in cur_state] 230 | ctrls_ws = [jnp.asarray(_ctrls_ws[:, :_N]) for _ctrls_ws in ctrls_ws] 231 | targets = [jnp.asarray(_target) for _target in targets] 232 | 233 | horizon = 0 234 | 235 | while horizon < self.config.RHC_STEPS: 236 | leader_trajs = [] # arranged according to orders but agnostic to agent id 237 | 238 | # Computes trajectories for agents who are assigned order of play. 239 | assert len(order) == self.num_agents 240 | _assigned_agents = [] 241 | for i_order in range(len(order)): 242 | agent_id = order[i_order] 243 | if (agent_id is None): # assuming order = [agent assigned order | agent not assigned order] 244 | break 245 | else: 246 | _assigned_agents.append(agent_id) 247 | if len(leader_trajs) > 0: 248 | _leader_trajs = np.stack(leader_trajs, axis=2) 249 | else: 250 | _leader_trajs = None 251 | states_tmp, controls_tmp, _, t_tmp, _ = self.solvers[i_order].solve( 252 | cur_state[agent_id], 253 | ctrls_ws[agent_id], 254 | _leader_trajs, 255 | targets[agent_id], 256 | ) 257 | _update_rhc( 258 | states_tmp, 259 | controls_tmp, 260 | t_tmp, 261 | _leader_trajs, 262 | agent_id, 263 | i_order, 264 | horizon, 265 | ) 266 | leader_trajs.append(states_tmp) 267 | 268 | # Computes trajectories for unassigned agents. 269 | if len(leader_trajs) > 0: 270 | _leader_trajs = np.stack(leader_trajs, axis=2) 271 | else: 272 | _leader_trajs = None 273 | 274 | if len(_assigned_agents) < self.num_agents: 275 | _order_unassigned = len(_assigned_agents) 276 | assert _order_unassigned == len(leader_trajs) 277 | for agent_id in range(self.num_agents): 278 | if agent_id not in _assigned_agents: 279 | # print('Unassigned agent detected, id =', agent_id) 280 | 281 | # -> Option A. Unassigned players play last. 282 | if not (hasattr(self.config, "OPTION_B") and self.config.OPTION_B): 283 | states_tmp, controls_tmp, _, t_tmp, _ = self.solvers[_order_unassigned].solve( 284 | cur_state[agent_id], ctrls_ws[agent_id], _leader_trajs, targets[agent_id] 285 | ) 286 | _update_rhc( 287 | states_tmp, controls_tmp, t_tmp, _leader_trajs, agent_id, _order_unassigned, 288 | horizon 289 | ) 290 | 291 | # -> Option B. Unassigned players play blindly. 292 | else: 293 | states_tmp, controls_tmp, _, t_tmp, _ = self.solvers[0].solve( 294 | cur_state[agent_id], ctrls_ws[agent_id], None, targets[agent_id] 295 | ) 296 | _update_rhc(states_tmp, controls_tmp, t_tmp, None, agent_id, 0, horizon) 297 | 298 | # Updates current states. 299 | cur_state = [_state[:, -1] for _state in states] 300 | 301 | horizon += _ols 302 | 303 | # Prepares return values. 304 | states = [np.asarray(_states) for _states in states] 305 | controls = [np.asarray(_controls) for _controls in controls] 306 | Js = [float(_J) for _J in Js] 307 | 308 | return states, controls, Js, ts 309 | 310 | @partial(jit, static_argnames="self") 311 | def pairwise_collision_check(self, states: ArrayImpl): 312 | """ 313 | Checks collisions pairwise for all agents. 314 | 315 | Args: 316 | states (ArrayImpl): (dim_x, N, N_agents) 317 | """ 318 | 319 | def _check_two_agents(state1, state2): 320 | _dx = self.config.PX_DIM 321 | _dy = self.config.PY_DIM 322 | pxpy_diff = jnp.stack((state1[_dx, :] - state2[_dx, :], state1[_dy, :] - state2[_dy, :])) 323 | sep = jnp.linalg.norm(pxpy_diff, axis=0) 324 | return jnp.any(sep < self.config.PROX_SEP_CHECK) 325 | 326 | def _looper(i, col_mat): 327 | _check_two_agents_vmap = vmap(_check_two_agents, in_axes=(None, 2), out_axes=(0)) 328 | _col_i = _check_two_agents_vmap(states[:, :, i], states) 329 | _col_i = _col_i.at[i].set(False) 330 | return col_mat.at[i, :].set(_col_i) 331 | 332 | col_mat = jnp.zeros((self.num_agents, self.num_agents), dtype=bool) 333 | col_mat = lax.fori_loop(0, self.num_agents, _looper, col_mat) 334 | return col_mat 335 | 336 | 337 | class STPMovingTargets(STP): 338 | 339 | def __init__(self, config: Struct): 340 | STP.__init__(self, config) 341 | 342 | self.solvers = [] 343 | for iagent in range(self.num_agents): 344 | self.solvers.append(iLQR(config, num_leaders=iagent + 1)) 345 | 346 | @staticmethod 347 | def generate_circle_positions(r, N): 348 | angles = np.linspace(-np.pi / 2., np.pi / 2., N, endpoint=True) 349 | positions_x = r * np.cos(angles) 350 | positions_y = r * np.sin(angles) 351 | return positions_x, positions_y 352 | 353 | def get_init_ctrl_from_LQR( 354 | self, N: float, cur_state: List[np.ndarray], history: Dict, perm: List[int] 355 | ) -> Tuple[List[np.ndarray], List[np.ndarray]]: 356 | 357 | @jit 358 | def get_LQR_ctrl(x_init, target): 359 | 360 | def _looper(i, carry): 361 | states, controls, = carry 362 | control_LQR = -K_lqr @ (states[:, i] - target[:, i]) 363 | state_next, control = self.solvers[0].dynamics.integrate_forward(states[:, i], control_LQR) 364 | return states.at[:, i + 1].set(state_next), controls.at[:, i].set(control) 365 | 366 | states = jnp.zeros((self.config.DIM_X, N)) 367 | states = states.at[:, 0].set(x_init) 368 | controls = jnp.zeros((self.config.DIM_U, N)) 369 | states, controls = lax.fori_loop(0, N, _looper, (states, controls)) 370 | return states, controls 371 | 372 | def set_targets(agent_id): 373 | _targets = deepcopy(history["alpha"]) 374 | _targets[0, :] += tar_xs[agent_id] 375 | _targets[1, :] += tar_ys[agent_id] 376 | return _targets 377 | 378 | tar_xs, tar_ys = self.generate_circle_positions(self.config.TAR_RADIUS, self.num_agents) 379 | K_lqr = self.solvers[0].cost.K_lqr 380 | state_list, ctrl_list = [], [] 381 | for ip in range(len(perm)): 382 | ii = perm[ip] 383 | _targets = set_targets(ii) 384 | # if ip == 0: 385 | # states_ii, controls_ii = get_LQR_ctrl(cur_state[ii], history["alpha"]) 386 | # else: 387 | # states_ii, controls_ii = get_LQR_ctrl(cur_state[ii], history["stp"][ip - 1]) 388 | states_ii, controls_ii = get_LQR_ctrl(cur_state[ii], _targets) 389 | state_list.append(states_ii) 390 | ctrl_list.append(controls_ii) 391 | return state_list, ctrl_list 392 | 393 | @partial(jit, static_argnames="self") 394 | def pairwise_collision_check_alpha(self, states: ArrayImpl): 395 | """ 396 | Checks collisions pairwise for all agents (including the alpha). 397 | 398 | Args: 399 | states (ArrayImpl): (dim_x, N, N_agents + 1) 400 | """ 401 | 402 | def _check_two_agents(state1, state2): 403 | _dx = self.config.PX_DIM 404 | _dy = self.config.PY_DIM 405 | pxpy_diff = jnp.stack((state1[_dx, :] - state2[_dx, :], state1[_dy, :] - state2[_dy, :])) 406 | sep = jnp.linalg.norm(pxpy_diff, axis=0) 407 | return jnp.any(sep < self.config.PROX_SEP_CHECK) 408 | 409 | def _looper(i, col_mat): 410 | _check_two_agents_vmap = vmap(_check_two_agents, in_axes=(None, 2), out_axes=(0)) 411 | _col_i = _check_two_agents_vmap(states[:, :, i], states) 412 | _col_i = _col_i.at[i].set(False) 413 | return col_mat.at[i, :].set(_col_i) 414 | 415 | col_mat = jnp.zeros((self.num_agents + 1, self.num_agents + 1), dtype=bool) 416 | col_mat = lax.fori_loop(0, self.num_agents + 1, _looper, col_mat) 417 | return col_mat 418 | 419 | def solve( 420 | self, cur_state: List[np.ndarray], ctrls_ws: List[np.ndarray], alpha_traj: np.ndarray, 421 | history: Dict, order: np.ndarray 422 | ) -> Tuple[List[np.ndarray], List[np.ndarray], List[float], List[float]]: 423 | """ 424 | Solves the iLQR-STP problem 425 | 426 | Args: 427 | cur_state (List[np.ndarray]): [(dim_x,)] 428 | ctrls_ws (List[np.ndarray]): [(self.dim_u, N - 1)] warmstart control sequences 429 | alpha_traj (np.ndarray): The alpha's planned trajectory 430 | history (Dict): Agent's historic trajectory (keys: "alpha", "stp") 431 | order (np.ndarray): [id of player who plays 1st, id of player who plays 2nd, ...] 432 | 433 | Returns: 434 | states: List[np.ndarray] 435 | controls: List[np.ndarray] 436 | Js: List[float] 437 | ts: List[float] 438 | """ 439 | 440 | def shift_by_delay(states, delay): 441 | states = deepcopy(states) 442 | states[:, :-delay] = states[:, delay:] 443 | states[:, -delay:] = states[:, -delay - 1:-delay] 444 | return states 445 | 446 | def set_targets(agent_id): 447 | # if agent_id == 0: 448 | # _targets = history["alpha"] 449 | # else: 450 | # _targets = history["stp"][agent_id - 1] 451 | # _targets = shift_by_delay(_targets, _delay) 452 | _targets = deepcopy(history["alpha"]) 453 | # _targets = deepcopy(alpha_traj) 454 | if _delay > 0: 455 | _targets = shift_by_delay(_targets, _delay) 456 | _targets[0, :] += tar_xs[agent_id] 457 | _targets[1, :] += tar_ys[agent_id] 458 | return _targets 459 | 460 | _Na = self.num_agents 461 | _delay = self.config.DELAY 462 | tar_xs, tar_ys = self.generate_circle_positions(self.config.TAR_RADIUS, _Na) 463 | 464 | states, controls, Js, ts = ([None] * _Na, [None] * _Na, [None] * _Na, [None] * _Na) 465 | leader_trajs = [alpha_traj] # arranged according to orders but agnostic to agent id 466 | 467 | cur_state = [jnp.asarray(_cur_state) for _cur_state in cur_state] 468 | ctrls_ws = [jnp.asarray(_ctrls_ws) for _ctrls_ws in ctrls_ws] 469 | 470 | # Computes trajectories for agents who are assigned order of play. 471 | assert len(order) == self.num_agents 472 | _assigned_agents = [] 473 | for i_order in range(len(order)): 474 | agent_id = order[i_order] 475 | if (agent_id is None): # assuming order = [agent assigned order | agent not assigned order] 476 | break 477 | else: 478 | _assigned_agents.append(agent_id) 479 | 480 | # Sets moving obstacles. 481 | _leader_trajs = np.stack(leader_trajs, axis=2) 482 | 483 | # Sets moving targets. 484 | _targets = set_targets(agent_id) 485 | 486 | # Solves iLQR/LQR. 487 | states_tmp, controls_tmp, J_tmp, t_tmp, _ = self.solvers[i_order].solve( 488 | cur_state[agent_id], ctrls_ws[agent_id], _leader_trajs, _targets 489 | ) 490 | states[agent_id], controls[agent_id], Js[agent_id], ts[agent_id] = ( 491 | states_tmp, controls_tmp, J_tmp, t_tmp 492 | ) 493 | leader_trajs.append(states[agent_id]) 494 | 495 | # Computes trajectories for unassigned agents 496 | _leader_trajs = np.stack(leader_trajs, axis=2) 497 | 498 | if len(_assigned_agents) < self.num_agents: 499 | _order_unassigned = len(_assigned_agents) 500 | assert _order_unassigned == len(leader_trajs) - 1 # alpha traj is always there 501 | for agent_id in range(self.num_agents): 502 | if agent_id not in _assigned_agents: 503 | # print('Unassigned agent detected, id =', agent_id) 504 | 505 | # Sets moving targets. 506 | _targets = set_targets(agent_id) 507 | 508 | # -> Option A. Unassigned players play last. 509 | if not (hasattr(self.config, "OPTION_B") and self.config.OPTION_B): 510 | states_tmp, controls_tmp, J_tmp, t_tmp, _ = self.solvers[_order_unassigned + 1].solve( 511 | cur_state[agent_id], ctrls_ws[agent_id], _leader_trajs, _targets 512 | ) 513 | 514 | # -> Option B. Unassigned players play blindly. 515 | else: 516 | states_tmp, controls_tmp, J_tmp, t_tmp, _ = self.solvers[0].solve( 517 | cur_state[agent_id], ctrls_ws[agent_id], None, _targets 518 | ) 519 | 520 | states[agent_id], controls[agent_id], Js[agent_id], ts[agent_id] = ( 521 | states_tmp, controls_tmp, J_tmp, t_tmp 522 | ) 523 | 524 | states = [np.asarray(_states) for _states in states] 525 | controls = [np.asarray(_controls) for _controls in controls] 526 | Js = [float(_J) for _J in Js] 527 | 528 | return states, controls, Js, ts 529 | 530 | def solve_rhc( 531 | self, cur_state: List[np.ndarray], ctrls_ws: List[np.ndarray], targets: List[np.ndarray], 532 | order: np.ndarray 533 | ) -> Tuple[List[np.ndarray], List[np.ndarray], List[float], List[float]]: 534 | return NotImplementedError 535 | -------------------------------------------------------------------------------- /wpf/STP/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions. 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Author: Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | 8 | import os 9 | import yaml 10 | import numpy as np 11 | import imageio.v2 as imageio 12 | from typing import List 13 | from IPython.display import Image 14 | from matplotlib import cm 15 | from matplotlib import pyplot as plt 16 | from matplotlib.patches import Circle 17 | from matplotlib.transforms import Affine2D 18 | 19 | plt.rcParams["font.family"] = "serif" 20 | plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"] 21 | 22 | 23 | class ATCZone: 24 | """ 25 | Air traffic control (ATC) zone. 26 | """ 27 | 28 | def __init__(self, config, targets: List[np.ndarray], oz_planner, tr_planner=None) -> None: 29 | self.center = [config.ATC_X, config.ATC_Y] 30 | self.radius = config.ATC_RADIUS 31 | self.targets = targets 32 | self.target_radius = config.TARGET_RADIUS 33 | self.reach_radius = config.REACH_RADIUS 34 | self.oz_planner = oz_planner # out-of-zone planner 35 | self.tr_planner = tr_planner # target-reaching planner 36 | self.num_agent = config.N_AGENT 37 | self.reach_flags = [False] * self.num_agent 38 | self.zone_flags = [False] * self.num_agent 39 | 40 | def reset(self): 41 | self.reach_flags = [False] * self.num_agent 42 | self.zone_flags = [False] * self.num_agent 43 | 44 | def is_in_zone(self, states: List[np.ndarray]) -> List[bool]: 45 | """ 46 | Checks if agents are in the ATC zone. 47 | """ 48 | return [ 49 | (s[0] - self.center[0])**2 + (s[1] - self.center[1])**2 <= self.radius**2 for s in states 50 | ] 51 | 52 | def is_near_target(self, states: List[np.ndarray]) -> List[bool]: 53 | return [(s[0] - t[0])**2 + (s[1] - t[1])**2 <= self.target_radius**2 54 | for (s, t) in zip(states, self.targets)] 55 | 56 | def is_reach_target(self, states: List[np.ndarray]) -> List[bool]: 57 | _reach_flags = [(s[0] - t[0])**2 + (s[1] - t[1])**2 <= self.reach_radius**2 58 | for (s, t) in zip(states, self.targets)] 59 | self.reach_flags = [_a or _b for _a, _b in zip(_reach_flags, self.reach_flags)] 60 | return self.reach_flags 61 | 62 | def is_collision(self, states: List[np.ndarray]) -> List[bool]: 63 | if len(states[0].shape) == 1: 64 | states = [s[:, np.newaxis] for s in states] 65 | states = np.stack(states, axis=2) 66 | return (self.oz_planner.pairwise_collision_check(states)).any() 67 | 68 | def plan_stp(self, x_cur, init_control, targets, order=None): 69 | """ 70 | Plan STP trajectories. 71 | """ 72 | if order is None: 73 | order = [None] * len(x_cur) 74 | states, controls, _, _ = self.oz_planner.solve(x_cur, init_control, targets, order) 75 | return states, controls 76 | 77 | def plan_stp_rhc(self, x_cur, init_control, targets, order=None): 78 | """ 79 | Plan STP trajectories with the RHC mode. 80 | """ 81 | if order is None: 82 | order = [None] * len(x_cur) 83 | states, controls, Js, _ = self.oz_planner.solve_rhc(x_cur, init_control, targets, order) 84 | return states, controls, Js 85 | 86 | def check_sep(self, x_cur, thresh): 87 | for ii in range(self.num_agent): 88 | for jj in range(self.num_agent): 89 | if ii != jj: 90 | sep = (x_cur[ii][0] - x_cur[jj][0])**2 + (x_cur[ii][1] - x_cur[jj][1])**2 91 | if sep < thresh**2: 92 | return True 93 | return False 94 | 95 | def check_zone(self, x_cur, x_new, us_cur, init_control): 96 | """ 97 | Checks agent location in zone. 98 | Out-of-zone agents use simple planning. Agents near target use target-reaching policy. 99 | """ 100 | 101 | # Checks out-of-zone. 102 | zone_flags = self.is_in_zone(x_cur) 103 | self.zone_flags = [_a or _b for _a, _b in zip(zone_flags, self.zone_flags)] 104 | if not all(self.zone_flags): 105 | _states, _controls = self.plan_stp(x_cur, init_control, self.targets) 106 | for ii in range(len(x_cur)): 107 | if not self.zone_flags[ii]: 108 | x_new[ii] = _states[ii][:, 1] 109 | us_cur[ii] = _controls[ii] 110 | 111 | # Checks target-reaching. 112 | target_flags = self.is_near_target(x_cur) 113 | for ii in range(len(x_cur)): 114 | if target_flags[ii]: 115 | _states, _controls, _, _, _ = self.tr_planner.solve( 116 | x_cur[ii], init_control[ii], [], self.targets[ii] 117 | ) 118 | x_new[ii] = _states[:, 1] 119 | us_cur[ii] = _controls 120 | 121 | return x_new, us_cur, zone_flags 122 | 123 | def generate_init_states(self, centers: List[np.ndarray], ranges: np.ndarray, rng, N): 124 | """ 125 | Generates initial conditions 126 | """ 127 | num_agent = len(centers) 128 | nx = centers[0].shape[0] 129 | init_states = [] 130 | for ii in range(num_agent): 131 | _l = ranges[:, np.newaxis] 132 | _init_s = centers[ii][:, np.newaxis] + rng.uniform(low=-_l, high=_l, size=(nx, N)) 133 | init_states.append(_init_s) 134 | return init_states 135 | 136 | 137 | class Struct: 138 | """ 139 | Struct for managing parameters. 140 | """ 141 | 142 | def __init__(self, data) -> None: 143 | for key, value in data.items(): 144 | setattr(self, key, value) 145 | 146 | 147 | def load_config(file_path): 148 | """ 149 | Loads the config file. 150 | 151 | Args: 152 | file_path (string): path to the parameter file. 153 | 154 | Returns: 155 | Struct: parameters. 156 | """ 157 | with open(file_path) as f: 158 | data = yaml.safe_load(f) 159 | config = Struct(data) 160 | return config 161 | 162 | 163 | def plot_receding_horizon( 164 | states, state_hist, k, config, fig_prog_folder, colors, xlim, ylim, figsize=(15, 15), 165 | fontsize=50, plot_pred=False 166 | ): 167 | """ 168 | Plot receding horizon planned trajectory. 169 | """ 170 | num_agent = len(colors) 171 | plt.figure(figsize=figsize) 172 | ax = plt.gca() 173 | for ii in range(num_agent): 174 | _xii, _yii, _pii = states[ii][0, 0], states[ii][1, 0], states[ii][3, 0] 175 | ax.add_patch(Circle((_xii, _yii), config.WIDTH, color=colors[ii], fill=False)) # footprint 176 | _len = config.WIDTH * 1.2 177 | plt.arrow( 178 | _xii, 179 | _yii, 180 | _len * np.cos(_pii), 181 | _len * np.sin(_pii), 182 | width=0.02, 183 | color=colors[ii], 184 | ) 185 | if plot_pred: 186 | plt.plot(states[ii][0, :], states[ii][1, :], linewidth=2, c="k", linestyle="--") # prediction 187 | plt.scatter( 188 | state_hist[ii][0, :k + 1], state_hist[ii][1, :k + 1], s=80, c=state_hist[ii][2, :k + 1], 189 | cmap=cm.jet, vmin=config.V_MIN, vmax=config.V_MAX, edgecolor="none", marker="o" 190 | ) # trajectory history 191 | # cbar = plt.colorbar(sc) 192 | # if ii < num_agent - 1: 193 | # cbar.remove() 194 | # cbar.set_label(r"velocity [m/s]") 195 | plt.axis("equal") 196 | plt.xlim(xlim) 197 | plt.ylim(ylim) 198 | plt.rcParams.update({"font.size": fontsize}) 199 | # plt.title(str(int(Js[0])) + ' | ' + str(int(Js[1])) + ' | ' + str(int(Js[2]))) 200 | plt.savefig(os.path.join(fig_prog_folder, str(k) + ".png"), dpi=50) 201 | plt.close() 202 | 203 | 204 | def generate_rgb_values(n): 205 | rgb_list = [] 206 | 207 | # Warm colors 208 | for i in range(n // 2): 209 | ratio = i / (n//2 - 1) 210 | rgb = [1.0, ratio, 0.0] # Red to Yellow 211 | rgb_list.append(rgb) 212 | 213 | # Cool colors 214 | for i in range(n // 2): 215 | ratio = i / (n//2 - 1) 216 | rgb = [0.0, 1.0 - ratio, 1.0] # Blue to Violet 217 | rgb_list.append(rgb) 218 | 219 | return rgb_list 220 | 221 | 222 | def plot_trajectory( 223 | states, config, fig_prog_folder, colors, xlim, ylim, figsize=(15, 15), fontsize=50, linewidth=5, 224 | image=None, targets=None, orders=None, plot_arrow=True, zone: ATCZone = None, zone_flags=None 225 | ): 226 | """ 227 | Plots the planned trajectory. 228 | """ 229 | step = states[0].shape[1] 230 | num_agent = len(colors) 231 | 232 | if orders is not None: 233 | rgb_list = generate_rgb_values(num_agent) 234 | 235 | for k in range(step): 236 | 237 | plt.figure(figsize=figsize) 238 | ax = plt.gca() 239 | 240 | # Plots the ATC Zone. 241 | if zone is not None: 242 | atc_circ = Circle((zone.center[0], zone.center[1]), zone.radius, color=[.7, .7, .7], 243 | fill=False, linestyle="--", linewidth=linewidth) 244 | ax.add_patch(atc_circ) 245 | 246 | for ii in range(num_agent): 247 | 248 | # Plots targets. 249 | if targets is not None: 250 | tar_circ = Circle((targets[ii][0], targets[ii][1]), config.WIDTH, color=colors[ii], 251 | fill=False, linestyle="-", linewidth=linewidth) 252 | ax.add_patch(tar_circ) 253 | 254 | # Plots agent footprints and headings. 255 | _xii, _yii, _pii = states[ii][0, k], states[ii][1, k], states[ii][3, k] 256 | fpt_circ = Circle((_xii, _yii), config.WIDTH, color=colors[ii], fill=False, linestyle="--", 257 | linewidth=linewidth / 2.) 258 | ax.add_patch(fpt_circ) 259 | if plot_arrow: 260 | _len = config.WIDTH * 1.2 261 | plt.arrow( 262 | _xii, _yii, _len * np.cos(_pii), _len * np.sin(_pii), width=0.015, color=colors[ii] 263 | ) 264 | 265 | # Plots trajectory history. 266 | if orders is not None: 267 | for tau in range(k): 268 | if zone_flags is not None and not zone_flags[tau][ii]: # Outside ATC zone 269 | _color = [.7, .7, .7] 270 | else: 271 | try: 272 | _order_tau = orders[tau] 273 | _color = rgb_list[_order_tau.index(ii)] 274 | except: 275 | _color = [.7, .7, .7] 276 | plt.scatter( 277 | states[ii][0, tau], states[ii][1, tau], s=figsize[0] * 7.5, color=_color, 278 | edgecolor="none", marker="o" 279 | ) 280 | else: 281 | plt.scatter( 282 | states[ii][0, :k], states[ii][1, :k], s=figsize[0] * 7.5, c=states[ii][2, :k], 283 | cmap=cm.jet, vmin=config.V_MIN, vmax=config.V_MAX, edgecolor="none", marker="o" 284 | ) 285 | 286 | # Plots agent images. 287 | if image is not None: 288 | transform_data = ( 289 | Affine2D().rotate_deg_around(*(_xii, _yii), _pii / np.pi * 180) + plt.gca().transData 290 | ) 291 | plt.imshow( 292 | image, transform=transform_data, interpolation="none", origin="lower", extent=[ 293 | _xii - config.WIDTH, _xii + config.WIDTH, _yii - config.WIDTH, _yii + config.WIDTH 294 | ], alpha=1.0, zorder=10.0, clip_on=True 295 | ) 296 | 297 | # Figure setup. 298 | plt.axis("equal") 299 | plt.xlim(xlim) 300 | plt.ylim(ylim) 301 | plt.axis('off') 302 | plt.rcParams.update({"font.size": fontsize}) 303 | plt.savefig(os.path.join(fig_prog_folder, str(k) + ".png"), dpi=50) 304 | plt.close() 305 | 306 | 307 | def make_animation(steps, config, fig_prog_folder, name="rollout.gif"): 308 | gif_path = os.path.join(config.OUT_FOLDER, name) 309 | with imageio.get_writer(gif_path, mode="I", loop=0) as writer: 310 | for j in range(1, steps): 311 | filename = os.path.join(fig_prog_folder, str(j) + ".png") 312 | image = imageio.imread(filename) 313 | writer.append_data(image) 314 | return Image(open(gif_path, "rb").read()) 315 | -------------------------------------------------------------------------------- /wpf/__init__.py: -------------------------------------------------------------------------------- 1 | from .bnp import * 2 | from .io import * 3 | from .STP import * 4 | -------------------------------------------------------------------------------- /wpf/bnp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main functions of Branch and Play (B&P). 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Authors: Gabriele Dragotto (hello@dragotto.net), Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | 8 | import time 9 | import numpy as np 10 | from dataclasses import dataclass 11 | from typing import Any 12 | 13 | from .io import print_log_head, print_iteration 14 | 15 | 16 | @dataclass 17 | class Node(object): 18 | """ 19 | Branch-and-Play node. 20 | """ 21 | 22 | permutation: np.ndarray 23 | depth: int 24 | # Note that the solution is initialized to the one of the parent node 25 | solution: np.ndarray 26 | lb: float 27 | ub: float 28 | custom_data: Any = None # custom data for each node 29 | 30 | 31 | @dataclass 32 | class Settings(object): 33 | """ 34 | Branch-and-Play settings. 35 | """ 36 | 37 | max_nodes: int = 1e20 38 | feas_tolerance: float = 1e-5 39 | min_gap: float = 0 40 | max_iter: int = 1e20 41 | random_heuristic_iterations: int = 1 42 | verbose: bool = True 43 | custom_settings: Any = None # custom settings 44 | 45 | 46 | @dataclass 47 | class Statistics(object): 48 | """ 49 | Branch-and-Play statistics. 50 | """ 51 | 52 | explored_nodes = 0 53 | fathomed_nodes = 0 54 | num_feasible_solutions = 0 55 | time = 0.0 56 | global_ub = np.inf 57 | global_ub_history = np.ndarray 58 | global_lb = np.inf 59 | global_lb_history = np.ndarray 60 | gap_history = np.ndarray 61 | explored_permutations = np.ndarray 62 | incumbent = np.ndarray 63 | incumbent_permutation = np.ndarray 64 | custom_statistics: Any = None # custom statistics 65 | 66 | 67 | class BranchAndPlay(object): 68 | """ 69 | Branch-and-Play tree. 70 | """ 71 | 72 | def __init__( 73 | self, n, instance_data, exploration_cb, solver_cb, fathoming_cb, branching_cb, settings, 74 | initializer=None 75 | ): 76 | """ 77 | Initializes the fields of the Branch and Play. 78 | :param n: Number of players 79 | :param instance_data: The instance data (passed to the solver_cb method) 80 | :param exploration_cb: Callback to select the next node to explore (arguments: vector of nodes and stats; returns the index of the selected node) 81 | :param solver_cb: Callback to solve the node (arguments: instance_data, node, settings and stats; writes in-place in node (no return)) 82 | :param fathoming_cb: Callback to fathom nodes (arguments: vector of nodes, upper bound, incumbent and incumbent permutation; write in-place in nodes) 83 | :param branching_cb: Callback to create children nodes (arguments: vector of nodes and current node, n; write in-place in nodes) 84 | :param settings: Settings object with the solver settings 85 | :param initialzier: An optional function that will take, as arguments, the instance data, the nodes, the statistics and the settings. It 86 | writes in place in these objects 87 | :return The object Statistics (which includes the solution) 88 | """ 89 | self.n = n 90 | self.data = instance_data 91 | self.exploration = exploration_cb 92 | self.solver = solver_cb 93 | self.fathoming = fathoming_cb 94 | self.branching = branching_cb 95 | self.settings = settings 96 | self.initializer = initializer 97 | 98 | self.update_problem(instance_data) 99 | 100 | @property 101 | def results(self): 102 | return self._stats 103 | 104 | def update_problem(self, instance_data, mode="init"): 105 | """ 106 | Updates the problem data. 107 | 108 | mode = "init" or "update" 109 | """ 110 | self.nodes = [] 111 | self._stats = Statistics() 112 | self.create_root_node() 113 | 114 | if self.initializer is not None: 115 | self.initializer(instance_data, self.nodes, self._stats, self.settings, mode) 116 | 117 | # Store intermediate results 118 | self._stats.iter = 0 119 | self._stats.global_lb_history = [] 120 | self._stats.global_ub_history = [] 121 | self._stats.gap_history = [] 122 | self._stats.explored_permutations = [] 123 | 124 | def create_root_node(self): 125 | """ 126 | Creates the root node for the tree. 127 | """ 128 | self.nodes.append( 129 | Node(np.array([None for _ in range(self.n)]), 0, None, -np.inf, +np.inf, None) 130 | ) 131 | 132 | def can_continue(self): 133 | """ 134 | Determines whether the exploration can continue or not. 135 | """ 136 | if len(self.nodes) == 0: 137 | return False 138 | 139 | if self._stats.explored_nodes > self.settings.max_nodes: 140 | return False 141 | 142 | if self._stats.iter > self.settings.max_iter: 143 | return False 144 | 145 | gap = np.inf 146 | if (abs(self._stats.global_ub) != np.inf and abs(self._stats.global_lb) != np.inf): 147 | gap = abs(self._stats.global_ub - self._stats.global_lb) / abs(self._stats.global_lb) 148 | self._stats.gap_history.append(gap) 149 | if gap <= self.settings.min_gap: 150 | print("Minimum gap reached. Terminating exploration with gap:", gap) 151 | return False 152 | 153 | return True 154 | 155 | def prune(self): 156 | """ 157 | Performs pruning via bound and user-specified fathoming rule. 158 | """ 159 | if self._stats.global_ub != np.inf: 160 | init_len = len(self.nodes) 161 | self.nodes = [ 162 | node for node in self.nodes 163 | if self._stats.global_ub >= node.lb + self.settings.feas_tolerance 164 | ] 165 | self._stats.fathomed_nodes += init_len - len(self.nodes) 166 | # Custom fathoming rule 167 | self.fathoming( 168 | self.nodes, 169 | self._stats.global_ub, 170 | self._stats.incumbent, 171 | self._stats.incumbent_permutation, 172 | ) 173 | 174 | def solve_node(self, node): 175 | """ 176 | Solves the selected node and updates the statistics. 177 | :param node: Node to solve 178 | """ 179 | prev_lb = lb = node.lb 180 | self.solver(self.data, node, self.settings, self._stats) 181 | feasible = False 182 | lb = node.lb 183 | 184 | assert lb >= prev_lb, "Lower bound cannot improve" 185 | 186 | # Check if node contains a feasible solution 187 | if lb != np.inf: 188 | if None not in node.permutation: 189 | # The permutation is complete (i.e., no missing elements) 190 | # Update the solution and statistics 191 | self._stats.num_feasible_solutions += 1 192 | feasible = True 193 | 194 | assert node.ub >= lb, "Upper bound cannot be lower than lower bound" 195 | 196 | if self._stats.global_ub > node.ub: 197 | self._stats.incumbent_permutation = node.permutation 198 | self._stats.incumbent = node.solution 199 | # Update upper bound 200 | self._stats.global_ub = node.ub 201 | else: 202 | self.branching(self.nodes, node, self.n, self._stats, self.settings) 203 | 204 | self._stats.explored_permutations.append(node.permutation.tolist()) 205 | return feasible, lb 206 | 207 | def init_perm_heuristic(self, verbose=True): 208 | """ 209 | Proposes an initial feasible permutation. 210 | """ 211 | if ( 212 | hasattr(self.settings, 'custom_settings') 213 | and self.settings.custom_settings["enable_custom_heuristic"] 214 | ): 215 | # Picks a permutation that prioritizes players with higher costs 216 | sorted_Js = sorted( 217 | enumerate(self.settings.custom_settings["Js_prev"]), key=lambda x: x[1], reverse=True 218 | ) 219 | permutation = [index for index, _ in sorted_Js] 220 | if verbose: 221 | print("\tCustom heuristic is exploring permutation", permutation) 222 | else: 223 | # Randomly pick a permutation 224 | for _ in range(self.settings.random_heuristic_iterations): 225 | permutation = np.random.permutation(self.n) 226 | # Create the node 227 | node = Node(permutation, 0, None, self.nodes[0].lb, -np.inf, None) 228 | # Solve the node 229 | feasible, lb = self.solve_node(node) 230 | if verbose: 231 | print("\tRandom heuristic is exploring permutation", permutation) 232 | 233 | def update_lb(self): 234 | """ 235 | Updates the lower bound in the statistics. 236 | """ 237 | minimum = np.inf 238 | for node in self.nodes: 239 | minimum = min(node.lb, minimum) 240 | self._stats.global_lb = minimum 241 | 242 | def solve(self): 243 | """ 244 | Main method for the user to solve the problem via Branch-and-Play. 245 | """ 246 | verbose = self.settings.verbose 247 | 248 | # Start the timer 249 | start = time.time() 250 | if verbose: 251 | print_log_head() 252 | 253 | self.init_perm_heuristic(verbose) 254 | 255 | while self.can_continue(): 256 | # Pick the next node 257 | index = self.exploration(self.nodes, self._stats) 258 | node = self.nodes[index] 259 | 260 | # Solve the node 261 | feasible, lb = self.solve_node(node) 262 | 263 | # Remove the node 264 | del self.nodes[index] 265 | 266 | # Fathom and prune nodes 267 | self.update_lb() 268 | self.prune() 269 | 270 | # Print iteration 271 | if verbose: 272 | print_iteration( 273 | self._stats.explored_nodes, 274 | time.time() - start, self._stats.fathomed_nodes, len(self.nodes), lb, 275 | self._stats.global_lb, self._stats.global_ub, feasible, node.permutation 276 | ) 277 | 278 | # Store intermediate data for later plots 279 | if self._stats.global_ub == np.inf: 280 | self._stats.global_ub_history.append(None) 281 | self._stats.gap_history.append(None) 282 | else: 283 | self._stats.global_ub_history.append(self._stats.global_ub) 284 | self._stats.global_ub_history.append(self._stats.global_lb) 285 | 286 | # Update stats 287 | self._stats.explored_nodes += 1 288 | self._stats.iter += 1 289 | 290 | self._stats.time = time.time() - start 291 | 292 | 293 | def bestfirst_cb(nodes, stats): 294 | """ 295 | Defines the best-first exploration strategy. 296 | :param nodes: List of nodes 297 | :param stats: The statistics object 298 | :return The selected node index 299 | """ 300 | best_index = 0 301 | best_bound = np.inf 302 | for index, node in enumerate(nodes): 303 | if node.lb < best_bound: 304 | best_index = index 305 | best_bound = node.lb 306 | return best_index 307 | 308 | 309 | def depthfirst_cb(nodes, stats): 310 | """ 311 | Defines the dept-first exploration strategy. 312 | :param nodes: List of nodes 313 | :param stats: The statistics object 314 | :return The selected node index 315 | """ 316 | return len(nodes) - 1 317 | 318 | 319 | def branchall_cb(nodes, node, n, stats, settings): 320 | """ 321 | Defines the simple branching strategy of creating n children from the first unknown player. 322 | Note that this method performs an in-place operation on nodes. 323 | :param nodes: List of nodes 324 | :param node: Incumbent node 325 | :param stats: Statistics 326 | :param settings: Settings 327 | :param n: Number of players 328 | """ 329 | 330 | undefined_locations = np.where(node.permutation == None)[0] 331 | if len(undefined_locations) > 0: 332 | location = undefined_locations[0] 333 | undefined = [] 334 | 335 | for player in range(n): 336 | if player not in node.permutation: 337 | undefined.append(player) 338 | 339 | for index, player in enumerate(undefined): 340 | child_permutation = np.copy(node.permutation) 341 | child_permutation[location] = player 342 | 343 | undefined_locations = np.where(child_permutation == None)[0] 344 | # If only one is missing 345 | if len(undefined_locations) == 1: 346 | child_permutation[undefined_locations[0]] = undefined[(index+1) % len(undefined)] 347 | if child_permutation.tolist() not in stats.explored_permutations: 348 | nodes.append( 349 | Node( 350 | child_permutation, 351 | node.depth + 1, 352 | node.solution, 353 | node.lb, 354 | -np.inf, 355 | node.custom_data, 356 | ) 357 | ) 358 | -------------------------------------------------------------------------------- /wpf/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger functions for Branch and Play (B&P). 3 | 4 | Please contact the author(s) of this library if you have any questions. 5 | Authors: Gabriele Dragotto (hello@dragotto.net), Haimin Hu (haiminh@princeton.edu) 6 | """ 7 | 8 | import logging 9 | import sys 10 | import numpy as np 11 | 12 | 13 | class CustomFormatter(logging.Formatter): 14 | grey = "\x1b[1;36m" 15 | yellow = "\x1b[33;20m" 16 | red = "\x1b[31;20m" 17 | bold_red = "\x1b[31;1m" 18 | reset = "\x1b[0m" 19 | format = ("%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)") 20 | 21 | FORMATS = { 22 | logging.DEBUG: grey + format + reset, 23 | logging.INFO: grey + format + reset, 24 | logging.WARNING: yellow + format + reset, 25 | logging.ERROR: red + format + reset, 26 | logging.CRITICAL: bold_red + format + reset, 27 | } 28 | 29 | def format(self, record): 30 | log_fmt = self.FORMATS.get(record.levelno) 31 | formatter = logging.Formatter(log_fmt) 32 | return formatter.format(record) 33 | 34 | 35 | logger = logging.getLogger("wpf") 36 | logger.setLevel(logging.INFO) 37 | ch = logging.StreamHandler(sys.stdout) 38 | ch.setLevel(logging.INFO) 39 | ch.setFormatter(CustomFormatter()) 40 | logger.addHandler(ch) 41 | logger.propagate = False 42 | 43 | 44 | def print_log_head(): 45 | print("🎮👾🎯🎰🎲🎳🃏🎱⚽️🎮👾🎯🎰🎲🎳🃏🎱⚽️ Who Plays First? 🎮👾🎯🎰🎲🎳🃏🎱⚽️🎮👾🎯🎰🎲🎳🃏🎱⚽️") 46 | print(("{:<7}\t{:<7}\t{:<7}\t{:<7}\t{:<10}\t{:<10}\t{:<10}\t{:<10}\t{:<10}\t{:<7}").format( 47 | "Node", 48 | "Time", 49 | "Pruned", 50 | "Active", 51 | "NodeBnd", 52 | "LowerBnd", 53 | "UpperBnd", 54 | "Gap", 55 | "Feasible", 56 | "Permutation", 57 | )) 58 | 59 | 60 | def print_iteration(node, time, fathomed, active, bound, best_lb, best_ub, feasible, permutation): 61 | _spec = "{:<7.0f}\t{:<7.3f}\t{:<7.0f}\t{:<7.0f}\t{:<10.2f}\t{:<10.2f}\t{:<10.2f}\t{:<10.4f}\t{:<10}\t{:<10}" 62 | print( 63 | _spec.format( 64 | node, 65 | time, 66 | fathomed, 67 | active, 68 | bound, 69 | best_lb, 70 | best_ub, 71 | np.inf if 72 | (best_ub == np.inf or best_lb == -np.inf) else abs(best_ub - best_lb) / abs(best_lb), 73 | "✅" if feasible else "❌", 74 | str(permutation), 75 | ) 76 | ) 77 | --------------------------------------------------------------------------------