├── .gitignore ├── LICENSE ├── README.md ├── gcbfplus ├── __init__.py ├── algo │ ├── __init__.py │ ├── base.py │ ├── centralized_cbf.py │ ├── dec_share_cbf.py │ ├── gcbf.py │ ├── gcbf_plus.py │ ├── module │ │ ├── __init__.py │ │ ├── cbf.py │ │ ├── distribution.py │ │ ├── policy.py │ │ └── value.py │ └── utils.py ├── env │ ├── __init__.py │ ├── base.py │ ├── crazyflie.py │ ├── double_integrator.py │ ├── dubins_car.py │ ├── linear_drone.py │ ├── obstacle.py │ ├── plot.py │ ├── single_integrator.py │ └── utils.py ├── nn │ ├── __init__.py │ ├── gnn.py │ ├── mlp.py │ └── utils.py ├── trainer │ ├── __init__.py │ ├── buffer.py │ ├── data.py │ ├── trainer.py │ └── utils.py └── utils │ ├── __init__.py │ ├── graph.py │ ├── typing.py │ └── utils.py ├── media ├── DoubleIntegrator_512_2x.gif ├── Obstacle2D_32.gif ├── Obstacle2D_512_2x.gif └── cbf1.gif ├── pretrained ├── CrazyFlie │ └── gcbf+ │ │ ├── config.yaml │ │ └── models │ │ └── 1000 │ │ ├── actor.pkl │ │ └── cbf.pkl ├── DoubleIntegrator │ └── gcbf+ │ │ ├── config.yaml │ │ └── models │ │ └── 1000 │ │ ├── actor.pkl │ │ └── cbf.pkl ├── DubinsCar │ └── gcbf+ │ │ ├── config.yaml │ │ └── models │ │ └── 1000 │ │ ├── actor.pkl │ │ └── cbf.pkl ├── LinearDrone │ └── gcbf+ │ │ ├── config.yaml │ │ └── models │ │ └── 1000 │ │ ├── actor.pkl │ │ └── cbf.pkl └── SingleIntegrator │ └── gcbf+ │ ├── config.yaml │ └── models │ └── 1000 │ ├── actor.pkl │ └── cbf.pkl ├── requirements.txt ├── settings.yaml ├── setup.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | __pycache__/ 4 | wandb 5 | logs 6 | figs 7 | videos/ 8 | wandb/ 9 | test_log.csv 10 | *.egg-info -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Songyuan Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # GCBF+ 4 | 5 | [![Paper](https://img.shields.io/badge/T--RO-Accepted-success)](https://mit-realm.github.io/gcbfplus-website/) 6 | 7 | Jax Official Implementation of T-RO Paper: [Songyuan Zhang*](https://syzhang092218-source.github.io), [Oswin So*](https://oswinso.xyz/), [Kunal Garg](https://kunalgarg.mit.edu/), and [Chuchu Fan](https://chuchu.mit.edu): "[GCBF+: A Neural Graph Control Barrier Function Framework for Distributed Safe Multi-Agent Control](https://mit-realm.github.io/gcbfplus-website/)". 8 | 9 | [Dependencies](#Dependencies) • 10 | [Installation](#Installation) • 11 | [Run](#Run) 12 | 13 |
14 | 15 | A much improved version of [GCBFv0](https://mit-realm.github.io/gcbf-website/)! 16 | 17 |
18 | LidarSpread 19 | LidarLine 20 | VMASReverseTransport 21 | VMASWheel 22 |
23 | 24 | ## Dependencies 25 | 26 | We recommend to use [CONDA](https://www.anaconda.com/) to install the requirements: 27 | 28 | ```bash 29 | conda create -n gcbfplus python=3.10 30 | conda activate gcbfplus 31 | cd gcbfplus 32 | ``` 33 | 34 | Then install jax following the [official instructions](https://github.com/google/jax#installation), and then install the rest of the dependencies: 35 | ```bash 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Installation 40 | 41 | Install GCBF: 42 | 43 | ```bash 44 | pip install -e . 45 | ``` 46 | 47 | ## Run 48 | 49 | ### Environments 50 | 51 | We provide 3 2D environments including `SingleIntegrator`, `DoubleIntegrator`, and `DubinsCar`, and 2 3D environments including `LinearDrone` and `CrazyFlie`. 52 | 53 | ### Algorithms 54 | 55 | We provide algorithms including GCBF+ (`gcbf+`), GCBF (`gcbf`), centralized CBF-QP (`centralized_cbf`), and decentralized CBF-QP (`dec_share_cbf`). Use `--algo` to specify the algorithm. 56 | 57 | ### Hyper-parameters 58 | 59 | To reproduce the results shown in our paper, one can refer to [`settings.yaml`](./settings.yaml). 60 | 61 | ### Train 62 | 63 | To train the model (only GCBF+ and GCBF need training), use: 64 | 65 | ```bash 66 | python train.py --algo gcbf+ --env DoubleIntegrator -n 8 --area-size 4 --loss-action-coef 1e-4 --n-env-train 16 --lr-actor 1e-5 --lr-cbf 1e-5 --horizon 32 67 | ``` 68 | 69 | In our paper, we use 8 agents with 1000 training steps. The training logs will be saved in folder `./logs///seed_`. We also provide the following flags: 70 | 71 | - `-n`: number of agents 72 | - `--env`: environment, including `SingleIntegrator`, `DoubleIntegrator`, `DubinsCar`, `LinearDrone`, and `CrazyFlie` 73 | - `--algo`: algorithm, including `gcbf`, `gcbf+` 74 | - `--seed`: random seed 75 | - `--steps`: number of training steps 76 | - `--name`: name of the experiment 77 | - `--debug`: debug mode: no recording, no saving 78 | - `--obs`: number of obstacles 79 | - `--n-rays`: number of LiDAR rays 80 | - `--area-size`: side length of the environment 81 | - `--n-env-train`: number of environments for training 82 | - `--n-env-test`: number of environments for testing 83 | - `--log-dir`: path to save the training logs 84 | - `--eval-interval`: interval of evaluation 85 | - `--eval-epi`: number of episodes for evaluation 86 | - `--save-interval`: interval of saving the model 87 | 88 | In addition, use the following flags to specify the hyper-parameters: 89 | - `--alpha`: GCBF alpha 90 | - `--horizon`: GCBF+ look forward horizon 91 | - `--lr-actor`: learning rate of the actor 92 | - `--lr-cbf`: learning rate of the CBF 93 | - `--loss-action-coef`: coefficient of the action loss 94 | - `--loss-h-dot-coef`: coefficient of the h_dot loss 95 | - `--loss-safe-coef`: coefficient of the safe loss 96 | - `--loss-unsafe-coef`: coefficient of the unsafe loss 97 | - `--buffer-size`: size of the replay buffer 98 | 99 | ### Test 100 | 101 | To test the learned model, use: 102 | 103 | ```bash 104 | python test.py --path --epi 5 --area-size 4 -n 16 --obs 0 105 | ``` 106 | 107 | This should report the safety rate, goal reaching rate, and success rate of the learned model, and generate videos of the learned model in `/videos`. Use the following flags to customize the test: 108 | 109 | - `-n`: number of agents 110 | - `--obs`: number of obstacles 111 | - `--area-size`: side length of the environment 112 | - `--max-step`: maximum number of steps for each episode, increase this if you have a large environment 113 | - `--path`: path to the log folder 114 | - `--n-rays`: number of LiDAR rays 115 | - `--alpha`: CBF alpha, used in centralized CBF-QP and decentralized CBF-QP 116 | - `--max-travel`: maximum travel distance of agents 117 | - `--cbf`: plot the CBF contour of this agent, only support 2D environments 118 | - `--seed`: random seed 119 | - `--debug`: debug mode 120 | - `--cpu`: use CPU 121 | - `--u-ref`: test the nominal controller 122 | - `--env`: test environment (not needed if the log folder is specified) 123 | - `--algo`: test algorithm (not needed if the log folder is specified) 124 | - `--step`: test step (not needed if testing the last saved model) 125 | - `--epi`: number of episodes to test 126 | - `--offset`: offset of the random seeds 127 | - `--no-video`: do not generate videos 128 | - `--log`: log the results to a file 129 | - `--dpi`: dpi of the video 130 | - `--nojit-rollout`: do not use jit to speed up the rollout, used for large-scale tests 131 | 132 | To test the nominal controller, use: 133 | 134 | ```bash 135 | python test.py --env SingleIntegrator -n 16 --u-ref --epi 1 --area-size 4 --obs 0 136 | ``` 137 | 138 | To test the CBF-QPs, use: 139 | 140 | ```bash 141 | python test.py --env SingleIntegrator -n 16 --algo dec_share_cbf --epi 1 --area-size 4 --obs 0 --alpha 1 142 | ``` 143 | 144 | ### Pre-trained models 145 | 146 | We provide pre-trained models in folder [`pretrained`](pretrained). However, their performance may depend on the GPU/CUDA/Jax versions. We highly recommend retraining a model yourself. 147 | 148 | ## Citation 149 | 150 | ``` 151 | @ARTICLE{zhang2025gcbf+, 152 | author={Zhang, Songyuan and So, Oswin and Garg, Kunal and Fan, Chuchu}, 153 | journal={IEEE Transactions on Robotics}, 154 | title={{GCBF}+: A Neural Graph Control Barrier Function Framework for Distributed Safe Multiagent Control}, 155 | year={2025}, 156 | volume={41}, 157 | pages={1533-1552}, 158 | doi={10.1109/TRO.2025.3530348} 159 | } 160 | ``` 161 | 162 | ## Acknowledgement 163 | 164 | The developers were partially supported by MITRE during the project. 165 | 166 | © 2024 MIT 167 | 168 | © 2024 The MITRE Corporation 169 | -------------------------------------------------------------------------------- /gcbfplus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/__init__.py -------------------------------------------------------------------------------- /gcbfplus/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import MultiAgentController 2 | from .dec_share_cbf import DecShareCBF 3 | from .gcbf import GCBF 4 | from .gcbf_plus import GCBFPlus 5 | from .centralized_cbf import CentralizedCBF 6 | 7 | 8 | def make_algo(algo: str, **kwargs) -> MultiAgentController: 9 | if algo == 'gcbf': 10 | return GCBF(**kwargs) 11 | elif algo == 'gcbf+': 12 | return GCBFPlus(**kwargs) 13 | elif algo == 'centralized_cbf': 14 | return CentralizedCBF(**kwargs) 15 | elif algo == 'dec_share_cbf': 16 | return DecShareCBF(**kwargs) 17 | else: 18 | raise ValueError(f'Unknown algorithm: {algo}') 19 | -------------------------------------------------------------------------------- /gcbfplus/algo/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod, abstractproperty 2 | from typing import Optional, Tuple 3 | 4 | from gcbfplus.utils.typing import Action, Params, PRNGKey, Array 5 | from gcbfplus.utils.graph import GraphsTuple 6 | from gcbfplus.trainer.data import Rollout 7 | from gcbfplus.env.base import MultiAgentEnv 8 | 9 | 10 | class MultiAgentController(ABC): 11 | 12 | def __init__( 13 | self, 14 | env: MultiAgentEnv, 15 | node_dim: int, 16 | edge_dim: int, 17 | action_dim: int, 18 | n_agents: int 19 | ): 20 | self._env = env 21 | self._node_dim = node_dim 22 | self._edge_dim = edge_dim 23 | self._action_dim = action_dim 24 | self._n_agents = n_agents 25 | 26 | @property 27 | def node_dim(self) -> int: 28 | return self._node_dim 29 | 30 | @property 31 | def edge_dim(self) -> int: 32 | return self._edge_dim 33 | 34 | @property 35 | def action_dim(self) -> int: 36 | return self._action_dim 37 | 38 | @property 39 | def n_agents(self) -> int: 40 | return self._n_agents 41 | 42 | @abstractproperty 43 | def config(self) -> dict: 44 | pass 45 | 46 | @abstractproperty 47 | def actor_params(self) -> Params: 48 | pass 49 | 50 | @abstractmethod 51 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action: 52 | pass 53 | 54 | @abstractmethod 55 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]: 56 | pass 57 | 58 | @abstractmethod 59 | def update(self, rollout: Rollout, step: int) -> dict: 60 | pass 61 | 62 | @abstractmethod 63 | def save(self, save_dir: str, step: int): 64 | pass 65 | 66 | @abstractmethod 67 | def load(self, load_dir: str, step: int): 68 | pass 69 | -------------------------------------------------------------------------------- /gcbfplus/algo/centralized_cbf.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax 3 | import einops as ei 4 | 5 | from typing import Optional, Tuple 6 | from jaxproxqp.jaxproxqp import JaxProxQP 7 | 8 | from gcbfplus.utils.typing import Action, Params, PRNGKey, Array, State 9 | from gcbfplus.utils.graph import GraphsTuple 10 | from gcbfplus.utils.utils import mask2index 11 | from gcbfplus.trainer.data import Rollout 12 | from gcbfplus.env.base import MultiAgentEnv 13 | from .utils import get_pwise_cbf_fn 14 | from .base import MultiAgentController 15 | 16 | 17 | class CentralizedCBF(MultiAgentController): 18 | 19 | def __init__( 20 | self, 21 | env: MultiAgentEnv, 22 | node_dim: int, 23 | edge_dim: int, 24 | state_dim: int, 25 | action_dim: int, 26 | n_agents: int, 27 | alpha: float = 1.0, 28 | **kwargs 29 | ): 30 | super(CentralizedCBF, self).__init__( 31 | env=env, 32 | node_dim=node_dim, 33 | edge_dim=edge_dim, 34 | action_dim=action_dim, 35 | n_agents=n_agents 36 | ) 37 | 38 | self.alpha = alpha 39 | self.k = 3 40 | self.cbf = get_pwise_cbf_fn(env, self.k) 41 | 42 | @property 43 | def config(self) -> dict: 44 | return { 45 | 'alpha': self.alpha, 46 | } 47 | 48 | @property 49 | def actor_params(self) -> Params: 50 | raise NotImplementedError 51 | 52 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]: 53 | raise NotImplementedError 54 | 55 | def get_cbf(self, graph: GraphsTuple) -> Array: 56 | return self.cbf(graph)[0] 57 | 58 | def update(self, rollout: Rollout, step: int) -> dict: 59 | raise NotImplementedError 60 | 61 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action: 62 | return self.get_qp_action(graph)[0] 63 | 64 | def get_qp_action(self, graph: GraphsTuple, relax_penalty: float = 1e3) -> [Action, Array]: 65 | assert graph.is_single # consider single graph 66 | agent_node_mask = graph.node_type == 0 67 | agent_node_id = mask2index(agent_node_mask, self.n_agents) 68 | 69 | def h_aug(new_agent_state: State) -> Array: 70 | new_state = graph.states.at[agent_node_id].set(new_agent_state) 71 | new_graph = graph._replace(edges=new_state[graph.receivers] - new_state[graph.senders], states=new_state) 72 | val = self.get_cbf(new_graph) 73 | assert val.shape == (self.n_agents, self.k) 74 | return val 75 | 76 | agent_state = graph.type_states(type_idx=0, n_type=self.n_agents) 77 | h = h_aug(agent_state) # (n_agents, k) 78 | h_x = jax.jacfwd(h_aug)(agent_state) # (n_agents, k | n_agents, nx) 79 | h = h.reshape(-1) # (n_agents * k,) 80 | 81 | dyn_f, dyn_g = self._env.control_affine_dyn(agent_state) 82 | Lf_h = ei.einsum(h_x, dyn_f, "agent_i k agent_j nx, agent_j nx -> agent_i k") 83 | Lg_h = ei.einsum(h_x, dyn_g, "agent_i k agent_j nx, agent_j nx nu -> agent_i k agent_j nu") 84 | Lf_h = Lf_h.reshape(-1) # (n_agents * k,) 85 | Lg_h = Lg_h.reshape((self.n_agents * self.k, -1)) # (n_agents * k, n_agents * nu) 86 | 87 | u_lb, u_ub = self._env.action_lim() 88 | u_lb = u_lb[None, :].repeat(self.n_agents, axis=0).reshape(-1) 89 | u_ub = u_ub[None, :].repeat(self.n_agents, axis=0).reshape(-1) 90 | u_ref = self._env.u_ref(graph).reshape(-1) 91 | 92 | # construct QP 93 | H = jnp.eye(self._env.action_dim * self.n_agents + self.n_agents * self.k, dtype=jnp.float32) 94 | H = H.at[-self.n_agents * self.k:, -self.n_agents * self.k:].set( 95 | H[-self.n_agents * self.k:, -self.n_agents * self.k:] * 10.0) 96 | g = jnp.concatenate([-u_ref, relax_penalty * jnp.ones(self.n_agents * self.k)]) 97 | C = -jnp.concatenate([Lg_h, jnp.eye(self.n_agents * self.k)], axis=1) 98 | b = Lf_h + self.alpha * h # (n_agents * k,) 99 | 100 | r_lb = jnp.array([0.] * self.n_agents * self.k, dtype=jnp.float32) 101 | r_ub = jnp.array([jnp.inf] * self.n_agents * self.k, dtype=jnp.float32) 102 | 103 | l_box = jnp.concatenate([u_lb, r_lb], axis=0) 104 | u_box = jnp.concatenate([u_ub, r_ub], axis=0) 105 | 106 | qp = JaxProxQP.QPModel.create(H, g, C, b, l_box, u_box) 107 | settings = JaxProxQP.Settings.default() 108 | settings.max_iter = 100 109 | settings.dua_gap_thresh_abs = None 110 | solver = JaxProxQP(qp, settings) 111 | sol = solver.solve() 112 | 113 | assert sol.x.shape == (self.action_dim * self.n_agents + self.n_agents * self.k,) 114 | u_opt, r = sol.x[:self.action_dim * self.n_agents], sol.x[-self.n_agents * self.k:] 115 | u_opt = u_opt.reshape(self.n_agents, -1) 116 | 117 | return u_opt, r 118 | 119 | def save(self, save_dir: str, step: int): 120 | raise NotImplementedError 121 | 122 | def load(self, load_dir: str, step: int): 123 | raise NotImplementedError 124 | -------------------------------------------------------------------------------- /gcbfplus/algo/dec_share_cbf.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import einops as ei 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from jaxproxqp.jaxproxqp import JaxProxQP 7 | from typing import Optional, Tuple 8 | 9 | from gcbfplus.env.base import MultiAgentEnv 10 | from gcbfplus.trainer.data import Rollout 11 | from gcbfplus.utils.graph import GraphsTuple 12 | from gcbfplus.utils.typing import Action, Array, Params, PRNGKey, State 13 | from gcbfplus.utils.utils import mask2index, jax_vmap 14 | from .base import MultiAgentController 15 | from .utils import get_pwise_cbf_fn 16 | 17 | 18 | class DecShareCBF(MultiAgentController): 19 | """Same as DecShareCBF, but takes the k closest agents into account.""" 20 | 21 | def __init__( 22 | self, 23 | env: MultiAgentEnv, 24 | node_dim: int, 25 | edge_dim: int, 26 | state_dim: int, 27 | action_dim: int, 28 | n_agents: int, 29 | alpha: float = 1.0, 30 | **kwargs 31 | ): 32 | super().__init__(env=env, node_dim=node_dim, edge_dim=edge_dim, action_dim=action_dim, n_agents=n_agents) 33 | 34 | if hasattr(env, "enable_stop"): 35 | env.enable_stop = False 36 | 37 | self.cbf_alpha = alpha 38 | self.k = 3 39 | self.cbf = get_pwise_cbf_fn(env, self.k) 40 | 41 | @property 42 | def config(self) -> dict: 43 | return { 44 | "alpha": self.cbf_alpha, 45 | } 46 | 47 | @property 48 | def actor_params(self) -> Params: 49 | raise NotImplementedError 50 | 51 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]: 52 | raise NotImplementedError 53 | 54 | def get_cbf(self, graph: GraphsTuple) -> tuple[Array, Array]: 55 | ak_h0, ak_isobs = self.cbf(graph) 56 | return ak_h0, ak_isobs 57 | 58 | def update(self, rollout: Rollout, step: int) -> dict: 59 | raise NotImplementedError 60 | 61 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action: 62 | return self.get_qp_action(graph)[0] 63 | 64 | def get_qp_action(self, graph: GraphsTuple, relax_penalty: float = 1e3) -> [Action, Array]: 65 | assert graph.is_single # consider single graph 66 | agent_node_mask = graph.node_type == 0 67 | agent_node_id = mask2index(agent_node_mask, self.n_agents) 68 | 69 | def h_aug(new_agent_state: State) -> tuple[Array, Array]: 70 | new_state = graph.states.at[agent_node_id].set(new_agent_state) 71 | new_graph = graph._replace(edges=new_state[graph.receivers] - new_state[graph.senders], states=new_state) 72 | ak_h_, ak_isobs_ = self.get_cbf(new_graph) 73 | assert ak_h_.shape == (self.n_agents, self.k) 74 | assert ak_isobs_.shape == (self.n_agents, self.k) 75 | return ak_h_, ak_isobs_ 76 | 77 | def h(new_agent_state: State) -> Array: 78 | return h_aug(new_agent_state)[0] 79 | 80 | agent_state = graph.type_states(type_idx=0, n_type=self.n_agents) 81 | # (n_agents, k) 82 | ak_h, ak_isobs = h_aug(agent_state) 83 | # (n_agents, k | n_agents, nx) 84 | ak_hx = jax.jacfwd(h)(agent_state) 85 | 86 | a_dyn_f, a_dyn_g = self._env.control_affine_dyn(agent_state) 87 | ak_Lf_h = ei.einsum(ak_hx, a_dyn_f, "agent_i k agent_j nx, agent_j nx -> agent_i k") 88 | aka_Lg_h: Array = ei.einsum(ak_hx, a_dyn_g, "agent_i k agent_j nx, agent_j nx nu -> agent_i k agent_j nu") 89 | 90 | def index_fn(idx: int): 91 | k_Lg_h = aka_Lg_h[idx, :, idx] 92 | assert k_Lg_h.shape == (self.k, self.action_dim) 93 | return k_Lg_h 94 | 95 | ak_Lg_h_self = jax_vmap(index_fn)(jnp.arange(self.n_agents)) 96 | 97 | au_ref = self._env.u_ref(graph) 98 | assert au_ref.shape == (self.n_agents, self.action_dim) 99 | 100 | # (n_agents, ). 1 if agent-obs, 0.5 if agent-agent. 101 | ak_resp = jnp.where(ak_isobs, 1.0, 0.5) 102 | 103 | # construct QP 104 | au_opt, ar = jax_vmap(ft.partial(self._solve_qp_single, relax_penalty=relax_penalty))( 105 | ak_h, ak_Lf_h, ak_Lg_h_self, au_ref, ak_resp 106 | ) 107 | return au_opt, ar 108 | 109 | def _solve_qp_single(self, k_h, k_Lf_h, k_Lg_h, u_ref, k_responsibility: float, relax_penalty: float = 1e3): 110 | n_qp_x = self._env.action_dim + self.k 111 | 112 | assert k_h.shape == (self.k,) 113 | assert k_Lf_h.shape == (self.k,) 114 | assert k_Lg_h.shape == (self.k, self._env.action_dim) 115 | 116 | u_lb, u_ub = self._env.action_lim() 117 | assert u_lb.shape == u_ub.shape == (self.action_dim,) 118 | 119 | ########### 120 | 121 | H = jnp.eye(n_qp_x, dtype=jnp.float32) 122 | H = H.at[-self.k :, -self.k :].set(10.0) 123 | g = jnp.concatenate([-u_ref, relax_penalty * jnp.ones(self.k)], axis=0) 124 | assert g.shape == (n_qp_x,) 125 | 126 | k_C = -jnp.concatenate([k_Lg_h, jnp.eye(self.k)], axis=1) 127 | assert k_C.shape == (self.k, n_qp_x) 128 | 129 | # Responsibility is one if this is agent-obs, half if this is agent-agent. 130 | k_b = k_responsibility * (k_Lf_h + self.cbf_alpha * k_h) 131 | assert k_b.shape == (self.k,) 132 | 133 | r_lb = jnp.full(self.k, 0.0, dtype=jnp.float32) 134 | r_ub = jnp.full(self.k, jnp.inf, dtype=jnp.float32) 135 | 136 | l_box = jnp.concatenate([u_lb, r_lb], axis=0) 137 | u_box = jnp.concatenate([u_ub, r_ub], axis=0) 138 | assert l_box.shape == u_box.shape == (n_qp_x,) 139 | 140 | qp = JaxProxQP.QPModel.create(H, g, k_C, k_b, l_box, u_box) 141 | settings = JaxProxQP.Settings.default() 142 | settings.max_iter = 100 143 | settings.dua_gap_thresh_abs = None 144 | solver = JaxProxQP(qp, settings) 145 | sol = solver.solve() 146 | 147 | assert sol.x.shape == (n_qp_x,) 148 | u_opt, r = sol.x[: self.action_dim], sol.x[-self.k :] 149 | 150 | return u_opt, r 151 | 152 | def save(self, save_dir: str, step: int): 153 | raise NotImplementedError 154 | 155 | def load(self, load_dir: str, step: int): 156 | raise NotImplementedError 157 | -------------------------------------------------------------------------------- /gcbfplus/algo/gcbf.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.random as jr 3 | import optax 4 | import os 5 | import jax 6 | import functools as ft 7 | import jax.tree_util as jtu 8 | import numpy as np 9 | import pickle 10 | 11 | from typing import Optional, Tuple 12 | from flax.training.train_state import TrainState 13 | 14 | from gcbfplus.utils.typing import Action, Params, PRNGKey, Array 15 | from gcbfplus.utils.graph import GraphsTuple 16 | from gcbfplus.utils.utils import merge01, jax_vmap, tree_merge 17 | from gcbfplus.trainer.data import Rollout 18 | from gcbfplus.trainer.buffer import ReplayBuffer 19 | from gcbfplus.trainer.utils import has_any_nan, compute_norm_and_clip 20 | from gcbfplus.env.base import MultiAgentEnv 21 | from gcbfplus.algo.module.cbf import CBF 22 | from gcbfplus.algo.module.policy import DeterministicPolicy 23 | from .base import MultiAgentController 24 | 25 | 26 | class GCBF(MultiAgentController): 27 | 28 | def __init__( 29 | self, 30 | env: MultiAgentEnv, 31 | node_dim: int, 32 | edge_dim: int, 33 | state_dim: int, 34 | action_dim: int, 35 | n_agents: int, 36 | gnn_layers: int, 37 | batch_size: int, 38 | buffer_size: int, 39 | lr_actor: float = 3e-5, 40 | lr_cbf: float = 3e-5, 41 | alpha: float = 1.0, 42 | eps: float = 0.02, 43 | inner_epoch: int = 8, 44 | loss_action_coef: float = 0.001, 45 | loss_unsafe_coef: float = 1., 46 | loss_safe_coef: float = 1., 47 | loss_h_dot_coef: float = 0.2, 48 | max_grad_norm: float = 2., 49 | seed: int = 0, 50 | online_pol_refine: bool = False, 51 | **kwargs 52 | ): 53 | super(GCBF, self).__init__( 54 | env=env, 55 | node_dim=node_dim, 56 | edge_dim=edge_dim, 57 | action_dim=action_dim, 58 | n_agents=n_agents 59 | ) 60 | 61 | # set hyperparameters 62 | self.batch_size = batch_size 63 | self.lr_actor = lr_actor 64 | self.lr_cbf = lr_cbf 65 | self.alpha = alpha 66 | self.eps = eps 67 | self.inner_epoch = inner_epoch 68 | self.loss_action_coef = loss_action_coef 69 | self.loss_unsafe_coef = loss_unsafe_coef 70 | self.loss_safe_coef = loss_safe_coef 71 | self.loss_h_dot_coef = loss_h_dot_coef 72 | self.gnn_layers = gnn_layers 73 | self.max_grad_norm = max_grad_norm 74 | self.seed = seed 75 | self.online_pol_refine = online_pol_refine 76 | 77 | # set nominal graph for initialization of the neural networks 78 | nominal_graph = GraphsTuple( 79 | nodes=jnp.zeros((n_agents, node_dim)), 80 | edges=jnp.zeros((n_agents, edge_dim)), 81 | states=jnp.zeros((n_agents, state_dim)), 82 | n_node=jnp.array(n_agents), 83 | n_edge=jnp.array(n_agents), 84 | senders=jnp.arange(n_agents), 85 | receivers=jnp.arange(n_agents), 86 | node_type=jnp.zeros((n_agents,)), 87 | env_states=jnp.zeros((n_agents,)), 88 | ) 89 | self.nominal_graph = nominal_graph 90 | 91 | # set up CBF 92 | self.cbf = CBF( 93 | node_dim=node_dim, 94 | edge_dim=edge_dim, 95 | n_agents=n_agents, 96 | gnn_layers=gnn_layers 97 | ) 98 | key = jr.PRNGKey(seed) 99 | cbf_key, key = jr.split(key) 100 | cbf_params = self.cbf.net.init(cbf_key, nominal_graph, self.n_agents) 101 | cbf_optim = optax.adam(learning_rate=lr_cbf) 102 | self.cbf_optim = optax.apply_if_finite(cbf_optim, 1_000_000) 103 | self.cbf_train_state = TrainState.create( 104 | apply_fn=self.cbf.get_cbf, 105 | params=cbf_params, 106 | tx=self.cbf_optim 107 | ) 108 | 109 | # set up actor 110 | self.actor = DeterministicPolicy( 111 | node_dim=node_dim, 112 | edge_dim=edge_dim, 113 | action_dim=action_dim, 114 | n_agents=n_agents 115 | ) 116 | actor_key, key = jr.split(key) 117 | actor_params = self.actor.net.init(actor_key, nominal_graph, self.n_agents) 118 | actor_optim = optax.adam(learning_rate=lr_actor) 119 | self.actor_optim = optax.apply_if_finite(actor_optim, 1_000_000) 120 | self.actor_train_state = TrainState.create( 121 | apply_fn=self.actor.sample_action, 122 | params=actor_params, 123 | tx=self.actor_optim 124 | ) 125 | 126 | # set up key 127 | self.key = key 128 | self.buffer = ReplayBuffer(size=buffer_size) 129 | self.unsafe_buffer = ReplayBuffer(size=buffer_size // 2) 130 | 131 | @property 132 | def config(self) -> dict: 133 | return { 134 | 'batch_size': self.batch_size, 135 | 'lr_actor': self.lr_actor, 136 | 'lr_cbf': self.lr_cbf, 137 | 'alpha': self.alpha, 138 | 'eps': self.eps, 139 | 'inner_epoch': self.inner_epoch, 140 | 'loss_action_coef': self.loss_action_coef, 141 | 'loss_unsafe_coef': self.loss_unsafe_coef, 142 | 'loss_safe_coef': self.loss_safe_coef, 143 | 'loss_h_dot_coef': self.loss_h_dot_coef, 144 | 'gnn_layers': self.gnn_layers, 145 | 'seed': self.seed, 146 | 'max_grad_norm': self.max_grad_norm 147 | } 148 | 149 | @property 150 | def actor_params(self) -> Params: 151 | return self.actor_train_state.params 152 | 153 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action: 154 | if self.online_pol_refine: 155 | return self.online_policy_refinement(graph, params) 156 | if params is None: 157 | params = self.actor_params 158 | nn_action = 2 * self.actor.get_action(params, graph) + self._env.u_ref(graph) 159 | return nn_action 160 | 161 | def online_policy_refinement(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action: 162 | if params is None: 163 | params = self.actor_params 164 | h = self.get_cbf(graph) 165 | 166 | # try u_ref first 167 | u_ref = self._env.u_ref(graph) 168 | next_graph_u_ref = self._env.forward_graph(graph, u_ref) 169 | h_next_u_ref = self.get_cbf(next_graph_u_ref) 170 | h_dot_u_ref = (h_next_u_ref - h) / self._env.dt 171 | max_val_h_dot_u_ref = jax.nn.relu(-h_dot_u_ref - self.alpha * h) 172 | nn_action = 2 * self.actor.get_action(params, graph) + u_ref 173 | nn_action = jnp.where(max_val_h_dot_u_ref > 0, nn_action, u_ref) 174 | 175 | max_iter = 30 176 | lr = 0.1 177 | 178 | def do_refinement(inp): 179 | i_iter, action, prev_h_dot_val = inp 180 | 181 | def h_dot_cond_val(a: Action): 182 | next_graph = self._env.forward_graph(graph, a) 183 | h_next = self.get_cbf(next_graph) 184 | h_dot = (h_next - h) / self._env.dt 185 | max_val_h_dot = jax.nn.relu(-h_dot - self.alpha * h).mean() 186 | return max_val_h_dot 187 | 188 | h_dot_val, grad = jax.value_and_grad(h_dot_cond_val)(action) 189 | action = action - lr * grad 190 | i_iter += 1 191 | return i_iter, action, h_dot_val 192 | 193 | def continue_refinement(inp): 194 | i_iter, action, h_dot_val = inp 195 | return (h_dot_val > 0) & (i_iter < max_iter) 196 | 197 | _, nn_action, _ = jax.lax.while_loop( 198 | continue_refinement, do_refinement, init_val=(0, nn_action, 1.0) 199 | ) 200 | 201 | return nn_action 202 | 203 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]: 204 | if params is None: 205 | params = self.actor_params 206 | action, log_pi = self.actor_train_state.apply_fn(params, graph, key) 207 | return 2 * action + self._env.u_ref(graph), log_pi 208 | 209 | def get_cbf(self, graph: GraphsTuple, params: Optional[Params] = None) -> Array: 210 | if params is None: 211 | params = self.cbf_train_state.params 212 | return self.cbf.get_cbf(params, graph) 213 | 214 | def update(self, rollout: Rollout, step: int) -> dict: 215 | key, self.key = jr.split(self.key) 216 | 217 | if self.buffer.length > self.batch_size: 218 | # sample from memory and unsafe_memory 219 | memory = self.buffer.sample(rollout.length // 2) 220 | unsafe_memory = self.unsafe_buffer.sample(rollout.length * rollout.time_horizon) 221 | 222 | # append new data to memory and unsafe_memory 223 | unsafe_mask = jax_vmap(jax_vmap(self._env.unsafe_mask))(rollout.graph).max(axis=-1) 224 | self.unsafe_buffer.append(jtu.tree_map(lambda x: x[unsafe_mask], rollout)) 225 | self.buffer.append(rollout) 226 | 227 | # get update data 228 | rollout = tree_merge([memory, rollout]) 229 | rollout = jtu.tree_map(lambda x: merge01(x), rollout) 230 | rollout = tree_merge([unsafe_memory, rollout]) 231 | else: 232 | self.buffer.append(rollout) 233 | unsafe_mask = jax_vmap(jax_vmap(self._env.unsafe_mask))(rollout.graph).max(axis=-1) 234 | self.unsafe_buffer.append(jtu.tree_map(lambda x: x[unsafe_mask], rollout)) 235 | rollout = jtu.tree_map(lambda x: merge01(x), rollout) 236 | 237 | # inner loop 238 | update_info = {} 239 | for i_epoch in range(self.inner_epoch): 240 | idx = np.arange(rollout.length) 241 | np.random.shuffle(idx) 242 | batch_idx = jnp.array(jnp.array_split(idx, idx.shape[0] // self.batch_size)) 243 | 244 | cbf_train_state, actor_train_state, update_info = self.update_inner( 245 | self.cbf_train_state, self.actor_train_state, rollout, batch_idx 246 | ) 247 | self.cbf_train_state = cbf_train_state 248 | self.actor_train_state = actor_train_state 249 | 250 | return update_info 251 | 252 | @ft.partial(jax.jit, static_argnums=(0,), donate_argnums=(1, 2)) 253 | def update_inner( 254 | self, cbf_train_state: TrainState, actor_train_state: TrainState, rollout: Rollout, batch_idx: Array 255 | ) -> Tuple[TrainState, TrainState, dict]: 256 | 257 | def update_fn(carry, idx): 258 | """Update the actor and the CBF network for a single batch given the batch index.""" 259 | cbf, actor = carry 260 | rollout_batch = jtu.tree_map(lambda x: x[idx], rollout) 261 | 262 | def get_loss(cbf_params: Params, actor_params: Params) -> Tuple[Array, dict]: 263 | # get CBF values 264 | cbf_fn = jax_vmap(ft.partial(self.cbf.get_cbf, cbf_params)) 265 | h = cbf_fn(rollout_batch.graph).squeeze() 266 | h = merge01(h) 267 | 268 | # unsafe region h(x) < 0 269 | unsafe_mask = merge01(jax_vmap(self._env.unsafe_mask)(rollout_batch.graph)) 270 | unsafe_data_ratio = jnp.mean(unsafe_mask) 271 | h_unsafe = jnp.where(unsafe_mask, h, -jnp.ones_like(h) * self.eps * 2) 272 | max_val_unsafe = jax.nn.relu(h_unsafe + self.eps) 273 | loss_unsafe = jnp.sum(max_val_unsafe) / (jnp.count_nonzero(unsafe_mask) + 1e-6) 274 | acc_unsafe_mask = jnp.where(unsafe_mask, h, jnp.ones_like(h)) 275 | acc_unsafe = (jnp.sum(jnp.less(acc_unsafe_mask, 0)) + 1e-6) / (jnp.count_nonzero(unsafe_mask) + 1e-6) 276 | 277 | # safe region h(x) > 0 278 | safe_mask = merge01(jax_vmap(self._env.safe_mask)(rollout_batch.graph)) 279 | h_safe = jnp.where(safe_mask, h, jnp.ones_like(h) * self.eps * 2) 280 | max_val_safe = jax.nn.relu(-h_safe + self.eps) 281 | loss_safe = jnp.sum(max_val_safe) / (jnp.count_nonzero(safe_mask) + 1e-6) 282 | acc_safe_mask = jnp.where(safe_mask, h, -jnp.ones_like(h)) 283 | acc_safe = (jnp.sum(jnp.greater(acc_safe_mask, 0)) + 1e-6) / (jnp.count_nonzero(safe_mask) + 1e-6) 284 | 285 | # get actions 286 | action_fn = jax.vmap(ft.partial(self.actor.get_action, actor_params)) 287 | action = action_fn(rollout_batch.graph) 288 | 289 | # get next graph 290 | forward_fn = jax_vmap(self._env.forward_graph) 291 | next_graph = forward_fn(rollout_batch.graph, action) 292 | h_next = merge01(cbf_fn(next_graph)) 293 | h_dot = (h_next - h) / self._env.dt 294 | 295 | # h_dot + alpha * h > 0 296 | max_val_h_dot = jax.nn.relu(-h_dot - self.alpha * h + self.eps) 297 | loss_h_dot = jnp.mean(max_val_h_dot) # + jnp.max(max_val_h_dot) 298 | acc_h_dot = jnp.mean(jnp.greater(h_dot + self.alpha * h, 0)) 299 | 300 | # action loss 301 | u_ref = jax_vmap(self._env.u_ref)(rollout_batch.graph) 302 | loss_action = jnp.mean(jnp.square(action - u_ref).sum(axis=-1)) 303 | 304 | # total loss 305 | total_loss = ( 306 | self.loss_action_coef * loss_action 307 | + self.loss_unsafe_coef * loss_unsafe 308 | + self.loss_safe_coef * loss_safe 309 | + self.loss_h_dot_coef * loss_h_dot 310 | ) 311 | 312 | return total_loss, {'loss/action': loss_action, 313 | 'loss/unsafe': loss_unsafe, 314 | 'loss/safe': loss_safe, 315 | 'loss/h_dot': loss_h_dot, 316 | 'loss/total': total_loss, 317 | 'acc/unsafe': acc_unsafe, 318 | 'acc/safe': acc_safe, 319 | 'acc/h_dot': acc_h_dot, 320 | 'acc/unsafe_data_ratio': unsafe_data_ratio} 321 | 322 | (loss, loss_info), (grad_cbf, grad_actor) = jax.value_and_grad( 323 | get_loss, has_aux=True, argnums=(0, 1))(cbf.params, actor.params) 324 | grad_cbf_has_nan = has_any_nan(grad_cbf).astype(jnp.float32) 325 | grad_actor_has_nan = has_any_nan(grad_actor).astype(jnp.float32) 326 | grad_cbf, grad_cbf_norm = compute_norm_and_clip(grad_cbf, self.max_grad_norm) 327 | grad_actor, grad_actor_norm = compute_norm_and_clip(grad_actor, self.max_grad_norm) 328 | cbf = cbf.apply_gradients(grads=grad_cbf) 329 | actor = actor.apply_gradients(grads=grad_actor) 330 | return (cbf, actor), {'grad_norm/cbf': grad_cbf_norm, 331 | 'grad_norm/actor': grad_actor_norm, 332 | 'grad_has_nan/cbf': grad_cbf_has_nan, 333 | 'grad_has_nan/actor': grad_actor_has_nan} | loss_info 334 | 335 | (cbf_train_state, actor_train_state), info = jax.lax.scan( 336 | update_fn, (cbf_train_state, actor_train_state), batch_idx 337 | ) 338 | 339 | # get training info of the last epoch 340 | info = jtu.tree_map(lambda x: x[-1], info) 341 | 342 | return cbf_train_state, actor_train_state, info 343 | 344 | def save(self, save_dir: str, step: int): 345 | model_dir = os.path.join(save_dir, str(step)) 346 | if not os.path.exists(model_dir): 347 | os.makedirs(model_dir) 348 | pickle.dump(self.actor_train_state.params, open(os.path.join(model_dir, 'actor.pkl'), 'wb')) 349 | pickle.dump(self.cbf_train_state.params, open(os.path.join(model_dir, 'cbf.pkl'), 'wb')) 350 | 351 | def load(self, load_dir: str, step: int): 352 | path = os.path.join(load_dir, str(step)) 353 | 354 | self.actor_train_state = \ 355 | self.actor_train_state.replace(params=pickle.load(open(os.path.join(path, 'actor.pkl'), 'rb'))) 356 | self.cbf_train_state = \ 357 | self.cbf_train_state.replace(params=pickle.load(open(os.path.join(path, 'cbf.pkl'), 'rb'))) 358 | -------------------------------------------------------------------------------- /gcbfplus/algo/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/algo/module/__init__.py -------------------------------------------------------------------------------- /gcbfplus/algo/module/cbf.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import functools as ft 3 | 4 | from typing import Type 5 | from ...nn.gnn import GNN 6 | from ...nn.mlp import MLP 7 | from ...nn.utils import default_nn_init 8 | from ...utils.typing import Array, Params 9 | from ...utils.graph import GraphsTuple 10 | 11 | 12 | class CBFNet(nn.Module): 13 | gnn_cls: Type[GNN] 14 | head_cls: Type[nn.Module] 15 | 16 | @nn.compact 17 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> Array: 18 | x = self.gnn_cls()(obs, node_type=0, n_type=n_agents) 19 | x = self.head_cls()(x) 20 | x = nn.tanh(nn.Dense(1, kernel_init=default_nn_init())(x)) 21 | return x 22 | 23 | 24 | class CBF: 25 | 26 | def __init__(self, node_dim: int, edge_dim: int, n_agents: int, gnn_layers: int): 27 | self.node_dim = node_dim 28 | self.edge_dim = edge_dim 29 | self.n_agents = n_agents 30 | 31 | self.cbf_gnn = ft.partial( 32 | GNN, 33 | msg_dim=128, 34 | hid_size_msg=(256, 256), 35 | hid_size_aggr=(128, 128), 36 | hid_size_update=(256, 256), 37 | out_dim=128, 38 | n_layers=gnn_layers 39 | ) 40 | self.cbf_head = ft.partial( 41 | MLP, 42 | hid_sizes=(256, 256), 43 | act=nn.relu, 44 | act_final=False, 45 | name='CBFHead' 46 | ) 47 | self.net = CBFNet( 48 | gnn_cls=self.cbf_gnn, 49 | head_cls=self.cbf_head 50 | ) 51 | 52 | def get_cbf(self, params: Params, obs: GraphsTuple) -> Array: 53 | return self.net.apply(params, obs, self.n_agents) 54 | -------------------------------------------------------------------------------- /gcbfplus/algo/module/distribution.py: -------------------------------------------------------------------------------- 1 | import tensorflow_probability.substrates.jax as tfp 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import jax.random as jr 5 | 6 | tfd = tfp.distributions 7 | tfb = tfp.bijectors 8 | 9 | 10 | class TanhTransformedDistribution(tfd.TransformedDistribution): 11 | 12 | def __init__(self, distribution: tfd.Distribution, threshold: float = 0.999, validate_args: bool = False): 13 | super().__init__(distribution=distribution, bijector=tfb.Tanh(), validate_args=validate_args) 14 | self._threshold = threshold 15 | self.inverse_threshold = self.bijector.inverse(threshold) 16 | 17 | inverse_threshold = self.bijector.inverse(threshold) 18 | # average(pdf) = p / epsilon 19 | # So log(average(pdf)) = log(p) - log(epsilon) 20 | log_epsilon = np.log(1.0 - threshold) 21 | 22 | self._log_prob_left = self.distribution.log_cdf(-inverse_threshold) - log_epsilon 23 | self._log_prob_right = self.distribution.log_survival_function(inverse_threshold) - log_epsilon 24 | 25 | def log_prob(self, value, name='log_prob', **kwargs): 26 | # Without this clip there would be NaNs in the inner tf.where and that 27 | # causes issues for some reasons. 28 | value = jnp.clip(value, -self._threshold, self._threshold) 29 | # The inverse image of {threshold} is the interval [atanh(threshold), inf] 30 | # which has a probability of "log_prob_right" under the given distribution. 31 | return jnp.where( 32 | value <= -self._threshold, 33 | self._log_prob_left, 34 | jnp.where(value >= self._threshold, self._log_prob_right, super().log_prob(value)), 35 | ) 36 | 37 | def entropy(self, name='entropy', **kwargs): 38 | # We return an estimation using a single sample of the log_det_jacobian. 39 | # We can still do some backpropagation with this estimate. 40 | seed = np.random.randint(0, 102400) 41 | return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( 42 | self.distribution.sample(seed=jr.PRNGKey(seed)), event_ndims=0 43 | ) 44 | 45 | def _mode(self) -> jnp.ndarray: 46 | return self.bijector.forward(self.distribution.mode()) 47 | 48 | @classmethod 49 | def _parameter_properties(cls, dtype, num_classes=None): 50 | td_properties = super()._parameter_properties(dtype, num_classes=num_classes) 51 | del td_properties["bijector"] 52 | return td_properties 53 | 54 | @property 55 | def experimental_is_sharded(self): 56 | raise NotImplementedError 57 | 58 | def _sample_n(self, n, seed=None, **kwargs): 59 | pass 60 | 61 | def _variance(self, **kwargs): 62 | pass 63 | 64 | @classmethod 65 | def _maximum_likelihood_parameters(cls, value): 66 | pass 67 | -------------------------------------------------------------------------------- /gcbfplus/algo/module/policy.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import functools as ft 3 | import numpy as np 4 | import jax.nn as jnn 5 | import jax.numpy as jnp 6 | 7 | from typing import Type, Tuple 8 | from abc import ABC, abstractproperty, abstractmethod 9 | 10 | from .distribution import TanhTransformedDistribution, tfd 11 | from ...utils.typing import Action, Array 12 | from ...utils.graph import GraphsTuple 13 | from ...nn.utils import default_nn_init, scaled_init 14 | from ...nn.gnn import GNN 15 | from ...nn.mlp import MLP 16 | from ...utils.typing import PRNGKey, Params 17 | 18 | 19 | class PolicyDistribution(nn.Module, ABC): 20 | 21 | @abstractmethod 22 | def __call__(self, *args, **kwargs) -> tfd.Distribution: 23 | pass 24 | 25 | @abstractproperty 26 | def nu(self) -> int: 27 | pass 28 | 29 | 30 | class TanhNormal(PolicyDistribution): 31 | base_cls: Type[GNN] 32 | _nu: int 33 | scale_final: float = 0.01 34 | std_dev_min: float = 1e-5 35 | std_dev_init: float = 0.5 36 | 37 | @property 38 | def std_dev_init_inv(self): 39 | # inverse of log(sum(exp())). 40 | inv = np.log(np.exp(self.std_dev_init) - 1) 41 | assert np.allclose(np.logaddexp(inv, 0), self.std_dev_init) 42 | return inv 43 | 44 | @nn.compact 45 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> tfd.Distribution: 46 | x = self.base_cls()(obs, node_type=0, n_type=n_agents) 47 | # x = x.nodes 48 | scaler_init = scaled_init(default_nn_init(), self.scale_final) 49 | feats_scaled = nn.Dense(256, kernel_init=scaler_init, name="ScaleHid")(x) 50 | 51 | means = nn.Dense(self.nu, kernel_init=default_nn_init(), name="OutputDenseMean")(feats_scaled) 52 | stds_trans = nn.Dense(self.nu, kernel_init=default_nn_init(), name="OutputDenseStdTrans")(feats_scaled) 53 | stds = jnn.softplus(stds_trans + self.std_dev_init_inv) + self.std_dev_min 54 | 55 | distribution = tfd.Normal(loc=means, scale=stds) 56 | return tfd.Independent(TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1) 57 | 58 | @property 59 | def nu(self): 60 | return self._nu 61 | 62 | 63 | class Deterministic(nn.Module): 64 | base_cls: Type[GNN] 65 | head_cls: Type[nn.Module] 66 | _nu: int 67 | 68 | @nn.compact 69 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> Action: 70 | x = self.base_cls()(obs, node_type=0, n_type=n_agents) 71 | x = self.head_cls()(x) 72 | x = nn.tanh(nn.Dense(self._nu, kernel_init=default_nn_init(), name="OutputDense")(x)) 73 | return x 74 | 75 | 76 | class MultiAgentPolicy(ABC): 77 | 78 | def __init__(self, node_dim: int, edge_dim: int, n_agents: int, action_dim: int): 79 | self.node_dim = node_dim 80 | self.edge_dim = edge_dim 81 | self.n_agents = n_agents 82 | self.action_dim = action_dim 83 | 84 | @abstractmethod 85 | def get_action(self, params: Params, obs: GraphsTuple) -> Action: 86 | pass 87 | 88 | @abstractmethod 89 | def sample_action(self, params: Params, obs: GraphsTuple, key: PRNGKey) -> Tuple[Action, Array]: 90 | pass 91 | 92 | @abstractmethod 93 | def eval_action(self, params: Params, obs: GraphsTuple, action: Action, key: PRNGKey) -> Tuple[Array, Array]: 94 | pass 95 | 96 | 97 | class DeterministicPolicy(MultiAgentPolicy): 98 | 99 | def __init__( 100 | self, 101 | node_dim: int, 102 | edge_dim: int, 103 | n_agents: int, 104 | action_dim: int, 105 | gnn_layers: int = 1, 106 | ): 107 | super().__init__(node_dim, edge_dim, n_agents, action_dim) 108 | self.policy_base = ft.partial( 109 | GNN, 110 | msg_dim=128, 111 | hid_size_msg=(256, 256), 112 | hid_size_aggr=(128, 128), 113 | hid_size_update=(256, 256), 114 | out_dim=128, 115 | n_layers=gnn_layers 116 | ) 117 | self.policy_head = ft.partial( 118 | MLP, 119 | hid_sizes=(256, 256), 120 | act=nn.relu, 121 | act_final=False, 122 | name='PolicyHead' 123 | ) 124 | self.net = Deterministic(base_cls=self.policy_base, head_cls=self.policy_head, _nu=action_dim) 125 | self.std = 0.1 126 | 127 | def get_action(self, params: Params, obs: GraphsTuple) -> Action: 128 | return self.net.apply(params, obs, self.n_agents) 129 | 130 | def sample_action(self, params: Params, obs: GraphsTuple, key: PRNGKey) -> Tuple[Action, Array]: 131 | action = self.get_action(params, obs) 132 | log_pi = jnp.zeros_like(action) 133 | return action, log_pi 134 | 135 | def eval_action(self, params: Params, obs: GraphsTuple, action: Action, key: PRNGKey) -> Tuple[Array, Array]: 136 | raise NotImplementedError 137 | 138 | 139 | class PPOPolicy(MultiAgentPolicy): 140 | 141 | def __init__( 142 | self, 143 | node_dim: int, 144 | edge_dim: int, 145 | n_agents: int, 146 | action_dim: int, 147 | gnn_layers: int = 1, 148 | ): 149 | super().__init__(node_dim, edge_dim, n_agents, action_dim) 150 | self.dist_base = ft.partial( 151 | GNN, 152 | msg_dim=64, 153 | hid_size_msg=(128, 128), 154 | hid_size_aggr=(128, 128), 155 | hid_size_update=(128, 128), 156 | out_dim=64, 157 | n_layers=gnn_layers 158 | ) 159 | self.dist = TanhNormal(base_cls=self.dist_base, _nu=action_dim) 160 | 161 | def get_action(self, params: Params, obs: GraphsTuple) -> Action: 162 | dist = self.dist.apply(params, obs, n_agents=self.n_agents) 163 | action = dist.mode() 164 | return action 165 | 166 | def sample_action(self, params: Params, obs: GraphsTuple, key: PRNGKey) -> Tuple[Action, Array]: 167 | dist = self.dist.apply(params, obs, n_agents=self.n_agents) 168 | action = dist.sample(seed=key) 169 | log_pi = dist.log_prob(action) 170 | return action, log_pi 171 | 172 | def eval_action(self, params: Params, obs: GraphsTuple, action: Action, key: PRNGKey) -> Tuple[Array, Array]: 173 | dist = self.dist.apply(params, obs, n_agents=self.n_agents) 174 | log_pi = dist.log_prob(action) 175 | entropy = dist.entropy(seed=key) 176 | return log_pi, entropy 177 | -------------------------------------------------------------------------------- /gcbfplus/algo/module/value.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import flax.linen as nn 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from typing import Type 7 | 8 | from ...nn.mlp import MLP 9 | from ...nn.gnn import GNN 10 | from ...nn.utils import default_nn_init 11 | from ...utils.typing import Array, Params 12 | from ...utils.graph import GraphsTuple 13 | 14 | 15 | class StateFn(nn.Module): 16 | gnn_cls: Type[GNN] 17 | aggr_cls: Type[nn.Module] 18 | head_cls: Type[nn.Module] 19 | 20 | @nn.compact 21 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> Array: 22 | # get node features 23 | x = self.gnn_cls()(obs, node_type=0, n_type=n_agents) 24 | 25 | # aggregate information using attention 26 | gate_feats = self.aggr_cls()(x) 27 | gate_feats = nn.Dense(1, kernel_init=default_nn_init())(gate_feats).squeeze(-1) 28 | attn = jax.nn.softmax(gate_feats, axis=-1) 29 | x = jnp.sum(attn[:, None] * x, axis=0) 30 | 31 | # pass through head class 32 | x = self.head_cls()(x) 33 | x = nn.Dense(1, kernel_init=default_nn_init())(x) 34 | 35 | return x 36 | 37 | 38 | class ValueNet: 39 | 40 | def __init__(self, node_dim: int, edge_dim: int, n_agents: int, gnn_layers: int = 1): 41 | self.node_dim = node_dim 42 | self.edge_dim = edge_dim 43 | self.n_agents = n_agents 44 | self.value_gnn = ft.partial( 45 | GNN, 46 | msg_dim=64, 47 | hid_size_msg=(128, 128), 48 | hid_size_aggr=(128, 128), 49 | hid_size_update=(128, 128), 50 | out_dim=64, 51 | n_layers=gnn_layers 52 | ) 53 | self.value_attn = ft.partial( 54 | MLP, 55 | hid_sizes=(128, 128), 56 | act=nn.relu, 57 | act_final=False, 58 | name='ValueAttn' 59 | ) 60 | self.value_head = ft.partial( 61 | MLP, 62 | hid_sizes=(128, 128), 63 | act=nn.relu, 64 | act_final=False, 65 | name='ValueHead' 66 | ) 67 | # self.net = StateFn(, _nu=1) 68 | self.net = StateFn( 69 | gnn_cls=self.value_gnn, 70 | aggr_cls=self.value_attn, 71 | head_cls=self.value_head 72 | ) 73 | 74 | def get_value(self, params: Params, obs: GraphsTuple) -> Array: 75 | values = self.net.apply(params, obs, self.n_agents) 76 | return values.squeeze() 77 | -------------------------------------------------------------------------------- /gcbfplus/algo/utils.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import einops as ei 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from typing import Tuple 7 | 8 | from ..env.base import MultiAgentEnv 9 | from ..env.double_integrator import DoubleIntegrator 10 | from ..env.dubins_car import DubinsCar 11 | from ..env.linear_drone import LinearDrone 12 | from ..env.single_integrator import SingleIntegrator 13 | from ..env.crazyflie import CrazyFlie 14 | from ..utils.graph import GraphsTuple 15 | from ..utils.typing import Array, Done, Reward 16 | 17 | 18 | def compute_gae_fn( 19 | values: Array, rewards: Reward, dones: Done, next_values: Array, gamma: float, gae_lambda: float 20 | ) -> Tuple[Array, Array]: 21 | """ 22 | Compute generalized advantage estimation. 23 | """ 24 | deltas = rewards + gamma * next_values * (1 - dones) - values 25 | gaes = deltas 26 | 27 | def scan_fn(gae, inp): 28 | delta, done = inp 29 | gae_prev = delta + gamma * gae_lambda * (1 - done) * gae 30 | return gae_prev, gae_prev 31 | 32 | _, gaes_prev = jax.lax.scan(scan_fn, gaes[-1], (deltas[:-1], dones[:-1]), reverse=True) 33 | gaes = jnp.concatenate([gaes_prev, gaes[-1, None]], axis=0) 34 | 35 | return gaes + values, (gaes - gaes.mean()) / (gaes.std() + 1e-8) 36 | 37 | 38 | def compute_gae( 39 | values: Array, rewards: Reward, dones: Done, next_values: Array, gamma: float, gae_lambda: float 40 | ) -> Tuple[Array, Array]: 41 | return jax.vmap(ft.partial(compute_gae_fn, gamma=gamma, gae_lambda=gae_lambda))(values, rewards, dones, next_values) 42 | 43 | 44 | def pwise_cbf_single_integrator_(pos: Array, agent_idx: int, o_obs_pos: Array, a_pos: Array, r: float, k: int): 45 | n_agent = len(a_pos) 46 | 47 | all_obs_pos = jnp.concatenate([a_pos, o_obs_pos], axis=0) 48 | 49 | # Only consider the k obstacles. 50 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1) 51 | # Remove self collisions 52 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2) 53 | 54 | # Take the k closest obstacles. 55 | k_idx = jnp.argsort(o_dist_sq)[:k] 56 | k_dist_sq = o_dist_sq[k_idx] 57 | # Take radius into account. Add some epsilon for qp solver error. 58 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2 59 | 60 | k_h0 = k_dist_sq 61 | k_isobs = k_idx >= n_agent 62 | 63 | return k_h0, k_isobs 64 | 65 | 66 | def pwise_cbf_single_integrator(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int): 67 | # (n_agents, 2) 68 | a_states = graph.type_states(type_idx=0, n_type=n_agent) 69 | # (n_obs, 2) 70 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays) 71 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent) 72 | 73 | agent_idx = jnp.arange(n_agent) 74 | fn = jax.vmap(ft.partial(pwise_cbf_single_integrator_, r=r, k=k), in_axes=(0, 0, 0, None)) 75 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states) 76 | return ak_h0, ak_isobs 77 | 78 | 79 | def pwise_cbf_double_integrator_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int): 80 | n_agent = len(a_state) 81 | 82 | pos = state[:2] 83 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0) 84 | all_obs_pos = all_obs_state[:, :2] 85 | del o_obs_state 86 | 87 | # Only consider the k closest obstacles. 88 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1) 89 | # Remove self collisions 90 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2) 91 | # Take the k closest obstacles. 92 | k_idx = jnp.argsort(o_dist_sq)[:k] 93 | k_dist_sq = o_dist_sq[k_idx] 94 | # Take radius into account. 95 | k_dist_sq = k_dist_sq - 4 * r ** 2 96 | 97 | k_h0 = k_dist_sq 98 | assert k_h0.shape == (k,) 99 | 100 | k_xdiff = state[:2] - all_obs_state[k_idx][:, :2] 101 | k_vdiff = state[2:] - all_obs_state[k_idx][:, 2:] 102 | assert k_xdiff.shape == k_vdiff.shape == (k, 2) 103 | 104 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1) 105 | assert k_h0_dot.shape == (k,) 106 | 107 | k_h1 = k_h0_dot + 10.0 * k_h0 108 | 109 | k_isobs = k_idx >= n_agent 110 | 111 | return k_h1, k_isobs 112 | 113 | 114 | def pwise_cbf_double_integrator(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int): 115 | # (n_agents, 4) 116 | a_states = graph.type_states(type_idx=0, n_type=n_agent) 117 | # (n_obs, 4) 118 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays) 119 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent) 120 | 121 | agent_idx = jnp.arange(n_agent) 122 | fn = jax.vmap(ft.partial(pwise_cbf_double_integrator_, r=r, k=k), in_axes=(0, 0, 0, None)) 123 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states) 124 | return ak_h0, ak_isobs 125 | 126 | 127 | def pwise_cbf_dubins_car_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int): 128 | n_agent = len(a_state) 129 | n_obs = len(o_obs_state) 130 | 131 | pos = state[:2] 132 | vel = state[3] * jnp.array([jnp.cos(state[2]), jnp.sin(state[2])]) 133 | assert vel.shape == (2,) 134 | 135 | agent_vel = a_state[:, 3, None] * jnp.stack([jnp.cos(a_state[:, 2]), jnp.sin(a_state[:, 2])], axis=-1) 136 | assert agent_vel.shape == (n_agent, 2) 137 | 138 | all_obs_pos = jnp.concatenate([a_state[:, :2], o_obs_state[:, :2]], axis=0) 139 | all_obs_vel = jnp.concatenate([agent_vel, jnp.zeros((n_obs, 2))], axis=0) 140 | del o_obs_state 141 | 142 | # Only consider the k closest obstacles. 143 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1) 144 | # Remove self collisions 145 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2) 146 | # Take the k closest obstacles. 147 | k_idx = jnp.argsort(o_dist_sq)[:k] 148 | k_dist_sq = o_dist_sq[k_idx] 149 | # Take radius into account. Add some epsilon for qp solver error. 150 | k_dist_sq = k_dist_sq - 4 * r ** 2 151 | 152 | k_h0 = k_dist_sq 153 | assert k_h0.shape == (k,) 154 | 155 | k_xdiff = state[:2] - all_obs_pos[k_idx] 156 | k_vdiff = agent_vel[agent_idx] - all_obs_vel[k_idx] 157 | assert k_xdiff.shape == k_vdiff.shape == (k, 2) 158 | 159 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1) 160 | assert k_h0_dot.shape == (k,) 161 | 162 | k_h1 = k_h0_dot + 5.0 * k_h0 163 | 164 | k_isobs = k_idx >= n_agent 165 | 166 | return k_h1, k_isobs 167 | 168 | 169 | def pwise_cbf_dubins_car(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int): 170 | # (n_agents, 4) 171 | a_states = graph.type_states(type_idx=0, n_type=n_agent) 172 | # (n_obs, 4) 173 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays) 174 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent) 175 | 176 | agent_idx = jnp.arange(n_agent) 177 | fn = jax.vmap(ft.partial(pwise_cbf_dubins_car_, r=r, k=k), in_axes=(0, 0, 0, None)) 178 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states) 179 | return ak_h0, ak_isobs 180 | 181 | 182 | def pwise_cbf_crazyflie_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int): 183 | # state: ( 12, ) 184 | n_agent = len(a_state) 185 | 186 | pos = state[:3] 187 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0) 188 | all_obs_pos = all_obs_state[:, :3] 189 | del o_obs_state 190 | 191 | # Only consider the k closest obstacles. 192 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1) 193 | # Remove self collisions 194 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2) 195 | # Take the k closest obstacles. 196 | k_idx = jnp.argsort(o_dist_sq)[:k] 197 | k_dist_sq = o_dist_sq[k_idx] 198 | # Take radius into account. Add some epsilon for qp solver error. 199 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2 200 | 201 | k_h0 = k_dist_sq 202 | assert k_h0.shape == (k,) 203 | 204 | # all_obs_state = all_obs_state[k_idx] 205 | 206 | def crazyflie_f_(x: Array) -> Array: 207 | Ixx, Iyy, Izz = CrazyFlie.PARAMS["Ixx"], CrazyFlie.PARAMS["Iyy"], CrazyFlie.PARAMS["Izz"] 208 | I = jnp.array([Ixx, Iyy, Izz]) 209 | # roll, pitch, yaw 210 | phi, theta, psi = x[CrazyFlie.PHI], x[CrazyFlie.THETA], x[CrazyFlie.PSI] 211 | c_phi, s_phi = jnp.cos(phi), jnp.sin(phi) 212 | c_th, s_th = jnp.cos(theta), jnp.sin(theta) 213 | c_psi, s_psi = jnp.cos(psi), jnp.sin(psi) 214 | t_th = jnp.tan(theta) 215 | 216 | u, v, w = x[CrazyFlie.U], x[CrazyFlie.V], x[CrazyFlie.W] 217 | uvw = jnp.array([u, v, w]) 218 | 219 | p, q, r = x[CrazyFlie.P], x[CrazyFlie.Q], x[CrazyFlie.R] 220 | pqr = jnp.array([p, q, r]) 221 | 222 | # Linear velocity 223 | R_W_cf = jnp.array( 224 | [ 225 | [c_psi * c_th, c_psi * s_th * s_phi - s_psi * c_phi, c_psi * s_th * c_phi + s_psi * s_phi], 226 | [s_psi * c_th, s_psi * s_th * s_phi + c_psi * c_phi, s_psi * s_th * c_phi - c_psi * s_phi], 227 | [-s_th, c_th * s_phi, c_th * c_phi], 228 | ] 229 | ) 230 | v_Wcf_cf = jnp.array([u, v, w]) 231 | v_Wcf_W = R_W_cf @ v_Wcf_cf 232 | assert v_Wcf_W.shape == (3,) 233 | 234 | # Euler angle dynamics. 235 | mat = jnp.array( 236 | [ 237 | [0, s_phi / c_th, c_phi / c_th], 238 | [0, c_phi, -s_phi], 239 | [1, s_phi * t_th, c_phi * t_th], 240 | ] 241 | ) 242 | deuler_rpy = mat @ pqr 243 | deuler_ypr = deuler_rpy[::-1] 244 | 245 | # Body frame linear acceleration. 246 | acc_cf_g = -R_W_cf[2, :] * 9.81 247 | acc_cf = -jnp.cross(pqr, uvw) + acc_cf_g 248 | 249 | # Body frame angular acceleration. 250 | pqr_dot = -jnp.cross(pqr, I * pqr) / I 251 | rpq_dot = pqr_dot[::-1] 252 | assert pqr_dot.shape == (3,) 253 | 254 | x_dot = jnp.concatenate([v_Wcf_W, deuler_ypr, acc_cf, rpq_dot], axis=0) 255 | return x_dot 256 | 257 | def h0(x, obs_x): 258 | k_xdiff = x[:3] - obs_x[:, :3] 259 | dist = jnp.square(k_xdiff).sum(axis=-1) 260 | dist = dist.at[agent_idx].set(1e2) 261 | return dist[k_idx] - 4 * r ** 2 # (k,) 262 | 263 | def h1(x, obs_x): 264 | x_dot = crazyflie_f_(x) # (nx,) 265 | obs_x_dot = jax.vmap(crazyflie_f_)(obs_x) # (k, nx) 266 | 267 | h0_x = jax.jacfwd(h0, argnums=0)(x, obs_x) # (k, nx) 268 | h0_obs_x = jax.jacfwd(h0, argnums=1)(x, obs_x) # (k, k, nx) 269 | h0_dot = ei.einsum(h0_x, x_dot, 'k nx, nx -> k') + \ 270 | ei.einsum(h0_obs_x, obs_x_dot, 'k1 k2 nx, k2 nx -> k1') # (k,) 271 | return h0_dot + 30.0 * h0(x, obs_x) 272 | 273 | def h2(x, obs_x): 274 | x_dot = crazyflie_f_(x) 275 | obs_x_dot = jax.vmap(crazyflie_f_)(obs_x) 276 | h1_x = jax.jacfwd(h1, argnums=0)(x, obs_x) # (k, nx) 277 | h1_obs_x = jax.jacfwd(h1, argnums=1)(x, obs_x) # (k, k, nx) 278 | h1_dot = ei.einsum(h1_x, x_dot, 'k nx, nx -> k') + \ 279 | ei.einsum(h1_obs_x, obs_x_dot, 'k1 k2 nx, k2 nx -> k1') # (k,) 280 | return h1_dot + 50.0 * h1(x, obs_x) 281 | 282 | k_h2 = h2(state, all_obs_state) 283 | assert k_h2.shape == (k,) 284 | 285 | k_isobs = k_idx >= n_agent 286 | 287 | return k_h2, k_isobs 288 | 289 | 290 | def pwise_cbf_crazyflie(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int): 291 | # (n_agents, 4) 292 | a_states = graph.type_states(type_idx=0, n_type=n_agent) 293 | # (n_obs, 4) 294 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays) 295 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent) 296 | 297 | agent_idx = jnp.arange(n_agent) 298 | fn = jax.vmap(ft.partial(pwise_cbf_crazyflie_, r=r, k=k), in_axes=(0, 0, 0, None)) 299 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states) 300 | return ak_h0, ak_isobs 301 | 302 | 303 | def pwise_cbf_linear_drone_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int): 304 | # state: ( 6, ) 305 | n_agent = len(a_state) 306 | 307 | pos = state[:3] 308 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0) 309 | all_obs_pos = all_obs_state[:, :3] 310 | del o_obs_state 311 | 312 | # Only consider the k closest obstacles. 313 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1) 314 | # Remove self collisions 315 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2) 316 | # Take the k closest obstacles. 317 | k_idx = jnp.argsort(o_dist_sq)[:k] 318 | k_dist_sq = o_dist_sq[k_idx] 319 | # Take radius into account. Add some epsilon for qp solver error. 320 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2 321 | 322 | k_h0 = k_dist_sq 323 | assert k_h0.shape == (k,) 324 | 325 | k_xdiff = state[:3] - all_obs_state[k_idx][:, :3] 326 | k_vdiff = state[3:6] - all_obs_state[k_idx][:, 3:6] 327 | assert k_xdiff.shape == k_vdiff.shape == (k, 3) 328 | 329 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1) 330 | assert k_h0_dot.shape == (k,) 331 | 332 | k_h1 = k_h0_dot + 3.0 * k_h0 333 | 334 | k_isobs = k_idx >= n_agent 335 | 336 | return k_h1, k_isobs 337 | 338 | 339 | def pwise_cbf_linear_drone(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int): 340 | # (n_agents, 4) 341 | a_states = graph.type_states(type_idx=0, n_type=n_agent) 342 | # (n_obs, 4) 343 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays) 344 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent) 345 | 346 | agent_idx = jnp.arange(n_agent) 347 | fn = jax.vmap(ft.partial(pwise_cbf_linear_drone_, r=r, k=k), in_axes=(0, 0, 0, None)) 348 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states) 349 | return ak_h0, ak_isobs 350 | 351 | 352 | def pwise_cbf_cfhl_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int): 353 | # state: ( 6, ) 354 | n_agent = len(a_state) 355 | 356 | pos = state[:3] 357 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0) 358 | all_obs_pos = all_obs_state[:, :3] 359 | del o_obs_state 360 | 361 | # Only consider the k closest obstacles. 362 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1) 363 | # Remove self collisions 364 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2) 365 | # Take the k closest obstacles. 366 | k_idx = jnp.argsort(o_dist_sq)[:k] 367 | k_dist_sq = o_dist_sq[k_idx] 368 | # Take radius into account. Add some epsilon for qp solver error. 369 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2 370 | 371 | k_h0 = k_dist_sq 372 | assert k_h0.shape == (k,) 373 | 374 | k_xdiff = state[:3] - all_obs_state[k_idx][:, :3] 375 | 376 | def get_v_single_(x): 377 | u = x[CrazyFlie.U] 378 | v = x[CrazyFlie.V] 379 | w = x[CrazyFlie.W] 380 | 381 | R_W_cf = CrazyFlie.get_rotation_mat(x) 382 | v_Wcf_cf = jnp.array([u, v, w]) 383 | v_Wcf_W = R_W_cf @ v_Wcf_cf # world frame velocity 384 | return v_Wcf_W 385 | 386 | k_vdiff = get_v_single_(state) - jax.vmap(get_v_single_)(all_obs_state[k_idx]) 387 | 388 | assert k_xdiff.shape == k_vdiff.shape == (k, 3) 389 | 390 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1) 391 | assert k_h0_dot.shape == (k,) 392 | 393 | k_h1 = k_h0_dot + 3 * k_h0 394 | 395 | k_isobs = k_idx >= n_agent 396 | 397 | return k_h1, k_isobs 398 | 399 | 400 | def pwise_cbf_cfhl(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int): 401 | # (n_agents, 4) 402 | a_states = graph.type_states(type_idx=0, n_type=n_agent) 403 | # (n_obs, 4) 404 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays) 405 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent) 406 | 407 | agent_idx = jnp.arange(n_agent) 408 | fn = jax.vmap(ft.partial(pwise_cbf_cfhl_, r=r, k=k), in_axes=(0, 0, 0, None)) 409 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states) 410 | return ak_h0, ak_isobs 411 | 412 | 413 | def get_pwise_cbf_fn(env: MultiAgentEnv, k: int = 3): 414 | if isinstance(env, SingleIntegrator): 415 | n_agent = env.num_agents 416 | n_rays = env.params["n_rays"] 417 | r = env.params["car_radius"] 418 | return ft.partial(pwise_cbf_single_integrator, r=r, n_agent=n_agent, n_rays=n_rays, k=k) 419 | elif isinstance(env, DoubleIntegrator): 420 | n_agent = env.num_agents 421 | n_rays = env.params["n_rays"] 422 | r = env.params["car_radius"] 423 | return ft.partial(pwise_cbf_double_integrator, r=r, n_agent=n_agent, n_rays=n_rays, k=k) 424 | elif isinstance(env, DubinsCar): 425 | r = env.params["car_radius"] 426 | n_agent = env.num_agents 427 | n_rays = env.params["n_rays"] 428 | return ft.partial(pwise_cbf_dubins_car, r=r, n_agent=n_agent, n_rays=n_rays, k=k) 429 | elif isinstance(env, CrazyFlie): 430 | return ft.partial(pwise_cbf_crazyflie, r=env.params["drone_radius"], n_agent=env.num_agents, n_rays=env.n_rays, 431 | k=k) 432 | elif isinstance(env, LinearDrone): 433 | return ft.partial(pwise_cbf_linear_drone, r=env.params["drone_radius"], n_agent=env.num_agents, 434 | n_rays=env.n_rays, k=k) 435 | elif isinstance(env, CrazyFlie): 436 | return ft.partial(pwise_cbf_cfhl, r=env.params["drone_radius"], n_agent=env.num_agents, 437 | n_rays=env.n_rays, k=k) 438 | 439 | raise NotImplementedError("") 440 | -------------------------------------------------------------------------------- /gcbfplus/env/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .base import MultiAgentEnv 4 | from .single_integrator import SingleIntegrator 5 | from .double_integrator import DoubleIntegrator 6 | from .linear_drone import LinearDrone 7 | from .dubins_car import DubinsCar 8 | from .crazyflie import CrazyFlie 9 | 10 | 11 | ENV = { 12 | 'SingleIntegrator': SingleIntegrator, 13 | 'DoubleIntegrator': DoubleIntegrator, 14 | 'LinearDrone': LinearDrone, 15 | 'DubinsCar': DubinsCar, 16 | 'CrazyFlie': CrazyFlie, 17 | } 18 | 19 | 20 | DEFAULT_MAX_STEP = 256 21 | 22 | 23 | def make_env( 24 | env_id: str, 25 | num_agents: int, 26 | area_size: float = None, 27 | max_step: int = None, 28 | max_travel: Optional[float] = None, 29 | num_obs: Optional[int] = None, 30 | n_rays: Optional[int] = None, 31 | ) -> MultiAgentEnv: 32 | assert env_id in ENV.keys(), f'Environment {env_id} not implemented.' 33 | params = ENV[env_id].PARAMS 34 | max_step = DEFAULT_MAX_STEP if max_step is None else max_step 35 | if num_obs is not None: 36 | params['n_obs'] = num_obs 37 | if n_rays is not None: 38 | params['n_rays'] = n_rays 39 | return ENV[env_id]( 40 | num_agents=num_agents, 41 | area_size=area_size, 42 | max_step=max_step, 43 | max_travel=max_travel, 44 | dt=0.03, 45 | params=params 46 | ) 47 | -------------------------------------------------------------------------------- /gcbfplus/env/base.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import numpy as np 3 | import pathlib 4 | import jax 5 | import jax.lax as lax 6 | import jax.numpy as jnp 7 | import tqdm 8 | 9 | from abc import ABC, abstractmethod, abstractproperty 10 | from typing import Callable, NamedTuple, Optional, Tuple 11 | 12 | from ..utils.graph import GraphsTuple 13 | from ..utils.typing import Action, Array, Cost, Done, Info, PRNGKey, Reward, State 14 | from ..utils.utils import jax2np, jax_jit_np, tree_concat_at_front, tree_stack 15 | 16 | 17 | class StepResult(NamedTuple): 18 | graph: GraphsTuple 19 | reward: Reward 20 | cost: Cost 21 | done: Done 22 | info: Info 23 | 24 | 25 | class RolloutResult(NamedTuple): 26 | Tp1_graph: GraphsTuple 27 | T_action: Action 28 | T_reward: Reward 29 | T_cost: Cost 30 | T_done: Done 31 | T_info: Info 32 | 33 | 34 | class MultiAgentEnv(ABC): 35 | 36 | PARAMS = {} 37 | 38 | def __init__( 39 | self, 40 | num_agents: int, 41 | area_size: float, 42 | max_step: int = 256, 43 | max_travel: float = None, 44 | dt: float = 0.03, 45 | params: dict = None 46 | ): 47 | super(MultiAgentEnv, self).__init__() 48 | self._num_agents = num_agents 49 | self._dt = dt 50 | if params is None: 51 | params = self.PARAMS 52 | self._params = params 53 | self._t = 0 54 | self._max_step = max_step 55 | self._max_travel = max_travel 56 | self._area_size = area_size 57 | 58 | @property 59 | def params(self) -> dict: 60 | return self._params 61 | 62 | @property 63 | def num_agents(self) -> int: 64 | return self._num_agents 65 | 66 | @property 67 | def max_travel(self) -> float: 68 | return self._max_travel 69 | 70 | @property 71 | def area_size(self) -> float: 72 | return self._area_size 73 | 74 | @property 75 | def dt(self) -> float: 76 | return self._dt 77 | 78 | @property 79 | def max_episode_steps(self) -> int: 80 | return self._max_step 81 | 82 | def clip_state(self, state: State) -> State: 83 | lower_limit, upper_limit = self.state_lim(state) 84 | return jnp.clip(state, lower_limit, upper_limit) 85 | 86 | def clip_action(self, action: Action) -> Action: 87 | lower_limit, upper_limit = self.action_lim() 88 | return jnp.clip(action, lower_limit, upper_limit) 89 | 90 | @abstractproperty 91 | def state_dim(self) -> int: 92 | pass 93 | 94 | @abstractproperty 95 | def node_dim(self) -> int: 96 | pass 97 | 98 | @abstractproperty 99 | def edge_dim(self) -> int: 100 | pass 101 | 102 | @abstractproperty 103 | def action_dim(self) -> int: 104 | pass 105 | 106 | @abstractmethod 107 | def reset(self, key: Array) -> GraphsTuple: 108 | pass 109 | 110 | def reset_np(self, key: Array) -> GraphsTuple: 111 | """Reset, but without the constraint that it has to be jittable.""" 112 | return self.reset(key) 113 | 114 | @abstractmethod 115 | def step(self, graph: GraphsTuple, action: Action, get_eval_info: bool = False) -> StepResult: 116 | pass 117 | 118 | @abstractmethod 119 | def state_lim(self, state: Optional[State] = None) -> Tuple[State, State]: 120 | """ 121 | Returns 122 | ------- 123 | lower_limit, upper_limit: Tuple[State, State], 124 | limits of the state 125 | """ 126 | pass 127 | 128 | @abstractmethod 129 | def action_lim(self) -> Tuple[Action, Action]: 130 | """ 131 | Returns 132 | ------- 133 | lower_limit, upper_limit: Tuple[Action, Action], 134 | limits of the action 135 | """ 136 | pass 137 | 138 | @abstractmethod 139 | def control_affine_dyn(self, state: State) -> [Array, Array]: 140 | pass 141 | 142 | @abstractmethod 143 | def add_edge_feats(self, graph: GraphsTuple, state: State) -> GraphsTuple: 144 | pass 145 | 146 | @abstractmethod 147 | def get_graph(self, state: State) -> GraphsTuple: 148 | pass 149 | 150 | @abstractmethod 151 | def u_ref(self, graph: GraphsTuple) -> Action: 152 | pass 153 | 154 | @abstractmethod 155 | def forward_graph(self, graph: GraphsTuple, action: Action) -> GraphsTuple: 156 | pass 157 | 158 | @abstractmethod 159 | @ft.partial(jax.jit, static_argnums=(0,)) 160 | def safe_mask(self, graph: GraphsTuple) -> Array: 161 | pass 162 | 163 | @abstractmethod 164 | @ft.partial(jax.jit, static_argnums=(0,)) 165 | def unsafe_mask(self, graph: GraphsTuple) -> Array: 166 | pass 167 | 168 | @abstractmethod 169 | def collision_mask(self, graph: GraphsTuple) -> Array: 170 | pass 171 | 172 | def rollout_fn(self, policy: Callable, rollout_length: int = None) -> Callable[[PRNGKey], RolloutResult]: 173 | rollout_length = rollout_length or self.max_episode_steps 174 | 175 | def body(graph, _): 176 | action = policy(graph) 177 | graph_new, reward, cost, done, info = self.step(graph, action, get_eval_info=True) 178 | return graph_new, (graph_new, action, reward, cost, done, info) 179 | 180 | def fn(key: PRNGKey) -> RolloutResult: 181 | graph0 = self.reset(key) 182 | graph_final, (T_graph, T_action, T_reward, T_cost, T_done, T_info) = lax.scan( 183 | body, graph0, None, length=rollout_length 184 | ) 185 | Tp1_graph = tree_concat_at_front(graph0, T_graph, axis=0) 186 | 187 | return RolloutResult(Tp1_graph, T_action, T_reward, T_cost, T_done, T_info) 188 | 189 | return fn 190 | 191 | def rollout_fn_jitstep( 192 | self, policy: Callable, rollout_length: int = None, noedge: bool = False, nograph: bool = False 193 | ): 194 | rollout_length = rollout_length or self.max_episode_steps 195 | 196 | def body(graph, _): 197 | action = policy(graph) 198 | graph_new, reward, cost, done, info = self.step(graph, action, get_eval_info=True) 199 | return graph_new, (graph_new, action, reward, cost, done, info) 200 | 201 | jit_body = jax.jit(body) 202 | 203 | is_unsafe_fn = jax_jit_np(self.collision_mask) 204 | is_finish_fn = jax_jit_np(self.finish_mask) 205 | 206 | def fn(key: PRNGKey) -> [RolloutResult, Array, Array]: 207 | graph0 = self.reset_np(key) 208 | graph = graph0 209 | T_output = [] 210 | is_unsafes = [] 211 | is_finishes = [] 212 | 213 | is_unsafes.append(is_unsafe_fn(graph0)) 214 | is_finishes.append(is_finish_fn(graph0)) 215 | graph0 = jax2np(graph0) 216 | 217 | for kk in tqdm.trange(rollout_length, ncols=80): 218 | graph, output = jit_body(graph, None) 219 | 220 | is_unsafes.append(is_unsafe_fn(graph)) 221 | is_finishes.append(is_finish_fn(graph)) 222 | 223 | output = jax2np(output) 224 | if noedge: 225 | output = (output[0].without_edge(), *output[1:]) 226 | if nograph: 227 | output = (None, *output[1:]) 228 | T_output.append(output) 229 | 230 | # Concatenate everything together. 231 | T_graph = [o[0] for o in T_output] 232 | if noedge: 233 | T_graph = [graph0.without_edge()] + T_graph 234 | else: 235 | T_graph = [graph0] + T_graph 236 | del graph0 237 | T_action = [o[1] for o in T_output] 238 | T_reward = [o[2] for o in T_output] 239 | T_cost = [o[3] for o in T_output] 240 | T_done = [o[4] for o in T_output] 241 | T_info = [o[5] for o in T_output] 242 | del T_output 243 | 244 | if nograph: 245 | T_graph = None 246 | else: 247 | T_graph = tree_stack(T_graph) 248 | T_action = tree_stack(T_action) 249 | T_reward = tree_stack(T_reward) 250 | T_cost = tree_stack(T_cost) 251 | T_done = tree_stack(T_done) 252 | T_info = tree_stack(T_info) 253 | 254 | Tp1_graph = T_graph 255 | 256 | rollout_result = jax2np(RolloutResult(Tp1_graph, T_action, T_reward, T_cost, T_done, T_info)) 257 | return rollout_result, np.stack(is_unsafes, axis=0), np.stack(is_finishes, axis=0) 258 | 259 | return fn 260 | 261 | @abstractmethod 262 | def render_video( 263 | self, rollout: RolloutResult, video_path: pathlib.Path, Ta_is_unsafe=None, viz_opts: dict = None, **kwargs 264 | ) -> None: 265 | pass 266 | 267 | @abstractmethod 268 | def finish_mask(self, graph: GraphsTuple) -> Array: 269 | pass 270 | -------------------------------------------------------------------------------- /gcbfplus/env/linear_drone.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import pathlib 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import numpy as np 7 | import scipy 8 | 9 | from typing import NamedTuple, Tuple, Optional 10 | 11 | from ..utils.graph import EdgeBlock, GetGraph, GraphsTuple 12 | from ..utils.typing import Action, AgentState, Array, Cost, Done, Info, Pos3d, Reward, State 13 | from ..utils.utils import merge01 14 | from .base import MultiAgentEnv, RolloutResult 15 | from .obstacle import Obstacle, Sphere 16 | from .plot import render_video 17 | from .utils import get_lidar, inside_obstacles, lqr, get_node_goal_rng 18 | 19 | 20 | class LinearDrone(MultiAgentEnv): 21 | AGENT = 0 22 | GOAL = 1 23 | OBS = 2 24 | 25 | class EnvState(NamedTuple): 26 | agent: AgentState 27 | goal: State 28 | obstacle: Obstacle 29 | 30 | @property 31 | def n_agent(self) -> int: 32 | return self.agent.shape[0] 33 | 34 | EnvGraphsTuple = GraphsTuple[State, EnvState] 35 | 36 | PARAMS = { 37 | "drone_radius": 0.05, 38 | "comm_radius": 0.5, 39 | "n_rays": 32, 40 | "obs_len_range": [0.15, 0.3], 41 | "n_obs": 4 42 | } 43 | 44 | def __init__( 45 | self, 46 | num_agents: int, 47 | area_size: float, 48 | max_step: int = 256, 49 | max_travel: float = None, 50 | dt: float = 0.03, 51 | params: dict = None 52 | ): 53 | super(LinearDrone, self).__init__(num_agents, area_size, max_step, max_travel, dt, params) 54 | 55 | self._A = np.zeros((self.state_dim, self.state_dim)) 56 | self._A[0, 3] = 1. 57 | self._A[1, 4] = 1. 58 | self._A[2, 5] = 1. 59 | self._A[3, 3] = -1.1 60 | self._A[4, 4] = -1.1 61 | self._A[5, 5] = -6. 62 | A_discrete = scipy.linalg.expm(self._A * self._dt) 63 | 64 | self._B = np.zeros((self.state_dim, self.action_dim)) 65 | self._B[3, 0] = 10. 66 | self._B[4, 1] = 10. 67 | self._B[5, 2] = 10. # 6.0 68 | 69 | self._Q = np.diag([5e1, 5e1, 5e1, 1., 1., 1.]) 70 | self._R = np.eye(self.action_dim) 71 | self._K = lqr(A_discrete, self._B, self._Q, self._R) 72 | self.create_obstacles = jax.vmap(Sphere.create) 73 | self.n_rays = 16 # consider top k rays 74 | 75 | @property 76 | def state_dim(self) -> int: 77 | return 6 # x, y, z, vx, vy, vz 78 | 79 | @property 80 | def node_dim(self) -> int: 81 | return 3 # indicator: agent: 001, goal: 010, obstacle: 100 82 | 83 | @property 84 | def edge_dim(self) -> int: 85 | return 6 # x_rel, y_rel, z_rel, vx_rel, vy_rel, vz_rel 86 | 87 | @property 88 | def action_dim(self) -> int: 89 | return 3 # ax, ay, az 90 | 91 | def reset(self, key: Array) -> GraphsTuple: 92 | self._t = 0 93 | 94 | # randomly generate obstacles 95 | n_rng_obs = self._params["n_obs"] 96 | assert n_rng_obs >= 0 97 | obstacle_key, key = jr.split(key, 2) 98 | obs_pos = jr.uniform(obstacle_key, (n_rng_obs, 3), minval=0, maxval=self.area_size) 99 | 100 | r_key, key = jr.split(key, 2) 101 | obs_radius = jr.uniform(r_key, (n_rng_obs,), 102 | minval=self._params["obs_len_range"][0] / 2, 103 | maxval=self._params["obs_len_range"][1] / 2) 104 | obstacles = self.create_obstacles(obs_pos, obs_radius) 105 | 106 | # randomly generate agent and goal 107 | states, goals = get_node_goal_rng( 108 | key, self.area_size, 3, obstacles, self.num_agents, 4 * self.params["drone_radius"], self.max_travel) 109 | 110 | # add zero velocity 111 | states = jnp.concatenate([states, jnp.zeros((self.num_agents, self.state_dim - 3))], axis=1) 112 | goals = jnp.concatenate([goals, jnp.zeros((self.num_agents, self.state_dim - 3))], axis=1) 113 | 114 | env_states = self.EnvState(states, goals, obstacles) 115 | 116 | return self.get_graph(env_states) 117 | 118 | def clip_action(self, action: Action) -> Action: 119 | lower_limit, upper_limit = self.action_lim() 120 | return jnp.clip(action, lower_limit, upper_limit) 121 | 122 | def agent_step_euler(self, agent_state: Array, action: Array) -> Array: 123 | assert action.shape == (self.num_agents, self.action_dim) 124 | assert agent_state.shape == (self.num_agents, self.state_dim) 125 | x_dot = self.agent_xdot(agent_state, action) 126 | n_state_agent_new = agent_state + x_dot * self.dt 127 | assert n_state_agent_new.shape == (self.num_agents, self.state_dim) 128 | return self.clip_state(n_state_agent_new) 129 | 130 | def agent_xdot(self, agent_states: AgentState, action: Action) -> AgentState: 131 | assert action.shape == (self.num_agents, self.action_dim) 132 | assert agent_states.shape == (self.num_agents, self.state_dim) 133 | 134 | return jnp.matmul(agent_states, self._A.T) + jnp.matmul(action, self._B.T) 135 | 136 | def step( 137 | self, graph: EnvGraphsTuple, action: Action, get_eval_info: bool = False 138 | ) -> Tuple[EnvGraphsTuple, Reward, Cost, Done, Info]: 139 | self._t += 1 140 | 141 | # calculate next graph 142 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents) 143 | goal_states = graph.type_states(type_idx=1, n_type=self.num_agents) 144 | obstacles = graph.env_states.obstacle 145 | action = self.clip_action(action) 146 | assert action.shape == (self.num_agents, self.action_dim) 147 | assert agent_states.shape == (self.num_agents, self.state_dim) 148 | next_agent_states = self.agent_step_euler(agent_states, action) 149 | 150 | # the episode ends when reaching max_episode_steps 151 | done = jnp.array(False) 152 | 153 | # compute reward and cost 154 | reward = jnp.zeros(()).astype(jnp.float32) 155 | reward -= (jnp.linalg.norm(action - self.u_ref(graph), axis=1) ** 2).mean() 156 | cost = self.get_cost(graph) 157 | 158 | assert reward.shape == tuple() 159 | assert cost.shape == tuple() 160 | assert done.shape == tuple() 161 | 162 | next_state = self.EnvState(next_agent_states, goal_states, obstacles) 163 | 164 | info = {} 165 | if get_eval_info: 166 | # collision between agents and obstacles 167 | agent_pos = agent_states[:, :3] 168 | info["inside_obstacles"] = inside_obstacles(agent_pos, obstacles, r=self._params["drone_radius"]) 169 | 170 | return self.get_graph(next_state), reward, cost, done, info 171 | 172 | def get_cost(self, graph: EnvGraphsTuple) -> Cost: 173 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents) 174 | obstacles = graph.env_states.obstacle 175 | 176 | # collision between agents 177 | agent_pos = agent_states[:, :3] 178 | dist = jnp.linalg.norm(jnp.expand_dims(agent_pos, 1) - jnp.expand_dims(agent_pos, 0), axis=-1) 179 | dist += jnp.eye(self.num_agents) * 1e6 180 | collision = (self._params["drone_radius"] * 2 > dist).any(axis=1) 181 | cost = collision.mean() 182 | 183 | # collision between agents and obstacles 184 | collision = inside_obstacles(agent_pos, obstacles, r=self._params["drone_radius"]) 185 | cost += collision.mean() 186 | 187 | return cost 188 | 189 | def render_video( 190 | self, 191 | rollout: RolloutResult, 192 | video_path: pathlib.Path, 193 | Ta_is_unsafe=None, 194 | viz_opts: dict = None, 195 | dpi: int = 100, 196 | **kwargs 197 | ) -> None: 198 | render_video( 199 | rollout=rollout, 200 | video_path=video_path, 201 | side_length=self.area_size, 202 | dim=3, 203 | n_agent=self.num_agents, 204 | n_rays=self.n_rays, 205 | r=self.params["drone_radius"], 206 | Ta_is_unsafe=Ta_is_unsafe, 207 | viz_opts=viz_opts, 208 | dpi=dpi, 209 | **kwargs 210 | ) 211 | 212 | def edge_blocks(self, state: EnvState, lidar_data: Pos3d) -> list[EdgeBlock]: 213 | n_hits = self.num_agents * self.n_rays 214 | 215 | # agent - agent connection 216 | agent_pos = state.agent[:, :3] 217 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j 218 | dist = jnp.linalg.norm(pos_diff, axis=-1) 219 | dist += jnp.eye(dist.shape[1]) * (self._params["comm_radius"] + 1) 220 | state_diff = state.agent[:, None, :] - state.agent[None, :, :] 221 | agent_agent_mask = jnp.less(dist, self._params["comm_radius"]) 222 | id_agent = jnp.arange(self.num_agents) 223 | agent_agent_edges = EdgeBlock(state_diff, agent_agent_mask, id_agent, id_agent) 224 | 225 | # agent - goal connection, clipped to avoid too long edges 226 | id_goal = jnp.arange(self.num_agents, self.num_agents * 2) 227 | agent_goal_mask = jnp.eye(self.num_agents) 228 | agent_goal_feats = state.agent[:, None, :] - state.goal[None, :, :] 229 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(agent_goal_feats[:, :3] ** 2, axis=-1, keepdims=True)) 230 | comm_radius = self._params["comm_radius"] 231 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius) 232 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0) 233 | agent_goal_feats = agent_goal_feats.at[:, :3].set(agent_goal_feats[:, :3] * coef) 234 | agent_goal_edges = EdgeBlock( 235 | agent_goal_feats, agent_goal_mask, id_agent, id_goal 236 | ) 237 | 238 | # agent - obs connection 239 | id_obs = jnp.arange(self.num_agents * 2, self.num_agents * 2 + n_hits) 240 | agent_obs_edges = [] 241 | for i in range(self.num_agents): 242 | id_hits = jnp.arange(i * self.n_rays, (i + 1) * self.n_rays) 243 | lidar_pos = agent_pos[i, :] - lidar_data[id_hits, :3] 244 | lidar_feats = state.agent[i, :] - lidar_data[id_hits, :] 245 | lidar_dist = jnp.linalg.norm(lidar_pos, axis=-1) 246 | active_lidar = jnp.less(lidar_dist, self._params["comm_radius"] - 1e-1) 247 | agent_obs_mask = jnp.ones((1, self.n_rays)) 248 | agent_obs_mask = jnp.logical_and(agent_obs_mask, active_lidar) 249 | agent_obs_edges.append( 250 | EdgeBlock(lidar_feats[None, :, :], agent_obs_mask, id_agent[i][None], id_obs[id_hits]) 251 | ) 252 | 253 | return [agent_agent_edges, agent_goal_edges] + agent_obs_edges 254 | 255 | def control_affine_dyn(self, state: State) -> [Array, Array]: 256 | assert state.ndim == 2 257 | f = jnp.matmul(state, self._A.T) 258 | g = self._B 259 | g = jnp.expand_dims(g, axis=0).repeat(f.shape[0], axis=0) 260 | assert f.shape == state.shape 261 | assert g.shape == (state.shape[0], self.state_dim, self.action_dim) 262 | return f, g 263 | 264 | def add_edge_feats(self, graph: GraphsTuple, state: State) -> GraphsTuple: 265 | assert graph.is_single 266 | assert state.ndim == 2 267 | 268 | edge_feats = state[graph.receivers] - state[graph.senders] 269 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(edge_feats[:, :3] ** 2, axis=-1, keepdims=True)) 270 | comm_radius = self._params["comm_radius"] 271 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius) 272 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0) 273 | edge_feats = edge_feats.at[:, :3].set(edge_feats[:, :3] * coef) 274 | 275 | return graph._replace(edges=edge_feats, states=state) 276 | 277 | def get_graph(self, state: EnvState, adjacency: Array = None) -> GraphsTuple: 278 | # node features 279 | n_hits = self.n_rays * self.num_agents 280 | n_nodes = 2 * self.num_agents + n_hits 281 | node_feats = jnp.zeros((self.num_agents * 2 + n_hits, 3)) 282 | node_feats = node_feats.at[: self.num_agents, 2].set(1) # agent feats 283 | node_feats = node_feats.at[self.num_agents: self.num_agents * 2, 1].set(1) # goal feats 284 | node_feats = node_feats.at[-n_hits:, 0].set(1) # obs feats 285 | 286 | node_type = jnp.zeros(n_nodes, dtype=jnp.int32) 287 | node_type = node_type.at[self.num_agents: self.num_agents * 2].set(LinearDrone.GOAL) 288 | node_type = node_type.at[-n_hits:].set(LinearDrone.OBS) 289 | 290 | get_lidar_vmap = jax.vmap( 291 | ft.partial( 292 | get_lidar, 293 | obstacles=state.obstacle, 294 | num_beams=self.params['n_rays'], 295 | sense_range=self._params["comm_radius"], 296 | max_returns=self.n_rays, 297 | ) 298 | ) 299 | lidar_data = merge01(get_lidar_vmap(state.agent[:, :3])) 300 | lidar_data = jnp.concatenate([lidar_data, jnp.zeros_like(lidar_data)], axis=-1) 301 | edge_blocks = self.edge_blocks(state, lidar_data) 302 | 303 | # create graph 304 | return GetGraph( 305 | nodes=node_feats, 306 | node_type=node_type, 307 | edge_blocks=edge_blocks, 308 | env_states=state, 309 | states=jnp.concatenate([state.agent, state.goal, lidar_data], axis=0), 310 | ).to_padded() 311 | 312 | def state_lim(self, state: Optional[State] = None) -> Tuple[State, State]: 313 | low_lim = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, -0.5, -0.5, -0.5]) 314 | up_lim = jnp.array([jnp.inf, jnp.inf, jnp.inf, 0.5, 0.5, 0.5]) 315 | return low_lim, up_lim 316 | 317 | def action_lim(self) -> Tuple[Action, Action]: 318 | low_lim = jnp.array([-1., -1., -1.]) 319 | up_lim = jnp.array([1., 1., 1.]) 320 | return low_lim, up_lim 321 | 322 | def u_ref(self, graph: GraphsTuple) -> Action: 323 | agent = graph.type_states(type_idx=0, n_type=self.num_agents) 324 | goal = graph.type_states(type_idx=1, n_type=self.num_agents) 325 | error = goal - agent 326 | error_max = jnp.abs(error / jnp.linalg.norm(error, axis=-1, keepdims=True) * self._params["comm_radius"]) 327 | error = jnp.clip(error, -error_max, error_max) 328 | return self.clip_action(error @ self._K.T) 329 | 330 | def forward_graph(self, graph: GraphsTuple, action: Action) -> GraphsTuple: 331 | # calculate next graph 332 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents) 333 | goal_states = graph.type_states(type_idx=1, n_type=self.num_agents) 334 | obs_states = graph.type_states(type_idx=2, n_type=self._params["n_rays"] * self.num_agents) 335 | action = self.clip_action(action) 336 | 337 | assert action.shape == (self.num_agents, self.action_dim) 338 | assert agent_states.shape == (self.num_agents, self.state_dim) 339 | 340 | next_agent_states = self.agent_step_euler(agent_states, action) 341 | next_states = jnp.concatenate([next_agent_states, goal_states, obs_states], axis=0) 342 | 343 | next_graph = self.add_edge_feats(graph, next_states) 344 | return next_graph 345 | 346 | def safe_mask(self, graph: GraphsTuple) -> Array: 347 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3] 348 | agent_vel = graph.type_states(type_idx=0, n_type=self.num_agents)[:, 3:] 349 | 350 | # agents are not colliding 351 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j 352 | dist = jnp.linalg.norm(pos_diff, axis=-1) 353 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["drone_radius"] * 2 + 1) # remove self connection 354 | safe_agent = jnp.greater(dist, self._params["drone_radius"] * 4) 355 | 356 | safe_agent = jnp.min(safe_agent, axis=1) 357 | 358 | safe_obs = jnp.logical_not( 359 | inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["drone_radius"] * 2) 360 | ) 361 | 362 | safe_mask = jnp.logical_and(safe_agent, safe_obs) 363 | 364 | return safe_mask 365 | 366 | def unsafe_mask(self, graph: GraphsTuple) -> Array: 367 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3] 368 | 369 | # agents are colliding 370 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j 371 | dist = jnp.linalg.norm(pos_diff, axis=-1) 372 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["drone_radius"] * 2 + 1) # remove self connection 373 | unsafe_agent = jnp.less(dist, self._params["drone_radius"] * 2.5) 374 | unsafe_agent = jnp.max(unsafe_agent, axis=1) 375 | 376 | # agents are colliding with obstacles 377 | unsafe_obs = inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["drone_radius"] * 1.5) 378 | 379 | collision_mask = jnp.logical_or(unsafe_agent, unsafe_obs) 380 | 381 | return collision_mask 382 | 383 | def collision_mask(self, graph: GraphsTuple) -> Array: 384 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3] 385 | 386 | # agents are colliding 387 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j 388 | dist = jnp.linalg.norm(pos_diff, axis=-1) 389 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["drone_radius"] * 2 + 1) # remove self connection 390 | unsafe_agent = jnp.less(dist, self._params["drone_radius"] * 2) 391 | unsafe_agent = jnp.max(unsafe_agent, axis=1) 392 | 393 | # agents are colliding with obstacles 394 | unsafe_obs = inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["drone_radius"]) 395 | 396 | collision_mask = jnp.logical_or(unsafe_agent, unsafe_obs) 397 | 398 | return collision_mask 399 | 400 | def finish_mask(self, graph: GraphsTuple) -> Array: 401 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3] 402 | goal_pos = graph.env_states.goal[:, :3] 403 | reach = jnp.linalg.norm(agent_pos - goal_pos, axis=1) < self._params["drone_radius"] * 2 404 | return reach 405 | -------------------------------------------------------------------------------- /gcbfplus/env/obstacle.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from typing import NamedTuple, Protocol 5 | from jax.scipy.spatial.transform import Rotation 6 | from ..utils.typing import Pos2d, Pos3d, Pos 7 | from ..utils.typing import Array, ObsType, ObsWidth, ObsHeight, ObsTheta, Radius, ObsLength, ObsQuaternion, BoolScalar 8 | 9 | RECTANGLE = jnp.zeros(1) 10 | CUBOID = jnp.ones(1) 11 | SPHERE = jnp.ones(1) * 2 12 | 13 | 14 | class Obstacle(Protocol): 15 | type: ObsType 16 | center: Pos 17 | 18 | def inside(self, point: Pos, r: Radius = 0.) -> BoolScalar: 19 | pass 20 | 21 | def raytracing(self, start: Pos, end: Pos) -> Array: 22 | pass 23 | 24 | 25 | class Rectangle(NamedTuple): 26 | type: ObsType 27 | center: Pos2d 28 | width: ObsWidth 29 | height: ObsHeight 30 | theta: ObsTheta 31 | points: Array 32 | 33 | @staticmethod 34 | def create(center: Pos2d, width: ObsWidth, height: ObsHeight, theta: ObsTheta) -> "Rectangle": 35 | bbox = jnp.array([ 36 | [width / 2, height / 2], 37 | [-width / 2, height / 2], 38 | [-width / 2, -height / 2], 39 | [width / 2, -height / 2], 40 | ]).T # (2, 4) 41 | 42 | rot = jnp.array([ 43 | [jnp.cos(theta), -jnp.sin(theta)], 44 | [jnp.sin(theta), jnp.cos(theta)] 45 | ]) 46 | 47 | trans = center[:, None] 48 | points = jnp.dot(rot, bbox) + trans 49 | points = points.T 50 | 51 | return Rectangle(RECTANGLE, center, width, height, theta, points) 52 | 53 | def inside(self, point: Pos2d, r: Radius = 0.) -> BoolScalar: 54 | rel_x = point[0] - self.center[0] 55 | rel_y = point[1] - self.center[1] 56 | rel_xx = jnp.abs(rel_x * jnp.cos(self.theta) + rel_y * jnp.sin(self.theta)) - self.width / 2 57 | rel_yy = jnp.abs(rel_x * jnp.sin(self.theta) - rel_y * jnp.cos(self.theta)) - self.height / 2 58 | is_in_down = jnp.logical_and(rel_xx < r, rel_yy < 0) 59 | is_in_up = jnp.logical_and(rel_xx < 0, rel_yy < r) 60 | is_out_corner = jnp.logical_and(rel_xx > 0, rel_yy > 0) 61 | is_in_circle = jnp.sqrt(rel_xx ** 2 + rel_yy ** 2) < r 62 | is_in = jnp.logical_or(jnp.logical_or(is_in_down, is_in_up), jnp.logical_and(is_out_corner, is_in_circle)) 63 | return is_in 64 | 65 | def raytracing(self, start: Pos2d, end: Pos2d) -> Array: 66 | # beam 67 | x1 = start[0] 68 | y1 = start[1] 69 | x2 = end[0] 70 | y2 = end[1] 71 | 72 | # edges 73 | x3 = self.points[:, 0] 74 | y3 = self.points[:, 1] 75 | x4 = self.points[[-1, 0, 1, 2], 0] 76 | y4 = self.points[[-1, 0, 1, 2], 1] 77 | 78 | ''' 79 | # solve the equation 80 | # x = x1 + alpha * (x2 - x1) = x3 + beta * (x4 - x3) 81 | # y = y1 + alpha * (y2 - y1) = y3 + beta * (y4 - y3) 82 | # equivalent to solve 83 | # [x1-x2 x4-x3] alpha = [x1-x3] 84 | # [y1-y2 y4-y4] beta [y1-y3] 85 | # solve by (alpha beta)^T = A^{-1} b 86 | ''' 87 | 88 | det = (x1 - x2) * (y4 - y3) - (y1 - y2) * (x4 - x3) 89 | # clip det for numerical issues 90 | det = jnp.sign(det) * jnp.clip(jnp.abs(det), 1e-7, 1e7) 91 | alphas = ((y4 - y3) * (x1 - x3) - (x4 - x3) * (y1 - y3)) / det 92 | betas = (-(y1 - y2) * (x1 - x3) + (x1 - x2) * (y1 - y3)) / det 93 | valids = jnp.logical_and(jnp.logical_and(alphas <= 1, alphas >= 0), jnp.logical_and(betas <= 1, betas >= 0)) 94 | alphas = valids * alphas + (1 - valids) * 1e6 95 | alphas = jnp.min(alphas) # reduce the polygon edges dimension 96 | return alphas 97 | 98 | 99 | class Cuboid(NamedTuple): 100 | type: ObsType 101 | center: Pos3d 102 | length: ObsLength 103 | width: ObsWidth 104 | height: ObsHeight 105 | rotation: Rotation 106 | points: Array 107 | 108 | @staticmethod 109 | def create( 110 | center: Pos3d, length: ObsLength, width: ObsWidth, height: ObsHeight, quaternion: ObsQuaternion 111 | ) -> "Cuboid": 112 | bbox = jnp.array([ 113 | [-length / 2, -width / 2, -height / 2], 114 | [length / 2, -width / 2, -height / 2], 115 | [length / 2, width / 2, -height / 2], 116 | [-length / 2, width / 2, -height / 2], 117 | [-length / 2, -width / 2, height / 2], 118 | [length / 2, -width / 2, height / 2], 119 | [length / 2, width / 2, height / 2], 120 | [-length / 2, width / 2, height / 2], 121 | ]) # (8, 3) 122 | 123 | rotation = Rotation.from_quat(quaternion) 124 | points = rotation.apply(bbox) + center 125 | return Cuboid(CUBOID, center, length, width, height, rotation, points) 126 | 127 | def inside(self, point: Pos3d, r: Radius = 0.) -> BoolScalar: 128 | # transform the point to the cuboid frame 129 | rot = self.rotation.as_matrix() 130 | rot_inv = jnp.linalg.inv(rot) 131 | point = jnp.dot(rot_inv, point - self.center) 132 | 133 | # check if the point is inside the cuboid 134 | is_in_height = ((-self.length / 2 < point[0]) & (point[0] < self.length / 2)) & \ 135 | ((-self.width / 2 < point[1]) & (point[1] < self.width / 2)) & \ 136 | ((-self.height / 2 - r < point[2]) & (point[2] < self.height / 2 + r)) 137 | is_in_length = ((-self.length / 2 - r < point[0]) & (point[0] < self.length / 2 + r)) & \ 138 | ((-self.width / 2 < point[1]) & (point[1] < self.width / 2)) & \ 139 | ((-self.height / 2 < point[2]) & (point[2] < self.height / 2)) 140 | is_in_width = ((-self.length / 2 < point[0]) & (point[0] < self.length / 2)) & \ 141 | ((-self.width / 2 - r < point[1]) & (point[1] < self.width / 2 + r)) & \ 142 | ((-self.height / 2 < point[2]) & (point[2] < self.height / 2)) 143 | is_in = is_in_height | is_in_length | is_in_width 144 | 145 | # check if the sphere intersects with the edges 146 | edge_order = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0], 147 | [4, 5], [5, 6], [6, 7], [7, 4], 148 | [0, 4], [1, 5], [2, 6], [3, 7]]) 149 | edges = self.points[edge_order] 150 | 151 | def intersect_edge(edge: Array) -> BoolScalar: 152 | assert edge.shape == (2, 3) 153 | dot_prod = jnp.dot(edge[1] - edge[0], point - edge[0]) 154 | frac = dot_prod / ((jnp.linalg.norm(edge[1] - edge[0])) ** 2) 155 | frac = jnp.clip(frac, 0, 1) 156 | closest_point = edge[0] + frac * (edge[1] - edge[0]) 157 | dist = jnp.linalg.norm(closest_point - point) 158 | return dist <= r 159 | 160 | is_intersect = jnp.any(jax.vmap(intersect_edge)(edges)) 161 | return is_in | is_intersect 162 | 163 | def raytracing(self, start: Pos3d, end: Pos3d) -> Array: 164 | # beams 165 | x1, y1, z1 = start[0], start[1], start[2] 166 | x2, y2, z2 = end[0], end[1], end[2] 167 | 168 | # those are for edges 169 | # point order for the base 0~7 (first 0-x-xy-y for the lower level, then 0-x-xy-y for the upper level) 170 | # face order: bottom, left, right, upper, outer left, outer right 171 | x3 = self.points[[0, 0, 0, 6, 6, 6], 0] 172 | y3 = self.points[[0, 0, 0, 6, 6, 6], 1] 173 | z3 = self.points[[0, 0, 0, 6, 6, 6], 2] 174 | 175 | x4 = self.points[[1, 1, 3, 5, 5, 7], 0] 176 | y4 = self.points[[1, 1, 3, 5, 5, 7], 1] 177 | z4 = self.points[[1, 1, 3, 5, 5, 7], 2] 178 | 179 | x5 = self.points[[3, 4, 4, 7, 2, 2], 0] 180 | y5 = self.points[[3, 4, 4, 7, 2, 2], 1] 181 | z5 = self.points[[3, 4, 4, 7, 2, 2], 2] 182 | 183 | ''' 184 | # solve the equation 185 | # x = x1 + alpha * (x2 - x1) = x3 + beta * (x4 - x3) + gamma * (x5 - x3) 186 | # y = y1 + alpha * (y2 - y1) = y3 + beta * (y4 - y3) + gamma * (y5 - y3) 187 | # z = z1 + alpha * (z2 - z1) = z3 + beta * (z4 - z3) + gamma * (z5 - z3) 188 | # equivalent to solve 189 | # [x1 - x2 x4 - x3 x5 - x3] alpha = [x1 - x3] 190 | # [y1 - y2 y4 - y3 y5 - y3] beta [y1 - y3] 191 | # [z1 - z2 z4 - z3 z5 - z3] gamma [z1 - z3] 192 | # solve by (alpha beta gamma)^T = A^{-1} b 193 | 194 | # A^{-1} = 1/det * [(y4-y3)*(z5-z3)-(y5-y3)*(z4-z3) -[(x4-x3)*(z5-z3)-(z4-z3)*(x5-x3)] (x4-x3)*(y5-y3)-(y4-y3)*(x5-x3)] 195 | # [-[(y1-y2)*(z5-z3)-(z1-z2)*(y5-y3)] (x1-x2)*(z5-z3)-(z1-z2)*(x5-x3) -[(x1-x2)*(y5-y3)-(y1-y2)*(x5-x3)]] 196 | # [(y1-y2)*(z4-z3)-(y4-y3)*(z1-z2) -[(x1-x2)*(z4-z3)-(z1-z2)*(x4-x3)] (x1-x2)*(y4-y3)-(y1-y2)*(x4-x3)] 197 | ''' 198 | 199 | det = (x1 - x2) * (y4 - y3) * (z5 - z3) + (x4 - x3) * (y5 - y3) * (z1 - z2) + (y1 - y2) * (z4 - z3) * ( 200 | x5 - x3) - (y1 - y2) * (x4 - x3) * (z5 - z3) - (z4 - z3) * (y5 - y3) * (x1 - x2) - (x5 - x3) * ( 201 | y4 - y3) * (z1 - z2) 202 | # clip det for numerical issues 203 | det = jnp.sign(det) * jnp.clip(jnp.abs(det), 1e-7, 1e7) 204 | adj_00 = (y4 - y3) * (z5 - z3) - (y5 - y3) * (z4 - z3) 205 | adj_01 = -((x4 - x3) * (z5 - z3) - (z4 - z3) * (x5 - x3)) 206 | adj_02 = (x4 - x3) * (y5 - y3) - (y4 - y3) * (x5 - x3) 207 | adj_10 = -((y1 - y2) * (z5 - z3) - (z1 - z2) * (y5 - y3)) 208 | adj_11 = (x1 - x2) * (z5 - z3) - (z1 - z2) * (x5 - x3) 209 | adj_12 = -((x1 - x2) * (y5 - y3) - (y1 - y2) * (x5 - x3)) 210 | adj_20 = (y1 - y2) * (z4 - z3) - (y4 - y3) * (z1 - z2) 211 | adj_21 = -((x1 - x2) * (z4 - z3) - (z1 - z2) * (x4 - x3)) 212 | adj_22 = (x1 - x2) * (y4 - y3) - (y1 - y2) * (x4 - x3) 213 | alphas = 1 / det * (adj_00 * (x1 - x3) + adj_01 * (y1 - y3) + adj_02 * (z1 - z3)) 214 | betas = 1 / det * (adj_10 * (x1 - x3) + adj_11 * (y1 - y3) + adj_12 * (z1 - z3)) 215 | gammas = 1 / det * (adj_20 * (x1 - x3) + adj_21 * (y1 - y3) + adj_22 * (z1 - z3)) 216 | valids = jnp.logical_and( 217 | jnp.logical_and(jnp.logical_and(alphas <= 1, alphas >= 0), jnp.logical_and(betas <= 1, betas >= 0)), 218 | jnp.logical_and(gammas <= 1, gammas >= 0) 219 | ) 220 | alphas = valids * alphas + (1 - valids) * 1e6 221 | alphas = jnp.min(alphas) # reduce the polygon edges dimension 222 | return alphas 223 | 224 | 225 | class Sphere(NamedTuple): 226 | type: ObsType 227 | center: Pos3d 228 | radius: Radius 229 | 230 | @staticmethod 231 | def create(center: Pos3d, radius: Radius) -> "Sphere": 232 | return Sphere(SPHERE, center, radius) 233 | 234 | def inside(self, point: Pos3d, r: Radius = 0.) -> BoolScalar: 235 | return jnp.linalg.norm(point - self.center) <= self.radius + r 236 | 237 | def raytracing(self, start: Pos3d, end: Pos3d) -> Array: 238 | x1, y1, z1 = start[0], start[1], start[2] 239 | x2, y2, z2 = end[0], end[1], end[2] 240 | xc, yc, zc = self.center[0], self.center[1], self.center[2] 241 | r = self.radius 242 | 243 | ''' 244 | # solve the equation 245 | # x = x1 + alpha * (x2 - x1) = xc + r * sin(gamma) * cos(theta) 246 | # y = y1 + alpha * (y2 - y1) = yc + r * sin(gamma) * sin(theta) 247 | # z = z1 + alpha * (z2 - z1) = zc + r * cos(gamma) 248 | # equivalent to solve (eliminate theta using sin^2(sin^2+cos^2) +cos^2 ...=1) 249 | # [(x2-x1)^2+(y2-y1)^2+(z2-z1)^2]alpha^2+2[(x2-x1)(x1-xc)+(y2-y1)(y1-yc)+(z2-z1)(z1-zc)]alpha+(x1-xc)^2+(y1-yc)^2+(z1-zc)^2-r^2=0 250 | # A alpha^2 + B alpha + C = 0 251 | # check delta = B^2-4AC 252 | # alpha = ... 253 | # take valid min 254 | ''' 255 | lidar_rmax = jnp.linalg.norm(end - start) 256 | A = lidar_rmax ** 2 # (x2-x1)^2+(y2-y1)^2 257 | B = 2 * ((x2 - x1) * (x1 - xc) + (y2 - y1) * (y1 - yc) + (z2 - z1) * (z1 - zc)) 258 | C = (x1 - xc) ** 2 + (y1 - yc) ** 2 + (z1 - zc) ** 2 - r ** 2 259 | 260 | delta = B ** 2 - 4 * A * C 261 | valid1 = delta >= 0 262 | 263 | alpha1 = (-B - jnp.sqrt(delta * valid1)) / (2 * A) * valid1 + (1 - valid1) 264 | alpha2 = (-B + jnp.sqrt(delta * valid1)) / (2 * A) * valid1 + (1 - valid1) 265 | alpha1_tilde = (alpha1 >= 0) * alpha1 + (alpha1 < 0) * 1 266 | alpha2_tilde = (alpha2 >= 0) * alpha2 + (alpha2 < 0) * 1 267 | alphas = jnp.minimum(alpha1_tilde, alpha2_tilde) 268 | alphas = jnp.clip(alphas, 0, 1) 269 | alphas = valid1 * alphas + (1 - valid1) * 1e6 270 | return alphas 271 | -------------------------------------------------------------------------------- /gcbfplus/env/plot.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import jax 5 | import pathlib 6 | 7 | from colour import hsl2hex 8 | from matplotlib.animation import FuncAnimation 9 | from matplotlib.collections import LineCollection, PatchCollection 10 | from matplotlib.colors import LinearSegmentedColormap 11 | from matplotlib.pyplot import Axes 12 | from matplotlib.patches import Polygon 13 | from mpl_toolkits.mplot3d import proj3d, Axes3D 14 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection 15 | from typing import List, Optional, Union 16 | 17 | from ..trainer.utils import centered_norm 18 | from ..utils.typing import EdgeIndex, Pos2d, Pos3d, Array 19 | from ..utils.utils import merge01, tree_index, MutablePatchCollection, save_anim 20 | from .obstacle import Cuboid, Sphere, Obstacle, Rectangle 21 | from .base import RolloutResult 22 | 23 | 24 | def plot_graph( 25 | ax: Axes, 26 | pos: Pos2d, 27 | radius: Union[float, List[float]], 28 | color: Union[str, List[str]], 29 | with_label: Union[bool, List[bool]] = True, 30 | plot_edge: bool = False, 31 | edge_index: Optional[EdgeIndex] = None, 32 | edge_color: Union[str, List[str]] = 'k', 33 | alpha: float = 1.0, 34 | obstacle_color: str = '#000000', 35 | ) -> Axes: 36 | if isinstance(radius, float): 37 | radius = np.ones(pos.shape[0]) * radius 38 | if isinstance(radius, list): 39 | radius = np.array(radius) 40 | if isinstance(color, str): 41 | color = [color for _ in range(pos.shape[0])] 42 | if isinstance(with_label, bool): 43 | with_label = [with_label for _ in range(pos.shape[0])] 44 | circles = [] 45 | for i in range(pos.shape[0]): 46 | circles.append(plt.Circle((float(pos[i, 0]), float(pos[i, 1])), 47 | radius=radius[i], color=color[i], clip_on=False, alpha=alpha, linewidth=0.0)) 48 | if with_label[i]: 49 | ax.text(float(pos[i, 0]), float(pos[i, 1]), f'{i}', size=12, color="k", 50 | family="sans-serif", weight="normal", horizontalalignment="center", verticalalignment="center", 51 | transform=ax.transData, clip_on=True) 52 | circles = PatchCollection(circles, match_original=True) 53 | ax.add_collection(circles) 54 | 55 | if plot_edge and edge_index is not None: 56 | if isinstance(edge_color, str): 57 | edge_color = [edge_color for _ in range(edge_index.shape[1])] 58 | start, end = pos[edge_index[0, :]], pos[edge_index[1, :]] 59 | direction = (end - start) / jnp.linalg.norm(end - start, axis=1, keepdims=True) 60 | start = start + direction * radius[edge_index[0, :]][:, None] 61 | end = end - direction * radius[edge_index[1, :]][:, None] 62 | widths = (radius[edge_index[0, :]] + radius[edge_index[1, :]]) * 20 63 | lines = np.stack([start, end], axis=1) 64 | edges = LineCollection(lines, colors=edge_color, linewidths=widths, alpha=0.5) 65 | ax.add_collection(edges) 66 | return ax 67 | 68 | 69 | def plot_node_3d(ax, pos: Pos3d, r: float, color: str, alpha: float, grid: int = 10) -> Axes: 70 | u = np.linspace(0, 2 * np.pi, grid) 71 | v = np.linspace(0, np.pi, grid) 72 | x = r * np.outer(np.cos(u), np.sin(v)) + pos[0] 73 | y = r * np.outer(np.sin(u), np.sin(v)) + pos[1] 74 | z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + pos[2] 75 | ax.plot_surface(x, y, z, color=color, alpha=alpha) 76 | return ax 77 | 78 | 79 | def plot_graph_3d( 80 | ax, 81 | pos: Pos3d, 82 | radius: float, 83 | color: Union[str, List[str]], 84 | with_label: bool = True, 85 | plot_edge: bool = False, 86 | edge_index: Optional[EdgeIndex] = None, 87 | edge_color: Union[str, List[str]] = 'k', 88 | alpha: float = 1.0, 89 | obstacle_color: str = '#000000', 90 | ): 91 | if isinstance(color, str): 92 | color = [color for _ in range(pos.shape[0])] 93 | for i in range(pos.shape[0]): 94 | plot_node_3d(ax, pos[i], radius, color[i], alpha) 95 | if with_label: 96 | ax.text(pos[i, 0], pos[i, 1], pos[i, 2], f'{i}', size=12, color="k", family="sans-serif", weight="normal", 97 | horizontalalignment="center", verticalalignment="center") 98 | if plot_edge: 99 | if isinstance(edge_color, str): 100 | edge_color = [edge_color for _ in range(edge_index.shape[1])] 101 | for i_edge in range(edge_index.shape[1]): 102 | i = edge_index[0, i_edge] 103 | j = edge_index[1, i_edge] 104 | vec = pos[i, :] - pos[j, :] 105 | x = [pos[i, 0] - 2 * radius * vec[0], pos[j, 0] + 2 * radius * vec[0]] 106 | y = [pos[i, 1] - 2 * radius * vec[1], pos[j, 1] + 2 * radius * vec[1]] 107 | z = [pos[i, 2] - 2 * radius * vec[2], pos[j, 2] + 2 * radius * vec[2]] 108 | ax.plot(x, y, z, linewidth=1.0, color=edge_color[i_edge]) 109 | return ax 110 | 111 | 112 | def get_BuRd(): 113 | # blue = "#3182bd" 114 | # blue = hsl2hex([0.57, 0.59, 0.47]) 115 | blue = hsl2hex([0.57, 0.5, 0.55]) 116 | light_blue = hsl2hex([0.5, 1.0, 0.995]) 117 | 118 | # Tint it to orange a bit. 119 | # red = "#de2d26" 120 | # red = hsl2hex([0.04, 0.74, 0.51]) 121 | red = hsl2hex([0.028, 0.62, 0.59]) 122 | light_red = hsl2hex([0.098, 1.0, 0.995]) 123 | 124 | sdf_cm = LinearSegmentedColormap.from_list("SDF", [(0, light_blue), (0.5, blue), (0.5, red), (1, light_red)], N=256) 125 | return sdf_cm 126 | 127 | 128 | def get_faces_cuboid(points: Pos3d) -> Array: 129 | point_id = jnp.array([[0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 5, 4], [2, 3, 7, 6], [0, 3, 7, 4], [1, 2, 6, 5]]) 130 | faces = points[point_id] 131 | return faces 132 | 133 | 134 | def get_cuboid_collection( 135 | obstacles: Cuboid, alpha: float = 0.8, linewidth: float = 1.0, edgecolor: str = 'k', facecolor: str = 'r' 136 | ) -> Poly3DCollection: 137 | get_faces_vmap = jax.vmap(get_faces_cuboid) 138 | cuboid_col = Poly3DCollection( 139 | merge01(get_faces_vmap(obstacles.points)), 140 | alpha=alpha, 141 | linewidth=linewidth, 142 | edgecolor=edgecolor, 143 | facecolor=facecolor 144 | ) 145 | return cuboid_col 146 | 147 | 148 | def get_sphere_collection( 149 | obstacles: Sphere, alpha: float = 0.8, facecolor: str = 'r' 150 | ) -> Poly3DCollection: 151 | def get_sphere(inp): 152 | center = inp[:3] 153 | radius = inp[3] 154 | u = np.linspace(0, 2 * np.pi, 30) 155 | v = np.linspace(0, np.pi, 30) 156 | x = radius * np.outer(np.cos(u), np.sin(v)) + center[0] 157 | y = radius * np.outer(np.sin(u), np.sin(v)) + center[1] 158 | z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2] 159 | return jnp.stack([x, y, z], axis=-1) 160 | 161 | get_sphere_vmap = jax.vmap(get_sphere) 162 | sphere_col = Poly3DCollection( 163 | merge01(get_sphere_vmap(jnp.concatenate([obstacles.center, obstacles.radius[:, None]], axis=-1))), 164 | alpha=alpha, 165 | linewidth=0.0, 166 | edgecolor='k', 167 | facecolor=facecolor 168 | ) 169 | 170 | return sphere_col 171 | 172 | 173 | def get_obs_collection( 174 | obstacles: Obstacle, color: str, alpha: float 175 | ): 176 | if isinstance(obstacles, Rectangle): 177 | n_obs = len(obstacles.center) 178 | obs_polys = [Polygon(obstacles.points[ii]) for ii in range(n_obs)] 179 | obs_col = PatchCollection(obs_polys, color=color, alpha=1.0, zorder=99) 180 | elif isinstance(obstacles, Cuboid): 181 | obs_col = get_cuboid_collection(obstacles, alpha=alpha, facecolor=color) 182 | elif isinstance(obstacles, Sphere): 183 | obs_col = get_sphere_collection(obstacles, alpha=alpha, facecolor=color) 184 | else: 185 | raise NotImplementedError 186 | return obs_col 187 | 188 | 189 | def render_video( 190 | rollout: RolloutResult, 191 | video_path: pathlib.Path, 192 | side_length: float, 193 | dim: int, 194 | n_agent: int, 195 | n_rays: int, 196 | r: float, 197 | Ta_is_unsafe=None, 198 | viz_opts: dict = None, 199 | dpi: int = 100, 200 | **kwargs 201 | ): 202 | assert dim == 2 or dim == 3 203 | 204 | # set up visualization option 205 | if dim == 2: 206 | ax: Axes 207 | fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=dpi) 208 | else: 209 | fig = plt.figure(figsize=(10, 10), dpi=dpi) 210 | ax: Axes3D = fig.add_subplot(projection='3d') 211 | ax.set_xlim(0., side_length) 212 | ax.set_ylim(0., side_length) 213 | if dim == 3: 214 | ax.set_zlim(0., side_length) 215 | ax.set(aspect="equal") 216 | if dim == 2: 217 | plt.axis("off") 218 | 219 | if viz_opts is None: 220 | viz_opts = {} 221 | 222 | # plot the first frame 223 | T_graph = rollout.Tp1_graph 224 | graph0 = tree_index(T_graph, 0) 225 | 226 | agent_color = "#0068ff" 227 | goal_color = "#2fdd00" 228 | obs_color = "#8a0000" 229 | edge_goal_color = goal_color 230 | 231 | # plot obstacles 232 | obs = graph0.env_states.obstacle 233 | ax.add_collection(get_obs_collection(obs, obs_color, alpha=0.8)) 234 | 235 | # plot agents 236 | n_hits = n_agent * n_rays 237 | n_color = [agent_color] * n_agent + [goal_color] * n_agent 238 | n_pos = graph0.states[:n_agent * 2, :dim] 239 | n_radius = np.array([r] * n_agent * 2) 240 | if dim == 2: 241 | agent_circs = [plt.Circle(n_pos[ii], n_radius[ii], color=n_color[ii], linewidth=0.0) 242 | for ii in range(n_agent * 2)] 243 | agent_col = MutablePatchCollection([i for i in reversed(agent_circs)], match_original=True, zorder=6) 244 | ax.add_collection(agent_col) 245 | else: 246 | plot_r = ax.transData.transform([r, 0])[0] - ax.transData.transform([0, 0])[0] 247 | agent_col = ax.scatter(n_pos[:, 0], n_pos[:, 1], n_pos[:, 2], 248 | s=plot_r, c=n_color, zorder=5) # todo: the size of the agent might not be correct 249 | 250 | # plot edges 251 | all_pos = graph0.states[:n_agent * 2 + n_hits, :dim] 252 | edge_index = np.stack([graph0.senders, graph0.receivers], axis=0) 253 | is_pad = np.any(edge_index == n_agent * 2 + n_hits, axis=0) 254 | e_edge_index = edge_index[:, ~is_pad] 255 | e_start, e_end = all_pos[e_edge_index[0, :]], all_pos[e_edge_index[1, :]] 256 | e_lines = np.stack([e_start, e_end], axis=1) # (e, n_pts, dim) 257 | e_is_goal = (n_agent <= graph0.senders) & (graph0.senders < n_agent * 2) 258 | e_is_goal = e_is_goal[~is_pad] 259 | e_colors = [edge_goal_color if e_is_goal[ii] else "0.2" for ii in range(len(e_start))] 260 | if dim == 2: 261 | edge_col = LineCollection(e_lines, colors=e_colors, linewidths=2, alpha=0.5, zorder=3) 262 | else: 263 | edge_col = Line3DCollection(e_lines, colors=e_colors, linewidths=2, alpha=0.5, zorder=3) 264 | ax.add_collection(edge_col) 265 | 266 | # text for cost and reward 267 | text_font_opts = dict( 268 | size=16, 269 | color="k", 270 | family="cursive", 271 | weight="normal", 272 | transform=ax.transAxes, 273 | ) 274 | if dim == 2: 275 | cost_text = ax.text(0.02, 1.04, "Cost: 1.0, Reward: 1.0", va="bottom", **text_font_opts) 276 | else: 277 | cost_text = ax.text2D(0.02, 1.04, "Cost: 1.0, Reward: 1.0", va="bottom", **text_font_opts) 278 | 279 | # text for safety 280 | safe_text = [] 281 | if Ta_is_unsafe is not None: 282 | if dim == 2: 283 | safe_text = [ax.text(0.02, 1.00, "Unsafe: {}", va="bottom", **text_font_opts)] 284 | else: 285 | safe_text = [ax.text2D(0.02, 1.00, "Unsafe: {}", va="bottom", **text_font_opts)] 286 | 287 | # text for time step 288 | if dim == 2: 289 | kk_text = ax.text(0.99, 0.99, "kk=0", va="top", ha="right", **text_font_opts) 290 | else: 291 | kk_text = ax.text2D(0.99, 0.99, "kk=0", va="top", ha="right", **text_font_opts) 292 | 293 | # add agent labels 294 | label_font_opts = dict( 295 | size=20, 296 | color="k", 297 | family="cursive", 298 | weight="normal", 299 | ha="center", 300 | va="center", 301 | transform=ax.transData, 302 | clip_on=True, 303 | zorder=7, 304 | ) 305 | agent_labels = [] 306 | if dim == 2: 307 | agent_labels = [ax.text(n_pos[ii, 0], n_pos[ii, 1], f"{ii}", **label_font_opts) for ii in range(n_agent)] 308 | else: 309 | for ii in range(n_agent): 310 | pos2d = proj3d.proj_transform(n_pos[ii, 0], n_pos[ii, 1], n_pos[ii, 2], ax.get_proj())[:2] 311 | agent_labels.append(ax.text2D(pos2d[0], pos2d[1], f"{ii}", **label_font_opts)) 312 | 313 | # plot cbf 314 | cnt_col = [] 315 | if "cbf" in viz_opts: 316 | if dim == 3: 317 | print('Warning: CBF visualization is not supported in 3D.') 318 | else: 319 | Tb_xs, Tb_ys, Tbb_h, cbf_num = viz_opts["cbf"] 320 | bb_Xs, bb_Ys = np.meshgrid(Tb_xs[0], Tb_ys[0]) 321 | norm = centered_norm(Tbb_h.min(), Tbb_h.max()) 322 | levels = np.linspace(norm.vmin, norm.vmax, 15) 323 | 324 | cmap = get_BuRd().reversed() 325 | contour_opts = dict(cmap=cmap, norm=norm, levels=levels, alpha=0.9) 326 | cnt = ax.contourf(bb_Xs, bb_Ys, Tbb_h[0], **contour_opts) 327 | 328 | contour_line_opts = dict(levels=[0.0], colors=["k"], linewidths=3.0) 329 | cnt_line = ax.contour(bb_Xs, bb_Ys, Tbb_h[0], **contour_line_opts) 330 | 331 | cbar = fig.colorbar(cnt, ax=ax) 332 | cbar.add_lines(cnt_line) 333 | cbar.ax.tick_params(labelsize=36, labelfontfamily="Times New Roman") 334 | 335 | cnt_col = [*cnt.collections, *cnt_line.collections] 336 | 337 | ax.text(0.5, 1.0, "CBF for {}".format(cbf_num), transform=ax.transAxes, va="bottom") 338 | 339 | # init function for animation 340 | def init_fn() -> list[plt.Artist]: 341 | return [agent_col, edge_col, *agent_labels, cost_text, *safe_text, *cnt_col, kk_text] 342 | 343 | # update function for animation 344 | def update(kk: int) -> list[plt.Artist]: 345 | graph = tree_index(T_graph, kk) 346 | n_pos_t = graph.states[:-1, :dim] 347 | 348 | # update agent positions 349 | if dim == 2: 350 | for ii in range(n_agent): 351 | agent_circs[ii].set_center(tuple(n_pos_t[ii])) 352 | else: 353 | agent_col.set_offsets(n_pos_t[:n_agent * 2, :2]) 354 | agent_col.set_3d_properties(n_pos_t[:n_agent * 2, 2], zdir='z') 355 | 356 | # update edges 357 | e_edge_index_t = np.stack([graph.senders, graph.receivers], axis=0) 358 | is_pad_t = np.any(e_edge_index_t == n_agent * 2 + n_hits, axis=0) 359 | e_edge_index_t = e_edge_index_t[:, ~is_pad_t] 360 | e_start_t, e_end_t = n_pos_t[e_edge_index_t[0, :]], n_pos_t[e_edge_index_t[1, :]] 361 | e_is_goal_t = (n_agent <= graph.senders) & (graph.senders < n_agent * 2) 362 | e_is_goal_t = e_is_goal_t[~is_pad_t] 363 | e_colors_t = [edge_goal_color if e_is_goal_t[ii] else "0.2" for ii in range(len(e_start_t))] 364 | e_lines_t = np.stack([e_start_t, e_end_t], axis=1) 365 | edge_col.set_segments(e_lines_t) 366 | edge_col.set_colors(e_colors_t) 367 | 368 | # update agent labels 369 | for ii in range(n_agent): 370 | if dim == 2: 371 | agent_labels[ii].set_position(n_pos_t[ii]) 372 | else: 373 | text_pos = proj3d.proj_transform(n_pos_t[ii, 0], n_pos_t[ii, 1], n_pos_t[ii, 2], ax.get_proj())[:2] 374 | agent_labels[ii].set_position(text_pos) 375 | 376 | # update cost and safe labels 377 | if kk < len(rollout.T_cost): 378 | cost_text.set_text("Cost: {:5.4f}, Reward: {:5.4f}".format(rollout.T_cost[kk], rollout.T_reward[kk])) 379 | else: 380 | cost_text.set_text("") 381 | if kk < len(Ta_is_unsafe): 382 | a_is_unsafe = Ta_is_unsafe[kk] 383 | unsafe_idx = np.where(a_is_unsafe)[0] 384 | safe_text[0].set_text("Unsafe: {}".format(unsafe_idx)) 385 | else: 386 | safe_text[0].set_text("Unsafe: {}") 387 | 388 | # Update the contourf. 389 | nonlocal cnt, cnt_line 390 | if "cbf" in viz_opts and dim == 2: 391 | for c in cnt.collections: 392 | c.remove() 393 | for c in cnt_line.collections: 394 | c.remove() 395 | 396 | bb_Xs_t, bb_Ys_t = np.meshgrid(Tb_xs[kk], Tb_ys[kk]) 397 | cnt = ax.contourf(bb_Xs_t, bb_Ys_t, Tbb_h[kk], **contour_opts) 398 | cnt_line = ax.contour(bb_Xs_t, bb_Ys_t, Tbb_h[kk], **contour_line_opts) 399 | 400 | cnt_col_t = [*cnt.collections, *cnt_line.collections] 401 | else: 402 | cnt_col_t = [] 403 | 404 | kk_text.set_text("kk={:04}".format(kk)) 405 | 406 | return [agent_col, edge_col, *agent_labels, cost_text, *safe_text, *cnt_col_t, kk_text] 407 | 408 | fps = 30.0 409 | spf = 1 / fps 410 | mspf = 1_000 * spf 411 | anim_T = len(T_graph.n_node) 412 | ani = FuncAnimation(fig, update, frames=anim_T, init_func=init_fn, interval=mspf, blit=True) 413 | save_anim(ani, video_path) 414 | -------------------------------------------------------------------------------- /gcbfplus/env/single_integrator.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import pathlib 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import numpy as np 7 | 8 | from typing import NamedTuple, Tuple, Optional 9 | 10 | from ..utils.graph import EdgeBlock, GetGraph, GraphsTuple 11 | from ..utils.typing import Action, Array, Cost, Done, Info, Pos2d, Reward, State, AgentState 12 | from ..utils.utils import merge01, jax_vmap 13 | from .base import MultiAgentEnv, RolloutResult 14 | from .obstacle import Obstacle, Rectangle 15 | from .plot import render_video 16 | from .utils import get_lidar, inside_obstacles, lqr, get_node_goal_rng 17 | 18 | 19 | class SingleIntegrator(MultiAgentEnv): 20 | 21 | AGENT = 0 22 | GOAL = 1 23 | OBS = 2 24 | 25 | class EnvState(NamedTuple): 26 | agent: State 27 | goal: State 28 | obstacle: Obstacle 29 | 30 | @property 31 | def n_agent(self) -> int: 32 | return self.agent.shape[0] 33 | 34 | EnvGraphsTuple = GraphsTuple[State, EnvState] 35 | 36 | PARAMS = { 37 | "car_radius": 0.05, 38 | "comm_radius": 0.5, 39 | "n_rays": 32, 40 | "obs_len_range": [0.1, 0.6], 41 | "n_obs": 8, 42 | } 43 | 44 | def __init__( 45 | self, 46 | num_agents: int, 47 | area_size: float, 48 | max_step: int = 256, 49 | max_travel: float = None, 50 | dt: float = 0.03, 51 | params: dict = None 52 | ): 53 | super(SingleIntegrator, self).__init__(num_agents, area_size, max_step, max_travel, dt, params) 54 | self._A = np.zeros((self.state_dim, self.state_dim), dtype=np.float32) * self._dt + np.eye(self.state_dim) 55 | self._B = np.array([[1.0, 0.0], [0.0, 1.0]]) * self._dt 56 | self._Q = np.eye(self.state_dim) * 2 57 | self._R = np.eye(self.action_dim) 58 | self._K = jnp.array(lqr(self._A, self._B, self._Q, self._R)) 59 | self.create_obstacles = jax_vmap(Rectangle.create) 60 | 61 | @property 62 | def state_dim(self) -> int: 63 | return 2 # x, y 64 | 65 | @property 66 | def node_dim(self) -> int: 67 | return 3 # indicator: agent: 001, goal: 010, obstacle: 100 68 | 69 | @property 70 | def edge_dim(self) -> int: 71 | return 2 # x_rel, y_rel 72 | 73 | @property 74 | def action_dim(self) -> int: 75 | return 2 # vx, vy 76 | 77 | def reset(self, key: Array) -> GraphsTuple: 78 | self._t = 0 79 | 80 | # randomly generate obstacles 81 | n_rng_obs = self._params["n_obs"] 82 | assert n_rng_obs >= 0 83 | obstacle_key, key = jr.split(key, 2) 84 | obs_pos = jr.uniform(obstacle_key, (n_rng_obs, 2), minval=0, maxval=self.area_size) 85 | length_key, key = jr.split(key, 2) 86 | obs_len = jr.uniform( 87 | length_key, 88 | (self._params["n_obs"], 2), 89 | minval=self._params["obs_len_range"][0], 90 | maxval=self._params["obs_len_range"][1], 91 | ) 92 | theta_key, key = jr.split(key, 2) 93 | obs_theta = jr.uniform(theta_key, (n_rng_obs,), minval=0, maxval=2 * np.pi) 94 | obstacles = self.create_obstacles(obs_pos, obs_len[:, 0], obs_len[:, 1], obs_theta) 95 | 96 | # randomly generate agent and goal 97 | states, goals = get_node_goal_rng( 98 | key, self.area_size, 2, obstacles, self.num_agents, 4 * self.params["car_radius"], self.max_travel) 99 | 100 | env_states = self.EnvState(states, goals, obstacles) 101 | 102 | return self.get_graph(env_states) 103 | 104 | def agent_step_euler(self, agent_states: AgentState, action: Action) -> AgentState: 105 | assert action.shape == (self.num_agents, self.action_dim) 106 | assert agent_states.shape == (self.num_agents, self.state_dim) 107 | x_dot = action 108 | n_state_agent_new = x_dot * self.dt + agent_states 109 | assert n_state_agent_new.shape == (self.num_agents, self.state_dim) 110 | return self.clip_state(n_state_agent_new) 111 | 112 | def step( 113 | self, graph: EnvGraphsTuple, action: Action, get_eval_info: bool = False 114 | ) -> Tuple[EnvGraphsTuple, Reward, Cost, Done, Info]: 115 | self._t += 1 116 | 117 | # calculate next graph 118 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents) 119 | goals = graph.type_states(type_idx=1, n_type=self.num_agents) 120 | obstacles = graph.env_states.obstacle 121 | action = self.clip_action(action) 122 | 123 | assert action.shape == (self.num_agents, self.action_dim) 124 | assert agent_states.shape == (self.num_agents, self.state_dim) 125 | 126 | next_agent_states = self.agent_step_euler(agent_states, action) 127 | 128 | # the episode ends when reaching max_episode_steps 129 | done = jnp.array(False) 130 | 131 | # compute reward and cost 132 | reward = jnp.zeros(()).astype(jnp.float32) 133 | reward -= (jnp.linalg.norm(action - self.u_ref(graph), axis=1) ** 2).mean() 134 | cost = self.get_cost(graph) 135 | 136 | assert reward.shape == tuple() 137 | assert cost.shape == tuple() 138 | assert done.shape == tuple() 139 | 140 | next_state = self.EnvState(next_agent_states, goals, obstacles) 141 | 142 | info = {} 143 | if get_eval_info: 144 | # collision between agents and obstacles 145 | agent_pos = agent_states 146 | info["inside_obstacles"] = inside_obstacles(agent_pos, obstacles, r=self._params["car_radius"]) 147 | 148 | return self.get_graph(next_state), reward, cost, done, info 149 | 150 | def get_cost(self, graph: GraphsTuple) -> Cost: 151 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents) 152 | obstacles = graph.env_states.obstacle 153 | 154 | # collision between agents 155 | agent_pos = agent_states 156 | dist = jnp.linalg.norm(jnp.expand_dims(agent_pos, 1) - jnp.expand_dims(agent_pos, 0), axis=-1) 157 | dist += jnp.eye(self.num_agents) * 1e6 158 | collision = (self._params["car_radius"] * 2 > dist).any(axis=1) 159 | cost = collision.mean() 160 | 161 | # collision between agents and obstacles 162 | collision = inside_obstacles(agent_pos, obstacles, r=self._params["car_radius"]) 163 | cost += collision.mean() 164 | 165 | return cost 166 | 167 | def render_video( 168 | self, 169 | rollout: RolloutResult, 170 | video_path: pathlib.Path, 171 | Ta_is_unsafe=None, 172 | viz_opts: dict = None, 173 | dpi: int = 100, 174 | **kwargs 175 | ) -> None: 176 | render_video( 177 | rollout=rollout, 178 | video_path=video_path, 179 | side_length=self.area_size, 180 | dim=2, 181 | n_agent=self.num_agents, 182 | n_rays=self.params["n_rays"], 183 | r=self.params["car_radius"], 184 | Ta_is_unsafe=Ta_is_unsafe, 185 | viz_opts=viz_opts, 186 | dpi=dpi, 187 | **kwargs 188 | ) 189 | 190 | def edge_blocks(self, state: EnvState, lidar_data: Pos2d) -> list[EdgeBlock]: 191 | n_hits = self._params["n_rays"] * self.num_agents 192 | 193 | # agent - agent connection 194 | agent_pos = state.agent 195 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j 196 | dist = jnp.linalg.norm(pos_diff, axis=-1) 197 | dist += jnp.eye(dist.shape[1]) * (self._params["comm_radius"] + 1) 198 | agent_agent_mask = jnp.less(dist, self._params["comm_radius"]) 199 | id_agent = jnp.arange(self.num_agents) 200 | agent_agent_edges = EdgeBlock(pos_diff, agent_agent_mask, id_agent, id_agent) 201 | 202 | # agent - goal connection, clipped to avoid too long edges 203 | id_goal = jnp.arange(self.num_agents, self.num_agents * 2) 204 | agent_goal_mask = jnp.eye(self.num_agents) 205 | agent_goal_feats = state.agent[:, None, :] - state.goal[None, :, :] 206 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(agent_goal_feats[:, :2] ** 2, axis=-1, keepdims=True)) 207 | comm_radius = self._params["comm_radius"] 208 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius) 209 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0) 210 | agent_goal_feats = agent_goal_feats.at[:, :2].set(agent_goal_feats[:, :2] * coef) 211 | agent_goal_edges = EdgeBlock( 212 | agent_goal_feats, agent_goal_mask, id_agent, id_goal 213 | ) 214 | 215 | # agent - obs connection 216 | id_obs = jnp.arange(self.num_agents * 2, self.num_agents * 2 + n_hits) 217 | agent_obs_edges = [] 218 | for i in range(self.num_agents): 219 | id_hits = jnp.arange(i * self._params["n_rays"], (i + 1) * self._params["n_rays"]) 220 | lidar_feats = agent_pos[i, :] - lidar_data[id_hits, :] 221 | lidar_dist = jnp.linalg.norm(lidar_feats, axis=-1) 222 | active_lidar = jnp.less(lidar_dist, self._params["comm_radius"] - 1e-1) 223 | agent_obs_mask = jnp.ones((1, self._params["n_rays"])) 224 | agent_obs_mask = jnp.logical_and(agent_obs_mask, active_lidar) 225 | agent_obs_edges.append( 226 | EdgeBlock(lidar_feats[None, :, :], agent_obs_mask, id_agent[i][None], id_obs[id_hits]) 227 | ) 228 | 229 | return [agent_agent_edges, agent_goal_edges] + agent_obs_edges 230 | 231 | def control_affine_dyn(self, state: State) -> [Array, Array]: 232 | assert state.ndim == 2 233 | f = jnp.zeros_like(state) 234 | g = jnp.eye(state.shape[1]) 235 | g = jnp.expand_dims(g, axis=0).repeat(f.shape[0], axis=0) 236 | assert f.shape == state.shape 237 | assert g.shape == (state.shape[0], self.state_dim, self.action_dim) 238 | return f, g 239 | 240 | def add_edge_feats(self, graph: GraphsTuple, state: State) -> GraphsTuple: 241 | assert graph.is_single 242 | assert state.ndim == 2 243 | 244 | edge_feats = state[graph.receivers] - state[graph.senders] 245 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(edge_feats[:, :2] ** 2, axis=-1, keepdims=True)) 246 | comm_radius = self._params["comm_radius"] 247 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius) 248 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0) 249 | edge_feats = edge_feats.at[:, :2].set(edge_feats[:, :2] * coef) 250 | 251 | return graph._replace(edges=edge_feats, states=state) 252 | 253 | def get_graph(self, state: EnvState) -> GraphsTuple: 254 | # node features 255 | n_hits = self._params["n_rays"] * self.num_agents 256 | n_nodes = 2 * self.num_agents + n_hits 257 | node_feats = jnp.zeros((self.num_agents * 2 + n_hits, 3)) 258 | node_feats = node_feats.at[: self.num_agents, 2].set(1) # agent feats 259 | node_feats = node_feats.at[self.num_agents: self.num_agents * 2, 1].set(1) # goal feats 260 | node_feats = node_feats.at[-n_hits:, 0].set(1) # obs feats 261 | 262 | # node type 263 | node_type = jnp.zeros(n_nodes, dtype=jnp.int32) 264 | node_type = node_type.at[self.num_agents: self.num_agents * 2].set(SingleIntegrator.GOAL) 265 | node_type = node_type.at[-n_hits:].set(SingleIntegrator.OBS) 266 | 267 | # edge blocks 268 | get_lidar_vmap = jax_vmap( 269 | ft.partial( 270 | get_lidar, 271 | obstacles=state.obstacle, 272 | num_beams=self._params["n_rays"], 273 | sense_range=self._params["comm_radius"], 274 | ) 275 | ) 276 | lidar_data = merge01(get_lidar_vmap(state.agent)) 277 | edge_blocks = self.edge_blocks(state, lidar_data) 278 | 279 | # create graph 280 | return GetGraph( 281 | nodes=node_feats, 282 | node_type=node_type, 283 | edge_blocks=edge_blocks, 284 | env_states=state, 285 | states=jnp.concatenate([state.agent, state.goal, lidar_data], axis=0), 286 | ).to_padded() 287 | 288 | def state_lim(self, state: Optional[State] = None) -> Tuple[State, State]: 289 | lower_lim = jnp.ones(2) * -jnp.inf 290 | upper_lim = jnp.ones(2) * jnp.inf 291 | return lower_lim, upper_lim 292 | 293 | def action_lim(self) -> Tuple[Action, Action]: 294 | lower_lim = jnp.ones(2) * -1.0 295 | upper_lim = jnp.ones(2) 296 | return lower_lim, upper_lim 297 | 298 | def u_ref(self, graph: GraphsTuple) -> Action: 299 | agent = graph.type_states(type_idx=0, n_type=self.num_agents) 300 | goal = graph.type_states(type_idx=1, n_type=self.num_agents) 301 | error = goal - agent 302 | error_max = jnp.abs(error / jnp.linalg.norm(error, axis=-1, keepdims=True) * self._params["comm_radius"]) 303 | error = jnp.clip(error, -error_max, error_max) 304 | return self.clip_action(error @ self._K.T) 305 | 306 | def forward_graph(self, graph: GraphsTuple, action: Action) -> GraphsTuple: 307 | # calculate next graph 308 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents) 309 | goal_states = graph.type_states(type_idx=1, n_type=self.num_agents) 310 | obs_states = graph.type_states(type_idx=2, n_type=self._params["n_rays"] * self.num_agents) 311 | action = self.clip_action(action) 312 | 313 | assert action.shape == (self.num_agents, self.action_dim) 314 | assert agent_states.shape == (self.num_agents, self.state_dim) 315 | 316 | next_agent_states = self.agent_step_euler(agent_states, action) 317 | next_states = jnp.concatenate([next_agent_states, goal_states, obs_states], axis=0) 318 | 319 | next_graph = self.add_edge_feats(graph, next_states) 320 | 321 | return next_graph 322 | 323 | @ft.partial(jax.jit, static_argnums=(0,)) 324 | def safe_mask(self, graph: GraphsTuple) -> Array: 325 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents) 326 | 327 | # agents are not colliding 328 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j 329 | dist = jnp.linalg.norm(pos_diff, axis=-1) 330 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["car_radius"] * 2 + 1) # remove self connection 331 | safe_agent = jnp.greater(dist, self._params["car_radius"] * 2.5) 332 | safe_agent = jnp.min(safe_agent, axis=1) 333 | 334 | safe_obs = jnp.logical_not( 335 | inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["car_radius"] * 1.5) 336 | ) 337 | 338 | safe_mask = jnp.logical_and(safe_agent, safe_obs) 339 | 340 | return safe_mask 341 | 342 | @ft.partial(jax.jit, static_argnums=(0,)) 343 | def unsafe_mask(self, graph: GraphsTuple) -> Array: 344 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents) 345 | 346 | # agents are colliding 347 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j 348 | dist = jnp.linalg.norm(pos_diff, axis=-1) 349 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["car_radius"] * 2 + 1) # remove self connection 350 | unsafe_agent = jnp.less(dist, self._params["car_radius"] * 2) 351 | unsafe_agent = jnp.max(unsafe_agent, axis=1) 352 | 353 | # agents are colliding with obstacles 354 | unsafe_obs = inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["car_radius"]) 355 | 356 | unsafe_mask = jnp.logical_or(unsafe_agent, unsafe_obs) 357 | 358 | return unsafe_mask 359 | 360 | def collision_mask(self, graph: GraphsTuple) -> Array: 361 | return self.unsafe_mask(graph) 362 | 363 | def finish_mask(self, graph: GraphsTuple) -> Array: 364 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :2] 365 | goal_pos = graph.env_states.goal[:, :2] 366 | reach = jnp.linalg.norm(agent_pos - goal_pos, axis=1) < self._params["car_radius"] * 2 367 | return reach 368 | -------------------------------------------------------------------------------- /gcbfplus/env/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | import functools as ft 4 | import jax 5 | import jax.random as jr 6 | 7 | from scipy.linalg import inv, solve_discrete_are 8 | from typing import Callable, Tuple 9 | from jax.lax import while_loop 10 | 11 | from ..utils.typing import Array, Radius, BoolScalar, Pos, State, Action, PRNGKey 12 | from ..utils.utils import merge01 13 | from .obstacle import Obstacle, Rectangle, Cuboid, Sphere 14 | 15 | 16 | def RK4_step(x_dot_fn: Callable, x: State, u: Action, dt: float) -> Array: 17 | k1 = x_dot_fn(x, u) 18 | k2 = x_dot_fn(x + 0.5 * dt * k1, u) 19 | k3 = x_dot_fn(x + 0.5 * dt * k2, u) 20 | k4 = x_dot_fn(x + dt * k3, u) 21 | return x + dt / 6.0 * (k1 + 2 * k2 + 2 * k3 + k4) 22 | 23 | 24 | def lqr( 25 | A: np.ndarray, 26 | B: np.ndarray, 27 | Q: np.ndarray, 28 | R: np.ndarray, 29 | ): 30 | """ 31 | Solve the discrete time lqr controller. 32 | x_{t+1} = A x_t + B u_t 33 | cost = sum x.T*Q*x + u.T*R*u 34 | Code adapted from Mark Wilfred Mueller's continuous LQR code at 35 | https://www.mwm.im/lqr-controllers-with-python/ 36 | Based on Bertsekas, p.151 37 | Yields the control law u = -K x 38 | """ 39 | 40 | # first, try to solve the Riccati equation 41 | X = solve_discrete_are(A, B, Q, R) 42 | 43 | # compute the LQR gain 44 | K = inv(B.T @ X @ B + R) @ (B.T @ X @ A) 45 | 46 | return K 47 | 48 | 49 | def get_lidar(start_point: Pos, obstacles: Obstacle, num_beams: int, sense_range: float, max_returns: int = 32): 50 | if isinstance(obstacles, Rectangle): 51 | thetas = jnp.linspace(-np.pi, np.pi - 2 * np.pi / num_beams, num_beams) 52 | starts = start_point[None, :].repeat(num_beams, axis=0) 53 | ends = jnp.stack( 54 | [starts[..., 0] + jnp.cos(thetas) * sense_range, starts[..., 1] + jnp.sin(thetas) * sense_range], 55 | axis=-1) 56 | elif isinstance(obstacles, Cuboid) or isinstance(obstacles, Sphere): 57 | thetas = jnp.linspace(-np.pi / 2 + 2 * np.pi / num_beams, np.pi / 2 - 2 * np.pi / num_beams, num_beams // 2) 58 | phis = jnp.linspace(-np.pi, np.pi - 2 * np.pi / num_beams, num_beams) 59 | starts = start_point[None, :].repeat(thetas.shape[0] * phis.shape[0] + 2, axis=0) 60 | 61 | def get_end_point(theta, phi): 62 | return jnp.array([ 63 | start_point[0] + jnp.cos(theta) * jnp.cos(phi) * sense_range, 64 | start_point[1] + jnp.cos(theta) * jnp.sin(phi) * sense_range, 65 | start_point[2] + jnp.sin(theta) * sense_range 66 | ]) 67 | 68 | def get_end_point_theta(theta): 69 | return jax.vmap(lambda phi: get_end_point(theta, phi))(phis) 70 | 71 | ends = merge01(jax.vmap(get_end_point_theta)(thetas)) 72 | ends = jnp.concatenate([ends, 73 | start_point[None, :] + jnp.array([[0., 0., sense_range]]), 74 | start_point[None, :] + jnp.array([[0., 0., -sense_range]])], axis=0) 75 | else: 76 | raise NotImplementedError 77 | sensor_data = raytracing(starts, ends, obstacles, max_returns) 78 | 79 | return sensor_data 80 | 81 | 82 | def inside_obstacles(points: Pos, obstacles: Obstacle, r: Radius = 0.) -> BoolScalar: 83 | """ 84 | points: (n, n_dim) or (n_dim, ) 85 | obstacles: tree_stacked obstacles. 86 | 87 | Returns: (n, ) or (,). True if in collision, false otherwise. 88 | """ 89 | # one point inside one obstacle 90 | def inside(point: Pos, obstacle: Obstacle): 91 | return obstacle.inside(point, r) 92 | 93 | # one point inside any obstacle 94 | def inside_any(point: Pos, obstacle: Obstacle): 95 | return jax.vmap(ft.partial(inside, point))(obstacle).max() 96 | 97 | # any point inside any obstacle 98 | if points.ndim == 1: 99 | if obstacles.center.shape[0] == 0: 100 | return jnp.zeros((), dtype=bool) 101 | is_in = inside_any(points, obstacles) 102 | else: 103 | if obstacles.center.shape[0] == 0: 104 | return jnp.zeros(points.shape[0], dtype=bool) 105 | is_in = jax.vmap(ft.partial(inside_any, obstacle=obstacles))(points) 106 | 107 | return is_in 108 | 109 | 110 | def raytracing(starts: Pos, ends: Pos, obstacles: Obstacle, max_returns: int) -> Pos: 111 | # if the start point if inside the obstacle, return the start point 112 | is_in = inside_obstacles(starts, obstacles) 113 | 114 | def raytracing_single(start: Pos, end: Pos, obstacle: Obstacle): 115 | return obstacle.raytracing(start, end) 116 | 117 | def raytracing_any(start: Pos, end: Pos, obstacle: Obstacle): 118 | return jax.vmap(ft.partial(raytracing_single, start, end))(obstacle).min() 119 | 120 | if obstacles.center.shape[0] == 0: 121 | alphas = jnp.ones(starts.shape[0]) * 1e6 122 | else: 123 | alphas = jax.vmap(ft.partial(raytracing_any, obstacle=obstacles))(starts, ends) 124 | alphas *= (1 - is_in) 125 | 126 | # assert max_returns <= alphas.shape[0] 127 | alphas_return = jnp.argsort(alphas)[:max_returns] 128 | 129 | hitting_points = starts + (ends - starts) * (alphas[..., None]) 130 | 131 | return hitting_points[alphas_return] 132 | 133 | 134 | def get_node_goal_rng( 135 | key: PRNGKey, 136 | side_length: float, 137 | dim: int, 138 | obstacles: Obstacle, 139 | n: int, 140 | min_dist: float, 141 | max_travel: float = None 142 | ) -> [Pos, Pos]: 143 | max_iter = 1024 # maximum number of iterations to find a valid initial state/goal 144 | states = jnp.zeros((n, dim)) 145 | goals = jnp.zeros((n, dim)) 146 | 147 | def get_node(reset_input: Tuple[int, Array, Array, Array]): # key, node, all nodes 148 | i_iter, this_key, _, all_nodes = reset_input 149 | use_key, this_key = jr.split(this_key, 2) 150 | i_iter += 1 151 | return i_iter, this_key, jr.uniform(use_key, (dim,), minval=0, maxval=side_length), all_nodes 152 | 153 | def non_valid_node(reset_input: Tuple[int, Array, Array, Array]): # key, node, all nodes 154 | i_iter, _, node, all_nodes = reset_input 155 | dist_min = jnp.linalg.norm(all_nodes - node, axis=1).min() 156 | collide = dist_min <= min_dist 157 | inside = inside_obstacles(node, obstacles, r=min_dist) 158 | valid = ~(collide | inside) | (i_iter >= max_iter) 159 | return ~valid 160 | 161 | def get_goal(reset_input: Tuple[int, Array, Array, Array, Array]): 162 | # key, goal_candidate, agent_start_pos, all_goals 163 | i_iter, this_key, _, agent, all_goals = reset_input 164 | use_key, this_key = jr.split(this_key, 2) 165 | i_iter += 1 166 | if max_travel is None: 167 | return i_iter, this_key, jr.uniform(use_key, (dim,), minval=0, maxval=side_length), agent, all_goals 168 | else: 169 | return i_iter, this_key, jr.uniform( 170 | use_key, (dim,), minval=-max_travel, maxval=max_travel) + agent, agent, all_goals 171 | 172 | def non_valid_goal(reset_input: Tuple[int, Array, Array, Array, Array]): 173 | # key, goal_candidate, agent_start_pos, all_goals 174 | i_iter, _, goal, agent, all_goals = reset_input 175 | dist_min = jnp.linalg.norm(all_goals - goal, axis=1).min() 176 | collide = dist_min <= min_dist 177 | inside = inside_obstacles(goal, obstacles, r=min_dist) 178 | outside = jnp.any(goal < 0) | jnp.any(goal > side_length) 179 | if max_travel is None: 180 | too_long = np.array(False, dtype=bool) 181 | else: 182 | too_long = jnp.linalg.norm(goal - agent) > max_travel 183 | valid = (~collide & ~inside & ~outside & ~too_long) | (i_iter >= max_iter) 184 | out = ~valid 185 | assert out.shape == tuple() and out.dtype == jnp.bool_ 186 | return out 187 | 188 | def reset_body(reset_input: Tuple[int, Array, Array, Array]): 189 | # agent_id, key, states, goals 190 | agent_id, this_key, all_states, all_goals = reset_input 191 | agent_key, goal_key, this_key = jr.split(this_key, 3) 192 | agent_candidate = jr.uniform(agent_key, (dim,), minval=0, maxval=side_length) 193 | n_iter_agent, _, agent_candidate, _ = while_loop( 194 | cond_fun=non_valid_node, body_fun=get_node, 195 | init_val=(0, agent_key, agent_candidate, all_states) 196 | ) 197 | all_states = all_states.at[agent_id].set(agent_candidate) 198 | 199 | if max_travel is None: 200 | goal_candidate = jr.uniform(goal_key, (dim,), minval=0, maxval=side_length) 201 | else: 202 | goal_candidate = jr.uniform(goal_key, (dim,), minval=0, maxval=max_travel) + agent_candidate 203 | 204 | n_iter_goal, _, goal_candidate, _, _ = while_loop( 205 | cond_fun=non_valid_goal, body_fun=get_goal, 206 | init_val=(0, goal_key, goal_candidate, agent_candidate, all_goals) 207 | ) 208 | all_goals = all_goals.at[agent_id].set(goal_candidate) 209 | agent_id += 1 210 | 211 | # if no solution is found, start over 212 | agent_id = (1 - (n_iter_agent >= max_iter)) * (1 - (n_iter_goal >= max_iter)) * agent_id 213 | all_states = (1 - (n_iter_agent >= max_iter)) * (1 - (n_iter_goal >= max_iter)) * all_states 214 | all_goals = (1 - (n_iter_agent >= max_iter)) * (1 - (n_iter_goal >= max_iter)) * all_goals 215 | 216 | return agent_id, this_key, all_states, all_goals 217 | 218 | def reset_not_terminate(reset_input: Tuple[int, Array, Array, Array]): 219 | # agent_id, key, states, goals 220 | agent_id, this_key, all_states, all_goals = reset_input 221 | return agent_id < n 222 | 223 | _, _, states, goals = while_loop( 224 | cond_fun=reset_not_terminate, body_fun=reset_body, init_val=(0, key, states, goals)) 225 | 226 | return states, goals 227 | -------------------------------------------------------------------------------- /gcbfplus/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/nn/__init__.py -------------------------------------------------------------------------------- /gcbfplus/nn/gnn.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import functools as ft 3 | import jax.numpy as jnp 4 | import jraph 5 | import jax.tree_util as jtu 6 | 7 | from typing import Type, NamedTuple, Callable, Tuple 8 | 9 | from ..utils.typing import EdgeAttr, Node, Array 10 | from ..utils.graph import GraphsTuple 11 | from .mlp import MLP, default_nn_init 12 | from .utils import safe_get 13 | 14 | save_attn = False 15 | 16 | 17 | def save_set_attn(v): 18 | global save_attn 19 | save_attn = v 20 | 21 | 22 | class GNNUpdate(NamedTuple): 23 | message: Callable[[EdgeAttr, Node, Node], Array] 24 | aggregate: Callable[[Array, Array, int], Array] 25 | update: Callable[[Node, Array], Array] 26 | 27 | def __call__(self, graph: GraphsTuple) -> GraphsTuple: 28 | assert graph.n_node.shape == tuple() 29 | node_feats_send = jtu.tree_map(lambda n: safe_get(n, graph.senders), graph.nodes) 30 | node_feats_recv = jtu.tree_map(lambda n: safe_get(n, graph.receivers), graph.nodes) 31 | 32 | # message passing 33 | edges = self.message(graph.edges, node_feats_send, node_feats_recv) 34 | 35 | # aggregate messages 36 | aggr_msg = jtu.tree_map(lambda edge: self.aggregate(edge, graph.receivers, graph.nodes.shape[0]), edges) 37 | 38 | # update nodes 39 | new_node_feats = self.update(graph.nodes, aggr_msg) 40 | 41 | return graph._replace(nodes=new_node_feats) 42 | 43 | 44 | class GNNLayer(nn.Module): 45 | msg_net_cls: Type[nn.Module] 46 | aggr_net_cls: Type[nn.Module] 47 | update_net_cls: Type[nn.Module] 48 | msg_dim: int 49 | out_dim: int 50 | 51 | @nn.compact 52 | def __call__(self, graph: GraphsTuple) -> GraphsTuple: 53 | def message(edge_feats: EdgeAttr, sender_feats: Node, receiver_feats: Node) -> Array: 54 | feats = jnp.concatenate([edge_feats, sender_feats, receiver_feats], axis=-1) 55 | feats = self.msg_net_cls()(feats) 56 | feats = nn.Dense(self.msg_dim, kernel_init=default_nn_init())(feats) 57 | return feats 58 | 59 | def update(node_feats: Node, msgs: Array) -> Array: 60 | feats = jnp.concatenate([node_feats, msgs], axis=-1) 61 | feats = self.update_net_cls()(feats) 62 | feats = nn.Dense(self.out_dim, kernel_init=default_nn_init())(feats) 63 | return feats 64 | 65 | def aggregate(msgs: Array, recv_idx: Array, num_segments: int) -> Array: 66 | gate_feats = self.aggr_net_cls()(msgs) 67 | gate_feats = nn.Dense(1, kernel_init=default_nn_init())(gate_feats).squeeze(-1) 68 | attn = jraph.segment_softmax(gate_feats, segment_ids=recv_idx, num_segments=num_segments) 69 | assert attn.shape[0] == msgs.shape[0] 70 | 71 | aggr_msg = jraph.segment_sum(attn[:, None] * msgs, segment_ids=recv_idx, num_segments=num_segments) 72 | return aggr_msg 73 | 74 | update_fn = GNNUpdate(message, aggregate, update) 75 | return update_fn(graph) 76 | 77 | 78 | class GNN(nn.Module): 79 | msg_dim: int 80 | hid_size_msg: Tuple[int, ...] 81 | hid_size_aggr: Tuple[int, ...] 82 | hid_size_update: Tuple[int, ...] 83 | out_dim: int 84 | n_layers: int 85 | 86 | @nn.compact 87 | def __call__(self, graph: GraphsTuple, node_type: int = None, n_type: int = None) -> Array: 88 | for i in range(self.n_layers): 89 | out_dim = self.out_dim if i == self.n_layers - 1 else self.msg_dim 90 | msg_net = ft.partial(MLP, hid_sizes=self.hid_size_msg, act=nn.relu, act_final=False, name="msg") 91 | attn_net = ft.partial(MLP, hid_sizes=self.hid_size_aggr, act=nn.relu, act_final=False, name="attn") 92 | update_net = ft.partial(MLP, hid_sizes=self.hid_size_update, act=nn.relu, act_final=False, name="update") 93 | gnn_layer = GNNLayer( 94 | msg_net_cls=msg_net, 95 | aggr_net_cls=attn_net, 96 | update_net_cls=update_net, 97 | msg_dim=self.msg_dim, 98 | out_dim=out_dim, 99 | ) 100 | graph = gnn_layer(graph) 101 | if node_type is None: 102 | return graph.nodes 103 | else: 104 | return graph.type_nodes(node_type, n_type) 105 | -------------------------------------------------------------------------------- /gcbfplus/nn/mlp.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | 3 | from .utils import default_nn_init, scaled_init, AnyFloat, HidSizes, ActFn, signal_last_enumerate 4 | 5 | 6 | class MLP(nn.Module): 7 | hid_sizes: HidSizes 8 | act: ActFn = nn.relu 9 | act_final: bool = True 10 | use_layernorm: bool = False 11 | scale_final: float | None = None 12 | dropout_rate: float | None = None 13 | 14 | @nn.compact 15 | def __call__(self, x: AnyFloat, apply_dropout: bool = False) -> AnyFloat: 16 | nn_init = default_nn_init 17 | for is_last_layer, ii, hid_size in signal_last_enumerate(self.hid_sizes): 18 | if is_last_layer and self.scale_final is not None: 19 | x = nn.Dense(hid_size, kernel_init=scaled_init(nn_init(), self.scale_final))(x) 20 | else: 21 | x = nn.Dense(hid_size, kernel_init=nn_init())(x) 22 | 23 | no_activation = is_last_layer and not self.act_final 24 | if not no_activation: 25 | if self.dropout_rate is not None and self.dropout_rate > 0: 26 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not apply_dropout)(x) 27 | if self.use_layernorm: 28 | x = nn.LayerNorm()(x) 29 | x = self.act(x) 30 | return x 31 | -------------------------------------------------------------------------------- /gcbfplus/nn/utils.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from typing import Any, Callable, Literal, Sequence, Iterable, Generator, TypeVar 5 | from jaxtyping import Array, Bool, Float, Int, Shaped 6 | 7 | 8 | ActFn = Callable[[Array], Array] 9 | PRGNKey = Float[Array, '2'] 10 | AnyFloat = Float[Array, '*'] 11 | Shape = tuple[int, ...] 12 | InitFn = Callable[[PRGNKey, Shape, Any], Any] 13 | HidSizes = Sequence[int] 14 | 15 | 16 | _Elem = TypeVar("_Elem") 17 | 18 | 19 | default_nn_init = nn.initializers.xavier_uniform 20 | 21 | 22 | def scaled_init(initializer: nn.initializers.Initializer, scale: float) -> nn.initializers.Initializer: 23 | def scaled_init_inner(*args, **kwargs) -> AnyFloat: 24 | return scale * initializer(*args, **kwargs) 25 | 26 | return scaled_init_inner 27 | 28 | 29 | ActLiteral = Literal["relu", "tanh", "elu", "swish", "silu", "gelu", "softplus"] 30 | 31 | 32 | def get_act_from_str(act_str: ActLiteral) -> ActFn: 33 | act_dict: dict[Literal, ActFn] = dict( 34 | relu=nn.relu, tanh=nn.tanh, elu=nn.elu, swish=nn.swish, silu=nn.silu, gelu=nn.gelu, softplus=nn.softplus 35 | ) 36 | return act_dict[act_str] 37 | 38 | 39 | def signal_last_enumerate(it: Iterable[_Elem]) -> Generator[tuple[bool, int, _Elem], None, None]: 40 | iterable = iter(it) 41 | count = 0 42 | ret_var = next(iterable) 43 | for val in iterable: 44 | yield False, count, ret_var 45 | count += 1 46 | ret_var = val 47 | yield True, count, ret_var 48 | 49 | 50 | def safe_get(arr, idx): 51 | return arr.at[idx].get(mode='fill', fill_value=jnp.nan) 52 | -------------------------------------------------------------------------------- /gcbfplus/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/trainer/__init__.py -------------------------------------------------------------------------------- /gcbfplus/trainer/buffer.py: -------------------------------------------------------------------------------- 1 | import jax.tree_util as jtu 2 | import numpy as np 3 | 4 | from abc import ABC, abstractproperty, abstractmethod 5 | from .data import Rollout 6 | from .utils import jax2np, np2jax 7 | from ..utils.utils import tree_merge 8 | from ..utils.typing import Array 9 | 10 | 11 | class Buffer(ABC): 12 | 13 | def __init__(self, size: int): 14 | self._size = size 15 | 16 | @abstractmethod 17 | def append(self, rollout: Rollout): 18 | pass 19 | 20 | @abstractmethod 21 | def sample(self, batch_size: int) -> Rollout: 22 | pass 23 | 24 | @abstractproperty 25 | def length(self) -> int: 26 | pass 27 | 28 | 29 | class ReplayBuffer(Buffer): 30 | 31 | def __init__(self, size: int): 32 | super(ReplayBuffer, self).__init__(size) 33 | self._buffer = None 34 | 35 | def append(self, rollout: Rollout): 36 | if self._buffer is None: 37 | self._buffer = jax2np(rollout) 38 | else: 39 | self._buffer = tree_merge([self._buffer, jax2np(rollout)]) 40 | if self._buffer.length > self._size: 41 | self._buffer = jtu.tree_map(lambda x: x[-self._size:], self._buffer) 42 | 43 | def sample(self, batch_size: int) -> Rollout: 44 | idx = np.random.randint(0, self._buffer.length, batch_size) 45 | return np2jax(self.get_data(idx)) 46 | 47 | def get_data(self, idx: np.ndarray) -> Rollout: 48 | return jtu.tree_map(lambda x: x[idx], self._buffer) 49 | 50 | @property 51 | def length(self) -> int: 52 | if self._buffer is None: 53 | return 0 54 | return self._buffer.n_data 55 | 56 | 57 | class MaskedReplayBuffer: 58 | 59 | def __init__(self, size: int): 60 | self._size = size 61 | # (b, T) 62 | self._buffer = None 63 | self._safe_mask = None 64 | self._unsafe_mask = None 65 | 66 | def append(self, rollout: Rollout, safe_mask: Array, unsafe_mask: Array): 67 | if self._buffer is None: 68 | self._buffer = jax2np(rollout) 69 | self._safe_mask = jax2np(safe_mask) 70 | self._unsafe_mask = jax2np(unsafe_mask) 71 | # self._mid_mask = jax2np(mid_mask) 72 | else: 73 | self._buffer = tree_merge([self._buffer, jax2np(rollout)]) 74 | self._safe_mask = tree_merge([self._safe_mask, jax2np(safe_mask)]) 75 | self._unsafe_mask = tree_merge([self._unsafe_mask, jax2np(unsafe_mask)]) 76 | if self._buffer.length > self._size: 77 | self._buffer = jtu.tree_map(lambda x: x[-self._size:], self._buffer) 78 | self._safe_mask = jtu.tree_map(lambda x: x[-self._size:], self._safe_mask) 79 | self._unsafe_mask = jtu.tree_map(lambda x: x[-self._size:], self._unsafe_mask) 80 | 81 | def sample(self, batch_size: int) -> [Rollout, Array, Array]: 82 | idx = np.random.randint(0, self._buffer.length, batch_size) 83 | rollout, safe_mask, unsafe_mask = self.get_data(idx) 84 | return rollout, safe_mask, unsafe_mask 85 | 86 | def get_data(self, idx: np.ndarray) -> [Rollout, Array, Array]: 87 | return jtu.tree_map(lambda x: x[idx], self._buffer), self._safe_mask[idx], self._unsafe_mask[idx] 88 | 89 | @property 90 | def length(self) -> int: 91 | if self._buffer is None: 92 | return 0 93 | return self._buffer.n_data 94 | -------------------------------------------------------------------------------- /gcbfplus/trainer/data.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | from ..utils.typing import Array 4 | from ..utils.typing import Action, Reward, Cost, Done 5 | from ..utils.graph import GraphsTuple 6 | 7 | 8 | class Rollout(NamedTuple): 9 | graph: GraphsTuple 10 | actions: Action 11 | rewards: Reward 12 | costs: Cost 13 | dones: Done 14 | log_pis: Array 15 | next_graph: GraphsTuple 16 | 17 | @property 18 | def length(self) -> int: 19 | return self.rewards.shape[0] 20 | 21 | @property 22 | def time_horizon(self) -> int: 23 | return self.rewards.shape[1] 24 | 25 | @property 26 | def num_agents(self) -> int: 27 | return self.rewards.shape[2] 28 | 29 | @property 30 | def n_data(self) -> int: 31 | return self.length * self.time_horizon 32 | -------------------------------------------------------------------------------- /gcbfplus/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import os 3 | import numpy as np 4 | import jax 5 | import jax.random as jr 6 | import functools as ft 7 | 8 | from time import time 9 | from tqdm import tqdm 10 | 11 | from .data import Rollout 12 | from .utils import rollout 13 | from ..env import MultiAgentEnv 14 | from ..algo.base import MultiAgentController 15 | from ..utils.utils import jax_vmap 16 | 17 | 18 | class Trainer: 19 | 20 | def __init__( 21 | self, 22 | env: MultiAgentEnv, 23 | env_test: MultiAgentEnv, 24 | algo: MultiAgentController, 25 | n_env_train: int, 26 | n_env_test: int, 27 | log_dir: str, 28 | seed: int, 29 | params: dict, 30 | save_log: bool = True 31 | ): 32 | self.env = env 33 | self.env_test = env_test 34 | self.algo = algo 35 | self.n_env_train = n_env_train 36 | self.n_env_test = n_env_test 37 | self.log_dir = log_dir 38 | self.seed = seed 39 | 40 | if Trainer._check_params(params): 41 | self.params = params 42 | 43 | # make dir for the models 44 | if save_log: 45 | if not os.path.exists(log_dir): 46 | os.mkdir(log_dir) 47 | self.model_dir = os.path.join(log_dir, 'models') 48 | if not os.path.exists(self.model_dir): 49 | os.mkdir(self.model_dir) 50 | 51 | wandb.login() 52 | wandb.init(name=params['run_name'], project='gcbf+', dir=self.log_dir) 53 | 54 | self.save_log = save_log 55 | 56 | self.steps = params['training_steps'] 57 | self.eval_interval = params['eval_interval'] 58 | self.eval_epi = params['eval_epi'] 59 | self.save_interval = params['save_interval'] 60 | 61 | self.update_steps = 0 62 | self.key = jax.random.PRNGKey(seed) 63 | 64 | @staticmethod 65 | def _check_params(params: dict) -> bool: 66 | assert 'run_name' in params, 'run_name not found in params' 67 | assert 'training_steps' in params, 'training_steps not found in params' 68 | assert 'eval_interval' in params, 'eval_interval not found in params' 69 | assert params['eval_interval'] > 0, 'eval_interval must be positive' 70 | assert 'eval_epi' in params, 'eval_epi not found in params' 71 | assert params['eval_epi'] >= 1, 'eval_epi must be greater than or equal to 1' 72 | assert 'save_interval' in params, 'save_interval not found in params' 73 | assert params['save_interval'] > 0, 'save_interval must be positive' 74 | return True 75 | 76 | def train(self): 77 | # record start time 78 | start_time = time() 79 | 80 | # preprocess the rollout function 81 | def rollout_fn_single(params, key): 82 | return rollout(self.env, ft.partial(self.algo.step, params=params), key) 83 | 84 | def rollout_fn(params, keys): 85 | return jax.vmap(ft.partial(rollout_fn_single, params))(keys) 86 | 87 | rollout_fn = jax.jit(rollout_fn) 88 | 89 | # preprocess the test function 90 | def test_fn_single(params, key): 91 | return rollout(self.env_test, lambda graph, k: (self.algo.act(graph, params), None), key) 92 | 93 | def test_fn(params, keys): 94 | return jax.vmap(ft.partial(test_fn_single, params))(keys) 95 | 96 | test_fn = jax.jit(test_fn) 97 | 98 | # start training 99 | test_key = jr.PRNGKey(self.seed) 100 | test_keys = jr.split(test_key, 1_000)[:self.n_env_test] 101 | 102 | pbar = tqdm(total=self.steps, ncols=80) 103 | for step in range(0, self.steps + 1): 104 | # evaluate the algorithm 105 | if step % self.eval_interval == 0: 106 | test_rollouts: Rollout = test_fn(self.algo.actor_params, test_keys) 107 | total_reward = test_rollouts.rewards.sum(axis=-1) 108 | assert total_reward.shape == (self.n_env_test,) 109 | reward_min, reward_max = total_reward.min(), total_reward.max() 110 | reward_mean = np.mean(total_reward) 111 | reward_final = np.mean(test_rollouts.rewards[:, -1]) 112 | finish_fun = jax_vmap(jax_vmap(self.env_test.finish_mask)) 113 | finish = finish_fun(test_rollouts.graph).max(axis=1).mean() 114 | cost = test_rollouts.costs.sum(axis=-1).mean() 115 | unsafe_frac = np.mean(test_rollouts.costs.max(axis=-1) >= 1e-6) 116 | eval_info = { 117 | "eval/reward": reward_mean, 118 | "eval/reward_final": reward_final, 119 | "eval/cost": cost, 120 | "eval/unsafe_frac": unsafe_frac, 121 | "eval/finish": finish, 122 | "step": step, 123 | } 124 | wandb.log(eval_info, step=self.update_steps) 125 | time_since_start = time() - start_time 126 | eval_verbose = (f'step: {step:3}, time: {time_since_start:5.0f}s, reward: {reward_mean:9.4f}, ' 127 | f'min/max reward: {reward_min:7.2f}/{reward_max:7.2f}, cost: {cost:8.4f}, ' 128 | f'unsafe_frac: {unsafe_frac:6.2f}, finish: {finish:6.2f}') 129 | tqdm.write(eval_verbose) 130 | if self.save_log and step % self.save_interval == 0: 131 | self.algo.save(os.path.join(self.model_dir), step) 132 | 133 | # collect rollouts 134 | key_x0, self.key = jax.random.split(self.key) 135 | key_x0 = jax.random.split(key_x0, self.n_env_train) 136 | rollouts: Rollout = rollout_fn(self.algo.actor_params, key_x0) 137 | 138 | # update the algorithm 139 | update_info = self.algo.update(rollouts, step) 140 | wandb.log(update_info, step=self.update_steps) 141 | self.update_steps += 1 142 | 143 | pbar.update(1) 144 | -------------------------------------------------------------------------------- /gcbfplus/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.tree_util as jtu 3 | import jax 4 | import numpy as np 5 | import socket 6 | import matplotlib.pyplot as plt 7 | import functools as ft 8 | import seaborn as sns 9 | import optax 10 | 11 | from typing import Callable, TYPE_CHECKING 12 | from matplotlib.colors import CenteredNorm 13 | 14 | from ..utils.typing import PRNGKey 15 | from ..utils.graph import GraphsTuple 16 | from .data import Rollout 17 | 18 | 19 | if TYPE_CHECKING: 20 | from ..env import MultiAgentEnv 21 | else: 22 | MultiAgentEnv = None 23 | 24 | 25 | def rollout( 26 | env: MultiAgentEnv, 27 | actor: Callable, 28 | key: PRNGKey 29 | ) -> Rollout: 30 | """ 31 | Get a rollout from the environment using the actor. 32 | 33 | Parameters 34 | ---------- 35 | env: MultiAgentEnv 36 | actor: Callable, [GraphsTuple, PRNGKey] -> [Action, LogPi] 37 | key: PRNGKey 38 | 39 | Returns 40 | ------- 41 | data: Rollout 42 | """ 43 | key_x0, key = jax.random.split(key) 44 | init_graph = env.reset(key_x0) 45 | 46 | def body(graph, key_): 47 | action, log_pi = actor(graph, key_) 48 | next_graph, reward, cost, done, info = env.step(graph, action) 49 | return next_graph, (graph, action, reward, cost, done, log_pi, next_graph) 50 | 51 | keys = jax.random.split(key, env.max_episode_steps) 52 | final_graph, (graphs, actions, rewards, costs, dones, log_pis, next_graphs) = \ 53 | jax.lax.scan(body, init_graph, keys, length=env.max_episode_steps) 54 | data = Rollout(graphs, actions, rewards, costs, dones, log_pis, next_graphs) 55 | return data 56 | 57 | 58 | def has_nan(x): 59 | return jtu.tree_map(lambda y: jnp.isnan(y).any(), x) 60 | 61 | 62 | def has_any_nan(x): 63 | return jnp.array(jtu.tree_flatten(has_nan(x))[0]).any() 64 | 65 | 66 | def compute_norm(grad): 67 | return jnp.sqrt(sum(jnp.sum(jnp.square(x)) for x in jtu.tree_leaves(grad))) 68 | 69 | 70 | def compute_norm_and_clip(grad, max_norm: float): 71 | g_norm = compute_norm(grad) 72 | clipped_g_norm = jnp.maximum(max_norm, g_norm) 73 | clipped_grad = jtu.tree_map(lambda t: (t / clipped_g_norm) * max_norm, grad) 74 | 75 | return clipped_grad, g_norm 76 | 77 | 78 | def tree_copy(tree): 79 | return jtu.tree_map(lambda x: x.copy(), tree) 80 | 81 | 82 | def empty_grad_tx() -> optax.GradientTransformation: 83 | def init_fn(params): 84 | return optax.EmptyState() 85 | 86 | def update_fn(updates, state, params=None): 87 | return None, None 88 | 89 | return optax.GradientTransformation(init_fn, update_fn) 90 | 91 | 92 | def jax2np(x): 93 | return jtu.tree_map(lambda y: np.array(y), x) 94 | 95 | 96 | def np2jax(x): 97 | return jtu.tree_map(lambda y: jnp.array(y), x) 98 | 99 | 100 | def is_connected(): 101 | try: 102 | sock = socket.create_connection(("www.google.com", 80)) 103 | if sock is not None: 104 | sock.close() 105 | return True 106 | except OSError: 107 | pass 108 | print('No internet connection') 109 | return False 110 | 111 | 112 | def plot_cbf( 113 | fig: plt.Figure, 114 | cbf: Callable, 115 | env: MultiAgentEnv, 116 | graph: GraphsTuple, 117 | agent_id: int, 118 | x_dim: int, 119 | y_dim: int, 120 | ) -> plt.Figure: 121 | ax = fig.gca() 122 | n_mesh = 30 123 | low_lim, high_lim = env.state_lim(graph.states) 124 | x, y = jnp.meshgrid( 125 | jnp.linspace(low_lim[x_dim], high_lim[x_dim], n_mesh), 126 | jnp.linspace(low_lim[y_dim], high_lim[y_dim], n_mesh) 127 | ) 128 | states = graph.states 129 | 130 | # generate new states 131 | plot_states = states[None, None, :, :].repeat(n_mesh, axis=0).repeat(n_mesh, axis=1) 132 | plot_states = plot_states.at[:, :, agent_id, x_dim].set(x) 133 | plot_states = plot_states.at[:, :, agent_id, y_dim].set(y) 134 | 135 | get_new_graph = env.add_edge_feats 136 | get_new_graph_vmap = jax.vmap(jax.vmap(ft.partial(get_new_graph, graph))) 137 | new_graph = get_new_graph_vmap(plot_states) 138 | h = jax.vmap(jax.vmap(cbf))(new_graph)[:, :, agent_id, :].squeeze(-1) 139 | plt.contourf(x, y, h, cmap=sns.color_palette("rocket", as_cmap=True), levels=15, alpha=0.5) 140 | plt.colorbar() 141 | plt.contour(x, y, h, levels=[0.0], colors='blue') 142 | ax.set_xlim(low_lim[0], high_lim[0]) 143 | ax.set_ylim(low_lim[1], high_lim[1]) 144 | plt.axis('off') 145 | 146 | return fig 147 | 148 | 149 | def get_bb_cbf(cbf: Callable, env: MultiAgentEnv, graph: GraphsTuple, agent_id: int, x_dim: int, y_dim: int): 150 | n_mesh = 20 151 | low_lim = jnp.array([0, 0]) 152 | high_lim = jnp.array([env.area_size, env.area_size]) 153 | b_xs = jnp.linspace(low_lim[x_dim], high_lim[x_dim], n_mesh) 154 | b_ys = jnp.linspace(low_lim[y_dim], high_lim[y_dim], n_mesh) 155 | bb_Xs, bb_Ys = jnp.meshgrid(b_xs, b_ys) 156 | states = graph.states 157 | 158 | # generate new states 159 | bb_plot_states = states[None, None, :, :].repeat(n_mesh, axis=0).repeat(n_mesh, axis=1) 160 | bb_plot_states = bb_plot_states.at[:, :, agent_id, x_dim].set(bb_Xs) 161 | bb_plot_states = bb_plot_states.at[:, :, agent_id, y_dim].set(bb_Ys) 162 | 163 | get_new_graph = env.add_edge_feats 164 | get_new_graph_vmap = jax.vmap(jax.vmap(ft.partial(get_new_graph, graph))) 165 | bb_new_graph = get_new_graph_vmap(bb_plot_states) 166 | bb_h = jax.vmap(jax.vmap(cbf))(bb_new_graph)[:, :, agent_id, :].squeeze(-1) 167 | assert bb_h.shape == (n_mesh, n_mesh) 168 | return b_xs, b_ys, bb_h 169 | 170 | 171 | def centered_norm(vmin: float | list[float], vmax: float | list[float]): 172 | if isinstance(vmin, list): 173 | vmin = min(vmin) 174 | if isinstance(vmax, list): 175 | vmin = max(vmax) 176 | halfrange = max(abs(vmin), abs(vmax)) 177 | return CenteredNorm(0, halfrange) 178 | -------------------------------------------------------------------------------- /gcbfplus/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/utils/__init__.py -------------------------------------------------------------------------------- /gcbfplus/utils/graph.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, NamedTuple, TypeVar, get_type_hints 2 | 3 | import einops as ei 4 | import jax.numpy as jnp 5 | import jax.tree_util as jtu 6 | from jax._src.tree_util import GetAttrKey 7 | 8 | from ..utils.typing import Any, Array, Bool, Float, Int 9 | from .utils import merge01 10 | 11 | _State = TypeVar("_State") 12 | _EnvState = TypeVar("_EnvState") 13 | 14 | 15 | class EdgeBlock(NamedTuple): 16 | edge_feats: Float[Array, "n_recv n_send n_edge_feat"] 17 | edge_mask: Bool[Array, "n_recv n_send"] 18 | ids_recv: Int[Array, "n_recv"] 19 | ids_send: Int[Array, "n_send"] 20 | 21 | @property 22 | def n_recv(self): 23 | assert self.edge_feats.shape[0] == self.edge_mask.shape[0] == len(self.ids_recv) 24 | return len(self.ids_recv) 25 | 26 | @property 27 | def n_send(self): 28 | assert self.edge_feats.shape[1] == self.edge_mask.shape[1] == len(self.ids_send) 29 | return len(self.ids_send) 30 | 31 | @property 32 | def n_edges(self): 33 | return self.n_recv * self.n_send 34 | 35 | def make_edges(self, pad_id: int, edge_mask: Bool[Array, "n_recv n_send"] = None): 36 | id_recv_rep = ei.repeat(self.ids_recv, "n_recv -> n_recv n_send", n_send=self.n_send) 37 | id_send_rep = ei.repeat(self.ids_send, "n_send -> n_recv n_send", n_recv=self.n_recv) 38 | edge_mask = self.edge_mask if edge_mask is None else edge_mask 39 | e_recvs = merge01(jnp.where(edge_mask, id_recv_rep, pad_id)) 40 | e_sends = merge01(jnp.where(edge_mask, id_send_rep, pad_id)) 41 | e_edge_feats = merge01(self.edge_feats) 42 | assert e_recvs.shape == e_sends.shape == e_edge_feats.shape[:1] == (self.n_edges,) 43 | 44 | return e_edge_feats, e_recvs, e_sends 45 | 46 | 47 | @jtu.register_pytree_with_keys_class 48 | class GraphsTuple(tuple, Generic[_State, _EnvState]): 49 | n_node: Int[Array, "n_graph"] # number of nodes in each subgraph 50 | n_edge: Int[Array, "n_graph"] # number of edges in each subgraph 51 | 52 | nodes: Float[Array, "sum_n_node ..."] # node features 53 | edges: Float[Array, "sum_n_edge ..."] # edge features 54 | states: _State # node state features 55 | receivers: Int[Array, "sum_n_edge"] 56 | senders: Int[Array, "sum_n_edge"] 57 | node_type: Int[Array, "sum_n_node"] # by default, 0 is agent, -1 is padding 58 | env_states: _EnvState # environment state features 59 | connectivity: Int[Array, "sum_n_node sum_n_node"] = None # desired connectivity matrix 60 | 61 | def __new__( 62 | cls, 63 | n_node, 64 | n_edge, 65 | nodes, 66 | edges, 67 | states: _State, 68 | receivers, 69 | senders, 70 | node_type, 71 | env_states: _EnvState, 72 | connectivity=None, 73 | ): 74 | tup = (n_node, n_edge, nodes, edges, states, receivers, senders, node_type, env_states, connectivity) 75 | self = tuple.__new__(cls, tup) 76 | self.n_node = n_node 77 | self.n_edge = n_edge 78 | self.nodes = nodes 79 | self.edges = edges 80 | self.states = states 81 | self.receivers = receivers 82 | self.senders = senders 83 | self.node_type = node_type 84 | self.env_states = env_states 85 | self.connectivity = connectivity 86 | return self 87 | 88 | def tree_flatten_with_keys(self): 89 | flat_contents = [(GetAttrKey(k), getattr(self, k)) for k in get_type_hints(GraphsTuple).keys()] 90 | aux_data = None 91 | return flat_contents, aux_data 92 | 93 | @classmethod 94 | def tree_unflatten(cls, aux_data, children): 95 | return cls(*children) 96 | 97 | @property 98 | def is_single(self) -> bool: 99 | return self.n_node.ndim == 0 100 | 101 | @property 102 | def n_graphs(self) -> int: 103 | if self.n_node.ndim == 0: 104 | return 1 105 | assert len(self.n_node) == len(self.n_edge) 106 | return len(self.n_node) 107 | 108 | @property 109 | def batch_shape(self): 110 | return self.n_node.shape 111 | 112 | def type_nodes(self, type_idx: int, n_type: int) -> Float[Array, "... n_type n_feats"]: 113 | assert self.nodes.ndim == 2 114 | n_feats = self.nodes.shape[1] 115 | 116 | n_is_type = self.node_type == type_idx 117 | idx = jnp.cumsum(n_is_type) - 1 118 | 119 | sum_n_type = self.n_graphs * n_type 120 | type_feats = jnp.zeros((sum_n_type, n_feats)) 121 | type_feats = type_feats.at[idx, :].add(n_is_type[:, None] * self.nodes) 122 | 123 | out = type_feats.reshape(self.batch_shape + (n_type, n_feats)) 124 | return out 125 | 126 | def type_states(self, type_idx: int, n_type: int) -> Float[Array, "... n_type n_states"]: 127 | assert self.states.ndim == 2 128 | n_states = self.states.shape[1] 129 | 130 | n_is_type = self.node_type == type_idx 131 | idx = jnp.cumsum(n_is_type) - 1 132 | 133 | sum_n_type = self.n_graphs * n_type 134 | type_feats = jnp.zeros((sum_n_type, n_states)) 135 | type_feats = type_feats.at[idx, :].add(n_is_type[:, None] * self.states) 136 | 137 | out = type_feats.reshape(self.batch_shape + (n_type, n_states)) 138 | return out 139 | 140 | def __str__(self) -> str: 141 | node_repr = str(self.nodes) 142 | edge_repr = str(self.edges) 143 | 144 | return "n_node={}, n_edge={}, \n{}\n---------\n{}\n-------\n{}\n | \n{}".format( 145 | self.n_node, self.n_edge, node_repr, edge_repr, self.senders, self.receivers 146 | ) 147 | 148 | def _replace( 149 | self, 150 | n_node=None, 151 | n_edge=None, 152 | nodes=None, 153 | edges=None, 154 | states: _State = None, 155 | receivers=None, 156 | senders=None, 157 | node_type=None, 158 | env_states: _EnvState = None, 159 | connectivity=None, 160 | ) -> "GraphsTuple": 161 | return GraphsTuple( 162 | self.n_node if n_node is None else n_node, 163 | self.n_edge if n_edge is None else n_edge, 164 | self.nodes if nodes is None else nodes, 165 | self.edges if edges is None else edges, 166 | self.states if states is None else states, 167 | self.receivers if receivers is None else receivers, 168 | self.senders if senders is None else senders, 169 | self.node_type if node_type is None else node_type, 170 | self.env_states if env_states is None else env_states, 171 | self.connectivity if connectivity is None else connectivity, 172 | ) 173 | 174 | def without_edge(self): 175 | return GraphsTuple( 176 | self.n_node, 177 | self.n_edge, 178 | self.nodes, 179 | None, 180 | self.states, 181 | self.receivers, 182 | self.senders, 183 | self.node_type, 184 | self.env_states, 185 | self.connectivity, 186 | ) 187 | 188 | 189 | class GetGraph(NamedTuple): 190 | nodes: Float[Array, "n_nodes n_node_feat"] # node features 191 | node_type: Int[Array, "n_nodes"] # by default, 0 is agent 192 | edge_blocks: list[EdgeBlock] 193 | env_states: Any 194 | states: Float[Array, "n_nodes n_state"] # node state features 195 | connectivity: Int[Array, "n_node n_node"] = None # desired connectivity matrix 196 | 197 | @property 198 | def n_nodes(self): 199 | return self.nodes.shape[0] 200 | 201 | @property 202 | def node_dim(self) -> int: 203 | return self.nodes.shape[1] 204 | 205 | @property 206 | def state_dim(self) -> int: 207 | return self.states.shape[1] 208 | 209 | def to_padded(self) -> GraphsTuple: 210 | # make a dummy node for creating fake self edges. 211 | node_feat_dummy = jnp.zeros(self.node_dim) 212 | node_feats_pad = jnp.concatenate([self.nodes, node_feat_dummy[None]], axis=0) 213 | node_type_pad = jnp.concatenate([self.node_type, jnp.full(1, -1)], axis=0) 214 | state_dummy = jnp.ones(self.state_dim) * -1 215 | state_pad = jnp.concatenate([self.states, state_dummy[None]], axis=0) 216 | 217 | # Construct edge list. 218 | pad_id = self.n_nodes 219 | edge_feats_lst, recv_list, send_list = [], [], [] 220 | for edge_block in self.edge_blocks: 221 | e_edge_feats, e_recvs, e_sends = edge_block.make_edges(pad_id) 222 | edge_feats_lst.append(e_edge_feats) 223 | recv_list.append(e_recvs) 224 | send_list.append(e_sends) 225 | e_edge_feats = jnp.concatenate(edge_feats_lst, axis=0) 226 | e_recv, e_send = jnp.concatenate(recv_list), jnp.concatenate(send_list) 227 | 228 | n_nodes, n_edges = self.n_nodes + 1, e_edge_feats.shape[0] 229 | assert e_recv.shape == e_send.shape == (n_edges,) 230 | n_nodes = jnp.array(n_nodes, dtype=jnp.int32) 231 | n_edges = jnp.array(n_edges, dtype=jnp.int32) 232 | 233 | return GraphsTuple( 234 | n_nodes, 235 | n_edges, 236 | node_feats_pad, 237 | e_edge_feats, 238 | state_pad, 239 | e_recv, 240 | e_send, 241 | node_type_pad, 242 | self.env_states, 243 | self.connectivity, 244 | ) 245 | -------------------------------------------------------------------------------- /gcbfplus/utils/typing.py: -------------------------------------------------------------------------------- 1 | from flax import core, struct 2 | from jaxtyping import Array, Bool, Float, Int, Shaped 3 | from typing import Dict, TypeVar, Any, List 4 | from numpy import ndarray 5 | 6 | 7 | # jax types 8 | PRNGKey = Float[Array, '2'] 9 | 10 | BoolScalar = Bool[Array, ""] 11 | ABool = Bool[Array, "num_agents"] 12 | 13 | # environment types 14 | Action = Float[Array, 'num_agents action_dim'] 15 | Reward = Float[Array, ''] 16 | Cost = Float[Array, ''] 17 | Done = BoolScalar 18 | Info = Dict[str, Shaped[Array, '']] 19 | EdgeIndex = Float[Array, '2 n_edge'] 20 | AgentState = Float[Array, 'num_agents agent_state_dim'] 21 | State = Float[Array, 'num_states state_dim'] 22 | Node = Float[Array, 'num_nodes node_dim'] 23 | EdgeAttr = Float[Array, 'num_edges edge_dim'] 24 | Pos2d = Float[Array, '2'] | Float[ndarray, '2'] 25 | Pos3d = Float[Array, '3'] | Float[ndarray, '3'] 26 | Pos = Pos2d | Pos3d 27 | Radius = Float[Array, ''] | float 28 | 29 | 30 | # neural network types 31 | Params = TypeVar("Params", bound=core.FrozenDict[str, Any]) 32 | 33 | # obstacles 34 | ObsType = Int[Array, ''] 35 | ObsWidth = Float[Array, ''] 36 | ObsHeight = Float[Array, ''] 37 | ObsLength = Float[Array, ''] 38 | ObsTheta = Float[Array, ''] 39 | ObsQuaternion = Float[Array, '4'] 40 | -------------------------------------------------------------------------------- /gcbfplus/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import jax.lax as lax 3 | import einops as ei 4 | import jax 5 | import jax.numpy as jnp 6 | import jax.tree_util as jtu 7 | import matplotlib.collections as mcollections 8 | import numpy as np 9 | import functools as ft 10 | 11 | from datetime import timedelta 12 | from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, Tuple, List, NamedTuple 13 | 14 | from jax import numpy as jnp, tree_util as jtu 15 | from jax._src.lib import xla_client as xc 16 | from matplotlib.animation import FuncAnimation 17 | from rich.progress import Progress, ProgressColumn 18 | from rich.text import Text 19 | from .typing import Array 20 | 21 | 22 | def merge01(x): 23 | return ei.rearrange(x, "n1 n2 ... -> (n1 n2) ...") 24 | 25 | 26 | _P = ParamSpec("_P") 27 | _R = TypeVar("_R") 28 | _Fn = Callable[_P, _R] 29 | 30 | _PyTree = TypeVar("_PyTree") 31 | 32 | 33 | def jax_vmap(fn: _Fn, in_axes: int | Sequence[Any] = 0, out_axes: Any = 0) -> _Fn: 34 | return jax.vmap(fn, in_axes, out_axes) 35 | 36 | 37 | def concat_at_front(arr1: jnp.ndarray, arr2: jnp.ndarray, axis: int) -> jnp.ndarray: 38 | """ 39 | :param arr1: (nx, ) 40 | :param arr2: (T, nx) 41 | :param axis: Which axis for arr2 to concat under. 42 | :return: (T + 1, nx) with [arr1 arr2] 43 | """ 44 | # The shapes of arr1 and arr2 should be the same without the dim at axis for arr1. 45 | arr2_shape = list(arr2.shape) 46 | del arr2_shape[axis] 47 | assert np.all(np.array(arr2_shape) == np.array(arr1.shape)) 48 | 49 | if isinstance(arr1, np.ndarray): 50 | return np.concatenate([np.expand_dims(arr1, axis=axis), arr2], axis=axis) 51 | else: 52 | return jnp.concatenate([jnp.expand_dims(arr1, axis=axis), arr2], axis=axis) 53 | 54 | 55 | def tree_concat_at_front(tree1: _PyTree, tree2: _PyTree, axis: int) -> _PyTree: 56 | def tree_concat_at_front_inner(arr1: jnp.ndarray, arr2: jnp.ndarray): 57 | return concat_at_front(arr1, arr2, axis=axis) 58 | 59 | return jtu.tree_map(tree_concat_at_front_inner, tree1, tree2) 60 | 61 | 62 | def tree_index(tree: _PyTree, idx: int) -> _PyTree: 63 | return jtu.tree_map(lambda x: x[idx], tree) 64 | 65 | 66 | def jax2np(pytree: _PyTree) -> _PyTree: 67 | return jtu.tree_map(np.array, pytree) 68 | 69 | 70 | def np2jax(pytree: _PyTree) -> _PyTree: 71 | return jtu.tree_map(jnp.array, pytree) 72 | 73 | 74 | def mask2index(mask: jnp.ndarray, n_true: int) -> jnp.ndarray: 75 | idx = lax.top_k(mask, n_true)[1] 76 | return idx 77 | 78 | 79 | def jax_jit_np( 80 | fn: _Fn, 81 | static_argnums: int | Sequence[int] | None = None, 82 | static_argnames: str | Iterable[str] | None = None, 83 | donate_argnums: int | Sequence[int] = (), 84 | device: xc.Device = None, 85 | *args, 86 | **kwargs, 87 | ) -> _Fn: 88 | jit_fn = jax.jit(fn, static_argnums, static_argnames, donate_argnums, device, *args, **kwargs) 89 | 90 | def wrapper(*args, **kwargs) -> _R: 91 | return jax2np(jit_fn(*args, **kwargs)) 92 | 93 | return wrapper 94 | 95 | 96 | def chunk_vmap(fn: _Fn, chunks: int) -> _Fn: 97 | fn_jit_vmap = jax_jit_np(jax.vmap(fn)) 98 | 99 | def wrapper(*args) -> _R: 100 | args = list(args) 101 | # 1: Get the batch size. 102 | batch_size = len(jtu.tree_leaves(args[0])[0]) 103 | chunk_idxs = np.array_split(np.arange(batch_size), chunks) 104 | 105 | out = [] 106 | for idxs in chunk_idxs: 107 | chunk_input = jtu.tree_map(lambda x: x[idxs], args) 108 | out.append(fn_jit_vmap(*chunk_input)) 109 | 110 | # 2: Concatenate the output. 111 | out = tree_merge(out) 112 | return out 113 | 114 | return wrapper 115 | 116 | class MutablePatchCollection(mcollections.PatchCollection): 117 | def __init__(self, patches, *args, **kwargs): 118 | self._paths = None 119 | self.patches = patches 120 | mcollections.PatchCollection.__init__(self, patches, *args, **kwargs) 121 | 122 | def get_paths(self): 123 | self.set_paths(self.patches) 124 | return self._paths 125 | 126 | 127 | class CustomTimeElapsedColumn(ProgressColumn): 128 | """Renders time elapsed.""" 129 | 130 | def render(self, task: "Task") -> Text: 131 | """Show time elapsed.""" 132 | elapsed = task.finished_time if task.finished else task.elapsed 133 | if elapsed is None: 134 | return Text("-:--:--", style="progress.elapsed") 135 | delta = timedelta(seconds=elapsed) 136 | delta = timedelta(seconds=delta.seconds, milliseconds=round(delta.microseconds // 1000)) 137 | delta_str = str(delta) 138 | return Text(delta_str, style="progress.elapsed") 139 | 140 | 141 | def save_anim(ani: FuncAnimation, path: pathlib.Path): 142 | pbar = Progress(*Progress.get_default_columns(), CustomTimeElapsedColumn()) 143 | pbar.start() 144 | task = pbar.add_task("Animating", total=ani._save_count) 145 | 146 | def progress_callback(curr_frame: int, total_frames: int): 147 | pbar.update(task, advance=1) 148 | 149 | ani.save(path, progress_callback=progress_callback) 150 | pbar.stop() 151 | 152 | 153 | def tree_merge(data: List[NamedTuple]): 154 | def body(*x): 155 | x = list(x) 156 | if isinstance(x[0], np.ndarray): 157 | return np.concatenate(x, axis=0) 158 | else: 159 | return jnp.concatenate(x, axis=0) 160 | out = jtu.tree_map(body, *data) 161 | return out 162 | 163 | 164 | def tree_stack(trees: list): 165 | def tree_stack_inner(*arrs): 166 | arrs = list(arrs) 167 | if isinstance(arrs[0], np.ndarray): 168 | return np.stack(arrs, axis=0) 169 | return np.stack(arrs, axis=0) 170 | 171 | return jtu.tree_map(tree_stack_inner, *trees) 172 | -------------------------------------------------------------------------------- /media/DoubleIntegrator_512_2x.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/DoubleIntegrator_512_2x.gif -------------------------------------------------------------------------------- /media/Obstacle2D_32.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/Obstacle2D_32.gif -------------------------------------------------------------------------------- /media/Obstacle2D_512_2x.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/Obstacle2D_512_2x.gif -------------------------------------------------------------------------------- /media/cbf1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/cbf1.gif -------------------------------------------------------------------------------- /pretrained/CrazyFlie/gcbf+/config.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:argparse.Namespace 2 | algo: gcbf+ 3 | alpha: 1.0 4 | area_size: 2.0 5 | batch_size: 256 6 | buffer_size: 512 7 | cost_weight: 1.0 8 | debug: false 9 | env: CrazyFlie 10 | eval_epi: 1 11 | eval_interval: 1 12 | gnn_layers: 1 13 | horizon: 32 14 | log_dir: ./logs 15 | loss_action_coef: 3.0e-05 16 | loss_h_dot_coef: 0.01 17 | loss_safe_coef: 1.0 18 | loss_unsafe_coef: 1.0 19 | lr_actor: 1.0e-05 20 | lr_cbf: 0.0001 21 | lr_critic: 0.0001 22 | n_env_test: 32 23 | n_env_train: 16 24 | n_rays: 32 25 | name: null 26 | num_agents: 8 27 | obs: 0 28 | save_interval: 10 29 | seed: 0 30 | steps: 1000 31 | alpha: 1.0 32 | batch_size: 256 33 | eps: 0.02 34 | gnn_layers: 1 35 | horizon: 32 36 | inner_epoch: 8 37 | loss_action_coef: 3.0e-05 38 | loss_h_dot_coef: 0.01 39 | loss_safe_coef: 1.0 40 | loss_unsafe_coef: 1.0 41 | lr_actor: 1.0e-05 42 | lr_cbf: 0.0001 43 | max_grad_norm: 2.0 44 | seed: 0 45 | -------------------------------------------------------------------------------- /pretrained/CrazyFlie/gcbf+/models/1000/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/CrazyFlie/gcbf+/models/1000/actor.pkl -------------------------------------------------------------------------------- /pretrained/CrazyFlie/gcbf+/models/1000/cbf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/CrazyFlie/gcbf+/models/1000/cbf.pkl -------------------------------------------------------------------------------- /pretrained/DoubleIntegrator/gcbf+/config.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:argparse.Namespace 2 | algo: gcbf+ 3 | batch_size: 256 4 | buffer_size: 512 5 | debug: false 6 | env: DoubleIntegrator 7 | eval_epi: 1 8 | eval_interval: 1 9 | gnn_layers: 1 10 | log_dir: ./logs 11 | loss_action_coef: 0.0001 12 | loss_h_dot_coef: 0.01 13 | loss_safe_coef: 1.0 14 | loss_unsafe_coef: 1.0 15 | lr_actor: 1.0e-05 16 | lr_cbf: 1.0e-05 17 | n_env_test: 32 18 | n_env_train: 16 19 | name: null 20 | num_agents: 8 21 | save_interval: 10 22 | seed: 2 23 | steps: 1000 24 | alpha: 1.0 25 | batch_size: 256 26 | eps: 0.02 27 | gnn_layers: 1 28 | horizon: 32 29 | inner_epoch: 8 30 | loss_action_coef: 0.0001 31 | loss_h_dot_coef: 0.01 32 | loss_safe_coef: 1.0 33 | loss_unsafe_coef: 1.0 34 | lr_actor: 1.0e-05 35 | lr_cbf: 1.0e-05 36 | max_grad_norm: 2.0 37 | seed: 2 38 | -------------------------------------------------------------------------------- /pretrained/DoubleIntegrator/gcbf+/models/1000/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DoubleIntegrator/gcbf+/models/1000/actor.pkl -------------------------------------------------------------------------------- /pretrained/DoubleIntegrator/gcbf+/models/1000/cbf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DoubleIntegrator/gcbf+/models/1000/cbf.pkl -------------------------------------------------------------------------------- /pretrained/DubinsCar/gcbf+/config.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:argparse.Namespace 2 | algo: gcbf+ 3 | alpha: 1.0 4 | area_size: 4.0 5 | batch_size: 256 6 | buffer_size: 512 7 | cost_weight: 1.0 8 | debug: false 9 | env: DubinsCar 10 | eval_epi: 1 11 | eval_interval: 1 12 | gnn_layers: 1 13 | horizon: 32 14 | log_dir: ./logs 15 | loss_action_coef: 1.0e-05 16 | loss_h_dot_coef: 0.01 17 | loss_safe_coef: 1.0 18 | loss_unsafe_coef: 1.0 19 | lr_actor: 2.0e-05 20 | lr_cbf: 2.0e-05 21 | lr_critic: 0.0001 22 | n_env_test: 32 23 | n_env_train: 12 24 | n_rays: 32 25 | name: null 26 | num_agents: 8 27 | obs: 4 28 | save_interval: 10 29 | seed: 1 30 | steps: 1000 31 | alpha: 1.0 32 | batch_size: 256 33 | eps: 0.02 34 | gnn_layers: 1 35 | horizon: 32 36 | inner_epoch: 8 37 | loss_action_coef: 1.0e-05 38 | loss_h_dot_coef: 0.01 39 | loss_safe_coef: 1.0 40 | loss_unsafe_coef: 1.0 41 | lr_actor: 2.0e-05 42 | lr_cbf: 2.0e-05 43 | max_grad_norm: 2.0 44 | seed: 1 45 | -------------------------------------------------------------------------------- /pretrained/DubinsCar/gcbf+/models/1000/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DubinsCar/gcbf+/models/1000/actor.pkl -------------------------------------------------------------------------------- /pretrained/DubinsCar/gcbf+/models/1000/cbf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DubinsCar/gcbf+/models/1000/cbf.pkl -------------------------------------------------------------------------------- /pretrained/LinearDrone/gcbf+/config.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:argparse.Namespace 2 | algo: gcbf+ 3 | alpha: 1.0 4 | area_size: 2.0 5 | batch_size: 256 6 | buffer_size: 512 7 | cost_weight: 1.0 8 | debug: false 9 | env: LinearDrone 10 | eval_epi: 1 11 | eval_interval: 1 12 | gnn_layers: 1 13 | horizon: 32 14 | log_dir: ./logs 15 | loss_action_coef: 0.0001 16 | loss_h_dot_coef: 0.01 17 | loss_safe_coef: 1.0 18 | loss_unsafe_coef: 1.0 19 | lr_actor: 1.0e-05 20 | lr_cbf: 1.0e-05 21 | lr_critic: 0.0001 22 | n_env_test: 32 23 | n_env_train: 12 24 | n_rays: 32 25 | name: null 26 | num_agents: 8 27 | obs: 8 28 | save_interval: 10 29 | seed: 2 30 | steps: 1000 31 | alpha: 1.0 32 | batch_size: 256 33 | eps: 0.02 34 | gnn_layers: 1 35 | horizon: 32 36 | inner_epoch: 8 37 | loss_action_coef: 0.0001 38 | loss_h_dot_coef: 0.01 39 | loss_safe_coef: 1.0 40 | loss_unsafe_coef: 1.0 41 | lr_actor: 1.0e-05 42 | lr_cbf: 1.0e-05 43 | max_grad_norm: 2.0 44 | seed: 2 45 | -------------------------------------------------------------------------------- /pretrained/LinearDrone/gcbf+/models/1000/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/LinearDrone/gcbf+/models/1000/actor.pkl -------------------------------------------------------------------------------- /pretrained/LinearDrone/gcbf+/models/1000/cbf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/LinearDrone/gcbf+/models/1000/cbf.pkl -------------------------------------------------------------------------------- /pretrained/SingleIntegrator/gcbf+/config.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:argparse.Namespace 2 | algo: gcbf+ 3 | alpha: 1.0 4 | batch_size: 256 5 | buffer_size: 512 6 | debug: false 7 | env: SingleIntegrator 8 | eval_epi: 1 9 | eval_interval: 1 10 | gnn_layers: 1 11 | horizon: 1 12 | log_dir: ./logs 13 | loss_action_coef: 0.0001 14 | loss_h_dot_coef: 0.01 15 | loss_safe_coef: 1.0 16 | loss_unsafe_coef: 1.0 17 | lr_actor: 1.0e-05 18 | lr_cbf: 1.0e-05 19 | n_env_test: 32 20 | n_env_train: 16 21 | name: null 22 | num_agents: 8 23 | obs: null 24 | save_interval: 10 25 | seed: 0 26 | steps: 1000 27 | alpha: 1.0 28 | batch_size: 256 29 | eps: 0.02 30 | gnn_layers: 1 31 | horizon: 1 32 | inner_epoch: 8 33 | loss_action_coef: 0.0001 34 | loss_h_dot_coef: 0.01 35 | loss_safe_coef: 1.0 36 | loss_unsafe_coef: 1.0 37 | lr_actor: 1.0e-05 38 | lr_cbf: 1.0e-05 39 | max_grad_norm: 2.0 40 | seed: 0 41 | -------------------------------------------------------------------------------- /pretrained/SingleIntegrator/gcbf+/models/1000/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/SingleIntegrator/gcbf+/models/1000/actor.pkl -------------------------------------------------------------------------------- /pretrained/SingleIntegrator/gcbf+/models/1000/cbf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/SingleIntegrator/gcbf+/models/1000/cbf.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flax>=0.7.2 2 | jax>=0.4.14 3 | jraph>=0.0.6.dev0 4 | jaxtyping>=0.2.21 5 | numpy>=1.25.2 6 | einops>=0.6.1 7 | matplotlib>=3.7.2 8 | opencv-python>=4.8.0.76 9 | tqdm>=4.66.1 10 | tensorflow-probability>=0.21.0 11 | optax>=0.1.7 12 | scipy>=1.11.2 13 | wandb>=0.15.8 14 | pyyaml>=6.0.1 15 | orbax_checkpoint>=0.3.5 16 | seaborn>=0.12.2 17 | equinox>=0.11.0 18 | loguru>=0.7.2 19 | attrs>=23.1.0 20 | rich>=13.5.3 21 | ipdb>=0.13.13 22 | colour>=0.1.5 23 | control>=0.9.4 24 | jaxproxqp @ git+https://github.com/oswinso/jaxproxqp.git -------------------------------------------------------------------------------- /settings.yaml: -------------------------------------------------------------------------------- 1 | DubinsCar: 2 | --env: DubinsCar 3 | --lr-actor: 3e-5 4 | --lr-cbf: 3e-5 5 | --loss-action-coef: 1e-5 6 | --n-env-train: 16 7 | --horizon: 32 8 | --area-size: 4 9 | DoubleIntegrator: 10 | --env: DoubleIntegrator 11 | --lr-actor: 1e-5 12 | --lr-cbf: 1e-5 13 | --loss-action-coef: 1e-4 14 | --n-env-train: 16 15 | --horizon: 32 16 | --area-size: 4 17 | SingleIntegrator: 18 | --env: SingleIntegrator 19 | --lr-actor: 1e-5 20 | --lr-cbf: 1e-5 21 | --loss-action-coef: 1e-4 22 | --n-env-train: 16 23 | --horizon: 1 24 | --area-size: 4 25 | LinearDrone: 26 | --env: LinearDrone 27 | --lr-actor: 1e-5 28 | --lr-cbf: 1e-5 29 | --loss-action-coef: 1e-3 30 | --n-env-train: 16 31 | --horizon: 32 32 | --area-size: 2 33 | CrazyFlie: 34 | --env: CrazyFlie 35 | --lr-actor: 1e-5 36 | --lr-cbf: 1e-4 37 | --loss-action-coef: 3e-5 38 | --n-env-train: 16 39 | --horizon: 32 40 | --area-size: 2 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | 6 | setup( 7 | name="gcbfplus", 8 | version="0.0.0", 9 | description='Jax Official Implementation of CoRL Paper: : S Zhang, O So, K Garg, C Fan: ' 10 | '"GCBF+: A Neural Graph Control Barrier Function Framework for Distributed Safe Multi-Agent Control"', 11 | author="Songyuan Zhang", 12 | author_email="szhang21@mit.edu", 13 | url="https://github.com/MIT-REALM/gcbfplus", 14 | install_requires=[], 15 | packages=find_packages(), 16 | ) 17 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import functools as ft 4 | import os 5 | import pathlib 6 | import ipdb 7 | import jax 8 | import jax.numpy as jnp 9 | import jax.random as jr 10 | import jax.tree_util as jtu 11 | import numpy as np 12 | import yaml 13 | 14 | from gcbfplus.algo import GCBF, GCBFPlus, make_algo, CentralizedCBF, DecShareCBF 15 | from gcbfplus.env import make_env 16 | from gcbfplus.env.base import RolloutResult 17 | from gcbfplus.trainer.utils import get_bb_cbf 18 | from gcbfplus.utils.graph import GraphsTuple 19 | from gcbfplus.utils.utils import jax_jit_np, tree_index, chunk_vmap, merge01, jax_vmap 20 | 21 | 22 | def test(args): 23 | print(f"> Running test.py {args}") 24 | 25 | stamp_str = datetime.datetime.now().strftime("%m%d-%H%M") 26 | 27 | # set up environment variables and seed 28 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 29 | if args.cpu: 30 | os.environ["JAX_PLATFORM_NAME"] = "cpu" 31 | if args.debug: 32 | jax.config.update("jax_disable_jit", True) 33 | np.random.seed(args.seed) 34 | 35 | # load config 36 | if not args.u_ref and args.path is not None: 37 | with open(os.path.join(args.path, "config.yaml"), "r") as f: 38 | config = yaml.load(f, Loader=yaml.UnsafeLoader) 39 | 40 | # create environments 41 | num_agents = config.num_agents if args.num_agents is None else args.num_agents 42 | env = make_env( 43 | env_id=config.env if args.env is None else args.env, 44 | num_agents=num_agents, 45 | num_obs=args.obs, 46 | area_size=args.area_size, 47 | max_step=args.max_step, 48 | max_travel=args.max_travel, 49 | ) 50 | 51 | if not args.u_ref: 52 | if args.path is not None: 53 | path = args.path 54 | model_path = os.path.join(path, "models") 55 | if args.step is None: 56 | models = os.listdir(model_path) 57 | step = max([int(model) for model in models if model.isdigit()]) 58 | else: 59 | step = args.step 60 | print("step: ", step) 61 | 62 | algo = make_algo( 63 | algo=config.algo, 64 | env=env, 65 | node_dim=env.node_dim, 66 | edge_dim=env.edge_dim, 67 | state_dim=env.state_dim, 68 | action_dim=env.action_dim, 69 | n_agents=env.num_agents, 70 | gnn_layers=config.gnn_layers, 71 | batch_size=config.batch_size, 72 | buffer_size=config.buffer_size, 73 | horizon=config.horizon, 74 | lr_actor=config.lr_actor, 75 | lr_cbf=config.lr_cbf, 76 | alpha=config.alpha, 77 | eps=0.02, 78 | inner_epoch=8, 79 | loss_action_coef=config.loss_action_coef, 80 | loss_unsafe_coef=config.loss_unsafe_coef, 81 | loss_safe_coef=config.loss_safe_coef, 82 | loss_h_dot_coef=config.loss_h_dot_coef, 83 | max_grad_norm=2.0, 84 | seed=config.seed 85 | ) 86 | algo.load(model_path, step) 87 | act_fn = jax.jit(algo.act) 88 | else: 89 | algo = make_algo( 90 | algo=args.algo, 91 | env=env, 92 | node_dim=env.node_dim, 93 | edge_dim=env.edge_dim, 94 | state_dim=env.state_dim, 95 | action_dim=env.action_dim, 96 | n_agents=env.num_agents, 97 | alpha=args.alpha, 98 | ) 99 | act_fn = jax.jit(algo.act) 100 | path = os.path.join(f"./logs/{args.env}/{args.algo}") 101 | if not os.path.exists(path): 102 | os.makedirs(path) 103 | step = None 104 | else: 105 | assert args.env is not None 106 | path = os.path.join(f"./logs/{args.env}/nominal") 107 | if not os.path.exists("./logs"): 108 | os.mkdir("./logs") 109 | if not os.path.exists(os.path.join("./logs", args.env)): 110 | os.mkdir(os.path.join("./logs", args.env)) 111 | if not os.path.exists(path): 112 | os.mkdir(path) 113 | algo = None 114 | act_fn = jax.jit(env.u_ref) 115 | step = 0 116 | 117 | test_key = jr.PRNGKey(args.seed) 118 | test_keys = jr.split(test_key, 1_000)[: args.epi] 119 | test_keys = test_keys[args.offset:] 120 | 121 | algo_is_cbf = isinstance(algo, (CentralizedCBF, DecShareCBF)) 122 | 123 | if args.cbf is not None: 124 | assert isinstance(algo, GCBF) or isinstance(algo, GCBFPlus) or isinstance(algo, CentralizedCBF) 125 | get_bb_cbf_fn_ = ft.partial(get_bb_cbf, algo.get_cbf, env, agent_id=args.cbf, x_dim=0, y_dim=1) 126 | get_bb_cbf_fn_ = jax_jit_np(get_bb_cbf_fn_) 127 | 128 | def get_bb_cbf_fn(T_graph: GraphsTuple): 129 | T = len(T_graph.states) 130 | outs = [get_bb_cbf_fn_(tree_index(T_graph, kk)) for kk in range(T)] 131 | Tb_x, Tb_y, Tbb_h = jtu.tree_map(lambda *x: jnp.stack(list(x), axis=0), *outs) 132 | return Tb_x, Tb_y, Tbb_h 133 | else: 134 | get_bb_cbf_fn = None 135 | cbf_fn = None 136 | 137 | if args.nojit_rollout: 138 | print("Only jit step, no jit rollout!") 139 | rollout_fn = env.rollout_fn_jitstep(act_fn, args.max_step, noedge=True, nograph=args.no_video) 140 | 141 | is_unsafe_fn = None 142 | is_finish_fn = None 143 | else: 144 | print("jit rollout!") 145 | rollout_fn = jax_jit_np(env.rollout_fn(act_fn, args.max_step)) 146 | 147 | is_unsafe_fn = jax_jit_np(jax_vmap(env.collision_mask)) 148 | is_finish_fn = jax_jit_np(jax_vmap(env.finish_mask)) 149 | 150 | rewards = [] 151 | costs = [] 152 | rollouts = [] 153 | is_unsafes = [] 154 | is_finishes = [] 155 | rates = [] 156 | cbfs = [] 157 | for i_epi in range(args.epi): 158 | key_x0, _ = jr.split(test_keys[i_epi], 2) 159 | 160 | if args.nojit_rollout: 161 | rollout: RolloutResult 162 | rollout, is_unsafe, is_finish = rollout_fn(key_x0) 163 | # if not jnp.isnan(rollout.T_reward).any(): 164 | is_unsafes.append(is_unsafe) 165 | is_finishes.append(is_finish) 166 | else: 167 | rollout: RolloutResult = rollout_fn(key_x0) 168 | # if not jnp.isnan(rollout.T_reward).any(): 169 | is_unsafes.append(is_unsafe_fn(rollout.Tp1_graph)) 170 | is_finishes.append(is_finish_fn(rollout.Tp1_graph)) 171 | 172 | epi_reward = rollout.T_reward.sum() 173 | epi_cost = rollout.T_cost.sum() 174 | rewards.append(epi_reward) 175 | costs.append(epi_cost) 176 | rollouts.append(rollout) 177 | 178 | if args.cbf is not None: 179 | cbfs.append(get_bb_cbf_fn(rollout.Tp1_graph)) 180 | else: 181 | cbfs.append(None) 182 | if len(is_unsafes) == 0: 183 | continue 184 | safe_rate = 1 - is_unsafes[-1].max(axis=0).mean() 185 | finish_rate = is_finishes[-1].max(axis=0).mean() 186 | success_rate = ((1 - is_unsafes[-1].max(axis=0)) * is_finishes[-1].max(axis=0)).mean() 187 | print(f"epi: {i_epi}, reward: {epi_reward:.3f}, cost: {epi_cost:.3f}, " 188 | f"safe rate: {safe_rate * 100:.3f}%," 189 | f"finish rate: {finish_rate * 100:.3f}%, " 190 | f"success rate: {success_rate * 100:.3f}%") 191 | 192 | rates.append(np.array([safe_rate, finish_rate, success_rate])) 193 | is_unsafe = np.max(np.stack(is_unsafes), axis=1) 194 | is_finish = np.max(np.stack(is_finishes), axis=1) 195 | 196 | safe_mean, safe_std = (1 - is_unsafe).mean(), (1 - is_unsafe).std() 197 | finish_mean, finish_std = is_finish.mean(), is_finish.std() 198 | success_mean, success_std = ((1 - is_unsafe) * is_finish).mean(), ((1 - is_unsafe) * is_finish).std() 199 | 200 | print( 201 | f"reward: {np.mean(rewards):.3f}, min/max reward: {np.min(rewards):.3f}/{np.max(rewards):.3f}, " 202 | f"cost: {np.mean(costs):.3f}, min/max cost: {np.min(costs):.3f}/{np.max(costs):.3f}, " 203 | f"safe_rate: {safe_mean * 100:.3f}%, " 204 | f"finish_rate: {finish_mean * 100:.3f}%, " 205 | f"success_rate: {success_mean * 100:.3f}%" 206 | ) 207 | 208 | # save results 209 | if args.log: 210 | with open(os.path.join(path, "test_log.csv"), "a") as f: 211 | f.write(f"{env.num_agents},{args.epi},{env.max_episode_steps}," 212 | f"{env.area_size},{env.params['n_obs']}," 213 | f"{safe_mean * 100:.3f},{safe_std * 100:.3f}," 214 | f"{finish_mean * 100:.3f},{finish_std * 100:.3f}," 215 | f"{success_mean * 100:.3f},{success_std * 100:.3f}\n") 216 | 217 | # make video 218 | if args.no_video: 219 | return 220 | 221 | videos_dir = pathlib.Path(path) / "videos" 222 | videos_dir.mkdir(exist_ok=True, parents=True) 223 | for ii, (rollout, Ta_is_unsafe, cbf) in enumerate(zip(rollouts, is_unsafes, cbfs)): 224 | if algo_is_cbf: 225 | safe_rate, finish_rate, success_rate = rates[ii] * 100 226 | video_name = f"n{num_agents}_epi{ii:02}_sr{safe_rate:.0f}_fr{finish_rate:.0f}_sr{success_rate:.0f}" 227 | else: 228 | video_name = f"n{num_agents}_step{step}_epi{ii:02}_reward{rewards[ii]:.3f}_cost{costs[ii]:.3f}" 229 | 230 | viz_opts = {} 231 | if args.cbf is not None: 232 | video_name += f"_cbf{args.cbf}" 233 | viz_opts["cbf"] = [*cbf, args.cbf] 234 | 235 | video_path = videos_dir / f"{stamp_str}_{video_name}.mp4" 236 | env.render_video(rollout, video_path, Ta_is_unsafe, viz_opts, dpi=args.dpi) 237 | 238 | 239 | def main(): 240 | parser = argparse.ArgumentParser() 241 | parser.add_argument("-n", "--num-agents", type=int, default=None) 242 | parser.add_argument("--obs", type=int, default=0) 243 | parser.add_argument("--area-size", type=float, required=True) 244 | parser.add_argument("--max-step", type=int, default=None) 245 | parser.add_argument("--path", type=str, default=None) 246 | parser.add_argument("--n-rays", type=int, default=32) 247 | parser.add_argument("--alpha", type=float, default=1.0) 248 | parser.add_argument("--max-travel", type=float, default=None) 249 | parser.add_argument("--cbf", type=int, default=None) 250 | 251 | parser.add_argument("--seed", type=int, default=1234) 252 | parser.add_argument("--debug", action="store_true", default=False) 253 | parser.add_argument("--cpu", action="store_true", default=False) 254 | parser.add_argument("--u-ref", action="store_true", default=False) 255 | parser.add_argument("--env", type=str, default=None) 256 | parser.add_argument("--algo", type=str, default=None) 257 | parser.add_argument("--step", type=int, default=None) 258 | parser.add_argument("--epi", type=int, default=5) 259 | parser.add_argument("--offset", type=int, default=0) 260 | parser.add_argument("--no-video", action="store_true", default=False) 261 | parser.add_argument("--nojit-rollout", action="store_true", default=False) 262 | parser.add_argument("--log", action="store_true", default=False) 263 | parser.add_argument("--dpi", type=int, default=100) 264 | 265 | args = parser.parse_args() 266 | test(args) 267 | 268 | 269 | if __name__ == "__main__": 270 | with ipdb.launch_ipdb_on_exception(): 271 | main() 272 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import ipdb 5 | import numpy as np 6 | import wandb 7 | import yaml 8 | 9 | from gcbfplus.algo import make_algo 10 | from gcbfplus.env import make_env 11 | from gcbfplus.trainer.trainer import Trainer 12 | from gcbfplus.trainer.utils import is_connected 13 | 14 | 15 | def train(args): 16 | print(f"> Running train.py {args}") 17 | 18 | # set up environment variables and seed 19 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 20 | if not is_connected(): 21 | os.environ["WANDB_MODE"] = "offline" 22 | np.random.seed(args.seed) 23 | if args.debug: 24 | os.environ["WANDB_MODE"] = "disabled" 25 | os.environ["JAX_DISABLE_JIT"] = "True" 26 | 27 | # create environments 28 | env = make_env( 29 | env_id=args.env, 30 | num_agents=args.num_agents, 31 | num_obs=args.obs, 32 | n_rays=args.n_rays, 33 | area_size=args.area_size 34 | ) 35 | env_test = make_env( 36 | env_id=args.env, 37 | num_agents=args.num_agents, 38 | num_obs=args.obs, 39 | n_rays=args.n_rays, 40 | area_size=args.area_size 41 | ) 42 | 43 | # create low level controller 44 | algo = make_algo( 45 | algo=args.algo, 46 | env=env, 47 | node_dim=env.node_dim, 48 | edge_dim=env.edge_dim, 49 | state_dim=env.state_dim, 50 | action_dim=env.action_dim, 51 | n_agents=env.num_agents, 52 | gnn_layers=args.gnn_layers, 53 | batch_size=256, 54 | buffer_size=args.buffer_size, 55 | horizon=args.horizon, 56 | lr_actor=args.lr_actor, 57 | lr_cbf=args.lr_cbf, 58 | alpha=args.alpha, 59 | eps=0.02, 60 | inner_epoch=8, 61 | loss_action_coef=args.loss_action_coef, 62 | loss_unsafe_coef=args.loss_unsafe_coef, 63 | loss_safe_coef=args.loss_safe_coef, 64 | loss_h_dot_coef=args.loss_h_dot_coef, 65 | max_grad_norm=2.0, 66 | seed=args.seed, 67 | ) 68 | 69 | # set up logger 70 | start_time = datetime.datetime.now() 71 | start_time = start_time.strftime("%Y%m%d%H%M%S") 72 | if not os.path.exists(args.log_dir): 73 | os.makedirs(args.log_dir) 74 | if not os.path.exists(f"{args.log_dir}/{args.env}"): 75 | os.makedirs(f"{args.log_dir}/{args.env}") 76 | if not os.path.exists(f"{args.log_dir}/{args.env}/{args.algo}"): 77 | os.makedirs(f"{args.log_dir}/{args.env}/{args.algo}") 78 | log_dir = f"{args.log_dir}/{args.env}/{args.algo}/seed{args.seed}_{start_time}" 79 | run_name = f"{args.algo}_{args.env}_{start_time}" if args.name is None else args.name 80 | 81 | # get training parameters 82 | train_params = { 83 | "run_name": run_name, 84 | "training_steps": args.steps, 85 | "eval_interval": args.eval_interval, 86 | "eval_epi": args.eval_epi, 87 | "save_interval": args.save_interval, 88 | } 89 | 90 | # create trainer 91 | trainer = Trainer( 92 | env=env, 93 | env_test=env_test, 94 | algo=algo, 95 | log_dir=log_dir, 96 | n_env_train=args.n_env_train, 97 | n_env_test=args.n_env_test, 98 | seed=args.seed, 99 | params=train_params, 100 | save_log=not args.debug, 101 | ) 102 | 103 | # save config 104 | wandb.config.update(args) 105 | wandb.config.update(algo.config) 106 | if not args.debug: 107 | with open(f"{log_dir}/config.yaml", "w") as f: 108 | yaml.dump(args, f) 109 | yaml.dump(algo.config, f) 110 | 111 | # start training 112 | trainer.train() 113 | 114 | 115 | def main(): 116 | parser = argparse.ArgumentParser() 117 | 118 | # custom arguments 119 | parser.add_argument("-n", "--num-agents", type=int, default=8) 120 | parser.add_argument("--algo", type=str, default="gcbf+") 121 | parser.add_argument("--env", type=str, default="SimpleCar") 122 | parser.add_argument("--seed", type=int, default=0) 123 | parser.add_argument("--steps", type=int, default=1000) 124 | parser.add_argument("--name", type=str, default=None) 125 | parser.add_argument("--debug", action="store_true", default=False) 126 | parser.add_argument("--obs", type=int, default=None) 127 | parser.add_argument("--n-rays", type=int, default=32) 128 | parser.add_argument("--area-size", type=float, required=True) 129 | 130 | # gcbf / gcbf+ arguments 131 | parser.add_argument("--gnn-layers", type=int, default=1) 132 | parser.add_argument("--alpha", type=float, default=1.0) 133 | parser.add_argument("--horizon", type=int, default=32) 134 | parser.add_argument("--lr-actor", type=float, default=3e-5) 135 | parser.add_argument("--lr-cbf", type=float, default=3e-5) 136 | parser.add_argument("--loss-action-coef", type=float, default=0.0001) 137 | parser.add_argument("--loss-unsafe-coef", type=float, default=1.0) 138 | parser.add_argument("--loss-safe-coef", type=float, default=1.0) 139 | parser.add_argument("--loss-h-dot-coef", type=float, default=0.01) 140 | parser.add_argument("--buffer-size", type=int, default=512) 141 | 142 | # default arguments 143 | parser.add_argument("--n-env-train", type=int, default=16) 144 | parser.add_argument("--n-env-test", type=int, default=32) 145 | parser.add_argument("--log-dir", type=str, default="./logs") 146 | parser.add_argument("--eval-interval", type=int, default=1) 147 | parser.add_argument("--eval-epi", type=int, default=1) 148 | parser.add_argument("--save-interval", type=int, default=10) 149 | 150 | args = parser.parse_args() 151 | train(args) 152 | 153 | 154 | if __name__ == "__main__": 155 | with ipdb.launch_ipdb_on_exception(): 156 | main() 157 | --------------------------------------------------------------------------------