├── .gitignore
├── LICENSE
├── README.md
├── gcbfplus
├── __init__.py
├── algo
│ ├── __init__.py
│ ├── base.py
│ ├── centralized_cbf.py
│ ├── dec_share_cbf.py
│ ├── gcbf.py
│ ├── gcbf_plus.py
│ ├── module
│ │ ├── __init__.py
│ │ ├── cbf.py
│ │ ├── distribution.py
│ │ ├── policy.py
│ │ └── value.py
│ └── utils.py
├── env
│ ├── __init__.py
│ ├── base.py
│ ├── crazyflie.py
│ ├── double_integrator.py
│ ├── dubins_car.py
│ ├── linear_drone.py
│ ├── obstacle.py
│ ├── plot.py
│ ├── single_integrator.py
│ └── utils.py
├── nn
│ ├── __init__.py
│ ├── gnn.py
│ ├── mlp.py
│ └── utils.py
├── trainer
│ ├── __init__.py
│ ├── buffer.py
│ ├── data.py
│ ├── trainer.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── graph.py
│ ├── typing.py
│ └── utils.py
├── media
├── DoubleIntegrator_512_2x.gif
├── Obstacle2D_32.gif
├── Obstacle2D_512_2x.gif
└── cbf1.gif
├── pretrained
├── CrazyFlie
│ └── gcbf+
│ │ ├── config.yaml
│ │ └── models
│ │ └── 1000
│ │ ├── actor.pkl
│ │ └── cbf.pkl
├── DoubleIntegrator
│ └── gcbf+
│ │ ├── config.yaml
│ │ └── models
│ │ └── 1000
│ │ ├── actor.pkl
│ │ └── cbf.pkl
├── DubinsCar
│ └── gcbf+
│ │ ├── config.yaml
│ │ └── models
│ │ └── 1000
│ │ ├── actor.pkl
│ │ └── cbf.pkl
├── LinearDrone
│ └── gcbf+
│ │ ├── config.yaml
│ │ └── models
│ │ └── 1000
│ │ ├── actor.pkl
│ │ └── cbf.pkl
└── SingleIntegrator
│ └── gcbf+
│ ├── config.yaml
│ └── models
│ └── 1000
│ ├── actor.pkl
│ └── cbf.pkl
├── requirements.txt
├── settings.yaml
├── setup.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea/
3 | __pycache__/
4 | wandb
5 | logs
6 | figs
7 | videos/
8 | wandb/
9 | test_log.csv
10 | *.egg-info
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Songyuan Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # GCBF+
4 |
5 | [](https://mit-realm.github.io/gcbfplus-website/)
6 |
7 | Jax Official Implementation of T-RO Paper: [Songyuan Zhang*](https://syzhang092218-source.github.io), [Oswin So*](https://oswinso.xyz/), [Kunal Garg](https://kunalgarg.mit.edu/), and [Chuchu Fan](https://chuchu.mit.edu): "[GCBF+: A Neural Graph Control Barrier Function Framework for Distributed Safe Multi-Agent Control](https://mit-realm.github.io/gcbfplus-website/)".
8 |
9 | [Dependencies](#Dependencies) •
10 | [Installation](#Installation) •
11 | [Run](#Run)
12 |
13 |
14 |
15 | A much improved version of [GCBFv0](https://mit-realm.github.io/gcbf-website/)!
16 |
17 |
23 |
24 | ## Dependencies
25 |
26 | We recommend to use [CONDA](https://www.anaconda.com/) to install the requirements:
27 |
28 | ```bash
29 | conda create -n gcbfplus python=3.10
30 | conda activate gcbfplus
31 | cd gcbfplus
32 | ```
33 |
34 | Then install jax following the [official instructions](https://github.com/google/jax#installation), and then install the rest of the dependencies:
35 | ```bash
36 | pip install -r requirements.txt
37 | ```
38 |
39 | ## Installation
40 |
41 | Install GCBF:
42 |
43 | ```bash
44 | pip install -e .
45 | ```
46 |
47 | ## Run
48 |
49 | ### Environments
50 |
51 | We provide 3 2D environments including `SingleIntegrator`, `DoubleIntegrator`, and `DubinsCar`, and 2 3D environments including `LinearDrone` and `CrazyFlie`.
52 |
53 | ### Algorithms
54 |
55 | We provide algorithms including GCBF+ (`gcbf+`), GCBF (`gcbf`), centralized CBF-QP (`centralized_cbf`), and decentralized CBF-QP (`dec_share_cbf`). Use `--algo` to specify the algorithm.
56 |
57 | ### Hyper-parameters
58 |
59 | To reproduce the results shown in our paper, one can refer to [`settings.yaml`](./settings.yaml).
60 |
61 | ### Train
62 |
63 | To train the model (only GCBF+ and GCBF need training), use:
64 |
65 | ```bash
66 | python train.py --algo gcbf+ --env DoubleIntegrator -n 8 --area-size 4 --loss-action-coef 1e-4 --n-env-train 16 --lr-actor 1e-5 --lr-cbf 1e-5 --horizon 32
67 | ```
68 |
69 | In our paper, we use 8 agents with 1000 training steps. The training logs will be saved in folder `./logs///seed_`. We also provide the following flags:
70 |
71 | - `-n`: number of agents
72 | - `--env`: environment, including `SingleIntegrator`, `DoubleIntegrator`, `DubinsCar`, `LinearDrone`, and `CrazyFlie`
73 | - `--algo`: algorithm, including `gcbf`, `gcbf+`
74 | - `--seed`: random seed
75 | - `--steps`: number of training steps
76 | - `--name`: name of the experiment
77 | - `--debug`: debug mode: no recording, no saving
78 | - `--obs`: number of obstacles
79 | - `--n-rays`: number of LiDAR rays
80 | - `--area-size`: side length of the environment
81 | - `--n-env-train`: number of environments for training
82 | - `--n-env-test`: number of environments for testing
83 | - `--log-dir`: path to save the training logs
84 | - `--eval-interval`: interval of evaluation
85 | - `--eval-epi`: number of episodes for evaluation
86 | - `--save-interval`: interval of saving the model
87 |
88 | In addition, use the following flags to specify the hyper-parameters:
89 | - `--alpha`: GCBF alpha
90 | - `--horizon`: GCBF+ look forward horizon
91 | - `--lr-actor`: learning rate of the actor
92 | - `--lr-cbf`: learning rate of the CBF
93 | - `--loss-action-coef`: coefficient of the action loss
94 | - `--loss-h-dot-coef`: coefficient of the h_dot loss
95 | - `--loss-safe-coef`: coefficient of the safe loss
96 | - `--loss-unsafe-coef`: coefficient of the unsafe loss
97 | - `--buffer-size`: size of the replay buffer
98 |
99 | ### Test
100 |
101 | To test the learned model, use:
102 |
103 | ```bash
104 | python test.py --path --epi 5 --area-size 4 -n 16 --obs 0
105 | ```
106 |
107 | This should report the safety rate, goal reaching rate, and success rate of the learned model, and generate videos of the learned model in `/videos`. Use the following flags to customize the test:
108 |
109 | - `-n`: number of agents
110 | - `--obs`: number of obstacles
111 | - `--area-size`: side length of the environment
112 | - `--max-step`: maximum number of steps for each episode, increase this if you have a large environment
113 | - `--path`: path to the log folder
114 | - `--n-rays`: number of LiDAR rays
115 | - `--alpha`: CBF alpha, used in centralized CBF-QP and decentralized CBF-QP
116 | - `--max-travel`: maximum travel distance of agents
117 | - `--cbf`: plot the CBF contour of this agent, only support 2D environments
118 | - `--seed`: random seed
119 | - `--debug`: debug mode
120 | - `--cpu`: use CPU
121 | - `--u-ref`: test the nominal controller
122 | - `--env`: test environment (not needed if the log folder is specified)
123 | - `--algo`: test algorithm (not needed if the log folder is specified)
124 | - `--step`: test step (not needed if testing the last saved model)
125 | - `--epi`: number of episodes to test
126 | - `--offset`: offset of the random seeds
127 | - `--no-video`: do not generate videos
128 | - `--log`: log the results to a file
129 | - `--dpi`: dpi of the video
130 | - `--nojit-rollout`: do not use jit to speed up the rollout, used for large-scale tests
131 |
132 | To test the nominal controller, use:
133 |
134 | ```bash
135 | python test.py --env SingleIntegrator -n 16 --u-ref --epi 1 --area-size 4 --obs 0
136 | ```
137 |
138 | To test the CBF-QPs, use:
139 |
140 | ```bash
141 | python test.py --env SingleIntegrator -n 16 --algo dec_share_cbf --epi 1 --area-size 4 --obs 0 --alpha 1
142 | ```
143 |
144 | ### Pre-trained models
145 |
146 | We provide pre-trained models in folder [`pretrained`](pretrained). However, their performance may depend on the GPU/CUDA/Jax versions. We highly recommend retraining a model yourself.
147 |
148 | ## Citation
149 |
150 | ```
151 | @ARTICLE{zhang2025gcbf+,
152 | author={Zhang, Songyuan and So, Oswin and Garg, Kunal and Fan, Chuchu},
153 | journal={IEEE Transactions on Robotics},
154 | title={{GCBF}+: A Neural Graph Control Barrier Function Framework for Distributed Safe Multiagent Control},
155 | year={2025},
156 | volume={41},
157 | pages={1533-1552},
158 | doi={10.1109/TRO.2025.3530348}
159 | }
160 | ```
161 |
162 | ## Acknowledgement
163 |
164 | The developers were partially supported by MITRE during the project.
165 |
166 | © 2024 MIT
167 |
168 | © 2024 The MITRE Corporation
169 |
--------------------------------------------------------------------------------
/gcbfplus/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/__init__.py
--------------------------------------------------------------------------------
/gcbfplus/algo/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import MultiAgentController
2 | from .dec_share_cbf import DecShareCBF
3 | from .gcbf import GCBF
4 | from .gcbf_plus import GCBFPlus
5 | from .centralized_cbf import CentralizedCBF
6 |
7 |
8 | def make_algo(algo: str, **kwargs) -> MultiAgentController:
9 | if algo == 'gcbf':
10 | return GCBF(**kwargs)
11 | elif algo == 'gcbf+':
12 | return GCBFPlus(**kwargs)
13 | elif algo == 'centralized_cbf':
14 | return CentralizedCBF(**kwargs)
15 | elif algo == 'dec_share_cbf':
16 | return DecShareCBF(**kwargs)
17 | else:
18 | raise ValueError(f'Unknown algorithm: {algo}')
19 |
--------------------------------------------------------------------------------
/gcbfplus/algo/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod, abstractproperty
2 | from typing import Optional, Tuple
3 |
4 | from gcbfplus.utils.typing import Action, Params, PRNGKey, Array
5 | from gcbfplus.utils.graph import GraphsTuple
6 | from gcbfplus.trainer.data import Rollout
7 | from gcbfplus.env.base import MultiAgentEnv
8 |
9 |
10 | class MultiAgentController(ABC):
11 |
12 | def __init__(
13 | self,
14 | env: MultiAgentEnv,
15 | node_dim: int,
16 | edge_dim: int,
17 | action_dim: int,
18 | n_agents: int
19 | ):
20 | self._env = env
21 | self._node_dim = node_dim
22 | self._edge_dim = edge_dim
23 | self._action_dim = action_dim
24 | self._n_agents = n_agents
25 |
26 | @property
27 | def node_dim(self) -> int:
28 | return self._node_dim
29 |
30 | @property
31 | def edge_dim(self) -> int:
32 | return self._edge_dim
33 |
34 | @property
35 | def action_dim(self) -> int:
36 | return self._action_dim
37 |
38 | @property
39 | def n_agents(self) -> int:
40 | return self._n_agents
41 |
42 | @abstractproperty
43 | def config(self) -> dict:
44 | pass
45 |
46 | @abstractproperty
47 | def actor_params(self) -> Params:
48 | pass
49 |
50 | @abstractmethod
51 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action:
52 | pass
53 |
54 | @abstractmethod
55 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]:
56 | pass
57 |
58 | @abstractmethod
59 | def update(self, rollout: Rollout, step: int) -> dict:
60 | pass
61 |
62 | @abstractmethod
63 | def save(self, save_dir: str, step: int):
64 | pass
65 |
66 | @abstractmethod
67 | def load(self, load_dir: str, step: int):
68 | pass
69 |
--------------------------------------------------------------------------------
/gcbfplus/algo/centralized_cbf.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import jax
3 | import einops as ei
4 |
5 | from typing import Optional, Tuple
6 | from jaxproxqp.jaxproxqp import JaxProxQP
7 |
8 | from gcbfplus.utils.typing import Action, Params, PRNGKey, Array, State
9 | from gcbfplus.utils.graph import GraphsTuple
10 | from gcbfplus.utils.utils import mask2index
11 | from gcbfplus.trainer.data import Rollout
12 | from gcbfplus.env.base import MultiAgentEnv
13 | from .utils import get_pwise_cbf_fn
14 | from .base import MultiAgentController
15 |
16 |
17 | class CentralizedCBF(MultiAgentController):
18 |
19 | def __init__(
20 | self,
21 | env: MultiAgentEnv,
22 | node_dim: int,
23 | edge_dim: int,
24 | state_dim: int,
25 | action_dim: int,
26 | n_agents: int,
27 | alpha: float = 1.0,
28 | **kwargs
29 | ):
30 | super(CentralizedCBF, self).__init__(
31 | env=env,
32 | node_dim=node_dim,
33 | edge_dim=edge_dim,
34 | action_dim=action_dim,
35 | n_agents=n_agents
36 | )
37 |
38 | self.alpha = alpha
39 | self.k = 3
40 | self.cbf = get_pwise_cbf_fn(env, self.k)
41 |
42 | @property
43 | def config(self) -> dict:
44 | return {
45 | 'alpha': self.alpha,
46 | }
47 |
48 | @property
49 | def actor_params(self) -> Params:
50 | raise NotImplementedError
51 |
52 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]:
53 | raise NotImplementedError
54 |
55 | def get_cbf(self, graph: GraphsTuple) -> Array:
56 | return self.cbf(graph)[0]
57 |
58 | def update(self, rollout: Rollout, step: int) -> dict:
59 | raise NotImplementedError
60 |
61 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action:
62 | return self.get_qp_action(graph)[0]
63 |
64 | def get_qp_action(self, graph: GraphsTuple, relax_penalty: float = 1e3) -> [Action, Array]:
65 | assert graph.is_single # consider single graph
66 | agent_node_mask = graph.node_type == 0
67 | agent_node_id = mask2index(agent_node_mask, self.n_agents)
68 |
69 | def h_aug(new_agent_state: State) -> Array:
70 | new_state = graph.states.at[agent_node_id].set(new_agent_state)
71 | new_graph = graph._replace(edges=new_state[graph.receivers] - new_state[graph.senders], states=new_state)
72 | val = self.get_cbf(new_graph)
73 | assert val.shape == (self.n_agents, self.k)
74 | return val
75 |
76 | agent_state = graph.type_states(type_idx=0, n_type=self.n_agents)
77 | h = h_aug(agent_state) # (n_agents, k)
78 | h_x = jax.jacfwd(h_aug)(agent_state) # (n_agents, k | n_agents, nx)
79 | h = h.reshape(-1) # (n_agents * k,)
80 |
81 | dyn_f, dyn_g = self._env.control_affine_dyn(agent_state)
82 | Lf_h = ei.einsum(h_x, dyn_f, "agent_i k agent_j nx, agent_j nx -> agent_i k")
83 | Lg_h = ei.einsum(h_x, dyn_g, "agent_i k agent_j nx, agent_j nx nu -> agent_i k agent_j nu")
84 | Lf_h = Lf_h.reshape(-1) # (n_agents * k,)
85 | Lg_h = Lg_h.reshape((self.n_agents * self.k, -1)) # (n_agents * k, n_agents * nu)
86 |
87 | u_lb, u_ub = self._env.action_lim()
88 | u_lb = u_lb[None, :].repeat(self.n_agents, axis=0).reshape(-1)
89 | u_ub = u_ub[None, :].repeat(self.n_agents, axis=0).reshape(-1)
90 | u_ref = self._env.u_ref(graph).reshape(-1)
91 |
92 | # construct QP
93 | H = jnp.eye(self._env.action_dim * self.n_agents + self.n_agents * self.k, dtype=jnp.float32)
94 | H = H.at[-self.n_agents * self.k:, -self.n_agents * self.k:].set(
95 | H[-self.n_agents * self.k:, -self.n_agents * self.k:] * 10.0)
96 | g = jnp.concatenate([-u_ref, relax_penalty * jnp.ones(self.n_agents * self.k)])
97 | C = -jnp.concatenate([Lg_h, jnp.eye(self.n_agents * self.k)], axis=1)
98 | b = Lf_h + self.alpha * h # (n_agents * k,)
99 |
100 | r_lb = jnp.array([0.] * self.n_agents * self.k, dtype=jnp.float32)
101 | r_ub = jnp.array([jnp.inf] * self.n_agents * self.k, dtype=jnp.float32)
102 |
103 | l_box = jnp.concatenate([u_lb, r_lb], axis=0)
104 | u_box = jnp.concatenate([u_ub, r_ub], axis=0)
105 |
106 | qp = JaxProxQP.QPModel.create(H, g, C, b, l_box, u_box)
107 | settings = JaxProxQP.Settings.default()
108 | settings.max_iter = 100
109 | settings.dua_gap_thresh_abs = None
110 | solver = JaxProxQP(qp, settings)
111 | sol = solver.solve()
112 |
113 | assert sol.x.shape == (self.action_dim * self.n_agents + self.n_agents * self.k,)
114 | u_opt, r = sol.x[:self.action_dim * self.n_agents], sol.x[-self.n_agents * self.k:]
115 | u_opt = u_opt.reshape(self.n_agents, -1)
116 |
117 | return u_opt, r
118 |
119 | def save(self, save_dir: str, step: int):
120 | raise NotImplementedError
121 |
122 | def load(self, load_dir: str, step: int):
123 | raise NotImplementedError
124 |
--------------------------------------------------------------------------------
/gcbfplus/algo/dec_share_cbf.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | import einops as ei
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from jaxproxqp.jaxproxqp import JaxProxQP
7 | from typing import Optional, Tuple
8 |
9 | from gcbfplus.env.base import MultiAgentEnv
10 | from gcbfplus.trainer.data import Rollout
11 | from gcbfplus.utils.graph import GraphsTuple
12 | from gcbfplus.utils.typing import Action, Array, Params, PRNGKey, State
13 | from gcbfplus.utils.utils import mask2index, jax_vmap
14 | from .base import MultiAgentController
15 | from .utils import get_pwise_cbf_fn
16 |
17 |
18 | class DecShareCBF(MultiAgentController):
19 | """Same as DecShareCBF, but takes the k closest agents into account."""
20 |
21 | def __init__(
22 | self,
23 | env: MultiAgentEnv,
24 | node_dim: int,
25 | edge_dim: int,
26 | state_dim: int,
27 | action_dim: int,
28 | n_agents: int,
29 | alpha: float = 1.0,
30 | **kwargs
31 | ):
32 | super().__init__(env=env, node_dim=node_dim, edge_dim=edge_dim, action_dim=action_dim, n_agents=n_agents)
33 |
34 | if hasattr(env, "enable_stop"):
35 | env.enable_stop = False
36 |
37 | self.cbf_alpha = alpha
38 | self.k = 3
39 | self.cbf = get_pwise_cbf_fn(env, self.k)
40 |
41 | @property
42 | def config(self) -> dict:
43 | return {
44 | "alpha": self.cbf_alpha,
45 | }
46 |
47 | @property
48 | def actor_params(self) -> Params:
49 | raise NotImplementedError
50 |
51 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]:
52 | raise NotImplementedError
53 |
54 | def get_cbf(self, graph: GraphsTuple) -> tuple[Array, Array]:
55 | ak_h0, ak_isobs = self.cbf(graph)
56 | return ak_h0, ak_isobs
57 |
58 | def update(self, rollout: Rollout, step: int) -> dict:
59 | raise NotImplementedError
60 |
61 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action:
62 | return self.get_qp_action(graph)[0]
63 |
64 | def get_qp_action(self, graph: GraphsTuple, relax_penalty: float = 1e3) -> [Action, Array]:
65 | assert graph.is_single # consider single graph
66 | agent_node_mask = graph.node_type == 0
67 | agent_node_id = mask2index(agent_node_mask, self.n_agents)
68 |
69 | def h_aug(new_agent_state: State) -> tuple[Array, Array]:
70 | new_state = graph.states.at[agent_node_id].set(new_agent_state)
71 | new_graph = graph._replace(edges=new_state[graph.receivers] - new_state[graph.senders], states=new_state)
72 | ak_h_, ak_isobs_ = self.get_cbf(new_graph)
73 | assert ak_h_.shape == (self.n_agents, self.k)
74 | assert ak_isobs_.shape == (self.n_agents, self.k)
75 | return ak_h_, ak_isobs_
76 |
77 | def h(new_agent_state: State) -> Array:
78 | return h_aug(new_agent_state)[0]
79 |
80 | agent_state = graph.type_states(type_idx=0, n_type=self.n_agents)
81 | # (n_agents, k)
82 | ak_h, ak_isobs = h_aug(agent_state)
83 | # (n_agents, k | n_agents, nx)
84 | ak_hx = jax.jacfwd(h)(agent_state)
85 |
86 | a_dyn_f, a_dyn_g = self._env.control_affine_dyn(agent_state)
87 | ak_Lf_h = ei.einsum(ak_hx, a_dyn_f, "agent_i k agent_j nx, agent_j nx -> agent_i k")
88 | aka_Lg_h: Array = ei.einsum(ak_hx, a_dyn_g, "agent_i k agent_j nx, agent_j nx nu -> agent_i k agent_j nu")
89 |
90 | def index_fn(idx: int):
91 | k_Lg_h = aka_Lg_h[idx, :, idx]
92 | assert k_Lg_h.shape == (self.k, self.action_dim)
93 | return k_Lg_h
94 |
95 | ak_Lg_h_self = jax_vmap(index_fn)(jnp.arange(self.n_agents))
96 |
97 | au_ref = self._env.u_ref(graph)
98 | assert au_ref.shape == (self.n_agents, self.action_dim)
99 |
100 | # (n_agents, ). 1 if agent-obs, 0.5 if agent-agent.
101 | ak_resp = jnp.where(ak_isobs, 1.0, 0.5)
102 |
103 | # construct QP
104 | au_opt, ar = jax_vmap(ft.partial(self._solve_qp_single, relax_penalty=relax_penalty))(
105 | ak_h, ak_Lf_h, ak_Lg_h_self, au_ref, ak_resp
106 | )
107 | return au_opt, ar
108 |
109 | def _solve_qp_single(self, k_h, k_Lf_h, k_Lg_h, u_ref, k_responsibility: float, relax_penalty: float = 1e3):
110 | n_qp_x = self._env.action_dim + self.k
111 |
112 | assert k_h.shape == (self.k,)
113 | assert k_Lf_h.shape == (self.k,)
114 | assert k_Lg_h.shape == (self.k, self._env.action_dim)
115 |
116 | u_lb, u_ub = self._env.action_lim()
117 | assert u_lb.shape == u_ub.shape == (self.action_dim,)
118 |
119 | ###########
120 |
121 | H = jnp.eye(n_qp_x, dtype=jnp.float32)
122 | H = H.at[-self.k :, -self.k :].set(10.0)
123 | g = jnp.concatenate([-u_ref, relax_penalty * jnp.ones(self.k)], axis=0)
124 | assert g.shape == (n_qp_x,)
125 |
126 | k_C = -jnp.concatenate([k_Lg_h, jnp.eye(self.k)], axis=1)
127 | assert k_C.shape == (self.k, n_qp_x)
128 |
129 | # Responsibility is one if this is agent-obs, half if this is agent-agent.
130 | k_b = k_responsibility * (k_Lf_h + self.cbf_alpha * k_h)
131 | assert k_b.shape == (self.k,)
132 |
133 | r_lb = jnp.full(self.k, 0.0, dtype=jnp.float32)
134 | r_ub = jnp.full(self.k, jnp.inf, dtype=jnp.float32)
135 |
136 | l_box = jnp.concatenate([u_lb, r_lb], axis=0)
137 | u_box = jnp.concatenate([u_ub, r_ub], axis=0)
138 | assert l_box.shape == u_box.shape == (n_qp_x,)
139 |
140 | qp = JaxProxQP.QPModel.create(H, g, k_C, k_b, l_box, u_box)
141 | settings = JaxProxQP.Settings.default()
142 | settings.max_iter = 100
143 | settings.dua_gap_thresh_abs = None
144 | solver = JaxProxQP(qp, settings)
145 | sol = solver.solve()
146 |
147 | assert sol.x.shape == (n_qp_x,)
148 | u_opt, r = sol.x[: self.action_dim], sol.x[-self.k :]
149 |
150 | return u_opt, r
151 |
152 | def save(self, save_dir: str, step: int):
153 | raise NotImplementedError
154 |
155 | def load(self, load_dir: str, step: int):
156 | raise NotImplementedError
157 |
--------------------------------------------------------------------------------
/gcbfplus/algo/gcbf.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import jax.random as jr
3 | import optax
4 | import os
5 | import jax
6 | import functools as ft
7 | import jax.tree_util as jtu
8 | import numpy as np
9 | import pickle
10 |
11 | from typing import Optional, Tuple
12 | from flax.training.train_state import TrainState
13 |
14 | from gcbfplus.utils.typing import Action, Params, PRNGKey, Array
15 | from gcbfplus.utils.graph import GraphsTuple
16 | from gcbfplus.utils.utils import merge01, jax_vmap, tree_merge
17 | from gcbfplus.trainer.data import Rollout
18 | from gcbfplus.trainer.buffer import ReplayBuffer
19 | from gcbfplus.trainer.utils import has_any_nan, compute_norm_and_clip
20 | from gcbfplus.env.base import MultiAgentEnv
21 | from gcbfplus.algo.module.cbf import CBF
22 | from gcbfplus.algo.module.policy import DeterministicPolicy
23 | from .base import MultiAgentController
24 |
25 |
26 | class GCBF(MultiAgentController):
27 |
28 | def __init__(
29 | self,
30 | env: MultiAgentEnv,
31 | node_dim: int,
32 | edge_dim: int,
33 | state_dim: int,
34 | action_dim: int,
35 | n_agents: int,
36 | gnn_layers: int,
37 | batch_size: int,
38 | buffer_size: int,
39 | lr_actor: float = 3e-5,
40 | lr_cbf: float = 3e-5,
41 | alpha: float = 1.0,
42 | eps: float = 0.02,
43 | inner_epoch: int = 8,
44 | loss_action_coef: float = 0.001,
45 | loss_unsafe_coef: float = 1.,
46 | loss_safe_coef: float = 1.,
47 | loss_h_dot_coef: float = 0.2,
48 | max_grad_norm: float = 2.,
49 | seed: int = 0,
50 | online_pol_refine: bool = False,
51 | **kwargs
52 | ):
53 | super(GCBF, self).__init__(
54 | env=env,
55 | node_dim=node_dim,
56 | edge_dim=edge_dim,
57 | action_dim=action_dim,
58 | n_agents=n_agents
59 | )
60 |
61 | # set hyperparameters
62 | self.batch_size = batch_size
63 | self.lr_actor = lr_actor
64 | self.lr_cbf = lr_cbf
65 | self.alpha = alpha
66 | self.eps = eps
67 | self.inner_epoch = inner_epoch
68 | self.loss_action_coef = loss_action_coef
69 | self.loss_unsafe_coef = loss_unsafe_coef
70 | self.loss_safe_coef = loss_safe_coef
71 | self.loss_h_dot_coef = loss_h_dot_coef
72 | self.gnn_layers = gnn_layers
73 | self.max_grad_norm = max_grad_norm
74 | self.seed = seed
75 | self.online_pol_refine = online_pol_refine
76 |
77 | # set nominal graph for initialization of the neural networks
78 | nominal_graph = GraphsTuple(
79 | nodes=jnp.zeros((n_agents, node_dim)),
80 | edges=jnp.zeros((n_agents, edge_dim)),
81 | states=jnp.zeros((n_agents, state_dim)),
82 | n_node=jnp.array(n_agents),
83 | n_edge=jnp.array(n_agents),
84 | senders=jnp.arange(n_agents),
85 | receivers=jnp.arange(n_agents),
86 | node_type=jnp.zeros((n_agents,)),
87 | env_states=jnp.zeros((n_agents,)),
88 | )
89 | self.nominal_graph = nominal_graph
90 |
91 | # set up CBF
92 | self.cbf = CBF(
93 | node_dim=node_dim,
94 | edge_dim=edge_dim,
95 | n_agents=n_agents,
96 | gnn_layers=gnn_layers
97 | )
98 | key = jr.PRNGKey(seed)
99 | cbf_key, key = jr.split(key)
100 | cbf_params = self.cbf.net.init(cbf_key, nominal_graph, self.n_agents)
101 | cbf_optim = optax.adam(learning_rate=lr_cbf)
102 | self.cbf_optim = optax.apply_if_finite(cbf_optim, 1_000_000)
103 | self.cbf_train_state = TrainState.create(
104 | apply_fn=self.cbf.get_cbf,
105 | params=cbf_params,
106 | tx=self.cbf_optim
107 | )
108 |
109 | # set up actor
110 | self.actor = DeterministicPolicy(
111 | node_dim=node_dim,
112 | edge_dim=edge_dim,
113 | action_dim=action_dim,
114 | n_agents=n_agents
115 | )
116 | actor_key, key = jr.split(key)
117 | actor_params = self.actor.net.init(actor_key, nominal_graph, self.n_agents)
118 | actor_optim = optax.adam(learning_rate=lr_actor)
119 | self.actor_optim = optax.apply_if_finite(actor_optim, 1_000_000)
120 | self.actor_train_state = TrainState.create(
121 | apply_fn=self.actor.sample_action,
122 | params=actor_params,
123 | tx=self.actor_optim
124 | )
125 |
126 | # set up key
127 | self.key = key
128 | self.buffer = ReplayBuffer(size=buffer_size)
129 | self.unsafe_buffer = ReplayBuffer(size=buffer_size // 2)
130 |
131 | @property
132 | def config(self) -> dict:
133 | return {
134 | 'batch_size': self.batch_size,
135 | 'lr_actor': self.lr_actor,
136 | 'lr_cbf': self.lr_cbf,
137 | 'alpha': self.alpha,
138 | 'eps': self.eps,
139 | 'inner_epoch': self.inner_epoch,
140 | 'loss_action_coef': self.loss_action_coef,
141 | 'loss_unsafe_coef': self.loss_unsafe_coef,
142 | 'loss_safe_coef': self.loss_safe_coef,
143 | 'loss_h_dot_coef': self.loss_h_dot_coef,
144 | 'gnn_layers': self.gnn_layers,
145 | 'seed': self.seed,
146 | 'max_grad_norm': self.max_grad_norm
147 | }
148 |
149 | @property
150 | def actor_params(self) -> Params:
151 | return self.actor_train_state.params
152 |
153 | def act(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action:
154 | if self.online_pol_refine:
155 | return self.online_policy_refinement(graph, params)
156 | if params is None:
157 | params = self.actor_params
158 | nn_action = 2 * self.actor.get_action(params, graph) + self._env.u_ref(graph)
159 | return nn_action
160 |
161 | def online_policy_refinement(self, graph: GraphsTuple, params: Optional[Params] = None) -> Action:
162 | if params is None:
163 | params = self.actor_params
164 | h = self.get_cbf(graph)
165 |
166 | # try u_ref first
167 | u_ref = self._env.u_ref(graph)
168 | next_graph_u_ref = self._env.forward_graph(graph, u_ref)
169 | h_next_u_ref = self.get_cbf(next_graph_u_ref)
170 | h_dot_u_ref = (h_next_u_ref - h) / self._env.dt
171 | max_val_h_dot_u_ref = jax.nn.relu(-h_dot_u_ref - self.alpha * h)
172 | nn_action = 2 * self.actor.get_action(params, graph) + u_ref
173 | nn_action = jnp.where(max_val_h_dot_u_ref > 0, nn_action, u_ref)
174 |
175 | max_iter = 30
176 | lr = 0.1
177 |
178 | def do_refinement(inp):
179 | i_iter, action, prev_h_dot_val = inp
180 |
181 | def h_dot_cond_val(a: Action):
182 | next_graph = self._env.forward_graph(graph, a)
183 | h_next = self.get_cbf(next_graph)
184 | h_dot = (h_next - h) / self._env.dt
185 | max_val_h_dot = jax.nn.relu(-h_dot - self.alpha * h).mean()
186 | return max_val_h_dot
187 |
188 | h_dot_val, grad = jax.value_and_grad(h_dot_cond_val)(action)
189 | action = action - lr * grad
190 | i_iter += 1
191 | return i_iter, action, h_dot_val
192 |
193 | def continue_refinement(inp):
194 | i_iter, action, h_dot_val = inp
195 | return (h_dot_val > 0) & (i_iter < max_iter)
196 |
197 | _, nn_action, _ = jax.lax.while_loop(
198 | continue_refinement, do_refinement, init_val=(0, nn_action, 1.0)
199 | )
200 |
201 | return nn_action
202 |
203 | def step(self, graph: GraphsTuple, key: PRNGKey, params: Optional[Params] = None) -> Tuple[Action, Array]:
204 | if params is None:
205 | params = self.actor_params
206 | action, log_pi = self.actor_train_state.apply_fn(params, graph, key)
207 | return 2 * action + self._env.u_ref(graph), log_pi
208 |
209 | def get_cbf(self, graph: GraphsTuple, params: Optional[Params] = None) -> Array:
210 | if params is None:
211 | params = self.cbf_train_state.params
212 | return self.cbf.get_cbf(params, graph)
213 |
214 | def update(self, rollout: Rollout, step: int) -> dict:
215 | key, self.key = jr.split(self.key)
216 |
217 | if self.buffer.length > self.batch_size:
218 | # sample from memory and unsafe_memory
219 | memory = self.buffer.sample(rollout.length // 2)
220 | unsafe_memory = self.unsafe_buffer.sample(rollout.length * rollout.time_horizon)
221 |
222 | # append new data to memory and unsafe_memory
223 | unsafe_mask = jax_vmap(jax_vmap(self._env.unsafe_mask))(rollout.graph).max(axis=-1)
224 | self.unsafe_buffer.append(jtu.tree_map(lambda x: x[unsafe_mask], rollout))
225 | self.buffer.append(rollout)
226 |
227 | # get update data
228 | rollout = tree_merge([memory, rollout])
229 | rollout = jtu.tree_map(lambda x: merge01(x), rollout)
230 | rollout = tree_merge([unsafe_memory, rollout])
231 | else:
232 | self.buffer.append(rollout)
233 | unsafe_mask = jax_vmap(jax_vmap(self._env.unsafe_mask))(rollout.graph).max(axis=-1)
234 | self.unsafe_buffer.append(jtu.tree_map(lambda x: x[unsafe_mask], rollout))
235 | rollout = jtu.tree_map(lambda x: merge01(x), rollout)
236 |
237 | # inner loop
238 | update_info = {}
239 | for i_epoch in range(self.inner_epoch):
240 | idx = np.arange(rollout.length)
241 | np.random.shuffle(idx)
242 | batch_idx = jnp.array(jnp.array_split(idx, idx.shape[0] // self.batch_size))
243 |
244 | cbf_train_state, actor_train_state, update_info = self.update_inner(
245 | self.cbf_train_state, self.actor_train_state, rollout, batch_idx
246 | )
247 | self.cbf_train_state = cbf_train_state
248 | self.actor_train_state = actor_train_state
249 |
250 | return update_info
251 |
252 | @ft.partial(jax.jit, static_argnums=(0,), donate_argnums=(1, 2))
253 | def update_inner(
254 | self, cbf_train_state: TrainState, actor_train_state: TrainState, rollout: Rollout, batch_idx: Array
255 | ) -> Tuple[TrainState, TrainState, dict]:
256 |
257 | def update_fn(carry, idx):
258 | """Update the actor and the CBF network for a single batch given the batch index."""
259 | cbf, actor = carry
260 | rollout_batch = jtu.tree_map(lambda x: x[idx], rollout)
261 |
262 | def get_loss(cbf_params: Params, actor_params: Params) -> Tuple[Array, dict]:
263 | # get CBF values
264 | cbf_fn = jax_vmap(ft.partial(self.cbf.get_cbf, cbf_params))
265 | h = cbf_fn(rollout_batch.graph).squeeze()
266 | h = merge01(h)
267 |
268 | # unsafe region h(x) < 0
269 | unsafe_mask = merge01(jax_vmap(self._env.unsafe_mask)(rollout_batch.graph))
270 | unsafe_data_ratio = jnp.mean(unsafe_mask)
271 | h_unsafe = jnp.where(unsafe_mask, h, -jnp.ones_like(h) * self.eps * 2)
272 | max_val_unsafe = jax.nn.relu(h_unsafe + self.eps)
273 | loss_unsafe = jnp.sum(max_val_unsafe) / (jnp.count_nonzero(unsafe_mask) + 1e-6)
274 | acc_unsafe_mask = jnp.where(unsafe_mask, h, jnp.ones_like(h))
275 | acc_unsafe = (jnp.sum(jnp.less(acc_unsafe_mask, 0)) + 1e-6) / (jnp.count_nonzero(unsafe_mask) + 1e-6)
276 |
277 | # safe region h(x) > 0
278 | safe_mask = merge01(jax_vmap(self._env.safe_mask)(rollout_batch.graph))
279 | h_safe = jnp.where(safe_mask, h, jnp.ones_like(h) * self.eps * 2)
280 | max_val_safe = jax.nn.relu(-h_safe + self.eps)
281 | loss_safe = jnp.sum(max_val_safe) / (jnp.count_nonzero(safe_mask) + 1e-6)
282 | acc_safe_mask = jnp.where(safe_mask, h, -jnp.ones_like(h))
283 | acc_safe = (jnp.sum(jnp.greater(acc_safe_mask, 0)) + 1e-6) / (jnp.count_nonzero(safe_mask) + 1e-6)
284 |
285 | # get actions
286 | action_fn = jax.vmap(ft.partial(self.actor.get_action, actor_params))
287 | action = action_fn(rollout_batch.graph)
288 |
289 | # get next graph
290 | forward_fn = jax_vmap(self._env.forward_graph)
291 | next_graph = forward_fn(rollout_batch.graph, action)
292 | h_next = merge01(cbf_fn(next_graph))
293 | h_dot = (h_next - h) / self._env.dt
294 |
295 | # h_dot + alpha * h > 0
296 | max_val_h_dot = jax.nn.relu(-h_dot - self.alpha * h + self.eps)
297 | loss_h_dot = jnp.mean(max_val_h_dot) # + jnp.max(max_val_h_dot)
298 | acc_h_dot = jnp.mean(jnp.greater(h_dot + self.alpha * h, 0))
299 |
300 | # action loss
301 | u_ref = jax_vmap(self._env.u_ref)(rollout_batch.graph)
302 | loss_action = jnp.mean(jnp.square(action - u_ref).sum(axis=-1))
303 |
304 | # total loss
305 | total_loss = (
306 | self.loss_action_coef * loss_action
307 | + self.loss_unsafe_coef * loss_unsafe
308 | + self.loss_safe_coef * loss_safe
309 | + self.loss_h_dot_coef * loss_h_dot
310 | )
311 |
312 | return total_loss, {'loss/action': loss_action,
313 | 'loss/unsafe': loss_unsafe,
314 | 'loss/safe': loss_safe,
315 | 'loss/h_dot': loss_h_dot,
316 | 'loss/total': total_loss,
317 | 'acc/unsafe': acc_unsafe,
318 | 'acc/safe': acc_safe,
319 | 'acc/h_dot': acc_h_dot,
320 | 'acc/unsafe_data_ratio': unsafe_data_ratio}
321 |
322 | (loss, loss_info), (grad_cbf, grad_actor) = jax.value_and_grad(
323 | get_loss, has_aux=True, argnums=(0, 1))(cbf.params, actor.params)
324 | grad_cbf_has_nan = has_any_nan(grad_cbf).astype(jnp.float32)
325 | grad_actor_has_nan = has_any_nan(grad_actor).astype(jnp.float32)
326 | grad_cbf, grad_cbf_norm = compute_norm_and_clip(grad_cbf, self.max_grad_norm)
327 | grad_actor, grad_actor_norm = compute_norm_and_clip(grad_actor, self.max_grad_norm)
328 | cbf = cbf.apply_gradients(grads=grad_cbf)
329 | actor = actor.apply_gradients(grads=grad_actor)
330 | return (cbf, actor), {'grad_norm/cbf': grad_cbf_norm,
331 | 'grad_norm/actor': grad_actor_norm,
332 | 'grad_has_nan/cbf': grad_cbf_has_nan,
333 | 'grad_has_nan/actor': grad_actor_has_nan} | loss_info
334 |
335 | (cbf_train_state, actor_train_state), info = jax.lax.scan(
336 | update_fn, (cbf_train_state, actor_train_state), batch_idx
337 | )
338 |
339 | # get training info of the last epoch
340 | info = jtu.tree_map(lambda x: x[-1], info)
341 |
342 | return cbf_train_state, actor_train_state, info
343 |
344 | def save(self, save_dir: str, step: int):
345 | model_dir = os.path.join(save_dir, str(step))
346 | if not os.path.exists(model_dir):
347 | os.makedirs(model_dir)
348 | pickle.dump(self.actor_train_state.params, open(os.path.join(model_dir, 'actor.pkl'), 'wb'))
349 | pickle.dump(self.cbf_train_state.params, open(os.path.join(model_dir, 'cbf.pkl'), 'wb'))
350 |
351 | def load(self, load_dir: str, step: int):
352 | path = os.path.join(load_dir, str(step))
353 |
354 | self.actor_train_state = \
355 | self.actor_train_state.replace(params=pickle.load(open(os.path.join(path, 'actor.pkl'), 'rb')))
356 | self.cbf_train_state = \
357 | self.cbf_train_state.replace(params=pickle.load(open(os.path.join(path, 'cbf.pkl'), 'rb')))
358 |
--------------------------------------------------------------------------------
/gcbfplus/algo/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/algo/module/__init__.py
--------------------------------------------------------------------------------
/gcbfplus/algo/module/cbf.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import functools as ft
3 |
4 | from typing import Type
5 | from ...nn.gnn import GNN
6 | from ...nn.mlp import MLP
7 | from ...nn.utils import default_nn_init
8 | from ...utils.typing import Array, Params
9 | from ...utils.graph import GraphsTuple
10 |
11 |
12 | class CBFNet(nn.Module):
13 | gnn_cls: Type[GNN]
14 | head_cls: Type[nn.Module]
15 |
16 | @nn.compact
17 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> Array:
18 | x = self.gnn_cls()(obs, node_type=0, n_type=n_agents)
19 | x = self.head_cls()(x)
20 | x = nn.tanh(nn.Dense(1, kernel_init=default_nn_init())(x))
21 | return x
22 |
23 |
24 | class CBF:
25 |
26 | def __init__(self, node_dim: int, edge_dim: int, n_agents: int, gnn_layers: int):
27 | self.node_dim = node_dim
28 | self.edge_dim = edge_dim
29 | self.n_agents = n_agents
30 |
31 | self.cbf_gnn = ft.partial(
32 | GNN,
33 | msg_dim=128,
34 | hid_size_msg=(256, 256),
35 | hid_size_aggr=(128, 128),
36 | hid_size_update=(256, 256),
37 | out_dim=128,
38 | n_layers=gnn_layers
39 | )
40 | self.cbf_head = ft.partial(
41 | MLP,
42 | hid_sizes=(256, 256),
43 | act=nn.relu,
44 | act_final=False,
45 | name='CBFHead'
46 | )
47 | self.net = CBFNet(
48 | gnn_cls=self.cbf_gnn,
49 | head_cls=self.cbf_head
50 | )
51 |
52 | def get_cbf(self, params: Params, obs: GraphsTuple) -> Array:
53 | return self.net.apply(params, obs, self.n_agents)
54 |
--------------------------------------------------------------------------------
/gcbfplus/algo/module/distribution.py:
--------------------------------------------------------------------------------
1 | import tensorflow_probability.substrates.jax as tfp
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import jax.random as jr
5 |
6 | tfd = tfp.distributions
7 | tfb = tfp.bijectors
8 |
9 |
10 | class TanhTransformedDistribution(tfd.TransformedDistribution):
11 |
12 | def __init__(self, distribution: tfd.Distribution, threshold: float = 0.999, validate_args: bool = False):
13 | super().__init__(distribution=distribution, bijector=tfb.Tanh(), validate_args=validate_args)
14 | self._threshold = threshold
15 | self.inverse_threshold = self.bijector.inverse(threshold)
16 |
17 | inverse_threshold = self.bijector.inverse(threshold)
18 | # average(pdf) = p / epsilon
19 | # So log(average(pdf)) = log(p) - log(epsilon)
20 | log_epsilon = np.log(1.0 - threshold)
21 |
22 | self._log_prob_left = self.distribution.log_cdf(-inverse_threshold) - log_epsilon
23 | self._log_prob_right = self.distribution.log_survival_function(inverse_threshold) - log_epsilon
24 |
25 | def log_prob(self, value, name='log_prob', **kwargs):
26 | # Without this clip there would be NaNs in the inner tf.where and that
27 | # causes issues for some reasons.
28 | value = jnp.clip(value, -self._threshold, self._threshold)
29 | # The inverse image of {threshold} is the interval [atanh(threshold), inf]
30 | # which has a probability of "log_prob_right" under the given distribution.
31 | return jnp.where(
32 | value <= -self._threshold,
33 | self._log_prob_left,
34 | jnp.where(value >= self._threshold, self._log_prob_right, super().log_prob(value)),
35 | )
36 |
37 | def entropy(self, name='entropy', **kwargs):
38 | # We return an estimation using a single sample of the log_det_jacobian.
39 | # We can still do some backpropagation with this estimate.
40 | seed = np.random.randint(0, 102400)
41 | return self.distribution.entropy() + self.bijector.forward_log_det_jacobian(
42 | self.distribution.sample(seed=jr.PRNGKey(seed)), event_ndims=0
43 | )
44 |
45 | def _mode(self) -> jnp.ndarray:
46 | return self.bijector.forward(self.distribution.mode())
47 |
48 | @classmethod
49 | def _parameter_properties(cls, dtype, num_classes=None):
50 | td_properties = super()._parameter_properties(dtype, num_classes=num_classes)
51 | del td_properties["bijector"]
52 | return td_properties
53 |
54 | @property
55 | def experimental_is_sharded(self):
56 | raise NotImplementedError
57 |
58 | def _sample_n(self, n, seed=None, **kwargs):
59 | pass
60 |
61 | def _variance(self, **kwargs):
62 | pass
63 |
64 | @classmethod
65 | def _maximum_likelihood_parameters(cls, value):
66 | pass
67 |
--------------------------------------------------------------------------------
/gcbfplus/algo/module/policy.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import functools as ft
3 | import numpy as np
4 | import jax.nn as jnn
5 | import jax.numpy as jnp
6 |
7 | from typing import Type, Tuple
8 | from abc import ABC, abstractproperty, abstractmethod
9 |
10 | from .distribution import TanhTransformedDistribution, tfd
11 | from ...utils.typing import Action, Array
12 | from ...utils.graph import GraphsTuple
13 | from ...nn.utils import default_nn_init, scaled_init
14 | from ...nn.gnn import GNN
15 | from ...nn.mlp import MLP
16 | from ...utils.typing import PRNGKey, Params
17 |
18 |
19 | class PolicyDistribution(nn.Module, ABC):
20 |
21 | @abstractmethod
22 | def __call__(self, *args, **kwargs) -> tfd.Distribution:
23 | pass
24 |
25 | @abstractproperty
26 | def nu(self) -> int:
27 | pass
28 |
29 |
30 | class TanhNormal(PolicyDistribution):
31 | base_cls: Type[GNN]
32 | _nu: int
33 | scale_final: float = 0.01
34 | std_dev_min: float = 1e-5
35 | std_dev_init: float = 0.5
36 |
37 | @property
38 | def std_dev_init_inv(self):
39 | # inverse of log(sum(exp())).
40 | inv = np.log(np.exp(self.std_dev_init) - 1)
41 | assert np.allclose(np.logaddexp(inv, 0), self.std_dev_init)
42 | return inv
43 |
44 | @nn.compact
45 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> tfd.Distribution:
46 | x = self.base_cls()(obs, node_type=0, n_type=n_agents)
47 | # x = x.nodes
48 | scaler_init = scaled_init(default_nn_init(), self.scale_final)
49 | feats_scaled = nn.Dense(256, kernel_init=scaler_init, name="ScaleHid")(x)
50 |
51 | means = nn.Dense(self.nu, kernel_init=default_nn_init(), name="OutputDenseMean")(feats_scaled)
52 | stds_trans = nn.Dense(self.nu, kernel_init=default_nn_init(), name="OutputDenseStdTrans")(feats_scaled)
53 | stds = jnn.softplus(stds_trans + self.std_dev_init_inv) + self.std_dev_min
54 |
55 | distribution = tfd.Normal(loc=means, scale=stds)
56 | return tfd.Independent(TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1)
57 |
58 | @property
59 | def nu(self):
60 | return self._nu
61 |
62 |
63 | class Deterministic(nn.Module):
64 | base_cls: Type[GNN]
65 | head_cls: Type[nn.Module]
66 | _nu: int
67 |
68 | @nn.compact
69 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> Action:
70 | x = self.base_cls()(obs, node_type=0, n_type=n_agents)
71 | x = self.head_cls()(x)
72 | x = nn.tanh(nn.Dense(self._nu, kernel_init=default_nn_init(), name="OutputDense")(x))
73 | return x
74 |
75 |
76 | class MultiAgentPolicy(ABC):
77 |
78 | def __init__(self, node_dim: int, edge_dim: int, n_agents: int, action_dim: int):
79 | self.node_dim = node_dim
80 | self.edge_dim = edge_dim
81 | self.n_agents = n_agents
82 | self.action_dim = action_dim
83 |
84 | @abstractmethod
85 | def get_action(self, params: Params, obs: GraphsTuple) -> Action:
86 | pass
87 |
88 | @abstractmethod
89 | def sample_action(self, params: Params, obs: GraphsTuple, key: PRNGKey) -> Tuple[Action, Array]:
90 | pass
91 |
92 | @abstractmethod
93 | def eval_action(self, params: Params, obs: GraphsTuple, action: Action, key: PRNGKey) -> Tuple[Array, Array]:
94 | pass
95 |
96 |
97 | class DeterministicPolicy(MultiAgentPolicy):
98 |
99 | def __init__(
100 | self,
101 | node_dim: int,
102 | edge_dim: int,
103 | n_agents: int,
104 | action_dim: int,
105 | gnn_layers: int = 1,
106 | ):
107 | super().__init__(node_dim, edge_dim, n_agents, action_dim)
108 | self.policy_base = ft.partial(
109 | GNN,
110 | msg_dim=128,
111 | hid_size_msg=(256, 256),
112 | hid_size_aggr=(128, 128),
113 | hid_size_update=(256, 256),
114 | out_dim=128,
115 | n_layers=gnn_layers
116 | )
117 | self.policy_head = ft.partial(
118 | MLP,
119 | hid_sizes=(256, 256),
120 | act=nn.relu,
121 | act_final=False,
122 | name='PolicyHead'
123 | )
124 | self.net = Deterministic(base_cls=self.policy_base, head_cls=self.policy_head, _nu=action_dim)
125 | self.std = 0.1
126 |
127 | def get_action(self, params: Params, obs: GraphsTuple) -> Action:
128 | return self.net.apply(params, obs, self.n_agents)
129 |
130 | def sample_action(self, params: Params, obs: GraphsTuple, key: PRNGKey) -> Tuple[Action, Array]:
131 | action = self.get_action(params, obs)
132 | log_pi = jnp.zeros_like(action)
133 | return action, log_pi
134 |
135 | def eval_action(self, params: Params, obs: GraphsTuple, action: Action, key: PRNGKey) -> Tuple[Array, Array]:
136 | raise NotImplementedError
137 |
138 |
139 | class PPOPolicy(MultiAgentPolicy):
140 |
141 | def __init__(
142 | self,
143 | node_dim: int,
144 | edge_dim: int,
145 | n_agents: int,
146 | action_dim: int,
147 | gnn_layers: int = 1,
148 | ):
149 | super().__init__(node_dim, edge_dim, n_agents, action_dim)
150 | self.dist_base = ft.partial(
151 | GNN,
152 | msg_dim=64,
153 | hid_size_msg=(128, 128),
154 | hid_size_aggr=(128, 128),
155 | hid_size_update=(128, 128),
156 | out_dim=64,
157 | n_layers=gnn_layers
158 | )
159 | self.dist = TanhNormal(base_cls=self.dist_base, _nu=action_dim)
160 |
161 | def get_action(self, params: Params, obs: GraphsTuple) -> Action:
162 | dist = self.dist.apply(params, obs, n_agents=self.n_agents)
163 | action = dist.mode()
164 | return action
165 |
166 | def sample_action(self, params: Params, obs: GraphsTuple, key: PRNGKey) -> Tuple[Action, Array]:
167 | dist = self.dist.apply(params, obs, n_agents=self.n_agents)
168 | action = dist.sample(seed=key)
169 | log_pi = dist.log_prob(action)
170 | return action, log_pi
171 |
172 | def eval_action(self, params: Params, obs: GraphsTuple, action: Action, key: PRNGKey) -> Tuple[Array, Array]:
173 | dist = self.dist.apply(params, obs, n_agents=self.n_agents)
174 | log_pi = dist.log_prob(action)
175 | entropy = dist.entropy(seed=key)
176 | return log_pi, entropy
177 |
--------------------------------------------------------------------------------
/gcbfplus/algo/module/value.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | import flax.linen as nn
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from typing import Type
7 |
8 | from ...nn.mlp import MLP
9 | from ...nn.gnn import GNN
10 | from ...nn.utils import default_nn_init
11 | from ...utils.typing import Array, Params
12 | from ...utils.graph import GraphsTuple
13 |
14 |
15 | class StateFn(nn.Module):
16 | gnn_cls: Type[GNN]
17 | aggr_cls: Type[nn.Module]
18 | head_cls: Type[nn.Module]
19 |
20 | @nn.compact
21 | def __call__(self, obs: GraphsTuple, n_agents: int, *args, **kwargs) -> Array:
22 | # get node features
23 | x = self.gnn_cls()(obs, node_type=0, n_type=n_agents)
24 |
25 | # aggregate information using attention
26 | gate_feats = self.aggr_cls()(x)
27 | gate_feats = nn.Dense(1, kernel_init=default_nn_init())(gate_feats).squeeze(-1)
28 | attn = jax.nn.softmax(gate_feats, axis=-1)
29 | x = jnp.sum(attn[:, None] * x, axis=0)
30 |
31 | # pass through head class
32 | x = self.head_cls()(x)
33 | x = nn.Dense(1, kernel_init=default_nn_init())(x)
34 |
35 | return x
36 |
37 |
38 | class ValueNet:
39 |
40 | def __init__(self, node_dim: int, edge_dim: int, n_agents: int, gnn_layers: int = 1):
41 | self.node_dim = node_dim
42 | self.edge_dim = edge_dim
43 | self.n_agents = n_agents
44 | self.value_gnn = ft.partial(
45 | GNN,
46 | msg_dim=64,
47 | hid_size_msg=(128, 128),
48 | hid_size_aggr=(128, 128),
49 | hid_size_update=(128, 128),
50 | out_dim=64,
51 | n_layers=gnn_layers
52 | )
53 | self.value_attn = ft.partial(
54 | MLP,
55 | hid_sizes=(128, 128),
56 | act=nn.relu,
57 | act_final=False,
58 | name='ValueAttn'
59 | )
60 | self.value_head = ft.partial(
61 | MLP,
62 | hid_sizes=(128, 128),
63 | act=nn.relu,
64 | act_final=False,
65 | name='ValueHead'
66 | )
67 | # self.net = StateFn(, _nu=1)
68 | self.net = StateFn(
69 | gnn_cls=self.value_gnn,
70 | aggr_cls=self.value_attn,
71 | head_cls=self.value_head
72 | )
73 |
74 | def get_value(self, params: Params, obs: GraphsTuple) -> Array:
75 | values = self.net.apply(params, obs, self.n_agents)
76 | return values.squeeze()
77 |
--------------------------------------------------------------------------------
/gcbfplus/algo/utils.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | import einops as ei
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from typing import Tuple
7 |
8 | from ..env.base import MultiAgentEnv
9 | from ..env.double_integrator import DoubleIntegrator
10 | from ..env.dubins_car import DubinsCar
11 | from ..env.linear_drone import LinearDrone
12 | from ..env.single_integrator import SingleIntegrator
13 | from ..env.crazyflie import CrazyFlie
14 | from ..utils.graph import GraphsTuple
15 | from ..utils.typing import Array, Done, Reward
16 |
17 |
18 | def compute_gae_fn(
19 | values: Array, rewards: Reward, dones: Done, next_values: Array, gamma: float, gae_lambda: float
20 | ) -> Tuple[Array, Array]:
21 | """
22 | Compute generalized advantage estimation.
23 | """
24 | deltas = rewards + gamma * next_values * (1 - dones) - values
25 | gaes = deltas
26 |
27 | def scan_fn(gae, inp):
28 | delta, done = inp
29 | gae_prev = delta + gamma * gae_lambda * (1 - done) * gae
30 | return gae_prev, gae_prev
31 |
32 | _, gaes_prev = jax.lax.scan(scan_fn, gaes[-1], (deltas[:-1], dones[:-1]), reverse=True)
33 | gaes = jnp.concatenate([gaes_prev, gaes[-1, None]], axis=0)
34 |
35 | return gaes + values, (gaes - gaes.mean()) / (gaes.std() + 1e-8)
36 |
37 |
38 | def compute_gae(
39 | values: Array, rewards: Reward, dones: Done, next_values: Array, gamma: float, gae_lambda: float
40 | ) -> Tuple[Array, Array]:
41 | return jax.vmap(ft.partial(compute_gae_fn, gamma=gamma, gae_lambda=gae_lambda))(values, rewards, dones, next_values)
42 |
43 |
44 | def pwise_cbf_single_integrator_(pos: Array, agent_idx: int, o_obs_pos: Array, a_pos: Array, r: float, k: int):
45 | n_agent = len(a_pos)
46 |
47 | all_obs_pos = jnp.concatenate([a_pos, o_obs_pos], axis=0)
48 |
49 | # Only consider the k obstacles.
50 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1)
51 | # Remove self collisions
52 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2)
53 |
54 | # Take the k closest obstacles.
55 | k_idx = jnp.argsort(o_dist_sq)[:k]
56 | k_dist_sq = o_dist_sq[k_idx]
57 | # Take radius into account. Add some epsilon for qp solver error.
58 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2
59 |
60 | k_h0 = k_dist_sq
61 | k_isobs = k_idx >= n_agent
62 |
63 | return k_h0, k_isobs
64 |
65 |
66 | def pwise_cbf_single_integrator(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int):
67 | # (n_agents, 2)
68 | a_states = graph.type_states(type_idx=0, n_type=n_agent)
69 | # (n_obs, 2)
70 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays)
71 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent)
72 |
73 | agent_idx = jnp.arange(n_agent)
74 | fn = jax.vmap(ft.partial(pwise_cbf_single_integrator_, r=r, k=k), in_axes=(0, 0, 0, None))
75 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states)
76 | return ak_h0, ak_isobs
77 |
78 |
79 | def pwise_cbf_double_integrator_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int):
80 | n_agent = len(a_state)
81 |
82 | pos = state[:2]
83 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0)
84 | all_obs_pos = all_obs_state[:, :2]
85 | del o_obs_state
86 |
87 | # Only consider the k closest obstacles.
88 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1)
89 | # Remove self collisions
90 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2)
91 | # Take the k closest obstacles.
92 | k_idx = jnp.argsort(o_dist_sq)[:k]
93 | k_dist_sq = o_dist_sq[k_idx]
94 | # Take radius into account.
95 | k_dist_sq = k_dist_sq - 4 * r ** 2
96 |
97 | k_h0 = k_dist_sq
98 | assert k_h0.shape == (k,)
99 |
100 | k_xdiff = state[:2] - all_obs_state[k_idx][:, :2]
101 | k_vdiff = state[2:] - all_obs_state[k_idx][:, 2:]
102 | assert k_xdiff.shape == k_vdiff.shape == (k, 2)
103 |
104 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1)
105 | assert k_h0_dot.shape == (k,)
106 |
107 | k_h1 = k_h0_dot + 10.0 * k_h0
108 |
109 | k_isobs = k_idx >= n_agent
110 |
111 | return k_h1, k_isobs
112 |
113 |
114 | def pwise_cbf_double_integrator(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int):
115 | # (n_agents, 4)
116 | a_states = graph.type_states(type_idx=0, n_type=n_agent)
117 | # (n_obs, 4)
118 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays)
119 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent)
120 |
121 | agent_idx = jnp.arange(n_agent)
122 | fn = jax.vmap(ft.partial(pwise_cbf_double_integrator_, r=r, k=k), in_axes=(0, 0, 0, None))
123 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states)
124 | return ak_h0, ak_isobs
125 |
126 |
127 | def pwise_cbf_dubins_car_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int):
128 | n_agent = len(a_state)
129 | n_obs = len(o_obs_state)
130 |
131 | pos = state[:2]
132 | vel = state[3] * jnp.array([jnp.cos(state[2]), jnp.sin(state[2])])
133 | assert vel.shape == (2,)
134 |
135 | agent_vel = a_state[:, 3, None] * jnp.stack([jnp.cos(a_state[:, 2]), jnp.sin(a_state[:, 2])], axis=-1)
136 | assert agent_vel.shape == (n_agent, 2)
137 |
138 | all_obs_pos = jnp.concatenate([a_state[:, :2], o_obs_state[:, :2]], axis=0)
139 | all_obs_vel = jnp.concatenate([agent_vel, jnp.zeros((n_obs, 2))], axis=0)
140 | del o_obs_state
141 |
142 | # Only consider the k closest obstacles.
143 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1)
144 | # Remove self collisions
145 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2)
146 | # Take the k closest obstacles.
147 | k_idx = jnp.argsort(o_dist_sq)[:k]
148 | k_dist_sq = o_dist_sq[k_idx]
149 | # Take radius into account. Add some epsilon for qp solver error.
150 | k_dist_sq = k_dist_sq - 4 * r ** 2
151 |
152 | k_h0 = k_dist_sq
153 | assert k_h0.shape == (k,)
154 |
155 | k_xdiff = state[:2] - all_obs_pos[k_idx]
156 | k_vdiff = agent_vel[agent_idx] - all_obs_vel[k_idx]
157 | assert k_xdiff.shape == k_vdiff.shape == (k, 2)
158 |
159 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1)
160 | assert k_h0_dot.shape == (k,)
161 |
162 | k_h1 = k_h0_dot + 5.0 * k_h0
163 |
164 | k_isobs = k_idx >= n_agent
165 |
166 | return k_h1, k_isobs
167 |
168 |
169 | def pwise_cbf_dubins_car(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int):
170 | # (n_agents, 4)
171 | a_states = graph.type_states(type_idx=0, n_type=n_agent)
172 | # (n_obs, 4)
173 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays)
174 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent)
175 |
176 | agent_idx = jnp.arange(n_agent)
177 | fn = jax.vmap(ft.partial(pwise_cbf_dubins_car_, r=r, k=k), in_axes=(0, 0, 0, None))
178 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states)
179 | return ak_h0, ak_isobs
180 |
181 |
182 | def pwise_cbf_crazyflie_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int):
183 | # state: ( 12, )
184 | n_agent = len(a_state)
185 |
186 | pos = state[:3]
187 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0)
188 | all_obs_pos = all_obs_state[:, :3]
189 | del o_obs_state
190 |
191 | # Only consider the k closest obstacles.
192 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1)
193 | # Remove self collisions
194 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2)
195 | # Take the k closest obstacles.
196 | k_idx = jnp.argsort(o_dist_sq)[:k]
197 | k_dist_sq = o_dist_sq[k_idx]
198 | # Take radius into account. Add some epsilon for qp solver error.
199 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2
200 |
201 | k_h0 = k_dist_sq
202 | assert k_h0.shape == (k,)
203 |
204 | # all_obs_state = all_obs_state[k_idx]
205 |
206 | def crazyflie_f_(x: Array) -> Array:
207 | Ixx, Iyy, Izz = CrazyFlie.PARAMS["Ixx"], CrazyFlie.PARAMS["Iyy"], CrazyFlie.PARAMS["Izz"]
208 | I = jnp.array([Ixx, Iyy, Izz])
209 | # roll, pitch, yaw
210 | phi, theta, psi = x[CrazyFlie.PHI], x[CrazyFlie.THETA], x[CrazyFlie.PSI]
211 | c_phi, s_phi = jnp.cos(phi), jnp.sin(phi)
212 | c_th, s_th = jnp.cos(theta), jnp.sin(theta)
213 | c_psi, s_psi = jnp.cos(psi), jnp.sin(psi)
214 | t_th = jnp.tan(theta)
215 |
216 | u, v, w = x[CrazyFlie.U], x[CrazyFlie.V], x[CrazyFlie.W]
217 | uvw = jnp.array([u, v, w])
218 |
219 | p, q, r = x[CrazyFlie.P], x[CrazyFlie.Q], x[CrazyFlie.R]
220 | pqr = jnp.array([p, q, r])
221 |
222 | # Linear velocity
223 | R_W_cf = jnp.array(
224 | [
225 | [c_psi * c_th, c_psi * s_th * s_phi - s_psi * c_phi, c_psi * s_th * c_phi + s_psi * s_phi],
226 | [s_psi * c_th, s_psi * s_th * s_phi + c_psi * c_phi, s_psi * s_th * c_phi - c_psi * s_phi],
227 | [-s_th, c_th * s_phi, c_th * c_phi],
228 | ]
229 | )
230 | v_Wcf_cf = jnp.array([u, v, w])
231 | v_Wcf_W = R_W_cf @ v_Wcf_cf
232 | assert v_Wcf_W.shape == (3,)
233 |
234 | # Euler angle dynamics.
235 | mat = jnp.array(
236 | [
237 | [0, s_phi / c_th, c_phi / c_th],
238 | [0, c_phi, -s_phi],
239 | [1, s_phi * t_th, c_phi * t_th],
240 | ]
241 | )
242 | deuler_rpy = mat @ pqr
243 | deuler_ypr = deuler_rpy[::-1]
244 |
245 | # Body frame linear acceleration.
246 | acc_cf_g = -R_W_cf[2, :] * 9.81
247 | acc_cf = -jnp.cross(pqr, uvw) + acc_cf_g
248 |
249 | # Body frame angular acceleration.
250 | pqr_dot = -jnp.cross(pqr, I * pqr) / I
251 | rpq_dot = pqr_dot[::-1]
252 | assert pqr_dot.shape == (3,)
253 |
254 | x_dot = jnp.concatenate([v_Wcf_W, deuler_ypr, acc_cf, rpq_dot], axis=0)
255 | return x_dot
256 |
257 | def h0(x, obs_x):
258 | k_xdiff = x[:3] - obs_x[:, :3]
259 | dist = jnp.square(k_xdiff).sum(axis=-1)
260 | dist = dist.at[agent_idx].set(1e2)
261 | return dist[k_idx] - 4 * r ** 2 # (k,)
262 |
263 | def h1(x, obs_x):
264 | x_dot = crazyflie_f_(x) # (nx,)
265 | obs_x_dot = jax.vmap(crazyflie_f_)(obs_x) # (k, nx)
266 |
267 | h0_x = jax.jacfwd(h0, argnums=0)(x, obs_x) # (k, nx)
268 | h0_obs_x = jax.jacfwd(h0, argnums=1)(x, obs_x) # (k, k, nx)
269 | h0_dot = ei.einsum(h0_x, x_dot, 'k nx, nx -> k') + \
270 | ei.einsum(h0_obs_x, obs_x_dot, 'k1 k2 nx, k2 nx -> k1') # (k,)
271 | return h0_dot + 30.0 * h0(x, obs_x)
272 |
273 | def h2(x, obs_x):
274 | x_dot = crazyflie_f_(x)
275 | obs_x_dot = jax.vmap(crazyflie_f_)(obs_x)
276 | h1_x = jax.jacfwd(h1, argnums=0)(x, obs_x) # (k, nx)
277 | h1_obs_x = jax.jacfwd(h1, argnums=1)(x, obs_x) # (k, k, nx)
278 | h1_dot = ei.einsum(h1_x, x_dot, 'k nx, nx -> k') + \
279 | ei.einsum(h1_obs_x, obs_x_dot, 'k1 k2 nx, k2 nx -> k1') # (k,)
280 | return h1_dot + 50.0 * h1(x, obs_x)
281 |
282 | k_h2 = h2(state, all_obs_state)
283 | assert k_h2.shape == (k,)
284 |
285 | k_isobs = k_idx >= n_agent
286 |
287 | return k_h2, k_isobs
288 |
289 |
290 | def pwise_cbf_crazyflie(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int):
291 | # (n_agents, 4)
292 | a_states = graph.type_states(type_idx=0, n_type=n_agent)
293 | # (n_obs, 4)
294 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays)
295 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent)
296 |
297 | agent_idx = jnp.arange(n_agent)
298 | fn = jax.vmap(ft.partial(pwise_cbf_crazyflie_, r=r, k=k), in_axes=(0, 0, 0, None))
299 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states)
300 | return ak_h0, ak_isobs
301 |
302 |
303 | def pwise_cbf_linear_drone_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int):
304 | # state: ( 6, )
305 | n_agent = len(a_state)
306 |
307 | pos = state[:3]
308 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0)
309 | all_obs_pos = all_obs_state[:, :3]
310 | del o_obs_state
311 |
312 | # Only consider the k closest obstacles.
313 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1)
314 | # Remove self collisions
315 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2)
316 | # Take the k closest obstacles.
317 | k_idx = jnp.argsort(o_dist_sq)[:k]
318 | k_dist_sq = o_dist_sq[k_idx]
319 | # Take radius into account. Add some epsilon for qp solver error.
320 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2
321 |
322 | k_h0 = k_dist_sq
323 | assert k_h0.shape == (k,)
324 |
325 | k_xdiff = state[:3] - all_obs_state[k_idx][:, :3]
326 | k_vdiff = state[3:6] - all_obs_state[k_idx][:, 3:6]
327 | assert k_xdiff.shape == k_vdiff.shape == (k, 3)
328 |
329 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1)
330 | assert k_h0_dot.shape == (k,)
331 |
332 | k_h1 = k_h0_dot + 3.0 * k_h0
333 |
334 | k_isobs = k_idx >= n_agent
335 |
336 | return k_h1, k_isobs
337 |
338 |
339 | def pwise_cbf_linear_drone(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int):
340 | # (n_agents, 4)
341 | a_states = graph.type_states(type_idx=0, n_type=n_agent)
342 | # (n_obs, 4)
343 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays)
344 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent)
345 |
346 | agent_idx = jnp.arange(n_agent)
347 | fn = jax.vmap(ft.partial(pwise_cbf_linear_drone_, r=r, k=k), in_axes=(0, 0, 0, None))
348 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states)
349 | return ak_h0, ak_isobs
350 |
351 |
352 | def pwise_cbf_cfhl_(state: Array, agent_idx: int, o_obs_state: Array, a_state: Array, r: float, k: int):
353 | # state: ( 6, )
354 | n_agent = len(a_state)
355 |
356 | pos = state[:3]
357 | all_obs_state = jnp.concatenate([a_state, o_obs_state], axis=0)
358 | all_obs_pos = all_obs_state[:, :3]
359 | del o_obs_state
360 |
361 | # Only consider the k closest obstacles.
362 | o_dist_sq = ((pos - all_obs_pos) ** 2).sum(axis=-1)
363 | # Remove self collisions
364 | o_dist_sq = o_dist_sq.at[agent_idx].set(1e2)
365 | # Take the k closest obstacles.
366 | k_idx = jnp.argsort(o_dist_sq)[:k]
367 | k_dist_sq = o_dist_sq[k_idx]
368 | # Take radius into account. Add some epsilon for qp solver error.
369 | k_dist_sq = k_dist_sq - 4 * (1.01 * r) ** 2
370 |
371 | k_h0 = k_dist_sq
372 | assert k_h0.shape == (k,)
373 |
374 | k_xdiff = state[:3] - all_obs_state[k_idx][:, :3]
375 |
376 | def get_v_single_(x):
377 | u = x[CrazyFlie.U]
378 | v = x[CrazyFlie.V]
379 | w = x[CrazyFlie.W]
380 |
381 | R_W_cf = CrazyFlie.get_rotation_mat(x)
382 | v_Wcf_cf = jnp.array([u, v, w])
383 | v_Wcf_W = R_W_cf @ v_Wcf_cf # world frame velocity
384 | return v_Wcf_W
385 |
386 | k_vdiff = get_v_single_(state) - jax.vmap(get_v_single_)(all_obs_state[k_idx])
387 |
388 | assert k_xdiff.shape == k_vdiff.shape == (k, 3)
389 |
390 | k_h0_dot = 2 * (k_xdiff * k_vdiff).sum(axis=-1)
391 | assert k_h0_dot.shape == (k,)
392 |
393 | k_h1 = k_h0_dot + 3 * k_h0
394 |
395 | k_isobs = k_idx >= n_agent
396 |
397 | return k_h1, k_isobs
398 |
399 |
400 | def pwise_cbf_cfhl(graph: GraphsTuple, r: float, n_agent: int, n_rays: int, k: int):
401 | # (n_agents, 4)
402 | a_states = graph.type_states(type_idx=0, n_type=n_agent)
403 | # (n_obs, 4)
404 | obs_states = graph.type_states(type_idx=2, n_type=n_agent * n_rays)
405 | a_obs_states = ei.rearrange(obs_states, "(n_agent n_ray) d -> n_agent n_ray d", n_agent=n_agent)
406 |
407 | agent_idx = jnp.arange(n_agent)
408 | fn = jax.vmap(ft.partial(pwise_cbf_cfhl_, r=r, k=k), in_axes=(0, 0, 0, None))
409 | ak_h0, ak_isobs = fn(a_states, agent_idx, a_obs_states, a_states)
410 | return ak_h0, ak_isobs
411 |
412 |
413 | def get_pwise_cbf_fn(env: MultiAgentEnv, k: int = 3):
414 | if isinstance(env, SingleIntegrator):
415 | n_agent = env.num_agents
416 | n_rays = env.params["n_rays"]
417 | r = env.params["car_radius"]
418 | return ft.partial(pwise_cbf_single_integrator, r=r, n_agent=n_agent, n_rays=n_rays, k=k)
419 | elif isinstance(env, DoubleIntegrator):
420 | n_agent = env.num_agents
421 | n_rays = env.params["n_rays"]
422 | r = env.params["car_radius"]
423 | return ft.partial(pwise_cbf_double_integrator, r=r, n_agent=n_agent, n_rays=n_rays, k=k)
424 | elif isinstance(env, DubinsCar):
425 | r = env.params["car_radius"]
426 | n_agent = env.num_agents
427 | n_rays = env.params["n_rays"]
428 | return ft.partial(pwise_cbf_dubins_car, r=r, n_agent=n_agent, n_rays=n_rays, k=k)
429 | elif isinstance(env, CrazyFlie):
430 | return ft.partial(pwise_cbf_crazyflie, r=env.params["drone_radius"], n_agent=env.num_agents, n_rays=env.n_rays,
431 | k=k)
432 | elif isinstance(env, LinearDrone):
433 | return ft.partial(pwise_cbf_linear_drone, r=env.params["drone_radius"], n_agent=env.num_agents,
434 | n_rays=env.n_rays, k=k)
435 | elif isinstance(env, CrazyFlie):
436 | return ft.partial(pwise_cbf_cfhl, r=env.params["drone_radius"], n_agent=env.num_agents,
437 | n_rays=env.n_rays, k=k)
438 |
439 | raise NotImplementedError("")
440 |
--------------------------------------------------------------------------------
/gcbfplus/env/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from .base import MultiAgentEnv
4 | from .single_integrator import SingleIntegrator
5 | from .double_integrator import DoubleIntegrator
6 | from .linear_drone import LinearDrone
7 | from .dubins_car import DubinsCar
8 | from .crazyflie import CrazyFlie
9 |
10 |
11 | ENV = {
12 | 'SingleIntegrator': SingleIntegrator,
13 | 'DoubleIntegrator': DoubleIntegrator,
14 | 'LinearDrone': LinearDrone,
15 | 'DubinsCar': DubinsCar,
16 | 'CrazyFlie': CrazyFlie,
17 | }
18 |
19 |
20 | DEFAULT_MAX_STEP = 256
21 |
22 |
23 | def make_env(
24 | env_id: str,
25 | num_agents: int,
26 | area_size: float = None,
27 | max_step: int = None,
28 | max_travel: Optional[float] = None,
29 | num_obs: Optional[int] = None,
30 | n_rays: Optional[int] = None,
31 | ) -> MultiAgentEnv:
32 | assert env_id in ENV.keys(), f'Environment {env_id} not implemented.'
33 | params = ENV[env_id].PARAMS
34 | max_step = DEFAULT_MAX_STEP if max_step is None else max_step
35 | if num_obs is not None:
36 | params['n_obs'] = num_obs
37 | if n_rays is not None:
38 | params['n_rays'] = n_rays
39 | return ENV[env_id](
40 | num_agents=num_agents,
41 | area_size=area_size,
42 | max_step=max_step,
43 | max_travel=max_travel,
44 | dt=0.03,
45 | params=params
46 | )
47 |
--------------------------------------------------------------------------------
/gcbfplus/env/base.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | import numpy as np
3 | import pathlib
4 | import jax
5 | import jax.lax as lax
6 | import jax.numpy as jnp
7 | import tqdm
8 |
9 | from abc import ABC, abstractmethod, abstractproperty
10 | from typing import Callable, NamedTuple, Optional, Tuple
11 |
12 | from ..utils.graph import GraphsTuple
13 | from ..utils.typing import Action, Array, Cost, Done, Info, PRNGKey, Reward, State
14 | from ..utils.utils import jax2np, jax_jit_np, tree_concat_at_front, tree_stack
15 |
16 |
17 | class StepResult(NamedTuple):
18 | graph: GraphsTuple
19 | reward: Reward
20 | cost: Cost
21 | done: Done
22 | info: Info
23 |
24 |
25 | class RolloutResult(NamedTuple):
26 | Tp1_graph: GraphsTuple
27 | T_action: Action
28 | T_reward: Reward
29 | T_cost: Cost
30 | T_done: Done
31 | T_info: Info
32 |
33 |
34 | class MultiAgentEnv(ABC):
35 |
36 | PARAMS = {}
37 |
38 | def __init__(
39 | self,
40 | num_agents: int,
41 | area_size: float,
42 | max_step: int = 256,
43 | max_travel: float = None,
44 | dt: float = 0.03,
45 | params: dict = None
46 | ):
47 | super(MultiAgentEnv, self).__init__()
48 | self._num_agents = num_agents
49 | self._dt = dt
50 | if params is None:
51 | params = self.PARAMS
52 | self._params = params
53 | self._t = 0
54 | self._max_step = max_step
55 | self._max_travel = max_travel
56 | self._area_size = area_size
57 |
58 | @property
59 | def params(self) -> dict:
60 | return self._params
61 |
62 | @property
63 | def num_agents(self) -> int:
64 | return self._num_agents
65 |
66 | @property
67 | def max_travel(self) -> float:
68 | return self._max_travel
69 |
70 | @property
71 | def area_size(self) -> float:
72 | return self._area_size
73 |
74 | @property
75 | def dt(self) -> float:
76 | return self._dt
77 |
78 | @property
79 | def max_episode_steps(self) -> int:
80 | return self._max_step
81 |
82 | def clip_state(self, state: State) -> State:
83 | lower_limit, upper_limit = self.state_lim(state)
84 | return jnp.clip(state, lower_limit, upper_limit)
85 |
86 | def clip_action(self, action: Action) -> Action:
87 | lower_limit, upper_limit = self.action_lim()
88 | return jnp.clip(action, lower_limit, upper_limit)
89 |
90 | @abstractproperty
91 | def state_dim(self) -> int:
92 | pass
93 |
94 | @abstractproperty
95 | def node_dim(self) -> int:
96 | pass
97 |
98 | @abstractproperty
99 | def edge_dim(self) -> int:
100 | pass
101 |
102 | @abstractproperty
103 | def action_dim(self) -> int:
104 | pass
105 |
106 | @abstractmethod
107 | def reset(self, key: Array) -> GraphsTuple:
108 | pass
109 |
110 | def reset_np(self, key: Array) -> GraphsTuple:
111 | """Reset, but without the constraint that it has to be jittable."""
112 | return self.reset(key)
113 |
114 | @abstractmethod
115 | def step(self, graph: GraphsTuple, action: Action, get_eval_info: bool = False) -> StepResult:
116 | pass
117 |
118 | @abstractmethod
119 | def state_lim(self, state: Optional[State] = None) -> Tuple[State, State]:
120 | """
121 | Returns
122 | -------
123 | lower_limit, upper_limit: Tuple[State, State],
124 | limits of the state
125 | """
126 | pass
127 |
128 | @abstractmethod
129 | def action_lim(self) -> Tuple[Action, Action]:
130 | """
131 | Returns
132 | -------
133 | lower_limit, upper_limit: Tuple[Action, Action],
134 | limits of the action
135 | """
136 | pass
137 |
138 | @abstractmethod
139 | def control_affine_dyn(self, state: State) -> [Array, Array]:
140 | pass
141 |
142 | @abstractmethod
143 | def add_edge_feats(self, graph: GraphsTuple, state: State) -> GraphsTuple:
144 | pass
145 |
146 | @abstractmethod
147 | def get_graph(self, state: State) -> GraphsTuple:
148 | pass
149 |
150 | @abstractmethod
151 | def u_ref(self, graph: GraphsTuple) -> Action:
152 | pass
153 |
154 | @abstractmethod
155 | def forward_graph(self, graph: GraphsTuple, action: Action) -> GraphsTuple:
156 | pass
157 |
158 | @abstractmethod
159 | @ft.partial(jax.jit, static_argnums=(0,))
160 | def safe_mask(self, graph: GraphsTuple) -> Array:
161 | pass
162 |
163 | @abstractmethod
164 | @ft.partial(jax.jit, static_argnums=(0,))
165 | def unsafe_mask(self, graph: GraphsTuple) -> Array:
166 | pass
167 |
168 | @abstractmethod
169 | def collision_mask(self, graph: GraphsTuple) -> Array:
170 | pass
171 |
172 | def rollout_fn(self, policy: Callable, rollout_length: int = None) -> Callable[[PRNGKey], RolloutResult]:
173 | rollout_length = rollout_length or self.max_episode_steps
174 |
175 | def body(graph, _):
176 | action = policy(graph)
177 | graph_new, reward, cost, done, info = self.step(graph, action, get_eval_info=True)
178 | return graph_new, (graph_new, action, reward, cost, done, info)
179 |
180 | def fn(key: PRNGKey) -> RolloutResult:
181 | graph0 = self.reset(key)
182 | graph_final, (T_graph, T_action, T_reward, T_cost, T_done, T_info) = lax.scan(
183 | body, graph0, None, length=rollout_length
184 | )
185 | Tp1_graph = tree_concat_at_front(graph0, T_graph, axis=0)
186 |
187 | return RolloutResult(Tp1_graph, T_action, T_reward, T_cost, T_done, T_info)
188 |
189 | return fn
190 |
191 | def rollout_fn_jitstep(
192 | self, policy: Callable, rollout_length: int = None, noedge: bool = False, nograph: bool = False
193 | ):
194 | rollout_length = rollout_length or self.max_episode_steps
195 |
196 | def body(graph, _):
197 | action = policy(graph)
198 | graph_new, reward, cost, done, info = self.step(graph, action, get_eval_info=True)
199 | return graph_new, (graph_new, action, reward, cost, done, info)
200 |
201 | jit_body = jax.jit(body)
202 |
203 | is_unsafe_fn = jax_jit_np(self.collision_mask)
204 | is_finish_fn = jax_jit_np(self.finish_mask)
205 |
206 | def fn(key: PRNGKey) -> [RolloutResult, Array, Array]:
207 | graph0 = self.reset_np(key)
208 | graph = graph0
209 | T_output = []
210 | is_unsafes = []
211 | is_finishes = []
212 |
213 | is_unsafes.append(is_unsafe_fn(graph0))
214 | is_finishes.append(is_finish_fn(graph0))
215 | graph0 = jax2np(graph0)
216 |
217 | for kk in tqdm.trange(rollout_length, ncols=80):
218 | graph, output = jit_body(graph, None)
219 |
220 | is_unsafes.append(is_unsafe_fn(graph))
221 | is_finishes.append(is_finish_fn(graph))
222 |
223 | output = jax2np(output)
224 | if noedge:
225 | output = (output[0].without_edge(), *output[1:])
226 | if nograph:
227 | output = (None, *output[1:])
228 | T_output.append(output)
229 |
230 | # Concatenate everything together.
231 | T_graph = [o[0] for o in T_output]
232 | if noedge:
233 | T_graph = [graph0.without_edge()] + T_graph
234 | else:
235 | T_graph = [graph0] + T_graph
236 | del graph0
237 | T_action = [o[1] for o in T_output]
238 | T_reward = [o[2] for o in T_output]
239 | T_cost = [o[3] for o in T_output]
240 | T_done = [o[4] for o in T_output]
241 | T_info = [o[5] for o in T_output]
242 | del T_output
243 |
244 | if nograph:
245 | T_graph = None
246 | else:
247 | T_graph = tree_stack(T_graph)
248 | T_action = tree_stack(T_action)
249 | T_reward = tree_stack(T_reward)
250 | T_cost = tree_stack(T_cost)
251 | T_done = tree_stack(T_done)
252 | T_info = tree_stack(T_info)
253 |
254 | Tp1_graph = T_graph
255 |
256 | rollout_result = jax2np(RolloutResult(Tp1_graph, T_action, T_reward, T_cost, T_done, T_info))
257 | return rollout_result, np.stack(is_unsafes, axis=0), np.stack(is_finishes, axis=0)
258 |
259 | return fn
260 |
261 | @abstractmethod
262 | def render_video(
263 | self, rollout: RolloutResult, video_path: pathlib.Path, Ta_is_unsafe=None, viz_opts: dict = None, **kwargs
264 | ) -> None:
265 | pass
266 |
267 | @abstractmethod
268 | def finish_mask(self, graph: GraphsTuple) -> Array:
269 | pass
270 |
--------------------------------------------------------------------------------
/gcbfplus/env/linear_drone.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | import pathlib
3 | import jax
4 | import jax.numpy as jnp
5 | import jax.random as jr
6 | import numpy as np
7 | import scipy
8 |
9 | from typing import NamedTuple, Tuple, Optional
10 |
11 | from ..utils.graph import EdgeBlock, GetGraph, GraphsTuple
12 | from ..utils.typing import Action, AgentState, Array, Cost, Done, Info, Pos3d, Reward, State
13 | from ..utils.utils import merge01
14 | from .base import MultiAgentEnv, RolloutResult
15 | from .obstacle import Obstacle, Sphere
16 | from .plot import render_video
17 | from .utils import get_lidar, inside_obstacles, lqr, get_node_goal_rng
18 |
19 |
20 | class LinearDrone(MultiAgentEnv):
21 | AGENT = 0
22 | GOAL = 1
23 | OBS = 2
24 |
25 | class EnvState(NamedTuple):
26 | agent: AgentState
27 | goal: State
28 | obstacle: Obstacle
29 |
30 | @property
31 | def n_agent(self) -> int:
32 | return self.agent.shape[0]
33 |
34 | EnvGraphsTuple = GraphsTuple[State, EnvState]
35 |
36 | PARAMS = {
37 | "drone_radius": 0.05,
38 | "comm_radius": 0.5,
39 | "n_rays": 32,
40 | "obs_len_range": [0.15, 0.3],
41 | "n_obs": 4
42 | }
43 |
44 | def __init__(
45 | self,
46 | num_agents: int,
47 | area_size: float,
48 | max_step: int = 256,
49 | max_travel: float = None,
50 | dt: float = 0.03,
51 | params: dict = None
52 | ):
53 | super(LinearDrone, self).__init__(num_agents, area_size, max_step, max_travel, dt, params)
54 |
55 | self._A = np.zeros((self.state_dim, self.state_dim))
56 | self._A[0, 3] = 1.
57 | self._A[1, 4] = 1.
58 | self._A[2, 5] = 1.
59 | self._A[3, 3] = -1.1
60 | self._A[4, 4] = -1.1
61 | self._A[5, 5] = -6.
62 | A_discrete = scipy.linalg.expm(self._A * self._dt)
63 |
64 | self._B = np.zeros((self.state_dim, self.action_dim))
65 | self._B[3, 0] = 10.
66 | self._B[4, 1] = 10.
67 | self._B[5, 2] = 10. # 6.0
68 |
69 | self._Q = np.diag([5e1, 5e1, 5e1, 1., 1., 1.])
70 | self._R = np.eye(self.action_dim)
71 | self._K = lqr(A_discrete, self._B, self._Q, self._R)
72 | self.create_obstacles = jax.vmap(Sphere.create)
73 | self.n_rays = 16 # consider top k rays
74 |
75 | @property
76 | def state_dim(self) -> int:
77 | return 6 # x, y, z, vx, vy, vz
78 |
79 | @property
80 | def node_dim(self) -> int:
81 | return 3 # indicator: agent: 001, goal: 010, obstacle: 100
82 |
83 | @property
84 | def edge_dim(self) -> int:
85 | return 6 # x_rel, y_rel, z_rel, vx_rel, vy_rel, vz_rel
86 |
87 | @property
88 | def action_dim(self) -> int:
89 | return 3 # ax, ay, az
90 |
91 | def reset(self, key: Array) -> GraphsTuple:
92 | self._t = 0
93 |
94 | # randomly generate obstacles
95 | n_rng_obs = self._params["n_obs"]
96 | assert n_rng_obs >= 0
97 | obstacle_key, key = jr.split(key, 2)
98 | obs_pos = jr.uniform(obstacle_key, (n_rng_obs, 3), minval=0, maxval=self.area_size)
99 |
100 | r_key, key = jr.split(key, 2)
101 | obs_radius = jr.uniform(r_key, (n_rng_obs,),
102 | minval=self._params["obs_len_range"][0] / 2,
103 | maxval=self._params["obs_len_range"][1] / 2)
104 | obstacles = self.create_obstacles(obs_pos, obs_radius)
105 |
106 | # randomly generate agent and goal
107 | states, goals = get_node_goal_rng(
108 | key, self.area_size, 3, obstacles, self.num_agents, 4 * self.params["drone_radius"], self.max_travel)
109 |
110 | # add zero velocity
111 | states = jnp.concatenate([states, jnp.zeros((self.num_agents, self.state_dim - 3))], axis=1)
112 | goals = jnp.concatenate([goals, jnp.zeros((self.num_agents, self.state_dim - 3))], axis=1)
113 |
114 | env_states = self.EnvState(states, goals, obstacles)
115 |
116 | return self.get_graph(env_states)
117 |
118 | def clip_action(self, action: Action) -> Action:
119 | lower_limit, upper_limit = self.action_lim()
120 | return jnp.clip(action, lower_limit, upper_limit)
121 |
122 | def agent_step_euler(self, agent_state: Array, action: Array) -> Array:
123 | assert action.shape == (self.num_agents, self.action_dim)
124 | assert agent_state.shape == (self.num_agents, self.state_dim)
125 | x_dot = self.agent_xdot(agent_state, action)
126 | n_state_agent_new = agent_state + x_dot * self.dt
127 | assert n_state_agent_new.shape == (self.num_agents, self.state_dim)
128 | return self.clip_state(n_state_agent_new)
129 |
130 | def agent_xdot(self, agent_states: AgentState, action: Action) -> AgentState:
131 | assert action.shape == (self.num_agents, self.action_dim)
132 | assert agent_states.shape == (self.num_agents, self.state_dim)
133 |
134 | return jnp.matmul(agent_states, self._A.T) + jnp.matmul(action, self._B.T)
135 |
136 | def step(
137 | self, graph: EnvGraphsTuple, action: Action, get_eval_info: bool = False
138 | ) -> Tuple[EnvGraphsTuple, Reward, Cost, Done, Info]:
139 | self._t += 1
140 |
141 | # calculate next graph
142 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents)
143 | goal_states = graph.type_states(type_idx=1, n_type=self.num_agents)
144 | obstacles = graph.env_states.obstacle
145 | action = self.clip_action(action)
146 | assert action.shape == (self.num_agents, self.action_dim)
147 | assert agent_states.shape == (self.num_agents, self.state_dim)
148 | next_agent_states = self.agent_step_euler(agent_states, action)
149 |
150 | # the episode ends when reaching max_episode_steps
151 | done = jnp.array(False)
152 |
153 | # compute reward and cost
154 | reward = jnp.zeros(()).astype(jnp.float32)
155 | reward -= (jnp.linalg.norm(action - self.u_ref(graph), axis=1) ** 2).mean()
156 | cost = self.get_cost(graph)
157 |
158 | assert reward.shape == tuple()
159 | assert cost.shape == tuple()
160 | assert done.shape == tuple()
161 |
162 | next_state = self.EnvState(next_agent_states, goal_states, obstacles)
163 |
164 | info = {}
165 | if get_eval_info:
166 | # collision between agents and obstacles
167 | agent_pos = agent_states[:, :3]
168 | info["inside_obstacles"] = inside_obstacles(agent_pos, obstacles, r=self._params["drone_radius"])
169 |
170 | return self.get_graph(next_state), reward, cost, done, info
171 |
172 | def get_cost(self, graph: EnvGraphsTuple) -> Cost:
173 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents)
174 | obstacles = graph.env_states.obstacle
175 |
176 | # collision between agents
177 | agent_pos = agent_states[:, :3]
178 | dist = jnp.linalg.norm(jnp.expand_dims(agent_pos, 1) - jnp.expand_dims(agent_pos, 0), axis=-1)
179 | dist += jnp.eye(self.num_agents) * 1e6
180 | collision = (self._params["drone_radius"] * 2 > dist).any(axis=1)
181 | cost = collision.mean()
182 |
183 | # collision between agents and obstacles
184 | collision = inside_obstacles(agent_pos, obstacles, r=self._params["drone_radius"])
185 | cost += collision.mean()
186 |
187 | return cost
188 |
189 | def render_video(
190 | self,
191 | rollout: RolloutResult,
192 | video_path: pathlib.Path,
193 | Ta_is_unsafe=None,
194 | viz_opts: dict = None,
195 | dpi: int = 100,
196 | **kwargs
197 | ) -> None:
198 | render_video(
199 | rollout=rollout,
200 | video_path=video_path,
201 | side_length=self.area_size,
202 | dim=3,
203 | n_agent=self.num_agents,
204 | n_rays=self.n_rays,
205 | r=self.params["drone_radius"],
206 | Ta_is_unsafe=Ta_is_unsafe,
207 | viz_opts=viz_opts,
208 | dpi=dpi,
209 | **kwargs
210 | )
211 |
212 | def edge_blocks(self, state: EnvState, lidar_data: Pos3d) -> list[EdgeBlock]:
213 | n_hits = self.num_agents * self.n_rays
214 |
215 | # agent - agent connection
216 | agent_pos = state.agent[:, :3]
217 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j
218 | dist = jnp.linalg.norm(pos_diff, axis=-1)
219 | dist += jnp.eye(dist.shape[1]) * (self._params["comm_radius"] + 1)
220 | state_diff = state.agent[:, None, :] - state.agent[None, :, :]
221 | agent_agent_mask = jnp.less(dist, self._params["comm_radius"])
222 | id_agent = jnp.arange(self.num_agents)
223 | agent_agent_edges = EdgeBlock(state_diff, agent_agent_mask, id_agent, id_agent)
224 |
225 | # agent - goal connection, clipped to avoid too long edges
226 | id_goal = jnp.arange(self.num_agents, self.num_agents * 2)
227 | agent_goal_mask = jnp.eye(self.num_agents)
228 | agent_goal_feats = state.agent[:, None, :] - state.goal[None, :, :]
229 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(agent_goal_feats[:, :3] ** 2, axis=-1, keepdims=True))
230 | comm_radius = self._params["comm_radius"]
231 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius)
232 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0)
233 | agent_goal_feats = agent_goal_feats.at[:, :3].set(agent_goal_feats[:, :3] * coef)
234 | agent_goal_edges = EdgeBlock(
235 | agent_goal_feats, agent_goal_mask, id_agent, id_goal
236 | )
237 |
238 | # agent - obs connection
239 | id_obs = jnp.arange(self.num_agents * 2, self.num_agents * 2 + n_hits)
240 | agent_obs_edges = []
241 | for i in range(self.num_agents):
242 | id_hits = jnp.arange(i * self.n_rays, (i + 1) * self.n_rays)
243 | lidar_pos = agent_pos[i, :] - lidar_data[id_hits, :3]
244 | lidar_feats = state.agent[i, :] - lidar_data[id_hits, :]
245 | lidar_dist = jnp.linalg.norm(lidar_pos, axis=-1)
246 | active_lidar = jnp.less(lidar_dist, self._params["comm_radius"] - 1e-1)
247 | agent_obs_mask = jnp.ones((1, self.n_rays))
248 | agent_obs_mask = jnp.logical_and(agent_obs_mask, active_lidar)
249 | agent_obs_edges.append(
250 | EdgeBlock(lidar_feats[None, :, :], agent_obs_mask, id_agent[i][None], id_obs[id_hits])
251 | )
252 |
253 | return [agent_agent_edges, agent_goal_edges] + agent_obs_edges
254 |
255 | def control_affine_dyn(self, state: State) -> [Array, Array]:
256 | assert state.ndim == 2
257 | f = jnp.matmul(state, self._A.T)
258 | g = self._B
259 | g = jnp.expand_dims(g, axis=0).repeat(f.shape[0], axis=0)
260 | assert f.shape == state.shape
261 | assert g.shape == (state.shape[0], self.state_dim, self.action_dim)
262 | return f, g
263 |
264 | def add_edge_feats(self, graph: GraphsTuple, state: State) -> GraphsTuple:
265 | assert graph.is_single
266 | assert state.ndim == 2
267 |
268 | edge_feats = state[graph.receivers] - state[graph.senders]
269 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(edge_feats[:, :3] ** 2, axis=-1, keepdims=True))
270 | comm_radius = self._params["comm_radius"]
271 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius)
272 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0)
273 | edge_feats = edge_feats.at[:, :3].set(edge_feats[:, :3] * coef)
274 |
275 | return graph._replace(edges=edge_feats, states=state)
276 |
277 | def get_graph(self, state: EnvState, adjacency: Array = None) -> GraphsTuple:
278 | # node features
279 | n_hits = self.n_rays * self.num_agents
280 | n_nodes = 2 * self.num_agents + n_hits
281 | node_feats = jnp.zeros((self.num_agents * 2 + n_hits, 3))
282 | node_feats = node_feats.at[: self.num_agents, 2].set(1) # agent feats
283 | node_feats = node_feats.at[self.num_agents: self.num_agents * 2, 1].set(1) # goal feats
284 | node_feats = node_feats.at[-n_hits:, 0].set(1) # obs feats
285 |
286 | node_type = jnp.zeros(n_nodes, dtype=jnp.int32)
287 | node_type = node_type.at[self.num_agents: self.num_agents * 2].set(LinearDrone.GOAL)
288 | node_type = node_type.at[-n_hits:].set(LinearDrone.OBS)
289 |
290 | get_lidar_vmap = jax.vmap(
291 | ft.partial(
292 | get_lidar,
293 | obstacles=state.obstacle,
294 | num_beams=self.params['n_rays'],
295 | sense_range=self._params["comm_radius"],
296 | max_returns=self.n_rays,
297 | )
298 | )
299 | lidar_data = merge01(get_lidar_vmap(state.agent[:, :3]))
300 | lidar_data = jnp.concatenate([lidar_data, jnp.zeros_like(lidar_data)], axis=-1)
301 | edge_blocks = self.edge_blocks(state, lidar_data)
302 |
303 | # create graph
304 | return GetGraph(
305 | nodes=node_feats,
306 | node_type=node_type,
307 | edge_blocks=edge_blocks,
308 | env_states=state,
309 | states=jnp.concatenate([state.agent, state.goal, lidar_data], axis=0),
310 | ).to_padded()
311 |
312 | def state_lim(self, state: Optional[State] = None) -> Tuple[State, State]:
313 | low_lim = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, -0.5, -0.5, -0.5])
314 | up_lim = jnp.array([jnp.inf, jnp.inf, jnp.inf, 0.5, 0.5, 0.5])
315 | return low_lim, up_lim
316 |
317 | def action_lim(self) -> Tuple[Action, Action]:
318 | low_lim = jnp.array([-1., -1., -1.])
319 | up_lim = jnp.array([1., 1., 1.])
320 | return low_lim, up_lim
321 |
322 | def u_ref(self, graph: GraphsTuple) -> Action:
323 | agent = graph.type_states(type_idx=0, n_type=self.num_agents)
324 | goal = graph.type_states(type_idx=1, n_type=self.num_agents)
325 | error = goal - agent
326 | error_max = jnp.abs(error / jnp.linalg.norm(error, axis=-1, keepdims=True) * self._params["comm_radius"])
327 | error = jnp.clip(error, -error_max, error_max)
328 | return self.clip_action(error @ self._K.T)
329 |
330 | def forward_graph(self, graph: GraphsTuple, action: Action) -> GraphsTuple:
331 | # calculate next graph
332 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents)
333 | goal_states = graph.type_states(type_idx=1, n_type=self.num_agents)
334 | obs_states = graph.type_states(type_idx=2, n_type=self._params["n_rays"] * self.num_agents)
335 | action = self.clip_action(action)
336 |
337 | assert action.shape == (self.num_agents, self.action_dim)
338 | assert agent_states.shape == (self.num_agents, self.state_dim)
339 |
340 | next_agent_states = self.agent_step_euler(agent_states, action)
341 | next_states = jnp.concatenate([next_agent_states, goal_states, obs_states], axis=0)
342 |
343 | next_graph = self.add_edge_feats(graph, next_states)
344 | return next_graph
345 |
346 | def safe_mask(self, graph: GraphsTuple) -> Array:
347 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3]
348 | agent_vel = graph.type_states(type_idx=0, n_type=self.num_agents)[:, 3:]
349 |
350 | # agents are not colliding
351 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j
352 | dist = jnp.linalg.norm(pos_diff, axis=-1)
353 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["drone_radius"] * 2 + 1) # remove self connection
354 | safe_agent = jnp.greater(dist, self._params["drone_radius"] * 4)
355 |
356 | safe_agent = jnp.min(safe_agent, axis=1)
357 |
358 | safe_obs = jnp.logical_not(
359 | inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["drone_radius"] * 2)
360 | )
361 |
362 | safe_mask = jnp.logical_and(safe_agent, safe_obs)
363 |
364 | return safe_mask
365 |
366 | def unsafe_mask(self, graph: GraphsTuple) -> Array:
367 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3]
368 |
369 | # agents are colliding
370 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j
371 | dist = jnp.linalg.norm(pos_diff, axis=-1)
372 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["drone_radius"] * 2 + 1) # remove self connection
373 | unsafe_agent = jnp.less(dist, self._params["drone_radius"] * 2.5)
374 | unsafe_agent = jnp.max(unsafe_agent, axis=1)
375 |
376 | # agents are colliding with obstacles
377 | unsafe_obs = inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["drone_radius"] * 1.5)
378 |
379 | collision_mask = jnp.logical_or(unsafe_agent, unsafe_obs)
380 |
381 | return collision_mask
382 |
383 | def collision_mask(self, graph: GraphsTuple) -> Array:
384 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3]
385 |
386 | # agents are colliding
387 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j
388 | dist = jnp.linalg.norm(pos_diff, axis=-1)
389 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["drone_radius"] * 2 + 1) # remove self connection
390 | unsafe_agent = jnp.less(dist, self._params["drone_radius"] * 2)
391 | unsafe_agent = jnp.max(unsafe_agent, axis=1)
392 |
393 | # agents are colliding with obstacles
394 | unsafe_obs = inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["drone_radius"])
395 |
396 | collision_mask = jnp.logical_or(unsafe_agent, unsafe_obs)
397 |
398 | return collision_mask
399 |
400 | def finish_mask(self, graph: GraphsTuple) -> Array:
401 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :3]
402 | goal_pos = graph.env_states.goal[:, :3]
403 | reach = jnp.linalg.norm(agent_pos - goal_pos, axis=1) < self._params["drone_radius"] * 2
404 | return reach
405 |
--------------------------------------------------------------------------------
/gcbfplus/env/obstacle.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 | from typing import NamedTuple, Protocol
5 | from jax.scipy.spatial.transform import Rotation
6 | from ..utils.typing import Pos2d, Pos3d, Pos
7 | from ..utils.typing import Array, ObsType, ObsWidth, ObsHeight, ObsTheta, Radius, ObsLength, ObsQuaternion, BoolScalar
8 |
9 | RECTANGLE = jnp.zeros(1)
10 | CUBOID = jnp.ones(1)
11 | SPHERE = jnp.ones(1) * 2
12 |
13 |
14 | class Obstacle(Protocol):
15 | type: ObsType
16 | center: Pos
17 |
18 | def inside(self, point: Pos, r: Radius = 0.) -> BoolScalar:
19 | pass
20 |
21 | def raytracing(self, start: Pos, end: Pos) -> Array:
22 | pass
23 |
24 |
25 | class Rectangle(NamedTuple):
26 | type: ObsType
27 | center: Pos2d
28 | width: ObsWidth
29 | height: ObsHeight
30 | theta: ObsTheta
31 | points: Array
32 |
33 | @staticmethod
34 | def create(center: Pos2d, width: ObsWidth, height: ObsHeight, theta: ObsTheta) -> "Rectangle":
35 | bbox = jnp.array([
36 | [width / 2, height / 2],
37 | [-width / 2, height / 2],
38 | [-width / 2, -height / 2],
39 | [width / 2, -height / 2],
40 | ]).T # (2, 4)
41 |
42 | rot = jnp.array([
43 | [jnp.cos(theta), -jnp.sin(theta)],
44 | [jnp.sin(theta), jnp.cos(theta)]
45 | ])
46 |
47 | trans = center[:, None]
48 | points = jnp.dot(rot, bbox) + trans
49 | points = points.T
50 |
51 | return Rectangle(RECTANGLE, center, width, height, theta, points)
52 |
53 | def inside(self, point: Pos2d, r: Radius = 0.) -> BoolScalar:
54 | rel_x = point[0] - self.center[0]
55 | rel_y = point[1] - self.center[1]
56 | rel_xx = jnp.abs(rel_x * jnp.cos(self.theta) + rel_y * jnp.sin(self.theta)) - self.width / 2
57 | rel_yy = jnp.abs(rel_x * jnp.sin(self.theta) - rel_y * jnp.cos(self.theta)) - self.height / 2
58 | is_in_down = jnp.logical_and(rel_xx < r, rel_yy < 0)
59 | is_in_up = jnp.logical_and(rel_xx < 0, rel_yy < r)
60 | is_out_corner = jnp.logical_and(rel_xx > 0, rel_yy > 0)
61 | is_in_circle = jnp.sqrt(rel_xx ** 2 + rel_yy ** 2) < r
62 | is_in = jnp.logical_or(jnp.logical_or(is_in_down, is_in_up), jnp.logical_and(is_out_corner, is_in_circle))
63 | return is_in
64 |
65 | def raytracing(self, start: Pos2d, end: Pos2d) -> Array:
66 | # beam
67 | x1 = start[0]
68 | y1 = start[1]
69 | x2 = end[0]
70 | y2 = end[1]
71 |
72 | # edges
73 | x3 = self.points[:, 0]
74 | y3 = self.points[:, 1]
75 | x4 = self.points[[-1, 0, 1, 2], 0]
76 | y4 = self.points[[-1, 0, 1, 2], 1]
77 |
78 | '''
79 | # solve the equation
80 | # x = x1 + alpha * (x2 - x1) = x3 + beta * (x4 - x3)
81 | # y = y1 + alpha * (y2 - y1) = y3 + beta * (y4 - y3)
82 | # equivalent to solve
83 | # [x1-x2 x4-x3] alpha = [x1-x3]
84 | # [y1-y2 y4-y4] beta [y1-y3]
85 | # solve by (alpha beta)^T = A^{-1} b
86 | '''
87 |
88 | det = (x1 - x2) * (y4 - y3) - (y1 - y2) * (x4 - x3)
89 | # clip det for numerical issues
90 | det = jnp.sign(det) * jnp.clip(jnp.abs(det), 1e-7, 1e7)
91 | alphas = ((y4 - y3) * (x1 - x3) - (x4 - x3) * (y1 - y3)) / det
92 | betas = (-(y1 - y2) * (x1 - x3) + (x1 - x2) * (y1 - y3)) / det
93 | valids = jnp.logical_and(jnp.logical_and(alphas <= 1, alphas >= 0), jnp.logical_and(betas <= 1, betas >= 0))
94 | alphas = valids * alphas + (1 - valids) * 1e6
95 | alphas = jnp.min(alphas) # reduce the polygon edges dimension
96 | return alphas
97 |
98 |
99 | class Cuboid(NamedTuple):
100 | type: ObsType
101 | center: Pos3d
102 | length: ObsLength
103 | width: ObsWidth
104 | height: ObsHeight
105 | rotation: Rotation
106 | points: Array
107 |
108 | @staticmethod
109 | def create(
110 | center: Pos3d, length: ObsLength, width: ObsWidth, height: ObsHeight, quaternion: ObsQuaternion
111 | ) -> "Cuboid":
112 | bbox = jnp.array([
113 | [-length / 2, -width / 2, -height / 2],
114 | [length / 2, -width / 2, -height / 2],
115 | [length / 2, width / 2, -height / 2],
116 | [-length / 2, width / 2, -height / 2],
117 | [-length / 2, -width / 2, height / 2],
118 | [length / 2, -width / 2, height / 2],
119 | [length / 2, width / 2, height / 2],
120 | [-length / 2, width / 2, height / 2],
121 | ]) # (8, 3)
122 |
123 | rotation = Rotation.from_quat(quaternion)
124 | points = rotation.apply(bbox) + center
125 | return Cuboid(CUBOID, center, length, width, height, rotation, points)
126 |
127 | def inside(self, point: Pos3d, r: Radius = 0.) -> BoolScalar:
128 | # transform the point to the cuboid frame
129 | rot = self.rotation.as_matrix()
130 | rot_inv = jnp.linalg.inv(rot)
131 | point = jnp.dot(rot_inv, point - self.center)
132 |
133 | # check if the point is inside the cuboid
134 | is_in_height = ((-self.length / 2 < point[0]) & (point[0] < self.length / 2)) & \
135 | ((-self.width / 2 < point[1]) & (point[1] < self.width / 2)) & \
136 | ((-self.height / 2 - r < point[2]) & (point[2] < self.height / 2 + r))
137 | is_in_length = ((-self.length / 2 - r < point[0]) & (point[0] < self.length / 2 + r)) & \
138 | ((-self.width / 2 < point[1]) & (point[1] < self.width / 2)) & \
139 | ((-self.height / 2 < point[2]) & (point[2] < self.height / 2))
140 | is_in_width = ((-self.length / 2 < point[0]) & (point[0] < self.length / 2)) & \
141 | ((-self.width / 2 - r < point[1]) & (point[1] < self.width / 2 + r)) & \
142 | ((-self.height / 2 < point[2]) & (point[2] < self.height / 2))
143 | is_in = is_in_height | is_in_length | is_in_width
144 |
145 | # check if the sphere intersects with the edges
146 | edge_order = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0],
147 | [4, 5], [5, 6], [6, 7], [7, 4],
148 | [0, 4], [1, 5], [2, 6], [3, 7]])
149 | edges = self.points[edge_order]
150 |
151 | def intersect_edge(edge: Array) -> BoolScalar:
152 | assert edge.shape == (2, 3)
153 | dot_prod = jnp.dot(edge[1] - edge[0], point - edge[0])
154 | frac = dot_prod / ((jnp.linalg.norm(edge[1] - edge[0])) ** 2)
155 | frac = jnp.clip(frac, 0, 1)
156 | closest_point = edge[0] + frac * (edge[1] - edge[0])
157 | dist = jnp.linalg.norm(closest_point - point)
158 | return dist <= r
159 |
160 | is_intersect = jnp.any(jax.vmap(intersect_edge)(edges))
161 | return is_in | is_intersect
162 |
163 | def raytracing(self, start: Pos3d, end: Pos3d) -> Array:
164 | # beams
165 | x1, y1, z1 = start[0], start[1], start[2]
166 | x2, y2, z2 = end[0], end[1], end[2]
167 |
168 | # those are for edges
169 | # point order for the base 0~7 (first 0-x-xy-y for the lower level, then 0-x-xy-y for the upper level)
170 | # face order: bottom, left, right, upper, outer left, outer right
171 | x3 = self.points[[0, 0, 0, 6, 6, 6], 0]
172 | y3 = self.points[[0, 0, 0, 6, 6, 6], 1]
173 | z3 = self.points[[0, 0, 0, 6, 6, 6], 2]
174 |
175 | x4 = self.points[[1, 1, 3, 5, 5, 7], 0]
176 | y4 = self.points[[1, 1, 3, 5, 5, 7], 1]
177 | z4 = self.points[[1, 1, 3, 5, 5, 7], 2]
178 |
179 | x5 = self.points[[3, 4, 4, 7, 2, 2], 0]
180 | y5 = self.points[[3, 4, 4, 7, 2, 2], 1]
181 | z5 = self.points[[3, 4, 4, 7, 2, 2], 2]
182 |
183 | '''
184 | # solve the equation
185 | # x = x1 + alpha * (x2 - x1) = x3 + beta * (x4 - x3) + gamma * (x5 - x3)
186 | # y = y1 + alpha * (y2 - y1) = y3 + beta * (y4 - y3) + gamma * (y5 - y3)
187 | # z = z1 + alpha * (z2 - z1) = z3 + beta * (z4 - z3) + gamma * (z5 - z3)
188 | # equivalent to solve
189 | # [x1 - x2 x4 - x3 x5 - x3] alpha = [x1 - x3]
190 | # [y1 - y2 y4 - y3 y5 - y3] beta [y1 - y3]
191 | # [z1 - z2 z4 - z3 z5 - z3] gamma [z1 - z3]
192 | # solve by (alpha beta gamma)^T = A^{-1} b
193 |
194 | # A^{-1} = 1/det * [(y4-y3)*(z5-z3)-(y5-y3)*(z4-z3) -[(x4-x3)*(z5-z3)-(z4-z3)*(x5-x3)] (x4-x3)*(y5-y3)-(y4-y3)*(x5-x3)]
195 | # [-[(y1-y2)*(z5-z3)-(z1-z2)*(y5-y3)] (x1-x2)*(z5-z3)-(z1-z2)*(x5-x3) -[(x1-x2)*(y5-y3)-(y1-y2)*(x5-x3)]]
196 | # [(y1-y2)*(z4-z3)-(y4-y3)*(z1-z2) -[(x1-x2)*(z4-z3)-(z1-z2)*(x4-x3)] (x1-x2)*(y4-y3)-(y1-y2)*(x4-x3)]
197 | '''
198 |
199 | det = (x1 - x2) * (y4 - y3) * (z5 - z3) + (x4 - x3) * (y5 - y3) * (z1 - z2) + (y1 - y2) * (z4 - z3) * (
200 | x5 - x3) - (y1 - y2) * (x4 - x3) * (z5 - z3) - (z4 - z3) * (y5 - y3) * (x1 - x2) - (x5 - x3) * (
201 | y4 - y3) * (z1 - z2)
202 | # clip det for numerical issues
203 | det = jnp.sign(det) * jnp.clip(jnp.abs(det), 1e-7, 1e7)
204 | adj_00 = (y4 - y3) * (z5 - z3) - (y5 - y3) * (z4 - z3)
205 | adj_01 = -((x4 - x3) * (z5 - z3) - (z4 - z3) * (x5 - x3))
206 | adj_02 = (x4 - x3) * (y5 - y3) - (y4 - y3) * (x5 - x3)
207 | adj_10 = -((y1 - y2) * (z5 - z3) - (z1 - z2) * (y5 - y3))
208 | adj_11 = (x1 - x2) * (z5 - z3) - (z1 - z2) * (x5 - x3)
209 | adj_12 = -((x1 - x2) * (y5 - y3) - (y1 - y2) * (x5 - x3))
210 | adj_20 = (y1 - y2) * (z4 - z3) - (y4 - y3) * (z1 - z2)
211 | adj_21 = -((x1 - x2) * (z4 - z3) - (z1 - z2) * (x4 - x3))
212 | adj_22 = (x1 - x2) * (y4 - y3) - (y1 - y2) * (x4 - x3)
213 | alphas = 1 / det * (adj_00 * (x1 - x3) + adj_01 * (y1 - y3) + adj_02 * (z1 - z3))
214 | betas = 1 / det * (adj_10 * (x1 - x3) + adj_11 * (y1 - y3) + adj_12 * (z1 - z3))
215 | gammas = 1 / det * (adj_20 * (x1 - x3) + adj_21 * (y1 - y3) + adj_22 * (z1 - z3))
216 | valids = jnp.logical_and(
217 | jnp.logical_and(jnp.logical_and(alphas <= 1, alphas >= 0), jnp.logical_and(betas <= 1, betas >= 0)),
218 | jnp.logical_and(gammas <= 1, gammas >= 0)
219 | )
220 | alphas = valids * alphas + (1 - valids) * 1e6
221 | alphas = jnp.min(alphas) # reduce the polygon edges dimension
222 | return alphas
223 |
224 |
225 | class Sphere(NamedTuple):
226 | type: ObsType
227 | center: Pos3d
228 | radius: Radius
229 |
230 | @staticmethod
231 | def create(center: Pos3d, radius: Radius) -> "Sphere":
232 | return Sphere(SPHERE, center, radius)
233 |
234 | def inside(self, point: Pos3d, r: Radius = 0.) -> BoolScalar:
235 | return jnp.linalg.norm(point - self.center) <= self.radius + r
236 |
237 | def raytracing(self, start: Pos3d, end: Pos3d) -> Array:
238 | x1, y1, z1 = start[0], start[1], start[2]
239 | x2, y2, z2 = end[0], end[1], end[2]
240 | xc, yc, zc = self.center[0], self.center[1], self.center[2]
241 | r = self.radius
242 |
243 | '''
244 | # solve the equation
245 | # x = x1 + alpha * (x2 - x1) = xc + r * sin(gamma) * cos(theta)
246 | # y = y1 + alpha * (y2 - y1) = yc + r * sin(gamma) * sin(theta)
247 | # z = z1 + alpha * (z2 - z1) = zc + r * cos(gamma)
248 | # equivalent to solve (eliminate theta using sin^2(sin^2+cos^2) +cos^2 ...=1)
249 | # [(x2-x1)^2+(y2-y1)^2+(z2-z1)^2]alpha^2+2[(x2-x1)(x1-xc)+(y2-y1)(y1-yc)+(z2-z1)(z1-zc)]alpha+(x1-xc)^2+(y1-yc)^2+(z1-zc)^2-r^2=0
250 | # A alpha^2 + B alpha + C = 0
251 | # check delta = B^2-4AC
252 | # alpha = ...
253 | # take valid min
254 | '''
255 | lidar_rmax = jnp.linalg.norm(end - start)
256 | A = lidar_rmax ** 2 # (x2-x1)^2+(y2-y1)^2
257 | B = 2 * ((x2 - x1) * (x1 - xc) + (y2 - y1) * (y1 - yc) + (z2 - z1) * (z1 - zc))
258 | C = (x1 - xc) ** 2 + (y1 - yc) ** 2 + (z1 - zc) ** 2 - r ** 2
259 |
260 | delta = B ** 2 - 4 * A * C
261 | valid1 = delta >= 0
262 |
263 | alpha1 = (-B - jnp.sqrt(delta * valid1)) / (2 * A) * valid1 + (1 - valid1)
264 | alpha2 = (-B + jnp.sqrt(delta * valid1)) / (2 * A) * valid1 + (1 - valid1)
265 | alpha1_tilde = (alpha1 >= 0) * alpha1 + (alpha1 < 0) * 1
266 | alpha2_tilde = (alpha2 >= 0) * alpha2 + (alpha2 < 0) * 1
267 | alphas = jnp.minimum(alpha1_tilde, alpha2_tilde)
268 | alphas = jnp.clip(alphas, 0, 1)
269 | alphas = valid1 * alphas + (1 - valid1) * 1e6
270 | return alphas
271 |
--------------------------------------------------------------------------------
/gcbfplus/env/plot.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import jax
5 | import pathlib
6 |
7 | from colour import hsl2hex
8 | from matplotlib.animation import FuncAnimation
9 | from matplotlib.collections import LineCollection, PatchCollection
10 | from matplotlib.colors import LinearSegmentedColormap
11 | from matplotlib.pyplot import Axes
12 | from matplotlib.patches import Polygon
13 | from mpl_toolkits.mplot3d import proj3d, Axes3D
14 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
15 | from typing import List, Optional, Union
16 |
17 | from ..trainer.utils import centered_norm
18 | from ..utils.typing import EdgeIndex, Pos2d, Pos3d, Array
19 | from ..utils.utils import merge01, tree_index, MutablePatchCollection, save_anim
20 | from .obstacle import Cuboid, Sphere, Obstacle, Rectangle
21 | from .base import RolloutResult
22 |
23 |
24 | def plot_graph(
25 | ax: Axes,
26 | pos: Pos2d,
27 | radius: Union[float, List[float]],
28 | color: Union[str, List[str]],
29 | with_label: Union[bool, List[bool]] = True,
30 | plot_edge: bool = False,
31 | edge_index: Optional[EdgeIndex] = None,
32 | edge_color: Union[str, List[str]] = 'k',
33 | alpha: float = 1.0,
34 | obstacle_color: str = '#000000',
35 | ) -> Axes:
36 | if isinstance(radius, float):
37 | radius = np.ones(pos.shape[0]) * radius
38 | if isinstance(radius, list):
39 | radius = np.array(radius)
40 | if isinstance(color, str):
41 | color = [color for _ in range(pos.shape[0])]
42 | if isinstance(with_label, bool):
43 | with_label = [with_label for _ in range(pos.shape[0])]
44 | circles = []
45 | for i in range(pos.shape[0]):
46 | circles.append(plt.Circle((float(pos[i, 0]), float(pos[i, 1])),
47 | radius=radius[i], color=color[i], clip_on=False, alpha=alpha, linewidth=0.0))
48 | if with_label[i]:
49 | ax.text(float(pos[i, 0]), float(pos[i, 1]), f'{i}', size=12, color="k",
50 | family="sans-serif", weight="normal", horizontalalignment="center", verticalalignment="center",
51 | transform=ax.transData, clip_on=True)
52 | circles = PatchCollection(circles, match_original=True)
53 | ax.add_collection(circles)
54 |
55 | if plot_edge and edge_index is not None:
56 | if isinstance(edge_color, str):
57 | edge_color = [edge_color for _ in range(edge_index.shape[1])]
58 | start, end = pos[edge_index[0, :]], pos[edge_index[1, :]]
59 | direction = (end - start) / jnp.linalg.norm(end - start, axis=1, keepdims=True)
60 | start = start + direction * radius[edge_index[0, :]][:, None]
61 | end = end - direction * radius[edge_index[1, :]][:, None]
62 | widths = (radius[edge_index[0, :]] + radius[edge_index[1, :]]) * 20
63 | lines = np.stack([start, end], axis=1)
64 | edges = LineCollection(lines, colors=edge_color, linewidths=widths, alpha=0.5)
65 | ax.add_collection(edges)
66 | return ax
67 |
68 |
69 | def plot_node_3d(ax, pos: Pos3d, r: float, color: str, alpha: float, grid: int = 10) -> Axes:
70 | u = np.linspace(0, 2 * np.pi, grid)
71 | v = np.linspace(0, np.pi, grid)
72 | x = r * np.outer(np.cos(u), np.sin(v)) + pos[0]
73 | y = r * np.outer(np.sin(u), np.sin(v)) + pos[1]
74 | z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + pos[2]
75 | ax.plot_surface(x, y, z, color=color, alpha=alpha)
76 | return ax
77 |
78 |
79 | def plot_graph_3d(
80 | ax,
81 | pos: Pos3d,
82 | radius: float,
83 | color: Union[str, List[str]],
84 | with_label: bool = True,
85 | plot_edge: bool = False,
86 | edge_index: Optional[EdgeIndex] = None,
87 | edge_color: Union[str, List[str]] = 'k',
88 | alpha: float = 1.0,
89 | obstacle_color: str = '#000000',
90 | ):
91 | if isinstance(color, str):
92 | color = [color for _ in range(pos.shape[0])]
93 | for i in range(pos.shape[0]):
94 | plot_node_3d(ax, pos[i], radius, color[i], alpha)
95 | if with_label:
96 | ax.text(pos[i, 0], pos[i, 1], pos[i, 2], f'{i}', size=12, color="k", family="sans-serif", weight="normal",
97 | horizontalalignment="center", verticalalignment="center")
98 | if plot_edge:
99 | if isinstance(edge_color, str):
100 | edge_color = [edge_color for _ in range(edge_index.shape[1])]
101 | for i_edge in range(edge_index.shape[1]):
102 | i = edge_index[0, i_edge]
103 | j = edge_index[1, i_edge]
104 | vec = pos[i, :] - pos[j, :]
105 | x = [pos[i, 0] - 2 * radius * vec[0], pos[j, 0] + 2 * radius * vec[0]]
106 | y = [pos[i, 1] - 2 * radius * vec[1], pos[j, 1] + 2 * radius * vec[1]]
107 | z = [pos[i, 2] - 2 * radius * vec[2], pos[j, 2] + 2 * radius * vec[2]]
108 | ax.plot(x, y, z, linewidth=1.0, color=edge_color[i_edge])
109 | return ax
110 |
111 |
112 | def get_BuRd():
113 | # blue = "#3182bd"
114 | # blue = hsl2hex([0.57, 0.59, 0.47])
115 | blue = hsl2hex([0.57, 0.5, 0.55])
116 | light_blue = hsl2hex([0.5, 1.0, 0.995])
117 |
118 | # Tint it to orange a bit.
119 | # red = "#de2d26"
120 | # red = hsl2hex([0.04, 0.74, 0.51])
121 | red = hsl2hex([0.028, 0.62, 0.59])
122 | light_red = hsl2hex([0.098, 1.0, 0.995])
123 |
124 | sdf_cm = LinearSegmentedColormap.from_list("SDF", [(0, light_blue), (0.5, blue), (0.5, red), (1, light_red)], N=256)
125 | return sdf_cm
126 |
127 |
128 | def get_faces_cuboid(points: Pos3d) -> Array:
129 | point_id = jnp.array([[0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 5, 4], [2, 3, 7, 6], [0, 3, 7, 4], [1, 2, 6, 5]])
130 | faces = points[point_id]
131 | return faces
132 |
133 |
134 | def get_cuboid_collection(
135 | obstacles: Cuboid, alpha: float = 0.8, linewidth: float = 1.0, edgecolor: str = 'k', facecolor: str = 'r'
136 | ) -> Poly3DCollection:
137 | get_faces_vmap = jax.vmap(get_faces_cuboid)
138 | cuboid_col = Poly3DCollection(
139 | merge01(get_faces_vmap(obstacles.points)),
140 | alpha=alpha,
141 | linewidth=linewidth,
142 | edgecolor=edgecolor,
143 | facecolor=facecolor
144 | )
145 | return cuboid_col
146 |
147 |
148 | def get_sphere_collection(
149 | obstacles: Sphere, alpha: float = 0.8, facecolor: str = 'r'
150 | ) -> Poly3DCollection:
151 | def get_sphere(inp):
152 | center = inp[:3]
153 | radius = inp[3]
154 | u = np.linspace(0, 2 * np.pi, 30)
155 | v = np.linspace(0, np.pi, 30)
156 | x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
157 | y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
158 | z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]
159 | return jnp.stack([x, y, z], axis=-1)
160 |
161 | get_sphere_vmap = jax.vmap(get_sphere)
162 | sphere_col = Poly3DCollection(
163 | merge01(get_sphere_vmap(jnp.concatenate([obstacles.center, obstacles.radius[:, None]], axis=-1))),
164 | alpha=alpha,
165 | linewidth=0.0,
166 | edgecolor='k',
167 | facecolor=facecolor
168 | )
169 |
170 | return sphere_col
171 |
172 |
173 | def get_obs_collection(
174 | obstacles: Obstacle, color: str, alpha: float
175 | ):
176 | if isinstance(obstacles, Rectangle):
177 | n_obs = len(obstacles.center)
178 | obs_polys = [Polygon(obstacles.points[ii]) for ii in range(n_obs)]
179 | obs_col = PatchCollection(obs_polys, color=color, alpha=1.0, zorder=99)
180 | elif isinstance(obstacles, Cuboid):
181 | obs_col = get_cuboid_collection(obstacles, alpha=alpha, facecolor=color)
182 | elif isinstance(obstacles, Sphere):
183 | obs_col = get_sphere_collection(obstacles, alpha=alpha, facecolor=color)
184 | else:
185 | raise NotImplementedError
186 | return obs_col
187 |
188 |
189 | def render_video(
190 | rollout: RolloutResult,
191 | video_path: pathlib.Path,
192 | side_length: float,
193 | dim: int,
194 | n_agent: int,
195 | n_rays: int,
196 | r: float,
197 | Ta_is_unsafe=None,
198 | viz_opts: dict = None,
199 | dpi: int = 100,
200 | **kwargs
201 | ):
202 | assert dim == 2 or dim == 3
203 |
204 | # set up visualization option
205 | if dim == 2:
206 | ax: Axes
207 | fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=dpi)
208 | else:
209 | fig = plt.figure(figsize=(10, 10), dpi=dpi)
210 | ax: Axes3D = fig.add_subplot(projection='3d')
211 | ax.set_xlim(0., side_length)
212 | ax.set_ylim(0., side_length)
213 | if dim == 3:
214 | ax.set_zlim(0., side_length)
215 | ax.set(aspect="equal")
216 | if dim == 2:
217 | plt.axis("off")
218 |
219 | if viz_opts is None:
220 | viz_opts = {}
221 |
222 | # plot the first frame
223 | T_graph = rollout.Tp1_graph
224 | graph0 = tree_index(T_graph, 0)
225 |
226 | agent_color = "#0068ff"
227 | goal_color = "#2fdd00"
228 | obs_color = "#8a0000"
229 | edge_goal_color = goal_color
230 |
231 | # plot obstacles
232 | obs = graph0.env_states.obstacle
233 | ax.add_collection(get_obs_collection(obs, obs_color, alpha=0.8))
234 |
235 | # plot agents
236 | n_hits = n_agent * n_rays
237 | n_color = [agent_color] * n_agent + [goal_color] * n_agent
238 | n_pos = graph0.states[:n_agent * 2, :dim]
239 | n_radius = np.array([r] * n_agent * 2)
240 | if dim == 2:
241 | agent_circs = [plt.Circle(n_pos[ii], n_radius[ii], color=n_color[ii], linewidth=0.0)
242 | for ii in range(n_agent * 2)]
243 | agent_col = MutablePatchCollection([i for i in reversed(agent_circs)], match_original=True, zorder=6)
244 | ax.add_collection(agent_col)
245 | else:
246 | plot_r = ax.transData.transform([r, 0])[0] - ax.transData.transform([0, 0])[0]
247 | agent_col = ax.scatter(n_pos[:, 0], n_pos[:, 1], n_pos[:, 2],
248 | s=plot_r, c=n_color, zorder=5) # todo: the size of the agent might not be correct
249 |
250 | # plot edges
251 | all_pos = graph0.states[:n_agent * 2 + n_hits, :dim]
252 | edge_index = np.stack([graph0.senders, graph0.receivers], axis=0)
253 | is_pad = np.any(edge_index == n_agent * 2 + n_hits, axis=0)
254 | e_edge_index = edge_index[:, ~is_pad]
255 | e_start, e_end = all_pos[e_edge_index[0, :]], all_pos[e_edge_index[1, :]]
256 | e_lines = np.stack([e_start, e_end], axis=1) # (e, n_pts, dim)
257 | e_is_goal = (n_agent <= graph0.senders) & (graph0.senders < n_agent * 2)
258 | e_is_goal = e_is_goal[~is_pad]
259 | e_colors = [edge_goal_color if e_is_goal[ii] else "0.2" for ii in range(len(e_start))]
260 | if dim == 2:
261 | edge_col = LineCollection(e_lines, colors=e_colors, linewidths=2, alpha=0.5, zorder=3)
262 | else:
263 | edge_col = Line3DCollection(e_lines, colors=e_colors, linewidths=2, alpha=0.5, zorder=3)
264 | ax.add_collection(edge_col)
265 |
266 | # text for cost and reward
267 | text_font_opts = dict(
268 | size=16,
269 | color="k",
270 | family="cursive",
271 | weight="normal",
272 | transform=ax.transAxes,
273 | )
274 | if dim == 2:
275 | cost_text = ax.text(0.02, 1.04, "Cost: 1.0, Reward: 1.0", va="bottom", **text_font_opts)
276 | else:
277 | cost_text = ax.text2D(0.02, 1.04, "Cost: 1.0, Reward: 1.0", va="bottom", **text_font_opts)
278 |
279 | # text for safety
280 | safe_text = []
281 | if Ta_is_unsafe is not None:
282 | if dim == 2:
283 | safe_text = [ax.text(0.02, 1.00, "Unsafe: {}", va="bottom", **text_font_opts)]
284 | else:
285 | safe_text = [ax.text2D(0.02, 1.00, "Unsafe: {}", va="bottom", **text_font_opts)]
286 |
287 | # text for time step
288 | if dim == 2:
289 | kk_text = ax.text(0.99, 0.99, "kk=0", va="top", ha="right", **text_font_opts)
290 | else:
291 | kk_text = ax.text2D(0.99, 0.99, "kk=0", va="top", ha="right", **text_font_opts)
292 |
293 | # add agent labels
294 | label_font_opts = dict(
295 | size=20,
296 | color="k",
297 | family="cursive",
298 | weight="normal",
299 | ha="center",
300 | va="center",
301 | transform=ax.transData,
302 | clip_on=True,
303 | zorder=7,
304 | )
305 | agent_labels = []
306 | if dim == 2:
307 | agent_labels = [ax.text(n_pos[ii, 0], n_pos[ii, 1], f"{ii}", **label_font_opts) for ii in range(n_agent)]
308 | else:
309 | for ii in range(n_agent):
310 | pos2d = proj3d.proj_transform(n_pos[ii, 0], n_pos[ii, 1], n_pos[ii, 2], ax.get_proj())[:2]
311 | agent_labels.append(ax.text2D(pos2d[0], pos2d[1], f"{ii}", **label_font_opts))
312 |
313 | # plot cbf
314 | cnt_col = []
315 | if "cbf" in viz_opts:
316 | if dim == 3:
317 | print('Warning: CBF visualization is not supported in 3D.')
318 | else:
319 | Tb_xs, Tb_ys, Tbb_h, cbf_num = viz_opts["cbf"]
320 | bb_Xs, bb_Ys = np.meshgrid(Tb_xs[0], Tb_ys[0])
321 | norm = centered_norm(Tbb_h.min(), Tbb_h.max())
322 | levels = np.linspace(norm.vmin, norm.vmax, 15)
323 |
324 | cmap = get_BuRd().reversed()
325 | contour_opts = dict(cmap=cmap, norm=norm, levels=levels, alpha=0.9)
326 | cnt = ax.contourf(bb_Xs, bb_Ys, Tbb_h[0], **contour_opts)
327 |
328 | contour_line_opts = dict(levels=[0.0], colors=["k"], linewidths=3.0)
329 | cnt_line = ax.contour(bb_Xs, bb_Ys, Tbb_h[0], **contour_line_opts)
330 |
331 | cbar = fig.colorbar(cnt, ax=ax)
332 | cbar.add_lines(cnt_line)
333 | cbar.ax.tick_params(labelsize=36, labelfontfamily="Times New Roman")
334 |
335 | cnt_col = [*cnt.collections, *cnt_line.collections]
336 |
337 | ax.text(0.5, 1.0, "CBF for {}".format(cbf_num), transform=ax.transAxes, va="bottom")
338 |
339 | # init function for animation
340 | def init_fn() -> list[plt.Artist]:
341 | return [agent_col, edge_col, *agent_labels, cost_text, *safe_text, *cnt_col, kk_text]
342 |
343 | # update function for animation
344 | def update(kk: int) -> list[plt.Artist]:
345 | graph = tree_index(T_graph, kk)
346 | n_pos_t = graph.states[:-1, :dim]
347 |
348 | # update agent positions
349 | if dim == 2:
350 | for ii in range(n_agent):
351 | agent_circs[ii].set_center(tuple(n_pos_t[ii]))
352 | else:
353 | agent_col.set_offsets(n_pos_t[:n_agent * 2, :2])
354 | agent_col.set_3d_properties(n_pos_t[:n_agent * 2, 2], zdir='z')
355 |
356 | # update edges
357 | e_edge_index_t = np.stack([graph.senders, graph.receivers], axis=0)
358 | is_pad_t = np.any(e_edge_index_t == n_agent * 2 + n_hits, axis=0)
359 | e_edge_index_t = e_edge_index_t[:, ~is_pad_t]
360 | e_start_t, e_end_t = n_pos_t[e_edge_index_t[0, :]], n_pos_t[e_edge_index_t[1, :]]
361 | e_is_goal_t = (n_agent <= graph.senders) & (graph.senders < n_agent * 2)
362 | e_is_goal_t = e_is_goal_t[~is_pad_t]
363 | e_colors_t = [edge_goal_color if e_is_goal_t[ii] else "0.2" for ii in range(len(e_start_t))]
364 | e_lines_t = np.stack([e_start_t, e_end_t], axis=1)
365 | edge_col.set_segments(e_lines_t)
366 | edge_col.set_colors(e_colors_t)
367 |
368 | # update agent labels
369 | for ii in range(n_agent):
370 | if dim == 2:
371 | agent_labels[ii].set_position(n_pos_t[ii])
372 | else:
373 | text_pos = proj3d.proj_transform(n_pos_t[ii, 0], n_pos_t[ii, 1], n_pos_t[ii, 2], ax.get_proj())[:2]
374 | agent_labels[ii].set_position(text_pos)
375 |
376 | # update cost and safe labels
377 | if kk < len(rollout.T_cost):
378 | cost_text.set_text("Cost: {:5.4f}, Reward: {:5.4f}".format(rollout.T_cost[kk], rollout.T_reward[kk]))
379 | else:
380 | cost_text.set_text("")
381 | if kk < len(Ta_is_unsafe):
382 | a_is_unsafe = Ta_is_unsafe[kk]
383 | unsafe_idx = np.where(a_is_unsafe)[0]
384 | safe_text[0].set_text("Unsafe: {}".format(unsafe_idx))
385 | else:
386 | safe_text[0].set_text("Unsafe: {}")
387 |
388 | # Update the contourf.
389 | nonlocal cnt, cnt_line
390 | if "cbf" in viz_opts and dim == 2:
391 | for c in cnt.collections:
392 | c.remove()
393 | for c in cnt_line.collections:
394 | c.remove()
395 |
396 | bb_Xs_t, bb_Ys_t = np.meshgrid(Tb_xs[kk], Tb_ys[kk])
397 | cnt = ax.contourf(bb_Xs_t, bb_Ys_t, Tbb_h[kk], **contour_opts)
398 | cnt_line = ax.contour(bb_Xs_t, bb_Ys_t, Tbb_h[kk], **contour_line_opts)
399 |
400 | cnt_col_t = [*cnt.collections, *cnt_line.collections]
401 | else:
402 | cnt_col_t = []
403 |
404 | kk_text.set_text("kk={:04}".format(kk))
405 |
406 | return [agent_col, edge_col, *agent_labels, cost_text, *safe_text, *cnt_col_t, kk_text]
407 |
408 | fps = 30.0
409 | spf = 1 / fps
410 | mspf = 1_000 * spf
411 | anim_T = len(T_graph.n_node)
412 | ani = FuncAnimation(fig, update, frames=anim_T, init_func=init_fn, interval=mspf, blit=True)
413 | save_anim(ani, video_path)
414 |
--------------------------------------------------------------------------------
/gcbfplus/env/single_integrator.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | import pathlib
3 | import jax
4 | import jax.numpy as jnp
5 | import jax.random as jr
6 | import numpy as np
7 |
8 | from typing import NamedTuple, Tuple, Optional
9 |
10 | from ..utils.graph import EdgeBlock, GetGraph, GraphsTuple
11 | from ..utils.typing import Action, Array, Cost, Done, Info, Pos2d, Reward, State, AgentState
12 | from ..utils.utils import merge01, jax_vmap
13 | from .base import MultiAgentEnv, RolloutResult
14 | from .obstacle import Obstacle, Rectangle
15 | from .plot import render_video
16 | from .utils import get_lidar, inside_obstacles, lqr, get_node_goal_rng
17 |
18 |
19 | class SingleIntegrator(MultiAgentEnv):
20 |
21 | AGENT = 0
22 | GOAL = 1
23 | OBS = 2
24 |
25 | class EnvState(NamedTuple):
26 | agent: State
27 | goal: State
28 | obstacle: Obstacle
29 |
30 | @property
31 | def n_agent(self) -> int:
32 | return self.agent.shape[0]
33 |
34 | EnvGraphsTuple = GraphsTuple[State, EnvState]
35 |
36 | PARAMS = {
37 | "car_radius": 0.05,
38 | "comm_radius": 0.5,
39 | "n_rays": 32,
40 | "obs_len_range": [0.1, 0.6],
41 | "n_obs": 8,
42 | }
43 |
44 | def __init__(
45 | self,
46 | num_agents: int,
47 | area_size: float,
48 | max_step: int = 256,
49 | max_travel: float = None,
50 | dt: float = 0.03,
51 | params: dict = None
52 | ):
53 | super(SingleIntegrator, self).__init__(num_agents, area_size, max_step, max_travel, dt, params)
54 | self._A = np.zeros((self.state_dim, self.state_dim), dtype=np.float32) * self._dt + np.eye(self.state_dim)
55 | self._B = np.array([[1.0, 0.0], [0.0, 1.0]]) * self._dt
56 | self._Q = np.eye(self.state_dim) * 2
57 | self._R = np.eye(self.action_dim)
58 | self._K = jnp.array(lqr(self._A, self._B, self._Q, self._R))
59 | self.create_obstacles = jax_vmap(Rectangle.create)
60 |
61 | @property
62 | def state_dim(self) -> int:
63 | return 2 # x, y
64 |
65 | @property
66 | def node_dim(self) -> int:
67 | return 3 # indicator: agent: 001, goal: 010, obstacle: 100
68 |
69 | @property
70 | def edge_dim(self) -> int:
71 | return 2 # x_rel, y_rel
72 |
73 | @property
74 | def action_dim(self) -> int:
75 | return 2 # vx, vy
76 |
77 | def reset(self, key: Array) -> GraphsTuple:
78 | self._t = 0
79 |
80 | # randomly generate obstacles
81 | n_rng_obs = self._params["n_obs"]
82 | assert n_rng_obs >= 0
83 | obstacle_key, key = jr.split(key, 2)
84 | obs_pos = jr.uniform(obstacle_key, (n_rng_obs, 2), minval=0, maxval=self.area_size)
85 | length_key, key = jr.split(key, 2)
86 | obs_len = jr.uniform(
87 | length_key,
88 | (self._params["n_obs"], 2),
89 | minval=self._params["obs_len_range"][0],
90 | maxval=self._params["obs_len_range"][1],
91 | )
92 | theta_key, key = jr.split(key, 2)
93 | obs_theta = jr.uniform(theta_key, (n_rng_obs,), minval=0, maxval=2 * np.pi)
94 | obstacles = self.create_obstacles(obs_pos, obs_len[:, 0], obs_len[:, 1], obs_theta)
95 |
96 | # randomly generate agent and goal
97 | states, goals = get_node_goal_rng(
98 | key, self.area_size, 2, obstacles, self.num_agents, 4 * self.params["car_radius"], self.max_travel)
99 |
100 | env_states = self.EnvState(states, goals, obstacles)
101 |
102 | return self.get_graph(env_states)
103 |
104 | def agent_step_euler(self, agent_states: AgentState, action: Action) -> AgentState:
105 | assert action.shape == (self.num_agents, self.action_dim)
106 | assert agent_states.shape == (self.num_agents, self.state_dim)
107 | x_dot = action
108 | n_state_agent_new = x_dot * self.dt + agent_states
109 | assert n_state_agent_new.shape == (self.num_agents, self.state_dim)
110 | return self.clip_state(n_state_agent_new)
111 |
112 | def step(
113 | self, graph: EnvGraphsTuple, action: Action, get_eval_info: bool = False
114 | ) -> Tuple[EnvGraphsTuple, Reward, Cost, Done, Info]:
115 | self._t += 1
116 |
117 | # calculate next graph
118 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents)
119 | goals = graph.type_states(type_idx=1, n_type=self.num_agents)
120 | obstacles = graph.env_states.obstacle
121 | action = self.clip_action(action)
122 |
123 | assert action.shape == (self.num_agents, self.action_dim)
124 | assert agent_states.shape == (self.num_agents, self.state_dim)
125 |
126 | next_agent_states = self.agent_step_euler(agent_states, action)
127 |
128 | # the episode ends when reaching max_episode_steps
129 | done = jnp.array(False)
130 |
131 | # compute reward and cost
132 | reward = jnp.zeros(()).astype(jnp.float32)
133 | reward -= (jnp.linalg.norm(action - self.u_ref(graph), axis=1) ** 2).mean()
134 | cost = self.get_cost(graph)
135 |
136 | assert reward.shape == tuple()
137 | assert cost.shape == tuple()
138 | assert done.shape == tuple()
139 |
140 | next_state = self.EnvState(next_agent_states, goals, obstacles)
141 |
142 | info = {}
143 | if get_eval_info:
144 | # collision between agents and obstacles
145 | agent_pos = agent_states
146 | info["inside_obstacles"] = inside_obstacles(agent_pos, obstacles, r=self._params["car_radius"])
147 |
148 | return self.get_graph(next_state), reward, cost, done, info
149 |
150 | def get_cost(self, graph: GraphsTuple) -> Cost:
151 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents)
152 | obstacles = graph.env_states.obstacle
153 |
154 | # collision between agents
155 | agent_pos = agent_states
156 | dist = jnp.linalg.norm(jnp.expand_dims(agent_pos, 1) - jnp.expand_dims(agent_pos, 0), axis=-1)
157 | dist += jnp.eye(self.num_agents) * 1e6
158 | collision = (self._params["car_radius"] * 2 > dist).any(axis=1)
159 | cost = collision.mean()
160 |
161 | # collision between agents and obstacles
162 | collision = inside_obstacles(agent_pos, obstacles, r=self._params["car_radius"])
163 | cost += collision.mean()
164 |
165 | return cost
166 |
167 | def render_video(
168 | self,
169 | rollout: RolloutResult,
170 | video_path: pathlib.Path,
171 | Ta_is_unsafe=None,
172 | viz_opts: dict = None,
173 | dpi: int = 100,
174 | **kwargs
175 | ) -> None:
176 | render_video(
177 | rollout=rollout,
178 | video_path=video_path,
179 | side_length=self.area_size,
180 | dim=2,
181 | n_agent=self.num_agents,
182 | n_rays=self.params["n_rays"],
183 | r=self.params["car_radius"],
184 | Ta_is_unsafe=Ta_is_unsafe,
185 | viz_opts=viz_opts,
186 | dpi=dpi,
187 | **kwargs
188 | )
189 |
190 | def edge_blocks(self, state: EnvState, lidar_data: Pos2d) -> list[EdgeBlock]:
191 | n_hits = self._params["n_rays"] * self.num_agents
192 |
193 | # agent - agent connection
194 | agent_pos = state.agent
195 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j
196 | dist = jnp.linalg.norm(pos_diff, axis=-1)
197 | dist += jnp.eye(dist.shape[1]) * (self._params["comm_radius"] + 1)
198 | agent_agent_mask = jnp.less(dist, self._params["comm_radius"])
199 | id_agent = jnp.arange(self.num_agents)
200 | agent_agent_edges = EdgeBlock(pos_diff, agent_agent_mask, id_agent, id_agent)
201 |
202 | # agent - goal connection, clipped to avoid too long edges
203 | id_goal = jnp.arange(self.num_agents, self.num_agents * 2)
204 | agent_goal_mask = jnp.eye(self.num_agents)
205 | agent_goal_feats = state.agent[:, None, :] - state.goal[None, :, :]
206 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(agent_goal_feats[:, :2] ** 2, axis=-1, keepdims=True))
207 | comm_radius = self._params["comm_radius"]
208 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius)
209 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0)
210 | agent_goal_feats = agent_goal_feats.at[:, :2].set(agent_goal_feats[:, :2] * coef)
211 | agent_goal_edges = EdgeBlock(
212 | agent_goal_feats, agent_goal_mask, id_agent, id_goal
213 | )
214 |
215 | # agent - obs connection
216 | id_obs = jnp.arange(self.num_agents * 2, self.num_agents * 2 + n_hits)
217 | agent_obs_edges = []
218 | for i in range(self.num_agents):
219 | id_hits = jnp.arange(i * self._params["n_rays"], (i + 1) * self._params["n_rays"])
220 | lidar_feats = agent_pos[i, :] - lidar_data[id_hits, :]
221 | lidar_dist = jnp.linalg.norm(lidar_feats, axis=-1)
222 | active_lidar = jnp.less(lidar_dist, self._params["comm_radius"] - 1e-1)
223 | agent_obs_mask = jnp.ones((1, self._params["n_rays"]))
224 | agent_obs_mask = jnp.logical_and(agent_obs_mask, active_lidar)
225 | agent_obs_edges.append(
226 | EdgeBlock(lidar_feats[None, :, :], agent_obs_mask, id_agent[i][None], id_obs[id_hits])
227 | )
228 |
229 | return [agent_agent_edges, agent_goal_edges] + agent_obs_edges
230 |
231 | def control_affine_dyn(self, state: State) -> [Array, Array]:
232 | assert state.ndim == 2
233 | f = jnp.zeros_like(state)
234 | g = jnp.eye(state.shape[1])
235 | g = jnp.expand_dims(g, axis=0).repeat(f.shape[0], axis=0)
236 | assert f.shape == state.shape
237 | assert g.shape == (state.shape[0], self.state_dim, self.action_dim)
238 | return f, g
239 |
240 | def add_edge_feats(self, graph: GraphsTuple, state: State) -> GraphsTuple:
241 | assert graph.is_single
242 | assert state.ndim == 2
243 |
244 | edge_feats = state[graph.receivers] - state[graph.senders]
245 | feats_norm = jnp.sqrt(1e-6 + jnp.sum(edge_feats[:, :2] ** 2, axis=-1, keepdims=True))
246 | comm_radius = self._params["comm_radius"]
247 | safe_feats_norm = jnp.maximum(feats_norm, comm_radius)
248 | coef = jnp.where(feats_norm > comm_radius, comm_radius / safe_feats_norm, 1.0)
249 | edge_feats = edge_feats.at[:, :2].set(edge_feats[:, :2] * coef)
250 |
251 | return graph._replace(edges=edge_feats, states=state)
252 |
253 | def get_graph(self, state: EnvState) -> GraphsTuple:
254 | # node features
255 | n_hits = self._params["n_rays"] * self.num_agents
256 | n_nodes = 2 * self.num_agents + n_hits
257 | node_feats = jnp.zeros((self.num_agents * 2 + n_hits, 3))
258 | node_feats = node_feats.at[: self.num_agents, 2].set(1) # agent feats
259 | node_feats = node_feats.at[self.num_agents: self.num_agents * 2, 1].set(1) # goal feats
260 | node_feats = node_feats.at[-n_hits:, 0].set(1) # obs feats
261 |
262 | # node type
263 | node_type = jnp.zeros(n_nodes, dtype=jnp.int32)
264 | node_type = node_type.at[self.num_agents: self.num_agents * 2].set(SingleIntegrator.GOAL)
265 | node_type = node_type.at[-n_hits:].set(SingleIntegrator.OBS)
266 |
267 | # edge blocks
268 | get_lidar_vmap = jax_vmap(
269 | ft.partial(
270 | get_lidar,
271 | obstacles=state.obstacle,
272 | num_beams=self._params["n_rays"],
273 | sense_range=self._params["comm_radius"],
274 | )
275 | )
276 | lidar_data = merge01(get_lidar_vmap(state.agent))
277 | edge_blocks = self.edge_blocks(state, lidar_data)
278 |
279 | # create graph
280 | return GetGraph(
281 | nodes=node_feats,
282 | node_type=node_type,
283 | edge_blocks=edge_blocks,
284 | env_states=state,
285 | states=jnp.concatenate([state.agent, state.goal, lidar_data], axis=0),
286 | ).to_padded()
287 |
288 | def state_lim(self, state: Optional[State] = None) -> Tuple[State, State]:
289 | lower_lim = jnp.ones(2) * -jnp.inf
290 | upper_lim = jnp.ones(2) * jnp.inf
291 | return lower_lim, upper_lim
292 |
293 | def action_lim(self) -> Tuple[Action, Action]:
294 | lower_lim = jnp.ones(2) * -1.0
295 | upper_lim = jnp.ones(2)
296 | return lower_lim, upper_lim
297 |
298 | def u_ref(self, graph: GraphsTuple) -> Action:
299 | agent = graph.type_states(type_idx=0, n_type=self.num_agents)
300 | goal = graph.type_states(type_idx=1, n_type=self.num_agents)
301 | error = goal - agent
302 | error_max = jnp.abs(error / jnp.linalg.norm(error, axis=-1, keepdims=True) * self._params["comm_radius"])
303 | error = jnp.clip(error, -error_max, error_max)
304 | return self.clip_action(error @ self._K.T)
305 |
306 | def forward_graph(self, graph: GraphsTuple, action: Action) -> GraphsTuple:
307 | # calculate next graph
308 | agent_states = graph.type_states(type_idx=0, n_type=self.num_agents)
309 | goal_states = graph.type_states(type_idx=1, n_type=self.num_agents)
310 | obs_states = graph.type_states(type_idx=2, n_type=self._params["n_rays"] * self.num_agents)
311 | action = self.clip_action(action)
312 |
313 | assert action.shape == (self.num_agents, self.action_dim)
314 | assert agent_states.shape == (self.num_agents, self.state_dim)
315 |
316 | next_agent_states = self.agent_step_euler(agent_states, action)
317 | next_states = jnp.concatenate([next_agent_states, goal_states, obs_states], axis=0)
318 |
319 | next_graph = self.add_edge_feats(graph, next_states)
320 |
321 | return next_graph
322 |
323 | @ft.partial(jax.jit, static_argnums=(0,))
324 | def safe_mask(self, graph: GraphsTuple) -> Array:
325 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)
326 |
327 | # agents are not colliding
328 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j
329 | dist = jnp.linalg.norm(pos_diff, axis=-1)
330 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["car_radius"] * 2 + 1) # remove self connection
331 | safe_agent = jnp.greater(dist, self._params["car_radius"] * 2.5)
332 | safe_agent = jnp.min(safe_agent, axis=1)
333 |
334 | safe_obs = jnp.logical_not(
335 | inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["car_radius"] * 1.5)
336 | )
337 |
338 | safe_mask = jnp.logical_and(safe_agent, safe_obs)
339 |
340 | return safe_mask
341 |
342 | @ft.partial(jax.jit, static_argnums=(0,))
343 | def unsafe_mask(self, graph: GraphsTuple) -> Array:
344 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)
345 |
346 | # agents are colliding
347 | pos_diff = agent_pos[:, None, :] - agent_pos[None, :, :] # [i, j]: i -> j
348 | dist = jnp.linalg.norm(pos_diff, axis=-1)
349 | dist = dist + jnp.eye(dist.shape[1]) * (self._params["car_radius"] * 2 + 1) # remove self connection
350 | unsafe_agent = jnp.less(dist, self._params["car_radius"] * 2)
351 | unsafe_agent = jnp.max(unsafe_agent, axis=1)
352 |
353 | # agents are colliding with obstacles
354 | unsafe_obs = inside_obstacles(agent_pos, graph.env_states.obstacle, self._params["car_radius"])
355 |
356 | unsafe_mask = jnp.logical_or(unsafe_agent, unsafe_obs)
357 |
358 | return unsafe_mask
359 |
360 | def collision_mask(self, graph: GraphsTuple) -> Array:
361 | return self.unsafe_mask(graph)
362 |
363 | def finish_mask(self, graph: GraphsTuple) -> Array:
364 | agent_pos = graph.type_states(type_idx=0, n_type=self.num_agents)[:, :2]
365 | goal_pos = graph.env_states.goal[:, :2]
366 | reach = jnp.linalg.norm(agent_pos - goal_pos, axis=1) < self._params["car_radius"] * 2
367 | return reach
368 |
--------------------------------------------------------------------------------
/gcbfplus/env/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax.numpy as jnp
3 | import functools as ft
4 | import jax
5 | import jax.random as jr
6 |
7 | from scipy.linalg import inv, solve_discrete_are
8 | from typing import Callable, Tuple
9 | from jax.lax import while_loop
10 |
11 | from ..utils.typing import Array, Radius, BoolScalar, Pos, State, Action, PRNGKey
12 | from ..utils.utils import merge01
13 | from .obstacle import Obstacle, Rectangle, Cuboid, Sphere
14 |
15 |
16 | def RK4_step(x_dot_fn: Callable, x: State, u: Action, dt: float) -> Array:
17 | k1 = x_dot_fn(x, u)
18 | k2 = x_dot_fn(x + 0.5 * dt * k1, u)
19 | k3 = x_dot_fn(x + 0.5 * dt * k2, u)
20 | k4 = x_dot_fn(x + dt * k3, u)
21 | return x + dt / 6.0 * (k1 + 2 * k2 + 2 * k3 + k4)
22 |
23 |
24 | def lqr(
25 | A: np.ndarray,
26 | B: np.ndarray,
27 | Q: np.ndarray,
28 | R: np.ndarray,
29 | ):
30 | """
31 | Solve the discrete time lqr controller.
32 | x_{t+1} = A x_t + B u_t
33 | cost = sum x.T*Q*x + u.T*R*u
34 | Code adapted from Mark Wilfred Mueller's continuous LQR code at
35 | https://www.mwm.im/lqr-controllers-with-python/
36 | Based on Bertsekas, p.151
37 | Yields the control law u = -K x
38 | """
39 |
40 | # first, try to solve the Riccati equation
41 | X = solve_discrete_are(A, B, Q, R)
42 |
43 | # compute the LQR gain
44 | K = inv(B.T @ X @ B + R) @ (B.T @ X @ A)
45 |
46 | return K
47 |
48 |
49 | def get_lidar(start_point: Pos, obstacles: Obstacle, num_beams: int, sense_range: float, max_returns: int = 32):
50 | if isinstance(obstacles, Rectangle):
51 | thetas = jnp.linspace(-np.pi, np.pi - 2 * np.pi / num_beams, num_beams)
52 | starts = start_point[None, :].repeat(num_beams, axis=0)
53 | ends = jnp.stack(
54 | [starts[..., 0] + jnp.cos(thetas) * sense_range, starts[..., 1] + jnp.sin(thetas) * sense_range],
55 | axis=-1)
56 | elif isinstance(obstacles, Cuboid) or isinstance(obstacles, Sphere):
57 | thetas = jnp.linspace(-np.pi / 2 + 2 * np.pi / num_beams, np.pi / 2 - 2 * np.pi / num_beams, num_beams // 2)
58 | phis = jnp.linspace(-np.pi, np.pi - 2 * np.pi / num_beams, num_beams)
59 | starts = start_point[None, :].repeat(thetas.shape[0] * phis.shape[0] + 2, axis=0)
60 |
61 | def get_end_point(theta, phi):
62 | return jnp.array([
63 | start_point[0] + jnp.cos(theta) * jnp.cos(phi) * sense_range,
64 | start_point[1] + jnp.cos(theta) * jnp.sin(phi) * sense_range,
65 | start_point[2] + jnp.sin(theta) * sense_range
66 | ])
67 |
68 | def get_end_point_theta(theta):
69 | return jax.vmap(lambda phi: get_end_point(theta, phi))(phis)
70 |
71 | ends = merge01(jax.vmap(get_end_point_theta)(thetas))
72 | ends = jnp.concatenate([ends,
73 | start_point[None, :] + jnp.array([[0., 0., sense_range]]),
74 | start_point[None, :] + jnp.array([[0., 0., -sense_range]])], axis=0)
75 | else:
76 | raise NotImplementedError
77 | sensor_data = raytracing(starts, ends, obstacles, max_returns)
78 |
79 | return sensor_data
80 |
81 |
82 | def inside_obstacles(points: Pos, obstacles: Obstacle, r: Radius = 0.) -> BoolScalar:
83 | """
84 | points: (n, n_dim) or (n_dim, )
85 | obstacles: tree_stacked obstacles.
86 |
87 | Returns: (n, ) or (,). True if in collision, false otherwise.
88 | """
89 | # one point inside one obstacle
90 | def inside(point: Pos, obstacle: Obstacle):
91 | return obstacle.inside(point, r)
92 |
93 | # one point inside any obstacle
94 | def inside_any(point: Pos, obstacle: Obstacle):
95 | return jax.vmap(ft.partial(inside, point))(obstacle).max()
96 |
97 | # any point inside any obstacle
98 | if points.ndim == 1:
99 | if obstacles.center.shape[0] == 0:
100 | return jnp.zeros((), dtype=bool)
101 | is_in = inside_any(points, obstacles)
102 | else:
103 | if obstacles.center.shape[0] == 0:
104 | return jnp.zeros(points.shape[0], dtype=bool)
105 | is_in = jax.vmap(ft.partial(inside_any, obstacle=obstacles))(points)
106 |
107 | return is_in
108 |
109 |
110 | def raytracing(starts: Pos, ends: Pos, obstacles: Obstacle, max_returns: int) -> Pos:
111 | # if the start point if inside the obstacle, return the start point
112 | is_in = inside_obstacles(starts, obstacles)
113 |
114 | def raytracing_single(start: Pos, end: Pos, obstacle: Obstacle):
115 | return obstacle.raytracing(start, end)
116 |
117 | def raytracing_any(start: Pos, end: Pos, obstacle: Obstacle):
118 | return jax.vmap(ft.partial(raytracing_single, start, end))(obstacle).min()
119 |
120 | if obstacles.center.shape[0] == 0:
121 | alphas = jnp.ones(starts.shape[0]) * 1e6
122 | else:
123 | alphas = jax.vmap(ft.partial(raytracing_any, obstacle=obstacles))(starts, ends)
124 | alphas *= (1 - is_in)
125 |
126 | # assert max_returns <= alphas.shape[0]
127 | alphas_return = jnp.argsort(alphas)[:max_returns]
128 |
129 | hitting_points = starts + (ends - starts) * (alphas[..., None])
130 |
131 | return hitting_points[alphas_return]
132 |
133 |
134 | def get_node_goal_rng(
135 | key: PRNGKey,
136 | side_length: float,
137 | dim: int,
138 | obstacles: Obstacle,
139 | n: int,
140 | min_dist: float,
141 | max_travel: float = None
142 | ) -> [Pos, Pos]:
143 | max_iter = 1024 # maximum number of iterations to find a valid initial state/goal
144 | states = jnp.zeros((n, dim))
145 | goals = jnp.zeros((n, dim))
146 |
147 | def get_node(reset_input: Tuple[int, Array, Array, Array]): # key, node, all nodes
148 | i_iter, this_key, _, all_nodes = reset_input
149 | use_key, this_key = jr.split(this_key, 2)
150 | i_iter += 1
151 | return i_iter, this_key, jr.uniform(use_key, (dim,), minval=0, maxval=side_length), all_nodes
152 |
153 | def non_valid_node(reset_input: Tuple[int, Array, Array, Array]): # key, node, all nodes
154 | i_iter, _, node, all_nodes = reset_input
155 | dist_min = jnp.linalg.norm(all_nodes - node, axis=1).min()
156 | collide = dist_min <= min_dist
157 | inside = inside_obstacles(node, obstacles, r=min_dist)
158 | valid = ~(collide | inside) | (i_iter >= max_iter)
159 | return ~valid
160 |
161 | def get_goal(reset_input: Tuple[int, Array, Array, Array, Array]):
162 | # key, goal_candidate, agent_start_pos, all_goals
163 | i_iter, this_key, _, agent, all_goals = reset_input
164 | use_key, this_key = jr.split(this_key, 2)
165 | i_iter += 1
166 | if max_travel is None:
167 | return i_iter, this_key, jr.uniform(use_key, (dim,), minval=0, maxval=side_length), agent, all_goals
168 | else:
169 | return i_iter, this_key, jr.uniform(
170 | use_key, (dim,), minval=-max_travel, maxval=max_travel) + agent, agent, all_goals
171 |
172 | def non_valid_goal(reset_input: Tuple[int, Array, Array, Array, Array]):
173 | # key, goal_candidate, agent_start_pos, all_goals
174 | i_iter, _, goal, agent, all_goals = reset_input
175 | dist_min = jnp.linalg.norm(all_goals - goal, axis=1).min()
176 | collide = dist_min <= min_dist
177 | inside = inside_obstacles(goal, obstacles, r=min_dist)
178 | outside = jnp.any(goal < 0) | jnp.any(goal > side_length)
179 | if max_travel is None:
180 | too_long = np.array(False, dtype=bool)
181 | else:
182 | too_long = jnp.linalg.norm(goal - agent) > max_travel
183 | valid = (~collide & ~inside & ~outside & ~too_long) | (i_iter >= max_iter)
184 | out = ~valid
185 | assert out.shape == tuple() and out.dtype == jnp.bool_
186 | return out
187 |
188 | def reset_body(reset_input: Tuple[int, Array, Array, Array]):
189 | # agent_id, key, states, goals
190 | agent_id, this_key, all_states, all_goals = reset_input
191 | agent_key, goal_key, this_key = jr.split(this_key, 3)
192 | agent_candidate = jr.uniform(agent_key, (dim,), minval=0, maxval=side_length)
193 | n_iter_agent, _, agent_candidate, _ = while_loop(
194 | cond_fun=non_valid_node, body_fun=get_node,
195 | init_val=(0, agent_key, agent_candidate, all_states)
196 | )
197 | all_states = all_states.at[agent_id].set(agent_candidate)
198 |
199 | if max_travel is None:
200 | goal_candidate = jr.uniform(goal_key, (dim,), minval=0, maxval=side_length)
201 | else:
202 | goal_candidate = jr.uniform(goal_key, (dim,), minval=0, maxval=max_travel) + agent_candidate
203 |
204 | n_iter_goal, _, goal_candidate, _, _ = while_loop(
205 | cond_fun=non_valid_goal, body_fun=get_goal,
206 | init_val=(0, goal_key, goal_candidate, agent_candidate, all_goals)
207 | )
208 | all_goals = all_goals.at[agent_id].set(goal_candidate)
209 | agent_id += 1
210 |
211 | # if no solution is found, start over
212 | agent_id = (1 - (n_iter_agent >= max_iter)) * (1 - (n_iter_goal >= max_iter)) * agent_id
213 | all_states = (1 - (n_iter_agent >= max_iter)) * (1 - (n_iter_goal >= max_iter)) * all_states
214 | all_goals = (1 - (n_iter_agent >= max_iter)) * (1 - (n_iter_goal >= max_iter)) * all_goals
215 |
216 | return agent_id, this_key, all_states, all_goals
217 |
218 | def reset_not_terminate(reset_input: Tuple[int, Array, Array, Array]):
219 | # agent_id, key, states, goals
220 | agent_id, this_key, all_states, all_goals = reset_input
221 | return agent_id < n
222 |
223 | _, _, states, goals = while_loop(
224 | cond_fun=reset_not_terminate, body_fun=reset_body, init_val=(0, key, states, goals))
225 |
226 | return states, goals
227 |
--------------------------------------------------------------------------------
/gcbfplus/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/nn/__init__.py
--------------------------------------------------------------------------------
/gcbfplus/nn/gnn.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import functools as ft
3 | import jax.numpy as jnp
4 | import jraph
5 | import jax.tree_util as jtu
6 |
7 | from typing import Type, NamedTuple, Callable, Tuple
8 |
9 | from ..utils.typing import EdgeAttr, Node, Array
10 | from ..utils.graph import GraphsTuple
11 | from .mlp import MLP, default_nn_init
12 | from .utils import safe_get
13 |
14 | save_attn = False
15 |
16 |
17 | def save_set_attn(v):
18 | global save_attn
19 | save_attn = v
20 |
21 |
22 | class GNNUpdate(NamedTuple):
23 | message: Callable[[EdgeAttr, Node, Node], Array]
24 | aggregate: Callable[[Array, Array, int], Array]
25 | update: Callable[[Node, Array], Array]
26 |
27 | def __call__(self, graph: GraphsTuple) -> GraphsTuple:
28 | assert graph.n_node.shape == tuple()
29 | node_feats_send = jtu.tree_map(lambda n: safe_get(n, graph.senders), graph.nodes)
30 | node_feats_recv = jtu.tree_map(lambda n: safe_get(n, graph.receivers), graph.nodes)
31 |
32 | # message passing
33 | edges = self.message(graph.edges, node_feats_send, node_feats_recv)
34 |
35 | # aggregate messages
36 | aggr_msg = jtu.tree_map(lambda edge: self.aggregate(edge, graph.receivers, graph.nodes.shape[0]), edges)
37 |
38 | # update nodes
39 | new_node_feats = self.update(graph.nodes, aggr_msg)
40 |
41 | return graph._replace(nodes=new_node_feats)
42 |
43 |
44 | class GNNLayer(nn.Module):
45 | msg_net_cls: Type[nn.Module]
46 | aggr_net_cls: Type[nn.Module]
47 | update_net_cls: Type[nn.Module]
48 | msg_dim: int
49 | out_dim: int
50 |
51 | @nn.compact
52 | def __call__(self, graph: GraphsTuple) -> GraphsTuple:
53 | def message(edge_feats: EdgeAttr, sender_feats: Node, receiver_feats: Node) -> Array:
54 | feats = jnp.concatenate([edge_feats, sender_feats, receiver_feats], axis=-1)
55 | feats = self.msg_net_cls()(feats)
56 | feats = nn.Dense(self.msg_dim, kernel_init=default_nn_init())(feats)
57 | return feats
58 |
59 | def update(node_feats: Node, msgs: Array) -> Array:
60 | feats = jnp.concatenate([node_feats, msgs], axis=-1)
61 | feats = self.update_net_cls()(feats)
62 | feats = nn.Dense(self.out_dim, kernel_init=default_nn_init())(feats)
63 | return feats
64 |
65 | def aggregate(msgs: Array, recv_idx: Array, num_segments: int) -> Array:
66 | gate_feats = self.aggr_net_cls()(msgs)
67 | gate_feats = nn.Dense(1, kernel_init=default_nn_init())(gate_feats).squeeze(-1)
68 | attn = jraph.segment_softmax(gate_feats, segment_ids=recv_idx, num_segments=num_segments)
69 | assert attn.shape[0] == msgs.shape[0]
70 |
71 | aggr_msg = jraph.segment_sum(attn[:, None] * msgs, segment_ids=recv_idx, num_segments=num_segments)
72 | return aggr_msg
73 |
74 | update_fn = GNNUpdate(message, aggregate, update)
75 | return update_fn(graph)
76 |
77 |
78 | class GNN(nn.Module):
79 | msg_dim: int
80 | hid_size_msg: Tuple[int, ...]
81 | hid_size_aggr: Tuple[int, ...]
82 | hid_size_update: Tuple[int, ...]
83 | out_dim: int
84 | n_layers: int
85 |
86 | @nn.compact
87 | def __call__(self, graph: GraphsTuple, node_type: int = None, n_type: int = None) -> Array:
88 | for i in range(self.n_layers):
89 | out_dim = self.out_dim if i == self.n_layers - 1 else self.msg_dim
90 | msg_net = ft.partial(MLP, hid_sizes=self.hid_size_msg, act=nn.relu, act_final=False, name="msg")
91 | attn_net = ft.partial(MLP, hid_sizes=self.hid_size_aggr, act=nn.relu, act_final=False, name="attn")
92 | update_net = ft.partial(MLP, hid_sizes=self.hid_size_update, act=nn.relu, act_final=False, name="update")
93 | gnn_layer = GNNLayer(
94 | msg_net_cls=msg_net,
95 | aggr_net_cls=attn_net,
96 | update_net_cls=update_net,
97 | msg_dim=self.msg_dim,
98 | out_dim=out_dim,
99 | )
100 | graph = gnn_layer(graph)
101 | if node_type is None:
102 | return graph.nodes
103 | else:
104 | return graph.type_nodes(node_type, n_type)
105 |
--------------------------------------------------------------------------------
/gcbfplus/nn/mlp.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 |
3 | from .utils import default_nn_init, scaled_init, AnyFloat, HidSizes, ActFn, signal_last_enumerate
4 |
5 |
6 | class MLP(nn.Module):
7 | hid_sizes: HidSizes
8 | act: ActFn = nn.relu
9 | act_final: bool = True
10 | use_layernorm: bool = False
11 | scale_final: float | None = None
12 | dropout_rate: float | None = None
13 |
14 | @nn.compact
15 | def __call__(self, x: AnyFloat, apply_dropout: bool = False) -> AnyFloat:
16 | nn_init = default_nn_init
17 | for is_last_layer, ii, hid_size in signal_last_enumerate(self.hid_sizes):
18 | if is_last_layer and self.scale_final is not None:
19 | x = nn.Dense(hid_size, kernel_init=scaled_init(nn_init(), self.scale_final))(x)
20 | else:
21 | x = nn.Dense(hid_size, kernel_init=nn_init())(x)
22 |
23 | no_activation = is_last_layer and not self.act_final
24 | if not no_activation:
25 | if self.dropout_rate is not None and self.dropout_rate > 0:
26 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not apply_dropout)(x)
27 | if self.use_layernorm:
28 | x = nn.LayerNorm()(x)
29 | x = self.act(x)
30 | return x
31 |
--------------------------------------------------------------------------------
/gcbfplus/nn/utils.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import jax.numpy as jnp
3 |
4 | from typing import Any, Callable, Literal, Sequence, Iterable, Generator, TypeVar
5 | from jaxtyping import Array, Bool, Float, Int, Shaped
6 |
7 |
8 | ActFn = Callable[[Array], Array]
9 | PRGNKey = Float[Array, '2']
10 | AnyFloat = Float[Array, '*']
11 | Shape = tuple[int, ...]
12 | InitFn = Callable[[PRGNKey, Shape, Any], Any]
13 | HidSizes = Sequence[int]
14 |
15 |
16 | _Elem = TypeVar("_Elem")
17 |
18 |
19 | default_nn_init = nn.initializers.xavier_uniform
20 |
21 |
22 | def scaled_init(initializer: nn.initializers.Initializer, scale: float) -> nn.initializers.Initializer:
23 | def scaled_init_inner(*args, **kwargs) -> AnyFloat:
24 | return scale * initializer(*args, **kwargs)
25 |
26 | return scaled_init_inner
27 |
28 |
29 | ActLiteral = Literal["relu", "tanh", "elu", "swish", "silu", "gelu", "softplus"]
30 |
31 |
32 | def get_act_from_str(act_str: ActLiteral) -> ActFn:
33 | act_dict: dict[Literal, ActFn] = dict(
34 | relu=nn.relu, tanh=nn.tanh, elu=nn.elu, swish=nn.swish, silu=nn.silu, gelu=nn.gelu, softplus=nn.softplus
35 | )
36 | return act_dict[act_str]
37 |
38 |
39 | def signal_last_enumerate(it: Iterable[_Elem]) -> Generator[tuple[bool, int, _Elem], None, None]:
40 | iterable = iter(it)
41 | count = 0
42 | ret_var = next(iterable)
43 | for val in iterable:
44 | yield False, count, ret_var
45 | count += 1
46 | ret_var = val
47 | yield True, count, ret_var
48 |
49 |
50 | def safe_get(arr, idx):
51 | return arr.at[idx].get(mode='fill', fill_value=jnp.nan)
52 |
--------------------------------------------------------------------------------
/gcbfplus/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/trainer/__init__.py
--------------------------------------------------------------------------------
/gcbfplus/trainer/buffer.py:
--------------------------------------------------------------------------------
1 | import jax.tree_util as jtu
2 | import numpy as np
3 |
4 | from abc import ABC, abstractproperty, abstractmethod
5 | from .data import Rollout
6 | from .utils import jax2np, np2jax
7 | from ..utils.utils import tree_merge
8 | from ..utils.typing import Array
9 |
10 |
11 | class Buffer(ABC):
12 |
13 | def __init__(self, size: int):
14 | self._size = size
15 |
16 | @abstractmethod
17 | def append(self, rollout: Rollout):
18 | pass
19 |
20 | @abstractmethod
21 | def sample(self, batch_size: int) -> Rollout:
22 | pass
23 |
24 | @abstractproperty
25 | def length(self) -> int:
26 | pass
27 |
28 |
29 | class ReplayBuffer(Buffer):
30 |
31 | def __init__(self, size: int):
32 | super(ReplayBuffer, self).__init__(size)
33 | self._buffer = None
34 |
35 | def append(self, rollout: Rollout):
36 | if self._buffer is None:
37 | self._buffer = jax2np(rollout)
38 | else:
39 | self._buffer = tree_merge([self._buffer, jax2np(rollout)])
40 | if self._buffer.length > self._size:
41 | self._buffer = jtu.tree_map(lambda x: x[-self._size:], self._buffer)
42 |
43 | def sample(self, batch_size: int) -> Rollout:
44 | idx = np.random.randint(0, self._buffer.length, batch_size)
45 | return np2jax(self.get_data(idx))
46 |
47 | def get_data(self, idx: np.ndarray) -> Rollout:
48 | return jtu.tree_map(lambda x: x[idx], self._buffer)
49 |
50 | @property
51 | def length(self) -> int:
52 | if self._buffer is None:
53 | return 0
54 | return self._buffer.n_data
55 |
56 |
57 | class MaskedReplayBuffer:
58 |
59 | def __init__(self, size: int):
60 | self._size = size
61 | # (b, T)
62 | self._buffer = None
63 | self._safe_mask = None
64 | self._unsafe_mask = None
65 |
66 | def append(self, rollout: Rollout, safe_mask: Array, unsafe_mask: Array):
67 | if self._buffer is None:
68 | self._buffer = jax2np(rollout)
69 | self._safe_mask = jax2np(safe_mask)
70 | self._unsafe_mask = jax2np(unsafe_mask)
71 | # self._mid_mask = jax2np(mid_mask)
72 | else:
73 | self._buffer = tree_merge([self._buffer, jax2np(rollout)])
74 | self._safe_mask = tree_merge([self._safe_mask, jax2np(safe_mask)])
75 | self._unsafe_mask = tree_merge([self._unsafe_mask, jax2np(unsafe_mask)])
76 | if self._buffer.length > self._size:
77 | self._buffer = jtu.tree_map(lambda x: x[-self._size:], self._buffer)
78 | self._safe_mask = jtu.tree_map(lambda x: x[-self._size:], self._safe_mask)
79 | self._unsafe_mask = jtu.tree_map(lambda x: x[-self._size:], self._unsafe_mask)
80 |
81 | def sample(self, batch_size: int) -> [Rollout, Array, Array]:
82 | idx = np.random.randint(0, self._buffer.length, batch_size)
83 | rollout, safe_mask, unsafe_mask = self.get_data(idx)
84 | return rollout, safe_mask, unsafe_mask
85 |
86 | def get_data(self, idx: np.ndarray) -> [Rollout, Array, Array]:
87 | return jtu.tree_map(lambda x: x[idx], self._buffer), self._safe_mask[idx], self._unsafe_mask[idx]
88 |
89 | @property
90 | def length(self) -> int:
91 | if self._buffer is None:
92 | return 0
93 | return self._buffer.n_data
94 |
--------------------------------------------------------------------------------
/gcbfplus/trainer/data.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | from ..utils.typing import Array
4 | from ..utils.typing import Action, Reward, Cost, Done
5 | from ..utils.graph import GraphsTuple
6 |
7 |
8 | class Rollout(NamedTuple):
9 | graph: GraphsTuple
10 | actions: Action
11 | rewards: Reward
12 | costs: Cost
13 | dones: Done
14 | log_pis: Array
15 | next_graph: GraphsTuple
16 |
17 | @property
18 | def length(self) -> int:
19 | return self.rewards.shape[0]
20 |
21 | @property
22 | def time_horizon(self) -> int:
23 | return self.rewards.shape[1]
24 |
25 | @property
26 | def num_agents(self) -> int:
27 | return self.rewards.shape[2]
28 |
29 | @property
30 | def n_data(self) -> int:
31 | return self.length * self.time_horizon
32 |
--------------------------------------------------------------------------------
/gcbfplus/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import os
3 | import numpy as np
4 | import jax
5 | import jax.random as jr
6 | import functools as ft
7 |
8 | from time import time
9 | from tqdm import tqdm
10 |
11 | from .data import Rollout
12 | from .utils import rollout
13 | from ..env import MultiAgentEnv
14 | from ..algo.base import MultiAgentController
15 | from ..utils.utils import jax_vmap
16 |
17 |
18 | class Trainer:
19 |
20 | def __init__(
21 | self,
22 | env: MultiAgentEnv,
23 | env_test: MultiAgentEnv,
24 | algo: MultiAgentController,
25 | n_env_train: int,
26 | n_env_test: int,
27 | log_dir: str,
28 | seed: int,
29 | params: dict,
30 | save_log: bool = True
31 | ):
32 | self.env = env
33 | self.env_test = env_test
34 | self.algo = algo
35 | self.n_env_train = n_env_train
36 | self.n_env_test = n_env_test
37 | self.log_dir = log_dir
38 | self.seed = seed
39 |
40 | if Trainer._check_params(params):
41 | self.params = params
42 |
43 | # make dir for the models
44 | if save_log:
45 | if not os.path.exists(log_dir):
46 | os.mkdir(log_dir)
47 | self.model_dir = os.path.join(log_dir, 'models')
48 | if not os.path.exists(self.model_dir):
49 | os.mkdir(self.model_dir)
50 |
51 | wandb.login()
52 | wandb.init(name=params['run_name'], project='gcbf+', dir=self.log_dir)
53 |
54 | self.save_log = save_log
55 |
56 | self.steps = params['training_steps']
57 | self.eval_interval = params['eval_interval']
58 | self.eval_epi = params['eval_epi']
59 | self.save_interval = params['save_interval']
60 |
61 | self.update_steps = 0
62 | self.key = jax.random.PRNGKey(seed)
63 |
64 | @staticmethod
65 | def _check_params(params: dict) -> bool:
66 | assert 'run_name' in params, 'run_name not found in params'
67 | assert 'training_steps' in params, 'training_steps not found in params'
68 | assert 'eval_interval' in params, 'eval_interval not found in params'
69 | assert params['eval_interval'] > 0, 'eval_interval must be positive'
70 | assert 'eval_epi' in params, 'eval_epi not found in params'
71 | assert params['eval_epi'] >= 1, 'eval_epi must be greater than or equal to 1'
72 | assert 'save_interval' in params, 'save_interval not found in params'
73 | assert params['save_interval'] > 0, 'save_interval must be positive'
74 | return True
75 |
76 | def train(self):
77 | # record start time
78 | start_time = time()
79 |
80 | # preprocess the rollout function
81 | def rollout_fn_single(params, key):
82 | return rollout(self.env, ft.partial(self.algo.step, params=params), key)
83 |
84 | def rollout_fn(params, keys):
85 | return jax.vmap(ft.partial(rollout_fn_single, params))(keys)
86 |
87 | rollout_fn = jax.jit(rollout_fn)
88 |
89 | # preprocess the test function
90 | def test_fn_single(params, key):
91 | return rollout(self.env_test, lambda graph, k: (self.algo.act(graph, params), None), key)
92 |
93 | def test_fn(params, keys):
94 | return jax.vmap(ft.partial(test_fn_single, params))(keys)
95 |
96 | test_fn = jax.jit(test_fn)
97 |
98 | # start training
99 | test_key = jr.PRNGKey(self.seed)
100 | test_keys = jr.split(test_key, 1_000)[:self.n_env_test]
101 |
102 | pbar = tqdm(total=self.steps, ncols=80)
103 | for step in range(0, self.steps + 1):
104 | # evaluate the algorithm
105 | if step % self.eval_interval == 0:
106 | test_rollouts: Rollout = test_fn(self.algo.actor_params, test_keys)
107 | total_reward = test_rollouts.rewards.sum(axis=-1)
108 | assert total_reward.shape == (self.n_env_test,)
109 | reward_min, reward_max = total_reward.min(), total_reward.max()
110 | reward_mean = np.mean(total_reward)
111 | reward_final = np.mean(test_rollouts.rewards[:, -1])
112 | finish_fun = jax_vmap(jax_vmap(self.env_test.finish_mask))
113 | finish = finish_fun(test_rollouts.graph).max(axis=1).mean()
114 | cost = test_rollouts.costs.sum(axis=-1).mean()
115 | unsafe_frac = np.mean(test_rollouts.costs.max(axis=-1) >= 1e-6)
116 | eval_info = {
117 | "eval/reward": reward_mean,
118 | "eval/reward_final": reward_final,
119 | "eval/cost": cost,
120 | "eval/unsafe_frac": unsafe_frac,
121 | "eval/finish": finish,
122 | "step": step,
123 | }
124 | wandb.log(eval_info, step=self.update_steps)
125 | time_since_start = time() - start_time
126 | eval_verbose = (f'step: {step:3}, time: {time_since_start:5.0f}s, reward: {reward_mean:9.4f}, '
127 | f'min/max reward: {reward_min:7.2f}/{reward_max:7.2f}, cost: {cost:8.4f}, '
128 | f'unsafe_frac: {unsafe_frac:6.2f}, finish: {finish:6.2f}')
129 | tqdm.write(eval_verbose)
130 | if self.save_log and step % self.save_interval == 0:
131 | self.algo.save(os.path.join(self.model_dir), step)
132 |
133 | # collect rollouts
134 | key_x0, self.key = jax.random.split(self.key)
135 | key_x0 = jax.random.split(key_x0, self.n_env_train)
136 | rollouts: Rollout = rollout_fn(self.algo.actor_params, key_x0)
137 |
138 | # update the algorithm
139 | update_info = self.algo.update(rollouts, step)
140 | wandb.log(update_info, step=self.update_steps)
141 | self.update_steps += 1
142 |
143 | pbar.update(1)
144 |
--------------------------------------------------------------------------------
/gcbfplus/trainer/utils.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import jax.tree_util as jtu
3 | import jax
4 | import numpy as np
5 | import socket
6 | import matplotlib.pyplot as plt
7 | import functools as ft
8 | import seaborn as sns
9 | import optax
10 |
11 | from typing import Callable, TYPE_CHECKING
12 | from matplotlib.colors import CenteredNorm
13 |
14 | from ..utils.typing import PRNGKey
15 | from ..utils.graph import GraphsTuple
16 | from .data import Rollout
17 |
18 |
19 | if TYPE_CHECKING:
20 | from ..env import MultiAgentEnv
21 | else:
22 | MultiAgentEnv = None
23 |
24 |
25 | def rollout(
26 | env: MultiAgentEnv,
27 | actor: Callable,
28 | key: PRNGKey
29 | ) -> Rollout:
30 | """
31 | Get a rollout from the environment using the actor.
32 |
33 | Parameters
34 | ----------
35 | env: MultiAgentEnv
36 | actor: Callable, [GraphsTuple, PRNGKey] -> [Action, LogPi]
37 | key: PRNGKey
38 |
39 | Returns
40 | -------
41 | data: Rollout
42 | """
43 | key_x0, key = jax.random.split(key)
44 | init_graph = env.reset(key_x0)
45 |
46 | def body(graph, key_):
47 | action, log_pi = actor(graph, key_)
48 | next_graph, reward, cost, done, info = env.step(graph, action)
49 | return next_graph, (graph, action, reward, cost, done, log_pi, next_graph)
50 |
51 | keys = jax.random.split(key, env.max_episode_steps)
52 | final_graph, (graphs, actions, rewards, costs, dones, log_pis, next_graphs) = \
53 | jax.lax.scan(body, init_graph, keys, length=env.max_episode_steps)
54 | data = Rollout(graphs, actions, rewards, costs, dones, log_pis, next_graphs)
55 | return data
56 |
57 |
58 | def has_nan(x):
59 | return jtu.tree_map(lambda y: jnp.isnan(y).any(), x)
60 |
61 |
62 | def has_any_nan(x):
63 | return jnp.array(jtu.tree_flatten(has_nan(x))[0]).any()
64 |
65 |
66 | def compute_norm(grad):
67 | return jnp.sqrt(sum(jnp.sum(jnp.square(x)) for x in jtu.tree_leaves(grad)))
68 |
69 |
70 | def compute_norm_and_clip(grad, max_norm: float):
71 | g_norm = compute_norm(grad)
72 | clipped_g_norm = jnp.maximum(max_norm, g_norm)
73 | clipped_grad = jtu.tree_map(lambda t: (t / clipped_g_norm) * max_norm, grad)
74 |
75 | return clipped_grad, g_norm
76 |
77 |
78 | def tree_copy(tree):
79 | return jtu.tree_map(lambda x: x.copy(), tree)
80 |
81 |
82 | def empty_grad_tx() -> optax.GradientTransformation:
83 | def init_fn(params):
84 | return optax.EmptyState()
85 |
86 | def update_fn(updates, state, params=None):
87 | return None, None
88 |
89 | return optax.GradientTransformation(init_fn, update_fn)
90 |
91 |
92 | def jax2np(x):
93 | return jtu.tree_map(lambda y: np.array(y), x)
94 |
95 |
96 | def np2jax(x):
97 | return jtu.tree_map(lambda y: jnp.array(y), x)
98 |
99 |
100 | def is_connected():
101 | try:
102 | sock = socket.create_connection(("www.google.com", 80))
103 | if sock is not None:
104 | sock.close()
105 | return True
106 | except OSError:
107 | pass
108 | print('No internet connection')
109 | return False
110 |
111 |
112 | def plot_cbf(
113 | fig: plt.Figure,
114 | cbf: Callable,
115 | env: MultiAgentEnv,
116 | graph: GraphsTuple,
117 | agent_id: int,
118 | x_dim: int,
119 | y_dim: int,
120 | ) -> plt.Figure:
121 | ax = fig.gca()
122 | n_mesh = 30
123 | low_lim, high_lim = env.state_lim(graph.states)
124 | x, y = jnp.meshgrid(
125 | jnp.linspace(low_lim[x_dim], high_lim[x_dim], n_mesh),
126 | jnp.linspace(low_lim[y_dim], high_lim[y_dim], n_mesh)
127 | )
128 | states = graph.states
129 |
130 | # generate new states
131 | plot_states = states[None, None, :, :].repeat(n_mesh, axis=0).repeat(n_mesh, axis=1)
132 | plot_states = plot_states.at[:, :, agent_id, x_dim].set(x)
133 | plot_states = plot_states.at[:, :, agent_id, y_dim].set(y)
134 |
135 | get_new_graph = env.add_edge_feats
136 | get_new_graph_vmap = jax.vmap(jax.vmap(ft.partial(get_new_graph, graph)))
137 | new_graph = get_new_graph_vmap(plot_states)
138 | h = jax.vmap(jax.vmap(cbf))(new_graph)[:, :, agent_id, :].squeeze(-1)
139 | plt.contourf(x, y, h, cmap=sns.color_palette("rocket", as_cmap=True), levels=15, alpha=0.5)
140 | plt.colorbar()
141 | plt.contour(x, y, h, levels=[0.0], colors='blue')
142 | ax.set_xlim(low_lim[0], high_lim[0])
143 | ax.set_ylim(low_lim[1], high_lim[1])
144 | plt.axis('off')
145 |
146 | return fig
147 |
148 |
149 | def get_bb_cbf(cbf: Callable, env: MultiAgentEnv, graph: GraphsTuple, agent_id: int, x_dim: int, y_dim: int):
150 | n_mesh = 20
151 | low_lim = jnp.array([0, 0])
152 | high_lim = jnp.array([env.area_size, env.area_size])
153 | b_xs = jnp.linspace(low_lim[x_dim], high_lim[x_dim], n_mesh)
154 | b_ys = jnp.linspace(low_lim[y_dim], high_lim[y_dim], n_mesh)
155 | bb_Xs, bb_Ys = jnp.meshgrid(b_xs, b_ys)
156 | states = graph.states
157 |
158 | # generate new states
159 | bb_plot_states = states[None, None, :, :].repeat(n_mesh, axis=0).repeat(n_mesh, axis=1)
160 | bb_plot_states = bb_plot_states.at[:, :, agent_id, x_dim].set(bb_Xs)
161 | bb_plot_states = bb_plot_states.at[:, :, agent_id, y_dim].set(bb_Ys)
162 |
163 | get_new_graph = env.add_edge_feats
164 | get_new_graph_vmap = jax.vmap(jax.vmap(ft.partial(get_new_graph, graph)))
165 | bb_new_graph = get_new_graph_vmap(bb_plot_states)
166 | bb_h = jax.vmap(jax.vmap(cbf))(bb_new_graph)[:, :, agent_id, :].squeeze(-1)
167 | assert bb_h.shape == (n_mesh, n_mesh)
168 | return b_xs, b_ys, bb_h
169 |
170 |
171 | def centered_norm(vmin: float | list[float], vmax: float | list[float]):
172 | if isinstance(vmin, list):
173 | vmin = min(vmin)
174 | if isinstance(vmax, list):
175 | vmin = max(vmax)
176 | halfrange = max(abs(vmin), abs(vmax))
177 | return CenteredNorm(0, halfrange)
178 |
--------------------------------------------------------------------------------
/gcbfplus/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/gcbfplus/utils/__init__.py
--------------------------------------------------------------------------------
/gcbfplus/utils/graph.py:
--------------------------------------------------------------------------------
1 | from typing import Generic, NamedTuple, TypeVar, get_type_hints
2 |
3 | import einops as ei
4 | import jax.numpy as jnp
5 | import jax.tree_util as jtu
6 | from jax._src.tree_util import GetAttrKey
7 |
8 | from ..utils.typing import Any, Array, Bool, Float, Int
9 | from .utils import merge01
10 |
11 | _State = TypeVar("_State")
12 | _EnvState = TypeVar("_EnvState")
13 |
14 |
15 | class EdgeBlock(NamedTuple):
16 | edge_feats: Float[Array, "n_recv n_send n_edge_feat"]
17 | edge_mask: Bool[Array, "n_recv n_send"]
18 | ids_recv: Int[Array, "n_recv"]
19 | ids_send: Int[Array, "n_send"]
20 |
21 | @property
22 | def n_recv(self):
23 | assert self.edge_feats.shape[0] == self.edge_mask.shape[0] == len(self.ids_recv)
24 | return len(self.ids_recv)
25 |
26 | @property
27 | def n_send(self):
28 | assert self.edge_feats.shape[1] == self.edge_mask.shape[1] == len(self.ids_send)
29 | return len(self.ids_send)
30 |
31 | @property
32 | def n_edges(self):
33 | return self.n_recv * self.n_send
34 |
35 | def make_edges(self, pad_id: int, edge_mask: Bool[Array, "n_recv n_send"] = None):
36 | id_recv_rep = ei.repeat(self.ids_recv, "n_recv -> n_recv n_send", n_send=self.n_send)
37 | id_send_rep = ei.repeat(self.ids_send, "n_send -> n_recv n_send", n_recv=self.n_recv)
38 | edge_mask = self.edge_mask if edge_mask is None else edge_mask
39 | e_recvs = merge01(jnp.where(edge_mask, id_recv_rep, pad_id))
40 | e_sends = merge01(jnp.where(edge_mask, id_send_rep, pad_id))
41 | e_edge_feats = merge01(self.edge_feats)
42 | assert e_recvs.shape == e_sends.shape == e_edge_feats.shape[:1] == (self.n_edges,)
43 |
44 | return e_edge_feats, e_recvs, e_sends
45 |
46 |
47 | @jtu.register_pytree_with_keys_class
48 | class GraphsTuple(tuple, Generic[_State, _EnvState]):
49 | n_node: Int[Array, "n_graph"] # number of nodes in each subgraph
50 | n_edge: Int[Array, "n_graph"] # number of edges in each subgraph
51 |
52 | nodes: Float[Array, "sum_n_node ..."] # node features
53 | edges: Float[Array, "sum_n_edge ..."] # edge features
54 | states: _State # node state features
55 | receivers: Int[Array, "sum_n_edge"]
56 | senders: Int[Array, "sum_n_edge"]
57 | node_type: Int[Array, "sum_n_node"] # by default, 0 is agent, -1 is padding
58 | env_states: _EnvState # environment state features
59 | connectivity: Int[Array, "sum_n_node sum_n_node"] = None # desired connectivity matrix
60 |
61 | def __new__(
62 | cls,
63 | n_node,
64 | n_edge,
65 | nodes,
66 | edges,
67 | states: _State,
68 | receivers,
69 | senders,
70 | node_type,
71 | env_states: _EnvState,
72 | connectivity=None,
73 | ):
74 | tup = (n_node, n_edge, nodes, edges, states, receivers, senders, node_type, env_states, connectivity)
75 | self = tuple.__new__(cls, tup)
76 | self.n_node = n_node
77 | self.n_edge = n_edge
78 | self.nodes = nodes
79 | self.edges = edges
80 | self.states = states
81 | self.receivers = receivers
82 | self.senders = senders
83 | self.node_type = node_type
84 | self.env_states = env_states
85 | self.connectivity = connectivity
86 | return self
87 |
88 | def tree_flatten_with_keys(self):
89 | flat_contents = [(GetAttrKey(k), getattr(self, k)) for k in get_type_hints(GraphsTuple).keys()]
90 | aux_data = None
91 | return flat_contents, aux_data
92 |
93 | @classmethod
94 | def tree_unflatten(cls, aux_data, children):
95 | return cls(*children)
96 |
97 | @property
98 | def is_single(self) -> bool:
99 | return self.n_node.ndim == 0
100 |
101 | @property
102 | def n_graphs(self) -> int:
103 | if self.n_node.ndim == 0:
104 | return 1
105 | assert len(self.n_node) == len(self.n_edge)
106 | return len(self.n_node)
107 |
108 | @property
109 | def batch_shape(self):
110 | return self.n_node.shape
111 |
112 | def type_nodes(self, type_idx: int, n_type: int) -> Float[Array, "... n_type n_feats"]:
113 | assert self.nodes.ndim == 2
114 | n_feats = self.nodes.shape[1]
115 |
116 | n_is_type = self.node_type == type_idx
117 | idx = jnp.cumsum(n_is_type) - 1
118 |
119 | sum_n_type = self.n_graphs * n_type
120 | type_feats = jnp.zeros((sum_n_type, n_feats))
121 | type_feats = type_feats.at[idx, :].add(n_is_type[:, None] * self.nodes)
122 |
123 | out = type_feats.reshape(self.batch_shape + (n_type, n_feats))
124 | return out
125 |
126 | def type_states(self, type_idx: int, n_type: int) -> Float[Array, "... n_type n_states"]:
127 | assert self.states.ndim == 2
128 | n_states = self.states.shape[1]
129 |
130 | n_is_type = self.node_type == type_idx
131 | idx = jnp.cumsum(n_is_type) - 1
132 |
133 | sum_n_type = self.n_graphs * n_type
134 | type_feats = jnp.zeros((sum_n_type, n_states))
135 | type_feats = type_feats.at[idx, :].add(n_is_type[:, None] * self.states)
136 |
137 | out = type_feats.reshape(self.batch_shape + (n_type, n_states))
138 | return out
139 |
140 | def __str__(self) -> str:
141 | node_repr = str(self.nodes)
142 | edge_repr = str(self.edges)
143 |
144 | return "n_node={}, n_edge={}, \n{}\n---------\n{}\n-------\n{}\n | \n{}".format(
145 | self.n_node, self.n_edge, node_repr, edge_repr, self.senders, self.receivers
146 | )
147 |
148 | def _replace(
149 | self,
150 | n_node=None,
151 | n_edge=None,
152 | nodes=None,
153 | edges=None,
154 | states: _State = None,
155 | receivers=None,
156 | senders=None,
157 | node_type=None,
158 | env_states: _EnvState = None,
159 | connectivity=None,
160 | ) -> "GraphsTuple":
161 | return GraphsTuple(
162 | self.n_node if n_node is None else n_node,
163 | self.n_edge if n_edge is None else n_edge,
164 | self.nodes if nodes is None else nodes,
165 | self.edges if edges is None else edges,
166 | self.states if states is None else states,
167 | self.receivers if receivers is None else receivers,
168 | self.senders if senders is None else senders,
169 | self.node_type if node_type is None else node_type,
170 | self.env_states if env_states is None else env_states,
171 | self.connectivity if connectivity is None else connectivity,
172 | )
173 |
174 | def without_edge(self):
175 | return GraphsTuple(
176 | self.n_node,
177 | self.n_edge,
178 | self.nodes,
179 | None,
180 | self.states,
181 | self.receivers,
182 | self.senders,
183 | self.node_type,
184 | self.env_states,
185 | self.connectivity,
186 | )
187 |
188 |
189 | class GetGraph(NamedTuple):
190 | nodes: Float[Array, "n_nodes n_node_feat"] # node features
191 | node_type: Int[Array, "n_nodes"] # by default, 0 is agent
192 | edge_blocks: list[EdgeBlock]
193 | env_states: Any
194 | states: Float[Array, "n_nodes n_state"] # node state features
195 | connectivity: Int[Array, "n_node n_node"] = None # desired connectivity matrix
196 |
197 | @property
198 | def n_nodes(self):
199 | return self.nodes.shape[0]
200 |
201 | @property
202 | def node_dim(self) -> int:
203 | return self.nodes.shape[1]
204 |
205 | @property
206 | def state_dim(self) -> int:
207 | return self.states.shape[1]
208 |
209 | def to_padded(self) -> GraphsTuple:
210 | # make a dummy node for creating fake self edges.
211 | node_feat_dummy = jnp.zeros(self.node_dim)
212 | node_feats_pad = jnp.concatenate([self.nodes, node_feat_dummy[None]], axis=0)
213 | node_type_pad = jnp.concatenate([self.node_type, jnp.full(1, -1)], axis=0)
214 | state_dummy = jnp.ones(self.state_dim) * -1
215 | state_pad = jnp.concatenate([self.states, state_dummy[None]], axis=0)
216 |
217 | # Construct edge list.
218 | pad_id = self.n_nodes
219 | edge_feats_lst, recv_list, send_list = [], [], []
220 | for edge_block in self.edge_blocks:
221 | e_edge_feats, e_recvs, e_sends = edge_block.make_edges(pad_id)
222 | edge_feats_lst.append(e_edge_feats)
223 | recv_list.append(e_recvs)
224 | send_list.append(e_sends)
225 | e_edge_feats = jnp.concatenate(edge_feats_lst, axis=0)
226 | e_recv, e_send = jnp.concatenate(recv_list), jnp.concatenate(send_list)
227 |
228 | n_nodes, n_edges = self.n_nodes + 1, e_edge_feats.shape[0]
229 | assert e_recv.shape == e_send.shape == (n_edges,)
230 | n_nodes = jnp.array(n_nodes, dtype=jnp.int32)
231 | n_edges = jnp.array(n_edges, dtype=jnp.int32)
232 |
233 | return GraphsTuple(
234 | n_nodes,
235 | n_edges,
236 | node_feats_pad,
237 | e_edge_feats,
238 | state_pad,
239 | e_recv,
240 | e_send,
241 | node_type_pad,
242 | self.env_states,
243 | self.connectivity,
244 | )
245 |
--------------------------------------------------------------------------------
/gcbfplus/utils/typing.py:
--------------------------------------------------------------------------------
1 | from flax import core, struct
2 | from jaxtyping import Array, Bool, Float, Int, Shaped
3 | from typing import Dict, TypeVar, Any, List
4 | from numpy import ndarray
5 |
6 |
7 | # jax types
8 | PRNGKey = Float[Array, '2']
9 |
10 | BoolScalar = Bool[Array, ""]
11 | ABool = Bool[Array, "num_agents"]
12 |
13 | # environment types
14 | Action = Float[Array, 'num_agents action_dim']
15 | Reward = Float[Array, '']
16 | Cost = Float[Array, '']
17 | Done = BoolScalar
18 | Info = Dict[str, Shaped[Array, '']]
19 | EdgeIndex = Float[Array, '2 n_edge']
20 | AgentState = Float[Array, 'num_agents agent_state_dim']
21 | State = Float[Array, 'num_states state_dim']
22 | Node = Float[Array, 'num_nodes node_dim']
23 | EdgeAttr = Float[Array, 'num_edges edge_dim']
24 | Pos2d = Float[Array, '2'] | Float[ndarray, '2']
25 | Pos3d = Float[Array, '3'] | Float[ndarray, '3']
26 | Pos = Pos2d | Pos3d
27 | Radius = Float[Array, ''] | float
28 |
29 |
30 | # neural network types
31 | Params = TypeVar("Params", bound=core.FrozenDict[str, Any])
32 |
33 | # obstacles
34 | ObsType = Int[Array, '']
35 | ObsWidth = Float[Array, '']
36 | ObsHeight = Float[Array, '']
37 | ObsLength = Float[Array, '']
38 | ObsTheta = Float[Array, '']
39 | ObsQuaternion = Float[Array, '4']
40 |
--------------------------------------------------------------------------------
/gcbfplus/utils/utils.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import jax.lax as lax
3 | import einops as ei
4 | import jax
5 | import jax.numpy as jnp
6 | import jax.tree_util as jtu
7 | import matplotlib.collections as mcollections
8 | import numpy as np
9 | import functools as ft
10 |
11 | from datetime import timedelta
12 | from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, Tuple, List, NamedTuple
13 |
14 | from jax import numpy as jnp, tree_util as jtu
15 | from jax._src.lib import xla_client as xc
16 | from matplotlib.animation import FuncAnimation
17 | from rich.progress import Progress, ProgressColumn
18 | from rich.text import Text
19 | from .typing import Array
20 |
21 |
22 | def merge01(x):
23 | return ei.rearrange(x, "n1 n2 ... -> (n1 n2) ...")
24 |
25 |
26 | _P = ParamSpec("_P")
27 | _R = TypeVar("_R")
28 | _Fn = Callable[_P, _R]
29 |
30 | _PyTree = TypeVar("_PyTree")
31 |
32 |
33 | def jax_vmap(fn: _Fn, in_axes: int | Sequence[Any] = 0, out_axes: Any = 0) -> _Fn:
34 | return jax.vmap(fn, in_axes, out_axes)
35 |
36 |
37 | def concat_at_front(arr1: jnp.ndarray, arr2: jnp.ndarray, axis: int) -> jnp.ndarray:
38 | """
39 | :param arr1: (nx, )
40 | :param arr2: (T, nx)
41 | :param axis: Which axis for arr2 to concat under.
42 | :return: (T + 1, nx) with [arr1 arr2]
43 | """
44 | # The shapes of arr1 and arr2 should be the same without the dim at axis for arr1.
45 | arr2_shape = list(arr2.shape)
46 | del arr2_shape[axis]
47 | assert np.all(np.array(arr2_shape) == np.array(arr1.shape))
48 |
49 | if isinstance(arr1, np.ndarray):
50 | return np.concatenate([np.expand_dims(arr1, axis=axis), arr2], axis=axis)
51 | else:
52 | return jnp.concatenate([jnp.expand_dims(arr1, axis=axis), arr2], axis=axis)
53 |
54 |
55 | def tree_concat_at_front(tree1: _PyTree, tree2: _PyTree, axis: int) -> _PyTree:
56 | def tree_concat_at_front_inner(arr1: jnp.ndarray, arr2: jnp.ndarray):
57 | return concat_at_front(arr1, arr2, axis=axis)
58 |
59 | return jtu.tree_map(tree_concat_at_front_inner, tree1, tree2)
60 |
61 |
62 | def tree_index(tree: _PyTree, idx: int) -> _PyTree:
63 | return jtu.tree_map(lambda x: x[idx], tree)
64 |
65 |
66 | def jax2np(pytree: _PyTree) -> _PyTree:
67 | return jtu.tree_map(np.array, pytree)
68 |
69 |
70 | def np2jax(pytree: _PyTree) -> _PyTree:
71 | return jtu.tree_map(jnp.array, pytree)
72 |
73 |
74 | def mask2index(mask: jnp.ndarray, n_true: int) -> jnp.ndarray:
75 | idx = lax.top_k(mask, n_true)[1]
76 | return idx
77 |
78 |
79 | def jax_jit_np(
80 | fn: _Fn,
81 | static_argnums: int | Sequence[int] | None = None,
82 | static_argnames: str | Iterable[str] | None = None,
83 | donate_argnums: int | Sequence[int] = (),
84 | device: xc.Device = None,
85 | *args,
86 | **kwargs,
87 | ) -> _Fn:
88 | jit_fn = jax.jit(fn, static_argnums, static_argnames, donate_argnums, device, *args, **kwargs)
89 |
90 | def wrapper(*args, **kwargs) -> _R:
91 | return jax2np(jit_fn(*args, **kwargs))
92 |
93 | return wrapper
94 |
95 |
96 | def chunk_vmap(fn: _Fn, chunks: int) -> _Fn:
97 | fn_jit_vmap = jax_jit_np(jax.vmap(fn))
98 |
99 | def wrapper(*args) -> _R:
100 | args = list(args)
101 | # 1: Get the batch size.
102 | batch_size = len(jtu.tree_leaves(args[0])[0])
103 | chunk_idxs = np.array_split(np.arange(batch_size), chunks)
104 |
105 | out = []
106 | for idxs in chunk_idxs:
107 | chunk_input = jtu.tree_map(lambda x: x[idxs], args)
108 | out.append(fn_jit_vmap(*chunk_input))
109 |
110 | # 2: Concatenate the output.
111 | out = tree_merge(out)
112 | return out
113 |
114 | return wrapper
115 |
116 | class MutablePatchCollection(mcollections.PatchCollection):
117 | def __init__(self, patches, *args, **kwargs):
118 | self._paths = None
119 | self.patches = patches
120 | mcollections.PatchCollection.__init__(self, patches, *args, **kwargs)
121 |
122 | def get_paths(self):
123 | self.set_paths(self.patches)
124 | return self._paths
125 |
126 |
127 | class CustomTimeElapsedColumn(ProgressColumn):
128 | """Renders time elapsed."""
129 |
130 | def render(self, task: "Task") -> Text:
131 | """Show time elapsed."""
132 | elapsed = task.finished_time if task.finished else task.elapsed
133 | if elapsed is None:
134 | return Text("-:--:--", style="progress.elapsed")
135 | delta = timedelta(seconds=elapsed)
136 | delta = timedelta(seconds=delta.seconds, milliseconds=round(delta.microseconds // 1000))
137 | delta_str = str(delta)
138 | return Text(delta_str, style="progress.elapsed")
139 |
140 |
141 | def save_anim(ani: FuncAnimation, path: pathlib.Path):
142 | pbar = Progress(*Progress.get_default_columns(), CustomTimeElapsedColumn())
143 | pbar.start()
144 | task = pbar.add_task("Animating", total=ani._save_count)
145 |
146 | def progress_callback(curr_frame: int, total_frames: int):
147 | pbar.update(task, advance=1)
148 |
149 | ani.save(path, progress_callback=progress_callback)
150 | pbar.stop()
151 |
152 |
153 | def tree_merge(data: List[NamedTuple]):
154 | def body(*x):
155 | x = list(x)
156 | if isinstance(x[0], np.ndarray):
157 | return np.concatenate(x, axis=0)
158 | else:
159 | return jnp.concatenate(x, axis=0)
160 | out = jtu.tree_map(body, *data)
161 | return out
162 |
163 |
164 | def tree_stack(trees: list):
165 | def tree_stack_inner(*arrs):
166 | arrs = list(arrs)
167 | if isinstance(arrs[0], np.ndarray):
168 | return np.stack(arrs, axis=0)
169 | return np.stack(arrs, axis=0)
170 |
171 | return jtu.tree_map(tree_stack_inner, *trees)
172 |
--------------------------------------------------------------------------------
/media/DoubleIntegrator_512_2x.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/DoubleIntegrator_512_2x.gif
--------------------------------------------------------------------------------
/media/Obstacle2D_32.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/Obstacle2D_32.gif
--------------------------------------------------------------------------------
/media/Obstacle2D_512_2x.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/Obstacle2D_512_2x.gif
--------------------------------------------------------------------------------
/media/cbf1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/media/cbf1.gif
--------------------------------------------------------------------------------
/pretrained/CrazyFlie/gcbf+/config.yaml:
--------------------------------------------------------------------------------
1 | !!python/object:argparse.Namespace
2 | algo: gcbf+
3 | alpha: 1.0
4 | area_size: 2.0
5 | batch_size: 256
6 | buffer_size: 512
7 | cost_weight: 1.0
8 | debug: false
9 | env: CrazyFlie
10 | eval_epi: 1
11 | eval_interval: 1
12 | gnn_layers: 1
13 | horizon: 32
14 | log_dir: ./logs
15 | loss_action_coef: 3.0e-05
16 | loss_h_dot_coef: 0.01
17 | loss_safe_coef: 1.0
18 | loss_unsafe_coef: 1.0
19 | lr_actor: 1.0e-05
20 | lr_cbf: 0.0001
21 | lr_critic: 0.0001
22 | n_env_test: 32
23 | n_env_train: 16
24 | n_rays: 32
25 | name: null
26 | num_agents: 8
27 | obs: 0
28 | save_interval: 10
29 | seed: 0
30 | steps: 1000
31 | alpha: 1.0
32 | batch_size: 256
33 | eps: 0.02
34 | gnn_layers: 1
35 | horizon: 32
36 | inner_epoch: 8
37 | loss_action_coef: 3.0e-05
38 | loss_h_dot_coef: 0.01
39 | loss_safe_coef: 1.0
40 | loss_unsafe_coef: 1.0
41 | lr_actor: 1.0e-05
42 | lr_cbf: 0.0001
43 | max_grad_norm: 2.0
44 | seed: 0
45 |
--------------------------------------------------------------------------------
/pretrained/CrazyFlie/gcbf+/models/1000/actor.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/CrazyFlie/gcbf+/models/1000/actor.pkl
--------------------------------------------------------------------------------
/pretrained/CrazyFlie/gcbf+/models/1000/cbf.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/CrazyFlie/gcbf+/models/1000/cbf.pkl
--------------------------------------------------------------------------------
/pretrained/DoubleIntegrator/gcbf+/config.yaml:
--------------------------------------------------------------------------------
1 | !!python/object:argparse.Namespace
2 | algo: gcbf+
3 | batch_size: 256
4 | buffer_size: 512
5 | debug: false
6 | env: DoubleIntegrator
7 | eval_epi: 1
8 | eval_interval: 1
9 | gnn_layers: 1
10 | log_dir: ./logs
11 | loss_action_coef: 0.0001
12 | loss_h_dot_coef: 0.01
13 | loss_safe_coef: 1.0
14 | loss_unsafe_coef: 1.0
15 | lr_actor: 1.0e-05
16 | lr_cbf: 1.0e-05
17 | n_env_test: 32
18 | n_env_train: 16
19 | name: null
20 | num_agents: 8
21 | save_interval: 10
22 | seed: 2
23 | steps: 1000
24 | alpha: 1.0
25 | batch_size: 256
26 | eps: 0.02
27 | gnn_layers: 1
28 | horizon: 32
29 | inner_epoch: 8
30 | loss_action_coef: 0.0001
31 | loss_h_dot_coef: 0.01
32 | loss_safe_coef: 1.0
33 | loss_unsafe_coef: 1.0
34 | lr_actor: 1.0e-05
35 | lr_cbf: 1.0e-05
36 | max_grad_norm: 2.0
37 | seed: 2
38 |
--------------------------------------------------------------------------------
/pretrained/DoubleIntegrator/gcbf+/models/1000/actor.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DoubleIntegrator/gcbf+/models/1000/actor.pkl
--------------------------------------------------------------------------------
/pretrained/DoubleIntegrator/gcbf+/models/1000/cbf.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DoubleIntegrator/gcbf+/models/1000/cbf.pkl
--------------------------------------------------------------------------------
/pretrained/DubinsCar/gcbf+/config.yaml:
--------------------------------------------------------------------------------
1 | !!python/object:argparse.Namespace
2 | algo: gcbf+
3 | alpha: 1.0
4 | area_size: 4.0
5 | batch_size: 256
6 | buffer_size: 512
7 | cost_weight: 1.0
8 | debug: false
9 | env: DubinsCar
10 | eval_epi: 1
11 | eval_interval: 1
12 | gnn_layers: 1
13 | horizon: 32
14 | log_dir: ./logs
15 | loss_action_coef: 1.0e-05
16 | loss_h_dot_coef: 0.01
17 | loss_safe_coef: 1.0
18 | loss_unsafe_coef: 1.0
19 | lr_actor: 2.0e-05
20 | lr_cbf: 2.0e-05
21 | lr_critic: 0.0001
22 | n_env_test: 32
23 | n_env_train: 12
24 | n_rays: 32
25 | name: null
26 | num_agents: 8
27 | obs: 4
28 | save_interval: 10
29 | seed: 1
30 | steps: 1000
31 | alpha: 1.0
32 | batch_size: 256
33 | eps: 0.02
34 | gnn_layers: 1
35 | horizon: 32
36 | inner_epoch: 8
37 | loss_action_coef: 1.0e-05
38 | loss_h_dot_coef: 0.01
39 | loss_safe_coef: 1.0
40 | loss_unsafe_coef: 1.0
41 | lr_actor: 2.0e-05
42 | lr_cbf: 2.0e-05
43 | max_grad_norm: 2.0
44 | seed: 1
45 |
--------------------------------------------------------------------------------
/pretrained/DubinsCar/gcbf+/models/1000/actor.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DubinsCar/gcbf+/models/1000/actor.pkl
--------------------------------------------------------------------------------
/pretrained/DubinsCar/gcbf+/models/1000/cbf.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/DubinsCar/gcbf+/models/1000/cbf.pkl
--------------------------------------------------------------------------------
/pretrained/LinearDrone/gcbf+/config.yaml:
--------------------------------------------------------------------------------
1 | !!python/object:argparse.Namespace
2 | algo: gcbf+
3 | alpha: 1.0
4 | area_size: 2.0
5 | batch_size: 256
6 | buffer_size: 512
7 | cost_weight: 1.0
8 | debug: false
9 | env: LinearDrone
10 | eval_epi: 1
11 | eval_interval: 1
12 | gnn_layers: 1
13 | horizon: 32
14 | log_dir: ./logs
15 | loss_action_coef: 0.0001
16 | loss_h_dot_coef: 0.01
17 | loss_safe_coef: 1.0
18 | loss_unsafe_coef: 1.0
19 | lr_actor: 1.0e-05
20 | lr_cbf: 1.0e-05
21 | lr_critic: 0.0001
22 | n_env_test: 32
23 | n_env_train: 12
24 | n_rays: 32
25 | name: null
26 | num_agents: 8
27 | obs: 8
28 | save_interval: 10
29 | seed: 2
30 | steps: 1000
31 | alpha: 1.0
32 | batch_size: 256
33 | eps: 0.02
34 | gnn_layers: 1
35 | horizon: 32
36 | inner_epoch: 8
37 | loss_action_coef: 0.0001
38 | loss_h_dot_coef: 0.01
39 | loss_safe_coef: 1.0
40 | loss_unsafe_coef: 1.0
41 | lr_actor: 1.0e-05
42 | lr_cbf: 1.0e-05
43 | max_grad_norm: 2.0
44 | seed: 2
45 |
--------------------------------------------------------------------------------
/pretrained/LinearDrone/gcbf+/models/1000/actor.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/LinearDrone/gcbf+/models/1000/actor.pkl
--------------------------------------------------------------------------------
/pretrained/LinearDrone/gcbf+/models/1000/cbf.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/LinearDrone/gcbf+/models/1000/cbf.pkl
--------------------------------------------------------------------------------
/pretrained/SingleIntegrator/gcbf+/config.yaml:
--------------------------------------------------------------------------------
1 | !!python/object:argparse.Namespace
2 | algo: gcbf+
3 | alpha: 1.0
4 | batch_size: 256
5 | buffer_size: 512
6 | debug: false
7 | env: SingleIntegrator
8 | eval_epi: 1
9 | eval_interval: 1
10 | gnn_layers: 1
11 | horizon: 1
12 | log_dir: ./logs
13 | loss_action_coef: 0.0001
14 | loss_h_dot_coef: 0.01
15 | loss_safe_coef: 1.0
16 | loss_unsafe_coef: 1.0
17 | lr_actor: 1.0e-05
18 | lr_cbf: 1.0e-05
19 | n_env_test: 32
20 | n_env_train: 16
21 | name: null
22 | num_agents: 8
23 | obs: null
24 | save_interval: 10
25 | seed: 0
26 | steps: 1000
27 | alpha: 1.0
28 | batch_size: 256
29 | eps: 0.02
30 | gnn_layers: 1
31 | horizon: 1
32 | inner_epoch: 8
33 | loss_action_coef: 0.0001
34 | loss_h_dot_coef: 0.01
35 | loss_safe_coef: 1.0
36 | loss_unsafe_coef: 1.0
37 | lr_actor: 1.0e-05
38 | lr_cbf: 1.0e-05
39 | max_grad_norm: 2.0
40 | seed: 0
41 |
--------------------------------------------------------------------------------
/pretrained/SingleIntegrator/gcbf+/models/1000/actor.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/SingleIntegrator/gcbf+/models/1000/actor.pkl
--------------------------------------------------------------------------------
/pretrained/SingleIntegrator/gcbf+/models/1000/cbf.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-REALM/gcbfplus/fb449907bdbf981aa10f0edfecca02663ddc8037/pretrained/SingleIntegrator/gcbf+/models/1000/cbf.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | flax>=0.7.2
2 | jax>=0.4.14
3 | jraph>=0.0.6.dev0
4 | jaxtyping>=0.2.21
5 | numpy>=1.25.2
6 | einops>=0.6.1
7 | matplotlib>=3.7.2
8 | opencv-python>=4.8.0.76
9 | tqdm>=4.66.1
10 | tensorflow-probability>=0.21.0
11 | optax>=0.1.7
12 | scipy>=1.11.2
13 | wandb>=0.15.8
14 | pyyaml>=6.0.1
15 | orbax_checkpoint>=0.3.5
16 | seaborn>=0.12.2
17 | equinox>=0.11.0
18 | loguru>=0.7.2
19 | attrs>=23.1.0
20 | rich>=13.5.3
21 | ipdb>=0.13.13
22 | colour>=0.1.5
23 | control>=0.9.4
24 | jaxproxqp @ git+https://github.com/oswinso/jaxproxqp.git
--------------------------------------------------------------------------------
/settings.yaml:
--------------------------------------------------------------------------------
1 | DubinsCar:
2 | --env: DubinsCar
3 | --lr-actor: 3e-5
4 | --lr-cbf: 3e-5
5 | --loss-action-coef: 1e-5
6 | --n-env-train: 16
7 | --horizon: 32
8 | --area-size: 4
9 | DoubleIntegrator:
10 | --env: DoubleIntegrator
11 | --lr-actor: 1e-5
12 | --lr-cbf: 1e-5
13 | --loss-action-coef: 1e-4
14 | --n-env-train: 16
15 | --horizon: 32
16 | --area-size: 4
17 | SingleIntegrator:
18 | --env: SingleIntegrator
19 | --lr-actor: 1e-5
20 | --lr-cbf: 1e-5
21 | --loss-action-coef: 1e-4
22 | --n-env-train: 16
23 | --horizon: 1
24 | --area-size: 4
25 | LinearDrone:
26 | --env: LinearDrone
27 | --lr-actor: 1e-5
28 | --lr-cbf: 1e-5
29 | --loss-action-coef: 1e-3
30 | --n-env-train: 16
31 | --horizon: 32
32 | --area-size: 2
33 | CrazyFlie:
34 | --env: CrazyFlie
35 | --lr-actor: 1e-5
36 | --lr-cbf: 1e-4
37 | --loss-action-coef: 3e-5
38 | --n-env-train: 16
39 | --horizon: 32
40 | --area-size: 2
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 |
6 | setup(
7 | name="gcbfplus",
8 | version="0.0.0",
9 | description='Jax Official Implementation of CoRL Paper: : S Zhang, O So, K Garg, C Fan: '
10 | '"GCBF+: A Neural Graph Control Barrier Function Framework for Distributed Safe Multi-Agent Control"',
11 | author="Songyuan Zhang",
12 | author_email="szhang21@mit.edu",
13 | url="https://github.com/MIT-REALM/gcbfplus",
14 | install_requires=[],
15 | packages=find_packages(),
16 | )
17 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import functools as ft
4 | import os
5 | import pathlib
6 | import ipdb
7 | import jax
8 | import jax.numpy as jnp
9 | import jax.random as jr
10 | import jax.tree_util as jtu
11 | import numpy as np
12 | import yaml
13 |
14 | from gcbfplus.algo import GCBF, GCBFPlus, make_algo, CentralizedCBF, DecShareCBF
15 | from gcbfplus.env import make_env
16 | from gcbfplus.env.base import RolloutResult
17 | from gcbfplus.trainer.utils import get_bb_cbf
18 | from gcbfplus.utils.graph import GraphsTuple
19 | from gcbfplus.utils.utils import jax_jit_np, tree_index, chunk_vmap, merge01, jax_vmap
20 |
21 |
22 | def test(args):
23 | print(f"> Running test.py {args}")
24 |
25 | stamp_str = datetime.datetime.now().strftime("%m%d-%H%M")
26 |
27 | # set up environment variables and seed
28 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
29 | if args.cpu:
30 | os.environ["JAX_PLATFORM_NAME"] = "cpu"
31 | if args.debug:
32 | jax.config.update("jax_disable_jit", True)
33 | np.random.seed(args.seed)
34 |
35 | # load config
36 | if not args.u_ref and args.path is not None:
37 | with open(os.path.join(args.path, "config.yaml"), "r") as f:
38 | config = yaml.load(f, Loader=yaml.UnsafeLoader)
39 |
40 | # create environments
41 | num_agents = config.num_agents if args.num_agents is None else args.num_agents
42 | env = make_env(
43 | env_id=config.env if args.env is None else args.env,
44 | num_agents=num_agents,
45 | num_obs=args.obs,
46 | area_size=args.area_size,
47 | max_step=args.max_step,
48 | max_travel=args.max_travel,
49 | )
50 |
51 | if not args.u_ref:
52 | if args.path is not None:
53 | path = args.path
54 | model_path = os.path.join(path, "models")
55 | if args.step is None:
56 | models = os.listdir(model_path)
57 | step = max([int(model) for model in models if model.isdigit()])
58 | else:
59 | step = args.step
60 | print("step: ", step)
61 |
62 | algo = make_algo(
63 | algo=config.algo,
64 | env=env,
65 | node_dim=env.node_dim,
66 | edge_dim=env.edge_dim,
67 | state_dim=env.state_dim,
68 | action_dim=env.action_dim,
69 | n_agents=env.num_agents,
70 | gnn_layers=config.gnn_layers,
71 | batch_size=config.batch_size,
72 | buffer_size=config.buffer_size,
73 | horizon=config.horizon,
74 | lr_actor=config.lr_actor,
75 | lr_cbf=config.lr_cbf,
76 | alpha=config.alpha,
77 | eps=0.02,
78 | inner_epoch=8,
79 | loss_action_coef=config.loss_action_coef,
80 | loss_unsafe_coef=config.loss_unsafe_coef,
81 | loss_safe_coef=config.loss_safe_coef,
82 | loss_h_dot_coef=config.loss_h_dot_coef,
83 | max_grad_norm=2.0,
84 | seed=config.seed
85 | )
86 | algo.load(model_path, step)
87 | act_fn = jax.jit(algo.act)
88 | else:
89 | algo = make_algo(
90 | algo=args.algo,
91 | env=env,
92 | node_dim=env.node_dim,
93 | edge_dim=env.edge_dim,
94 | state_dim=env.state_dim,
95 | action_dim=env.action_dim,
96 | n_agents=env.num_agents,
97 | alpha=args.alpha,
98 | )
99 | act_fn = jax.jit(algo.act)
100 | path = os.path.join(f"./logs/{args.env}/{args.algo}")
101 | if not os.path.exists(path):
102 | os.makedirs(path)
103 | step = None
104 | else:
105 | assert args.env is not None
106 | path = os.path.join(f"./logs/{args.env}/nominal")
107 | if not os.path.exists("./logs"):
108 | os.mkdir("./logs")
109 | if not os.path.exists(os.path.join("./logs", args.env)):
110 | os.mkdir(os.path.join("./logs", args.env))
111 | if not os.path.exists(path):
112 | os.mkdir(path)
113 | algo = None
114 | act_fn = jax.jit(env.u_ref)
115 | step = 0
116 |
117 | test_key = jr.PRNGKey(args.seed)
118 | test_keys = jr.split(test_key, 1_000)[: args.epi]
119 | test_keys = test_keys[args.offset:]
120 |
121 | algo_is_cbf = isinstance(algo, (CentralizedCBF, DecShareCBF))
122 |
123 | if args.cbf is not None:
124 | assert isinstance(algo, GCBF) or isinstance(algo, GCBFPlus) or isinstance(algo, CentralizedCBF)
125 | get_bb_cbf_fn_ = ft.partial(get_bb_cbf, algo.get_cbf, env, agent_id=args.cbf, x_dim=0, y_dim=1)
126 | get_bb_cbf_fn_ = jax_jit_np(get_bb_cbf_fn_)
127 |
128 | def get_bb_cbf_fn(T_graph: GraphsTuple):
129 | T = len(T_graph.states)
130 | outs = [get_bb_cbf_fn_(tree_index(T_graph, kk)) for kk in range(T)]
131 | Tb_x, Tb_y, Tbb_h = jtu.tree_map(lambda *x: jnp.stack(list(x), axis=0), *outs)
132 | return Tb_x, Tb_y, Tbb_h
133 | else:
134 | get_bb_cbf_fn = None
135 | cbf_fn = None
136 |
137 | if args.nojit_rollout:
138 | print("Only jit step, no jit rollout!")
139 | rollout_fn = env.rollout_fn_jitstep(act_fn, args.max_step, noedge=True, nograph=args.no_video)
140 |
141 | is_unsafe_fn = None
142 | is_finish_fn = None
143 | else:
144 | print("jit rollout!")
145 | rollout_fn = jax_jit_np(env.rollout_fn(act_fn, args.max_step))
146 |
147 | is_unsafe_fn = jax_jit_np(jax_vmap(env.collision_mask))
148 | is_finish_fn = jax_jit_np(jax_vmap(env.finish_mask))
149 |
150 | rewards = []
151 | costs = []
152 | rollouts = []
153 | is_unsafes = []
154 | is_finishes = []
155 | rates = []
156 | cbfs = []
157 | for i_epi in range(args.epi):
158 | key_x0, _ = jr.split(test_keys[i_epi], 2)
159 |
160 | if args.nojit_rollout:
161 | rollout: RolloutResult
162 | rollout, is_unsafe, is_finish = rollout_fn(key_x0)
163 | # if not jnp.isnan(rollout.T_reward).any():
164 | is_unsafes.append(is_unsafe)
165 | is_finishes.append(is_finish)
166 | else:
167 | rollout: RolloutResult = rollout_fn(key_x0)
168 | # if not jnp.isnan(rollout.T_reward).any():
169 | is_unsafes.append(is_unsafe_fn(rollout.Tp1_graph))
170 | is_finishes.append(is_finish_fn(rollout.Tp1_graph))
171 |
172 | epi_reward = rollout.T_reward.sum()
173 | epi_cost = rollout.T_cost.sum()
174 | rewards.append(epi_reward)
175 | costs.append(epi_cost)
176 | rollouts.append(rollout)
177 |
178 | if args.cbf is not None:
179 | cbfs.append(get_bb_cbf_fn(rollout.Tp1_graph))
180 | else:
181 | cbfs.append(None)
182 | if len(is_unsafes) == 0:
183 | continue
184 | safe_rate = 1 - is_unsafes[-1].max(axis=0).mean()
185 | finish_rate = is_finishes[-1].max(axis=0).mean()
186 | success_rate = ((1 - is_unsafes[-1].max(axis=0)) * is_finishes[-1].max(axis=0)).mean()
187 | print(f"epi: {i_epi}, reward: {epi_reward:.3f}, cost: {epi_cost:.3f}, "
188 | f"safe rate: {safe_rate * 100:.3f}%,"
189 | f"finish rate: {finish_rate * 100:.3f}%, "
190 | f"success rate: {success_rate * 100:.3f}%")
191 |
192 | rates.append(np.array([safe_rate, finish_rate, success_rate]))
193 | is_unsafe = np.max(np.stack(is_unsafes), axis=1)
194 | is_finish = np.max(np.stack(is_finishes), axis=1)
195 |
196 | safe_mean, safe_std = (1 - is_unsafe).mean(), (1 - is_unsafe).std()
197 | finish_mean, finish_std = is_finish.mean(), is_finish.std()
198 | success_mean, success_std = ((1 - is_unsafe) * is_finish).mean(), ((1 - is_unsafe) * is_finish).std()
199 |
200 | print(
201 | f"reward: {np.mean(rewards):.3f}, min/max reward: {np.min(rewards):.3f}/{np.max(rewards):.3f}, "
202 | f"cost: {np.mean(costs):.3f}, min/max cost: {np.min(costs):.3f}/{np.max(costs):.3f}, "
203 | f"safe_rate: {safe_mean * 100:.3f}%, "
204 | f"finish_rate: {finish_mean * 100:.3f}%, "
205 | f"success_rate: {success_mean * 100:.3f}%"
206 | )
207 |
208 | # save results
209 | if args.log:
210 | with open(os.path.join(path, "test_log.csv"), "a") as f:
211 | f.write(f"{env.num_agents},{args.epi},{env.max_episode_steps},"
212 | f"{env.area_size},{env.params['n_obs']},"
213 | f"{safe_mean * 100:.3f},{safe_std * 100:.3f},"
214 | f"{finish_mean * 100:.3f},{finish_std * 100:.3f},"
215 | f"{success_mean * 100:.3f},{success_std * 100:.3f}\n")
216 |
217 | # make video
218 | if args.no_video:
219 | return
220 |
221 | videos_dir = pathlib.Path(path) / "videos"
222 | videos_dir.mkdir(exist_ok=True, parents=True)
223 | for ii, (rollout, Ta_is_unsafe, cbf) in enumerate(zip(rollouts, is_unsafes, cbfs)):
224 | if algo_is_cbf:
225 | safe_rate, finish_rate, success_rate = rates[ii] * 100
226 | video_name = f"n{num_agents}_epi{ii:02}_sr{safe_rate:.0f}_fr{finish_rate:.0f}_sr{success_rate:.0f}"
227 | else:
228 | video_name = f"n{num_agents}_step{step}_epi{ii:02}_reward{rewards[ii]:.3f}_cost{costs[ii]:.3f}"
229 |
230 | viz_opts = {}
231 | if args.cbf is not None:
232 | video_name += f"_cbf{args.cbf}"
233 | viz_opts["cbf"] = [*cbf, args.cbf]
234 |
235 | video_path = videos_dir / f"{stamp_str}_{video_name}.mp4"
236 | env.render_video(rollout, video_path, Ta_is_unsafe, viz_opts, dpi=args.dpi)
237 |
238 |
239 | def main():
240 | parser = argparse.ArgumentParser()
241 | parser.add_argument("-n", "--num-agents", type=int, default=None)
242 | parser.add_argument("--obs", type=int, default=0)
243 | parser.add_argument("--area-size", type=float, required=True)
244 | parser.add_argument("--max-step", type=int, default=None)
245 | parser.add_argument("--path", type=str, default=None)
246 | parser.add_argument("--n-rays", type=int, default=32)
247 | parser.add_argument("--alpha", type=float, default=1.0)
248 | parser.add_argument("--max-travel", type=float, default=None)
249 | parser.add_argument("--cbf", type=int, default=None)
250 |
251 | parser.add_argument("--seed", type=int, default=1234)
252 | parser.add_argument("--debug", action="store_true", default=False)
253 | parser.add_argument("--cpu", action="store_true", default=False)
254 | parser.add_argument("--u-ref", action="store_true", default=False)
255 | parser.add_argument("--env", type=str, default=None)
256 | parser.add_argument("--algo", type=str, default=None)
257 | parser.add_argument("--step", type=int, default=None)
258 | parser.add_argument("--epi", type=int, default=5)
259 | parser.add_argument("--offset", type=int, default=0)
260 | parser.add_argument("--no-video", action="store_true", default=False)
261 | parser.add_argument("--nojit-rollout", action="store_true", default=False)
262 | parser.add_argument("--log", action="store_true", default=False)
263 | parser.add_argument("--dpi", type=int, default=100)
264 |
265 | args = parser.parse_args()
266 | test(args)
267 |
268 |
269 | if __name__ == "__main__":
270 | with ipdb.launch_ipdb_on_exception():
271 | main()
272 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 | import ipdb
5 | import numpy as np
6 | import wandb
7 | import yaml
8 |
9 | from gcbfplus.algo import make_algo
10 | from gcbfplus.env import make_env
11 | from gcbfplus.trainer.trainer import Trainer
12 | from gcbfplus.trainer.utils import is_connected
13 |
14 |
15 | def train(args):
16 | print(f"> Running train.py {args}")
17 |
18 | # set up environment variables and seed
19 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
20 | if not is_connected():
21 | os.environ["WANDB_MODE"] = "offline"
22 | np.random.seed(args.seed)
23 | if args.debug:
24 | os.environ["WANDB_MODE"] = "disabled"
25 | os.environ["JAX_DISABLE_JIT"] = "True"
26 |
27 | # create environments
28 | env = make_env(
29 | env_id=args.env,
30 | num_agents=args.num_agents,
31 | num_obs=args.obs,
32 | n_rays=args.n_rays,
33 | area_size=args.area_size
34 | )
35 | env_test = make_env(
36 | env_id=args.env,
37 | num_agents=args.num_agents,
38 | num_obs=args.obs,
39 | n_rays=args.n_rays,
40 | area_size=args.area_size
41 | )
42 |
43 | # create low level controller
44 | algo = make_algo(
45 | algo=args.algo,
46 | env=env,
47 | node_dim=env.node_dim,
48 | edge_dim=env.edge_dim,
49 | state_dim=env.state_dim,
50 | action_dim=env.action_dim,
51 | n_agents=env.num_agents,
52 | gnn_layers=args.gnn_layers,
53 | batch_size=256,
54 | buffer_size=args.buffer_size,
55 | horizon=args.horizon,
56 | lr_actor=args.lr_actor,
57 | lr_cbf=args.lr_cbf,
58 | alpha=args.alpha,
59 | eps=0.02,
60 | inner_epoch=8,
61 | loss_action_coef=args.loss_action_coef,
62 | loss_unsafe_coef=args.loss_unsafe_coef,
63 | loss_safe_coef=args.loss_safe_coef,
64 | loss_h_dot_coef=args.loss_h_dot_coef,
65 | max_grad_norm=2.0,
66 | seed=args.seed,
67 | )
68 |
69 | # set up logger
70 | start_time = datetime.datetime.now()
71 | start_time = start_time.strftime("%Y%m%d%H%M%S")
72 | if not os.path.exists(args.log_dir):
73 | os.makedirs(args.log_dir)
74 | if not os.path.exists(f"{args.log_dir}/{args.env}"):
75 | os.makedirs(f"{args.log_dir}/{args.env}")
76 | if not os.path.exists(f"{args.log_dir}/{args.env}/{args.algo}"):
77 | os.makedirs(f"{args.log_dir}/{args.env}/{args.algo}")
78 | log_dir = f"{args.log_dir}/{args.env}/{args.algo}/seed{args.seed}_{start_time}"
79 | run_name = f"{args.algo}_{args.env}_{start_time}" if args.name is None else args.name
80 |
81 | # get training parameters
82 | train_params = {
83 | "run_name": run_name,
84 | "training_steps": args.steps,
85 | "eval_interval": args.eval_interval,
86 | "eval_epi": args.eval_epi,
87 | "save_interval": args.save_interval,
88 | }
89 |
90 | # create trainer
91 | trainer = Trainer(
92 | env=env,
93 | env_test=env_test,
94 | algo=algo,
95 | log_dir=log_dir,
96 | n_env_train=args.n_env_train,
97 | n_env_test=args.n_env_test,
98 | seed=args.seed,
99 | params=train_params,
100 | save_log=not args.debug,
101 | )
102 |
103 | # save config
104 | wandb.config.update(args)
105 | wandb.config.update(algo.config)
106 | if not args.debug:
107 | with open(f"{log_dir}/config.yaml", "w") as f:
108 | yaml.dump(args, f)
109 | yaml.dump(algo.config, f)
110 |
111 | # start training
112 | trainer.train()
113 |
114 |
115 | def main():
116 | parser = argparse.ArgumentParser()
117 |
118 | # custom arguments
119 | parser.add_argument("-n", "--num-agents", type=int, default=8)
120 | parser.add_argument("--algo", type=str, default="gcbf+")
121 | parser.add_argument("--env", type=str, default="SimpleCar")
122 | parser.add_argument("--seed", type=int, default=0)
123 | parser.add_argument("--steps", type=int, default=1000)
124 | parser.add_argument("--name", type=str, default=None)
125 | parser.add_argument("--debug", action="store_true", default=False)
126 | parser.add_argument("--obs", type=int, default=None)
127 | parser.add_argument("--n-rays", type=int, default=32)
128 | parser.add_argument("--area-size", type=float, required=True)
129 |
130 | # gcbf / gcbf+ arguments
131 | parser.add_argument("--gnn-layers", type=int, default=1)
132 | parser.add_argument("--alpha", type=float, default=1.0)
133 | parser.add_argument("--horizon", type=int, default=32)
134 | parser.add_argument("--lr-actor", type=float, default=3e-5)
135 | parser.add_argument("--lr-cbf", type=float, default=3e-5)
136 | parser.add_argument("--loss-action-coef", type=float, default=0.0001)
137 | parser.add_argument("--loss-unsafe-coef", type=float, default=1.0)
138 | parser.add_argument("--loss-safe-coef", type=float, default=1.0)
139 | parser.add_argument("--loss-h-dot-coef", type=float, default=0.01)
140 | parser.add_argument("--buffer-size", type=int, default=512)
141 |
142 | # default arguments
143 | parser.add_argument("--n-env-train", type=int, default=16)
144 | parser.add_argument("--n-env-test", type=int, default=32)
145 | parser.add_argument("--log-dir", type=str, default="./logs")
146 | parser.add_argument("--eval-interval", type=int, default=1)
147 | parser.add_argument("--eval-epi", type=int, default=1)
148 | parser.add_argument("--save-interval", type=int, default=10)
149 |
150 | args = parser.parse_args()
151 | train(args)
152 |
153 |
154 | if __name__ == "__main__":
155 | with ipdb.launch_ipdb_on_exception():
156 | main()
157 |
--------------------------------------------------------------------------------