├── .gitignore ├── LICENSE ├── README.md ├── cars ├── config.py ├── core.py ├── evaluate.py ├── models │ ├── model_save.data-00000-of-00001 │ ├── model_save.index │ └── model_save.meta └── train.py ├── drones ├── config.py ├── core.py ├── dijkstra.py ├── evaluate.py ├── models │ ├── model_save.data-00000-of-00001 │ ├── model_save.index │ └── model_save.meta └── train.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | trajectory 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zengyi Qin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Learning Safe Multi-Agent Control with Decentralized Neural Barrier Certificates 2 | 3 | [Zengyi Qin](http://www.qinzy.tech/), [Kaiqing Zhang](https://kzhang66.github.io/), [Yuxiao Chen](http://www.its.caltech.edu/~chenyx/), [Jingkai Chen](https://jkchengh.github.io/), [Chuchu Fan](https://chuchu.mit.edu/) 4 | 5 | This repository contains the official implementation of [Learning Safe Multi-Agent Control with Decentralized Neural Barrier Certificates](https://arxiv.org/abs/2101.05436) published at the **International Conference on Learning Representations (ICLR)**, 2021. 6 | 7 | ### Installation 8 | Create a virtual environment with Anaconda: 9 | ```bash 10 | conda create -n macbf python=3.6 11 | ``` 12 | Activate the virtual environment: 13 | ```bash 14 | source activate macbf 15 | ``` 16 | Clone this repository: 17 | ```bash 18 | git clone https://github.com/Zengyi-Qin/macbf.git 19 | ``` 20 | Enter the main folder and install the dependencies: 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ### Cars 26 | In `cars`, we provide a multi-agent collision avoidance example with the double integrator dynamics. First enter the directory: 27 | ```bash 28 | cd cars 29 | ``` 30 | To evaluate the pretrained neural network CBF and controller, run: 31 | ```bash 32 | python evaluate.py --num_agents 32 --model_path models/model_save --vis 1 33 | ``` 34 | `--num_agents` specifies the number of agents in the environment. `--model_path` points to the prefix of the pretrained neural network weights. The visualization is disabled by default and will be enabled when `--vis` is set to 1. 35 | 36 | To train the neural network CBF and controller from scratch, run: 37 | ```bash 38 | python train.py --num_agents 32 39 | ``` 40 | We can add another argument `--model_path` and point to a pretrained model we want to use. 41 | 42 | 43 | ### Drones 44 | In `drones`, we consider the drone dynamics with 8-dimsional state space. Details of the dynamics can be found in Appendix C of our paper. To experiment with this example, first enter the directory: 45 | ```bash 46 | cd drones 47 | ``` 48 | To evaluate the pretrained neural network CBF and controller, run: 49 | ```bash 50 | python evaluate.py --num_agents 32 --model_path models/model_save --vis 1 51 | ``` 52 | To train the neural network CBF and controller from scratch, run: 53 | ```bash 54 | python train.py --num_agents 32 55 | ``` 56 | The arguments are the same as the cars example. -------------------------------------------------------------------------------- /cars/config.py: -------------------------------------------------------------------------------- 1 | TIME_STEP = 1e-1 2 | WEIGHT_DECAY = 1e-8 3 | 4 | ALPHA_CBF = 1.0 5 | 6 | DIST_MIN_THRES = 0.07 7 | DIST_MIN_CHECK = 0.07 8 | DIST_MIN_ENLARGED = 0.07 9 | 10 | OBS_RADIUS = 1.0 11 | TOP_K = 12 12 | 13 | TIME_TO_COLLISION = 2.0 14 | TIME_TO_COLLISION_CHECK = 0.1 15 | 16 | TRAIN_STEPS = 70000 17 | EVALUATE_STEPS = 10 18 | INNER_LOOPS = 50 19 | REFINE_LOOPS = 50 20 | REFINE_LEARNING_RATE = 0.3 21 | 22 | LEARNING_RATE = 1e-4 23 | DISPLAY_STEPS = 200 24 | SAVE_STEPS = 1000 25 | 26 | ADD_NOISE_PROB = 0.0 27 | NOISE_SCALE = 0.3 28 | -------------------------------------------------------------------------------- /cars/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import config 5 | 6 | 7 | def generate_obstacle_circle(center, radius, num=12): 8 | theta = np.linspace(0, np.pi*2, num=num, endpoint=False).reshape(-1, 1) 9 | unit_circle = np.concatenate([np.cos(theta), np.sin(theta)], axis=1) 10 | circle = np.array(center) + unit_circle * radius 11 | return circle 12 | 13 | 14 | def generate_obstacle_rectangle(center, sides, num=12): 15 | # calculate the number of points on each side of the rectangle 16 | a, b = sides # side lengths 17 | n_side_1 = int(num // 2 * a / (a+b)) 18 | n_side_2 = num // 2 - n_side_1 19 | n_side_3 = n_side_1 20 | n_side_4 = num - n_side_1 - n_side_2 - n_side_3 21 | # top 22 | side_1 = np.concatenate([ 23 | np.linspace(-a/2, a/2, n_side_1, endpoint=False).reshape(-1, 1), 24 | b/2 * np.ones(n_side_1).reshape(-1, 1)], axis=1) 25 | # right 26 | side_2 = np.concatenate([ 27 | a/2 * np.ones(n_side_2).reshape(-1, 1), 28 | np.linspace(b/2, -b/2, n_side_2, endpoint=False).reshape(-1, 1)], axis=1) 29 | # bottom 30 | side_3 = np.concatenate([ 31 | np.linspace(a/2, -a/2, n_side_3, endpoint=False).reshape(-1, 1), 32 | -b/2 * np.ones(n_side_3).reshape(-1, 1)], axis=1) 33 | # left 34 | side_4 = np.concatenate([ 35 | -a/2 * np.ones(n_side_4).reshape(-1, 1), 36 | np.linspace(-b/2, b/2, n_side_4, endpoint=False).reshape(-1, 1)], axis=1) 37 | 38 | rectangle = np.concatenate([side_1, side_2, side_3, side_4], axis=0) 39 | rectangle = rectangle + np.array(center) 40 | return rectangle 41 | 42 | 43 | def generate_data(num_agents, dist_min_thres): 44 | side_length = np.sqrt(max(1.0, num_agents / 8.0)) 45 | states = np.zeros(shape=(num_agents, 2), dtype=np.float32) 46 | goals = np.zeros(shape=(num_agents, 2), dtype=np.float32) 47 | 48 | i = 0 49 | while i < num_agents: 50 | candidate = np.random.uniform(size=(2,)) * side_length 51 | dist_min = np.linalg.norm(states - candidate, axis=1).min() 52 | if dist_min <= dist_min_thres: 53 | continue 54 | states[i] = candidate 55 | i = i + 1 56 | 57 | i = 0 58 | while i < num_agents: 59 | candidate = np.random.uniform(-0.5, 0.5, size=(2,)) + states[i] 60 | dist_min = np.linalg.norm(goals - candidate, axis=1).min() 61 | if dist_min <= dist_min_thres: 62 | continue 63 | goals[i] = candidate 64 | i = i + 1 65 | 66 | states = np.concatenate( 67 | [states, np.zeros(shape=(num_agents, 2), dtype=np.float32)], axis=1) 68 | return states, goals 69 | 70 | 71 | def network_cbf(x, r, indices=None): 72 | d_norm = tf.sqrt( 73 | tf.reduce_sum(tf.square(x[:, :, :2]) + 1e-4, axis=2)) 74 | x = tf.concat([x, 75 | tf.expand_dims(tf.eye(tf.shape(x)[0]), 2), 76 | tf.expand_dims(d_norm - r, 2)], axis=2) 77 | x, indices = remove_distant_agents(x=x, k=config.TOP_K, indices=indices) 78 | dist = tf.sqrt( 79 | tf.reduce_sum(tf.square(x[:, :, :2]) + 1e-4, axis=2, keepdims=True)) 80 | mask = tf.cast(tf.less_equal(dist, config.OBS_RADIUS), tf.float32) 81 | x = tf.contrib.layers.conv1d(inputs=x, 82 | num_outputs=64, 83 | kernel_size=1, 84 | reuse=tf.AUTO_REUSE, 85 | scope='cbf/conv_1', 86 | activation_fn=tf.nn.relu) 87 | x = tf.contrib.layers.conv1d(inputs=x, 88 | num_outputs=128, 89 | kernel_size=1, 90 | reuse=tf.AUTO_REUSE, 91 | scope='cbf/conv_2', 92 | activation_fn=tf.nn.relu) 93 | x = tf.contrib.layers.conv1d(inputs=x, 94 | num_outputs=64, 95 | kernel_size=1, 96 | reuse=tf.AUTO_REUSE, 97 | scope='cbf/conv_3', 98 | activation_fn=tf.nn.relu) 99 | x = tf.contrib.layers.conv1d(inputs=x, 100 | num_outputs=1, 101 | kernel_size=1, 102 | reuse=tf.AUTO_REUSE, 103 | scope='cbf/conv_4', 104 | activation_fn=None) 105 | x = x * mask 106 | return x, mask, indices 107 | 108 | 109 | def network_action(s, g, obs_radius=1.0, indices=None): 110 | x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 111 | x = tf.concat([x, 112 | tf.expand_dims(tf.eye(tf.shape(x)[0]), 2)], axis=2) 113 | x, _ = remove_distant_agents(x=x, k=config.TOP_K, indices=indices) 114 | dist = tf.norm(x[:, :, :2], axis=2, keepdims=True) 115 | mask = tf.cast(tf.less(dist, obs_radius), tf.float32) 116 | x = tf.contrib.layers.conv1d(inputs=x, 117 | num_outputs=64, 118 | kernel_size=1, 119 | reuse=tf.AUTO_REUSE, 120 | scope='action/conv_1', 121 | activation_fn=tf.nn.relu) 122 | x = tf.contrib.layers.conv1d(inputs=x, 123 | num_outputs=128, 124 | kernel_size=1, 125 | reuse=tf.AUTO_REUSE, 126 | scope='action/conv_2', 127 | activation_fn=tf.nn.relu) 128 | x = tf.reduce_max(x * mask, axis=1) 129 | x = tf.concat([x, s[:, :2] - g, s[:, 2:]], axis=1) 130 | x = tf.contrib.layers.fully_connected(inputs=x, 131 | num_outputs=64, 132 | reuse=tf.AUTO_REUSE, 133 | scope='action/fc_1', 134 | activation_fn=tf.nn.relu) 135 | x = tf.contrib.layers.fully_connected(inputs=x, 136 | num_outputs=128, 137 | reuse=tf.AUTO_REUSE, 138 | scope='action/fc_2', 139 | activation_fn=tf.nn.relu) 140 | x = tf.contrib.layers.fully_connected(inputs=x, 141 | num_outputs=64, 142 | reuse=tf.AUTO_REUSE, 143 | scope='action/fc_3', 144 | activation_fn=tf.nn.relu) 145 | x = tf.contrib.layers.fully_connected(inputs=x, 146 | num_outputs=4, 147 | reuse=tf.AUTO_REUSE, 148 | scope='action/fc_4', 149 | activation_fn=None) 150 | x = 2.0 * tf.nn.sigmoid(x) + 0.2 151 | k_1, k_2, k_3, k_4 = tf.split(x, 4, axis=1) 152 | zeros = tf.zeros_like(k_1) 153 | gain_x = -tf.concat([k_1, zeros, k_2, zeros], axis=1) 154 | gain_y = -tf.concat([zeros, k_3, zeros, k_4], axis=1) 155 | state = tf.concat([s[:, :2] - g, s[:, 2:]], axis=1) 156 | a_x = tf.reduce_sum(state * gain_x, axis=1, keepdims=True) 157 | a_y = tf.reduce_sum(state * gain_y, axis=1, keepdims=True) 158 | a = tf.concat([a_x, a_y], axis=1) 159 | return a 160 | 161 | 162 | def dynamics(s, a): 163 | """ The ground robot dynamics. 164 | 165 | Args: 166 | s (N, 4): The current state. 167 | a (N, 2): The acceleration taken by each agent. 168 | Returns: 169 | dsdt (N, 4): The time derivative of s. 170 | """ 171 | dsdt = tf.concat([s[:, 2:], a], axis=1) 172 | return dsdt 173 | 174 | 175 | def loss_barrier(h, s, r, ttc, indices=None, eps=[1e-3, 0]): 176 | """ Build the loss function for the control barrier functions. 177 | 178 | Args: 179 | h (N, N, 1): The control barrier function. 180 | s (N, 4): The current state of N agents. 181 | r (float): The radius of the safe regions. 182 | ttc (float): The threshold of time to collision. 183 | """ 184 | 185 | h_reshape = tf.reshape(h, [-1]) 186 | dang_mask = ttc_dangerous_mask(s, r=r, ttc=ttc, indices=indices) 187 | dang_mask_reshape = tf.reshape(dang_mask, [-1]) 188 | safe_mask_reshape = tf.logical_not(dang_mask_reshape) 189 | 190 | dang_h = tf.boolean_mask(h_reshape, dang_mask_reshape) 191 | safe_h = tf.boolean_mask(h_reshape, safe_mask_reshape) 192 | 193 | num_dang = tf.cast(tf.shape(dang_h)[0], tf.float32) 194 | num_safe = tf.cast(tf.shape(safe_h)[0], tf.float32) 195 | 196 | loss_dang = tf.reduce_sum( 197 | tf.math.maximum(dang_h + eps[0], 0)) / (1e-5 + num_dang) 198 | loss_safe = tf.reduce_sum( 199 | tf.math.maximum(-safe_h + eps[1], 0)) / (1e-5 + num_safe) 200 | 201 | acc_dang = tf.reduce_sum(tf.cast( 202 | tf.less_equal(dang_h, 0), tf.float32)) / (1e-5 + num_dang) 203 | acc_safe = tf.reduce_sum(tf.cast( 204 | tf.greater(safe_h, 0), tf.float32)) / (1e-5 + num_safe) 205 | 206 | acc_dang = tf.cond( 207 | tf.greater(num_dang, 0), lambda: acc_dang, lambda: -tf.constant(1.0)) 208 | acc_safe = tf.cond( 209 | tf.greater(num_safe, 0), lambda: acc_safe, lambda: -tf.constant(1.0)) 210 | 211 | return loss_dang, loss_safe, acc_dang, acc_safe 212 | 213 | 214 | def loss_derivatives(s, a, h, x, r, ttc, alpha, indices=None, eps=[1e-3, 0]): 215 | dsdt = dynamics(s, a) 216 | s_next = s + dsdt * config.TIME_STEP 217 | 218 | x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0) 219 | h_next, mask_next, _ = network_cbf(x=x_next, r=config.DIST_MIN_THRES, indices=indices) 220 | 221 | deriv = h_next - h + config.TIME_STEP * alpha * h 222 | 223 | deriv_reshape = tf.reshape(deriv, [-1]) 224 | dang_mask = ttc_dangerous_mask(s=s, r=r, ttc=ttc, indices=indices) 225 | dang_mask_reshape = tf.reshape(dang_mask, [-1]) 226 | safe_mask_reshape = tf.logical_not(dang_mask_reshape) 227 | 228 | dang_deriv = tf.boolean_mask(deriv_reshape, dang_mask_reshape) 229 | safe_deriv = tf.boolean_mask(deriv_reshape, safe_mask_reshape) 230 | 231 | num_dang = tf.cast(tf.shape(dang_deriv)[0], tf.float32) 232 | num_safe = tf.cast(tf.shape(safe_deriv)[0], tf.float32) 233 | 234 | loss_dang_deriv = tf.reduce_sum( 235 | tf.math.maximum(-dang_deriv + eps[0], 0)) / (1e-5 + num_dang) 236 | loss_safe_deriv = tf.reduce_sum( 237 | tf.math.maximum(-safe_deriv + eps[1], 0)) / (1e-5 + num_safe) 238 | 239 | acc_dang_deriv = tf.reduce_sum(tf.cast( 240 | tf.greater_equal(dang_deriv, 0), tf.float32)) / (1e-5 + num_dang) 241 | acc_safe_deriv = tf.reduce_sum(tf.cast( 242 | tf.greater_equal(safe_deriv, 0), tf.float32)) / (1e-5 + num_safe) 243 | 244 | acc_dang_deriv = tf.cond( 245 | tf.greater(num_dang, 0), lambda: acc_dang_deriv, lambda: -tf.constant(1.0)) 246 | acc_safe_deriv = tf.cond( 247 | tf.greater(num_safe, 0), lambda: acc_safe_deriv, lambda: -tf.constant(1.0)) 248 | 249 | return loss_dang_deriv, loss_safe_deriv, acc_dang_deriv, acc_safe_deriv 250 | 251 | 252 | def loss_actions(s, g, a, r, ttc): 253 | state_gain = -tf.constant( 254 | np.eye(2, 4) + np.eye(2, 4, k=2) * np.sqrt(3), dtype=tf.float32) 255 | s_ref = tf.concat([s[:, :2] - g, s[:, 2:]], axis=1) 256 | action_ref = tf.linalg.matmul(s_ref, state_gain, False, True) 257 | action_ref_norm = tf.reduce_sum(tf.square(action_ref), axis=1) 258 | action_net_norm = tf.reduce_sum(tf.square(a), axis=1) 259 | norm_diff = tf.abs(action_net_norm - action_ref_norm) 260 | loss = tf.reduce_mean(norm_diff) 261 | return loss 262 | 263 | 264 | def statics(s, a, h, alpha, indices=None): 265 | dsdt = dynamics(s, a) 266 | s_next = s + dsdt * config.TIME_STEP 267 | 268 | x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0) 269 | h_next, mask_next, _ = network_cbf(x=x_next, r=config.DIST_MIN_THRES, indices=indices) 270 | 271 | deriv = h_next - h + config.TIME_STEP * alpha * h 272 | 273 | mean_deriv = tf.reduce_mean(deriv) 274 | std_deriv = tf.sqrt(tf.reduce_mean(tf.square(deriv - mean_deriv))) 275 | prob_neg = tf.reduce_mean(tf.cast(tf.less(deriv, 0), tf.float32)) 276 | 277 | return mean_deriv, std_deriv, prob_neg 278 | 279 | 280 | def ttc_dangerous_mask(s, r, ttc, indices=None): 281 | s_diff = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 282 | s_diff = tf.concat( 283 | [s_diff, tf.expand_dims(tf.eye(tf.shape(s)[0]), 2)], axis=2) 284 | s_diff, _ = remove_distant_agents(s_diff, config.TOP_K, indices) 285 | x, y, vx, vy, eye = tf.split(s_diff, 5, axis=2) 286 | x = x + eye 287 | y = y + eye 288 | alpha = vx ** 2 + vy ** 2 289 | beta = 2 * (x * vx + y * vy) 290 | gamma = x ** 2 + y ** 2 - r ** 2 291 | dist_dangerous = tf.less(gamma, 0) 292 | 293 | has_two_positive_roots = tf.logical_and( 294 | tf.greater(beta ** 2 - 4 * alpha * gamma, 0), 295 | tf.logical_and(tf.greater(gamma, 0), tf.less(beta, 0))) 296 | root_less_than_ttc = tf.logical_or( 297 | tf.less(-beta - 2 * alpha * ttc, 0), 298 | tf.less((beta + 2 * alpha * ttc) ** 2, beta ** 2 - 4 * alpha * gamma)) 299 | has_root_less_than_ttc = tf.logical_and(has_two_positive_roots, root_less_than_ttc) 300 | ttc_dangerous = tf.logical_or(dist_dangerous, has_root_less_than_ttc) 301 | 302 | return ttc_dangerous 303 | 304 | 305 | def ttc_dangerous_mask_np(s, r, ttc): 306 | s_diff = np.expand_dims(s, 1) - np.expand_dims(s, 0) 307 | x, y, vx, vy = np.split(s_diff, 4, axis=2) 308 | x = x + np.expand_dims(np.eye(np.shape(s)[0]), 2) 309 | y = y + np.expand_dims(np.eye(np.shape(s)[0]), 2) 310 | alpha = vx ** 2 + vy ** 2 311 | beta = 2 * (x * vx + y * vy) 312 | gamma = x ** 2 + y ** 2 - r ** 2 313 | dist_dangerous = np.less(gamma, 0) 314 | 315 | has_two_positive_roots = np.logical_and( 316 | np.greater(beta ** 2 - 4 * alpha * gamma, 0), 317 | np.logical_and(np.greater(gamma, 0), np.less(beta, 0))) 318 | root_less_than_ttc = np.logical_or( 319 | np.less(-beta - 2 * alpha * ttc, 0), 320 | np.less((beta + 2 * alpha * ttc) ** 2, beta ** 2 - 4 * alpha * gamma)) 321 | has_root_less_than_ttc = np.logical_and(has_two_positive_roots, root_less_than_ttc) 322 | ttc_dangerous = np.logical_or(dist_dangerous, has_root_less_than_ttc) 323 | 324 | return ttc_dangerous 325 | 326 | 327 | def remove_distant_agents(x, k, indices=None): 328 | n, _, c = x.get_shape().as_list() 329 | if n <= k: 330 | return x, False 331 | d_norm = tf.sqrt(tf.reduce_sum(tf.square(x[:, :, :2]) + 1e-6, axis=2)) 332 | if indices is not None: 333 | x = tf.reshape(tf.gather_nd(x, indices), [n, k, c]) 334 | return x, indices 335 | _, indices = tf.nn.top_k(-d_norm, k=k) 336 | row_indices = tf.expand_dims( 337 | tf.range(tf.shape(indices)[0]), 1) * tf.ones_like(indices) 338 | row_indices = tf.reshape(row_indices, [-1, 1]) 339 | column_indices = tf.reshape(indices, [-1, 1]) 340 | indices = tf.concat([row_indices, column_indices], axis=1) 341 | x = tf.reshape(tf.gather_nd(x, indices), [n, k, c]) 342 | return x, indices -------------------------------------------------------------------------------- /cars/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.dont_write_bytecode = True 3 | 4 | import os 5 | import time 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | import matplotlib.pyplot as plt 10 | 11 | import core 12 | import config 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--num_agents', type=int, required=True) 17 | parser.add_argument('--model_path', type=str, default=None) 18 | parser.add_argument('--vis', type=int, default=0) 19 | parser.add_argument('--gpu', type=str, default='0') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def build_evaluation_graph(num_agents): 25 | # s is the state vectors of the agents 26 | s = tf.placeholder(tf.float32, [num_agents, 4]) 27 | # g is the goal states 28 | g = tf.placeholder(tf.float32, [num_agents, 2]) 29 | # x is difference between the state of each agent and other agents 30 | x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 31 | # h is the CBF value of shape [num_agents, TOP_K, 1], where TOP_K represents 32 | # the K nearest agents 33 | h, mask, indices = core.network_cbf(x=x, r=config.DIST_MIN_THRES) 34 | # a is the control action of each agent, with shape [num_agents, 3] 35 | a = core.network_action(s=s, g=g, obs_radius=config.OBS_RADIUS, indices=indices) 36 | # a_res is delta a. when a does not satisfy the CBF conditions, we want to compute 37 | # a a_res such that a + a_res satisfies the CBF conditions 38 | a_res = tf.Variable(tf.zeros_like(a), name='a_res') 39 | loop_count = tf.Variable(0, name='loop_count') 40 | 41 | def opt_body(a_res, loop_count): 42 | # a loop of updating a_res 43 | # compute s_next under a + a_res 44 | dsdt = core.dynamics(s, a + a_res) 45 | s_next = s + dsdt * config.TIME_STEP 46 | x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0) 47 | h_next, mask_next, _ = core.network_cbf( 48 | x=x_next, r=config.DIST_MIN_THRES, indices=indices) 49 | # deriv should be >= 0. if not, we update a_res by gradient descent 50 | deriv = h_next - h + config.TIME_STEP * config.ALPHA_CBF * h 51 | deriv = deriv * mask * mask_next 52 | error = tf.reduce_sum(tf.math.maximum(-deriv, 0), axis=1) 53 | # compute the gradient to update a_res 54 | error_gradient = tf.gradients(error, a_res)[0] 55 | a_res = a_res - config.REFINE_LEARNING_RATE * error_gradient 56 | loop_count = loop_count + 1 57 | return a_res, loop_count 58 | 59 | def opt_cond(a_res, loop_count): 60 | # update u_res for REFINE_LOOPS 61 | cond = tf.less(loop_count, config.REFINE_LOOPS) 62 | return cond 63 | 64 | with tf.control_dependencies([ 65 | a_res.assign(tf.zeros_like(a)), loop_count.assign(0)]): 66 | a_res, _ = tf.while_loop(opt_cond, opt_body, [a_res, loop_count]) 67 | a_opt = a + a_res 68 | 69 | dsdt = core.dynamics(s, a_opt) 70 | s_next = s + dsdt * config.TIME_STEP 71 | x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0) 72 | h_next, mask_next, _ = core.network_cbf(x=x_next, r=config.DIST_MIN_THRES, indices=indices) 73 | 74 | # compute the value of loss functions and the accuracies 75 | # loss_dang is for h(s) < 0, s in dangerous set 76 | # loss safe is for h(s) >=0, s in safe set 77 | # acc_dang is the accuracy that h(s) < 0, s in dangerous set is satisfied 78 | # acc_safe is the accuracy that h(s) >=0, s in safe set is satisfied 79 | (loss_dang, loss_safe, acc_dang, acc_safe) = core.loss_barrier( 80 | h=h_next, s=s_next, r=config.DIST_MIN_THRES, 81 | ttc=config.TIME_TO_COLLISION, eps=[0, 0]) 82 | # loss_dang_deriv is for doth(s) + alpha h(s) >=0 for s in dangerous set 83 | # loss_safe_deriv is for doth(s) + alpha h(s) >=0 for s in safe set 84 | # loss_medium_deriv is for doth(s) + alpha h(s) >=0 for s not in the dangerous 85 | # or the safe set 86 | (loss_dang_deriv, loss_safe_deriv, acc_dang_deriv, acc_safe_deriv 87 | ) = core.loss_derivatives(s=s_next, a=a_opt, h=h_next, x=x_next, 88 | r=config.DIST_MIN_THRES, ttc=config.TIME_TO_COLLISION, alpha=config.ALPHA_CBF, indices=indices) 89 | # the distance between the u_opt and the nominal u 90 | loss_action = core.loss_actions(s, g, a, r=config.DIST_MIN_THRES, ttc=config.TIME_TO_COLLISION) 91 | 92 | loss_list = [loss_dang, loss_safe, loss_dang_deriv, loss_safe_deriv, loss_action] 93 | acc_list = [acc_dang, acc_safe, acc_dang_deriv, acc_safe_deriv] 94 | 95 | return s, g, a_opt, loss_list, acc_list 96 | 97 | 98 | def print_accuracy(accuracy_lists): 99 | acc = np.array(accuracy_lists) 100 | acc_list = [] 101 | for i in range(acc.shape[1]): 102 | acc_i = acc[:, i] 103 | acc_list.append(np.mean(acc_i[acc_i > 0])) 104 | print('Accuracy: {}'.format(acc_list)) 105 | 106 | 107 | def render_init(): 108 | fig = plt.figure(figsize=(9, 4)) 109 | return fig 110 | 111 | 112 | def main(): 113 | args = parse_args() 114 | s, g, a, loss_list, acc_list = build_evaluation_graph(args.num_agents) 115 | # loads the pretrained weights 116 | vars = tf.trainable_variables() 117 | vars_restore = [] 118 | for v in vars: 119 | if 'action' in v.name or 'cbf' in v.name: 120 | vars_restore.append(v) 121 | # initialize the tensorflow Session 122 | sess = tf.Session() 123 | sess.run(tf.global_variables_initializer()) 124 | saver = tf.train.Saver(var_list=vars_restore) 125 | saver.restore(sess, args.model_path) 126 | 127 | safety_ratios_epoch = [] 128 | safety_ratios_epoch_lqr = [] 129 | 130 | dist_errors = [] 131 | init_dist_errors = [] 132 | accuracy_lists = [] 133 | 134 | safety_reward = [] 135 | dist_reward = [] 136 | safety_reward_baseline = [] 137 | dist_reward_baseline = [] 138 | 139 | if args.vis: 140 | plt.ion() 141 | plt.close() 142 | fig = render_init() 143 | 144 | for istep in range(config.EVALUATE_STEPS): 145 | start_time = time.time() 146 | 147 | safety_info = [] 148 | safety_info_baseline = [] 149 | 150 | # randomly generate the initial conditions s_np_ori and the goal states g_np_ori 151 | s_np_ori, g_np_ori = core.generate_data(args.num_agents, config.DIST_MIN_THRES * 1.5) 152 | 153 | s_np, g_np = np.copy(s_np_ori), np.copy(g_np_ori) 154 | init_dist_errors.append(np.mean(np.linalg.norm(s_np[:, :2] - g_np, axis=1))) 155 | # store the trajectory for visualization 156 | s_np_ours = [] 157 | s_np_lqr = [] 158 | 159 | safety_ours = [] 160 | safety_lqr = [] 161 | # run INNER_LOOPS steps to reach the current goals 162 | for i in range(config.INNER_LOOPS): 163 | # compute the control input 164 | a_network, acc_list_np = sess.run([a, acc_list], feed_dict={s: s_np, g: g_np}) 165 | dsdt = np.concatenate([s_np[:, 2:], a_network], axis=1) 166 | # simulate the system for one step 167 | s_np = s_np + dsdt * config.TIME_STEP 168 | s_np_ours.append(s_np) 169 | safety_ratio = 1 - np.mean(core.ttc_dangerous_mask_np( 170 | s_np, config.DIST_MIN_CHECK, config.TIME_TO_COLLISION_CHECK), axis=1) 171 | safety_ours.append(safety_ratio) 172 | safety_info.append((safety_ratio == 1).astype(np.float32).reshape((1, -1))) 173 | safety_ratio = np.mean(safety_ratio == 1) 174 | safety_ratios_epoch.append(safety_ratio) 175 | accuracy_lists.append(acc_list_np) 176 | 177 | if args.vis: 178 | # break if the agents are already very close to their goals 179 | if np.amax( 180 | np.linalg.norm(s_np[:, :2] - g_np, axis=1) 181 | ) < config.DIST_MIN_CHECK / 3: 182 | time.sleep(1) 183 | break 184 | # if the agents are very close to their goals, safely switch to LQR 185 | if np.mean( 186 | np.linalg.norm(s_np[:, :2] - g_np, axis=1) 187 | ) < config.DIST_MIN_CHECK / 2: 188 | K = np.eye(2, 4) + np.eye(2, 4, k=2) * np.sqrt(3) 189 | s_ref = np.concatenate([s_np[:, :2] - g_np, s_np[:, 2:]], axis=1) 190 | a_lqr = -s_ref.dot(K.T) 191 | s_np = s_np + np.concatenate( 192 | [s_np[:, 2:], a_lqr], axis=1) * config.TIME_STEP 193 | else: 194 | if np.mean( 195 | np.linalg.norm(s_np[:, :2] - g_np, axis=1) 196 | ) < config.DIST_MIN_CHECK: 197 | break 198 | 199 | dist_errors.append(np.mean(np.linalg.norm(s_np[:, :2] - g_np, axis=1))) 200 | safety_reward.append(np.mean(np.sum(np.concatenate(safety_info, axis=0) - 1, axis=0))) 201 | dist_reward.append(np.mean( 202 | (np.linalg.norm(s_np[:, :2] - g_np, axis=1) < 0.2).astype(np.float32) * 10)) 203 | 204 | # run the simulation using LQR controller without considering collision 205 | s_np, g_np = np.copy(s_np_ori), np.copy(g_np_ori) 206 | for i in range(config.INNER_LOOPS): 207 | K = np.eye(2, 4) + np.eye(2, 4, k=2) * np.sqrt(3) 208 | s_ref = np.concatenate([s_np[:, :2] - g_np, s_np[:, 2:]], axis=1) 209 | a_lqr = -s_ref.dot(K.T) 210 | s_np = s_np + np.concatenate([s_np[:, 2:], a_lqr], axis=1) * config.TIME_STEP 211 | s_np_lqr.append(s_np) 212 | safety_ratio = 1 - np.mean(core.ttc_dangerous_mask_np( 213 | s_np, config.DIST_MIN_CHECK, config.TIME_TO_COLLISION_CHECK), axis=1) 214 | safety_lqr.append(safety_ratio) 215 | safety_info_baseline.append((safety_ratio == 1).astype(np.float32).reshape((1, -1))) 216 | safety_ratio = np.mean(safety_ratio == 1) 217 | safety_ratios_epoch_lqr.append(safety_ratio) 218 | if np.mean( 219 | np.linalg.norm(s_np[:, :2] - g_np, axis=1) 220 | ) < config.DIST_MIN_CHECK / 3: 221 | break 222 | 223 | safety_reward_baseline.append(np.mean( 224 | np.sum(np.concatenate(safety_info_baseline, axis=0) - 1, axis=0))) 225 | dist_reward_baseline.append(np.mean( 226 | (np.linalg.norm(s_np[:, :2] - g_np, axis=1) < 0.2).astype(np.float32) * 10)) 227 | 228 | if args.vis: 229 | # visualize the trajectories 230 | vis_range = max(1, np.amax(np.abs(s_np_ori[:, :2]))) 231 | agent_size = 100 / vis_range ** 2 232 | g_np = g_np / vis_range 233 | for j in range(max(len(s_np_ours), len(s_np_lqr))): 234 | plt.clf() 235 | 236 | plt.subplot(121) 237 | j_ours = min(j, len(s_np_ours)-1) 238 | s_np = s_np_ours[j_ours] / vis_range 239 | plt.scatter(s_np[:, 0], s_np[:, 1], 240 | color='darkorange', 241 | s=agent_size, label='Agent', alpha=0.6) 242 | plt.scatter(g_np[:, 0], g_np[:, 1], 243 | color='deepskyblue', 244 | s=agent_size, label='Target', alpha=0.6) 245 | safety = np.squeeze(safety_ours[j_ours]) 246 | plt.scatter(s_np[safety<1, 0], s_np[safety<1, 1], 247 | color='red', 248 | s=agent_size, label='Collision', alpha=0.9) 249 | plt.xlim(-0.5, 1.5) 250 | plt.ylim(-0.5, 1.5) 251 | ax = plt.gca() 252 | for side in ax.spines.keys(): 253 | ax.spines[side].set_linewidth(2) 254 | ax.spines[side].set_color('grey') 255 | ax.set_xticklabels([]) 256 | ax.set_yticklabels([]) 257 | plt.legend(loc='upper right', fontsize=14) 258 | plt.title('Ours: Safety Rate = {:.3f}'.format( 259 | np.mean(safety_ratios_epoch)), fontsize=14) 260 | 261 | plt.subplot(122) 262 | j_lqr = min(j, len(s_np_lqr)-1) 263 | s_np = s_np_lqr[j_lqr] / vis_range 264 | plt.scatter(s_np[:, 0], s_np[:, 1], 265 | color='darkorange', 266 | s=agent_size, label='Agent', alpha=0.6) 267 | plt.scatter(g_np[:, 0], g_np[:, 1], 268 | color='deepskyblue', 269 | s=agent_size, label='Target', alpha=0.6) 270 | safety = np.squeeze(safety_lqr[j_lqr]) 271 | plt.scatter(s_np[safety<1, 0], s_np[safety<1, 1], 272 | color='red', 273 | s=agent_size, label='Collision', alpha=0.9) 274 | plt.xlim(-0.5, 1.5) 275 | plt.ylim(-0.5, 1.5) 276 | ax = plt.gca() 277 | for side in ax.spines.keys(): 278 | ax.spines[side].set_linewidth(2) 279 | ax.spines[side].set_color('grey') 280 | ax.set_xticklabels([]) 281 | ax.set_yticklabels([]) 282 | plt.legend(loc='upper right', fontsize=14) 283 | plt.title('LQR: Safety Rate = {:.3f}'.format( 284 | np.mean(safety_ratios_epoch_lqr)), fontsize=14) 285 | 286 | fig.canvas.draw() 287 | plt.pause(0.01) 288 | plt.clf() 289 | 290 | 291 | end_time = time.time() 292 | print('Evaluation Step: {} | {}, Time: {:.4f}'.format( 293 | istep + 1, config.EVALUATE_STEPS, end_time - start_time)) 294 | 295 | print_accuracy(accuracy_lists) 296 | print('Distance Error (Final | Inititial): {:.4f} | {:.4f}'.format( 297 | np.mean(dist_errors), np.mean(init_dist_errors))) 298 | print('Mean Safety Ratio (Learning | LQR): {:.4f} | {:.4f}'.format( 299 | np.mean(safety_ratios_epoch), np.mean(safety_ratios_epoch_lqr))) 300 | print('Reward Safety (Learning | LQR): {:.4f} | {:.4f}, Reward Distance: {:.4f} | {:.4f}'.format( 301 | np.mean(safety_reward), np.mean(safety_reward_baseline), 302 | np.mean(dist_reward), np.mean(dist_reward_baseline))) 303 | 304 | 305 | if __name__ == '__main__': 306 | main() 307 | -------------------------------------------------------------------------------- /cars/models/model_save.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/macbf/7e364b7a65801debaccc2f60673323385e6f05d3/cars/models/model_save.data-00000-of-00001 -------------------------------------------------------------------------------- /cars/models/model_save.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/macbf/7e364b7a65801debaccc2f60673323385e6f05d3/cars/models/model_save.index -------------------------------------------------------------------------------- /cars/models/model_save.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/macbf/7e364b7a65801debaccc2f60673323385e6f05d3/cars/models/model_save.meta -------------------------------------------------------------------------------- /cars/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.dont_write_bytecode = True 3 | 4 | import os 5 | import h5py 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | import core 11 | import config 12 | 13 | np.set_printoptions(4) 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--num_agents', type=int, required=True) 18 | parser.add_argument('--model_path', type=str, default=None) 19 | parser.add_argument('--gpu', type=str, default='0') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def build_optimizer(loss): 25 | optimizer = tf.train.AdamOptimizer(learning_rate=config.LEARNING_RATE) 26 | trainable_vars = tf.trainable_variables() 27 | 28 | # tensor to accumulate gradients over multiple steps 29 | accumulators = [ 30 | tf.Variable( 31 | tf.zeros_like(tv.initialized_value()), 32 | trainable=False 33 | ) for tv in trainable_vars] 34 | 35 | # count how many steps we have accumulated 36 | accumulation_counter = tf.Variable(0.0, trainable=False) 37 | grad_pairs = optimizer.compute_gradients(loss, trainable_vars) 38 | # add the gradient to the accumulation tensor 39 | accumulate_ops = [ 40 | accumulator.assign_add( 41 | grad 42 | ) for (accumulator, (grad, var)) in zip(accumulators, grad_pairs)] 43 | 44 | accumulate_ops.append(accumulation_counter.assign_add(1.0)) 45 | # divide the accumulated gradient by the number of accumulation steps 46 | gradient_vars = [(accumulator / accumulation_counter, var) \ 47 | for (accumulator, (grad, var)) in zip(accumulators, grad_pairs)] 48 | # seperate the gradient of CBF and the controller 49 | gradient_vars_h = [] 50 | gradient_vars_a = [] 51 | for accumulate_grad, var in gradient_vars: 52 | if 'cbf' in var.name: 53 | gradient_vars_h.append((accumulate_grad, var)) 54 | elif 'action' in var.name: 55 | gradient_vars_a.append((accumulate_grad, var)) 56 | else: 57 | raise ValueError 58 | 59 | train_step_h = optimizer.apply_gradients(gradient_vars_h) 60 | train_step_a = optimizer.apply_gradients(gradient_vars_a) 61 | # re-initialize the accmulation tensor and accumulation step to zero 62 | zero_ops = [ 63 | accumulator.assign( 64 | tf.zeros_like(tv) 65 | ) for (accumulator, tv) in zip(accumulators, trainable_vars)] 66 | zero_ops.append(accumulation_counter.assign(0.0)) 67 | 68 | return zero_ops, accumulate_ops, train_step_h, train_step_a 69 | 70 | 71 | def build_training_graph(num_agents): 72 | # s is the state vectors of the agents 73 | s = tf.placeholder(tf.float32, [num_agents, 4]) 74 | # g is the goal states 75 | g = tf.placeholder(tf.float32, [num_agents, 2]) 76 | # x is difference between the state of each agent and other agents 77 | x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 78 | # h is the CBF value of shape [num_agents, TOP_K, 1], where TOP_K represents 79 | # the K nearest agents 80 | h, mask, indices = core.network_cbf(x=x, r=config.DIST_MIN_THRES, indices=None) 81 | # a is the control action of each agent, with shape [num_agents, 2] 82 | a = core.network_action(s=s, g=g, obs_radius=config.OBS_RADIUS, indices=indices) 83 | # compute the value of loss functions and the accuracies 84 | # loss_dang is for h(s) < 0, s in dangerous set 85 | # loss safe is for h(s) >=0, s in safe set 86 | # acc_dang is the accuracy that h(s) < 0, s in dangerous set is satisfied 87 | # acc_safe is the accuracy that h(s) >=0, s in safe set is satisfied 88 | (loss_dang, loss_safe, acc_dang, acc_safe) = core.loss_barrier( 89 | h=h, s=s, r=config.DIST_MIN_THRES, ttc=config.TIME_TO_COLLISION, indices=indices) 90 | # loss_dang_deriv is for doth(s) + alpha h(s) >=0 for s in dangerous set 91 | # loss_safe_deriv is for doth(s) + alpha h(s) >=0 for s in safe set 92 | # loss_medium_deriv is for doth(s) + alpha h(s) >=0 for s not in the dangerous 93 | # or the safe set 94 | (loss_dang_deriv, loss_safe_deriv, acc_dang_deriv, acc_safe_deriv 95 | ) = core.loss_derivatives(s=s, a=a, h=h, x=x, r=config.DIST_MIN_THRES, 96 | indices=indices, ttc=config.TIME_TO_COLLISION, alpha=config.ALPHA_CBF) 97 | # the distance between the a and the nominal a 98 | loss_action = core.loss_actions( 99 | s=s, g=g, a=a, r=config.DIST_MIN_THRES, ttc=config.TIME_TO_COLLISION) 100 | 101 | loss_list = [2 * loss_dang, loss_safe, 2 * loss_dang_deriv, loss_safe_deriv, 0.01 * loss_action] 102 | acc_list = [acc_dang, acc_safe, acc_dang_deriv, acc_safe_deriv] 103 | 104 | weight_loss = [ 105 | config.WEIGHT_DECAY * tf.nn.l2_loss(v) for v in tf.trainable_variables()] 106 | loss = 10 * tf.math.add_n(loss_list + weight_loss) 107 | 108 | return s, g, a, loss_list, loss, acc_list 109 | 110 | 111 | def count_accuracy(accuracy_lists): 112 | acc = np.array(accuracy_lists) 113 | acc_list = [] 114 | for i in range(acc.shape[1]): 115 | acc_list.append(np.mean(acc[acc[:, i] >= 0, i])) 116 | return acc_list 117 | 118 | 119 | def main(): 120 | args = parse_args() 121 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 122 | 123 | s, g, a, loss_list, loss, acc_list = build_training_graph(args.num_agents) 124 | zero_ops, accumulate_ops, train_step_h, train_step_a = build_optimizer(loss) 125 | accumulate_ops.append(loss_list) 126 | accumulate_ops.append(acc_list) 127 | 128 | accumulation_steps = config.INNER_LOOPS 129 | 130 | with tf.Session() as sess: 131 | sess.run(tf.global_variables_initializer()) 132 | saver = tf.train.Saver() 133 | 134 | if args.model_path: 135 | saver.restore(sess, args.model_path) 136 | 137 | state_gain = np.eye(2, 4) + np.eye(2, 4, k=2) * np.sqrt(3) 138 | 139 | loss_lists_np = [] 140 | acc_lists_np = [] 141 | dist_errors_np = [] 142 | init_dist_errors_np = [] 143 | 144 | safety_ratios_epoch = [] 145 | safety_ratios_epoch_lqr = [] 146 | 147 | for istep in range(config.TRAIN_STEPS): 148 | # randomly generate the initial states and goals 149 | s_np, g_np = core.generate_data(args.num_agents, config.DIST_MIN_THRES) 150 | s_np_lqr, g_np_lqr = np.copy(s_np), np.copy(g_np) 151 | init_dist_errors_np.append(np.mean(np.linalg.norm(s_np[:, :2] - g_np, axis=1))) 152 | sess.run(zero_ops) 153 | # run the system with the safe controller 154 | for i in range(accumulation_steps): 155 | # computes the control input a_np using the safe controller 156 | a_np, out = sess.run([a, accumulate_ops], feed_dict={s:s_np, g: g_np}) 157 | if np.random.uniform() < config.ADD_NOISE_PROB: 158 | noise = np.random.normal(size=np.shape(a_np)) * config.NOISE_SCALE 159 | a_np = a_np + noise 160 | # simulate the system for one step 161 | s_np = s_np + np.concatenate([s_np[:, 2:], a_np], axis=1) * config.TIME_STEP 162 | # computes the safety rate 163 | safety_ratio = 1 - np.mean(core.ttc_dangerous_mask_np( 164 | s_np, config.DIST_MIN_CHECK, config.TIME_TO_COLLISION_CHECK), axis=1) 165 | safety_ratio = np.mean(safety_ratio == 1) 166 | safety_ratios_epoch.append(safety_ratio) 167 | loss_list_np, acc_list_np = out[-2], out[-1] 168 | loss_lists_np.append(loss_list_np) 169 | acc_lists_np.append(acc_list_np) 170 | 171 | if np.mean( 172 | np.linalg.norm(s_np[:, :2] - g_np, axis=1) 173 | ) < config.DIST_MIN_CHECK: 174 | break 175 | # run the system with the LQR controller without collision avoidance as the baseline 176 | for i in range(accumulation_steps): 177 | state_gain = np.eye(2, 4) + np.eye(2, 4, k=2) * np.sqrt(3) 178 | s_ref_lqr = np.concatenate([s_np_lqr[:, :2] - g_np_lqr, s_np_lqr[:, 2:]], axis=1) 179 | a_lqr = -s_ref_lqr.dot(state_gain.T) 180 | s_np_lqr = s_np_lqr + np.concatenate([s_np_lqr[:, 2:], a_lqr], axis=1) * config.TIME_STEP 181 | s_np_lqr[:, :2] = np.clip(s_np_lqr[:, :2], 0, 1) 182 | safety_ratio_lqr = 1 - np.mean(core.ttc_dangerous_mask_np( 183 | s_np_lqr, config.DIST_MIN_CHECK, config.TIME_TO_COLLISION_CHECK), axis=1) 184 | safety_ratio = np.mean(safety_ratio == 1) 185 | safety_ratios_epoch_lqr.append(safety_ratio_lqr) 186 | 187 | if np.mean( 188 | np.linalg.norm(s_np_lqr[:, :2] - g_np_lqr, axis=1) 189 | ) < config.DIST_MIN_CHECK: 190 | break 191 | dist_errors_np.append(np.mean(np.linalg.norm(s_np[:, :2] - g_np, axis=1))) 192 | 193 | if np.mod(istep // 10, 2) == 0: 194 | sess.run(train_step_h) 195 | else: 196 | sess.run(train_step_a) 197 | 198 | if np.mod(istep, config.DISPLAY_STEPS) == 0: 199 | print('Step: {}, Loss: {}, Accuracy: {}'.format( 200 | istep, np.mean(loss_lists_np, axis=0), 201 | np.array(count_accuracy(acc_lists_np)))) 202 | loss_lists_np, acc_lists_np, dist_errors_np, safety_ratios_epoch, safety_ratios_epoch_lqr = [], [], [], [], [] 203 | 204 | if np.mod(istep, config.SAVE_STEPS) == 0 or istep + 1 == config.TRAIN_STEPS: 205 | saver.save(sess, 'models/model_iter_{}'.format(istep)) 206 | 207 | 208 | if __name__ == '__main__': 209 | main() -------------------------------------------------------------------------------- /drones/config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.dont_write_bytecode = True 3 | 4 | A_MAT = [[0, 0, 0, 1, 0, 0, 0, 0], 5 | [0, 0, 0, 0, 1, 0, 0, 0], 6 | [0, 0, 0, 0, 0, 1, 0, 0], 7 | [0, 0, 0, 0, 0, 0, 1, 0], 8 | [0, 0, 0, 0, 0, 0, 0, 1], 9 | [0, 0, 0, 0, 0, 0, 0, 0], 10 | [0, 0, 0, 0, 0, 0, 0, 0], 11 | [0, 0, 0, 0, 0, 0, 0, 0]] 12 | B_MAT = [[0, 0, 0], 13 | [0, 0, 0], 14 | [0, 0, 0], 15 | [0, 0, 0], 16 | [0, 0, 0], 17 | [0, 0, 1], 18 | [1, 0, 0], 19 | [0, 1, 0]] 20 | 21 | K_MAT = [[1, 0, 0, 3, 0, 0, 3, 0], 22 | [0, 1, 0, 0, 3, 0, 0, 3], 23 | [0, 0, 1, 0, 0, 3, 0, 0]] 24 | 25 | TIME_STEP = 0.05 26 | TIME_STEP_EVAL = 0.05 27 | TOP_K = 8 28 | OBS_RADIUS = 1.0 29 | 30 | DIST_MIN_THRES = 0.8 31 | DIST_MIN_CHECK = 0.6 32 | DIST_SAFE = 1.0 33 | DIST_TOLERATE = 0.7 34 | 35 | ALPHA_CBF = 1.0 36 | WEIGHT_DECAY = 1e-8 37 | 38 | TRAIN_STEPS = 70000 39 | EVALUATE_STEPS = 5 40 | INNER_LOOPS = 40 41 | INNER_LOOPS_EVAL = 40 42 | DISPLAY_STEPS = 10 43 | SAVE_STEPS = 200 44 | 45 | LEARNING_RATE = 1e-4 46 | REFINE_LEARNING_RATE = 1.0 47 | REFINE_LOOPS = 40 48 | -------------------------------------------------------------------------------- /drones/core.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.dont_write_bytecode = True 3 | import os 4 | from yaml import load, Loader 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import time 10 | import config 11 | import pickle 12 | from yaml import load, Loader 13 | from scipy import interpolate 14 | from dijkstra import Dijkstra 15 | 16 | 17 | class Cityscape(object): 18 | OBSTACLES = [[0, 4, 2, 6], # x1, y1, x2, y2 19 | [4, 4, 6, 6], 20 | [8, 4, 10, 6], 21 | ] 22 | 23 | def __init__(self, num_agents, area_size=10, max_steps=12, show=False): 24 | self.num_agents = num_agents 25 | self.area_size = area_size 26 | self.show = show 27 | self.max_steps = max_steps 28 | 29 | def reset(self): 30 | self.init_obstacles() 31 | self.init_broadlines() 32 | self.reset_starting_points() 33 | self.reset_end_points() 34 | self.reset_reference_paths() 35 | 36 | def init_obstacles(self): 37 | self.obstacle_points = [] 38 | for x1, y1, x2, y2 in self.OBSTACLES: 39 | xrnd = np.linspace(x1, x2, int(abs(x1-x2) * 3))[:, np.newaxis] 40 | for r in np.linspace(0, 1, int(abs(y1-y2) * 3)): 41 | yrnd = np.ones_like(xrnd) * (r * y1 + (1-r) * y2) 42 | self.obstacle_points.append(np.concatenate([xrnd, yrnd], axis=1)) 43 | 44 | yrnd = np.linspace(y1, y2, int(abs(y1-y2) * 3))[:, np.newaxis] 45 | for r in np.linspace(0, 1, int(abs(x1-x2) * 3)): 46 | xrnd = np.ones_like(yrnd) * (r * x1 + (1-r) * x2) 47 | self.obstacle_points.append(np.concatenate([xrnd, yrnd], axis=1)) 48 | self.obstacle_points = np.concatenate(self.obstacle_points, axis=0) 49 | 50 | def init_broadlines(self): 51 | area_size = self.area_size 52 | bx, by = [], [] 53 | bx.append(np.linspace(-1, area_size + 1, int(area_size * 2))) 54 | by.append(-np.ones(int(area_size * 2))) 55 | bx.append(np.linspace(-1, area_size + 1, int(area_size * 2))) 56 | by.append(np.ones(int(area_size * 2)) + area_size) 57 | by.append(np.linspace(-1, area_size + 1, int(area_size * 2))) 58 | bx.append(-np.ones(int(area_size * 2))) 59 | by.append(np.linspace(-1, area_size + 1, int(area_size * 2))) 60 | bx.append(np.ones(int(area_size * 2)) + area_size) 61 | bx = np.concatenate(bx, axis=0)[:, np.newaxis] 62 | by = np.concatenate(by, axis=0)[:, np.newaxis] 63 | self.broadline_points = np.concatenate([bx, by], axis=1) 64 | 65 | def reset_starting_points(self): 66 | start_points = np.zeros((self.num_agents, 3), dtype=np.float32) 67 | for i in range(self.num_agents): 68 | random_start = np.random.randint( 69 | low=1, high=self.area_size-1, size=(1, 2)) 70 | dist = min(np.amin(np.linalg.norm( 71 | start_points[:, :2] - random_start, axis=1)), 72 | np.amin(np.linalg.norm( 73 | self.obstacle_points - random_start, axis=1))) 74 | while dist < 1: 75 | random_start = np.random.randint( 76 | low=1, high=self.area_size-1, size=(1, 2)) 77 | dist = min(np.amin(np.linalg.norm( 78 | start_points[:, :2] - random_start, axis=1)), 79 | np.amin(np.linalg.norm( 80 | self.obstacle_points - random_start, axis=1))) 81 | start_points[i, :2] = random_start 82 | self.start_points = start_points 83 | 84 | def reset_end_points(self): 85 | end_points = np.zeros((self.num_agents, 3), dtype=np.float32) 86 | for i in range(self.num_agents): 87 | random_end = np.random.randint( 88 | low=1, high=self.area_size-1, size=(1, 3)) 89 | random_end[0, 2] = min(6, random_end[0, 2]) 90 | dist = min(np.amin(np.linalg.norm( 91 | self.obstacle_points - random_end[0, :2], axis=1)), 92 | np.linalg.norm( 93 | self.start_points[i, :2] - random_end[0, :2]) / 5) 94 | while dist < 1: 95 | random_end = np.random.randint( 96 | low=1, high=self.area_size-1, size=(1, 3)) 97 | random_end[0, 2] = min(6, random_end[0, 2]) 98 | dist = min(np.amin(np.linalg.norm( 99 | self.obstacle_points - random_end[0, :2], axis=1)), 100 | np.linalg.norm( 101 | self.start_points[i, :2] - random_end[0, :2]) / 5) 102 | end_points[i] = random_end 103 | self.end_points = end_points 104 | 105 | def reset_reference_paths(self): 106 | max_time = 0 107 | reference_paths = [] 108 | o = np.concatenate( 109 | [self.obstacle_points, self.broadline_points], axis=0) 110 | dijkstra = Dijkstra(o[:, 0], o[:, 1], 1.0, 0.5) 111 | if self.show: 112 | plt.ion() 113 | fig = plt.figure() 114 | for i in range(self.num_agents): 115 | sx, sy, sz = self.start_points[i] 116 | gx, gy, gz = self.end_points[i] 117 | rx, ry = dijkstra.planning(sx, sy, gx, gy) 118 | rx, ry = np.reshape(rx[::-1], (-1, 1)), np.reshape(ry[::-1], (-1, 1)) 119 | rz = np.reshape(np.linspace(sz, gz, rx.shape[0]), (-1, 1)) 120 | path = np.concatenate([rx, ry, rz], axis=1) 121 | reference_paths.append(path) 122 | max_time = max(max_time, rx.shape[0]) 123 | if self.show: 124 | plt.clf() 125 | plt.scatter(o[:, 0], o[:, 1], color='red', alpha=0.2) 126 | plt.scatter(sx, sy, color='orange', s=100) 127 | plt.scatter(gx, gy, color='darkred', s=100) 128 | plt.scatter(rx, ry, color='grey') 129 | fig.canvas.draw() 130 | time.sleep(1) 131 | 132 | waypoints = [] 133 | for i in range(self.num_agents): 134 | path = reference_paths[i] 135 | path_extend = np.zeros(shape=(max_time, 3), dtype=np.float32) 136 | path_extend[:path.shape[0]] = path 137 | path_extend[path.shape[0]:] = path[-1] 138 | waypoints.append(path_extend[np.newaxis]) 139 | waypoints = np.concatenate(waypoints, axis=0) 140 | self.waypoints = np.transpose(waypoints, [1, 0, 2]) 141 | self.max_time = max_time 142 | self.steps = min(self.max_steps, max_time) 143 | 144 | 145 | class Maze(Cityscape): 146 | 147 | def __init__(self, num_agents=64, max_steps=12, show=False): 148 | 149 | scale = np.sqrt(max(1.0, num_agents / 64.0)) 150 | area_size = 20 * scale 151 | 152 | Cityscape.__init__(self, num_agents, area_size, max_steps, show) 153 | self.OBSTACLES = np.array([ 154 | [3, 3, 6, 4], [3, 4, 4, 16], [3, 17, 6, 16], 155 | [8, 7, 12, 8], [8, 13, 12, 12], 156 | [11, 3, 12, 8], [11, 17, 12, 12], 157 | [10, 3, 17, 4], [10, 17, 17, 16], 158 | [14, 9, 16, 11], [16, 7, 17, 13] 159 | ]) * scale 160 | self.external_traj = False 161 | 162 | def write(self, data_path, num_episodes): 163 | start_points = [] 164 | end_points = [] 165 | for _ in range(num_episodes): 166 | self.reset() 167 | start_points.append(self.start_points[np.newaxis]) 168 | end_points.append(self.end_points[np.newaxis]) 169 | write_dict = {} 170 | write_dict['start_points'] = np.concatenate(start_points, axis=0) 171 | write_dict['end_points'] = np.concatenate(end_points, axis=0) 172 | obstacles = np.array(self.OBSTACLES, dtype=np.float32) 173 | obstacles = np.concatenate( 174 | [obstacles, 6 * np.ones_like(obstacles[:, :1])], axis=1) 175 | write_dict['obstacles'] = obstacles 176 | pickle.dump(write_dict, open(data_path, 'wb')) 177 | 178 | def write_trajectory(self, data_path, traj): 179 | write_dict = {} 180 | write_dict['trajectory'] = traj 181 | write_dict['start_points'] = np.array([t[0, :, :3] for t in traj]) 182 | write_dict['end_points'] = np.array([t[-1, :, :3] for t in traj]) 183 | obstacles = np.array(self.OBSTACLES, dtype=np.float32) 184 | obstacles = np.concatenate( 185 | [obstacles, 6 * np.ones_like(obstacles[:, :1])], axis=1) 186 | write_dict['obstacles'] = obstacles 187 | pickle.dump(write_dict, open(data_path, 'wb')) 188 | 189 | def read(self, traj_dir): 190 | traj_files = os.listdir(traj_dir) 191 | self.episodes = [] 192 | for traj_file in traj_files: 193 | self.episodes.append( 194 | load(open(os.path.join(traj_dir, traj_file)), Loader=Loader)) 195 | self.external_traj = True 196 | 197 | def reset(self): 198 | if self.external_traj is False: 199 | super().reset() 200 | return 201 | self.init_obstacles() 202 | self.init_broadlines() 203 | episode = np.random.choice(self.episodes) 204 | agents = list(episode.keys())[:self.num_agents] 205 | self.start_points = np.zeros((self.num_agents, 3), dtype=np.float32) 206 | self.end_points = np.zeros((self.num_agents, 3), dtype=np.float32) 207 | max_time = 0 208 | for i, agent in enumerate(agents): 209 | episode_agent = np.array(episode[agent]) 210 | self.start_points[i] = episode_agent[0, 1:] 211 | self.end_points[i] = episode_agent[-1, 1:] 212 | max_time = max(max_time, episode_agent[-1, 0]) 213 | max_time = int(np.rint(max_time)) 214 | 215 | waypoints = [] 216 | for agent in agents: 217 | path = np.array(episode[agent]) 218 | ts, path = path[:, 0], path[:, 1:] 219 | ts_extend = np.arange(max_time) 220 | 221 | xs = np.interp(ts_extend, ts, path[:, 0])[:, np.newaxis] 222 | ys = np.interp(ts_extend, ts, path[:, 1])[:, np.newaxis] 223 | zs = np.interp(ts_extend, ts, path[:, 2])[:, np.newaxis] 224 | 225 | path_extend = np.concatenate([xs, ys, zs], axis=1) 226 | waypoints.append(path_extend[np.newaxis]) 227 | waypoints = np.concatenate(waypoints, axis=0) 228 | self.waypoints = np.transpose(waypoints, [1, 0, 2]) 229 | self.steps = min(self.max_steps, max_time) 230 | 231 | 232 | def quadrotor_dynamics_np(s, u): 233 | dsdt = s.dot(np.array(config.A_MAT).T) + u.dot(np.array(config.B_MAT).T) 234 | return dsdt 235 | 236 | 237 | def quadrotor_dynamics_tf(s, u): 238 | A = tf.constant(np.array(config.A_MAT).T, dtype=tf.float32) 239 | B = tf.constant(np.array(config.B_MAT).T, dtype=tf.float32) 240 | dsdt = tf.matmul(s, A) + tf.matmul(u, B) 241 | return dsdt 242 | 243 | 244 | def quadrotor_controller_np(s, s_ref): 245 | u = (s_ref - s).dot(np.array(config.K_MAT).T) 246 | return u 247 | 248 | 249 | def quadrotor_controller_tf(s, s_ref): 250 | K = tf.constant(np.array(config.K_MAT).T, dtype=tf.float32) 251 | u = tf.matmul(s_ref - s, K) 252 | return u 253 | 254 | 255 | def network_cbf(x, r, indices=None): 256 | """ Control barrier function as a neural network. 257 | Args: 258 | x (N, N, 8): The state difference of N agents. 259 | r (float): The radius of the dangerous zone. 260 | indices (N, K): The indices of K nearest agents of each agent. 261 | Returns: 262 | h (N, K, 1): The CBF of N agents with K neighbouring agents. 263 | mask (N, K, 1): The mask of agents within the observation radius. 264 | indices (N, K): The indices of K nearest agents of each agent. 265 | """ 266 | d_norm = tf.sqrt( 267 | tf.reduce_sum(tf.square(x[:, :, :3]) + 1e-4, axis=2)) 268 | x = tf.concat([x, 269 | tf.expand_dims(tf.eye(tf.shape(x)[0]), 2), 270 | tf.expand_dims(d_norm - r, 2)], axis=2) 271 | x, indices = remove_distant_agents(x=x, k=config.TOP_K, indices=indices) 272 | dist = tf.sqrt( 273 | tf.reduce_sum(tf.square(x[:, :, :2]) + 1e-4, axis=2, keepdims=True)) 274 | mask = tf.cast(tf.less_equal(dist, config.OBS_RADIUS), tf.float32) 275 | x = tf.contrib.layers.conv1d(inputs=x, 276 | num_outputs=64, 277 | kernel_size=1, 278 | reuse=tf.AUTO_REUSE, 279 | scope='cbf/conv_1', 280 | activation_fn=tf.nn.relu) 281 | x = tf.contrib.layers.conv1d(inputs=x, 282 | num_outputs=128, 283 | kernel_size=1, 284 | reuse=tf.AUTO_REUSE, 285 | scope='cbf/conv_2', 286 | activation_fn=tf.nn.relu) 287 | x = tf.contrib.layers.conv1d(inputs=x, 288 | num_outputs=64, 289 | kernel_size=1, 290 | reuse=tf.AUTO_REUSE, 291 | scope='cbf/conv_3', 292 | activation_fn=tf.nn.relu) 293 | x = tf.contrib.layers.conv1d(inputs=x, 294 | num_outputs=1, 295 | kernel_size=1, 296 | reuse=tf.AUTO_REUSE, 297 | scope='cbf/conv_4', 298 | activation_fn=None) 299 | h = x * mask 300 | return h, mask, indices 301 | 302 | 303 | def network_action(s, s_ref, obs_radius=1.0, indices=None): 304 | """ Controller as a neural network. 305 | Args: 306 | s (N, 8): The current state of N agents. 307 | s_ref (N, 8): The reference location, velocity and acceleration. 308 | obs_radius (float): The observation radius. 309 | indices (N, K): The indices of K nearest agents of each agent. 310 | Returns: 311 | u (N, 2): The control action. 312 | """ 313 | x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 314 | x = tf.concat([x, 315 | tf.expand_dims(tf.eye(tf.shape(x)[0]), 2)], axis=2) 316 | x, _ = remove_distant_agents(x=x, k=config.TOP_K, indices=indices) 317 | dist = tf.norm(x[:, :, :3], axis=2, keepdims=True) 318 | mask = tf.cast(tf.less(dist, obs_radius), tf.float32) 319 | x = tf.contrib.layers.conv1d(inputs=x, 320 | num_outputs=64, 321 | kernel_size=1, 322 | reuse=tf.AUTO_REUSE, 323 | scope='action/conv_1', 324 | activation_fn=tf.nn.relu) 325 | x = tf.contrib.layers.conv1d(inputs=x, 326 | num_outputs=128, 327 | kernel_size=1, 328 | reuse=tf.AUTO_REUSE, 329 | scope='action/conv_2', 330 | activation_fn=tf.nn.relu) 331 | x = tf.reduce_max(x * mask, axis=1) 332 | x = tf.concat([x, s - s_ref], axis=1) 333 | x = tf.contrib.layers.fully_connected(inputs=x, 334 | num_outputs=64, 335 | reuse=tf.AUTO_REUSE, 336 | scope='action/fc_1', 337 | activation_fn=tf.nn.relu) 338 | x = tf.contrib.layers.fully_connected(inputs=x, 339 | num_outputs=128, 340 | reuse=tf.AUTO_REUSE, 341 | scope='action/fc_2', 342 | activation_fn=tf.nn.relu) 343 | x = tf.contrib.layers.fully_connected(inputs=x, 344 | num_outputs=64, 345 | reuse=tf.AUTO_REUSE, 346 | scope='action/fc_3', 347 | activation_fn=tf.nn.relu) 348 | x = tf.contrib.layers.fully_connected(inputs=x, 349 | num_outputs=3, 350 | reuse=tf.AUTO_REUSE, 351 | scope='action/fc_4', 352 | activation_fn=None) 353 | u_ref = quadrotor_controller_tf(s, s_ref) 354 | u = x + u_ref 355 | return u 356 | 357 | 358 | def loss_barrier(h, s, indices=None, eps=[5e-2, 1e-3]): 359 | """ Build the loss function for the control barrier functions. 360 | Args: 361 | h (N, N, 1): The control barrier function. 362 | s (N, 8): The current state of N agents. 363 | indices (N, K): The indices of K nearest agents of each agent. 364 | eps (2, ): The margin factors. 365 | Returns: 366 | loss_dang (float): The barrier loss for dangerous states. 367 | loss_safe (float): The barrier loss for safe sates. 368 | acc_dang (float): The accuracy of h(dangerous states) <= 0. 369 | acc_safe (float): The accuracy of h(safe states) >= 0. 370 | """ 371 | h_reshape = tf.reshape(h, [-1]) 372 | dang_mask = compute_dangerous_mask( 373 | s, r=config.DIST_MIN_THRES, indices=indices) 374 | dang_mask_reshape = tf.reshape(dang_mask, [-1]) 375 | safe_mask = compute_safe_mask(s, r=config.DIST_SAFE, indices=indices) 376 | safe_mask_reshape = tf.reshape(safe_mask, [-1]) 377 | 378 | dang_h = tf.boolean_mask(h_reshape, dang_mask_reshape) 379 | safe_h = tf.boolean_mask(h_reshape, safe_mask_reshape) 380 | 381 | num_dang = tf.cast(tf.shape(dang_h)[0], tf.float32) 382 | num_safe = tf.cast(tf.shape(safe_h)[0], tf.float32) 383 | 384 | loss_dang = tf.reduce_sum( 385 | tf.math.maximum(dang_h + eps[0], 0)) / (1e-5 + num_dang) 386 | loss_safe = tf.reduce_sum( 387 | tf.math.maximum(-safe_h + eps[1], 0)) / (1e-5 + num_safe) 388 | 389 | acc_dang = tf.reduce_sum(tf.cast( 390 | tf.less_equal(dang_h, 0), tf.float32)) / (1e-5 + num_dang) 391 | acc_safe = tf.reduce_sum(tf.cast( 392 | tf.greater_equal(safe_h, 0), tf.float32)) / (1e-5 + num_safe) 393 | 394 | acc_dang = tf.cond( 395 | tf.greater(num_dang, 0), lambda: acc_dang, lambda: -tf.constant(1.0)) 396 | acc_safe = tf.cond( 397 | tf.greater(num_safe, 0), lambda: acc_safe, lambda: -tf.constant(1.0)) 398 | 399 | return loss_dang, loss_safe, acc_dang, acc_safe 400 | 401 | 402 | def loss_derivatives(s, u, h, x, indices=None, eps=[8e-2, 0, 3e-2]): 403 | """ Build the loss function for the derivatives of the CBF. 404 | Args: 405 | s (N, 8): The current state of N agents. 406 | u (N, 2): The control action. 407 | h (N, N, 1): The control barrier function. 408 | x (N, N, 8): The state difference of N agents. 409 | indices (N, K): The indices of K nearest agents of each agent. 410 | eps (3, ): The margin factors. 411 | Returns: 412 | loss_dang_deriv (float): The derivative loss of dangerous states. 413 | loss_safe_deriv (float): The derivative loss of safe states. 414 | loss_medium_deriv (float): The derivative loss of medium states. 415 | acc_dang_deriv (float): The derivative accuracy of dangerous states. 416 | acc_safe_deriv (float): The derivative accuracy of safe states. 417 | acc_medium_deriv (float): The derivative accuracy of medium states. 418 | """ 419 | dsdt = quadrotor_dynamics_tf(s, u) 420 | s_next = s + dsdt * config.TIME_STEP 421 | 422 | x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0) 423 | h_next, mask_next, _ = network_cbf( 424 | x=x_next, r=config.DIST_MIN_THRES, indices=indices) 425 | 426 | deriv = h_next - h + config.TIME_STEP * config.ALPHA_CBF * h 427 | 428 | deriv_reshape = tf.reshape(deriv, [-1]) 429 | dang_mask = compute_dangerous_mask( 430 | s, r=config.DIST_MIN_THRES, indices=indices) 431 | dang_mask_reshape = tf.reshape(dang_mask, [-1]) 432 | safe_mask = compute_safe_mask(s, r=config.DIST_SAFE, indices=indices) 433 | safe_mask_reshape = tf.reshape(safe_mask, [-1]) 434 | medium_mask_reshape = tf.logical_not( 435 | tf.logical_or(dang_mask_reshape, safe_mask_reshape)) 436 | 437 | dang_deriv = tf.boolean_mask(deriv_reshape, dang_mask_reshape) 438 | safe_deriv = tf.boolean_mask(deriv_reshape, safe_mask_reshape) 439 | medium_deriv = tf.boolean_mask(deriv_reshape, medium_mask_reshape) 440 | 441 | num_dang = tf.cast(tf.shape(dang_deriv)[0], tf.float32) 442 | num_safe = tf.cast(tf.shape(safe_deriv)[0], tf.float32) 443 | num_medium = tf.cast(tf.shape(medium_deriv)[0], tf.float32) 444 | 445 | loss_dang_deriv = tf.reduce_sum( 446 | tf.math.maximum(-dang_deriv + eps[0], 0)) / (1e-5 + num_dang) 447 | loss_safe_deriv = tf.reduce_sum( 448 | tf.math.maximum(-safe_deriv + eps[1], 0)) / (1e-5 + num_safe) 449 | loss_medium_deriv = tf.reduce_sum( 450 | tf.math.maximum(-medium_deriv + eps[2], 0)) / (1e-5 + num_medium) 451 | 452 | acc_dang_deriv = tf.reduce_sum(tf.cast( 453 | tf.greater_equal(dang_deriv, 0), tf.float32)) / (1e-5 + num_dang) 454 | acc_safe_deriv = tf.reduce_sum(tf.cast( 455 | tf.greater_equal(safe_deriv, 0), tf.float32)) / (1e-5 + num_safe) 456 | acc_medium_deriv = tf.reduce_sum(tf.cast( 457 | tf.greater_equal(medium_deriv, 0), tf.float32)) / (1e-5 + num_medium) 458 | 459 | acc_dang_deriv = tf.cond( 460 | tf.greater(num_dang, 0), lambda: acc_dang_deriv, lambda: -tf.constant(1.0)) 461 | acc_safe_deriv = tf.cond( 462 | tf.greater(num_safe, 0), lambda: acc_safe_deriv, lambda: -tf.constant(1.0)) 463 | acc_medium_deriv = tf.cond( 464 | tf.greater(num_medium, 0), lambda: acc_medium_deriv, lambda: -tf.constant(1.0)) 465 | 466 | return (loss_dang_deriv, loss_safe_deriv, loss_medium_deriv, 467 | acc_dang_deriv, acc_safe_deriv, acc_medium_deriv) 468 | 469 | 470 | def loss_actions(s, u, s_ref, indices): 471 | """ Build the loss function for control actions. 472 | Args: 473 | s (N, 8): The current state of N agents. 474 | u (N, 2): The control action. 475 | z_ref (N, 6): The reference trajectory. 476 | indices (N, K): The indices of K nearest agents of each agent. 477 | Returns: 478 | loss (float): The loss function for control actions. 479 | """ 480 | u_ref = quadrotor_controller_tf(s, s_ref) 481 | loss = tf.minimum(tf.abs(u - u_ref), (u - u_ref)**2) 482 | safe_mask = compute_safe_mask(s, config.DIST_SAFE, indices) 483 | safe_mask = tf.reduce_mean(tf.cast(safe_mask, tf.float32), axis=1) 484 | safe_mask = tf.cast(tf.equal(safe_mask, 1), tf.float32) 485 | loss = tf.reduce_sum(loss * safe_mask) / (1e-4 + tf.reduce_sum(safe_mask)) 486 | return loss 487 | 488 | 489 | def compute_dangerous_mask(s, r, indices=None): 490 | """ Identify the agents within the dangerous radius. 491 | Args: 492 | s (N, 8): The current state of N agents. 493 | r (float): The dangerous radius. 494 | indices (N, K): The indices of K nearest agents of each agent. 495 | Returns: 496 | mask (N, K): 1 for agents inside the dangerous radius and 0 otherwise. 497 | """ 498 | s_diff = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 499 | s_diff = tf.concat( 500 | [s_diff, tf.expand_dims(tf.eye(tf.shape(s)[0]), 2)], axis=2) 501 | s_diff, _ = remove_distant_agents(s_diff, config.TOP_K, indices) 502 | z_diff, eye = s_diff[:, :, :3], s_diff[:, :, -1:] 503 | z_diff = tf.norm(z_diff, axis=2, keepdims=True) 504 | mask = tf.logical_and(tf.less(z_diff, r), tf.equal(eye, 0)) 505 | return mask 506 | 507 | 508 | def compute_safe_mask(s, r, indices=None): 509 | """ Identify the agents outside the safe radius. 510 | Args: 511 | s (N, 8): The current state of N agents. 512 | r (float): The safe radius. 513 | indices (N, K): The indices of K nearest agents of each agent. 514 | Returns: 515 | mask (N, K): 1 for agents outside the safe radius and 0 otherwise. 516 | """ 517 | s_diff = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 518 | s_diff = tf.concat( 519 | [s_diff, tf.expand_dims(tf.eye(tf.shape(s)[0]), 2)], axis=2) 520 | s_diff, _ = remove_distant_agents(s_diff, config.TOP_K, indices) 521 | z_diff, eye = s_diff[:, :, :3], s_diff[:, :, -1:] 522 | z_diff = tf.norm(z_diff, axis=2, keepdims=True) 523 | mask = tf.logical_or(tf.greater(z_diff, r), tf.equal(eye, 1)) 524 | return mask 525 | 526 | 527 | def dangerous_mask_np(s, r): 528 | """ Identify the agents within the dangerous radius. 529 | Args: 530 | s (N, 8): The current state of N agents. 531 | r (float): The dangerous radius. 532 | Returns: 533 | mask (N, N): 1 for agents inside the dangerous radius and 0 otherwise. 534 | """ 535 | s_diff = np.expand_dims(s, 1) - np.expand_dims(s, 0) 536 | s_diff = np.linalg.norm(s_diff[:, :, :3], axis=2, keepdims=False) 537 | eye = np.eye(s_diff.shape[0]) 538 | mask = np.logical_and(s_diff < r, eye == 0) 539 | return mask 540 | 541 | 542 | def remove_distant_agents(x, k, indices=None): 543 | """ Remove the distant agents. 544 | Args: 545 | x (N, N, C): The state difference of N agents. 546 | k (int): The K nearest agents to keep. 547 | Returns: 548 | x (N, K, C): The K nearest agents. 549 | indices (N, K): The indices of K nearest agents of each agent. 550 | """ 551 | n, _, c = x.get_shape().as_list() 552 | if n <= k: 553 | return x, False 554 | d_norm = tf.sqrt(tf.reduce_sum(tf.square(x[:, :, :3]) + 1e-6, axis=2)) 555 | if indices is not None: 556 | x = tf.reshape(tf.gather_nd(x, indices), [n, k, c]) 557 | return x, indices 558 | _, indices = tf.nn.top_k(-d_norm, k=k) 559 | row_indices = tf.expand_dims( 560 | tf.range(tf.shape(indices)[0]), 1) * tf.ones_like(indices) 561 | row_indices = tf.reshape(row_indices, [-1, 1]) 562 | column_indices = tf.reshape(indices, [-1, 1]) 563 | indices = tf.concat([row_indices, column_indices], axis=1) 564 | x = tf.reshape(tf.gather_nd(x, indices), [n, k, c]) 565 | return x, indices 566 | -------------------------------------------------------------------------------- /drones/dijkstra.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Grid based Dijkstra planning 4 | 5 | author: Atsushi Sakai(@Atsushi_twi) 6 | 7 | """ 8 | import sys 9 | sys.dont_write_bytecode = True 10 | 11 | import matplotlib.pyplot as plt 12 | import math 13 | 14 | show_animation = False 15 | 16 | 17 | class Dijkstra: 18 | 19 | def __init__(self, ox, oy, resolution, robot_radius): 20 | """ 21 | Initialize map for a star planning 22 | 23 | ox: x position list of Obstacles [m] 24 | oy: y position list of Obstacles [m] 25 | resolution: grid resolution [m] 26 | rr: robot radius[m] 27 | """ 28 | 29 | self.min_x = None 30 | self.min_y = None 31 | self.max_x = None 32 | self.max_y = None 33 | self.x_width = None 34 | self.y_width = None 35 | self.obstacle_map = None 36 | 37 | self.resolution = resolution 38 | self.robot_radius = robot_radius 39 | self.calc_obstacle_map(ox, oy) 40 | self.motion = self.get_motion_model() 41 | 42 | class Node: 43 | def __init__(self, x, y, cost, parent_index): 44 | self.x = x # index of grid 45 | self.y = y # index of grid 46 | self.cost = cost 47 | self.parent_index = parent_index # index of previous Node 48 | 49 | def __str__(self): 50 | return str(self.x) + "," + str(self.y) + "," + str( 51 | self.cost) + "," + str(self.parent_index) 52 | 53 | def planning(self, sx, sy, gx, gy): 54 | """ 55 | dijkstra path search 56 | 57 | input: 58 | s_x: start x position [m] 59 | s_y: start y position [m] 60 | gx: goal x position [m] 61 | gx: goal x position [m] 62 | 63 | output: 64 | rx: x position list of the final path 65 | ry: y position list of the final path 66 | """ 67 | 68 | start_node = self.Node(self.calc_xy_index(sx, self.min_x), 69 | self.calc_xy_index(sy, self.min_y), 0.0, -1) 70 | goal_node = self.Node(self.calc_xy_index(gx, self.min_x), 71 | self.calc_xy_index(gy, self.min_y), 0.0, -1) 72 | 73 | open_set, closed_set = dict(), dict() 74 | open_set[self.calc_index(start_node)] = start_node 75 | 76 | while True: 77 | c_id = min(open_set, key=lambda o: open_set[o].cost) 78 | current = open_set[c_id] 79 | 80 | # show graph 81 | if show_animation: # pragma: no cover 82 | plt.plot(self.calc_position(current.x, self.min_x), 83 | self.calc_position(current.y, self.min_y), "xc") 84 | # for stopping simulation with the esc key. 85 | plt.gcf().canvas.mpl_connect( 86 | 'key_release_event', 87 | lambda event: [exit(0) if event.key == 'escape' else None]) 88 | if len(closed_set.keys()) % 10 == 0: 89 | plt.pause(0.001) 90 | 91 | if current.x == goal_node.x and current.y == goal_node.y: 92 | goal_node.parent_index = current.parent_index 93 | goal_node.cost = current.cost 94 | break 95 | 96 | # Remove the item from the open set 97 | del open_set[c_id] 98 | 99 | # Add it to the closed set 100 | closed_set[c_id] = current 101 | 102 | # expand search grid based on motion model 103 | for move_x, move_y, move_cost in self.motion: 104 | node = self.Node(current.x + move_x, 105 | current.y + move_y, 106 | current.cost + move_cost, c_id) 107 | n_id = self.calc_index(node) 108 | 109 | if n_id in closed_set: 110 | continue 111 | 112 | if not self.verify_node(node): 113 | continue 114 | 115 | if n_id not in open_set: 116 | open_set[n_id] = node # Discover a new node 117 | else: 118 | if open_set[n_id].cost >= node.cost: 119 | # This path is the best until now. record it! 120 | open_set[n_id] = node 121 | 122 | rx, ry = self.calc_final_path(goal_node, closed_set) 123 | 124 | return rx, ry 125 | 126 | def calc_final_path(self, goal_node, closed_set): 127 | # generate final course 128 | rx, ry = [self.calc_position(goal_node.x, self.min_x)], [ 129 | self.calc_position(goal_node.y, self.min_y)] 130 | parent_index = goal_node.parent_index 131 | while parent_index != -1: 132 | n = closed_set[parent_index] 133 | rx.append(self.calc_position(n.x, self.min_x)) 134 | ry.append(self.calc_position(n.y, self.min_y)) 135 | parent_index = n.parent_index 136 | 137 | return rx, ry 138 | 139 | def calc_position(self, index, minp): 140 | pos = index * self.resolution + minp 141 | return pos 142 | 143 | def calc_xy_index(self, position, minp): 144 | return round((position - minp) / self.resolution) 145 | 146 | def calc_index(self, node): 147 | return (node.y - self.min_y) * self.x_width + (node.x - self.min_x) 148 | 149 | def verify_node(self, node): 150 | px = self.calc_position(node.x, self.min_x) 151 | py = self.calc_position(node.y, self.min_y) 152 | 153 | if px < self.min_x: 154 | return False 155 | if py < self.min_y: 156 | return False 157 | if px >= self.max_x: 158 | return False 159 | if py >= self.max_y: 160 | return False 161 | 162 | if self.obstacle_map[int(node.x)][int(node.y)]: 163 | return False 164 | 165 | return True 166 | 167 | def calc_obstacle_map(self, ox, oy): 168 | 169 | self.min_x = round(min(ox)) 170 | self.min_y = round(min(oy)) 171 | self.max_x = round(max(ox)) 172 | self.max_y = round(max(oy)) 173 | 174 | self.x_width = int(round((self.max_x - self.min_x) / self.resolution)) 175 | self.y_width = int(round((self.max_y - self.min_y) / self.resolution)) 176 | 177 | # obstacle map generation 178 | self.obstacle_map = [[False for _ in range(self.y_width)] 179 | for _ in range(self.x_width)] 180 | for ix in range(self.x_width): 181 | x = self.calc_position(ix, self.min_x) 182 | for iy in range(self.y_width): 183 | y = self.calc_position(iy, self.min_y) 184 | for iox, ioy in zip(ox, oy): 185 | d = math.hypot(iox - x, ioy - y) 186 | if d <= self.robot_radius: 187 | self.obstacle_map[ix][iy] = True 188 | break 189 | 190 | @staticmethod 191 | def get_motion_model(): 192 | # dx, dy, cost 193 | motion = [[1, 0, 1], 194 | [0, 1, 1], 195 | [-1, 0, 1], 196 | [0, -1, 1], 197 | [-1, -1, math.sqrt(2)], 198 | [-1, 1, math.sqrt(2)], 199 | [1, -1, math.sqrt(2)], 200 | [1, 1, math.sqrt(2)]] 201 | 202 | return motion 203 | 204 | 205 | def main(): 206 | print(__file__ + " start!!") 207 | 208 | # start and goal position 209 | sx = -5.0 # [m] 210 | sy = -5.0 # [m] 211 | gx = 50.0 # [m] 212 | gy = 50.0 # [m] 213 | grid_size = 5.0 # [m] 214 | robot_radius = 1.0 # [m] 215 | 216 | # set obstacle positions 217 | ox, oy = [], [] 218 | for i in range(-10, 60): 219 | ox.append(i) 220 | oy.append(-10.0) 221 | for i in range(-10, 60): 222 | ox.append(60.0) 223 | oy.append(i) 224 | for i in range(-10, 61): 225 | ox.append(i) 226 | oy.append(60.0) 227 | for i in range(-10, 61): 228 | ox.append(-10.0) 229 | oy.append(i) 230 | for i in range(-10, 40): 231 | ox.append(20.0) 232 | oy.append(i) 233 | for i in range(0, 40): 234 | ox.append(40.0) 235 | oy.append(60.0 - i) 236 | 237 | if show_animation: # pragma: no cover 238 | plt.plot(ox, oy, ".k") 239 | plt.plot(sx, sy, "og") 240 | plt.plot(gx, gy, "xb") 241 | plt.grid(True) 242 | plt.axis("equal") 243 | 244 | dijkstra = Dijkstra(ox, oy, grid_size, robot_radius) 245 | rx, ry = dijkstra.planning(sx, sy, gx, gy) 246 | 247 | if show_animation: # pragma: no cover 248 | plt.plot(rx, ry, "-r") 249 | plt.pause(0.01) 250 | plt.show() 251 | 252 | 253 | if __name__ == '__main__': 254 | main() 255 | -------------------------------------------------------------------------------- /drones/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.dont_write_bytecode = True 3 | 4 | import os 5 | import time 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d import Axes3D 11 | import pickle 12 | 13 | import core 14 | import config 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--num_agents', type=int, default=64) 20 | parser.add_argument('--max_steps', type=int, default=12) 21 | parser.add_argument('--model_path', type=str, default=None) 22 | parser.add_argument('--vis', type=int, default=0) 23 | parser.add_argument('--gpu', type=str, default='0') 24 | parser.add_argument('--ref', type=str, default=None) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def build_evaluation_graph(num_agents): 30 | # s is the state vectors of the agents 31 | s = tf.placeholder(tf.float32, [num_agents, 8]) 32 | # s_ref is the goal states 33 | s_ref = tf.placeholder(tf.float32, [num_agents, 8]) 34 | # x is difference between the state of each agent and other agents 35 | x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 36 | # h is the CBF value of shape [num_agents, TOP_K, 1], where TOP_K represents 37 | # the K nearest agents 38 | h, mask, indices = core.network_cbf( 39 | x=x, r=config.DIST_MIN_THRES, indices=None) 40 | # u is the control action of each agent, with shape [num_agents, 3] 41 | u = core.network_action( 42 | s=s, s_ref=s_ref, obs_radius=config.OBS_RADIUS, indices=indices) 43 | safe_mask = core.compute_safe_mask(s, r=config.DIST_SAFE, indices=indices) 44 | # check if each agent is safe 45 | is_safe = tf.equal(tf.reduce_mean(tf.cast(safe_mask, tf.float32)), 1) 46 | 47 | # u_res is delta u. when u does not satisfy the CBF conditions, we want to compute 48 | # a u_res such that u + u_res satisfies the CBF conditions 49 | u_res = tf.Variable(tf.zeros_like(u), name='u_res') 50 | loop_count = tf.Variable(0, name='loop_count') 51 | 52 | def opt_body(u_res, loop_count, is_safe): 53 | # a loop of updating u_res 54 | # compute s_next under u + u_res 55 | dsdt = core.quadrotor_dynamics_tf(s, u + u_res) 56 | s_next = s + dsdt * config.TIME_STEP_EVAL 57 | x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0) 58 | h_next, mask_next, _ = core.network_cbf( 59 | x=x_next, r=config.DIST_MIN_THRES, indices=indices) 60 | # deriv should be >= 0. if not, we update u_res by gradient descent 61 | deriv = h_next - h + config.TIME_STEP_EVAL * config.ALPHA_CBF * h 62 | deriv = deriv * mask * mask_next 63 | error = tf.reduce_sum(tf.math.maximum(-deriv, 0), axis=1)\ 64 | # compute the gradient to update u_res 65 | error_gradient = tf.gradients(error, u_res)[0] 66 | u_res = u_res - config.REFINE_LEARNING_RATE * error_gradient 67 | loop_count = loop_count + 1 68 | return u_res, loop_count, is_safe 69 | 70 | def opt_cond(u_res, loop_count, is_safe): 71 | # update u_res for REFINE_LOOPS 72 | cond = tf.logical_and( 73 | tf.less(loop_count, config.REFINE_LOOPS), 74 | tf.logical_not(is_safe)) 75 | return cond 76 | 77 | with tf.control_dependencies([ 78 | u_res.assign(tf.zeros_like(u)), loop_count.assign(0)]): 79 | u_res, _, _ = tf.while_loop(opt_cond, opt_body, [u_res, loop_count, is_safe]) 80 | u_opt = u + u_res 81 | 82 | # compute the value of loss functions and the accuracies 83 | # loss_dang is for h(s) < 0, s in dangerous set 84 | # loss safe is for h(s) >=0, s in safe set 85 | # acc_dang is the accuracy that h(s) < 0, s in dangerous set is satisfied 86 | # acc_safe is the accuracy that h(s) >=0, s in safe set is satisfied 87 | loss_dang, loss_safe, acc_dang, acc_safe = core.loss_barrier( 88 | h=h, s=s, indices=indices) 89 | # loss_dang_deriv is for doth(s) + alpha h(s) >=0 for s in dangerous set 90 | # loss_safe_deriv is for doth(s) + alpha h(s) >=0 for s in safe set 91 | # loss_medium_deriv is for doth(s) + alpha h(s) >=0 for s not in the dangerous 92 | # or the safe set 93 | (loss_dang_deriv, loss_safe_deriv, loss_medium_deriv, acc_dang_deriv, 94 | acc_safe_deriv, acc_medium_deriv) = core.loss_derivatives( 95 | s=s, u=u_opt, h=h, x=x, indices=indices) 96 | # the distance between the u_opt and the nominal u 97 | loss_action = core.loss_actions(s=s, u=u_opt, s_ref=s_ref, indices=indices) 98 | 99 | loss_list = [loss_dang, loss_safe, loss_dang_deriv, 100 | loss_safe_deriv, loss_medium_deriv, loss_action] 101 | acc_list = [acc_dang, acc_safe, acc_dang_deriv, acc_safe_deriv, acc_medium_deriv] 102 | 103 | return s, s_ref, u_opt, loss_list, acc_list 104 | 105 | 106 | def print_accuracy(accuracy_lists): 107 | acc = np.array(accuracy_lists) 108 | acc_list = [] 109 | for i in range(acc.shape[1]): 110 | acc_i = acc[:, i] 111 | acc_list.append(np.mean(acc_i[acc_i > 0])) 112 | print('Accuracy: {}'.format(acc_list)) 113 | 114 | 115 | def render_init(num_agents): 116 | fig = plt.figure(figsize=(10, 7)) 117 | return fig 118 | 119 | 120 | def show_obstacles(obs, ax, z=[0, 6], alpha=0.6, color='deepskyblue'): 121 | for x1, y1, x2, y2 in obs: 122 | xs, ys = np.meshgrid([x1, x2], [y1, y2]) 123 | zs = np.ones_like(xs) 124 | ax.plot_surface(xs, ys, zs * z[0], alpha=alpha, color=color) 125 | ax.plot_surface(xs, ys, zs * z[1], alpha=alpha, color=color) 126 | 127 | xs, zs = np.meshgrid([x1, x2], z) 128 | ys = np.ones_like(xs) 129 | ax.plot_surface(xs, ys * y1, zs, alpha=alpha, color=color) 130 | ax.plot_surface(xs, ys * y2, zs, alpha=alpha, color=color) 131 | 132 | ys, zs = np.meshgrid([y1, y2], z) 133 | xs = np.ones_like(ys) 134 | ax.plot_surface(xs * x1, ys, zs, alpha=alpha, color=color) 135 | ax.plot_surface(xs * x2, ys, zs, alpha=alpha, color=color) 136 | 137 | 138 | def clip_norm(x, thres): 139 | norm = np.linalg.norm(x, axis=1, keepdims=True) 140 | mask = (norm > thres).astype(np.float32) 141 | x = x * (1 - mask) + x * mask / (1e-6 + norm) 142 | return x 143 | 144 | 145 | def clip_state(s, x_thres, v_thres=0.1, h_thres=6): 146 | x, v, r = s[:, :3], s[:, 3:6], s[:, 6:] 147 | x = np.concatenate([np.clip(x[:, :2], 0, x_thres), 148 | np.clip(x[:, 2:], 0, h_thres)], axis=1) 149 | v = clip_norm(v, v_thres) 150 | s = np.concatenate([x, v, r], axis=1) 151 | return s 152 | 153 | 154 | def main(): 155 | args = parse_args() 156 | s, s_ref, u, loss_list, acc_list = build_evaluation_graph(args.num_agents) 157 | # loads the pretrained weights 158 | vars = tf.trainable_variables() 159 | vars_restore = [] 160 | for v in vars: 161 | if 'action' in v.name or 'cbf' in v.name: 162 | vars_restore.append(v) 163 | # initialize the tensorflow Session 164 | sess = tf.Session() 165 | sess.run(tf.global_variables_initializer()) 166 | saver = tf.train.Saver(var_list=vars_restore) 167 | saver.restore(sess, args.model_path) 168 | 169 | safety_ratios_epoch = [] 170 | safety_ratios_epoch_baseline = [] 171 | 172 | dist_errors = [] 173 | dist_errors_baseline = [] 174 | accuracy_lists = [] 175 | 176 | if args.vis > 0: 177 | plt.ion() 178 | plt.close() 179 | fig = render_init(args.num_agents) 180 | # initialize the environment 181 | scene = core.Maze(args.num_agents, max_steps=args.max_steps) 182 | if args.ref is not None: 183 | scene.read(args.ref) 184 | 185 | if not os.path.exists('trajectory'): 186 | os.mkdir('trajectory') 187 | traj_dict = {'ours': [], 'baseline': [], 'obstacles': [np.array(scene.OBSTACLES)]} 188 | 189 | 190 | safety_reward = [] 191 | dist_reward = [] 192 | 193 | for istep in range(config.EVALUATE_STEPS): 194 | if args.vis > 0: 195 | plt.clf() 196 | ax_1 = fig.add_subplot(121, projection='3d') 197 | ax_2 = fig.add_subplot(122, projection='3d') 198 | safety_ours = [] 199 | safety_baseline = [] 200 | 201 | scene.reset() 202 | start_time = time.time() 203 | s_np = np.concatenate( 204 | [scene.start_points, np.zeros((args.num_agents, 5))], axis=1) 205 | s_traj = [] 206 | safety_info = np.zeros(args.num_agents, dtype=np.float32) 207 | # a scene has a sequence of goal states for each agent. in each scene.step, 208 | # we move to a new goal state 209 | for t in range(scene.steps): 210 | # the goal states 211 | s_ref_np = np.concatenate( 212 | [scene.waypoints[t], np.zeros((args.num_agents, 5))], axis=1) 213 | # run INNER_LOOPS_EVAL steps to reach the goal state 214 | for i in range(config.INNER_LOOPS_EVAL): 215 | u_np, acc_list_np = sess.run( 216 | [u, acc_list], feed_dict={s:s_np, s_ref: s_ref_np}) 217 | if args.vis == 1: 218 | u_ref_np = core.quadrotor_controller_np(s_np, s_ref_np) 219 | u_np = clip_norm(u_np - u_ref_np, 100.0) + u_ref_np 220 | dsdt = core.quadrotor_dynamics_np(s_np, u_np) 221 | s_np = s_np + dsdt * config.TIME_STEP_EVAL 222 | safety_ratio = 1 - np.mean( 223 | core.dangerous_mask_np(s_np, config.DIST_MIN_CHECK), axis=1) 224 | individual_safety = safety_ratio == 1 225 | safety_ours.append(individual_safety) 226 | safety_info = safety_info + individual_safety - 1 227 | safety_ratio = np.mean(individual_safety) 228 | safety_ratios_epoch.append(safety_ratio) 229 | accuracy_lists.append(acc_list_np) 230 | if np.mean( 231 | np.linalg.norm(s_np[:, :3] - s_ref_np[:, :3], axis=1) 232 | ) < config.DIST_TOLERATE: 233 | break 234 | 235 | s_traj.append(np.expand_dims(s_np[:, [0, 1, 2, 6, 7]], axis=0)) 236 | safety_reward.append(np.mean(safety_info)) 237 | dist_reward.append(np.mean((np.linalg.norm( 238 | s_np[:, :3] - s_ref_np[:, :3], axis=1) < 1.5).astype(np.float32) * 10)) 239 | dist_errors.append( 240 | np.mean(np.linalg.norm(s_np[:, :3] - s_ref_np[:, :3], axis=1))) 241 | traj_dict['ours'].append(np.concatenate(s_traj, axis=0)) 242 | end_time = time.time() 243 | 244 | # reach the same goals using LQR controller without considering the collision 245 | s_np = np.concatenate( 246 | [scene.start_points, np.zeros((args.num_agents, 5))], axis=1) 247 | s_traj = [] 248 | for t in range(scene.steps): 249 | s_ref_np = np.concatenate( 250 | [scene.waypoints[t], np.zeros((args.num_agents, 5))], axis=1) 251 | for i in range(config.INNER_LOOPS_EVAL): 252 | u_np = core.quadrotor_controller_np(s_np, s_ref_np) 253 | dsdt = core.quadrotor_dynamics_np(s_np, u_np) 254 | s_np = s_np + dsdt * config.TIME_STEP_EVAL 255 | safety_ratio = 1 - np.mean( 256 | core.dangerous_mask_np(s_np, config.DIST_MIN_CHECK), axis=1) 257 | individual_safety = safety_ratio == 1 258 | safety_baseline.append(individual_safety) 259 | safety_ratio = np.mean(individual_safety) 260 | safety_ratios_epoch_baseline.append(safety_ratio) 261 | s_traj.append(np.expand_dims(s_np[:, [0, 1, 2, 6, 7]], axis=0)) 262 | dist_errors_baseline.append(np.mean(np.linalg.norm(s_np[:, :3] - s_ref_np[:, :3], axis=1))) 263 | traj_dict['baseline'].append(np.concatenate(s_traj, axis=0)) 264 | 265 | if args.vis > 0: 266 | # visualize the trajectories 267 | s_traj_ours = traj_dict['ours'][-1] 268 | s_traj_baseline = traj_dict['baseline'][-1] 269 | 270 | for j in range(0, max(s_traj_ours.shape[0], s_traj_baseline.shape[0]), 10): 271 | ax_1.clear() 272 | ax_1.view_init(elev=80, azim=-45) 273 | ax_1.axis('off') 274 | show_obstacles(scene.OBSTACLES, ax_1) 275 | j_ours = min(j, s_traj_ours.shape[0]-1) 276 | s_np = s_traj_ours[j_ours] 277 | safety = safety_ours[j_ours] 278 | 279 | ax_1.set_xlim(0, 20) 280 | ax_1.set_ylim(0, 20) 281 | ax_1.set_zlim(0, 10) 282 | ax_1.scatter(s_np[:, 0], s_np[:, 1], s_np[0, 2], 283 | color='darkorange', label='Agent') 284 | ax_1.scatter(s_np[safety<1, 0], s_np[safety<1, 1], s_np[safety<1, 2], 285 | color='red', label='Collision') 286 | ax_1.set_title('Ours: Safety Rate = {:.4f}'.format( 287 | np.mean(safety_ratios_epoch)), fontsize=16) 288 | 289 | ax_2.clear() 290 | ax_2.view_init(elev=80, azim=-45) 291 | ax_2.axis('off') 292 | show_obstacles(scene.OBSTACLES, ax_2) 293 | j_baseline = min(j, s_traj_baseline.shape[0]-1) 294 | s_np = s_traj_baseline[j_baseline] 295 | safety = safety_baseline[j_baseline] 296 | 297 | ax_2.set_xlim(0, 20) 298 | ax_2.set_ylim(0, 20) 299 | ax_2.set_zlim(0, 10) 300 | ax_2.scatter(s_np[:, 0], s_np[:, 1], s_np[1, 2], 301 | color='darkorange', label='Agent') 302 | ax_2.scatter(s_np[safety<1, 0], s_np[safety<1, 1], s_np[safety<1, 2], 303 | color='red', label='Collision') 304 | ax_2.set_title('LQR: Safety Rate = {:.4f}'.format( 305 | np.mean(safety_ratios_epoch_baseline)), fontsize=16) 306 | plt.legend(loc='lower right') 307 | 308 | fig.canvas.draw() 309 | plt.pause(0.001) 310 | 311 | 312 | 313 | print('Evaluation Step: {} | {}, Time: {:.4f}'.format( 314 | istep + 1, config.EVALUATE_STEPS, end_time - start_time)) 315 | 316 | print_accuracy(accuracy_lists) 317 | print('Distance Error (Learning | Baseline): {:.4f} | {:.4f}'.format( 318 | np.mean(dist_errors), np.mean(dist_errors_baseline))) 319 | print('Mean Safety Ratio (Learning | Baseline): {:.4f} | {:.4f}'.format( 320 | np.mean(safety_ratios_epoch), np.mean(safety_ratios_epoch_baseline))) 321 | 322 | safety_reward = np.mean(safety_reward) 323 | dist_reward = np.mean(dist_reward) 324 | print('Safety Reward: {:.4f}, Dist Reward: {:.4f}, Reward: {:.4f}'.format( 325 | safety_reward, dist_reward, 9 + 0.1 * (safety_reward + dist_reward))) 326 | 327 | pickle.dump(traj_dict, open('trajectory/traj_eval.pkl', 'wb')) 328 | scene.write_trajectory('trajectory/env_traj_eval.pkl', traj_dict['ours']) 329 | 330 | 331 | if __name__ == '__main__': 332 | main() 333 | -------------------------------------------------------------------------------- /drones/models/model_save.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/macbf/7e364b7a65801debaccc2f60673323385e6f05d3/drones/models/model_save.data-00000-of-00001 -------------------------------------------------------------------------------- /drones/models/model_save.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/macbf/7e364b7a65801debaccc2f60673323385e6f05d3/drones/models/model_save.index -------------------------------------------------------------------------------- /drones/models/model_save.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-REALM/macbf/7e364b7a65801debaccc2f60673323385e6f05d3/drones/models/model_save.meta -------------------------------------------------------------------------------- /drones/train.py: -------------------------------------------------------------------------------- 1 | import config 2 | import core 3 | import tensorflow as tf 4 | import numpy as np 5 | import argparse 6 | import time 7 | import os 8 | import sys 9 | sys.dont_write_bytecode = True 10 | 11 | 12 | np.set_printoptions(3) 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--num_agents', type=int, default=16) 18 | parser.add_argument('--model_path', type=str, default=None) 19 | parser.add_argument('--gpu', type=str, default='0') 20 | parser.add_argument('--tag', type=str, default='default') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def build_optimizer(loss): 26 | optimizer = tf.train.AdamOptimizer(learning_rate=config.LEARNING_RATE) 27 | trainable_vars = tf.trainable_variables() 28 | 29 | # tensor to accumulate gradients over multiple steps 30 | accumulators = [ 31 | tf.Variable( 32 | tf.zeros_like(tv.initialized_value()), 33 | trainable=False 34 | ) for tv in trainable_vars] 35 | # count how many steps we have accumulated 36 | accumulation_counter = tf.Variable(0.0, trainable=False) 37 | grad_pairs = optimizer.compute_gradients(loss, trainable_vars) 38 | # add the gradient to the accumulation tensor 39 | accumulate_ops = [ 40 | accumulator.assign_add( 41 | grad 42 | ) for (accumulator, (grad, var)) in zip(accumulators, grad_pairs)] 43 | 44 | accumulate_ops.append(accumulation_counter.assign_add(1.0)) 45 | # divide the accumulated gradient by the number of accumulation steps 46 | gradient_vars = [(accumulator / accumulation_counter, var) 47 | for (accumulator, (grad, var)) in zip(accumulators, grad_pairs)] 48 | # seperate the gradient of CBF and the controller 49 | gradient_vars_h = [] 50 | gradient_vars_a = [] 51 | for accumulate_grad, var in gradient_vars: 52 | if 'cbf' in var.name: 53 | gradient_vars_h.append((accumulate_grad, var)) 54 | elif 'action' in var.name: 55 | gradient_vars_a.append((accumulate_grad, var)) 56 | else: 57 | raise ValueError 58 | 59 | train_step_h = optimizer.apply_gradients(gradient_vars_h) 60 | train_step_a = optimizer.apply_gradients(gradient_vars_a) 61 | # re-initialize the accmulation tensor and accumulation step to zero 62 | zero_ops = [ 63 | accumulator.assign( 64 | tf.zeros_like(tv) 65 | ) for (accumulator, tv) in zip(accumulators, trainable_vars)] 66 | zero_ops.append(accumulation_counter.assign(0.0)) 67 | 68 | return zero_ops, accumulate_ops, train_step_h, train_step_a 69 | 70 | 71 | def build_training_graph(num_agents): 72 | # s is the state vectors of the agents 73 | s = tf.placeholder(tf.float32, [num_agents, 8]) 74 | # s_ref is the goal states 75 | s_ref = tf.placeholder(tf.float32, [num_agents, 8]) 76 | # x is difference between the state of each agent and other agents 77 | x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0) 78 | # h is the CBF value of shape [num_agents, TOP_K, 1], where TOP_K represents 79 | # the K nearest agents 80 | h, mask, indices = core.network_cbf( 81 | x=x, r=config.DIST_MIN_THRES, indices=None) 82 | # u is the control action of each agent, with shape [num_agents, 3] 83 | u = core.network_action( 84 | s=s, s_ref=s_ref, obs_radius=config.OBS_RADIUS, indices=indices) 85 | # compute the value of loss functions and the accuracies 86 | # loss_dang is for h(s) < 0, s in dangerous set 87 | # loss safe is for h(s) >=0, s in safe set 88 | # acc_dang is the accuracy that h(s) < 0, s in dangerous set is satisfied 89 | # acc_safe is the accuracy that h(s) >=0, s in safe set is satisfied 90 | loss_dang, loss_safe, acc_dang, acc_safe = core.loss_barrier( 91 | h=h, s=s, indices=indices) 92 | # loss_dang_deriv is for doth(s) + alpha h(s) >=0 for s in dangerous set 93 | # loss_safe_deriv is for doth(s) + alpha h(s) >=0 for s in safe set 94 | # loss_medium_deriv is for doth(s) + alpha h(s) >=0 for s not in the dangerous 95 | # or the safe set 96 | (loss_dang_deriv, loss_safe_deriv, loss_medium_deriv, acc_dang_deriv, 97 | acc_safe_deriv, acc_medium_deriv) = core.loss_derivatives( 98 | s=s, u=u, h=h, x=x, indices=indices) 99 | # the distance between the u and the nominal u 100 | loss_action = core.loss_actions(s=s, u=u, s_ref=s_ref, indices=indices) 101 | 102 | # the weight of each loss item requires careful tuning 103 | loss_list = [loss_dang, loss_safe, 3 * loss_dang_deriv, 104 | loss_safe_deriv, 2 * loss_medium_deriv, 0.5 * loss_action] 105 | acc_list = [acc_dang, acc_safe, acc_dang_deriv, 106 | acc_safe_deriv, acc_medium_deriv] 107 | 108 | weight_loss = [ 109 | config.WEIGHT_DECAY * tf.nn.l2_loss(v) for v in tf.trainable_variables()] 110 | loss = 10 * tf.math.add_n(loss_list + weight_loss) 111 | 112 | return s, s_ref, u, loss_list, loss, acc_list 113 | 114 | 115 | def count_accuracy(accuracy_lists): 116 | acc = np.array(accuracy_lists) 117 | acc_list = [] 118 | for i in range(acc.shape[1]): 119 | acc_list.append(np.mean(acc[acc[:, i] >= 0, i])) 120 | return acc_list 121 | 122 | 123 | def main(): 124 | args = parse_args() 125 | 126 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 127 | 128 | s, s_ref, u, loss_list, loss, acc_list = build_training_graph( 129 | args.num_agents) 130 | zero_ops, accumulate_ops, train_step_h, train_step_a = build_optimizer( 131 | loss) 132 | accumulate_ops.append(loss_list) 133 | accumulate_ops.append(acc_list) 134 | 135 | accumulation_steps = config.INNER_LOOPS 136 | 137 | with tf.Session() as sess: 138 | sess.run(tf.global_variables_initializer()) 139 | saver = tf.train.Saver() 140 | 141 | if args.model_path: 142 | saver.restore(sess, args.model_path) 143 | 144 | loss_lists_np = [] 145 | acc_lists_np = [] 146 | dist_errors_np = [] 147 | dist_errors_baseline_np = [] 148 | 149 | safety_ratios_epoch = [] 150 | safety_ratios_epoch_baseline = [] 151 | 152 | scene = core.Cityscape(args.num_agents) 153 | start_time = time.time() 154 | 155 | for istep in range(config.TRAIN_STEPS): 156 | scene.reset() 157 | # the initial state 158 | s_np = np.concatenate( 159 | [scene.start_points, np.zeros((args.num_agents, 5))], axis=1) 160 | # set the gradient accumulation tensor to zero 161 | sess.run(zero_ops) 162 | for t in range(scene.steps): 163 | # for each scene.step, we have a goal (reference) state for each agent 164 | s_ref_np = np.concatenate( 165 | [scene.waypoints[t], np.zeros((args.num_agents, 5))], axis=1) 166 | # run the system and accumulate gradient over multiple steps 167 | for i in range(accumulation_steps): 168 | u_np, out = sess.run([u, accumulate_ops], feed_dict={ 169 | s: s_np, s_ref: s_ref_np}) 170 | dsdt = core.quadrotor_dynamics_np(s_np, u_np) 171 | s_np = s_np + dsdt * config.TIME_STEP 172 | safety_ratio = 1 - np.mean( 173 | core.dangerous_mask_np(s_np, config.DIST_MIN_CHECK), axis=1) 174 | safety_ratio = np.mean(safety_ratio == 1) 175 | safety_ratios_epoch.append(safety_ratio) 176 | loss_list_np, acc_list_np = out[-2], out[-1] 177 | loss_lists_np.append(loss_list_np) 178 | acc_lists_np.append(acc_list_np) 179 | 180 | dist_errors_np.append(np.mean(np.linalg.norm( 181 | s_np[:, :3] - s_ref_np[:, :3], axis=1))) 182 | 183 | # achieve the same goal states using LQR controllers without considering collision 184 | s_np = np.concatenate( 185 | [scene.start_points, np.zeros((args.num_agents, 5))], axis=1) 186 | for t in range(scene.steps): 187 | s_ref_np = np.concatenate( 188 | [scene.waypoints[t], np.zeros((args.num_agents, 5))], axis=1) 189 | for i in range(accumulation_steps): 190 | u_np = core.quadrotor_controller_np(s_np, s_ref_np) 191 | dsdt = core.quadrotor_dynamics_np(s_np, u_np) 192 | s_np = s_np + dsdt * config.TIME_STEP 193 | safety_ratio = 1 - np.mean( 194 | core.dangerous_mask_np(s_np, config.DIST_MIN_CHECK), axis=1) 195 | safety_ratio = np.mean(safety_ratio == 1) 196 | safety_ratios_epoch_baseline.append(safety_ratio) 197 | dist_errors_baseline_np.append( 198 | np.mean(np.linalg.norm(s_np[:, :3] - s_ref_np[:, :3], axis=1))) 199 | 200 | if np.mod(istep // 10, 2) == 0: 201 | sess.run(train_step_h) 202 | else: 203 | sess.run(train_step_a) 204 | 205 | if np.mod(istep, config.DISPLAY_STEPS) == 0: 206 | print('Step: {}, Time: {:.1f}, Loss: {}, Dist: {:.3f}, Safety Rate: {:.3f}'.format( 207 | istep, time.time() - start_time, np.mean(loss_lists_np, axis=0), 208 | np.mean(dist_errors_np), np.mean(safety_ratios_epoch))) 209 | start_time = time.time() 210 | (loss_lists_np, acc_lists_np, dist_errors_np, dist_errors_baseline_np, safety_ratios_epoch, 211 | safety_ratios_epoch_baseline) = [], [], [], [], [], [] 212 | 213 | if np.mod(istep, config.SAVE_STEPS) == 0 or istep + 1 == config.TRAIN_STEPS: 214 | saver.save( 215 | sess, 'models/model_{}_iter_{}'.format(args.tag, istep)) 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.14.0 2 | scipy 3 | pyyaml 4 | matplotlib --------------------------------------------------------------------------------